File size: 2,407 Bytes
1938217
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

from transformers.models.clip.modeling_clip import CLIPVisionModel
import os

if 'NORMALIZE_POOL' in os.environ:
    NORMALIZE_POOL = bool(int(os.environ['NORMALIZE_POOL']))
    print(f'NORMALIZE_POOL: {NORMALIZE_POOL}')
else:
    NORMALIZE_POOL = True


class PoolerProjector(nn.Module):
    def __init__(self, config, vision_cfg):
        super().__init__()
        self._config = config
        self.hw = vision_cfg.image_size // vision_cfg.patch_size

        self.conv_pool = nn.Conv2d(
            config.mm_hidden_size, config.hidden_size,
            kernel_size=2, stride=2
        )

        self.proj = nn.Sequential(
            nn.GELU(),
            nn.Linear(config.hidden_size, config.hidden_size),
        )

    def forward(self, x, *args, **kwargs):
        height = width = self.hw
        assert height * width == x.shape[1]
        x = x.view(x.shape[0], height, width, -1).permute(0, 3, 1, 2)
        x = self.conv_pool(x)
        x = x.flatten(2).transpose(1, 2)
        x = self.proj(x)
        return x

    @property
    def config(self):
        return {"mm_projector_type": 'pooler'}


class NormalizedDwPooler(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
        self.predictor = nn.Sequential(
            nn.Linear(dim*2, dim),
            nn.GELU(),
            nn.Linear(dim, dim),
        )
    
    def forward(self, x, forward_type='2x'):
        B, H, W, C = x.shape

        if forward_type == '2x':
            new_x = x.reshape(B, H//2, 2, W//2, 2, C).permute(0, 1, 3, 2, 4, 5).reshape(B, H//2, W//2, 4, C)
            pooled_x = new_x.mean(-2, keepdim=True).expand(-1, -1, -1, 4, -1)
            fused_x = torch.cat([new_x, pooled_x], dim=-1)
        elif forward_type == '1x':
            new_x = x.reshape(B, H, W, 1, C)
            fused_x = torch.cat([new_x, new_x], dim=-1)
        elif forward_type == '4x':
            new_x = x.reshape(B, H//4, 4, W//4, 4, C).permute(0, 1, 3, 2, 4, 5).reshape(B, H//4, W//4, 16, C)
            pooled_x = new_x.mean(-2, keepdim=True).expand(-1, -1, -1, 16, -1)
            fused_x = torch.cat([new_x, pooled_x], dim=-1)
        
        score = self.predictor(fused_x)
        normalized_score = F.softmax(score, dim=-2)
        new_x = (new_x * normalized_score).sum(dim=-2)
        return new_x