import gradio as gr import os import sys import logging import numpy as np from PIL import Image # Configure logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) # Suppress TensorFlow logging os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' try: import absl.logging absl.logging.set_verbosity(absl.logging.ERROR) except ImportError: pass logging.getLogger('tensorflow').setLevel(logging.ERROR) from model import RawImageModel, PrecomputedModel # Global Model Instances raw_model = None precomputed_model = None pos_emb = None neg_emb = None # Optimal Threshold from Kaggle validation THRESHOLD = -0.1173 def load_models(): global raw_model, precomputed_model, pos_emb, neg_emb if raw_model is None: logger.info("Loading models...") try: precomputed_model = PrecomputedModel() raw_model = RawImageModel() # Pre-fetch text embeddings pos_txt = 'small pneumothorax' neg_txt = 'no pneumothorax' pos_emb, neg_emb = precomputed_model.get_diagnosis_embeddings(pos_txt, neg_txt) logger.info("Models loaded.") except Exception as e: logger.error(f"Failed to load models: {e}") raise e # ZeroGPU compatibility for Hugging Face Spaces try: import spaces except ImportError: # Dummy decorator if running locally without spaces installed class spaces: @staticmethod def GPU(func): return func @spaces.GPU def predict(image): if image is None: return "No image uploaded.", 0.0, "Please upload an image." try: # Save temp image for model ingestion temp_path = "temp_gradio_upload.png" image.save(temp_path) # Run Inference img_emb = raw_model.compute_embeddings(temp_path) score = PrecomputedModel.zero_shot(img_emb, pos_emb, neg_emb) score = float(score) # Binary Classification if score >= THRESHOLD: prediction = '
PNEUMOTHORAX ⚠️
' else: prediction = 'NORMAL ✅
' return prediction, score, f"Raw Score: {score:.4f} (Threshold: {THRESHOLD})" except Exception as e: logger.error(f"Prediction failed: {e}") return "Error
", 0.0, str(e) # Load models at startup load_models() # UI Layout with gr.Blocks(title="Chest X-Ray Zero-Shot Classifier") as demo: gr.Markdown("# 🩻 Zero-Shot Chest X-Ray Classification") gr.Markdown("Detect Pneumothorax from raw X-ray images using a pre-trained foundation model.") with gr.Row(): with gr.Column(): gr.Markdown("### 1. Upload X-Ray") input_image = gr.Image(type="pil", label="Upload Image (PNG/JPG/DICOM converted)") predict_btn = gr.Button("Analyze Image", variant="primary") with gr.Column(): gr.Markdown("### 2. Results") output_label = gr.HTML(label="Prediction") output_score = gr.Number(label="Zero-Shot Score") output_msg = gr.Textbox(label="Details") gr.Markdown("---") gr.Markdown("### Performance Context") gr.Markdown("This model uses a **zero-shot** approach. The threshold was calibrated using a local Kaggle dataset.") with gr.Tabs(): with gr.TabItem("Local Kaggle Benchmark"): gr.Image("results/kaggle_roc_curve.png", label="local ROC Curve") gr.Markdown("**AUC: 0.88** on 250 local samples.") with gr.TabItem("Google Benchmark"): gr.Image("results/roc_PNEUMOTHORAX.png", label="Reference ROC") predict_btn.click(predict, inputs=input_image, outputs=[output_label, output_score, output_msg]) if __name__ == "__main__": demo.launch(share=True)