ing0 commited on
Commit
33facbc
·
1 Parent(s): a58e0e9
app.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from openai import OpenAI
3
+ import requests
4
+ import json
5
+ # from volcenginesdkarkruntime import Ark
6
+ import torch
7
+ import torchaudio
8
+ from einops import rearrange
9
+ import argparse
10
+ import json
11
+ import os
12
+ from tqdm import tqdm
13
+ import random
14
+ import numpy as np
15
+ import sys
16
+ from diffrhythm.infer.infer_utils import (
17
+ get_reference_latent,
18
+ get_lrc_token,
19
+ get_style_prompt,
20
+ prepare_model,
21
+ get_negative_style_prompt
22
+ )
23
+ from diffrhythm.infer.infer import inference
24
+
25
+ device='cuda'
26
+ cfm, tokenizer, muq, vae = prepare_model(device)
27
+ cfm = torch.compile(cfm)
28
+
29
+ def infer_music(lrc, ref_audio_path, max_frames=2048, device='cuda'):
30
+
31
+ # lrc_list = lrc.split("\n")
32
+ # print(lrc_list)
33
+
34
+ # return "./gift_of_the_world.wav"
35
+ lrc_prompt, start_time = get_lrc_token(lrc, tokenizer, device)
36
+ style_prompt = get_style_prompt(muq, ref_audio_path)
37
+ negative_style_prompt = get_negative_style_prompt(device)
38
+ latent_prompt = get_reference_latent(device, max_frames)
39
+ generated_song = inference(cfm_model=cfm,
40
+ vae_model=vae,
41
+ cond=latent_prompt,
42
+ text=lrc_prompt,
43
+ duration=max_frames,
44
+ style_prompt=style_prompt,
45
+ negative_style_prompt=negative_style_prompt,
46
+ start_time=start_time
47
+ )
48
+ return generated_song
49
+
50
+ def R1_infer1(theme, tags_gen, language):
51
+ try:
52
+ client = OpenAI(api_key="XXXX", base_url = "https://ark.cn-beijing.volces.com/api/v3")
53
+
54
+ llm_prompt = """
55
+ 请围绕"{theme}"主题生成一首符合"{tags}"风格的完整歌词。生成的{language}语言的歌词。
56
+ ### **歌曲结构要求**
57
+ 1. 歌词应富有变化,使情绪递进,整体连贯有层次感。**每行歌词长度应自然变化**,切勿长度一致,导致很格式化。
58
+ 2. **时间戳分配应根据歌曲的标签\歌词的情感、节奏来合理推测**,而非机械地按照歌词长度分配。
59
+ ### **歌曲内容要求**
60
+ 1. **第一句歌词的时间戳应考虑前奏长度**,避免歌词从 `[00:00.00]` 直接开始。
61
+ 2. **严格按照 LRC 格式输出歌词**,每行格式为 `[mm:ss.xx]歌词内容`。
62
+ 3. 输出的歌词不能有空行、括号,不能有其他解释内容,例如:副歌、桥段、结尾。
63
+ 4. 输出必须是**纯净的 LRC**。
64
+ """
65
+
66
+ response = client.chat.completions.create(
67
+ model="ep-20250215195652-lrff7",
68
+ messages=[
69
+ {"role": "system", "content": "You are a professional musician who has been invited to make music-related comments."},
70
+ {"role": "user", "content": llm_prompt.format(theme=theme, tags=tags_gen, language=language)},
71
+ ],
72
+ stream=False
73
+ )
74
+
75
+ info = response.choices[0].message.content
76
+
77
+ return info
78
+
79
+ except requests.exceptions.RequestException as e:
80
+ print(f'请求出错: {e}')
81
+ return {}
82
+
83
+
84
+
85
+ def R1_infer2(tags_lyrics, lyrics_input):
86
+ client = OpenAI(api_key="XXX", base_url = "https://ark.cn-beijing.volces.com/api/v3")
87
+
88
+ llm_prompt = """
89
+ {lyrics_input}这是一首歌的歌词,每一行是一句歌词,{tags_lyrics}是我希望这首歌的风格,我现在想要给这首歌的每一句歌词打时间戳得到LRC,我希望时间戳分配应根据歌曲的标签、歌词的情感、节奏来合理推测,而非机械地按照歌词长度分配。第一句歌词的时间戳应考虑前奏长度,避免歌词从 `[00:00.00]` 直接开始。严格按照 LRC 格式输出歌词,每行格式为 `[mm:ss.xx]歌词内容`。最后的结果只输出LRC,不需要其他的解释。
90
+ """
91
+
92
+ response = client.chat.completions.create(
93
+ model="ep-20250215195652-lrff7",
94
+ messages=[
95
+ {"role": "system", "content": "You are a professional musician who has been invited to make music-related comments."},
96
+ {"role": "user", "content": llm_prompt.format(lyrics_input=lyrics_input, tags_lyrics=tags_lyrics)},
97
+ ],
98
+ stream=False
99
+ )
100
+
101
+ info = response.choices[0].message.content
102
+
103
+ return info
104
+
105
+ css = """
106
+ /* 固定文本域高度并强制滚动条 */
107
+ .lyrics-scroll-box textarea {
108
+ height: 300px !important; /* 固定高度 */
109
+ max-height: 500px !important; /* 最大高度 */
110
+ overflow-y: auto !important; /* 垂直滚动 */
111
+ white-space: pre-wrap; /* 保留换行 */
112
+ line-height: 1.5; /* 行高优化 */
113
+ }
114
+ """
115
+
116
+ with gr.Blocks(css=css) as demo:
117
+ gr.Markdown("# DiffRhythm")
118
+
119
+ with gr.Tabs() as tabs:
120
+
121
+ # page 1
122
+ with gr.Tab("Music Generate", id=0):
123
+ with gr.Row():
124
+ with gr.Column():
125
+ with gr.Accordion("Best Practices Guide", open=False):
126
+ gr.Markdown("""
127
+ 1. **Lyrics Format Requirements**
128
+ - Each line must follow: `[mm:ss.xx]Lyric content`
129
+ - Example of valid format:
130
+ ```
131
+ [00:07.23]Fight me fight me fight me
132
+ [00:08.73]You made me so unlike me
133
+ ```
134
+
135
+ 2. **Generation Duration Limits**
136
+ - Current version supports maximum **95 seconds** of music generation
137
+ - Total timestamps should not exceed 01:35.00 (95 seconds)
138
+
139
+ 3. **Audio Prompt Requirements**
140
+ - Reference audio should be ≥10 seconds for optimal results
141
+ - Shorter clips may lead to incoherent generation
142
+ """)
143
+ lrc = gr.Textbox(
144
+ label="Lrc",
145
+ placeholder="Input the full lyrics",
146
+ lines=12,
147
+ max_lines=50,
148
+ elem_classes="lyrics-scroll-box"
149
+ )
150
+ audio_prompt = gr.Audio(label="Audio Prompt", type="filepath")
151
+
152
+ with gr.Column():
153
+ lyrics_btn = gr.Button("Submit", variant="primary")
154
+ audio_output = gr.Audio(label="Audio Result", type="filepath", elem_id="audio_output")
155
+
156
+
157
+ gr.Examples(
158
+ examples=[
159
+ ["./gift_of_the_world.wav"],
160
+ ["./most_beautiful_expectation.wav"],
161
+ ["./ltwyl.wav"]
162
+ ],
163
+ inputs=[audio_prompt],
164
+ label="Audio Examples",
165
+ examples_per_page=3
166
+ )
167
+
168
+ gr.Examples(
169
+ examples=[
170
+ ["""[00:10.00]Moonlight spills through broken blinds
171
+ [00:13.20]Your shadow dances on the dashboard shrine
172
+ [00:16.85]Neon ghosts in gasoline rain
173
+ [00:20.40]I hear your laughter down the midnight train
174
+ [00:24.15]Static whispers through frayed wires
175
+ [00:27.65]Guitar strings hum our cathedral choirs
176
+ [00:31.30]Flicker screens show reruns of June
177
+ [00:34.90]I'm drowning in this mercury lagoon
178
+ [00:38.55]Electric veins pulse through concrete skies
179
+ [00:42.10]Your name echoes in the hollow where my heartbeat lies
180
+ [00:45.75]We're satellites trapped in parallel light
181
+ [00:49.25]Burning through the atmosphere of endless night
182
+ [01:00.00]Dusty vinyl spins reverse
183
+ [01:03.45]Our polaroid timeline bleeds through the verse
184
+ [01:07.10]Telescope aimed at dead stars
185
+ [01:10.65]Still tracing constellations through prison bars
186
+ [01:14.30]Electric veins pulse through concrete skies
187
+ [01:17.85]Your name echoes in the hollow where my heartbeat lies
188
+ [01:21.50]We're satellites trapped in parallel light
189
+ [01:25.05]Burning through the atmosphere of endless night
190
+ [02:10.00]Clockwork gears grind moonbeams to rust
191
+ [02:13.50]Our fingerprint smudged by interstellar dust
192
+ [02:17.15]Velvet thunder rolls through my veins
193
+ [02:20.70]Chasing phantom trains through solar plane
194
+ [02:24.35]Electric veins pulse through concrete skies
195
+ [02:27.90]Your name echoes in the hollow where my heartbeat lies"""],
196
+ ["""[00:05.00]Stardust whispers in your eyes
197
+ [00:09.30]Moonlight paints our silhouettes
198
+ [00:13.75]Tides bring secrets from the deep
199
+ [00:18.20]Where forever's breath is kept
200
+ [00:22.90]We dance through constellations' maze
201
+ [00:27.15]Footprints melt in cosmic waves
202
+ [00:31.65]Horizons hum our silent vow
203
+ [00:36.10]Time unravels here and now
204
+ [00:40.85]Eternal embers in the night oh oh oh
205
+ [00:45.25]Healing scars with liquid light
206
+ [00:49.70]Galaxies write our refrain
207
+ [00:54.15]Love reborn in endless rain
208
+ [01:15.30]Paper boats of memories
209
+ [01:19.75]Float through veins of ancient trees
210
+ [01:24.20]Your laughter spins aurora threads
211
+ [01:28.65]Weaving dawn through featherbed"""]
212
+ ],
213
+ inputs=[lrc], # 只绑定到歌词输入
214
+ label="Lrc Examples",
215
+ examples_per_page=2
216
+ )
217
+
218
+ # page 2
219
+ with gr.Tab("LLM Generate LRC", id=1):
220
+ with gr.Row():
221
+ with gr.Column():
222
+ with gr.Accordion("Notice", open=False):
223
+ gr.Markdown("**Two Generation Modes:**\n1. Generate from theme & tags\n2. Add timestamps to existing lyrics")
224
+
225
+ with gr.Group():
226
+ gr.Markdown("### Method 1: Generate from Theme")
227
+ theme = gr.Textbox(label="theme", placeholder="Enter song theme, e.g. Love and Heartbreak")
228
+ tags_gen = gr.Textbox(label="tags", placeholder="Example: male pop confidence healing")
229
+ language = gr.Dropdown(["zh", "en"], label="language", value="en")
230
+ gen_from_theme_btn = gr.Button("Generate LRC (From Theme)", variant="primary")
231
+
232
+ with gr.Group(visible=True):
233
+ gr.Markdown("### Method 2: Add Timestamps to Lyrics")
234
+ tags_lyrics = gr.Textbox(label="tags", placeholder="Example: female ballad piano slow")
235
+ lyrics_input = gr.Textbox(
236
+ label="Raw Lyrics (without timestamps)",
237
+ placeholder="Enter plain lyrics (without timestamps), e.g.:\nYesterday\nAll my troubles...",
238
+ lines=12,
239
+ max_lines=50,
240
+ elem_classes="lyrics-scroll-box"
241
+ )
242
+ gen_from_lyrics_btn = gr.Button("Generate LRC (From Lyrics)", variant="primary")
243
+
244
+ with gr.Column():
245
+ lrc_output = gr.Textbox(
246
+ label="Generated LRC Lyrics",
247
+ placeholder="Timed lyrics will appear here",
248
+ lines=50,
249
+ elem_classes="lrc-output",
250
+ show_copy_button=True
251
+ )
252
+
253
+ # Examples section
254
+ gr.Examples(
255
+ examples=[
256
+ [
257
+ "Love and Heartbreak",
258
+ "female vocal emotional piano pop",
259
+ "en"
260
+ ],
261
+ [
262
+ "Heroic Epic",
263
+ "male choir orchestral powerful",
264
+ "zh"
265
+ ]
266
+ ],
267
+ inputs=[theme, tags_gen, language],
268
+ label="Examples: Generate from Theme"
269
+ )
270
+
271
+ gr.Examples(
272
+ examples=[
273
+ [
274
+ "acoustic folk happy",
275
+ """I'm sitting here in the boring room
276
+ It's just another rainy Sunday afternoon"""
277
+ ],
278
+ [
279
+ "electronic dance energetic",
280
+ """We're living in a material world
281
+ And I am a material girl"""
282
+ ]
283
+ ],
284
+ inputs=[tags_lyrics, lyrics_input],
285
+ label="Examples: Generate from Lyrics"
286
+ )
287
+
288
+ # Bind functions
289
+ gen_from_theme_btn.click(
290
+ fn=R1_infer1,
291
+ inputs=[theme, tags_gen, language],
292
+ outputs=lrc_output
293
+ )
294
+
295
+ gen_from_lyrics_btn.click(
296
+ fn=R1_infer2,
297
+ inputs=[tags_lyrics, lyrics_input],
298
+ outputs=lrc_output
299
+ )
300
+
301
+ tabs.select(
302
+ lambda s: None,
303
+ None,
304
+ None
305
+ )
306
+
307
+ lyrics_btn.click(
308
+ fn=infer_music,
309
+ inputs=[lrc, audio_prompt],
310
+ outputs=audio_output
311
+ )
312
+
313
+ demo.queue().launch(show_api=False, show_error=True)
314
+
315
+
316
+
317
+ if __name__ == "__main__":
318
+ demo.launch()
diffrhythm/config/defaults.ini ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ [DEFAULTS]
3
+
4
+ #name of the run
5
+ exp_name = F5
6
+
7
+ # the batch size
8
+ batch_size = 8
9
+
10
+ # the chunk size
11
+ max_frames = 3000
12
+ min_frames = 10
13
+
14
+ # number of CPU workers for the DataLoader
15
+ num_workers = 4
16
+
17
+ # the random seed
18
+ seed = 42
19
+
20
+ # Batches for gradient accumulation
21
+ accum_batches = 1
22
+
23
+ # Number of steps between checkpoints
24
+ checkpoint_every = 10000
25
+
26
+ # trainer checkpoint file to restart training from
27
+ ckpt_path = ''
28
+
29
+ # model checkpoint file to start a new training run from
30
+ pretrained_ckpt_path = ''
31
+
32
+ # Checkpoint path for the pretransform model if needed
33
+ pretransform_ckpt_path = ''
34
+
35
+ # configuration model specifying model hyperparameters
36
+ model_config = ''
37
+
38
+ # configuration for datasets
39
+ dataset_config = ''
40
+
41
+ # directory to save the checkpoints in
42
+ save_dir = ''
43
+
44
+ # grad norm
45
+ max_grad_norm = 1.0
46
+
47
+ # grad accu
48
+ grad_accumulation_steps = 1
49
+
50
+ # lr
51
+ learning_rate = 7.5e-5
52
+
53
+ # epoch
54
+ epochs = 110
55
+
56
+ # warmup steps
57
+ num_warmup_updates = 2000
58
+
59
+ # save checkpoint per steps
60
+ save_per_updates = 5000
61
+
62
+ # save last checkpoint per steps
63
+ last_per_steps = 5000
64
+
65
+ prompt_path = "/mnt/sfs/music/lance/style-lance-full|/mnt/sfs/music/lance/style-lance-cnen-music-second"
66
+ lrc_path = "/mnt/sfs/music/lance/lrc-lance-emb-full|/mnt/sfs/music/lance/lrc-lance-cnen-second"
67
+ latent_path = "/mnt/sfs/music/lance/latent-lance|/mnt/sfs/music/lance/latent-lance-cnen-music-second-1|/mnt/sfs/music/lance/latent-lance-cnen-music-second-2"
68
+
69
+ audio_drop_prob = 0.3
70
+ cond_drop_prob = 0.0
71
+ style_drop_prob = 0.1
72
+ lrc_drop_prob = 0.1
73
+
74
+ align_lyrics = 0
75
+ lyrics_slice = 0
76
+ parse_lyrics = 1
77
+ skip_empty_lyrics = 0
78
+ lyrics_shift = -1
79
+
80
+ use_style_prompt = 1
81
+
82
+ tokenizer_type = gpt2
83
+
84
+ reset_lr = 0
85
+
86
+ resumable_with_seed = 666
87
+
88
+ downsample_rate = 2048
89
+
90
+ grad_ckpt = 0
91
+
92
+ dataset_path = "/mnt/sfs/music/hkchen/workspace/F5-TTS-HW/filelists/music123latent_asred_bpmstyle_cnen_pure1"
93
+
94
+ pure_prob = 0.0
diffrhythm/config/diffrhythm-1b.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "diffrhythm",
3
+ "model": {
4
+ "dim": 2048,
5
+ "depth": 16,
6
+ "heads": 32,
7
+ "ff_mult": 4,
8
+ "text_dim": 512,
9
+ "conv_layers": 4,
10
+ "mel_dim": 64,
11
+ "text_num_embeds": 363
12
+ }
13
+ }
diffrhythm/model/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from diffrhythm.model.cfm import CFM
2
+
3
+ from diffrhythm.model.dit import DiT
4
+
5
+
6
+ __all__ = ["CFM"]
diffrhythm/model/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (290 Bytes). View file
 
