borso271 commited on
Commit
a6ffb65
·
1 Parent(s): efade39

Add dynamic label management with Hub persistence

Browse files

- Implement upsert_labels and reload_labels admin operations
- Add versioned snapshot persistence to HF dataset repo
- Maintain backward compatibility for classification API
- Add fingerprinting for model compatibility checks
- Enable incremental embedding updates without re-computation

Files changed (2) hide show
  1. handler.py +216 -45
  2. requirements.txt +3 -1
handler.py CHANGED
@@ -1,75 +1,246 @@
1
- import contextlib, io, base64, torch, json
2
  from PIL import Image
3
  import open_clip
 
 
4
  from reparam import reparameterize_model
5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  class EndpointHandler:
7
  def __init__(self, path: str = ""):
8
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
 
9
 
10
- # 1. Load the model (happens only once at startup)
11
  model, _, self.preprocess = open_clip.create_model_and_transforms(
12
- "MobileCLIP-B", pretrained='datacompdr'
13
  )
14
  model.eval()
15
- self.model = reparameterize_model(model)
16
- tokenizer = open_clip.get_tokenizer("MobileCLIP-B")
17
- self.model.to(self.device)
18
-
19
  if self.device == "cuda":
20
- self.model.to(torch.float16)
21
-
22
- # --- OPTIMIZATION: Pre-compute text features from your JSON ---
23
-
24
- # 2. Load your rich class definitions from the file
25
- with open(f"{path}/items.json", "r", encoding="utf-8") as f:
26
- class_definitions = json.load(f)
27
-
28
- # 3. Prepare the data for encoding and for the final response
29
- # - Use the 'prompt' field for creating the embeddings
30
- # - Keep 'name' and 'id' to structure the response later
31
- prompts = [item['prompt'] for item in class_definitions]
32
- self.class_ids = [item['id'] for item in class_definitions]
33
- self.class_names = [item['name'] for item in class_definitions]
34
-
35
- # 4. Tokenize and encode all prompts at once
36
- with torch.no_grad():
37
- text_tokens = tokenizer(prompts).to(self.device)
38
- self.text_features = self.model.encode_text(text_tokens)
39
- self.text_features /= self.text_features.norm(dim=-1, keepdim=True)
 
 
 
 
40
 
41
  def __call__(self, data):
42
- # The payload only needs the image now
43
  payload = data.get("inputs", data)
44
- img_b64 = payload["image"]
45
 
46
- # ---------------- decode image ----------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  image = Image.open(io.BytesIO(base64.b64decode(img_b64))).convert("RGB")
48
  img_tensor = self.preprocess(image).unsqueeze(0).to(self.device)
49
-
50
  if self.device == "cuda":
51
  img_tensor = img_tensor.to(torch.float16)
52
-
53
- # ---------------- forward pass (very fast) -----------------
54
  with torch.no_grad():
55
- # 1. Encode only the image
56
  img_feat = self.model.encode_image(img_tensor)
57
  img_feat /= img_feat.norm(dim=-1, keepdim=True)
58
-
59
- # 2. Compute similarity against the pre-computed text features
60
- probs = (100 * img_feat @ self.text_features.T).softmax(dim=-1)[0]
61
-
62
- # 3. Combine the results with your stored class IDs and names
63
- # and convert the tensor of probabilities to a list of floats
64
- results = zip(self.class_ids, self.class_names, probs.cpu().tolist())
65
-
66
- # 4. Create a sorted list of dictionaries for a clean JSON response
67
  return sorted(
68
  [{"id": i, "label": name, "score": float(p)} for i, name, p in results],
69
  key=lambda x: x["score"],
70
- reverse=True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  )
72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
 
75
  # """
@@ -213,4 +384,4 @@ class EndpointHandler:
213
  # )
214
 
215
 
216
-
 
1
+ import contextlib, io, base64, torch, json, os, threading
2
  from PIL import Image
3
  import open_clip
