Spaces:
Runtime error
Runtime error
#!/usr/bin/env python3 | |
# Portions Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
import os | |
from functools import partial | |
from types import SimpleNamespace | |
from typing import Union, Optional, Tuple, Dict, List | |
import torch | |
import torch.nn as nn | |
from torch import Tensor | |
from omegaconf import DictConfig | |
from imagebind.models.helper import ( | |
EinOpsRearrange, | |
LearnableLogitScaling, | |
Normalize, | |
SelectElement, | |
SelectEOSAndProject, | |
) | |
from imagebind.models.multimodal_formers import SequenceGenericQFormer, disabled_train | |
from imagebind.models.multimodal_preprocessors import ( | |
AudioPreprocessor, | |
IMUPreprocessor, | |
PadIm2Video, | |
PatchEmbedGeneric, | |
RGBDTPreprocessor, | |
SpatioTemporalPosEmbeddingHelper, | |
TextPreprocessor, | |
ThermalPreprocessor, | |
BlipPreprocessor, | |
) | |
from imagebind.models.multimodal_projectors import create_projectors | |
from imagebind.models.transformer import MultiheadAttention, SimpleTransformer | |
ModalityType = SimpleNamespace( | |
VISION="vision", | |
TEXT="text", | |
AUDIO="audio", | |
THERMAL="thermal", | |
DEPTH="depth", | |
IMU="imu", | |
) | |
class ImageBindJoiner(nn.Module): | |
def __init__(self, cfg: DictConfig, output_dim: int): | |
super().__init__() | |
""" | |
cfg: | |
- share_key: Optional[str] | |
- modality_key: DictConfig, the modality cfg for the corresponding key | |
- feat_dim: int, defaults to 1024, the input dimension to the qformer | |
- post_dims: tuple, defaults to (768,), layers for post-qformer projection | |
- pre_dims: tuple, defaults to (), layers for pre-qformer projection | |
- num_query_token: int, defaults to 32, the numbher of query tokens in qformer | |
- freeze_qformer: bool, defaults to true, keeping the qformer frozen or not | |
- qformer_model: str, defaults to "", path to the checkpoint of a qformer, "" for not loading | |
- modality_key ... | |
""" | |
# vision_qformer_model is always "" | |
# assert not (vision_qformer_frozen and vision_qformer_model == "") | |
self.share_key = share_key = cfg.get('share_key', None) | |
self.use_pre_ln = cfg.pop('use_pre_ln') if 'use_pre_ln' in cfg else False | |
if share_key is not None and isinstance(share_key, str): | |
self.share_joiner = True | |
cfg.pop("share_key") | |
assert share_key in cfg, "The modality key to share does not exist." | |
# assert len(cfg) == 1, "Only one config is needed for shared joiner." | |
else: | |
self.share_joiner = False | |
for modality_cfg in cfg.values(): | |
modality_cfg.pre_dims = modality_cfg.get("pre_dims", ()) | |
modality_cfg.post_dims = modality_cfg.get("post_dims", (768,)) | |
modality_cfg.num_query_token = modality_cfg.get("num_query_token", 32) | |
modality_cfg.freeze_qformer = modality_cfg.get("freeze_qformer", True) | |
modality_cfg.qformer_model = modality_cfg.get("qformer_model", "") | |
modality_cfg.freeze_post = modality_cfg.get("freeze_post", False) | |
if self.use_pre_ln: | |
self.modality_pre_lns = self._create_modality_pre_lns(cfg) | |
self.modality_pre_projectors = self._create_modality_pre_projectors(cfg) | |
self.modality_qformers = self._create_modality_qformers(cfg) | |
self.modality_post_projectors = self._create_modality_post_projectors(cfg, output_dim) | |
def _create_modality_pre_lns(cfg): | |
lns = {} | |
for modality, modality_cfg in cfg.items(): | |
lns[modality] = nn.LayerNorm(cfg.feat_dim) | |
return nn.ModuleDict(lns) | |
def _create_modality_pre_projectors(self, cfg): | |
projectors = {} | |
for modality, modality_cfg in cfg.items(): | |
projectors[modality] = create_projectors(tuple(modality_cfg.pre_dims)) | |
return nn.ModuleDict(projectors) | |
def _create_modality_post_projectors(self, cfg, output_dim): | |
projectors = {} | |
for modality, modality_cfg in cfg.items(): | |
projectors[modality] = create_projectors(tuple(modality_cfg.post_dims) + (output_dim,)) | |
if modality_cfg.freeze_post: | |
for p in projectors[modality].parameters(): | |
p.requires_grad = False | |
return nn.ModuleDict(projectors) | |
def _create_modality_qformers(self, cfg): | |
modality_qformers = {} | |
for modality, modality_cfg in cfg.items(): | |
modality_qformers[modality] = SequenceGenericQFormer( | |
num_query_token=modality_cfg.num_query_token, | |
freeze_qformer=modality_cfg.freeze_qformer, | |
encoder_width=modality_cfg.feat_dim, | |
q_former_model=modality_cfg.get("qformer_model", ""), | |
) | |
return nn.ModuleDict(modality_qformers) | |
def forward(self, inputs: Dict[str, Tensor]) -> Dict[str, Tensor]: | |
outputs = {} | |
for modality_key, modality_value in inputs.items(): | |
model_key = self.share_key if self.share_joiner else modality_key | |
if modality_value is not None: | |
if self.use_pre_ln: | |
modality_value = self.modality_pre_lns[modality_key](modality_value) | |
modality_value = self.modality_pre_projectors[model_key](modality_value) | |
modality_value = self.modality_qformers[model_key](modality_value) | |
modality_value = self.modality_post_projectors[model_key](modality_value) | |
outputs[modality_key] = modality_value | |
return outputs | |
class ImageBindModel(nn.Module): | |
def __init__( | |
self, | |
video_frames=2, | |
kernel_size=(2, 14, 14), | |
audio_kernel_size=16, | |
audio_stride=10, | |
out_embed_dim=768, | |
vision_embed_dim=1024, | |
vision_num_blocks=24, | |
vision_num_heads=16, | |
audio_embed_dim=768, | |
audio_num_blocks=12, | |
audio_num_heads=12, | |
audio_num_mel_bins=128, | |
audio_target_len=204, | |
audio_drop_path=0.1, | |
text_embed_dim=768, | |
text_num_blocks=12, | |
text_num_heads=12, | |
depth_embed_dim=384, | |
depth_kernel_size=16, | |
depth_num_blocks=12, | |
depth_num_heads=8, | |
depth_drop_path=0.0, | |
thermal_embed_dim=768, | |
thermal_kernel_size=16, | |
thermal_num_blocks=12, | |
thermal_num_heads=12, | |
thermal_drop_path=0.0, | |
imu_embed_dim=512, | |
imu_kernel_size=8, | |
imu_num_blocks=6, | |
imu_num_heads=8, | |
imu_drop_path=0.7, | |
with_head=True, | |
): | |
super().__init__() | |
self.with_head = with_head | |
self.modality_preprocessors = self._create_modality_preprocessors( | |
video_frames, | |
vision_embed_dim, | |
kernel_size, | |
text_embed_dim, | |
audio_embed_dim, | |
audio_kernel_size, | |
audio_stride, | |
audio_num_mel_bins, | |
audio_target_len, | |
depth_embed_dim, | |
depth_kernel_size, | |
thermal_embed_dim, | |
thermal_kernel_size, | |
imu_embed_dim, | |
) | |
self.modality_trunks = self._create_modality_trunks( | |
vision_embed_dim, | |
vision_num_blocks, | |
vision_num_heads, | |
text_embed_dim, | |
text_num_blocks, | |
text_num_heads, | |
audio_embed_dim, | |
audio_num_blocks, | |
audio_num_heads, | |
audio_drop_path, | |
depth_embed_dim, | |
depth_num_blocks, | |
depth_num_heads, | |
depth_drop_path, | |
thermal_embed_dim, | |
thermal_num_blocks, | |
thermal_num_heads, | |
thermal_drop_path, | |
imu_embed_dim, | |
imu_num_blocks, | |
imu_num_heads, | |
imu_drop_path, | |
) | |
self.modality_heads = self._create_modality_heads( | |
out_embed_dim, | |
vision_embed_dim, | |
text_embed_dim, | |
audio_embed_dim, | |
depth_embed_dim, | |
thermal_embed_dim, | |
imu_embed_dim, | |
) | |
self.modality_postprocessors = self._create_modality_postprocessors( | |
out_embed_dim | |
) | |
def _create_modality_preprocessors( | |
self, | |
video_frames=2, | |
vision_embed_dim=1024, | |
kernel_size=(2, 14, 14), | |
text_embed_dim=768, | |
audio_embed_dim=768, | |
audio_kernel_size=16, | |
audio_stride=10, | |
audio_num_mel_bins=128, | |
audio_target_len=204, | |
depth_embed_dim=768, | |
depth_kernel_size=16, | |
thermal_embed_dim=768, | |
thermal_kernel_size=16, | |
imu_embed_dim=512, | |
): | |
rgbt_stem = PatchEmbedGeneric( | |
proj_stem=[ | |
PadIm2Video(pad_type="repeat", ntimes=2), | |
nn.Conv3d( | |
in_channels=3, | |
kernel_size=kernel_size, | |
out_channels=vision_embed_dim, | |
stride=kernel_size, | |
bias=False, | |
), | |
] | |
) | |
rgbt_preprocessor = RGBDTPreprocessor( | |
img_size=[3, video_frames, 224, 224], | |
num_cls_tokens=1, | |
pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True), | |
rgbt_stem=rgbt_stem, | |
depth_stem=None, | |
) | |
# text_preprocessor = TextPreprocessor( | |
# context_length=77, | |
# vocab_size=49408, | |
# embed_dim=text_embed_dim, | |
# causal_masking=True, | |
# ) | |
audio_stem = PatchEmbedGeneric( | |
proj_stem=[ | |
nn.Conv2d( | |
in_channels=1, | |
kernel_size=audio_kernel_size, | |
stride=audio_stride, | |
out_channels=audio_embed_dim, | |
bias=False, | |
), | |
], | |
norm_layer=nn.LayerNorm(normalized_shape=audio_embed_dim), | |
) | |
audio_preprocessor = AudioPreprocessor( | |
img_size=[1, audio_num_mel_bins, audio_target_len], | |
num_cls_tokens=1, | |
pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True), | |
audio_stem=audio_stem, | |
) | |
# depth_stem = PatchEmbedGeneric( | |
# [ | |
# nn.Conv2d( | |
# kernel_size=depth_kernel_size, | |
# in_channels=1, | |
# out_channels=depth_embed_dim, | |
# stride=depth_kernel_size, | |
# bias=False, | |
# ), | |
# ], | |
# norm_layer=nn.LayerNorm(normalized_shape=depth_embed_dim), | |
# ) | |
# | |
# depth_preprocessor = RGBDTPreprocessor( | |
# img_size=[1, 224, 224], | |
# num_cls_tokens=1, | |
# pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True), | |
# rgbt_stem=None, | |
# depth_stem=depth_stem, | |
# ) | |
# | |
# thermal_stem = PatchEmbedGeneric( | |
# [ | |
# nn.Conv2d( | |
# kernel_size=thermal_kernel_size, | |
# in_channels=1, | |
# out_channels=thermal_embed_dim, | |
# stride=thermal_kernel_size, | |
# bias=False, | |
# ), | |
# ], | |
# norm_layer=nn.LayerNorm(normalized_shape=thermal_embed_dim), | |
# ) | |
# thermal_preprocessor = ThermalPreprocessor( | |
# img_size=[1, 224, 224], | |
# num_cls_tokens=1, | |
# pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True), | |
# thermal_stem=thermal_stem, | |
# ) | |
# | |
# imu_stem = PatchEmbedGeneric( | |
# [ | |
# nn.Linear( | |
# in_features=48, | |
# out_features=imu_embed_dim, | |
# bias=False, | |
# ), | |
# ], | |
# norm_layer=nn.LayerNorm(normalized_shape=imu_embed_dim), | |
# ) | |
# | |
# imu_preprocessor = IMUPreprocessor( | |
# img_size=[6, 2000], | |
# num_cls_tokens=1, | |
# kernel_size=8, | |
# embed_dim=imu_embed_dim, | |
# pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True), | |
# imu_stem=imu_stem, | |
# ) | |
modality_preprocessors = { | |
ModalityType.VISION: rgbt_preprocessor, | |
# ModalityType.TEXT: text_preprocessor, | |
ModalityType.AUDIO: audio_preprocessor, | |
# ModalityType.DEPTH: depth_preprocessor, | |
# ModalityType.THERMAL: thermal_preprocessor, | |
# ModalityType.IMU: imu_preprocessor, | |
} | |
return nn.ModuleDict(modality_preprocessors) | |
def _create_modality_trunks( | |
self, | |
vision_embed_dim=1024, | |
vision_num_blocks=24, | |
vision_num_heads=16, | |
text_embed_dim=768, | |
text_num_blocks=12, | |
text_num_heads=12, | |
audio_embed_dim=768, | |
audio_num_blocks=12, | |
audio_num_heads=12, | |
audio_drop_path=0.0, | |
depth_embed_dim=768, | |
depth_num_blocks=12, | |
depth_num_heads=12, | |
depth_drop_path=0.0, | |
thermal_embed_dim=768, | |
thermal_num_blocks=12, | |
thermal_num_heads=12, | |
thermal_drop_path=0.0, | |
imu_embed_dim=512, | |
imu_num_blocks=6, | |
imu_num_heads=8, | |
imu_drop_path=0.7, | |
): | |
def instantiate_trunk( | |
embed_dim, num_blocks, num_heads, pre_transformer_ln, add_bias_kv, drop_path | |
): | |
return SimpleTransformer( | |
embed_dim=embed_dim, | |
num_blocks=num_blocks, | |
ffn_dropout_rate=0.0, | |
drop_path_rate=drop_path, | |
attn_target=partial( | |
MultiheadAttention, | |
embed_dim=embed_dim, | |
num_heads=num_heads, | |
bias=True, | |
add_bias_kv=add_bias_kv, | |
), | |
pre_transformer_layer=nn.Sequential( | |
nn.LayerNorm(embed_dim, eps=1e-6) | |
if pre_transformer_ln | |
else nn.Identity(), | |
EinOpsRearrange("b l d -> l b d"), | |
), | |
post_transformer_layer=EinOpsRearrange("l b d -> b l d"), | |
) | |
modality_trunks = {} | |
modality_trunks[ModalityType.VISION] = instantiate_trunk( | |
vision_embed_dim, | |
vision_num_blocks, | |
vision_num_heads, | |
pre_transformer_ln=True, | |
add_bias_kv=False, | |
drop_path=0.0, | |
) | |
# modality_trunks[ModalityType.TEXT] = instantiate_trunk( | |
# text_embed_dim, | |
# text_num_blocks, | |
# text_num_heads, | |
# pre_transformer_ln=False, | |
# add_bias_kv=False, | |
# drop_path=0.0, | |
# ) | |
modality_trunks[ModalityType.AUDIO] = instantiate_trunk( | |
audio_embed_dim, | |
audio_num_blocks, | |
audio_num_heads, | |
pre_transformer_ln=False, | |
add_bias_kv=True, | |
drop_path=audio_drop_path, | |
) | |
# modality_trunks[ModalityType.DEPTH] = instantiate_trunk( | |
# depth_embed_dim, | |
# depth_num_blocks, | |
# depth_num_heads, | |
# pre_transformer_ln=False, | |
# add_bias_kv=True, | |
# drop_path=depth_drop_path, | |
# ) | |
# modality_trunks[ModalityType.THERMAL] = instantiate_trunk( | |
# thermal_embed_dim, | |
# thermal_num_blocks, | |
# thermal_num_heads, | |
# pre_transformer_ln=False, | |
# add_bias_kv=True, | |
# drop_path=thermal_drop_path, | |
# ) | |
# modality_trunks[ModalityType.IMU] = instantiate_trunk( | |
# imu_embed_dim, | |
# imu_num_blocks, | |
# imu_num_heads, | |
# pre_transformer_ln=False, | |
# add_bias_kv=True, | |
# drop_path=imu_drop_path, | |
# ) | |
return nn.ModuleDict(modality_trunks) | |
def _create_modality_heads( | |
self, | |
out_embed_dim, | |
vision_embed_dim, | |
text_embed_dim, | |
audio_embed_dim, | |
depth_embed_dim, | |
thermal_embed_dim, | |
imu_embed_dim, | |
use_selection=False, | |
): | |
modality_heads = {} | |
modality_heads[ModalityType.VISION] = nn.Sequential( | |
nn.LayerNorm(normalized_shape=vision_embed_dim, eps=1e-6), | |
SelectElement(index=0) if use_selection else nn.Identity(), | |
nn.Linear(vision_embed_dim, out_embed_dim, bias=False), | |
) | |
# modality_heads[ModalityType.TEXT] = SelectEOSAndProject( | |
# proj=nn.Sequential( | |
# nn.LayerNorm(normalized_shape=text_embed_dim, eps=1e-6), | |
# nn.Linear(text_embed_dim, out_embed_dim, bias=False), | |
# ) | |
# ) | |
modality_heads[ModalityType.AUDIO] = nn.Sequential( | |
nn.LayerNorm(normalized_shape=audio_embed_dim, eps=1e-6), | |
SelectElement(index=0) if use_selection else nn.Identity(), | |
nn.Linear(audio_embed_dim, out_embed_dim, bias=False), | |
) | |
# modality_heads[ModalityType.DEPTH] = nn.Sequential( | |
# nn.LayerNorm(normalized_shape=depth_embed_dim, eps=1e-6), | |
# SelectElement(index=0) if use_selection else nn.Identity(), | |
# nn.Linear(depth_embed_dim, out_embed_dim, bias=False), | |
# ) | |
# | |
# modality_heads[ModalityType.THERMAL] = nn.Sequential( | |
# nn.LayerNorm(normalized_shape=thermal_embed_dim, eps=1e-6), | |
# SelectElement(index=0) if use_selection else nn.Identity(), | |
# nn.Linear(thermal_embed_dim, out_embed_dim, bias=False), | |
# ) | |
# | |
# modality_heads[ModalityType.IMU] = nn.Sequential( | |
# nn.LayerNorm(normalized_shape=imu_embed_dim, eps=1e-6), | |
# SelectElement(index=0) if use_selection else nn.Identity(), | |
# nn.Dropout(p=0.5), | |
# nn.Linear(imu_embed_dim, out_embed_dim, bias=False), | |
# ) | |
return nn.ModuleDict(modality_heads) | |
def _create_modality_postprocessors(self, out_embed_dim): | |
modality_postprocessors = {} | |
modality_postprocessors[ModalityType.VISION] = Normalize(dim=-1) | |
# modality_postprocessors[ModalityType.TEXT] = nn.Sequential( | |
# Normalize(dim=-1), LearnableLogitScaling(learnable=True) | |
# ) | |
modality_postprocessors[ModalityType.AUDIO] = nn.Sequential( | |
Normalize(dim=-1), | |
LearnableLogitScaling(logit_scale_init=20.0, learnable=False), | |
) | |
# modality_postprocessors[ModalityType.DEPTH] = nn.Sequential( | |
# Normalize(dim=-1), | |
# LearnableLogitScaling(logit_scale_init=5.0, learnable=False), | |
# ) | |
# modality_postprocessors[ModalityType.THERMAL] = nn.Sequential( | |
# Normalize(dim=-1), | |
# LearnableLogitScaling(logit_scale_init=10.0, learnable=False), | |
# ) | |
# modality_postprocessors[ModalityType.IMU] = nn.Sequential( | |
# Normalize(dim=-1), | |
# LearnableLogitScaling(logit_scale_init=5.0, learnable=False), | |
# ) | |
return nn.ModuleDict(modality_postprocessors) | |
def forward(self, inputs: Dict[str, Tensor]) -> Dict[str, Tensor]: | |
outputs = {} | |
for modality_key, modality_value in inputs.items(): | |
reduce_list = ( | |
modality_value.ndim >= 5 | |
) # Audio and Video inputs consist of multiple clips | |
if reduce_list: | |
B, S = modality_value.shape[:2] | |
modality_value = modality_value.reshape( | |
B * S, *modality_value.shape[2:] | |
) | |
if modality_value is not None: | |
modality_value = self.modality_preprocessors[modality_key]( | |
**{modality_key: modality_value} | |
) | |
trunk_inputs = modality_value["trunk"] | |
head_inputs = modality_value["head"] | |
modality_value = self.modality_trunks[modality_key](**trunk_inputs) | |
# NOTE: No heads are needed any more. | |
if self.with_head: | |
modality_value = self.modality_heads[modality_key]( | |
modality_value, **head_inputs | |
) | |
modality_value = self.modality_postprocessors[modality_key]( | |
modality_value | |
) | |
# NOTE: The reduction operation has been modified. | |
if reduce_list: | |
modality_value = modality_value.reshape(B, S, *modality_value.shape[1:]) | |
modality_value = modality_value.mean(dim=1) | |
outputs[modality_key] = modality_value | |
return outputs | |
def imagebind_huge(pretrained=False, freeze_imagebind=False, with_head=True, use_blip_vision=False): | |
model = ImageBindModel( | |
vision_embed_dim=1280, | |
vision_num_blocks=32, | |
vision_num_heads=16, | |
text_embed_dim=1024, | |
text_num_blocks=24, | |
text_num_heads=16, | |
out_embed_dim=1024, | |
audio_drop_path=0.1, | |
imu_drop_path=0.7, | |
with_head=with_head, | |
) | |
if pretrained: | |
if not os.path.exists(".checkpoints/imagebind_huge.pth"): | |
print( | |
"Downloading imagebind weights to .checkpoints/imagebind_huge.pth ..." | |
) | |
os.makedirs(".checkpoints", exist_ok=True) | |
torch.hub.download_url_to_file( | |
"https://dl.fbaipublicfiles.com/imagebind/imagebind_huge.pth", | |
".checkpoints/imagebind_huge.pth", | |
progress=True, | |
) | |
model.load_state_dict(torch.load(".checkpoints/imagebind_huge.pth"), strict=False) | |
if use_blip_vision: | |
from bubogpt.models.eva_vit import create_eva_vit_g | |
visual_encoder = create_eva_vit_g( | |
img_size=224, drop_path_rate=0., use_checkpoint=False, precision='fp16' | |
) | |
vision_ln = LayerNorm(visual_encoder.num_features) | |
vision_ln.load_state_dict(load_ln_params()) | |
model.modality_preprocessors[ModalityType.VISION] = BlipPreprocessor() | |
model.modality_trunks[ModalityType.VISION] = visual_encoder | |
model.modality_postprocessors[ModalityType.VISION] = vision_ln | |
if freeze_imagebind: | |
for name, param in model.named_parameters(): | |
param.requires_grad = False | |
model = model.eval() | |
model.train = disabled_train | |
return model | |
class LayerNorm(nn.LayerNorm): | |
"""Subclass torch's LayerNorm to handle fp16.""" | |
def forward(self, x: torch.Tensor): | |
orig_type = x.dtype | |
ret = super().forward(x.type(torch.float32)) | |
return ret.type(orig_type) | |
def load_ln_params(path="checkpoints/blip2_pretrained_flant5xxl.pth"): | |
state_dict = torch.load(path, map_location="cpu")["model"] | |
params = type(state_dict)() | |
params["weight"] = state_dict["ln_vision.weight"] | |
params["bias"] = state_dict["ln_vision.bias"] | |
return params | |
def replace_joiner_vision(joiner, q_former_model, proj_model): | |
assert isinstance(joiner.modality_pre_projectors.vision, nn.Identity) | |
joiner.modality_qformers[ModalityType.VISION].load_Qformer(q_former_model) | |
if proj_model: | |
state_dict = torch.load(proj_model, map_location="cpu")["model"] | |
params = type(state_dict)() | |
params["fc.weight"] = state_dict["llama_proj.weight"] | |
params["fc.bias"] = state_dict["llama_proj.bias"] | |
joiner.modality_post_projectors[ModalityType.VISION].load_state_dict(params, strict=False) | |