|
|
import gradio as gr |
|
|
import os |
|
|
import sys |
|
|
import logging |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' |
|
|
try: |
|
|
import absl.logging |
|
|
absl.logging.set_verbosity(absl.logging.ERROR) |
|
|
except ImportError: |
|
|
pass |
|
|
logging.getLogger('tensorflow').setLevel(logging.ERROR) |
|
|
|
|
|
from model import RawImageModel, PrecomputedModel |
|
|
|
|
|
|
|
|
raw_model = None |
|
|
precomputed_model = None |
|
|
pos_emb = None |
|
|
neg_emb = None |
|
|
|
|
|
|
|
|
THRESHOLD = -0.1173 |
|
|
|
|
|
def load_models(): |
|
|
global raw_model, precomputed_model, pos_emb, neg_emb |
|
|
if raw_model is None: |
|
|
logger.info("Loading models...") |
|
|
try: |
|
|
precomputed_model = PrecomputedModel() |
|
|
raw_model = RawImageModel() |
|
|
|
|
|
|
|
|
pos_txt = 'small pneumothorax' |
|
|
neg_txt = 'no pneumothorax' |
|
|
pos_emb, neg_emb = precomputed_model.get_diagnosis_embeddings(pos_txt, neg_txt) |
|
|
logger.info("Models loaded.") |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to load models: {e}") |
|
|
raise e |
|
|
|
|
|
|
|
|
try: |
|
|
import spaces |
|
|
except ImportError: |
|
|
|
|
|
class spaces: |
|
|
@staticmethod |
|
|
def GPU(func): |
|
|
return func |
|
|
|
|
|
@spaces.GPU |
|
|
def predict(image): |
|
|
if image is None: |
|
|
return "No image uploaded.", 0.0, "Please upload an image." |
|
|
|
|
|
try: |
|
|
|
|
|
temp_path = "temp_gradio_upload.png" |
|
|
image.save(temp_path) |
|
|
|
|
|
|
|
|
img_emb = raw_model.compute_embeddings(temp_path) |
|
|
score = PrecomputedModel.zero_shot(img_emb, pos_emb, neg_emb) |
|
|
score = float(score) |
|
|
|
|
|
|
|
|
if score >= THRESHOLD: |
|
|
prediction = '<p style="color: red; font-size: 24px; font-weight: bold;">PNEUMOTHORAX ⚠️</p>' |
|
|
else: |
|
|
prediction = '<p style="color: green; font-size: 24px; font-weight: bold;">NORMAL ✅</p>' |
|
|
|
|
|
return prediction, score, f"Raw Score: {score:.4f} (Threshold: {THRESHOLD})" |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Prediction failed: {e}") |
|
|
return "<p style='color:red'>Error</p>", 0.0, str(e) |
|
|
|
|
|
|
|
|
load_models() |
|
|
|
|
|
|
|
|
with gr.Blocks(title="Chest X-Ray Zero-Shot Classifier") as demo: |
|
|
gr.Markdown("# 🩻 Zero-Shot Chest X-Ray Classification") |
|
|
gr.Markdown("Detect Pneumothorax from raw X-ray images using a pre-trained foundation model.") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
gr.Markdown("### 1. Upload X-Ray") |
|
|
input_image = gr.Image(type="pil", label="Upload Image (PNG/JPG/DICOM converted)") |
|
|
predict_btn = gr.Button("Analyze Image", variant="primary") |
|
|
|
|
|
with gr.Column(): |
|
|
gr.Markdown("### 2. Results") |
|
|
output_label = gr.HTML(label="Prediction") |
|
|
output_score = gr.Number(label="Zero-Shot Score") |
|
|
output_msg = gr.Textbox(label="Details") |
|
|
|
|
|
gr.Markdown("---") |
|
|
gr.Markdown("### Performance Context") |
|
|
gr.Markdown("This model uses a **zero-shot** approach. The threshold was calibrated using a local Kaggle dataset.") |
|
|
|
|
|
with gr.Tabs(): |
|
|
with gr.TabItem("Local Kaggle Benchmark"): |
|
|
gr.Image("results/kaggle_roc_curve.png", label="local ROC Curve") |
|
|
gr.Markdown("**AUC: 0.88** on 250 local samples.") |
|
|
with gr.TabItem("Google Benchmark"): |
|
|
gr.Image("results/roc_PNEUMOTHORAX.png", label="Reference ROC") |
|
|
|
|
|
predict_btn.click(predict, inputs=input_image, outputs=[output_label, output_score, output_msg]) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch(share=True) |
|
|
|