rohitium's picture
Deploy Chest X-Ray App (LFS)
b412062
import os
import argparse
import logging
import sys
# Suppress TensorFlow and system warnings
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # FATAL
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
import warnings
warnings.filterwarnings('ignore')
import numpy as np
import pandas as pd
# Configure logging first
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# Suppress absl logging from TensorFlow
try:
import absl.logging
absl.logging.set_verbosity(absl.logging.ERROR)
except ImportError:
pass
# Suppress TensorFlow Python logging
logging.getLogger('tensorflow').setLevel(logging.ERROR)
from model import PrecomputedModel, RawImageModel
from evaluate import evaluate_predictions
DIAGNOSIS_PROMPTS = {
'AIRSPACE_OPACITY': ('Airspace Opacity', 'no evidence of airspace disease'),
'PNEUMOTHORAX': ('small pneumothorax', 'no pneumothorax'),
'EFFUSION': ('large pleural effusion', 'no pleural effusion'),
'PULMONARY_EDEMA': ('moderate pulmonary edema', 'no pulmonary edema'),
}
def main():
parser = argparse.ArgumentParser(description="Zero-Shot Chest X-Ray Classification")
parser.add_argument("--diagnosis", type=str, choices=DIAGNOSIS_PROMPTS.keys(), required=True, help="Diagnosis to evaluate")
parser.add_argument("--data-dir", type=str, default="data", help="Path to data directory")
parser.add_argument("--raw-image", type=str, help="Path to a raw image file for inference (optional)")
args = parser.parse_args()
# Get prompts
pos_txt, neg_txt = DIAGNOSIS_PROMPTS[args.diagnosis]
logger.info(f"Diagnosis: {args.diagnosis}")
logger.info(f"Positive query: '{pos_txt}'")
logger.info(f"Negative query: '{neg_txt}'")
# Load precomputed model for text embeddings (and image embeddings if no raw image)
precomputed_model = PrecomputedModel(data_dir=args.data_dir)
pos_emb, neg_emb = precomputed_model.get_diagnosis_embeddings(pos_txt, neg_txt)
if args.raw_image:
# Raw Image Inference Mode
logger.info(f"Running inference on raw image: {args.raw_image}")
raw_model = RawImageModel()
try:
image_emb = raw_model.compute_embeddings(args.raw_image)
# image_emb shape is likely (1, 32, 128) or (32, 128)
# PrecomputedModel.zero_shot expects flattened or (32, 128)
score = PrecomputedModel.zero_shot(image_emb, pos_emb, neg_emb)
logger.info(f"Zero-shot score for {args.raw_image}: {score:.4f}")
# Since we only have one image, we can't calculate AUC meaningfully
# unless we run it against the full validation set which takes time.
# For this demo, just output the score.
print(f"Score for {args.diagnosis}: {score}")
except Exception as e:
logger.error(f"Failed to process raw image: {e}")
sys.exit(1)
else:
# Precomputed Embeddings Evaluation Mode (Full Dataset)
logger.info("Running evaluation on full precomputed dataset...")
# Filter labels for the target diagnosis (0 or 1)
labels_df = precomputed_model.labels
target_df = labels_df[labels_df[args.diagnosis].isin([0, 1])][['image_id', args.diagnosis]].copy()
image_ids = target_df['image_id'].tolist()
true_labels = target_df[args.diagnosis].tolist()
# Compute scores
valid_ids, scores = precomputed_model.compute_scores(image_ids, pos_emb, neg_emb)
# Filter labels to match valid_ids found in embeddings
final_labels = []
for img_id, label in zip(image_ids, true_labels):
if img_id in valid_ids:
final_labels.append(label)
if not scores:
logger.error("No valid scores computed. Check embedding match.")
sys.exit(1)
# Evaluate
evaluate_predictions(scores, final_labels, args.diagnosis)
if __name__ == "__main__":
main()