| |
| |
| |
|
|
| |
| |
|
|
|
|
| import os |
| import urllib |
| from functools import partial |
| from types import SimpleNamespace |
|
|
| import torch |
| import torch.nn as nn |
|
|
| from .helpers import ( |
| EinOpsRearrange, |
| LearnableLogitScaling, |
| Normalize, |
| SelectElement, |
| SelectEOSAndProject, |
| ) |
| from .multimodal_preprocessors import ( |
| AudioPreprocessor, |
| IMUPreprocessor, |
| PadIm2Video, |
| PatchEmbedGeneric, |
| RGBDTPreprocessor, |
| SpatioTemporalPosEmbeddingHelper, |
| TextPreprocessor, |
| ThermalPreprocessor, |
| ) |
|
|
| from .transformer import MultiheadAttention, SimpleTransformer |
|
|
|
|
| ModalityType = SimpleNamespace( |
| VISION="vision", |
| TEXT="text", |
| AUDIO="audio", |
| THERMAL="thermal", |
| DEPTH="depth", |
| IMU="imu", |
| ) |
|
|
|
|
| class ImageBindModel(nn.Module): |
| def __init__( |
| self, |
| video_frames=2, |
| kernel_size=(2, 14, 14), |
| audio_kernel_size=16, |
| audio_stride=10, |
| out_embed_dim=768, |
| vision_embed_dim=1024, |
| vision_num_blocks=24, |
| vision_num_heads=16, |
| audio_embed_dim=768, |
| audio_num_blocks=12, |
| audio_num_heads=12, |
| audio_num_mel_bins=128, |
| audio_target_len=204, |
| audio_drop_path=0.1, |
| text_embed_dim=768, |
| text_num_blocks=12, |
| text_num_heads=12, |
| depth_embed_dim=384, |
| depth_kernel_size=16, |
| depth_num_blocks=12, |
| depth_num_heads=8, |
| depth_drop_path=0.0, |
| thermal_embed_dim=768, |
| thermal_kernel_size=16, |
| thermal_num_blocks=12, |
| thermal_num_heads=12, |
| thermal_drop_path=0.0, |
| imu_embed_dim=512, |
| imu_kernel_size=8, |
| imu_num_blocks=6, |
| imu_num_heads=8, |
| imu_drop_path=0.7, |
| ): |
| super().__init__() |
|
|
| self.modality_preprocessors = self._create_modality_preprocessors( |
| video_frames, |
| vision_embed_dim, |
| kernel_size, |
| text_embed_dim, |
| audio_embed_dim, |
| audio_kernel_size, |
| audio_stride, |
| audio_num_mel_bins, |
| audio_target_len, |
| depth_embed_dim, |
| depth_kernel_size, |
| thermal_embed_dim, |
| thermal_kernel_size, |
| imu_embed_dim, |
| ) |
|
|
| self.modality_trunks = self._create_modality_trunks( |
| vision_embed_dim, |
| vision_num_blocks, |
| vision_num_heads, |
| text_embed_dim, |
| text_num_blocks, |
| text_num_heads, |
| audio_embed_dim, |
| audio_num_blocks, |
| audio_num_heads, |
| audio_drop_path, |
| depth_embed_dim, |
| depth_num_blocks, |
| depth_num_heads, |
| depth_drop_path, |
| thermal_embed_dim, |
| thermal_num_blocks, |
| thermal_num_heads, |
| thermal_drop_path, |
| imu_embed_dim, |
| imu_num_blocks, |
| imu_num_heads, |
| imu_drop_path, |
| ) |
|
|
| self.modality_heads = self._create_modality_heads( |
| out_embed_dim, |
| vision_embed_dim, |
| text_embed_dim, |
| audio_embed_dim, |
| depth_embed_dim, |
| thermal_embed_dim, |
| imu_embed_dim, |
| ) |
|
|
| self.modality_postprocessors = self._create_modality_postprocessors( |
| out_embed_dim |
| ) |
|
|
| def _create_modality_preprocessors( |
| self, |
| video_frames=2, |
| vision_embed_dim=1024, |
| kernel_size=(2, 14, 14), |
| text_embed_dim=768, |
| audio_embed_dim=768, |
| audio_kernel_size=16, |
| audio_stride=10, |
| audio_num_mel_bins=128, |
| audio_target_len=204, |
| depth_embed_dim=768, |
| depth_kernel_size=16, |
| thermal_embed_dim=768, |
| thermal_kernel_size=16, |
| imu_embed_dim=512, |
| ): |
| rgbt_stem = PatchEmbedGeneric( |
| proj_stem=[ |
| PadIm2Video(pad_type="repeat", ntimes=2), |
| nn.Conv3d( |
| in_channels=3, |
| kernel_size=kernel_size, |
| out_channels=vision_embed_dim, |
| stride=kernel_size, |
| bias=False, |
| ), |
| ] |
| ) |
| rgbt_preprocessor = RGBDTPreprocessor( |
| img_size=[3, video_frames, 224, 224], |
| num_cls_tokens=1, |
| pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True), |
| rgbt_stem=rgbt_stem, |
| depth_stem=None, |
| ) |
|
|
| text_preprocessor = TextPreprocessor( |
| context_length=77, |
| vocab_size=49408, |
| embed_dim=text_embed_dim, |
| causal_masking=True, |
| ) |
|
|
| audio_stem = PatchEmbedGeneric( |
| proj_stem=[ |
| nn.Conv2d( |
| in_channels=1, |
| kernel_size=audio_kernel_size, |
| stride=audio_stride, |
| out_channels=audio_embed_dim, |
| bias=False, |
| ), |
| ], |
| norm_layer=nn.LayerNorm(normalized_shape=audio_embed_dim), |
| ) |
| audio_preprocessor = AudioPreprocessor( |
| img_size=[1, audio_num_mel_bins, audio_target_len], |
| num_cls_tokens=1, |
| pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True), |
| audio_stem=audio_stem, |
| ) |
|
|
| depth_stem = PatchEmbedGeneric( |
| [ |
| nn.Conv2d( |
| kernel_size=depth_kernel_size, |
| in_channels=1, |
| out_channels=depth_embed_dim, |
| stride=depth_kernel_size, |
| bias=False, |
| ), |
| ], |
| norm_layer=nn.LayerNorm(normalized_shape=depth_embed_dim), |
| ) |
|
|
| depth_preprocessor = RGBDTPreprocessor( |
| img_size=[1, 224, 224], |
| num_cls_tokens=1, |
| pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True), |
| rgbt_stem=None, |
| depth_stem=depth_stem, |
| ) |
|
|
| thermal_stem = PatchEmbedGeneric( |
| [ |
| nn.Conv2d( |
| kernel_size=thermal_kernel_size, |
| in_channels=1, |
| out_channels=thermal_embed_dim, |
| stride=thermal_kernel_size, |
| bias=False, |
| ), |
| ], |
| norm_layer=nn.LayerNorm(normalized_shape=thermal_embed_dim), |
| ) |
| thermal_preprocessor = ThermalPreprocessor( |
| img_size=[1, 224, 224], |
| num_cls_tokens=1, |
| pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True), |
| thermal_stem=thermal_stem, |
| ) |
|
|
| imu_stem = PatchEmbedGeneric( |
| [ |
| nn.Linear( |
| in_features=48, |
| out_features=imu_embed_dim, |
| bias=False, |
| ), |
| ], |
| norm_layer=nn.LayerNorm(normalized_shape=imu_embed_dim), |
| ) |
|
|
| imu_preprocessor = IMUPreprocessor( |
| img_size=[6, 2000], |
| num_cls_tokens=1, |
| kernel_size=8, |
| embed_dim=imu_embed_dim, |
| pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True), |
| imu_stem=imu_stem, |
| ) |
|
|
| modality_preprocessors = { |
| ModalityType.VISION: rgbt_preprocessor, |
| ModalityType.TEXT: text_preprocessor, |
| ModalityType.AUDIO: audio_preprocessor, |
| ModalityType.DEPTH: depth_preprocessor, |
| ModalityType.THERMAL: thermal_preprocessor, |
| ModalityType.IMU: imu_preprocessor, |
| } |
|
|
| return nn.ModuleDict(modality_preprocessors) |
|
|
| def _create_modality_trunks( |
| self, |
| vision_embed_dim=1024, |
| vision_num_blocks=24, |
| vision_num_heads=16, |
| text_embed_dim=768, |
| text_num_blocks=12, |
| text_num_heads=12, |
| audio_embed_dim=768, |
| audio_num_blocks=12, |
| audio_num_heads=12, |
| audio_drop_path=0.0, |
| depth_embed_dim=768, |
| depth_num_blocks=12, |
| depth_num_heads=12, |
| depth_drop_path=0.0, |
| thermal_embed_dim=768, |
| thermal_num_blocks=12, |
| thermal_num_heads=12, |
| thermal_drop_path=0.0, |
| imu_embed_dim=512, |
| imu_num_blocks=6, |
| imu_num_heads=8, |
| imu_drop_path=0.7, |
| ): |
| def instantiate_trunk( |
| embed_dim, num_blocks, num_heads, pre_transformer_ln, add_bias_kv, drop_path |
| ): |
| return SimpleTransformer( |
| embed_dim=embed_dim, |
| num_blocks=num_blocks, |
| ffn_dropout_rate=0.0, |
| drop_path_rate=drop_path, |
| attn_target=partial( |
| MultiheadAttention, |
| embed_dim=embed_dim, |
| num_heads=num_heads, |
| bias=True, |
| add_bias_kv=add_bias_kv, |
| ), |
| pre_transformer_layer=nn.Sequential( |
| nn.LayerNorm(embed_dim, eps=1e-6) |
| if pre_transformer_ln |
| else nn.Identity(), |
| EinOpsRearrange("b l d -> l b d"), |
| ), |
| post_transformer_layer=EinOpsRearrange("l b d -> b l d"), |
| ) |
|
|
| modality_trunks = {} |
| modality_trunks[ModalityType.VISION] = instantiate_trunk( |
| vision_embed_dim, |
| vision_num_blocks, |
| vision_num_heads, |
| pre_transformer_ln=True, |
| add_bias_kv=False, |
| drop_path=0.0, |
| ) |
| modality_trunks[ModalityType.TEXT] = instantiate_trunk( |
| text_embed_dim, |
| text_num_blocks, |
| text_num_heads, |
| pre_transformer_ln=False, |
| add_bias_kv=False, |
| drop_path=0.0, |
| ) |
| modality_trunks[ModalityType.AUDIO] = instantiate_trunk( |
| audio_embed_dim, |
| audio_num_blocks, |
| audio_num_heads, |
| pre_transformer_ln=False, |
| add_bias_kv=True, |
| drop_path=audio_drop_path, |
| ) |
| modality_trunks[ModalityType.DEPTH] = instantiate_trunk( |
| depth_embed_dim, |
| depth_num_blocks, |
| depth_num_heads, |
| pre_transformer_ln=False, |
| add_bias_kv=True, |
| drop_path=depth_drop_path, |
| ) |
| modality_trunks[ModalityType.THERMAL] = instantiate_trunk( |
| thermal_embed_dim, |
| thermal_num_blocks, |
| thermal_num_heads, |
| pre_transformer_ln=False, |
| add_bias_kv=True, |
| drop_path=thermal_drop_path, |
| ) |
| modality_trunks[ModalityType.IMU] = instantiate_trunk( |
| imu_embed_dim, |
| imu_num_blocks, |
| imu_num_heads, |
| pre_transformer_ln=False, |
| add_bias_kv=True, |
| drop_path=imu_drop_path, |
| ) |
|
|
| return nn.ModuleDict(modality_trunks) |
|
|
| def _create_modality_heads( |
| self, |
| out_embed_dim, |
| vision_embed_dim, |
| text_embed_dim, |
| audio_embed_dim, |
| depth_embed_dim, |
| thermal_embed_dim, |
| imu_embed_dim, |
| ): |
| modality_heads = {} |
|
|
| modality_heads[ModalityType.VISION] = nn.Sequential( |
| nn.LayerNorm(normalized_shape=vision_embed_dim, eps=1e-6), |
| SelectElement(index=0), |
| nn.Linear(vision_embed_dim, out_embed_dim, bias=False), |
| ) |
|
|
| modality_heads[ModalityType.TEXT] = SelectEOSAndProject( |
| proj=nn.Sequential( |
| nn.LayerNorm(normalized_shape=text_embed_dim, eps=1e-6), |
| nn.Linear(text_embed_dim, out_embed_dim, bias=False), |
| ) |
| ) |
|
|
| modality_heads[ModalityType.AUDIO] = nn.Sequential( |
| nn.LayerNorm(normalized_shape=audio_embed_dim, eps=1e-6), |
| SelectElement(index=0), |
| nn.Linear(audio_embed_dim, out_embed_dim, bias=False), |
| ) |
|
|
| modality_heads[ModalityType.DEPTH] = nn.Sequential( |
| nn.LayerNorm(normalized_shape=depth_embed_dim, eps=1e-6), |
| SelectElement(index=0), |
| nn.Linear(depth_embed_dim, out_embed_dim, bias=False), |
| ) |
|
|
| modality_heads[ModalityType.THERMAL] = nn.Sequential( |
| nn.LayerNorm(normalized_shape=thermal_embed_dim, eps=1e-6), |
| SelectElement(index=0), |
| nn.Linear(thermal_embed_dim, out_embed_dim, bias=False), |
| ) |
|
|
| modality_heads[ModalityType.IMU] = nn.Sequential( |
| nn.LayerNorm(normalized_shape=imu_embed_dim, eps=1e-6), |
| SelectElement(index=0), |
| nn.Dropout(p=0.5), |
| nn.Linear(imu_embed_dim, out_embed_dim, bias=False), |
| ) |
|
|
| return nn.ModuleDict(modality_heads) |
|
|
| def _create_modality_postprocessors(self, out_embed_dim): |
| modality_postprocessors = {} |
|
|
| modality_postprocessors[ModalityType.VISION] = Normalize(dim=-1) |
| modality_postprocessors[ModalityType.TEXT] = nn.Sequential( |
| Normalize(dim=-1), LearnableLogitScaling(learnable=True) |
| ) |
| modality_postprocessors[ModalityType.AUDIO] = nn.Sequential( |
| Normalize(dim=-1), |
| LearnableLogitScaling(logit_scale_init=20.0, learnable=False), |
| ) |
| modality_postprocessors[ModalityType.DEPTH] = nn.Sequential( |
| Normalize(dim=-1), |
| LearnableLogitScaling(logit_scale_init=5.0, learnable=False), |
| ) |
| modality_postprocessors[ModalityType.THERMAL] = nn.Sequential( |
| Normalize(dim=-1), |
| LearnableLogitScaling(logit_scale_init=10.0, learnable=False), |
| ) |
| modality_postprocessors[ModalityType.IMU] = nn.Sequential( |
| Normalize(dim=-1), |
| LearnableLogitScaling(logit_scale_init=5.0, learnable=False), |
| ) |
| return nn.ModuleDict(modality_postprocessors) |
|
|
| def forward(self, inputs): |
| outputs = {} |
| for modality_key, modality_value in inputs.items(): |
| reduce_list = ( |
| modality_value.ndim >= 5 |
| ) |
| if reduce_list: |
| B, S = modality_value.shape[:2] |
| modality_value = modality_value.reshape( |
| B * S, *modality_value.shape[2:] |
| ) |
|
|
| if modality_value is not None: |
| modality_value = self.modality_preprocessors[modality_key]( |
| **{modality_key: modality_value} |
| ) |
| trunk_inputs = modality_value["trunk"] |
| head_inputs = modality_value["head"] |
| modality_value = self.modality_trunks[modality_key](**trunk_inputs) |
| modality_value = self.modality_heads[modality_key]( |
| modality_value, **head_inputs |
| ) |
| if modality_key in [ModalityType.AUDIO]: |
| modality_value = self.modality_postprocessors[modality_key][0]( |
| modality_value |
| ) |
| else: |
| modality_value = self.modality_postprocessors[modality_key]( |
| modality_value |
| ) |
|
|
| if reduce_list: |
| modality_value = modality_value.reshape(B, S, -1) |
| modality_value = modality_value.mean(dim=1) |
|
|
| outputs[modality_key] = modality_value |
|
|
| return outputs |
|
|
|
|
| def imagebind_huge(pretrained=False, store_path=r'.checkpoints'): |
| model = ImageBindModel( |
| vision_embed_dim=1280, |
| vision_num_blocks=32, |
| vision_num_heads=16, |
| text_embed_dim=1024, |
| text_num_blocks=24, |
| text_num_heads=16, |
| out_embed_dim=1024, |
| audio_drop_path=0.1, |
| imu_drop_path=0.7, |
| ) |
|
|
| if pretrained: |
| if not os.path.exists("{}/imagebind_huge.pth".format(store_path)): |
| print( |
| "Downloading imagebind weights to {}/imagebind_huge.pth ...".format(store_path) |
| ) |
| os.makedirs(store_path, exist_ok=True) |
| torch.hub.download_url_to_file( |
| "https://dl.fbaipublicfiles.com/imagebind/imagebind_huge.pth", |
| "{}/imagebind_huge.pth".format(store_path), |
| progress=True, |
| ) |
| print("Loading imagebind weights from {}/imagebind_huge.pth ...".format(store_path)) |
| model.load_state_dict(torch.load("{}/imagebind_huge.pth".format(store_path))) |
|
|
| return model, 1024 |
|
|