|
|
import os |
|
|
import logging |
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
from huggingface_hub import snapshot_download |
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class PrecomputedModel: |
|
|
def __init__(self, data_dir="data"): |
|
|
self.data_dir = data_dir |
|
|
self.image_embeddings = None |
|
|
self.text_embeddings = None |
|
|
self.labels = None |
|
|
self._load_data() |
|
|
|
|
|
def _load_data(self): |
|
|
"""Loads precomputed embeddings and labels.""" |
|
|
img_emb_path = os.path.join(self.data_dir, "precomputed_image_embeddings.npz") |
|
|
txt_emb_path = os.path.join(self.data_dir, "precomputed_text_embeddings.npz") |
|
|
labels_path = os.path.join(self.data_dir, "cxr14_subset_labels.csv") |
|
|
|
|
|
|
|
|
if not os.path.exists(txt_emb_path): |
|
|
raise FileNotFoundError(f"Missing required text embeddings: {txt_emb_path}") |
|
|
|
|
|
logger.info("Loading precomputed text embeddings...") |
|
|
with np.load(txt_emb_path) as data: |
|
|
self.text_embeddings = {key: data[key] for key in data} |
|
|
|
|
|
|
|
|
if os.path.exists(img_emb_path): |
|
|
logger.info("Loading precomputed image embeddings...") |
|
|
with np.load(img_emb_path) as data: |
|
|
self.image_embeddings = {key: data[key] for key in data} |
|
|
else: |
|
|
logger.warning("Precomputed image embeddings not found. Benchmarking features will be disabled.") |
|
|
|
|
|
|
|
|
if os.path.exists(labels_path): |
|
|
logger.info("Loading labels...") |
|
|
self.labels = pd.read_csv(labels_path) |
|
|
else: |
|
|
logger.warning("Labels file not found.") |
|
|
|
|
|
def get_diagnosis_embeddings(self, pos_txt, neg_txt): |
|
|
"""Retrieves embeddings for positive and negative text queries.""" |
|
|
if pos_txt not in self.text_embeddings: |
|
|
raise ValueError(f"Positive query '{pos_txt}' not found in precomputed embeddings.") |
|
|
if neg_txt not in self.text_embeddings: |
|
|
raise ValueError(f"Negative query '{neg_txt}' not found in precomputed embeddings.") |
|
|
|
|
|
return self.text_embeddings[pos_txt], self.text_embeddings[neg_txt] |
|
|
|
|
|
def compute_scores(self, image_ids, pos_emb, neg_emb): |
|
|
"""Computes zero-shot scores for a list of image IDs.""" |
|
|
scores = [] |
|
|
valid_ids = [] |
|
|
|
|
|
for img_id in image_ids: |
|
|
if img_id not in self.image_embeddings: |
|
|
continue |
|
|
|
|
|
img_emb = self.image_embeddings[img_id] |
|
|
score = self.zero_shot(img_emb, pos_emb, neg_emb) |
|
|
scores.append(score) |
|
|
valid_ids.append(img_id) |
|
|
|
|
|
return valid_ids, scores |
|
|
|
|
|
@staticmethod |
|
|
def compute_image_text_similarity(image_emb, txt_emb): |
|
|
"""Computes cosine similarity between image and text embeddings.""" |
|
|
|
|
|
|
|
|
image_emb = np.reshape(image_emb, (32, 128)) |
|
|
|
|
|
similarities = [] |
|
|
for i in range(32): |
|
|
|
|
|
similarity = np.dot(image_emb[i], txt_emb) / (np.linalg.norm(image_emb[i]) * np.linalg.norm(txt_emb)) |
|
|
similarities.append(similarity) |
|
|
|
|
|
return np.max(similarities) |
|
|
|
|
|
@classmethod |
|
|
def zero_shot(cls, image_emb, pos_txt_emb, neg_txt_emb): |
|
|
"""Computes the zero-shot score (pos_sim - neg_sim).""" |
|
|
pos_cosine = cls.compute_image_text_similarity(image_emb, pos_txt_emb) |
|
|
neg_cosine = cls.compute_image_text_similarity(image_emb, neg_txt_emb) |
|
|
return pos_cosine - neg_cosine |
|
|
|
|
|
|
|
|
class RawImageModel: |
|
|
def __init__(self): |
|
|
self.elixrc_model = None |
|
|
self.qformer_model = None |
|
|
self._load_model() |
|
|
|
|
|
def _load_model(self): |
|
|
"""Loads the TensorFlow model from Hugging Face.""" |
|
|
try: |
|
|
import tensorflow as tf |
|
|
import tensorflow_text as text |
|
|
except ImportError: |
|
|
raise ImportError("TensorFlow or tensorflow-text is not installed. Use precomputed mode or install them.") |
|
|
|
|
|
logger.info("Checking for GPU acceleration...") |
|
|
gpus = tf.config.list_physical_devices('GPU') |
|
|
if gpus: |
|
|
logger.info(f"Running on GPU: {gpus}") |
|
|
else: |
|
|
logger.info("Running on CPU. Expect slower inference.") |
|
|
|
|
|
logger.info("Downloading model weights from Hugging Face...") |
|
|
model_path = snapshot_download( |
|
|
repo_id="google/cxr-foundation", |
|
|
allow_patterns=['elixr-c-v2-pooled/*', 'pax-elixr-b-text/*'] |
|
|
) |
|
|
|
|
|
logger.info("Loading ELIXR-C (Image Encoder)...") |
|
|
self.elixrc_model = tf.saved_model.load(os.path.join(model_path, 'elixr-c-v2-pooled')) |
|
|
|
|
|
logger.info("Loading QFormer (Adapter)...") |
|
|
self.qformer_model = tf.saved_model.load(os.path.join(model_path, 'pax-elixr-b-text')) |
|
|
|
|
|
def compute_embeddings(self, image_path): |
|
|
"""Generates embeddings for a raw image file.""" |
|
|
import tensorflow as tf |
|
|
import png |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
|
|
|
|
|
|
with open(image_path, 'rb') as f: |
|
|
image_bytes = f.read() |
|
|
|
|
|
|
|
|
example = tf.train.Example() |
|
|
features = example.features.feature |
|
|
features['image/encoded'].bytes_list.value.append(image_bytes) |
|
|
features['image/format'].bytes_list.value.append(b'png') |
|
|
serialized_example = example.SerializeToString() |
|
|
|
|
|
|
|
|
elixrc_infer = self.elixrc_model.signatures['serving_default'] |
|
|
elixrc_output = elixrc_infer(input_example=tf.constant([serialized_example])) |
|
|
elixrc_embedding = elixrc_output['feature_maps_0'].numpy() |
|
|
|
|
|
|
|
|
|
|
|
qformer_input = { |
|
|
'image_feature': elixrc_embedding.tolist(), |
|
|
'ids': np.zeros((1, 1, 128), dtype=np.int32).tolist(), |
|
|
'paddings': np.zeros((1, 1, 128), dtype=np.float32).tolist(), |
|
|
} |
|
|
|
|
|
qformer_output = self.qformer_model.signatures['serving_default'](**qformer_input) |
|
|
elixrb_embeddings = qformer_output['all_contrastive_img_emb'].numpy() |
|
|
|
|
|
return elixrb_embeddings |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error computing raw embeddings: {e}") |
|
|
raise |
|
|
|