martin commited on
Commit
ccdff04
·
1 Parent(s): 51a6224

tts use api

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +0 -6
  2. Dockerfile +0 -46
  3. app.py +24 -32
  4. cosyvoice/__init__.py +0 -0
  5. cosyvoice/cli/__init__.py +0 -0
  6. cosyvoice/cli/cosyvoice.py +0 -68
  7. cosyvoice/cli/frontend.py +0 -106
  8. cosyvoice/cli/model.py +0 -32
  9. cosyvoice/flow/decoder.py +0 -238
  10. cosyvoice/flow/flow.py +0 -196
  11. cosyvoice/flow/flow_matching.py +0 -315
  12. cosyvoice/flow/length_regulator.py +0 -65
  13. cosyvoice/hifigan/f0_predictor.py +0 -55
  14. cosyvoice/hifigan/generator.py +0 -566
  15. cosyvoice/matcha/audio.py +0 -90
  16. cosyvoice/matcha/decoder.py +0 -511
  17. cosyvoice/matcha/flow_matching.py +0 -141
  18. cosyvoice/matcha/transformer.py +0 -443
  19. cosyvoice/transformer/__init__.py +0 -0
  20. cosyvoice/transformer/activation.py +0 -87
  21. cosyvoice/transformer/attention.py +0 -322
  22. cosyvoice/transformer/convolution.py +0 -147
  23. cosyvoice/transformer/decoder.py +0 -418
  24. cosyvoice/transformer/decoder_layer.py +0 -132
  25. cosyvoice/transformer/embedding.py +0 -293
  26. cosyvoice/transformer/encoder.py +0 -633
  27. cosyvoice/transformer/encoder_layer.py +0 -237
  28. cosyvoice/transformer/label_smoothing_loss.py +0 -98
  29. cosyvoice/transformer/positionwise_feed_forward.py +0 -116
  30. cosyvoice/transformer/subsampling.py +0 -391
  31. cosyvoice/utils/__init__.py +0 -0
  32. cosyvoice/utils/audio.py +0 -90
  33. cosyvoice/utils/class_utils.py +0 -78
  34. cosyvoice/utils/common.py +0 -169
  35. cosyvoice/utils/executor.py +0 -151
  36. cosyvoice/utils/file_utils.py +0 -49
  37. cosyvoice/utils/frontend_utils.py +0 -142
  38. cosyvoice/utils/mask.py +0 -226
  39. cosyvoice/utils/scheduler.py +0 -761
  40. cosyvoice/utils/train_utils.py +0 -350
  41. funasr_detach/__init__.py +0 -38
  42. funasr_detach/auto/__init__.py +0 -0
  43. funasr_detach/auto/auto_frontend.py +0 -90
  44. funasr_detach/auto/auto_model.py +0 -573
  45. funasr_detach/auto/auto_tokenizer.py +0 -7
  46. funasr_detach/bin/__init__.py +0 -0
  47. funasr_detach/bin/compute_audio_cmvn.py +0 -152
  48. funasr_detach/bin/inference.py +0 -33
  49. funasr_detach/bin/tokenize_text.py +0 -281
  50. funasr_detach/bin/train.py +0 -227
.gitattributes CHANGED
@@ -2,11 +2,5 @@
2
  *.wav filter=lfs diff=lfs merge=lfs -text
3
  assets/user.png filter=lfs diff=lfs merge=lfs -text
4
  assets/assistant.png filter=lfs diff=lfs merge=lfs -text
5
- speakers/闫雨婷_prompt.wav filter=lfs diff=lfs merge=lfs -text
6
- speakers/闫雨婷RAP_prompt.wav filter=lfs diff=lfs merge=lfs -text
7
- speakers/闫雨婷VOCAL_prompt.wav filter=lfs diff=lfs merge=lfs -text
8
- speakers/Tingting_prompt.wav filter=lfs diff=lfs merge=lfs -text
9
- speakers/TingtingRAP_prompt.wav filter=lfs diff=lfs merge=lfs -text
10
- speakers/TingtingVOCAL_prompt.wav filter=lfs diff=lfs merge=lfs -text
11
  assets/yuewen.jpeg filter=lfs diff=lfs merge=lfs -text
12
  assets/request_rap_zh.wav filter=lfs diff=lfs merge=lfs -text
 
2
  *.wav filter=lfs diff=lfs merge=lfs -text
3
  assets/user.png filter=lfs diff=lfs merge=lfs -text
4
  assets/assistant.png filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
5
  assets/yuewen.jpeg filter=lfs diff=lfs merge=lfs -text
6
  assets/request_rap_zh.wav filter=lfs diff=lfs merge=lfs -text
Dockerfile DELETED
@@ -1,46 +0,0 @@
1
- FROM nvidia/cuda:12.1.0-base-ubuntu20.04
2
-
3
- ENV TZ=Asia/Shanghai
4
- RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime \
5
- && echo $TZ > /etc/timezone
6
-
7
- RUN apt-get update \
8
- && apt-get install -y build-essential \
9
- && apt-get install -y wget \
10
- && apt-get install -y software-properties-common curl zip unzip git-lfs awscli libssl-dev openssh-server vim \
11
- && apt-get install -y net-tools iputils-ping iproute2
12
-
13
- RUN apt-get install --reinstall ca-certificates && update-ca-certificates
14
-
15
- RUN add-apt-repository -y 'ppa:deadsnakes/ppa' && apt update
16
- RUN apt install python3.10 python3.10-dev python3.10-distutils python3.10-venv -y \
17
- && apt-get clean \
18
- && rm -rf /var/lib/apt/lists/*
19
-
20
- RUN wget -qO- https://bootstrap.pypa.io/get-pip.py | python3.10
21
- RUN ln -s /usr/bin/python3.10 /usr/bin/python
22
- RUN pip uninstall -y Pillow && pip install pillow
23
-
24
- # https://huggingface.co/docs/hub/spaces-sdks-docker#permissions
25
- RUN useradd -m -u 1000 user
26
- USER user
27
-
28
- ENV HOME="/home/user" \
29
- PATH="/home/user/.local/bin:${PATH}"
30
-
31
- RUN python3.10 -m pip install pipx
32
- RUN pipx install poetry
33
-
34
- RUN poetry --version || { echo 'Poetry installation check failed' ; exit 1; }
35
-
36
- WORKDIR /workspace
37
-
38
- COPY --chown=user requirements.txt .
39
- RUN pip install -r requirements.txt
40
-
41
- COPY --chown=user . .
42
-
43
- RUN pip install gradio
44
- RUN pip install openai
45
- RUN chmod +x start_app.sh
46
- CMD ["./start_app.sh", "/tmp/hf_model"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app.py CHANGED
@@ -4,15 +4,13 @@ import gradio as gr
4
  import time
5
  from pathlib import Path
6
 
7
- from tokenizer import StepAudioTokenizer
8
- from tts import StepAudioTTS
9
- from yuewen_api import call_audiochat, call_asr
10
 
11
  CACHE_DIR = "/tmp/gradio/"
12
- CACHE_CLEAN_AGE = 864000
13
 
14
  CHINESE_PROMPT_CONTENT = """你是一个为对话而设计的人工智能模型,目前无法连接到互联网。
15
- 当你需要唱歌或说唱时,请以(RAP)开头。当你需要快速说话时,请以(快速)开头。当你需要慢速说话时,请以(慢速)开头。
16
  现在,你需要倾听用户的语音内容,并以礼貌、简洁、口语化的文本进行回复。你需要尽量用户的语种进行回复。"""
17
 
18
  ENGLISH_PROMPT_CONTENT = """You are an AI designed for conversation, currently unable to connect to the internet.
@@ -89,20 +87,15 @@ def add_message(chatbot, history, mic, text):
89
  return chatbot, history, None
90
 
91
 
92
- def save_tmp_audio(audio, sr):
93
  import tempfile
94
- import torchaudio
95
 
96
- with tempfile.NamedTemporaryFile(
97
- dir=CACHE_DIR, delete=False, suffix=".wav"
98
- ) as temp_audio:
99
- temp_audio_path = temp_audio.name
100
- torchaudio.save(temp_audio_path, audio, sr)
101
 
102
  return temp_audio.name
103
 
104
 
105
- def predict(chatbot, history, tts_model, user_prompt, enable_asr):
106
  """Generate a response from the model."""
107
  start_time = time.time()
108
  try:
@@ -126,8 +119,8 @@ def predict(chatbot, history, tts_model, user_prompt, enable_asr):
126
 
127
  text = call_audiochat(messages)
128
  print(f"predict {text=}")
129
- audio, sr = tts_model(text, "Tingting")
130
- audio_path = save_tmp_audio(audio, sr)
131
  print(f"save_tmp_audio {audio_path=}")
132
  chatbot.append({"role": "assistant", "content": text})
133
  chatbot.append({"role": "assistant", "content": {"path": audio_path}})
@@ -142,17 +135,15 @@ def predict(chatbot, history, tts_model, user_prompt, enable_asr):
142
  return chatbot, history
143
 
144
 
145
- def _launch_demo(args, tts_model):
146
- with gr.Blocks(delete_cache=(86400, CACHE_CLEAN_AGE)) as demo:
147
  # 保存 chat 历史,不需要每次再重新拼格式
148
  history = gr.State([])
149
  gr.Markdown("""<center><font size=8>Step Audio Chat</center>""")
150
  gr.Markdown(
151
  """<font size=4>This preview demonstrates core functionalities. To unlock the cormplete real-time voice conversation system with end-to-end encryption and advanced features, download the [Yuewen APP](https://m.yuewen.cn/call-app) with the link or via QR Code.</font>"""
152
  )
153
- with gr.Accordion(
154
- label="Click to view the QR code ", open=False
155
- ):
156
  gr.Image(
157
  value="assets/yuewen.jpeg",
158
  interactive=False,
@@ -161,7 +152,8 @@ def _launch_demo(args, tts_model):
161
  show_fullscreen_button=False,
162
  )
163
  with gr.Accordion(
164
- label="The performance of English prompts is not as stable as that of Chinese prompts. You can click here to change sys prompt.", open=False
 
165
  ):
166
  prompt_choice = gr.Radio(
167
  choices=list(PROMPT_TEMPLATE.keys()),
@@ -222,7 +214,7 @@ def _launch_demo(args, tts_model):
222
  print(f"update_examples error")
223
  return chatbot, history
224
  else:
225
- chatbot, history = predict(chatbot, history, tts_model, user_prompt, enable_asr)
226
  print(f"update_examples done")
227
  return chatbot, history
228
 
@@ -230,7 +222,13 @@ def _launch_demo(args, tts_model):
230
  gr.Examples(
231
  fn=update_examples,
232
  examples=CHAT_EXAMPLES,
233
- inputs=[example_comment, example_text, example_audio, user_prompt, enable_asr],
 
 
 
 
 
 
234
  outputs=[chatbot, history],
235
  run_on_click=True,
236
  )
@@ -241,7 +239,7 @@ def _launch_demo(args, tts_model):
241
  gr.Warning(error)
242
  return chatbot, history, None, None
243
  else:
244
- chatbot, history = predict(chatbot, history, tts_model, user_prompt, enable_asr)
245
  return chatbot, history, None, None
246
 
247
  gen_btn.click(
@@ -266,7 +264,7 @@ def _launch_demo(args, tts_model):
266
  while history and history[-1]["role"] == "assistant":
267
  print(f"discard {history[-1]}")
268
  history.pop()
269
- return predict(chatbot, history, tts_model, user_prompt, enable_asr)
270
 
271
  regen_btn.click(
272
  regenerate,
@@ -295,10 +293,4 @@ if __name__ == "__main__":
295
  "--server-name", type=str, default="0.0.0.0", help="Demo server name."
296
  )
297
  args = parser.parse_args()
298
- tokenizer = StepAudioTokenizer(
299
- os.path.join(args.model_path, "Step-Audio-Tokenizer")
300
- )
301
- tts_model = StepAudioTTS(
302
- os.path.join(args.model_path, "Step-Audio-TTS-3B"), tokenizer
303
- )
304
- _launch_demo(args, tts_model)
 
4
  import time
5
  from pathlib import Path
6
 
7
+ from yuewen_api import call_audiochat, call_asr, call_tts
 
 
8
 
9
  CACHE_DIR = "/tmp/gradio/"
10
+ CACHE_CLEAN_AGE = 86400
11
 
12
  CHINESE_PROMPT_CONTENT = """你是一个为对话而设计的人工智能模型,目前无法连接到互联网。
13
+ 当你需要唱歌时,请以(哼唱)开头。当你需要rap或说唱时,请以(RAP)开头。当你需要快速说话时,请以(快速)开头。当你需要慢速说话时,请以(慢速)开头。
14
  现在,你需要倾听用户的语音内容,并以礼貌、简洁、口语化的文本进行回复。你需要尽量用户的语种进行回复。"""
15
 
16
  ENGLISH_PROMPT_CONTENT = """You are an AI designed for conversation, currently unable to connect to the internet.
 
87
  return chatbot, history, None
88
 
89
 
90
+ def get_tmp_audio_path():
91
  import tempfile
 
92
 
93
+ temp_audio = tempfile.NamedTemporaryFile(dir=CACHE_DIR, delete=False, suffix=".mp3")
 
 
 
 
94
 
95
  return temp_audio.name
96
 
97
 
98
+ def predict(chatbot, history, user_prompt, enable_asr):
99
  """Generate a response from the model."""
100
  start_time = time.time()
101
  try:
 
119
 
120
  text = call_audiochat(messages)
121
  print(f"predict {text=}")
122
+ audio_path = get_tmp_audio_path()
123
+ call_tts(text, audio_path)
124
  print(f"save_tmp_audio {audio_path=}")
125
  chatbot.append({"role": "assistant", "content": text})
126
  chatbot.append({"role": "assistant", "content": {"path": audio_path}})
 
135
  return chatbot, history
136
 
137
 
138
+ def _launch_demo(args):
139
+ with gr.Blocks(delete_cache=(3600, CACHE_CLEAN_AGE)) as demo:
140
  # 保存 chat 历史,不需要每次再重新拼格式
141
  history = gr.State([])
142
  gr.Markdown("""<center><font size=8>Step Audio Chat</center>""")
143
  gr.Markdown(
144
  """<font size=4>This preview demonstrates core functionalities. To unlock the cormplete real-time voice conversation system with end-to-end encryption and advanced features, download the [Yuewen APP](https://m.yuewen.cn/call-app) with the link or via QR Code.</font>"""
145
  )
146
+ with gr.Accordion(label="Click to view the QR code ", open=False):
 
 
147
  gr.Image(
148
  value="assets/yuewen.jpeg",
149
  interactive=False,
 
152
  show_fullscreen_button=False,
153
  )
154
  with gr.Accordion(
155
+ label="The performance of English prompts is not as stable as that of Chinese prompts. You can click here to change sys prompt.",
156
+ open=False,
157
  ):
158
  prompt_choice = gr.Radio(
159
  choices=list(PROMPT_TEMPLATE.keys()),
 
214
  print(f"update_examples error")
215
  return chatbot, history
216
  else:
217
+ chatbot, history = predict(chatbot, history, user_prompt, enable_asr)
218
  print(f"update_examples done")
219
  return chatbot, history
220
 
 
222
  gr.Examples(
223
  fn=update_examples,
224
  examples=CHAT_EXAMPLES,
225
+ inputs=[
226
+ example_comment,
227
+ example_text,
228
+ example_audio,
229
+ user_prompt,
230
+ enable_asr,
231
+ ],
232
  outputs=[chatbot, history],
233
  run_on_click=True,
234
  )
 
239
  gr.Warning(error)
240
  return chatbot, history, None, None
241
  else:
242
+ chatbot, history = predict(chatbot, history, user_prompt, enable_asr)
243
  return chatbot, history, None, None
244
 
245
  gen_btn.click(
 
264
  while history and history[-1]["role"] == "assistant":
265
  print(f"discard {history[-1]}")
266
  history.pop()
267
+ return predict(chatbot, history, user_prompt, enable_asr)
268
 
269
  regen_btn.click(
270
  regenerate,
 
293
  "--server-name", type=str, default="0.0.0.0", help="Demo server name."
294
  )
295
  args = parser.parse_args()
296
+ _launch_demo(args)
 
 
 
 
 
 
cosyvoice/__init__.py DELETED
File without changes
cosyvoice/cli/__init__.py DELETED
File without changes
cosyvoice/cli/cosyvoice.py DELETED
@@ -1,68 +0,0 @@
1
- # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- import os
15
- import uuid
16
- import time
17
- from tqdm import tqdm
18
- import torch
19
- import torchaudio
20
- from hyperpyyaml import load_hyperpyyaml
21
- from cosyvoice.cli.frontend import CosyVoiceFrontEnd
22
- from cosyvoice.cli.model import CosyVoiceModel
23
-
24
-
25
- class CosyVoice:
26
-
27
- def __init__(
28
- self,
29
- model_dir,
30
- ):
31
- self.model_dir = model_dir
32
- with open("{}/cosyvoice.yaml".format(model_dir), "r") as f:
33
- configs = load_hyperpyyaml(f)
34
- self.frontend = CosyVoiceFrontEnd(
35
- configs["feat_extractor"],
36
- "{}/campplus.onnx".format(model_dir),
37
- "{}/speech_tokenizer_v1.onnx".format(model_dir),
38
- )
39
- self.model = CosyVoiceModel(configs["flow"], configs["hift"])
40
- self.model.load(
41
- "{}/flow.pt".format(model_dir),
42
- "{}/hift.pt".format(model_dir),
43
- )
44
- self.model.flow = self.model.flow.to(torch.bfloat16)
45
- del configs
46
-
47
- def token_to_wav_offline(
48
- self,
49
- speech_token,
50
- speech_feat,
51
- speech_feat_len,
52
- prompt_token,
53
- prompt_token_len,
54
- embedding,
55
- ):
56
- tts_mel = self.model.flow.inference(
57
- token=speech_token.to(self.model.device),
58
- token_len=torch.tensor([speech_token.size(1)], dtype=torch.int32).to(
59
- self.model.device
60
- ),
61
- prompt_token=prompt_token.to(self.model.device),
62
- prompt_token_len=prompt_token_len.to(self.model.device),
63
- prompt_feat=speech_feat.to(self.model.device),
64
- prompt_feat_len=speech_feat_len.to(self.model.device),
65
- embedding=embedding.to(self.model.device),
66
- )
67
- tts_speech = self.model.hift.inference(mel=tts_mel.float())[0].cpu()
68
- return tts_speech
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cosyvoice/cli/frontend.py DELETED
@@ -1,106 +0,0 @@
1
- # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- import onnxruntime
15
- import torch
16
- import numpy as np
17
- import whisper
18
- from typing import Callable
19
- import torchaudio.compliance.kaldi as kaldi
20
-
21
-
22
- class CosyVoiceFrontEnd:
23
-
24
- def __init__(
25
- self,
26
- feat_extractor: Callable,
27
- campplus_model: str,
28
- speech_tokenizer_model: str,
29
- ):
30
- self.feat_extractor = feat_extractor
31
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
32
- option = onnxruntime.SessionOptions()
33
- option.graph_optimization_level = (
34
- onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
35
- )
36
- option.intra_op_num_threads = 1
37
- self.campplus_session = onnxruntime.InferenceSession(
38
- campplus_model, sess_options=option, providers=["CPUExecutionProvider"]
39
- )
40
- self.speech_tokenizer_session = onnxruntime.InferenceSession(
41
- speech_tokenizer_model,
42
- sess_options=option,
43
- providers=[
44
- (
45
- "CUDAExecutionProvider"
46
- if torch.cuda.is_available()
47
- else "CPUExecutionProvider"
48
- )
49
- ],
50
- )
51
-
52
- def _extract_speech_token(self, speech):
53
- assert (
54
- speech.shape[1] / 16000 <= 30
55
- ), "do not support extract speech token for audio longer than 30s"
56
- feat = whisper.log_mel_spectrogram(speech, n_mels=128)
57
- speech_token = (
58
- self.speech_tokenizer_session.run(
59
- None,
60
- {
61
- self.speech_tokenizer_session.get_inputs()[0]
62
- .name: feat.detach()
63
- .cpu()
64
- .numpy(),
65
- self.speech_tokenizer_session.get_inputs()[1].name: np.array(
66
- [feat.shape[2]], dtype=np.int32
67
- ),
68
- },
69
- )[0]
70
- .flatten()
71
- .tolist()
72
- )
73
- speech_token = torch.tensor([speech_token], dtype=torch.int32).to(self.device)
74
- speech_token_len = torch.tensor([speech_token.shape[1]], dtype=torch.int32).to(
75
- self.device
76
- )
77
- return speech_token, speech_token_len
78
-
79
- def _extract_spk_embedding(self, speech):
80
- feat = kaldi.fbank(speech, num_mel_bins=80, dither=0, sample_frequency=16000)
81
- feat = feat - feat.mean(dim=0, keepdim=True)
82
- embedding = (
83
- self.campplus_session.run(
84
- None,
85
- {
86
- self.campplus_session.get_inputs()[0]
87
- .name: feat.unsqueeze(dim=0)
88
- .cpu()
89
- .numpy()
90
- },
91
- )[0]
92
- .flatten()
93
- .tolist()
94
- )
95
- embedding = torch.tensor([embedding]).to(self.device)
96
- return embedding
97
-
98
- def _extract_speech_feat(self, speech):
99
- speech_feat = (
100
- self.feat_extractor(speech).squeeze(dim=0).transpose(0, 1).to(self.device)
101
- )
102
- speech_feat = speech_feat.unsqueeze(dim=0)
103
- speech_feat_len = torch.tensor([speech_feat.shape[1]], dtype=torch.int32).to(
104
- self.device
105
- )
106
- return speech_feat, speech_feat_len
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cosyvoice/cli/model.py DELETED
@@ -1,32 +0,0 @@
1
- # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- import torch
15
-
16
-
17
- class CosyVoiceModel:
18
-
19
- def __init__(
20
- self,
21
- flow: torch.nn.Module,
22
- hift: torch.nn.Module,
23
- ):
24
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
- self.flow = flow
26
- self.hift = hift
27
-
28
- def load(self, flow_model, hift_model):
29
- self.flow.load_state_dict(torch.load(flow_model, map_location=self.device))
30
- self.flow.to(self.device).eval()
31
- self.hift.load_state_dict(torch.load(hift_model, map_location=self.device))
32
- self.hift.to(self.device).eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cosyvoice/flow/decoder.py DELETED
@@ -1,238 +0,0 @@
1
- # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- import torch
15
- import torch.nn as nn
16
- from einops import pack, rearrange, repeat
17
- from cosyvoice.matcha.decoder import (
18
- SinusoidalPosEmb,
19
- Block1D,
20
- ResnetBlock1D,
21
- Downsample1D,
22
- TimestepEmbedding,
23
- Upsample1D,
24
- )
25
- from cosyvoice.matcha.transformer import BasicTransformerBlock
26
-
27
-
28
- class ConditionalDecoder(nn.Module):
29
- def __init__(
30
- self,
31
- in_channels,
32
- out_channels,
33
- channels=(256, 256),
34
- dropout=0.05,
35
- attention_head_dim=64,
36
- n_blocks=1,
37
- num_mid_blocks=2,
38
- num_heads=4,
39
- act_fn="snake",
40
- ):
41
- """
42
- This decoder requires an input with the same shape of the target. So, if your text content
43
- is shorter or longer than the outputs, please re-sampling it before feeding to the decoder.
44
- """
45
- super().__init__()
46
- channels = tuple(channels)
47
- self.in_channels = in_channels
48
- self.out_channels = out_channels
49
-
50
- self.time_embeddings = SinusoidalPosEmb(in_channels)
51
- time_embed_dim = channels[0] * 4
52
- self.time_mlp = TimestepEmbedding(
53
- in_channels=in_channels,
54
- time_embed_dim=time_embed_dim,
55
- act_fn="silu",
56
- )
57
- self.down_blocks = nn.ModuleList([])
58
- self.mid_blocks = nn.ModuleList([])
59
- self.up_blocks = nn.ModuleList([])
60
-
61
- output_channel = in_channels
62
- for i in range(len(channels)): # pylint: disable=consider-using-enumerate
63
- input_channel = output_channel
64
- output_channel = channels[i]
65
- is_last = i == len(channels) - 1
66
- resnet = ResnetBlock1D(
67
- dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim
68
- )
69
- transformer_blocks = nn.ModuleList(
70
- [
71
- BasicTransformerBlock(
72
- dim=output_channel,
73
- num_attention_heads=num_heads,
74
- attention_head_dim=attention_head_dim,
75
- dropout=dropout,
76
- activation_fn=act_fn,
77
- )
78
- for _ in range(n_blocks)
79
- ]
80
- )
81
- downsample = (
82
- Downsample1D(output_channel)
83
- if not is_last
84
- else nn.Conv1d(output_channel, output_channel, 3, padding=1)
85
- )
86
- self.down_blocks.append(
87
- nn.ModuleList([resnet, transformer_blocks, downsample])
88
- )
89
-
90
- for _ in range(num_mid_blocks):
91
- input_channel = channels[-1]
92
- out_channels = channels[-1]
93
- resnet = ResnetBlock1D(
94
- dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim
95
- )
96
-
97
- transformer_blocks = nn.ModuleList(
98
- [
99
- BasicTransformerBlock(
100
- dim=output_channel,
101
- num_attention_heads=num_heads,
102
- attention_head_dim=attention_head_dim,
103
- dropout=dropout,
104
- activation_fn=act_fn,
105
- )
106
- for _ in range(n_blocks)
107
- ]
108
- )
109
-
110
- self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))
111
-
112
- channels = channels[::-1] + (channels[0],)
113
- for i in range(len(channels) - 1):
114
- input_channel = channels[i] * 2
115
- output_channel = channels[i + 1]
116
- is_last = i == len(channels) - 2
117
- resnet = ResnetBlock1D(
118
- dim=input_channel,
119
- dim_out=output_channel,
120
- time_emb_dim=time_embed_dim,
121
- )
122
- transformer_blocks = nn.ModuleList(
123
- [
124
- BasicTransformerBlock(
125
- dim=output_channel,
126
- num_attention_heads=num_heads,
127
- attention_head_dim=attention_head_dim,
128
- dropout=dropout,
129
- activation_fn=act_fn,
130
- )
131
- for _ in range(n_blocks)
132
- ]
133
- )
134
- upsample = (
135
- Upsample1D(output_channel, use_conv_transpose=True)
136
- if not is_last
137
- else nn.Conv1d(output_channel, output_channel, 3, padding=1)
138
- )
139
- self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
140
- self.final_block = Block1D(channels[-1], channels[-1])
141
- self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
142
- self.initialize_weights()
143
-
144
- def initialize_weights(self):
145
- for m in self.modules():
146
- if isinstance(m, nn.Conv1d):
147
- nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
148
- if m.bias is not None:
149
- nn.init.constant_(m.bias, 0)
150
- elif isinstance(m, nn.GroupNorm):
151
- nn.init.constant_(m.weight, 1)
152
- nn.init.constant_(m.bias, 0)
153
- elif isinstance(m, nn.Linear):
154
- nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
155
- if m.bias is not None:
156
- nn.init.constant_(m.bias, 0)
157
-
158
- def forward(self, x, mask, mu, t, spks=None, cond=None):
159
- """Forward pass of the UNet1DConditional model.
160
-
161
- Args:
162
- x (torch.Tensor): shape (batch_size, in_channels, time)
163
- mask (_type_): shape (batch_size, 1, time)
164
- t (_type_): shape (batch_size)
165
- spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
166
- cond (_type_, optional): placeholder for future use. Defaults to None.
167
-
168
- Raises:
169
- ValueError: _description_
170
- ValueError: _description_
171
-
172
- Returns:
173
- _type_: _description_
174
- """
175
-
176
- t = self.time_embeddings(t).to(t.dtype)
177
- t = self.time_mlp(t)
178
-
179
- x = pack([x, mu], "b * t")[0]
180
-
181
- if spks is not None:
182
- spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
183
- x = pack([x, spks], "b * t")[0]
184
- if cond is not None:
185
- x = pack([x, cond], "b * t")[0]
186
-
187
- hiddens = []
188
- masks = [mask]
189
- for resnet, transformer_blocks, downsample in self.down_blocks:
190
- mask_down = masks[-1]
191
- x = resnet(
192
- x.to(torch.bfloat16), mask_down.to(torch.bfloat16), t.to(torch.bfloat16)
193
- )
194
- x = rearrange(x, "b c t -> b t c").contiguous()
195
- # attn_mask = torch.matmul(mask_down.transpose(1, 2).contiguous(), mask_down)
196
- for transformer_block in transformer_blocks:
197
- x = transformer_block(
198
- hidden_states=x,
199
- # attention_mask=attn_mask,
200
- timestep=t,
201
- )
202
- x = rearrange(x, "b t c -> b c t").contiguous()
203
- hiddens.append(x) # Save hidden states for skip connections
204
- x = downsample(x * mask_down)
205
- masks.append(mask_down[:, :, ::2])
206
- masks = masks[:-1]
207
- mask_mid = masks[-1]
208
-
209
- for resnet, transformer_blocks in self.mid_blocks:
210
- x = resnet(x, mask_mid, t)
211
- x = rearrange(x, "b c t -> b t c").contiguous()
212
- # attn_mask = torch.matmul(mask_mid.transpose(1, 2).contiguous(), mask_mid)
213
- for transformer_block in transformer_blocks:
214
- x = transformer_block(
215
- hidden_states=x,
216
- # attention_mask=attn_mask,
217
- timestep=t,
218
- )
219
- x = rearrange(x, "b t c -> b c t").contiguous()
220
-
221
- for resnet, transformer_blocks, upsample in self.up_blocks:
222
- mask_up = masks.pop()
223
- skip = hiddens.pop()
224
- x = pack([x[:, :, : skip.shape[-1]], skip], "b * t")[0]
225
- x = resnet(x, mask_up, t)
226
- x = rearrange(x, "b c t -> b t c").contiguous()
227
- # attn_mask = torch.matmul(mask_up.transpose(1, 2).contiguous(), mask_up)
228
- for transformer_block in transformer_blocks:
229
- x = transformer_block(
230
- hidden_states=x,
231
- # attention_mask=attn_mask,
232
- timestep=t,
233
- )
234
- x = rearrange(x, "b t c -> b c t").contiguous()
235
- x = upsample(x * mask_up)
236
- x = self.final_block(x, mask_up)
237
- output = self.final_proj(x * mask_up)
238
- return output * mask
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cosyvoice/flow/flow.py DELETED
@@ -1,196 +0,0 @@
1
- # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- import logging
15
- import random
16
- from typing import Dict, Optional
17
- import torch
18
- import torch.nn as nn
19
- from torch.nn import functional as F
20
- from omegaconf import DictConfig
21
- from cosyvoice.utils.mask import make_pad_mask
22
- import time
23
-
24
-
25
- class MaskedDiffWithXvec(torch.nn.Module):
26
- def __init__(
27
- self,
28
- input_size: int = 512,
29
- output_size: int = 80,
30
- spk_embed_dim: int = 192,
31
- output_type: str = "mel",
32
- vocab_size: int = 4096,
33
- input_frame_rate: int = 50,
34
- only_mask_loss: bool = True,
35
- encoder: torch.nn.Module = None,
36
- length_regulator: torch.nn.Module = None,
37
- decoder: torch.nn.Module = None,
38
- decoder_conf: Dict = {
39
- "in_channels": 240,
40
- "out_channel": 80,
41
- "spk_emb_dim": 80,
42
- "n_spks": 1,
43
- "cfm_params": DictConfig(
44
- {
45
- "sigma_min": 1e-06,
46
- "solver": "euler",
47
- "t_scheduler": "cosine",
48
- "training_cfg_rate": 0.2,
49
- "inference_cfg_rate": 0.7,
50
- "reg_loss_type": "l1",
51
- }
52
- ),
53
- "decoder_params": {
54
- "channels": [256, 256],
55
- "dropout": 0.0,
56
- "attention_head_dim": 64,
57
- "n_blocks": 4,
58
- "num_mid_blocks": 12,
59
- "num_heads": 8,
60
- "act_fn": "gelu",
61
- },
62
- },
63
- mel_feat_conf: Dict = {
64
- "n_fft": 1024,
65
- "num_mels": 80,
66
- "sampling_rate": 22050,
67
- "hop_size": 256,
68
- "win_size": 1024,
69
- "fmin": 0,
70
- "fmax": 8000,
71
- },
72
- ):
73
- super().__init__()
74
- self.input_size = input_size
75
- self.output_size = output_size
76
- self.decoder_conf = decoder_conf
77
- self.mel_feat_conf = mel_feat_conf
78
- self.vocab_size = vocab_size
79
- self.output_type = output_type
80
- self.input_frame_rate = input_frame_rate
81
- logging.info(f"input frame rate={self.input_frame_rate}")
82
- self.input_embedding = nn.Embedding(vocab_size, input_size)
83
- self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
84
- self.encoder = encoder
85
- self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
86
- self.decoder = decoder
87
- self.length_regulator = length_regulator
88
- self.only_mask_loss = only_mask_loss
89
-
90
- def forward(
91
- self,
92
- batch: dict,
93
- device: torch.device,
94
- ) -> Dict[str, Optional[torch.Tensor]]:
95
- token = batch["speech_token"].to(device)
96
- token_len = batch["speech_token_len"].to(device)
97
- feat = batch["speech_feat"].to(device)
98
- feat_len = batch["speech_feat_len"].to(device)
99
- embedding = batch["embedding"].to(device)
100
-
101
- # xvec projection
102
- embedding = F.normalize(embedding, dim=1)
103
- embedding = self.spk_embed_affine_layer(embedding)
104
-
105
- # concat text and prompt_text
106
- mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device)
107
- token = self.input_embedding(torch.clamp(token, min=0)) * mask
108
-
109
- # text encode
110
- h, h_lengths = self.encoder(token, token_len)
111
- h = self.encoder_proj(h)
112
- h, h_lengths = self.length_regulator(h, feat_len)
113
-
114
- # get conditions
115
- conds = torch.zeros(feat.shape, device=token.device)
116
- for i, j in enumerate(feat_len):
117
- if random.random() < 0.5:
118
- continue
119
- index = random.randint(0, int(0.3 * j))
120
- conds[i, :index] = feat[i, :index]
121
- conds = conds.transpose(1, 2)
122
-
123
- mask = (~make_pad_mask(feat_len)).to(h)
124
- feat = F.interpolate(
125
- feat.unsqueeze(dim=1), size=h.shape[1:], mode="nearest"
126
- ).squeeze(dim=1)
127
- loss, _ = self.decoder.compute_loss(
128
- feat.transpose(1, 2).contiguous(),
129
- mask.unsqueeze(1),
130
- h.transpose(1, 2).contiguous(),
131
- embedding,
132
- cond=conds,
133
- )
134
- return {"loss": loss}
135
-
136
- @torch.inference_mode()
137
- def inference(
138
- self,
139
- token,
140
- token_len,
141
- prompt_token,
142
- prompt_token_len,
143
- prompt_feat,
144
- prompt_feat_len,
145
- embedding,
146
- ):
147
- assert token.shape[0] == 1
148
- # xvec projection
149
- embedding = F.normalize(embedding, dim=1)
150
- embedding = self.spk_embed_affine_layer(embedding)
151
-
152
- # concat text and prompt_text
153
- token_len1, token_len2 = prompt_token.shape[1], token.shape[1]
154
- # text encode
155
- token, token_len = (
156
- torch.concat([prompt_token, token], dim=1),
157
- prompt_token_len + token_len,
158
- )
159
- token = self.input_embedding(torch.clamp(token, min=0))
160
- h, _ = self.encoder.inference(token, token_len)
161
- h = self.encoder_proj(h)
162
- mel_len1, mel_len2 = prompt_feat.shape[1], int(
163
- token_len2
164
- / self.input_frame_rate
165
- * self.mel_feat_conf["sampling_rate"]
166
- / self.mel_feat_conf["hop_size"]
167
- )
168
-
169
- h, _ = self.length_regulator.inference(
170
- h[:, :token_len1],
171
- h[:, token_len1:],
172
- mel_len1,
173
- mel_len2,
174
- )
175
-
176
- # get conditions
177
- conds = torch.zeros(
178
- [1, mel_len1 + mel_len2, self.output_size], device=token.device
179
- )
180
- conds[:, :mel_len1] = prompt_feat
181
- conds = conds.transpose(1, 2)
182
-
183
- # mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
184
- mask = torch.ones(
185
- [1, mel_len1 + mel_len2], device=h.device, dtype=torch.bfloat16
186
- )
187
- feat = self.decoder(
188
- mu=h.transpose(1, 2).contiguous(),
189
- mask=mask.unsqueeze(1),
190
- spks=embedding,
191
- cond=conds,
192
- n_timesteps=10,
193
- )
194
- feat = feat[:, :, mel_len1:]
195
- assert feat.shape[2] == mel_len2
196
- return feat
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cosyvoice/flow/flow_matching.py DELETED
@@ -1,315 +0,0 @@
1
- # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- import time
15
- import torch
16
- import torch.nn.functional as F
17
- from cosyvoice.matcha.flow_matching import BASECFM
18
-
19
-
20
- class ConditionalCFM(BASECFM):
21
- def __init__(
22
- self,
23
- in_channels,
24
- cfm_params,
25
- n_spks=1,
26
- spk_emb_dim=64,
27
- estimator: torch.nn.Module = None,
28
- ):
29
- super().__init__(
30
- n_feats=in_channels,
31
- cfm_params=cfm_params,
32
- n_spks=n_spks,
33
- spk_emb_dim=spk_emb_dim,
34
- )
35
- self.t_scheduler = cfm_params.t_scheduler
36
- self.training_cfg_rate = cfm_params.training_cfg_rate
37
- self.inference_cfg_rate = cfm_params.inference_cfg_rate
38
- in_channels = in_channels + (spk_emb_dim if n_spks > 0 else 0)
39
- # Just change the architecture of the estimator here
40
- self.estimator = estimator
41
- self.inference_graphs = {}
42
- self.inference_buffers = {}
43
- # self.capture_inference()
44
-
45
- @torch.inference_mode()
46
- def forward(
47
- self,
48
- mu,
49
- mask,
50
- n_timesteps,
51
- temperature=1.0,
52
- spks=None,
53
- cond=None,
54
- ):
55
- """Forward diffusion
56
-
57
- Args:
58
- mu (torch.Tensor): output of encoder
59
- shape: (batch_size, n_feats, mel_timesteps)
60
- mask (torch.Tensor): output_mask
61
- shape: (batch_size, 1, mel_timesteps)
62
- n_timesteps (int): number of diffusion steps
63
- temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
64
- spks (torch.Tensor, optional): speaker ids. Defaults to None.
65
- shape: (batch_size, spk_emb_dim)
66
- cond: Not used but kept for future purposes
67
-
68
- Returns:
69
- sample: generated mel-spectrogram
70
- shape: (batch_size, n_feats, mel_timesteps)
71
- """
72
- z = torch.randn_like(mu) * temperature
73
- t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
74
- if self.t_scheduler == "cosine":
75
- t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
76
- return self.solve_euler(
77
- z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond
78
- )
79
-
80
- @torch.inference_mode()
81
- def capture_inference(self, seq_len_to_capture=list(range(128, 512, 8))):
82
- start_time = time.time()
83
- print(
84
- f"capture_inference for ConditionalCFM solve euler, seq_len_to_capture: {seq_len_to_capture}"
85
- )
86
- for seq_len in seq_len_to_capture:
87
- static_z = torch.randn(
88
- 1, 80, seq_len, device=torch.device("cuda"), dtype=torch.bfloat16
89
- )
90
- static_t_span = torch.linspace(
91
- 0, 1, 11, device=torch.device("cuda"), dtype=torch.bfloat16
92
- ) # only capture at 10 steps
93
- static_mu = torch.randn(
94
- 1, 80, seq_len, device=torch.device("cuda"), dtype=torch.bfloat16
95
- )
96
- static_mask = torch.ones(
97
- 1, 1, seq_len, device=torch.device("cuda"), dtype=torch.bfloat16
98
- )
99
- static_spks = torch.randn(
100
- 1, 80, device=torch.device("cuda"), dtype=torch.bfloat16
101
- )
102
- static_cond = torch.randn(
103
- 1, 80, seq_len, device=torch.device("cuda"), dtype=torch.float32
104
- )
105
- static_out = torch.randn(
106
- 1, 80, seq_len, device=torch.device("cuda"), dtype=torch.bfloat16
107
- )
108
-
109
- self._solve_euler_impl(
110
- static_z,
111
- t_span=static_t_span,
112
- mu=static_mu,
113
- mask=static_mask,
114
- spks=static_spks,
115
- cond=static_cond,
116
- )
117
- torch.cuda.synchronize()
118
-
119
- g = torch.cuda.CUDAGraph()
120
- with torch.cuda.graph(g):
121
- static_out = self._solve_euler_impl(
122
- static_z,
123
- t_span=static_t_span,
124
- mu=static_mu,
125
- mask=static_mask,
126
- spks=static_spks,
127
- cond=static_cond,
128
- )
129
-
130
- self.inference_buffers[seq_len] = {
131
- "z": static_z,
132
- "t_span": static_t_span,
133
- "mu": static_mu,
134
- "mask": static_mask,
135
- "spks": static_spks,
136
- "cond": static_cond,
137
- "out": static_out,
138
- }
139
- self.inference_graphs[seq_len] = g
140
- end_time = time.time()
141
- print(
142
- f"capture_inference for ConditionalCFM solve euler, time elapsed: {end_time - start_time}"
143
- )
144
-
145
- def solve_euler(self, x, t_span, mu, mask, spks, cond):
146
- if hasattr(self, "inference_graphs") and len(self.inference_graphs) > 0:
147
- curr_seq_len = x.shape[2]
148
-
149
- available_lengths = sorted(list(self.inference_graphs.keys()))
150
-
151
- if curr_seq_len <= max(available_lengths):
152
- target_len = min(available_lengths, key=lambda x: abs(x - curr_seq_len))
153
- if target_len == curr_seq_len:
154
- padded_x = x
155
- padded_mu = mu
156
- padded_mask = mask
157
- if cond is not None:
158
- padded_cond = cond
159
- else:
160
- padded_x = torch.randn(
161
- (x.shape[0], x.shape[1], target_len),
162
- dtype=x.dtype,
163
- device=x.device,
164
- )
165
- padded_x[:, :, :curr_seq_len] = x
166
-
167
- padded_mu = torch.randn(
168
- (mu.shape[0], mu.shape[1], target_len),
169
- dtype=mu.dtype,
170
- device=mu.device,
171
- )
172
- padded_mu[:, :, :curr_seq_len] = mu
173
-
174
- # FIXME(ys): uses zeros and maskgroupnorm
175
- padded_mask = torch.ones(
176
- (mask.shape[0], mask.shape[1], target_len),
177
- dtype=mask.dtype,
178
- device=mask.device,
179
- )
180
-
181
- if cond is not None:
182
- padded_cond = torch.randn(
183
- (cond.shape[0], cond.shape[1], target_len),
184
- dtype=cond.dtype,
185
- device=cond.device,
186
- )
187
- padded_cond[:, :, :curr_seq_len] = cond
188
-
189
- buffer = self.inference_buffers[target_len]
190
- buffer["z"].copy_(padded_x)
191
- buffer["t_span"].copy_(t_span)
192
- buffer["mu"].copy_(padded_mu)
193
- buffer["mask"].copy_(padded_mask)
194
- buffer["spks"].copy_(spks)
195
- if cond is not None:
196
- buffer["cond"].copy_(padded_cond)
197
-
198
- self.inference_graphs[target_len].replay()
199
-
200
- output = buffer["out"][:, :, :curr_seq_len]
201
- return output
202
-
203
- return self._solve_euler_impl(x, t_span, mu, mask, spks, cond)
204
-
205
- def _solve_euler_impl(self, x, t_span, mu, mask, spks, cond):
206
- """
207
- Fixed euler solver for ODEs.
208
- Args:
209
- x (torch.Tensor): random noise
210
- t_span (torch.Tensor): n_timesteps interpolated
211
- shape: (n_timesteps + 1,)
212
- mu (torch.Tensor): output of encoder
213
- shape: (batch_size, n_feats, mel_timesteps)
214
- mask (torch.Tensor): output_mask
215
- shape: (batch_size, 1, mel_timesteps)
216
- spks (torch.Tensor, optional): speaker ids. Defaults to None.
217
- shape: (batch_size, spk_emb_dim)
218
- cond: Not used but kept for future purposes
219
- """
220
- t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
221
- t = t.unsqueeze(dim=0)
222
-
223
- # I am storing this because I can later plot it by putting a debugger here and saving it to a file
224
- # Or in future might add like a return_all_steps flag
225
- sol = []
226
-
227
- for step in range(1, len(t_span)):
228
- if self.inference_cfg_rate > 0:
229
- x_double = torch.cat([x, x], dim=0)
230
- mask_double = torch.cat([mask, mask], dim=0)
231
- mu_double = torch.cat([mu, torch.zeros_like(mu)], dim=0)
232
- t_double = torch.cat([t, t], dim=0)
233
- spks_double = (
234
- torch.cat([spks, torch.zeros_like(spks)], dim=0)
235
- if spks is not None
236
- else None
237
- )
238
- cond_double = torch.cat([cond, torch.zeros_like(cond)], dim=0)
239
-
240
- dphi_dt_double = self.forward_estimator(
241
- x_double, mask_double, mu_double, t_double, spks_double, cond_double
242
- )
243
-
244
- dphi_dt, cfg_dphi_dt = torch.chunk(dphi_dt_double, 2, dim=0)
245
- dphi_dt = (
246
- 1.0 + self.inference_cfg_rate
247
- ) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt
248
- else:
249
- dphi_dt = self.forward_estimator(x, mask, mu, t, spks, cond)
250
-
251
- x = x + dt * dphi_dt
252
- t = t + dt
253
- sol.append(x)
254
- if step < len(t_span) - 1:
255
- dt = t_span[step + 1] - t
256
-
257
- return sol[-1]
258
-
259
- def forward_estimator(self, x, mask, mu, t, spks, cond):
260
- if isinstance(self.estimator, torch.nn.Module):
261
- return self.estimator.forward(x, mask, mu, t, spks, cond)
262
- else:
263
- ort_inputs = {
264
- "x": x.cpu().numpy(),
265
- "mask": mask.cpu().numpy(),
266
- "mu": mu.cpu().numpy(),
267
- "t": t.cpu().numpy(),
268
- "spks": spks.cpu().numpy(),
269
- "cond": cond.cpu().numpy(),
270
- }
271
- output = self.estimator.run(None, ort_inputs)[0]
272
- return torch.tensor(output, dtype=x.dtype, device=x.device)
273
-
274
- def compute_loss(self, x1, mask, mu, spks=None, cond=None):
275
- """Computes diffusion loss
276
-
277
- Args:
278
- x1 (torch.Tensor): Target
279
- shape: (batch_size, n_feats, mel_timesteps)
280
- mask (torch.Tensor): target mask
281
- shape: (batch_size, 1, mel_timesteps)
282
- mu (torch.Tensor): output of encoder
283
- shape: (batch_size, n_feats, mel_timesteps)
284
- spks (torch.Tensor, optional): speaker embedding. Defaults to None.
285
- shape: (batch_size, spk_emb_dim)
286
-
287
- Returns:
288
- loss: conditional flow matching loss
289
- y: conditional flow
290
- shape: (batch_size, n_feats, mel_timesteps)
291
- """
292
- b, _, t = mu.shape
293
-
294
- # random timestep
295
- t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
296
- if self.t_scheduler == "cosine":
297
- t = 1 - torch.cos(t * 0.5 * torch.pi)
298
- # sample noise p(x_0)
299
- z = torch.randn_like(x1)
300
-
301
- y = (1 - (1 - self.sigma_min) * t) * z + t * x1
302
- u = x1 - (1 - self.sigma_min) * z
303
-
304
- # during training, we randomly drop condition to trade off mode coverage and sample fidelity
305
- if self.training_cfg_rate > 0:
306
- cfg_mask = torch.rand(b, device=x1.device) > self.training_cfg_rate
307
- mu = mu * cfg_mask.view(-1, 1, 1)
308
- spks = spks * cfg_mask.view(-1, 1)
309
- cond = cond * cfg_mask.view(-1, 1, 1)
310
-
311
- pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond)
312
- loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (
313
- torch.sum(mask) * u.shape[1]
314
- )
315
- return loss, y
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cosyvoice/flow/length_regulator.py DELETED
@@ -1,65 +0,0 @@
1
- # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- from typing import Tuple
15
- import torch.nn as nn
16
- import torch
17
- from torch.nn import functional as F
18
- from cosyvoice.utils.mask import make_pad_mask
19
-
20
-
21
- class InterpolateRegulator(nn.Module):
22
- def __init__(
23
- self,
24
- channels: int,
25
- sampling_ratios: Tuple,
26
- out_channels: int = None,
27
- groups: int = 1,
28
- ):
29
- super().__init__()
30
- self.sampling_ratios = sampling_ratios
31
- out_channels = out_channels or channels
32
- model = nn.ModuleList([])
33
- if len(sampling_ratios) > 0:
34
- for _ in sampling_ratios:
35
- module = nn.Conv1d(channels, channels, 3, 1, 1)
36
- norm = nn.GroupNorm(groups, channels)
37
- act = nn.Mish()
38
- model.extend([module, norm, act])
39
- model.append(nn.Conv1d(channels, out_channels, 1, 1))
40
- self.model = nn.Sequential(*model)
41
-
42
- def forward(self, x, ylens=None):
43
- # x in (B, T, D)
44
- mask = (~make_pad_mask(ylens)).to(x).unsqueeze(-1)
45
- x = F.interpolate(
46
- x.transpose(1, 2).contiguous(), size=ylens.max(), mode="linear"
47
- )
48
- out = self.model(x).transpose(1, 2).contiguous()
49
- olens = ylens
50
- return out * mask, olens
51
-
52
- def inference(self, x1, x2, mel_len1, mel_len2):
53
- # x in (B, T, D)
54
- x2 = F.interpolate(
55
- x2.transpose(1, 2).contiguous(), size=mel_len2, mode="linear"
56
- )
57
- if x1.shape[1] != 0:
58
- x1 = F.interpolate(
59
- x1.transpose(1, 2).contiguous(), size=mel_len1, mode="linear"
60
- )
61
- x = torch.concat([x1, x2], dim=2)
62
- else:
63
- x = x2
64
- out = self.model(x).transpose(1, 2).contiguous()
65
- return out, mel_len1 + mel_len2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cosyvoice/hifigan/f0_predictor.py DELETED
@@ -1,55 +0,0 @@
1
- # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Kai Hu)
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- import torch
15
- import torch.nn as nn
16
- from torch.nn.utils import weight_norm
17
-
18
-
19
- class ConvRNNF0Predictor(nn.Module):
20
- def __init__(
21
- self, num_class: int = 1, in_channels: int = 80, cond_channels: int = 512
22
- ):
23
- super().__init__()
24
-
25
- self.num_class = num_class
26
- self.condnet = nn.Sequential(
27
- weight_norm(
28
- nn.Conv1d(in_channels, cond_channels, kernel_size=3, padding=1)
29
- ),
30
- nn.ELU(),
31
- weight_norm(
32
- nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
33
- ),
34
- nn.ELU(),
35
- weight_norm(
36
- nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
37
- ),
38
- nn.ELU(),
39
- weight_norm(
40
- nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
41
- ),
42
- nn.ELU(),
43
- weight_norm(
44
- nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
45
- ),
46
- nn.ELU(),
47
- )
48
- self.classifier = nn.Linear(
49
- in_features=cond_channels, out_features=self.num_class
50
- )
51
-
52
- def forward(self, x: torch.Tensor) -> torch.Tensor:
53
- x = self.condnet(x)
54
- x = x.transpose(1, 2)
55
- return torch.abs(self.classifier(x).squeeze(-1))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cosyvoice/hifigan/generator.py DELETED
@@ -1,566 +0,0 @@
1
- # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Kai Hu)
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- """HIFI-GAN"""
16
-
17
- import typing as tp
18
- import time
19
- import numpy as np
20
- from scipy.signal import get_window
21
- import torch
22
- import torch.nn as nn
23
- import torch.nn.functional as F
24
- from torch.nn import Conv1d
25
- from torch.nn import ConvTranspose1d
26
- from torch.nn.utils import remove_weight_norm
27
- from torch.nn.utils import weight_norm
28
- from torch.distributions.uniform import Uniform
29
-
30
- from cosyvoice.transformer.activation import Snake
31
- from cosyvoice.utils.common import get_padding
32
- from cosyvoice.utils.common import init_weights
33
-
34
-
35
- """hifigan based generator implementation.
36
-
37
- This code is modified from https://github.com/jik876/hifi-gan
38
- ,https://github.com/kan-bayashi/ParallelWaveGAN and
39
- https://github.com/NVIDIA/BigVGAN
40
-
41
- """
42
-
43
-
44
- class ResBlock(torch.nn.Module):
45
- """Residual block module in HiFiGAN/BigVGAN."""
46
-
47
- def __init__(
48
- self,
49
- channels: int = 512,
50
- kernel_size: int = 3,
51
- dilations: tp.List[int] = [1, 3, 5],
52
- ):
53
- super(ResBlock, self).__init__()
54
- self.convs1 = nn.ModuleList()
55
- self.convs2 = nn.ModuleList()
56
-
57
- for dilation in dilations:
58
- self.convs1.append(
59
- weight_norm(
60
- Conv1d(
61
- channels,
62
- channels,
63
- kernel_size,
64
- 1,
65
- dilation=dilation,
66
- padding=get_padding(kernel_size, dilation),
67
- )
68
- )
69
- )
70
- self.convs2.append(
71
- weight_norm(
72
- Conv1d(
73
- channels,
74
- channels,
75
- kernel_size,
76
- 1,
77
- dilation=1,
78
- padding=get_padding(kernel_size, 1),
79
- )
80
- )
81
- )
82
- self.convs1.apply(init_weights)
83
- self.convs2.apply(init_weights)
84
- self.activations1 = nn.ModuleList(
85
- [Snake(channels, alpha_logscale=False) for _ in range(len(self.convs1))]
86
- )
87
- self.activations2 = nn.ModuleList(
88
- [Snake(channels, alpha_logscale=False) for _ in range(len(self.convs2))]
89
- )
90
-
91
- def forward(self, x: torch.Tensor) -> torch.Tensor:
92
- for idx in range(len(self.convs1)):
93
- xt = self.activations1[idx](x)
94
- xt = self.convs1[idx](xt)
95
- xt = self.activations2[idx](xt)
96
- xt = self.convs2[idx](xt)
97
- x = xt + x
98
- return x
99
-
100
- def remove_weight_norm(self):
101
- for idx in range(len(self.convs1)):
102
- remove_weight_norm(self.convs1[idx])
103
- remove_weight_norm(self.convs2[idx])
104
-
105
-
106
- class SineGen(torch.nn.Module):
107
- """Definition of sine generator
108
- SineGen(samp_rate, harmonic_num = 0,
109
- sine_amp = 0.1, noise_std = 0.003,
110
- voiced_threshold = 0,
111
- flag_for_pulse=False)
112
- samp_rate: sampling rate in Hz
113
- harmonic_num: number of harmonic overtones (default 0)
114
- sine_amp: amplitude of sine-wavefrom (default 0.1)
115
- noise_std: std of Gaussian noise (default 0.003)
116
- voiced_thoreshold: F0 threshold for U/V classification (default 0)
117
- flag_for_pulse: this SinGen is used inside PulseGen (default False)
118
- Note: when flag_for_pulse is True, the first time step of a voiced
119
- segment is always sin(np.pi) or cos(0)
120
- """
121
-
122
- def __init__(
123
- self,
124
- samp_rate,
125
- harmonic_num=0,
126
- sine_amp=0.1,
127
- noise_std=0.003,
128
- voiced_threshold=0,
129
- ):
130
- super(SineGen, self).__init__()
131
- self.sine_amp = sine_amp
132
- self.noise_std = noise_std
133
- self.harmonic_num = harmonic_num
134
- self.sampling_rate = samp_rate
135
- self.voiced_threshold = voiced_threshold
136
-
137
- def _f02uv(self, f0):
138
- # generate uv signal
139
- uv = (f0 > self.voiced_threshold).type(torch.float32)
140
- return uv
141
-
142
- @torch.no_grad()
143
- def forward(self, f0):
144
- """
145
- :param f0: [B, 1, sample_len], Hz
146
- :return: [B, 1, sample_len]
147
- """
148
-
149
- F_mat = torch.zeros((f0.size(0), self.harmonic_num + 1, f0.size(-1))).to(
150
- f0.device
151
- )
152
- for i in range(self.harmonic_num + 1):
153
- F_mat[:, i : i + 1, :] = f0 * (i + 1) / self.sampling_rate
154
-
155
- theta_mat = 2 * np.pi * (torch.cumsum(F_mat, dim=-1) % 1)
156
- u_dist = Uniform(low=-np.pi, high=np.pi)
157
- phase_vec = u_dist.sample(
158
- sample_shape=(f0.size(0), self.harmonic_num + 1, 1)
159
- ).to(F_mat.device)
160
- phase_vec[:, 0, :] = 0
161
-
162
- # generate sine waveforms
163
- sine_waves = self.sine_amp * torch.sin(theta_mat + phase_vec)
164
-
165
- # generate uv signal
166
- uv = self._f02uv(f0)
167
-
168
- # noise: for unvoiced should be similar to sine_amp
169
- # std = self.sine_amp/3 -> max value ~ self.sine_amp
170
- # . for voiced regions is self.noise_std
171
- noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
172
- noise = noise_amp * torch.randn_like(sine_waves)
173
-
174
- # first: set the unvoiced part to 0 by uv
175
- # then: additive noise
176
- sine_waves = sine_waves * uv + noise
177
- return sine_waves, uv, noise
178
-
179
-
180
- class SourceModuleHnNSF(torch.nn.Module):
181
- """SourceModule for hn-nsf
182
- SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
183
- add_noise_std=0.003, voiced_threshod=0)
184
- sampling_rate: sampling_rate in Hz
185
- harmonic_num: number of harmonic above F0 (default: 0)
186
- sine_amp: amplitude of sine source signal (default: 0.1)
187
- add_noise_std: std of additive Gaussian noise (default: 0.003)
188
- note that amplitude of noise in unvoiced is decided
189
- by sine_amp
190
- voiced_threshold: threhold to set U/V given F0 (default: 0)
191
- Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
192
- F0_sampled (batchsize, length, 1)
193
- Sine_source (batchsize, length, 1)
194
- noise_source (batchsize, length 1)
195
- uv (batchsize, length, 1)
196
- """
197
-
198
- def __init__(
199
- self,
200
- sampling_rate,
201
- upsample_scale,
202
- harmonic_num=0,
203
- sine_amp=0.1,
204
- add_noise_std=0.003,
205
- voiced_threshod=0,
206
- ):
207
- super(SourceModuleHnNSF, self).__init__()
208
-
209
- self.sine_amp = sine_amp
210
- self.noise_std = add_noise_std
211
-
212
- # to produce sine waveforms
213
- self.l_sin_gen = SineGen(
214
- sampling_rate, harmonic_num, sine_amp, add_noise_std, voiced_threshod
215
- )
216
-
217
- # to merge source harmonics into a single excitation
218
- self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
219
- self.l_tanh = torch.nn.Tanh()
220
-
221
- def forward(self, x):
222
- """
223
- Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
224
- F0_sampled (batchsize, length, 1)
225
- Sine_source (batchsize, length, 1)
226
- noise_source (batchsize, length 1)
227
- """
228
- # source for harmonic branch
229
- with torch.no_grad():
230
- sine_wavs, uv, _ = self.l_sin_gen(x.transpose(1, 2))
231
- sine_wavs = sine_wavs.transpose(1, 2)
232
- uv = uv.transpose(1, 2)
233
- sine_merge = self.l_tanh(self.l_linear(sine_wavs))
234
-
235
- # source for noise branch, in the same shape as uv
236
- noise = torch.randn_like(uv) * self.sine_amp / 3
237
- return sine_merge, noise, uv
238
-
239
-
240
- class HiFTGenerator(nn.Module):
241
- """
242
- HiFTNet Generator: Neural Source Filter + ISTFTNet
243
- https://arxiv.org/abs/2309.09493
244
- """
245
-
246
- def __init__(
247
- self,
248
- in_channels: int = 80,
249
- base_channels: int = 512,
250
- nb_harmonics: int = 8,
251
- sampling_rate: int = 22050,
252
- nsf_alpha: float = 0.1,
253
- nsf_sigma: float = 0.003,
254
- nsf_voiced_threshold: float = 10,
255
- upsample_rates: tp.List[int] = [8, 8],
256
- upsample_kernel_sizes: tp.List[int] = [16, 16],
257
- istft_params: tp.Dict[str, int] = {"n_fft": 16, "hop_len": 4},
258
- resblock_kernel_sizes: tp.List[int] = [3, 7, 11],
259
- resblock_dilation_sizes: tp.List[tp.List[int]] = [
260
- [1, 3, 5],
261
- [1, 3, 5],
262
- [1, 3, 5],
263
- ],
264
- source_resblock_kernel_sizes: tp.List[int] = [7, 11],
265
- source_resblock_dilation_sizes: tp.List[tp.List[int]] = [[1, 3, 5], [1, 3, 5]],
266
- lrelu_slope: float = 0.1,
267
- audio_limit: float = 0.99,
268
- f0_predictor: torch.nn.Module = None,
269
- ):
270
- super(HiFTGenerator, self).__init__()
271
-
272
- self.out_channels = 1
273
- self.nb_harmonics = nb_harmonics
274
- self.sampling_rate = sampling_rate
275
- self.istft_params = istft_params
276
- self.lrelu_slope = lrelu_slope
277
- self.audio_limit = audio_limit
278
-
279
- self.num_kernels = len(resblock_kernel_sizes)
280
- self.num_upsamples = len(upsample_rates)
281
- self.upsample_rates = upsample_rates
282
- self.m_source = SourceModuleHnNSF(
283
- sampling_rate=sampling_rate,
284
- upsample_scale=np.prod(upsample_rates) * istft_params["hop_len"],
285
- harmonic_num=nb_harmonics,
286
- sine_amp=nsf_alpha,
287
- add_noise_std=nsf_sigma,
288
- voiced_threshod=nsf_voiced_threshold,
289
- )
290
- self.f0_upsamp = torch.nn.Upsample(
291
- scale_factor=np.prod(upsample_rates) * istft_params["hop_len"]
292
- )
293
-
294
- self.conv_pre = weight_norm(Conv1d(in_channels, base_channels, 7, 1, padding=3))
295
-
296
- # Up
297
- self.ups = nn.ModuleList()
298
- for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
299
- self.ups.append(
300
- weight_norm(
301
- ConvTranspose1d(
302
- base_channels // (2**i),
303
- base_channels // (2 ** (i + 1)),
304
- k,
305
- u,
306
- padding=(k - u) // 2,
307
- )
308
- )
309
- )
310
-
311
- # Down
312
- self.source_downs = nn.ModuleList()
313
- self.source_resblocks = nn.ModuleList()
314
- downsample_rates = [1] + upsample_rates[::-1][:-1]
315
- downsample_cum_rates = np.cumprod(downsample_rates)
316
- for i, (u, k, d) in enumerate(
317
- zip(
318
- downsample_cum_rates[::-1],
319
- source_resblock_kernel_sizes,
320
- source_resblock_dilation_sizes,
321
- )
322
- ):
323
- if u == 1:
324
- self.source_downs.append(
325
- Conv1d(
326
- istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), 1, 1
327
- )
328
- )
329
- else:
330
- self.source_downs.append(
331
- Conv1d(
332
- istft_params["n_fft"] + 2,
333
- base_channels // (2 ** (i + 1)),
334
- u * 2,
335
- u,
336
- padding=(u // 2),
337
- )
338
- )
339
-
340
- self.source_resblocks.append(
341
- ResBlock(base_channels // (2 ** (i + 1)), k, d)
342
- )
343
-
344
- self.resblocks = nn.ModuleList()
345
- for i in range(len(self.ups)):
346
- ch = base_channels // (2 ** (i + 1))
347
- for _, (k, d) in enumerate(
348
- zip(resblock_kernel_sizes, resblock_dilation_sizes)
349
- ):
350
- self.resblocks.append(ResBlock(ch, k, d))
351
-
352
- self.conv_post = weight_norm(
353
- Conv1d(ch, istft_params["n_fft"] + 2, 7, 1, padding=3)
354
- )
355
- self.ups.apply(init_weights)
356
- self.conv_post.apply(init_weights)
357
- self.reflection_pad = nn.ReflectionPad1d((1, 0))
358
- self.stft_window = torch.from_numpy(
359
- get_window("hann", istft_params["n_fft"], fftbins=True).astype(np.float32)
360
- ).cuda()
361
- self.f0_predictor = f0_predictor
362
- self.inference_buffers = {}
363
- self.inference_graphs = {}
364
-
365
- def _f02source(self, f0: torch.Tensor) -> torch.Tensor:
366
- f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
367
-
368
- har_source, _, _ = self.m_source(f0)
369
- return har_source.transpose(1, 2)
370
-
371
- def _stft(self, x):
372
- spec = torch.stft(
373
- x,
374
- self.istft_params["n_fft"],
375
- self.istft_params["hop_len"],
376
- self.istft_params["n_fft"],
377
- window=self.stft_window,
378
- return_complex=True,
379
- )
380
- spec = torch.view_as_real(spec) # [B, F, TT, 2]
381
- return spec[..., 0], spec[..., 1]
382
-
383
- def _istft(self, magnitude, phase):
384
- magnitude = torch.clip(magnitude, max=1e2)
385
- real = magnitude * torch.cos(phase)
386
- img = magnitude * torch.sin(phase)
387
- inverse_transform = torch.istft(
388
- torch.complex(real, img),
389
- self.istft_params["n_fft"],
390
- self.istft_params["hop_len"],
391
- self.istft_params["n_fft"],
392
- window=self.stft_window,
393
- )
394
- return inverse_transform
395
-
396
- def forward(
397
- self, x: torch.Tensor, cache_source: torch.Tensor = torch.zeros(1, 1, 0)
398
- ) -> torch.Tensor:
399
- f0 = self.f0_predictor(x)
400
- s = self._f02source(f0)
401
-
402
- # use cache_source to avoid glitch
403
- if cache_source.shape[2] != 0:
404
- s[:, :, : cache_source.shape[2]] = cache_source
405
-
406
- s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
407
- s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1)
408
-
409
- x = self.conv_pre(x)
410
- for i in range(self.num_upsamples):
411
- x = F.leaky_relu(x, self.lrelu_slope)
412
- x = self.ups[i](x)
413
-
414
- if i == self.num_upsamples - 1:
415
- x = self.reflection_pad(x)
416
-
417
- # fusion
418
- si = self.source_downs[i](s_stft)
419
- si = self.source_resblocks[i](si)
420
- x = x + si
421
-
422
- xs = None
423
- for j in range(self.num_kernels):
424
- if xs is None:
425
- xs = self.resblocks[i * self.num_kernels + j](x)
426
- else:
427
- xs += self.resblocks[i * self.num_kernels + j](x)
428
- x = xs / self.num_kernels
429
-
430
- x = F.leaky_relu(x)
431
- x = self.conv_post(x)
432
- magnitude = torch.exp(x[:, : self.istft_params["n_fft"] // 2 + 1, :])
433
- phase = torch.sin(
434
- x[:, self.istft_params["n_fft"] // 2 + 1 :, :]
435
- ) # actually, sin is redundancy
436
-
437
- x = self._istft(magnitude, phase)
438
- x = torch.clamp(x, -self.audio_limit, self.audio_limit)
439
- return x, s
440
-
441
- def remove_weight_norm(self):
442
- print("Removing weight norm...")
443
- for l in self.ups:
444
- remove_weight_norm(l)
445
- for l in self.resblocks:
446
- l.remove_weight_norm()
447
- remove_weight_norm(self.conv_pre)
448
- remove_weight_norm(self.conv_post)
449
- self.source_module.remove_weight_norm()
450
- for l in self.source_downs:
451
- remove_weight_norm(l)
452
- for l in self.source_resblocks:
453
- l.remove_weight_norm()
454
-
455
- @torch.inference_mode()
456
- def _inference_impl(self, mel: torch.Tensor, s_stft: torch.Tensor) -> torch.Tensor:
457
- x = self.conv_pre(mel)
458
- for i in range(self.num_upsamples):
459
- x = F.leaky_relu(x, self.lrelu_slope)
460
- x = self.ups[i](x)
461
-
462
- if i == self.num_upsamples - 1:
463
- x = self.reflection_pad(x)
464
-
465
- # fusion
466
- si = self.source_downs[i](s_stft)
467
- si = self.source_resblocks[i](si)
468
- x = x + si
469
-
470
- xs = None
471
- for j in range(self.num_kernels):
472
- if xs is None:
473
- xs = self.resblocks[i * self.num_kernels + j](x)
474
- else:
475
- xs += self.resblocks[i * self.num_kernels + j](x)
476
- x = xs / self.num_kernels
477
-
478
- x = F.leaky_relu(x)
479
- x = self.conv_post(x)
480
- magnitude = torch.exp(x[:, : self.istft_params["n_fft"] // 2 + 1, :])
481
- phase = torch.sin(
482
- x[:, self.istft_params["n_fft"] // 2 + 1 :, :]
483
- ) # actually, sin is redundancy
484
- # print(f"mel: {mel.shape}, magnitude: {magnitude.shape}, phase: {phase.shape}")
485
- return magnitude, phase
486
-
487
- @torch.inference_mode()
488
- def inference(
489
- self, mel: torch.Tensor, cache_source: torch.Tensor = torch.zeros(1, 1, 0)
490
- ) -> torch.Tensor:
491
- curr_seq_len = mel.shape[2]
492
- f0 = self.f0_predictor(mel)
493
- s = self._f02source(f0)
494
- s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
495
- s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1)
496
-
497
- target_len = None
498
- for seq_len in sorted(self.inference_buffers.keys()):
499
- if curr_seq_len <= seq_len:
500
- target_len = seq_len
501
- break
502
-
503
- if target_len is not None:
504
- buffer = self.inference_buffers[target_len]
505
-
506
- if curr_seq_len < target_len:
507
- padded_mel = torch.zeros_like(buffer["mel"])
508
- padded_mel[:, :, :curr_seq_len] = mel
509
- buffer["mel"].copy_(padded_mel)
510
- padded_s_stft = torch.zeros_like(buffer["s_stft"])
511
- cur_s_stft_len = s_stft.shape[2]
512
- padded_s_stft[:, :, :cur_s_stft_len] = s_stft
513
- buffer["s_stft"].copy_(padded_s_stft)
514
-
515
- else:
516
- buffer["mel"].copy_(mel)
517
- buffer["s_stft"].copy_(s_stft)
518
- cur_s_stft_len = s_stft.shape[2]
519
-
520
- self.inference_graphs[target_len].replay()
521
-
522
- magnitude, phase = (
523
- buffer["magnitude"][:, :, :cur_s_stft_len],
524
- buffer["phase"][:, :, :cur_s_stft_len],
525
- )
526
- else:
527
- magnitude, phase = self._inference_impl(mel=mel, s_stft=s_stft)
528
-
529
- x = self._istft(magnitude, phase)
530
- x = torch.clamp(x, -self.audio_limit, self.audio_limit)
531
- return x, s
532
-
533
- @torch.inference_mode()
534
- def capture_inference(self, seq_len_to_capture=[64, 128, 256, 512, 1024]):
535
- start_time = time.time()
536
- print(
537
- f"capture inference for HiFTGenerator with seq_len_to_capture: {seq_len_to_capture}"
538
- )
539
- for seq_len in seq_len_to_capture:
540
- mel = torch.randn(
541
- 1, 80, seq_len, device=torch.device("cuda"), dtype=torch.float32
542
- )
543
- f0 = self.f0_predictor(mel)
544
- s = self._f02source(f0)
545
- s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
546
- s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1)
547
-
548
- magnitude, phase = self._inference_impl(mel=mel, s_stft=s_stft)
549
- torch.cuda.synchronize()
550
-
551
- g = torch.cuda.CUDAGraph()
552
- with torch.cuda.graph(g):
553
- magnitude, phase = self._inference_impl(mel=mel, s_stft=s_stft)
554
- inference_buffer = {
555
- "mel": mel,
556
- "s_stft": s_stft,
557
- "magnitude": magnitude,
558
- "phase": phase,
559
- }
560
- self.inference_buffers[seq_len] = inference_buffer
561
- self.inference_graphs[seq_len] = g
562
-
563
- end_time = time.time()
564
- print(
565
- f"capture inference for HiFTGenerator with seq_len_to_capture: {seq_len_to_capture} takes {end_time - start_time} seconds"
566
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cosyvoice/matcha/audio.py DELETED
@@ -1,90 +0,0 @@
1
- import numpy as np
2
- import torch
3
- import torch.utils.data
4
- from librosa.filters import mel as librosa_mel_fn
5
- from scipy.io.wavfile import read
6
-
7
- MAX_WAV_VALUE = 32768.0
8
-
9
-
10
- def load_wav(full_path):
11
- sampling_rate, data = read(full_path)
12
- return data, sampling_rate
13
-
14
-
15
- def dynamic_range_compression(x, C=1, clip_val=1e-5):
16
- return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
17
-
18
-
19
- def dynamic_range_decompression(x, C=1):
20
- return np.exp(x) / C
21
-
22
-
23
- def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
24
- return torch.log(torch.clamp(x, min=clip_val) * C)
25
-
26
-
27
- def dynamic_range_decompression_torch(x, C=1):
28
- return torch.exp(x) / C
29
-
30
-
31
- def spectral_normalize_torch(magnitudes):
32
- output = dynamic_range_compression_torch(magnitudes)
33
- return output
34
-
35
-
36
- def spectral_de_normalize_torch(magnitudes):
37
- output = dynamic_range_decompression_torch(magnitudes)
38
- return output
39
-
40
-
41
- mel_basis = {}
42
- hann_window = {}
43
-
44
-
45
- def mel_spectrogram(
46
- y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False
47
- ):
48
- if torch.min(y) < -1.0:
49
- print("min value is ", torch.min(y))
50
- if torch.max(y) > 1.0:
51
- print("max value is ", torch.max(y))
52
-
53
- global mel_basis, hann_window # pylint: disable=global-statement
54
- if f"{str(fmax)}_{str(y.device)}" not in mel_basis:
55
- mel = librosa_mel_fn(
56
- sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
57
- )
58
- mel_basis[str(fmax) + "_" + str(y.device)] = (
59
- torch.from_numpy(mel).float().to(y.device)
60
- )
61
- hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
62
-
63
- y = torch.nn.functional.pad(
64
- y.unsqueeze(1),
65
- (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
66
- mode="reflect",
67
- )
68
- y = y.squeeze(1)
69
-
70
- spec = torch.view_as_real(
71
- torch.stft(
72
- y,
73
- n_fft,
74
- hop_length=hop_size,
75
- win_length=win_size,
76
- window=hann_window[str(y.device)],
77
- center=center,
78
- pad_mode="reflect",
79
- normalized=False,
80
- onesided=True,
81
- return_complex=True,
82
- )
83
- )
84
-
85
- spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
86
-
87
- spec = torch.matmul(mel_basis[str(fmax) + "_" + str(y.device)], spec)
88
- spec = spectral_normalize_torch(spec)
89
-
90
- return spec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cosyvoice/matcha/decoder.py DELETED
@@ -1,511 +0,0 @@
1
- import math
2
- from typing import Optional
3
-
4
- import torch
5
- import torch.nn as nn
6
- import torch.nn.functional as F
7
- from conformer import ConformerBlock
8
- from diffusers.models.activations import get_activation
9
- from einops import pack, rearrange, repeat
10
-
11
- from cosyvoice.matcha.transformer import BasicTransformerBlock
12
-
13
-
14
- class SinusoidalPosEmb(torch.nn.Module):
15
- def __init__(self, dim):
16
- super().__init__()
17
- self.dim = dim
18
- assert self.dim % 2 == 0, "SinusoidalPosEmb requires dim to be even"
19
-
20
- def forward(self, x, scale=1000):
21
- if x.ndim < 1:
22
- x = x.unsqueeze(0)
23
- device = x.device
24
- half_dim = self.dim // 2
25
- emb = math.log(10000) / (half_dim - 1)
26
- emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
27
- emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
28
- emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
29
- return emb
30
-
31
-
32
- class MaskedGroupNorm(nn.GroupNorm):
33
- """
34
- Masked verstion of the Group normalization.
35
-
36
- Based on: https://github.com/ptrblck/pytorch_misc/blob/20e8ea93bd458b88f921a87e2d4001a4eb753a02/batch_norm_manual.py
37
-
38
- Receives a N-dim tensor of sequence lengths per batch element
39
- along with the regular input for masking.
40
-
41
- Check pytorch's GroupNorm implementation for argument details.
42
- """
43
-
44
- def __init__(self, num_groups, num_channels, eps=1e-5, affine=True):
45
- super(MaskedGroupNorm, self).__init__(num_groups, num_channels, eps, affine)
46
-
47
- def forward(self, inp, mask=None):
48
- assert (
49
- inp.shape[1] % self.num_groups == 0
50
- ), "Feature size not divisible by groups"
51
-
52
- # 计算有效长度
53
- seq_lengths = mask.sum(-1, keepdim=True) # [batch_size, 1]
54
-
55
- # 将输入reshape为groups
56
- features_per_group = inp.shape[1] // self.num_groups
57
- inp_r = inp.reshape(
58
- inp.shape[0], self.num_groups, features_per_group, inp.shape[-1]
59
- )
60
- mask_r = mask.unsqueeze(1) # [batch_size, 1, 1, length]
61
-
62
- # 计算masked mean和variance
63
- masked_inp = inp_r * mask_r
64
- n = seq_lengths * features_per_group # 每组的有效元素数量
65
- mean = masked_inp.sum([2, 3], keepdim=True) / (n.view(-1, 1, 1, 1) + 1e-5)
66
- var = ((masked_inp - mean * mask_r) ** 2).sum([2, 3], keepdim=True) / (
67
- n.view(-1, 1, 1, 1) + 1e-5
68
- )
69
-
70
- # 标准化
71
- inp_r = (inp_r - mean) / (torch.sqrt(var + self.eps))
72
- out = inp_r.reshape(inp.shape[0], self.num_channels, inp.shape[-1])
73
-
74
- # 应用仿射变换
75
- if self.affine:
76
- out = out * self.weight[None, :, None] + self.bias[None, :, None]
77
-
78
- return out
79
-
80
-
81
- class Block1D(torch.nn.Module):
82
- def __init__(self, dim, dim_out, groups=8):
83
- super().__init__()
84
- self.block = torch.nn.Sequential(
85
- torch.nn.Conv1d(dim, dim_out, 3, padding=1),
86
- torch.nn.GroupNorm(groups, dim_out),
87
- # MaskedGroupNorm(groups, dim_out),
88
- nn.Mish(),
89
- )
90
-
91
- def forward(self, x, mask):
92
- output = self.block(x * mask)
93
- return output * mask
94
- return x * mask
95
-
96
-
97
- class ResnetBlock1D(torch.nn.Module):
98
- def __init__(self, dim, dim_out, time_emb_dim, groups=8):
99
- super().__init__()
100
- self.mlp = torch.nn.Sequential(
101
- nn.Mish(), torch.nn.Linear(time_emb_dim, dim_out)
102
- )
103
-
104
- self.block1 = Block1D(dim, dim_out, groups=groups)
105
- self.block2 = Block1D(dim_out, dim_out, groups=groups)
106
-
107
- self.res_conv = torch.nn.Conv1d(dim, dim_out, 1)
108
-
109
- def forward(self, x, mask, time_emb):
110
- h = self.block1(x, mask)
111
- h += self.mlp(time_emb).unsqueeze(-1)
112
- h = self.block2(h, mask)
113
- output = h + self.res_conv(x * mask)
114
- return output
115
-
116
-
117
- class Downsample1D(nn.Module):
118
- def __init__(self, dim):
119
- super().__init__()
120
- self.conv = torch.nn.Conv1d(dim, dim, 3, 2, 1)
121
-
122
- def forward(self, x):
123
- return self.conv(x)
124
-
125
-
126
- class TimestepEmbedding(nn.Module):
127
- def __init__(
128
- self,
129
- in_channels: int,
130
- time_embed_dim: int,
131
- act_fn: str = "silu",
132
- out_dim: int = None,
133
- post_act_fn: Optional[str] = None,
134
- cond_proj_dim=None,
135
- ):
136
- super().__init__()
137
-
138
- self.linear_1 = nn.Linear(in_channels, time_embed_dim)
139
-
140
- if cond_proj_dim is not None:
141
- self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
142
- else:
143
- self.cond_proj = None
144
-
145
- self.act = get_activation(act_fn)
146
-
147
- if out_dim is not None:
148
- time_embed_dim_out = out_dim
149
- else:
150
- time_embed_dim_out = time_embed_dim
151
- self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out)
152
-
153
- if post_act_fn is None:
154
- self.post_act = None
155
- else:
156
- self.post_act = get_activation(post_act_fn)
157
-
158
- def forward(self, sample, condition=None):
159
- if condition is not None:
160
- sample = sample + self.cond_proj(condition)
161
- sample = self.linear_1(sample)
162
-
163
- if self.act is not None:
164
- sample = self.act(sample)
165
-
166
- sample = self.linear_2(sample)
167
-
168
- if self.post_act is not None:
169
- sample = self.post_act(sample)
170
- return sample
171
-
172
-
173
- class Upsample1D(nn.Module):
174
- """A 1D upsampling layer with an optional convolution.
175
-
176
- Parameters:
177
- channels (`int`):
178
- number of channels in the inputs and outputs.
179
- use_conv (`bool`, default `False`):
180
- option to use a convolution.
181
- use_conv_transpose (`bool`, default `False`):
182
- option to use a convolution transpose.
183
- out_channels (`int`, optional):
184
- number of output channels. Defaults to `channels`.
185
- """
186
-
187
- def __init__(
188
- self,
189
- channels,
190
- use_conv=False,
191
- use_conv_transpose=True,
192
- out_channels=None,
193
- name="conv",
194
- ):
195
- super().__init__()
196
- self.channels = channels
197
- self.out_channels = out_channels or channels
198
- self.use_conv = use_conv
199
- self.use_conv_transpose = use_conv_transpose
200
- self.name = name
201
-
202
- self.conv = None
203
- if use_conv_transpose:
204
- self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1)
205
- elif use_conv:
206
- self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1)
207
-
208
- def forward(self, inputs):
209
- assert inputs.shape[1] == self.channels
210
- if self.use_conv_transpose:
211
- return self.conv(inputs)
212
-
213
- outputs = F.interpolate(inputs, scale_factor=2.0, mode="nearest")
214
-
215
- if self.use_conv:
216
- outputs = self.conv(outputs)
217
-
218
- return outputs
219
-
220
-
221
- class ConformerWrapper(ConformerBlock):
222
- def __init__( # pylint: disable=useless-super-delegation
223
- self,
224
- *,
225
- dim,
226
- dim_head=64,
227
- heads=8,
228
- ff_mult=4,
229
- conv_expansion_factor=2,
230
- conv_kernel_size=31,
231
- attn_dropout=0,
232
- ff_dropout=0,
233
- conv_dropout=0,
234
- conv_causal=False,
235
- ):
236
- super().__init__(
237
- dim=dim,
238
- dim_head=dim_head,
239
- heads=heads,
240
- ff_mult=ff_mult,
241
- conv_expansion_factor=conv_expansion_factor,
242
- conv_kernel_size=conv_kernel_size,
243
- attn_dropout=attn_dropout,
244
- ff_dropout=ff_dropout,
245
- conv_dropout=conv_dropout,
246
- conv_causal=conv_causal,
247
- )
248
-
249
- def forward(
250
- self,
251
- hidden_states,
252
- attention_mask,
253
- encoder_hidden_states=None,
254
- encoder_attention_mask=None,
255
- timestep=None,
256
- ):
257
- return super().forward(x=hidden_states, mask=attention_mask.bool())
258
-
259
-
260
- class Decoder(nn.Module):
261
- def __init__(
262
- self,
263
- in_channels,
264
- out_channels,
265
- channels=(256, 256),
266
- dropout=0.05,
267
- attention_head_dim=64,
268
- n_blocks=1,
269
- num_mid_blocks=2,
270
- num_heads=4,
271
- act_fn="snake",
272
- down_block_type="transformer",
273
- mid_block_type="transformer",
274
- up_block_type="transformer",
275
- ):
276
- super().__init__()
277
- channels = tuple(channels)
278
- self.in_channels = in_channels
279
- self.out_channels = out_channels
280
-
281
- self.time_embeddings = SinusoidalPosEmb(in_channels)
282
- time_embed_dim = channels[0] * 4
283
- self.time_mlp = TimestepEmbedding(
284
- in_channels=in_channels,
285
- time_embed_dim=time_embed_dim,
286
- act_fn="silu",
287
- )
288
-
289
- self.down_blocks = nn.ModuleList([])
290
- self.mid_blocks = nn.ModuleList([])
291
- self.up_blocks = nn.ModuleList([])
292
-
293
- output_channel = in_channels
294
- for i in range(len(channels)): # pylint: disable=consider-using-enumerate
295
- input_channel = output_channel
296
- output_channel = channels[i]
297
- is_last = i == len(channels) - 1
298
- resnet = ResnetBlock1D(
299
- dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim
300
- )
301
- transformer_blocks = nn.ModuleList(
302
- [
303
- self.get_block(
304
- down_block_type,
305
- output_channel,
306
- attention_head_dim,
307
- num_heads,
308
- dropout,
309
- act_fn,
310
- )
311
- for _ in range(n_blocks)
312
- ]
313
- )
314
- downsample = (
315
- Downsample1D(output_channel)
316
- if not is_last
317
- else nn.Conv1d(output_channel, output_channel, 3, padding=1)
318
- )
319
-
320
- self.down_blocks.append(
321
- nn.ModuleList([resnet, transformer_blocks, downsample])
322
- )
323
-
324
- for i in range(num_mid_blocks):
325
- input_channel = channels[-1]
326
- out_channels = channels[-1]
327
-
328
- resnet = ResnetBlock1D(
329
- dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim
330
- )
331
-
332
- transformer_blocks = nn.ModuleList(
333
- [
334
- self.get_block(
335
- mid_block_type,
336
- output_channel,
337
- attention_head_dim,
338
- num_heads,
339
- dropout,
340
- act_fn,
341
- )
342
- for _ in range(n_blocks)
343
- ]
344
- )
345
-
346
- self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))
347
-
348
- channels = channels[::-1] + (channels[0],)
349
- for i in range(len(channels) - 1):
350
- input_channel = channels[i]
351
- output_channel = channels[i + 1]
352
- is_last = i == len(channels) - 2
353
-
354
- resnet = ResnetBlock1D(
355
- dim=2 * input_channel,
356
- dim_out=output_channel,
357
- time_emb_dim=time_embed_dim,
358
- )
359
- transformer_blocks = nn.ModuleList(
360
- [
361
- self.get_block(
362
- up_block_type,
363
- output_channel,
364
- attention_head_dim,
365
- num_heads,
366
- dropout,
367
- act_fn,
368
- )
369
- for _ in range(n_blocks)
370
- ]
371
- )
372
- upsample = (
373
- Upsample1D(output_channel, use_conv_transpose=True)
374
- if not is_last
375
- else nn.Conv1d(output_channel, output_channel, 3, padding=1)
376
- )
377
-
378
- self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
379
-
380
- self.final_block = Block1D(channels[-1], channels[-1])
381
- self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
382
-
383
- self.initialize_weights()
384
- # nn.init.normal_(self.final_proj.weight)
385
-
386
- @staticmethod
387
- def get_block(block_type, dim, attention_head_dim, num_heads, dropout, act_fn):
388
- if block_type == "conformer":
389
- block = ConformerWrapper(
390
- dim=dim,
391
- dim_head=attention_head_dim,
392
- heads=num_heads,
393
- ff_mult=1,
394
- conv_expansion_factor=2,
395
- ff_dropout=dropout,
396
- attn_dropout=dropout,
397
- conv_dropout=dropout,
398
- conv_kernel_size=31,
399
- )
400
- elif block_type == "transformer":
401
- block = BasicTransformerBlock(
402
- dim=dim,
403
- num_attention_heads=num_heads,
404
- attention_head_dim=attention_head_dim,
405
- dropout=dropout,
406
- activation_fn=act_fn,
407
- )
408
- else:
409
- raise ValueError(f"Unknown block type {block_type}")
410
-
411
- return block
412
-
413
- def initialize_weights(self):
414
- for m in self.modules():
415
- if isinstance(m, nn.Conv1d):
416
- nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
417
-
418
- if m.bias is not None:
419
- nn.init.constant_(m.bias, 0)
420
-
421
- elif isinstance(m, nn.GroupNorm):
422
- nn.init.constant_(m.weight, 1)
423
- nn.init.constant_(m.bias, 0)
424
-
425
- elif isinstance(m, nn.Linear):
426
- nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
427
-
428
- if m.bias is not None:
429
- nn.init.constant_(m.bias, 0)
430
-
431
- def forward(self, x, mask, mu, t, spks=None, cond=None):
432
- """Forward pass of the UNet1DConditional model.
433
-
434
- Args:
435
- x (torch.Tensor): shape (batch_size, in_channels, time)
436
- mask (_type_): shape (batch_size, 1, time)
437
- t (_type_): shape (batch_size)
438
- spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
439
- cond (_type_, optional): placeholder for future use. Defaults to None.
440
-
441
- Raises:
442
- ValueError: _description_
443
- ValueError: _description_
444
-
445
- Returns:
446
- _type_: _description_
447
- """
448
-
449
- t = self.time_embeddings(t)
450
- t = self.time_mlp(t)
451
-
452
- x = pack([x, mu], "b * t")[0]
453
-
454
- if spks is not None:
455
- spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
456
- x = pack([x, spks], "b * t")[0]
457
-
458
- hiddens = []
459
- masks = [mask]
460
- for resnet, transformer_blocks, downsample in self.down_blocks:
461
- mask_down = masks[-1]
462
- x = resnet(x, mask_down, t)
463
- x = rearrange(x, "b c t -> b t c")
464
- mask_down = rearrange(mask_down, "b 1 t -> b t")
465
- for transformer_block in transformer_blocks:
466
- x = transformer_block(
467
- hidden_states=x,
468
- attention_mask=mask_down,
469
- timestep=t,
470
- )
471
- x = rearrange(x, "b t c -> b c t")
472
- mask_down = rearrange(mask_down, "b t -> b 1 t")
473
- hiddens.append(x) # Save hidden states for skip connections
474
- x = downsample(x * mask_down)
475
- masks.append(mask_down[:, :, ::2])
476
-
477
- masks = masks[:-1]
478
- mask_mid = masks[-1]
479
-
480
- for resnet, transformer_blocks in self.mid_blocks:
481
- x = resnet(x, mask_mid, t)
482
- x = rearrange(x, "b c t -> b t c")
483
- mask_mid = rearrange(mask_mid, "b 1 t -> b t")
484
- for transformer_block in transformer_blocks:
485
- x = transformer_block(
486
- hidden_states=x,
487
- attention_mask=mask_mid,
488
- timestep=t,
489
- )
490
- x = rearrange(x, "b t c -> b c t")
491
- mask_mid = rearrange(mask_mid, "b t -> b 1 t")
492
-
493
- for resnet, transformer_blocks, upsample in self.up_blocks:
494
- mask_up = masks.pop()
495
- x = resnet(pack([x, hiddens.pop()], "b * t")[0], mask_up, t)
496
- x = rearrange(x, "b c t -> b t c")
497
- mask_up = rearrange(mask_up, "b 1 t -> b t")
498
- for transformer_block in transformer_blocks:
499
- x = transformer_block(
500
- hidden_states=x,
501
- attention_mask=mask_up,
502
- timestep=t,
503
- )
504
- x = rearrange(x, "b t c -> b c t")
505
- mask_up = rearrange(mask_up, "b t -> b 1 t")
506
- x = upsample(x * mask_up)
507
-
508
- x = self.final_block(x, mask_up)
509
- output = self.final_proj(x * mask_up)
510
-
511
- return output * mask
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cosyvoice/matcha/flow_matching.py DELETED
@@ -1,141 +0,0 @@
1
- from abc import ABC
2
-
3
- import torch
4
- import torch.nn.functional as F
5
-
6
- from cosyvoice.matcha.decoder import Decoder
7
-
8
-
9
- class BASECFM(torch.nn.Module, ABC):
10
- def __init__(
11
- self,
12
- n_feats,
13
- cfm_params,
14
- n_spks=1,
15
- spk_emb_dim=128,
16
- ):
17
- super().__init__()
18
- self.n_feats = n_feats
19
- self.n_spks = n_spks
20
- self.spk_emb_dim = spk_emb_dim
21
- self.solver = cfm_params.solver
22
- if hasattr(cfm_params, "sigma_min"):
23
- self.sigma_min = cfm_params.sigma_min
24
- else:
25
- self.sigma_min = 1e-4
26
-
27
- self.estimator = None
28
-
29
- @torch.inference_mode()
30
- def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
31
- """Forward diffusion
32
-
33
- Args:
34
- mu (torch.Tensor): output of encoder
35
- shape: (batch_size, n_feats, mel_timesteps)
36
- mask (torch.Tensor): output_mask
37
- shape: (batch_size, 1, mel_timesteps)
38
- n_timesteps (int): number of diffusion steps
39
- temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
40
- spks (torch.Tensor, optional): speaker ids. Defaults to None.
41
- shape: (batch_size, spk_emb_dim)
42
- cond: Not used but kept for future purposes
43
-
44
- Returns:
45
- sample: generated mel-spectrogram
46
- shape: (batch_size, n_feats, mel_timesteps)
47
- """
48
- z = torch.randn_like(mu) * temperature
49
- t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
50
- return self.solve_euler(
51
- z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond
52
- )
53
-
54
- def solve_euler(self, x, t_span, mu, mask, spks, cond):
55
- """
56
- Fixed euler solver for ODEs.
57
- Args:
58
- x (torch.Tensor): random noise
59
- t_span (torch.Tensor): n_timesteps interpolated
60
- shape: (n_timesteps + 1,)
61
- mu (torch.Tensor): output of encoder
62
- shape: (batch_size, n_feats, mel_timesteps)
63
- mask (torch.Tensor): output_mask
64
- shape: (batch_size, 1, mel_timesteps)
65
- spks (torch.Tensor, optional): speaker ids. Defaults to None.
66
- shape: (batch_size, spk_emb_dim)
67
- cond: Not used but kept for future purposes
68
- """
69
- t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
70
-
71
- # I am storing this because I can later plot it by putting a debugger here and saving it to a file
72
- # Or in future might add like a return_all_steps flag
73
- sol = []
74
-
75
- for step in range(1, len(t_span)):
76
- dphi_dt = self.estimator(x, mask, mu, t, spks, cond)
77
-
78
- x = x + dt * dphi_dt
79
- t = t + dt
80
- sol.append(x)
81
- if step < len(t_span) - 1:
82
- dt = t_span[step + 1] - t
83
-
84
- return sol[-1]
85
-
86
- def compute_loss(self, x1, mask, mu, spks=None, cond=None):
87
- """Computes diffusion loss
88
-
89
- Args:
90
- x1 (torch.Tensor): Target
91
- shape: (batch_size, n_feats, mel_timesteps)
92
- mask (torch.Tensor): target mask
93
- shape: (batch_size, 1, mel_timesteps)
94
- mu (torch.Tensor): output of encoder
95
- shape: (batch_size, n_feats, mel_timesteps)
96
- spks (torch.Tensor, optional): speaker embedding. Defaults to None.
97
- shape: (batch_size, spk_emb_dim)
98
-
99
- Returns:
100
- loss: conditional flow matching loss
101
- y: conditional flow
102
- shape: (batch_size, n_feats, mel_timesteps)
103
- """
104
- b, _, t = mu.shape
105
-
106
- # random timestep
107
- t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
108
- # sample noise p(x_0)
109
- z = torch.randn_like(x1)
110
-
111
- y = (1 - (1 - self.sigma_min) * t) * z + t * x1
112
- u = x1 - (1 - self.sigma_min) * z
113
-
114
- loss = F.mse_loss(
115
- self.estimator(y, mask, mu, t.squeeze(), spks), u, reduction="sum"
116
- ) / (torch.sum(mask) * u.shape[1])
117
- return loss, y
118
-
119
-
120
- class CFM(BASECFM):
121
- def __init__(
122
- self,
123
- in_channels,
124
- out_channel,
125
- cfm_params,
126
- decoder_params,
127
- n_spks=1,
128
- spk_emb_dim=64,
129
- ):
130
- super().__init__(
131
- n_feats=in_channels,
132
- cfm_params=cfm_params,
133
- n_spks=n_spks,
134
- spk_emb_dim=spk_emb_dim,
135
- )
136
-
137
- in_channels = in_channels + (spk_emb_dim if n_spks > 1 else 0)
138
- # Just change the architecture of the estimator here
139
- self.estimator = Decoder(
140
- in_channels=in_channels, out_channels=out_channel, **decoder_params
141
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cosyvoice/matcha/transformer.py DELETED
@@ -1,443 +0,0 @@
1
- from typing import Any, Dict, Optional
2
-
3
- import torch
4
- import torch.nn as nn
5
- from diffusers.models.attention import (
6
- GEGLU,
7
- GELU,
8
- AdaLayerNorm,
9
- AdaLayerNormZero,
10
- ApproximateGELU,
11
- )
12
- from diffusers.models.attention_processor import Attention
13
- from diffusers.models.lora import LoRACompatibleLinear
14
- from diffusers.utils.torch_utils import maybe_allow_in_graph
15
-
16
-
17
- class SnakeBeta(nn.Module):
18
- """
19
- A modified Snake function which uses separate parameters for the magnitude of the periodic components
20
- Shape:
21
- - Input: (B, C, T)
22
- - Output: (B, C, T), same shape as the input
23
- Parameters:
24
- - alpha - trainable parameter that controls frequency
25
- - beta - trainable parameter that controls magnitude
26
- References:
27
- - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
28
- https://arxiv.org/abs/2006.08195
29
- Examples:
30
- >>> a1 = snakebeta(256)
31
- >>> x = torch.randn(256)
32
- >>> x = a1(x)
33
- """
34
-
35
- def __init__(
36
- self,
37
- in_features,
38
- out_features,
39
- alpha=1.0,
40
- alpha_trainable=True,
41
- alpha_logscale=True,
42
- ):
43
- """
44
- Initialization.
45
- INPUT:
46
- - in_features: shape of the input
47
- - alpha - trainable parameter that controls frequency
48
- - beta - trainable parameter that controls magnitude
49
- alpha is initialized to 1 by default, higher values = higher-frequency.
50
- beta is initialized to 1 by default, higher values = higher-magnitude.
51
- alpha will be trained along with the rest of your model.
52
- """
53
- super().__init__()
54
- self.in_features = (
55
- out_features if isinstance(out_features, list) else [out_features]
56
- )
57
- self.proj = LoRACompatibleLinear(in_features, out_features)
58
-
59
- # initialize alpha
60
- self.alpha_logscale = alpha_logscale
61
- if self.alpha_logscale: # log scale alphas initialized to zeros
62
- self.alpha = nn.Parameter(torch.zeros(self.in_features) * alpha)
63
- self.beta = nn.Parameter(torch.zeros(self.in_features) * alpha)
64
- else: # linear scale alphas initialized to ones
65
- self.alpha = nn.Parameter(torch.ones(self.in_features) * alpha)
66
- self.beta = nn.Parameter(torch.ones(self.in_features) * alpha)
67
-
68
- self.alpha.requires_grad = alpha_trainable
69
- self.beta.requires_grad = alpha_trainable
70
-
71
- self.no_div_by_zero = 0.000000001
72
-
73
- def forward(self, x):
74
- """
75
- Forward pass of the function.
76
- Applies the function to the input elementwise.
77
- SnakeBeta ∶= x + 1/b * sin^2 (xa)
78
- """
79
- x = self.proj(x)
80
- if self.alpha_logscale:
81
- alpha = torch.exp(self.alpha)
82
- beta = torch.exp(self.beta)
83
- else:
84
- alpha = self.alpha
85
- beta = self.beta
86
-
87
- x = x + (1.0 / (beta + self.no_div_by_zero)) * torch.pow(
88
- torch.sin(x * alpha), 2
89
- )
90
-
91
- return x
92
-
93
-
94
- class FeedForward(nn.Module):
95
- r"""
96
- A feed-forward layer.
97
-
98
- Parameters:
99
- dim (`int`): The number of channels in the input.
100
- dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
101
- mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
102
- dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
103
- activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
104
- final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
105
- """
106
-
107
- def __init__(
108
- self,
109
- dim: int,
110
- dim_out: Optional[int] = None,
111
- mult: int = 4,
112
- dropout: float = 0.0,
113
- activation_fn: str = "geglu",
114
- final_dropout: bool = False,
115
- ):
116
- super().__init__()
117
- inner_dim = int(dim * mult)
118
- dim_out = dim_out if dim_out is not None else dim
119
-
120
- if activation_fn == "gelu":
121
- act_fn = GELU(dim, inner_dim)
122
- if activation_fn == "gelu-approximate":
123
- act_fn = GELU(dim, inner_dim, approximate="tanh")
124
- elif activation_fn == "geglu":
125
- act_fn = GEGLU(dim, inner_dim)
126
- elif activation_fn == "geglu-approximate":
127
- act_fn = ApproximateGELU(dim, inner_dim)
128
- elif activation_fn == "snakebeta":
129
- act_fn = SnakeBeta(dim, inner_dim)
130
-
131
- self.net = nn.ModuleList([])
132
- # project in
133
- self.net.append(act_fn)
134
- # project dropout
135
- self.net.append(nn.Dropout(dropout))
136
- # project out
137
- self.net.append(LoRACompatibleLinear(inner_dim, dim_out))
138
- # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
139
- if final_dropout:
140
- self.net.append(nn.Dropout(dropout))
141
-
142
- def forward(self, hidden_states):
143
- for module in self.net:
144
- hidden_states = module(hidden_states)
145
- return hidden_states
146
-
147
-
148
- @maybe_allow_in_graph
149
- class BasicTransformerBlock(nn.Module):
150
- r"""
151
- A basic Transformer block.
152
-
153
- Parameters:
154
- dim (`int`): The number of channels in the input and output.
155
- num_attention_heads (`int`): The number of heads to use for multi-head attention.
156
- attention_head_dim (`int`): The number of channels in each head.
157
- dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
158
- cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
159
- only_cross_attention (`bool`, *optional*):
160
- Whether to use only cross-attention layers. In this case two cross attention layers are used.
161
- double_self_attention (`bool`, *optional*):
162
- Whether to use two self-attention layers. In this case no cross attention layers are used.
163
- activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
164
- num_embeds_ada_norm (:
165
- obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
166
- attention_bias (:
167
- obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
168
- """
169
-
170
- def __init__(
171
- self,
172
- dim: int,
173
- num_attention_heads: int,
174
- attention_head_dim: int,
175
- dropout=0.0,
176
- cross_attention_dim: Optional[int] = None,
177
- activation_fn: str = "geglu",
178
- num_embeds_ada_norm: Optional[int] = None,
179
- attention_bias: bool = False,
180
- only_cross_attention: bool = False,
181
- double_self_attention: bool = False,
182
- upcast_attention: bool = False,
183
- norm_elementwise_affine: bool = True,
184
- norm_type: str = "layer_norm",
185
- final_dropout: bool = False,
186
- ):
187
- super().__init__()
188
- self.only_cross_attention = only_cross_attention
189
-
190
- self.use_ada_layer_norm_zero = (
191
- num_embeds_ada_norm is not None
192
- ) and norm_type == "ada_norm_zero"
193
- self.use_ada_layer_norm = (
194
- num_embeds_ada_norm is not None
195
- ) and norm_type == "ada_norm"
196
-
197
- if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
198
- raise ValueError(
199
- f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
200
- f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
201
- )
202
-
203
- # Define 3 blocks. Each block has its own normalization layer.
204
- # 1. Self-Attn
205
- if self.use_ada_layer_norm:
206
- self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
207
- elif self.use_ada_layer_norm_zero:
208
- self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
209
- else:
210
- self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
211
- self.attn1 = Attention(
212
- query_dim=dim,
213
- heads=num_attention_heads,
214
- dim_head=attention_head_dim,
215
- dropout=dropout,
216
- bias=attention_bias,
217
- cross_attention_dim=cross_attention_dim if only_cross_attention else None,
218
- upcast_attention=upcast_attention,
219
- )
220
-
221
- # 2. Cross-Attn
222
- if cross_attention_dim is not None or double_self_attention:
223
- # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
224
- # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
225
- # the second cross attention block.
226
- self.norm2 = (
227
- AdaLayerNorm(dim, num_embeds_ada_norm)
228
- if self.use_ada_layer_norm
229
- else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
230
- )
231
- self.attn2 = Attention(
232
- query_dim=dim,
233
- cross_attention_dim=(
234
- cross_attention_dim if not double_self_attention else None
235
- ),
236
- heads=num_attention_heads,
237
- dim_head=attention_head_dim,
238
- dropout=dropout,
239
- bias=attention_bias,
240
- upcast_attention=upcast_attention,
241
- # scale_qk=False, # uncomment this to not to use flash attention
242
- ) # is self-attn if encoder_hidden_states is none
243
- else:
244
- self.norm2 = None
245
- self.attn2 = None
246
-
247
- # 3. Feed-forward
248
- self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
249
- self.ff = FeedForward(
250
- dim,
251
- dropout=dropout,
252
- activation_fn=activation_fn,
253
- final_dropout=final_dropout,
254
- )
255
-
256
- # let chunk size default to None
257
- self._chunk_size = None
258
- self._chunk_dim = 0
259
-
260
- def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
261
- # Sets chunk feed-forward
262
- self._chunk_size = chunk_size
263
- self._chunk_dim = dim
264
-
265
- def forward_native(
266
- self,
267
- hidden_states: torch.FloatTensor,
268
- attention_mask: Optional[torch.FloatTensor] = None,
269
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
270
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
271
- timestep: Optional[torch.LongTensor] = None,
272
- cross_attention_kwargs: Dict[str, Any] = None,
273
- class_labels: Optional[torch.LongTensor] = None,
274
- ):
275
- # Notice that normalization is always applied before the real computation in the following blocks.
276
- # 1. Self-Attention
277
- if self.use_ada_layer_norm:
278
- norm_hidden_states = self.norm1(hidden_states, timestep)
279
- elif self.use_ada_layer_norm_zero:
280
- norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
281
- hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
282
- )
283
- else:
284
- norm_hidden_states = self.norm1(hidden_states)
285
-
286
- cross_attention_kwargs = (
287
- cross_attention_kwargs if cross_attention_kwargs is not None else {}
288
- )
289
-
290
- attn_output = self.attn1(
291
- norm_hidden_states,
292
- encoder_hidden_states=(
293
- encoder_hidden_states if self.only_cross_attention else None
294
- ),
295
- attention_mask=(
296
- encoder_attention_mask if self.only_cross_attention else attention_mask
297
- ),
298
- **cross_attention_kwargs,
299
- )
300
- if self.use_ada_layer_norm_zero:
301
- attn_output = gate_msa.unsqueeze(1) * attn_output
302
- hidden_states = attn_output + hidden_states
303
-
304
- # 2. Cross-Attention
305
- if self.attn2 is not None:
306
- norm_hidden_states = (
307
- self.norm2(hidden_states, timestep)
308
- if self.use_ada_layer_norm
309
- else self.norm2(hidden_states)
310
- )
311
-
312
- attn_output = self.attn2(
313
- norm_hidden_states,
314
- encoder_hidden_states=encoder_hidden_states,
315
- attention_mask=encoder_attention_mask,
316
- **cross_attention_kwargs,
317
- )
318
- hidden_states = attn_output + hidden_states
319
-
320
- # 3. Feed-forward
321
- norm_hidden_states = self.norm3(hidden_states)
322
-
323
- if self.use_ada_layer_norm_zero:
324
- norm_hidden_states = (
325
- norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
326
- )
327
-
328
- if self._chunk_size is not None:
329
- # "feed_forward_chunk_size" can be used to save memory
330
- if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
331
- raise ValueError(
332
- f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
333
- )
334
-
335
- num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
336
- ff_output = torch.cat(
337
- [
338
- self.ff(hid_slice)
339
- for hid_slice in norm_hidden_states.chunk(
340
- num_chunks, dim=self._chunk_dim
341
- )
342
- ],
343
- dim=self._chunk_dim,
344
- )
345
- else:
346
- ff_output = self.ff(norm_hidden_states)
347
-
348
- if self.use_ada_layer_norm_zero:
349
- ff_output = gate_mlp.unsqueeze(1) * ff_output
350
-
351
- hidden_states = ff_output + hidden_states
352
-
353
- return hidden_states
354
-
355
- def forward(
356
- self,
357
- hidden_states: torch.FloatTensor,
358
- attention_mask: Optional[torch.FloatTensor] = None,
359
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
360
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
361
- timestep: Optional[torch.LongTensor] = None,
362
- cross_attention_kwargs: Dict[str, Any] = None,
363
- class_labels: Optional[torch.LongTensor] = None,
364
- ):
365
- # Notice that normalization is always applied before the real computation in the following blocks.
366
- # 1. Self-Attention
367
- if self.use_ada_layer_norm:
368
- norm_hidden_states = self.norm1(hidden_states, timestep)
369
- elif self.use_ada_layer_norm_zero:
370
- norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
371
- hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
372
- )
373
- else:
374
- norm_hidden_states = self.norm1(hidden_states)
375
-
376
- cross_attention_kwargs = (
377
- cross_attention_kwargs if cross_attention_kwargs is not None else {}
378
- )
379
-
380
- attn_output = self.attn1(
381
- norm_hidden_states,
382
- encoder_hidden_states=(
383
- encoder_hidden_states if self.only_cross_attention else None
384
- ),
385
- attention_mask=(
386
- encoder_attention_mask if self.only_cross_attention else attention_mask
387
- ),
388
- **cross_attention_kwargs,
389
- )
390
- if self.use_ada_layer_norm_zero:
391
- attn_output = gate_msa.unsqueeze(1) * attn_output
392
- hidden_states = attn_output + hidden_states
393
-
394
- # 2. Cross-Attention
395
- if self.attn2 is not None:
396
- norm_hidden_states = (
397
- self.norm2(hidden_states, timestep)
398
- if self.use_ada_layer_norm
399
- else self.norm2(hidden_states)
400
- )
401
-
402
- attn_output = self.attn2(
403
- norm_hidden_states,
404
- encoder_hidden_states=encoder_hidden_states,
405
- attention_mask=encoder_attention_mask,
406
- **cross_attention_kwargs,
407
- )
408
- hidden_states = attn_output + hidden_states
409
-
410
- # 3. Feed-forward
411
- norm_hidden_states = self.norm3(hidden_states)
412
-
413
- if self.use_ada_layer_norm_zero:
414
- norm_hidden_states = (
415
- norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
416
- )
417
-
418
- if self._chunk_size is not None:
419
- # "feed_forward_chunk_size" can be used to save memory
420
- if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
421
- raise ValueError(
422
- f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
423
- )
424
-
425
- num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
426
- ff_output = torch.cat(
427
- [
428
- self.ff(hid_slice)
429
- for hid_slice in norm_hidden_states.chunk(
430
- num_chunks, dim=self._chunk_dim
431
- )
432
- ],
433
- dim=self._chunk_dim,
434
- )
435
- else:
436
- ff_output = self.ff(norm_hidden_states)
437
-
438
- if self.use_ada_layer_norm_zero:
439
- ff_output = gate_mlp.unsqueeze(1) * ff_output
440
-
441
- hidden_states = ff_output + hidden_states
442
-
443
- return hidden_states
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cosyvoice/transformer/__init__.py DELETED
File without changes
cosyvoice/transformer/activation.py DELETED
@@ -1,87 +0,0 @@
1
- # Copyright (c) 2020 Johns Hopkins University (Shinji Watanabe)
2
- # 2020 Northwestern Polytechnical University (Pengcheng Guo)
3
- # 2020 Mobvoi Inc (Binbin Zhang)
4
- # 2024 Alibaba Inc (Xiang Lyu)
5
- #
6
- # Licensed under the Apache License, Version 2.0 (the "License");
7
- # you may not use this file except in compliance with the License.
8
- # You may obtain a copy of the License at
9
- #
10
- # http://www.apache.org/licenses/LICENSE-2.0
11
- #
12
- # Unless required by applicable law or agreed to in writing, software
13
- # distributed under the License is distributed on an "AS IS" BASIS,
14
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
- # See the License for the specific language governing permissions and
16
- # limitations under the License.
17
- """Swish() activation function for Conformer."""
18
-
19
- import torch
20
- from torch import nn, sin, pow
21
- from torch.nn import Parameter
22
-
23
-
24
- class Swish(torch.nn.Module):
25
- """Construct an Swish object."""
26
-
27
- def forward(self, x: torch.Tensor) -> torch.Tensor:
28
- """Return Swish activation function."""
29
- return x * torch.sigmoid(x)
30
-
31
-
32
- # Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license.
33
- # LICENSE is in incl_licenses directory.
34
- class Snake(nn.Module):
35
- """
36
- Implementation of a sine-based periodic activation function
37
- Shape:
38
- - Input: (B, C, T)
39
- - Output: (B, C, T), same shape as the input
40
- Parameters:
41
- - alpha - trainable parameter
42
- References:
43
- - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
44
- https://arxiv.org/abs/2006.08195
45
- Examples:
46
- >>> a1 = snake(256)
47
- >>> x = torch.randn(256)
48
- >>> x = a1(x)
49
- """
50
-
51
- def __init__(
52
- self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False
53
- ):
54
- """
55
- Initialization.
56
- INPUT:
57
- - in_features: shape of the input
58
- - alpha: trainable parameter
59
- alpha is initialized to 1 by default, higher values = higher-frequency.
60
- alpha will be trained along with the rest of your model.
61
- """
62
- super(Snake, self).__init__()
63
- self.in_features = in_features
64
-
65
- # initialize alpha
66
- self.alpha_logscale = alpha_logscale
67
- if self.alpha_logscale: # log scale alphas initialized to zeros
68
- self.alpha = Parameter(torch.zeros(in_features) * alpha)
69
- else: # linear scale alphas initialized to ones
70
- self.alpha = Parameter(torch.ones(in_features) * alpha)
71
-
72
- self.alpha.requires_grad = alpha_trainable
73
-
74
- self.no_div_by_zero = 0.000000001
75
-
76
- def forward(self, x):
77
- """
78
- Forward pass of the function.
79
- Applies the function to the input elementwise.
80
- Snake ∶= x + 1/a * sin^2 (xa)
81
- """
82
- alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
83
- if self.alpha_logscale:
84
- alpha = torch.exp(alpha)
85
- x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
86
-
87
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cosyvoice/transformer/attention.py DELETED
@@ -1,322 +0,0 @@
1
- # Copyright (c) 2019 Shigeki Karita
2
- # 2020 Mobvoi Inc (Binbin Zhang)
3
- # 2022 Xingchen Song ([email protected])
4
- # 2024 Alibaba Inc (Xiang Lyu)
5
- #
6
- # Licensed under the Apache License, Version 2.0 (the "License");
7
- # you may not use this file except in compliance with the License.
8
- # You may obtain a copy of the License at
9
- #
10
- # http://www.apache.org/licenses/LICENSE-2.0
11
- #
12
- # Unless required by applicable law or agreed to in writing, software
13
- # distributed under the License is distributed on an "AS IS" BASIS,
14
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
- # See the License for the specific language governing permissions and
16
- # limitations under the License.
17
- """Multi-Head Attention layer definition."""
18
-
19
- import math
20
- from typing import Tuple
21
-
22
- import torch
23
- from torch import nn
24
-
25
-
26
- class MultiHeadedAttention(nn.Module):
27
- """Multi-Head Attention layer.
28
-
29
- Args:
30
- n_head (int): The number of heads.
31
- n_feat (int): The number of features.
32
- dropout_rate (float): Dropout rate.
33
-
34
- """
35
-
36
- def __init__(
37
- self, n_head: int, n_feat: int, dropout_rate: float, key_bias: bool = True
38
- ):
39
- """Construct an MultiHeadedAttention object."""
40
- super().__init__()
41
- assert n_feat % n_head == 0
42
- # We assume d_v always equals d_k
43
- self.d_k = n_feat // n_head
44
- self.h = n_head
45
- self.linear_q = nn.Linear(n_feat, n_feat)
46
- self.linear_k = nn.Linear(n_feat, n_feat, bias=key_bias)
47
- self.linear_v = nn.Linear(n_feat, n_feat)
48
- self.linear_out = nn.Linear(n_feat, n_feat)
49
- self.dropout = nn.Dropout(p=dropout_rate)
50
-
51
- def forward_qkv(
52
- self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
53
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
54
- """Transform query, key and value.
55
-
56
- Args:
57
- query (torch.Tensor): Query tensor (#batch, time1, size).
58
- key (torch.Tensor): Key tensor (#batch, time2, size).
59
- value (torch.Tensor): Value tensor (#batch, time2, size).
60
-
61
- Returns:
62
- torch.Tensor: Transformed query tensor, size
63
- (#batch, n_head, time1, d_k).
64
- torch.Tensor: Transformed key tensor, size
65
- (#batch, n_head, time2, d_k).
66
- torch.Tensor: Transformed value tensor, size
67
- (#batch, n_head, time2, d_k).
68
-
69
- """
70
- n_batch = query.size(0)
71
- q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
72
- k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
73
- v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
74
- q = q.transpose(1, 2) # (batch, head, time1, d_k)
75
- k = k.transpose(1, 2) # (batch, head, time2, d_k)
76
- v = v.transpose(1, 2) # (batch, head, time2, d_k)
77
-
78
- return q, k, v
79
-
80
- def forward_attention(
81
- self,
82
- value: torch.Tensor,
83
- scores: torch.Tensor,
84
- mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
85
- ) -> torch.Tensor:
86
- """Compute attention context vector.
87
-
88
- Args:
89
- value (torch.Tensor): Transformed value, size
90
- (#batch, n_head, time2, d_k).
91
- scores (torch.Tensor): Attention score, size
92
- (#batch, n_head, time1, time2).
93
- mask (torch.Tensor): Mask, size (#batch, 1, time2) or
94
- (#batch, time1, time2), (0, 0, 0) means fake mask.
95
-
96
- Returns:
97
- torch.Tensor: Transformed value (#batch, time1, d_model)
98
- weighted by the attention score (#batch, time1, time2).
99
-
100
- """
101
- n_batch = value.size(0)
102
- # NOTE(xcsong): When will `if mask.size(2) > 0` be True?
103
- # 1. onnx(16/4) [WHY? Because we feed real cache & real mask for the
104
- # 1st chunk to ease the onnx export.]
105
- # 2. pytorch training
106
- if mask.size(2) > 0: # time2 > 0
107
- mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
108
- # For last chunk, time2 might be larger than scores.size(-1)
109
- mask = mask[:, :, :, : scores.size(-1)] # (batch, 1, *, time2)
110
- scores = scores.masked_fill(mask, -float("inf"))
111
- attn = torch.softmax(scores, dim=-1).masked_fill(
112
- mask, 0.0
113
- ) # (batch, head, time1, time2)
114
- # NOTE(xcsong): When will `if mask.size(2) > 0` be False?
115
- # 1. onnx(16/-1, -1/-1, 16/0)
116
- # 2. jit (16/-1, -1/-1, 16/0, 16/4)
117
- else:
118
- attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
119
-
120
- p_attn = self.dropout(attn)
121
- x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
122
- x = (
123
- x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
124
- ) # (batch, time1, d_model)
125
-
126
- return self.linear_out(x) # (batch, time1, d_model)
127
-
128
- def forward(
129
- self,
130
- query: torch.Tensor,
131
- key: torch.Tensor,
132
- value: torch.Tensor,
133
- mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
134
- pos_emb: torch.Tensor = torch.empty(0),
135
- cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
136
- ) -> Tuple[torch.Tensor, torch.Tensor]:
137
- """Compute scaled dot product attention.
138
-
139
- Args:
140
- query (torch.Tensor): Query tensor (#batch, time1, size).
141
- key (torch.Tensor): Key tensor (#batch, time2, size).
142
- value (torch.Tensor): Value tensor (#batch, time2, size).
143
- mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
144
- (#batch, time1, time2).
145
- 1.When applying cross attention between decoder and encoder,
146
- the batch padding mask for input is in (#batch, 1, T) shape.
147
- 2.When applying self attention of encoder,
148
- the mask is in (#batch, T, T) shape.
149
- 3.When applying self attention of decoder,
150
- the mask is in (#batch, L, L) shape.
151
- 4.If the different position in decoder see different block
152
- of the encoder, such as Mocha, the passed in mask could be
153
- in (#batch, L, T) shape. But there is no such case in current
154
- CosyVoice.
155
- cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
156
- where `cache_t == chunk_size * num_decoding_left_chunks`
157
- and `head * d_k == size`
158
-
159
-
160
- Returns:
161
- torch.Tensor: Output tensor (#batch, time1, d_model).
162
- torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
163
- where `cache_t == chunk_size * num_decoding_left_chunks`
164
- and `head * d_k == size`
165
-
166
- """
167
- q, k, v = self.forward_qkv(query, key, value)
168
-
169
- # NOTE(xcsong):
170
- # when export onnx model, for 1st chunk, we feed
171
- # cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
172
- # or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
173
- # In all modes, `if cache.size(0) > 0` will alwayse be `True`
174
- # and we will always do splitting and
175
- # concatnation(this will simplify onnx export). Note that
176
- # it's OK to concat & split zero-shaped tensors(see code below).
177
- # when export jit model, for 1st chunk, we always feed
178
- # cache(0, 0, 0, 0) since jit supports dynamic if-branch.
179
- # >>> a = torch.ones((1, 2, 0, 4))
180
- # >>> b = torch.ones((1, 2, 3, 4))
181
- # >>> c = torch.cat((a, b), dim=2)
182
- # >>> torch.equal(b, c) # True
183
- # >>> d = torch.split(a, 2, dim=-1)
184
- # >>> torch.equal(d[0], d[1]) # True
185
- if cache.size(0) > 0:
186
- key_cache, value_cache = torch.split(cache, cache.size(-1) // 2, dim=-1)
187
- k = torch.cat([key_cache, k], dim=2)
188
- v = torch.cat([value_cache, v], dim=2)
189
- # NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
190
- # non-trivial to calculate `next_cache_start` here.
191
- new_cache = torch.cat((k, v), dim=-1)
192
-
193
- scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
194
- return self.forward_attention(v, scores, mask), new_cache
195
-
196
-
197
- class RelPositionMultiHeadedAttention(MultiHeadedAttention):
198
- """Multi-Head Attention layer with relative position encoding.
199
- Paper: https://arxiv.org/abs/1901.02860
200
- Args:
201
- n_head (int): The number of heads.
202
- n_feat (int): The number of features.
203
- dropout_rate (float): Dropout rate.
204
- """
205
-
206
- def __init__(
207
- self, n_head: int, n_feat: int, dropout_rate: float, key_bias: bool = True
208
- ):
209
- """Construct an RelPositionMultiHeadedAttention object."""
210
- super().__init__(n_head, n_feat, dropout_rate, key_bias)
211
- # linear transformation for positional encoding
212
- self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
213
- # these two learnable bias are used in matrix c and matrix d
214
- # as described in https://arxiv.org/abs/1901.02860 Section 3.3
215
- self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
216
- self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
217
- torch.nn.init.xavier_uniform_(self.pos_bias_u)
218
- torch.nn.init.xavier_uniform_(self.pos_bias_v)
219
-
220
- def rel_shift(self, x: torch.Tensor) -> torch.Tensor:
221
- """Compute relative positional encoding.
222
-
223
- Args:
224
- x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1).
225
- time1 means the length of query vector.
226
-
227
- Returns:
228
- torch.Tensor: Output tensor.
229
-
230
- """
231
- zero_pad = torch.zeros(
232
- (x.size()[0], x.size()[1], x.size()[2], 1), device=x.device, dtype=x.dtype
233
- )
234
- x_padded = torch.cat([zero_pad, x], dim=-1)
235
-
236
- x_padded = x_padded.view(x.size()[0], x.size()[1], x.size(3) + 1, x.size(2))
237
- x = x_padded[:, :, 1:].view_as(x)[
238
- :, :, :, : x.size(-1) // 2 + 1
239
- ] # only keep the positions from 0 to time2
240
- return x
241
-
242
- def forward(
243
- self,
244
- query: torch.Tensor,
245
- key: torch.Tensor,
246
- value: torch.Tensor,
247
- mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
248
- pos_emb: torch.Tensor = torch.empty(0),
249
- cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
250
- ) -> Tuple[torch.Tensor, torch.Tensor]:
251
- """Compute 'Scaled Dot Product Attention' with rel. positional encoding.
252
- Args:
253
- query (torch.Tensor): Query tensor (#batch, time1, size).
254
- key (torch.Tensor): Key tensor (#batch, time2, size).
255
- value (torch.Tensor): Value tensor (#batch, time2, size).
256
- mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
257
- (#batch, time1, time2), (0, 0, 0) means fake mask.
258
- pos_emb (torch.Tensor): Positional embedding tensor
259
- (#batch, time2, size).
260
- cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
261
- where `cache_t == chunk_size * num_decoding_left_chunks`
262
- and `head * d_k == size`
263
- Returns:
264
- torch.Tensor: Output tensor (#batch, time1, d_model).
265
- torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
266
- where `cache_t == chunk_size * num_decoding_left_chunks`
267
- and `head * d_k == size`
268
- """
269
- q, k, v = self.forward_qkv(query, key, value)
270
- q = q.transpose(1, 2) # (batch, time1, head, d_k)
271
-
272
- # NOTE(xcsong):
273
- # when export onnx model, for 1st chunk, we feed
274
- # cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
275
- # or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
276
- # In all modes, `if cache.size(0) > 0` will alwayse be `True`
277
- # and we will always do splitting and
278
- # concatnation(this will simplify onnx export). Note that
279
- # it's OK to concat & split zero-shaped tensors(see code below).
280
- # when export jit model, for 1st chunk, we always feed
281
- # cache(0, 0, 0, 0) since jit supports dynamic if-branch.
282
- # >>> a = torch.ones((1, 2, 0, 4))
283
- # >>> b = torch.ones((1, 2, 3, 4))
284
- # >>> c = torch.cat((a, b), dim=2)
285
- # >>> torch.equal(b, c) # True
286
- # >>> d = torch.split(a, 2, dim=-1)
287
- # >>> torch.equal(d[0], d[1]) # True
288
- if cache.size(0) > 0:
289
- key_cache, value_cache = torch.split(cache, cache.size(-1) // 2, dim=-1)
290
- k = torch.cat([key_cache, k], dim=2)
291
- v = torch.cat([value_cache, v], dim=2)
292
- # NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
293
- # non-trivial to calculate `next_cache_start` here.
294
- new_cache = torch.cat((k, v), dim=-1)
295
-
296
- n_batch_pos = pos_emb.size(0)
297
- p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
298
- p = p.transpose(1, 2) # (batch, head, time1, d_k)
299
-
300
- # (batch, head, time1, d_k)
301
- q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
302
- # (batch, head, time1, d_k)
303
- q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
304
-
305
- # compute attention score
306
- # first compute matrix a and matrix c
307
- # as described in https://arxiv.org/abs/1901.02860 Section 3.3
308
- # (batch, head, time1, time2)
309
- matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
310
-
311
- # compute matrix b and matrix d
312
- # (batch, head, time1, time2)
313
- matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
314
- # NOTE(Xiang Lyu): Keep rel_shift since espnet rel_pos_emb is used
315
- if matrix_ac.shape != matrix_bd.shape:
316
- matrix_bd = self.rel_shift(matrix_bd)
317
-
318
- scores = (matrix_ac + matrix_bd) / math.sqrt(
319
- self.d_k
320
- ) # (batch, head, time1, time2)
321
-
322
- return self.forward_attention(v, scores, mask), new_cache
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cosyvoice/transformer/convolution.py DELETED
@@ -1,147 +0,0 @@
1
- # Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu)
2
- # 2024 Alibaba Inc (Xiang Lyu)
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- # Modified from ESPnet(https://github.com/espnet/espnet)
16
- """ConvolutionModule definition."""
17
-
18
- from typing import Tuple
19
-
20
- import torch
21
- from torch import nn
22
-
23
-
24
- class ConvolutionModule(nn.Module):
25
- """ConvolutionModule in Conformer model."""
26
-
27
- def __init__(
28
- self,
29
- channels: int,
30
- kernel_size: int = 15,
31
- activation: nn.Module = nn.ReLU(),
32
- norm: str = "batch_norm",
33
- causal: bool = False,
34
- bias: bool = True,
35
- ):
36
- """Construct an ConvolutionModule object.
37
- Args:
38
- channels (int): The number of channels of conv layers.
39
- kernel_size (int): Kernel size of conv layers.
40
- causal (int): Whether use causal convolution or not
41
- """
42
- super().__init__()
43
-
44
- self.pointwise_conv1 = nn.Conv1d(
45
- channels,
46
- 2 * channels,
47
- kernel_size=1,
48
- stride=1,
49
- padding=0,
50
- bias=bias,
51
- )
52
- # self.lorder is used to distinguish if it's a causal convolution,
53
- # if self.lorder > 0: it's a causal convolution, the input will be
54
- # padded with self.lorder frames on the left in forward.
55
- # else: it's a symmetrical convolution
56
- if causal:
57
- padding = 0
58
- self.lorder = kernel_size - 1
59
- else:
60
- # kernel_size should be an odd number for none causal convolution
61
- assert (kernel_size - 1) % 2 == 0
62
- padding = (kernel_size - 1) // 2
63
- self.lorder = 0
64
- self.depthwise_conv = nn.Conv1d(
65
- channels,
66
- channels,
67
- kernel_size,
68
- stride=1,
69
- padding=padding,
70
- groups=channels,
71
- bias=bias,
72
- )
73
-
74
- assert norm in ["batch_norm", "layer_norm"]
75
- if norm == "batch_norm":
76
- self.use_layer_norm = False
77
- self.norm = nn.BatchNorm1d(channels)
78
- else:
79
- self.use_layer_norm = True
80
- self.norm = nn.LayerNorm(channels)
81
-
82
- self.pointwise_conv2 = nn.Conv1d(
83
- channels,
84
- channels,
85
- kernel_size=1,
86
- stride=1,
87
- padding=0,
88
- bias=bias,
89
- )
90
- self.activation = activation
91
-
92
- def forward(
93
- self,
94
- x: torch.Tensor,
95
- mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
96
- cache: torch.Tensor = torch.zeros((0, 0, 0)),
97
- ) -> Tuple[torch.Tensor, torch.Tensor]:
98
- """Compute convolution module.
99
- Args:
100
- x (torch.Tensor): Input tensor (#batch, time, channels).
101
- mask_pad (torch.Tensor): used for batch padding (#batch, 1, time),
102
- (0, 0, 0) means fake mask.
103
- cache (torch.Tensor): left context cache, it is only
104
- used in causal convolution (#batch, channels, cache_t),
105
- (0, 0, 0) meas fake cache.
106
- Returns:
107
- torch.Tensor: Output tensor (#batch, time, channels).
108
- """
109
- # exchange the temporal dimension and the feature dimension
110
- x = x.transpose(1, 2) # (#batch, channels, time)
111
-
112
- # mask batch padding
113
- if mask_pad.size(2) > 0: # time > 0
114
- x.masked_fill_(~mask_pad, 0.0)
115
-
116
- if self.lorder > 0:
117
- if cache.size(2) == 0: # cache_t == 0
118
- x = nn.functional.pad(x, (self.lorder, 0), "constant", 0.0)
119
- else:
120
- assert cache.size(0) == x.size(0) # equal batch
121
- assert cache.size(1) == x.size(1) # equal channel
122
- x = torch.cat((cache, x), dim=2)
123
- assert x.size(2) > self.lorder
124
- new_cache = x[:, :, -self.lorder :]
125
- else:
126
- # It's better we just return None if no cache is required,
127
- # However, for JIT export, here we just fake one tensor instead of
128
- # None.
129
- new_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
130
-
131
- # GLU mechanism
132
- x = self.pointwise_conv1(x) # (batch, 2*channel, dim)
133
- x = nn.functional.glu(x, dim=1) # (batch, channel, dim)
134
-
135
- # 1D Depthwise Conv
136
- x = self.depthwise_conv(x)
137
- if self.use_layer_norm:
138
- x = x.transpose(1, 2)
139
- x = self.activation(self.norm(x))
140
- if self.use_layer_norm:
141
- x = x.transpose(1, 2)
142
- x = self.pointwise_conv2(x)
143
- # mask batch padding
144
- if mask_pad.size(2) > 0: # time > 0
145
- x.masked_fill_(~mask_pad, 0.0)
146
-
147
- return x.transpose(1, 2), new_cache
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cosyvoice/transformer/decoder.py DELETED
@@ -1,418 +0,0 @@
1
- # Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang, Di Wu)
2
- # 2024 Alibaba Inc (Xiang Lyu)
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- # Modified from ESPnet(https://github.com/espnet/espnet)
16
- """Decoder definition."""
17
- from typing import Tuple, List, Optional
18
-
19
- import torch
20
- import torch.utils.checkpoint as ckpt
21
- import logging
22
-
23
- from cosyvoice.transformer.decoder_layer import DecoderLayer
24
- from cosyvoice.transformer.positionwise_feed_forward import (
25
- PositionwiseFeedForward,
26
- )
27
- from cosyvoice.utils.class_utils import (
28
- COSYVOICE_EMB_CLASSES,
29
- COSYVOICE_ATTENTION_CLASSES,
30
- COSYVOICE_ACTIVATION_CLASSES,
31
- )
32
- from cosyvoice.utils.mask import subsequent_mask, make_pad_mask
33
-
34
-
35
- class TransformerDecoder(torch.nn.Module):
36
- """Base class of Transfomer decoder module.
37
- Args:
38
- vocab_size: output dim
39
- encoder_output_size: dimension of attention
40
- attention_heads: the number of heads of multi head attention
41
- linear_units: the hidden units number of position-wise feedforward
42
- num_blocks: the number of decoder blocks
43
- dropout_rate: dropout rate
44
- self_attention_dropout_rate: dropout rate for attention
45
- input_layer: input layer type
46
- use_output_layer: whether to use output layer
47
- pos_enc_class: PositionalEncoding or ScaledPositionalEncoding
48
- normalize_before:
49
- True: use layer_norm before each sub-block of a layer.
50
- False: use layer_norm after each sub-block of a layer.
51
- src_attention: if false, encoder-decoder cross attention is not
52
- applied, such as CIF model
53
- key_bias: whether use bias in attention.linear_k, False for whisper models.
54
- gradient_checkpointing: rerunning a forward-pass segment for each
55
- checkpointed segment during backward.
56
- tie_word_embedding: Tie or clone module weights depending of whether we are
57
- using TorchScript or not
58
- """
59
-
60
- def __init__(
61
- self,
62
- vocab_size: int,
63
- encoder_output_size: int,
64
- attention_heads: int = 4,
65
- linear_units: int = 2048,
66
- num_blocks: int = 6,
67
- dropout_rate: float = 0.1,
68
- positional_dropout_rate: float = 0.1,
69
- self_attention_dropout_rate: float = 0.0,
70
- src_attention_dropout_rate: float = 0.0,
71
- input_layer: str = "embed",
72
- use_output_layer: bool = True,
73
- normalize_before: bool = True,
74
- src_attention: bool = True,
75
- key_bias: bool = True,
76
- activation_type: str = "relu",
77
- gradient_checkpointing: bool = False,
78
- tie_word_embedding: bool = False,
79
- ):
80
- super().__init__()
81
- attention_dim = encoder_output_size
82
- activation = COSYVOICE_ACTIVATION_CLASSES[activation_type]()
83
-
84
- self.embed = torch.nn.Sequential(
85
- (
86
- torch.nn.Identity()
87
- if input_layer == "no_pos"
88
- else torch.nn.Embedding(vocab_size, attention_dim)
89
- ),
90
- COSYVOICE_EMB_CLASSES[input_layer](attention_dim, positional_dropout_rate),
91
- )
92
-
93
- self.normalize_before = normalize_before
94
- self.after_norm = torch.nn.LayerNorm(attention_dim, eps=1e-5)
95
- self.use_output_layer = use_output_layer
96
- if use_output_layer:
97
- self.output_layer = torch.nn.Linear(attention_dim, vocab_size)
98
- else:
99
- self.output_layer = torch.nn.Identity()
100
- self.num_blocks = num_blocks
101
- self.decoders = torch.nn.ModuleList(
102
- [
103
- DecoderLayer(
104
- attention_dim,
105
- COSYVOICE_ATTENTION_CLASSES["selfattn"](
106
- attention_heads,
107
- attention_dim,
108
- self_attention_dropout_rate,
109
- key_bias,
110
- ),
111
- (
112
- COSYVOICE_ATTENTION_CLASSES["selfattn"](
113
- attention_heads,
114
- attention_dim,
115
- src_attention_dropout_rate,
116
- key_bias,
117
- )
118
- if src_attention
119
- else None
120
- ),
121
- PositionwiseFeedForward(
122
- attention_dim, linear_units, dropout_rate, activation
123
- ),
124
- dropout_rate,
125
- normalize_before,
126
- )
127
- for _ in range(self.num_blocks)
128
- ]
129
- )
130
-
131
- self.gradient_checkpointing = gradient_checkpointing
132
- self.tie_word_embedding = tie_word_embedding
133
-
134
- def forward(
135
- self,
136
- memory: torch.Tensor,
137
- memory_mask: torch.Tensor,
138
- ys_in_pad: torch.Tensor,
139
- ys_in_lens: torch.Tensor,
140
- r_ys_in_pad: torch.Tensor = torch.empty(0),
141
- reverse_weight: float = 0.0,
142
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
143
- """Forward decoder.
144
- Args:
145
- memory: encoded memory, float32 (batch, maxlen_in, feat)
146
- memory_mask: encoder memory mask, (batch, 1, maxlen_in)
147
- ys_in_pad: padded input token ids, int64 (batch, maxlen_out)
148
- ys_in_lens: input lengths of this batch (batch)
149
- r_ys_in_pad: not used in transformer decoder, in order to unify api
150
- with bidirectional decoder
151
- reverse_weight: not used in transformer decoder, in order to unify
152
- api with bidirectional decode
153
- Returns:
154
- (tuple): tuple containing:
155
- x: decoded token score before softmax (batch, maxlen_out,
156
- vocab_size) if use_output_layer is True,
157
- torch.tensor(0.0), in order to unify api with bidirectional decoder
158
- olens: (batch, )
159
- NOTE(xcsong):
160
- We pass the `__call__` method of the modules instead of `forward` to the
161
- checkpointing API because `__call__` attaches all the hooks of the module.
162
- https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2
163
- """
164
- tgt = ys_in_pad
165
- maxlen = tgt.size(1)
166
- # tgt_mask: (B, 1, L)
167
- tgt_mask = ~make_pad_mask(ys_in_lens, maxlen).unsqueeze(1)
168
- tgt_mask = tgt_mask.to(tgt.device)
169
- # m: (1, L, L)
170
- m = subsequent_mask(tgt_mask.size(-1), device=tgt_mask.device).unsqueeze(0)
171
- # tgt_mask: (B, L, L)
172
- tgt_mask = tgt_mask & m
173
- x, _ = self.embed(tgt)
174
- if self.gradient_checkpointing and self.training:
175
- x = self.forward_layers_checkpointed(x, tgt_mask, memory, memory_mask)
176
- else:
177
- x = self.forward_layers(x, tgt_mask, memory, memory_mask)
178
- if self.normalize_before:
179
- x = self.after_norm(x)
180
- if self.use_output_layer:
181
- x = self.output_layer(x)
182
- olens = tgt_mask.sum(1)
183
- return x, torch.tensor(0.0), olens
184
-
185
- def forward_layers(
186
- self,
187
- x: torch.Tensor,
188
- tgt_mask: torch.Tensor,
189
- memory: torch.Tensor,
190
- memory_mask: torch.Tensor,
191
- ) -> torch.Tensor:
192
- for layer in self.decoders:
193
- x, tgt_mask, memory, memory_mask = layer(x, tgt_mask, memory, memory_mask)
194
- return x
195
-
196
- @torch.jit.unused
197
- def forward_layers_checkpointed(
198
- self,
199
- x: torch.Tensor,
200
- tgt_mask: torch.Tensor,
201
- memory: torch.Tensor,
202
- memory_mask: torch.Tensor,
203
- ) -> torch.Tensor:
204
- for layer in self.decoders:
205
- x, tgt_mask, memory, memory_mask = ckpt.checkpoint(
206
- layer.__call__, x, tgt_mask, memory, memory_mask
207
- )
208
- return x
209
-
210
- def forward_one_step(
211
- self,
212
- memory: torch.Tensor,
213
- memory_mask: torch.Tensor,
214
- tgt: torch.Tensor,
215
- tgt_mask: torch.Tensor,
216
- cache: Optional[List[torch.Tensor]] = None,
217
- ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
218
- """Forward one step.
219
- This is only used for decoding.
220
- Args:
221
- memory: encoded memory, float32 (batch, maxlen_in, feat)
222
- memory_mask: encoded memory mask, (batch, 1, maxlen_in)
223
- tgt: input token ids, int64 (batch, maxlen_out)
224
- tgt_mask: input token mask, (batch, maxlen_out)
225
- dtype=torch.uint8 in PyTorch 1.2-
226
- dtype=torch.bool in PyTorch 1.2+ (include 1.2)
227
- cache: cached output list of (batch, max_time_out-1, size)
228
- Returns:
229
- y, cache: NN output value and cache per `self.decoders`.
230
- y.shape` is (batch, maxlen_out, token)
231
- """
232
- x, _ = self.embed(tgt)
233
- new_cache = []
234
- for i, decoder in enumerate(self.decoders):
235
- if cache is None:
236
- c = None
237
- else:
238
- c = cache[i]
239
- x, tgt_mask, memory, memory_mask = decoder(
240
- x, tgt_mask, memory, memory_mask, cache=c
241
- )
242
- new_cache.append(x)
243
- if self.normalize_before:
244
- y = self.after_norm(x[:, -1])
245
- else:
246
- y = x[:, -1]
247
- if self.use_output_layer:
248
- y = torch.log_softmax(self.output_layer(y), dim=-1)
249
- return y, new_cache
250
-
251
- def tie_or_clone_weights(self, jit_mode: bool = True):
252
- """Tie or clone module weights (between word_emb and output_layer)
253
- depending of whether we are using TorchScript or not"""
254
- if not self.use_output_layer:
255
- return
256
- if jit_mode:
257
- logging.info("clone emb.weight to output.weight")
258
- self.output_layer.weight = torch.nn.Parameter(self.embed[0].weight.clone())
259
- else:
260
- logging.info("tie emb.weight with output.weight")
261
- self.output_layer.weight = self.embed[0].weight
262
-
263
- if getattr(self.output_layer, "bias", None) is not None:
264
- self.output_layer.bias.data = torch.nn.functional.pad(
265
- self.output_layer.bias.data,
266
- (
267
- 0,
268
- self.output_layer.weight.shape[0] - self.output_layer.bias.shape[0],
269
- ),
270
- "constant",
271
- 0,
272
- )
273
-
274
-
275
- class BiTransformerDecoder(torch.nn.Module):
276
- """Base class of Transfomer decoder module.
277
- Args:
278
- vocab_size: output dim
279
- encoder_output_size: dimension of attention
280
- attention_heads: the number of heads of multi head attention
281
- linear_units: the hidden units number of position-wise feedforward
282
- num_blocks: the number of decoder blocks
283
- r_num_blocks: the number of right to left decoder blocks
284
- dropout_rate: dropout rate
285
- self_attention_dropout_rate: dropout rate for attention
286
- input_layer: input layer type
287
- use_output_layer: whether to use output layer
288
- pos_enc_class: PositionalEncoding or ScaledPositionalEncoding
289
- normalize_before:
290
- True: use layer_norm before each sub-block of a layer.
291
- False: use layer_norm after each sub-block of a layer.
292
- key_bias: whether use bias in attention.linear_k, False for whisper models.
293
- """
294
-
295
- def __init__(
296
- self,
297
- vocab_size: int,
298
- encoder_output_size: int,
299
- attention_heads: int = 4,
300
- linear_units: int = 2048,
301
- num_blocks: int = 6,
302
- r_num_blocks: int = 0,
303
- dropout_rate: float = 0.1,
304
- positional_dropout_rate: float = 0.1,
305
- self_attention_dropout_rate: float = 0.0,
306
- src_attention_dropout_rate: float = 0.0,
307
- input_layer: str = "embed",
308
- use_output_layer: bool = True,
309
- normalize_before: bool = True,
310
- key_bias: bool = True,
311
- gradient_checkpointing: bool = False,
312
- tie_word_embedding: bool = False,
313
- ):
314
-
315
- super().__init__()
316
- self.tie_word_embedding = tie_word_embedding
317
- self.left_decoder = TransformerDecoder(
318
- vocab_size,
319
- encoder_output_size,
320
- attention_heads,
321
- linear_units,
322
- num_blocks,
323
- dropout_rate,
324
- positional_dropout_rate,
325
- self_attention_dropout_rate,
326
- src_attention_dropout_rate,
327
- input_layer,
328
- use_output_layer,
329
- normalize_before,
330
- key_bias=key_bias,
331
- gradient_checkpointing=gradient_checkpointing,
332
- tie_word_embedding=tie_word_embedding,
333
- )
334
-
335
- self.right_decoder = TransformerDecoder(
336
- vocab_size,
337
- encoder_output_size,
338
- attention_heads,
339
- linear_units,
340
- r_num_blocks,
341
- dropout_rate,
342
- positional_dropout_rate,
343
- self_attention_dropout_rate,
344
- src_attention_dropout_rate,
345
- input_layer,
346
- use_output_layer,
347
- normalize_before,
348
- key_bias=key_bias,
349
- gradient_checkpointing=gradient_checkpointing,
350
- tie_word_embedding=tie_word_embedding,
351
- )
352
-
353
- def forward(
354
- self,
355
- memory: torch.Tensor,
356
- memory_mask: torch.Tensor,
357
- ys_in_pad: torch.Tensor,
358
- ys_in_lens: torch.Tensor,
359
- r_ys_in_pad: torch.Tensor,
360
- reverse_weight: float = 0.0,
361
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
362
- """Forward decoder.
363
- Args:
364
- memory: encoded memory, float32 (batch, maxlen_in, feat)
365
- memory_mask: encoder memory mask, (batch, 1, maxlen_in)
366
- ys_in_pad: padded input token ids, int64 (batch, maxlen_out)
367
- ys_in_lens: input lengths of this batch (batch)
368
- r_ys_in_pad: padded input token ids, int64 (batch, maxlen_out),
369
- used for right to left decoder
370
- reverse_weight: used for right to left decoder
371
- Returns:
372
- (tuple): tuple containing:
373
- x: decoded token score before softmax (batch, maxlen_out,
374
- vocab_size) if use_output_layer is True,
375
- r_x: x: decoded token score (right to left decoder)
376
- before softmax (batch, maxlen_out, vocab_size)
377
- if use_output_layer is True,
378
- olens: (batch, )
379
- """
380
- l_x, _, olens = self.left_decoder(memory, memory_mask, ys_in_pad, ys_in_lens)
381
- r_x = torch.tensor(0.0)
382
- if reverse_weight > 0.0:
383
- r_x, _, olens = self.right_decoder(
384
- memory, memory_mask, r_ys_in_pad, ys_in_lens
385
- )
386
- return l_x, r_x, olens
387
-
388
- def forward_one_step(
389
- self,
390
- memory: torch.Tensor,
391
- memory_mask: torch.Tensor,
392
- tgt: torch.Tensor,
393
- tgt_mask: torch.Tensor,
394
- cache: Optional[List[torch.Tensor]] = None,
395
- ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
396
- """Forward one step.
397
- This is only used for decoding.
398
- Args:
399
- memory: encoded memory, float32 (batch, maxlen_in, feat)
400
- memory_mask: encoded memory mask, (batch, 1, maxlen_in)
401
- tgt: input token ids, int64 (batch, maxlen_out)
402
- tgt_mask: input token mask, (batch, maxlen_out)
403
- dtype=torch.uint8 in PyTorch 1.2-
404
- dtype=torch.bool in PyTorch 1.2+ (include 1.2)
405
- cache: cached output list of (batch, max_time_out-1, size)
406
- Returns:
407
- y, cache: NN output value and cache per `self.decoders`.
408
- y.shape` is (batch, maxlen_out, token)
409
- """
410
- return self.left_decoder.forward_one_step(
411
- memory, memory_mask, tgt, tgt_mask, cache
412
- )
413
-
414
- def tie_or_clone_weights(self, jit_mode: bool = True):
415
- """Tie or clone module weights (between word_emb and output_layer)
416
- depending of whether we are using TorchScript or not"""
417
- self.left_decoder.tie_or_clone_weights(jit_mode)
418
- self.right_decoder.tie_or_clone_weights(jit_mode)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cosyvoice/transformer/decoder_layer.py DELETED
@@ -1,132 +0,0 @@
1
- # Copyright (c) 2019 Shigeki Karita
2
- # 2020 Mobvoi Inc (Binbin Zhang)
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- """Decoder self-attention layer definition."""
16
- from typing import Optional, Tuple
17
-
18
- import torch
19
- from torch import nn
20
-
21
-
22
- class DecoderLayer(nn.Module):
23
- """Single decoder layer module.
24
-
25
- Args:
26
- size (int): Input dimension.
27
- self_attn (torch.nn.Module): Self-attention module instance.
28
- `MultiHeadedAttention` instance can be used as the argument.
29
- src_attn (torch.nn.Module): Inter-attention module instance.
30
- `MultiHeadedAttention` instance can be used as the argument.
31
- If `None` is passed, Inter-attention is not used, such as
32
- CIF, GPT, and other decoder only model.
33
- feed_forward (torch.nn.Module): Feed-forward module instance.
34
- `PositionwiseFeedForward` instance can be used as the argument.
35
- dropout_rate (float): Dropout rate.
36
- normalize_before (bool):
37
- True: use layer_norm before each sub-block.
38
- False: to use layer_norm after each sub-block.
39
- """
40
-
41
- def __init__(
42
- self,
43
- size: int,
44
- self_attn: nn.Module,
45
- src_attn: Optional[nn.Module],
46
- feed_forward: nn.Module,
47
- dropout_rate: float,
48
- normalize_before: bool = True,
49
- ):
50
- """Construct an DecoderLayer object."""
51
- super().__init__()
52
- self.size = size
53
- self.self_attn = self_attn
54
- self.src_attn = src_attn
55
- self.feed_forward = feed_forward
56
- self.norm1 = nn.LayerNorm(size, eps=1e-5)
57
- self.norm2 = nn.LayerNorm(size, eps=1e-5)
58
- self.norm3 = nn.LayerNorm(size, eps=1e-5)
59
- self.dropout = nn.Dropout(dropout_rate)
60
- self.normalize_before = normalize_before
61
-
62
- def forward(
63
- self,
64
- tgt: torch.Tensor,
65
- tgt_mask: torch.Tensor,
66
- memory: torch.Tensor,
67
- memory_mask: torch.Tensor,
68
- cache: Optional[torch.Tensor] = None,
69
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
70
- """Compute decoded features.
71
-
72
- Args:
73
- tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
74
- tgt_mask (torch.Tensor): Mask for input tensor
75
- (#batch, maxlen_out).
76
- memory (torch.Tensor): Encoded memory
77
- (#batch, maxlen_in, size).
78
- memory_mask (torch.Tensor): Encoded memory mask
79
- (#batch, maxlen_in).
80
- cache (torch.Tensor): cached tensors.
81
- (#batch, maxlen_out - 1, size).
82
-
83
- Returns:
84
- torch.Tensor: Output tensor (#batch, maxlen_out, size).
85
- torch.Tensor: Mask for output tensor (#batch, maxlen_out).
86
- torch.Tensor: Encoded memory (#batch, maxlen_in, size).
87
- torch.Tensor: Encoded memory mask (#batch, maxlen_in).
88
-
89
- """
90
- residual = tgt
91
- if self.normalize_before:
92
- tgt = self.norm1(tgt)
93
-
94
- if cache is None:
95
- tgt_q = tgt
96
- tgt_q_mask = tgt_mask
97
- else:
98
- # compute only the last frame query keeping dim: max_time_out -> 1
99
- assert cache.shape == (
100
- tgt.shape[0],
101
- tgt.shape[1] - 1,
102
- self.size,
103
- ), "{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}"
104
- tgt_q = tgt[:, -1:, :]
105
- residual = residual[:, -1:, :]
106
- tgt_q_mask = tgt_mask[:, -1:, :]
107
-
108
- x = residual + self.dropout(self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)[0])
109
- if not self.normalize_before:
110
- x = self.norm1(x)
111
-
112
- if self.src_attn is not None:
113
- residual = x
114
- if self.normalize_before:
115
- x = self.norm2(x)
116
- x = residual + self.dropout(
117
- self.src_attn(x, memory, memory, memory_mask)[0]
118
- )
119
- if not self.normalize_before:
120
- x = self.norm2(x)
121
-
122
- residual = x
123
- if self.normalize_before:
124
- x = self.norm3(x)
125
- x = residual + self.dropout(self.feed_forward(x))
126
- if not self.normalize_before:
127
- x = self.norm3(x)
128
-
129
- if cache is not None:
130
- x = torch.cat([cache, x], dim=1)
131
-
132
- return x, tgt_mask, memory, memory_mask
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cosyvoice/transformer/embedding.py DELETED
@@ -1,293 +0,0 @@
1
- # Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu)
2
- # 2024 Alibaba Inc (Xiang Lyu)
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- # Modified from ESPnet(https://github.com/espnet/espnet)
16
- """Positonal Encoding Module."""
17
-
18
- import math
19
- from typing import Tuple, Union
20
-
21
- import torch
22
- import torch.nn.functional as F
23
- import numpy as np
24
-
25
-
26
- class PositionalEncoding(torch.nn.Module):
27
- """Positional encoding.
28
-
29
- :param int d_model: embedding dim
30
- :param float dropout_rate: dropout rate
31
- :param int max_len: maximum input length
32
-
33
- PE(pos, 2i) = sin(pos/(10000^(2i/dmodel)))
34
- PE(pos, 2i+1) = cos(pos/(10000^(2i/dmodel)))
35
- """
36
-
37
- def __init__(
38
- self,
39
- d_model: int,
40
- dropout_rate: float,
41
- max_len: int = 5000,
42
- reverse: bool = False,
43
- ):
44
- """Construct an PositionalEncoding object."""
45
- super().__init__()
46
- self.d_model = d_model
47
- self.xscale = math.sqrt(self.d_model)
48
- self.dropout = torch.nn.Dropout(p=dropout_rate)
49
- self.max_len = max_len
50
-
51
- self.pe = torch.zeros(self.max_len, self.d_model)
52
- position = torch.arange(0, self.max_len, dtype=torch.float32).unsqueeze(1)
53
- div_term = torch.exp(
54
- torch.arange(0, self.d_model, 2, dtype=torch.float32)
55
- * -(math.log(10000.0) / self.d_model)
56
- )
57
- self.pe[:, 0::2] = torch.sin(position * div_term)
58
- self.pe[:, 1::2] = torch.cos(position * div_term)
59
- self.pe = self.pe.unsqueeze(0)
60
-
61
- def forward(
62
- self, x: torch.Tensor, offset: Union[int, torch.Tensor] = 0
63
- ) -> Tuple[torch.Tensor, torch.Tensor]:
64
- """Add positional encoding.
65
-
66
- Args:
67
- x (torch.Tensor): Input. Its shape is (batch, time, ...)
68
- offset (int, torch.tensor): position offset
69
-
70
- Returns:
71
- torch.Tensor: Encoded tensor. Its shape is (batch, time, ...)
72
- torch.Tensor: for compatibility to RelPositionalEncoding
73
- """
74
-
75
- self.pe = self.pe.to(x.device)
76
- pos_emb = self.position_encoding(offset, x.size(1), False)
77
- x = x * self.xscale + pos_emb
78
- return self.dropout(x), self.dropout(pos_emb)
79
-
80
- def position_encoding(
81
- self, offset: Union[int, torch.Tensor], size: int, apply_dropout: bool = True
82
- ) -> torch.Tensor:
83
- """For getting encoding in a streaming fashion
84
-
85
- Attention!!!!!
86
- we apply dropout only once at the whole utterance level in a none
87
- streaming way, but will call this function several times with
88
- increasing input size in a streaming scenario, so the dropout will
89
- be applied several times.
90
-
91
- Args:
92
- offset (int or torch.tensor): start offset
93
- size (int): required size of position encoding
94
-
95
- Returns:
96
- torch.Tensor: Corresponding encoding
97
- """
98
- # How to subscript a Union type:
99
- # https://github.com/pytorch/pytorch/issues/69434
100
- if isinstance(offset, int):
101
- assert offset + size <= self.max_len
102
- pos_emb = self.pe[:, offset : offset + size]
103
- elif isinstance(offset, torch.Tensor) and offset.dim() == 0: # scalar
104
- assert offset + size <= self.max_len
105
- pos_emb = self.pe[:, offset : offset + size]
106
- else: # for batched streaming decoding on GPU
107
- assert torch.max(offset) + size <= self.max_len
108
- index = offset.unsqueeze(1) + torch.arange(0, size).to(
109
- offset.device
110
- ) # B X T
111
- flag = index > 0
112
- # remove negative offset
113
- index = index * flag
114
- pos_emb = F.embedding(index, self.pe[0]) # B X T X d_model
115
-
116
- if apply_dropout:
117
- pos_emb = self.dropout(pos_emb)
118
- return pos_emb
119
-
120
-
121
- class RelPositionalEncoding(PositionalEncoding):
122
- """Relative positional encoding module.
123
- See : Appendix B in https://arxiv.org/abs/1901.02860
124
- Args:
125
- d_model (int): Embedding dimension.
126
- dropout_rate (float): Dropout rate.
127
- max_len (int): Maximum input length.
128
- """
129
-
130
- def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000):
131
- """Initialize class."""
132
- super().__init__(d_model, dropout_rate, max_len, reverse=True)
133
-
134
- def forward(
135
- self, x: torch.Tensor, offset: Union[int, torch.Tensor] = 0
136
- ) -> Tuple[torch.Tensor, torch.Tensor]:
137
- """Compute positional encoding.
138
- Args:
139
- x (torch.Tensor): Input tensor (batch, time, `*`).
140
- Returns:
141
- torch.Tensor: Encoded tensor (batch, time, `*`).
142
- torch.Tensor: Positional embedding tensor (1, time, `*`).
143
- """
144
- self.pe = self.pe.to(x.device)
145
- x = x * self.xscale
146
- pos_emb = self.position_encoding(offset, x.size(1), False)
147
- return self.dropout(x), self.dropout(pos_emb)
148
-
149
-
150
- class WhisperPositionalEncoding(PositionalEncoding):
151
- """Sinusoids position encoding used in openai-whisper.encoder"""
152
-
153
- def __init__(self, d_model: int, dropout_rate: float, max_len: int = 1500):
154
- super().__init__(d_model, dropout_rate, max_len)
155
- self.xscale = 1.0
156
- log_timescale_increment = np.log(10000) / (d_model // 2 - 1)
157
- inv_timescales = torch.exp(
158
- -log_timescale_increment * torch.arange(d_model // 2)
159
- )
160
- scaled_time = (
161
- torch.arange(max_len)[:, np.newaxis] * inv_timescales[np.newaxis, :]
162
- )
163
- pe = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
164
- delattr(self, "pe")
165
- self.register_buffer("pe", pe.unsqueeze(0))
166
-
167
-
168
- class LearnablePositionalEncoding(PositionalEncoding):
169
- """Learnable position encoding used in openai-whisper.decoder"""
170
-
171
- def __init__(self, d_model: int, dropout_rate: float, max_len: int = 448):
172
- super().__init__(d_model, dropout_rate, max_len)
173
- # NOTE(xcsong): overwrite self.pe & self.xscale
174
- self.pe = torch.nn.Parameter(torch.empty(1, max_len, d_model))
175
- self.xscale = 1.0
176
-
177
-
178
- class NoPositionalEncoding(torch.nn.Module):
179
- """No position encoding"""
180
-
181
- def __init__(self, d_model: int, dropout_rate: float):
182
- super().__init__()
183
- self.d_model = d_model
184
- self.dropout = torch.nn.Dropout(p=dropout_rate)
185
-
186
- def forward(
187
- self, x: torch.Tensor, offset: Union[int, torch.Tensor] = 0
188
- ) -> Tuple[torch.Tensor, torch.Tensor]:
189
- """Just return zero vector for interface compatibility"""
190
- pos_emb = torch.zeros(1, x.size(1), self.d_model).to(x.device)
191
- return self.dropout(x), pos_emb
192
-
193
- def position_encoding(
194
- self, offset: Union[int, torch.Tensor], size: int
195
- ) -> torch.Tensor:
196
- return torch.zeros(1, size, self.d_model)
197
-
198
-
199
- class EspnetRelPositionalEncoding(torch.nn.Module):
200
- """Relative positional encoding module (new implementation).
201
-
202
- Details can be found in https://github.com/espnet/espnet/pull/2816.
203
-
204
- See : Appendix B in https://arxiv.org/abs/1901.02860
205
-
206
- Args:
207
- d_model (int): Embedding dimension.
208
- dropout_rate (float): Dropout rate.
209
- max_len (int): Maximum input length.
210
-
211
- """
212
-
213
- def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000):
214
- """Construct an PositionalEncoding object."""
215
- super(EspnetRelPositionalEncoding, self).__init__()
216
- self.d_model = d_model
217
- self.xscale = math.sqrt(self.d_model)
218
- self.dropout = torch.nn.Dropout(p=dropout_rate)
219
- self.pe = None
220
- self.extend_pe(torch.tensor(0.0).expand(1, max_len))
221
-
222
- def extend_pe(self, x: torch.Tensor):
223
- """Reset the positional encodings."""
224
- if self.pe is not None:
225
- # self.pe contains both positive and negative parts
226
- # the length of self.pe is 2 * input_len - 1
227
- if self.pe.size(1) >= x.size(1) * 2 - 1:
228
- if self.pe.dtype != x.dtype or self.pe.device != x.device:
229
- self.pe = self.pe.to(dtype=x.dtype, device=x.device)
230
- return
231
- # Suppose `i` means to the position of query vecotr and `j` means the
232
- # position of key vector. We use position relative positions when keys
233
- # are to the left (i>j) and negative relative positions otherwise (i<j).
234
- pe_positive = torch.zeros(x.size(1), self.d_model)
235
- pe_negative = torch.zeros(x.size(1), self.d_model)
236
- position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
237
- div_term = torch.exp(
238
- torch.arange(0, self.d_model, 2, dtype=torch.float32)
239
- * -(math.log(10000.0) / self.d_model)
240
- )
241
- pe_positive[:, 0::2] = torch.sin(position * div_term)
242
- pe_positive[:, 1::2] = torch.cos(position * div_term)
243
- pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
244
- pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
245
-
246
- # Reserve the order of positive indices and concat both positive and
247
- # negative indices. This is used to support the shifting trick
248
- # as in https://arxiv.org/abs/1901.02860
249
- pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
250
- pe_negative = pe_negative[1:].unsqueeze(0)
251
- pe = torch.cat([pe_positive, pe_negative], dim=1)
252
- self.pe = pe.to(device=x.device, dtype=x.dtype)
253
-
254
- def forward(
255
- self, x: torch.Tensor, offset: Union[int, torch.Tensor] = 0
256
- ) -> Tuple[torch.Tensor, torch.Tensor]:
257
- """Add positional encoding.
258
-
259
- Args:
260
- x (torch.Tensor): Input tensor (batch, time, `*`).
261
-
262
- Returns:
263
- torch.Tensor: Encoded tensor (batch, time, `*`).
264
-
265
- """
266
- self.extend_pe(x)
267
- x = x * self.xscale
268
- pos_emb = self.position_encoding(size=x.size(1), offset=offset)
269
- return self.dropout(x), self.dropout(pos_emb)
270
-
271
- def position_encoding(
272
- self, offset: Union[int, torch.Tensor], size: int
273
- ) -> torch.Tensor:
274
- """For getting encoding in a streaming fashion
275
-
276
- Attention!!!!!
277
- we apply dropout only once at the whole utterance level in a none
278
- streaming way, but will call this function several times with
279
- increasing input size in a streaming scenario, so the dropout will
280
- be applied several times.
281
-
282
- Args:
283
- offset (int or torch.tensor): start offset
284
- size (int): required size of position encoding
285
-
286
- Returns:
287
- torch.Tensor: Corresponding encoding
288
- """
289
- pos_emb = self.pe[
290
- :,
291
- self.pe.size(1) // 2 - size + 1 : self.pe.size(1) // 2 + size,
292
- ]
293
- return pos_emb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cosyvoice/transformer/encoder.py DELETED
@@ -1,633 +0,0 @@
1
- # Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
2
- # 2022 Xingchen Song ([email protected])
3
- # 2024 Alibaba Inc (Xiang Lyu)
4
- #
5
- # Licensed under the Apache License, Version 2.0 (the "License");
6
- # you may not use this file except in compliance with the License.
7
- # You may obtain a copy of the License at
8
- #
9
- # http://www.apache.org/licenses/LICENSE-2.0
10
- #
11
- # Unless required by applicable law or agreed to in writing, software
12
- # distributed under the License is distributed on an "AS IS" BASIS,
13
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
- # See the License for the specific language governing permissions and
15
- # limitations under the License.
16
- # Modified from ESPnet(https://github.com/espnet/espnet)
17
- """Encoder definition."""
18
- from typing import Tuple
19
- import time
20
-
21
- import torch
22
- import torch.utils.checkpoint as ckpt
23
- import torch.nn.functional as F
24
-
25
- from cosyvoice.transformer.convolution import ConvolutionModule
26
- from cosyvoice.transformer.encoder_layer import (
27
- TransformerEncoderLayer,
28
- )
29
- from cosyvoice.transformer.encoder_layer import (
30
- ConformerEncoderLayer,
31
- )
32
- from cosyvoice.transformer.positionwise_feed_forward import (
33
- PositionwiseFeedForward,
34
- )
35
- from cosyvoice.utils.class_utils import (
36
- COSYVOICE_EMB_CLASSES,
37
- COSYVOICE_SUBSAMPLE_CLASSES,
38
- COSYVOICE_ATTENTION_CLASSES,
39
- COSYVOICE_ACTIVATION_CLASSES,
40
- )
41
- from cosyvoice.utils.mask import make_pad_mask
42
- from cosyvoice.utils.mask import add_optional_chunk_mask
43
-
44
-
45
- class BaseEncoder(torch.nn.Module):
46
-
47
- def __init__(
48
- self,
49
- input_size: int,
50
- output_size: int = 256,
51
- attention_heads: int = 4,
52
- linear_units: int = 2048,
53
- num_blocks: int = 6,
54
- dropout_rate: float = 0.1,
55
- positional_dropout_rate: float = 0.1,
56
- attention_dropout_rate: float = 0.0,
57
- input_layer: str = "conv2d",
58
- pos_enc_layer_type: str = "abs_pos",
59
- normalize_before: bool = True,
60
- static_chunk_size: int = 0,
61
- use_dynamic_chunk: bool = False,
62
- global_cmvn: torch.nn.Module = None,
63
- use_dynamic_left_chunk: bool = False,
64
- gradient_checkpointing: bool = False,
65
- ):
66
- """
67
- Args:
68
- input_size (int): input dim
69
- output_size (int): dimension of attention
70
- attention_heads (int): the number of heads of multi head attention
71
- linear_units (int): the hidden units number of position-wise feed
72
- forward
73
- num_blocks (int): the number of decoder blocks
74
- dropout_rate (float): dropout rate
75
- attention_dropout_rate (float): dropout rate in attention
76
- positional_dropout_rate (float): dropout rate after adding
77
- positional encoding
78
- input_layer (str): input layer type.
79
- optional [linear, conv2d, conv2d6, conv2d8]
80
- pos_enc_layer_type (str): Encoder positional encoding layer type.
81
- opitonal [abs_pos, scaled_abs_pos, rel_pos, no_pos]
82
- normalize_before (bool):
83
- True: use layer_norm before each sub-block of a layer.
84
- False: use layer_norm after each sub-block of a layer.
85
- static_chunk_size (int): chunk size for static chunk training and
86
- decoding
87
- use_dynamic_chunk (bool): whether use dynamic chunk size for
88
- training or not, You can only use fixed chunk(chunk_size > 0)
89
- or dyanmic chunk size(use_dynamic_chunk = True)
90
- global_cmvn (Optional[torch.nn.Module]): Optional GlobalCMVN module
91
- use_dynamic_left_chunk (bool): whether use dynamic left chunk in
92
- dynamic chunk training
93
- key_bias: whether use bias in attention.linear_k, False for whisper models.
94
- gradient_checkpointing: rerunning a forward-pass segment for each
95
- checkpointed segment during backward.
96
- """
97
- super().__init__()
98
- self._output_size = output_size
99
-
100
- self.global_cmvn = global_cmvn
101
- self.embed = COSYVOICE_SUBSAMPLE_CLASSES[input_layer](
102
- input_size,
103
- output_size,
104
- dropout_rate,
105
- COSYVOICE_EMB_CLASSES[pos_enc_layer_type](
106
- output_size, positional_dropout_rate
107
- ),
108
- )
109
-
110
- self.normalize_before = normalize_before
111
- self.after_norm = torch.nn.LayerNorm(output_size, eps=1e-5)
112
- self.static_chunk_size = static_chunk_size
113
- self.use_dynamic_chunk = use_dynamic_chunk
114
- self.use_dynamic_left_chunk = use_dynamic_left_chunk
115
- self.gradient_checkpointing = gradient_checkpointing
116
-
117
- def output_size(self) -> int:
118
- return self._output_size
119
-
120
- def forward(
121
- self,
122
- xs: torch.Tensor,
123
- xs_lens: torch.Tensor,
124
- decoding_chunk_size: int = 0,
125
- num_decoding_left_chunks: int = -1,
126
- ) -> Tuple[torch.Tensor, torch.Tensor]:
127
- """Embed positions in tensor.
128
-
129
- Args:
130
- xs: padded input tensor (B, T, D)
131
- xs_lens: input length (B)
132
- decoding_chunk_size: decoding chunk size for dynamic chunk
133
- 0: default for training, use random dynamic chunk.
134
- <0: for decoding, use full chunk.
135
- >0: for decoding, use fixed chunk size as set.
136
- num_decoding_left_chunks: number of left chunks, this is for decoding,
137
- the chunk size is decoding_chunk_size.
138
- >=0: use num_decoding_left_chunks
139
- <0: use all left chunks
140
- Returns:
141
- encoder output tensor xs, and subsampled masks
142
- xs: padded output tensor (B, T' ~= T/subsample_rate, D)
143
- masks: torch.Tensor batch padding mask after subsample
144
- (B, 1, T' ~= T/subsample_rate)
145
- NOTE(xcsong):
146
- We pass the `__call__` method of the modules instead of `forward` to the
147
- checkpointing API because `__call__` attaches all the hooks of the module.
148
- https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2
149
- """
150
- T = xs.size(1)
151
- masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
152
- if self.global_cmvn is not None:
153
- xs = self.global_cmvn(xs)
154
- xs, pos_emb, masks = self.embed(xs, masks)
155
- mask_pad = masks # (B, 1, T/subsample_rate)
156
- chunk_masks = add_optional_chunk_mask(
157
- xs,
158
- masks,
159
- self.use_dynamic_chunk,
160
- self.use_dynamic_left_chunk,
161
- decoding_chunk_size,
162
- self.static_chunk_size,
163
- num_decoding_left_chunks,
164
- )
165
- print(f"chunk_masks shape: {chunk_masks.shape}")
166
- if self.gradient_checkpointing and self.training:
167
- xs = self.forward_layers_checkpointed(xs, chunk_masks, pos_emb, mask_pad)
168
- else:
169
- xs = self.forward_layers(xs, chunk_masks, pos_emb, mask_pad)
170
- if self.normalize_before:
171
- xs = self.after_norm(xs)
172
- # Here we assume the mask is not changed in encoder layers, so just
173
- # return the masks before encoder layers, and the masks will be used
174
- # for cross attention with decoder later
175
- return xs, masks
176
-
177
- def forward_layers(
178
- self,
179
- xs: torch.Tensor,
180
- chunk_masks: torch.Tensor,
181
- pos_emb: torch.Tensor,
182
- mask_pad: torch.Tensor,
183
- ) -> torch.Tensor:
184
- for layer in self.encoders:
185
- xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
186
- return xs
187
-
188
- @torch.jit.unused
189
- def forward_layers_checkpointed(
190
- self,
191
- xs: torch.Tensor,
192
- chunk_masks: torch.Tensor,
193
- pos_emb: torch.Tensor,
194
- mask_pad: torch.Tensor,
195
- ) -> torch.Tensor:
196
- for layer in self.encoders:
197
- xs, chunk_masks, _, _ = ckpt.checkpoint(
198
- layer.__call__, xs, chunk_masks, pos_emb, mask_pad
199
- )
200
- return xs
201
-
202
- @torch.jit.export
203
- def forward_chunk(
204
- self,
205
- xs: torch.Tensor,
206
- offset: int,
207
- required_cache_size: int,
208
- att_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
209
- cnn_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
210
- att_mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
211
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
212
- """ Forward just one chunk
213
-
214
- Args:
215
- xs (torch.Tensor): chunk input, with shape (b=1, time, mel-dim),
216
- where `time == (chunk_size - 1) * subsample_rate + \
217
- subsample.right_context + 1`
218
- offset (int): current offset in encoder output time stamp
219
- required_cache_size (int): cache size required for next chunk
220
- compuation
221
- >=0: actual cache size
222
- <0: means all history cache is required
223
- att_cache (torch.Tensor): cache tensor for KEY & VALUE in
224
- transformer/conformer attention, with shape
225
- (elayers, head, cache_t1, d_k * 2), where
226
- `head * d_k == hidden-dim` and
227
- `cache_t1 == chunk_size * num_decoding_left_chunks`.
228
- cnn_cache (torch.Tensor): cache tensor for cnn_module in conformer,
229
- (elayers, b=1, hidden-dim, cache_t2), where
230
- `cache_t2 == cnn.lorder - 1`
231
-
232
- Returns:
233
- torch.Tensor: output of current input xs,
234
- with shape (b=1, chunk_size, hidden-dim).
235
- torch.Tensor: new attention cache required for next chunk, with
236
- dynamic shape (elayers, head, ?, d_k * 2)
237
- depending on required_cache_size.
238
- torch.Tensor: new conformer cnn cache required for next chunk, with
239
- same shape as the original cnn_cache.
240
-
241
- """
242
- assert xs.size(0) == 1
243
- # tmp_masks is just for interface compatibility
244
- tmp_masks = torch.ones(1, xs.size(1), device=xs.device, dtype=torch.bool)
245
- tmp_masks = tmp_masks.unsqueeze(1)
246
- if self.global_cmvn is not None:
247
- xs = self.global_cmvn(xs)
248
- # NOTE(xcsong): Before embed, shape(xs) is (b=1, time, mel-dim)
249
- xs, pos_emb, _ = self.embed(xs, tmp_masks, offset)
250
- # NOTE(xcsong): After embed, shape(xs) is (b=1, chunk_size, hidden-dim)
251
- elayers, cache_t1 = att_cache.size(0), att_cache.size(2)
252
- chunk_size = xs.size(1)
253
- attention_key_size = cache_t1 + chunk_size
254
- pos_emb = self.embed.position_encoding(
255
- offset=offset - cache_t1, size=attention_key_size
256
- )
257
- if required_cache_size < 0:
258
- next_cache_start = 0
259
- elif required_cache_size == 0:
260
- next_cache_start = attention_key_size
261
- else:
262
- next_cache_start = max(attention_key_size - required_cache_size, 0)
263
- r_att_cache = []
264
- r_cnn_cache = []
265
- for i, layer in enumerate(self.encoders):
266
- # NOTE(xcsong): Before layer.forward
267
- # shape(att_cache[i:i + 1]) is (1, head, cache_t1, d_k * 2),
268
- # shape(cnn_cache[i]) is (b=1, hidden-dim, cache_t2)
269
- xs, _, new_att_cache, new_cnn_cache = layer(
270
- xs,
271
- att_mask,
272
- pos_emb,
273
- att_cache=att_cache[i : i + 1] if elayers > 0 else att_cache,
274
- cnn_cache=cnn_cache[i] if cnn_cache.size(0) > 0 else cnn_cache,
275
- )
276
- # NOTE(xcsong): After layer.forward
277
- # shape(new_att_cache) is (1, head, attention_key_size, d_k * 2),
278
- # shape(new_cnn_cache) is (b=1, hidden-dim, cache_t2)
279
- r_att_cache.append(new_att_cache[:, :, next_cache_start:, :])
280
- r_cnn_cache.append(new_cnn_cache.unsqueeze(0))
281
- if self.normalize_before:
282
- xs = self.after_norm(xs)
283
-
284
- # NOTE(xcsong): shape(r_att_cache) is (elayers, head, ?, d_k * 2),
285
- # ? may be larger than cache_t1, it depends on required_cache_size
286
- r_att_cache = torch.cat(r_att_cache, dim=0)
287
- # NOTE(xcsong): shape(r_cnn_cache) is (e, b=1, hidden-dim, cache_t2)
288
- r_cnn_cache = torch.cat(r_cnn_cache, dim=0)
289
-
290
- return (xs, r_att_cache, r_cnn_cache)
291
-
292
- @torch.jit.unused
293
- def forward_chunk_by_chunk(
294
- self,
295
- xs: torch.Tensor,
296
- decoding_chunk_size: int,
297
- num_decoding_left_chunks: int = -1,
298
- ) -> Tuple[torch.Tensor, torch.Tensor]:
299
- """Forward input chunk by chunk with chunk_size like a streaming
300
- fashion
301
-
302
- Here we should pay special attention to computation cache in the
303
- streaming style forward chunk by chunk. Three things should be taken
304
- into account for computation in the current network:
305
- 1. transformer/conformer encoder layers output cache
306
- 2. convolution in conformer
307
- 3. convolution in subsampling
308
-
309
- However, we don't implement subsampling cache for:
310
- 1. We can control subsampling module to output the right result by
311
- overlapping input instead of cache left context, even though it
312
- wastes some computation, but subsampling only takes a very
313
- small fraction of computation in the whole model.
314
- 2. Typically, there are several covolution layers with subsampling
315
- in subsampling module, it is tricky and complicated to do cache
316
- with different convolution layers with different subsampling
317
- rate.
318
- 3. Currently, nn.Sequential is used to stack all the convolution
319
- layers in subsampling, we need to rewrite it to make it work
320
- with cache, which is not preferred.
321
- Args:
322
- xs (torch.Tensor): (1, max_len, dim)
323
- chunk_size (int): decoding chunk size
324
- """
325
- assert decoding_chunk_size > 0
326
- # The model is trained by static or dynamic chunk
327
- assert self.static_chunk_size > 0 or self.use_dynamic_chunk
328
- subsampling = self.embed.subsampling_rate
329
- context = self.embed.right_context + 1 # Add current frame
330
- stride = subsampling * decoding_chunk_size
331
- decoding_window = (decoding_chunk_size - 1) * subsampling + context
332
- num_frames = xs.size(1)
333
- att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0), device=xs.device)
334
- cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0), device=xs.device)
335
- outputs = []
336
- offset = 0
337
- required_cache_size = decoding_chunk_size * num_decoding_left_chunks
338
-
339
- # Feed forward overlap input step by step
340
- for cur in range(0, num_frames - context + 1, stride):
341
- end = min(cur + decoding_window, num_frames)
342
- chunk_xs = xs[:, cur:end, :]
343
- (y, att_cache, cnn_cache) = self.forward_chunk(
344
- chunk_xs, offset, required_cache_size, att_cache, cnn_cache
345
- )
346
- outputs.append(y)
347
- offset += y.size(1)
348
- ys = torch.cat(outputs, 1)
349
- masks = torch.ones((1, 1, ys.size(1)), device=ys.device, dtype=torch.bool)
350
- return ys, masks
351
-
352
-
353
- class TransformerEncoder(BaseEncoder):
354
- """Transformer encoder module."""
355
-
356
- def __init__(
357
- self,
358
- input_size: int,
359
- output_size: int = 256,
360
- attention_heads: int = 4,
361
- linear_units: int = 2048,
362
- num_blocks: int = 6,
363
- dropout_rate: float = 0.1,
364
- positional_dropout_rate: float = 0.1,
365
- attention_dropout_rate: float = 0.0,
366
- input_layer: str = "conv2d",
367
- pos_enc_layer_type: str = "abs_pos",
368
- normalize_before: bool = True,
369
- static_chunk_size: int = 0,
370
- use_dynamic_chunk: bool = False,
371
- global_cmvn: torch.nn.Module = None,
372
- use_dynamic_left_chunk: bool = False,
373
- key_bias: bool = True,
374
- selfattention_layer_type: str = "selfattn",
375
- activation_type: str = "relu",
376
- gradient_checkpointing: bool = False,
377
- ):
378
- """Construct TransformerEncoder
379
-
380
- See Encoder for the meaning of each parameter.
381
- """
382
- super().__init__(
383
- input_size,
384
- output_size,
385
- attention_heads,
386
- linear_units,
387
- num_blocks,
388
- dropout_rate,
389
- positional_dropout_rate,
390
- attention_dropout_rate,
391
- input_layer,
392
- pos_enc_layer_type,
393
- normalize_before,
394
- static_chunk_size,
395
- use_dynamic_chunk,
396
- global_cmvn,
397
- use_dynamic_left_chunk,
398
- gradient_checkpointing,
399
- )
400
- activation = COSYVOICE_ACTIVATION_CLASSES[activation_type]()
401
- self.encoders = torch.nn.ModuleList(
402
- [
403
- TransformerEncoderLayer(
404
- output_size,
405
- COSYVOICE_ATTENTION_CLASSES[selfattention_layer_type](
406
- attention_heads, output_size, attention_dropout_rate, key_bias
407
- ),
408
- PositionwiseFeedForward(
409
- output_size, linear_units, dropout_rate, activation
410
- ),
411
- dropout_rate,
412
- normalize_before,
413
- )
414
- for _ in range(num_blocks)
415
- ]
416
- )
417
-
418
-
419
- class ConformerEncoder(BaseEncoder):
420
- """Conformer encoder module."""
421
-
422
- def __init__(
423
- self,
424
- input_size: int,
425
- output_size: int = 256,
426
- attention_heads: int = 4,
427
- linear_units: int = 2048,
428
- num_blocks: int = 6,
429
- dropout_rate: float = 0.1,
430
- positional_dropout_rate: float = 0.1,
431
- attention_dropout_rate: float = 0.0,
432
- input_layer: str = "conv2d",
433
- pos_enc_layer_type: str = "rel_pos",
434
- normalize_before: bool = True,
435
- static_chunk_size: int = 0,
436
- use_dynamic_chunk: bool = False,
437
- global_cmvn: torch.nn.Module = None,
438
- use_dynamic_left_chunk: bool = False,
439
- positionwise_conv_kernel_size: int = 1,
440
- macaron_style: bool = True,
441
- selfattention_layer_type: str = "rel_selfattn",
442
- activation_type: str = "swish",
443
- use_cnn_module: bool = True,
444
- cnn_module_kernel: int = 15,
445
- causal: bool = False,
446
- cnn_module_norm: str = "batch_norm",
447
- key_bias: bool = True,
448
- gradient_checkpointing: bool = False,
449
- ):
450
- """Construct ConformerEncoder
451
-
452
- Args:
453
- input_size to use_dynamic_chunk, see in BaseEncoder
454
- positionwise_conv_kernel_size (int): Kernel size of positionwise
455
- conv1d layer.
456
- macaron_style (bool): Whether to use macaron style for
457
- positionwise layer.
458
- selfattention_layer_type (str): Encoder attention layer type,
459
- the parameter has no effect now, it's just for configure
460
- compatibility.
461
- activation_type (str): Encoder activation function type.
462
- use_cnn_module (bool): Whether to use convolution module.
463
- cnn_module_kernel (int): Kernel size of convolution module.
464
- causal (bool): whether to use causal convolution or not.
465
- key_bias: whether use bias in attention.linear_k, False for whisper models.
466
- """
467
- super().__init__(
468
- input_size,
469
- output_size,
470
- attention_heads,
471
- linear_units,
472
- num_blocks,
473
- dropout_rate,
474
- positional_dropout_rate,
475
- attention_dropout_rate,
476
- input_layer,
477
- pos_enc_layer_type,
478
- normalize_before,
479
- static_chunk_size,
480
- use_dynamic_chunk,
481
- global_cmvn,
482
- use_dynamic_left_chunk,
483
- gradient_checkpointing,
484
- )
485
- activation = COSYVOICE_ACTIVATION_CLASSES[activation_type]()
486
-
487
- # self-attention module definition
488
- encoder_selfattn_layer_args = (
489
- attention_heads,
490
- output_size,
491
- attention_dropout_rate,
492
- key_bias,
493
- )
494
- # feed-forward module definition
495
- positionwise_layer_args = (
496
- output_size,
497
- linear_units,
498
- dropout_rate,
499
- activation,
500
- )
501
- # convolution module definition
502
- convolution_layer_args = (
503
- output_size,
504
- cnn_module_kernel,
505
- activation,
506
- cnn_module_norm,
507
- causal,
508
- )
509
-
510
- self.encoders = torch.nn.ModuleList(
511
- [
512
- ConformerEncoderLayer(
513
- output_size,
514
- COSYVOICE_ATTENTION_CLASSES[selfattention_layer_type](
515
- *encoder_selfattn_layer_args
516
- ),
517
- PositionwiseFeedForward(*positionwise_layer_args),
518
- (
519
- PositionwiseFeedForward(*positionwise_layer_args)
520
- if macaron_style
521
- else None
522
- ),
523
- (
524
- ConvolutionModule(*convolution_layer_args)
525
- if use_cnn_module
526
- else None
527
- ),
528
- dropout_rate,
529
- normalize_before,
530
- )
531
- for _ in range(num_blocks)
532
- ]
533
- )
534
- self.inference_buffers = {}
535
- self.inference_graphs = {}
536
-
537
- @torch.inference_mode()
538
- def capture_inference(self, seq_len_to_capture=[128, 256, 512, 1024]):
539
- device = next(self.parameters()).device
540
- start_time = time.time()
541
- print(
542
- f"Start capture_inference for ConformerEncoder, seq_len_to_capture: {seq_len_to_capture}"
543
- )
544
-
545
- for seq_len in seq_len_to_capture:
546
- xs = torch.randn(
547
- 1, seq_len, self._output_size, device=device, dtype=torch.bfloat16
548
- )
549
- xs_lens = torch.tensor([seq_len], device=device, dtype=torch.int32)
550
- decoding_chunk_size = 0
551
- num_decoding_left_chunks = -1
552
-
553
- T = xs.size(1)
554
- masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
555
- if self.global_cmvn is not None:
556
- xs = self.global_cmvn(xs)
557
- xs, pos_emb, masks = self.embed(xs, masks)
558
- mask_pad = masks # (B, 1, T/subsample_rate)
559
- chunk_masks = add_optional_chunk_mask(
560
- xs,
561
- masks,
562
- self.use_dynamic_chunk,
563
- self.use_dynamic_left_chunk,
564
- decoding_chunk_size,
565
- self.static_chunk_size,
566
- num_decoding_left_chunks,
567
- )
568
-
569
- g = torch.cuda.CUDAGraph()
570
- with torch.cuda.graph(g):
571
- out = self.forward_layers(xs, chunk_masks, pos_emb, mask_pad)
572
-
573
- self.inference_graphs[seq_len] = g
574
- self.inference_buffers[seq_len] = {
575
- "xs": xs,
576
- "chunk_masks": chunk_masks,
577
- "pos_emb": pos_emb,
578
- "mask_pad": mask_pad,
579
- "out": out,
580
- }
581
- end_time = time.time()
582
- print(
583
- f"Finish capture_inference for ConformerEncoder, time elapsed: {end_time - start_time}"
584
- )
585
-
586
- @torch.inference_mode()
587
- def inference(self, xs: torch.Tensor, xs_lens: torch.Tensor):
588
- curr_seq_len = xs.shape[1]
589
- target_len = None
590
-
591
- for seq_len in sorted(self.inference_graphs.keys()):
592
- if seq_len >= curr_seq_len:
593
- target_len = seq_len
594
- break
595
-
596
- if target_len is not None:
597
- xs = F.pad(xs, (0, 0, 0, target_len - curr_seq_len), "constant", 0)
598
-
599
- decoding_chunk_size = 0
600
- num_decoding_left_chunks = -1
601
-
602
- T = xs.size(1)
603
- masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
604
- if self.global_cmvn is not None:
605
- xs = self.global_cmvn(xs)
606
- xs, pos_emb, masks = self.embed(xs, masks)
607
- mask_pad = masks # (B, 1, T/subsample_rate)
608
- chunk_masks = add_optional_chunk_mask(
609
- xs,
610
- masks,
611
- self.use_dynamic_chunk,
612
- self.use_dynamic_left_chunk,
613
- decoding_chunk_size,
614
- self.static_chunk_size,
615
- num_decoding_left_chunks,
616
- )
617
-
618
- if target_len is not None:
619
- buffer = self.inference_buffers[target_len]
620
- buffer["xs"].copy_(xs)
621
- buffer["chunk_masks"].copy_(chunk_masks)
622
- buffer["pos_emb"].copy_(pos_emb)
623
- buffer["mask_pad"].copy_(mask_pad)
624
-
625
- self.inference_graphs[target_len].replay()
626
-
627
- out = buffer["out"][:, :curr_seq_len, :]
628
- else:
629
- out = self.forward_layers(xs, chunk_masks, pos_emb, mask_pad)
630
-
631
- if self.normalize_before:
632
- out = self.after_norm(out)
633
- return out, masks
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cosyvoice/transformer/encoder_layer.py DELETED
@@ -1,237 +0,0 @@
1
- # Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
2
- # 2022 Xingchen Song ([email protected])
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- # Modified from ESPnet(https://github.com/espnet/espnet)
16
- """Encoder self-attention layer definition."""
17
-
18
- from typing import Optional, Tuple
19
-
20
- import torch
21
- from torch import nn
22
-
23
-
24
- class TransformerEncoderLayer(nn.Module):
25
- """Encoder layer module.
26
-
27
- Args:
28
- size (int): Input dimension.
29
- self_attn (torch.nn.Module): Self-attention module instance.
30
- `MultiHeadedAttention` or `RelPositionMultiHeadedAttention`
31
- instance can be used as the argument.
32
- feed_forward (torch.nn.Module): Feed-forward module instance.
33
- `PositionwiseFeedForward`, instance can be used as the argument.
34
- dropout_rate (float): Dropout rate.
35
- normalize_before (bool):
36
- True: use layer_norm before each sub-block.
37
- False: to use layer_norm after each sub-block.
38
- """
39
-
40
- def __init__(
41
- self,
42
- size: int,
43
- self_attn: torch.nn.Module,
44
- feed_forward: torch.nn.Module,
45
- dropout_rate: float,
46
- normalize_before: bool = True,
47
- ):
48
- """Construct an EncoderLayer object."""
49
- super().__init__()
50
- self.self_attn = self_attn
51
- self.feed_forward = feed_forward
52
- self.norm1 = nn.LayerNorm(size, eps=1e-5)
53
- self.norm2 = nn.LayerNorm(size, eps=1e-5)
54
- self.dropout = nn.Dropout(dropout_rate)
55
- self.size = size
56
- self.normalize_before = normalize_before
57
-
58
- def forward(
59
- self,
60
- x: torch.Tensor,
61
- mask: torch.Tensor,
62
- pos_emb: torch.Tensor,
63
- mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
64
- att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
65
- cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
66
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
67
- """Compute encoded features.
68
-
69
- Args:
70
- x (torch.Tensor): (#batch, time, size)
71
- mask (torch.Tensor): Mask tensor for the input (#batch, time,time),
72
- (0, 0, 0) means fake mask.
73
- pos_emb (torch.Tensor): just for interface compatibility
74
- to ConformerEncoderLayer
75
- mask_pad (torch.Tensor): does not used in transformer layer,
76
- just for unified api with conformer.
77
- att_cache (torch.Tensor): Cache tensor of the KEY & VALUE
78
- (#batch=1, head, cache_t1, d_k * 2), head * d_k == size.
79
- cnn_cache (torch.Tensor): Convolution cache in conformer layer
80
- (#batch=1, size, cache_t2), not used here, it's for interface
81
- compatibility to ConformerEncoderLayer.
82
- Returns:
83
- torch.Tensor: Output tensor (#batch, time, size).
84
- torch.Tensor: Mask tensor (#batch, time, time).
85
- torch.Tensor: att_cache tensor,
86
- (#batch=1, head, cache_t1 + time, d_k * 2).
87
- torch.Tensor: cnn_cahce tensor (#batch=1, size, cache_t2).
88
-
89
- """
90
- residual = x
91
- if self.normalize_before:
92
- x = self.norm1(x)
93
- x_att, new_att_cache = self.self_attn(
94
- x, x, x, mask, pos_emb=pos_emb, cache=att_cache
95
- )
96
- x = residual + self.dropout(x_att)
97
- if not self.normalize_before:
98
- x = self.norm1(x)
99
-
100
- residual = x
101
- if self.normalize_before:
102
- x = self.norm2(x)
103
- x = residual + self.dropout(self.feed_forward(x))
104
- if not self.normalize_before:
105
- x = self.norm2(x)
106
-
107
- fake_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
108
- return x, mask, new_att_cache, fake_cnn_cache
109
-
110
-
111
- class ConformerEncoderLayer(nn.Module):
112
- """Encoder layer module.
113
- Args:
114
- size (int): Input dimension.
115
- self_attn (torch.nn.Module): Self-attention module instance.
116
- `MultiHeadedAttention` or `RelPositionMultiHeadedAttention`
117
- instance can be used as the argument.
118
- feed_forward (torch.nn.Module): Feed-forward module instance.
119
- `PositionwiseFeedForward` instance can be used as the argument.
120
- feed_forward_macaron (torch.nn.Module): Additional feed-forward module
121
- instance.
122
- `PositionwiseFeedForward` instance can be used as the argument.
123
- conv_module (torch.nn.Module): Convolution module instance.
124
- `ConvlutionModule` instance can be used as the argument.
125
- dropout_rate (float): Dropout rate.
126
- normalize_before (bool):
127
- True: use layer_norm before each sub-block.
128
- False: use layer_norm after each sub-block.
129
- """
130
-
131
- def __init__(
132
- self,
133
- size: int,
134
- self_attn: torch.nn.Module,
135
- feed_forward: Optional[nn.Module] = None,
136
- feed_forward_macaron: Optional[nn.Module] = None,
137
- conv_module: Optional[nn.Module] = None,
138
- dropout_rate: float = 0.1,
139
- normalize_before: bool = True,
140
- ):
141
- """Construct an EncoderLayer object."""
142
- super().__init__()
143
- self.self_attn = self_attn
144
- self.feed_forward = feed_forward
145
- self.feed_forward_macaron = feed_forward_macaron
146
- self.conv_module = conv_module
147
- self.norm_ff = nn.LayerNorm(size, eps=1e-5) # for the FNN module
148
- self.norm_mha = nn.LayerNorm(size, eps=1e-5) # for the MHA module
149
- if feed_forward_macaron is not None:
150
- self.norm_ff_macaron = nn.LayerNorm(size, eps=1e-5)
151
- self.ff_scale = 0.5
152
- else:
153
- self.ff_scale = 1.0
154
- if self.conv_module is not None:
155
- self.norm_conv = nn.LayerNorm(size, eps=1e-5) # for the CNN module
156
- self.norm_final = nn.LayerNorm(
157
- size, eps=1e-5
158
- ) # for the final output of the block
159
- self.dropout = nn.Dropout(dropout_rate)
160
- self.size = size
161
- self.normalize_before = normalize_before
162
-
163
- def forward(
164
- self,
165
- x: torch.Tensor,
166
- mask: torch.Tensor,
167
- pos_emb: torch.Tensor,
168
- mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
169
- att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
170
- cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
171
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
172
- """Compute encoded features.
173
-
174
- Args:
175
- x (torch.Tensor): (#batch, time, size)
176
- mask (torch.Tensor): Mask tensor for the input (#batch, time,time),
177
- (0, 0, 0) means fake mask.
178
- pos_emb (torch.Tensor): positional encoding, must not be None
179
- for ConformerEncoderLayer.
180
- mask_pad (torch.Tensor): batch padding mask used for conv module.
181
- (#batch, 1,time), (0, 0, 0) means fake mask.
182
- att_cache (torch.Tensor): Cache tensor of the KEY & VALUE
183
- (#batch=1, head, cache_t1, d_k * 2), head * d_k == size.
184
- cnn_cache (torch.Tensor): Convolution cache in conformer layer
185
- (#batch=1, size, cache_t2)
186
- Returns:
187
- torch.Tensor: Output tensor (#batch, time, size).
188
- torch.Tensor: Mask tensor (#batch, time, time).
189
- torch.Tensor: att_cache tensor,
190
- (#batch=1, head, cache_t1 + time, d_k * 2).
191
- torch.Tensor: cnn_cahce tensor (#batch, size, cache_t2).
192
- """
193
-
194
- # whether to use macaron style
195
- if self.feed_forward_macaron is not None:
196
- residual = x
197
- if self.normalize_before:
198
- x = self.norm_ff_macaron(x)
199
- x = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(x))
200
- if not self.normalize_before:
201
- x = self.norm_ff_macaron(x)
202
-
203
- # multi-headed self-attention module
204
- residual = x
205
- if self.normalize_before:
206
- x = self.norm_mha(x)
207
- x_att, new_att_cache = self.self_attn(x, x, x, mask, pos_emb, att_cache)
208
- x = residual + self.dropout(x_att)
209
- if not self.normalize_before:
210
- x = self.norm_mha(x)
211
-
212
- # convolution module
213
- # Fake new cnn cache here, and then change it in conv_module
214
- new_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
215
- if self.conv_module is not None:
216
- residual = x
217
- if self.normalize_before:
218
- x = self.norm_conv(x)
219
- x, new_cnn_cache = self.conv_module(x, mask_pad, cnn_cache)
220
- x = residual + self.dropout(x)
221
-
222
- if not self.normalize_before:
223
- x = self.norm_conv(x)
224
-
225
- # feed forward module
226
- residual = x
227
- if self.normalize_before:
228
- x = self.norm_ff(x)
229
-
230
- x = residual + self.ff_scale * self.dropout(self.feed_forward(x))
231
- if not self.normalize_before:
232
- x = self.norm_ff(x)
233
-
234
- if self.conv_module is not None:
235
- x = self.norm_final(x)
236
-
237
- return x, mask, new_att_cache, new_cnn_cache
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cosyvoice/transformer/label_smoothing_loss.py DELETED
@@ -1,98 +0,0 @@
1
- # Copyright (c) 2019 Shigeki Karita
2
- # 2020 Mobvoi Inc (Binbin Zhang)
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- """Label smoothing module."""
16
-
17
- import torch
18
- from torch import nn
19
-
20
-
21
- class LabelSmoothingLoss(nn.Module):
22
- """Label-smoothing loss.
23
-
24
- In a standard CE loss, the label's data distribution is:
25
- [0,1,2] ->
26
- [
27
- [1.0, 0.0, 0.0],
28
- [0.0, 1.0, 0.0],
29
- [0.0, 0.0, 1.0],
30
- ]
31
-
32
- In the smoothing version CE Loss,some probabilities
33
- are taken from the true label prob (1.0) and are divided
34
- among other labels.
35
-
36
- e.g.
37
- smoothing=0.1
38
- [0,1,2] ->
39
- [
40
- [0.9, 0.05, 0.05],
41
- [0.05, 0.9, 0.05],
42
- [0.05, 0.05, 0.9],
43
- ]
44
-
45
- Args:
46
- size (int): the number of class
47
- padding_idx (int): padding class id which will be ignored for loss
48
- smoothing (float): smoothing rate (0.0 means the conventional CE)
49
- normalize_length (bool):
50
- normalize loss by sequence length if True
51
- normalize loss by batch size if False
52
- """
53
-
54
- def __init__(
55
- self,
56
- size: int,
57
- padding_idx: int,
58
- smoothing: float,
59
- normalize_length: bool = False,
60
- ):
61
- """Construct an LabelSmoothingLoss object."""
62
- super(LabelSmoothingLoss, self).__init__()
63
- self.criterion = nn.KLDivLoss(reduction="none")
64
- self.padding_idx = padding_idx
65
- self.confidence = 1.0 - smoothing
66
- self.smoothing = smoothing
67
- self.size = size
68
- self.normalize_length = normalize_length
69
-
70
- def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
71
- """Compute loss between x and target.
72
-
73
- The model outputs and data labels tensors are flatten to
74
- (batch*seqlen, class) shape and a mask is applied to the
75
- padding part which should not be calculated for loss.
76
-
77
- Args:
78
- x (torch.Tensor): prediction (batch, seqlen, class)
79
- target (torch.Tensor):
80
- target signal masked with self.padding_id (batch, seqlen)
81
- Returns:
82
- loss (torch.Tensor) : The KL loss, scalar float value
83
- """
84
- assert x.size(2) == self.size
85
- batch_size = x.size(0)
86
- x = x.view(-1, self.size)
87
- target = target.view(-1)
88
- # use zeros_like instead of torch.no_grad() for true_dist,
89
- # since no_grad() can not be exported by JIT
90
- true_dist = torch.zeros_like(x)
91
- true_dist.fill_(self.smoothing / (self.size - 1))
92
- ignore = target == self.padding_idx # (B,)
93
- total = len(target) - ignore.sum().item()
94
- target = target.masked_fill(ignore, 0) # avoid -1 index
95
- true_dist.scatter_(1, target.unsqueeze(1), self.confidence)
96
- kl = self.criterion(torch.log_softmax(x, dim=1), true_dist)
97
- denom = total if self.normalize_length else batch_size
98
- return kl.masked_fill(ignore.unsqueeze(1), 0).sum() / denom
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cosyvoice/transformer/positionwise_feed_forward.py DELETED
@@ -1,116 +0,0 @@
1
- # Copyright (c) 2019 Shigeki Karita
2
- # 2020 Mobvoi Inc (Binbin Zhang)
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- """Positionwise feed forward layer definition."""
16
-
17
- import torch
18
-
19
-
20
- class PositionwiseFeedForward(torch.nn.Module):
21
- """Positionwise feed forward layer.
22
-
23
- FeedForward are appied on each position of the sequence.
24
- The output dim is same with the input dim.
25
-
26
- Args:
27
- idim (int): Input dimenstion.
28
- hidden_units (int): The number of hidden units.
29
- dropout_rate (float): Dropout rate.
30
- activation (torch.nn.Module): Activation function
31
- """
32
-
33
- def __init__(
34
- self,
35
- idim: int,
36
- hidden_units: int,
37
- dropout_rate: float,
38
- activation: torch.nn.Module = torch.nn.ReLU(),
39
- ):
40
- """Construct a PositionwiseFeedForward object."""
41
- super(PositionwiseFeedForward, self).__init__()
42
- self.w_1 = torch.nn.Linear(idim, hidden_units)
43
- self.activation = activation
44
- self.dropout = torch.nn.Dropout(dropout_rate)
45
- self.w_2 = torch.nn.Linear(hidden_units, idim)
46
-
47
- def forward(self, xs: torch.Tensor) -> torch.Tensor:
48
- """Forward function.
49
-
50
- Args:
51
- xs: input tensor (B, L, D)
52
- Returns:
53
- output tensor, (B, L, D)
54
- """
55
- return self.w_2(self.dropout(self.activation(self.w_1(xs))))
56
-
57
-
58
- class MoEFFNLayer(torch.nn.Module):
59
- """
60
- Mixture of expert with Positionwise feed forward layer
61
- See also figure 1 in https://arxiv.org/pdf/2305.15663.pdf
62
- The output dim is same with the input dim.
63
-
64
- Modified from https://github.com/Lightning-AI/lit-gpt/pull/823
65
- https://github.com/mistralai/mistral-src/blob/b46d6/moe_one_file_ref.py#L203-L219
66
- Args:
67
- n_expert: number of expert.
68
- n_expert_per_token: The actual number of experts used for each frame
69
- idim (int): Input dimenstion.
70
- hidden_units (int): The number of hidden units.
71
- dropout_rate (float): Dropout rate.
72
- activation (torch.nn.Module): Activation function
73
- """
74
-
75
- def __init__(
76
- self,
77
- n_expert: int,
78
- n_expert_per_token: int,
79
- idim: int,
80
- hidden_units: int,
81
- dropout_rate: float,
82
- activation: torch.nn.Module = torch.nn.ReLU(),
83
- ):
84
- super(MoEFFNLayer, self).__init__()
85
- self.gate = torch.nn.Linear(idim, n_expert, bias=False)
86
- self.experts = torch.nn.ModuleList(
87
- PositionwiseFeedForward(idim, hidden_units, dropout_rate, activation)
88
- for _ in range(n_expert)
89
- )
90
- self.n_expert_per_token = n_expert_per_token
91
-
92
- def forward(self, xs: torch.Tensor) -> torch.Tensor:
93
- """Foward function.
94
- Args:
95
- xs: input tensor (B, L, D)
96
- Returns:
97
- output tensor, (B, L, D)
98
-
99
- """
100
- B, L, D = xs.size() # batch size, sequence length, embedding dimension (idim)
101
- xs = xs.view(-1, D) # (B*L, D)
102
- router = self.gate(xs) # (B*L, n_expert)
103
- logits, indices = torch.topk(
104
- router, self.n_expert_per_token
105
- ) # probs:(B*L, n_expert), indices: (B*L, n_expert)
106
- weights = torch.nn.functional.softmax(logits, dim=1, dtype=torch.float).to(
107
- dtype=xs.dtype
108
- ) # (B*L, n_expert_per_token)
109
- output = torch.zeros_like(xs) # (B*L, D)
110
- for i, expert in enumerate(self.experts):
111
- mask = indices == i
112
- batch_idx, ith_expert = torch.where(mask)
113
- output[batch_idx] += weights[batch_idx, ith_expert, None] * expert(
114
- xs[batch_idx]
115
- )
116
- return output.view(B, L, D)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cosyvoice/transformer/subsampling.py DELETED
@@ -1,391 +0,0 @@
1
- # Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
2
- # 2024 Alibaba Inc (Xiang Lyu)
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- # Modified from ESPnet(https://github.com/espnet/espnet)
16
- """Subsampling layer definition."""
17
-
18
- from typing import Tuple, Union
19
-
20
- import torch
21
-
22
-
23
- class BaseSubsampling(torch.nn.Module):
24
-
25
- def __init__(self):
26
- super().__init__()
27
- self.right_context = 0
28
- self.subsampling_rate = 1
29
-
30
- def position_encoding(
31
- self, offset: Union[int, torch.Tensor], size: int
32
- ) -> torch.Tensor:
33
- return self.pos_enc.position_encoding(offset, size)
34
-
35
-
36
- class EmbedinigNoSubsampling(BaseSubsampling):
37
- """Embedding input without subsampling"""
38
-
39
- def __init__(
40
- self, idim: int, odim: int, dropout_rate: float, pos_enc_class: torch.nn.Module
41
- ):
42
- super().__init__()
43
- self.embed = torch.nn.Embedding(idim, odim)
44
- self.pos_enc = pos_enc_class
45
-
46
- def forward(
47
- self,
48
- x: torch.Tensor,
49
- x_mask: torch.Tensor,
50
- offset: Union[int, torch.Tensor] = 0,
51
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
52
- """Input x.
53
-
54
- Args:
55
- x (torch.Tensor): Input tensor (#batch, time, idim).
56
- x_mask (torch.Tensor): Input mask (#batch, 1, time).
57
-
58
- Returns:
59
- torch.Tensor: linear input tensor (#batch, time', odim),
60
- where time' = time .
61
- torch.Tensor: linear input mask (#batch, 1, time'),
62
- where time' = time .
63
-
64
- """
65
- x = self.embed(x)
66
- x, pos_emb = self.pos_enc(x, offset)
67
- return x, pos_emb, x_mask
68
-
69
-
70
- class LinearNoSubsampling(BaseSubsampling):
71
- """Linear transform the input without subsampling
72
-
73
- Args:
74
- idim (int): Input dimension.
75
- odim (int): Output dimension.
76
- dropout_rate (float): Dropout rate.
77
-
78
- """
79
-
80
- def __init__(
81
- self, idim: int, odim: int, dropout_rate: float, pos_enc_class: torch.nn.Module
82
- ):
83
- """Construct an linear object."""
84
- super().__init__()
85
- self.out = torch.nn.Sequential(
86
- torch.nn.Linear(idim, odim),
87
- torch.nn.LayerNorm(odim, eps=1e-5),
88
- torch.nn.Dropout(dropout_rate),
89
- )
90
- self.pos_enc = pos_enc_class
91
- self.right_context = 0
92
- self.subsampling_rate = 1
93
-
94
- def forward(
95
- self,
96
- x: torch.Tensor,
97
- x_mask: torch.Tensor,
98
- offset: Union[int, torch.Tensor] = 0,
99
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
100
- """Input x.
101
-
102
- Args:
103
- x (torch.Tensor): Input tensor (#batch, time, idim).
104
- x_mask (torch.Tensor): Input mask (#batch, 1, time).
105
-
106
- Returns:
107
- torch.Tensor: linear input tensor (#batch, time', odim),
108
- where time' = time .
109
- torch.Tensor: linear input mask (#batch, 1, time'),
110
- where time' = time .
111
-
112
- """
113
- x = self.out(x)
114
- x, pos_emb = self.pos_enc(x, offset)
115
- return x, pos_emb, x_mask
116
-
117
-
118
- class Conv1dSubsampling2(BaseSubsampling):
119
- """Convolutional 1D subsampling (to 1/2 length).
120
- It is designed for Whisper, ref:
121
- https://github.com/openai/whisper/blob/main/whisper/model.py
122
-
123
- Args:
124
- idim (int): Input dimension.
125
- odim (int): Output dimension.
126
- dropout_rate (float): Dropout rate.
127
-
128
- """
129
-
130
- def __init__(
131
- self, idim: int, odim: int, dropout_rate: float, pos_enc_class: torch.nn.Module
132
- ):
133
- """Construct an Conv1dSubsampling2 object."""
134
- super().__init__()
135
- self.conv = torch.nn.Sequential(
136
- torch.nn.Conv1d(idim, odim, kernel_size=3, padding=1),
137
- torch.nn.GELU(),
138
- torch.nn.Conv1d(odim, odim, kernel_size=3, stride=2, padding=1),
139
- torch.nn.GELU(),
140
- )
141
- self.pos_enc = pos_enc_class
142
- # The right context for every conv layer is computed by:
143
- # (kernel_size - 1) * frame_rate_of_this_layer
144
- self.subsampling_rate = 2
145
- # 4 = (3 - 1) * 1 + (3 - 1) * 1
146
- self.right_context = 4
147
-
148
- def forward(
149
- self,
150
- x: torch.Tensor,
151
- x_mask: torch.Tensor,
152
- offset: Union[int, torch.Tensor] = 0,
153
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
154
- """Subsample x.
155
-
156
- Args:
157
- x (torch.Tensor): Input tensor (#batch, time, idim).
158
- x_mask (torch.Tensor): Input mask (#batch, 1, time).
159
-
160
- Returns:
161
- torch.Tensor: Subsampled tensor (#batch, time', odim),
162
- where time' = time // 2.
163
- torch.Tensor: Subsampled mask (#batch, 1, time'),
164
- where time' = time // 2.
165
- torch.Tensor: positional encoding
166
-
167
- """
168
- time = x.size(1)
169
- x = x.transpose(1, 2) # (b, f, t)
170
- x = self.conv(x)
171
- x = x.transpose(1, 2) # (b, t, f)
172
- x, pos_emb = self.pos_enc(x, offset)
173
- return x, pos_emb, x_mask[:, :, (time + 1) % 2 :: 2]
174
-
175
-
176
- class Conv2dSubsampling4(BaseSubsampling):
177
- """Convolutional 2D subsampling (to 1/4 length).
178
-
179
- Args:
180
- idim (int): Input dimension.
181
- odim (int): Output dimension.
182
- dropout_rate (float): Dropout rate.
183
-
184
- """
185
-
186
- def __init__(
187
- self, idim: int, odim: int, dropout_rate: float, pos_enc_class: torch.nn.Module
188
- ):
189
- """Construct an Conv2dSubsampling4 object."""
190
- super().__init__()
191
- self.conv = torch.nn.Sequential(
192
- torch.nn.Conv2d(1, odim, 3, 2),
193
- torch.nn.ReLU(),
194
- torch.nn.Conv2d(odim, odim, 3, 2),
195
- torch.nn.ReLU(),
196
- )
197
- self.out = torch.nn.Sequential(
198
- torch.nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim)
199
- )
200
- self.pos_enc = pos_enc_class
201
- # The right context for every conv layer is computed by:
202
- # (kernel_size - 1) * frame_rate_of_this_layer
203
- self.subsampling_rate = 4
204
- # 6 = (3 - 1) * 1 + (3 - 1) * 2
205
- self.right_context = 6
206
-
207
- def forward(
208
- self,
209
- x: torch.Tensor,
210
- x_mask: torch.Tensor,
211
- offset: Union[int, torch.Tensor] = 0,
212
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
213
- """Subsample x.
214
-
215
- Args:
216
- x (torch.Tensor): Input tensor (#batch, time, idim).
217
- x_mask (torch.Tensor): Input mask (#batch, 1, time).
218
-
219
- Returns:
220
- torch.Tensor: Subsampled tensor (#batch, time', odim),
221
- where time' = time // 4.
222
- torch.Tensor: Subsampled mask (#batch, 1, time'),
223
- where time' = time // 4.
224
- torch.Tensor: positional encoding
225
-
226
- """
227
- x = x.unsqueeze(1) # (b, c=1, t, f)
228
- x = self.conv(x)
229
- b, c, t, f = x.size()
230
- x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
231
- x, pos_emb = self.pos_enc(x, offset)
232
- return x, pos_emb, x_mask[:, :, 2::2][:, :, 2::2]
233
-
234
-
235
- class Conv2dSubsampling6(BaseSubsampling):
236
- """Convolutional 2D subsampling (to 1/6 length).
237
- Args:
238
- idim (int): Input dimension.
239
- odim (int): Output dimension.
240
- dropout_rate (float): Dropout rate.
241
- pos_enc (torch.nn.Module): Custom position encoding layer.
242
- """
243
-
244
- def __init__(
245
- self, idim: int, odim: int, dropout_rate: float, pos_enc_class: torch.nn.Module
246
- ):
247
- """Construct an Conv2dSubsampling6 object."""
248
- super().__init__()
249
- self.conv = torch.nn.Sequential(
250
- torch.nn.Conv2d(1, odim, 3, 2),
251
- torch.nn.ReLU(),
252
- torch.nn.Conv2d(odim, odim, 5, 3),
253
- torch.nn.ReLU(),
254
- )
255
- self.linear = torch.nn.Linear(odim * (((idim - 1) // 2 - 2) // 3), odim)
256
- self.pos_enc = pos_enc_class
257
- # 10 = (3 - 1) * 1 + (5 - 1) * 2
258
- self.subsampling_rate = 6
259
- self.right_context = 10
260
-
261
- def forward(
262
- self,
263
- x: torch.Tensor,
264
- x_mask: torch.Tensor,
265
- offset: Union[int, torch.Tensor] = 0,
266
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
267
- """Subsample x.
268
- Args:
269
- x (torch.Tensor): Input tensor (#batch, time, idim).
270
- x_mask (torch.Tensor): Input mask (#batch, 1, time).
271
-
272
- Returns:
273
- torch.Tensor: Subsampled tensor (#batch, time', odim),
274
- where time' = time // 6.
275
- torch.Tensor: Subsampled mask (#batch, 1, time'),
276
- where time' = time // 6.
277
- torch.Tensor: positional encoding
278
- """
279
- x = x.unsqueeze(1) # (b, c, t, f)
280
- x = self.conv(x)
281
- b, c, t, f = x.size()
282
- x = self.linear(x.transpose(1, 2).contiguous().view(b, t, c * f))
283
- x, pos_emb = self.pos_enc(x, offset)
284
- return x, pos_emb, x_mask[:, :, 2::2][:, :, 4::3]
285
-
286
-
287
- class Conv2dSubsampling8(BaseSubsampling):
288
- """Convolutional 2D subsampling (to 1/8 length).
289
-
290
- Args:
291
- idim (int): Input dimension.
292
- odim (int): Output dimension.
293
- dropout_rate (float): Dropout rate.
294
-
295
- """
296
-
297
- def __init__(
298
- self, idim: int, odim: int, dropout_rate: float, pos_enc_class: torch.nn.Module
299
- ):
300
- """Construct an Conv2dSubsampling8 object."""
301
- super().__init__()
302
- self.conv = torch.nn.Sequential(
303
- torch.nn.Conv2d(1, odim, 3, 2),
304
- torch.nn.ReLU(),
305
- torch.nn.Conv2d(odim, odim, 3, 2),
306
- torch.nn.ReLU(),
307
- torch.nn.Conv2d(odim, odim, 3, 2),
308
- torch.nn.ReLU(),
309
- )
310
- self.linear = torch.nn.Linear(
311
- odim * ((((idim - 1) // 2 - 1) // 2 - 1) // 2), odim
312
- )
313
- self.pos_enc = pos_enc_class
314
- self.subsampling_rate = 8
315
- # 14 = (3 - 1) * 1 + (3 - 1) * 2 + (3 - 1) * 4
316
- self.right_context = 14
317
-
318
- def forward(
319
- self,
320
- x: torch.Tensor,
321
- x_mask: torch.Tensor,
322
- offset: Union[int, torch.Tensor] = 0,
323
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
324
- """Subsample x.
325
-
326
- Args:
327
- x (torch.Tensor): Input tensor (#batch, time, idim).
328
- x_mask (torch.Tensor): Input mask (#batch, 1, time).
329
-
330
- Returns:
331
- torch.Tensor: Subsampled tensor (#batch, time', odim),
332
- where time' = time // 8.
333
- torch.Tensor: Subsampled mask (#batch, 1, time'),
334
- where time' = time // 8.
335
- torch.Tensor: positional encoding
336
- """
337
- x = x.unsqueeze(1) # (b, c, t, f)
338
- x = self.conv(x)
339
- b, c, t, f = x.size()
340
- x = self.linear(x.transpose(1, 2).contiguous().view(b, t, c * f))
341
- x, pos_emb = self.pos_enc(x, offset)
342
- return x, pos_emb, x_mask[:, :, 2::2][:, :, 2::2][:, :, 2::2]
343
-
344
-
345
- class LegacyLinearNoSubsampling(BaseSubsampling):
346
- """Linear transform the input without subsampling
347
-
348
- Args:
349
- idim (int): Input dimension.
350
- odim (int): Output dimension.
351
- dropout_rate (float): Dropout rate.
352
-
353
- """
354
-
355
- def __init__(
356
- self, idim: int, odim: int, dropout_rate: float, pos_enc_class: torch.nn.Module
357
- ):
358
- """Construct an linear object."""
359
- super().__init__()
360
- self.out = torch.nn.Sequential(
361
- torch.nn.Linear(idim, odim),
362
- torch.nn.LayerNorm(odim, eps=1e-5),
363
- torch.nn.Dropout(dropout_rate),
364
- torch.nn.ReLU(),
365
- )
366
- self.pos_enc = pos_enc_class
367
- self.right_context = 0
368
- self.subsampling_rate = 1
369
-
370
- def forward(
371
- self,
372
- x: torch.Tensor,
373
- x_mask: torch.Tensor,
374
- offset: Union[int, torch.Tensor] = 0,
375
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
376
- """Input x.
377
-
378
- Args:
379
- x (torch.Tensor): Input tensor (#batch, time, idim).
380
- x_mask (torch.Tensor): Input mask (#batch, 1, time).
381
-
382
- Returns:
383
- torch.Tensor: linear input tensor (#batch, time', odim),
384
- where time' = time .
385
- torch.Tensor: linear input mask (#batch, 1, time'),
386
- where time' = time .
387
-
388
- """
389
- x = self.out(x)
390
- x, pos_emb = self.pos_enc(x, offset)
391
- return x, pos_emb, x_mask
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cosyvoice/utils/__init__.py DELETED
File without changes
cosyvoice/utils/audio.py DELETED
@@ -1,90 +0,0 @@
1
- import numpy as np
2
- import torch
3
- import torch.utils.data
4
- from librosa.filters import mel as librosa_mel_fn
5
- from scipy.io.wavfile import read
6
-
7
- MAX_WAV_VALUE = 32768.0
8
-
9
-
10
- def load_wav(full_path):
11
- sampling_rate, data = read(full_path)
12
- return data, sampling_rate
13
-
14
-
15
- def dynamic_range_compression(x, C=1, clip_val=1e-5):
16
- return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
17
-
18
-
19
- def dynamic_range_decompression(x, C=1):
20
- return np.exp(x) / C
21
-
22
-
23
- def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
24
- return torch.log(torch.clamp(x, min=clip_val) * C)
25
-
26
-
27
- def dynamic_range_decompression_torch(x, C=1):
28
- return torch.exp(x) / C
29
-
30
-
31
- def spectral_normalize_torch(magnitudes):
32
- output = dynamic_range_compression_torch(magnitudes)
33
- return output
34
-
35
-
36
- def spectral_de_normalize_torch(magnitudes):
37
- output = dynamic_range_decompression_torch(magnitudes)
38
- return output
39
-
40
-
41
- mel_basis = {}
42
- hann_window = {}
43
-
44
-
45
- def mel_spectrogram(
46
- y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False
47
- ):
48
- # if torch.min(y) < -1.0:
49
- # print("min value is ", torch.min(y))
50
- # if torch.max(y) > 1.0:
51
- # print("max value is ", torch.max(y))
52
-
53
- global mel_basis, hann_window # pylint: disable=global-statement
54
- if f"{str(fmax)}_{str(y.device)}" not in mel_basis:
55
- mel = librosa_mel_fn(
56
- sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
57
- )
58
- mel_basis[str(fmax) + "_" + str(y.device)] = (
59
- torch.from_numpy(mel).float().to(y.device)
60
- )
61
- hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
62
-
63
- y = torch.nn.functional.pad(
64
- y.unsqueeze(1),
65
- (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
66
- mode="reflect",
67
- )
68
- y = y.squeeze(1)
69
-
70
- spec = torch.view_as_real(
71
- torch.stft(
72
- y,
73
- n_fft,
74
- hop_length=hop_size,
75
- win_length=win_size,
76
- window=hann_window[str(y.device)],
77
- center=center,
78
- pad_mode="reflect",
79
- normalized=False,
80
- onesided=True,
81
- return_complex=True,
82
- )
83
- )
84
-
85
- spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
86
-
87
- spec = torch.matmul(mel_basis[str(fmax) + "_" + str(y.device)], spec)
88
- spec = spectral_normalize_torch(spec)
89
-
90
- return spec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cosyvoice/utils/class_utils.py DELETED
@@ -1,78 +0,0 @@
1
- # Copyright [2023-11-28] <[email protected], Xingchen Song>
2
- # 2024 Alibaba Inc (authors: Xiang Lyu)
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- import torch
16
-
17
- from cosyvoice.transformer.activation import Swish
18
- from cosyvoice.transformer.subsampling import (
19
- LinearNoSubsampling,
20
- EmbedinigNoSubsampling,
21
- Conv1dSubsampling2,
22
- Conv2dSubsampling4,
23
- Conv2dSubsampling6,
24
- Conv2dSubsampling8,
25
- )
26
- from cosyvoice.transformer.embedding import (
27
- PositionalEncoding,
28
- RelPositionalEncoding,
29
- WhisperPositionalEncoding,
30
- LearnablePositionalEncoding,
31
- NoPositionalEncoding,
32
- )
33
- from cosyvoice.transformer.attention import (
34
- MultiHeadedAttention,
35
- RelPositionMultiHeadedAttention,
36
- )
37
- from cosyvoice.transformer.embedding import (
38
- EspnetRelPositionalEncoding,
39
- )
40
- from cosyvoice.transformer.subsampling import (
41
- LegacyLinearNoSubsampling,
42
- )
43
-
44
-
45
- COSYVOICE_ACTIVATION_CLASSES = {
46
- "hardtanh": torch.nn.Hardtanh,
47
- "tanh": torch.nn.Tanh,
48
- "relu": torch.nn.ReLU,
49
- "selu": torch.nn.SELU,
50
- "swish": getattr(torch.nn, "SiLU", Swish),
51
- "gelu": torch.nn.GELU,
52
- }
53
-
54
- COSYVOICE_SUBSAMPLE_CLASSES = {
55
- "linear": LinearNoSubsampling,
56
- "linear_legacy": LegacyLinearNoSubsampling,
57
- "embed": EmbedinigNoSubsampling,
58
- "conv1d2": Conv1dSubsampling2,
59
- "conv2d": Conv2dSubsampling4,
60
- "conv2d6": Conv2dSubsampling6,
61
- "conv2d8": Conv2dSubsampling8,
62
- "paraformer_dummy": torch.nn.Identity,
63
- }
64
-
65
- COSYVOICE_EMB_CLASSES = {
66
- "embed": PositionalEncoding,
67
- "abs_pos": PositionalEncoding,
68
- "rel_pos": RelPositionalEncoding,
69
- "rel_pos_espnet": EspnetRelPositionalEncoding,
70
- "no_pos": NoPositionalEncoding,
71
- "abs_pos_whisper": WhisperPositionalEncoding,
72
- "embed_learnable_pe": LearnablePositionalEncoding,
73
- }
74
-
75
- COSYVOICE_ATTENTION_CLASSES = {
76
- "selfattn": MultiHeadedAttention,
77
- "rel_selfattn": RelPositionMultiHeadedAttention,
78
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cosyvoice/utils/common.py DELETED
@@ -1,169 +0,0 @@
1
- # Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)
2
- # 2024 Alibaba Inc (authors: Xiang Lyu)
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- # Modified from ESPnet(https://github.com/espnet/espnet)
16
- """Unility functions for Transformer."""
17
-
18
- import random
19
- from typing import List
20
-
21
- import numpy as np
22
- import torch
23
-
24
- IGNORE_ID = -1
25
-
26
-
27
- def pad_list(xs: List[torch.Tensor], pad_value: int):
28
- """Perform padding for the list of tensors.
29
-
30
- Args:
31
- xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)].
32
- pad_value (float): Value for padding.
33
-
34
- Returns:
35
- Tensor: Padded tensor (B, Tmax, `*`).
36
-
37
- Examples:
38
- >>> x = [torch.ones(4), torch.ones(2), torch.ones(1)]
39
- >>> x
40
- [tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])]
41
- >>> pad_list(x, 0)
42
- tensor([[1., 1., 1., 1.],
43
- [1., 1., 0., 0.],
44
- [1., 0., 0., 0.]])
45
-
46
- """
47
- max_len = max([len(item) for item in xs])
48
- batchs = len(xs)
49
- ndim = xs[0].ndim
50
- if ndim == 1:
51
- pad_res = torch.zeros(batchs, max_len, dtype=xs[0].dtype, device=xs[0].device)
52
- elif ndim == 2:
53
- pad_res = torch.zeros(
54
- batchs, max_len, xs[0].shape[1], dtype=xs[0].dtype, device=xs[0].device
55
- )
56
- elif ndim == 3:
57
- pad_res = torch.zeros(
58
- batchs,
59
- max_len,
60
- xs[0].shape[1],
61
- xs[0].shape[2],
62
- dtype=xs[0].dtype,
63
- device=xs[0].device,
64
- )
65
- else:
66
- raise ValueError(f"Unsupported ndim: {ndim}")
67
- pad_res.fill_(pad_value)
68
- for i in range(batchs):
69
- pad_res[i, : len(xs[i])] = xs[i]
70
- return pad_res
71
-
72
-
73
- def th_accuracy(
74
- pad_outputs: torch.Tensor, pad_targets: torch.Tensor, ignore_label: int
75
- ) -> torch.Tensor:
76
- """Calculate accuracy.
77
-
78
- Args:
79
- pad_outputs (Tensor): Prediction tensors (B * Lmax, D).
80
- pad_targets (LongTensor): Target label tensors (B, Lmax).
81
- ignore_label (int): Ignore label id.
82
-
83
- Returns:
84
- torch.Tensor: Accuracy value (0.0 - 1.0).
85
-
86
- """
87
- pad_pred = pad_outputs.view(
88
- pad_targets.size(0), pad_targets.size(1), pad_outputs.size(1)
89
- ).argmax(2)
90
- mask = pad_targets != ignore_label
91
- numerator = torch.sum(
92
- pad_pred.masked_select(mask) == pad_targets.masked_select(mask)
93
- )
94
- denominator = torch.sum(mask)
95
- return (numerator / denominator).detach()
96
-
97
-
98
- def get_padding(kernel_size, dilation=1):
99
- return int((kernel_size * dilation - dilation) / 2)
100
-
101
-
102
- def init_weights(m, mean=0.0, std=0.01):
103
- classname = m.__class__.__name__
104
- if classname.find("Conv") != -1:
105
- m.weight.data.normal_(mean, std)
106
-
107
-
108
- # Repetition Aware Sampling in VALL-E 2
109
- def ras_sampling(
110
- weighted_scores,
111
- decoded_tokens,
112
- sampling,
113
- top_p=0.8,
114
- top_k=25,
115
- win_size=10,
116
- tau_r=0.1,
117
- ):
118
- top_ids = nucleus_sampling(weighted_scores, top_p=top_p, top_k=top_k)
119
- rep_num = (
120
- (torch.tensor(decoded_tokens[-win_size:]).to(weighted_scores.device) == top_ids)
121
- .sum()
122
- .item()
123
- )
124
- if rep_num >= win_size * tau_r:
125
- top_ids = random_sampling(weighted_scores, decoded_tokens, sampling)
126
- return top_ids
127
-
128
-
129
- def nucleus_sampling(weighted_scores, top_p=0.8, top_k=25):
130
- prob, indices = [], []
131
- cum_prob = 0.0
132
- sorted_value, sorted_idx = weighted_scores.softmax(dim=0).sort(
133
- descending=True, stable=True
134
- )
135
- for i in range(len(sorted_idx)):
136
- # sampling both top-p and numbers.
137
- if cum_prob < top_p and len(prob) < top_k:
138
- cum_prob += sorted_value[i]
139
- prob.append(sorted_value[i])
140
- indices.append(sorted_idx[i])
141
- else:
142
- break
143
- prob = torch.tensor(prob).to(weighted_scores)
144
- indices = torch.tensor(indices, dtype=torch.long).to(weighted_scores.device)
145
- top_ids = indices[prob.multinomial(1, replacement=True)]
146
- return top_ids
147
-
148
-
149
- def random_sampling(weighted_scores, decoded_tokens, sampling):
150
- top_ids = weighted_scores.softmax(dim=0).multinomial(1, replacement=True)
151
- return top_ids
152
-
153
-
154
- def fade_in_out(fade_in_mel, fade_out_mel, window):
155
- device = fade_in_mel.device
156
- fade_in_mel, fade_out_mel = fade_in_mel.cpu(), fade_out_mel.cpu()
157
- mel_overlap_len = int(window.shape[0] / 2)
158
- fade_in_mel[..., :mel_overlap_len] = (
159
- fade_in_mel[..., :mel_overlap_len] * window[:mel_overlap_len]
160
- + fade_out_mel[..., -mel_overlap_len:] * window[mel_overlap_len:]
161
- )
162
- return fade_in_mel.to(device)
163
-
164
-
165
- def set_all_random_seed(seed):
166
- random.seed(seed)
167
- np.random.seed(seed)
168
- torch.manual_seed(seed)
169
- torch.cuda.manual_seed_all(seed)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cosyvoice/utils/executor.py DELETED
@@ -1,151 +0,0 @@
1
- # Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)
2
- # 2024 Alibaba Inc (authors: Xiang Lyu)
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- import logging
17
- from contextlib import nullcontext
18
- import os
19
-
20
- import torch
21
- import torch.distributed as dist
22
-
23
- from cosyvoice.utils.train_utils import (
24
- update_parameter_and_lr,
25
- log_per_step,
26
- log_per_save,
27
- batch_forward,
28
- batch_backward,
29
- save_model,
30
- cosyvoice_join,
31
- )
32
-
33
-
34
- class Executor:
35
-
36
- def __init__(self):
37
- self.step = 0
38
- self.epoch = 0
39
- self.rank = int(os.environ.get("RANK", 0))
40
- self.device = torch.device("cuda:{}".format(self.rank))
41
-
42
- def train_one_epoc(
43
- self,
44
- model,
45
- optimizer,
46
- scheduler,
47
- train_data_loader,
48
- cv_data_loader,
49
- writer,
50
- info_dict,
51
- group_join,
52
- ):
53
- """Train one epoch"""
54
-
55
- lr = optimizer.param_groups[0]["lr"]
56
- logging.info(
57
- "Epoch {} TRAIN info lr {} rank {}".format(self.epoch, lr, self.rank)
58
- )
59
- logging.info(
60
- "using accumulate grad, new batch size is {} times"
61
- " larger than before".format(info_dict["accum_grad"])
62
- )
63
- # A context manager to be used in conjunction with an instance of
64
- # torch.nn.parallel.DistributedDataParallel to be able to train
65
- # with uneven inputs across participating processes.
66
- model.train()
67
- model_context = (
68
- model.join if info_dict["train_engine"] == "torch_ddp" else nullcontext
69
- )
70
- with model_context():
71
- for batch_idx, batch_dict in enumerate(train_data_loader):
72
- info_dict["tag"] = "TRAIN"
73
- info_dict["step"] = self.step
74
- info_dict["epoch"] = self.epoch
75
- info_dict["batch_idx"] = batch_idx
76
- if cosyvoice_join(group_join, info_dict):
77
- break
78
-
79
- # Disable gradient synchronizations across DDP processes.
80
- # Within this context, gradients will be accumulated on module
81
- # variables, which will later be synchronized.
82
- if (
83
- info_dict["train_engine"] == "torch_ddp"
84
- and (batch_idx + 1) % info_dict["accum_grad"] != 0
85
- ):
86
- context = model.no_sync
87
- # Used for single gpu training and DDP gradient synchronization
88
- # processes.
89
- else:
90
- context = nullcontext
91
-
92
- with context():
93
- info_dict = batch_forward(model, batch_dict, info_dict)
94
- info_dict = batch_backward(model, info_dict)
95
-
96
- info_dict = update_parameter_and_lr(
97
- model, optimizer, scheduler, info_dict
98
- )
99
- log_per_step(writer, info_dict)
100
- # NOTE specify save_per_step in cosyvoice.yaml if you want to enable step save
101
- if (
102
- info_dict["save_per_step"] > 0
103
- and (self.step + 1) % info_dict["save_per_step"] == 0
104
- and (batch_idx + 1) % info_dict["accum_grad"] == 0
105
- ):
106
- dist.barrier()
107
- self.cv(
108
- model, cv_data_loader, writer, info_dict, on_batch_end=False
109
- )
110
- model.train()
111
- if (batch_idx + 1) % info_dict["accum_grad"] == 0:
112
- self.step += 1
113
- dist.barrier()
114
- self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=True)
115
-
116
- @torch.inference_mode()
117
- def cv(self, model, cv_data_loader, writer, info_dict, on_batch_end=True):
118
- """Cross validation on"""
119
- logging.info(
120
- "Epoch {} Step {} on_batch_end {} CV rank {}".format(
121
- self.epoch, self.step + 1, on_batch_end, self.rank
122
- )
123
- )
124
- model.eval()
125
- total_num_utts, total_loss_dict = 0, {} # avoid division by 0
126
- for batch_idx, batch_dict in enumerate(cv_data_loader):
127
- info_dict["tag"] = "CV"
128
- info_dict["step"] = self.step
129
- info_dict["epoch"] = self.epoch
130
- info_dict["batch_idx"] = batch_idx
131
-
132
- num_utts = len(batch_dict["utts"])
133
- total_num_utts += num_utts
134
-
135
- info_dict = batch_forward(model, batch_dict, info_dict)
136
-
137
- for k, v in info_dict["loss_dict"].items():
138
- if k not in total_loss_dict:
139
- total_loss_dict[k] = []
140
- total_loss_dict[k].append(v.item() * num_utts)
141
- log_per_step(None, info_dict)
142
- for k, v in total_loss_dict.items():
143
- total_loss_dict[k] = sum(v) / total_num_utts
144
- info_dict["loss_dict"] = total_loss_dict
145
- log_per_save(writer, info_dict)
146
- model_name = (
147
- "epoch_{}_whole".format(self.epoch)
148
- if on_batch_end
149
- else "epoch_{}_step_{}".format(self.epoch, self.step + 1)
150
- )
151
- save_model(model, model_name, info_dict)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cosyvoice/utils/file_utils.py DELETED
@@ -1,49 +0,0 @@
1
- # Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
2
- # 2024 Alibaba Inc (authors: Xiang Lyu)
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- import json
17
- import torchaudio
18
- import logging
19
-
20
- logging.getLogger("matplotlib").setLevel(logging.WARNING)
21
- logging.basicConfig(level=logging.DEBUG, format="%(asctime)s %(levelname)s %(message)s")
22
-
23
-
24
- def read_lists(list_file):
25
- lists = []
26
- with open(list_file, "r", encoding="utf8") as fin:
27
- for line in fin:
28
- lists.append(line.strip())
29
- return lists
30
-
31
-
32
- def read_json_lists(list_file):
33
- lists = read_lists(list_file)
34
- results = {}
35
- for fn in lists:
36
- with open(fn, "r", encoding="utf8") as fin:
37
- results.update(json.load(fin))
38
- return results
39
-
40
-
41
- def load_wav(wav, target_sr):
42
- speech, sample_rate = torchaudio.load(wav)
43
- speech = speech.mean(dim=0, keepdim=True)
44
- if sample_rate != target_sr:
45
- # assert sample_rate > target_sr, 'wav sample rate {} must be greater than {}'.format(sample_rate, target_sr)
46
- speech = torchaudio.transforms.Resample(
47
- orig_freq=sample_rate, new_freq=target_sr
48
- )(speech)
49
- return speech
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cosyvoice/utils/frontend_utils.py DELETED
@@ -1,142 +0,0 @@
1
- # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- import re
16
-
17
- chinese_char_pattern = re.compile(r"[\u4e00-\u9fff]+")
18
-
19
-
20
- # whether contain chinese character
21
- def contains_chinese(text):
22
- return bool(chinese_char_pattern.search(text))
23
-
24
-
25
- # replace special symbol
26
- def replace_corner_mark(text):
27
- text = text.replace("²", "平方")
28
- text = text.replace("³", "立方")
29
- return text
30
-
31
-
32
- # remove meaningless symbol
33
- def remove_bracket(text):
34
- text = text.replace("(", "").replace(")", "")
35
- text = text.replace("【", "").replace("】", "")
36
- text = text.replace("`", "").replace("`", "")
37
- text = text.replace("——", " ")
38
- return text
39
-
40
-
41
- # spell Arabic numerals
42
- def spell_out_number(text: str, inflect_parser):
43
- new_text = []
44
- st = None
45
- for i, c in enumerate(text):
46
- if not c.isdigit():
47
- if st is not None:
48
- num_str = inflect_parser.number_to_words(text[st:i])
49
- new_text.append(num_str)
50
- st = None
51
- new_text.append(c)
52
- else:
53
- if st is None:
54
- st = i
55
- if st is not None and st < len(text):
56
- num_str = inflect_parser.number_to_words(text[st:])
57
- new_text.append(num_str)
58
- return "".join(new_text)
59
-
60
-
61
- # split paragrah logic:
62
- # 1. per sentence max len token_max_n, min len token_min_n, merge if last sentence len less than merge_len
63
- # 2. cal sentence len according to lang
64
- # 3. split sentence according to puncatation
65
- def split_paragraph(
66
- text: str,
67
- tokenize,
68
- lang="zh",
69
- token_max_n=80,
70
- token_min_n=60,
71
- merge_len=20,
72
- comma_split=False,
73
- ):
74
- def calc_utt_length(_text: str):
75
- if lang == "zh":
76
- return len(_text)
77
- else:
78
- return len(tokenize(_text))
79
-
80
- def should_merge(_text: str):
81
- if lang == "zh":
82
- return len(_text) < merge_len
83
- else:
84
- return len(tokenize(_text)) < merge_len
85
-
86
- if lang == "zh":
87
- pounc = ["。", "?", "!", ";", ":", "、", ".", "?", "!", ";"]
88
- else:
89
- pounc = [".", "?", "!", ";", ":"]
90
- if comma_split:
91
- pounc.extend([",", ","])
92
-
93
- if text[-1] not in pounc:
94
- if lang == "zh":
95
- text += "。"
96
- else:
97
- text += "."
98
-
99
- st = 0
100
- utts = []
101
- for i, c in enumerate(text):
102
- if c in pounc:
103
- if len(text[st:i]) > 0:
104
- utts.append(text[st:i] + c)
105
- if i + 1 < len(text) and text[i + 1] in ['"', "”"]:
106
- tmp = utts.pop(-1)
107
- utts.append(tmp + text[i + 1])
108
- st = i + 2
109
- else:
110
- st = i + 1
111
-
112
- final_utts = []
113
- cur_utt = ""
114
- for utt in utts:
115
- if (
116
- calc_utt_length(cur_utt + utt) > token_max_n
117
- and calc_utt_length(cur_utt) > token_min_n
118
- ):
119
- final_utts.append(cur_utt)
120
- cur_utt = ""
121
- cur_utt = cur_utt + utt
122
- if len(cur_utt) > 0:
123
- if should_merge(cur_utt) and len(final_utts) != 0:
124
- final_utts[-1] = final_utts[-1] + cur_utt
125
- else:
126
- final_utts.append(cur_utt)
127
-
128
- return final_utts
129
-
130
-
131
- # remove blank between chinese character
132
- def replace_blank(text: str):
133
- out_str = []
134
- for i, c in enumerate(text):
135
- if c == " ":
136
- if (text[i + 1].isascii() and text[i + 1] != " ") and (
137
- text[i - 1].isascii() and text[i - 1] != " "
138
- ):
139
- out_str.append(c)
140
- else:
141
- out_str.append(c)
142
- return "".join(out_str)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cosyvoice/utils/mask.py DELETED
@@ -1,226 +0,0 @@
1
- # Copyright (c) 2019 Shigeki Karita
2
- # 2020 Mobvoi Inc (Binbin Zhang)
3
- # 2024 Alibaba Inc (authors: Xiang Lyu)
4
- #
5
- # Licensed under the Apache License, Version 2.0 (the "License");
6
- # you may not use this file except in compliance with the License.
7
- # You may obtain a copy of the License at
8
- #
9
- # http://www.apache.org/licenses/LICENSE-2.0
10
- #
11
- # Unless required by applicable law or agreed to in writing, software
12
- # distributed under the License is distributed on an "AS IS" BASIS,
13
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
- # See the License for the specific language governing permissions and
15
- # limitations under the License.
16
-
17
- import torch
18
-
19
- '''
20
- def subsequent_mask(
21
- size: int,
22
- device: torch.device = torch.device("cpu"),
23
- ) -> torch.Tensor:
24
- """Create mask for subsequent steps (size, size).
25
-
26
- This mask is used only in decoder which works in an auto-regressive mode.
27
- This means the current step could only do attention with its left steps.
28
-
29
- In encoder, fully attention is used when streaming is not necessary and
30
- the sequence is not long. In this case, no attention mask is needed.
31
-
32
- When streaming is need, chunk-based attention is used in encoder. See
33
- subsequent_chunk_mask for the chunk-based attention mask.
34
-
35
- Args:
36
- size (int): size of mask
37
- str device (str): "cpu" or "cuda" or torch.Tensor.device
38
- dtype (torch.device): result dtype
39
-
40
- Returns:
41
- torch.Tensor: mask
42
-
43
- Examples:
44
- >>> subsequent_mask(3)
45
- [[1, 0, 0],
46
- [1, 1, 0],
47
- [1, 1, 1]]
48
- """
49
- ret = torch.ones(size, size, device=device, dtype=torch.bool)
50
- return torch.tril(ret)
51
- '''
52
-
53
-
54
- def subsequent_mask(
55
- size: int,
56
- device: torch.device = torch.device("cpu"),
57
- ) -> torch.Tensor:
58
- """Create mask for subsequent steps (size, size).
59
-
60
- This mask is used only in decoder which works in an auto-regressive mode.
61
- This means the current step could only do attention with its left steps.
62
-
63
- In encoder, fully attention is used when streaming is not necessary and
64
- the sequence is not long. In this case, no attention mask is needed.
65
-
66
- When streaming is need, chunk-based attention is used in encoder. See
67
- subsequent_chunk_mask for the chunk-based attention mask.
68
-
69
- Args:
70
- size (int): size of mask
71
- str device (str): "cpu" or "cuda" or torch.Tensor.device
72
- dtype (torch.device): result dtype
73
-
74
- Returns:
75
- torch.Tensor: mask
76
-
77
- Examples:
78
- >>> subsequent_mask(3)
79
- [[1, 0, 0],
80
- [1, 1, 0],
81
- [1, 1, 1]]
82
- """
83
- arange = torch.arange(size, device=device)
84
- mask = arange.expand(size, size)
85
- arange = arange.unsqueeze(-1)
86
- mask = mask <= arange
87
- return mask
88
-
89
-
90
- def subsequent_chunk_mask(
91
- size: int,
92
- chunk_size: int,
93
- num_left_chunks: int = -1,
94
- device: torch.device = torch.device("cpu"),
95
- ) -> torch.Tensor:
96
- """Create mask for subsequent steps (size, size) with chunk size,
97
- this is for streaming encoder
98
-
99
- Args:
100
- size (int): size of mask
101
- chunk_size (int): size of chunk
102
- num_left_chunks (int): number of left chunks
103
- <0: use full chunk
104
- >=0: use num_left_chunks
105
- device (torch.device): "cpu" or "cuda" or torch.Tensor.device
106
-
107
- Returns:
108
- torch.Tensor: mask
109
-
110
- Examples:
111
- >>> subsequent_chunk_mask(4, 2)
112
- [[1, 1, 0, 0],
113
- [1, 1, 0, 0],
114
- [1, 1, 1, 1],
115
- [1, 1, 1, 1]]
116
- """
117
- ret = torch.zeros(size, size, device=device, dtype=torch.bool)
118
- for i in range(size):
119
- if num_left_chunks < 0:
120
- start = 0
121
- else:
122
- start = max((i // chunk_size - num_left_chunks) * chunk_size, 0)
123
- ending = min((i // chunk_size + 1) * chunk_size, size)
124
- ret[i, start:ending] = True
125
- return ret
126
-
127
-
128
- def add_optional_chunk_mask(
129
- xs: torch.Tensor,
130
- masks: torch.Tensor,
131
- use_dynamic_chunk: bool,
132
- use_dynamic_left_chunk: bool,
133
- decoding_chunk_size: int,
134
- static_chunk_size: int,
135
- num_decoding_left_chunks: int,
136
- enable_full_context: bool = True,
137
- ):
138
- """Apply optional mask for encoder.
139
-
140
- Args:
141
- xs (torch.Tensor): padded input, (B, L, D), L for max length
142
- mask (torch.Tensor): mask for xs, (B, 1, L)
143
- use_dynamic_chunk (bool): whether to use dynamic chunk or not
144
- use_dynamic_left_chunk (bool): whether to use dynamic left chunk for
145
- training.
146
- decoding_chunk_size (int): decoding chunk size for dynamic chunk, it's
147
- 0: default for training, use random dynamic chunk.
148
- <0: for decoding, use full chunk.
149
- >0: for decoding, use fixed chunk size as set.
150
- static_chunk_size (int): chunk size for static chunk training/decoding
151
- if it's greater than 0, if use_dynamic_chunk is true,
152
- this parameter will be ignored
153
- num_decoding_left_chunks: number of left chunks, this is for decoding,
154
- the chunk size is decoding_chunk_size.
155
- >=0: use num_decoding_left_chunks
156
- <0: use all left chunks
157
- enable_full_context (bool):
158
- True: chunk size is either [1, 25] or full context(max_len)
159
- False: chunk size ~ U[1, 25]
160
-
161
- Returns:
162
- torch.Tensor: chunk mask of the input xs.
163
- """
164
- # Whether to use chunk mask or not
165
- if use_dynamic_chunk:
166
- max_len = xs.size(1)
167
- if decoding_chunk_size < 0:
168
- chunk_size = max_len
169
- num_left_chunks = -1
170
- elif decoding_chunk_size > 0:
171
- chunk_size = decoding_chunk_size
172
- num_left_chunks = num_decoding_left_chunks
173
- else:
174
- # chunk size is either [1, 25] or full context(max_len).
175
- # Since we use 4 times subsampling and allow up to 1s(100 frames)
176
- # delay, the maximum frame is 100 / 4 = 25.
177
- chunk_size = torch.randint(1, max_len, (1,)).item()
178
- num_left_chunks = -1
179
- if chunk_size > max_len // 2 and enable_full_context:
180
- chunk_size = max_len
181
- else:
182
- chunk_size = chunk_size % 25 + 1
183
- if use_dynamic_left_chunk:
184
- max_left_chunks = (max_len - 1) // chunk_size
185
- num_left_chunks = torch.randint(0, max_left_chunks, (1,)).item()
186
- chunk_masks = subsequent_chunk_mask(
187
- xs.size(1), chunk_size, num_left_chunks, xs.device
188
- ) # (L, L)
189
- chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
190
- chunk_masks = masks & chunk_masks # (B, L, L)
191
- elif static_chunk_size > 0:
192
- num_left_chunks = num_decoding_left_chunks
193
- chunk_masks = subsequent_chunk_mask(
194
- xs.size(1), static_chunk_size, num_left_chunks, xs.device
195
- ) # (L, L)
196
- chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
197
- chunk_masks = masks & chunk_masks # (B, L, L)
198
- else:
199
- chunk_masks = masks
200
- return chunk_masks
201
-
202
-
203
- def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
204
- """Make mask tensor containing indices of padded part.
205
-
206
- See description of make_non_pad_mask.
207
-
208
- Args:
209
- lengths (torch.Tensor): Batch of lengths (B,).
210
- Returns:
211
- torch.Tensor: Mask tensor containing indices of padded part.
212
-
213
- Examples:
214
- >>> lengths = [5, 3, 2]
215
- >>> make_pad_mask(lengths)
216
- masks = [[0, 0, 0, 0 ,0],
217
- [0, 0, 0, 1, 1],
218
- [0, 0, 1, 1, 1]]
219
- """
220
- batch_size = lengths.size(0)
221
- max_len = max_len if max_len > 0 else lengths.max().item()
222
- seq_range = torch.arange(0, max_len, dtype=torch.int64, device=lengths.device)
223
- seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
224
- seq_length_expand = lengths.unsqueeze(-1)
225
- mask = seq_range_expand >= seq_length_expand
226
- return mask
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cosyvoice/utils/scheduler.py DELETED
@@ -1,761 +0,0 @@
1
- # Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)
2
- # 2022 Ximalaya Inc (Yuguang Yang)
3
- # 2024 Alibaba Inc (authors: Xiang Lyu)
4
- #
5
- # Licensed under the Apache License, Version 2.0 (the "License");
6
- # you may not use this file except in compliance with the License.
7
- # You may obtain a copy of the License at
8
- #
9
- # http://www.apache.org/licenses/LICENSE-2.0
10
- #
11
- # Unless required by applicable law or agreed to in writing, software
12
- # distributed under the License is distributed on an "AS IS" BASIS,
13
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
- # See the License for the specific language governing permissions and
15
- # limitations under the License.
16
- # Modified from ESPnet(https://github.com/espnet/espnet)
17
- # NeMo(https://github.com/NVIDIA/NeMo)
18
-
19
- from typing import Union
20
-
21
- import math
22
- import warnings
23
- import torch
24
- from torch.optim.lr_scheduler import _LRScheduler
25
-
26
-
27
- class WarmupLR(_LRScheduler):
28
- """The WarmupLR scheduler
29
-
30
- This scheduler is almost same as NoamLR Scheduler except for following
31
- difference:
32
-
33
- NoamLR:
34
- lr = optimizer.lr * model_size ** -0.5
35
- * min(step ** -0.5, step * warmup_step ** -1.5)
36
- WarmupLR:
37
- lr = optimizer.lr * warmup_step ** 0.5
38
- * min(step ** -0.5, step * warmup_step ** -1.5)
39
-
40
- Note that the maximum lr equals to optimizer.lr in this scheduler.
41
-
42
- """
43
-
44
- def __init__(
45
- self,
46
- optimizer: torch.optim.Optimizer,
47
- warmup_steps: Union[int, float] = 25000,
48
- last_epoch: int = -1,
49
- ):
50
- self.warmup_steps = warmup_steps
51
-
52
- # __init__() must be invoked before setting field
53
- # because step() is also invoked in __init__()
54
- super().__init__(optimizer, last_epoch)
55
-
56
- def __repr__(self):
57
- return f"{self.__class__.__name__}(warmup_steps={self.warmup_steps})"
58
-
59
- def get_lr(self):
60
- step_num = self.last_epoch + 1
61
- if self.warmup_steps == 0:
62
- return [lr * step_num**-0.5 for lr in self.base_lrs]
63
- else:
64
- return [
65
- lr
66
- * self.warmup_steps**0.5
67
- * min(step_num**-0.5, step_num * self.warmup_steps**-1.5)
68
- for lr in self.base_lrs
69
- ]
70
-
71
- def set_step(self, step: int):
72
- self.last_epoch = step
73
-
74
-
75
- class WarmupPolicy(_LRScheduler):
76
- """Adds warmup kwargs and warmup logic to lr policy.
77
- All arguments should be passed as kwargs for clarity,
78
- Args:
79
- warmup_steps: Number of training steps in warmup stage
80
- warmup_ratio: Ratio of warmup steps to total steps
81
- max_steps: Total number of steps while training or `None` for
82
- infinite training
83
- """
84
-
85
- def __init__(
86
- self,
87
- optimizer,
88
- *,
89
- warmup_steps=None,
90
- warmup_ratio=None,
91
- max_steps=None,
92
- min_lr=0.0,
93
- last_epoch=-1,
94
- ):
95
- assert not (
96
- warmup_steps is not None and warmup_ratio is not None
97
- ), "Either use particular number of step or ratio"
98
- assert (
99
- warmup_ratio is None or max_steps is not None
100
- ), "If there is a ratio, there should be a total steps"
101
-
102
- # It is necessary to assign all attributes *before* __init__,
103
- # as class is wrapped by an inner class.
104
- self.max_steps = max_steps
105
- if warmup_steps is not None:
106
- self.warmup_steps = warmup_steps
107
- elif warmup_ratio is not None:
108
- self.warmup_steps = int(warmup_ratio * max_steps)
109
- else:
110
- self.warmup_steps = 0
111
-
112
- self.min_lr = min_lr
113
- super().__init__(optimizer, last_epoch)
114
-
115
- def get_lr(self):
116
- if not self._get_lr_called_within_step:
117
- warnings.warn(
118
- "To get the last learning rate computed "
119
- "by the scheduler, please use `get_last_lr()`.",
120
- UserWarning,
121
- stacklevel=2,
122
- )
123
-
124
- step = self.last_epoch
125
-
126
- if step <= self.warmup_steps and self.warmup_steps > 0:
127
- return self._get_warmup_lr(step)
128
-
129
- if step > self.max_steps:
130
- return [self.min_lr for _ in self.base_lrs]
131
-
132
- return self._get_lr(step)
133
-
134
- def _get_warmup_lr(self, step):
135
- lr_val = (step + 1) / (self.warmup_steps + 1)
136
- return [initial_lr * lr_val for initial_lr in self.base_lrs]
137
-
138
- def _get_lr(self, step):
139
- """Simple const lr policy"""
140
- return self.base_lrs
141
-
142
-
143
- class SquareRootConstantPolicy(_LRScheduler):
144
- """Adds warmup kwargs and warmup logic to lr policy.
145
- All arguments should be passed as kwargs for clarity,
146
- Args:
147
- warmup_steps: Number of training steps in warmup stage
148
- warmup_ratio: Ratio of warmup steps to total steps
149
- max_steps: Total number of steps while training or `None` for
150
- infinite training
151
- """
152
-
153
- def __init__(
154
- self,
155
- optimizer,
156
- *,
157
- constant_steps=None,
158
- constant_ratio=None,
159
- max_steps=None,
160
- min_lr=0.0,
161
- last_epoch=-1,
162
- ):
163
- assert not (
164
- constant_steps is not None and constant_ratio is not None
165
- ), "Either use particular number of step or ratio"
166
- assert (
167
- constant_ratio is None or max_steps is not None
168
- ), "If there is a ratio, there should be a total steps"
169
-
170
- # It is necessary to assign all attributes *before* __init__,
171
- # as class is wrapped by an inner class.
172
- self.max_steps = max_steps
173
- if constant_steps is not None:
174
- self.constant_steps = constant_steps
175
- elif constant_ratio is not None:
176
- self.constant_steps = int(constant_ratio * max_steps)
177
- else:
178
- self.constant_steps = 0
179
-
180
- self.constant_lr = 1 / (constant_steps**0.5)
181
- self.min_lr = min_lr
182
- super().__init__(optimizer, last_epoch)
183
-
184
- def get_lr(self):
185
- if not self._get_lr_called_within_step:
186
- warnings.warn(
187
- "To get the last learning rate computed "
188
- "by the scheduler, please use `get_last_lr()`.",
189
- UserWarning,
190
- stacklevel=2,
191
- )
192
-
193
- step = self.last_epoch
194
-
195
- if step <= self.constant_steps:
196
- return [self.constant_lr for _ in self.base_lrs]
197
-
198
- if step > self.max_steps:
199
- return [self.min_lr for _ in self.base_lrs]
200
-
201
- return self._get_lr(step)
202
-
203
- def _get_lr(self, step):
204
- """Simple const lr policy"""
205
- return self.base_lrs
206
-
207
-
208
- class WarmupHoldPolicy(WarmupPolicy):
209
- """Variant of WarmupPolicy which maintains high
210
- learning rate for a defined number of steps.
211
- All arguments should be passed as kwargs for clarity,
212
- Args:
213
- warmup_steps: Number of training steps in warmup stage
214
- warmup_ratio: Ratio of warmup steps to total steps
215
- hold_steps: Number of training steps to
216
- hold the learning rate after warm up
217
- hold_ratio: Ratio of hold steps to total steps
218
- max_steps: Total number of steps while training or `None` for
219
- infinite training
220
- """
221
-
222
- def __init__(
223
- self,
224
- optimizer,
225
- *,
226
- warmup_steps=None,
227
- warmup_ratio=None,
228
- hold_steps=None,
229
- hold_ratio=None,
230
- max_steps=None,
231
- min_lr=0.0,
232
- last_epoch=-1,
233
- ):
234
- assert not (
235
- hold_steps is not None and hold_ratio is not None
236
- ), "Either use particular number of step or ratio"
237
- assert (
238
- hold_ratio is None or max_steps is not None
239
- ), "If there is a ratio, there should be a total steps"
240
-
241
- self.min_lr = min_lr
242
- self._last_warmup_lr = 0.0
243
-
244
- # Necessary to duplicate as class attributes are hidden in inner class
245
- self.max_steps = max_steps
246
- if warmup_steps is not None:
247
- self.warmup_steps = warmup_steps
248
- elif warmup_ratio is not None:
249
- self.warmup_steps = int(warmup_ratio * max_steps)
250
- else:
251
- self.warmup_steps = 0
252
-
253
- if hold_steps is not None:
254
- self.hold_steps = hold_steps + self.warmup_steps
255
- elif hold_ratio is not None:
256
- self.hold_steps = int(hold_ratio * max_steps) + self.warmup_steps
257
- else:
258
- self.hold_steps = 0
259
-
260
- super().__init__(
261
- optimizer,
262
- warmup_steps=warmup_steps,
263
- warmup_ratio=warmup_ratio,
264
- max_steps=max_steps,
265
- last_epoch=last_epoch,
266
- min_lr=min_lr,
267
- )
268
-
269
- def get_lr(self):
270
- if not self._get_lr_called_within_step:
271
- warnings.warn(
272
- "To get the last learning rate computed by the scheduler,"
273
- " "
274
- "please use `get_last_lr()`.",
275
- UserWarning,
276
- stacklevel=2,
277
- )
278
-
279
- step = self.last_epoch
280
-
281
- # Warmup phase
282
- if step <= self.warmup_steps and self.warmup_steps > 0:
283
- return self._get_warmup_lr(step)
284
-
285
- # Hold phase
286
- if (step >= self.warmup_steps) and (step < self.hold_steps):
287
- return self.base_lrs
288
-
289
- if step > self.max_steps:
290
- return [self.min_lr for _ in self.base_lrs]
291
-
292
- return self._get_lr(step)
293
-
294
-
295
- class WarmupAnnealHoldPolicy(_LRScheduler):
296
- """Adds warmup kwargs and warmup logic to lr policy.
297
- All arguments should be passed as kwargs for clarity,
298
- Args:
299
- warmup_steps: Number of training steps in warmup stage
300
- warmup_ratio: Ratio of warmup steps to total steps
301
- max_steps: Total number of steps while training or `None` for
302
- infinite training
303
- min_lr: Minimum lr to hold the learning rate after decay at.
304
- constant_steps: Number of steps to keep lr constant at.
305
- constant_ratio: Ratio of steps to keep lr constant.
306
- """
307
-
308
- def __init__(
309
- self,
310
- optimizer,
311
- *,
312
- warmup_steps=None,
313
- warmup_ratio=None,
314
- constant_steps=None,
315
- constant_ratio=None,
316
- max_steps=None,
317
- min_lr=0.0,
318
- last_epoch=-1,
319
- ):
320
- assert not (
321
- warmup_steps is not None and warmup_ratio is not None
322
- ), "Either use particular number of step or ratio"
323
- assert not (
324
- constant_steps is not None and constant_ratio is not None
325
- ), "Either use constant_steps or constant_ratio"
326
- assert (
327
- warmup_ratio is None or max_steps is not None
328
- ), "If there is a ratio, there should be a total steps"
329
-
330
- # It is necessary to assign all attributes *before* __init__,
331
- # as class is wrapped by an inner class.
332
- self.max_steps = max_steps
333
-
334
- if warmup_steps is not None:
335
- self.warmup_steps = warmup_steps
336
- elif warmup_ratio is not None:
337
- self.warmup_steps = int(warmup_ratio * max_steps)
338
- else:
339
- self.warmup_steps = 0
340
-
341
- if constant_steps is not None:
342
- self.constant_steps = constant_steps
343
- elif constant_ratio is not None:
344
- self.constant_steps = int(constant_ratio * max_steps)
345
- else:
346
- self.constant_steps = 0
347
-
348
- self.decay_steps = max_steps - (self.constant_steps + self.warmup_steps)
349
-
350
- self.min_lr = min_lr
351
- super().__init__(optimizer, last_epoch)
352
-
353
- def get_lr(self):
354
- if not self._get_lr_called_within_step:
355
- warnings.warn(
356
- "To get the last learning rate computed "
357
- "by the scheduler, please use `get_last_lr()`.",
358
- UserWarning,
359
- stacklevel=2,
360
- )
361
-
362
- step = self.last_epoch
363
-
364
- # Warmup steps
365
- if self.warmup_steps > 0 and step <= self.warmup_steps:
366
- return self._get_warmup_lr(step)
367
-
368
- # Constant steps after warmup and decay
369
- if (
370
- self.constant_steps > 0
371
- and (self.warmup_steps + self.decay_steps) < step <= self.max_steps
372
- ):
373
- return self._get_constant_lr(step)
374
-
375
- # Min lr after max steps of updates
376
- if step > self.max_steps:
377
- return [self.min_lr for _ in self.base_lrs]
378
-
379
- return self._get_lr(step)
380
-
381
- def _get_warmup_lr(self, step):
382
- lr_val = (step + 1) / (self.warmup_steps + 1)
383
- return [initial_lr * lr_val for initial_lr in self.base_lrs]
384
-
385
- def _get_constant_lr(self, step):
386
- return [self.min_lr for _ in self.base_lrs]
387
-
388
- def _get_lr(self, step):
389
- """Simple const lr policy"""
390
- return self.base_lrs
391
-
392
-
393
- def _squareroot_annealing(initial_lr, step, max_steps, min_lr):
394
- mult = ((max_steps - step) / max_steps) ** 0.5
395
- out_lr = initial_lr * mult
396
- out_lr = max(out_lr, min_lr)
397
- return out_lr
398
-
399
-
400
- def _square_annealing(initial_lr, step, max_steps, min_lr):
401
- mult = ((max_steps - step) / max_steps) ** 2
402
- out_lr = initial_lr * mult
403
- out_lr = max(out_lr, min_lr)
404
- return out_lr
405
-
406
-
407
- def _cosine_annealing(initial_lr, step, max_steps, min_lr):
408
- mult = 0.5 * (1 + math.cos(math.pi * step / max_steps))
409
- out_lr = (initial_lr - min_lr) * mult + min_lr
410
- return out_lr
411
-
412
-
413
- def _linear_warmup_with_cosine_annealing(
414
- max_lr, warmup_steps, step, decay_steps, min_lr
415
- ):
416
- assert max_lr > min_lr
417
- # Use linear warmup for the initial part.
418
- if warmup_steps > 0 and step <= warmup_steps:
419
- return max_lr * float(step) / float(warmup_steps)
420
-
421
- # For any steps larger than `decay_steps`, use `min_lr`.
422
- if step > warmup_steps + decay_steps:
423
- return min_lr
424
-
425
- # If we are done with the warmup period, use the decay style.
426
- num_steps_ = step - warmup_steps
427
- decay_steps_ = decay_steps
428
- decay_ratio = float(num_steps_) / float(decay_steps_)
429
- assert decay_ratio >= 0.0
430
- assert decay_ratio <= 1.0
431
- delta_lr = max_lr - min_lr
432
-
433
- coeff = 0.5 * (math.cos(math.pi * decay_ratio) + 1.0)
434
-
435
- return min_lr + coeff * delta_lr
436
-
437
-
438
- def _poly_decay(initial_lr, step, decay_steps, power, min_lr, cycle):
439
- if cycle:
440
- multiplier = 1.0 if step == 0 else math.ceil(step / decay_steps)
441
- decay_steps *= multiplier
442
- else:
443
- step = min(step, decay_steps)
444
- p = step / decay_steps
445
- lr = (initial_lr - min_lr) * math.pow(1.0 - p, power)
446
- lr += min_lr
447
- return lr
448
-
449
-
450
- def _noam_hold_annealing(
451
- initial_lr, step, warmup_steps, hold_steps, decay_rate, min_lr
452
- ):
453
- # hold_steps = total number of steps
454
- # to hold the LR, not the warmup + hold steps.
455
- T_warmup_decay = max(1, warmup_steps**decay_rate)
456
- T_hold_decay = max(1, (step - hold_steps) ** decay_rate)
457
- lr = (initial_lr * T_warmup_decay) / T_hold_decay
458
- lr = max(lr, min_lr)
459
- return lr
460
-
461
-
462
- class SquareAnnealing(WarmupPolicy):
463
-
464
- def __init__(self, optimizer, *, max_steps, min_lr=1e-5, last_epoch=-1, **kwargs):
465
- super().__init__(
466
- optimizer=optimizer,
467
- max_steps=max_steps,
468
- last_epoch=last_epoch,
469
- min_lr=min_lr,
470
- **kwargs,
471
- )
472
-
473
- def _get_lr(self, step):
474
- new_lrs = [
475
- _square_annealing(
476
- initial_lr=initial_lr,
477
- step=step - self.warmup_steps,
478
- max_steps=self.max_steps - self.warmup_steps,
479
- min_lr=self.min_lr,
480
- )
481
- for initial_lr in self.base_lrs
482
- ]
483
- return new_lrs
484
-
485
-
486
- class SquareRootAnnealing(WarmupPolicy):
487
-
488
- def __init__(self, optimizer, *, max_steps, min_lr=0, last_epoch=-1, **kwargs):
489
- super().__init__(
490
- optimizer=optimizer,
491
- max_steps=max_steps,
492
- last_epoch=last_epoch,
493
- min_lr=min_lr,
494
- **kwargs,
495
- )
496
-
497
- def _get_lr(self, step):
498
- new_lrs = [
499
- _squareroot_annealing(
500
- initial_lr=initial_lr,
501
- step=step,
502
- max_steps=self.max_steps,
503
- min_lr=self.min_lr,
504
- )
505
- for initial_lr in self.base_lrs
506
- ]
507
- return new_lrs
508
-
509
-
510
- class CosineAnnealing(WarmupAnnealHoldPolicy):
511
-
512
- def __init__(self, optimizer, *, max_steps, min_lr=0, last_epoch=-1, **kwargs):
513
- super().__init__(
514
- optimizer=optimizer,
515
- max_steps=max_steps,
516
- last_epoch=last_epoch,
517
- min_lr=min_lr,
518
- **kwargs,
519
- )
520
-
521
- def _get_lr(self, step):
522
- for initial_lr in self.base_lrs:
523
- if initial_lr < self.min_lr:
524
- raise ValueError(
525
- f"{self} received an initial learning rate "
526
- f"that was lower than the minimum learning rate."
527
- )
528
-
529
- if self.constant_steps is None or self.constant_steps == 0:
530
- new_lrs = [
531
- _cosine_annealing(
532
- initial_lr=initial_lr,
533
- step=step - self.warmup_steps,
534
- max_steps=self.max_steps - self.warmup_steps,
535
- min_lr=self.min_lr,
536
- )
537
- for initial_lr in self.base_lrs
538
- ]
539
- else:
540
- new_lrs = self._get_linear_warmup_with_cosine_annealing_lr(step)
541
- return new_lrs
542
-
543
- def _get_warmup_lr(self, step):
544
- if self.constant_steps is None or self.constant_steps == 0:
545
- return super()._get_warmup_lr(step)
546
- else:
547
- # Use linear warmup for the initial part.
548
- return self._get_linear_warmup_with_cosine_annealing_lr(step)
549
-
550
- def _get_constant_lr(self, step):
551
- # Only called when `constant_steps` > 0.
552
- return self._get_linear_warmup_with_cosine_annealing_lr(step)
553
-
554
- def _get_linear_warmup_with_cosine_annealing_lr(self, step):
555
- # Cosine Schedule for Megatron LM,
556
- # slightly different warmup schedule + constant LR at the end.
557
- new_lrs = [
558
- _linear_warmup_with_cosine_annealing(
559
- max_lr=self.base_lrs[0],
560
- warmup_steps=self.warmup_steps,
561
- step=step,
562
- decay_steps=self.decay_steps,
563
- min_lr=self.min_lr,
564
- )
565
- for _ in self.base_lrs
566
- ]
567
- return new_lrs
568
-
569
-
570
- class NoamAnnealing(_LRScheduler):
571
-
572
- def __init__(
573
- self,
574
- optimizer,
575
- *,
576
- d_model,
577
- warmup_steps=None,
578
- warmup_ratio=None,
579
- max_steps=None,
580
- min_lr=0.0,
581
- last_epoch=-1,
582
- ):
583
- self._normalize = d_model ** (-0.5)
584
- assert not (
585
- warmup_steps is not None and warmup_ratio is not None
586
- ), "Either use particular number of step or ratio"
587
- assert (
588
- warmup_ratio is None or max_steps is not None
589
- ), "If there is a ratio, there should be a total steps"
590
-
591
- # It is necessary to assign all attributes *before* __init__,
592
- # as class is wrapped by an inner class.
593
- self.max_steps = max_steps
594
- if warmup_steps is not None:
595
- self.warmup_steps = warmup_steps
596
- elif warmup_ratio is not None:
597
- self.warmup_steps = int(warmup_ratio * max_steps)
598
- else:
599
- self.warmup_steps = 0
600
-
601
- self.min_lr = min_lr
602
- super().__init__(optimizer, last_epoch)
603
-
604
- def get_lr(self):
605
- if not self._get_lr_called_within_step:
606
- warnings.warn(
607
- "To get the last learning rate computed "
608
- "by the scheduler, please use `get_last_lr()`.",
609
- UserWarning,
610
- stacklevel=2,
611
- )
612
-
613
- step = max(1, self.last_epoch)
614
-
615
- for initial_lr in self.base_lrs:
616
- if initial_lr < self.min_lr:
617
- raise ValueError(
618
- f"{self} received an initial learning rate "
619
- f"that was lower than the minimum learning rate."
620
- )
621
-
622
- new_lrs = [
623
- self._noam_annealing(initial_lr=initial_lr, step=step)
624
- for initial_lr in self.base_lrs
625
- ]
626
- return new_lrs
627
-
628
- def _noam_annealing(self, initial_lr, step):
629
- if self.warmup_steps > 0:
630
- mult = self._normalize * min(
631
- step ** (-0.5), step * (self.warmup_steps ** (-1.5))
632
- )
633
- else:
634
- mult = self._normalize * step ** (-0.5)
635
-
636
- out_lr = initial_lr * mult
637
- if step > self.warmup_steps:
638
- out_lr = max(out_lr, self.min_lr)
639
- return out_lr
640
-
641
-
642
- class NoamHoldAnnealing(WarmupHoldPolicy):
643
-
644
- def __init__(
645
- self,
646
- optimizer,
647
- *,
648
- max_steps,
649
- decay_rate=0.5,
650
- min_lr=0.0,
651
- last_epoch=-1,
652
- **kwargs,
653
- ):
654
- """
655
- From Nemo:
656
- Implementation of the Noam Hold Annealing policy
657
- from the SqueezeFormer paper.
658
-
659
- Unlike NoamAnnealing, the peak learning rate
660
- can be explicitly set for this scheduler.
661
- The schedule first performs linear warmup,
662
- then holds the peak LR, then decays with some schedule for
663
- the remainder of the steps.
664
- Therefore the min-lr is still dependent
665
- on the hyper parameters selected.
666
-
667
- It's schedule is determined by three factors-
668
-
669
- Warmup Steps: Initial stage, where linear warmup
670
- occurs uptil the peak LR is reached. Unlike NoamAnnealing,
671
- the peak LR is explicitly stated here instead of a scaling factor.
672
-
673
- Hold Steps: Intermediate stage, where the peak LR
674
- is maintained for some number of steps. In this region,
675
- the high peak LR allows the model to converge faster
676
- if training is stable. However the high LR
677
- may also cause instability during training.
678
- Should usually be a significant fraction of training
679
- steps (around 30-40% of the entire training steps).
680
-
681
- Decay Steps: Final stage, where the LR rapidly decays
682
- with some scaling rate (set by decay rate).
683
- To attain Noam decay, use 0.5,
684
- for Squeezeformer recommended decay, use 1.0.
685
- The fast decay after prolonged high LR during
686
- hold phase allows for rapid convergence.
687
-
688
- References:
689
- - [Squeezeformer:
690
- An Efficient Transformer for Automatic Speech Recognition]
691
- (https://arxiv.org/abs/2206.00888)
692
-
693
- Args:
694
- optimizer: Pytorch compatible Optimizer object.
695
- warmup_steps: Number of training steps in warmup stage
696
- warmup_ratio: Ratio of warmup steps to total steps
697
- hold_steps: Number of training steps to
698
- hold the learning rate after warm up
699
- hold_ratio: Ratio of hold steps to total steps
700
- max_steps: Total number of steps while training or `None` for
701
- infinite training
702
- decay_rate: Float value describing the polynomial decay
703
- after the hold period. Default value
704
- of 0.5 corresponds to Noam decay.
705
- min_lr: Minimum learning rate.
706
- """
707
- self.decay_rate = decay_rate
708
- super().__init__(
709
- optimizer=optimizer,
710
- max_steps=max_steps,
711
- last_epoch=last_epoch,
712
- min_lr=min_lr,
713
- **kwargs,
714
- )
715
-
716
- def _get_lr(self, step):
717
- if self.warmup_steps is None or self.warmup_steps == 0:
718
- raise ValueError("Noam scheduler cannot be used without warmup steps")
719
-
720
- if self.hold_steps > 0:
721
- hold_steps = self.hold_steps - self.warmup_steps
722
- else:
723
- hold_steps = 0
724
-
725
- new_lrs = [
726
- _noam_hold_annealing(
727
- initial_lr,
728
- step=step,
729
- warmup_steps=self.warmup_steps,
730
- hold_steps=hold_steps,
731
- decay_rate=self.decay_rate,
732
- min_lr=self.min_lr,
733
- )
734
- for initial_lr in self.base_lrs
735
- ]
736
- return new_lrs
737
-
738
- def set_step(self, step: int):
739
- self.last_epoch = step
740
-
741
-
742
- class ConstantLR(_LRScheduler):
743
- """The ConstantLR scheduler
744
-
745
- This scheduler keeps a constant lr
746
-
747
- """
748
-
749
- def __init__(
750
- self,
751
- optimizer: torch.optim.Optimizer,
752
- ):
753
- # __init__() must be invoked before setting field
754
- # because step() is also invoked in __init__()
755
- super().__init__(optimizer)
756
-
757
- def get_lr(self):
758
- return self.base_lrs
759
-
760
- def set_step(self, step: int):
761
- self.last_epoch = step
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cosyvoice/utils/train_utils.py DELETED
@@ -1,350 +0,0 @@
1
- # Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
2
- # 2023 Horizon Inc. (authors: Xingchen Song)
3
- # 2024 Alibaba Inc (authors: Xiang Lyu)
4
- #
5
- # Licensed under the Apache License, Version 2.0 (the "License");
6
- # you may not use this file except in compliance with the License.
7
- # You may obtain a copy of the License at
8
- #
9
- # http://www.apache.org/licenses/LICENSE-2.0
10
- #
11
- # Unless required by applicable law or agreed to in writing, software
12
- # distributed under the License is distributed on an "AS IS" BASIS,
13
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
- # See the License for the specific language governing permissions and
15
- # limitations under the License.
16
-
17
- from contextlib import nullcontext
18
- import logging
19
- import os
20
- import torch
21
- import json
22
- import re
23
- import datetime
24
- import yaml
25
-
26
- import deepspeed
27
- import torch.optim as optim
28
- import torch.distributed as dist
29
-
30
- from torch.utils.tensorboard import SummaryWriter
31
- from torch.utils.data import DataLoader
32
- from torch.nn.utils import clip_grad_norm_
33
-
34
- from deepspeed.runtime.zero.stage_1_and_2 import (
35
- estimate_zero2_model_states_mem_needs_all_live,
36
- )
37
-
38
- from cosyvoice.dataset.dataset import Dataset
39
- from cosyvoice.utils.scheduler import (
40
- WarmupLR,
41
- NoamHoldAnnealing,
42
- ConstantLR,
43
- )
44
-
45
-
46
- def init_distributed(args):
47
- world_size = int(os.environ.get("WORLD_SIZE", 1))
48
- local_rank = int(os.environ.get("LOCAL_RANK", 0))
49
- rank = int(os.environ.get("RANK", 0))
50
- logging.info(
51
- "training on multiple gpus, this gpu {}".format(local_rank)
52
- + ", rank {}, world_size {}".format(rank, world_size)
53
- )
54
- if args.train_engine == "torch_ddp":
55
- torch.cuda.set_device(local_rank)
56
- dist.init_process_group(args.dist_backend)
57
- else:
58
- deepspeed.init_distributed(dist_backend=args.dist_backend)
59
- return world_size, local_rank, rank
60
-
61
-
62
- def init_dataset_and_dataloader(args, configs):
63
- train_dataset = Dataset(
64
- args.train_data,
65
- data_pipeline=configs["data_pipeline"],
66
- mode="train",
67
- shuffle=True,
68
- partition=True,
69
- )
70
- cv_dataset = Dataset(
71
- args.cv_data,
72
- data_pipeline=configs["data_pipeline"],
73
- mode="train",
74
- shuffle=False,
75
- partition=False,
76
- )
77
-
78
- # do not use persistent_workers=True, as whisper tokenizer opens tiktoken file each time when the for loop starts
79
- train_data_loader = DataLoader(
80
- train_dataset,
81
- batch_size=None,
82
- pin_memory=args.pin_memory,
83
- num_workers=args.num_workers,
84
- prefetch_factor=args.prefetch,
85
- )
86
- cv_data_loader = DataLoader(
87
- cv_dataset,
88
- batch_size=None,
89
- pin_memory=args.pin_memory,
90
- num_workers=args.num_workers,
91
- prefetch_factor=args.prefetch,
92
- )
93
- return train_dataset, cv_dataset, train_data_loader, cv_data_loader
94
-
95
-
96
- def check_modify_and_save_config(args, configs):
97
- if args.train_engine == "torch_ddp":
98
- configs["train_conf"]["dtype"] = "fp32"
99
- else:
100
- with open(args.deepspeed_config, "r") as fin:
101
- ds_configs = json.load(fin)
102
- if "fp16" in ds_configs and ds_configs["fp16"]["enabled"]:
103
- configs["train_conf"]["dtype"] = "fp16"
104
- elif "bf16" in ds_configs and ds_configs["bf16"]["enabled"]:
105
- configs["train_conf"]["dtype"] = "bf16"
106
- else:
107
- configs["train_conf"]["dtype"] = "fp32"
108
- assert ds_configs["train_micro_batch_size_per_gpu"] == 1
109
- # if use deepspeed, override ddp config
110
- configs["train_conf"]["save_per_step"] = int(
111
- configs["train_conf"]["save_per_step"]
112
- * configs["train_conf"]["accum_grad"]
113
- / ds_configs["gradient_accumulation_steps"]
114
- )
115
- configs["train_conf"]["accum_grad"] = ds_configs["gradient_accumulation_steps"]
116
- configs["train_conf"]["grad_clip"] = ds_configs["gradient_clipping"]
117
- configs["train_conf"]["log_interval"] = ds_configs["steps_per_print"]
118
- return configs
119
-
120
-
121
- def wrap_cuda_model(args, model):
122
- local_world_size = int(os.environ.get("LOCAL_WORLD_SIZE", 1))
123
- world_size = int(os.environ.get("WORLD_SIZE", 1))
124
- if args.train_engine == "torch_ddp": # native pytorch ddp
125
- assert torch.cuda.is_available()
126
- model.cuda()
127
- model = torch.nn.parallel.DistributedDataParallel(
128
- model, find_unused_parameters=True
129
- )
130
- else:
131
- if int(os.environ.get("RANK", 0)) == 0:
132
- logging.info("Estimating model states memory needs (zero2)...")
133
- estimate_zero2_model_states_mem_needs_all_live(
134
- model,
135
- num_gpus_per_node=local_world_size,
136
- num_nodes=world_size // local_world_size,
137
- )
138
- return model
139
-
140
-
141
- def init_optimizer_and_scheduler(args, configs, model):
142
- if configs["train_conf"]["optim"] == "adam":
143
- optimizer = optim.Adam(
144
- model.parameters(), **configs["train_conf"]["optim_conf"]
145
- )
146
- elif configs["train_conf"]["optim"] == "adamw":
147
- optimizer = optim.AdamW(
148
- model.parameters(), **configs["train_conf"]["optim_conf"]
149
- )
150
- else:
151
- raise ValueError("unknown optimizer: " + configs["train_conf"])
152
-
153
- if configs["train_conf"]["scheduler"] == "warmuplr":
154
- scheduler_type = WarmupLR
155
- scheduler = WarmupLR(optimizer, **configs["train_conf"]["scheduler_conf"])
156
- elif configs["train_conf"]["scheduler"] == "NoamHoldAnnealing":
157
- scheduler_type = NoamHoldAnnealing
158
- scheduler = NoamHoldAnnealing(
159
- optimizer, **configs["train_conf"]["scheduler_conf"]
160
- )
161
- elif configs["train_conf"]["scheduler"] == "constantlr":
162
- scheduler_type = ConstantLR
163
- scheduler = ConstantLR(optimizer)
164
- else:
165
- raise ValueError("unknown scheduler: " + configs["train_conf"])
166
-
167
- # use deepspeed optimizer for speedup
168
- if args.train_engine == "deepspeed":
169
-
170
- def scheduler(opt):
171
- return scheduler_type(opt, **configs["train_conf"]["scheduler_conf"])
172
-
173
- model, optimizer, _, scheduler = deepspeed.initialize(
174
- args=args,
175
- model=model,
176
- optimizer=None,
177
- lr_scheduler=scheduler,
178
- model_parameters=model.parameters(),
179
- )
180
-
181
- return model, optimizer, scheduler
182
-
183
-
184
- def init_summarywriter(args):
185
- writer = None
186
- if int(os.environ.get("RANK", 0)) == 0:
187
- os.makedirs(args.model_dir, exist_ok=True)
188
- writer = SummaryWriter(args.tensorboard_dir)
189
- return writer
190
-
191
-
192
- def save_model(model, model_name, info_dict):
193
- rank = int(os.environ.get("RANK", 0))
194
- model_dir = info_dict["model_dir"]
195
- save_model_path = os.path.join(model_dir, "{}.pt".format(model_name))
196
-
197
- if info_dict["train_engine"] == "torch_ddp":
198
- if rank == 0:
199
- torch.save(model.module.state_dict(), save_model_path)
200
- else:
201
- with torch.no_grad():
202
- model.save_checkpoint(
203
- save_dir=model_dir, tag=model_name, client_state=info_dict
204
- )
205
- if rank == 0:
206
- info_path = re.sub(".pt$", ".yaml", save_model_path)
207
- info_dict["save_time"] = datetime.datetime.now().strftime("%d/%m/%Y %H:%M:%S")
208
- with open(info_path, "w") as fout:
209
- data = yaml.dump(info_dict)
210
- fout.write(data)
211
- logging.info(
212
- "[Rank {}] Checkpoint: save to checkpoint {}".format(rank, save_model_path)
213
- )
214
-
215
-
216
- def cosyvoice_join(group_join, info_dict):
217
- world_size = int(os.environ.get("WORLD_SIZE", 1))
218
- local_rank = int(os.environ.get("LOCAL_RANK", 0))
219
- rank = int(os.environ.get("RANK", 0))
220
-
221
- if info_dict["batch_idx"] != 0:
222
- # we try to join all rank in both ddp and deepspeed mode, in case different rank has different lr
223
- try:
224
- dist.monitored_barrier(
225
- group=group_join, timeout=group_join.options._timeout
226
- )
227
- return False
228
- except RuntimeError as e:
229
- logging.info(
230
- "Detected uneven workload distribution: {}\n".format(e)
231
- + "Break current worker to manually join all workers, "
232
- + "world_size {}, current rank {}, current local_rank {}\n".format(
233
- world_size, rank, local_rank
234
- )
235
- )
236
- return True
237
- else:
238
- return False
239
-
240
-
241
- def batch_forward(model, batch, info_dict):
242
- device = int(os.environ.get("LOCAL_RANK", 0))
243
-
244
- dtype = info_dict["dtype"]
245
- if dtype == "fp16":
246
- dtype = torch.float16
247
- elif dtype == "bf16":
248
- dtype = torch.bfloat16
249
- else: # fp32
250
- dtype = torch.float32
251
-
252
- if info_dict["train_engine"] == "torch_ddp":
253
- autocast = nullcontext()
254
- else:
255
- autocast = torch.cuda.amp.autocast(
256
- enabled=True, dtype=dtype, cache_enabled=False
257
- )
258
-
259
- with autocast:
260
- info_dict["loss_dict"] = model(batch, device)
261
- return info_dict
262
-
263
-
264
- def batch_backward(model, info_dict):
265
- if info_dict["train_engine"] == "deepspeed":
266
- scaled_loss = model.backward(info_dict["loss_dict"]["loss"])
267
- else:
268
- scaled_loss = info_dict["loss_dict"]["loss"] / info_dict["accum_grad"]
269
- scaled_loss.backward()
270
-
271
- info_dict["loss_dict"]["loss"] = scaled_loss
272
- return info_dict
273
-
274
-
275
- def update_parameter_and_lr(model, optimizer, scheduler, info_dict):
276
- grad_norm = 0.0
277
- if info_dict["train_engine"] == "deepspeed":
278
- info_dict["is_gradient_accumulation_boundary"] = (
279
- model.is_gradient_accumulation_boundary()
280
- )
281
- model.step()
282
- grad_norm = model.get_global_grad_norm()
283
- elif (info_dict["batch_idx"] + 1) % info_dict["accum_grad"] == 0:
284
- grad_norm = clip_grad_norm_(model.parameters(), info_dict["grad_clip"])
285
- if torch.isfinite(grad_norm):
286
- optimizer.step()
287
- optimizer.zero_grad()
288
- scheduler.step()
289
- info_dict["lr"] = optimizer.param_groups[0]["lr"]
290
- info_dict["grad_norm"] = grad_norm
291
- return info_dict
292
-
293
-
294
- def log_per_step(writer, info_dict):
295
- tag = info_dict["tag"]
296
- epoch = info_dict.get("epoch", 0)
297
- step = info_dict["step"]
298
- batch_idx = info_dict["batch_idx"]
299
- loss_dict = info_dict["loss_dict"]
300
- rank = int(os.environ.get("RANK", 0))
301
-
302
- # only rank 0 write to tensorboard to avoid multi-process write
303
- if writer is not None:
304
- if (
305
- info_dict["train_engine"] == "deepspeed"
306
- and info_dict["is_gradient_accumulation_boundary"] is True
307
- ) or (
308
- info_dict["train_engine"] == "torch_ddp"
309
- and (info_dict["batch_idx"] + 1) % info_dict["accum_grad"] == 0
310
- ):
311
- for k in ["epoch", "lr", "grad_norm"]:
312
- writer.add_scalar("{}/{}".format(tag, k), info_dict[k], step + 1)
313
- for k, v in loss_dict.items():
314
- writer.add_scalar("{}/{}".format(tag, k), v, step + 1)
315
-
316
- # TRAIN & CV, Shell log (stdout)
317
- if (info_dict["batch_idx"] + 1) % info_dict["log_interval"] == 0:
318
- log_str = "{} Batch {}/{} ".format(tag, epoch, batch_idx + 1)
319
- for name, value in loss_dict.items():
320
- log_str += "{} {:.6f} ".format(name, value)
321
- if tag == "TRAIN":
322
- log_str += "lr {:.8f} grad_norm {:.6f}".format(
323
- info_dict["lr"], info_dict["grad_norm"]
324
- )
325
- log_str += " rank {}".format(rank)
326
- logging.debug(log_str)
327
-
328
-
329
- def log_per_save(writer, info_dict):
330
- tag = info_dict["tag"]
331
- epoch = info_dict["epoch"]
332
- step = info_dict["step"]
333
- loss_dict = info_dict["loss_dict"]
334
- lr = info_dict["lr"]
335
- rank = int(os.environ.get("RANK", 0))
336
- logging.info(
337
- "Epoch {} Step {} CV info lr {} {} rank {}".format(
338
- epoch,
339
- step + 1,
340
- lr,
341
- rank,
342
- " ".join(["{}_{}".format(k, v) for k, v in loss_dict.items()]),
343
- )
344
- )
345
-
346
- if writer is not None:
347
- for k in ["epoch", "lr"]:
348
- writer.add_scalar("{}/{}".format(tag, k), info_dict[k], step + 1)
349
- for k, v in loss_dict.items():
350
- writer.add_scalar("{}/{}".format(tag, k), v, step + 1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
funasr_detach/__init__.py DELETED
@@ -1,38 +0,0 @@
1
- """Initialize funasr package."""
2
-
3
- import os
4
- import pkgutil
5
- import importlib
6
-
7
- dirname = os.path.dirname(__file__)
8
- version_file = os.path.join(dirname, "version.txt")
9
- with open(version_file, "r") as f:
10
- __version__ = f.read().strip()
11
-
12
-
13
- import importlib
14
- import pkgutil
15
-
16
-
17
- def import_submodules(package, recursive=True):
18
- if isinstance(package, str):
19
- package = importlib.import_module(package)
20
- results = {}
21
- for loader, name, is_pkg in pkgutil.walk_packages(
22
- package.__path__, package.__name__ + "."
23
- ):
24
- try:
25
- results[name] = importlib.import_module(name)
26
- except Exception as e:
27
- # 如果想要看到导入错误的具体信息,可以取消注释下面的行
28
- # print(f"Failed to import {name}: {e}")
29
- pass
30
- if recursive and is_pkg:
31
- results.update(import_submodules(name))
32
- return results
33
-
34
-
35
- import_submodules(__name__)
36
-
37
- from funasr_detach.auto.auto_model import AutoModel
38
- from funasr_detach.auto.auto_frontend import AutoFrontend
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
funasr_detach/auto/__init__.py DELETED
File without changes
funasr_detach/auto/auto_frontend.py DELETED
@@ -1,90 +0,0 @@
1
- import time
2
- import logging
3
- from tqdm import tqdm
4
-
5
- from funasr_detach.register import tables
6
- from funasr_detach.download.download_from_hub import download_model
7
- from funasr_detach.utils.load_utils import load_audio_text_image_video, extract_fbank
8
- from funasr_detach.auto.auto_model import prepare_data_iterator
9
- from funasr_detach.auto.auto_model import prepare_data_iterator
10
-
11
-
12
- class AutoFrontend:
13
- def __init__(self, **kwargs):
14
- assert "model" in kwargs
15
- if "model_conf" not in kwargs:
16
- logging.info(
17
- "download models from model hub: {}".format(
18
- kwargs.get("model_hub", "ms")
19
- )
20
- )
21
- kwargs = download_model(**kwargs)
22
-
23
- # build frontend
24
- frontend = kwargs.get("frontend", None)
25
- if frontend is not None:
26
- frontend_class = tables.frontend_classes.get(frontend)
27
- frontend = frontend_class(**kwargs["frontend_conf"])
28
-
29
- self.frontend = frontend
30
- if "frontend" in kwargs:
31
- del kwargs["frontend"]
32
- self.kwargs = kwargs
33
-
34
- def __call__(self, input, input_len=None, kwargs=None, **cfg):
35
-
36
- kwargs = self.kwargs if kwargs is None else kwargs
37
- kwargs.update(cfg)
38
-
39
- key_list, data_list = prepare_data_iterator(input, input_len=input_len)
40
- batch_size = kwargs.get("batch_size", 1)
41
- device = kwargs.get("device", "cpu")
42
- if device == "cpu":
43
- batch_size = 1
44
-
45
- meta_data = {}
46
-
47
- result_list = []
48
- num_samples = len(data_list)
49
- pbar = tqdm(colour="blue", total=num_samples + 1, dynamic_ncols=True)
50
-
51
- time0 = time.perf_counter()
52
- for beg_idx in range(0, num_samples, batch_size):
53
- end_idx = min(num_samples, beg_idx + batch_size)
54
- data_batch = data_list[beg_idx:end_idx]
55
- key_batch = key_list[beg_idx:end_idx]
56
-
57
- # extract fbank feats
58
- time1 = time.perf_counter()
59
- audio_sample_list = load_audio_text_image_video(
60
- data_batch, fs=self.frontend.fs, audio_fs=kwargs.get("fs", 16000)
61
- )
62
- time2 = time.perf_counter()
63
- meta_data["load_data"] = f"{time2 - time1:0.3f}"
64
- speech, speech_lengths = extract_fbank(
65
- audio_sample_list,
66
- data_type=kwargs.get("data_type", "sound"),
67
- frontend=self.frontend,
68
- **kwargs,
69
- )
70
- time3 = time.perf_counter()
71
- meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
72
- meta_data["batch_data_time"] = (
73
- speech_lengths.sum().item()
74
- * self.frontend.frame_shift
75
- * self.frontend.lfr_n
76
- / 1000
77
- )
78
-
79
- speech.to(device=device), speech_lengths.to(device=device)
80
- batch = {"input": speech, "input_len": speech_lengths, "key": key_batch}
81
- result_list.append(batch)
82
-
83
- pbar.update(1)
84
- description = f"{meta_data}, "
85
- pbar.set_description(description)
86
-
87
- time_end = time.perf_counter()
88
- pbar.set_description(f"time escaped total: {time_end - time0:0.3f}")
89
-
90
- return result_list
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
funasr_detach/auto/auto_model.py DELETED
@@ -1,573 +0,0 @@
1
- import json
2
- import time
3
- import copy
4
- import torch
5
- import random
6
- import string
7
- import logging
8
- import os.path
9
- import numpy as np
10
- from tqdm import tqdm
11
-
12
- from funasr_detach.register import tables
13
- from funasr_detach.utils.load_utils import load_bytes
14
- from funasr_detach.download.file import download_from_url
15
- from funasr_detach.download.download_from_hub import download_model
16
- from funasr_detach.utils.vad_utils import slice_padding_audio_samples
17
- from funasr_detach.train_utils.set_all_random_seed import set_all_random_seed
18
- from funasr_detach.train_utils.load_pretrained_model import load_pretrained_model
19
- from funasr_detach.utils.load_utils import load_audio_text_image_video
20
- from funasr_detach.utils.timestamp_tools import timestamp_sentence
21
- from funasr_detach.models.campplus.utils import sv_chunk, postprocess, distribute_spk
22
-
23
- try:
24
- from funasr_detach.models.campplus.cluster_backend import ClusterBackend
25
- except:
26
- print("If you want to use the speaker diarization, please `pip install hdbscan`")
27
-
28
-
29
- def prepare_data_iterator(data_in, input_len=None, data_type=None, key=None):
30
- """
31
-
32
- :param input:
33
- :param input_len:
34
- :param data_type:
35
- :param frontend:
36
- :return:
37
- """
38
- data_list = []
39
- key_list = []
40
- filelist = [".scp", ".txt", ".json", ".jsonl"]
41
-
42
- chars = string.ascii_letters + string.digits
43
- if isinstance(data_in, str) and data_in.startswith("http"): # url
44
- data_in = download_from_url(data_in)
45
- if isinstance(data_in, str) and os.path.exists(
46
- data_in
47
- ): # wav_path; filelist: wav.scp, file.jsonl;text.txt;
48
- _, file_extension = os.path.splitext(data_in)
49
- file_extension = file_extension.lower()
50
- if file_extension in filelist: # filelist: wav.scp, file.jsonl;text.txt;
51
- with open(data_in, encoding="utf-8") as fin:
52
- for line in fin:
53
- key = "rand_key_" + "".join(random.choice(chars) for _ in range(13))
54
- if data_in.endswith(
55
- ".jsonl"
56
- ): # file.jsonl: json.dumps({"source": data})
57
- lines = json.loads(line.strip())
58
- data = lines["source"]
59
- key = data["key"] if "key" in data else key
60
- else: # filelist, wav.scp, text.txt: id \t data or data
61
- lines = line.strip().split(maxsplit=1)
62
- data = lines[1] if len(lines) > 1 else lines[0]
63
- key = lines[0] if len(lines) > 1 else key
64
-
65
- data_list.append(data)
66
- key_list.append(key)
67
- else:
68
- key = "rand_key_" + "".join(random.choice(chars) for _ in range(13))
69
- data_list = [data_in]
70
- key_list = [key]
71
- elif isinstance(data_in, (list, tuple)):
72
- if data_type is not None and isinstance(
73
- data_type, (list, tuple)
74
- ): # mutiple inputs
75
- data_list_tmp = []
76
- for data_in_i, data_type_i in zip(data_in, data_type):
77
- key_list, data_list_i = prepare_data_iterator(
78
- data_in=data_in_i, data_type=data_type_i
79
- )
80
- data_list_tmp.append(data_list_i)
81
- data_list = []
82
- for item in zip(*data_list_tmp):
83
- data_list.append(item)
84
- else:
85
- # [audio sample point, fbank, text]
86
- data_list = data_in
87
- key_list = [
88
- "rand_key_" + "".join(random.choice(chars) for _ in range(13))
89
- for _ in range(len(data_in))
90
- ]
91
- else: # raw text; audio sample point, fbank; bytes
92
- if isinstance(data_in, bytes): # audio bytes
93
- data_in = load_bytes(data_in)
94
- if key is None:
95
- key = "rand_key_" + "".join(random.choice(chars) for _ in range(13))
96
- data_list = [data_in]
97
- key_list = [key]
98
-
99
- return key_list, data_list
100
-
101
-
102
- class AutoModel:
103
-
104
- def __init__(self, **kwargs):
105
- if not kwargs.get("disable_log", False):
106
- tables.print()
107
-
108
- model, kwargs = self.build_model(**kwargs)
109
-
110
- # if vad_model is not None, build vad model else None
111
- vad_model = kwargs.get("vad_model", None)
112
- vad_kwargs = kwargs.get("vad_model_revision", None)
113
- if vad_model is not None:
114
- logging.info("Building VAD model.")
115
- vad_kwargs = {
116
- "model": vad_model,
117
- "model_revision": vad_kwargs,
118
- "device": kwargs["device"],
119
- }
120
- vad_model, vad_kwargs = self.build_model(**vad_kwargs)
121
-
122
- # if punc_model is not None, build punc model else None
123
- punc_model = kwargs.get("punc_model", None)
124
- punc_kwargs = kwargs.get("punc_model_revision", None)
125
- if punc_model is not None:
126
- logging.info("Building punc model.")
127
- punc_kwargs = {
128
- "model": punc_model,
129
- "model_revision": punc_kwargs,
130
- "device": kwargs["device"],
131
- }
132
- punc_model, punc_kwargs = self.build_model(**punc_kwargs)
133
-
134
- # if spk_model is not None, build spk model else None
135
- spk_model = kwargs.get("spk_model", None)
136
- spk_kwargs = kwargs.get("spk_model_revision", None)
137
- if spk_model is not None:
138
- logging.info("Building SPK model.")
139
- spk_kwargs = {
140
- "model": spk_model,
141
- "model_revision": spk_kwargs,
142
- "device": kwargs["device"],
143
- }
144
- spk_model, spk_kwargs = self.build_model(**spk_kwargs)
145
- self.cb_model = ClusterBackend().to(kwargs["device"])
146
- spk_mode = kwargs.get("spk_mode", "punc_segment")
147
- if spk_mode not in ["default", "vad_segment", "punc_segment"]:
148
- logging.error(
149
- "spk_mode should be one of default, vad_segment and punc_segment."
150
- )
151
- self.spk_mode = spk_mode
152
-
153
- self.kwargs = kwargs
154
- self.model = model
155
- self.vad_model = vad_model
156
- self.vad_kwargs = vad_kwargs
157
- self.punc_model = punc_model
158
- self.punc_kwargs = punc_kwargs
159
- self.spk_model = spk_model
160
- self.spk_kwargs = spk_kwargs
161
- self.model_path = kwargs.get("model_path")
162
-
163
- def build_model(self, **kwargs):
164
- assert "model" in kwargs
165
- if "model_conf" not in kwargs:
166
- logging.info(
167
- "download models from model hub: {}".format(
168
- kwargs.get("model_hub", "ms")
169
- )
170
- )
171
- kwargs = download_model(**kwargs)
172
-
173
- set_all_random_seed(kwargs.get("seed", 0))
174
-
175
- device = kwargs.get("device", "cuda")
176
- if not torch.cuda.is_available() or kwargs.get("ngpu", 1) == 0:
177
- device = "cpu"
178
- kwargs["batch_size"] = 1
179
- kwargs["device"] = device
180
-
181
- if kwargs.get("ncpu", None):
182
- torch.set_num_threads(kwargs.get("ncpu"))
183
-
184
- # build tokenizer
185
- tokenizer = kwargs.get("tokenizer", None)
186
- if tokenizer is not None:
187
- tokenizer_class = tables.tokenizer_classes.get(tokenizer)
188
- tokenizer = tokenizer_class(**kwargs["tokenizer_conf"])
189
- kwargs["tokenizer"] = tokenizer
190
- kwargs["token_list"] = tokenizer.token_list
191
- vocab_size = len(tokenizer.token_list)
192
- else:
193
- vocab_size = -1
194
-
195
- # build frontend
196
- frontend = kwargs.get("frontend", None)
197
- if frontend is not None:
198
- frontend_class = tables.frontend_classes.get(frontend)
199
- frontend = frontend_class(**kwargs["frontend_conf"])
200
- kwargs["frontend"] = frontend
201
- kwargs["input_size"] = frontend.output_size()
202
-
203
- # build model
204
- model_class = tables.model_classes.get(kwargs["model"])
205
- model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=vocab_size)
206
-
207
- model.to(device)
208
-
209
- # init_param
210
- init_param = kwargs.get("init_param", None)
211
- if init_param is not None:
212
- logging.info(f"Loading pretrained params from {init_param}")
213
- load_pretrained_model(
214
- model=model,
215
- path=init_param,
216
- ignore_init_mismatch=kwargs.get("ignore_init_mismatch", False),
217
- oss_bucket=kwargs.get("oss_bucket", None),
218
- scope_map=kwargs.get("scope_map", None),
219
- excludes=kwargs.get("excludes", None),
220
- )
221
-
222
- return model, kwargs
223
-
224
- def __call__(self, *args, **cfg):
225
- kwargs = self.kwargs
226
- kwargs.update(cfg)
227
- res = self.model(*args, kwargs)
228
- return res
229
-
230
- def generate(self, input, input_len=None, **cfg):
231
- if self.vad_model is None:
232
- return self.inference(input, input_len=input_len, **cfg)
233
-
234
- else:
235
- return self.inference_with_vad(input, input_len=input_len, **cfg)
236
-
237
- def inference(
238
- self, input, input_len=None, model=None, kwargs=None, key=None, **cfg
239
- ):
240
- kwargs = self.kwargs if kwargs is None else kwargs
241
- kwargs.update(cfg)
242
- model = self.model if model is None else model
243
- model = model.cuda()
244
- model.eval()
245
-
246
- batch_size = kwargs.get("batch_size", 1)
247
- # if kwargs.get("device", "cpu") == "cpu":
248
- # batch_size = 1
249
-
250
- key_list, data_list = prepare_data_iterator(
251
- input, input_len=input_len, data_type=kwargs.get("data_type", None), key=key
252
- )
253
-
254
- speed_stats = {}
255
- asr_result_list = []
256
- num_samples = len(data_list)
257
- disable_pbar = kwargs.get("disable_pbar", False)
258
- pbar = (
259
- tqdm(colour="blue", total=num_samples, dynamic_ncols=True)
260
- if not disable_pbar
261
- else None
262
- )
263
- time_speech_total = 0.0
264
- time_escape_total = 0.0
265
- for beg_idx in range(0, num_samples, batch_size):
266
- end_idx = min(num_samples, beg_idx + batch_size)
267
- data_batch = data_list[beg_idx:end_idx]
268
- key_batch = key_list[beg_idx:end_idx]
269
- batch = {"data_in": data_batch, "key": key_batch}
270
- if (end_idx - beg_idx) == 1 and kwargs.get(
271
- "data_type", None
272
- ) == "fbank": # fbank
273
- batch["data_in"] = data_batch[0]
274
- batch["data_lengths"] = input_len
275
-
276
- time1 = time.perf_counter()
277
- with torch.no_grad():
278
- results, meta_data = model.inference(**batch, **kwargs)
279
- time2 = time.perf_counter()
280
-
281
- asr_result_list.extend(results)
282
-
283
- # batch_data_time = time_per_frame_s * data_batch_i["speech_lengths"].sum().item()
284
- batch_data_time = meta_data.get("batch_data_time", -1)
285
- time_escape = time2 - time1
286
- speed_stats["load_data"] = meta_data.get("load_data", 0.0)
287
- speed_stats["extract_feat"] = meta_data.get("extract_feat", 0.0)
288
- speed_stats["forward"] = f"{time_escape:0.3f}"
289
- speed_stats["batch_size"] = f"{len(results)}"
290
- speed_stats["time_cost"] = f"{(time_escape)}"
291
- speed_stats["rtf"] = f"{(time_escape) / batch_data_time:0.3f}"
292
- description = f"{speed_stats}, "
293
- if pbar:
294
- pbar.update(1)
295
- pbar.set_description(description)
296
- time_speech_total += batch_data_time
297
- time_escape_total += time_escape
298
-
299
- if pbar:
300
- # pbar.update(1)
301
- pbar.set_description(f"rtf_avg: {time_escape_total/time_speech_total:0.3f}")
302
- torch.cuda.empty_cache()
303
- return asr_result_list
304
-
305
- def inference_with_vad(self, input, input_len=None, **cfg):
306
-
307
- # step.1: compute the vad model
308
- self.vad_kwargs.update(cfg)
309
- beg_vad = time.time()
310
- res = self.inference(
311
- input,
312
- input_len=input_len,
313
- model=self.vad_model,
314
- kwargs=self.vad_kwargs,
315
- **cfg,
316
- )
317
- end_vad = time.time()
318
- print(f"time cost vad: {end_vad - beg_vad:0.3f}")
319
-
320
- # step.2 compute asr model
321
- model = self.model
322
- kwargs = self.kwargs
323
- kwargs.update(cfg)
324
- batch_size = int(kwargs.get("batch_size_s", 300)) * 1000
325
- batch_size_threshold_ms = int(kwargs.get("batch_size_threshold_s", 60)) * 1000
326
- kwargs["batch_size"] = batch_size
327
-
328
- key_list, data_list = prepare_data_iterator(
329
- input, input_len=input_len, data_type=kwargs.get("data_type", None)
330
- )
331
- results_ret_list = []
332
- time_speech_total_all_samples = 1e-6
333
-
334
- beg_total = time.time()
335
- pbar_total = tqdm(colour="red", total=len(res), dynamic_ncols=True)
336
- for i in range(len(res)):
337
- key = res[i]["key"]
338
- vadsegments = res[i]["value"]
339
- input_i = data_list[i]
340
- speech = load_audio_text_image_video(
341
- input_i, fs=kwargs["frontend"].fs, audio_fs=kwargs.get("fs", 16000)
342
- )
343
- speech_lengths = len(speech)
344
- n = len(vadsegments)
345
- data_with_index = [(vadsegments[i], i) for i in range(n)]
346
- sorted_data = sorted(data_with_index, key=lambda x: x[0][1] - x[0][0])
347
- results_sorted = []
348
-
349
- if not len(sorted_data):
350
- logging.info("decoding, utt: {}, empty speech".format(key))
351
- continue
352
-
353
- if len(sorted_data) > 0 and len(sorted_data[0]) > 0:
354
- batch_size = max(
355
- batch_size, sorted_data[0][0][1] - sorted_data[0][0][0]
356
- )
357
-
358
- batch_size_ms_cum = 0
359
- beg_idx = 0
360
- beg_asr_total = time.time()
361
- time_speech_total_per_sample = speech_lengths / 16000
362
- time_speech_total_all_samples += time_speech_total_per_sample
363
-
364
- all_segments = []
365
- for j, _ in enumerate(range(0, n)):
366
- # pbar_sample.update(1)
367
- batch_size_ms_cum += sorted_data[j][0][1] - sorted_data[j][0][0]
368
- if (
369
- j < n - 1
370
- and (
371
- batch_size_ms_cum
372
- + sorted_data[j + 1][0][1]
373
- - sorted_data[j + 1][0][0]
374
- )
375
- < batch_size
376
- and (sorted_data[j + 1][0][1] - sorted_data[j + 1][0][0])
377
- < batch_size_threshold_ms
378
- ):
379
- continue
380
- batch_size_ms_cum = 0
381
- end_idx = j + 1
382
- speech_j, speech_lengths_j = slice_padding_audio_samples(
383
- speech, speech_lengths, sorted_data[beg_idx:end_idx]
384
- )
385
- results = self.inference(
386
- speech_j,
387
- input_len=None,
388
- model=model,
389
- kwargs=kwargs,
390
- disable_pbar=True,
391
- **cfg,
392
- )
393
- if self.spk_model is not None:
394
- # compose vad segments: [[start_time_sec, end_time_sec, speech], [...]]
395
- for _b in range(len(speech_j)):
396
- vad_segments = [
397
- [
398
- sorted_data[beg_idx:end_idx][_b][0][0] / 1000.0,
399
- sorted_data[beg_idx:end_idx][_b][0][1] / 1000.0,
400
- np.array(speech_j[_b]),
401
- ]
402
- ]
403
- segments = sv_chunk(vad_segments)
404
- all_segments.extend(segments)
405
- speech_b = [i[2] for i in segments]
406
- spk_res = self.inference(
407
- speech_b,
408
- input_len=None,
409
- model=self.spk_model,
410
- kwargs=kwargs,
411
- disable_pbar=True,
412
- **cfg,
413
- )
414
- results[_b]["spk_embedding"] = spk_res[0]["spk_embedding"]
415
- beg_idx = end_idx
416
- if len(results) < 1:
417
- continue
418
- results_sorted.extend(results)
419
-
420
- restored_data = [0] * n
421
- for j in range(n):
422
- index = sorted_data[j][1]
423
- restored_data[index] = results_sorted[j]
424
- result = {}
425
-
426
- # results combine for texts, timestamps, speaker embeddings and others
427
- # TODO: rewrite for clean code
428
- for j in range(n):
429
- for k, v in restored_data[j].items():
430
- if k.startswith("timestamp"):
431
- if k not in result:
432
- result[k] = []
433
- for t in restored_data[j][k]:
434
- t[0] += vadsegments[j][0]
435
- t[1] += vadsegments[j][0]
436
- result[k].extend(restored_data[j][k])
437
- elif k == "spk_embedding":
438
- if k not in result:
439
- result[k] = restored_data[j][k]
440
- else:
441
- result[k] = torch.cat(
442
- [result[k], restored_data[j][k]], dim=0
443
- )
444
- elif "text" in k:
445
- if k not in result:
446
- result[k] = restored_data[j][k]
447
- else:
448
- result[k] += " " + restored_data[j][k]
449
- else:
450
- if k not in result:
451
- result[k] = restored_data[j][k]
452
- else:
453
- result[k] += restored_data[j][k]
454
-
455
- return_raw_text = kwargs.get("return_raw_text", False)
456
- # step.3 compute punc model
457
- if self.punc_model is not None:
458
- self.punc_kwargs.update(cfg)
459
- punc_res = self.inference(
460
- result["text"],
461
- model=self.punc_model,
462
- kwargs=self.punc_kwargs,
463
- disable_pbar=True,
464
- **cfg,
465
- )
466
- raw_text = copy.copy(result["text"])
467
- if return_raw_text:
468
- result["raw_text"] = raw_text
469
- result["text"] = punc_res[0]["text"]
470
- else:
471
- raw_text = None
472
-
473
- # speaker embedding cluster after resorted
474
- if self.spk_model is not None and kwargs.get("return_spk_res", True):
475
- if raw_text is None:
476
- logging.error("Missing punc_model, which is required by spk_model.")
477
- all_segments = sorted(all_segments, key=lambda x: x[0])
478
- spk_embedding = result["spk_embedding"]
479
- labels = self.cb_model(
480
- spk_embedding.cpu(), oracle_num=kwargs.get("preset_spk_num", None)
481
- )
482
- # del result['spk_embedding']
483
- sv_output = postprocess(all_segments, None, labels, spk_embedding.cpu())
484
- if self.spk_mode == "vad_segment": # recover sentence_list
485
- sentence_list = []
486
- for res, vadsegment in zip(restored_data, vadsegments):
487
- if "timestamp" not in res:
488
- logging.error(
489
- "Only 'iic/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch' \
490
- and 'iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch'\
491
- can predict timestamp, and speaker diarization relies on timestamps."
492
- )
493
- sentence_list.append(
494
- {
495
- "start": vadsegment[0],
496
- "end": vadsegment[1],
497
- "sentence": res["text"],
498
- "timestamp": res["timestamp"],
499
- }
500
- )
501
- elif self.spk_mode == "punc_segment":
502
- if "timestamp" not in result:
503
- logging.error(
504
- "Only 'iic/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch' \
505
- and 'iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch'\
506
- can predict timestamp, and speaker diarization relies on timestamps."
507
- )
508
- sentence_list = timestamp_sentence(
509
- punc_res[0]["punc_array"],
510
- result["timestamp"],
511
- raw_text,
512
- return_raw_text=return_raw_text,
513
- )
514
- distribute_spk(sentence_list, sv_output)
515
- result["sentence_info"] = sentence_list
516
- elif kwargs.get("sentence_timestamp", False):
517
- sentence_list = timestamp_sentence(
518
- punc_res[0]["punc_array"],
519
- result["timestamp"],
520
- raw_text,
521
- return_raw_text=return_raw_text,
522
- )
523
- result["sentence_info"] = sentence_list
524
- if "spk_embedding" in result:
525
- del result["spk_embedding"]
526
-
527
- result["key"] = key
528
- results_ret_list.append(result)
529
- end_asr_total = time.time()
530
- time_escape_total_per_sample = end_asr_total - beg_asr_total
531
- pbar_total.update(1)
532
- pbar_total.set_description(
533
- f"rtf_avg: {time_escape_total_per_sample / time_speech_total_per_sample:0.3f}, "
534
- f"time_speech: {time_speech_total_per_sample: 0.3f}, "
535
- f"time_escape: {time_escape_total_per_sample:0.3f}"
536
- )
537
-
538
- return results_ret_list
539
-
540
- def infer_encoder(
541
- self, input, input_len=None, model=None, kwargs=None, key=None, **cfg
542
- ):
543
- kwargs = self.kwargs if kwargs is None else kwargs
544
- kwargs.update(cfg)
545
- model = self.model if model is None else model
546
- model = model.cuda()
547
- model.eval()
548
-
549
- batch_size = kwargs.get("batch_size", 1)
550
-
551
- key_list, data_list = prepare_data_iterator(
552
- input, input_len=input_len, data_type=kwargs.get("data_type", None), key=key
553
- )
554
-
555
- asr_result_list = []
556
- num_samples = len(data_list)
557
- for beg_idx in range(0, num_samples, batch_size):
558
- end_idx = min(num_samples, beg_idx + batch_size)
559
- data_batch = data_list[beg_idx:end_idx]
560
- key_batch = key_list[beg_idx:end_idx]
561
- batch = {"data_in": data_batch, "key": key_batch}
562
- if (end_idx - beg_idx) == 1 and kwargs.get(
563
- "data_type", None
564
- ) == "fbank": # fbank
565
- batch["data_in"] = data_batch[0]
566
- batch["data_lengths"] = input_len
567
-
568
- with torch.no_grad():
569
- results, meta_data, cache = model.infer_encoder(**batch, **kwargs)
570
- asr_result_list.extend(results)
571
-
572
- torch.cuda.empty_cache()
573
- return asr_result_list, cache
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
funasr_detach/auto/auto_tokenizer.py DELETED
@@ -1,7 +0,0 @@
1
- class AutoTokenizer:
2
- """
3
- Undo
4
- """
5
-
6
- def __init__(self):
7
- pass
 
 
 
 
 
 
 
 
funasr_detach/bin/__init__.py DELETED
File without changes
funasr_detach/bin/compute_audio_cmvn.py DELETED
@@ -1,152 +0,0 @@
1
- import os
2
- import json
3
- import numpy as np
4
- import torch
5
- import hydra
6
- import logging
7
- from omegaconf import DictConfig, OmegaConf
8
-
9
- from funasr_detach.register import tables
10
- from funasr_detach.download.download_from_hub import download_model
11
- from funasr_detach.train_utils.set_all_random_seed import set_all_random_seed
12
-
13
-
14
- @hydra.main(config_name=None, version_base=None)
15
- def main_hydra(kwargs: DictConfig):
16
- if kwargs.get("debug", False):
17
- import pdb
18
-
19
- pdb.set_trace()
20
-
21
- assert "model" in kwargs
22
- if "model_conf" not in kwargs:
23
- logging.info(
24
- "download models from model hub: {}".format(kwargs.get("model_hub", "ms"))
25
- )
26
- kwargs = download_model(is_training=kwargs.get("is_training", True), **kwargs)
27
-
28
- main(**kwargs)
29
-
30
-
31
- def main(**kwargs):
32
- print(kwargs)
33
- # set random seed
34
- tables.print()
35
- set_all_random_seed(kwargs.get("seed", 0))
36
- torch.backends.cudnn.enabled = kwargs.get(
37
- "cudnn_enabled", torch.backends.cudnn.enabled
38
- )
39
- torch.backends.cudnn.benchmark = kwargs.get(
40
- "cudnn_benchmark", torch.backends.cudnn.benchmark
41
- )
42
- torch.backends.cudnn.deterministic = kwargs.get("cudnn_deterministic", True)
43
-
44
- tokenizer = kwargs.get("tokenizer", None)
45
-
46
- # build frontend if frontend is none None
47
- frontend = kwargs.get("frontend", None)
48
- if frontend is not None:
49
- frontend_class = tables.frontend_classes.get(frontend)
50
- frontend = frontend_class(**kwargs["frontend_conf"])
51
- kwargs["frontend"] = frontend
52
- kwargs["input_size"] = frontend.output_size()
53
-
54
- # dataset
55
- dataset_class = tables.dataset_classes.get(kwargs.get("dataset", "AudioDataset"))
56
- dataset_train = dataset_class(
57
- kwargs.get("train_data_set_list"),
58
- frontend=frontend,
59
- tokenizer=None,
60
- is_training=False,
61
- **kwargs.get("dataset_conf")
62
- )
63
-
64
- # dataloader
65
- batch_sampler = kwargs["dataset_conf"].get(
66
- "batch_sampler", "DynamicBatchLocalShuffleSampler"
67
- )
68
- batch_sampler_train = None
69
- if batch_sampler is not None:
70
- batch_sampler_class = tables.batch_sampler_classes.get(batch_sampler)
71
- dataset_conf = kwargs.get("dataset_conf")
72
- dataset_conf["batch_type"] = "example"
73
- dataset_conf["batch_size"] = 1
74
- batch_sampler_train = batch_sampler_class(
75
- dataset_train, is_training=False, **dataset_conf
76
- )
77
-
78
- dataloader_train = torch.utils.data.DataLoader(
79
- dataset_train,
80
- collate_fn=dataset_train.collator,
81
- batch_sampler=batch_sampler_train,
82
- num_workers=int(kwargs.get("dataset_conf").get("num_workers", 4)),
83
- pin_memory=True,
84
- )
85
-
86
- iter_stop = int(kwargs.get("scale", 1.0) * len(dataloader_train))
87
-
88
- total_frames = 0
89
- for batch_idx, batch in enumerate(dataloader_train):
90
- if batch_idx >= iter_stop:
91
- break
92
-
93
- fbank = batch["speech"].numpy()[0, :, :]
94
- if total_frames == 0:
95
- mean_stats = np.sum(fbank, axis=0)
96
- var_stats = np.sum(np.square(fbank), axis=0)
97
- else:
98
- mean_stats += np.sum(fbank, axis=0)
99
- var_stats += np.sum(np.square(fbank), axis=0)
100
- total_frames += fbank.shape[0]
101
-
102
- cmvn_info = {
103
- "mean_stats": list(mean_stats.tolist()),
104
- "var_stats": list(var_stats.tolist()),
105
- "total_frames": total_frames,
106
- }
107
- cmvn_file = kwargs.get("cmvn_file", "cmvn.json")
108
- # import pdb;pdb.set_trace()
109
- with open(cmvn_file, "w") as fout:
110
- fout.write(json.dumps(cmvn_info))
111
-
112
- mean = -1.0 * mean_stats / total_frames
113
- var = 1.0 / np.sqrt(var_stats / total_frames - mean * mean)
114
- dims = mean.shape[0]
115
- am_mvn = os.path.dirname(cmvn_file) + "/am.mvn"
116
- with open(am_mvn, "w") as fout:
117
- fout.write(
118
- "<Nnet>"
119
- + "\n"
120
- + "<Splice> "
121
- + str(dims)
122
- + " "
123
- + str(dims)
124
- + "\n"
125
- + "[ 0 ]"
126
- + "\n"
127
- + "<AddShift> "
128
- + str(dims)
129
- + " "
130
- + str(dims)
131
- + "\n"
132
- )
133
- mean_str = (
134
- str(list(mean)).replace(",", "").replace("[", "[ ").replace("]", " ]")
135
- )
136
- fout.write("<LearnRateCoef> 0 " + mean_str + "\n")
137
- fout.write("<Rescale> " + str(dims) + " " + str(dims) + "\n")
138
- var_str = str(list(var)).replace(",", "").replace("[", "[ ").replace("]", " ]")
139
- fout.write("<LearnRateCoef> 0 " + var_str + "\n")
140
- fout.write("</Nnet>" + "\n")
141
-
142
-
143
- """
144
- python funasr/bin/compute_audio_cmvn.py \
145
- --config-path "/Users/zhifu/funasr1.0/examples/aishell/paraformer/conf" \
146
- --config-name "train_asr_paraformer_conformer_12e_6d_2048_256.yaml" \
147
- ++train_data_set_list="/Users/zhifu/funasr1.0/data/list/audio_datasets.jsonl" \
148
- ++cmvn_file="/Users/zhifu/funasr1.0/data/list/cmvn.json" \
149
- ++dataset_conf.num_workers=0
150
- """
151
- if __name__ == "__main__":
152
- main_hydra()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
funasr_detach/bin/inference.py DELETED
@@ -1,33 +0,0 @@
1
- import hydra
2
- import logging
3
- from omegaconf import DictConfig, OmegaConf, ListConfig
4
-
5
- from funasr_detach.auto.auto_model import AutoModel
6
-
7
-
8
- @hydra.main(config_name=None, version_base=None)
9
- def main_hydra(cfg: DictConfig):
10
- def to_plain_list(cfg_item):
11
- if isinstance(cfg_item, ListConfig):
12
- return OmegaConf.to_container(cfg_item, resolve=True)
13
- elif isinstance(cfg_item, DictConfig):
14
- return {k: to_plain_list(v) for k, v in cfg_item.items()}
15
- else:
16
- return cfg_item
17
-
18
- kwargs = to_plain_list(cfg)
19
- log_level = getattr(logging, kwargs.get("log_level", "INFO").upper())
20
-
21
- logging.basicConfig(level=log_level)
22
-
23
- if kwargs.get("debug", False):
24
- import pdb
25
-
26
- pdb.set_trace()
27
- model = AutoModel(**kwargs)
28
- res = model.generate(input=kwargs["input"])
29
- print(res)
30
-
31
-
32
- if __name__ == "__main__":
33
- main_hydra()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
funasr_detach/bin/tokenize_text.py DELETED
@@ -1,281 +0,0 @@
1
- #!/usr/bin/env python3
2
- import argparse
3
- from collections import Counter
4
- import logging
5
- from pathlib import Path
6
- import sys
7
- from typing import List
8
- from typing import Optional
9
-
10
-
11
- from funasr_detach.utils.cli_utils import get_commandline_args
12
- from funasr_detach.tokenizer.build_tokenizer import build_tokenizer
13
- from funasr_detach.tokenizer.cleaner import TextCleaner
14
- from funasr_detach.tokenizer.phoneme_tokenizer import g2p_classes
15
- from funasr_detach.utils.types import str2bool
16
- from funasr_detach.utils.types import str_or_none
17
-
18
-
19
- def field2slice(field: Optional[str]) -> slice:
20
- """Convert field string to slice
21
-
22
- Note that field string accepts 1-based integer.
23
-
24
- Examples:
25
- >>> field2slice("1-")
26
- slice(0, None, None)
27
- >>> field2slice("1-3")
28
- slice(0, 3, None)
29
- >>> field2slice("-3")
30
- slice(None, 3, None)
31
- """
32
- field = field.strip()
33
- try:
34
- if "-" in field:
35
- # e.g. "2-" or "2-5" or "-7"
36
- s1, s2 = field.split("-", maxsplit=1)
37
- if s1.strip() == "":
38
- s1 = None
39
- else:
40
- s1 = int(s1)
41
- if s1 == 0:
42
- raise ValueError("1-based string")
43
- if s2.strip() == "":
44
- s2 = None
45
- else:
46
- s2 = int(s2)
47
- else:
48
- # e.g. "2"
49
- s1 = int(field)
50
- s2 = s1 + 1
51
- if s1 == 0:
52
- raise ValueError("must be 1 or more value")
53
- except ValueError:
54
- raise RuntimeError(f"Format error: e.g. '2-', '2-5', or '-5': {field}")
55
-
56
- if s1 is None:
57
- slic = slice(None, s2)
58
- else:
59
- # -1 because of 1-based integer following "cut" command
60
- # e.g "1-3" -> slice(0, 3)
61
- slic = slice(s1 - 1, s2)
62
- return slic
63
-
64
-
65
- def tokenize(
66
- input: str,
67
- output: str,
68
- field: Optional[str],
69
- delimiter: Optional[str],
70
- token_type: str,
71
- space_symbol: str,
72
- non_linguistic_symbols: Optional[str],
73
- bpemodel: Optional[str],
74
- log_level: str,
75
- write_vocabulary: bool,
76
- vocabulary_size: int,
77
- remove_non_linguistic_symbols: bool,
78
- cutoff: int,
79
- add_symbol: List[str],
80
- cleaner: Optional[str],
81
- g2p: Optional[str],
82
- ):
83
-
84
- logging.basicConfig(
85
- level=log_level,
86
- format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
87
- )
88
- if input == "-":
89
- fin = sys.stdin
90
- else:
91
- fin = Path(input).open("r", encoding="utf-8")
92
- if output == "-":
93
- fout = sys.stdout
94
- else:
95
- p = Path(output)
96
- p.parent.mkdir(parents=True, exist_ok=True)
97
- fout = p.open("w", encoding="utf-8")
98
-
99
- cleaner = TextCleaner(cleaner)
100
- tokenizer = build_tokenizer(
101
- token_type=token_type,
102
- bpemodel=bpemodel,
103
- delimiter=delimiter,
104
- space_symbol=space_symbol,
105
- non_linguistic_symbols=non_linguistic_symbols,
106
- remove_non_linguistic_symbols=remove_non_linguistic_symbols,
107
- g2p_type=g2p,
108
- )
109
-
110
- counter = Counter()
111
- if field is not None:
112
- field = field2slice(field)
113
-
114
- for line in fin:
115
- line = line.rstrip()
116
- if field is not None:
117
- # e.g. field="2-"
118
- # uttidA hello world!! -> hello world!!
119
- tokens = line.split(delimiter)
120
- tokens = tokens[field]
121
- if delimiter is None:
122
- line = " ".join(tokens)
123
- else:
124
- line = delimiter.join(tokens)
125
-
126
- line = cleaner(line)
127
- tokens = tokenizer.text2tokens(line)
128
- if not write_vocabulary:
129
- fout.write(" ".join(tokens) + "\n")
130
- else:
131
- for t in tokens:
132
- counter[t] += 1
133
-
134
- if not write_vocabulary:
135
- return
136
-
137
- ## FIXME
138
- ## del duplicate add_symbols in counter
139
- for symbol_and_id in add_symbol:
140
- # e.g symbol="<blank>:0"
141
- try:
142
- symbol, idx = symbol_and_id.split(":")
143
- except ValueError:
144
- raise RuntimeError(f"Format error: e.g. '<blank>:0': {symbol_and_id}")
145
- symbol = symbol.strip()
146
- if symbol in counter:
147
- del counter[symbol]
148
-
149
- # ======= write_vocabulary mode from here =======
150
- # Sort by the number of occurrences in descending order
151
- # and filter lower frequency words than cutoff value
152
- words_and_counts = list(
153
- filter(lambda x: x[1] > cutoff, sorted(counter.items(), key=lambda x: -x[1]))
154
- )
155
- # Restrict the vocabulary size
156
- if vocabulary_size > 0:
157
- if vocabulary_size < len(add_symbol):
158
- raise RuntimeError(f"vocabulary_size is too small: {vocabulary_size}")
159
- words_and_counts = words_and_counts[: vocabulary_size - len(add_symbol)]
160
-
161
- # Parse the values of --add_symbol
162
- for symbol_and_id in add_symbol:
163
- # e.g symbol="<blank>:0"
164
- try:
165
- symbol, idx = symbol_and_id.split(":")
166
- idx = int(idx)
167
- except ValueError:
168
- raise RuntimeError(f"Format error: e.g. '<blank>:0': {symbol_and_id}")
169
- symbol = symbol.strip()
170
-
171
- # e.g. idx=0 -> append as the first symbol
172
- # e.g. idx=-1 -> append as the last symbol
173
- if idx < 0:
174
- idx = len(words_and_counts) + 1 + idx
175
- words_and_counts.insert(idx, (symbol, None))
176
-
177
- # Write words
178
- for w, c in words_and_counts:
179
- fout.write(w + "\n")
180
-
181
- # Logging
182
- total_count = sum(counter.values())
183
- invocab_count = sum(c for w, c in words_and_counts if c is not None)
184
- logging.info(f"OOV rate = {(total_count - invocab_count) / total_count * 100} %")
185
-
186
-
187
- def get_parser() -> argparse.ArgumentParser:
188
- parser = argparse.ArgumentParser(
189
- description="Tokenize texts",
190
- formatter_class=argparse.ArgumentDefaultsHelpFormatter,
191
- )
192
- parser.add_argument(
193
- "--log_level",
194
- type=lambda x: x.upper(),
195
- default="INFO",
196
- choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
197
- help="The verbose level of logging",
198
- )
199
-
200
- parser.add_argument(
201
- "--input", "-i", required=True, help="Input text. - indicates sys.stdin"
202
- )
203
- parser.add_argument(
204
- "--output", "-o", required=True, help="Output text. - indicates sys.stdout"
205
- )
206
- parser.add_argument(
207
- "--field",
208
- "-f",
209
- help="The target columns of the input text as 1-based integer. e.g 2-",
210
- )
211
- parser.add_argument(
212
- "--token_type",
213
- "-t",
214
- default="char",
215
- choices=["char", "bpe", "word", "phn"],
216
- help="Token type",
217
- )
218
- parser.add_argument("--delimiter", "-d", default=None, help="The delimiter")
219
- parser.add_argument("--space_symbol", default="<space>", help="The space symbol")
220
- parser.add_argument("--bpemodel", default=None, help="The bpemodel file path")
221
- parser.add_argument(
222
- "--non_linguistic_symbols",
223
- type=str_or_none,
224
- help="non_linguistic_symbols file path",
225
- )
226
- parser.add_argument(
227
- "--remove_non_linguistic_symbols",
228
- type=str2bool,
229
- default=False,
230
- help="Remove non-language-symbols from tokens",
231
- )
232
- parser.add_argument(
233
- "--cleaner",
234
- type=str_or_none,
235
- choices=[None, "tacotron", "jaconv", "vietnamese", "korean_cleaner"],
236
- default=None,
237
- help="Apply text cleaning",
238
- )
239
- parser.add_argument(
240
- "--g2p",
241
- type=str_or_none,
242
- choices=g2p_classes,
243
- default=None,
244
- help="Specify g2p method if --token_type=phn",
245
- )
246
-
247
- group = parser.add_argument_group("write_vocabulary mode related")
248
- group.add_argument(
249
- "--write_vocabulary",
250
- type=str2bool,
251
- default=False,
252
- help="Write tokens list instead of tokenized text per line",
253
- )
254
- group.add_argument("--vocabulary_size", type=int, default=0, help="Vocabulary size")
255
- group.add_argument(
256
- "--cutoff",
257
- default=0,
258
- type=int,
259
- help="cut-off frequency used for write-vocabulary mode",
260
- )
261
- group.add_argument(
262
- "--add_symbol",
263
- type=str,
264
- default=[],
265
- action="append",
266
- help="Append symbol e.g. --add_symbol '<blank>:0' --add_symbol '<unk>:1'",
267
- )
268
-
269
- return parser
270
-
271
-
272
- def main(cmd=None):
273
- print(get_commandline_args(), file=sys.stderr)
274
- parser = get_parser()
275
- args = parser.parse_args(cmd)
276
- kwargs = vars(args)
277
- tokenize(**kwargs)
278
-
279
-
280
- if __name__ == "__main__":
281
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
funasr_detach/bin/train.py DELETED
@@ -1,227 +0,0 @@
1
- #!/usr/bin/env python3
2
- # -*- encoding: utf-8 -*-
3
-
4
- import os
5
- import sys
6
- import torch
7
- import hydra
8
- import logging
9
- import argparse
10
- from io import BytesIO
11
- import torch.distributed as dist
12
- from collections.abc import Sequence
13
- from omegaconf import DictConfig, OmegaConf
14
- from torch.nn.parallel import DistributedDataParallel as DDP
15
- from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
16
-
17
- from funasr_detach.register import tables
18
- from funasr_detach.optimizers import optim_classes
19
- from funasr_detach.train_utils.trainer import Trainer
20
- from funasr_detach.schedulers import scheduler_classes
21
- from funasr_detach.train_utils.initialize import initialize
22
- from funasr_detach.download.download_from_hub import download_model
23
- from funasr_detach.models.lora.utils import mark_only_lora_as_trainable
24
- from funasr_detach.train_utils.set_all_random_seed import set_all_random_seed
25
- from funasr_detach.train_utils.load_pretrained_model import load_pretrained_model
26
-
27
- # from funasr_detach.tokenizer.build_tokenizer import build_tokenizer
28
- # from funasr_detach.tokenizer.token_id_converter import TokenIDConverter
29
- # from funasr_detach.tokenizer.funtoken import build_tokenizer
30
-
31
-
32
- @hydra.main(config_name=None, version_base=None)
33
- def main_hydra(kwargs: DictConfig):
34
- if kwargs.get("debug", False):
35
- import pdb
36
-
37
- pdb.set_trace()
38
-
39
- assert "model" in kwargs
40
- if "model_conf" not in kwargs:
41
- logging.info(
42
- "download models from model hub: {}".format(kwargs.get("model_hub", "ms"))
43
- )
44
- kwargs = download_model(is_training=kwargs.get("is_training", True), **kwargs)
45
-
46
- main(**kwargs)
47
-
48
-
49
- def main(**kwargs):
50
- print(kwargs)
51
-
52
- # set random seed
53
- set_all_random_seed(kwargs.get("seed", 0))
54
- torch.backends.cudnn.enabled = kwargs.get(
55
- "cudnn_enabled", torch.backends.cudnn.enabled
56
- )
57
- torch.backends.cudnn.benchmark = kwargs.get(
58
- "cudnn_benchmark", torch.backends.cudnn.benchmark
59
- )
60
- torch.backends.cudnn.deterministic = kwargs.get("cudnn_deterministic", True)
61
-
62
- local_rank = int(os.environ.get("LOCAL_RANK", 0))
63
- if local_rank == 0:
64
- tables.print()
65
- # Check if we are using DDP or FSDP
66
- use_ddp = "WORLD_SIZE" in os.environ and int(os.environ["WORLD_SIZE"]) > 1
67
- use_fsdp = kwargs.get("use_fsdp", None)
68
- if use_ddp or use_fsdp:
69
- dist.init_process_group(
70
- backend=kwargs.get("backend", "nccl"), init_method="env://"
71
- )
72
- torch.cuda.set_device(local_rank)
73
-
74
- # save config.yaml
75
- if (
76
- (use_ddp or use_fsdp)
77
- and dist.get_rank() == 0
78
- or not (use_ddp or use_fsdp)
79
- and local_rank == 0
80
- ):
81
- os.makedirs(kwargs.get("output_dir", "./"), exist_ok=True)
82
- yaml_file = os.path.join(kwargs.get("output_dir", "./"), "config.yaml")
83
- OmegaConf.save(config=kwargs, f=yaml_file)
84
- logging.info("config.yaml is saved to: %s", yaml_file)
85
-
86
- tokenizer = kwargs.get("tokenizer", None)
87
- if tokenizer is not None:
88
- tokenizer_class = tables.tokenizer_classes.get(tokenizer)
89
- tokenizer = tokenizer_class(**kwargs["tokenizer_conf"])
90
- kwargs["tokenizer"] = tokenizer
91
-
92
- # build frontend if frontend is none None
93
- frontend = kwargs.get("frontend", None)
94
- if frontend is not None:
95
- frontend_class = tables.frontend_classes.get(frontend)
96
- frontend = frontend_class(**kwargs["frontend_conf"])
97
- kwargs["frontend"] = frontend
98
- kwargs["input_size"] = frontend.output_size()
99
-
100
- # build model
101
- model_class = tables.model_classes.get(kwargs["model"])
102
- model = model_class(
103
- **kwargs, **kwargs["model_conf"], vocab_size=len(tokenizer.token_list)
104
- )
105
-
106
- # init_param
107
- init_param = kwargs.get("init_param", None)
108
- if init_param is not None:
109
- if not isinstance(init_param, (list, tuple)):
110
- init_param = (init_param,)
111
- logging.info("init_param is not None: %s", init_param)
112
- for p in init_param:
113
- logging.info(f"Loading pretrained params from {p}")
114
- load_pretrained_model(
115
- model=model,
116
- path=p,
117
- ignore_init_mismatch=kwargs.get("ignore_init_mismatch", True),
118
- oss_bucket=kwargs.get("oss_bucket", None),
119
- scope_map=kwargs.get("scope_map", None),
120
- excludes=kwargs.get("excludes", None),
121
- )
122
- else:
123
- initialize(model, kwargs.get("init", "kaiming_normal"))
124
-
125
- # freeze_param
126
- freeze_param = kwargs.get("freeze_param", None)
127
- if freeze_param is not None:
128
- freeze_param = eval(freeze_param)
129
- if isinstance(freeze_param, Sequence):
130
- freeze_param = (freeze_param,)
131
- logging.info("freeze_param is not None: %s", freeze_param)
132
- for t in freeze_param:
133
- for k, p in model.named_parameters():
134
- if k.startswith(t + ".") or k == t:
135
- logging.info(f"Setting {k}.requires_grad = False")
136
- p.requires_grad = False
137
-
138
- if use_ddp:
139
- model = model.cuda(local_rank)
140
- model = DDP(
141
- model,
142
- device_ids=[local_rank],
143
- find_unused_parameters=kwargs.get("train_conf", {}).get(
144
- "find_unused_parameters", False
145
- ),
146
- )
147
- elif use_fsdp:
148
- model = FSDP(model).cuda(local_rank)
149
- else:
150
- model = model.to(device=kwargs.get("device", "cuda"))
151
-
152
- # optim
153
- optim = kwargs.get("optim", "adam")
154
- assert optim in optim_classes
155
- optim_class = optim_classes.get(optim)
156
- optim = optim_class(model.parameters(), **kwargs.get("optim_conf"))
157
-
158
- # scheduler
159
- scheduler = kwargs.get("scheduler", "warmuplr")
160
- assert scheduler in scheduler_classes
161
- scheduler_class = scheduler_classes.get(scheduler)
162
- scheduler = scheduler_class(optim, **kwargs.get("scheduler_conf"))
163
-
164
- # dataset
165
- dataset_class = tables.dataset_classes.get(kwargs.get("dataset", "AudioDataset"))
166
- dataset_tr = dataset_class(
167
- kwargs.get("train_data_set_list"),
168
- frontend=frontend,
169
- tokenizer=tokenizer,
170
- is_training=True,
171
- **kwargs.get("dataset_conf"),
172
- )
173
- dataset_val = dataset_class(
174
- kwargs.get("valid_data_set_list"),
175
- frontend=frontend,
176
- tokenizer=tokenizer,
177
- is_training=False,
178
- **kwargs.get("dataset_conf"),
179
- )
180
-
181
- # dataloader
182
- batch_sampler = kwargs["dataset_conf"].get(
183
- "batch_sampler", "DynamicBatchLocalShuffleSampler"
184
- )
185
- batch_sampler_val = None
186
- if batch_sampler is not None:
187
- batch_sampler_class = tables.batch_sampler_classes.get(batch_sampler)
188
- batch_sampler = batch_sampler_class(dataset_tr, **kwargs.get("dataset_conf"))
189
- batch_sampler_val = batch_sampler_class(
190
- dataset_val, is_training=False, **kwargs.get("dataset_conf")
191
- )
192
- dataloader_tr = torch.utils.data.DataLoader(
193
- dataset_tr,
194
- collate_fn=dataset_tr.collator,
195
- batch_sampler=batch_sampler,
196
- num_workers=kwargs.get("dataset_conf").get("num_workers", 4),
197
- pin_memory=True,
198
- )
199
-
200
- dataloader_val = torch.utils.data.DataLoader(
201
- dataset_val,
202
- collate_fn=dataset_val.collator,
203
- batch_sampler=batch_sampler_val,
204
- num_workers=kwargs.get("dataset_conf").get("num_workers", 4),
205
- pin_memory=True,
206
- )
207
- trainer = Trainer(
208
- model=model,
209
- optim=optim,
210
- scheduler=scheduler,
211
- dataloader_train=dataloader_tr,
212
- dataloader_val=dataloader_val,
213
- local_rank=local_rank,
214
- use_ddp=use_ddp,
215
- use_fsdp=use_fsdp,
216
- output_dir=kwargs.get("output_dir", "./exp"),
217
- resume=kwargs.get("resume", True),
218
- **kwargs.get("train_conf"),
219
- )
220
- trainer.run()
221
-
222
- if use_ddp or use_fsdp:
223
- torch.distributed.destroy_process_group()
224
-
225
-
226
- if __name__ == "__main__":
227
- main_hydra()