AiOS / detrsmpl /models /necks /temporal_encoder.py
ttxskk
update
d7e58f0
raw
history blame
1.54 kB
from typing import Optional, Union
import torch.nn as nn
from mmcv.runner.base_module import BaseModule
class TemporalGRUEncoder(BaseModule):
"""TemporalEncoder used for VIBE. Adapted from
https://github.com/mkocabas/VIBE.
Args:
input_size (int, optional): dimension of input feature. Default: 2048.
num_layer (int, optional): number of layers for GRU. Default: 1.
hidden_size (int, optional): hidden size for GRU. Default: 2048.
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None.
"""
def __init__(self,
input_size: Optional[int] = 2048,
num_layers: Optional[int] = 1,
hidden_size: Optional[int] = 2048,
init_cfg: Optional[Union[list, dict, None]] = None):
super(TemporalGRUEncoder, self).__init__(init_cfg)
self.input_size = input_size
self.hidden_size = hidden_size
self.gru = nn.GRU(input_size=input_size,
hidden_size=hidden_size,
bidirectional=False,
num_layers=num_layers)
self.relu = nn.ReLU()
self.linear = self.linear = nn.Linear(hidden_size, input_size)
def forward(self, x):
N, T = x.shape[:2]
x = x.permute(1, 0, 2)
y, _ = self.gru(x)
y = self.linear(self.relu(y).view(-1, self.hidden_size))
y = y.view(T, N, self.input_size) + x
y = y.permute(1, 0, 2).contiguous()
return y