|
import copy
|
|
import math
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import difflib
|
|
from typing import Optional
|
|
|
|
from .motion_encoder import VQEncoderV6
|
|
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, BertTokenizer, BertModel, Wav2Vec2Model, Wav2Vec2Config
|
|
|
|
|
|
def audio_to_time_aligned_text_features(inputs, processor, model, tokenizer, bert_model):
|
|
with torch.no_grad():
|
|
logits = model(inputs.input_values).logits
|
|
|
|
predicted_ids_per_timestep = torch.argmax(logits, dim=-1)
|
|
predicted_ids_per_timestep = predicted_ids_per_timestep[0].cpu().numpy()
|
|
vocab = processor.tokenizer.get_vocab()
|
|
id_to_token = {v: k for k, v in vocab.items()}
|
|
tokens_per_timestep = [id_to_token[id] for id in predicted_ids_per_timestep]
|
|
|
|
predicted_ids = torch.argmax(logits, dim=-1)
|
|
transcription = processor.decode(predicted_ids[0])
|
|
inputs_bert = tokenizer(transcription, return_tensors="pt")
|
|
input_ids = inputs_bert["input_ids"][0]
|
|
tokens_bert = tokenizer.convert_ids_to_tokens(input_ids)
|
|
|
|
with torch.no_grad():
|
|
outputs_bert = bert_model(**inputs_bert.to(inputs.input_values.device))
|
|
all_token_embeddings = outputs_bert.last_hidden_state[0]
|
|
per_timestep_chars = []
|
|
per_timestep_char_indices = []
|
|
for idx, t in enumerate(tokens_per_timestep):
|
|
if t not in ("<pad>", "|"):
|
|
per_timestep_chars.append(t.lower())
|
|
per_timestep_char_indices.append(idx)
|
|
bert_chars = []
|
|
bert_char_indices = []
|
|
for idx, token in enumerate(tokens_bert):
|
|
if token in ("[CLS]", "[SEP]"):
|
|
continue
|
|
token_str = token.replace("##", "")
|
|
for c in token_str:
|
|
bert_chars.append(c)
|
|
bert_char_indices.append(idx)
|
|
|
|
s = difflib.SequenceMatcher(None, per_timestep_chars, bert_chars)
|
|
opcodes = s.get_opcodes()
|
|
per_timestep_to_bert_token_idx = {}
|
|
for tag, i1, i2, j1, j2 in opcodes:
|
|
if tag == "equal":
|
|
for k in range(i2 - i1):
|
|
per_timestep_idx = per_timestep_char_indices[i1 + k]
|
|
bert_token_idx = bert_char_indices[j1 + k]
|
|
per_timestep_to_bert_token_idx[per_timestep_idx] = bert_token_idx
|
|
features_per_timestep = []
|
|
check = []
|
|
for i, per_token in enumerate(tokens_per_timestep):
|
|
if i == 0:
|
|
embedding = all_token_embeddings[0]
|
|
check.append("cls")
|
|
elif per_token in ("<pad>", "|"):
|
|
embedding = torch.zeros(all_token_embeddings.shape[-1]).to(inputs.input_values.device)
|
|
check.append(0)
|
|
else:
|
|
if i in per_timestep_to_bert_token_idx:
|
|
bert_idx = per_timestep_to_bert_token_idx[i]
|
|
embedding = all_token_embeddings[bert_idx]
|
|
check.append(tokens_bert[bert_idx])
|
|
else:
|
|
embedding = torch.zeros(all_token_embeddings.shape[-1]).to(inputs.input_values.device)
|
|
check.append(0)
|
|
features_per_timestep.append(embedding)
|
|
features_per_timestep = torch.stack(features_per_timestep)
|
|
|
|
updated_check = check.copy()
|
|
for i in range(len(check)):
|
|
if check[i] == 0:
|
|
left = i - 1
|
|
right = i + 1
|
|
left_found = False
|
|
right_found = False
|
|
|
|
while left >= 0:
|
|
if check[left] != 0:
|
|
left_found = True
|
|
break
|
|
left -= 1
|
|
|
|
while right < len(check):
|
|
if check[right] != 0:
|
|
right_found = True
|
|
break
|
|
right += 1
|
|
|
|
if left_found and right_found:
|
|
if (i - left) <= (right - i):
|
|
nearest = left
|
|
else:
|
|
nearest = right
|
|
elif left_found:
|
|
nearest = left
|
|
elif right_found:
|
|
nearest = right
|
|
else:
|
|
continue
|
|
updated_check[i] = updated_check[nearest]
|
|
features_per_timestep[i] = features_per_timestep[nearest]
|
|
features_per_timestep = features_per_timestep.unsqueeze(0)
|
|
return transcription, features_per_timestep, all_token_embeddings
|
|
|
|
|
|
class MLP(nn.Module):
|
|
def __init__(self, in_dim, hidden_size, out_dim):
|
|
super().__init__()
|
|
self.mlp = nn.Sequential(nn.Linear(in_dim, hidden_size), nn.LeakyReLU(0.2, True), nn.Linear(hidden_size, out_dim))
|
|
|
|
def forward(self, inputs):
|
|
out = self.mlp(inputs)
|
|
return out
|
|
|
|
|
|
class PeriodicPositionalEncoding(nn.Module):
|
|
def __init__(self, d_model, dropout=0.1, period=20, max_seq_len=64):
|
|
super(PeriodicPositionalEncoding, self).__init__()
|
|
self.dropout = nn.Dropout(p=dropout)
|
|
pe = torch.zeros(period, d_model)
|
|
position = torch.arange(0, period, dtype=torch.float).unsqueeze(1)
|
|
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
|
|
pe[:, 0::2] = torch.sin(position * div_term)
|
|
pe[:, 1::2] = torch.cos(position * div_term)
|
|
pe = pe.unsqueeze(0)
|
|
repeat_num = (max_seq_len // period) + 1
|
|
pe = pe.repeat(1, repeat_num, 1)
|
|
self.register_buffer("pe", pe)
|
|
|
|
def forward(self, x):
|
|
x = x + self.pe[:, : x.size(1), :]
|
|
return self.dropout(x)
|
|
|
|
|
|
class CustomMultiheadAttention(nn.Module):
|
|
def __init__(self, embed_dim, num_heads):
|
|
super(CustomMultiheadAttention, self).__init__()
|
|
self.num_heads = num_heads
|
|
self.head_dim = embed_dim // num_heads
|
|
assert self.head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"
|
|
|
|
self.query_proj = nn.Linear(embed_dim, embed_dim)
|
|
self.key_proj = nn.Linear(embed_dim, embed_dim)
|
|
self.value_proj = nn.Linear(embed_dim, embed_dim)
|
|
self.out_proj = nn.Linear(embed_dim, embed_dim)
|
|
|
|
def forward(self, query, key, value):
|
|
batch_size, seq_len, embed_dim = query.size()
|
|
|
|
|
|
Q = self.query_proj(query).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
|
K = self.key_proj(key).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
|
V = self.value_proj(value).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
|
|
|
|
|
scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim**0.5)
|
|
attn_weights = F.softmax(scores, dim=-1)
|
|
attn_output = torch.matmul(attn_weights, V)
|
|
|
|
|
|
attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, embed_dim)
|
|
|
|
|
|
output = self.out_proj(attn_output)
|
|
return output, attn_weights
|
|
|
|
|
|
def reinitialize_weights(module):
|
|
for submodule in module.modules():
|
|
weight = getattr(submodule, "weight", None)
|
|
if weight is not None and isinstance(weight, torch.Tensor) and weight.dim() >= 2:
|
|
torch.nn.init.xavier_uniform_(weight)
|
|
print("init")
|
|
elif weight is not None and isinstance(weight, torch.Tensor):
|
|
torch.nn.init.normal_(weight, mean=0.0, std=0.02)
|
|
print("init")
|
|
bias = getattr(submodule, "bias", None)
|
|
if bias is not None and isinstance(bias, torch.Tensor):
|
|
torch.nn.init.zeros_(bias)
|
|
|
|
|
|
class WrapedMotionCNN(nn.Module):
|
|
def __init__(self, args):
|
|
super(WrapedMotionCNN, self).__init__()
|
|
self.args = args
|
|
encoder_layer = nn.TransformerEncoderLayer(
|
|
d_model=self.args.motion_f,
|
|
nhead=8,
|
|
dim_feedforward=self.args.hidden_size,
|
|
dropout=0.1,
|
|
batch_first=True,
|
|
)
|
|
args_top = copy.deepcopy(self.args)
|
|
args_top.vae_layer = 3
|
|
args_top.vae_length = self.args.motion_f
|
|
args_top.vae_test_dim = self.args.motion_dim
|
|
self.feature_extractor = VQEncoderV6(args_top)
|
|
|
|
args_top = copy.deepcopy(self.args)
|
|
args_top.vae_layer = 6
|
|
args_top.vae_length = self.args.motion_f
|
|
args_top.vae_test_dim = self.args.motion_dim + self.args.motion_f
|
|
|
|
self.encoder_cnn = VQEncoderV6(args_top)
|
|
self.pos_encoding = PeriodicPositionalEncoding(d_model=self.args.motion_f, period=20, max_seq_len=64, dropout=0.0)
|
|
self.encoder_trans = nn.TransformerEncoder(
|
|
encoder_layer, num_layers=1
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
inputs,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
mask_time_indices: Optional[torch.FloatTensor] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
return_dict: Optional[bool] = None,
|
|
):
|
|
low_level = self.feature_extractor(inputs)
|
|
hidden_states = self.encoder_cnn(torch.cat([low_level.detach(), inputs], dim=-1))
|
|
hidden_states = self.pos_encoding(hidden_states)
|
|
hidden_states = self.encoder_trans(hidden_states)
|
|
return {"low_level": low_level, "high_level": hidden_states}
|
|
|
|
|
|
class WrapedWav2Vec(nn.Module):
|
|
def __init__(self):
|
|
super(WrapedWav2Vec, self).__init__()
|
|
self.feature_extractor = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h").feature_extractor
|
|
self.feature_projection = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h").feature_projection
|
|
self.encoder = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h").encoder
|
|
self.encoder.layers = self.encoder.layers[:1]
|
|
self.proj_down = nn.Linear(768, 512)
|
|
|
|
def forward(
|
|
self,
|
|
inputs,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
mask_time_indices: Optional[torch.FloatTensor] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
return_dict: Optional[bool] = None,
|
|
):
|
|
finetune_audio_low = self.feature_extractor(inputs).transpose(1, 2)
|
|
hidden_states, _ = self.feature_projection(finetune_audio_low.detach())
|
|
encoder_outputs = self.encoder(
|
|
hidden_states,
|
|
attention_mask=attention_mask,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=return_dict,
|
|
)
|
|
hidden_states = encoder_outputs[0]
|
|
hidden_states = self.proj_down(hidden_states)
|
|
return {"low_level": finetune_audio_low, "high_level": hidden_states}
|
|
|
|
|
|
class JointEmbedding(nn.Module):
|
|
def __init__(self, args):
|
|
super(JointEmbedding, self).__init__()
|
|
self.args = args.model
|
|
self.audio_processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
|
|
self.audio_encoder = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h")
|
|
self.config_wav2vec = Wav2Vec2Config.from_pretrained("facebook/wav2vec2-base-960h")
|
|
self.audio_encoder_fintune = WrapedWav2Vec()
|
|
|
|
self.asr = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
|
|
self.bert_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
|
self.bert_model = BertModel.from_pretrained("bert-base-uncased")
|
|
|
|
self.audio_low_mapping = MLP(512 + 512, self.args.hidden_size, self.args.audio_f)
|
|
self.audio_high_mapping = MLP(512 + 512 + 512, self.args.hidden_size, self.args.audio_f)
|
|
self.audio_down_proj_2 = nn.Linear(768, 512)
|
|
self.audio_down_proj_3 = nn.Linear(768, 512)
|
|
self.audio_sa = CustomMultiheadAttention(
|
|
embed_dim=self.args.audio_f,
|
|
num_heads=8,
|
|
)
|
|
|
|
self.motion_encoder_fintune = WrapedMotionCNN(self.args)
|
|
self.motion_low_mapping = MLP(self.args.motion_f, self.args.hidden_size, self.args.motion_f)
|
|
self.motion_high_mapping = MLP(self.args.motion_f, self.args.hidden_size, self.args.motion_f)
|
|
self.motion_sa = CustomMultiheadAttention(
|
|
embed_dim=self.args.audio_f,
|
|
num_heads=8,
|
|
)
|
|
|
|
self.down_sample = 2
|
|
self.smplx_model = None
|
|
self.get_motion_reps = None
|
|
self.audio_to_time_aligned_text_features = audio_to_time_aligned_text_features
|
|
self.low_temp = nn.Parameter(torch.tensor(0.07))
|
|
self.low_level_loss_fn = None
|
|
self.high_temp = nn.Parameter(torch.tensor(0.07))
|
|
self.high_level_loss_fn = None
|
|
|
|
def _reset_parameters(self):
|
|
nn.init.normal_(self.mask_embeddings, 0, self.args.hidden_size**-0.5)
|
|
|
|
def forward(self, in_audio=None, in_motion=None, cached_audio_low=None, cached_audio_high=None, cached_rep15d=None):
|
|
|
|
if cached_rep15d is not None:
|
|
in_motion = cached_rep15d[:, :: self.down_sample]
|
|
else:
|
|
in_motion = self.get_motion_reps(in_motion, self.smplx_model)["rep15d"][:, :: self.down_sample]
|
|
|
|
motion_features = self.motion_encoder_fintune(in_motion)
|
|
raw_motion_low = motion_features["low_level"]
|
|
raw_motion_high = motion_features["high_level"]
|
|
|
|
motion_low = self.motion_low_mapping(raw_motion_low)
|
|
motion_high = self.motion_high_mapping(raw_motion_high)
|
|
motion_high_att, motion_high_weight = self.motion_sa(motion_high, motion_high, motion_high)
|
|
bs, n, c = motion_high.shape
|
|
|
|
motion_high_att_before_sum = motion_high_weight[:, :, 0, :].unsqueeze(2) * motion_high.transpose(1, 2).view(bs, 8, c // 8, n)
|
|
motion_high_att_before_sum = motion_high_att_before_sum.reshape(bs, c, n).transpose(1, 2)
|
|
motion_low = F.interpolate(motion_low.transpose(1, 2), scale_factor=2, mode="linear", align_corners=True).transpose(1, 2)
|
|
motion_high_att = F.interpolate(motion_high_att.transpose(1, 2), scale_factor=2, mode="linear", align_corners=True).transpose(1, 2)
|
|
motion_high_att_before_sum = F.interpolate(
|
|
motion_high_att_before_sum.transpose(1, 2), scale_factor=2, mode="linear", align_corners=True
|
|
).transpose(1, 2)
|
|
motion_cls = motion_high_att[:, 0]
|
|
|
|
|
|
if cached_audio_low is not None:
|
|
raw_audio_low = cached_audio_low
|
|
raw_audio_high = torch.cat(
|
|
[self.audio_down_proj_2(cached_audio_high[:, :, :768]), self.audio_down_proj_3(cached_audio_high[:, :, 768:])], dim=-1
|
|
)
|
|
|
|
audio_list = [i.cpu().numpy() for i in in_audio]
|
|
inputs = self.audio_processor(audio_list, sampling_rate=16000, return_tensors="pt", padding=True).to(in_audio.device)
|
|
finetune_audio = self.audio_encoder_fintune(inputs.input_values)
|
|
finetune_audio_low, finetune_audio_high = finetune_audio["low_level"], finetune_audio["high_level"]
|
|
diff = raw_audio_low.shape[1] - finetune_audio_low.shape[1]
|
|
if diff > 0:
|
|
finetune_audio_low = torch.cat([finetune_audio_low, finetune_audio_low[:, -diff:]], dim=1)
|
|
diff = raw_audio_high.shape[1] - finetune_audio_high.shape[1]
|
|
if diff > 0:
|
|
finetune_audio_high = torch.cat([finetune_audio_high, finetune_audio_high[:, -diff:]], dim=1)
|
|
raw_audio_low = torch.cat([raw_audio_low, finetune_audio_low], dim=-1)
|
|
else:
|
|
print("error! must have cached audio in training")
|
|
|
|
raw_audio_low = F.interpolate(raw_audio_low.transpose(1, 2), scale_factor=30 / 50, mode="linear", align_corners=True).transpose(1, 2)
|
|
raw_audio_high = F.interpolate(raw_audio_high.transpose(1, 2), scale_factor=15 / 50, mode="linear", align_corners=True).transpose(1, 2)
|
|
finetune_audio_high = F.interpolate(
|
|
finetune_audio_high.transpose(1, 2), scale_factor=15 / 50, mode="linear", align_corners=True
|
|
).transpose(1, 2)
|
|
audio_low = self.audio_low_mapping(raw_audio_low)
|
|
raw_audio_high = torch.cat([finetune_audio_high, raw_audio_high], dim=-1)
|
|
audio_high = self.audio_high_mapping(raw_audio_high)
|
|
audio_high_att, audio_high_weight = self.audio_sa(audio_high, audio_high, audio_high)
|
|
bs, n, c = audio_high.shape
|
|
audio_high_att_before_sum = audio_high_weight[:, :, 0, :].unsqueeze(2) * audio_high.transpose(1, 2).view(bs, 8, c // 8, n)
|
|
audio_high_att_before_sum = audio_high_att_before_sum.reshape(bs, c, n).transpose(1, 2)
|
|
audio_high_att = F.interpolate(audio_high_att.transpose(1, 2), scale_factor=2, mode="linear", align_corners=True).transpose(1, 2)
|
|
audio_high_att_before_sum = F.interpolate(
|
|
audio_high_att_before_sum.transpose(1, 2), scale_factor=2, mode="linear", align_corners=True
|
|
).transpose(1, 2)
|
|
audio_cls = audio_high_att[:, 0]
|
|
|
|
|
|
low_infonce, low_acc = self.low_level_loss_fn(audio_low, motion_low)
|
|
high_infonce = self.high_level_loss_fn(audio_cls, motion_cls)
|
|
return {
|
|
"audio_low": audio_low,
|
|
"audio_high": audio_high_att,
|
|
"audio_cls": audio_cls,
|
|
"audio_high_weight": audio_high_att_before_sum,
|
|
"motion_low": motion_low,
|
|
"motion_high": motion_high_att,
|
|
"motion_cls": motion_cls,
|
|
"motion_high_weight": motion_high_att_before_sum,
|
|
"low_level_loss": [low_infonce, low_acc],
|
|
"high_level_loss": high_infonce,
|
|
}
|
|
|
|
def get_audio_features(self, in_audio):
|
|
audio_list = [i.cpu().numpy() for i in in_audio]
|
|
inputs = self.audio_processor(audio_list, sampling_rate=16000, return_tensors="pt", padding=True).to(in_audio.device)
|
|
raw_audio_low = self.audio_encoder.feature_extractor(inputs.input_values).transpose(1, 2)
|
|
raw_audio_low = raw_audio_low
|
|
|
|
finetune_audio = self.audio_encoder_fintune(inputs.input_values)
|
|
finetune_audio_low, finetune_audio_high = finetune_audio["low_level"], finetune_audio["high_level"]
|
|
diff = raw_audio_low.shape[1] - finetune_audio_low.shape[1]
|
|
if diff > 0:
|
|
finetune_audio_low = torch.cat([finetune_audio_low, finetune_audio_low[:, -diff:]], dim=1)
|
|
raw_audio_low = torch.cat([raw_audio_low, finetune_audio_low], dim=-1)
|
|
|
|
raw_audio_high = self.audio_encoder(inputs.input_values).last_hidden_state
|
|
|
|
diff = raw_audio_high.shape[1] - finetune_audio_high.shape[1]
|
|
if diff > 0:
|
|
finetune_audio_high = torch.cat([finetune_audio_high, finetune_audio_high[:, -diff:]], dim=1)
|
|
|
|
_, bert_time_aligned_text, _ = audio_to_time_aligned_text_features(
|
|
inputs, self.audio_processor, self.asr, self.bert_tokenizer, self.bert_model
|
|
)
|
|
raw_audio_high = torch.cat([raw_audio_high, bert_time_aligned_text], dim=2)
|
|
raw_audio_high = torch.cat(
|
|
[self.audio_down_proj_2(raw_audio_high[:, :, :768]), self.audio_down_proj_3(raw_audio_high[:, :, 768:])], dim=-1
|
|
)
|
|
|
|
raw_audio_low = F.interpolate(raw_audio_low.transpose(1, 2), scale_factor=30 / 50, mode="linear", align_corners=True).transpose(1, 2)
|
|
raw_audio_high = F.interpolate(raw_audio_high.transpose(1, 2), scale_factor=15 / 50, mode="linear", align_corners=True).transpose(1, 2)
|
|
finetune_audio_high = F.interpolate(
|
|
finetune_audio_high.transpose(1, 2), scale_factor=15 / 50, mode="linear", align_corners=True
|
|
).transpose(1, 2)
|
|
|
|
if raw_audio_low.shape[1] % 2 == 1:
|
|
raw_audio_low = torch.cat([raw_audio_low, raw_audio_low[:, -1:]], dim=1)
|
|
diff = raw_audio_low[:, ::2].shape[1] - raw_audio_high.shape[1]
|
|
if diff > 0:
|
|
raw_audio_high = torch.cat([raw_audio_high, raw_audio_high[:, -diff:]], dim=1)
|
|
finetune_audio_high = torch.cat([finetune_audio_high, finetune_audio_high[:, -diff:]], dim=1)
|
|
|
|
audio_low = self.audio_low_mapping(raw_audio_low)
|
|
raw_audio_high = torch.cat([finetune_audio_high, raw_audio_high], dim=-1)
|
|
audio_high = self.audio_high_mapping(raw_audio_high)
|
|
audio_high_att, audio_high_weight = self.audio_sa(audio_high, audio_high, audio_high)
|
|
bs, n, c = audio_high.shape
|
|
audio_high_att_before_sum = audio_high_weight[:, :, 0, :].unsqueeze(2) * audio_high.transpose(1, 2).view(bs, 8, c // 8, n)
|
|
audio_high_att_before_sum = audio_high_att_before_sum.reshape(bs, c, n).transpose(1, 2)
|
|
audio_high_att = F.interpolate(audio_high_att.transpose(1, 2), scale_factor=2, mode="linear", align_corners=True).transpose(1, 2)
|
|
audio_high_att_before_sum = F.interpolate(
|
|
audio_high_att_before_sum.transpose(1, 2), scale_factor=2, mode="linear", align_corners=True
|
|
).transpose(1, 2)
|
|
audio_cls = audio_high_att[:, 0]
|
|
return {
|
|
"audio_low": audio_low,
|
|
"audio_high": audio_high_att,
|
|
"audio_cls": audio_cls,
|
|
"audio_high_weight": audio_high_att_before_sum,
|
|
}
|
|
|
|
def get_motion_features(self, in_motion):
|
|
original_length = in_motion.shape[1]
|
|
|
|
in_motion = self.get_motion_reps(in_motion, self.smplx_model)["rep15d"][:, :: self.down_sample]
|
|
motion_features = self.motion_encoder_fintune(in_motion)
|
|
raw_motion_low = motion_features["low_level"]
|
|
raw_motion_high = motion_features["high_level"]
|
|
motion_low = self.motion_low_mapping(raw_motion_low)
|
|
motion_high = self.motion_high_mapping(raw_motion_high)
|
|
|
|
motion_high_att, motion_high_weight = self.motion_sa(motion_high, motion_high, motion_high)
|
|
bs, n, c = motion_high.shape
|
|
motion_high_att_before_sum = motion_high_weight[:, :, 0, :].unsqueeze(2) * motion_high.transpose(1, 2).view(bs, 8, c // 8, n)
|
|
motion_high_att_before_sum = motion_high_att_before_sum.reshape(bs, c, n).transpose(1, 2)
|
|
motion_low = F.interpolate(motion_low.transpose(1, 2), scale_factor=2, mode="linear", align_corners=True).transpose(1, 2)
|
|
motion_high_att = F.interpolate(motion_high_att.transpose(1, 2), scale_factor=2, mode="linear", align_corners=True).transpose(1, 2)
|
|
motion_high_att_before_sum = F.interpolate(
|
|
motion_high_att_before_sum.transpose(1, 2), scale_factor=2, mode="linear", align_corners=True
|
|
).transpose(1, 2)
|
|
|
|
|
|
motion_low = motion_low[:, :original_length]
|
|
motion_high_att = motion_high_att[:, :original_length]
|
|
motion_high_att_before_sum = motion_high_att_before_sum[:, :original_length]
|
|
|
|
motion_cls = motion_high_att[:, 0]
|
|
return {
|
|
"motion_low": motion_low,
|
|
"motion_high": motion_high_att,
|
|
"motion_cls": motion_cls,
|
|
"motion_high_weight": motion_high_att_before_sum,
|
|
}
|
|
|