suryabbrj akhaliq HF Staff commited on
Commit
0031b60
Β·
0 Parent(s):

Duplicate from akhaliq/CLIP_prefix_captioning

Browse files

Co-authored-by: AK <akhaliq@users.noreply.huggingface.co>

Files changed (5) hide show
  1. .gitattributes +27 -0
  2. README.md +38 -0
  3. app.py +273 -0
  4. requirements.txt +8 -0
  5. water.jpeg +0 -0
.gitattributes ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bin.* filter=lfs diff=lfs merge=lfs -text
5
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.model filter=lfs diff=lfs merge=lfs -text
12
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
13
+ *.onnx filter=lfs diff=lfs merge=lfs -text
14
+ *.ot filter=lfs diff=lfs merge=lfs -text
15
+ *.parquet filter=lfs diff=lfs merge=lfs -text
16
+ *.pb filter=lfs diff=lfs merge=lfs -text
17
+ *.pt filter=lfs diff=lfs merge=lfs -text
18
+ *.pth filter=lfs diff=lfs merge=lfs -text
19
+ *.rar filter=lfs diff=lfs merge=lfs -text
20
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
21
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
22
+ *.tflite filter=lfs diff=lfs merge=lfs -text
23
+ *.tgz filter=lfs diff=lfs merge=lfs -text
24
+ *.xz filter=lfs diff=lfs merge=lfs -text
25
+ *.zip filter=lfs diff=lfs merge=lfs -text
26
+ *.zstandard filter=lfs diff=lfs merge=lfs -text
27
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: CLIP_prefix_captioning
3
+ emoji: πŸ’©
4
+ colorFrom: red
5
+ colorTo: indigo
6
+ sdk: gradio
7
+ app_file: app.py
8
+ pinned: false
9
+ duplicated_from: akhaliq/CLIP_prefix_captioning
10
+ ---
11
+
12
+ # Configuration
13
+
14
+ `title`: _string_
15
+ Display title for the Space
16
+
17
+ `emoji`: _string_
18
+ Space emoji (emoji-only character allowed)
19
+
20
+ `colorFrom`: _string_
21
+ Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
22
+
23
+ `colorTo`: _string_
24
+ Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
25
+
26
+ `sdk`: _string_
27
+ Can be either `gradio` or `streamlit`
28
+
29
+ `sdk_version` : _string_
30
+ Only applicable for `streamlit` SDK.
31
+ See [doc](https://hf.co/docs/hub/spaces) for more info on supported versions.
32
+
33
+ `app_file`: _string_
34
+ Path to your main application file (which contains either `gradio` or `streamlit` Python code).
35
+ Path is relative to the root of the repository.
36
+
37
+ `pinned`: _boolean_
38
+ Whether the Space stays on top of your list.
app.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from huggingface_hub import hf_hub_download
3
+ conceptual_weight = hf_hub_download(repo_id="akhaliq/CLIP-prefix-captioning-conceptual-weights", filename="conceptual_weights.pt")
4
+ coco_weight = hf_hub_download(repo_id="akhaliq/CLIP-prefix-captioning-COCO-weights", filename="coco_weights.pt")
5
+ import clip
6
+ import os
7
+ from torch import nn
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn.functional as nnf
11
+ import sys
12
+ from typing import Tuple, List, Union, Optional
13
+ from transformers import GPT2Tokenizer, GPT2LMHeadModel, AdamW, get_linear_schedule_with_warmup
14
+ from tqdm import tqdm, trange
15
+ import skimage.io as io
16
+ import PIL.Image
17
+ import gradio as gr
18
+
19
+ N = type(None)
20
+ V = np.array
21
+ ARRAY = np.ndarray
22
+ ARRAYS = Union[Tuple[ARRAY, ...], List[ARRAY]]
23
+ VS = Union[Tuple[V, ...], List[V]]
24
+ VN = Union[V, N]
25
+ VNS = Union[VS, N]
26
+ T = torch.Tensor
27
+ TS = Union[Tuple[T, ...], List[T]]
28
+ TN = Optional[T]
29
+ TNS = Union[Tuple[TN, ...], List[TN]]
30
+ TSN = Optional[TS]
31
+ TA = Union[T, ARRAY]
32
+
33
+
34
+ D = torch.device
35
+ CPU = torch.device('cpu')
36
+
37
+
38
+ def get_device(device_id: int) -> D:
39
+ if not torch.cuda.is_available():
40
+ return CPU
41
+ device_id = min(torch.cuda.device_count() - 1, device_id)
42
+ return torch.device(f'cuda:{device_id}')
43
+
44
+
45
+ CUDA = get_device
46
+
47
+ class MLP(nn.Module):
48
+
49
+ def forward(self, x: T) -> T:
50
+ return self.model(x)
51
+
52
+ def __init__(self, sizes: Tuple[int, ...], bias=True, act=nn.Tanh):
53
+ super(MLP, self).__init__()
54
+ layers = []
55
+ for i in range(len(sizes) -1):
56
+ layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=bias))
57
+ if i < len(sizes) - 2:
58
+ layers.append(act())
59
+ self.model = nn.Sequential(*layers)
60
+
61
+
62
+ class ClipCaptionModel(nn.Module):
63
+
64
+ #@functools.lru_cache #FIXME
65
+ def get_dummy_token(self, batch_size: int, device: D) -> T:
66
+ return torch.zeros(batch_size, self.prefix_length, dtype=torch.int64, device=device)
67
+
68
+ def forward(self, tokens: T, prefix: T, mask: Optional[T] = None, labels: Optional[T] = None):
69
+ embedding_text = self.gpt.transformer.wte(tokens)
70
+ prefix_projections = self.clip_project(prefix).view(-1, self.prefix_length, self.gpt_embedding_size)
71
+ #print(embedding_text.size()) #torch.Size([5, 67, 768])
72
+ #print(prefix_projections.size()) #torch.Size([5, 1, 768])
73
+ embedding_cat = torch.cat((prefix_projections, embedding_text), dim=1)
74
+ if labels is not None:
75
+ dummy_token = self.get_dummy_token(tokens.shape[0], tokens.device)
76
+ labels = torch.cat((dummy_token, tokens), dim=1)
77
+ out = self.gpt(inputs_embeds=embedding_cat, labels=labels, attention_mask=mask)
78
+ return out
79
+
80
+ def __init__(self, prefix_length: int, prefix_size: int = 512):
81
+ super(ClipCaptionModel, self).__init__()
82
+ self.prefix_length = prefix_length
83
+ self.gpt = GPT2LMHeadModel.from_pretrained('gpt2')
84
+ self.gpt_embedding_size = self.gpt.transformer.wte.weight.shape[1]
85
+ if prefix_length > 10: # not enough memory
86
+ self.clip_project = nn.Linear(prefix_size, self.gpt_embedding_size * prefix_length)
87
+ else:
88
+ self.clip_project = MLP((prefix_size, (self.gpt_embedding_size * prefix_length) // 2, self.gpt_embedding_size * prefix_length))
89
+
90
+
91
+ class ClipCaptionPrefix(ClipCaptionModel):
92
+
93
+ def parameters(self, recurse: bool = True):
94
+ return self.clip_project.parameters()
95
+
96
+ def train(self, mode: bool = True):
97
+ super(ClipCaptionPrefix, self).train(mode)
98
+ self.gpt.eval()
99
+ return self
100
+
101
+
102
+ #@title Caption prediction
103
+
104
+ def generate_beam(model, tokenizer, beam_size: int = 5, prompt=None, embed=None,
105
+ entry_length=67, temperature=1., stop_token: str = '.'):
106
+
107
+ model.eval()
108
+ stop_token_index = tokenizer.encode(stop_token)[0]
109
+ tokens = None
110
+ scores = None
111
+ device = next(model.parameters()).device
112
+ seq_lengths = torch.ones(beam_size, device=device)
113
+ is_stopped = torch.zeros(beam_size, device=device, dtype=torch.bool)
114
+ with torch.no_grad():
115
+ if embed is not None:
116
+ generated = embed
117
+ else:
118
+ if tokens is None:
119
+ tokens = torch.tensor(tokenizer.encode(prompt))
120
+ tokens = tokens.unsqueeze(0).to(device)
121
+ generated = model.gpt.transformer.wte(tokens)
122
+ for i in range(entry_length):
123
+ outputs = model.gpt(inputs_embeds=generated)
124
+ logits = outputs.logits
125
+ logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
126
+ logits = logits.softmax(-1).log()
127
+ if scores is None:
128
+ scores, next_tokens = logits.topk(beam_size, -1)
129
+ generated = generated.expand(beam_size, *generated.shape[1:])
130
+ next_tokens, scores = next_tokens.permute(1, 0), scores.squeeze(0)
131
+ if tokens is None:
132
+ tokens = next_tokens
133
+ else:
134
+ tokens = tokens.expand(beam_size, *tokens.shape[1:])
135
+ tokens = torch.cat((tokens, next_tokens), dim=1)
136
+ else:
137
+ logits[is_stopped] = -float(np.inf)
138
+ logits[is_stopped, 0] = 0
139
+ scores_sum = scores[:, None] + logits
140
+ seq_lengths[~is_stopped] += 1
141
+ scores_sum_average = scores_sum / seq_lengths[:, None]
142
+ scores_sum_average, next_tokens = scores_sum_average.view(-1).topk(beam_size, -1)
143
+ next_tokens_source = next_tokens // scores_sum.shape[1]
144
+ seq_lengths = seq_lengths[next_tokens_source]
145
+ next_tokens = next_tokens % scores_sum.shape[1]
146
+ next_tokens = next_tokens.unsqueeze(1)
147
+ tokens = tokens[next_tokens_source]
148
+ tokens = torch.cat((tokens, next_tokens), dim=1)
149
+ generated = generated[next_tokens_source]
150
+ scores = scores_sum_average * seq_lengths
151
+ is_stopped = is_stopped[next_tokens_source]
152
+ next_token_embed = model.gpt.transformer.wte(next_tokens.squeeze()).view(generated.shape[0], 1, -1)
153
+ generated = torch.cat((generated, next_token_embed), dim=1)
154
+ is_stopped = is_stopped + next_tokens.eq(stop_token_index).squeeze()
155
+ if is_stopped.all():
156
+ break
157
+ scores = scores / seq_lengths
158
+ output_list = tokens.cpu().numpy()
159
+ output_texts = [tokenizer.decode(output[:int(length)]) for output, length in zip(output_list, seq_lengths)]
160
+ order = scores.argsort(descending=True)
161
+ output_texts = [output_texts[i] for i in order]
162
+ return output_texts
163
+
164
+
165
+ def generate2(
166
+ model,
167
+ tokenizer,
168
+ tokens=None,
169
+ prompt=None,
170
+ embed=None,
171
+ entry_count=1,
172
+ entry_length=67, # maximum number of words
173
+ top_p=0.8,
174
+ temperature=1.,
175
+ stop_token: str = '.',
176
+ ):
177
+ model.eval()
178
+ generated_num = 0
179
+ generated_list = []
180
+ stop_token_index = tokenizer.encode(stop_token)[0]
181
+ filter_value = -float("Inf")
182
+ device = next(model.parameters()).device
183
+
184
+ with torch.no_grad():
185
+
186
+ for entry_idx in trange(entry_count):
187
+ if embed is not None:
188
+ generated = embed
189
+ else:
190
+ if tokens is None:
191
+ tokens = torch.tensor(tokenizer.encode(prompt))
192
+ tokens = tokens.unsqueeze(0).to(device)
193
+
194
+ generated = model.gpt.transformer.wte(tokens)
195
+
196
+ for i in range(entry_length):
197
+
198
+ outputs = model.gpt(inputs_embeds=generated)
199
+ logits = outputs.logits
200
+ logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
201
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
202
+ cumulative_probs = torch.cumsum(nnf.softmax(sorted_logits, dim=-1), dim=-1)
203
+ sorted_indices_to_remove = cumulative_probs > top_p
204
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
205
+ ..., :-1
206
+ ].clone()
207
+ sorted_indices_to_remove[..., 0] = 0
208
+
209
+ indices_to_remove = sorted_indices[sorted_indices_to_remove]
210
+ logits[:, indices_to_remove] = filter_value
211
+ next_token = torch.argmax(logits, -1).unsqueeze(0)
212
+ next_token_embed = model.gpt.transformer.wte(next_token)
213
+ if tokens is None:
214
+ tokens = next_token
215
+ else:
216
+ tokens = torch.cat((tokens, next_token), dim=1)
217
+ generated = torch.cat((generated, next_token_embed), dim=1)
218
+ if stop_token_index == next_token.item():
219
+ break
220
+
221
+ output_list = list(tokens.squeeze().cpu().numpy())
222
+ output_text = tokenizer.decode(output_list)
223
+ generated_list.append(output_text)
224
+
225
+ return generated_list[0]
226
+
227
+ is_gpu = False
228
+ device = CUDA(0) if is_gpu else "cpu"
229
+ clip_model, preprocess = clip.load("ViT-B/32", device=device, jit=False)
230
+ tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
231
+
232
+ def inference(img,model_name):
233
+ prefix_length = 10
234
+
235
+ model = ClipCaptionModel(prefix_length)
236
+
237
+ if model_name == "COCO":
238
+ model_path = coco_weight
239
+ else:
240
+ model_path = conceptual_weight
241
+ model.load_state_dict(torch.load(model_path, map_location=CPU))
242
+ model = model.eval()
243
+ device = CUDA(0) if is_gpu else "cpu"
244
+ model = model.to(device)
245
+
246
+ use_beam_search = False
247
+ image = io.imread(img.name)
248
+ pil_image = PIL.Image.fromarray(image)
249
+ image = preprocess(pil_image).unsqueeze(0).to(device)
250
+ with torch.no_grad():
251
+ prefix = clip_model.encode_image(image).to(device, dtype=torch.float32)
252
+ prefix_embed = model.clip_project(prefix).reshape(1, prefix_length, -1)
253
+ if use_beam_search:
254
+ generated_text_prefix = generate_beam(model, tokenizer, embed=prefix_embed)[0]
255
+ else:
256
+ generated_text_prefix = generate2(model, tokenizer, embed=prefix_embed)
257
+ return generated_text_prefix
258
+
259
+ title = "CLIP prefix captioning"
260
+ description = "Gradio demo for CLIP prefix captioning: a simple image captioning model. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below."
261
+ article = "<p style='text-align: center'><a href='https://github.com/rmokady/CLIP_prefix_caption' target='_blank'>Github Repo</a></p>"
262
+
263
+ examples=[['water.jpeg',"COCO"]]
264
+ gr.Interface(
265
+ inference,
266
+ [gr.inputs.Image(type="file", label="Input"),gr.inputs.Radio(choices=["COCO","Conceptual captions"], type="value", default="COCO", label="Model")],
267
+ gr.outputs.Textbox(label="Output"),
268
+ title=title,
269
+ description=description,
270
+ article=article,
271
+ enable_queue=True,
272
+ examples=examples
273
+ ).launch(debug=True)
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ transformers
2
+ gdown
3
+ torch
4
+ numpy
5
+ tqdm
6
+ Pillow
7
+ scikit-image
8
+ git+https://github.com/openai/CLIP.git
water.jpeg ADDED