""" |
Long Short Term Memory (LSTM) <link https://ieeexplore.ieee.org/abstract/document/6795963 link> is a kind of recurrent neural network that can capture long-short term information. |
This document mainly includes: |
- Pytorch implementation for LSTM. |
- An example to test LSTM. |
For beginners, you can refer to <link https://zhuanlan.zhihu.com/p/32085405 link> to learn the basics about how LSTM works. |
""" |
from typing import Optional, Union, Tuple, List, Dict |
import math |
import torch |
import torch.nn as nn |
from ding.torch_utils import build_normalization |
class LSTM(nn.Module): |
""" |
**Overview:** |
Implementation of LSTM cell with layer norm. |
""" |
def __init__( |
self, |
input_size: int, |
hidden_size: int, |
num_layers: int, |
norm_type: Optional[str] = 'LN', |
dropout: float = 0. |
) -> None: |
super(LSTM, self).__init__() |
self.input_size = input_size |
self.hidden_size = hidden_size |
self.num_layers = num_layers |
norm_func = build_normalization(norm_type) |
self.norm = nn.ModuleList([norm_func(hidden_size * 4) for _ in range(2 * num_layers)]) |
self.wx = nn.ParameterList() |
self.wh = nn.ParameterList() |
dims = [input_size] + [hidden_size] * num_layers |
for l in range(num_layers): |
self.wx.append(nn.Parameter(torch.zeros(dims[l], dims[l + 1] * 4))) |
self.wh.append(nn.Parameter(torch.zeros(hidden_size, hidden_size * 4))) |
self.bias = nn.Parameter(torch.zeros(num_layers, hidden_size * 4)) |
self.use_dropout = dropout > 0. |
if self.use_dropout: |
self.dropout = nn.Dropout(dropout) |
self._init() |
def _before_forward(self, inputs: torch.Tensor, prev_state: Union[None, List[Dict]]) -> torch.Tensor: |
seq_len, batch_size = inputs.shape[:2] |
if prev_state is None: |
zeros = torch.zeros(self.num_layers, batch_size, self.hidden_size, dtype=inputs.dtype, device=inputs.device) |
prev_state = (zeros, zeros) |
else: |
assert len(prev_state) == batch_size |
state = [[v for v in prev.values()] for prev in prev_state] |
state = list(zip(*state)) |
prev_state = [torch.cat(t, dim=1) for t in state] |
return prev_state |
def _init(self): |
gain = math.sqrt(1. / self.hidden_size) |
for l in range(self.num_layers): |
torch.nn.init.uniform_(self.wx[l], -gain, gain) |
torch.nn.init.uniform_(self.wh[l], -gain, gain) |
if self.bias is not None: |
torch.nn.init.uniform_(self.bias[l], -gain, gain) |
def forward( |
self, |
inputs: torch.Tensor, |
prev_state: torch.Tensor, |
) -> Tuple[torch.Tensor, Union[torch.Tensor, list]]: |
seq_len, batch_size = inputs.shape[:2] |
prev_state = self._before_forward(inputs, prev_state) |
H, C = prev_state |
x = inputs |
next_state = [] |
for l in range(self.num_layers): |
h, c = H[l], C[l] |
new_x = [] |
for s in range(seq_len): |
gate = self.norm[l * 2](torch.matmul(x[s], self.wx[l]) |
) + self.norm[l * 2 + 1](torch.matmul(h, self.wh[l])) |
if self.bias is not None: |
gate += self.bias[l] |
gate = list(torch.chunk(gate, 4, dim=1)) |
i, f, o, z = gate |
i = torch.sigmoid(i) |
f = torch.sigmoid(f) |
o = torch.sigmoid(o) |
z = torch.tanh(z) |
c = f * c + i * z |
h = o * torch.tanh(c) |
new_x.append(h) |
next_state.append((h, c)) |
x = torch.stack(new_x, dim=0) |
if self.use_dropout and l != self.num_layers - 1: |
x = self.dropout(x) |
next_state = [torch.stack(t, dim=0) for t in zip(*next_state)] |
h, c = next_state |
batch_size = h.shape[1] |
next_state = [torch.chunk(h, batch_size, dim=1), torch.chunk(c, batch_size, dim=1)] |
next_state = list(zip(*next_state)) |
next_state = [{k: v for k, v in zip(['h', 'c'], item)} for item in next_state] |
return x, next_state |
def pack_data(data: List[torch.Tensor], traj_len: int) -> Tuple[torch.Tensor, torch.Tensor]: |
""" |
Overview: |
You need to pack variable-length data to regular tensor, return tensor and corresponding mask. |
If len(data_i) < traj_len, use `null_padding`, |
else split the whole sequences info different trajectories. |
Returns: |
- tensor (:obj:`torch.Tensor`): dtype (torch.float32), shape (traj_len, B, N) |
- mask (:obj:`torch.Tensor`): dtype (torch.float32), shape (traj_len, B) |
""" |
new_data = [] |
mask = [] |
for item in data: |
D, N = item.shape |
if D < traj_len: |
null_padding = torch.zeros(traj_len - D, N) |
new_item = torch.cat([item, null_padding]) |
new_data.append(new_item) |
item_mask = torch.ones(traj_len) |
item_mask[D:].zero_() |
mask.append(item_mask) |
else: |
for i in range(0, D, traj_len): |
item_mask = torch.ones(traj_len) |
new_item = item[i:i + traj_len] |
if new_item.shape[0] < traj_len: |
new_item = item[-traj_len:] |
new_data.append(new_item) |
mask.append(torch.ones(traj_len)) |
new_data = torch.stack(new_data, dim=1) |
mask = torch.stack(mask, dim=1) |
return new_data, mask |
def test_lstm(): |
seq_len_list = [32, 49, 24, 78, 45] |
traj_len = 32 |
N = 10 |
hidden_size = 32 |
num_layers = 2 |
variable_len_data = [torch.rand(s, N) for s in seq_len_list] |
input_, mask = pack_data(variable_len_data, traj_len) |
assert isinstance(input_, torch.Tensor), type(input_) |
batch_size = input_.shape[1] |
assert batch_size == 9, "packed data must have 9 trajectories" |
lstm = LSTM(N, hidden_size=hidden_size, num_layers=num_layers, norm_type='LN', dropout=0.1) |
prev_state = None |
for s in range(traj_len): |
input_step = input_[s:s + 1] |
output, prev_state = lstm(input_step, prev_state) |
assert output.shape == (1, batch_size, hidden_size) |
assert len(prev_state) == batch_size |
assert prev_state[0]['h'].shape == (num_layers, 1, hidden_size) |
loss = (output * mask.unsqueeze(-1)).mean() |
loss.backward() |
for _, m in lstm.named_parameters(): |
assert isinstance(m.grad, torch.Tensor) |
print('finished') |
if __name__ == '__main__': |
test_lstm() |