Spaces:
Sleeping
Sleeping
| import pandas as pd | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| from sklearn.metrics import roc_curve, auc | |
| import logging | |
| import os | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
| logger = logging.getLogger(__name__) | |
| def plot_roc_curve(results_path, output_image_path): | |
| """ | |
| Reads predictions CSV, calculates AUC, and plots ROC curve. | |
| """ | |
| if not os.path.exists(results_path): | |
| logger.error(f"Results file not found: {results_path}") | |
| return | |
| try: | |
| df = pd.read_csv(results_path) | |
| logger.info(f"Loaded {len(df)} predictions from {results_path}") | |
| # Filter out errors | |
| df = df.dropna(subset=['pneumothorax_score']) | |
| if len(df) == 0: | |
| logger.error("No valid predictions found.") | |
| return | |
| # Prepare True Labels (Binary) | |
| # Kaggle Labels: 'Pneumothorax' vs 'No Pneumothorax' | |
| y_true = (df['true_label'] == 'Pneumothorax').astype(int) | |
| y_scores = df['pneumothorax_score'] | |
| # Calculate ROC and AUC | |
| fpr, tpr, thresholds = roc_curve(y_true, y_scores) | |
| roc_auc = auc(fpr, tpr) | |
| logger.info(f"Calculated AUC: {roc_auc:.4f}") | |
| # Plot | |
| plt.figure(figsize=(8, 6)) | |
| plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (area = {roc_auc:.2f})') | |
| plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--') | |
| plt.xlim([0.0, 1.0]) | |
| plt.ylim([0.0, 1.05]) | |
| plt.xlabel('False Positive Rate') | |
| plt.ylabel('True Positive Rate') | |
| plt.title('ROC Curve - Zero-Shot Pneumothorax Classification (Kaggle)') | |
| plt.legend(loc="lower right") | |
| plt.grid(True, alpha=0.3) | |
| plt.savefig(output_image_path) | |
| logger.info(f"ROC curve saved to {output_image_path}") | |
| plt.close() | |
| except Exception as e: | |
| logger.error(f"Failed to plot ROC curve: {e}") | |
| if __name__ == "__main__": | |
| results_file = "results/kaggle_predictions.csv" | |
| output_image = "results/kaggle_roc_curve.png" | |
| plot_roc_curve(results_file, output_image) | |