File size: 7,175 Bytes
b412062
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
import os
import logging
import numpy as np
import pandas as pd
from huggingface_hub import snapshot_download

# Configure logging
logger = logging.getLogger(__name__)

class PrecomputedModel:
    def __init__(self, data_dir="data"):
        self.data_dir = data_dir
        self.image_embeddings = None
        self.text_embeddings = None
        self.labels = None
        self._load_data()

    def _load_data(self):
        """Loads precomputed embeddings and labels."""
        img_emb_path = os.path.join(self.data_dir, "precomputed_image_embeddings.npz")
        txt_emb_path = os.path.join(self.data_dir, "precomputed_text_embeddings.npz")
        labels_path = os.path.join(self.data_dir, "cxr14_subset_labels.csv")

        # Text embeddings are strictly required for Zero-Shot
        if not os.path.exists(txt_emb_path):
             raise FileNotFoundError(f"Missing required text embeddings: {txt_emb_path}")

        logger.info("Loading precomputed text embeddings...")
        with np.load(txt_emb_path) as data:
            self.text_embeddings = {key: data[key] for key in data}

        # Image embeddings (Optional, only for benchmarking)
        if os.path.exists(img_emb_path):
            logger.info("Loading precomputed image embeddings...")
            with np.load(img_emb_path) as data:
                self.image_embeddings = {key: data[key] for key in data}
        else:
            logger.warning("Precomputed image embeddings not found. Benchmarking features will be disabled.")
            
        # Labels (Optional)
        if os.path.exists(labels_path):
            logger.info("Loading labels...")
            self.labels = pd.read_csv(labels_path)
        else:
             logger.warning("Labels file not found.")

    def get_diagnosis_embeddings(self, pos_txt, neg_txt):
        """Retrieves embeddings for positive and negative text queries."""
        if pos_txt not in self.text_embeddings:
            raise ValueError(f"Positive query '{pos_txt}' not found in precomputed embeddings.")
        if neg_txt not in self.text_embeddings:
            raise ValueError(f"Negative query '{neg_txt}' not found in precomputed embeddings.")
            
        return self.text_embeddings[pos_txt], self.text_embeddings[neg_txt]

    def compute_scores(self, image_ids, pos_emb, neg_emb):
        """Computes zero-shot scores for a list of image IDs."""
        scores = []
        valid_ids = []
        
        for img_id in image_ids:
            if img_id not in self.image_embeddings:
                continue
                
            img_emb = self.image_embeddings[img_id]
            score = self.zero_shot(img_emb, pos_emb, neg_emb)
            scores.append(score)
            valid_ids.append(img_id)
            
        return valid_ids, scores

    @staticmethod
    def compute_image_text_similarity(image_emb, txt_emb):
        """Computes cosine similarity between image and text embeddings."""
        # Image embedding shape: (1, 32, 128) or (32, 128) flattened?
        # The notebook says: image_emb = np.reshape(image_emb, (32, 128))
        image_emb = np.reshape(image_emb, (32, 128))
        
        similarities = []
        for i in range(32):
            # cosine similarity
            similarity = np.dot(image_emb[i], txt_emb) / (np.linalg.norm(image_emb[i]) * np.linalg.norm(txt_emb))
            similarities.append(similarity)
        
        return np.max(similarities)

    @classmethod
    def zero_shot(cls, image_emb, pos_txt_emb, neg_txt_emb):
        """Computes the zero-shot score (pos_sim - neg_sim)."""
        pos_cosine = cls.compute_image_text_similarity(image_emb, pos_txt_emb)
        neg_cosine = cls.compute_image_text_similarity(image_emb, neg_txt_emb)
        return pos_cosine - neg_cosine


class RawImageModel:
    def __init__(self):
        self.elixrc_model = None
        self.qformer_model = None
        self._load_model()

    def _load_model(self):
        """Loads the TensorFlow model from Hugging Face."""
        try:
            import tensorflow as tf
            import tensorflow_text as text  # Registers the ops
        except ImportError:
            raise ImportError("TensorFlow or tensorflow-text is not installed. Use precomputed mode or install them.")

        logger.info("Checking for GPU acceleration...")
        gpus = tf.config.list_physical_devices('GPU')
        if gpus:
            logger.info(f"Running on GPU: {gpus}")
        else:
            logger.info("Running on CPU. Expect slower inference.")

        logger.info("Downloading model weights from Hugging Face...")
        model_path = snapshot_download(
            repo_id="google/cxr-foundation",
            allow_patterns=['elixr-c-v2-pooled/*', 'pax-elixr-b-text/*']
        )
        
        logger.info("Loading ELIXR-C (Image Encoder)...")
        self.elixrc_model = tf.saved_model.load(os.path.join(model_path, 'elixr-c-v2-pooled'))
        
        logger.info("Loading QFormer (Adapter)...")
        self.qformer_model = tf.saved_model.load(os.path.join(model_path, 'pax-elixr-b-text'))

    def compute_embeddings(self, image_path):
        """Generates embeddings for a raw image file."""
        import tensorflow as tf
        import png # pypng

        # Load and preprocess image
        # This follows the notebook's png_to_tfexample logic but simplified or imported
        # For simplicity, implementing the preprocess logic here
        
        try:
            # Read image using pypng logic or similar 
            # Note: The notebook uses pypng to write to BytesIO then TF reads it.
            # We can just read the file directly if it's a PNG.
            with open(image_path, 'rb') as f:
                image_bytes = f.read()
                
            # Create TF Example
            example = tf.train.Example()
            features = example.features.feature
            features['image/encoded'].bytes_list.value.append(image_bytes)
            features['image/format'].bytes_list.value.append(b'png')
            serialized_example = example.SerializeToString()
            
            # Step 1: ELIXR-C
            elixrc_infer = self.elixrc_model.signatures['serving_default']
            elixrc_output = elixrc_infer(input_example=tf.constant([serialized_example]))
            elixrc_embedding = elixrc_output['feature_maps_0'].numpy() # Shape (1, 8, 8, 1376)
            
            # Step 2: QFormer
            # Initialize text inputs with zeros (as we only want image embeddings)
            qformer_input = {
                'image_feature': elixrc_embedding.tolist(),
                'ids': np.zeros((1, 1, 128), dtype=np.int32).tolist(),
                'paddings': np.zeros((1, 1, 128), dtype=np.float32).tolist(),
            }
            
            qformer_output = self.qformer_model.signatures['serving_default'](**qformer_input)
            elixrb_embeddings = qformer_output['all_contrastive_img_emb'].numpy() # Shape (1, 32, 128)
            
            return elixrb_embeddings
            
        except Exception as e:
            logger.error(f"Error computing raw embeddings: {e}")
            raise