File size: 21,220 Bytes
c9bb3f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
Tutorial: Simple LSTM
=====================

In this tutorial we will extend fairseq by adding a new
:class:`~fairseq.models.FairseqEncoderDecoderModel` that encodes a source
sentence with an LSTM and then passes the final hidden state to a second LSTM
that decodes the target sentence (without attention).

This tutorial covers:

1. **Writing an Encoder and Decoder** to encode/decode the source/target
   sentence, respectively.
2. **Registering a new Model** so that it can be used with the existing
   :ref:`Command-line tools`.
3. **Training the Model** using the existing command-line tools.
4. **Making generation faster** by modifying the Decoder to use
   :ref:`Incremental decoding`.


1. Building an Encoder and Decoder
----------------------------------

In this section we'll define a simple LSTM Encoder and Decoder. All Encoders
should implement the :class:`~fairseq.models.FairseqEncoder` interface and
Decoders should implement the :class:`~fairseq.models.FairseqDecoder` interface.
These interfaces themselves extend :class:`torch.nn.Module`, so FairseqEncoders
and FairseqDecoders can be written and used in the same ways as ordinary PyTorch
Modules.


Encoder
~~~~~~~

Our Encoder will embed the tokens in the source sentence, feed them to a
:class:`torch.nn.LSTM` and return the final hidden state. To create our encoder
save the following in a new file named :file:`fairseq/models/simple_lstm.py`::

  import torch.nn as nn
  from fairseq import utils
  from fairseq.models import FairseqEncoder

  class SimpleLSTMEncoder(FairseqEncoder):

      def __init__(
          self, args, dictionary, embed_dim=128, hidden_dim=128, dropout=0.1,
      ):
          super().__init__(dictionary)
          self.args = args

          # Our encoder will embed the inputs before feeding them to the LSTM.
          self.embed_tokens = nn.Embedding(
              num_embeddings=len(dictionary),
              embedding_dim=embed_dim,
              padding_idx=dictionary.pad(),
          )
          self.dropout = nn.Dropout(p=dropout)

          # We'll use a single-layer, unidirectional LSTM for simplicity.
          self.lstm = nn.LSTM(
              input_size=embed_dim,
              hidden_size=hidden_dim,
              num_layers=1,
              bidirectional=False,
              batch_first=True,
          )

      def forward(self, src_tokens, src_lengths):
          # The inputs to the ``forward()`` function are determined by the
          # Task, and in particular the ``'net_input'`` key in each
          # mini-batch. We discuss Tasks in the next tutorial, but for now just
          # know that *src_tokens* has shape `(batch, src_len)` and *src_lengths*
          # has shape `(batch)`.

          # Note that the source is typically padded on the left. This can be
          # configured by adding the `--left-pad-source "False"` command-line
          # argument, but here we'll make the Encoder handle either kind of
          # padding by converting everything to be right-padded.
          if self.args.left_pad_source:
              # Convert left-padding to right-padding.
              src_tokens = utils.convert_padding_direction(
                  src_tokens,
                  padding_idx=self.dictionary.pad(),
                  left_to_right=True
              )

          # Embed the source.
          x = self.embed_tokens(src_tokens)

          # Apply dropout.
          x = self.dropout(x)

          # Pack the sequence into a PackedSequence object to feed to the LSTM.
          x = nn.utils.rnn.pack_padded_sequence(x, src_lengths, batch_first=True)

          # Get the output from the LSTM.
          _outputs, (final_hidden, _final_cell) = self.lstm(x)

          # Return the Encoder's output. This can be any object and will be
          # passed directly to the Decoder.
          return {
              # this will have shape `(bsz, hidden_dim)`
              'final_hidden': final_hidden.squeeze(0),
          }

      # Encoders are required to implement this method so that we can rearrange
      # the order of the batch elements during inference (e.g., beam search).
      def reorder_encoder_out(self, encoder_out, new_order):
          """
          Reorder encoder output according to `new_order`.

          Args:
              encoder_out: output from the ``forward()`` method
              new_order (LongTensor): desired order

          Returns:
              `encoder_out` rearranged according to `new_order`
          """
          final_hidden = encoder_out['final_hidden']
          return {
              'final_hidden': final_hidden.index_select(0, new_order),
          }


