|
"""Sequential implementation of Recurrent Neural Network Language Model.""" |
|
from typing import Tuple |
|
from typing import Union |
|
|
|
import torch |
|
import torch.nn as nn |
|
from typeguard import check_argument_types |
|
|
|
from espnet2.lm.abs_model import AbsLM |
|
|
|
|
|
class SequentialRNNLM(AbsLM): |
|
"""Sequential RNNLM. |
|
|
|
See also: |
|
https://github.com/pytorch/examples/blob/4581968193699de14b56527296262dd76ab43557/word_language_model/model.py |
|
|
|
""" |
|
|
|
def __init__( |
|
self, |
|
vocab_size: int, |
|
unit: int = 650, |
|
nhid: int = None, |
|
nlayers: int = 2, |
|
dropout_rate: float = 0.0, |
|
tie_weights: bool = False, |
|
rnn_type: str = "lstm", |
|
ignore_id: int = 0, |
|
): |
|
assert check_argument_types() |
|
super().__init__() |
|
|
|
ninp = unit |
|
if nhid is None: |
|
nhid = unit |
|
rnn_type = rnn_type.upper() |
|
|
|
self.drop = nn.Dropout(dropout_rate) |
|
self.encoder = nn.Embedding(vocab_size, ninp, padding_idx=ignore_id) |
|
if rnn_type in ["LSTM", "GRU"]: |
|
rnn_class = getattr(nn, rnn_type) |
|
self.rnn = rnn_class( |
|
ninp, nhid, nlayers, dropout=dropout_rate, batch_first=True |
|
) |
|
else: |
|
try: |
|
nonlinearity = {"RNN_TANH": "tanh", "RNN_RELU": "relu"}[rnn_type] |
|
except KeyError: |
|
raise ValueError( |
|
"""An invalid option for `--model` was supplied, |
|
options are ['LSTM', 'GRU', 'RNN_TANH' or 'RNN_RELU']""" |
|
) |
|
self.rnn = nn.RNN( |
|
ninp, |
|
nhid, |
|
nlayers, |
|
nonlinearity=nonlinearity, |
|
dropout=dropout_rate, |
|
batch_first=True, |
|
) |
|
self.decoder = nn.Linear(nhid, vocab_size) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if tie_weights: |
|
if nhid != ninp: |
|
raise ValueError( |
|
"When using the tied flag, nhid must be equal to emsize" |
|
) |
|
self.decoder.weight = self.encoder.weight |
|
|
|
self.rnn_type = rnn_type |
|
self.nhid = nhid |
|
self.nlayers = nlayers |
|
|
|
def forward( |
|
self, input: torch.Tensor, hidden: torch.Tensor |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
emb = self.drop(self.encoder(input)) |
|
output, hidden = self.rnn(emb, hidden) |
|
output = self.drop(output) |
|
decoded = self.decoder( |
|
output.contiguous().view(output.size(0) * output.size(1), output.size(2)) |
|
) |
|
return ( |
|
decoded.view(output.size(0), output.size(1), decoded.size(1)), |
|
hidden, |
|
) |
|
|
|
def score( |
|
self, |
|
y: torch.Tensor, |
|
state: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], |
|
x: torch.Tensor, |
|
) -> Tuple[torch.Tensor, Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]]: |
|
"""Score new token. |
|
|
|
Args: |
|
y: 1D torch.int64 prefix tokens. |
|
state: Scorer state for prefix tokens |
|
x: 2D encoder feature that generates ys. |
|
|
|
Returns: |
|
Tuple of |
|
torch.float32 scores for next token (n_vocab) |
|
and next state for ys |
|
|
|
""" |
|
y, new_state = self(y[-1].view(1, 1), state) |
|
logp = y.log_softmax(dim=-1).view(-1) |
|
return logp, new_state |
|
|
|
def batch_score( |
|
self, ys: torch.Tensor, states: torch.Tensor, xs: torch.Tensor |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
"""Score new token batch. |
|
|
|
Args: |
|
ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen). |
|
states (List[Any]): Scorer states for prefix tokens. |
|
xs (torch.Tensor): |
|
The encoder feature that generates ys (n_batch, xlen, n_feat). |
|
|
|
Returns: |
|
tuple[torch.Tensor, List[Any]]: Tuple of |
|
batchfied scores for next token with shape of `(n_batch, n_vocab)` |
|
and next state list for ys. |
|
|
|
""" |
|
if states[0] is None: |
|
states = None |
|
elif isinstance(self.rnn, torch.nn.LSTM): |
|
|
|
h = torch.stack([h for h, c in states], dim=1) |
|
c = torch.stack([c for h, c in states], dim=1) |
|
states = h, c |
|
else: |
|
|
|
states = torch.stack(states, dim=1) |
|
|
|
ys, states = self(ys[:, -1:], states) |
|
|
|
assert ys.size(1) == 1, ys.shape |
|
ys = ys.squeeze(1) |
|
logp = ys.log_softmax(dim=-1) |
|
|
|
|
|
if isinstance(self.rnn, torch.nn.LSTM): |
|
|
|
h, c = states |
|
|
|
states = [(h[:, i], c[:, i]) for i in range(h.size(1))] |
|
else: |
|
|
|
states = [states[:, i] for i in range(states.size(1))] |
|
|
|
return logp, states |
|
|