File size: 1,543 Bytes
d7e58f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
from typing import Optional, Union

import torch.nn as nn
from mmcv.runner.base_module import BaseModule


class TemporalGRUEncoder(BaseModule):
    """TemporalEncoder used for VIBE. Adapted from
    https://github.com/mkocabas/VIBE.

    Args:
        input_size (int, optional): dimension of input feature. Default: 2048.
        num_layer (int, optional): number of layers for GRU. Default: 1.
        hidden_size (int, optional): hidden size for GRU. Default: 2048.
        init_cfg (dict or list[dict], optional): Initialization config dict.
            Default: None.
    """
    def __init__(self,
                 input_size: Optional[int] = 2048,
                 num_layers: Optional[int] = 1,
                 hidden_size: Optional[int] = 2048,
                 init_cfg: Optional[Union[list, dict, None]] = None):
        super(TemporalGRUEncoder, self).__init__(init_cfg)

        self.input_size = input_size
        self.hidden_size = hidden_size
        self.gru = nn.GRU(input_size=input_size,
                          hidden_size=hidden_size,
                          bidirectional=False,
                          num_layers=num_layers)
        self.relu = nn.ReLU()
        self.linear = self.linear = nn.Linear(hidden_size, input_size)

    def forward(self, x):
        N, T = x.shape[:2]
        x = x.permute(1, 0, 2)
        y, _ = self.gru(x)
        y = self.linear(self.relu(y).view(-1, self.hidden_size))
        y = y.view(T, N, self.input_size) + x
        y = y.permute(1, 0, 2).contiguous()
        return y