diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..a7767b63a8d61b2622642ccc9012f06af5053e17 --- /dev/null +++ b/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 1999-2022 Alibaba Group Holding Ltd. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/README.md b/README.md index c0283989b5a6f7739a0623e0443a6c9bfd5fe30f..2b21e08b5d68cca5d10f920fb548a1e8dc5fcb26 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,8 @@ --- -title: OFA Open_Domain_VQA -emoji: 🔥 -colorFrom: red -colorTo: indigo +title: OFA-Open_Domain_VQA +emoji: 💩 +colorFrom: blue +colorTo: pink sdk: gradio app_file: app.py pinned: false diff --git a/airship.jpg b/airship.jpg new file mode 100644 index 0000000000000000000000000000000000000000..0a751cf9dfaf6f4983dafccb4aeadcc79b208b7e Binary files /dev/null and b/airship.jpg differ diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..e9ecbf5221c3a97250e839564fcd9e0d7581b45b --- /dev/null +++ b/app.py @@ -0,0 +1,153 @@ +import os + +os.system('git clone https://github.com/pytorch/fairseq.git; cd fairseq;' + 'pip install --use-feature=in-tree-build ./; cd ..') +os.system('ls -l') + +import torch +import numpy as np +import re +from fairseq import utils,tasks +from fairseq import checkpoint_utils +from fairseq import distributed_utils, options, tasks, utils +from fairseq.dataclass.utils import convert_namespace_to_omegaconf +from utils.zero_shot_utils import zero_shot_step +from tasks.mm_tasks.vqa_gen import VqaGenTask +from models.ofa import OFAModel +from PIL import Image +from torchvision import transforms +import gradio as gr + +# Register VQA task +tasks.register_task('vqa_gen',VqaGenTask) +# turn on cuda if GPU is available +use_cuda = torch.cuda.is_available() +# use fp16 only when GPU is available +use_fp16 = False + +os.system('wget https://ofa-silicon.oss-us-west-1.aliyuncs.com/checkpoints/ofa_large_384.pt; ' + 'mkdir -p checkpoints; mv ofa_large_384.pt checkpoints/ofa_large_384.pt') + +# specify some options for evaluation +parser = options.get_generation_parser() +input_args = ["", "--task=vqa_gen", "--beam=100", "--unnormalized", "--path=checkpoints/ofa_large_384.pt", "--bpe-dir=utils/BPE"] +args = options.parse_args_and_arch(parser, input_args) +cfg = convert_namespace_to_omegaconf(args) + +# Load pretrained ckpt & config +task = tasks.setup_task(cfg.task) +models, cfg = checkpoint_utils.load_model_ensemble( + utils.split_paths(cfg.common_eval.path), + task=task +) + +# Move models to GPU +for model in models: + model.eval() + if use_fp16: + model.half() + if use_cuda and not cfg.distributed_training.pipeline_model_parallel: + model.cuda() + model.prepare_for_inference_(cfg) + +# Initialize generator +generator = task.build_generator(models, cfg.generation) + +# Image transform +from torchvision import transforms +mean = [0.5, 0.5, 0.5] +std = [0.5, 0.5, 0.5] + +patch_resize_transform = transforms.Compose([ + lambda image: image.convert("RGB"), + transforms.Resize((cfg.task.patch_image_size, cfg.task.patch_image_size), interpolation=Image.BICUBIC), + transforms.ToTensor(), + transforms.Normalize(mean=mean, std=std), +]) + +# Text preprocess +bos_item = torch.LongTensor([task.src_dict.bos()]) +eos_item = torch.LongTensor([task.src_dict.eos()]) +pad_idx = task.src_dict.pad() + +# Normalize the question +def pre_question(question, max_ques_words): + question = question.lower().lstrip(",.!?*#:;~").replace('-', ' ').replace('/', ' ') + question = re.sub( + r"\s{2,}", + ' ', + question, + ) + question = question.rstrip('\n') + question = question.strip(' ') + # truncate question + question_words = question.split(' ') + if len(question_words) > max_ques_words: + question = ' '.join(question_words[:max_ques_words]) + return question + +def encode_text(text, length=None, append_bos=False, append_eos=False): + s = task.tgt_dict.encode_line( + line=task.bpe.encode(text), + add_if_not_exist=False, + append_eos=False + ).long() + if length is not None: + s = s[:length] + if append_bos: + s = torch.cat([bos_item, s]) + if append_eos: + s = torch.cat([s, eos_item]) + return s + +# Construct input for open-domain VQA task +def construct_sample(image: Image, question: str): + patch_image = patch_resize_transform(image).unsqueeze(0) + patch_mask = torch.tensor([True]) + + question = pre_question(question, task.cfg.max_src_length) + question = question + '?' if not question.endswith('?') else question + src_text = encode_text(' {}'.format(question), append_bos=True, append_eos=True).unsqueeze(0) + + src_length = torch.LongTensor([s.ne(pad_idx).long().sum() for s in src_text]) + ref_dict = np.array([{'yes': 1.0}]) # just placeholder + sample = { + "id":np.array(['42']), + "net_input": { + "src_tokens": src_text, + "src_lengths": src_length, + "patch_images": patch_image, + "patch_masks": patch_mask, + }, + "ref_dict": ref_dict, + } + return sample + +# Function to turn FP32 to FP16 +def apply_half(t): + if t.dtype is torch.float32: + return t.to(dtype=torch.half) + return t + + +# Function for image captioning +def open_domain_vqa(Image, Question): + sample = construct_sample(Image, Question) + sample = utils.move_to_cuda(sample) if use_cuda else sample + sample = utils.apply_to_sample(apply_half, sample) if use_fp16 else sample + # Run eval step for open-domain VQA + with torch.no_grad(): + result, scores = zero_shot_step(task, generator, models, sample) + return result[0]['answer'] + + +title = "OFA-Open_Domain_VQA" +description = "Gradio Demo for OFA-Open_Domain_VQA. Upload your own image or click any one of the examples, and click " \ + "\"Submit\" and then wait for OFA's answer. " +article = "
" +examples = [['money_tree.png', 'what is grown on the plant?'], ['airship.jpg', 'what does the red-roofed building right to the big airship look like?'], ['sitting_man.png', 'what is the man sitting on?']] +io = gr.Interface(fn=open_domain_vqa, inputs=[gr.inputs.Image(type='pil'), "textbox"], outputs=gr.outputs.Textbox(label="Answer"), + title=title, description=description, article=article, examples=examples, + allow_flagging=False, allow_screenshot=False) +io.launch(cache_examples=True) \ No newline at end of file diff --git a/checkpoints.md b/checkpoints.md new file mode 100644 index 0000000000000000000000000000000000000000..e887af2f7dbc198661b6224d5eab1b7f2bdd5b77 --- /dev/null +++ b/checkpoints.md @@ -0,0 +1,13 @@ +# Checkpoints + +We provide links for you to download our checkpoints. We will release all the checkpoints including pretrained and finetuned models on different tasks. + +## Pretraining +* Pre-trained checkpoint (OFA-Large) + +## Finetuning + +* Finetuned checkpoint for Caption on COCO +* Finetuned checkpoint for RefCOCO +* Finetuned checkpoint for RefCOCO+ +* Finetuned checkpoint for RefCOCOg diff --git a/colab.md b/colab.md new file mode 100644 index 0000000000000000000000000000000000000000..460c5821bc1fe1662239a1ec4901459e251bf67c --- /dev/null +++ b/colab.md @@ -0,0 +1,8 @@ +# Colab Notebooks + +We provide Colab notebooks of different downstream task for you guys to enjoy OFA. See below. + +* Image Captioning: [![][colab]](https://colab.research.google.com/drive/1Q4eNhhhLcgOP4hHqwZwU1ijOlabgve1W?usp=sharing) +* Referring Expression Comprehension: [![][colab]](https://colab.research.google.com/drive/1AHQNRdaUpRTgr3XySHSlba8aXwBAjwPB?usp=sharing) + +[colab]:+ +* **Convolutional Neural Networks (CNN)** + + [Language Modeling with Gated Convolutional Networks (Dauphin et al., 2017)](examples/language_model/conv_lm/README.md) + + [Convolutional Sequence to Sequence Learning (Gehring et al., 2017)](examples/conv_seq2seq/README.md) + + [Classical Structured Prediction Losses for Sequence to Sequence Learning (Edunov et al., 2018)](https://github.com/pytorch/fairseq/tree/classic_seqlevel) + + [Hierarchical Neural Story Generation (Fan et al., 2018)](examples/stories/README.md) + + [wav2vec: Unsupervised Pre-training for Speech Recognition (Schneider et al., 2019)](examples/wav2vec/README.md) +* **LightConv and DynamicConv models** + + [Pay Less Attention with Lightweight and Dynamic Convolutions (Wu et al., 2019)](examples/pay_less_attention_paper/README.md) +* **Long Short-Term Memory (LSTM) networks** + + Effective Approaches to Attention-based Neural Machine Translation (Luong et al., 2015) +* **Transformer (self-attention) networks** + + Attention Is All You Need (Vaswani et al., 2017) + + [Scaling Neural Machine Translation (Ott et al., 2018)](examples/scaling_nmt/README.md) + + [Understanding Back-Translation at Scale (Edunov et al., 2018)](examples/backtranslation/README.md) + + [Adaptive Input Representations for Neural Language Modeling (Baevski and Auli, 2018)](examples/language_model/README.adaptive_inputs.md) + + [Lexically constrained decoding with dynamic beam allocation (Post & Vilar, 2018)](examples/constrained_decoding/README.md) + + [Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context (Dai et al., 2019)](examples/truncated_bptt/README.md) + + [Adaptive Attention Span in Transformers (Sukhbaatar et al., 2019)](examples/adaptive_span/README.md) + + [Mixture Models for Diverse Machine Translation: Tricks of the Trade (Shen et al., 2019)](examples/translation_moe/README.md) + + [RoBERTa: A Robustly Optimized BERT Pretraining Approach (Liu et al., 2019)](examples/roberta/README.md) + + [Facebook FAIR's WMT19 News Translation Task Submission (Ng et al., 2019)](examples/wmt19/README.md) + + [Jointly Learning to Align and Translate with Transformer Models (Garg et al., 2019)](examples/joint_alignment_translation/README.md ) + + [Multilingual Denoising Pre-training for Neural Machine Translation (Liu et at., 2020)](examples/mbart/README.md) + + [Neural Machine Translation with Byte-Level Subwords (Wang et al., 2020)](examples/byte_level_bpe/README.md) + + [Unsupervised Quality Estimation for Neural Machine Translation (Fomicheva et al., 2020)](examples/unsupervised_quality_estimation/README.md) + + [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations (Baevski et al., 2020)](examples/wav2vec/README.md) + + [Generating Medical Reports from Patient-Doctor Conversations Using Sequence-to-Sequence Models (Enarvi et al., 2020)](examples/pointer_generator/README.md) + + [Linformer: Self-Attention with Linear Complexity (Wang et al., 2020)](examples/linformer/README.md) + + [Cross-lingual Retrieval for Iterative Self-Supervised Training (Tran et al., 2020)](examples/criss/README.md) + + [Deep Transformers with Latent Depth (Li et al., 2020)](examples/latent_depth/README.md) + + [Unsupervised Cross-lingual Representation Learning for Speech Recognition (Conneau et al., 2020)](https://arxiv.org/abs/2006.13979) + + [Robust wav2vec 2.0: Analyzing Domain Shift in Self-Supervised Pre-Training (Hsu, et al., 2021)](https://arxiv.org/abs/2104.01027) + + [Unsupervised Speech Recognition (Baevski, et al., 2021)](https://arxiv.org/abs/2105.11084) +* **Non-autoregressive Transformers** + + Non-Autoregressive Neural Machine Translation (Gu et al., 2017) + + Deterministic Non-Autoregressive Neural Sequence Modeling by Iterative Refinement (Lee et al. 2018) + + Insertion Transformer: Flexible Sequence Generation via Insertion Operations (Stern et al. 2019) + + Mask-Predict: Parallel Decoding of Conditional Masked Language Models (Ghazvininejad et al., 2019) + + [Levenshtein Transformer (Gu et al., 2019)](examples/nonautoregressive_translation/README.md) +* **Finetuning** + + [Better Fine-Tuning by Reducing Representational Collapse (Aghajanyan et al. 2020)](examples/rxf/README.md) + +
+ +* September 2020: [Added Linformer code](examples/linformer/README.md) +* September 2020: [Added pointer-generator networks](examples/pointer_generator/README.md) +* August 2020: [Added lexically constrained decoding](examples/constrained_decoding/README.md) +* August 2020: [wav2vec2 models and code released](examples/wav2vec/README.md) +* July 2020: [Unsupervised Quality Estimation code released](examples/unsupervised_quality_estimation/README.md) +* May 2020: [Follow fairseq on Twitter](https://twitter.com/fairseq) +* April 2020: [Monotonic Multihead Attention code released](examples/simultaneous_translation/README.md) +* April 2020: [Quant-Noise code released](examples/quant_noise/README.md) +* April 2020: [Initial model parallel support and 11B parameters unidirectional LM released](examples/megatron_11b/README.md) +* March 2020: [Byte-level BPE code released](examples/byte_level_bpe/README.md) +* February 2020: [mBART model and code released](examples/mbart/README.md) +* February 2020: [Added tutorial for back-translation](https://github.com/pytorch/fairseq/tree/main/examples/backtranslation#training-your-own-model-wmt18-english-german) +* December 2019: [fairseq 0.9.0 released](https://github.com/pytorch/fairseq/releases/tag/v0.9.0) +* November 2019: [VizSeq released (a visual analysis toolkit for evaluating fairseq models)](https://facebookresearch.github.io/vizseq/docs/getting_started/fairseq_example) +* November 2019: [CamemBERT model and code released](examples/camembert/README.md) +* November 2019: [BART model and code released](examples/bart/README.md) +* November 2019: [XLM-R models and code released](examples/xlmr/README.md) +* September 2019: [Nonautoregressive translation code released](examples/nonautoregressive_translation/README.md) +* August 2019: [WMT'19 models released](examples/wmt19/README.md) +* July 2019: fairseq relicensed under MIT license +* July 2019: [RoBERTa models and code released](examples/roberta/README.md) +* June 2019: [wav2vec models and code released](examples/wav2vec/README.md) + +
+
+
+ +FSDP currently has several limitations compared to fairseq's default DDP backend (PyTorch DDP): +* while FSDP is full compatible with pointwise Optimizers (e.g., Adam, AdamW, Adadelta, Adamax, SGD, etc.), it is not currently compatible with non-pointwise Optimizers (e.g., Adagrad, Adafactor, LAMB, etc.) +* FSDP depends on flattening the parameters, so models that currently require `--fp16-no-flatten-grads` may not be supported + +See the [fairscale docs](https://fairscale.readthedocs.io/en/latest/api/nn/fsdp_tips.html) for a more detailed +explanation of these and other limitations. + +
+
+
+
+See the [fairscale docs](https://fairscale.readthedocs.io/en/latest/api/nn/fsdp_tips.html) for a more detailed
+explanation of how FSDP works.
+
+
+ +``` +(...) +2021-03-08 12:29:51 | INFO | fairseq_cli.train | num. model params: 13,110,865,920 (num. trained: 13,110,865,920) +(...) +2021-03-08 12:29:51 | INFO | fairseq_cli.train | training on 1 devices (GPUs/TPUs) +2021-03-08 12:29:51 | INFO | fairseq_cli.train | max tokens per GPU = None and batch size per GPU = 8 +(...) +Adam Optimizer #0 is created with AVX2 arithmetic capability. +Config: alpha=0.000100, betas=(0.900000, 0.980000), weight_decay=0.000000, adam_w=1 +(...) +2021-03-08 12:31:36 | INFO | train_inner | {"epoch": 1, "update": 0.0, "loss": "16.475", "ppl": "91120.8", "wps": "0", "ups": "0", "wpb": "16384", "bsz": "8", "num_updates": "1", "lr": "2e-05", "gnorm": "20.751", "loss_scale": "4", "train_wall": "99", "gb_free": "9.3", "wall": "105"} +2021-03-08 12:32:33 | INFO | train_inner | {"epoch": 1, "update": 0.0, "loss": "16.446", "ppl": "89281.6", "wps": "288.7", "ups": "0.02", "wpb": "16384", "bsz": "8", "num_updates": "2", "lr": "4e-05", "gnorm": "19.777", "loss_scale": "4", "train_wall": "57", "gb_free": "9.3", "wall": "161"} +2021-03-08 12:33:12 | INFO | fairseq.trainer | NOTE: gradient overflow detected, ignoring gradient, setting loss scale to: 2.0 +2021-03-08 12:33:51 | INFO | fairseq.trainer | NOTE: gradient overflow detected, ignoring gradient, setting loss scale to: 1.0 +2021-03-08 12:34:45 | INFO | train_inner | {"epoch": 1, "update": 0.001, "loss": "25.22", "ppl": "3.90691e+07", "wps": "123.4", "ups": "0.01", "wpb": "16384", "bsz": "8", "num_updates": "3", "lr": "6e-05", "gnorm": "131.281", "loss_scale": "1", "train_wall": "133", "gb_free": "9.3", "wall": "294"} +2021-03-08 12:35:43 | INFO | train_inner | {"epoch": 1, "update": 0.001, "loss": "18.079", "ppl": "276809", "wps": "285.5", "ups": "0.02", "wpb": "16384", "bsz": "8", "num_updates": "4", "lr": "8e-05", "gnorm": "13.776", "loss_scale": "1", "train_wall": "57", "gb_free": "9.3", "wall": "351"} +2021-03-08 12:36:35 | INFO | train_inner | {"epoch": 1, "update": 0.001, "loss": "23.729", "ppl": "1.39088e+07", "wps": "316.7", "ups": "0.02", "wpb": "16384", "bsz": "8", "num_updates": "5", "lr": "0.0001", "gnorm": "72.774", "loss_scale": "1", "train_wall": "52", "gb_free": "9.3", "wall": "403"} +2021-03-08 12:37:28 | INFO | train_inner | {"epoch": 1, "update": 0.001, "loss": "20.429", "ppl": "1.41203e+06", "wps": "307.6", "ups": "0.02", "wpb": "16384", "bsz": "8", "num_updates": "6", "lr": "8e-05", "gnorm": "60.846", "loss_scale": "1", "train_wall": "53", "gb_free": "9.3", "wall": "456"} +2021-03-08 12:38:27 | INFO | train_inner | {"epoch": 1, "update": 0.001, "loss": "18.965", "ppl": "511684", "wps": "279.4", "ups": "0.02", "wpb": "16384", "bsz": "8", "num_updates": "7", "lr": "6e-05", "gnorm": "22.687", "loss_scale": "1", "train_wall": "59", "gb_free": "9.3", "wall": "515"} +2021-03-08 12:39:18 | INFO | train_inner | {"epoch": 1, "update": 0.001, "loss": "18.345", "ppl": "332887", "wps": "319.1", "ups": "0.02", "wpb": "16384", "bsz": "8", "num_updates": "8", "lr": "4e-05", "gnorm": "8.451", "loss_scale": "1", "train_wall": "51", "gb_free": "9.3", "wall": "566"} +2021-03-08 12:40:11 | INFO | train_inner | {"epoch": 1, "update": 0.002, "loss": "18.262", "ppl": "314336", "wps": "305.9", "ups": "0.02", "wpb": "16384", "bsz": "8", "num_updates": "9", "lr": "2e-05", "gnorm": "6.457", "loss_scale": "1", "train_wall": "54", "gb_free": "9.3", "wall": "620"} +2021-03-08 12:41:04 | INFO | train_inner | {"epoch": 1, "update": 0.002, "loss": "17.556", "ppl": "192686", "wps": "311.8", "ups": "0.02", "wpb": "16384", "bsz": "8", "num_updates": "10", "lr": "0", "gnorm": "5.796", "loss_scale": "1", "train_wall": "53", "gb_free": "9.3", "wall": "673"} +2021-03-08 12:41:04 | INFO | fairseq_cli.train | Stopping training due to num_updates: 10 >= max_update: 10 +2021-03-08 12:41:04 | INFO | fairseq_cli.train | begin validation on "valid" subset +2021-03-08 12:43:15 | INFO | valid | {"epoch": 1, "valid_loss": "17.953", "valid_ppl": "253807", "valid_wps": "1868.4", "valid_wpb": "15400.2", "valid_bsz": "7.6", "valid_num_updates": "10"} +2021-03-08 12:43:15 | INFO | fairseq_cli.train | end of epoch 1 (average epoch stats below) +2021-03-08 12:43:15 | INFO | train | {"epoch": 1, "train_loss": "19.351", "train_ppl": "668509", "train_wps": "210.9", "train_ups": "0.01", "train_wpb": "16384", "train_bsz": "8", "train_num_updates": "10", "train_lr": "0", "train_gnorm": "36.26", "train_loss_scale": "1", "train_train_wall": "667", "train_gb_free": "9.3", "train_wall": "804"} +2021-03-08 12:43:15 | INFO | fairseq_cli.train | done training in 798.6 seconds +``` + +
+ +``` +(...) +2021-03-08 18:04:09 | INFO | fairseq_cli.train | num. model params: 13,110,865,920 (num. trained: 13,110,865,920) +(...) +2021-03-08 18:04:09 | INFO | fairseq_cli.train | training on 8 devices (GPUs/TPUs) +2021-03-08 18:04:09 | INFO | fairseq_cli.train | max tokens per GPU = None and batch size per GPU = 8 +(...) +Adam Optimizer #0 is created with AVX2 arithmetic capability. +Config: alpha=0.000100, betas=(0.900000, 0.980000), weight_decay=0.000000, adam_w=1 +(...) +2021-03-08 18:05:06 | INFO | train_inner | {"epoch": 1, "update": 0.001, "loss": "16.408", "ppl": "86945.6", "wps": "0", "ups": "0", "wpb": "131072", "bsz": "64", "num_updates": "1", "lr": "2e-05", "gnorm": "18.27", "loss_scale": "4", "train_wall": "47", "gb_free": "9.3", "wall": "56"} +2021-03-08 18:05:45 | INFO | train_inner | {"epoch": 1, "update": 0.002, "loss": "16.352", "ppl": "83644.3", "wps": "3283.4", "ups": "0.03", "wpb": "131072", "bsz": "64", "num_updates": "2", "lr": "4e-05", "gnorm": "18.411", "loss_scale": "4", "train_wall": "40", "gb_free": "9.3", "wall": "96"} +2021-03-08 18:06:21 | INFO | fairseq.trainer | NOTE: gradient overflow detected, ignoring gradient, setting loss scale to: 2.0 +2021-03-08 18:06:56 | INFO | fairseq.trainer | NOTE: gradient overflow detected, ignoring gradient, setting loss scale to: 1.0 +2021-03-08 18:07:37 | INFO | train_inner | {"epoch": 1, "update": 0.006, "loss": "23.682", "ppl": "1.34537e+07", "wps": "1176.6", "ups": "0.01", "wpb": "131072", "bsz": "64", "num_updates": "3", "lr": "6e-05", "gnorm": "119.682", "loss_scale": "1", "train_wall": "111", "gb_free": "9.3", "wall": "208"} +2021-03-08 18:08:18 | INFO | train_inner | {"epoch": 1, "update": 0.007, "loss": "18.988", "ppl": "519921", "wps": "3189.1", "ups": "0.02", "wpb": "131072", "bsz": "64", "num_updates": "4", "lr": "8e-05", "gnorm": "14.934", "loss_scale": "1", "train_wall": "41", "gb_free": "9.3", "wall": "249"} +2021-03-08 18:08:59 | INFO | train_inner | {"epoch": 1, "update": 0.008, "loss": "20.08", "ppl": "1.10798e+06", "wps": "3223.1", "ups": "0.02", "wpb": "131072", "bsz": "64", "num_updates": "5", "lr": "0.0001", "gnorm": "59.92", "loss_scale": "1", "train_wall": "41", "gb_free": "9.3", "wall": "289"} +2021-03-08 18:09:39 | INFO | train_inner | {"epoch": 1, "update": 0.009, "loss": "18.323", "ppl": "327980", "wps": "3256.6", "ups": "0.02", "wpb": "131072", "bsz": "64", "num_updates": "6", "lr": "8e-05", "gnorm": "37.425", "loss_scale": "1", "train_wall": "40", "gb_free": "9.3", "wall": "330"} +2021-03-08 18:10:20 | INFO | train_inner | {"epoch": 1, "update": 0.01, "loss": "17.264", "ppl": "157354", "wps": "3188.7", "ups": "0.02", "wpb": "131072", "bsz": "64", "num_updates": "7", "lr": "6e-05", "gnorm": "10.824", "loss_scale": "1", "train_wall": "41", "gb_free": "9.3", "wall": "371"} +2021-03-08 18:11:01 | INFO | train_inner | {"epoch": 1, "update": 0.011, "loss": "16.794", "ppl": "113647", "wps": "3230", "ups": "0.02", "wpb": "131072", "bsz": "64", "num_updates": "8", "lr": "4e-05", "gnorm": "5.616", "loss_scale": "1", "train_wall": "41", "gb_free": "9.3", "wall": "411"} +2021-03-08 18:11:39 | INFO | train_inner | {"epoch": 1, "update": 0.012, "loss": "16.706", "ppl": "106938", "wps": "3384", "ups": "0.03", "wpb": "131072", "bsz": "64", "num_updates": "9", "lr": "2e-05", "gnorm": "5.318", "loss_scale": "1", "train_wall": "39", "gb_free": "9.3", "wall": "450"} +2021-03-08 18:12:19 | INFO | train_inner | {"epoch": 1, "update": 0.013, "loss": "16.548", "ppl": "95796.2", "wps": "3274.4", "ups": "0.02", "wpb": "131072", "bsz": "64", "num_updates": "10", "lr": "0", "gnorm": "5.22", "loss_scale": "1", "train_wall": "40", "gb_free": "9.3", "wall": "490"} +2021-03-08 18:12:19 | INFO | fairseq_cli.train | Stopping training due to num_updates: 10 >= max_update: 10 +2021-03-08 18:12:19 | INFO | fairseq_cli.train | begin validation on "valid" subset +2021-03-08 18:12:45 | INFO | valid | {"epoch": 1, "valid_loss": "16.624", "valid_ppl": "101000", "valid_wps": "10855.9", "valid_wpb": "123202", "valid_bsz": "60.5", "valid_num_updates": "10"} +2021-03-08 18:12:45 | INFO | fairseq_cli.train | end of epoch 1 (average epoch stats below) +2021-03-08 18:12:45 | INFO | train | {"epoch": 1, "train_loss": "18.114", "train_ppl": "283776", "train_wps": "2567.8", "train_ups": "0.02", "train_wpb": "131072", "train_bsz": "64", "train_num_updates": "10", "train_lr": "0", "train_gnorm": "29.562", "train_loss_scale": "1", "train_train_wall": "480", "train_gb_free": "9.3", "train_wall": "516"} +2021-03-08 18:12:45 | INFO | fairseq_cli.train | done training in 509.9 seconds +``` + +
+ self.score = score # float
+
+
+def coordinate_to_offset(row, col, ncols):
+ return int(row * ncols + col)
+
+
+def offset_to_row(offset, ncols):
+ return int(offset / ncols)
+
+
+def offset_to_col(offset, ncols):
+ return int(offset % ncols)
+
+
+def trimWhitespace(str):
+ return re.sub(" +", " ", re.sub(" *$", "", re.sub("^ *", "", str)))
+
+
+def str2toks(str):
+ pieces = trimWhitespace(str).split(" ")
+ toks = []
+ for p in pieces:
+ toks.append(Token(p, 0.0, 0.0))
+ return toks
+
+
+class EditDistance(object):
+ def __init__(self, time_mediated):
+ self.time_mediated_ = time_mediated
+ self.scores_ = np.nan # Eigen::Matrix
+ self.backtraces_ = (
+ np.nan
+ ) # Eigen::Matrix backtraces_;
+ self.confusion_pairs_ = {}
+
+ def cost(self, ref, hyp, code):
+ if self.time_mediated_:
+ if code == Code.match:
+ return abs(ref.start - hyp.start) + abs(ref.end - hyp.end)
+ elif code == Code.insertion:
+ return hyp.end - hyp.start
+ elif code == Code.deletion:
+ return ref.end - ref.start
+ else: # substitution
+ return abs(ref.start - hyp.start) + abs(ref.end - hyp.end) + 0.1
+ else:
+ if code == Code.match:
+ return 0
+ elif code == Code.insertion or code == Code.deletion:
+ return 3
+ else: # substitution
+ return 4
+
+ def get_result(self, refs, hyps):
+ res = AlignmentResult(refs=deque(), hyps=deque(), codes=deque(), score=np.nan)
+
+ num_rows, num_cols = self.scores_.shape
+ res.score = self.scores_[num_rows - 1, num_cols - 1]
+
+ curr_offset = coordinate_to_offset(num_rows - 1, num_cols - 1, num_cols)
+
+ while curr_offset != 0:
+ curr_row = offset_to_row(curr_offset, num_cols)
+ curr_col = offset_to_col(curr_offset, num_cols)
+
+ prev_offset = self.backtraces_[curr_row, curr_col]
+
+ prev_row = offset_to_row(prev_offset, num_cols)
+ prev_col = offset_to_col(prev_offset, num_cols)
+
+ res.refs.appendleft(curr_row - 1) # Note: this was .push_front() in C++
+ res.hyps.appendleft(curr_col - 1)
+ if curr_row - 1 == prev_row and curr_col == prev_col:
+ res.codes.appendleft(Code.deletion)
+ elif curr_row == prev_row and curr_col - 1 == prev_col:
+ res.codes.appendleft(Code.insertion)
+ else:
+ # assert(curr_row - 1 == prev_row and curr_col - 1 == prev_col)
+ ref_str = refs[res.refs[0]].label
+ hyp_str = hyps[res.hyps[0]].label
+
+ if ref_str == hyp_str:
+ res.codes.appendleft(Code.match)
+ else:
+ res.codes.appendleft(Code.substitution)
+
+ confusion_pair = "%s -> %s" % (ref_str, hyp_str)
+ if confusion_pair not in self.confusion_pairs_:
+ self.confusion_pairs_[confusion_pair] = 1
+ else:
+ self.confusion_pairs_[confusion_pair] += 1
+
+ curr_offset = prev_offset
+
+ return res
+
+ def align(self, refs, hyps):
+ if len(refs) == 0 and len(hyps) == 0:
+ return np.nan
+
+ # NOTE: we're not resetting the values in these matrices because every value
+ # will be overridden in the loop below. If this assumption doesn't hold,
+ # be sure to set all entries in self.scores_ and self.backtraces_ to 0.
+ self.scores_ = np.zeros((len(refs) + 1, len(hyps) + 1))
+ self.backtraces_ = np.zeros((len(refs) + 1, len(hyps) + 1))
+
+ num_rows, num_cols = self.scores_.shape
+
+ for i in range(num_rows):
+ for j in range(num_cols):
+ if i == 0 and j == 0:
+ self.scores_[i, j] = 0.0
+ self.backtraces_[i, j] = 0
+ continue
+
+ if i == 0:
+ self.scores_[i, j] = self.scores_[i, j - 1] + self.cost(
+ None, hyps[j - 1], Code.insertion
+ )
+ self.backtraces_[i, j] = coordinate_to_offset(i, j - 1, num_cols)
+ continue
+
+ if j == 0:
+ self.scores_[i, j] = self.scores_[i - 1, j] + self.cost(
+ refs[i - 1], None, Code.deletion
+ )
+ self.backtraces_[i, j] = coordinate_to_offset(i - 1, j, num_cols)
+ continue
+
+ # Below here both i and j are greater than 0
+ ref = refs[i - 1]
+ hyp = hyps[j - 1]
+ best_score = self.scores_[i - 1, j - 1] + (
+ self.cost(ref, hyp, Code.match)
+ if (ref.label == hyp.label)
+ else self.cost(ref, hyp, Code.substitution)
+ )
+
+ prev_row = i - 1
+ prev_col = j - 1
+ ins = self.scores_[i, j - 1] + self.cost(None, hyp, Code.insertion)
+ if ins < best_score:
+ best_score = ins
+ prev_row = i
+ prev_col = j - 1
+
+ delt = self.scores_[i - 1, j] + self.cost(ref, None, Code.deletion)
+ if delt < best_score:
+ best_score = delt
+ prev_row = i - 1
+ prev_col = j
+
+ self.scores_[i, j] = best_score
+ self.backtraces_[i, j] = coordinate_to_offset(
+ prev_row, prev_col, num_cols
+ )
+
+ return self.get_result(refs, hyps)
+
+
+class WERTransformer(object):
+ def __init__(self, hyp_str, ref_str, verbose=True):
+ self.ed_ = EditDistance(False)
+ self.id2oracle_errs_ = {}
+ self.utts_ = 0
+ self.words_ = 0
+ self.insertions_ = 0
+ self.deletions_ = 0
+ self.substitutions_ = 0
+
+ self.process(["dummy_str", hyp_str, ref_str])
+
+ if verbose:
+ print("'%s' vs '%s'" % (hyp_str, ref_str))
+ self.report_result()
+
+ def process(self, input): # std::vector&& input
+ if len(input) < 3:
+ print(
+ "Input must be of the form ... , got ",
+ len(input),
+ " inputs:",
+ )
+ return None
+
+ # Align
+ # std::vector hyps;
+ # std::vector refs;
+
+ hyps = str2toks(input[-2])
+ refs = str2toks(input[-1])
+
+ alignment = self.ed_.align(refs, hyps)
+ if alignment is None:
+ print("Alignment is null")
+ return np.nan
+
+ # Tally errors
+ ins = 0
+ dels = 0
+ subs = 0
+ for code in alignment.codes:
+ if code == Code.substitution:
+ subs += 1
+ elif code == Code.insertion:
+ ins += 1
+ elif code == Code.deletion:
+ dels += 1
+
+ # Output
+ row = input
+ row.append(str(len(refs)))
+ row.append(str(ins))
+ row.append(str(dels))
+ row.append(str(subs))
+ # print(row)
+
+ # Accumulate
+ kIdIndex = 0
+ kNBestSep = "/"
+
+ pieces = input[kIdIndex].split(kNBestSep)
+
+ if len(pieces) == 0:
+ print(
+ "Error splitting ",
+ input[kIdIndex],
+ " on '",
+ kNBestSep,
+ "', got empty list",
+ )
+ return np.nan
+
+ id = pieces[0]
+ if id not in self.id2oracle_errs_:
+ self.utts_ += 1
+ self.words_ += len(refs)
+ self.insertions_ += ins
+ self.deletions_ += dels
+ self.substitutions_ += subs
+ self.id2oracle_errs_[id] = [ins, dels, subs]
+ else:
+ curr_err = ins + dels + subs
+ prev_err = np.sum(self.id2oracle_errs_[id])
+ if curr_err < prev_err:
+ self.id2oracle_errs_[id] = [ins, dels, subs]
+
+ return 0
+
+ def report_result(self):
+ # print("---------- Summary ---------------")
+ if self.words_ == 0:
+ print("No words counted")
+ return
+
+ # 1-best
+ best_wer = (
+ 100.0
+ * (self.insertions_ + self.deletions_ + self.substitutions_)
+ / self.words_
+ )
+
+ print(
+ "\tWER = %0.2f%% (%i utts, %i words, %0.2f%% ins, "
+ "%0.2f%% dels, %0.2f%% subs)"
+ % (
+ best_wer,
+ self.utts_,
+ self.words_,
+ 100.0 * self.insertions_ / self.words_,
+ 100.0 * self.deletions_ / self.words_,
+ 100.0 * self.substitutions_ / self.words_,
+ )
+ )
+
+ def wer(self):
+ if self.words_ == 0:
+ wer = np.nan
+ else:
+ wer = (
+ 100.0
+ * (self.insertions_ + self.deletions_ + self.substitutions_)
+ / self.words_
+ )
+ return wer
+
+ def stats(self):
+ if self.words_ == 0:
+ stats = {}
+ else:
+ wer = (
+ 100.0
+ * (self.insertions_ + self.deletions_ + self.substitutions_)
+ / self.words_
+ )
+ stats = dict(
+ {
+ "wer": wer,
+ "utts": self.utts_,
+ "numwords": self.words_,
+ "ins": self.insertions_,
+ "dels": self.deletions_,
+ "subs": self.substitutions_,
+ "confusion_pairs": self.ed_.confusion_pairs_,
+ }
+ )
+ return stats
+
+
+def calc_wer(hyp_str, ref_str):
+ t = WERTransformer(hyp_str, ref_str, verbose=0)
+ return t.wer()
+
+
+def calc_wer_stats(hyp_str, ref_str):
+ t = WERTransformer(hyp_str, ref_str, verbose=0)
+ return t.stats()
+
+
+def get_wer_alignment_codes(hyp_str, ref_str):
+ """
+ INPUT: hypothesis string, reference string
+ OUTPUT: List of alignment codes (intermediate results from WER computation)
+ """
+ t = WERTransformer(hyp_str, ref_str, verbose=0)
+ return t.ed_.align(str2toks(ref_str), str2toks(hyp_str)).codes
+
+
+def merge_counts(x, y):
+ # Merge two hashes which have 'counts' as their values
+ # This can be used for example to merge confusion pair counts
+ # conf_pairs = merge_counts(conf_pairs, stats['confusion_pairs'])
+ for k, v in y.items():
+ if k not in x:
+ x[k] = 0
+ x[k] += v
+ return x
diff --git a/fairseq/examples/speech_recognition/w2l_decoder.py b/fairseq/examples/speech_recognition/w2l_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..fbf2d3524ee40bd0d08b6a9560047d96e49b6045
--- /dev/null
+++ b/fairseq/examples/speech_recognition/w2l_decoder.py
@@ -0,0 +1,486 @@
+#!/usr/bin/env python3
+
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+Flashlight decoders.
+"""
+
+import gc
+import itertools as it
+import os.path as osp
+from typing import List
+import warnings
+from collections import deque, namedtuple
+
+import numpy as np
+import torch
+from examples.speech_recognition.data.replabels import unpack_replabels
+from fairseq import tasks
+from fairseq.utils import apply_to_sample
+from omegaconf import open_dict
+from fairseq.dataclass.utils import convert_namespace_to_omegaconf
+
+
+try:
+ from flashlight.lib.text.dictionary import create_word_dict, load_words
+ from flashlight.lib.sequence.criterion import CpuViterbiPath, get_data_ptr_as_bytes
+ from flashlight.lib.text.decoder import (
+ CriterionType,
+ LexiconDecoderOptions,
+ KenLM,
+ LM,
+ LMState,
+ SmearingMode,
+ Trie,
+ LexiconDecoder,
+ )
+except:
+ warnings.warn(
+ "flashlight python bindings are required to use this functionality. Please install from https://github.com/facebookresearch/flashlight/tree/master/bindings/python"
+ )
+ LM = object
+ LMState = object
+
+
+class W2lDecoder(object):
+ def __init__(self, args, tgt_dict):
+ self.tgt_dict = tgt_dict
+ self.vocab_size = len(tgt_dict)
+ self.nbest = args.nbest
+
+ # criterion-specific init
+ self.criterion_type = CriterionType.CTC
+ self.blank = (
+ tgt_dict.index("")
+ if "" in tgt_dict.indices
+ else tgt_dict.bos()
+ )
+ if "" in tgt_dict.indices:
+ self.silence = tgt_dict.index("")
+ elif "|" in tgt_dict.indices:
+ self.silence = tgt_dict.index("|")
+ else:
+ self.silence = tgt_dict.eos()
+ self.asg_transitions = None
+
+ def generate(self, models, sample, **unused):
+ """Generate a batch of inferences."""
+ # model.forward normally channels prev_output_tokens into the decoder
+ # separately, but SequenceGenerator directly calls model.encoder
+ encoder_input = {
+ k: v for k, v in sample["net_input"].items() if k != "prev_output_tokens"
+ }
+ emissions = self.get_emissions(models, encoder_input)
+ return self.decode(emissions)
+
+ def get_emissions(self, models, encoder_input):
+ """Run encoder and normalize emissions"""
+ model = models[0]
+ encoder_out = model(**encoder_input)
+ if hasattr(model, "get_logits"):
+ emissions = model.get_logits(encoder_out) # no need to normalize emissions
+ else:
+ emissions = model.get_normalized_probs(encoder_out, log_probs=True)
+ return emissions.transpose(0, 1).float().cpu().contiguous()
+
+ def get_tokens(self, idxs):
+ """Normalize tokens by handling CTC blank, ASG replabels, etc."""
+ idxs = (g[0] for g in it.groupby(idxs))
+ idxs = filter(lambda x: x != self.blank, idxs)
+ return torch.LongTensor(list(idxs))
+
+
+class W2lViterbiDecoder(W2lDecoder):
+ def __init__(self, args, tgt_dict):
+ super().__init__(args, tgt_dict)
+
+ def decode(self, emissions):
+ B, T, N = emissions.size()
+ hypos = []
+ if self.asg_transitions is None:
+ transitions = torch.FloatTensor(N, N).zero_()
+ else:
+ transitions = torch.FloatTensor(self.asg_transitions).view(N, N)
+ viterbi_path = torch.IntTensor(B, T)
+ workspace = torch.ByteTensor(CpuViterbiPath.get_workspace_size(B, T, N))
+ CpuViterbiPath.compute(
+ B,
+ T,
+ N,
+ get_data_ptr_as_bytes(emissions),
+ get_data_ptr_as_bytes(transitions),
+ get_data_ptr_as_bytes(viterbi_path),
+ get_data_ptr_as_bytes(workspace),
+ )
+ return [
+ [{"tokens": self.get_tokens(viterbi_path[b].tolist()), "score": 0}]
+ for b in range(B)
+ ]
+
+
+class W2lKenLMDecoder(W2lDecoder):
+ def __init__(self, args, tgt_dict):
+ super().__init__(args, tgt_dict)
+
+ self.unit_lm = getattr(args, "unit_lm", False)
+
+ if args.lexicon:
+ self.lexicon = load_words(args.lexicon)
+ self.word_dict = create_word_dict(self.lexicon)
+ self.unk_word = self.word_dict.get_index("")
+
+ self.lm = KenLM(args.kenlm_model, self.word_dict)
+ self.trie = Trie(self.vocab_size, self.silence)
+
+ start_state = self.lm.start(False)
+ for i, (word, spellings) in enumerate(self.lexicon.items()):
+ word_idx = self.word_dict.get_index(word)
+ _, score = self.lm.score(start_state, word_idx)
+ for spelling in spellings:
+ spelling_idxs = [tgt_dict.index(token) for token in spelling]
+ assert (
+ tgt_dict.unk() not in spelling_idxs
+ ), f"{spelling} {spelling_idxs}"
+ self.trie.insert(spelling_idxs, word_idx, score)
+ self.trie.smear(SmearingMode.MAX)
+
+ self.decoder_opts = LexiconDecoderOptions(
+ beam_size=args.beam,
+ beam_size_token=int(getattr(args, "beam_size_token", len(tgt_dict))),
+ beam_threshold=args.beam_threshold,
+ lm_weight=args.lm_weight,
+ word_score=args.word_score,
+ unk_score=args.unk_weight,
+ sil_score=args.sil_weight,
+ log_add=False,
+ criterion_type=self.criterion_type,
+ )
+
+ if self.asg_transitions is None:
+ N = 768
+ # self.asg_transitions = torch.FloatTensor(N, N).zero_()
+ self.asg_transitions = []
+
+ self.decoder = LexiconDecoder(
+ self.decoder_opts,
+ self.trie,
+ self.lm,
+ self.silence,
+ self.blank,
+ self.unk_word,
+ self.asg_transitions,
+ self.unit_lm,
+ )
+ else:
+ assert args.unit_lm, "lexicon free decoding can only be done with a unit language model"
+ from flashlight.lib.text.decoder import LexiconFreeDecoder, LexiconFreeDecoderOptions
+
+ d = {w: [[w]] for w in tgt_dict.symbols}
+ self.word_dict = create_word_dict(d)
+ self.lm = KenLM(args.kenlm_model, self.word_dict)
+ self.decoder_opts = LexiconFreeDecoderOptions(
+ beam_size=args.beam,
+ beam_size_token=int(getattr(args, "beam_size_token", len(tgt_dict))),
+ beam_threshold=args.beam_threshold,
+ lm_weight=args.lm_weight,
+ sil_score=args.sil_weight,
+ log_add=False,
+ criterion_type=self.criterion_type,
+ )
+ self.decoder = LexiconFreeDecoder(
+ self.decoder_opts, self.lm, self.silence, self.blank, []
+ )
+
+ def get_timesteps(self, token_idxs: List[int]) -> List[int]:
+ """Returns frame numbers corresponding to every non-blank token.
+
+ Parameters
+ ----------
+ token_idxs : List[int]
+ IDs of decoded tokens.
+
+ Returns
+ -------
+ List[int]
+ Frame numbers corresponding to every non-blank token.
+ """
+ timesteps = []
+ for i, token_idx in enumerate(token_idxs):
+ if token_idx == self.blank:
+ continue
+ if i == 0 or token_idx != token_idxs[i-1]:
+ timesteps.append(i)
+ return timesteps
+
+ def decode(self, emissions):
+ B, T, N = emissions.size()
+ hypos = []
+ for b in range(B):
+ emissions_ptr = emissions.data_ptr() + 4 * b * emissions.stride(0)
+ results = self.decoder.decode(emissions_ptr, T, N)
+
+ nbest_results = results[: self.nbest]
+ hypos.append(
+ [
+ {
+ "tokens": self.get_tokens(result.tokens),
+ "score": result.score,
+ "timesteps": self.get_timesteps(result.tokens),
+ "words": [
+ self.word_dict.get_entry(x) for x in result.words if x >= 0
+ ],
+ }
+ for result in nbest_results
+ ]
+ )
+ return hypos
+
+
+FairseqLMState = namedtuple("FairseqLMState", ["prefix", "incremental_state", "probs"])
+
+
+class FairseqLM(LM):
+ def __init__(self, dictionary, model):
+ LM.__init__(self)
+ self.dictionary = dictionary
+ self.model = model
+ self.unk = self.dictionary.unk()
+
+ self.save_incremental = False # this currently does not work properly
+ self.max_cache = 20_000
+
+ model.cuda()
+ model.eval()
+ model.make_generation_fast_()
+
+ self.states = {}
+ self.stateq = deque()
+
+ def start(self, start_with_nothing):
+ state = LMState()
+ prefix = torch.LongTensor([[self.dictionary.eos()]])
+ incremental_state = {} if self.save_incremental else None
+ with torch.no_grad():
+ res = self.model(prefix.cuda(), incremental_state=incremental_state)
+ probs = self.model.get_normalized_probs(res, log_probs=True, sample=None)
+
+ if incremental_state is not None:
+ incremental_state = apply_to_sample(lambda x: x.cpu(), incremental_state)
+ self.states[state] = FairseqLMState(
+ prefix.numpy(), incremental_state, probs[0, -1].cpu().numpy()
+ )
+ self.stateq.append(state)
+
+ return state
+
+ def score(self, state: LMState, token_index: int, no_cache: bool = False):
+ """
+ Evaluate language model based on the current lm state and new word
+ Parameters:
+ -----------
+ state: current lm state
+ token_index: index of the word
+ (can be lexicon index then you should store inside LM the
+ mapping between indices of lexicon and lm, or lm index of a word)
+
+ Returns:
+ --------
+ (LMState, float): pair of (new state, score for the current word)
+ """
+ curr_state = self.states[state]
+
+ def trim_cache(targ_size):
+ while len(self.stateq) > targ_size:
+ rem_k = self.stateq.popleft()
+ rem_st = self.states[rem_k]
+ rem_st = FairseqLMState(rem_st.prefix, None, None)
+ self.states[rem_k] = rem_st
+
+ if curr_state.probs is None:
+ new_incremental_state = (
+ curr_state.incremental_state.copy()
+ if curr_state.incremental_state is not None
+ else None
+ )
+ with torch.no_grad():
+ if new_incremental_state is not None:
+ new_incremental_state = apply_to_sample(
+ lambda x: x.cuda(), new_incremental_state
+ )
+ elif self.save_incremental:
+ new_incremental_state = {}
+
+ res = self.model(
+ torch.from_numpy(curr_state.prefix).cuda(),
+ incremental_state=new_incremental_state,
+ )
+ probs = self.model.get_normalized_probs(
+ res, log_probs=True, sample=None
+ )
+
+ if new_incremental_state is not None:
+ new_incremental_state = apply_to_sample(
+ lambda x: x.cpu(), new_incremental_state
+ )
+
+ curr_state = FairseqLMState(
+ curr_state.prefix, new_incremental_state, probs[0, -1].cpu().numpy()
+ )
+
+ if not no_cache:
+ self.states[state] = curr_state
+ self.stateq.append(state)
+
+ score = curr_state.probs[token_index].item()
+
+ trim_cache(self.max_cache)
+
+ outstate = state.child(token_index)
+ if outstate not in self.states and not no_cache:
+ prefix = np.concatenate(
+ [curr_state.prefix, torch.LongTensor([[token_index]])], -1
+ )
+ incr_state = curr_state.incremental_state
+
+ self.states[outstate] = FairseqLMState(prefix, incr_state, None)
+
+ if token_index == self.unk:
+ score = float("-inf")
+
+ return outstate, score
+
+ def finish(self, state: LMState):
+ """
+ Evaluate eos for language model based on the current lm state
+
+ Returns:
+ --------
+ (LMState, float): pair of (new state, score for the current word)
+ """
+ return self.score(state, self.dictionary.eos())
+
+ def empty_cache(self):
+ self.states = {}
+ self.stateq = deque()
+ gc.collect()
+
+
+class W2lFairseqLMDecoder(W2lDecoder):
+ def __init__(self, args, tgt_dict):
+ super().__init__(args, tgt_dict)
+
+ self.unit_lm = getattr(args, "unit_lm", False)
+
+ self.lexicon = load_words(args.lexicon) if args.lexicon else None
+ self.idx_to_wrd = {}
+
+ checkpoint = torch.load(args.kenlm_model, map_location="cpu")
+
+ if "cfg" in checkpoint and checkpoint["cfg"] is not None:
+ lm_args = checkpoint["cfg"]
+ else:
+ lm_args = convert_namespace_to_omegaconf(checkpoint["args"])
+
+ with open_dict(lm_args.task):
+ lm_args.task.data = osp.dirname(args.kenlm_model)
+
+ task = tasks.setup_task(lm_args.task)
+ model = task.build_model(lm_args.model)
+ model.load_state_dict(checkpoint["model"], strict=False)
+
+ self.trie = Trie(self.vocab_size, self.silence)
+
+ self.word_dict = task.dictionary
+ self.unk_word = self.word_dict.unk()
+ self.lm = FairseqLM(self.word_dict, model)
+
+ if self.lexicon:
+ start_state = self.lm.start(False)
+ for i, (word, spellings) in enumerate(self.lexicon.items()):
+ if self.unit_lm:
+ word_idx = i
+ self.idx_to_wrd[i] = word
+ score = 0
+ else:
+ word_idx = self.word_dict.index(word)
+ _, score = self.lm.score(start_state, word_idx, no_cache=True)
+
+ for spelling in spellings:
+ spelling_idxs = [tgt_dict.index(token) for token in spelling]
+ assert (
+ tgt_dict.unk() not in spelling_idxs
+ ), f"{spelling} {spelling_idxs}"
+ self.trie.insert(spelling_idxs, word_idx, score)
+ self.trie.smear(SmearingMode.MAX)
+
+ self.decoder_opts = LexiconDecoderOptions(
+ beam_size=args.beam,
+ beam_size_token=int(getattr(args, "beam_size_token", len(tgt_dict))),
+ beam_threshold=args.beam_threshold,
+ lm_weight=args.lm_weight,
+ word_score=args.word_score,
+ unk_score=args.unk_weight,
+ sil_score=args.sil_weight,
+ log_add=False,
+ criterion_type=self.criterion_type,
+ )
+
+ self.decoder = LexiconDecoder(
+ self.decoder_opts,
+ self.trie,
+ self.lm,
+ self.silence,
+ self.blank,
+ self.unk_word,
+ [],
+ self.unit_lm,
+ )
+ else:
+ assert args.unit_lm, "lexicon free decoding can only be done with a unit language model"
+ from flashlight.lib.text.decoder import LexiconFreeDecoder, LexiconFreeDecoderOptions
+
+ d = {w: [[w]] for w in tgt_dict.symbols}
+ self.word_dict = create_word_dict(d)
+ self.lm = KenLM(args.kenlm_model, self.word_dict)
+ self.decoder_opts = LexiconFreeDecoderOptions(
+ beam_size=args.beam,
+ beam_size_token=int(getattr(args, "beam_size_token", len(tgt_dict))),
+ beam_threshold=args.beam_threshold,
+ lm_weight=args.lm_weight,
+ sil_score=args.sil_weight,
+ log_add=False,
+ criterion_type=self.criterion_type,
+ )
+ self.decoder = LexiconFreeDecoder(
+ self.decoder_opts, self.lm, self.silence, self.blank, []
+ )
+
+ def decode(self, emissions):
+ B, T, N = emissions.size()
+ hypos = []
+
+ def idx_to_word(idx):
+ if self.unit_lm:
+ return self.idx_to_wrd[idx]
+ else:
+ return self.word_dict[idx]
+
+ def make_hypo(result):
+ hypo = {"tokens": self.get_tokens(result.tokens), "score": result.score}
+ if self.lexicon:
+ hypo["words"] = [idx_to_word(x) for x in result.words if x >= 0]
+ return hypo
+
+ for b in range(B):
+ emissions_ptr = emissions.data_ptr() + 4 * b * emissions.stride(0)
+ results = self.decoder.decode(emissions_ptr, T, N)
+
+ nbest_results = results[: self.nbest]
+ hypos.append([make_hypo(result) for result in nbest_results])
+ self.lm.empty_cache()
+
+ return hypos
diff --git a/fairseq/examples/speech_synthesis/README.md b/fairseq/examples/speech_synthesis/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..4a3ae54b857c43621c9fb67ee4b214584beec835
--- /dev/null
+++ b/fairseq/examples/speech_synthesis/README.md
@@ -0,0 +1,16 @@
+Speech Synthesis (S^2)
+===
+
+Speech synthesis with fairseq.
+
+- Autoregressive and non-autoregressive models
+- Multi-speaker synthesis
+- Audio preprocessing
+- Automatic metrics
+- Similar data configuration as [S2T](../speech_to_text/README.md)
+
+
+## Examples
+- [Single-speaker synthesis on LJSpeech](docs/ljspeech_example.md)
+- [Multi-speaker synthesis on VCTK](docs/vctk_example.md)
+- [Multi-speaker synthesis on Common Voice](docs/common_voice_example.md)
diff --git a/fairseq/examples/speech_synthesis/__init__.py b/fairseq/examples/speech_synthesis/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..6264236915a7269a4d920ee8213004374dd86a9a
--- /dev/null
+++ b/fairseq/examples/speech_synthesis/__init__.py
@@ -0,0 +1,4 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
diff --git a/fairseq/examples/speech_synthesis/data_utils.py b/fairseq/examples/speech_synthesis/data_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..f43a4a90046fb9ee4944dc06ba377c1faade141d
--- /dev/null
+++ b/fairseq/examples/speech_synthesis/data_utils.py
@@ -0,0 +1,320 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import os
+from pathlib import Path
+from typing import Optional, List, Dict
+import zipfile
+import tempfile
+from dataclasses import dataclass
+from itertools import groupby
+
+import torch
+import torch.nn.functional as F
+import numpy as np
+from tqdm import tqdm
+
+from examples.speech_to_text.data_utils import load_tsv_to_dicts
+from fairseq.data.audio.audio_utils import TTSSpectrogram, TTSMelScale
+
+
+def trim_or_pad_to_target_length(
+ data_1d_or_2d: np.ndarray, target_length: int
+) -> np.ndarray:
+ assert len(data_1d_or_2d.shape) in {1, 2}
+ delta = data_1d_or_2d.shape[0] - target_length
+ if delta >= 0: # trim if being longer
+ data_1d_or_2d = data_1d_or_2d[: target_length]
+ else: # pad if being shorter
+ if len(data_1d_or_2d.shape) == 1:
+ data_1d_or_2d = np.concatenate(
+ [data_1d_or_2d, np.zeros(-delta)], axis=0
+ )
+ else:
+ data_1d_or_2d = np.concatenate(
+ [data_1d_or_2d, np.zeros((-delta, data_1d_or_2d.shape[1]))],
+ axis=0
+ )
+ return data_1d_or_2d
+
+
+def extract_logmel_spectrogram(
+ waveform: torch.Tensor, sample_rate: int,
+ output_path: Optional[Path] = None, win_length: int = 1024,
+ hop_length: int = 256, n_fft: int = 1024,
+ win_fn: callable = torch.hann_window, n_mels: int = 80,
+ f_min: float = 0., f_max: float = 8000, eps: float = 1e-5,
+ overwrite: bool = False, target_length: Optional[int] = None
+):
+ if output_path is not None and output_path.is_file() and not overwrite:
+ return
+
+ spectrogram_transform = TTSSpectrogram(
+ n_fft=n_fft, win_length=win_length, hop_length=hop_length,
+ window_fn=win_fn
+ )
+ mel_scale_transform = TTSMelScale(
+ n_mels=n_mels, sample_rate=sample_rate, f_min=f_min, f_max=f_max,
+ n_stft=n_fft // 2 + 1
+ )
+ spectrogram = spectrogram_transform(waveform)
+ mel_spec = mel_scale_transform(spectrogram)
+ logmel_spec = torch.clamp(mel_spec, min=eps).log()
+ assert len(logmel_spec.shape) == 3 and logmel_spec.shape[0] == 1
+ logmel_spec = logmel_spec.squeeze().t() # D x T -> T x D
+ if target_length is not None:
+ trim_or_pad_to_target_length(logmel_spec, target_length)
+
+ if output_path is not None:
+ np.save(output_path.as_posix(), logmel_spec)
+ else:
+ return logmel_spec
+
+
+def extract_pitch(
+ waveform: torch.Tensor, sample_rate: int,
+ output_path: Optional[Path] = None, hop_length: int = 256,
+ log_scale: bool = True, phoneme_durations: Optional[List[int]] = None
+):
+ if output_path is not None and output_path.is_file():
+ return
+
+ try:
+ import pyworld
+ except ImportError:
+ raise ImportError("Please install PyWORLD: pip install pyworld")
+
+ _waveform = waveform.squeeze(0).double().numpy()
+ pitch, t = pyworld.dio(
+ _waveform, sample_rate, frame_period=hop_length / sample_rate * 1000
+ )
+ pitch = pyworld.stonemask(_waveform, pitch, t, sample_rate)
+
+ if phoneme_durations is not None:
+ pitch = trim_or_pad_to_target_length(pitch, sum(phoneme_durations))
+ try:
+ from scipy.interpolate import interp1d
+ except ImportError:
+ raise ImportError("Please install SciPy: pip install scipy")
+ nonzero_ids = np.where(pitch != 0)[0]
+ interp_fn = interp1d(
+ nonzero_ids,
+ pitch[nonzero_ids],
+ fill_value=(pitch[nonzero_ids[0]], pitch[nonzero_ids[-1]]),
+ bounds_error=False,
+ )
+ pitch = interp_fn(np.arange(0, len(pitch)))
+ d_cumsum = np.cumsum(np.concatenate([np.array([0]), phoneme_durations]))
+ pitch = np.array(
+ [
+ np.mean(pitch[d_cumsum[i-1]: d_cumsum[i]])
+ for i in range(1, len(d_cumsum))
+ ]
+ )
+ assert len(pitch) == len(phoneme_durations)
+
+ if log_scale:
+ pitch = np.log(pitch + 1)
+
+ if output_path is not None:
+ np.save(output_path.as_posix(), pitch)
+ else:
+ return pitch
+
+
+def extract_energy(
+ waveform: torch.Tensor, output_path: Optional[Path] = None,
+ hop_length: int = 256, n_fft: int = 1024, log_scale: bool = True,
+ phoneme_durations: Optional[List[int]] = None
+):
+ if output_path is not None and output_path.is_file():
+ return
+
+ assert len(waveform.shape) == 2 and waveform.shape[0] == 1
+ waveform = waveform.view(1, 1, waveform.shape[1])
+ waveform = F.pad(
+ waveform.unsqueeze(1), [n_fft // 2, n_fft // 2, 0, 0],
+ mode="reflect"
+ )
+ waveform = waveform.squeeze(1)
+
+ fourier_basis = np.fft.fft(np.eye(n_fft))
+ cutoff = int((n_fft / 2 + 1))
+ fourier_basis = np.vstack(
+ [np.real(fourier_basis[:cutoff, :]),
+ np.imag(fourier_basis[:cutoff, :])]
+ )
+
+ forward_basis = torch.FloatTensor(fourier_basis[:, None, :])
+ forward_transform = F.conv1d(
+ waveform, forward_basis, stride=hop_length, padding=0
+ )
+
+ real_part = forward_transform[:, :cutoff, :]
+ imag_part = forward_transform[:, cutoff:, :]
+ magnitude = torch.sqrt(real_part ** 2 + imag_part ** 2)
+ energy = torch.norm(magnitude, dim=1).squeeze(0).numpy()
+
+ if phoneme_durations is not None:
+ energy = trim_or_pad_to_target_length(energy, sum(phoneme_durations))
+ d_cumsum = np.cumsum(np.concatenate([np.array([0]), phoneme_durations]))
+ energy = np.array(
+ [
+ np.mean(energy[d_cumsum[i - 1]: d_cumsum[i]])
+ for i in range(1, len(d_cumsum))
+ ]
+ )
+ assert len(energy) == len(phoneme_durations)
+
+ if log_scale:
+ energy = np.log(energy + 1)
+
+ if output_path is not None:
+ np.save(output_path.as_posix(), energy)
+ else:
+ return energy
+
+
+def get_global_cmvn(feature_root: Path, output_path: Optional[Path] = None):
+ mean_x, mean_x2, n_frames = None, None, 0
+ feature_paths = feature_root.glob("*.npy")
+ for p in tqdm(feature_paths):
+ with open(p, 'rb') as f:
+ frames = np.load(f).squeeze()
+
+ n_frames += frames.shape[0]
+
+ cur_mean_x = frames.sum(axis=0)
+ if mean_x is None:
+ mean_x = cur_mean_x
+ else:
+ mean_x += cur_mean_x
+
+ cur_mean_x2 = (frames ** 2).sum(axis=0)
+ if mean_x2 is None:
+ mean_x2 = cur_mean_x2
+ else:
+ mean_x2 += cur_mean_x2
+
+ mean_x /= n_frames
+ mean_x2 /= n_frames
+ var_x = mean_x2 - mean_x ** 2
+ std_x = np.sqrt(np.maximum(var_x, 1e-10))
+
+ if output_path is not None:
+ with open(output_path, 'wb') as f:
+ np.savez(f, mean=mean_x, std=std_x)
+ else:
+ return {"mean": mean_x, "std": std_x}
+
+
+def ipa_phonemize(text, lang="en-us", use_g2p=False):
+ if use_g2p:
+ assert lang == "en-us", "g2pE phonemizer only works for en-us"
+ try:
+ from g2p_en import G2p
+ g2p = G2p()
+ return " ".join("|" if p == " " else p for p in g2p(text))
+ except ImportError:
+ raise ImportError(
+ "Please install phonemizer: pip install g2p_en"
+ )
+ else:
+ try:
+ from phonemizer import phonemize
+ from phonemizer.separator import Separator
+ return phonemize(
+ text, backend='espeak', language=lang,
+ separator=Separator(word="| ", phone=" ")
+ )
+ except ImportError:
+ raise ImportError(
+ "Please install phonemizer: pip install phonemizer"
+ )
+
+
+@dataclass
+class ForceAlignmentInfo(object):
+ tokens: List[str]
+ frame_durations: List[int]
+ start_sec: Optional[float]
+ end_sec: Optional[float]
+
+
+def get_mfa_alignment_by_sample_id(
+ textgrid_zip_path: str, sample_id: str, sample_rate: int,
+ hop_length: int, silence_phones: List[str] = ("sil", "sp", "spn")
+) -> ForceAlignmentInfo:
+ try:
+ import tgt
+ except ImportError:
+ raise ImportError("Please install TextGridTools: pip install tgt")
+
+ filename = f"{sample_id}.TextGrid"
+ out_root = Path(tempfile.gettempdir())
+ tgt_path = out_root / filename
+ with zipfile.ZipFile(textgrid_zip_path) as f_zip:
+ f_zip.extract(filename, path=out_root)
+ textgrid = tgt.io.read_textgrid(tgt_path.as_posix())
+ os.remove(tgt_path)
+
+ phones, frame_durations = [], []
+ start_sec, end_sec, end_idx = 0, 0, 0
+ for t in textgrid.get_tier_by_name("phones")._objects:
+ s, e, p = t.start_time, t.end_time, t.text
+ # Trim leading silences
+ if len(phones) == 0:
+ if p in silence_phones:
+ continue
+ else:
+ start_sec = s
+ phones.append(p)
+ if p not in silence_phones:
+ end_sec = e
+ end_idx = len(phones)
+ r = sample_rate / hop_length
+ frame_durations.append(int(np.round(e * r) - np.round(s * r)))
+ # Trim tailing silences
+ phones = phones[:end_idx]
+ frame_durations = frame_durations[:end_idx]
+
+ return ForceAlignmentInfo(
+ tokens=phones, frame_durations=frame_durations, start_sec=start_sec,
+ end_sec=end_sec
+ )
+
+
+def get_mfa_alignment(
+ textgrid_zip_path: str, sample_ids: List[str], sample_rate: int,
+ hop_length: int
+) -> Dict[str, ForceAlignmentInfo]:
+ return {
+ i: get_mfa_alignment_by_sample_id(
+ textgrid_zip_path, i, sample_rate, hop_length
+ ) for i in tqdm(sample_ids)
+ }
+
+
+def get_unit_alignment(
+ id_to_unit_tsv_path: str, sample_ids: List[str]
+) -> Dict[str, ForceAlignmentInfo]:
+ id_to_units = {
+ e["id"]: e["units"] for e in load_tsv_to_dicts(id_to_unit_tsv_path)
+ }
+ id_to_units = {i: id_to_units[i].split() for i in sample_ids}
+ id_to_units_collapsed = {
+ i: [uu for uu, _ in groupby(u)] for i, u in id_to_units.items()
+ }
+ id_to_durations = {
+ i: [len(list(g)) for _, g in groupby(u)] for i, u in id_to_units.items()
+ }
+
+ return {
+ i: ForceAlignmentInfo(
+ tokens=id_to_units_collapsed[i], frame_durations=id_to_durations[i],
+ start_sec=None, end_sec=None
+ )
+ for i in sample_ids
+ }
diff --git a/fairseq/examples/speech_synthesis/docs/common_voice_example.md b/fairseq/examples/speech_synthesis/docs/common_voice_example.md
new file mode 100644
index 0000000000000000000000000000000000000000..40e841b284a7e34b458b286eb0bb60e33c0601da
--- /dev/null
+++ b/fairseq/examples/speech_synthesis/docs/common_voice_example.md
@@ -0,0 +1,56 @@
+[[Back]](..)
+
+# Common Voice
+
+[Common Voice](https://commonvoice.mozilla.org/en/datasets) is a public domain speech corpus with 11.2K hours of read
+speech in 76 languages (the latest version 7.0). We provide examples for building
+[Transformer](https://arxiv.org/abs/1809.08895) models on this dataset.
+
+
+## Data preparation
+[Download](https://commonvoice.mozilla.org/en/datasets) and unpack Common Voice v4 to a path `${DATA_ROOT}/${LANG_ID}`.
+Create splits and generate audio manifests with
+```bash
+python -m examples.speech_synthesis.preprocessing.get_common_voice_audio_manifest \
+ --data-root ${DATA_ROOT} \
+ --lang ${LANG_ID} \
+ --output-manifest-root ${AUDIO_MANIFEST_ROOT} --convert-to-wav
+```
+
+Then, extract log-Mel spectrograms, generate feature manifest and create data configuration YAML with
+```bash
+python -m examples.speech_synthesis.preprocessing.get_feature_manifest \
+ --audio-manifest-root ${AUDIO_MANIFEST_ROOT} \
+ --output-root ${FEATURE_MANIFEST_ROOT} \
+ --ipa-vocab --lang ${LANG_ID}
+```
+where we use phoneme inputs (`--ipa-vocab`) as example.
+
+To denoise audio and trim leading/trailing silence using signal processing based VAD, run
+```bash
+for SPLIT in dev test train; do
+ python -m examples.speech_synthesis.preprocessing.denoise_and_vad_audio \
+ --audio-manifest ${AUDIO_MANIFEST_ROOT}/${SPLIT}.audio.tsv \
+ --output-dir ${PROCESSED_DATA_ROOT} \
+ --denoise --vad --vad-agg-level 2
+done
+```
+
+
+## Training
+(Please refer to [the LJSpeech example](../docs/ljspeech_example.md#transformer).)
+
+
+## Inference
+(Please refer to [the LJSpeech example](../docs/ljspeech_example.md#inference).)
+
+## Automatic Evaluation
+(Please refer to [the LJSpeech example](../docs/ljspeech_example.md#automatic-evaluation).)
+
+## Results
+
+| Language | Speakers | --arch | Params | Test MCD | Model |
+|---|---|---|---|---|---|
+| English | 200 | tts_transformer | 54M | 3.8 | [Download](https://dl.fbaipublicfiles.com/fairseq/s2/cv4_en200_transformer_phn.tar) |
+
+[[Back]](..)
diff --git a/fairseq/examples/speech_synthesis/docs/ljspeech_example.md b/fairseq/examples/speech_synthesis/docs/ljspeech_example.md
new file mode 100644
index 0000000000000000000000000000000000000000..90c524fac8ffdc1819ec9bb36928500320337603
--- /dev/null
+++ b/fairseq/examples/speech_synthesis/docs/ljspeech_example.md
@@ -0,0 +1,138 @@
+[[Back]](..)
+
+# LJSpeech
+
+[LJSpeech](https://keithito.com/LJ-Speech-Dataset) is a public domain TTS
+corpus with around 24 hours of English speech sampled at 22.05kHz. We provide examples for building
+[Transformer](https://arxiv.org/abs/1809.08895) and [FastSpeech 2](https://arxiv.org/abs/2006.04558)
+models on this dataset.
+
+
+## Data preparation
+
+Download data, create splits and generate audio manifests with
+```bash
+python -m examples.speech_synthesis.preprocessing.get_ljspeech_audio_manifest \
+ --output-data-root ${AUDIO_DATA_ROOT} \
+ --output-manifest-root ${AUDIO_MANIFEST_ROOT}
+```
+
+Then, extract log-Mel spectrograms, generate feature manifest and create data configuration YAML with
+```bash
+python -m examples.speech_synthesis.preprocessing.get_feature_manifest \
+ --audio-manifest-root ${AUDIO_MANIFEST_ROOT} \
+ --output-root ${FEATURE_MANIFEST_ROOT} \
+ --ipa-vocab --use-g2p
+```
+where we use phoneme inputs (`--ipa-vocab --use-g2p`) as example.
+
+FastSpeech 2 additionally requires frame durations, pitch and energy as auxiliary training targets.
+Add `--add-fastspeech-targets` to include these fields in the feature manifests. We get frame durations either from
+phoneme-level force-alignment or frame-level pseudo-text unit sequence. They should be pre-computed and specified via:
+- `--textgrid-zip ${TEXT_GRID_ZIP_PATH}` for a ZIP file, inside which there is one
+ [TextGrid](https://www.fon.hum.uva.nl/praat/manual/TextGrid.html) file per sample to provide force-alignment info.
+- `--id-to-units-tsv ${ID_TO_UNIT_TSV}` for a TSV file, where there are 2 columns for sample ID and
+ space-delimited pseudo-text unit sequence, respectively.
+
+For your convenience, we provide pre-computed
+[force-alignment](https://dl.fbaipublicfiles.com/fairseq/s2/ljspeech_mfa.zip) from
+[Montreal Forced Aligner](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) and
+[pseudo-text units](s3://dl.fbaipublicfiles.com/fairseq/s2/ljspeech_hubert.tsv) from
+[HuBERT](https://github.com/pytorch/fairseq/tree/main/examples/hubert). You can also generate them by yourself using
+a different software or model.
+
+
+## Training
+#### Transformer
+```bash
+fairseq-train ${FEATURE_MANIFEST_ROOT} --save-dir ${SAVE_DIR} \
+ --config-yaml config.yaml --train-subset train --valid-subset dev \
+ --num-workers 4 --max-tokens 30000 --max-update 200000 \
+ --task text_to_speech --criterion tacotron2 --arch tts_transformer \
+ --clip-norm 5.0 --n-frames-per-step 4 --bce-pos-weight 5.0 \
+ --dropout 0.1 --attention-dropout 0.1 --activation-dropout 0.1 \
+ --encoder-normalize-before --decoder-normalize-before \
+ --optimizer adam --lr 2e-3 --lr-scheduler inverse_sqrt --warmup-updates 4000 \
+ --seed 1 --update-freq 8 --eval-inference --best-checkpoint-metric mcd_loss
+```
+where `SAVE_DIR` is the checkpoint root path. We set `--update-freq 8` to simulate 8 GPUs with 1 GPU. You may want to
+update it accordingly when using more than 1 GPU.
+
+#### FastSpeech2
+```bash
+fairseq-train ${FEATURE_MANIFEST_ROOT} --save-dir ${SAVE_DIR} \
+ --config-yaml config.yaml --train-subset train --valid-subset dev \
+ --num-workers 4 --max-sentences 6 --max-update 200000 \
+ --task text_to_speech --criterion fastspeech2 --arch fastspeech2 \
+ --clip-norm 5.0 --n-frames-per-step 1 \
+ --dropout 0.1 --attention-dropout 0.1 --activation-dropout 0.1 \
+ --encoder-normalize-before --decoder-normalize-before \
+ --optimizer adam --lr 5e-4 --lr-scheduler inverse_sqrt --warmup-updates 4000 \
+ --seed 1 --update-freq 8 --eval-inference --best-checkpoint-metric mcd_loss
+```
+
+
+## Inference
+Average the last 5 checkpoints, generate the test split spectrogram and waveform using the default Griffin-Lim vocoder:
+```bash
+SPLIT=test
+CHECKPOINT_NAME=avg_last_5
+CHECKPOINT_PATH=${SAVE_DIR}/checkpoint_${CHECKPOINT_NAME}.pt
+python scripts/average_checkpoints.py --inputs ${SAVE_DIR} \
+ --num-epoch-checkpoints 5 \
+ --output ${CHECKPOINT_PATH}
+
+python -m examples.speech_synthesis.generate_waveform ${FEATURE_MANIFEST_ROOT} \
+ --config-yaml config.yaml --gen-subset ${SPLIT} --task text_to_speech \
+ --path ${CHECKPOINT_PATH} --max-tokens 50000 --spec-bwd-max-iter 32 \
+ --dump-waveforms
+```
+which dumps files (waveform, feature, attention plot, etc.) to `${SAVE_DIR}/generate-${CHECKPOINT_NAME}-${SPLIT}`. To
+re-synthesize target waveforms for automatic evaluation, add `--dump-target`.
+
+## Automatic Evaluation
+To start with, generate the manifest for synthetic speech, which will be taken as inputs by evaluation scripts.
+```bash
+python -m examples.speech_synthesis.evaluation.get_eval_manifest \
+ --generation-root ${SAVE_DIR}/generate-${CHECKPOINT_NAME}-${SPLIT} \
+ --audio-manifest ${AUDIO_MANIFEST_ROOT}/${SPLIT}.audio.tsv \
+ --output-path ${EVAL_OUTPUT_ROOT}/eval.tsv \
+ --vocoder griffin_lim --sample-rate 22050 --audio-format flac \
+ --use-resynthesized-target
+```
+Speech recognition (ASR) models usually operate at lower sample rates (e.g. 16kHz). For the WER/CER metric,
+you may need to resample the audios accordingly --- add `--output-sample-rate 16000` for `generate_waveform.py` and
+use `--sample-rate 16000` for `get_eval_manifest.py`.
+
+
+#### WER/CER metric
+We use wav2vec 2.0 ASR model as example. [Download](https://github.com/pytorch/fairseq/tree/main/examples/wav2vec)
+the model checkpoint and dictionary, then compute WER/CER with
+```bash
+python -m examples.speech_synthesis.evaluation.eval_asr \
+ --audio-header syn --text-header text --err-unit char --split ${SPLIT} \
+ --w2v-ckpt ${WAV2VEC2_CHECKPOINT_PATH} --w2v-dict-dir ${WAV2VEC2_DICT_DIR} \
+ --raw-manifest ${EVAL_OUTPUT_ROOT}/eval_16khz.tsv --asr-dir ${EVAL_OUTPUT_ROOT}/asr
+```
+
+#### MCD/MSD metric
+```bash
+python -m examples.speech_synthesis.evaluation.eval_sp \
+ ${EVAL_OUTPUT_ROOT}/eval.tsv --mcd --msd
+```
+
+#### F0 metrics
+```bash
+python -m examples.speech_synthesis.evaluation.eval_f0 \
+ ${EVAL_OUTPUT_ROOT}/eval.tsv --gpe --vde --ffe
+```
+
+
+## Results
+
+| --arch | Params | Test MCD | Model |
+|---|---|---|---|
+| tts_transformer | 54M | 3.8 | [Download](https://dl.fbaipublicfiles.com/fairseq/s2/ljspeech_transformer_phn.tar) |
+| fastspeech2 | 41M | 3.8 | [Download](https://dl.fbaipublicfiles.com/fairseq/s2/ljspeech_fastspeech2_phn.tar) |
+
+[[Back]](..)
diff --git a/fairseq/examples/speech_synthesis/docs/vctk_example.md b/fairseq/examples/speech_synthesis/docs/vctk_example.md
new file mode 100644
index 0000000000000000000000000000000000000000..2ba78f3f73d6ea30f9de89150fbbc9dd5923b6fa
--- /dev/null
+++ b/fairseq/examples/speech_synthesis/docs/vctk_example.md
@@ -0,0 +1,51 @@
+[[Back]](..)
+
+# VCTK
+
+[VCTK](https://datashare.ed.ac.uk/handle/10283/3443) is an open English speech corpus. We provide examples
+for building [Transformer](https://arxiv.org/abs/1809.08895) models on this dataset.
+
+
+## Data preparation
+Download data, create splits and generate audio manifests with
+```bash
+python -m examples.speech_synthesis.preprocessing.get_vctk_audio_manifest \
+ --output-data-root ${AUDIO_DATA_ROOT} \
+ --output-manifest-root ${AUDIO_MANIFEST_ROOT}
+```
+
+Then, extract log-Mel spectrograms, generate feature manifest and create data configuration YAML with
+```bash
+python -m examples.speech_synthesis.preprocessing.get_feature_manifest \
+ --audio-manifest-root ${AUDIO_MANIFEST_ROOT} \
+ --output-root ${FEATURE_MANIFEST_ROOT} \
+ --ipa-vocab --use-g2p
+```
+where we use phoneme inputs (`--ipa-vocab --use-g2p`) as example.
+
+To denoise audio and trim leading/trailing silence using signal processing based VAD, run
+```bash
+for SPLIT in dev test train; do
+ python -m examples.speech_synthesis.preprocessing.denoise_and_vad_audio \
+ --audio-manifest ${AUDIO_MANIFEST_ROOT}/${SPLIT}.audio.tsv \
+ --output-dir ${PROCESSED_DATA_ROOT} \
+ --denoise --vad --vad-agg-level 3
+done
+```
+
+## Training
+(Please refer to [the LJSpeech example](../docs/ljspeech_example.md#transformer).)
+
+## Inference
+(Please refer to [the LJSpeech example](../docs/ljspeech_example.md#inference).)
+
+## Automatic Evaluation
+(Please refer to [the LJSpeech example](../docs/ljspeech_example.md#automatic-evaluation).)
+
+## Results
+
+| --arch | Params | Test MCD | Model |
+|---|---|---|---|
+| tts_transformer | 54M | 3.4 | [Download](https://dl.fbaipublicfiles.com/fairseq/s2/vctk_transformer_phn.tar) |
+
+[[Back]](..)
diff --git a/fairseq/examples/speech_synthesis/evaluation/__init__.py b/fairseq/examples/speech_synthesis/evaluation/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..6264236915a7269a4d920ee8213004374dd86a9a
--- /dev/null
+++ b/fairseq/examples/speech_synthesis/evaluation/__init__.py
@@ -0,0 +1,4 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
diff --git a/fairseq/examples/speech_synthesis/evaluation/eval_asr.py b/fairseq/examples/speech_synthesis/evaluation/eval_asr.py
new file mode 100644
index 0000000000000000000000000000000000000000..005a11bfb34ca477ad9e133acd60f249e66cda47
--- /dev/null
+++ b/fairseq/examples/speech_synthesis/evaluation/eval_asr.py
@@ -0,0 +1,128 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import argparse
+import editdistance
+import re
+import shutil
+import soundfile as sf
+import subprocess
+from pathlib import Path
+
+from examples.speech_to_text.data_utils import load_tsv_to_dicts
+
+
+def preprocess_text(text):
+ text = "|".join(re.sub(r"[^A-Z' ]", " ", text.upper()).split())
+ text = " ".join(text)
+ return text
+
+
+def prepare_w2v_data(
+ dict_dir, sample_rate, label, audio_paths, texts, split, data_dir
+):
+ data_dir.mkdir(parents=True, exist_ok=True)
+ shutil.copyfile(
+ dict_dir / f"dict.{label}.txt",
+ data_dir / f"dict.{label}.txt"
+ )
+ with open(data_dir / f"{split}.tsv", "w") as f:
+ f.write("/\n")
+ for audio_path in audio_paths:
+ wav, sr = sf.read(audio_path)
+ assert sr == sample_rate, f"{sr} != sample_rate"
+ nsample = len(wav)
+ f.write(f"{audio_path}\t{nsample}\n")
+ with open(data_dir / f"{split}.{label}", "w") as f:
+ for text in texts:
+ text = preprocess_text(text)
+ f.write(f"{text}\n")
+
+
+def run_asr(asr_dir, split, w2v_ckpt, w2v_label, res_dir):
+ """
+ results will be saved at
+ {res_dir}/{ref,hypo}.word-{w2v_ckpt.filename}-{split}.txt
+ """
+ cmd = ["python", "-m", "examples.speech_recognition.infer"]
+ cmd += [str(asr_dir.resolve())]
+ cmd += ["--task", "audio_finetuning", "--nbest", "1", "--quiet"]
+ cmd += ["--w2l-decoder", "viterbi", "--criterion", "ctc"]
+ cmd += ["--post-process", "letter", "--max-tokens", "4000000"]
+ cmd += ["--path", str(w2v_ckpt.resolve()), "--labels", w2v_label]
+ cmd += ["--gen-subset", split, "--results-path", str(res_dir.resolve())]
+
+ print(f"running cmd:\n{' '.join(cmd)}")
+ subprocess.run(cmd, check=True)
+
+
+def compute_error_rate(hyp_wrd_path, ref_wrd_path, unit="word"):
+ """each line is " (None-)" """
+ tokenize_line = {
+ "word": lambda x: re.sub(r" \(.*\)$", "", x.rstrip()).split(),
+ "char": lambda x: list(re.sub(r" \(.*\)$", "", x.rstrip()))
+ }.get(unit)
+ if tokenize_line is None:
+ raise ValueError(f"{unit} not supported")
+
+ inds = [int(re.sub(r"\D*(\d*)\D*", r"\1", line))
+ for line in open(hyp_wrd_path)]
+ hyps = [tokenize_line(line) for line in open(hyp_wrd_path)]
+ refs = [tokenize_line(line) for line in open(ref_wrd_path)]
+ assert(len(hyps) == len(refs))
+ err_rates = [
+ editdistance.eval(hyp, ref) / len(ref) for hyp, ref in zip(hyps, refs)
+ ]
+ ind_to_err_rates = {i: e for i, e in zip(inds, err_rates)}
+ return ind_to_err_rates
+
+
+def main(args):
+ samples = load_tsv_to_dicts(args.raw_manifest)
+ ids = [
+ sample[args.id_header] if args.id_header else "" for sample in samples
+ ]
+ audio_paths = [sample[args.audio_header] for sample in samples]
+ texts = [sample[args.text_header] for sample in samples]
+
+ prepare_w2v_data(
+ args.w2v_dict_dir,
+ args.w2v_sample_rate,
+ args.w2v_label,
+ audio_paths,
+ texts,
+ args.split,
+ args.asr_dir
+ )
+ run_asr(args.asr_dir, args.split, args.w2v_ckpt, args.w2v_label, args.asr_dir)
+ ind_to_err_rates = compute_error_rate(
+ args.asr_dir / f"hypo.word-{args.w2v_ckpt.name}-{args.split}.txt",
+ args.asr_dir / f"ref.word-{args.w2v_ckpt.name}-{args.split}.txt",
+ args.err_unit,
+ )
+
+ uer_path = args.asr_dir / f"uer_{args.err_unit}.{args.split}.tsv"
+ with open(uer_path, "w") as f:
+ f.write("id\taudio\tuer\n")
+ for ind, (id_, audio_path) in enumerate(zip(ids, audio_paths)):
+ f.write(f"{id_}\t{audio_path}\t{ind_to_err_rates[ind]:.4f}\n")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--raw-manifest", required=True, type=Path)
+ parser.add_argument("--asr-dir", required=True, type=Path)
+ parser.add_argument("--id-header", default="id", type=str)
+ parser.add_argument("--audio-header", default="audio", type=str)
+ parser.add_argument("--text-header", default="src_text", type=str)
+ parser.add_argument("--split", default="raw", type=str)
+ parser.add_argument("--w2v-ckpt", required=True, type=Path)
+ parser.add_argument("--w2v-dict-dir", required=True, type=Path)
+ parser.add_argument("--w2v-sample-rate", default=16000, type=int)
+ parser.add_argument("--w2v-label", default="ltr", type=str)
+ parser.add_argument("--err-unit", default="word", type=str)
+ args = parser.parse_args()
+
+ main(args)
diff --git a/fairseq/examples/speech_synthesis/evaluation/eval_f0.py b/fairseq/examples/speech_synthesis/evaluation/eval_f0.py
new file mode 100644
index 0000000000000000000000000000000000000000..df721d683113b44957149cfc3cddaba36520a22c
--- /dev/null
+++ b/fairseq/examples/speech_synthesis/evaluation/eval_f0.py
@@ -0,0 +1,266 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+Signal processing-based evaluation using waveforms
+"""
+import numpy as np
+import os.path as op
+
+import torchaudio
+import tqdm
+from tabulate import tabulate
+
+from examples.speech_synthesis.utils import (
+ gross_pitch_error, voicing_decision_error, f0_frame_error
+)
+from examples.speech_synthesis.evaluation.eval_sp import load_eval_spec
+
+
+def difference_function(x, n, tau_max):
+ """
+ Compute difference function of data x. This solution is implemented directly
+ with Numpy fft.
+
+
+ :param x: audio data
+ :param n: length of data
+ :param tau_max: integration window size
+ :return: difference function
+ :rtype: list
+ """
+
+ x = np.array(x, np.float64)
+ w = x.size
+ tau_max = min(tau_max, w)
+ x_cumsum = np.concatenate((np.array([0.]), (x * x).cumsum()))
+ size = w + tau_max
+ p2 = (size // 32).bit_length()
+ nice_numbers = (16, 18, 20, 24, 25, 27, 30, 32)
+ size_pad = min(x * 2 ** p2 for x in nice_numbers if x * 2 ** p2 >= size)
+ fc = np.fft.rfft(x, size_pad)
+ conv = np.fft.irfft(fc * fc.conjugate())[:tau_max]
+ return x_cumsum[w:w - tau_max:-1] + x_cumsum[w] - x_cumsum[:tau_max] - \
+ 2 * conv
+
+
+def cumulative_mean_normalized_difference_function(df, n):
+ """
+ Compute cumulative mean normalized difference function (CMND).
+
+ :param df: Difference function
+ :param n: length of data
+ :return: cumulative mean normalized difference function
+ :rtype: list
+ """
+
+ # scipy method
+ cmn_df = df[1:] * range(1, n) / np.cumsum(df[1:]).astype(float)
+ return np.insert(cmn_df, 0, 1)
+
+
+def get_pitch(cmdf, tau_min, tau_max, harmo_th=0.1):
+ """
+ Return fundamental period of a frame based on CMND function.
+
+ :param cmdf: Cumulative Mean Normalized Difference function
+ :param tau_min: minimum period for speech
+ :param tau_max: maximum period for speech
+ :param harmo_th: harmonicity threshold to determine if it is necessary to
+ compute pitch frequency
+ :return: fundamental period if there is values under threshold, 0 otherwise
+ :rtype: float
+ """
+ tau = tau_min
+ while tau < tau_max:
+ if cmdf[tau] < harmo_th:
+ while tau + 1 < tau_max and cmdf[tau + 1] < cmdf[tau]:
+ tau += 1
+ return tau
+ tau += 1
+
+ return 0 # if unvoiced
+
+
+def compute_yin(sig, sr, w_len=512, w_step=256, f0_min=100, f0_max=500,
+ harmo_thresh=0.1):
+ """
+
+ Compute the Yin Algorithm. Return fundamental frequency and harmonic rate.
+
+ https://github.com/NVIDIA/mellotron adaption of
+ https://github.com/patriceguyot/Yin
+
+ :param sig: Audio signal (list of float)
+ :param sr: sampling rate (int)
+ :param w_len: size of the analysis window (samples)
+ :param w_step: size of the lag between two consecutives windows (samples)
+ :param f0_min: Minimum fundamental frequency that can be detected (hertz)
+ :param f0_max: Maximum fundamental frequency that can be detected (hertz)
+ :param harmo_thresh: Threshold of detection. The yalgorithmù return the
+ first minimum of the CMND function below this threshold.
+
+ :returns:
+
+ * pitches: list of fundamental frequencies,
+ * harmonic_rates: list of harmonic rate values for each fundamental
+ frequency value (= confidence value)
+ * argmins: minimums of the Cumulative Mean Normalized DifferenceFunction
+ * times: list of time of each estimation
+ :rtype: tuple
+ """
+
+ tau_min = int(sr / f0_max)
+ tau_max = int(sr / f0_min)
+
+ # time values for each analysis window
+ time_scale = range(0, len(sig) - w_len, w_step)
+ times = [t/float(sr) for t in time_scale]
+ frames = [sig[t:t + w_len] for t in time_scale]
+
+ pitches = [0.0] * len(time_scale)
+ harmonic_rates = [0.0] * len(time_scale)
+ argmins = [0.0] * len(time_scale)
+
+ for i, frame in enumerate(frames):
+ # Compute YIN
+ df = difference_function(frame, w_len, tau_max)
+ cm_df = cumulative_mean_normalized_difference_function(df, tau_max)
+ p = get_pitch(cm_df, tau_min, tau_max, harmo_thresh)
+
+ # Get results
+ if np.argmin(cm_df) > tau_min:
+ argmins[i] = float(sr / np.argmin(cm_df))
+ if p != 0: # A pitch was found
+ pitches[i] = float(sr / p)
+ harmonic_rates[i] = cm_df[p]
+ else: # No pitch, but we compute a value of the harmonic rate
+ harmonic_rates[i] = min(cm_df)
+
+ return pitches, harmonic_rates, argmins, times
+
+
+def extract_f0(samples):
+ f0_samples = []
+ for sample in tqdm.tqdm(samples):
+ if not op.isfile(sample["ref"]) or not op.isfile(sample["syn"]):
+ f0_samples.append(None)
+ continue
+
+ # assume single channel
+ yref, sr = torchaudio.load(sample["ref"])
+ ysyn, _sr = torchaudio.load(sample["syn"])
+ yref, ysyn = yref[0], ysyn[0]
+ assert sr == _sr, f"{sr} != {_sr}"
+
+ yref_f0 = compute_yin(yref, sr)
+ ysyn_f0 = compute_yin(ysyn, sr)
+
+ f0_samples += [
+ {
+ "ref": yref_f0,
+ "syn": ysyn_f0
+ }
+ ]
+
+ return f0_samples
+
+
+def eval_f0_error(samples, distortion_fn):
+ results = []
+ for sample in tqdm.tqdm(samples):
+ if sample is None:
+ results.append(None)
+ continue
+ # assume single channel
+ yref_f, _, _, yref_t = sample["ref"]
+ ysyn_f, _, _, ysyn_t = sample["syn"]
+
+ yref_f = np.array(yref_f)
+ yref_t = np.array(yref_t)
+ ysyn_f = np.array(ysyn_f)
+ ysyn_t = np.array(ysyn_t)
+
+ distortion = distortion_fn(yref_t, yref_f, ysyn_t, ysyn_f)
+ results.append((distortion.item(),
+ len(yref_f),
+ len(ysyn_f)
+ ))
+ return results
+
+
+def eval_gross_pitch_error(samples):
+ return eval_f0_error(samples, gross_pitch_error)
+
+
+def eval_voicing_decision_error(samples):
+ return eval_f0_error(samples, voicing_decision_error)
+
+
+def eval_f0_frame_error(samples):
+ return eval_f0_error(samples, f0_frame_error)
+
+
+def print_results(results, show_bin):
+ results = np.array(list(filter(lambda x: x is not None, results)))
+
+ np.set_printoptions(precision=3)
+
+ def _print_result(results):
+ res = {
+ "nutt": len(results),
+ "error": results[:, 0].mean(),
+ "std": results[:, 0].std(),
+ "dur_ref": int(results[:, 1].sum()),
+ "dur_syn": int(results[:, 2].sum()),
+ }
+ print(tabulate([res.values()], res.keys(), floatfmt=".4f"))
+
+ print(">>>> ALL")
+ _print_result(results)
+
+ if show_bin:
+ edges = [0, 200, 400, 600, 800, 1000, 2000, 4000]
+ for i in range(1, len(edges)):
+ mask = np.logical_and(results[:, 1] >= edges[i-1],
+ results[:, 1] < edges[i])
+ if not mask.any():
+ continue
+ bin_results = results[mask]
+ print(f">>>> ({edges[i-1]}, {edges[i]})")
+ _print_result(bin_results)
+
+
+def main(eval_f0, gpe, vde, ffe, show_bin):
+ samples = load_eval_spec(eval_f0)
+ if gpe or vde or ffe:
+ f0_samples = extract_f0(samples)
+
+ if gpe:
+ print("===== Evaluate Gross Pitch Error =====")
+ results = eval_gross_pitch_error(f0_samples)
+ print_results(results, show_bin)
+ if vde:
+ print("===== Evaluate Voicing Decision Error =====")
+ results = eval_voicing_decision_error(f0_samples)
+ print_results(results, show_bin)
+ if ffe:
+ print("===== Evaluate F0 Frame Error =====")
+ results = eval_f0_frame_error(f0_samples)
+ print_results(results, show_bin)
+
+
+if __name__ == "__main__":
+ import argparse
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument("eval_f0")
+ parser.add_argument("--gpe", action="store_true")
+ parser.add_argument("--vde", action="store_true")
+ parser.add_argument("--ffe", action="store_true")
+ parser.add_argument("--show-bin", action="store_true")
+ args = parser.parse_args()
+
+ main(args.eval_f0, args.gpe, args.vde, args.ffe, args.show_bin)
diff --git a/fairseq/examples/speech_synthesis/evaluation/eval_sp.py b/fairseq/examples/speech_synthesis/evaluation/eval_sp.py
new file mode 100644
index 0000000000000000000000000000000000000000..702c4980389624f788abc0b42cdf54757a52512f
--- /dev/null
+++ b/fairseq/examples/speech_synthesis/evaluation/eval_sp.py
@@ -0,0 +1,131 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+"""
+Signal processing-based evaluation using waveforms
+"""
+
+import csv
+import numpy as np
+import os.path as op
+
+import torch
+import tqdm
+from tabulate import tabulate
+import torchaudio
+
+from examples.speech_synthesis.utils import batch_mel_spectral_distortion
+from fairseq.tasks.text_to_speech import batch_mel_cepstral_distortion
+
+
+def load_eval_spec(path):
+ with open(path) as f:
+ reader = csv.DictReader(f, delimiter='\t')
+ samples = list(reader)
+ return samples
+
+
+def eval_distortion(samples, distortion_fn, device="cuda"):
+ nmiss = 0
+ results = []
+ for sample in tqdm.tqdm(samples):
+ if not op.isfile(sample["ref"]) or not op.isfile(sample["syn"]):
+ nmiss += 1
+ results.append(None)
+ continue
+ # assume single channel
+ yref, sr = torchaudio.load(sample["ref"])
+ ysyn, _sr = torchaudio.load(sample["syn"])
+ yref, ysyn = yref[0].to(device), ysyn[0].to(device)
+ assert sr == _sr, f"{sr} != {_sr}"
+
+ distortion, extra = distortion_fn([yref], [ysyn], sr, None)[0]
+ _, _, _, _, _, pathmap = extra
+ nins = torch.sum(pathmap.sum(dim=1) - 1) # extra frames in syn
+ ndel = torch.sum(pathmap.sum(dim=0) - 1) # missing frames from syn
+ results.append(
+ (distortion.item(), # path distortion
+ pathmap.size(0), # yref num frames
+ pathmap.size(1), # ysyn num frames
+ pathmap.sum().item(), # path length
+ nins.item(), # insertion
+ ndel.item(), # deletion
+ )
+ )
+ return results
+
+
+def eval_mel_cepstral_distortion(samples, device="cuda"):
+ return eval_distortion(samples, batch_mel_cepstral_distortion, device)
+
+
+def eval_mel_spectral_distortion(samples, device="cuda"):
+ return eval_distortion(samples, batch_mel_spectral_distortion, device)
+
+
+def print_results(results, show_bin):
+ results = np.array(list(filter(lambda x: x is not None, results)))
+
+ np.set_printoptions(precision=3)
+
+ def _print_result(results):
+ dist, dur_ref, dur_syn, dur_ali, nins, ndel = results.sum(axis=0)
+ res = {
+ "nutt": len(results),
+ "dist": dist,
+ "dur_ref": int(dur_ref),
+ "dur_syn": int(dur_syn),
+ "dur_ali": int(dur_ali),
+ "dist_per_ref_frm": dist/dur_ref,
+ "dist_per_syn_frm": dist/dur_syn,
+ "dist_per_ali_frm": dist/dur_ali,
+ "ins": nins/dur_ref,
+ "del": ndel/dur_ref,
+ }
+ print(tabulate(
+ [res.values()],
+ res.keys(),
+ floatfmt=".4f"
+ ))
+
+ print(">>>> ALL")
+ _print_result(results)
+
+ if show_bin:
+ edges = [0, 200, 400, 600, 800, 1000, 2000, 4000]
+ for i in range(1, len(edges)):
+ mask = np.logical_and(results[:, 1] >= edges[i-1],
+ results[:, 1] < edges[i])
+ if not mask.any():
+ continue
+ bin_results = results[mask]
+ print(f">>>> ({edges[i-1]}, {edges[i]})")
+ _print_result(bin_results)
+
+
+def main(eval_spec, mcd, msd, show_bin):
+ samples = load_eval_spec(eval_spec)
+ device = "cpu"
+ if mcd:
+ print("===== Evaluate Mean Cepstral Distortion =====")
+ results = eval_mel_cepstral_distortion(samples, device)
+ print_results(results, show_bin)
+ if msd:
+ print("===== Evaluate Mean Spectral Distortion =====")
+ results = eval_mel_spectral_distortion(samples, device)
+ print_results(results, show_bin)
+
+
+if __name__ == "__main__":
+ import argparse
+ parser = argparse.ArgumentParser()
+ parser.add_argument("eval_spec")
+ parser.add_argument("--mcd", action="store_true")
+ parser.add_argument("--msd", action="store_true")
+ parser.add_argument("--show-bin", action="store_true")
+ args = parser.parse_args()
+
+ main(args.eval_spec, args.mcd, args.msd, args.show_bin)
diff --git a/fairseq/examples/speech_synthesis/evaluation/get_eval_manifest.py b/fairseq/examples/speech_synthesis/evaluation/get_eval_manifest.py
new file mode 100644
index 0000000000000000000000000000000000000000..a28cd607a096844438f6a3ba6b007d94d67d1bc8
--- /dev/null
+++ b/fairseq/examples/speech_synthesis/evaluation/get_eval_manifest.py
@@ -0,0 +1,58 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+import csv
+from pathlib import Path
+
+
+def main(args):
+ """
+ `uid syn ref text`
+ """
+ in_root = Path(args.generation_root).resolve()
+ ext = args.audio_format
+ with open(args.audio_manifest) as f, open(args.output_path, "w") as f_out:
+ reader = csv.DictReader(
+ f, delimiter="\t", quotechar=None, doublequote=False,
+ lineterminator="\n", quoting=csv.QUOTE_NONE
+ )
+ header = ["id", "syn", "ref", "text", "speaker"]
+ f_out.write("\t".join(header) + "\n")
+ for row in reader:
+ dir_name = f"{ext}_{args.sample_rate}hz_{args.vocoder}"
+ id_ = row["id"]
+ syn = (in_root / dir_name / f"{id_}.{ext}").as_posix()
+ ref = row["audio"]
+ if args.use_resynthesized_target:
+ ref = (in_root / f"{dir_name}_tgt" / f"{id_}.{ext}").as_posix()
+ sample = [id_, syn, ref, row["tgt_text"], row["speaker"]]
+ f_out.write("\t".join(sample) + "\n")
+ print(f"wrote evaluation file to {args.output_path}")
+
+
+if __name__ == "__main__":
+ import argparse
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--generation-root", help="output directory for generate_waveform.py"
+ )
+ parser.add_argument(
+ "--audio-manifest",
+ help="used to determine the original utterance ID and text"
+ )
+ parser.add_argument(
+ "--output-path", help="path to output evaluation spec file"
+ )
+ parser.add_argument(
+ "--use-resynthesized-target", action="store_true",
+ help="use resynthesized reference instead of the original audio"
+ )
+ parser.add_argument("--vocoder", type=str, default="griffin_lim")
+ parser.add_argument("--sample-rate", type=int, default=22_050)
+ parser.add_argument("--audio-format", type=str, default="wav")
+ args = parser.parse_args()
+
+ main(args)
diff --git a/fairseq/examples/speech_synthesis/generate_waveform.py b/fairseq/examples/speech_synthesis/generate_waveform.py
new file mode 100644
index 0000000000000000000000000000000000000000..bfc2ef8eb3d91366caf7609d75aa1795ab0ed8f9
--- /dev/null
+++ b/fairseq/examples/speech_synthesis/generate_waveform.py
@@ -0,0 +1,191 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+import logging
+import matplotlib.pyplot as plt
+import numpy as np
+from pathlib import Path
+import soundfile as sf
+import sys
+import torch
+import torchaudio
+
+from fairseq import checkpoint_utils, options, tasks, utils
+from fairseq.logging import progress_bar
+from fairseq.tasks.text_to_speech import plot_tts_output
+from fairseq.data.audio.text_to_speech_dataset import TextToSpeechDataset
+
+
+logging.basicConfig()
+logging.root.setLevel(logging.INFO)
+logging.basicConfig(level=logging.INFO)
+logger = logging.getLogger(__name__)
+
+
+def make_parser():
+ parser = options.get_speech_generation_parser()
+ parser.add_argument("--dump-features", action="store_true")
+ parser.add_argument("--dump-waveforms", action="store_true")
+ parser.add_argument("--dump-attentions", action="store_true")
+ parser.add_argument("--dump-eos-probs", action="store_true")
+ parser.add_argument("--dump-plots", action="store_true")
+ parser.add_argument("--dump-target", action="store_true")
+ parser.add_argument("--output-sample-rate", default=22050, type=int)
+ parser.add_argument("--teacher-forcing", action="store_true")
+ parser.add_argument(
+ "--audio-format", type=str, default="wav", choices=["wav", "flac"]
+ )
+ return parser
+
+
+def postprocess_results(
+ dataset: TextToSpeechDataset, sample, hypos, resample_fn, dump_target
+):
+ def to_np(x):
+ return None if x is None else x.detach().cpu().numpy()
+
+ sample_ids = [dataset.ids[i] for i in sample["id"].tolist()]
+ texts = sample["src_texts"]
+ attns = [to_np(hypo["attn"]) for hypo in hypos]
+ eos_probs = [to_np(hypo.get("eos_prob", None)) for hypo in hypos]
+ feat_preds = [to_np(hypo["feature"]) for hypo in hypos]
+ wave_preds = [to_np(resample_fn(h["waveform"])) for h in hypos]
+ if dump_target:
+ feat_targs = [to_np(hypo["targ_feature"]) for hypo in hypos]
+ wave_targs = [to_np(resample_fn(h["targ_waveform"])) for h in hypos]
+ else:
+ feat_targs = [None for _ in hypos]
+ wave_targs = [None for _ in hypos]
+
+ return zip(sample_ids, texts, attns, eos_probs, feat_preds, wave_preds,
+ feat_targs, wave_targs)
+
+
+def dump_result(
+ is_na_model,
+ args,
+ vocoder,
+ sample_id,
+ text,
+ attn,
+ eos_prob,
+ feat_pred,
+ wave_pred,
+ feat_targ,
+ wave_targ,
+):
+ sample_rate = args.output_sample_rate
+ out_root = Path(args.results_path)
+ if args.dump_features:
+ feat_dir = out_root / "feat"
+ feat_dir.mkdir(exist_ok=True, parents=True)
+ np.save(feat_dir / f"{sample_id}.npy", feat_pred)
+ if args.dump_target:
+ feat_tgt_dir = out_root / "feat_tgt"
+ feat_tgt_dir.mkdir(exist_ok=True, parents=True)
+ np.save(feat_tgt_dir / f"{sample_id}.npy", feat_targ)
+ if args.dump_attentions:
+ attn_dir = out_root / "attn"
+ attn_dir.mkdir(exist_ok=True, parents=True)
+ np.save(attn_dir / f"{sample_id}.npy", attn.numpy())
+ if args.dump_eos_probs and not is_na_model:
+ eos_dir = out_root / "eos"
+ eos_dir.mkdir(exist_ok=True, parents=True)
+ np.save(eos_dir / f"{sample_id}.npy", eos_prob)
+
+ if args.dump_plots:
+ images = [feat_pred.T] if is_na_model else [feat_pred.T, attn]
+ names = ["output"] if is_na_model else ["output", "alignment"]
+ if feat_targ is not None:
+ images = [feat_targ.T] + images
+ names = [f"target (idx={sample_id})"] + names
+ if is_na_model:
+ plot_tts_output(images, names, attn, "alignment", suptitle=text)
+ else:
+ plot_tts_output(images, names, eos_prob, "eos prob", suptitle=text)
+ plot_dir = out_root / "plot"
+ plot_dir.mkdir(exist_ok=True, parents=True)
+ plt.savefig(plot_dir / f"{sample_id}.png")
+ plt.close()
+
+ if args.dump_waveforms:
+ ext = args.audio_format
+ if wave_pred is not None:
+ wav_dir = out_root / f"{ext}_{sample_rate}hz_{vocoder}"
+ wav_dir.mkdir(exist_ok=True, parents=True)
+ sf.write(wav_dir / f"{sample_id}.{ext}", wave_pred, sample_rate)
+ if args.dump_target and wave_targ is not None:
+ wav_tgt_dir = out_root / f"{ext}_{sample_rate}hz_{vocoder}_tgt"
+ wav_tgt_dir.mkdir(exist_ok=True, parents=True)
+ sf.write(wav_tgt_dir / f"{sample_id}.{ext}", wave_targ, sample_rate)
+
+
+def main(args):
+ assert(args.dump_features or args.dump_waveforms or args.dump_attentions
+ or args.dump_eos_probs or args.dump_plots)
+ if args.max_tokens is None and args.batch_size is None:
+ args.max_tokens = 8000
+ logger.info(args)
+
+ use_cuda = torch.cuda.is_available() and not args.cpu
+ task = tasks.setup_task(args)
+ models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task(
+ [args.path],
+ task=task,
+ )
+ model = models[0].cuda() if use_cuda else models[0]
+ # use the original n_frames_per_step
+ task.args.n_frames_per_step = saved_cfg.task.n_frames_per_step
+ task.load_dataset(args.gen_subset, task_cfg=saved_cfg.task)
+
+ data_cfg = task.data_cfg
+ sample_rate = data_cfg.config.get("features", {}).get("sample_rate", 22050)
+ resample_fn = {
+ False: lambda x: x,
+ True: lambda x: torchaudio.sox_effects.apply_effects_tensor(
+ x.detach().cpu().unsqueeze(0), sample_rate,
+ [['rate', str(args.output_sample_rate)]]
+ )[0].squeeze(0)
+ }.get(args.output_sample_rate != sample_rate)
+ if args.output_sample_rate != sample_rate:
+ logger.info(f"resampling to {args.output_sample_rate}Hz")
+
+ generator = task.build_generator([model], args)
+ itr = task.get_batch_iterator(
+ dataset=task.dataset(args.gen_subset),
+ max_tokens=args.max_tokens,
+ max_sentences=args.batch_size,
+ max_positions=(sys.maxsize, sys.maxsize),
+ ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
+ required_batch_size_multiple=args.required_batch_size_multiple,
+ num_shards=args.num_shards,
+ shard_id=args.shard_id,
+ num_workers=args.num_workers,
+ data_buffer_size=args.data_buffer_size,
+ ).next_epoch_itr(shuffle=False)
+
+ Path(args.results_path).mkdir(exist_ok=True, parents=True)
+ is_na_model = getattr(model, "NON_AUTOREGRESSIVE", False)
+ dataset = task.dataset(args.gen_subset)
+ vocoder = task.args.vocoder
+ with progress_bar.build_progress_bar(args, itr) as t:
+ for sample in t:
+ sample = utils.move_to_cuda(sample) if use_cuda else sample
+ hypos = generator.generate(model, sample, has_targ=args.dump_target)
+ for result in postprocess_results(
+ dataset, sample, hypos, resample_fn, args.dump_target
+ ):
+ dump_result(is_na_model, args, vocoder, *result)
+
+
+def cli_main():
+ parser = make_parser()
+ args = options.parse_args_and_arch(parser)
+ main(args)
+
+
+if __name__ == "__main__":
+ cli_main()
diff --git a/fairseq/examples/speech_synthesis/preprocessing/__init__.py b/fairseq/examples/speech_synthesis/preprocessing/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..6264236915a7269a4d920ee8213004374dd86a9a
--- /dev/null
+++ b/fairseq/examples/speech_synthesis/preprocessing/__init__.py
@@ -0,0 +1,4 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
diff --git a/fairseq/examples/speech_synthesis/preprocessing/denoise_and_vad_audio.py b/fairseq/examples/speech_synthesis/preprocessing/denoise_and_vad_audio.py
new file mode 100644
index 0000000000000000000000000000000000000000..4e13b38a5d3fb44dd3969e6afcb8f202274ee3b7
--- /dev/null
+++ b/fairseq/examples/speech_synthesis/preprocessing/denoise_and_vad_audio.py
@@ -0,0 +1,204 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import argparse
+import logging
+import os
+import csv
+import tempfile
+from collections import defaultdict
+from pathlib import Path
+
+import torchaudio
+try:
+ import webrtcvad
+except ImportError:
+ raise ImportError("Please install py-webrtcvad: pip install webrtcvad")
+import pandas as pd
+from tqdm import tqdm
+
+from examples.speech_synthesis.preprocessing.denoiser.pretrained import master64
+import examples.speech_synthesis.preprocessing.denoiser.utils as utils
+from examples.speech_synthesis.preprocessing.vad import (
+ frame_generator, vad_collector, read_wave, write_wave, FS_MS, THRESHOLD,
+ SCALE
+)
+from examples.speech_to_text.data_utils import save_df_to_tsv
+
+
+log = logging.getLogger(__name__)
+
+PATHS = ["after_denoise", "after_vad"]
+MIN_T = 0.05
+
+
+def generate_tmp_filename(extension="txt"):
+ return tempfile._get_default_tempdir() + "/" + \
+ next(tempfile._get_candidate_names()) + "." + extension
+
+
+def convert_sr(inpath, sr, output_path=None):
+ if not output_path:
+ output_path = generate_tmp_filename("wav")
+ cmd = f"sox {inpath} -r {sr} {output_path}"
+ os.system(cmd)
+ return output_path
+
+
+def apply_vad(vad, inpath):
+ audio, sample_rate = read_wave(inpath)
+ frames = frame_generator(FS_MS, audio, sample_rate)
+ frames = list(frames)
+ segments = vad_collector(sample_rate, FS_MS, 300, vad, frames)
+ merge_segments = list()
+ timestamp_start = 0.0
+ timestamp_end = 0.0
+ # removing start, end, and long sequences of sils
+ for i, segment in enumerate(segments):
+ merge_segments.append(segment[0])
+ if i and timestamp_start:
+ sil_duration = segment[1] - timestamp_end
+ if sil_duration > THRESHOLD:
+ merge_segments.append(int(THRESHOLD / SCALE) * (b'\x00'))
+ else:
+ merge_segments.append(int((sil_duration / SCALE)) * (b'\x00'))
+ timestamp_start = segment[1]
+ timestamp_end = segment[2]
+ segment = b''.join(merge_segments)
+ return segment, sample_rate
+
+
+def write(wav, filename, sr=16_000):
+ # Normalize audio if it prevents clipping
+ wav = wav / max(wav.abs().max().item(), 1)
+ torchaudio.save(filename, wav.cpu(), sr, encoding="PCM_S",
+ bits_per_sample=16)
+
+
+def process(args):
+ # making sure we are requested either denoise or vad
+ if not args.denoise and not args.vad:
+ log.error("No denoise or vad is requested.")
+ return
+
+ log.info("Creating out directories...")
+ if args.denoise:
+ out_denoise = Path(args.output_dir).absolute().joinpath(PATHS[0])
+ out_denoise.mkdir(parents=True, exist_ok=True)
+ if args.vad:
+ out_vad = Path(args.output_dir).absolute().joinpath(PATHS[1])
+ out_vad.mkdir(parents=True, exist_ok=True)
+
+ log.info("Loading pre-trained speech enhancement model...")
+ model = master64().to(args.device)
+
+ log.info("Building the VAD model...")
+ vad = webrtcvad.Vad(int(args.vad_agg_level))
+
+ # preparing the output dict
+ output_dict = defaultdict(list)
+
+ log.info(f"Parsing input manifest: {args.audio_manifest}")
+ with open(args.audio_manifest, "r") as f:
+ manifest_dict = csv.DictReader(f, delimiter="\t")
+ for row in tqdm(manifest_dict):
+ filename = str(row["audio"])
+
+ final_output = filename
+ keep_sample = True
+ n_frames = row["n_frames"]
+ snr = -1
+ if args.denoise:
+ output_path_denoise = out_denoise.joinpath(Path(filename).name)
+ # convert to 16khz in case we use a differet sr
+ tmp_path = convert_sr(final_output, 16000)
+
+ # loading audio file and generating the enhanced version
+ out, sr = torchaudio.load(tmp_path)
+ out = out.to(args.device)
+ estimate = model(out)
+ estimate = (1 - args.dry_wet) * estimate + args.dry_wet * out
+ write(estimate[0], str(output_path_denoise), sr)
+
+ snr = utils.cal_snr(out, estimate)
+ snr = snr.cpu().detach().numpy()[0][0]
+ final_output = str(output_path_denoise)
+
+ if args.vad:
+ output_path_vad = out_vad.joinpath(Path(filename).name)
+ sr = torchaudio.info(final_output).sample_rate
+ if sr in [16000, 32000, 48000]:
+ tmp_path = final_output
+ elif sr < 16000:
+ tmp_path = convert_sr(final_output, 16000)
+ elif sr < 32000:
+ tmp_path = convert_sr(final_output, 32000)
+ else:
+ tmp_path = convert_sr(final_output, 48000)
+ # apply VAD
+ segment, sample_rate = apply_vad(vad, tmp_path)
+ if len(segment) < sample_rate * MIN_T:
+ keep_sample = False
+ print((
+ f"WARNING: skip {filename} because it is too short "
+ f"after VAD ({len(segment) / sample_rate} < {MIN_T})"
+ ))
+ else:
+ if sample_rate != sr:
+ tmp_path = generate_tmp_filename("wav")
+ write_wave(tmp_path, segment, sample_rate)
+ convert_sr(tmp_path, sr,
+ output_path=str(output_path_vad))
+ else:
+ write_wave(str(output_path_vad), segment, sample_rate)
+ final_output = str(output_path_vad)
+ segment, _ = torchaudio.load(final_output)
+ n_frames = segment.size(1)
+
+ if keep_sample:
+ output_dict["id"].append(row["id"])
+ output_dict["audio"].append(final_output)
+ output_dict["n_frames"].append(n_frames)
+ output_dict["tgt_text"].append(row["tgt_text"])
+ output_dict["speaker"].append(row["speaker"])
+ output_dict["src_text"].append(row["src_text"])
+ output_dict["snr"].append(snr)
+
+ out_tsv_path = Path(args.output_dir) / Path(args.audio_manifest).name
+ log.info(f"Saving manifest to {out_tsv_path.as_posix()}")
+ save_df_to_tsv(pd.DataFrame.from_dict(output_dict), out_tsv_path)
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--audio-manifest", "-i", required=True,
+ type=str, help="path to the input manifest.")
+ parser.add_argument(
+ "--output-dir", "-o", required=True, type=str,
+ help="path to the output dir. it will contain files after denoising and"
+ " vad"
+ )
+ parser.add_argument("--vad-agg-level", "-a", type=int, default=2,
+ help="the aggresive level of the vad [0-3].")
+ parser.add_argument(
+ "--dry-wet", "-dw", type=float, default=0.01,
+ help="the level of linear interpolation between noisy and enhanced "
+ "files."
+ )
+ parser.add_argument(
+ "--device", "-d", type=str, default="cpu",
+ help="the device to be used for the speech enhancement model: "
+ "cpu | cuda."
+ )
+ parser.add_argument("--denoise", action="store_true",
+ help="apply a denoising")
+ parser.add_argument("--vad", action="store_true", help="apply a VAD")
+ args = parser.parse_args()
+
+ process(args)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/fairseq/examples/speech_synthesis/preprocessing/denoiser/__init__.py b/fairseq/examples/speech_synthesis/preprocessing/denoiser/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..6264236915a7269a4d920ee8213004374dd86a9a
--- /dev/null
+++ b/fairseq/examples/speech_synthesis/preprocessing/denoiser/__init__.py
@@ -0,0 +1,4 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
diff --git a/fairseq/examples/speech_synthesis/preprocessing/denoiser/demucs.py b/fairseq/examples/speech_synthesis/preprocessing/denoiser/demucs.py
new file mode 100644
index 0000000000000000000000000000000000000000..3f70e73d6a37d32e05b6cf0e87f42e13c467cd52
--- /dev/null
+++ b/fairseq/examples/speech_synthesis/preprocessing/denoiser/demucs.py
@@ -0,0 +1,473 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+# author: adefossez
+
+import math
+import time
+
+import torch as th
+from torch import nn
+from torch.nn import functional as F
+
+from .resample import downsample2, upsample2
+from .utils import capture_init
+
+
+class BLSTM(nn.Module):
+ def __init__(self, dim, layers=2, bi=True):
+ super().__init__()
+ klass = nn.LSTM
+ self.lstm = klass(
+ bidirectional=bi, num_layers=layers, hidden_size=dim, input_size=dim
+ )
+ self.linear = None
+ if bi:
+ self.linear = nn.Linear(2 * dim, dim)
+
+ def forward(self, x, hidden=None):
+ x, hidden = self.lstm(x, hidden)
+ if self.linear:
+ x = self.linear(x)
+ return x, hidden
+
+
+def rescale_conv(conv, reference):
+ std = conv.weight.std().detach()
+ scale = (std / reference)**0.5
+ conv.weight.data /= scale
+ if conv.bias is not None:
+ conv.bias.data /= scale
+
+
+def rescale_module(module, reference):
+ for sub in module.modules():
+ if isinstance(sub, (nn.Conv1d, nn.ConvTranspose1d)):
+ rescale_conv(sub, reference)
+
+
+class Demucs(nn.Module):
+ """
+ Demucs speech enhancement model.
+ Args:
+ - chin (int): number of input channels.
+ - chout (int): number of output channels.
+ - hidden (int): number of initial hidden channels.
+ - depth (int): number of layers.
+ - kernel_size (int): kernel size for each layer.
+ - stride (int): stride for each layer.
+ - causal (bool): if false, uses BiLSTM instead of LSTM.
+ - resample (int): amount of resampling to apply to the input/output.
+ Can be one of 1, 2 or 4.
+ - growth (float): number of channels is multiplied by this for every layer.
+ - max_hidden (int): maximum number of channels. Can be useful to
+ control the size/speed of the model.
+ - normalize (bool): if true, normalize the input.
+ - glu (bool): if true uses GLU instead of ReLU in 1x1 convolutions.
+ - rescale (float): controls custom weight initialization.
+ See https://arxiv.org/abs/1911.13254.
+ - floor (float): stability flooring when normalizing.
+
+ """
+ @capture_init
+ def __init__(self,
+ chin=1,
+ chout=1,
+ hidden=48,
+ depth=5,
+ kernel_size=8,
+ stride=4,
+ causal=True,
+ resample=4,
+ growth=2,
+ max_hidden=10_000,
+ normalize=True,
+ glu=True,
+ rescale=0.1,
+ floor=1e-3):
+
+ super().__init__()
+ if resample not in [1, 2, 4]:
+ raise ValueError("Resample should be 1, 2 or 4.")
+
+ self.chin = chin
+ self.chout = chout
+ self.hidden = hidden
+ self.depth = depth
+ self.kernel_size = kernel_size
+ self.stride = stride
+ self.causal = causal
+ self.floor = floor
+ self.resample = resample
+ self.normalize = normalize
+
+ self.encoder = nn.ModuleList()
+ self.decoder = nn.ModuleList()
+ activation = nn.GLU(1) if glu else nn.ReLU()
+ ch_scale = 2 if glu else 1
+
+ for index in range(depth):
+ encode = []
+ encode += [
+ nn.Conv1d(chin, hidden, kernel_size, stride),
+ nn.ReLU(),
+ nn.Conv1d(hidden, hidden * ch_scale, 1), activation,
+ ]
+ self.encoder.append(nn.Sequential(*encode))
+
+ decode = []
+ decode += [
+ nn.Conv1d(hidden, ch_scale * hidden, 1), activation,
+ nn.ConvTranspose1d(hidden, chout, kernel_size, stride),
+ ]
+ if index > 0:
+ decode.append(nn.ReLU())
+ self.decoder.insert(0, nn.Sequential(*decode))
+ chout = hidden
+ chin = hidden
+ hidden = min(int(growth * hidden), max_hidden)
+
+ self.lstm = BLSTM(chin, bi=not causal)
+ if rescale:
+ rescale_module(self, reference=rescale)
+
+ def valid_length(self, length):
+ """
+ Return the nearest valid length to use with the model so that
+ there is no time steps left over in a convolutions, e.g. for all
+ layers, size of the input - kernel_size % stride = 0.
+
+ If the mixture has a valid length, the estimated sources
+ will have exactly the same length.
+ """
+ length = math.ceil(length * self.resample)
+ for _ in range(self.depth):
+ length = math.ceil((length - self.kernel_size) / self.stride) + 1
+ length = max(length, 1)
+ for _ in range(self.depth):
+ length = (length - 1) * self.stride + self.kernel_size
+ length = int(math.ceil(length / self.resample))
+ return int(length)
+
+ @property
+ def total_stride(self):
+ return self.stride ** self.depth // self.resample
+
+ def forward(self, mix):
+ if mix.dim() == 2:
+ mix = mix.unsqueeze(1)
+
+ if self.normalize:
+ mono = mix.mean(dim=1, keepdim=True)
+ std = mono.std(dim=-1, keepdim=True)
+ mix = mix / (self.floor + std)
+ else:
+ std = 1
+ length = mix.shape[-1]
+ x = mix
+ x = F.pad(x, (0, self.valid_length(length) - length))
+ if self.resample == 2:
+ x = upsample2(x)
+ elif self.resample == 4:
+ x = upsample2(x)
+ x = upsample2(x)
+ skips = []
+ for encode in self.encoder:
+ x = encode(x)
+ skips.append(x)
+ x = x.permute(2, 0, 1)
+ x, _ = self.lstm(x)
+ x = x.permute(1, 2, 0)
+ for decode in self.decoder:
+ skip = skips.pop(-1)
+ x = x + skip[..., :x.shape[-1]]
+ x = decode(x)
+ if self.resample == 2:
+ x = downsample2(x)
+ elif self.resample == 4:
+ x = downsample2(x)
+ x = downsample2(x)
+
+ x = x[..., :length]
+ return std * x
+
+
+def fast_conv(conv, x):
+ """
+ Faster convolution evaluation if either kernel size is 1
+ or length of sequence is 1.
+ """
+ batch, chin, length = x.shape
+ chout, chin, kernel = conv.weight.shape
+ assert batch == 1
+ if kernel == 1:
+ x = x.view(chin, length)
+ out = th.addmm(conv.bias.view(-1, 1),
+ conv.weight.view(chout, chin), x)
+ elif length == kernel:
+ x = x.view(chin * kernel, 1)
+ out = th.addmm(conv.bias.view(-1, 1),
+ conv.weight.view(chout, chin * kernel), x)
+ else:
+ out = conv(x)
+ return out.view(batch, chout, -1)
+
+
+class DemucsStreamer:
+ """
+ Streaming implementation for Demucs. It supports being fed with any amount
+ of audio at a time. You will get back as much audio as possible at that
+ point.
+
+ Args:
+ - demucs (Demucs): Demucs model.
+ - dry (float): amount of dry (e.g. input) signal to keep. 0 is maximum
+ noise removal, 1 just returns the input signal. Small values > 0
+ allows to limit distortions.
+ - num_frames (int): number of frames to process at once. Higher values
+ will increase overall latency but improve the real time factor.
+ - resample_lookahead (int): extra lookahead used for the resampling.
+ - resample_buffer (int): size of the buffer of previous inputs/outputs
+ kept for resampling.
+ """
+ def __init__(self, demucs,
+ dry=0,
+ num_frames=1,
+ resample_lookahead=64,
+ resample_buffer=256):
+ device = next(iter(demucs.parameters())).device
+ self.demucs = demucs
+ self.lstm_state = None
+ self.conv_state = None
+ self.dry = dry
+ self.resample_lookahead = resample_lookahead
+ resample_buffer = min(demucs.total_stride, resample_buffer)
+ self.resample_buffer = resample_buffer
+ self.frame_length = demucs.valid_length(1) + \
+ demucs.total_stride * (num_frames - 1)
+ self.total_length = self.frame_length + self.resample_lookahead
+ self.stride = demucs.total_stride * num_frames
+ self.resample_in = th.zeros(demucs.chin, resample_buffer, device=device)
+ self.resample_out = th.zeros(
+ demucs.chin, resample_buffer, device=device
+ )
+
+ self.frames = 0
+ self.total_time = 0
+ self.variance = 0
+ self.pending = th.zeros(demucs.chin, 0, device=device)
+
+ bias = demucs.decoder[0][2].bias
+ weight = demucs.decoder[0][2].weight
+ chin, chout, kernel = weight.shape
+ self._bias = bias.view(-1, 1).repeat(1, kernel).view(-1, 1)
+ self._weight = weight.permute(1, 2, 0).contiguous()
+
+ def reset_time_per_frame(self):
+ self.total_time = 0
+ self.frames = 0
+
+ @property
+ def time_per_frame(self):
+ return self.total_time / self.frames
+
+ def flush(self):
+ """
+ Flush remaining audio by padding it with zero. Call this
+ when you have no more input and want to get back the last chunk of audio.
+ """
+ pending_length = self.pending.shape[1]
+ padding = th.zeros(
+ self.demucs.chin, self.total_length, device=self.pending.device
+ )
+ out = self.feed(padding)
+ return out[:, :pending_length]
+
+ def feed(self, wav):
+ """
+ Apply the model to mix using true real time evaluation.
+ Normalization is done online as is the resampling.
+ """
+ begin = time.time()
+ demucs = self.demucs
+ resample_buffer = self.resample_buffer
+ stride = self.stride
+ resample = demucs.resample
+
+ if wav.dim() != 2:
+ raise ValueError("input wav should be two dimensional.")
+ chin, _ = wav.shape
+ if chin != demucs.chin:
+ raise ValueError(f"Expected {demucs.chin} channels, got {chin}")
+
+ self.pending = th.cat([self.pending, wav], dim=1)
+ outs = []
+ while self.pending.shape[1] >= self.total_length:
+ self.frames += 1
+ frame = self.pending[:, :self.total_length]
+ dry_signal = frame[:, :stride]
+ if demucs.normalize:
+ mono = frame.mean(0)
+ variance = (mono**2).mean()
+ self.variance = variance / self.frames + \
+ (1 - 1 / self.frames) * self.variance
+ frame = frame / (demucs.floor + math.sqrt(self.variance))
+ frame = th.cat([self.resample_in, frame], dim=-1)
+ self.resample_in[:] = frame[:, stride - resample_buffer:stride]
+
+ if resample == 4:
+ frame = upsample2(upsample2(frame))
+ elif resample == 2:
+ frame = upsample2(frame)
+ # remove pre sampling buffer
+ frame = frame[:, resample * resample_buffer:]
+ # remove extra samples after window
+ frame = frame[:, :resample * self.frame_length]
+
+ out, extra = self._separate_frame(frame)
+ padded_out = th.cat([self.resample_out, out, extra], 1)
+ self.resample_out[:] = out[:, -resample_buffer:]
+ if resample == 4:
+ out = downsample2(downsample2(padded_out))
+ elif resample == 2:
+ out = downsample2(padded_out)
+ else:
+ out = padded_out
+
+ out = out[:, resample_buffer // resample:]
+ out = out[:, :stride]
+
+ if demucs.normalize:
+ out *= math.sqrt(self.variance)
+ out = self.dry * dry_signal + (1 - self.dry) * out
+ outs.append(out)
+ self.pending = self.pending[:, stride:]
+
+ self.total_time += time.time() - begin
+ if outs:
+ out = th.cat(outs, 1)
+ else:
+ out = th.zeros(chin, 0, device=wav.device)
+ return out
+
+ def _separate_frame(self, frame):
+ demucs = self.demucs
+ skips = []
+ next_state = []
+ first = self.conv_state is None
+ stride = self.stride * demucs.resample
+ x = frame[None]
+ for idx, encode in enumerate(demucs.encoder):
+ stride //= demucs.stride
+ length = x.shape[2]
+ if idx == demucs.depth - 1:
+ # This is sligthly faster for the last conv
+ x = fast_conv(encode[0], x)
+ x = encode[1](x)
+ x = fast_conv(encode[2], x)
+ x = encode[3](x)
+ else:
+ if not first:
+ prev = self.conv_state.pop(0)
+ prev = prev[..., stride:]
+ tgt = (length - demucs.kernel_size) // demucs.stride + 1
+ missing = tgt - prev.shape[-1]
+ offset = length - demucs.kernel_size - \
+ demucs.stride * (missing - 1)
+ x = x[..., offset:]
+ x = encode[1](encode[0](x))
+ x = fast_conv(encode[2], x)
+ x = encode[3](x)
+ if not first:
+ x = th.cat([prev, x], -1)
+ next_state.append(x)
+ skips.append(x)
+
+ x = x.permute(2, 0, 1)
+ x, self.lstm_state = demucs.lstm(x, self.lstm_state)
+ x = x.permute(1, 2, 0)
+ # In the following, x contains only correct samples, i.e. the one
+ # for which each time position is covered by two window of the upper
+ # layer. extra contains extra samples to the right, and is used only as
+ # a better padding for the online resampling.
+ extra = None
+ for idx, decode in enumerate(demucs.decoder):
+ skip = skips.pop(-1)
+ x += skip[..., :x.shape[-1]]
+ x = fast_conv(decode[0], x)
+ x = decode[1](x)
+
+ if extra is not None:
+ skip = skip[..., x.shape[-1]:]
+ extra += skip[..., :extra.shape[-1]]
+ extra = decode[2](decode[1](decode[0](extra)))
+ x = decode[2](x)
+ next_state.append(
+ x[..., -demucs.stride:] - decode[2].bias.view(-1, 1)
+ )
+ if extra is None:
+ extra = x[..., -demucs.stride:]
+ else:
+ extra[..., :demucs.stride] += next_state[-1]
+ x = x[..., :-demucs.stride]
+
+ if not first:
+ prev = self.conv_state.pop(0)
+ x[..., :demucs.stride] += prev
+ if idx != demucs.depth - 1:
+ x = decode[3](x)
+ extra = decode[3](extra)
+ self.conv_state = next_state
+ return x[0], extra[0]
+
+
+def test():
+ import argparse
+ parser = argparse.ArgumentParser(
+ "denoiser.demucs",
+ description="Benchmark the streaming Demucs implementation, as well as "
+ "checking the delta with the offline implementation.")
+ parser.add_argument("--depth", default=5, type=int)
+ parser.add_argument("--resample", default=4, type=int)
+ parser.add_argument("--hidden", default=48, type=int)
+ parser.add_argument("--sample_rate", default=16000, type=float)
+ parser.add_argument("--device", default="cpu")
+ parser.add_argument("-t", "--num_threads", type=int)
+ parser.add_argument("-f", "--num_frames", type=int, default=1)
+ args = parser.parse_args()
+ if args.num_threads:
+ th.set_num_threads(args.num_threads)
+ sr = args.sample_rate
+ sr_ms = sr / 1000
+ demucs = Demucs(
+ depth=args.depth, hidden=args.hidden, resample=args.resample
+ ).to(args.device)
+ x = th.randn(1, int(sr * 4)).to(args.device)
+ out = demucs(x[None])[0]
+ streamer = DemucsStreamer(demucs, num_frames=args.num_frames)
+ out_rt = []
+ frame_size = streamer.total_length
+ with th.no_grad():
+ while x.shape[1] > 0:
+ out_rt.append(streamer.feed(x[:, :frame_size]))
+ x = x[:, frame_size:]
+ frame_size = streamer.demucs.total_stride
+ out_rt.append(streamer.flush())
+ out_rt = th.cat(out_rt, 1)
+ model_size = sum(p.numel() for p in demucs.parameters()) * 4 / 2**20
+ initial_lag = streamer.total_length / sr_ms
+ tpf = 1000 * streamer.time_per_frame
+ print(f"model size: {model_size:.1f}MB, ", end='')
+ print(f"delta batch/streaming: {th.norm(out - out_rt) / th.norm(out):.2%}")
+ print(f"initial lag: {initial_lag:.1f}ms, ", end='')
+ print(f"stride: {streamer.stride * args.num_frames / sr_ms:.1f}ms")
+ print(f"time per frame: {tpf:.1f}ms, ", end='')
+ rtf = (1000 * streamer.time_per_frame) / (streamer.stride / sr_ms)
+ print(f"RTF: {rtf:.2f}")
+ print(f"Total lag with computation: {initial_lag + tpf:.1f}ms")
+
+
+if __name__ == "__main__":
+ test()
diff --git a/fairseq/examples/speech_synthesis/preprocessing/denoiser/pretrained.py b/fairseq/examples/speech_synthesis/preprocessing/denoiser/pretrained.py
new file mode 100644
index 0000000000000000000000000000000000000000..2fa846075b6872cdcc0baebca0b9acbb9ffcd287
--- /dev/null
+++ b/fairseq/examples/speech_synthesis/preprocessing/denoiser/pretrained.py
@@ -0,0 +1,81 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+# author: adefossez
+
+import logging
+
+import torch.hub
+
+from .demucs import Demucs
+from .utils import deserialize_model
+
+logger = logging.getLogger(__name__)
+ROOT = "https://dl.fbaipublicfiles.com/adiyoss/denoiser/"
+DNS_48_URL = ROOT + "dns48-11decc9d8e3f0998.th"
+DNS_64_URL = ROOT + "dns64-a7761ff99a7d5bb6.th"
+MASTER_64_URL = ROOT + "master64-8a5dfb4bb92753dd.th"
+
+
+def _demucs(pretrained, url, **kwargs):
+ model = Demucs(**kwargs)
+ if pretrained:
+ state_dict = torch.hub.load_state_dict_from_url(url, map_location='cpu')
+ model.load_state_dict(state_dict)
+ return model
+
+
+def dns48(pretrained=True):
+ return _demucs(pretrained, DNS_48_URL, hidden=48)
+
+
+def dns64(pretrained=True):
+ return _demucs(pretrained, DNS_64_URL, hidden=64)
+
+
+def master64(pretrained=True):
+ return _demucs(pretrained, MASTER_64_URL, hidden=64)
+
+
+def add_model_flags(parser):
+ group = parser.add_mutually_exclusive_group(required=False)
+ group.add_argument(
+ "-m", "--model_path", help="Path to local trained model."
+ )
+ group.add_argument(
+ "--dns48", action="store_true",
+ help="Use pre-trained real time H=48 model trained on DNS."
+ )
+ group.add_argument(
+ "--dns64", action="store_true",
+ help="Use pre-trained real time H=64 model trained on DNS."
+ )
+ group.add_argument(
+ "--master64", action="store_true",
+ help="Use pre-trained real time H=64 model trained on DNS and Valentini."
+ )
+
+
+def get_model(args):
+ """
+ Load local model package or torchhub pre-trained model.
+ """
+ if args.model_path:
+ logger.info("Loading model from %s", args.model_path)
+ pkg = torch.load(args.model_path)
+ model = deserialize_model(pkg)
+ elif args.dns64:
+ logger.info("Loading pre-trained real time H=64 model trained on DNS.")
+ model = dns64()
+ elif args.master64:
+ logger.info(
+ "Loading pre-trained real time H=64 model trained on DNS and Valentini."
+ )
+ model = master64()
+ else:
+ logger.info("Loading pre-trained real time H=48 model trained on DNS.")
+ model = dns48()
+ logger.debug(model)
+ return model
diff --git a/fairseq/examples/speech_synthesis/preprocessing/denoiser/resample.py b/fairseq/examples/speech_synthesis/preprocessing/denoiser/resample.py
new file mode 100644
index 0000000000000000000000000000000000000000..1222addc424d4f898d602009e4032907241aadfe
--- /dev/null
+++ b/fairseq/examples/speech_synthesis/preprocessing/denoiser/resample.py
@@ -0,0 +1,79 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+# author: adefossez
+
+import math
+
+import torch as th
+from torch.nn import functional as F
+
+
+def sinc(t):
+ """sinc.
+
+ :param t: the input tensor
+ """
+ return th.where(t == 0, th.tensor(1., device=t.device, dtype=t.dtype),
+ th.sin(t) / t)
+
+
+def kernel_upsample2(zeros=56):
+ """kernel_upsample2.
+
+ """
+ win = th.hann_window(4 * zeros + 1, periodic=False)
+ winodd = win[1::2]
+ t = th.linspace(-zeros + 0.5, zeros - 0.5, 2 * zeros)
+ t *= math.pi
+ kernel = (sinc(t) * winodd).view(1, 1, -1)
+ return kernel
+
+
+def upsample2(x, zeros=56):
+ """
+ Upsampling the input by 2 using sinc interpolation.
+ Smith, Julius, and Phil Gossett. "A flexible sampling-rate conversion method."
+ ICASSP'84. IEEE International Conference on Acoustics, Speech, and Signal Processing.
+ Vol. 9. IEEE, 1984.
+ """
+ *other, time = x.shape
+ kernel = kernel_upsample2(zeros).to(x)
+ out = F.conv1d(x.view(-1, 1, time), kernel, padding=zeros)[..., 1:].view(
+ *other, time
+ )
+ y = th.stack([x, out], dim=-1)
+ return y.view(*other, -1)
+
+
+def kernel_downsample2(zeros=56):
+ """kernel_downsample2.
+
+ """
+ win = th.hann_window(4 * zeros + 1, periodic=False)
+ winodd = win[1::2]
+ t = th.linspace(-zeros + 0.5, zeros - 0.5, 2 * zeros)
+ t.mul_(math.pi)
+ kernel = (sinc(t) * winodd).view(1, 1, -1)
+ return kernel
+
+
+def downsample2(x, zeros=56):
+ """
+ Downsampling the input by 2 using sinc interpolation.
+ Smith, Julius, and Phil Gossett. "A flexible sampling-rate conversion method."
+ ICASSP'84. IEEE International Conference on Acoustics, Speech, and Signal Processing.
+ Vol. 9. IEEE, 1984.
+ """
+ if x.shape[-1] % 2 != 0:
+ x = F.pad(x, (0, 1))
+ xeven = x[..., ::2]
+ xodd = x[..., 1::2]
+ *other, time = xodd.shape
+ kernel = kernel_downsample2(zeros).to(x)
+ out = xeven + F.conv1d(
+ xodd.view(-1, 1, time), kernel, padding=zeros
+ )[..., :-1].view(*other, time)
+ return out.view(*other, -1).mul(0.5)
diff --git a/fairseq/examples/speech_synthesis/preprocessing/denoiser/utils.py b/fairseq/examples/speech_synthesis/preprocessing/denoiser/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..734d047f1bb8e3aa98c88e152eee7f91fea3d814
--- /dev/null
+++ b/fairseq/examples/speech_synthesis/preprocessing/denoiser/utils.py
@@ -0,0 +1,176 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+# author: adefossez
+
+import functools
+import logging
+from contextlib import contextmanager
+import inspect
+import time
+
+logger = logging.getLogger(__name__)
+
+EPS = 1e-8
+
+
+def capture_init(init):
+ """capture_init.
+
+ Decorate `__init__` with this, and you can then
+ recover the *args and **kwargs passed to it in `self._init_args_kwargs`
+ """
+ @functools.wraps(init)
+ def __init__(self, *args, **kwargs):
+ self._init_args_kwargs = (args, kwargs)
+ init(self, *args, **kwargs)
+
+ return __init__
+
+
+def deserialize_model(package, strict=False):
+ """deserialize_model.
+
+ """
+ klass = package['class']
+ if strict:
+ model = klass(*package['args'], **package['kwargs'])
+ else:
+ sig = inspect.signature(klass)
+ kw = package['kwargs']
+ for key in list(kw):
+ if key not in sig.parameters:
+ logger.warning("Dropping inexistant parameter %s", key)
+ del kw[key]
+ model = klass(*package['args'], **kw)
+ model.load_state_dict(package['state'])
+ return model
+
+
+def copy_state(state):
+ return {k: v.cpu().clone() for k, v in state.items()}
+
+
+def serialize_model(model):
+ args, kwargs = model._init_args_kwargs
+ state = copy_state(model.state_dict())
+ return {"class": model.__class__, "args": args, "kwargs": kwargs, "state": state}
+
+
+@contextmanager
+def swap_state(model, state):
+ """
+ Context manager that swaps the state of a model, e.g:
+
+ # model is in old state
+ with swap_state(model, new_state):
+ # model in new state
+ # model back to old state
+ """
+ old_state = copy_state(model.state_dict())
+ model.load_state_dict(state)
+ try:
+ yield
+ finally:
+ model.load_state_dict(old_state)
+
+
+def pull_metric(history, name):
+ out = []
+ for metrics in history:
+ if name in metrics:
+ out.append(metrics[name])
+ return out
+
+
+class LogProgress:
+ """
+ Sort of like tqdm but using log lines and not as real time.
+ Args:
+ - logger: logger obtained from `logging.getLogger`,
+ - iterable: iterable object to wrap
+ - updates (int): number of lines that will be printed, e.g.
+ if `updates=5`, log every 1/5th of the total length.
+ - total (int): length of the iterable, in case it does not support
+ `len`.
+ - name (str): prefix to use in the log.
+ - level: logging level (like `logging.INFO`).
+ """
+ def __init__(self,
+ logger,
+ iterable,
+ updates=5,
+ total=None,
+ name="LogProgress",
+ level=logging.INFO):
+ self.iterable = iterable
+ self.total = total or len(iterable)
+ self.updates = updates
+ self.name = name
+ self.logger = logger
+ self.level = level
+
+ def update(self, **infos):
+ self._infos = infos
+
+ def __iter__(self):
+ self._iterator = iter(self.iterable)
+ self._index = -1
+ self._infos = {}
+ self._begin = time.time()
+ return self
+
+ def __next__(self):
+ self._index += 1
+ try:
+ value = next(self._iterator)
+ except StopIteration:
+ raise
+ else:
+ return value
+ finally:
+ log_every = max(1, self.total // self.updates)
+ # logging is delayed by 1 it, in order to have the metrics from update
+ if self._index >= 1 and self._index % log_every == 0:
+ self._log()
+
+ def _log(self):
+ self._speed = (1 + self._index) / (time.time() - self._begin)
+ infos = " | ".join(f"{k.capitalize()} {v}" for k, v in self._infos.items())
+ if self._speed < 1e-4:
+ speed = "oo sec/it"
+ elif self._speed < 0.1:
+ speed = f"{1/self._speed:.1f} sec/it"
+ else:
+ speed = f"{self._speed:.1f} it/sec"
+ out = f"{self.name} | {self._index}/{self.total} | {speed}"
+ if infos:
+ out += " | " + infos
+ self.logger.log(self.level, out)
+
+
+def colorize(text, color):
+ """
+ Display text with some ANSI color in the terminal.
+ """
+ code = f"\033[{color}m"
+ restore = "\033[0m"
+ return "".join([code, text, restore])
+
+
+def bold(text):
+ """
+ Display text in bold in the terminal.
+ """
+ return colorize(text, "1")
+
+
+def cal_snr(lbl, est):
+ import torch
+ y = 10.0 * torch.log10(
+ torch.sum(lbl**2, dim=-1) / (torch.sum((est-lbl)**2, dim=-1) + EPS) +
+ EPS
+ )
+ return y
diff --git a/fairseq/examples/speech_synthesis/preprocessing/get_common_voice_audio_manifest.py b/fairseq/examples/speech_synthesis/preprocessing/get_common_voice_audio_manifest.py
new file mode 100644
index 0000000000000000000000000000000000000000..a30254604311a488a1d4959f941051890ed32b2e
--- /dev/null
+++ b/fairseq/examples/speech_synthesis/preprocessing/get_common_voice_audio_manifest.py
@@ -0,0 +1,140 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import argparse
+import logging
+from pathlib import Path
+from collections import defaultdict
+from typing import List, Dict, Tuple
+
+import pandas as pd
+import numpy as np
+import torchaudio
+from tqdm import tqdm
+
+from examples.speech_to_text.data_utils import load_df_from_tsv, save_df_to_tsv
+
+
+log = logging.getLogger(__name__)
+
+SPLITS = ["train", "dev", "test"]
+
+
+def get_top_n(
+ root: Path, n_speakers: int = 10, min_n_tokens: int = 5
+) -> pd.DataFrame:
+ df = load_df_from_tsv(root / "validated.tsv")
+ df["n_tokens"] = [len(s.split()) for s in df["sentence"]]
+ df = df[df["n_tokens"] >= min_n_tokens]
+ df["n_frames"] = [
+ torchaudio.info((root / "clips" / p).as_posix()).num_frames
+ for p in tqdm(df["path"])
+ ]
+ df["id"] = [Path(p).stem for p in df["path"]]
+ total_duration_ms = df.groupby("client_id")["n_frames"].agg(["sum"])
+ total_duration_ms = total_duration_ms.sort_values("sum", ascending=False)
+
+ top_n_total_duration_ms = total_duration_ms.head(n_speakers)
+ top_n_client_ids = set(top_n_total_duration_ms.index.tolist())
+ df_top_n = df[df["client_id"].isin(top_n_client_ids)]
+ return df_top_n
+
+
+def get_splits(
+ df, train_split_ratio=0.99, speaker_in_all_splits=False, rand_seed=0
+) -> Tuple[Dict[str, str], List[str]]:
+ np.random.seed(rand_seed)
+ dev_split_ratio = (1. - train_split_ratio) / 3
+ grouped = list(df.groupby("client_id"))
+ id_to_split = {}
+ for _, cur_df in tqdm(grouped):
+ cur_n_examples = len(cur_df)
+ if speaker_in_all_splits and cur_n_examples < 3:
+ continue
+ cur_n_train = int(cur_n_examples * train_split_ratio)
+ cur_n_dev = int(cur_n_examples * dev_split_ratio)
+ cur_n_test = cur_n_examples - cur_n_dev - cur_n_train
+ if speaker_in_all_splits and cur_n_dev * cur_n_test == 0:
+ cur_n_dev, cur_n_test = 1, 1
+ cur_n_train = cur_n_examples - cur_n_dev - cur_n_test
+ cur_indices = cur_df.index.tolist()
+ cur_shuffled_indices = np.random.permutation(cur_n_examples)
+ cur_shuffled_indices = [cur_indices[i] for i in cur_shuffled_indices]
+ cur_indices_by_split = {
+ "train": cur_shuffled_indices[:cur_n_train],
+ "dev": cur_shuffled_indices[cur_n_train: cur_n_train + cur_n_dev],
+ "test": cur_shuffled_indices[cur_n_train + cur_n_dev:]
+ }
+ for split in SPLITS:
+ for i in cur_indices_by_split[split]:
+ id_ = df["id"].loc[i]
+ id_to_split[id_] = split
+ return id_to_split, sorted(df["client_id"].unique())
+
+
+def convert_to_wav(root: Path, filenames: List[str], target_sr=16_000):
+ out_root = root / "wav"
+ out_root.mkdir(exist_ok=True, parents=True)
+ print("Converting to WAV...")
+ for n in tqdm(filenames):
+ in_path = (root / "clips" / n).as_posix()
+ waveform, sr = torchaudio.load(in_path)
+ converted, converted_sr = torchaudio.sox_effects.apply_effects_tensor(
+ waveform, sr, [["rate", str(target_sr)], ["channels", "1"]]
+ )
+ out_path = (out_root / Path(n).with_suffix(".wav").name).as_posix()
+ torchaudio.save(out_path, converted, converted_sr, encoding="PCM_S",
+ bits_per_sample=16)
+
+
+def process(args):
+ data_root = Path(args.data_root).absolute() / args.lang
+
+ # Generate TSV manifest
+ print("Generating manifest...")
+
+ df_top_n = get_top_n(data_root)
+ id_to_split, speakers = get_splits(df_top_n)
+
+ if args.convert_to_wav:
+ convert_to_wav(data_root, df_top_n["path"].tolist())
+
+ manifest_by_split = {split: defaultdict(list) for split in SPLITS}
+ for sample in tqdm(df_top_n.to_dict(orient="index").values()):
+ sample_id = sample["id"]
+ split = id_to_split[sample_id]
+ manifest_by_split[split]["id"].append(sample_id)
+ if args.convert_to_wav:
+ audio_path = data_root / "wav" / f"{sample_id}.wav"
+ else:
+ audio_path = data_root / "clips" / f"{sample_id}.mp3"
+ manifest_by_split[split]["audio"].append(audio_path.as_posix())
+ manifest_by_split[split]["n_frames"].append(sample["n_frames"])
+ manifest_by_split[split]["tgt_text"].append(sample["sentence"])
+ manifest_by_split[split]["speaker"].append(sample["client_id"])
+ manifest_by_split[split]["src_text"].append(sample["sentence"])
+
+ output_root = Path(args.output_manifest_root).absolute()
+ output_root.mkdir(parents=True, exist_ok=True)
+ for split in SPLITS:
+ save_df_to_tsv(
+ pd.DataFrame.from_dict(manifest_by_split[split]),
+ output_root / f"{split}.audio.tsv"
+ )
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--data-root", "-d", required=True, type=str)
+ parser.add_argument("--output-manifest-root", "-m", required=True, type=str)
+ parser.add_argument("--lang", "-l", required=True, type=str)
+ parser.add_argument("--convert-to-wav", action="store_true")
+ args = parser.parse_args()
+
+ process(args)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/fairseq/examples/speech_synthesis/preprocessing/get_feature_manifest.py b/fairseq/examples/speech_synthesis/preprocessing/get_feature_manifest.py
new file mode 100644
index 0000000000000000000000000000000000000000..516f2cc469af9b417126dea1988698adac41d8ab
--- /dev/null
+++ b/fairseq/examples/speech_synthesis/preprocessing/get_feature_manifest.py
@@ -0,0 +1,233 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import argparse
+import logging
+from pathlib import Path
+import shutil
+from tempfile import NamedTemporaryFile
+from collections import Counter, defaultdict
+
+import pandas as pd
+import torchaudio
+from tqdm import tqdm
+
+from fairseq.data.audio.audio_utils import convert_waveform
+from examples.speech_to_text.data_utils import (
+ create_zip,
+ gen_config_yaml,
+ gen_vocab,
+ get_zip_manifest,
+ load_tsv_to_dicts,
+ save_df_to_tsv
+)
+from examples.speech_synthesis.data_utils import (
+ extract_logmel_spectrogram, extract_pitch, extract_energy, get_global_cmvn,
+ ipa_phonemize, get_mfa_alignment, get_unit_alignment
+)
+
+
+log = logging.getLogger(__name__)
+
+
+def process(args):
+ assert "train" in args.splits
+ out_root = Path(args.output_root).absolute()
+ out_root.mkdir(exist_ok=True)
+
+ print("Fetching data...")
+ audio_manifest_root = Path(args.audio_manifest_root).absolute()
+ samples = []
+ for s in args.splits:
+ for e in load_tsv_to_dicts(audio_manifest_root / f"{s}.audio.tsv"):
+ e["split"] = s
+ samples.append(e)
+ sample_ids = [s["id"] for s in samples]
+
+ # Get alignment info
+ id_to_alignment = None
+ if args.textgrid_zip is not None:
+ assert args.id_to_units_tsv is None
+ id_to_alignment = get_mfa_alignment(
+ args.textgrid_zip, sample_ids, args.sample_rate, args.hop_length
+ )
+ elif args.id_to_units_tsv is not None:
+ # assume identical hop length on the unit sequence
+ id_to_alignment = get_unit_alignment(args.id_to_units_tsv, sample_ids)
+
+ # Extract features and pack features into ZIP
+ feature_name = "logmelspec80"
+ zip_path = out_root / f"{feature_name}.zip"
+ pitch_zip_path = out_root / "pitch.zip"
+ energy_zip_path = out_root / "energy.zip"
+ gcmvn_npz_path = out_root / "gcmvn_stats.npz"
+ if zip_path.exists() and gcmvn_npz_path.exists():
+ print(f"{zip_path} and {gcmvn_npz_path} exist.")
+ else:
+ feature_root = out_root / feature_name
+ feature_root.mkdir(exist_ok=True)
+ pitch_root = out_root / "pitch"
+ energy_root = out_root / "energy"
+ if args.add_fastspeech_targets:
+ pitch_root.mkdir(exist_ok=True)
+ energy_root.mkdir(exist_ok=True)
+ print("Extracting Mel spectrogram features...")
+ for sample in tqdm(samples):
+ waveform, sample_rate = torchaudio.load(sample["audio"])
+ waveform, sample_rate = convert_waveform(
+ waveform, sample_rate, normalize_volume=args.normalize_volume,
+ to_sample_rate=args.sample_rate
+ )
+ sample_id = sample["id"]
+ target_length = None
+ if id_to_alignment is not None:
+ a = id_to_alignment[sample_id]
+ target_length = sum(a.frame_durations)
+ if a.start_sec is not None and a.end_sec is not None:
+ start_frame = int(a.start_sec * sample_rate)
+ end_frame = int(a.end_sec * sample_rate)
+ waveform = waveform[:, start_frame: end_frame]
+ extract_logmel_spectrogram(
+ waveform, sample_rate, feature_root / f"{sample_id}.npy",
+ win_length=args.win_length, hop_length=args.hop_length,
+ n_fft=args.n_fft, n_mels=args.n_mels, f_min=args.f_min,
+ f_max=args.f_max, target_length=target_length
+ )
+ if args.add_fastspeech_targets:
+ assert id_to_alignment is not None
+ extract_pitch(
+ waveform, sample_rate, pitch_root / f"{sample_id}.npy",
+ hop_length=args.hop_length, log_scale=True,
+ phoneme_durations=id_to_alignment[sample_id].frame_durations
+ )
+ extract_energy(
+ waveform, energy_root / f"{sample_id}.npy",
+ hop_length=args.hop_length, n_fft=args.n_fft,
+ log_scale=True,
+ phoneme_durations=id_to_alignment[sample_id].frame_durations
+ )
+ print("ZIPing features...")
+ create_zip(feature_root, zip_path)
+ get_global_cmvn(feature_root, gcmvn_npz_path)
+ shutil.rmtree(feature_root)
+ if args.add_fastspeech_targets:
+ create_zip(pitch_root, pitch_zip_path)
+ shutil.rmtree(pitch_root)
+ create_zip(energy_root, energy_zip_path)
+ shutil.rmtree(energy_root)
+
+ print("Fetching ZIP manifest...")
+ audio_paths, audio_lengths = get_zip_manifest(zip_path)
+ pitch_paths, pitch_lengths, energy_paths, energy_lengths = [None] * 4
+ if args.add_fastspeech_targets:
+ pitch_paths, pitch_lengths = get_zip_manifest(pitch_zip_path)
+ energy_paths, energy_lengths = get_zip_manifest(energy_zip_path)
+ # Generate TSV manifest
+ print("Generating manifest...")
+ manifest_by_split = {split: defaultdict(list) for split in args.splits}
+ for sample in tqdm(samples):
+ sample_id, split = sample["id"], sample["split"]
+ normalized_utt = sample["tgt_text"]
+ if id_to_alignment is not None:
+ normalized_utt = " ".join(id_to_alignment[sample_id].tokens)
+ elif args.ipa_vocab:
+ normalized_utt = ipa_phonemize(
+ normalized_utt, lang=args.lang, use_g2p=args.use_g2p
+ )
+ manifest_by_split[split]["id"].append(sample_id)
+ manifest_by_split[split]["audio"].append(audio_paths[sample_id])
+ manifest_by_split[split]["n_frames"].append(audio_lengths[sample_id])
+ manifest_by_split[split]["tgt_text"].append(normalized_utt)
+ manifest_by_split[split]["speaker"].append(sample["speaker"])
+ manifest_by_split[split]["src_text"].append(sample["src_text"])
+ if args.add_fastspeech_targets:
+ assert id_to_alignment is not None
+ duration = " ".join(
+ str(d) for d in id_to_alignment[sample_id].frame_durations
+ )
+ manifest_by_split[split]["duration"].append(duration)
+ manifest_by_split[split]["pitch"].append(pitch_paths[sample_id])
+ manifest_by_split[split]["energy"].append(energy_paths[sample_id])
+ for split in args.splits:
+ save_df_to_tsv(
+ pd.DataFrame.from_dict(manifest_by_split[split]),
+ out_root / f"{split}.tsv"
+ )
+ # Generate vocab
+ vocab_name, spm_filename = None, None
+ if id_to_alignment is not None or args.ipa_vocab:
+ vocab = Counter()
+ for t in manifest_by_split["train"]["tgt_text"]:
+ vocab.update(t.split(" "))
+ vocab_name = "vocab.txt"
+ with open(out_root / vocab_name, "w") as f:
+ for s, c in vocab.most_common():
+ f.write(f"{s} {c}\n")
+ else:
+ spm_filename_prefix = "spm_char"
+ spm_filename = f"{spm_filename_prefix}.model"
+ with NamedTemporaryFile(mode="w") as f:
+ for t in manifest_by_split["train"]["tgt_text"]:
+ f.write(t + "\n")
+ f.flush() # needed to ensure gen_vocab sees dumped text
+ gen_vocab(Path(f.name), out_root / spm_filename_prefix, "char")
+ # Generate speaker list
+ speakers = sorted({sample["speaker"] for sample in samples})
+ speakers_path = out_root / "speakers.txt"
+ with open(speakers_path, "w") as f:
+ for speaker in speakers:
+ f.write(f"{speaker}\n")
+ # Generate config YAML
+ win_len_t = args.win_length / args.sample_rate
+ hop_len_t = args.hop_length / args.sample_rate
+ extra = {
+ "sample_rate": args.sample_rate,
+ "features": {
+ "type": "spectrogram+melscale+log",
+ "eps": 1e-2, "n_mels": args.n_mels, "n_fft": args.n_fft,
+ "window_fn": "hann", "win_length": args.win_length,
+ "hop_length": args.hop_length, "sample_rate": args.sample_rate,
+ "win_len_t": win_len_t, "hop_len_t": hop_len_t,
+ "f_min": args.f_min, "f_max": args.f_max,
+ "n_stft": args.n_fft // 2 + 1
+ }
+ }
+ if len(speakers) > 1:
+ extra["speaker_set_filename"] = "speakers.txt"
+ gen_config_yaml(
+ out_root, spm_filename=spm_filename, vocab_name=vocab_name,
+ audio_root=out_root.as_posix(), input_channels=None,
+ input_feat_per_channel=None, specaugment_policy=None,
+ cmvn_type="global", gcmvn_path=gcmvn_npz_path, extra=extra
+ )
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--audio-manifest-root", "-m", required=True, type=str)
+ parser.add_argument("--output-root", "-o", required=True, type=str)
+ parser.add_argument("--splits", "-s", type=str, nargs="+",
+ default=["train", "dev", "test"])
+ parser.add_argument("--ipa-vocab", action="store_true")
+ parser.add_argument("--use-g2p", action="store_true")
+ parser.add_argument("--lang", type=str, default="en-us")
+ parser.add_argument("--win-length", type=int, default=1024)
+ parser.add_argument("--hop-length", type=int, default=256)
+ parser.add_argument("--n-fft", type=int, default=1024)
+ parser.add_argument("--n-mels", type=int, default=80)
+ parser.add_argument("--f-min", type=int, default=20)
+ parser.add_argument("--f-max", type=int, default=8000)
+ parser.add_argument("--sample-rate", type=int, default=22050)
+ parser.add_argument("--normalize-volume", "-n", action="store_true")
+ parser.add_argument("--textgrid-zip", type=str, default=None)
+ parser.add_argument("--id-to-units-tsv", type=str, default=None)
+ parser.add_argument("--add-fastspeech-targets", action="store_true")
+ args = parser.parse_args()
+
+ process(args)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/fairseq/examples/speech_synthesis/preprocessing/get_ljspeech_audio_manifest.py b/fairseq/examples/speech_synthesis/preprocessing/get_ljspeech_audio_manifest.py
new file mode 100644
index 0000000000000000000000000000000000000000..7ec1fb7521b8a9b821d28bcaaaedb034f6e95e0b
--- /dev/null
+++ b/fairseq/examples/speech_synthesis/preprocessing/get_ljspeech_audio_manifest.py
@@ -0,0 +1,70 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import argparse
+import logging
+from pathlib import Path
+from collections import defaultdict
+
+import pandas as pd
+from torchaudio.datasets import LJSPEECH
+from tqdm import tqdm
+
+from examples.speech_to_text.data_utils import save_df_to_tsv
+
+
+log = logging.getLogger(__name__)
+
+SPLITS = ["train", "dev", "test"]
+
+
+def process(args):
+ out_root = Path(args.output_data_root).absolute()
+ out_root.mkdir(parents=True, exist_ok=True)
+
+ # Generate TSV manifest
+ print("Generating manifest...")
+ # following FastSpeech's splits
+ dataset = LJSPEECH(out_root.as_posix(), download=True)
+ id_to_split = {}
+ for x in dataset._flist:
+ id_ = x[0]
+ speaker = id_.split("-")[0]
+ id_to_split[id_] = {
+ "LJ001": "test", "LJ002": "test", "LJ003": "dev"
+ }.get(speaker, "train")
+ manifest_by_split = {split: defaultdict(list) for split in SPLITS}
+ progress = tqdm(enumerate(dataset), total=len(dataset))
+ for i, (waveform, _, utt, normalized_utt) in progress:
+ sample_id = dataset._flist[i][0]
+ split = id_to_split[sample_id]
+ manifest_by_split[split]["id"].append(sample_id)
+ audio_path = f"{dataset._path}/{sample_id}.wav"
+ manifest_by_split[split]["audio"].append(audio_path)
+ manifest_by_split[split]["n_frames"].append(len(waveform[0]))
+ manifest_by_split[split]["tgt_text"].append(normalized_utt)
+ manifest_by_split[split]["speaker"].append("ljspeech")
+ manifest_by_split[split]["src_text"].append(utt)
+
+ manifest_root = Path(args.output_manifest_root).absolute()
+ manifest_root.mkdir(parents=True, exist_ok=True)
+ for split in SPLITS:
+ save_df_to_tsv(
+ pd.DataFrame.from_dict(manifest_by_split[split]),
+ manifest_root / f"{split}.audio.tsv"
+ )
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--output-data-root", "-d", required=True, type=str)
+ parser.add_argument("--output-manifest-root", "-m", required=True, type=str)
+ args = parser.parse_args()
+
+ process(args)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/fairseq/examples/speech_synthesis/preprocessing/get_speaker_embedding.py b/fairseq/examples/speech_synthesis/preprocessing/get_speaker_embedding.py
new file mode 100644
index 0000000000000000000000000000000000000000..0e3e4c5cd7aef15dae0b41b0ec7b33e17f66597f
--- /dev/null
+++ b/fairseq/examples/speech_synthesis/preprocessing/get_speaker_embedding.py
@@ -0,0 +1,89 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+import argparse
+from collections import defaultdict
+from itertools import chain
+from pathlib import Path
+
+import numpy as np
+import torchaudio
+import torchaudio.sox_effects as ta_sox
+import yaml
+from tqdm import tqdm
+
+from examples.speech_to_text.data_utils import load_tsv_to_dicts
+from examples.speech_synthesis.preprocessing.speaker_embedder import SpkrEmbedder
+
+
+def extract_embedding(audio_path, embedder):
+ wav, sr = torchaudio.load(audio_path) # 2D
+ if sr != embedder.RATE:
+ wav, sr = ta_sox.apply_effects_tensor(
+ wav, sr, [["rate", str(embedder.RATE)]]
+ )
+ try:
+ emb = embedder([wav[0].cuda().float()]).cpu().numpy()
+ except RuntimeError:
+ emb = None
+ return emb
+
+
+def process(args):
+ print("Fetching data...")
+ raw_manifest_root = Path(args.raw_manifest_root).absolute()
+ samples = [load_tsv_to_dicts(raw_manifest_root / (s + ".tsv"))
+ for s in args.splits]
+ samples = list(chain(*samples))
+ with open(args.config, "r") as f:
+ config = yaml.load(f, Loader=yaml.FullLoader)
+ with open(f"{config['audio_root']}/{config['speaker_set_filename']}") as f:
+ speaker_to_id = {r.strip(): i for i, r in enumerate(f)}
+
+ embedder = SpkrEmbedder(args.ckpt).cuda()
+ speaker_to_cnt = defaultdict(float)
+ speaker_to_emb = defaultdict(float)
+ for sample in tqdm(samples, desc="extract emb"):
+ emb = extract_embedding(sample["audio"], embedder)
+ if emb is not None:
+ speaker_to_cnt[sample["speaker"]] += 1
+ speaker_to_emb[sample["speaker"]] += emb
+ if len(speaker_to_emb) != len(speaker_to_id):
+ missed = set(speaker_to_id) - set(speaker_to_emb.keys())
+ print(
+ f"WARNING: missing embeddings for {len(missed)} speaker:\n{missed}"
+ )
+ speaker_emb_mat = np.zeros((len(speaker_to_id), len(emb)), float)
+ for speaker in speaker_to_emb:
+ idx = speaker_to_id[speaker]
+ emb = speaker_to_emb[speaker]
+ cnt = speaker_to_cnt[speaker]
+ speaker_emb_mat[idx, :] = emb / cnt
+ speaker_emb_name = "speaker_emb.npy"
+ speaker_emb_path = f"{config['audio_root']}/{speaker_emb_name}"
+ np.save(speaker_emb_path, speaker_emb_mat)
+ config["speaker_emb_filename"] = speaker_emb_name
+
+ with open(args.new_config, "w") as f:
+ yaml.dump(config, f)
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--raw-manifest-root", "-m", required=True, type=str)
+ parser.add_argument("--splits", "-s", type=str, nargs="+",
+ default=["train"])
+ parser.add_argument("--config", "-c", required=True, type=str)
+ parser.add_argument("--new-config", "-n", required=True, type=str)
+ parser.add_argument("--ckpt", required=True, type=str,
+ help="speaker embedder checkpoint")
+ args = parser.parse_args()
+
+ process(args)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/fairseq/examples/speech_synthesis/preprocessing/get_vctk_audio_manifest.py b/fairseq/examples/speech_synthesis/preprocessing/get_vctk_audio_manifest.py
new file mode 100644
index 0000000000000000000000000000000000000000..7afa40fcd195465a225c9f251734e84fe6b3c7ef
--- /dev/null
+++ b/fairseq/examples/speech_synthesis/preprocessing/get_vctk_audio_manifest.py
@@ -0,0 +1,79 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import argparse
+import logging
+import numpy as np
+import re
+from pathlib import Path
+from collections import defaultdict
+
+import pandas as pd
+from torchaudio.datasets import VCTK
+from tqdm import tqdm
+
+from examples.speech_to_text.data_utils import save_df_to_tsv
+
+
+log = logging.getLogger(__name__)
+
+SPLITS = ["train", "dev", "test"]
+
+
+def normalize_text(text):
+ return re.sub(r"[^a-zA-Z.?!,'\- ]", '', text)
+
+
+def process(args):
+ out_root = Path(args.output_data_root).absolute()
+ out_root.mkdir(parents=True, exist_ok=True)
+
+ # Generate TSV manifest
+ print("Generating manifest...")
+ dataset = VCTK(out_root.as_posix(), download=False)
+ ids = list(dataset._walker)
+ np.random.seed(args.seed)
+ np.random.shuffle(ids)
+ n_train = len(ids) - args.n_dev - args.n_test
+ _split = ["train"] * n_train + ["dev"] * args.n_dev + ["test"] * args.n_test
+ id_to_split = dict(zip(ids, _split))
+ manifest_by_split = {split: defaultdict(list) for split in SPLITS}
+ progress = tqdm(enumerate(dataset), total=len(dataset))
+ for i, (waveform, _, text, speaker_id, _) in progress:
+ sample_id = dataset._walker[i]
+ _split = id_to_split[sample_id]
+ audio_dir = Path(dataset._path) / dataset._folder_audio / speaker_id
+ audio_path = audio_dir / f"{sample_id}.wav"
+ text = normalize_text(text)
+ manifest_by_split[_split]["id"].append(sample_id)
+ manifest_by_split[_split]["audio"].append(audio_path.as_posix())
+ manifest_by_split[_split]["n_frames"].append(len(waveform[0]))
+ manifest_by_split[_split]["tgt_text"].append(text)
+ manifest_by_split[_split]["speaker"].append(speaker_id)
+ manifest_by_split[_split]["src_text"].append(text)
+
+ manifest_root = Path(args.output_manifest_root).absolute()
+ manifest_root.mkdir(parents=True, exist_ok=True)
+ for _split in SPLITS:
+ save_df_to_tsv(
+ pd.DataFrame.from_dict(manifest_by_split[_split]),
+ manifest_root / f"{_split}.audio.tsv"
+ )
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--output-data-root", "-d", required=True, type=str)
+ parser.add_argument("--output-manifest-root", "-m", required=True, type=str)
+ parser.add_argument("--n-dev", default=50, type=int)
+ parser.add_argument("--n-test", default=100, type=int)
+ parser.add_argument("--seed", "-s", default=1234, type=int)
+ args = parser.parse_args()
+
+ process(args)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/fairseq/examples/speech_synthesis/preprocessing/speaker_embedder/__init__.py b/fairseq/examples/speech_synthesis/preprocessing/speaker_embedder/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..3b178676ba322ef613df42977cb498101f841b09
--- /dev/null
+++ b/fairseq/examples/speech_synthesis/preprocessing/speaker_embedder/__init__.py
@@ -0,0 +1,135 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+import librosa
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.data
+import torchaudio
+
+
+EMBEDDER_PARAMS = {
+ 'num_mels': 40,
+ 'n_fft': 512,
+ 'emb_dim': 256,
+ 'lstm_hidden': 768,
+ 'lstm_layers': 3,
+ 'window': 80,
+ 'stride': 40,
+}
+
+
+def set_requires_grad(nets, requires_grad=False):
+ """Set requies_grad=Fasle for all the networks to avoid unnecessary
+ computations
+ Parameters:
+ nets (network list) -- a list of networks
+ requires_grad (bool) -- whether the networks require gradients or not
+ """
+ if not isinstance(nets, list):
+ nets = [nets]
+ for net in nets:
+ if net is not None:
+ for param in net.parameters():
+ param.requires_grad = requires_grad
+
+
+class LinearNorm(nn.Module):
+ def __init__(self, hp):
+ super(LinearNorm, self).__init__()
+ self.linear_layer = nn.Linear(hp["lstm_hidden"], hp["emb_dim"])
+
+ def forward(self, x):
+ return self.linear_layer(x)
+
+
+class SpeechEmbedder(nn.Module):
+ def __init__(self, hp):
+ super(SpeechEmbedder, self).__init__()
+ self.lstm = nn.LSTM(hp["num_mels"],
+ hp["lstm_hidden"],
+ num_layers=hp["lstm_layers"],
+ batch_first=True)
+ self.proj = LinearNorm(hp)
+ self.hp = hp
+
+ def forward(self, mel):
+ # (num_mels, T) -> (num_mels, T', window)
+ mels = mel.unfold(1, self.hp["window"], self.hp["stride"])
+ mels = mels.permute(1, 2, 0) # (T', window, num_mels)
+ x, _ = self.lstm(mels) # (T', window, lstm_hidden)
+ x = x[:, -1, :] # (T', lstm_hidden), use last frame only
+ x = self.proj(x) # (T', emb_dim)
+ x = x / torch.norm(x, p=2, dim=1, keepdim=True) # (T', emb_dim)
+
+ x = x.mean(dim=0)
+ if x.norm(p=2) != 0:
+ x = x / x.norm(p=2)
+ return x
+
+
+class SpkrEmbedder(nn.Module):
+ RATE = 16000
+
+ def __init__(
+ self,
+ embedder_path,
+ embedder_params=EMBEDDER_PARAMS,
+ rate=16000,
+ hop_length=160,
+ win_length=400,
+ pad=False,
+ ):
+ super(SpkrEmbedder, self).__init__()
+ embedder_pt = torch.load(embedder_path, map_location="cpu")
+ self.embedder = SpeechEmbedder(embedder_params)
+ self.embedder.load_state_dict(embedder_pt)
+ self.embedder.eval()
+ set_requires_grad(self.embedder, requires_grad=False)
+ self.embedder_params = embedder_params
+
+ self.register_buffer('mel_basis', torch.from_numpy(
+ librosa.filters.mel(
+ sr=self.RATE,
+ n_fft=self.embedder_params["n_fft"],
+ n_mels=self.embedder_params["num_mels"])
+ )
+ )
+
+ self.resample = None
+ if rate != self.RATE:
+ self.resample = torchaudio.transforms.Resample(rate, self.RATE)
+ self.hop_length = hop_length
+ self.win_length = win_length
+ self.pad = pad
+
+ def get_mel(self, y):
+ if self.pad and y.shape[-1] < 14000:
+ y = F.pad(y, (0, 14000 - y.shape[-1]))
+
+ window = torch.hann_window(self.win_length).to(y)
+ y = torch.stft(y, n_fft=self.embedder_params["n_fft"],
+ hop_length=self.hop_length,
+ win_length=self.win_length,
+ window=window)
+ magnitudes = torch.norm(y, dim=-1, p=2) ** 2
+ mel = torch.log10(self.mel_basis @ magnitudes + 1e-6)
+ return mel
+
+ def forward(self, inputs):
+ dvecs = []
+ for wav in inputs:
+ mel = self.get_mel(wav)
+ if mel.dim() == 3:
+ mel = mel.squeeze(0)
+ dvecs += [self.embedder(mel)]
+ dvecs = torch.stack(dvecs)
+
+ dvec = torch.mean(dvecs, dim=0)
+ dvec = dvec / torch.norm(dvec)
+
+ return dvec
diff --git a/fairseq/examples/speech_synthesis/preprocessing/vad/__init__.py b/fairseq/examples/speech_synthesis/preprocessing/vad/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9cf121081fbde2f5085ed380f0841649d143a4be
--- /dev/null
+++ b/fairseq/examples/speech_synthesis/preprocessing/vad/__init__.py
@@ -0,0 +1,192 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+import collections
+import contextlib
+import wave
+
+try:
+ import webrtcvad
+except ImportError:
+ raise ImportError("Please install py-webrtcvad: pip install webrtcvad")
+import argparse
+import os
+import logging
+from tqdm import tqdm
+
+AUDIO_SUFFIX = '.wav'
+FS_MS = 30
+SCALE = 6e-5
+THRESHOLD = 0.3
+
+
+def read_wave(path):
+ """Reads a .wav file.
+ Takes the path, and returns (PCM audio data, sample rate).
+ """
+ with contextlib.closing(wave.open(path, 'rb')) as wf:
+ num_channels = wf.getnchannels()
+ assert num_channels == 1
+ sample_width = wf.getsampwidth()
+ assert sample_width == 2
+ sample_rate = wf.getframerate()
+ assert sample_rate in (8000, 16000, 32000, 48000)
+ pcm_data = wf.readframes(wf.getnframes())
+ return pcm_data, sample_rate
+
+
+def write_wave(path, audio, sample_rate):
+ """Writes a .wav file.
+ Takes path, PCM audio data, and sample rate.
+ """
+ with contextlib.closing(wave.open(path, 'wb')) as wf:
+ wf.setnchannels(1)
+ wf.setsampwidth(2)
+ wf.setframerate(sample_rate)
+ wf.writeframes(audio)
+
+
+class Frame(object):
+ """Represents a "frame" of audio data."""
+ def __init__(self, bytes, timestamp, duration):
+ self.bytes = bytes
+ self.timestamp = timestamp
+ self.duration = duration
+
+
+def frame_generator(frame_duration_ms, audio, sample_rate):
+ """Generates audio frames from PCM audio data.
+ Takes the desired frame duration in milliseconds, the PCM data, and
+ the sample rate.
+ Yields Frames of the requested duration.
+ """
+ n = int(sample_rate * (frame_duration_ms / 1000.0) * 2)
+ offset = 0
+ timestamp = 0.0
+ duration = (float(n) / sample_rate) / 2.0
+ while offset + n < len(audio):
+ yield Frame(audio[offset:offset + n], timestamp, duration)
+ timestamp += duration
+ offset += n
+
+
+def vad_collector(sample_rate, frame_duration_ms,
+ padding_duration_ms, vad, frames):
+ """Filters out non-voiced audio frames.
+ Given a webrtcvad.Vad and a source of audio frames, yields only
+ the voiced audio.
+ Uses a padded, sliding window algorithm over the audio frames.
+ When more than 90% of the frames in the window are voiced (as
+ reported by the VAD), the collector triggers and begins yielding
+ audio frames. Then the collector waits until 90% of the frames in
+ the window are unvoiced to detrigger.
+ The window is padded at the front and back to provide a small
+ amount of silence or the beginnings/endings of speech around the
+ voiced frames.
+ Arguments:
+ sample_rate - The audio sample rate, in Hz.
+ frame_duration_ms - The frame duration in milliseconds.
+ padding_duration_ms - The amount to pad the window, in milliseconds.
+ vad - An instance of webrtcvad.Vad.
+ frames - a source of audio frames (sequence or generator).
+ Returns: A generator that yields PCM audio data.
+ """
+ num_padding_frames = int(padding_duration_ms / frame_duration_ms)
+ # We use a deque for our sliding window/ring buffer.
+ ring_buffer = collections.deque(maxlen=num_padding_frames)
+ # We have two states: TRIGGERED and NOTTRIGGERED. We start in the
+ # NOTTRIGGERED state.
+ triggered = False
+
+ voiced_frames = []
+ for frame in frames:
+ is_speech = vad.is_speech(frame.bytes, sample_rate)
+
+ # sys.stdout.write('1' if is_speech else '0')
+ if not triggered:
+ ring_buffer.append((frame, is_speech))
+ num_voiced = len([f for f, speech in ring_buffer if speech])
+ # If we're NOTTRIGGERED and more than 90% of the frames in
+ # the ring buffer are voiced frames, then enter the
+ # TRIGGERED state.
+ if num_voiced > 0.9 * ring_buffer.maxlen:
+ triggered = True
+ # We want to yield all the audio we see from now until
+ # we are NOTTRIGGERED, but we have to start with the
+ # audio that's already in the ring buffer.
+ for f, _ in ring_buffer:
+ voiced_frames.append(f)
+ ring_buffer.clear()
+ else:
+ # We're in the TRIGGERED state, so collect the audio data
+ # and add it to the ring buffer.
+ voiced_frames.append(frame)
+ ring_buffer.append((frame, is_speech))
+ num_unvoiced = len([f for f, speech in ring_buffer if not speech])
+ # If more than 90% of the frames in the ring buffer are
+ # unvoiced, then enter NOTTRIGGERED and yield whatever
+ # audio we've collected.
+ if num_unvoiced > 0.9 * ring_buffer.maxlen:
+ triggered = False
+ yield [b''.join([f.bytes for f in voiced_frames]),
+ voiced_frames[0].timestamp, voiced_frames[-1].timestamp]
+ ring_buffer.clear()
+ voiced_frames = []
+ # If we have any leftover voiced audio when we run out of input,
+ # yield it.
+ if voiced_frames:
+ yield [b''.join([f.bytes for f in voiced_frames]),
+ voiced_frames[0].timestamp, voiced_frames[-1].timestamp]
+
+
+def main(args):
+ # create output folder
+ try:
+ cmd = f"mkdir -p {args.out_path}"
+ os.system(cmd)
+ except Exception:
+ logging.error("Can not create output folder")
+ exit(-1)
+
+ # build vad object
+ vad = webrtcvad.Vad(int(args.agg))
+ # iterating over wavs in dir
+ for file in tqdm(os.listdir(args.in_path)):
+ if file.endswith(AUDIO_SUFFIX):
+ audio_inpath = os.path.join(args.in_path, file)
+ audio_outpath = os.path.join(args.out_path, file)
+ audio, sample_rate = read_wave(audio_inpath)
+ frames = frame_generator(FS_MS, audio, sample_rate)
+ frames = list(frames)
+ segments = vad_collector(sample_rate, FS_MS, 300, vad, frames)
+ merge_segments = list()
+ timestamp_start = 0.0
+ timestamp_end = 0.0
+ # removing start, end, and long sequences of sils
+ for i, segment in enumerate(segments):
+ merge_segments.append(segment[0])
+ if i and timestamp_start:
+ sil_duration = segment[1] - timestamp_end
+ if sil_duration > THRESHOLD:
+ merge_segments.append(int(THRESHOLD / SCALE)*(b'\x00'))
+ else:
+ merge_segments.append(int((sil_duration / SCALE))*(b'\x00'))
+ timestamp_start = segment[1]
+ timestamp_end = segment[2]
+ segment = b''.join(merge_segments)
+ write_wave(audio_outpath, segment, sample_rate)
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser(description='Apply vad to a file of fils.')
+ parser.add_argument('in_path', type=str, help='Path to the input files')
+ parser.add_argument('out_path', type=str,
+ help='Path to save the processed files')
+ parser.add_argument('--agg', type=int, default=3,
+ help='The level of aggressiveness of the VAD: [0-3]')
+ args = parser.parse_args()
+
+ main(args)
diff --git a/fairseq/examples/speech_synthesis/utils.py b/fairseq/examples/speech_synthesis/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..2c7b03733d2290d3834d2c68a16034198daa1e69
--- /dev/null
+++ b/fairseq/examples/speech_synthesis/utils.py
@@ -0,0 +1,101 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import numpy as np
+import torch
+from scipy.interpolate import interp1d
+import torchaudio
+
+from fairseq.tasks.text_to_speech import (
+ batch_compute_distortion, compute_rms_dist
+)
+
+
+def batch_mel_spectral_distortion(
+ y1, y2, sr, normalize_type="path", mel_fn=None
+):
+ """
+ https://arxiv.org/pdf/2011.03568.pdf
+
+ Same as Mel Cepstral Distortion, but computed on log-mel spectrograms.
+ """
+ if mel_fn is None or mel_fn.sample_rate != sr:
+ mel_fn = torchaudio.transforms.MelSpectrogram(
+ sr, n_fft=int(0.05 * sr), win_length=int(0.05 * sr),
+ hop_length=int(0.0125 * sr), f_min=20, n_mels=80,
+ window_fn=torch.hann_window
+ ).to(y1[0].device)
+ offset = 1e-6
+ return batch_compute_distortion(
+ y1, y2, sr, lambda y: torch.log(mel_fn(y) + offset).transpose(-1, -2),
+ compute_rms_dist, normalize_type
+ )
+
+
+# This code is based on
+# "https://github.com/bastibe/MAPS-Scripts/blob/master/helper.py"
+def _same_t_in_true_and_est(func):
+ def new_func(true_t, true_f, est_t, est_f):
+ assert type(true_t) is np.ndarray
+ assert type(true_f) is np.ndarray
+ assert type(est_t) is np.ndarray
+ assert type(est_f) is np.ndarray
+
+ interpolated_f = interp1d(
+ est_t, est_f, bounds_error=False, kind='nearest', fill_value=0
+ )(true_t)
+ return func(true_t, true_f, true_t, interpolated_f)
+
+ return new_func
+
+
+@_same_t_in_true_and_est
+def gross_pitch_error(true_t, true_f, est_t, est_f):
+ """The relative frequency in percent of pitch estimates that are
+ outside a threshold around the true pitch. Only frames that are
+ considered pitched by both the ground truth and the estimator (if
+ applicable) are considered.
+ """
+
+ correct_frames = _true_voiced_frames(true_t, true_f, est_t, est_f)
+ gross_pitch_error_frames = _gross_pitch_error_frames(
+ true_t, true_f, est_t, est_f
+ )
+ return np.sum(gross_pitch_error_frames) / np.sum(correct_frames)
+
+
+def _gross_pitch_error_frames(true_t, true_f, est_t, est_f, eps=1e-8):
+ voiced_frames = _true_voiced_frames(true_t, true_f, est_t, est_f)
+ true_f_p_eps = [x + eps for x in true_f]
+ pitch_error_frames = np.abs(est_f / true_f_p_eps - 1) > 0.2
+ return voiced_frames & pitch_error_frames
+
+
+def _true_voiced_frames(true_t, true_f, est_t, est_f):
+ return (est_f != 0) & (true_f != 0)
+
+
+def _voicing_decision_error_frames(true_t, true_f, est_t, est_f):
+ return (est_f != 0) != (true_f != 0)
+
+
+@_same_t_in_true_and_est
+def f0_frame_error(true_t, true_f, est_t, est_f):
+ gross_pitch_error_frames = _gross_pitch_error_frames(
+ true_t, true_f, est_t, est_f
+ )
+ voicing_decision_error_frames = _voicing_decision_error_frames(
+ true_t, true_f, est_t, est_f
+ )
+ return (np.sum(gross_pitch_error_frames) +
+ np.sum(voicing_decision_error_frames)) / (len(true_t))
+
+
+@_same_t_in_true_and_est
+def voicing_decision_error(true_t, true_f, est_t, est_f):
+ voicing_decision_error_frames = _voicing_decision_error_frames(
+ true_t, true_f, est_t, est_f
+ )
+ return np.sum(voicing_decision_error_frames) / (len(true_t))
diff --git a/fairseq/examples/speech_text_joint_to_text/README.md b/fairseq/examples/speech_text_joint_to_text/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..e071d241e0e02b35d3aac777ac09b4ef3be9119f
--- /dev/null
+++ b/fairseq/examples/speech_text_joint_to_text/README.md
@@ -0,0 +1,46 @@
+# Joint Speech Text training in Fairseq
+An extension of Fairseq s2t project with the speech to text task enhanced by the co-trained text to text mapping task. More details about Fairseq s2t can be found [here](../speech_to_text/README.md)
+
+## Examples
+Examples of speech text joint training in fairseq
+- [English-to-German MuST-C model](docs/ende-mustc.md)
+- [IWSLT 2021 Multilingual Speech Translation](docs/iwslt2021.md)
+
+## Citation
+Please cite as:
+```
+@inproceedings{Tang2021AGM,
+ title={A General Multi-Task Learning Framework to Leverage Text Data for Speech to Text Tasks},
+ author={Yun Tang and J. Pino and Changhan Wang and Xutai Ma and Dmitriy Genzel},
+ booktitle={ICASSP},
+ year={2021}
+}
+
+@inproceedings{Tang2021IST,
+ title = {Improving Speech Translation by Understanding and Learning from the Auxiliary Text Translation Task},
+ author = {Yun Tang and Juan Pino and Xian Li and Changhan Wang and Dmitriy Genzel},
+ booktitle = {ACL},
+ year = {2021},
+}
+
+@inproceedings{Tang2021FST,
+ title = {FST: the FAIR Speech Translation System for the IWSLT21 Multilingual Shared Task},
+ author = {Yun Tang and Hongyu Gong and Xian Li and Changhan Wang and Juan Pino and Holger Schwenk and Naman Goyal},
+ booktitle = {IWSLT},
+ year = {2021},
+}
+
+@inproceedings{wang2020fairseqs2t,
+ title = {fairseq S2T: Fast Speech-to-Text Modeling with fairseq},
+ author = {Changhan Wang and Yun Tang and Xutai Ma and Anne Wu and Dmytro Okhonko and Juan Pino},
+ booktitle = {Proceedings of the 2020 Conference of the Asian Chapter of the Association for Computational Linguistics (AACL): System Demonstrations},
+ year = {2020},
+}
+
+@inproceedings{ott2019fairseq,
+ title = {fairseq: A Fast, Extensible Toolkit for Sequence Modeling},
+ author = {Myle Ott and Sergey Edunov and Alexei Baevski and Angela Fan and Sam Gross and Nathan Ng and David Grangier and Michael Auli},
+ booktitle = {Proceedings of NAACL-HLT 2019: Demonstrations},
+ year = {2019},
+}
+```
diff --git a/fairseq/examples/speech_text_joint_to_text/__init__.py b/fairseq/examples/speech_text_joint_to_text/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..239d2e69f9a235095dee1ea7b3a94164a77273f5
--- /dev/null
+++ b/fairseq/examples/speech_text_joint_to_text/__init__.py
@@ -0,0 +1,6 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from . import tasks, criterions, models # noqa
diff --git a/fairseq/examples/speech_text_joint_to_text/configs/mustc_noise.list b/fairseq/examples/speech_text_joint_to_text/configs/mustc_noise.list
new file mode 100644
index 0000000000000000000000000000000000000000..02eeac4e009f77b765004272f59a1618214da18d
--- /dev/null
+++ b/fairseq/examples/speech_text_joint_to_text/configs/mustc_noise.list
@@ -0,0 +1,49 @@
+"(Applause) NOISE
+"(Laughter) VOICE
+"(Laughter)" VOICE
+(Applause) NOISE
+(Applause). NOISE
+(Audience) VOICE
+(Audio) NOISE
+(Beat) NOISE
+(Beatboxing) VOICE
+(Beep) NOISE
+(Beeps) NOISE
+(Cheering) VOICE
+(Cheers) VOICE
+(Claps) NOISE
+(Clicking) NOISE
+(Clunk) NOISE
+(Coughs) NOISE
+(Drums) NOISE
+(Explosion) NOISE
+(Gasps) VOICE
+(Guitar) NOISE
+(Honk) NOISE
+(Laugher) VOICE
+(Laughing) VOICE
+(Laughs) VOICE
+(Laughter) VOICE
+(Laughter). VOICE
+(Laughter)... VOICE
+(Mumbling) VOICE
+(Music) NOISE
+(Noise) NOISE
+(Recording) VOICE
+(Ringing) NOISE
+(Shouts) VOICE
+(Sigh) VOICE
+(Sighs) VOICE
+(Silence) NOISE
+(Singing) VOICE
+(Sings) VOICE
+(Spanish) VOICE
+(Static) NOISE
+(Tones) NOISE
+(Trumpet) NOISE
+(Video) NOISE
+(Video): NOISE
+(Voice-over) NOISE
+(Whistle) NOISE
+(Whistling) NOISE
+(video): NOISE
diff --git a/fairseq/examples/speech_text_joint_to_text/criterions/__init__.py b/fairseq/examples/speech_text_joint_to_text/criterions/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7faae73119321af0b34fe8e26499a2ef5577291a
--- /dev/null
+++ b/fairseq/examples/speech_text_joint_to_text/criterions/__init__.py
@@ -0,0 +1,15 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import importlib
+import os
+
+
+for file in os.listdir(os.path.dirname(__file__)):
+ if file.endswith(".py") and not file.startswith("_"):
+ criterion_name = file[: file.find(".py")]
+ importlib.import_module(
+ "examples.speech_text_joint_to_text.criterions." + criterion_name
+ )
diff --git a/fairseq/examples/speech_text_joint_to_text/criterions/text_guide_cross_entropy_acc.py b/fairseq/examples/speech_text_joint_to_text/criterions/text_guide_cross_entropy_acc.py
new file mode 100644
index 0000000000000000000000000000000000000000..0d356e5a10241716b58a5bc04a9d204a72553ff8
--- /dev/null
+++ b/fairseq/examples/speech_text_joint_to_text/criterions/text_guide_cross_entropy_acc.py
@@ -0,0 +1,223 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+import math
+
+import torch
+import torch.nn.functional as F
+from fairseq.criterions import FairseqCriterion, register_criterion
+from fairseq.criterions.label_smoothed_cross_entropy import label_smoothed_nll_loss
+from fairseq import metrics, utils
+
+
+@register_criterion("guided_label_smoothed_cross_entropy_with_accuracy")
+class GuidedCrossEntAccCriterion(FairseqCriterion):
+ def __init__(
+ self,
+ task,
+ sentence_avg,
+ guide_alpha,
+ text_input_cost_ratio,
+ label_smoothing,
+ disable_text_guide_update_num=0,
+ attentive_cost_regularization=0,
+ ):
+ """
+ guide_alpha: alpha to inteplate nll and kd loss
+ text_input_cost_ratio: loss ratio for text only input data
+ label_smoothing: label smoothing ratio
+ disable_text_guide_update_num: only use nll loss for the first N updates
+ attentive_cost_regularization: ratio fo attentive cost
+ """
+ super().__init__(task)
+ self.alpha = guide_alpha
+ self.attn_beta = attentive_cost_regularization
+ self.sentence_avg = sentence_avg
+ self.eps = label_smoothing
+ self.text_input_cost_ratio = text_input_cost_ratio
+ self.disable_update_num = disable_text_guide_update_num
+ assert self.alpha >= 0 and self.alpha <= 1.0
+
+ @staticmethod
+ def add_args(parser):
+ """Add criterion-specific arguments to the parser."""
+ # fmt: off
+ parser.add_argument('--label-smoothing', default=0., type=float, metavar='D',
+ help='epsilon for label smoothing, 0 means no label smoothing')
+ # fmt: off
+ parser.add_argument('--guide-alpha', default=0., type=float, metavar='D',
+ help='alpha to merge kd cost from text to speech input with ce loss')
+ # fmt: off
+ parser.add_argument('--disable-text-guide-update-num', default=0, type=int, metavar='D',
+ help='disable guided target from text for the first N updates.')
+ parser.add_argument("--attentive-cost-regularization", default=0.0, type=float, metavar='D',
+ help="use encoder attentive loss regularization with cost ratio D")
+ parser.add_argument("--attentive-cost-without-normalize", action='store_true',
+ help="Don't do normalization during attentive cost computation")
+
+ def forward(self, model, sample, reduce=True):
+ reduction = 'sum' if reduce else 'none'
+ net_input = sample["net_input"]
+ net_output = model(**net_input)
+ attn_cost = None
+ lprobs = model.get_normalized_probs(net_output, log_probs=True)
+ is_dual_input = True if net_input['src_tokens'] is not None and net_input.get('src_txt_tokens') is not None else False
+ target = model.get_targets(sample, net_output)
+ src_token_num = 0
+ if is_dual_input:
+ # lprobs_spch from speech encoder and lprobs_text from text encoder
+ lprobs_spch, lprobs_text = torch.chunk(lprobs, 2)
+ lprobs_spch.batch_first = lprobs.batch_first
+ lprobs_text.batch_first = lprobs.batch_first
+
+ speech_loss, speech_nll_loss, speech_correct, speech_total = \
+ self.guide_loss_and_acc(model, lprobs_spch, lprobs_text, target, reduce=(reduction == 'sum'))
+ text_loss, text_nll_loss, text_correct, text_total = self.compute_loss_and_acc(model, lprobs_text, target, reduction=reduction)
+ loss = (speech_loss + text_loss)
+ nll_loss = (speech_nll_loss + text_nll_loss)
+ correct = speech_correct + text_correct
+ total = speech_total + text_total
+
+ attn_cost = net_output[1].get('attn_cost')
+ if attn_cost is not None:
+ # attn_cost is batch_first and padding tokens have been masked already
+ src_token_num = attn_cost.ne(0).sum()
+ attn_cost = attn_cost.sum()
+ loss = loss + attn_cost * self.attn_beta
+ else:
+ attn_cost = 0
+ else:
+ loss, nll_loss, correct, total = self.compute_loss_and_acc(model, lprobs, target, reduction=reduction)
+ if sample["net_input"]['src_tokens'] is None: # text input only
+ loss = loss * self.text_input_cost_ratio
+ speech_loss = None
+ speech_nll_loss = None
+
+ sample_size, logging_output = self.get_logging_output(
+ sample, loss, nll_loss, correct, total, src_token_num, speech_loss, speech_nll_loss, attn_cost, is_dual_input
+ )
+ return loss, sample_size, logging_output
+
+ def compute_loss_and_acc(self, model, lprobs, target, reduction='sum'):
+ if not lprobs.batch_first:
+ lprobs = lprobs.transpose(0, 1)
+ lprobs = lprobs.view(-1, lprobs.size(-1)) # -> (B x T) x C
+ target = target.view(-1)
+ loss, nll_loss = label_smoothed_nll_loss(
+ lprobs, target, self.eps, ignore_index=self.padding_idx, reduce=(reduction == 'sum'),
+ )
+
+ mask = target.ne(self.padding_idx)
+ correct = torch.sum(lprobs.argmax(1).masked_select(mask).eq(target.masked_select(mask)))
+ total = torch.sum(mask)
+ return loss, nll_loss, correct, total
+
+ def guide_loss_and_acc(self, model, lprobs, lprobs_teacher, target, reduce=True):
+ """ lprobs_teacher is used as guide for lprobs """
+ if self.alpha == 0.0 or model.num_updates < self.disable_update_num:
+ return self.compute_loss_and_acc(model, lprobs, target, reduction=('sum' if reduce else 'none'))
+ if not lprobs.batch_first:
+ lprobs = lprobs.transpose(0, 1)
+ lprobs_teacher = lprobs_teacher.transpose(0, 1)
+
+ lprobs = lprobs.view(-1, lprobs.size(-1)).float() # -> (B x T) x C
+ lprobs_teacher = lprobs_teacher.view(-1, lprobs_teacher.size(-1)).float() # -> (B x T) x C
+ target = target.view(-1)
+ loss = F.nll_loss(lprobs, target, ignore_index=self.padding_idx, reduction='sum' if reduce else 'none')
+ nll_loss = loss
+ probs_teacher = lprobs_teacher.exp().masked_fill_(target.unsqueeze(-1).eq(self.padding_idx), 0)
+ probs_teacher = probs_teacher.detach()
+ guide_loss = -(probs_teacher*lprobs).sum() if reduce else -(probs_teacher*lprobs).sum(-1, keepdim=True)
+ loss = self.alpha*guide_loss + (1.0 - self.alpha)*loss
+
+ mask = target.ne(self.padding_idx)
+ correct = torch.sum(lprobs.argmax(1).masked_select(mask).eq(target.masked_select(mask)))
+ total = torch.sum(mask)
+ return loss, nll_loss, correct, total
+
+ def get_logging_output(
+ self,
+ sample,
+ loss,
+ nll_loss,
+ correct,
+ total,
+ src_token_num=0,
+ speech_loss=None,
+ speech_nll_loss=None,
+ attn_cost=None,
+ is_dual_input=False,
+ ):
+
+ sample_size = (
+ sample["target"].size(0) if self.sentence_avg else sample["ntokens"]
+ )
+ mul_size = 2 if is_dual_input else 1
+
+ logging_output = {
+ "loss": utils.item(loss.data), # * sample['ntokens'],
+ "nll_loss": utils.item(nll_loss.data), # * sample['ntokens'],
+ "ntokens": sample["ntokens"]*mul_size,
+ "nsentences": sample["target"].size(0)*mul_size,
+ "sample_size": sample_size*mul_size,
+ "correct": utils.item(correct.data),
+ "total": utils.item(total.data),
+ "src_token_num": utils.item(src_token_num.data) if src_token_num > 0 else 0,
+ "nframes": torch.sum(sample["net_input"]["src_lengths"]).item(),
+ }
+
+ if speech_loss is not None:
+ logging_output["speech_loss"] = utils.item(speech_loss.data)
+ logging_output["speech_nll_loss"] = utils.item(speech_nll_loss.data)
+ logging_output["sample_size_speech_cost"] = sample_size
+ logging_output["speech_attn_loss"] = attn_cost
+
+ return sample_size*mul_size, logging_output
+
+ @staticmethod
+ def aggregate_logging_outputs(logging_outputs):
+ """Aggregate logging outputs from data parallel training."""
+ correct_sum = sum(log.get("correct", 0) for log in logging_outputs)
+ total_sum = sum(log.get("total", 0) for log in logging_outputs)
+ src_token_sum = sum(log.get("src_token_num", 0) for log in logging_outputs)
+ loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
+ nll_loss_sum = sum(log.get("nll_loss", 0) for log in logging_outputs)
+ ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
+ nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
+ sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
+ nframes = sum(log.get("nframes", 0) for log in logging_outputs)
+ speech_loss_sum = sum(log.get("speech_loss", 0) for log in logging_outputs)
+ speech_nll_loss_sum = sum(log.get("speech_nll_loss", 0) for log in logging_outputs)
+ speech_attn_loss_sum = sum(log.get("speech_attn_loss", 0) for log in logging_outputs)
+ sample_size_speech = sum(log.get("sample_size_speech_cost", 0) for log in logging_outputs)
+
+ agg_output = {
+ "loss": loss_sum / sample_size / math.log(2) if sample_size > 0 else 0.0,
+ "nll_loss": nll_loss_sum / sample_size / math.log(2) if sample_size > 0 else 0.0,
+ # if args.sentence_avg, then sample_size is nsentences, and loss
+ # is per-sentence loss; else sample_size is ntokens, and the loss
+ # becomes per-output token loss
+ "speech_loss": speech_loss_sum / sample_size_speech / math.log(2) if sample_size_speech > 0 else 0.0,
+ "speech_nll_loss": speech_nll_loss_sum / sample_size_speech / math.log(2) if sample_size_speech > 0 else 0.0,
+ "speech_attn_loss": speech_attn_loss_sum / src_token_sum / math.log(2) if src_token_sum > 0 else 0.0,
+ "ntokens": ntokens,
+ "nsentences": nsentences,
+ "nframes": nframes,
+ "sample_size": sample_size,
+ "acc": correct_sum * 100.0 / total_sum if total_sum > 0 else 0.0,
+ "correct": correct_sum,
+ "total": total_sum,
+ "src_token_num": src_token_sum,
+ # total is the number of validate tokens
+ }
+ return agg_output
+
+ @classmethod
+ def reduce_metrics(cls, logging_outputs):
+ """Aggregate logging outputs from data parallel training."""
+ agg_logging_outputs = cls.aggregate_logging_outputs(logging_outputs)
+ for k, v in agg_logging_outputs.items():
+ if k in {'nsentences', 'ntokens', 'sample_size'}:
+ continue
+ metrics.log_scalar(k, v, round=3)
diff --git a/fairseq/examples/speech_text_joint_to_text/docs/ende-mustc.md b/fairseq/examples/speech_text_joint_to_text/docs/ende-mustc.md
new file mode 100644
index 0000000000000000000000000000000000000000..2897c4e27b053d4fd65b37fb7e586679dffed1ba
--- /dev/null
+++ b/fairseq/examples/speech_text_joint_to_text/docs/ende-mustc.md
@@ -0,0 +1,112 @@
+[[Back]](..)
+
+# Joint Speech Text Training for the MuST-C English to German Speech Translation task
+
+Joint Training Baseline: it is based on paper ["A general multi-task learning framework to leverage text data for speech to text tasks"](https://arxiv.org/pdf/2010.11338.pdf)
+
+Enhanced Joint Training: the joint training is enhanced with pre-trained models, cross attentive regularization and online knowledge distillation based on paper ["Improving Speech Translation by Understanding and Learning from the Auxiliary Text Translation Task"](https://research.fb.com/publications/improving-speech-translation-by-understanding-and-learning-from-the-auxiliary-text-translation-task)
+
+## Prepare Data
+#### Download files
+- Sentence piece model [spm.model](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/must_c/en_de/spm.model)
+- Dictionary [dict.txt](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/must_c/en_de/dict.txt)
+- config [config.yaml](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/must_c/en_de/config.yaml)
+#### Prepare MuST-C data set
+- [Please follow the data preparation in the S2T example](https://github.com/pytorch/fairseq/blob/main/examples/speech_to_text/docs/mustc_example.md)
+- Append src_text in the tsv file with phoneme representation.
+```bash
+ python examples/speech_text_joint_to_text/scripts/g2p_encode.py \
+ --lower-case --do-filter --use-word-start --no-punc \
+ --reserve-word examples/speech_text_joint_to_text/configs/mustc_noise.list \
+ --data-path ${must_c_en_de_src_text} \
+ --out-path ${must_c_en_de_src_text_pho}
+```
+- Update tsv data with src_text generated above and save to $MANIFEST_ROOT
+- Prepare phoneme dictionary and save to $MANIFEST_ROOT as [src_dict.txt](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/must_c/en_de/src_dict.txt)
+#### Prepare WMT text data
+- [Download wmt data](https://github.com/pytorch/fairseq/blob/main/examples/translation/prepare-wmt14en2de.sh)
+- Convert source text (English) into phoneme representation as above
+- Generate binary parallel file for training (as translation example) and save data in $parallel_text_data
+
+## Training
+The model is trained with 8 v100 GPUs.
+
+#### Download pretrained models
+- [pretrain_encoder](https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_multilingual_asr_transformer_m.pt)
+- [pretrain_nmt](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/must_c/en_de/checkpoint_mt.pt)
+
+#### Training scripts
+- Jointly trained model from scratch
+```bash
+python train.py ${MANIFEST_ROOT} \
+ --save-dir ${save_dir} \
+ --num-workers 8 \
+ --task speech_text_joint_to_text \
+ --arch dualinputs2ttransformer_s \
+ --user-dir examples/speech_text_joint_to_text \
+ --max-epoch 100 --update-mix-data \
+ --optimizer adam --lr-scheduler inverse_sqrt \
+ --lr 0.001 --update-freq 4 --clip-norm 10.0 \
+ --criterion guided_label_smoothed_cross_entropy_with_accuracy \
+ --label-smoothing 0.1 --max-tokens 10000 --max-tokens-text 10000 \
+ --max-positions-text 400 --seed 2 --speech-encoder-layers 12 \
+ --text-encoder-layers 6 --encoder-shared-layers 6 --decoder-layers 6 \
+ --dropout 0.1 --warmup-updates 20000 \
+ --text-sample-ratio 0.25 --parallel-text-data ${parallel_text_data} \
+ --text-input-cost-ratio 0.5 --enc-grad-mult 2.0 --add-speech-eos \
+ --log-format json --langpairs en-de --noise-token '"'"'▁NOISE'"'"' \
+ --mask-text-ratio 0.0 --max-tokens-valid 20000 --ddp-backend no_c10d \
+ --log-interval 100 --data-buffer-size 50 --config-yaml config.yaml \
+ --keep-last-epochs 10
+```
+- Jointly trained model with good initialization, cross attentive loss and online knowledge distillation
+```bash
+python train.py ${MANIFEST_ROOT} \
+ --save-dir ${save_dir} \
+ --num-workers 8 \
+ --task speech_text_joint_to_text \
+ --arch dualinputs2ttransformer_m \
+ --user-dir examples/speech_text_joint_to_text \
+ --max-epoch 100 --update-mix-data \
+ --optimizer adam --lr-scheduler inverse_sqrt \
+ --lr 0.002 --update-freq 4 --clip-norm 10.0 \
+ --criterion guided_label_smoothed_cross_entropy_with_accuracy \
+ --guide-alpha 0.8 --disable-text-guide-update-num 5000 \
+ --label-smoothing 0.1 --max-tokens 10000 --max-tokens-text 10000 \
+ --max-positions-text 400 --seed 2 --speech-encoder-layers 12 \
+ --text-encoder-layers 6 --encoder-shared-layers 6 --decoder-layers 6 \
+ --dropout 0.1 --warmup-updates 20000 --attentive-cost-regularization 0.02 \
+ --text-sample-ratio 0.25 --parallel-text-data ${parallel_text_data} \
+ --text-input-cost-ratio 0.5 --enc-grad-mult 2.0 --add-speech-eos \
+ --log-format json --langpairs en-de --noise-token '"'"'▁NOISE'"'"' \
+ --mask-text-ratio 0.0 --max-tokens-valid 20000 --ddp-backend no_c10d \
+ --log-interval 100 --data-buffer-size 50 --config-yaml config.yaml \
+ --load-pretrain-speech-encoder ${pretrain_encoder} \
+ --load-pretrain-decoder ${pretrain_nmt} \
+ --load-pretrain-text-encoder-last ${pretrain_nmt} \
+ --keep-last-epochs 10
+```
+
+## Evaluation
+```bash
+python ./fairseq_cli/generate.py \
+ ${MANIFEST_ROOT} \
+ --task speech_text_joint_to_text \
+ --max-tokens 25000 \
+ --nbest 1 \
+ --results-path ${infer_results} \
+ --batch-size 512 \
+ --path ${model} \
+ --gen-subset tst-COMMON \
+ --config-yaml config_spm.yaml \
+ --scoring sacrebleu \
+ --beam 5 --lenpen 1.0 \
+ --user-dir examples/speech_text_joint_to_text \
+ --load-speech-only
+```
+
+## Results (Joint training with initialization + CAR + online KD)
+|Direction|En-De | En-Es | En-Fr |
+|---|---|---|---|
+|BLEU|27.4| 31.2 | 37.6 |
+|checkpoint | [link](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/must_c/en_de/checkpoint_ave_10.pt) |[link](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/must_c/en_es/checkpoint_ave_10.pt)|[link](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/must_c/en_fr/checkpoint_ave_10.pt)|
diff --git a/fairseq/examples/speech_text_joint_to_text/docs/iwslt2021.md b/fairseq/examples/speech_text_joint_to_text/docs/iwslt2021.md
new file mode 100644
index 0000000000000000000000000000000000000000..920ff271c2e178c7a4ca3c7c8ce57a2f28653969
--- /dev/null
+++ b/fairseq/examples/speech_text_joint_to_text/docs/iwslt2021.md
@@ -0,0 +1,76 @@
+[[Back]](..)
+
+# Joint Speech Text Training for the 2021 IWSLT multilingual speech translation
+
+This directory contains the code from paper ["FST: the FAIR Speech Translation System for the IWSLT21 Multilingual Shared Task"](https://arxiv.org/pdf/2107.06959.pdf).
+
+## Prepare Data
+#### Download files
+- Sentence piece model [spm.model](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/iwslt/iwslt_data/spm.model)
+- Dictionary [tgt_dict.txt](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/iwslt/iwslt_data/dict.txt)
+- Config [config.yaml](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/iwslt/iwslt_data/config.yaml)
+
+#### Prepare
+- [Please follow the data preparation in speech-to-text](https://github.com/pytorch/fairseq/blob/main/examples/speech_to_text/docs/mtedx_example.md)
+
+
+
+## Training
+
+#### Download pretrained models
+- [Pretrained mbart model](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/iwslt/iwslt_data/mbart.pt)
+- [Pretrained w2v model](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/iwslt/iwslt_data/xlsr_53_56k.pt)
+
+
+#### Training scripts
+
+```bash
+python train.py ${MANIFEST_ROOT} \
+ --save-dir ${save_dir} \
+ --user-dir examples/speech_text_joint_to_text \
+ --train-subset train_es_en_tedx,train_es_es_tedx,train_fr_en_tedx,train_fr_es_tedx,train_fr_fr_tedx,train_it_it_tedx,train_pt_en_tedx,train_pt_pt_tedx \
+ --valid-subset valid_es_en_tedx,valid_es_es_tedx,valid_es_fr_tedx,valid_es_it_tedx,valid_es_pt_tedx,valid_fr_en_tedx,valid_fr_es_tedx,valid_fr_fr_tedx,valid_fr_pt_tedx,valid_it_en_tedx,valid_it_es_tedx,valid_it_it_tedx,valid_pt_en_tedx,valid_pt_es_tedx,valid_pt_pt_tedx \
+ --config-yaml config.yaml --ddp-backend no_c10d \
+ --num-workers 2 --task speech_text_joint_to_text \
+ --criterion guided_label_smoothed_cross_entropy_with_accuracy \
+ --label-smoothing 0.3 --guide-alpha 0.8 \
+ --disable-text-guide-update-num 5000 --arch dualinputxmtransformer_base \
+ --max-tokens 500000 --max-sentences 3 --max-tokens-valid 800000 \
+ --max-source-positions 800000 --enc-grad-mult 2.0 \
+ --attentive-cost-regularization 0.02 --optimizer adam \
+ --clip-norm 1.0 --log-format simple --log-interval 200 \
+ --keep-last-epochs 5 --seed 1 \
+ --w2v-path ${w2v_path} \
+ --load-pretrained-mbart-from ${mbart_path} \
+ --max-update 1000000 --update-freq 4 \
+ --skip-invalid-size-inputs-valid-test \
+ --skip-encoder-projection --save-interval 1 \
+ --attention-dropout 0.3 --mbart-dropout 0.3 \
+ --finetune-w2v-params all --finetune-mbart-decoder-params all \
+ --finetune-mbart-encoder-params all --stack-w2v-mbart-encoder \
+ --drop-w2v-layers 12 --normalize \
+ --lr 5e-05 --lr-scheduler inverse_sqrt --warmup-updates 5000
+```
+
+## Evaluation
+```bash
+python ./fairseq_cli/generate.py
+ ${MANIFEST_ROOT} \
+ --task speech_text_joint_to_text \
+ --user-dir ./examples/speech_text_joint_to_text \
+ --load-speech-only --gen-subset test_es_en_tedx \
+ --path ${model} \
+ --max-source-positions 800000 \
+ --skip-invalid-size-inputs-valid-test \
+ --config-yaml config.yaml \
+ --infer-target-lang en \
+ --max-tokens 800000 \
+ --beam 5 \
+ --results-path ${RESULTS_DIR} \
+ --scoring sacrebleu
+```
+The trained model can be downloaded [here](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/iwslt/iwslt_data/checkpoint17.pt)
+
+|direction|es_en|fr_en|pt_en|it_en|fr_es|pt_es|it_es|es_es|fr_fr|pt_pt|it_it|
+|---|---|---|---|---|---|---|---|---|---|---|---|
+|BLEU|31.62|36.93|35.07|27.12|38.87|35.57|34.13|74.59|74.64|70.84|69.76|
diff --git a/fairseq/examples/speech_text_joint_to_text/models/__init__.py b/fairseq/examples/speech_text_joint_to_text/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a394c7e4f25bfef8603596ca3629e65ca7b0d8b
--- /dev/null
+++ b/fairseq/examples/speech_text_joint_to_text/models/__init__.py
@@ -0,0 +1,14 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import importlib
+import os
+
+for file in os.listdir(os.path.dirname(__file__)):
+ if file.endswith(".py") and not file.startswith("_"):
+ model_name = file[: file.find(".py")]
+ importlib.import_module(
+ "examples.speech_text_joint_to_text.models." + model_name
+ )
diff --git a/fairseq/examples/speech_text_joint_to_text/models/s2t_dualinputtransformer.py b/fairseq/examples/speech_text_joint_to_text/models/s2t_dualinputtransformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..7970a3c71401b4835ba09158ea06134418afa065
--- /dev/null
+++ b/fairseq/examples/speech_text_joint_to_text/models/s2t_dualinputtransformer.py
@@ -0,0 +1,1090 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import logging
+from collections import namedtuple
+
+import torch
+import torch.nn as nn
+from fairseq import checkpoint_utils
+from fairseq import utils
+from fairseq.models import (
+ FairseqEncoder,
+ FairseqDecoder,
+ FairseqEncoderDecoderModel,
+ register_model,
+ register_model_architecture,
+)
+from fairseq.models.fairseq_encoder import EncoderOut
+from fairseq.models.speech_to_text import (
+ TransformerDecoder,
+ S2TTransformerEncoder,
+)
+from fairseq.models.transformer import TransformerEncoder
+from fairseq.modules import (
+ TransformerEncoderLayer,
+ GradMultiply,
+ LayerNorm,
+)
+
+logger = logging.getLogger(__name__)
+
+
+class SpeechEoSEncoder(FairseqEncoder):
+ def __init__(self, encoder, eos_num, feat_dim, adapter_type="None", adapter_dim=0):
+ super().__init__(None)
+ self.encoder = encoder
+ self.eos_num = eos_num # downsampling rate for speech input feature
+ self.eos_emb = (
+ nn.Parameter(torch.zeros(1, feat_dim), requires_grad=True)
+ if eos_num > 0
+ else None
+ )
+ self.adapter = self.add_adapter(adapter_type, adapter_dim)
+
+ def add_adapter(self, adapter_type, adapter_dim):
+ def _make_identity(linear, eps=1e-5):
+ assert isinstance(linear, nn.Linear)
+ linear.weight.data.mul_(eps)
+ linear.weight.data.fill_diagonal_(1.0)
+ if linear.bias is not None:
+ linear.bias.data.mul_(eps)
+
+ adapter = None
+ if adapter_type == "Linear":
+ assert adapter_dim > 0
+ adapter = nn.Sequential(
+ nn.Linear(adapter_dim, adapter_dim), LayerNorm(adapter_dim)
+ )
+ # initialize the adapter as identity matrix first
+ _make_identity(adapter[0])
+
+ elif adapter_type == "MLP":
+ assert adapter_dim > 0
+ # assume the model is pre-norm model
+ adapter = nn.Sequential(
+ nn.Linear(adapter_dim, 2 * adapter_dim),
+ nn.ReLU(),
+ nn.Linear(2 * adapter_dim, adapter_dim),
+ LayerNorm(adapter_dim),
+ )
+ _make_identity(adapter[0])
+ _make_identity(adapter[2])
+ return adapter
+
+ def add_eos(self, src_tokens, src_lengths):
+ bsz, max_seq_len, fdim = src_tokens.size()
+ if self.eos_num > 0:
+ src_token_eos = torch.zeros(
+ [bsz, max_seq_len + self.eos_num, fdim],
+ dtype=src_tokens.dtype,
+ device=src_tokens.device,
+ )
+ src_token_eos[:, :max_seq_len] = src_tokens
+ for bi in range(bsz):
+ src_token_eos[bi][
+ src_lengths[bi] : src_lengths[bi] + self.eos_num
+ ] = self.eos_emb.expand(self.eos_num, fdim)
+ src_lengths = src_lengths + self.eos_num
+ src_tokens = src_token_eos
+ return src_tokens, src_lengths
+
+ def apply_adapter(self, enc_out):
+ if self.adapter is None:
+ return enc_out
+ rst = self.adapter(enc_out.encoder_out)
+ if enc_out.encoder_padding_mask is not None:
+ rst.masked_fill_(
+ enc_out.encoder_padding_mask.transpose(0, 1).unsqueeze(-1), 0
+ )
+ return EncoderOut(
+ encoder_out=rst,
+ encoder_padding_mask=enc_out.encoder_padding_mask,
+ encoder_embedding=enc_out.encoder_embedding,
+ encoder_states=enc_out.encoder_states,
+ src_tokens=enc_out.src_tokens,
+ src_lengths=enc_out.src_lengths,
+ )
+
+ def forward(self, src_tokens, src_lengths=None, return_all_hiddens=False, **kwargs):
+ """
+ src_tokens: padded tensor (B, T, C * feat)
+ src_lengths: tensor of original lengths of input utterances (B,)
+ """
+ src_tokens, src_lengths = self.add_eos(src_tokens, src_lengths)
+ enc_out = self.encoder(src_tokens, src_lengths, return_all_hiddens)
+ enc_out = self.apply_adapter(enc_out)
+ return enc_out
+
+ def reorder_encoder_out(self, encoder_out, new_order):
+ return self.encoder.reorder_encoder_out(encoder_out, new_order)
+
+
+class DualInputEncoder(FairseqEncoder):
+ def __init__(
+ self,
+ args,
+ spch_encoder,
+ text_encoder,
+ dictionary,
+ cross_attentive_loss_before_last_layer=-1,
+ ):
+ super().__init__(dictionary)
+
+ self.spch_encoder = spch_encoder
+ self.text_encoder = text_encoder
+ self.enc_grad_mult = args.enc_grad_mult
+ self.cross_attentive_loss_before_last_layer = (
+ cross_attentive_loss_before_last_layer
+ )
+ self.use_cross_attentive_loss = (
+ False if cross_attentive_loss_before_last_layer <= -1 else True
+ )
+ self.enc2_along_grad_mult = args.enc2_along_grad_mult
+
+ @classmethod
+ def set_shared_layer(cls, share_level, src_layer, tgt_layer):
+ """
+ share parameters from tgt_layer to src_layer
+ share_level:
+ 0: share everything
+ 1: share everything but different model
+ 2: share weight but not bias, layernorm
+ """
+ if share_level == 0:
+ return tgt_layer
+ if isinstance(src_layer, nn.Linear):
+ return tgt_layer
+ if isinstance(src_layer, TransformerEncoderLayer):
+ assert src_layer.embed_dim == tgt_layer.embed_dim
+ assert src_layer.normalize_before == tgt_layer.normalize_before
+ if share_level == 1:
+ src_layer.fc1 = tgt_layer.fc1
+ src_layer.fc2 = tgt_layer.fc2
+ src_layer.self_attn = tgt_layer.self_attn
+ src_layer.final_layer_norm = tgt_layer.final_layer_norm
+ src_layer.self_attn_layer_norm = tgt_layer.self_attn_layer_norm
+ src_layer.layernorm_embedding = tgt_layer.layernorm_embedding
+ else:
+ src_layer.fc1.weight = tgt_layer.fc1.weight
+ src_layer.fc2.weight = tgt_layer.fc2.weight
+ src_layer.self_attn.k_proj.weight = tgt_layer.self_attn.k_proj.weight
+ src_layer.self_attn.v_proj.weight = tgt_layer.self_attn.v_proj.weight
+ src_layer.self_attn.q_proj.weight = tgt_layer.self_attn.q_proj.weight
+ src_layer.self_attn.out_proj.weight = (
+ tgt_layer.self_attn.out_proj.weight
+ )
+ else:
+ if share_level == 1:
+ return tgt_layer
+ return src_layer
+
+ @classmethod
+ def build_spch_encoder(cls, args):
+ cfg = {
+ "input_feat_per_channel": args.input_feat_per_channel,
+ "input_channels": args.input_channels,
+ "conv_kernel_sizes": args.conv_kernel_sizes,
+ "conv_channels": args.conv_channels,
+ "encoder_embed_dim": args.encoder_embed_dim,
+ "encoder_ffn_embed_dim": args.encoder_ffn_embed_dim,
+ "encoder_layers": args.speech_encoder_layers,
+ "encoder_layerdrop": args.encoder_layerdrop,
+ "encoder_attention_heads": args.encoder_attention_heads,
+ "max_source_positions": args.max_source_positions,
+ "dropout": args.dropout,
+ "encoder_normalize_before": args.encoder_normalize_before,
+ "activation_dropout": args.activation_dropout,
+ "attention_dropout": args.attention_dropout,
+ "activation_fn": args.activation_fn,
+ "layernorm_embedding": args.layernorm_embedding,
+ "no_token_positional_embeddings": args.no_token_positional_embeddings,
+ "no_scale_embedding": args.no_scale_embedding,
+ "quant_noise_pq": args.quant_noise_pq,
+ "encoder_freezing_updates": 0,
+ }
+ model_args = namedtuple("args", cfg.keys())(*cfg.values())
+ spch_encoder = S2TTransformerEncoder(model_args)
+ if args.add_speech_eos:
+ spch_encoder = SpeechEoSEncoder(
+ spch_encoder,
+ 2 * len(args.conv_kernel_sizes.split(",")),
+ args.input_feat_per_channel,
+ adapter_type=getattr(args, "speech_encoder_adapter_type", "None"),
+ adapter_dim=args.encoder_embed_dim,
+ )
+ return spch_encoder
+
+ @classmethod
+ def build_text_encoder(cls, args, src_dictionary, spch_encoder):
+ if args.encoder_shared_layers > 0:
+ mx_shared_layers = (
+ args.speech_encoder_layers
+ if args.speech_encoder_layers < args.text_encoder_layers
+ else args.text_encoder_layers
+ )
+ args.encoder_shared_layers = (
+ args.encoder_shared_layers
+ if args.encoder_shared_layers <= mx_shared_layers
+ else mx_shared_layers
+ )
+ cfg = {
+ "encoder_embed_dim": args.encoder_text_embed_dim,
+ "encoder_ffn_embed_dim": args.encoder_ffn_embed_dim,
+ "encoder_layers": args.text_encoder_layers,
+ "encoder_layerdrop": args.encoder_layerdrop,
+ "encoder_attention_heads": args.encoder_attention_heads,
+ "encoder_learned_pos": args.encoder_learned_pos,
+ "max_source_positions": args.max_source_positions,
+ "dropout": args.dropout,
+ "encoder_normalize_before": args.encoder_normalize_before,
+ "activation_dropout": args.activation_dropout,
+ "attention_dropout": args.attention_dropout,
+ "activation_fn": args.activation_fn,
+ "adaptive_input": args.adaptive_input,
+ "no_token_positional_embeddings": args.no_token_positional_embeddings,
+ "no_scale_embedding": args.no_scale_embedding,
+ "quant_noise_pq": args.quant_noise_pq,
+ }
+ model_args = namedtuple("args", cfg.keys())(*cfg.values())
+ enc_emb = nn.Embedding(
+ len(src_dictionary), model_args.encoder_embed_dim, src_dictionary.pad()
+ )
+ text_encoder = TransformerEncoder(model_args, src_dictionary, enc_emb)
+ if args.add_speech_eos:
+ spch_encoder = spch_encoder.encoder
+ if args.encoder_shared_layers > 0:
+ text_encoder.layer_norm = cls.set_shared_layer(
+ args.encoder_shared_layer_level,
+ text_encoder.layer_norm,
+ spch_encoder.layer_norm,
+ )
+ for i, ly in enumerate(
+ spch_encoder.transformer_layers[-args.encoder_shared_layers :]
+ ):
+ ly_id = i + args.text_encoder_layers - args.encoder_shared_layers
+ assert isinstance(text_encoder.layers[ly_id], type(ly))
+ text_encoder.layers[ly_id] = cls.set_shared_layer(
+ args.encoder_shared_layer_level,
+ text_encoder.layers[ly_id],
+ ly,
+ )
+ return text_encoder
+
+ def mult_rst_grad(self, rst, ratio):
+ assert isinstance(rst, dict) # instead of EncoderOut
+ assert len(rst["encoder_out"]) == 1
+ rst["encoder_out"][0] = GradMultiply.apply(rst["encoder_out"][0], ratio)
+ return rst
+
+ def process_attentive_loss_states(self, rst, interstates):
+ assert isinstance(rst, dict) # instead of EncoderOut
+ rst["encoder_states"] = interstates
+ return rst
+
+ def forward(
+ self,
+ src_tokens,
+ src_lengths=None,
+ src_txt_tokens=None,
+ src_txt_lengths=None,
+ **kwargs
+ ):
+ """
+ Args:
+ src_tokens: padded tensor (B, T, C * feat)
+ src_lengths: tensor of original lengths of input utterances (speech) (B,)
+ src_txt_tokens: padded tensor (B, T)
+ src_txt_lengths: tensor of original lengths of input utterances (text) (B,)
+ """
+ # src_tokens only: inference
+ # src_tokens, src_lengths: speech only training
+ # src_txt_tokens, src_txt_lengths: text only training
+ # all valid: speech + text training
+
+ if src_tokens is None and src_txt_tokens is None:
+ raise ValueError(
+ "src_tokens and src_txt_tokens cannot be None at the same time"
+ )
+ ret1 = None
+ ret2 = None
+ return_all_hiddens = False
+ if src_tokens is not None:
+ if (
+ self.use_cross_attentive_loss and src_txt_tokens is not None
+ ): # remove self.training so we can get attn score during validation step
+ return_all_hiddens = True
+ ret1 = self.spch_encoder(
+ src_tokens, src_lengths, return_all_hiddens=return_all_hiddens
+ )
+
+ if self.use_cross_attentive_loss and src_txt_tokens is not None:
+ assert self.cross_attentive_loss_before_last_layer < len(
+ ret1["encoder_states"]
+ )
+ ret1 = self.process_attentive_loss_states(
+ ret1,
+ ret1["encoder_states"][
+ -self.cross_attentive_loss_before_last_layer - 1
+ ],
+ )
+
+ if src_txt_tokens is not None:
+ ret2 = self.text_encoder(
+ src_txt_tokens, src_txt_lengths, return_all_hiddens=return_all_hiddens
+ )
+ if return_all_hiddens:
+ if self.cross_attentive_loss_before_last_layer == len(
+ self.text_encoder.layers
+ ):
+ text_embedding, _ = self.text_encoder.forward_embedding(
+ src_txt_tokens
+ )
+ text_embedding = text_embedding.transpose(0, 1)
+ ret2 = self.process_attentive_loss_states(ret2, text_embedding)
+ else:
+ assert self.cross_attentive_loss_before_last_layer < len(
+ self.text_encoder.layers
+ )
+ ret2 = self.process_attentive_loss_states(
+ ret2,
+ ret2["encoder_states"][
+ -self.cross_attentive_loss_before_last_layer - 1
+ ],
+ )
+
+ def merge_output(rst1, rst2):
+ if rst1 is None:
+ if not (self.enc2_along_grad_mult == 1.0 or self.training):
+ rst2 = self.mult_rst_grad(rst2, self.enc2_along_grad_mult)
+ return rst2
+ if rst2 is None:
+ return rst1
+ if self.enc_grad_mult != 1.0 and self.training:
+ rst1 = self.mult_rst_grad(rst1, self.enc_grad_mult)
+ rst2 = self.mult_rst_grad(rst2, self.enc_grad_mult)
+ rst = (rst1, rst2)
+ return rst
+
+ return merge_output(ret1, ret2)
+
+ def reorder_encoder_out(self, encoder_out, new_order):
+ assert self.training is False # used for inference only
+ return self.spch_encoder.reorder_encoder_out(encoder_out, new_order)
+
+
+# TransformerMultiInputDecoder: take one or two encoder inputs
+class TransformerMultiInputDecoder(FairseqDecoder):
+ def __init__(
+ self,
+ dictionary,
+ spch_decoder,
+ text_decoder,
+ compute_cross_attentive_loss=False,
+ cross_attentive_loss_with_norm=True,
+ cross_attentive_loss_reverse=False,
+ ):
+
+ super().__init__(dictionary)
+ self.spch_decoder = spch_decoder
+ self.text_decoder = text_decoder
+ self.compute_cross_attentive_loss = compute_cross_attentive_loss
+ self.cross_attentive_loss_with_norm = cross_attentive_loss_with_norm
+ self.cross_attentive_loss_reverse = cross_attentive_loss_reverse
+
+ @classmethod
+ def share_spchdecoder(cls, task_args, text_decoder, spch_decoder):
+ if task_args.decoder_shared_layer_level == 0:
+ return text_decoder
+ assert text_decoder.embed_tokens == spch_decoder.embed_tokens
+ spch_decoder.project_in_dim = text_decoder.project_in_dim
+ spch_decoder.embed_positions = text_decoder.embed_positions
+ spch_decoder.layernorm_embedding = text_decoder.layernorm_embedding
+ spch_decoder.project_out_dim = text_decoder.project_out_dim
+ spch_decoder.adaptive_softmax = text_decoder.adaptive_softmax
+ if task_args.decoder_shared_layer_level == 1:
+ spch_decoder.output_projection = text_decoder.output_projection
+ spch_decoder.layer_norm = text_decoder.layer_norm
+ else: # 2
+ spch_decoder.output_projection.weight = (
+ text_decoder.output_projection.weight
+ )
+ for i, ly in enumerate(text_decoder.layers):
+ sly = spch_decoder.layers[i]
+ sly.self_attn = ly.self_attn
+ sly.self_attn_layer_norm = ly.self_attn_layer_norm
+ # sly.encoder_attn = ly.encoder_attn
+ if (
+ task_args.decoder_shared_layer_level == 1
+ ): # share everything, but under different models
+ sly.encoder_attn = ly.encoder_attn
+ sly.encoder_attn_layer_norm = ly.encoder_attn_layer_norm
+ sly.fc1 = ly.fc1
+ sly.fc2 = ly.fc2
+ sly.final_layer_norm = ly.final_layer_norm
+ else: # task_args.decoder_shared_layer_level == 2: #separated encoder_attn_layer_norm and bias
+ sly.encoder_attn.k_proj.weight = ly.encoder_attn.k_proj.weight
+ sly.encoder_attn.v_proj.weight = ly.encoder_attn.v_proj.weight
+ sly.encoder_attn.q_proj.weight = ly.encoder_attn.q_proj.weight
+ sly.encoder_attn.out_proj.weight = ly.encoder_attn.out_proj.weight
+ sly.fc1.weight = ly.fc1.weight
+ sly.fc2.weight = ly.fc2.weight
+
+ return spch_decoder
+
+ def cross_attentive_loss(
+ self, teacher_states, student_states, teacher_masking, student_masking, eps=1e-6
+ ):
+ x = teacher_states.transpose(0, 1) # from T X B X D to B X T X D
+ y = student_states.transpose(0, 1)
+ if self.cross_attentive_loss_with_norm:
+ x = x / (x.norm(dim=2, keepdim=True) + eps)
+ y = y / (y.norm(dim=2, keepdim=True) + eps)
+ dim = x.size(-1)
+ # lengths: batch X seqLen
+ sim_scores_xy = torch.bmm(x, y.transpose(1, 2)) # batch X lenx X leny ]
+ if y.dtype == torch.float16:
+ sim_scores_xy = sim_scores_xy.float()
+ y = y.float()
+ x = x.float()
+ if teacher_masking != []:
+ assert len(teacher_masking) == 1
+ sim_scores_xy = sim_scores_xy.masked_fill(
+ teacher_masking[0].unsqueeze(-1), float("-inf")
+ )
+ if student_masking != []:
+ sim_scores_xy = sim_scores_xy.masked_fill(
+ student_masking[0].unsqueeze(1), float("-inf")
+ )
+ # do masking
+ y_weights = utils.softmax(sim_scores_xy, dim=-1)
+ if teacher_masking != []:
+ y_weights = y_weights.masked_fill(teacher_masking[0].unsqueeze(-1), 0)
+ x_reconstruct_from_y = torch.bmm(y_weights, y)
+
+ sim_scores_xx = torch.bmm(x, x.transpose(1, 2)) # batch X lenx X lenx ]
+ x_weights = utils.softmax(sim_scores_xx, dim=-1)
+ if teacher_masking != []:
+ x_weights = x_weights.masked_fill(teacher_masking[0].unsqueeze(-1), 0)
+
+ # no gradient for teacher state
+ x_reconstruct_from_x = torch.bmm(x_weights, x).detach()
+ cost = (x_reconstruct_from_x - x_reconstruct_from_y).norm(dim=2)
+ if teacher_masking != []:
+ cost = cost.masked_fill(teacher_masking[0], 0)
+
+ if not self.cross_attentive_loss_with_norm:
+ cost = cost / dim
+ return cost
+
+ def forward(
+ self,
+ prev_output_tokens,
+ encoder_out,
+ incremental_state=None,
+ has_txt_input=False,
+ **kwargs
+ ):
+ """
+ Args:
+ prev_output_tokens (LongTensor): previous decoder outputs of shape
+ `(batch, tgt_len)`, for input feeding/teacher forcing. If there are
+ two or more input during training, they will share the same prev_output_tokens
+ encoder_out (tuple[Tensor]): output from the encoder, used for
+ encoder-side attention. It will be tuple if there are more inputs, but a tensor
+ if only one input
+ incremental_state ([dict]): dictionary used for storing state during
+ :ref:`Incremental decoding`. It is only valid for inference, only from single
+ input
+ Returns:
+ tuple:
+ - the last decoder layer's output of shape `(batch, tgt_len,
+ vocab)`. If there are N inputs, batch will be N bigger than a single input
+ - the last decoder layer's attention weights of shape `(batch,
+ tgt_len, src_len)`
+ """
+ assert not isinstance(encoder_out, EncoderOut)
+ if isinstance(encoder_out, tuple): # training with mulitple input
+ rst = []
+ assert len(encoder_out) == 2
+ for i, eo in enumerate(encoder_out):
+ assert incremental_state is None
+ if i == 0:
+ rst.append(
+ self.spch_decoder(prev_output_tokens, eo, incremental_state)
+ )
+ else:
+ rst.append(
+ self.text_decoder(prev_output_tokens, eo, incremental_state)
+ )
+ dec_out = torch.cat([r[0] for r in rst], dim=0)
+ attn_cost = None
+ if self.compute_cross_attentive_loss:
+ assert isinstance(encoder_out[0], dict)
+ if self.cross_attentive_loss_reverse:
+ attn_cost = self.cross_attentive_loss(
+ teacher_states=encoder_out[1]["encoder_states"], # text_states
+ student_states=encoder_out[0]["encoder_states"], # spch_states
+ teacher_masking=encoder_out[1]["encoder_padding_mask"],
+ student_masking=encoder_out[0]["encoder_padding_mask"],
+ )
+ else:
+ attn_cost = self.cross_attentive_loss(
+ teacher_states=encoder_out[0]["encoder_states"], # spch_states
+ student_states=encoder_out[1]["encoder_states"], # text_states
+ teacher_masking=encoder_out[0]["encoder_padding_mask"],
+ student_masking=encoder_out[1]["encoder_padding_mask"],
+ )
+
+ return (dec_out, {"attn_cost": attn_cost})
+ else: # inference or training with one input
+ if has_txt_input:
+ return self.text_decoder(
+ prev_output_tokens, encoder_out, incremental_state
+ )
+ return self.spch_decoder(prev_output_tokens, encoder_out, incremental_state)
+
+
+# Note:
+# dual input transformer:
+# encoder: S2TTransformerEncoder for speech + TransformerEncoder for text
+# decoder: TransformerDecoder for text
+@register_model("dual_input_s2t_transformer")
+class DualInputS2TTransformerModel(FairseqEncoderDecoderModel):
+ def __init__(self, encoder, decoder):
+ super().__init__(encoder, decoder)
+ self.num_updates = 0
+
+ def max_positions(self):
+ return None # it is provided in task
+
+ @staticmethod
+ def add_args(parser):
+ """Add model-specific arguments to the parser."""
+ # encoder 1: S2TTransformerEncoder for speech
+ parser.add_argument(
+ "--conv-kernel-sizes",
+ type=str,
+ metavar="N",
+ help="kernel sizes of Conv1d subsampling layers",
+ )
+ parser.add_argument(
+ "--conv-channels",
+ type=int,
+ metavar="N",
+ help="# of channels in Conv1d subsampling layers",
+ )
+ parser.add_argument(
+ "--enc-output-dim",
+ type=int,
+ metavar="N",
+ help="""
+ encoder output dimension, can be None. If specified, projecting the
+ transformer output to the specified dimension""",
+ )
+ # standard Transformer
+ parser.add_argument(
+ "--activation-fn",
+ type=str,
+ default="relu",
+ choices=utils.get_available_activation_fns(),
+ help="activation function to use",
+ )
+ parser.add_argument(
+ "--dropout", type=float, metavar="D", help="dropout probability"
+ )
+ parser.add_argument(
+ "--attention-dropout",
+ type=float,
+ metavar="D",
+ help="dropout probability for attention weights",
+ )
+ parser.add_argument(
+ "--activation-dropout",
+ "--relu-dropout",
+ type=float,
+ metavar="D",
+ help="dropout probability after activation in FFN.",
+ )
+ parser.add_argument(
+ "--encoder-embed-dim",
+ type=int,
+ metavar="N",
+ help="encoder embedding dimension",
+ )
+ parser.add_argument(
+ "--encoder-text-embed-dim",
+ type=int,
+ metavar="N",
+ help="encoder text embedding dimension",
+ )
+ parser.add_argument(
+ "--encoder-ffn-embed-dim",
+ type=int,
+ metavar="N",
+ help="encoder embedding dimension for FFN",
+ )
+ parser.add_argument(
+ "--encoder-attention-heads",
+ type=int,
+ metavar="N",
+ help="num encoder attention heads",
+ )
+ parser.add_argument(
+ "--decoder-embed-dim",
+ type=int,
+ metavar="N",
+ help="decoder embedding dimension",
+ )
+ parser.add_argument(
+ "--decoder-ffn-embed-dim",
+ type=int,
+ metavar="N",
+ help="decoder embedding dimension for FFN",
+ )
+ parser.add_argument(
+ "--decoder-layers", type=int, metavar="N", help="num decoder layers"
+ )
+ parser.add_argument(
+ "--decoder-attention-heads",
+ type=int,
+ metavar="N",
+ help="num decoder attention heads",
+ )
+ parser.add_argument(
+ "--layernorm-embedding",
+ action="store_true",
+ help="add layernorm to embedding",
+ )
+ parser.add_argument(
+ "--no-scale-embedding",
+ action="store_true",
+ help="if True, dont scale embeddings",
+ )
+ # non-standard transformer parameters
+ parser.add_argument(
+ "--speech-encoder-layers",
+ type=int,
+ metavar="N",
+ help="num speech encoder layers",
+ )
+ parser.add_argument(
+ "--text-encoder-layers",
+ type=int,
+ metavar="N",
+ help="num text encoder layers",
+ )
+ parser.add_argument(
+ "--encoder-shared-layers",
+ type=int,
+ metavar="N",
+ help="num shared encoder layers",
+ )
+ parser.add_argument(
+ "--encoder-shared-layer-level",
+ type=int,
+ metavar="N",
+ default=0,
+ choices=[0, 1, 2],
+ help="share layer level 0: all share 1: all share with separate model 2: share weight but not bias and layernorm",
+ )
+
+ parser.add_argument(
+ "--decoder-shared-layer-level",
+ default=0,
+ choices=[0, 1, 2],
+ type=int,
+ metavar="N",
+ help="0: share everything; 1: share everything with different model 2: no share layer_norm and bias",
+ )
+ ###
+ parser.add_argument(
+ "--text-input-cost-ratio",
+ type=float,
+ default=1.0,
+ metavar="V",
+ help="text input cost ratio relative to speech input cost",
+ )
+ parser.add_argument(
+ "--init-scale",
+ type=float,
+ default=1.0,
+ metavar="V",
+ help="scale the initial weight by given factor",
+ )
+ parser.add_argument(
+ "--enc-grad-mult",
+ type=float,
+ metavar="V",
+ default=1.0,
+ help="multiply enc1 and enc2 gradient by V",
+ )
+ parser.add_argument(
+ "--enc2-along-grad-mult",
+ type=float,
+ metavar="V",
+ default=1.0,
+ help="multiply enc2 gradient by V if only enc2 is used",
+ )
+ parser.add_argument(
+ "--load-pretrain-encoder",
+ type=str,
+ default="",
+ metavar="EXPR",
+ help=""" path to the pretrained encoder """,
+ )
+ parser.add_argument(
+ "--load-pretrain-speech-encoder",
+ type=str,
+ default="",
+ metavar="EXPR",
+ help=""" path to the pretrained speech encoder """,
+ )
+ parser.add_argument(
+ "--load-pretrain-text-encoder",
+ type=str,
+ default="",
+ metavar="EXPR",
+ help=""" path to the pretrained text encoder """,
+ )
+ parser.add_argument(
+ "--load-pretrain-text-encoder-last",
+ type=str,
+ default="",
+ metavar="EXPR",
+ help=""" path to the pretrained text encoder """,
+ )
+ parser.add_argument(
+ "--load-pretrain-decoder",
+ type=str,
+ metavar="EXPR",
+ default="",
+ help=""" path to the pretrained encoder """,
+ )
+ parser.add_argument(
+ "--add-speech-eos",
+ action="store_true",
+ help="add eos token at the end of input feature",
+ )
+ parser.add_argument(
+ "--speech-encoder-adapter-type",
+ type=str,
+ metavar="EXPR",
+ default="None",
+ choices=["None", "Linear", "MLP"],
+ help="add speech encoder adapter",
+ )
+
+ @classmethod
+ def build_encoder(cls, args, task):
+ spch_encoder = DualInputEncoder.build_spch_encoder(args)
+ text_encoder = DualInputEncoder.build_text_encoder(
+ args, task.src_dict, spch_encoder
+ )
+ cross_attentive_loss_before_last_layer = (
+ 0 if getattr(args, "attentive_cost_regularization", 0.0) > 0.0 else -1
+ )
+ encoder = DualInputEncoder(
+ args,
+ spch_encoder,
+ text_encoder,
+ task.src_dict,
+ cross_attentive_loss_before_last_layer,
+ )
+ if args.init_scale != 1.0:
+ with torch.no_grad():
+ for param in encoder.parameters():
+ param.data.mul_(args.init_scale)
+ if args.load_pretrain_text_encoder != "":
+ checkpoint_utils.load_pretrained_component_from_model(
+ text_encoder, args.load_pretrain_text_encoder
+ )
+ if args.load_pretrain_speech_encoder != "":
+ if hasattr(spch_encoder, "encoder"):
+ checkpoint_utils.load_pretrained_component_from_model(
+ spch_encoder.encoder, args.load_pretrain_speech_encoder
+ )
+ else:
+ checkpoint_utils.load_pretrained_component_from_model(
+ spch_encoder, args.load_pretrain_speech_encoder
+ )
+ if (
+ args.load_pretrain_text_encoder_last != ""
+ ): # if share encoder, speech encoder parameters will be used.
+ # It provides a chance to use pre-trained mt encoder instead
+ checkpoint_utils.load_pretrained_component_from_model(
+ text_encoder, args.load_pretrain_text_encoder_last
+ )
+
+ if args.load_pretrain_encoder != "":
+ checkpoint_utils.load_pretrained_component_from_model(
+ encoder, args.load_pretrain_encoder
+ )
+ return encoder
+
+ @classmethod
+ def build_decoder(cls, args, task):
+ dec_cfg = {
+ "decoder_layerdrop": args.decoder_layerdrop,
+ "share_decoder_input_output_embed": args.share_decoder_input_output_embed,
+ "decoder_embed_dim": args.decoder_embed_dim,
+ "max_target_positions": args.max_target_positions,
+ "dropout": args.dropout,
+ "encoder_learned_pos": args.encoder_learned_pos,
+ "decoder_learned_pos": args.decoder_learned_pos,
+ "layernorm_embedding": args.layernorm_embedding,
+ "decoder_normalize_before": args.decoder_normalize_before,
+ "activation_dropout": args.activation_dropout,
+ "attention_dropout": args.attention_dropout,
+ "decoder_ffn_embed_dim": args.decoder_ffn_embed_dim,
+ "decoder_layers": args.decoder_layers,
+ "decoder_attention_heads": args.decoder_attention_heads,
+ "decoder_output_dim": args.decoder_embed_dim,
+ "no_scale_embedding": args.no_scale_embedding,
+ "adaptive_input": args.adaptive_input,
+ "quant_noise_pq": args.quant_noise_pq,
+ "adaptive_softmax_cutoff": args.adaptive_softmax_cutoff,
+ "tie_adaptive_weights": args.tie_adaptive_weights,
+ "no_token_positional_embeddings": args.no_token_positional_embeddings,
+ }
+ dec_cfg = namedtuple("args", dec_cfg.keys())(*dec_cfg.values())
+ dec_emb = nn.Embedding(
+ len(task.target_dictionary),
+ args.decoder_embed_dim,
+ task.target_dictionary.pad(),
+ )
+ compute_cross_attentive_loss = (
+ True if getattr(args, "attentive_cost_regularization", 0.0) > 0.0 else False
+ )
+ cross_attentive_loss_without_norm = getattr(
+ args, "attentive_cost_without_normalize", False
+ )
+ cross_attentive_loss_reverse = (
+ False # getattr(args, "attentive_cost_reverse", False)
+ )
+
+ text_decoder = TransformerDecoder(dec_cfg, task.target_dictionary, dec_emb)
+ spch_decoder = TransformerDecoder(dec_cfg, task.target_dictionary, dec_emb)
+ spch_decoder = TransformerMultiInputDecoder.share_spchdecoder(
+ args, text_decoder, spch_decoder
+ )
+ decoder = TransformerMultiInputDecoder(
+ dictionary=task.target_dictionary,
+ spch_decoder=spch_decoder,
+ text_decoder=text_decoder,
+ compute_cross_attentive_loss=compute_cross_attentive_loss,
+ cross_attentive_loss_with_norm=True
+ if not cross_attentive_loss_without_norm
+ else False,
+ cross_attentive_loss_reverse=cross_attentive_loss_reverse,
+ )
+ if args.init_scale != 1.0:
+ with torch.no_grad():
+ for param in decoder.parameters():
+ param.data.mul_(args.init_scale)
+ if args.load_pretrain_decoder != "":
+ try:
+ checkpoint_utils.load_pretrained_component_from_model(
+ decoder, args.load_pretrain_decoder
+ )
+ except RuntimeError:
+ checkpoint_utils.load_pretrained_component_from_model(
+ decoder.text_decoder, args.load_pretrain_decoder
+ )
+ if args.decoder_shared_layer_level > 0:
+ checkpoint_utils.load_pretrained_component_from_model(
+ decoder.spch_decoder, args.load_pretrain_decoder
+ )
+
+ return decoder
+
+ @classmethod
+ def build_model(cls, args, task):
+ """Build a new model instance."""
+ # make sure that all args are properly defaulted
+ # (in case there are any new ones)
+ dualinputs2ttransformer_base(args)
+
+ encoder = cls.build_encoder(args, task)
+ decoder = cls.build_decoder(args, task)
+ return cls(encoder, decoder)
+
+ def get_normalized_probs(self, net_output, log_probs, sample=None):
+ # net_output['encoder_out'] is a (B, T, D) tensor
+ lprobs = super().get_normalized_probs(net_output, log_probs, sample)
+ lprobs.batch_first = True
+ return lprobs
+
+ def set_num_updates(self, num_updates):
+ """Set the number of parameters updates."""
+ super().set_num_updates(num_updates)
+ self.num_updates = num_updates
+
+ def forward(
+ self,
+ src_tokens,
+ src_lengths,
+ prev_output_tokens,
+ use_encoder_outputs=False,
+ src_txt_tokens=None,
+ src_txt_lengths=None,
+ mode="sup_speech",
+ **kwargs
+ ):
+ """
+ Run the forward pass for an encoder-decoder model.
+
+ First feed a batch of source tokens through the encoder. Then, feed the
+ encoder output and previous decoder outputs (i.e., teacher forcing) to
+ the decoder to produce the next outputs::
+
+ encoder_out = self.encoder(src_tokens, src_lengths)
+ return self.decoder(prev_output_tokens, encoder_out)
+
+ Args:
+ src_tokens (LongTensor): tokens in the source language of shape
+ `(batch, src_len)`
+ src_lengths (LongTensor): source sentence lengths of shape `(batch)`
+ prev_output_tokens (LongTensor): previous decoder outputs of shape
+ `(batch, tgt_len)`, for teacher forcing
+ mode = 'sup_speech' or 'text'
+
+ Returns:
+ tuple:
+ - the decoder's output of shape `(batch, tgt_len, vocab)`
+ - a dictionary with any model-specific outputs
+ """
+ if mode == "text":
+ assert src_txt_tokens is None
+ src_txt_tokens = src_tokens
+ src_txt_lengths = src_lengths
+ src_tokens = None
+ src_lengths = None
+ encoder_out = self.encoder(
+ src_tokens,
+ src_lengths=src_lengths,
+ src_txt_tokens=src_txt_tokens,
+ src_txt_lengths=src_txt_lengths,
+ **kwargs
+ )
+ has_txt_input = True if src_txt_tokens is not None else False
+ decoder_out = self.decoder(
+ prev_output_tokens,
+ encoder_out=encoder_out,
+ has_txt_input=has_txt_input,
+ **kwargs
+ )
+ if use_encoder_outputs:
+ return decoder_out, encoder_out
+ return decoder_out
+
+
+@register_model_architecture(
+ "dual_input_s2t_transformer", "dualinputs2ttransformer_base"
+)
+def dualinputs2ttransformer_base(args):
+ args.encoder_freezing_updates = getattr(args, "encoder_freezing_updates", 0)
+ # Convolutional subsampler
+ args.input_feat_per_channel = getattr(args, "input_feat_per_channel", 80)
+ args.conv_kernel_sizes = getattr(args, "conv_kernel_sizes", "5,5")
+ args.conv_channels = getattr(args, "conv_channels", 1024)
+ # Transformer
+ args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
+ args.encoder_text_embed_dim = getattr(
+ args, "encoder_text_embed_dim", args.encoder_embed_dim
+ )
+ args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048)
+ args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8)
+ args.encoder_normalize_before = getattr(args, "encoder_normalize_before", True)
+ args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0)
+ args.encoder_learned_pos = getattr(args, "encoder_learned_pos", False)
+
+ args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim)
+ args.decoder_ffn_embed_dim = getattr(
+ args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim
+ )
+ args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8)
+ args.decoder_normalize_before = getattr(args, "decoder_normalize_before", True)
+ args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False)
+ args.dropout = getattr(args, "dropout", 0.1)
+ args.attention_dropout = getattr(args, "attention_dropout", args.dropout)
+ args.activation_dropout = getattr(args, "activation_dropout", args.dropout)
+ args.activation_fn = getattr(args, "activation_fn", "relu")
+ args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
+ args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
+ args.tie_adaptive_weights = getattr(args, "tie_adaptive_weights", False)
+ args.share_decoder_input_output_embed = getattr(
+ args, "share_decoder_input_output_embed", False
+ )
+ args.no_token_positional_embeddings = getattr(
+ args, "no_token_positional_embeddings", False
+ )
+ args.adaptive_input = getattr(args, "adaptive_input", False)
+ args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0.0)
+ args.decoder_output_dim = getattr(
+ args, "decoder_output_dim", args.decoder_embed_dim
+ )
+ args.layernorm_embedding = getattr(args, "layernorm_embedding", False)
+ args.no_scale_embedding = getattr(args, "no_scale_embedding", False)
+ args.quant_noise_pq = getattr(args, "quant_noise_pq", 0)
+
+ args.speech_encoder_layers = getattr(args, "speech_encoder_layers", 10)
+ args.text_encoder_layers = getattr(args, "text_encoder_layers", 6)
+ args.encoder_shared_layers = getattr(args, "encoder_shared_layers", 0)
+ args.decoder_layers = getattr(args, "decoder_layers", 6)
+
+ args.add_speech_eos = getattr(args, "add_speech_eos", False)
+
+
+@register_model_architecture("dual_input_s2t_transformer", "dualinputs2ttransformer_s")
+def dualinputs2ttransformer_s(args):
+ args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 256)
+ args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 256 * 4)
+ args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4)
+ args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 4)
+ args.dropout = getattr(args, "dropout", 0.1)
+ args.speech_encoder_layers = getattr(args, "speech_encoder_layers", 7)
+ args.text_encoder_layers = getattr(args, "text_encoder_layers", 7)
+ args.decoder_layers = getattr(args, "decoder_layers", 7)
+ dualinputs2ttransformer_base(args)
+
+
+@register_model_architecture("dual_input_s2t_transformer", "dualinputs2ttransformer_m")
+def dualinputs2ttransformer_m(args):
+ args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
+ args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 512 * 4)
+ args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8)
+ args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8)
+ args.dropout = getattr(args, "dropout", 0.15)
+ args.speech_encoder_layers = getattr(args, "speech_encoder_layers", 10)
+ args.text_encoder_layers = getattr(args, "text_encoder_layers", 6)
+ args.decoder_layers = getattr(args, "decoder_layers", 6)
+ dualinputs2ttransformer_base(args)
+
+
+@register_model_architecture("dual_input_s2t_transformer", "dualinputs2ttransformer_b")
+def dualinputs2ttransformer_b(args):
+ args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 768)
+ args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 768 * 4)
+ args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 12)
+ args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 12)
+ args.dropout = getattr(args, "dropout", 0.15)
+ args.speech_encoder_layers = getattr(args, "speech_encoder_layers", 12)
+ args.text_encoder_layers = getattr(args, "text_encoder_layers", 6)
+ args.decoder_layers = getattr(args, "decoder_layers", 6)
+ dualinputs2ttransformer_base(args)
+
+
+@register_model_architecture("dual_input_s2t_transformer", "dualinputs2ttransformer_l")
+def dualinputs2ttransformer_l(args):
+ args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024)
+ args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 1024 * 4)
+ args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16)
+ args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16)
+ args.dropout = getattr(args, "dropout", 0.2)
+ args.speech_encoder_layers = getattr(args, "speech_encoder_layers", 12)
+ args.text_encoder_layers = getattr(args, "text_encoder_layers", 6)
+ args.decoder_layers = getattr(args, "decoder_layers", 6)
+ dualinputs2ttransformer_base(args)
diff --git a/fairseq/examples/speech_text_joint_to_text/models/s2t_dualinputxmtransformer.py b/fairseq/examples/speech_text_joint_to_text/models/s2t_dualinputxmtransformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..50683e6d7c8c0db5b8f019e5f7f5fb8c6dfd9f66
--- /dev/null
+++ b/fairseq/examples/speech_text_joint_to_text/models/s2t_dualinputxmtransformer.py
@@ -0,0 +1,585 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import copy
+
+import torch.nn as nn
+from fairseq import checkpoint_utils
+from fairseq import utils
+from fairseq.data.data_utils import lengths_to_padding_mask
+from fairseq.models import (
+ register_model,
+ register_model_architecture,
+ FairseqEncoder,
+)
+from fairseq.models.speech_to_text import XMTransformerModel, Wav2VecEncoderWithAdaptor
+from fairseq.models.speech_to_text.xm_transformer import (
+ set_default_adaptor_args,
+ set_default_w2v_encoder_args,
+)
+from fairseq.models.transformer import TransformerEncoder, TransformerDecoder
+from fairseq.models.wav2vec import TransformerSentenceEncoderLayer
+from fairseq.utils import safe_hasattr
+
+from .s2t_dualinputtransformer import (
+ DualInputS2TTransformerModel,
+ TransformerMultiInputDecoder,
+ DualInputEncoder,
+)
+
+
+class TransformerSentenceEncoderLayerStd(TransformerSentenceEncoderLayer):
+ def __init__(self, sent_enc_layer):
+ super(TransformerSentenceEncoderLayer, self).__init__()
+ self.embedding_dim = sent_enc_layer.embedding_dim
+ self.dropout = sent_enc_layer.dropout
+ self.activation_dropout = sent_enc_layer.activation_dropout
+
+ # Initialize blocks
+ self.activation_fn = sent_enc_layer.activation_fn
+ self.self_attn = sent_enc_layer.self_attn
+
+ self.dropout1 = sent_enc_layer.dropout1
+ self.dropout2 = sent_enc_layer.dropout2
+ self.dropout3 = sent_enc_layer.dropout3
+
+ self.layer_norm_first = sent_enc_layer.layer_norm_first
+
+ # layer norm associated with the self attention layer
+ self.self_attn_layer_norm = sent_enc_layer.self_attn_layer_norm
+ self.fc1 = sent_enc_layer.fc1
+ self.fc2 = sent_enc_layer.fc2
+
+ # layer norm associated with the position wise feed-forward NN
+ self.final_layer_norm = sent_enc_layer.final_layer_norm
+
+ def forward(
+ self,
+ x,
+ self_attn_mask=None,
+ self_attn_padding_mask=None,
+ need_weights=None,
+ att_args=None,
+ ):
+ x, attn = super().forward(
+ x, self_attn_mask, self_attn_padding_mask, need_weights, att_args
+ )
+ return x
+
+
+# TODO retire SharedEncoder
+class SharedEncoder(FairseqEncoder):
+ def __init__(self, wav2vec_enc, mbart_enc, adaptor, shared_layers):
+ super().__init__(None)
+ self.w2v_encoder = wav2vec_enc
+ self.shared_layers = self.w2v_encoder.w2v_model.encoder.layers[-shared_layers:]
+ self.w2v_encoder.w2v_model.encoder.layers = (
+ self.w2v_encoder.w2v_model.encoder.layers[:-shared_layers]
+ )
+ self.adaptor = adaptor
+ if self.shared_layers[-1].layer_norm_first:
+ self.final_layer_norm = mbart_enc.layer_norm
+ else:
+ mbart_enc.layer_norm = None
+ self.final_layer_norm = None
+ shared_layer_from = len(mbart_enc.layers) - shared_layers
+ if shared_layer_from < 0:
+ shared_layer_from = 0
+ for layer_id, layer in enumerate(self.shared_layers):
+ mbart_enc.layers[
+ shared_layer_from + layer_id
+ ] = TransformerSentenceEncoderLayerStd(layer)
+
+ def forward(self, src_tokens, src_lengths=None, **kwargs):
+ padding_mask = lengths_to_padding_mask(src_lengths)
+ if not padding_mask.any():
+ padding_mask = None
+
+ out = self.w2v_encoder.forward(src_tokens, padding_mask, tbc=True)
+ x = out["encoder_out"]
+ enc_padding_mask = None
+ if out["encoder_padding_mask"] is not None:
+ enc_padding_mask = out["encoder_padding_mask"].transpose(
+ 0, 1
+ ) # T X B --> B X T
+
+ x, enc_padding_mask = self.adaptor(x, enc_padding_mask)
+ for layer in self.shared_layers:
+ x, _ = layer(x, enc_padding_mask)
+ if self.final_layer_norm is not None:
+ x = self.final_layer_norm(x)
+
+ return {
+ "encoder_out": [x], # T x B x C
+ "encoder_padding_mask": [enc_padding_mask]
+ if enc_padding_mask is not None
+ else [], # B x T
+ "encoder_embedding": [], # B x T x C
+ "encoder_states": [], # List[T x B x C]
+ "src_tokens": [],
+ "src_lengths": [],
+ }
+
+
+class StackedWav2VecEncoderWithAdaptor(FairseqEncoder):
+ def __init__(
+ self,
+ wav2vec_enc,
+ mbart_enc_layers,
+ mbart_layer_norm,
+ adaptor,
+ drop_w2v_layers=0,
+ ):
+ super().__init__(None)
+ self.w2v_encoder = wav2vec_enc
+ self.adaptor = adaptor
+ self.mbart_encoder_layers = mbart_enc_layers
+ self.final_layer_norm = mbart_layer_norm
+ if drop_w2v_layers > 0:
+ self.w2v_encoder.w2v_model.encoder.layers = (
+ self.w2v_encoder.w2v_model.encoder.layers[:-drop_w2v_layers]
+ )
+
+ def forward(self, src_tokens, src_lengths=None, return_all_hiddens=False, **kwargs):
+ padding_mask = lengths_to_padding_mask(src_lengths)
+ if not padding_mask.any():
+ padding_mask = None
+
+ out = self.w2v_encoder.forward(src_tokens, padding_mask, tbc=True)
+ x = out["encoder_out"]
+ enc_padding_mask = None
+ if out["encoder_padding_mask"] is not None:
+ enc_padding_mask = out["encoder_padding_mask"].transpose(
+ 0, 1
+ ) # T X B --> B X T
+
+ x, enc_padding_mask = self.adaptor(x, enc_padding_mask)
+ encoder_states = []
+ for layer in self.mbart_encoder_layers:
+ x = layer(x, enc_padding_mask)
+ if return_all_hiddens:
+ encoder_states.append(x)
+ if self.final_layer_norm is not None:
+ x = self.final_layer_norm(x)
+
+ return {
+ "encoder_out": [x], # T x B x C
+ "encoder_padding_mask": [enc_padding_mask]
+ if enc_padding_mask is not None
+ else [], # B x T
+ "encoder_embedding": [], # B x T x C
+ "encoder_states": encoder_states, # List[T x B x C]
+ "src_tokens": [],
+ "src_lengths": [],
+ }
+
+ def reorder_encoder_out(self, encoder_out, new_order):
+ new_encoder_out = (
+ []
+ if len(encoder_out["encoder_out"]) == 0
+ else [x.index_select(1, new_order) for x in encoder_out["encoder_out"]]
+ )
+
+ new_encoder_padding_mask = (
+ []
+ if len(encoder_out["encoder_padding_mask"]) == 0
+ else [
+ x.index_select(0, new_order)
+ for x in encoder_out["encoder_padding_mask"]
+ ]
+ )
+
+ new_encoder_embedding = (
+ []
+ if len(encoder_out["encoder_embedding"]) == 0
+ else [
+ x.index_select(0, new_order) for x in encoder_out["encoder_embedding"]
+ ]
+ )
+
+ encoder_states = encoder_out["encoder_states"]
+ if len(encoder_states) > 0:
+ for idx, state in enumerate(encoder_states):
+ encoder_states[idx] = state.index_select(1, new_order)
+
+ return {
+ "encoder_out": new_encoder_out, # T x B x C
+ "encoder_padding_mask": new_encoder_padding_mask, # B x T
+ "encoder_embedding": new_encoder_embedding, # B x T x C
+ "encoder_states": encoder_states, # List[T x B x C]
+ "src_tokens": [], # B x T
+ "src_lengths": [], # B x 1
+ }
+
+
+# Note:
+# dual input transformer:
+# encoder: wav2vec for speech + mbart encoder for text
+# decoder: mbart decoder for text
+@register_model("dual_input_xm_transformer")
+class DualInputXMTransformerModel(DualInputS2TTransformerModel):
+ def __init__(self, encoder, decoder):
+ super().__init__(encoder, decoder)
+
+ @staticmethod
+ def add_args(parser):
+ """Add model-specific arguments to the parser."""
+ # wav2vec encoder
+ Wav2VecEncoderWithAdaptor.add_args(parser)
+ # add_decoder_args(parser)
+ # mbart Transformer
+ parser.add_argument(
+ "--activation-fn",
+ type=str,
+ default="relu",
+ choices=utils.get_available_activation_fns(),
+ help="activation function to use",
+ )
+
+ parser.add_argument(
+ "--mbart-dropout", type=float, metavar="D", help="dropout probability"
+ )
+ parser.add_argument(
+ "--mbart-attention-dropout",
+ type=float,
+ metavar="D",
+ help="dropout probability for attention weights",
+ )
+ parser.add_argument(
+ "--mbart-activation-dropout",
+ type=float,
+ metavar="D",
+ help="dropout probability after activation in FFN.",
+ )
+
+ parser.add_argument(
+ "--encoder-embed-dim",
+ type=int,
+ metavar="N",
+ help="encoder embedding dimension",
+ )
+ parser.add_argument(
+ "--encoder-ffn-embed-dim",
+ type=int,
+ metavar="N",
+ help="encoder embedding dimension for FFN",
+ )
+ parser.add_argument(
+ "--encoder-layers", type=int, metavar="N", help="num encoder layers"
+ )
+ parser.add_argument(
+ "--encoder-attention-heads",
+ type=int,
+ metavar="N",
+ help="num encoder attention heads",
+ )
+ parser.add_argument(
+ "--encoder-normalize-before",
+ action="store_true",
+ help="apply layernorm before each encoder block",
+ )
+
+ parser.add_argument(
+ "--decoder-embed-dim",
+ type=int,
+ metavar="N",
+ help="decoder embedding dimension",
+ )
+ parser.add_argument(
+ "--decoder-ffn-embed-dim",
+ type=int,
+ metavar="N",
+ help="decoder embedding dimension for FFN",
+ )
+ parser.add_argument(
+ "--decoder-layers", type=int, metavar="N", help="num decoder layers"
+ )
+ parser.add_argument(
+ "--decoder-attention-heads",
+ type=int,
+ metavar="N",
+ help="num decoder attention heads",
+ )
+ parser.add_argument(
+ "--decoder-normalize-before",
+ action="store_true",
+ help="apply layernorm before each decoder block",
+ )
+ parser.add_argument(
+ "--layernorm-embedding",
+ action="store_true",
+ help="add layernorm to embedding",
+ )
+ parser.add_argument(
+ "--no-scale-embedding",
+ action="store_true",
+ help="if True, dont scale embeddings",
+ )
+ parser.add_argument(
+ "--load-pretrained-mbart-from",
+ type=str,
+ metavar="STR",
+ help="model to take text encoder decoder weights from (for initialization)",
+ )
+ # parser.add_argument("--finetune-w2v-params", type=str, metavar="STR",
+ # help="comma-separated param strings to finetune.")
+ parser.add_argument(
+ "--finetune-mbart-decoder-params",
+ type=str,
+ metavar="STR",
+ help="comma-separated param strings to finetune.",
+ )
+ parser.add_argument(
+ "--finetune-mbart-encoder-params",
+ type=str,
+ metavar="STR",
+ help="comma-separated param strings to finetune.",
+ )
+ parser.add_argument(
+ "--skip-encoder-projection",
+ action="store_true",
+ help="skip the projection layer in encoder",
+ )
+
+ parser.add_argument(
+ "--enc-grad-mult",
+ type=float,
+ metavar="V",
+ default=1.0,
+ help="multiply enc1 and enc2 gradient by V",
+ )
+ parser.add_argument(
+ "--enc2-along-grad-mult",
+ type=float,
+ metavar="V",
+ default=1.0,
+ help="multiply enc2 gradient by V if only enc2 is used",
+ )
+ parser.add_argument(
+ "--text-input-cost-ratio",
+ type=float,
+ default=1.0,
+ metavar="V",
+ help="text input cost ratio relative to speech input cost",
+ )
+ parser.add_argument(
+ "--stack-w2v-mbart-encoder",
+ action="store_true",
+ help="stack w2v and mbart encoder",
+ )
+ parser.add_argument(
+ "--stack-w2v-mbart-nonorm-encoder",
+ action="store_true",
+ help="stack w2v and mbart encoder",
+ )
+ parser.add_argument(
+ "--no-final-norm-decoder", action="store_true", help="no layer norm"
+ )
+ parser.add_argument(
+ "--drop-w2v-layers",
+ type=int,
+ default=0,
+ metavar="N",
+ help="drop w2v encoder layers",
+ )
+
+ parser.add_argument(
+ "--share-w2v-text-encoder",
+ action="store_true",
+ help="share w2v encoder layers with text encoder",
+ )
+ parser.add_argument(
+ "--shared-w2v-layers",
+ type=int,
+ default=0,
+ metavar="N",
+ help="shared encoder layers from w2v encoder",
+ )
+
+ @classmethod
+ def build_encoder(cls, args, task):
+ _args = copy.deepcopy(args)
+ _args.dropout = args.mbart_dropout
+ _args.attention_dropout = args.mbart_attention_dropout
+ _args.activation_dropout = args.mbart_activation_dropout
+ _args.max_source_positions = 1024
+ enc_emb = nn.Embedding(
+ len(task.src_dict), _args.encoder_embed_dim, task.src_dict.pad()
+ )
+ text_encoder = TransformerEncoder(_args, task.src_dict, enc_emb)
+ spch_encoder = Wav2VecEncoderWithAdaptor(args)
+ if getattr(args, "load_pretrained_mbart_from", None):
+ text_encoder = checkpoint_utils.load_pretrained_component_from_model(
+ component=text_encoder, checkpoint=args.load_pretrained_mbart_from
+ )
+ if getattr(args, "stack_w2v_mbart_encoder", False):
+ assert getattr(args, "share_w2v_text_encoder", False) is False
+ spch_encoder = StackedWav2VecEncoderWithAdaptor(
+ spch_encoder.w2v_encoder,
+ text_encoder.layers,
+ text_encoder.layer_norm,
+ spch_encoder.adaptor,
+ args.drop_w2v_layers,
+ )
+ elif getattr(args, "stack_w2v_mbart_nonorm_encoder", False):
+ text_encoder.layer_norm = None
+ spch_encoder = StackedWav2VecEncoderWithAdaptor(
+ spch_encoder.w2v_encoder,
+ text_encoder.layers,
+ text_encoder.layer_norm,
+ spch_encoder.adaptor,
+ args.drop_w2v_layers,
+ )
+ elif getattr(args, "share_w2v_text_encoder", False):
+ spch_encoder = SharedEncoder(
+ spch_encoder.w2v_encoder,
+ text_encoder,
+ spch_encoder.adaptor,
+ args.shared_w2v_layers,
+ )
+
+ for k, p in spch_encoder.named_parameters():
+ # Freeze pretrained models by default
+ if safe_hasattr(
+ args, "finetune_w2v_params"
+ ) and XMTransformerModel.finetune_params(args.finetune_w2v_params, k):
+ p.requires_grad = True
+ else:
+ p.requires_grad = False
+ for k, p in text_encoder.named_parameters():
+ # Freeze pretrained models by default
+ if safe_hasattr(
+ args, "finetune_mbart_encoder_params"
+ ) and XMTransformerModel.finetune_params(
+ args.finetune_mbart_encoder_params, k
+ ):
+ p.requires_grad = True
+ else:
+ p.requires_grad = False
+ cross_attentive_loss_before_last_layer = (
+ 0 if getattr(args, "attentive_cost_regularization", 0.0) > 0.0 else -1
+ )
+ encoder = DualInputEncoder(
+ args,
+ spch_encoder,
+ text_encoder,
+ task.src_dict,
+ cross_attentive_loss_before_last_layer,
+ )
+ return encoder
+
+ @classmethod
+ def build_decoder(cls, args, task):
+ _args = copy.deepcopy(args)
+ _args.dropout = args.mbart_dropout
+ _args.attention_dropout = args.mbart_attention_dropout
+ _args.activation_dropout = args.mbart_activation_dropout
+ _args.max_target_positions = 1024
+ dec_emb = nn.Embedding(
+ len(task.tgt_dict), _args.encoder_embed_dim, task.tgt_dict.pad()
+ )
+ decoder = TransformerDecoder(_args, task.tgt_dict, dec_emb)
+ if getattr(args, "load_pretrained_mbart_from", None):
+ decoder = checkpoint_utils.load_pretrained_component_from_model(
+ component=decoder, checkpoint=args.load_pretrained_mbart_from
+ )
+ if getattr(args, "no_final_norm_decoder", False):
+ decoder.layer_norm = None
+ for k, p in decoder.named_parameters():
+ # Freeze pretrained models by default
+ if safe_hasattr(
+ args, "finetune_mbart_decoder_params"
+ ) and XMTransformerModel.finetune_params(
+ args.finetune_mbart_decoder_params, k
+ ):
+ p.requires_grad = True
+ else:
+ p.requires_grad = False
+
+ compute_cross_attentive_loss = (
+ True if getattr(args, "attentive_cost_regularization", 0.0) > 0.0 else False
+ )
+ cross_attentive_loss_without_norm = getattr(
+ args, "attentive_cost_without_normalize", False
+ )
+ cross_attentive_loss_reverse = (
+ False # getattr(args, "attentive_cost_reverse", False)
+ )
+ decoder = TransformerMultiInputDecoder(
+ dictionary=task.target_dictionary,
+ spch_decoder=decoder,
+ text_decoder=decoder,
+ compute_cross_attentive_loss=compute_cross_attentive_loss,
+ cross_attentive_loss_with_norm=True
+ if not cross_attentive_loss_without_norm
+ else False,
+ cross_attentive_loss_reverse=cross_attentive_loss_reverse,
+ )
+ return decoder
+
+ @classmethod
+ def build_model(cls, args, task):
+ """Build a new model instance."""
+ # make sure that all args are properly defaulted
+ # (in case there are any new ones)
+ dualinputxmtransformer_base(args)
+
+ encoder = cls.build_encoder(args, task)
+ decoder = cls.build_decoder(args, task)
+ return cls(encoder, decoder)
+
+
+@register_model_architecture("dual_input_xm_transformer", "dualinputxmtransformer_base")
+def dualinputxmtransformer_base(args):
+ # wav2vec encoder
+ set_default_w2v_encoder_args(args)
+ set_default_adaptor_args(args)
+
+ # mbart model
+ args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024)
+ args.encoder_ffn_embed_dim = getattr(
+ args, "encoder_ffn_embed_dim", 4 * args.encoder_embed_dim
+ )
+ args.encoder_layers = getattr(args, "encoder_layers", 12)
+ args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16)
+ args.encoder_normalize_before = getattr(args, "encoder_normalize_before", True)
+ args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0)
+ args.encoder_learned_pos = getattr(args, "encoder_learned_pos", True)
+
+ args.decoder_embed_path = getattr(args, "decoder_embed_path", None)
+ args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1024)
+ args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 4 * 1024)
+ args.decoder_layers = getattr(args, "decoder_layers", 12)
+ args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16)
+ args.decoder_normalize_before = getattr(args, "decoder_normalize_before", True)
+ args.decoder_learned_pos = getattr(args, "decoder_learned_pos", True)
+ args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0.0)
+
+ args.adaptive_input = getattr(args, "adaptive_input", False)
+
+ args.mbart_attention_dropout = getattr(args, "mbart_attention_dropout", 0.0)
+ args.mbart_activation_dropout = getattr(args, "mbart_activation_dropout", 0.0)
+ args.mbart_dropout = getattr(args, "mbart_dropout", 0.1)
+ args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
+ args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
+ args.share_decoder_input_output_embed = getattr(
+ args, "share_decoder_input_output_embed", True
+ )
+ args.no_token_positional_embeddings = getattr(
+ args, "no_token_positional_embeddings", False
+ )
+
+ args.decoder_output_dim = getattr(
+ args, "decoder_output_dim", args.decoder_embed_dim
+ )
+ args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim)
+
+ args.no_scale_embedding = getattr(args, "no_scale_embedding", False)
+ args.quant_noise_pq = getattr(args, "quant_noise_pq", 0)
+ args.layernorm_embedding = getattr(args, "layernorm_embedding", True)
+
+ args.activation_fn = getattr(args, "activation_fn", "gelu")
+ args.pooler_activation_fn = getattr(args, "pooler_activation_fn", "tanh")
+ args.pooler_dropout = getattr(args, "pooler_dropout", 0.0)
diff --git a/fairseq/examples/speech_text_joint_to_text/scripts/g2p_encode.py b/fairseq/examples/speech_text_joint_to_text/scripts/g2p_encode.py
new file mode 100644
index 0000000000000000000000000000000000000000..9db779396f492e3f71b08d7b895beb81d8e46bc9
--- /dev/null
+++ b/fairseq/examples/speech_text_joint_to_text/scripts/g2p_encode.py
@@ -0,0 +1,191 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import argparse
+import itertools
+import logging
+import re
+import time
+
+from g2p_en import G2p
+
+logger = logging.getLogger(__name__)
+
+FAIL_SENT = "FAILED_SENTENCE"
+
+
+def parse():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--data-path", type=str, required=True)
+ parser.add_argument("--out-path", type=str, required=True)
+ parser.add_argument("--lower-case", action="store_true")
+ parser.add_argument("--do-filter", action="store_true")
+ parser.add_argument("--use-word-start", action="store_true")
+ parser.add_argument("--dup-vowel", default=1, type=int)
+ parser.add_argument("--dup-consonant", default=1, type=int)
+ parser.add_argument("--no-punc", action="store_true")
+ parser.add_argument("--reserve-word", type=str, default="")
+ parser.add_argument(
+ "--reserve-first-column",
+ action="store_true",
+ help="first column is sentence id",
+ )
+ ###
+ parser.add_argument("--parallel-process-num", default=1, type=int)
+ parser.add_argument("--logdir", default="")
+ args = parser.parse_args()
+ return args
+
+
+def process_sent(sent, g2p, res_wrds, args):
+ sents = pre_process_sent(sent, args.do_filter, args.lower_case, res_wrds)
+ pho_seqs = [do_g2p(g2p, s, res_wrds, i == 0) for i, s in enumerate(sents)]
+ pho_seq = (
+ [FAIL_SENT]
+ if [FAIL_SENT] in pho_seqs
+ else list(itertools.chain.from_iterable(pho_seqs))
+ )
+ if args.no_punc:
+ pho_seq = remove_punc(pho_seq)
+ if args.dup_vowel > 1 or args.dup_consonant > 1:
+ pho_seq = dup_pho(pho_seq, args.dup_vowel, args.dup_consonant)
+ if args.use_word_start:
+ pho_seq = add_word_start(pho_seq)
+ return " ".join(pho_seq)
+
+
+def remove_punc(sent):
+ ns = []
+ regex = re.compile("[^a-zA-Z0-9 ]")
+ for p in sent:
+ if (not regex.search(p)) or p == FAIL_SENT:
+ if p == " " and (len(ns) == 0 or ns[-1] == " "):
+ continue
+ ns.append(p)
+ return ns
+
+
+def do_g2p(g2p, sent, res_wrds, is_first_sent):
+ if sent in res_wrds:
+ pho_seq = [res_wrds[sent]]
+ else:
+ pho_seq = g2p(sent)
+ if not is_first_sent:
+ pho_seq = [" "] + pho_seq # add space to separate
+ return pho_seq
+
+
+def pre_process_sent(sent, do_filter, lower_case, res_wrds):
+ if do_filter:
+ sent = re.sub("-", " ", sent)
+ sent = re.sub("—", " ", sent)
+ if len(res_wrds) > 0:
+ wrds = sent.split()
+ wrds = ["SPLIT_ME " + w + " SPLIT_ME" if w in res_wrds else w for w in wrds]
+ sents = [x.strip() for x in " ".join(wrds).split("SPLIT_ME") if x.strip() != ""]
+ else:
+ sents = [sent]
+ if lower_case:
+ sents = [s.lower() if s not in res_wrds else s for s in sents]
+ return sents
+
+
+def dup_pho(sent, dup_v_num, dup_c_num):
+ """
+ duplicate phoneme defined as cmudict
+ http://www.speech.cs.cmu.edu/cgi-bin/cmudict
+ """
+ if dup_v_num == 1 and dup_c_num == 1:
+ return sent
+ ns = []
+ for p in sent:
+ ns.append(p)
+ if re.search(r"\d$", p):
+ for i in range(1, dup_v_num):
+ ns.append(f"{p}-{i}P")
+ elif re.search(r"\w", p):
+ for i in range(1, dup_c_num):
+ ns.append(f"{p}-{i}P")
+ return ns
+
+
+def add_word_start(sent):
+ ns = []
+ do_add = True
+ ws = "▁"
+ for p in sent:
+ if do_add:
+ p = ws + p
+ do_add = False
+ if p == " ":
+ do_add = True
+ else:
+ ns.append(p)
+ return ns
+
+
+def load_reserve_word(reserve_word):
+ if reserve_word == "":
+ return []
+ with open(reserve_word, "r") as fp:
+ res_wrds = [x.strip().split() for x in fp.readlines() if x.strip() != ""]
+ assert sum([0 if len(x) == 2 else 1 for x in res_wrds]) == 0
+ res_wrds = dict(res_wrds)
+ return res_wrds
+
+
+def process_sents(sents, args):
+ g2p = G2p()
+ out_sents = []
+ res_wrds = load_reserve_word(args.reserve_word)
+ for sent in sents:
+ col1 = ""
+ if args.reserve_first_column:
+ col1, sent = sent.split(None, 1)
+ sent = process_sent(sent, g2p, res_wrds, args)
+ if args.reserve_first_column and col1 != "":
+ sent = f"{col1} {sent}"
+ out_sents.append(sent)
+ return out_sents
+
+
+def main():
+ args = parse()
+ out_sents = []
+ with open(args.data_path, "r") as fp:
+ sent_list = [x.strip() for x in fp.readlines()]
+ if args.parallel_process_num > 1:
+ try:
+ import submitit
+ except ImportError:
+ logger.warn(
+ "submitit is not found and only one job is used to process the data"
+ )
+ submitit = None
+
+ if args.parallel_process_num == 1 or submitit is None:
+ out_sents = process_sents(sent_list, args)
+ else:
+ # process sentences with parallel computation
+ lsize = len(sent_list) // args.parallel_process_num + 1
+ executor = submitit.AutoExecutor(folder=args.logdir)
+ executor.update_parameters(timeout_min=1000, cpus_per_task=4)
+ jobs = []
+ for i in range(args.parallel_process_num):
+ job = executor.submit(
+ process_sents, sent_list[lsize * i : lsize * (i + 1)], args
+ )
+ jobs.append(job)
+ is_running = True
+ while is_running:
+ time.sleep(5)
+ is_running = sum([job.done() for job in jobs]) < len(jobs)
+ out_sents = list(itertools.chain.from_iterable([job.result() for job in jobs]))
+ with open(args.out_path, "w") as fp:
+ fp.write("\n".join(out_sents) + "\n")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/fairseq/examples/speech_text_joint_to_text/tasks/__init__.py b/fairseq/examples/speech_text_joint_to_text/tasks/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d878278475fb24cf6b97d66d784e657567f5aa80
--- /dev/null
+++ b/fairseq/examples/speech_text_joint_to_text/tasks/__init__.py
@@ -0,0 +1,12 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import importlib
+import os
+
+for file in os.listdir(os.path.dirname(__file__)):
+ if file.endswith(".py") and not file.startswith("_"):
+ task_name = file[: file.find(".py")]
+ importlib.import_module("examples.speech_text_joint_to_text.tasks." + task_name)
diff --git a/fairseq/examples/speech_text_joint_to_text/tasks/speech_text_joint.py b/fairseq/examples/speech_text_joint_to_text/tasks/speech_text_joint.py
new file mode 100644
index 0000000000000000000000000000000000000000..f2b3966d2d6b103f3dc2ff170c12ab9663875684
--- /dev/null
+++ b/fairseq/examples/speech_text_joint_to_text/tasks/speech_text_joint.py
@@ -0,0 +1,372 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+import logging
+import os
+from argparse import Namespace
+from pathlib import Path
+
+import torch
+from fairseq.data import (
+ encoders,
+ Dictionary,
+ ResamplingDataset,
+ TransformEosLangPairDataset,
+ ConcatDataset,
+)
+from fairseq.data.iterators import GroupedEpochBatchIterator
+from fairseq.data.audio.multi_modality_dataset import (
+ MultiModalityDataset,
+ LangPairMaskDataset,
+ ModalityDatasetItem,
+)
+from fairseq.data.audio.speech_to_text_dataset import SpeechToTextDataset, SpeechToTextDatasetCreator
+from fairseq.data.audio.speech_to_text_joint_dataset import (
+ S2TJointDataConfig,
+ SpeechToTextJointDatasetCreator,
+)
+from fairseq.tasks import register_task
+from fairseq.tasks.speech_to_text import SpeechToTextTask
+from fairseq.tasks.translation import load_langpair_dataset
+
+logger = logging.getLogger(__name__)
+LANG_TAG_TEMPLATE = ""
+
+
+@register_task("speech_text_joint_to_text")
+class SpeechTextJointToTextTask(SpeechToTextTask):
+ """
+ Task for joint training speech and text to text.
+ """
+
+ @classmethod
+ def add_args(cls, parser):
+ """Add task-specific arguments to the parser."""
+ super(SpeechTextJointToTextTask, cls).add_args(parser)
+ ###
+ parser.add_argument(
+ "--parallel-text-data",
+ default="",
+ help="path to parallel text data directory",
+ )
+ parser.add_argument(
+ "--max-tokens-text",
+ type=int,
+ metavar="N",
+ help="maximum tokens for encoder text input ",
+ )
+ parser.add_argument(
+ "--max-positions-text",
+ type=int,
+ metavar="N",
+ default=400,
+ help="maximum tokens for per encoder text input ",
+ )
+ parser.add_argument(
+ "--langpairs",
+ default=None,
+ metavar="S",
+ help='language pairs for text training, separated with ","',
+ )
+ parser.add_argument(
+ "--speech-sample-ratio",
+ default=1,
+ type=float,
+ metavar="N",
+ help="Multiple Ratio for speech dataset with transcripts ",
+ )
+ parser.add_argument(
+ "--text-sample-ratio",
+ default=1,
+ type=float,
+ metavar="N",
+ help="Multiple Ratio for text set ",
+ )
+ parser.add_argument(
+ "--update-mix-data",
+ action="store_true",
+ help="use mixed data in one update when update-freq > 1",
+ )
+ parser.add_argument(
+ "--load-speech-only",
+ action="store_true",
+ help="load speech data only",
+ )
+ parser.add_argument(
+ "--mask-text-ratio",
+ type=float,
+ metavar="V",
+ default=0.0,
+ help="mask V source tokens for text only mode",
+ )
+ parser.add_argument(
+ "--mask-text-type",
+ default="random",
+ choices=["random", "tail"],
+ help="mask text typed",
+ )
+ parser.add_argument(
+ "--noise-token",
+ default="",
+ help="noise token for masking src text tokens if mask-text-ratio > 0",
+ )
+ parser.add_argument(
+ "--infer-target-lang",
+ default="",
+ metavar="S",
+ help="target language for inference",
+ )
+
+ def __init__(self, args, src_dict, tgt_dict, infer_tgt_lang_id=None):
+ super().__init__(args, tgt_dict)
+ self.src_dict = src_dict
+ self.data_cfg = S2TJointDataConfig(Path(args.data) / args.config_yaml)
+ assert self.tgt_dict.pad() == self.src_dict.pad()
+ assert self.tgt_dict.eos() == self.src_dict.eos()
+ self.speech_only = args.load_speech_only
+ self._infer_tgt_lang_id = infer_tgt_lang_id
+
+ @classmethod
+ def setup_task(cls, args, **kwargs):
+ """Setup the task (e.g., load dictionaries)."""
+ data_cfg = S2TJointDataConfig(Path(args.data) / args.config_yaml)
+ tgt_dict_path = Path(args.data) / data_cfg.vocab_filename
+ src_dict_path = Path(args.data) / data_cfg.src_vocab_filename
+ if (not os.path.isfile(src_dict_path)) or (not os.path.isfile(tgt_dict_path)):
+ raise FileNotFoundError("Dict not found: {}".format(args.data))
+ src_dict = Dictionary.load(src_dict_path.as_posix())
+ tgt_dict = Dictionary.load(tgt_dict_path.as_posix())
+
+ print("| src dictionary: {} types".format(len(src_dict)))
+ print("| tgt dictionary: {} types".format(len(tgt_dict)))
+
+ if args.parallel_text_data != "":
+ if not os.path.isabs(args.parallel_text_data):
+ args.parallel_text_data = os.path.join(
+ args.data, args.parallel_text_data
+ )
+
+ if args.langpairs is None:
+ raise Exception(
+ "Could not infer language pair, please provide it explicitly"
+ )
+ infer_tgt_lang_id = None
+ if args.infer_target_lang != "" and data_cfg.prepend_tgt_lang_tag_no_change:
+ tgt_lang_tag = SpeechToTextDataset.LANG_TAG_TEMPLATE.format(
+ args.infer_target_lang
+ )
+ infer_tgt_lang_id = tgt_dict.index(tgt_lang_tag)
+ assert infer_tgt_lang_id != tgt_dict.unk()
+ return cls(args, src_dict, tgt_dict, infer_tgt_lang_id=infer_tgt_lang_id)
+
+ def load_langpair_dataset(self, prepend_tgt_lang_tag=False, sampling_alpha=1.0, epoch=0):
+ lang_pairs = []
+ text_dataset = None
+ split = "train"
+ for lp in self.args.langpairs.split(","):
+ src, tgt = lp.split("-")
+ text_dataset = load_langpair_dataset(
+ self.args.parallel_text_data,
+ split,
+ src,
+ self.src_dict,
+ tgt,
+ self.tgt_dict,
+ combine=True,
+ dataset_impl=None,
+ upsample_primary=1,
+ left_pad_source=False,
+ left_pad_target=False,
+ max_source_positions=self.args.max_positions_text,
+ max_target_positions=self.args.max_target_positions,
+ load_alignments=False,
+ truncate_source=False,
+ )
+ if prepend_tgt_lang_tag:
+ # TODO
+ text_dataset = TransformEosLangPairDataset(
+ text_dataset,
+ src_eos=self.src_dict.eos(),
+ tgt_bos=self.tgt_dict.eos(), # 'prev_output_tokens' starts with eos
+ new_tgt_bos=self.tgt_dict.index(LANG_TAG_TEMPLATE.format(tgt)),
+ )
+ lang_pairs.append(text_dataset)
+ if len(lang_pairs) > 1:
+ if sampling_alpha != 1.0:
+ size_ratios = SpeechToTextDatasetCreator.get_size_ratios(
+ self.args.langpairs.split(","),
+ [len(s) for s in lang_pairs],
+ alpha=sampling_alpha,
+ )
+ lang_pairs = [
+ ResamplingDataset(
+ d, size_ratio=r, epoch=epoch, replace=(r >= 1.0)
+ )
+ for d, r in zip(lang_pairs, size_ratios)
+ ]
+ return ConcatDataset(lang_pairs)
+ return text_dataset
+
+ def inference_step(
+ self, generator, models, sample, prefix_tokens=None, constraints=None
+ ):
+ with torch.no_grad():
+ return generator.generate(
+ models,
+ sample,
+ prefix_tokens=prefix_tokens,
+ constraints=constraints,
+ bos_token=self._infer_tgt_lang_id,
+ )
+
+ def build_src_tokenizer(self, args):
+ logger.info(f"src-pre-tokenizer: {self.data_cfg.src_pre_tokenizer}")
+ return encoders.build_tokenizer(Namespace(**self.data_cfg.src_pre_tokenizer))
+
+ def build_src_bpe(self, args):
+ logger.info(f"tokenizer: {self.data_cfg.src_bpe_tokenizer}")
+ return encoders.build_bpe(Namespace(**self.data_cfg.src_bpe_tokenizer))
+
+ def load_dataset(self, split, epoch=1, combine=False, **kwargs):
+ """Load a given dataset split.
+
+ Args:
+ split (str): name of the split (e.g., train, valid, test)
+ """
+ is_train_split = split.startswith("train")
+ pre_tokenizer = self.build_tokenizer(self.args)
+ bpe_tokenizer = self.build_bpe(self.args)
+ src_pre_tokenizer = self.build_src_tokenizer(self.args)
+ src_bpe_tokenizer = self.build_src_bpe(self.args)
+ ast_dataset = SpeechToTextJointDatasetCreator.from_tsv(
+ self.args.data,
+ self.data_cfg,
+ split,
+ self.tgt_dict,
+ src_dict=None if self.speech_only else self.src_dict,
+ pre_tokenizer=pre_tokenizer,
+ bpe_tokenizer=bpe_tokenizer,
+ src_pre_tokenizer=src_pre_tokenizer,
+ src_bpe_tokenizer=src_bpe_tokenizer,
+ is_train_split=is_train_split,
+ epoch=epoch,
+ seed=self.args.seed,
+ )
+ noise_token_id = -1
+ text_dataset = None
+ if self.args.parallel_text_data != "" and is_train_split:
+ text_dataset = self.load_langpair_dataset(
+ self.data_cfg.prepend_tgt_lang_tag_no_change,
+ 1.0,
+ epoch=epoch,
+ )
+ if self.args.mask_text_ratio > 0:
+ # add mask
+ noise_token_id = (
+ self.src_dict.unk()
+ if self.args.noise_token == ""
+ else self.src_dict.index(self.args.noise_token)
+ )
+ text_dataset = LangPairMaskDataset(
+ text_dataset,
+ src_bos=self.src_dict.bos(),
+ src_eos=self.src_dict.eos(),
+ noise_id=noise_token_id,
+ mask_ratio=self.args.mask_text_ratio,
+ mask_type=self.args.mask_text_type,
+ )
+
+ if text_dataset is not None:
+ mdsets = [
+ ModalityDatasetItem(
+ "sup_speech",
+ ast_dataset,
+ (self.args.max_source_positions, self.args.max_target_positions),
+ self.args.max_tokens,
+ self.args.batch_size,
+ ),
+ ModalityDatasetItem(
+ "text",
+ text_dataset,
+ (self.args.max_positions_text, self.args.max_target_positions),
+ self.args.max_tokens_text
+ if self.args.max_tokens_text is not None
+ else self.args.max_tokens,
+ self.args.batch_size,
+ ),
+ ]
+ ast_dataset = MultiModalityDataset(mdsets)
+ self.datasets[split] = ast_dataset
+
+ @property
+ def target_dictionary(self):
+ """Return the :class:`~fairseq.data.Dictionary` for the language
+ model."""
+ return self.tgt_dict
+
+ @property
+ def source_dictionary(self):
+ """Return the source :class:`~fairseq.data.Dictionary` (if applicable
+ for this task)."""
+ return None if self.speech_only else self.src_dict
+
+ def get_batch_iterator(
+ self,
+ dataset,
+ max_tokens=None,
+ max_sentences=None,
+ max_positions=None,
+ ignore_invalid_inputs=False,
+ required_batch_size_multiple=1,
+ seed=1,
+ num_shards=1,
+ shard_id=0,
+ num_workers=0,
+ epoch=0,
+ data_buffer_size=0,
+ disable_iterator_cache=False,
+ ):
+
+ if not isinstance(dataset, MultiModalityDataset):
+ return super(SpeechTextJointToTextTask, self).get_batch_iterator(
+ dataset,
+ max_tokens,
+ max_sentences,
+ max_positions,
+ ignore_invalid_inputs,
+ required_batch_size_multiple,
+ seed,
+ num_shards,
+ shard_id,
+ num_workers,
+ epoch,
+ data_buffer_size,
+ disable_iterator_cache,
+ )
+
+ mult_ratio = [self.args.speech_sample_ratio, self.args.text_sample_ratio]
+ assert len(dataset.datasets) == 2
+
+ # initialize the dataset with the correct starting epoch
+ dataset.set_epoch(epoch)
+
+ batch_samplers = dataset.get_batch_samplers(
+ mult_ratio, required_batch_size_multiple, seed
+ )
+
+ # return a reusable, sharded iterator
+ epoch_iter = GroupedEpochBatchIterator(
+ dataset=dataset,
+ collate_fn=dataset.collater,
+ batch_samplers=batch_samplers,
+ seed=seed,
+ num_shards=num_shards,
+ shard_id=shard_id,
+ num_workers=num_workers,
+ epoch=epoch,
+ mult_rate=1 if self.args.update_mix_data else max(self.args.update_freq),
+ buffer_size=data_buffer_size,
+ )
+ self.dataset_to_epoch_iter[dataset] = {} # refresh it every epoch
+ return epoch_iter
diff --git a/fairseq/examples/speech_to_text/README.md b/fairseq/examples/speech_to_text/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..f639d300d342f8de1392c98bfc44ec8690188539
--- /dev/null
+++ b/fairseq/examples/speech_to_text/README.md
@@ -0,0 +1,77 @@
+# Speech-to-Text (S2T) Modeling
+
+[https://www.aclweb.org/anthology/2020.aacl-demo.6](https://www.aclweb.org/anthology/2020.aacl-demo.6.pdf)
+
+Speech recognition (ASR) and speech-to-text translation (ST) with fairseq.
+
+## Data Preparation
+S2T modeling data consists of source speech features, target text and other optional information
+(source text, speaker id, etc.). Fairseq S2T uses per-dataset-split TSV manifest files
+to store these information. Each data field is represented by a column in the TSV file.
+
+Unlike text token embeddings, speech features (e.g. log mel-scale filter banks) are usually fixed
+during model training and can be pre-computed. The manifest file contains the path to
+either the feature file in NumPy format or the WAV/FLAC audio file. For the latter,
+features will be extracted on-the-fly by fairseq S2T. Optionally, feature/audio files can be packed
+into uncompressed ZIP files (then accessed via byte offset and length) to improve I/O performance.
+
+Fairseq S2T also employs a YAML file for data related configurations: tokenizer type and dictionary path
+for the target text, feature transforms such as CMVN (cepstral mean and variance normalization) and SpecAugment,
+temperature-based resampling, etc.
+
+## Model Training
+Fairseq S2T uses the unified `fairseq-train` interface for model training. It requires arguments `--task speech_to_text`,
+ `--arch ` and `--config-yaml `.
+
+## Inference & Evaluation
+Fairseq S2T uses the unified `fairseq-generate`/`fairseq-interactive` interface for inference and evaluation. It
+requires arguments `--task speech_to_text` and `--config-yaml `. The interactive console takes
+audio paths (one per line) as inputs.
+
+
+## Examples
+- [Speech Recognition (ASR) on LibriSpeech](docs/librispeech_example.md)
+
+- [Speech-to-Text Translation (ST) on MuST-C](docs/mustc_example.md)
+
+- [Speech-to-Text Translation (ST) on CoVoST 2](docs/covost_example.md)
+
+- [Speech-to-Text Translation (ST) on Multilingual TEDx](docs/mtedx_example.md)
+- [Simultaneous Speech-to-Text Translation (SimulST) on MuST-C](docs/simulst_mustc_example.md)
+
+## Updates
+- 02/04/2021: Added interactive decoding (`fairseq-interactive`) support. Examples:
+ [ASR (LibriSpeech)](docs/librispeech_example.md#interactive-decoding)
+ and [ST (CoVoST 2)](docs/covost_example.md#interactive-decoding).
+- 01/08/2021: Several fixes for S2T Transformer model, inference-time de-tokenization, scorer configuration and data
+ preparation scripts. We also add pre-trained models to the examples and revise the instructions.
+ Breaking changes: the data preparation scripts now extract filterbank features without CMVN. CMVN is instead applied
+ on-the-fly (defined in the config YAML).
+
+## What's Next
+- We are migrating the old fairseq [ASR example](../speech_recognition) into this S2T framework and
+ merging the features from both sides.
+- The following papers also base their experiments on fairseq S2T. We are adding more examples for replication.
+ - [Improving Cross-Lingual Transfer Learning for End-to-End Speech Recognition with Speech Translation (Wang et al., 2020)](https://arxiv.org/abs/2006.05474)
+ - [Self-Supervised Representations Improve End-to-End Speech Translation (Wu et al., 2020)](https://arxiv.org/abs/2006.12124)
+ - [Self-Training for End-to-End Speech Translation (Pino et al., 2020)](https://arxiv.org/abs/2006.02490)
+ - [CoVoST: A Diverse Multilingual Speech-To-Text Translation Corpus (Wang et al., 2020)](https://arxiv.org/abs/2002.01320)
+ - [Harnessing Indirect Training Data for End-to-End Automatic Speech Translation: Tricks of the Trade (Pino et al., 2019)](https://arxiv.org/abs/1909.06515)
+
+## Citation
+Please cite as:
+```
+@inproceedings{wang2020fairseqs2t,
+ title = {fairseq S2T: Fast Speech-to-Text Modeling with fairseq},
+ author = {Changhan Wang and Yun Tang and Xutai Ma and Anne Wu and Dmytro Okhonko and Juan Pino},
+ booktitle = {Proceedings of the 2020 Conference of the Asian Chapter of the Association for Computational Linguistics (AACL): System Demonstrations},
+ year = {2020},
+}
+
+@inproceedings{ott2019fairseq,
+ title = {fairseq: A Fast, Extensible Toolkit for Sequence Modeling},
+ author = {Myle Ott and Sergey Edunov and Alexei Baevski and Angela Fan and Sam Gross and Nathan Ng and David Grangier and Michael Auli},
+ booktitle = {Proceedings of NAACL-HLT 2019: Demonstrations},
+ year = {2019},
+}
+```
diff --git a/fairseq/examples/speech_to_text/data_utils.py b/fairseq/examples/speech_to_text/data_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..41afac0bf8f6d70e06bee1a34e220ab396ec247d
--- /dev/null
+++ b/fairseq/examples/speech_to_text/data_utils.py
@@ -0,0 +1,382 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import csv
+from pathlib import Path
+import zipfile
+from functools import reduce
+from multiprocessing import cpu_count
+from typing import Any, Dict, List, Optional, Union
+import io
+
+import numpy as np
+import pandas as pd
+import sentencepiece as sp
+from fairseq.data.audio.audio_utils import (
+ convert_waveform, _get_kaldi_fbank, _get_torchaudio_fbank, is_npy_data,
+ is_sf_audio_data
+)
+import torch
+import soundfile as sf
+from tqdm import tqdm
+
+
+UNK_TOKEN, UNK_TOKEN_ID = "", 3
+BOS_TOKEN, BOS_TOKEN_ID = "", 0
+EOS_TOKEN, EOS_TOKEN_ID = "", 2
+PAD_TOKEN, PAD_TOKEN_ID = "", 1
+
+
+def gen_vocab(
+ input_path: Path, output_path_prefix: Path, model_type="bpe",
+ vocab_size=1000, special_symbols: Optional[List[str]] = None
+):
+ # Train SentencePiece Model
+ arguments = [
+ f"--input={input_path.as_posix()}",
+ f"--model_prefix={output_path_prefix.as_posix()}",
+ f"--model_type={model_type}",
+ f"--vocab_size={vocab_size}",
+ "--character_coverage=1.0",
+ f"--num_threads={cpu_count()}",
+ f"--unk_id={UNK_TOKEN_ID}",
+ f"--bos_id={BOS_TOKEN_ID}",
+ f"--eos_id={EOS_TOKEN_ID}",
+ f"--pad_id={PAD_TOKEN_ID}",
+ ]
+ if special_symbols is not None:
+ _special_symbols = ",".join(special_symbols)
+ arguments.append(f"--user_defined_symbols={_special_symbols}")
+ sp.SentencePieceTrainer.Train(" ".join(arguments))
+ # Export fairseq dictionary
+ spm = sp.SentencePieceProcessor()
+ spm.Load(output_path_prefix.as_posix() + ".model")
+ vocab = {i: spm.IdToPiece(i) for i in range(spm.GetPieceSize())}
+ assert (
+ vocab.get(UNK_TOKEN_ID) == UNK_TOKEN
+ and vocab.get(PAD_TOKEN_ID) == PAD_TOKEN
+ and vocab.get(BOS_TOKEN_ID) == BOS_TOKEN
+ and vocab.get(EOS_TOKEN_ID) == EOS_TOKEN
+ )
+ vocab = {
+ i: s
+ for i, s in vocab.items()
+ if s not in {UNK_TOKEN, BOS_TOKEN, EOS_TOKEN, PAD_TOKEN}
+ }
+ with open(output_path_prefix.as_posix() + ".txt", "w") as f_out:
+ for _, s in sorted(vocab.items(), key=lambda x: x[0]):
+ f_out.write(f"{s} 1\n")
+
+
+def extract_fbank_features(
+ waveform: torch.FloatTensor,
+ sample_rate: int,
+ output_path: Optional[Path] = None,
+ n_mel_bins: int = 80,
+ overwrite: bool = False,
+):
+ if output_path is not None and output_path.is_file() and not overwrite:
+ return
+
+ _waveform = convert_waveform(waveform, sample_rate, to_mono=True)
+ # Kaldi compliance: 16-bit signed integers
+ _waveform = _waveform * (2 ** 15)
+ _waveform = _waveform.numpy()
+
+ features = _get_kaldi_fbank(_waveform, sample_rate, n_mel_bins)
+ if features is None:
+ features = _get_torchaudio_fbank(_waveform, sample_rate, n_mel_bins)
+ if features is None:
+ raise ImportError(
+ "Please install pyKaldi or torchaudio to enable fbank feature extraction"
+ )
+
+ if output_path is not None:
+ np.save(output_path.as_posix(), features)
+ return features
+
+
+def create_zip(data_root: Path, zip_path: Path):
+ paths = list(data_root.glob("*.npy"))
+ with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_STORED) as f:
+ for path in tqdm(paths):
+ f.write(path, arcname=path.name)
+
+
+def get_zip_manifest(
+ zip_path: Path, zip_root: Optional[Path] = None, is_audio=False
+):
+ _zip_path = Path.joinpath(zip_root or Path(""), zip_path)
+ with zipfile.ZipFile(_zip_path, mode="r") as f:
+ info = f.infolist()
+ paths, lengths = {}, {}
+ for i in tqdm(info):
+ utt_id = Path(i.filename).stem
+ offset, file_size = i.header_offset + 30 + len(i.filename), i.file_size
+ paths[utt_id] = f"{zip_path.as_posix()}:{offset}:{file_size}"
+ with open(_zip_path, "rb") as f:
+ f.seek(offset)
+ byte_data = f.read(file_size)
+ assert len(byte_data) > 1
+ if is_audio:
+ assert is_sf_audio_data(byte_data), i
+ else:
+ assert is_npy_data(byte_data), i
+ byte_data_fp = io.BytesIO(byte_data)
+ if is_audio:
+ lengths[utt_id] = sf.info(byte_data_fp).frames
+ else:
+ lengths[utt_id] = np.load(byte_data_fp).shape[0]
+ return paths, lengths
+
+
+def gen_config_yaml(
+ manifest_root: Path,
+ spm_filename: Optional[str] = None,
+ vocab_name: Optional[str] = None,
+ yaml_filename: str = "config.yaml",
+ specaugment_policy: Optional[str] = "lb",
+ prepend_tgt_lang_tag: bool = False,
+ sampling_alpha: Optional[float] = None,
+ input_channels: Optional[int] = 1,
+ input_feat_per_channel: Optional[int] = 80,
+ audio_root: str = "",
+ cmvn_type: str = "utterance",
+ gcmvn_path: Optional[Path] = None,
+ extra=None
+):
+ manifest_root = manifest_root.absolute()
+ writer = S2TDataConfigWriter(manifest_root / yaml_filename)
+ assert spm_filename is not None or vocab_name is not None
+ vocab_name = spm_filename.replace(".model", ".txt") if vocab_name is None \
+ else vocab_name
+ writer.set_vocab_filename(vocab_name)
+ if input_channels is not None:
+ writer.set_input_channels(input_channels)
+ if input_feat_per_channel is not None:
+ writer.set_input_feat_per_channel(input_feat_per_channel)
+ specaugment_setters = {
+ "lb": writer.set_specaugment_lb_policy,
+ "ld": writer.set_specaugment_ld_policy,
+ "sm": writer.set_specaugment_sm_policy,
+ "ss": writer.set_specaugment_ss_policy,
+ }
+ specaugment_setter = specaugment_setters.get(specaugment_policy, None)
+ if specaugment_setter is not None:
+ specaugment_setter()
+ if spm_filename is not None:
+ writer.set_bpe_tokenizer(
+ {
+ "bpe": "sentencepiece",
+ "sentencepiece_model": (manifest_root / spm_filename).as_posix(),
+ }
+ )
+ if prepend_tgt_lang_tag:
+ writer.set_prepend_tgt_lang_tag(True)
+ if sampling_alpha is not None:
+ writer.set_sampling_alpha(sampling_alpha)
+
+ if cmvn_type not in ["global", "utterance"]:
+ raise NotImplementedError
+
+ if specaugment_policy is not None:
+ writer.set_feature_transforms(
+ "_train", [f"{cmvn_type}_cmvn", "specaugment"]
+ )
+ writer.set_feature_transforms("*", [f"{cmvn_type}_cmvn"])
+
+ if cmvn_type == "global":
+ if gcmvn_path is None:
+ raise ValueError("Please provide path of global cmvn file.")
+ else:
+ writer.set_global_cmvn(gcmvn_path.as_posix())
+
+ if len(audio_root) > 0:
+ writer.set_audio_root(audio_root)
+
+ if extra is not None:
+ writer.set_extra(extra)
+ writer.flush()
+
+
+def load_df_from_tsv(path: Union[str, Path]) -> pd.DataFrame:
+ _path = path if isinstance(path, str) else path.as_posix()
+ return pd.read_csv(
+ _path,
+ sep="\t",
+ header=0,
+ encoding="utf-8",
+ escapechar="\\",
+ quoting=csv.QUOTE_NONE,
+ na_filter=False,
+ )
+
+
+def save_df_to_tsv(dataframe, path: Union[str, Path]):
+ _path = path if isinstance(path, str) else path.as_posix()
+ dataframe.to_csv(
+ _path,
+ sep="\t",
+ header=True,
+ index=False,
+ encoding="utf-8",
+ escapechar="\\",
+ quoting=csv.QUOTE_NONE,
+ )
+
+
+def load_tsv_to_dicts(path: Union[str, Path]) -> List[dict]:
+ with open(path, "r") as f:
+ reader = csv.DictReader(
+ f,
+ delimiter="\t",
+ quotechar=None,
+ doublequote=False,
+ lineterminator="\n",
+ quoting=csv.QUOTE_NONE,
+ )
+ rows = [dict(e) for e in reader]
+ return rows
+
+
+def filter_manifest_df(
+ df, is_train_split=False, extra_filters=None, min_n_frames=5, max_n_frames=3000
+):
+ filters = {
+ "no speech": df["audio"] == "",
+ f"short speech (<{min_n_frames} frames)": df["n_frames"] < min_n_frames,
+ "empty sentence": df["tgt_text"] == "",
+ }
+ if is_train_split:
+ filters[f"long speech (>{max_n_frames} frames)"] = df["n_frames"] > max_n_frames
+ if extra_filters is not None:
+ filters.update(extra_filters)
+ invalid = reduce(lambda x, y: x | y, filters.values())
+ valid = ~invalid
+ print(
+ "| "
+ + ", ".join(f"{n}: {f.sum()}" for n, f in filters.items())
+ + f", total {invalid.sum()} filtered, {valid.sum()} remained."
+ )
+ return df[valid]
+
+
+def cal_gcmvn_stats(features_list):
+ features = np.concatenate(features_list)
+ square_sums = (features ** 2).sum(axis=0)
+ mean = features.mean(axis=0)
+ features = np.subtract(features, mean)
+ var = square_sums / features.shape[0] - mean ** 2
+ std = np.sqrt(np.maximum(var, 1e-8))
+ return {"mean": mean.astype("float32"), "std": std.astype("float32")}
+
+
+class S2TDataConfigWriter(object):
+ DEFAULT_VOCAB_FILENAME = "dict.txt"
+ DEFAULT_INPUT_FEAT_PER_CHANNEL = 80
+ DEFAULT_INPUT_CHANNELS = 1
+
+ def __init__(self, yaml_path: Path):
+ try:
+ import yaml
+ except ImportError:
+ print("Please install PyYAML for S2T data config YAML files")
+ self.yaml = yaml
+ self.yaml_path = yaml_path
+ self.config = {}
+
+ def flush(self):
+ with open(self.yaml_path, "w") as f:
+ self.yaml.dump(self.config, f)
+
+ def set_audio_root(self, audio_root=""):
+ self.config["audio_root"] = audio_root
+
+ def set_vocab_filename(self, vocab_filename: str = "dict.txt"):
+ self.config["vocab_filename"] = vocab_filename
+
+ def set_specaugment(
+ self,
+ time_wrap_w: int,
+ freq_mask_n: int,
+ freq_mask_f: int,
+ time_mask_n: int,
+ time_mask_t: int,
+ time_mask_p: float,
+ ):
+ self.config["specaugment"] = {
+ "time_wrap_W": time_wrap_w,
+ "freq_mask_N": freq_mask_n,
+ "freq_mask_F": freq_mask_f,
+ "time_mask_N": time_mask_n,
+ "time_mask_T": time_mask_t,
+ "time_mask_p": time_mask_p,
+ }
+
+ def set_specaugment_lb_policy(self):
+ self.set_specaugment(
+ time_wrap_w=0,
+ freq_mask_n=1,
+ freq_mask_f=27,
+ time_mask_n=1,
+ time_mask_t=100,
+ time_mask_p=1.0,
+ )
+
+ def set_specaugment_ld_policy(self):
+ self.set_specaugment(
+ time_wrap_w=0,
+ freq_mask_n=2,
+ freq_mask_f=27,
+ time_mask_n=2,
+ time_mask_t=100,
+ time_mask_p=1.0,
+ )
+
+ def set_specaugment_sm_policy(self):
+ self.set_specaugment(
+ time_wrap_w=0,
+ freq_mask_n=2,
+ freq_mask_f=15,
+ time_mask_n=2,
+ time_mask_t=70,
+ time_mask_p=0.2,
+ )
+
+ def set_specaugment_ss_policy(self):
+ self.set_specaugment(
+ time_wrap_w=0,
+ freq_mask_n=2,
+ freq_mask_f=27,
+ time_mask_n=2,
+ time_mask_t=70,
+ time_mask_p=0.2,
+ )
+
+ def set_input_channels(self, input_channels: int = 1):
+ self.config["input_channels"] = input_channels
+
+ def set_input_feat_per_channel(self, input_feat_per_channel: int = 80):
+ self.config["input_feat_per_channel"] = input_feat_per_channel
+
+ def set_bpe_tokenizer(self, bpe_tokenizer: Dict[str, Any]):
+ self.config["bpe_tokenizer"] = bpe_tokenizer
+
+ def set_global_cmvn(self, stats_npz_path: str):
+ self.config["global_cmvn"] = {"stats_npz_path": stats_npz_path}
+
+ def set_feature_transforms(self, split: str, transforms: List[str]):
+ if "transforms" not in self.config:
+ self.config["transforms"] = {}
+ self.config["transforms"][split] = transforms
+
+ def set_prepend_tgt_lang_tag(self, flag: bool = True):
+ self.config["prepend_tgt_lang_tag"] = flag
+
+ def set_sampling_alpha(self, sampling_alpha: float = 1.0):
+ self.config["sampling_alpha"] = sampling_alpha
+
+ def set_extra(self, data):
+ self.config.update(data)
diff --git a/fairseq/examples/speech_to_text/docs/covost_example.md b/fairseq/examples/speech_to_text/docs/covost_example.md
new file mode 100644
index 0000000000000000000000000000000000000000..16447f041e4751f79d9f7848b33ef2ff943d63c2
--- /dev/null
+++ b/fairseq/examples/speech_to_text/docs/covost_example.md
@@ -0,0 +1,102 @@
+[[Back]](..)
+
+# S2T Example: ST on CoVoST
+We replicate the experiments in
+[CoVoST 2 and Massively Multilingual Speech-to-Text Translation (Wang et al., 2020)](https://arxiv.org/abs/2007.10310).
+
+## Data Preparation
+[Download](https://commonvoice.mozilla.org/en/datasets) and unpack Common Voice v4 to a path
+`${COVOST_ROOT}/${SOURCE_LANG_ID}`, then preprocess it with
+```bash
+# additional Python packages for S2T data processing/model training
+pip install pandas torchaudio sentencepiece
+
+# En ASR
+python examples/speech_to_text/prep_covost_data.py \
+ --data-root ${COVOST_ROOT} --vocab-type char --src-lang en
+# ST
+python examples/speech_to_text/prep_covost_data.py \
+ --data-root ${COVOST_ROOT} --vocab-type char \
+ --src-lang fr --tgt-lang en
+```
+The generated files (manifest, features, vocabulary and data configuration) will be added to
+`${COVOST_ROOT}/${SOURCE_LANG_ID}`.
+
+Download our vocabulary files if you want to use our pre-trained models:
+- ASR: [En](https://dl.fbaipublicfiles.com/fairseq/s2t/covost2_en_asr_vocab_char.zip)
+- ST: [Fr-En](https://dl.fbaipublicfiles.com/fairseq/s2t/covost2_fr_en_st_vocab_char.zip), [De-En](https://dl.fbaipublicfiles.com/fairseq/s2t/covost2_de_en_st_vocab_char.zip), [Es-En](https://dl.fbaipublicfiles.com/fairseq/s2t/covost2_es_en_st_vocab_char.zip), [Ca-En](https://dl.fbaipublicfiles.com/fairseq/s2t/covost2_ca_en_st_vocab_char.zip), [En-De](https://dl.fbaipublicfiles.com/fairseq/s2t/covost2_en_de_st_vocab_char.zip), [En-Ca](https://dl.fbaipublicfiles.com/fairseq/s2t/covost2_en_ca_st_vocab_char.zip), [En-Fa](https://dl.fbaipublicfiles.com/fairseq/s2t/covost2_en_fa_st_vocab_char.zip), [En-Et](https://dl.fbaipublicfiles.com/fairseq/s2t/covost2_en_et_st_vocab_char.zip)
+
+## ASR
+#### Training
+We train an En ASR model for encoder pre-training of all ST models:
+```bash
+fairseq-train ${COVOST_ROOT}/en \
+ --config-yaml config_asr_en.yaml --train-subset train_asr_en --valid-subset dev_asr_en \
+ --save-dir ${ASR_SAVE_DIR} --num-workers 4 --max-tokens 50000 --max-update 60000 \
+ --task speech_to_text --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
+ --report-accuracy --arch s2t_transformer_s --dropout 0.15 --optimizer adam --lr 2e-3 \
+ --lr-scheduler inverse_sqrt --warmup-updates 10000 --clip-norm 10.0 --seed 1 --update-freq 8
+```
+where `ASR_SAVE_DIR` is the checkpoint root path. We set `--update-freq 8` to simulate 8 GPUs with 1 GPU.
+You may want to update it accordingly when using more than 1 GPU.
+
+#### Inference & Evaluation
+```bash
+CHECKPOINT_FILENAME=avg_last_10_checkpoint.pt
+python scripts/average_checkpoints.py \
+ --inputs ${ASR_SAVE_DIR} --num-epoch-checkpoints 10 \
+ --output "${ASR_SAVE_DIR}/${CHECKPOINT_FILENAME}"
+fairseq-generate ${COVOST_ROOT}/en \
+ --config-yaml config_asr_en.yaml --gen-subset test_asr_en --task speech_to_text \
+ --path ${ASR_SAVE_DIR}/${CHECKPOINT_FILENAME} --max-tokens 50000 --beam 5 \
+ --scoring wer --wer-tokenizer 13a --wer-lowercase --wer-remove-punct
+```
+#### Results
+| --arch | Params | En | Model |
+|---|---|---|---|
+| s2t_transformer_s | 31M | 25.6 | [Download](https://dl.fbaipublicfiles.com/fairseq/s2t/covost2_en_asr_transformer_s.pt) |
+
+## ST
+#### Training
+Fr-En as example:
+```bash
+fairseq-train ${COVOST_ROOT}/fr \
+ --config-yaml config_st_fr_en.yaml --train-subset train_st_fr_en --valid-subset dev_st_fr_en \
+ --save-dir ${ST_SAVE_DIR} --num-workers 4 --max-update 30000 --max-tokens 40000 \ # --max-tokens 50000 for en-*
+ --task speech_to_text --criterion label_smoothed_cross_entropy --label-smoothing 0.1 --report-accuracy \
+ --arch s2t_transformer_s --encoder-freezing-updates 1000 --optimizer adam --lr 2e-3 \
+ --lr-scheduler inverse_sqrt --warmup-updates 10000 --clip-norm 10.0 --seed 1 --update-freq 8 \
+ --load-pretrained-encoder-from ${ASR_SAVE_DIR}/${CHECKPOINT_FILENAME}
+```
+where `ST_SAVE_DIR` is the checkpoint root path. The ST encoder is pre-trained by En ASR for faster training and better
+performance: `--load-pretrained-encoder-from `. We set `--update-freq 8` to simulate 8 GPUs with 1 GPU.
+You may want to update it accordingly when using more than 1 GPU.
+
+#### Inference & Evaluation
+Average the last 10 checkpoints and evaluate on test split:
+```bash
+CHECKPOINT_FILENAME=avg_last_10_checkpoint.pt
+python scripts/average_checkpoints.py \
+ --inputs ${ST_SAVE_DIR} --num-epoch-checkpoints 10 \
+ --output "${ST_SAVE_DIR}/${CHECKPOINT_FILENAME}"
+fairseq-generate ${COVOST_ROOT}/fr \
+ --config-yaml config_st_fr_en.yaml --gen-subset test_st_fr_en --task speech_to_text \
+ --path ${ST_SAVE_DIR}/${CHECKPOINT_FILENAME} \
+ --max-tokens 50000 --beam 5 --scoring sacrebleu
+```
+
+## Interactive Decoding
+Launch the interactive console via
+```bash
+fairseq-interactive ${COVOST_ROOT}/fr --config-yaml config_st_fr_en.yaml \
+ --task speech_to_text --path ${SAVE_DIR}/${CHECKPOINT_FILENAME} \
+ --max-tokens 50000 --beam 5
+```
+Type in WAV/FLAC/OGG audio paths (one per line) after the prompt.
+
+#### Results
+| --arch | Params | Fr-En | De-En | Es-En | Ca-En | En-De | En-Ca | En-Fa | En-Et | Model |
+|---|---|---|---|---|---|---|---|---|---|---|
+| s2t_transformer_s | 31M | [27.2](https://dl.fbaipublicfiles.com/fairseq/s2t/covost2_fr_en_st_transformer_s.pt) | [17.7](https://dl.fbaipublicfiles.com/fairseq/s2t/covost2_de_en_st_transformer_s.pt) | [23.1](https://dl.fbaipublicfiles.com/fairseq/s2t/covost2_es_en_st_transformer_s.pt) | [19.3](https://dl.fbaipublicfiles.com/fairseq/s2t/covost2_ca_en_st_transformer_s.pt) | [16.1](https://dl.fbaipublicfiles.com/fairseq/s2t/covost2_en_de_st_transformer_s.pt) | [21.6](https://dl.fbaipublicfiles.com/fairseq/s2t/covost2_en_ca_st_transformer_s.pt) | [12.9](https://dl.fbaipublicfiles.com/fairseq/s2t/covost2_en_fa_st_transformer_s.pt) | [12.8](https://dl.fbaipublicfiles.com/fairseq/s2t/covost2_en_et_st_transformer_s.pt) | (<-Download) |
+
+[[Back]](..)
diff --git a/fairseq/examples/speech_to_text/docs/librispeech_example.md b/fairseq/examples/speech_to_text/docs/librispeech_example.md
new file mode 100644
index 0000000000000000000000000000000000000000..4040fda9426027537036ba987d087a43e734bfd9
--- /dev/null
+++ b/fairseq/examples/speech_to_text/docs/librispeech_example.md
@@ -0,0 +1,69 @@
+[[Back]](..)
+
+# S2T Example: Speech Recognition (ASR) on LibriSpeech
+[LibriSpeech](https://www.danielpovey.com/files/2015_icassp_librispeech.pdf) is a de-facto standard English ASR
+benchmark. We provide competitive
+vanilla [Transformer](https://papers.nips.cc/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf) baselines.
+
+## Data preparation
+Download and preprocess LibriSpeech data with
+```bash
+# additional Python packages for S2T data processing/model training
+pip install pandas torchaudio sentencepiece
+
+python examples/speech_to_text/prep_librispeech_data.py \
+ --output-root ${LS_ROOT} --vocab-type unigram --vocab-size 10000
+```
+where `LS_ROOT` is the root path for downloaded data as well as generated files (manifest, features, vocabulary and
+data configuration).
+
+[Download](https://dl.fbaipublicfiles.com/fairseq/s2t/librispeech_vocab_unigram10000.zip) our vocabulary files
+if you want to use our pre-trained models.
+
+## Training
+```bash
+fairseq-train ${LS_ROOT} --save-dir ${SAVE_DIR} \
+ --config-yaml config.yaml --train-subset train-clean-100,train-clean-360,train-other-500 --valid-subset dev-clean,dev-other \
+ --num-workers 4 --max-tokens 40000 --max-update 300000 \
+ --task speech_to_text --criterion label_smoothed_cross_entropy --label-smoothing 0.1 --report-accuracy \
+ --arch s2t_transformer_s --share-decoder-input-output-embed \
+ --optimizer adam --lr 2e-3 --lr-scheduler inverse_sqrt --warmup-updates 10000 \
+ --clip-norm 10.0 --seed 1 --update-freq 8
+```
+where `SAVE_DIR` is the checkpoint root path. Here we use `--arch s2t_transformer_s` (31M parameters) as example.
+For better performance, you may switch to `s2t_transformer_m` (71M, with `--lr 1e-3`) or `s2t_transformer_l`
+(268M, with `--lr 5e-4`). We set `--update-freq 8` to simulate 8 GPUs with 1 GPU. You may want to update it accordingly
+when using more than 1 GPU.
+
+## Inference & Evaluation
+Average the last 10 checkpoints and evaluate on the 4 splits
+(`dev-clean`, `dev-other`, `test-clean` and `test-other`):
+```bash
+CHECKPOINT_FILENAME=avg_last_10_checkpoint.pt
+python scripts/average_checkpoints.py --inputs ${SAVE_DIR} \
+ --num-epoch-checkpoints 10 \
+ --output "${SAVE_DIR}/${CHECKPOINT_FILENAME}"
+for SUBSET in dev-clean dev-other test-clean test-other; do
+ fairseq-generate ${LS_ROOT} --config-yaml config.yaml --gen-subset ${SUBSET} \
+ --task speech_to_text --path ${SAVE_DIR}/${CHECKPOINT_FILENAME} \
+ --max-tokens 50000 --beam 5 --scoring wer
+done
+```
+
+## Interactive Decoding
+Launch the interactive console via
+```bash
+fairseq-interactive ${LS_ROOT} --config-yaml config.yaml --task speech_to_text \
+ --path ${SAVE_DIR}/${CHECKPOINT_FILENAME} --max-tokens 50000 --beam 5
+```
+Type in WAV/FLAC/OGG audio paths (one per line) after the prompt.
+
+## Results
+
+| --arch | Params | dev-clean | dev-other | test-clean | test-other | Model |
+|---|---|---|---|---|---|---|
+| s2t_transformer_s | 30M | 3.8 | 8.9 | 4.4 | 9.0 | [Download](https://dl.fbaipublicfiles.com/fairseq/s2t/librispeech_transformer_s.pt) |
+| s2t_transformer_m | 71M | 3.2 | 8.0 | 3.4 | 7.9 | [Download](https://dl.fbaipublicfiles.com/fairseq/s2t/librispeech_transformer_m.pt) |
+| s2t_transformer_l | 268M | 3.0 | 7.5 | 3.2 | 7.5 | [Download](https://dl.fbaipublicfiles.com/fairseq/s2t/librispeech_transformer_l.pt) |
+
+[[Back]](..)
diff --git a/fairseq/examples/speech_to_text/docs/mtedx_example.md b/fairseq/examples/speech_to_text/docs/mtedx_example.md
new file mode 100644
index 0000000000000000000000000000000000000000..25b4556affbf5bc141b103095d15fffef6225c0e
--- /dev/null
+++ b/fairseq/examples/speech_to_text/docs/mtedx_example.md
@@ -0,0 +1,200 @@
+[[Back]](..)
+
+# S2T Example: Speech Translation (ST) on Multilingual TEDx
+
+[Multilingual TEDx](https://arxiv.org/abs/2102.01757) is multilingual corpus for speech recognition and
+speech translation. The data is derived from TEDx talks in 8 source languages
+with translations to a subset of 5 target languages.
+
+## Data Preparation
+[Download](http://openslr.org/100/) and unpack Multilingual TEDx data to a path
+`${MTEDX_ROOT}/${LANG_PAIR}`, then preprocess it with
+```bash
+# additional Python packages for S2T data processing/model training
+pip install pandas torchaudio soundfile sentencepiece
+
+# Generate TSV manifests, features, vocabulary
+# and configuration for each language
+python examples/speech_to_text/prep_mtedx_data.py \
+ --data-root ${MTEDX_ROOT} --task asr \
+ --vocab-type unigram --vocab-size 1000
+python examples/speech_to_text/prep_mtedx_data.py \
+ --data-root ${MTEDX_ROOT} --task st \
+ --vocab-type unigram --vocab-size 1000
+
+# Add vocabulary and configuration for joint data
+# (based on the manifests and features generated above)
+python examples/speech_to_text/prep_mtedx_data.py \
+ --data-root ${MTEDX_ROOT} --task asr --joint \
+ --vocab-type unigram --vocab-size 8000
+python examples/speech_to_text/prep_mtedx_data.py \
+ --data-root ${MTEDX_ROOT} --task st --joint \
+ --vocab-type unigram --vocab-size 8000
+```
+The generated files (manifest, features, vocabulary and data configuration) will be added to
+`${MTEDX_ROOT}/${LANG_PAIR}` (per-language data) and `MTEDX_ROOT` (joint data).
+
+
+## ASR
+#### Training
+Spanish as example:
+```bash
+fairseq-train ${MTEDX_ROOT}/es-es \
+ --config-yaml config_asr.yaml --train-subset train_asr --valid-subset valid_asr \
+ --save-dir ${ASR_SAVE_DIR} --num-workers 4 --max-tokens 40000 --max-epoch 200 \
+ --task speech_to_text --criterion label_smoothed_cross_entropy --report-accuracy \
+ --arch s2t_transformer_xs --optimizer adam --lr 2e-3 --lr-scheduler inverse_sqrt \
+ --warmup-updates 10000 --clip-norm 10.0 --seed 1 --dropout 0.3 --label-smoothing 0.1 \
+ --load-pretrained-encoder-from ${PRETRAINED_ENCODER} \
+ --skip-invalid-size-inputs-valid-test \
+ --keep-last-epochs 10 --update-freq 8 --patience 10
+```
+For joint model (using ASR data from all 8 languages):
+```bash
+fairseq-train ${MTEDX_ROOT} \
+ --config-yaml config_asr.yaml \
+ --train-subset train_es-es_asr,train_fr-fr_asr,train_pt-pt_asr,train_it-it_asr,train_ru-ru_asr,train_el-el_asr,train_ar-ar_asr,train_de-de_asr \
+ --valid-subset valid_es-es_asr,valid_fr-fr_asr,valid_pt-pt_asr,valid_it-it_asr,valid_ru-ru_asr,valid_el-el_asr,valid_ar-ar_asr,valid_de-de_asr \
+ --save-dir ${MULTILINGUAL_ASR_SAVE_DIR} --num-workers 4 --max-tokens 40000 --max-epoch 200 \
+ --task speech_to_text --criterion label_smoothed_cross_entropy --report-accuracy \
+ --arch s2t_transformer_s --optimizer adam --lr 2e-3 --lr-scheduler inverse_sqrt \
+ --warmup-updates 10000 --clip-norm 10.0 --seed 1 --dropout 0.3 --label-smoothing 0.1 \
+ --skip-invalid-size-inputs-valid-test \
+ --keep-last-epochs 10 --update-freq 8 --patience 10 \
+ --ignore-prefix-size 1
+```
+where `MULTILINGUAL_ASR_SAVE_DIR` is the checkpoint root path. We set `--update-freq 8` to simulate 8 GPUs
+with 1 GPU. You may want to update it accordingly when using more than 1 GPU.
+For multilingual models, we prepend target language ID token as target BOS, which should be excluded from
+the training loss via `--ignore-prefix-size 1`.
+
+#### Inference & Evaluation
+```bash
+CHECKPOINT_FILENAME=avg_last_10_checkpoint.pt
+python scripts/average_checkpoints.py \
+ --inputs ${ASR_SAVE_DIR} --num-epoch-checkpoints 10 \
+ --output "${ASR_SAVE_DIR}/${CHECKPOINT_FILENAME}"
+
+fairseq-generate ${MTEDX_ROOT}/es-es \
+ --config-yaml config_asr.yaml --gen-subset test --task speech_to_text \
+ --path ${ASR_SAVE_DIR}/${CHECKPOINT_FILENAME} --max-tokens 50000 --beam 5 \
+ --skip-invalid-size-inputs-valid-test \
+ --scoring wer --wer-tokenizer 13a --wer-lowercase --wer-remove-punct --remove-bpe
+
+# For models trained on joint data
+CHECKPOINT_FILENAME=avg_last_10_checkpoint.pt
+python scripts/average_checkpoints.py \
+ --inputs ${MULTILINGUAL_ASR_SAVE_DIR} --num-epoch-checkpoints 10 \
+ --output "${MULTILINGUAL_ASR_SAVE_DIR}/${CHECKPOINT_FILENAME}"
+
+for LANG in es fr pt it ru el ar de; do
+ fairseq-generate ${MTEDX_ROOT} \
+ --config-yaml config_asr.yaml --gen-subset test_${LANG}-${LANG}_asr --task speech_to_text \
+ --prefix-size 1 --path ${MULTILINGUAL_ASR_SAVE_DIR}/${CHECKPOINT_FILENAME} \
+ --max-tokens 40000 --beam 5 \
+ --skip-invalid-size-inputs-valid-test \
+ --scoring wer --wer-tokenizer 13a --wer-lowercase --wer-remove-punct --remove-bpe
+done
+```
+#### Results
+| Data | --arch | Params | Es | Fr | Pt | It | Ru | El | Ar | De |
+|--------------|--------------------|--------|------|------|------|------|------|-------|-------|-------|
+| Monolingual | s2t_transformer_xs | 10M | 46.4 | 45.6 | 54.8 | 48.0 | 74.7 | 109.5 | 104.4 | 111.1 |
+
+
+## ST
+#### Training
+Es-En as example:
+```bash
+fairseq-train ${MTEDX_ROOT}/es-en \
+ --config-yaml config_st.yaml --train-subset train_st --valid-subset valid_st \
+ --save-dir ${ST_SAVE_DIR} --num-workers 4 --max-tokens 40000 --max-epoch 200 \
+ --task speech_to_text --criterion label_smoothed_cross_entropy --report-accuracy \
+ --arch s2t_transformer_xs --optimizer adam --lr 2e-3 --lr-scheduler inverse_sqrt \
+ --warmup-updates 10000 --clip-norm 10.0 --seed 1 --dropout 0.3 --label-smoothing 0.1 \
+ --load-pretrained-encoder-from ${PRETRAINED_ENCODER} \
+ --skip-invalid-size-inputs-valid-test \
+ --keep-last-epochs 10 --update-freq 8 --patience 10
+```
+For multilingual model (all 12 directions):
+```bash
+fairseq-train ${MTEDX_ROOT} \
+ --config-yaml config_st.yaml \
+ --train-subset train_el-en_st,train_es-en_st,train_es-fr_st,train_es-it_st,train_es-pt_st,train_fr-en_st,train_fr-es_st,train_fr-pt_st,train_it-en_st,train_it-es_st,train_pt-en_st,train_pt-es_st,train_ru-en_st \
+ --valid-subset valid_el-en_st,valid_es-en_st,valid_es-fr_st,valid_es-it_st,valid_es-pt_st,valid_fr-en_st,valid_fr-es_st,valid_fr-pt_st,valid_it-en_st,valid_it-es_st,valid_pt-en_st,valid_pt-es_st,valid_ru-en_st \
+ --save-dir ${MULTILINGUAL_ST_SAVE_DIR} --num-workers 4 --max-tokens 40000 --max-epoch 200 \
+ --task speech_to_text --criterion label_smoothed_cross_entropy --report-accuracy \
+ --arch s2t_transformer_s --optimizer adam --lr 2e-3 --lr-scheduler inverse_sqrt \
+ --warmup-updates 10000 --clip-norm 10.0 --seed 1 --dropout 0.3 --label-smoothing 0.1 \
+ --skip-invalid-size-inputs-valid-test \
+ --keep-last-epochs 10 --update-freq 8 --patience 10 \
+ --ignore-prefix-size 1 \
+ --load-pretrained-encoder-from ${PRETRAINED_ENCODER}
+```
+where `ST_SAVE_DIR` (`MULTILINGUAL_ST_SAVE_DIR`) is the checkpoint root path. The ST encoder is pre-trained by ASR
+for faster training and better performance: `--load-pretrained-encoder-from <(JOINT_)ASR checkpoint path>`. We set
+`--update-freq 8` to simulate 8 GPUs with 1 GPU. You may want to update it accordingly when using more than 1 GPU.
+For multilingual models, we prepend target language ID token as target BOS, which should be excluded from
+the training loss via `--ignore-prefix-size 1`.
+
+#### Inference & Evaluation
+Average the last 10 checkpoints and evaluate on the `test` split:
+```bash
+CHECKPOINT_FILENAME=avg_last_10_checkpoint.pt
+python scripts/average_checkpoints.py \
+ --inputs ${ST_SAVE_DIR} --num-epoch-checkpoints 10 \
+ --output "${ST_SAVE_DIR}/${CHECKPOINT_FILENAME}"
+
+fairseq-generate ${MTEDX_ROOT}/es-en \
+ --config-yaml config_st.yaml --gen-subset test --task speech_to_text \
+ --path ${ST_SAVE_DIR}/${CHECKPOINT_FILENAME} \
+ --max-tokens 50000 --beam 5 --scoring sacrebleu --remove-bpe
+
+# For multilingual models
+python scripts/average_checkpoints.py \
+ --inputs ${MULTILINGUAL_ST_SAVE_DIR} --num-epoch-checkpoints 10 \
+ --output "${MULTILINGUAL_ST_SAVE_DIR}/${CHECKPOINT_FILENAME}"
+
+for LANGPAIR in es-en es-fr es-pt fr-en fr-es fr-pt pt-en pt-es it-en it-es ru-en el-en; do
+ fairseq-generate ${MTEDX_ROOT} \
+ --config-yaml config_st.yaml --gen-subset test_${LANGPAIR}_st --task speech_to_text \
+ --prefix-size 1 --path ${MULTILINGUAL_ST_SAVE_DIR}/${CHECKPOINT_FILENAME} \
+ --max-tokens 40000 --beam 5 \
+ --skip-invalid-size-inputs-valid-test \
+ --scoring sacrebleu --remove-bpe
+done
+```
+For multilingual models, we force decoding from the target language ID token (as BOS) via `--prefix-size 1`.
+
+#### Results
+| Data | --arch | Params | Es-En | Es-Pt | Es-Fr | Fr-En | Fr-Es | Fr-Pt | Pt-En | Pt-Es | It-En | It-Es | Ru-En | El-En |
+|--------------|--------------------|-----|-------|-------|-------|-------|-------|-------|-------|-------|-------|-------|-------|-------|
+| Bilingual | s2t_transformer_xs | 10M | 7.0 | 12.2 | 1.7 | 8.9 | 10.6 | 7.9 | 8.1 | 8.7 | 6.4 | 1.0 | 0.7 | 0.6 |
+| Multilingual | s2t_transformer_s | 31M | 12.3 | 17.4 | 6.1 | 12.0 | 13.6 | 13.2 | 12.0 | 13.7 | 10.7 | 13.1 | 0.6 | 0.8 |
+
+
+## Citation
+Please cite as:
+```
+@misc{salesky2021mtedx,
+ title={Multilingual TEDx Corpus for Speech Recognition and Translation},
+ author={Elizabeth Salesky and Matthew Wiesner and Jacob Bremerman and Roldano Cattoni and Matteo Negri and Marco Turchi and Douglas W. Oard and Matt Post},
+ year={2021},
+}
+
+@inproceedings{wang2020fairseqs2t,
+ title = {fairseq S2T: Fast Speech-to-Text Modeling with fairseq},
+ author = {Changhan Wang and Yun Tang and Xutai Ma and Anne Wu and Dmytro Okhonko and Juan Pino},
+ booktitle = {Proceedings of the 2020 Conference of the Asian Chapter of the Association for Computational Linguistics (AACL): System Demonstrations},
+ year = {2020},
+}
+
+@inproceedings{ott2019fairseq,
+ title = {fairseq: A Fast, Extensible Toolkit for Sequence Modeling},
+ author = {Myle Ott and Sergey Edunov and Alexei Baevski and Angela Fan and Sam Gross and Nathan Ng and David Grangier and Michael Auli},
+ booktitle = {Proceedings of NAACL-HLT 2019: Demonstrations},
+ year = {2019},
+}
+```
+
+[[Back]](..)
diff --git a/fairseq/examples/speech_to_text/docs/mustc_example.md b/fairseq/examples/speech_to_text/docs/mustc_example.md
new file mode 100644
index 0000000000000000000000000000000000000000..c95ef3e15660107c3384f87c1680f005044e7f3b
--- /dev/null
+++ b/fairseq/examples/speech_to_text/docs/mustc_example.md
@@ -0,0 +1,155 @@
+[[Back]](..)
+
+# S2T Example: Speech Translation (ST) on MuST-C
+
+[MuST-C](https://www.aclweb.org/anthology/N19-1202) is multilingual speech-to-text translation corpus with
+8-language translations on English TED talks. We match the state-of-the-art performance in
+[ESPNet-ST](https://arxiv.org/pdf/2004.10234.pdf) with a simpler model training pipeline.
+
+## Data Preparation
+[Download](https://ict.fbk.eu/must-c) and unpack MuST-C data to a path
+`${MUSTC_ROOT}/en-${TARGET_LANG_ID}`, then preprocess it with
+```bash
+# additional Python packages for S2T data processing/model training
+pip install pandas torchaudio soundfile sentencepiece
+
+# Generate TSV manifests, features, vocabulary
+# and configuration for each language
+python examples/speech_to_text/prep_mustc_data.py \
+ --data-root ${MUSTC_ROOT} --task asr \
+ --vocab-type unigram --vocab-size 5000
+python examples/speech_to_text/prep_mustc_data.py \
+ --data-root ${MUSTC_ROOT} --task st \
+ --vocab-type unigram --vocab-size 8000
+
+# Add vocabulary and configuration for joint data
+# (based on the manifests and features generated above)
+python examples/speech_to_text/prep_mustc_data.py \
+ --data-root ${MUSTC_ROOT} --task asr --joint \
+ --vocab-type unigram --vocab-size 10000
+python examples/speech_to_text/prep_mustc_data.py \
+ --data-root ${MUSTC_ROOT} --task st --joint \
+ --vocab-type unigram --vocab-size 10000
+```
+The generated files (manifest, features, vocabulary and data configuration) will be added to
+`${MUSTC_ROOT}/en-${TARGET_LANG_ID}` (per-language data) and `MUSTC_ROOT` (joint data).
+
+Download our vocabulary files if you want to use our pre-trained models:
+- ASR: [En-De](https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_de_asr_vocab_unigram5000.zip), [En-Nl](https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_nl_asr_vocab_unigram5000.zip), [En-Es](https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_es_asr_vocab_unigram5000.zip), [En-Fr](https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_fr_asr_vocab_unigram5000.zip), [En-It](https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_it_asr_vocab_unigram5000.zip), [En-Pt](https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_pt_asr_vocab_unigram5000.zip), [En-Ro](https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_ro_asr_vocab_unigram5000.zip), [En-Ru](https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_ru_asr_vocab_unigram5000.zip), [Joint](https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_joint_asr_vocab_unigram10000.zip)
+- ST: [En-De](https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_de_st_vocab_unigram8000.zip), [En-Nl](https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_nl_st_vocab_unigram8000.zip), [En-Es](https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_es_st_vocab_unigram8000.zip), [En-Fr](https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_fr_st_vocab_unigram8000.zip), [En-It](https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_it_st_vocab_unigram8000.zip), [En-Pt](https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_pt_st_vocab_unigram8000.zip), [En-Ro](https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_ro_st_vocab_unigram8000.zip), [En-Ru](https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_ru_st_vocab_unigram8000.zip), [Multilingual](https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_multilingual_st_vocab_unigram10000.zip)
+
+## ASR
+#### Training
+En-De as example:
+```bash
+fairseq-train ${MUSTC_ROOT}/en-de \
+ --config-yaml config_asr.yaml --train-subset train_asr --valid-subset dev_asr \
+ --save-dir ${ASR_SAVE_DIR} --num-workers 4 --max-tokens 40000 --max-update 100000 \
+ --task speech_to_text --criterion label_smoothed_cross_entropy --label-smoothing 0.1 --report-accuracy \
+ --arch s2t_transformer_s --optimizer adam --lr 1e-3 --lr-scheduler inverse_sqrt \
+ --warmup-updates 10000 --clip-norm 10.0 --seed 1 --update-freq 8
+```
+For joint model (using ASR data from all 8 directions):
+```bash
+fairseq-train ${MUSTC_ROOT} \
+ --config-yaml config_asr.yaml \
+ --train-subset train_de_asr,train_nl_asr,train_es_asr,train_fr_asr,train_it_asr,train_pt_asr,train_ro_asr,train_ru_asr \
+ --valid-subset dev_de_asr,dev_nl_asr,dev_es_asr,dev_fr_asr,dev_it_asr,dev_pt_asr,dev_ro_asr,dev_ru_asr \
+ --save-dir ${JOINT_ASR_SAVE_DIR} --num-workers 4 --max-tokens 40000 --max-update 100000 \
+ --task speech_to_text --criterion label_smoothed_cross_entropy --label-smoothing 0.1 --report-accuracy \
+ --arch s2t_transformer_s --optimizer adam --lr 1e-3 --lr-scheduler inverse_sqrt \
+ --warmup-updates 10000 --clip-norm 10.0 --seed 1 --update-freq 8
+```
+where `ASR_SAVE_DIR` (`JOINT_ASR_SAVE_DIR`) is the checkpoint root path. We set `--update-freq 8` to simulate 8 GPUs
+with 1 GPU. You may want to update it accordingly when using more than 1 GPU.
+
+#### Inference & Evaluation
+```bash
+CHECKPOINT_FILENAME=avg_last_10_checkpoint.pt
+python scripts/average_checkpoints.py \
+ --inputs ${ASR_SAVE_DIR} --num-epoch-checkpoints 10 \
+ --output "${ASR_SAVE_DIR}/${CHECKPOINT_FILENAME}"
+fairseq-generate ${MUSTC_ROOT}/en-de \
+ --config-yaml config_asr.yaml --gen-subset tst-COMMON_asr --task speech_to_text \
+ --path ${ASR_SAVE_DIR}/${CHECKPOINT_FILENAME} --max-tokens 50000 --beam 5 \
+ --scoring wer --wer-tokenizer 13a --wer-lowercase --wer-remove-punct
+
+# For models trained on joint data
+python scripts/average_checkpoints.py \
+ --inputs ${JOINT_ASR_SAVE_DIR} --num-epoch-checkpoints 10 \
+ --output "${JOINT_ASR_SAVE_DIR}/${CHECKPOINT_FILENAME}"
+for LANG in de nl es fr it pt ro ru; do
+ fairseq-generate ${MUSTC_ROOT} \
+ --config-yaml config_asr.yaml --gen-subset tst-COMMON_${LANG}_asr --task speech_to_text \
+ --path ${JOINT_ASR_SAVE_DIR}/${CHECKPOINT_FILENAME} --max-tokens 50000 --beam 5 \
+ --scoring wer --wer-tokenizer 13a --wer-lowercase --wer-remove-punct
+done
+```
+#### Results
+| Data | --arch | Params | En-De | En-Nl | En-Es | En-Fr | En-It | En-Pt | En-Ro | En-Ru | Model |
+|---|---|---|---|---|---|---|---|---|---|---|---|
+| Single | s2t_transformer_s | 31M | [18.2](https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_de_asr_transformer_s.pt) | [17.6](https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_nl_asr_transformer_s.pt) | [17.7](https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_es_asr_transformer_s.pt) | [17.2](https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_fr_asr_transformer_s.pt) | [17.9](https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_it_asr_transformer_s.pt) | [19.1](https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_pt_asr_transformer_s.pt) | [18.1](https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_ro_asr_transformer_s.pt) | [17.7](https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_ru_asr_transformer_s.pt) | (<-Download) |
+| Joint | s2t_transformer_m | 76M | 16.8 | 16.7 | 16.9 | 16.9 | 17.0 | 17.4 | 17.0 | 16.9 | [Download](https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_joint_asr_transformer_m.pt) |
+
+## ST
+#### Training
+En-De as example:
+```bash
+fairseq-train ${MUSTC_ROOT}/en-de \
+ --config-yaml config_st.yaml --train-subset train_st --valid-subset dev_st \
+ --save-dir ${ST_SAVE_DIR} --num-workers 4 --max-tokens 40000 --max-update 100000 \
+ --task speech_to_text --criterion label_smoothed_cross_entropy --label-smoothing 0.1 --report-accuracy \
+ --arch s2t_transformer_s --optimizer adam --lr 2e-3 --lr-scheduler inverse_sqrt \
+ --warmup-updates 10000 --clip-norm 10.0 --seed 1 --update-freq 8 \
+ --load-pretrained-encoder-from ${ASR_SAVE_DIR}/${CHECKPOINT_FILENAME}
+```
+For multilingual model (all 8 directions):
+```bash
+fairseq-train ${MUSTC_ROOT} \
+ --config-yaml config_st.yaml \
+ --train-subset train_de_st,train_nl_st,train_es_st,train_fr_st,train_it_st,train_pt_st,train_ro_st,train_ru_st \
+ --valid-subset dev_de_st,dev_nl_st,dev_es_st,dev_fr_st,dev_it_st,dev_pt_st,dev_ro_st,dev_ru_st \
+ --save-dir ${MULTILINGUAL_ST_SAVE_DIR} --num-workers 4 --max-tokens 40000 --max-update 100000 \
+ --task speech_to_text --criterion label_smoothed_cross_entropy --label-smoothing 0.1 --report-accuracy \
+ --arch s2t_transformer_s --ignore-prefix-size 1 --optimizer adam --lr 2e-3 --lr-scheduler inverse_sqrt \
+ --warmup-updates 10000 --clip-norm 10.0 --seed 1 --update-freq 8 \
+ --load-pretrained-encoder-from ${JOINT_ASR_SAVE_DIR}/${CHECKPOINT_FILENAME}
+```
+where `ST_SAVE_DIR` (`MULTILINGUAL_ST_SAVE_DIR`) is the checkpoint root path. The ST encoder is pre-trained by ASR
+for faster training and better performance: `--load-pretrained-encoder-from <(JOINT_)ASR checkpoint path>`. We set
+`--update-freq 8` to simulate 8 GPUs with 1 GPU. You may want to update it accordingly when using more than 1 GPU.
+For multilingual models, we prepend target language ID token as target BOS, which should be excluded from
+the training loss via `--ignore-prefix-size 1`.
+
+#### Inference & Evaluation
+Average the last 10 checkpoints and evaluate on the `tst-COMMON` split:
+```bash
+CHECKPOINT_FILENAME=avg_last_10_checkpoint.pt
+python scripts/average_checkpoints.py \
+ --inputs ${ST_SAVE_DIR} --num-epoch-checkpoints 10 \
+ --output "${ST_SAVE_DIR}/${CHECKPOINT_FILENAME}"
+fairseq-generate ${MUSTC_ROOT}/en-de \
+ --config-yaml config_st.yaml --gen-subset tst-COMMON_st --task speech_to_text \
+ --path ${ST_SAVE_DIR}/${CHECKPOINT_FILENAME} \
+ --max-tokens 50000 --beam 5 --scoring sacrebleu
+
+# For multilingual models
+python scripts/average_checkpoints.py \
+ --inputs ${MULTILINGUAL_ST_SAVE_DIR} --num-epoch-checkpoints 10 \
+ --output "${MULTILINGUAL_ST_SAVE_DIR}/${CHECKPOINT_FILENAME}"
+for LANG in de nl es fr it pt ro ru; do
+ fairseq-generate ${MUSTC_ROOT} \
+ --config-yaml config_st.yaml --gen-subset tst-COMMON_${LANG}_st --task speech_to_text \
+ --prefix-size 1 --path ${MULTILINGUAL_ST_SAVE_DIR}/${CHECKPOINT_FILENAME} \
+ --max-tokens 50000 --beam 5 --scoring sacrebleu
+done
+```
+For multilingual models, we force decoding from the target language ID token (as BOS) via `--prefix-size 1`.
+
+#### Results
+| Data | --arch | Params | En-De | En-Nl | En-Es | En-Fr | En-It | En-Pt | En-Ro | En-Ru | Model |
+|---|---|---|---|---|---|---|---|---|---|---|---|
+| Bilingual | s2t_transformer_s | 31M | [22.7](https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_de_st_transformer_s.pt) | [27.3](https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_nl_st_transformer_s.pt) | [27.2](https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_es_st_transformer_s.pt) | [32.9](https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_fr_st_transformer_s.pt) | [22.7](https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_it_st_transformer_s.pt) | [28.1](https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_pt_st_transformer_s.pt) | [21.9](https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_ro_st_transformer_s.pt) | [15.3](https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_ru_st_transformer_s.pt) | (<-Download) |
+| Multilingual | s2t_transformer_m | 76M | 24.5 | 28.6 | 28.2 | 34.9 | 24.6 | 31.1 | 23.8 | 16.0 | [Download](https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_multilingual_st_transformer_m.pt) |
+
+[[Back]](..)
diff --git a/fairseq/examples/speech_to_text/docs/simulst_mustc_example.md b/fairseq/examples/speech_to_text/docs/simulst_mustc_example.md
new file mode 100644
index 0000000000000000000000000000000000000000..f3b5a413a27bbe2700da3f418460aa0a7c41abdd
--- /dev/null
+++ b/fairseq/examples/speech_to_text/docs/simulst_mustc_example.md
@@ -0,0 +1,190 @@
+# Simultaneous Speech Translation (SimulST) on MuST-C
+
+This is a tutorial of training and evaluating a transformer *wait-k* simultaneous model on MUST-C English-Germen Dataset, from [SimulMT to SimulST: Adapting Simultaneous Text Translation to End-to-End Simultaneous Speech Translation](https://www.aclweb.org/anthology/2020.aacl-main.58.pdf).
+
+[MuST-C](https://www.aclweb.org/anthology/N19-1202) is multilingual speech-to-text translation corpus with 8-language translations on English TED talks.
+
+## Data Preparation
+This section introduces the data preparation for training and evaluation.
+If you only want to evaluate the model, please jump to [Inference & Evaluation](#inference--evaluation)
+
+[Download](https://ict.fbk.eu/must-c) and unpack MuST-C data to a path
+`${MUSTC_ROOT}/en-${TARGET_LANG_ID}`, then preprocess it with
+```bash
+# Additional Python packages for S2T data processing/model training
+pip install pandas torchaudio sentencepiece
+
+# Generate TSV manifests, features, vocabulary,
+# global cepstral and mean estimation,
+# and configuration for each language
+cd fairseq
+
+python examples/speech_to_text/prep_mustc_data.py \
+ --data-root ${MUSTC_ROOT} --task asr \
+ --vocab-type unigram --vocab-size 10000 \
+ --cmvn-type global
+
+python examples/speech_to_text/prep_mustc_data.py \
+ --data-root ${MUSTC_ROOT} --task st \
+ --vocab-type unigram --vocab-size 10000 \
+ --cmvn-type global
+```
+
+## ASR Pretraining
+We need a pretrained offline ASR model. Assuming the save directory of the ASR model is `${ASR_SAVE_DIR}`.
+The following command (and the subsequent training commands in this tutorial) assume training on 1 GPU (you can also train on 8 GPUs and remove the `--update-freq 8` option).
+```
+fairseq-train ${MUSTC_ROOT}/en-de \
+ --config-yaml config_asr.yaml --train-subset train_asr --valid-subset dev_asr \
+ --save-dir ${ASR_SAVE_DIR} --num-workers 4 --max-tokens 40000 --max-update 100000 \
+ --task speech_to_text --criterion label_smoothed_cross_entropy --report-accuracy \
+ --arch convtransformer_espnet --optimizer adam --lr 0.0005 --lr-scheduler inverse_sqrt \
+ --warmup-updates 10000 --clip-norm 10.0 --seed 1 --update-freq 8
+```
+A pretrained ASR checkpoint can be downloaded [here](https://dl.fbaipublicfiles.com/simultaneous_translation/must_c_v1_en_de_pretrained_asr)
+
+## Simultaneous Speech Translation Training
+
+### Wait-K with fixed pre-decision module
+Fixed pre-decision indicates that the model operate simultaneous policy on the boundaries of fixed chunks.
+Here is a example of fixed pre-decision ratio 7 (the simultaneous decision is made every 7 encoder states) and
+a wait-3 policy model. Assuming the save directory is `${ST_SAVE_DIR}`
+```bash
+ fairseq-train ${MUSTC_ROOT}/en-de \
+ --config-yaml config_st.yaml --train-subset train_st --valid-subset dev_st \
+ --save-dir ${ST_SAVE_DIR} --num-workers 8 \
+ --optimizer adam --lr 0.0001 --lr-scheduler inverse_sqrt --clip-norm 10.0 \
+ --criterion label_smoothed_cross_entropy \
+ --warmup-updates 4000 --max-update 100000 --max-tokens 40000 --seed 2 \
+ --load-pretrained-encoder-from ${ASR_SAVE_DIR}/checkpoint_best.pt \
+ --task speech_to_text \
+ --arch convtransformer_simul_trans_espnet \
+ --simul-type waitk_fixed_pre_decision \
+ --waitk-lagging 3 \
+ --fixed-pre-decision-ratio 7 \
+ --update-freq 8
+
+```
+### Monotonic multihead attention with fixed pre-decision module
+```
+ fairseq-train ${MUSTC_ROOT}/en-de \
+ --config-yaml config_st.yaml --train-subset train_st --valid-subset dev_st \
+ --save-dir ${ST_SAVE_DIR} --num-workers 8 \
+ --optimizer adam --lr 0.0001 --lr-scheduler inverse_sqrt --clip-norm 10.0 \
+ --warmup-updates 4000 --max-update 100000 --max-tokens 40000 --seed 2 \
+ --load-pretrained-encoder-from ${ASR_SAVE_DIR}/${CHECKPOINT_FILENAME} \
+ --task speech_to_text \
+ --criterion latency_augmented_label_smoothed_cross_entropy \
+ --latency-weight-avg 0.1 \
+ --arch convtransformer_simul_trans_espnet \
+ --simul-type infinite_lookback_fixed_pre_decision \
+ --fixed-pre-decision-ratio 7 \
+ --update-freq 8
+```
+## Inference & Evaluation
+[SimulEval](https://github.com/facebookresearch/SimulEval) is used for evaluation.
+The following command is for evaluation.
+
+```
+git clone https://github.com/facebookresearch/SimulEval.git
+cd SimulEval
+pip install -e .
+
+simuleval \
+ --agent ${FAIRSEQ}/examples/speech_to_text/simultaneous_translation/agents/fairseq_simul_st_agent.py
+ --source ${SRC_LIST_OF_AUDIO}
+ --target ${TGT_FILE}
+ --data-bin ${MUSTC_ROOT}/en-de \
+ --config config_st.yaml \
+ --model-path ${ST_SAVE_DIR}/${CHECKPOINT_FILENAME} \
+ --output ${OUTPUT} \
+ --scores
+```
+
+The source file `${SRC_LIST_OF_AUDIO}` is a list of paths of audio files. Assuming your audio files stored at `/home/user/data`,
+it should look like this
+
+```bash
+/home/user/data/audio-1.wav
+/home/user/data/audio-2.wav
+```
+
+Each line of target file `${TGT_FILE}` is the translation for each audio file input.
+```bash
+Translation_1
+Translation_2
+```
+The evaluation runs on the original MUSTC segmentation.
+The following command will generate the wav list and text file for a evaluation set `${SPLIT}` (chose from `dev`, `tst-COMMON` and `tst-HE`) in MUSTC to `${EVAL_DATA}`.
+```bash
+python ${FAIRSEQ}/examples/speech_to_text/seg_mustc_data.py \
+ --data-root ${MUSTC_ROOT} --lang de \
+ --split ${SPLIT} --task st \
+ --output ${EVAL_DATA}
+```
+
+The `--data-bin` and `--config` should be the same in previous section if you prepare the data from the scratch.
+If only for evaluation, a prepared data directory can be found [here](https://dl.fbaipublicfiles.com/simultaneous_translation/must_c_v1.0_en_de_databin.tgz). It contains
+- `spm_unigram10000_st.model`: a sentencepiece model binary.
+- `spm_unigram10000_st.txt`: the dictionary file generated by the sentencepiece model.
+- `gcmvn.npz`: the binary for global cepstral mean and variance.
+- `config_st.yaml`: the config yaml file. It looks like this.
+You will need to set the absolute paths for `sentencepiece_model` and `stats_npz_path` if the data directory is downloaded.
+```yaml
+bpe_tokenizer:
+ bpe: sentencepiece
+ sentencepiece_model: ABS_PATH_TO_SENTENCEPIECE_MODEL
+global_cmvn:
+ stats_npz_path: ABS_PATH_TO_GCMVN_FILE
+input_channels: 1
+input_feat_per_channel: 80
+sampling_alpha: 1.0
+specaugment:
+ freq_mask_F: 27
+ freq_mask_N: 1
+ time_mask_N: 1
+ time_mask_T: 100
+ time_mask_p: 1.0
+ time_wrap_W: 0
+transforms:
+ '*':
+ - global_cmvn
+ _train:
+ - global_cmvn
+ - specaugment
+vocab_filename: spm_unigram10000_st.txt
+```
+
+Notice that once a `--data-bin` is set, the `--config` is the base name of the config yaml, not the full path.
+
+Set `--model-path` to the model checkpoint.
+A pretrained checkpoint can be downloaded from [here](https://dl.fbaipublicfiles.com/simultaneous_translation/convtransformer_wait5_pre7), which is a wait-5 model with a pre-decision of 280 ms.
+
+The result of this model on `tst-COMMON` is:
+```bash
+{
+ "Quality": {
+ "BLEU": 13.94974229366959
+ },
+ "Latency": {
+ "AL": 1751.8031870037803,
+ "AL_CA": 2338.5911762796536,
+ "AP": 0.7931395378788959,
+ "AP_CA": 0.9405103863210942,
+ "DAL": 1987.7811616943081,
+ "DAL_CA": 2425.2751560926167
+ }
+}
+```
+
+If `--output ${OUTPUT}` option is used, the detailed log and scores will be stored under the `${OUTPUT}` directory.
+
+
+The quality is measured by detokenized BLEU. So make sure that the predicted words sent to the server are detokenized.
+
+The latency metrics are
+* Average Proportion
+* Average Lagging
+* Differentiable Average Lagging
+
+Again they will also be evaluated on detokenized text.
diff --git a/fairseq/examples/speech_to_text/prep_covost_data.py b/fairseq/examples/speech_to_text/prep_covost_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..411e9b55152ea4a8e345e8c2d18431958c4f4c07
--- /dev/null
+++ b/fairseq/examples/speech_to_text/prep_covost_data.py
@@ -0,0 +1,279 @@
+#!/usr/bin/env python3
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import argparse
+import logging
+from pathlib import Path
+import shutil
+from tempfile import NamedTemporaryFile
+from typing import Optional, Tuple
+
+import pandas as pd
+import torchaudio
+from examples.speech_to_text.data_utils import (
+ create_zip,
+ extract_fbank_features,
+ filter_manifest_df,
+ gen_config_yaml,
+ gen_vocab,
+ get_zip_manifest,
+ load_df_from_tsv,
+ save_df_to_tsv,
+)
+from torch import Tensor
+from torch.utils.data import Dataset
+from torchaudio.datasets.utils import download_url, extract_archive
+from tqdm import tqdm
+
+
+log = logging.getLogger(__name__)
+
+
+MANIFEST_COLUMNS = ["id", "audio", "n_frames", "tgt_text", "speaker"]
+
+
+class CoVoST(Dataset):
+ """Create a Dataset for CoVoST (https://github.com/facebookresearch/covost).
+
+ Args:
+ root (str): root path to the dataset and generated manifests/features
+ source_language (str): source (audio) language
+ target_language (str, optional): target (text) language,
+ None for no translation (default: None)
+ version (int, optional): CoVoST version. (default: 2)
+ download (bool, optional): Whether to download the dataset if it is not
+ found at root path. (default: ``False``).
+ """
+
+ COVOST_URL_TEMPLATE = (
+ "https://dl.fbaipublicfiles.com/covost/"
+ "covost_v2.{src_lang}_{tgt_lang}.tsv.tar.gz"
+ )
+
+ VERSIONS = {2}
+ SPLITS = ["train", "dev", "test"]
+
+ XX_EN_LANGUAGES = {
+ 1: ["fr", "de", "nl", "ru", "es", "it", "tr", "fa", "sv-SE", "mn", "zh-CN"],
+ 2: [
+ "fr",
+ "de",
+ "es",
+ "ca",
+ "it",
+ "ru",
+ "zh-CN",
+ "pt",
+ "fa",
+ "et",
+ "mn",
+ "nl",
+ "tr",
+ "ar",
+ "sv-SE",
+ "lv",
+ "sl",
+ "ta",
+ "ja",
+ "id",
+ "cy",
+ ],
+ }
+ EN_XX_LANGUAGES = {
+ 1: [],
+ 2: [
+ "de",
+ "tr",
+ "fa",
+ "sv-SE",
+ "mn",
+ "zh-CN",
+ "cy",
+ "ca",
+ "sl",
+ "et",
+ "id",
+ "ar",
+ "ta",
+ "lv",
+ "ja",
+ ],
+ }
+
+ def __init__(
+ self,
+ root: str,
+ split: str,
+ source_language: str,
+ target_language: Optional[str] = None,
+ version: int = 2,
+ ) -> None:
+ assert version in self.VERSIONS and split in self.SPLITS
+ assert source_language is not None
+ self.no_translation = target_language is None
+ if not self.no_translation:
+ assert "en" in {source_language, target_language}
+ if source_language == "en":
+ assert target_language in self.EN_XX_LANGUAGES[version]
+ else:
+ assert source_language in self.XX_EN_LANGUAGES[version]
+ else:
+ # Hack here so that we can get "split" column from CoVoST TSV.
+ # Note that we use CoVoST train split for ASR which is an extension
+ # to Common Voice train split.
+ target_language = "de" if source_language == "en" else "en"
+
+ self.root: Path = Path(root)
+
+ cv_tsv_path = self.root / "validated.tsv"
+ assert cv_tsv_path.is_file()
+
+ covost_url = self.COVOST_URL_TEMPLATE.format(
+ src_lang=source_language, tgt_lang=target_language
+ )
+ covost_archive = self.root / Path(covost_url).name
+ if not covost_archive.is_file():
+ download_url(covost_url, self.root.as_posix(), hash_value=None)
+ extract_archive(covost_archive.as_posix())
+
+ cv_tsv = load_df_from_tsv(cv_tsv_path)
+ covost_tsv = load_df_from_tsv(
+ self.root / Path(covost_url).name.replace(".tar.gz", "")
+ )
+ df = pd.merge(
+ left=cv_tsv[["path", "sentence", "client_id"]],
+ right=covost_tsv[["path", "translation", "split"]],
+ how="inner",
+ on="path",
+ )
+ if split == "train":
+ df = df[(df["split"] == split) | (df["split"] == f"{split}_covost")]
+ else:
+ df = df[df["split"] == split]
+ data = df.to_dict(orient="index").items()
+ data = [v for k, v in sorted(data, key=lambda x: x[0])]
+ self.data = []
+ for e in data:
+ try:
+ path = self.root / "clips" / e["path"]
+ _ = torchaudio.info(path.as_posix())
+ self.data.append(e)
+ except RuntimeError:
+ pass
+
+ def __getitem__(
+ self, n: int
+ ) -> Tuple[Tensor, int, str, str, Optional[str], str, str]:
+ """Load the n-th sample from the dataset.
+
+ Args:
+ n (int): The index of the sample to be loaded
+
+ Returns:
+ tuple: ``(waveform, sample_rate, sentence, translation, speaker_id,
+ sample_id)``
+ """
+ data = self.data[n]
+ path = self.root / "clips" / data["path"]
+ waveform, sample_rate = torchaudio.load(path)
+ sentence = data["sentence"]
+ translation = None if self.no_translation else data["translation"]
+ speaker_id = data["client_id"]
+ _id = data["path"].replace(".mp3", "")
+ return waveform, sample_rate, sentence, translation, speaker_id, _id
+
+ def __len__(self) -> int:
+ return len(self.data)
+
+
+def process(args):
+ root = Path(args.data_root).absolute() / args.src_lang
+ if not root.is_dir():
+ raise NotADirectoryError(f"{root} does not exist")
+ # Extract features
+ feature_root = root / "fbank80"
+ feature_root.mkdir(exist_ok=True)
+ for split in CoVoST.SPLITS:
+ print(f"Fetching split {split}...")
+ dataset = CoVoST(root, split, args.src_lang, args.tgt_lang)
+ print("Extracting log mel filter bank features...")
+ for waveform, sample_rate, _, _, _, utt_id in tqdm(dataset):
+ extract_fbank_features(
+ waveform, sample_rate, feature_root / f"{utt_id}.npy"
+ )
+ # Pack features into ZIP
+ zip_path = root / "fbank80.zip"
+ print("ZIPing features...")
+ create_zip(feature_root, zip_path)
+ print("Fetching ZIP manifest...")
+ audio_paths, audio_lengths = get_zip_manifest(zip_path)
+ # Generate TSV manifest
+ print("Generating manifest...")
+ train_text = []
+ task = f"asr_{args.src_lang}"
+ if args.tgt_lang is not None:
+ task = f"st_{args.src_lang}_{args.tgt_lang}"
+ for split in CoVoST.SPLITS:
+ manifest = {c: [] for c in MANIFEST_COLUMNS}
+ dataset = CoVoST(root, split, args.src_lang, args.tgt_lang)
+ for _, _, src_utt, tgt_utt, speaker_id, utt_id in tqdm(dataset):
+ manifest["id"].append(utt_id)
+ manifest["audio"].append(audio_paths[utt_id])
+ manifest["n_frames"].append(audio_lengths[utt_id])
+ manifest["tgt_text"].append(src_utt if args.tgt_lang is None else tgt_utt)
+ manifest["speaker"].append(speaker_id)
+ is_train_split = split.startswith("train")
+ if is_train_split:
+ train_text.extend(manifest["tgt_text"])
+ df = pd.DataFrame.from_dict(manifest)
+ df = filter_manifest_df(df, is_train_split=is_train_split)
+ save_df_to_tsv(df, root / f"{split}_{task}.tsv")
+ # Generate vocab
+ vocab_size_str = "" if args.vocab_type == "char" else str(args.vocab_size)
+ spm_filename_prefix = f"spm_{args.vocab_type}{vocab_size_str}_{task}"
+ with NamedTemporaryFile(mode="w") as f:
+ for t in train_text:
+ f.write(t + "\n")
+ gen_vocab(
+ Path(f.name),
+ root / spm_filename_prefix,
+ args.vocab_type,
+ args.vocab_size
+ )
+ # Generate config YAML
+ gen_config_yaml(
+ root,
+ spm_filename=spm_filename_prefix + ".model",
+ yaml_filename=f"config_{task}.yaml",
+ specaugment_policy="lb",
+ )
+ # Clean up
+ shutil.rmtree(feature_root)
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--data-root", "-d", required=True, type=str,
+ help="data root with sub-folders for each language /"
+ )
+ parser.add_argument(
+ "--vocab-type",
+ default="unigram",
+ required=True,
+ type=str,
+ choices=["bpe", "unigram", "char"],
+ ),
+ parser.add_argument("--vocab-size", default=1000, type=int)
+ parser.add_argument("--src-lang", "-s", required=True, type=str)
+ parser.add_argument("--tgt-lang", "-t", type=str)
+ args = parser.parse_args()
+
+ process(args)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/fairseq/examples/speech_to_text/prep_librispeech_data.py b/fairseq/examples/speech_to_text/prep_librispeech_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..f379fa7bf195f48ad6b2ed3dbd93a5fbeb7abf79
--- /dev/null
+++ b/fairseq/examples/speech_to_text/prep_librispeech_data.py
@@ -0,0 +1,119 @@
+#!/usr/bin/env python3
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import argparse
+import logging
+from pathlib import Path
+import shutil
+from tempfile import NamedTemporaryFile
+
+import pandas as pd
+from examples.speech_to_text.data_utils import (
+ create_zip,
+ extract_fbank_features,
+ gen_config_yaml,
+ gen_vocab,
+ get_zip_manifest,
+ save_df_to_tsv,
+)
+from torchaudio.datasets import LIBRISPEECH
+from tqdm import tqdm
+
+
+log = logging.getLogger(__name__)
+
+SPLITS = [
+ "train-clean-100",
+ "train-clean-360",
+ "train-other-500",
+ "dev-clean",
+ "dev-other",
+ "test-clean",
+ "test-other",
+]
+
+MANIFEST_COLUMNS = ["id", "audio", "n_frames", "tgt_text", "speaker"]
+
+
+def process(args):
+ out_root = Path(args.output_root).absolute()
+ out_root.mkdir(exist_ok=True)
+ # Extract features
+ feature_root = out_root / "fbank80"
+ feature_root.mkdir(exist_ok=True)
+ for split in SPLITS:
+ print(f"Fetching split {split}...")
+ dataset = LIBRISPEECH(out_root.as_posix(), url=split, download=True)
+ print("Extracting log mel filter bank features...")
+ for wav, sample_rate, _, spk_id, chapter_no, utt_no in tqdm(dataset):
+ sample_id = f"{spk_id}-{chapter_no}-{utt_no}"
+ extract_fbank_features(
+ wav, sample_rate, feature_root / f"{sample_id}.npy"
+ )
+ # Pack features into ZIP
+ zip_path = out_root / "fbank80.zip"
+ print("ZIPing features...")
+ create_zip(feature_root, zip_path)
+ print("Fetching ZIP manifest...")
+ audio_paths, audio_lengths = get_zip_manifest(zip_path)
+ # Generate TSV manifest
+ print("Generating manifest...")
+ train_text = []
+ for split in SPLITS:
+ manifest = {c: [] for c in MANIFEST_COLUMNS}
+ dataset = LIBRISPEECH(out_root.as_posix(), url=split)
+ for _, _, utt, spk_id, chapter_no, utt_no in tqdm(dataset):
+ sample_id = f"{spk_id}-{chapter_no}-{utt_no}"
+ manifest["id"].append(sample_id)
+ manifest["audio"].append(audio_paths[sample_id])
+ manifest["n_frames"].append(audio_lengths[sample_id])
+ manifest["tgt_text"].append(utt.lower())
+ manifest["speaker"].append(spk_id)
+ save_df_to_tsv(
+ pd.DataFrame.from_dict(manifest), out_root / f"{split}.tsv"
+ )
+ if split.startswith("train"):
+ train_text.extend(manifest["tgt_text"])
+ # Generate vocab
+ vocab_size = "" if args.vocab_type == "char" else str(args.vocab_size)
+ spm_filename_prefix = f"spm_{args.vocab_type}{vocab_size}"
+ with NamedTemporaryFile(mode="w") as f:
+ for t in train_text:
+ f.write(t + "\n")
+ gen_vocab(
+ Path(f.name),
+ out_root / spm_filename_prefix,
+ args.vocab_type,
+ args.vocab_size,
+ )
+ # Generate config YAML
+ gen_config_yaml(
+ out_root,
+ spm_filename=spm_filename_prefix + ".model",
+ specaugment_policy="ld"
+ )
+ # Clean up
+ shutil.rmtree(feature_root)
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--output-root", "-o", required=True, type=str)
+ parser.add_argument(
+ "--vocab-type",
+ default="unigram",
+ required=True,
+ type=str,
+ choices=["bpe", "unigram", "char"],
+ ),
+ parser.add_argument("--vocab-size", default=10000, type=int)
+ args = parser.parse_args()
+
+ process(args)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/fairseq/examples/speech_to_text/prep_mtedx_data.py b/fairseq/examples/speech_to_text/prep_mtedx_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..2dfd6317631f56b7fd1e31da98f29f79681ba972
--- /dev/null
+++ b/fairseq/examples/speech_to_text/prep_mtedx_data.py
@@ -0,0 +1,271 @@
+#!/usr/bin/env python3
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import argparse
+import logging
+import os
+from pathlib import Path
+import shutil
+from itertools import groupby
+from tempfile import NamedTemporaryFile
+from typing import Tuple
+
+import pandas as pd
+import soundfile as sf
+from examples.speech_to_text.data_utils import (
+ create_zip,
+ extract_fbank_features,
+ filter_manifest_df,
+ gen_config_yaml,
+ gen_vocab,
+ get_zip_manifest,
+ load_df_from_tsv,
+ save_df_to_tsv,
+)
+import torch
+from torch.utils.data import Dataset
+from tqdm import tqdm
+
+from fairseq.data.audio.audio_utils import get_waveform, convert_waveform
+
+
+log = logging.getLogger(__name__)
+
+
+MANIFEST_COLUMNS = [
+ "id", "audio", "n_frames", "tgt_text", "speaker", "tgt_lang"
+]
+
+
+class mTEDx(Dataset):
+ """
+ Create a Dataset for Multilingual TEDx.
+ Each item is a tuple of the form: waveform, sample_rate, source utterance,
+ target utterance, speaker_id, utterance_id
+ """
+
+ SPLITS = ["train", "valid", "test"]
+ LANGPAIRS = ["es-es", "fr-fr", "pt-pt", "it-it", "ru-ru", "el-el", "ar-ar",
+ "de-de", "es-en", "es-fr", "es-pt", "es-it", "fr-en", "fr-es",
+ "fr-pt", "pt-en", "pt-es", "it-en", "it-es", "ru-en", "el-en"]
+
+ def __init__(self, root: str, lang: str, split: str) -> None:
+ assert split in self.SPLITS and lang in self.LANGPAIRS
+ _root = Path(root) / f"{lang}" / "data" / split
+ wav_root, txt_root = _root / "wav", _root / "txt"
+ assert _root.is_dir() and wav_root.is_dir() and txt_root.is_dir()
+ # Load audio segments
+ try:
+ import yaml
+ except ImportError:
+ print(
+ "Please install PyYAML to load the Multilingual TEDx YAML files"
+ )
+ with open(txt_root / f"{split}.yaml") as f:
+ segments = yaml.load(f, Loader=yaml.BaseLoader)
+ # Load source and target utterances
+ src, tgt = lang.split("-")
+ for _lang in [src, tgt]:
+ with open(txt_root / f"{split}.{_lang}") as f:
+ utterances = [r.strip() for r in f]
+ assert len(segments) == len(utterances)
+ for i, u in enumerate(utterances):
+ segments[i][_lang] = u
+ # Gather info
+ self.data = []
+ for wav_filename, _seg_group in groupby(segments, lambda x: x["wav"]):
+ wav_filename = wav_filename.replace(".wav", ".flac")
+ wav_path = wav_root / wav_filename
+ sample_rate = sf.info(wav_path.as_posix()).samplerate
+ seg_group = sorted(_seg_group, key=lambda x: float(x["offset"]))
+ for i, segment in enumerate(seg_group):
+ offset = int(float(segment["offset"]) * sample_rate)
+ n_frames = int(float(segment["duration"]) * sample_rate)
+ _id = f"{wav_path.stem}_{i}"
+ self.data.append(
+ (
+ wav_path.as_posix(),
+ offset,
+ n_frames,
+ sample_rate,
+ segment[src],
+ segment[tgt],
+ segment["speaker_id"],
+ tgt,
+ _id,
+ )
+ )
+
+ def __getitem__(
+ self, n: int
+ ) -> Tuple[torch.Tensor, int, str, str, str, str, str]:
+ wav_path, offset, n_frames, sr, src_utt, tgt_utt, spk_id, tgt_lang, \
+ utt_id = self.data[n]
+ waveform, _ = get_waveform(wav_path, frames=n_frames, start=offset)
+ waveform = torch.from_numpy(waveform)
+ return waveform, sr, src_utt, tgt_utt, spk_id, tgt_lang, utt_id
+
+ def __len__(self) -> int:
+ return len(self.data)
+
+
+def process(args):
+ root = Path(args.data_root).absolute()
+ for lang in mTEDx.LANGPAIRS:
+ cur_root = root / f"{lang}"
+ if not cur_root.is_dir():
+ print(f"{cur_root.as_posix()} does not exist. Skipped.")
+ continue
+ # Extract features
+ audio_root = cur_root / ("flac" if args.use_audio_input else "fbank80")
+ audio_root.mkdir(exist_ok=True)
+ for split in mTEDx.SPLITS:
+ print(f"Fetching split {split}...")
+ dataset = mTEDx(root.as_posix(), lang, split)
+ if args.use_audio_input:
+ print("Converting audios...")
+ for waveform, sample_rate, _, _, _, utt_id in tqdm(dataset):
+ tgt_sample_rate = 16_000
+ _wavform, _ = convert_waveform(
+ waveform, sample_rate, to_mono=True,
+ to_sample_rate=tgt_sample_rate
+ )
+ sf.write(
+ (audio_root / f"{utt_id}.flac").as_posix(),
+ _wavform.numpy(), tgt_sample_rate
+ )
+ else:
+ print("Extracting log mel filter bank features...")
+ for waveform, sample_rate, _, _, _, _, utt_id in tqdm(dataset):
+ extract_fbank_features(
+ waveform, sample_rate, audio_root / f"{utt_id}.npy"
+ )
+ # Pack features into ZIP
+ zip_path = cur_root / f"{audio_root.name}.zip"
+ print("ZIPing audios/features...")
+ create_zip(audio_root, zip_path)
+ print("Fetching ZIP manifest...")
+ audio_paths, audio_lengths = get_zip_manifest(zip_path)
+ # Generate TSV manifest
+ print("Generating manifest...")
+ train_text = []
+ for split in mTEDx.SPLITS:
+ is_train_split = split.startswith("train")
+ manifest = {c: [] for c in MANIFEST_COLUMNS}
+ ds = mTEDx(args.data_root, lang, split)
+ for _, _, src_utt, tgt_utt, spk_id, tgt_lang, utt_id in tqdm(ds):
+ manifest["id"].append(utt_id)
+ manifest["audio"].append(audio_paths[utt_id])
+ manifest["n_frames"].append(audio_lengths[utt_id])
+ manifest["tgt_text"].append(
+ src_utt if args.task == "asr" else tgt_utt
+ )
+ manifest["speaker"].append(spk_id)
+ manifest["tgt_lang"].append(tgt_lang)
+ if is_train_split:
+ train_text.extend(manifest["tgt_text"])
+ df = pd.DataFrame.from_dict(manifest)
+ df = filter_manifest_df(df, is_train_split=is_train_split)
+ save_df_to_tsv(df, cur_root / f"{split}_{args.task}.tsv")
+ # Generate vocab
+ v_size_str = "" if args.vocab_type == "char" else str(args.vocab_size)
+ spm_filename_prefix = f"spm_{args.vocab_type}{v_size_str}_{args.task}"
+ with NamedTemporaryFile(mode="w") as f:
+ for t in train_text:
+ f.write(t + "\n")
+ gen_vocab(
+ Path(f.name),
+ cur_root / spm_filename_prefix,
+ args.vocab_type,
+ args.vocab_size,
+ )
+ # Generate config YAML
+ if args.use_audio_input:
+ gen_config_yaml(
+ cur_root,
+ spm_filename=spm_filename_prefix + ".model",
+ yaml_filename=f"config_{args.task}.yaml",
+ specaugment_policy=None,
+ extra={"use_audio_input": True}
+ )
+ else:
+ gen_config_yaml(
+ cur_root,
+ spm_filename=spm_filename_prefix + ".model",
+ yaml_filename=f"config_{args.task}.yaml",
+ specaugment_policy="lb",
+ )
+ # Clean up
+ shutil.rmtree(audio_root)
+
+
+def process_joint(args):
+ cur_root = Path(args.data_root)
+ assert all((cur_root / f"{lang}").is_dir() for lang in mTEDx.LANGPAIRS), \
+ "do not have downloaded data available for all languages"
+ # Generate vocab
+ vocab_size_str = "" if args.vocab_type == "char" else str(args.vocab_size)
+ spm_filename_prefix = f"spm_{args.vocab_type}{vocab_size_str}_{args.task}"
+ with NamedTemporaryFile(mode="w") as f:
+ for lang in mTEDx.LANGPAIRS:
+ tsv_path = cur_root / f"{lang}" / f"train_{args.task}.tsv"
+ df = load_df_from_tsv(tsv_path)
+ for t in df["tgt_text"]:
+ f.write(t + "\n")
+ special_symbols = None
+ if args.joint:
+ # Add tgt_lang tags to dict
+ special_symbols = list(
+ {f'' for lang in mTEDx.LANGPAIRS}
+ )
+ gen_vocab(
+ Path(f.name),
+ cur_root / spm_filename_prefix,
+ args.vocab_type,
+ args.vocab_size,
+ special_symbols=special_symbols
+ )
+ # Generate config YAML
+ gen_config_yaml(
+ cur_root,
+ spm_filename=spm_filename_prefix + ".model",
+ yaml_filename=f"config_{args.task}.yaml",
+ specaugment_policy="ld",
+ prepend_tgt_lang_tag=(args.joint),
+ )
+ # Make symbolic links to manifests
+ for lang in mTEDx.LANGPAIRS:
+ for split in mTEDx.SPLITS:
+ src_path = cur_root / f"{lang}" / f"{split}_{args.task}.tsv"
+ desc_path = cur_root / f"{split}_{lang}_{args.task}.tsv"
+ if not desc_path.is_symlink():
+ os.symlink(src_path, desc_path)
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--data-root", "-d", required=True, type=str)
+ parser.add_argument(
+ "--vocab-type",
+ default="unigram",
+ required=True,
+ type=str,
+ choices=["bpe", "unigram", "char"],
+ ),
+ parser.add_argument("--vocab-size", default=8000, type=int)
+ parser.add_argument("--task", type=str, choices=["asr", "st"])
+ parser.add_argument("--joint", action="store_true", help="")
+ parser.add_argument("--use-audio-input", action="store_true")
+ args = parser.parse_args()
+
+ if args.joint:
+ process_joint(args)
+ else:
+ process(args)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/fairseq/examples/speech_to_text/prep_mustc_data.py b/fairseq/examples/speech_to_text/prep_mustc_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..3f0d3fcbd9437999f86d5a39e3d18ba9669f5894
--- /dev/null
+++ b/fairseq/examples/speech_to_text/prep_mustc_data.py
@@ -0,0 +1,291 @@
+#!/usr/bin/env python3
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import argparse
+import logging
+import os
+from pathlib import Path
+import shutil
+from itertools import groupby
+from tempfile import NamedTemporaryFile
+from typing import Tuple
+
+import numpy as np
+import pandas as pd
+import soundfile as sf
+from examples.speech_to_text.data_utils import (
+ create_zip,
+ extract_fbank_features,
+ filter_manifest_df,
+ gen_config_yaml,
+ gen_vocab,
+ get_zip_manifest,
+ load_df_from_tsv,
+ save_df_to_tsv,
+ cal_gcmvn_stats,
+)
+import torch
+from torch.utils.data import Dataset
+from tqdm import tqdm
+
+from fairseq.data.audio.audio_utils import get_waveform, convert_waveform
+
+
+log = logging.getLogger(__name__)
+
+
+MANIFEST_COLUMNS = ["id", "audio", "n_frames", "tgt_text", "speaker"]
+
+
+class MUSTC(Dataset):
+ """
+ Create a Dataset for MuST-C. Each item is a tuple of the form:
+ waveform, sample_rate, source utterance, target utterance, speaker_id,
+ utterance_id
+ """
+
+ SPLITS = ["train", "dev", "tst-COMMON", "tst-HE"]
+ LANGUAGES = ["de", "es", "fr", "it", "nl", "pt", "ro", "ru"]
+
+ def __init__(self, root: str, lang: str, split: str) -> None:
+ assert split in self.SPLITS and lang in self.LANGUAGES
+ _root = Path(root) / f"en-{lang}" / "data" / split
+ wav_root, txt_root = _root / "wav", _root / "txt"
+ assert _root.is_dir() and wav_root.is_dir() and txt_root.is_dir()
+ # Load audio segments
+ try:
+ import yaml
+ except ImportError:
+ print("Please install PyYAML to load the MuST-C YAML files")
+ with open(txt_root / f"{split}.yaml") as f:
+ segments = yaml.load(f, Loader=yaml.BaseLoader)
+ # Load source and target utterances
+ for _lang in ["en", lang]:
+ with open(txt_root / f"{split}.{_lang}") as f:
+ utterances = [r.strip() for r in f]
+ assert len(segments) == len(utterances)
+ for i, u in enumerate(utterances):
+ segments[i][_lang] = u
+ # Gather info
+ self.data = []
+ for wav_filename, _seg_group in groupby(segments, lambda x: x["wav"]):
+ wav_path = wav_root / wav_filename
+ sample_rate = sf.info(wav_path.as_posix()).samplerate
+ seg_group = sorted(_seg_group, key=lambda x: x["offset"])
+ for i, segment in enumerate(seg_group):
+ offset = int(float(segment["offset"]) * sample_rate)
+ n_frames = int(float(segment["duration"]) * sample_rate)
+ _id = f"{wav_path.stem}_{i}"
+ self.data.append(
+ (
+ wav_path.as_posix(),
+ offset,
+ n_frames,
+ sample_rate,
+ segment["en"],
+ segment[lang],
+ segment["speaker_id"],
+ _id,
+ )
+ )
+
+ def __getitem__(
+ self, n: int
+ ) -> Tuple[torch.Tensor, int, str, str, str, str]:
+ wav_path, offset, n_frames, sr, src_utt, tgt_utt, spk_id, \
+ utt_id = self.data[n]
+ waveform, _ = get_waveform(wav_path, frames=n_frames, start=offset)
+ waveform = torch.from_numpy(waveform)
+ return waveform, sr, src_utt, tgt_utt, spk_id, utt_id
+
+ def __len__(self) -> int:
+ return len(self.data)
+
+
+def process(args):
+ root = Path(args.data_root).absolute()
+ for lang in MUSTC.LANGUAGES:
+ cur_root = root / f"en-{lang}"
+ if not cur_root.is_dir():
+ print(f"{cur_root.as_posix()} does not exist. Skipped.")
+ continue
+ # Extract features
+ audio_root = cur_root / ("flac" if args.use_audio_input else "fbank80")
+ audio_root.mkdir(exist_ok=True)
+
+ for split in MUSTC.SPLITS:
+ print(f"Fetching split {split}...")
+ dataset = MUSTC(root.as_posix(), lang, split)
+ if args.use_audio_input:
+ print("Converting audios...")
+ for waveform, sample_rate, _, _, _, utt_id in tqdm(dataset):
+ tgt_sample_rate = 16_000
+ _wavform, _ = convert_waveform(
+ waveform, sample_rate, to_mono=True,
+ to_sample_rate=tgt_sample_rate
+ )
+ sf.write(
+ (audio_root / f"{utt_id}.flac").as_posix(),
+ _wavform.numpy(), tgt_sample_rate
+ )
+ else:
+ print("Extracting log mel filter bank features...")
+ gcmvn_feature_list = []
+ if split == 'train' and args.cmvn_type == "global":
+ print("And estimating cepstral mean and variance stats...")
+
+ for waveform, sample_rate, _, _, _, utt_id in tqdm(dataset):
+ features = extract_fbank_features(
+ waveform, sample_rate, audio_root / f"{utt_id}.npy"
+ )
+ if split == 'train' and args.cmvn_type == "global":
+ if len(gcmvn_feature_list) < args.gcmvn_max_num:
+ gcmvn_feature_list.append(features)
+
+ if split == 'train' and args.cmvn_type == "global":
+ # Estimate and save cmv
+ stats = cal_gcmvn_stats(gcmvn_feature_list)
+ with open(cur_root / "gcmvn.npz", "wb") as f:
+ np.savez(f, mean=stats["mean"], std=stats["std"])
+
+ # Pack features into ZIP
+ zip_path = cur_root / f"{audio_root.name}.zip"
+ print("ZIPing audios/features...")
+ create_zip(audio_root, zip_path)
+ print("Fetching ZIP manifest...")
+ audio_paths, audio_lengths = get_zip_manifest(zip_path)
+ # Generate TSV manifest
+ print("Generating manifest...")
+ train_text = []
+ for split in MUSTC.SPLITS:
+ is_train_split = split.startswith("train")
+ manifest = {c: [] for c in MANIFEST_COLUMNS}
+ dataset = MUSTC(args.data_root, lang, split)
+ for _, _, src_utt, tgt_utt, speaker_id, utt_id in tqdm(dataset):
+ manifest["id"].append(utt_id)
+ manifest["audio"].append(audio_paths[utt_id])
+ manifest["n_frames"].append(audio_lengths[utt_id])
+ manifest["tgt_text"].append(
+ src_utt if args.task == "asr" else tgt_utt
+ )
+ manifest["speaker"].append(speaker_id)
+ if is_train_split:
+ train_text.extend(manifest["tgt_text"])
+ df = pd.DataFrame.from_dict(manifest)
+ df = filter_manifest_df(df, is_train_split=is_train_split)
+ save_df_to_tsv(df, cur_root / f"{split}_{args.task}.tsv")
+ # Generate vocab
+ v_size_str = "" if args.vocab_type == "char" else str(args.vocab_size)
+ spm_filename_prefix = f"spm_{args.vocab_type}{v_size_str}_{args.task}"
+ with NamedTemporaryFile(mode="w") as f:
+ for t in train_text:
+ f.write(t + "\n")
+ gen_vocab(
+ Path(f.name),
+ cur_root / spm_filename_prefix,
+ args.vocab_type,
+ args.vocab_size,
+ )
+ # Generate config YAML
+ if args.use_audio_input:
+ gen_config_yaml(
+ cur_root,
+ spm_filename=spm_filename_prefix + ".model",
+ yaml_filename=f"config_{args.task}.yaml",
+ specaugment_policy=None,
+ extra={"use_audio_input": True}
+ )
+ else:
+ gen_config_yaml(
+ cur_root,
+ spm_filename=spm_filename_prefix + ".model",
+ yaml_filename=f"config_{args.task}.yaml",
+ specaugment_policy="lb",
+ cmvn_type=args.cmvn_type,
+ gcmvn_path=(
+ cur_root / "gcmvn.npz" if args.cmvn_type == "global"
+ else None
+ ),
+ )
+ # Clean up
+ shutil.rmtree(audio_root)
+
+
+def process_joint(args):
+ cur_root = Path(args.data_root)
+ assert all(
+ (cur_root / f"en-{lang}").is_dir() for lang in MUSTC.LANGUAGES
+ ), "do not have downloaded data available for all 8 languages"
+ # Generate vocab
+ vocab_size_str = "" if args.vocab_type == "char" else str(args.vocab_size)
+ spm_filename_prefix = f"spm_{args.vocab_type}{vocab_size_str}_{args.task}"
+ with NamedTemporaryFile(mode="w") as f:
+ for lang in MUSTC.LANGUAGES:
+ tsv_path = cur_root / f"en-{lang}" / f"train_{args.task}.tsv"
+ df = load_df_from_tsv(tsv_path)
+ for t in df["tgt_text"]:
+ f.write(t + "\n")
+ special_symbols = None
+ if args.task == 'st':
+ special_symbols = [f'' for lang in MUSTC.LANGUAGES]
+ gen_vocab(
+ Path(f.name),
+ cur_root / spm_filename_prefix,
+ args.vocab_type,
+ args.vocab_size,
+ special_symbols=special_symbols
+ )
+ # Generate config YAML
+ gen_config_yaml(
+ cur_root,
+ spm_filename=spm_filename_prefix + ".model",
+ yaml_filename=f"config_{args.task}.yaml",
+ specaugment_policy="ld",
+ prepend_tgt_lang_tag=(args.task == "st"),
+ )
+ # Make symbolic links to manifests
+ for lang in MUSTC.LANGUAGES:
+ for split in MUSTC.SPLITS:
+ src_path = cur_root / f"en-{lang}" / f"{split}_{args.task}.tsv"
+ desc_path = cur_root / f"{split}_{lang}_{args.task}.tsv"
+ if not desc_path.is_symlink():
+ os.symlink(src_path, desc_path)
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--data-root", "-d", required=True, type=str)
+ parser.add_argument(
+ "--vocab-type",
+ default="unigram",
+ required=True,
+ type=str,
+ choices=["bpe", "unigram", "char"],
+ ),
+ parser.add_argument("--vocab-size", default=8000, type=int)
+ parser.add_argument("--task", type=str, choices=["asr", "st"])
+ parser.add_argument("--joint", action="store_true", help="")
+ parser.add_argument(
+ "--cmvn-type", default="utterance",
+ choices=["global", "utterance"],
+ help="The type of cepstral mean and variance normalization"
+ )
+ parser.add_argument(
+ "--gcmvn-max-num", default=150000, type=int,
+ help="Maximum number of sentences to use to estimate global mean and "
+ "variance"
+ )
+ parser.add_argument("--use-audio-input", action="store_true")
+ args = parser.parse_args()
+
+ if args.joint:
+ process_joint(args)
+ else:
+ process(args)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/fairseq/examples/speech_to_text/seg_mustc_data.py b/fairseq/examples/speech_to_text/seg_mustc_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..1ee665d6399729afe17d790d872eff34de124900
--- /dev/null
+++ b/fairseq/examples/speech_to_text/seg_mustc_data.py
@@ -0,0 +1,54 @@
+#!/usr/bin/env python3
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import argparse
+import logging
+from pathlib import Path
+import soundfile as sf
+from examples.speech_to_text.prep_mustc_data import (
+ MUSTC
+)
+
+from tqdm import tqdm
+
+log = logging.getLogger(__name__)
+
+
+def main(args):
+ root = Path(args.data_root).absolute()
+ lang = args.lang
+ split = args.split
+
+ cur_root = root / f"en-{lang}"
+ assert cur_root.is_dir(), (
+ f"{cur_root.as_posix()} does not exist. Skipped."
+ )
+
+ dataset = MUSTC(root.as_posix(), lang, split)
+ output = Path(args.output).absolute()
+ output.mkdir(exist_ok=True)
+ f_text = open(output / f"{split}.{lang}", "w")
+ f_wav_list = open(output / f"{split}.wav_list", "w")
+ for waveform, sample_rate, _, text, _, utt_id in tqdm(dataset):
+ sf.write(
+ output / f"{utt_id}.wav",
+ waveform.squeeze(0).numpy(),
+ samplerate=int(sample_rate)
+ )
+ f_text.write(text + "\n")
+ f_wav_list.write(str(output / f"{utt_id}.wav") + "\n")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--data-root", "-d", required=True, type=str)
+ parser.add_argument("--task", required=True, type=str, choices=["asr", "st"])
+ parser.add_argument("--lang", required=True, type=str)
+ parser.add_argument("--output", required=True, type=str)
+ parser.add_argument("--split", required=True, choices=MUSTC.SPLITS)
+ args = parser.parse_args()
+
+ main(args)
diff --git a/fairseq/examples/speech_to_text/simultaneous_translation/agents/fairseq_simul_st_agent.py b/fairseq/examples/speech_to_text/simultaneous_translation/agents/fairseq_simul_st_agent.py
new file mode 100644
index 0000000000000000000000000000000000000000..61617a1739ce196abba1e9a6f9ad9e9f4b37b9c1
--- /dev/null
+++ b/fairseq/examples/speech_to_text/simultaneous_translation/agents/fairseq_simul_st_agent.py
@@ -0,0 +1,363 @@
+import math
+import os
+import json
+import numpy as np
+import torch
+import torchaudio.compliance.kaldi as kaldi
+import yaml
+from fairseq import checkpoint_utils, tasks
+from fairseq.file_io import PathManager
+
+try:
+ from simuleval import READ_ACTION, WRITE_ACTION, DEFAULT_EOS
+ from simuleval.agents import SpeechAgent
+ from simuleval.states import ListEntry, SpeechStates
+except ImportError:
+ print("Please install simuleval 'pip install simuleval'")
+
+SHIFT_SIZE = 10
+WINDOW_SIZE = 25
+SAMPLE_RATE = 16000
+FEATURE_DIM = 80
+BOW_PREFIX = "\u2581"
+
+
+class OnlineFeatureExtractor:
+ """
+ Extract speech feature on the fly.
+ """
+
+ def __init__(self, args):
+ self.shift_size = args.shift_size
+ self.window_size = args.window_size
+ assert self.window_size >= self.shift_size
+
+ self.sample_rate = args.sample_rate
+ self.feature_dim = args.feature_dim
+ self.num_samples_per_shift = int(self.shift_size * self.sample_rate / 1000)
+ self.num_samples_per_window = int(self.window_size * self.sample_rate / 1000)
+ self.len_ms_to_samples = lambda x: x * self.sample_rate / 1000
+ self.previous_residual_samples = []
+ self.global_cmvn = args.global_cmvn
+
+ def clear_cache(self):
+ self.previous_residual_samples = []
+
+ def __call__(self, new_samples):
+ samples = self.previous_residual_samples + new_samples
+ if len(samples) < self.num_samples_per_window:
+ self.previous_residual_samples = samples
+ return
+
+ # num_frames is the number of frames from the new segment
+ num_frames = math.floor(
+ (len(samples) - self.len_ms_to_samples(self.window_size - self.shift_size))
+ / self.num_samples_per_shift
+ )
+
+ # the number of frames used for feature extraction
+ # including some part of thte previous segment
+ effective_num_samples = int(
+ num_frames * self.len_ms_to_samples(self.shift_size)
+ + self.len_ms_to_samples(self.window_size - self.shift_size)
+ )
+
+ input_samples = samples[:effective_num_samples]
+ self.previous_residual_samples = samples[
+ num_frames * self.num_samples_per_shift:
+ ]
+
+ torch.manual_seed(1)
+ output = kaldi.fbank(
+ torch.FloatTensor(input_samples).unsqueeze(0),
+ num_mel_bins=self.feature_dim,
+ frame_length=self.window_size,
+ frame_shift=self.shift_size,
+ ).numpy()
+
+ output = self.transform(output)
+
+ return torch.from_numpy(output)
+
+ def transform(self, input):
+ if self.global_cmvn is None:
+ return input
+
+ mean = self.global_cmvn["mean"]
+ std = self.global_cmvn["std"]
+
+ x = np.subtract(input, mean)
+ x = np.divide(x, std)
+ return x
+
+
+class TensorListEntry(ListEntry):
+ """
+ Data structure to store a list of tensor.
+ """
+
+ def append(self, value):
+
+ if len(self.value) == 0:
+ self.value = value
+ return
+
+ self.value = torch.cat([self.value] + [value], dim=0)
+
+ def info(self):
+ return {
+ "type": str(self.new_value_type),
+ "length": self.__len__(),
+ "value": "" if type(self.value) is list else self.value.size(),
+ }
+
+
+class FairseqSimulSTAgent(SpeechAgent):
+
+ speech_segment_size = 40 # in ms, 4 pooling ratio * 10 ms step size
+
+ def __init__(self, args):
+ super().__init__(args)
+
+ self.eos = DEFAULT_EOS
+
+ self.gpu = getattr(args, "gpu", False)
+
+ self.args = args
+
+ self.load_model_vocab(args)
+
+ if getattr(
+ self.model.decoder.layers[0].encoder_attn,
+ 'pre_decision_ratio',
+ None
+ ) is not None:
+ self.speech_segment_size *= (
+ self.model.decoder.layers[0].encoder_attn.pre_decision_ratio
+ )
+
+ args.global_cmvn = None
+ if args.config:
+ with open(os.path.join(args.data_bin, args.config), "r") as f:
+ config = yaml.load(f, Loader=yaml.BaseLoader)
+
+ if "global_cmvn" in config:
+ args.global_cmvn = np.load(config["global_cmvn"]["stats_npz_path"])
+
+ if args.global_stats:
+ with PathManager.open(args.global_stats, "r") as f:
+ global_cmvn = json.loads(f.read())
+ self.global_cmvn = {"mean": global_cmvn["mean"], "std": global_cmvn["stddev"]}
+
+ self.feature_extractor = OnlineFeatureExtractor(args)
+
+ self.max_len = args.max_len
+
+ self.force_finish = args.force_finish
+
+ torch.set_grad_enabled(False)
+
+ def build_states(self, args, client, sentence_id):
+ # Initialize states here, for example add customized entry to states
+ # This function will be called at beginning of every new sentence
+ states = SpeechStates(args, client, sentence_id, self)
+ self.initialize_states(states)
+ return states
+
+ def to_device(self, tensor):
+ if self.gpu:
+ return tensor.cuda()
+ else:
+ return tensor.cpu()
+
+ @staticmethod
+ def add_args(parser):
+ # fmt: off
+ parser.add_argument('--model-path', type=str, required=True,
+ help='path to your pretrained model.')
+ parser.add_argument("--data-bin", type=str, required=True,
+ help="Path of data binary")
+ parser.add_argument("--config", type=str, default=None,
+ help="Path to config yaml file")
+ parser.add_argument("--global-stats", type=str, default=None,
+ help="Path to json file containing cmvn stats")
+ parser.add_argument("--tgt-splitter-type", type=str, default="SentencePiece",
+ help="Subword splitter type for target text")
+ parser.add_argument("--tgt-splitter-path", type=str, default=None,
+ help="Subword splitter model path for target text")
+ parser.add_argument("--user-dir", type=str, default="examples/simultaneous_translation",
+ help="User directory for simultaneous translation")
+ parser.add_argument("--max-len", type=int, default=200,
+ help="Max length of translation")
+ parser.add_argument("--force-finish", default=False, action="store_true",
+ help="Force the model to finish the hypothsis if the source is not finished")
+ parser.add_argument("--shift-size", type=int, default=SHIFT_SIZE,
+ help="Shift size of feature extraction window.")
+ parser.add_argument("--window-size", type=int, default=WINDOW_SIZE,
+ help="Window size of feature extraction window.")
+ parser.add_argument("--sample-rate", type=int, default=SAMPLE_RATE,
+ help="Sample rate")
+ parser.add_argument("--feature-dim", type=int, default=FEATURE_DIM,
+ help="Acoustic feature dimension.")
+
+ # fmt: on
+ return parser
+
+ def load_model_vocab(self, args):
+
+ filename = args.model_path
+ if not os.path.exists(filename):
+ raise IOError("Model file not found: {}".format(filename))
+
+ state = checkpoint_utils.load_checkpoint_to_cpu(filename)
+
+ task_args = state["cfg"]["task"]
+ task_args.data = args.data_bin
+
+ if args.config is not None:
+ task_args.config_yaml = args.config
+
+ task = tasks.setup_task(task_args)
+
+ # build model for ensemble
+ state["cfg"]["model"].load_pretrained_encoder_from = None
+ state["cfg"]["model"].load_pretrained_decoder_from = None
+ self.model = task.build_model(state["cfg"]["model"])
+ self.model.load_state_dict(state["model"], strict=True)
+ self.model.eval()
+ self.model.share_memory()
+
+ if self.gpu:
+ self.model.cuda()
+
+ # Set dictionary
+ self.dict = {}
+ self.dict["tgt"] = task.target_dictionary
+
+ def initialize_states(self, states):
+ self.feature_extractor.clear_cache()
+ states.units.source = TensorListEntry()
+ states.units.target = ListEntry()
+ states.incremental_states = dict()
+
+ def segment_to_units(self, segment, states):
+ # Convert speech samples to features
+ features = self.feature_extractor(segment)
+ if features is not None:
+ return [features]
+ else:
+ return []
+
+ def units_to_segment(self, units, states):
+ # Merge sub word to full word.
+ if self.model.decoder.dictionary.eos() == units[0]:
+ return DEFAULT_EOS
+
+ segment = []
+ if None in units.value:
+ units.value.remove(None)
+
+ for index in units:
+ if index is None:
+ units.pop()
+ token = self.model.decoder.dictionary.string([index])
+ if token.startswith(BOW_PREFIX):
+ if len(segment) == 0:
+ segment += [token.replace(BOW_PREFIX, "")]
+ else:
+ for j in range(len(segment)):
+ units.pop()
+
+ string_to_return = ["".join(segment)]
+
+ if self.model.decoder.dictionary.eos() == units[0]:
+ string_to_return += [DEFAULT_EOS]
+
+ return string_to_return
+ else:
+ segment += [token.replace(BOW_PREFIX, "")]
+
+ if (
+ len(units) > 0
+ and self.model.decoder.dictionary.eos() == units[-1]
+ or len(states.units.target) > self.max_len
+ ):
+ tokens = [self.model.decoder.dictionary.string([unit]) for unit in units]
+ return ["".join(tokens).replace(BOW_PREFIX, "")] + [DEFAULT_EOS]
+
+ return None
+
+ def update_model_encoder(self, states):
+ if len(states.units.source) == 0:
+ return
+ src_indices = self.to_device(
+ states.units.source.value.unsqueeze(0)
+ )
+ src_lengths = self.to_device(
+ torch.LongTensor([states.units.source.value.size(0)])
+ )
+
+ states.encoder_states = self.model.encoder(src_indices, src_lengths)
+ torch.cuda.empty_cache()
+
+ def update_states_read(self, states):
+ # Happens after a read action.
+ self.update_model_encoder(states)
+
+ def policy(self, states):
+ if not getattr(states, "encoder_states", None):
+ return READ_ACTION
+
+ tgt_indices = self.to_device(
+ torch.LongTensor(
+ [self.model.decoder.dictionary.eos()]
+ + [x for x in states.units.target.value if x is not None]
+ ).unsqueeze(0)
+ )
+
+ states.incremental_states["steps"] = {
+ "src": states.encoder_states["encoder_out"][0].size(0),
+ "tgt": 1 + len(states.units.target),
+ }
+
+ states.incremental_states["online"] = {"only": torch.tensor(not states.finish_read())}
+
+ x, outputs = self.model.decoder.forward(
+ prev_output_tokens=tgt_indices,
+ encoder_out=states.encoder_states,
+ incremental_state=states.incremental_states,
+ )
+
+ states.decoder_out = x
+
+ states.decoder_out_extra = outputs
+
+ torch.cuda.empty_cache()
+
+ if outputs.action == 0:
+ return READ_ACTION
+ else:
+ return WRITE_ACTION
+
+ def predict(self, states):
+ decoder_states = states.decoder_out
+
+ lprobs = self.model.get_normalized_probs(
+ [decoder_states[:, -1:]], log_probs=True
+ )
+
+ index = lprobs.argmax(dim=-1)
+
+ index = index[0, 0].item()
+
+ if (
+ self.force_finish
+ and index == self.model.decoder.dictionary.eos()
+ and not states.finish_read()
+ ):
+ # If we want to force finish the translation
+ # (don't stop before finish reading), return a None
+ # self.model.decoder.clear_cache(states.incremental_states)
+ index = None
+
+ return index
diff --git a/fairseq/examples/stories/README.md b/fairseq/examples/stories/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..588941eddc5f0280f5254affd40ef49de874c885
--- /dev/null
+++ b/fairseq/examples/stories/README.md
@@ -0,0 +1,66 @@
+# Hierarchical Neural Story Generation (Fan et al., 2018)
+
+The following commands provide an example of pre-processing data, training a model, and generating text for story generation with the WritingPrompts dataset.
+
+## Pre-trained models
+
+Description | Dataset | Model | Test set(s)
+---|---|---|---
+Stories with Convolutional Model
([Fan et al., 2018](https://arxiv.org/abs/1805.04833)) | [WritingPrompts](https://dl.fbaipublicfiles.com/fairseq/data/writingPrompts.tar.gz) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/stories_checkpoint.tar.bz2) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/stories_test.tar.bz2)
+
+We provide sample stories generated by the [convolutional seq2seq model](https://dl.fbaipublicfiles.com/fairseq/data/seq2seq_stories.txt) and [fusion model](https://dl.fbaipublicfiles.com/fairseq/data/fusion_stories.txt) from [Fan et al., 2018](https://arxiv.org/abs/1805.04833). The corresponding prompts for the fusion model can be found [here](https://dl.fbaipublicfiles.com/fairseq/data/fusion_prompts.txt). Note that there are unk in the file, as we modeled a small full vocabulary (no BPE or pre-training). We did not use these unk prompts for human evaluation.
+
+## Dataset
+
+The dataset can be downloaded like this:
+
+```bash
+cd examples/stories
+curl https://dl.fbaipublicfiles.com/fairseq/data/writingPrompts.tar.gz | tar xvzf -
+```
+
+and contains a train, test, and valid split. The dataset is described here: https://arxiv.org/abs/1805.04833. We model only the first 1000 words of each story, including one newLine token.
+
+## Example usage
+
+First we will preprocess the dataset. Note that the dataset release is the full data, but the paper models the first 1000 words of each story. Here is example code that trims the dataset to the first 1000 words of each story:
+```python
+data = ["train", "test", "valid"]
+for name in data:
+ with open(name + ".wp_target") as f:
+ stories = f.readlines()
+ stories = [" ".join(i.split()[0:1000]) for i in stories]
+ with open(name + ".wp_target", "w") as o:
+ for line in stories:
+ o.write(line.strip() + "\n")
+```
+
+Once we've trimmed the data we can binarize it and train our model:
+```bash
+# Binarize the dataset:
+export TEXT=examples/stories/writingPrompts
+fairseq-preprocess --source-lang wp_source --target-lang wp_target \
+ --trainpref $TEXT/train --validpref $TEXT/valid --testpref $TEXT/test \
+ --destdir data-bin/writingPrompts --padding-factor 1 --thresholdtgt 10 --thresholdsrc 10
+
+# Train the model:
+fairseq-train data-bin/writingPrompts -a fconv_self_att_wp --lr 0.25 --optimizer nag --clip-norm 0.1 --max-tokens 1500 --lr-scheduler reduce_lr_on_plateau --decoder-attention True --encoder-attention False --criterion label_smoothed_cross_entropy --weight-decay .0000001 --label-smoothing 0 --source-lang wp_source --target-lang wp_target --gated-attention True --self-attention True --project-input True --pretrained False
+
+# Train a fusion model:
+# add the arguments: --pretrained True --pretrained-checkpoint path/to/checkpoint
+
+# Generate:
+# Note: to load the pretrained model at generation time, you need to pass in a model-override argument to communicate to the fusion model at generation time where you have placed the pretrained checkpoint. By default, it will load the exact path of the fusion model's pretrained model from training time. You should use model-override if you have moved the pretrained model (or are using our provided models). If you are generating from a non-fusion model, the model-override argument is not necessary.
+
+fairseq-generate data-bin/writingPrompts --path /path/to/trained/model/checkpoint_best.pt --batch-size 32 --beam 1 --sampling --sampling-topk 10 --temperature 0.8 --nbest 1 --model-overrides "{'pretrained_checkpoint':'/path/to/pretrained/model/checkpoint'}"
+```
+
+## Citation
+```bibtex
+@inproceedings{fan2018hierarchical,
+ title = {Hierarchical Neural Story Generation},
+ author = {Fan, Angela and Lewis, Mike and Dauphin, Yann},
+ booktitle = {Conference of the Association for Computational Linguistics (ACL)},
+ year = 2018,
+}
+```
diff --git a/fairseq/examples/textless_nlp/gslm/README.md b/fairseq/examples/textless_nlp/gslm/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..7a76ffd57c066c20af94aa3fca24c18e2ba4c3dd
--- /dev/null
+++ b/fairseq/examples/textless_nlp/gslm/README.md
@@ -0,0 +1,21 @@
+# Generative Spoken Language Modeling
+
+* [Paper](https://arxiv.org/abs/2102.01192)
+* [Demo](https://speechbot.github.io/gslm/index.html)
+
+We build and evaluate generative speech2speech systems using [Log Mel Filtebank](https://pytorch.org/audio/stable/compliance.kaldi.html#fbank), [Modified CPC](https://github.com/facebookresearch/CPC_audio), [HuBERT Base](https://github.com/pytorch/fairseq/tree/main/examples/hubert) and [Wav2Vec 2.0 Large](https://github.com/pytorch/fairseq/tree/main/examples/wav2vec). Our system is composed of three components, namely, *speech2unit*, *ulm* and *unit2speech*. We explain about models and usage of these components in their respective sub-directories. See the links below.
+
+## Speech to Unit Model (speech2unit)
+Speech to unit model is used for quantizing raw speech into learned discrete speech units. [More details](speech2unit)
+
+## Unit Language Model (ulm)
+Unit Language Model is a generative language model trained on discrete speech units. [More details](ulm)
+
+## Unit to Speech Model (unit2speech)
+Unit to speech model is used for synthesizing speech from discrete speech units. [More details](unit2speech)
+
+## Metrics
+We show how to compute ASR based metrics as well as zero-shot metrics proposed in our paper [here](metrics).
+
+## Tools
+We share two tools to resynthesize a given spoken utterance, and generate novel spoken language given a spoken prompt. [More detail](tools)
diff --git a/fairseq/examples/textless_nlp/gslm/metrics/README.md b/fairseq/examples/textless_nlp/gslm/metrics/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..0a63e2f0d844ce157f9502c82738aac2a0de3f0c
--- /dev/null
+++ b/fairseq/examples/textless_nlp/gslm/metrics/README.md
@@ -0,0 +1,10 @@
+# GSLM Metrics
+
+## ASR Metrics
+The suite of metrics here uses an ASR model to transcribe the synthesized speech into text, and then uses text-based metrics. We also use word error rate from ASR transcription itself as one of the metrics. [More details](asr_metrics)
+
+## ABX Metrics
+We use [ABX](https://www.semanticscholar.org/paper/ABX-Discriminability-Measures-and-Applications-Schatz/13d3537228f728c1063cc83743cb118bba3367a0) to evaluate how well-separated phonetic categories are with quantized representations. [More details](abx_metrics)
+
+## sWUGGY and sBLIMP
+We refer to [ZeroSpeech challenge](https://www.zerospeech.com/2021/track_s.html#scoring-based-metrics) for details on the sWUGGY and sBLIMP metrics.
diff --git a/fairseq/examples/textless_nlp/gslm/metrics/abx_metrics/README.md b/fairseq/examples/textless_nlp/gslm/metrics/abx_metrics/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..aa2560f0453403fb5846c387848c78b037c79cb2
--- /dev/null
+++ b/fairseq/examples/textless_nlp/gslm/metrics/abx_metrics/README.md
@@ -0,0 +1,77 @@
+# ABX-based evaluation
+
+ABX is used to evaluate the quality of the obtained discrete units.
+
+The life cycle of the ABX-based evaluation for the Speech-to-Unit contains the following steps:
+1. Training an acoustic model (or use an existing acoustic model) ([description](./../..))
+2. Perform quantization of speech by learning a K-means clustering model ([description](./../..))
+3. Compute discrete features for ABX computation using the learned clusters
+4. Compute the ABX score over the discrete features taking advantage of [libri-light's ABX evaluation script][ll-abx]
+
+Here we assume that you already went throught the first two steps and focus solely on extracting features and computing ABX scores.
+
+## Libri-light setup
+
+Follow [libri-light's instructions][ll-instructions] for installation and [ABX evaluation setup][ll-abx] (including the download of the data items required for ABX computation).
+
+## Computing ABX
+
+### Dumping quantized features
+
+The first step for the ABX computation is to dump the quantized representations corresponding to the test files.
+
+```shell
+TYPE="hubert"
+LAYER=6
+CKPT_PATH=""
+KM_MODEL_PATH=""
+
+SUBSET="dev-clean"
+MANIFEST=""
+DATA_DIR="/$SUBSET"
+
+PYTHONPATH=. python examples/textless_nlp/gslm/metrics/abx_metrics/dump_abx_feats.py \
+ --feature_type $TYPE \
+ --kmeans_model_path $KM_MODEL_PATH \
+ --checkpoint_path $CKPT_PATH \
+ --layer $LAYER \
+ --manifest_path $MANIFEST \
+ --out_dir_path $DATA_DIR \
+ --extension ".flac"
+```
+
+Again the manifest file follows the same structure than elsewhere in the codebase.
+
+### Compute ABX with Libri-light
+
+Use libri-light's `eval_ABX.py` script (within the appropriate environment set up) as followed:
+
+```shell
+LIBRILIGHT_ROOT=""
+
+SUBSET="dev-clean"
+DATA_DIR="/$SUBSET"
+ITEM_FILE_PATH="$LIBRILIGHT_ROOT/eval/ABX_data/$SUBSET.item"
+OUT_DIR="/$SUBSET"
+
+FILE_EXTENSION=".npy"
+FEATURE_SIZE=0.02 # depends on the model used
+
+PYTHONPATH=$LIBRILIGHT_ROOT \
+ python $LIBRILIGHT_ROOT/eval/eval_ABX.py \
+ $DATA_DIR \
+ $ITEM_FILE_PATH \
+ --file_extension $FILE_EXTENSION \
+ --feature_size $FEATURE_SIZE \
+ --out $OUT_DIR \
+ --mode "all"
+```
+
+Note that `FEATURE_SIZE` will depend on the model type you are using to extract the acoustic features:
+* For HuBERT and Wav2Vec2.0, use `FEATURE_SIZE=0.02`
+* For CPC and Log Mel, use `FEATURE_SIZE=0.01`
+
+If you have a gpu available, make sure you add the `--cuda` flag for faster computation.
+
+[ll-instructions]: https://github.com/facebookresearch/libri-light
+[ll-abx]: https://github.com/facebookresearch/libri-light/tree/master/eval#abx
diff --git a/fairseq/examples/textless_nlp/gslm/metrics/abx_metrics/dump_abx_feats.py b/fairseq/examples/textless_nlp/gslm/metrics/abx_metrics/dump_abx_feats.py
new file mode 100644
index 0000000000000000000000000000000000000000..41cf558970608fa5a9241e91e59ba214b609dc73
--- /dev/null
+++ b/fairseq/examples/textless_nlp/gslm/metrics/abx_metrics/dump_abx_feats.py
@@ -0,0 +1,107 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import argparse
+import logging
+import os
+
+import joblib
+import numpy as np
+
+from examples.textless_nlp.gslm.speech2unit.clustering.utils import get_audio_files
+from examples.textless_nlp.gslm.speech2unit.pretrained.utils import get_features
+
+def get_logger():
+ log_format = "[%(asctime)s] [%(levelname)s]: %(message)s"
+ logging.basicConfig(format=log_format, level=logging.INFO)
+ logger = logging.getLogger(__name__)
+ return logger
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ description="Quantize using K-means clustering over acoustic features."
+ )
+ parser.add_argument(
+ "--feature_type",
+ type=str,
+ choices=["logmel", "hubert", "w2v2", "cpc"],
+ default=None,
+ required=True,
+ help="Acoustic feature type",
+ )
+ parser.add_argument(
+ "--kmeans_model_path",
+ type=str,
+ required=True,
+ help="K-means model file path to use for inference",
+ )
+ parser.add_argument(
+ "--manifest_path",
+ type=str,
+ default=None,
+ help="Manifest file containing the root dir and file names",
+ )
+ parser.add_argument(
+ "--checkpoint_path",
+ type=str,
+ help="Pretrained model checkpoint",
+ )
+ parser.add_argument(
+ "--layer",
+ type=int,
+ help="The layer of the pretrained model to extract features from",
+ default=-1,
+ )
+ parser.add_argument(
+ "--out_dir_path",
+ required=True,
+ type=str,
+ help="File path of quantized output.",
+ )
+ parser.add_argument(
+ "--extension", type=str, default=".flac", help="Features file path"
+ )
+ return parser
+
+
+def one_hot(feat, n_clusters):
+ return np.eye(n_clusters)[feat]
+
+def main(args, logger):
+ # Feature extraction
+ logger.info(f"Extracting {args.feature_type} acoustic features...")
+ features_batch = get_features(
+ feature_type=args.feature_type,
+ checkpoint_path=args.checkpoint_path,
+ layer=args.layer,
+ manifest_path=args.manifest_path,
+ sample_pct=1.0,
+ flatten=False,
+ )
+ logger.info(f"Features extracted for {len(features_batch)} utterances.\n")
+ logger.info(f"Dimensionality of representation = {features_batch[0].shape[1]}")
+
+ logger.info(f"Loading K-means model from {args.kmeans_model_path} ...")
+ kmeans_model = joblib.load(open(args.kmeans_model_path, "rb"))
+ kmeans_model.verbose = False
+
+ _, fnames, _ = get_audio_files(args.manifest_path)
+
+ os.makedirs(args.out_dir_path, exist_ok=True)
+ logger.info(f"Writing quantized features to {args.out_dir_path}")
+ for i, feats in enumerate(features_batch):
+ pred = kmeans_model.predict(feats)
+ emb = one_hot(pred, kmeans_model.n_clusters)
+ base_fname = os.path.basename(fnames[i]).rstrip(args.extension)
+ output_path = os.path.join(args.out_dir_path, f"{base_fname}.npy")
+ with open(output_path, "wb") as f:
+ np.save(f, emb)
+
+if __name__ == "__main__":
+ parser = get_parser()
+ args = parser.parse_args()
+ logger = get_logger()
+ logger.info(args)
+ main(args, logger)
diff --git a/fairseq/examples/textless_nlp/gslm/metrics/asr_metrics/README.md b/fairseq/examples/textless_nlp/gslm/metrics/asr_metrics/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..90741f42b0b070f2a91b63c8badb817c6aa24230
--- /dev/null
+++ b/fairseq/examples/textless_nlp/gslm/metrics/asr_metrics/README.md
@@ -0,0 +1,87 @@
+# ASR-based evaluation
+
+Overall, the life cycle of the ASR-based evaluation for an ULM contains the following steps:
+ 1. Training an ULM and sampling from it [[description]](./../../ulm)
+ 2. Running UTS on the sampled unit sequences [[description]](./../../unit2speech)
+ 3. Pre-processing for the ASR (down-sampling to 16 KHz, aligning length of the generated audio with ground-truth utterances)
+ 4. Running ASR
+ 5. Calculation of the post-ASR evaluation metrics
+
+Here we assume that you have already went throught the first two steps and focus on the rest.
+
+## Preprocessing
+### Down-sampling to 16KHz
+The bulk conversion can be done by running
+```bash
+ python $FAIRSEQ_ROOT/examples/textless_nlp/gslm/unit2speech/convert_to_16k.py $UTS_OUTPUT $UTS_OUTPUT_DOWNSAMPLE
+ ```
+ where `$UTS_OUTPUT` specifies the directory with the generated audio and `$UTS_OUTPUT_DOWNSAMPLE` is the directory where downsampled audio would be saved.
+
+ ### Matching by length
+This step is somewhat optional. However, if you want to compare the fluency and diversity of a generated speech utterance to that of the ground-truth speech with the same prefix, it is a good idea to force them to be of the same length.
+```bash
+python $FAIRSEQ_ROOT/examples/textless_nlp/asr_metrics/cut_as.py \
+ --samples_dir=$UTS_OUTPUT_DOWNSAMPLE --out_dir=$UTS_OUTPUT_DOWNSAMPLE_CUT \
+ --prompts_description=data/ground_truth_continuation_dev.json
+```
+
+Here `ground_truth_continuation_dev.json` is a json file with ground-truth text from LibriSpeech dev-clean, associated with some meta-data (assuming the evaluation is done on dev-clean). This file can be downloaded [[here]](https://dl.fbaipublicfiles.com/textless_nlp/gslm/eval_data/ground_truth_continuation_dev.json). A similar file for the test-clean is [[here]](https://dl.fbaipublicfiles.com/textless_nlp/gslm/eval_data/ground_truth_continuation_test.json). These files are used for the evaluation and contain texts for audio sequences that are at least 6s long.
+
+## Running ASR
+We use a pre-trained wav2vec model to run the ASR step. We firstly need to prepare manifest files which, roughly, tell the ASR system which files we want to transcribe. You can find more details and download the `960h_scratch.pt` checkpoint
+[[here]](https://github.com/pytorch/fairseq/blob/main/examples/wav2vec/README.md)). To run ASR, you would also need to
+install KenLM, Flashlight decoder, and download the KenLM 4-gram English language model.
+
+```bash
+ python $FAIRSEQ_ROOT/examples/wav2vec/wav2vec_manifest.py \
+ $UTS_OUTPUT_DOWNSAMPLE_CUT --valid-percent 0.0 --dest $MANIFEST_DIR --ext wav
+```
+where `$UTS_OUTPUT_DOWNSAMPLE_CUT` speficies the directory with the preprocessed UTS outputs and `$MANIFEST_DIR` is the output directory.
+
+We will be running an out-of-the-box evaluation script which requires ground-truth transcripts to measure quality metrics. We are only
+interested in the transcripts (and we don't have ground-truth outputs for when our ULM generated!), hence we will just generate
+some dummy transcripts instead:
+```bash
+cp $FAIRSEQ_ROOT/examples/textless_nlp/gslm/asr_metrics/misc/dict.ltr.txt $MANIFEST_DIR
+python $FAIRSEQ_ROOT/examples/textless_nlp/gslm/asr_metrics/misc/dummy_asr_data.py --tsv=$MANIFEST_DIR/train.tsv \
+ --output-dir=$MANIFEST_DIR
+```
+
+Now we are ready for running ASR:
+```
+mkdir -p asr
+python $FAIRSEQ_ROOT/examples/speech_recognition/infer.py \
+ $MANIFEST_DIR \
+ --task audio_pretraining --nbest 1 --path 960h_scratch.pt \
+ --gen-subset=train --results-path $PATH_TO_ASR_OUTPUT \
+ --w2l-decoder kenlm --lm-model 4-gram.bin \
+ --lexicon librispeech/lexicon_ltr.lst --word-score -1 \
+ --sil-weight 0 --lm-weight 2 --criterion ctc --labels ltr --max-tokens 300000 --remove-bpe letter
+```
+where `lexicon_ltr.lst` is the LibriSpeech lexicon and `$PATH_TO_ASR_OUTPUT` is the output directory (can be downloaded [[here]](https://dl.fbaipublicfiles.com/textless_nlp/gslm/eval_data/lexicon_ltr.lst)).
+
+## Evaluation metrics
+We run evaluation on the 1_000 shortest sequences that are at least 6s long. To filter those from the ASR transcript, we additionally provide each metric script with the paths to the manifest and `ground_truth_continuation_*` files.
+
+### Perplexity (PPX)
+To get a PPX metric estimate on an ASR transcript, you need to run the following command:
+```bash
+python ppx.py $PATH_TO_ASR_OUTPUT/hypo.word-960h_scratch.pt-train.txt --cut-tail\
+ --manifest=$MANIFEST_DIR/train.tsv --prompts-description=data/ground_truth_continuation_dev.json
+```
+where `--cut-tail` tells the script to ignore the last token on each line (ASR puts the sequence ID there).
+
+### Self- and Auto-BLEU
+```bash
+python self_bleu.py $PATH_TO_ASR_OUTPUT/hypo.word-960h_scratch.pt-train.txt --cut-tail \
+ --manifest=$MANIFEST_DIR/train.tsv --prompts-description=data/ground_truth_continuation_dev.json
+```
+
+### Continuation-BLEU
+```bash
+python continuation_eval.py --asr-transcript $PATH_TO_ASR_OUTPUT/hypo.word-960h_scratch.pt-train.txt \
+ --manifest=$MANIFEST_DIR/train.tsv --prompts-description=data/ground_truth_continuation_dev.json
+```
+
+### AUC
+Based on the metrics calculated above, we can estimate the AUC of the perplexity/diversity trade-off. We provide an illustration in a [Colab notebook](https://colab.research.google.com/drive/1pVPfOVax_PU3MkYdHRSsa-SI8GBUldNt?usp=sharing).
diff --git a/fairseq/examples/textless_nlp/gslm/metrics/asr_metrics/continuation_eval.py b/fairseq/examples/textless_nlp/gslm/metrics/asr_metrics/continuation_eval.py
new file mode 100644
index 0000000000000000000000000000000000000000..72b92a341dcd1b82035af72b8a6b4edc65783ecc
--- /dev/null
+++ b/fairseq/examples/textless_nlp/gslm/metrics/asr_metrics/continuation_eval.py
@@ -0,0 +1,99 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+from collections import defaultdict
+import numpy as np
+from misc.bleu_utils import sentence_bleu
+import json
+import warnings
+
+
+def get_args():
+ import argparse
+
+ parser = argparse.ArgumentParser("Tool to calculate Continuation-BLEU2")
+ parser.add_argument('--asr-transcript', type=str,
+ help='Path to the transcript file.')
+ parser.add_argument('--prompts-description', type=str,
+ help='Path to the ground-truth continuation')
+ parser.add_argument('--manifest', type=str, required=True)
+ parser.add_argument('--take-shortest', type=int, default=1000)
+
+ args = parser.parse_args()
+
+ return args
+
+
+def main():
+ # NLTK produces warnings
+ warnings.filterwarnings("ignore")
+
+ args = get_args()
+
+ with open(args.prompts_description, 'r') as fin:
+ original_continuations = json.loads(fin.read())
+
+ sequence2length = [(k, v[0]) for k, v in original_continuations.items()]
+ assert all(float(v) >= 6.0 for (_, v) in sequence2length) # 6 seconds
+
+ sequence2length.sort(key=lambda x: x[1])
+ to_take = set(v[0] for v in sequence2length[:args.take_shortest])
+
+ with open(args.manifest, 'r') as fin:
+ fin.readline()
+
+ linenum2file = dict([
+ (i, l.split("__")[0]) for (i, l) in enumerate(fin)
+ ])
+
+ max_files = max(linenum2file.keys())
+ continuations = defaultdict(list)
+
+ mean_length_after = 0
+ n_examples = 0
+
+ with open(args.asr_transcript, 'r') as fin:
+ for line in fin:
+ n_examples += 1
+ line = line.split()
+ sequence_id = int(line[-1].split('-')[1][:-1])
+
+ assert sequence_id <= max_files
+
+ sequence_name = linenum2file[sequence_id]
+
+ continuations[sequence_name].append(line[:-1])
+ mean_length_after += len(line)
+
+ mean_length_after /= n_examples
+ print(f'Mean length of continuations, in words: {mean_length_after}')
+ metric_values = []
+
+ mean_ground_truth_words = 0
+ n_examples = 0
+ n_candidates = 0
+
+ for k, candidates in continuations.items():
+ if k not in to_take:
+ continue
+
+ n_examples += 1
+
+ ground_truth = original_continuations[k][1].split()
+ n_candidates += len(candidates)
+ bleu = sentence_bleu(candidates, ground_truth, weights=(
+ 0.5, 0.5), no_length_penalty=True, averaging_mode="geometric")
+ mean_ground_truth_words += len(ground_truth)
+
+ metric_values.append(bleu)
+
+ n = len(metric_values)
+ print(
+ f'Median BLEU over {n} examples: {np.median(metric_values)} +- {np.std(metric_values) / np.sqrt(n)}')
+
+
+if __name__ == '__main__':
+ main()
diff --git a/fairseq/examples/textless_nlp/gslm/metrics/asr_metrics/misc/bleu_utils.py b/fairseq/examples/textless_nlp/gslm/metrics/asr_metrics/misc/bleu_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..75cc5272d367c4f3be98d698b512a529bdb2e4f5
--- /dev/null
+++ b/fairseq/examples/textless_nlp/gslm/metrics/asr_metrics/misc/bleu_utils.py
@@ -0,0 +1,166 @@
+"""
+
+TODO: the code is take from Apache-2 Licensed NLTK: make sure we do this properly!
+
+
+Copied over from nltk.tranlate.bleu_score. This code has two major changes:
+ - allows to turn off length/brevity penalty --- it has no sense for self-bleu,
+ - allows to use arithmetic instead of geometric mean
+"""
+
+import math
+import sys
+from fractions import Fraction
+import warnings
+from collections import Counter
+from nltk.translate.bleu_score import modified_precision, closest_ref_length, brevity_penalty, SmoothingFunction
+
+
+def corpus_bleu(
+ list_of_references,
+ hypotheses,
+ weights=(0.25, 0.25, 0.25, 0.25),
+ smoothing_function=None,
+ auto_reweigh=False,
+ averaging_mode="geometric",
+ no_length_penalty=False
+):
+ """
+ Calculate a single corpus-level BLEU score (aka. system-level BLEU) for all
+ the hypotheses and their respective references.
+
+ Instead of averaging the sentence level BLEU scores (i.e. marco-average
+ precision), the original BLEU metric (Papineni et al. 2002) accounts for
+ the micro-average precision (i.e. summing the numerators and denominators
+ for each hypothesis-reference(s) pairs before the division).
+
+ >>> hyp1 = ['It', 'is', 'a', 'guide', 'to', 'action', 'which',
+ ... 'ensures', 'that', 'the', 'military', 'always',
+ ... 'obeys', 'the', 'commands', 'of', 'the', 'party']
+ >>> ref1a = ['It', 'is', 'a', 'guide', 'to', 'action', 'that',
+ ... 'ensures', 'that', 'the', 'military', 'will', 'forever',
+ ... 'heed', 'Party', 'commands']
+ >>> ref1b = ['It', 'is', 'the', 'guiding', 'principle', 'which',
+ ... 'guarantees', 'the', 'military', 'forces', 'always',
+ ... 'being', 'under', 'the', 'command', 'of', 'the', 'Party']
+ >>> ref1c = ['It', 'is', 'the', 'practical', 'guide', 'for', 'the',
+ ... 'army', 'always', 'to', 'heed', 'the', 'directions',
+ ... 'of', 'the', 'party']
+
+ >>> hyp2 = ['he', 'read', 'the', 'book', 'because', 'he', 'was',
+ ... 'interested', 'in', 'world', 'history']
+ >>> ref2a = ['he', 'was', 'interested', 'in', 'world', 'history',
+ ... 'because', 'he', 'read', 'the', 'book']
+
+ >>> list_of_references = [[ref1a, ref1b, ref1c], [ref2a]]
+ >>> hypotheses = [hyp1, hyp2]
+ >>> corpus_bleu(list_of_references, hypotheses) # doctest: +ELLIPSIS
+ 0.5920...
+
+ The example below show that corpus_bleu() is different from averaging
+ sentence_bleu() for hypotheses
+
+ >>> score1 = sentence_bleu([ref1a, ref1b, ref1c], hyp1)
+ >>> score2 = sentence_bleu([ref2a], hyp2)
+ >>> (score1 + score2) / 2 # doctest: +ELLIPSIS
+ 0.6223...
+
+ :param list_of_references: a corpus of lists of reference sentences, w.r.t. hypotheses
+ :type list_of_references: list(list(list(str)))
+ :param hypotheses: a list of hypothesis sentences
+ :type hypotheses: list(list(str))
+ :param weights: weights for unigrams, bigrams, trigrams and so on
+ :type weights: list(float)
+ :param smoothing_function:
+ :type smoothing_function: SmoothingFunction
+ :param auto_reweigh: Option to re-normalize the weights uniformly.
+ :type auto_reweigh: bool
+ :return: The corpus-level BLEU score.
+ :rtype: float
+ """
+ # Before proceeding to compute BLEU, perform sanity checks.
+
+ p_numerators = Counter() # Key = ngram order, and value = no. of ngram matches.
+ p_denominators = Counter() # Key = ngram order, and value = no. of ngram in ref.
+ hyp_lengths, ref_lengths = 0, 0
+
+ assert len(list_of_references) == len(hypotheses), (
+ "The number of hypotheses and their reference(s) should be the " "same "
+ )
+
+ # Iterate through each hypothesis and their corresponding references.
+ for references, hypothesis in zip(list_of_references, hypotheses):
+ # For each order of ngram, calculate the numerator and
+ # denominator for the corpus-level modified precision.
+ for i, _ in enumerate(weights, start=1):
+ p_i = modified_precision(references, hypothesis, i)
+ p_numerators[i] += p_i.numerator
+ p_denominators[i] += p_i.denominator
+
+ # Calculate the hypothesis length and the closest reference length.
+ # Adds them to the corpus-level hypothesis and reference counts.
+ hyp_len = len(hypothesis)
+ hyp_lengths += hyp_len
+ ref_lengths += closest_ref_length(references, hyp_len)
+
+ # Calculate corpus-level brevity penalty.
+ if no_length_penalty and averaging_mode == 'geometric':
+ bp = 1.0
+ elif no_length_penalty and averaging_mode == 'arithmetic':
+ bp = 0.0
+ else:
+ assert not no_length_penalty
+ assert averaging_mode != 'arithmetic', 'Not sure how to apply length penalty when aurithmetic mode'
+ bp = brevity_penalty(ref_lengths, hyp_lengths)
+
+ # Uniformly re-weighting based on maximum hypothesis lengths if largest
+ # order of n-grams < 4 and weights is set at default.
+ if auto_reweigh:
+ if hyp_lengths < 4 and weights == (0.25, 0.25, 0.25, 0.25):
+ weights = (1 / hyp_lengths,) * hyp_lengths
+
+ # Collects the various precision values for the different ngram orders.
+ p_n = [
+ Fraction(p_numerators[i], p_denominators[i], _normalize=False)
+ for i, _ in enumerate(weights, start=1)
+ ]
+
+ # Returns 0 if there's no matching n-grams
+ # We only need to check for p_numerators[1] == 0, since if there's
+ # no unigrams, there won't be any higher order ngrams.
+ if p_numerators[1] == 0:
+ return 0
+
+ # If there's no smoothing, set use method0 from SmoothinFunction class.
+ if not smoothing_function:
+ smoothing_function = SmoothingFunction().method0
+ # Smoothen the modified precision.
+ # Note: smoothing_function() may convert values into floats;
+ # it tries to retain the Fraction object as much as the
+ # smoothing method allows.
+ p_n = smoothing_function(
+ p_n, references=references, hypothesis=hypothesis, hyp_len=hyp_lengths
+ )
+
+ if averaging_mode == "geometric":
+ s = (w_i * math.log(p_i) for w_i, p_i in zip(weights, p_n))
+ s = bp * math.exp(math.fsum(s))
+ elif averaging_mode == "arithmetic":
+ s = (w_i * p_i for w_i, p_i in zip(weights, p_n))
+ s = math.fsum(s)
+
+ return s
+
+
+def sentence_bleu(
+ references,
+ hypothesis,
+ weights=(0.25, 0.25, 0.25, 0.25),
+ smoothing_function=None,
+ auto_reweigh=False,
+ averaging_mode="geometric",
+ no_length_penalty=False
+):
+ return corpus_bleu(
+ [references], [hypothesis], weights, smoothing_function, auto_reweigh, averaging_mode, no_length_penalty
+ )
\ No newline at end of file
diff --git a/fairseq/examples/textless_nlp/gslm/metrics/asr_metrics/misc/cut_as.py b/fairseq/examples/textless_nlp/gslm/metrics/asr_metrics/misc/cut_as.py
new file mode 100644
index 0000000000000000000000000000000000000000..5b7e1e968564b84c47049c5cc69c9d6b8fafe0e9
--- /dev/null
+++ b/fairseq/examples/textless_nlp/gslm/metrics/asr_metrics/misc/cut_as.py
@@ -0,0 +1,69 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+import torchaudio
+import argparse
+import json
+import pathlib
+
+
+def get_args():
+ parser = argparse.ArgumentParser(
+ "Assuring generated audio have the same length as ground-truth audio")
+ parser.add_argument('--samples_dir', required=True, type=str)
+ parser.add_argument('--out_dir', required=True, type=str)
+ parser.add_argument('--prompts_description', required=True, type=str)
+ return parser.parse_args()
+
+
+def cut(src, tgt, l):
+ x, sr = torchaudio.load(str(src))
+ assert sr == 16_000
+
+ x = x.squeeze()
+ target_frames = int(l * sr)
+
+ flag = 0
+ if target_frames <= x.size(0):
+ x = x[:target_frames]
+ flag = 1
+ else:
+ flag = 0
+ torchaudio.save(str(tgt), x.unsqueeze(0), sr)
+ return flag
+
+
+def main():
+ args = get_args()
+ tgt_dir = pathlib.Path(args.out_dir)
+ tgt_dir.mkdir(exist_ok=True, parents=True)
+
+ total_files, sufficiently_long = 0, 0
+
+ with open(args.prompts_description, 'r') as f:
+ description = json.loads(f.read())
+
+ for src_f in pathlib.Path(args.samples_dir).glob('*.wav'):
+ name_prompt = src_f.with_suffix('').name.split('__')[0]
+
+ assert name_prompt in description, f'Cannot find {name_prompt}!'
+
+ target_length = description[name_prompt][0]
+ tgt_f = tgt_dir / (src_f.name)
+
+ is_long_enough = cut(src_f, tgt_f, target_length)
+ sufficiently_long += is_long_enough
+ if not is_long_enough:
+ print(f'{src_f} is not long enough')
+
+ total_files += 1
+
+ print(
+ f'Total files: {total_files}; sufficiently long: {sufficiently_long}')
+
+
+if __name__ == '__main__':
+ main()
diff --git a/fairseq/examples/textless_nlp/gslm/metrics/asr_metrics/misc/dict.ltr.txt b/fairseq/examples/textless_nlp/gslm/metrics/asr_metrics/misc/dict.ltr.txt
new file mode 100644
index 0000000000000000000000000000000000000000..69929e1666c8182148d83ef4332e4c677bb90e5a
--- /dev/null
+++ b/fairseq/examples/textless_nlp/gslm/metrics/asr_metrics/misc/dict.ltr.txt
@@ -0,0 +1,28 @@
+| 94802
+E 51860
+T 38431
+A 33152
+O 31495
+N 28855
+I 28794
+H 27187
+S 26071
+R 23546
+D 18289
+L 16308
+U 12400
+M 10685
+W 10317
+C 9844
+F 9062
+G 8924
+Y 8226
+P 6890
+B 6339
+V 3936
+K 3456
+' 1023
+X 636
+J 598
+Q 437
+Z 213
diff --git a/fairseq/examples/textless_nlp/gslm/metrics/asr_metrics/ppx.py b/fairseq/examples/textless_nlp/gslm/metrics/asr_metrics/ppx.py
new file mode 100644
index 0000000000000000000000000000000000000000..d6a40e4d359bdcae6d64f53ba06d8a533aec01ac
--- /dev/null
+++ b/fairseq/examples/textless_nlp/gslm/metrics/asr_metrics/ppx.py
@@ -0,0 +1,122 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+import torch
+import numpy as np
+import warnings
+
+
+def get_target_sequences(manifest, ground_truth, to_take=1000):
+ import json
+ import pathlib
+
+ with open(ground_truth, 'r') as fin:
+ original_continuations = json.loads(fin.read())
+
+ sequence2length = [(k, v[0]) for k, v in original_continuations.items()]
+ assert all(float(v) >= 6.0 for (_, v) in sequence2length) # 6 seconds
+
+ sequence2length.sort(key=lambda x: x[1])
+ to_take_sequences = set(v[0] for v in sequence2length[:to_take])
+ to_take_ids = []
+
+ with open(manifest, 'r') as f:
+ f.readline()
+
+ for i, line in enumerate(f.readlines()):
+ seq_id = line.split()[0]
+ seq_id = pathlib.Path(seq_id).name.split('__')[0]
+
+ if seq_id in to_take_sequences:
+ to_take_ids.append(i)
+
+ print(f'Took {len(to_take_ids)} ids')
+ return set(to_take_ids)
+
+
+def get_args():
+ import argparse
+
+ parser = argparse.ArgumentParser("Evaluate PPX metric of a transcript.")
+ parser.add_argument('--asr-transcript', type=str,
+ help='Path to the transcript file.')
+ parser.add_argument('--cut-id', action='store_true',
+ help='Whether cut the first token (typically a seq id)')
+ parser.add_argument('--cut-tail', action='store_true',
+ help='Whether cut the last token (typically a speaker id)')
+
+ parser.add_argument('--manifest', type=str, default=None)
+ parser.add_argument('--prompts-description', type=str, default=None)
+
+ args = parser.parse_args()
+
+ return args
+
+
+def main():
+ args = get_args()
+
+ lm = torch.hub.load(
+ 'pytorch/fairseq', 'transformer_lm.wmt19.en', tokenizer='moses', bpe='fastbpe')
+
+ lm.eval().cuda() # disable dropout
+
+ if args.manifest is None and args.prompts_description is None:
+ target_ids = None
+ else:
+ target_ids = get_target_sequences(
+ args.manifest, args.prompts_description)
+
+ with open(args.asr_transcript, 'r') as fin:
+ lines = fin.readlines()
+
+ if target_ids is not None:
+ filtered = []
+ for line in lines:
+ line_id = line.split()[-1]
+ line_id = int(line_id.split('-')[1][:-1])
+ if line_id in target_ids:
+ filtered.append(line)
+ lines = filtered
+ else:
+ pass
+
+ if args.cut_id:
+ lines = [' '.join(x.split()[1:]) for x in lines]
+ if args.cut_tail:
+ lines = [' '.join(x.split()[:-1]) for x in lines]
+ lines = [x.strip().lower() for x in lines]
+
+ def get_logprob(sent): return \
+ lm.score(sent)['positional_scores'].mean().neg().item()
+
+ logprobs = [get_logprob(l) for l in lines]
+
+ filtered = [x for x in logprobs if not np.isnan(x)]
+ if len(filtered) != len(logprobs):
+ warnings.warn("NaNs detected!")
+ logprobs = filtered
+
+ perplexities = [np.exp(l) for l in logprobs]
+
+ for name, stats in [('logprob', logprobs), ('perplexity', perplexities)]:
+ mean = np.mean(stats)
+ sem = np.std(stats) / np.sqrt(len(stats))
+
+ median = np.median(stats)
+ interval = list(np.percentile(stats, [10, 90]))
+
+ mean, sem, median, percentile10, percentile90 = [
+ round(x, 2) for x in [mean, sem, median] + interval]
+
+ print(name)
+ print(f"\tMean {mean} +- {sem}")
+ print(
+ f"\tMedian {median}, 90% confidence interval {percentile10}...{percentile90}")
+
+
+if __name__ == '__main__':
+ main()
diff --git a/fairseq/examples/textless_nlp/gslm/metrics/asr_metrics/self_auto_bleu.py b/fairseq/examples/textless_nlp/gslm/metrics/asr_metrics/self_auto_bleu.py
new file mode 100644
index 0000000000000000000000000000000000000000..062bb82f669f63a537b6ee8df4d42d292eb2575e
--- /dev/null
+++ b/fairseq/examples/textless_nlp/gslm/metrics/asr_metrics/self_auto_bleu.py
@@ -0,0 +1,201 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import numpy as np
+import nltk
+from misc.bleu_utils import sentence_bleu
+import warnings
+
+
+def get_target_sequences(manifest, ground_truth, to_take=1000):
+ import json
+ import pathlib
+
+ with open(ground_truth, 'r') as fin:
+ original_continuations = json.loads(fin.read())
+
+ sequence2length = [(k, v[0]) for k, v in original_continuations.items()]
+ assert all(float(v) >= 6.0 for (_, v) in sequence2length) # 6 seconds
+
+ sequence2length.sort(key=lambda x: x[1])
+ to_take_sequences = set(v[0] for v in sequence2length[:to_take])
+ to_take_ids = []
+
+ with open(manifest, 'r') as f:
+ f.readline()
+
+ for i, line in enumerate(f.readlines()):
+ seq_id = line.split()[0]
+ seq_id = pathlib.Path(seq_id).name.split('__')[0]
+
+ if seq_id in to_take_sequences:
+ to_take_ids.append(i)
+
+ print(f'Took {len(to_take_ids)} ids')
+ return set(to_take_ids)
+
+
+def get_args():
+ import argparse
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--asr-transcript', type=str,
+ help='Path to the transcript file.')
+
+ parser.add_argument('--manifest', required=True)
+ parser.add_argument('--prompts-description', required=True)
+
+ parser.add_argument('--cut-id', action='store_true',
+ help='Whether cut the first token (typically a seq id)')
+ parser.add_argument('--cut-tail', action='store_true',
+ help='Whether cut the last token (typically a speaker id)')
+ parser.add_argument('--debug', action='store_true')
+
+ args = parser.parse_args()
+
+ return args
+
+
+def get_self_bleu(utterances, averaging_mode, weights):
+ self_bleu = []
+
+ for i in range(len(utterances)):
+ hypo = utterances[i]
+ rest = utterances[:i] + utterances[i+1:]
+
+ self_bleu.append(sentence_bleu(rest, hypo, weights,
+ no_length_penalty=True, averaging_mode=averaging_mode))
+
+ return self_bleu
+
+
+def get_self_bleu2_arithmetic(utterances):
+ weights = (0.5, 0.5) # equal weight for unigrams and bigrams
+ return get_self_bleu(utterances, averaging_mode='arithmetic', weights=weights)
+
+
+def get_self_bleu2_geometric(utterances):
+ weights = (0.5, 0.5)
+ return get_self_bleu(utterances, averaging_mode='geometric', weights=weights)
+
+
+def get_auto_bleu2_arithmetic(utterances):
+ weights = (0.5, 0.5)
+ return [auto_bleu(u, mean_mode='arithmetic', weights=weights) for u in utterances]
+
+
+def get_auto_bleu2_geometric(utterances):
+ weights = (0.5, 0.5)
+ return [auto_bleu(u, mean_mode='geometric', weights=weights) for u in utterances]
+
+
+def get_auto_bleu3_geometric(utterances):
+ weights = (1./3, 1./3, 1./3)
+ return [auto_bleu(u, mean_mode='geometric', weights=weights) for u in utterances]
+
+
+def get_auto_bleu3_arithmetic(utterances):
+ weights = (1./3, 1./3, 1./3)
+ return [auto_bleu(u, mean_mode='arithmetic', weights=weights) for u in utterances]
+
+
+def get_self_bleu3_arithmetic(utterances):
+ weights = (1./3, 1./3, 1./3)
+ return get_self_bleu(utterances, averaging_mode='arithmetic', weights=weights)
+
+
+def get_self_bleu3_geometric(utterances):
+ weights = (1./3, 1./3, 1./3)
+ return get_self_bleu(utterances, averaging_mode='geometric', weights=weights)
+
+
+def auto_bleu(sentence, weights, mean_mode='arithmetic'):
+ if len(sentence) <= 1:
+ return 0
+
+ N = len(weights)
+
+ bleu_n = np.zeros([N])
+ for n in range(N):
+ targ_ngrams = list(nltk.ngrams(sentence, n+1))
+ for p in range(len(targ_ngrams)):
+ left = sentence[:p]
+ right = sentence[(p+n+1):]
+ rest_ngrams = list(nltk.ngrams(left, n+1)) + \
+ list(nltk.ngrams(right, n+1))
+ # compute the nb of matching ngrams
+ bleu_n[n] += targ_ngrams[p] in rest_ngrams
+ bleu_n[n] /= len(targ_ngrams) # average them to get a proportion
+
+ weights = np.array(weights)
+ if mean_mode == 'arithmetic':
+ return (bleu_n * weights).sum()
+ elif mean_mode == 'geometric':
+ return (bleu_n ** weights).prod()
+ else:
+ raise ValueError(f'Unknown agggregation mode {mean_mode}')
+
+
+def main():
+ from multiprocessing import Pool
+
+ args = get_args()
+ target_ids = get_target_sequences(args.manifest, args.prompts_description)
+
+ with open(args.asr_transcript, 'r') as fin:
+ lines = fin.readlines()
+
+ terms = [x.strip().split() for x in lines]
+ filtered = []
+ for term in terms:
+ line_id = int(term[-1].split('-')[1][:-1])
+ if line_id in target_ids:
+ filtered.append(term)
+ terms = filtered
+
+ if args.cut_id:
+ terms = [x[1:] for x in terms]
+ if args.cut_tail:
+ terms = [x[:-1] for x in terms]
+
+ if args.debug:
+ terms = terms[:10]
+
+ tasks = [
+ ('Self-BLEU2-arithmetic', get_self_bleu2_arithmetic),
+ ('Self-BLEU2-geometric', get_self_bleu2_geometric),
+ ('Auto-BLEU2-arithmetic', get_auto_bleu2_arithmetic),
+ ('Auto-BLEU2-geometric', get_auto_bleu2_geometric),
+
+ ('Self-BLEU3-arithmetic', get_self_bleu3_arithmetic),
+ ('Self-BLEU3-geometric', get_self_bleu3_geometric),
+ ('Auto-BLEU3-arithmetic', get_auto_bleu3_arithmetic),
+ ('Auto-BLEU3-geometric', get_auto_bleu3_geometric),
+ ]
+
+ n_processes = min(16, len(tasks))
+ with Pool(n_processes) as pool:
+ metrics = pool.map(run_f, [(t[1], terms) for t in tasks])
+
+ for (metric_name, _), metric in zip(tasks, metrics):
+ metric, sem = np.mean(metric), np.std(metric) / np.sqrt(len(metric))
+
+ metric, sem = [
+ round(100 * x, 2) for x in [metric, sem]
+ ]
+
+ print(f'{metric_name} {metric} +- {sem}')
+
+
+def run_f(task_params):
+ f, terms = task_params
+ return f(terms)
+
+
+if __name__ == '__main__':
+ # NLTK produces warnings
+ warnings.filterwarnings("ignore")
+
+ main()
diff --git a/fairseq/examples/textless_nlp/gslm/speech2unit/README.md b/fairseq/examples/textless_nlp/gslm/speech2unit/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..1a3d131ec165f12e37906420fc2c284a7223bda2
--- /dev/null
+++ b/fairseq/examples/textless_nlp/gslm/speech2unit/README.md
@@ -0,0 +1,71 @@
+# Speech to Unit Model (speech2unit)
+
+## Acoustic Model
+For quantizing speech we learn a K-means clustering over acoustic representations for which we either use Log-Mel Filterbank or pretrained acoustic representation models. For using pretrained models, please download from their respective locations linked below.
+* [Modified CPC](https://dl.fbaipublicfiles.com/textless_nlp/gslm/cpc/cpc_big_ll6kh_top_ctc.pt)
+* [HuBERT-Base](https://dl.fbaipublicfiles.com/hubert/hubert_base_ls960.pt)
+* [Wav2Vec 2.0-Base](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_vox_new.pt)
+
+## Quantization Model
+You can download pretrained quantized model from the list below.
+
+K-Means Model | Download Link
+|-|-
+Log Mel Filterbank + KM50 | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/logmel/km50/km.bin)
+Log Mel Filterbank + KM100 | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/logmel/km100/km.bin)
+Log Mel Filterbank + KM200 | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/logmel/km200/km.bin)
+Log Mel Filterbank + KM500 | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/logmel/km500/km.bin)
+Modified CPC + KM50 | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/cpc/km50/km.bin)
+Modified CPC + KM100 | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/cpc/km100/km.bin)
+Modified CPC + KM200 | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/cpc/km200/km.bin)
+Modified CPC + KM500 | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/cpc/km500/km.bin)
+HuBERT Base + KM50 | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/hubert/km50/km.bin)
+HuBERT Base + KM100 | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/hubert/km100/km.bin)
+HuBERT Base + KM200 | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/hubert/km200/km.bin)
+HuBERT Base + KM500 | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/hubert/km500/km.bin)
+wav2vec 2.0 Large + KM50 | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/w2v2/km50/km.bin)
+wav2vec 2.0 Large + KM100 | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/w2v2/km100/km.bin)
+wav2vec 2.0 Large + KM200 | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/w2v2/km200/km.bin)
+wav2vec 2.0 Large + KM500 | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/w2v2/km500/km.bin)
+
+### Quantization
+For quantizing speech with a given acoustic representation, please follow the steps below.
+1. Learn K-means clustering model
+```
+N_CLUSTERS=
+TYPE=
+CKPT_PATH=
+LAYER=
+MANIFEST=
+KM_MODEL_PATH=
+
+PYTHONPATH=. python examples/textless_nlp/gslm/speech2unit/clustering/cluster_kmeans.py \
+ --num_clusters $N_CLUSTERS \
+ --feature_type $TYPE \
+ --checkpoint_path $CKPT_PATH \
+ --layer $LAYER \
+ --manifest_path $MANIFEST \
+ --out_kmeans_model_path $KM_MODEL_PATH
+```
+2. Quantize using the learned clusters
+```
+MANIFEST=
+OUT_QUANTIZED_FILE=
+
+python examples/textless_nlp/gslm/speech2unit/clustering/del/quantize_with_kmeans.py \
+ --feature_type $TYPE \
+ --kmeans_model_path $KM_MODEL_PATH \
+ --checkpoint_path $CKPT_PATH \
+ --layer $LAYER \
+ --manifest_path $MANIFEST \
+ --out_quantized_file_path $OUT_QUANTIZED_FILE \
+ --extension ".flac"
+```
+
+Note about the manifest file is a file with paths and length of input audio files. The format of the file is as follows:
+```
+
+\t
+\t
+...
+```
\ No newline at end of file
diff --git a/fairseq/examples/textless_nlp/gslm/speech2unit/__init__.py b/fairseq/examples/textless_nlp/gslm/speech2unit/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/fairseq/examples/textless_nlp/gslm/speech2unit/clustering/__init__.py b/fairseq/examples/textless_nlp/gslm/speech2unit/clustering/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/fairseq/examples/textless_nlp/gslm/speech2unit/clustering/cluster_kmeans.py b/fairseq/examples/textless_nlp/gslm/speech2unit/clustering/cluster_kmeans.py
new file mode 100644
index 0000000000000000000000000000000000000000..7cf844a95a075ee9ad318dc11dd71537d1ef6a5b
--- /dev/null
+++ b/fairseq/examples/textless_nlp/gslm/speech2unit/clustering/cluster_kmeans.py
@@ -0,0 +1,212 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import argparse
+import logging
+import os
+import time
+
+import numpy as np
+from sklearn.cluster import MiniBatchKMeans
+
+import joblib
+from examples.textless_nlp.gslm.speech2unit.pretrained.utils import (
+ get_and_dump_features,
+ get_features,
+)
+
+
+def get_logger():
+ log_format = "[%(asctime)s] [%(levelname)s]: %(message)s"
+ logging.basicConfig(format=log_format, level=logging.INFO)
+ logger = logging.getLogger(__name__)
+ return logger
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ description="Learn K-means clustering over acoustic features."
+ )
+
+ # Features arguments
+ parser.add_argument(
+ "--in_features_path", type=str, default=None, help="Features file path"
+ )
+ parser.add_argument(
+ "--feature_type",
+ type=str,
+ choices=["logmel", "hubert", "w2v2", "cpc"],
+ default=None,
+ help="Acoustic feature type",
+ )
+ parser.add_argument(
+ "--manifest_path",
+ type=str,
+ default=None,
+ help="Manifest file containing the root dir and file names",
+ )
+ parser.add_argument(
+ "--out_features_path",
+ type=str,
+ default=None,
+ help="Features file path to write to",
+ )
+ parser.add_argument(
+ "--checkpoint_path",
+ type=str,
+ help="Pretrained acoustic model checkpoint",
+ )
+ parser.add_argument(
+ "--layer",
+ type=int,
+ help="The layer of the pretrained model to extract features from",
+ default=-1,
+ )
+ parser.add_argument(
+ "--sample_pct",
+ type=float,
+ help="Percent data to use for K-means training",
+ default=0.1,
+ )
+
+ # K-means arguments
+ parser.add_argument(
+ "--num_clusters", type=int, help="Nubmer of clusters", default=50
+ )
+ parser.add_argument("--init", default="k-means++")
+ parser.add_argument(
+ "--max_iter",
+ type=int,
+ help="Maximum number of iterations for K-means training",
+ default=150,
+ )
+ parser.add_argument(
+ "--batch_size",
+ type=int,
+ help="Batch size for K-means training",
+ default=10000,
+ )
+ parser.add_argument("--tol", default=0.0, type=float)
+ parser.add_argument("--max_no_improvement", default=100, type=int)
+ parser.add_argument("--n_init", default=20, type=int)
+ parser.add_argument("--reassignment_ratio", default=0.5, type=float)
+ parser.add_argument(
+ "--out_kmeans_model_path",
+ type=str,
+ required=True,
+ help="Path to save K-means model",
+ )
+
+ # Leftovers
+ parser.add_argument(
+ "--seed",
+ type=int,
+ help="Random seed to use for K-means training",
+ default=1369,
+ )
+
+ return parser
+
+
+def get_kmeans_model(
+ n_clusters,
+ init,
+ max_iter,
+ batch_size,
+ tol,
+ max_no_improvement,
+ n_init,
+ reassignment_ratio,
+ random_state,
+):
+ return MiniBatchKMeans(
+ n_clusters=n_clusters,
+ init=init,
+ max_iter=max_iter,
+ batch_size=batch_size,
+ tol=tol,
+ max_no_improvement=max_no_improvement,
+ n_init=n_init,
+ reassignment_ratio=reassignment_ratio,
+ random_state=random_state,
+ verbose=1,
+ compute_labels=True,
+ init_size=None,
+ )
+
+
+def train_kmeans(kmeans_model, features_batch):
+ start_time = time.time()
+ kmeans_model.fit(features_batch)
+ time_taken = round((time.time() - start_time) // 60, 2)
+ return kmeans_model, time_taken
+
+
+def main(args, logger):
+ # Features loading/extraction for K-means
+ if args.in_features_path:
+ # Feature loading
+ logger.info(f"Loading features from {args.in_features_path}...")
+ features_batch = np.load(args.in_features_path, allow_pickle=True)
+ else:
+ # Feature extraction
+ logger.info(f"Extracting {args.feature_type} acoustic features...")
+ features_batch = (
+ get_features(
+ feature_type=args.feature_type,
+ checkpoint_path=args.checkpoint_path,
+ layer=args.layer,
+ manifest_path=args.manifest_path,
+ sample_pct=args.sample_pct,
+ flatten=True,
+ )
+ if not args.out_features_path
+ else get_and_dump_features(
+ feature_type=args.feature_type,
+ checkpoint_path=args.checkpoint_path,
+ layer=args.layer,
+ manifest_path=args.manifest_path,
+ sample_pct=args.sample_pct,
+ flatten=True,
+ out_features_path=args.out_features_path,
+ )
+ )
+ if args.out_features_path:
+ logger.info(
+ f"Saved extracted features at {args.out_features_path}"
+ )
+ logger.info(f"Features shape = {features_batch.shape}\n")
+
+ # Learn and save K-means model
+ kmeans_model = get_kmeans_model(
+ n_clusters=args.num_clusters,
+ init=args.init,
+ max_iter=args.max_iter,
+ batch_size=args.batch_size,
+ tol=args.tol,
+ max_no_improvement=args.max_no_improvement,
+ n_init=args.n_init,
+ reassignment_ratio=args.reassignment_ratio,
+ random_state=args.seed,
+ )
+ logger.info("Starting k-means training...")
+ kmeans_model, time_taken = train_kmeans(
+ kmeans_model=kmeans_model, features_batch=features_batch
+ )
+ logger.info(f"...done k-means training in {time_taken} minutes")
+ inertia = -kmeans_model.score(features_batch) / len(features_batch)
+ logger.info(f"Total intertia: {round(inertia, 2)}\n")
+
+ logger.info(f"Saving k-means model to {args.out_kmeans_model_path}")
+ os.makedirs(os.path.dirname(args.out_kmeans_model_path), exist_ok=True)
+ joblib.dump(kmeans_model, open(args.out_kmeans_model_path, "wb"))
+
+
+if __name__ == "__main__":
+ parser = get_parser()
+ args = parser.parse_args()
+ logger = get_logger()
+ logger.info(args)
+ main(args, logger)
diff --git a/fairseq/examples/textless_nlp/gslm/speech2unit/clustering/dump_feats.py b/fairseq/examples/textless_nlp/gslm/speech2unit/clustering/dump_feats.py
new file mode 100644
index 0000000000000000000000000000000000000000..031567c6d85d16b5236053abf008b7cabccb4673
--- /dev/null
+++ b/fairseq/examples/textless_nlp/gslm/speech2unit/clustering/dump_feats.py
@@ -0,0 +1,91 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import argparse
+import logging
+
+from examples.textless_nlp.gslm.speech2unit.pretrained.utils import (
+ get_and_dump_features,
+)
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ description="Compute and dump log mel fbank features."
+ )
+ parser.add_argument(
+ "--feature_type",
+ type=str,
+ choices=["logmel", "hubert", "w2v2", "cpc"],
+ default=None,
+ help="Acoustic feature type",
+ )
+ parser.add_argument(
+ "--manifest_path",
+ type=str,
+ default=None,
+ help="Manifest file containing the root dir and file names",
+ )
+ parser.add_argument(
+ "--out_features_path",
+ type=str,
+ default=None,
+ help="Features file path to write to",
+ )
+ parser.add_argument(
+ "--checkpoint_path",
+ type=str,
+ help="Pretrained acoustic model checkpoint",
+ )
+ parser.add_argument(
+ "--layer",
+ type=int,
+ help="The layer of the pretrained model to extract features from",
+ default=-1,
+ )
+ parser.add_argument(
+ "--sample_pct",
+ type=float,
+ help="Percent data to use for K-means training",
+ default=0.1,
+ )
+ parser.add_argument(
+ "--out_features_path",
+ type=str,
+ help="Path to save log mel fbank features",
+ )
+ return parser
+
+
+def get_logger():
+ log_format = "[%(asctime)s] [%(levelname)s]: %(message)s"
+ logging.basicConfig(format=log_format, level=logging.INFO)
+ logger = logging.getLogger(__name__)
+ return logger
+
+
+if __name__ == "__main__":
+ """
+ Example command:
+ python ~/speechbot/clustering/dump_logmelfank_feats.py \
+ --manifest_path /checkpoint/kushall/data/LJSpeech-1.1/asr_input_wavs_16k/train.tsv
+ --out_features_path /checkpoint/kushall/experiments/speechbot/logmelfbank/features/ljspeech/train.npy
+ """
+ parser = get_parser()
+ args = parser.parse_args()
+ logger = get_logger()
+ logger.info(args)
+
+ logger.info(f"Extracting {args.feature_type} acoustic features...")
+ get_and_dump_features(
+ feature_type=args.feature_type,
+ checkpoint_path=args.checkpoint_path,
+ layer=args.layer,
+ manifest_path=args.manifest_path,
+ sample_pct=args.sample_pct,
+ flatten=True,
+ out_features_path=args.out_features_path,
+ )
+ logger.info(f"Saved extracted features at {args.out_features_path}")
diff --git a/fairseq/examples/textless_nlp/gslm/speech2unit/clustering/quantize_with_kmeans.py b/fairseq/examples/textless_nlp/gslm/speech2unit/clustering/quantize_with_kmeans.py
new file mode 100644
index 0000000000000000000000000000000000000000..2c87445d810cd790f887d1a135287a334cbdf223
--- /dev/null
+++ b/fairseq/examples/textless_nlp/gslm/speech2unit/clustering/quantize_with_kmeans.py
@@ -0,0 +1,125 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import argparse
+import logging
+import os
+
+import numpy as np
+
+import joblib
+from examples.textless_nlp.gslm.speech2unit.clustering.utils import (
+ get_audio_files,
+)
+from examples.textless_nlp.gslm.speech2unit.pretrained.utils import (
+ get_features,
+)
+
+
+def get_logger():
+ log_format = "[%(asctime)s] [%(levelname)s]: %(message)s"
+ logging.basicConfig(format=log_format, level=logging.INFO)
+ logger = logging.getLogger(__name__)
+ return logger
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ description="Quantize using K-means clustering over acoustic features."
+ )
+ parser.add_argument(
+ "--feature_type",
+ type=str,
+ choices=["logmel", "hubert", "w2v2", "cpc"],
+ default=None,
+ required=True,
+ help="Acoustic feature type",
+ )
+ parser.add_argument(
+ "--acoustic_model_path",
+ type=str,
+ help="Pretrained acoustic model checkpoint"
+ )
+ parser.add_argument(
+ "--layer",
+ type=int,
+ help="The layer of the pretrained model to extract features from",
+ default=-1,
+ )
+ parser.add_argument(
+ "--kmeans_model_path",
+ type=str,
+ required=True,
+ help="K-means model file path to use for inference",
+ )
+ parser.add_argument(
+ "--features_path",
+ type=str,
+ default=None,
+ help="Features file path. You don't need to enter acoustic model details if you have dumped features",
+ )
+ parser.add_argument(
+ "--manifest_path",
+ type=str,
+ default=None,
+ help="Manifest file containing the root dir and file names",
+ )
+ parser.add_argument(
+ "--out_quantized_file_path",
+ required=True,
+ type=str,
+ help="File path of quantized output.",
+ )
+ parser.add_argument(
+ "--extension", type=str, default=".flac", help="Features file path"
+ )
+ return parser
+
+
+def main(args, logger):
+ # Feature extraction
+ if args.features_path is not None:
+ logger.info(f"Loading acoustic features from {args.features_path}...")
+ features_batch = np.load(args.features_path)
+ else:
+ logger.info(f"Extracting {args.feature_type} acoustic features...")
+ features_batch = get_features(
+ feature_type=args.feature_type,
+ checkpoint_path=args.acoustic_model_path,
+ layer=args.layer,
+ manifest_path=args.manifest_path,
+ sample_pct=1.0,
+ flatten=False,
+ )
+ logger.info(
+ f"Features extracted for {len(features_batch)} utterances.\n"
+ )
+ logger.info(
+ f"Dimensionality of representation = {features_batch[0].shape[1]}"
+ )
+
+ # K-means model
+ logger.info(f"Loading K-means model from {args.kmeans_model_path} ...")
+ kmeans_model = joblib.load(open(args.kmeans_model_path, "rb"))
+ kmeans_model.verbose = False
+
+ _, fnames, _ = get_audio_files(args.manifest_path)
+
+ os.makedirs(os.path.dirname(args.out_quantized_file_path), exist_ok=True)
+ print(f"Writing quantized predictions to {args.out_quantized_file_path}")
+ with open(args.out_quantized_file_path, "w") as fout:
+ for i, feats in enumerate(features_batch):
+ pred = kmeans_model.predict(feats)
+ pred_str = " ".join(str(p) for p in pred)
+ base_fname = os.path.basename(fnames[i]).rstrip(args.extension)
+ fout.write(f"{base_fname}|{pred_str}\n")
+
+
+if __name__ == "__main__":
+ parser = get_parser()
+ args = parser.parse_args()
+ logger = get_logger()
+ logger.info(args)
+ main(args, logger)
diff --git a/fairseq/examples/textless_nlp/gslm/speech2unit/clustering/utils.py b/fairseq/examples/textless_nlp/gslm/speech2unit/clustering/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..cf08d1fe4b470477b724aa8d770d91c0cac35a0e
--- /dev/null
+++ b/fairseq/examples/textless_nlp/gslm/speech2unit/clustering/utils.py
@@ -0,0 +1,20 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from typing import List, Tuple
+
+
+def get_audio_files(manifest_path: str) -> Tuple[str, List[str], List[int]]:
+ fnames, sizes = [], []
+ with open(manifest_path, "r") as f:
+ root_dir = f.readline().strip()
+ for line in f:
+ items = line.strip().split("\t")
+ assert (
+ len(items) == 2
+ ), f"File must have two columns separated by tab. Got {line}"
+ fnames.append(items[0])
+ sizes.append(int(items[1]))
+ return root_dir, fnames, sizes
diff --git a/fairseq/examples/textless_nlp/gslm/speech2unit/pretrained/cpc_feature_reader.py b/fairseq/examples/textless_nlp/gslm/speech2unit/pretrained/cpc_feature_reader.py
new file mode 100644
index 0000000000000000000000000000000000000000..c613f52d3c3de43a048849a231a9a34e2a883486
--- /dev/null
+++ b/fairseq/examples/textless_nlp/gslm/speech2unit/pretrained/cpc_feature_reader.py
@@ -0,0 +1,192 @@
+import soundfile as sf
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class CpcFeatureReader:
+ """
+ Wrapper class to run inference on CPC model.
+ Helps extract features for a given audio file.
+ """
+
+ def __init__(
+ self,
+ checkpoint_path,
+ layer,
+ use_encoder_layer=False,
+ norm_features=False,
+ sample_rate=16000,
+ max_chunk=64000,
+ ):
+ self.model = load_cpc_model(checkpoint_path, layer).eval().cuda()
+ self.sample_rate = sample_rate
+ self.max_chunk = max_chunk
+ self.norm_features = norm_features
+ self.use_encoder_layer = use_encoder_layer
+
+ def read_audio(self, path, ref_len=None):
+ wav, sr = sf.read(path)
+ if wav.ndim == 2:
+ wav = wav.mean(-1)
+ assert wav.ndim == 1, wav.ndim
+ assert sr == self.sample_rate, sr
+ if ref_len is not None and abs(ref_len - len(wav)) > 160:
+ print(f"ref {ref_len} != read {len(wav)} ({path})")
+ return wav
+
+ def get_feats(self, file_path, ref_len=None):
+ x = self.read_audio(file_path, ref_len)
+ # Inspired from CPC_audio feature_loader.py
+ with torch.no_grad():
+ x = torch.from_numpy(x).float().cuda()
+ x = x.view(1, 1, -1)
+ size = x.size(2)
+ feat = []
+ start = 0
+ while start < size:
+ if start + self.max_chunk > size:
+ break
+ x_chunk = x[..., start : start + self.max_chunk]
+ feat_chunk = self.model.extract_features(
+ source=x_chunk,
+ get_encoded=self.use_encoder_layer,
+ norm_output=self.norm_features,
+ )
+ feat.append(feat_chunk)
+ start += self.max_chunk
+
+ if start < size:
+ x_chunk = x[:, -self.max_chunk :]
+ feat_chunk = self.model.extract_features(
+ source=x_chunk,
+ get_encoded=self.use_encoder_layer,
+ norm_output=self.norm_features,
+ )
+ df = x_chunk.size(2) // feat_chunk.size(1)
+ delta = (size - start) // df
+ feat.append(feat_chunk[:, -delta:])
+ return torch.cat(feat, 1).squeeze(0)
+
+
+def load_cpc_model(checkpoint_path, layer=None):
+ state_dict = torch.load(checkpoint_path)
+ weights = state_dict["weights"]
+ config = state_dict["config"]
+ if layer is not None:
+ config["nLevelsGRU"] = layer
+
+ encoder = CPCEncoder(config["hiddenEncoder"])
+ ar_net = CPCAR(
+ config["hiddenEncoder"], config["hiddenGar"], False, config["nLevelsGRU"]
+ )
+
+ model = CPCModel(encoder, ar_net)
+ model.load_state_dict(weights, strict=False)
+ model.config = config
+
+ return model
+
+
+class ChannelNorm(nn.Module):
+ def __init__(self, num_features, epsilon=1e-05, affine=True):
+ super(ChannelNorm, self).__init__()
+ if affine:
+ self.weight = nn.parameter.Parameter(torch.Tensor(1, num_features, 1))
+ self.bias = nn.parameter.Parameter(torch.Tensor(1, num_features, 1))
+ else:
+ self.weight = None
+ self.bias = None
+ self.epsilon = epsilon
+ self.p = 0
+ self.affine = affine
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ if self.affine:
+ torch.nn.init.ones_(self.weight)
+ torch.nn.init.zeros_(self.bias)
+
+ def forward(self, x):
+ cum_mean = x.mean(dim=1, keepdim=True)
+ cum_var = x.var(dim=1, keepdim=True)
+ x = (x - cum_mean) * torch.rsqrt(cum_var + self.epsilon)
+ if self.weight is not None:
+ x = x * self.weight + self.bias
+ return x
+
+
+class CPCEncoder(nn.Module):
+ def __init__(self, hidden_dim=512):
+ super(CPCEncoder, self).__init__()
+ self.conv0 = nn.Conv1d(1, hidden_dim, 10, stride=5, padding=3)
+ self.batchNorm0 = ChannelNorm(hidden_dim)
+ self.conv1 = nn.Conv1d(hidden_dim, hidden_dim, 8, stride=4, padding=2)
+ self.batchNorm1 = ChannelNorm(hidden_dim)
+ self.conv2 = nn.Conv1d(hidden_dim, hidden_dim, 4, stride=2, padding=1)
+ self.batchNorm2 = ChannelNorm(hidden_dim)
+ self.conv3 = nn.Conv1d(hidden_dim, hidden_dim, 4, stride=2, padding=1)
+ self.batchNorm3 = ChannelNorm(hidden_dim)
+ self.conv4 = nn.Conv1d(hidden_dim, hidden_dim, 4, stride=2, padding=1)
+ self.batchNorm4 = ChannelNorm(hidden_dim)
+ self.DOWNSAMPLING = 160
+
+ def get_output_dim(self):
+ return self.conv4.out_channels
+
+ def forward(self, x):
+ x = F.relu(self.batchNorm0(self.conv0(x)))
+ x = F.relu(self.batchNorm1(self.conv1(x)))
+ x = F.relu(self.batchNorm2(self.conv2(x)))
+ x = F.relu(self.batchNorm3(self.conv3(x)))
+ x = F.relu(self.batchNorm4(self.conv4(x)))
+ return x
+
+
+class CPCAR(nn.Module):
+ def __init__(self, dim_encoded, dim_output, keep_hidden, num_layers):
+ super(CPCAR, self).__init__()
+ self.baseNet = nn.LSTM(
+ dim_encoded, dim_output, num_layers=num_layers, batch_first=True
+ )
+ self.hidden = None
+ self.keep_hidden = keep_hidden
+
+ def get_output_dim(self):
+ return self.baseNet.hidden_size
+
+ def forward(self, x):
+ try:
+ self.baseNet.flatten_parameters()
+ except RuntimeError:
+ pass
+ x, h = self.baseNet(x, self.hidden)
+ if self.keep_hidden:
+ if isinstance(h, tuple):
+ self.hidden = tuple(x.detach() for x in h)
+ else:
+ self.hidden = h.detach()
+ return x
+
+
+class CPCModel(nn.Module):
+ def __init__(self, encoder, ar_net):
+ super(CPCModel, self).__init__()
+ self.gEncoder = encoder
+ self.gAR = ar_net
+ self.config = None
+
+ def forward(self, x, label):
+ encoded = self.gEncoder(x).permute(0, 2, 1)
+ cpc_feature = self.gAR(encoded)
+ return cpc_feature, encoded, label
+
+ def extract_features(self, source, get_encoded=False, norm_output=False):
+ cpc_feature, encoded, _ = self.forward(source, None)
+ if get_encoded:
+ cpc_feature = encoded
+ if norm_output:
+ mean = cpc_feature.mean(dim=1, keepdim=True)
+ var = cpc_feature.var(dim=1, keepdim=True)
+ cpc_feature = (cpc_feature - mean) / torch.sqrt(var + 1e-08)
+ return cpc_feature
diff --git a/fairseq/examples/textless_nlp/gslm/speech2unit/pretrained/hubert_feature_reader.py b/fairseq/examples/textless_nlp/gslm/speech2unit/pretrained/hubert_feature_reader.py
new file mode 100644
index 0000000000000000000000000000000000000000..09442206e19abf854f2f02754ec7c6f8bc564200
--- /dev/null
+++ b/fairseq/examples/textless_nlp/gslm/speech2unit/pretrained/hubert_feature_reader.py
@@ -0,0 +1,59 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import fairseq
+import soundfile as sf
+import torch.nn.functional as F
+
+
+class HubertFeatureReader:
+ """
+ Wrapper class to run inference on HuBERT model.
+ Helps extract features for a given audio file.
+ """
+
+ def __init__(self, checkpoint_path, layer, max_chunk=1600000):
+ (
+ model,
+ cfg,
+ task,
+ ) = fairseq.checkpoint_utils.load_model_ensemble_and_task(
+ [checkpoint_path]
+ )
+ self.model = model[0].eval().cuda()
+ self.task = task
+ self.layer = layer
+ self.max_chunk = max_chunk
+
+ def read_audio(self, path, ref_len=None):
+ wav, sr = sf.read(path)
+ if wav.ndim == 2:
+ wav = wav.mean(-1)
+ assert wav.ndim == 1, wav.ndim
+ assert sr == self.task.cfg.sample_rate, sr
+ if ref_len is not None and abs(ref_len - len(wav)) > 160:
+ print(f"ref {ref_len} != read {len(wav)} ({path})")
+ return wav
+
+ def get_feats(self, file_path, ref_len=None):
+ x = self.read_audio(file_path, ref_len)
+ with torch.no_grad():
+ x = torch.from_numpy(x).float().cuda()
+ if self.task.cfg.normalize:
+ x = F.layer_norm(x, x.shape)
+ x = x.view(1, -1)
+
+ feat = []
+ for start in range(0, x.size(1), self.max_chunk):
+ x_chunk = x[:, start: start + self.max_chunk]
+ feat_chunk, _ = self.model.extract_features(
+ source=x_chunk,
+ padding_mask=None,
+ mask=False,
+ output_layer=self.layer,
+ )
+ feat.append(feat_chunk)
+ return torch.cat(feat, 1).squeeze(0)
diff --git a/fairseq/examples/textless_nlp/gslm/speech2unit/pretrained/logmel_feature_reader.py b/fairseq/examples/textless_nlp/gslm/speech2unit/pretrained/logmel_feature_reader.py
new file mode 100644
index 0000000000000000000000000000000000000000..106f50247622deca688b223f1ad63275d5b65e58
--- /dev/null
+++ b/fairseq/examples/textless_nlp/gslm/speech2unit/pretrained/logmel_feature_reader.py
@@ -0,0 +1,30 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import soundfile as sf
+import torch
+import torchaudio.compliance.kaldi as kaldi
+
+
+class LogMelFeatureReader:
+ """
+ Wrapper class to run inference on HuBERT model.
+ Helps extract features for a given audio file.
+ """
+
+ def __init__(self, *args, **kwargs):
+ self.num_mel_bins = kwargs.get("num_mel_bins", 80)
+ self.frame_length = kwargs.get("frame_length", 25.0)
+
+ def get_feats(self, file_path):
+ wav, sr = sf.read(file_path)
+ feats = torch.from_numpy(wav).float()
+ feats = kaldi.fbank(
+ feats.unsqueeze(0),
+ num_mel_bins=self.num_mel_bins,
+ frame_length=self.frame_length,
+ sample_frequency=sr,
+ )
+ return feats
diff --git a/fairseq/examples/textless_nlp/gslm/speech2unit/pretrained/utils.py b/fairseq/examples/textless_nlp/gslm/speech2unit/pretrained/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..5aaddf6421ab7fa417af508005671a0ed821c701
--- /dev/null
+++ b/fairseq/examples/textless_nlp/gslm/speech2unit/pretrained/utils.py
@@ -0,0 +1,126 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import gc
+import os
+import random
+import shutil
+import numpy as np
+
+import torch
+import tqdm
+from examples.textless_nlp.gslm.speech2unit.pretrained.cpc_feature_reader import (
+ CpcFeatureReader,
+)
+from examples.textless_nlp.gslm.speech2unit.pretrained.hubert_feature_reader import (
+ HubertFeatureReader,
+)
+from examples.textless_nlp.gslm.speech2unit.pretrained.logmel_feature_reader import (
+ LogMelFeatureReader,
+)
+from examples.textless_nlp.gslm.speech2unit.pretrained.w2v2_feature_reader import (
+ Wav2VecFeatureReader,
+)
+
+
+def get_feature_reader(feature_type):
+ if feature_type == "logmel":
+ return LogMelFeatureReader
+ elif feature_type == "hubert":
+ return HubertFeatureReader
+ elif feature_type == "w2v2":
+ return Wav2VecFeatureReader
+ elif feature_type == "cpc":
+ return CpcFeatureReader
+ else:
+ raise NotImplementedError(f"{feature_type} is not supported.")
+
+
+def get_feature_iterator(
+ feature_type, checkpoint_path, layer, manifest_path, sample_pct
+):
+ feature_reader_cls = get_feature_reader(feature_type)
+ with open(manifest_path, "r") as fp:
+ lines = fp.read().split("\n")
+ root = lines.pop(0).strip()
+ file_path_list = [
+ os.path.join(root, line.split("\t")[0])
+ for line in lines
+ if len(line) > 0
+ ]
+ if sample_pct < 1.0:
+ file_path_list = random.sample(
+ file_path_list, int(sample_pct * len(file_path_list))
+ )
+ num_files = len(file_path_list)
+ reader = feature_reader_cls(
+ checkpoint_path=checkpoint_path, layer=layer
+ )
+
+ def iterate():
+ for file_path in file_path_list:
+ feats = reader.get_feats(file_path)
+ yield feats.cpu().numpy()
+
+ return iterate, num_files
+
+
+def get_features(
+ feature_type, checkpoint_path, layer, manifest_path, sample_pct, flatten
+):
+ generator, num_files = get_feature_iterator(
+ feature_type=feature_type,
+ checkpoint_path=checkpoint_path,
+ layer=layer,
+ manifest_path=manifest_path,
+ sample_pct=sample_pct,
+ )
+ iterator = generator()
+
+ features_list = []
+ for features in tqdm.tqdm(iterator, total=num_files):
+ features_list.append(features)
+
+ # Explicit clean up
+ del iterator
+ del generator
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ if flatten:
+ return np.concatenate(features_list)
+
+ return features_list
+
+
+def get_and_dump_features(
+ feature_type,
+ checkpoint_path,
+ layer,
+ manifest_path,
+ sample_pct,
+ flatten,
+ out_features_path,
+):
+ # Feature extraction
+ features_batch = get_features(
+ feature_type=feature_type,
+ checkpoint_path=checkpoint_path,
+ layer=layer,
+ manifest_path=manifest_path,
+ sample_pct=sample_pct,
+ flatten=flatten,
+ )
+
+ # Save features
+ out_dir_path = os.path.dirname(out_features_path)
+ os.makedirs(out_dir_path, exist_ok=True)
+ shutil.copyfile(
+ manifest_path,
+ os.path.join(out_dir_path, os.path.basename(manifest_path)),
+ )
+ np.save(out_features_path, features_batch)
+
+ return features_batch
diff --git a/fairseq/examples/textless_nlp/gslm/speech2unit/pretrained/w2v2_feature_reader.py b/fairseq/examples/textless_nlp/gslm/speech2unit/pretrained/w2v2_feature_reader.py
new file mode 100644
index 0000000000000000000000000000000000000000..b878321e445093f187e7af5310622a6ac456c30d
--- /dev/null
+++ b/fairseq/examples/textless_nlp/gslm/speech2unit/pretrained/w2v2_feature_reader.py
@@ -0,0 +1,46 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import fairseq
+import soundfile as sf
+
+
+class Wav2VecFeatureReader:
+ """
+ Wrapper class to run inference on Wav2Vec 2.0 model.
+ Helps extract features for a given audio file.
+ """
+
+ def __init__(self, checkpoint_path, layer):
+ state = fairseq.checkpoint_utils.load_checkpoint_to_cpu(
+ checkpoint_path
+ )
+
+ w2v_args = state["args"]
+ self.task = fairseq.tasks.setup_task(w2v_args)
+ model = self.task.build_model(w2v_args)
+ model.load_state_dict(state["model"], strict=True)
+ model.eval()
+ model.cuda()
+ self.model = model
+ self.layer = layer
+
+ def read_audio(self, fname):
+ wav, sr = sf.read(fname)
+ if wav.ndim == 2:
+ wav = wav.mean(-1)
+ assert wav.ndim == 1, wav.ndim
+ assert sr == self.task.cfg.sample_rate, sr
+ return wav
+
+ def get_feats(self, file_path):
+ x = self.read_audio(file_path)
+ with torch.no_grad():
+ source = torch.from_numpy(x).view(1, -1).float().cuda()
+ res = self.model(
+ source=source, mask=False, features_only=True, layer=self.layer
+ )
+ return res["layer_results"][self.layer][0].squeeze(1)
diff --git a/fairseq/examples/textless_nlp/gslm/tools/README.md b/fairseq/examples/textless_nlp/gslm/tools/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..61fcbbded80023f75eaec4b69ddfbbe4cc252e5b
--- /dev/null
+++ b/fairseq/examples/textless_nlp/gslm/tools/README.md
@@ -0,0 +1,22 @@
+# GSLM Tools
+
+## Resynthesis
+You can use the command line tool below to input an audio file and get the resynthesized audio. This tool implements the unsupervised method for resynthesis described in the paper. The way to invoke the command line tool is shown below.
+```
+FAIRSEQ_ROOT=
+TYPE=
+ACOUSTIC_MODEL_PATH=
+LAYER=
+KM_MODEL_PATH=
+TTS_MODEL_PATH=
+WAVEGLOW_PATH=
+
+PYTHONPATH=${FAIRSEQ_ROOT}:${FAIRSEQ_ROOT}/examples/textless_nlp/gslm/unit2speech python ${FAIRSEQ_ROOT}/examples/textless_nlp/gslm/tools/gen_speech.py \
+ --feature_type $TYPE \
+ --acoustic_model_path $ACOUSTIC_MODEL_PATH \
+ --layer $LAYER \
+ --kmeans_model_path $KM_MODEL_PATH \
+ --tts_model_path $TTS_MODEL_PATH \
+ --waveglow_path $WAVEGLOW_PATH \
+ --max_decoder_steps 2000
+```
\ No newline at end of file
diff --git a/fairseq/examples/textless_nlp/gslm/tools/resynthesize_speech.py b/fairseq/examples/textless_nlp/gslm/tools/resynthesize_speech.py
new file mode 100644
index 0000000000000000000000000000000000000000..2b6215d372035284f115b6eec0712a324246b67a
--- /dev/null
+++ b/fairseq/examples/textless_nlp/gslm/tools/resynthesize_speech.py
@@ -0,0 +1,138 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import argparse
+import gc
+import logging
+
+import joblib
+import soundfile as sf
+import torch
+from examples.textless_nlp.gslm.speech2unit.pretrained.utils import (
+ get_feature_reader,
+)
+from examples.textless_nlp.gslm.unit2speech.tts_data import (
+ TacotronInputDataset,
+)
+from examples.textless_nlp.gslm.unit2speech.utils import (
+ load_tacotron,
+ load_waveglow,
+ synthesize_audio,
+)
+
+
+def get_logger():
+ log_format = "[%(asctime)s] [%(levelname)s]: %(message)s"
+ logging.basicConfig(format=log_format, level=logging.INFO)
+ logger = logging.getLogger(__name__)
+ return logger
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ description="GSLM speech resynthesis tool."
+ )
+ parser.add_argument(
+ "--feature_type",
+ type=str,
+ choices=["logmel", "hubert", "w2v2", "cpc"],
+ default=None,
+ required=True,
+ help="Acoustic feature type",
+ )
+ parser.add_argument(
+ "--acoustic_model_path",
+ type=str,
+ help="Pretrained acoustic model checkpoint",
+ )
+ parser.add_argument(
+ "--layer", type=int, help="Layer of acoustic model"
+ )
+ parser.add_argument(
+ "--kmeans_model_path",
+ type=str,
+ required=True,
+ help="K-means model file path to use for inference",
+ )
+ parser.add_argument(
+ "--tts_model_path",
+ type=str,
+ help="TTS model file path to use for inference",
+ )
+ parser.add_argument(
+ "--waveglow_path",
+ type=str,
+ help="Waveglow (vocoder) model file path to use for inference",
+ )
+ parser.add_argument("--max_decoder_steps", type=int, default=2000)
+ parser.add_argument("--denoiser_strength", type=float, default=0.1)
+ return parser
+
+
+################################################
+def main(args, logger):
+ # Acoustic Model
+ logger.info(f"Loading acoustic model from {args.tts_model_path}...")
+ feature_reader_cls = get_feature_reader(args.feature_type)
+ reader = feature_reader_cls(
+ checkpoint_path=args.acoustic_model_path, layer=args.layer
+ )
+
+ # K-means Model
+ logger.info(f"Loading K-means model from {args.kmeans_model_path} ...")
+ kmeans_model = joblib.load(open(args.kmeans_model_path, "rb"))
+ kmeans_model.verbose = False
+
+ # TTS Model
+ logger.info(f"Loading TTS model from {args.tts_model_path}...")
+ tacotron_model, sample_rate, hparams = load_tacotron(
+ tacotron_model_path=args.tts_model_path,
+ max_decoder_steps=args.max_decoder_steps,
+ )
+
+ # Waveglow Model
+ logger.info(f"Loading Waveglow model from {args.waveglow_path}...")
+ waveglow, denoiser = load_waveglow(waveglow_path=args.waveglow_path)
+
+ # Dataset
+ tts_dataset = TacotronInputDataset(hparams)
+
+ iters = 0
+ while True:
+ in_file_path = input(
+ "Input: Enter the full file path of audio file...\n"
+ )
+ out_file_path = input(
+ "Output: Enter the full file path of audio file...\n"
+ )
+ feats = reader.get_feats(in_file_path).cpu().numpy()
+ iters += 1
+ if iters == 1000:
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ quantized_units = kmeans_model.predict(feats)
+ quantized_units_str = " ".join(map(str, quantized_units))
+
+ tts_input = tts_dataset.get_tensor(quantized_units_str)
+ mel, aud, aud_dn, has_eos = synthesize_audio(
+ tacotron_model,
+ waveglow,
+ denoiser,
+ tts_input.unsqueeze(0),
+ strength=args.denoiser_strength,
+ )
+ sf.write(
+ f"{out_file_path}", aud_dn[0].cpu().float().numpy(), sample_rate
+ )
+ logger.info("Resynthesis done!\n")
+
+
+if __name__ == "__main__":
+ parser = get_parser()
+ args = parser.parse_args()
+ logger = get_logger()
+ logger.info(args)
+ main(args, logger)
diff --git a/fairseq/examples/textless_nlp/gslm/ulm/README.md b/fairseq/examples/textless_nlp/gslm/ulm/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..01459121cebefc61fdc2eae201462aa78d699111
--- /dev/null
+++ b/fairseq/examples/textless_nlp/gslm/ulm/README.md
@@ -0,0 +1,72 @@
+# Unit Language Model (ULM)
+
+Here you can find links to the pre-trained ULMs and instructions on training new models using fairseq. At the end of the page, we also share how to run sampling for those models and provide pointers to the transcribed prompts we used.
+
+## Pre-trained models
+
+Using the links below, you can download pre-trained models for various unit types and vocabulary sizes:
+
+| | 50 | 100 | 200
+|-|-|-|-
+| LogMel Filterbank | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/logmel/lm_km50/logmel50_lm.tgz) | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/logmel/lm_km100/logmel100_lm.tgz) | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/logmel/lm_km200/logmel200_lm.tgz)
+| Modified CPC | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/cpc/lm_km50/cpc50_lm.tgz) | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/cpc/lm_km100/cpc100_lm.tgz) | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/cpc/lm_km200/cpc200_lm.tgz)
+| HuBERT | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/hubert/lm_km50/hubert50_lm.tgz) | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/hubert/lm_km100/hubert100_lm.tgz) | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/hubert/lm_km200/hubert200_lm.tgz)
+| Wav2Vec 2.0 | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/w2v2/lm_km50/w2v2_50_lm.tgz) | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/w2v2/lm_km100/w2v2_100_lm.tgz) | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/w2v2/lm_km200/w2v2_200_lm.tgz)
+
+
+## Preprocessing data
+Assuming that unit-transcribed train, valid, and test sets are located in `data/train.txt`, `data/valid.txt`, and `data/test.txt`, respectively,
+we run the following command to get a preprocessed version of the datast in `data-bin`:
+
+```bash
+fairseq-preprocess --only-source \
+ --trainpref data/train.txt --validpref data/valid.txt --testpref data/test.txt \
+ --destdir data-bin/ --workers 40
+```
+As a result, the `data-bin` directory should appear.
+
+## Fitting a Unit Language Model (ULM)
+As an ULM, we train a standard fairseq Transformer LM. Assuming 8 GPUs used for training, a good starting point for an ULM training would be:
+```bash
+ fairseq-train data-bin/ \
+ --task=language_modeling \
+ --arch=transformer_lm_big \
+ --share-decoder-input-output-embed \
+ --dropout=0.1 \
+ --attention-dropout=0.1 \
+ --optimizer=adam \
+ --adam-betas='(0.9, 0.98)' \
+ --clip-norm=1.0 \
+ --lr=0.0005 \
+ --lr-scheduler=inverse_sqrt \
+ --warmup-updates=4000 \
+ --warmup-init-lr=1e-07 \
+ --tokens-per-sample=3072 \
+ --update-freq=16 \
+ --max-tokens=4096 \
+ --num-workers=4 \
+ --skip-invalid-size-inputs-valid-test \
+ --max-update=500000 \
+ --log-interval=10 \
+ --seed=100501 \
+ --fp16 \
+ --sample-break-mode=eos
+```
+This command will train a Transformer-large model (12 layers). You can train other standard LM models provided by fairseq, e.g. specify `--arch=transformer_lm` to train a smaller (6-layer) Transformer model. When training with a different number of GPUs, it might be a good idea to adjust the `update-freq` parameter. To save the GPU memory at an expense of additional computation, it can be useful to enable activation checkpointing with `--checkpoint-activations`.
+
+## Sampling from an ULM
+Once an ULM was trained, we can use it for generating new utterances. Suppose, that the prompts are given in a file named `prompts.txt`. Then we can sample continuations by running the following command:
+
+```bash
+ python sample.py data-bin/ \
+ --path=checkpoints/checkpoint_best.pt --task=language_modeling --sampling --temperature=0.7 \
+ --seed=1 --prompts=prompts.txt --output=samples.txt --max-len-a=0 --max-len-b=500 \
+ --prefix-size=-1 --batch-size=16 --fp16 --samples-per-prompt=10
+```
+Here, `--prefix-size` controls the number of tokens that are used to prime the ULM. When set to a positive value, the sampling script will take first `prefix-size` tokens to prompt the ULM; with `0` it runs unconditional sampling and with `-1` the entire prompt is used.
+`--samples-per-prompt` specifies how many utterances are generated with every prompt which can be useful when generating multiple prompt continuations. In this command, `--max-len-a` and `--max-len-b` control the number of generated tokens.
+
+When using a pretrained model from above, `data-bin` should point to the unpacked directory (with `dict.txt` file).
+
+Evaluation-time, to generate prompts, we used utterances from LibriSpeech dev-clean and test-clean that are longer than 6s. We took first 3s from an utterance as a prompt. Unit transcripts of those prompts can be downloaded here: [[dev]](https://dl.fbaipublicfiles.com/textless_nlp/gslm/eval_data/dev_prompts.tgz) [[test]](https://dl.fbaipublicfiles.com/textless_nlp/gslm/eval_data/test_prompts.tgz)
+
diff --git a/fairseq/examples/textless_nlp/gslm/ulm/sample.py b/fairseq/examples/textless_nlp/gslm/ulm/sample.py
new file mode 100644
index 0000000000000000000000000000000000000000..77302a6894cacf07588cf34fb1e695dc519d7df5
--- /dev/null
+++ b/fairseq/examples/textless_nlp/gslm/ulm/sample.py
@@ -0,0 +1,174 @@
+#!/usr/bin/env python3 -u
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+"""
+Sample from a trained LM; hacked fairseq-interactive
+"""
+from collections import namedtuple
+import os
+import ast
+import numpy as np
+
+from fairseq import checkpoint_utils, options, tasks, utils
+
+import tqdm
+
+Batch = namedtuple('Batch', 'ids src_tokens src_lengths')
+Translation = namedtuple('Translation', 'src_str hypos pos_scores alignments')
+
+
+def make_batches(lines, args, task, max_positions):
+ tokens = [
+ task.source_dictionary.encode_line(
+ src_str, add_if_not_exist=False
+ ).long()
+ for src_str in lines
+ ]
+ lengths = [t.numel() for t in tokens]
+ itr = task.get_batch_iterator(
+ dataset=task.build_dataset_for_inference(tokens, lengths),
+ max_tokens=args.dataset.max_tokens,
+ max_sentences=args.dataset.batch_size,
+ max_positions=max_positions,
+ ignore_invalid_inputs=args.dataset.skip_invalid_size_inputs_valid_test
+ ).next_epoch_itr(shuffle=False)
+ for batch in itr:
+ yield Batch(
+ ids=batch['id'],
+ src_tokens=batch['net_input']['src_tokens'], src_lengths=batch['net_input']['src_lengths'],
+ )
+
+
+def main(args):
+ arg_prompts = args.prompts
+ arg_output = args.output
+ arg_debug = args.debug
+ arg_sample_size = args.samples_per_prompt
+
+ try:
+ from fairseq.dataclass.utils import convert_namespace_to_omegaconf
+ args = convert_namespace_to_omegaconf(args)
+ except:
+ pass
+
+ # if args.max_tokens is None and args.max_sentences is None:
+ if args.common.seed is not None:
+ np.random.seed(args.common.seed)
+ utils.set_torch_seed(args.common.seed)
+
+ if args.generation.sampling:
+ args.generation.nbest = args.generation.beam = arg_sample_size
+
+ task = tasks.setup_task(args.task)
+
+ overrides = ast.literal_eval(args.common_eval.model_overrides)
+
+ models, _model_args = checkpoint_utils.load_model_ensemble(
+ args.common_eval.path.split(os.pathsep),
+ arg_overrides=overrides,
+ task=task,
+ suffix=getattr(args, "checkpoint_suffix", ""),
+ )
+
+ # Set dictionaries
+ src_dict = task.source_dictionary
+ tgt_dict = task.target_dictionary
+
+ # Optimize ensemble for generation
+ for model in models:
+ model.prepare_for_inference_(args)
+ model.cuda()
+
+ # Load alignment dictionary for unknown word replacement
+ # (None if no unknown word replacement, empty if no path to align dictionary)
+ align_dict = utils.load_align_dict(args.generation.replace_unk)
+
+ max_positions = utils.resolve_max_positions(
+ task.max_positions(),
+ *[model.max_positions() for model in models]
+ )
+
+ output_file = open(arg_output, 'w')
+
+ with open(arg_prompts, 'r') as fin:
+ lines = fin.readlines()
+
+ split = [x.split('|', 1) for x in lines]
+ seq_id = [x[0] for x in split]
+ prompts = [x[1] for x in split]
+
+ if args.generation.prefix_size >= 0:
+ prompts = [' '.join(l.split()[:args.generation.prefix_size])
+ for l in prompts]
+
+ if arg_debug:
+ prompts = prompts[:10]
+
+ generator = task.build_generator(models, args.generation)
+
+ start_id = 0
+ pbar = tqdm.tqdm(total=len(prompts))
+ for batch in make_batches(prompts, args, task, max_positions):
+ src_tokens = batch.src_tokens
+ src_lengths = batch.src_lengths
+ src_tokens = src_tokens.cuda()
+ src_lengths = src_lengths.cuda()
+
+ sample = {
+ 'net_input': {
+ 'src_tokens': src_tokens,
+ 'src_lengths': src_lengths,
+ },
+ }
+
+ results = []
+ translations = task.inference_step(generator, models, sample)
+ for i, (id, hypos) in enumerate(zip(batch.ids.tolist(), translations)):
+ src_tokens_i = utils.strip_pad(src_tokens[i], tgt_dict.pad())
+ results.append((i + start_id, src_tokens_i, hypos))
+
+ # sort output to match input order
+ for id, src_tokens, hypos in sorted(results, key=lambda x: x[0]):
+ if src_dict is not None:
+ src_str = src_dict.string(
+ src_tokens, args.common_eval.post_process)
+
+ # Process top predictions
+ for hypo_id, hypo in enumerate(hypos):
+ _hypo_tokens, hypo_str, _alignment = utils.post_process_prediction(
+ hypo_tokens=hypo['tokens'].int().cpu(),
+ src_str=src_str,
+ alignment=hypo['alignment'],
+ align_dict=align_dict,
+ tgt_dict=tgt_dict,
+ remove_bpe=args.common_eval.post_process,
+ )
+
+ detok_hypo_str = hypo_str
+ utterance = detok_hypo_str
+ print(f'{seq_id[id]}__{hypo_id}|{utterance}', file=output_file)
+ pbar.update(1)
+ start_id += len(results)
+
+ # output_file.close()
+
+
+def cli_main():
+ parser = options.get_interactive_generation_parser()
+ parser.add_argument('--prompts', type=str, default=None, required=True)
+ parser.add_argument('--output', type=str, default=None, required=True)
+ parser.add_argument('--debug', action='store_true')
+ parser.add_argument('--samples-per-prompt', type=int, default=1)
+
+ args = options.parse_args_and_arch(parser)
+
+ np.random.seed(args.seed)
+ utils.set_torch_seed(args.seed)
+
+ main(args)
+
+
+if __name__ == '__main__':
+ cli_main()
diff --git a/fairseq/examples/textless_nlp/gslm/unit2speech/README.md b/fairseq/examples/textless_nlp/gslm/unit2speech/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..57104230655c7c517d25904e634c53b6159ee60f
--- /dev/null
+++ b/fairseq/examples/textless_nlp/gslm/unit2speech/README.md
@@ -0,0 +1,42 @@
+# Unit to Speech Model (unit2speech)
+
+Unit to speech model is modified Tacotron2 model that learns to synthesize speech from discrete speech units. All models are trained on quantized [LJSpeech](https://keithito.com/LJ-Speech-Dataset/).
+
+Upstream Units | Download Link
+|-|-
+Log Mel Filterbank + KM50 | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/logmel/tts_km50/tts_checkpoint_best.pt)
+Log Mel Filterbank + KM100 | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/logmel/tts_km100/tts_checkpoint_best.pt)
+Log Mel Filterbank + KM200 | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/logmel/tts_km200/tts_checkpoint_best.pt)
+Log Mel Filterbank + KM500 | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/logmel/tts_km500/tts_checkpoint_best.pt)
+Modified CPC + KM50 | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/cpc/tts_km50/tts_checkpoint_best.pt)
+Modified CPC + KM100 | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/cpc/tts_km100/tts_checkpoint_best.pt)
+Modified CPC + KM200 | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/cpc/tts_km200/tts_checkpoint_best.pt)
+Modified CPC + KM500 | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/cpc/tts_km500/tts_checkpoint_best.pt)
+HuBERT Base + KM50 | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/hubert/tts_km50/tts_checkpoint_best.pt)
+HuBERT Base + KM100 | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/hubert/tts_km100/tts_checkpoint_best.pt)
+HuBERT Base + KM200 | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/hubert/tts_km200/tts_checkpoint_best.pt)
+HuBERT Base + KM500 | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/hubert/tts_km500/tts_checkpoint_best.pt)
+wav2vec 2.0 Large + KM50 | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/w2v2/tts_km50/tts_checkpoint_best.pt)
+wav2vec 2.0 Large + KM100 | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/w2v2/tts_km100/tts_checkpoint_best.pt)
+wav2vec 2.0 Large + KM200 | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/w2v2/tts_km200/tts_checkpoint_best.pt)
+wav2vec 2.0 Large + KM500 | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/w2v2/tts_km500/tts_checkpoint_best.pt)
+
+## Run inference using a unit2speech model
+* Install librosa, unidecode and inflect using `pip install librosa, unidecode, inflect`
+* Download [Waveglow checkpoint](https://dl.fbaipublicfiles.com/textless_nlp/gslm/waveglow_256channels_new.pt). This is the vocoder.
+
+Sample commnd to run inference using trained unit2speech models. Please note that the quantized audio to synthesized should be using the same units as the unit2speech model was trained with.
+```
+FAIRSEQ_ROOT=
+TTS_MODEL_PATH=
+QUANTIZED_UNIT_PATH=
+OUT_DIR=
+WAVEGLOW_PATH=
+
+PYTHONPATH=${FAIRSEQ_ROOT}:${FAIRSEQ_ROOT}/examples/textless_nlp/gslm/unit2speech python ${FAIRSEQ_ROOT}/examples/textless_nlp/gslm/unit2speech/synthesize_audio_from_units.py \
+ --tts_model_path $TTS_MODEL_PATH \
+ --quantized_unit_path $QUANTIZED_UNIT_PATH \
+ --out_audio_dir $OUT_DIR \
+ --waveglow_path $WAVEGLOW_PATH \
+ --max_decoder_steps 2000
+```
\ No newline at end of file
diff --git a/fairseq/examples/textless_nlp/gslm/unit2speech/convert_to_16k.py b/fairseq/examples/textless_nlp/gslm/unit2speech/convert_to_16k.py
new file mode 100644
index 0000000000000000000000000000000000000000..2be848fceae65e3bd5747a2c98106b0215c6a039
--- /dev/null
+++ b/fairseq/examples/textless_nlp/gslm/unit2speech/convert_to_16k.py
@@ -0,0 +1,56 @@
+import os
+import shlex
+import subprocess
+import progressbar
+from time import time
+from pathlib import Path
+
+def find_all_files(path_dir, extension):
+ out = []
+ for root, dirs, filenames in os.walk(path_dir):
+ for f in filenames:
+ if f.endswith(extension):
+ out.append(((str(Path(f).stem)), os.path.join(root, f)))
+ return out
+
+def convert16k(inputfile, outputfile16k):
+ command = ('sox -c 1 -b 16 {} -t wav {} rate 16k'.format(inputfile, outputfile16k))
+ subprocess.call(shlex.split(command))
+
+if __name__ == "__main__":
+ import argparse
+
+ parser = argparse.ArgumentParser(description='Convert to wav 16k audio using sox.')
+ parser.add_argument('input_dir', type=str,
+ help='Path to the input dir.')
+ parser.add_argument('output_dir', type=str,
+ help='Path to the output dir.')
+ parser.add_argument('--extension', type=str, default='wav',
+ help='Audio file extension in the input. Default: mp3')
+ args = parser.parse_args()
+
+ # Find all sequences
+ print(f"Finding all audio files with extension '{args.extension}' from {args.input_dir}...")
+ audio_files = find_all_files(args.input_dir, args.extension)
+ print(f"Done! Found {len(audio_files)} files.")
+
+ # Convert to relative path
+ audio_files = [os.path.relpath(file[-1], start=args.input_dir) for file in audio_files]
+
+ # Create all the directories needed
+ rel_dirs_set = set([os.path.dirname(file) for file in audio_files])
+ for rel_dir in rel_dirs_set:
+ Path(os.path.join(args.output_dir, rel_dir)).mkdir(parents=True, exist_ok=True)
+
+ # Converting wavs files
+ print("Converting the audio to wav files...")
+ bar = progressbar.ProgressBar(maxval=len(audio_files))
+ bar.start()
+ start_time = time()
+ for index, file in enumerate(audio_files):
+ bar.update(index)
+ input_file = os.path.join(args.input_dir, file)
+ output_file = os.path.join(args.output_dir, os.path.splitext(file)[0]+".wav")
+ convert16k(input_file, output_file)
+ bar.finish()
+ print(f"...done {len(audio_files)} files in {time()-start_time} seconds.")
\ No newline at end of file
diff --git a/fairseq/examples/textless_nlp/gslm/unit2speech/glow.py b/fairseq/examples/textless_nlp/gslm/unit2speech/glow.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a7696403d505afdf0f1606f8220801b0f46152f
--- /dev/null
+++ b/fairseq/examples/textless_nlp/gslm/unit2speech/glow.py
@@ -0,0 +1,311 @@
+# *****************************************************************************
+# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in the
+# documentation and/or other materials provided with the distribution.
+# * Neither the name of the NVIDIA CORPORATION nor the
+# names of its contributors may be used to endorse or promote products
+# derived from this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
+# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
+# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
+# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
+# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+#
+# *****************************************************************************
+import copy
+import torch
+from torch.autograd import Variable
+import torch.nn.functional as F
+
+
+@torch.jit.script
+def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
+ n_channels_int = n_channels[0]
+ in_act = input_a+input_b
+ t_act = torch.tanh(in_act[:, :n_channels_int, :])
+ s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
+ acts = t_act * s_act
+ return acts
+
+
+class WaveGlowLoss(torch.nn.Module):
+ def __init__(self, sigma=1.0):
+ super(WaveGlowLoss, self).__init__()
+ self.sigma = sigma
+
+ def forward(self, model_output):
+ z, log_s_list, log_det_W_list = model_output
+ for i, log_s in enumerate(log_s_list):
+ if i == 0:
+ log_s_total = torch.sum(log_s)
+ log_det_W_total = log_det_W_list[i]
+ else:
+ log_s_total = log_s_total + torch.sum(log_s)
+ log_det_W_total += log_det_W_list[i]
+
+ loss = torch.sum(z*z)/(2*self.sigma*self.sigma) - log_s_total - log_det_W_total
+ return loss/(z.size(0)*z.size(1)*z.size(2))
+
+
+class Invertible1x1Conv(torch.nn.Module):
+ """
+ The layer outputs both the convolution, and the log determinant
+ of its weight matrix. If reverse=True it does convolution with
+ inverse
+ """
+ def __init__(self, c):
+ super(Invertible1x1Conv, self).__init__()
+ self.conv = torch.nn.Conv1d(c, c, kernel_size=1, stride=1, padding=0,
+ bias=False)
+
+ # Sample a random orthonormal matrix to initialize weights
+ W = torch.qr(torch.FloatTensor(c, c).normal_())[0]
+
+ # Ensure determinant is 1.0 not -1.0
+ if torch.det(W) < 0:
+ W[:,0] = -1*W[:,0]
+ W = W.view(c, c, 1)
+ self.conv.weight.data = W
+
+ def forward(self, z, reverse=False):
+ # shape
+ batch_size, group_size, n_of_groups = z.size()
+
+ W = self.conv.weight.squeeze()
+
+ if reverse:
+ if not hasattr(self, 'W_inverse'):
+ # Reverse computation
+ W_inverse = W.float().inverse()
+ W_inverse = Variable(W_inverse[..., None])
+ if z.type() == 'torch.cuda.HalfTensor':
+ W_inverse = W_inverse.half()
+ self.W_inverse = W_inverse
+ z = F.conv1d(z, self.W_inverse, bias=None, stride=1, padding=0)
+ return z
+ else:
+ # Forward computation
+ log_det_W = batch_size * n_of_groups * torch.logdet(W)
+ z = self.conv(z)
+ return z, log_det_W
+
+
+class WN(torch.nn.Module):
+ """
+ This is the WaveNet like layer for the affine coupling. The primary difference
+ from WaveNet is the convolutions need not be causal. There is also no dilation
+ size reset. The dilation only doubles on each layer
+ """
+ def __init__(self, n_in_channels, n_mel_channels, n_layers, n_channels,
+ kernel_size):
+ super(WN, self).__init__()
+ assert(kernel_size % 2 == 1)
+ assert(n_channels % 2 == 0)
+ self.n_layers = n_layers
+ self.n_channels = n_channels
+ self.in_layers = torch.nn.ModuleList()
+ self.res_skip_layers = torch.nn.ModuleList()
+
+ start = torch.nn.Conv1d(n_in_channels, n_channels, 1)
+ start = torch.nn.utils.weight_norm(start, name='weight')
+ self.start = start
+
+ # Initializing last layer to 0 makes the affine coupling layers
+ # do nothing at first. This helps with training stability
+ end = torch.nn.Conv1d(n_channels, 2*n_in_channels, 1)
+ end.weight.data.zero_()
+ end.bias.data.zero_()
+ self.end = end
+
+ cond_layer = torch.nn.Conv1d(n_mel_channels, 2*n_channels*n_layers, 1)
+ self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight')
+
+ for i in range(n_layers):
+ dilation = 2 ** i
+ padding = int((kernel_size*dilation - dilation)/2)
+ in_layer = torch.nn.Conv1d(n_channels, 2*n_channels, kernel_size,
+ dilation=dilation, padding=padding)
+ in_layer = torch.nn.utils.weight_norm(in_layer, name='weight')
+ self.in_layers.append(in_layer)
+
+
+ # last one is not necessary
+ if i < n_layers - 1:
+ res_skip_channels = 2*n_channels
+ else:
+ res_skip_channels = n_channels
+ res_skip_layer = torch.nn.Conv1d(n_channels, res_skip_channels, 1)
+ res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight')
+ self.res_skip_layers.append(res_skip_layer)
+
+ def forward(self, forward_input):
+ audio, spect = forward_input
+ audio = self.start(audio)
+ output = torch.zeros_like(audio)
+ n_channels_tensor = torch.IntTensor([self.n_channels])
+
+ spect = self.cond_layer(spect)
+
+ for i in range(self.n_layers):
+ spect_offset = i*2*self.n_channels
+ acts = fused_add_tanh_sigmoid_multiply(
+ self.in_layers[i](audio),
+ spect[:,spect_offset:spect_offset+2*self.n_channels,:],
+ n_channels_tensor)
+
+ res_skip_acts = self.res_skip_layers[i](acts)
+ if i < self.n_layers - 1:
+ audio = audio + res_skip_acts[:,:self.n_channels,:]
+ output = output + res_skip_acts[:,self.n_channels:,:]
+ else:
+ output = output + res_skip_acts
+
+ return self.end(output)
+
+
+class WaveGlow(torch.nn.Module):
+ def __init__(self, n_mel_channels, n_flows, n_group, n_early_every,
+ n_early_size, WN_config):
+ super(WaveGlow, self).__init__()
+
+ self.upsample = torch.nn.ConvTranspose1d(n_mel_channels,
+ n_mel_channels,
+ 1024, stride=256)
+ assert(n_group % 2 == 0)
+ self.n_flows = n_flows
+ self.n_group = n_group
+ self.n_early_every = n_early_every
+ self.n_early_size = n_early_size
+ self.WN = torch.nn.ModuleList()
+ self.convinv = torch.nn.ModuleList()
+
+ n_half = int(n_group/2)
+
+ # Set up layers with the right sizes based on how many dimensions
+ # have been output already
+ n_remaining_channels = n_group
+ for k in range(n_flows):
+ if k % self.n_early_every == 0 and k > 0:
+ n_half = n_half - int(self.n_early_size/2)
+ n_remaining_channels = n_remaining_channels - self.n_early_size
+ self.convinv.append(Invertible1x1Conv(n_remaining_channels))
+ self.WN.append(WN(n_half, n_mel_channels*n_group, **WN_config))
+ self.n_remaining_channels = n_remaining_channels # Useful during inference
+
+ def forward(self, forward_input):
+ """
+ forward_input[0] = mel_spectrogram: batch x n_mel_channels x frames
+ forward_input[1] = audio: batch x time
+ """
+ spect, audio = forward_input
+
+ # Upsample spectrogram to size of audio
+ spect = self.upsample(spect)
+ assert(spect.size(2) >= audio.size(1))
+ if spect.size(2) > audio.size(1):
+ spect = spect[:, :, :audio.size(1)]
+
+ spect = spect.unfold(2, self.n_group, self.n_group).permute(0, 2, 1, 3)
+ spect = spect.contiguous().view(spect.size(0), spect.size(1), -1).permute(0, 2, 1)
+
+ audio = audio.unfold(1, self.n_group, self.n_group).permute(0, 2, 1)
+ output_audio = []
+ log_s_list = []
+ log_det_W_list = []
+
+ for k in range(self.n_flows):
+ if k % self.n_early_every == 0 and k > 0:
+ output_audio.append(audio[:,:self.n_early_size,:])
+ audio = audio[:,self.n_early_size:,:]
+
+ audio, log_det_W = self.convinv[k](audio)
+ log_det_W_list.append(log_det_W)
+
+ n_half = int(audio.size(1)/2)
+ audio_0 = audio[:,:n_half,:]
+ audio_1 = audio[:,n_half:,:]
+
+ output = self.WN[k]((audio_0, spect))
+ log_s = output[:, n_half:, :]
+ b = output[:, :n_half, :]
+ audio_1 = torch.exp(log_s)*audio_1 + b
+ log_s_list.append(log_s)
+
+ audio = torch.cat([audio_0, audio_1],1)
+
+ output_audio.append(audio)
+ return torch.cat(output_audio,1), log_s_list, log_det_W_list
+
+ def infer(self, spect, sigma=1.0):
+ spect = self.upsample(spect)
+ # trim conv artifacts. maybe pad spec to kernel multiple
+ time_cutoff = self.upsample.kernel_size[0] - self.upsample.stride[0]
+ spect = spect[:, :, :-time_cutoff]
+
+ spect = spect.unfold(2, self.n_group, self.n_group).permute(0, 2, 1, 3)
+ spect = spect.contiguous().view(spect.size(0), spect.size(1), -1).permute(0, 2, 1)
+
+ if spect.type() == 'torch.cuda.HalfTensor':
+ audio = torch.cuda.HalfTensor(spect.size(0),
+ self.n_remaining_channels,
+ spect.size(2)).normal_()
+ else:
+ audio = torch.cuda.FloatTensor(spect.size(0),
+ self.n_remaining_channels,
+ spect.size(2)).normal_()
+
+ audio = torch.autograd.Variable(sigma*audio)
+
+ for k in reversed(range(self.n_flows)):
+ n_half = int(audio.size(1)/2)
+ audio_0 = audio[:,:n_half,:]
+ audio_1 = audio[:,n_half:,:]
+
+ output = self.WN[k]((audio_0, spect))
+
+ s = output[:, n_half:, :]
+ b = output[:, :n_half, :]
+ audio_1 = (audio_1 - b)/torch.exp(s)
+ audio = torch.cat([audio_0, audio_1],1)
+
+ audio = self.convinv[k](audio, reverse=True)
+
+ if k % self.n_early_every == 0 and k > 0:
+ if spect.type() == 'torch.cuda.HalfTensor':
+ z = torch.cuda.HalfTensor(spect.size(0), self.n_early_size, spect.size(2)).normal_()
+ else:
+ z = torch.cuda.FloatTensor(spect.size(0), self.n_early_size, spect.size(2)).normal_()
+ audio = torch.cat((sigma*z, audio),1)
+
+ audio = audio.permute(0,2,1).contiguous().view(audio.size(0), -1).data
+ return audio
+
+ @staticmethod
+ def remove_weightnorm(model):
+ waveglow = model
+ for WN in waveglow.WN:
+ WN.start = torch.nn.utils.remove_weight_norm(WN.start)
+ WN.in_layers = remove(WN.in_layers)
+ WN.cond_layer = torch.nn.utils.remove_weight_norm(WN.cond_layer)
+ WN.res_skip_layers = remove(WN.res_skip_layers)
+ return waveglow
+
+
+def remove(conv_list):
+ new_conv_list = torch.nn.ModuleList()
+ for old_conv in conv_list:
+ old_conv = torch.nn.utils.remove_weight_norm(old_conv)
+ new_conv_list.append(old_conv)
+ return new_conv_list
diff --git a/fairseq/examples/textless_nlp/gslm/unit2speech/multiproc.py b/fairseq/examples/textless_nlp/gslm/unit2speech/multiproc.py
new file mode 100644
index 0000000000000000000000000000000000000000..2a287a4e97c66acbd36897b25f2ece5494005f03
--- /dev/null
+++ b/fairseq/examples/textless_nlp/gslm/unit2speech/multiproc.py
@@ -0,0 +1,27 @@
+import os
+import time
+import torch
+import sys
+import subprocess
+
+argslist = list(sys.argv)[1:]
+log_dir = argslist[-1]
+num_gpus = torch.cuda.device_count()
+argslist.append('--n_gpus={}'.format(num_gpus))
+workers = []
+job_id = time.strftime("%Y_%m_%d-%H%M%S")
+argslist.append("--group_name=group_{}".format(job_id))
+
+print("GPU log directory is {}".format(log_dir))
+os.makedirs(log_dir, exist_ok=True)
+for i in range(num_gpus):
+ argslist.append('--rank={}'.format(i))
+ stdout = None if i == 0 else open("{}/{}_GPU_{}.log".format(log_dir, job_id, i),
+ "w")
+ print(argslist)
+ p = subprocess.Popen([str(sys.executable)]+argslist, stdout=stdout)
+ workers.append(p)
+ argslist = argslist[:-1]
+
+for p in workers:
+ p.wait()
diff --git a/fairseq/examples/textless_nlp/gslm/unit2speech/synthesize_audio_from_units.py b/fairseq/examples/textless_nlp/gslm/unit2speech/synthesize_audio_from_units.py
new file mode 100644
index 0000000000000000000000000000000000000000..f226d5f50514ecb5ee3b4f1031df750609a56112
--- /dev/null
+++ b/fairseq/examples/textless_nlp/gslm/unit2speech/synthesize_audio_from_units.py
@@ -0,0 +1,97 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import argparse
+import logging
+import os
+
+import soundfile as sf
+from examples.textless_nlp.gslm.unit2speech.tts_data import (
+ TacotronInputDataset,
+)
+from examples.textless_nlp.gslm.unit2speech.utils import (
+ load_quantized_audio_from_file,
+ load_tacotron,
+ load_waveglow,
+ synthesize_audio,
+)
+
+
+def get_logger():
+ log_format = "[%(asctime)s] [%(levelname)s]: %(message)s"
+ logging.basicConfig(format=log_format, level=logging.INFO)
+ logger = logging.getLogger(__name__)
+ return logger
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ description="Wav2Vec 2.0 speech generator."
+ )
+ parser.add_argument(
+ "--quantized_unit_path",
+ type=str,
+ help="K-means model file path to use for inference",
+ )
+ parser.add_argument(
+ "--tts_model_path",
+ type=str,
+ help="TTS model file path to use for inference",
+ )
+ parser.add_argument(
+ "--waveglow_path",
+ type=str,
+ help="Path to the waveglow checkpoint (vocoder).",
+ )
+ parser.add_argument("--max_decoder_steps", type=int, default=2000)
+ parser.add_argument("--denoiser_strength", type=float, default=0.1)
+ parser.add_argument(
+ "--out_audio_dir",
+ type=str,
+ help="Output directory to dump audio files",
+ )
+
+ return parser
+
+
+def main(args, logger):
+ # Load quantized audio
+ logger.info(f"Loading quantized audio from {args.quantized_unit_path}...")
+ names_batch, quantized_units_batch = load_quantized_audio_from_file(
+ file_path=args.quantized_unit_path
+ )
+
+ logger.info(f"Loading TTS model from {args.tts_model_path}...")
+ tacotron_model, sample_rate, hparams = load_tacotron(
+ tacotron_model_path=args.tts_model_path,
+ max_decoder_steps=args.max_decoder_steps,
+ )
+
+ logger.info(f"Loading Waveglow model from {args.waveglow_path}...")
+ waveglow, denoiser = load_waveglow(waveglow_path=args.waveglow_path)
+
+ tts_dataset = TacotronInputDataset(hparams)
+ for name, quantized_units in zip(names_batch, quantized_units_batch):
+ quantized_units_str = " ".join(map(str, quantized_units))
+ tts_input = tts_dataset.get_tensor(quantized_units_str)
+ mel, aud, aud_dn, has_eos = synthesize_audio(
+ tacotron_model,
+ waveglow,
+ denoiser,
+ tts_input.unsqueeze(0),
+ strength=args.denoiser_strength,
+ )
+ out_file_path = os.path.join(args.out_audio_dir, f"{name}.wav")
+ sf.write(
+ f"{out_file_path}", aud_dn[0].cpu().float().numpy(), sample_rate
+ )
+
+
+if __name__ == "__main__":
+ parser = get_parser()
+ args = parser.parse_args()
+ logger = get_logger()
+ logger.info(args)
+ main(args, logger)
diff --git a/fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/__init__.py b/fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/audio_processing.py b/fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/audio_processing.py
new file mode 100644
index 0000000000000000000000000000000000000000..b5af7f723eb8047bc58db2f85234aea161fbc659
--- /dev/null
+++ b/fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/audio_processing.py
@@ -0,0 +1,93 @@
+import torch
+import numpy as np
+from scipy.signal import get_window
+import librosa.util as librosa_util
+
+
+def window_sumsquare(window, n_frames, hop_length=200, win_length=800,
+ n_fft=800, dtype=np.float32, norm=None):
+ """
+ # from librosa 0.6
+ Compute the sum-square envelope of a window function at a given hop length.
+
+ This is used to estimate modulation effects induced by windowing
+ observations in short-time fourier transforms.
+
+ Parameters
+ ----------
+ window : string, tuple, number, callable, or list-like
+ Window specification, as in `get_window`
+
+ n_frames : int > 0
+ The number of analysis frames
+
+ hop_length : int > 0
+ The number of samples to advance between frames
+
+ win_length : [optional]
+ The length of the window function. By default, this matches `n_fft`.
+
+ n_fft : int > 0
+ The length of each analysis frame.
+
+ dtype : np.dtype
+ The data type of the output
+
+ Returns
+ -------
+ wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))`
+ The sum-squared envelope of the window function
+ """
+ if win_length is None:
+ win_length = n_fft
+
+ n = n_fft + hop_length * (n_frames - 1)
+ x = np.zeros(n, dtype=dtype)
+
+ # Compute the squared window at the desired length
+ win_sq = get_window(window, win_length, fftbins=True)
+ win_sq = librosa_util.normalize(win_sq, norm=norm)**2
+ win_sq = librosa_util.pad_center(win_sq, n_fft)
+
+ # Fill the envelope
+ for i in range(n_frames):
+ sample = i * hop_length
+ x[sample:min(n, sample + n_fft)] += win_sq[:max(0, min(n_fft, n - sample))]
+ return x
+
+
+def griffin_lim(magnitudes, stft_fn, n_iters=30):
+ """
+ PARAMS
+ ------
+ magnitudes: spectrogram magnitudes
+ stft_fn: STFT class with transform (STFT) and inverse (ISTFT) methods
+ """
+
+ angles = np.angle(np.exp(2j * np.pi * np.random.rand(*magnitudes.size())))
+ angles = angles.astype(np.float32)
+ angles = torch.autograd.Variable(torch.from_numpy(angles))
+ signal = stft_fn.inverse(magnitudes, angles).squeeze(1)
+
+ for i in range(n_iters):
+ _, angles = stft_fn.transform(signal)
+ signal = stft_fn.inverse(magnitudes, angles).squeeze(1)
+ return signal
+
+
+def dynamic_range_compression(x, C=1, clip_val=1e-5):
+ """
+ PARAMS
+ ------
+ C: compression factor
+ """
+ return torch.log(torch.clamp(x, min=clip_val) * C)
+
+
+def dynamic_range_decompression(x, C=1):
+ """
+ PARAMS
+ ------
+ C: compression factor used to compress
+ """
+ return torch.exp(x) / C
diff --git a/fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/cleaners.py b/fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/cleaners.py
new file mode 100644
index 0000000000000000000000000000000000000000..e2e35c1a8cc4c628c5d05802677142c9a2122d2b
--- /dev/null
+++ b/fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/cleaners.py
@@ -0,0 +1,90 @@
+""" from https://github.com/keithito/tacotron """
+
+'''
+Cleaners are transformations that run over the input text at both training and eval time.
+
+Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners"
+hyperparameter. Some cleaners are English-specific. You'll typically want to use:
+ 1. "english_cleaners" for English text
+ 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using
+ the Unidecode library (https://pypi.python.org/pypi/Unidecode)
+ 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update
+ the symbols in symbols.py to match your data).
+'''
+
+import re
+from unidecode import unidecode
+from .numbers import normalize_numbers
+
+
+# Regular expression matching whitespace:
+_whitespace_re = re.compile(r'\s+')
+
+# List of (regular expression, replacement) pairs for abbreviations:
+_abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [
+ ('mrs', 'misess'),
+ ('mr', 'mister'),
+ ('dr', 'doctor'),
+ ('st', 'saint'),
+ ('co', 'company'),
+ ('jr', 'junior'),
+ ('maj', 'major'),
+ ('gen', 'general'),
+ ('drs', 'doctors'),
+ ('rev', 'reverend'),
+ ('lt', 'lieutenant'),
+ ('hon', 'honorable'),
+ ('sgt', 'sergeant'),
+ ('capt', 'captain'),
+ ('esq', 'esquire'),
+ ('ltd', 'limited'),
+ ('col', 'colonel'),
+ ('ft', 'fort'),
+]]
+
+
+def expand_abbreviations(text):
+ for regex, replacement in _abbreviations:
+ text = re.sub(regex, replacement, text)
+ return text
+
+
+def expand_numbers(text):
+ return normalize_numbers(text)
+
+
+def lowercase(text):
+ return text.lower()
+
+
+def collapse_whitespace(text):
+ return re.sub(_whitespace_re, ' ', text)
+
+
+def convert_to_ascii(text):
+ return unidecode(text)
+
+
+def basic_cleaners(text):
+ '''Basic pipeline that lowercases and collapses whitespace without transliteration.'''
+ text = lowercase(text)
+ text = collapse_whitespace(text)
+ return text
+
+
+def transliteration_cleaners(text):
+ '''Pipeline for non-English text that transliterates to ASCII.'''
+ text = convert_to_ascii(text)
+ text = lowercase(text)
+ text = collapse_whitespace(text)
+ return text
+
+
+def english_cleaners(text):
+ '''Pipeline for English text, including number and abbreviation expansion.'''
+ text = convert_to_ascii(text)
+ text = lowercase(text)
+ text = expand_numbers(text)
+ text = expand_abbreviations(text)
+ text = collapse_whitespace(text)
+ return text
diff --git a/fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/cmudict.py b/fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/cmudict.py
new file mode 100644
index 0000000000000000000000000000000000000000..62bfef745c30a56f7b6605d9e3becfbc40edb50d
--- /dev/null
+++ b/fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/cmudict.py
@@ -0,0 +1,65 @@
+""" from https://github.com/keithito/tacotron """
+
+import re
+
+
+valid_symbols = [
+ 'AA', 'AA0', 'AA1', 'AA2', 'AE', 'AE0', 'AE1', 'AE2', 'AH', 'AH0', 'AH1', 'AH2',
+ 'AO', 'AO0', 'AO1', 'AO2', 'AW', 'AW0', 'AW1', 'AW2', 'AY', 'AY0', 'AY1', 'AY2',
+ 'B', 'CH', 'D', 'DH', 'EH', 'EH0', 'EH1', 'EH2', 'ER', 'ER0', 'ER1', 'ER2', 'EY',
+ 'EY0', 'EY1', 'EY2', 'F', 'G', 'HH', 'IH', 'IH0', 'IH1', 'IH2', 'IY', 'IY0', 'IY1',
+ 'IY2', 'JH', 'K', 'L', 'M', 'N', 'NG', 'OW', 'OW0', 'OW1', 'OW2', 'OY', 'OY0',
+ 'OY1', 'OY2', 'P', 'R', 'S', 'SH', 'T', 'TH', 'UH', 'UH0', 'UH1', 'UH2', 'UW',
+ 'UW0', 'UW1', 'UW2', 'V', 'W', 'Y', 'Z', 'ZH'
+]
+
+_valid_symbol_set = set(valid_symbols)
+
+
+class CMUDict:
+ '''Thin wrapper around CMUDict data. http://www.speech.cs.cmu.edu/cgi-bin/cmudict'''
+ def __init__(self, file_or_path, keep_ambiguous=True):
+ if isinstance(file_or_path, str):
+ with open(file_or_path, encoding='latin-1') as f:
+ entries = _parse_cmudict(f)
+ else:
+ entries = _parse_cmudict(file_or_path)
+ if not keep_ambiguous:
+ entries = {word: pron for word, pron in entries.items() if len(pron) == 1}
+ self._entries = entries
+
+
+ def __len__(self):
+ return len(self._entries)
+
+
+ def lookup(self, word):
+ '''Returns list of ARPAbet pronunciations of the given word.'''
+ return self._entries.get(word.upper())
+
+
+
+_alt_re = re.compile(r'\([0-9]+\)')
+
+
+def _parse_cmudict(file):
+ cmudict = {}
+ for line in file:
+ if len(line) and (line[0] >= 'A' and line[0] <= 'Z' or line[0] == "'"):
+ parts = line.split(' ')
+ word = re.sub(_alt_re, '', parts[0])
+ pronunciation = _get_pronunciation(parts[1])
+ if pronunciation:
+ if word in cmudict:
+ cmudict[word].append(pronunciation)
+ else:
+ cmudict[word] = [pronunciation]
+ return cmudict
+
+
+def _get_pronunciation(s):
+ parts = s.strip().split(' ')
+ for part in parts:
+ if part not in _valid_symbol_set:
+ return None
+ return ' '.join(parts)
diff --git a/fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/layers.py b/fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..f10d557ff5a4fff03b94f81543bd58cf1a66bc8f
--- /dev/null
+++ b/fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/layers.py
@@ -0,0 +1,103 @@
+import torch
+from librosa.filters import mel as librosa_mel_fn
+from .audio_processing import dynamic_range_compression
+from .audio_processing import dynamic_range_decompression
+from .stft import STFT
+from .utils import get_mask_from_lengths
+
+
+class LinearNorm(torch.nn.Module):
+ def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
+ super(LinearNorm, self).__init__()
+ self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
+
+ torch.nn.init.xavier_uniform_(
+ self.linear_layer.weight,
+ gain=torch.nn.init.calculate_gain(w_init_gain))
+
+ def forward(self, x):
+ return self.linear_layer(x)
+
+
+class ConvNorm(torch.nn.Module):
+ def __init__(self, in_channels, out_channels, kernel_size=1, stride=1,
+ padding=None, dilation=1, bias=True, w_init_gain='linear'):
+ super(ConvNorm, self).__init__()
+ if padding is None:
+ assert(kernel_size % 2 == 1)
+ padding = int(dilation * (kernel_size - 1) / 2)
+
+ self.conv = torch.nn.Conv1d(in_channels, out_channels,
+ kernel_size=kernel_size, stride=stride,
+ padding=padding, dilation=dilation,
+ bias=bias)
+
+ torch.nn.init.xavier_uniform_(
+ self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain))
+
+ def forward(self, signal):
+ conv_signal = self.conv(signal)
+ return conv_signal
+
+
+class GlobalAvgPool(torch.nn.Module):
+ def __init__(self):
+ super(GlobalAvgPool, self).__init__()
+
+ def forward(self, x, lengths=None):
+ """Average pooling across time steps (dim=1) with optionally lengths.
+ Args:
+ x: torch.Tensor of shape (N, T, ...)
+ lengths: None or torch.Tensor of shape (N,)
+ dim: dimension to pool
+ """
+ if lengths is None:
+ return x.mean(dim=1, keepdim=False)
+ else:
+ mask = get_mask_from_lengths(lengths).type(x.type()).to(x.device)
+ mask_shape = list(mask.size()) + [1 for _ in range(x.ndimension()-2)]
+ mask = mask.reshape(*mask_shape)
+ numer = (x * mask).sum(dim=1, keepdim=False)
+ denom = mask.sum(dim=1, keepdim=False)
+ return numer / denom
+
+
+class TacotronSTFT(torch.nn.Module):
+ def __init__(self, filter_length=1024, hop_length=256, win_length=1024,
+ n_mel_channels=80, sampling_rate=22050, mel_fmin=0.0,
+ mel_fmax=8000.0):
+ super(TacotronSTFT, self).__init__()
+ self.n_mel_channels = n_mel_channels
+ self.sampling_rate = sampling_rate
+ self.stft_fn = STFT(filter_length, hop_length, win_length)
+ mel_basis = librosa_mel_fn(
+ sampling_rate, filter_length, n_mel_channels, mel_fmin, mel_fmax)
+ mel_basis = torch.from_numpy(mel_basis).float()
+ self.register_buffer('mel_basis', mel_basis)
+
+ def spectral_normalize(self, magnitudes):
+ output = dynamic_range_compression(magnitudes)
+ return output
+
+ def spectral_de_normalize(self, magnitudes):
+ output = dynamic_range_decompression(magnitudes)
+ return output
+
+ def mel_spectrogram(self, y):
+ """Computes mel-spectrograms from a batch of waves
+ PARAMS
+ ------
+ y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1]
+
+ RETURNS
+ -------
+ mel_output: torch.FloatTensor of shape (B, n_mel_channels, T)
+ """
+ assert(torch.min(y.data) >= -1)
+ assert(torch.max(y.data) <= 1)
+
+ magnitudes, phases = self.stft_fn.transform(y)
+ magnitudes = magnitudes.data
+ mel_output = torch.matmul(self.mel_basis, magnitudes)
+ mel_output = self.spectral_normalize(mel_output)
+ return mel_output
diff --git a/fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/model.py b/fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..ccf132b150a7cc1c125c1190b5fd8f43edaae685
--- /dev/null
+++ b/fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/model.py
@@ -0,0 +1,669 @@
+from math import sqrt
+import torch
+import torch.distributions as distr
+from torch.autograd import Variable
+from torch import nn
+from torch.nn import functional as F
+from .layers import ConvNorm, LinearNorm, GlobalAvgPool
+from .utils import to_gpu, get_mask_from_lengths
+
+
+class LocationLayer(nn.Module):
+ def __init__(self, attention_n_filters, attention_kernel_size,
+ attention_dim):
+ super(LocationLayer, self).__init__()
+ padding = int((attention_kernel_size - 1) / 2)
+ self.location_conv = ConvNorm(2, attention_n_filters,
+ kernel_size=attention_kernel_size,
+ padding=padding, bias=False, stride=1,
+ dilation=1)
+ self.location_dense = LinearNorm(attention_n_filters, attention_dim,
+ bias=False, w_init_gain='tanh')
+
+ def forward(self, attention_weights_cat):
+ processed_attention = self.location_conv(attention_weights_cat)
+ processed_attention = processed_attention.transpose(1, 2)
+ processed_attention = self.location_dense(processed_attention)
+ return processed_attention
+
+
+class Attention(nn.Module):
+ def __init__(self, attention_rnn_dim, embedding_dim, attention_dim,
+ attention_location_n_filters, attention_location_kernel_size):
+ super(Attention, self).__init__()
+ self.query_layer = LinearNorm(attention_rnn_dim, attention_dim,
+ bias=False, w_init_gain='tanh')
+ self.memory_layer = LinearNorm(embedding_dim, attention_dim, bias=False,
+ w_init_gain='tanh')
+ self.v = LinearNorm(attention_dim, 1, bias=False)
+ self.location_layer = LocationLayer(attention_location_n_filters,
+ attention_location_kernel_size,
+ attention_dim)
+ self.score_mask_value = -float("inf")
+
+ def get_alignment_energies(self, query, processed_memory,
+ attention_weights_cat):
+ """
+ PARAMS
+ ------
+ query: decoder output (batch, n_mel_channels * n_frames_per_step)
+ processed_memory: processed encoder outputs (B, T_in, attention_dim)
+ attention_weights_cat: cumulative and prev. att weights (B, 2, max_time)
+
+ RETURNS
+ -------
+ alignment (batch, max_time)
+ """
+
+ processed_query = self.query_layer(query.unsqueeze(1))
+ processed_attention_weights = self.location_layer(attention_weights_cat)
+ energies = self.v(torch.tanh(
+ processed_query + processed_attention_weights + processed_memory))
+
+ energies = energies.squeeze(-1)
+ return energies
+
+ def forward(self, attention_hidden_state, memory, processed_memory,
+ attention_weights_cat, mask):
+ """
+ PARAMS
+ ------
+ attention_hidden_state: attention rnn last output
+ memory: encoder outputs
+ processed_memory: processed encoder outputs
+ attention_weights_cat: previous and cummulative attention weights
+ mask: binary mask for padded data
+ """
+ alignment = self.get_alignment_energies(
+ attention_hidden_state, processed_memory, attention_weights_cat)
+
+ if mask is not None:
+ alignment.data.masked_fill_(mask, self.score_mask_value)
+
+ attention_weights = F.softmax(alignment, dim=1)
+ attention_context = torch.bmm(attention_weights.unsqueeze(1), memory)
+ attention_context = attention_context.squeeze(1)
+
+ return attention_context, attention_weights
+
+
+class Prenet(nn.Module):
+ def __init__(self, in_dim, sizes):
+ super(Prenet, self).__init__()
+ in_sizes = [in_dim] + sizes[:-1]
+ self.layers = nn.ModuleList(
+ [LinearNorm(in_size, out_size, bias=False)
+ for (in_size, out_size) in zip(in_sizes, sizes)])
+
+ def forward(self, x):
+ for linear in self.layers:
+ x = F.dropout(F.relu(linear(x)), p=0.5, training=True)
+ return x
+
+
+class Postnet(nn.Module):
+ """Postnet
+ - Five 1-d convolution with 512 channels and kernel size 5
+ """
+
+ def __init__(self, hparams):
+ super(Postnet, self).__init__()
+ self.convolutions = nn.ModuleList()
+
+ self.convolutions.append(
+ nn.Sequential(
+ ConvNorm(hparams.n_mel_channels, hparams.postnet_embedding_dim,
+ kernel_size=hparams.postnet_kernel_size, stride=1,
+ padding=int((hparams.postnet_kernel_size - 1) / 2),
+ dilation=1, w_init_gain='tanh'),
+ nn.BatchNorm1d(hparams.postnet_embedding_dim))
+ )
+
+ for i in range(1, hparams.postnet_n_convolutions - 1):
+ self.convolutions.append(
+ nn.Sequential(
+ ConvNorm(hparams.postnet_embedding_dim,
+ hparams.postnet_embedding_dim,
+ kernel_size=hparams.postnet_kernel_size, stride=1,
+ padding=int((hparams.postnet_kernel_size - 1) / 2),
+ dilation=1, w_init_gain='tanh'),
+ nn.BatchNorm1d(hparams.postnet_embedding_dim))
+ )
+
+ self.convolutions.append(
+ nn.Sequential(
+ ConvNorm(hparams.postnet_embedding_dim, hparams.n_mel_channels,
+ kernel_size=hparams.postnet_kernel_size, stride=1,
+ padding=int((hparams.postnet_kernel_size - 1) / 2),
+ dilation=1, w_init_gain='linear'),
+ nn.BatchNorm1d(hparams.n_mel_channels))
+ )
+
+ def forward(self, x):
+ for i in range(len(self.convolutions) - 1):
+ x = F.dropout(torch.tanh(self.convolutions[i](x)), 0.5, self.training)
+ x = F.dropout(self.convolutions[-1](x), 0.5, self.training)
+
+ return x
+
+
+class Encoder(nn.Module):
+ """Encoder module:
+ - Three 1-d convolution banks
+ - Bidirectional LSTM
+ """
+ def __init__(self, hparams):
+ super(Encoder, self).__init__()
+
+ convolutions = []
+ for _ in range(hparams.encoder_n_convolutions):
+ conv_layer = nn.Sequential(
+ ConvNorm(hparams.encoder_embedding_dim,
+ hparams.encoder_embedding_dim,
+ kernel_size=hparams.encoder_kernel_size, stride=1,
+ padding=int((hparams.encoder_kernel_size - 1) / 2),
+ dilation=1, w_init_gain='relu'),
+ nn.BatchNorm1d(hparams.encoder_embedding_dim))
+ convolutions.append(conv_layer)
+ self.convolutions = nn.ModuleList(convolutions)
+
+ self.lstm = nn.LSTM(hparams.encoder_embedding_dim,
+ int(hparams.encoder_embedding_dim / 2), 1,
+ batch_first=True, bidirectional=True)
+
+ def forward(self, x, input_lengths):
+ for conv in self.convolutions:
+ x = F.dropout(F.relu(conv(x)), 0.5, self.training)
+
+ x = x.transpose(1, 2)
+
+ # pytorch tensor are not reversible, hence the conversion
+ input_lengths = input_lengths.cpu().numpy()
+ x = nn.utils.rnn.pack_padded_sequence(
+ x, input_lengths, batch_first=True)
+
+ self.lstm.flatten_parameters()
+ outputs, _ = self.lstm(x)
+
+ outputs, _ = nn.utils.rnn.pad_packed_sequence(
+ outputs, batch_first=True)
+
+ return outputs
+
+ def inference(self, x):
+ for conv in self.convolutions:
+ x = F.dropout(F.relu(conv(x)), 0.5, self.training)
+
+ x = x.transpose(1, 2)
+
+ self.lstm.flatten_parameters()
+ outputs, _ = self.lstm(x)
+
+ return outputs
+
+
+class AudioEncoder(nn.Module):
+ def __init__(self, hparams):
+ super(AudioEncoder, self).__init__()
+
+ assert hparams.lat_dim > 0
+
+ convolutions = []
+ inp_dim = hparams.n_mel_channels
+ for _ in range(hparams.lat_n_convolutions):
+ conv_layer = nn.Sequential(
+ ConvNorm(inp_dim, hparams.lat_n_filters,
+ kernel_size=hparams.lat_kernel_size, stride=1,
+ padding=int((hparams.lat_kernel_size - 1) / 2),
+ dilation=1, w_init_gain='tanh'),
+ nn.BatchNorm1d(hparams.lat_n_filters))
+ inp_dim = hparams.lat_n_filters
+ convolutions.append(conv_layer)
+ self.convolutions = nn.ModuleList(convolutions)
+
+ self.lstm = nn.LSTM(hparams.lat_n_filters,
+ int(hparams.lat_n_filters / 2),
+ hparams.lat_n_blstms, batch_first=True,
+ bidirectional=True)
+ self.pool = GlobalAvgPool()
+
+ self.mu_proj = LinearNorm(hparams.lat_n_filters, hparams.lat_dim)
+ self.logvar_proj = LinearNorm(hparams.lat_n_filters, hparams.lat_dim)
+ self.lat_dim = hparams.lat_dim
+
+ def forward(self, x, lengths):
+ """
+ Args:
+ x (torch.Tensor): (B, F, T)
+ """
+
+ for conv in self.convolutions:
+ x = F.dropout(F.tanh(conv(x)), 0.5, self.training)
+
+ x = x.transpose(1, 2) # (B, T, D)
+
+ # x may not be sorted by length. Sort->process->unsort
+ max_len = x.size(1)
+ assert max_len == torch.max(lengths).item()
+
+ lengths, perm_idx = lengths.sort(0, descending=True)
+ x = x[perm_idx]
+ x = nn.utils.rnn.pack_padded_sequence(x, lengths, batch_first=True)
+
+ self.lstm.flatten_parameters()
+ outputs, _ = self.lstm(x)
+ outputs, _ = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True)
+
+ _, unperm_idx = perm_idx.sort(0)
+ outputs = outputs[unperm_idx] # (B, T, D)
+ lengths = lengths[unperm_idx] # (B, T, D)
+
+ outputs = self.pool(outputs, lengths) # (B, D)
+
+ mu = self.mu_proj(outputs)
+ logvar = self.logvar_proj(outputs)
+ z = distr.Normal(mu, logvar).rsample()
+ return z, mu, logvar
+
+
+class Decoder(nn.Module):
+ def __init__(self, hparams):
+ super(Decoder, self).__init__()
+ self.n_mel_channels = hparams.n_mel_channels
+ self.n_frames_per_step = hparams.n_frames_per_step
+ self.encoder_embedding_dim = hparams.encoder_embedding_dim
+ self.obs_dim = hparams.obs_dim
+ self.lat_dim = hparams.lat_dim
+ self.attention_rnn_dim = hparams.attention_rnn_dim
+ self.decoder_rnn_dim = hparams.decoder_rnn_dim
+ self.prenet_dim = hparams.prenet_dim
+ self.max_decoder_steps = hparams.max_decoder_steps
+ self.gate_threshold = hparams.gate_threshold
+ self.p_attention_dropout = hparams.p_attention_dropout
+ self.p_decoder_dropout = hparams.p_decoder_dropout
+
+ self.prenet = Prenet(
+ hparams.n_mel_channels * hparams.n_frames_per_step,
+ [hparams.prenet_dim, hparams.prenet_dim])
+
+ self.attention_rnn = nn.LSTMCell(
+ hparams.prenet_dim + hparams.encoder_embedding_dim,
+ hparams.attention_rnn_dim)
+
+ self.attention_layer = Attention(
+ hparams.attention_rnn_dim, hparams.encoder_embedding_dim,
+ hparams.attention_dim, hparams.attention_location_n_filters,
+ hparams.attention_location_kernel_size)
+
+ encoder_tot_dim = (hparams.encoder_embedding_dim + \
+ hparams.lat_dim + hparams.obs_dim)
+ self.decoder_rnn = nn.LSTMCell(
+ hparams.attention_rnn_dim + encoder_tot_dim,
+ hparams.decoder_rnn_dim, 1)
+
+ self.linear_projection = LinearNorm(
+ hparams.decoder_rnn_dim + encoder_tot_dim,
+ hparams.n_mel_channels * hparams.n_frames_per_step)
+
+ self.gate_layer = LinearNorm(
+ hparams.decoder_rnn_dim + encoder_tot_dim, 1,
+ bias=True, w_init_gain='sigmoid')
+
+ def get_go_frame(self, memory):
+ """ Gets all zeros frames to use as first decoder input
+ PARAMS
+ ------
+ memory: decoder outputs
+
+ RETURNS
+ -------
+ decoder_input: all zeros frames
+ """
+ B = memory.size(0)
+ decoder_input = Variable(memory.data.new(
+ B, self.n_mel_channels * self.n_frames_per_step).zero_())
+ return decoder_input
+
+ def initialize_decoder_states(self, memory, obs_and_lat, mask):
+ """ Initializes attention rnn states, decoder rnn states, attention
+ weights, attention cumulative weights, attention context, stores memory
+ and stores processed memory
+ PARAMS
+ ------
+ memory: Encoder outputs
+ obs_and_lat: Observed and latent attribute embeddings
+ mask: Mask for padded data if training, expects None for inference
+ """
+ B = memory.size(0)
+ MAX_TIME = memory.size(1)
+
+ self.attention_hidden = Variable(memory.data.new(
+ B, self.attention_rnn_dim).zero_())
+ self.attention_cell = Variable(memory.data.new(
+ B, self.attention_rnn_dim).zero_())
+
+ self.decoder_hidden = Variable(memory.data.new(
+ B, self.decoder_rnn_dim).zero_())
+ self.decoder_cell = Variable(memory.data.new(
+ B, self.decoder_rnn_dim).zero_())
+
+ self.attention_weights = Variable(memory.data.new(
+ B, MAX_TIME).zero_())
+ self.attention_weights_cum = Variable(memory.data.new(
+ B, MAX_TIME).zero_())
+ self.attention_context = Variable(memory.data.new(
+ B, self.encoder_embedding_dim).zero_())
+
+ self.memory = memory
+ self.processed_memory = self.attention_layer.memory_layer(memory)
+ self.obs_and_lat = obs_and_lat
+ self.mask = mask
+
+ def parse_decoder_inputs(self, decoder_inputs):
+ """ Prepares decoder inputs, i.e. mel outputs
+ PARAMS
+ ------
+ decoder_inputs: inputs used for teacher-forced training, i.e. mel-specs
+
+ RETURNS
+ -------
+ inputs: processed decoder inputs
+
+ """
+ # (B, n_mel_channels, T_out) -> (B, T_out, n_mel_channels)
+ decoder_inputs = decoder_inputs.transpose(1, 2)
+ decoder_inputs = decoder_inputs.view(
+ decoder_inputs.size(0),
+ int(decoder_inputs.size(1)/self.n_frames_per_step), -1)
+ # (B, T_out, n_mel_channels) -> (T_out, B, n_mel_channels)
+ decoder_inputs = decoder_inputs.transpose(0, 1)
+ return decoder_inputs
+
+ def parse_decoder_outputs(self, mel_outputs, gate_outputs, alignments):
+ """ Prepares decoder outputs for output
+ PARAMS
+ ------
+ mel_outputs:
+ gate_outputs: gate output energies
+ alignments:
+
+ RETURNS
+ -------
+ mel_outputs:
+ gate_outpust: gate output energies
+ alignments:
+ """
+ # (T_out, B) -> (B, T_out)
+ alignments = torch.stack(alignments).transpose(0, 1)
+ # (T_out, B) -> (B, T_out)
+ gate_outputs = torch.stack(gate_outputs).transpose(0, 1)
+ gate_outputs = gate_outputs.contiguous()
+ # (T_out, B, n_mel_channels) -> (B, T_out, n_mel_channels)
+ mel_outputs = torch.stack(mel_outputs).transpose(0, 1).contiguous()
+ # decouple frames per step
+ mel_outputs = mel_outputs.view(
+ mel_outputs.size(0), -1, self.n_mel_channels)
+ # (B, T_out, n_mel_channels) -> (B, n_mel_channels, T_out)
+ mel_outputs = mel_outputs.transpose(1, 2)
+
+ return mel_outputs, gate_outputs, alignments
+
+ def decode(self, decoder_input):
+ """ Decoder step using stored states, attention and memory
+ PARAMS
+ ------
+ decoder_input: previous mel output
+
+ RETURNS
+ -------
+ mel_output:
+ gate_output: gate output energies
+ attention_weights:
+ """
+ cell_input = torch.cat((decoder_input, self.attention_context), -1)
+ self.attention_hidden, self.attention_cell = self.attention_rnn(
+ cell_input, (self.attention_hidden, self.attention_cell))
+ self.attention_hidden = F.dropout(
+ self.attention_hidden, self.p_attention_dropout, self.training)
+
+ attention_weights_cat = torch.cat(
+ (self.attention_weights.unsqueeze(1),
+ self.attention_weights_cum.unsqueeze(1)), dim=1)
+ self.attention_context, self.attention_weights = self.attention_layer(
+ self.attention_hidden, self.memory, self.processed_memory,
+ attention_weights_cat, self.mask)
+
+ self.attention_weights_cum += self.attention_weights
+ decoder_input = torch.cat(
+ (self.attention_hidden, self.attention_context), -1)
+ if self.obs_and_lat is not None:
+ decoder_input = torch.cat((decoder_input, self.obs_and_lat), -1)
+ self.decoder_hidden, self.decoder_cell = self.decoder_rnn(
+ decoder_input, (self.decoder_hidden, self.decoder_cell))
+ self.decoder_hidden = F.dropout(
+ self.decoder_hidden, self.p_decoder_dropout, self.training)
+
+ decoder_hidden_attention_context = torch.cat(
+ (self.decoder_hidden, self.attention_context), dim=1)
+ if self.obs_and_lat is not None:
+ decoder_hidden_attention_context = torch.cat(
+ (decoder_hidden_attention_context, self.obs_and_lat), dim=1)
+ decoder_output = self.linear_projection(
+ decoder_hidden_attention_context)
+
+ gate_prediction = self.gate_layer(decoder_hidden_attention_context)
+ return decoder_output, gate_prediction, self.attention_weights
+
+ def forward(self, memory, obs_and_lat, decoder_inputs, memory_lengths):
+ """ Decoder forward pass for training
+ PARAMS
+ ------
+ memory: Encoder outputs
+ obs_and_lat: Observed and latent attribute embeddings
+ decoder_inputs: Decoder inputs for teacher forcing. i.e. mel-specs
+ memory_lengths: Encoder output lengths for attention masking.
+
+ RETURNS
+ -------
+ mel_outputs: mel outputs from the decoder
+ gate_outputs: gate outputs from the decoder
+ alignments: sequence of attention weights from the decoder
+ """
+
+ decoder_input = self.get_go_frame(memory).unsqueeze(0)
+ decoder_inputs = self.parse_decoder_inputs(decoder_inputs)
+ decoder_inputs = torch.cat((decoder_input, decoder_inputs), dim=0)
+ decoder_inputs = self.prenet(decoder_inputs)
+
+ self.initialize_decoder_states(
+ memory, obs_and_lat, mask=~get_mask_from_lengths(memory_lengths))
+
+ mel_outputs, gate_outputs, alignments = [], [], []
+ while len(mel_outputs) < decoder_inputs.size(0) - 1:
+ decoder_input = decoder_inputs[len(mel_outputs)]
+ mel_output, gate_output, attention_weights = self.decode(
+ decoder_input)
+ mel_outputs += [mel_output.squeeze(1)]
+ gate_outputs += [gate_output.squeeze()]
+ alignments += [attention_weights]
+
+ mel_outputs, gate_outputs, alignments = self.parse_decoder_outputs(
+ mel_outputs, gate_outputs, alignments)
+
+ return mel_outputs, gate_outputs, alignments
+
+ def inference(self, memory, obs_and_lat, ret_has_eos=False):
+ """ Decoder inference
+ PARAMS
+ ------
+ memory: Encoder outputs
+ obs_and_lat: Observed and latent attribute embeddings
+
+ RETURNS
+ -------
+ mel_outputs: mel outputs from the decoder
+ gate_outputs: gate outputs from the decoder
+ alignments: sequence of attention weights from the decoder
+ """
+ decoder_input = self.get_go_frame(memory)
+
+ self.initialize_decoder_states(memory, obs_and_lat, mask=None)
+
+ mel_outputs, gate_outputs, alignments = [], [], []
+ has_eos = False
+ while True:
+ decoder_input = self.prenet(decoder_input)
+ mel_output, gate_output, alignment = self.decode(decoder_input)
+
+ mel_outputs += [mel_output.squeeze(1)]
+ gate_outputs += [gate_output]
+ alignments += [alignment]
+
+ if torch.sigmoid(gate_output.data) > self.gate_threshold:
+ has_eos = True
+ break
+ elif len(mel_outputs) == self.max_decoder_steps:
+ # print("Warning! Reached max decoder steps")
+ break
+
+ decoder_input = mel_output
+
+ mel_outputs, gate_outputs, alignments = self.parse_decoder_outputs(
+ mel_outputs, gate_outputs, alignments)
+
+ if ret_has_eos:
+ return mel_outputs, gate_outputs, alignments, has_eos
+ else:
+ return mel_outputs, gate_outputs, alignments
+
+
+class Tacotron2(nn.Module):
+ def __init__(self, hparams):
+ super(Tacotron2, self).__init__()
+ self.mask_padding = hparams.mask_padding
+ self.fp16_run = hparams.fp16_run
+ self.n_mel_channels = hparams.n_mel_channels
+ self.n_frames_per_step = hparams.n_frames_per_step
+
+ # initialize text encoder embedding
+ self.embedding = nn.Embedding(
+ hparams.n_symbols, hparams.symbols_embedding_dim)
+ std = sqrt(2.0 / (hparams.n_symbols + hparams.symbols_embedding_dim))
+ val = sqrt(3.0) * std # uniform bounds for std
+ self.embedding.weight.data.uniform_(-val, val)
+
+ # initialize observed attribute embedding
+ self.obs_embedding = None
+ if hparams.obs_dim > 0:
+ self.obs_embedding = nn.Embedding(
+ hparams.obs_n_class, hparams.obs_dim)
+ std = sqrt(2.0 / (hparams.obs_n_class + hparams.obs_dim))
+ val = sqrt(3.0) * std # uniform bounds for std
+ self.obs_embedding.weight.data.uniform_(-val, val)
+
+ self.encoder = Encoder(hparams)
+ self.decoder = Decoder(hparams)
+ self.postnet = Postnet(hparams)
+
+ self.lat_encoder = None
+ if hparams.lat_dim > 0:
+ self.lat_encoder = AudioEncoder(hparams)
+
+ def parse_batch(self, batch):
+ (text_padded, input_lengths, obs_labels,
+ mel_padded, gate_padded, output_lengths) = batch
+ text_padded = to_gpu(text_padded).long()
+ input_lengths = to_gpu(input_lengths).long()
+ obs_labels = to_gpu(obs_labels).long()
+ max_len = torch.max(input_lengths.data).item()
+ mel_padded = to_gpu(mel_padded).float()
+ gate_padded = to_gpu(gate_padded).float()
+ output_lengths = to_gpu(output_lengths).long()
+
+ return (
+ (text_padded, input_lengths, obs_labels,
+ mel_padded, max_len, output_lengths),
+ (mel_padded, gate_padded))
+
+ def parse_output(self, outputs, output_lengths=None):
+ if self.mask_padding and output_lengths is not None:
+ mask = ~get_mask_from_lengths(output_lengths)
+ mask = mask.expand(self.n_mel_channels, mask.size(0), mask.size(1))
+ mask = mask.permute(1, 0, 2)
+
+ outputs[0].data.masked_fill_(mask, 0.0)
+ outputs[1].data.masked_fill_(mask, 0.0)
+ outputs[2].data.masked_fill_(mask[:, 0, :], 1e3) # gate energies
+
+ return outputs
+
+ def forward(self, inputs):
+ (text_inputs, text_lengths, obs_labels,
+ mels, max_len, output_lengths) = inputs
+ text_lengths, output_lengths = text_lengths.data, output_lengths.data
+
+ embedded_inputs = self.embedding(text_inputs).transpose(1, 2)
+
+ encoder_outputs = self.encoder(embedded_inputs, text_lengths)
+
+ obs = None
+ if self.obs_embedding is not None:
+ obs = self.obs_embedding(obs_labels)
+
+ lat, lat_mu, lat_logvar = None, None, None
+ if self.lat_encoder is not None:
+ (lat, lat_mu, lat_logvar) = self.lat_encoder(mels, output_lengths)
+
+ obs_and_lat = [x for x in [obs, lat] if x is not None]
+ if bool(obs_and_lat):
+ obs_and_lat = torch.cat(obs_and_lat, dim=-1)
+ else:
+ obs_and_lat = None
+
+ mel_outputs, gate_outputs, alignments = self.decoder(
+ encoder_outputs, obs_and_lat, mels, memory_lengths=text_lengths)
+
+ mel_outputs_postnet = self.postnet(mel_outputs)
+ mel_outputs_postnet = mel_outputs + mel_outputs_postnet
+
+ return self.parse_output(
+ [mel_outputs, mel_outputs_postnet, gate_outputs, alignments,
+ lat_mu, lat_logvar],
+ output_lengths)
+
+ def inference(self, inputs, obs_labels=None, lat=None, ret_has_eos=False):
+ embedded_inputs = self.embedding(inputs).transpose(1, 2)
+ encoder_outputs = self.encoder.inference(embedded_inputs)
+
+ if obs_labels is None:
+ obs_labels = torch.LongTensor(len(inputs))
+ obs_labels = obs_labels.to(inputs.device).zero_()
+
+ obs = None
+ if self.obs_embedding is not None:
+ obs = self.obs_embedding(obs_labels)
+
+ if self.lat_encoder is not None:
+ if lat is None:
+ lat = torch.FloatTensor(len(inputs), self.lat_encoder.lat_dim)
+ lat = lat.to(inputs.device).zero_().type(encoder_outputs.type())
+
+ obs_and_lat = [x for x in [obs, lat] if x is not None]
+ if bool(obs_and_lat):
+ obs_and_lat = torch.cat(obs_and_lat, dim=-1)
+ else:
+ obs_and_lat = None
+
+ mel_outputs, gate_outputs, alignments, has_eos = self.decoder.inference(
+ encoder_outputs, obs_and_lat, ret_has_eos=True)
+
+ mel_outputs_postnet = self.postnet(mel_outputs)
+ mel_outputs_postnet = mel_outputs + mel_outputs_postnet
+
+ outputs = self.parse_output(
+ [mel_outputs, mel_outputs_postnet, gate_outputs, alignments])
+
+ if ret_has_eos:
+ return outputs + [has_eos]
+ else:
+ return outputs
diff --git a/fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/numbers.py b/fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/numbers.py
new file mode 100644
index 0000000000000000000000000000000000000000..0d5f7fa818a45ecf132627d240afac653e148070
--- /dev/null
+++ b/fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/numbers.py
@@ -0,0 +1,71 @@
+""" from https://github.com/keithito/tacotron """
+
+import inflect
+import re
+
+
+_inflect = inflect.engine()
+_comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])')
+_decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)')
+_pounds_re = re.compile(r'£([0-9\,]*[0-9]+)')
+_dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)')
+_ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)')
+_number_re = re.compile(r'[0-9]+')
+
+
+def _remove_commas(m):
+ return m.group(1).replace(',', '')
+
+
+def _expand_decimal_point(m):
+ return m.group(1).replace('.', ' point ')
+
+
+def _expand_dollars(m):
+ match = m.group(1)
+ parts = match.split('.')
+ if len(parts) > 2:
+ return match + ' dollars' # Unexpected format
+ dollars = int(parts[0]) if parts[0] else 0
+ cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
+ if dollars and cents:
+ dollar_unit = 'dollar' if dollars == 1 else 'dollars'
+ cent_unit = 'cent' if cents == 1 else 'cents'
+ return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit)
+ elif dollars:
+ dollar_unit = 'dollar' if dollars == 1 else 'dollars'
+ return '%s %s' % (dollars, dollar_unit)
+ elif cents:
+ cent_unit = 'cent' if cents == 1 else 'cents'
+ return '%s %s' % (cents, cent_unit)
+ else:
+ return 'zero dollars'
+
+
+def _expand_ordinal(m):
+ return _inflect.number_to_words(m.group(0))
+
+
+def _expand_number(m):
+ num = int(m.group(0))
+ if num > 1000 and num < 3000:
+ if num == 2000:
+ return 'two thousand'
+ elif num > 2000 and num < 2010:
+ return 'two thousand ' + _inflect.number_to_words(num % 100)
+ elif num % 100 == 0:
+ return _inflect.number_to_words(num // 100) + ' hundred'
+ else:
+ return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ')
+ else:
+ return _inflect.number_to_words(num, andword='')
+
+
+def normalize_numbers(text):
+ text = re.sub(_comma_number_re, _remove_commas, text)
+ text = re.sub(_pounds_re, r'\1 pounds', text)
+ text = re.sub(_dollars_re, _expand_dollars, text)
+ text = re.sub(_decimal_number_re, _expand_decimal_point, text)
+ text = re.sub(_ordinal_re, _expand_ordinal, text)
+ text = re.sub(_number_re, _expand_number, text)
+ return text
diff --git a/fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/stft.py b/fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/stft.py
new file mode 100644
index 0000000000000000000000000000000000000000..63fcd431e2d7746b696aaa0d4172bc04ffb88efa
--- /dev/null
+++ b/fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/stft.py
@@ -0,0 +1,141 @@
+"""
+BSD 3-Clause License
+
+Copyright (c) 2017, Prem Seetharaman
+All rights reserved.
+
+* Redistribution and use in source and binary forms, with or without
+ modification, are permitted provided that the following conditions are met:
+
+* Redistributions of source code must retain the above copyright notice,
+ this list of conditions and the following disclaimer.
+
+* Redistributions in binary form must reproduce the above copyright notice, this
+ list of conditions and the following disclaimer in the
+ documentation and/or other materials provided with the distribution.
+
+* Neither the name of the copyright holder nor the names of its
+ contributors may be used to endorse or promote products derived from this
+ software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
+ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
+ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
+(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
+ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+"""
+
+import torch
+import numpy as np
+import torch.nn.functional as F
+from torch.autograd import Variable
+from scipy.signal import get_window
+from librosa.util import pad_center, tiny
+from .audio_processing import window_sumsquare
+
+
+class STFT(torch.nn.Module):
+ """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
+ def __init__(self, filter_length=800, hop_length=200, win_length=800,
+ window='hann'):
+ super(STFT, self).__init__()
+ self.filter_length = filter_length
+ self.hop_length = hop_length
+ self.win_length = win_length
+ self.window = window
+ self.forward_transform = None
+ scale = self.filter_length / self.hop_length
+ fourier_basis = np.fft.fft(np.eye(self.filter_length))
+
+ cutoff = int((self.filter_length / 2 + 1))
+ fourier_basis = np.vstack([np.real(fourier_basis[:cutoff, :]),
+ np.imag(fourier_basis[:cutoff, :])])
+
+ forward_basis = torch.FloatTensor(fourier_basis[:, None, :])
+ inverse_basis = torch.FloatTensor(
+ np.linalg.pinv(scale * fourier_basis).T[:, None, :])
+
+ if window is not None:
+ assert(filter_length >= win_length)
+ # get window and zero center pad it to filter_length
+ fft_window = get_window(window, win_length, fftbins=True)
+ fft_window = pad_center(fft_window, filter_length)
+ fft_window = torch.from_numpy(fft_window).float()
+
+ # window the bases
+ forward_basis *= fft_window
+ inverse_basis *= fft_window
+
+ self.register_buffer('forward_basis', forward_basis.float())
+ self.register_buffer('inverse_basis', inverse_basis.float())
+
+ def transform(self, input_data):
+ num_batches = input_data.size(0)
+ num_samples = input_data.size(1)
+
+ self.num_samples = num_samples
+
+ # similar to librosa, reflect-pad the input
+ input_data = input_data.view(num_batches, 1, num_samples)
+ input_data = F.pad(
+ input_data.unsqueeze(1),
+ (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0),
+ mode='reflect')
+ input_data = input_data.squeeze(1)
+
+ forward_transform = F.conv1d(
+ input_data,
+ Variable(self.forward_basis, requires_grad=False),
+ stride=self.hop_length,
+ padding=0)
+
+ cutoff = int((self.filter_length / 2) + 1)
+ real_part = forward_transform[:, :cutoff, :]
+ imag_part = forward_transform[:, cutoff:, :]
+
+ magnitude = torch.sqrt(real_part**2 + imag_part**2)
+ phase = torch.autograd.Variable(
+ torch.atan2(imag_part.data, real_part.data))
+
+ return magnitude, phase
+
+ def inverse(self, magnitude, phase):
+ recombine_magnitude_phase = torch.cat(
+ [magnitude*torch.cos(phase), magnitude*torch.sin(phase)], dim=1)
+
+ inverse_transform = F.conv_transpose1d(
+ recombine_magnitude_phase,
+ Variable(self.inverse_basis, requires_grad=False),
+ stride=self.hop_length,
+ padding=0)
+
+ if self.window is not None:
+ window_sum = window_sumsquare(
+ self.window, magnitude.size(-1), hop_length=self.hop_length,
+ win_length=self.win_length, n_fft=self.filter_length,
+ dtype=np.float32)
+ # remove modulation effects
+ approx_nonzero_indices = torch.from_numpy(
+ np.where(window_sum > tiny(window_sum))[0])
+ window_sum = torch.autograd.Variable(
+ torch.from_numpy(window_sum), requires_grad=False)
+ window_sum = window_sum.cuda() if magnitude.is_cuda else window_sum
+ inverse_transform[:, :, approx_nonzero_indices] /= window_sum[approx_nonzero_indices]
+
+ # scale by hop ratio
+ inverse_transform *= float(self.filter_length) / self.hop_length
+
+ inverse_transform = inverse_transform[:, :, int(self.filter_length/2):]
+ inverse_transform = inverse_transform[:, :, :-int(self.filter_length/2):]
+
+ return inverse_transform
+
+ def forward(self, input_data):
+ self.magnitude, self.phase = self.transform(input_data)
+ reconstruction = self.inverse(self.magnitude, self.phase)
+ return reconstruction
diff --git a/fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/symbols.py b/fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/symbols.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f0d70fdad92ba4f554d971710b60f2f9e8d9298
--- /dev/null
+++ b/fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/symbols.py
@@ -0,0 +1,18 @@
+""" from https://github.com/keithito/tacotron """
+
+'''
+Defines the set of symbols used in text input to the model.
+
+The default is a set of ASCII characters that works well for English or text that has been run through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details. '''
+from . import cmudict
+
+_pad = '_'
+_punctuation = '!\'(),.:;? '
+_special = '-'
+_letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
+
+# Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters):
+_arpabet = ['@' + s for s in cmudict.valid_symbols]
+
+# Export all symbols:
+symbols = [_pad] + list(_special) + list(_punctuation) + list(_letters) + _arpabet
diff --git a/fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/text.py b/fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/text.py
new file mode 100644
index 0000000000000000000000000000000000000000..49e2ca498bf67ad226af5de796b9f44afa65198d
--- /dev/null
+++ b/fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/text.py
@@ -0,0 +1,107 @@
+""" from https://github.com/keithito/tacotron """
+import numpy as np
+import re
+from . import cleaners
+from .symbols import symbols
+
+
+# Mappings from symbol to numeric ID and vice versa:
+_symbol_to_id = {s: i for i, s in enumerate(symbols)}
+_id_to_symbol = {i: s for i, s in enumerate(symbols)}
+
+# Regular expression matching text enclosed in curly braces:
+_curly_re = re.compile(r'(.*?)\{(.+?)\}(.*)')
+
+# Special symbols
+SOS_TOK = ''
+EOS_TOK = ''
+
+def text_to_sequence(text, cleaner_names):
+ '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
+
+ The text can optionally have ARPAbet sequences enclosed in curly braces embedded
+ in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street."
+
+ Args:
+ text: string to convert to a sequence
+ cleaner_names: names of the cleaner functions to run the text through
+
+ Returns:
+ List of integers corresponding to the symbols in the text
+ '''
+ sequence = []
+
+ # Check for curly braces and treat their contents as ARPAbet:
+ while len(text):
+ m = _curly_re.match(text)
+ if not m:
+ sequence += _symbols_to_sequence(_clean_text(text, cleaner_names))
+ break
+ sequence += _symbols_to_sequence(_clean_text(m.group(1), cleaner_names))
+ sequence += _arpabet_to_sequence(m.group(2))
+ text = m.group(3)
+
+ return sequence
+
+
+def sample_code_chunk(code, size):
+ assert(size > 0 and size <= len(code))
+ start = np.random.randint(len(code) - size + 1)
+ end = start + size
+ return code[start:end], start, end
+
+
+def code_to_sequence(code, code_dict, collapse_code):
+ if collapse_code:
+ prev_c = None
+ sequence = []
+ for c in code:
+ if c in code_dict and c != prev_c:
+ sequence.append(code_dict[c])
+ prev_c = c
+ else:
+ sequence = [code_dict[c] for c in code if c in code_dict]
+ if len(sequence) < 0.95 * len(code):
+ print('WARNING : over 5%% codes are OOV')
+
+ return sequence
+
+
+def sequence_to_text(sequence):
+ '''Converts a sequence of IDs back to a string'''
+ result = ''
+ for symbol_id in sequence:
+ if symbol_id in _id_to_symbol:
+ s = _id_to_symbol[symbol_id]
+ # Enclose ARPAbet back in curly braces:
+ if len(s) > 1 and s[0] == '@':
+ s = '{%s}' % s[1:]
+ result += s
+ return result.replace('}{', ' ')
+
+
+def sequence_to_code(sequence, code_dict):
+ '''Analogous to sequence_to_text'''
+ id_to_code = {i: c for c, i in code_dict.items()}
+ return ' '.join([id_to_code[i] for i in sequence])
+
+
+def _clean_text(text, cleaner_names):
+ for name in cleaner_names:
+ cleaner = getattr(cleaners, name)
+ if not cleaner:
+ raise Exception('Unknown cleaner: %s' % name)
+ text = cleaner(text)
+ return text
+
+
+def _symbols_to_sequence(symbols):
+ return [_symbol_to_id[s] for s in symbols if _should_keep_symbol(s)]
+
+
+def _arpabet_to_sequence(text):
+ return _symbols_to_sequence(['@' + s for s in text.split()])
+
+
+def _should_keep_symbol(s):
+ return s in _symbol_to_id and s != '_' and s != '~'
diff --git a/fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/utils.py b/fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..66a426d2223ce75ffae6cee2131770556c5949bc
--- /dev/null
+++ b/fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/utils.py
@@ -0,0 +1,167 @@
+import collections
+import io
+import json
+import librosa
+import numpy as np
+import soundfile as sf
+import time
+import torch
+from scipy.io.wavfile import read
+from .text import SOS_TOK, EOS_TOK
+
+
+def get_mask_from_lengths(lengths):
+ max_len = torch.max(lengths).item()
+ ids = torch.arange(0, max_len, out=torch.cuda.LongTensor(max_len))
+ mask = (ids < lengths.unsqueeze(1))
+ return mask
+
+
+def load_wav_to_torch(full_path, sr=None):
+ data, sr = librosa.load(full_path, sr=sr)
+ data = np.clip(data, -1, 1) # potentially out of [-1, 1] due to resampling
+ data = data * 32768.0 # match values loaded by scipy
+ return torch.FloatTensor(data.astype(np.float32)), sr
+
+
+def read_binary_audio(bin_data, tar_sr=None):
+ """
+ read binary audio (`bytes` or `uint8` `numpy.ndarray`) to `float32`
+ `numpy.ndarray`
+
+ RETURNS:
+ data (np.ndarray) : audio of shape (n,) or (2, n)
+ tar_sr (int) : sample rate
+ """
+ data, ori_sr = sf.read(io.BytesIO(bin_data), dtype='float32')
+ data = data.T
+ if (tar_sr is not None) and (ori_sr != tar_sr):
+ data = librosa.resample(data, ori_sr, tar_sr)
+ else:
+ tar_sr = ori_sr
+ data = np.clip(data, -1, 1)
+ data = data * 32768.0
+ return torch.FloatTensor(data.astype(np.float32)), tar_sr
+
+
+def load_filepaths_and_text(filename):
+ with open(filename, encoding='utf-8') as f:
+ data = [json.loads(line.rstrip()) for line in f]
+ return data
+
+
+def to_gpu(x):
+ x = x.contiguous()
+
+ if torch.cuda.is_available():
+ x = x.cuda(non_blocking=True)
+ return torch.autograd.Variable(x)
+
+
+def load_code_dict(path, add_sos=False, add_eos=False):
+ if not path:
+ return {}
+
+ with open(path, 'r') as f:
+ codes = ['_'] + [line.rstrip() for line in f] # '_' for pad
+ code_dict = {c: i for i, c in enumerate(codes)}
+
+ if add_sos:
+ code_dict[SOS_TOK] = len(code_dict)
+ if add_eos:
+ code_dict[EOS_TOK] = len(code_dict)
+ assert(set(code_dict.values()) == set(range(len(code_dict))))
+
+ return code_dict
+
+
+def load_obs_label_dict(path):
+ if not path:
+ return {}
+ with open(path, 'r') as f:
+ obs_labels = [line.rstrip() for line in f]
+ return {c: i for i, c in enumerate(obs_labels)}
+
+
+# A simple timer class inspired from `tnt.TimeMeter`
+class CudaTimer:
+ def __init__(self, keys):
+ self.keys = keys
+ self.reset()
+
+ def start(self, key):
+ s = torch.cuda.Event(enable_timing=True)
+ s.record()
+ self.start_events[key].append(s)
+ return self
+
+ def stop(self, key):
+ e = torch.cuda.Event(enable_timing=True)
+ e.record()
+ self.end_events[key].append(e)
+ return self
+
+ def reset(self):
+ self.start_events = collections.defaultdict(list)
+ self.end_events = collections.defaultdict(list)
+ self.running_times = collections.defaultdict(float)
+ self.n = collections.defaultdict(int)
+ return self
+
+ def value(self):
+ self._synchronize()
+ return {k: self.running_times[k] / self.n[k] for k in self.keys}
+
+ def _synchronize(self):
+ torch.cuda.synchronize()
+ for k in self.keys:
+ starts = self.start_events[k]
+ ends = self.end_events[k]
+ if len(starts) == 0:
+ raise ValueError("Trying to divide by zero in TimeMeter")
+ if len(ends) != len(starts):
+ raise ValueError("Call stop before checking value!")
+ time = 0
+ for start, end in zip(starts, ends):
+ time += start.elapsed_time(end)
+ self.running_times[k] += time * 1e-3
+ self.n[k] += len(starts)
+ self.start_events = collections.defaultdict(list)
+ self.end_events = collections.defaultdict(list)
+
+
+# Used to measure the time taken for multiple events
+class Timer:
+ def __init__(self, keys):
+ self.keys = keys
+ self.n = {}
+ self.running_time = {}
+ self.total_time = {}
+ self.reset()
+
+ def start(self, key):
+ self.running_time[key] = time.time()
+ return self
+
+ def stop(self, key):
+ self.total_time[key] = time.time() - self.running_time[key]
+ self.n[key] += 1
+ self.running_time[key] = None
+ return self
+
+ def reset(self):
+ for k in self.keys:
+ self.total_time[k] = 0
+ self.running_time[k] = None
+ self.n[k] = 0
+ return self
+
+ def value(self):
+ vals = {}
+ for k in self.keys:
+ if self.n[k] == 0:
+ raise ValueError("Trying to divide by zero in TimeMeter")
+ else:
+ vals[k] = self.total_time[k] / self.n[k]
+ return vals
+
diff --git a/fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/waveglow_denoiser.py b/fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/waveglow_denoiser.py
new file mode 100644
index 0000000000000000000000000000000000000000..6a6585e8b6901a059445ff54ca20ea87751bbb11
--- /dev/null
+++ b/fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/waveglow_denoiser.py
@@ -0,0 +1,40 @@
+# import sys
+# sys.path.append('tacotron2')
+import torch
+from .layers import STFT
+
+
+class Denoiser(torch.nn.Module):
+ """ Removes model bias from audio produced with waveglow """
+
+ def __init__(self, waveglow, filter_length=1024, n_overlap=4,
+ win_length=1024, mode='zeros'):
+ super(Denoiser, self).__init__()
+ self.stft = STFT(filter_length=filter_length,
+ hop_length=int(filter_length/n_overlap),
+ win_length=win_length).cuda()
+ if mode == 'zeros':
+ mel_input = torch.zeros(
+ (1, 80, 88),
+ dtype=waveglow.upsample.weight.dtype,
+ device=waveglow.upsample.weight.device)
+ elif mode == 'normal':
+ mel_input = torch.randn(
+ (1, 80, 88),
+ dtype=waveglow.upsample.weight.dtype,
+ device=waveglow.upsample.weight.device)
+ else:
+ raise Exception("Mode {} if not supported".format(mode))
+
+ with torch.no_grad():
+ bias_audio = waveglow.infer(mel_input, sigma=0.0).float()
+ bias_spec, _ = self.stft.transform(bias_audio)
+
+ self.register_buffer('bias_spec', bias_spec[:, :, 0][:, :, None])
+
+ def forward(self, audio, strength=0.1):
+ audio_spec, audio_angles = self.stft.transform(audio.cuda().float())
+ audio_spec_denoised = audio_spec - self.bias_spec * strength
+ audio_spec_denoised = torch.clamp(audio_spec_denoised, 0.0)
+ audio_denoised = self.stft.inverse(audio_spec_denoised, audio_angles)
+ return audio_denoised
diff --git a/fairseq/examples/textless_nlp/gslm/unit2speech/tts_data.py b/fairseq/examples/textless_nlp/gslm/unit2speech/tts_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..eb0f7c360d749fd9d489b40b04dae8652b095098
--- /dev/null
+++ b/fairseq/examples/textless_nlp/gslm/unit2speech/tts_data.py
@@ -0,0 +1,52 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+import torch
+import numpy as np
+from examples.textless_nlp.gslm.unit2speech.tacotron2.text import (
+ EOS_TOK,
+ SOS_TOK,
+ code_to_sequence,
+ text_to_sequence,
+)
+from examples.textless_nlp.gslm.unit2speech.tacotron2.utils import (
+ load_code_dict,
+)
+
+
+class TacotronInputDataset:
+ def __init__(self, hparams, append_str=""):
+ self.is_text = getattr(hparams, "text_or_code", "text") == "text"
+ if not self.is_text:
+ self.code_dict = load_code_dict(hparams.code_dict)
+ self.code_key = hparams.code_key
+ self.add_sos = hparams.add_sos
+ self.add_eos = hparams.add_eos
+ self.collapse_code = hparams.collapse_code
+ self.append_str = append_str
+
+ def process_code(self, inp_str):
+ inp_toks = inp_str.split()
+ if self.add_sos:
+ inp_toks = [SOS_TOK] + inp_toks
+ if self.add_eos:
+ inp_toks = inp_toks + [EOS_TOK]
+ return code_to_sequence(inp_toks, self.code_dict, self.collapse_code)
+
+ def process_text(self, inp_str):
+ return text_to_sequence(inp_str, ["english_cleaners"])
+
+ def get_tensor(self, inp_str):
+ # uid, txt, inp_str = self._get_data(idx)
+ inp_str = inp_str + self.append_str
+ if self.is_text:
+ inp_toks = self.process_text(inp_str)
+ else:
+ inp_toks = self.process_code(inp_str)
+ return torch.from_numpy(np.array(inp_toks)).long()
+
+ def __len__(self):
+ return len(self.data)
diff --git a/fairseq/examples/textless_nlp/gslm/unit2speech/utils.py b/fairseq/examples/textless_nlp/gslm/unit2speech/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..7aced08d38301b98b19e2df7d19f1c61150107bc
--- /dev/null
+++ b/fairseq/examples/textless_nlp/gslm/unit2speech/utils.py
@@ -0,0 +1,55 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+import torch
+from examples.textless_nlp.gslm.unit2speech.tacotron2.model import Tacotron2
+from examples.textless_nlp.gslm.unit2speech.tacotron2.waveglow_denoiser import (
+ Denoiser,
+)
+
+
+def load_quantized_audio_from_file(file_path):
+ base_fname_batch, quantized_units_batch = [], []
+ with open(file_path) as f:
+ for line in f:
+ base_fname, quantized_units_str = line.rstrip().split("|")
+ quantized_units = [int(q) for q in quantized_units_str.split(" ")]
+ base_fname_batch.append(base_fname)
+ quantized_units_batch.append(quantized_units)
+ return base_fname_batch, quantized_units_batch
+
+
+def synthesize_audio(model, waveglow, denoiser, inp, lab=None, strength=0.0):
+ assert inp.size(0) == 1
+ inp = inp.cuda()
+ if lab is not None:
+ lab = torch.LongTensor(1).cuda().fill_(lab)
+
+ with torch.no_grad():
+ _, mel, _, ali, has_eos = model.inference(inp, lab, ret_has_eos=True)
+ aud = waveglow.infer(mel, sigma=0.666)
+ aud_dn = denoiser(aud, strength=strength).squeeze(1)
+ return mel, aud, aud_dn, has_eos
+
+
+def load_tacotron(tacotron_model_path, max_decoder_steps):
+ ckpt_dict = torch.load(tacotron_model_path)
+ hparams = ckpt_dict["hparams"]
+ hparams.max_decoder_steps = max_decoder_steps
+ sr = hparams.sampling_rate
+ model = Tacotron2(hparams)
+ model.load_state_dict(ckpt_dict["model_dict"])
+ model = model.cuda().eval().half()
+ return model, sr, hparams
+
+
+def load_waveglow(waveglow_path):
+ waveglow = torch.load(waveglow_path)["model"]
+ waveglow = waveglow.cuda().eval().half()
+ for k in waveglow.convinv:
+ k.float()
+ denoiser = Denoiser(waveglow)
+ return waveglow, denoiser
diff --git a/fairseq/examples/translation/README.md b/fairseq/examples/translation/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..2941f5eb8482dab61dca5eca27a71abd7ee5bf5c
--- /dev/null
+++ b/fairseq/examples/translation/README.md
@@ -0,0 +1,301 @@
+# Neural Machine Translation
+
+This README contains instructions for [using pretrained translation models](#example-usage-torchhub)
+as well as [training new models](#training-a-new-model).
+
+## Pre-trained models
+
+Model | Description | Dataset | Download
+---|---|---|---
+`conv.wmt14.en-fr` | Convolutional
([Gehring et al., 2017](https://arxiv.org/abs/1705.03122)) | [WMT14 English-French](http://statmt.org/wmt14/translation-task.html#Download) | model:
[download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/wmt14.v2.en-fr.fconv-py.tar.bz2)
newstest2014:
[download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt14.v2.en-fr.newstest2014.tar.bz2)
newstest2012/2013:
[download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt14.v2.en-fr.ntst1213.tar.bz2)
+`conv.wmt14.en-de` | Convolutional
([Gehring et al., 2017](https://arxiv.org/abs/1705.03122)) | [WMT14 English-German](http://statmt.org/wmt14/translation-task.html#Download) | model:
[download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/wmt14.en-de.fconv-py.tar.bz2)
newstest2014:
[download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt14.en-de.newstest2014.tar.bz2)
+`conv.wmt17.en-de` | Convolutional
([Gehring et al., 2017](https://arxiv.org/abs/1705.03122)) | [WMT17 English-German](http://statmt.org/wmt17/translation-task.html#Download) | model:
[download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/wmt17.v2.en-de.fconv-py.tar.bz2)
newstest2014:
[download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt17.v2.en-de.newstest2014.tar.bz2)
+`transformer.wmt14.en-fr` | Transformer
([Ott et al., 2018](https://arxiv.org/abs/1806.00187)) | [WMT14 English-French](http://statmt.org/wmt14/translation-task.html#Download) | model:
[download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/wmt14.en-fr.joined-dict.transformer.tar.bz2)
newstest2014:
[download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt14.en-fr.joined-dict.newstest2014.tar.bz2)
+`transformer.wmt16.en-de` | Transformer
([Ott et al., 2018](https://arxiv.org/abs/1806.00187)) | [WMT16 English-German](https://drive.google.com/uc?export=download&id=0B_bZck-ksdkpM25jRUN2X2UxMm8) | model:
[download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/wmt16.en-de.joined-dict.transformer.tar.bz2)
newstest2014:
[download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt16.en-de.joined-dict.newstest2014.tar.bz2)
+`transformer.wmt18.en-de` | Transformer
([Edunov et al., 2018](https://arxiv.org/abs/1808.09381))
WMT'18 winner | [WMT'18 English-German](http://www.statmt.org/wmt18/translation-task.html) | model:
[download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/wmt18.en-de.ensemble.tar.gz)
See NOTE in the archive
+`transformer.wmt19.en-de` | Transformer
([Ng et al., 2019](https://arxiv.org/abs/1907.06616))
WMT'19 winner | [WMT'19 English-German](http://www.statmt.org/wmt19/translation-task.html) | model:
[download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/wmt19.en-de.joined-dict.ensemble.tar.gz)
+`transformer.wmt19.de-en` | Transformer
([Ng et al., 2019](https://arxiv.org/abs/1907.06616))
WMT'19 winner | [WMT'19 German-English](http://www.statmt.org/wmt19/translation-task.html) | model:
[download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/wmt19.de-en.joined-dict.ensemble.tar.gz)
+`transformer.wmt19.en-ru` | Transformer
([Ng et al., 2019](https://arxiv.org/abs/1907.06616))
WMT'19 winner | [WMT'19 English-Russian](http://www.statmt.org/wmt19/translation-task.html) | model:
[download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/wmt19.en-ru.ensemble.tar.gz)
+`transformer.wmt19.ru-en` | Transformer
([Ng et al., 2019](https://arxiv.org/abs/1907.06616))
WMT'19 winner | [WMT'19 Russian-English](http://www.statmt.org/wmt19/translation-task.html) | model:
[download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/wmt19.ru-en.ensemble.tar.gz)
+
+## Example usage (torch.hub)
+
+We require a few additional Python dependencies for preprocessing:
+```bash
+pip install fastBPE sacremoses subword_nmt
+```
+
+Interactive translation via PyTorch Hub:
+```python
+import torch
+
+# List available models
+torch.hub.list('pytorch/fairseq') # [..., 'transformer.wmt16.en-de', ... ]
+
+# Load a transformer trained on WMT'16 En-De
+# Note: WMT'19 models use fastBPE instead of subword_nmt, see instructions below
+en2de = torch.hub.load('pytorch/fairseq', 'transformer.wmt16.en-de',
+ tokenizer='moses', bpe='subword_nmt')
+en2de.eval() # disable dropout
+
+# The underlying model is available under the *models* attribute
+assert isinstance(en2de.models[0], fairseq.models.transformer.TransformerModel)
+
+# Move model to GPU for faster translation
+en2de.cuda()
+
+# Translate a sentence
+en2de.translate('Hello world!')
+# 'Hallo Welt!'
+
+# Batched translation
+en2de.translate(['Hello world!', 'The cat sat on the mat.'])
+# ['Hallo Welt!', 'Die Katze saß auf der Matte.']
+```
+
+Loading custom models:
+```python
+from fairseq.models.transformer import TransformerModel
+zh2en = TransformerModel.from_pretrained(
+ '/path/to/checkpoints',
+ checkpoint_file='checkpoint_best.pt',
+ data_name_or_path='data-bin/wmt17_zh_en_full',
+ bpe='subword_nmt',
+ bpe_codes='data-bin/wmt17_zh_en_full/zh.code'
+)
+zh2en.translate('你好 世界')
+# 'Hello World'
+```
+
+If you are using a `transformer.wmt19` models, you will need to set the `bpe`
+argument to `'fastbpe'` and (optionally) load the 4-model ensemble:
+```python
+en2de = torch.hub.load('pytorch/fairseq', 'transformer.wmt19.en-de',
+ checkpoint_file='model1.pt:model2.pt:model3.pt:model4.pt',
+ tokenizer='moses', bpe='fastbpe')
+en2de.eval() # disable dropout
+```
+
+## Example usage (CLI tools)
+
+Generation with the binarized test sets can be run in batch mode as follows, e.g. for WMT 2014 English-French on a GTX-1080ti:
+```bash
+mkdir -p data-bin
+curl https://dl.fbaipublicfiles.com/fairseq/models/wmt14.v2.en-fr.fconv-py.tar.bz2 | tar xvjf - -C data-bin
+curl https://dl.fbaipublicfiles.com/fairseq/data/wmt14.v2.en-fr.newstest2014.tar.bz2 | tar xvjf - -C data-bin
+fairseq-generate data-bin/wmt14.en-fr.newstest2014 \
+ --path data-bin/wmt14.en-fr.fconv-py/model.pt \
+ --beam 5 --batch-size 128 --remove-bpe | tee /tmp/gen.out
+# ...
+# | Translated 3003 sentences (96311 tokens) in 166.0s (580.04 tokens/s)
+# | Generate test with beam=5: BLEU4 = 40.83, 67.5/46.9/34.4/25.5 (BP=1.000, ratio=1.006, syslen=83262, reflen=82787)
+
+# Compute BLEU score
+grep ^H /tmp/gen.out | cut -f3- > /tmp/gen.out.sys
+grep ^T /tmp/gen.out | cut -f2- > /tmp/gen.out.ref
+fairseq-score --sys /tmp/gen.out.sys --ref /tmp/gen.out.ref
+# BLEU4 = 40.83, 67.5/46.9/34.4/25.5 (BP=1.000, ratio=1.006, syslen=83262, reflen=82787)
+```
+
+## Training a new model
+
+### IWSLT'14 German to English (Transformer)
+
+The following instructions can be used to train a Transformer model on the [IWSLT'14 German to English dataset](http://workshop2014.iwslt.org/downloads/proceeding.pdf).
+
+First download and preprocess the data:
+```bash
+# Download and prepare the data
+cd examples/translation/
+bash prepare-iwslt14.sh
+cd ../..
+
+# Preprocess/binarize the data
+TEXT=examples/translation/iwslt14.tokenized.de-en
+fairseq-preprocess --source-lang de --target-lang en \
+ --trainpref $TEXT/train --validpref $TEXT/valid --testpref $TEXT/test \
+ --destdir data-bin/iwslt14.tokenized.de-en \
+ --workers 20
+```
+
+Next we'll train a Transformer translation model over this data:
+```bash
+CUDA_VISIBLE_DEVICES=0 fairseq-train \
+ data-bin/iwslt14.tokenized.de-en \
+ --arch transformer_iwslt_de_en --share-decoder-input-output-embed \
+ --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
+ --lr 5e-4 --lr-scheduler inverse_sqrt --warmup-updates 4000 \
+ --dropout 0.3 --weight-decay 0.0001 \
+ --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
+ --max-tokens 4096 \
+ --eval-bleu \
+ --eval-bleu-args '{"beam": 5, "max_len_a": 1.2, "max_len_b": 10}' \
+ --eval-bleu-detok moses \
+ --eval-bleu-remove-bpe \
+ --eval-bleu-print-samples \
+ --best-checkpoint-metric bleu --maximize-best-checkpoint-metric
+```
+
+Finally we can evaluate our trained model:
+```bash
+fairseq-generate data-bin/iwslt14.tokenized.de-en \
+ --path checkpoints/checkpoint_best.pt \
+ --batch-size 128 --beam 5 --remove-bpe
+```
+
+### WMT'14 English to German (Convolutional)
+
+The following instructions can be used to train a Convolutional translation model on the WMT English to German dataset.
+See the [Scaling NMT README](../scaling_nmt/README.md) for instructions to train a Transformer translation model on this data.
+
+The WMT English to German dataset can be preprocessed using the `prepare-wmt14en2de.sh` script.
+By default it will produce a dataset that was modeled after [Attention Is All You Need (Vaswani et al., 2017)](https://arxiv.org/abs/1706.03762), but with additional news-commentary-v12 data from WMT'17.
+
+To use only data available in WMT'14 or to replicate results obtained in the original [Convolutional Sequence to Sequence Learning (Gehring et al., 2017)](https://arxiv.org/abs/1705.03122) paper, please use the `--icml17` option.
+
+```bash
+# Download and prepare the data
+cd examples/translation/
+# WMT'17 data:
+bash prepare-wmt14en2de.sh
+# or to use WMT'14 data:
+# bash prepare-wmt14en2de.sh --icml17
+cd ../..
+
+# Binarize the dataset
+TEXT=examples/translation/wmt17_en_de
+fairseq-preprocess \
+ --source-lang en --target-lang de \
+ --trainpref $TEXT/train --validpref $TEXT/valid --testpref $TEXT/test \
+ --destdir data-bin/wmt17_en_de --thresholdtgt 0 --thresholdsrc 0 \
+ --workers 20
+
+# Train the model
+mkdir -p checkpoints/fconv_wmt_en_de
+fairseq-train \
+ data-bin/wmt17_en_de \
+ --arch fconv_wmt_en_de \
+ --dropout 0.2 \
+ --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
+ --optimizer nag --clip-norm 0.1 \
+ --lr 0.5 --lr-scheduler fixed --force-anneal 50 \
+ --max-tokens 4000 \
+ --save-dir checkpoints/fconv_wmt_en_de
+
+# Evaluate
+fairseq-generate data-bin/wmt17_en_de \
+ --path checkpoints/fconv_wmt_en_de/checkpoint_best.pt \
+ --beam 5 --remove-bpe
+```
+
+### WMT'14 English to French
+```bash
+# Download and prepare the data
+cd examples/translation/
+bash prepare-wmt14en2fr.sh
+cd ../..
+
+# Binarize the dataset
+TEXT=examples/translation/wmt14_en_fr
+fairseq-preprocess \
+ --source-lang en --target-lang fr \
+ --trainpref $TEXT/train --validpref $TEXT/valid --testpref $TEXT/test \
+ --destdir data-bin/wmt14_en_fr --thresholdtgt 0 --thresholdsrc 0 \
+ --workers 60
+
+# Train the model
+mkdir -p checkpoints/fconv_wmt_en_fr
+fairseq-train \
+ data-bin/wmt14_en_fr \
+ --arch fconv_wmt_en_fr \
+ --dropout 0.1 \
+ --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
+ --optimizer nag --clip-norm 0.1 \
+ --lr 0.5 --lr-scheduler fixed --force-anneal 50 \
+ --max-tokens 3000 \
+ --save-dir checkpoints/fconv_wmt_en_fr
+
+# Evaluate
+fairseq-generate \
+ data-bin/fconv_wmt_en_fr \
+ --path checkpoints/fconv_wmt_en_fr/checkpoint_best.pt \
+ --beam 5 --remove-bpe
+```
+
+## Multilingual Translation
+
+We also support training multilingual translation models. In this example we'll
+train a multilingual `{de,fr}-en` translation model using the IWSLT'17 datasets.
+
+Note that we use slightly different preprocessing here than for the IWSLT'14
+En-De data above. In particular we learn a joint BPE code for all three
+languages and use fairseq-interactive and sacrebleu for scoring the test set.
+
+```bash
+# First install sacrebleu and sentencepiece
+pip install sacrebleu sentencepiece
+
+# Then download and preprocess the data
+cd examples/translation/
+bash prepare-iwslt17-multilingual.sh
+cd ../..
+
+# Binarize the de-en dataset
+TEXT=examples/translation/iwslt17.de_fr.en.bpe16k
+fairseq-preprocess --source-lang de --target-lang en \
+ --trainpref $TEXT/train.bpe.de-en \
+ --validpref $TEXT/valid0.bpe.de-en,$TEXT/valid1.bpe.de-en,$TEXT/valid2.bpe.de-en,$TEXT/valid3.bpe.de-en,$TEXT/valid4.bpe.de-en,$TEXT/valid5.bpe.de-en \
+ --destdir data-bin/iwslt17.de_fr.en.bpe16k \
+ --workers 10
+
+# Binarize the fr-en dataset
+# NOTE: it's important to reuse the en dictionary from the previous step
+fairseq-preprocess --source-lang fr --target-lang en \
+ --trainpref $TEXT/train.bpe.fr-en \
+ --validpref $TEXT/valid0.bpe.fr-en,$TEXT/valid1.bpe.fr-en,$TEXT/valid2.bpe.fr-en,$TEXT/valid3.bpe.fr-en,$TEXT/valid4.bpe.fr-en,$TEXT/valid5.bpe.fr-en \
+ --tgtdict data-bin/iwslt17.de_fr.en.bpe16k/dict.en.txt \
+ --destdir data-bin/iwslt17.de_fr.en.bpe16k \
+ --workers 10
+
+# Train a multilingual transformer model
+# NOTE: the command below assumes 1 GPU, but accumulates gradients from
+# 8 fwd/bwd passes to simulate training on 8 GPUs
+mkdir -p checkpoints/multilingual_transformer
+CUDA_VISIBLE_DEVICES=0 fairseq-train data-bin/iwslt17.de_fr.en.bpe16k/ \
+ --max-epoch 50 \
+ --ddp-backend=legacy_ddp \
+ --task multilingual_translation --lang-pairs de-en,fr-en \
+ --arch multilingual_transformer_iwslt_de_en \
+ --share-decoders --share-decoder-input-output-embed \
+ --optimizer adam --adam-betas '(0.9, 0.98)' \
+ --lr 0.0005 --lr-scheduler inverse_sqrt \
+ --warmup-updates 4000 --warmup-init-lr '1e-07' \
+ --label-smoothing 0.1 --criterion label_smoothed_cross_entropy \
+ --dropout 0.3 --weight-decay 0.0001 \
+ --save-dir checkpoints/multilingual_transformer \
+ --max-tokens 4000 \
+ --update-freq 8
+
+# Generate and score the test set with sacrebleu
+SRC=de
+sacrebleu --test-set iwslt17 --language-pair ${SRC}-en --echo src \
+ | python scripts/spm_encode.py --model examples/translation/iwslt17.de_fr.en.bpe16k/sentencepiece.bpe.model \
+ > iwslt17.test.${SRC}-en.${SRC}.bpe
+cat iwslt17.test.${SRC}-en.${SRC}.bpe \
+ | fairseq-interactive data-bin/iwslt17.de_fr.en.bpe16k/ \
+ --task multilingual_translation --lang-pairs de-en,fr-en \
+ --source-lang ${SRC} --target-lang en \
+ --path checkpoints/multilingual_transformer/checkpoint_best.pt \
+ --buffer-size 2000 --batch-size 128 \
+ --beam 5 --remove-bpe=sentencepiece \
+ > iwslt17.test.${SRC}-en.en.sys
+grep ^H iwslt17.test.${SRC}-en.en.sys | cut -f3 \
+ | sacrebleu --test-set iwslt17 --language-pair ${SRC}-en
+```
+
+##### Argument format during inference
+
+During inference it is required to specify a single `--source-lang` and
+`--target-lang`, which indicates the inference langauge direction.
+`--lang-pairs`, `--encoder-langtok`, `--decoder-langtok` have to be set to
+the same value as training.
diff --git a/fairseq/examples/translation/prepare-iwslt14.sh b/fairseq/examples/translation/prepare-iwslt14.sh
new file mode 100644
index 0000000000000000000000000000000000000000..2fb6643fbccb58701dcbb77d91430e68a821ba38
--- /dev/null
+++ b/fairseq/examples/translation/prepare-iwslt14.sh
@@ -0,0 +1,115 @@
+#!/usr/bin/env bash
+#
+# Adapted from https://github.com/facebookresearch/MIXER/blob/master/prepareData.sh
+
+echo 'Cloning Moses github repository (for tokenization scripts)...'
+git clone https://github.com/moses-smt/mosesdecoder.git
+
+echo 'Cloning Subword NMT repository (for BPE pre-processing)...'
+git clone https://github.com/rsennrich/subword-nmt.git
+
+SCRIPTS=mosesdecoder/scripts
+TOKENIZER=$SCRIPTS/tokenizer/tokenizer.perl
+LC=$SCRIPTS/tokenizer/lowercase.perl
+CLEAN=$SCRIPTS/training/clean-corpus-n.perl
+BPEROOT=subword-nmt/subword_nmt
+BPE_TOKENS=10000
+
+URL="http://dl.fbaipublicfiles.com/fairseq/data/iwslt14/de-en.tgz"
+GZ=de-en.tgz
+
+if [ ! -d "$SCRIPTS" ]; then
+ echo "Please set SCRIPTS variable correctly to point to Moses scripts."
+ exit
+fi
+
+src=de
+tgt=en
+lang=de-en
+prep=iwslt14.tokenized.de-en
+tmp=$prep/tmp
+orig=orig
+
+mkdir -p $orig $tmp $prep
+
+echo "Downloading data from ${URL}..."
+cd $orig
+wget "$URL"
+
+if [ -f $GZ ]; then
+ echo "Data successfully downloaded."
+else
+ echo "Data not successfully downloaded."
+ exit
+fi
+
+tar zxvf $GZ
+cd ..
+
+echo "pre-processing train data..."
+for l in $src $tgt; do
+ f=train.tags.$lang.$l
+ tok=train.tags.$lang.tok.$l
+
+ cat $orig/$lang/$f | \
+ grep -v '' | \
+ grep -v '' | \
+ grep -v '' | \
+ sed -e 's///g' | \
+ sed -e 's/<\/title>//g' | \
+ sed -e 's///g' | \
+ sed -e 's/<\/description>//g' | \
+ perl $TOKENIZER -threads 8 -l $l > $tmp/$tok
+ echo ""
+done
+perl $CLEAN -ratio 1.5 $tmp/train.tags.$lang.tok $src $tgt $tmp/train.tags.$lang.clean 1 175
+for l in $src $tgt; do
+ perl $LC < $tmp/train.tags.$lang.clean.$l > $tmp/train.tags.$lang.$l
+done
+
+echo "pre-processing valid/test data..."
+for l in $src $tgt; do
+ for o in `ls $orig/$lang/IWSLT14.TED*.$l.xml`; do
+ fname=${o##*/}
+ f=$tmp/${fname%.*}
+ echo $o $f
+ grep '