File size: 7,175 Bytes
b412062 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 |
import os
import logging
import numpy as np
import pandas as pd
from huggingface_hub import snapshot_download
# Configure logging
logger = logging.getLogger(__name__)
class PrecomputedModel:
def __init__(self, data_dir="data"):
self.data_dir = data_dir
self.image_embeddings = None
self.text_embeddings = None
self.labels = None
self._load_data()
def _load_data(self):
"""Loads precomputed embeddings and labels."""
img_emb_path = os.path.join(self.data_dir, "precomputed_image_embeddings.npz")
txt_emb_path = os.path.join(self.data_dir, "precomputed_text_embeddings.npz")
labels_path = os.path.join(self.data_dir, "cxr14_subset_labels.csv")
# Text embeddings are strictly required for Zero-Shot
if not os.path.exists(txt_emb_path):
raise FileNotFoundError(f"Missing required text embeddings: {txt_emb_path}")
logger.info("Loading precomputed text embeddings...")
with np.load(txt_emb_path) as data:
self.text_embeddings = {key: data[key] for key in data}
# Image embeddings (Optional, only for benchmarking)
if os.path.exists(img_emb_path):
logger.info("Loading precomputed image embeddings...")
with np.load(img_emb_path) as data:
self.image_embeddings = {key: data[key] for key in data}
else:
logger.warning("Precomputed image embeddings not found. Benchmarking features will be disabled.")
# Labels (Optional)
if os.path.exists(labels_path):
logger.info("Loading labels...")
self.labels = pd.read_csv(labels_path)
else:
logger.warning("Labels file not found.")
def get_diagnosis_embeddings(self, pos_txt, neg_txt):
"""Retrieves embeddings for positive and negative text queries."""
if pos_txt not in self.text_embeddings:
raise ValueError(f"Positive query '{pos_txt}' not found in precomputed embeddings.")
if neg_txt not in self.text_embeddings:
raise ValueError(f"Negative query '{neg_txt}' not found in precomputed embeddings.")
return self.text_embeddings[pos_txt], self.text_embeddings[neg_txt]
def compute_scores(self, image_ids, pos_emb, neg_emb):
"""Computes zero-shot scores for a list of image IDs."""
scores = []
valid_ids = []
for img_id in image_ids:
if img_id not in self.image_embeddings:
continue
img_emb = self.image_embeddings[img_id]
score = self.zero_shot(img_emb, pos_emb, neg_emb)
scores.append(score)
valid_ids.append(img_id)
return valid_ids, scores
@staticmethod
def compute_image_text_similarity(image_emb, txt_emb):
"""Computes cosine similarity between image and text embeddings."""
# Image embedding shape: (1, 32, 128) or (32, 128) flattened?
# The notebook says: image_emb = np.reshape(image_emb, (32, 128))
image_emb = np.reshape(image_emb, (32, 128))
similarities = []
for i in range(32):
# cosine similarity
similarity = np.dot(image_emb[i], txt_emb) / (np.linalg.norm(image_emb[i]) * np.linalg.norm(txt_emb))
similarities.append(similarity)
return np.max(similarities)
@classmethod
def zero_shot(cls, image_emb, pos_txt_emb, neg_txt_emb):
"""Computes the zero-shot score (pos_sim - neg_sim)."""
pos_cosine = cls.compute_image_text_similarity(image_emb, pos_txt_emb)
neg_cosine = cls.compute_image_text_similarity(image_emb, neg_txt_emb)
return pos_cosine - neg_cosine
class RawImageModel:
def __init__(self):
self.elixrc_model = None
self.qformer_model = None
self._load_model()
def _load_model(self):
"""Loads the TensorFlow model from Hugging Face."""
try:
import tensorflow as tf
import tensorflow_text as text # Registers the ops
except ImportError:
raise ImportError("TensorFlow or tensorflow-text is not installed. Use precomputed mode or install them.")
logger.info("Checking for GPU acceleration...")
gpus = tf.config.list_physical_devices('GPU')
if gpus:
logger.info(f"Running on GPU: {gpus}")
else:
logger.info("Running on CPU. Expect slower inference.")
logger.info("Downloading model weights from Hugging Face...")
model_path = snapshot_download(
repo_id="google/cxr-foundation",
allow_patterns=['elixr-c-v2-pooled/*', 'pax-elixr-b-text/*']
)
logger.info("Loading ELIXR-C (Image Encoder)...")
self.elixrc_model = tf.saved_model.load(os.path.join(model_path, 'elixr-c-v2-pooled'))
logger.info("Loading QFormer (Adapter)...")
self.qformer_model = tf.saved_model.load(os.path.join(model_path, 'pax-elixr-b-text'))
def compute_embeddings(self, image_path):
"""Generates embeddings for a raw image file."""
import tensorflow as tf
import png # pypng
# Load and preprocess image
# This follows the notebook's png_to_tfexample logic but simplified or imported
# For simplicity, implementing the preprocess logic here
try:
# Read image using pypng logic or similar
# Note: The notebook uses pypng to write to BytesIO then TF reads it.
# We can just read the file directly if it's a PNG.
with open(image_path, 'rb') as f:
image_bytes = f.read()
# Create TF Example
example = tf.train.Example()
features = example.features.feature
features['image/encoded'].bytes_list.value.append(image_bytes)
features['image/format'].bytes_list.value.append(b'png')
serialized_example = example.SerializeToString()
# Step 1: ELIXR-C
elixrc_infer = self.elixrc_model.signatures['serving_default']
elixrc_output = elixrc_infer(input_example=tf.constant([serialized_example]))
elixrc_embedding = elixrc_output['feature_maps_0'].numpy() # Shape (1, 8, 8, 1376)
# Step 2: QFormer
# Initialize text inputs with zeros (as we only want image embeddings)
qformer_input = {
'image_feature': elixrc_embedding.tolist(),
'ids': np.zeros((1, 1, 128), dtype=np.int32).tolist(),
'paddings': np.zeros((1, 1, 128), dtype=np.float32).tolist(),
}
qformer_output = self.qformer_model.signatures['serving_default'](**qformer_input)
elixrb_embeddings = qformer_output['all_contrastive_img_emb'].numpy() # Shape (1, 32, 128)
return elixrb_embeddings
except Exception as e:
logger.error(f"Error computing raw embeddings: {e}")
raise
|