rohitium's picture
Deploy Chest X-Ray App (LFS)
b412062
import os
import logging
import numpy as np
import pandas as pd
from huggingface_hub import snapshot_download
# Configure logging
logger = logging.getLogger(__name__)
class PrecomputedModel:
def __init__(self, data_dir="data"):
self.data_dir = data_dir
self.image_embeddings = None
self.text_embeddings = None
self.labels = None
self._load_data()
def _load_data(self):
"""Loads precomputed embeddings and labels."""
img_emb_path = os.path.join(self.data_dir, "precomputed_image_embeddings.npz")
txt_emb_path = os.path.join(self.data_dir, "precomputed_text_embeddings.npz")
labels_path = os.path.join(self.data_dir, "cxr14_subset_labels.csv")
# Text embeddings are strictly required for Zero-Shot
if not os.path.exists(txt_emb_path):
raise FileNotFoundError(f"Missing required text embeddings: {txt_emb_path}")
logger.info("Loading precomputed text embeddings...")
with np.load(txt_emb_path) as data:
self.text_embeddings = {key: data[key] for key in data}
# Image embeddings (Optional, only for benchmarking)
if os.path.exists(img_emb_path):
logger.info("Loading precomputed image embeddings...")
with np.load(img_emb_path) as data:
self.image_embeddings = {key: data[key] for key in data}
else:
logger.warning("Precomputed image embeddings not found. Benchmarking features will be disabled.")
# Labels (Optional)
if os.path.exists(labels_path):
logger.info("Loading labels...")
self.labels = pd.read_csv(labels_path)
else:
logger.warning("Labels file not found.")
def get_diagnosis_embeddings(self, pos_txt, neg_txt):
"""Retrieves embeddings for positive and negative text queries."""
if pos_txt not in self.text_embeddings:
raise ValueError(f"Positive query '{pos_txt}' not found in precomputed embeddings.")
if neg_txt not in self.text_embeddings:
raise ValueError(f"Negative query '{neg_txt}' not found in precomputed embeddings.")
return self.text_embeddings[pos_txt], self.text_embeddings[neg_txt]
def compute_scores(self, image_ids, pos_emb, neg_emb):
"""Computes zero-shot scores for a list of image IDs."""
scores = []
valid_ids = []
for img_id in image_ids:
if img_id not in self.image_embeddings:
continue
img_emb = self.image_embeddings[img_id]
score = self.zero_shot(img_emb, pos_emb, neg_emb)
scores.append(score)
valid_ids.append(img_id)
return valid_ids, scores
@staticmethod
def compute_image_text_similarity(image_emb, txt_emb):
"""Computes cosine similarity between image and text embeddings."""
# Image embedding shape: (1, 32, 128) or (32, 128) flattened?
# The notebook says: image_emb = np.reshape(image_emb, (32, 128))
image_emb = np.reshape(image_emb, (32, 128))
similarities = []
for i in range(32):
# cosine similarity
similarity = np.dot(image_emb[i], txt_emb) / (np.linalg.norm(image_emb[i]) * np.linalg.norm(txt_emb))
similarities.append(similarity)
return np.max(similarities)
@classmethod
def zero_shot(cls, image_emb, pos_txt_emb, neg_txt_emb):
"""Computes the zero-shot score (pos_sim - neg_sim)."""
pos_cosine = cls.compute_image_text_similarity(image_emb, pos_txt_emb)
neg_cosine = cls.compute_image_text_similarity(image_emb, neg_txt_emb)
return pos_cosine - neg_cosine
class RawImageModel:
def __init__(self):
self.elixrc_model = None
self.qformer_model = None
self._load_model()
def _load_model(self):
"""Loads the TensorFlow model from Hugging Face."""
try:
import tensorflow as tf
import tensorflow_text as text # Registers the ops
except ImportError:
raise ImportError("TensorFlow or tensorflow-text is not installed. Use precomputed mode or install them.")
logger.info("Checking for GPU acceleration...")
gpus = tf.config.list_physical_devices('GPU')
if gpus:
logger.info(f"Running on GPU: {gpus}")
else:
logger.info("Running on CPU. Expect slower inference.")
logger.info("Downloading model weights from Hugging Face...")
model_path = snapshot_download(
repo_id="google/cxr-foundation",
allow_patterns=['elixr-c-v2-pooled/*', 'pax-elixr-b-text/*']
)
logger.info("Loading ELIXR-C (Image Encoder)...")
self.elixrc_model = tf.saved_model.load(os.path.join(model_path, 'elixr-c-v2-pooled'))
logger.info("Loading QFormer (Adapter)...")
self.qformer_model = tf.saved_model.load(os.path.join(model_path, 'pax-elixr-b-text'))
def compute_embeddings(self, image_path):
"""Generates embeddings for a raw image file."""
import tensorflow as tf
import png # pypng
# Load and preprocess image
# This follows the notebook's png_to_tfexample logic but simplified or imported
# For simplicity, implementing the preprocess logic here
try:
# Read image using pypng logic or similar
# Note: The notebook uses pypng to write to BytesIO then TF reads it.
# We can just read the file directly if it's a PNG.
with open(image_path, 'rb') as f:
image_bytes = f.read()
# Create TF Example
example = tf.train.Example()
features = example.features.feature
features['image/encoded'].bytes_list.value.append(image_bytes)
features['image/format'].bytes_list.value.append(b'png')
serialized_example = example.SerializeToString()
# Step 1: ELIXR-C
elixrc_infer = self.elixrc_model.signatures['serving_default']
elixrc_output = elixrc_infer(input_example=tf.constant([serialized_example]))
elixrc_embedding = elixrc_output['feature_maps_0'].numpy() # Shape (1, 8, 8, 1376)
# Step 2: QFormer
# Initialize text inputs with zeros (as we only want image embeddings)
qformer_input = {
'image_feature': elixrc_embedding.tolist(),
'ids': np.zeros((1, 1, 128), dtype=np.int32).tolist(),
'paddings': np.zeros((1, 1, 128), dtype=np.float32).tolist(),
}
qformer_output = self.qformer_model.signatures['serving_default'](**qformer_input)
elixrb_embeddings = qformer_output['all_contrastive_img_emb'].numpy() # Shape (1, 32, 128)
return elixrb_embeddings
except Exception as e:
logger.error(f"Error computing raw embeddings: {e}")
raise