lengyue233 commited on
Commit
0a3525d
·
verified ·
1 Parent(s): c8f7e84

Init hf space integration

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +2 -0
  2. app.py +317 -0
  3. fish_speech/callbacks/__init__.py +3 -0
  4. fish_speech/callbacks/grad_norm.py +113 -0
  5. fish_speech/configs/base.yaml +86 -0
  6. fish_speech/configs/model/dual_ar_2_codebook_large.yaml +9 -0
  7. fish_speech/configs/model/dual_ar_2_codebook_medium.yaml +9 -0
  8. fish_speech/configs/model/dual_ar_2_codebook_small.yaml +13 -0
  9. fish_speech/configs/model/naive_2_codebook_small.yaml +12 -0
  10. fish_speech/configs/text2semantic_finetune.yaml +79 -0
  11. fish_speech/configs/text2semantic_finetune_lora.yaml +13 -0
  12. fish_speech/configs/text2semantic_pretrain.yaml +74 -0
  13. fish_speech/configs/text2semantic_sft.yaml +87 -0
  14. fish_speech/configs/vqgan_finetune.yaml +135 -0
  15. fish_speech/configs/vqgan_pretrain.yaml +139 -0
  16. fish_speech/datasets/protos/text-data.proto +24 -0
  17. fish_speech/datasets/protos/text_data_pb2.py +33 -0
  18. fish_speech/datasets/protos/text_data_stream.py +36 -0
  19. fish_speech/datasets/text.py +661 -0
  20. fish_speech/datasets/vqgan.py +145 -0
  21. fish_speech/models/text2semantic/__init__.py +3 -0
  22. fish_speech/models/text2semantic/lit_module.py +344 -0
  23. fish_speech/models/text2semantic/llama.py +595 -0
  24. fish_speech/models/vqgan/__init__.py +3 -0
  25. fish_speech/models/vqgan/lit_module.py +442 -0
  26. fish_speech/models/vqgan/modules/discriminator.py +44 -0
  27. fish_speech/models/vqgan/modules/firefly.py +538 -0
  28. fish_speech/models/vqgan/modules/fsq.py +139 -0
  29. fish_speech/models/vqgan/modules/reference.py +113 -0
  30. fish_speech/models/vqgan/modules/wavenet.py +225 -0
  31. fish_speech/models/vqgan/spectrogram.py +122 -0
  32. fish_speech/models/vqgan/utils.py +94 -0
  33. fish_speech/scheduler.py +22 -0
  34. fish_speech/text/__init__.py +3 -0
  35. fish_speech/text/clean.py +73 -0
  36. fish_speech/train.py +135 -0
  37. fish_speech/utils/__init__.py +21 -0
  38. fish_speech/utils/braceexpand.py +217 -0
  39. fish_speech/utils/file.py +119 -0
  40. fish_speech/utils/instantiators.py +50 -0
  41. fish_speech/utils/logger.py +55 -0
  42. fish_speech/utils/logging_utils.py +48 -0
  43. fish_speech/utils/rich_utils.py +96 -0
  44. fish_speech/utils/utils.py +114 -0
  45. packages.txt +10 -0
  46. pyrightconfig.json +6 -0
  47. requirements.txt +24 -0
  48. setup.sh +18 -0
  49. tools/extract_model.py +21 -0
  50. tools/llama/build_dataset.py +165 -0
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ __pycache__
2
+ checkpoints
app.py ADDED
@@ -0,0 +1,317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import subprocess as sp
2
+ sp.check_call("setup.sh", shell=True)
3
+
4
+ import html
5
+ import os
6
+ from argparse import ArgumentParser
7
+ from io import BytesIO
8
+ from pathlib import Path
9
+
10
+ import gradio as gr
11
+ import librosa
12
+ import spaces
13
+ import torch
14
+ from loguru import logger
15
+ from torchaudio import functional as AF
16
+ from transformers import AutoTokenizer
17
+
18
+ from tools.llama.generate import generate_long
19
+ from tools.llama.generate import load_model as load_llama_model
20
+ from tools.vqgan.inference import load_model as load_vqgan_model
21
+
22
+ # Make einx happy
23
+ os.environ["EINX_FILTER_TRACEBACK"] = "false"
24
+
25
+
26
+ HEADER_MD = """# Fish Speech
27
+
28
+ A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).
29
+ 由 [Fish Audio](https://fish.audio) 研发的基于 VQ-GAN 和 Llama 的多语种语音合成.
30
+
31
+ You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).
32
+ 你可以在 [这里](https://github.com/fishaudio/fish-speech) 找到源代码和 [这里](https://huggingface.co/fishaudio/fish-speech-1) 找到模型.
33
+
34
+ Related code are released under BSD-3-Clause License, and weights are released under CC BY-NC-SA 4.0 License.
35
+ 相关代码使用 BSD-3-Clause 许可证发布,权重使用 CC BY-NC-SA 4.0 许可证发布.
36
+
37
+ We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.
38
+ 我们不对模型的任何滥用负责,请在使用之前考虑您当地的法律法规.
39
+ """
40
+
41
+ TEXTBOX_PLACEHOLDER = """Put your text here. 在此处输入文本."""
42
+
43
+
44
+ def build_html_error_message(error):
45
+ return f"""
46
+ <div style="color: red; font-weight: bold;">
47
+ {html.escape(error)}
48
+ </div>
49
+ """
50
+
51
+
52
+ @spaces.GPU
53
+ def inference(
54
+ text,
55
+ enable_reference_audio,
56
+ reference_audio,
57
+ reference_text,
58
+ max_new_tokens,
59
+ chunk_length,
60
+ top_k,
61
+ top_p,
62
+ repetition_penalty,
63
+ temperature,
64
+ speaker=None,
65
+ ):
66
+ if len(reference_text) > 100:
67
+ return None, "Ref text is too long, please keep it under 100 characters."
68
+
69
+ if args.max_gradio_length > 0 and len(text) > args.max_gradio_length:
70
+ return None, "Text is too long, please keep it under 1000 characters."
71
+
72
+ # Parse reference audio aka prompt
73
+ if enable_reference_audio and reference_audio is not None:
74
+ # reference_audio_sr, reference_audio_content = reference_audio
75
+ reference_audio_content, _ = librosa.load(
76
+ reference_audio, sr=vqgan_model.sampling_rate, mono=True
77
+ )
78
+ audios = torch.from_numpy(reference_audio_content).to(vqgan_model.device)[
79
+ None, None, :
80
+ ]
81
+
82
+ logger.info(
83
+ f"Loaded audio with {audios.shape[2] / vqgan_model.sampling_rate:.2f} seconds"
84
+ )
85
+
86
+ # VQ Encoder
87
+ audio_lengths = torch.tensor(
88
+ [audios.shape[2]], device=vqgan_model.device, dtype=torch.long
89
+ )
90
+ prompt_tokens = vqgan_model.encode(audios, audio_lengths)[0][0]
91
+
92
+ # LLAMA Inference
93
+ result = generate_long(
94
+ model=llama_model,
95
+ tokenizer=llama_tokenizer,
96
+ device=vqgan_model.device,
97
+ decode_one_token=decode_one_token,
98
+ max_new_tokens=max_new_tokens,
99
+ text=text,
100
+ top_k=int(top_k) if top_k > 0 else None,
101
+ top_p=top_p,
102
+ repetition_penalty=repetition_penalty,
103
+ temperature=temperature,
104
+ compile=args.compile,
105
+ iterative_prompt=chunk_length > 0,
106
+ chunk_length=chunk_length,
107
+ max_length=args.max_length,
108
+ speaker=speaker if speaker else None,
109
+ prompt_tokens=prompt_tokens if enable_reference_audio else None,
110
+ prompt_text=reference_text if enable_reference_audio else None,
111
+ )
112
+
113
+ codes = next(result)
114
+
115
+ # VQGAN Inference
116
+ feature_lengths = torch.tensor([codes.shape[1]], device=vqgan_model.device)
117
+ fake_audios = vqgan_model.decode(
118
+ indices=codes[None], feature_lengths=feature_lengths, return_audios=True
119
+ )[0, 0]
120
+
121
+ fake_audios = fake_audios.float().cpu().numpy()
122
+
123
+ return (vqgan_model.sampling_rate, fake_audios), None
124
+
125
+
126
+ def build_app():
127
+ with gr.Blocks(theme=gr.themes.Base()) as app:
128
+ gr.Markdown(HEADER_MD)
129
+
130
+ # Use light theme by default
131
+ app.load(
132
+ None,
133
+ None,
134
+ js="() => {const params = new URLSearchParams(window.location.search);if (!params.has('__theme')) {params.set('__theme', 'light');window.location.search = params.toString();}}",
135
+ )
136
+
137
+ # Inference
138
+ with gr.Row():
139
+ with gr.Column(scale=3):
140
+ text = gr.Textbox(
141
+ label="Input Text / 输入文本",
142
+ placeholder=TEXTBOX_PLACEHOLDER,
143
+ lines=15,
144
+ )
145
+
146
+ with gr.Row():
147
+ with gr.Tab(label="Advanced Config / 高级参数"):
148
+ chunk_length = gr.Slider(
149
+ label="Iterative Prompt Length, 0 means off / 迭代提示长度,0 表示关闭",
150
+ minimum=0,
151
+ maximum=100,
152
+ value=30,
153
+ step=8,
154
+ )
155
+
156
+ max_new_tokens = gr.Slider(
157
+ label="Maximum tokens per batch, 0 means no limit / 每批最大令牌数,0 表示无限制",
158
+ minimum=128,
159
+ maximum=512,
160
+ value=512, # 0 means no limit
161
+ step=8,
162
+ )
163
+
164
+ top_k = gr.Slider(
165
+ label="Top-K", minimum=0, maximum=5, value=0, step=1
166
+ )
167
+
168
+ top_p = gr.Slider(
169
+ label="Top-P", minimum=0, maximum=1, value=0.7, step=0.01
170
+ )
171
+
172
+ repetition_penalty = gr.Slider(
173
+ label="Repetition Penalty",
174
+ minimum=0,
175
+ maximum=2,
176
+ value=1.5,
177
+ step=0.01,
178
+ )
179
+
180
+ temperature = gr.Slider(
181
+ label="Temperature",
182
+ minimum=0,
183
+ maximum=2,
184
+ value=0.7,
185
+ step=0.01,
186
+ )
187
+
188
+ # speaker = gr.Textbox(
189
+ # label="Speaker / 说话人",
190
+ # placeholder="Type name of the speaker / 输入说话人的名称",
191
+ # lines=1,
192
+ # )
193
+
194
+ with gr.Tab(label="Reference Audio / 参考音频"):
195
+ gr.Markdown(
196
+ "5 to 10 seconds of reference audio, useful for specifying speaker. \n5 到 10 秒的参考音频,适用于指定音色。"
197
+ )
198
+
199
+ enable_reference_audio = gr.Checkbox(
200
+ label="Enable Reference Audio / 启用参考音频",
201
+ )
202
+ reference_audio = gr.Audio(
203
+ label="Reference Audio / 参考音频",
204
+ value="docs/assets/audios/0_input.wav",
205
+ type="filepath",
206
+ )
207
+ reference_text = gr.Textbox(
208
+ label="Reference Text / 参考文本",
209
+ placeholder="参考文本",
210
+ lines=1,
211
+ value="在一无所知中,梦里的一天结束了,一个新的「轮回」便会开始。",
212
+ )
213
+
214
+ with gr.Column(scale=3):
215
+ with gr.Row():
216
+ error = gr.HTML(label="Error Message / 错误信息")
217
+ with gr.Row():
218
+ audio = gr.Audio(label="Generated Audio / 音频", type="numpy")
219
+
220
+ with gr.Row():
221
+ with gr.Column(scale=3):
222
+ generate = gr.Button(
223
+ value="\U0001F3A7 Generate / 合成", variant="primary"
224
+ )
225
+
226
+ # # Submit
227
+ generate.click(
228
+ inference,
229
+ [
230
+ text,
231
+ enable_reference_audio,
232
+ reference_audio,
233
+ reference_text,
234
+ max_new_tokens,
235
+ chunk_length,
236
+ top_k,
237
+ top_p,
238
+ repetition_penalty,
239
+ temperature,
240
+ # speaker,
241
+ ],
242
+ [audio, error],
243
+ )
244
+
245
+ return app
246
+
247
+
248
+ def parse_args():
249
+ parser = ArgumentParser()
250
+ parser.add_argument(
251
+ "--llama-checkpoint-path",
252
+ type=Path,
253
+ default="checkpoints/text2semantic-medium-v1-2k.pth",
254
+ )
255
+ parser.add_argument(
256
+ "--llama-config-name", type=str, default="dual_ar_2_codebook_medium"
257
+ )
258
+ parser.add_argument(
259
+ "--vqgan-checkpoint-path",
260
+ type=Path,
261
+ default="checkpoints/vq-gan-group-fsq-2x1024.pth",
262
+ )
263
+ parser.add_argument("--vqgan-config-name", type=str, default="vqgan_pretrain")
264
+ parser.add_argument("--tokenizer", type=str, default="fishaudio/fish-speech-1")
265
+ parser.add_argument("--device", type=str, default="cuda")
266
+ parser.add_argument("--half", action="store_true")
267
+ parser.add_argument("--max-length", type=int, default=2048)
268
+ parser.add_argument("--compile", action="store_true")
269
+ parser.add_argument("--max-gradio-length", type=int, default=1024)
270
+
271
+ return parser.parse_args()
272
+
273
+
274
+ if __name__ == "__main__":
275
+ args = parse_args()
276
+
277
+ args.precision = torch.half if args.half else torch.bfloat16
278
+
279
+ logger.info("Loading Llama model...")
280
+ llama_model, decode_one_token = load_llama_model(
281
+ config_name=args.llama_config_name,
282
+ checkpoint_path=args.llama_checkpoint_path,
283
+ device=args.device,
284
+ precision=args.precision,
285
+ max_length=args.max_length,
286
+ compile=args.compile,
287
+ )
288
+ llama_tokenizer = AutoTokenizer.from_pretrained(args.tokenizer)
289
+ logger.info("Llama model loaded, loading VQ-GAN model...")
290
+
291
+ vqgan_model = load_vqgan_model(
292
+ config_name=args.vqgan_config_name,
293
+ checkpoint_path=args.vqgan_checkpoint_path,
294
+ device=args.device,
295
+ )
296
+
297
+ logger.info("VQ-GAN model loaded, warming up...")
298
+
299
+ # Dry run to check if the model is loaded correctly and avoid the first-time latency
300
+ inference(
301
+ text="Hello, world!",
302
+ enable_reference_audio=False,
303
+ reference_audio=None,
304
+ reference_text="",
305
+ max_new_tokens=0,
306
+ chunk_length=0,
307
+ top_k=0, # 0 means no limit
308
+ top_p=0.7,
309
+ repetition_penalty=1.5,
310
+ temperature=0.7,
311
+ speaker=None,
312
+ )
313
+
314
+ logger.info("Warming up done, launching the web UI...")
315
+
316
+ app = build_app()
317
+ app.launch(show_api=False)
fish_speech/callbacks/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .grad_norm import GradNormMonitor
2
+
3
+ __all__ = ["GradNormMonitor"]
fish_speech/callbacks/grad_norm.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Union
2
+
3
+ import lightning.pytorch as pl
4
+ import torch
5
+ from lightning import LightningModule, Trainer
6
+ from lightning.pytorch.callbacks import Callback
7
+ from torch import Tensor, nn
8
+ from torch.utils._foreach_utils import (
9
+ _group_tensors_by_device_and_dtype,
10
+ _has_foreach_support,
11
+ )
12
+
13
+
14
+ @torch.no_grad()
15
+ def grad_norm(
16
+ parameters: Union[Tensor, list[Tensor]],
17
+ norm_type: float = 2.0,
18
+ ) -> float:
19
+ """
20
+ Returns the norm of the gradients of the given parameters.
21
+
22
+ Args:
23
+ parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
24
+ single Tensor that will have gradients normalized
25
+ norm_type (float): type of the used p-norm.
26
+
27
+ Returns:
28
+ Total norm of the parameter gradients (viewed as a single vector).
29
+ """ # noqa: E501
30
+
31
+ if isinstance(parameters, Tensor):
32
+ parameters = [parameters]
33
+
34
+ grads = [p.grad for p in parameters if p.grad is not None]
35
+ if len(grads) == 0:
36
+ return None
37
+
38
+ first_device = grads[0].device
39
+ grouped_grads: dict[
40
+ tuple[torch.device, torch.dtype], list[list[Tensor]]
41
+ ] = _group_tensors_by_device_and_dtype(
42
+ [[g.detach() for g in grads]]
43
+ ) # type: ignore[assignment]
44
+
45
+ norms = []
46
+ for (device, _), ([grads], _) in grouped_grads.items():
47
+ if _has_foreach_support(grads, device=device):
48
+ norms.extend(torch._foreach_norm(grads, norm_type))
49
+ else:
50
+ norms.extend([torch.norm(g, norm_type) for g in grads])
51
+
52
+ return torch.norm(torch.stack([norm.to(first_device) for norm in norms]), norm_type)
53
+
54
+
55
+ class GradNormMonitor(Callback):
56
+ """
57
+ Callback that computes the gradient norm of the model parameters.
58
+ """
59
+
60
+ def __init__(
61
+ self,
62
+ norm_type: float = 2.0,
63
+ logging_interval: str = "step",
64
+ sub_module: Optional[Union[str, list[str]]] = None,
65
+ ) -> None:
66
+ """
67
+ Args:
68
+ norm_type (float): type of the used p-norm.
69
+ logging_interval (str): "step" or "epoch".
70
+ """
71
+ super().__init__()
72
+
73
+ self.norm_type = norm_type
74
+ self.logging_interval = logging_interval
75
+ self.sub_module = sub_module
76
+
77
+ def on_after_backward(self, trainer: Trainer, model: LightningModule) -> None:
78
+ """
79
+ Computes the gradient norm of the model parameters and logs it to the logger.
80
+
81
+ Args:
82
+ trainer (Trainer): The trainer object
83
+ model (LightningModule): The current lightningModule
84
+ """
85
+
86
+ lightning_model = model
87
+
88
+ if self.sub_module is None:
89
+ return self.log_sub_module_grad_norm(lightning_model, model, "")
90
+
91
+ sub_modules = self.sub_module
92
+ if isinstance(sub_modules, str):
93
+ sub_modules = [sub_modules]
94
+
95
+ for sub_module in sub_modules:
96
+ self.log_sub_module_grad_norm(
97
+ lightning_model, getattr(model, sub_module), f"/{sub_module}"
98
+ )
99
+
100
+ def log_sub_module_grad_norm(
101
+ self, lightning_model: LightningModule, model: nn.Module, path: str
102
+ ) -> None:
103
+ grad_norm_val = grad_norm(model.parameters(), self.norm_type)
104
+ if grad_norm_val is None:
105
+ return
106
+
107
+ on_step = self.logging_interval == "step"
108
+ lightning_model.log(
109
+ f"train{path}/grad_norm",
110
+ grad_norm_val,
111
+ on_step=on_step,
112
+ on_epoch=not on_step,
113
+ )
fish_speech/configs/base.yaml ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Base configuration for training a model
2
+ paths:
3
+ run_dir: results/${project}
4
+ ckpt_dir: ${paths.run_dir}/checkpoints
5
+
6
+ hydra:
7
+ run:
8
+ dir: ${paths.run_dir}
9
+
10
+ # Lightning Trainer
11
+ trainer:
12
+ _target_: lightning.pytorch.trainer.Trainer
13
+
14
+ default_root_dir: ${paths.run_dir}
15
+ accelerator: gpu
16
+ num_nodes: 1
17
+ devices: auto
18
+ strategy:
19
+ _target_: lightning.pytorch.strategies.DDPStrategy
20
+
21
+ precision: bf16-mixed
22
+
23
+ # disable validation by epoch end
24
+ check_val_every_n_epoch: null
25
+ val_check_interval: 5000
26
+ max_steps: 100_000
27
+
28
+ # Use torch.backends.cudnn.benchmark to speed up training
29
+ benchmark: true
30
+
31
+ # Callbacks
32
+ callbacks:
33
+ model_checkpoint:
34
+ _target_: lightning.pytorch.callbacks.ModelCheckpoint
35
+ dirpath: ${paths.ckpt_dir}
36
+ filename: "step_{step:09d}"
37
+ save_last: false # additionally always save an exact copy of the last checkpoint to a file last.ckpt
38
+ save_top_k: 5 # save 5 latest checkpoints
39
+ monitor: step # use step to monitor checkpoints
40
+ mode: max # save the latest checkpoint with the highest global_step
41
+ every_n_epochs: null # don't save checkpoints by epoch end
42
+ every_n_train_steps: 5000 # save checkpoints every 5000 steps
43
+ auto_insert_metric_name: false
44
+
45
+ model_summary:
46
+ _target_: lightning.pytorch.callbacks.ModelSummary
47
+ max_depth: 2 # the maximum depth of layer nesting that the summary will include
48
+
49
+ learning_rate_monitor:
50
+ _target_: lightning.pytorch.callbacks.LearningRateMonitor
51
+ logging_interval: step
52
+ log_momentum: false
53
+
54
+ grad_norm_monitor:
55
+ _target_: fish_speech.callbacks.GradNormMonitor
56
+ norm_type: 2
57
+ logging_interval: step
58
+
59
+ # Logger
60
+ logger:
61
+ tensorboard:
62
+ _target_: lightning.pytorch.loggers.tensorboard.TensorBoardLogger
63
+ save_dir: "${paths.run_dir}/tensorboard/"
64
+ name: null
65
+ log_graph: false
66
+ default_hp_metric: true
67
+ prefix: ""
68
+
69
+ # wandb:
70
+ # _target_: lightning.pytorch.loggers.wandb.WandbLogger
71
+ # # name: "" # name of the run (normally generated by wandb)
72
+ # save_dir: "${paths.run_dir}"
73
+ # offline: False
74
+ # id: null # pass correct id to resume experiment!
75
+ # anonymous: null # enable anonymous logging
76
+ # project: "fish-speech"
77
+ # log_model: False # upload lightning ckpts
78
+ # prefix: "" # a string to put at the beginning of metric keys
79
+ # # entity: "" # set to name of your wandb team
80
+ # group: ""
81
+ # tags: ["vq", "hq", "finetune"]
82
+ # job_type: ""
83
+
84
+ # Loop
85
+ train: true
86
+ test: false
fish_speech/configs/model/dual_ar_2_codebook_large.yaml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - dual_ar_2_codebook_small
3
+ - _self_
4
+
5
+ config:
6
+ n_layer: 30
7
+ n_fast_layer: 6
8
+ n_head: 24
9
+ dim: 1536
fish_speech/configs/model/dual_ar_2_codebook_medium.yaml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - dual_ar_2_codebook_small
3
+ - _self_
4
+
5
+ config:
6
+ n_layer: 24
7
+ n_fast_layer: 6
8
+ n_head: 16
9
+ dim: 1024
fish_speech/configs/model/dual_ar_2_codebook_small.yaml ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: fish_speech.models.text2semantic.llama.DualARTransformer
2
+ config:
3
+ _target_: fish_speech.models.text2semantic.llama.DualARModelArgs
4
+ max_seq_len: ${max_length}
5
+ vocab_size: 264 # pad 262 to 8x
6
+ n_layer: 12
7
+ n_fast_layer: 4
8
+ n_head: 12
9
+ dim: 768
10
+ rope_base: 10000
11
+ norm_eps: 1e-5
12
+ num_codebooks: 2 # input/output codebook size
13
+ codebook_size: 1032 # codebook size 1024 + 2 special tokens
fish_speech/configs/model/naive_2_codebook_small.yaml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: fish_speech.models.text2semantic.llama.NaiveTransformer
2
+ config:
3
+ _target_: fish_speech.models.text2semantic.llama.NaiveModelArgs
4
+ max_seq_len: ${max_length}
5
+ vocab_size: 36408
6
+ n_layer: 12
7
+ n_head: 12
8
+ dim: 768
9
+ rope_base: 10000
10
+ norm_eps: 1e-5
11
+ num_codebooks: 2 # input/output codebook size
12
+ codebook_size: 1032 # codebook size 1024 + 2 special tokens
fish_speech/configs/text2semantic_finetune.yaml ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - base
3
+ - [email protected]: dual_ar_2_codebook_small
4
+ - _self_
5
+
6
+ project: text2semantic_finetune_dual_ar
7
+ max_length: 2048
8
+ ckpt_path: checkpoints/text2semantic-medium-v1-2k.pth
9
+ resume_weights_only: true
10
+
11
+ # Lightning Trainer
12
+ trainer:
13
+ accumulate_grad_batches: 1
14
+ gradient_clip_val: 1.0
15
+ gradient_clip_algorithm: 'norm'
16
+ max_steps: 1000
17
+ precision: bf16-true
18
+ limit_val_batches: 10
19
+ val_check_interval: 100
20
+
21
+ # Dataset Configuration
22
+ tokenizer:
23
+ _target_: transformers.AutoTokenizer.from_pretrained
24
+ pretrained_model_name_or_path: fishaudio/fish-speech-1
25
+
26
+ # Dataset Configuration
27
+ train_dataset:
28
+ _target_: fish_speech.datasets.text.AutoAugTextDataset
29
+ proto_files:
30
+ - data/protos
31
+ tokenizer: ${tokenizer}
32
+ max_length: ${max_length}
33
+ num_codebooks: ${model.model.config.num_codebooks}
34
+ use_speaker: false
35
+
36
+ val_dataset:
37
+ _target_: fish_speech.datasets.text.AutoAugTextDataset
38
+ proto_files:
39
+ - data/protos
40
+ tokenizer: ${tokenizer}
41
+ max_length: ${max_length}
42
+ num_codebooks: ${model.model.config.num_codebooks}
43
+ use_speaker: false
44
+
45
+ data:
46
+ _target_: fish_speech.datasets.text.TextDataModule
47
+ train_dataset: ${train_dataset}
48
+ val_dataset: ${val_dataset}
49
+ num_workers: 4
50
+ batch_size: 8
51
+ tokenizer: ${tokenizer}
52
+ max_length: ${max_length}
53
+
54
+ # Model Configuration
55
+ model:
56
+ _target_: fish_speech.models.text2semantic.TextToSemantic
57
+ model: {}
58
+
59
+ optimizer:
60
+ _target_: torch.optim.AdamW
61
+ _partial_: true
62
+ lr: 1e-5
63
+ weight_decay: 0
64
+ betas: [0.9, 0.95]
65
+ eps: 1e-5
66
+
67
+ lr_scheduler:
68
+ _target_: torch.optim.lr_scheduler.LambdaLR
69
+ _partial_: true
70
+ lr_lambda:
71
+ _target_: fish_speech.scheduler.get_cosine_schedule_with_warmup_lr_lambda
72
+ _partial_: true
73
+ num_warmup_steps: 100
74
+ num_training_steps: ${trainer.max_steps}
75
+
76
+ # Callbacks
77
+ callbacks:
78
+ model_checkpoint:
79
+ every_n_train_steps: 100
fish_speech/configs/text2semantic_finetune_lora.yaml ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - text2semantic_finetune
3
+ - _self_
4
+
5
+ project: text2semantic_finetune_dual_ar_lora
6
+
7
+ # Model Configuration
8
+ model:
9
+ save_lora_only: true
10
+ lora_config:
11
+ _target_: fish_speech.models.text2semantic.lit_module.LoraConfig
12
+ r: 8
13
+ lora_alpha: 16
fish_speech/configs/text2semantic_pretrain.yaml ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - base
3
+ - [email protected]: dual_ar_2_codebook_small
4
+ - _self_
5
+
6
+ project: text2semantic_pretrain_dual_ar_debug
7
+ max_length: 2048
8
+
9
+ # Lightning Trainer
10
+ trainer:
11
+ accumulate_grad_batches: 1
12
+ gradient_clip_val: 1.0
13
+ gradient_clip_algorithm: 'norm'
14
+ max_steps: 1_000_000
15
+ precision: bf16-true
16
+ limit_val_batches: 10
17
+
18
+ # Dataset Configuration
19
+ tokenizer:
20
+ _target_: transformers.AutoTokenizer.from_pretrained
21
+ pretrained_model_name_or_path: fishaudio/fish-speech-1
22
+
23
+ # Dataset Configuration
24
+ train_dataset:
25
+ _target_: fish_speech.datasets.text.AutoAugTextDataset
26
+ proto_files:
27
+ - data/protos/train
28
+ tokenizer: ${tokenizer}
29
+ max_length: ${max_length}
30
+ num_codebooks: ${model.model.config.num_codebooks}
31
+ use_speaker: false
32
+ interactive_prob: 0.5
33
+
34
+ val_dataset:
35
+ _target_: fish_speech.datasets.text.AutoAugTextDataset
36
+ proto_files:
37
+ - data/protos/test
38
+ tokenizer: ${tokenizer}
39
+ max_length: ${max_length}
40
+ num_codebooks: ${model.model.config.num_codebooks}
41
+ use_speaker: false
42
+ interactive_prob: 0.5
43
+
44
+ data:
45
+ _target_: fish_speech.datasets.text.TextDataModule
46
+ train_dataset: ${train_dataset}
47
+ val_dataset: ${val_dataset}
48
+ num_workers: 4
49
+ batch_size: 8
50
+ tokenizer: ${tokenizer}
51
+ max_length: ${max_length}
52
+
53
+ # Model Configuration
54
+ model:
55
+ _target_: fish_speech.models.text2semantic.TextToSemantic
56
+ model: {}
57
+
58
+ optimizer:
59
+ _target_: torch.optim.AdamW
60
+ _partial_: true
61
+ lr: 3e-4
62
+ weight_decay: 0.01
63
+ betas: [0.9, 0.95]
64
+ eps: 1e-5
65
+
66
+ lr_scheduler:
67
+ _target_: torch.optim.lr_scheduler.LambdaLR
68
+ _partial_: true
69
+ lr_lambda:
70
+ _target_: fish_speech.scheduler.get_cosine_schedule_with_warmup_lr_lambda
71
+ _partial_: true
72
+ num_warmup_steps: 2000
73
+ num_training_steps: ${trainer.max_steps}
74
+ final_lr_ratio: 0.1
fish_speech/configs/text2semantic_sft.yaml ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - base
3
+ - [email protected]: dual_ar_8_codebook_small
4
+ - _self_
5
+
6
+ project: text2semantic_sft_medium_dual_ar
7
+ max_length: 4096
8
+ ckpt_path: results/text2semantic_pretrain_medium_dual_ar/checkpoints/step_000060000.ckpt
9
+ resume_weights_only: true
10
+
11
+ # Lightning Trainer
12
+ trainer:
13
+ accumulate_grad_batches: 1
14
+ gradient_clip_val: 1.0
15
+ gradient_clip_algorithm: 'norm'
16
+ max_steps: 10_000
17
+ precision: bf16-true
18
+ limit_val_batches: 10
19
+ val_check_interval: 500
20
+
21
+ # Dataset Configuration
22
+ tokenizer:
23
+ _target_: transformers.AutoTokenizer.from_pretrained
24
+ pretrained_model_name_or_path: fishaudio/speech-lm-v1
25
+
26
+ # Dataset Configuration
27
+ train_dataset:
28
+ _target_: fish_speech.datasets.text.AutoAugTextDataset
29
+ use_data_server: false
30
+ proto_files:
31
+ - data/protos/sft/train_Genshin.protos
32
+ - data/protos/sft/sft.protos
33
+ tokenizer: ${tokenizer}
34
+ max_length: ${max_length}
35
+ num_codebooks: ${model.model.config.num_codebooks}
36
+ use_speaker: false
37
+ phones_prob: 0.5
38
+ interactive_prob: 0.5
39
+
40
+ val_dataset:
41
+ _target_: fish_speech.datasets.text.AutoAugTextDataset
42
+ use_data_server: false
43
+ proto_files:
44
+ - data/protos/sft/val_Genshin.protos
45
+ tokenizer: ${tokenizer}
46
+ max_length: ${max_length}
47
+ num_codebooks: ${model.model.config.num_codebooks}
48
+ use_speaker: false
49
+ phones_prob: 0.5
50
+ interactive_prob: 0.5
51
+
52
+ data:
53
+ _target_: fish_speech.datasets.text.TextDataModule
54
+ train_dataset: ${train_dataset}
55
+ val_dataset: ${val_dataset}
56
+ num_workers: 4
57
+ batch_size: 8
58
+ tokenizer: ${tokenizer}
59
+ max_length: ${max_length}
60
+
61
+ # Model Configuration
62
+ model:
63
+ _target_: fish_speech.models.text2semantic.TextToSemantic
64
+ model: {}
65
+
66
+ optimizer:
67
+ _target_: torch.optim.AdamW
68
+ _partial_: true
69
+ lr: 4e-5
70
+ weight_decay: 0
71
+ betas: [0.9, 0.95]
72
+ eps: 1e-5
73
+
74
+ lr_scheduler:
75
+ _target_: torch.optim.lr_scheduler.LambdaLR
76
+ _partial_: true
77
+ lr_lambda:
78
+ _target_: fish_speech.scheduler.get_cosine_schedule_with_warmup_lr_lambda
79
+ _partial_: true
80
+ num_warmup_steps: 100
81
+ num_training_steps: ${trainer.max_steps}
82
+ final_lr_ratio: 0
83
+
84
+ callbacks:
85
+ model_checkpoint:
86
+ every_n_train_steps: 1000
87
+ save_top_k: 10
fish_speech/configs/vqgan_finetune.yaml ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - base
3
+ - _self_
4
+
5
+ project: vq-gan-finetune
6
+ ckpt_path: checkpoints/vq-gan-group-fsq-2x1024.pth
7
+ resume_weights_only: true
8
+
9
+ # Lightning Trainer
10
+ trainer:
11
+ accelerator: gpu
12
+ devices: auto
13
+ precision: bf16-mixed
14
+ max_steps: 100_000
15
+ val_check_interval: 5000
16
+ strategy: ddp_find_unused_parameters_true
17
+
18
+ sample_rate: 44100
19
+ hop_length: 512
20
+ num_mels: 128
21
+ n_fft: 2048
22
+ win_length: 2048
23
+ freeze_encoder: true
24
+
25
+ # Dataset Configuration
26
+ train_dataset:
27
+ _target_: fish_speech.datasets.vqgan.VQGANDataset
28
+ filelist: data/filelist.train.txt
29
+ sample_rate: ${sample_rate}
30
+ hop_length: ${hop_length}
31
+ slice_frames: 512
32
+
33
+ val_dataset:
34
+ _target_: fish_speech.datasets.vqgan.VQGANDataset
35
+ filelist: data/filelist.val.txt
36
+ sample_rate: ${sample_rate}
37
+ hop_length: ${hop_length}
38
+
39
+ data:
40
+ _target_: fish_speech.datasets.vqgan.VQGANDataModule
41
+ train_dataset: ${train_dataset}
42
+ val_dataset: ${val_dataset}
43
+ num_workers: 4
44
+ batch_size: 16
45
+ val_batch_size: 16
46
+
47
+ # Model Configuration
48
+ model:
49
+ _target_: fish_speech.models.vqgan.VQGAN
50
+
51
+ sampling_rate: ${sample_rate}
52
+ weight_adv: 0.2
53
+ weight_vq: 1.0
54
+ weight_mel: 1.0
55
+ freeze_encoder: false
56
+
57
+ encoder:
58
+ _target_: fish_speech.models.vqgan.modules.wavenet.WaveNet
59
+ input_channels: ${num_mels}
60
+ residual_channels: 768
61
+ residual_layers: 20
62
+ dilation_cycle: 4
63
+
64
+ quantizer:
65
+ _target_: fish_speech.models.vqgan.modules.fsq.DownsampleFiniteScalarQuantize
66
+ input_dim: 768
67
+ n_codebooks: 1
68
+ n_groups: 2
69
+ levels: [8, 5, 5, 5]
70
+
71
+ decoder:
72
+ _target_: fish_speech.models.vqgan.modules.wavenet.WaveNet
73
+ output_channels: ${num_mels}
74
+ residual_channels: 768
75
+ residual_layers: 20
76
+ dilation_cycle: 4
77
+ condition_channels: 768
78
+
79
+ discriminator:
80
+ _target_: fish_speech.models.vqgan.modules.discriminator.Discriminator
81
+
82
+ vocoder:
83
+ _target_: fish_speech.models.vqgan.modules.firefly.FireflyBase
84
+ ckpt_path: null # You may download the pretrained vocoder and set the path here
85
+
86
+ encode_mel_transform:
87
+ _target_: fish_speech.models.vqgan.spectrogram.LogMelSpectrogram
88
+ sample_rate: ${sample_rate}
89
+ n_fft: ${n_fft}
90
+ hop_length: ${hop_length}
91
+ win_length: ${win_length}
92
+ n_mels: ${num_mels}
93
+ f_min: 0.0
94
+ f_max: 8000.0
95
+
96
+ gt_mel_transform:
97
+ _target_: fish_speech.models.vqgan.spectrogram.LogMelSpectrogram
98
+ sample_rate: ${sample_rate}
99
+ n_fft: ${n_fft}
100
+ hop_length: ${hop_length}
101
+ win_length: ${win_length}
102
+ n_mels: ${num_mels}
103
+
104
+ optimizer:
105
+ _target_: torch.optim.AdamW
106
+ _partial_: true
107
+ lr: 4e-5
108
+ betas: [0.8, 0.99]
109
+ eps: 1e-5
110
+ weight_decay: 0.01
111
+
112
+ lr_scheduler:
113
+ _target_: torch.optim.lr_scheduler.LambdaLR
114
+ _partial_: true
115
+ lr_lambda:
116
+ _target_: fish_speech.scheduler.get_cosine_schedule_with_warmup_lr_lambda
117
+ _partial_: true
118
+ num_warmup_steps: 100
119
+ num_training_steps: ${trainer.max_steps}
120
+ final_lr_ratio: 0
121
+
122
+ callbacks:
123
+ model_summary:
124
+ _target_: lightning.pytorch.callbacks.ModelSummary
125
+ max_depth: 1
126
+
127
+ model_checkpoint:
128
+ every_n_train_steps: ${trainer.val_check_interval}
129
+
130
+ grad_norm_monitor:
131
+ sub_module:
132
+ - encoder
133
+ - decoder
134
+ - quantizer
135
+ - discriminator
fish_speech/configs/vqgan_pretrain.yaml ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - base
3
+ - _self_
4
+
5
+ project: vq-gan-pretrain
6
+
7
+ # Lightning Trainer
8
+ trainer:
9
+ accelerator: gpu
10
+ devices: auto
11
+ precision: bf16-mixed
12
+ max_steps: 1_000_000
13
+ val_check_interval: 5000
14
+ strategy: ddp_find_unused_parameters_true
15
+
16
+ sample_rate: 44100
17
+ hop_length: 512
18
+ num_mels: 128
19
+ n_fft: 2048
20
+ win_length: 2048
21
+
22
+ # Dataset Configuration
23
+ train_dataset:
24
+ _target_: torch.utils.data.ConcatDataset
25
+ datasets:
26
+ - _target_: fish_speech.datasets.vqgan.VQGANDataset
27
+ filelist: data/gigaspeech/vq_train_filelist.txt
28
+ sample_rate: ${sample_rate}
29
+ hop_length: ${hop_length}
30
+ slice_frames: 512
31
+ - _target_: fish_speech.datasets.vqgan.VQGANDataset
32
+ filelist: data/sft/vq_train_filelist.txt
33
+ sample_rate: ${sample_rate}
34
+ hop_length: ${hop_length}
35
+ slice_frames: 512
36
+
37
+ val_dataset:
38
+ _target_: fish_speech.datasets.vqgan.VQGANDataset
39
+ filelist: data/sft/vq_val_filelist.txt
40
+ sample_rate: ${sample_rate}
41
+ hop_length: ${hop_length}
42
+
43
+ data:
44
+ _target_: fish_speech.datasets.vqgan.VQGANDataModule
45
+ train_dataset: ${train_dataset}
46
+ val_dataset: ${val_dataset}
47
+ num_workers: 4
48
+ batch_size: 32
49
+ val_batch_size: 32
50
+
51
+ # Model Configuration
52
+ model:
53
+ _target_: fish_speech.models.vqgan.VQGAN
54
+
55
+ sampling_rate: ${sample_rate}
56
+ weight_adv: 0.2
57
+ weight_vq: 1.0
58
+ weight_mel: 1.0
59
+ freeze_encoder: false
60
+
61
+ encoder:
62
+ _target_: fish_speech.models.vqgan.modules.wavenet.WaveNet
63
+ input_channels: ${num_mels}
64
+ residual_channels: 768
65
+ residual_layers: 20
66
+ dilation_cycle: 4
67
+
68
+ quantizer:
69
+ _target_: fish_speech.models.vqgan.modules.fsq.DownsampleFiniteScalarQuantize
70
+ input_dim: 768
71
+ n_codebooks: 1
72
+ n_groups: 2
73
+ levels: [8, 5, 5, 5]
74
+
75
+ decoder:
76
+ _target_: fish_speech.models.vqgan.modules.wavenet.WaveNet
77
+ output_channels: ${num_mels}
78
+ residual_channels: 768
79
+ residual_layers: 20
80
+ dilation_cycle: 4
81
+ condition_channels: 768
82
+
83
+ discriminator:
84
+ _target_: fish_speech.models.vqgan.modules.discriminator.Discriminator
85
+
86
+ vocoder:
87
+ _target_: fish_speech.models.vqgan.modules.firefly.FireflyBase
88
+ ckpt_path: null # You may download the pretrained vocoder and set the path here
89
+
90
+ encode_mel_transform:
91
+ _target_: fish_speech.models.vqgan.spectrogram.LogMelSpectrogram
92
+ sample_rate: ${sample_rate}
93
+ n_fft: ${n_fft}
94
+ hop_length: ${hop_length}
95
+ win_length: ${win_length}
96
+ n_mels: ${num_mels}
97
+ f_min: 0.0
98
+ f_max: 8000.0
99
+
100
+ gt_mel_transform:
101
+ _target_: fish_speech.models.vqgan.spectrogram.LogMelSpectrogram
102
+ sample_rate: ${sample_rate}
103
+ n_fft: ${n_fft}
104
+ hop_length: ${hop_length}
105
+ win_length: ${win_length}
106
+ n_mels: ${num_mels}
107
+
108
+ optimizer:
109
+ _target_: torch.optim.AdamW
110
+ _partial_: true
111
+ lr: 1e-4
112
+ betas: [0.8, 0.99]
113
+ eps: 1e-5
114
+ weight_decay: 0.01
115
+
116
+ lr_scheduler:
117
+ _target_: torch.optim.lr_scheduler.LambdaLR
118
+ _partial_: true
119
+ lr_lambda:
120
+ _target_: fish_speech.scheduler.get_cosine_schedule_with_warmup_lr_lambda
121
+ _partial_: true
122
+ num_warmup_steps: 100
123
+ num_training_steps: ${trainer.max_steps}
124
+ final_lr_ratio: 0
125
+
126
+ callbacks:
127
+ model_summary:
128
+ _target_: lightning.pytorch.callbacks.ModelSummary
129
+ max_depth: 1
130
+
131
+ model_checkpoint:
132
+ every_n_train_steps: ${trainer.val_check_interval}
133
+
134
+ grad_norm_monitor:
135
+ sub_module:
136
+ - encoder
137
+ - decoder
138
+ - quantizer
139
+ - discriminator
fish_speech/datasets/protos/text-data.proto ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ syntax = "proto3";
2
+
3
+ package text_data;
4
+
5
+ message Semantics {
6
+ repeated uint32 values = 1;
7
+ }
8
+
9
+ message Sentence {
10
+ repeated string texts = 1;
11
+ repeated Semantics semantics = 3;
12
+ }
13
+
14
+ message TextData {
15
+ string source = 1;
16
+ string name = 2;
17
+ repeated Sentence sentences = 4;
18
+ }
19
+
20
+ message SampledData {
21
+ string source = 1;
22
+ string name = 2;
23
+ repeated Sentence samples = 3;
24
+ }
fish_speech/datasets/protos/text_data_pb2.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Generated by the protocol buffer compiler. DO NOT EDIT!
3
+ # source: text-data.proto
4
+ # Protobuf Python Version: 4.25.1
5
+ """Generated protocol buffer code."""
6
+ from google.protobuf import descriptor as _descriptor
7
+ from google.protobuf import descriptor_pool as _descriptor_pool
8
+ from google.protobuf import symbol_database as _symbol_database
9
+ from google.protobuf.internal import builder as _builder
10
+
11
+ # @@protoc_insertion_point(imports)
12
+
13
+ _sym_db = _symbol_database.Default()
14
+
15
+
16
+ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
17
+ b'\n\x0ftext-data.proto\x12\ttext_data"\x1b\n\tSemantics\x12\x0e\n\x06values\x18\x01 \x03(\r"B\n\x08Sentence\x12\r\n\x05texts\x18\x01 \x03(\t\x12\'\n\tsemantics\x18\x03 \x03(\x0b\x32\x14.text_data.Semantics"P\n\x08TextData\x12\x0e\n\x06source\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12&\n\tsentences\x18\x04 \x03(\x0b\x32\x13.text_data.Sentence"Q\n\x0bSampledData\x12\x0e\n\x06source\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12$\n\x07samples\x18\x03 \x03(\x0b\x32\x13.text_data.Sentenceb\x06proto3'
18
+ )
19
+
20
+ _globals = globals()
21
+ _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
22
+ _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "text_data_pb2", _globals)
23
+ if _descriptor._USE_C_DESCRIPTORS == False:
24
+ DESCRIPTOR._options = None
25
+ _globals["_SEMANTICS"]._serialized_start = 30
26
+ _globals["_SEMANTICS"]._serialized_end = 57
27
+ _globals["_SENTENCE"]._serialized_start = 59
28
+ _globals["_SENTENCE"]._serialized_end = 125
29
+ _globals["_TEXTDATA"]._serialized_start = 127
30
+ _globals["_TEXTDATA"]._serialized_end = 207
31
+ _globals["_SAMPLEDDATA"]._serialized_start = 209
32
+ _globals["_SAMPLEDDATA"]._serialized_end = 290
33
+ # @@protoc_insertion_point(module_scope)
fish_speech/datasets/protos/text_data_stream.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import struct
2
+
3
+ from .text_data_pb2 import TextData
4
+
5
+
6
+ def read_pb_stream(f):
7
+ while True:
8
+ buf = f.read(4)
9
+ if len(buf) == 0:
10
+ break
11
+ size = struct.unpack("I", buf)[0]
12
+ buf = f.read(size)
13
+ text_data = TextData()
14
+ text_data.ParseFromString(buf)
15
+ yield text_data
16
+
17
+
18
+ def write_pb_stream(f, text_data):
19
+ buf = text_data.SerializeToString()
20
+ f.write(struct.pack("I", len(buf)))
21
+ f.write(buf)
22
+
23
+
24
+ def pack_pb_stream(text_data):
25
+ buf = text_data.SerializeToString()
26
+ return struct.pack("I", len(buf)) + buf
27
+
28
+
29
+ def split_pb_stream(f):
30
+ while True:
31
+ head = f.read(4)
32
+ if len(head) == 0:
33
+ break
34
+ size = struct.unpack("I", head)[0]
35
+ buf = f.read(size)
36
+ yield head + buf
fish_speech/datasets/text.py ADDED
@@ -0,0 +1,661 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from dataclasses import dataclass
3
+ from itertools import chain
4
+ from pathlib import Path
5
+ from random import Random
6
+ from typing import Optional, Union
7
+
8
+ import grpc
9
+ import numpy as np
10
+ import pyarrow.parquet as pq
11
+ import torch
12
+ import torch.nn.functional as F
13
+ from datasets.download.streaming_download_manager import xopen
14
+ from huggingface_hub import HfApi
15
+ from lightning import LightningDataModule
16
+ from torch.distributed import get_rank, get_world_size, is_initialized
17
+ from torch.utils.data import DataLoader, IterableDataset, get_worker_info
18
+ from transformers import AutoTokenizer
19
+
20
+ from fish_speech.datasets.protos.text_data_pb2 import SampledData
21
+ from fish_speech.datasets.protos.text_data_stream import read_pb_stream
22
+ from fish_speech.text.clean import clean_text
23
+ from fish_speech.utils import RankedLogger
24
+ from fish_speech.utils.braceexpand import braceexpand
25
+
26
+ log = RankedLogger(__name__, rank_zero_only=True)
27
+
28
+ CODEBOOK_PAD_TOKEN_ID = 0
29
+ CODEBOOK_EOS_TOKEN_ID = 1
30
+
31
+
32
+ def split_by_rank_worker(files):
33
+ # We need to know the total number of devices
34
+ # to split the data properly
35
+
36
+ total_devices = 1
37
+ if is_initialized():
38
+ total_devices = get_world_size()
39
+
40
+ worker_info = get_worker_info()
41
+ if worker_info is not None:
42
+ total_devices *= worker_info.num_workers
43
+
44
+ if len(files) < total_devices:
45
+ # Repeat the files N times to match the number of devices
46
+ files = files * (total_devices // len(files) + 1)
47
+
48
+ # DDP
49
+ if is_initialized():
50
+ files = files[get_rank() :: get_world_size()]
51
+
52
+ # Split by worker
53
+ if worker_info is not None:
54
+ files = files[worker_info.id :: worker_info.num_workers]
55
+
56
+ return files
57
+
58
+
59
+ class StreamTextDataset(IterableDataset):
60
+ def __init__(
61
+ self,
62
+ files: Optional[Union[list[str], str]] = None,
63
+ prefix: Optional[str] = None,
64
+ seed: int = 42,
65
+ parquet_batch_size: int = 10000,
66
+ repo: str = "uonlp/CulturaX",
67
+ max_length: int = 1024,
68
+ tokenizer: AutoTokenizer = None,
69
+ ):
70
+ super().__init__()
71
+
72
+ self.seed = seed
73
+ self.parquet_batch_size = parquet_batch_size
74
+ self.repo = repo
75
+ self.max_length = max_length
76
+ self.tokenizer = tokenizer
77
+
78
+ if files is None and prefix is None:
79
+ raise ValueError("Either files or prefix must be specified")
80
+
81
+ if prefix is not None:
82
+ files = HfApi().list_repo_files(repo, repo_type="dataset")
83
+ files = [
84
+ f for f in files if f.startswith(prefix) and f.endswith(".parquet")
85
+ ]
86
+ log.info(f"Found {len(files)} files in {repo} with prefix {prefix}")
87
+ else:
88
+ if isinstance(files, str):
89
+ files = [files]
90
+
91
+ files = list(chain.from_iterable(map(braceexpand, files)))
92
+ log.info(f"Expanded {len(files)} files in {repo}")
93
+
94
+ # Get sharded files
95
+ self.files = sorted(files)
96
+ Random(seed).shuffle(self.files)
97
+
98
+ def __iter__(self):
99
+ files = split_by_rank_worker(self.files)
100
+ random.shuffle(files)
101
+
102
+ for filename in files:
103
+ try:
104
+ yield from self.parse_data(filename)
105
+ except Exception as e:
106
+ log.exception(f"Failed to parse {filename}: {e}")
107
+
108
+ def parse_data(self, filename: str):
109
+ for data in self.parse_data_internal(filename):
110
+ text = data["text"]
111
+
112
+ # encode
113
+ tokens = self.tokenizer.encode(
114
+ text,
115
+ add_special_tokens=False,
116
+ truncation=False,
117
+ max_length=10**6,
118
+ )
119
+
120
+ # Random choice self.max_length
121
+ if len(tokens) > self.max_length:
122
+ start = random.randint(0, len(tokens) - self.max_length)
123
+ tokens = tokens[start : start + self.max_length - 1]
124
+
125
+ tokens = (
126
+ [self.tokenizer.bos_token_id] + tokens + [self.tokenizer.eos_token_id]
127
+ )
128
+ # Pad dims
129
+ placeholder_multi_codebook = torch.zeros((4, len(tokens)), dtype=torch.long)
130
+
131
+ tokens = torch.concat(
132
+ [
133
+ torch.tensor([tokens], dtype=torch.long),
134
+ placeholder_multi_codebook,
135
+ ],
136
+ dim=0,
137
+ )
138
+ labels = tokens.clone()
139
+ tokens = tokens[:, :-1]
140
+ labels = labels[:, 1:]
141
+ labels[1:] = -100 # remove all placeholders
142
+
143
+ yield {"tokens": tokens, "labels": labels}
144
+
145
+ def parse_data_internal(self, filename: str):
146
+ url = f"https://huggingface.co/datasets/{self.repo}/resolve/main/{filename}"
147
+
148
+ with xopen(url, mode="rb") as stream:
149
+ parquet_file = pq.ParquetFile(stream)
150
+
151
+ for batch in parquet_file.iter_batches(
152
+ batch_size=self.parquet_batch_size, columns=["text"]
153
+ ):
154
+ # In-batch shuffling
155
+ texts = [{"text": text.as_py()} for text in batch["text"]]
156
+ random.shuffle(texts)
157
+ yield from texts
158
+
159
+
160
+ class AutoAugTextDataset(IterableDataset):
161
+ """
162
+ Auto Augment Dataset by Speaker
163
+
164
+ 1. Random concatenate multiple sentences from the same speaker to form a longer sentence
165
+ 2. Automatically normalize the text
166
+
167
+ For interactive mode, we use the following format (multiple sequences):
168
+ <s> [INST] [SPK: speaker] text [/INST] ... [INST] text [/INST] </s>
169
+
170
+ For non-interactive mode, we use the following format (one long sequence):
171
+ <s> [INST] text [/INST] ... </s>
172
+ """
173
+
174
+ def __init__(
175
+ self,
176
+ proto_files: list[str],
177
+ seed: int = 42,
178
+ interactive_prob: float = 0.5,
179
+ max_length: int = 1024,
180
+ tokenizer: AutoTokenizer = None,
181
+ use_speaker: bool = True,
182
+ causual: bool = True,
183
+ use_negative_samples: bool = False,
184
+ num_codebooks: Optional[int] = None,
185
+ ):
186
+ """
187
+ Args:
188
+ proto_files: proto buf files if using local data
189
+ seed: random seed
190
+ interactive_prob: probability to use interactive mode
191
+ max_length: max length of the text
192
+ tokenizer: tokenizer
193
+ use_speaker: include speaker information in the prompt
194
+ causual: use causual sampling when using local data, disable will lead to random sampling
195
+ use_negative_samples: generate negative samples
196
+ num_codebooks: number of codebooks, if None, it will be automatically detected
197
+ """
198
+
199
+ super().__init__()
200
+
201
+ assert 0 <= interactive_prob <= 1, "interactive_prob must be in [0, 1]"
202
+
203
+ self.seed = seed
204
+ self.max_length = max_length
205
+ self.tokenizer = tokenizer
206
+ self.interactive_prob = interactive_prob
207
+ self.use_speaker = use_speaker
208
+ self.proto_files = proto_files
209
+ self.causual = causual
210
+ self.use_negative_samples = use_negative_samples
211
+ self.num_codebooks = num_codebooks
212
+
213
+ self.semantic_token_id = self.tokenizer.convert_tokens_to_ids("<|semantic|>")
214
+ self.groups = None
215
+
216
+ def init_mock_data_server(self):
217
+ if self.groups is not None:
218
+ return
219
+
220
+ # Expand the proto files
221
+ expanded_proto_files = []
222
+ for filename in self.proto_files:
223
+ for i in braceexpand(filename):
224
+ i = Path(i)
225
+ if i.is_file():
226
+ expanded_proto_files.append(i)
227
+ elif i.is_dir():
228
+ expanded_proto_files.extend(i.rglob("*.proto"))
229
+ expanded_proto_files.extend(i.rglob("*.protos"))
230
+ else:
231
+ raise ValueError(f"{i} is not a file or directory")
232
+
233
+ expanded_proto_files = sorted(expanded_proto_files)
234
+ Random(self.seed).shuffle(expanded_proto_files)
235
+
236
+ self.groups = []
237
+ shard_proto_files = split_by_rank_worker(expanded_proto_files)
238
+ log.info(
239
+ f"Reading {len(shard_proto_files)} / {len(expanded_proto_files)} files"
240
+ )
241
+
242
+ count = 0
243
+ for filename in shard_proto_files:
244
+ with open(filename, "rb") as f:
245
+ for text_data in read_pb_stream(f):
246
+ self.groups.append(text_data)
247
+ count += 1
248
+
249
+ log.info(f"Read total {count} groups of data")
250
+
251
+ # Shuffle the lines
252
+ Random(self.seed).shuffle(self.groups)
253
+ self.group_weights = [len(i.sentences) for i in self.groups]
254
+
255
+ def __iter__(self):
256
+ while True:
257
+ yield self.augment()
258
+
259
+ def tokenize_sentence(self, sentence: str):
260
+ sentence = clean_text(sentence)
261
+ tokens = self.tokenizer.encode(
262
+ f"{sentence}",
263
+ max_length=10**6,
264
+ add_special_tokens=False,
265
+ truncation=False,
266
+ )
267
+ return sentence, len(tokens)
268
+
269
+ def sample_data(self):
270
+ if self.groups is None:
271
+ self.init_mock_data_server()
272
+
273
+ # Shuffle unique lines, estimate that each sample is at least 20 tokens
274
+ num_samples = self.max_length // 20
275
+
276
+ # choice group based on their number of samples
277
+ group = random.choices(self.groups, weights=self.group_weights, k=1)[0]
278
+
279
+ if self.causual:
280
+ # Sample in order
281
+ if num_samples >= len(group.sentences):
282
+ samples = group.sentences
283
+ else:
284
+ begin = random.randint(0, len(group.sentences) - num_samples)
285
+ samples = group.sentences[begin : begin + num_samples]
286
+ else:
287
+ samples = random.choices(
288
+ group.sentences, k=min(num_samples, len(group.sentences))
289
+ )
290
+
291
+ return SampledData(
292
+ source=group.source,
293
+ name=group.name,
294
+ samples=samples,
295
+ )
296
+
297
+ def augment(self):
298
+ # Random sample based on speaker using a truncated normal distribution
299
+ a = torch.tensor([0], dtype=torch.float32)
300
+ torch.nn.init.trunc_normal_(
301
+ a,
302
+ mean=self.max_length // 2,
303
+ std=self.max_length // 4,
304
+ a=10,
305
+ b=self.max_length,
306
+ )
307
+ remaining_tokens = a.long().item() - 4
308
+
309
+ final_text, final_semantic = [], []
310
+ response = self.sample_data()
311
+ if len(response.samples) == 0:
312
+ # Invalid group
313
+ return None
314
+
315
+ samples = list(response.samples)
316
+ idx = 0
317
+ use_interactive = random.random() < self.interactive_prob
318
+
319
+ all_tokens, all_labels = [], []
320
+ while remaining_tokens > 0 and len(samples) > 0:
321
+ sentence = samples.pop(0)
322
+
323
+ text = random.choice(sentence.texts)
324
+ text, length = self.tokenize_sentence(text)
325
+ remaining_tokens -= length + len(sentence.semantics[0].values)
326
+
327
+ if use_interactive is False:
328
+ final_text.append(text)
329
+ final_semantic.append(sentence.semantics)
330
+ else:
331
+ # For interactive mode, we only apply speaker for the first sentence
332
+ # [INST] [SPK: speaker] text [/INST] ... [INST] text [/INST]
333
+ tokens, labels = self.pack_sentences(
334
+ sentences=[text],
335
+ semantics=[sentence.semantics],
336
+ speaker=response.name if (self.use_speaker and idx == 0) else None,
337
+ add_bos=idx == 0,
338
+ )
339
+
340
+ all_tokens.append(tokens)
341
+ all_labels.append(labels)
342
+
343
+ idx += 1
344
+
345
+ if use_interactive is False:
346
+ tokens, labels = self.pack_sentences(
347
+ final_text,
348
+ semantics=final_semantic,
349
+ speaker=response.name if self.use_speaker else None,
350
+ add_bos=True,
351
+ )
352
+ all_tokens.append(tokens)
353
+ all_labels.append(labels)
354
+
355
+ tokens = torch.cat(all_tokens, dim=1)
356
+ labels = torch.cat(all_labels, dim=1)
357
+
358
+ # Verify that the length is correct
359
+ assert tokens.size(1) == labels.size(1), f"{tokens.size(1)} != {labels.size(1)}"
360
+
361
+ # Verify bos token
362
+ assert tokens[0, 0] == self.tokenizer.bos_token_id
363
+
364
+ data = {"tokens": tokens, "labels": labels}
365
+
366
+ if self.use_negative_samples:
367
+ negative_samples = self.generate_negative_samples(all_tokens, all_labels)
368
+ data.update(negative_samples)
369
+
370
+ return data
371
+
372
+ def generate_negative_samples(self, all_tokens, all_labels):
373
+ new_tokens, new_labels = [], []
374
+
375
+ for tokens, labels in zip(all_tokens, all_labels):
376
+ # If all codebooks are not -100, we find where it starts
377
+ start = torch.where(labels[1:].sum(0) != -100 * (labels.size(0) - 1))[0][0]
378
+ assert (labels[1:, start:] != -100).all() # This shouldn't happen
379
+
380
+ mode = random.choice(["repeat", "lost", "noise"])
381
+ begin = random.randint(start, labels.size(1) - 1)
382
+ end = random.randint(begin, labels.size(1) - 1)
383
+
384
+ if mode == "repeat":
385
+ tokens = torch.cat(
386
+ [
387
+ tokens[:, :begin],
388
+ tokens[:, begin:end],
389
+ tokens[:, begin:end],
390
+ tokens[:, end:],
391
+ ],
392
+ dim=1,
393
+ )
394
+ labels = torch.cat(
395
+ [
396
+ labels[:, :begin],
397
+ labels[:, begin:end],
398
+ labels[:, begin:end],
399
+ labels[:, end:],
400
+ ],
401
+ dim=1,
402
+ )
403
+ elif mode == "lost":
404
+ tokens = torch.cat([tokens[:, :begin], tokens[:, end:]], dim=1)
405
+ labels = torch.cat([labels[:, :begin], labels[:, end:]], dim=1)
406
+ elif mode == "noise":
407
+ middle_tokens, middle_labels = (
408
+ tokens[:, begin:end],
409
+ labels[:, begin:end],
410
+ )
411
+ random_order0 = torch.randperm(middle_tokens.size(1))
412
+ random_order1 = torch.randperm(middle_tokens.size(1))
413
+ middle_tokens = middle_tokens[:, random_order0]
414
+ middle_labels = middle_labels[:, random_order1]
415
+ tokens = torch.cat(
416
+ [tokens[:, :begin], middle_tokens, tokens[:, end:]], dim=1
417
+ )
418
+ labels = torch.cat(
419
+ [labels[:, :begin], middle_labels, labels[:, end:]], dim=1
420
+ )
421
+
422
+ new_tokens.append(tokens)
423
+ new_labels.append(labels)
424
+
425
+ tokens = torch.cat(new_tokens, dim=1)
426
+ labels = torch.cat(new_labels, dim=1)
427
+
428
+ # Verify that the length is correct
429
+ assert tokens.size(1) == labels.size(1), f"{tokens.size(1)} != {labels.size(1)}"
430
+
431
+ return {"negative_tokens": tokens, "negative_labels": labels}
432
+
433
+ def pack_sentences(
434
+ self,
435
+ sentences: list[str],
436
+ semantics=list,
437
+ speaker: Optional[str] = None,
438
+ add_bos: bool = True,
439
+ ):
440
+ if speaker is not None:
441
+ sentences = [f"[SPK: {speaker}]"] + sentences
442
+
443
+ final_text = "<|im_start|>user<|im_sep|>" + " ".join(sentences) + "<|im_end|>"
444
+ final_text = final_text + "<|im_start|>assistant<|im_sep|>"
445
+
446
+ encoded = self.tokenizer.encode(
447
+ final_text,
448
+ add_special_tokens=False,
449
+ truncation=False,
450
+ max_length=10**6,
451
+ )
452
+ semantic_length = sum([len(i[0].values) for i in semantics])
453
+ prompt_length = len(encoded)
454
+ num_codebooks = (
455
+ len(semantics[0]) if self.num_codebooks is None else self.num_codebooks
456
+ )
457
+
458
+ bos_bias = 1 if add_bos else 0
459
+
460
+ # Pack the tokens and semantics (add <s> and </s> to semantic tokens)
461
+ tokens = (
462
+ encoded
463
+ + [self.semantic_token_id] * semantic_length
464
+ + self.tokenizer.convert_tokens_to_ids(
465
+ ["<|im_end|>", "<|end_of_sequence|>"]
466
+ )
467
+ )
468
+
469
+ if add_bos:
470
+ tokens = [self.tokenizer.bos_token_id] + tokens
471
+
472
+ # Codebook bos/padding: 0, eos: 1
473
+ codes = [
474
+ [CODEBOOK_PAD_TOKEN_ID] * (prompt_length + bos_bias)
475
+ for _ in range(num_codebooks)
476
+ ]
477
+ for segment in semantics:
478
+ for book_idx, book in zip(range(num_codebooks), segment):
479
+ for j in book.values:
480
+ codes[book_idx].append(int(j) + 2)
481
+
482
+ for book in codes:
483
+ book.extend([CODEBOOK_EOS_TOKEN_ID] * 2)
484
+
485
+ tokens = [tokens] + codes
486
+
487
+ tokens = torch.tensor(tokens, dtype=torch.long)
488
+ labels = tokens.clone()
489
+
490
+ # Mask out the <s> tokens for semantic, predict semantic tokens only
491
+ # Since we don't mask out the input tokens, the language modeling still works
492
+ labels[1:, : (prompt_length + bos_bias)] = -100
493
+
494
+ tokens = tokens[:, :-1]
495
+ labels = labels[:, 1:]
496
+
497
+ # Verify the padding is correct, and the last token is eos
498
+ assert add_bos is False or tokens[0, 0] == self.tokenizer.bos_token_id
499
+ assert (tokens[1:, : prompt_length + bos_bias] == CODEBOOK_PAD_TOKEN_ID).all()
500
+ assert labels[0, -1] == self.tokenizer.eos_token_id
501
+ assert (labels[1:, -2:] == CODEBOOK_EOS_TOKEN_ID).all()
502
+
503
+ return tokens, labels
504
+
505
+
506
+ @dataclass
507
+ class TextDataCollator:
508
+ tokenizer: AutoTokenizer
509
+ max_length: int = 1024
510
+
511
+ def __call__(self, examples):
512
+ if "negative_tokens" in examples:
513
+ positive_examples = []
514
+ negative_examples = []
515
+
516
+ for i in examples:
517
+ positive_examples.append(
518
+ {
519
+ "tokens": i["tokens"],
520
+ "labels": i["labels"],
521
+ }
522
+ )
523
+ negative_examples.append(
524
+ {
525
+ "tokens": i["negative_tokens"],
526
+ "labels": i["negative_labels"],
527
+ }
528
+ )
529
+
530
+ examples = positive_examples + negative_examples
531
+
532
+ return self.batchify(examples)
533
+
534
+ def batchify(self, examples, tokens_key="tokens", labels_key="labels"):
535
+ tokens, attention_masks, labels = [], [], []
536
+
537
+ # Calculate the max length
538
+ max_tokens_length = 0
539
+ for example in examples:
540
+ max_tokens_length = max(max_tokens_length, example[tokens_key].size(1))
541
+ max_tokens_length = min(max_tokens_length, self.max_length)
542
+
543
+ for example in examples:
544
+ _tokens = example[tokens_key][:, :max_tokens_length]
545
+ _labels = example[labels_key][:, :max_tokens_length]
546
+ _attention_mask = torch.ones((max_tokens_length,), dtype=torch.bool)
547
+ tokens_length = _tokens.size(1)
548
+ _attention_mask[:tokens_length] = False
549
+
550
+ assert tokens_length == _labels.size(
551
+ 1
552
+ ), f"{tokens_length} != {_labels.size(1)}"
553
+
554
+ if tokens_length < max_tokens_length:
555
+ _tokens = F.pad(
556
+ _tokens,
557
+ (0, max_tokens_length - tokens_length),
558
+ value=self.tokenizer.eos_token_id,
559
+ )
560
+ _tokens[1:, tokens_length:] = CODEBOOK_PAD_TOKEN_ID
561
+ _labels = F.pad(
562
+ _labels, (0, max_tokens_length - _labels.size(1)), value=-100
563
+ )
564
+
565
+ tokens.append(_tokens)
566
+ attention_masks.append(_attention_mask)
567
+ labels.append(_labels)
568
+
569
+ tokens = torch.stack(tokens, dim=0)
570
+ attention_masks = torch.stack(attention_masks, dim=0)
571
+ labels = torch.stack(labels, dim=0)
572
+
573
+ return {
574
+ "inputs": tokens,
575
+ "attention_masks": attention_masks,
576
+ "labels": labels,
577
+ }
578
+
579
+
580
+ class InterleaveDataset(IterableDataset):
581
+ def __init__(
582
+ self,
583
+ datasets: list[IterableDataset],
584
+ probabilities: list[float],
585
+ seed: int = 42,
586
+ ):
587
+ super().__init__()
588
+
589
+ self.datasets = datasets
590
+ self.probabilities = probabilities
591
+ self.seed = seed
592
+
593
+ def __iter__(self):
594
+ rng = np.random.default_rng(self.seed)
595
+ dataset_iterators = [iter(dataset) for dataset in self.datasets]
596
+
597
+ while True:
598
+ # Random choice one
599
+ dataset_idx = rng.choice(len(self.datasets), p=self.probabilities)
600
+ dataset_iterator = dataset_iterators[dataset_idx]
601
+
602
+ try:
603
+ yield next(dataset_iterator)
604
+ except StopIteration:
605
+ # Exhausted, create a new iterator
606
+ dataset_iterators[dataset_idx] = iter(self.datasets[dataset_idx])
607
+ yield next(dataset_iterators[dataset_idx])
608
+
609
+
610
+ class TextDataModule(LightningDataModule):
611
+ def __init__(
612
+ self,
613
+ train_dataset: Union[StreamTextDataset, AutoAugTextDataset, InterleaveDataset],
614
+ val_dataset: Union[StreamTextDataset, AutoAugTextDataset, InterleaveDataset],
615
+ batch_size: int = 32,
616
+ tokenizer: AutoTokenizer = None,
617
+ max_length: int = 1024,
618
+ num_workers: int = 4,
619
+ ):
620
+ super().__init__()
621
+
622
+ self.train_dataset = train_dataset
623
+ self.val_dataset = val_dataset
624
+ self.batch_size = batch_size
625
+ self.tokenizer = tokenizer
626
+ self.max_length = max_length
627
+ self.num_workers = num_workers
628
+
629
+ def train_dataloader(self):
630
+ return DataLoader(
631
+ self.train_dataset,
632
+ batch_size=self.batch_size,
633
+ collate_fn=TextDataCollator(self.tokenizer, self.max_length),
634
+ num_workers=self.num_workers,
635
+ )
636
+
637
+ def val_dataloader(self):
638
+ return DataLoader(
639
+ self.val_dataset,
640
+ batch_size=self.batch_size,
641
+ collate_fn=TextDataCollator(self.tokenizer, self.max_length),
642
+ num_workers=self.num_workers,
643
+ )
644
+
645
+
646
+ if __name__ == "__main__":
647
+ from tqdm import tqdm
648
+
649
+ ds = AutoAugTextDataset(
650
+ ["data/protos"],
651
+ tokenizer=AutoTokenizer.from_pretrained("fishaudio/fish-speech-1"),
652
+ use_speaker=False,
653
+ interactive_prob=1.0,
654
+ use_negative_samples=False,
655
+ )
656
+
657
+ for i in ds:
658
+ print(ds.tokenizer.decode(i["tokens"][0], skip_special_tokens=False))
659
+ # i["labels"][0][i["labels"][0] == -100] = 0
660
+ # print(ds.tokenizer.decode(i["labels"][0], skip_special_tokens=False))
661
+ break
fish_speech/datasets/vqgan.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from pathlib import Path
3
+ from typing import Optional
4
+
5
+ import librosa
6
+ import numpy as np
7
+ import torch
8
+ from lightning import LightningDataModule
9
+ from torch.utils.data import DataLoader, Dataset
10
+
11
+ from fish_speech.utils import RankedLogger
12
+
13
+ logger = RankedLogger(__name__, rank_zero_only=False)
14
+
15
+
16
+ class VQGANDataset(Dataset):
17
+ def __init__(
18
+ self,
19
+ filelist: str,
20
+ sample_rate: int = 32000,
21
+ hop_length: int = 640,
22
+ slice_frames: Optional[int] = None,
23
+ ):
24
+ super().__init__()
25
+
26
+ filelist = Path(filelist)
27
+ root = filelist.parent
28
+
29
+ self.files = [
30
+ root / line.strip()
31
+ for line in filelist.read_text().splitlines()
32
+ if line.strip()
33
+ ]
34
+ self.sample_rate = sample_rate
35
+ self.hop_length = hop_length
36
+ self.slice_frames = slice_frames
37
+
38
+ def __len__(self):
39
+ return len(self.files)
40
+
41
+ def get_item(self, idx):
42
+ file = self.files[idx]
43
+
44
+ audio, _ = librosa.load(file, sr=self.sample_rate, mono=True)
45
+
46
+ # Slice audio and features
47
+ if (
48
+ self.slice_frames is not None
49
+ and audio.shape[0] > self.slice_frames * self.hop_length
50
+ ):
51
+ start = np.random.randint(
52
+ 0, audio.shape[0] - self.slice_frames * self.hop_length
53
+ )
54
+ audio = audio[start : start + self.slice_frames * self.hop_length]
55
+
56
+ if len(audio) == 0:
57
+ return None
58
+
59
+ max_value = np.abs(audio).max()
60
+ if max_value > 1.0:
61
+ audio = audio / max_value
62
+
63
+ return {
64
+ "audio": torch.from_numpy(audio),
65
+ }
66
+
67
+ def __getitem__(self, idx):
68
+ try:
69
+ return self.get_item(idx)
70
+ except Exception as e:
71
+ import traceback
72
+
73
+ traceback.print_exc()
74
+ logger.error(f"Error loading {self.files[idx]}: {e}")
75
+ return None
76
+
77
+
78
+ @dataclass
79
+ class VQGANCollator:
80
+ def __call__(self, batch):
81
+ batch = [x for x in batch if x is not None]
82
+
83
+ audio_lengths = torch.tensor([len(x["audio"]) for x in batch])
84
+ audio_maxlen = audio_lengths.max()
85
+
86
+ # Rounds up to nearest multiple of 2 (audio_lengths)
87
+ audios = []
88
+ for x in batch:
89
+ audios.append(
90
+ torch.nn.functional.pad(x["audio"], (0, audio_maxlen - len(x["audio"])))
91
+ )
92
+
93
+ return {
94
+ "audios": torch.stack(audios),
95
+ "audio_lengths": audio_lengths,
96
+ }
97
+
98
+
99
+ class VQGANDataModule(LightningDataModule):
100
+ def __init__(
101
+ self,
102
+ train_dataset: VQGANDataset,
103
+ val_dataset: VQGANDataset,
104
+ batch_size: int = 32,
105
+ num_workers: int = 4,
106
+ val_batch_size: Optional[int] = None,
107
+ ):
108
+ super().__init__()
109
+
110
+ self.train_dataset = train_dataset
111
+ self.val_dataset = val_dataset
112
+ self.batch_size = batch_size
113
+ self.val_batch_size = val_batch_size or batch_size
114
+ self.num_workers = num_workers
115
+
116
+ def train_dataloader(self):
117
+ return DataLoader(
118
+ self.train_dataset,
119
+ batch_size=self.batch_size,
120
+ collate_fn=VQGANCollator(),
121
+ num_workers=self.num_workers,
122
+ shuffle=True,
123
+ )
124
+
125
+ def val_dataloader(self):
126
+ return DataLoader(
127
+ self.val_dataset,
128
+ batch_size=self.val_batch_size,
129
+ collate_fn=VQGANCollator(),
130
+ num_workers=self.num_workers,
131
+ )
132
+
133
+
134
+ if __name__ == "__main__":
135
+ dataset = VQGANDataset("data/LibriTTS_R/vq_train_filelist.txt")
136
+ dataloader = DataLoader(
137
+ dataset, batch_size=4, shuffle=False, collate_fn=VQGANCollator()
138
+ )
139
+
140
+ for batch in dataloader:
141
+ print(batch["audios"].shape)
142
+ print(batch["features"].shape)
143
+ print(batch["audio_lengths"])
144
+ print(batch["feature_lengths"])
145
+ break
fish_speech/models/text2semantic/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .lit_module import TextToSemantic
2
+
3
+ __all__ = ["TextToSemantic"]
fish_speech/models/text2semantic/lit_module.py ADDED
@@ -0,0 +1,344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Any, Optional
3
+
4
+ import lightning as L
5
+ import loralib as lora
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from lightning.pytorch.utilities.types import OptimizerLRScheduler
9
+
10
+ import fish_speech.utils as utils
11
+ from fish_speech.models.text2semantic.llama import NaiveTransformer
12
+
13
+ log = utils.RankedLogger(__name__, rank_zero_only=True)
14
+
15
+
16
+ @dataclass
17
+ class LoraConfig:
18
+ r: int
19
+ lora_alpha: float
20
+ lora_dropout: float = 0.0
21
+
22
+
23
+ class TextToSemantic(L.LightningModule):
24
+ def __init__(
25
+ self,
26
+ model: NaiveTransformer,
27
+ optimizer: Any,
28
+ lr_scheduler: Any,
29
+ lora_config: Optional[LoraConfig] = None,
30
+ save_lora_only: bool = False,
31
+ use_dpo: bool = False,
32
+ dpo_beta: float = 0.2,
33
+ ):
34
+ super().__init__()
35
+
36
+ self.model = model
37
+ self.optimizer_builder = optimizer
38
+ self.lr_scheduler_builder = lr_scheduler
39
+ self.lora_config = lora_config
40
+ self.save_lora_only = save_lora_only
41
+ self.use_dpo = use_dpo # We don't support reference model yet
42
+ self.dpo_beta = dpo_beta
43
+
44
+ if self.lora_config is not None:
45
+ self.setup_lora()
46
+
47
+ def setup_lora(self):
48
+ # Replace the embedding layer with a LoRA layer
49
+ self.model.embeddings = lora.Embedding(
50
+ num_embeddings=self.model.embeddings.num_embeddings,
51
+ embedding_dim=self.model.embeddings.embedding_dim,
52
+ padding_idx=self.model.embeddings.padding_idx,
53
+ r=self.lora_config.r,
54
+ lora_alpha=self.lora_config.lora_alpha,
55
+ )
56
+
57
+ # Replace output layer with a LoRA layer
58
+ linears = [(self.model, "output")]
59
+
60
+ # Replace all linear layers with LoRA layers
61
+ for layer in self.model.layers:
62
+ linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")])
63
+ linears.extend(
64
+ [
65
+ (layer.feed_forward, "w1"),
66
+ (layer.feed_forward, "w2"),
67
+ (layer.feed_forward, "w3"),
68
+ ]
69
+ )
70
+
71
+ if hasattr(self.model, "fast_layers"):
72
+ # Dual-AR model
73
+ linears.extend([(self.model, "fast_output")])
74
+
75
+ for layer in self.model.fast_layers:
76
+ linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")])
77
+ linears.extend(
78
+ [
79
+ (layer.feed_forward, "w1"),
80
+ (layer.feed_forward, "w2"),
81
+ (layer.feed_forward, "w3"),
82
+ ]
83
+ )
84
+
85
+ for module, layer in linears:
86
+ updated_linear = lora.Linear(
87
+ in_features=getattr(module, layer).in_features,
88
+ out_features=getattr(module, layer).out_features,
89
+ bias=getattr(module, layer).bias,
90
+ r=self.lora_config.r,
91
+ lora_alpha=self.lora_config.lora_alpha,
92
+ lora_dropout=self.lora_config.lora_dropout,
93
+ )
94
+ setattr(module, layer, updated_linear)
95
+
96
+ # Mark only the LoRA layers as trainable
97
+ lora.mark_only_lora_as_trainable(self.model, bias="lora_only")
98
+
99
+ def forward(self, x):
100
+ return self.model(x)
101
+
102
+ def on_save_checkpoint(self, checkpoint):
103
+ if self.lora_config is None or self.save_lora_only is False:
104
+ return
105
+
106
+ # Save only LoRA parameters
107
+ state_dict = checkpoint["state_dict"]
108
+ for name in list(state_dict.keys()):
109
+ if "lora" not in name:
110
+ state_dict.pop(name)
111
+
112
+ def configure_optimizers(self) -> OptimizerLRScheduler:
113
+ # Get weight decay parameters
114
+ weight_decay_parameters, other_parameters = [], []
115
+ for name, param in self.named_parameters():
116
+ if ".bias" in name or "norm.weight" in name or ".embeddings." in name:
117
+ other_parameters.append(param)
118
+ else:
119
+ weight_decay_parameters.append(param)
120
+
121
+ optimizer = self.optimizer_builder(
122
+ [
123
+ {"params": weight_decay_parameters},
124
+ {"params": other_parameters, "weight_decay": 0.0},
125
+ ]
126
+ )
127
+
128
+ # Print the parameters and their weight decay
129
+ for i in optimizer.param_groups:
130
+ log.info(
131
+ f"Set weight decay: {i['weight_decay']} for {len(i['params'])} parameters"
132
+ )
133
+
134
+ lr_scheduler = self.lr_scheduler_builder(optimizer)
135
+
136
+ return {
137
+ "optimizer": optimizer,
138
+ "lr_scheduler": {
139
+ "scheduler": lr_scheduler,
140
+ "interval": "step",
141
+ },
142
+ }
143
+
144
+ # Copied from https://github.com/eric-mitchell/direct-preference-optimization/blob/main/trainers.py#L90
145
+ def get_batch_logps(
146
+ self,
147
+ logits: torch.FloatTensor,
148
+ labels: torch.LongTensor,
149
+ average_log_prob: bool = False,
150
+ ) -> torch.FloatTensor:
151
+ """Compute the log probabilities of the given labels under the given logits.
152
+
153
+ Args:
154
+ logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, codebook_size, vocab_size)
155
+ labels: Labels for which to compute the log probabilities. Label tokens with a value of -100 are ignored. Shape: (batch_size, sequence_length, codebook_size)
156
+ average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
157
+
158
+ Returns:
159
+ A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
160
+ """
161
+ assert logits.shape[:-1] == labels.shape
162
+
163
+ labels = labels.clone()
164
+ loss_mask = labels != -100
165
+
166
+ # dummy token; we'll ignore the losses on these tokens later
167
+ labels[labels == -100] = 0
168
+
169
+ per_token_logps = torch.gather(
170
+ logits.log_softmax(-1), dim=-1, index=labels.unsqueeze(-1)
171
+ ).squeeze(-1)
172
+
173
+ if average_log_prob:
174
+ return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
175
+ else:
176
+ return (per_token_logps * loss_mask).sum(-1)
177
+
178
+ def _step(self, batch, batch_idx, stage: str):
179
+ is_train = stage == "train"
180
+
181
+ # Do positive and negative samples in the same batch to speed up training
182
+ labels = batch["labels"]
183
+ outputs = self.model(
184
+ inp=batch["inputs"],
185
+ key_padding_mask=batch["attention_masks"],
186
+ )
187
+ token_logits = outputs.token_logits
188
+ codebook_logits = outputs.codebook_logits
189
+
190
+ if self.use_dpo:
191
+ # Firtst half is positive, second half is negative
192
+ token_logits, negative_token_logits = token_logits.chunk(2)
193
+ codebook_logits, negative_codebook_logits = codebook_logits.chunk(2)
194
+ labels, negative_labels = labels.chunk(2)
195
+
196
+ # Generate labels
197
+ base_loss = F.cross_entropy(
198
+ token_logits.reshape(-1, token_logits.size(-1)),
199
+ labels[:, 0].reshape(-1),
200
+ ignore_index=-100,
201
+ )
202
+
203
+ codebook_labels = labels[:, 1 : 1 + self.model.config.num_codebooks].mT
204
+ semantic_loss = F.cross_entropy(
205
+ codebook_logits.reshape(-1, codebook_logits.size(-1)),
206
+ codebook_labels.reshape(-1),
207
+ ignore_index=-100,
208
+ )
209
+
210
+ loss = base_loss + semantic_loss
211
+
212
+ # If we use dpo
213
+ if self.use_dpo:
214
+ negative_codebook_labels = negative_labels[
215
+ :, 1 : 1 + self.model.config.num_codebooks
216
+ ].mT
217
+
218
+ positive_codebook_logps = self.get_batch_logps(
219
+ codebook_logits, codebook_labels
220
+ )
221
+ negative_codebook_logps = self.get_batch_logps(
222
+ negative_codebook_logits, negative_codebook_labels
223
+ )
224
+
225
+ # TODO: implement the reference model, avoid screwing up the gradients
226
+ dpo_loss = -F.logsigmoid(
227
+ (positive_codebook_logps - negative_codebook_logps) * self.dpo_beta
228
+ ).mean()
229
+
230
+ chosen_rewards = self.dpo_beta * positive_codebook_logps.detach()
231
+ rejected_rewards = self.dpo_beta * negative_codebook_logps.detach()
232
+ reward_accuracy = (chosen_rewards > rejected_rewards).float().mean()
233
+ chosen_rewards, rejected_rewards = (
234
+ chosen_rewards.mean(),
235
+ rejected_rewards.mean(),
236
+ )
237
+
238
+ loss = loss + dpo_loss
239
+
240
+ self.log(
241
+ f"{stage}/dpo_loss",
242
+ dpo_loss,
243
+ on_step=is_train,
244
+ on_epoch=not is_train,
245
+ prog_bar=False,
246
+ logger=True,
247
+ )
248
+
249
+ self.log(
250
+ f"{stage}/chosen_rewards",
251
+ chosen_rewards,
252
+ on_step=is_train,
253
+ on_epoch=not is_train,
254
+ prog_bar=False,
255
+ logger=True,
256
+ )
257
+
258
+ self.log(
259
+ f"{stage}/rejected_rewards",
260
+ rejected_rewards,
261
+ on_step=is_train,
262
+ on_epoch=not is_train,
263
+ prog_bar=False,
264
+ logger=True,
265
+ )
266
+
267
+ self.log(
268
+ f"{stage}/reward_accuracy",
269
+ reward_accuracy,
270
+ on_step=is_train,
271
+ on_epoch=not is_train,
272
+ prog_bar=False,
273
+ logger=True,
274
+ )
275
+
276
+ self.log(
277
+ f"{stage}/loss",
278
+ loss,
279
+ on_step=is_train,
280
+ on_epoch=not is_train,
281
+ prog_bar=True,
282
+ logger=True,
283
+ )
284
+
285
+ self.log(
286
+ f"{stage}/base_loss",
287
+ base_loss,
288
+ on_step=is_train,
289
+ on_epoch=not is_train,
290
+ prog_bar=False,
291
+ logger=True,
292
+ )
293
+
294
+ self.log(
295
+ f"{stage}/semantic_loss",
296
+ semantic_loss,
297
+ on_step=is_train,
298
+ on_epoch=not is_train,
299
+ prog_bar=False,
300
+ logger=True,
301
+ )
302
+
303
+ # Top-5 accuracy
304
+ accuracy = self.get_accuracy(codebook_logits, codebook_labels)
305
+ self.log(
306
+ f"{stage}/top_5_accuracy",
307
+ accuracy,
308
+ on_step=is_train,
309
+ on_epoch=not is_train,
310
+ prog_bar=True,
311
+ logger=True,
312
+ )
313
+
314
+ if self.model.config.num_codebooks != self.model.config.num_in_codebooks:
315
+ accuracy = self.get_accuracy(
316
+ codebook_logits[:, :, : self.model.config.num_in_codebooks],
317
+ codebook_labels[:, :, : self.model.config.num_in_codebooks],
318
+ )
319
+
320
+ self.log(
321
+ f"{stage}/top_5_accuracy_in",
322
+ accuracy,
323
+ on_step=is_train,
324
+ on_epoch=not is_train,
325
+ prog_bar=True,
326
+ logger=True,
327
+ )
328
+
329
+ return loss
330
+
331
+ def get_accuracy(self, logits, labels):
332
+ _, indices = logits.topk(5, dim=-1)
333
+ correct = indices.eq(labels.unsqueeze(-1))
334
+ correct[labels == -100] = 0
335
+ correct = correct.sum()
336
+ accuracy = correct / (labels != -100).sum()
337
+
338
+ return accuracy
339
+
340
+ def training_step(self, batch, batch_idx):
341
+ return self._step(batch, batch_idx, "train")
342
+
343
+ def validation_step(self, batch, batch_idx):
344
+ return self._step(batch, batch_idx, "val")
fish_speech/models/text2semantic/llama.py ADDED
@@ -0,0 +1,595 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+ from typing import Optional
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from einops import rearrange
8
+ from torch import Tensor
9
+ from torch.nn import functional as F
10
+ from torch.utils.checkpoint import checkpoint
11
+
12
+
13
+ def find_multiple(n: int, k: int) -> int:
14
+ if n % k == 0:
15
+ return n
16
+ return n + k - (n % k)
17
+
18
+
19
+ @dataclass
20
+ class BaseModelArgs:
21
+ vocab_size: int = 32000
22
+ n_layer: int = 32
23
+ n_head: int = 32
24
+ dim: int = 4096
25
+ intermediate_size: int = None
26
+ n_local_heads: int = -1
27
+ head_dim: int = 64
28
+ rope_base: float = 10000
29
+ norm_eps: float = 1e-5
30
+ max_seq_len: int = 2048
31
+ dropout: float = 0.0
32
+
33
+ # Codebook configs
34
+ codebook_size: int = 160
35
+ num_codebooks: int = 4
36
+ num_in_codebooks: Optional[int] = None
37
+ codebook_padding_idx: int = 0
38
+
39
+ # Gradient checkpointing
40
+ use_gradient_checkpointing: bool = True
41
+
42
+ def __post_init__(self):
43
+ if self.n_local_heads == -1:
44
+ self.n_local_heads = self.n_head
45
+ if self.intermediate_size is None:
46
+ hidden_dim = 4 * self.dim
47
+ n_hidden = int(2 * hidden_dim / 3)
48
+ self.intermediate_size = find_multiple(n_hidden, 256)
49
+ if self.num_in_codebooks is None:
50
+ self.num_in_codebooks = self.num_codebooks
51
+ self.head_dim = self.dim // self.n_head
52
+
53
+
54
+ @dataclass
55
+ class NaiveModelArgs(BaseModelArgs):
56
+ pass
57
+
58
+
59
+ @dataclass
60
+ class DualARModelArgs(BaseModelArgs):
61
+ n_fast_layer: int = 4
62
+
63
+
64
+ class KVCache(nn.Module):
65
+ def __init__(
66
+ self, max_batch_size, max_seq_len, n_heads, head_dim, dtype=torch.bfloat16
67
+ ):
68
+ super().__init__()
69
+ cache_shape = (max_batch_size, n_heads, max_seq_len, head_dim)
70
+ self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype))
71
+ self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype))
72
+
73
+ def update(self, input_pos, k_val, v_val):
74
+ # input_pos: [S], k_val: [B, H, S, D]
75
+ assert input_pos.shape[0] == k_val.shape[2]
76
+
77
+ k_out = self.k_cache
78
+ v_out = self.v_cache
79
+ k_out[:, :, input_pos] = k_val
80
+ v_out[:, :, input_pos] = v_val
81
+
82
+ return k_out, v_out
83
+
84
+
85
+ @dataclass
86
+ class TransformerForwardResult:
87
+ token_logits: Tensor
88
+ codebook_logits: Tensor
89
+
90
+
91
+ @dataclass
92
+ class BaseTransformerForwardResult:
93
+ logits: Tensor
94
+ hidden_states: Tensor
95
+
96
+
97
+ class BaseTransformer(nn.Module):
98
+ def __init__(self, config: BaseModelArgs) -> None:
99
+ super().__init__()
100
+ self.config = config
101
+
102
+ # Slow transformer
103
+ self.embeddings = nn.Embedding(
104
+ config.vocab_size + config.codebook_size * config.num_in_codebooks,
105
+ config.dim,
106
+ )
107
+ self.layers = nn.ModuleList(
108
+ TransformerBlock(config, use_sdpa=True) for _ in range(config.n_layer)
109
+ )
110
+ self.norm = RMSNorm(config.dim, eps=config.norm_eps)
111
+ self.output = nn.Linear(
112
+ config.dim,
113
+ config.vocab_size,
114
+ bias=False,
115
+ )
116
+
117
+ self.register_buffer(
118
+ "freqs_cis",
119
+ precompute_freqs_cis(
120
+ config.max_seq_len,
121
+ config.dim // config.n_head,
122
+ config.rope_base,
123
+ ),
124
+ persistent=False,
125
+ )
126
+ self.register_buffer(
127
+ "causal_mask",
128
+ torch.tril(
129
+ torch.ones(
130
+ config.max_seq_len,
131
+ config.max_seq_len,
132
+ dtype=torch.bool,
133
+ )
134
+ ),
135
+ persistent=False,
136
+ )
137
+
138
+ # For kv cache
139
+ self.max_batch_size = -1
140
+ self.max_seq_len = -1
141
+
142
+ def setup_caches(
143
+ self, max_batch_size: int, max_seq_len: int, dtype: torch.dtype = torch.bfloat16
144
+ ):
145
+ if self.max_seq_len >= max_seq_len and self.max_batch_size >= max_batch_size:
146
+ return
147
+
148
+ head_dim = self.config.dim // self.config.n_head
149
+ max_seq_len = find_multiple(max_seq_len, 8)
150
+ self.max_seq_len = max_seq_len
151
+ self.max_batch_size = max_batch_size
152
+
153
+ for b in self.layers:
154
+ b.attention.kv_cache = KVCache(
155
+ max_batch_size,
156
+ max_seq_len,
157
+ self.config.n_local_heads,
158
+ head_dim,
159
+ dtype=dtype,
160
+ )
161
+
162
+ def embed(self, x: Tensor) -> Tensor:
163
+ vocab_embeds = [self.embeddings(x[:, 0])]
164
+ for i in range(self.config.num_in_codebooks):
165
+ emb = self.embeddings(
166
+ x[:, i + 1] + i * self.config.codebook_size + self.config.vocab_size
167
+ )
168
+ emb[x[:, i + 1] == self.config.codebook_padding_idx] = 0
169
+ vocab_embeds.append(emb)
170
+
171
+ x = torch.stack(vocab_embeds, dim=3)
172
+ x = x.sum(dim=3)
173
+
174
+ return x
175
+
176
+ def forward(
177
+ self, inp: Tensor, key_padding_mask: Optional[Tensor] = None
178
+ ) -> BaseTransformerForwardResult:
179
+ # x: (batch, num_codebooks + 1, seq_len)
180
+ seq_len = inp.size(2)
181
+
182
+ # Here we want to merge the embeddings of the codebooks
183
+ x = self.embed(inp)
184
+
185
+ mask = self.causal_mask[None, None, :seq_len, :seq_len] # (B, N, Q, K)
186
+ freqs_cis = self.freqs_cis[:seq_len]
187
+
188
+ # Not that the causal mask here follows the definition of scaled_dot_product_attention
189
+ # That is, FALSE means masked out
190
+ # To maintain consistency, key_padding_mask use TRUE to mask out
191
+ if key_padding_mask is not None:
192
+ mask = mask & key_padding_mask[:, None, None, :].logical_not()
193
+
194
+ for layer in self.layers:
195
+ if self.config.use_gradient_checkpointing and self.training:
196
+ x = checkpoint(layer, x, freqs_cis, mask, use_reentrant=True)
197
+ else:
198
+ x = layer(x, freqs_cis, mask)
199
+
200
+ # We got slow_out here
201
+ slow_out = self.norm(x)
202
+ token_logits = self.output(slow_out)
203
+
204
+ return BaseTransformerForwardResult(
205
+ logits=token_logits,
206
+ hidden_states=x,
207
+ )
208
+
209
+ def forward_generate(
210
+ self, x: Tensor, input_pos: Optional[Tensor] = None
211
+ ) -> BaseTransformerForwardResult:
212
+ # This is used for generation, optimized for torch compile
213
+ assert (
214
+ self.max_seq_len != -1 and self.max_batch_size != -1
215
+ ), "Please call setup_caches before forward_generate"
216
+
217
+ x = self.embed(x)
218
+
219
+ mask = self.causal_mask[
220
+ None, None, input_pos, : self.max_seq_len
221
+ ] # (B, N, Q, K)
222
+ freqs_cis = self.freqs_cis[input_pos]
223
+
224
+ for layer in self.layers:
225
+ x = layer(x, freqs_cis, mask, input_pos=input_pos)
226
+
227
+ # If prefill, we only calculate the logits of last token
228
+ if x.size(1) > 1:
229
+ x = x[:, -1:]
230
+
231
+ # We got slow_out here
232
+ slow_out = self.norm(x)
233
+ token_logits = self.output(slow_out)
234
+
235
+ return BaseTransformerForwardResult(
236
+ logits=token_logits,
237
+ hidden_states=x,
238
+ )
239
+
240
+
241
+ class NaiveTransformer(BaseTransformer):
242
+ def __init__(self, config: NaiveModelArgs) -> None:
243
+ super().__init__(config)
244
+
245
+ self.codebook_norm = RMSNorm(config.dim, eps=config.norm_eps)
246
+ self.codebook_output = nn.Linear(
247
+ config.dim,
248
+ config.codebook_size * config.num_codebooks,
249
+ bias=False,
250
+ )
251
+
252
+ def decode(self, result: BaseTransformerForwardResult) -> TransformerForwardResult:
253
+ token_logits = result.logits
254
+ x = result.hidden_states
255
+
256
+ # Codebook
257
+ codebook_logits = self.codebook_output(self.codebook_norm(x))
258
+ codebook_logits = rearrange(
259
+ codebook_logits, "b n (c d) -> b n c d", c=self.config.num_codebooks
260
+ )
261
+
262
+ return TransformerForwardResult(
263
+ token_logits=token_logits,
264
+ codebook_logits=codebook_logits,
265
+ )
266
+
267
+ def forward(
268
+ self, inp: Tensor, key_padding_mask: Optional[Tensor] = None
269
+ ) -> TransformerForwardResult:
270
+ result = super().forward(inp, key_padding_mask)
271
+ return self.decode(result)
272
+
273
+ def forward_generate(
274
+ self, x: Tensor, input_pos: Optional[Tensor] = None
275
+ ) -> TransformerForwardResult:
276
+ result = super().forward_generate(x, input_pos)
277
+ return self.decode(result)
278
+
279
+
280
+ class DualARTransformer(BaseTransformer):
281
+ def __init__(self, config: DualARModelArgs) -> None:
282
+ super().__init__(config)
283
+
284
+ # Fast transformer
285
+ self.fast_embeddings = nn.Embedding(
286
+ config.codebook_size, config.dim, padding_idx=config.codebook_padding_idx
287
+ )
288
+
289
+ # The equivalent bs is so large that sdpa doesn't work
290
+ self.fast_layers = nn.ModuleList(
291
+ TransformerBlock(config, use_sdpa=False) for _ in range(config.n_fast_layer)
292
+ )
293
+ self.fast_norm = RMSNorm(config.dim, eps=config.norm_eps)
294
+ self.fast_output = nn.Linear(
295
+ config.dim,
296
+ config.codebook_size,
297
+ bias=False,
298
+ )
299
+
300
+ def setup_caches(
301
+ self, max_batch_size: int, max_seq_len: int, dtype: torch.dtype = torch.bfloat16
302
+ ):
303
+ super().setup_caches(max_batch_size, max_seq_len, dtype)
304
+
305
+ head_dim = self.config.dim // self.config.n_head
306
+
307
+ # Fast transformer
308
+ # The max seq len here is the number of codebooks
309
+ for b in self.fast_layers:
310
+ b.attention.kv_cache = KVCache(
311
+ max_batch_size,
312
+ self.config.num_codebooks,
313
+ self.config.n_local_heads,
314
+ head_dim,
315
+ dtype=dtype,
316
+ )
317
+
318
+ def forward(
319
+ self, inp: Tensor, key_padding_mask: Optional[Tensor] = None
320
+ ) -> TransformerForwardResult:
321
+ parent_result = super().forward(inp, key_padding_mask)
322
+ token_logits = parent_result.logits
323
+ x = parent_result.hidden_states
324
+
325
+ # Fast transformer
326
+ fast_seq_len = self.config.num_codebooks
327
+ fast_mask = self.causal_mask[
328
+ None, None, :fast_seq_len, :fast_seq_len
329
+ ] # (B, N, Q, K)
330
+ fast_freqs_cis = self.freqs_cis[:fast_seq_len]
331
+
332
+ # Drop the last token and rotate left
333
+ codebooks = inp[:, 1:-1, 1:]
334
+ codebooks = F.pad(codebooks, (0, 1), value=self.config.codebook_padding_idx)
335
+ codebook_embeddings = self.fast_embeddings(codebooks)
336
+ x = torch.cat([x[:, None], codebook_embeddings], dim=1)
337
+ b, s = x.size(0), x.size(2)
338
+ x = rearrange(x, "b n s d -> (b s) n d") # flatten the batch and seq_len
339
+
340
+ # Remove padded part
341
+ codebooks = rearrange(codebooks, "b n s -> (b s) n")
342
+ codebook_mask = (codebooks == self.config.codebook_padding_idx).all(dim=-1)
343
+ x_bs, x_len = x.size(0), x.size(1)
344
+ x = x[~codebook_mask]
345
+
346
+ for layer in self.fast_layers:
347
+ if self.config.use_gradient_checkpointing and self.training:
348
+ x = checkpoint(layer, x, fast_freqs_cis, fast_mask, use_reentrant=True)
349
+ else:
350
+ x = layer(x, fast_freqs_cis, fast_mask)
351
+
352
+ # unflatten the batch and num_codebooks
353
+ fast_out = self.fast_norm(x)
354
+ codebook_logits = self.fast_output(fast_out)
355
+
356
+ # Re-pad the codebook_logits
357
+ buffer = torch.zeros(
358
+ x_bs,
359
+ x_len,
360
+ codebook_logits.size(-1),
361
+ device=codebook_logits.device,
362
+ dtype=codebook_logits.dtype,
363
+ )
364
+ buffer[~codebook_mask] = codebook_logits
365
+ codebook_logits = buffer
366
+
367
+ assert codebook_logits.shape[1] == self.config.num_codebooks
368
+ codebook_logits = rearrange(
369
+ codebook_logits,
370
+ "(b s) n d -> b s n d",
371
+ b=b,
372
+ s=s,
373
+ n=self.config.num_codebooks,
374
+ )
375
+
376
+ return TransformerForwardResult(
377
+ token_logits=token_logits,
378
+ codebook_logits=codebook_logits,
379
+ )
380
+
381
+ def forward_generate_fast(
382
+ self, x: Tensor, input_pos: Optional[Tensor] = None
383
+ ) -> Tensor:
384
+ # Fast transformer
385
+ x = x.view(1, 1, -1)
386
+
387
+ fast_mask = self.causal_mask[
388
+ None, None, input_pos, : self.config.num_codebooks
389
+ ] # (B, N, Q, K)
390
+ fast_freqs_cis = self.freqs_cis[input_pos]
391
+
392
+ for layer in self.fast_layers:
393
+ x = layer(x, fast_freqs_cis, fast_mask, input_pos=input_pos)
394
+
395
+ # unflatten the batch and num_codebooks
396
+ fast_out = self.fast_norm(x) # only take the last token
397
+ codebook_logits = self.fast_output(fast_out)
398
+
399
+ return codebook_logits
400
+
401
+
402
+ class TransformerBlock(nn.Module):
403
+ def __init__(self, config: BaseModelArgs, use_sdpa: bool = True) -> None:
404
+ super().__init__()
405
+ self.attention = Attention(config, use_sdpa=use_sdpa)
406
+ self.feed_forward = FeedForward(config)
407
+ self.ffn_norm = RMSNorm(config.dim, config.norm_eps)
408
+ self.attention_norm = RMSNorm(config.dim, config.norm_eps)
409
+
410
+ def forward(
411
+ self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Tensor = None
412
+ ) -> Tensor:
413
+ h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos)
414
+ out = h + self.feed_forward(self.ffn_norm(h))
415
+ return out
416
+
417
+
418
+ class Attention(nn.Module):
419
+ def __init__(self, config: BaseModelArgs, use_sdpa: bool = True):
420
+ super().__init__()
421
+ assert config.dim % config.n_head == 0
422
+
423
+ total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
424
+ # key, query, value projections for all heads, but in a batch
425
+ self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False)
426
+ self.wo = nn.Linear(config.dim, config.dim, bias=False)
427
+ self.kv_cache = None
428
+
429
+ self.dropout = config.dropout
430
+ self.n_head = config.n_head
431
+ self.head_dim = config.head_dim
432
+ self.n_local_heads = config.n_local_heads
433
+ self.dim = config.dim
434
+ self.use_sdpa = use_sdpa
435
+ self._register_load_state_dict_pre_hook(self.load_hook)
436
+
437
+ def load_hook(self, state_dict, prefix, *args):
438
+ if prefix + "wq.weight" in state_dict:
439
+ wq = state_dict.pop(prefix + "wq.weight")
440
+ wk = state_dict.pop(prefix + "wk.weight")
441
+ wv = state_dict.pop(prefix + "wv.weight")
442
+ state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])
443
+
444
+ def forward(
445
+ self,
446
+ x: Tensor,
447
+ freqs_cis: Tensor,
448
+ mask: Tensor,
449
+ input_pos: Optional[Tensor] = None,
450
+ ) -> Tensor:
451
+ bsz, seqlen, _ = x.shape
452
+
453
+ kv_size = self.n_local_heads * self.head_dim
454
+ q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)
455
+
456
+ q = q.view(bsz, seqlen, self.n_head, self.head_dim)
457
+ k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim)
458
+ v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim)
459
+
460
+ q = apply_rotary_emb(q, freqs_cis)
461
+ k = apply_rotary_emb(k, freqs_cis)
462
+
463
+ q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
464
+
465
+ if self.kv_cache is not None:
466
+ k, v = self.kv_cache.update(input_pos, k, v)
467
+
468
+ k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
469
+ v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
470
+
471
+ if self.use_sdpa:
472
+ y = F.scaled_dot_product_attention(
473
+ q,
474
+ k,
475
+ v,
476
+ attn_mask=mask,
477
+ dropout_p=self.dropout if self.training else 0.0,
478
+ )
479
+ else:
480
+ y = self.eq_scaled_dot_product_attention(
481
+ q,
482
+ k,
483
+ v,
484
+ attn_mask=mask,
485
+ dropout_p=self.dropout if self.training else 0.0,
486
+ )
487
+
488
+ y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
489
+
490
+ return self.wo(y)
491
+
492
+ def eq_scaled_dot_product_attention(
493
+ self,
494
+ query,
495
+ key,
496
+ value,
497
+ attn_mask=None,
498
+ dropout_p=0.0,
499
+ ) -> torch.Tensor:
500
+ # This is a standard scaled dot product attention
501
+ # It's low efficient, but it doesn't raise cuda error
502
+
503
+ L, S = query.size(-2), key.size(-2)
504
+ scale_factor = 1 / math.sqrt(query.size(-1))
505
+ attn_bias = torch.zeros(1, 1, L, S, dtype=query.dtype, device=query.device)
506
+
507
+ if attn_mask is not None:
508
+ if attn_mask.dtype == torch.bool:
509
+ attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
510
+ else:
511
+ attn_bias += attn_mask
512
+
513
+ attn_weight = query @ key.transpose(-2, -1) * scale_factor
514
+ attn_weight += attn_bias
515
+ attn_weight = torch.softmax(attn_weight, dim=-1)
516
+ attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
517
+
518
+ return attn_weight @ value
519
+
520
+
521
+ class FeedForward(nn.Module):
522
+ def __init__(self, config: BaseModelArgs) -> None:
523
+ super().__init__()
524
+ self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False)
525
+ self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False)
526
+ self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False)
527
+
528
+ def forward(self, x: Tensor) -> Tensor:
529
+ return self.w2(F.silu(self.w1(x)) * self.w3(x))
530
+
531
+
532
+ class RMSNorm(nn.Module):
533
+ def __init__(self, dim: int, eps: float = 1e-5):
534
+ super().__init__()
535
+ self.eps = eps
536
+ self.weight = nn.Parameter(torch.ones(dim))
537
+
538
+ def _norm(self, x):
539
+ return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
540
+
541
+ def forward(self, x: Tensor) -> Tensor:
542
+ output = self._norm(x.float()).type_as(x)
543
+ return output * self.weight
544
+
545
+
546
+ def precompute_freqs_cis(seq_len: int, n_elem: int, base: int = 10000) -> Tensor:
547
+ freqs = 1.0 / (
548
+ base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)
549
+ )
550
+ t = torch.arange(seq_len, device=freqs.device)
551
+ freqs = torch.outer(t, freqs)
552
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
553
+ cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
554
+ return cache.to(dtype=torch.bfloat16)
555
+
556
+
557
+ def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
558
+ xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
559
+ freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2)
560
+ x_out2 = torch.stack(
561
+ [
562
+ xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
563
+ xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
564
+ ],
565
+ -1,
566
+ )
567
+
568
+ x_out2 = x_out2.flatten(3)
569
+ return x_out2.type_as(x)
570
+
571
+
572
+ if __name__ == "__main__":
573
+ args = DualARModelArgs(
574
+ max_seq_len=4096,
575
+ vocab_size=32312,
576
+ n_layer=12,
577
+ n_fast_layer=4,
578
+ n_head=12,
579
+ dim=768,
580
+ rope_base=10000,
581
+ norm_eps=1e-5,
582
+ codebook_size=128,
583
+ num_codebooks=4,
584
+ )
585
+
586
+ model = DualARTransformer(args)
587
+ model = model.cuda().bfloat16()
588
+ print("Total params:", sum(i.numel() for i in model.parameters()) / 1024 / 1024)
589
+
590
+ inputs = torch.randint(0, 100, (2, 5, 128)).cuda()
591
+ key_padding_mask = torch.zeros(2, 128).bool().cuda()
592
+ key_padding_mask[0, 2:] = True
593
+ x1 = model(inputs, key_padding_mask=key_padding_mask)
594
+ print(x1.token_logits.shape)
595
+ print(x1.codebook_logits.shape)
fish_speech/models/vqgan/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .lit_module import VQGAN
2
+
3
+ __all__ = ["VQGAN"]
fish_speech/models/vqgan/lit_module.py ADDED
@@ -0,0 +1,442 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+ import math
3
+ from typing import Any, Callable
4
+
5
+ import lightning as L
6
+ import torch
7
+ import torch.nn.functional as F
8
+ import wandb
9
+ from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
10
+ from matplotlib import pyplot as plt
11
+ from torch import nn
12
+
13
+ from fish_speech.models.vqgan.modules.discriminator import Discriminator
14
+ from fish_speech.models.vqgan.modules.wavenet import WaveNet
15
+ from fish_speech.models.vqgan.utils import avg_with_mask, plot_mel, sequence_mask
16
+
17
+
18
+ class VQGAN(L.LightningModule):
19
+ def __init__(
20
+ self,
21
+ optimizer: Callable,
22
+ lr_scheduler: Callable,
23
+ encoder: WaveNet,
24
+ quantizer: nn.Module,
25
+ decoder: WaveNet,
26
+ discriminator: Discriminator,
27
+ vocoder: nn.Module,
28
+ encode_mel_transform: nn.Module,
29
+ gt_mel_transform: nn.Module,
30
+ weight_adv: float = 1.0,
31
+ weight_vq: float = 1.0,
32
+ weight_mel: float = 1.0,
33
+ sampling_rate: int = 44100,
34
+ freeze_encoder: bool = False,
35
+ ):
36
+ super().__init__()
37
+
38
+ # Model parameters
39
+ self.optimizer_builder = optimizer
40
+ self.lr_scheduler_builder = lr_scheduler
41
+
42
+ # Modules
43
+ self.encoder = encoder
44
+ self.quantizer = quantizer
45
+ self.decoder = decoder
46
+ self.vocoder = vocoder
47
+ self.discriminator = discriminator
48
+ self.encode_mel_transform = encode_mel_transform
49
+ self.gt_mel_transform = gt_mel_transform
50
+
51
+ # A simple linear layer to project quality to condition channels
52
+ self.quality_projection = nn.Linear(1, 768)
53
+
54
+ # Freeze vocoder
55
+ for param in self.vocoder.parameters():
56
+ param.requires_grad = False
57
+
58
+ # Loss weights
59
+ self.weight_adv = weight_adv
60
+ self.weight_vq = weight_vq
61
+ self.weight_mel = weight_mel
62
+
63
+ # Other parameters
64
+ self.sampling_rate = sampling_rate
65
+
66
+ # Disable strict loading
67
+ self.strict_loading = False
68
+
69
+ # If encoder is frozen
70
+ if freeze_encoder:
71
+ for param in self.encoder.parameters():
72
+ param.requires_grad = False
73
+
74
+ for param in self.quantizer.parameters():
75
+ param.requires_grad = False
76
+
77
+ self.automatic_optimization = False
78
+
79
+ def on_save_checkpoint(self, checkpoint):
80
+ # Do not save vocoder
81
+ state_dict = checkpoint["state_dict"]
82
+ for name in list(state_dict.keys()):
83
+ if "vocoder" in name:
84
+ state_dict.pop(name)
85
+
86
+ def configure_optimizers(self):
87
+ optimizer_generator = self.optimizer_builder(
88
+ itertools.chain(
89
+ self.encoder.parameters(),
90
+ self.quantizer.parameters(),
91
+ self.decoder.parameters(),
92
+ self.quality_projection.parameters(),
93
+ )
94
+ )
95
+ optimizer_discriminator = self.optimizer_builder(
96
+ self.discriminator.parameters()
97
+ )
98
+
99
+ lr_scheduler_generator = self.lr_scheduler_builder(optimizer_generator)
100
+ lr_scheduler_discriminator = self.lr_scheduler_builder(optimizer_discriminator)
101
+
102
+ return (
103
+ {
104
+ "optimizer": optimizer_generator,
105
+ "lr_scheduler": {
106
+ "scheduler": lr_scheduler_generator,
107
+ "interval": "step",
108
+ "name": "optimizer/generator",
109
+ },
110
+ },
111
+ {
112
+ "optimizer": optimizer_discriminator,
113
+ "lr_scheduler": {
114
+ "scheduler": lr_scheduler_discriminator,
115
+ "interval": "step",
116
+ "name": "optimizer/discriminator",
117
+ },
118
+ },
119
+ )
120
+
121
+ def training_step(self, batch, batch_idx):
122
+ optim_g, optim_d = self.optimizers()
123
+
124
+ audios, audio_lengths = batch["audios"], batch["audio_lengths"]
125
+
126
+ audios = audios.float()
127
+ audios = audios[:, None, :]
128
+
129
+ with torch.no_grad():
130
+ encoded_mels = self.encode_mel_transform(audios)
131
+ gt_mels = self.gt_mel_transform(audios)
132
+ quality = ((gt_mels.mean(-1) > -8).sum(-1) - 90) / 10
133
+ quality = quality.unsqueeze(-1)
134
+
135
+ mel_lengths = audio_lengths // self.gt_mel_transform.hop_length
136
+ mel_masks = sequence_mask(mel_lengths, gt_mels.shape[2])
137
+ mel_masks_float_conv = mel_masks[:, None, :].float()
138
+ gt_mels = gt_mels * mel_masks_float_conv
139
+ encoded_mels = encoded_mels * mel_masks_float_conv
140
+
141
+ # Encode
142
+ encoded_features = self.encoder(encoded_mels) * mel_masks_float_conv
143
+
144
+ # Quantize
145
+ vq_result = self.quantizer(encoded_features)
146
+ loss_vq = getattr("vq_result", "loss", 0.0)
147
+ vq_recon_features = vq_result.z * mel_masks_float_conv
148
+ vq_recon_features = (
149
+ vq_recon_features + self.quality_projection(quality)[:, :, None]
150
+ )
151
+
152
+ # VQ Decode
153
+ gen_mel = (
154
+ self.decoder(
155
+ torch.randn_like(vq_recon_features) * mel_masks_float_conv,
156
+ condition=vq_recon_features,
157
+ )
158
+ * mel_masks_float_conv
159
+ )
160
+
161
+ # Discriminator
162
+ real_logits = self.discriminator(gt_mels)
163
+ fake_logits = self.discriminator(gen_mel.detach())
164
+ d_mask = F.interpolate(
165
+ mel_masks_float_conv, size=(real_logits.shape[2],), mode="nearest"
166
+ )
167
+
168
+ loss_real = avg_with_mask((real_logits - 1) ** 2, d_mask)
169
+ loss_fake = avg_with_mask(fake_logits**2, d_mask)
170
+
171
+ loss_d = loss_real + loss_fake
172
+
173
+ self.log(
174
+ "train/discriminator/loss",
175
+ loss_d,
176
+ on_step=True,
177
+ on_epoch=False,
178
+ prog_bar=True,
179
+ logger=True,
180
+ )
181
+
182
+ # Discriminator backward
183
+ optim_d.zero_grad()
184
+ self.manual_backward(loss_d)
185
+ self.clip_gradients(
186
+ optim_d, gradient_clip_val=1000.0, gradient_clip_algorithm="norm"
187
+ )
188
+ optim_d.step()
189
+
190
+ # Mel Loss, applying l1, using a weighted sum
191
+ mel_distance = (
192
+ gen_mel - gt_mels
193
+ ).abs() # * 0.5 + self.ssim(gen_mel, gt_mels) * 0.5
194
+ loss_mel_low_freq = avg_with_mask(mel_distance[:, :40, :], mel_masks_float_conv)
195
+ loss_mel_mid_freq = avg_with_mask(
196
+ mel_distance[:, 40:70, :], mel_masks_float_conv
197
+ )
198
+ loss_mel_high_freq = avg_with_mask(
199
+ mel_distance[:, 70:, :], mel_masks_float_conv
200
+ )
201
+ loss_mel = (
202
+ loss_mel_low_freq * 0.6 + loss_mel_mid_freq * 0.3 + loss_mel_high_freq * 0.1
203
+ )
204
+
205
+ # Adversarial Loss
206
+ fake_logits = self.discriminator(gen_mel)
207
+ loss_adv = avg_with_mask((fake_logits - 1) ** 2, d_mask)
208
+
209
+ # Total loss
210
+ loss = (
211
+ self.weight_vq * loss_vq
212
+ + self.weight_mel * loss_mel
213
+ + self.weight_adv * loss_adv
214
+ )
215
+
216
+ # Log losses
217
+ self.log(
218
+ "train/generator/loss",
219
+ loss,
220
+ on_step=True,
221
+ on_epoch=False,
222
+ prog_bar=True,
223
+ logger=True,
224
+ )
225
+ self.log(
226
+ "train/generator/loss_vq",
227
+ loss_vq,
228
+ on_step=True,
229
+ on_epoch=False,
230
+ prog_bar=False,
231
+ logger=True,
232
+ )
233
+ self.log(
234
+ "train/generator/loss_mel",
235
+ loss_mel,
236
+ on_step=True,
237
+ on_epoch=False,
238
+ prog_bar=False,
239
+ logger=True,
240
+ )
241
+ self.log(
242
+ "train/generator/loss_adv",
243
+ loss_adv,
244
+ on_step=True,
245
+ on_epoch=False,
246
+ prog_bar=False,
247
+ logger=True,
248
+ )
249
+
250
+ # Generator backward
251
+ optim_g.zero_grad()
252
+ self.manual_backward(loss)
253
+ self.clip_gradients(
254
+ optim_g, gradient_clip_val=1000.0, gradient_clip_algorithm="norm"
255
+ )
256
+ optim_g.step()
257
+
258
+ scheduler_g, scheduler_d = self.lr_schedulers()
259
+ scheduler_g.step()
260
+ scheduler_d.step()
261
+
262
+ def validation_step(self, batch: Any, batch_idx: int):
263
+ audios, audio_lengths = batch["audios"], batch["audio_lengths"]
264
+
265
+ audios = audios.float()
266
+ audios = audios[:, None, :]
267
+
268
+ encoded_mels = self.encode_mel_transform(audios)
269
+ gt_mels = self.gt_mel_transform(audios)
270
+
271
+ mel_lengths = audio_lengths // self.gt_mel_transform.hop_length
272
+ mel_masks = sequence_mask(mel_lengths, gt_mels.shape[2])
273
+ mel_masks_float_conv = mel_masks[:, None, :].float()
274
+ gt_mels = gt_mels * mel_masks_float_conv
275
+ encoded_mels = encoded_mels * mel_masks_float_conv
276
+
277
+ # Encode
278
+ encoded_features = self.encoder(encoded_mels) * mel_masks_float_conv
279
+
280
+ # Quantize
281
+ vq_recon_features = self.quantizer(encoded_features).z * mel_masks_float_conv
282
+ vq_recon_features = (
283
+ vq_recon_features
284
+ + self.quality_projection(
285
+ torch.ones(
286
+ vq_recon_features.shape[0], 1, device=vq_recon_features.device
287
+ )
288
+ * 2
289
+ )[:, :, None]
290
+ )
291
+
292
+ # VQ Decode
293
+ gen_aux_mels = (
294
+ self.decoder(
295
+ torch.randn_like(vq_recon_features) * mel_masks_float_conv,
296
+ condition=vq_recon_features,
297
+ )
298
+ * mel_masks_float_conv
299
+ )
300
+ loss_mel = avg_with_mask((gen_aux_mels - gt_mels).abs(), mel_masks_float_conv)
301
+
302
+ self.log(
303
+ "val/loss_mel",
304
+ loss_mel,
305
+ on_step=False,
306
+ on_epoch=True,
307
+ prog_bar=False,
308
+ logger=True,
309
+ sync_dist=True,
310
+ )
311
+
312
+ recon_audios = self.vocoder(gt_mels)
313
+ gen_aux_audios = self.vocoder(gen_aux_mels)
314
+
315
+ # only log the first batch
316
+ if batch_idx != 0:
317
+ return
318
+
319
+ for idx, (
320
+ gt_mel,
321
+ gen_aux_mel,
322
+ audio,
323
+ gen_aux_audio,
324
+ recon_audio,
325
+ audio_len,
326
+ ) in enumerate(
327
+ zip(
328
+ gt_mels,
329
+ gen_aux_mels,
330
+ audios.cpu().float(),
331
+ gen_aux_audios.cpu().float(),
332
+ recon_audios.cpu().float(),
333
+ audio_lengths,
334
+ )
335
+ ):
336
+ if idx > 4:
337
+ break
338
+
339
+ mel_len = audio_len // self.gt_mel_transform.hop_length
340
+
341
+ image_mels = plot_mel(
342
+ [
343
+ gt_mel[:, :mel_len],
344
+ gen_aux_mel[:, :mel_len],
345
+ ],
346
+ [
347
+ "Ground-Truth",
348
+ "Auxiliary",
349
+ ],
350
+ )
351
+
352
+ if isinstance(self.logger, WandbLogger):
353
+ self.logger.experiment.log(
354
+ {
355
+ "reconstruction_mel": wandb.Image(image_mels, caption="mels"),
356
+ "wavs": [
357
+ wandb.Audio(
358
+ audio[0, :audio_len],
359
+ sample_rate=self.sampling_rate,
360
+ caption="gt",
361
+ ),
362
+ wandb.Audio(
363
+ gen_aux_audio[0, :audio_len],
364
+ sample_rate=self.sampling_rate,
365
+ caption="aux",
366
+ ),
367
+ wandb.Audio(
368
+ recon_audio[0, :audio_len],
369
+ sample_rate=self.sampling_rate,
370
+ caption="recon",
371
+ ),
372
+ ],
373
+ },
374
+ )
375
+
376
+ if isinstance(self.logger, TensorBoardLogger):
377
+ self.logger.experiment.add_figure(
378
+ f"sample-{idx}/mels",
379
+ image_mels,
380
+ global_step=self.global_step,
381
+ )
382
+ self.logger.experiment.add_audio(
383
+ f"sample-{idx}/wavs/gt",
384
+ audio[0, :audio_len],
385
+ self.global_step,
386
+ sample_rate=self.sampling_rate,
387
+ )
388
+ self.logger.experiment.add_audio(
389
+ f"sample-{idx}/wavs/gen",
390
+ gen_aux_audio[0, :audio_len],
391
+ self.global_step,
392
+ sample_rate=self.sampling_rate,
393
+ )
394
+ self.logger.experiment.add_audio(
395
+ f"sample-{idx}/wavs/recon",
396
+ recon_audio[0, :audio_len],
397
+ self.global_step,
398
+ sample_rate=self.sampling_rate,
399
+ )
400
+
401
+ plt.close(image_mels)
402
+
403
+ def encode(self, audios, audio_lengths):
404
+ audios = audios.float()
405
+
406
+ mels = self.encode_mel_transform(audios)
407
+ mel_lengths = audio_lengths // self.encode_mel_transform.hop_length
408
+ mel_masks = sequence_mask(mel_lengths, mels.shape[2])
409
+ mel_masks_float_conv = mel_masks[:, None, :].float()
410
+ mels = mels * mel_masks_float_conv
411
+
412
+ # Encode
413
+ encoded_features = self.encoder(mels) * mel_masks_float_conv
414
+ feature_lengths = mel_lengths // math.prod(self.quantizer.downsample_factor)
415
+
416
+ return self.quantizer.encode(encoded_features), feature_lengths
417
+
418
+ def decode(self, indices, feature_lengths, return_audios=False):
419
+ factor = math.prod(self.quantizer.downsample_factor)
420
+ mel_masks = sequence_mask(feature_lengths * factor, indices.shape[2] * factor)
421
+ mel_masks_float_conv = mel_masks[:, None, :].float()
422
+
423
+ z = self.quantizer.decode(indices) * mel_masks_float_conv
424
+ z = (
425
+ z
426
+ + self.quality_projection(torch.ones(z.shape[0], 1, device=z.device) * 2)[
427
+ :, :, None
428
+ ]
429
+ )
430
+
431
+ gen_mel = (
432
+ self.decoder(
433
+ torch.randn_like(z) * mel_masks_float_conv,
434
+ condition=z,
435
+ )
436
+ * mel_masks_float_conv
437
+ )
438
+
439
+ if return_audios:
440
+ return self.vocoder(gen_mel)
441
+
442
+ return gen_mel
fish_speech/models/vqgan/modules/discriminator.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torch.nn.utils.parametrizations import weight_norm
4
+
5
+
6
+ class Discriminator(nn.Module):
7
+ def __init__(self):
8
+ super().__init__()
9
+
10
+ blocks = []
11
+ convs = [
12
+ (1, 64, (3, 9), 1, (1, 4)),
13
+ (64, 128, (3, 9), (1, 2), (1, 4)),
14
+ (128, 256, (3, 9), (1, 2), (1, 4)),
15
+ (256, 512, (3, 9), (1, 2), (1, 4)),
16
+ (512, 1024, (3, 3), 1, (1, 1)),
17
+ (1024, 1, (3, 3), 1, (1, 1)),
18
+ ]
19
+
20
+ for idx, (in_channels, out_channels, kernel_size, stride, padding) in enumerate(
21
+ convs
22
+ ):
23
+ blocks.append(
24
+ weight_norm(
25
+ nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
26
+ )
27
+ )
28
+
29
+ if idx != len(convs) - 1:
30
+ blocks.append(nn.SiLU(inplace=True))
31
+
32
+ self.blocks = nn.Sequential(*blocks)
33
+
34
+ def forward(self, x):
35
+ return self.blocks(x[:, None])[:, 0]
36
+
37
+
38
+ if __name__ == "__main__":
39
+ model = Discriminator()
40
+ print(sum(p.numel() for p in model.parameters()) / 1_000_000)
41
+ x = torch.randn(1, 128, 1024)
42
+ y = model(x)
43
+ print(y.shape)
44
+ print(y)
fish_speech/models/vqgan/modules/firefly.py ADDED
@@ -0,0 +1,538 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # A inference only version of the FireflyGAN model
2
+
3
+ from functools import partial
4
+ from math import prod
5
+ from typing import Callable
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn.functional as F
10
+ from torch import nn
11
+ from torch.nn import Conv1d
12
+ from torch.nn.utils.parametrizations import weight_norm
13
+ from torch.nn.utils.parametrize import remove_parametrizations
14
+ from torch.utils.checkpoint import checkpoint
15
+
16
+
17
+ def init_weights(m, mean=0.0, std=0.01):
18
+ classname = m.__class__.__name__
19
+ if classname.find("Conv") != -1:
20
+ m.weight.data.normal_(mean, std)
21
+
22
+
23
+ def get_padding(kernel_size, dilation=1):
24
+ return (kernel_size * dilation - dilation) // 2
25
+
26
+
27
+ class ResBlock1(torch.nn.Module):
28
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
29
+ super().__init__()
30
+
31
+ self.convs1 = nn.ModuleList(
32
+ [
33
+ weight_norm(
34
+ Conv1d(
35
+ channels,
36
+ channels,
37
+ kernel_size,
38
+ 1,
39
+ dilation=dilation[0],
40
+ padding=get_padding(kernel_size, dilation[0]),
41
+ )
42
+ ),
43
+ weight_norm(
44
+ Conv1d(
45
+ channels,
46
+ channels,
47
+ kernel_size,
48
+ 1,
49
+ dilation=dilation[1],
50
+ padding=get_padding(kernel_size, dilation[1]),
51
+ )
52
+ ),
53
+ weight_norm(
54
+ Conv1d(
55
+ channels,
56
+ channels,
57
+ kernel_size,
58
+ 1,
59
+ dilation=dilation[2],
60
+ padding=get_padding(kernel_size, dilation[2]),
61
+ )
62
+ ),
63
+ ]
64
+ )
65
+ self.convs1.apply(init_weights)
66
+
67
+ self.convs2 = nn.ModuleList(
68
+ [
69
+ weight_norm(
70
+ Conv1d(
71
+ channels,
72
+ channels,
73
+ kernel_size,
74
+ 1,
75
+ dilation=1,
76
+ padding=get_padding(kernel_size, 1),
77
+ )
78
+ ),
79
+ weight_norm(
80
+ Conv1d(
81
+ channels,
82
+ channels,
83
+ kernel_size,
84
+ 1,
85
+ dilation=1,
86
+ padding=get_padding(kernel_size, 1),
87
+ )
88
+ ),
89
+ weight_norm(
90
+ Conv1d(
91
+ channels,
92
+ channels,
93
+ kernel_size,
94
+ 1,
95
+ dilation=1,
96
+ padding=get_padding(kernel_size, 1),
97
+ )
98
+ ),
99
+ ]
100
+ )
101
+ self.convs2.apply(init_weights)
102
+
103
+ def forward(self, x):
104
+ for c1, c2 in zip(self.convs1, self.convs2):
105
+ xt = F.silu(x)
106
+ xt = c1(xt)
107
+ xt = F.silu(xt)
108
+ xt = c2(xt)
109
+ x = xt + x
110
+ return x
111
+
112
+ def remove_parametrizations(self):
113
+ for conv in self.convs1:
114
+ remove_parametrizations(conv, tensor_name="weight")
115
+ for conv in self.convs2:
116
+ remove_parametrizations(conv, tensor_name="weight")
117
+
118
+
119
+ class ParralelBlock(nn.Module):
120
+ def __init__(
121
+ self,
122
+ channels: int,
123
+ kernel_sizes: tuple[int] = (3, 7, 11),
124
+ dilation_sizes: tuple[tuple[int]] = ((1, 3, 5), (1, 3, 5), (1, 3, 5)),
125
+ ):
126
+ super().__init__()
127
+
128
+ assert len(kernel_sizes) == len(dilation_sizes)
129
+
130
+ self.blocks = nn.ModuleList()
131
+ for k, d in zip(kernel_sizes, dilation_sizes):
132
+ self.blocks.append(ResBlock1(channels, k, d))
133
+
134
+ def forward(self, x):
135
+ return torch.stack([block(x) for block in self.blocks], dim=0).mean(dim=0)
136
+
137
+ def remove_parametrizations(self):
138
+ for block in self.blocks:
139
+ block.remove_parametrizations()
140
+
141
+
142
+ class HiFiGANGenerator(nn.Module):
143
+ def __init__(
144
+ self,
145
+ *,
146
+ hop_length: int = 512,
147
+ upsample_rates: tuple[int] = (8, 8, 2, 2, 2),
148
+ upsample_kernel_sizes: tuple[int] = (16, 16, 8, 2, 2),
149
+ resblock_kernel_sizes: tuple[int] = (3, 7, 11),
150
+ resblock_dilation_sizes: tuple[tuple[int]] = ((1, 3, 5), (1, 3, 5), (1, 3, 5)),
151
+ num_mels: int = 128,
152
+ upsample_initial_channel: int = 512,
153
+ use_template: bool = True,
154
+ pre_conv_kernel_size: int = 7,
155
+ post_conv_kernel_size: int = 7,
156
+ post_activation: Callable = partial(nn.SiLU, inplace=True),
157
+ ):
158
+ super().__init__()
159
+
160
+ assert (
161
+ prod(upsample_rates) == hop_length
162
+ ), f"hop_length must be {prod(upsample_rates)}"
163
+
164
+ self.conv_pre = weight_norm(
165
+ nn.Conv1d(
166
+ num_mels,
167
+ upsample_initial_channel,
168
+ pre_conv_kernel_size,
169
+ 1,
170
+ padding=get_padding(pre_conv_kernel_size),
171
+ )
172
+ )
173
+
174
+ self.num_upsamples = len(upsample_rates)
175
+ self.num_kernels = len(resblock_kernel_sizes)
176
+
177
+ self.noise_convs = nn.ModuleList()
178
+ self.use_template = use_template
179
+ self.ups = nn.ModuleList()
180
+
181
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
182
+ c_cur = upsample_initial_channel // (2 ** (i + 1))
183
+ self.ups.append(
184
+ weight_norm(
185
+ nn.ConvTranspose1d(
186
+ upsample_initial_channel // (2**i),
187
+ upsample_initial_channel // (2 ** (i + 1)),
188
+ k,
189
+ u,
190
+ padding=(k - u) // 2,
191
+ )
192
+ )
193
+ )
194
+
195
+ if not use_template:
196
+ continue
197
+
198
+ if i + 1 < len(upsample_rates):
199
+ stride_f0 = np.prod(upsample_rates[i + 1 :])
200
+ self.noise_convs.append(
201
+ Conv1d(
202
+ 1,
203
+ c_cur,
204
+ kernel_size=stride_f0 * 2,
205
+ stride=stride_f0,
206
+ padding=stride_f0 // 2,
207
+ )
208
+ )
209
+ else:
210
+ self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1))
211
+
212
+ self.resblocks = nn.ModuleList()
213
+ for i in range(len(self.ups)):
214
+ ch = upsample_initial_channel // (2 ** (i + 1))
215
+ self.resblocks.append(
216
+ ParralelBlock(ch, resblock_kernel_sizes, resblock_dilation_sizes)
217
+ )
218
+
219
+ self.activation_post = post_activation()
220
+ self.conv_post = weight_norm(
221
+ nn.Conv1d(
222
+ ch,
223
+ 1,
224
+ post_conv_kernel_size,
225
+ 1,
226
+ padding=get_padding(post_conv_kernel_size),
227
+ )
228
+ )
229
+ self.ups.apply(init_weights)
230
+ self.conv_post.apply(init_weights)
231
+
232
+ def forward(self, x, template=None):
233
+ x = self.conv_pre(x)
234
+
235
+ for i in range(self.num_upsamples):
236
+ x = F.silu(x, inplace=True)
237
+ x = self.ups[i](x)
238
+
239
+ if self.use_template:
240
+ x = x + self.noise_convs[i](template)
241
+
242
+ if self.training and self.checkpointing:
243
+ x = checkpoint(
244
+ self.resblocks[i],
245
+ x,
246
+ use_reentrant=False,
247
+ )
248
+ else:
249
+ x = self.resblocks[i](x)
250
+
251
+ x = self.activation_post(x)
252
+ x = self.conv_post(x)
253
+ x = torch.tanh(x)
254
+
255
+ return x
256
+
257
+ def remove_parametrizations(self):
258
+ for up in self.ups:
259
+ remove_parametrizations(up, tensor_name="weight")
260
+ for block in self.resblocks:
261
+ block.remove_parametrizations()
262
+ remove_parametrizations(self.conv_pre, tensor_name="weight")
263
+ remove_parametrizations(self.conv_post, tensor_name="weight")
264
+
265
+
266
+ # DropPath copied from timm library
267
+ def drop_path(
268
+ x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
269
+ ):
270
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
271
+
272
+ This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
273
+ the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
274
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
275
+ changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
276
+ 'survival rate' as the argument.
277
+
278
+ """ # noqa: E501
279
+
280
+ if drop_prob == 0.0 or not training:
281
+ return x
282
+ keep_prob = 1 - drop_prob
283
+ shape = (x.shape[0],) + (1,) * (
284
+ x.ndim - 1
285
+ ) # work with diff dim tensors, not just 2D ConvNets
286
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
287
+ if keep_prob > 0.0 and scale_by_keep:
288
+ random_tensor.div_(keep_prob)
289
+ return x * random_tensor
290
+
291
+
292
+ class DropPath(nn.Module):
293
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" # noqa: E501
294
+
295
+ def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
296
+ super(DropPath, self).__init__()
297
+ self.drop_prob = drop_prob
298
+ self.scale_by_keep = scale_by_keep
299
+
300
+ def forward(self, x):
301
+ return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
302
+
303
+ def extra_repr(self):
304
+ return f"drop_prob={round(self.drop_prob,3):0.3f}"
305
+
306
+
307
+ class LayerNorm(nn.Module):
308
+ r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
309
+ The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
310
+ shape (batch_size, height, width, channels) while channels_first corresponds to inputs
311
+ with shape (batch_size, channels, height, width).
312
+ """ # noqa: E501
313
+
314
+ def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
315
+ super().__init__()
316
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
317
+ self.bias = nn.Parameter(torch.zeros(normalized_shape))
318
+ self.eps = eps
319
+ self.data_format = data_format
320
+ if self.data_format not in ["channels_last", "channels_first"]:
321
+ raise NotImplementedError
322
+ self.normalized_shape = (normalized_shape,)
323
+
324
+ def forward(self, x):
325
+ if self.data_format == "channels_last":
326
+ return F.layer_norm(
327
+ x, self.normalized_shape, self.weight, self.bias, self.eps
328
+ )
329
+ elif self.data_format == "channels_first":
330
+ u = x.mean(1, keepdim=True)
331
+ s = (x - u).pow(2).mean(1, keepdim=True)
332
+ x = (x - u) / torch.sqrt(s + self.eps)
333
+ x = self.weight[:, None] * x + self.bias[:, None]
334
+ return x
335
+
336
+
337
+ # ConvNeXt Block copied from https://github.com/fishaudio/fish-diffusion/blob/main/fish_diffusion/modules/convnext.py
338
+ class ConvNeXtBlock(nn.Module):
339
+ r"""ConvNeXt Block. There are two equivalent implementations:
340
+ (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
341
+ (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
342
+ We use (2) as we find it slightly faster in PyTorch
343
+
344
+ Args:
345
+ dim (int): Number of input channels.
346
+ drop_path (float): Stochastic depth rate. Default: 0.0
347
+ layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
348
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0.
349
+ kernel_size (int): Kernel size for depthwise conv. Default: 7.
350
+ dilation (int): Dilation for depthwise conv. Default: 1.
351
+ """ # noqa: E501
352
+
353
+ def __init__(
354
+ self,
355
+ dim: int,
356
+ drop_path: float = 0.0,
357
+ layer_scale_init_value: float = 1e-6,
358
+ mlp_ratio: float = 4.0,
359
+ kernel_size: int = 7,
360
+ dilation: int = 1,
361
+ ):
362
+ super().__init__()
363
+
364
+ self.dwconv = nn.Conv1d(
365
+ dim,
366
+ dim,
367
+ kernel_size=kernel_size,
368
+ padding=int(dilation * (kernel_size - 1) / 2),
369
+ groups=dim,
370
+ ) # depthwise conv
371
+ self.norm = LayerNorm(dim, eps=1e-6)
372
+ self.pwconv1 = nn.Linear(
373
+ dim, int(mlp_ratio * dim)
374
+ ) # pointwise/1x1 convs, implemented with linear layers
375
+ self.act = nn.GELU()
376
+ self.pwconv2 = nn.Linear(int(mlp_ratio * dim), dim)
377
+ self.gamma = (
378
+ nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
379
+ if layer_scale_init_value > 0
380
+ else None
381
+ )
382
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
383
+
384
+ def forward(self, x, apply_residual: bool = True):
385
+ input = x
386
+
387
+ x = self.dwconv(x)
388
+ x = x.permute(0, 2, 1) # (N, C, L) -> (N, L, C)
389
+ x = self.norm(x)
390
+ x = self.pwconv1(x)
391
+ x = self.act(x)
392
+ x = self.pwconv2(x)
393
+
394
+ if self.gamma is not None:
395
+ x = self.gamma * x
396
+
397
+ x = x.permute(0, 2, 1) # (N, L, C) -> (N, C, L)
398
+ x = self.drop_path(x)
399
+
400
+ if apply_residual:
401
+ x = input + x
402
+
403
+ return x
404
+
405
+
406
+ class ConvNeXtEncoder(nn.Module):
407
+ def __init__(
408
+ self,
409
+ input_channels: int = 3,
410
+ depths: list[int] = [3, 3, 9, 3],
411
+ dims: list[int] = [96, 192, 384, 768],
412
+ drop_path_rate: float = 0.0,
413
+ layer_scale_init_value: float = 1e-6,
414
+ kernel_size: int = 7,
415
+ ):
416
+ super().__init__()
417
+ assert len(depths) == len(dims)
418
+
419
+ self.downsample_layers = nn.ModuleList()
420
+ stem = nn.Sequential(
421
+ nn.Conv1d(
422
+ input_channels,
423
+ dims[0],
424
+ kernel_size=kernel_size,
425
+ padding=kernel_size // 2,
426
+ padding_mode="zeros",
427
+ ),
428
+ LayerNorm(dims[0], eps=1e-6, data_format="channels_first"),
429
+ )
430
+ self.downsample_layers.append(stem)
431
+
432
+ for i in range(len(depths) - 1):
433
+ mid_layer = nn.Sequential(
434
+ LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
435
+ nn.Conv1d(dims[i], dims[i + 1], kernel_size=1),
436
+ )
437
+ self.downsample_layers.append(mid_layer)
438
+
439
+ self.stages = nn.ModuleList()
440
+ dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
441
+
442
+ cur = 0
443
+ for i in range(len(depths)):
444
+ stage = nn.Sequential(
445
+ *[
446
+ ConvNeXtBlock(
447
+ dim=dims[i],
448
+ drop_path=dp_rates[cur + j],
449
+ layer_scale_init_value=layer_scale_init_value,
450
+ kernel_size=kernel_size,
451
+ )
452
+ for j in range(depths[i])
453
+ ]
454
+ )
455
+ self.stages.append(stage)
456
+ cur += depths[i]
457
+
458
+ self.norm = LayerNorm(dims[-1], eps=1e-6, data_format="channels_first")
459
+ self.apply(self._init_weights)
460
+
461
+ def _init_weights(self, m):
462
+ if isinstance(m, (nn.Conv1d, nn.Linear)):
463
+ nn.init.trunc_normal_(m.weight, std=0.02)
464
+ nn.init.constant_(m.bias, 0)
465
+
466
+ def forward(
467
+ self,
468
+ x: torch.Tensor,
469
+ ) -> torch.Tensor:
470
+ for i in range(len(self.downsample_layers)):
471
+ x = self.downsample_layers[i](x)
472
+ x = self.stages[i](x)
473
+
474
+ return self.norm(x)
475
+
476
+
477
+ class FireflyBase(nn.Module):
478
+ def __init__(self, ckpt_path: str = None, pretrained: bool = True):
479
+ super().__init__()
480
+
481
+ self.backbone = ConvNeXtEncoder(
482
+ input_channels=128,
483
+ depths=[3, 3, 9, 3],
484
+ dims=[128, 256, 384, 512],
485
+ drop_path_rate=0.2,
486
+ kernel_size=7,
487
+ )
488
+
489
+ self.head = HiFiGANGenerator(
490
+ hop_length=512,
491
+ upsample_rates=[8, 8, 2, 2, 2],
492
+ upsample_kernel_sizes=[16, 16, 4, 4, 4],
493
+ resblock_kernel_sizes=[3, 7, 11],
494
+ resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
495
+ num_mels=512,
496
+ upsample_initial_channel=512,
497
+ use_template=False,
498
+ pre_conv_kernel_size=13,
499
+ post_conv_kernel_size=13,
500
+ )
501
+
502
+ if ckpt_path is not None:
503
+ self.load_state_dict(torch.load(ckpt_path, map_location="cpu"))
504
+ elif pretrained:
505
+ state_dict = torch.hub.load_state_dict_from_url(
506
+ "https://github.com/fishaudio/vocoder/releases/download/1.0.0/firefly-gan-base-generator.ckpt",
507
+ map_location="cpu",
508
+ )
509
+
510
+ if "state_dict" in state_dict:
511
+ state_dict = state_dict["state_dict"]
512
+
513
+ if any("generator." in k for k in state_dict):
514
+ state_dict = {
515
+ k.replace("generator.", ""): v
516
+ for k, v in state_dict.items()
517
+ if "generator." in k
518
+ }
519
+
520
+ self.load_state_dict(state_dict, strict=True)
521
+ self.head.remove_parametrizations()
522
+
523
+ @torch.no_grad()
524
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
525
+ x = self.backbone(x)
526
+ x = self.head(x)
527
+ if x.ndim == 2:
528
+ x = x[:, None, :]
529
+ return x
530
+
531
+
532
+ if __name__ == "__main__":
533
+ model = FireflyBase()
534
+ model.eval()
535
+ x = torch.randn(1, 128, 128)
536
+ with torch.no_grad():
537
+ y = model(x)
538
+ print(y.shape)
fish_speech/models/vqgan/modules/fsq.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from einops import rearrange
7
+ from vector_quantize_pytorch import GroupedResidualFSQ
8
+
9
+ from .firefly import ConvNeXtBlock
10
+
11
+
12
+ @dataclass
13
+ class FSQResult:
14
+ z: torch.Tensor
15
+ codes: torch.Tensor
16
+ latents: torch.Tensor
17
+
18
+
19
+ class DownsampleFiniteScalarQuantize(nn.Module):
20
+ def __init__(
21
+ self,
22
+ input_dim: int = 512,
23
+ n_codebooks: int = 9,
24
+ n_groups: int = 1,
25
+ levels: tuple[int] = (8, 5, 5, 5), # Approximate 2**10
26
+ downsample_factor: tuple[int] = (2, 2),
27
+ downsample_dims: tuple[int] | None = None,
28
+ ):
29
+ super().__init__()
30
+
31
+ if downsample_dims is None:
32
+ downsample_dims = [input_dim for _ in range(len(downsample_factor))]
33
+
34
+ all_dims = (input_dim,) + tuple(downsample_dims)
35
+
36
+ self.residual_fsq = GroupedResidualFSQ(
37
+ dim=all_dims[-1],
38
+ levels=levels,
39
+ num_quantizers=n_codebooks,
40
+ groups=n_groups,
41
+ )
42
+
43
+ self.downsample_factor = downsample_factor
44
+ self.downsample_dims = downsample_dims
45
+
46
+ self.downsample = nn.Sequential(
47
+ *[
48
+ nn.Sequential(
49
+ nn.Conv1d(
50
+ all_dims[idx],
51
+ all_dims[idx + 1],
52
+ kernel_size=factor,
53
+ stride=factor,
54
+ ),
55
+ ConvNeXtBlock(dim=all_dims[idx + 1]),
56
+ )
57
+ for idx, factor in enumerate(downsample_factor)
58
+ ]
59
+ )
60
+
61
+ self.upsample = nn.Sequential(
62
+ *[
63
+ nn.Sequential(
64
+ nn.ConvTranspose1d(
65
+ all_dims[idx + 1],
66
+ all_dims[idx],
67
+ kernel_size=factor,
68
+ stride=factor,
69
+ ),
70
+ ConvNeXtBlock(dim=all_dims[idx]),
71
+ )
72
+ for idx, factor in reversed(list(enumerate(downsample_factor)))
73
+ ]
74
+ )
75
+
76
+ self.apply(self._init_weights)
77
+
78
+ def _init_weights(self, m):
79
+ if isinstance(m, (nn.Conv1d, nn.Linear)):
80
+ nn.init.trunc_normal_(m.weight, std=0.02)
81
+ nn.init.constant_(m.bias, 0)
82
+
83
+ def forward(self, z) -> FSQResult:
84
+ original_shape = z.shape
85
+ z = self.downsample(z)
86
+ quantized, indices = self.residual_fsq(z.mT)
87
+ result = FSQResult(
88
+ z=quantized.mT,
89
+ codes=indices.mT,
90
+ latents=z,
91
+ )
92
+ result.z = self.upsample(result.z)
93
+
94
+ # Pad or crop z to match original shape
95
+ diff = original_shape[-1] - result.z.shape[-1]
96
+ left = diff // 2
97
+ right = diff - left
98
+
99
+ if diff > 0:
100
+ result.z = F.pad(result.z, (left, right))
101
+ elif diff < 0:
102
+ result.z = result.z[..., left:-right]
103
+
104
+ return result
105
+
106
+ def encode(self, z):
107
+ z = self.downsample(z)
108
+ _, indices = self.residual_fsq(z.mT)
109
+ indices = rearrange(indices, "g b l r -> b (g r) l")
110
+ return indices
111
+
112
+ def decode(self, indices: torch.Tensor):
113
+ indices = rearrange(indices, "b (g r) l -> g b l r", g=self.residual_fsq.groups)
114
+ z_q = self.residual_fsq.get_output_from_indices(indices)
115
+ z_q = self.upsample(z_q.mT)
116
+ return z_q
117
+
118
+ # def from_latents(self, latents: torch.Tensor):
119
+ # z_q, z_p, codes = super().from_latents(latents)
120
+ # z_q = self.upsample(z_q)
121
+ # return z_q, z_p, codes
122
+
123
+
124
+ if __name__ == "__main__":
125
+ rvq = DownsampleFiniteScalarQuantize(
126
+ n_codebooks=1,
127
+ downsample_factor=(2, 2),
128
+ )
129
+ x = torch.randn(16, 512, 80)
130
+
131
+ result = rvq(x)
132
+ print(rvq)
133
+ print(result.latents.shape, result.codes.shape, result.z.shape)
134
+
135
+ # y = rvq.from_codes(result.codes)
136
+ # print(y[0].shape)
137
+
138
+ # y = rvq.from_latents(result.latents)
139
+ # print(y[0].shape)
fish_speech/models/vqgan/modules/reference.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch import nn
6
+
7
+ from .wavenet import WaveNet
8
+
9
+
10
+ class ReferenceEncoder(WaveNet):
11
+ def __init__(
12
+ self,
13
+ input_channels: Optional[int] = None,
14
+ output_channels: Optional[int] = None,
15
+ residual_channels: int = 512,
16
+ residual_layers: int = 20,
17
+ dilation_cycle: Optional[int] = 4,
18
+ num_heads: int = 8,
19
+ latent_len: int = 4,
20
+ ):
21
+ super().__init__(
22
+ input_channels=input_channels,
23
+ residual_channels=residual_channels,
24
+ residual_layers=residual_layers,
25
+ dilation_cycle=dilation_cycle,
26
+ )
27
+
28
+ self.head_dim = residual_channels // num_heads
29
+ self.num_heads = num_heads
30
+
31
+ self.latent_len = latent_len
32
+ self.latent = nn.Parameter(torch.zeros(1, self.latent_len, residual_channels))
33
+
34
+ self.q = nn.Linear(residual_channels, residual_channels, bias=True)
35
+ self.kv = nn.Linear(residual_channels, residual_channels * 2, bias=True)
36
+ self.q_norm = nn.LayerNorm(self.head_dim)
37
+ self.k_norm = nn.LayerNorm(self.head_dim)
38
+ self.proj = nn.Linear(residual_channels, residual_channels)
39
+ self.proj_drop = nn.Dropout(0.1)
40
+
41
+ self.norm = nn.LayerNorm(residual_channels)
42
+ self.mlp = nn.Sequential(
43
+ nn.Linear(residual_channels, residual_channels * 4),
44
+ nn.SiLU(),
45
+ nn.Linear(residual_channels * 4, residual_channels),
46
+ )
47
+ self.output_projection_attn = nn.Linear(residual_channels, output_channels)
48
+
49
+ torch.nn.init.trunc_normal_(self.latent, std=0.02)
50
+ self.apply(self.init_weights)
51
+
52
+ def init_weights(self, m):
53
+ if isinstance(m, nn.Linear):
54
+ torch.nn.init.trunc_normal_(m.weight, std=0.02)
55
+ if m.bias is not None:
56
+ torch.nn.init.constant_(m.bias, 0)
57
+
58
+ def forward(self, x, attn_mask=None):
59
+ x = super().forward(x).mT
60
+ B, N, C = x.shape
61
+
62
+ # Calculate mask
63
+ if attn_mask is not None:
64
+ assert attn_mask.shape == (B, N) and attn_mask.dtype == torch.bool
65
+
66
+ attn_mask = attn_mask[:, None, None, :].expand(
67
+ B, self.num_heads, self.latent_len, N
68
+ )
69
+
70
+ q_latent = self.latent.expand(B, -1, -1)
71
+ q = (
72
+ self.q(q_latent)
73
+ .reshape(B, self.latent_len, self.num_heads, self.head_dim)
74
+ .transpose(1, 2)
75
+ )
76
+
77
+ kv = (
78
+ self.kv(x)
79
+ .reshape(B, N, 2, self.num_heads, self.head_dim)
80
+ .permute(2, 0, 3, 1, 4)
81
+ )
82
+ k, v = kv.unbind(0)
83
+
84
+ q, k = self.q_norm(q), self.k_norm(k)
85
+ x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
86
+
87
+ x = x.transpose(1, 2).reshape(B, self.latent_len, C)
88
+ x = self.proj(x)
89
+ x = self.proj_drop(x)
90
+
91
+ x = x + self.mlp(self.norm(x))
92
+ x = self.output_projection_attn(x)
93
+ x = x.mean(1)
94
+
95
+ return x
96
+
97
+
98
+ if __name__ == "__main__":
99
+ with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
100
+ model = ReferenceEncoder(
101
+ input_channels=128,
102
+ output_channels=64,
103
+ residual_channels=384,
104
+ residual_layers=20,
105
+ dilation_cycle=4,
106
+ num_heads=8,
107
+ )
108
+ x = torch.randn(4, 128, 64)
109
+ mask = torch.ones(4, 64, dtype=torch.bool)
110
+ y = model(x, mask)
111
+ print(y.shape)
112
+ loss = F.mse_loss(y, torch.randn(4, 64))
113
+ loss.backward()
fish_speech/models/vqgan/modules/wavenet.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Optional
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from torch import nn
7
+
8
+
9
+ class Mish(nn.Module):
10
+ def forward(self, x):
11
+ return x * torch.tanh(F.softplus(x))
12
+
13
+
14
+ class DiffusionEmbedding(nn.Module):
15
+ """Diffusion Step Embedding"""
16
+
17
+ def __init__(self, d_denoiser):
18
+ super(DiffusionEmbedding, self).__init__()
19
+ self.dim = d_denoiser
20
+
21
+ def forward(self, x):
22
+ device = x.device
23
+ half_dim = self.dim // 2
24
+ emb = math.log(10000) / (half_dim - 1)
25
+ emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
26
+ emb = x[:, None] * emb[None, :]
27
+ emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
28
+ return emb
29
+
30
+
31
+ class LinearNorm(nn.Module):
32
+ """LinearNorm Projection"""
33
+
34
+ def __init__(self, in_features, out_features, bias=False):
35
+ super(LinearNorm, self).__init__()
36
+ self.linear = nn.Linear(in_features, out_features, bias)
37
+
38
+ nn.init.xavier_uniform_(self.linear.weight)
39
+ if bias:
40
+ nn.init.constant_(self.linear.bias, 0.0)
41
+
42
+ def forward(self, x):
43
+ x = self.linear(x)
44
+ return x
45
+
46
+
47
+ class ConvNorm(nn.Module):
48
+ """1D Convolution"""
49
+
50
+ def __init__(
51
+ self,
52
+ in_channels,
53
+ out_channels,
54
+ kernel_size=1,
55
+ stride=1,
56
+ padding=None,
57
+ dilation=1,
58
+ bias=True,
59
+ w_init_gain="linear",
60
+ ):
61
+ super(ConvNorm, self).__init__()
62
+
63
+ if padding is None:
64
+ assert kernel_size % 2 == 1
65
+ padding = int(dilation * (kernel_size - 1) / 2)
66
+
67
+ self.conv = nn.Conv1d(
68
+ in_channels,
69
+ out_channels,
70
+ kernel_size=kernel_size,
71
+ stride=stride,
72
+ padding=padding,
73
+ dilation=dilation,
74
+ bias=bias,
75
+ )
76
+ nn.init.kaiming_normal_(self.conv.weight)
77
+
78
+ def forward(self, signal):
79
+ conv_signal = self.conv(signal)
80
+
81
+ return conv_signal
82
+
83
+
84
+ class ResidualBlock(nn.Module):
85
+ """Residual Block"""
86
+
87
+ def __init__(
88
+ self,
89
+ residual_channels,
90
+ use_linear_bias=False,
91
+ dilation=1,
92
+ condition_channels=None,
93
+ ):
94
+ super(ResidualBlock, self).__init__()
95
+ self.conv_layer = ConvNorm(
96
+ residual_channels,
97
+ 2 * residual_channels,
98
+ kernel_size=3,
99
+ stride=1,
100
+ padding=dilation,
101
+ dilation=dilation,
102
+ )
103
+
104
+ if condition_channels is not None:
105
+ self.diffusion_projection = LinearNorm(
106
+ residual_channels, residual_channels, use_linear_bias
107
+ )
108
+ self.condition_projection = ConvNorm(
109
+ condition_channels, 2 * residual_channels, kernel_size=1
110
+ )
111
+
112
+ self.output_projection = ConvNorm(
113
+ residual_channels, 2 * residual_channels, kernel_size=1
114
+ )
115
+
116
+ def forward(self, x, condition=None, diffusion_step=None):
117
+ y = x
118
+
119
+ if diffusion_step is not None:
120
+ diffusion_step = self.diffusion_projection(diffusion_step).unsqueeze(-1)
121
+ y = y + diffusion_step
122
+
123
+ y = self.conv_layer(y)
124
+
125
+ if condition is not None:
126
+ condition = self.condition_projection(condition)
127
+ y = y + condition
128
+
129
+ gate, filter = torch.chunk(y, 2, dim=1)
130
+ y = torch.sigmoid(gate) * torch.tanh(filter)
131
+
132
+ y = self.output_projection(y)
133
+ residual, skip = torch.chunk(y, 2, dim=1)
134
+
135
+ return (x + residual) / math.sqrt(2.0), skip
136
+
137
+
138
+ class WaveNet(nn.Module):
139
+ def __init__(
140
+ self,
141
+ input_channels: Optional[int] = None,
142
+ output_channels: Optional[int] = None,
143
+ residual_channels: int = 512,
144
+ residual_layers: int = 20,
145
+ dilation_cycle: Optional[int] = 4,
146
+ is_diffusion: bool = False,
147
+ condition_channels: Optional[int] = None,
148
+ ):
149
+ super().__init__()
150
+
151
+ # Input projection
152
+ self.input_projection = None
153
+ if input_channels is not None and input_channels != residual_channels:
154
+ self.input_projection = ConvNorm(
155
+ input_channels, residual_channels, kernel_size=1
156
+ )
157
+
158
+ if input_channels is None:
159
+ input_channels = residual_channels
160
+
161
+ self.input_channels = input_channels
162
+
163
+ # Residual layers
164
+ self.residual_layers = nn.ModuleList(
165
+ [
166
+ ResidualBlock(
167
+ residual_channels=residual_channels,
168
+ use_linear_bias=False,
169
+ dilation=2 ** (i % dilation_cycle) if dilation_cycle else 1,
170
+ condition_channels=condition_channels,
171
+ )
172
+ for i in range(residual_layers)
173
+ ]
174
+ )
175
+
176
+ # Skip projection
177
+ self.skip_projection = ConvNorm(
178
+ residual_channels, residual_channels, kernel_size=1
179
+ )
180
+
181
+ # Output projection
182
+ self.output_projection = None
183
+ if output_channels is not None and output_channels != residual_channels:
184
+ self.output_projection = ConvNorm(
185
+ residual_channels, output_channels, kernel_size=1
186
+ )
187
+
188
+ if is_diffusion:
189
+ self.diffusion_embedding = DiffusionEmbedding(residual_channels)
190
+ self.mlp = nn.Sequential(
191
+ LinearNorm(residual_channels, residual_channels * 4, False),
192
+ Mish(),
193
+ LinearNorm(residual_channels * 4, residual_channels, False),
194
+ )
195
+
196
+ self.apply(self._init_weights)
197
+
198
+ def _init_weights(self, m):
199
+ if isinstance(m, (nn.Conv1d, nn.Linear)):
200
+ nn.init.trunc_normal_(m.weight, std=0.02)
201
+ if getattr(m, "bias", None) is not None:
202
+ nn.init.constant_(m.bias, 0)
203
+
204
+ def forward(self, x, t=None, condition=None):
205
+ if self.input_projection is not None:
206
+ x = self.input_projection(x)
207
+ x = F.silu(x)
208
+
209
+ if t is not None:
210
+ t = self.diffusion_embedding(t)
211
+ t = self.mlp(t)
212
+
213
+ skip = []
214
+ for layer in self.residual_layers:
215
+ x, skip_connection = layer(x, condition, t)
216
+ skip.append(skip_connection)
217
+
218
+ x = torch.sum(torch.stack(skip), dim=0) / math.sqrt(len(self.residual_layers))
219
+ x = self.skip_projection(x)
220
+
221
+ if self.output_projection is not None:
222
+ x = F.silu(x)
223
+ x = self.output_projection(x)
224
+
225
+ return x
fish_speech/models/vqgan/spectrogram.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchaudio.functional as F
3
+ from torch import Tensor, nn
4
+ from torchaudio.transforms import MelScale
5
+
6
+
7
+ class LinearSpectrogram(nn.Module):
8
+ def __init__(
9
+ self,
10
+ n_fft=2048,
11
+ win_length=2048,
12
+ hop_length=512,
13
+ center=False,
14
+ mode="pow2_sqrt",
15
+ ):
16
+ super().__init__()
17
+
18
+ self.n_fft = n_fft
19
+ self.win_length = win_length
20
+ self.hop_length = hop_length
21
+ self.center = center
22
+ self.mode = mode
23
+
24
+ self.register_buffer("window", torch.hann_window(win_length), persistent=False)
25
+
26
+ def forward(self, y: Tensor) -> Tensor:
27
+ if y.ndim == 3:
28
+ y = y.squeeze(1)
29
+
30
+ y = torch.nn.functional.pad(
31
+ y.unsqueeze(1),
32
+ (
33
+ (self.win_length - self.hop_length) // 2,
34
+ (self.win_length - self.hop_length + 1) // 2,
35
+ ),
36
+ mode="reflect",
37
+ ).squeeze(1)
38
+
39
+ spec = torch.stft(
40
+ y,
41
+ self.n_fft,
42
+ hop_length=self.hop_length,
43
+ win_length=self.win_length,
44
+ window=self.window,
45
+ center=self.center,
46
+ pad_mode="reflect",
47
+ normalized=False,
48
+ onesided=True,
49
+ return_complex=True,
50
+ )
51
+
52
+ spec = torch.view_as_real(spec)
53
+
54
+ if self.mode == "pow2_sqrt":
55
+ spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
56
+
57
+ return spec
58
+
59
+
60
+ class LogMelSpectrogram(nn.Module):
61
+ def __init__(
62
+ self,
63
+ sample_rate=44100,
64
+ n_fft=2048,
65
+ win_length=2048,
66
+ hop_length=512,
67
+ n_mels=128,
68
+ center=False,
69
+ f_min=0.0,
70
+ f_max=None,
71
+ ):
72
+ super().__init__()
73
+
74
+ self.sample_rate = sample_rate
75
+ self.n_fft = n_fft
76
+ self.win_length = win_length
77
+ self.hop_length = hop_length
78
+ self.center = center
79
+ self.n_mels = n_mels
80
+ self.f_min = f_min
81
+ self.f_max = f_max or float(sample_rate // 2)
82
+
83
+ self.spectrogram = LinearSpectrogram(n_fft, win_length, hop_length, center)
84
+
85
+ fb = F.melscale_fbanks(
86
+ n_freqs=self.n_fft // 2 + 1,
87
+ f_min=self.f_min,
88
+ f_max=self.f_max,
89
+ n_mels=self.n_mels,
90
+ sample_rate=self.sample_rate,
91
+ norm="slaney",
92
+ mel_scale="slaney",
93
+ )
94
+ self.register_buffer(
95
+ "fb",
96
+ fb,
97
+ persistent=False,
98
+ )
99
+
100
+ def compress(self, x: Tensor) -> Tensor:
101
+ return torch.log(torch.clamp(x, min=1e-5))
102
+
103
+ def decompress(self, x: Tensor) -> Tensor:
104
+ return torch.exp(x)
105
+
106
+ def apply_mel_scale(self, x: Tensor) -> Tensor:
107
+ return torch.matmul(x.transpose(-1, -2), self.fb).transpose(-1, -2)
108
+
109
+ def forward(
110
+ self, x: Tensor, return_linear: bool = False, sample_rate: int = None
111
+ ) -> Tensor:
112
+ if sample_rate is not None and sample_rate != self.sample_rate:
113
+ x = F.resample(x, orig_freq=sample_rate, new_freq=self.sample_rate)
114
+
115
+ linear = self.spectrogram(x)
116
+ x = self.apply_mel_scale(linear)
117
+ x = self.compress(x)
118
+
119
+ if return_linear:
120
+ return x, self.compress(linear)
121
+
122
+ return x
fish_speech/models/vqgan/utils.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib
2
+ import torch
3
+ from matplotlib import pyplot as plt
4
+
5
+ matplotlib.use("Agg")
6
+
7
+
8
+ def convert_pad_shape(pad_shape):
9
+ l = pad_shape[::-1]
10
+ pad_shape = [item for sublist in l for item in sublist]
11
+ return pad_shape
12
+
13
+
14
+ def sequence_mask(length, max_length=None):
15
+ if max_length is None:
16
+ max_length = length.max()
17
+ x = torch.arange(max_length, dtype=length.dtype, device=length.device)
18
+ return x.unsqueeze(0) < length.unsqueeze(1)
19
+
20
+
21
+ def init_weights(m, mean=0.0, std=0.01):
22
+ classname = m.__class__.__name__
23
+ if classname.find("Conv") != -1:
24
+ m.weight.data.normal_(mean, std)
25
+
26
+
27
+ def get_padding(kernel_size, dilation=1):
28
+ return int((kernel_size * dilation - dilation) / 2)
29
+
30
+
31
+ def plot_mel(data, titles=None):
32
+ fig, axes = plt.subplots(len(data), 1, squeeze=False)
33
+
34
+ if titles is None:
35
+ titles = [None for i in range(len(data))]
36
+
37
+ plt.tight_layout()
38
+
39
+ for i in range(len(data)):
40
+ mel = data[i]
41
+
42
+ if isinstance(mel, torch.Tensor):
43
+ mel = mel.float().detach().cpu().numpy()
44
+
45
+ axes[i][0].imshow(mel, origin="lower")
46
+ axes[i][0].set_aspect(2.5, adjustable="box")
47
+ axes[i][0].set_ylim(0, mel.shape[0])
48
+ axes[i][0].set_title(titles[i], fontsize="medium")
49
+ axes[i][0].tick_params(labelsize="x-small", left=False, labelleft=False)
50
+ axes[i][0].set_anchor("W")
51
+
52
+ return fig
53
+
54
+
55
+ def slice_segments(x, ids_str, segment_size=4):
56
+ ret = torch.zeros_like(x[:, :, :segment_size])
57
+ for i in range(x.size(0)):
58
+ idx_str = ids_str[i]
59
+ idx_end = idx_str + segment_size
60
+ ret[i] = x[i, :, idx_str:idx_end]
61
+
62
+ return ret
63
+
64
+
65
+ def rand_slice_segments(x, x_lengths=None, segment_size=4):
66
+ b, d, t = x.size()
67
+ if x_lengths is None:
68
+ x_lengths = t
69
+ ids_str_max = torch.clamp(x_lengths - segment_size + 1, min=0)
70
+ ids_str = (torch.rand([b], device=x.device) * ids_str_max).to(dtype=torch.long)
71
+ ret = slice_segments(x, ids_str, segment_size)
72
+ return ret, ids_str
73
+
74
+
75
+ @torch.jit.script
76
+ def fused_add_tanh_sigmoid_multiply(in_act, n_channels):
77
+ n_channels_int = n_channels[0]
78
+ t_act = torch.tanh(in_act[:, :n_channels_int, :])
79
+ s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
80
+ acts = t_act * s_act
81
+
82
+ return acts
83
+
84
+
85
+ def avg_with_mask(x, mask):
86
+ assert mask.dtype == torch.float, "Mask should be float"
87
+
88
+ if mask.ndim == 2:
89
+ mask = mask.unsqueeze(1)
90
+
91
+ if mask.shape[1] == 1:
92
+ mask = mask.expand_as(x)
93
+
94
+ return (x * mask).sum() / mask.sum()
fish_speech/scheduler.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+
4
+ def get_cosine_schedule_with_warmup_lr_lambda(
5
+ current_step: int,
6
+ *,
7
+ num_warmup_steps: int,
8
+ num_training_steps: int,
9
+ num_cycles: float = 0.5,
10
+ final_lr_ratio: float = 0.0,
11
+ ):
12
+ if current_step < num_warmup_steps:
13
+ return float(current_step) / float(max(1, num_warmup_steps))
14
+
15
+ progress = float(current_step - num_warmup_steps) / float(
16
+ max(1, num_training_steps - num_warmup_steps)
17
+ )
18
+
19
+ return max(
20
+ final_lr_ratio,
21
+ 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)),
22
+ )
fish_speech/text/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .clean import clean_text
2
+
3
+ __all__ = ["clean_text"]
fish_speech/text/clean.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+ import re
3
+
4
+ LANGUAGE_UNICODE_RANGE_MAP = {
5
+ "ZH": [(0x4E00, 0x9FFF)],
6
+ "JP": [(0x4E00, 0x9FFF), (0x3040, 0x309F), (0x30A0, 0x30FF), (0x31F0, 0x31FF)],
7
+ "EN": [(0x0000, 0x007F)],
8
+ }
9
+
10
+ SYMBOLS_MAPPING = {
11
+ ":": ",",
12
+ ";": ",",
13
+ ",": ",",
14
+ "。": ".",
15
+ "!": "!",
16
+ "?": "?",
17
+ "\n": ".",
18
+ "·": ",",
19
+ "、": ",",
20
+ "...": "…",
21
+ "$": ".",
22
+ "“": "'",
23
+ "”": "'",
24
+ "‘": "'",
25
+ "’": "'",
26
+ "(": "'",
27
+ ")": "'",
28
+ "(": "'",
29
+ ")": "'",
30
+ "《": "'",
31
+ "》": "'",
32
+ "【": "'",
33
+ "】": "'",
34
+ "[": "'",
35
+ "]": "'",
36
+ "—": "-",
37
+ "~": "-",
38
+ "~": "-",
39
+ "・": "-",
40
+ "「": "'",
41
+ "」": "'",
42
+ ";": ",",
43
+ ":": ",",
44
+ }
45
+
46
+ REPLACE_SYMBOL_REGEX = re.compile(
47
+ "|".join(re.escape(p) for p in SYMBOLS_MAPPING.keys())
48
+ )
49
+ ALL_KNOWN_UTF8_RANGE = list(
50
+ itertools.chain.from_iterable(LANGUAGE_UNICODE_RANGE_MAP.values())
51
+ )
52
+ REMOVE_UNKNOWN_SYMBOL_REGEX = re.compile(
53
+ "[^"
54
+ + "".join(
55
+ f"{re.escape(chr(start))}-{re.escape(chr(end))}"
56
+ for start, end in ALL_KNOWN_UTF8_RANGE
57
+ )
58
+ + "]"
59
+ )
60
+
61
+
62
+ def clean_text(text):
63
+ # Clean the text
64
+ text = text.strip()
65
+ # Replace <p:(.*?)> with <PPP(.*?)PPP>
66
+ text = re.sub(r"<p:(.*?)>", r"<PPP\1PPP>", text)
67
+ # Replace all chinese symbols with their english counterparts
68
+ text = REPLACE_SYMBOL_REGEX.sub(lambda x: SYMBOLS_MAPPING[x.group()], text)
69
+ text = REMOVE_UNKNOWN_SYMBOL_REGEX.sub("", text)
70
+ # Replace <PPP(.*?)PPP> with <p:(.*?)>
71
+ text = re.sub(r"<PPP(.*?)PPP>", r"<p:\1>", text)
72
+
73
+ return text
fish_speech/train.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Optional
3
+
4
+ import hydra
5
+ import lightning as L
6
+ import pyrootutils
7
+ import torch
8
+ from lightning import Callback, LightningDataModule, LightningModule, Trainer
9
+ from lightning.pytorch.loggers import Logger
10
+ from omegaconf import DictConfig, OmegaConf
11
+
12
+ os.environ.pop("SLURM_NTASKS", None)
13
+ os.environ.pop("SLURM_JOB_NAME", None)
14
+ os.environ.pop("SLURM_NTASKS_PER_NODE", None)
15
+
16
+ # register eval resolver and root
17
+ pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
18
+
19
+ # Allow TF32 on Ampere GPUs
20
+ torch.set_float32_matmul_precision("high")
21
+ torch.backends.cudnn.allow_tf32 = True
22
+
23
+ # register eval resolver
24
+ OmegaConf.register_new_resolver("eval", eval)
25
+
26
+ import fish_speech.utils as utils
27
+
28
+ log = utils.RankedLogger(__name__, rank_zero_only=True)
29
+
30
+
31
+ @utils.task_wrapper
32
+ def train(cfg: DictConfig) -> tuple[dict, dict]:
33
+ """Trains the model. Can additionally evaluate on a testset, using best weights obtained during
34
+ training.
35
+ This method is wrapped in optional @task_wrapper decorator, that controls the behavior during
36
+ failure. Useful for multiruns, saving info about the crash, etc.
37
+ Args:
38
+ cfg (DictConfig): Configuration composed by Hydra.
39
+ Returns:
40
+ Tuple[dict, dict]: Dict with metrics and dict with all instantiated objects.
41
+ """ # noqa: E501
42
+
43
+ # set seed for random number generators in pytorch, numpy and python.random
44
+ if cfg.get("seed"):
45
+ L.seed_everything(cfg.seed, workers=False)
46
+
47
+ if cfg.get("deterministic"):
48
+ torch.use_deterministic_algorithms(True)
49
+
50
+ log.info(f"Instantiating datamodule <{cfg.data._target_}>")
51
+ datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data)
52
+
53
+ log.info(f"Instantiating model <{cfg.model._target_}>")
54
+ model: LightningModule = hydra.utils.instantiate(cfg.model)
55
+
56
+ log.info("Instantiating callbacks...")
57
+ callbacks: list[Callback] = utils.instantiate_callbacks(cfg.get("callbacks"))
58
+
59
+ log.info("Instantiating loggers...")
60
+ logger: list[Logger] = utils.instantiate_loggers(cfg.get("logger"))
61
+
62
+ log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
63
+ trainer: Trainer = hydra.utils.instantiate(
64
+ cfg.trainer, callbacks=callbacks, logger=logger
65
+ )
66
+
67
+ object_dict = {
68
+ "cfg": cfg,
69
+ "datamodule": datamodule,
70
+ "model": model,
71
+ "callbacks": callbacks,
72
+ "logger": logger,
73
+ "trainer": trainer,
74
+ }
75
+
76
+ if logger:
77
+ log.info("Logging hyperparameters!")
78
+ utils.log_hyperparameters(object_dict)
79
+
80
+ if cfg.get("train"):
81
+ log.info("Starting training!")
82
+
83
+ ckpt_path = cfg.get("ckpt_path")
84
+ auto_resume = False
85
+
86
+ resume_ckpt_path = utils.get_latest_checkpoint(cfg.paths.ckpt_dir)
87
+ if resume_ckpt_path is not None:
88
+ ckpt_path = resume_ckpt_path
89
+ auto_resume = True
90
+
91
+ if ckpt_path is not None:
92
+ log.info(f"Resuming from checkpoint: {ckpt_path}")
93
+
94
+ # resume weights only is disabled for auto-resume
95
+ if cfg.get("resume_weights_only") and auto_resume is False:
96
+ log.info("Resuming weights only!")
97
+ ckpt = torch.load(ckpt_path, map_location=model.device)
98
+ if "state_dict" in ckpt:
99
+ ckpt = ckpt["state_dict"]
100
+ err = model.load_state_dict(ckpt, strict=False)
101
+ log.info(f"Error loading state dict: {err}")
102
+ ckpt_path = None
103
+
104
+ trainer.fit(model=model, datamodule=datamodule, ckpt_path=ckpt_path)
105
+
106
+ train_metrics = trainer.callback_metrics
107
+
108
+ if cfg.get("test"):
109
+ log.info("Starting testing!")
110
+ ckpt_path = trainer.checkpoint_callback.best_model_path
111
+ if ckpt_path == "":
112
+ log.warning("Best ckpt not found! Using current weights for testing...")
113
+ ckpt_path = cfg.get("ckpt_path")
114
+
115
+ trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path)
116
+ log.info(f"Best ckpt path: {ckpt_path}")
117
+
118
+ test_metrics = trainer.callback_metrics
119
+
120
+ # merge train and test metrics
121
+ metric_dict = {**train_metrics, **test_metrics}
122
+
123
+ return metric_dict, object_dict
124
+
125
+
126
+ @hydra.main(
127
+ version_base="1.3", config_path="./configs", config_name="llama_pretrain.yaml"
128
+ )
129
+ def main(cfg: DictConfig) -> Optional[float]:
130
+ # train the model
131
+ train(cfg)
132
+
133
+
134
+ if __name__ == "__main__":
135
+ main()
fish_speech/utils/__init__.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .braceexpand import braceexpand
2
+ from .file import get_latest_checkpoint
3
+ from .instantiators import instantiate_callbacks, instantiate_loggers
4
+ from .logger import RankedLogger
5
+ from .logging_utils import log_hyperparameters
6
+ from .rich_utils import enforce_tags, print_config_tree
7
+ from .utils import extras, get_metric_value, task_wrapper
8
+
9
+ __all__ = [
10
+ "enforce_tags",
11
+ "extras",
12
+ "get_metric_value",
13
+ "RankedLogger",
14
+ "instantiate_callbacks",
15
+ "instantiate_loggers",
16
+ "log_hyperparameters",
17
+ "print_config_tree",
18
+ "task_wrapper",
19
+ "braceexpand",
20
+ "get_latest_checkpoint",
21
+ ]
fish_speech/utils/braceexpand.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Bash-style brace expansion
3
+ Copied from: https://github.com/trendels/braceexpand/blob/main/src/braceexpand/__init__.py
4
+ License: MIT
5
+ """
6
+
7
+ import re
8
+ import string
9
+ from itertools import chain, product
10
+ from typing import Iterable, Iterator, Optional
11
+
12
+ __all__ = ["braceexpand", "alphabet", "UnbalancedBracesError"]
13
+
14
+
15
+ class UnbalancedBracesError(ValueError):
16
+ pass
17
+
18
+
19
+ alphabet = string.ascii_uppercase + string.ascii_lowercase
20
+
21
+ int_range_re = re.compile(r"^(-?\d+)\.\.(-?\d+)(?:\.\.-?(\d+))?$")
22
+ char_range_re = re.compile(r"^([A-Za-z])\.\.([A-Za-z])(?:\.\.-?(\d+))?$")
23
+ escape_re = re.compile(r"\\(.)")
24
+
25
+
26
+ def braceexpand(pattern: str, escape: bool = True) -> Iterator[str]:
27
+ """braceexpand(pattern) -> iterator over generated strings
28
+
29
+ Returns an iterator over the strings resulting from brace expansion
30
+ of pattern. This function implements Brace Expansion as described in
31
+ bash(1), with the following limitations:
32
+
33
+ * A pattern containing unbalanced braces will raise an
34
+ UnbalancedBracesError exception. In bash, unbalanced braces will either
35
+ be partly expanded or ignored.
36
+
37
+ * A mixed-case character range like '{Z..a}' or '{a..Z}' will not
38
+ include the characters '[]^_`' between 'Z' and 'a'.
39
+
40
+ When escape is True (the default), characters in pattern can be
41
+ prefixed with a backslash to cause them not to be interpreted as
42
+ special characters for brace expansion (such as '{', '}', ',').
43
+ To pass through a a literal backslash, double it ('\\\\').
44
+
45
+ When escape is False, backslashes in pattern have no special
46
+ meaning and will be preserved in the output.
47
+
48
+ Examples:
49
+
50
+ >>> from braceexpand import braceexpand
51
+
52
+ # Integer range
53
+ >>> list(braceexpand('item{1..3}'))
54
+ ['item1', 'item2', 'item3']
55
+
56
+ # Character range
57
+ >>> list(braceexpand('{a..c}'))
58
+ ['a', 'b', 'c']
59
+
60
+ # Sequence
61
+ >>> list(braceexpand('index.html{,.backup}'))
62
+ ['index.html', 'index.html.backup']
63
+
64
+ # Nested patterns
65
+ >>> list(braceexpand('python{2.{5..7},3.{2,3}}'))
66
+ ['python2.5', 'python2.6', 'python2.7', 'python3.2', 'python3.3']
67
+
68
+ # Prefixing an integer with zero causes all numbers to be padded to
69
+ # the same width.
70
+ >>> list(braceexpand('{07..10}'))
71
+ ['07', '08', '09', '10']
72
+
73
+ # An optional increment can be specified for ranges.
74
+ >>> list(braceexpand('{a..g..2}'))
75
+ ['a', 'c', 'e', 'g']
76
+
77
+ # Ranges can go in both directions.
78
+ >>> list(braceexpand('{4..1}'))
79
+ ['4', '3', '2', '1']
80
+
81
+ # Numbers can be negative
82
+ >>> list(braceexpand('{2..-1}'))
83
+ ['2', '1', '0', '-1']
84
+
85
+ # Unbalanced braces raise an exception.
86
+ >>> list(braceexpand('{1{2,3}'))
87
+ Traceback (most recent call last):
88
+ ...
89
+ UnbalancedBracesError: Unbalanced braces: '{1{2,3}'
90
+
91
+ # By default, the backslash is the escape character.
92
+ >>> list(braceexpand(r'{1\\{2,3}'))
93
+ ['1{2', '3']
94
+
95
+ # Setting 'escape' to False disables backslash escaping.
96
+ >>> list(braceexpand(r'\\{1,2}', escape=False))
97
+ ['\\\\1', '\\\\2']
98
+
99
+ """
100
+ return (
101
+ escape_re.sub(r"\1", s) if escape else s for s in parse_pattern(pattern, escape)
102
+ )
103
+
104
+
105
+ def parse_pattern(pattern: str, escape: bool) -> Iterator[str]:
106
+ start = 0
107
+ pos = 0
108
+ bracketdepth = 0
109
+ items: list[Iterable[str]] = []
110
+
111
+ # print 'pattern:', pattern
112
+ while pos < len(pattern):
113
+ if escape and pattern[pos] == "\\":
114
+ pos += 2
115
+ continue
116
+ elif pattern[pos] == "{":
117
+ if bracketdepth == 0 and pos > start:
118
+ # print 'literal:', pattern[start:pos]
119
+ items.append([pattern[start:pos]])
120
+ start = pos
121
+ bracketdepth += 1
122
+ elif pattern[pos] == "}":
123
+ bracketdepth -= 1
124
+ if bracketdepth == 0:
125
+ # print 'expression:', pattern[start+1:pos]
126
+ expr = pattern[start + 1 : pos]
127
+ item = parse_expression(expr, escape)
128
+ if item is None: # not a range or sequence
129
+ items.extend([["{"], parse_pattern(expr, escape), ["}"]])
130
+ else:
131
+ items.append(item)
132
+ start = pos + 1 # skip the closing brace
133
+ pos += 1
134
+
135
+ if bracketdepth != 0: # unbalanced braces
136
+ raise UnbalancedBracesError("Unbalanced braces: '%s'" % pattern)
137
+
138
+ if start < pos:
139
+ items.append([pattern[start:]])
140
+
141
+ return ("".join(item) for item in product(*items))
142
+
143
+
144
+ def parse_expression(expr: str, escape: bool) -> Optional[Iterable[str]]:
145
+ int_range_match = int_range_re.match(expr)
146
+ if int_range_match:
147
+ return make_int_range(*int_range_match.groups())
148
+
149
+ char_range_match = char_range_re.match(expr)
150
+ if char_range_match:
151
+ return make_char_range(*char_range_match.groups())
152
+
153
+ return parse_sequence(expr, escape)
154
+
155
+
156
+ def parse_sequence(seq: str, escape: bool) -> Optional[Iterator[str]]:
157
+ # sequence -> chain(*sequence_items)
158
+ start = 0
159
+ pos = 0
160
+ bracketdepth = 0
161
+ items: list[Iterable[str]] = []
162
+
163
+ # print 'sequence:', seq
164
+ while pos < len(seq):
165
+ if escape and seq[pos] == "\\":
166
+ pos += 2
167
+ continue
168
+ elif seq[pos] == "{":
169
+ bracketdepth += 1
170
+ elif seq[pos] == "}":
171
+ bracketdepth -= 1
172
+ elif seq[pos] == "," and bracketdepth == 0:
173
+ items.append(parse_pattern(seq[start:pos], escape))
174
+ start = pos + 1 # skip the comma
175
+ pos += 1
176
+
177
+ if bracketdepth != 0:
178
+ raise UnbalancedBracesError
179
+ if not items:
180
+ return None
181
+
182
+ # part after the last comma (may be the empty string)
183
+ items.append(parse_pattern(seq[start:], escape))
184
+ return chain(*items)
185
+
186
+
187
+ def make_int_range(left: str, right: str, incr: Optional[str] = None) -> Iterator[str]:
188
+ if any([s.startswith(("0", "-0")) for s in (left, right) if s not in ("0", "-0")]):
189
+ padding = max(len(left), len(right))
190
+ else:
191
+ padding = 0
192
+ step = (int(incr) or 1) if incr else 1
193
+ start = int(left)
194
+ end = int(right)
195
+ r = range(start, end + 1, step) if start < end else range(start, end - 1, -step)
196
+ fmt = "%0{}d".format(padding)
197
+ return (fmt % i for i in r)
198
+
199
+
200
+ def make_char_range(left: str, right: str, incr: Optional[str] = None) -> str:
201
+ step = (int(incr) or 1) if incr else 1
202
+ start = alphabet.index(left)
203
+ end = alphabet.index(right)
204
+ if start < end:
205
+ return alphabet[start : end + 1 : step]
206
+ else:
207
+ end = end or -len(alphabet)
208
+ return alphabet[start : end - 1 : -step]
209
+
210
+
211
+ if __name__ == "__main__":
212
+ import doctest
213
+ import sys
214
+
215
+ failed, _ = doctest.testmod(optionflags=doctest.IGNORE_EXCEPTION_DETAIL)
216
+ if failed:
217
+ sys.exit(1)
fish_speech/utils/file.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from glob import glob
3
+ from pathlib import Path
4
+ from typing import Union
5
+
6
+ from loguru import logger
7
+ from natsort import natsorted
8
+
9
+ AUDIO_EXTENSIONS = {
10
+ ".mp3",
11
+ ".wav",
12
+ ".flac",
13
+ ".ogg",
14
+ ".m4a",
15
+ ".wma",
16
+ ".aac",
17
+ ".aiff",
18
+ ".aif",
19
+ ".aifc",
20
+ }
21
+
22
+
23
+ def list_files(
24
+ path: Union[Path, str],
25
+ extensions: set[str] = None,
26
+ recursive: bool = False,
27
+ sort: bool = True,
28
+ ) -> list[Path]:
29
+ """List files in a directory.
30
+
31
+ Args:
32
+ path (Path): Path to the directory.
33
+ extensions (set, optional): Extensions to filter. Defaults to None.
34
+ recursive (bool, optional): Whether to search recursively. Defaults to False.
35
+ sort (bool, optional): Whether to sort the files. Defaults to True.
36
+
37
+ Returns:
38
+ list: List of files.
39
+ """
40
+
41
+ if isinstance(path, str):
42
+ path = Path(path)
43
+
44
+ if not path.exists():
45
+ raise FileNotFoundError(f"Directory {path} does not exist.")
46
+
47
+ files = [file for ext in extensions for file in path.iglob(f"**/*{ext}")]
48
+
49
+ if sort:
50
+ files = natsorted(files)
51
+
52
+ return files
53
+
54
+
55
+ def get_latest_checkpoint(path: Path | str) -> Path | None:
56
+ # Find the latest checkpoint
57
+ ckpt_dir = Path(path)
58
+
59
+ if ckpt_dir.exists() is False:
60
+ return None
61
+
62
+ ckpts = sorted(ckpt_dir.glob("*.ckpt"), key=os.path.getmtime)
63
+ if len(ckpts) == 0:
64
+ return None
65
+
66
+ return ckpts[-1]
67
+
68
+
69
+ def load_filelist(path: Path | str) -> list[tuple[Path, str, str, str]]:
70
+ """
71
+ Load a Bert-VITS2 style filelist.
72
+ """
73
+
74
+ files = set()
75
+ results = []
76
+ count_duplicated, count_not_found = 0, 0
77
+
78
+ LANGUAGE_TO_LANGUAGES = {
79
+ "zh": ["zh", "en"],
80
+ "jp": ["jp", "en"],
81
+ "en": ["en"],
82
+ }
83
+
84
+ with open(path, "r", encoding="utf-8") as f:
85
+ for line in f.readlines():
86
+ splits = line.strip().split("|", maxsplit=3)
87
+ if len(splits) != 4:
88
+ logger.warning(f"Invalid line: {line}")
89
+ continue
90
+
91
+ filename, speaker, language, text = splits
92
+ file = Path(filename)
93
+ language = language.strip().lower()
94
+
95
+ if language == "ja":
96
+ language = "jp"
97
+
98
+ assert language in ["zh", "jp", "en"], f"Invalid language {language}"
99
+ languages = LANGUAGE_TO_LANGUAGES[language]
100
+
101
+ if file in files:
102
+ logger.warning(f"Duplicated file: {file}")
103
+ count_duplicated += 1
104
+ continue
105
+
106
+ if not file.exists():
107
+ logger.warning(f"File not found: {file}")
108
+ count_not_found += 1
109
+ continue
110
+
111
+ results.append((file, speaker, languages, text))
112
+
113
+ if count_duplicated > 0:
114
+ logger.warning(f"Total duplicated files: {count_duplicated}")
115
+
116
+ if count_not_found > 0:
117
+ logger.warning(f"Total files not found: {count_not_found}")
118
+
119
+ return results
fish_speech/utils/instantiators.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ import hydra
4
+ from omegaconf import DictConfig
5
+ from pytorch_lightning import Callback
6
+ from pytorch_lightning.loggers import Logger
7
+
8
+ from .logger import RankedLogger
9
+
10
+ log = RankedLogger(__name__, rank_zero_only=True)
11
+
12
+
13
+ def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]:
14
+ """Instantiates callbacks from config."""
15
+
16
+ callbacks: List[Callback] = []
17
+
18
+ if not callbacks_cfg:
19
+ log.warning("No callback configs found! Skipping..")
20
+ return callbacks
21
+
22
+ if not isinstance(callbacks_cfg, DictConfig):
23
+ raise TypeError("Callbacks config must be a DictConfig!")
24
+
25
+ for _, cb_conf in callbacks_cfg.items():
26
+ if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf:
27
+ log.info(f"Instantiating callback <{cb_conf._target_}>")
28
+ callbacks.append(hydra.utils.instantiate(cb_conf))
29
+
30
+ return callbacks
31
+
32
+
33
+ def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]:
34
+ """Instantiates loggers from config."""
35
+
36
+ logger: List[Logger] = []
37
+
38
+ if not logger_cfg:
39
+ log.warning("No logger configs found! Skipping...")
40
+ return logger
41
+
42
+ if not isinstance(logger_cfg, DictConfig):
43
+ raise TypeError("Logger config must be a DictConfig!")
44
+
45
+ for _, lg_conf in logger_cfg.items():
46
+ if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf:
47
+ log.info(f"Instantiating logger <{lg_conf._target_}>")
48
+ logger.append(hydra.utils.instantiate(lg_conf))
49
+
50
+ return logger
fish_speech/utils/logger.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Mapping, Optional
3
+
4
+ from lightning_utilities.core.rank_zero import rank_prefixed_message, rank_zero_only
5
+
6
+
7
+ class RankedLogger(logging.LoggerAdapter):
8
+ """A multi-GPU-friendly python command line logger."""
9
+
10
+ def __init__(
11
+ self,
12
+ name: str = __name__,
13
+ rank_zero_only: bool = True,
14
+ extra: Optional[Mapping[str, object]] = None,
15
+ ) -> None:
16
+ """Initializes a multi-GPU-friendly python command line logger that logs on all processes
17
+ with their rank prefixed in the log message.
18
+
19
+ :param name: The name of the logger. Default is ``__name__``.
20
+ :param rank_zero_only: Whether to force all logs to only occur on the rank zero process. Default is `False`.
21
+ :param extra: (Optional) A dict-like object which provides contextual information. See `logging.LoggerAdapter`.
22
+ """
23
+ logger = logging.getLogger(name)
24
+ super().__init__(logger=logger, extra=extra)
25
+ self.rank_zero_only = rank_zero_only
26
+
27
+ def log(
28
+ self, level: int, msg: str, rank: Optional[int] = None, *args, **kwargs
29
+ ) -> None:
30
+ """Delegate a log call to the underlying logger, after prefixing its message with the rank
31
+ of the process it's being logged from. If `'rank'` is provided, then the log will only
32
+ occur on that rank/process.
33
+
34
+ :param level: The level to log at. Look at `logging.__init__.py` for more information.
35
+ :param msg: The message to log.
36
+ :param rank: The rank to log at.
37
+ :param args: Additional args to pass to the underlying logging function.
38
+ :param kwargs: Any additional keyword args to pass to the underlying logging function.
39
+ """
40
+ if self.isEnabledFor(level):
41
+ msg, kwargs = self.process(msg, kwargs)
42
+ current_rank = getattr(rank_zero_only, "rank", None)
43
+ if current_rank is None:
44
+ raise RuntimeError(
45
+ "The `rank_zero_only.rank` needs to be set before use"
46
+ )
47
+ msg = rank_prefixed_message(msg, current_rank)
48
+ if self.rank_zero_only:
49
+ if current_rank == 0:
50
+ self.logger.log(level, msg, *args, **kwargs)
51
+ else:
52
+ if rank is None:
53
+ self.logger.log(level, msg, *args, **kwargs)
54
+ elif current_rank == rank:
55
+ self.logger.log(level, msg, *args, **kwargs)
fish_speech/utils/logging_utils.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from lightning.pytorch.utilities import rank_zero_only
2
+
3
+ from fish_speech.utils import logger as log
4
+
5
+
6
+ @rank_zero_only
7
+ def log_hyperparameters(object_dict: dict) -> None:
8
+ """Controls which config parts are saved by lightning loggers.
9
+
10
+ Additionally saves:
11
+ - Number of model parameters
12
+ """
13
+
14
+ hparams = {}
15
+
16
+ cfg = object_dict["cfg"]
17
+ model = object_dict["model"]
18
+ trainer = object_dict["trainer"]
19
+
20
+ if not trainer.logger:
21
+ log.warning("Logger not found! Skipping hyperparameter logging...")
22
+ return
23
+
24
+ hparams["model"] = cfg["model"]
25
+
26
+ # save number of model parameters
27
+ hparams["model/params/total"] = sum(p.numel() for p in model.parameters())
28
+ hparams["model/params/trainable"] = sum(
29
+ p.numel() for p in model.parameters() if p.requires_grad
30
+ )
31
+ hparams["model/params/non_trainable"] = sum(
32
+ p.numel() for p in model.parameters() if not p.requires_grad
33
+ )
34
+
35
+ hparams["data"] = cfg["data"]
36
+ hparams["trainer"] = cfg["trainer"]
37
+
38
+ hparams["callbacks"] = cfg.get("callbacks")
39
+ hparams["extras"] = cfg.get("extras")
40
+
41
+ hparams["task_name"] = cfg.get("task_name")
42
+ hparams["tags"] = cfg.get("tags")
43
+ hparams["ckpt_path"] = cfg.get("ckpt_path")
44
+ hparams["seed"] = cfg.get("seed")
45
+
46
+ # send hparams to all loggers
47
+ for logger in trainer.loggers:
48
+ logger.log_hyperparams(hparams)
fish_speech/utils/rich_utils.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from typing import Sequence
3
+
4
+ import rich
5
+ import rich.syntax
6
+ import rich.tree
7
+ from hydra.core.hydra_config import HydraConfig
8
+ from lightning.pytorch.utilities import rank_zero_only
9
+ from omegaconf import DictConfig, OmegaConf, open_dict
10
+ from rich.prompt import Prompt
11
+
12
+ from fish_speech.utils import logger as log
13
+
14
+
15
+ @rank_zero_only
16
+ def print_config_tree(
17
+ cfg: DictConfig,
18
+ print_order: Sequence[str] = (
19
+ "data",
20
+ "model",
21
+ "callbacks",
22
+ "logger",
23
+ "trainer",
24
+ "paths",
25
+ "extras",
26
+ ),
27
+ resolve: bool = False,
28
+ save_to_file: bool = False,
29
+ ) -> None:
30
+ """Prints content of DictConfig using Rich library and its tree structure.
31
+
32
+ Args:
33
+ cfg (DictConfig): Configuration composed by Hydra.
34
+ print_order (Sequence[str], optional): Determines in what order config components are printed.
35
+ resolve (bool, optional): Whether to resolve reference fields of DictConfig.
36
+ save_to_file (bool, optional): Whether to export config to the hydra output folder.
37
+ """ # noqa: E501
38
+
39
+ style = "dim"
40
+ tree = rich.tree.Tree("CONFIG", style=style, guide_style=style)
41
+
42
+ queue = []
43
+
44
+ # add fields from `print_order` to queue
45
+ for field in print_order:
46
+ queue.append(field) if field in cfg else log.warning(
47
+ f"Field '{field}' not found in config. "
48
+ + f"Skipping '{field}' config printing..."
49
+ )
50
+
51
+ # add all the other fields to queue (not specified in `print_order`)
52
+ for field in cfg:
53
+ if field not in queue:
54
+ queue.append(field)
55
+
56
+ # generate config tree from queue
57
+ for field in queue:
58
+ branch = tree.add(field, style=style, guide_style=style)
59
+
60
+ config_group = cfg[field]
61
+ if isinstance(config_group, DictConfig):
62
+ branch_content = OmegaConf.to_yaml(config_group, resolve=resolve)
63
+ else:
64
+ branch_content = str(config_group)
65
+
66
+ branch.add(rich.syntax.Syntax(branch_content, "yaml"))
67
+
68
+ # print config tree
69
+ rich.print(tree)
70
+
71
+ # save config tree to file
72
+ if save_to_file:
73
+ with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file:
74
+ rich.print(tree, file=file)
75
+
76
+
77
+ @rank_zero_only
78
+ def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None:
79
+ """Prompts user to input tags from command line if no tags are provided in config.""" # noqa: E501
80
+
81
+ if not cfg.get("tags"):
82
+ if "id" in HydraConfig().cfg.hydra.job:
83
+ raise ValueError("Specify tags before launching a multirun!")
84
+
85
+ log.warning("No tags provided in config. Prompting user to input tags...")
86
+ tags = Prompt.ask("Enter a list of comma separated tags", default="dev")
87
+ tags = [t.strip() for t in tags.split(",") if t != ""]
88
+
89
+ with open_dict(cfg):
90
+ cfg.tags = tags
91
+
92
+ log.info(f"Tags: {cfg.tags}")
93
+
94
+ if save_to_file:
95
+ with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file:
96
+ rich.print(cfg.tags, file=file)
fish_speech/utils/utils.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ from importlib.util import find_spec
3
+ from typing import Callable
4
+
5
+ from omegaconf import DictConfig
6
+
7
+ from .logger import RankedLogger
8
+ from .rich_utils import enforce_tags, print_config_tree
9
+
10
+ log = RankedLogger(__name__, rank_zero_only=True)
11
+
12
+
13
+ def extras(cfg: DictConfig) -> None:
14
+ """Applies optional utilities before the task is started.
15
+
16
+ Utilities:
17
+ - Ignoring python warnings
18
+ - Setting tags from command line
19
+ - Rich config printing
20
+ """
21
+
22
+ # return if no `extras` config
23
+ if not cfg.get("extras"):
24
+ log.warning("Extras config not found! <cfg.extras=null>")
25
+ return
26
+
27
+ # disable python warnings
28
+ if cfg.extras.get("ignore_warnings"):
29
+ log.info("Disabling python warnings! <cfg.extras.ignore_warnings=True>")
30
+ warnings.filterwarnings("ignore")
31
+
32
+ # prompt user to input tags from command line if none are provided in the config
33
+ if cfg.extras.get("enforce_tags"):
34
+ log.info("Enforcing tags! <cfg.extras.enforce_tags=True>")
35
+ enforce_tags(cfg, save_to_file=True)
36
+
37
+ # pretty print config tree using Rich library
38
+ if cfg.extras.get("print_config"):
39
+ log.info("Printing config tree with Rich! <cfg.extras.print_config=True>")
40
+ print_config_tree(cfg, resolve=True, save_to_file=True)
41
+
42
+
43
+ def task_wrapper(task_func: Callable) -> Callable:
44
+ """Optional decorator that controls the failure behavior when executing the task function.
45
+
46
+ This wrapper can be used to:
47
+ - make sure loggers are closed even if the task function raises an exception (prevents multirun failure)
48
+ - save the exception to a `.log` file
49
+ - mark the run as failed with a dedicated file in the `logs/` folder (so we can find and rerun it later)
50
+ - etc. (adjust depending on your needs)
51
+
52
+ Example:
53
+ ```
54
+ @utils.task_wrapper
55
+ def train(cfg: DictConfig) -> Tuple[dict, dict]:
56
+
57
+ ...
58
+
59
+ return metric_dict, object_dict
60
+ ```
61
+ """ # noqa: E501
62
+
63
+ def wrap(cfg: DictConfig):
64
+ # execute the task
65
+ try:
66
+ metric_dict, object_dict = task_func(cfg=cfg)
67
+
68
+ # things to do if exception occurs
69
+ except Exception as ex:
70
+ # save exception to `.log` file
71
+ log.exception("")
72
+
73
+ # some hyperparameter combinations might be invalid or
74
+ # cause out-of-memory errors so when using hparam search
75
+ # plugins like Optuna, you might want to disable
76
+ # raising the below exception to avoid multirun failure
77
+ raise ex
78
+
79
+ # things to always do after either success or exception
80
+ finally:
81
+ # display output dir path in terminal
82
+ log.info(f"Output dir: {cfg.paths.run_dir}")
83
+
84
+ # always close wandb run (even if exception occurs so multirun won't fail)
85
+ if find_spec("wandb"): # check if wandb is installed
86
+ import wandb
87
+
88
+ if wandb.run:
89
+ log.info("Closing wandb!")
90
+ wandb.finish()
91
+
92
+ return metric_dict, object_dict
93
+
94
+ return wrap
95
+
96
+
97
+ def get_metric_value(metric_dict: dict, metric_name: str) -> float:
98
+ """Safely retrieves value of the metric logged in LightningModule."""
99
+
100
+ if not metric_name:
101
+ log.info("Metric name is None! Skipping metric value retrieval...")
102
+ return None
103
+
104
+ if metric_name not in metric_dict:
105
+ raise Exception(
106
+ f"Metric value not found! <metric_name={metric_name}>\n"
107
+ "Make sure metric name logged in LightningModule is correct!\n"
108
+ "Make sure `optimized_metric` name in `hparams_search` config is correct!"
109
+ )
110
+
111
+ metric_value = metric_dict[metric_name].item()
112
+ log.info(f"Retrieved metric value! <{metric_name}={metric_value}>")
113
+
114
+ return metric_value
packages.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ git
2
+ curl
3
+ build-essential
4
+ ffmpeg
5
+ libsm6
6
+ libxext6
7
+ libjpeg-dev
8
+ zlib1g-dev
9
+ protobuf-compiler
10
+ cmake
pyrightconfig.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "exclude": [
3
+ "data",
4
+ "filelists"
5
+ ]
6
+ }
requirements.txt ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchaudio
3
+ transformers>=4.35.2
4
+ datasets>=2.14.5
5
+ lightning>=2.1.0
6
+ hydra-core>=1.3.2
7
+ tensorboard>=2.14.1
8
+ natsort>=8.4.0
9
+ einops>=0.7.0
10
+ librosa>=0.10.1
11
+ rich>=13.5.3
12
+ gradio>=4.0.0
13
+ wandb>=0.15.11
14
+ grpcio>=1.58.0
15
+ kui>=1.6.0
16
+ zibai-server>=0.9.0
17
+ loguru>=0.6.0
18
+ loralib>=0.1.2
19
+ natsort>=8.4.0
20
+ pyrootutils>=1.0.4
21
+ vector_quantize_pytorch>=1.14.7
22
+ samplerate>=0.2.1
23
+ resampy>=0.4.3
24
+ spaces>=0.26.1"
setup.sh ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ set -e
3
+
4
+ mkdir -p checkpoints
5
+
6
+ if [ -e checkpoints/text2semantic-medium-v1-2k.pth ]; then
7
+ echo "checkpoints/text2semantic-medium-v1-2k.pth already exists"
8
+ else
9
+ echo "Downloading text2semantic-medium-v1-2k.pth"
10
+ wget -O checkpoints/text2semantic-medium-v1-2k.pth $CKPT_SEMANTIC
11
+ fi
12
+
13
+ if [ -e checkpoints/vq-gan-group-fsq-2x1024.pth ]; then
14
+ echo "checkpoints/vq-gan-group-fsq-2x1024.pth already exists"
15
+ else
16
+ echo "Downloading vq-gan-group-fsq-2x1024.pth"
17
+ wget -O checkpoints/vq-gan-group-fsq-2x1024.pth $CKPT_VQGAN
18
+ fi
tools/extract_model.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import click
2
+ import torch
3
+ from loguru import logger
4
+
5
+
6
+ @click.command()
7
+ @click.argument("model_path")
8
+ @click.argument("output_path")
9
+ def main(model_path, output_path):
10
+ if model_path == output_path:
11
+ logger.error("Model path and output path are the same")
12
+ return
13
+
14
+ logger.info(f"Loading model from {model_path}")
15
+ state_dict = torch.load(model_path, map_location="cpu")["state_dict"]
16
+ torch.save(state_dict, output_path)
17
+ logger.info(f"Model saved to {output_path}")
18
+
19
+
20
+ if __name__ == "__main__":
21
+ main()
tools/llama/build_dataset.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+ import os
3
+ import re
4
+ from collections import defaultdict
5
+ from functools import partial
6
+ from multiprocessing import Pool
7
+ from pathlib import Path
8
+
9
+ import click
10
+ import numpy as np
11
+ from loguru import logger
12
+ from tqdm import tqdm
13
+
14
+ from fish_speech.datasets.protos.text_data_pb2 import Semantics, Sentence, TextData
15
+ from fish_speech.datasets.protos.text_data_stream import pack_pb_stream
16
+ from fish_speech.utils.file import load_filelist
17
+
18
+ # To avoid CPU overload
19
+ os.environ["MKL_NUM_THREADS"] = "1"
20
+ os.environ["OMP_NUM_THREADS"] = "1"
21
+
22
+
23
+ def task_generator_folder(root: Path, text_extension: str):
24
+ files = list(tqdm(Path(root).rglob("*.npy"), desc=f"Loading {root}"))
25
+ files = sorted(files)
26
+
27
+ grouped_files = defaultdict(list)
28
+ for file in tqdm(files, desc=f"Grouping {root}"):
29
+ p = str(file.parent)
30
+
31
+ try:
32
+ if isinstance(text_extension, str):
33
+ texts = [file.with_suffix(text_extension).read_text()]
34
+ else:
35
+ texts = [file.with_suffix(ext).read_text() for ext in text_extension]
36
+ except Exception as e:
37
+ logger.error(f"Failed to read text {file}: {e}")
38
+ continue
39
+
40
+ grouped_files[p].append((file, texts))
41
+
42
+ logger.info(
43
+ f"Found {len(grouped_files)} groups in {root}, {list(grouped_files.keys())[:5]}..."
44
+ )
45
+ for name, subset in grouped_files.items():
46
+ yield name, subset, "folder"
47
+
48
+
49
+ def task_generator_filelist(filelist):
50
+ grouped_files = defaultdict(list)
51
+ for filename, speaker, _, text in load_filelist(filelist):
52
+ grouped_files[speaker].append((Path(filename), [text]))
53
+
54
+ logger.info(f"Found {len(grouped_files)} groups in {filelist}")
55
+ for speaker, values in grouped_files.items():
56
+ yield speaker, values, "filelist"
57
+
58
+
59
+ def run_task(task):
60
+ name, subset, source = task
61
+
62
+ # Parse the files
63
+ sentences = []
64
+ for file in subset:
65
+ file, texts = file
66
+
67
+ np_file = file.with_suffix(".npy")
68
+ if np_file.exists() is False:
69
+ logger.warning(f"Can't find {np_file}")
70
+ continue
71
+
72
+ new_texts = []
73
+
74
+ for text in texts:
75
+ # Simple cleaning: replace { xxx } and < xxx > with space
76
+ text = re.sub(r"\{.*?\}", " ", text)
77
+ text = re.sub(r"<.*?>", " ", text)
78
+ text = re.sub(r"\s+", " ", text)
79
+ new_texts.append(text)
80
+
81
+ try:
82
+ semantics = np.load(np_file)
83
+ except Exception as e:
84
+ logger.error(f"Failed to parse {file}: {e}")
85
+ continue
86
+
87
+ if isinstance(semantics, np.ndarray):
88
+ semantics = semantics.tolist()
89
+
90
+ sentences.append(
91
+ Sentence(
92
+ texts=new_texts,
93
+ semantics=[Semantics(values=s) for s in semantics],
94
+ )
95
+ )
96
+
97
+ # Pack the sentences
98
+ return pack_pb_stream(
99
+ TextData(
100
+ source=source,
101
+ name=name,
102
+ sentences=sentences,
103
+ )
104
+ )
105
+
106
+
107
+ @click.command()
108
+ @click.option(
109
+ "--input",
110
+ type=click.Path(path_type=Path),
111
+ required=True,
112
+ help="A folder containing the dataset or a filelist",
113
+ multiple=True,
114
+ )
115
+ @click.option(
116
+ "--output", type=click.Path(path_type=Path), default="data/quantized-dataset-ft"
117
+ )
118
+ @click.option("--num-workers", type=int, default=16)
119
+ @click.option("--text-extension", type=str, default=[".txt"], multiple=True)
120
+ @click.option(
121
+ "--shard-size", type=int, default=10, help="The maximum size of each shard in mb"
122
+ )
123
+ def main(input, output, num_workers, text_extension, shard_size):
124
+ generator_fns = []
125
+
126
+ for f in input:
127
+ assert f.exists(), f"{f} not found"
128
+
129
+ if f.is_dir():
130
+ generator_fn = task_generator_folder(f, text_extension)
131
+ else:
132
+ generator_fn = task_generator_filelist(f)
133
+
134
+ generator_fns.append(generator_fn)
135
+
136
+ generator_fn = itertools.chain(*generator_fns)
137
+ output.mkdir(parents=True, exist_ok=True)
138
+
139
+ dataset_fp = None
140
+ tar_idx = 0
141
+ written_size = 0
142
+
143
+ with Pool(num_workers) as p:
144
+ for result in tqdm(p.imap_unordered(run_task, generator_fn)):
145
+ if dataset_fp is None:
146
+ dataset_fp = open(Path(output) / f"{tar_idx:08d}.protos", "wb")
147
+
148
+ dataset_fp.write(result)
149
+ written_size += len(result)
150
+
151
+ if written_size > shard_size * 1024 * 1024:
152
+ logger.info(f"Finished writing {tar_idx} shards to {output}")
153
+ dataset_fp.close()
154
+ dataset_fp = None
155
+ written_size = 0
156
+ tar_idx += 1
157
+
158
+ if dataset_fp is not None:
159
+ dataset_fp.close()
160
+
161
+ logger.info(f"Finished writing {tar_idx + 1} shards to {output}")
162
+
163
+
164
+ if __name__ == "__main__":
165
+ main()