File size: 871 Bytes
f7009b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import torch
from torch import nn
import math


class LstmModel(nn.Module):
    config = {}

    def __init__(self, positional_embedding):
        super().__init__()
        self.lstm_forward = nn.LSTM(
            input_size=self.config["d_model"],
            hidden_size=self.config["d_model"],
            num_layers=self.config["num_layers"],
            dropout=self.config["dropout"],
            bias=True,
            batch_first=True,)
        pe = positional_embedding[None, :, :]
        if self.config.get("trainable_pe"):
            self.pe = nn.Parameter(pe)
        else:  # fixed positional embedding
            self.register_buffer("pe", pe)

    def forward(self, output_shape, condition=None):
        assert len(condition.shape) == 3
        x, _ = self.lstm_forward(self.pe.repeat(output_shape[0], 1, 1) + condition)
        return x.contiguous()