wjf5203
Add application file
2aac0e2
raw
history blame
9.84 kB
import torch
import torch.nn.functional as F
from torch import nn
from timm.models.layers import DropPath
class VLFuse(torch.nn.Module):
"""
Early Fusion Module
"""
def __init__(self, ):
super(VLFuse, self).__init__()
self.init_configs()
# early fusion module
# bi-direction (text->image, image->text)
self.b_attn = BiAttentionBlockForCheckpoint(v_dim=self.img_dim, # 256
l_dim=self.lang_dim, # 768
embed_dim=self.embed_dim, # 2048
num_heads=self.n_head, # 8
dropout=0.1,
drop_path=.0,
init_values=1.0 / 6,
)
def init_configs(self, ):
# common params
self.img_dim = 256
self.max_query_len = 256
self.n_layers =1
# mha params
self.n_head = 8
self.embed_dim = 2048 # 2048 by default
self.lang_dim = 256
def forward(self, x, task=None):
visual_features = x["visual"]
language_dict_features = x["lang"]
fused_visual_features, language_features = self.b_attn(
visual_features, language_dict_features['hidden'], language_dict_features['masks'], task)
language_dict_features['hidden'] = language_features
fused_language_dict_features = language_dict_features
features_dict = {"visual": fused_visual_features,
"lang": fused_language_dict_features}
return features_dict
class BiMultiHeadAttention(nn.Module):
def __init__(self, v_dim, l_dim, embed_dim, num_heads, dropout=0.1):
super(BiMultiHeadAttention, self).__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.v_dim = v_dim
self.l_dim = l_dim
assert (
self.head_dim * self.num_heads == self.embed_dim
), f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})."
self.scale = self.head_dim ** (-0.5)
self.dropout = dropout
self.v_proj = nn.Linear(self.v_dim, self.embed_dim)
self.l_proj = nn.Linear(self.l_dim, self.embed_dim)
self.values_v_proj = nn.Linear(self.v_dim, self.embed_dim)
self.values_l_proj = nn.Linear(self.l_dim, self.embed_dim)
self.out_v_proj = nn.Linear(self.embed_dim, self.v_dim)
self.out_l_proj = nn.Linear(self.embed_dim, self.l_dim)
self.stable_softmax_2d = False
self.clamp_min_for_underflow = True
self.clamp_max_for_overflow = True
self._reset_parameters()
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
def _reset_parameters(self):
nn.init.xavier_uniform_(self.v_proj.weight)
self.v_proj.bias.data.fill_(0)
nn.init.xavier_uniform_(self.l_proj.weight)
self.l_proj.bias.data.fill_(0)
nn.init.xavier_uniform_(self.values_v_proj.weight)
self.values_v_proj.bias.data.fill_(0)
nn.init.xavier_uniform_(self.values_l_proj.weight)
self.values_l_proj.bias.data.fill_(0)
nn.init.xavier_uniform_(self.out_v_proj.weight)
self.out_v_proj.bias.data.fill_(0)
nn.init.xavier_uniform_(self.out_l_proj.weight)
self.out_l_proj.bias.data.fill_(0)
def forward(self, v, l, attention_mask_l=None):
bsz, tgt_len, embed_dim = v.size()
query_states = self.v_proj(v) * self.scale
key_states = self._shape(self.l_proj(l), -1, bsz)
value_v_states = self._shape(self.values_v_proj(v), -1, bsz)
value_l_states = self._shape(self.values_l_proj(l), -1, bsz)
proj_shape = (bsz * self.num_heads, -1, self.head_dim) # (bs * 8, -1, embed_dim//8)
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) # (bs * 8, seq_len_img, embed_dim//8)
key_states = key_states.view(*proj_shape) # (bs * 8, seq_len_text, embed_dim//8)
value_v_states = value_v_states.view(*proj_shape)
value_l_states = value_l_states.view(*proj_shape)
src_len = key_states.size(1)
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) # (bs * 8, seq_len_img, seq_len_text)
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
raise ValueError(
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
)
# attn_weights_l = nn.functional.softmax(attn_weights.transpose(1, 2), dim=-1)
if self.stable_softmax_2d:
attn_weights = attn_weights - attn_weights.max()
if self.clamp_min_for_underflow:
attn_weights = torch.clamp(attn_weights, min=-50000) # Do not increase -50000, data type half has quite limited range
if self.clamp_max_for_overflow:
attn_weights = torch.clamp(attn_weights, max=50000) # Do not increase 50000, data type half has quite limited range
attn_weights_T = attn_weights.transpose(1, 2)
attn_weights_l = (attn_weights_T - torch.max(attn_weights_T, dim=-1, keepdim=True)[
0])
if self.clamp_min_for_underflow:
attn_weights_l = torch.clamp(attn_weights_l, min=-50000) # Do not increase -50000, data type half has quite limited range
if self.clamp_max_for_overflow:
attn_weights_l = torch.clamp(attn_weights_l, max=50000) # Do not increase 50000, data type half has quite limited range
attn_weights_l = attn_weights_l.softmax(dim=-1)
# assert attention_mask_l.dtype == torch.int64
if attention_mask_l is not None:
assert (attention_mask_l.dim() == 2) # (bs, seq_len)
attention_mask = attention_mask_l.unsqueeze(1).unsqueeze(1) # (bs, 1, 1, seq_len)
attention_mask = attention_mask.expand(bsz, 1, tgt_len, src_len)
attention_mask = attention_mask.masked_fill(attention_mask == 0, -9e15)
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}"
)
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
attn_weights_v = nn.functional.softmax(attn_weights, dim=-1)
attn_probs_v = F.dropout(attn_weights_v, p=self.dropout, training=self.training)
attn_probs_l = F.dropout(attn_weights_l, p=self.dropout, training=self.training)
attn_output_v = torch.bmm(attn_probs_v, value_l_states)
attn_output_l = torch.bmm(attn_probs_l, value_v_states)
if attn_output_v.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
raise ValueError(
f"`attn_output_v` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output_v.size()}"
)
if attn_output_l.size() != (bsz * self.num_heads, src_len, self.head_dim):
raise ValueError(
f"`attn_output_l` should be of size {(bsz, self.num_heads, src_len, self.head_dim)}, but is {attn_output_l.size()}"
)
attn_output_v = attn_output_v.view(bsz, self.num_heads, tgt_len, self.head_dim)
attn_output_v = attn_output_v.transpose(1, 2)
attn_output_v = attn_output_v.reshape(bsz, tgt_len, self.embed_dim)
attn_output_l = attn_output_l.view(bsz, self.num_heads, src_len, self.head_dim)
attn_output_l = attn_output_l.transpose(1, 2)
attn_output_l = attn_output_l.reshape(bsz, src_len, self.embed_dim)
attn_output_v = self.out_v_proj(attn_output_v)
attn_output_l = self.out_l_proj(attn_output_l)
return attn_output_v, attn_output_l
class BiAttentionBlockForCheckpoint(nn.Module):
def __init__(self, v_dim, l_dim, embed_dim, num_heads, dropout=0.1,
drop_path=.0, init_values=1e-4, ):
"""
Inputs:
embed_dim - Dimensionality of input and attention feature vectors
num_heads - Number of heads to use in the Multi-Head Attention block
dropout - Amount of dropout to apply in the feed-forward network
"""
super(BiAttentionBlockForCheckpoint, self).__init__()
# pre layer norm
self.layer_norm_v = nn.LayerNorm(v_dim)
self.layer_norm_l = nn.LayerNorm(l_dim)
self.attn = BiMultiHeadAttention(v_dim=v_dim,
l_dim=l_dim,
embed_dim=embed_dim,
num_heads=num_heads,
dropout=dropout,
)
# add layer scale for training stability
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.gamma_v = nn.Parameter(init_values * torch.ones((v_dim)), requires_grad=True)
self.gamma_l = nn.Parameter(init_values * torch.ones((l_dim)), requires_grad=True)
def forward(self, v, l, attention_mask_l=None, task=None):
# v: visual features, (bs, sigma(HW), 256)
# l: language features, (bs, seq_len, 768)
v = self.layer_norm_v(v)
l = self.layer_norm_l(l)
delta_v, delta_l = self.attn(v, l, attention_mask_l=attention_mask_l)
# v, l = v + delta_v, l + delta_l
v = v + self.drop_path(self.gamma_v * delta_v)
l = l + self.drop_path(self.gamma_l * delta_l)
return v, l