Update modeling_vila.py
Browse files- modeling_vila.py +7 -4
modeling_vila.py
CHANGED
|
@@ -739,15 +739,18 @@ class VILAForCausalLM(VILAPretrainedModel):
|
|
| 739 |
self.encoders[name].pool_sizes[0][0] = 4 * round_up_to_bucket(num_video_frames / 256)
|
| 740 |
|
| 741 |
if num_video_frames > 512:
|
| 742 |
-
media_split = []
|
| 743 |
-
frames_split = 4
|
| 744 |
for video in media[name]:
|
| 745 |
-
|
|
|
|
|
|
|
|
|
|
| 746 |
embeds_split = []
|
| 747 |
for video in media_split:
|
| 748 |
embeds_split += self.encoders[name]([video], media_config[name])
|
|
|
|
| 749 |
embeds_merged = [
|
| 750 |
-
torch.cat(embeds_split[i *
|
| 751 |
for i in range(len(media[name]))
|
| 752 |
]
|
| 753 |
embeds[name] = deque(embeds_merged)
|
|
|
|
| 739 |
self.encoders[name].pool_sizes[0][0] = 4 * round_up_to_bucket(num_video_frames / 256)
|
| 740 |
|
| 741 |
if num_video_frames > 512:
|
| 742 |
+
media_split, num_splits = [], []
|
|
|
|
| 743 |
for video in media[name]:
|
| 744 |
+
video_split = video.split(512, dim=0)
|
| 745 |
+
media_split.extend(video_split)
|
| 746 |
+
num_splits.append(len(video_split))
|
| 747 |
+
|
| 748 |
embeds_split = []
|
| 749 |
for video in media_split:
|
| 750 |
embeds_split += self.encoders[name]([video], media_config[name])
|
| 751 |
+
|
| 752 |
embeds_merged = [
|
| 753 |
+
torch.cat(embeds_split[i * num_splits[i]: (i + 1) * num_splits[i]], dim=0)
|
| 754 |
for i in range(len(media[name]))
|
| 755 |
]
|
| 756 |
embeds[name] = deque(embeds_merged)
|