4
+ from huggingface_hub import hf_hub_download, create_commit, CommitOperationAdd
5
+ from safetensors.torch import save_file, load_file
6
  from reparam import reparameterize_model
7
 
8
+ ADMIN_TOKEN = os.getenv("ADMIN_TOKEN", "")
9
+ HF_LABEL_REPO = os.getenv("HF_LABEL_REPO", "") # e.g. "org/mobileclip-labels"
10
+ HF_WRITE_TOKEN = os.getenv("HF_WRITE_TOKEN", "")
11
+ HF_READ_TOKEN = os.getenv("HF_READ_TOKEN", HF_WRITE_TOKEN)
12
+
13
+
14
+ def _fingerprint(device: str, dtype: torch.dtype) -> dict:
15
+ return {
16
+ "model_id": "MobileCLIP-B",
17
+ "pretrained": "datacompdr",
18
+ "open_clip": getattr(open_clip, "__version__", "unknown"),
19
+ "torch": torch.__version__,
20
+ "cuda": torch.version.cuda if torch.cuda.is_available() else None,
21
+ "dtype_runtime": str(dtype),
22
+ "text_norm": "L2",
23
+ "logit_scale": 100.0,
24
+ }
25
+
26
+
27
  class EndpointHandler:
28
  def __init__(self, path: str = ""):
29
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
30
+ self.dtype = torch.float16 if self.device == "cuda" else torch.float32
31
 
32
+ # 1) Load model + transforms
33
  model, _, self.preprocess = open_clip.create_model_and_transforms(
34
+ "MobileCLIP-B", pretrained="datacompdr"
35
  )
36
  model.eval()
37
+ model = reparameterize_model(model)
38
+ model.to(self.device)
 
 
39
  if self.device == "cuda":
40
+ model = model.to(torch.float16)
41
+ self.model = model
42
+ self.tokenizer = open_clip.get_tokenizer("MobileCLIP-B")
43
+ self.fingerprint = _fingerprint(self.device, self.dtype)
44
+ self._lock = threading.Lock()
45
+
46
+ # 2) Try to load snapshot from Hub; else seed from items.json
47
+ loaded = False
48
+ if HF_LABEL_REPO:
49
+ with contextlib.suppress(Exception):
50
+ loaded = self._load_snapshot_from_hub_latest()
51
+ if not loaded:
52
+ with open(f"{path}/items.json", "r", encoding="utf-8") as f:
53
+ items = json.load(f)
54
+ prompts = [it["prompt"] for it in items]
55
+ self.class_ids = [int(it["id"]) for it in items]
56
+ self.class_names = [it["name"] for it in items]
57
+ with torch.no_grad():
58
+ toks = self.tokenizer(prompts).to(self.device)
59
+ feats = self.model.encode_text(toks)
60
+ feats = feats / feats.norm(dim=-1, keepdim=True)
61
+ self.text_features_cpu = feats.detach().cpu().to(torch.float32).contiguous()
62
+ self._to_device()
63
+ self.labels_version = 1
64
 
65
  def __call__(self, data):
 
66
  payload = data.get("inputs", data)
 
67
 
