initial commit
Browse files- .gitattributes +1 -0
- .gitignore +2 -0
- README.md +65 -0
- build.bat +1 -0
- build.sh +1 -0
- docker/Dockerfile +188 -0
- err2020/audio/emt16k.wav +0 -0
- err2020/conformer_ctc3/__init__.py +0 -0
- err2020/conformer_ctc3/__pycache__/__init__.cpython-39.pyc +0 -0
- err2020/conformer_ctc3/__pycache__/asr_datamodule.cpython-39.pyc +0 -0
- err2020/conformer_ctc3/__pycache__/conformer.cpython-39.pyc +0 -0
- err2020/conformer_ctc3/__pycache__/decode.cpython-39.pyc +0 -0
- err2020/conformer_ctc3/__pycache__/encoder_interface.cpython-39.pyc +0 -0
- err2020/conformer_ctc3/__pycache__/model.cpython-39.pyc +0 -0
- err2020/conformer_ctc3/__pycache__/optim.cpython-39.pyc +0 -0
- err2020/conformer_ctc3/__pycache__/scaling.cpython-39.pyc +0 -0
- err2020/conformer_ctc3/__pycache__/train.cpython-39.pyc +0 -0
- err2020/conformer_ctc3/asr_datamodule.py +458 -0
- err2020/conformer_ctc3/conformer.py +1598 -0
- err2020/conformer_ctc3/decode.py +1052 -0
- err2020/conformer_ctc3/encoder_interface.py +43 -0
- err2020/conformer_ctc3/exp/jit_trace.pt +3 -0
- err2020/conformer_ctc3/export.py +292 -0
- err2020/conformer_ctc3/jit_pretrained.py +413 -0
- err2020/conformer_ctc3/lstmp.py +102 -0
- err2020/conformer_ctc3/model.py +122 -0
- err2020/conformer_ctc3/optim.py +320 -0
- err2020/conformer_ctc3/pretrained.py +461 -0
- err2020/conformer_ctc3/scaling.py +1015 -0
- err2020/conformer_ctc3/test_model.py +82 -0
- err2020/conformer_ctc3/train.py +1109 -0
- err2020/conformer_ctc3_usage.ipynb +500 -0
- err2020/data/lang_bpe_500/bpe.model +3 -0
- requirements.txt +1 -0
- run.bat +9 -0
- run.sh +11 -0
.gitattributes
CHANGED
@@ -32,3 +32,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
35 |
+
err2020/audio/oden_kypsis16k.wav filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
**/.ipynb_checkpoints/
|
2 |
+
.idea/
|
README.md
CHANGED
@@ -1,3 +1,68 @@
|
|
1 |
---
|
|
|
|
|
2 |
license: apache-2.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
+
language:
|
3 |
+
- et
|
4 |
license: apache-2.0
|
5 |
+
metrics:
|
6 |
+
- wer
|
7 |
+
model-index:
|
8 |
+
- name: conformer-ctc et
|
9 |
+
results:
|
10 |
+
- task:
|
11 |
+
name: Automatic Speech Recognition
|
12 |
+
type: automatic-speech-recognition
|
13 |
+
dataset:
|
14 |
+
name: ERR2020
|
15 |
+
args: et
|
16 |
+
metrics:
|
17 |
+
- name: Wer
|
18 |
+
type: wer
|
19 |
+
value: 12.1
|
20 |
---
|
21 |
+
|
22 |
+
<!-- This model card has been generated automatically according to the information the Trainer had access to. You
|
23 |
+
should probably proofread and complete it, then remove this comment. -->
|
24 |
+
|
25 |
+
# conformer-ctc et
|
26 |
+
|
27 |
+
Icefall conformer-ctc3 based recipe (https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/conformer_ctc3) trained Estonian ASR model using ERR2020 dataset
|
28 |
+
- WER on ERR2020: 12.1
|
29 |
+
- WER on mozilla commonvoice_11: 24.5
|
30 |
+
|
31 |
+
|
32 |
+
For usage:
|
33 |
+
- clone this repo (`git clone https://huggingface.co/rristo/icefall_conformer_ctc3_et`)
|
34 |
+
- go to repo (`cd icefall_conformer_ctc3_et`)
|
35 |
+
- build docker image for needed libraries (`build.sh` or `build.bat`)
|
36 |
+
- run docker container (`run.sh`or `run.sh`). This mounts current directory
|
37 |
+
- run notebook `err2020/conformer_ctc3_usage.ipynb` for example usage
|
38 |
+
- currently expects audio to be in .wav format
|
39 |
+
|
40 |
+
## Model description
|
41 |
+
|
42 |
+
ASR model for Estonian, uses Estonian Public Broadcasting data ERR2020 data (around 230 hours of audio)
|
43 |
+
|
44 |
+
## Intended uses & limitations
|
45 |
+
|
46 |
+
Pretty much a toy model, trained on limited amount of data. Might not work well on data out of domain
|
47 |
+
(especially spontaneous/noisy data).
|
48 |
+
|
49 |
+
## Training and evaluation data
|
50 |
+
|
51 |
+
Trained on ERR2020 data, evaluated on ERR2020 and mozilla commonvoice test data.
|
52 |
+
|
53 |
+
## Training procedure
|
54 |
+
|
55 |
+
Used Icefall conformer-ctc3 based recipe (https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/conformer_ctc3)
|
56 |
+
|
57 |
+
### Training results
|
58 |
+
|
59 |
+
|
60 |
+
TODO
|
61 |
+
|
62 |
+
### Framework versions
|
63 |
+
|
64 |
+
- icefall
|
65 |
+
- k2
|
66 |
+
- kaldifeat==1.24
|
67 |
+
- lhotse==1.15.0
|
68 |
+
- torch==2.0.0
|
build.bat
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
docker build -t icefall -f docker/Dockerfile .
|
build.sh
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
docker build -t icefall -f docker/Dockerfile .
|
docker/Dockerfile
ADDED
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
# ==================================================================
|
4 |
+
# Initial setup
|
5 |
+
# ------------------------------------------------------------------
|
6 |
+
|
7 |
+
# Ubuntu 20.04 as base image
|
8 |
+
FROM ubuntu:20.04
|
9 |
+
RUN yes| unminimize
|
10 |
+
|
11 |
+
# Set ENV variables
|
12 |
+
ENV LANG C.UTF-8
|
13 |
+
ENV SHELL=/bin/bash
|
14 |
+
ENV DEBIAN_FRONTEND=noninteractive
|
15 |
+
|
16 |
+
ENV APT_INSTALL="apt-get install -y --no-install-recommends"
|
17 |
+
ENV PIP_INSTALL="python3 -m pip --no-cache-dir install --upgrade"
|
18 |
+
ENV GIT_CLONE="git clone --depth 10"
|
19 |
+
|
20 |
+
|
21 |
+
# ==================================================================
|
22 |
+
# Tools
|
23 |
+
# ------------------------------------------------------------------
|
24 |
+
|
25 |
+
RUN apt-get update && \
|
26 |
+
$APT_INSTALL \
|
27 |
+
apt-utils \
|
28 |
+
gcc \
|
29 |
+
make \
|
30 |
+
pkg-config \
|
31 |
+
apt-transport-https \
|
32 |
+
build-essential \
|
33 |
+
ca-certificates \
|
34 |
+
wget \
|
35 |
+
rsync \
|
36 |
+
git \
|
37 |
+
vim \
|
38 |
+
mlocate \
|
39 |
+
libssl-dev \
|
40 |
+
curl \
|
41 |
+
openssh-client \
|
42 |
+
unzip \
|
43 |
+
unrar \
|
44 |
+
zip \
|
45 |
+
csvkit \
|
46 |
+
emacs \
|
47 |
+
joe \
|
48 |
+
jq \
|
49 |
+
dialog \
|
50 |
+
man-db \
|
51 |
+
manpages \
|
52 |
+
manpages-dev \
|
53 |
+
manpages-posix \
|
54 |
+
manpages-posix-dev \
|
55 |
+
nano \
|
56 |
+
iputils-ping \
|
57 |
+
sudo \
|
58 |
+
ffmpeg \
|
59 |
+
libsm6 \
|
60 |
+
libxext6 \
|
61 |
+
libboost-all-dev \
|
62 |
+
cifs-utils \
|
63 |
+
software-properties-common
|
64 |
+
|
65 |
+
|
66 |
+
#RUN curl -LO http://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh
|
67 |
+
#RUN bash Miniconda3-latest-Linux-x86_64.sh -p /miniconda -b
|
68 |
+
#RUN rm Miniconda3-latest-Linux-x86_64.sh
|
69 |
+
#ENV PATH=/miniconda/bin:${PATH}
|
70 |
+
#RUN conda update -y conda
|
71 |
+
|
72 |
+
## conda
|
73 |
+
#RUN conda install -c anaconda -y python=3.9.7
|
74 |
+
|
75 |
+
|
76 |
+
# ==================================================================
|
77 |
+
# Python
|
78 |
+
# ------------------------------------------------------------------
|
79 |
+
|
80 |
+
#Based on https://launchpad.net/~deadsnakes/+archive/ubuntu/ppa
|
81 |
+
|
82 |
+
# Adding repository for python3.9
|
83 |
+
RUN add-apt-repository ppa:deadsnakes/ppa -y && \
|
84 |
+
|
85 |
+
# Installing python3.9
|
86 |
+
$APT_INSTALL \
|
87 |
+
python3.9 \
|
88 |
+
python3.9-dev \
|
89 |
+
python3.9-venv \
|
90 |
+
python3-distutils-extra
|
91 |
+
|
92 |
+
# Add symlink so python and python3 commands use same python3.9 executable
|
93 |
+
RUN ln -s /usr/bin/python3.9 /usr/local/bin/python3 && \
|
94 |
+
ln -s /usr/bin/python3.9 /usr/local/bin/python
|
95 |
+
|
96 |
+
# Installing pip
|
97 |
+
RUN curl -sS https://bootstrap.pypa.io/get-pip.py | python3.9
|
98 |
+
ENV PATH=$PATH:/root/.local/bin
|
99 |
+
|
100 |
+
RUN pip install torch==2.0.0+cpu torchvision==0.15.1+cpu torchaudio==2.0.1 --index-url https://download.pytorch.org/whl/cpu
|
101 |
+
|
102 |
+
|
103 |
+
# ==================================================================
|
104 |
+
# JupyterLab
|
105 |
+
# ------------------------------------------------------------------
|
106 |
+
|
107 |
+
# Based on https://jupyterlab.readthedocs.io/en/stable/getting_started/installation.html#pip
|
108 |
+
|
109 |
+
RUN $PIP_INSTALL jupyterlab==3.4.6
|
110 |
+
|
111 |
+
# ==================================================================
|
112 |
+
# Additional Python Packages
|
113 |
+
# ------------------------------------------------------------------
|
114 |
+
|
115 |
+
RUN $PIP_INSTALL \
|
116 |
+
numpy==1.23.4 \
|
117 |
+
scipy==1.9.2 \
|
118 |
+
pandas==1.5.0 \
|
119 |
+
cloudpickle==2.2.0 \
|
120 |
+
scikit-image==0.19.3 \
|
121 |
+
scikit-learn==1.1.2 \
|
122 |
+
matplotlib==3.6.1 \
|
123 |
+
ipython==8.5.0 \
|
124 |
+
ipykernel==6.16.0 \
|
125 |
+
ipywidgets==8.0.2 \
|
126 |
+
cython==0.29.32 \
|
127 |
+
tqdm==4.64.1 \
|
128 |
+
pillow==9.2.0 \
|
129 |
+
seaborn==0.12.0 \
|
130 |
+
future==0.18.2 \
|
131 |
+
jsonify==0.5 \
|
132 |
+
opencv-python==4.6.0.66 \
|
133 |
+
awscli==1.25.91 \
|
134 |
+
jupyterlab-snippets==0.4.1
|
135 |
+
|
136 |
+
# ==================================================================
|
137 |
+
# CMake
|
138 |
+
# ------------------------------------------------------------------
|
139 |
+
|
140 |
+
RUN git clone https://github.com/Kitware/CMake ~/cmake && \
|
141 |
+
cd ~/cmake && \
|
142 |
+
./bootstrap && \
|
143 |
+
make -j"$(nproc)" install
|
144 |
+
|
145 |
+
|
146 |
+
# ==================================================================
|
147 |
+
# Node.js and Jupyter Notebook Extensions
|
148 |
+
# ------------------------------------------------------------------
|
149 |
+
|
150 |
+
RUN curl -sL https://deb.nodesource.com/setup_16.x | bash && \
|
151 |
+
$APT_INSTALL nodejs && \
|
152 |
+
$PIP_INSTALL jupyter_contrib_nbextensions jupyterlab-git && \
|
153 |
+
jupyter contrib nbextension install --user
|
154 |
+
|
155 |
+
|
156 |
+
# ==================================================================
|
157 |
+
# Icefall stuff
|
158 |
+
# ------------------------------------------------------------------
|
159 |
+
|
160 |
+
#k2
|
161 |
+
RUN cd /opt && \
|
162 |
+
git clone https://github.com/k2-fsa/k2.git && \
|
163 |
+
cd k2 && \
|
164 |
+
mkdir build-cpu && \
|
165 |
+
cd build-cpu && \
|
166 |
+
cmake -DK2_WITH_CUDA=OFF -DCMAKE_BUILD_TYPE=Debug .. && \
|
167 |
+
make -j5
|
168 |
+
|
169 |
+
ENV PYTHONPATH "${PYTHONPATH}:/opt/k2/build-cpu/../k2/python"
|
170 |
+
ENV PYTHONPATH "${PYTHONPATH}:/opt/k2/build-cpu/lib"
|
171 |
+
|
172 |
+
#icefall
|
173 |
+
RUN mkdir /opt/install/
|
174 |
+
COPY requirements.txt /opt/install/requirements.txt
|
175 |
+
RUN pip3 install -r /opt/install/requirements.txt
|
176 |
+
RUN cd /opt && git clone https://github.com/k2-fsa/icefall
|
177 |
+
RUN cd /opt/icefall && pip install -r requirements.txt
|
178 |
+
ENV PYTHONPATH "${PYTHONPATH}:/opt/icefall/"
|
179 |
+
RUN pip install kaldifeat
|
180 |
+
RUN mkdir /opt/notebooks
|
181 |
+
|
182 |
+
# ==================================================================
|
183 |
+
# Startup
|
184 |
+
# ------------------------------------------------------------------
|
185 |
+
|
186 |
+
EXPOSE 8888 6006
|
187 |
+
WORKDIR /opt/notebooks
|
188 |
+
CMD jupyter lab --allow-root --ip=0.0.0.0 --ServerApp.trust_xheaders=True --ServerApp.disable_check_xsrf=False --ServerApp.allow_remote_access=True --ServerApp.allow_origin='*' --ServerApp.allow_credentials=True
|
err2020/audio/emt16k.wav
ADDED
Binary file (408 kB). View file
|
|
err2020/conformer_ctc3/__init__.py
ADDED
File without changes
|
err2020/conformer_ctc3/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (140 Bytes). View file
|
|
err2020/conformer_ctc3/__pycache__/asr_datamodule.cpython-39.pyc
ADDED
Binary file (9.95 kB). View file
|
|
err2020/conformer_ctc3/__pycache__/conformer.cpython-39.pyc
ADDED
Binary file (43.6 kB). View file
|
|
err2020/conformer_ctc3/__pycache__/decode.cpython-39.pyc
ADDED
Binary file (24.7 kB). View file
|
|
err2020/conformer_ctc3/__pycache__/encoder_interface.cpython-39.pyc
ADDED
Binary file (1.34 kB). View file
|
|
err2020/conformer_ctc3/__pycache__/model.cpython-39.pyc
ADDED
Binary file (3.65 kB). View file
|
|
err2020/conformer_ctc3/__pycache__/optim.cpython-39.pyc
ADDED
Binary file (9.99 kB). View file
|
|
err2020/conformer_ctc3/__pycache__/scaling.cpython-39.pyc
ADDED
Binary file (30.5 kB). View file
|
|
err2020/conformer_ctc3/__pycache__/train.cpython-39.pyc
ADDED
Binary file (24.7 kB). View file
|
|
err2020/conformer_ctc3/asr_datamodule.py
ADDED
@@ -0,0 +1,458 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2021 Piotr Żelasko
|
2 |
+
# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo)
|
3 |
+
#
|
4 |
+
# See ../../../../LICENSE for clarification regarding multiple authors
|
5 |
+
#
|
6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
7 |
+
# you may not use this file except in compliance with the License.
|
8 |
+
# You may obtain a copy of the License at
|
9 |
+
#
|
10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
11 |
+
#
|
12 |
+
# Unless required by applicable law or agreed to in writing, software
|
13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15 |
+
# See the License for the specific language governing permissions and
|
16 |
+
# limitations under the License.
|
17 |
+
|
18 |
+
|
19 |
+
import argparse
|
20 |
+
import inspect
|
21 |
+
import logging
|
22 |
+
from functools import lru_cache
|
23 |
+
from pathlib import Path
|
24 |
+
from typing import Any, Dict, Optional
|
25 |
+
|
26 |
+
import torch
|
27 |
+
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
|
28 |
+
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
|
29 |
+
CutConcatenate,
|
30 |
+
CutMix,
|
31 |
+
DynamicBucketingSampler,
|
32 |
+
K2SpeechRecognitionDataset,
|
33 |
+
PrecomputedFeatures,
|
34 |
+
SingleCutSampler,
|
35 |
+
SpecAugment,
|
36 |
+
)
|
37 |
+
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
|
38 |
+
AudioSamples,
|
39 |
+
OnTheFlyFeatures,
|
40 |
+
)
|
41 |
+
from lhotse.utils import fix_random_seed
|
42 |
+
from torch.utils.data import DataLoader
|
43 |
+
|
44 |
+
from icefall.utils import str2bool
|
45 |
+
|
46 |
+
|
47 |
+
class _SeedWorkers:
|
48 |
+
def __init__(self, seed: int):
|
49 |
+
self.seed = seed
|
50 |
+
|
51 |
+
def __call__(self, worker_id: int):
|
52 |
+
fix_random_seed(self.seed + worker_id)
|
53 |
+
|
54 |
+
|
55 |
+
class LibriSpeechAsrDataModule:
|
56 |
+
"""
|
57 |
+
DataModule for k2 ASR experiments.
|
58 |
+
It assumes there is always one train and valid dataloader,
|
59 |
+
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
|
60 |
+
and test-other).
|
61 |
+
|
62 |
+
It contains all the common data pipeline modules used in ASR
|
63 |
+
experiments, e.g.:
|
64 |
+
- dynamic batch size,
|
65 |
+
- bucketing samplers,
|
66 |
+
- cut concatenation,
|
67 |
+
- augmentation,
|
68 |
+
- on-the-fly feature extraction
|
69 |
+
|
70 |
+
This class should be derived for specific corpora used in ASR tasks.
|
71 |
+
"""
|
72 |
+
|
73 |
+
def __init__(self, args: argparse.Namespace):
|
74 |
+
self.args = args
|
75 |
+
|
76 |
+
@classmethod
|
77 |
+
def add_arguments(cls, parser: argparse.ArgumentParser):
|
78 |
+
group = parser.add_argument_group(
|
79 |
+
title="ASR data related options",
|
80 |
+
description="These options are used for the preparation of "
|
81 |
+
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
|
82 |
+
"effective batch sizes, sampling strategies, applied data "
|
83 |
+
"augmentations, etc.",
|
84 |
+
)
|
85 |
+
group.add_argument(
|
86 |
+
"--full-libri",
|
87 |
+
type=str2bool,
|
88 |
+
default=True,
|
89 |
+
help="When enabled, use 960h LibriSpeech. Otherwise, use 100h subset.",
|
90 |
+
)
|
91 |
+
group.add_argument(
|
92 |
+
"--manifest-dir",
|
93 |
+
type=Path,
|
94 |
+
default=Path("data/fbank"),
|
95 |
+
help="Path to directory with train/valid/test cuts.",
|
96 |
+
)
|
97 |
+
group.add_argument(
|
98 |
+
"--max-duration",
|
99 |
+
type=int,
|
100 |
+
default=200.0,
|
101 |
+
help="Maximum pooled recordings duration (seconds) in a "
|
102 |
+
"single batch. You can reduce it if it causes CUDA OOM.",
|
103 |
+
)
|
104 |
+
group.add_argument(
|
105 |
+
"--bucketing-sampler",
|
106 |
+
type=str2bool,
|
107 |
+
default=True,
|
108 |
+
help="When enabled, the batches will come from buckets of "
|
109 |
+
"similar duration (saves padding frames).",
|
110 |
+
)
|
111 |
+
group.add_argument(
|
112 |
+
"--num-buckets",
|
113 |
+
type=int,
|
114 |
+
default=30,
|
115 |
+
help="The number of buckets for the DynamicBucketingSampler"
|
116 |
+
"(you might want to increase it for larger datasets).",
|
117 |
+
)
|
118 |
+
group.add_argument(
|
119 |
+
"--concatenate-cuts",
|
120 |
+
type=str2bool,
|
121 |
+
default=False,
|
122 |
+
help="When enabled, utterances (cuts) will be concatenated "
|
123 |
+
"to minimize the amount of padding.",
|
124 |
+
)
|
125 |
+
group.add_argument(
|
126 |
+
"--duration-factor",
|
127 |
+
type=float,
|
128 |
+
default=1.0,
|
129 |
+
help="Determines the maximum duration of a concatenated cut "
|
130 |
+
"relative to the duration of the longest cut in a batch.",
|
131 |
+
)
|
132 |
+
group.add_argument(
|
133 |
+
"--gap",
|
134 |
+
type=float,
|
135 |
+
default=1.0,
|
136 |
+
help="The amount of padding (in seconds) inserted between "
|
137 |
+
"concatenated cuts. This padding is filled with noise when "
|
138 |
+
"noise augmentation is used.",
|
139 |
+
)
|
140 |
+
group.add_argument(
|
141 |
+
"--on-the-fly-feats",
|
142 |
+
type=str2bool,
|
143 |
+
default=False,
|
144 |
+
help="When enabled, use on-the-fly cut mixing and feature "
|
145 |
+
"extraction. Will drop existing precomputed feature manifests "
|
146 |
+
"if available.",
|
147 |
+
)
|
148 |
+
group.add_argument(
|
149 |
+
"--shuffle",
|
150 |
+
type=str2bool,
|
151 |
+
default=True,
|
152 |
+
help="When enabled (=default), the examples will be "
|
153 |
+
"shuffled for each epoch.",
|
154 |
+
)
|
155 |
+
group.add_argument(
|
156 |
+
"--drop-last",
|
157 |
+
type=str2bool,
|
158 |
+
default=True,
|
159 |
+
help="Whether to drop last batch. Used by sampler.",
|
160 |
+
)
|
161 |
+
group.add_argument(
|
162 |
+
"--return-cuts",
|
163 |
+
type=str2bool,
|
164 |
+
default=True,
|
165 |
+
help="When enabled, each batch will have the "
|
166 |
+
"field: batch['supervisions']['cut'] with the cuts that "
|
167 |
+
"were used to construct it.",
|
168 |
+
)
|
169 |
+
|
170 |
+
group.add_argument(
|
171 |
+
"--num-workers",
|
172 |
+
type=int,
|
173 |
+
default=2,
|
174 |
+
help="The number of training dataloader workers that "
|
175 |
+
"collect the batches.",
|
176 |
+
)
|
177 |
+
|
178 |
+
group.add_argument(
|
179 |
+
"--enable-spec-aug",
|
180 |
+
type=str2bool,
|
181 |
+
default=True,
|
182 |
+
help="When enabled, use SpecAugment for training dataset.",
|
183 |
+
)
|
184 |
+
|
185 |
+
group.add_argument(
|
186 |
+
"--spec-aug-time-warp-factor",
|
187 |
+
type=int,
|
188 |
+
default=80,
|
189 |
+
help="Used only when --enable-spec-aug is True. "
|
190 |
+
"It specifies the factor for time warping in SpecAugment. "
|
191 |
+
"Larger values mean more warping. "
|
192 |
+
"A value less than 1 means to disable time warp.",
|
193 |
+
)
|
194 |
+
|
195 |
+
group.add_argument(
|
196 |
+
"--enable-musan",
|
197 |
+
type=str2bool,
|
198 |
+
default=True,
|
199 |
+
help="When enabled, select noise from MUSAN and mix it"
|
200 |
+
"with training dataset. ",
|
201 |
+
)
|
202 |
+
|
203 |
+
group.add_argument(
|
204 |
+
"--input-strategy",
|
205 |
+
type=str,
|
206 |
+
default="PrecomputedFeatures",
|
207 |
+
help="AudioSamples or PrecomputedFeatures",
|
208 |
+
)
|
209 |
+
|
210 |
+
def train_dataloaders(
|
211 |
+
self,
|
212 |
+
cuts_train: CutSet,
|
213 |
+
sampler_state_dict: Optional[Dict[str, Any]] = None,
|
214 |
+
) -> DataLoader:
|
215 |
+
"""
|
216 |
+
Args:
|
217 |
+
cuts_train:
|
218 |
+
CutSet for training.
|
219 |
+
sampler_state_dict:
|
220 |
+
The state dict for the training sampler.
|
221 |
+
"""
|
222 |
+
transforms = []
|
223 |
+
if self.args.enable_musan:
|
224 |
+
logging.info("Enable MUSAN")
|
225 |
+
logging.info("About to get Musan cuts")
|
226 |
+
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
|
227 |
+
transforms.append(
|
228 |
+
CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
|
229 |
+
)
|
230 |
+
else:
|
231 |
+
logging.info("Disable MUSAN")
|
232 |
+
|
233 |
+
if self.args.concatenate_cuts:
|
234 |
+
logging.info(
|
235 |
+
f"Using cut concatenation with duration factor "
|
236 |
+
f"{self.args.duration_factor} and gap {self.args.gap}."
|
237 |
+
)
|
238 |
+
# Cut concatenation should be the first transform in the list,
|
239 |
+
# so that if we e.g. mix noise in, it will fill the gaps between
|
240 |
+
# different utterances.
|
241 |
+
transforms = [
|
242 |
+
CutConcatenate(
|
243 |
+
duration_factor=self.args.duration_factor, gap=self.args.gap
|
244 |
+
)
|
245 |
+
] + transforms
|
246 |
+
|
247 |
+
input_transforms = []
|
248 |
+
if self.args.enable_spec_aug:
|
249 |
+
logging.info("Enable SpecAugment")
|
250 |
+
logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
|
251 |
+
# Set the value of num_frame_masks according to Lhotse's version.
|
252 |
+
# In different Lhotse's versions, the default of num_frame_masks is
|
253 |
+
# different.
|
254 |
+
num_frame_masks = 10
|
255 |
+
num_frame_masks_parameter = inspect.signature(
|
256 |
+
SpecAugment.__init__
|
257 |
+
).parameters["num_frame_masks"]
|
258 |
+
if num_frame_masks_parameter.default == 1:
|
259 |
+
num_frame_masks = 2
|
260 |
+
logging.info(f"Num frame mask: {num_frame_masks}")
|
261 |
+
input_transforms.append(
|
262 |
+
SpecAugment(
|
263 |
+
time_warp_factor=self.args.spec_aug_time_warp_factor,
|
264 |
+
num_frame_masks=num_frame_masks,
|
265 |
+
features_mask_size=27,
|
266 |
+
num_feature_masks=2,
|
267 |
+
frames_mask_size=100,
|
268 |
+
)
|
269 |
+
)
|
270 |
+
else:
|
271 |
+
logging.info("Disable SpecAugment")
|
272 |
+
|
273 |
+
logging.info("About to create train dataset")
|
274 |
+
train = K2SpeechRecognitionDataset(
|
275 |
+
input_strategy=eval(self.args.input_strategy)(),
|
276 |
+
cut_transforms=transforms,
|
277 |
+
input_transforms=input_transforms,
|
278 |
+
return_cuts=self.args.return_cuts,
|
279 |
+
)
|
280 |
+
|
281 |
+
if self.args.on_the_fly_feats:
|
282 |
+
# NOTE: the PerturbSpeed transform should be added only if we
|
283 |
+
# remove it from data prep stage.
|
284 |
+
# Add on-the-fly speed perturbation; since originally it would
|
285 |
+
# have increased epoch size by 3, we will apply prob 2/3 and use
|
286 |
+
# 3x more epochs.
|
287 |
+
# Speed perturbation probably should come first before
|
288 |
+
# concatenation, but in principle the transforms order doesn't have
|
289 |
+
# to be strict (e.g. could be randomized)
|
290 |
+
# transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
|
291 |
+
# Drop feats to be on the safe side.
|
292 |
+
train = K2SpeechRecognitionDataset(
|
293 |
+
cut_transforms=transforms,
|
294 |
+
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
|
295 |
+
input_transforms=input_transforms,
|
296 |
+
return_cuts=self.args.return_cuts,
|
297 |
+
)
|
298 |
+
|
299 |
+
if self.args.bucketing_sampler:
|
300 |
+
logging.info("Using DynamicBucketingSampler.")
|
301 |
+
train_sampler = DynamicBucketingSampler(
|
302 |
+
cuts_train,
|
303 |
+
max_duration=self.args.max_duration,
|
304 |
+
shuffle=self.args.shuffle,
|
305 |
+
num_buckets=self.args.num_buckets,
|
306 |
+
drop_last=self.args.drop_last,
|
307 |
+
)
|
308 |
+
else:
|
309 |
+
logging.info("Using SingleCutSampler.")
|
310 |
+
train_sampler = SingleCutSampler(
|
311 |
+
cuts_train,
|
312 |
+
max_duration=self.args.max_duration,
|
313 |
+
shuffle=self.args.shuffle,
|
314 |
+
)
|
315 |
+
logging.info("About to create train dataloader")
|
316 |
+
|
317 |
+
if sampler_state_dict is not None:
|
318 |
+
logging.info("Loading sampler state dict")
|
319 |
+
train_sampler.load_state_dict(sampler_state_dict)
|
320 |
+
|
321 |
+
# 'seed' is derived from the current random state, which will have
|
322 |
+
# previously been set in the main process.
|
323 |
+
seed = torch.randint(0, 100000, ()).item()
|
324 |
+
worker_init_fn = _SeedWorkers(seed)
|
325 |
+
|
326 |
+
train_dl = DataLoader(
|
327 |
+
train,
|
328 |
+
sampler=train_sampler,
|
329 |
+
batch_size=None,
|
330 |
+
num_workers=self.args.num_workers,
|
331 |
+
persistent_workers=False,
|
332 |
+
worker_init_fn=worker_init_fn,
|
333 |
+
)
|
334 |
+
|
335 |
+
return train_dl
|
336 |
+
|
337 |
+
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
|
338 |
+
transforms = []
|
339 |
+
if self.args.concatenate_cuts:
|
340 |
+
transforms = [
|
341 |
+
CutConcatenate(
|
342 |
+
duration_factor=self.args.duration_factor, gap=self.args.gap
|
343 |
+
)
|
344 |
+
] + transforms
|
345 |
+
|
346 |
+
logging.info("About to create dev dataset")
|
347 |
+
if self.args.on_the_fly_feats:
|
348 |
+
validate = K2SpeechRecognitionDataset(
|
349 |
+
cut_transforms=transforms,
|
350 |
+
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
|
351 |
+
return_cuts=self.args.return_cuts,
|
352 |
+
)
|
353 |
+
else:
|
354 |
+
validate = K2SpeechRecognitionDataset(
|
355 |
+
cut_transforms=transforms,
|
356 |
+
return_cuts=self.args.return_cuts,
|
357 |
+
)
|
358 |
+
valid_sampler = DynamicBucketingSampler(
|
359 |
+
cuts_valid,
|
360 |
+
max_duration=self.args.max_duration,
|
361 |
+
shuffle=False,
|
362 |
+
)
|
363 |
+
logging.info("About to create dev dataloader")
|
364 |
+
valid_dl = DataLoader(
|
365 |
+
validate,
|
366 |
+
sampler=valid_sampler,
|
367 |
+
batch_size=None,
|
368 |
+
num_workers=2,
|
369 |
+
persistent_workers=False,
|
370 |
+
)
|
371 |
+
|
372 |
+
return valid_dl
|
373 |
+
|
374 |
+
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
|
375 |
+
logging.debug("About to create test dataset")
|
376 |
+
test = K2SpeechRecognitionDataset(
|
377 |
+
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
|
378 |
+
if self.args.on_the_fly_feats
|
379 |
+
else eval(self.args.input_strategy)(),
|
380 |
+
return_cuts=self.args.return_cuts,
|
381 |
+
)
|
382 |
+
sampler = DynamicBucketingSampler(
|
383 |
+
cuts,
|
384 |
+
max_duration=self.args.max_duration,
|
385 |
+
shuffle=False,
|
386 |
+
)
|
387 |
+
logging.debug("About to create test dataloader")
|
388 |
+
test_dl = DataLoader(
|
389 |
+
test,
|
390 |
+
batch_size=None,
|
391 |
+
sampler=sampler,
|
392 |
+
num_workers=self.args.num_workers,
|
393 |
+
)
|
394 |
+
return test_dl
|
395 |
+
|
396 |
+
@lru_cache()
|
397 |
+
def train_clean_100_cuts(self) -> CutSet:
|
398 |
+
logging.info("About to get train-clean-100 cuts")
|
399 |
+
return load_manifest_lazy(
|
400 |
+
self.args.manifest_dir / "err2020_cuts_train.jsonl.gz"
|
401 |
+
)
|
402 |
+
|
403 |
+
# @lru_cache()
|
404 |
+
# def train_clean_360_cuts(self) -> CutSet:
|
405 |
+
# logging.info("About to get train-clean-360 cuts")
|
406 |
+
# return load_manifest_lazy(
|
407 |
+
# self.args.manifest_dir / "librispeech_cuts_train-clean-360.jsonl.gz"
|
408 |
+
# )
|
409 |
+
|
410 |
+
# @lru_cache()
|
411 |
+
# def train_other_500_cuts(self) -> CutSet:
|
412 |
+
# logging.info("About to get train-other-500 cuts")
|
413 |
+
# return load_manifest_lazy(
|
414 |
+
# self.args.manifest_dir / "librispeech_cuts_train-other-500.jsonl.gz"
|
415 |
+
# )
|
416 |
+
|
417 |
+
@lru_cache()
|
418 |
+
def train_all_shuf_cuts(self) -> CutSet:
|
419 |
+
logging.info(
|
420 |
+
"About to get the shuffled train-clean-100, \
|
421 |
+
train-clean-360 and train-other-500 cuts"
|
422 |
+
)
|
423 |
+
return load_manifest_lazy(
|
424 |
+
# self.args.manifest_dir / "librispeech_cuts_train-all-shuf.jsonl.gz"
|
425 |
+
self.args.manifest_dir / "err2020_cuts_train-all-shuf.jsonl.gz"
|
426 |
+
)
|
427 |
+
|
428 |
+
@lru_cache()
|
429 |
+
def dev_clean_cuts(self) -> CutSet:
|
430 |
+
logging.info("About to get dev-clean cuts")
|
431 |
+
return load_manifest_lazy(
|
432 |
+
# self.args.manifest_dir / "librispeech_cuts_dev-clean.jsonl.gz"
|
433 |
+
self.args.manifest_dir / "err2020_cuts_validation.jsonl.gz"
|
434 |
+
)
|
435 |
+
|
436 |
+
# @lru_cache()
|
437 |
+
# def dev_other_cuts(self) -> CutSet:
|
438 |
+
# logging.info("About to get dev-other cuts")
|
439 |
+
# return load_manifest_lazy(
|
440 |
+
# # self.args.manifest_dir / "librispeech_cuts_dev-other.jsonl.gz"
|
441 |
+
# self.args.manifest_dir / "err2020_cuts_validation.jsonl.gz"
|
442 |
+
# )
|
443 |
+
|
444 |
+
@lru_cache()
|
445 |
+
def test_clean_cuts(self) -> CutSet:
|
446 |
+
logging.info("About to get test-clean cuts")
|
447 |
+
return load_manifest_lazy(
|
448 |
+
# self.args.manifest_dir / "librispeech_cuts_test-clean.jsonl.gz"
|
449 |
+
self.args.manifest_dir / "err2020_cuts_test.jsonl.gz"
|
450 |
+
)
|
451 |
+
|
452 |
+
# @lru_cache()
|
453 |
+
# def test_other_cuts(self) -> CutSet:
|
454 |
+
# logging.info("About to get test-other cuts")
|
455 |
+
# return load_manifest_lazy(
|
456 |
+
# self.args.manifest_dir / "librispeech_cuts_test-other.jsonl.gz"
|
457 |
+
# self.args.manifest_dir / "err2020_cuts_test.jsonl.gz"
|
458 |
+
# )
|
err2020/conformer_ctc3/conformer.py
ADDED
@@ -0,0 +1,1598 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# Copyright (c) 2021 University of Chinese Academy of Sciences (author: Han Zhu)
|
3 |
+
#
|
4 |
+
# See ../../../../LICENSE for clarification regarding multiple authors
|
5 |
+
#
|
6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
7 |
+
# you may not use this file except in compliance with the License.
|
8 |
+
# You may obtain a copy of the License at
|
9 |
+
#
|
10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
11 |
+
#
|
12 |
+
# Unless required by applicable law or agreed to in writing, software
|
13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15 |
+
# See the License for the specific language governing permissions and
|
16 |
+
# limitations under the License.
|
17 |
+
|
18 |
+
import copy
|
19 |
+
import math
|
20 |
+
import warnings
|
21 |
+
from typing import List, Optional, Tuple
|
22 |
+
|
23 |
+
import torch
|
24 |
+
from encoder_interface import EncoderInterface
|
25 |
+
from scaling import (
|
26 |
+
ActivationBalancer,
|
27 |
+
BasicNorm,
|
28 |
+
DoubleSwish,
|
29 |
+
ScaledConv1d,
|
30 |
+
ScaledConv2d,
|
31 |
+
ScaledLinear,
|
32 |
+
)
|
33 |
+
from torch import Tensor, nn
|
34 |
+
|
35 |
+
from icefall.utils import is_jit_tracing, make_pad_mask, subsequent_chunk_mask
|
36 |
+
|
37 |
+
|
38 |
+
class Conformer(EncoderInterface):
|
39 |
+
"""
|
40 |
+
Args:
|
41 |
+
num_features (int): Number of input features
|
42 |
+
subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers)
|
43 |
+
d_model (int): attention dimension, also the output dimension
|
44 |
+
nhead (int): number of head
|
45 |
+
dim_feedforward (int): feedforward dimention
|
46 |
+
num_encoder_layers (int): number of encoder layers
|
47 |
+
dropout (float): dropout rate
|
48 |
+
layer_dropout (float): layer-dropout rate.
|
49 |
+
cnn_module_kernel (int): Kernel size of convolution module
|
50 |
+
vgg_frontend (bool): whether to use vgg frontend.
|
51 |
+
dynamic_chunk_training (bool): whether to use dynamic chunk training, if
|
52 |
+
you want to train a streaming model, this is expected to be True.
|
53 |
+
When setting True, it will use a masking strategy to make the attention
|
54 |
+
see only limited left and right context.
|
55 |
+
short_chunk_threshold (float): a threshold to determinize the chunk size
|
56 |
+
to be used in masking training, if the randomly generated chunk size
|
57 |
+
is greater than ``max_len * short_chunk_threshold`` (max_len is the
|
58 |
+
max sequence length of current batch) then it will use
|
59 |
+
full context in training (i.e. with chunk size equals to max_len).
|
60 |
+
This will be used only when dynamic_chunk_training is True.
|
61 |
+
short_chunk_size (int): see docs above, if the randomly generated chunk
|
62 |
+
size equals to or less than ``max_len * short_chunk_threshold``, the
|
63 |
+
chunk size will be sampled uniformly from 1 to short_chunk_size.
|
64 |
+
This also will be used only when dynamic_chunk_training is True.
|
65 |
+
num_left_chunks (int): the left context (in chunks) attention can see, the
|
66 |
+
chunk size is decided by short_chunk_threshold and short_chunk_size.
|
67 |
+
A minus value means seeing full left context.
|
68 |
+
This also will be used only when dynamic_chunk_training is True.
|
69 |
+
causal (bool): Whether to use causal convolution in conformer encoder
|
70 |
+
layer. This MUST be True when using dynamic_chunk_training.
|
71 |
+
"""
|
72 |
+
|
73 |
+
def __init__(
|
74 |
+
self,
|
75 |
+
num_features: int,
|
76 |
+
subsampling_factor: int = 4,
|
77 |
+
d_model: int = 256,
|
78 |
+
nhead: int = 4,
|
79 |
+
dim_feedforward: int = 2048,
|
80 |
+
num_encoder_layers: int = 12,
|
81 |
+
dropout: float = 0.1,
|
82 |
+
layer_dropout: float = 0.075,
|
83 |
+
cnn_module_kernel: int = 31,
|
84 |
+
dynamic_chunk_training: bool = False,
|
85 |
+
short_chunk_threshold: float = 0.75,
|
86 |
+
short_chunk_size: int = 25,
|
87 |
+
num_left_chunks: int = -1,
|
88 |
+
causal: bool = False,
|
89 |
+
) -> None:
|
90 |
+
super(Conformer, self).__init__()
|
91 |
+
|
92 |
+
self.num_features = num_features
|
93 |
+
self.subsampling_factor = subsampling_factor
|
94 |
+
if subsampling_factor != 4:
|
95 |
+
raise NotImplementedError("Support only 'subsampling_factor=4'.")
|
96 |
+
|
97 |
+
# self.encoder_embed converts the input of shape (N, T, num_features)
|
98 |
+
# to the shape (N, T//subsampling_factor, d_model).
|
99 |
+
# That is, it does two things simultaneously:
|
100 |
+
# (1) subsampling: T -> T//subsampling_factor
|
101 |
+
# (2) embedding: num_features -> d_model
|
102 |
+
self.encoder_embed = Conv2dSubsampling(num_features, d_model)
|
103 |
+
|
104 |
+
self.encoder_layers = num_encoder_layers
|
105 |
+
self.d_model = d_model
|
106 |
+
self.cnn_module_kernel = cnn_module_kernel
|
107 |
+
self.causal = causal
|
108 |
+
self.dynamic_chunk_training = dynamic_chunk_training
|
109 |
+
self.short_chunk_threshold = short_chunk_threshold
|
110 |
+
self.short_chunk_size = short_chunk_size
|
111 |
+
self.num_left_chunks = num_left_chunks
|
112 |
+
|
113 |
+
self.encoder_pos = RelPositionalEncoding(d_model, dropout)
|
114 |
+
|
115 |
+
encoder_layer = ConformerEncoderLayer(
|
116 |
+
d_model,
|
117 |
+
nhead,
|
118 |
+
dim_feedforward,
|
119 |
+
dropout,
|
120 |
+
layer_dropout,
|
121 |
+
cnn_module_kernel,
|
122 |
+
causal,
|
123 |
+
)
|
124 |
+
self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers)
|
125 |
+
self._init_state: List[torch.Tensor] = [torch.empty(0)]
|
126 |
+
|
127 |
+
def forward(
|
128 |
+
self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0
|
129 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
130 |
+
"""
|
131 |
+
Args:
|
132 |
+
x:
|
133 |
+
The input tensor. Its shape is (batch_size, seq_len, feature_dim).
|
134 |
+
x_lens:
|
135 |
+
A tensor of shape (batch_size,) containing the number of frames in
|
136 |
+
`x` before padding.
|
137 |
+
warmup:
|
138 |
+
A floating point value that gradually increases from 0 throughout
|
139 |
+
training; when it is >= 1.0 we are "fully warmed up". It is used
|
140 |
+
to turn modules on sequentially.
|
141 |
+
Returns:
|
142 |
+
Return a tuple containing 2 tensors:
|
143 |
+
- embeddings: its shape is (batch_size, output_seq_len, d_model)
|
144 |
+
- lengths, a tensor of shape (batch_size,) containing the number
|
145 |
+
of frames in `embeddings` before padding.
|
146 |
+
"""
|
147 |
+
x = self.encoder_embed(x)
|
148 |
+
x, pos_emb = self.encoder_pos(x)
|
149 |
+
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
150 |
+
|
151 |
+
# Caution: We assume the subsampling factor is 4!
|
152 |
+
|
153 |
+
# lengths = ((x_lens - 1) // 2 - 1) // 2 # issue an warning
|
154 |
+
#
|
155 |
+
# Note: rounding_mode in torch.div() is available only in torch >= 1.8.0
|
156 |
+
lengths = (((x_lens - 1) >> 1) - 1) >> 1
|
157 |
+
|
158 |
+
if not is_jit_tracing():
|
159 |
+
assert x.size(0) == lengths.max().item()
|
160 |
+
|
161 |
+
src_key_padding_mask = make_pad_mask(lengths)
|
162 |
+
|
163 |
+
if self.dynamic_chunk_training:
|
164 |
+
assert (
|
165 |
+
self.causal
|
166 |
+
), "Causal convolution is required for streaming conformer."
|
167 |
+
max_len = x.size(0)
|
168 |
+
chunk_size = torch.randint(1, max_len, (1,)).item()
|
169 |
+
if chunk_size > (max_len * self.short_chunk_threshold):
|
170 |
+
chunk_size = max_len
|
171 |
+
else:
|
172 |
+
chunk_size = chunk_size % self.short_chunk_size + 1
|
173 |
+
|
174 |
+
mask = ~subsequent_chunk_mask(
|
175 |
+
size=x.size(0),
|
176 |
+
chunk_size=chunk_size,
|
177 |
+
num_left_chunks=self.num_left_chunks,
|
178 |
+
device=x.device,
|
179 |
+
)
|
180 |
+
x = self.encoder(
|
181 |
+
x,
|
182 |
+
pos_emb,
|
183 |
+
mask=mask,
|
184 |
+
src_key_padding_mask=src_key_padding_mask,
|
185 |
+
warmup=warmup,
|
186 |
+
) # (T, N, C)
|
187 |
+
else:
|
188 |
+
x = self.encoder(
|
189 |
+
x,
|
190 |
+
pos_emb,
|
191 |
+
mask=None,
|
192 |
+
src_key_padding_mask=src_key_padding_mask,
|
193 |
+
warmup=warmup,
|
194 |
+
) # (T, N, C)
|
195 |
+
|
196 |
+
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
197 |
+
return x, lengths
|
198 |
+
|
199 |
+
@torch.jit.export
|
200 |
+
def get_init_state(
|
201 |
+
self, left_context: int, device: torch.device
|
202 |
+
) -> List[torch.Tensor]:
|
203 |
+
"""Return the initial cache state of the model.
|
204 |
+
|
205 |
+
Args:
|
206 |
+
left_context: The left context size (in frames after subsampling).
|
207 |
+
|
208 |
+
Returns:
|
209 |
+
Return the initial state of the model, it is a list containing two
|
210 |
+
tensors, the first one is the cache for attentions which has a shape
|
211 |
+
of (num_encoder_layers, left_context, encoder_dim), the second one
|
212 |
+
is the cache of conv_modules which has a shape of
|
213 |
+
(num_encoder_layers, cnn_module_kernel - 1, encoder_dim).
|
214 |
+
|
215 |
+
NOTE: the returned tensors are on the given device.
|
216 |
+
"""
|
217 |
+
if len(self._init_state) == 2 and self._init_state[0].size(1) == left_context:
|
218 |
+
# Note: It is OK to share the init state as it is
|
219 |
+
# not going to be modified by the model
|
220 |
+
return self._init_state
|
221 |
+
|
222 |
+
init_states: List[torch.Tensor] = [
|
223 |
+
torch.zeros(
|
224 |
+
(
|
225 |
+
self.encoder_layers,
|
226 |
+
left_context,
|
227 |
+
self.d_model,
|
228 |
+
),
|
229 |
+
device=device,
|
230 |
+
),
|
231 |
+
torch.zeros(
|
232 |
+
(
|
233 |
+
self.encoder_layers,
|
234 |
+
self.cnn_module_kernel - 1,
|
235 |
+
self.d_model,
|
236 |
+
),
|
237 |
+
device=device,
|
238 |
+
),
|
239 |
+
]
|
240 |
+
|
241 |
+
self._init_state = init_states
|
242 |
+
|
243 |
+
return init_states
|
244 |
+
|
245 |
+
@torch.jit.export
|
246 |
+
def streaming_forward(
|
247 |
+
self,
|
248 |
+
x: torch.Tensor,
|
249 |
+
x_lens: torch.Tensor,
|
250 |
+
states: Optional[List[Tensor]] = None,
|
251 |
+
processed_lens: Optional[Tensor] = None,
|
252 |
+
left_context: int = 64,
|
253 |
+
right_context: int = 4,
|
254 |
+
chunk_size: int = 16,
|
255 |
+
simulate_streaming: bool = False,
|
256 |
+
warmup: float = 1.0,
|
257 |
+
) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]:
|
258 |
+
"""
|
259 |
+
Args:
|
260 |
+
x:
|
261 |
+
The input tensor. Its shape is (batch_size, seq_len, feature_dim).
|
262 |
+
x_lens:
|
263 |
+
A tensor of shape (batch_size,) containing the number of frames in
|
264 |
+
`x` before padding.
|
265 |
+
states:
|
266 |
+
The decode states for previous frames which contains the cached data.
|
267 |
+
It has two elements, the first element is the attn_cache which has
|
268 |
+
a shape of (encoder_layers, left_context, batch, attention_dim),
|
269 |
+
the second element is the conv_cache which has a shape of
|
270 |
+
(encoder_layers, cnn_module_kernel-1, batch, conv_dim).
|
271 |
+
Note: states will be modified in this function.
|
272 |
+
processed_lens:
|
273 |
+
How many frames (after subsampling) have been processed for each sequence.
|
274 |
+
left_context:
|
275 |
+
How many previous frames the attention can see in current chunk.
|
276 |
+
Note: It's not that each individual frame has `left_context` frames
|
277 |
+
of left context, some have more.
|
278 |
+
right_context:
|
279 |
+
How many future frames the attention can see in current chunk.
|
280 |
+
Note: It's not that each individual frame has `right_context` frames
|
281 |
+
of right context, some have more.
|
282 |
+
chunk_size:
|
283 |
+
The chunk size for decoding, this will be used to simulate streaming
|
284 |
+
decoding using masking.
|
285 |
+
simulate_streaming:
|
286 |
+
If setting True, it will use a masking strategy to simulate streaming
|
287 |
+
fashion (i.e. every chunk data only see limited left context and
|
288 |
+
right context). The whole sequence is supposed to be send at a time
|
289 |
+
When using simulate_streaming.
|
290 |
+
warmup:
|
291 |
+
A floating point value that gradually increases from 0 throughout
|
292 |
+
training; when it is >= 1.0 we are "fully warmed up". It is used
|
293 |
+
to turn modules on sequentially.
|
294 |
+
Returns:
|
295 |
+
Return a tuple containing 2 tensors:
|
296 |
+
- logits, its shape is (batch_size, output_seq_len, output_dim)
|
297 |
+
- logit_lens, a tensor of shape (batch_size,) containing the number
|
298 |
+
of frames in `logits` before padding.
|
299 |
+
- decode_states, the updated states including the information
|
300 |
+
of current chunk.
|
301 |
+
"""
|
302 |
+
|
303 |
+
# x: [N, T, C]
|
304 |
+
# Caution: We assume the subsampling factor is 4!
|
305 |
+
|
306 |
+
# lengths = ((x_lens - 1) // 2 - 1) // 2 # issue an warning
|
307 |
+
#
|
308 |
+
# Note: rounding_mode in torch.div() is available only in torch >= 1.8.0
|
309 |
+
lengths = (((x_lens - 1) >> 1) - 1) >> 1
|
310 |
+
|
311 |
+
if not simulate_streaming:
|
312 |
+
assert states is not None
|
313 |
+
assert processed_lens is not None
|
314 |
+
assert (
|
315 |
+
len(states) == 2
|
316 |
+
and states[0].shape
|
317 |
+
== (self.encoder_layers, left_context, x.size(0), self.d_model)
|
318 |
+
and states[1].shape
|
319 |
+
== (
|
320 |
+
self.encoder_layers,
|
321 |
+
self.cnn_module_kernel - 1,
|
322 |
+
x.size(0),
|
323 |
+
self.d_model,
|
324 |
+
)
|
325 |
+
), f"""The length of states MUST be equal to 2, and the shape of
|
326 |
+
first element should be {(self.encoder_layers, left_context, x.size(0), self.d_model)},
|
327 |
+
given {states[0].shape}. the shape of second element should be
|
328 |
+
{(self.encoder_layers, self.cnn_module_kernel - 1, x.size(0), self.d_model)},
|
329 |
+
given {states[1].shape}."""
|
330 |
+
|
331 |
+
lengths -= 2 # we will cut off 1 frame on each side of encoder_embed output
|
332 |
+
|
333 |
+
src_key_padding_mask = make_pad_mask(lengths)
|
334 |
+
|
335 |
+
processed_mask = torch.arange(left_context, device=x.device).expand(
|
336 |
+
x.size(0), left_context
|
337 |
+
)
|
338 |
+
processed_lens = processed_lens.view(x.size(0), 1)
|
339 |
+
processed_mask = (processed_lens <= processed_mask).flip(1)
|
340 |
+
|
341 |
+
src_key_padding_mask = torch.cat(
|
342 |
+
[processed_mask, src_key_padding_mask], dim=1
|
343 |
+
)
|
344 |
+
|
345 |
+
embed = self.encoder_embed(x)
|
346 |
+
|
347 |
+
# cut off 1 frame on each size of embed as they see the padding
|
348 |
+
# value which causes a training and decoding mismatch.
|
349 |
+
embed = embed[:, 1:-1, :]
|
350 |
+
|
351 |
+
embed, pos_enc = self.encoder_pos(embed, left_context)
|
352 |
+
embed = embed.permute(1, 0, 2) # (B, T, F) -> (T, B, F)
|
353 |
+
|
354 |
+
x, states = self.encoder.chunk_forward(
|
355 |
+
embed,
|
356 |
+
pos_enc,
|
357 |
+
src_key_padding_mask=src_key_padding_mask,
|
358 |
+
warmup=warmup,
|
359 |
+
states=states,
|
360 |
+
left_context=left_context,
|
361 |
+
right_context=right_context,
|
362 |
+
) # (T, B, F)
|
363 |
+
if right_context > 0:
|
364 |
+
x = x[0:-right_context, ...]
|
365 |
+
lengths -= right_context
|
366 |
+
else:
|
367 |
+
assert states is None
|
368 |
+
states = [] # just to make torch.script.jit happy
|
369 |
+
# this branch simulates streaming decoding using mask as we are
|
370 |
+
# using in training time.
|
371 |
+
src_key_padding_mask = make_pad_mask(lengths)
|
372 |
+
x = self.encoder_embed(x)
|
373 |
+
x, pos_emb = self.encoder_pos(x)
|
374 |
+
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
375 |
+
|
376 |
+
assert x.size(0) == lengths.max().item()
|
377 |
+
|
378 |
+
if chunk_size < 0:
|
379 |
+
# use full attention
|
380 |
+
chunk_size = x.size(0)
|
381 |
+
left_context = -1
|
382 |
+
|
383 |
+
num_left_chunks = -1
|
384 |
+
if left_context >= 0:
|
385 |
+
assert left_context % chunk_size == 0
|
386 |
+
num_left_chunks = left_context // chunk_size
|
387 |
+
|
388 |
+
mask = ~subsequent_chunk_mask(
|
389 |
+
size=x.size(0),
|
390 |
+
chunk_size=chunk_size,
|
391 |
+
num_left_chunks=num_left_chunks,
|
392 |
+
device=x.device,
|
393 |
+
)
|
394 |
+
x = self.encoder(
|
395 |
+
x,
|
396 |
+
pos_emb,
|
397 |
+
mask=mask,
|
398 |
+
src_key_padding_mask=src_key_padding_mask,
|
399 |
+
warmup=warmup,
|
400 |
+
) # (T, N, C)
|
401 |
+
|
402 |
+
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
403 |
+
|
404 |
+
return x, lengths, states
|
405 |
+
|
406 |
+
|
407 |
+
class ConformerEncoderLayer(nn.Module):
|
408 |
+
"""
|
409 |
+
ConformerEncoderLayer is made up of self-attn, feedforward and convolution networks.
|
410 |
+
See: "Conformer: Convolution-augmented Transformer for Speech Recognition"
|
411 |
+
|
412 |
+
Args:
|
413 |
+
d_model: the number of expected features in the input (required).
|
414 |
+
nhead: the number of heads in the multiheadattention models (required).
|
415 |
+
dim_feedforward: the dimension of the feedforward network model (default=2048).
|
416 |
+
dropout: the dropout value (default=0.1).
|
417 |
+
cnn_module_kernel (int): Kernel size of convolution module.
|
418 |
+
causal (bool): Whether to use causal convolution in conformer encoder
|
419 |
+
layer. This MUST be True when using dynamic_chunk_training and streaming decoding.
|
420 |
+
|
421 |
+
Examples::
|
422 |
+
>>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8)
|
423 |
+
>>> src = torch.rand(10, 32, 512)
|
424 |
+
>>> pos_emb = torch.rand(32, 19, 512)
|
425 |
+
>>> out = encoder_layer(src, pos_emb)
|
426 |
+
"""
|
427 |
+
|
428 |
+
def __init__(
|
429 |
+
self,
|
430 |
+
d_model: int,
|
431 |
+
nhead: int,
|
432 |
+
dim_feedforward: int = 2048,
|
433 |
+
dropout: float = 0.1,
|
434 |
+
layer_dropout: float = 0.075,
|
435 |
+
cnn_module_kernel: int = 31,
|
436 |
+
causal: bool = False,
|
437 |
+
) -> None:
|
438 |
+
super(ConformerEncoderLayer, self).__init__()
|
439 |
+
|
440 |
+
self.layer_dropout = layer_dropout
|
441 |
+
|
442 |
+
self.d_model = d_model
|
443 |
+
|
444 |
+
self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0)
|
445 |
+
|
446 |
+
self.feed_forward = nn.Sequential(
|
447 |
+
ScaledLinear(d_model, dim_feedforward),
|
448 |
+
ActivationBalancer(channel_dim=-1),
|
449 |
+
DoubleSwish(),
|
450 |
+
nn.Dropout(dropout),
|
451 |
+
ScaledLinear(dim_feedforward, d_model, initial_scale=0.25),
|
452 |
+
)
|
453 |
+
|
454 |
+
self.feed_forward_macaron = nn.Sequential(
|
455 |
+
ScaledLinear(d_model, dim_feedforward),
|
456 |
+
ActivationBalancer(channel_dim=-1),
|
457 |
+
DoubleSwish(),
|
458 |
+
nn.Dropout(dropout),
|
459 |
+
ScaledLinear(dim_feedforward, d_model, initial_scale=0.25),
|
460 |
+
)
|
461 |
+
|
462 |
+
self.conv_module = ConvolutionModule(d_model, cnn_module_kernel, causal=causal)
|
463 |
+
|
464 |
+
self.norm_final = BasicNorm(d_model)
|
465 |
+
|
466 |
+
# try to ensure the output is close to zero-mean (or at least, zero-median).
|
467 |
+
self.balancer = ActivationBalancer(
|
468 |
+
channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0
|
469 |
+
)
|
470 |
+
|
471 |
+
self.dropout = nn.Dropout(dropout)
|
472 |
+
|
473 |
+
def forward(
|
474 |
+
self,
|
475 |
+
src: Tensor,
|
476 |
+
pos_emb: Tensor,
|
477 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
478 |
+
src_mask: Optional[Tensor] = None,
|
479 |
+
warmup: float = 1.0,
|
480 |
+
) -> Tensor:
|
481 |
+
"""
|
482 |
+
Pass the input through the encoder layer.
|
483 |
+
|
484 |
+
Args:
|
485 |
+
src: the sequence to the encoder layer (required).
|
486 |
+
pos_emb: Positional embedding tensor (required).
|
487 |
+
src_key_padding_mask: the mask for the src keys per batch (optional).
|
488 |
+
src_mask: the mask for the src sequence (optional).
|
489 |
+
warmup: controls selective bypass of of layers; if < 1.0, we will
|
490 |
+
bypass layers more frequently.
|
491 |
+
Shape:
|
492 |
+
src: (S, N, E).
|
493 |
+
pos_emb: (N, 2*S-1, E)
|
494 |
+
src_mask: (S, S).
|
495 |
+
src_key_padding_mask: (N, S).
|
496 |
+
S is the source sequence length, N is the batch size, E is the feature number
|
497 |
+
"""
|
498 |
+
src_orig = src
|
499 |
+
|
500 |
+
warmup_scale = min(0.1 + warmup, 1.0)
|
501 |
+
# alpha = 1.0 means fully use this encoder layer, 0.0 would mean
|
502 |
+
# completely bypass it.
|
503 |
+
if self.training:
|
504 |
+
alpha = (
|
505 |
+
warmup_scale
|
506 |
+
if torch.rand(()).item() <= (1.0 - self.layer_dropout)
|
507 |
+
else 0.1
|
508 |
+
)
|
509 |
+
else:
|
510 |
+
alpha = 1.0
|
511 |
+
|
512 |
+
# macaron style feed forward module
|
513 |
+
src = src + self.dropout(self.feed_forward_macaron(src))
|
514 |
+
|
515 |
+
# multi-headed self-attention module
|
516 |
+
src_att = self.self_attn(
|
517 |
+
src,
|
518 |
+
src,
|
519 |
+
src,
|
520 |
+
pos_emb=pos_emb,
|
521 |
+
attn_mask=src_mask,
|
522 |
+
key_padding_mask=src_key_padding_mask,
|
523 |
+
)[0]
|
524 |
+
|
525 |
+
src = src + self.dropout(src_att)
|
526 |
+
|
527 |
+
# convolution module
|
528 |
+
conv, _ = self.conv_module(src, src_key_padding_mask=src_key_padding_mask)
|
529 |
+
src = src + self.dropout(conv)
|
530 |
+
|
531 |
+
# feed forward module
|
532 |
+
src = src + self.dropout(self.feed_forward(src))
|
533 |
+
|
534 |
+
src = self.norm_final(self.balancer(src))
|
535 |
+
|
536 |
+
if alpha != 1.0:
|
537 |
+
src = alpha * src + (1 - alpha) * src_orig
|
538 |
+
|
539 |
+
return src
|
540 |
+
|
541 |
+
@torch.jit.export
|
542 |
+
def chunk_forward(
|
543 |
+
self,
|
544 |
+
src: Tensor,
|
545 |
+
pos_emb: Tensor,
|
546 |
+
states: List[Tensor],
|
547 |
+
src_mask: Optional[Tensor] = None,
|
548 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
549 |
+
warmup: float = 1.0,
|
550 |
+
left_context: int = 0,
|
551 |
+
right_context: int = 0,
|
552 |
+
) -> Tuple[Tensor, List[Tensor]]:
|
553 |
+
"""
|
554 |
+
Pass the input through the encoder layer.
|
555 |
+
|
556 |
+
Args:
|
557 |
+
src: the sequence to the encoder layer (required).
|
558 |
+
pos_emb: Positional embedding tensor (required).
|
559 |
+
states:
|
560 |
+
The decode states for previous frames which contains the cached data.
|
561 |
+
It has two elements, the first element is the attn_cache which has
|
562 |
+
a shape of (left_context, batch, attention_dim),
|
563 |
+
the second element is the conv_cache which has a shape of
|
564 |
+
(cnn_module_kernel-1, batch, conv_dim).
|
565 |
+
Note: states will be modified in this function.
|
566 |
+
src_mask: the mask for the src sequence (optional).
|
567 |
+
src_key_padding_mask: the mask for the src keys per batch (optional).
|
568 |
+
warmup: controls selective bypass of of layers; if < 1.0, we will
|
569 |
+
bypass layers more frequently.
|
570 |
+
left_context:
|
571 |
+
How many previous frames the attention can see in current chunk.
|
572 |
+
Note: It's not that each individual frame has `left_context` frames
|
573 |
+
of left context, some have more.
|
574 |
+
right_context:
|
575 |
+
How many future frames the attention can see in current chunk.
|
576 |
+
Note: It's not that each individual frame has `right_context` frames
|
577 |
+
of right context, some have more.
|
578 |
+
|
579 |
+
Shape:
|
580 |
+
src: (S, N, E).
|
581 |
+
pos_emb: (N, 2*(S+left_context)-1, E).
|
582 |
+
src_mask: (S, S).
|
583 |
+
src_key_padding_mask: (N, S).
|
584 |
+
S is the source sequence length, N is the batch size, E is the feature number
|
585 |
+
"""
|
586 |
+
|
587 |
+
assert not self.training
|
588 |
+
assert len(states) == 2
|
589 |
+
assert states[0].shape == (left_context, src.size(1), src.size(2))
|
590 |
+
|
591 |
+
# macaron style feed forward module
|
592 |
+
src = src + self.dropout(self.feed_forward_macaron(src))
|
593 |
+
|
594 |
+
# We put the attention cache this level (i.e. before linear transformation)
|
595 |
+
# to save memory consumption, when decoding in streaming fashion, the
|
596 |
+
# batch size would be thousands (for 32GB machine), if we cache key & val
|
597 |
+
# separately, it needs extra several GB memory.
|
598 |
+
# TODO(WeiKang): Move cache to self_attn level (i.e. cache key & val
|
599 |
+
# separately) if needed.
|
600 |
+
key = torch.cat([states[0], src], dim=0)
|
601 |
+
val = key
|
602 |
+
if right_context > 0:
|
603 |
+
states[0] = key[
|
604 |
+
-(left_context + right_context) : -right_context, ... # noqa
|
605 |
+
]
|
606 |
+
else:
|
607 |
+
states[0] = key[-left_context:, ...]
|
608 |
+
|
609 |
+
# multi-headed self-attention module
|
610 |
+
src_att = self.self_attn(
|
611 |
+
src,
|
612 |
+
key,
|
613 |
+
val,
|
614 |
+
pos_emb=pos_emb,
|
615 |
+
attn_mask=src_mask,
|
616 |
+
key_padding_mask=src_key_padding_mask,
|
617 |
+
left_context=left_context,
|
618 |
+
)[0]
|
619 |
+
|
620 |
+
src = src + self.dropout(src_att)
|
621 |
+
|
622 |
+
# convolution module
|
623 |
+
conv, conv_cache = self.conv_module(src, states[1], right_context)
|
624 |
+
states[1] = conv_cache
|
625 |
+
|
626 |
+
src = src + self.dropout(conv)
|
627 |
+
|
628 |
+
# feed forward module
|
629 |
+
src = src + self.dropout(self.feed_forward(src))
|
630 |
+
|
631 |
+
src = self.norm_final(self.balancer(src))
|
632 |
+
|
633 |
+
return src, states
|
634 |
+
|
635 |
+
|
636 |
+
class ConformerEncoder(nn.Module):
|
637 |
+
r"""ConformerEncoder is a stack of N encoder layers
|
638 |
+
|
639 |
+
Args:
|
640 |
+
encoder_layer: an instance of the ConformerEncoderLayer() class (required).
|
641 |
+
num_layers: the number of sub-encoder-layers in the encoder (required).
|
642 |
+
|
643 |
+
Examples::
|
644 |
+
>>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8)
|
645 |
+
>>> conformer_encoder = ConformerEncoder(encoder_layer, num_layers=6)
|
646 |
+
>>> src = torch.rand(10, 32, 512)
|
647 |
+
>>> pos_emb = torch.rand(32, 19, 512)
|
648 |
+
>>> out = conformer_encoder(src, pos_emb)
|
649 |
+
"""
|
650 |
+
|
651 |
+
def __init__(self, encoder_layer: nn.Module, num_layers: int) -> None:
|
652 |
+
super().__init__()
|
653 |
+
self.layers = nn.ModuleList(
|
654 |
+
[copy.deepcopy(encoder_layer) for i in range(num_layers)]
|
655 |
+
)
|
656 |
+
self.num_layers = num_layers
|
657 |
+
|
658 |
+
def forward(
|
659 |
+
self,
|
660 |
+
src: Tensor,
|
661 |
+
pos_emb: Tensor,
|
662 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
663 |
+
mask: Optional[Tensor] = None,
|
664 |
+
warmup: float = 1.0,
|
665 |
+
) -> Tensor:
|
666 |
+
r"""Pass the input through the encoder layers in turn.
|
667 |
+
|
668 |
+
Args:
|
669 |
+
src: the sequence to the encoder (required).
|
670 |
+
pos_emb: Positional embedding tensor (required).
|
671 |
+
src_key_padding_mask: the mask for the src keys per batch (optional).
|
672 |
+
mask: the mask for the src sequence (optional).
|
673 |
+
warmup: controls selective bypass of of layers; if < 1.0, we will
|
674 |
+
bypass layers more frequently.
|
675 |
+
|
676 |
+
Shape:
|
677 |
+
src: (S, N, E).
|
678 |
+
pos_emb: (N, 2*S-1, E)
|
679 |
+
mask: (S, S).
|
680 |
+
src_key_padding_mask: (N, S).
|
681 |
+
S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number
|
682 |
+
|
683 |
+
"""
|
684 |
+
output = src
|
685 |
+
|
686 |
+
for layer_index, mod in enumerate(self.layers):
|
687 |
+
output = mod(
|
688 |
+
output,
|
689 |
+
pos_emb,
|
690 |
+
src_mask=mask,
|
691 |
+
src_key_padding_mask=src_key_padding_mask,
|
692 |
+
warmup=warmup,
|
693 |
+
)
|
694 |
+
|
695 |
+
return output
|
696 |
+
|
697 |
+
@torch.jit.export
|
698 |
+
def chunk_forward(
|
699 |
+
self,
|
700 |
+
src: Tensor,
|
701 |
+
pos_emb: Tensor,
|
702 |
+
states: List[Tensor],
|
703 |
+
mask: Optional[Tensor] = None,
|
704 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
705 |
+
warmup: float = 1.0,
|
706 |
+
left_context: int = 0,
|
707 |
+
right_context: int = 0,
|
708 |
+
) -> Tuple[Tensor, List[Tensor]]:
|
709 |
+
r"""Pass the input through the encoder layers in turn.
|
710 |
+
|
711 |
+
Args:
|
712 |
+
src: the sequence to the encoder (required).
|
713 |
+
pos_emb: Positional embedding tensor (required).
|
714 |
+
states:
|
715 |
+
The decode states for previous frames which contains the cached data.
|
716 |
+
It has two elements, the first element is the attn_cache which has
|
717 |
+
a shape of (encoder_layers, left_context, batch, attention_dim),
|
718 |
+
the second element is the conv_cache which has a shape of
|
719 |
+
(encoder_layers, cnn_module_kernel-1, batch, conv_dim).
|
720 |
+
Note: states will be modified in this function.
|
721 |
+
mask: the mask for the src sequence (optional).
|
722 |
+
src_key_padding_mask: the mask for the src keys per batch (optional).
|
723 |
+
warmup: controls selective bypass of of layers; if < 1.0, we will
|
724 |
+
bypass layers more frequently.
|
725 |
+
left_context:
|
726 |
+
How many previous frames the attention can see in current chunk.
|
727 |
+
Note: It's not that each individual frame has `left_context` frames
|
728 |
+
of left context, some have more.
|
729 |
+
right_context:
|
730 |
+
How many future frames the attention can see in current chunk.
|
731 |
+
Note: It's not that each individual frame has `right_context` frames
|
732 |
+
of right context, some have more.
|
733 |
+
Shape:
|
734 |
+
src: (S, N, E).
|
735 |
+
pos_emb: (N, 2*(S+left_context)-1, E).
|
736 |
+
mask: (S, S).
|
737 |
+
src_key_padding_mask: (N, S).
|
738 |
+
S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number
|
739 |
+
|
740 |
+
"""
|
741 |
+
assert not self.training
|
742 |
+
assert len(states) == 2
|
743 |
+
assert states[0].shape == (
|
744 |
+
self.num_layers,
|
745 |
+
left_context,
|
746 |
+
src.size(1),
|
747 |
+
src.size(2),
|
748 |
+
)
|
749 |
+
assert states[1].size(0) == self.num_layers
|
750 |
+
|
751 |
+
output = src
|
752 |
+
|
753 |
+
for layer_index, mod in enumerate(self.layers):
|
754 |
+
cache = [states[0][layer_index], states[1][layer_index]]
|
755 |
+
output, cache = mod.chunk_forward(
|
756 |
+
output,
|
757 |
+
pos_emb,
|
758 |
+
states=cache,
|
759 |
+
src_mask=mask,
|
760 |
+
src_key_padding_mask=src_key_padding_mask,
|
761 |
+
warmup=warmup,
|
762 |
+
left_context=left_context,
|
763 |
+
right_context=right_context,
|
764 |
+
)
|
765 |
+
states[0][layer_index] = cache[0]
|
766 |
+
states[1][layer_index] = cache[1]
|
767 |
+
|
768 |
+
return output, states
|
769 |
+
|
770 |
+
|
771 |
+
class RelPositionalEncoding(torch.nn.Module):
|
772 |
+
"""Relative positional encoding module.
|
773 |
+
|
774 |
+
See : Appendix B in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
|
775 |
+
Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/embedding.py
|
776 |
+
|
777 |
+
Args:
|
778 |
+
d_model: Embedding dimension.
|
779 |
+
dropout_rate: Dropout rate.
|
780 |
+
max_len: Maximum input length.
|
781 |
+
|
782 |
+
"""
|
783 |
+
|
784 |
+
def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None:
|
785 |
+
"""Construct an PositionalEncoding object."""
|
786 |
+
super(RelPositionalEncoding, self).__init__()
|
787 |
+
if is_jit_tracing():
|
788 |
+
# 10k frames correspond to ~100k ms, e.g., 100 seconds, i.e.,
|
789 |
+
# It assumes that the maximum input won't have more than
|
790 |
+
# 10k frames.
|
791 |
+
#
|
792 |
+
# TODO(fangjun): Use torch.jit.script() for this module
|
793 |
+
max_len = 10000
|
794 |
+
|
795 |
+
self.d_model = d_model
|
796 |
+
self.dropout = torch.nn.Dropout(p=dropout_rate)
|
797 |
+
self.pe = None
|
798 |
+
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
|
799 |
+
|
800 |
+
def extend_pe(self, x: Tensor, left_context: int = 0) -> None:
|
801 |
+
"""Reset the positional encodings."""
|
802 |
+
x_size_1 = x.size(1) + left_context
|
803 |
+
if self.pe is not None:
|
804 |
+
# self.pe contains both positive and negative parts
|
805 |
+
# the length of self.pe is 2 * input_len - 1
|
806 |
+
if self.pe.size(1) >= x_size_1 * 2 - 1:
|
807 |
+
# Note: TorchScript doesn't implement operator== for torch.Device
|
808 |
+
if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device):
|
809 |
+
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
|
810 |
+
return
|
811 |
+
# Suppose `i` means to the position of query vector and `j` means the
|
812 |
+
# position of key vector. We use position relative positions when keys
|
813 |
+
# are to the left (i>j) and negative relative positions otherwise (i<j).
|
814 |
+
pe_positive = torch.zeros(x_size_1, self.d_model)
|
815 |
+
pe_negative = torch.zeros(x_size_1, self.d_model)
|
816 |
+
position = torch.arange(0, x_size_1, dtype=torch.float32).unsqueeze(1)
|
817 |
+
div_term = torch.exp(
|
818 |
+
torch.arange(0, self.d_model, 2, dtype=torch.float32)
|
819 |
+
* -(math.log(10000.0) / self.d_model)
|
820 |
+
)
|
821 |
+
pe_positive[:, 0::2] = torch.sin(position * div_term)
|
822 |
+
pe_positive[:, 1::2] = torch.cos(position * div_term)
|
823 |
+
pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
|
824 |
+
pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
|
825 |
+
|
826 |
+
# Reserve the order of positive indices and concat both positive and
|
827 |
+
# negative indices. This is used to support the shifting trick
|
828 |
+
# as in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
|
829 |
+
pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
|
830 |
+
pe_negative = pe_negative[1:].unsqueeze(0)
|
831 |
+
pe = torch.cat([pe_positive, pe_negative], dim=1)
|
832 |
+
self.pe = pe.to(device=x.device, dtype=x.dtype)
|
833 |
+
|
834 |
+
def forward(
|
835 |
+
self,
|
836 |
+
x: torch.Tensor,
|
837 |
+
left_context: int = 0,
|
838 |
+
) -> Tuple[Tensor, Tensor]:
|
839 |
+
"""Add positional encoding.
|
840 |
+
|
841 |
+
Args:
|
842 |
+
x (torch.Tensor): Input tensor (batch, time, `*`).
|
843 |
+
left_context (int): left context (in frames) used during streaming decoding.
|
844 |
+
this is used only in real streaming decoding, in other circumstances,
|
845 |
+
it MUST be 0.
|
846 |
+
|
847 |
+
Returns:
|
848 |
+
torch.Tensor: Encoded tensor (batch, time, `*`).
|
849 |
+
torch.Tensor: Encoded tensor (batch, 2*time-1, `*`).
|
850 |
+
|
851 |
+
"""
|
852 |
+
self.extend_pe(x, left_context)
|
853 |
+
x_size_1 = x.size(1) + left_context
|
854 |
+
pos_emb = self.pe[
|
855 |
+
:,
|
856 |
+
self.pe.size(1) // 2
|
857 |
+
- x_size_1
|
858 |
+
+ 1 : self.pe.size(1) // 2 # noqa E203
|
859 |
+
+ x.size(1),
|
860 |
+
]
|
861 |
+
return self.dropout(x), self.dropout(pos_emb)
|
862 |
+
|
863 |
+
|
864 |
+
class RelPositionMultiheadAttention(nn.Module):
|
865 |
+
r"""Multi-Head Attention layer with relative position encoding
|
866 |
+
|
867 |
+
See reference: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
|
868 |
+
|
869 |
+
Args:
|
870 |
+
embed_dim: total dimension of the model.
|
871 |
+
num_heads: parallel attention heads.
|
872 |
+
dropout: a Dropout layer on attn_output_weights. Default: 0.0.
|
873 |
+
|
874 |
+
Examples::
|
875 |
+
|
876 |
+
>>> rel_pos_multihead_attn = RelPositionMultiheadAttention(embed_dim, num_heads)
|
877 |
+
>>> attn_output, attn_output_weights = multihead_attn(query, key, value, pos_emb)
|
878 |
+
"""
|
879 |
+
|
880 |
+
def __init__(
|
881 |
+
self,
|
882 |
+
embed_dim: int,
|
883 |
+
num_heads: int,
|
884 |
+
dropout: float = 0.0,
|
885 |
+
) -> None:
|
886 |
+
super(RelPositionMultiheadAttention, self).__init__()
|
887 |
+
self.embed_dim = embed_dim
|
888 |
+
self.num_heads = num_heads
|
889 |
+
self.dropout = dropout
|
890 |
+
self.head_dim = embed_dim // num_heads
|
891 |
+
assert (
|
892 |
+
self.head_dim * num_heads == self.embed_dim
|
893 |
+
), "embed_dim must be divisible by num_heads"
|
894 |
+
|
895 |
+
self.in_proj = ScaledLinear(embed_dim, 3 * embed_dim, bias=True)
|
896 |
+
self.out_proj = ScaledLinear(
|
897 |
+
embed_dim, embed_dim, bias=True, initial_scale=0.25
|
898 |
+
)
|
899 |
+
|
900 |
+
# linear transformation for positional encoding.
|
901 |
+
self.linear_pos = ScaledLinear(embed_dim, embed_dim, bias=False)
|
902 |
+
# these two learnable bias are used in matrix c and matrix d
|
903 |
+
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
|
904 |
+
self.pos_bias_u = nn.Parameter(torch.Tensor(num_heads, self.head_dim))
|
905 |
+
self.pos_bias_v = nn.Parameter(torch.Tensor(num_heads, self.head_dim))
|
906 |
+
self.pos_bias_u_scale = nn.Parameter(torch.zeros(()).detach())
|
907 |
+
self.pos_bias_v_scale = nn.Parameter(torch.zeros(()).detach())
|
908 |
+
self._reset_parameters()
|
909 |
+
|
910 |
+
def _pos_bias_u(self):
|
911 |
+
return self.pos_bias_u * self.pos_bias_u_scale.exp()
|
912 |
+
|
913 |
+
def _pos_bias_v(self):
|
914 |
+
return self.pos_bias_v * self.pos_bias_v_scale.exp()
|
915 |
+
|
916 |
+
def _reset_parameters(self) -> None:
|
917 |
+
nn.init.normal_(self.pos_bias_u, std=0.01)
|
918 |
+
nn.init.normal_(self.pos_bias_v, std=0.01)
|
919 |
+
|
920 |
+
def forward(
|
921 |
+
self,
|
922 |
+
query: Tensor,
|
923 |
+
key: Tensor,
|
924 |
+
value: Tensor,
|
925 |
+
pos_emb: Tensor,
|
926 |
+
key_padding_mask: Optional[Tensor] = None,
|
927 |
+
need_weights: bool = False,
|
928 |
+
attn_mask: Optional[Tensor] = None,
|
929 |
+
left_context: int = 0,
|
930 |
+
) -> Tuple[Tensor, Optional[Tensor]]:
|
931 |
+
r"""
|
932 |
+
Args:
|
933 |
+
query, key, value: map a query and a set of key-value pairs to an output.
|
934 |
+
pos_emb: Positional embedding tensor
|
935 |
+
key_padding_mask: if provided, specified padding elements in the key will
|
936 |
+
be ignored by the attention. When given a binary mask and a value is True,
|
937 |
+
the corresponding value on the attention layer will be ignored. When given
|
938 |
+
a byte mask and a value is non-zero, the corresponding value on the attention
|
939 |
+
layer will be ignored
|
940 |
+
need_weights: output attn_output_weights.
|
941 |
+
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
|
942 |
+
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
|
943 |
+
left_context (int): left context (in frames) used during streaming decoding.
|
944 |
+
this is used only in real streaming decoding, in other circumstances,
|
945 |
+
it MUST be 0.
|
946 |
+
|
947 |
+
Shape:
|
948 |
+
- Inputs:
|
949 |
+
- query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
|
950 |
+
the embedding dimension.
|
951 |
+
- key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
|
952 |
+
the embedding dimension.
|
953 |
+
- value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
|
954 |
+
the embedding dimension.
|
955 |
+
- pos_emb: :math:`(N, 2*L-1, E)` where L is the target sequence length, N is the batch size, E is
|
956 |
+
the embedding dimension.
|
957 |
+
- key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
|
958 |
+
If a ByteTensor is provided, the non-zero positions will be ignored while the position
|
959 |
+
with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the
|
960 |
+
value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
|
961 |
+
- attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
|
962 |
+
3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
|
963 |
+
S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked
|
964 |
+
positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
|
965 |
+
while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
|
966 |
+
is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
|
967 |
+
is provided, it will be added to the attention weight.
|
968 |
+
|
969 |
+
- Outputs:
|
970 |
+
- attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
|
971 |
+
E is the embedding dimension.
|
972 |
+
- attn_output_weights: :math:`(N, L, S)` where N is the batch size,
|
973 |
+
L is the target sequence length, S is the source sequence length.
|
974 |
+
"""
|
975 |
+
return self.multi_head_attention_forward(
|
976 |
+
query,
|
977 |
+
key,
|
978 |
+
value,
|
979 |
+
pos_emb,
|
980 |
+
self.embed_dim,
|
981 |
+
self.num_heads,
|
982 |
+
self.in_proj.get_weight(),
|
983 |
+
self.in_proj.get_bias(),
|
984 |
+
self.dropout,
|
985 |
+
self.out_proj.get_weight(),
|
986 |
+
self.out_proj.get_bias(),
|
987 |
+
training=self.training,
|
988 |
+
key_padding_mask=key_padding_mask,
|
989 |
+
need_weights=need_weights,
|
990 |
+
attn_mask=attn_mask,
|
991 |
+
left_context=left_context,
|
992 |
+
)
|
993 |
+
|
994 |
+
def rel_shift(self, x: Tensor, left_context: int = 0) -> Tensor:
|
995 |
+
"""Compute relative positional encoding.
|
996 |
+
|
997 |
+
Args:
|
998 |
+
x: Input tensor (batch, head, time1, 2*time1-1+left_context).
|
999 |
+
time1 means the length of query vector.
|
1000 |
+
left_context (int): left context (in frames) used during streaming decoding.
|
1001 |
+
this is used only in real streaming decoding, in other circumstances,
|
1002 |
+
it MUST be 0.
|
1003 |
+
|
1004 |
+
Returns:
|
1005 |
+
Tensor: tensor of shape (batch, head, time1, time2)
|
1006 |
+
(note: time2 has the same value as time1, but it is for
|
1007 |
+
the key, while time1 is for the query).
|
1008 |
+
"""
|
1009 |
+
(batch_size, num_heads, time1, n) = x.shape
|
1010 |
+
|
1011 |
+
time2 = time1 + left_context
|
1012 |
+
if not is_jit_tracing():
|
1013 |
+
assert (
|
1014 |
+
n == left_context + 2 * time1 - 1
|
1015 |
+
), f"{n} == {left_context} + 2 * {time1} - 1"
|
1016 |
+
|
1017 |
+
if is_jit_tracing():
|
1018 |
+
rows = torch.arange(start=time1 - 1, end=-1, step=-1)
|
1019 |
+
cols = torch.arange(time2)
|
1020 |
+
rows = rows.repeat(batch_size * num_heads).unsqueeze(-1)
|
1021 |
+
indexes = rows + cols
|
1022 |
+
|
1023 |
+
x = x.reshape(-1, n)
|
1024 |
+
x = torch.gather(x, dim=1, index=indexes)
|
1025 |
+
x = x.reshape(batch_size, num_heads, time1, time2)
|
1026 |
+
return x
|
1027 |
+
else:
|
1028 |
+
# Note: TorchScript requires explicit arg for stride()
|
1029 |
+
batch_stride = x.stride(0)
|
1030 |
+
head_stride = x.stride(1)
|
1031 |
+
time1_stride = x.stride(2)
|
1032 |
+
n_stride = x.stride(3)
|
1033 |
+
return x.as_strided(
|
1034 |
+
(batch_size, num_heads, time1, time2),
|
1035 |
+
(batch_stride, head_stride, time1_stride - n_stride, n_stride),
|
1036 |
+
storage_offset=n_stride * (time1 - 1),
|
1037 |
+
)
|
1038 |
+
|
1039 |
+
def multi_head_attention_forward(
|
1040 |
+
self,
|
1041 |
+
query: Tensor,
|
1042 |
+
key: Tensor,
|
1043 |
+
value: Tensor,
|
1044 |
+
pos_emb: Tensor,
|
1045 |
+
embed_dim_to_check: int,
|
1046 |
+
num_heads: int,
|
1047 |
+
in_proj_weight: Tensor,
|
1048 |
+
in_proj_bias: Tensor,
|
1049 |
+
dropout_p: float,
|
1050 |
+
out_proj_weight: Tensor,
|
1051 |
+
out_proj_bias: Tensor,
|
1052 |
+
training: bool = True,
|
1053 |
+
key_padding_mask: Optional[Tensor] = None,
|
1054 |
+
need_weights: bool = False,
|
1055 |
+
attn_mask: Optional[Tensor] = None,
|
1056 |
+
left_context: int = 0,
|
1057 |
+
) -> Tuple[Tensor, Optional[Tensor]]:
|
1058 |
+
r"""
|
1059 |
+
Args:
|
1060 |
+
query, key, value: map a query and a set of key-value pairs to an output.
|
1061 |
+
pos_emb: Positional embedding tensor
|
1062 |
+
embed_dim_to_check: total dimension of the model.
|
1063 |
+
num_heads: parallel attention heads.
|
1064 |
+
in_proj_weight, in_proj_bias: input projection weight and bias.
|
1065 |
+
dropout_p: probability of an element to be zeroed.
|
1066 |
+
out_proj_weight, out_proj_bias: the output projection weight and bias.
|
1067 |
+
training: apply dropout if is ``True``.
|
1068 |
+
key_padding_mask: if provided, specified padding elements in the key will
|
1069 |
+
be ignored by the attention. This is an binary mask. When the value is True,
|
1070 |
+
the corresponding value on the attention layer will be filled with -inf.
|
1071 |
+
need_weights: output attn_output_weights.
|
1072 |
+
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
|
1073 |
+
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
|
1074 |
+
left_context (int): left context (in frames) used during streaming decoding.
|
1075 |
+
this is used only in real streaming decoding, in other circumstances,
|
1076 |
+
it MUST be 0.
|
1077 |
+
|
1078 |
+
Shape:
|
1079 |
+
Inputs:
|
1080 |
+
- query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
|
1081 |
+
the embedding dimension.
|
1082 |
+
- key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
|
1083 |
+
the embedding dimension.
|
1084 |
+
- value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
|
1085 |
+
the embedding dimension.
|
1086 |
+
- pos_emb: :math:`(N, 2*L-1, E)` or :math:`(1, 2*L-1, E)` where L is the target sequence
|
1087 |
+
length, N is the batch size, E is the embedding dimension.
|
1088 |
+
- key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
|
1089 |
+
If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions
|
1090 |
+
will be unchanged. If a BoolTensor is provided, the positions with the
|
1091 |
+
value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
|
1092 |
+
- attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
|
1093 |
+
3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
|
1094 |
+
S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked
|
1095 |
+
positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
|
1096 |
+
while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
|
1097 |
+
are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
|
1098 |
+
is provided, it will be added to the attention weight.
|
1099 |
+
|
1100 |
+
Outputs:
|
1101 |
+
- attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
|
1102 |
+
E is the embedding dimension.
|
1103 |
+
- attn_output_weights: :math:`(N, L, S)` where N is the batch size,
|
1104 |
+
L is the target sequence length, S is the source sequence length.
|
1105 |
+
"""
|
1106 |
+
|
1107 |
+
tgt_len, bsz, embed_dim = query.size()
|
1108 |
+
if not is_jit_tracing():
|
1109 |
+
assert embed_dim == embed_dim_to_check
|
1110 |
+
assert key.size(0) == value.size(0) and key.size(1) == value.size(1)
|
1111 |
+
|
1112 |
+
head_dim = embed_dim // num_heads
|
1113 |
+
if not is_jit_tracing():
|
1114 |
+
assert (
|
1115 |
+
head_dim * num_heads == embed_dim
|
1116 |
+
), "embed_dim must be divisible by num_heads"
|
1117 |
+
|
1118 |
+
scaling = float(head_dim) ** -0.5
|
1119 |
+
|
1120 |
+
if torch.equal(query, key) and torch.equal(key, value):
|
1121 |
+
# self-attention
|
1122 |
+
q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk(
|
1123 |
+
3, dim=-1
|
1124 |
+
)
|
1125 |
+
|
1126 |
+
elif torch.equal(key, value):
|
1127 |
+
# encoder-decoder attention
|
1128 |
+
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
1129 |
+
_b = in_proj_bias
|
1130 |
+
_start = 0
|
1131 |
+
_end = embed_dim
|
1132 |
+
_w = in_proj_weight[_start:_end, :]
|
1133 |
+
if _b is not None:
|
1134 |
+
_b = _b[_start:_end]
|
1135 |
+
q = nn.functional.linear(query, _w, _b)
|
1136 |
+
|
1137 |
+
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
1138 |
+
_b = in_proj_bias
|
1139 |
+
_start = embed_dim
|
1140 |
+
_end = None
|
1141 |
+
_w = in_proj_weight[_start:, :]
|
1142 |
+
if _b is not None:
|
1143 |
+
_b = _b[_start:]
|
1144 |
+
k, v = nn.functional.linear(key, _w, _b).chunk(2, dim=-1)
|
1145 |
+
|
1146 |
+
else:
|
1147 |
+
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
1148 |
+
_b = in_proj_bias
|
1149 |
+
_start = 0
|
1150 |
+
_end = embed_dim
|
1151 |
+
_w = in_proj_weight[_start:_end, :]
|
1152 |
+
if _b is not None:
|
1153 |
+
_b = _b[_start:_end]
|
1154 |
+
q = nn.functional.linear(query, _w, _b)
|
1155 |
+
|
1156 |
+
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
1157 |
+
_b = in_proj_bias
|
1158 |
+
_start = embed_dim
|
1159 |
+
_end = embed_dim * 2
|
1160 |
+
_w = in_proj_weight[_start:_end, :]
|
1161 |
+
if _b is not None:
|
1162 |
+
_b = _b[_start:_end]
|
1163 |
+
k = nn.functional.linear(key, _w, _b)
|
1164 |
+
|
1165 |
+
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
1166 |
+
_b = in_proj_bias
|
1167 |
+
_start = embed_dim * 2
|
1168 |
+
_end = None
|
1169 |
+
_w = in_proj_weight[_start:, :]
|
1170 |
+
if _b is not None:
|
1171 |
+
_b = _b[_start:]
|
1172 |
+
v = nn.functional.linear(value, _w, _b)
|
1173 |
+
|
1174 |
+
if attn_mask is not None:
|
1175 |
+
assert (
|
1176 |
+
attn_mask.dtype == torch.float32
|
1177 |
+
or attn_mask.dtype == torch.float64
|
1178 |
+
or attn_mask.dtype == torch.float16
|
1179 |
+
or attn_mask.dtype == torch.uint8
|
1180 |
+
or attn_mask.dtype == torch.bool
|
1181 |
+
), "Only float, byte, and bool types are supported for attn_mask, not {}".format(
|
1182 |
+
attn_mask.dtype
|
1183 |
+
)
|
1184 |
+
if attn_mask.dtype == torch.uint8:
|
1185 |
+
warnings.warn(
|
1186 |
+
"Byte tensor for attn_mask is deprecated. Use bool tensor instead."
|
1187 |
+
)
|
1188 |
+
attn_mask = attn_mask.to(torch.bool)
|
1189 |
+
|
1190 |
+
if attn_mask.dim() == 2:
|
1191 |
+
attn_mask = attn_mask.unsqueeze(0)
|
1192 |
+
if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
|
1193 |
+
raise RuntimeError("The size of the 2D attn_mask is not correct.")
|
1194 |
+
elif attn_mask.dim() == 3:
|
1195 |
+
if list(attn_mask.size()) != [
|
1196 |
+
bsz * num_heads,
|
1197 |
+
query.size(0),
|
1198 |
+
key.size(0),
|
1199 |
+
]:
|
1200 |
+
raise RuntimeError("The size of the 3D attn_mask is not correct.")
|
1201 |
+
else:
|
1202 |
+
raise RuntimeError(
|
1203 |
+
"attn_mask's dimension {} is not supported".format(attn_mask.dim())
|
1204 |
+
)
|
1205 |
+
# attn_mask's dim is 3 now.
|
1206 |
+
|
1207 |
+
# convert ByteTensor key_padding_mask to bool
|
1208 |
+
if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
|
1209 |
+
warnings.warn(
|
1210 |
+
"Byte tensor for key_padding_mask is deprecated. Use bool tensor instead."
|
1211 |
+
)
|
1212 |
+
key_padding_mask = key_padding_mask.to(torch.bool)
|
1213 |
+
|
1214 |
+
q = (q * scaling).contiguous().view(tgt_len, bsz, num_heads, head_dim)
|
1215 |
+
k = k.contiguous().view(-1, bsz, num_heads, head_dim)
|
1216 |
+
v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
|
1217 |
+
|
1218 |
+
src_len = k.size(0)
|
1219 |
+
|
1220 |
+
if key_padding_mask is not None and not is_jit_tracing():
|
1221 |
+
assert key_padding_mask.size(0) == bsz, "{} == {}".format(
|
1222 |
+
key_padding_mask.size(0), bsz
|
1223 |
+
)
|
1224 |
+
assert key_padding_mask.size(1) == src_len, "{} == {}".format(
|
1225 |
+
key_padding_mask.size(1), src_len
|
1226 |
+
)
|
1227 |
+
|
1228 |
+
q = q.transpose(0, 1) # (batch, time1, head, d_k)
|
1229 |
+
|
1230 |
+
pos_emb_bsz = pos_emb.size(0)
|
1231 |
+
if not is_jit_tracing():
|
1232 |
+
assert pos_emb_bsz in (1, bsz) # actually it is 1
|
1233 |
+
|
1234 |
+
p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim)
|
1235 |
+
# (batch, 2*time1, head, d_k) --> (batch, head, d_k, 2*time -1)
|
1236 |
+
p = p.permute(0, 2, 3, 1)
|
1237 |
+
|
1238 |
+
q_with_bias_u = (q + self._pos_bias_u()).transpose(
|
1239 |
+
1, 2
|
1240 |
+
) # (batch, head, time1, d_k)
|
1241 |
+
|
1242 |
+
q_with_bias_v = (q + self._pos_bias_v()).transpose(
|
1243 |
+
1, 2
|
1244 |
+
) # (batch, head, time1, d_k)
|
1245 |
+
|
1246 |
+
# compute attention score
|
1247 |
+
# first compute matrix a and matrix c
|
1248 |
+
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
|
1249 |
+
k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2)
|
1250 |
+
matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2)
|
1251 |
+
|
1252 |
+
# compute matrix b and matrix d
|
1253 |
+
matrix_bd = torch.matmul(q_with_bias_v, p) # (batch, head, time1, 2*time1-1)
|
1254 |
+
matrix_bd = self.rel_shift(matrix_bd, left_context)
|
1255 |
+
|
1256 |
+
attn_output_weights = matrix_ac + matrix_bd # (batch, head, time1, time2)
|
1257 |
+
|
1258 |
+
attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1)
|
1259 |
+
|
1260 |
+
if not is_jit_tracing():
|
1261 |
+
assert list(attn_output_weights.size()) == [
|
1262 |
+
bsz * num_heads,
|
1263 |
+
tgt_len,
|
1264 |
+
src_len,
|
1265 |
+
]
|
1266 |
+
|
1267 |
+
if attn_mask is not None:
|
1268 |
+
if attn_mask.dtype == torch.bool:
|
1269 |
+
attn_output_weights.masked_fill_(attn_mask, float("-inf"))
|
1270 |
+
else:
|
1271 |
+
attn_output_weights += attn_mask
|
1272 |
+
|
1273 |
+
if key_padding_mask is not None:
|
1274 |
+
attn_output_weights = attn_output_weights.view(
|
1275 |
+
bsz, num_heads, tgt_len, src_len
|
1276 |
+
)
|
1277 |
+
attn_output_weights = attn_output_weights.masked_fill(
|
1278 |
+
key_padding_mask.unsqueeze(1).unsqueeze(2),
|
1279 |
+
float("-inf"),
|
1280 |
+
)
|
1281 |
+
attn_output_weights = attn_output_weights.view(
|
1282 |
+
bsz * num_heads, tgt_len, src_len
|
1283 |
+
)
|
1284 |
+
|
1285 |
+
attn_output_weights = nn.functional.softmax(attn_output_weights, dim=-1)
|
1286 |
+
|
1287 |
+
# If we are using dynamic_chunk_training and setting a limited
|
1288 |
+
# num_left_chunks, the attention may only see the padding values which
|
1289 |
+
# will also be masked out by `key_padding_mask`, at this circumstances,
|
1290 |
+
# the whole column of `attn_output_weights` will be `-inf`
|
1291 |
+
# (i.e. be `nan` after softmax), so, we fill `0.0` at the masking
|
1292 |
+
# positions to avoid invalid loss value below.
|
1293 |
+
if (
|
1294 |
+
attn_mask is not None
|
1295 |
+
and attn_mask.dtype == torch.bool
|
1296 |
+
and key_padding_mask is not None
|
1297 |
+
):
|
1298 |
+
if attn_mask.size(0) != 1:
|
1299 |
+
attn_mask = attn_mask.view(bsz, num_heads, tgt_len, src_len)
|
1300 |
+
combined_mask = attn_mask | key_padding_mask.unsqueeze(1).unsqueeze(2)
|
1301 |
+
else:
|
1302 |
+
# attn_mask.shape == (1, tgt_len, src_len)
|
1303 |
+
combined_mask = attn_mask.unsqueeze(0) | key_padding_mask.unsqueeze(
|
1304 |
+
1
|
1305 |
+
).unsqueeze(2)
|
1306 |
+
|
1307 |
+
attn_output_weights = attn_output_weights.view(
|
1308 |
+
bsz, num_heads, tgt_len, src_len
|
1309 |
+
)
|
1310 |
+
attn_output_weights = attn_output_weights.masked_fill(combined_mask, 0.0)
|
1311 |
+
attn_output_weights = attn_output_weights.view(
|
1312 |
+
bsz * num_heads, tgt_len, src_len
|
1313 |
+
)
|
1314 |
+
|
1315 |
+
attn_output_weights = nn.functional.dropout(
|
1316 |
+
attn_output_weights, p=dropout_p, training=training
|
1317 |
+
)
|
1318 |
+
|
1319 |
+
attn_output = torch.bmm(attn_output_weights, v)
|
1320 |
+
|
1321 |
+
if not is_jit_tracing():
|
1322 |
+
assert list(attn_output.size()) == [
|
1323 |
+
bsz * num_heads,
|
1324 |
+
tgt_len,
|
1325 |
+
head_dim,
|
1326 |
+
]
|
1327 |
+
|
1328 |
+
attn_output = (
|
1329 |
+
attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
|
1330 |
+
)
|
1331 |
+
attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias)
|
1332 |
+
|
1333 |
+
if need_weights:
|
1334 |
+
# average attention weights over heads
|
1335 |
+
attn_output_weights = attn_output_weights.view(
|
1336 |
+
bsz, num_heads, tgt_len, src_len
|
1337 |
+
)
|
1338 |
+
return attn_output, attn_output_weights.sum(dim=1) / num_heads
|
1339 |
+
else:
|
1340 |
+
return attn_output, None
|
1341 |
+
|
1342 |
+
|
1343 |
+
class ConvolutionModule(nn.Module):
|
1344 |
+
"""ConvolutionModule in Conformer model.
|
1345 |
+
Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py
|
1346 |
+
|
1347 |
+
Args:
|
1348 |
+
channels (int): The number of channels of conv layers.
|
1349 |
+
kernel_size (int): Kernerl size of conv layers.
|
1350 |
+
bias (bool): Whether to use bias in conv layers (default=True).
|
1351 |
+
causal (bool): Whether to use causal convolution.
|
1352 |
+
"""
|
1353 |
+
|
1354 |
+
def __init__(
|
1355 |
+
self,
|
1356 |
+
channels: int,
|
1357 |
+
kernel_size: int,
|
1358 |
+
bias: bool = True,
|
1359 |
+
causal: bool = False,
|
1360 |
+
) -> None:
|
1361 |
+
"""Construct an ConvolutionModule object."""
|
1362 |
+
super(ConvolutionModule, self).__init__()
|
1363 |
+
# kernerl_size should be a odd number for 'SAME' padding
|
1364 |
+
assert (kernel_size - 1) % 2 == 0
|
1365 |
+
self.causal = causal
|
1366 |
+
|
1367 |
+
self.pointwise_conv1 = ScaledConv1d(
|
1368 |
+
channels,
|
1369 |
+
2 * channels,
|
1370 |
+
kernel_size=1,
|
1371 |
+
stride=1,
|
1372 |
+
padding=0,
|
1373 |
+
bias=bias,
|
1374 |
+
)
|
1375 |
+
|
1376 |
+
# after pointwise_conv1 we put x through a gated linear unit (nn.functional.glu).
|
1377 |
+
# For most layers the normal rms value of channels of x seems to be in the range 1 to 4,
|
1378 |
+
# but sometimes, for some reason, for layer 0 the rms ends up being very large,
|
1379 |
+
# between 50 and 100 for different channels. This will cause very peaky and
|
1380 |
+
# sparse derivatives for the sigmoid gating function, which will tend to make
|
1381 |
+
# the loss function not learn effectively. (for most layers the average absolute values
|
1382 |
+
# are in the range 0.5..9.0, and the average p(x>0), i.e. positive proportion,
|
1383 |
+
# at the output of pointwise_conv1.output is around 0.35 to 0.45 for different
|
1384 |
+
# layers, which likely breaks down as 0.5 for the "linear" half and
|
1385 |
+
# 0.2 to 0.3 for the part that goes into the sigmoid. The idea is that if we
|
1386 |
+
# constrain the rms values to a reasonable range via a constraint of max_abs=10.0,
|
1387 |
+
# it will be in a better position to start learning something, i.e. to latch onto
|
1388 |
+
# the correct range.
|
1389 |
+
self.deriv_balancer1 = ActivationBalancer(
|
1390 |
+
channel_dim=1, max_abs=10.0, min_positive=0.05, max_positive=1.0
|
1391 |
+
)
|
1392 |
+
|
1393 |
+
self.lorder = kernel_size - 1
|
1394 |
+
padding = (kernel_size - 1) // 2
|
1395 |
+
if self.causal:
|
1396 |
+
padding = 0
|
1397 |
+
|
1398 |
+
self.depthwise_conv = ScaledConv1d(
|
1399 |
+
channels,
|
1400 |
+
channels,
|
1401 |
+
kernel_size,
|
1402 |
+
stride=1,
|
1403 |
+
padding=padding,
|
1404 |
+
groups=channels,
|
1405 |
+
bias=bias,
|
1406 |
+
)
|
1407 |
+
|
1408 |
+
self.deriv_balancer2 = ActivationBalancer(
|
1409 |
+
channel_dim=1, min_positive=0.05, max_positive=1.0
|
1410 |
+
)
|
1411 |
+
|
1412 |
+
self.activation = DoubleSwish()
|
1413 |
+
|
1414 |
+
self.pointwise_conv2 = ScaledConv1d(
|
1415 |
+
channels,
|
1416 |
+
channels,
|
1417 |
+
kernel_size=1,
|
1418 |
+
stride=1,
|
1419 |
+
padding=0,
|
1420 |
+
bias=bias,
|
1421 |
+
initial_scale=0.25,
|
1422 |
+
)
|
1423 |
+
|
1424 |
+
def forward(
|
1425 |
+
self,
|
1426 |
+
x: Tensor,
|
1427 |
+
cache: Optional[Tensor] = None,
|
1428 |
+
right_context: int = 0,
|
1429 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
1430 |
+
) -> Tuple[Tensor, Tensor]:
|
1431 |
+
"""Compute convolution module.
|
1432 |
+
|
1433 |
+
Args:
|
1434 |
+
x: Input tensor (#time, batch, channels).
|
1435 |
+
cache: The cache of depthwise_conv, only used in real streaming
|
1436 |
+
decoding.
|
1437 |
+
right_context:
|
1438 |
+
How many future frames the attention can see in current chunk.
|
1439 |
+
Note: It's not that each individual frame has `right_context` frames
|
1440 |
+
src_key_padding_mask: the mask for the src keys per batch (optional).
|
1441 |
+
of right context, some have more.
|
1442 |
+
|
1443 |
+
Returns:
|
1444 |
+
If cache is None return the output tensor (#time, batch, channels).
|
1445 |
+
If cache is not None, return a tuple of Tensor, the first one is
|
1446 |
+
the output tensor (#time, batch, channels), the second one is the
|
1447 |
+
new cache for next chunk (#kernel_size - 1, batch, channels).
|
1448 |
+
|
1449 |
+
"""
|
1450 |
+
# exchange the temporal dimension and the feature dimension
|
1451 |
+
x = x.permute(1, 2, 0) # (#batch, channels, time).
|
1452 |
+
|
1453 |
+
# GLU mechanism
|
1454 |
+
x = self.pointwise_conv1(x) # (batch, 2*channels, time)
|
1455 |
+
|
1456 |
+
x = self.deriv_balancer1(x)
|
1457 |
+
x = nn.functional.glu(x, dim=1) # (batch, channels, time)
|
1458 |
+
|
1459 |
+
# 1D Depthwise Conv
|
1460 |
+
if src_key_padding_mask is not None:
|
1461 |
+
x.masked_fill_(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0)
|
1462 |
+
if self.causal and self.lorder > 0:
|
1463 |
+
if cache is None:
|
1464 |
+
# Make depthwise_conv causal by
|
1465 |
+
# manualy padding self.lorder zeros to the left
|
1466 |
+
x = nn.functional.pad(x, (self.lorder, 0), "constant", 0.0)
|
1467 |
+
else:
|
1468 |
+
assert not self.training, "Cache should be None in training time"
|
1469 |
+
assert cache.size(0) == self.lorder
|
1470 |
+
x = torch.cat([cache.permute(1, 2, 0), x], dim=2)
|
1471 |
+
if right_context > 0:
|
1472 |
+
cache = x.permute(2, 0, 1)[
|
1473 |
+
-(self.lorder + right_context) : (-right_context), # noqa
|
1474 |
+
...,
|
1475 |
+
]
|
1476 |
+
else:
|
1477 |
+
cache = x.permute(2, 0, 1)[-self.lorder :, ...] # noqa
|
1478 |
+
x = self.depthwise_conv(x)
|
1479 |
+
|
1480 |
+
x = self.deriv_balancer2(x)
|
1481 |
+
x = self.activation(x)
|
1482 |
+
|
1483 |
+
x = self.pointwise_conv2(x) # (batch, channel, time)
|
1484 |
+
|
1485 |
+
# torch.jit.script requires return types be the same as annotated above
|
1486 |
+
if cache is None:
|
1487 |
+
cache = torch.empty(0)
|
1488 |
+
|
1489 |
+
return x.permute(2, 0, 1), cache
|
1490 |
+
|
1491 |
+
|
1492 |
+
class Conv2dSubsampling(nn.Module):
|
1493 |
+
"""Convolutional 2D subsampling (to 1/4 length).
|
1494 |
+
|
1495 |
+
Convert an input of shape (N, T, idim) to an output
|
1496 |
+
with shape (N, T', odim), where
|
1497 |
+
T' = ((T-1)//2 - 1)//2, which approximates T' == T//4
|
1498 |
+
|
1499 |
+
It is based on
|
1500 |
+
https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa
|
1501 |
+
"""
|
1502 |
+
|
1503 |
+
def __init__(
|
1504 |
+
self,
|
1505 |
+
in_channels: int,
|
1506 |
+
out_channels: int,
|
1507 |
+
layer1_channels: int = 8,
|
1508 |
+
layer2_channels: int = 32,
|
1509 |
+
layer3_channels: int = 128,
|
1510 |
+
) -> None:
|
1511 |
+
"""
|
1512 |
+
Args:
|
1513 |
+
in_channels:
|
1514 |
+
Number of channels in. The input shape is (N, T, in_channels).
|
1515 |
+
Caution: It requires: T >=7, in_channels >=7
|
1516 |
+
out_channels
|
1517 |
+
Output dim. The output shape is (N, ((T-1)//2 - 1)//2, out_channels)
|
1518 |
+
layer1_channels:
|
1519 |
+
Number of channels in layer1
|
1520 |
+
layer1_channels:
|
1521 |
+
Number of channels in layer2
|
1522 |
+
"""
|
1523 |
+
assert in_channels >= 7
|
1524 |
+
super().__init__()
|
1525 |
+
|
1526 |
+
self.conv = nn.Sequential(
|
1527 |
+
ScaledConv2d(
|
1528 |
+
in_channels=1,
|
1529 |
+
out_channels=layer1_channels,
|
1530 |
+
kernel_size=3,
|
1531 |
+
padding=1,
|
1532 |
+
),
|
1533 |
+
ActivationBalancer(channel_dim=1),
|
1534 |
+
DoubleSwish(),
|
1535 |
+
ScaledConv2d(
|
1536 |
+
in_channels=layer1_channels,
|
1537 |
+
out_channels=layer2_channels,
|
1538 |
+
kernel_size=3,
|
1539 |
+
stride=2,
|
1540 |
+
),
|
1541 |
+
ActivationBalancer(channel_dim=1),
|
1542 |
+
DoubleSwish(),
|
1543 |
+
ScaledConv2d(
|
1544 |
+
in_channels=layer2_channels,
|
1545 |
+
out_channels=layer3_channels,
|
1546 |
+
kernel_size=3,
|
1547 |
+
stride=2,
|
1548 |
+
),
|
1549 |
+
ActivationBalancer(channel_dim=1),
|
1550 |
+
DoubleSwish(),
|
1551 |
+
)
|
1552 |
+
self.out = ScaledLinear(
|
1553 |
+
layer3_channels * (((in_channels - 1) // 2 - 1) // 2), out_channels
|
1554 |
+
)
|
1555 |
+
# set learn_eps=False because out_norm is preceded by `out`, and `out`
|
1556 |
+
# itself has learned scale, so the extra degree of freedom is not
|
1557 |
+
# needed.
|
1558 |
+
self.out_norm = BasicNorm(out_channels, learn_eps=False)
|
1559 |
+
# constrain median of output to be close to zero.
|
1560 |
+
self.out_balancer = ActivationBalancer(
|
1561 |
+
channel_dim=-1, min_positive=0.45, max_positive=0.55
|
1562 |
+
)
|
1563 |
+
|
1564 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
1565 |
+
"""Subsample x.
|
1566 |
+
|
1567 |
+
Args:
|
1568 |
+
x:
|
1569 |
+
Its shape is (N, T, idim).
|
1570 |
+
|
1571 |
+
Returns:
|
1572 |
+
Return a tensor of shape (N, ((T-1)//2 - 1)//2, odim)
|
1573 |
+
"""
|
1574 |
+
# On entry, x is (N, T, idim)
|
1575 |
+
x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W)
|
1576 |
+
x = self.conv(x)
|
1577 |
+
# Now x is of shape (N, odim, ((T-1)//2 - 1)//2, ((idim-1)//2 - 1)//2)
|
1578 |
+
b, c, t, f = x.size()
|
1579 |
+
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
1580 |
+
# Now x is of shape (N, ((T-1)//2 - 1))//2, odim)
|
1581 |
+
x = self.out_norm(x)
|
1582 |
+
x = self.out_balancer(x)
|
1583 |
+
return x
|
1584 |
+
|
1585 |
+
|
1586 |
+
if __name__ == "__main__":
|
1587 |
+
torch.set_num_threads(1)
|
1588 |
+
torch.set_num_interop_threads(1)
|
1589 |
+
feature_dim = 50
|
1590 |
+
c = Conformer(num_features=feature_dim, d_model=128, nhead=4)
|
1591 |
+
batch_size = 5
|
1592 |
+
seq_len = 20
|
1593 |
+
# Just make sure the forward pass runs.
|
1594 |
+
f = c(
|
1595 |
+
torch.randn(batch_size, seq_len, feature_dim),
|
1596 |
+
torch.full((batch_size,), seq_len, dtype=torch.int64),
|
1597 |
+
warmup=0.5,
|
1598 |
+
)
|
err2020/conformer_ctc3/decode.py
ADDED
@@ -0,0 +1,1052 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
#
|
3 |
+
# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
|
4 |
+
# Zengwei Yao)
|
5 |
+
#
|
6 |
+
# See ../../../../LICENSE for clarification regarding multiple authors
|
7 |
+
#
|
8 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
9 |
+
# you may not use this file except in compliance with the License.
|
10 |
+
# You may obtain a copy of the License at
|
11 |
+
#
|
12 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
13 |
+
#
|
14 |
+
# Unless required by applicable law or agreed to in writing, software
|
15 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
16 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
17 |
+
# See the License for the specific language governing permissions and
|
18 |
+
# limitations under the License.
|
19 |
+
"""
|
20 |
+
Usage:
|
21 |
+
(1) decode in non-streaming mode (take ctc-decoding as an example)
|
22 |
+
./conformer_ctc3/decode.py \
|
23 |
+
--epoch 30 \
|
24 |
+
--avg 15 \
|
25 |
+
--exp-dir ./conformer_ctc3/exp \
|
26 |
+
--max-duration 600 \
|
27 |
+
--decoding-method ctc-decoding
|
28 |
+
|
29 |
+
(2) decode in streaming mode (take ctc-decoding as an example)
|
30 |
+
./conformer_ctc3/decode.py \
|
31 |
+
--epoch 30 \
|
32 |
+
--avg 15 \
|
33 |
+
--simulate-streaming 1 \
|
34 |
+
--causal-convolution 1 \
|
35 |
+
--decode-chunk-size 16 \
|
36 |
+
--left-context 64 \
|
37 |
+
--exp-dir ./conformer_ctc3/exp \
|
38 |
+
--max-duration 600 \
|
39 |
+
--decoding-method ctc-decoding
|
40 |
+
|
41 |
+
To evaluate symbol delay, you should:
|
42 |
+
(1) Generate cuts with word-time alignments:
|
43 |
+
./add_alignments.sh
|
44 |
+
(2) Set the argument "--manifest-dir data/fbank_ali" while decoding.
|
45 |
+
For example:
|
46 |
+
./conformer_ctc3/decode.py \
|
47 |
+
--epoch 30 \
|
48 |
+
--avg 15 \
|
49 |
+
--exp-dir ./conformer_ctc3/exp \
|
50 |
+
--max-duration 600 \
|
51 |
+
--decoding-method ctc-decoding \
|
52 |
+
--simulate-streaming 1 \
|
53 |
+
--causal-convolution 1 \
|
54 |
+
--decode-chunk-size 16 \
|
55 |
+
--left-context 64 \
|
56 |
+
--manifest-dir data/fbank_ali
|
57 |
+
Note: It supports calculating symbol delay with following decoding methods:
|
58 |
+
- ctc-decoding
|
59 |
+
- 1best
|
60 |
+
"""
|
61 |
+
|
62 |
+
|
63 |
+
import argparse
|
64 |
+
import logging
|
65 |
+
import math
|
66 |
+
from collections import defaultdict
|
67 |
+
from pathlib import Path
|
68 |
+
from typing import Dict, List, Optional, Tuple
|
69 |
+
|
70 |
+
import k2
|
71 |
+
import sentencepiece as spm
|
72 |
+
import torch
|
73 |
+
import torch.nn as nn
|
74 |
+
from asr_datamodule import LibriSpeechAsrDataModule
|
75 |
+
from train import add_model_arguments, get_ctc_model, get_params
|
76 |
+
|
77 |
+
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
|
78 |
+
from icefall.checkpoint import (
|
79 |
+
average_checkpoints,
|
80 |
+
average_checkpoints_with_averaged_model,
|
81 |
+
find_checkpoints,
|
82 |
+
load_checkpoint,
|
83 |
+
)
|
84 |
+
from icefall.decode import (
|
85 |
+
get_lattice,
|
86 |
+
nbest_decoding,
|
87 |
+
nbest_oracle,
|
88 |
+
one_best_decoding,
|
89 |
+
rescore_with_n_best_list,
|
90 |
+
rescore_with_whole_lattice,
|
91 |
+
)
|
92 |
+
from icefall.lexicon import Lexicon
|
93 |
+
from icefall.utils import (
|
94 |
+
AttributeDict,
|
95 |
+
convert_timestamp,
|
96 |
+
get_texts,
|
97 |
+
make_pad_mask,
|
98 |
+
parse_bpe_start_end_pairs,
|
99 |
+
parse_fsa_timestamps_and_texts,
|
100 |
+
setup_logger,
|
101 |
+
store_transcripts_and_timestamps,
|
102 |
+
str2bool,
|
103 |
+
write_error_stats_with_timestamps,
|
104 |
+
)
|
105 |
+
|
106 |
+
LOG_EPS = math.log(1e-10)
|
107 |
+
|
108 |
+
|
109 |
+
def get_parser():
|
110 |
+
parser = argparse.ArgumentParser(
|
111 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
112 |
+
)
|
113 |
+
|
114 |
+
parser.add_argument(
|
115 |
+
"--epoch",
|
116 |
+
type=int,
|
117 |
+
default=30,
|
118 |
+
help="""It specifies the checkpoint to use for decoding.
|
119 |
+
Note: Epoch counts from 1.
|
120 |
+
You can specify --avg to use more checkpoints for model averaging.""",
|
121 |
+
)
|
122 |
+
|
123 |
+
parser.add_argument(
|
124 |
+
"--iter",
|
125 |
+
type=int,
|
126 |
+
default=0,
|
127 |
+
help="""If positive, --epoch is ignored and it
|
128 |
+
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
129 |
+
You can specify --avg to use more checkpoints for model averaging.
|
130 |
+
""",
|
131 |
+
)
|
132 |
+
|
133 |
+
parser.add_argument(
|
134 |
+
"--avg",
|
135 |
+
type=int,
|
136 |
+
default=15,
|
137 |
+
help="Number of checkpoints to average. Automatically select "
|
138 |
+
"consecutive checkpoints before the checkpoint specified by "
|
139 |
+
"'--epoch' and '--iter'",
|
140 |
+
)
|
141 |
+
|
142 |
+
parser.add_argument(
|
143 |
+
"--use-averaged-model",
|
144 |
+
type=str2bool,
|
145 |
+
default=True,
|
146 |
+
help="Whether to load averaged model. Currently it only supports "
|
147 |
+
"using --epoch. If True, it would decode with the averaged model "
|
148 |
+
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
149 |
+
"Actually only the models with epoch number of `epoch-avg` and "
|
150 |
+
"`epoch` are loaded for averaging. ",
|
151 |
+
)
|
152 |
+
|
153 |
+
parser.add_argument(
|
154 |
+
"--exp-dir",
|
155 |
+
type=str,
|
156 |
+
default="pruned_transducer_stateless4/exp",
|
157 |
+
help="The experiment dir",
|
158 |
+
)
|
159 |
+
|
160 |
+
parser.add_argument(
|
161 |
+
"--lang-dir",
|
162 |
+
type=Path,
|
163 |
+
default="data/lang_bpe_500",
|
164 |
+
help="The lang dir containing word table and LG graph",
|
165 |
+
)
|
166 |
+
|
167 |
+
parser.add_argument(
|
168 |
+
"--decoding-method",
|
169 |
+
type=str,
|
170 |
+
default="ctc-decoding",
|
171 |
+
help="""Decoding method.
|
172 |
+
Supported values are:
|
173 |
+
- (0) ctc-greedy-search. It uses a sentence piece model,
|
174 |
+
i.e., lang_dir/bpe.model, to convert word pieces to words.
|
175 |
+
It needs neither a lexicon nor an n-gram LM.
|
176 |
+
- (1) ctc-decoding. Use CTC decoding. It uses a sentence piece
|
177 |
+
model, i.e., lang_dir/bpe.model, to convert word pieces to words.
|
178 |
+
It needs neither a lexicon nor an n-gram LM.
|
179 |
+
- (2) 1best. Extract the best path from the decoding lattice as the
|
180 |
+
decoding result.
|
181 |
+
- (3) nbest. Extract n paths from the decoding lattice; the path
|
182 |
+
with the highest score is the decoding result.
|
183 |
+
- (4) nbest-rescoring. Extract n paths from the decoding lattice,
|
184 |
+
rescore them with an n-gram LM (e.g., a 4-gram LM), the path with
|
185 |
+
the highest score is the decoding result.
|
186 |
+
- (5) whole-lattice-rescoring. Rescore the decoding lattice with an
|
187 |
+
n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice
|
188 |
+
is the decoding result.
|
189 |
+
you have trained an RNN LM using ./rnn_lm/train.py
|
190 |
+
- (6) nbest-oracle. Its WER is the lower bound of any n-best
|
191 |
+
rescoring method can achieve. Useful for debugging n-best
|
192 |
+
rescoring method.
|
193 |
+
""",
|
194 |
+
)
|
195 |
+
|
196 |
+
parser.add_argument(
|
197 |
+
"--num-paths",
|
198 |
+
type=int,
|
199 |
+
default=100,
|
200 |
+
help="""Number of paths for n-best based decoding method.
|
201 |
+
Used only when "method" is one of the following values:
|
202 |
+
nbest, nbest-rescoring, and nbest-oracle
|
203 |
+
""",
|
204 |
+
)
|
205 |
+
|
206 |
+
parser.add_argument(
|
207 |
+
"--nbest-scale",
|
208 |
+
type=float,
|
209 |
+
default=0.5,
|
210 |
+
help="""The scale to be applied to `lattice.scores`.
|
211 |
+
It's needed if you use any kinds of n-best based rescoring.
|
212 |
+
Used only when "method" is one of the following values:
|
213 |
+
nbest, nbest-rescoring, and nbest-oracle
|
214 |
+
A smaller value results in more unique paths.
|
215 |
+
""",
|
216 |
+
)
|
217 |
+
|
218 |
+
parser.add_argument(
|
219 |
+
"--lm-dir",
|
220 |
+
type=str,
|
221 |
+
default="data/lm",
|
222 |
+
help="""The n-gram LM dir.
|
223 |
+
It should contain either G_4_gram.pt or G_4_gram.fst.txt
|
224 |
+
""",
|
225 |
+
)
|
226 |
+
|
227 |
+
parser.add_argument(
|
228 |
+
"--simulate-streaming",
|
229 |
+
type=str2bool,
|
230 |
+
default=False,
|
231 |
+
help="""Whether to simulate streaming in decoding, this is a good way to
|
232 |
+
test a streaming model.
|
233 |
+
""",
|
234 |
+
)
|
235 |
+
|
236 |
+
parser.add_argument(
|
237 |
+
"--decode-chunk-size",
|
238 |
+
type=int,
|
239 |
+
default=16,
|
240 |
+
help="The chunk size for decoding (in frames after subsampling)",
|
241 |
+
)
|
242 |
+
|
243 |
+
parser.add_argument(
|
244 |
+
"--left-context",
|
245 |
+
type=int,
|
246 |
+
default=64,
|
247 |
+
help="left context can be seen during decoding (in frames after subsampling)",
|
248 |
+
)
|
249 |
+
|
250 |
+
parser.add_argument(
|
251 |
+
"--hlg-scale",
|
252 |
+
type=float,
|
253 |
+
default=0.8,
|
254 |
+
help="""The scale to be applied to `hlg.scores`.
|
255 |
+
""",
|
256 |
+
)
|
257 |
+
|
258 |
+
add_model_arguments(parser)
|
259 |
+
|
260 |
+
return parser
|
261 |
+
|
262 |
+
|
263 |
+
def get_decoding_params() -> AttributeDict:
|
264 |
+
"""Parameters for decoding."""
|
265 |
+
params = AttributeDict(
|
266 |
+
{
|
267 |
+
"frame_shift_ms": 10,
|
268 |
+
"search_beam": 20,
|
269 |
+
"output_beam": 8,
|
270 |
+
"min_active_states": 30,
|
271 |
+
"max_active_states": 10000,
|
272 |
+
"use_double_scores": True,
|
273 |
+
}
|
274 |
+
)
|
275 |
+
return params
|
276 |
+
|
277 |
+
|
278 |
+
def ctc_greedy_search(
|
279 |
+
ctc_probs: torch.Tensor,
|
280 |
+
nnet_output_lens: torch.Tensor,
|
281 |
+
sp: spm.SentencePieceProcessor,
|
282 |
+
subsampling_factor: int = 4,
|
283 |
+
frame_shift_ms: float = 10,
|
284 |
+
) -> Tuple[List[Tuple[float, float]], List[List[str]]]:
|
285 |
+
"""Apply CTC greedy search
|
286 |
+
Args:
|
287 |
+
ctc_probs (torch.Tensor):
|
288 |
+
(batch, max_len, feat_dim)
|
289 |
+
nnet_output_lens (torch.Tensor):
|
290 |
+
(batch, )
|
291 |
+
sp:
|
292 |
+
The BPE model.
|
293 |
+
subsampling_factor:
|
294 |
+
The subsampling factor of the model.
|
295 |
+
frame_shift_ms:
|
296 |
+
Frame shift in milliseconds between two contiguous frames.
|
297 |
+
|
298 |
+
Returns:
|
299 |
+
utt_time_pairs:
|
300 |
+
A list of pair list. utt_time_pairs[i] is a list of
|
301 |
+
(start-time, end-time) pairs for each word in
|
302 |
+
utterance-i.
|
303 |
+
utt_words:
|
304 |
+
A list of str list. utt_words[i] is a word list of utterence-i.
|
305 |
+
"""
|
306 |
+
topk_prob, topk_index = ctc_probs.topk(1, dim=2) # (B, maxlen, 1)
|
307 |
+
topk_index = topk_index.squeeze(2) # (B, maxlen)
|
308 |
+
mask = make_pad_mask(nnet_output_lens)
|
309 |
+
topk_index = topk_index.masked_fill_(mask, 0) # (B, maxlen)
|
310 |
+
hyps = [hyp.tolist() for hyp in topk_index]
|
311 |
+
|
312 |
+
def get_first_tokens(tokens: List[int]) -> List[bool]:
|
313 |
+
is_first_token = []
|
314 |
+
first_tokens = []
|
315 |
+
for t in range(len(tokens)):
|
316 |
+
if tokens[t] != 0 and (t == 0 or tokens[t - 1] != tokens[t]):
|
317 |
+
is_first_token.append(True)
|
318 |
+
first_tokens.append(tokens[t])
|
319 |
+
else:
|
320 |
+
is_first_token.append(False)
|
321 |
+
return first_tokens, is_first_token
|
322 |
+
|
323 |
+
utt_time_pairs = []
|
324 |
+
utt_words = []
|
325 |
+
for utt in range(len(hyps)):
|
326 |
+
first_tokens, is_first_token = get_first_tokens(hyps[utt])
|
327 |
+
all_tokens = sp.id_to_piece(hyps[utt])
|
328 |
+
index_pairs = parse_bpe_start_end_pairs(all_tokens, is_first_token)
|
329 |
+
words = sp.decode(first_tokens).split()
|
330 |
+
assert len(index_pairs) == len(words), (
|
331 |
+
len(index_pairs),
|
332 |
+
len(words),
|
333 |
+
all_tokens,
|
334 |
+
)
|
335 |
+
start = convert_timestamp(
|
336 |
+
frames=[i[0] for i in index_pairs],
|
337 |
+
subsampling_factor=subsampling_factor,
|
338 |
+
frame_shift_ms=frame_shift_ms,
|
339 |
+
)
|
340 |
+
end = convert_timestamp(
|
341 |
+
# The duration in frames is (end_frame_index - start_frame_index + 1)
|
342 |
+
frames=[i[1] + 1 for i in index_pairs],
|
343 |
+
subsampling_factor=subsampling_factor,
|
344 |
+
frame_shift_ms=frame_shift_ms,
|
345 |
+
)
|
346 |
+
utt_time_pairs.append(list(zip(start, end)))
|
347 |
+
utt_words.append(words)
|
348 |
+
|
349 |
+
return utt_time_pairs, utt_words
|
350 |
+
|
351 |
+
|
352 |
+
def remove_duplicates_and_blank(hyp: List[int]) -> Tuple[List[int], List[int]]:
|
353 |
+
# modified from https://github.com/wenet-e2e/wenet/blob/main/wenet/utils/common.py
|
354 |
+
new_hyp: List[int] = []
|
355 |
+
time: List[Tuple[int, int]] = []
|
356 |
+
cur = 0
|
357 |
+
start, end = -1, -1
|
358 |
+
while cur < len(hyp):
|
359 |
+
if hyp[cur] != 0:
|
360 |
+
new_hyp.append(hyp[cur])
|
361 |
+
start = cur
|
362 |
+
prev = cur
|
363 |
+
while cur < len(hyp) and hyp[cur] == hyp[prev]:
|
364 |
+
if start != -1:
|
365 |
+
end = cur
|
366 |
+
cur += 1
|
367 |
+
if start != -1 and end != -1:
|
368 |
+
time.append((start, end))
|
369 |
+
start, end = -1, -1
|
370 |
+
return new_hyp, time
|
371 |
+
|
372 |
+
|
373 |
+
def decode_one_batch(
|
374 |
+
params: AttributeDict,
|
375 |
+
model: nn.Module,
|
376 |
+
HLG: Optional[k2.Fsa],
|
377 |
+
H: Optional[k2.Fsa],
|
378 |
+
bpe_model: Optional[spm.SentencePieceProcessor],
|
379 |
+
batch: dict,
|
380 |
+
word_table: k2.SymbolTable,
|
381 |
+
sos_id: int,
|
382 |
+
eos_id: int,
|
383 |
+
G: Optional[k2.Fsa] = None,
|
384 |
+
) -> Dict[str, Tuple[List[List[str]], List[List[float]]]]:
|
385 |
+
"""Decode one batch and return the result in a dict. The dict has the
|
386 |
+
following format:
|
387 |
+
- key: It indicates the setting used for decoding. For example,
|
388 |
+
if no rescoring is used, the key is the string `no_rescore`.
|
389 |
+
If LM rescoring is used, the key is the string `lm_scale_xxx`,
|
390 |
+
where `xxx` is the value of `lm_scale`. An example key is
|
391 |
+
`lm_scale_0.7`
|
392 |
+
- value: It contains the decoding result. `len(value)` equals to
|
393 |
+
batch size. `value[i]` is the decoding result for the i-th
|
394 |
+
utterance in the given batch.
|
395 |
+
|
396 |
+
Args:
|
397 |
+
params:
|
398 |
+
It's the return value of :func:`get_params`.
|
399 |
+
|
400 |
+
- params.decoding_method is "1best", it uses 1best decoding without LM rescoring.
|
401 |
+
- params.decoding_method is "nbest", it uses nbest decoding without LM rescoring.
|
402 |
+
- params.decoding_method is "nbest-rescoring", it uses nbest LM rescoring.
|
403 |
+
- params.decoding_method is "whole-lattice-rescoring", it uses whole lattice LM
|
404 |
+
rescoring.
|
405 |
+
|
406 |
+
model:
|
407 |
+
The neural model.
|
408 |
+
HLG:
|
409 |
+
The decoding graph. Used only when params.decoding_method is NOT ctc-decoding.
|
410 |
+
H:
|
411 |
+
The ctc topo. Used only when params.decoding_method is ctc-decoding.
|
412 |
+
bpe_model:
|
413 |
+
The BPE model. Used only when params.decoding_method is ctc-decoding.
|
414 |
+
batch:
|
415 |
+
It is the return value from iterating
|
416 |
+
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
417 |
+
for the format of the `batch`.
|
418 |
+
word_table:
|
419 |
+
The word symbol table.
|
420 |
+
sos_id:
|
421 |
+
The token ID of the SOS.
|
422 |
+
eos_id:
|
423 |
+
The token ID of the EOS.
|
424 |
+
G:
|
425 |
+
An LM. It is not None when params.decoding_method is "nbest-rescoring"
|
426 |
+
or "whole-lattice-rescoring". In general, the G in HLG
|
427 |
+
is a 3-gram LM, while this G is a 4-gram LM.
|
428 |
+
Returns:
|
429 |
+
Return the decoding result. See above description for the format of
|
430 |
+
the returned dict. Note: If it decodes to nothing, then return None.
|
431 |
+
"""
|
432 |
+
if HLG is not None:
|
433 |
+
device = HLG.device
|
434 |
+
else:
|
435 |
+
device = H.device
|
436 |
+
feature = batch["inputs"]
|
437 |
+
assert feature.ndim == 3
|
438 |
+
feature = feature.to(device)
|
439 |
+
# at entry, feature is (N, T, C)
|
440 |
+
|
441 |
+
supervisions = batch["supervisions"]
|
442 |
+
feature_lens = supervisions["num_frames"].to(device)
|
443 |
+
|
444 |
+
if params.simulate_streaming:
|
445 |
+
feature_lens += params.left_context
|
446 |
+
feature = torch.nn.functional.pad(
|
447 |
+
feature,
|
448 |
+
pad=(0, 0, 0, params.left_context),
|
449 |
+
value=LOG_EPS,
|
450 |
+
)
|
451 |
+
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
|
452 |
+
x=feature,
|
453 |
+
x_lens=feature_lens,
|
454 |
+
chunk_size=params.decode_chunk_size,
|
455 |
+
left_context=params.left_context,
|
456 |
+
simulate_streaming=True,
|
457 |
+
)
|
458 |
+
else:
|
459 |
+
encoder_out, encoder_out_lens = model.encoder(feature, feature_lens)
|
460 |
+
|
461 |
+
nnet_output = model.get_ctc_output(encoder_out)
|
462 |
+
# nnet_output is (N, T, C)
|
463 |
+
|
464 |
+
if params.decoding_method == "ctc-greedy-search":
|
465 |
+
timestamps, hyps = ctc_greedy_search(
|
466 |
+
ctc_probs=nnet_output,
|
467 |
+
nnet_output_lens=encoder_out_lens,
|
468 |
+
sp=bpe_model,
|
469 |
+
subsampling_factor=params.subsampling_factor,
|
470 |
+
frame_shift_ms=params.frame_shift_ms,
|
471 |
+
)
|
472 |
+
key = "ctc-greedy-search"
|
473 |
+
return {key: (hyps, timestamps)}
|
474 |
+
|
475 |
+
supervision_segments = torch.stack(
|
476 |
+
(
|
477 |
+
supervisions["sequence_idx"],
|
478 |
+
supervisions["start_frame"] // params.subsampling_factor,
|
479 |
+
encoder_out_lens.cpu(),
|
480 |
+
),
|
481 |
+
1,
|
482 |
+
).to(torch.int32)
|
483 |
+
|
484 |
+
if H is None:
|
485 |
+
assert HLG is not None
|
486 |
+
decoding_graph = HLG
|
487 |
+
else:
|
488 |
+
assert HLG is None
|
489 |
+
assert bpe_model is not None
|
490 |
+
decoding_graph = H
|
491 |
+
|
492 |
+
lattice = get_lattice(
|
493 |
+
nnet_output=nnet_output,
|
494 |
+
decoding_graph=decoding_graph,
|
495 |
+
supervision_segments=supervision_segments,
|
496 |
+
search_beam=params.search_beam,
|
497 |
+
output_beam=params.output_beam,
|
498 |
+
min_active_states=params.min_active_states,
|
499 |
+
max_active_states=params.max_active_states,
|
500 |
+
subsampling_factor=params.subsampling_factor,
|
501 |
+
)
|
502 |
+
|
503 |
+
if params.decoding_method == "ctc-decoding":
|
504 |
+
best_path = one_best_decoding(
|
505 |
+
lattice=lattice, use_double_scores=params.use_double_scores
|
506 |
+
)
|
507 |
+
timestamps, hyps = parse_fsa_timestamps_and_texts(
|
508 |
+
best_paths=best_path,
|
509 |
+
sp=bpe_model,
|
510 |
+
subsampling_factor=params.subsampling_factor,
|
511 |
+
frame_shift_ms=params.frame_shift_ms,
|
512 |
+
)
|
513 |
+
key = "ctc-decoding"
|
514 |
+
return {key: (hyps, timestamps)}
|
515 |
+
|
516 |
+
if params.decoding_method == "nbest-oracle":
|
517 |
+
# Note: You can also pass rescored lattices to it.
|
518 |
+
# We choose the HLG decoded lattice for speed reasons
|
519 |
+
# as HLG decoding is faster and the oracle WER
|
520 |
+
# is only slightly worse than that of rescored lattices.
|
521 |
+
best_path = nbest_oracle(
|
522 |
+
lattice=lattice,
|
523 |
+
num_paths=params.num_paths,
|
524 |
+
ref_texts=supervisions["text"],
|
525 |
+
word_table=word_table,
|
526 |
+
nbest_scale=params.nbest_scale,
|
527 |
+
oov="<UNK>",
|
528 |
+
)
|
529 |
+
hyps = get_texts(best_path)
|
530 |
+
hyps = [[word_table[i] for i in ids] for ids in hyps]
|
531 |
+
timestamps = [[] for _ in range(len(hyps))]
|
532 |
+
key = f"oracle_{params.num_paths}_nbest_scale_{params.nbest_scale}_hlg_scale_{params.hlg_scale}" # noqa
|
533 |
+
return {key: (hyps, timestamps)}
|
534 |
+
|
535 |
+
if params.decoding_method in ["1best", "nbest"]:
|
536 |
+
if params.decoding_method == "1best":
|
537 |
+
best_path = one_best_decoding(
|
538 |
+
lattice=lattice, use_double_scores=params.use_double_scores
|
539 |
+
)
|
540 |
+
key = f"no_rescore_hlg_scale_{params.hlg_scale}"
|
541 |
+
timestamps, hyps = parse_fsa_timestamps_and_texts(
|
542 |
+
best_paths=best_path,
|
543 |
+
word_table=word_table,
|
544 |
+
subsampling_factor=params.subsampling_factor,
|
545 |
+
frame_shift_ms=params.frame_shift_ms,
|
546 |
+
)
|
547 |
+
else:
|
548 |
+
best_path = nbest_decoding(
|
549 |
+
lattice=lattice,
|
550 |
+
num_paths=params.num_paths,
|
551 |
+
use_double_scores=params.use_double_scores,
|
552 |
+
nbest_scale=params.nbest_scale,
|
553 |
+
)
|
554 |
+
key = f"no_rescore-nbest-scale-{params.nbest_scale}-{params.num_paths}-hlg-scale-{params.hlg_scale}" # noqa
|
555 |
+
hyps = get_texts(best_path)
|
556 |
+
hyps = [[word_table[i] for i in ids] for ids in hyps]
|
557 |
+
timestamps = [[] for _ in range(len(hyps))]
|
558 |
+
return {key: (hyps, timestamps)}
|
559 |
+
|
560 |
+
assert params.decoding_method in [
|
561 |
+
"nbest-rescoring",
|
562 |
+
"whole-lattice-rescoring",
|
563 |
+
]
|
564 |
+
|
565 |
+
lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
|
566 |
+
lm_scale_list += [0.8, 0.9, 1.0, 1.1, 1.2, 1.3]
|
567 |
+
lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0]
|
568 |
+
|
569 |
+
if params.decoding_method == "nbest-rescoring":
|
570 |
+
best_path_dict = rescore_with_n_best_list(
|
571 |
+
lattice=lattice,
|
572 |
+
G=G,
|
573 |
+
num_paths=params.num_paths,
|
574 |
+
lm_scale_list=lm_scale_list,
|
575 |
+
nbest_scale=params.nbest_scale,
|
576 |
+
)
|
577 |
+
elif params.decoding_method == "whole-lattice-rescoring":
|
578 |
+
best_path_dict = rescore_with_whole_lattice(
|
579 |
+
lattice=lattice,
|
580 |
+
G_with_epsilon_loops=G,
|
581 |
+
lm_scale_list=lm_scale_list,
|
582 |
+
)
|
583 |
+
else:
|
584 |
+
assert False, f"Unsupported decoding method: {params.decoding_method}"
|
585 |
+
|
586 |
+
ans = dict()
|
587 |
+
if best_path_dict is not None:
|
588 |
+
for lm_scale_str, best_path in best_path_dict.items():
|
589 |
+
hyps = get_texts(best_path)
|
590 |
+
hyps = [[word_table[i] for i in ids] for ids in hyps]
|
591 |
+
timestamps = [[] for _ in range(len(hyps))]
|
592 |
+
ans[lm_scale_str] = (hyps, timestamps)
|
593 |
+
else:
|
594 |
+
ans = None
|
595 |
+
return ans
|
596 |
+
|
597 |
+
|
598 |
+
def decode_dataset(
|
599 |
+
dl: torch.utils.data.DataLoader,
|
600 |
+
params: AttributeDict,
|
601 |
+
model: nn.Module,
|
602 |
+
HLG: Optional[k2.Fsa],
|
603 |
+
H: Optional[k2.Fsa],
|
604 |
+
bpe_model: Optional[spm.SentencePieceProcessor],
|
605 |
+
word_table: k2.SymbolTable,
|
606 |
+
sos_id: int,
|
607 |
+
eos_id: int,
|
608 |
+
G: Optional[k2.Fsa] = None,
|
609 |
+
) -> Dict[
|
610 |
+
str,
|
611 |
+
List[
|
612 |
+
Tuple[
|
613 |
+
str,
|
614 |
+
List[str],
|
615 |
+
List[str],
|
616 |
+
List[Tuple[float, float]],
|
617 |
+
List[Tuple[float, float]],
|
618 |
+
]
|
619 |
+
],
|
620 |
+
]:
|
621 |
+
"""Decode dataset.
|
622 |
+
|
623 |
+
Args:
|
624 |
+
dl:
|
625 |
+
PyTorch's dataloader containing the dataset to decode.
|
626 |
+
params:
|
627 |
+
It is returned by :func:`get_params`.
|
628 |
+
model:
|
629 |
+
The neural model.
|
630 |
+
HLG:
|
631 |
+
The decoding graph. Used only when params.decoding_method is NOT ctc-decoding.
|
632 |
+
H:
|
633 |
+
The ctc topo. Used only when params.decoding_method is ctc-decoding.
|
634 |
+
bpe_model:
|
635 |
+
The BPE model. Used only when params.decoding_method is ctc-decoding.
|
636 |
+
word_table:
|
637 |
+
It is the word symbol table.
|
638 |
+
sos_id:
|
639 |
+
The token ID for SOS.
|
640 |
+
eos_id:
|
641 |
+
The token ID for EOS.
|
642 |
+
G:
|
643 |
+
An LM. It is not None when params.decoding_method is "nbest-rescoring"
|
644 |
+
or "whole-lattice-rescoring". In general, the G in HLG
|
645 |
+
is a 3-gram LM, while this G is a 4-gram LM.
|
646 |
+
Returns:
|
647 |
+
Return a dict, whose key may be "no-rescore" if no LM rescoring
|
648 |
+
is used, or it may be "lm_scale_0.7" if LM rescoring is used.
|
649 |
+
Its value is a list of tuples. Each tuple contains two elements:
|
650 |
+
The first is the reference transcript, and the second is the
|
651 |
+
predicted result.
|
652 |
+
"""
|
653 |
+
num_cuts = 0
|
654 |
+
|
655 |
+
try:
|
656 |
+
num_batches = len(dl)
|
657 |
+
except TypeError:
|
658 |
+
num_batches = "?"
|
659 |
+
|
660 |
+
results = defaultdict(list)
|
661 |
+
for batch_idx, batch in enumerate(dl):
|
662 |
+
texts = batch["supervisions"]["text"]
|
663 |
+
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
664 |
+
|
665 |
+
timestamps_ref = []
|
666 |
+
for cut in batch["supervisions"]["cut"]:
|
667 |
+
for s in cut.supervisions:
|
668 |
+
time = []
|
669 |
+
if s.alignment is not None and "word" in s.alignment:
|
670 |
+
time = [
|
671 |
+
(aliword.start, aliword.end)
|
672 |
+
for aliword in s.alignment["word"]
|
673 |
+
if aliword.symbol != ""
|
674 |
+
]
|
675 |
+
timestamps_ref.append(time)
|
676 |
+
|
677 |
+
hyps_dict = decode_one_batch(
|
678 |
+
params=params,
|
679 |
+
model=model,
|
680 |
+
HLG=HLG,
|
681 |
+
H=H,
|
682 |
+
bpe_model=bpe_model,
|
683 |
+
batch=batch,
|
684 |
+
word_table=word_table,
|
685 |
+
G=G,
|
686 |
+
sos_id=sos_id,
|
687 |
+
eos_id=eos_id,
|
688 |
+
)
|
689 |
+
|
690 |
+
for name, (hyps, timestamps_hyp) in hyps_dict.items():
|
691 |
+
this_batch = []
|
692 |
+
assert len(hyps) == len(texts) and len(timestamps_hyp) == len(
|
693 |
+
timestamps_ref
|
694 |
+
)
|
695 |
+
for cut_id, hyp_words, ref_text, time_hyp, time_ref in zip(
|
696 |
+
cut_ids, hyps, texts, timestamps_hyp, timestamps_ref
|
697 |
+
):
|
698 |
+
ref_words = ref_text.split()
|
699 |
+
this_batch.append((cut_id, ref_words, hyp_words, time_ref, time_hyp))
|
700 |
+
|
701 |
+
results[name].extend(this_batch)
|
702 |
+
|
703 |
+
num_cuts += len(texts)
|
704 |
+
|
705 |
+
if batch_idx % 100 == 0:
|
706 |
+
batch_str = f"{batch_idx}/{num_batches}"
|
707 |
+
|
708 |
+
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
|
709 |
+
return results
|
710 |
+
|
711 |
+
|
712 |
+
def save_results(
|
713 |
+
params: AttributeDict,
|
714 |
+
test_set_name: str,
|
715 |
+
results_dict: Dict[
|
716 |
+
str,
|
717 |
+
List[
|
718 |
+
Tuple[
|
719 |
+
List[str],
|
720 |
+
List[str],
|
721 |
+
List[str],
|
722 |
+
List[Tuple[float, float]],
|
723 |
+
List[Tuple[float, float]],
|
724 |
+
]
|
725 |
+
],
|
726 |
+
],
|
727 |
+
):
|
728 |
+
test_set_wers = dict()
|
729 |
+
test_set_delays = dict()
|
730 |
+
for key, results in results_dict.items():
|
731 |
+
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
|
732 |
+
results = sorted(results)
|
733 |
+
store_transcripts_and_timestamps(filename=recog_path, texts=results)
|
734 |
+
logging.info(f"The transcripts are stored in {recog_path}")
|
735 |
+
|
736 |
+
# The following prints out WERs, per-word error statistics and aligned
|
737 |
+
# ref/hyp pairs.
|
738 |
+
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
|
739 |
+
with open(errs_filename, "w") as f:
|
740 |
+
wer, mean_delay, var_delay = write_error_stats_with_timestamps(
|
741 |
+
f,
|
742 |
+
f"{test_set_name}-{key}",
|
743 |
+
results,
|
744 |
+
enable_log=True,
|
745 |
+
with_end_time=True,
|
746 |
+
)
|
747 |
+
test_set_wers[key] = wer
|
748 |
+
test_set_delays[key] = (mean_delay, var_delay)
|
749 |
+
|
750 |
+
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
751 |
+
|
752 |
+
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
753 |
+
errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
|
754 |
+
with open(errs_info, "w") as f:
|
755 |
+
print("settings\tWER", file=f)
|
756 |
+
for key, val in test_set_wers:
|
757 |
+
print("{}\t{}".format(key, val), file=f)
|
758 |
+
|
759 |
+
# sort according to the mean start symbol delay
|
760 |
+
test_set_delays = sorted(test_set_delays.items(), key=lambda x: x[1][0][0])
|
761 |
+
delays_info = (
|
762 |
+
params.res_dir / f"symbol-delay-summary-{test_set_name}-{params.suffix}.txt"
|
763 |
+
)
|
764 |
+
with open(delays_info, "w") as f:
|
765 |
+
print("settings\t(start, end) symbol-delay (s) (start, end)", file=f)
|
766 |
+
for key, val in test_set_delays:
|
767 |
+
print(
|
768 |
+
"{}\tmean: {}, variance: {}".format(key, val[0], val[1]),
|
769 |
+
file=f,
|
770 |
+
)
|
771 |
+
|
772 |
+
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
|
773 |
+
note = "\tbest for {}".format(test_set_name)
|
774 |
+
for key, val in test_set_wers:
|
775 |
+
s += "{}\t{}{}\n".format(key, val, note)
|
776 |
+
note = ""
|
777 |
+
logging.info(s)
|
778 |
+
|
779 |
+
s = "\nFor {}, (start, end) symbol-delay (s) of different settings are:\n".format(
|
780 |
+
test_set_name
|
781 |
+
)
|
782 |
+
note = "\tbest for {}".format(test_set_name)
|
783 |
+
for key, val in test_set_delays:
|
784 |
+
s += "{}\tmean: {}, variance: {}{}\n".format(key, val[0], val[1], note)
|
785 |
+
note = ""
|
786 |
+
logging.info(s)
|
787 |
+
|
788 |
+
|
789 |
+
@torch.no_grad()
|
790 |
+
def main():
|
791 |
+
parser = get_parser()
|
792 |
+
LibriSpeechAsrDataModule.add_arguments(parser)
|
793 |
+
args = parser.parse_args()
|
794 |
+
args.exp_dir = Path(args.exp_dir)
|
795 |
+
args.lang_dir = Path(args.lang_dir)
|
796 |
+
args.lm_dir = Path(args.lm_dir)
|
797 |
+
|
798 |
+
params = get_params()
|
799 |
+
# add decoding params
|
800 |
+
params.update(get_decoding_params())
|
801 |
+
params.update(vars(args))
|
802 |
+
|
803 |
+
assert params.decoding_method in (
|
804 |
+
"ctc-greedy-search",
|
805 |
+
"ctc-decoding",
|
806 |
+
"1best",
|
807 |
+
"nbest",
|
808 |
+
"nbest-rescoring",
|
809 |
+
"whole-lattice-rescoring",
|
810 |
+
"nbest-oracle",
|
811 |
+
)
|
812 |
+
params.res_dir = params.exp_dir / params.decoding_method
|
813 |
+
|
814 |
+
if params.iter > 0:
|
815 |
+
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
816 |
+
else:
|
817 |
+
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
818 |
+
|
819 |
+
if params.simulate_streaming:
|
820 |
+
params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}"
|
821 |
+
params.suffix += f"-left-context-{params.left_context}"
|
822 |
+
|
823 |
+
if params.simulate_streaming:
|
824 |
+
assert (
|
825 |
+
params.causal_convolution
|
826 |
+
), "Decoding in streaming requires causal convolution"
|
827 |
+
|
828 |
+
if params.use_averaged_model:
|
829 |
+
params.suffix += "-use-averaged-model"
|
830 |
+
|
831 |
+
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
832 |
+
logging.info("Decoding started")
|
833 |
+
|
834 |
+
device = torch.device("cpu")
|
835 |
+
if torch.cuda.is_available():
|
836 |
+
device = torch.device("cuda", 0)
|
837 |
+
|
838 |
+
logging.info(f"Device: {device}")
|
839 |
+
logging.info(params)
|
840 |
+
|
841 |
+
lexicon = Lexicon(params.lang_dir)
|
842 |
+
max_token_id = max(lexicon.tokens)
|
843 |
+
num_classes = max_token_id + 1 # +1 for the blank
|
844 |
+
|
845 |
+
graph_compiler = BpeCtcTrainingGraphCompiler(
|
846 |
+
params.lang_dir,
|
847 |
+
device=device,
|
848 |
+
sos_token="<sos/eos>",
|
849 |
+
eos_token="<sos/eos>",
|
850 |
+
)
|
851 |
+
sos_id = graph_compiler.sos_id
|
852 |
+
eos_id = graph_compiler.eos_id
|
853 |
+
|
854 |
+
params.vocab_size = num_classes
|
855 |
+
params.sos_id = sos_id
|
856 |
+
params.eos_id = eos_id
|
857 |
+
|
858 |
+
if params.decoding_method in ["ctc-decoding", "ctc-greedy-search"]:
|
859 |
+
HLG = None
|
860 |
+
H = k2.ctc_topo(
|
861 |
+
max_token=max_token_id,
|
862 |
+
modified=False,
|
863 |
+
device=device,
|
864 |
+
)
|
865 |
+
bpe_model = spm.SentencePieceProcessor()
|
866 |
+
bpe_model.load(str(params.lang_dir / "bpe.model"))
|
867 |
+
else:
|
868 |
+
H = None
|
869 |
+
bpe_model = None
|
870 |
+
HLG = k2.Fsa.from_dict(
|
871 |
+
torch.load(f"{params.lang_dir}/HLG.pt", map_location=device)
|
872 |
+
)
|
873 |
+
assert HLG.requires_grad is False
|
874 |
+
|
875 |
+
HLG.scores *= params.hlg_scale
|
876 |
+
if not hasattr(HLG, "lm_scores"):
|
877 |
+
HLG.lm_scores = HLG.scores.clone()
|
878 |
+
|
879 |
+
if params.decoding_method in (
|
880 |
+
"nbest-rescoring",
|
881 |
+
"whole-lattice-rescoring",
|
882 |
+
):
|
883 |
+
if not (params.lm_dir / "G_4_gram.pt").is_file():
|
884 |
+
logging.info("Loading G_4_gram.fst.txt")
|
885 |
+
logging.warning("It may take 8 minutes.")
|
886 |
+
with open(params.lm_dir / "G_4_gram.fst.txt") as f:
|
887 |
+
first_word_disambig_id = lexicon.word_table["#0"]
|
888 |
+
|
889 |
+
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
|
890 |
+
# G.aux_labels is not needed in later computations, so
|
891 |
+
# remove it here.
|
892 |
+
del G.aux_labels
|
893 |
+
# CAUTION: The following line is crucial.
|
894 |
+
# Arcs entering the back-off state have label equal to #0.
|
895 |
+
# We have to change it to 0 here.
|
896 |
+
G.labels[G.labels >= first_word_disambig_id] = 0
|
897 |
+
# See https://github.com/k2-fsa/k2/issues/874
|
898 |
+
# for why we need to set G.properties to None
|
899 |
+
G.__dict__["_properties"] = None
|
900 |
+
G = k2.Fsa.from_fsas([G]).to(device)
|
901 |
+
G = k2.arc_sort(G)
|
902 |
+
# Save a dummy value so that it can be loaded in C++.
|
903 |
+
# See https://github.com/pytorch/pytorch/issues/67902
|
904 |
+
# for why we need to do this.
|
905 |
+
G.dummy = 1
|
906 |
+
|
907 |
+
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
|
908 |
+
else:
|
909 |
+
logging.info("Loading pre-compiled G_4_gram.pt")
|
910 |
+
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device)
|
911 |
+
G = k2.Fsa.from_dict(d)
|
912 |
+
|
913 |
+
if params.decoding_method == "whole-lattice-rescoring":
|
914 |
+
# Add epsilon self-loops to G as we will compose
|
915 |
+
# it with the whole lattice later
|
916 |
+
G = k2.add_epsilon_self_loops(G)
|
917 |
+
G = k2.arc_sort(G)
|
918 |
+
G = G.to(device)
|
919 |
+
|
920 |
+
# G.lm_scores is used to replace HLG.lm_scores during
|
921 |
+
# LM rescoring.
|
922 |
+
G.lm_scores = G.scores.clone()
|
923 |
+
else:
|
924 |
+
G = None
|
925 |
+
|
926 |
+
logging.info("About to create model")
|
927 |
+
model = get_ctc_model(params)
|
928 |
+
|
929 |
+
if not params.use_averaged_model:
|
930 |
+
if params.iter > 0:
|
931 |
+
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
932 |
+
: params.avg
|
933 |
+
]
|
934 |
+
if len(filenames) == 0:
|
935 |
+
raise ValueError(
|
936 |
+
f"No checkpoints found for"
|
937 |
+
f" --iter {params.iter}, --avg {params.avg}"
|
938 |
+
)
|
939 |
+
elif len(filenames) < params.avg:
|
940 |
+
raise ValueError(
|
941 |
+
f"Not enough checkpoints ({len(filenames)}) found for"
|
942 |
+
f" --iter {params.iter}, --avg {params.avg}"
|
943 |
+
)
|
944 |
+
logging.info(f"averaging {filenames}")
|
945 |
+
model.to(device)
|
946 |
+
model.load_state_dict(average_checkpoints(filenames, device=device))
|
947 |
+
elif params.avg == 1:
|
948 |
+
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
949 |
+
else:
|
950 |
+
start = params.epoch - params.avg + 1
|
951 |
+
filenames = []
|
952 |
+
for i in range(start, params.epoch + 1):
|
953 |
+
if i >= 1:
|
954 |
+
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
955 |
+
logging.info(f"averaging {filenames}")
|
956 |
+
model.to(device)
|
957 |
+
model.load_state_dict(average_checkpoints(filenames, device=device))
|
958 |
+
else:
|
959 |
+
if params.iter > 0:
|
960 |
+
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
961 |
+
: params.avg + 1
|
962 |
+
]
|
963 |
+
if len(filenames) == 0:
|
964 |
+
raise ValueError(
|
965 |
+
f"No checkpoints found for"
|
966 |
+
f" --iter {params.iter}, --avg {params.avg}"
|
967 |
+
)
|
968 |
+
elif len(filenames) < params.avg + 1:
|
969 |
+
raise ValueError(
|
970 |
+
f"Not enough checkpoints ({len(filenames)}) found for"
|
971 |
+
f" --iter {params.iter}, --avg {params.avg}"
|
972 |
+
)
|
973 |
+
filename_start = filenames[-1]
|
974 |
+
filename_end = filenames[0]
|
975 |
+
logging.info(
|
976 |
+
"Calculating the averaged model over iteration checkpoints"
|
977 |
+
f" from {filename_start} (excluded) to {filename_end}"
|
978 |
+
)
|
979 |
+
model.to(device)
|
980 |
+
model.load_state_dict(
|
981 |
+
average_checkpoints_with_averaged_model(
|
982 |
+
filename_start=filename_start,
|
983 |
+
filename_end=filename_end,
|
984 |
+
device=device,
|
985 |
+
)
|
986 |
+
)
|
987 |
+
else:
|
988 |
+
assert params.avg > 0, params.avg
|
989 |
+
start = params.epoch - params.avg
|
990 |
+
assert start >= 1, start
|
991 |
+
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
992 |
+
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
993 |
+
logging.info(
|
994 |
+
f"Calculating the averaged model over epoch range from "
|
995 |
+
f"{start} (excluded) to {params.epoch}"
|
996 |
+
)
|
997 |
+
model.to(device)
|
998 |
+
model.load_state_dict(
|
999 |
+
average_checkpoints_with_averaged_model(
|
1000 |
+
filename_start=filename_start,
|
1001 |
+
filename_end=filename_end,
|
1002 |
+
device=device,
|
1003 |
+
)
|
1004 |
+
)
|
1005 |
+
|
1006 |
+
model.to(device)
|
1007 |
+
model.eval()
|
1008 |
+
|
1009 |
+
num_param = sum([p.numel() for p in model.parameters()])
|
1010 |
+
logging.info(f"Number of model parameters: {num_param}")
|
1011 |
+
|
1012 |
+
# we need cut ids to display recognition results.
|
1013 |
+
args.return_cuts = True
|
1014 |
+
librispeech = LibriSpeechAsrDataModule(args)
|
1015 |
+
|
1016 |
+
test_clean_cuts = librispeech.test_clean_cuts()
|
1017 |
+
#test_other_cuts = librispeech.test_other_cuts()
|
1018 |
+
|
1019 |
+
test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
|
1020 |
+
#test_other_dl = librispeech.test_dataloaders(test_other_cuts)
|
1021 |
+
|
1022 |
+
#test_sets = ["test-clean", "test-other"]
|
1023 |
+
#test_dl = [test_clean_dl, test_other_dl]
|
1024 |
+
|
1025 |
+
test_sets = ["test-clean"]
|
1026 |
+
test_dl = [test_clean_dl]
|
1027 |
+
|
1028 |
+
for test_set, test_dl in zip(test_sets, test_dl):
|
1029 |
+
results_dict = decode_dataset(
|
1030 |
+
dl=test_dl,
|
1031 |
+
params=params,
|
1032 |
+
model=model,
|
1033 |
+
HLG=HLG,
|
1034 |
+
H=H,
|
1035 |
+
bpe_model=bpe_model,
|
1036 |
+
word_table=lexicon.word_table,
|
1037 |
+
G=G,
|
1038 |
+
sos_id=sos_id,
|
1039 |
+
eos_id=eos_id,
|
1040 |
+
)
|
1041 |
+
|
1042 |
+
save_results(
|
1043 |
+
params=params,
|
1044 |
+
test_set_name=test_set,
|
1045 |
+
results_dict=results_dict,
|
1046 |
+
)
|
1047 |
+
|
1048 |
+
logging.info("Done!")
|
1049 |
+
|
1050 |
+
|
1051 |
+
if __name__ == "__main__":
|
1052 |
+
main()
|
err2020/conformer_ctc3/encoder_interface.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
2 |
+
#
|
3 |
+
# See ../../../../LICENSE for clarification regarding multiple authors
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
|
17 |
+
from typing import Tuple
|
18 |
+
|
19 |
+
import torch
|
20 |
+
import torch.nn as nn
|
21 |
+
|
22 |
+
|
23 |
+
class EncoderInterface(nn.Module):
|
24 |
+
def forward(
|
25 |
+
self, x: torch.Tensor, x_lens: torch.Tensor
|
26 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
27 |
+
"""
|
28 |
+
Args:
|
29 |
+
x:
|
30 |
+
A tensor of shape (batch_size, input_seq_len, num_features)
|
31 |
+
containing the input features.
|
32 |
+
x_lens:
|
33 |
+
A tensor of shape (batch_size,) containing the number of frames
|
34 |
+
in `x` before padding.
|
35 |
+
Returns:
|
36 |
+
Return a tuple containing two tensors:
|
37 |
+
- encoder_out, a tensor of (batch_size, out_seq_len, output_dim)
|
38 |
+
containing unnormalized probabilities, i.e., the output of a
|
39 |
+
linear layer.
|
40 |
+
- encoder_out_lens, a tensor of shape (batch_size,) containing
|
41 |
+
the number of frames in `encoder_out` before padding.
|
42 |
+
"""
|
43 |
+
raise NotImplementedError("Please implement it in a subclass")
|
err2020/conformer_ctc3/exp/jit_trace.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0d364f175859dcfff11dcd5eea7032f568b77af7c9ffa15ff9f405c69983d58b
|
3 |
+
size 330828854
|
err2020/conformer_ctc3/export.py
ADDED
@@ -0,0 +1,292 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
#
|
3 |
+
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
|
4 |
+
#
|
5 |
+
# See ../../../../LICENSE for clarification regarding multiple authors
|
6 |
+
#
|
7 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
8 |
+
# you may not use this file except in compliance with the License.
|
9 |
+
# You may obtain a copy of the License at
|
10 |
+
#
|
11 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
12 |
+
#
|
13 |
+
# Unless required by applicable law or agreed to in writing, software
|
14 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
15 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
16 |
+
# See the License for the specific language governing permissions and
|
17 |
+
# limitations under the License.
|
18 |
+
|
19 |
+
# This script converts several saved checkpoints
|
20 |
+
# to a single one using model averaging.
|
21 |
+
"""
|
22 |
+
Usage:
|
23 |
+
|
24 |
+
(1) Export to torchscript model using torch.jit.trace()
|
25 |
+
|
26 |
+
./conformer_ctc3/export.py \
|
27 |
+
--exp-dir ./conformer_ctc3/exp \
|
28 |
+
--lang-dir data/lang_bpe_500 \
|
29 |
+
--epoch 20 \
|
30 |
+
--avg 10 \
|
31 |
+
--jit-trace 1
|
32 |
+
|
33 |
+
It will generates the file: `jit_trace.pt`.
|
34 |
+
|
35 |
+
(2) Export `model.state_dict()`
|
36 |
+
|
37 |
+
./conformer_ctc3/export.py \
|
38 |
+
--exp-dir ./conformer_ctc3/exp \
|
39 |
+
--lang-dir data/lang_bpe_500 \
|
40 |
+
--epoch 20 \
|
41 |
+
--avg 10
|
42 |
+
|
43 |
+
It will generate a file `pretrained.pt` in the given `exp_dir`. You can later
|
44 |
+
load it by `icefall.checkpoint.load_checkpoint()`.
|
45 |
+
|
46 |
+
To use the generated file with `conformer_ctc3/decode.py`,
|
47 |
+
you can do:
|
48 |
+
|
49 |
+
cd /path/to/exp_dir
|
50 |
+
ln -s pretrained.pt epoch-9999.pt
|
51 |
+
|
52 |
+
cd /path/to/egs/librispeech/ASR
|
53 |
+
./conformer_ctc3/decode.py \
|
54 |
+
--exp-dir ./conformer_ctc3/exp \
|
55 |
+
--epoch 9999 \
|
56 |
+
--avg 1 \
|
57 |
+
--max-duration 100 \
|
58 |
+
--lang-dir data/lang_bpe_500
|
59 |
+
"""
|
60 |
+
|
61 |
+
import argparse
|
62 |
+
import logging
|
63 |
+
from pathlib import Path
|
64 |
+
|
65 |
+
import torch
|
66 |
+
from scaling_converter import convert_scaled_to_non_scaled
|
67 |
+
from train import add_model_arguments, get_ctc_model, get_params
|
68 |
+
|
69 |
+
from icefall.checkpoint import (
|
70 |
+
average_checkpoints,
|
71 |
+
average_checkpoints_with_averaged_model,
|
72 |
+
find_checkpoints,
|
73 |
+
load_checkpoint,
|
74 |
+
)
|
75 |
+
from icefall.lexicon import Lexicon
|
76 |
+
from icefall.utils import str2bool
|
77 |
+
|
78 |
+
|
79 |
+
def get_parser():
|
80 |
+
parser = argparse.ArgumentParser(
|
81 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
82 |
+
)
|
83 |
+
|
84 |
+
parser.add_argument(
|
85 |
+
"--epoch",
|
86 |
+
type=int,
|
87 |
+
default=28,
|
88 |
+
help="""It specifies the checkpoint to use for averaging.
|
89 |
+
Note: Epoch counts from 0.
|
90 |
+
You can specify --avg to use more checkpoints for model averaging.""",
|
91 |
+
)
|
92 |
+
|
93 |
+
parser.add_argument(
|
94 |
+
"--iter",
|
95 |
+
type=int,
|
96 |
+
default=0,
|
97 |
+
help="""If positive, --epoch is ignored and it
|
98 |
+
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
99 |
+
You can specify --avg to use more checkpoints for model averaging.
|
100 |
+
""",
|
101 |
+
)
|
102 |
+
|
103 |
+
parser.add_argument(
|
104 |
+
"--avg",
|
105 |
+
type=int,
|
106 |
+
default=15,
|
107 |
+
help="Number of checkpoints to average. Automatically select "
|
108 |
+
"consecutive checkpoints before the checkpoint specified by "
|
109 |
+
"'--epoch' and '--iter'",
|
110 |
+
)
|
111 |
+
|
112 |
+
parser.add_argument(
|
113 |
+
"--use-averaged-model",
|
114 |
+
type=str2bool,
|
115 |
+
default=True,
|
116 |
+
help="Whether to load averaged model. Currently it only supports "
|
117 |
+
"using --epoch. If True, it would decode with the averaged model "
|
118 |
+
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
119 |
+
"Actually only the models with epoch number of `epoch-avg` and "
|
120 |
+
"`epoch` are loaded for averaging. ",
|
121 |
+
)
|
122 |
+
|
123 |
+
parser.add_argument(
|
124 |
+
"--exp-dir",
|
125 |
+
type=str,
|
126 |
+
default="pruned_transducer_stateless4/exp",
|
127 |
+
help="""It specifies the directory where all training related
|
128 |
+
files, e.g., checkpoints, log, etc, are saved
|
129 |
+
""",
|
130 |
+
)
|
131 |
+
|
132 |
+
parser.add_argument(
|
133 |
+
"--lang-dir",
|
134 |
+
type=Path,
|
135 |
+
default="data/lang_bpe_500",
|
136 |
+
help="The lang dir containing word table and LG graph",
|
137 |
+
)
|
138 |
+
|
139 |
+
parser.add_argument(
|
140 |
+
"--jit-trace",
|
141 |
+
type=str2bool,
|
142 |
+
default=False,
|
143 |
+
help="""True to save a model after applying torch.jit.script.
|
144 |
+
""",
|
145 |
+
)
|
146 |
+
|
147 |
+
parser.add_argument(
|
148 |
+
"--streaming-model",
|
149 |
+
type=str2bool,
|
150 |
+
default=False,
|
151 |
+
help="""Whether to export a streaming model, if the models in exp-dir
|
152 |
+
are streaming model, this should be True.
|
153 |
+
""",
|
154 |
+
)
|
155 |
+
|
156 |
+
add_model_arguments(parser)
|
157 |
+
|
158 |
+
return parser
|
159 |
+
|
160 |
+
|
161 |
+
def main():
|
162 |
+
args = get_parser().parse_args()
|
163 |
+
args.exp_dir = Path(args.exp_dir)
|
164 |
+
|
165 |
+
params = get_params()
|
166 |
+
params.update(vars(args))
|
167 |
+
|
168 |
+
device = torch.device("cpu")
|
169 |
+
if torch.cuda.is_available():
|
170 |
+
device = torch.device("cuda", 0)
|
171 |
+
|
172 |
+
logging.info(f"device: {device}")
|
173 |
+
|
174 |
+
lexicon = Lexicon(params.lang_dir)
|
175 |
+
max_token_id = max(lexicon.tokens)
|
176 |
+
num_classes = max_token_id + 1 # +1 for the blank
|
177 |
+
params.vocab_size = num_classes
|
178 |
+
|
179 |
+
if params.streaming_model:
|
180 |
+
assert params.causal_convolution
|
181 |
+
|
182 |
+
logging.info(params)
|
183 |
+
|
184 |
+
logging.info("About to create model")
|
185 |
+
model = get_ctc_model(params)
|
186 |
+
|
187 |
+
model.to(device)
|
188 |
+
|
189 |
+
if not params.use_averaged_model:
|
190 |
+
if params.iter > 0:
|
191 |
+
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
192 |
+
: params.avg
|
193 |
+
]
|
194 |
+
if len(filenames) == 0:
|
195 |
+
raise ValueError(
|
196 |
+
f"No checkpoints found for"
|
197 |
+
f" --iter {params.iter}, --avg {params.avg}"
|
198 |
+
)
|
199 |
+
elif len(filenames) < params.avg:
|
200 |
+
raise ValueError(
|
201 |
+
f"Not enough checkpoints ({len(filenames)}) found for"
|
202 |
+
f" --iter {params.iter}, --avg {params.avg}"
|
203 |
+
)
|
204 |
+
logging.info(f"averaging {filenames}")
|
205 |
+
model.load_state_dict(average_checkpoints(filenames, device=device))
|
206 |
+
elif params.avg == 1:
|
207 |
+
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
208 |
+
else:
|
209 |
+
start = params.epoch - params.avg + 1
|
210 |
+
filenames = []
|
211 |
+
for i in range(start, params.epoch + 1):
|
212 |
+
if i >= 1:
|
213 |
+
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
214 |
+
logging.info(f"averaging {filenames}")
|
215 |
+
model.load_state_dict(average_checkpoints(filenames, device=device))
|
216 |
+
else:
|
217 |
+
if params.iter > 0:
|
218 |
+
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
219 |
+
: params.avg + 1
|
220 |
+
]
|
221 |
+
if len(filenames) == 0:
|
222 |
+
raise ValueError(
|
223 |
+
f"No checkpoints found for"
|
224 |
+
f" --iter {params.iter}, --avg {params.avg}"
|
225 |
+
)
|
226 |
+
elif len(filenames) < params.avg + 1:
|
227 |
+
raise ValueError(
|
228 |
+
f"Not enough checkpoints ({len(filenames)}) found for"
|
229 |
+
f" --iter {params.iter}, --avg {params.avg}"
|
230 |
+
)
|
231 |
+
filename_start = filenames[-1]
|
232 |
+
filename_end = filenames[0]
|
233 |
+
logging.info(
|
234 |
+
"Calculating the averaged model over iteration checkpoints"
|
235 |
+
f" from {filename_start} (excluded) to {filename_end}"
|
236 |
+
)
|
237 |
+
model.load_state_dict(
|
238 |
+
average_checkpoints_with_averaged_model(
|
239 |
+
filename_start=filename_start,
|
240 |
+
filename_end=filename_end,
|
241 |
+
device=device,
|
242 |
+
)
|
243 |
+
)
|
244 |
+
else:
|
245 |
+
assert params.avg > 0, params.avg
|
246 |
+
start = params.epoch - params.avg
|
247 |
+
assert start >= 1, start
|
248 |
+
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
249 |
+
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
250 |
+
logging.info(
|
251 |
+
f"Calculating the averaged model over epoch range from "
|
252 |
+
f"{start} (excluded) to {params.epoch}"
|
253 |
+
)
|
254 |
+
model.load_state_dict(
|
255 |
+
average_checkpoints_with_averaged_model(
|
256 |
+
filename_start=filename_start,
|
257 |
+
filename_end=filename_end,
|
258 |
+
device=device,
|
259 |
+
)
|
260 |
+
)
|
261 |
+
|
262 |
+
model.to("cpu")
|
263 |
+
model.eval()
|
264 |
+
|
265 |
+
if params.jit_trace:
|
266 |
+
# TODO: will support streaming mode
|
267 |
+
assert not params.streaming_model
|
268 |
+
convert_scaled_to_non_scaled(model, inplace=True)
|
269 |
+
|
270 |
+
logging.info("Using torch.jit.trace()")
|
271 |
+
|
272 |
+
x = torch.zeros(1, 100, 80, dtype=torch.float32)
|
273 |
+
x_lens = torch.tensor([100], dtype=torch.int64)
|
274 |
+
traced_model = torch.jit.trace(model, (x, x_lens))
|
275 |
+
|
276 |
+
filename = params.exp_dir / "jit_trace.pt"
|
277 |
+
traced_model.save(str(filename))
|
278 |
+
logging.info(f"Saved to {filename}")
|
279 |
+
else:
|
280 |
+
logging.info("Not using torch.jit.trace()")
|
281 |
+
# Save it using a format so that it can be loaded
|
282 |
+
# by :func:`load_checkpoint`
|
283 |
+
filename = params.exp_dir / "pretrained.pt"
|
284 |
+
torch.save({"model": model.state_dict()}, str(filename))
|
285 |
+
logging.info(f"Saved to {filename}")
|
286 |
+
|
287 |
+
|
288 |
+
if __name__ == "__main__":
|
289 |
+
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
290 |
+
|
291 |
+
logging.basicConfig(format=formatter, level=logging.INFO)
|
292 |
+
main()
|
err2020/conformer_ctc3/jit_pretrained.py
ADDED
@@ -0,0 +1,413 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang,
|
3 |
+
# Mingshuang Luo,)
|
4 |
+
# Zengwei Yao)
|
5 |
+
#
|
6 |
+
# See ../../../../LICENSE for clarification regarding multiple authors
|
7 |
+
#
|
8 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
9 |
+
# you may not use this file except in compliance with the License.
|
10 |
+
# You may obtain a copy of the License at
|
11 |
+
#
|
12 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
13 |
+
#
|
14 |
+
# Unless required by applicable law or agreed to in writing, software
|
15 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
16 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
17 |
+
# See the License for the specific language governing permissions and
|
18 |
+
# limitations under the License.
|
19 |
+
|
20 |
+
|
21 |
+
"""
|
22 |
+
Usage (for non-streaming mode):
|
23 |
+
|
24 |
+
(1) ctc-decoding
|
25 |
+
./conformer_ctc3/pretrained.py \
|
26 |
+
--nn-model-filename ./conformer_ctc3/exp/cpu_jit.pt \
|
27 |
+
--bpe-model data/lang_bpe_500/bpe.model \
|
28 |
+
--method ctc-decoding \
|
29 |
+
--sample-rate 16000 \
|
30 |
+
/path/to/foo.wav \
|
31 |
+
/path/to/bar.wav
|
32 |
+
|
33 |
+
(2) 1best
|
34 |
+
./conformer_ctc3/pretrained.py \
|
35 |
+
--nn-model-filename ./conformer_ctc3/exp/cpu_jit.pt \
|
36 |
+
--HLG data/lang_bpe_500/HLG.pt \
|
37 |
+
--words-file data/lang_bpe_500/words.txt \
|
38 |
+
--method 1best \
|
39 |
+
--sample-rate 16000 \
|
40 |
+
/path/to/foo.wav \
|
41 |
+
/path/to/bar.wav
|
42 |
+
|
43 |
+
(3) nbest-rescoring
|
44 |
+
./conformer_ctc3/pretrained.py \
|
45 |
+
--nn-model-filename ./conformer_ctc3/exp/cpu_jit.pt \
|
46 |
+
--HLG data/lang_bpe_500/HLG.pt \
|
47 |
+
--words-file data/lang_bpe_500/words.txt \
|
48 |
+
--G data/lm/G_4_gram.pt \
|
49 |
+
--method nbest-rescoring \
|
50 |
+
--sample-rate 16000 \
|
51 |
+
/path/to/foo.wav \
|
52 |
+
/path/to/bar.wav
|
53 |
+
|
54 |
+
(4) whole-lattice-rescoring
|
55 |
+
./conformer_ctc3/pretrained.py \
|
56 |
+
--nn-model-filename ./conformer_ctc3/exp/cpu_jit.pt \
|
57 |
+
--HLG data/lang_bpe_500/HLG.pt \
|
58 |
+
--words-file data/lang_bpe_500/words.txt \
|
59 |
+
--G data/lm/G_4_gram.pt \
|
60 |
+
--method whole-lattice-rescoring \
|
61 |
+
--sample-rate 16000 \
|
62 |
+
/path/to/foo.wav \
|
63 |
+
/path/to/bar.wav
|
64 |
+
"""
|
65 |
+
|
66 |
+
|
67 |
+
import argparse
|
68 |
+
import logging
|
69 |
+
import math
|
70 |
+
from typing import List
|
71 |
+
|
72 |
+
import k2
|
73 |
+
import kaldifeat
|
74 |
+
import sentencepiece as spm
|
75 |
+
import torch
|
76 |
+
import torchaudio
|
77 |
+
from decode import get_decoding_params
|
78 |
+
from torch.nn.utils.rnn import pad_sequence
|
79 |
+
from train import add_model_arguments, get_params
|
80 |
+
|
81 |
+
from icefall.decode import (
|
82 |
+
get_lattice,
|
83 |
+
one_best_decoding,
|
84 |
+
rescore_with_n_best_list,
|
85 |
+
rescore_with_whole_lattice,
|
86 |
+
)
|
87 |
+
from icefall.utils import get_texts
|
88 |
+
|
89 |
+
|
90 |
+
def get_parser():
|
91 |
+
parser = argparse.ArgumentParser(
|
92 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
93 |
+
)
|
94 |
+
|
95 |
+
parser.add_argument(
|
96 |
+
"--model-filename",
|
97 |
+
type=str,
|
98 |
+
required=True,
|
99 |
+
help="Path to the torchscript model.",
|
100 |
+
)
|
101 |
+
|
102 |
+
parser.add_argument(
|
103 |
+
"--words-file",
|
104 |
+
type=str,
|
105 |
+
help="""Path to words.txt.
|
106 |
+
Used only when method is not ctc-decoding.
|
107 |
+
""",
|
108 |
+
)
|
109 |
+
|
110 |
+
parser.add_argument(
|
111 |
+
"--HLG",
|
112 |
+
type=str,
|
113 |
+
help="""Path to HLG.pt.
|
114 |
+
Used only when method is not ctc-decoding.
|
115 |
+
""",
|
116 |
+
)
|
117 |
+
|
118 |
+
parser.add_argument(
|
119 |
+
"--bpe-model",
|
120 |
+
type=str,
|
121 |
+
help="""Path to bpe.model.
|
122 |
+
Used only when method is ctc-decoding.
|
123 |
+
""",
|
124 |
+
)
|
125 |
+
|
126 |
+
parser.add_argument(
|
127 |
+
"--method",
|
128 |
+
type=str,
|
129 |
+
default="1best",
|
130 |
+
help="""Decoding method.
|
131 |
+
Possible values are:
|
132 |
+
(0) ctc-decoding - Use CTC decoding. It uses a sentence
|
133 |
+
piece model, i.e., lang_dir/bpe.model, to convert
|
134 |
+
word pieces to words. It needs neither a lexicon
|
135 |
+
nor an n-gram LM.
|
136 |
+
(1) 1best - Use the best path as decoding output. Only
|
137 |
+
the transformer encoder output is used for decoding.
|
138 |
+
We call it HLG decoding.
|
139 |
+
(2) nbest-rescoring. Extract n paths from the decoding lattice,
|
140 |
+
rescore them with an LM, the path with
|
141 |
+
the highest score is the decoding result.
|
142 |
+
We call it HLG decoding + n-gram LM rescoring.
|
143 |
+
(3) whole-lattice-rescoring - Use an LM to rescore the
|
144 |
+
decoding lattice and then use 1best to decode the
|
145 |
+
rescored lattice.
|
146 |
+
We call it HLG decoding + n-gram LM rescoring.
|
147 |
+
""",
|
148 |
+
)
|
149 |
+
|
150 |
+
parser.add_argument(
|
151 |
+
"--G",
|
152 |
+
type=str,
|
153 |
+
help="""An LM for rescoring.
|
154 |
+
Used only when method is
|
155 |
+
whole-lattice-rescoring or nbest-rescoring.
|
156 |
+
It's usually a 4-gram LM.
|
157 |
+
""",
|
158 |
+
)
|
159 |
+
|
160 |
+
parser.add_argument(
|
161 |
+
"--num-paths",
|
162 |
+
type=int,
|
163 |
+
default=100,
|
164 |
+
help="""
|
165 |
+
Used only when method is attention-decoder.
|
166 |
+
It specifies the size of n-best list.""",
|
167 |
+
)
|
168 |
+
|
169 |
+
parser.add_argument(
|
170 |
+
"--ngram-lm-scale",
|
171 |
+
type=float,
|
172 |
+
default=1.3,
|
173 |
+
help="""
|
174 |
+
Used only when method is whole-lattice-rescoring and nbest-rescoring.
|
175 |
+
It specifies the scale for n-gram LM scores.
|
176 |
+
(Note: You need to tune it on a dataset.)
|
177 |
+
""",
|
178 |
+
)
|
179 |
+
|
180 |
+
parser.add_argument(
|
181 |
+
"--nbest-scale",
|
182 |
+
type=float,
|
183 |
+
default=0.5,
|
184 |
+
help="""
|
185 |
+
Used only when method is nbest-rescoring.
|
186 |
+
It specifies the scale for lattice.scores when
|
187 |
+
extracting n-best lists. A smaller value results in
|
188 |
+
more unique number of paths with the risk of missing
|
189 |
+
the best path.
|
190 |
+
""",
|
191 |
+
)
|
192 |
+
|
193 |
+
parser.add_argument(
|
194 |
+
"--num-classes",
|
195 |
+
type=int,
|
196 |
+
default=500,
|
197 |
+
help="""
|
198 |
+
Vocab size in the BPE model.
|
199 |
+
""",
|
200 |
+
)
|
201 |
+
|
202 |
+
parser.add_argument(
|
203 |
+
"--sample-rate",
|
204 |
+
type=int,
|
205 |
+
default=16000,
|
206 |
+
help="The sample rate of the input sound file",
|
207 |
+
)
|
208 |
+
|
209 |
+
parser.add_argument(
|
210 |
+
"sound_files",
|
211 |
+
type=str,
|
212 |
+
nargs="+",
|
213 |
+
help="The input sound file(s) to transcribe. "
|
214 |
+
"Supported formats are those supported by torchaudio.load(). "
|
215 |
+
"For example, wav and flac are supported. "
|
216 |
+
"The sample rate has to be 16kHz.",
|
217 |
+
)
|
218 |
+
|
219 |
+
add_model_arguments(parser)
|
220 |
+
|
221 |
+
return parser
|
222 |
+
|
223 |
+
|
224 |
+
def read_sound_files(
|
225 |
+
filenames: List[str], expected_sample_rate: float
|
226 |
+
) -> List[torch.Tensor]:
|
227 |
+
"""Read a list of sound files into a list 1-D float32 torch tensors.
|
228 |
+
Args:
|
229 |
+
filenames:
|
230 |
+
A list of sound filenames.
|
231 |
+
expected_sample_rate:
|
232 |
+
The expected sample rate of the sound files.
|
233 |
+
Returns:
|
234 |
+
Return a list of 1-D float32 torch tensors.
|
235 |
+
"""
|
236 |
+
ans = []
|
237 |
+
for f in filenames:
|
238 |
+
wave, sample_rate = torchaudio.load(f)
|
239 |
+
assert sample_rate == expected_sample_rate, (
|
240 |
+
f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
|
241 |
+
)
|
242 |
+
# We use only the first channel
|
243 |
+
ans.append(wave[0])
|
244 |
+
return ans
|
245 |
+
|
246 |
+
|
247 |
+
def main():
|
248 |
+
parser = get_parser()
|
249 |
+
args = parser.parse_args()
|
250 |
+
|
251 |
+
params = get_params()
|
252 |
+
# add decoding params
|
253 |
+
params.update(get_decoding_params())
|
254 |
+
params.update(vars(args))
|
255 |
+
params.vocab_size = params.num_classes
|
256 |
+
|
257 |
+
logging.info(f"{params}")
|
258 |
+
|
259 |
+
device = torch.device("cpu")
|
260 |
+
|
261 |
+
logging.info(f"device: {device}")
|
262 |
+
|
263 |
+
model = torch.jit.load(args.model_filename)
|
264 |
+
model.to(device)
|
265 |
+
model.eval()
|
266 |
+
|
267 |
+
logging.info("Constructing Fbank computer")
|
268 |
+
opts = kaldifeat.FbankOptions()
|
269 |
+
opts.device = device
|
270 |
+
opts.frame_opts.dither = 0
|
271 |
+
opts.frame_opts.snip_edges = False
|
272 |
+
opts.frame_opts.samp_freq = params.sample_rate
|
273 |
+
opts.mel_opts.num_bins = params.feature_dim
|
274 |
+
|
275 |
+
fbank = kaldifeat.Fbank(opts)
|
276 |
+
|
277 |
+
logging.info(f"Reading sound files: {params.sound_files}")
|
278 |
+
waves = read_sound_files(
|
279 |
+
filenames=params.sound_files, expected_sample_rate=params.sample_rate
|
280 |
+
)
|
281 |
+
waves = [w.to(device) for w in waves]
|
282 |
+
|
283 |
+
logging.info("Decoding started")
|
284 |
+
features = fbank(waves)
|
285 |
+
feature_lengths = [f.size(0) for f in features]
|
286 |
+
|
287 |
+
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
|
288 |
+
feature_lengths = torch.tensor(feature_lengths, device=device)
|
289 |
+
|
290 |
+
nnet_output, _ = model(features, feature_lengths)
|
291 |
+
|
292 |
+
batch_size = nnet_output.shape[0]
|
293 |
+
supervision_segments = torch.tensor(
|
294 |
+
[
|
295 |
+
[i, 0, feature_lengths[i] // params.subsampling_factor]
|
296 |
+
for i in range(batch_size)
|
297 |
+
],
|
298 |
+
dtype=torch.int32,
|
299 |
+
)
|
300 |
+
|
301 |
+
if params.method == "ctc-decoding":
|
302 |
+
logging.info("Use CTC decoding")
|
303 |
+
bpe_model = spm.SentencePieceProcessor()
|
304 |
+
bpe_model.load(params.bpe_model)
|
305 |
+
max_token_id = params.num_classes - 1
|
306 |
+
|
307 |
+
H = k2.ctc_topo(
|
308 |
+
max_token=max_token_id,
|
309 |
+
modified=False,
|
310 |
+
device=device,
|
311 |
+
)
|
312 |
+
|
313 |
+
lattice = get_lattice(
|
314 |
+
nnet_output=nnet_output,
|
315 |
+
decoding_graph=H,
|
316 |
+
supervision_segments=supervision_segments,
|
317 |
+
search_beam=params.search_beam,
|
318 |
+
output_beam=params.output_beam,
|
319 |
+
min_active_states=params.min_active_states,
|
320 |
+
max_active_states=params.max_active_states,
|
321 |
+
subsampling_factor=params.subsampling_factor,
|
322 |
+
)
|
323 |
+
|
324 |
+
best_path = one_best_decoding(
|
325 |
+
lattice=lattice, use_double_scores=params.use_double_scores
|
326 |
+
)
|
327 |
+
token_ids = get_texts(best_path)
|
328 |
+
hyps = bpe_model.decode(token_ids)
|
329 |
+
hyps = [s.split() for s in hyps]
|
330 |
+
elif params.method in [
|
331 |
+
"1best",
|
332 |
+
"nbest-rescoring",
|
333 |
+
"whole-lattice-rescoring",
|
334 |
+
]:
|
335 |
+
logging.info(f"Loading HLG from {params.HLG}")
|
336 |
+
HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu"))
|
337 |
+
HLG = HLG.to(device)
|
338 |
+
if not hasattr(HLG, "lm_scores"):
|
339 |
+
# For whole-lattice-rescoring and attention-decoder
|
340 |
+
HLG.lm_scores = HLG.scores.clone()
|
341 |
+
|
342 |
+
if params.method in [
|
343 |
+
"nbest-rescoring",
|
344 |
+
"whole-lattice-rescoring",
|
345 |
+
]:
|
346 |
+
logging.info(f"Loading G from {params.G}")
|
347 |
+
G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu"))
|
348 |
+
G = G.to(device)
|
349 |
+
if params.method == "whole-lattice-rescoring":
|
350 |
+
# Add epsilon self-loops to G as we will compose
|
351 |
+
# it with the whole lattice later
|
352 |
+
G = k2.add_epsilon_self_loops(G)
|
353 |
+
G = k2.arc_sort(G)
|
354 |
+
|
355 |
+
# G.lm_scores is used to replace HLG.lm_scores during
|
356 |
+
# LM rescoring.
|
357 |
+
G.lm_scores = G.scores.clone()
|
358 |
+
|
359 |
+
lattice = get_lattice(
|
360 |
+
nnet_output=nnet_output,
|
361 |
+
decoding_graph=HLG,
|
362 |
+
supervision_segments=supervision_segments,
|
363 |
+
search_beam=params.search_beam,
|
364 |
+
output_beam=params.output_beam,
|
365 |
+
min_active_states=params.min_active_states,
|
366 |
+
max_active_states=params.max_active_states,
|
367 |
+
subsampling_factor=params.subsampling_factor,
|
368 |
+
)
|
369 |
+
|
370 |
+
if params.method == "1best":
|
371 |
+
logging.info("Use HLG decoding")
|
372 |
+
best_path = one_best_decoding(
|
373 |
+
lattice=lattice, use_double_scores=params.use_double_scores
|
374 |
+
)
|
375 |
+
if params.method == "nbest-rescoring":
|
376 |
+
logging.info("Use HLG decoding + LM rescoring")
|
377 |
+
best_path_dict = rescore_with_n_best_list(
|
378 |
+
lattice=lattice,
|
379 |
+
G=G,
|
380 |
+
num_paths=params.num_paths,
|
381 |
+
lm_scale_list=[params.ngram_lm_scale],
|
382 |
+
nbest_scale=params.nbest_scale,
|
383 |
+
)
|
384 |
+
best_path = next(iter(best_path_dict.values()))
|
385 |
+
elif params.method == "whole-lattice-rescoring":
|
386 |
+
logging.info("Use HLG decoding + LM rescoring")
|
387 |
+
best_path_dict = rescore_with_whole_lattice(
|
388 |
+
lattice=lattice,
|
389 |
+
G_with_epsilon_loops=G,
|
390 |
+
lm_scale_list=[params.ngram_lm_scale],
|
391 |
+
)
|
392 |
+
best_path = next(iter(best_path_dict.values()))
|
393 |
+
|
394 |
+
hyps = get_texts(best_path)
|
395 |
+
word_sym_table = k2.SymbolTable.from_file(params.words_file)
|
396 |
+
hyps = [[word_sym_table[i] for i in ids] for ids in hyps]
|
397 |
+
else:
|
398 |
+
raise ValueError(f"Unsupported decoding method: {params.method}")
|
399 |
+
|
400 |
+
s = "\n"
|
401 |
+
for filename, hyp in zip(params.sound_files, hyps):
|
402 |
+
words = " ".join(hyp)
|
403 |
+
s += f"{filename}:\n{words}\n\n"
|
404 |
+
logging.info(s)
|
405 |
+
|
406 |
+
logging.info("Decoding Done")
|
407 |
+
|
408 |
+
|
409 |
+
if __name__ == "__main__":
|
410 |
+
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
411 |
+
|
412 |
+
logging.basicConfig(format=formatter, level=logging.INFO)
|
413 |
+
main()
|
err2020/conformer_ctc3/lstmp.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Tuple
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
|
8 |
+
class LSTMP(nn.Module):
|
9 |
+
"""LSTM with projection.
|
10 |
+
|
11 |
+
PyTorch does not support exporting LSTM with projection to ONNX.
|
12 |
+
This class reimplements LSTM with projection using basic matrix-matrix
|
13 |
+
and matrix-vector operations. It is not intended for training.
|
14 |
+
"""
|
15 |
+
|
16 |
+
def __init__(self, lstm: nn.LSTM):
|
17 |
+
"""
|
18 |
+
Args:
|
19 |
+
lstm:
|
20 |
+
LSTM with proj_size. We support only uni-directional,
|
21 |
+
1-layer LSTM with projection at present.
|
22 |
+
"""
|
23 |
+
super().__init__()
|
24 |
+
assert lstm.bidirectional is False, lstm.bidirectional
|
25 |
+
assert lstm.num_layers == 1, lstm.num_layers
|
26 |
+
assert 0 < lstm.proj_size < lstm.hidden_size, (
|
27 |
+
lstm.proj_size,
|
28 |
+
lstm.hidden_size,
|
29 |
+
)
|
30 |
+
|
31 |
+
assert lstm.batch_first is False, lstm.batch_first
|
32 |
+
|
33 |
+
state_dict = lstm.state_dict()
|
34 |
+
|
35 |
+
w_ih = state_dict["weight_ih_l0"]
|
36 |
+
w_hh = state_dict["weight_hh_l0"]
|
37 |
+
|
38 |
+
b_ih = state_dict["bias_ih_l0"]
|
39 |
+
b_hh = state_dict["bias_hh_l0"]
|
40 |
+
|
41 |
+
w_hr = state_dict["weight_hr_l0"]
|
42 |
+
self.input_size = lstm.input_size
|
43 |
+
self.proj_size = lstm.proj_size
|
44 |
+
self.hidden_size = lstm.hidden_size
|
45 |
+
|
46 |
+
self.w_ih = w_ih
|
47 |
+
self.w_hh = w_hh
|
48 |
+
self.b = b_ih + b_hh
|
49 |
+
self.w_hr = w_hr
|
50 |
+
|
51 |
+
def forward(
|
52 |
+
self,
|
53 |
+
input: torch.Tensor,
|
54 |
+
hx: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
55 |
+
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
56 |
+
"""
|
57 |
+
Args:
|
58 |
+
input:
|
59 |
+
A tensor of shape [T, N, hidden_size]
|
60 |
+
hx:
|
61 |
+
A tuple containing:
|
62 |
+
- h0: a tensor of shape (1, N, proj_size)
|
63 |
+
- c0: a tensor of shape (1, N, hidden_size)
|
64 |
+
Returns:
|
65 |
+
Return a tuple containing:
|
66 |
+
- output: a tensor of shape (T, N, proj_size).
|
67 |
+
- A tuple containing:
|
68 |
+
- h: a tensor of shape (1, N, proj_size)
|
69 |
+
- c: a tensor of shape (1, N, hidden_size)
|
70 |
+
|
71 |
+
"""
|
72 |
+
x_list = input.unbind(dim=0) # We use batch_first=False
|
73 |
+
|
74 |
+
if hx is not None:
|
75 |
+
h0, c0 = hx
|
76 |
+
else:
|
77 |
+
h0 = torch.zeros(1, input.size(1), self.proj_size)
|
78 |
+
c0 = torch.zeros(1, input.size(1), self.hidden_size)
|
79 |
+
h0 = h0.squeeze(0)
|
80 |
+
c0 = c0.squeeze(0)
|
81 |
+
y_list = []
|
82 |
+
for x in x_list:
|
83 |
+
gates = F.linear(x, self.w_ih, self.b) + F.linear(h0, self.w_hh)
|
84 |
+
i, f, g, o = gates.chunk(4, dim=1)
|
85 |
+
|
86 |
+
i = i.sigmoid()
|
87 |
+
f = f.sigmoid()
|
88 |
+
g = g.tanh()
|
89 |
+
o = o.sigmoid()
|
90 |
+
|
91 |
+
c = f * c0 + i * g
|
92 |
+
h = o * c.tanh()
|
93 |
+
|
94 |
+
h = F.linear(h, self.w_hr)
|
95 |
+
y_list.append(h)
|
96 |
+
|
97 |
+
c0 = c
|
98 |
+
h0 = h
|
99 |
+
|
100 |
+
y = torch.stack(y_list, dim=0)
|
101 |
+
|
102 |
+
return y, (h0.unsqueeze(0), c0.unsqueeze(0))
|
err2020/conformer_ctc3/model.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang,
|
2 |
+
# Wei Kang,
|
3 |
+
# Zengwei Yao)
|
4 |
+
#
|
5 |
+
# See ../../../../LICENSE for clarification regarding multiple authors
|
6 |
+
#
|
7 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
8 |
+
# you may not use this file except in compliance with the License.
|
9 |
+
# You may obtain a copy of the License at
|
10 |
+
#
|
11 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
12 |
+
#
|
13 |
+
# Unless required by applicable law or agreed to in writing, software
|
14 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
15 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
16 |
+
# See the License for the specific language governing permissions and
|
17 |
+
# limitations under the License.
|
18 |
+
|
19 |
+
|
20 |
+
import math
|
21 |
+
from typing import Tuple
|
22 |
+
|
23 |
+
import torch
|
24 |
+
import torch.nn as nn
|
25 |
+
from encoder_interface import EncoderInterface
|
26 |
+
from scaling import ScaledLinear
|
27 |
+
|
28 |
+
|
29 |
+
class CTCModel(nn.Module):
|
30 |
+
"""It implements https://www.cs.toronto.edu/~graves/icml_2006.pdf
|
31 |
+
"Connectionist Temporal Classification: Labelling Unsegmented
|
32 |
+
Sequence Data with Recurrent Neural Networks"
|
33 |
+
"""
|
34 |
+
|
35 |
+
def __init__(
|
36 |
+
self,
|
37 |
+
encoder: EncoderInterface,
|
38 |
+
encoder_dim: int,
|
39 |
+
vocab_size: int,
|
40 |
+
):
|
41 |
+
"""
|
42 |
+
Args:
|
43 |
+
encoder:
|
44 |
+
It is the transcription network in the paper. Its accepts
|
45 |
+
two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,).
|
46 |
+
It returns two tensors: `logits` of shape (N, T, encoder_dm) and
|
47 |
+
`logit_lens` of shape (N,).
|
48 |
+
encoder_dim:
|
49 |
+
The feature embedding dimension.
|
50 |
+
vocab_size:
|
51 |
+
The vocabulary size.
|
52 |
+
"""
|
53 |
+
super().__init__()
|
54 |
+
assert isinstance(encoder, EncoderInterface), type(encoder)
|
55 |
+
|
56 |
+
self.encoder = encoder
|
57 |
+
self.ctc_output_module = nn.Sequential(
|
58 |
+
nn.Dropout(p=0.1),
|
59 |
+
ScaledLinear(encoder_dim, vocab_size),
|
60 |
+
)
|
61 |
+
|
62 |
+
def get_ctc_output(
|
63 |
+
self,
|
64 |
+
encoder_out: torch.Tensor,
|
65 |
+
delay_penalty: float = 0.0,
|
66 |
+
blank_threshold: float = 0.99,
|
67 |
+
):
|
68 |
+
"""Compute ctc log-prob and optionally (delay_penalty > 0) apply delay penalty.
|
69 |
+
We first split utterance into sub-utterances according to the
|
70 |
+
blank probs, and then add sawtooth-like "blank-bonus" values to
|
71 |
+
the blank probs.
|
72 |
+
See https://github.com/k2-fsa/icefall/pull/669 for details.
|
73 |
+
|
74 |
+
Args:
|
75 |
+
encoder_out:
|
76 |
+
A tensor with shape of (N, T, C).
|
77 |
+
delay_penalty:
|
78 |
+
A constant used to scale the delay penalty score.
|
79 |
+
blank_threshold:
|
80 |
+
The threshold used to split utterance into sub-utterances.
|
81 |
+
"""
|
82 |
+
output = self.ctc_output_module(encoder_out)
|
83 |
+
log_prob = nn.functional.log_softmax(output, dim=-1)
|
84 |
+
|
85 |
+
if self.training and delay_penalty > 0:
|
86 |
+
T_arange = torch.arange(encoder_out.shape[1]).to(device=encoder_out.device)
|
87 |
+
# split into sub-utterances using the blank-id
|
88 |
+
mask = log_prob[:, :, 0] >= math.log(blank_threshold) # (B, T)
|
89 |
+
mask[:, 0] = True
|
90 |
+
cummax_out = (T_arange * mask).cummax(dim=-1)[0] # (B, T)
|
91 |
+
# the sawtooth "blank-bonus" value
|
92 |
+
penalty = T_arange - cummax_out # (B, T)
|
93 |
+
penalty_all = torch.zeros_like(log_prob)
|
94 |
+
penalty_all[:, :, 0] = delay_penalty * penalty
|
95 |
+
# apply latency penalty on probs
|
96 |
+
log_prob = log_prob + penalty_all
|
97 |
+
|
98 |
+
return log_prob
|
99 |
+
|
100 |
+
def forward(
|
101 |
+
self,
|
102 |
+
x: torch.Tensor,
|
103 |
+
x_lens: torch.Tensor,
|
104 |
+
warmup: float = 1.0,
|
105 |
+
delay_penalty: float = 0.0,
|
106 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
107 |
+
"""
|
108 |
+
Args:
|
109 |
+
x:
|
110 |
+
A 3-D tensor of shape (N, T, C).
|
111 |
+
x_lens:
|
112 |
+
A 1-D tensor of shape (N,). It contains the number of frames in `x`
|
113 |
+
before padding.
|
114 |
+
warmup: a floating point value which increases throughout training;
|
115 |
+
values >= 1.0 are fully warmed up and have all modules present.
|
116 |
+
delay_penalty:
|
117 |
+
A constant used to scale the delay penalty score.
|
118 |
+
"""
|
119 |
+
encoder_out, encoder_out_lens = self.encoder(x, x_lens, warmup=warmup)
|
120 |
+
assert torch.all(encoder_out_lens > 0)
|
121 |
+
nnet_output = self.get_ctc_output(encoder_out, delay_penalty=delay_penalty)
|
122 |
+
return nnet_output, encoder_out_lens
|
err2020/conformer_ctc3/optim.py
ADDED
@@ -0,0 +1,320 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey)
|
2 |
+
#
|
3 |
+
# See ../LICENSE for clarification regarding multiple authors
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
|
17 |
+
|
18 |
+
from typing import List, Optional, Union
|
19 |
+
|
20 |
+
import torch
|
21 |
+
from torch.optim import Optimizer
|
22 |
+
|
23 |
+
|
24 |
+
class Eve(Optimizer):
|
25 |
+
r"""
|
26 |
+
Implements Eve algorithm. This is a modified version of AdamW with a special
|
27 |
+
way of setting the weight-decay / shrinkage-factor, which is designed to make the
|
28 |
+
rms of the parameters approach a particular target_rms (default: 0.1). This is
|
29 |
+
for use with networks with 'scaled' versions of modules (see scaling.py), which
|
30 |
+
will be close to invariant to the absolute scale on the parameter matrix.
|
31 |
+
|
32 |
+
The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_.
|
33 |
+
The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_.
|
34 |
+
Eve is unpublished so far.
|
35 |
+
|
36 |
+
Arguments:
|
37 |
+
params (iterable): iterable of parameters to optimize or dicts defining
|
38 |
+
parameter groups
|
39 |
+
lr (float, optional): learning rate (default: 1e-3)
|
40 |
+
betas (Tuple[float, float], optional): coefficients used for computing
|
41 |
+
running averages of gradient and its square (default: (0.9, 0.999))
|
42 |
+
eps (float, optional): term added to the denominator to improve
|
43 |
+
numerical stability (default: 1e-8)
|
44 |
+
weight_decay (float, optional): weight decay coefficient (default: 3e-4;
|
45 |
+
this value means that the weight would decay significantly after
|
46 |
+
about 3k minibatches. Is not multiplied by learning rate, but
|
47 |
+
is conditional on RMS-value of parameter being > target_rms.
|
48 |
+
target_rms (float, optional): target root-mean-square value of
|
49 |
+
parameters, if they fall below this we will stop applying weight decay.
|
50 |
+
|
51 |
+
|
52 |
+
.. _Adam\: A Method for Stochastic Optimization:
|
53 |
+
https://arxiv.org/abs/1412.6980
|
54 |
+
.. _Decoupled Weight Decay Regularization:
|
55 |
+
https://arxiv.org/abs/1711.05101
|
56 |
+
.. _On the Convergence of Adam and Beyond:
|
57 |
+
https://openreview.net/forum?id=ryQu7f-RZ
|
58 |
+
"""
|
59 |
+
|
60 |
+
def __init__(
|
61 |
+
self,
|
62 |
+
params,
|
63 |
+
lr=1e-3,
|
64 |
+
betas=(0.9, 0.98),
|
65 |
+
eps=1e-8,
|
66 |
+
weight_decay=1e-3,
|
67 |
+
target_rms=0.1,
|
68 |
+
):
|
69 |
+
|
70 |
+
if not 0.0 <= lr:
|
71 |
+
raise ValueError("Invalid learning rate: {}".format(lr))
|
72 |
+
if not 0.0 <= eps:
|
73 |
+
raise ValueError("Invalid epsilon value: {}".format(eps))
|
74 |
+
if not 0.0 <= betas[0] < 1.0:
|
75 |
+
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
|
76 |
+
if not 0.0 <= betas[1] < 1.0:
|
77 |
+
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
|
78 |
+
if not 0 <= weight_decay <= 0.1:
|
79 |
+
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
|
80 |
+
if not 0 < target_rms <= 10.0:
|
81 |
+
raise ValueError("Invalid target_rms value: {}".format(target_rms))
|
82 |
+
defaults = dict(
|
83 |
+
lr=lr,
|
84 |
+
betas=betas,
|
85 |
+
eps=eps,
|
86 |
+
weight_decay=weight_decay,
|
87 |
+
target_rms=target_rms,
|
88 |
+
)
|
89 |
+
super(Eve, self).__init__(params, defaults)
|
90 |
+
|
91 |
+
def __setstate__(self, state):
|
92 |
+
super(Eve, self).__setstate__(state)
|
93 |
+
|
94 |
+
@torch.no_grad()
|
95 |
+
def step(self, closure=None):
|
96 |
+
"""Performs a single optimization step.
|
97 |
+
|
98 |
+
Arguments:
|
99 |
+
closure (callable, optional): A closure that reevaluates the model
|
100 |
+
and returns the loss.
|
101 |
+
"""
|
102 |
+
loss = None
|
103 |
+
if closure is not None:
|
104 |
+
with torch.enable_grad():
|
105 |
+
loss = closure()
|
106 |
+
|
107 |
+
for group in self.param_groups:
|
108 |
+
for p in group["params"]:
|
109 |
+
if p.grad is None:
|
110 |
+
continue
|
111 |
+
|
112 |
+
# Perform optimization step
|
113 |
+
grad = p.grad
|
114 |
+
if grad.is_sparse:
|
115 |
+
raise RuntimeError("AdamW does not support sparse gradients")
|
116 |
+
|
117 |
+
state = self.state[p]
|
118 |
+
|
119 |
+
# State initialization
|
120 |
+
if len(state) == 0:
|
121 |
+
state["step"] = 0
|
122 |
+
# Exponential moving average of gradient values
|
123 |
+
state["exp_avg"] = torch.zeros_like(
|
124 |
+
p, memory_format=torch.preserve_format
|
125 |
+
)
|
126 |
+
# Exponential moving average of squared gradient values
|
127 |
+
state["exp_avg_sq"] = torch.zeros_like(
|
128 |
+
p, memory_format=torch.preserve_format
|
129 |
+
)
|
130 |
+
|
131 |
+
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
|
132 |
+
|
133 |
+
beta1, beta2 = group["betas"]
|
134 |
+
|
135 |
+
state["step"] += 1
|
136 |
+
bias_correction1 = 1 - beta1 ** state["step"]
|
137 |
+
bias_correction2 = 1 - beta2 ** state["step"]
|
138 |
+
|
139 |
+
# Decay the first and second moment running average coefficient
|
140 |
+
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
|
141 |
+
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
|
142 |
+
denom = (exp_avg_sq.sqrt() * (bias_correction2**-0.5)).add_(
|
143 |
+
group["eps"]
|
144 |
+
)
|
145 |
+
|
146 |
+
step_size = group["lr"] / bias_correction1
|
147 |
+
target_rms = group["target_rms"]
|
148 |
+
weight_decay = group["weight_decay"]
|
149 |
+
|
150 |
+
if p.numel() > 1:
|
151 |
+
# avoid applying this weight-decay on "scaling factors"
|
152 |
+
# (which are scalar).
|
153 |
+
is_above_target_rms = p.norm() > (target_rms * (p.numel() ** 0.5))
|
154 |
+
p.mul_(1 - (weight_decay * is_above_target_rms))
|
155 |
+
p.addcdiv_(exp_avg, denom, value=-step_size)
|
156 |
+
|
157 |
+
# Constrain the range of scalar weights
|
158 |
+
if p.numel() == 1:
|
159 |
+
p.clamp_(min=-10, max=2)
|
160 |
+
|
161 |
+
return loss
|
162 |
+
|
163 |
+
|
164 |
+
class LRScheduler(object):
|
165 |
+
"""
|
166 |
+
Base-class for learning rate schedulers where the learning-rate depends on both the
|
167 |
+
batch and the epoch.
|
168 |
+
"""
|
169 |
+
|
170 |
+
def __init__(self, optimizer: Optimizer, verbose: bool = False):
|
171 |
+
# Attach optimizer
|
172 |
+
if not isinstance(optimizer, Optimizer):
|
173 |
+
raise TypeError("{} is not an Optimizer".format(type(optimizer).__name__))
|
174 |
+
self.optimizer = optimizer
|
175 |
+
self.verbose = verbose
|
176 |
+
|
177 |
+
for group in optimizer.param_groups:
|
178 |
+
group.setdefault("initial_lr", group["lr"])
|
179 |
+
|
180 |
+
self.base_lrs = [group["initial_lr"] for group in optimizer.param_groups]
|
181 |
+
|
182 |
+
self.epoch = 0
|
183 |
+
self.batch = 0
|
184 |
+
|
185 |
+
def state_dict(self):
|
186 |
+
"""Returns the state of the scheduler as a :class:`dict`.
|
187 |
+
|
188 |
+
It contains an entry for every variable in self.__dict__ which
|
189 |
+
is not the optimizer.
|
190 |
+
"""
|
191 |
+
return {
|
192 |
+
"base_lrs": self.base_lrs,
|
193 |
+
"epoch": self.epoch,
|
194 |
+
"batch": self.batch,
|
195 |
+
}
|
196 |
+
|
197 |
+
def load_state_dict(self, state_dict):
|
198 |
+
"""Loads the schedulers state.
|
199 |
+
|
200 |
+
Args:
|
201 |
+
state_dict (dict): scheduler state. Should be an object returned
|
202 |
+
from a call to :meth:`state_dict`.
|
203 |
+
"""
|
204 |
+
self.__dict__.update(state_dict)
|
205 |
+
|
206 |
+
def get_last_lr(self) -> List[float]:
|
207 |
+
"""Return last computed learning rate by current scheduler. Will be a list of float."""
|
208 |
+
return self._last_lr
|
209 |
+
|
210 |
+
def get_lr(self):
|
211 |
+
# Compute list of learning rates from self.epoch and self.batch and
|
212 |
+
# self.base_lrs; this must be overloaded by the user.
|
213 |
+
# e.g. return [some_formula(self.batch, self.epoch, base_lr) for base_lr in self.base_lrs ]
|
214 |
+
raise NotImplementedError
|
215 |
+
|
216 |
+
def step_batch(self, batch: Optional[int] = None) -> None:
|
217 |
+
# Step the batch index, or just set it. If `batch` is specified, it
|
218 |
+
# must be the batch index from the start of training, i.e. summed over
|
219 |
+
# all epochs.
|
220 |
+
# You can call this in any order; if you don't provide 'batch', it should
|
221 |
+
# of course be called once per batch.
|
222 |
+
if batch is not None:
|
223 |
+
self.batch = batch
|
224 |
+
else:
|
225 |
+
self.batch = self.batch + 1
|
226 |
+
self._set_lrs()
|
227 |
+
|
228 |
+
def step_epoch(self, epoch: Optional[int] = None):
|
229 |
+
# Step the epoch index, or just set it. If you provide the 'epoch' arg,
|
230 |
+
# you should call this at the start of the epoch; if you don't provide the 'epoch'
|
231 |
+
# arg, you should call it at the end of the epoch.
|
232 |
+
if epoch is not None:
|
233 |
+
self.epoch = epoch
|
234 |
+
else:
|
235 |
+
self.epoch = self.epoch + 1
|
236 |
+
self._set_lrs()
|
237 |
+
|
238 |
+
def _set_lrs(self):
|
239 |
+
values = self.get_lr()
|
240 |
+
assert len(values) == len(self.optimizer.param_groups)
|
241 |
+
|
242 |
+
for i, data in enumerate(zip(self.optimizer.param_groups, values)):
|
243 |
+
param_group, lr = data
|
244 |
+
param_group["lr"] = lr
|
245 |
+
self.print_lr(self.verbose, i, lr)
|
246 |
+
self._last_lr = [group["lr"] for group in self.optimizer.param_groups]
|
247 |
+
|
248 |
+
def print_lr(self, is_verbose, group, lr):
|
249 |
+
"""Display the current learning rate."""
|
250 |
+
if is_verbose:
|
251 |
+
print(
|
252 |
+
f"Epoch={self.epoch}, batch={self.batch}: adjusting learning rate"
|
253 |
+
f" of group {group} to {lr:.4e}."
|
254 |
+
)
|
255 |
+
|
256 |
+
|
257 |
+
class Eden(LRScheduler):
|
258 |
+
"""
|
259 |
+
Eden scheduler.
|
260 |
+
lr = initial_lr * (((batch**2 + lr_batches**2) / lr_batches**2) ** -0.25 *
|
261 |
+
(((epoch**2 + lr_epochs**2) / lr_epochs**2) ** -0.25))
|
262 |
+
|
263 |
+
E.g. suggest initial-lr = 0.003 (passed to optimizer).
|
264 |
+
|
265 |
+
Args:
|
266 |
+
optimizer: the optimizer to change the learning rates on
|
267 |
+
lr_batches: the number of batches after which we start significantly
|
268 |
+
decreasing the learning rate, suggest 5000.
|
269 |
+
lr_epochs: the number of epochs after which we start significantly
|
270 |
+
decreasing the learning rate, suggest 6 if you plan to do e.g.
|
271 |
+
20 to 40 epochs, but may need smaller number if dataset is huge
|
272 |
+
and you will do few epochs.
|
273 |
+
"""
|
274 |
+
|
275 |
+
def __init__(
|
276 |
+
self,
|
277 |
+
optimizer: Optimizer,
|
278 |
+
lr_batches: Union[int, float],
|
279 |
+
lr_epochs: Union[int, float],
|
280 |
+
verbose: bool = False,
|
281 |
+
):
|
282 |
+
super(Eden, self).__init__(optimizer, verbose)
|
283 |
+
self.lr_batches = lr_batches
|
284 |
+
self.lr_epochs = lr_epochs
|
285 |
+
|
286 |
+
def get_lr(self):
|
287 |
+
factor = (
|
288 |
+
(self.batch**2 + self.lr_batches**2) / self.lr_batches**2
|
289 |
+
) ** -0.25 * (
|
290 |
+
((self.epoch**2 + self.lr_epochs**2) / self.lr_epochs**2) ** -0.25
|
291 |
+
)
|
292 |
+
return [x * factor for x in self.base_lrs]
|
293 |
+
|
294 |
+
|
295 |
+
def _test_eden():
|
296 |
+
m = torch.nn.Linear(100, 100)
|
297 |
+
optim = Eve(m.parameters(), lr=0.003)
|
298 |
+
|
299 |
+
scheduler = Eden(optim, lr_batches=30, lr_epochs=2, verbose=True)
|
300 |
+
|
301 |
+
for epoch in range(10):
|
302 |
+
scheduler.step_epoch(epoch) # sets epoch to `epoch`
|
303 |
+
|
304 |
+
for step in range(20):
|
305 |
+
x = torch.randn(200, 100).detach()
|
306 |
+
x.requires_grad = True
|
307 |
+
y = m(x)
|
308 |
+
dy = torch.randn(200, 100).detach()
|
309 |
+
f = (y * dy).sum()
|
310 |
+
f.backward()
|
311 |
+
|
312 |
+
optim.step()
|
313 |
+
scheduler.step_batch()
|
314 |
+
optim.zero_grad()
|
315 |
+
print("last lr = ", scheduler.get_last_lr())
|
316 |
+
print("state dict = ", scheduler.state_dict())
|
317 |
+
|
318 |
+
|
319 |
+
if __name__ == "__main__":
|
320 |
+
_test_eden()
|
err2020/conformer_ctc3/pretrained.py
ADDED
@@ -0,0 +1,461 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang,
|
3 |
+
# Mingshuang Luo,)
|
4 |
+
# Zengwei Yao)
|
5 |
+
#
|
6 |
+
# See ../../../../LICENSE for clarification regarding multiple authors
|
7 |
+
#
|
8 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
9 |
+
# you may not use this file except in compliance with the License.
|
10 |
+
# You may obtain a copy of the License at
|
11 |
+
#
|
12 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
13 |
+
#
|
14 |
+
# Unless required by applicable law or agreed to in writing, software
|
15 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
16 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
17 |
+
# See the License for the specific language governing permissions and
|
18 |
+
# limitations under the License.
|
19 |
+
|
20 |
+
|
21 |
+
"""
|
22 |
+
Usage (for non-streaming mode):
|
23 |
+
|
24 |
+
(1) ctc-decoding
|
25 |
+
./conformer_ctc3/pretrained.py \
|
26 |
+
--checkpoint conformer_ctc3/exp/pretrained.pt \
|
27 |
+
--bpe-model data/lang_bpe_500/bpe.model \
|
28 |
+
--method ctc-decoding \
|
29 |
+
--sample-rate 16000 \
|
30 |
+
test_wavs/1089-134686-0001.wav
|
31 |
+
|
32 |
+
(2) 1best
|
33 |
+
./conformer_ctc3/pretrained.py \
|
34 |
+
--checkpoint conformer_ctc3/exp/pretrained.pt \
|
35 |
+
--HLG data/lang_bpe_500/HLG.pt \
|
36 |
+
--words-file data/lang_bpe_500/words.txt \
|
37 |
+
--method 1best \
|
38 |
+
--sample-rate 16000 \
|
39 |
+
test_wavs/1089-134686-0001.wav
|
40 |
+
|
41 |
+
(3) nbest-rescoring
|
42 |
+
./conformer_ctc3/pretrained.py \
|
43 |
+
--checkpoint conformer_ctc3/exp/pretrained.pt \
|
44 |
+
--HLG data/lang_bpe_500/HLG.pt \
|
45 |
+
--words-file data/lang_bpe_500/words.txt \
|
46 |
+
--G data/lm/G_4_gram.pt \
|
47 |
+
--method nbest-rescoring \
|
48 |
+
--sample-rate 16000 \
|
49 |
+
test_wavs/1089-134686-0001.wav
|
50 |
+
|
51 |
+
(4) whole-lattice-rescoring
|
52 |
+
./conformer_ctc3/pretrained.py \
|
53 |
+
--checkpoint conformer_ctc3/exp/pretrained.pt \
|
54 |
+
--HLG data/lang_bpe_500/HLG.pt \
|
55 |
+
--words-file data/lang_bpe_500/words.txt \
|
56 |
+
--G data/lm/G_4_gram.pt \
|
57 |
+
--method whole-lattice-rescoring \
|
58 |
+
--sample-rate 16000 \
|
59 |
+
test_wavs/1089-134686-0001.wav
|
60 |
+
"""
|
61 |
+
|
62 |
+
|
63 |
+
import argparse
|
64 |
+
import logging
|
65 |
+
import math
|
66 |
+
from typing import List
|
67 |
+
|
68 |
+
import k2
|
69 |
+
import kaldifeat
|
70 |
+
import sentencepiece as spm
|
71 |
+
import torch
|
72 |
+
import torchaudio
|
73 |
+
from decode import get_decoding_params
|
74 |
+
from torch.nn.utils.rnn import pad_sequence
|
75 |
+
from train import add_model_arguments, get_ctc_model, get_params
|
76 |
+
|
77 |
+
from icefall.decode import (
|
78 |
+
get_lattice,
|
79 |
+
one_best_decoding,
|
80 |
+
rescore_with_n_best_list,
|
81 |
+
rescore_with_whole_lattice,
|
82 |
+
)
|
83 |
+
from icefall.utils import get_texts, str2bool
|
84 |
+
|
85 |
+
|
86 |
+
def get_parser():
|
87 |
+
parser = argparse.ArgumentParser(
|
88 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
89 |
+
)
|
90 |
+
|
91 |
+
parser.add_argument(
|
92 |
+
"--checkpoint",
|
93 |
+
type=str,
|
94 |
+
required=True,
|
95 |
+
help="Path to the checkpoint. "
|
96 |
+
"The checkpoint is assumed to be saved by "
|
97 |
+
"icefall.checkpoint.save_checkpoint().",
|
98 |
+
)
|
99 |
+
|
100 |
+
parser.add_argument(
|
101 |
+
"--words-file",
|
102 |
+
type=str,
|
103 |
+
help="""Path to words.txt.
|
104 |
+
Used only when method is not ctc-decoding.
|
105 |
+
""",
|
106 |
+
)
|
107 |
+
|
108 |
+
parser.add_argument(
|
109 |
+
"--HLG",
|
110 |
+
type=str,
|
111 |
+
help="""Path to HLG.pt.
|
112 |
+
Used only when method is not ctc-decoding.
|
113 |
+
""",
|
114 |
+
)
|
115 |
+
|
116 |
+
parser.add_argument(
|
117 |
+
"--bpe-model",
|
118 |
+
type=str,
|
119 |
+
help="""Path to bpe.model.
|
120 |
+
Used only when method is ctc-decoding.
|
121 |
+
""",
|
122 |
+
)
|
123 |
+
|
124 |
+
parser.add_argument(
|
125 |
+
"--method",
|
126 |
+
type=str,
|
127 |
+
default="1best",
|
128 |
+
help="""Decoding method.
|
129 |
+
Possible values are:
|
130 |
+
(0) ctc-decoding - Use CTC decoding. It uses a sentence
|
131 |
+
piece model, i.e., lang_dir/bpe.model, to convert
|
132 |
+
word pieces to words. It needs neither a lexicon
|
133 |
+
nor an n-gram LM.
|
134 |
+
(1) 1best - Use the best path as decoding output. Only
|
135 |
+
the transformer encoder output is used for decoding.
|
136 |
+
We call it HLG decoding.
|
137 |
+
(2) nbest-rescoring. Extract n paths from the decoding lattice,
|
138 |
+
rescore them with an LM, the path with
|
139 |
+
the highest score is the decoding result.
|
140 |
+
We call it HLG decoding + n-gram LM rescoring.
|
141 |
+
(3) whole-lattice-rescoring - Use an LM to rescore the
|
142 |
+
decoding lattice and then use 1best to decode the
|
143 |
+
rescored lattice.
|
144 |
+
We call it HLG decoding + n-gram LM rescoring.
|
145 |
+
""",
|
146 |
+
)
|
147 |
+
|
148 |
+
parser.add_argument(
|
149 |
+
"--G",
|
150 |
+
type=str,
|
151 |
+
help="""An LM for rescoring.
|
152 |
+
Used only when method is
|
153 |
+
whole-lattice-rescoring or nbest-rescoring.
|
154 |
+
It's usually a 4-gram LM.
|
155 |
+
""",
|
156 |
+
)
|
157 |
+
|
158 |
+
parser.add_argument(
|
159 |
+
"--num-paths",
|
160 |
+
type=int,
|
161 |
+
default=100,
|
162 |
+
help="""
|
163 |
+
Used only when method is attention-decoder.
|
164 |
+
It specifies the size of n-best list.""",
|
165 |
+
)
|
166 |
+
|
167 |
+
parser.add_argument(
|
168 |
+
"--ngram-lm-scale",
|
169 |
+
type=float,
|
170 |
+
default=1.3,
|
171 |
+
help="""
|
172 |
+
Used only when method is whole-lattice-rescoring and nbest-rescoring.
|
173 |
+
It specifies the scale for n-gram LM scores.
|
174 |
+
(Note: You need to tune it on a dataset.)
|
175 |
+
""",
|
176 |
+
)
|
177 |
+
|
178 |
+
parser.add_argument(
|
179 |
+
"--nbest-scale",
|
180 |
+
type=float,
|
181 |
+
default=0.5,
|
182 |
+
help="""
|
183 |
+
Used only when method is nbest-rescoring.
|
184 |
+
It specifies the scale for lattice.scores when
|
185 |
+
extracting n-best lists. A smaller value results in
|
186 |
+
more unique number of paths with the risk of missing
|
187 |
+
the best path.
|
188 |
+
""",
|
189 |
+
)
|
190 |
+
|
191 |
+
parser.add_argument(
|
192 |
+
"--num-classes",
|
193 |
+
type=int,
|
194 |
+
default=500,
|
195 |
+
help="""
|
196 |
+
Vocab size in the BPE model.
|
197 |
+
""",
|
198 |
+
)
|
199 |
+
|
200 |
+
parser.add_argument(
|
201 |
+
"--simulate-streaming",
|
202 |
+
type=str2bool,
|
203 |
+
default=False,
|
204 |
+
help="""Whether to simulate streaming in decoding, this is a good way to
|
205 |
+
test a streaming model.
|
206 |
+
""",
|
207 |
+
)
|
208 |
+
|
209 |
+
parser.add_argument(
|
210 |
+
"--decode-chunk-size",
|
211 |
+
type=int,
|
212 |
+
default=16,
|
213 |
+
help="The chunk size for decoding (in frames after subsampling)",
|
214 |
+
)
|
215 |
+
|
216 |
+
parser.add_argument(
|
217 |
+
"--left-context",
|
218 |
+
type=int,
|
219 |
+
default=64,
|
220 |
+
help="left context can be seen during decoding (in frames after subsampling)",
|
221 |
+
)
|
222 |
+
|
223 |
+
parser.add_argument(
|
224 |
+
"--sample-rate",
|
225 |
+
type=int,
|
226 |
+
default=16000,
|
227 |
+
help="The sample rate of the input sound file",
|
228 |
+
)
|
229 |
+
|
230 |
+
parser.add_argument(
|
231 |
+
"sound_files",
|
232 |
+
type=str,
|
233 |
+
nargs="+",
|
234 |
+
help="The input sound file(s) to transcribe. "
|
235 |
+
"Supported formats are those supported by torchaudio.load(). "
|
236 |
+
"For example, wav and flac are supported. "
|
237 |
+
"The sample rate has to be 16kHz.",
|
238 |
+
)
|
239 |
+
|
240 |
+
add_model_arguments(parser)
|
241 |
+
|
242 |
+
return parser
|
243 |
+
|
244 |
+
|
245 |
+
def read_sound_files(
|
246 |
+
filenames: List[str], expected_sample_rate: float
|
247 |
+
) -> List[torch.Tensor]:
|
248 |
+
"""Read a list of sound files into a list 1-D float32 torch tensors.
|
249 |
+
Args:
|
250 |
+
filenames:
|
251 |
+
A list of sound filenames.
|
252 |
+
expected_sample_rate:
|
253 |
+
The expected sample rate of the sound files.
|
254 |
+
Returns:
|
255 |
+
Return a list of 1-D float32 torch tensors.
|
256 |
+
"""
|
257 |
+
ans = []
|
258 |
+
for f in filenames:
|
259 |
+
wave, sample_rate = torchaudio.load(f)
|
260 |
+
assert sample_rate == expected_sample_rate, (
|
261 |
+
f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
|
262 |
+
)
|
263 |
+
# We use only the first channel
|
264 |
+
ans.append(wave[0])
|
265 |
+
return ans
|
266 |
+
|
267 |
+
|
268 |
+
def main():
|
269 |
+
parser = get_parser()
|
270 |
+
args = parser.parse_args()
|
271 |
+
|
272 |
+
params = get_params()
|
273 |
+
# add decoding params
|
274 |
+
params.update(get_decoding_params())
|
275 |
+
params.update(vars(args))
|
276 |
+
params.vocab_size = params.num_classes
|
277 |
+
|
278 |
+
if params.simulate_streaming:
|
279 |
+
assert (
|
280 |
+
params.causal_convolution
|
281 |
+
), "Decoding in streaming requires causal convolution"
|
282 |
+
|
283 |
+
logging.info(f"{params}")
|
284 |
+
|
285 |
+
device = torch.device("cpu")
|
286 |
+
if torch.cuda.is_available():
|
287 |
+
device = torch.device("cuda", 0)
|
288 |
+
|
289 |
+
logging.info(f"device: {device}")
|
290 |
+
|
291 |
+
logging.info("About to create model")
|
292 |
+
model = get_ctc_model(params)
|
293 |
+
|
294 |
+
num_param = sum([p.numel() for p in model.parameters()])
|
295 |
+
logging.info(f"Number of model parameters: {num_param}")
|
296 |
+
|
297 |
+
checkpoint = torch.load(args.checkpoint, map_location="cpu")
|
298 |
+
model.load_state_dict(checkpoint["model"], strict=False)
|
299 |
+
model.to(device)
|
300 |
+
model.eval()
|
301 |
+
|
302 |
+
logging.info("Constructing Fbank computer")
|
303 |
+
opts = kaldifeat.FbankOptions()
|
304 |
+
opts.device = device
|
305 |
+
opts.frame_opts.dither = 0
|
306 |
+
opts.frame_opts.snip_edges = False
|
307 |
+
opts.frame_opts.samp_freq = params.sample_rate
|
308 |
+
opts.mel_opts.num_bins = params.feature_dim
|
309 |
+
|
310 |
+
fbank = kaldifeat.Fbank(opts)
|
311 |
+
|
312 |
+
logging.info(f"Reading sound files: {params.sound_files}")
|
313 |
+
waves = read_sound_files(
|
314 |
+
filenames=params.sound_files, expected_sample_rate=params.sample_rate
|
315 |
+
)
|
316 |
+
waves = [w.to(device) for w in waves]
|
317 |
+
|
318 |
+
logging.info("Decoding started")
|
319 |
+
features = fbank(waves)
|
320 |
+
feature_lengths = [f.size(0) for f in features]
|
321 |
+
|
322 |
+
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
|
323 |
+
feature_lengths = torch.tensor(feature_lengths, device=device)
|
324 |
+
|
325 |
+
# model forward
|
326 |
+
if params.simulate_streaming:
|
327 |
+
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
|
328 |
+
x=features,
|
329 |
+
x_lens=feature_lengths,
|
330 |
+
chunk_size=params.decode_chunk_size,
|
331 |
+
left_context=params.left_context,
|
332 |
+
simulate_streaming=True,
|
333 |
+
)
|
334 |
+
else:
|
335 |
+
encoder_out, encoder_out_lens = model.encoder(
|
336 |
+
x=features, x_lens=feature_lengths
|
337 |
+
)
|
338 |
+
nnet_output = model.get_ctc_output(encoder_out)
|
339 |
+
|
340 |
+
batch_size = nnet_output.shape[0]
|
341 |
+
supervision_segments = torch.tensor(
|
342 |
+
[
|
343 |
+
[i, 0, feature_lengths[i] // params.subsampling_factor]
|
344 |
+
for i in range(batch_size)
|
345 |
+
],
|
346 |
+
dtype=torch.int32,
|
347 |
+
)
|
348 |
+
|
349 |
+
if params.method == "ctc-decoding":
|
350 |
+
logging.info("Use CTC decoding")
|
351 |
+
bpe_model = spm.SentencePieceProcessor()
|
352 |
+
bpe_model.load(params.bpe_model)
|
353 |
+
max_token_id = params.num_classes - 1
|
354 |
+
|
355 |
+
H = k2.ctc_topo(
|
356 |
+
max_token=max_token_id,
|
357 |
+
modified=False,
|
358 |
+
device=device,
|
359 |
+
)
|
360 |
+
|
361 |
+
lattice = get_lattice(
|
362 |
+
nnet_output=nnet_output,
|
363 |
+
decoding_graph=H,
|
364 |
+
supervision_segments=supervision_segments,
|
365 |
+
search_beam=params.search_beam,
|
366 |
+
output_beam=params.output_beam,
|
367 |
+
min_active_states=params.min_active_states,
|
368 |
+
max_active_states=params.max_active_states,
|
369 |
+
subsampling_factor=params.subsampling_factor,
|
370 |
+
)
|
371 |
+
|
372 |
+
best_path = one_best_decoding(
|
373 |
+
lattice=lattice, use_double_scores=params.use_double_scores
|
374 |
+
)
|
375 |
+
token_ids = get_texts(best_path)
|
376 |
+
hyps = bpe_model.decode(token_ids)
|
377 |
+
hyps = [s.split() for s in hyps]
|
378 |
+
elif params.method in [
|
379 |
+
"1best",
|
380 |
+
"nbest-rescoring",
|
381 |
+
"whole-lattice-rescoring",
|
382 |
+
]:
|
383 |
+
logging.info(f"Loading HLG from {params.HLG}")
|
384 |
+
HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu"))
|
385 |
+
HLG = HLG.to(device)
|
386 |
+
if not hasattr(HLG, "lm_scores"):
|
387 |
+
# For whole-lattice-rescoring and attention-decoder
|
388 |
+
HLG.lm_scores = HLG.scores.clone()
|
389 |
+
|
390 |
+
if params.method in [
|
391 |
+
"nbest-rescoring",
|
392 |
+
"whole-lattice-rescoring",
|
393 |
+
]:
|
394 |
+
logging.info(f"Loading G from {params.G}")
|
395 |
+
G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu"))
|
396 |
+
G = G.to(device)
|
397 |
+
if params.method == "whole-lattice-rescoring":
|
398 |
+
# Add epsilon self-loops to G as we will compose
|
399 |
+
# it with the whole lattice later
|
400 |
+
G = k2.add_epsilon_self_loops(G)
|
401 |
+
G = k2.arc_sort(G)
|
402 |
+
|
403 |
+
# G.lm_scores is used to replace HLG.lm_scores during
|
404 |
+
# LM rescoring.
|
405 |
+
G.lm_scores = G.scores.clone()
|
406 |
+
|
407 |
+
lattice = get_lattice(
|
408 |
+
nnet_output=nnet_output,
|
409 |
+
decoding_graph=HLG,
|
410 |
+
supervision_segments=supervision_segments,
|
411 |
+
search_beam=params.search_beam,
|
412 |
+
output_beam=params.output_beam,
|
413 |
+
min_active_states=params.min_active_states,
|
414 |
+
max_active_states=params.max_active_states,
|
415 |
+
subsampling_factor=params.subsampling_factor,
|
416 |
+
)
|
417 |
+
|
418 |
+
if params.method == "1best":
|
419 |
+
logging.info("Use HLG decoding")
|
420 |
+
best_path = one_best_decoding(
|
421 |
+
lattice=lattice, use_double_scores=params.use_double_scores
|
422 |
+
)
|
423 |
+
if params.method == "nbest-rescoring":
|
424 |
+
logging.info("Use HLG decoding + LM rescoring")
|
425 |
+
best_path_dict = rescore_with_n_best_list(
|
426 |
+
lattice=lattice,
|
427 |
+
G=G,
|
428 |
+
num_paths=params.num_paths,
|
429 |
+
lm_scale_list=[params.ngram_lm_scale],
|
430 |
+
nbest_scale=params.nbest_scale,
|
431 |
+
)
|
432 |
+
best_path = next(iter(best_path_dict.values()))
|
433 |
+
elif params.method == "whole-lattice-rescoring":
|
434 |
+
logging.info("Use HLG decoding + LM rescoring")
|
435 |
+
best_path_dict = rescore_with_whole_lattice(
|
436 |
+
lattice=lattice,
|
437 |
+
G_with_epsilon_loops=G,
|
438 |
+
lm_scale_list=[params.ngram_lm_scale],
|
439 |
+
)
|
440 |
+
best_path = next(iter(best_path_dict.values()))
|
441 |
+
|
442 |
+
hyps = get_texts(best_path)
|
443 |
+
word_sym_table = k2.SymbolTable.from_file(params.words_file)
|
444 |
+
hyps = [[word_sym_table[i] for i in ids] for ids in hyps]
|
445 |
+
else:
|
446 |
+
raise ValueError(f"Unsupported decoding method: {params.method}")
|
447 |
+
|
448 |
+
s = "\n"
|
449 |
+
for filename, hyp in zip(params.sound_files, hyps):
|
450 |
+
words = " ".join(hyp)
|
451 |
+
s += f"{filename}:\n{words}\n\n"
|
452 |
+
logging.info(s)
|
453 |
+
|
454 |
+
logging.info("Decoding Done")
|
455 |
+
|
456 |
+
|
457 |
+
if __name__ == "__main__":
|
458 |
+
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
459 |
+
|
460 |
+
logging.basicConfig(format=formatter, level=logging.INFO)
|
461 |
+
main()
|
err2020/conformer_ctc3/scaling.py
ADDED
@@ -0,0 +1,1015 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey, Zengwei Yao)
|
2 |
+
#
|
3 |
+
# See ../../../../LICENSE for clarification regarding multiple authors
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
|
17 |
+
|
18 |
+
import collections
|
19 |
+
import random
|
20 |
+
from itertools import repeat
|
21 |
+
from typing import Optional, Tuple
|
22 |
+
|
23 |
+
import torch
|
24 |
+
import torch.backends.cudnn.rnn as rnn
|
25 |
+
import torch.nn as nn
|
26 |
+
from torch import _VF, Tensor
|
27 |
+
|
28 |
+
from icefall.utils import is_jit_tracing
|
29 |
+
|
30 |
+
|
31 |
+
def _ntuple(n):
|
32 |
+
def parse(x):
|
33 |
+
if isinstance(x, collections.Iterable):
|
34 |
+
return x
|
35 |
+
return tuple(repeat(x, n))
|
36 |
+
|
37 |
+
return parse
|
38 |
+
|
39 |
+
|
40 |
+
_single = _ntuple(1)
|
41 |
+
_pair = _ntuple(2)
|
42 |
+
|
43 |
+
|
44 |
+
class ActivationBalancerFunction(torch.autograd.Function):
|
45 |
+
@staticmethod
|
46 |
+
def forward(
|
47 |
+
ctx,
|
48 |
+
x: Tensor,
|
49 |
+
channel_dim: int,
|
50 |
+
min_positive: float, # e.g. 0.05
|
51 |
+
max_positive: float, # e.g. 0.95
|
52 |
+
max_factor: float, # e.g. 0.01
|
53 |
+
min_abs: float, # e.g. 0.2
|
54 |
+
max_abs: float, # e.g. 100.0
|
55 |
+
) -> Tensor:
|
56 |
+
if x.requires_grad:
|
57 |
+
if channel_dim < 0:
|
58 |
+
channel_dim += x.ndim
|
59 |
+
|
60 |
+
# sum_dims = [d for d in range(x.ndim) if d != channel_dim]
|
61 |
+
# The above line is not torch scriptable for torch 1.6.0
|
62 |
+
# torch.jit.frontend.NotSupportedError: comprehension ifs not supported yet: # noqa
|
63 |
+
sum_dims = []
|
64 |
+
for d in range(x.ndim):
|
65 |
+
if d != channel_dim:
|
66 |
+
sum_dims.append(d)
|
67 |
+
|
68 |
+
xgt0 = x > 0
|
69 |
+
proportion_positive = torch.mean(
|
70 |
+
xgt0.to(x.dtype), dim=sum_dims, keepdim=True
|
71 |
+
)
|
72 |
+
factor1 = (
|
73 |
+
(min_positive - proportion_positive).relu()
|
74 |
+
* (max_factor / min_positive)
|
75 |
+
if min_positive != 0.0
|
76 |
+
else 0.0
|
77 |
+
)
|
78 |
+
factor2 = (
|
79 |
+
(proportion_positive - max_positive).relu()
|
80 |
+
* (max_factor / (max_positive - 1.0))
|
81 |
+
if max_positive != 1.0
|
82 |
+
else 0.0
|
83 |
+
)
|
84 |
+
factor = factor1 + factor2
|
85 |
+
if isinstance(factor, float):
|
86 |
+
factor = torch.zeros_like(proportion_positive)
|
87 |
+
|
88 |
+
mean_abs = torch.mean(x.abs(), dim=sum_dims, keepdim=True)
|
89 |
+
below_threshold = mean_abs < min_abs
|
90 |
+
above_threshold = mean_abs > max_abs
|
91 |
+
|
92 |
+
ctx.save_for_backward(factor, xgt0, below_threshold, above_threshold)
|
93 |
+
ctx.max_factor = max_factor
|
94 |
+
ctx.sum_dims = sum_dims
|
95 |
+
return x
|
96 |
+
|
97 |
+
@staticmethod
|
98 |
+
def backward(
|
99 |
+
ctx, x_grad: Tensor
|
100 |
+
) -> Tuple[Tensor, None, None, None, None, None, None]:
|
101 |
+
factor, xgt0, below_threshold, above_threshold = ctx.saved_tensors
|
102 |
+
dtype = x_grad.dtype
|
103 |
+
scale_factor = (
|
104 |
+
(below_threshold.to(dtype) - above_threshold.to(dtype))
|
105 |
+
* (xgt0.to(dtype) - 0.5)
|
106 |
+
* (ctx.max_factor * 2.0)
|
107 |
+
)
|
108 |
+
|
109 |
+
neg_delta_grad = x_grad.abs() * (factor + scale_factor)
|
110 |
+
return x_grad - neg_delta_grad, None, None, None, None, None, None
|
111 |
+
|
112 |
+
|
113 |
+
class GradientFilterFunction(torch.autograd.Function):
|
114 |
+
@staticmethod
|
115 |
+
def forward(
|
116 |
+
ctx,
|
117 |
+
x: Tensor,
|
118 |
+
batch_dim: int, # e.g., 1
|
119 |
+
threshold: float, # e.g., 10.0
|
120 |
+
*params: Tensor, # module parameters
|
121 |
+
) -> Tuple[Tensor, ...]:
|
122 |
+
if x.requires_grad:
|
123 |
+
if batch_dim < 0:
|
124 |
+
batch_dim += x.ndim
|
125 |
+
ctx.batch_dim = batch_dim
|
126 |
+
ctx.threshold = threshold
|
127 |
+
return (x,) + params
|
128 |
+
|
129 |
+
@staticmethod
|
130 |
+
def backward(
|
131 |
+
ctx,
|
132 |
+
x_grad: Tensor,
|
133 |
+
*param_grads: Tensor,
|
134 |
+
) -> Tuple[Tensor, ...]:
|
135 |
+
eps = 1.0e-20
|
136 |
+
dim = ctx.batch_dim
|
137 |
+
norm_dims = [d for d in range(x_grad.ndim) if d != dim]
|
138 |
+
norm_of_batch = (x_grad**2).mean(dim=norm_dims, keepdim=True).sqrt()
|
139 |
+
median_norm = norm_of_batch.median()
|
140 |
+
|
141 |
+
cutoff = median_norm * ctx.threshold
|
142 |
+
inv_mask = (cutoff + norm_of_batch) / (cutoff + eps)
|
143 |
+
mask = 1.0 / (inv_mask + eps)
|
144 |
+
x_grad = x_grad * mask
|
145 |
+
|
146 |
+
avg_mask = 1.0 / (inv_mask.mean() + eps)
|
147 |
+
param_grads = [avg_mask * g for g in param_grads]
|
148 |
+
|
149 |
+
return (x_grad, None, None) + tuple(param_grads)
|
150 |
+
|
151 |
+
|
152 |
+
class GradientFilter(torch.nn.Module):
|
153 |
+
"""This is used to filter out elements that have extremely large gradients
|
154 |
+
in batch and the module parameters with soft masks.
|
155 |
+
|
156 |
+
Args:
|
157 |
+
batch_dim (int):
|
158 |
+
The batch dimension.
|
159 |
+
threshold (float):
|
160 |
+
For each element in batch, its gradient will be
|
161 |
+
filtered out if the gradient norm is larger than
|
162 |
+
`grad_norm_threshold * median`, where `median` is the median
|
163 |
+
value of gradient norms of all elememts in batch.
|
164 |
+
"""
|
165 |
+
|
166 |
+
def __init__(self, batch_dim: int = 1, threshold: float = 10.0):
|
167 |
+
super(GradientFilter, self).__init__()
|
168 |
+
self.batch_dim = batch_dim
|
169 |
+
self.threshold = threshold
|
170 |
+
|
171 |
+
def forward(self, x: Tensor, *params: Tensor) -> Tuple[Tensor, ...]:
|
172 |
+
if torch.jit.is_scripting() or is_jit_tracing():
|
173 |
+
return (x,) + params
|
174 |
+
else:
|
175 |
+
return GradientFilterFunction.apply(
|
176 |
+
x,
|
177 |
+
self.batch_dim,
|
178 |
+
self.threshold,
|
179 |
+
*params,
|
180 |
+
)
|
181 |
+
|
182 |
+
|
183 |
+
class BasicNorm(torch.nn.Module):
|
184 |
+
"""
|
185 |
+
This is intended to be a simpler, and hopefully cheaper, replacement for
|
186 |
+
LayerNorm. The observation this is based on, is that Transformer-type
|
187 |
+
networks, especially with pre-norm, sometimes seem to set one of the
|
188 |
+
feature dimensions to a large constant value (e.g. 50), which "defeats"
|
189 |
+
the LayerNorm because the output magnitude is then not strongly dependent
|
190 |
+
on the other (useful) features. Presumably the weight and bias of the
|
191 |
+
LayerNorm are required to allow it to do this.
|
192 |
+
|
193 |
+
So the idea is to introduce this large constant value as an explicit
|
194 |
+
parameter, that takes the role of the "eps" in LayerNorm, so the network
|
195 |
+
doesn't have to do this trick. We make the "eps" learnable.
|
196 |
+
|
197 |
+
Args:
|
198 |
+
num_channels: the number of channels, e.g. 512.
|
199 |
+
channel_dim: the axis/dimension corresponding to the channel,
|
200 |
+
interprted as an offset from the input's ndim if negative.
|
201 |
+
shis is NOT the num_channels; it should typically be one of
|
202 |
+
{-2, -1, 0, 1, 2, 3}.
|
203 |
+
eps: the initial "epsilon" that we add as ballast in:
|
204 |
+
scale = ((input_vec**2).mean() + epsilon)**-0.5
|
205 |
+
Note: our epsilon is actually large, but we keep the name
|
206 |
+
to indicate the connection with conventional LayerNorm.
|
207 |
+
learn_eps: if true, we learn epsilon; if false, we keep it
|
208 |
+
at the initial value.
|
209 |
+
"""
|
210 |
+
|
211 |
+
def __init__(
|
212 |
+
self,
|
213 |
+
num_channels: int,
|
214 |
+
channel_dim: int = -1, # CAUTION: see documentation.
|
215 |
+
eps: float = 0.25,
|
216 |
+
learn_eps: bool = True,
|
217 |
+
) -> None:
|
218 |
+
super(BasicNorm, self).__init__()
|
219 |
+
self.num_channels = num_channels
|
220 |
+
self.channel_dim = channel_dim
|
221 |
+
if learn_eps:
|
222 |
+
self.eps = nn.Parameter(torch.tensor(eps).log().detach())
|
223 |
+
else:
|
224 |
+
self.register_buffer("eps", torch.tensor(eps).log().detach())
|
225 |
+
|
226 |
+
def forward(self, x: Tensor) -> Tensor:
|
227 |
+
if not is_jit_tracing():
|
228 |
+
assert x.shape[self.channel_dim] == self.num_channels
|
229 |
+
scales = (
|
230 |
+
torch.mean(x**2, dim=self.channel_dim, keepdim=True) + self.eps.exp()
|
231 |
+
) ** -0.5
|
232 |
+
return x * scales
|
233 |
+
|
234 |
+
|
235 |
+
class ScaledLinear(nn.Linear):
|
236 |
+
"""
|
237 |
+
A modified version of nn.Linear where the parameters are scaled before
|
238 |
+
use, via:
|
239 |
+
weight = self.weight * self.weight_scale.exp()
|
240 |
+
bias = self.bias * self.bias_scale.exp()
|
241 |
+
|
242 |
+
Args:
|
243 |
+
Accepts the standard args and kwargs that nn.Linear accepts
|
244 |
+
e.g. in_features, out_features, bias=False.
|
245 |
+
|
246 |
+
initial_scale: you can override this if you want to increase
|
247 |
+
or decrease the initial magnitude of the module's output
|
248 |
+
(affects the initialization of weight_scale and bias_scale).
|
249 |
+
Another option, if you want to do something like this, is
|
250 |
+
to re-initialize the parameters.
|
251 |
+
initial_speed: this affects how fast the parameter will
|
252 |
+
learn near the start of training; you can set it to a
|
253 |
+
value less than one if you suspect that a module
|
254 |
+
is contributing to instability near the start of training.
|
255 |
+
Nnote: regardless of the use of this option, it's best to
|
256 |
+
use schedulers like Noam that have a warm-up period.
|
257 |
+
Alternatively you can set it to more than 1 if you want it to
|
258 |
+
initially train faster. Must be greater than 0.
|
259 |
+
"""
|
260 |
+
|
261 |
+
def __init__(
|
262 |
+
self,
|
263 |
+
*args,
|
264 |
+
initial_scale: float = 1.0,
|
265 |
+
initial_speed: float = 1.0,
|
266 |
+
**kwargs,
|
267 |
+
):
|
268 |
+
super(ScaledLinear, self).__init__(*args, **kwargs)
|
269 |
+
initial_scale = torch.tensor(initial_scale).log()
|
270 |
+
self.weight_scale = nn.Parameter(initial_scale.clone().detach())
|
271 |
+
if self.bias is not None:
|
272 |
+
self.bias_scale = nn.Parameter(initial_scale.clone().detach())
|
273 |
+
else:
|
274 |
+
self.register_parameter("bias_scale", None)
|
275 |
+
|
276 |
+
self._reset_parameters(
|
277 |
+
initial_speed
|
278 |
+
) # Overrides the reset_parameters in nn.Linear
|
279 |
+
|
280 |
+
def _reset_parameters(self, initial_speed: float):
|
281 |
+
std = 0.1 / initial_speed
|
282 |
+
a = (3**0.5) * std
|
283 |
+
nn.init.uniform_(self.weight, -a, a)
|
284 |
+
if self.bias is not None:
|
285 |
+
nn.init.constant_(self.bias, 0.0)
|
286 |
+
fan_in = self.weight.shape[1] * self.weight[0][0].numel()
|
287 |
+
scale = fan_in**-0.5 # 1/sqrt(fan_in)
|
288 |
+
with torch.no_grad():
|
289 |
+
self.weight_scale += torch.tensor(scale / std).log()
|
290 |
+
|
291 |
+
def get_weight(self):
|
292 |
+
return self.weight * self.weight_scale.exp()
|
293 |
+
|
294 |
+
def get_bias(self):
|
295 |
+
if self.bias is None or self.bias_scale is None:
|
296 |
+
return None
|
297 |
+
else:
|
298 |
+
return self.bias * self.bias_scale.exp()
|
299 |
+
|
300 |
+
def forward(self, input: Tensor) -> Tensor:
|
301 |
+
return torch.nn.functional.linear(input, self.get_weight(), self.get_bias())
|
302 |
+
|
303 |
+
|
304 |
+
class ScaledConv1d(nn.Conv1d):
|
305 |
+
# See docs for ScaledLinear
|
306 |
+
def __init__(
|
307 |
+
self,
|
308 |
+
*args,
|
309 |
+
initial_scale: float = 1.0,
|
310 |
+
initial_speed: float = 1.0,
|
311 |
+
**kwargs,
|
312 |
+
):
|
313 |
+
super(ScaledConv1d, self).__init__(*args, **kwargs)
|
314 |
+
initial_scale = torch.tensor(initial_scale).log()
|
315 |
+
|
316 |
+
self.bias_scale: Optional[nn.Parameter] # for torchscript
|
317 |
+
|
318 |
+
self.weight_scale = nn.Parameter(initial_scale.clone().detach())
|
319 |
+
if self.bias is not None:
|
320 |
+
self.bias_scale = nn.Parameter(initial_scale.clone().detach())
|
321 |
+
else:
|
322 |
+
self.register_parameter("bias_scale", None)
|
323 |
+
self._reset_parameters(
|
324 |
+
initial_speed
|
325 |
+
) # Overrides the reset_parameters in base class
|
326 |
+
|
327 |
+
def _reset_parameters(self, initial_speed: float):
|
328 |
+
std = 0.1 / initial_speed
|
329 |
+
a = (3**0.5) * std
|
330 |
+
nn.init.uniform_(self.weight, -a, a)
|
331 |
+
if self.bias is not None:
|
332 |
+
nn.init.constant_(self.bias, 0.0)
|
333 |
+
fan_in = self.weight.shape[1] * self.weight[0][0].numel()
|
334 |
+
scale = fan_in**-0.5 # 1/sqrt(fan_in)
|
335 |
+
with torch.no_grad():
|
336 |
+
self.weight_scale += torch.tensor(scale / std).log()
|
337 |
+
|
338 |
+
def get_weight(self):
|
339 |
+
return self.weight * self.weight_scale.exp()
|
340 |
+
|
341 |
+
def get_bias(self):
|
342 |
+
bias = self.bias
|
343 |
+
bias_scale = self.bias_scale
|
344 |
+
if bias is None or bias_scale is None:
|
345 |
+
return None
|
346 |
+
else:
|
347 |
+
return bias * bias_scale.exp()
|
348 |
+
|
349 |
+
def forward(self, input: Tensor) -> Tensor:
|
350 |
+
F = torch.nn.functional
|
351 |
+
if self.padding_mode != "zeros":
|
352 |
+
return F.conv1d(
|
353 |
+
F.pad(
|
354 |
+
input,
|
355 |
+
self._reversed_padding_repeated_twice,
|
356 |
+
mode=self.padding_mode,
|
357 |
+
),
|
358 |
+
self.get_weight(),
|
359 |
+
self.get_bias(),
|
360 |
+
self.stride,
|
361 |
+
(0,),
|
362 |
+
self.dilation,
|
363 |
+
self.groups,
|
364 |
+
)
|
365 |
+
return F.conv1d(
|
366 |
+
input,
|
367 |
+
self.get_weight(),
|
368 |
+
self.get_bias(),
|
369 |
+
self.stride,
|
370 |
+
self.padding,
|
371 |
+
self.dilation,
|
372 |
+
self.groups,
|
373 |
+
)
|
374 |
+
|
375 |
+
|
376 |
+
class ScaledConv2d(nn.Conv2d):
|
377 |
+
# See docs for ScaledLinear
|
378 |
+
def __init__(
|
379 |
+
self,
|
380 |
+
*args,
|
381 |
+
initial_scale: float = 1.0,
|
382 |
+
initial_speed: float = 1.0,
|
383 |
+
**kwargs,
|
384 |
+
):
|
385 |
+
super(ScaledConv2d, self).__init__(*args, **kwargs)
|
386 |
+
initial_scale = torch.tensor(initial_scale).log()
|
387 |
+
self.weight_scale = nn.Parameter(initial_scale.clone().detach())
|
388 |
+
if self.bias is not None:
|
389 |
+
self.bias_scale = nn.Parameter(initial_scale.clone().detach())
|
390 |
+
else:
|
391 |
+
self.register_parameter("bias_scale", None)
|
392 |
+
self._reset_parameters(
|
393 |
+
initial_speed
|
394 |
+
) # Overrides the reset_parameters in base class
|
395 |
+
|
396 |
+
def _reset_parameters(self, initial_speed: float):
|
397 |
+
std = 0.1 / initial_speed
|
398 |
+
a = (3**0.5) * std
|
399 |
+
nn.init.uniform_(self.weight, -a, a)
|
400 |
+
if self.bias is not None:
|
401 |
+
nn.init.constant_(self.bias, 0.0)
|
402 |
+
fan_in = self.weight.shape[1] * self.weight[0][0].numel()
|
403 |
+
scale = fan_in**-0.5 # 1/sqrt(fan_in)
|
404 |
+
with torch.no_grad():
|
405 |
+
self.weight_scale += torch.tensor(scale / std).log()
|
406 |
+
|
407 |
+
def get_weight(self):
|
408 |
+
return self.weight * self.weight_scale.exp()
|
409 |
+
|
410 |
+
def get_bias(self):
|
411 |
+
# see https://github.com/pytorch/pytorch/issues/24135
|
412 |
+
bias = self.bias
|
413 |
+
bias_scale = self.bias_scale
|
414 |
+
if bias is None or bias_scale is None:
|
415 |
+
return None
|
416 |
+
else:
|
417 |
+
return bias * bias_scale.exp()
|
418 |
+
|
419 |
+
def _conv_forward(self, input, weight):
|
420 |
+
F = torch.nn.functional
|
421 |
+
if self.padding_mode != "zeros":
|
422 |
+
return F.conv2d(
|
423 |
+
F.pad(
|
424 |
+
input,
|
425 |
+
self._reversed_padding_repeated_twice,
|
426 |
+
mode=self.padding_mode,
|
427 |
+
),
|
428 |
+
weight,
|
429 |
+
self.get_bias(),
|
430 |
+
self.stride,
|
431 |
+
(0, 0),
|
432 |
+
self.dilation,
|
433 |
+
self.groups,
|
434 |
+
)
|
435 |
+
return F.conv2d(
|
436 |
+
input,
|
437 |
+
weight,
|
438 |
+
self.get_bias(),
|
439 |
+
self.stride,
|
440 |
+
self.padding,
|
441 |
+
self.dilation,
|
442 |
+
self.groups,
|
443 |
+
)
|
444 |
+
|
445 |
+
def forward(self, input: Tensor) -> Tensor:
|
446 |
+
return self._conv_forward(input, self.get_weight())
|
447 |
+
|
448 |
+
|
449 |
+
class ScaledLSTM(nn.LSTM):
|
450 |
+
# See docs for ScaledLinear.
|
451 |
+
# This class implements LSTM with scaling mechanism, using `torch._VF.lstm`
|
452 |
+
# Please refer to https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/rnn.py
|
453 |
+
def __init__(
|
454 |
+
self,
|
455 |
+
*args,
|
456 |
+
initial_scale: float = 1.0,
|
457 |
+
initial_speed: float = 1.0,
|
458 |
+
grad_norm_threshold: float = 10.0,
|
459 |
+
**kwargs,
|
460 |
+
):
|
461 |
+
if "bidirectional" in kwargs:
|
462 |
+
assert kwargs["bidirectional"] is False
|
463 |
+
super(ScaledLSTM, self).__init__(*args, **kwargs)
|
464 |
+
initial_scale = torch.tensor(initial_scale).log()
|
465 |
+
self._scales_names = []
|
466 |
+
self._scales = []
|
467 |
+
for name in self._flat_weights_names:
|
468 |
+
scale_name = name + "_scale"
|
469 |
+
self._scales_names.append(scale_name)
|
470 |
+
param = nn.Parameter(initial_scale.clone().detach())
|
471 |
+
setattr(self, scale_name, param)
|
472 |
+
self._scales.append(param)
|
473 |
+
|
474 |
+
self.grad_filter = GradientFilter(batch_dim=1, threshold=grad_norm_threshold)
|
475 |
+
|
476 |
+
self._reset_parameters(
|
477 |
+
initial_speed
|
478 |
+
) # Overrides the reset_parameters in base class
|
479 |
+
|
480 |
+
def _reset_parameters(self, initial_speed: float):
|
481 |
+
std = 0.1 / initial_speed
|
482 |
+
a = (3**0.5) * std
|
483 |
+
scale = self.hidden_size**-0.5
|
484 |
+
v = scale / std
|
485 |
+
for idx, name in enumerate(self._flat_weights_names):
|
486 |
+
if "weight" in name:
|
487 |
+
nn.init.uniform_(self._flat_weights[idx], -a, a)
|
488 |
+
with torch.no_grad():
|
489 |
+
self._scales[idx] += torch.tensor(v).log()
|
490 |
+
elif "bias" in name:
|
491 |
+
nn.init.constant_(self._flat_weights[idx], 0.0)
|
492 |
+
|
493 |
+
def _flatten_parameters(self, flat_weights) -> None:
|
494 |
+
"""Resets parameter data pointer so that they can use faster code paths.
|
495 |
+
|
496 |
+
Right now, this works only if the module is on the GPU and cuDNN is enabled.
|
497 |
+
Otherwise, it's a no-op.
|
498 |
+
|
499 |
+
This function is modified from https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/rnn.py # noqa
|
500 |
+
"""
|
501 |
+
# Short-circuits if _flat_weights is only partially instantiated
|
502 |
+
if len(flat_weights) != len(self._flat_weights_names):
|
503 |
+
return
|
504 |
+
|
505 |
+
for w in flat_weights:
|
506 |
+
if not isinstance(w, Tensor):
|
507 |
+
return
|
508 |
+
# Short-circuits if any tensor in flat_weights is not acceptable to cuDNN
|
509 |
+
# or the tensors in flat_weights are of different dtypes
|
510 |
+
|
511 |
+
first_fw = flat_weights[0]
|
512 |
+
dtype = first_fw.dtype
|
513 |
+
for fw in flat_weights:
|
514 |
+
if (
|
515 |
+
not isinstance(fw.data, Tensor)
|
516 |
+
or not (fw.data.dtype == dtype)
|
517 |
+
or not fw.data.is_cuda
|
518 |
+
or not torch.backends.cudnn.is_acceptable(fw.data)
|
519 |
+
):
|
520 |
+
return
|
521 |
+
|
522 |
+
# If any parameters alias, we fall back to the slower, copying code path. This is
|
523 |
+
# a sufficient check, because overlapping parameter buffers that don't completely
|
524 |
+
# alias would break the assumptions of the uniqueness check in
|
525 |
+
# Module.named_parameters().
|
526 |
+
unique_data_ptrs = set(p.data_ptr() for p in flat_weights)
|
527 |
+
if len(unique_data_ptrs) != len(flat_weights):
|
528 |
+
return
|
529 |
+
|
530 |
+
with torch.cuda.device_of(first_fw):
|
531 |
+
|
532 |
+
# Note: no_grad() is necessary since _cudnn_rnn_flatten_weight is
|
533 |
+
# an inplace operation on self._flat_weights
|
534 |
+
with torch.no_grad():
|
535 |
+
if torch._use_cudnn_rnn_flatten_weight():
|
536 |
+
num_weights = 4 if self.bias else 2
|
537 |
+
if self.proj_size > 0:
|
538 |
+
num_weights += 1
|
539 |
+
torch._cudnn_rnn_flatten_weight(
|
540 |
+
flat_weights,
|
541 |
+
num_weights,
|
542 |
+
self.input_size,
|
543 |
+
rnn.get_cudnn_mode(self.mode),
|
544 |
+
self.hidden_size,
|
545 |
+
self.proj_size,
|
546 |
+
self.num_layers,
|
547 |
+
self.batch_first,
|
548 |
+
bool(self.bidirectional),
|
549 |
+
)
|
550 |
+
|
551 |
+
def _get_flat_weights(self):
|
552 |
+
"""Get scaled weights, and resets their data pointer."""
|
553 |
+
flat_weights = []
|
554 |
+
for idx in range(len(self._flat_weights_names)):
|
555 |
+
flat_weights.append(self._flat_weights[idx] * self._scales[idx].exp())
|
556 |
+
self._flatten_parameters(flat_weights)
|
557 |
+
return flat_weights
|
558 |
+
|
559 |
+
def forward(self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None):
|
560 |
+
# This function is modified from https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/rnn.py # noqa
|
561 |
+
# The change for calling `_VF.lstm()` is:
|
562 |
+
# self._flat_weights -> self._get_flat_weights()
|
563 |
+
if hx is None:
|
564 |
+
h_zeros = torch.zeros(
|
565 |
+
self.num_layers,
|
566 |
+
input.size(1),
|
567 |
+
self.proj_size if self.proj_size > 0 else self.hidden_size,
|
568 |
+
dtype=input.dtype,
|
569 |
+
device=input.device,
|
570 |
+
)
|
571 |
+
c_zeros = torch.zeros(
|
572 |
+
self.num_layers,
|
573 |
+
input.size(1),
|
574 |
+
self.hidden_size,
|
575 |
+
dtype=input.dtype,
|
576 |
+
device=input.device,
|
577 |
+
)
|
578 |
+
hx = (h_zeros, c_zeros)
|
579 |
+
|
580 |
+
self.check_forward_args(input, hx, None)
|
581 |
+
|
582 |
+
flat_weights = self._get_flat_weights()
|
583 |
+
input, *flat_weights = self.grad_filter(input, *flat_weights)
|
584 |
+
|
585 |
+
result = _VF.lstm(
|
586 |
+
input,
|
587 |
+
hx,
|
588 |
+
flat_weights,
|
589 |
+
self.bias,
|
590 |
+
self.num_layers,
|
591 |
+
self.dropout,
|
592 |
+
self.training,
|
593 |
+
self.bidirectional,
|
594 |
+
self.batch_first,
|
595 |
+
)
|
596 |
+
|
597 |
+
output = result[0]
|
598 |
+
hidden = result[1:]
|
599 |
+
return output, hidden
|
600 |
+
|
601 |
+
|
602 |
+
class ActivationBalancer(torch.nn.Module):
|
603 |
+
"""
|
604 |
+
Modifies the backpropped derivatives of a function to try to encourage, for
|
605 |
+
each channel, that it is positive at least a proportion `threshold` of the
|
606 |
+
time. It does this by multiplying negative derivative values by up to
|
607 |
+
(1+max_factor), and positive derivative values by up to (1-max_factor),
|
608 |
+
interpolated from 1 at the threshold to those extremal values when none
|
609 |
+
of the inputs are positive.
|
610 |
+
|
611 |
+
|
612 |
+
Args:
|
613 |
+
channel_dim: the dimension/axis corresponding to the channel, e.g.
|
614 |
+
-1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative.
|
615 |
+
min_positive: the minimum, per channel, of the proportion of the time
|
616 |
+
that (x > 0), below which we start to modify the derivatives.
|
617 |
+
max_positive: the maximum, per channel, of the proportion of the time
|
618 |
+
that (x > 0), above which we start to modify the derivatives.
|
619 |
+
max_factor: the maximum factor by which we modify the derivatives for
|
620 |
+
either the sign constraint or the magnitude constraint;
|
621 |
+
e.g. with max_factor=0.02, the the derivatives would be multiplied by
|
622 |
+
values in the range [0.98..1.02].
|
623 |
+
min_abs: the minimum average-absolute-value per channel, which
|
624 |
+
we allow, before we start to modify the derivatives to prevent
|
625 |
+
this.
|
626 |
+
max_abs: the maximum average-absolute-value per channel, which
|
627 |
+
we allow, before we start to modify the derivatives to prevent
|
628 |
+
this.
|
629 |
+
balance_prob: the probability to apply the ActivationBalancer.
|
630 |
+
"""
|
631 |
+
|
632 |
+
def __init__(
|
633 |
+
self,
|
634 |
+
channel_dim: int,
|
635 |
+
min_positive: float = 0.05,
|
636 |
+
max_positive: float = 0.95,
|
637 |
+
max_factor: float = 0.01,
|
638 |
+
min_abs: float = 0.2,
|
639 |
+
max_abs: float = 100.0,
|
640 |
+
balance_prob: float = 0.25,
|
641 |
+
):
|
642 |
+
super(ActivationBalancer, self).__init__()
|
643 |
+
self.channel_dim = channel_dim
|
644 |
+
self.min_positive = min_positive
|
645 |
+
self.max_positive = max_positive
|
646 |
+
self.max_factor = max_factor
|
647 |
+
self.min_abs = min_abs
|
648 |
+
self.max_abs = max_abs
|
649 |
+
assert 0 < balance_prob <= 1, balance_prob
|
650 |
+
self.balance_prob = balance_prob
|
651 |
+
|
652 |
+
def forward(self, x: Tensor) -> Tensor:
|
653 |
+
if random.random() >= self.balance_prob:
|
654 |
+
return x
|
655 |
+
|
656 |
+
return ActivationBalancerFunction.apply(
|
657 |
+
x,
|
658 |
+
self.channel_dim,
|
659 |
+
self.min_positive,
|
660 |
+
self.max_positive,
|
661 |
+
self.max_factor / self.balance_prob,
|
662 |
+
self.min_abs,
|
663 |
+
self.max_abs,
|
664 |
+
)
|
665 |
+
|
666 |
+
|
667 |
+
class DoubleSwishFunction(torch.autograd.Function):
|
668 |
+
"""
|
669 |
+
double_swish(x) = x * torch.sigmoid(x-1)
|
670 |
+
This is a definition, originally motivated by its close numerical
|
671 |
+
similarity to swish(swish(x)), where swish(x) = x * sigmoid(x).
|
672 |
+
|
673 |
+
Memory-efficient derivative computation:
|
674 |
+
double_swish(x) = x * s, where s(x) = torch.sigmoid(x-1)
|
675 |
+
double_swish'(x) = d/dx double_swish(x) = x * s'(x) + x' * s(x) = x * s'(x) + s(x).
|
676 |
+
Now, s'(x) = s(x) * (1-s(x)).
|
677 |
+
double_swish'(x) = x * s'(x) + s(x).
|
678 |
+
= x * s(x) * (1-s(x)) + s(x).
|
679 |
+
= double_swish(x) * (1-s(x)) + s(x)
|
680 |
+
... so we just need to remember s(x) but not x itself.
|
681 |
+
"""
|
682 |
+
|
683 |
+
@staticmethod
|
684 |
+
def forward(ctx, x: Tensor) -> Tensor:
|
685 |
+
x = x.detach()
|
686 |
+
s = torch.sigmoid(x - 1.0)
|
687 |
+
y = x * s
|
688 |
+
ctx.save_for_backward(s, y)
|
689 |
+
return y
|
690 |
+
|
691 |
+
@staticmethod
|
692 |
+
def backward(ctx, y_grad: Tensor) -> Tensor:
|
693 |
+
s, y = ctx.saved_tensors
|
694 |
+
return (y * (1 - s) + s) * y_grad
|
695 |
+
|
696 |
+
|
697 |
+
class DoubleSwish(torch.nn.Module):
|
698 |
+
def forward(self, x: Tensor) -> Tensor:
|
699 |
+
"""Return double-swish activation function which is an approximation to Swish(Swish(x)),
|
700 |
+
that we approximate closely with x * sigmoid(x-1).
|
701 |
+
"""
|
702 |
+
if torch.jit.is_scripting() or is_jit_tracing():
|
703 |
+
return x * torch.sigmoid(x - 1.0)
|
704 |
+
else:
|
705 |
+
return DoubleSwishFunction.apply(x)
|
706 |
+
|
707 |
+
|
708 |
+
class ScaledEmbedding(nn.Module):
|
709 |
+
r"""This is a modified version of nn.Embedding that introduces a learnable scale
|
710 |
+
on the parameters. Note: due to how we initialize it, it's best used with
|
711 |
+
schedulers like Noam that have a warmup period.
|
712 |
+
|
713 |
+
It is a simple lookup table that stores embeddings of a fixed dictionary and size.
|
714 |
+
|
715 |
+
This module is often used to store word embeddings and retrieve them using indices.
|
716 |
+
The input to the module is a list of indices, and the output is the corresponding
|
717 |
+
word embeddings.
|
718 |
+
|
719 |
+
Args:
|
720 |
+
num_embeddings (int): size of the dictionary of embeddings
|
721 |
+
embedding_dim (int): the size of each embedding vector
|
722 |
+
padding_idx (int, optional): If given, pads the output with the embedding vector at :attr:`padding_idx`
|
723 |
+
(initialized to zeros) whenever it encounters the index.
|
724 |
+
scale_grad_by_freq (boolean, optional): If given, this will scale gradients by the inverse of frequency of
|
725 |
+
the words in the mini-batch. Default ``False``.
|
726 |
+
sparse (bool, optional): If ``True``, gradient w.r.t. :attr:`weight` matrix will be a sparse tensor.
|
727 |
+
See Notes for more details regarding sparse gradients.
|
728 |
+
|
729 |
+
initial_speed (float, optional): This affects how fast the parameter will
|
730 |
+
learn near the start of training; you can set it to a value less than
|
731 |
+
one if you suspect that a module is contributing to instability near
|
732 |
+
the start of training. Note: regardless of the use of this option,
|
733 |
+
it's best to use schedulers like Noam that have a warm-up period.
|
734 |
+
Alternatively you can set it to more than 1 if you want it to
|
735 |
+
initially train faster. Must be greater than 0.
|
736 |
+
|
737 |
+
|
738 |
+
Attributes:
|
739 |
+
weight (Tensor): the learnable weights of the module of shape (num_embeddings, embedding_dim)
|
740 |
+
initialized from :math:`\mathcal{N}(0, 1)`
|
741 |
+
|
742 |
+
Shape:
|
743 |
+
- Input: :math:`(*)`, LongTensor of arbitrary shape containing the indices to extract
|
744 |
+
- Output: :math:`(*, H)`, where `*` is the input shape and :math:`H=\text{embedding\_dim}`
|
745 |
+
|
746 |
+
.. note::
|
747 |
+
Keep in mind that only a limited number of optimizers support
|
748 |
+
sparse gradients: currently it's :class:`optim.SGD` (`CUDA` and `CPU`),
|
749 |
+
:class:`optim.SparseAdam` (`CUDA` and `CPU`) and :class:`optim.Adagrad` (`CPU`)
|
750 |
+
|
751 |
+
.. note::
|
752 |
+
With :attr:`padding_idx` set, the embedding vector at
|
753 |
+
:attr:`padding_idx` is initialized to all zeros. However, note that this
|
754 |
+
vector can be modified afterwards, e.g., using a customized
|
755 |
+
initialization method, and thus changing the vector used to pad the
|
756 |
+
output. The gradient for this vector from :class:`~torch.nn.Embedding`
|
757 |
+
is always zero.
|
758 |
+
|
759 |
+
Examples::
|
760 |
+
|
761 |
+
>>> # an Embedding module containing 10 tensors of size 3
|
762 |
+
>>> embedding = nn.Embedding(10, 3)
|
763 |
+
>>> # a batch of 2 samples of 4 indices each
|
764 |
+
>>> input = torch.LongTensor([[1,2,4,5],[4,3,2,9]])
|
765 |
+
>>> embedding(input)
|
766 |
+
tensor([[[-0.0251, -1.6902, 0.7172],
|
767 |
+
[-0.6431, 0.0748, 0.6969],
|
768 |
+
[ 1.4970, 1.3448, -0.9685],
|
769 |
+
[-0.3677, -2.7265, -0.1685]],
|
770 |
+
|
771 |
+
[[ 1.4970, 1.3448, -0.9685],
|
772 |
+
[ 0.4362, -0.4004, 0.9400],
|
773 |
+
[-0.6431, 0.0748, 0.6969],
|
774 |
+
[ 0.9124, -2.3616, 1.1151]]])
|
775 |
+
|
776 |
+
|
777 |
+
>>> # example with padding_idx
|
778 |
+
>>> embedding = nn.Embedding(10, 3, padding_idx=0)
|
779 |
+
>>> input = torch.LongTensor([[0,2,0,5]])
|
780 |
+
>>> embedding(input)
|
781 |
+
tensor([[[ 0.0000, 0.0000, 0.0000],
|
782 |
+
[ 0.1535, -2.0309, 0.9315],
|
783 |
+
[ 0.0000, 0.0000, 0.0000],
|
784 |
+
[-0.1655, 0.9897, 0.0635]]])
|
785 |
+
|
786 |
+
"""
|
787 |
+
__constants__ = [
|
788 |
+
"num_embeddings",
|
789 |
+
"embedding_dim",
|
790 |
+
"padding_idx",
|
791 |
+
"scale_grad_by_freq",
|
792 |
+
"sparse",
|
793 |
+
]
|
794 |
+
|
795 |
+
num_embeddings: int
|
796 |
+
embedding_dim: int
|
797 |
+
padding_idx: int
|
798 |
+
scale_grad_by_freq: bool
|
799 |
+
weight: Tensor
|
800 |
+
sparse: bool
|
801 |
+
|
802 |
+
def __init__(
|
803 |
+
self,
|
804 |
+
num_embeddings: int,
|
805 |
+
embedding_dim: int,
|
806 |
+
padding_idx: Optional[int] = None,
|
807 |
+
scale_grad_by_freq: bool = False,
|
808 |
+
sparse: bool = False,
|
809 |
+
initial_speed: float = 1.0,
|
810 |
+
) -> None:
|
811 |
+
super(ScaledEmbedding, self).__init__()
|
812 |
+
self.num_embeddings = num_embeddings
|
813 |
+
self.embedding_dim = embedding_dim
|
814 |
+
if padding_idx is not None:
|
815 |
+
if padding_idx > 0:
|
816 |
+
assert (
|
817 |
+
padding_idx < self.num_embeddings
|
818 |
+
), "Padding_idx must be within num_embeddings"
|
819 |
+
elif padding_idx < 0:
|
820 |
+
assert (
|
821 |
+
padding_idx >= -self.num_embeddings
|
822 |
+
), "Padding_idx must be within num_embeddings"
|
823 |
+
padding_idx = self.num_embeddings + padding_idx
|
824 |
+
self.padding_idx = padding_idx
|
825 |
+
self.scale_grad_by_freq = scale_grad_by_freq
|
826 |
+
|
827 |
+
self.scale = nn.Parameter(torch.zeros(())) # see reset_parameters()
|
828 |
+
self.sparse = sparse
|
829 |
+
|
830 |
+
self.weight = nn.Parameter(torch.Tensor(num_embeddings, embedding_dim))
|
831 |
+
self.reset_parameters(initial_speed)
|
832 |
+
|
833 |
+
def reset_parameters(self, initial_speed: float = 1.0) -> None:
|
834 |
+
std = 0.1 / initial_speed
|
835 |
+
nn.init.normal_(self.weight, std=std)
|
836 |
+
nn.init.constant_(self.scale, torch.tensor(1.0 / std).log())
|
837 |
+
|
838 |
+
if self.padding_idx is not None:
|
839 |
+
with torch.no_grad():
|
840 |
+
self.weight[self.padding_idx].fill_(0)
|
841 |
+
|
842 |
+
def forward(self, input: Tensor) -> Tensor:
|
843 |
+
F = torch.nn.functional
|
844 |
+
scale = self.scale.exp()
|
845 |
+
if input.numel() < self.num_embeddings:
|
846 |
+
return (
|
847 |
+
F.embedding(
|
848 |
+
input,
|
849 |
+
self.weight,
|
850 |
+
self.padding_idx,
|
851 |
+
None,
|
852 |
+
2.0, # None, 2.0 relate to normalization
|
853 |
+
self.scale_grad_by_freq,
|
854 |
+
self.sparse,
|
855 |
+
)
|
856 |
+
* scale
|
857 |
+
)
|
858 |
+
else:
|
859 |
+
return F.embedding(
|
860 |
+
input,
|
861 |
+
self.weight * scale,
|
862 |
+
self.padding_idx,
|
863 |
+
None,
|
864 |
+
2.0, # None, 2.0 relates to normalization
|
865 |
+
self.scale_grad_by_freq,
|
866 |
+
self.sparse,
|
867 |
+
)
|
868 |
+
|
869 |
+
def extra_repr(self) -> str:
|
870 |
+
# s = "{num_embeddings}, {embedding_dim}, scale={scale}"
|
871 |
+
s = "{num_embeddings}, {embedding_dim}"
|
872 |
+
if self.padding_idx is not None:
|
873 |
+
s += ", padding_idx={padding_idx}"
|
874 |
+
if self.scale_grad_by_freq is not False:
|
875 |
+
s += ", scale_grad_by_freq={scale_grad_by_freq}"
|
876 |
+
if self.sparse is not False:
|
877 |
+
s += ", sparse=True"
|
878 |
+
return s.format(**self.__dict__)
|
879 |
+
|
880 |
+
|
881 |
+
def _test_activation_balancer_sign():
|
882 |
+
probs = torch.arange(0, 1, 0.01)
|
883 |
+
N = 1000
|
884 |
+
x = 1.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1))
|
885 |
+
x = x.detach()
|
886 |
+
x.requires_grad = True
|
887 |
+
m = ActivationBalancer(
|
888 |
+
channel_dim=0,
|
889 |
+
min_positive=0.05,
|
890 |
+
max_positive=0.95,
|
891 |
+
max_factor=0.2,
|
892 |
+
min_abs=0.0,
|
893 |
+
)
|
894 |
+
|
895 |
+
y_grad = torch.sign(torch.randn(probs.numel(), N))
|
896 |
+
|
897 |
+
y = m(x)
|
898 |
+
y.backward(gradient=y_grad)
|
899 |
+
print("_test_activation_balancer_sign: x = ", x)
|
900 |
+
print("_test_activation_balancer_sign: y grad = ", y_grad)
|
901 |
+
print("_test_activation_balancer_sign: x grad = ", x.grad)
|
902 |
+
|
903 |
+
|
904 |
+
def _test_activation_balancer_magnitude():
|
905 |
+
magnitudes = torch.arange(0, 1, 0.01)
|
906 |
+
N = 1000
|
907 |
+
x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1)
|
908 |
+
x = x.detach()
|
909 |
+
x.requires_grad = True
|
910 |
+
m = ActivationBalancer(
|
911 |
+
channel_dim=0,
|
912 |
+
min_positive=0.0,
|
913 |
+
max_positive=1.0,
|
914 |
+
max_factor=0.2,
|
915 |
+
min_abs=0.2,
|
916 |
+
max_abs=0.8,
|
917 |
+
)
|
918 |
+
|
919 |
+
y_grad = torch.sign(torch.randn(magnitudes.numel(), N))
|
920 |
+
|
921 |
+
y = m(x)
|
922 |
+
y.backward(gradient=y_grad)
|
923 |
+
print("_test_activation_balancer_magnitude: x = ", x)
|
924 |
+
print("_test_activation_balancer_magnitude: y grad = ", y_grad)
|
925 |
+
print("_test_activation_balancer_magnitude: x grad = ", x.grad)
|
926 |
+
|
927 |
+
|
928 |
+
def _test_basic_norm():
|
929 |
+
num_channels = 128
|
930 |
+
m = BasicNorm(num_channels=num_channels, channel_dim=1)
|
931 |
+
|
932 |
+
x = torch.randn(500, num_channels)
|
933 |
+
|
934 |
+
y = m(x)
|
935 |
+
|
936 |
+
assert y.shape == x.shape
|
937 |
+
x_rms = (x**2).mean().sqrt()
|
938 |
+
y_rms = (y**2).mean().sqrt()
|
939 |
+
print("x rms = ", x_rms)
|
940 |
+
print("y rms = ", y_rms)
|
941 |
+
assert y_rms < x_rms
|
942 |
+
assert y_rms > 0.5 * x_rms
|
943 |
+
|
944 |
+
|
945 |
+
def _test_double_swish_deriv():
|
946 |
+
x = torch.randn(10, 12, dtype=torch.double) * 0.5
|
947 |
+
x.requires_grad = True
|
948 |
+
m = DoubleSwish()
|
949 |
+
torch.autograd.gradcheck(m, x)
|
950 |
+
|
951 |
+
|
952 |
+
def _test_scaled_lstm():
|
953 |
+
N, L = 2, 30
|
954 |
+
dim_in, dim_hidden = 10, 20
|
955 |
+
m = ScaledLSTM(input_size=dim_in, hidden_size=dim_hidden, bias=True)
|
956 |
+
x = torch.randn(L, N, dim_in)
|
957 |
+
h0 = torch.randn(1, N, dim_hidden)
|
958 |
+
c0 = torch.randn(1, N, dim_hidden)
|
959 |
+
y, (h, c) = m(x, (h0, c0))
|
960 |
+
assert y.shape == (L, N, dim_hidden)
|
961 |
+
assert h.shape == (1, N, dim_hidden)
|
962 |
+
assert c.shape == (1, N, dim_hidden)
|
963 |
+
|
964 |
+
|
965 |
+
def _test_grad_filter():
|
966 |
+
threshold = 50.0
|
967 |
+
time, batch, channel = 200, 5, 128
|
968 |
+
grad_filter = GradientFilter(batch_dim=1, threshold=threshold)
|
969 |
+
|
970 |
+
for i in range(2):
|
971 |
+
x = torch.randn(time, batch, channel, requires_grad=True)
|
972 |
+
w = nn.Parameter(torch.ones(5))
|
973 |
+
b = nn.Parameter(torch.zeros(5))
|
974 |
+
|
975 |
+
x_out, w_out, b_out = grad_filter(x, w, b)
|
976 |
+
|
977 |
+
w_out_grad = torch.randn_like(w)
|
978 |
+
b_out_grad = torch.randn_like(b)
|
979 |
+
x_out_grad = torch.rand_like(x)
|
980 |
+
if i % 2 == 1:
|
981 |
+
# The gradient norm of the first element must be larger than
|
982 |
+
# `threshold * median`, where `median` is the median value
|
983 |
+
# of gradient norms of all elements in batch.
|
984 |
+
x_out_grad[:, 0, :] = torch.full((time, channel), threshold)
|
985 |
+
|
986 |
+
torch.autograd.backward(
|
987 |
+
[x_out, w_out, b_out], [x_out_grad, w_out_grad, b_out_grad]
|
988 |
+
)
|
989 |
+
|
990 |
+
print(
|
991 |
+
"_test_grad_filter: for gradient norms, the first element > median * threshold ", # noqa
|
992 |
+
i % 2 == 1,
|
993 |
+
)
|
994 |
+
|
995 |
+
print(
|
996 |
+
"_test_grad_filter: x_out_grad norm = ",
|
997 |
+
(x_out_grad**2).mean(dim=(0, 2)).sqrt(),
|
998 |
+
)
|
999 |
+
print(
|
1000 |
+
"_test_grad_filter: x.grad norm = ",
|
1001 |
+
(x.grad**2).mean(dim=(0, 2)).sqrt(),
|
1002 |
+
)
|
1003 |
+
print("_test_grad_filter: w_out_grad = ", w_out_grad)
|
1004 |
+
print("_test_grad_filter: w.grad = ", w.grad)
|
1005 |
+
print("_test_grad_filter: b_out_grad = ", b_out_grad)
|
1006 |
+
print("_test_grad_filter: b.grad = ", b.grad)
|
1007 |
+
|
1008 |
+
|
1009 |
+
if __name__ == "__main__":
|
1010 |
+
_test_activation_balancer_sign()
|
1011 |
+
_test_activation_balancer_magnitude()
|
1012 |
+
_test_basic_norm()
|
1013 |
+
_test_double_swish_deriv()
|
1014 |
+
_test_scaled_lstm()
|
1015 |
+
_test_grad_filter()
|
err2020/conformer_ctc3/test_model.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
3 |
+
#
|
4 |
+
# See ../../../../LICENSE for clarification regarding multiple authors
|
5 |
+
#
|
6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
7 |
+
# you may not use this file except in compliance with the License.
|
8 |
+
# You may obtain a copy of the License at
|
9 |
+
#
|
10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
11 |
+
#
|
12 |
+
# Unless required by applicable law or agreed to in writing, software
|
13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15 |
+
# See the License for the specific language governing permissions and
|
16 |
+
# limitations under the License.
|
17 |
+
|
18 |
+
|
19 |
+
"""
|
20 |
+
To run this file, do:
|
21 |
+
|
22 |
+
cd icefall/egs/librispeech/ASR
|
23 |
+
python ./conformer_ctc3/test_model.py
|
24 |
+
"""
|
25 |
+
|
26 |
+
import torch
|
27 |
+
|
28 |
+
from train import get_params, get_ctc_model
|
29 |
+
|
30 |
+
|
31 |
+
def test_model():
|
32 |
+
params = get_params()
|
33 |
+
params.vocab_size = 500
|
34 |
+
params.blank_id = 0
|
35 |
+
params.context_size = 2
|
36 |
+
params.unk_id = 2
|
37 |
+
|
38 |
+
params.dynamic_chunk_training = False
|
39 |
+
params.short_chunk_size = 25
|
40 |
+
params.num_left_chunks = 4
|
41 |
+
params.causal_convolution = False
|
42 |
+
|
43 |
+
model = get_ctc_model(params)
|
44 |
+
|
45 |
+
num_param = sum([p.numel() for p in model.parameters()])
|
46 |
+
print(f"Number of model parameters: {num_param}")
|
47 |
+
|
48 |
+
features = torch.randn(2, 100, 80)
|
49 |
+
feature_lengths = torch.full((2,), 100)
|
50 |
+
model(x=features, x_lens=feature_lengths)
|
51 |
+
|
52 |
+
|
53 |
+
def test_model_streaming():
|
54 |
+
params = get_params()
|
55 |
+
params.vocab_size = 500
|
56 |
+
params.blank_id = 0
|
57 |
+
params.context_size = 2
|
58 |
+
params.unk_id = 2
|
59 |
+
|
60 |
+
params.dynamic_chunk_training = True
|
61 |
+
params.short_chunk_size = 25
|
62 |
+
params.num_left_chunks = 4
|
63 |
+
params.causal_convolution = True
|
64 |
+
|
65 |
+
model = get_ctc_model(params)
|
66 |
+
|
67 |
+
num_param = sum([p.numel() for p in model.parameters()])
|
68 |
+
print(f"Number of model parameters: {num_param}")
|
69 |
+
|
70 |
+
features = torch.randn(2, 100, 80)
|
71 |
+
feature_lengths = torch.full((2,), 100)
|
72 |
+
encoder_out, _ = model.encoder(x=features, x_lens=feature_lengths)
|
73 |
+
model.get_ctc_output(encoder_out)
|
74 |
+
|
75 |
+
|
76 |
+
def main():
|
77 |
+
test_model()
|
78 |
+
test_model_streaming()
|
79 |
+
|
80 |
+
|
81 |
+
if __name__ == "__main__":
|
82 |
+
main()
|
err2020/conformer_ctc3/train.py
ADDED
@@ -0,0 +1,1109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang,
|
3 |
+
# Wei Kang,
|
4 |
+
# Mingshuang Luo,)
|
5 |
+
# Zengwei Yao)
|
6 |
+
#
|
7 |
+
# See ../../../../LICENSE for clarification regarding multiple authors
|
8 |
+
#
|
9 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
10 |
+
# you may not use this file except in compliance with the License.
|
11 |
+
# You may obtain a copy of the License at
|
12 |
+
#
|
13 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
14 |
+
#
|
15 |
+
# Unless required by applicable law or agreed to in writing, software
|
16 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
17 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
18 |
+
# See the License for the specific language governing permissions and
|
19 |
+
# limitations under the License.
|
20 |
+
"""
|
21 |
+
Usage:
|
22 |
+
|
23 |
+
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
24 |
+
|
25 |
+
./conformer_ctc3/train.py \
|
26 |
+
--world-size 4 \
|
27 |
+
--num-epochs 30 \
|
28 |
+
--start-epoch 1 \
|
29 |
+
--exp-dir conformer_ctc3/exp \
|
30 |
+
--full-libri 1 \
|
31 |
+
--max-duration 300
|
32 |
+
|
33 |
+
# For mix precision training:
|
34 |
+
|
35 |
+
./conformer_ctc3/train.py \
|
36 |
+
--world-size 4 \
|
37 |
+
--num-epochs 30 \
|
38 |
+
--start-epoch 1 \
|
39 |
+
--use-fp16 1 \
|
40 |
+
--exp-dir conformer_ctc3/exp \
|
41 |
+
--full-libri 1 \
|
42 |
+
--max-duration 550
|
43 |
+
|
44 |
+
# train a streaming model
|
45 |
+
./conformer_ctc3/train.py \
|
46 |
+
--world-size 4 \
|
47 |
+
--num-epochs 30 \
|
48 |
+
--start-epoch 1 \
|
49 |
+
--exp-dir conformer_ctc3/exp \
|
50 |
+
--full-libri 1 \
|
51 |
+
--dynamic-chunk-training 1 \
|
52 |
+
--causal-convolution 1 \
|
53 |
+
--short-chunk-size 25 \
|
54 |
+
--num-left-chunks 4 \
|
55 |
+
--max-duration 300 \
|
56 |
+
--delay-penalty 0.0
|
57 |
+
"""
|
58 |
+
|
59 |
+
import argparse
|
60 |
+
import copy
|
61 |
+
import logging
|
62 |
+
from pathlib import Path
|
63 |
+
from shutil import copyfile
|
64 |
+
from typing import Any, Dict, Optional, Tuple, Union
|
65 |
+
|
66 |
+
import k2
|
67 |
+
import optim
|
68 |
+
import torch
|
69 |
+
import torch.multiprocessing as mp
|
70 |
+
import torch.nn as nn
|
71 |
+
from asr_datamodule import LibriSpeechAsrDataModule
|
72 |
+
from conformer import Conformer
|
73 |
+
from lhotse.cut import Cut
|
74 |
+
from lhotse.dataset.sampling.base import CutSampler
|
75 |
+
from lhotse.utils import fix_random_seed
|
76 |
+
from model import CTCModel
|
77 |
+
from optim import Eden, Eve
|
78 |
+
from torch import Tensor
|
79 |
+
from torch.cuda.amp import GradScaler
|
80 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
81 |
+
from torch.utils.tensorboard import SummaryWriter
|
82 |
+
|
83 |
+
from icefall import diagnostics
|
84 |
+
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
|
85 |
+
from icefall.checkpoint import load_checkpoint, remove_checkpoints
|
86 |
+
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
87 |
+
from icefall.checkpoint import (
|
88 |
+
save_checkpoint_with_global_batch_idx,
|
89 |
+
update_averaged_model,
|
90 |
+
)
|
91 |
+
from icefall.dist import cleanup_dist, setup_dist
|
92 |
+
from icefall.env import get_env_info
|
93 |
+
from icefall.graph_compiler import CtcTrainingGraphCompiler
|
94 |
+
from icefall.lexicon import Lexicon
|
95 |
+
from icefall.utils import (
|
96 |
+
AttributeDict,
|
97 |
+
MetricsTracker,
|
98 |
+
encode_supervisions,
|
99 |
+
setup_logger,
|
100 |
+
str2bool,
|
101 |
+
)
|
102 |
+
|
103 |
+
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
104 |
+
|
105 |
+
|
106 |
+
def add_model_arguments(parser: argparse.ArgumentParser):
|
107 |
+
parser.add_argument(
|
108 |
+
"--dynamic-chunk-training",
|
109 |
+
type=str2bool,
|
110 |
+
default=False,
|
111 |
+
help="""Whether to use dynamic_chunk_training, if you want a streaming
|
112 |
+
model, this requires to be True.
|
113 |
+
""",
|
114 |
+
)
|
115 |
+
|
116 |
+
parser.add_argument(
|
117 |
+
"--causal-convolution",
|
118 |
+
type=str2bool,
|
119 |
+
default=False,
|
120 |
+
help="""Whether to use causal convolution, this requires to be True when
|
121 |
+
using dynamic_chunk_training.
|
122 |
+
""",
|
123 |
+
)
|
124 |
+
|
125 |
+
parser.add_argument(
|
126 |
+
"--short-chunk-size",
|
127 |
+
type=int,
|
128 |
+
default=25,
|
129 |
+
help="""Chunk length of dynamic training, the chunk size would be either
|
130 |
+
max sequence length of current batch or uniformly sampled from (1, short_chunk_size).
|
131 |
+
""",
|
132 |
+
)
|
133 |
+
|
134 |
+
parser.add_argument(
|
135 |
+
"--num-left-chunks",
|
136 |
+
type=int,
|
137 |
+
default=4,
|
138 |
+
help="How many left context can be seen in chunks when calculating attention.",
|
139 |
+
)
|
140 |
+
|
141 |
+
|
142 |
+
def get_parser():
|
143 |
+
parser = argparse.ArgumentParser(
|
144 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
145 |
+
)
|
146 |
+
|
147 |
+
parser.add_argument(
|
148 |
+
"--world-size",
|
149 |
+
type=int,
|
150 |
+
default=1,
|
151 |
+
help="Number of GPUs for DDP training.",
|
152 |
+
)
|
153 |
+
|
154 |
+
parser.add_argument(
|
155 |
+
"--master-port",
|
156 |
+
type=int,
|
157 |
+
default=12354,
|
158 |
+
help="Master port to use for DDP training.",
|
159 |
+
)
|
160 |
+
|
161 |
+
parser.add_argument(
|
162 |
+
"--tensorboard",
|
163 |
+
type=str2bool,
|
164 |
+
default=True,
|
165 |
+
help="Should various information be logged in tensorboard.",
|
166 |
+
)
|
167 |
+
|
168 |
+
parser.add_argument(
|
169 |
+
"--num-epochs",
|
170 |
+
type=int,
|
171 |
+
default=30,
|
172 |
+
help="Number of epochs to train.",
|
173 |
+
)
|
174 |
+
|
175 |
+
parser.add_argument(
|
176 |
+
"--start-epoch",
|
177 |
+
type=int,
|
178 |
+
default=1,
|
179 |
+
help="""Resume training from this epoch. It should be positive.
|
180 |
+
If larger than 1, it will load checkpoint from
|
181 |
+
exp-dir/epoch-{start_epoch-1}.pt
|
182 |
+
""",
|
183 |
+
)
|
184 |
+
|
185 |
+
parser.add_argument(
|
186 |
+
"--start-batch",
|
187 |
+
type=int,
|
188 |
+
default=0,
|
189 |
+
help="""If positive, --start-epoch is ignored and
|
190 |
+
it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt
|
191 |
+
""",
|
192 |
+
)
|
193 |
+
|
194 |
+
parser.add_argument(
|
195 |
+
"--exp-dir",
|
196 |
+
type=str,
|
197 |
+
default="conformer_ctc3/exp",
|
198 |
+
help="""The experiment dir.
|
199 |
+
It specifies the directory where all training related
|
200 |
+
files, e.g., checkpoints, log, etc, are saved
|
201 |
+
""",
|
202 |
+
)
|
203 |
+
|
204 |
+
parser.add_argument(
|
205 |
+
"--lang-dir",
|
206 |
+
type=str,
|
207 |
+
default="data/lang_bpe_500",
|
208 |
+
help="""The lang dir
|
209 |
+
It contains language related input files such as
|
210 |
+
"lexicon.txt"
|
211 |
+
""",
|
212 |
+
)
|
213 |
+
|
214 |
+
parser.add_argument(
|
215 |
+
"--initial-lr",
|
216 |
+
type=float,
|
217 |
+
default=0.003,
|
218 |
+
help="""The initial learning rate. This value should not need to be
|
219 |
+
changed.""",
|
220 |
+
)
|
221 |
+
|
222 |
+
parser.add_argument(
|
223 |
+
"--lr-batches",
|
224 |
+
type=float,
|
225 |
+
default=5000,
|
226 |
+
help="""Number of steps that affects how rapidly the learning rate decreases.
|
227 |
+
We suggest not to change this.""",
|
228 |
+
)
|
229 |
+
|
230 |
+
parser.add_argument(
|
231 |
+
"--lr-epochs",
|
232 |
+
type=float,
|
233 |
+
default=6,
|
234 |
+
help="""Number of epochs that affects how rapidly the learning rate decreases.
|
235 |
+
""",
|
236 |
+
)
|
237 |
+
|
238 |
+
parser.add_argument(
|
239 |
+
"--seed",
|
240 |
+
type=int,
|
241 |
+
default=42,
|
242 |
+
help="The seed for random generators intended for reproducibility",
|
243 |
+
)
|
244 |
+
|
245 |
+
parser.add_argument(
|
246 |
+
"--print-diagnostics",
|
247 |
+
type=str2bool,
|
248 |
+
default=False,
|
249 |
+
help="Accumulate stats on activations, print them and exit.",
|
250 |
+
)
|
251 |
+
|
252 |
+
parser.add_argument(
|
253 |
+
"--save-every-n",
|
254 |
+
type=int,
|
255 |
+
default=8000,
|
256 |
+
help="""Save checkpoint after processing this number of batches"
|
257 |
+
periodically. We save checkpoint to exp-dir/ whenever
|
258 |
+
params.batch_idx_train % save_every_n == 0. The checkpoint filename
|
259 |
+
has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
|
260 |
+
Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
|
261 |
+
end of each epoch where `xxx` is the epoch number counting from 0.
|
262 |
+
""",
|
263 |
+
)
|
264 |
+
|
265 |
+
parser.add_argument(
|
266 |
+
"--keep-last-k",
|
267 |
+
type=int,
|
268 |
+
default=20,
|
269 |
+
help="""Only keep this number of checkpoints on disk.
|
270 |
+
For instance, if it is 3, there are only 3 checkpoints
|
271 |
+
in the exp-dir with filenames `checkpoint-xxx.pt`.
|
272 |
+
It does not affect checkpoints with name `epoch-xxx.pt`.
|
273 |
+
""",
|
274 |
+
)
|
275 |
+
|
276 |
+
parser.add_argument(
|
277 |
+
"--average-period",
|
278 |
+
type=int,
|
279 |
+
default=100,
|
280 |
+
help="""Update the averaged model, namely `model_avg`, after processing
|
281 |
+
this number of batches. `model_avg` is a separate version of model,
|
282 |
+
in which each floating-point parameter is the average of all the
|
283 |
+
parameters from the start of training. Each time we take the average,
|
284 |
+
we do: `model_avg = model * (average_period / batch_idx_train) +
|
285 |
+
model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
|
286 |
+
""",
|
287 |
+
)
|
288 |
+
|
289 |
+
parser.add_argument(
|
290 |
+
"--use-fp16",
|
291 |
+
type=str2bool,
|
292 |
+
default=False,
|
293 |
+
help="Whether to use half precision training.",
|
294 |
+
)
|
295 |
+
|
296 |
+
parser.add_argument(
|
297 |
+
"--delay-penalty",
|
298 |
+
type=float,
|
299 |
+
default=0.0,
|
300 |
+
help="""A constant used to scale the symbol delay penalty,
|
301 |
+
to encourage symbol emit earlier for streaming models.
|
302 |
+
It is almost the same as the `delay_penalty` in our `rnnt_loss`, See
|
303 |
+
https://github.com/k2-fsa/k2/issues/955 and
|
304 |
+
https://arxiv.org/pdf/2211.00490.pdf for more details.""",
|
305 |
+
)
|
306 |
+
|
307 |
+
parser.add_argument(
|
308 |
+
"--nnet-delay-penalty",
|
309 |
+
type=float,
|
310 |
+
default=0.0,
|
311 |
+
help="""A constant to penalize symbol delay, which is applied on
|
312 |
+
the nnet_output after log-softmax.
|
313 |
+
We recommend using --delay-penalty instead.
|
314 |
+
See https://github.com/k2-fsa/icefall/pull/669 for details.""",
|
315 |
+
)
|
316 |
+
|
317 |
+
add_model_arguments(parser)
|
318 |
+
|
319 |
+
return parser
|
320 |
+
|
321 |
+
|
322 |
+
def get_params() -> AttributeDict:
|
323 |
+
"""Return a dict containing training parameters.
|
324 |
+
|
325 |
+
All training related parameters that are not passed from the commandline
|
326 |
+
are saved in the variable `params`.
|
327 |
+
|
328 |
+
Commandline options are merged into `params` after they are parsed, so
|
329 |
+
you can also access them via `params`.
|
330 |
+
|
331 |
+
Explanation of options saved in `params`:
|
332 |
+
|
333 |
+
- best_train_loss: Best training loss so far. It is used to select
|
334 |
+
the model that has the lowest training loss. It is
|
335 |
+
updated during the training.
|
336 |
+
|
337 |
+
- best_valid_loss: Best validation loss so far. It is used to select
|
338 |
+
the model that has the lowest validation loss. It is
|
339 |
+
updated during the training.
|
340 |
+
|
341 |
+
- best_train_epoch: It is the epoch that has the best training loss.
|
342 |
+
|
343 |
+
- best_valid_epoch: It is the epoch that has the best validation loss.
|
344 |
+
|
345 |
+
- batch_idx_train: Used to writing statistics to tensorboard. It
|
346 |
+
contains number of batches trained so far across
|
347 |
+
epochs.
|
348 |
+
|
349 |
+
- log_interval: Print training loss if batch_idx % log_interval` is 0
|
350 |
+
|
351 |
+
- reset_interval: Reset statistics if batch_idx % reset_interval is 0
|
352 |
+
|
353 |
+
- valid_interval: Run validation if batch_idx % valid_interval is 0
|
354 |
+
|
355 |
+
- feature_dim: The model input dim. It has to match the one used
|
356 |
+
in computing features.
|
357 |
+
|
358 |
+
- subsampling_factor: The subsampling factor for the model.
|
359 |
+
|
360 |
+
- encoder_dim: Hidden dim for multi-head attention model.
|
361 |
+
|
362 |
+
- num_decoder_layers: Number of decoder layer of transformer decoder.
|
363 |
+
|
364 |
+
- warm_step: The warm_step for Noam optimizer.
|
365 |
+
"""
|
366 |
+
params = AttributeDict(
|
367 |
+
{
|
368 |
+
"best_train_loss": float("inf"),
|
369 |
+
"best_valid_loss": float("inf"),
|
370 |
+
"best_train_epoch": -1,
|
371 |
+
"best_valid_epoch": -1,
|
372 |
+
"batch_idx_train": 0,
|
373 |
+
"log_interval": 50,
|
374 |
+
"reset_interval": 200,
|
375 |
+
"valid_interval": 3000, # For the 100h subset, use 800
|
376 |
+
# parameters for conformer
|
377 |
+
"feature_dim": 80,
|
378 |
+
"subsampling_factor": 4,
|
379 |
+
"encoder_dim": 512,
|
380 |
+
"nhead": 8,
|
381 |
+
"dim_feedforward": 2048,
|
382 |
+
"num_encoder_layers": 12,
|
383 |
+
# parameters for loss
|
384 |
+
"beam_size": 10,
|
385 |
+
"reduction": "none",
|
386 |
+
"use_double_scores": True,
|
387 |
+
# parameters for Noam
|
388 |
+
"model_warm_step": 3000, # arg given to model, not for lrate
|
389 |
+
"env_info": get_env_info(),
|
390 |
+
}
|
391 |
+
)
|
392 |
+
|
393 |
+
return params
|
394 |
+
|
395 |
+
|
396 |
+
def get_encoder_model(params: AttributeDict) -> nn.Module:
|
397 |
+
# TODO: We can add an option to switch between Conformer and Transformer
|
398 |
+
encoder = Conformer(
|
399 |
+
num_features=params.feature_dim,
|
400 |
+
subsampling_factor=params.subsampling_factor,
|
401 |
+
d_model=params.encoder_dim,
|
402 |
+
nhead=params.nhead,
|
403 |
+
dim_feedforward=params.dim_feedforward,
|
404 |
+
num_encoder_layers=params.num_encoder_layers,
|
405 |
+
dynamic_chunk_training=params.dynamic_chunk_training,
|
406 |
+
short_chunk_size=params.short_chunk_size,
|
407 |
+
num_left_chunks=params.num_left_chunks,
|
408 |
+
causal=params.causal_convolution,
|
409 |
+
)
|
410 |
+
return encoder
|
411 |
+
|
412 |
+
|
413 |
+
def get_ctc_model(params: AttributeDict) -> nn.Module:
|
414 |
+
encoder = get_encoder_model(params)
|
415 |
+
model = CTCModel(
|
416 |
+
encoder=encoder,
|
417 |
+
encoder_dim=params.encoder_dim,
|
418 |
+
vocab_size=params.vocab_size,
|
419 |
+
)
|
420 |
+
return model
|
421 |
+
|
422 |
+
|
423 |
+
def load_checkpoint_if_available(
|
424 |
+
params: AttributeDict,
|
425 |
+
model: nn.Module,
|
426 |
+
model_avg: nn.Module = None,
|
427 |
+
optimizer: Optional[torch.optim.Optimizer] = None,
|
428 |
+
scheduler: Optional[LRSchedulerType] = None,
|
429 |
+
) -> Optional[Dict[str, Any]]:
|
430 |
+
"""Load checkpoint from file.
|
431 |
+
|
432 |
+
If params.start_batch is positive, it will load the checkpoint from
|
433 |
+
`params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
|
434 |
+
params.start_epoch is larger than 1, it will load the checkpoint from
|
435 |
+
`params.start_epoch - 1`.
|
436 |
+
|
437 |
+
Apart from loading state dict for `model` and `optimizer` it also updates
|
438 |
+
`best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
|
439 |
+
and `best_valid_loss` in `params`.
|
440 |
+
|
441 |
+
Args:
|
442 |
+
params:
|
443 |
+
The return value of :func:`get_params`.
|
444 |
+
model:
|
445 |
+
The training model.
|
446 |
+
model_avg:
|
447 |
+
The stored model averaged from the start of training.
|
448 |
+
optimizer:
|
449 |
+
The optimizer that we are using.
|
450 |
+
scheduler:
|
451 |
+
The scheduler that we are using.
|
452 |
+
Returns:
|
453 |
+
Return a dict containing previously saved training info.
|
454 |
+
"""
|
455 |
+
if params.start_batch > 0:
|
456 |
+
filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
|
457 |
+
elif params.start_epoch > 1:
|
458 |
+
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
|
459 |
+
else:
|
460 |
+
return None
|
461 |
+
|
462 |
+
assert filename.is_file(), f"{filename} does not exist!"
|
463 |
+
|
464 |
+
saved_params = load_checkpoint(
|
465 |
+
filename,
|
466 |
+
model=model,
|
467 |
+
model_avg=model_avg,
|
468 |
+
optimizer=optimizer,
|
469 |
+
scheduler=scheduler,
|
470 |
+
)
|
471 |
+
|
472 |
+
keys = [
|
473 |
+
"best_train_epoch",
|
474 |
+
"best_valid_epoch",
|
475 |
+
"batch_idx_train",
|
476 |
+
"best_train_loss",
|
477 |
+
"best_valid_loss",
|
478 |
+
]
|
479 |
+
for k in keys:
|
480 |
+
params[k] = saved_params[k]
|
481 |
+
|
482 |
+
if params.start_batch > 0:
|
483 |
+
if "cur_epoch" in saved_params:
|
484 |
+
params["start_epoch"] = saved_params["cur_epoch"]
|
485 |
+
|
486 |
+
return saved_params
|
487 |
+
|
488 |
+
|
489 |
+
def save_checkpoint(
|
490 |
+
params: AttributeDict,
|
491 |
+
model: Union[nn.Module, DDP],
|
492 |
+
model_avg: Optional[nn.Module] = None,
|
493 |
+
optimizer: Optional[torch.optim.Optimizer] = None,
|
494 |
+
scheduler: Optional[LRSchedulerType] = None,
|
495 |
+
sampler: Optional[CutSampler] = None,
|
496 |
+
scaler: Optional[GradScaler] = None,
|
497 |
+
rank: int = 0,
|
498 |
+
) -> None:
|
499 |
+
"""Save model, optimizer, scheduler and training stats to file.
|
500 |
+
|
501 |
+
Args:
|
502 |
+
params:
|
503 |
+
It is returned by :func:`get_params`.
|
504 |
+
model:
|
505 |
+
The training model.
|
506 |
+
model_avg:
|
507 |
+
The stored model averaged from the start of training.
|
508 |
+
optimizer:
|
509 |
+
The optimizer used in the training.
|
510 |
+
sampler:
|
511 |
+
The sampler for the training dataset.
|
512 |
+
scaler:
|
513 |
+
The scaler used for mix precision training.
|
514 |
+
"""
|
515 |
+
if rank != 0:
|
516 |
+
return
|
517 |
+
filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
|
518 |
+
save_checkpoint_impl(
|
519 |
+
filename=filename,
|
520 |
+
model=model,
|
521 |
+
model_avg=model_avg,
|
522 |
+
params=params,
|
523 |
+
optimizer=optimizer,
|
524 |
+
scheduler=scheduler,
|
525 |
+
sampler=sampler,
|
526 |
+
scaler=scaler,
|
527 |
+
rank=rank,
|
528 |
+
)
|
529 |
+
|
530 |
+
if params.best_train_epoch == params.cur_epoch:
|
531 |
+
best_train_filename = params.exp_dir / "best-train-loss.pt"
|
532 |
+
copyfile(src=filename, dst=best_train_filename)
|
533 |
+
|
534 |
+
if params.best_valid_epoch == params.cur_epoch:
|
535 |
+
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
|
536 |
+
copyfile(src=filename, dst=best_valid_filename)
|
537 |
+
|
538 |
+
|
539 |
+
def compute_loss(
|
540 |
+
params: AttributeDict,
|
541 |
+
model: Union[nn.Module, DDP],
|
542 |
+
graph_compiler: Union[BpeCtcTrainingGraphCompiler, CtcTrainingGraphCompiler],
|
543 |
+
batch: dict,
|
544 |
+
is_training: bool,
|
545 |
+
warmup: float = 1.0,
|
546 |
+
) -> Tuple[Tensor, MetricsTracker]:
|
547 |
+
"""
|
548 |
+
Compute RNN-T loss given the model and its inputs.
|
549 |
+
|
550 |
+
Args:
|
551 |
+
params:
|
552 |
+
Parameters for training. See :func:`get_params`.
|
553 |
+
model:
|
554 |
+
The model for training. It is an instance of Conformer in our case.
|
555 |
+
graph_compiler:
|
556 |
+
It is used to build a decoding graph from a ctc topo and training
|
557 |
+
transcript. The training transcript is contained in the given `batch`,
|
558 |
+
while the ctc topo is built when this compiler is instantiated.
|
559 |
+
batch:
|
560 |
+
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
|
561 |
+
for the content in it.
|
562 |
+
is_training:
|
563 |
+
True for training. False for validation. When it is True, this
|
564 |
+
function enables autograd during computation; when it is False, it
|
565 |
+
disables autograd.
|
566 |
+
warmup: a floating point value which increases throughout training;
|
567 |
+
values >= 1.0 are fully warmed up and have all modules present.
|
568 |
+
"""
|
569 |
+
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
570 |
+
feature = batch["inputs"]
|
571 |
+
# at entry, feature is (N, T, C)
|
572 |
+
assert feature.ndim == 3
|
573 |
+
feature = feature.to(device)
|
574 |
+
|
575 |
+
supervisions = batch["supervisions"]
|
576 |
+
feature_lens = supervisions["num_frames"].to(device)
|
577 |
+
|
578 |
+
with torch.set_grad_enabled(is_training):
|
579 |
+
nnet_output, encoder_out_lens = model(
|
580 |
+
feature,
|
581 |
+
feature_lens,
|
582 |
+
warmup=warmup,
|
583 |
+
delay_penalty=params.nnet_delay_penalty if warmup >= 1.0 else 0,
|
584 |
+
)
|
585 |
+
assert torch.all(encoder_out_lens > 0)
|
586 |
+
|
587 |
+
# NOTE: We need `encode_supervisions` to sort sequences with
|
588 |
+
# different duration in decreasing order, required by
|
589 |
+
# `k2.intersect_dense` called in `k2.ctc_loss`
|
590 |
+
supervision_segments, texts = encode_supervisions(
|
591 |
+
supervisions, subsampling_factor=params.subsampling_factor
|
592 |
+
)
|
593 |
+
|
594 |
+
if isinstance(graph_compiler, BpeCtcTrainingGraphCompiler):
|
595 |
+
# Works with a BPE model
|
596 |
+
token_ids = graph_compiler.texts_to_ids(texts)
|
597 |
+
decoding_graph = graph_compiler.compile(token_ids)
|
598 |
+
elif isinstance(graph_compiler, CtcTrainingGraphCompiler):
|
599 |
+
# Works with a phone lexicon
|
600 |
+
decoding_graph = graph_compiler.compile(texts)
|
601 |
+
else:
|
602 |
+
raise ValueError(f"Unsupported type of graph compiler: {type(graph_compiler)}")
|
603 |
+
|
604 |
+
dense_fsa_vec = k2.DenseFsaVec(
|
605 |
+
nnet_output,
|
606 |
+
supervision_segments,
|
607 |
+
allow_truncate=params.subsampling_factor - 1,
|
608 |
+
)
|
609 |
+
|
610 |
+
ctc_loss = k2.ctc_loss(
|
611 |
+
decoding_graph=decoding_graph,
|
612 |
+
dense_fsa_vec=dense_fsa_vec,
|
613 |
+
output_beam=params.beam_size,
|
614 |
+
delay_penalty=params.delay_penalty if warmup >= 1.0 else 0.0,
|
615 |
+
reduction=params.reduction,
|
616 |
+
use_double_scores=params.use_double_scores,
|
617 |
+
)
|
618 |
+
ctc_loss_is_finite = torch.isfinite(ctc_loss)
|
619 |
+
if not torch.all(ctc_loss_is_finite):
|
620 |
+
logging.info("Not all losses are finite!\n" f"ctc_loss: {ctc_loss}")
|
621 |
+
ctc_loss = ctc_loss[ctc_loss_is_finite]
|
622 |
+
|
623 |
+
# If either all simple_loss or pruned_loss is inf or nan,
|
624 |
+
# we stop the training process by raising an exception
|
625 |
+
if torch.all(~ctc_loss_is_finite):
|
626 |
+
raise ValueError(
|
627 |
+
"There are too many utterances in this batch "
|
628 |
+
"leading to inf or nan losses."
|
629 |
+
)
|
630 |
+
loss = ctc_loss.sum()
|
631 |
+
|
632 |
+
assert loss.requires_grad == is_training
|
633 |
+
|
634 |
+
info = MetricsTracker()
|
635 |
+
# info["frames"] is an approximate number for two reasons:
|
636 |
+
# (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2
|
637 |
+
# (2) If some utterances in the batch lead to inf/nan loss, they
|
638 |
+
# are filtered out.
|
639 |
+
info["frames"] = supervision_segments[:, 2].sum().item()
|
640 |
+
# `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa
|
641 |
+
info["utterances"] = feature.size(0)
|
642 |
+
# averaged input duration in frames over utterances
|
643 |
+
info["utt_duration"] = feature_lens.sum().item()
|
644 |
+
# averaged padding proportion over utterances
|
645 |
+
info["utt_pad_proportion"] = (
|
646 |
+
((feature.size(1) - feature_lens) / feature.size(1)).sum().item()
|
647 |
+
)
|
648 |
+
|
649 |
+
# Note: We use reduction=sum while computing the loss.
|
650 |
+
info["loss"] = loss.detach().cpu().item()
|
651 |
+
|
652 |
+
return loss, info
|
653 |
+
|
654 |
+
|
655 |
+
def compute_validation_loss(
|
656 |
+
params: AttributeDict,
|
657 |
+
model: Union[nn.Module, DDP],
|
658 |
+
graph_compiler: Union[BpeCtcTrainingGraphCompiler, CtcTrainingGraphCompiler],
|
659 |
+
valid_dl: torch.utils.data.DataLoader,
|
660 |
+
world_size: int = 1,
|
661 |
+
) -> MetricsTracker:
|
662 |
+
"""Run the validation process."""
|
663 |
+
model.eval()
|
664 |
+
|
665 |
+
tot_loss = MetricsTracker()
|
666 |
+
|
667 |
+
for batch_idx, batch in enumerate(valid_dl):
|
668 |
+
loss, loss_info = compute_loss(
|
669 |
+
params=params,
|
670 |
+
model=model,
|
671 |
+
graph_compiler=graph_compiler,
|
672 |
+
batch=batch,
|
673 |
+
is_training=False,
|
674 |
+
)
|
675 |
+
assert loss.requires_grad is False
|
676 |
+
tot_loss = tot_loss + loss_info
|
677 |
+
|
678 |
+
if world_size > 1:
|
679 |
+
tot_loss.reduce(loss.device)
|
680 |
+
|
681 |
+
loss_value = tot_loss["loss"] / tot_loss["frames"]
|
682 |
+
if loss_value < params.best_valid_loss:
|
683 |
+
params.best_valid_epoch = params.cur_epoch
|
684 |
+
params.best_valid_loss = loss_value
|
685 |
+
|
686 |
+
return tot_loss
|
687 |
+
|
688 |
+
|
689 |
+
def train_one_epoch(
|
690 |
+
params: AttributeDict,
|
691 |
+
model: Union[nn.Module, DDP],
|
692 |
+
optimizer: torch.optim.Optimizer,
|
693 |
+
scheduler: LRSchedulerType,
|
694 |
+
graph_compiler: Union[BpeCtcTrainingGraphCompiler, CtcTrainingGraphCompiler],
|
695 |
+
train_dl: torch.utils.data.DataLoader,
|
696 |
+
valid_dl: torch.utils.data.DataLoader,
|
697 |
+
scaler: GradScaler,
|
698 |
+
model_avg: Optional[nn.Module] = None,
|
699 |
+
tb_writer: Optional[SummaryWriter] = None,
|
700 |
+
world_size: int = 1,
|
701 |
+
rank: int = 0,
|
702 |
+
) -> None:
|
703 |
+
"""Train the model for one epoch.
|
704 |
+
|
705 |
+
The training loss from the mean of all frames is saved in
|
706 |
+
`params.train_loss`. It runs the validation process every
|
707 |
+
`params.valid_interval` batches.
|
708 |
+
|
709 |
+
Args:
|
710 |
+
params:
|
711 |
+
It is returned by :func:`get_params`.
|
712 |
+
model:
|
713 |
+
The model for training.
|
714 |
+
optimizer:
|
715 |
+
The optimizer we are using.
|
716 |
+
scheduler:
|
717 |
+
The learning rate scheduler, we call step() every step.
|
718 |
+
graph_compiler:
|
719 |
+
It is used to build a decoding graph from a ctc topo and training
|
720 |
+
transcript. The training transcript is contained in the given `batch`,
|
721 |
+
while the ctc topo is built when this compiler is instantiated.
|
722 |
+
train_dl:
|
723 |
+
Dataloader for the training dataset.
|
724 |
+
valid_dl:
|
725 |
+
Dataloader for the validation dataset.
|
726 |
+
scaler:
|
727 |
+
The scaler used for mix precision training.
|
728 |
+
model_avg:
|
729 |
+
The stored model averaged from the start of training.
|
730 |
+
tb_writer:
|
731 |
+
Writer to write log messages to tensorboard.
|
732 |
+
world_size:
|
733 |
+
Number of nodes in DDP training. If it is 1, DDP is disabled.
|
734 |
+
rank:
|
735 |
+
The rank of the node in DDP training. If no DDP is used, it should
|
736 |
+
be set to 0.
|
737 |
+
"""
|
738 |
+
model.train()
|
739 |
+
|
740 |
+
tot_loss = MetricsTracker()
|
741 |
+
|
742 |
+
for batch_idx, batch in enumerate(train_dl):
|
743 |
+
params.batch_idx_train += 1
|
744 |
+
batch_size = len(batch["supervisions"]["text"])
|
745 |
+
|
746 |
+
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
747 |
+
loss, loss_info = compute_loss(
|
748 |
+
params=params,
|
749 |
+
model=model,
|
750 |
+
graph_compiler=graph_compiler,
|
751 |
+
batch=batch,
|
752 |
+
is_training=True,
|
753 |
+
warmup=(params.batch_idx_train / params.model_warm_step),
|
754 |
+
)
|
755 |
+
# summary stats
|
756 |
+
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
|
757 |
+
|
758 |
+
# NOTE: We use reduction==sum and loss is computed over utterances
|
759 |
+
# in the batch and there is no normalization to it so far.
|
760 |
+
scaler.scale(loss).backward()
|
761 |
+
scheduler.step_batch(params.batch_idx_train)
|
762 |
+
scaler.step(optimizer)
|
763 |
+
scaler.update()
|
764 |
+
optimizer.zero_grad()
|
765 |
+
|
766 |
+
if params.print_diagnostics and batch_idx == 30:
|
767 |
+
return
|
768 |
+
|
769 |
+
if (
|
770 |
+
rank == 0
|
771 |
+
and params.batch_idx_train > 0
|
772 |
+
and params.batch_idx_train % params.average_period == 0
|
773 |
+
):
|
774 |
+
update_averaged_model(
|
775 |
+
params=params,
|
776 |
+
model_cur=model,
|
777 |
+
model_avg=model_avg,
|
778 |
+
)
|
779 |
+
|
780 |
+
if (
|
781 |
+
params.batch_idx_train > 0
|
782 |
+
and params.batch_idx_train % params.save_every_n == 0
|
783 |
+
):
|
784 |
+
save_checkpoint_with_global_batch_idx(
|
785 |
+
out_dir=params.exp_dir,
|
786 |
+
global_batch_idx=params.batch_idx_train,
|
787 |
+
model=model,
|
788 |
+
model_avg=model_avg,
|
789 |
+
params=params,
|
790 |
+
optimizer=optimizer,
|
791 |
+
scheduler=scheduler,
|
792 |
+
sampler=train_dl.sampler,
|
793 |
+
scaler=scaler,
|
794 |
+
rank=rank,
|
795 |
+
)
|
796 |
+
remove_checkpoints(
|
797 |
+
out_dir=params.exp_dir,
|
798 |
+
topk=params.keep_last_k,
|
799 |
+
rank=rank,
|
800 |
+
)
|
801 |
+
|
802 |
+
if batch_idx % params.log_interval == 0:
|
803 |
+
cur_lr = scheduler.get_last_lr()[0]
|
804 |
+
logging.info(
|
805 |
+
f"Epoch {params.cur_epoch}, "
|
806 |
+
f"batch {batch_idx}, loss[{loss_info}], "
|
807 |
+
f"tot_loss[{tot_loss}], batch size: {batch_size}, "
|
808 |
+
f"lr: {cur_lr:.2e}"
|
809 |
+
)
|
810 |
+
|
811 |
+
if tb_writer is not None:
|
812 |
+
tb_writer.add_scalar(
|
813 |
+
"train/learning_rate", cur_lr, params.batch_idx_train
|
814 |
+
)
|
815 |
+
|
816 |
+
loss_info.write_summary(
|
817 |
+
tb_writer, "train/current_", params.batch_idx_train
|
818 |
+
)
|
819 |
+
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
|
820 |
+
|
821 |
+
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
|
822 |
+
logging.info("Computing validation loss")
|
823 |
+
valid_info = compute_validation_loss(
|
824 |
+
params=params,
|
825 |
+
model=model,
|
826 |
+
graph_compiler=graph_compiler,
|
827 |
+
valid_dl=valid_dl,
|
828 |
+
world_size=world_size,
|
829 |
+
)
|
830 |
+
model.train()
|
831 |
+
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
|
832 |
+
if tb_writer is not None:
|
833 |
+
valid_info.write_summary(
|
834 |
+
tb_writer, "train/valid_", params.batch_idx_train
|
835 |
+
)
|
836 |
+
|
837 |
+
loss_value = tot_loss["loss"] / tot_loss["frames"]
|
838 |
+
params.train_loss = loss_value
|
839 |
+
if params.train_loss < params.best_train_loss:
|
840 |
+
params.best_train_epoch = params.cur_epoch
|
841 |
+
params.best_train_loss = params.train_loss
|
842 |
+
|
843 |
+
|
844 |
+
def run(rank, world_size, args):
|
845 |
+
"""
|
846 |
+
Args:
|
847 |
+
rank:
|
848 |
+
It is a value between 0 and `world_size-1`, which is
|
849 |
+
passed automatically by `mp.spawn()` in :func:`main`.
|
850 |
+
The node with rank 0 is responsible for saving checkpoint.
|
851 |
+
world_size:
|
852 |
+
Number of GPUs for DDP training.
|
853 |
+
args:
|
854 |
+
The return value of get_parser().parse_args()
|
855 |
+
"""
|
856 |
+
params = get_params()
|
857 |
+
params.update(vars(args))
|
858 |
+
if params.full_libri is False:
|
859 |
+
params.valid_interval = 1600
|
860 |
+
|
861 |
+
fix_random_seed(params.seed)
|
862 |
+
if world_size > 1:
|
863 |
+
setup_dist(rank, world_size, params.master_port)
|
864 |
+
|
865 |
+
setup_logger(f"{params.exp_dir}/log/log-train")
|
866 |
+
logging.info("Training started")
|
867 |
+
|
868 |
+
if args.tensorboard and rank == 0:
|
869 |
+
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
|
870 |
+
else:
|
871 |
+
tb_writer = None
|
872 |
+
|
873 |
+
lexicon = Lexicon(params.lang_dir)
|
874 |
+
max_token_id = max(lexicon.tokens)
|
875 |
+
params.vocab_size = max_token_id + 1 # +1 for the blank
|
876 |
+
|
877 |
+
device = torch.device("cpu")
|
878 |
+
if torch.cuda.is_available():
|
879 |
+
device = torch.device("cuda", rank)
|
880 |
+
logging.info(f"Device: {device}")
|
881 |
+
|
882 |
+
if "lang_bpe" in str(params.lang_dir):
|
883 |
+
graph_compiler = BpeCtcTrainingGraphCompiler(
|
884 |
+
params.lang_dir,
|
885 |
+
device=device,
|
886 |
+
sos_token="<sos/eos>",
|
887 |
+
eos_token="<sos/eos>",
|
888 |
+
)
|
889 |
+
elif "lang_phone" in str(params.lang_dir):
|
890 |
+
graph_compiler = CtcTrainingGraphCompiler(
|
891 |
+
lexicon,
|
892 |
+
device=device,
|
893 |
+
need_repeat_flag=params.delay_penalty > 0,
|
894 |
+
)
|
895 |
+
# Manually add the sos/eos ID with their default values
|
896 |
+
# from the BPE recipe which we're adapting here.
|
897 |
+
graph_compiler.sos_id = 1
|
898 |
+
graph_compiler.eos_id = 1
|
899 |
+
else:
|
900 |
+
raise ValueError(
|
901 |
+
f"Unsupported type of lang dir (we expected it to have "
|
902 |
+
f"'lang_bpe' or 'lang_phone' in its name): {params.lang_dir}"
|
903 |
+
)
|
904 |
+
|
905 |
+
if params.dynamic_chunk_training:
|
906 |
+
assert (
|
907 |
+
params.causal_convolution
|
908 |
+
), "dynamic_chunk_training requires causal convolution"
|
909 |
+
|
910 |
+
logging.info(params)
|
911 |
+
|
912 |
+
logging.info("About to create model")
|
913 |
+
model = get_ctc_model(params)
|
914 |
+
|
915 |
+
num_param = sum([p.numel() for p in model.parameters()])
|
916 |
+
logging.info(f"Number of model parameters: {num_param}")
|
917 |
+
|
918 |
+
assert params.save_every_n >= params.average_period
|
919 |
+
model_avg: Optional[nn.Module] = None
|
920 |
+
if rank == 0:
|
921 |
+
# model_avg is only used with rank 0
|
922 |
+
model_avg = copy.deepcopy(model)
|
923 |
+
|
924 |
+
assert params.start_epoch > 0, params.start_epoch
|
925 |
+
checkpoints = load_checkpoint_if_available(
|
926 |
+
params=params, model=model, model_avg=model_avg
|
927 |
+
)
|
928 |
+
|
929 |
+
model.to(device)
|
930 |
+
if world_size > 1:
|
931 |
+
logging.info("Using DDP")
|
932 |
+
model = DDP(model, device_ids=[rank])
|
933 |
+
|
934 |
+
optimizer = Eve(model.parameters(), lr=params.initial_lr)
|
935 |
+
|
936 |
+
scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
|
937 |
+
|
938 |
+
if checkpoints and "optimizer" in checkpoints:
|
939 |
+
logging.info("Loading optimizer state dict")
|
940 |
+
optimizer.load_state_dict(checkpoints["optimizer"])
|
941 |
+
|
942 |
+
if (
|
943 |
+
checkpoints
|
944 |
+
and "scheduler" in checkpoints
|
945 |
+
and checkpoints["scheduler"] is not None
|
946 |
+
):
|
947 |
+
logging.info("Loading scheduler state dict")
|
948 |
+
scheduler.load_state_dict(checkpoints["scheduler"])
|
949 |
+
|
950 |
+
if params.print_diagnostics:
|
951 |
+
diagnostic = diagnostics.attach_diagnostics(model)
|
952 |
+
|
953 |
+
librispeech = LibriSpeechAsrDataModule(args)
|
954 |
+
|
955 |
+
train_cuts = librispeech.train_clean_100_cuts()
|
956 |
+
# if params.full_libri:
|
957 |
+
# train_cuts += librispeech.train_clean_360_cuts()
|
958 |
+
# train_cuts += librispeech.train_other_500_cuts()
|
959 |
+
|
960 |
+
def remove_short_and_long_utt(c: Cut):
|
961 |
+
# Keep only utterances with duration between 1 second and 20 seconds
|
962 |
+
#
|
963 |
+
# Caution: There is a reason to select 20.0 here. Please see
|
964 |
+
# ../local/display_manifest_statistics.py
|
965 |
+
#
|
966 |
+
# You should use ../local/display_manifest_statistics.py to get
|
967 |
+
# an utterance duration distribution for your dataset to select
|
968 |
+
# the threshold
|
969 |
+
return 1.0 <= c.duration <= 20.0
|
970 |
+
|
971 |
+
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
972 |
+
|
973 |
+
if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
|
974 |
+
# We only load the sampler's state dict when it loads a checkpoint
|
975 |
+
# saved in the middle of an epoch
|
976 |
+
sampler_state_dict = checkpoints["sampler"]
|
977 |
+
else:
|
978 |
+
sampler_state_dict = None
|
979 |
+
|
980 |
+
train_dl = librispeech.train_dataloaders(
|
981 |
+
train_cuts, sampler_state_dict=sampler_state_dict
|
982 |
+
)
|
983 |
+
|
984 |
+
valid_cuts = librispeech.dev_clean_cuts()
|
985 |
+
#valid_cuts += librispeech.dev_other_cuts()
|
986 |
+
valid_dl = librispeech.valid_dataloaders(valid_cuts)
|
987 |
+
|
988 |
+
if params.start_batch <= 0 and not params.print_diagnostics:
|
989 |
+
scan_pessimistic_batches_for_oom(
|
990 |
+
model=model,
|
991 |
+
train_dl=train_dl,
|
992 |
+
optimizer=optimizer,
|
993 |
+
graph_compiler=graph_compiler,
|
994 |
+
params=params,
|
995 |
+
warmup=0.0 if params.start_epoch == 1 else 1.0,
|
996 |
+
)
|
997 |
+
|
998 |
+
scaler = GradScaler(enabled=params.use_fp16)
|
999 |
+
if checkpoints and "grad_scaler" in checkpoints:
|
1000 |
+
logging.info("Loading grad scaler state dict")
|
1001 |
+
scaler.load_state_dict(checkpoints["grad_scaler"])
|
1002 |
+
|
1003 |
+
for epoch in range(params.start_epoch, params.num_epochs + 1):
|
1004 |
+
scheduler.step_epoch(epoch - 1)
|
1005 |
+
fix_random_seed(params.seed + epoch - 1)
|
1006 |
+
train_dl.sampler.set_epoch(epoch - 1)
|
1007 |
+
|
1008 |
+
if tb_writer is not None:
|
1009 |
+
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
|
1010 |
+
|
1011 |
+
params.cur_epoch = epoch
|
1012 |
+
|
1013 |
+
train_one_epoch(
|
1014 |
+
params=params,
|
1015 |
+
model=model,
|
1016 |
+
model_avg=model_avg,
|
1017 |
+
optimizer=optimizer,
|
1018 |
+
scheduler=scheduler,
|
1019 |
+
graph_compiler=graph_compiler,
|
1020 |
+
train_dl=train_dl,
|
1021 |
+
valid_dl=valid_dl,
|
1022 |
+
scaler=scaler,
|
1023 |
+
tb_writer=tb_writer,
|
1024 |
+
world_size=world_size,
|
1025 |
+
rank=rank,
|
1026 |
+
)
|
1027 |
+
|
1028 |
+
if params.print_diagnostics:
|
1029 |
+
diagnostic.print_diagnostics()
|
1030 |
+
break
|
1031 |
+
|
1032 |
+
save_checkpoint(
|
1033 |
+
params=params,
|
1034 |
+
model=model,
|
1035 |
+
model_avg=model_avg,
|
1036 |
+
optimizer=optimizer,
|
1037 |
+
scheduler=scheduler,
|
1038 |
+
sampler=train_dl.sampler,
|
1039 |
+
scaler=scaler,
|
1040 |
+
rank=rank,
|
1041 |
+
)
|
1042 |
+
|
1043 |
+
logging.info("Done!")
|
1044 |
+
|
1045 |
+
if world_size > 1:
|
1046 |
+
torch.distributed.barrier()
|
1047 |
+
cleanup_dist()
|
1048 |
+
|
1049 |
+
|
1050 |
+
def scan_pessimistic_batches_for_oom(
|
1051 |
+
model: Union[nn.Module, DDP],
|
1052 |
+
train_dl: torch.utils.data.DataLoader,
|
1053 |
+
optimizer: torch.optim.Optimizer,
|
1054 |
+
graph_compiler: Union[BpeCtcTrainingGraphCompiler, CtcTrainingGraphCompiler],
|
1055 |
+
params: AttributeDict,
|
1056 |
+
warmup: float,
|
1057 |
+
):
|
1058 |
+
from lhotse.dataset import find_pessimistic_batches
|
1059 |
+
|
1060 |
+
logging.info(
|
1061 |
+
"Sanity check -- see if any of the batches in epoch 1 would cause OOM."
|
1062 |
+
)
|
1063 |
+
batches, crit_values = find_pessimistic_batches(train_dl.sampler)
|
1064 |
+
for criterion, cuts in batches.items():
|
1065 |
+
batch = train_dl.dataset[cuts]
|
1066 |
+
try:
|
1067 |
+
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
1068 |
+
loss, _ = compute_loss(
|
1069 |
+
params=params,
|
1070 |
+
model=model,
|
1071 |
+
graph_compiler=graph_compiler,
|
1072 |
+
batch=batch,
|
1073 |
+
is_training=True,
|
1074 |
+
warmup=warmup,
|
1075 |
+
)
|
1076 |
+
loss.backward()
|
1077 |
+
optimizer.step()
|
1078 |
+
optimizer.zero_grad()
|
1079 |
+
except RuntimeError as e:
|
1080 |
+
if "CUDA out of memory" in str(e):
|
1081 |
+
logging.error(
|
1082 |
+
"Your GPU ran out of memory with the current "
|
1083 |
+
"max_duration setting. We recommend decreasing "
|
1084 |
+
"max_duration and trying again.\n"
|
1085 |
+
f"Failing criterion: {criterion} "
|
1086 |
+
f"(={crit_values[criterion]}) ..."
|
1087 |
+
)
|
1088 |
+
raise
|
1089 |
+
|
1090 |
+
|
1091 |
+
def main():
|
1092 |
+
parser = get_parser()
|
1093 |
+
LibriSpeechAsrDataModule.add_arguments(parser)
|
1094 |
+
args = parser.parse_args()
|
1095 |
+
args.exp_dir = Path(args.exp_dir)
|
1096 |
+
|
1097 |
+
world_size = args.world_size
|
1098 |
+
assert world_size >= 1
|
1099 |
+
if world_size > 1:
|
1100 |
+
mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
|
1101 |
+
else:
|
1102 |
+
run(rank=0, world_size=1, args=args)
|
1103 |
+
|
1104 |
+
|
1105 |
+
torch.set_num_threads(1)
|
1106 |
+
torch.set_num_interop_threads(1)
|
1107 |
+
|
1108 |
+
if __name__ == "__main__":
|
1109 |
+
main()
|
err2020/conformer_ctc3_usage.ipynb
ADDED
@@ -0,0 +1,500 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 3,
|
6 |
+
"id": "b6b6ded1-0a58-43cb-9065-4f4fae02a01b",
|
7 |
+
"metadata": {},
|
8 |
+
"outputs": [],
|
9 |
+
"source": [
|
10 |
+
"import argparse\n",
|
11 |
+
"import logging\n",
|
12 |
+
"import math\n",
|
13 |
+
"import re\n",
|
14 |
+
"from typing import List\n",
|
15 |
+
"import sys\n",
|
16 |
+
"sys.path.append('/opt/notebooks/err2020/conformer_ctc3/')\n",
|
17 |
+
"import k2\n",
|
18 |
+
"import kaldifeat\n",
|
19 |
+
"import sentencepiece as spm\n",
|
20 |
+
"import torch\n",
|
21 |
+
"import torchaudio\n",
|
22 |
+
"from decode import get_decoding_params\n",
|
23 |
+
"from torch.nn.utils.rnn import pad_sequence\n",
|
24 |
+
"from train import add_model_arguments, get_params\n",
|
25 |
+
"\n",
|
26 |
+
"from icefall.decode import (\n",
|
27 |
+
" get_lattice,\n",
|
28 |
+
" one_best_decoding,\n",
|
29 |
+
" rescore_with_n_best_list,\n",
|
30 |
+
" rescore_with_whole_lattice\n",
|
31 |
+
")\n",
|
32 |
+
"from icefall.utils import get_texts, parse_fsa_timestamps_and_texts"
|
33 |
+
]
|
34 |
+
},
|
35 |
+
{
|
36 |
+
"cell_type": "markdown",
|
37 |
+
"id": "52514f2f-1195-4e4f-8174-d21aa7462476",
|
38 |
+
"metadata": {},
|
39 |
+
"source": [
|
40 |
+
"## Helpers"
|
41 |
+
]
|
42 |
+
},
|
43 |
+
{
|
44 |
+
"cell_type": "markdown",
|
45 |
+
"id": "8ec024bf-7f91-47a9-9293-822fe2765c4b",
|
46 |
+
"metadata": {},
|
47 |
+
"source": [
|
48 |
+
"#### Load args helpers"
|
49 |
+
]
|
50 |
+
},
|
51 |
+
{
|
52 |
+
"cell_type": "code",
|
53 |
+
"execution_count": 4,
|
54 |
+
"id": "3d69d771-b421-417f-a6ff-e1d1c64ba934",
|
55 |
+
"metadata": {},
|
56 |
+
"outputs": [],
|
57 |
+
"source": [
|
58 |
+
"class Args:\n",
|
59 |
+
" model_filename='conformer_ctc3/exp/jit_trace.pt'\n",
|
60 |
+
" bpe_model_filename=\"data/lang_bpe_500/bpe.model\"\n",
|
61 |
+
" method=\"ctc-decoding\"\n",
|
62 |
+
" sample_rate=16000\n",
|
63 |
+
" num_classes=500 #bpe model size\n",
|
64 |
+
" frame_shift_ms=10\n",
|
65 |
+
" dither=0\n",
|
66 |
+
" snip_edges=False\n",
|
67 |
+
" num_bins=80\n",
|
68 |
+
" device='cpu'\n",
|
69 |
+
" \n",
|
70 |
+
" def args_from_dict(self, dct):\n",
|
71 |
+
" for key in dct:\n",
|
72 |
+
" setattr(self, key, dct[key])\n",
|
73 |
+
" \n",
|
74 |
+
" def __repr__(self):\n",
|
75 |
+
" text=''\n",
|
76 |
+
" for k, v in self.__dict__.items():\n",
|
77 |
+
" text+=f'{k} = {v}\\n'\n",
|
78 |
+
" return text"
|
79 |
+
]
|
80 |
+
},
|
81 |
+
{
|
82 |
+
"cell_type": "markdown",
|
83 |
+
"id": "57a3cd62-3037-4c99-9094-dd63429e660e",
|
84 |
+
"metadata": {},
|
85 |
+
"source": [
|
86 |
+
"#### Decoder helper"
|
87 |
+
]
|
88 |
+
},
|
89 |
+
{
|
90 |
+
"cell_type": "code",
|
91 |
+
"execution_count": 5,
|
92 |
+
"id": "48306369-fb68-4abe-be62-0806d00059f8",
|
93 |
+
"metadata": {},
|
94 |
+
"outputs": [],
|
95 |
+
"source": [
|
96 |
+
"class ConformerCtc3Decoder:\n",
|
97 |
+
" def __init__(self, params_dct=None):\n",
|
98 |
+
" logging.info('loading args')\n",
|
99 |
+
" self.args=Args()\n",
|
100 |
+
" if params_dct is not None:\n",
|
101 |
+
" self.args.args_from_dict(params_dct)\n",
|
102 |
+
" logging.info('loading model')\n",
|
103 |
+
" self.load_model()\n",
|
104 |
+
" logging.info('loading fbank')\n",
|
105 |
+
" self.get_fbank()\n",
|
106 |
+
" \n",
|
107 |
+
" def update_args(self, dct):\n",
|
108 |
+
" self.args.args_from_dict(dct)\n",
|
109 |
+
" \n",
|
110 |
+
" def load_model_(self, model_filename, device):\n",
|
111 |
+
" device = torch.device(\"cpu\")\n",
|
112 |
+
" model = torch.jit.load(model_filename)\n",
|
113 |
+
" model.to(device)\n",
|
114 |
+
" model=model.eval()\n",
|
115 |
+
" self.model=model\n",
|
116 |
+
" \n",
|
117 |
+
" def load_model(self, model_filename=None, device=None):\n",
|
118 |
+
" if model_filename is not None:\n",
|
119 |
+
" self.args.model_filename=model_filename\n",
|
120 |
+
" if device is not None:\n",
|
121 |
+
" self.args.device=device\n",
|
122 |
+
" self.load_model_(self.args.model_filename, self.args.device)\n",
|
123 |
+
" \n",
|
124 |
+
" def get_fbank_(self, device='cpu'):\n",
|
125 |
+
" opts = kaldifeat.FbankOptions()\n",
|
126 |
+
" opts.device = device\n",
|
127 |
+
" opts.frame_opts.dither = self.args.dither\n",
|
128 |
+
" opts.frame_opts.snip_edges = self.args.snip_edges\n",
|
129 |
+
" #opts.frame_opts.samp_freq = sample_rate\n",
|
130 |
+
" opts.mel_opts.num_bins = self.args.num_bins\n",
|
131 |
+
"\n",
|
132 |
+
" fbank = kaldifeat.Fbank(opts)\n",
|
133 |
+
" return fbank\n",
|
134 |
+
" \n",
|
135 |
+
" def get_fbank(self):\n",
|
136 |
+
" self.fbank=self.get_fbank_(self.args.device)\n",
|
137 |
+
" \n",
|
138 |
+
" def read_sound_file_(self, filename: str, expected_sample_rate: float ) -> List[torch.Tensor]:\n",
|
139 |
+
" \"\"\"Read a sound file into a 1-D float32 torch tensor.\n",
|
140 |
+
" Args:\n",
|
141 |
+
" filenames:\n",
|
142 |
+
" A list of sound filenames.\n",
|
143 |
+
" expected_sample_rate:\n",
|
144 |
+
" The expected sample rate of the sound files.\n",
|
145 |
+
" Returns:\n",
|
146 |
+
" Return a 1-D float32 torch tensor.\n",
|
147 |
+
" \"\"\"\n",
|
148 |
+
" wave, sample_rate = torchaudio.load(filename)\n",
|
149 |
+
" assert sample_rate == expected_sample_rate, (\n",
|
150 |
+
" f\"expected sample rate: {expected_sample_rate}. \" f\"Given: {sample_rate}\"\n",
|
151 |
+
" )\n",
|
152 |
+
" # We use only the first channel\n",
|
153 |
+
" return wave[0]\n",
|
154 |
+
" \n",
|
155 |
+
" def format_trs(self, hyp, timestamps):\n",
|
156 |
+
" if len(hyp)!=len(timestamps):\n",
|
157 |
+
" print(f'len of hyp and timestamps is not the same len hyp {len(hyp)} and len of timestamps {len(timestamps)}')\n",
|
158 |
+
" return None\n",
|
159 |
+
" trs ={'text': ' '.join(hyp),\n",
|
160 |
+
" 'words': [{'word': w, 'start':timestamps[i][0], 'end': timestamps[i][1]} for i, w in enumerate(hyp)]\n",
|
161 |
+
" }\n",
|
162 |
+
" return trs\n",
|
163 |
+
" \n",
|
164 |
+
" def decode_(self, wave, fbank, model, device, method, bpe_model_filename, num_classes, \n",
|
165 |
+
" min_active_states, max_active_states, subsampling_factor, use_double_scores, \n",
|
166 |
+
" frame_shift_ms, search_beam, output_beam):\n",
|
167 |
+
" \n",
|
168 |
+
" wave = [wave.to(device)]\n",
|
169 |
+
" logging.info(\"Decoding started\")\n",
|
170 |
+
" features = fbank(wave)\n",
|
171 |
+
" feature_lengths = [f.size(0) for f in features]\n",
|
172 |
+
"\n",
|
173 |
+
" features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))\n",
|
174 |
+
" feature_lengths = torch.tensor(feature_lengths, device=device)\n",
|
175 |
+
"\n",
|
176 |
+
" nnet_output, _ = model(features, feature_lengths)\n",
|
177 |
+
"\n",
|
178 |
+
" batch_size = nnet_output.shape[0]\n",
|
179 |
+
" supervision_segments = torch.tensor(\n",
|
180 |
+
" [\n",
|
181 |
+
" [i, 0, feature_lengths[i] // subsampling_factor]\n",
|
182 |
+
" for i in range(batch_size)\n",
|
183 |
+
" ],\n",
|
184 |
+
" dtype=torch.int32,\n",
|
185 |
+
" )\n",
|
186 |
+
"\n",
|
187 |
+
" if method == \"ctc-decoding\":\n",
|
188 |
+
" logging.info(\"Use CTC decoding\")\n",
|
189 |
+
" bpe_model = spm.SentencePieceProcessor()\n",
|
190 |
+
" bpe_model.load(bpe_model_filename)\n",
|
191 |
+
" max_token_id = num_classes - 1\n",
|
192 |
+
"\n",
|
193 |
+
" H = k2.ctc_topo(\n",
|
194 |
+
" max_token=max_token_id,\n",
|
195 |
+
" modified=False,\n",
|
196 |
+
" device=device,\n",
|
197 |
+
" )\n",
|
198 |
+
"\n",
|
199 |
+
" lattice = get_lattice(\n",
|
200 |
+
" nnet_output=nnet_output,\n",
|
201 |
+
" decoding_graph=H,\n",
|
202 |
+
" supervision_segments=supervision_segments,\n",
|
203 |
+
" search_beam=search_beam,\n",
|
204 |
+
" output_beam=output_beam,\n",
|
205 |
+
" min_active_states=min_active_states,\n",
|
206 |
+
" max_active_states=max_active_states,\n",
|
207 |
+
" subsampling_factor=subsampling_factor,\n",
|
208 |
+
" )\n",
|
209 |
+
"\n",
|
210 |
+
" best_path = one_best_decoding(\n",
|
211 |
+
" lattice=lattice, use_double_scores=use_double_scores\n",
|
212 |
+
" )\n",
|
213 |
+
"\n",
|
214 |
+
" confidence=best_path.get_tot_scores(use_double_scores=False, log_semiring=False).detach()[0]\n",
|
215 |
+
"\n",
|
216 |
+
" timestamps, hyps = parse_fsa_timestamps_and_texts(\n",
|
217 |
+
" best_paths=best_path,\n",
|
218 |
+
" sp=bpe_model,\n",
|
219 |
+
" subsampling_factor=subsampling_factor,\n",
|
220 |
+
" frame_shift_ms=frame_shift_ms,\n",
|
221 |
+
" )\n",
|
222 |
+
" logging.info(f'confidence {confidence}')\n",
|
223 |
+
" logging.info(timestamps)\n",
|
224 |
+
" token_ids = get_texts(best_path)\n",
|
225 |
+
" return self.format_trs(hyps[0], timestamps[0])\n",
|
226 |
+
" \n",
|
227 |
+
" def transcribe_file(self, audio_filename):\n",
|
228 |
+
" wave=self.read_sound_file_(audio_filename, expected_sample_rate=self.args.sample_rate)\n",
|
229 |
+
" \n",
|
230 |
+
" trs=self.decode_(wave, self.fbank, self.model, self.args.device, self.args.method, \n",
|
231 |
+
" self.args.bpe_model_filename, self.args.num_classes,\n",
|
232 |
+
" self.args.min_active_states, self.args.max_active_states, \n",
|
233 |
+
" self.args.subsampling_factor, self.args.use_double_scores, \n",
|
234 |
+
" self.args.frame_shift_ms, self.args.search_beam, self.args.output_beam)\n",
|
235 |
+
" return trs"
|
236 |
+
]
|
237 |
+
},
|
238 |
+
{
|
239 |
+
"cell_type": "markdown",
|
240 |
+
"id": "b1464957-05b6-40f8-a1aa-c58edbed440c",
|
241 |
+
"metadata": {},
|
242 |
+
"source": [
|
243 |
+
"## Example usage"
|
244 |
+
]
|
245 |
+
},
|
246 |
+
{
|
247 |
+
"cell_type": "code",
|
248 |
+
"execution_count": 6,
|
249 |
+
"id": "50ab7c8e-39b6-4783-8342-e79e91d2417e",
|
250 |
+
"metadata": {},
|
251 |
+
"outputs": [
|
252 |
+
{
|
253 |
+
"name": "stderr",
|
254 |
+
"output_type": "stream",
|
255 |
+
"text": [
|
256 |
+
"fatal: not a git repository (or any parent up to mount point /opt)\n",
|
257 |
+
"Stopping at filesystem boundary (GIT_DISCOVERY_ACROSS_FILESYSTEM not set).\n",
|
258 |
+
"fatal: not a git repository (or any parent up to mount point /opt)\n",
|
259 |
+
"Stopping at filesystem boundary (GIT_DISCOVERY_ACROSS_FILESYSTEM not set).\n",
|
260 |
+
"fatal: not a git repository (or any parent up to mount point /opt)\n",
|
261 |
+
"Stopping at filesystem boundary (GIT_DISCOVERY_ACROSS_FILESYSTEM not set).\n"
|
262 |
+
]
|
263 |
+
}
|
264 |
+
],
|
265 |
+
"source": [
|
266 |
+
"#create transcriber/decoder object\n",
|
267 |
+
"#if you want to change parameters (for example model filename) you could create a dict (see class Args attribute names)\n",
|
268 |
+
"#and add it to as argument decoder initialization:\n",
|
269 |
+
"#conformerCtc3Decoder(get_params() | get_decoding_params() | {'model_filename':'my new model filename'})\n",
|
270 |
+
"transcriber=ConformerCtc3Decoder(get_params() | get_decoding_params())"
|
271 |
+
]
|
272 |
+
},
|
273 |
+
{
|
274 |
+
"cell_type": "code",
|
275 |
+
"execution_count": 7,
|
276 |
+
"id": "8020f371-7584-4f6c-990b-f2c023e24060",
|
277 |
+
"metadata": {},
|
278 |
+
"outputs": [
|
279 |
+
{
|
280 |
+
"name": "stdout",
|
281 |
+
"output_type": "stream",
|
282 |
+
"text": [
|
283 |
+
"CPU times: user 4.86 s, sys: 435 ms, total: 5.29 s\n",
|
284 |
+
"Wall time: 4.45 s\n"
|
285 |
+
]
|
286 |
+
},
|
287 |
+
{
|
288 |
+
"data": {
|
289 |
+
"text/plain": [
|
290 |
+
"{'text': 'mina tahaksin homme täna ja homme kui saan all kolm krantsumadiseid veiki panna',\n",
|
291 |
+
" 'words': [{'word': 'mina', 'start': 0.8, 'end': 0.84},\n",
|
292 |
+
" {'word': 'tahaksin', 'start': 1.0, 'end': 1.32},\n",
|
293 |
+
" {'word': 'homme', 'start': 1.48, 'end': 1.76},\n",
|
294 |
+
" {'word': 'täna', 'start': 2.08, 'end': 2.12},\n",
|
295 |
+
" {'word': 'ja', 'start': 3.72, 'end': 3.76},\n",
|
296 |
+
" {'word': 'homme', 'start': 4.16, 'end': 4.44},\n",
|
297 |
+
" {'word': 'kui', 'start': 5.96, 'end': 6.0},\n",
|
298 |
+
" {'word': 'saan', 'start': 6.52, 'end': 6.84},\n",
|
299 |
+
" {'word': 'all', 'start': 7.36, 'end': 7.4},\n",
|
300 |
+
" {'word': 'kolm', 'start': 8.32, 'end': 8.36},\n",
|
301 |
+
" {'word': 'krantsumadiseid', 'start': 8.68, 'end': 9.72},\n",
|
302 |
+
" {'word': 'veiki', 'start': 9.76, 'end': 10.04},\n",
|
303 |
+
" {'word': 'panna', 'start': 10.16, 'end': 10.4}]}"
|
304 |
+
]
|
305 |
+
},
|
306 |
+
"execution_count": 7,
|
307 |
+
"metadata": {},
|
308 |
+
"output_type": "execute_result"
|
309 |
+
}
|
310 |
+
],
|
311 |
+
"source": [
|
312 |
+
"#transribe audiofile (NB! model assumes sample rate of 16000)\n",
|
313 |
+
"%time transcriber.transcribe_file('audio/emt16k.wav')"
|
314 |
+
]
|
315 |
+
},
|
316 |
+
{
|
317 |
+
"cell_type": "code",
|
318 |
+
"execution_count": 10,
|
319 |
+
"id": "4d2a480d-f0aa-4474-bfdb-ad298a629ce5",
|
320 |
+
"metadata": {},
|
321 |
+
"outputs": [
|
322 |
+
{
|
323 |
+
"name": "stdout",
|
324 |
+
"output_type": "stream",
|
325 |
+
"text": [
|
326 |
+
"CPU times: user 16.2 s, sys: 1.8 s, total: 18 s\n",
|
327 |
+
"Wall time: 15.1 s\n"
|
328 |
+
]
|
329 |
+
}
|
330 |
+
],
|
331 |
+
"source": [
|
332 |
+
"%time trs=transcriber.transcribe_file('audio/oden_kypsis16k.wav')"
|
333 |
+
]
|
334 |
+
},
|
335 |
+
{
|
336 |
+
"cell_type": "code",
|
337 |
+
"execution_count": 11,
|
338 |
+
"id": "d3827548-bca0-4409-95bc-9aa8ba377135",
|
339 |
+
"metadata": {},
|
340 |
+
"outputs": [
|
341 |
+
{
|
342 |
+
"data": {
|
343 |
+
"text/plain": [
|
344 |
+
"{'text': 'enamus ajast nagu klikkid neid allserva tekivad need luba küpsiseid mis on nagu ilusti kohati tõlgitud eesti keelde see idee arusaadavamaks ma tean et see on kukis inglise kees ma ei saa sellest ka aru nagu mis asi on kukis on ju ma saan aru et ta vaid minee eest ära luba küpsises tava ei anna noh anna minna ma luban küpssi juhmaoloog okei on ju ma ei tea mis ta teeb lihtsalt selle eestikeelseks tõlk või eesti keelde tõlkimine kui teinud seda nagu arusaadavamaks küpsised kuule kuule veebisaid küsib sinu käest tahad tähendab on okei kui me neid kugiseid kasutame sa mingi ja mida iga mul täiesti savi või noh et et jah',\n",
|
345 |
+
" 'words': [{'word': 'enamus', 'start': 3.56, 'end': 3.8},\n",
|
346 |
+
" {'word': 'ajast', 'start': 3.8, 'end': 4.04},\n",
|
347 |
+
" {'word': 'nagu', 'start': 4.2, 'end': 4.24},\n",
|
348 |
+
" {'word': 'klikkid', 'start': 4.72, 'end': 5.12},\n",
|
349 |
+
" {'word': 'neid', 'start': 5.16, 'end': 5.2},\n",
|
350 |
+
" {'word': 'allserva', 'start': 5.72, 'end': 6.2},\n",
|
351 |
+
" {'word': 'tekivad', 'start': 6.32, 'end': 6.64},\n",
|
352 |
+
" {'word': 'need', 'start': 7.4, 'end': 7.44},\n",
|
353 |
+
" {'word': 'luba', 'start': 7.72, 'end': 8.0},\n",
|
354 |
+
" {'word': 'küpsiseid', 'start': 8.08, 'end': 8.64},\n",
|
355 |
+
" {'word': 'mis', 'start': 9.68, 'end': 9.72},\n",
|
356 |
+
" {'word': 'on', 'start': 9.76, 'end': 9.8},\n",
|
357 |
+
" {'word': 'nagu', 'start': 9.92, 'end': 9.96},\n",
|
358 |
+
" {'word': 'ilusti', 'start': 10.04, 'end': 10.36},\n",
|
359 |
+
" {'word': 'kohati', 'start': 10.4, 'end': 10.68},\n",
|
360 |
+
" {'word': 'tõlgitud', 'start': 11.08, 'end': 11.4},\n",
|
361 |
+
" {'word': 'eesti', 'start': 11.6, 'end': 11.64},\n",
|
362 |
+
" {'word': 'keelde', 'start': 11.8, 'end': 12.08},\n",
|
363 |
+
" {'word': 'see', 'start': 12.68, 'end': 12.72},\n",
|
364 |
+
" {'word': 'idee', 'start': 12.8, 'end': 13.04},\n",
|
365 |
+
" {'word': 'arusaadavamaks', 'start': 13.2, 'end': 13.8},\n",
|
366 |
+
" {'word': 'ma', 'start': 13.92, 'end': 13.96},\n",
|
367 |
+
" {'word': 'tean', 'start': 14.04, 'end': 14.24},\n",
|
368 |
+
" {'word': 'et', 'start': 14.28, 'end': 14.36},\n",
|
369 |
+
" {'word': 'see', 'start': 14.4, 'end': 14.44},\n",
|
370 |
+
" {'word': 'on', 'start': 14.44, 'end': 14.52},\n",
|
371 |
+
" {'word': 'kukis', 'start': 14.56, 'end': 14.92},\n",
|
372 |
+
" {'word': 'inglise', 'start': 14.92, 'end': 15.2},\n",
|
373 |
+
" {'word': 'kees', 'start': 15.2, 'end': 15.44},\n",
|
374 |
+
" {'word': 'ma', 'start': 15.84, 'end': 15.88},\n",
|
375 |
+
" {'word': 'ei', 'start': 15.92, 'end': 16.0},\n",
|
376 |
+
" {'word': 'saa', 'start': 16.04, 'end': 16.08},\n",
|
377 |
+
" {'word': 'sellest', 'start': 16.24, 'end': 16.28},\n",
|
378 |
+
" {'word': 'ka', 'start': 16.56, 'end': 16.6},\n",
|
379 |
+
" {'word': 'aru', 'start': 16.76, 'end': 16.8},\n",
|
380 |
+
" {'word': 'nagu', 'start': 16.96, 'end': 17.0},\n",
|
381 |
+
" {'word': 'mis', 'start': 17.12, 'end': 17.16},\n",
|
382 |
+
" {'word': 'asi', 'start': 17.28, 'end': 17.32},\n",
|
383 |
+
" {'word': 'on', 'start': 17.36, 'end': 17.4},\n",
|
384 |
+
" {'word': 'kukis', 'start': 17.48, 'end': 17.8},\n",
|
385 |
+
" {'word': 'on', 'start': 17.88, 'end': 17.92},\n",
|
386 |
+
" {'word': 'ju', 'start': 17.96, 'end': 18.0},\n",
|
387 |
+
" {'word': 'ma', 'start': 18.28, 'end': 18.32},\n",
|
388 |
+
" {'word': 'saan', 'start': 18.36, 'end': 18.48},\n",
|
389 |
+
" {'word': 'aru', 'start': 18.52, 'end': 18.56},\n",
|
390 |
+
" {'word': 'et', 'start': 18.72, 'end': 18.76},\n",
|
391 |
+
" {'word': 'ta', 'start': 19.2, 'end': 19.24},\n",
|
392 |
+
" {'word': 'vaid', 'start': 19.32, 'end': 19.44},\n",
|
393 |
+
" {'word': 'minee', 'start': 19.48, 'end': 19.68},\n",
|
394 |
+
" {'word': 'eest', 'start': 19.76, 'end': 19.96},\n",
|
395 |
+
" {'word': 'ära', 'start': 20.12, 'end': 20.16},\n",
|
396 |
+
" {'word': 'luba', 'start': 21.56, 'end': 21.88},\n",
|
397 |
+
" {'word': 'küpsises', 'start': 21.96, 'end': 22.44},\n",
|
398 |
+
" {'word': 'tava', 'start': 22.6, 'end': 22.76},\n",
|
399 |
+
" {'word': 'ei', 'start': 22.84, 'end': 22.88},\n",
|
400 |
+
" {'word': 'anna', 'start': 23.0, 'end': 23.16},\n",
|
401 |
+
" {'word': 'noh', 'start': 23.4, 'end': 23.44},\n",
|
402 |
+
" {'word': 'anna', 'start': 23.64, 'end': 23.76},\n",
|
403 |
+
" {'word': 'minna', 'start': 24.0, 'end': 24.04},\n",
|
404 |
+
" {'word': 'ma', 'start': 24.16, 'end': 24.2},\n",
|
405 |
+
" {'word': 'luban', 'start': 24.24, 'end': 24.56},\n",
|
406 |
+
" {'word': 'küpssi', 'start': 24.64, 'end': 24.92},\n",
|
407 |
+
" {'word': 'juhmaoloog', 'start': 25.0, 'end': 25.28},\n",
|
408 |
+
" {'word': 'okei', 'start': 25.28, 'end': 25.56},\n",
|
409 |
+
" {'word': 'on', 'start': 25.64, 'end': 25.72},\n",
|
410 |
+
" {'word': 'ju', 'start': 25.72, 'end': 25.76},\n",
|
411 |
+
" {'word': 'ma', 'start': 25.84, 'end': 25.88},\n",
|
412 |
+
" {'word': 'ei', 'start': 25.92, 'end': 25.96},\n",
|
413 |
+
" {'word': 'tea', 'start': 26.0, 'end': 26.04},\n",
|
414 |
+
" {'word': 'mis', 'start': 26.28, 'end': 26.32},\n",
|
415 |
+
" {'word': 'ta', 'start': 26.36, 'end': 26.4},\n",
|
416 |
+
" {'word': 'teeb', 'start': 26.56, 'end': 26.8},\n",
|
417 |
+
" {'word': 'lihtsalt', 'start': 27.04, 'end': 27.08},\n",
|
418 |
+
" {'word': 'selle', 'start': 27.24, 'end': 27.28},\n",
|
419 |
+
" {'word': 'eestikeelseks', 'start': 28.04, 'end': 28.68},\n",
|
420 |
+
" {'word': 'tõlk', 'start': 28.8, 'end': 29.08},\n",
|
421 |
+
" {'word': 'või', 'start': 29.16, 'end': 29.2},\n",
|
422 |
+
" {'word': 'eesti', 'start': 29.48, 'end': 29.52},\n",
|
423 |
+
" {'word': 'keelde', 'start': 29.68, 'end': 30.04},\n",
|
424 |
+
" {'word': 'tõlkimine', 'start': 30.2, 'end': 30.68},\n",
|
425 |
+
" {'word': 'kui', 'start': 30.8, 'end': 30.84},\n",
|
426 |
+
" {'word': 'teinud', 'start': 30.96, 'end': 31.16},\n",
|
427 |
+
" {'word': 'seda', 'start': 31.2, 'end': 31.24},\n",
|
428 |
+
" {'word': 'nagu', 'start': 31.72, 'end': 31.76},\n",
|
429 |
+
" {'word': 'arusaadavamaks', 'start': 31.88, 'end': 32.6},\n",
|
430 |
+
" {'word': 'küpsised', 'start': 33.52, 'end': 33.88},\n",
|
431 |
+
" {'word': 'kuule', 'start': 36.96, 'end': 37.08},\n",
|
432 |
+
" {'word': 'kuule', 'start': 37.32, 'end': 37.44},\n",
|
433 |
+
" {'word': 'veebisaid', 'start': 37.8, 'end': 38.28},\n",
|
434 |
+
" {'word': 'küsib', 'start': 38.44, 'end': 38.56},\n",
|
435 |
+
" {'word': 'sinu', 'start': 38.6, 'end': 38.72},\n",
|
436 |
+
" {'word': 'käest', 'start': 38.76, 'end': 39.0},\n",
|
437 |
+
" {'word': 'tahad', 'start': 39.52, 'end': 39.72},\n",
|
438 |
+
" {'word': 'tähendab', 'start': 40.32, 'end': 40.36},\n",
|
439 |
+
" {'word': 'on', 'start': 40.8, 'end': 40.88},\n",
|
440 |
+
" {'word': 'okei', 'start': 40.88, 'end': 41.2},\n",
|
441 |
+
" {'word': 'kui', 'start': 41.24, 'end': 41.28},\n",
|
442 |
+
" {'word': 'me', 'start': 41.36, 'end': 41.4},\n",
|
443 |
+
" {'word': 'neid', 'start': 41.6, 'end': 41.64},\n",
|
444 |
+
" {'word': 'kugiseid', 'start': 42.2, 'end': 42.64},\n",
|
445 |
+
" {'word': 'kasutame', 'start': 42.8, 'end': 43.08},\n",
|
446 |
+
" {'word': 'sa', 'start': 43.56, 'end': 43.6},\n",
|
447 |
+
" {'word': 'mingi', 'start': 43.8, 'end': 43.84},\n",
|
448 |
+
" {'word': 'ja', 'start': 44.04, 'end': 44.08},\n",
|
449 |
+
" {'word': 'mida', 'start': 44.28, 'end': 44.32},\n",
|
450 |
+
" {'word': 'iga', 'start': 44.44, 'end': 44.48},\n",
|
451 |
+
" {'word': 'mul', 'start': 44.56, 'end': 44.6},\n",
|
452 |
+
" {'word': 'täiesti', 'start': 44.92, 'end': 44.96},\n",
|
453 |
+
" {'word': 'savi', 'start': 45.08, 'end': 45.28},\n",
|
454 |
+
" {'word': 'või', 'start': 45.36, 'end': 45.4},\n",
|
455 |
+
" {'word': 'noh', 'start': 45.44, 'end': 45.48},\n",
|
456 |
+
" {'word': 'et', 'start': 45.6, 'end': 45.64},\n",
|
457 |
+
" {'word': 'et', 'start': 47.36, 'end': 47.4},\n",
|
458 |
+
" {'word': 'jah', 'start': 47.56, 'end': 47.68}]}"
|
459 |
+
]
|
460 |
+
},
|
461 |
+
"execution_count": 11,
|
462 |
+
"metadata": {},
|
463 |
+
"output_type": "execute_result"
|
464 |
+
}
|
465 |
+
],
|
466 |
+
"source": [
|
467 |
+
"trs"
|
468 |
+
]
|
469 |
+
},
|
470 |
+
{
|
471 |
+
"cell_type": "code",
|
472 |
+
"execution_count": null,
|
473 |
+
"id": "ea3b25b7-a1f9-4b21-911d-35159c5f3009",
|
474 |
+
"metadata": {},
|
475 |
+
"outputs": [],
|
476 |
+
"source": []
|
477 |
+
}
|
478 |
+
],
|
479 |
+
"metadata": {
|
480 |
+
"kernelspec": {
|
481 |
+
"display_name": "Python 3 (ipykernel)",
|
482 |
+
"language": "python",
|
483 |
+
"name": "python3"
|
484 |
+
},
|
485 |
+
"language_info": {
|
486 |
+
"codemirror_mode": {
|
487 |
+
"name": "ipython",
|
488 |
+
"version": 3
|
489 |
+
},
|
490 |
+
"file_extension": ".py",
|
491 |
+
"mimetype": "text/x-python",
|
492 |
+
"name": "python",
|
493 |
+
"nbconvert_exporter": "python",
|
494 |
+
"pygments_lexer": "ipython3",
|
495 |
+
"version": "3.9.16"
|
496 |
+
}
|
497 |
+
},
|
498 |
+
"nbformat": 4,
|
499 |
+
"nbformat_minor": 5
|
500 |
+
}
|
err2020/data/lang_bpe_500/bpe.model
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:14afda3d7b1a9b2d07ca4f55bdf2d9d7424bb795068cac61107bc2b58a26b7fd
|
3 |
+
size 245129
|
requirements.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
lhotse
|
run.bat
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
set APP_PATH=%cd%
|
2 |
+
|
3 |
+
docker stop icefall_run
|
4 |
+
docker rm icefall_run
|
5 |
+
docker run -it --rm ^
|
6 |
+
-p 8888:8888 ^
|
7 |
+
-v %APP_PATH%:/opt/notebooks ^
|
8 |
+
--name icefall_run ^
|
9 |
+
icefall
|
run.sh
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
APP_PATH=$(pwd)
|
4 |
+
|
5 |
+
docker stop icefall_run
|
6 |
+
docker rm icefall_run
|
7 |
+
docker run -it --rm \
|
8 |
+
-p 8888:8888 \
|
9 |
+
-v "$APP_PATH":/opt/notebooks \
|
10 |
+
--name icefall_run \
|
11 |
+
icefall
|