Decoder
~~~~~~~

Our Decoder will predict the next word, conditioned on the Encoder's final
hidden state and an embedded representation of the previous target word -- which
is sometimes called *teacher forcing*. More specifically, we'll use a
:class:`torch.nn.LSTM` to produce a sequence of hidden states that we'll project
to the size of the output vocabulary to predict each target word.

::

  import torch
  from fairseq.models import FairseqDecoder

  class SimpleLSTMDecoder(FairseqDecoder):

      def __init__(
          self, dictionary, encoder_hidden_dim=128, embed_dim=128, hidden_dim=128,
          dropout=0.1,
      ):
          super().__init__(dictionary)

          # Our decoder will embed the inputs before feeding them to the LSTM.
          self.embed_tokens = nn.Embedding(
              num_embeddings=len(dictionary),
              embedding_dim=embed_dim,
              padding_idx=dictionary.pad(),
          )
          self.dropout = nn.Dropout(p=dropout)

          # We'll use a single-layer, unidirectional LSTM for simplicity.
          self.lstm = nn.LSTM(
              # For the first layer we'll concatenate the Encoder's final hidden
              # state with the embedded target tokens.
              input_size=encoder_hidden_dim + embed_dim,
              hidden_size=hidden_dim,
              num_layers=1,
              bidirectional=False,
          )

          # Define the output projection.
          self.output_projection = nn.Linear(hidden_dim, len(dictionary))

      # During training Decoders are expected to take the entire target sequence
      # (shifted right by one position) and produce logits over the vocabulary.
      # The *prev_output_tokens* tensor begins with the end-of-sentence symbol,
      # ``dictionary.eos()``, followed by the target sequence.
      def forward(self, prev_output_tokens, encoder_out):
          """
          Args:
              prev_output_tokens (LongTensor): previous decoder outputs of shape
                  `(batch, tgt_len)`, for teacher forcing
              encoder_out (Tensor, optional): output from the encoder, used for
                  encoder-side attention

          Returns:
              tuple:
                  - the last decoder layer's output of shape
                    `(batch, tgt_len, vocab)`
                  - the last decoder layer's attention weights of shape
                    `(batch, tgt_len, src_len)`
          """
          bsz, tgt_len = prev_output_tokens.size()

          # Extract the final hidden state from the Encoder.
          final_encoder_hidden = encoder_out['final_hidden']

          # Embed the target sequence, which has been shifted right by one
          # position and now starts with the end-of-sentence symbol.
          x = self.embed_tokens(prev_output_tokens)

          # Apply dropout.
          x = self.dropout(x)

          # Concatenate the Encoder's final hidden state to *every* embedded
          # target token.
          x = torch.cat(
              [x, final_encoder_hidden.unsqueeze(1).expand(bsz, tgt_len, -1)],
              dim=2,
          )

          # Using PackedSequence objects in the Decoder is harder than in the
          # Encoder, since the targets are not sorted in descending length order,
          # which is a requirement of ``pack_padded_sequence()``. Instead we'll
          # feed nn.LSTM directly.
          initial_state = (
              final_encoder_hidden.unsqueeze(0),  # hidden
              torch.zeros_like(final_encoder_hidden).unsqueeze(0),  # cell
          )
          output, _ = self.lstm(
              x.transpose(0, 1),  # convert to shape `(tgt_len, bsz, dim)`
              initial_state,
          )
          x = output.transpose(0, 1)  # convert to shape `(bsz, tgt_len, hidden)`

          # Project the outputs to the size of the vocabulary.
          x = self.output_projection(x)

          # Return the logits and ``None`` for the attention weights
          return x, None


