osanseviero commited on
Commit
fc67275
·
1 Parent(s): 07d2bce
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. CODE_OF_CONDUCT.md +77 -0
  2. CONTRIBUTING.md +28 -0
  3. LICENSE +21 -0
  4. audio1.mp3 +0 -0
  5. audio2.mp3 +0 -0
  6. docs/Makefile +20 -0
  7. docs/_static/theme_overrides.css +9 -0
  8. docs/command_line_tools.rst +85 -0
  9. docs/conf.py +134 -0
  10. docs/criterions.rst +31 -0
  11. docs/data.rst +58 -0
  12. docs/docutils.conf +2 -0
  13. docs/fairseq.gif +0 -0
  14. docs/fairseq_logo.png +0 -0
  15. docs/getting_started.rst +216 -0
  16. docs/hydra_integration.md +284 -0
  17. docs/index.rst +49 -0
  18. docs/lr_scheduler.rst +34 -0
  19. docs/make.bat +36 -0
  20. docs/models.rst +104 -0
  21. docs/modules.rst +9 -0
  22. docs/optim.rst +38 -0
  23. docs/overview.rst +74 -0
  24. docs/requirements.txt +2 -0
  25. docs/tasks.rst +61 -0
  26. docs/tutorial_classifying_names.rst +415 -0
  27. docs/tutorial_simple_lstm.rst +518 -0
  28. examples/.gitignore +2 -0
  29. examples/__init__.py +9 -0
  30. examples/adaptive_span/README.md +90 -0
  31. examples/adaptive_span/__init__.py +19 -0
  32. examples/adaptive_span/adagrad_with_grad_clip.py +128 -0
  33. examples/adaptive_span/adaptive_span_attention.py +160 -0
  34. examples/adaptive_span/adaptive_span_loss.py +106 -0
  35. examples/adaptive_span/adaptive_span_model.py +263 -0
  36. examples/adaptive_span/adaptive_span_model_wrapper.py +145 -0
  37. examples/adaptive_span/truncated_bptt_lm_task.py +1 -0
  38. examples/backtranslation/README.md +297 -0
  39. examples/backtranslation/deduplicate_lines.py +41 -0
  40. examples/backtranslation/extract_bt_data.py +72 -0
  41. examples/backtranslation/prepare-de-monolingual.sh +98 -0
  42. examples/backtranslation/prepare-wmt18en2de.sh +135 -0
  43. examples/backtranslation/sacrebleu.sh +37 -0
  44. examples/backtranslation/tokenized_bleu.sh +46 -0
  45. examples/bart/README.glue.md +99 -0
  46. examples/bart/README.md +228 -0
  47. examples/bart/README.summarization.md +102 -0
  48. examples/bart/summarize.py +100 -0
  49. examples/byte_level_bpe/README.md +88 -0
  50. examples/byte_level_bpe/get_bitext.py +254 -0
CODE_OF_CONDUCT.md ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Code of Conduct
2
+
3
+ ## Our Pledge
4
+
5
+ In the interest of fostering an open and welcoming environment, we as
6
+ contributors and maintainers pledge to make participation in our project and
7
+ our community a harassment-free experience for everyone, regardless of age, body
8
+ size, disability, ethnicity, sex characteristics, gender identity and expression,
9
+ level of experience, education, socio-economic status, nationality, personal
10
+ appearance, race, religion, or sexual identity and orientation.
11
+
12
+ ## Our Standards
13
+
14
+ Examples of behavior that contributes to creating a positive environment
15
+ include:
16
+
17
+ * Using welcoming and inclusive language
18
+ * Being respectful of differing viewpoints and experiences
19
+ * Gracefully accepting constructive criticism
20
+ * Focusing on what is best for the community
21
+ * Showing empathy towards other community members
22
+
23
+ Examples of unacceptable behavior by participants include:
24
+
25
+ * The use of sexualized language or imagery and unwelcome sexual attention or
26
+ advances
27
+ * Trolling, insulting/derogatory comments, and personal or political attacks
28
+ * Public or private harassment
29
+ * Publishing others' private information, such as a physical or electronic
30
+ address, without explicit permission
31
+ * Other conduct which could reasonably be considered inappropriate in a
32
+ professional setting
33
+
34
+ ## Our Responsibilities
35
+
36
+ Project maintainers are responsible for clarifying the standards of acceptable
37
+ behavior and are expected to take appropriate and fair corrective action in
38
+ response to any instances of unacceptable behavior.
39
+
40
+ Project maintainers have the right and responsibility to remove, edit, or
41
+ reject comments, commits, code, wiki edits, issues, and other contributions
42
+ that are not aligned to this Code of Conduct, or to ban temporarily or
43
+ permanently any contributor for other behaviors that they deem inappropriate,
44
+ threatening, offensive, or harmful.
45
+
46
+ ## Scope
47
+
48
+ This Code of Conduct applies within all project spaces, and it also applies when
49
+ an individual is representing the project or its community in public spaces.
50
+ Examples of representing a project or community include using an official
51
+ project e-mail address, posting via an official social media account, or acting
52
+ as an appointed representative at an online or offline event. Representation of
53
+ a project may be further defined and clarified by project maintainers.
54
+
55
+ ## Enforcement
56
+
57
+ Instances of abusive, harassing, or otherwise unacceptable behavior may be
58
+ reported by contacting the project team at <[email protected]>. All
59
+ complaints will be reviewed and investigated and will result in a response that
60
+ is deemed necessary and appropriate to the circumstances. The project team is
61
+ obligated to maintain confidentiality with regard to the reporter of an incident.
62
+ Further details of specific enforcement policies may be posted separately.
63
+
64
+ Project maintainers who do not follow or enforce the Code of Conduct in good
65
+ faith may face temporary or permanent repercussions as determined by other
66
+ members of the project's leadership.
67
+
68
+ ## Attribution
69
+
70
+ This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
71
+ available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
72
+
73
+ [homepage]: https://www.contributor-covenant.org
74
+
75
+ For answers to common questions about this code of conduct, see
76
+ https://www.contributor-covenant.org/faq
77
+
CONTRIBUTING.md ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Contributing to Facebook AI Research Sequence-to-Sequence Toolkit (fairseq)
2
+ We want to make contributing to this project as easy and transparent as
3
+ possible.
4
+
5
+ ## Pull Requests
6
+ We actively welcome your pull requests.
7
+
8
+ 1. Fork the repo and create your branch from `master`.
9
+ 2. If you've added code that should be tested, add tests.
10
+ 3. If you've changed APIs, update the documentation.
11
+ 4. Ensure the test suite passes.
12
+ 5. Make sure your code lints.
13
+ 6. If you haven't already, complete the Contributor License Agreement ("CLA").
14
+
15
+ ## Contributor License Agreement ("CLA")
16
+ In order to accept your pull request, we need you to submit a CLA. You only need
17
+ to do this once to work on any of Facebook's open source projects.
18
+
19
+ Complete your CLA here: <https://code.facebook.com/cla>
20
+
21
+ ## Issues
22
+ We use GitHub issues to track public bugs. Please ensure your description is
23
+ clear and has sufficient instructions to be able to reproduce the issue.
24
+
25
+ ## License
26
+ By contributing to Facebook AI Research Sequence-to-Sequence Toolkit (fairseq),
27
+ you agree that your contributions will be licensed under the LICENSE file in
28
+ the root directory of this source tree.
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) Facebook, Inc. and its affiliates.
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
audio1.mp3 ADDED
Binary file (221 kB). View file
 
audio2.mp3 ADDED
Binary file (268 kB). View file
 
