import os import sys import pandas as pd import logging import argparse import numpy as np from tqdm import tqdm # Configure logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) # Suppress TensorFlow logging os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' try: import absl.logging absl.logging.set_verbosity(absl.logging.ERROR) except ImportError: pass import logging logging.getLogger('tensorflow').setLevel(logging.ERROR) from model import RawImageModel, PrecomputedModel from dicom_utils import read_dicom_image from PIL import Image def main(): parser = argparse.ArgumentParser(description="Evaluate on Kaggle DICOM Dataset") parser.add_argument("--csv", default="data/kaggle/labels.csv", help="Path to labels CSV") parser.add_argument("--data-dir", default="data/kaggle", help="Root directory for images if relative paths in CSV") parser.add_argument("--output", default="results/kaggle_predictions.csv", help="Output predictions file") args = parser.parse_args() # Create output directory os.makedirs(os.path.dirname(args.output), exist_ok=True) # Load dataset try: df = pd.read_csv(args.csv) logger.info(f"Loaded {len(df)} records from {args.csv}") except Exception as e: logger.error(f"Failed to load CSV: {e}") return # Check for file column file_col = 'file' if 'file' in df.columns else 'dicom_file' # Adapt to potential column names if file_col not in df.columns and 'file' not in df.columns: # Fallback inspection or error logger.error(f"Missing file column in CSV. Found: {df.columns}") return # Initialize Models try: # We need PrecomputedModel for text embeddings (labels) precomputed_model = PrecomputedModel() # We need RawImageModel for the images raw_model = RawImageModel() logger.info("Models loaded successfully.") except Exception as e: logger.fatal(f"Failed to initialize models: {e}") return # Get text embeddings for diagnosis diagnosis = 'PNEUMOTHORAX' try: # Hardcoded prompts matching main.py pos_txt = 'small pneumothorax' neg_txt = 'no pneumothorax' pos_emb, neg_emb = precomputed_model.get_diagnosis_embeddings(pos_txt, neg_txt) except Exception as e: logger.fatal(f"Failed to get text embeddings: {e}") return predictions = [] # Iterate and predict print(f"Running inference for {diagnosis} on {len(df)} images...") temp_path = "temp_inference.png" for _, row in tqdm(df.iterrows(), total=len(df)): file_path = row[file_col] # Construct full path full_path = os.path.join(args.data_dir, file_path) if not os.path.isabs(file_path) else file_path # Check if file exists if not os.path.exists(full_path): logger.warning(f"File not found: {full_path}") predictions.append({ 'file': file_path, 'true_label': None, 'pneumothorax_score': None, 'error': 'File not found' }) continue true_label = row.get('label', row.get('PNEUMOTHORAX', 'Unknown')) try: # 1. Read DICOM image_array = read_dicom_image(full_path) # 2. Save as temp PNG (Required by RawImageModel/TF pipeline currently) Image.fromarray(image_array).save(temp_path) # 3. Compute Image Embedding img_emb = raw_model.compute_embeddings(temp_path) # 4. Compute Zero-Shot Score score = PrecomputedModel.zero_shot(img_emb, pos_emb, neg_emb) predictions.append({ 'file': file_path, 'true_label': true_label, 'pneumothorax_score': float(score) }) except Exception as e: # logger.warning(f"Failed to process {file_path}: {e}") predictions.append({ 'file': file_path, 'true_label': true_label, 'pneumothorax_score': None, 'error': str(e) }) # Incremental Save every 10 items if len(predictions) % 10 == 0: pd.DataFrame(predictions).to_csv(args.output, index=False) # Final Save results_df = pd.DataFrame(predictions) results_df.to_csv(args.output, index=False) logger.info(f"Predictions saved to {args.output}") # Cleanup if os.path.exists("temp_inference.png"): os.remove("temp_inference.png") if __name__ == "__main__": main()