lxysl's picture
upload vita-1.5 app.py
bc752b1
import math
import re
from functools import partial
from torch import nn
from timm.layers.norm_act import LayerNormAct2d
from torchvision.models.mobilenetv3 import InvertedResidual, InvertedResidualConfig
from torchvision.ops.misc import SqueezeExcitation as SELayer
class IdentityMap(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, *args, **kwargs):
return x
@property
def config(self):
return {"mm_projector_type": "identity"}
class Minigpt(nn.Module):
def __init__(self, config=None):
super(Minigpt, self).__init__()
# c*4 is the input size, and c is the output size for the linear layer
inc, ouc = config.mm_hidden_size, config.hidden_size
self.linear = nn.Linear(inc * 4, ouc)
def forward(self, x):
# x is the input tensor with shape [b, num_tokens, c]
b, num_tokens, c = x.shape
# Check if num_tokens is divisible by 4
if num_tokens % 4 != 0:
raise ValueError("num_tokens must be divisible by 4")
# Reshape x to [b, num_tokens/4, c*4]
x = x.view(b, num_tokens // 4, c * 4)
# Apply the linear transformation
x = self.linear(x)
return x
class Vanilla(nn.Module):
def __init__(self, config=None):
super(Vanilla, self).__init__()
# c*4 is the input size, and c is the output size for the linear layer
inc, ouc = config.mm_hidden_size, config.hidden_size
self.linear = nn.Linear(inc * 4, ouc)
def forward(self, x):
b, num_tokens, c = x.shape
# Check if num_tokens is divisible by 4
if num_tokens % 4 != 0:
raise ValueError("num_tokens must be divisible by 4")
# First, reshape to [b, num_tokens//4, 4, c]
x = x.view(b, num_tokens // 4, 4, c)
# Then, permute to interleave the tokens
x = x.permute(0, 1, 3, 2).contiguous()
# Finally, reshape to [b, num_tokens//4, c*4] to interleave features of 4 tokens
x = x.view(b, num_tokens // 4, c * 4)
# Apply the linear transformation
x = self.linear(x)
return x
class LDPBlock(nn.Module):
# Lightweight Downsample Projector Block
def __init__(self, config=None):
super().__init__()
inc, ouc = config.mm_hidden_size, config.hidden_size
layer_norm = partial(LayerNormAct2d, act_layer=None)
se_layer = partial(SELayer, scale_activation=nn.Hardsigmoid)
self.mlp = nn.Sequential(nn.Identity(), nn.Linear(inc, ouc), nn.GELU(), nn.Linear(ouc, ouc))
self.mb_block = nn.Sequential(
nn.Identity(),
InvertedResidual(
InvertedResidualConfig(ouc, 3, ouc, ouc, True, "HS", 1, 1, 1), layer_norm, se_layer
),
InvertedResidual(
InvertedResidualConfig(ouc, 3, ouc, ouc, True, "HS", 2, 1, 1), layer_norm, se_layer
),
)
def forward(self, x):
b, num_tokens, c = x.shape
h = int(math.sqrt(num_tokens))
x = self.mlp(x)
x = x.permute(0, 2, 1).reshape(b, -1, h, h)
x = self.mb_block(x)
x = x.flatten(2).permute(0, 2, 1)
return x
class LDPNetProjector(nn.Module):
def __init__(self, config=None):
super().__init__()
self.model = LDPBlock(config)
def forward(self, x):
return self.model(x)
class SPP(nn.Module):
def __init__(self, config=None, projector_type="v1"):
super().__init__()
self.projector_type = projector_type
inc, ouc = config.mm_hidden_size, config.hidden_size
self.linear_0 = nn.Linear(inc, inc)
self.linear_1 = nn.Linear(inc, ouc)
self.pooling = nn.AvgPool2d(kernel_size=2)
self.linear_2 = nn.Linear(ouc, ouc)
def forward(self, x):
b, num_tokens, c = x.shape
h = int(math.sqrt(num_tokens))
if "v1" in self.projector_type:
x = self.linear_1(x)
x = x.permute(0, 2, 1).reshape(b, -1, h, h)
x = self.pooling(x)
x = x.flatten(2).permute(0, 2, 1)
x = self.linear_2(x)
elif "v2" in self.projector_type:
x = self.linear_1(x)
x = self.linear_2(x)
x = x.permute(0, 2, 1).reshape(b, -1, h, h)
x = self.pooling(x)
x = x.flatten(2).permute(0, 2, 1)
elif "v3" in self.projector_type:
x = self.linear_0(x)
x = x.permute(0, 2, 1).reshape(b, -1, h, h)
x = self.pooling(x)
x = x.flatten(2).permute(0, 2, 1)
x = self.linear_1(x)
x = self.linear_2(x)
return x
def build_vision_projector(config, delay_load=False, **kwargs):
projector_type = getattr(config, "mm_projector_type", "mlp2x_gelu")
if projector_type == "linear":
return nn.Linear(config.mm_hidden_size, config.hidden_size)
elif projector_type.startswith("mlp"):
mlp_gelu_match = re.match(r"^mlp(\d+)x_gelu$", projector_type)
if mlp_gelu_match:
mlp_depth = int(mlp_gelu_match.group(1))
modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
for _ in range(1, mlp_depth):
modules.append(nn.GELU())
modules.append(nn.Linear(config.hidden_size, config.hidden_size))
return nn.Sequential(*modules)
elif projector_type.startswith("spp"):
return SPP(config, projector_type)
elif projector_type == "ldp":
return LDPNetProjector(config)
elif projector_type == "vanilla":
return Vanilla(config)
elif projector_type == "minigpt":
return Minigpt(config)
elif projector_type == "identity":
return IdentityMap()
raise ValueError(f"Unknown projector type: {projector_type}")