rohitium's picture
Fix HF ZeroGPU support: Add spaces decorator
1c0808a
import gradio as gr
import os
import sys
import logging
import numpy as np
from PIL import Image
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# Suppress TensorFlow logging
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
try:
import absl.logging
absl.logging.set_verbosity(absl.logging.ERROR)
except ImportError:
pass
logging.getLogger('tensorflow').setLevel(logging.ERROR)
from model import RawImageModel, PrecomputedModel
# Global Model Instances
raw_model = None
precomputed_model = None
pos_emb = None
neg_emb = None
# Optimal Threshold from Kaggle validation
THRESHOLD = -0.1173
def load_models():
global raw_model, precomputed_model, pos_emb, neg_emb
if raw_model is None:
logger.info("Loading models...")
try:
precomputed_model = PrecomputedModel()
raw_model = RawImageModel()
# Pre-fetch text embeddings
pos_txt = 'small pneumothorax'
neg_txt = 'no pneumothorax'
pos_emb, neg_emb = precomputed_model.get_diagnosis_embeddings(pos_txt, neg_txt)
logger.info("Models loaded.")
except Exception as e:
logger.error(f"Failed to load models: {e}")
raise e
# ZeroGPU compatibility for Hugging Face Spaces
try:
import spaces
except ImportError:
# Dummy decorator if running locally without spaces installed
class spaces:
@staticmethod
def GPU(func):
return func
@spaces.GPU
def predict(image):
if image is None:
return "No image uploaded.", 0.0, "Please upload an image."
try:
# Save temp image for model ingestion
temp_path = "temp_gradio_upload.png"
image.save(temp_path)
# Run Inference
img_emb = raw_model.compute_embeddings(temp_path)
score = PrecomputedModel.zero_shot(img_emb, pos_emb, neg_emb)
score = float(score)
# Binary Classification
if score >= THRESHOLD:
prediction = '<p style="color: red; font-size: 24px; font-weight: bold;">PNEUMOTHORAX ⚠️</p>'
else:
prediction = '<p style="color: green; font-size: 24px; font-weight: bold;">NORMAL ✅</p>'
return prediction, score, f"Raw Score: {score:.4f} (Threshold: {THRESHOLD})"
except Exception as e:
logger.error(f"Prediction failed: {e}")
return "<p style='color:red'>Error</p>", 0.0, str(e)
# Load models at startup
load_models()
# UI Layout
with gr.Blocks(title="Chest X-Ray Zero-Shot Classifier") as demo:
gr.Markdown("# 🩻 Zero-Shot Chest X-Ray Classification")
gr.Markdown("Detect Pneumothorax from raw X-ray images using a pre-trained foundation model.")
with gr.Row():
with gr.Column():
gr.Markdown("### 1. Upload X-Ray")
input_image = gr.Image(type="pil", label="Upload Image (PNG/JPG/DICOM converted)")
predict_btn = gr.Button("Analyze Image", variant="primary")
with gr.Column():
gr.Markdown("### 2. Results")
output_label = gr.HTML(label="Prediction")
output_score = gr.Number(label="Zero-Shot Score")
output_msg = gr.Textbox(label="Details")
gr.Markdown("---")
gr.Markdown("### Performance Context")
gr.Markdown("This model uses a **zero-shot** approach. The threshold was calibrated using a local Kaggle dataset.")
with gr.Tabs():
with gr.TabItem("Local Kaggle Benchmark"):
gr.Image("results/kaggle_roc_curve.png", label="local ROC Curve")
gr.Markdown("**AUC: 0.88** on 250 local samples.")
with gr.TabItem("Google Benchmark"):
gr.Image("results/roc_PNEUMOTHORAX.png", label="Reference ROC")
predict_btn.click(predict, inputs=input_image, outputs=[output_label, output_score, output_msg])
if __name__ == "__main__":
demo.launch(share=True)