Spaces:
Runtime error
Runtime error
File size: 9,841 Bytes
2aac0e2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 |
import torch
import torch.nn.functional as F
from torch import nn
from timm.models.layers import DropPath
class VLFuse(torch.nn.Module):
"""
Early Fusion Module
"""
def __init__(self, ):
super(VLFuse, self).__init__()
self.init_configs()
# early fusion module
# bi-direction (text->image, image->text)
self.b_attn = BiAttentionBlockForCheckpoint(v_dim=self.img_dim, # 256
l_dim=self.lang_dim, # 768
embed_dim=self.embed_dim, # 2048
num_heads=self.n_head, # 8
dropout=0.1,
drop_path=.0,
init_values=1.0 / 6,
)
def init_configs(self, ):
# common params
self.img_dim = 256
self.max_query_len = 256
self.n_layers =1
# mha params
self.n_head = 8
self.embed_dim = 2048 # 2048 by default
self.lang_dim = 256
def forward(self, x, task=None):
visual_features = x["visual"]
language_dict_features = x["lang"]
fused_visual_features, language_features = self.b_attn(
visual_features, language_dict_features['hidden'], language_dict_features['masks'], task)
language_dict_features['hidden'] = language_features
fused_language_dict_features = language_dict_features
features_dict = {"visual": fused_visual_features,
"lang": fused_language_dict_features}
return features_dict
class BiMultiHeadAttention(nn.Module):
def __init__(self, v_dim, l_dim, embed_dim, num_heads, dropout=0.1):
super(BiMultiHeadAttention, self).__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.v_dim = v_dim
self.l_dim = l_dim
assert (
self.head_dim * self.num_heads == self.embed_dim
), f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})."
self.scale = self.head_dim ** (-0.5)
self.dropout = dropout
self.v_proj = nn.Linear(self.v_dim, self.embed_dim)
self.l_proj = nn.Linear(self.l_dim, self.embed_dim)
self.values_v_proj = nn.Linear(self.v_dim, self.embed_dim)
self.values_l_proj = nn.Linear(self.l_dim, self.embed_dim)
self.out_v_proj = nn.Linear(self.embed_dim, self.v_dim)
self.out_l_proj = nn.Linear(self.embed_dim, self.l_dim)
self.stable_softmax_2d = False
self.clamp_min_for_underflow = True
self.clamp_max_for_overflow = True
self._reset_parameters()
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
def _reset_parameters(self):
nn.init.xavier_uniform_(self.v_proj.weight)
self.v_proj.bias.data.fill_(0)
nn.init.xavier_uniform_(self.l_proj.weight)
self.l_proj.bias.data.fill_(0)
nn.init.xavier_uniform_(self.values_v_proj.weight)
self.values_v_proj.bias.data.fill_(0)
nn.init.xavier_uniform_(self.values_l_proj.weight)
self.values_l_proj.bias.data.fill_(0)
nn.init.xavier_uniform_(self.out_v_proj.weight)
self.out_v_proj.bias.data.fill_(0)
nn.init.xavier_uniform_(self.out_l_proj.weight)
self.out_l_proj.bias.data.fill_(0)
def forward(self, v, l, attention_mask_l=None):
bsz, tgt_len, embed_dim = v.size()
query_states = self.v_proj(v) * self.scale
key_states = self._shape(self.l_proj(l), -1, bsz)
value_v_states = self._shape(self.values_v_proj(v), -1, bsz)
value_l_states = self._shape(self.values_l_proj(l), -1, bsz)
proj_shape = (bsz * self.num_heads, -1, self.head_dim) # (bs * 8, -1, embed_dim//8)
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) # (bs * 8, seq_len_img, embed_dim//8)
key_states = key_states.view(*proj_shape) # (bs * 8, seq_len_text, embed_dim//8)
value_v_states = value_v_states.view(*proj_shape)
value_l_states = value_l_states.view(*proj_shape)
src_len = key_states.size(1)
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) # (bs * 8, seq_len_img, seq_len_text)
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
raise ValueError(
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
)
# attn_weights_l = nn.functional.softmax(attn_weights.transpose(1, 2), dim=-1)
if self.stable_softmax_2d:
attn_weights = attn_weights - attn_weights.max()
if self.clamp_min_for_underflow:
attn_weights = torch.clamp(attn_weights, min=-50000) # Do not increase -50000, data type half has quite limited range
if self.clamp_max_for_overflow:
attn_weights = torch.clamp(attn_weights, max=50000) # Do not increase 50000, data type half has quite limited range
attn_weights_T = attn_weights.transpose(1, 2)
attn_weights_l = (attn_weights_T - torch.max(attn_weights_T, dim=-1, keepdim=True)[
0])
if self.clamp_min_for_underflow:
attn_weights_l = torch.clamp(attn_weights_l, min=-50000) # Do not increase -50000, data type half has quite limited range
if self.clamp_max_for_overflow:
attn_weights_l = torch.clamp(attn_weights_l, max=50000) # Do not increase 50000, data type half has quite limited range
attn_weights_l = attn_weights_l.softmax(dim=-1)
# assert attention_mask_l.dtype == torch.int64
if attention_mask_l is not None:
assert (attention_mask_l.dim() == 2) # (bs, seq_len)
attention_mask = attention_mask_l.unsqueeze(1).unsqueeze(1) # (bs, 1, 1, seq_len)
attention_mask = attention_mask.expand(bsz, 1, tgt_len, src_len)
attention_mask = attention_mask.masked_fill(attention_mask == 0, -9e15)
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}"
)
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
attn_weights_v = nn.functional.softmax(attn_weights, dim=-1)
attn_probs_v = F.dropout(attn_weights_v, p=self.dropout, training=self.training)
attn_probs_l = F.dropout(attn_weights_l, p=self.dropout, training=self.training)
attn_output_v = torch.bmm(attn_probs_v, value_l_states)
attn_output_l = torch.bmm(attn_probs_l, value_v_states)
if attn_output_v.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
raise ValueError(
f"`attn_output_v` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output_v.size()}"
)
if attn_output_l.size() != (bsz * self.num_heads, src_len, self.head_dim):
raise ValueError(
f"`attn_output_l` should be of size {(bsz, self.num_heads, src_len, self.head_dim)}, but is {attn_output_l.size()}"
)
attn_output_v = attn_output_v.view(bsz, self.num_heads, tgt_len, self.head_dim)
attn_output_v = attn_output_v.transpose(1, 2)
attn_output_v = attn_output_v.reshape(bsz, tgt_len, self.embed_dim)
attn_output_l = attn_output_l.view(bsz, self.num_heads, src_len, self.head_dim)
attn_output_l = attn_output_l.transpose(1, 2)
attn_output_l = attn_output_l.reshape(bsz, src_len, self.embed_dim)
attn_output_v = self.out_v_proj(attn_output_v)
attn_output_l = self.out_l_proj(attn_output_l)
return attn_output_v, attn_output_l
class BiAttentionBlockForCheckpoint(nn.Module):
def __init__(self, v_dim, l_dim, embed_dim, num_heads, dropout=0.1,
drop_path=.0, init_values=1e-4, ):
"""
Inputs:
embed_dim - Dimensionality of input and attention feature vectors
num_heads - Number of heads to use in the Multi-Head Attention block
dropout - Amount of dropout to apply in the feed-forward network
"""
super(BiAttentionBlockForCheckpoint, self).__init__()
# pre layer norm
self.layer_norm_v = nn.LayerNorm(v_dim)
self.layer_norm_l = nn.LayerNorm(l_dim)
self.attn = BiMultiHeadAttention(v_dim=v_dim,
l_dim=l_dim,
embed_dim=embed_dim,
num_heads=num_heads,
dropout=dropout,
)
# add layer scale for training stability
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.gamma_v = nn.Parameter(init_values * torch.ones((v_dim)), requires_grad=True)
self.gamma_l = nn.Parameter(init_values * torch.ones((l_dim)), requires_grad=True)
def forward(self, v, l, attention_mask_l=None, task=None):
# v: visual features, (bs, sigma(HW), 256)
# l: language features, (bs, seq_len, 768)
v = self.layer_norm_v(v)
l = self.layer_norm_l(l)
delta_v, delta_l = self.attn(v, l, attention_mask_l=attention_mask_l)
# v, l = v + delta_v, l + delta_l
v = v + self.drop_path(self.gamma_v * delta_v)
l = l + self.drop_path(self.gamma_l * delta_l)
return v, l
|