File size: 4,113 Bytes
b412062
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1c0808a
 
 
 
 
 
 
 
 
 
 
b412062
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45f1909
 
 
 
b412062
 
 
 
 
45f1909
b412062
 
 
 
 
 
 
45f1909
b412062
 
 
 
 
 
 
 
 
45f1909
b412062
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import gradio as gr
import os
import sys
import logging
import numpy as np
from PIL import Image

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Suppress TensorFlow logging
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
try:
    import absl.logging
    absl.logging.set_verbosity(absl.logging.ERROR)
except ImportError:
    pass
logging.getLogger('tensorflow').setLevel(logging.ERROR)

from model import RawImageModel, PrecomputedModel

# Global Model Instances
raw_model = None
precomputed_model = None
pos_emb = None
neg_emb = None

# Optimal Threshold from Kaggle validation
THRESHOLD = -0.1173

def load_models():
    global raw_model, precomputed_model, pos_emb, neg_emb
    if raw_model is None:
        logger.info("Loading models...")
        try:
            precomputed_model = PrecomputedModel()
            raw_model = RawImageModel()
            
            # Pre-fetch text embeddings
            pos_txt = 'small pneumothorax'
            neg_txt = 'no pneumothorax'
            pos_emb, neg_emb = precomputed_model.get_diagnosis_embeddings(pos_txt, neg_txt)
            logger.info("Models loaded.")
        except Exception as e:
            logger.error(f"Failed to load models: {e}")
            raise e

# ZeroGPU compatibility for Hugging Face Spaces
try:
    import spaces
except ImportError:
    # Dummy decorator if running locally without spaces installed
    class spaces:
        @staticmethod
        def GPU(func):
            return func

@spaces.GPU
def predict(image):
    if image is None:
        return "No image uploaded.", 0.0, "Please upload an image."

    try:
        # Save temp image for model ingestion
        temp_path = "temp_gradio_upload.png"
        image.save(temp_path)
        
        # Run Inference
        img_emb = raw_model.compute_embeddings(temp_path)
        score = PrecomputedModel.zero_shot(img_emb, pos_emb, neg_emb)
        score = float(score)
        
        # Binary Classification
        if score >= THRESHOLD:
            prediction = '<p style="color: red; font-size: 24px; font-weight: bold;">PNEUMOTHORAX ⚠️</p>'
        else:
            prediction = '<p style="color: green; font-size: 24px; font-weight: bold;">NORMAL ✅</p>'
        
        return prediction, score, f"Raw Score: {score:.4f} (Threshold: {THRESHOLD})"

    except Exception as e:
        logger.error(f"Prediction failed: {e}")
        return "<p style='color:red'>Error</p>", 0.0, str(e)

# Load models at startup
load_models()

# UI Layout
with gr.Blocks(title="Chest X-Ray Zero-Shot Classifier") as demo:
    gr.Markdown("# 🩻 Zero-Shot Chest X-Ray Classification")
    gr.Markdown("Detect Pneumothorax from raw X-ray images using a pre-trained foundation model.")
    
    with gr.Row():
        with gr.Column():
            gr.Markdown("### 1. Upload X-Ray")
            input_image = gr.Image(type="pil", label="Upload Image (PNG/JPG/DICOM converted)")
            predict_btn = gr.Button("Analyze Image", variant="primary")
            
        with gr.Column():
            gr.Markdown("### 2. Results")
            output_label = gr.HTML(label="Prediction")
            output_score = gr.Number(label="Zero-Shot Score")
            output_msg = gr.Textbox(label="Details")
            
            gr.Markdown("---")
            gr.Markdown("### Performance Context")
            gr.Markdown("This model uses a **zero-shot** approach. The threshold was calibrated using a local Kaggle dataset.")
            
            with gr.Tabs():
                with gr.TabItem("Local Kaggle Benchmark"):
                    gr.Image("results/kaggle_roc_curve.png", label="local ROC Curve")
                    gr.Markdown("**AUC: 0.88** on 250 local samples.")
                with gr.TabItem("Google Benchmark"):
                    gr.Image("results/roc_PNEUMOTHORAX.png", label="Reference ROC")

    predict_btn.click(predict, inputs=input_image, outputs=[output_label, output_score, output_msg])

if __name__ == "__main__":
    demo.launch(share=True)