docs/Makefile ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Minimal makefile for Sphinx documentation
2
+ #
3
+
4
+ # You can set these variables from the command line.
5
+ SPHINXOPTS =
6
+ SPHINXBUILD = python -msphinx
7
+ SPHINXPROJ = fairseq
8
+ SOURCEDIR = .
9
+ BUILDDIR = _build
10
+
11
+ # Put it first so that "make" without argument is like "make help".
12
+ help:
13
+ @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14
+
15
+ .PHONY: help Makefile
16
+
17
+ # Catch-all target: route all unknown targets to Sphinx using the new
18
+ # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19
+ %: Makefile
20
+ @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
docs/_static/theme_overrides.css ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ .wy-table-responsive table td kbd {
2
+ white-space: nowrap;
3
+ }
4
+ .wy-table-responsive table td {
5
+ white-space: normal !important;
6
+ }
7
+ .wy-table-responsive {
8
+ overflow: visible !important;
9
+ }
docs/command_line_tools.rst ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .. _Command-line Tools:
2
+
3
+ Command-line Tools
4
+ ==================
5
+
6
+ Fairseq provides several command-line tools for training and evaluating models:
7
+
8
+ - :ref:`fairseq-preprocess`: Data pre-processing: build vocabularies and binarize training data
9
+ - :ref:`fairseq-train`: Train a new model on one or multiple GPUs
10
+ - :ref:`fairseq-generate`: Translate pre-processed data with a trained model
11
+ - :ref:`fairseq-interactive`: Translate raw text with a trained model
12
+ - :ref:`fairseq-score`: BLEU scoring of generated translations against reference translations
13
+ - :ref:`fairseq-eval-lm`: Language model evaluation
14
+
15
+
16
+ .. _fairseq-preprocess:
17
+
18
+ fairseq-preprocess
19
+ ~~~~~~~~~~~~~~~~~~
20
+ .. automodule:: fairseq_cli.preprocess
21
+
22
+ .. argparse::
23
+ :module: fairseq.options
24
+ :func: get_preprocessing_parser
25
+ :prog: fairseq-preprocess
26
+
27
+
28
+ .. _fairseq-train:
29
+
30
+ fairseq-train
31
+ ~~~~~~~~~~~~~
32
+ .. automodule:: fairseq_cli.train
33
+
34
+ .. argparse::
35
+ :module: fairseq.options
36
+ :func: get_training_parser
37
+ :prog: fairseq-train
38
+
39
+
40
+ .. _fairseq-generate:
41
+
42
+ fairseq-generate
43
+ ~~~~~~~~~~~~~~~~
44
+ .. automodule:: fairseq_cli.generate
45
+
46
+ .. argparse::
47
+ :module: fairseq.options
48
+ :func: get_generation_parser
49
+ :prog: fairseq-generate
50
+
51
+
52
+ .. _fairseq-interactive:
53
+
54
+ fairseq-interactive
55
+ ~~~~~~~~~~~~~~~~~~~
56
+ .. automodule:: fairseq_cli.interactive
57
+
58
+ .. argparse::
59
+ :module: fairseq.options
60
+ :func: get_interactive_generation_parser
61
+ :prog: fairseq-interactive
62
+
63
+
64
+ .. _fairseq-score:
65
+
66
+ fairseq-score
67
+ ~~~~~~~~~~~~~
68
+ .. automodule:: fairseq_cli.score
69
+
70
+ .. argparse::
71
+ :module: fairseq_cli.score
72
+ :func: get_parser
73
+ :prog: fairseq-score
74
+
75
+
76
+ .. _fairseq-eval-lm:
77
+
78
+ fairseq-eval-lm
79
+ ~~~~~~~~~~~~~~~
80
+ .. automodule:: fairseq_cli.eval_lm
81
+
82
+ .. argparse::
83
+ :module: fairseq.options
84
+ :func: get_eval_lm_parser
85
+ :prog: fairseq-eval-lm
docs/conf.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ #
4
+ # fairseq documentation build configuration file, created by
5
+ # sphinx-quickstart on Fri Aug 17 21:45:30 2018.
6
+ #
7
+ # This file is execfile()d with the current directory set to its
8
+ # containing dir.
9
+ #
10
+ # Note that not all possible configuration values are present in this
11
+ # autogenerated file.
12
+ #
13
+ # All configuration values have a default; values that are commented out
14
+ # serve to show the default.
15
+
16
+ # If extensions (or modules to document with autodoc) are in another directory,
17
+ # add these directories to sys.path here. If the directory is relative to the
18
+ # documentation root, use os.path.abspath to make it absolute, like shown here.
19
+
20
+ import os
21
+ import sys
22
+ from fairseq import __version__
23
+
24
+
25
+ # source code directory, relative to this file, for sphinx-autobuild
26
+ sys.path.insert(0, os.path.abspath(".."))
27
+
28
+ source_suffix = [".rst"]
29
+
30
+ # -- General configuration ------------------------------------------------
31
+
32
+ # If your documentation needs a minimal Sphinx version, state it here.
33
+ #
34
+ # needs_sphinx = '1.0'
35
+
36
+ # Add any Sphinx extension module names here, as strings. They can be
37
+ # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
38
+ # ones.
39
+ extensions = [
40
+ "sphinx.ext.autodoc",
41
+ "sphinx.ext.intersphinx",
42
+ "sphinx.ext.viewcode",
43
+ "sphinx.ext.napoleon",
44
+ "sphinxarg.ext",
45
+ ]
46
+
47
+ # Add any paths that contain templates here, relative to this directory.
48
+ templates_path = ["_templates"]
49
+
50
+ # The master toctree document.
51
+ master_doc = "index"
52
+
53
+ # General information about the project.
54
+ project = "fairseq"
55
+ copyright = "Facebook AI Research (FAIR)"
56
+ author = "Facebook AI Research (FAIR)"
57
+
58
+ github_doc_root = "https://github.com/pytorch/fairseq/tree/master/docs/"
59
+
60
+ # The version info for the project you're documenting, acts as replacement for
61
+ # |version| and |release|, also used in various other places throughout the
62
+ # built documents.
63
+ #
64
+ # The short X.Y version.
65
+ version = __version__
66
+ # The full version, including alpha/beta/rc tags.
67
+ release = __version__
68
+
69
+ # The language for content autogenerated by Sphinx. Refer to documentation
70
+ # for a list of supported languages.
71
+ #
72
+ # This is also used if you do content translation via gettext catalogs.
73
+ # Usually you set "language" from the command line for these cases.
74
+ language = None
75
+
76
+ # List of patterns, relative to source directory, that match files and
77
+ # directories to ignore when looking for source files.
78
+ # This patterns also effect to html_static_path and html_extra_path
79
+ exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
80
+
81
+ # The name of the Pygments (syntax highlighting) style to use.
82
+ pygments_style = "sphinx"
83
+ highlight_language = "python"
84
+
85
+ # If true, `todo` and `todoList` produce output, else they produce nothing.
86
+ todo_include_todos = False
87
+
88
+
89
+ # -- Options for HTML output ----------------------------------------------
90
+
91
+ # The theme to use for HTML and HTML Help pages. See the documentation for
92
+ # a list of builtin themes.
93
+ #
94
+ html_theme = "sphinx_rtd_theme"
95
+
96
+ # Theme options are theme-specific and customize the look and feel of a theme
97
+ # further. For a list of options available for each theme, see the
98
+ # documentation.
99
+ #
100
+ # html_theme_options = {}
101
+
102
+ # Add any paths that contain custom static files (such as style sheets) here,
103
+ # relative to this directory. They are copied after the builtin static files,
104
+ # so a file named "default.css" will overwrite the builtin "default.css".
105
+ html_static_path = ["_static"]
106
+
107
+ html_context = {
108
+ "css_files": [
109
+ "_static/theme_overrides.css", # override wide tables in RTD theme
110
+ ],
111
+ }
112
+
113
+ # Custom sidebar templates, must be a dictionary that maps document names
114
+ # to template names.
115
+ #
116
+ # This is required for the alabaster theme
117
+ # refs: http://alabaster.readthedocs.io/en/latest/installation.html#sidebars
118
+ # html_sidebars = {
119
+ # '**': [
120
+ # 'about.html',
121
+ # 'navigation.html',
122
+ # 'relations.html', # needs 'show_related': True theme option to display
123
+ # 'searchbox.html',
124
+ # 'donate.html',
125
+ # ]
126
+ # }
127
+
128
+
129
+ # Example configuration for intersphinx: refer to the Python standard library.
130
+ intersphinx_mapping = {
131
+ "numpy": ("http://docs.scipy.org/doc/numpy/", None),
132
+ "python": ("https://docs.python.org/", None),
133
+ "torch": ("https://pytorch.org/docs/master/", None),
134
+ }
docs/criterions.rst ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .. role:: hidden
2
+ :class: hidden-section
3
+
4
+ .. _Criterions:
5
+
6
+ Criterions
7
+ ==========
8
+
9
+ Criterions compute the loss function given the model and batch, roughly::
10
+
11
+ loss = criterion(model, batch)
12
+
13
+ .. automodule:: fairseq.criterions
14
+ :members:
15
+
16
+ .. autoclass:: fairseq.criterions.FairseqCriterion
17
+ :members:
18
+ :undoc-members:
19
+
20
+ .. autoclass:: fairseq.criterions.adaptive_loss.AdaptiveLoss
21
+ :members:
22
+ :undoc-members:
23
+ .. autoclass:: fairseq.criterions.composite_loss.CompositeLoss
24
+ :members:
25
+ :undoc-members:
26
+ .. autoclass:: fairseq.criterions.cross_entropy.CrossEntropyCriterion
27
+ :members:
28
+ :undoc-members:
29
+ .. autoclass:: fairseq.criterions.label_smoothed_cross_entropy.LabelSmoothedCrossEntropyCriterion
30
+ :members:
31
+ :undoc-members:
docs/data.rst ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .. role:: hidden
2
+ :class: hidden-section
3
+
4
+ .. module:: fairseq.data
5
+
6
+ Data Loading and Utilities
7
+ ==========================
8
+
9
+ .. _datasets:
10
+
11
+ Datasets
12
+ --------
13
+
14
+ **Datasets** define the data format and provide helpers for creating
15
+ mini-batches.
16
+
17
+ .. autoclass:: fairseq.data.FairseqDataset
18
+ :members:
19
+ .. autoclass:: fairseq.data.LanguagePairDataset
20
+ :members:
21
+ .. autoclass:: fairseq.data.MonolingualDataset
22
+ :members:
23
+
24
+ **Helper Datasets**
25
+
26
+ These datasets wrap other :class:`fairseq.data.FairseqDataset` instances and
27
+ provide additional functionality:
28
+
29
+ .. autoclass:: fairseq.data.BacktranslationDataset
30
+ :members:
31
+ .. autoclass:: fairseq.data.ConcatDataset
32
+ :members:
33
+ .. autoclass:: fairseq.data.ResamplingDataset
34
+ :members:
35
+ .. autoclass:: fairseq.data.RoundRobinZipDatasets
36
+ :members:
37
+ .. autoclass:: fairseq.data.TransformEosDataset
38
+ :members:
39
+
40
+
41
+ Dictionary
42
+ ----------
43
+
44
+ .. autoclass:: fairseq.data.Dictionary
45
+ :members:
46
+
47
+
48
+ Iterators
49
+ ---------
50
+
51
+ .. autoclass:: fairseq.data.CountingIterator
52
+ :members:
53
+ .. autoclass:: fairseq.data.EpochBatchIterator
54
+ :members:
55
+ .. autoclass:: fairseq.data.GroupedIterator
56
+ :members:
57
+ .. autoclass:: fairseq.data.ShardedIterator
58
+ :members:
docs/docutils.conf ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ [writers]
2
+ option-limit=0
docs/fairseq.gif ADDED
docs/fairseq_logo.png ADDED
docs/getting_started.rst ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Evaluating Pre-trained Models
2
+ =============================
3
+
4
+ First, download a pre-trained model along with its vocabularies:
5
+
6
+ .. code-block:: console
7
+
8
+ > curl https://dl.fbaipublicfiles.com/fairseq/models/wmt14.v2.en-fr.fconv-py.tar.bz2 | tar xvjf -
9
+
10
+ This model uses a `Byte Pair Encoding (BPE)
11
+ vocabulary <https://arxiv.org/abs/1508.07909>`__, so we'll have to apply
12
+ the encoding to the source text before it can be translated. This can be
13
+ done with the
14
+ `apply\_bpe.py <https://github.com/rsennrich/subword-nmt/blob/master/subword_nmt/apply_bpe.py>`__
15
+ script using the ``wmt14.en-fr.fconv-cuda/bpecodes`` file. ``@@`` is
16
+ used as a continuation marker and the original text can be easily
17
+ recovered with e.g. ``sed s/@@ //g`` or by passing the ``--remove-bpe``
18
+ flag to :ref:`fairseq-generate`. Prior to BPE, input text needs to be tokenized
19
+ using ``tokenizer.perl`` from
20
+ `mosesdecoder <https://github.com/moses-smt/mosesdecoder>`__.
21
+
22
+ Let's use :ref:`fairseq-interactive` to generate translations interactively.
23
+ Here, we use a beam size of 5 and preprocess the input with the Moses
24
+ tokenizer and the given Byte-Pair Encoding vocabulary. It will automatically
25
+ remove the BPE continuation markers and detokenize the output.
26
+
27
+ .. code-block:: console
28
+
29
+ > MODEL_DIR=wmt14.en-fr.fconv-py
30
+ > fairseq-interactive \
31
+ --path $MODEL_DIR/model.pt $MODEL_DIR \
32
+ --beam 5 --source-lang en --target-lang fr \
33
+ --tokenizer moses \
34
+ --bpe subword_nmt --bpe-codes $MODEL_DIR/bpecodes
35
+ | loading model(s) from wmt14.en-fr.fconv-py/model.pt
36
+ | [en] dictionary: 44206 types
37
+ | [fr] dictionary: 44463 types
38
+ | Type the input sentence and press return:
39
+ Why is it rare to discover new marine mammal species?
40
+ S-0 Why is it rare to discover new marine mam@@ mal species ?
41
+ H-0 -0.0643349438905716 Pourquoi est-il rare de découvrir de nouvelles espèces de mammifères marins?
42
+ P-0 -0.0763 -0.1849 -0.0956 -0.0946 -0.0735 -0.1150 -0.1301 -0.0042 -0.0321 -0.0171 -0.0052 -0.0062 -0.0015
43
+
44
+ This generation script produces three types of outputs: a line prefixed
45
+ with *O* is a copy of the original source sentence; *H* is the
46
+ hypothesis along with an average log-likelihood; and *P* is the
47
+ positional score per token position, including the
48
+ end-of-sentence marker which is omitted from the text.
49
+
50
+ Other types of output lines you might see are *D*, the detokenized hypothesis,
51
+ *T*, the reference target, *A*, alignment info, *E* the history of generation steps.
52
+
53
+ See the `README <https://github.com/pytorch/fairseq#pre-trained-models>`__ for a
54
+ full list of pre-trained models available.
55
+
56
+ Training a New Model
57
+ ====================
58
+
59
+ The following tutorial is for machine translation. For an example of how
60
+ to use Fairseq for other tasks, such as :ref:`language modeling`, please see the
61
+ ``examples/`` directory.
62
+
63
+ Data Pre-processing
64
+ -------------------
65
+
66
+ Fairseq contains example pre-processing scripts for several translation
67
+ datasets: IWSLT 2014 (German-English), WMT 2014 (English-French) and WMT
68
+ 2014 (English-German). To pre-process and binarize the IWSLT dataset:
69
+
70
+ .. code-block:: console
71
+
72
+ > cd examples/translation/
73
+ > bash prepare-iwslt14.sh
74
+ > cd ../..
75
+ > TEXT=examples/translation/iwslt14.tokenized.de-en
76
+ > fairseq-preprocess --source-lang de --target-lang en \
77
+ --trainpref $TEXT/train --validpref $TEXT/valid --testpref $TEXT/test \
78
+ --destdir data-bin/iwslt14.tokenized.de-en
79
+
80
+ This will write binarized data that can be used for model training to
81
+ ``data-bin/iwslt14.tokenized.de-en``.
82
+
83
+ Training
84
+ --------
85
+
86
+ Use :ref:`fairseq-train` to train a new model. Here a few example settings that work
87
+ well for the IWSLT 2014 dataset:
88
+
89
+ .. code-block:: console
90
+
91
+ > mkdir -p checkpoints/fconv
92
+ > CUDA_VISIBLE_DEVICES=0 fairseq-train data-bin/iwslt14.tokenized.de-en \
93
+ --optimizer nag --lr 0.25 --clip-norm 0.1 --dropout 0.2 --max-tokens 4000 \
94
+ --arch fconv_iwslt_de_en --save-dir checkpoints/fconv
95
+
96
+ By default, :ref:`fairseq-train` will use all available GPUs on your machine. Use the
97
+ ``CUDA_VISIBLE_DEVICES`` environment variable to select specific GPUs and/or to
98
+ change the number of GPU devices that will be used.
99
+
100
+ Also note that the batch size is specified in terms of the maximum
101
+ number of tokens per batch (``--max-tokens``). You may need to use a
102
+ smaller value depending on the available GPU memory on your system.
103
+
104
+ Generation
105
+ ----------
106
+
107
+ Once your model is trained, you can generate translations using
108
+ :ref:`fairseq-generate` **(for binarized data)** or
109
+ :ref:`fairseq-interactive` **(for raw text)**:
110
+
111
+ .. code-block:: console
112
+
113
+ > fairseq-generate data-bin/iwslt14.tokenized.de-en \
114
+ --path checkpoints/fconv/checkpoint_best.pt \
115
+ --batch-size 128 --beam 5
116
+ | [de] dictionary: 35475 types
117
+ | [en] dictionary: 24739 types
118
+ | data-bin/iwslt14.tokenized.de-en test 6750 examples
119
+ | model fconv
120
+ | loaded checkpoint trainings/fconv/checkpoint_best.pt
121
+ S-721 danke .
122
+ T-721 thank you .
123
+ ...
124
+
125
+ To generate translations with only a CPU, use the ``--cpu`` flag. BPE
126
+ continuation markers can be removed with the ``--remove-bpe`` flag.
127
+
128
+ Advanced Training Options
129
+ =========================
130
+
131
+ Large mini-batch training with delayed updates
132
+ ----------------------------------------------
133
+
134
+ The ``--update-freq`` option can be used to accumulate gradients from
135
+ multiple mini-batches and delay updating, creating a larger effective
136
+ batch size. Delayed updates can also improve training speed by reducing
137
+ inter-GPU communication costs and by saving idle time caused by variance
138
+ in workload across GPUs. See `Ott et al.
139
+ (2018) <https://arxiv.org/abs/1806.00187>`__ for more details.
140
+
141
+ To train on a single GPU with an effective batch size that is equivalent
142
+ to training on 8 GPUs:
143
+
144
+ .. code-block:: console
145
+
146
+ > CUDA_VISIBLE_DEVICES=0 fairseq-train --update-freq 8 (...)
147
+
148
+ Training with half precision floating point (FP16)
149
+ --------------------------------------------------
150
+
151
+ .. note::
152
+
153
+ FP16 training requires a Volta GPU and CUDA 9.1 or greater
154
+
155
+ Recent GPUs enable efficient half precision floating point computation,
156
+ e.g., using `Nvidia Tensor Cores
157
+ <https://docs.nvidia.com/deeplearning/sdk/mixed-precision-training/index.html>`__.
158
+ Fairseq supports FP16 training with the ``--fp16`` flag:
159
+
160
+ .. code-block:: console
161
+
162
+ > fairseq-train --fp16 (...)
163
+
164
+ Distributed training
165
+ --------------------
166
+
167
+ Distributed training in fairseq is implemented on top of ``torch.distributed``.
168
+ The easiest way to launch jobs is with the `torch.distributed.launch
169
+ <https://pytorch.org/docs/stable/distributed.html#launch-utility>`__ tool.
170
+
171
+ For example, to train a large English-German Transformer model on 2 nodes each
172
+ with 8 GPUs (in total 16 GPUs), run the following command on each node,
173
+ replacing ``node_rank=0`` with ``node_rank=1`` on the second node and making
174
+ sure to update ``--master_addr`` to the IP address of the first node:
175
+
176
+ .. code-block:: console
177
+
178
+ > python -m torch.distributed.launch --nproc_per_node=8 \
179
+ --nnodes=2 --node_rank=0 --master_addr="192.168.1.1" \
180
+ --master_port=12345 \
181
+ $(which fairseq-train) data-bin/wmt16_en_de_bpe32k \
182
+ --arch transformer_vaswani_wmt_en_de_big --share-all-embeddings \
183
+ --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
184
+ --lr-scheduler inverse_sqrt --warmup-init-lr 1e-07 --warmup-updates 4000 \
185
+ --lr 0.0005 \
186
+ --dropout 0.3 --weight-decay 0.0 --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
187
+ --max-tokens 3584 \
188
+ --max-epoch 70 \
189
+ --fp16
190
+
191
+ On SLURM clusters, fairseq will automatically detect the number of nodes and
192
+ GPUs, but a port number must be provided:
193
+
194
+ .. code-block:: console
195
+
196
+ > salloc --gpus=16 --nodes 2 (...)
197
+ > srun fairseq-train --distributed-port 12345 (...).
198
+
199
+ Sharding very large datasets
200
+ ----------------------------
201
+
202
+ It can be challenging to train over very large datasets, particularly if your
203
+ machine does not have much system RAM. Most tasks in fairseq support training
204
+ over "sharded" datasets, in which the original dataset has been preprocessed
205
+ into non-overlapping chunks (or "shards").
206
+
207
+ For example, instead of preprocessing all your data into a single "data-bin"
208
+ directory, you can split the data and create "data-bin1", "data-bin2", etc.
209
+ Then you can adapt your training command like so:
210
+
211
+ .. code-block:: console
212
+
213
+ > fairseq-train data-bin1:data-bin2:data-bin3 (...)
214
+
215
+ Training will now iterate over each shard, one by one, with each shard
216
+ corresponding to an "epoch", thus reducing system memory usage.
docs/hydra_integration.md ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Hydra
2
+
3
+ [Hydra](https://github.com/facebookresearch/hydra) is an open-source Python
4
+ framework that simplifies the development of research and other complex
5
+ applications. The key feature is the ability to dynamically create a
6
+ hierarchical configuration by composition and override it through config files
7
+ and the command line. The name Hydra comes from its ability to run multiple
8
+ similar jobs - much like a Hydra with multiple heads.
9
+
10
+ ## Motivation
11
+
12
+ Until recently, all components in fairseq were configured through a shared
13
+ `args` namespace that was created at application startup. Components declared
14
+ their own `add_args` method to update the argparse parser, hoping that the names
15
+ would not clash with arguments from other components. While this model works for
16
+ smaller applications, as fairseq grew and became integrated into other
17
+ applications, this became problematic. In order to determine how to configure
18
+ each component, one needed to a) examine what args were added by this component,
19
+ and b) read the code to figure out what shared arguments it is using that were
20
+ added in other places. Reproducing models involved sharing commands that often
21
+ contained dozens of command line switches.
22
+
23
+ The model described above is still supported by fairseq for backward
24
+ compatibility, but will be deprecated some time in the future.
25
+
26
+ New components in fairseq should now create a dataclass that encapsulates all
27
+ parameters required to configure this component. The dataclass is registered
28
+ along with the component, and fairseq takes care of constructing and providing
29
+ this configuration object to the component's constructor. Note that sharing
30
+ parameters can optionally still work, but one has to explicitly point to the
31
+ "source of truth" (see inheritance example below). These changes make components
32
+ in fairseq more independent and re-usable by other applications: all that is
33
+ needed to create a component is to initialize its dataclass and overwrite some
34
+ of the defaults.
35
+
36
+ While configuring fairseq through command line (using either the legacy argparse
37
+ based or the new Hydra based entry points) is still fully supported, you can now
38
+ take advantage of configuring fairseq completely or piece-by-piece through
39
+ hierarchical YAML configuration files. These files can also be shipped as
40
+ examples that others can use to run an identically configured job.
41
+
42
+ Additionally, Hydra has a rich and growing [library of
43
+ plugins](https://github.com/facebookresearch/hydra/tree/master/plugins) that
44
+ provide functionality such as hyperparameter sweeping (including using bayesian
45
+ optimization through the [Ax](https://github.com/facebook/Ax) library), job
46
+ launching across various platforms, and more.
47
+
48
+ ## Creating or migrating components
49
+
50
+ In general, each new (or updated) component should provide a companion
51
+ [dataclass](https://www.python.org/dev/peps/pep-0557/). These dataclass are
52
+ typically located in the same file as the component and are passed as arguments
53
+ to the `register_*()` functions. Top-level configs that should be present in
54
+ every fairseq application are placed in the
55
+ [global](fairseq/dataclass/configs.py) config file and added to the
56
+ `FairseqConfig` object.
57
+
58
+ Each dataclass is a plain-old-data object, similar to a `NamedTuple`. These
59
+ classes are decorated with a `@dataclass` decorator, and typically inherit from
60
+ `FairseqDataclass` (which adds some functionality for backward compatibility).
61
+ Each field must have a type, and generally has metadata (such as a help string)
62
+ and a default value. Only primitive types or other config objects are allowed as
63
+ data types for each field.
64
+
65
+ #### Example:
66
+
67
+ ```python
68
+ from dataclasses import dataclass, field
69
+ from fairseq.dataclass import FairseqDataclass
70
+
71
+ @dataclass
72
+ class InteractiveConfig(FairseqDataclass):
73
+ buffer_size: int = field(
74
+ default=0,
75
+ metadata={
76
+ "help": "read this many sentences into a buffer before processing them"
77
+ },
78
+ )
79
+ input: str = field(
80
+ default="-",
81
+ metadata={"help": "file to read from; use - for stdin"},
82
+ )
83
+ ```
84
+
85
+ ### Inherting values
86
+
87
+ Some components require sharing a value. For example, a learning rate scheduler
88
+ and an optimizer may both need to know the initial learning rate value. One can
89
+ declare a field that, by default, will inherit its value from another config
90
+ node in the same hierarchy:
91
+
92
+ ```python
93
+ @dataclass
94
+ FairseqAdamConfig(FairseqDataclass):
95
+ ...
96
+ lr: List[float] = II("optimization.lr")
97
+ ...
98
+ ```
99
+
100
+ `II("optimization.lr")` is syntactic sugar for `"${optimization.lr}"`, which is
101
+ the value one can use in a YAML config file or through command line to achieve
102
+ the same effect. Note that this assumes that there is an "optimization" config
103
+ object in the root config and it has a field called "lr".
104
+
105
+ ### Tasks and Models
106
+
107
+ Creating Tasks and Models works same as before, except that legacy
108
+ implementations now inherit from `LegacyFairseq*` base classes, while new
109
+ components inherit from `FairseqTask` and `FairseqModel` and provide a dataclass
110
+ to the `register_*()` functions.
111
+
112
+ #### Task example:
113
+
114
+ ```python
115
+ @dataclass
116
+ class LanguageModelingConfig(FairseqDataclass):
117
+ data: Optional[str] = field(
118
+ default=None, metadata={"help": "path to data directory"}
119
+ )
120
+ ...
121
+
122
+ @register_task("language_modeling", dataclass=LanguageModelingConfig)
123
+ class LanguageModelingTask(FairseqTask):
124
+ ...
125
+ @classmethod
126
+ def setup_task(cls, cfg: LanguageModelingConfig):
127
+ ...
128
+ ```
129
+
130
+ #### Model example:
131
+
132
+ ```python
133
+ @dataclass
134
+ class TransformerLanguageModelConfig(FairseqDataclass):
135
+ activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = field(
136
+ default="relu", metadata={"help": "activation function to use"}
137
+ )
138
+ dropout: float = field(default=0.1, metadata={"help": "dropout probability"})
139
+ ...
140
+
141
+ @register_model("transformer_lm", dataclass=TransformerLanguageModelConfig)
142
+ class TransformerLanguageModel(FairseqLanguageModel):
143
+ ...
144
+ @classmethod
145
+ def build_model(cls, cfg: TransformerLanguageModelConfig, task: FairseqTask):
146
+ ...
147
+ ```
148
+
149
+ ### Other components
150
+
151
+ Other components work as before, but they now take their configuration dataclass
152
+ as the only constructor argument:
153
+
154
+ ```python
155
+ @dataclass
156
+ class MosesTokenizerConfig(FairseqDataclass):
157
+ source_lang: str = field(default="en", metadata={"help": "source language"})
158
+ ...
159
+
160
+ @register_tokenizer("moses", dataclass=MosesTokenizerConfig)
161
+ class MosesTokenizer(object):
162
+ def __init__(self, cfg: MosesTokenizerConfig):
163
+ ...
164
+ ```
165
+
166
+ Note that if you are adding a new registry for a new set of components, you need
167
+ to add it to the `FairseqConfig` object in `fairseq/dataclass/configs.py`:
168
+
169
+ ```python
170
+ @dataclass
171
+ class FairseqConfig(object):
172
+ ...
173
+ my_new_registry: Any = None
174
+ ```
175
+
176
+ ## Training with `fairseq-hydra-train`
177
+
178
+ To fully take advantage of configuration flexibility offered by Hydra, you may
179
+ want to train new models using the `fairseq-hydra-train` entry point. Legacy CLI
180
+ tools such as `fairseq-train` will remain supported for the foreseeable future
181
+ but will be deprecated eventually.
182
+
183
+ On startup, Hydra will create a configuration object that contains a hierarchy
184
+ of all the necessary dataclasses populated with their default values in the
185
+ code. The default values are overwritten by values found in YAML files in
186
+ `fairseq/config` directory (which currently sets minimal defaults) and then
187
+ further overwritten by values provided through command line arguments.
188
+
189
+ Some of the most common use cases are shown below:
190
+
191
+ ### 1. Override default values through command line:
192
+
193
+ ```shell script
194
+ $ fairseq-hydra-train \
195
+ distributed_training.distributed_world_size=1 \
196
+ dataset.batch_size=2 \
197
+ task.data=data-bin \
198
+ model=transformer_lm/transformer_lm_gpt \
199
+ task=language_modeling \
200
+ optimization.max_update=5000
201
+ ```
202
+
203
+ Note that along with explicitly providing values for parameters such as
204
+ `dataset.batch_size`, this also tells Hydra to overlay configuration found in
205
+ `fairseq/config/model/transformer_lm/transformer_lm_gpt.yaml` over the default
206
+ values in the dataclass. If you want to train a model without specifying a
207
+ particular architecture you can simply specify `model=transformer_lm`. This only
208
+ works for migrated tasks and models.
209
+
210
+ ### 2. Replace bundled configs with an external config:
211
+
212
+ ```shell script
213
+ $ fairseq-hydra-train \
214
+ --config-dir /path/to/external/configs \
215
+ --config-name wiki103
216
+ ```
217
+
218
+ where `/path/to/external/configs/wiki103.yaml` contains:
219
+
220
+ ```yaml
221
+ # @package _group_
222
+
223
+ model:
224
+ _name: transformer_lm
225
+ distributed_training:
226
+ distributed_world_size: 1
227
+ dataset:
228
+ batch_size: 2
229
+ task:
230
+ _name: language_modeling
231
+ data: /path/to/data
232
+ add_bos_token: false
233
+ max_target_positions: 1024
234
+ optimization:
235
+ max_update: 50000
236
+ lr: [ 0.25 ]
237
+ criterion: cross_entropy
238
+ optimizer: adam
239
+ lr_scheduler:
240
+ _name: cosine
241
+ ```
242
+
243
+ Note that here bundled configs from `fairseq/config` directory are not used,
244
+ however the defaults from each dataclass will still be used (unless overwritten
245
+ by your external config).
246
+
247
+ Additionally you can choose to break up your configs by creating a directory
248
+ structure in the same location as your main config file, with the names of the
249
+ top-level fields (such as "model", "dataset", etc), and placing config files
250
+ with meaningful names that would populate that specific section of your
251
+ top-level config file (for example, you might have
252
+ `model/small_transformer_lm.yaml`, `model/big_transformer_lm.yaml`, etc). You
253
+ can then specify the correct configuration via command line, defaults in the
254
+ main config, or even launch all of them as a sweep (see Hydra documentation on
255
+ how to do this).
256
+
257
+ ### 3. Add an external config directory to Hydra search path:
258
+
259
+ This allows combining default configuration (including using any bundled config
260
+ files), while specifying your own config files for some parts of the
261
+ configuration.
262
+
263
+ ```shell script
264
+ $ fairseq-hydra-train \
265
+ distributed_training.distributed_world_size=1 \
266
+ dataset.batch_size=2 \
267
+ task.data=/path/to/data/ \
268
+ model=transformer_lm/2_layers \
269
+ task=language_modeling \
270
+ optimization.max_update=5000 \
271
+ --config-dir /path/to/external/configs
272
+ ```
273
+
274
+ where `/path/to/external/configs` has the following structure:
275
+ ```
276
+ .
277
+ +-- model
278
+ | +-- transformer_lm
279
+ | | +-- 2_layers.yaml
280
+ ```
281
+
282
+ and `2_layers.yaml` contains a copy of `transformer_lm_gpt.yaml` but with
283
+ `decoder_layers` set to 2. You can add other configs to configure other
284
+ components as well.
docs/index.rst ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .. fairseq documentation master file, created by
2
+ sphinx-quickstart on Fri Aug 17 21:45:30 2018.
3
+ You can adapt this file completely to your liking, but it should at least
4
+ contain the root `toctree` directive.
5
+
6
+ :github_url: https://github.com/pytorch/fairseq
7
+
8
+
9
+ fairseq documentation
10
+ =====================
11
+
12
+ Fairseq is a sequence modeling toolkit written in `PyTorch
13
+ <http://pytorch.org/>`_ that allows researchers and developers to
14
+ train custom models for translation, summarization, language modeling and other
15
+ text generation tasks.
16
+
17
+ .. toctree::
18
+ :maxdepth: 1
19
+ :caption: Getting Started
20
+
21
+ getting_started
22
+ command_line_tools
23
+
24
+ .. toctree::
25
+ :maxdepth: 1
26
+ :caption: Extending Fairseq
27
+
28
+ overview
29
+ tutorial_simple_lstm
30
+ tutorial_classifying_names
31
+
32
+ .. toctree::
33
+ :maxdepth: 2
34
+ :caption: Library Reference
35
+
36
+ tasks
37
+ models
38
+ criterions
39
+ optim
40
+ lr_scheduler
41
+ data
42
+ modules
43
+
44
+
45
+ Indices and tables
46
+ ==================
47
+
48
+ * :ref:`genindex`
49
+ * :ref:`search`
docs/lr_scheduler.rst ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .. role:: hidden
2
+ :class: hidden-section
3
+
4
+ .. _Learning Rate Schedulers:
5
+
6
+ Learning Rate Schedulers
7
+ ========================
8
+
9
+ Learning Rate Schedulers update the learning rate over the course of training.
10
+ Learning rates can be updated after each update via :func:`step_update` or at
11
+ epoch boundaries via :func:`step`.
12
+
13
+ .. automodule:: fairseq.optim.lr_scheduler
14
+ :members:
15
+
16
+ .. autoclass:: fairseq.optim.lr_scheduler.FairseqLRScheduler
17
+ :members:
18
+ :undoc-members:
19
+
20
+ .. autoclass:: fairseq.optim.lr_scheduler.cosine_lr_scheduler.CosineSchedule
21
+ :members:
22
+ :undoc-members:
23
+ .. autoclass:: fairseq.optim.lr_scheduler.fixed_schedule.FixedSchedule
24
+ :members:
25
+ :undoc-members:
26
+ .. autoclass:: fairseq.optim.lr_scheduler.inverse_square_root_schedule.InverseSquareRootSchedule
27
+ :members:
28
+ :undoc-members:
29
+ .. autoclass:: fairseq.optim.lr_scheduler.reduce_lr_on_plateau.ReduceLROnPlateau
30
+ :members:
31
+ :undoc-members:
32
+ .. autoclass:: fairseq.optim.lr_scheduler.triangular_lr_scheduler.TriangularSchedule
33
+ :members:
34
+ :undoc-members:
docs/make.bat ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ @ECHO OFF
2
+
3
+ pushd %~dp0
4
+
5
+ REM Command file for Sphinx documentation
6
+
7
+ if "%SPHINXBUILD%" == "" (
8
+ set SPHINXBUILD=python -msphinx
9
+ )
10
+ set SOURCEDIR=.
11
+ set BUILDDIR=_build
12
+ set SPHINXPROJ=fairseq
13
+
14
+ if "%1" == "" goto help
15
+
16
+ %SPHINXBUILD% >NUL 2>NUL
17
+ if errorlevel 9009 (
18
+ echo.
19
+ echo.The Sphinx module was not found. Make sure you have Sphinx installed,
20
+ echo.then set the SPHINXBUILD environment variable to point to the full
21
+ echo.path of the 'sphinx-build' executable. Alternatively you may add the
22
+ echo.Sphinx directory to PATH.
23
+ echo.
24
+ echo.If you don't have Sphinx installed, grab it from
25
+ echo.http://sphinx-doc.org/
26
+ exit /b 1
27
+ )
28
+
29
+ %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%
30
+ goto end
31
+
32
+ :help
33
+ %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%
34
+
35
+ :end
36
+ popd
docs/models.rst ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .. role:: hidden
2
+ :class: hidden-section
3
+
4
+ .. module:: fairseq.models
5
+
6
+ .. _Models:
7
+
8
+ Models
9
+ ======
10
+
11
+ A Model defines the neural network's ``forward()`` method and encapsulates all
12
+ of the learnable parameters in the network. Each model also provides a set of
13
+ named *architectures* that define the precise network configuration (e.g.,
14
+ embedding dimension, number of layers, etc.).
15
+
16
+ Both the model type and architecture are selected via the ``--arch``
17
+ command-line argument. Once selected, a model may expose additional command-line
18
+ arguments for further configuration.
19
+
20
+ .. note::
21
+
22
+ All fairseq Models extend :class:`BaseFairseqModel`, which in turn extends
23
+ :class:`torch.nn.Module`. Thus any fairseq Model can be used as a
24
+ stand-alone Module in other PyTorch code.
25
+
26
+
27
+ Convolutional Neural Networks (CNN)
28
+ -----------------------------------
29
+
30
+ .. module:: fairseq.models.fconv
31
+ .. autoclass:: fairseq.models.fconv.FConvModel
32
+ :members:
33
+ .. autoclass:: fairseq.models.fconv.FConvEncoder
34
+ :members:
35
+ :undoc-members:
36
+ .. autoclass:: fairseq.models.fconv.FConvDecoder
37
+ :members:
38
+
39
+
40
+ Long Short-Term Memory (LSTM) networks
41
+ --------------------------------------
42
+
43
+ .. module:: fairseq.models.lstm
44
+ .. autoclass:: fairseq.models.lstm.LSTMModel
45
+ :members:
46
+ .. autoclass:: fairseq.models.lstm.LSTMEncoder
47
+ :members:
48
+ .. autoclass:: fairseq.models.lstm.LSTMDecoder
49
+ :members:
50
+
51
+
52
+ Transformer (self-attention) networks
53
+ -------------------------------------
54
+
55
+ .. module:: fairseq.models.transformer
56
+ .. autoclass:: fairseq.models.transformer.TransformerModel
57
+ :members:
58
+ .. autoclass:: fairseq.models.transformer.TransformerEncoder
59
+ :members:
60
+ .. autoclass:: fairseq.models.transformer.TransformerEncoderLayer
61
+ :members:
62
+ .. autoclass:: fairseq.models.transformer.TransformerDecoder
63
+ :members:
64
+ .. autoclass:: fairseq.models.transformer.TransformerDecoderLayer
65
+ :members:
66
+
67
+
68
+ Adding new models
69
+ -----------------
70
+
71
+ .. currentmodule:: fairseq.models
72
+ .. autofunction:: fairseq.models.register_model
73
+ .. autofunction:: fairseq.models.register_model_architecture
74
+ .. autoclass:: fairseq.models.BaseFairseqModel
75
+ :members:
76
+ :undoc-members:
77
+ .. autoclass:: fairseq.models.FairseqEncoderDecoderModel
78
+ :members:
79
+ :undoc-members:
80
+ .. autoclass:: fairseq.models.FairseqEncoderModel
81
+ :members:
82
+ :undoc-members:
83
+ .. autoclass:: fairseq.models.FairseqLanguageModel
84
+ :members:
85
+ :undoc-members:
86
+ .. autoclass:: fairseq.models.FairseqMultiModel
87
+ :members:
88
+ :undoc-members:
89
+ .. autoclass:: fairseq.models.FairseqEncoder
90
+ :members:
91
+ .. autoclass:: fairseq.models.CompositeEncoder
92
+ :members:
93
+ .. autoclass:: fairseq.models.FairseqDecoder
94
+ :members:
95
+
96
+
97
+ .. _Incremental decoding:
98
+
99
+ Incremental decoding
100
+ --------------------
101
+
102
+ .. autoclass:: fairseq.models.FairseqIncrementalDecoder
103
+ :members:
104
+ :undoc-members:
docs/modules.rst ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ Modules
2
+ =======
3
+
4
+ Fairseq provides several stand-alone :class:`torch.nn.Module` classes that may
5
+ be helpful when implementing a new :class:`~fairseq.models.BaseFairseqModel`.
6
+
7
+ .. automodule:: fairseq.modules
8
+ :members:
9
+ :undoc-members:
docs/optim.rst ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .. role:: hidden
2
+ :class: hidden-section
3
+
4
+ .. _optimizers:
5
+
6
+ Optimizers
7
+ ==========
8
+
9
+ Optimizers update the Model parameters based on the gradients.
10
+
11
+ .. automodule:: fairseq.optim
12
+ :members:
13
+
14
+ .. autoclass:: fairseq.optim.FairseqOptimizer
15
+ :members:
16
+ :undoc-members:
17
+
18
+ .. autoclass:: fairseq.optim.adadelta.Adadelta
19
+ :members:
20
+ :undoc-members:
21
+ .. autoclass:: fairseq.optim.adagrad.Adagrad
22
+ :members:
23
+ :undoc-members:
24
+ .. autoclass:: fairseq.optim.adafactor.FairseqAdafactor
25
+ :members:
26
+ :undoc-members:
27
+ .. autoclass:: fairseq.optim.adam.FairseqAdam
28
+ :members:
29
+ :undoc-members:
30
+ .. autoclass:: fairseq.optim.fp16_optimizer.FP16Optimizer
31
+ :members:
32
+ :undoc-members:
33
+ .. autoclass:: fairseq.optim.nag.FairseqNAG
34
+ :members:
35
+ :undoc-members:
36
+ .. autoclass:: fairseq.optim.sgd.SGD
37
+ :members:
38
+ :undoc-members:
docs/overview.rst ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Overview
2
+ ========
3
+
4
+ Fairseq can be extended through user-supplied `plug-ins
5
+ <https://en.wikipedia.org/wiki/Plug-in_(computing)>`_. We support five kinds of
6
+ plug-ins:
7
+
8
+ - :ref:`Models` define the neural network architecture and encapsulate all of the
9
+ learnable parameters.
10
+ - :ref:`Criterions` compute the loss function given the model outputs and targets.
11
+ - :ref:`Tasks` store dictionaries and provide helpers for loading/iterating over
12
+ Datasets, initializing the Model/Criterion and calculating the loss.
13
+ - :ref:`Optimizers` update the Model parameters based on the gradients.
14
+ - :ref:`Learning Rate Schedulers` update the learning rate over the course of
15
+ training.
16
+
17
+ **Training Flow**
18
+
19
+ Given a ``model``, ``criterion``, ``task``, ``optimizer`` and ``lr_scheduler``,
20
+ fairseq implements the following high-level training flow::
21
+
22
+ for epoch in range(num_epochs):
23
+ itr = task.get_batch_iterator(task.dataset('train'))
24
+ for num_updates, batch in enumerate(itr):
25
+ task.train_step(batch, model, criterion, optimizer)
26
+ average_and_clip_gradients()
27
+ optimizer.step()
28
+ lr_scheduler.step_update(num_updates)
29
+ lr_scheduler.step(epoch)
30
+
31
+ where the default implementation for ``task.train_step`` is roughly::
32
+
33
+ def train_step(self, batch, model, criterion, optimizer, **unused):
34
+ loss = criterion(model, batch)
35
+ optimizer.backward(loss)
36
+ return loss
37
+
38
+ **Registering new plug-ins**
39
+
40
+ New plug-ins are *registered* through a set of ``@register`` function
41
+ decorators, for example::
42
+
43
+ @register_model('my_lstm')
44
+ class MyLSTM(FairseqEncoderDecoderModel):
45
+ (...)
46
+
47
+ Once registered, new plug-ins can be used with the existing :ref:`Command-line
48
+ Tools`. See the Tutorial sections for more detailed walkthroughs of how to add
49
+ new plug-ins.
50
+
51
+ **Loading plug-ins from another directory**
52
+
53
+ New plug-ins can be defined in a custom module stored in the user system. In
54
+ order to import the module, and make the plugin available to *fairseq*, the
55
+ command line supports the ``--user-dir`` flag that can be used to specify a
56
+ custom location for additional modules to load into *fairseq*.
57
+
58
+ For example, assuming this directory tree::
59
+
60
+ /home/user/my-module/
61
+ └── __init__.py
62
+
63
+ with ``__init__.py``::
64
+
65
+ from fairseq.models import register_model_architecture
66
+ from fairseq.models.transformer import transformer_vaswani_wmt_en_de_big
67
+
68
+ @register_model_architecture('transformer', 'my_transformer')
69
+ def transformer_mmt_big(args):
70
+ transformer_vaswani_wmt_en_de_big(args)
71
+
72
+ it is possible to invoke the :ref:`fairseq-train` script with the new architecture with::
73
+
74
+ fairseq-train ... --user-dir /home/user/my-module -a my_transformer --task translation
docs/requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ sphinx<2.0
2
+ sphinx-argparse
docs/tasks.rst ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .. role:: hidden
2
+ :class: hidden-section
3
+
4
+ .. module:: fairseq.tasks
5
+
6
+ .. _Tasks:
7
+
8
+ Tasks
9
+ =====
10
+
11
+ Tasks store dictionaries and provide helpers for loading/iterating over
12
+ Datasets, initializing the Model/Criterion and calculating the loss.
13
+
14
+ Tasks can be selected via the ``--task`` command-line argument. Once selected, a
15
+ task may expose additional command-line arguments for further configuration.
16
+
17
+ Example usage::
18
+
19
+ # setup the task (e.g., load dictionaries)
20
+ task = fairseq.tasks.setup_task(args)
21
+
22
+ # build model and criterion
23
+ model = task.build_model(args)
24
+ criterion = task.build_criterion(args)
25
+
26
+ # load datasets
27
+ task.load_dataset('train')
28
+ task.load_dataset('valid')
29
+
30
+ # iterate over mini-batches of data
31
+ batch_itr = task.get_batch_iterator(
32
+ task.dataset('train'), max_tokens=4096,
33
+ )
34
+ for batch in batch_itr:
35
+ # compute the loss
36
+ loss, sample_size, logging_output = task.get_loss(
37
+ model, criterion, batch,
38
+ )
39
+ loss.backward()
40
+
41
+
42
+ Translation
43
+ -----------
44
+
45
+ .. autoclass:: fairseq.tasks.translation.TranslationTask
46
+
47
+ .. _language modeling:
48
+
49
+ Language Modeling
50
+ -----------------
51
+
52
+ .. autoclass:: fairseq.tasks.language_modeling.LanguageModelingTask
53
+
54
+
55
+ Adding new tasks
56
+ ----------------
57
+
58
+ .. autofunction:: fairseq.tasks.register_task
59
+ .. autoclass:: fairseq.tasks.FairseqTask
60
+ :members:
61
+ :undoc-members:
docs/tutorial_classifying_names.rst ADDED
@@ -0,0 +1,415 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Tutorial: Classifying Names with a Character-Level RNN
2
+ ======================================================
3
+
4
+ In this tutorial we will extend fairseq to support *classification* tasks. In
5
+ particular we will re-implement the PyTorch tutorial for `Classifying Names with
6
+ a Character-Level RNN <https://pytorch.org/tutorials/intermediate/char_rnn_classification_tutorial.html>`_
7
+ in fairseq. It is recommended to quickly skim that tutorial before beginning
8
+ this one.
9
+
10
+ This tutorial covers:
11
+
12
+ 1. **Preprocessing the data** to create dictionaries.
13
+ 2. **Registering a new Model** that encodes an input sentence with a simple RNN
14
+ and predicts the output label.
15
+ 3. **Registering a new Task** that loads our dictionaries and dataset.
16
+ 4. **Training the Model** using the existing command-line tools.
17
+ 5. **Writing an evaluation script** that imports fairseq and allows us to
18
+ interactively evaluate our model on new inputs.
19
+
20
+
21
+ 1. Preprocessing the data
22
+ -------------------------
23
+
24
+ The original tutorial provides raw data, but we'll work with a modified version
25
+ of the data that is already tokenized into characters and split into separate
26
+ train, valid and test sets.
27
+
28
+ Download and extract the data from here:
29
+ `tutorial_names.tar.gz <https://dl.fbaipublicfiles.com/fairseq/data/tutorial_names.tar.gz>`_
30
+
31
+ Once extracted, let's preprocess the data using the :ref:`fairseq-preprocess`
32
+ command-line tool to create the dictionaries. While this tool is primarily
33
+ intended for sequence-to-sequence problems, we're able to reuse it here by
34
+ treating the label as a "target" sequence of length 1. We'll also output the
35
+ preprocessed files in "raw" format using the ``--dataset-impl`` option to
36
+ enhance readability:
37
+
38
+ .. code-block:: console
39
+
40
+ > fairseq-preprocess \
41
+ --trainpref names/train --validpref names/valid --testpref names/test \
42
+ --source-lang input --target-lang label \
43
+ --destdir names-bin --dataset-impl raw
44
+
45
+ After running the above command you should see a new directory,
46
+ :file:`names-bin/`, containing the dictionaries for *inputs* and *labels*.
47
+
48
+
49
+ 2. Registering a new Model
50
+ --------------------------
51
+
52
+ Next we'll register a new model in fairseq that will encode an input sentence
53
+ with a simple RNN and predict the output label. Compared to the original PyTorch
54
+ tutorial, our version will also work with batches of data and GPU Tensors.
55
+
56
+ First let's copy the simple RNN module implemented in the `PyTorch tutorial
57
+ <https://pytorch.org/tutorials/intermediate/char_rnn_classification_tutorial.html#creating-the-network>`_.
58
+ Create a new file named :file:`fairseq/models/rnn_classifier.py` with the
59
+ following contents::
60
+
61
+ import torch
62
+ import torch.nn as nn
63
+
64
+ class RNN(nn.Module):
65
+
66
+ def __init__(self, input_size, hidden_size, output_size):
67
+ super(RNN, self).__init__()
68
+
69
+ self.hidden_size = hidden_size
70
+
71
+ self.i2h = nn.Linear(input_size + hidden_size, hidden_size)
72
+ self.i2o = nn.Linear(input_size + hidden_size, output_size)
73
+ self.softmax = nn.LogSoftmax(dim=1)
74
+
75
+ def forward(self, input, hidden):
76
+ combined = torch.cat((input, hidden), 1)
77
+ hidden = self.i2h(combined)
78
+ output = self.i2o(combined)
79
+ output = self.softmax(output)
80
+ return output, hidden
81
+
82
+ def initHidden(self):
83
+ return torch.zeros(1, self.hidden_size)
84
+
85
+ We must also *register* this model with fairseq using the
86
+ :func:`~fairseq.models.register_model` function decorator. Once the model is
87
+ registered we'll be able to use it with the existing :ref:`Command-line Tools`.
88
+
89
+ All registered models must implement the :class:`~fairseq.models.BaseFairseqModel`
90
+ interface, so we'll create a small wrapper class in the same file and register
91
+ it in fairseq with the name ``'rnn_classifier'``::
92
+
93
+ from fairseq.models import BaseFairseqModel, register_model
94
+
95
+ # Note: the register_model "decorator" should immediately precede the
96
+ # definition of the Model class.
97
+
98
+ @register_model('rnn_classifier')
99
+ class FairseqRNNClassifier(BaseFairseqModel):
100
+
101
+ @staticmethod
102
+ def add_args(parser):
103
+ # Models can override this method to add new command-line arguments.
104
+ # Here we'll add a new command-line argument to configure the
105
+ # dimensionality of the hidden state.
106
+ parser.add_argument(
107
+ '--hidden-dim', type=int, metavar='N',
108
+ help='dimensionality of the hidden state',
109
+ )
110
+
111
+ @classmethod
112
+ def build_model(cls, args, task):
113
+ # Fairseq initializes models by calling the ``build_model()``
114
+ # function. This provides more flexibility, since the returned model
115
+ # instance can be of a different type than the one that was called.
116
+ # In this case we'll just return a FairseqRNNClassifier instance.
117
+
118
+ # Initialize our RNN module
119
+ rnn = RNN(
120
+ # We'll define the Task in the next section, but for now just
121
+ # notice that the task holds the dictionaries for the "source"
122
+ # (i.e., the input sentence) and "target" (i.e., the label).
123
+ input_size=len(task.source_dictionary),
124
+ hidden_size=args.hidden_dim,
125
+ output_size=len(task.target_dictionary),
126
+ )
127
+
128
+ # Return the wrapped version of the module
129
+ return FairseqRNNClassifier(
130
+ rnn=rnn,
131
+ input_vocab=task.source_dictionary,
132
+ )
133
+
134
+ def __init__(self, rnn, input_vocab):
135
+ super(FairseqRNNClassifier, self).__init__()
136
+
137
+ self.rnn = rnn
138
+ self.input_vocab = input_vocab
139
+
140
+ # The RNN module in the tutorial expects one-hot inputs, so we can
141
+ # precompute the identity matrix to help convert from indices to
142
+ # one-hot vectors. We register it as a buffer so that it is moved to
143
+ # the GPU when ``cuda()`` is called.
144
+ self.register_buffer('one_hot_inputs', torch.eye(len(input_vocab)))
145
+
146
+ def forward(self, src_tokens, src_lengths):
147
+ # The inputs to the ``forward()`` function are determined by the
148
+ # Task, and in particular the ``'net_input'`` key in each
149
+ # mini-batch. We'll define the Task in the next section, but for
150
+ # now just know that *src_tokens* has shape `(batch, src_len)` and
151
+ # *src_lengths* has shape `(batch)`.
152
+ bsz, max_src_len = src_tokens.size()
153
+
154
+ # Initialize the RNN hidden state. Compared to the original PyTorch
155
+ # tutorial we'll also handle batched inputs and work on the GPU.
156
+ hidden = self.rnn.initHidden()
157
+ hidden = hidden.repeat(bsz, 1) # expand for batched inputs
158
+ hidden = hidden.to(src_tokens.device) # move to GPU
159
+
160
+ for i in range(max_src_len):
161
+ # WARNING: The inputs have padding, so we should mask those
162
+ # elements here so that padding doesn't affect the results.
163
+ # This is left as an exercise for the reader. The padding symbol
164
+ # is given by ``self.input_vocab.pad()`` and the unpadded length
165
+ # of each input is given by *src_lengths*.
166
+
167
+ # One-hot encode a batch of input characters.
168
+ input = self.one_hot_inputs[src_tokens[:, i].long()]
169
+
170
+ # Feed the input to our RNN.
171
+ output, hidden = self.rnn(input, hidden)
172
+
173
+ # Return the final output state for making a prediction
174
+ return output
175
+
176
+ Finally let's define a *named architecture* with the configuration for our
177
+ model. This is done with the :func:`~fairseq.models.register_model_architecture`
178
+ function decorator. Thereafter this named architecture can be used with the
179
+ ``--arch`` command-line argument, e.g., ``--arch pytorch_tutorial_rnn``::
180
+
181
+ from fairseq.models import register_model_architecture
182
+
183
+ # The first argument to ``register_model_architecture()`` should be the name
184
+ # of the model we registered above (i.e., 'rnn_classifier'). The function we
185
+ # register here should take a single argument *args* and modify it in-place
186
+ # to match the desired architecture.
187
+
188
+ @register_model_architecture('rnn_classifier', 'pytorch_tutorial_rnn')
189
+ def pytorch_tutorial_rnn(args):
190
+ # We use ``getattr()`` to prioritize arguments that are explicitly given
191
+ # on the command-line, so that the defaults defined below are only used
192
+ # when no other value has been specified.
193
+ args.hidden_dim = getattr(args, 'hidden_dim', 128)
194
+
195
+
196
+ 3. Registering a new Task
197
+ -------------------------
198
+
199
+ Now we'll register a new :class:`~fairseq.tasks.FairseqTask` that will load our
200
+ dictionaries and dataset. Tasks can also control how the data is batched into
201
+ mini-batches, but in this tutorial we'll reuse the batching provided by
202
+ :class:`fairseq.data.LanguagePairDataset`.
203
+
204
+ Create a new file named :file:`fairseq/tasks/simple_classification.py` with the
205
+ following contents::
206
+
207
+ import os
208
+ import torch
209
+
210
+ from fairseq.data import Dictionary, LanguagePairDataset
211
+ from fairseq.tasks import FairseqTask, register_task
212
+
213
+
214
+ @register_task('simple_classification')
215
+ class SimpleClassificationTask(LegacyFairseqTask):
216
+
217
+ @staticmethod
218
+ def add_args(parser):
219
+ # Add some command-line arguments for specifying where the data is
220
+ # located and the maximum supported input length.
221
+ parser.add_argument('data', metavar='FILE',
222
+ help='file prefix for data')
223
+ parser.add_argument('--max-positions', default=1024, type=int,
224
+ help='max input length')
225
+
226
+ @classmethod
227
+ def setup_task(cls, args, **kwargs):
228
+ # Here we can perform any setup required for the task. This may include
229
+ # loading Dictionaries, initializing shared Embedding layers, etc.
230
+ # In this case we'll just load the Dictionaries.
231
+ input_vocab = Dictionary.load(os.path.join(args.data, 'dict.input.txt'))
232
+ label_vocab = Dictionary.load(os.path.join(args.data, 'dict.label.txt'))
233
+ print('| [input] dictionary: {} types'.format(len(input_vocab)))
234
+ print('| [label] dictionary: {} types'.format(len(label_vocab)))
235
+
236
+ return SimpleClassificationTask(args, input_vocab, label_vocab)
237
+
238
+ def __init__(self, args, input_vocab, label_vocab):
239
+ super().__init__(args)
240
+ self.input_vocab = input_vocab
241
+ self.label_vocab = label_vocab
242
+
243
+ def load_dataset(self, split, **kwargs):
244
+ """Load a given dataset split (e.g., train, valid, test)."""
245
+
246
+ prefix = os.path.join(self.args.data, '{}.input-label'.format(split))
247
+
248
+ # Read input sentences.
249
+ sentences, lengths = [], []
250
+ with open(prefix + '.input', encoding='utf-8') as file:
251
+ for line in file:
252
+ sentence = line.strip()
253
+
254
+ # Tokenize the sentence, splitting on spaces
255
+ tokens = self.input_vocab.encode_line(
256
+ sentence, add_if_not_exist=False,
257
+ )
258
+
259
+ sentences.append(tokens)
260
+ lengths.append(tokens.numel())
261
+
262
+ # Read labels.
263
+ labels = []
264
+ with open(prefix + '.label', encoding='utf-8') as file:
265
+ for line in file:
266
+ label = line.strip()
267
+ labels.append(
268
+ # Convert label to a numeric ID.
269
+ torch.LongTensor([self.label_vocab.add_symbol(label)])
270
+ )
271
+
272
+ assert len(sentences) == len(labels)
273
+ print('| {} {} {} examples'.format(self.args.data, split, len(sentences)))
274
+
275
+ # We reuse LanguagePairDataset since classification can be modeled as a
276
+ # sequence-to-sequence task where the target sequence has length 1.
277
+ self.datasets[split] = LanguagePairDataset(
278
+ src=sentences,
279
+ src_sizes=lengths,
280
+ src_dict=self.input_vocab,
281
+ tgt=labels,
282
+ tgt_sizes=torch.ones(len(labels)), # targets have length 1
283
+ tgt_dict=self.label_vocab,
284
+ left_pad_source=False,
285
+ # Since our target is a single class label, there's no need for
286
+ # teacher forcing. If we set this to ``True`` then our Model's
287
+ # ``forward()`` method would receive an additional argument called
288
+ # *prev_output_tokens* that would contain a shifted version of the
289
+ # target sequence.
290
+ input_feeding=False,
291
+ )
292
+
293
+ def max_positions(self):
294
+ """Return the max input length allowed by the task."""
295
+ # The source should be less than *args.max_positions* and the "target"
296
+ # has max length 1.
297
+ return (self.args.max_positions, 1)
298
+
299
+ @property
300
+ def source_dictionary(self):
301
+ """Return the source :class:`~fairseq.data.Dictionary`."""
302
+ return self.input_vocab
303
+
304
+ @property
305
+ def target_dictionary(self):
306
+ """Return the target :class:`~fairseq.data.Dictionary`."""
307
+ return self.label_vocab
308
+
309
+ # We could override this method if we wanted more control over how batches
310
+ # are constructed, but it's not necessary for this tutorial since we can
311
+ # reuse the batching provided by LanguagePairDataset.
312
+ #
313
+ # def get_batch_iterator(
314
+ # self, dataset, max_tokens=None, max_sentences=None, max_positions=None,
315
+ # ignore_invalid_inputs=False, required_batch_size_multiple=1,
316
+ # seed=1, num_shards=1, shard_id=0, num_workers=0, epoch=1,
317
+ # data_buffer_size=0, disable_iterator_cache=False,
318
+ # ):
319
+ # (...)
320
+
321
+
322
+ 4. Training the Model
323
+ ---------------------
324
+
325
+ Now we're ready to train the model. We can use the existing :ref:`fairseq-train`
326
+ command-line tool for this, making sure to specify our new Task (``--task
327
+ simple_classification``) and Model architecture (``--arch
328
+ pytorch_tutorial_rnn``):
329
+
330
+ .. note::
331
+
332
+ You can also configure the dimensionality of the hidden state by passing the
333
+ ``--hidden-dim`` argument to :ref:`fairseq-train`.
334
+
335
+ .. code-block:: console
336
+
337
+ > fairseq-train names-bin \
338
+ --task simple_classification \
339
+ --arch pytorch_tutorial_rnn \
340
+ --optimizer adam --lr 0.001 --lr-shrink 0.5 \
341
+ --max-tokens 1000
342
+ (...)
343
+ | epoch 027 | loss 1.200 | ppl 2.30 | wps 15728 | ups 119.4 | wpb 116 | bsz 116 | num_updates 3726 | lr 1.5625e-05 | gnorm 1.290 | clip 0% | oom 0 | wall 32 | train_wall 21
344
+ | epoch 027 | valid on 'valid' subset | valid_loss 1.41304 | valid_ppl 2.66 | num_updates 3726 | best 1.41208
345
+ | done training in 31.6 seconds
346
+
347
+ The model files should appear in the :file:`checkpoints/` directory.
348
+
349
+
350
+ 5. Writing an evaluation script
351
+ -------------------------------
352
+
353
+ Finally we can write a short script to evaluate our model on new inputs. Create
354
+ a new file named :file:`eval_classifier.py` with the following contents::
355
+
356
+ from fairseq import checkpoint_utils, data, options, tasks
357
+
358
+ # Parse command-line arguments for generation
359
+ parser = options.get_generation_parser(default_task='simple_classification')
360
+ args = options.parse_args_and_arch(parser)
361
+
362
+ # Setup task
363
+ task = tasks.setup_task(args)
364
+
365
+ # Load model
366
+ print('| loading model from {}'.format(args.path))
367
+ models, _model_args = checkpoint_utils.load_model_ensemble([args.path], task=task)
368
+ model = models[0]
369
+
370
+ while True:
371
+ sentence = input('\nInput: ')
372
+
373
+ # Tokenize into characters
374
+ chars = ' '.join(list(sentence.strip()))
375
+ tokens = task.source_dictionary.encode_line(
376
+ chars, add_if_not_exist=False,
377
+ )
378
+
379
+ # Build mini-batch to feed to the model
380
+ batch = data.language_pair_dataset.collate(
381
+ samples=[{'id': -1, 'source': tokens}], # bsz = 1
382
+ pad_idx=task.source_dictionary.pad(),
383
+ eos_idx=task.source_dictionary.eos(),
384
+ left_pad_source=False,
385
+ input_feeding=False,
386
+ )
387
+
388
+ # Feed batch to the model and get predictions
389
+ preds = model(**batch['net_input'])
390
+
391
+ # Print top 3 predictions and their log-probabilities
392
+ top_scores, top_labels = preds[0].topk(k=3)
393
+ for score, label_idx in zip(top_scores, top_labels):
394
+ label_name = task.target_dictionary.string([label_idx])
395
+ print('({:.2f})\t{}'.format(score, label_name))
396
+
397
+ Now we can evaluate our model interactively. Note that we have included the
398
+ original data path (:file:`names-bin/`) so that the dictionaries can be loaded:
399
+
400
+ .. code-block:: console
401
+
402
+ > python eval_classifier.py names-bin --path checkpoints/checkpoint_best.pt
403
+ | [input] dictionary: 64 types
404
+ | [label] dictionary: 24 types
405
+ | loading model from checkpoints/checkpoint_best.pt
406
+
407
+ Input: Satoshi
408
+ (-0.61) Japanese
409
+ (-1.20) Arabic
410
+ (-2.86) Italian
411
+
412
+ Input: Sinbad
413
+ (-0.30) Arabic
414
+ (-1.76) English
415
+ (-4.08) Russian
docs/tutorial_simple_lstm.rst ADDED
@@ -0,0 +1,518 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Tutorial: Simple LSTM
2
+ =====================
3
+
4
+ In this tutorial we will extend fairseq by adding a new
5
+ :class:`~fairseq.models.FairseqEncoderDecoderModel` that encodes a source
6
+ sentence with an LSTM and then passes the final hidden state to a second LSTM
7
+ that decodes the target sentence (without attention).
8
+
9
+ This tutorial covers:
10
+
11
+ 1. **Writing an Encoder and Decoder** to encode/decode the source/target
12
+ sentence, respectively.
13
+ 2. **Registering a new Model** so that it can be used with the existing
14
+ :ref:`Command-line tools`.
15
+ 3. **Training the Model** using the existing command-line tools.
16
+ 4. **Making generation faster** by modifying the Decoder to use
17
+ :ref:`Incremental decoding`.
18
+
19
+
20
+ 1. Building an Encoder and Decoder
21
+ ----------------------------------
22
+
23
+ In this section we'll define a simple LSTM Encoder and Decoder. All Encoders
24
+ should implement the :class:`~fairseq.models.FairseqEncoder` interface and
25
+ Decoders should implement the :class:`~fairseq.models.FairseqDecoder` interface.
26
+ These interfaces themselves extend :class:`torch.nn.Module`, so FairseqEncoders
27
+ and FairseqDecoders can be written and used in the same ways as ordinary PyTorch
28
+ Modules.
29
+
30
+
31
+ Encoder
32
+ ~~~~~~~
33
+
34
+ Our Encoder will embed the tokens in the source sentence, feed them to a
35
+ :class:`torch.nn.LSTM` and return the final hidden state. To create our encoder
36
+ save the following in a new file named :file:`fairseq/models/simple_lstm.py`::
37
+
38
+ import torch.nn as nn
39
+ from fairseq import utils
40
+ from fairseq.models import FairseqEncoder
41
+
42
+ class SimpleLSTMEncoder(FairseqEncoder):
43
+
44
+ def __init__(
45
+ self, args, dictionary, embed_dim=128, hidden_dim=128, dropout=0.1,
46
+ ):
47
+ super().__init__(dictionary)
48
+ self.args = args
49
+
50
+ # Our encoder will embed the inputs before feeding them to the LSTM.
51
+ self.embed_tokens = nn.Embedding(
52
+ num_embeddings=len(dictionary),
53
+ embedding_dim=embed_dim,
54
+ padding_idx=dictionary.pad(),
55
+ )
56
+ self.dropout = nn.Dropout(p=dropout)
57
+
58
+ # We'll use a single-layer, unidirectional LSTM for simplicity.
59
+ self.lstm = nn.LSTM(
60
+ input_size=embed_dim,
61
+ hidden_size=hidden_dim,
62
+ num_layers=1,
63
+ bidirectional=False,
64
+ batch_first=True,
65
+ )
66
+
67
+ def forward(self, src_tokens, src_lengths):
68
+ # The inputs to the ``forward()`` function are determined by the
69
+ # Task, and in particular the ``'net_input'`` key in each
70
+ # mini-batch. We discuss Tasks in the next tutorial, but for now just
71
+ # know that *src_tokens* has shape `(batch, src_len)` and *src_lengths*
72
+ # has shape `(batch)`.
73
+
74
+ # Note that the source is typically padded on the left. This can be
75
+ # configured by adding the `--left-pad-source "False"` command-line
76
+ # argument, but here we'll make the Encoder handle either kind of
77
+ # padding by converting everything to be right-padded.
78
+ if self.args.left_pad_source:
79
+ # Convert left-padding to right-padding.
80
+ src_tokens = utils.convert_padding_direction(
81
+ src_tokens,
82
+ padding_idx=self.dictionary.pad(),
83
+ left_to_right=True
84
+ )
85
+
86
+ # Embed the source.
87
+ x = self.embed_tokens(src_tokens)
88
+
89
+ # Apply dropout.
90
+ x = self.dropout(x)
91
+
92
+ # Pack the sequence into a PackedSequence object to feed to the LSTM.
93
+ x = nn.utils.rnn.pack_padded_sequence(x, src_lengths, batch_first=True)
94
+
95
+ # Get the output from the LSTM.
96
+ _outputs, (final_hidden, _final_cell) = self.lstm(x)
97
+
98
+ # Return the Encoder's output. This can be any object and will be
99
+ # passed directly to the Decoder.
100
+ return {
101
+ # this will have shape `(bsz, hidden_dim)`
102
+ 'final_hidden': final_hidden.squeeze(0),
103
+ }
104
+
105
+ # Encoders are required to implement this method so that we can rearrange
106
+ # the order of the batch elements during inference (e.g., beam search).
107
+ def reorder_encoder_out(self, encoder_out, new_order):
108
+ """
109
+ Reorder encoder output according to `new_order`.
110
+
111
+ Args:
112
+ encoder_out: output from the ``forward()`` method
113
+ new_order (LongTensor): desired order
114
+
115
+ Returns:
116
+ `encoder_out` rearranged according to `new_order`
117
+ """
118
+ final_hidden = encoder_out['final_hidden']
119
+ return {
120
+ 'final_hidden': final_hidden.index_select(0, new_order),
121
+ }
122
+
123
+
124
+ Decoder
125
+ ~~~~~~~
126
+
127
+ Our Decoder will predict the next word, conditioned on the Encoder's final
128
+ hidden state and an embedded representation of the previous target word -- which
129
+ is sometimes called *teacher forcing*. More specifically, we'll use a
130
+ :class:`torch.nn.LSTM` to produce a sequence of hidden states that we'll project
131
+ to the size of the output vocabulary to predict each target word.
132
+
133
+ ::
134
+
135
+ import torch
136
+ from fairseq.models import FairseqDecoder
137
+
138
+ class SimpleLSTMDecoder(FairseqDecoder):
139
+
140
+ def __init__(
141
+ self, dictionary, encoder_hidden_dim=128, embed_dim=128, hidden_dim=128,
142
+ dropout=0.1,
143
+ ):
144
+ super().__init__(dictionary)
145
+
146
+ # Our decoder will embed the inputs before feeding them to the LSTM.
147
+ self.embed_tokens = nn.Embedding(
148
+ num_embeddings=len(dictionary),
149
+ embedding_dim=embed_dim,
150
+ padding_idx=dictionary.pad(),
151
+ )
152
+ self.dropout = nn.Dropout(p=dropout)
153
+
154
+ # We'll use a single-layer, unidirectional LSTM for simplicity.
155
+ self.lstm = nn.LSTM(
156
+ # For the first layer we'll concatenate the Encoder's final hidden
157
+ # state with the embedded target tokens.
158
+ input_size=encoder_hidden_dim + embed_dim,
159
+ hidden_size=hidden_dim,
160
+ num_layers=1,
161
+ bidirectional=False,
162
+ )
163
+
164
+ # Define the output projection.
165
+ self.output_projection = nn.Linear(hidden_dim, len(dictionary))
166
+
167
+ # During training Decoders are expected to take the entire target sequence
168
+ # (shifted right by one position) and produce logits over the vocabulary.
169
+ # The *prev_output_tokens* tensor begins with the end-of-sentence symbol,
170
+ # ``dictionary.eos()``, followed by the target sequence.
171
+ def forward(self, prev_output_tokens, encoder_out):
172
+ """
173
+ Args:
174
+ prev_output_tokens (LongTensor): previous decoder outputs of shape
175
+ `(batch, tgt_len)`, for teacher forcing
176
+ encoder_out (Tensor, optional): output from the encoder, used for
177
+ encoder-side attention
178
+
179
+ Returns:
180
+ tuple:
181
+ - the last decoder layer's output of shape
182
+ `(batch, tgt_len, vocab)`
183
+ - the last decoder layer's attention weights of shape
184
+ `(batch, tgt_len, src_len)`
185
+ """
186
+ bsz, tgt_len = prev_output_tokens.size()
187
+
188
+ # Extract the final hidden state from the Encoder.
189
+ final_encoder_hidden = encoder_out['final_hidden']
190
+
191
+ # Embed the target sequence, which has been shifted right by one
192
+ # position and now starts with the end-of-sentence symbol.
193
+ x = self.embed_tokens(prev_output_tokens)
194
+
195
+ # Apply dropout.
196
+ x = self.dropout(x)
197
+
198
+ # Concatenate the Encoder's final hidden state to *every* embedded
199
+ # target token.
200
+ x = torch.cat(
201
+ [x, final_encoder_hidden.unsqueeze(1).expand(bsz, tgt_len, -1)],
202
+ dim=2,
203
+ )
204
+
205
+ # Using PackedSequence objects in the Decoder is harder than in the
206
+ # Encoder, since the targets are not sorted in descending length order,
207
+ # which is a requirement of ``pack_padded_sequence()``. Instead we'll
208
+ # feed nn.LSTM directly.
209
+ initial_state = (
210
+ final_encoder_hidden.unsqueeze(0), # hidden
211
+ torch.zeros_like(final_encoder_hidden).unsqueeze(0), # cell
212
+ )
213
+ output, _ = self.lstm(
214
+ x.transpose(0, 1), # convert to shape `(tgt_len, bsz, dim)`
215
+ initial_state,
216
+ )
217
+ x = output.transpose(0, 1) # convert to shape `(bsz, tgt_len, hidden)`
218
+
219
+ # Project the outputs to the size of the vocabulary.
220
+ x = self.output_projection(x)
221
+
222
+ # Return the logits and ``None`` for the attention weights
223
+ return x, None
224
+
225
+
226
+ 2. Registering the Model
227
+ ------------------------
228
+
229
+ Now that we've defined our Encoder and Decoder we must *register* our model with
230
+ fairseq using the :func:`~fairseq.models.register_model` function decorator.
231
+ Once the model is registered we'll be able to use it with the existing
232
+ :ref:`Command-line Tools`.
233
+
234
+ All registered models must implement the
235
+ :class:`~fairseq.models.BaseFairseqModel` interface. For sequence-to-sequence
236
+ models (i.e., any model with a single Encoder and Decoder), we can instead
237
+ implement the :class:`~fairseq.models.FairseqEncoderDecoderModel` interface.
238
+
239
+ Create a small wrapper class in the same file and register it in fairseq with
240
+ the name ``'simple_lstm'``::
241
+
242
+ from fairseq.models import FairseqEncoderDecoderModel, register_model
243
+
244
+ # Note: the register_model "decorator" should immediately precede the
245
+ # definition of the Model class.
246
+
247
+ @register_model('simple_lstm')
248
+ class SimpleLSTMModel(FairseqEncoderDecoderModel):
249
+
250
+ @staticmethod
251
+ def add_args(parser):
252
+ # Models can override this method to add new command-line arguments.
253
+ # Here we'll add some new command-line arguments to configure dropout
254
+ # and the dimensionality of the embeddings and hidden states.
255
+ parser.add_argument(
256
+ '--encoder-embed-dim', type=int, metavar='N',
257
+ help='dimensionality of the encoder embeddings',
258
+ )
259
+ parser.add_argument(
260
+ '--encoder-hidden-dim', type=int, metavar='N',
261
+ help='dimensionality of the encoder hidden state',
262
+ )
263
+ parser.add_argument(
264
+ '--encoder-dropout', type=float, default=0.1,
265
+ help='encoder dropout probability',
266
+ )
267
+ parser.add_argument(
268
+ '--decoder-embed-dim', type=int, metavar='N',
269
+ help='dimensionality of the decoder embeddings',
270
+ )
271
+ parser.add_argument(
272
+ '--decoder-hidden-dim', type=int, metavar='N',
273
+ help='dimensionality of the decoder hidden state',
274
+ )
275
+ parser.add_argument(
276
+ '--decoder-dropout', type=float, default=0.1,
277
+ help='decoder dropout probability',
278
+ )
279
+
280
+ @classmethod
281
+ def build_model(cls, args, task):
282
+ # Fairseq initializes models by calling the ``build_model()``
283
+ # function. This provides more flexibility, since the returned model
284
+ # instance can be of a different type than the one that was called.
285
+ # In this case we'll just return a SimpleLSTMModel instance.
286
+
287
+ # Initialize our Encoder and Decoder.
288
+ encoder = SimpleLSTMEncoder(
289
+ args=args,
290
+ dictionary=task.source_dictionary,
291
+ embed_dim=args.encoder_embed_dim,
292
+ hidden_dim=args.encoder_hidden_dim,
293
+ dropout=args.encoder_dropout,
294
+ )
295
+ decoder = SimpleLSTMDecoder(
296
+ dictionary=task.target_dictionary,
297
+ encoder_hidden_dim=args.encoder_hidden_dim,
298
+ embed_dim=args.decoder_embed_dim,
299
+ hidden_dim=args.decoder_hidden_dim,
300
+ dropout=args.decoder_dropout,
301
+ )
302
+ model = SimpleLSTMModel(encoder, decoder)
303
+
304
+ # Print the model architecture.
305
+ print(model)
306
+
307
+ return model
308
+
309
+ # We could override the ``forward()`` if we wanted more control over how
310
+ # the encoder and decoder interact, but it's not necessary for this
311
+ # tutorial since we can inherit the default implementation provided by
312
+ # the FairseqEncoderDecoderModel base class, which looks like:
313
+ #
314
+ # def forward(self, src_tokens, src_lengths, prev_output_tokens):
315
+ # encoder_out = self.encoder(src_tokens, src_lengths)
316
+ # decoder_out = self.decoder(prev_output_tokens, encoder_out)
317
+ # return decoder_out
318
+
319
+ Finally let's define a *named architecture* with the configuration for our
320
+ model. This is done with the :func:`~fairseq.models.register_model_architecture`
321
+ function decorator. Thereafter this named architecture can be used with the
322
+ ``--arch`` command-line argument, e.g., ``--arch tutorial_simple_lstm``::
323
+
324
+ from fairseq.models import register_model_architecture
325
+
326
+ # The first argument to ``register_model_architecture()`` should be the name
327
+ # of the model we registered above (i.e., 'simple_lstm'). The function we
328
+ # register here should take a single argument *args* and modify it in-place
329
+ # to match the desired architecture.
330
+
331
+ @register_model_architecture('simple_lstm', 'tutorial_simple_lstm')
332
+ def tutorial_simple_lstm(args):
333
+ # We use ``getattr()`` to prioritize arguments that are explicitly given
334
+ # on the command-line, so that the defaults defined below are only used
335
+ # when no other value has been specified.
336
+ args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 256)
337
+ args.encoder_hidden_dim = getattr(args, 'encoder_hidden_dim', 256)
338
+ args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 256)
339
+ args.decoder_hidden_dim = getattr(args, 'decoder_hidden_dim', 256)
340
+
341
+
342
+ 3. Training the Model
343
+ ---------------------
344
+
345
+ Now we're ready to train the model. We can use the existing :ref:`fairseq-train`
346
+ command-line tool for this, making sure to specify our new Model architecture
347
+ (``--arch tutorial_simple_lstm``).
348
+
349
+ .. note::
350
+
351
+ Make sure you've already preprocessed the data from the IWSLT example in the
352
+ :file:`examples/translation/` directory.
353
+
354
+ .. code-block:: console
355
+
356
+ > fairseq-train data-bin/iwslt14.tokenized.de-en \
357
+ --arch tutorial_simple_lstm \
358
+ --encoder-dropout 0.2 --decoder-dropout 0.2 \
359
+ --optimizer adam --lr 0.005 --lr-shrink 0.5 \
360
+ --max-tokens 12000
361
+ (...)
362
+ | epoch 052 | loss 4.027 | ppl 16.30 | wps 420805 | ups 39.7 | wpb 9841 | bsz 400 | num_updates 20852 | lr 1.95313e-05 | gnorm 0.218 | clip 0% | oom 0 | wall 529 | train_wall 396
363
+ | epoch 052 | valid on 'valid' subset | valid_loss 4.74989 | valid_ppl 26.91 | num_updates 20852 | best 4.74954
364
+
365
+ The model files should appear in the :file:`checkpoints/` directory. While this
366
+ model architecture is not very good, we can use the :ref:`fairseq-generate` script to
367
+ generate translations and compute our BLEU score over the test set:
368
+
369
+ .. code-block:: console
370
+
371
+ > fairseq-generate data-bin/iwslt14.tokenized.de-en \
372
+ --path checkpoints/checkpoint_best.pt \
373
+ --beam 5 \
374
+ --remove-bpe
375
+ (...)
376
+ | Translated 6750 sentences (153132 tokens) in 17.3s (389.12 sentences/s, 8827.68 tokens/s)
377
+ | Generate test with beam=5: BLEU4 = 8.18, 38.8/12.1/4.7/2.0 (BP=1.000, ratio=1.066, syslen=139865, reflen=131146)
378
+
379
+
380
+ 4. Making generation faster
381
+ ---------------------------
382
+
383
+ While autoregressive generation from sequence-to-sequence models is inherently
384
+ slow, our implementation above is especially slow because it recomputes the
385
+ entire sequence of Decoder hidden states for every output token (i.e., it is
386
+ ``O(n^2)``). We can make this significantly faster by instead caching the
387
+ previous hidden states.
388
+
389
+ In fairseq this is called :ref:`Incremental decoding`. Incremental decoding is a
390
+ special mode at inference time where the Model only receives a single timestep
391
+ of input corresponding to the immediately previous output token (for teacher
392
+ forcing) and must produce the next output incrementally. Thus the model must
393
+ cache any long-term state that is needed about the sequence, e.g., hidden
394
+ states, convolutional states, etc.
395
+
396
+ To implement incremental decoding we will modify our model to implement the
397
+ :class:`~fairseq.models.FairseqIncrementalDecoder` interface. Compared to the
398
+ standard :class:`~fairseq.models.FairseqDecoder` interface, the incremental
399
+ decoder interface allows ``forward()`` methods to take an extra keyword argument
400
+ (*incremental_state*) that can be used to cache state across time-steps.
401
+
402
+ Let's replace our ``SimpleLSTMDecoder`` with an incremental one::
403
+
404
+ import torch
405
+ from fairseq.models import FairseqIncrementalDecoder
406
+
407
+ class SimpleLSTMDecoder(FairseqIncrementalDecoder):
408
+
409
+ def __init__(
410
+ self, dictionary, encoder_hidden_dim=128, embed_dim=128, hidden_dim=128,
411
+ dropout=0.1,
412
+ ):
413
+ # This remains the same as before.
414
+ super().__init__(dictionary)
415
+ self.embed_tokens = nn.Embedding(
416
+ num_embeddings=len(dictionary),
417
+ embedding_dim=embed_dim,
418
+ padding_idx=dictionary.pad(),
419
+ )
420
+ self.dropout = nn.Dropout(p=dropout)
421
+ self.lstm = nn.LSTM(
422
+ input_size=encoder_hidden_dim + embed_dim,
423
+ hidden_size=hidden_dim,
424
+ num_layers=1,
425
+ bidirectional=False,
426
+ )
427
+ self.output_projection = nn.Linear(hidden_dim, len(dictionary))
428
+
429
+ # We now take an additional kwarg (*incremental_state*) for caching the
430
+ # previous hidden and cell states.
431
+ def forward(self, prev_output_tokens, encoder_out, incremental_state=None):
432
+ if incremental_state is not None:
433
+ # If the *incremental_state* argument is not ``None`` then we are
434
+ # in incremental inference mode. While *prev_output_tokens* will
435
+ # still contain the entire decoded prefix, we will only use the
436
+ # last step and assume that the rest of the state is cached.
437
+ prev_output_tokens = prev_output_tokens[:, -1:]
438
+
439
+ # This remains the same as before.
440
+ bsz, tgt_len = prev_output_tokens.size()
441
+ final_encoder_hidden = encoder_out['final_hidden']
442
+ x = self.embed_tokens(prev_output_tokens)
443
+ x = self.dropout(x)
444
+ x = torch.cat(
445
+ [x, final_encoder_hidden.unsqueeze(1).expand(bsz, tgt_len, -1)],
446
+ dim=2,
447
+ )
448
+
449
+ # We will now check the cache and load the cached previous hidden and
450
+ # cell states, if they exist, otherwise we will initialize them to
451
+ # zeros (as before). We will use the ``utils.get_incremental_state()``
452
+ # and ``utils.set_incremental_state()`` helpers.
453
+ initial_state = utils.get_incremental_state(
454
+ self, incremental_state, 'prev_state',
455
+ )
456
+ if initial_state is None:
457
+ # first time initialization, same as the original version
458
+ initial_state = (
459
+ final_encoder_hidden.unsqueeze(0), # hidden
460
+ torch.zeros_like(final_encoder_hidden).unsqueeze(0), # cell
461
+ )
462
+
463
+ # Run one step of our LSTM.
464
+ output, latest_state = self.lstm(x.transpose(0, 1), initial_state)
465
+
466
+ # Update the cache with the latest hidden and cell states.
467
+ utils.set_incremental_state(
468
+ self, incremental_state, 'prev_state', latest_state,
469
+ )
470
+
471
+ # This remains the same as before
472
+ x = output.transpose(0, 1)
473
+ x = self.output_projection(x)
474
+ return x, None
475
+
476
+ # The ``FairseqIncrementalDecoder`` interface also requires implementing a
477
+ # ``reorder_incremental_state()`` method, which is used during beam search
478
+ # to select and reorder the incremental state.
479
+ def reorder_incremental_state(self, incremental_state, new_order):
480
+ # Load the cached state.
481
+ prev_state = utils.get_incremental_state(
482
+ self, incremental_state, 'prev_state',
483
+ )
484
+
485
+ # Reorder batches according to *new_order*.
486
+ reordered_state = (
487
+ prev_state[0].index_select(1, new_order), # hidden
488
+ prev_state[1].index_select(1, new_order), # cell
489
+ )
490
+
491
+ # Update the cached state.
492
+ utils.set_incremental_state(
493
+ self, incremental_state, 'prev_state', reordered_state,
494
+ )
495
+
496
+ Finally, we can rerun generation and observe the speedup:
497
+
498
+ .. code-block:: console
499
+
500
+ # Before
501
+
502
+ > fairseq-generate data-bin/iwslt14.tokenized.de-en \
503
+ --path checkpoints/checkpoint_best.pt \
504
+ --beam 5 \
505
+ --remove-bpe
506
+ (...)
507
+ | Translated 6750 sentences (153132 tokens) in 17.3s (389.12 sentences/s, 8827.68 tokens/s)
508
+ | Generate test with beam=5: BLEU4 = 8.18, 38.8/12.1/4.7/2.0 (BP=1.000, ratio=1.066, syslen=139865, reflen=131146)
509
+
510
+ # After
511
+
512
+ > fairseq-generate data-bin/iwslt14.tokenized.de-en \
513
+ --path checkpoints/checkpoint_best.pt \
514
+ --beam 5 \
515
+ --remove-bpe
516
+ (...)
517
+ | Translated 6750 sentences (153132 tokens) in 5.5s (1225.54 sentences/s, 27802.94 tokens/s)
518
+ | Generate test with beam=5: BLEU4 = 8.18, 38.8/12.1/4.7/2.0 (BP=1.000, ratio=1.066, syslen=139865, reflen=131146)
examples/.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ !*/*.sh
2
+ !*/*.md
examples/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ try:
7
+ from fairseq.version import __version__ # noqa
8
+ except ImportError:
9
+ pass
examples/adaptive_span/README.md ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adaptive Span
2
+
3
+ Adaptive Span is a novel self-attention mechanism that can learn its optimal
4
+ attention span. This allows us to extend significantly the maximum context size
5
+ used in Transformer, while maintaining control over their memory footprint
6
+ and computational time. It uses the Truncated BPTT technique for training,
7
+ as in [transformerXL](https://github.com/pytorch/fairseq/blob/master/examples/truncated_bptt/README.md).
8
+
9
+ Adaptive Span was introduced by paper:
10
+ [Adaptive Attention Span in Transformers](https://arxiv.org/abs/1905.07799),
11
+ which achieved state-of-the-art language modeling results at the time of publication.
12
+
13
+ We manage to reproduce their result in fairseq and keep most of the
14
+ [original implementation](https://github.com/facebookresearch/adaptive-span) untouched.
15
+ You can refer to the their sweep file as well if any combination of hyperparameter is not clear.
16
+
17
+ ##### 0. Setup
18
+
19
+ First you need to process the Enwik8 dataset, we use the pre-tokenized dataset
20
+ from [adaptive span paper](https://github.com/facebookresearch/adaptive-span/blob/master/get_data.sh).
21
+ You can download the dataset, and then run:
22
+ ```bash
23
+ fairseq-preprocess --only-source --trainpref ~/data/enwik8/train.txt \
24
+ --validpref ~/data/enwik8/valid.txt --testpref ~/data/enwik8/test.txt \
25
+ --destdir ~/data/enwik8/data-bin/ --joined-dictionary --workers 20
26
+ ```
27
+
28
+ ##### 1. Train a Adaptive Span model on Enwik8
29
+
30
+ We will train a 12-layer Adaptive Span model following the [hyperparameters
31
+ used in the original
32
+ paper](https://github.com/facebookresearch/adaptive-span/blob/master/experiments/enwik8.sh).
33
+
34
+ The following command assumes 4 GPUs, so that the total batch size is 64
35
+ sequences (4 x 16). Training should take 2-3 days on 4 V100 GPUs:
36
+ ```bash
37
+ CUDA_VISIBLE_DEVICES=0,1,2,3 fairseq-train \
38
+ --user-dir examples/adaptive_span \
39
+ --data ~/data/enwik8/data-bin/ \
40
+ --fp16 --fp16-no-flatten-grads --max-update 600000 \
41
+ --task truncated_bptt_lm --tokens-per-sample 512 --arch adaptive_span \
42
+ --n-layer 12 --d-model 512 --n-head 8 --d-inner 2048 --dropout 0.3 \
43
+ --attn-span 8192 --optimizer adagrad_with_grad_clip --adagrad-clip 0.03 \
44
+ --validate-interval-updates 1000 \
45
+ --lr-scheduler fixed --warmup-updates 32000 --batch-size-valid 32 \
46
+ --lr 0.07 --criterion adaptive_span_loss --batch-size 16 --update-freq 1 \
47
+ --seed 2 --log-format json --log-interval 25 --aux-loss-scaler 5e-07
48
+ ```
49
+ This should land around 1.05 on validation, 1.03 on test. You can lower the
50
+ --aux-loss-scaler for better performance (longer span). It gives ~0.03 bpc
51
+ improvement to the transformerXL baseline here.
52
+ If training on a single GPU, set `--update-freq=4` to accumulate 4x gradients
53
+ and simulate training on 4 GPUs.
54
+ You can also reproduce the transformerXL result on enwik8 using this code base.
55
+ It should land around 1.06 on test,matching the [original paper](https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/run_enwik8_base.sh).
56
+ You can try by
57
+ ```bash
58
+ CUDA_VISIBLE_DEVICES=0,1,2,3 fairseq-train \
59
+ --user-dir examples/truncated_bptt \
60
+ ~/data/enwik8/data-bin/ \
61
+ --task truncated_bptt_lm --fp16 --max-update 400000 \
62
+ --tokens-per-sample 512 --arch transformer_xl --n-layer 12 \
63
+ --d-model 512 --n-head 8 --d-head 64 --d-inner 2048 --dropout 0.1 \
64
+ --dropatt 0.0 --mem-len 512 --optimizer adam --clip-norm 0.25 \
65
+ --lr-scheduler cosine --warmup-updates 0 \
66
+ --lr 0.0 --lr 0.00025 --batch-size 15 \
67
+ --update-freq 1 --seed 2 --log-format json --log-interval 25 \
68
+ --fp16
69
+ ```
70
+
71
+ ##### 2. Evaluate
72
+ For Adaptive Span:
73
+ ```bash
74
+ fairseq-eval-lm ~/data/enwik8/data-bin/ --path model/checkpoint_best.pt \
75
+ --user-dir examples/adaptive_span \
76
+ --task truncated_bptt_lm --batch-size 8 --tokens-per-sample 512 --gen-subset test
77
+ ```
78
+ For Transformer-XL evaluation:
79
+ ```bash
80
+ fairseq-eval-lm ~/data/enwik8/data-bin/ --path model/checkpoint_best.pt \
81
+ --user-dir examples/truncated_bptt/ --task truncated_bptt_lm --batch-size 8 \
82
+ --tokens-per-sample 80 \
83
+ --model-overrides '{"mem_len":2100,"clamp_len":820,"same_length":True}' \
84
+ --gen-subset valid
85
+ ```
86
+
87
+ *Note:* During training the model saw 512 tokens of context
88
+ (``--tokens-per-sample=512``), with batch size 8. These settings match the evaluation
89
+ settings from [the original
90
+ paper](https://github.com/facebookresearch/adaptive-span/blob/master/experiments/enwik8.sh).
examples/adaptive_span/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import importlib
7
+ import os
8
+
9
+ # automatically import any Python files in the current directory
10
+ cur_dir = os.path.dirname(__file__)
11
+ for file in os.listdir(cur_dir):
12
+ path = os.path.join(cur_dir, file)
13
+ if (
14
+ not file.startswith("_")
15
+ and not file.startswith(".")
16
+ and (file.endswith(".py") or os.path.isdir(path))
17
+ ):
18
+ mod_name = file[: file.find(".py")] if file.endswith(".py") else file
19
+ module = importlib.import_module(__name__ + "." + mod_name)
examples/adaptive_span/adagrad_with_grad_clip.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from torch.optim import Adagrad
7
+
8
+ from fairseq.optim import LegacyFairseqOptimizer, register_optimizer
9
+
10
+
11
+ @register_optimizer("adagrad_with_grad_clip")
12
+ class FairseqAdagradWithGradClip(LegacyFairseqOptimizer):
13
+ def __init__(self, args, params):
14
+ super().__init__(args)
15
+ self._optimizer = AdagradWithGradClip(params, **self.optimizer_config)
16
+
17
+ @staticmethod
18
+ def add_args(parser):
19
+ """Add optimizer-specific arguments to the parser."""
20
+ # fmt: off
21
+ parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD',
22
+ help='weight decay')
23
+ parser.add_argument('--adagrad-clip', default=0.0, type=float, metavar='D',
24
+ help='internal grad clip')
25
+ # fmt: on
26
+
27
+ @property
28
+ def optimizer_config(self):
29
+ """
30
+ Return a kwarg dictionary that will be used to override optimizer
31
+ args stored in checkpoints. This allows us to load a checkpoint and
32
+ resume training using a different set of optimizer args, e.g., with a
33
+ different learning rate.
34
+ """
35
+ return {
36
+ "lr": self.args.lr[0],
37
+ "weight_decay": self.args.weight_decay,
38
+ "grad_clip": self.args.adagrad_clip,
39
+ }
40
+
41
+ @property
42
+ def supports_flat_params(self):
43
+ return False
44
+
45
+
46
+ def _clip_grad(clr, grad, group_grad_clip):
47
+ if group_grad_clip > 0:
48
+ norm = grad.norm(2).item()
49
+ if norm > group_grad_clip:
50
+ clr *= group_grad_clip / (norm + 1e-10)
51
+ return clr
52
+
53
+
54
+ class AdagradWithGradClip(Adagrad):
55
+ """Adagrad algorithm with custom gradient clipping"""
56
+
57
+ def __init__(
58
+ self,
59
+ params,
60
+ lr=1e-2,
61
+ lr_decay=0,
62
+ weight_decay=0,
63
+ initial_accumulator_value=0,
64
+ grad_clip=0,
65
+ ):
66
+ Adagrad.__init__(
67
+ self,
68
+ params,
69
+ lr=lr,
70
+ lr_decay=lr_decay,
71
+ weight_decay=weight_decay,
72
+ initial_accumulator_value=initial_accumulator_value,
73
+ )
74
+ self.defaults["grad_clip"] = grad_clip
75
+ self.param_groups[0].setdefault("grad_clip", grad_clip)
76
+
77
+ def step(self, closure=None):
78
+ loss = None
79
+ if closure is not None:
80
+ loss = closure()
81
+
82
+ for group in self.param_groups:
83
+ for p in group["params"]:
84
+ if p.grad is None:
85
+ continue
86
+
87
+ grad = p.grad.data
88
+ state = self.state[p]
89
+
90
+ state["step"] += 1
91
+
92
+ if group["weight_decay"] != 0:
93
+ if p.grad.data.is_sparse:
94
+ raise RuntimeError(
95
+ "weight_decay option is "
96
+ "not compatible with sparse "
97
+ "gradients"
98
+ )
99
+ grad = grad.add(group["weight_decay"], p.data)
100
+
101
+ clr = group["lr"] / (1 + (state["step"] - 1) * group["lr_decay"])
102
+
103
+ # clip
104
+ clr = _clip_grad(clr=clr, grad=grad, group_grad_clip=group["grad_clip"])
105
+
106
+ if grad.is_sparse:
107
+ # the update is non-linear so indices must be unique
108
+ grad = grad.coalesce()
109
+ grad_indices = grad._indices()
110
+ grad_values = grad._values()
111
+ size = grad.size()
112
+
113
+ def make_sparse(values):
114
+ constructor = grad.new
115
+ if grad_indices.dim() == 0 or values.dim() == 0:
116
+ return constructor().resize_as_(grad)
117
+ return constructor(grad_indices, values, size)
118
+
119
+ state["sum"].add_(make_sparse(grad_values.pow(2)))
120
+ std = state["sum"]._sparse_mask(grad)
121
+ std_values = std._values().sqrt_().add_(1e-10)
122
+ p.data.add_(-clr, make_sparse(grad_values / std_values))
123
+ else:
124
+ state["sum"].addcmul_(1, grad, grad)
125
+ std = state["sum"].sqrt().add_(1e-10)
126
+ p.data.addcdiv_(-clr, grad, std)
127
+
128
+ return loss
examples/adaptive_span/adaptive_span_attention.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ import math
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+
12
+ class AdaptiveMask(nn.Module):
13
+ """Soft masking function for adaptive size.
14
+ It masks out the last K values of an input. The masking value
15
+ goes from 1 to 0 gradually, so K can be learned with
16
+ back-propagation.
17
+ Args:
18
+ max_size: maximum size (i.e. input dimension)
19
+ ramp_size: size of the ramp going from 0 to 1
20
+ init_val: initial size proportion not to be masked out
21
+ shape: learn multiple sizes independent of each other
22
+ """
23
+
24
+ def __init__(self, max_size, ramp_size, init_val=0, shape=(1,)):
25
+ nn.Module.__init__(self)
26
+ self._max_size = max_size
27
+ self._ramp_size = ramp_size
28
+ self.current_val = nn.Parameter(torch.zeros(*shape) + init_val)
29
+ mask_template = torch.linspace(1 - max_size, 0, steps=max_size)
30
+ self.register_buffer("mask_template", mask_template)
31
+
32
+ def forward(self, x):
33
+ mask = self.mask_template.float() + self.current_val.float() * self._max_size
34
+ mask = mask / self._ramp_size + 1
35
+ mask = mask.clamp(0, 1)
36
+ if x.size(-1) < self._max_size:
37
+ # the input could have been trimmed beforehand to save computation
38
+ mask = mask.narrow(-1, self._max_size - x.size(-1), x.size(-1))
39
+ x = (x * mask).type_as(x)
40
+ return x
41
+
42
+ def get_current_max_size(self, include_ramp=True):
43
+ current_size = math.ceil(self.current_val.max().item() * self._max_size)
44
+ if include_ramp:
45
+ current_size += self._ramp_size
46
+ current_size = max(0, min(self._max_size, current_size))
47
+ return current_size
48
+
49
+ def get_current_avg_size(self, include_ramp=True):
50
+ current_size = math.ceil(
51
+ self.current_val.float().mean().item() * self._max_size
52
+ )
53
+ if include_ramp:
54
+ current_size += self._ramp_size
55
+ current_size = max(0, min(self._max_size, current_size))
56
+ return current_size
57
+
58
+ def clamp_param(self):
59
+ """this need to be called after each update"""
60
+ self.current_val.data.clamp_(0, 1)
61
+
62
+
63
+ class AdaptiveSpan(nn.Module):
64
+ """Adaptive attention span for Transformerself.
65
+ This module learns an attention span length from data for each
66
+ self-attention head.
67
+ Args:
68
+ attn_span: maximum attention span
69
+ adapt_span_loss: loss coefficient for the span length
70
+ adapt_span_ramp: length of the masking ramp
71
+ adapt_span_init: initial size ratio
72
+ adapt_span_cache: adapt cache size to reduce memory usage
73
+ """
74
+
75
+ def __init__(
76
+ self,
77
+ attn_span,
78
+ adapt_span_ramp,
79
+ adapt_span_init,
80
+ n_head,
81
+ adapt_span_layer,
82
+ **kargs
83
+ ):
84
+ nn.Module.__init__(self)
85
+ self._max_span = attn_span
86
+ self._n_head = n_head
87
+ self._adapt_span_layer = adapt_span_layer
88
+ if self._adapt_span_layer:
89
+ self._mask = AdaptiveMask(
90
+ max_size=self._max_span,
91
+ ramp_size=adapt_span_ramp,
92
+ init_val=adapt_span_init,
93
+ )
94
+ else:
95
+ self._mask = AdaptiveMask(
96
+ max_size=self._max_span,
97
+ ramp_size=adapt_span_ramp,
98
+ init_val=adapt_span_init,
99
+ shape=(n_head, 1, 1),
100
+ )
101
+
102
+ def forward(self, attn, normalize=True):
103
+ """mask attention with the right span"""
104
+ # batch and head dimensions are merged together, so separate them first
105
+ self.clamp_param()
106
+ if self._adapt_span_layer:
107
+ attn = self._mask(attn)
108
+ else:
109
+ B = attn.size(0) # batch size
110
+ M = attn.size(1) # block size
111
+ attn = attn.reshape(B // self._n_head, self._n_head, M, -1)
112
+ attn = self._mask(attn)
113
+ attn = attn.view(B, M, -1)
114
+ return attn
115
+
116
+ def get_trim_len(self):
117
+ """how much of memory can be trimmed to reduce computation"""
118
+ L = self._max_span
119
+ trim_len = min(L - 1, L - self._mask.get_current_max_size())
120
+ # too fine granularity might be bad for the memory management
121
+ trim_len = math.floor(trim_len / 64) * 64
122
+ return trim_len
123
+
124
+ def trim_memory(self, query, key, value, key_pe):
125
+ """trim out unnecessary memory beforehand to reduce computation"""
126
+ trim_len = self.get_trim_len()
127
+ cache_size = key.size(1) - query.size(1)
128
+ trim_len_cache = trim_len - (self._max_span - cache_size)
129
+ if trim_len_cache > 0:
130
+ key = key[:, trim_len_cache:, :]
131
+ value = value[:, trim_len_cache:, :]
132
+ elif trim_len_cache < 0:
133
+ # cache is too short! this happens when validation resumes
134
+ # after a lot of updates.
135
+ key = F.pad(key, [0, 0, -trim_len_cache, 0])
136
+ value = F.pad(value, [0, 0, -trim_len_cache, 0])
137
+ if trim_len > 0:
138
+ if key_pe is not None:
139
+ key_pe = key_pe[:, :, trim_len:]
140
+ return key, value, key_pe
141
+
142
+ def get_cache_size(self):
143
+ """determine how long the cache should be"""
144
+ trim_len = self.get_trim_len()
145
+ # give a buffer of 64 steps since a span might increase
146
+ # in future updates
147
+ return min(self._max_span, self._max_span - trim_len + 64)
148
+
149
+ def get_loss(self):
150
+ """a loss term for regularizing the span length"""
151
+ return self._max_span * self._mask.current_val.float().mean()
152
+
153
+ def get_current_max_span(self):
154
+ return self._mask.get_current_max_size()
155
+
156
+ def get_current_avg_span(self):
157
+ return self._mask.get_current_avg_size()
158
+
159
+ def clamp_param(self):
160
+ self._mask.clamp_param()
examples/adaptive_span/adaptive_span_loss.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import math
7
+ from dataclasses import dataclass
8
+
9
+ import torch.nn.functional as F
10
+ from fairseq import metrics, utils
11
+ from fairseq.criterions import register_criterion
12
+ from fairseq.criterions.cross_entropy import CrossEntropyCriterion
13
+ from fairseq.dataclass import FairseqDataclass
14
+ from omegaconf import II
15
+
16
+
17
+ @dataclass
18
+ class AdaptiveSpanCriterionConfig(FairseqDataclass):
19
+ sentence_avg: bool = II("optimization.sentence_avg")
20
+
21
+
22
+ @register_criterion("adaptive_span_loss", dataclass=AdaptiveSpanCriterionConfig)
23
+ class AdaptiveSpanCriterion(CrossEntropyCriterion):
24
+ def __init__(self, task, sentence_avg):
25
+ super().__init__(task, sentence_avg)
26
+
27
+ def forward(self, model, sample, reduce=True):
28
+ """Compute the loss for the given sample.
29
+
30
+ Returns a tuple with three elements:
31
+ 1) the loss here is summed, different from the adaptive span code
32
+ 2) the sample size, which is used as the denominator for the gradient
33
+ 3) logging outputs to display while training
34
+ """
35
+ net_output = model(**sample["net_input"])
36
+ loss, aux_loss, avg_span, max_span = self.compute_loss(
37
+ model, net_output, sample, reduce=reduce
38
+ )
39
+ sample_size = (
40
+ sample["target"].size(0) if self.sentence_avg else sample["ntokens"]
41
+ )
42
+ loss /= sample_size
43
+ total_loss = loss + aux_loss
44
+ sample_size = 1
45
+
46
+ logging_output = {
47
+ "loss": loss.data,
48
+ "ntokens": sample["ntokens"],
49
+ "nsentences": sample["target"].size(0),
50
+ "sample_size": sample_size,
51
+ "total_loss": total_loss.data,
52
+ "avg_span": avg_span * sample_size,
53
+ "max_span": max_span * sample_size,
54
+ }
55
+ return total_loss, sample_size, logging_output
56
+
57
+ def compute_loss(self, model, net_output, sample, reduce=True):
58
+ loss, _ = super().compute_loss(model, net_output, sample, reduce)
59
+ aux_loss = model.get_aux_loss()
60
+ avg_span = model.get_current_avg_span()
61
+ max_span = model.get_current_max_span()
62
+ return loss, aux_loss, avg_span, max_span
63
+
64
+ @staticmethod
65
+ def reduce_metrics(logging_outputs) -> None:
66
+ """Aggregate logging outputs from data parallel training."""
67
+ loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
68
+ ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
69
+ sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
70
+ total_loss_sum = sum(log.get("total_loss", 0) for log in logging_outputs)
71
+ avg_span_sum = sum(log.get("avg_span", 0) for log in logging_outputs)
72
+ max_span_sum = sum(log.get("max_span", 0) for log in logging_outputs)
73
+
74
+ # we divide by log(2) to convert the loss from base e to base 2
75
+ metrics.log_scalar(
76
+ "loss", loss_sum / sample_size / math.log(2), sample_size, round=3
77
+ )
78
+ metrics.log_scalar("avg_span", avg_span_sum / sample_size, sample_size, round=3)
79
+ metrics.log_scalar("max_span", max_span_sum / sample_size, sample_size, round=3)
80
+ # total loss contains the L1 norm on adaptive-span
81
+ metrics.log_scalar(
82
+ "total_loss",
83
+ total_loss_sum / sample_size / math.log(2),
84
+ sample_size,
85
+ round=3,
86
+ )
87
+ if sample_size != ntokens:
88
+ metrics.log_scalar(
89
+ "nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3
90
+ )
91
+ metrics.log_derived(
92
+ "ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)
93
+ )
94
+ else:
95
+ metrics.log_derived(
96
+ "ppl", lambda meters: utils.get_perplexity(meters["loss"].avg)
97
+ )
98
+
99
+ @staticmethod
100
+ def logging_outputs_can_be_summed() -> bool:
101
+ """
102
+ Whether the logging outputs returned by `forward` can be summed
103
+ across workers prior to calling `reduce_metrics`. Setting this
104
+ to True will improves distributed training speed.
105
+ """
106
+ return True
examples/adaptive_span/adaptive_span_model.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+
13
+ from fairseq.modules.layer_norm import LayerNorm
14
+
15
+ from .adaptive_span_attention import AdaptiveSpan
16
+
17
+ # Size notations:
18
+ # B = batch_size, H = d_model, M = block_size, L = attn_span
19
+
20
+
21
+ def _skew(X, pad_value):
22
+ """shift every row 1 step to right"""
23
+ # X = B x M x L
24
+ B, M, L = X.size()
25
+ X = F.pad(X, (0, M + 1), value=pad_value) # B x M x (L+M+1)
26
+ X = X.view(B, -1) # B x ML+MM+M
27
+ X = X[:, :-M] # B x ML+MM
28
+ X = X.view(B, M, M + L) # B x M x L+M
29
+ return X
30
+
31
+
32
+ def _unskew(X):
33
+ """reverse _skew operation"""
34
+ # X = B x M x L+M
35
+ B, M, L = X.size()
36
+ L -= M
37
+ X = X.view(B, -1) # B x ML+MM
38
+ X = F.pad(X, (0, M)) # B x ML+MM+M
39
+ X = X.view(B, M, M + L + 1) # B x M x L+M+1
40
+ X = X[:, :, :L] # B x M x L
41
+ return X
42
+
43
+
44
+ class SeqAttention(nn.Module):
45
+ """Sequential self-attention layer.
46
+ Each token will attend to its previous fixed number of steps.
47
+ Note that attention doesn't include the current step itself.
48
+ """
49
+
50
+ def __init__(self, d_model, n_head, attn_span, dropout, adapt_span_layer, **kargs):
51
+ nn.Module.__init__(self)
52
+ self.dropout = nn.Dropout(dropout)
53
+ self.d_model = d_model # size of a single head
54
+ self.attn_span = attn_span
55
+ self.adaptive_span = AdaptiveSpan(
56
+ attn_span=attn_span,
57
+ n_head=n_head,
58
+ adapt_span_layer=adapt_span_layer,
59
+ **kargs
60
+ )
61
+
62
+ def forward(self, query, key, value, key_pe):
63
+ # query size = B x M x H
64
+ # key, value sizes = B x (M+L) x H
65
+
66
+ key, value, key_pe = self.adaptive_span.trim_memory(query, key, value, key_pe)
67
+
68
+ # compute attention from context
69
+ # B x M (dest) x (M+L) (src)
70
+ attn_cont = torch.matmul(query, key.transpose(-1, -2))
71
+ attn_cont = _unskew(attn_cont) # B x M x L
72
+
73
+ # compute the effect of position embedding
74
+ attn_pos = torch.matmul(query, key_pe) # B x M x L_pos
75
+ attn = attn_cont + attn_pos
76
+
77
+ attn = attn / math.sqrt(self.d_model) # B x M X L_pos
78
+
79
+ attn = F.softmax(attn.float(), dim=-1).type_as(attn)
80
+
81
+ # trim attention lengths according to the learned span
82
+ attn = self.adaptive_span(attn)
83
+
84
+ attn = self.dropout(attn) # B x M X L_pos
85
+
86
+ attn_cont = _skew(attn, 0) # B x M X (L+M)
87
+ out = torch.matmul(attn_cont, value) # B x M x H
88
+ return out
89
+
90
+ def get_cache_size(self):
91
+ return self.adaptive_span.get_cache_size()
92
+
93
+
94
+ class MultiHeadSeqAttention(nn.Module):
95
+ def __init__(self, d_model, n_head, **kargs):
96
+ nn.Module.__init__(self)
97
+ assert d_model % n_head == 0
98
+ self.n_head = n_head
99
+ self.head_dim = d_model // n_head
100
+ self.attn = SeqAttention(d_model=self.head_dim, n_head=n_head, **kargs)
101
+ self.proj_query = nn.Linear(d_model, d_model, bias=False)
102
+ nn.init.xavier_normal_(self.proj_query.weight)
103
+ self.proj_out = nn.Linear(d_model, d_model, bias=False)
104
+ nn.init.xavier_normal_(self.proj_out.weight)
105
+ self.proj_val = nn.Linear(d_model, d_model, bias=False)
106
+ nn.init.xavier_normal_(self.proj_val.weight)
107
+ self.proj_key = nn.Linear(d_model, d_model, bias=False)
108
+ nn.init.xavier_normal_(self.proj_key.weight)
109
+
110
+ def head_reshape(self, x):
111
+ K = self.n_head
112
+ D = self.head_dim
113
+ x = x.view(x.size()[:-1] + (K, D)) # B x (M+L) x K x D
114
+ x = x.transpose(1, 2).contiguous() # B x K x (M+L) x D
115
+ x = x.view(-1, x.size(-2), x.size(-1)) # B_K x (M+L) x D
116
+ return x
117
+
118
+ def forward(self, query, key, value, key_pe):
119
+ B = query.size(0)
120
+ K = self.n_head
121
+ D = self.head_dim
122
+ M = query.size(1)
123
+
124
+ query = self.proj_query(query)
125
+ query = self.head_reshape(query)
126
+ value = self.proj_val(value)
127
+ value = self.head_reshape(value)
128
+ key = self.proj_key(key)
129
+ key = self.head_reshape(key)
130
+
131
+ out = self.attn(query, key, value, key_pe) # B_K x M x D
132
+ out = out.view(B, K, M, D) # B x K x M x D
133
+ out = out.transpose(1, 2).contiguous() # B x M x K x D
134
+ out = out.view(B, M, -1) # B x M x K_D
135
+ out = self.proj_out(out)
136
+ return out
137
+
138
+
139
+ class FeedForwardLayer(nn.Module):
140
+ def __init__(self, d_model, d_inner, dropout, **kargs):
141
+ nn.Module.__init__(self)
142
+ self.fc1 = nn.Linear(d_model, d_inner)
143
+ self.fc2 = nn.Linear(d_inner, d_model)
144
+ nn.init.xavier_uniform_(self.fc1.weight)
145
+ nn.init.xavier_uniform_(self.fc2.weight)
146
+ self.dropout = nn.Dropout(dropout)
147
+
148
+ def forward(self, h):
149
+ h1 = F.relu(self.fc1(h))
150
+ h1 = self.dropout(h1)
151
+ h2 = self.fc2(h1)
152
+ return h2
153
+
154
+
155
+ class TransformerSeqLayer(nn.Module):
156
+ def __init__(self, d_model, **kargs):
157
+ nn.Module.__init__(self)
158
+ self.attn = MultiHeadSeqAttention(d_model=d_model, **kargs)
159
+ self.norm1 = LayerNorm(d_model)
160
+ self.ff = FeedForwardLayer(d_model=d_model, **kargs)
161
+ self.norm2 = LayerNorm(d_model)
162
+
163
+ def forward(self, h, h_cache, key_pe):
164
+ # h = B x M x H
165
+ # h_cache = B x L x H
166
+ h_all = torch.cat([h_cache, h], dim=1) # B x (M+L) x H
167
+ attn_out = self.attn(h, h_all, h_all, key_pe)
168
+ h = self.norm1(h + attn_out) # B x M x H
169
+ if self.ff is not None:
170
+ ff_out = self.ff(h)
171
+ out = self.norm2(h + ff_out) # B x M x H
172
+ else:
173
+ out = h
174
+ return out
175
+
176
+ def get_cache_size(self):
177
+ return self.attn.attn.get_cache_size()
178
+
179
+
180
+ class TransformerSeq(nn.Module):
181
+ def __init__(
182
+ self,
183
+ vocab_size,
184
+ d_model,
185
+ n_head,
186
+ n_layer,
187
+ attn_span,
188
+ emb_dropout,
189
+ aux_loss_scaler,
190
+ adapt_span_layer,
191
+ **kargs
192
+ ):
193
+ nn.Module.__init__(self)
194
+ # token embeddings
195
+ self.in_emb = nn.Embedding(vocab_size, d_model)
196
+ nn.init.normal_(self.in_emb.weight, mean=0, std=d_model ** -0.5)
197
+ self.out_emb = nn.Linear(d_model, vocab_size)
198
+ self.aux_loss_scaler = aux_loss_scaler
199
+ if emb_dropout > 0:
200
+ self.emb_dropout = nn.Dropout(emb_dropout)
201
+ else:
202
+ self.emb_dropout = None
203
+ # position embeddings
204
+ self.key_pe = nn.Parameter(torch.randn(1, d_model // n_head, attn_span))
205
+
206
+ self.layers = nn.ModuleList()
207
+ self.layers.extend(
208
+ TransformerSeqLayer(
209
+ d_model=d_model,
210
+ n_head=n_head,
211
+ attn_span=attn_span,
212
+ adapt_span_layer=adapt_span_layer,
213
+ **kargs
214
+ )
215
+ for _ in range(n_layer)
216
+ )
217
+
218
+ def forward(self, x, h_cache, target=None):
219
+ # x size = B x M
220
+ block_size = x.size(1)
221
+ h = self.in_emb(x) # B x M x H
222
+ if self.emb_dropout is not None:
223
+ h = self.emb_dropout(h)
224
+
225
+ h_cache_next = []
226
+ for l, layer in enumerate(self.layers):
227
+ cache_size = layer.attn.attn.get_cache_size()
228
+ if cache_size > block_size:
229
+ h_cache_next_l = torch.cat(
230
+ [h_cache[l][:, -cache_size + block_size :, :], h], dim=1
231
+ ).detach()
232
+ else:
233
+ h_cache_next_l = h[:, -cache_size:, :].detach()
234
+ h_cache_next.append(h_cache_next_l)
235
+ h = layer(h, h_cache[l], self.key_pe) # B x M x H
236
+
237
+ if self.emb_dropout is not None:
238
+ h = self.emb_dropout(h)
239
+
240
+ out = F.log_softmax(self.out_emb(h).float(), dim=-1).type_as(h)
241
+ dummy_loss = None
242
+
243
+ return out, h_cache_next, dummy_loss
244
+
245
+ def get_aux_loss(self):
246
+ loss = 0.0
247
+ for layer in self.layers:
248
+ loss += layer.attn.attn.adaptive_span.get_loss()
249
+ return self.aux_loss_scaler * loss
250
+
251
+ def get_current_max_span(self):
252
+ max_span = 0.0
253
+ for layer in self.layers:
254
+ max_span = max(
255
+ max_span, layer.attn.attn.adaptive_span.get_current_max_span()
256
+ )
257
+ return max_span
258
+
259
+ def get_current_avg_span(self):
260
+ avg_span = 0.0
261
+ for layer in self.layers:
262
+ avg_span += layer.attn.attn.adaptive_span.get_current_avg_span()
263
+ return avg_span / len(self.layers)
examples/adaptive_span/adaptive_span_model_wrapper.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import logging
7
+ from dataclasses import dataclass
8
+ from typing import Dict, List, Optional
9
+
10
+ import torch
11
+ from fairseq.dataclass import FairseqDataclass
12
+ from fairseq.models import (
13
+ FairseqIncrementalDecoder,
14
+ FairseqLanguageModel,
15
+ register_model,
16
+ )
17
+ from .adaptive_span_model import TransformerSeq as AdaptiveSpanTransformerModel
18
+
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ @dataclass
24
+ class AdaptiveSpanSmallConfig(FairseqDataclass):
25
+ # defaults come from https://github.com/facebookresearch/adaptive-span/blob/master/experiments/enwik8_small.sh
26
+ vocab_size: int = 50
27
+ d_model: int = 256
28
+ n_head: int = 4
29
+ d_inner: int = 1024
30
+ n_layer: int = 8
31
+ attn_span: int = 1024
32
+ dropout: float = 0.0
33
+ emb_dropout: float = 0.0
34
+ adapt_span_ramp: int = 32
35
+ adapt_span_init: float = 0.0
36
+ aux_loss_scaler: float = 0.000002
37
+ adapt_span_layer: bool = False
38
+
39
+
40
+ @register_model("adaptive_span", dataclass=AdaptiveSpanSmallConfig)
41
+ class AdaptiveSpanTransformer(FairseqLanguageModel):
42
+ @classmethod
43
+ def build_model(cls, cfg: AdaptiveSpanSmallConfig, task):
44
+ return cls(AdaptiveSpanDecoder(cfg, task))
45
+
46
+ def get_aux_loss(self):
47
+ return self.decoder.get_aux_loss()
48
+
49
+ def get_current_max_span(self):
50
+ return self.decoder.get_current_max_span()
51
+
52
+ def get_current_avg_span(self):
53
+ return self.decoder.get_current_avg_span()
54
+
55
+
56
+ class AdaptiveSpanDecoder(FairseqIncrementalDecoder):
57
+ def __init__(self, cfg, task):
58
+
59
+ super().__init__(task.target_dictionary)
60
+
61
+ self.config = cfg
62
+ config = AdaptiveSpanSmallConfig(
63
+ vocab_size=len(task.target_dictionary),
64
+ d_model=cfg.d_model,
65
+ n_head=cfg.n_head,
66
+ d_inner=cfg.d_inner,
67
+ n_layer=cfg.n_layer,
68
+ attn_span=cfg.attn_span,
69
+ dropout=cfg.dropout,
70
+ emb_dropout=cfg.emb_dropout,
71
+ adapt_span_ramp=cfg.adapt_span_ramp,
72
+ adapt_span_init=cfg.adapt_span_init,
73
+ aux_loss_scaler=cfg.aux_loss_scaler,
74
+ adapt_span_layer=cfg.adapt_span_layer,
75
+ )
76
+ logger.info(config)
77
+ self.model = AdaptiveSpanTransformerModel(**config.__dict__)
78
+
79
+ self._mems = None
80
+
81
+ def forward(
82
+ self,
83
+ src_tokens,
84
+ incremental_state: Optional[Dict[str, List[torch.Tensor]]] = None,
85
+ encoder_out=None,
86
+ ):
87
+ bsz = src_tokens.size(0)
88
+ if incremental_state is not None: # used during inference
89
+ mems = self.get_incremental_state("mems")
90
+ src_tokens = src_tokens[:, -1:] # only keep the most recent token
91
+ else:
92
+ mems = self._mems
93
+
94
+ if mems is None:
95
+ # first time init
96
+ mems = self.init_hid_cache(bsz)
97
+ output = self.model(x=src_tokens, h_cache=mems,)
98
+ if incremental_state is not None:
99
+ self.set_incremental_state(incremental_state, "mems", output[1])
100
+ else:
101
+ self._mems = output[1]
102
+ return (output[0],)
103
+
104
+ def max_positions(self):
105
+ return self.config.attn_span
106
+
107
+ def init_hid_cache(self, batch_sz):
108
+ hid = []
109
+ for layer in self.model.layers:
110
+ param = next(self.model.parameters())
111
+ h = torch.zeros(
112
+ batch_sz,
113
+ layer.get_cache_size(),
114
+ self.config.d_model,
115
+ dtype=param.dtype,
116
+ device=param.device,
117
+ )
118
+ hid.append(h)
119
+ return hid
120
+
121
+ def get_aux_loss(self):
122
+ return self.model.get_aux_loss()
123
+
124
+ def get_current_max_span(self):
125
+ return self.model.get_current_max_span()
126
+
127
+ def get_current_avg_span(self):
128
+ return self.model.get_current_avg_span()
129
+
130
+ def reorder_incremental_state(
131
+ self,
132
+ incremental_state: Dict[str, Dict[str, Optional[torch.Tensor]]],
133
+ new_order: torch.Tensor,
134
+ ):
135
+ """Reorder incremental state.
136
+
137
+ This will be called when the order of the input has changed from the
138
+ previous time step. A typical use case is beam search, where the input
139
+ order changes between time steps based on the selection of beams.
140
+ """
141
+ raise NotImplementedError("This is required for generation/beam search")
142
+ # mems = self.get_incremental_state(incremental_state, "mems")
143
+ # if mems is not None:
144
+ # new_mems = [mems_i.index_select(1, new_order) for mems_i in mems]
145
+ # self.set_incremental_state(incremental_state, "mems", new_mems)
examples/adaptive_span/truncated_bptt_lm_task.py ADDED
@@ -0,0 +1 @@
 
 
1
+ ../truncated_bptt/truncated_bptt_lm_task.py
examples/backtranslation/README.md ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Understanding Back-Translation at Scale (Edunov et al., 2018)
2
+
3
+ This page includes pre-trained models from the paper [Understanding Back-Translation at Scale (Edunov et al., 2018)](https://arxiv.org/abs/1808.09381).
4
+
5
+ ## Pre-trained models
6
+
7
+ Model | Description | Dataset | Download
8
+ ---|---|---|---
9
+ `transformer.wmt18.en-de` | Transformer <br> ([Edunov et al., 2018](https://arxiv.org/abs/1808.09381)) <br> WMT'18 winner | [WMT'18 English-German](http://www.statmt.org/wmt18/translation-task.html) | [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/wmt18.en-de.ensemble.tar.gz) <br> See NOTE in the archive
10
+
11
+ ## Example usage (torch.hub)
12
+
13
+ We require a few additional Python dependencies for preprocessing:
14
+ ```bash
15
+ pip install subword_nmt sacremoses
16
+ ```
17
+
18
+ Then to generate translations from the full model ensemble:
19
+ ```python
20
+ import torch
21
+
22
+ # List available models
23
+ torch.hub.list('pytorch/fairseq') # [..., 'transformer.wmt18.en-de', ... ]
24
+
25
+ # Load the WMT'18 En-De ensemble
26
+ en2de_ensemble = torch.hub.load(
27
+ 'pytorch/fairseq', 'transformer.wmt18.en-de',
28
+ checkpoint_file='wmt18.model1.pt:wmt18.model2.pt:wmt18.model3.pt:wmt18.model4.pt:wmt18.model5.pt',
29
+ tokenizer='moses', bpe='subword_nmt')
30
+
31
+ # The ensemble contains 5 models
32
+ len(en2de_ensemble.models)
33
+ # 5
34
+
35
+ # Translate
36
+ en2de_ensemble.translate('Hello world!')
37
+ # 'Hallo Welt!'
38
+ ```
39
+
40
+ ## Training your own model (WMT'18 English-German)
41
+
42
+ The following instructions can be adapted to reproduce the models from the paper.
43
+
44
+
45
+ #### Step 1. Prepare parallel data and optionally train a baseline (English-German) model
46
+
47
+ First download and preprocess the data:
48
+ ```bash
49
+ # Download and prepare the data
50
+ cd examples/backtranslation/
51
+ bash prepare-wmt18en2de.sh
52
+ cd ../..
53
+
54
+ # Binarize the data
55
+ TEXT=examples/backtranslation/wmt18_en_de
56
+ fairseq-preprocess \
57
+ --joined-dictionary \
58
+ --source-lang en --target-lang de \
59
+ --trainpref $TEXT/train --validpref $TEXT/valid --testpref $TEXT/test \
60
+ --destdir data-bin/wmt18_en_de --thresholdtgt 0 --thresholdsrc 0 \
61
+ --workers 20
62
+
63
+ # Copy the BPE code into the data-bin directory for future use
64
+ cp examples/backtranslation/wmt18_en_de/code data-bin/wmt18_en_de/code
65
+ ```
66
+
67
+ (Optionally) Train a baseline model (English-German) using just the parallel data:
68
+ ```bash
69
+ CHECKPOINT_DIR=checkpoints_en_de_parallel
70
+ fairseq-train --fp16 \
71
+ data-bin/wmt18_en_de \
72
+ --source-lang en --target-lang de \
73
+ --arch transformer_wmt_en_de_big --share-all-embeddings \
74
+ --dropout 0.3 --weight-decay 0.0 \
75
+ --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
76
+ --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
77
+ --lr 0.001 --lr-scheduler inverse_sqrt --warmup-updates 4000 \
78
+ --max-tokens 3584 --update-freq 16 \
79
+ --max-update 30000 \
80
+ --save-dir $CHECKPOINT_DIR
81
+ # Note: the above command assumes 8 GPUs. Adjust `--update-freq` if you have a
82
+ # different number of GPUs.
83
+ ```
84
+
85
+ Average the last 10 checkpoints:
86
+ ```bash
87
+ python scripts/average_checkpoints.py \
88
+ --inputs $CHECKPOINT_DIR \
89
+ --num-epoch-checkpoints 10 \
90
+ --output $CHECKPOINT_DIR/checkpoint.avg10.pt
91
+ ```
92
+
93
+ Evaluate BLEU:
94
+ ```bash
95
+ # tokenized BLEU on newstest2017:
96
+ bash examples/backtranslation/tokenized_bleu.sh \
97
+ wmt17 \
98
+ en-de \
99
+ data-bin/wmt18_en_de \
100
+ data-bin/wmt18_en_de/code \
101
+ $CHECKPOINT_DIR/checkpoint.avg10.pt
102
+ # BLEU4 = 29.57, 60.9/35.4/22.9/15.5 (BP=1.000, ratio=1.014, syslen=63049, reflen=62152)
103
+ # compare to 29.46 in Table 1, which is also for tokenized BLEU
104
+
105
+ # generally it's better to report (detokenized) sacrebleu though:
106
+ bash examples/backtranslation/sacrebleu.sh \
107
+ wmt17 \
108
+ en-de \
109
+ data-bin/wmt18_en_de \
110
+ data-bin/wmt18_en_de/code \
111
+ $CHECKPOINT_DIR/checkpoint.avg10.pt
112
+ # BLEU+case.mixed+lang.en-de+numrefs.1+smooth.exp+test.wmt17+tok.13a+version.1.4.3 = 29.0 60.6/34.7/22.4/14.9 (BP = 1.000 ratio = 1.013 hyp_len = 62099 ref_len = 61287)
113
+ ```
114
+
115
+
116
+ #### Step 2. Back-translate monolingual German data
117
+
118
+ Train a reverse model (German-English) to do the back-translation:
119
+ ```bash
120
+ CHECKPOINT_DIR=checkpoints_de_en_parallel
121
+ fairseq-train --fp16 \
122
+ data-bin/wmt18_en_de \
123
+ --source-lang de --target-lang en \
124
+ --arch transformer_wmt_en_de_big --share-all-embeddings \
125
+ --dropout 0.3 --weight-decay 0.0 \
126
+ --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
127
+ --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
128
+ --lr 0.001 --lr-scheduler inverse_sqrt --warmup-updates 4000 \
129
+ --max-tokens 3584 --update-freq 16 \
130
+ --max-update 30000 \
131
+ --save-dir $CHECKPOINT_DIR
132
+ # Note: the above command assumes 8 GPUs. Adjust `--update-freq` if you have a
133
+ # different number of GPUs.
134
+ ```
135
+
136
+ Let's evaluate the back-translation (BT) model to make sure it is well trained:
137
+ ```bash
138
+ bash examples/backtranslation/sacrebleu.sh \
139
+ wmt17 \
140
+ de-en \
141
+ data-bin/wmt18_en_de \
142
+ data-bin/wmt18_en_de/code \
143
+ $CHECKPOINT_DIR/checkpoint_best.py
144
+ # BLEU+case.mixed+lang.de-en+numrefs.1+smooth.exp+test.wmt17+tok.13a+version.1.4.3 = 34.9 66.9/41.8/28.5/19.9 (BP = 0.983 ratio = 0.984 hyp_len = 63342 ref_len = 64399)
145
+ # compare to the best system from WMT'17 which scored 35.1: http://matrix.statmt.org/matrix/systems_list/1868
146
+ ```
147
+
148
+ Next prepare the monolingual data:
149
+ ```bash
150
+ # Download and prepare the monolingual data
151
+ # By default the script samples 25M monolingual sentences, which after
152
+ # deduplication should be just over 24M sentences. These are split into 25
153
+ # shards, each with 1M sentences (except for the last shard).
154
+ cd examples/backtranslation/
155
+ bash prepare-de-monolingual.sh
156
+ cd ../..
157
+
158
+ # Binarize each shard of the monolingual data
159
+ TEXT=examples/backtranslation/wmt18_de_mono
160
+ for SHARD in $(seq -f "%02g" 0 24); do \
161
+ fairseq-preprocess \
162
+ --only-source \
163
+ --source-lang de --target-lang en \
164
+ --joined-dictionary \
165
+ --srcdict data-bin/wmt18_en_de/dict.de.txt \
166
+ --testpref $TEXT/bpe.monolingual.dedup.${SHARD} \
167
+ --destdir data-bin/wmt18_de_mono/shard${SHARD} \
168
+ --workers 20; \
169
+ cp data-bin/wmt18_en_de/dict.en.txt data-bin/wmt18_de_mono/shard${SHARD}/; \
170
+ done
171
+ ```
172
+
173
+ Now we're ready to perform back-translation over the monolingual data. The
174
+ following command generates via sampling, but it's possible to use greedy
175
+ decoding (`--beam 1`), beam search (`--beam 5`),
176
+ top-k sampling (`--sampling --beam 1 --sampling-topk 10`), etc.:
177
+ ```bash
178
+ mkdir backtranslation_output
179
+ for SHARD in $(seq -f "%02g" 0 24); do \
180
+ fairseq-generate --fp16 \
181
+ data-bin/wmt18_de_mono/shard${SHARD} \
182
+ --path $CHECKPOINT_DIR/checkpoint_best.pt \
183
+ --skip-invalid-size-inputs-valid-test \
184
+ --max-tokens 4096 \
185
+ --sampling --beam 1 \
186
+ > backtranslation_output/sampling.shard${SHARD}.out; \
187
+ done
188
+ ```
189
+
190
+ After BT, use the `extract_bt_data.py` script to re-combine the shards, extract
191
+ the back-translations and apply length ratio filters:
192
+ ```bash
193
+ python examples/backtranslation/extract_bt_data.py \
194
+ --minlen 1 --maxlen 250 --ratio 1.5 \
195
+ --output backtranslation_output/bt_data --srclang en --tgtlang de \
196
+ backtranslation_output/sampling.shard*.out
197
+
198
+ # Ensure lengths are the same:
199
+ # wc -l backtranslation_output/bt_data.{en,de}
200
+ # 21795614 backtranslation_output/bt_data.en
201
+ # 21795614 backtranslation_output/bt_data.de
202
+ # 43591228 total
203
+ ```
204
+
205
+ Binarize the filtered BT data and combine it with the parallel data:
206
+ ```bash
207
+ TEXT=backtranslation_output
208
+ fairseq-preprocess \
209
+ --source-lang en --target-lang de \
210
+ --joined-dictionary \
211
+ --srcdict data-bin/wmt18_en_de/dict.en.txt \
212
+ --trainpref $TEXT/bt_data \
213
+ --destdir data-bin/wmt18_en_de_bt \
214
+ --workers 20
215
+
216
+ # We want to train on the combined data, so we'll symlink the parallel + BT data
217
+ # in the wmt18_en_de_para_plus_bt directory. We link the parallel data as "train"
218
+ # and the BT data as "train1", so that fairseq will combine them automatically
219
+ # and so that we can use the `--upsample-primary` option to upsample the
220
+ # parallel data (if desired).
221
+ PARA_DATA=$(readlink -f data-bin/wmt18_en_de)
222
+ BT_DATA=$(readlink -f data-bin/wmt18_en_de_bt)
223
+ COMB_DATA=data-bin/wmt18_en_de_para_plus_bt
224
+ mkdir -p $COMB_DATA
225
+ for LANG in en de; do \
226
+ ln -s ${PARA_DATA}/dict.$LANG.txt ${COMB_DATA}/dict.$LANG.txt; \
227
+ for EXT in bin idx; do \
228
+ ln -s ${PARA_DATA}/train.en-de.$LANG.$EXT ${COMB_DATA}/train.en-de.$LANG.$EXT; \
229
+ ln -s ${BT_DATA}/train.en-de.$LANG.$EXT ${COMB_DATA}/train1.en-de.$LANG.$EXT; \
230
+ ln -s ${PARA_DATA}/valid.en-de.$LANG.$EXT ${COMB_DATA}/valid.en-de.$LANG.$EXT; \
231
+ ln -s ${PARA_DATA}/test.en-de.$LANG.$EXT ${COMB_DATA}/test.en-de.$LANG.$EXT; \
232
+ done; \
233
+ done
234
+ ```
235
+
236
+
237
+ #### 3. Train an English-German model over the combined parallel + BT data
238
+
239
+ Finally we can train a model over the parallel + BT data:
240
+ ```bash
241
+ CHECKPOINT_DIR=checkpoints_en_de_parallel_plus_bt
242
+ fairseq-train --fp16 \
243
+ data-bin/wmt18_en_de_para_plus_bt \
244
+ --upsample-primary 16 \
245
+ --source-lang en --target-lang de \
246
+ --arch transformer_wmt_en_de_big --share-all-embeddings \
247
+ --dropout 0.3 --weight-decay 0.0 \
248
+ --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
249
+ --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
250
+ --lr 0.0007 --lr-scheduler inverse_sqrt --warmup-updates 4000 \
251
+ --max-tokens 3584 --update-freq 16 \
252
+ --max-update 100000 \
253
+ --save-dir $CHECKPOINT_DIR
254
+ # Note: the above command assumes 8 GPUs. Adjust `--update-freq` if you have a
255
+ # different number of GPUs.
256
+ ```
257
+
258
+ Average the last 10 checkpoints:
259
+ ```bash
260
+ python scripts/average_checkpoints.py \
261
+ --inputs $CHECKPOINT_DIR \
262
+ --num-epoch-checkpoints 10 \
263
+ --output $CHECKPOINT_DIR/checkpoint.avg10.pt
264
+ ```
265
+
266
+ Evaluate BLEU:
267
+ ```bash
268
+ # tokenized BLEU on newstest2017:
269
+ bash examples/backtranslation/tokenized_bleu.sh \
270
+ wmt17 \
271
+ en-de \
272
+ data-bin/wmt18_en_de \
273
+ data-bin/wmt18_en_de/code \
274
+ $CHECKPOINT_DIR/checkpoint.avg10.pt
275
+ # BLEU4 = 32.35, 64.4/38.9/26.2/18.3 (BP=0.977, ratio=0.977, syslen=60729, reflen=62152)
276
+ # compare to 32.35 in Table 1, which is also for tokenized BLEU
277
+
278
+ # generally it's better to report (detokenized) sacrebleu:
279
+ bash examples/backtranslation/sacrebleu.sh \
280
+ wmt17 \
281
+ en-de \
282
+ data-bin/wmt18_en_de \
283
+ data-bin/wmt18_en_de/code \
284
+ $CHECKPOINT_DIR/checkpoint.avg10.pt
285
+ # BLEU+case.mixed+lang.en-de+numrefs.1+smooth.exp+test.wmt17+tok.13a+version.1.4.3 = 31.5 64.3/38.2/25.6/17.6 (BP = 0.971 ratio = 0.971 hyp_len = 59515 ref_len = 61287)
286
+ ```
287
+
288
+
289
+ ## Citation
290
+ ```bibtex
291
+ @inproceedings{edunov2018backtranslation,
292
+ title = {Understanding Back-Translation at Scale},
293
+ author = {Edunov, Sergey and Ott, Myle and Auli, Michael and Grangier, David},
294
+ booktitle = {Conference of the Association for Computational Linguistics (ACL)},
295
+ year = 2018,
296
+ }
297
+ ```
examples/backtranslation/deduplicate_lines.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+ #
4
+ # This source code is licensed under the MIT license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import argparse
8
+ import fileinput
9
+ import hashlib
10
+ import sys
11
+ from multiprocessing import Pool
12
+
13
+
14
+ def get_hashes_and_lines(raw_line):
15
+ hash = hashlib.md5(raw_line).hexdigest()
16
+ return hash, raw_line
17
+
18
+
19
+ def main():
20
+ parser = argparse.ArgumentParser()
21
+ parser.add_argument("--workers", type=int, default=10)
22
+ parser.add_argument("files", nargs="*", help="input files")
23
+ args = parser.parse_args()
24
+
25
+ seen = set()
26
+ with fileinput.input(args.files, mode="rb") as h:
27
+ pool = Pool(args.workers)
28
+ results = pool.imap_unordered(get_hashes_and_lines, h, 1000)
29
+ for i, (hash, raw_line) in enumerate(results):
30
+ if hash not in seen:
31
+ seen.add(hash)
32
+ sys.stdout.buffer.write(raw_line)
33
+ if i % 1000000 == 0:
34
+ print(i, file=sys.stderr, end="", flush=True)
35
+ elif i % 100000 == 0:
36
+ print(".", file=sys.stderr, end="", flush=True)
37
+ print(file=sys.stderr, flush=True)
38
+
39
+
40
+ if __name__ == "__main__":
41
+ main()
examples/backtranslation/extract_bt_data.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+ #
4
+ # This source code is licensed under the MIT license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import argparse
8
+ import fileinput
9
+
10
+ from tqdm import tqdm
11
+
12
+
13
+ def main():
14
+ parser = argparse.ArgumentParser(
15
+ description=(
16
+ "Extract back-translations from the stdout of fairseq-generate. "
17
+ "If there are multiply hypotheses for a source, we only keep the first one. "
18
+ )
19
+ )
20
+ parser.add_argument("--output", required=True, help="output prefix")
21
+ parser.add_argument(
22
+ "--srclang", required=True, help="source language (extracted from H-* lines)"
23
+ )
24
+ parser.add_argument(
25
+ "--tgtlang", required=True, help="target language (extracted from S-* lines)"
26
+ )
27
+ parser.add_argument("--minlen", type=int, help="min length filter")
28
+ parser.add_argument("--maxlen", type=int, help="max length filter")
29
+ parser.add_argument("--ratio", type=float, help="ratio filter")
30
+ parser.add_argument("files", nargs="*", help="input files")
31
+ args = parser.parse_args()
32
+
33
+ def validate(src, tgt):
34
+ srclen = len(src.split(" ")) if src != "" else 0
35
+ tgtlen = len(tgt.split(" ")) if tgt != "" else 0
36
+ if (
37
+ (args.minlen is not None and (srclen < args.minlen or tgtlen < args.minlen))
38
+ or (
39
+ args.maxlen is not None
40
+ and (srclen > args.maxlen or tgtlen > args.maxlen)
41
+ )
42
+ or (
43
+ args.ratio is not None
44
+ and (max(srclen, tgtlen) / float(min(srclen, tgtlen)) > args.ratio)
45
+ )
46
+ ):
47
+ return False
48
+ return True
49
+
50
+ def safe_index(toks, index, default):
51
+ try:
52
+ return toks[index]
53
+ except IndexError:
54
+ return default
55
+
56
+ with open(args.output + "." + args.srclang, "w") as src_h, open(
57
+ args.output + "." + args.tgtlang, "w"
58
+ ) as tgt_h:
59
+ for line in tqdm(fileinput.input(args.files)):
60
+ if line.startswith("S-"):
61
+ tgt = safe_index(line.rstrip().split("\t"), 1, "")
62
+ elif line.startswith("H-"):
63
+ if tgt is not None:
64
+ src = safe_index(line.rstrip().split("\t"), 2, "")
65
+ if validate(src, tgt):
66
+ print(src, file=src_h)
67
+ print(tgt, file=tgt_h)
68
+ tgt = None
69
+
70
+
71
+ if __name__ == "__main__":
72
+ main()
examples/backtranslation/prepare-de-monolingual.sh ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ SCRIPTS=mosesdecoder/scripts
4
+ TOKENIZER=$SCRIPTS/tokenizer/tokenizer.perl
5
+ NORM_PUNC=$SCRIPTS/tokenizer/normalize-punctuation.perl
6
+ REM_NON_PRINT_CHAR=$SCRIPTS/tokenizer/remove-non-printing-char.perl
7
+ BPEROOT=subword-nmt/subword_nmt
8
+
9
+
10
+ BPE_CODE=wmt18_en_de/code
11
+ SUBSAMPLE_SIZE=25000000
12
+ LANG=de
13
+
14
+
15
+ OUTDIR=wmt18_${LANG}_mono
16
+ orig=orig
17
+ tmp=$OUTDIR/tmp
18
+ mkdir -p $OUTDIR $tmp
19
+
20
+
21
+ URLS=(
22
+ "http://www.statmt.org/wmt14/training-monolingual-news-crawl/news.2007.de.shuffled.gz"
23
+ "http://www.statmt.org/wmt14/training-monolingual-news-crawl/news.2008.de.shuffled.gz"
24
+ "http://www.statmt.org/wmt14/training-monolingual-news-crawl/news.2009.de.shuffled.gz"
25
+ "http://www.statmt.org/wmt14/training-monolingual-news-crawl/news.2010.de.shuffled.gz"
26
+ "http://www.statmt.org/wmt14/training-monolingual-news-crawl/news.2011.de.shuffled.gz"
27
+ "http://www.statmt.org/wmt14/training-monolingual-news-crawl/news.2012.de.shuffled.gz"
28
+ "http://www.statmt.org/wmt14/training-monolingual-news-crawl/news.2013.de.shuffled.gz"
29
+ "http://www.statmt.org/wmt15/training-monolingual-news-crawl-v2/news.2014.de.shuffled.v2.gz"
30
+ "http://data.statmt.org/wmt16/translation-task/news.2015.de.shuffled.gz"
31
+ "http://data.statmt.org/wmt17/translation-task/news.2016.de.shuffled.gz"
32
+ "http://data.statmt.org/wmt18/translation-task/news.2017.de.shuffled.deduped.gz"
33
+ )
34
+ FILES=(
35
+ "news.2007.de.shuffled.gz"
36
+ "news.2008.de.shuffled.gz"
37
+ "news.2009.de.shuffled.gz"
38
+ "news.2010.de.shuffled.gz"
39
+ "news.2011.de.shuffled.gz"
40
+ "news.2012.de.shuffled.gz"
41
+ "news.2013.de.shuffled.gz"
42
+ "news.2014.de.shuffled.v2.gz"
43
+ "news.2015.de.shuffled.gz"
44
+ "news.2016.de.shuffled.gz"
45
+ "news.2017.de.shuffled.deduped.gz"
46
+ )
47
+
48
+
49
+ cd $orig
50
+ for ((i=0;i<${#URLS[@]};++i)); do
51
+ file=${FILES[i]}
52
+ if [ -f $file ]; then
53
+ echo "$file already exists, skipping download"
54
+ else
55
+ url=${URLS[i]}
56
+ wget "$url"
57
+ fi
58
+ done
59
+ cd ..
60
+
61
+
62
+ if [ -f $tmp/monolingual.${SUBSAMPLE_SIZE}.${LANG} ]; then
63
+ echo "found monolingual sample, skipping shuffle/sample/tokenize"
64
+ else
65
+ gzip -c -d -k $(for FILE in "${FILES[@]}"; do echo $orig/$FILE; done) \
66
+ | shuf -n $SUBSAMPLE_SIZE \
67
+ | perl $NORM_PUNC $LANG \
68
+ | perl $REM_NON_PRINT_CHAR \
69
+ | perl $TOKENIZER -threads 8 -a -l $LANG \
70
+ > $tmp/monolingual.${SUBSAMPLE_SIZE}.${LANG}
71
+ fi
72
+
73
+
74
+ if [ -f $tmp/bpe.monolingual.${SUBSAMPLE_SIZE}.${LANG} ]; then
75
+ echo "found BPE monolingual sample, skipping BPE step"
76
+ else
77
+ python $BPEROOT/apply_bpe.py -c $BPE_CODE \
78
+ < $tmp/monolingual.${SUBSAMPLE_SIZE}.${LANG} \
79
+ > $tmp/bpe.monolingual.${SUBSAMPLE_SIZE}.${LANG}
80
+ fi
81
+
82
+
83
+ if [ -f $tmp/bpe.monolingual.dedup.${SUBSAMPLE_SIZE}.${LANG} ]; then
84
+ echo "found deduplicated monolingual sample, skipping deduplication step"
85
+ else
86
+ python deduplicate_lines.py $tmp/bpe.monolingual.${SUBSAMPLE_SIZE}.${LANG} \
87
+ > $tmp/bpe.monolingual.dedup.${SUBSAMPLE_SIZE}.${LANG}
88
+ fi
89
+
90
+
91
+ if [ -f $OUTDIR/bpe.monolingual.dedup.00.de ]; then
92
+ echo "found sharded data, skipping sharding step"
93
+ else
94
+ split --lines 1000000 --numeric-suffixes \
95
+ --additional-suffix .${LANG} \
96
+ $tmp/bpe.monolingual.dedup.${SUBSAMPLE_SIZE}.${LANG} \
97
+ $OUTDIR/bpe.monolingual.dedup.
98
+ fi
examples/backtranslation/prepare-wmt18en2de.sh ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Adapted from https://github.com/facebookresearch/MIXER/blob/master/prepareData.sh
3
+
4
+ echo 'Cloning Moses github repository (for tokenization scripts)...'
5
+ git clone https://github.com/moses-smt/mosesdecoder.git
6
+
7
+ echo 'Cloning Subword NMT repository (for BPE pre-processing)...'
8
+ git clone https://github.com/rsennrich/subword-nmt.git
9
+
10
+ SCRIPTS=mosesdecoder/scripts
11
+ TOKENIZER=$SCRIPTS/tokenizer/tokenizer.perl
12
+ CLEAN=$SCRIPTS/training/clean-corpus-n.perl
13
+ NORM_PUNC=$SCRIPTS/tokenizer/normalize-punctuation.perl
14
+ REM_NON_PRINT_CHAR=$SCRIPTS/tokenizer/remove-non-printing-char.perl
15
+ BPEROOT=subword-nmt/subword_nmt
16
+ BPE_TOKENS=32000
17
+
18
+ URLS=(
19
+ "http://statmt.org/wmt13/training-parallel-europarl-v7.tgz"
20
+ "http://statmt.org/wmt13/training-parallel-commoncrawl.tgz"
21
+ "http://data.statmt.org/wmt18/translation-task/training-parallel-nc-v13.tgz"
22
+ "http://data.statmt.org/wmt18/translation-task/rapid2016.tgz"
23
+ "http://data.statmt.org/wmt17/translation-task/dev.tgz"
24
+ "http://statmt.org/wmt14/test-full.tgz"
25
+ )
26
+ FILES=(
27
+ "training-parallel-europarl-v7.tgz"
28
+ "training-parallel-commoncrawl.tgz"
29
+ "training-parallel-nc-v13.tgz"
30
+ "rapid2016.tgz"
31
+ "dev.tgz"
32
+ "test-full.tgz"
33
+ )
34
+ CORPORA=(
35
+ "training/europarl-v7.de-en"
36
+ "commoncrawl.de-en"
37
+ "training-parallel-nc-v13/news-commentary-v13.de-en"
38
+ "rapid2016.de-en"
39
+ )
40
+
41
+ if [ ! -d "$SCRIPTS" ]; then
42
+ echo "Please set SCRIPTS variable correctly to point to Moses scripts."
43
+ exit 1
44
+ fi
45
+
46
+ OUTDIR=wmt18_en_de
47
+
48
+ src=en
49
+ tgt=de
50
+ lang=en-de
51
+ prep=$OUTDIR
52
+ tmp=$prep/tmp
53
+ orig=orig
54
+
55
+ mkdir -p $orig $tmp $prep
56
+
57
+ cd $orig
58
+
59
+ for ((i=0;i<${#URLS[@]};++i)); do
60
+ file=${FILES[i]}
61
+ if [ -f $file ]; then
62
+ echo "$file already exists, skipping download"
63
+ else
64
+ url=${URLS[i]}
65
+ wget "$url"
66
+ if [ -f $file ]; then
67
+ echo "$url successfully downloaded."
68
+ else
69
+ echo "$url not successfully downloaded."
70
+ exit 1
71
+ fi
72
+ if [ ${file: -4} == ".tgz" ]; then
73
+ tar zxvf $file
74
+ elif [ ${file: -4} == ".tar" ]; then
75
+ tar xvf $file
76
+ fi
77
+ fi
78
+ done
79
+ cd ..
80
+
81
+ echo "pre-processing train data..."
82
+ for l in $src $tgt; do
83
+ rm $tmp/train.tags.$lang.tok.$l
84
+ for f in "${CORPORA[@]}"; do
85
+ cat $orig/$f.$l | \
86
+ perl $NORM_PUNC $l | \
87
+ perl $REM_NON_PRINT_CHAR | \
88
+ perl $TOKENIZER -threads 8 -a -l $l >> $tmp/train.tags.$lang.tok.$l
89
+ done
90
+ done
91
+
92
+ echo "pre-processing test data..."
93
+ for l in $src $tgt; do
94
+ if [ "$l" == "$src" ]; then
95
+ t="src"
96
+ else
97
+ t="ref"
98
+ fi
99
+ grep '<seg id' $orig/test-full/newstest2014-deen-$t.$l.sgm | \
100
+ sed -e 's/<seg id="[0-9]*">\s*//g' | \
101
+ sed -e 's/\s*<\/seg>\s*//g' | \
102
+ sed -e "s/\’/\'/g" | \
103
+ perl $TOKENIZER -threads 8 -a -l $l > $tmp/test.$l
104
+ echo ""
105
+ done
106
+
107
+ echo "splitting train and valid..."
108
+ for l in $src $tgt; do
109
+ awk '{if (NR%100 == 0) print $0; }' $tmp/train.tags.$lang.tok.$l > $tmp/valid.$l
110
+ awk '{if (NR%100 != 0) print $0; }' $tmp/train.tags.$lang.tok.$l > $tmp/train.$l
111
+ done
112
+
113
+ TRAIN=$tmp/train.de-en
114
+ BPE_CODE=$prep/code
115
+ rm -f $TRAIN
116
+ for l in $src $tgt; do
117
+ cat $tmp/train.$l >> $TRAIN
118
+ done
119
+
120
+ echo "learn_bpe.py on ${TRAIN}..."
121
+ python $BPEROOT/learn_bpe.py -s $BPE_TOKENS < $TRAIN > $BPE_CODE
122
+
123
+ for L in $src $tgt; do
124
+ for f in train.$L valid.$L test.$L; do
125
+ echo "apply_bpe.py to ${f}..."
126
+ python $BPEROOT/apply_bpe.py -c $BPE_CODE < $tmp/$f > $tmp/bpe.$f
127
+ done
128
+ done
129
+
130
+ perl $CLEAN -ratio 1.5 $tmp/bpe.train $src $tgt $prep/train 1 250
131
+ perl $CLEAN -ratio 1.5 $tmp/bpe.valid $src $tgt $prep/valid 1 250
132
+
133
+ for L in $src $tgt; do
134
+ cp $tmp/bpe.test.$L $prep/test.$L
135
+ done
examples/backtranslation/sacrebleu.sh ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ if [ $# -ne 5 ]; then
4
+ echo "usage: $0 [dataset=wmt14/full] [langpair=en-de] [databin] [bpecode] [model]"
5
+ exit
6
+ fi
7
+
8
+
9
+ DATASET=$1
10
+ LANGPAIR=$2
11
+ DATABIN=$3
12
+ BPECODE=$4
13
+ MODEL=$5
14
+
15
+ SRCLANG=$(echo $LANGPAIR | cut -d '-' -f 1)
16
+ TGTLANG=$(echo $LANGPAIR | cut -d '-' -f 2)
17
+
18
+
19
+ BPEROOT=examples/backtranslation/subword-nmt/subword_nmt
20
+ if [ ! -e $BPEROOT ]; then
21
+ BPEROOT=subword-nmt/subword_nmt
22
+ if [ ! -e $BPEROOT ]; then
23
+ echo 'Cloning Subword NMT repository (for BPE pre-processing)...'
24
+ git clone https://github.com/rsennrich/subword-nmt.git
25
+ fi
26
+ fi
27
+
28
+
29
+ sacrebleu -t $DATASET -l $LANGPAIR --echo src \
30
+ | sacremoses tokenize -a -l $SRCLANG -q \
31
+ | python $BPEROOT/apply_bpe.py -c $BPECODE \
32
+ | fairseq-interactive $DATABIN --path $MODEL \
33
+ -s $SRCLANG -t $TGTLANG \
34
+ --beam 5 --remove-bpe --buffer-size 1024 --max-tokens 8000 \
35
+ | grep ^H- | cut -f 3- \
36
+ | sacremoses detokenize -l $TGTLANG -q \
37
+ | sacrebleu -t $DATASET -l $LANGPAIR
examples/backtranslation/tokenized_bleu.sh ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ if [ $# -ne 5 ]; then
4
+ echo "usage: $0 [dataset=wmt14/full] [langpair=en-de] [databin] [bpecode] [model]"
5
+ exit
6
+ fi
7
+
8
+
9
+ DATASET=$1
10
+ LANGPAIR=$2
11
+ DATABIN=$3
12
+ BPECODE=$4
13
+ MODEL=$5
14
+
15
+ SRCLANG=$(echo $LANGPAIR | cut -d '-' -f 1)
16
+ TGTLANG=$(echo $LANGPAIR | cut -d '-' -f 2)
17
+
18
+
19
+ BPEROOT=examples/backtranslation/subword-nmt/subword_nmt
20
+ if [ ! -e $BPEROOT ]; then
21
+ BPEROOT=subword-nmt/subword_nmt
22
+ if [ ! -e $BPEROOT ]; then
23
+ echo 'Cloning Subword NMT repository (for BPE pre-processing)...'
24
+ git clone https://github.com/rsennrich/subword-nmt.git
25
+ fi
26
+ fi
27
+
28
+
29
+ TMP_REF=$(mktemp)
30
+
31
+ sacrebleu -t $DATASET -l $LANGPAIR --echo ref -q \
32
+ | sacremoses normalize -l $TGTLANG -q \
33
+ | sacremoses tokenize -a -l $TGTLANG -q \
34
+ > $TMP_REF
35
+
36
+ sacrebleu -t $DATASET -l $LANGPAIR --echo src -q \
37
+ | sacremoses normalize -l $SRCLANG -q \
38
+ | sacremoses tokenize -a -l $SRCLANG -q \
39
+ | python $BPEROOT/apply_bpe.py -c $BPECODE \
40
+ | fairseq-interactive $DATABIN --path $MODEL \
41
+ -s $SRCLANG -t $TGTLANG \
42
+ --beam 5 --remove-bpe --buffer-size 1024 --max-tokens 8000 \
43
+ | grep ^H- | cut -f 3- \
44
+ | fairseq-score --ref $TMP_REF
45
+
46
+ rm -f $TMP_REF
examples/bart/README.glue.md ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Fine-tuning BART on GLUE tasks
2
+
3
+ ### 1) Download the data from GLUE website (https://gluebenchmark.com/tasks) using following commands:
4
+ ```bash
5
+ wget https://gist.githubusercontent.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e/raw/17b8dd0d724281ed7c3b2aeeda662b92809aadd5/download_glue_data.py
6
+ python download_glue_data.py --data_dir glue_data --tasks all
7
+ ```
8
+
9
+ ### 2) Preprocess GLUE task data (same as RoBERTa):
10
+ ```bash
11
+ ./examples/roberta/preprocess_GLUE_tasks.sh glue_data <glue_task_name>
12
+ ```
13
+ `glue_task_name` is one of the following:
14
+ `{ALL, QQP, MNLI, QNLI, MRPC, RTE, STS-B, SST-2, CoLA}`
15
+ Use `ALL` for preprocessing all the glue tasks.
16
+
17
+ ### 3) Fine-tuning on GLUE task:
18
+ Example fine-tuning cmd for `RTE` task
19
+ ```bash
20
+ TOTAL_NUM_UPDATES=2036 # 10 epochs through RTE for bsz 16
21
+ WARMUP_UPDATES=61 # 6 percent of the number of updates
22
+ LR=1e-05 # Peak LR for polynomial LR scheduler.
23
+ NUM_CLASSES=2
24
+ MAX_SENTENCES=16 # Batch size.
25
+ BART_PATH=/path/to/bart/model.pt
26
+
27
+ CUDA_VISIBLE_DEVICES=0,1 fairseq-train RTE-bin/ \
28
+ --restore-file $BART_PATH \
29
+ --batch-size $MAX_SENTENCES \
30
+ --max-tokens 4400 \
31
+ --task sentence_prediction \
32
+ --add-prev-output-tokens \
33
+ --layernorm-embedding \
34
+ --share-all-embeddings \
35
+ --share-decoder-input-output-embed \
36
+ --reset-optimizer --reset-dataloader --reset-meters \
37
+ --required-batch-size-multiple 1 \
38
+ --init-token 0 \
39
+ --arch bart_large \
40
+ --criterion sentence_prediction \
41
+ --num-classes $NUM_CLASSES \
42
+ --dropout 0.1 --attention-dropout 0.1 \
43
+ --weight-decay 0.01 --optimizer adam --adam-betas "(0.9, 0.98)" --adam-eps 1e-08 \
44
+ --clip-norm 0.0 \
45
+ --lr-scheduler polynomial_decay --lr $LR --total-num-update $TOTAL_NUM_UPDATES --warmup-updates $WARMUP_UPDATES \
46
+ --fp16 --fp16-init-scale 4 --threshold-loss-scale 1 --fp16-scale-window 128 \
47
+ --max-epoch 10 \
48
+ --find-unused-parameters \
49
+ --best-checkpoint-metric accuracy --maximize-best-checkpoint-metric;
50
+ ```
51
+
52
+ For each of the GLUE task, you will need to use following cmd-line arguments:
53
+
54
+ Model | MNLI | QNLI | QQP | RTE | SST-2 | MRPC | CoLA | STS-B
55
+ ---|---|---|---|---|---|---|---|---
56
+ `--num-classes` | 3 | 2 | 2 | 2 | 2 | 2 | 2 | 1
57
+ `--lr` | 5e-6 | 1e-5 | 1e-5 | 1e-5 | 5e-6 | 2e-5 | 2e-5 | 2e-5
58
+ `bsz` | 128 | 32 | 32 | 32 | 128 | 64 | 64 | 32
59
+ `--total-num-update` | 30968 | 33112 | 113272 | 1018 | 5233 | 1148 | 1334 | 1799
60
+ `--warmup-updates` | 1858 | 1986 | 6796 | 61 | 314 | 68 | 80 | 107
61
+
62
+ For `STS-B` additionally add `--regression-target --best-checkpoint-metric loss` and remove `--maximize-best-checkpoint-metric`.
63
+
64
+ **Note:**
65
+
66
+ a) `--total-num-updates` is used by `--polynomial_decay` scheduler and is calculated for `--max-epoch=10` and `--batch-size=32/64/128` depending on the task.
67
+
68
+ b) Above cmd-args and hyperparams are tested on Nvidia `V100` GPU with `32gb` of memory for each task. Depending on the GPU memory resources available to you, you can use increase `--update-freq` and reduce `--batch-size`.
69
+
70
+ ### Inference on GLUE task
71
+ After training the model as mentioned in previous step, you can perform inference with checkpoints in `checkpoints/` directory using following python code snippet:
72
+
73
+ ```python
74
+ from fairseq.models.bart import BARTModel
75
+
76
+ bart = BARTModel.from_pretrained(
77
+ 'checkpoints/',
78
+ checkpoint_file='checkpoint_best.pt',
79
+ data_name_or_path='RTE-bin'
80
+ )
81
+
82
+ label_fn = lambda label: bart.task.label_dictionary.string(
83
+ [label + bart.task.label_dictionary.nspecial]
84
+ )
85
+ ncorrect, nsamples = 0, 0
86
+ bart.cuda()
87
+ bart.eval()
88
+ with open('glue_data/RTE/dev.tsv') as fin:
89
+ fin.readline()
90
+ for index, line in enumerate(fin):
91
+ tokens = line.strip().split('\t')
92
+ sent1, sent2, target = tokens[1], tokens[2], tokens[3]
93
+ tokens = bart.encode(sent1, sent2)
94
+ prediction = bart.predict('sentence_classification_head', tokens).argmax().item()
95
+ prediction_label = label_fn(prediction)
96
+ ncorrect += int(prediction_label == target)
97
+ nsamples += 1
98
+ print('| Accuracy: ', float(ncorrect)/float(nsamples))
99
+ ```
examples/bart/README.md ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension
2
+
3
+ [https://arxiv.org/abs/1910.13461](https://arxiv.org/abs/1910.13461)
4
+
5
+ ## Introduction
6
+
7
+ BART is sequence-to-sequence model trained with denoising as pretraining objective. We show that this pretraining objective is more generic and show that we can match [RoBERTa](../roberta) results on SQuAD and GLUE and gain state-of-the-art results on summarization (XSum, CNN dataset), long form generative question answering (ELI5) and dialog response genration (ConvAI2). See the associated paper for more details.
8
+
9
+ ## Pre-trained models
10
+
11
+ Model | Description | # params | Download
12
+ ---|---|---|---
13
+ `bart.base` | BART model with 6 encoder and decoder layers | 140M | [bart.base.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/bart.base.tar.gz)
14
+ `bart.large` | BART model with 12 encoder and decoder layers | 400M | [bart.large.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/bart.large.tar.gz)
15
+ `bart.large.mnli` | `bart.large` finetuned on `MNLI` | 400M | [bart.large.mnli.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/bart.large.mnli.tar.gz)
16
+ `bart.large.cnn` | `bart.large` finetuned on `CNN-DM` | 400M | [bart.large.cnn.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/bart.large.cnn.tar.gz)
17
+ `bart.large.xsum` | `bart.large` finetuned on `Xsum` | 400M | [bart.large.xsum.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/bart.large.xsum.tar.gz)
18
+
19
+ ## Results
20
+
21
+ **[GLUE (Wang et al., 2019)](https://gluebenchmark.com/)**
22
+ _(dev set, single model, single-task finetuning)_
23
+
24
+ Model | MNLI | QNLI | QQP | RTE | SST-2 | MRPC | CoLA | STS-B
25
+ ---|---|---|---|---|---|---|---|---
26
+ `roberta.large` | 90.2 | 94.7 | 92.2 | 86.6 | 96.4 | 90.9 | 68.0 | 92.4
27
+ `bart.large` | 89.9 | 94.9 | 92.5 | 87.0 | 96.6 | 90.4 | 62.8 | 91.2
28
+
29
+ **[SQuAD (Rajpurkar et al., 2018)](https://rajpurkar.github.io/SQuAD-explorer/)**
30
+ _(dev set, no additional data used)_
31
+
32
+ Model | SQuAD 1.1 EM/F1 | SQuAD 2.0 EM/F1
33
+ ---|---|---
34
+ `roberta.large` | 88.9/94.6 | 86.5/89.4
35
+ `bart.large` | 88.8/94.6 | 86.1/89.2
36
+
37
+ **[CNN/Daily Mail](http://nlpprogress.com/english/summarization.html)**
38
+ _(test set, no additional data used)_
39
+
40
+ Model | R1 | R2 | RL
41
+ ---|---|---|---
42
+ `BERTSUMEXTABS` | 42.13 | 19.60 | 39.18
43
+ `bart.large` | 44.16 | 21.28 | 40.90
44
+
45
+ ## Example usage
46
+
47
+ ##### Load BART from torch.hub (PyTorch >= 1.1):
48
+ ```python
49
+ import torch
50
+ bart = torch.hub.load('pytorch/fairseq', 'bart.large')
51
+ bart.eval() # disable dropout (or leave in train mode to finetune)
52
+ ```
53
+
54
+ ##### Load BART (for PyTorch 1.0 or custom models):
55
+ ```python
56
+ # Download bart.large model
57
+ wget https://dl.fbaipublicfiles.com/fairseq/models/bart.large.tar.gz
58
+ tar -xzvf bart.large.tar.gz
59
+
60
+ # Load the model in fairseq
61
+ from fairseq.models.bart import BARTModel
62
+ bart = BARTModel.from_pretrained('/path/to/bart.large', checkpoint_file='model.pt')
63
+ bart.eval() # disable dropout (or leave in train mode to finetune)
64
+ ```
65
+
66
+ ##### Apply Byte-Pair Encoding (BPE) to input text:
67
+ ```python
68
+ tokens = bart.encode('Hello world!')
69
+ assert tokens.tolist() == [0, 31414, 232, 328, 2]
70
+ bart.decode(tokens) # 'Hello world!'
71
+ ```
72
+
73
+ ##### Extract features from BART:
74
+ ```python
75
+ # Extract the last layer's features
76
+ last_layer_features = bart.extract_features(tokens)
77
+ assert last_layer_features.size() == torch.Size([1, 5, 1024])
78
+
79
+ # Extract all layer's features from decoder (layer 0 is the embedding layer)
80
+ all_layers = bart.extract_features(tokens, return_all_hiddens=True)
81
+ assert len(all_layers) == 13
82
+ assert torch.all(all_layers[-1] == last_layer_features)
83
+ ```
84
+
85
+ ##### Use BART for sentence-pair classification tasks:
86
+ ```python
87
+ # Download BART already finetuned for MNLI
88
+ bart = torch.hub.load('pytorch/fairseq', 'bart.large.mnli')
89
+ bart.eval() # disable dropout for evaluation
90
+
91
+ # Encode a pair of sentences and make a prediction
92
+ tokens = bart.encode('BART is a seq2seq model.', 'BART is not sequence to sequence.')
93
+ bart.predict('mnli', tokens).argmax() # 0: contradiction
94
+
95
+ # Encode another pair of sentences
96
+ tokens = bart.encode('BART is denoising autoencoder.', 'BART is version of autoencoder.')
97
+ bart.predict('mnli', tokens).argmax() # 2: entailment
98
+ ```
99
+
100
+ ##### Register a new (randomly initialized) classification head:
101
+ ```python
102
+ bart.register_classification_head('new_task', num_classes=3)
103
+ logprobs = bart.predict('new_task', tokens)
104
+ ```
105
+
106
+ ##### Batched prediction:
107
+ ```python
108
+ import torch
109
+ from fairseq.data.data_utils import collate_tokens
110
+
111
+ bart = torch.hub.load('pytorch/fairseq', 'bart.large.mnli')
112
+ bart.eval()
113
+
114
+ batch_of_pairs = [
115
+ ['BART is a seq2seq model.', 'BART is not sequence to sequence.'],
116
+ ['BART is denoising autoencoder.', 'BART is version of autoencoder.'],
117
+ ]
118
+
119
+ batch = collate_tokens(
120
+ [bart.encode(pair[0], pair[1]) for pair in batch_of_pairs], pad_idx=1
121
+ )
122
+
123
+ logprobs = bart.predict('mnli', batch)
124
+ print(logprobs.argmax(dim=1))
125
+ # tensor([0, 2])
126
+ ```
127
+
128
+ ##### Using the GPU:
129
+ ```python
130
+ bart.cuda()
131
+ bart.predict('new_task', tokens)
132
+ ```
133
+
134
+ #### Filling masks:
135
+
136
+ BART can be used to fill multiple `<mask>` tokens in the input.
137
+ ```python
138
+ bart = torch.hub.load('pytorch/fairseq', 'bart.base')
139
+ bart.eval()
140
+ bart.fill_mask(['The cat <mask> on the <mask>.'], topk=3, beam=10)
141
+ # [[('The cat was on the ground.', tensor(-0.6183)), ('The cat was on the floor.', tensor(-0.6798)), ('The cat sleeps on the couch.', tensor(-0.6830))]]
142
+ ```
143
+
144
+ Note that by default we enforce the output length to match the input length.
145
+ This can be disabled by setting ``match_source_len=False``:
146
+ ```
147
+ bart.fill_mask(['The cat <mask> on the <mask>.'], topk=3, beam=10, match_source_len=False)
148
+ # [[('The cat was on the ground.', tensor(-0.6185)), ('The cat was asleep on the couch.', tensor(-0.6276)), ('The cat was on the floor.', tensor(-0.6800))]]
149
+ ```
150
+
151
+ Example code to fill masks for a batch of sentences using GPU
152
+ ```
153
+ bart.cuda()
154
+ bart.fill_mask(['The cat <mask> on the <mask>.', 'The dog <mask> on the <mask>.'], topk=3, beam=10)
155
+ # [[('The cat was on the ground.', tensor(-0.6183)), ('The cat was on the floor.', tensor(-0.6798)), ('The cat sleeps on the couch.', tensor(-0.6830))], [('The dog was on the ground.', tensor(-0.6190)), ('The dog lay on the ground.', tensor(-0.6711)),
156
+ ('The dog was asleep on the couch', tensor(-0.6796))]]
157
+ ```
158
+
159
+ #### Evaluating the `bart.large.mnli` model:
160
+
161
+ Example python code snippet to evaluate accuracy on the MNLI `dev_matched` set.
162
+ ```python
163
+ label_map = {0: 'contradiction', 1: 'neutral', 2: 'entailment'}
164
+ ncorrect, nsamples = 0, 0
165
+ bart.cuda()
166
+ bart.eval()
167
+ with open('glue_data/MNLI/dev_matched.tsv') as fin:
168
+ fin.readline()
169
+ for index, line in enumerate(fin):
170
+ tokens = line.strip().split('\t')
171
+ sent1, sent2, target = tokens[8], tokens[9], tokens[-1]
172
+ tokens = bart.encode(sent1, sent2)
173
+ prediction = bart.predict('mnli', tokens).argmax().item()
174
+ prediction_label = label_map[prediction]
175
+ ncorrect += int(prediction_label == target)
176
+ nsamples += 1
177
+ print('| Accuracy: ', float(ncorrect)/float(nsamples))
178
+ # Expected output: 0.9010
179
+ ```
180
+
181
+ #### Evaluating the `bart.large.cnn` model:
182
+ - Follow instructions [here](https://github.com/abisee/cnn-dailymail) to download and process into data-files such that `test.source` and `test.target` has one line for each non-tokenized sample.
183
+ - For simpler preprocessing, you can also `wget https://cdn-datasets.huggingface.co/summarization/cnn_dm_v2.tgz`, although there is no guarantee of identical scores
184
+ - `huggingface/transformers` has a simpler interface that supports [single-gpu](https://github.com/huggingface/transformers/blob/master/examples/legacy/seq2seq/run_eval.py) and [multi-gpu](https://github.com/huggingface/transformers/blob/master/examples/legacy/seq2seq/run_distributed_eval.py) beam search.
185
+ In `huggingface/transformers`, the BART models' paths are `facebook/bart-large-cnn` and `facebook/bart-large-xsum`.
186
+
187
+ In `fairseq`, summaries can be generated using:
188
+
189
+ ```bash
190
+ cp data-bin/cnn_dm/dict.source.txt checkpoints/
191
+ python examples/bart/summarize.py \
192
+ --model-dir pytorch/fairseq \
193
+ --model-file bart.large.cnn \
194
+ --src cnn_dm/test.source \
195
+ --out cnn_dm/test.hypo
196
+ ```
197
+
198
+ For calculating rouge, install `files2rouge` from [here](https://github.com/pltrdy/files2rouge).
199
+
200
+ ```bash
201
+ export CLASSPATH=/path/to/stanford-corenlp-full-2016-10-31/stanford-corenlp-3.7.0.jar
202
+
203
+ # Tokenize hypothesis and target files.
204
+ cat test.hypo | java edu.stanford.nlp.process.PTBTokenizer -ioFileList -preserveLines > test.hypo.tokenized
205
+ cat test.target | java edu.stanford.nlp.process.PTBTokenizer -ioFileList -preserveLines > test.hypo.target
206
+ files2rouge test.hypo.tokenized test.hypo.target
207
+ # Expected output: (ROUGE-2 Average_F: 0.21238)
208
+ ```
209
+
210
+
211
+ ## Finetuning
212
+
213
+ - [Finetuning on GLUE](README.glue.md)
214
+ - [Finetuning on CNN-DM](README.summarization.md)
215
+
216
+ ## Citation
217
+
218
+ ```bibtex
219
+ @article{lewis2019bart,
220
+ title = {BART: Denoising Sequence-to-Sequence Pre-training for Natural
221
+ Language Generation, Translation, and Comprehension},
222
+ author = {Mike Lewis and Yinhan Liu and Naman Goyal and Marjan Ghazvininejad and
223
+ Abdelrahman Mohamed and Omer Levy and Veselin Stoyanov
224
+ and Luke Zettlemoyer },
225
+ journal={arXiv preprint arXiv:1910.13461},
226
+ year = {2019},
227
+ }
228
+ ```
examples/bart/README.summarization.md ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Fine-tuning BART on CNN-Dailymail summarization task
2
+
3
+ ### 1) Download the CNN and Daily Mail data and preprocess it into data files with non-tokenized cased samples.
4
+
5
+ Follow the instructions [here](https://github.com/abisee/cnn-dailymail) to download the original CNN and Daily Mail datasets. To preprocess the data, refer to the pointers in [this issue](https://github.com/pytorch/fairseq/issues/1391) or check out the code [here](https://github.com/artmatsak/cnn-dailymail).
6
+
7
+ Follow the instructions [here](https://github.com/EdinburghNLP/XSum) to download the original Extreme Summarization datasets, or check out the code [here](https://github.com/EdinburghNLP/XSum/tree/master/XSum-Dataset), Please keep the raw dataset and make sure no tokenization nor BPE on the dataset.
8
+
9
+ ### 2) BPE preprocess:
10
+
11
+ ```bash
12
+ wget -N 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json'
13
+ wget -N 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe'
14
+ wget -N 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/dict.txt'
15
+
16
+ TASK=cnn_dm
17
+ for SPLIT in train val
18
+ do
19
+ for LANG in source target
20
+ do
21
+ python -m examples.roberta.multiprocessing_bpe_encoder \
22
+ --encoder-json encoder.json \
23
+ --vocab-bpe vocab.bpe \
24
+ --inputs "$TASK/$SPLIT.$LANG" \
25
+ --outputs "$TASK/$SPLIT.bpe.$LANG" \
26
+ --workers 60 \
27
+ --keep-empty;
28
+ done
29
+ done
30
+ ```
31
+
32
+ ### 3) Binarize dataset:
33
+ ```bash
34
+ fairseq-preprocess \
35
+ --source-lang "source" \
36
+ --target-lang "target" \
37
+ --trainpref "${TASK}/train.bpe" \
38
+ --validpref "${TASK}/val.bpe" \
39
+ --destdir "${TASK}-bin/" \
40
+ --workers 60 \
41
+ --srcdict dict.txt \
42
+ --tgtdict dict.txt;
43
+ ```
44
+
45
+ ### 4) Fine-tuning on CNN-DM summarization task:
46
+ Example fine-tuning CNN-DM
47
+ ```bash
48
+ TOTAL_NUM_UPDATES=20000
49
+ WARMUP_UPDATES=500
50
+ LR=3e-05
51
+ MAX_TOKENS=2048
52
+ UPDATE_FREQ=4
53
+ BART_PATH=/path/to/bart/model.pt
54
+
55
+ CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 fairseq-train cnn_dm-bin \
56
+ --restore-file $BART_PATH \
57
+ --max-tokens $MAX_TOKENS \
58
+ --task translation \
59
+ --source-lang source --target-lang target \
60
+ --truncate-source \
61
+ --layernorm-embedding \
62
+ --share-all-embeddings \
63
+ --share-decoder-input-output-embed \
64
+ --reset-optimizer --reset-dataloader --reset-meters \
65
+ --required-batch-size-multiple 1 \
66
+ --arch bart_large \
67
+ --criterion label_smoothed_cross_entropy \
68
+ --label-smoothing 0.1 \
69
+ --dropout 0.1 --attention-dropout 0.1 \
70
+ --weight-decay 0.01 --optimizer adam --adam-betas "(0.9, 0.999)" --adam-eps 1e-08 \
71
+ --clip-norm 0.1 \
72
+ --lr-scheduler polynomial_decay --lr $LR --total-num-update $TOTAL_NUM_UPDATES --warmup-updates $WARMUP_UPDATES \
73
+ --fp16 --update-freq $UPDATE_FREQ \
74
+ --skip-invalid-size-inputs-valid-test \
75
+ --find-unused-parameters;
76
+ ```
77
+ Above is expected to run on `1` node with `8 32gb-V100`.
78
+ Expected training time is about `5 hours`. Training time can be reduced with distributed training on `4` nodes and `--update-freq 1`.
79
+
80
+ Use TOTAL_NUM_UPDATES=15000 UPDATE_FREQ=2 for Xsum task
81
+
82
+ ### Inference for CNN-DM test data using above trained checkpoint.
83
+ After training the model as mentioned in previous step, you can perform inference with checkpoints in `checkpoints/` directory using `eval_cnn.py`, for example
84
+
85
+ ```bash
86
+ cp data-bin/cnn_dm/dict.source.txt checkpoints/
87
+ python examples/bart/summarize.py \
88
+ --model-dir checkpoints \
89
+ --model-file checkpoint_best.pt \
90
+ --src cnn_dm/test.source \
91
+ --out cnn_dm/test.hypo
92
+ ```
93
+ For XSUM, which uses beam=6, lenpen=1.0, max_len_b=60, min_len=10:
94
+ ```bash
95
+ cp data-bin/cnn_dm/dict.source.txt checkpoints/
96
+ python examples/bart/summarize.py \
97
+ --model-dir checkpoints \
98
+ --model-file checkpoint_best.pt \
99
+ --src cnn_dm/test.source \
100
+ --out cnn_dm/test.hypo \
101
+ --xsum-kwargs
102
+ ```
examples/bart/summarize.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ from fairseq.models.bart import BARTModel
8
+ import argparse
9
+
10
+ XSUM_KWARGS = dict(beam=6, lenpen=1.0, max_len_b=60, min_len=10, no_repeat_ngram_size=3)
11
+ CNN_KWARGS = dict(beam=4, lenpen=2.0, max_len_b=140, min_len=55, no_repeat_ngram_size=3)
12
+
13
+
14
+ @torch.no_grad()
15
+ def generate(bart, infile, outfile="bart_hypo.txt", bsz=32, n_obs=None, **eval_kwargs):
16
+ count = 1
17
+
18
+ # if n_obs is not None: bsz = min(bsz, n_obs)
19
+
20
+ with open(infile) as source, open(outfile, "w") as fout:
21
+ sline = source.readline().strip()
22
+ slines = [sline]
23
+ for sline in source:
24
+ if n_obs is not None and count > n_obs:
25
+ break
26
+ if count % bsz == 0:
27
+ hypotheses_batch = bart.sample(slines, **eval_kwargs)
28
+ for hypothesis in hypotheses_batch:
29
+ fout.write(hypothesis + "\n")
30
+ fout.flush()
31
+ slines = []
32
+
33
+ slines.append(sline.strip())
34
+ count += 1
35
+
36
+ if slines != []:
37
+ hypotheses_batch = bart.sample(slines, **eval_kwargs)
38
+ for hypothesis in hypotheses_batch:
39
+ fout.write(hypothesis + "\n")
40
+ fout.flush()
41
+
42
+
43
+ def main():
44
+ """
45
+ Usage::
46
+
47
+ python examples/bart/summarize.py \
48
+ --model-dir $HOME/bart.large.cnn \
49
+ --model-file model.pt \
50
+ --src $HOME/data-bin/cnn_dm/test.source
51
+ """
52
+ parser = argparse.ArgumentParser()
53
+ parser.add_argument(
54
+ "--model-dir",
55
+ required=True,
56
+ type=str,
57
+ default="bart.large.cnn/",
58
+ help="path containing model file and src_dict.txt",
59
+ )
60
+ parser.add_argument(
61
+ "--model-file",
62
+ default="checkpoint_best.pt",
63
+ help="where in model_dir are weights saved",
64
+ )
65
+ parser.add_argument(
66
+ "--src", default="test.source", help="text to summarize", type=str
67
+ )
68
+ parser.add_argument(
69
+ "--out", default="test.hypo", help="where to save summaries", type=str
70
+ )
71
+ parser.add_argument("--bsz", default=32, help="where to save summaries", type=int)
72
+ parser.add_argument(
73
+ "--n", default=None, help="how many examples to summarize", type=int
74
+ )
75
+ parser.add_argument(
76
+ "--xsum-kwargs",
77
+ action="store_true",
78
+ default=False,
79
+ help="if true use XSUM_KWARGS else CNN_KWARGS",
80
+ )
81
+ args = parser.parse_args()
82
+ eval_kwargs = XSUM_KWARGS if args.xsum_kwargs else CNN_KWARGS
83
+ if args.model_dir == "pytorch/fairseq":
84
+ bart = torch.hub.load("pytorch/fairseq", args.model_file)
85
+ else:
86
+ bart = BARTModel.from_pretrained(
87
+ args.model_dir,
88
+ checkpoint_file=args.model_file,
89
+ data_name_or_path=args.model_dir,
90
+ )
91
+ bart = bart.eval()
92
+ if torch.cuda.is_available():
93
+ bart = bart.cuda().half()
94
+ generate(
95
+ bart, args.src, bsz=args.bsz, n_obs=args.n, outfile=args.out, **eval_kwargs
96
+ )
97
+
98
+
99
+ if __name__ == "__main__":
100
+ main()
examples/byte_level_bpe/README.md ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Neural Machine Translation with Byte-Level Subwords
2
+
3
+ https://arxiv.org/abs/1909.03341
4
+
5
+ We provide an implementation of byte-level byte-pair encoding (BBPE), taking IWSLT 2017 Fr-En translation as
6
+ example.
7
+
8
+ ## Data
9
+ Get data and generate fairseq binary dataset:
10
+ ```bash
11
+ bash ./get_data.sh
12
+ ```
13
+
14
+ ## Model Training
15
+ Train Transformer model with Bi-GRU embedding contextualization (implemented in `gru_transformer.py`):
16
+ ```bash
17
+ # VOCAB=bytes
18
+ # VOCAB=chars
19
+ VOCAB=bbpe2048
20
+ # VOCAB=bpe2048
21
+ # VOCAB=bbpe4096
22
+ # VOCAB=bpe4096
23
+ # VOCAB=bpe16384
24
+ ```
25
+ ```bash
26
+ fairseq-train "data/bin_${VOCAB}" --task translation --user-dir examples/byte_level_bpe/gru_transformer \
27
+ --arch gru_transformer --encoder-layers 2 --decoder-layers 2 --dropout 0.3 --share-all-embeddings \
28
+ --optimizer adam --adam-betas '(0.9, 0.98)' \
29
+ --lr 5e-4 --lr-scheduler inverse_sqrt --warmup-updates 4000 \
30
+ --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
31
+ --log-format 'simple' --log-interval 100 --save-dir "checkpoints/${VOCAB}" \
32
+ --batch-size 100 --max-update 100000 --update-freq 2
33
+ ```
34
+
35
+ ## Generation
36
+ `fairseq-generate` requires bytes (BBPE) decoder to convert byte-level representation back to characters:
37
+ ```bash
38
+ # BPE=--bpe bytes
39
+ # BPE=--bpe characters
40
+ BPE=--bpe byte_bpe --sentencepiece-model-path data/spm_bbpe2048.model
41
+ # BPE=--bpe sentencepiece --sentencepiece-model data/spm_bpe2048.model
42
+ # BPE=--bpe byte_bpe --sentencepiece-model-path data/spm_bbpe4096.model
43
+ # BPE=--bpe sentencepiece --sentencepiece-model data/spm_bpe4096.model
44
+ # BPE=--bpe sentencepiece --sentencepiece-model data/spm_bpe16384.model
45
+ ```
46
+
47
+ ```bash
48
+ fairseq-generate "data/bin_${VOCAB}" --task translation --user-dir examples/byte_level_bpe/gru_transformer \
49
+ --source-lang fr --gen-subset test --sacrebleu --path "checkpoints/${VOCAB}/checkpoint_last.pt" \
50
+ --tokenizer moses --moses-target-lang en ${BPE}
51
+ ```
52
+ When using `fairseq-interactive`, bytes (BBPE) encoder/decoder is required to tokenize input data and detokenize model predictions:
53
+ ```bash
54
+ fairseq-interactive "data/bin_${VOCAB}" --task translation --user-dir examples/byte_level_bpe/gru_transformer \
55
+ --path "checkpoints/${VOCAB}/checkpoint_last.pt" --input data/test.fr --tokenizer moses --moses-source-lang fr \
56
+ --moses-target-lang en ${BPE} --buffer-size 1000 --max-tokens 10000
57
+ ```
58
+
59
+ ## Results
60
+ | Vocabulary | Model | BLEU |
61
+ |:-------------:|:-------------:|:-------------:|
62
+ | Joint BPE 16k ([Kudo, 2018](https://arxiv.org/abs/1804.10959)) | 512d LSTM 2+2 | 33.81 |
63
+ | Joint BPE 16k | Transformer base 2+2 (w/ GRU) | 36.64 (36.72) |
64
+ | Joint BPE 4k | Transformer base 2+2 (w/ GRU) | 35.49 (36.10) |
65
+ | Joint BBPE 4k | Transformer base 2+2 (w/ GRU) | 35.61 (35.82) |
66
+ | Joint BPE 2k | Transformer base 2+2 (w/ GRU) | 34.87 (36.13) |
67
+ | Joint BBPE 2k | Transformer base 2+2 (w/ GRU) | 34.98 (35.43) |
68
+ | Characters | Transformer base 2+2 (w/ GRU) | 31.78 (33.30) |
69
+ | Bytes | Transformer base 2+2 (w/ GRU) | 31.57 (33.62) |
70
+
71
+
72
+ ## Citation
73
+ ```
74
+ @misc{wang2019neural,
75
+ title={Neural Machine Translation with Byte-Level Subwords},
76
+ author={Changhan Wang and Kyunghyun Cho and Jiatao Gu},
77
+ year={2019},
78
+ eprint={1909.03341},
79
+ archivePrefix={arXiv},
80
+ primaryClass={cs.CL}
81
+ }
82
+ ```
83
+
84
+
85
+ ## Contact
86
+ Changhan Wang ([[email protected]](mailto:[email protected])),
87
+ Kyunghyun Cho ([[email protected]](mailto:[email protected])),
88
+ Jiatao Gu ([[email protected]](mailto:[email protected]))
examples/byte_level_bpe/get_bitext.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+
7
+ import argparse
8
+ import os
9
+ import os.path as op
10
+ from collections import namedtuple
11
+ from multiprocessing import cpu_count
12
+ from typing import List, Optional
13
+
14
+ import sentencepiece as sp
15
+ from fairseq.data.encoders.byte_bpe import ByteBPE
16
+ from fairseq.data.encoders.byte_utils import byte_encode
17
+ from fairseq.data.encoders.bytes import Bytes
18
+ from fairseq.data.encoders.characters import Characters
19
+ from fairseq.data.encoders.moses_tokenizer import MosesTokenizer
20
+ from fairseq.data.encoders.sentencepiece_bpe import SentencepieceBPE
21
+
22
+
23
+ SPLITS = ["train", "valid", "test"]
24
+
25
+
26
+ def _convert_xml(in_path: str, out_path: str):
27
+ with open(in_path) as f, open(out_path, "w") as f_o:
28
+ for s in f:
29
+ ss = s.strip()
30
+ if not ss.startswith("<seg"):
31
+ continue
32
+ ss = ss.replace("</seg>", "").split('">')
33
+ assert len(ss) == 2
34
+ f_o.write(ss[1].strip() + "\n")
35
+
36
+
37
+ def _convert_train(in_path: str, out_path: str):
38
+ with open(in_path) as f, open(out_path, "w") as f_o:
39
+ for s in f:
40
+ ss = s.strip()
41
+ if ss.startswith("<"):
42
+ continue
43
+ f_o.write(ss.strip() + "\n")
44
+
45
+
46
+ def _get_bytes(in_path: str, out_path: str):
47
+ with open(in_path) as f, open(out_path, "w") as f_o:
48
+ for s in f:
49
+ f_o.write(Bytes.encode(s.strip()) + "\n")
50
+
51
+
52
+ def _get_chars(in_path: str, out_path: str):
53
+ with open(in_path) as f, open(out_path, "w") as f_o:
54
+ for s in f:
55
+ f_o.write(Characters.encode(s.strip()) + "\n")
56
+
57
+
58
+ def pretokenize(in_path: str, out_path: str, src: str, tgt: str):
59
+ Args = namedtuple(
60
+ "Args",
61
+ [
62
+ "moses_source_lang",
63
+ "moses_target_lang",
64
+ "moses_no_dash_splits",
65
+ "moses_no_escape",
66
+ ],
67
+ )
68
+ args = Args(
69
+ moses_source_lang=src,
70
+ moses_target_lang=tgt,
71
+ moses_no_dash_splits=False,
72
+ moses_no_escape=False,
73
+ )
74
+ pretokenizer = MosesTokenizer(args)
75
+ with open(in_path) as f, open(out_path, "w") as f_o:
76
+ for s in f:
77
+ f_o.write(pretokenizer.encode(s.strip()) + "\n")
78
+
79
+
80
+ def _convert_to_bchar(in_path_prefix: str, src: str, tgt: str, out_path: str):
81
+ with open(out_path, "w") as f_o:
82
+ for lang in [src, tgt]:
83
+ with open(f"{in_path_prefix}.{lang}") as f:
84
+ for s in f:
85
+ f_o.write(byte_encode(s.strip()) + "\n")
86
+
87
+
88
+ def _get_bpe(in_path: str, model_prefix: str, vocab_size: int):
89
+ arguments = [
90
+ f"--input={in_path}",
91
+ f"--model_prefix={model_prefix}",
92
+ f"--model_type=bpe",
93
+ f"--vocab_size={vocab_size}",
94
+ "--character_coverage=1.0",
95
+ "--normalization_rule_name=identity",
96
+ f"--num_threads={cpu_count()}",
97
+ ]
98
+ sp.SentencePieceTrainer.Train(" ".join(arguments))
99
+
100
+
101
+ def _apply_bbpe(model_path: str, in_path: str, out_path: str):
102
+ Args = namedtuple("Args", ["sentencepiece_model_path"])
103
+ args = Args(sentencepiece_model_path=model_path)
104
+ tokenizer = ByteBPE(args)
105
+ with open(in_path) as f, open(out_path, "w") as f_o:
106
+ for s in f:
107
+ f_o.write(tokenizer.encode(s.strip()) + "\n")
108
+
109
+
110
+ def _apply_bpe(model_path: str, in_path: str, out_path: str):
111
+ Args = namedtuple("Args", ["sentencepiece_model"])
112
+ args = Args(sentencepiece_model=model_path)
113
+ tokenizer = SentencepieceBPE(args)
114
+ with open(in_path) as f, open(out_path, "w") as f_o:
115
+ for s in f:
116
+ f_o.write(tokenizer.encode(s.strip()) + "\n")
117
+
118
+
119
+ def _concat_files(in_paths: List[str], out_path: str):
120
+ with open(out_path, "w") as f_o:
121
+ for p in in_paths:
122
+ with open(p) as f:
123
+ for r in f:
124
+ f_o.write(r)
125
+
126
+
127
+ def preprocess_iwslt17(
128
+ root: str,
129
+ src: str,
130
+ tgt: str,
131
+ bpe_size: Optional[int],
132
+ need_chars: bool,
133
+ bbpe_size: Optional[int],
134
+ need_bytes: bool,
135
+ ):
136
+ # extract bitext
137
+ in_root = op.join(root, f"{src}-{tgt}")
138
+ for lang in [src, tgt]:
139
+ _convert_train(
140
+ op.join(in_root, f"train.tags.{src}-{tgt}.{lang}"),
141
+ op.join(root, f"train.{lang}"),
142
+ )
143
+ _convert_xml(
144
+ op.join(in_root, f"IWSLT17.TED.dev2010.{src}-{tgt}.{lang}.xml"),
145
+ op.join(root, f"valid.{lang}"),
146
+ )
147
+ _convert_xml(
148
+ op.join(in_root, f"IWSLT17.TED.tst2015.{src}-{tgt}.{lang}.xml"),
149
+ op.join(root, f"test.{lang}"),
150
+ )
151
+ # pre-tokenize
152
+ for lang in [src, tgt]:
153
+ for split in SPLITS:
154
+ pretokenize(
155
+ op.join(root, f"{split}.{lang}"),
156
+ op.join(root, f"{split}.moses.{lang}"),
157
+ src,
158
+ tgt,
159
+ )
160
+ # tokenize with BPE vocabulary
161
+ if bpe_size is not None:
162
+ # learn vocabulary
163
+ concated_train_path = op.join(root, "train.all")
164
+ _concat_files(
165
+ [op.join(root, "train.moses.fr"), op.join(root, "train.moses.en")],
166
+ concated_train_path,
167
+ )
168
+ bpe_model_prefix = op.join(root, f"spm_bpe{bpe_size}")
169
+ _get_bpe(concated_train_path, bpe_model_prefix, bpe_size)
170
+ os.remove(concated_train_path)
171
+ # apply
172
+ for lang in [src, tgt]:
173
+ for split in SPLITS:
174
+ _apply_bpe(
175
+ bpe_model_prefix + ".model",
176
+ op.join(root, f"{split}.moses.{lang}"),
177
+ op.join(root, f"{split}.moses.bpe{bpe_size}.{lang}"),
178
+ )
179
+ # tokenize with bytes vocabulary
180
+ if need_bytes:
181
+ for lang in [src, tgt]:
182
+ for split in SPLITS:
183
+ _get_bytes(
184
+ op.join(root, f"{split}.moses.{lang}"),
185
+ op.join(root, f"{split}.moses.bytes.{lang}"),
186
+ )
187
+ # tokenize with characters vocabulary
188
+ if need_chars:
189
+ for lang in [src, tgt]:
190
+ for split in SPLITS:
191
+ _get_chars(
192
+ op.join(root, f"{split}.moses.{lang}"),
193
+ op.join(root, f"{split}.moses.chars.{lang}"),
194
+ )
195
+ # tokenize with byte-level BPE vocabulary
196
+ if bbpe_size is not None:
197
+ # learn vocabulary
198
+ bchar_path = op.join(root, "train.bchar")
199
+ _convert_to_bchar(op.join(root, "train.moses"), src, tgt, bchar_path)
200
+ bbpe_model_prefix = op.join(root, f"spm_bbpe{bbpe_size}")
201
+ _get_bpe(bchar_path, bbpe_model_prefix, bbpe_size)
202
+ os.remove(bchar_path)
203
+ # apply
204
+ for lang in [src, tgt]:
205
+ for split in SPLITS:
206
+ _apply_bbpe(
207
+ bbpe_model_prefix + ".model",
208
+ op.join(root, f"{split}.moses.{lang}"),
209
+ op.join(root, f"{split}.moses.bbpe{bbpe_size}.{lang}"),
210
+ )
211
+
212
+
213
+ def main():
214
+ parser = argparse.ArgumentParser()
215
+ parser.add_argument("--root", type=str, default="data")
216
+ parser.add_argument(
217
+ "--bpe-vocab",
218
+ default=None,
219
+ type=int,
220
+ help="Generate tokenized bitext with BPE of size K."
221
+ "Default to None (disabled).",
222
+ )
223
+ parser.add_argument(
224
+ "--bbpe-vocab",
225
+ default=None,
226
+ type=int,
227
+ help="Generate tokenized bitext with BBPE of size K."
228
+ "Default to None (disabled).",
229
+ )
230
+ parser.add_argument(
231
+ "--byte-vocab",
232
+ action="store_true",
233
+ help="Generate tokenized bitext with bytes vocabulary",
234
+ )
235
+ parser.add_argument(
236
+ "--char-vocab",
237
+ action="store_true",
238
+ help="Generate tokenized bitext with chars vocabulary",
239
+ )
240
+ args = parser.parse_args()
241
+
242
+ preprocess_iwslt17(
243
+ args.root,
244
+ "fr",
245
+ "en",
246
+ args.bpe_vocab,
247
+ args.char_vocab,
248
+ args.bbpe_vocab,
249
+ args.byte_vocab,
250
+ )
251
+
252
+
253
+ if __name__ == "__main__":
254
+ main()