import os import logging import numpy as np import pandas as pd from huggingface_hub import snapshot_download # Configure logging logger = logging.getLogger(__name__) class PrecomputedModel: def __init__(self, data_dir="data"): self.data_dir = data_dir self.image_embeddings = None self.text_embeddings = None self.labels = None self._load_data() def _load_data(self): """Loads precomputed embeddings and labels.""" img_emb_path = os.path.join(self.data_dir, "precomputed_image_embeddings.npz") txt_emb_path = os.path.join(self.data_dir, "precomputed_text_embeddings.npz") labels_path = os.path.join(self.data_dir, "cxr14_subset_labels.csv") # Text embeddings are strictly required for Zero-Shot if not os.path.exists(txt_emb_path): raise FileNotFoundError(f"Missing required text embeddings: {txt_emb_path}") logger.info("Loading precomputed text embeddings...") with np.load(txt_emb_path) as data: self.text_embeddings = {key: data[key] for key in data} # Image embeddings (Optional, only for benchmarking) if os.path.exists(img_emb_path): logger.info("Loading precomputed image embeddings...") with np.load(img_emb_path) as data: self.image_embeddings = {key: data[key] for key in data} else: logger.warning("Precomputed image embeddings not found. Benchmarking features will be disabled.") # Labels (Optional) if os.path.exists(labels_path): logger.info("Loading labels...") self.labels = pd.read_csv(labels_path) else: logger.warning("Labels file not found.") def get_diagnosis_embeddings(self, pos_txt, neg_txt): """Retrieves embeddings for positive and negative text queries.""" if pos_txt not in self.text_embeddings: raise ValueError(f"Positive query '{pos_txt}' not found in precomputed embeddings.") if neg_txt not in self.text_embeddings: raise ValueError(f"Negative query '{neg_txt}' not found in precomputed embeddings.") return self.text_embeddings[pos_txt], self.text_embeddings[neg_txt] def compute_scores(self, image_ids, pos_emb, neg_emb): """Computes zero-shot scores for a list of image IDs.""" scores = [] valid_ids = [] for img_id in image_ids: if img_id not in self.image_embeddings: continue img_emb = self.image_embeddings[img_id] score = self.zero_shot(img_emb, pos_emb, neg_emb) scores.append(score) valid_ids.append(img_id) return valid_ids, scores @staticmethod def compute_image_text_similarity(image_emb, txt_emb): """Computes cosine similarity between image and text embeddings.""" # Image embedding shape: (1, 32, 128) or (32, 128) flattened? # The notebook says: image_emb = np.reshape(image_emb, (32, 128)) image_emb = np.reshape(image_emb, (32, 128)) similarities = [] for i in range(32): # cosine similarity similarity = np.dot(image_emb[i], txt_emb) / (np.linalg.norm(image_emb[i]) * np.linalg.norm(txt_emb)) similarities.append(similarity) return np.max(similarities) @classmethod def zero_shot(cls, image_emb, pos_txt_emb, neg_txt_emb): """Computes the zero-shot score (pos_sim - neg_sim).""" pos_cosine = cls.compute_image_text_similarity(image_emb, pos_txt_emb) neg_cosine = cls.compute_image_text_similarity(image_emb, neg_txt_emb) return pos_cosine - neg_cosine class RawImageModel: def __init__(self): self.elixrc_model = None self.qformer_model = None self._load_model() def _load_model(self): """Loads the TensorFlow model from Hugging Face.""" try: import tensorflow as tf import tensorflow_text as text # Registers the ops except ImportError: raise ImportError("TensorFlow or tensorflow-text is not installed. Use precomputed mode or install them.") logger.info("Checking for GPU acceleration...") gpus = tf.config.list_physical_devices('GPU') if gpus: logger.info(f"Running on GPU: {gpus}") else: logger.info("Running on CPU. Expect slower inference.") logger.info("Downloading model weights from Hugging Face...") model_path = snapshot_download( repo_id="google/cxr-foundation", allow_patterns=['elixr-c-v2-pooled/*', 'pax-elixr-b-text/*'] ) logger.info("Loading ELIXR-C (Image Encoder)...") self.elixrc_model = tf.saved_model.load(os.path.join(model_path, 'elixr-c-v2-pooled')) logger.info("Loading QFormer (Adapter)...") self.qformer_model = tf.saved_model.load(os.path.join(model_path, 'pax-elixr-b-text')) def compute_embeddings(self, image_path): """Generates embeddings for a raw image file.""" import tensorflow as tf import png # pypng # Load and preprocess image # This follows the notebook's png_to_tfexample logic but simplified or imported # For simplicity, implementing the preprocess logic here try: # Read image using pypng logic or similar # Note: The notebook uses pypng to write to BytesIO then TF reads it. # We can just read the file directly if it's a PNG. with open(image_path, 'rb') as f: image_bytes = f.read() # Create TF Example example = tf.train.Example() features = example.features.feature features['image/encoded'].bytes_list.value.append(image_bytes) features['image/format'].bytes_list.value.append(b'png') serialized_example = example.SerializeToString() # Step 1: ELIXR-C elixrc_infer = self.elixrc_model.signatures['serving_default'] elixrc_output = elixrc_infer(input_example=tf.constant([serialized_example])) elixrc_embedding = elixrc_output['feature_maps_0'].numpy() # Shape (1, 8, 8, 1376) # Step 2: QFormer # Initialize text inputs with zeros (as we only want image embeddings) qformer_input = { 'image_feature': elixrc_embedding.tolist(), 'ids': np.zeros((1, 1, 128), dtype=np.int32).tolist(), 'paddings': np.zeros((1, 1, 128), dtype=np.float32).tolist(), } qformer_output = self.qformer_model.signatures['serving_default'](**qformer_input) elixrb_embeddings = qformer_output['all_contrastive_img_emb'].numpy() # Shape (1, 32, 128) return elixrb_embeddings except Exception as e: logger.error(f"Error computing raw embeddings: {e}") raise