|
|
import os |
|
|
import sys |
|
|
import pandas as pd |
|
|
import logging |
|
|
import argparse |
|
|
import numpy as np |
|
|
from tqdm import tqdm |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' |
|
|
try: |
|
|
import absl.logging |
|
|
absl.logging.set_verbosity(absl.logging.ERROR) |
|
|
except ImportError: |
|
|
pass |
|
|
import logging |
|
|
logging.getLogger('tensorflow').setLevel(logging.ERROR) |
|
|
|
|
|
from model import RawImageModel, PrecomputedModel |
|
|
from dicom_utils import read_dicom_image |
|
|
from PIL import Image |
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser(description="Evaluate on Kaggle DICOM Dataset") |
|
|
parser.add_argument("--csv", default="data/kaggle/labels.csv", help="Path to labels CSV") |
|
|
parser.add_argument("--data-dir", default="data/kaggle", help="Root directory for images if relative paths in CSV") |
|
|
parser.add_argument("--output", default="results/kaggle_predictions.csv", help="Output predictions file") |
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
os.makedirs(os.path.dirname(args.output), exist_ok=True) |
|
|
|
|
|
|
|
|
try: |
|
|
df = pd.read_csv(args.csv) |
|
|
logger.info(f"Loaded {len(df)} records from {args.csv}") |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to load CSV: {e}") |
|
|
return |
|
|
|
|
|
|
|
|
file_col = 'file' if 'file' in df.columns else 'dicom_file' |
|
|
if file_col not in df.columns and 'file' not in df.columns: |
|
|
|
|
|
logger.error(f"Missing file column in CSV. Found: {df.columns}") |
|
|
return |
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
precomputed_model = PrecomputedModel() |
|
|
|
|
|
|
|
|
raw_model = RawImageModel() |
|
|
logger.info("Models loaded successfully.") |
|
|
except Exception as e: |
|
|
logger.fatal(f"Failed to initialize models: {e}") |
|
|
return |
|
|
|
|
|
|
|
|
diagnosis = 'PNEUMOTHORAX' |
|
|
try: |
|
|
|
|
|
pos_txt = 'small pneumothorax' |
|
|
neg_txt = 'no pneumothorax' |
|
|
pos_emb, neg_emb = precomputed_model.get_diagnosis_embeddings(pos_txt, neg_txt) |
|
|
except Exception as e: |
|
|
logger.fatal(f"Failed to get text embeddings: {e}") |
|
|
return |
|
|
|
|
|
predictions = [] |
|
|
|
|
|
|
|
|
print(f"Running inference for {diagnosis} on {len(df)} images...") |
|
|
|
|
|
temp_path = "temp_inference.png" |
|
|
|
|
|
for _, row in tqdm(df.iterrows(), total=len(df)): |
|
|
file_path = row[file_col] |
|
|
|
|
|
full_path = os.path.join(args.data_dir, file_path) if not os.path.isabs(file_path) else file_path |
|
|
|
|
|
|
|
|
if not os.path.exists(full_path): |
|
|
logger.warning(f"File not found: {full_path}") |
|
|
predictions.append({ |
|
|
'file': file_path, |
|
|
'true_label': None, |
|
|
'pneumothorax_score': None, |
|
|
'error': 'File not found' |
|
|
}) |
|
|
continue |
|
|
|
|
|
true_label = row.get('label', row.get('PNEUMOTHORAX', 'Unknown')) |
|
|
|
|
|
try: |
|
|
|
|
|
image_array = read_dicom_image(full_path) |
|
|
|
|
|
|
|
|
Image.fromarray(image_array).save(temp_path) |
|
|
|
|
|
|
|
|
img_emb = raw_model.compute_embeddings(temp_path) |
|
|
|
|
|
|
|
|
score = PrecomputedModel.zero_shot(img_emb, pos_emb, neg_emb) |
|
|
|
|
|
predictions.append({ |
|
|
'file': file_path, |
|
|
'true_label': true_label, |
|
|
'pneumothorax_score': float(score) |
|
|
}) |
|
|
|
|
|
except Exception as e: |
|
|
|
|
|
predictions.append({ |
|
|
'file': file_path, |
|
|
'true_label': true_label, |
|
|
'pneumothorax_score': None, |
|
|
'error': str(e) |
|
|
}) |
|
|
|
|
|
|
|
|
if len(predictions) % 10 == 0: |
|
|
pd.DataFrame(predictions).to_csv(args.output, index=False) |
|
|
|
|
|
|
|
|
results_df = pd.DataFrame(predictions) |
|
|
results_df.to_csv(args.output, index=False) |
|
|
logger.info(f"Predictions saved to {args.output}") |
|
|
|
|
|
|
|
|
if os.path.exists("temp_inference.png"): |
|
|
os.remove("temp_inference.png") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|