|
|
import os |
|
|
import argparse |
|
|
import logging |
|
|
import sys |
|
|
|
|
|
|
|
|
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' |
|
|
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0' |
|
|
|
|
|
import warnings |
|
|
warnings.filterwarnings('ignore') |
|
|
|
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
try: |
|
|
import absl.logging |
|
|
absl.logging.set_verbosity(absl.logging.ERROR) |
|
|
except ImportError: |
|
|
pass |
|
|
|
|
|
|
|
|
logging.getLogger('tensorflow').setLevel(logging.ERROR) |
|
|
|
|
|
from model import PrecomputedModel, RawImageModel |
|
|
from evaluate import evaluate_predictions |
|
|
|
|
|
DIAGNOSIS_PROMPTS = { |
|
|
'AIRSPACE_OPACITY': ('Airspace Opacity', 'no evidence of airspace disease'), |
|
|
'PNEUMOTHORAX': ('small pneumothorax', 'no pneumothorax'), |
|
|
'EFFUSION': ('large pleural effusion', 'no pleural effusion'), |
|
|
'PULMONARY_EDEMA': ('moderate pulmonary edema', 'no pulmonary edema'), |
|
|
} |
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser(description="Zero-Shot Chest X-Ray Classification") |
|
|
parser.add_argument("--diagnosis", type=str, choices=DIAGNOSIS_PROMPTS.keys(), required=True, help="Diagnosis to evaluate") |
|
|
parser.add_argument("--data-dir", type=str, default="data", help="Path to data directory") |
|
|
parser.add_argument("--raw-image", type=str, help="Path to a raw image file for inference (optional)") |
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
pos_txt, neg_txt = DIAGNOSIS_PROMPTS[args.diagnosis] |
|
|
logger.info(f"Diagnosis: {args.diagnosis}") |
|
|
logger.info(f"Positive query: '{pos_txt}'") |
|
|
logger.info(f"Negative query: '{neg_txt}'") |
|
|
|
|
|
|
|
|
precomputed_model = PrecomputedModel(data_dir=args.data_dir) |
|
|
pos_emb, neg_emb = precomputed_model.get_diagnosis_embeddings(pos_txt, neg_txt) |
|
|
|
|
|
if args.raw_image: |
|
|
|
|
|
logger.info(f"Running inference on raw image: {args.raw_image}") |
|
|
raw_model = RawImageModel() |
|
|
try: |
|
|
image_emb = raw_model.compute_embeddings(args.raw_image) |
|
|
|
|
|
|
|
|
|
|
|
score = PrecomputedModel.zero_shot(image_emb, pos_emb, neg_emb) |
|
|
logger.info(f"Zero-shot score for {args.raw_image}: {score:.4f}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print(f"Score for {args.diagnosis}: {score}") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to process raw image: {e}") |
|
|
sys.exit(1) |
|
|
|
|
|
else: |
|
|
|
|
|
logger.info("Running evaluation on full precomputed dataset...") |
|
|
|
|
|
|
|
|
labels_df = precomputed_model.labels |
|
|
target_df = labels_df[labels_df[args.diagnosis].isin([0, 1])][['image_id', args.diagnosis]].copy() |
|
|
|
|
|
image_ids = target_df['image_id'].tolist() |
|
|
true_labels = target_df[args.diagnosis].tolist() |
|
|
|
|
|
|
|
|
valid_ids, scores = precomputed_model.compute_scores(image_ids, pos_emb, neg_emb) |
|
|
|
|
|
|
|
|
final_labels = [] |
|
|
for img_id, label in zip(image_ids, true_labels): |
|
|
if img_id in valid_ids: |
|
|
final_labels.append(label) |
|
|
|
|
|
if not scores: |
|
|
logger.error("No valid scores computed. Check embedding match.") |
|
|
sys.exit(1) |
|
|
|
|
|
|
|
|
evaluate_predictions(scores, final_labels, args.diagnosis) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|