import os import argparse import logging import sys # Suppress TensorFlow and system warnings os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # FATAL os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0' import warnings warnings.filterwarnings('ignore') import numpy as np import pandas as pd # Configure logging first logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) # Suppress absl logging from TensorFlow try: import absl.logging absl.logging.set_verbosity(absl.logging.ERROR) except ImportError: pass # Suppress TensorFlow Python logging logging.getLogger('tensorflow').setLevel(logging.ERROR) from model import PrecomputedModel, RawImageModel from evaluate import evaluate_predictions DIAGNOSIS_PROMPTS = { 'AIRSPACE_OPACITY': ('Airspace Opacity', 'no evidence of airspace disease'), 'PNEUMOTHORAX': ('small pneumothorax', 'no pneumothorax'), 'EFFUSION': ('large pleural effusion', 'no pleural effusion'), 'PULMONARY_EDEMA': ('moderate pulmonary edema', 'no pulmonary edema'), } def main(): parser = argparse.ArgumentParser(description="Zero-Shot Chest X-Ray Classification") parser.add_argument("--diagnosis", type=str, choices=DIAGNOSIS_PROMPTS.keys(), required=True, help="Diagnosis to evaluate") parser.add_argument("--data-dir", type=str, default="data", help="Path to data directory") parser.add_argument("--raw-image", type=str, help="Path to a raw image file for inference (optional)") args = parser.parse_args() # Get prompts pos_txt, neg_txt = DIAGNOSIS_PROMPTS[args.diagnosis] logger.info(f"Diagnosis: {args.diagnosis}") logger.info(f"Positive query: '{pos_txt}'") logger.info(f"Negative query: '{neg_txt}'") # Load precomputed model for text embeddings (and image embeddings if no raw image) precomputed_model = PrecomputedModel(data_dir=args.data_dir) pos_emb, neg_emb = precomputed_model.get_diagnosis_embeddings(pos_txt, neg_txt) if args.raw_image: # Raw Image Inference Mode logger.info(f"Running inference on raw image: {args.raw_image}") raw_model = RawImageModel() try: image_emb = raw_model.compute_embeddings(args.raw_image) # image_emb shape is likely (1, 32, 128) or (32, 128) # PrecomputedModel.zero_shot expects flattened or (32, 128) score = PrecomputedModel.zero_shot(image_emb, pos_emb, neg_emb) logger.info(f"Zero-shot score for {args.raw_image}: {score:.4f}") # Since we only have one image, we can't calculate AUC meaningfully # unless we run it against the full validation set which takes time. # For this demo, just output the score. print(f"Score for {args.diagnosis}: {score}") except Exception as e: logger.error(f"Failed to process raw image: {e}") sys.exit(1) else: # Precomputed Embeddings Evaluation Mode (Full Dataset) logger.info("Running evaluation on full precomputed dataset...") # Filter labels for the target diagnosis (0 or 1) labels_df = precomputed_model.labels target_df = labels_df[labels_df[args.diagnosis].isin([0, 1])][['image_id', args.diagnosis]].copy() image_ids = target_df['image_id'].tolist() true_labels = target_df[args.diagnosis].tolist() # Compute scores valid_ids, scores = precomputed_model.compute_scores(image_ids, pos_emb, neg_emb) # Filter labels to match valid_ids found in embeddings final_labels = [] for img_id, label in zip(image_ids, true_labels): if img_id in valid_ids: final_labels.append(label) if not scores: logger.error("No valid scores computed. Check embedding match.") sys.exit(1) # Evaluate evaluate_predictions(scores, final_labels, args.diagnosis) if __name__ == "__main__": main()