|
import torch |
|
from torch import nn |
|
import math |
|
|
|
|
|
class LstmModel(nn.Module): |
|
config = {} |
|
|
|
def __init__(self, positional_embedding): |
|
super().__init__() |
|
self.lstm_forward = nn.LSTM( |
|
input_size=self.config["d_model"], |
|
hidden_size=self.config["d_model"], |
|
num_layers=self.config["num_layers"], |
|
dropout=self.config["dropout"], |
|
bias=True, |
|
batch_first=True,) |
|
pe = positional_embedding[None, :, :] |
|
if self.config.get("trainable_pe"): |
|
self.pe = nn.Parameter(pe) |
|
else: |
|
self.register_buffer("pe", pe) |
|
|
|
def forward(self, output_shape, condition=None): |
|
assert len(condition.shape) == 3 |
|
x, _ = self.lstm_forward(self.pe.repeat(output_shape[0], 1, 1) + condition) |
|
return x.contiguous() |
|
|