Spaces:
Sleeping
Sleeping
File size: 4,113 Bytes
b412062 1c0808a b412062 45f1909 b412062 45f1909 b412062 45f1909 b412062 45f1909 b412062 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 |
import gradio as gr
import os
import sys
import logging
import numpy as np
from PIL import Image
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# Suppress TensorFlow logging
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
try:
import absl.logging
absl.logging.set_verbosity(absl.logging.ERROR)
except ImportError:
pass
logging.getLogger('tensorflow').setLevel(logging.ERROR)
from model import RawImageModel, PrecomputedModel
# Global Model Instances
raw_model = None
precomputed_model = None
pos_emb = None
neg_emb = None
# Optimal Threshold from Kaggle validation
THRESHOLD = -0.1173
def load_models():
global raw_model, precomputed_model, pos_emb, neg_emb
if raw_model is None:
logger.info("Loading models...")
try:
precomputed_model = PrecomputedModel()
raw_model = RawImageModel()
# Pre-fetch text embeddings
pos_txt = 'small pneumothorax'
neg_txt = 'no pneumothorax'
pos_emb, neg_emb = precomputed_model.get_diagnosis_embeddings(pos_txt, neg_txt)
logger.info("Models loaded.")
except Exception as e:
logger.error(f"Failed to load models: {e}")
raise e
# ZeroGPU compatibility for Hugging Face Spaces
try:
import spaces
except ImportError:
# Dummy decorator if running locally without spaces installed
class spaces:
@staticmethod
def GPU(func):
return func
@spaces.GPU
def predict(image):
if image is None:
return "No image uploaded.", 0.0, "Please upload an image."
try:
# Save temp image for model ingestion
temp_path = "temp_gradio_upload.png"
image.save(temp_path)
# Run Inference
img_emb = raw_model.compute_embeddings(temp_path)
score = PrecomputedModel.zero_shot(img_emb, pos_emb, neg_emb)
score = float(score)
# Binary Classification
if score >= THRESHOLD:
prediction = '<p style="color: red; font-size: 24px; font-weight: bold;">PNEUMOTHORAX ⚠️</p>'
else:
prediction = '<p style="color: green; font-size: 24px; font-weight: bold;">NORMAL ✅</p>'
return prediction, score, f"Raw Score: {score:.4f} (Threshold: {THRESHOLD})"
except Exception as e:
logger.error(f"Prediction failed: {e}")
return "<p style='color:red'>Error</p>", 0.0, str(e)
# Load models at startup
load_models()
# UI Layout
with gr.Blocks(title="Chest X-Ray Zero-Shot Classifier") as demo:
gr.Markdown("# 🩻 Zero-Shot Chest X-Ray Classification")
gr.Markdown("Detect Pneumothorax from raw X-ray images using a pre-trained foundation model.")
with gr.Row():
with gr.Column():
gr.Markdown("### 1. Upload X-Ray")
input_image = gr.Image(type="pil", label="Upload Image (PNG/JPG/DICOM converted)")
predict_btn = gr.Button("Analyze Image", variant="primary")
with gr.Column():
gr.Markdown("### 2. Results")
output_label = gr.HTML(label="Prediction")
output_score = gr.Number(label="Zero-Shot Score")
output_msg = gr.Textbox(label="Details")
gr.Markdown("---")
gr.Markdown("### Performance Context")
gr.Markdown("This model uses a **zero-shot** approach. The threshold was calibrated using a local Kaggle dataset.")
with gr.Tabs():
with gr.TabItem("Local Kaggle Benchmark"):
gr.Image("results/kaggle_roc_curve.png", label="local ROC Curve")
gr.Markdown("**AUC: 0.88** on 250 local samples.")
with gr.TabItem("Google Benchmark"):
gr.Image("results/roc_PNEUMOTHORAX.png", label="Reference ROC")
predict_btn.click(predict, inputs=input_image, outputs=[output_label, output_score, output_msg])
if __name__ == "__main__":
demo.launch(share=True)
|