chest-xray-classification / src /evaluate_kaggle.py
rohitium's picture
Deploy Chest X-Ray App (LFS)
b412062
import os
import sys
import pandas as pd
import logging
import argparse
import numpy as np
from tqdm import tqdm
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# Suppress TensorFlow logging
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
try:
import absl.logging
absl.logging.set_verbosity(absl.logging.ERROR)
except ImportError:
pass
import logging
logging.getLogger('tensorflow').setLevel(logging.ERROR)
from model import RawImageModel, PrecomputedModel
from dicom_utils import read_dicom_image
from PIL import Image
def main():
parser = argparse.ArgumentParser(description="Evaluate on Kaggle DICOM Dataset")
parser.add_argument("--csv", default="data/kaggle/labels.csv", help="Path to labels CSV")
parser.add_argument("--data-dir", default="data/kaggle", help="Root directory for images if relative paths in CSV")
parser.add_argument("--output", default="results/kaggle_predictions.csv", help="Output predictions file")
args = parser.parse_args()
# Create output directory
os.makedirs(os.path.dirname(args.output), exist_ok=True)
# Load dataset
try:
df = pd.read_csv(args.csv)
logger.info(f"Loaded {len(df)} records from {args.csv}")
except Exception as e:
logger.error(f"Failed to load CSV: {e}")
return
# Check for file column
file_col = 'file' if 'file' in df.columns else 'dicom_file' # Adapt to potential column names
if file_col not in df.columns and 'file' not in df.columns:
# Fallback inspection or error
logger.error(f"Missing file column in CSV. Found: {df.columns}")
return
# Initialize Models
try:
# We need PrecomputedModel for text embeddings (labels)
precomputed_model = PrecomputedModel()
# We need RawImageModel for the images
raw_model = RawImageModel()
logger.info("Models loaded successfully.")
except Exception as e:
logger.fatal(f"Failed to initialize models: {e}")
return
# Get text embeddings for diagnosis
diagnosis = 'PNEUMOTHORAX'
try:
# Hardcoded prompts matching main.py
pos_txt = 'small pneumothorax'
neg_txt = 'no pneumothorax'
pos_emb, neg_emb = precomputed_model.get_diagnosis_embeddings(pos_txt, neg_txt)
except Exception as e:
logger.fatal(f"Failed to get text embeddings: {e}")
return
predictions = []
# Iterate and predict
print(f"Running inference for {diagnosis} on {len(df)} images...")
temp_path = "temp_inference.png"
for _, row in tqdm(df.iterrows(), total=len(df)):
file_path = row[file_col]
# Construct full path
full_path = os.path.join(args.data_dir, file_path) if not os.path.isabs(file_path) else file_path
# Check if file exists
if not os.path.exists(full_path):
logger.warning(f"File not found: {full_path}")
predictions.append({
'file': file_path,
'true_label': None,
'pneumothorax_score': None,
'error': 'File not found'
})
continue
true_label = row.get('label', row.get('PNEUMOTHORAX', 'Unknown'))
try:
# 1. Read DICOM
image_array = read_dicom_image(full_path)
# 2. Save as temp PNG (Required by RawImageModel/TF pipeline currently)
Image.fromarray(image_array).save(temp_path)
# 3. Compute Image Embedding
img_emb = raw_model.compute_embeddings(temp_path)
# 4. Compute Zero-Shot Score
score = PrecomputedModel.zero_shot(img_emb, pos_emb, neg_emb)
predictions.append({
'file': file_path,
'true_label': true_label,
'pneumothorax_score': float(score)
})
except Exception as e:
# logger.warning(f"Failed to process {file_path}: {e}")
predictions.append({
'file': file_path,
'true_label': true_label,
'pneumothorax_score': None,
'error': str(e)
})
# Incremental Save every 10 items
if len(predictions) % 10 == 0:
pd.DataFrame(predictions).to_csv(args.output, index=False)
# Final Save
results_df = pd.DataFrame(predictions)
results_df.to_csv(args.output, index=False)
logger.info(f"Predictions saved to {args.output}")
# Cleanup
if os.path.exists("temp_inference.png"):
os.remove("temp_inference.png")
if __name__ == "__main__":
main()