2. Registering the Model
------------------------

Now that we've defined our Encoder and Decoder we must *register* our model with
fairseq using the :func:`~fairseq.models.register_model` function decorator.
Once the model is registered we'll be able to use it with the existing
:ref:`Command-line Tools`.

All registered models must implement the
:class:`~fairseq.models.BaseFairseqModel` interface. For sequence-to-sequence
models (i.e., any model with a single Encoder and Decoder), we can instead
implement the :class:`~fairseq.models.FairseqEncoderDecoderModel` interface.

Create a small wrapper class in the same file and register it in fairseq with
the name ``'simple_lstm'``::

  from fairseq.models import FairseqEncoderDecoderModel, register_model

  # Note: the register_model "decorator" should immediately precede the
  # definition of the Model class.

  @register_model('simple_lstm')
  class SimpleLSTMModel(FairseqEncoderDecoderModel):

      @staticmethod
      def add_args(parser):
          # Models can override this method to add new command-line arguments.
          # Here we'll add some new command-line arguments to configure dropout
          # and the dimensionality of the embeddings and hidden states.
          parser.add_argument(
              '--encoder-embed-dim', type=int, metavar='N',
              help='dimensionality of the encoder embeddings',
          )
          parser.add_argument(
              '--encoder-hidden-dim', type=int, metavar='N',
              help='dimensionality of the encoder hidden state',
          )
          parser.add_argument(
              '--encoder-dropout', type=float, default=0.1,
              help='encoder dropout probability',
          )
          parser.add_argument(
              '--decoder-embed-dim', type=int, metavar='N',
              help='dimensionality of the decoder embeddings',
          )
          parser.add_argument(
              '--decoder-hidden-dim', type=int, metavar='N',
              help='dimensionality of the decoder hidden state',
          )
          parser.add_argument(
              '--decoder-dropout', type=float, default=0.1,
              help='decoder dropout probability',
          )

      @classmethod
      def build_model(cls, args, task):
          # Fairseq initializes models by calling the ``build_model()``
          # function. This provides more flexibility, since the returned model
          # instance can be of a different type than the one that was called.
          # In this case we'll just return a SimpleLSTMModel instance.

          # Initialize our Encoder and Decoder.
          encoder = SimpleLSTMEncoder(
              args=args,
              dictionary=task.source_dictionary,
              embed_dim=args.encoder_embed_dim,
              hidden_dim=args.encoder_hidden_dim,
              dropout=args.encoder_dropout,
          )
          decoder = SimpleLSTMDecoder(
              dictionary=task.target_dictionary,
              encoder_hidden_dim=args.encoder_hidden_dim,
              embed_dim=args.decoder_embed_dim,
              hidden_dim=args.decoder_hidden_dim,
              dropout=args.decoder_dropout,
          )
          model = SimpleLSTMModel(encoder, decoder)

          # Print the model architecture.
          print(model)

          return model

      # We could override the ``forward()`` if we wanted more control over how
      # the encoder and decoder interact, but it's not necessary for this
      # tutorial since we can inherit the default implementation provided by
      # the FairseqEncoderDecoderModel base class, which looks like:
      #
      # def forward(self, src_tokens, src_lengths, prev_output_tokens):
      #     encoder_out = self.encoder(src_tokens, src_lengths)
      #     decoder_out = self.decoder(prev_output_tokens, encoder_out)
      #     return decoder_out