68
+ # Admin op: upsert_labels
69
+ op = payload.get("op")
70
+ if op == "upsert_labels":
71
+ if payload.get("token") != ADMIN_TOKEN:
72
+ return {"error": "unauthorized"}
73
+ items = payload.get("items", []) or []
74
+ added = self._upsert_items(items)
75
+ if added > 0:
76
+ new_ver = int(getattr(self, "labels_version", 1)) + 1
77
+ try:
78
+ self._persist_snapshot_to_hub(new_ver)
79
+ self.labels_version = new_ver
80
+ except Exception as e:
81
+ return {"status": "error", "added": added, "detail": str(e)}
82
+ return {"status": "ok", "added": added, "labels_version": getattr(self, "labels_version", 1)}
83
+
84
+ # Admin op: reload_labels
85
+ if op == "reload_labels":
86
+ if payload.get("token") != ADMIN_TOKEN:
87
+ return {"error": "unauthorized"}
88
+ try:
89
+ ver = int(payload.get("version"))
90
+ except Exception:
91
+ return {"error": "invalid_version"}
92
+ ok = self._load_snapshot_from_hub_version(ver)
93
+ return {"status": "ok" if ok else "nochange", "labels_version": getattr(self, "labels_version", 0)}
94
+
95
+ # Freshness guard (optional)
96
+ min_ver = payload.get("min_labels_version")
97
+ if isinstance(min_ver, int) and min_ver > getattr(self, "labels_version", 0):
98
+ with contextlib.suppress(Exception):
99
+ self._load_snapshot_from_hub_version(min_ver)
100
+
101
+ # Classification path (unchanged contract)
102
+ img_b64 = payload["image"]
103
  image = Image.open(io.BytesIO(base64.b64decode(img_b64))).convert("RGB")
104
  img_tensor = self.preprocess(image).unsqueeze(0).to(self.device)
 
105
  if self.device == "cuda":
106
  img_tensor = img_tensor.to(torch.float16)
 
 
107
  with torch.no_grad():
 
108
  img_feat = self.model.encode_image(img_tensor)
109
  img_feat /= img_feat.norm(dim=-1, keepdim=True)
110
+ probs = (100.0 * img_feat @ self.text_features.T).softmax(dim=-1)[0]
111
+ results = zip(self.class_ids, self.class_names, probs.detach().cpu().tolist())
112
+ top_k = int(payload.get("top_k", len(self.class_ids)))
 
 
 
 
 
 
113
  return sorted(
114
  [{"id": i, "label": name, "score": float(p)} for i, name, p in results],
115
  key=lambda x: x["score"],
116
+ reverse=True,
117
+ )[:top_k]
118
+
119
+ # ------------- helpers -------------
120
+ def _encode_text(self, prompts):
121
+ with torch.no_grad():
122
+ toks = self.tokenizer(prompts).to(self.device)
123
+ feats = self.model.encode_text(toks)
124
+ feats = feats / feats.norm(dim=-1, keepdim=True)
125
+ return feats
126
+
127
+ def _to_device(self):
128
+ self.text_features = self.text_features_cpu.to(
129
+ self.device, dtype=(torch.float16 if self.device == "cuda" else torch.float32)
130
+ )
131
+
132
+ def _upsert_items(self, new_items):
133
+ if not new_items:
134
+ return 0
135
+ with self._lock:
136
+ known = set(getattr(self, "class_ids", []))
137
+ batch = [it for it in new_items if int(it.get("id")) not in known]
138
+ if not batch:
139
+ return 0
140
+ prompts = [it["prompt"] for it in batch]
141
+ feats = self._encode_text(prompts).detach().cpu().to(torch.float32)
142
+ if not hasattr(self, "text_features_cpu"):
143
+ self.text_features_cpu = feats.contiguous()
144
+ self.class_ids = [int(it["id"]) for it in batch]
145
+ self.class_names = [it["name"] for it in batch]
146
+ else:
147
+ self.text_features_cpu = torch.cat([self.text_features_cpu, feats], dim=0).contiguous()
148
+ self.class_ids.extend([int(it["id"]) for it in batch])
149
+ self.class_names.extend([it["name"] for it in batch])
150
+ self._to_device()
151
+ return len(batch)
152
+
153
+ def _persist_snapshot_to_hub(self, version: int):
154
+ if not HF_LABEL_REPO:
155
+ raise RuntimeError("HF_LABEL_REPO not set")
156
+ if not HF_WRITE_TOKEN:
157
+ raise RuntimeError("HF_WRITE_TOKEN not set for publishing")
158
+
159
+ emb_path = "/tmp/embeddings.safetensors"
160
+ meta_path = "/tmp/meta.json"
161
+ latest_bytes = io.BytesIO(json.dumps({"version": int(version)}).encode("utf-8"))
162
+
163
+ save_file({"embeddings": self.text_features_cpu.to(torch.float32)}, emb_path)
164
+ meta = {
165
+ "items": [{"id": int(i), "name": n} for i, n in zip(self.class_ids, self.class_names)],
166
+ "fingerprint": self.fingerprint,
167
+ "dims": int(self.text_features_cpu.shape[1]),
168
+ "count": int(self.text_features_cpu.shape[0]),
169
+ "version": int(version),
170
+ }
171
+ with open(meta_path, "w", encoding="utf-8") as f:
172
+ json.dump(meta, f)
173
+
174
+ ops = [
175
+ CommitOperationAdd(
176
+ path_in_repo=f"snapshots/v{version}/embeddings.safetensors",
177
+ path_or_fileobj=emb_path,
178
+ lfs=True,
179
+ ),
180
+ CommitOperationAdd(
181
+ path_in_repo=f"snapshots/v{version}/meta.json",
182
+ path_or_fileobj=meta_path,
183
+ ),
184
+ CommitOperationAdd(
185
+ path_in_repo="snapshots/latest.json",
186
+ path_or_fileobj=latest_bytes,
187
+ ),
188
+ ]
189
+ create_commit(
190
+ repo_id=HF_LABEL_REPO,
191
+ repo_type="dataset",
192
+ operations=ops,
193
+ token=HF_WRITE_TOKEN,
194
+ commit_message=f"labels v{version}",
195
  )
