|
import torch |
|
import torch.nn.functional as F |
|
from torch import nn |
|
|
|
from networks.layers.basic import DropPath, GroupNorm1D, GNActDWConv2d, seq_to_2d, ScaleOffset, mask_out |
|
from networks.layers.attention import silu, MultiheadAttention, MultiheadLocalAttentionV2, MultiheadLocalAttentionV3, GatedPropagation, LocalGatedPropagation |
|
|
|
|
|
def _get_norm(indim, type='ln', groups=8): |
|
if type == 'gn': |
|
return GroupNorm1D(indim, groups) |
|
else: |
|
return nn.LayerNorm(indim) |
|
|
|
|
|
def _get_activation_fn(activation): |
|
"""Return an activation function given a string""" |
|
if activation == "relu": |
|
return F.relu |
|
if activation == "gelu": |
|
return F.gelu |
|
if activation == "glu": |
|
return F.glu |
|
raise RuntimeError( |
|
F"activation should be relu/gele/glu, not {activation}.") |
|
|
|
|
|
class LongShortTermTransformer(nn.Module): |
|
def __init__(self, |
|
num_layers=2, |
|
d_model=256, |
|
self_nhead=8, |
|
att_nhead=8, |
|
dim_feedforward=1024, |
|
emb_dropout=0., |
|
droppath=0.1, |
|
lt_dropout=0., |
|
st_dropout=0., |
|
droppath_lst=False, |
|
droppath_scaling=False, |
|
activation="gelu", |
|
return_intermediate=False, |
|
intermediate_norm=True, |
|
final_norm=True, |
|
block_version="v1"): |
|
|
|
super().__init__() |
|
self.intermediate_norm = intermediate_norm |
|
self.final_norm = final_norm |
|
self.num_layers = num_layers |
|
self.return_intermediate = return_intermediate |
|
|
|
self.emb_dropout = nn.Dropout(emb_dropout, True) |
|
self.mask_token = nn.Parameter(torch.randn([1, 1, d_model])) |
|
|
|
if block_version == "v1": |
|
block = LongShortTermTransformerBlock |
|
elif block_version == "v2": |
|
block = LongShortTermTransformerBlockV2 |
|
elif block_version == "v3": |
|
block = LongShortTermTransformerBlockV3 |
|
else: |
|
raise NotImplementedError |
|
|
|
layers = [] |
|
for idx in range(num_layers): |
|
if droppath_scaling: |
|
if num_layers == 1: |
|
droppath_rate = 0 |
|
else: |
|
droppath_rate = droppath * idx / (num_layers - 1) |
|
else: |
|
droppath_rate = droppath |
|
layers.append( |
|
block(d_model, self_nhead, att_nhead, dim_feedforward, |
|
droppath_rate, lt_dropout, st_dropout, droppath_lst, |
|
activation)) |
|
self.layers = nn.ModuleList(layers) |
|
|
|
num_norms = num_layers - 1 if intermediate_norm else 0 |
|
if final_norm: |
|
num_norms += 1 |
|
self.decoder_norms = [ |
|
_get_norm(d_model, type='ln') for _ in range(num_norms) |
|
] if num_norms > 0 else None |
|
|
|
if self.decoder_norms is not None: |
|
self.decoder_norms = nn.ModuleList(self.decoder_norms) |
|
|
|
def forward(self, |
|
tgt, |
|
long_term_memories, |
|
short_term_memories, |
|
curr_id_emb=None, |
|
self_pos=None, |
|
size_2d=None): |
|
|
|
output = self.emb_dropout(tgt) |
|
|
|
|
|
|
|
intermediate = [] |
|
intermediate_memories = [] |
|
|
|
for idx, layer in enumerate(self.layers): |
|
output, memories = layer(output, |
|
long_term_memories[idx] if |
|
long_term_memories is not None else None, |
|
short_term_memories[idx] if |
|
short_term_memories is not None else None, |
|
curr_id_emb=curr_id_emb, |
|
self_pos=self_pos, |
|
size_2d=size_2d) |
|
|
|
if self.return_intermediate: |
|
intermediate.append(output) |
|
intermediate_memories.append(memories) |
|
|
|
if self.decoder_norms is not None: |
|
if self.final_norm: |
|
output = self.decoder_norms[-1](output) |
|
|
|
if self.return_intermediate: |
|
intermediate.pop() |
|
intermediate.append(output) |
|
|
|
if self.intermediate_norm: |
|
for idx in range(len(intermediate) - 1): |
|
intermediate[idx] = self.decoder_norms[idx]( |
|
intermediate[idx]) |
|
|
|
if self.return_intermediate: |
|
return intermediate, intermediate_memories |
|
|
|
return output, memories |
|
|
|
|
|
class DualBranchGPM(nn.Module): |
|
def __init__(self, |
|
num_layers=2, |
|
d_model=256, |
|
self_nhead=8, |
|
att_nhead=8, |
|
dim_feedforward=1024, |
|
emb_dropout=0., |
|
droppath=0.1, |
|
lt_dropout=0., |
|
st_dropout=0., |
|
droppath_lst=False, |
|
droppath_scaling=False, |
|
activation="gelu", |
|
return_intermediate=False, |
|
intermediate_norm=True, |
|
final_norm=True): |
|
|
|
super().__init__() |
|
self.intermediate_norm = intermediate_norm |
|
self.final_norm = final_norm |
|
self.num_layers = num_layers |
|
self.return_intermediate = return_intermediate |
|
|
|
self.emb_dropout = nn.Dropout(emb_dropout, True) |
|
|
|
|
|
block = GatedPropagationModule |
|
|
|
layers = [] |
|
for idx in range(num_layers): |
|
if droppath_scaling: |
|
if num_layers == 1: |
|
droppath_rate = 0 |
|
else: |
|
droppath_rate = droppath * idx / (num_layers - 1) |
|
else: |
|
droppath_rate = droppath |
|
layers.append( |
|
block(d_model, |
|
self_nhead, |
|
att_nhead, |
|
dim_feedforward, |
|
droppath_rate, |
|
lt_dropout, |
|
st_dropout, |
|
droppath_lst, |
|
activation, |
|
layer_idx=idx)) |
|
self.layers = nn.ModuleList(layers) |
|
|
|
num_norms = num_layers - 1 if intermediate_norm else 0 |
|
if final_norm: |
|
num_norms += 1 |
|
self.decoder_norms = [ |
|
_get_norm(d_model * 2, type='gn', groups=2) |
|
for _ in range(num_norms) |
|
] if num_norms > 0 else None |
|
|
|
if self.decoder_norms is not None: |
|
self.decoder_norms = nn.ModuleList(self.decoder_norms) |
|
|
|
def forward(self, |
|
tgt, |
|
long_term_memories, |
|
short_term_memories, |
|
curr_id_emb=None, |
|
self_pos=None, |
|
size_2d=None): |
|
|
|
output = self.emb_dropout(tgt) |
|
|
|
|
|
|
|
intermediate = [] |
|
intermediate_memories = [] |
|
output_id = None |
|
|
|
for idx, layer in enumerate(self.layers): |
|
output, output_id, memories = layer( |
|
output, |
|
output_id, |
|
long_term_memories[idx] |
|
if long_term_memories is not None else None, |
|
short_term_memories[idx] |
|
if short_term_memories is not None else None, |
|
curr_id_emb=curr_id_emb, |
|
self_pos=self_pos, |
|
size_2d=size_2d) |
|
|
|
cat_output = torch.cat([output, output_id], dim=2) |
|
|
|
if self.return_intermediate: |
|
intermediate.append(cat_output) |
|
intermediate_memories.append(memories) |
|
|
|
if self.decoder_norms is not None: |
|
if self.final_norm: |
|
cat_output = self.decoder_norms[-1](cat_output) |
|
|
|
if self.return_intermediate: |
|
intermediate.pop() |
|
intermediate.append(cat_output) |
|
|
|
if self.intermediate_norm: |
|
for idx in range(len(intermediate) - 1): |
|
intermediate[idx] = self.decoder_norms[idx]( |
|
intermediate[idx]) |
|
|
|
if self.return_intermediate: |
|
return intermediate, intermediate_memories |
|
|
|
return cat_output, memories |
|
|
|
|
|
class LongShortTermTransformerBlock(nn.Module): |
|
def __init__(self, |
|
d_model, |
|
self_nhead, |
|
att_nhead, |
|
dim_feedforward=1024, |
|
droppath=0.1, |
|
lt_dropout=0., |
|
st_dropout=0., |
|
droppath_lst=False, |
|
activation="gelu", |
|
local_dilation=1, |
|
enable_corr=True): |
|
super().__init__() |
|
|
|
|
|
self.norm1 = _get_norm(d_model) |
|
self.linear_Q = nn.Linear(d_model, d_model) |
|
self.linear_V = nn.Linear(d_model, d_model) |
|
|
|
self.long_term_attn = MultiheadAttention(d_model, |
|
att_nhead, |
|
use_linear=False, |
|
dropout=lt_dropout) |
|
|
|
|
|
if enable_corr: |
|
try: |
|
import spatial_correlation_sampler |
|
MultiheadLocalAttention = MultiheadLocalAttentionV2 |
|
except Exception as inst: |
|
print(inst) |
|
print("Failed to import PyTorch Correlation, For better efficiency, please install it.") |
|
MultiheadLocalAttention = MultiheadLocalAttentionV3 |
|
else: |
|
MultiheadLocalAttention = MultiheadLocalAttentionV3 |
|
self.short_term_attn = MultiheadLocalAttention(d_model, |
|
att_nhead, |
|
dilation=local_dilation, |
|
use_linear=False, |
|
dropout=st_dropout) |
|
self.lst_dropout = nn.Dropout(max(lt_dropout, st_dropout), True) |
|
self.droppath_lst = droppath_lst |
|
|
|
|
|
self.norm2 = _get_norm(d_model) |
|
self.self_attn = MultiheadAttention(d_model, self_nhead) |
|
|
|
|
|
self.norm3 = _get_norm(d_model) |
|
self.linear1 = nn.Linear(d_model, dim_feedforward) |
|
self.activation = GNActDWConv2d(dim_feedforward) |
|
self.linear2 = nn.Linear(dim_feedforward, d_model) |
|
|
|
self.droppath = DropPath(droppath, batch_dim=1) |
|
self._init_weight() |
|
|
|
def with_pos_embed(self, tensor, pos=None): |
|
size = tensor.size() |
|
if len(size) == 4 and pos is not None: |
|
n, c, h, w = size |
|
pos = pos.view(h, w, n, c).permute(2, 3, 0, 1) |
|
return tensor if pos is None else tensor + pos |
|
|
|
def forward(self, |
|
tgt, |
|
long_term_memory=None, |
|
short_term_memory=None, |
|
curr_id_emb=None, |
|
self_pos=None, |
|
size_2d=(30, 30)): |
|
|
|
|
|
_tgt = self.norm1(tgt) |
|
q = k = self.with_pos_embed(_tgt, self_pos) |
|
v = _tgt |
|
tgt2 = self.self_attn(q, k, v)[0] |
|
|
|
tgt = tgt + self.droppath(tgt2) |
|
|
|
|
|
_tgt = self.norm2(tgt) |
|
|
|
curr_Q = self.linear_Q(_tgt) |
|
curr_K = curr_Q |
|
curr_V = _tgt |
|
|
|
local_Q = seq_to_2d(curr_Q, size_2d) |
|
|
|
if curr_id_emb is not None: |
|
global_K, global_V = self.fuse_key_value_id( |
|
curr_K, curr_V, curr_id_emb) |
|
local_K = seq_to_2d(global_K, size_2d) |
|
local_V = seq_to_2d(global_V, size_2d) |
|
else: |
|
global_K, global_V = long_term_memory |
|
local_K, local_V = short_term_memory |
|
|
|
tgt2 = self.long_term_attn(curr_Q, global_K, global_V)[0] |
|
tgt3 = self.short_term_attn(local_Q, local_K, local_V)[0] |
|
|
|
if self.droppath_lst: |
|
tgt = tgt + self.droppath(tgt2 + tgt3) |
|
else: |
|
tgt = tgt + self.lst_dropout(tgt2 + tgt3) |
|
|
|
|
|
_tgt = self.norm3(tgt) |
|
|
|
tgt2 = self.linear2(self.activation(self.linear1(_tgt), size_2d)) |
|
|
|
tgt = tgt + self.droppath(tgt2) |
|
|
|
return tgt, [[curr_K, curr_V], [global_K, global_V], |
|
[local_K, local_V]] |
|
|
|
def fuse_key_value_id(self, key, value, id_emb): |
|
K = key |
|
V = self.linear_V(value + id_emb) |
|
return K, V |
|
|
|
def _init_weight(self): |
|
for p in self.parameters(): |
|
if p.dim() > 1: |
|
nn.init.xavier_uniform_(p) |
|
|
|
|
|
class LongShortTermTransformerBlockV2(nn.Module): |
|
def __init__(self, |
|
d_model, |
|
self_nhead, |
|
att_nhead, |
|
dim_feedforward=1024, |
|
droppath=0.1, |
|
lt_dropout=0., |
|
st_dropout=0., |
|
droppath_lst=False, |
|
activation="gelu", |
|
local_dilation=1, |
|
enable_corr=True): |
|
super().__init__() |
|
self.d_model = d_model |
|
self.att_nhead = att_nhead |
|
|
|
|
|
self.norm1 = _get_norm(d_model) |
|
self.self_attn = MultiheadAttention(d_model, self_nhead) |
|
|
|
|
|
self.norm2 = _get_norm(d_model) |
|
self.linear_QV = nn.Linear(d_model, 2 * d_model) |
|
self.linear_ID_KV = nn.Linear(d_model, d_model + att_nhead) |
|
|
|
self.long_term_attn = MultiheadAttention(d_model, |
|
att_nhead, |
|
use_linear=False, |
|
dropout=lt_dropout) |
|
|
|
|
|
if enable_corr: |
|
try: |
|
import spatial_correlation_sampler |
|
MultiheadLocalAttention = MultiheadLocalAttentionV2 |
|
except Exception as inst: |
|
print(inst) |
|
print("Failed to import PyTorch Correlation, For better efficiency, please install it.") |
|
MultiheadLocalAttention = MultiheadLocalAttentionV3 |
|
else: |
|
MultiheadLocalAttention = MultiheadLocalAttentionV3 |
|
self.short_term_attn = MultiheadLocalAttention(d_model, |
|
att_nhead, |
|
dilation=local_dilation, |
|
use_linear=False, |
|
dropout=st_dropout) |
|
self.lst_dropout = nn.Dropout(max(lt_dropout, st_dropout), True) |
|
self.droppath_lst = droppath_lst |
|
|
|
|
|
self.norm3 = _get_norm(d_model) |
|
self.linear1 = nn.Linear(d_model, dim_feedforward) |
|
self.activation = GNActDWConv2d(dim_feedforward) |
|
self.linear2 = nn.Linear(dim_feedforward, d_model) |
|
|
|
self.droppath = DropPath(droppath, batch_dim=1) |
|
self._init_weight() |
|
|
|
def with_pos_embed(self, tensor, pos=None): |
|
size = tensor.size() |
|
if len(size) == 4 and pos is not None: |
|
n, c, h, w = size |
|
pos = pos.view(h, w, n, c).permute(2, 3, 0, 1) |
|
return tensor if pos is None else tensor + pos |
|
|
|
def forward(self, |
|
tgt, |
|
long_term_memory=None, |
|
short_term_memory=None, |
|
curr_id_emb=None, |
|
self_pos=None, |
|
size_2d=(30, 30)): |
|
|
|
|
|
_tgt = self.norm1(tgt) |
|
q = k = self.with_pos_embed(_tgt, self_pos) |
|
v = _tgt |
|
tgt2 = self.self_attn(q, k, v)[0] |
|
|
|
tgt = tgt + self.droppath(tgt2) |
|
|
|
|
|
_tgt = self.norm2(tgt) |
|
|
|
curr_QV = self.linear_QV(_tgt) |
|
curr_QV = torch.split(curr_QV, self.d_model, dim=2) |
|
curr_Q = curr_K = curr_QV[0] |
|
curr_V = curr_QV[1] |
|
|
|
local_Q = seq_to_2d(curr_Q, size_2d) |
|
|
|
if curr_id_emb is not None: |
|
global_K, global_V = self.fuse_key_value_id( |
|
curr_K, curr_V, curr_id_emb) |
|
|
|
local_K = seq_to_2d(global_K, size_2d) |
|
local_V = seq_to_2d(global_V, size_2d) |
|
else: |
|
global_K, global_V = long_term_memory |
|
local_K, local_V = short_term_memory |
|
|
|
tgt2 = self.long_term_attn(curr_Q, global_K, global_V)[0] |
|
tgt3 = self.short_term_attn(local_Q, local_K, local_V)[0] |
|
|
|
if self.droppath_lst: |
|
tgt = tgt + self.droppath(tgt2 + tgt3) |
|
else: |
|
tgt = tgt + self.lst_dropout(tgt2 + tgt3) |
|
|
|
|
|
_tgt = self.norm3(tgt) |
|
|
|
tgt2 = self.linear2(self.activation(self.linear1(_tgt), size_2d)) |
|
|
|
tgt = tgt + self.droppath(tgt2) |
|
|
|
return tgt, [[curr_K, curr_V], [global_K, global_V], |
|
[local_K, local_V]] |
|
|
|
def fuse_key_value_id(self, key, value, id_emb): |
|
ID_KV = self.linear_ID_KV(id_emb) |
|
ID_K, ID_V = torch.split(ID_KV, [self.att_nhead, self.d_model], dim=2) |
|
bs = key.size(1) |
|
K = key.view(-1, bs, self.att_nhead, self.d_model // |
|
self.att_nhead) * (1 + torch.tanh(ID_K)).unsqueeze(-1) |
|
K = K.view(-1, bs, self.d_model) |
|
V = value + ID_V |
|
return K, V |
|
|
|
def _init_weight(self): |
|
for p in self.parameters(): |
|
if p.dim() > 1: |
|
nn.init.xavier_uniform_(p) |
|
|
|
class GatedPropagationModule(nn.Module): |
|
def __init__(self, |
|
d_model, |
|
self_nhead, |
|
att_nhead, |
|
dim_feedforward=1024, |
|
droppath=0.1, |
|
lt_dropout=0., |
|
st_dropout=0., |
|
droppath_lst=False, |
|
activation="gelu", |
|
local_dilation=1, |
|
enable_corr=True, |
|
max_local_dis=7, |
|
layer_idx=0, |
|
expand_ratio=2.): |
|
super().__init__() |
|
expand_ratio = expand_ratio |
|
expand_d_model = int(d_model * expand_ratio) |
|
self.expand_d_model = expand_d_model |
|
self.d_model = d_model |
|
self.att_nhead = att_nhead |
|
|
|
d_att = d_model // 2 if att_nhead == 1 else d_model // att_nhead |
|
self.d_att = d_att |
|
self.layer_idx = layer_idx |
|
|
|
|
|
self.norm1 = _get_norm(d_model) |
|
self.linear_QV = nn.Linear(d_model, d_att * att_nhead + expand_d_model) |
|
self.linear_U = nn.Linear(d_model, expand_d_model) |
|
|
|
if layer_idx == 0: |
|
self.linear_ID_V = nn.Linear(d_model, expand_d_model) |
|
else: |
|
self.id_norm1 = _get_norm(d_model) |
|
self.linear_ID_V = nn.Linear(d_model * 2, expand_d_model) |
|
self.linear_ID_U = nn.Linear(d_model, expand_d_model) |
|
|
|
self.long_term_attn = GatedPropagation(d_qk=self.d_model, |
|
d_vu=self.d_model * 2, |
|
num_head=att_nhead, |
|
use_linear=False, |
|
dropout=lt_dropout, |
|
d_att=d_att, |
|
top_k=-1, |
|
expand_ratio=expand_ratio) |
|
|
|
if enable_corr: |
|
try: |
|
import spatial_correlation_sampler |
|
except Exception as inst: |
|
print(inst) |
|
print("Failed to import PyTorch Correlation, For better efficiency, please install it.") |
|
enable_corr = False |
|
self.short_term_attn = LocalGatedPropagation(d_qk=self.d_model, |
|
d_vu=self.d_model * 2, |
|
num_head=att_nhead, |
|
dilation=local_dilation, |
|
use_linear=False, |
|
enable_corr=enable_corr, |
|
dropout=st_dropout, |
|
d_att=d_att, |
|
max_dis=max_local_dis, |
|
expand_ratio=expand_ratio) |
|
|
|
self.lst_dropout = nn.Dropout(max(lt_dropout, st_dropout), True) |
|
self.droppath_lst = droppath_lst |
|
|
|
|
|
self.norm2 = _get_norm(d_model) |
|
self.id_norm2 = _get_norm(d_model) |
|
self.self_attn = GatedPropagation(d_model * 2, |
|
d_model * 2, |
|
self_nhead, |
|
d_att=d_att) |
|
|
|
self.droppath = DropPath(droppath, batch_dim=1) |
|
self._init_weight() |
|
|
|
def with_pos_embed(self, tensor, pos=None): |
|
size = tensor.size() |
|
if len(size) == 4 and pos is not None: |
|
n, c, h, w = size |
|
pos = pos.view(h, w, n, c).permute(2, 3, 0, 1) |
|
return tensor if pos is None else tensor + pos |
|
|
|
def forward(self, |
|
tgt, |
|
tgt_id=None, |
|
long_term_memory=None, |
|
short_term_memory=None, |
|
curr_id_emb=None, |
|
self_pos=None, |
|
size_2d=(30, 30)): |
|
|
|
|
|
_tgt = self.norm1(tgt) |
|
|
|
curr_QV = self.linear_QV(_tgt) |
|
curr_QV = torch.split( |
|
curr_QV, [self.d_att * self.att_nhead, self.expand_d_model], dim=2) |
|
curr_Q = curr_K = curr_QV[0] |
|
local_Q = seq_to_2d(curr_Q, size_2d) |
|
curr_V = silu(curr_QV[1]) |
|
curr_U = self.linear_U(_tgt) |
|
|
|
if tgt_id is None: |
|
tgt_id = 0 |
|
cat_curr_U = torch.cat( |
|
[silu(curr_U), torch.ones_like(curr_U)], dim=-1) |
|
curr_ID_V = None |
|
else: |
|
_tgt_id = self.id_norm1(tgt_id) |
|
curr_ID_V = _tgt_id |
|
curr_ID_U = self.linear_ID_U(_tgt_id) |
|
cat_curr_U = silu(torch.cat([curr_U, curr_ID_U], dim=-1)) |
|
|
|
if curr_id_emb is not None: |
|
global_K, global_V = curr_K, curr_V |
|
local_K = seq_to_2d(global_K, size_2d) |
|
local_V = seq_to_2d(global_V, size_2d) |
|
|
|
_, global_ID_V = self.fuse_key_value_id(None, curr_ID_V, |
|
curr_id_emb) |
|
local_ID_V = seq_to_2d(global_ID_V, size_2d) |
|
else: |
|
global_K, global_V, _, global_ID_V = long_term_memory |
|
local_K, local_V, _, local_ID_V = short_term_memory |
|
|
|
cat_global_V = torch.cat([global_V, global_ID_V], dim=-1) |
|
cat_local_V = torch.cat([local_V, local_ID_V], dim=1) |
|
|
|
cat_tgt2, _ = self.long_term_attn(curr_Q, global_K, cat_global_V, |
|
cat_curr_U, size_2d) |
|
cat_tgt3, _ = self.short_term_attn(local_Q, local_K, cat_local_V, |
|
cat_curr_U, size_2d) |
|
|
|
tgt2, tgt_id2 = torch.split(cat_tgt2, self.d_model, dim=-1) |
|
tgt3, tgt_id3 = torch.split(cat_tgt3, self.d_model, dim=-1) |
|
|
|
if self.droppath_lst: |
|
tgt = tgt + self.droppath(tgt2 + tgt3) |
|
tgt_id = tgt_id + self.droppath(tgt_id2 + tgt_id3) |
|
else: |
|
tgt = tgt + self.lst_dropout(tgt2 + tgt3) |
|
tgt_id = tgt_id + self.lst_dropout(tgt_id2 + tgt_id3) |
|
|
|
|
|
_tgt = self.norm2(tgt) |
|
_tgt_id = self.id_norm2(tgt_id) |
|
q = k = v = u = torch.cat([_tgt, _tgt_id], dim=-1) |
|
|
|
cat_tgt2, _ = self.self_attn(q, k, v, u, size_2d) |
|
|
|
tgt2, tgt_id2 = torch.split(cat_tgt2, self.d_model, dim=-1) |
|
|
|
tgt = tgt + self.droppath(tgt2) |
|
tgt_id = tgt_id + self.droppath(tgt_id2) |
|
|
|
return tgt, tgt_id, [[curr_K, curr_V, None, curr_ID_V], |
|
[global_K, global_V, None, global_ID_V], |
|
[local_K, local_V, None, local_ID_V]] |
|
|
|
def fuse_key_value_id(self, key, value, id_emb): |
|
ID_K = None |
|
if value is not None: |
|
ID_V = silu(self.linear_ID_V(torch.cat([value, id_emb], dim=2))) |
|
else: |
|
ID_V = silu(self.linear_ID_V(id_emb)) |
|
return ID_K, ID_V |
|
|
|
def _init_weight(self): |
|
for p in self.parameters(): |
|
if p.dim() > 1: |
|
nn.init.xavier_uniform_(p) |
|
|