Finally let's define a *named architecture* with the configuration for our
model. This is done with the :func:`~fairseq.models.register_model_architecture`
function decorator. Thereafter this named architecture can be used with the
``--arch`` command-line argument, e.g., ``--arch tutorial_simple_lstm``::

  from fairseq.models import register_model_architecture

  # The first argument to ``register_model_architecture()`` should be the name
  # of the model we registered above (i.e., 'simple_lstm'). The function we
  # register here should take a single argument *args* and modify it in-place
  # to match the desired architecture.

  @register_model_architecture('simple_lstm', 'tutorial_simple_lstm')
  def tutorial_simple_lstm(args):
      # We use ``getattr()`` to prioritize arguments that are explicitly given
      # on the command-line, so that the defaults defined below are only used
      # when no other value has been specified.
      args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 256)
      args.encoder_hidden_dim = getattr(args, 'encoder_hidden_dim', 256)
      args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 256)
      args.decoder_hidden_dim = getattr(args, 'decoder_hidden_dim', 256)


3. Training the Model
---------------------

Now we're ready to train the model. We can use the existing :ref:`fairseq-train`
command-line tool for this, making sure to specify our new Model architecture
(``--arch tutorial_simple_lstm``).

.. note::

  Make sure you've already preprocessed the data from the IWSLT example in the
  :file:`examples/translation/` directory.

.. code-block:: console

  > fairseq-train data-bin/iwslt14.tokenized.de-en \
    --arch tutorial_simple_lstm \
    --encoder-dropout 0.2 --decoder-dropout 0.2 \
    --optimizer adam --lr 0.005 --lr-shrink 0.5 \
    --max-tokens 12000
  (...)
  | epoch 052 | loss 4.027 | ppl 16.30 | wps 420805 | ups 39.7 | wpb 9841 | bsz 400 | num_updates 20852 | lr 1.95313e-05 | gnorm 0.218 | clip 0% | oom 0 | wall 529 | train_wall 396
  | epoch 052 | valid on 'valid' subset | valid_loss 4.74989 | valid_ppl 26.91 | num_updates 20852 | best 4.74954

The model files should appear in the :file:`checkpoints/` directory. While this
model architecture is not very good, we can use the :ref:`fairseq-generate` script to
generate translations and compute our BLEU score over the test set:

.. code-block:: console

  > fairseq-generate data-bin/iwslt14.tokenized.de-en \
    --path checkpoints/checkpoint_best.pt \
    --beam 5 \
    --remove-bpe
  (...)
  | Translated 6750 sentences (153132 tokens) in 17.3s (389.12 sentences/s, 8827.68 tokens/s)
  | Generate test with beam=5: BLEU4 = 8.18, 38.8/12.1/4.7/2.0 (BP=1.000, ratio=1.066, syslen=139865, reflen=131146)


4. Making generation faster
---------------------------

While autoregressive generation from sequence-to-sequence models is inherently
slow, our implementation above is especially slow because it recomputes the
entire sequence of Decoder hidden states for every output token (i.e., it is
``O(n^2)``). We can make this significantly faster by instead caching the
previous hidden states.

In fairseq this is called :ref:`Incremental decoding`. Incremental decoding is a
special mode at inference time where the Model only receives a single timestep
of input corresponding to the immediately previous output token (for teacher
forcing) and must produce the next output incrementally. Thus the model must
cache any long-term state that is needed about the sequence, e.g., hidden
states, convolutional states, etc.

To implement incremental decoding we will modify our model to implement the
:class:`~fairseq.models.FairseqIncrementalDecoder` interface. Compared to the
standard :class:`~fairseq.models.FairseqDecoder` interface, the incremental
decoder interface allows ``forward()`` methods to take an extra keyword argument
(*incremental_state*) that can be used to cache state across time-steps.

