Spaces:
Running
on
Zero
Running
on
Zero
import math | |
import re | |
from functools import partial | |
from torch import nn | |
from timm.layers.norm_act import LayerNormAct2d | |
from torchvision.models.mobilenetv3 import InvertedResidual, InvertedResidualConfig | |
from torchvision.ops.misc import SqueezeExcitation as SELayer | |
class IdentityMap(nn.Module): | |
def __init__(self): | |
super().__init__() | |
def forward(self, x, *args, **kwargs): | |
return x | |
def config(self): | |
return {"mm_projector_type": "identity"} | |
class Minigpt(nn.Module): | |
def __init__(self, config=None): | |
super(Minigpt, self).__init__() | |
# c*4 is the input size, and c is the output size for the linear layer | |
inc, ouc = config.mm_hidden_size, config.hidden_size | |
self.linear = nn.Linear(inc * 4, ouc) | |
def forward(self, x): | |
# x is the input tensor with shape [b, num_tokens, c] | |
b, num_tokens, c = x.shape | |
# Check if num_tokens is divisible by 4 | |
if num_tokens % 4 != 0: | |
raise ValueError("num_tokens must be divisible by 4") | |
# Reshape x to [b, num_tokens/4, c*4] | |
x = x.view(b, num_tokens // 4, c * 4) | |
# Apply the linear transformation | |
x = self.linear(x) | |
return x | |
class Vanilla(nn.Module): | |
def __init__(self, config=None): | |
super(Vanilla, self).__init__() | |
# c*4 is the input size, and c is the output size for the linear layer | |
inc, ouc = config.mm_hidden_size, config.hidden_size | |
self.linear = nn.Linear(inc * 4, ouc) | |
def forward(self, x): | |
b, num_tokens, c = x.shape | |
# Check if num_tokens is divisible by 4 | |
if num_tokens % 4 != 0: | |
raise ValueError("num_tokens must be divisible by 4") | |
# First, reshape to [b, num_tokens//4, 4, c] | |
x = x.view(b, num_tokens // 4, 4, c) | |
# Then, permute to interleave the tokens | |
x = x.permute(0, 1, 3, 2).contiguous() | |
# Finally, reshape to [b, num_tokens//4, c*4] to interleave features of 4 tokens | |
x = x.view(b, num_tokens // 4, c * 4) | |
# Apply the linear transformation | |
x = self.linear(x) | |
return x | |
class LDPBlock(nn.Module): | |
# Lightweight Downsample Projector Block | |
def __init__(self, config=None): | |
super().__init__() | |
inc, ouc = config.mm_hidden_size, config.hidden_size | |
layer_norm = partial(LayerNormAct2d, act_layer=None) | |
se_layer = partial(SELayer, scale_activation=nn.Hardsigmoid) | |
self.mlp = nn.Sequential(nn.Identity(), nn.Linear(inc, ouc), nn.GELU(), nn.Linear(ouc, ouc)) | |
self.mb_block = nn.Sequential( | |
nn.Identity(), | |
InvertedResidual( | |
InvertedResidualConfig(ouc, 3, ouc, ouc, True, "HS", 1, 1, 1), layer_norm, se_layer | |
), | |
InvertedResidual( | |
InvertedResidualConfig(ouc, 3, ouc, ouc, True, "HS", 2, 1, 1), layer_norm, se_layer | |
), | |
) | |
def forward(self, x): | |
b, num_tokens, c = x.shape | |
h = int(math.sqrt(num_tokens)) | |
x = self.mlp(x) | |
x = x.permute(0, 2, 1).reshape(b, -1, h, h) | |
x = self.mb_block(x) | |
x = x.flatten(2).permute(0, 2, 1) | |
return x | |
class LDPNetProjector(nn.Module): | |
def __init__(self, config=None): | |
super().__init__() | |
self.model = LDPBlock(config) | |
def forward(self, x): | |
return self.model(x) | |
class SPP(nn.Module): | |
def __init__(self, config=None, projector_type="v1"): | |
super().__init__() | |
self.projector_type = projector_type | |
inc, ouc = config.mm_hidden_size, config.hidden_size | |
self.linear_0 = nn.Linear(inc, inc) | |
self.linear_1 = nn.Linear(inc, ouc) | |
self.pooling = nn.AvgPool2d(kernel_size=2) | |
self.linear_2 = nn.Linear(ouc, ouc) | |
def forward(self, x): | |
b, num_tokens, c = x.shape | |
h = int(math.sqrt(num_tokens)) | |
if "v1" in self.projector_type: | |
x = self.linear_1(x) | |
x = x.permute(0, 2, 1).reshape(b, -1, h, h) | |
x = self.pooling(x) | |
x = x.flatten(2).permute(0, 2, 1) | |
x = self.linear_2(x) | |
elif "v2" in self.projector_type: | |
x = self.linear_1(x) | |
x = self.linear_2(x) | |
x = x.permute(0, 2, 1).reshape(b, -1, h, h) | |
x = self.pooling(x) | |
x = x.flatten(2).permute(0, 2, 1) | |
elif "v3" in self.projector_type: | |
x = self.linear_0(x) | |
x = x.permute(0, 2, 1).reshape(b, -1, h, h) | |
x = self.pooling(x) | |
x = x.flatten(2).permute(0, 2, 1) | |
x = self.linear_1(x) | |
x = self.linear_2(x) | |
return x | |
def build_vision_projector(config, delay_load=False, **kwargs): | |
projector_type = getattr(config, "mm_projector_type", "mlp2x_gelu") | |
if projector_type == "linear": | |
return nn.Linear(config.mm_hidden_size, config.hidden_size) | |
elif projector_type.startswith("mlp"): | |
mlp_gelu_match = re.match(r"^mlp(\d+)x_gelu$", projector_type) | |
if mlp_gelu_match: | |
mlp_depth = int(mlp_gelu_match.group(1)) | |
modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)] | |
for _ in range(1, mlp_depth): | |
modules.append(nn.GELU()) | |
modules.append(nn.Linear(config.hidden_size, config.hidden_size)) | |
return nn.Sequential(*modules) | |
elif projector_type.startswith("spp"): | |
return SPP(config, projector_type) | |
elif projector_type == "ldp": | |
return LDPNetProjector(config) | |
elif projector_type == "vanilla": | |
return Vanilla(config) | |
elif projector_type == "minigpt": | |
return Minigpt(config) | |
elif projector_type == "identity": | |
return IdentityMap() | |
raise ValueError(f"Unknown projector type: {projector_type}") | |