BuboGPT / imagebind /models /image_bind.py
bingyikang
fix proj moel
72960a7
raw
history blame
24.3 kB
#!/usr/bin/env python3
# Portions Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
from functools import partial
from types import SimpleNamespace
from typing import Union, Optional, Tuple, Dict, List
import torch
import torch.nn as nn
from torch import Tensor
from omegaconf import DictConfig
from imagebind.models.helper import (
EinOpsRearrange,
LearnableLogitScaling,
Normalize,
SelectElement,
SelectEOSAndProject,
)
from imagebind.models.multimodal_formers import SequenceGenericQFormer, disabled_train
from imagebind.models.multimodal_preprocessors import (
AudioPreprocessor,
IMUPreprocessor,
PadIm2Video,
PatchEmbedGeneric,
RGBDTPreprocessor,
SpatioTemporalPosEmbeddingHelper,
TextPreprocessor,
ThermalPreprocessor,
BlipPreprocessor,
)
from imagebind.models.multimodal_projectors import create_projectors
from imagebind.models.transformer import MultiheadAttention, SimpleTransformer
ModalityType = SimpleNamespace(
VISION="vision",
TEXT="text",
AUDIO="audio",
THERMAL="thermal",
DEPTH="depth",
IMU="imu",
)
class ImageBindJoiner(nn.Module):
def __init__(self, cfg: DictConfig, output_dim: int):
super().__init__()
"""
cfg:
- share_key: Optional[str]
- modality_key: DictConfig, the modality cfg for the corresponding key
- feat_dim: int, defaults to 1024, the input dimension to the qformer
- post_dims: tuple, defaults to (768,), layers for post-qformer projection
- pre_dims: tuple, defaults to (), layers for pre-qformer projection
- num_query_token: int, defaults to 32, the numbher of query tokens in qformer
- freeze_qformer: bool, defaults to true, keeping the qformer frozen or not
- qformer_model: str, defaults to "", path to the checkpoint of a qformer, "" for not loading
- modality_key ...
"""
# vision_qformer_model is always ""
# assert not (vision_qformer_frozen and vision_qformer_model == "")
self.share_key = share_key = cfg.get('share_key', None)
self.use_pre_ln = cfg.pop('use_pre_ln') if 'use_pre_ln' in cfg else False
if share_key is not None and isinstance(share_key, str):
self.share_joiner = True
cfg.pop("share_key")
assert share_key in cfg, "The modality key to share does not exist."
# assert len(cfg) == 1, "Only one config is needed for shared joiner."
else:
self.share_joiner = False
for modality_cfg in cfg.values():
modality_cfg.pre_dims = modality_cfg.get("pre_dims", ())
modality_cfg.post_dims = modality_cfg.get("post_dims", (768,))
modality_cfg.num_query_token = modality_cfg.get("num_query_token", 32)
modality_cfg.freeze_qformer = modality_cfg.get("freeze_qformer", True)
modality_cfg.qformer_model = modality_cfg.get("qformer_model", "")
modality_cfg.freeze_post = modality_cfg.get("freeze_post", False)
if self.use_pre_ln:
self.modality_pre_lns = self._create_modality_pre_lns(cfg)
self.modality_pre_projectors = self._create_modality_pre_projectors(cfg)
self.modality_qformers = self._create_modality_qformers(cfg)
self.modality_post_projectors = self._create_modality_post_projectors(cfg, output_dim)
def _create_modality_pre_lns(cfg):
lns = {}
for modality, modality_cfg in cfg.items():
lns[modality] = nn.LayerNorm(cfg.feat_dim)
return nn.ModuleDict(lns)
def _create_modality_pre_projectors(self, cfg):
projectors = {}
for modality, modality_cfg in cfg.items():
projectors[modality] = create_projectors(tuple(modality_cfg.pre_dims))
return nn.ModuleDict(projectors)
def _create_modality_post_projectors(self, cfg, output_dim):
projectors = {}
for modality, modality_cfg in cfg.items():
projectors[modality] = create_projectors(tuple(modality_cfg.post_dims) + (output_dim,))
if modality_cfg.freeze_post:
for p in projectors[modality].parameters():
p.requires_grad = False
return nn.ModuleDict(projectors)
def _create_modality_qformers(self, cfg):
modality_qformers = {}
for modality, modality_cfg in cfg.items():
modality_qformers[modality] = SequenceGenericQFormer(
num_query_token=modality_cfg.num_query_token,
freeze_qformer=modality_cfg.freeze_qformer,
encoder_width=modality_cfg.feat_dim,
q_former_model=modality_cfg.get("qformer_model", ""),
)
return nn.ModuleDict(modality_qformers)
def forward(self, inputs: Dict[str, Tensor]) -> Dict[str, Tensor]:
outputs = {}
for modality_key, modality_value in inputs.items():
model_key = self.share_key if self.share_joiner else modality_key
if modality_value is not None:
if self.use_pre_ln:
modality_value = self.modality_pre_lns[modality_key](modality_value)
modality_value = self.modality_pre_projectors[model_key](modality_value)
modality_value = self.modality_qformers[model_key](modality_value)
modality_value = self.modality_post_projectors[model_key](modality_value)
outputs[modality_key] = modality_value
return outputs
class ImageBindModel(nn.Module):
def __init__(
self,
video_frames=2,
kernel_size=(2, 14, 14),
audio_kernel_size=16,
audio_stride=10,
out_embed_dim=768,
vision_embed_dim=1024,
vision_num_blocks=24,
vision_num_heads=16,
audio_embed_dim=768,
audio_num_blocks=12,
audio_num_heads=12,
audio_num_mel_bins=128,
audio_target_len=204,
audio_drop_path=0.1,
text_embed_dim=768,
text_num_blocks=12,
text_num_heads=12,
depth_embed_dim=384,
depth_kernel_size=16,
depth_num_blocks=12,
depth_num_heads=8,
depth_drop_path=0.0,
thermal_embed_dim=768,
thermal_kernel_size=16,
thermal_num_blocks=12,
thermal_num_heads=12,
thermal_drop_path=0.0,
imu_embed_dim=512,
imu_kernel_size=8,
imu_num_blocks=6,
imu_num_heads=8,
imu_drop_path=0.7,
with_head=True,
):
super().__init__()
self.with_head = with_head
self.modality_preprocessors = self._create_modality_preprocessors(
video_frames,
vision_embed_dim,
kernel_size,
text_embed_dim,
audio_embed_dim,
audio_kernel_size,
audio_stride,
audio_num_mel_bins,
audio_target_len,
depth_embed_dim,
depth_kernel_size,
thermal_embed_dim,
thermal_kernel_size,
imu_embed_dim,
)
self.modality_trunks = self._create_modality_trunks(
vision_embed_dim,
vision_num_blocks,
vision_num_heads,
text_embed_dim,
text_num_blocks,
text_num_heads,
audio_embed_dim,
audio_num_blocks,
audio_num_heads,
audio_drop_path,
depth_embed_dim,
depth_num_blocks,
depth_num_heads,
depth_drop_path,
thermal_embed_dim,
thermal_num_blocks,
thermal_num_heads,
thermal_drop_path,
imu_embed_dim,
imu_num_blocks,
imu_num_heads,
imu_drop_path,
)
self.modality_heads = self._create_modality_heads(
out_embed_dim,
vision_embed_dim,
text_embed_dim,
audio_embed_dim,
depth_embed_dim,
thermal_embed_dim,
imu_embed_dim,
)
self.modality_postprocessors = self._create_modality_postprocessors(
out_embed_dim
)
def _create_modality_preprocessors(
self,
video_frames=2,
vision_embed_dim=1024,
kernel_size=(2, 14, 14),
text_embed_dim=768,
audio_embed_dim=768,
audio_kernel_size=16,
audio_stride=10,
audio_num_mel_bins=128,
audio_target_len=204,
depth_embed_dim=768,
depth_kernel_size=16,
thermal_embed_dim=768,
thermal_kernel_size=16,
imu_embed_dim=512,
):
rgbt_stem = PatchEmbedGeneric(
proj_stem=[
PadIm2Video(pad_type="repeat", ntimes=2),
nn.Conv3d(
in_channels=3,
kernel_size=kernel_size,
out_channels=vision_embed_dim,
stride=kernel_size,
bias=False,
),
]
)
rgbt_preprocessor = RGBDTPreprocessor(
img_size=[3, video_frames, 224, 224],
num_cls_tokens=1,
pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True),
rgbt_stem=rgbt_stem,
depth_stem=None,
)
# text_preprocessor = TextPreprocessor(
# context_length=77,
# vocab_size=49408,
# embed_dim=text_embed_dim,
# causal_masking=True,
# )
audio_stem = PatchEmbedGeneric(
proj_stem=[
nn.Conv2d(
in_channels=1,
kernel_size=audio_kernel_size,
stride=audio_stride,
out_channels=audio_embed_dim,
bias=False,
),
],
norm_layer=nn.LayerNorm(normalized_shape=audio_embed_dim),
)
audio_preprocessor = AudioPreprocessor(
img_size=[1, audio_num_mel_bins, audio_target_len],
num_cls_tokens=1,
pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True),
audio_stem=audio_stem,
)
# depth_stem = PatchEmbedGeneric(
# [
# nn.Conv2d(
# kernel_size=depth_kernel_size,
# in_channels=1,
# out_channels=depth_embed_dim,
# stride=depth_kernel_size,
# bias=False,
# ),
# ],
# norm_layer=nn.LayerNorm(normalized_shape=depth_embed_dim),
# )
#
# depth_preprocessor = RGBDTPreprocessor(
# img_size=[1, 224, 224],
# num_cls_tokens=1,
# pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True),
# rgbt_stem=None,
# depth_stem=depth_stem,
# )
#
# thermal_stem = PatchEmbedGeneric(
# [
# nn.Conv2d(
# kernel_size=thermal_kernel_size,
# in_channels=1,
# out_channels=thermal_embed_dim,
# stride=thermal_kernel_size,
# bias=False,
# ),
# ],
# norm_layer=nn.LayerNorm(normalized_shape=thermal_embed_dim),
# )
# thermal_preprocessor = ThermalPreprocessor(
# img_size=[1, 224, 224],
# num_cls_tokens=1,
# pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True),
# thermal_stem=thermal_stem,
# )
#
# imu_stem = PatchEmbedGeneric(
# [
# nn.Linear(
# in_features=48,
# out_features=imu_embed_dim,
# bias=False,
# ),
# ],
# norm_layer=nn.LayerNorm(normalized_shape=imu_embed_dim),
# )
#
# imu_preprocessor = IMUPreprocessor(
# img_size=[6, 2000],
# num_cls_tokens=1,
# kernel_size=8,
# embed_dim=imu_embed_dim,
# pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True),
# imu_stem=imu_stem,
# )
modality_preprocessors = {
ModalityType.VISION: rgbt_preprocessor,
# ModalityType.TEXT: text_preprocessor,
ModalityType.AUDIO: audio_preprocessor,
# ModalityType.DEPTH: depth_preprocessor,
# ModalityType.THERMAL: thermal_preprocessor,
# ModalityType.IMU: imu_preprocessor,
}
return nn.ModuleDict(modality_preprocessors)
def _create_modality_trunks(
self,
vision_embed_dim=1024,
vision_num_blocks=24,
vision_num_heads=16,
text_embed_dim=768,
text_num_blocks=12,
text_num_heads=12,
audio_embed_dim=768,
audio_num_blocks=12,
audio_num_heads=12,
audio_drop_path=0.0,
depth_embed_dim=768,
depth_num_blocks=12,
depth_num_heads=12,
depth_drop_path=0.0,
thermal_embed_dim=768,
thermal_num_blocks=12,
thermal_num_heads=12,
thermal_drop_path=0.0,
imu_embed_dim=512,
imu_num_blocks=6,
imu_num_heads=8,
imu_drop_path=0.7,
):
def instantiate_trunk(
embed_dim, num_blocks, num_heads, pre_transformer_ln, add_bias_kv, drop_path
):
return SimpleTransformer(
embed_dim=embed_dim,
num_blocks=num_blocks,
ffn_dropout_rate=0.0,
drop_path_rate=drop_path,
attn_target=partial(
MultiheadAttention,
embed_dim=embed_dim,
num_heads=num_heads,
bias=True,
add_bias_kv=add_bias_kv,
),
pre_transformer_layer=nn.Sequential(
nn.LayerNorm(embed_dim, eps=1e-6)
if pre_transformer_ln
else nn.Identity(),
EinOpsRearrange("b l d -> l b d"),
),
post_transformer_layer=EinOpsRearrange("l b d -> b l d"),
)
modality_trunks = {}
modality_trunks[ModalityType.VISION] = instantiate_trunk(
vision_embed_dim,
vision_num_blocks,
vision_num_heads,
pre_transformer_ln=True,
add_bias_kv=False,
drop_path=0.0,
)
# modality_trunks[ModalityType.TEXT] = instantiate_trunk(
# text_embed_dim,
# text_num_blocks,
# text_num_heads,
# pre_transformer_ln=False,
# add_bias_kv=False,
# drop_path=0.0,
# )
modality_trunks[ModalityType.AUDIO] = instantiate_trunk(
audio_embed_dim,
audio_num_blocks,
audio_num_heads,
pre_transformer_ln=False,
add_bias_kv=True,
drop_path=audio_drop_path,
)
# modality_trunks[ModalityType.DEPTH] = instantiate_trunk(
# depth_embed_dim,
# depth_num_blocks,
# depth_num_heads,
# pre_transformer_ln=False,
# add_bias_kv=True,
# drop_path=depth_drop_path,
# )
# modality_trunks[ModalityType.THERMAL] = instantiate_trunk(
# thermal_embed_dim,
# thermal_num_blocks,
# thermal_num_heads,
# pre_transformer_ln=False,
# add_bias_kv=True,
# drop_path=thermal_drop_path,
# )
# modality_trunks[ModalityType.IMU] = instantiate_trunk(
# imu_embed_dim,
# imu_num_blocks,
# imu_num_heads,
# pre_transformer_ln=False,
# add_bias_kv=True,
# drop_path=imu_drop_path,
# )
return nn.ModuleDict(modality_trunks)
def _create_modality_heads(
self,
out_embed_dim,
vision_embed_dim,
text_embed_dim,
audio_embed_dim,
depth_embed_dim,
thermal_embed_dim,
imu_embed_dim,
use_selection=False,
):
modality_heads = {}
modality_heads[ModalityType.VISION] = nn.Sequential(
nn.LayerNorm(normalized_shape=vision_embed_dim, eps=1e-6),
SelectElement(index=0) if use_selection else nn.Identity(),
nn.Linear(vision_embed_dim, out_embed_dim, bias=False),
)
# modality_heads[ModalityType.TEXT] = SelectEOSAndProject(
# proj=nn.Sequential(
# nn.LayerNorm(normalized_shape=text_embed_dim, eps=1e-6),
# nn.Linear(text_embed_dim, out_embed_dim, bias=False),
# )
# )
modality_heads[ModalityType.AUDIO] = nn.Sequential(
nn.LayerNorm(normalized_shape=audio_embed_dim, eps=1e-6),
SelectElement(index=0) if use_selection else nn.Identity(),
nn.Linear(audio_embed_dim, out_embed_dim, bias=False),
)
# modality_heads[ModalityType.DEPTH] = nn.Sequential(
# nn.LayerNorm(normalized_shape=depth_embed_dim, eps=1e-6),
# SelectElement(index=0) if use_selection else nn.Identity(),
# nn.Linear(depth_embed_dim, out_embed_dim, bias=False),
# )
#
# modality_heads[ModalityType.THERMAL] = nn.Sequential(
# nn.LayerNorm(normalized_shape=thermal_embed_dim, eps=1e-6),
# SelectElement(index=0) if use_selection else nn.Identity(),
# nn.Linear(thermal_embed_dim, out_embed_dim, bias=False),
# )
#
# modality_heads[ModalityType.IMU] = nn.Sequential(
# nn.LayerNorm(normalized_shape=imu_embed_dim, eps=1e-6),
# SelectElement(index=0) if use_selection else nn.Identity(),
# nn.Dropout(p=0.5),
# nn.Linear(imu_embed_dim, out_embed_dim, bias=False),
# )
return nn.ModuleDict(modality_heads)
def _create_modality_postprocessors(self, out_embed_dim):
modality_postprocessors = {}
modality_postprocessors[ModalityType.VISION] = Normalize(dim=-1)
# modality_postprocessors[ModalityType.TEXT] = nn.Sequential(
# Normalize(dim=-1), LearnableLogitScaling(learnable=True)
# )
modality_postprocessors[ModalityType.AUDIO] = nn.Sequential(
Normalize(dim=-1),
LearnableLogitScaling(logit_scale_init=20.0, learnable=False),
)
# modality_postprocessors[ModalityType.DEPTH] = nn.Sequential(
# Normalize(dim=-1),
# LearnableLogitScaling(logit_scale_init=5.0, learnable=False),
# )
# modality_postprocessors[ModalityType.THERMAL] = nn.Sequential(
# Normalize(dim=-1),
# LearnableLogitScaling(logit_scale_init=10.0, learnable=False),
# )
# modality_postprocessors[ModalityType.IMU] = nn.Sequential(
# Normalize(dim=-1),
# LearnableLogitScaling(logit_scale_init=5.0, learnable=False),
# )
return nn.ModuleDict(modality_postprocessors)
def forward(self, inputs: Dict[str, Tensor]) -> Dict[str, Tensor]:
outputs = {}
for modality_key, modality_value in inputs.items():
reduce_list = (
modality_value.ndim >= 5
) # Audio and Video inputs consist of multiple clips
if reduce_list:
B, S = modality_value.shape[:2]
modality_value = modality_value.reshape(
B * S, *modality_value.shape[2:]
)
if modality_value is not None:
modality_value = self.modality_preprocessors[modality_key](
**{modality_key: modality_value}
)
trunk_inputs = modality_value["trunk"]
head_inputs = modality_value["head"]
modality_value = self.modality_trunks[modality_key](**trunk_inputs)
# NOTE: No heads are needed any more.
if self.with_head:
modality_value = self.modality_heads[modality_key](
modality_value, **head_inputs
)
modality_value = self.modality_postprocessors[modality_key](
modality_value
)
# NOTE: The reduction operation has been modified.
if reduce_list:
modality_value = modality_value.reshape(B, S, *modality_value.shape[1:])
modality_value = modality_value.mean(dim=1)
outputs[modality_key] = modality_value
return outputs
def imagebind_huge(pretrained=False, freeze_imagebind=False, with_head=True, use_blip_vision=False):
model = ImageBindModel(
vision_embed_dim=1280,
vision_num_blocks=32,
vision_num_heads=16,
text_embed_dim=1024,
text_num_blocks=24,
text_num_heads=16,
out_embed_dim=1024,
audio_drop_path=0.1,
imu_drop_path=0.7,
with_head=with_head,
)
if pretrained:
if not os.path.exists(".checkpoints/imagebind_huge.pth"):
print(
"Downloading imagebind weights to .checkpoints/imagebind_huge.pth ..."
)
os.makedirs(".checkpoints", exist_ok=True)
torch.hub.download_url_to_file(
"https://dl.fbaipublicfiles.com/imagebind/imagebind_huge.pth",
".checkpoints/imagebind_huge.pth",
progress=True,
)
model.load_state_dict(torch.load(".checkpoints/imagebind_huge.pth"), strict=False)
if use_blip_vision:
from bubogpt.models.eva_vit import create_eva_vit_g
visual_encoder = create_eva_vit_g(
img_size=224, drop_path_rate=0., use_checkpoint=False, precision='fp16'
)
vision_ln = LayerNorm(visual_encoder.num_features)
vision_ln.load_state_dict(load_ln_params())
model.modality_preprocessors[ModalityType.VISION] = BlipPreprocessor()
model.modality_trunks[ModalityType.VISION] = visual_encoder
model.modality_postprocessors[ModalityType.VISION] = vision_ln
if freeze_imagebind:
for name, param in model.named_parameters():
param.requires_grad = False
model = model.eval()
model.train = disabled_train
return model
class LayerNorm(nn.LayerNorm):
"""Subclass torch's LayerNorm to handle fp16."""
def forward(self, x: torch.Tensor):
orig_type = x.dtype
ret = super().forward(x.type(torch.float32))
return ret.type(orig_type)
def load_ln_params(path="checkpoints/blip2_pretrained_flant5xxl.pth"):
state_dict = torch.load(path, map_location="cpu")["model"]
params = type(state_dict)()
params["weight"] = state_dict["ln_vision.weight"]
params["bias"] = state_dict["ln_vision.bias"]
return params
def replace_joiner_vision(joiner, q_former_model, proj_model):
assert isinstance(joiner.modality_pre_projectors.vision, nn.Identity)
joiner.modality_qformers[ModalityType.VISION].load_Qformer(q_former_model)
if proj_model:
state_dict = torch.load(proj_model, map_location="cpu")["model"]
params = type(state_dict)()
params["fc.weight"] = state_dict["llama_proj.weight"]
params["fc.bias"] = state_dict["llama_proj.bias"]
joiner.modality_post_projectors[ModalityType.VISION].load_state_dict(params, strict=False)