diffrhythm/model/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (508 Bytes). View file
 
diffrhythm/model/__pycache__/cfm.cpython-310.pyc ADDED
Binary file (6.28 kB). View file
 
diffrhythm/model/__pycache__/cfm.cpython-312.pyc ADDED
Binary file (10.7 kB). View file
 
diffrhythm/model/__pycache__/custom_dataset.cpython-310.pyc ADDED
Binary file (11.5 kB). View file
 
diffrhythm/model/__pycache__/custom_dataset_lrc_emb.cpython-310.pyc ADDED
Binary file (10.5 kB). View file
 
diffrhythm/model/__pycache__/dataset.cpython-310.pyc ADDED
Binary file (8.04 kB). View file
 
diffrhythm/model/__pycache__/dit.cpython-310.pyc ADDED
Binary file (5.61 kB). View file
 
diffrhythm/model/__pycache__/modules.cpython-310.pyc ADDED
Binary file (15.9 kB). View file
 
diffrhythm/model/__pycache__/trainer.cpython-310.pyc ADDED
Binary file (9.13 kB). View file
 
diffrhythm/model/__pycache__/utils.cpython-310.pyc ADDED
Binary file (6.03 kB). View file
 
diffrhythm/model/cfm.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ein notation:
3
+ b - batch
4
+ n - sequence
5
+ nt - text sequence
6
+ nw - raw wave length
7
+ d - dimension
8
+ """
9
+
10
+ from __future__ import annotations
11
+ from typing import Callable
12
+ from random import random
13
+
14
+ import torch
15
+ from torch import nn
16
+ import torch
17
+ import torch.nn.functional as F
18
+ from torch.nn.utils.rnn import pad_sequence
19
+
20
+ from torchdiffeq import odeint
21
+
22
+ from diffrhythm.model.modules import MelSpec
23
+ from diffrhythm.model.utils import (
24
+ default,
25
+ exists,
26
+ list_str_to_idx,
27
+ list_str_to_tensor,
28
+ lens_to_mask,
29
+ mask_from_frac_lengths,
30
+ )
31
+
32
+ def custom_mask_from_start_end_indices(seq_len: int["b"], start: int["b"], end: int["b"], device, max_seq_len): # noqa: F722 F821
33
+ max_seq_len = max_seq_len
34
+ seq = torch.arange(max_seq_len, device=device).long()
35
+ start_mask = seq[None, :] >= start[:, None]
36
+ end_mask = seq[None, :] < end[:, None]
37
+ return start_mask & end_mask
38
+
39
+ class CFM(nn.Module):
40
+ def __init__(
41
+ self,
42
+ transformer: nn.Module,
43
+ sigma=0.0,
44
+ odeint_kwargs: dict = dict(
45
+ # atol = 1e-5,
46
+ # rtol = 1e-5,
47
+ method="euler" # 'midpoint'
48
+ # method="adaptive_heun" # dopri5
49
+ ),
50
+ odeint_options: dict = dict(
51
+ min_step=0.05
52
+ ),
53
+ audio_drop_prob=0.3,
54
+ cond_drop_prob=0.2,
55
+ style_drop_prob=0.1,
56
+ lrc_drop_prob=0.1,
57
+ num_channels=None,
58
+ frac_lengths_mask: tuple[float, float] = (0.7, 1.0),
59
+ vocab_char_map: dict[str:int] | None = None,
60
+ use_style_prompt: bool = False
61
+ ):
62
+ super().__init__()
63
+
64
+ self.frac_lengths_mask = frac_lengths_mask
65
+
66
+ self.num_channels = num_channels
67
+
68
+ # classifier-free guidance
69
+ self.audio_drop_prob = audio_drop_prob
70
+ self.cond_drop_prob = cond_drop_prob
71
+ self.style_drop_prob = style_drop_prob
72
+ self.lrc_drop_prob = lrc_drop_prob
73
+
74
+ print(f"audio drop prob -> {self.audio_drop_prob}; style_drop_prob -> {self.style_drop_prob}; lrc_drop_prob: {self.lrc_drop_prob}")
75
+
76
+ # transformer
77
+ self.transformer = transformer
78
+ dim = transformer.dim
79
+ self.dim = dim
80
+
81
+ # conditional flow related
82
+ self.sigma = sigma
83
+
84
+ # sampling related
85
+ self.odeint_kwargs = odeint_kwargs
86
+ # print(f"ODE SOLVER: {self.odeint_kwargs['method']}")
87
+
88
+ self.odeint_options = odeint_options
89
+
90
+ # vocab map for tokenization
91
+ self.vocab_char_map = vocab_char_map
92
+
93
+ self.use_style_prompt = use_style_prompt
94
+
95
+ @property
96
+ def device(self):
97
+ return next(self.parameters()).device
98
+
99
+ @torch.no_grad()
100
+ def sample(
101
+ self,
102
+ cond: float["b n d"] | float["b nw"], # noqa: F722
103
+ text: int["b nt"] | list[str], # noqa: F722
104
+ duration: int | int["b"], # noqa: F821
105
+ *,
106
+ style_prompt = None,
107
+ style_prompt_lens = None,
108
+ negative_style_prompt = None,
109
+ lens: int["b"] | None = None, # noqa: F821
110
+ steps=32,
111
+ cfg_strength=4.0,
112
+ sway_sampling_coef=None,
113
+ seed: int | None = None,
114
+ max_duration=4096,
115
+ vocoder: Callable[[float["b d n"]], float["b nw"]] | None = None, # noqa: F722
116
+ no_ref_audio=False,
117
+ duplicate_test=False,
118
+ t_inter=0.1,
119
+ edit_mask=None,
120
+ start_time=None,
121
+ latent_pred_start_frame=0,
122
+ latent_pred_end_frame=2048,
123
+ ):
124
+ self.eval()
125
+
126
+ if next(self.parameters()).dtype == torch.float16:
127
+ cond = cond.half()
128
+
129
+ # raw wave
130
+
131
+ if cond.shape[1] > duration:
132
+ cond = cond[:, :duration, :]
133
+
134
+ if cond.ndim == 2:
135
+ cond = self.mel_spec(cond)
136
+ cond = cond.permute(0, 2, 1)
137
+ assert cond.shape[-1] == self.num_channels
138
+
139
+ batch, cond_seq_len, device = *cond.shape[:2], cond.device
140
+ if not exists(lens):
141
+ lens = torch.full((batch,), cond_seq_len, device=device, dtype=torch.long)
142
+
143
+ # text
144
+
145
+ if isinstance(text, list):
146
+ if exists(self.vocab_char_map):
147
+ text = list_str_to_idx(text, self.vocab_char_map).to(device)
148
+ else:
149
+ text = list_str_to_tensor(text).to(device)
150
+ assert text.shape[0] == batch
151
+
152
+ if exists(text):
153
+ text_lens = (text != -1).sum(dim=-1)
154
+ #lens = torch.maximum(text_lens, lens) # make sure lengths are at least those of the text characters
155
+
156
+ # duration
157
+ # import pdb; pdb.set_trace()
158
+ cond_mask = lens_to_mask(lens)
159
+ if edit_mask is not None:
160
+ cond_mask = cond_mask & edit_mask
161
+
162
+ latent_pred_start_frame = torch.tensor([latent_pred_start_frame]).to(cond.device)
163
+ latent_pred_end_frame = duration
164
+ latent_pred_end_frame = torch.tensor([latent_pred_end_frame]).to(cond.device)
165
+ fixed_span_mask = custom_mask_from_start_end_indices(cond_seq_len, latent_pred_start_frame, latent_pred_end_frame, device=cond.device, max_seq_len=duration)
166
+
167
+ fixed_span_mask = fixed_span_mask.unsqueeze(-1)
168
+ step_cond = torch.where(fixed_span_mask, torch.zeros_like(cond), cond)
169
+
170
+ if isinstance(duration, int):
171
+ duration = torch.full((batch,), duration, device=device, dtype=torch.long)
172
+
173
+ # duration = torch.maximum(lens + 1, duration) # just add one token so something is generated
174
+ duration = duration.clamp(max=max_duration)
175
+ max_duration = duration.amax()
176
+
177
+ # duplicate test corner for inner time step oberservation
178
+ if duplicate_test:
179
+ test_cond = F.pad(cond, (0, 0, cond_seq_len, max_duration - 2 * cond_seq_len), value=0.0)
180
+
181
+ # cond = F.pad(cond, (0, 0, 0, max_duration - cond_seq_len), value=0.0) # [b, t, d]
182
+ # cond_mask = F.pad(cond_mask, (0, max_duration - cond_mask.shape[-1]), value=False) # [b, max_duration]
183
+ # cond_mask = cond_mask.unsqueeze(-1) #[b, t, d]
184
+ # step_cond = torch.where(
185
+ # cond_mask, cond, torch.zeros_like(cond)
186
+ # ) # allow direct control (cut cond audio) with lens passed in
187
+
188
+ if batch > 1:
189
+ mask = lens_to_mask(duration)
190
+ else: # save memory and speed up, as single inference need no mask currently
191
+ mask = None
192
+
193
+ # test for no ref audio
194
+ if no_ref_audio:
195
+ cond = torch.zeros_like(cond)
196
+
197
+
198
+ def fn(t, x):
199
+ # at each step, conditioning is fixed
200
+ # step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond))
201
+
202
+ # predict flow
203
+ pred = self.transformer(
204
+ x=x, cond=step_cond, text=text, time=t, mask=mask, drop_audio_cond=False, drop_text=False, drop_prompt=False,
205
+ style_prompt=style_prompt, style_prompt_lens=style_prompt_lens, start_time=start_time
206
+ )
207
+ if cfg_strength < 1e-5:
208
+ return pred
209
+
210
+ null_pred = self.transformer(
211
+ x=x, cond=step_cond, text=text, time=t, mask=mask, drop_audio_cond=True, drop_text=True, drop_prompt=False,
212
+ style_prompt=negative_style_prompt, style_prompt_lens=style_prompt_lens, start_time=start_time
213
+ )
214
+ return pred + (pred - null_pred) * cfg_strength
215
+
216
+ # noise input
217
+ # to make sure batch inference result is same with different batch size, and for sure single inference
218
+ # still some difference maybe due to convolutional layers
219
+ y0 = []
220
+ for dur in duration:
221
+ if exists(seed):
222
+ torch.manual_seed(seed)
223
+ y0.append(torch.randn(dur, self.num_channels, device=self.device, dtype=step_cond.dtype))
224
+ y0 = pad_sequence(y0, padding_value=0, batch_first=True)
225
+
226
+ t_start = 0
227
+
228
+ # duplicate test corner for inner time step oberservation
229
+ if duplicate_test:
230
+ t_start = t_inter
231
+ y0 = (1 - t_start) * y0 + t_start * test_cond
232
+ steps = int(steps * (1 - t_start))
233
+
234
+ t = torch.linspace(t_start, 1, steps, device=self.device, dtype=step_cond.dtype)
235
+ if sway_sampling_coef is not None:
236
+ t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t)
237
+
238
+ trajectory = odeint(fn, y0, t, **self.odeint_kwargs)
239
+
240
+ sampled = trajectory[-1]
241
+ out = sampled
242
+ # out = torch.where(cond_mask, cond, out)
243
+ out = torch.where(fixed_span_mask, out, cond)
244
+
245
+ if exists(vocoder):
246
+ out = out.permute(0, 2, 1)
247
+ out = vocoder(out)
248
+
249
+ return out, trajectory
250
+
251
+ def forward(
252
+ self,
253
+ inp: float["b n d"] | float["b nw"], # mel or raw wave # noqa: F722
254
+ text: int["b nt"] | list[str], # noqa: F722
255
+ style_prompt = None,
256
+ style_prompt_lens = None,
257
+ lens: int["b"] | None = None, # noqa: F821
258
+ noise_scheduler: str | None = None,
259
+ grad_ckpt = False,
260
+ start_time = None,
261
+ ):
262
+
263
+ batch, seq_len, dtype, device, _σ1 = *inp.shape[:2], inp.dtype, self.device, self.sigma
264
+
265
+ # lens and mask
266
+ if not exists(lens):
267
+ lens = torch.full((batch,), seq_len, device=device)
268
+
269
+ mask = lens_to_mask(lens, length=seq_len) # useless here, as collate_fn will pad to max length in batch
270
+
271
+ # get a random span to mask out for training conditionally
272
+ frac_lengths = torch.zeros((batch,), device=self.device).float().uniform_(*self.frac_lengths_mask)
273
+ rand_span_mask = mask_from_frac_lengths(lens, frac_lengths)
274
+
275
+ if exists(mask):
276
+ rand_span_mask = mask
277
+ # rand_span_mask &= mask
278
+
279
+ # mel is x1
280
+ x1 = inp
281
+
282
+ # x0 is gaussian noise
283
+ x0 = torch.randn_like(x1)
284
+
285
+ # time step
286
+ # time = torch.rand((batch,), dtype=dtype, device=self.device)
287
+ time = torch.normal(mean=0, std=1, size=(batch,), device=self.device)
288
+ time = torch.nn.functional.sigmoid(time)
289
+ # TODO. noise_scheduler
290
+
291
+ # sample xt (φ_t(x) in the paper)
292
+ t = time.unsqueeze(-1).unsqueeze(-1)
293
+ φ = (1 - t) * x0 + t * x1
294
+ flow = x1 - x0
295
+
296
+ # only predict what is within the random mask span for infilling
297
+ cond = torch.where(rand_span_mask[..., None], torch.zeros_like(x1), x1)
298
+
299
+ # transformer and cfg training with a drop rate
300
+ drop_audio_cond = random() < self.audio_drop_prob # p_drop in voicebox paper
301
+ drop_text = random() < self.lrc_drop_prob
302
+ drop_prompt = random() < self.style_drop_prob
303
+
304
+ # if want rigourously mask out padding, record in collate_fn in dataset.py, and pass in here
305
+ # adding mask will use more memory, thus also need to adjust batchsampler with scaled down threshold for long sequences
306
+ pred = self.transformer(
307
+ x=φ, cond=cond, text=text, time=time, drop_audio_cond=drop_audio_cond, drop_text=drop_text, drop_prompt=drop_prompt,
308
+ style_prompt=style_prompt, style_prompt_lens=style_prompt_lens, grad_ckpt=grad_ckpt, start_time=start_time
309
+ )
310
+
311
+ # flow matching loss
312
+ loss = F.mse_loss(pred, flow, reduction="none")
313
+ loss = loss[rand_span_mask]
314
+
315
+ return loss.mean(), cond, pred
diffrhythm/model/dit.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ein notation:
3
+ b - batch
4
+ n - sequence
5
+ nt - text sequence
6
+ nw - raw wave length
7
+ d - dimension
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import torch
13
+ from torch import nn
14
+ import torch
15
+ import torch.nn.functional as F
16
+
17
+ from x_transformers.x_transformers import RotaryEmbedding
18
+ from transformers.models.llama.modeling_llama import LlamaDecoderLayer
19
+ from transformers.models.llama import LlamaConfig
20
+ from torch.utils.checkpoint import checkpoint
21
+
22
+ from diffrhythm.model.modules import (
23
+ TimestepEmbedding,
24
+ ConvNeXtV2Block,
25
+ ConvPositionEmbedding,
26
+ DiTBlock,
27
+ AdaLayerNormZero_Final,
28
+ precompute_freqs_cis,
29
+ get_pos_embed_indices,
30
+ )
31
+
32
+
33
+ # Text embedding
34
+
35
+
36
+ class TextEmbedding(nn.Module):
37
+ def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2):
38
+ super().__init__()
39
+ self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
40
+
41
+ if conv_layers > 0:
42
+ self.extra_modeling = True
43
+ self.precompute_max_pos = 4096 # ~44s of 24khz audio
44
+ self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False)
45
+ self.text_blocks = nn.Sequential(
46
+ *[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)]
47
+ )
48
+ else:
49
+ self.extra_modeling = False
50
+
51
+ def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722
52
+ #text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
53
+ #text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
54
+ batch, text_len = text.shape[0], text.shape[1]
55
+ #text = F.pad(text, (0, seq_len - text_len), value=0)
56
+
57
+ if drop_text: # cfg for text
58
+ text = torch.zeros_like(text)
59
+
60
+ text = self.text_embed(text) # b n -> b n d
61
+
62
+ # possible extra modeling
63
+ if self.extra_modeling:
64
+ # sinus pos emb
65
+ batch_start = torch.zeros((batch,), dtype=torch.long)
66
+ pos_idx = get_pos_embed_indices(batch_start, seq_len, max_pos=self.precompute_max_pos)
67
+ text_pos_embed = self.freqs_cis[pos_idx]
68
+ text = text + text_pos_embed
69
+
70
+ # convnextv2 blocks
71
+ text = self.text_blocks(text)
72
+
73
+ return text
74
+
75
+
76
+ # noised input audio and context mixing embedding
77
+
78
+
79
+ class InputEmbedding(nn.Module):
80
+ def __init__(self, mel_dim, text_dim, out_dim, cond_dim):
81
+ super().__init__()
82
+ self.proj = nn.Linear(mel_dim * 2 + text_dim + cond_dim * 2, out_dim)
83
+ self.conv_pos_embed = ConvPositionEmbedding(dim=out_dim)
84
+
85
+ def forward(self, x: float["b n d"], cond: float["b n d"], text_embed: float["b n d"], style_emb, time_emb, drop_audio_cond=False): # noqa: F722
86
+ if drop_audio_cond: # cfg for cond audio
87
+ cond = torch.zeros_like(cond)
88
+
89
+ style_emb = style_emb.unsqueeze(1).repeat(1, x.shape[1], 1)
90
+ time_emb = time_emb.unsqueeze(1).repeat(1, x.shape[1], 1)
91
+ # print(x.shape, cond.shape, text_embed.shape, style_emb.shape, time_emb.shape)
92
+ x = self.proj(torch.cat((x, cond, text_embed, style_emb, time_emb), dim=-1))
93
+ x = self.conv_pos_embed(x) + x
94
+ return x
95
+
96
+
97
+ # Transformer backbone using DiT blocks
98
+
99
+
100
+ class DiT(nn.Module):
101
+ def __init__(
102
+ self,
103
+ *,
104
+ dim,
105
+ depth=8,
106
+ heads=8,
107
+ dim_head=64,
108
+ dropout=0.1,
109
+ ff_mult=4,
110
+ mel_dim=100,
111
+ text_num_embeds=256,
112
+ text_dim=None,
113
+ conv_layers=0,
114
+ long_skip_connection=False,
115
+ use_style_prompt=False
116
+ ):
117
+ super().__init__()
118
+
119
+ cond_dim = 512
120
+ self.time_embed = TimestepEmbedding(cond_dim)
121
+ self.start_time_embed = TimestepEmbedding(cond_dim)
122
+ if text_dim is None:
123
+ text_dim = mel_dim
124
+ self.text_embed = TextEmbedding(text_num_embeds, text_dim, conv_layers=conv_layers)
125
+ self.input_embed = InputEmbedding(mel_dim, text_dim, dim, cond_dim=cond_dim)
126
+
127
+ #self.rotary_embed = RotaryEmbedding(dim_head)
128
+
129
+ self.dim = dim
130
+ self.depth = depth
131
+
132
+ #self.transformer_blocks = nn.ModuleList(
133
+ # [DiTBlock(dim=dim, heads=heads, dim_head=dim_head, ff_mult=ff_mult, dropout=dropout, use_style_prompt=use_style_prompt) for _ in range(depth)]
134
+ #)
135
+ llama_config = LlamaConfig(hidden_size=dim, intermediate_size=dim * ff_mult, hidden_act='silu')
136
+ llama_config._attn_implementation = 'sdpa'
137
+ self.transformer_blocks = nn.ModuleList(
138
+ [LlamaDecoderLayer(llama_config, layer_idx=i) for i in range(depth)]
139
+ )
140
+ self.long_skip_connection = nn.Linear(dim * 2, dim, bias=False) if long_skip_connection else None
141
+
142
+ self.text_fusion_linears = nn.ModuleList(
143
+ [
144
+ nn.Sequential(
145
+ nn.Linear(cond_dim, dim),
146
+ nn.SiLU()
147
+ ) for i in range(depth // 2)
148
+ ]
149
+ )
150
+ for layer in self.text_fusion_linears:
151
+ for p in layer.parameters():
152
+ p.detach().zero_()
153
+
154
+ self.norm_out = AdaLayerNormZero_Final(dim, cond_dim) # final modulation
155
+ self.proj_out = nn.Linear(dim, mel_dim)
156
+
157
+ # if use_style_prompt:
158
+ # self.prompt_rnn = nn.LSTM(64, cond_dim, 1, batch_first=True)
159
+
160
+
161
+ def forward(
162
+ self,
163
+ x: float["b n d"], # nosied input audio # noqa: F722
164
+ cond: float["b n d"], # masked cond audio # noqa: F722
165
+ text: int["b nt"], # text # noqa: F722
166
+ time: float["b"] | float[""], # time step # noqa: F821 F722
167
+ drop_audio_cond, # cfg for cond audio
168
+ drop_text, # cfg for text
169
+ drop_prompt=False,
170
+ style_prompt=None, # [b d t]
171
+ style_prompt_lens=None,
172
+ mask: bool["b n"] | None = None, # noqa: F722
173
+ grad_ckpt=False,
174
+ start_time=None,
175
+ ):
176
+ batch, seq_len = x.shape[0], x.shape[1]
177
+ if time.ndim == 0:
178
+ time = time.repeat(batch)
179
+
180
+ # t: conditioning time, c: context (text + masked cond audio), x: noised input audio
181
+ t = self.time_embed(time)
182
+ s_t = self.start_time_embed(start_time)
183
+ c = t + s_t
184
+ text_embed = self.text_embed(text, seq_len, drop_text=drop_text)
185
+
186
+ # import pdb; pdb.set_trace()
187
+ if drop_prompt:
188
+ style_prompt = torch.zeros_like(style_prompt)
189
+ # if self.training:
190
+ # packed_style_prompt = torch.nn.utils.rnn.pack_padded_sequence(style_prompt.transpose(1, 2), style_prompt_lens.cpu(), batch_first=True, enforce_sorted=False)
191
+ # else:
192
+ # packed_style_prompt = style_prompt.transpose(1, 2)
193
+ #print(packed_style_prompt.shape)
194
+ # _, style_emb = self.prompt_rnn.forward(packed_style_prompt)
195
+ # _, (h_n, c_n) = self.prompt_rnn.forward(packed_style_prompt)
196
+ # style_emb = h_n.squeeze(0) # 1, B, dim -> B, dim
197
+
198
+ style_emb = style_prompt # [b, 512]
199
+
200
+ x = self.input_embed(x, cond, text_embed, style_emb, c, drop_audio_cond=drop_audio_cond)
201
+
202
+ if self.long_skip_connection is not None:
203
+ residual = x
204
+
205
+ pos_ids = torch.arange(x.shape[1], device=x.device)
206
+ pos_ids = pos_ids.unsqueeze(0).repeat(x.shape[0], 1)
207
+ for i, block in enumerate(self.transformer_blocks):
208
+ if not grad_ckpt:
209
+ x, *_ = block(x, position_ids=pos_ids)
210
+ else:
211
+ x, *_ = checkpoint(block, x, position_ids=pos_ids, use_reentrant=False)
212
+ if i < self.depth // 2:
213
+ x = x + self.text_fusion_linears[i](text_embed)
214
+
215
+ if self.long_skip_connection is not None:
216
+ x = self.long_skip_connection(torch.cat((x, residual), dim=-1))
217
+
218
+ x = self.norm_out(x, c)
219
+ output = self.proj_out(x)
220
+
221
+ return output
diffrhythm/model/modules.py ADDED
@@ -0,0 +1,611 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ein notation:
3
+ b - batch
4
+ n - sequence
5
+ nt - text sequence
6
+ nw - raw wave length
7
+ d - dimension
8
+ """
9
+
10
+ from __future__ import annotations
11
+ from typing import Optional
12
+ import math
13
+
14
+ import torch
15
+ from torch import nn
16
+ import torch
17
+ import torch.nn.functional as F
18
+ import torchaudio
19
+
20
+ from x_transformers.x_transformers import apply_rotary_pos_emb
21
+
22
+
23
+
24
+ class FiLMLayer(nn.Module):
25
+ """
26
+ Feature-wise Linear Modulation (FiLM) layer
27
+ Reference: https://arxiv.org/abs/1709.07871
28
+ """
29
+ def __init__(self, in_channels, cond_channels):
30
+
31
+ super(FiLMLayer, self).__init__()
32
+ self.in_channels = in_channels
33
+ self.film = nn.Conv1d(cond_channels, in_channels * 2, 1)
34
+
35
+ def forward(self, x, c):
36
+ gamma, beta = torch.chunk(self.film(c.unsqueeze(2)), chunks=2, dim=1)
37
+ gamma = gamma.transpose(1, 2)
38
+ beta = beta.transpose(1, 2)
39
+ # print(gamma.shape, beta.shape)
40
+ return gamma * x + beta
41
+
42
+ # raw wav to mel spec
43
+
44
+
45
+ class MelSpec(nn.Module):
46
+ def __init__(
47
+ self,
48
+ filter_length=1024,
49
+ hop_length=256,
50
+ win_length=1024,
51
+ n_mel_channels=100,
52
+ target_sample_rate=24_000,
53
+ normalize=False,
54
+ power=1,
55
+ norm=None,
56
+ center=True,
57
+ ):
58
+ super().__init__()
59
+ self.n_mel_channels = n_mel_channels
60
+
61
+ self.mel_stft = torchaudio.transforms.MelSpectrogram(
62
+ sample_rate=target_sample_rate,
63
+ n_fft=filter_length,
64
+ win_length=win_length,
65
+ hop_length=hop_length,
66
+ n_mels=n_mel_channels,
67
+ power=power,
68
+ center=center,
69
+ normalized=normalize,
70
+ norm=norm,
71
+ )
72
+
73
+ self.register_buffer("dummy", torch.tensor(0), persistent=False)
74
+
75
+ def forward(self, inp):
76
+ if len(inp.shape) == 3:
77
+ inp = inp.squeeze(1) # 'b 1 nw -> b nw'
78
+
79
+ assert len(inp.shape) == 2
80
+
81
+ if self.dummy.device != inp.device:
82
+ self.to(inp.device)
83
+
84
+ mel = self.mel_stft(inp)
85
+ mel = mel.clamp(min=1e-5).log()
86
+ return mel
87
+
88
+
89
+ # sinusoidal position embedding
90
+
91
+
92
+ class SinusPositionEmbedding(nn.Module):
93
+ def __init__(self, dim):
94
+ super().__init__()
95
+ self.dim = dim
96
+
97
+ def forward(self, x, scale=1000):
98
+ device = x.device
99
+ half_dim = self.dim // 2
100
+ emb = math.log(10000) / (half_dim - 1)
101
+ emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
102
+ emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
103
+ emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
104
+ return emb
105
+
106
+
107
+ # convolutional position embedding
108
+
109
+
110
+ class ConvPositionEmbedding(nn.Module):
111
+ def __init__(self, dim, kernel_size=31, groups=16):
112
+ super().__init__()
113
+ assert kernel_size % 2 != 0
114
+ self.conv1d = nn.Sequential(
115
+ nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2),
116
+ nn.Mish(),
117
+ nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2),
118
+ nn.Mish(),
119
+ )
120
+
121
+ def forward(self, x: float["b n d"], mask: bool["b n"] | None = None): # noqa: F722
122
+ if mask is not None:
123
+ mask = mask[..., None]
124
+ x = x.masked_fill(~mask, 0.0)
125
+
126
+ x = x.permute(0, 2, 1)
127
+ x = self.conv1d(x)
128
+ out = x.permute(0, 2, 1)
129
+
130
+ if mask is not None:
131
+ out = out.masked_fill(~mask, 0.0)
132
+
133
+ return out
134
+
135
+
136
+ # rotary positional embedding related
137
+
138
+
139
+ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, theta_rescale_factor=1.0):
140
+ # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
141
+ # has some connection to NTK literature
142
+ # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
143
+ # https://github.com/lucidrains/rotary-embedding-torch/blob/main/rotary_embedding_torch/rotary_embedding_torch.py
144
+ theta *= theta_rescale_factor ** (dim / (dim - 2))
145
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
146
+ t = torch.arange(end, device=freqs.device) # type: ignore
147
+ freqs = torch.outer(t, freqs).float() # type: ignore
148
+ freqs_cos = torch.cos(freqs) # real part
149
+ freqs_sin = torch.sin(freqs) # imaginary part
150
+ return torch.cat([freqs_cos, freqs_sin], dim=-1)
151
+
152
+
153
+ def get_pos_embed_indices(start, length, max_pos, scale=1.0):
154
+ # length = length if isinstance(length, int) else length.max()
155
+ scale = scale * torch.ones_like(start, dtype=torch.float32) # in case scale is a scalar
156
+ pos = (
157
+ start.unsqueeze(1)
158
+ + (torch.arange(length, device=start.device, dtype=torch.float32).unsqueeze(0) * scale.unsqueeze(1)).long()
159
+ )
160
+ # avoid extra long error.
161
+ pos = torch.where(pos < max_pos, pos, max_pos - 1)
162
+ return pos
163
+
164
+
165
+ # Global Response Normalization layer (Instance Normalization ?)
166
+
167
+
168
+ class GRN(nn.Module):
169
+ def __init__(self, dim):
170
+ super().__init__()
171
+ self.gamma = nn.Parameter(torch.zeros(1, 1, dim))
172
+ self.beta = nn.Parameter(torch.zeros(1, 1, dim))
173
+
174
+ def forward(self, x):
175
+ Gx = torch.norm(x, p=2, dim=1, keepdim=True)
176
+ Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
177
+ return self.gamma * (x * Nx) + self.beta + x
178
+
179
+
180
+ # ConvNeXt-V2 Block https://github.com/facebookresearch/ConvNeXt-V2/blob/main/models/convnextv2.py
181
+ # ref: https://github.com/bfs18/e2_tts/blob/main/rfwave/modules.py#L108
182
+
183
+
184
+ class ConvNeXtV2Block(nn.Module):
185
+ def __init__(
186
+ self,
187
+ dim: int,
188
+ intermediate_dim: int,
189
+ dilation: int = 1,
190
+ ):
191
+ super().__init__()
192
+ padding = (dilation * (7 - 1)) // 2
193
+ self.dwconv = nn.Conv1d(
194
+ dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation
195
+ ) # depthwise conv
196
+ self.norm = nn.LayerNorm(dim, eps=1e-6)
197
+ self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers
198
+ self.act = nn.GELU()
199
+ self.grn = GRN(intermediate_dim)
200
+ self.pwconv2 = nn.Linear(intermediate_dim, dim)
201
+
202
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
203
+ residual = x
204
+ x = x.transpose(1, 2) # b n d -> b d n
205
+ x = self.dwconv(x)
206
+ x = x.transpose(1, 2) # b d n -> b n d
207
+ x = self.norm(x)
208
+ x = self.pwconv1(x)
209
+ x = self.act(x)
210
+ x = self.grn(x)
211
+ x = self.pwconv2(x)
212
+ return residual + x
213
+
214
+
215
+ # AdaLayerNormZero
216
+ # return with modulated x for attn input, and params for later mlp modulation
217
+
218
+
219
+ class AdaLayerNormZero(nn.Module):
220
+ def __init__(self, dim):
221
+ super().__init__()
222
+
223
+ self.silu = nn.SiLU()
224
+ self.linear = nn.Linear(dim, dim * 6)
225
+
226
+ self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
227
+
228
+ def forward(self, x, emb=None):
229
+ emb = self.linear(self.silu(emb))
230
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = torch.chunk(emb, 6, dim=1)
231
+
232
+ x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
233
+ return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
234
+
235
+
236
+ # AdaLayerNormZero for final layer
237
+ # return only with modulated x for attn input, cuz no more mlp modulation
238
+
239
+
240
+ class AdaLayerNormZero_Final(nn.Module):
241
+ def __init__(self, dim, cond_dim):
242
+ super().__init__()
243
+
244
+ self.silu = nn.SiLU()
245
+ self.linear = nn.Linear(cond_dim, dim * 2)
246
+
247
+ self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
248
+
249
+ def forward(self, x, emb):
250
+ emb = self.linear(self.silu(emb))
251
+ scale, shift = torch.chunk(emb, 2, dim=1)
252
+
253
+ x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
254
+ return x
255
+
256
+
257
+ # FeedForward
258
+
259
+
260
+ class FeedForward(nn.Module):
261
+ def __init__(self, dim, dim_out=None, mult=4, dropout=0.0, approximate: str = "none"):
262
+ super().__init__()
263
+ inner_dim = int(dim * mult)
264
+ dim_out = dim_out if dim_out is not None else dim
265
+
266
+ activation = nn.GELU(approximate=approximate)
267
+ #activation = nn.SiLU()
268
+ project_in = nn.Sequential(nn.Linear(dim, inner_dim), activation)
269
+ self.ff = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out))
270
+
271
+ def forward(self, x):
272
+ return self.ff(x)
273
+
274
+
275
+ # Attention with possible joint part
276
+ # modified from diffusers/src/diffusers/models/attention_processor.py
277
+
278
+
279
+ class Attention(nn.Module):
280
+ def __init__(
281
+ self,
282
+ processor: JointAttnProcessor | AttnProcessor,
283
+ dim: int,
284
+ heads: int = 8,
285
+ dim_head: int = 64,
286
+ dropout: float = 0.0,
287
+ context_dim: Optional[int] = None, # if not None -> joint attention
288
+ context_pre_only=None,
289
+ ):
290
+ super().__init__()
291
+
292
+ if not hasattr(F, "scaled_dot_product_attention"):
293
+ raise ImportError("Attention equires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
294
+
295
+ self.processor = processor
296
+
297
+ self.dim = dim
298
+ self.heads = heads
299
+ self.inner_dim = dim_head * heads
300
+ self.dropout = dropout
301
+
302
+ self.context_dim = context_dim
303
+ self.context_pre_only = context_pre_only
304
+
305
+ self.to_q = nn.Linear(dim, self.inner_dim)
306
+ self.to_k = nn.Linear(dim, self.inner_dim)
307
+ self.to_v = nn.Linear(dim, self.inner_dim)
308
+
309
+ if self.context_dim is not None:
310
+ self.to_k_c = nn.Linear(context_dim, self.inner_dim)
311
+ self.to_v_c = nn.Linear(context_dim, self.inner_dim)
312
+ if self.context_pre_only is not None:
313
+ self.to_q_c = nn.Linear(context_dim, self.inner_dim)
314
+
315
+ self.to_out = nn.ModuleList([])
316
+ self.to_out.append(nn.Linear(self.inner_dim, dim))
317
+ self.to_out.append(nn.Dropout(dropout))
318
+
319
+ if self.context_pre_only is not None and not self.context_pre_only:
320
+ self.to_out_c = nn.Linear(self.inner_dim, dim)
321
+
322
+ def forward(
323
+ self,
324
+ x: float["b n d"], # noised input x # noqa: F722
325
+ c: float["b n d"] = None, # context c # noqa: F722
326
+ mask: bool["b n"] | None = None, # noqa: F722
327
+ rope=None, # rotary position embedding for x
328
+ c_rope=None, # rotary position embedding for c
329
+ ) -> torch.Tensor:
330
+ if c is not None:
331
+ return self.processor(self, x, c=c, mask=mask, rope=rope, c_rope=c_rope)
332
+ else:
333
+ return self.processor(self, x, mask=mask, rope=rope)
334
+
335
+
336
+ # Attention processor
337
+
338
+
339
+ class AttnProcessor:
340
+ def __init__(self):
341
+ pass
342
+
343
+ def __call__(
344
+ self,
345
+ attn: Attention,
346
+ x: float["b n d"], # noised input x # noqa: F722
347
+ mask: bool["b n"] | None = None, # noqa: F722
348
+ rope=None, # rotary position embedding
349
+ ) -> torch.FloatTensor:
350
+ batch_size = x.shape[0]
351
+
352
+ # `sample` projections.
353
+ query = attn.to_q(x)
354
+ key = attn.to_k(x)
355
+ value = attn.to_v(x)
356
+
357
+ # apply rotary position embedding
358
+ if rope is not None:
359
+ freqs, xpos_scale = rope
360
+ q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
361
+
362
+ query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
363
+ key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
364
+
365
+ # attention
366
+ inner_dim = key.shape[-1]
367
+ head_dim = inner_dim // attn.heads
368
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
369
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
370
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
371
+
372
+ # mask. e.g. inference got a batch with different target durations, mask out the padding
373
+ if mask is not None:
374
+ attn_mask = mask
375
+ attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n'
376
+ attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
377
+ else:
378
+ attn_mask = None
379
+
380
+ x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
381
+ x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
382
+ x = x.to(query.dtype)
383
+
384
+ # linear proj
385
+ x = attn.to_out[0](x)
386
+ # dropout
387
+ x = attn.to_out[1](x)
388
+
389
+ if mask is not None:
390
+ mask = mask.unsqueeze(-1)
391
+ x = x.masked_fill(~mask, 0.0)
392
+
393
+ return x
394
+
395
+
396
+ # Joint Attention processor for MM-DiT
397
+ # modified from diffusers/src/diffusers/models/attention_processor.py
398
+
399
+
400
+ class JointAttnProcessor:
401
+ def __init__(self):
402
+ pass
403
+
404
+ def __call__(
405
+ self,
406
+ attn: Attention,
407
+ x: float["b n d"], # noised input x # noqa: F722
408
+ c: float["b nt d"] = None, # context c, here text # noqa: F722
409
+ mask: bool["b n"] | None = None, # noqa: F722
410
+ rope=None, # rotary position embedding for x
411
+ c_rope=None, # rotary position embedding for c
412
+ ) -> torch.FloatTensor:
413
+ residual = x
414
+
415
+ batch_size = c.shape[0]
416
+
417
+ # `sample` projections.
418
+ query = attn.to_q(x)
419
+ key = attn.to_k(x)
420
+ value = attn.to_v(x)
421
+
422
+ # `context` projections.
423
+ c_query = attn.to_q_c(c)
424
+ c_key = attn.to_k_c(c)
425
+ c_value = attn.to_v_c(c)
426
+
427
+ # apply rope for context and noised input independently
428
+ if rope is not None:
429
+ freqs, xpos_scale = rope
430
+ q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
431
+ query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
432
+ key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
433
+ if c_rope is not None:
434
+ freqs, xpos_scale = c_rope
435
+ q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
436
+ c_query = apply_rotary_pos_emb(c_query, freqs, q_xpos_scale)
437
+ c_key = apply_rotary_pos_emb(c_key, freqs, k_xpos_scale)
438
+
439
+ # attention
440
+ query = torch.cat([query, c_query], dim=1)
441
+ key = torch.cat([key, c_key], dim=1)
442
+ value = torch.cat([value, c_value], dim=1)
443
+
444
+ inner_dim = key.shape[-1]
445
+ head_dim = inner_dim // attn.heads
446
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
447
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
448
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
449
+
450
+ # mask. e.g. inference got a batch with different target durations, mask out the padding
451
+ if mask is not None:
452
+ attn_mask = F.pad(mask, (0, c.shape[1]), value=True) # no mask for c (text)
453
+ attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n'
454
+ attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
455
+ else:
456
+ attn_mask = None
457
+
458
+ x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
459
+ x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
460
+ x = x.to(query.dtype)
461
+
462
+ # Split the attention outputs.
463
+ x, c = (
464
+ x[:, : residual.shape[1]],
465
+ x[:, residual.shape[1] :],
466
+ )
467
+
468
+ # linear proj
469
+ x = attn.to_out[0](x)
470
+ # dropout
471
+ x = attn.to_out[1](x)
472
+ if not attn.context_pre_only:
473
+ c = attn.to_out_c(c)
474
+
475
+ if mask is not None:
476
+ mask = mask.unsqueeze(-1)
477
+ x = x.masked_fill(~mask, 0.0)
478
+ # c = c.masked_fill(~mask, 0.) # no mask for c (text)
479
+
480
+ return x, c
481
+
482
+
483
+ # DiT Block
484
+
485
+
486
+ class DiTBlock(nn.Module):
487
+ def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1, use_style_prompt=False):
488
+ super().__init__()
489
+
490
+ self.attn_norm = AdaLayerNormZero(dim)
491
+ self.attn = Attention(
492
+ processor=AttnProcessor(),
493
+ dim=dim,
494
+ heads=heads,
495
+ dim_head=dim_head,
496
+ dropout=dropout,
497
+ )
498
+
499
+ self.ff_norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
500
+ self.ff = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
501
+
502
+ self.use_style_prompt = use_style_prompt
503
+ if use_style_prompt:
504
+ #self.film = FiLMLayer(dim, dim)
505
+ self.prompt_norm = AdaLayerNormZero_Final(dim)
506
+
507
+ def forward(self, x, t, c=None, mask=None, rope=None): # x: noised input, t: time embedding
508
+ if c is not None and self.use_style_prompt:
509
+ #x = self.film(x, c)
510
+ x = self.prompt_norm(x, c)
511
+
512
+ # pre-norm & modulation for attention input
513
+ norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(x, emb=t)
514
+
515
+ # attention
516
+ attn_output = self.attn(x=norm, mask=mask, rope=rope)
517
+
518
+ # process attention output for input x
519
+ x = x + gate_msa.unsqueeze(1) * attn_output
520
+
521
+ norm = self.ff_norm(x) * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
522
+ ff_output = self.ff(norm)
523
+ x = x + gate_mlp.unsqueeze(1) * ff_output
524
+
525
+ return x
526
+
527
+
528
+ # MMDiT Block https://arxiv.org/abs/2403.03206
529
+
530
+
531
+ class MMDiTBlock(nn.Module):
532
+ r"""
533
+ modified from diffusers/src/diffusers/models/attention.py
534
+
535
+ notes.
536
+ _c: context related. text, cond, etc. (left part in sd3 fig2.b)
537
+ _x: noised input related. (right part)
538
+ context_pre_only: last layer only do prenorm + modulation cuz no more ffn
539
+ """
540
+
541
+ def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1, context_pre_only=False):
542
+ super().__init__()
543
+
544
+ self.context_pre_only = context_pre_only
545
+
546
+ self.attn_norm_c = AdaLayerNormZero_Final(dim) if context_pre_only else AdaLayerNormZero(dim)
547
+ self.attn_norm_x = AdaLayerNormZero(dim)
548
+ self.attn = Attention(
549
+ processor=JointAttnProcessor(),
550
+ dim=dim,
551
+ heads=heads,
552
+ dim_head=dim_head,
553
+ dropout=dropout,
554
+ context_dim=dim,
555
+ context_pre_only=context_pre_only,
556
+ )
557
+
558
+ if not context_pre_only:
559
+ self.ff_norm_c = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
560
+ self.ff_c = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
561
+ else:
562
+ self.ff_norm_c = None
563
+ self.ff_c = None
564
+ self.ff_norm_x = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
565
+ self.ff_x = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
566
+
567
+ def forward(self, x, c, t, mask=None, rope=None, c_rope=None): # x: noised input, c: context, t: time embedding
568
+ # pre-norm & modulation for attention input
569
+ if self.context_pre_only:
570
+ norm_c = self.attn_norm_c(c, t)
571
+ else:
572
+ norm_c, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.attn_norm_c(c, emb=t)
573
+ norm_x, x_gate_msa, x_shift_mlp, x_scale_mlp, x_gate_mlp = self.attn_norm_x(x, emb=t)
574
+
575
+ # attention
576
+ x_attn_output, c_attn_output = self.attn(x=norm_x, c=norm_c, mask=mask, rope=rope, c_rope=c_rope)
577
+
578
+ # process attention output for context c
579
+ if self.context_pre_only:
580
+ c = None
581
+ else: # if not last layer
582
+ c = c + c_gate_msa.unsqueeze(1) * c_attn_output
583
+
584
+ norm_c = self.ff_norm_c(c) * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
585
+ c_ff_output = self.ff_c(norm_c)
586
+ c = c + c_gate_mlp.unsqueeze(1) * c_ff_output
587
+
588
+ # process attention output for input x
589
+ x = x + x_gate_msa.unsqueeze(1) * x_attn_output
590
+
591
+ norm_x = self.ff_norm_x(x) * (1 + x_scale_mlp[:, None]) + x_shift_mlp[:, None]
592
+ x_ff_output = self.ff_x(norm_x)
593
+ x = x + x_gate_mlp.unsqueeze(1) * x_ff_output
594
+
595
+ return c, x
596
+
597
+
598
+ # time step conditioning embedding
599
+
600
+
601
+ class TimestepEmbedding(nn.Module):
602
+ def __init__(self, dim, freq_embed_dim=256):
603
+ super().__init__()
604
+ self.time_embed = SinusPositionEmbedding(freq_embed_dim)
605
+ self.time_mlp = nn.Sequential(nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
606
+
607
+ def forward(self, timestep: float["b"]): # noqa: F821
608
+ time_hidden = self.time_embed(timestep)
609
+ time_hidden = time_hidden.to(timestep.dtype)
610
+ time = self.time_mlp(time_hidden) # b d
611
+ return time
diffrhythm/model/trainer.py ADDED
@@ -0,0 +1,350 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import gc
5
+ from tqdm import tqdm
6
+ import wandb
7
+
8
+ import torch
9
+ from torch.optim import AdamW
10
+ from torch.optim.lr_scheduler import LinearLR, SequentialLR, ConstantLR
11
+
12
+ from accelerate import Accelerator
13
+ from accelerate.utils import DistributedDataParallelKwargs
14
+ from diffrhythm.dataset.custom_dataset_align2f5 import LanceDiffusionDataset
15
+
16
+ from torch.utils.data import DataLoader, DistributedSampler
17
+
18
+ from ema_pytorch import EMA
19
+
20
+ from diffrhythm.model import CFM
21
+ from diffrhythm.model.utils import exists, default
22
+
23
+ import time
24
+
25
+ # from apex.optimizers.fused_adam import FusedAdam
26
+
27
+ # trainer
28
+
29
+
30
+ class Trainer:
31
+ def __init__(
32
+ self,
33
+ model: CFM,
34
+ args,
35
+ epochs,
36
+ learning_rate,
37
+ #dataloader,
38
+ num_warmup_updates=20000,
39
+ save_per_updates=1000,
40
+ checkpoint_path=None,
41
+ batch_size=32,
42
+ batch_size_type: str = "sample",
43
+ max_samples=32,
44
+ grad_accumulation_steps=1,
45
+ max_grad_norm=1.0,
46
+ noise_scheduler: str | None = None,
47
+ duration_predictor: torch.nn.Module | None = None,
48
+ wandb_project="test_e2-tts",
49
+ wandb_run_name="test_run",
50
+ wandb_resume_id: str = None,
51
+ last_per_steps=None,
52
+ accelerate_kwargs: dict = dict(),
53
+ ema_kwargs: dict = dict(),
54
+ bnb_optimizer: bool = False,
55
+ reset_lr: bool = False,
56
+ use_style_prompt: bool = False,
57
+ grad_ckpt: bool = False
58
+ ):
59
+ self.args = args
60
+
61
+ ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=False, )
62
+
63
+ logger = "wandb" if wandb.api.api_key else None
64
+ #logger = None
65
+ print(f"Using logger: {logger}")
66
+ # print("-----------1-------------")
67
+ import tbe.common
68
+ # print("-----------2-------------")
69
+ self.accelerator = Accelerator(
70
+ log_with=logger,
71
+ kwargs_handlers=[ddp_kwargs],
72
+ gradient_accumulation_steps=grad_accumulation_steps,
73
+ **accelerate_kwargs,
74
+ )
75
+ # print("-----------3-------------")
76
+
77
+ if logger == "wandb":
78
+ if exists(wandb_resume_id):
79
+ init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name, "id": wandb_resume_id}}
80
+ else:
81
+ init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name}}
82
+ self.accelerator.init_trackers(
83
+ project_name=wandb_project,
84
+ init_kwargs=init_kwargs,
85
+ config={
86
+ "epochs": epochs,
87
+ "learning_rate": learning_rate,
88
+ "num_warmup_updates": num_warmup_updates,
89
+ "batch_size": batch_size,
90
+ "batch_size_type": batch_size_type,
91
+ "max_samples": max_samples,
92
+ "grad_accumulation_steps": grad_accumulation_steps,
93
+ "max_grad_norm": max_grad_norm,
94
+ "gpus": self.accelerator.num_processes,
95
+ "noise_scheduler": noise_scheduler,
96
+ },
97
+ )
98
+
99
+ self.precision = self.accelerator.state.mixed_precision
100
+ self.precision = self.precision.replace("no", "fp32")
101
+ print("!!!!!!!!!!!!!!!!!", self.precision)
102
+
103
+ self.model = model
104
+ #self.model = torch.compile(model)
105
+
106
+ #self.dataloader = dataloader
107
+
108
+ if self.is_main:
109
+ self.ema_model = EMA(model, include_online_model=False, **ema_kwargs)
110
+
111
+ self.ema_model.to(self.accelerator.device)
112
+ if self.accelerator.state.distributed_type in ["DEEPSPEED", "FSDP"]:
113
+ self.ema_model.half()
114
+
115
+ self.epochs = epochs
116
+ self.num_warmup_updates = num_warmup_updates
117
+ self.save_per_updates = save_per_updates
118
+ self.last_per_steps = default(last_per_steps, save_per_updates * grad_accumulation_steps)
119
+ self.checkpoint_path = default(checkpoint_path, "ckpts/test_e2-tts")
120
+
121
+ self.max_samples = max_samples
122
+ self.grad_accumulation_steps = grad_accumulation_steps
123
+ self.max_grad_norm = max_grad_norm
124
+
125
+ self.noise_scheduler = noise_scheduler
126
+
127
+ self.duration_predictor = duration_predictor
128
+
129
+ self.reset_lr = reset_lr
130
+
131
+ self.use_style_prompt = use_style_prompt
132
+
133
+ self.grad_ckpt = grad_ckpt
134
+
135
+ if bnb_optimizer:
136
+ import bitsandbytes as bnb
137
+
138
+ self.optimizer = bnb.optim.AdamW8bit(model.parameters(), lr=learning_rate)
139
+ else:
140
+ self.optimizer = AdamW(model.parameters(), lr=learning_rate)
141
+ #self.optimizer = FusedAdam(model.parameters(), lr=learning_rate)
142
+
143
+ #self.model = torch.compile(self.model)
144
+ if self.accelerator.state.distributed_type == "DEEPSPEED":
145
+ self.accelerator.state.deepspeed_plugin.deepspeed_config['train_micro_batch_size_per_gpu'] = batch_size
146
+
147
+ self.get_dataloader()
148
+ self.get_scheduler()
149
+ # self.get_constant_scheduler()
150
+
151
+ self.model, self.optimizer, self.scheduler, self.train_dataloader = self.accelerator.prepare(self.model, self.optimizer, self.scheduler, self.train_dataloader)
152
+
153
+ def get_scheduler(self):
154
+ warmup_steps = (
155
+ self.num_warmup_updates * self.accelerator.num_processes
156
+ ) # consider a fixed warmup steps while using accelerate multi-gpu ddp
157
+ total_steps = len(self.train_dataloader) * self.epochs / self.grad_accumulation_steps
158
+ decay_steps = total_steps - warmup_steps
159
+ warmup_scheduler = LinearLR(self.optimizer, start_factor=1e-8, end_factor=1.0, total_iters=warmup_steps)
160
+ decay_scheduler = LinearLR(self.optimizer, start_factor=1.0, end_factor=1e-8, total_iters=decay_steps)
161
+ # constant_scheduler = ConstantLR(self.optimizer, factor=1, total_iters=decay_steps)
162
+ self.scheduler = SequentialLR(
163
+ self.optimizer, schedulers=[warmup_scheduler, decay_scheduler], milestones=[warmup_steps]
164
+ )
165
+
166
+ def get_constant_scheduler(self):
167
+ total_steps = len(self.train_dataloader) * self.epochs / self.grad_accumulation_steps
168
+ self.scheduler = ConstantLR(self.optimizer, factor=1, total_iters=total_steps)
169
+
170
+ def get_dataloader(self):
171
+ prompt_path = self.args.prompt_path.split('|')
172
+ lrc_path = self.args.lrc_path.split('|')
173
+ latent_path = self.args.latent_path.split('|')
174
+ ldd = LanceDiffusionDataset(*LanceDiffusionDataset.init_data(self.args.dataset_path), \
175
+ max_frames=self.args.max_frames, min_frames=self.args.min_frames, \
176
+ align_lyrics=self.args.align_lyrics, lyrics_slice=self.args.lyrics_slice, \
177
+ use_style_prompt=self.args.use_style_prompt, parse_lyrics=self.args.parse_lyrics,
178
+ lyrics_shift=self.args.lyrics_shift, downsample_rate=self.args.downsample_rate, \
179
+ skip_empty_lyrics=self.args.skip_empty_lyrics, tokenizer_type=self.args.tokenizer_type, precision=self.precision, \
180
+ start_time=time.time(), pure_prob=self.args.pure_prob)
181
+
182
+ # start_time = time.time()
183
+ self.train_dataloader = DataLoader(
184
+ dataset=ldd,
185
+ batch_size=self.args.batch_size, # 每个批次的样本数
186
+ shuffle=True, # 是否随机打乱数据
187
+ num_workers=4, # 用于加载数据的子进程数
188
+ pin_memory=True, # 加速GPU训练
189
+ collate_fn=ldd.custom_collate_fn,
190
+ persistent_workers=True
191
+ )
192
+
193
+
194
+ @property
195
+ def is_main(self):
196
+ return self.accelerator.is_main_process
197
+
198
+ def save_checkpoint(self, step, last=False):
199
+ self.accelerator.wait_for_everyone()
200
+ if self.is_main:
201
+ checkpoint = dict(
202
+ model_state_dict=self.accelerator.unwrap_model(self.model).state_dict(),
203
+ optimizer_state_dict=self.accelerator.unwrap_model(self.optimizer).state_dict(),
204
+ ema_model_state_dict=self.ema_model.state_dict(),
205
+ scheduler_state_dict=self.scheduler.state_dict(),
206
+ step=step,
207
+ )
208
+ if not os.path.exists(self.checkpoint_path):
209
+ os.makedirs(self.checkpoint_path)
210
+ if last:
211
+ self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_last.pt")
212
+ print(f"Saved last checkpoint at step {step}")
213
+ else:
214
+ self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_{step}.pt")
215
+
216
+ def load_checkpoint(self):
217
+ if (
218
+ not exists(self.checkpoint_path)
219
+ or not os.path.exists(self.checkpoint_path)
220
+ or not os.listdir(self.checkpoint_path)
221
+ ):
222
+ return 0
223
+
224
+ self.accelerator.wait_for_everyone()
225
+ if "model_last.pt" in os.listdir(self.checkpoint_path):
226
+ latest_checkpoint = "model_last.pt"
227
+ else:
228
+ latest_checkpoint = sorted(
229
+ [f for f in os.listdir(self.checkpoint_path) if f.endswith(".pt")],
230
+ key=lambda x: int("".join(filter(str.isdigit, x))),
231
+ )[-1]
232
+
233
+ checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location="cpu")
234
+
235
+ ### **1. 过滤 `ema_model` 的不匹配参数**
236
+ if self.is_main:
237
+ ema_dict = self.ema_model.state_dict()
238
+ ema_checkpoint_dict = checkpoint["ema_model_state_dict"]
239
+
240
+ filtered_ema_dict = {
241
+ k: v for k, v in ema_checkpoint_dict.items()
242
+ if k in ema_dict and ema_dict[k].shape == v.shape # 仅加载 shape 匹配的参数
243
+ }
244
+
245
+ print(f"Loading {len(filtered_ema_dict)} / {len(ema_checkpoint_dict)} ema_model params")
246
+ self.ema_model.load_state_dict(filtered_ema_dict, strict=False)
247
+
248
+ ### **2. 过滤 `model` 的不匹配参数**
249
+ model_dict = self.accelerator.unwrap_model(self.model).state_dict()
250
+ checkpoint_model_dict = checkpoint["model_state_dict"]
251
+
252
+ filtered_model_dict = {
253
+ k: v for k, v in checkpoint_model_dict.items()
254
+ if k in model_dict and model_dict[k].shape == v.shape # 仅加载 shape 匹配的参数
255
+ }
256
+
257
+ print(f"Loading {len(filtered_model_dict)} / {len(checkpoint_model_dict)} model params")
258
+ self.accelerator.unwrap_model(self.model).load_state_dict(filtered_model_dict, strict=False)
259
+
260
+ ### **3. 加载优化器、调度器和步数**
261
+ if "step" in checkpoint:
262
+ if self.scheduler and not self.reset_lr:
263
+ self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
264
+ step = checkpoint["step"]
265
+ else:
266
+ step = 0
267
+
268
+ del checkpoint
269
+ gc.collect()
270
+ print("Checkpoint loaded at step", step)
271
+ return step
272
+
273
+ def train(self, resumable_with_seed: int = None):
274
+ train_dataloader = self.train_dataloader
275
+
276
+ start_step = self.load_checkpoint()
277
+ global_step = start_step
278
+
279
+ if resumable_with_seed > 0:
280
+ orig_epoch_step = len(train_dataloader)
281
+ skipped_epoch = int(start_step // orig_epoch_step)
282
+ skipped_batch = start_step % orig_epoch_step
283
+ skipped_dataloader = self.accelerator.skip_first_batches(train_dataloader, num_batches=skipped_batch)
284
+ else:
285
+ skipped_epoch = 0
286
+
287
+ for epoch in range(skipped_epoch, self.epochs):
288
+ self.model.train()
289
+ if resumable_with_seed > 0 and epoch == skipped_epoch:
290
+ progress_bar = tqdm(
291
+ skipped_dataloader,
292
+ desc=f"Epoch {epoch+1}/{self.epochs}",
293
+ unit="step",
294
+ disable=not self.accelerator.is_local_main_process,
295
+ initial=skipped_batch,
296
+ total=orig_epoch_step,
297
+ smoothing=0.15
298
+ )
299
+ else:
300
+ progress_bar = tqdm(
301
+ train_dataloader,
302
+ desc=f"Epoch {epoch+1}/{self.epochs}",
303
+ unit="step",
304
+ disable=not self.accelerator.is_local_main_process,
305
+ smoothing=0.15
306
+ )
307
+
308
+ for batch in progress_bar:
309
+ with self.accelerator.accumulate(self.model):
310
+ text_inputs = batch["lrc"]
311
+ mel_spec = batch["latent"].permute(0, 2, 1)
312
+ mel_lengths = batch["latent_lengths"]
313
+ style_prompt = batch["prompt"]
314
+ style_prompt_lens = batch["prompt_lengths"]
315
+ start_time = batch["start_time"]
316
+
317
+ loss, cond, pred = self.model(
318
+ mel_spec, text=text_inputs, lens=mel_lengths, noise_scheduler=self.noise_scheduler,
319
+ style_prompt=style_prompt if self.use_style_prompt else None,
320
+ style_prompt_lens=style_prompt_lens if self.use_style_prompt else None,
321
+ grad_ckpt=self.grad_ckpt, start_time=start_time
322
+ )
323
+ self.accelerator.backward(loss)
324
+
325
+ if self.max_grad_norm > 0 and self.accelerator.sync_gradients:
326
+ self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
327
+
328
+ self.optimizer.step()
329
+ self.scheduler.step()
330
+ self.optimizer.zero_grad()
331
+
332
+ if self.is_main:
333
+ self.ema_model.update()
334
+
335
+ global_step += 1
336
+
337
+ if self.accelerator.is_local_main_process:
338
+ self.accelerator.log({"loss": loss.item(), "lr": self.scheduler.get_last_lr()[0]}, step=global_step)
339
+
340
+ progress_bar.set_postfix(step=str(global_step), loss=loss.item())
341
+
342
+ if global_step % (self.save_per_updates * self.grad_accumulation_steps) == 0:
343
+ self.save_checkpoint(global_step)
344
+
345
+ if global_step % self.last_per_steps == 0:
346
+ self.save_checkpoint(global_step, last=True)
347
+
348
+ self.save_checkpoint(global_step, last=True)
349
+
350
+ self.accelerator.end_training()
diffrhythm/model/utils.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import random
5
+ from collections import defaultdict
6
+ from importlib.resources import files
7
+
8
+ import torch
9
+ from torch.nn.utils.rnn import pad_sequence
10
+
11
+
12
+ # seed everything
13
+
14
+
15
+ def seed_everything(seed=0):
16
+ random.seed(seed)
17
+ os.environ["PYTHONHASHSEED"] = str(seed)
18
+ torch.manual_seed(seed)
19
+ torch.cuda.manual_seed(seed)
20
+ torch.cuda.manual_seed_all(seed)
21
+ torch.backends.cudnn.deterministic = True
22
+ torch.backends.cudnn.benchmark = False
23
+
24
+
25
+ # helpers
26
+
27
+
28
+ def exists(v):
29
+ return v is not None
30
+
31
+
32
+ def default(v, d):
33
+ return v if exists(v) else d
34
+
35
+
36
+ # tensor helpers
37
+
38
+
39
+ def lens_to_mask(t: int["b"], length: int | None = None) -> bool["b n"]: # noqa: F722 F821
40
+ if not exists(length):
41
+ length = t.amax()
42
+
43
+ seq = torch.arange(length, device=t.device)
44
+ return seq[None, :] < t[:, None]
45
+
46
+
47
+ def mask_from_start_end_indices(seq_len: int["b"], start: int["b"], end: int["b"]): # noqa: F722 F821
48
+ max_seq_len = 2048
49
+ seq = torch.arange(max_seq_len, device=start.device).long()
50
+ start_mask = seq[None, :] >= start[:, None]
51
+ end_mask = seq[None, :] < end[:, None]
52
+ return start_mask & end_mask
53
+
54
+
55
+ def mask_from_frac_lengths(seq_len: int["b"], frac_lengths: float["b"]): # noqa: F722 F821
56
+ lengths = (frac_lengths * seq_len).long()
57
+ max_start = seq_len - lengths
58
+
59
+ rand = torch.rand_like(frac_lengths)
60
+ start = (max_start * rand).long().clamp(min=0)
61
+ end = start + lengths
62
+
63
+ return mask_from_start_end_indices(seq_len, start, end)
64
+
65
+
66
+ def maybe_masked_mean(t: float["b n d"], mask: bool["b n"] = None) -> float["b d"]: # noqa: F722
67
+ if not exists(mask):
68
+ return t.mean(dim=1)
69
+
70
+ t = torch.where(mask[:, :, None], t, torch.tensor(0.0, device=t.device))
71
+ num = t.sum(dim=1)
72
+ den = mask.float().sum(dim=1)
73
+
74
+ return num / den.clamp(min=1.0)
75
+
76
+
77
+ # simple utf-8 tokenizer, since paper went character based
78
+ def list_str_to_tensor(text: list[str], padding_value=-1) -> int["b nt"]: # noqa: F722
79
+ list_tensors = [torch.tensor([*bytes(t, "UTF-8")]) for t in text] # ByT5 style
80
+ text = pad_sequence(list_tensors, padding_value=padding_value, batch_first=True)
81
+ return text
82
+
83
+
84
+ # char tokenizer, based on custom dataset's extracted .txt file
85
+ def list_str_to_idx(
86
+ text: list[str] | list[list[str]],
87
+ vocab_char_map: dict[str, int], # {char: idx}
88
+ padding_value=-1,
89
+ ) -> int["b nt"]: # noqa: F722
90
+ list_idx_tensors = [torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text] # pinyin or char style
91
+ text = pad_sequence(list_idx_tensors, padding_value=padding_value, batch_first=True)
92
+ return text
93
+
94
+
95
+ # Get tokenizer
96
+
97
+
98
+ def get_tokenizer(dataset_name, tokenizer: str = "pinyin"):
99
+ """
100
+ tokenizer - "pinyin" do g2p for only chinese characters, need .txt vocab_file
101
+ - "char" for char-wise tokenizer, need .txt vocab_file
102
+ - "byte" for utf-8 tokenizer
103
+ - "custom" if you're directly passing in a path to the vocab.txt you want to use
104
+ vocab_size - if use "pinyin", all available pinyin types, common alphabets (also those with accent) and symbols
105
+ - if use "char", derived from unfiltered character & symbol counts of custom dataset
106
+ - if use "byte", set to 256 (unicode byte range)
107
+ """
108
+ if tokenizer in ["pinyin", "char"]:
109
+ tokenizer_path = os.path.join(files("diffrhythm").joinpath("../../data"), f"{dataset_name}_{tokenizer}/vocab.txt")
110
+ with open(tokenizer_path, "r", encoding="utf-8") as f:
111
+ vocab_char_map = {}
112
+ for i, char in enumerate(f):
113
+ vocab_char_map[char[:-1]] = i
114
+ vocab_size = len(vocab_char_map)
115
+ assert vocab_char_map[" "] == 0, "make sure space is of idx 0 in vocab.txt, cuz 0 is used for unknown char"
116
+
117
+ elif tokenizer == "byte":
118
+ vocab_char_map = None
119
+ vocab_size = 256
120
+
121
+ elif tokenizer == "custom":
122
+ with open(dataset_name, "r", encoding="utf-8") as f:
123
+ vocab_char_map = {}
124
+ for i, char in enumerate(f):
125
+ vocab_char_map[char[:-1]] = i
126
+ vocab_size = len(vocab_char_map)
127
+
128
+ return vocab_char_map, vocab_size
129
+
130
+
131
+ # convert char to pinyin
132
+
133
+
134
+ def convert_char_to_pinyin(text_list, polyphone=True):
135
+ final_text_list = []
136
+ god_knows_why_en_testset_contains_zh_quote = str.maketrans(
137
+ {"“": '"', "”": '"', "‘": "'", "’": "'"}
138
+ ) # in case librispeech (orig no-pc) test-clean
139
+ custom_trans = str.maketrans({";": ","}) # add custom trans here, to address oov
140
+ for text in text_list:
141
+ char_list = []
142
+ text = text.translate(god_knows_why_en_testset_contains_zh_quote)
143
+ text = text.translate(custom_trans)
144
+ for seg in jieba.cut(text):
145
+ seg_byte_len = len(bytes(seg, "UTF-8"))
146
+ if seg_byte_len == len(seg): # if pure alphabets and symbols
147
+ if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"":
148
+ char_list.append(" ")
149
+ char_list.extend(seg)
150
+ elif polyphone and seg_byte_len == 3 * len(seg): # if pure chinese characters
151
+ seg = lazy_pinyin(seg, style=Style.TONE3, tone_sandhi=True)
152
+ for c in seg:
153
+ if c not in "。,、;:?!《》【】—…":
154
+ char_list.append(" ")
155
+ char_list.append(c)
156
+ else: # if mixed chinese characters, alphabets and symbols
157
+ for c in seg:
158
+ if ord(c) < 256:
159
+ char_list.extend(c)
160
+ else:
161
+ if c not in "。,、;:?!《》【】—…":
162
+ char_list.append(" ")
163
+ char_list.extend(lazy_pinyin(c, style=Style.TONE3, tone_sandhi=True))
164
+ else: # if is zh punc
165
+ char_list.append(c)
166
+ final_text_list.append(char_list)
167
+
168
+ return final_text_list
169
+
170
+
171
+ # filter func for dirty data with many repetitions
172
+
173
+
174
+ def repetition_found(text, length=2, tolerance=10):
175
+ pattern_count = defaultdict(int)
176
+ for i in range(len(text) - length + 1):
177
+ pattern = text[i : i + length]
178
+ pattern_count[pattern] += 1
179
+ for pattern, count in pattern_count.items():
180
+ if count > tolerance:
181
+ return True
182
+ return False
prompt/gift_of_the_world.wav ADDED
Binary file (960 kB). View file
 
prompt/little_happiness.wav ADDED
Binary file (960 kB). View file
 
prompt/little_talks.wav ADDED
Binary file (960 kB). View file
 
prompt/ltwyl.wav ADDED
Binary file (882 kB). View file
 
prompt/most_beautiful_expectation.wav ADDED
Binary file (960 kB). View file