Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import math | |
from transformers.models.clip.modeling_clip import CLIPVisionModel | |
import os | |
if 'NORMALIZE_POOL' in os.environ: | |
NORMALIZE_POOL = bool(int(os.environ['NORMALIZE_POOL'])) | |
print(f'NORMALIZE_POOL: {NORMALIZE_POOL}') | |
else: | |
NORMALIZE_POOL = True | |
class PoolerProjector(nn.Module): | |
def __init__(self, config, vision_cfg): | |
super().__init__() | |
self._config = config | |
self.hw = vision_cfg.image_size // vision_cfg.patch_size | |
self.conv_pool = nn.Conv2d( | |
config.mm_hidden_size, config.hidden_size, | |
kernel_size=2, stride=2 | |
) | |
self.proj = nn.Sequential( | |
nn.GELU(), | |
nn.Linear(config.hidden_size, config.hidden_size), | |
) | |
def forward(self, x, *args, **kwargs): | |
height = width = self.hw | |
assert height * width == x.shape[1] | |
x = x.view(x.shape[0], height, width, -1).permute(0, 3, 1, 2) | |
x = self.conv_pool(x) | |
x = x.flatten(2).transpose(1, 2) | |
x = self.proj(x) | |
return x | |
def config(self): | |
return {"mm_projector_type": 'pooler'} | |
class NormalizedDwPooler(nn.Module): | |
def __init__(self, dim): | |
super().__init__() | |
self.dim = dim | |
self.predictor = nn.Sequential( | |
nn.Linear(dim*2, dim), | |
nn.GELU(), | |
nn.Linear(dim, dim), | |
) | |
def forward(self, x, forward_type='2x'): | |
B, H, W, C = x.shape | |
if forward_type == '2x': | |
new_x = x.reshape(B, H//2, 2, W//2, 2, C).permute(0, 1, 3, 2, 4, 5).reshape(B, H//2, W//2, 4, C) | |
pooled_x = new_x.mean(-2, keepdim=True).expand(-1, -1, -1, 4, -1) | |
fused_x = torch.cat([new_x, pooled_x], dim=-1) | |
elif forward_type == '1x': | |
new_x = x.reshape(B, H, W, 1, C) | |
fused_x = torch.cat([new_x, new_x], dim=-1) | |
elif forward_type == '4x': | |
new_x = x.reshape(B, H//4, 4, W//4, 4, C).permute(0, 1, 3, 2, 4, 5).reshape(B, H//4, W//4, 16, C) | |
pooled_x = new_x.mean(-2, keepdim=True).expand(-1, -1, -1, 16, -1) | |
fused_x = torch.cat([new_x, pooled_x], dim=-1) | |
score = self.predictor(fused_x) | |
normalized_score = F.softmax(score, dim=-2) | |
new_x = (new_x * normalized_score).sum(dim=-2) | |
return new_x | |