196
 
197
+ def _load_snapshot_from_hub_version(self, version: int) -> bool:
198
+ if not HF_LABEL_REPO:
199
+ return False
200
+ with self._lock:
201
+ emb_p = hf_hub_download(
202
+ HF_LABEL_REPO,
203
+ f"snapshots/v{version}/embeddings.safetensors",
204
+ repo_type="dataset",
205
+ token=HF_READ_TOKEN,
206
+ force_download=True,
207
+ )
208
+ meta_p = hf_hub_download(
209
+ HF_LABEL_REPO,
210
+ f"snapshots/v{version}/meta.json",
211
+ repo_type="dataset",
212
+ token=HF_READ_TOKEN,
213
+ force_download=True,
214
+ )
215
+ meta = json.load(open(meta_p, "r", encoding="utf-8"))
216
+ if meta.get("fingerprint") != self.fingerprint:
217
+ raise RuntimeError("Embedding/model fingerprint mismatch")
218
+ feats = load_file(emb_p)["embeddings"] # float32 CPU
219
+ self.text_features_cpu = feats.contiguous()
220
+ self.class_ids = [int(x["id"]) for x in meta.get("items", [])]
221
+ self.class_names = [x["name"] for x in meta.get("items", [])]
222
+ self.labels_version = int(meta.get("version", version))
223
+ self._to_device()
224
+ return True
225
+
226
+ def _load_snapshot_from_hub_latest(self) -> bool:
227
+ if not HF_LABEL_REPO:
228
+ return False
229
+ try:
230
+ latest_p = hf_hub_download(
231
+ HF_LABEL_REPO,
232
+ "snapshots/latest.json",
233
+ repo_type="dataset",
234
+ token=HF_READ_TOKEN,
235
+ )
236
+ except Exception:
237
+ return False
238
+ latest = json.load(open(latest_p, "r", encoding="utf-8"))
239
+ ver = int(latest.get("version", 0))
240
+ if ver <= 0:
241
+ return False
242
+ return self._load_snapshot_from_hub_version(ver)
243
+
244
 
245
 
246
  # """
 
384
  # )
385
 
386
 
387
+
requirements.txt CHANGED
@@ -1,2 +1,4 @@
1
  Pillow
2
- open_clip_torch
 
 
 
1
  Pillow
2
+ open_clip_torch
3
+ huggingface_hub>=0.23.0
4
+ safetensors>=0.4.3