Let's replace our ``SimpleLSTMDecoder`` with an incremental one::

  import torch
  from fairseq.models import FairseqIncrementalDecoder

  class SimpleLSTMDecoder(FairseqIncrementalDecoder):

      def __init__(
          self, dictionary, encoder_hidden_dim=128, embed_dim=128, hidden_dim=128,
          dropout=0.1,
      ):
          # This remains the same as before.
          super().__init__(dictionary)
          self.embed_tokens = nn.Embedding(
              num_embeddings=len(dictionary),
              embedding_dim=embed_dim,
              padding_idx=dictionary.pad(),
          )
          self.dropout = nn.Dropout(p=dropout)
          self.lstm = nn.LSTM(
              input_size=encoder_hidden_dim + embed_dim,
              hidden_size=hidden_dim,
              num_layers=1,
              bidirectional=False,
          )
          self.output_projection = nn.Linear(hidden_dim, len(dictionary))

      # We now take an additional kwarg (*incremental_state*) for caching the
      # previous hidden and cell states.
      def forward(self, prev_output_tokens, encoder_out, incremental_state=None):
          if incremental_state is not None:
              # If the *incremental_state* argument is not ``None`` then we are
              # in incremental inference mode. While *prev_output_tokens* will
              # still contain the entire decoded prefix, we will only use the
              # last step and assume that the rest of the state is cached.
              prev_output_tokens = prev_output_tokens[:, -1:]

          # This remains the same as before.
          bsz, tgt_len = prev_output_tokens.size()
          final_encoder_hidden = encoder_out['final_hidden']
          x = self.embed_tokens(prev_output_tokens)
          x = self.dropout(x)
          x = torch.cat(
              [x, final_encoder_hidden.unsqueeze(1).expand(bsz, tgt_len, -1)],
              dim=2,
          )

          # We will now check the cache and load the cached previous hidden and
          # cell states, if they exist, otherwise we will initialize them to
          # zeros (as before). We will use the ``utils.get_incremental_state()``
          # and ``utils.set_incremental_state()`` helpers.
          initial_state = utils.get_incremental_state(
              self, incremental_state, 'prev_state',
          )
          if initial_state is None:
              # first time initialization, same as the original version
              initial_state = (
                  final_encoder_hidden.unsqueeze(0),  # hidden
                  torch.zeros_like(final_encoder_hidden).unsqueeze(0),  # cell
              )

          # Run one step of our LSTM.
          output, latest_state = self.lstm(x.transpose(0, 1), initial_state)

          # Update the cache with the latest hidden and cell states.
          utils.set_incremental_state(
              self, incremental_state, 'prev_state', latest_state,
          )

          # This remains the same as before
          x = output.transpose(0, 1)
          x = self.output_projection(x)
          return x, None

      # The ``FairseqIncrementalDecoder`` interface also requires implementing a
      # ``reorder_incremental_state()`` method, which is used during beam search
      # to select and reorder the incremental state.
      def reorder_incremental_state(self, incremental_state, new_order):
          # Load the cached state.
          prev_state = utils.get_incremental_state(
              self, incremental_state, 'prev_state',
          )

          # Reorder batches according to *new_order*.
          reordered_state = (
              prev_state[0].index_select(1, new_order),  # hidden
              prev_state[1].index_select(1, new_order),  # cell
          )

          # Update the cached state.
          utils.set_incremental_state(
              self, incremental_state, 'prev_state', reordered_state,
          )

Finally, we can rerun generation and observe the speedup:

.. code-block:: console

  # Before

  > fairseq-generate data-bin/iwslt14.tokenized.de-en \
    --path checkpoints/checkpoint_best.pt \
    --beam 5 \
    --remove-bpe
  (...)
  | Translated 6750 sentences (153132 tokens) in 17.3s (389.12 sentences/s, 8827.68 tokens/s)
  | Generate test with beam=5: BLEU4 = 8.18, 38.8/12.1/4.7/2.0 (BP=1.000, ratio=1.066, syslen=139865, reflen=131146)

  # After

  > fairseq-generate data-bin/iwslt14.tokenized.de-en \
    --path checkpoints/checkpoint_best.pt \
    --beam 5 \
    --remove-bpe
  (...)
  | Translated 6750 sentences (153132 tokens) in 5.5s (1225.54 sentences/s, 27802.94 tokens/s)
  | Generate test with beam=5: BLEU4 = 8.18, 38.8/12.1/4.7/2.0 (BP=1.000, ratio=1.066, syslen=139865, reflen=131146)