|
|
import os |
|
|
import logging |
|
|
import numpy as np |
|
|
import matplotlib.pyplot as plt |
|
|
from sklearn.metrics import roc_curve, auc |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def evaluate_predictions(scores, true_labels, diagnosis, output_dir="results"): |
|
|
"""Calculates AUC and generates ROC plot.""" |
|
|
|
|
|
if not os.path.exists(output_dir): |
|
|
os.makedirs(output_dir) |
|
|
|
|
|
fpr, tpr, thresholds = roc_curve(true_labels, scores) |
|
|
roc_auc = auc(fpr, tpr) |
|
|
|
|
|
logger.info(f"Diagnosis: {diagnosis}") |
|
|
logger.info(f"AUC: {roc_auc:.4f}") |
|
|
|
|
|
|
|
|
plt.figure() |
|
|
lw = 2 |
|
|
plt.plot( |
|
|
fpr, |
|
|
tpr, |
|
|
color="darkorange", |
|
|
lw=lw, |
|
|
label="ROC curve (area = %0.2f)" % roc_auc, |
|
|
) |
|
|
plt.plot([0, 1], [0, 1], color="navy", lw=lw, linestyle="--") |
|
|
plt.xlim([0.0, 1.0]) |
|
|
plt.ylim([0.0, 1.05]) |
|
|
plt.xlabel("False Positive Rate") |
|
|
plt.ylabel("True Positive Rate") |
|
|
plt.title(f"ROC for {diagnosis}") |
|
|
plt.legend(loc="lower right") |
|
|
|
|
|
plot_path = os.path.join(output_dir, f"roc_{diagnosis}.png") |
|
|
plt.savefig(plot_path) |
|
|
logger.info(f"ROC plot saved to {plot_path}") |
|
|
|
|
|
return roc_auc |
|
|
|