diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..701e8751212e3f017443e43e871f6b809d631426
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,2 @@
+__pycache__
+checkpoints
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..05c0dc91bab552e19c9df5603c7d2468ce0e3ce0
--- /dev/null
+++ b/app.py
@@ -0,0 +1,317 @@
+import subprocess as sp
+sp.check_call("setup.sh", shell=True)
+
+import html
+import os
+from argparse import ArgumentParser
+from io import BytesIO
+from pathlib import Path
+
+import gradio as gr
+import librosa
+import spaces
+import torch
+from loguru import logger
+from torchaudio import functional as AF
+from transformers import AutoTokenizer
+
+from tools.llama.generate import generate_long
+from tools.llama.generate import load_model as load_llama_model
+from tools.vqgan.inference import load_model as load_vqgan_model
+
+# Make einx happy
+os.environ["EINX_FILTER_TRACEBACK"] = "false"
+
+
+HEADER_MD = """# Fish Speech
+
+A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).
+由 [Fish Audio](https://fish.audio) 研发的基于 VQ-GAN 和 Llama 的多语种语音合成.
+
+You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co./fishaudio/fish-speech-1).
+你可以在 [这里](https://github.com/fishaudio/fish-speech) 找到源代码和 [这里](https://huggingface.co./fishaudio/fish-speech-1) 找到模型.
+
+Related code are released under BSD-3-Clause License, and weights are released under CC BY-NC-SA 4.0 License.
+相关代码使用 BSD-3-Clause 许可证发布,权重使用 CC BY-NC-SA 4.0 许可证发布.
+
+We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.
+我们不对模型的任何滥用负责,请在使用之前考虑您当地的法律法规.
+"""
+
+TEXTBOX_PLACEHOLDER = """Put your text here. 在此处输入文本."""
+
+
+def build_html_error_message(error):
+ return f"""
+
+ {html.escape(error)}
+
+ """
+
+
+@spaces.GPU
+def inference(
+ text,
+ enable_reference_audio,
+ reference_audio,
+ reference_text,
+ max_new_tokens,
+ chunk_length,
+ top_k,
+ top_p,
+ repetition_penalty,
+ temperature,
+ speaker=None,
+):
+ if len(reference_text) > 100:
+ return None, "Ref text is too long, please keep it under 100 characters."
+
+ if args.max_gradio_length > 0 and len(text) > args.max_gradio_length:
+ return None, "Text is too long, please keep it under 1000 characters."
+
+ # Parse reference audio aka prompt
+ if enable_reference_audio and reference_audio is not None:
+ # reference_audio_sr, reference_audio_content = reference_audio
+ reference_audio_content, _ = librosa.load(
+ reference_audio, sr=vqgan_model.sampling_rate, mono=True
+ )
+ audios = torch.from_numpy(reference_audio_content).to(vqgan_model.device)[
+ None, None, :
+ ]
+
+ logger.info(
+ f"Loaded audio with {audios.shape[2] / vqgan_model.sampling_rate:.2f} seconds"
+ )
+
+ # VQ Encoder
+ audio_lengths = torch.tensor(
+ [audios.shape[2]], device=vqgan_model.device, dtype=torch.long
+ )
+ prompt_tokens = vqgan_model.encode(audios, audio_lengths)[0][0]
+
+ # LLAMA Inference
+ result = generate_long(
+ model=llama_model,
+ tokenizer=llama_tokenizer,
+ device=vqgan_model.device,
+ decode_one_token=decode_one_token,
+ max_new_tokens=max_new_tokens,
+ text=text,
+ top_k=int(top_k) if top_k > 0 else None,
+ top_p=top_p,
+ repetition_penalty=repetition_penalty,
+ temperature=temperature,
+ compile=args.compile,
+ iterative_prompt=chunk_length > 0,
+ chunk_length=chunk_length,
+ max_length=args.max_length,
+ speaker=speaker if speaker else None,
+ prompt_tokens=prompt_tokens if enable_reference_audio else None,
+ prompt_text=reference_text if enable_reference_audio else None,
+ )
+
+ codes = next(result)
+
+ # VQGAN Inference
+ feature_lengths = torch.tensor([codes.shape[1]], device=vqgan_model.device)
+ fake_audios = vqgan_model.decode(
+ indices=codes[None], feature_lengths=feature_lengths, return_audios=True
+ )[0, 0]
+
+ fake_audios = fake_audios.float().cpu().numpy()
+
+ return (vqgan_model.sampling_rate, fake_audios), None
+
+
+def build_app():
+ with gr.Blocks(theme=gr.themes.Base()) as app:
+ gr.Markdown(HEADER_MD)
+
+ # Use light theme by default
+ app.load(
+ None,
+ None,
+ js="() => {const params = new URLSearchParams(window.location.search);if (!params.has('__theme')) {params.set('__theme', 'light');window.location.search = params.toString();}}",
+ )
+
+ # Inference
+ with gr.Row():
+ with gr.Column(scale=3):
+ text = gr.Textbox(
+ label="Input Text / 输入文本",
+ placeholder=TEXTBOX_PLACEHOLDER,
+ lines=15,
+ )
+
+ with gr.Row():
+ with gr.Tab(label="Advanced Config / 高级参数"):
+ chunk_length = gr.Slider(
+ label="Iterative Prompt Length, 0 means off / 迭代提示长度,0 表示关闭",
+ minimum=0,
+ maximum=100,
+ value=30,
+ step=8,
+ )
+
+ max_new_tokens = gr.Slider(
+ label="Maximum tokens per batch, 0 means no limit / 每批最大令牌数,0 表示无限制",
+ minimum=128,
+ maximum=512,
+ value=512, # 0 means no limit
+ step=8,
+ )
+
+ top_k = gr.Slider(
+ label="Top-K", minimum=0, maximum=5, value=0, step=1
+ )
+
+ top_p = gr.Slider(
+ label="Top-P", minimum=0, maximum=1, value=0.7, step=0.01
+ )
+
+ repetition_penalty = gr.Slider(
+ label="Repetition Penalty",
+ minimum=0,
+ maximum=2,
+ value=1.5,
+ step=0.01,
+ )
+
+ temperature = gr.Slider(
+ label="Temperature",
+ minimum=0,
+ maximum=2,
+ value=0.7,
+ step=0.01,
+ )
+
+ # speaker = gr.Textbox(
+ # label="Speaker / 说话人",
+ # placeholder="Type name of the speaker / 输入说话人的名称",
+ # lines=1,
+ # )
+
+ with gr.Tab(label="Reference Audio / 参考音频"):
+ gr.Markdown(
+ "5 to 10 seconds of reference audio, useful for specifying speaker. \n5 到 10 秒的参考音频,适用于指定音色。"
+ )
+
+ enable_reference_audio = gr.Checkbox(
+ label="Enable Reference Audio / 启用参考音频",
+ )
+ reference_audio = gr.Audio(
+ label="Reference Audio / 参考音频",
+ value="docs/assets/audios/0_input.wav",
+ type="filepath",
+ )
+ reference_text = gr.Textbox(
+ label="Reference Text / 参考文本",
+ placeholder="参考文本",
+ lines=1,
+ value="在一无所知中,梦里的一天结束了,一个新的「轮回」便会开始。",
+ )
+
+ with gr.Column(scale=3):
+ with gr.Row():
+ error = gr.HTML(label="Error Message / 错误信息")
+ with gr.Row():
+ audio = gr.Audio(label="Generated Audio / 音频", type="numpy")
+
+ with gr.Row():
+ with gr.Column(scale=3):
+ generate = gr.Button(
+ value="\U0001F3A7 Generate / 合成", variant="primary"
+ )
+
+ # # Submit
+ generate.click(
+ inference,
+ [
+ text,
+ enable_reference_audio,
+ reference_audio,
+ reference_text,
+ max_new_tokens,
+ chunk_length,
+ top_k,
+ top_p,
+ repetition_penalty,
+ temperature,
+ # speaker,
+ ],
+ [audio, error],
+ )
+
+ return app
+
+
+def parse_args():
+ parser = ArgumentParser()
+ parser.add_argument(
+ "--llama-checkpoint-path",
+ type=Path,
+ default="checkpoints/text2semantic-medium-v1-2k.pth",
+ )
+ parser.add_argument(
+ "--llama-config-name", type=str, default="dual_ar_2_codebook_medium"
+ )
+ parser.add_argument(
+ "--vqgan-checkpoint-path",
+ type=Path,
+ default="checkpoints/vq-gan-group-fsq-2x1024.pth",
+ )
+ parser.add_argument("--vqgan-config-name", type=str, default="vqgan_pretrain")
+ parser.add_argument("--tokenizer", type=str, default="fishaudio/fish-speech-1")
+ parser.add_argument("--device", type=str, default="cuda")
+ parser.add_argument("--half", action="store_true")
+ parser.add_argument("--max-length", type=int, default=2048)
+ parser.add_argument("--compile", action="store_true")
+ parser.add_argument("--max-gradio-length", type=int, default=1024)
+
+ return parser.parse_args()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+
+ args.precision = torch.half if args.half else torch.bfloat16
+
+ logger.info("Loading Llama model...")
+ llama_model, decode_one_token = load_llama_model(
+ config_name=args.llama_config_name,
+ checkpoint_path=args.llama_checkpoint_path,
+ device=args.device,
+ precision=args.precision,
+ max_length=args.max_length,
+ compile=args.compile,
+ )
+ llama_tokenizer = AutoTokenizer.from_pretrained(args.tokenizer)
+ logger.info("Llama model loaded, loading VQ-GAN model...")
+
+ vqgan_model = load_vqgan_model(
+ config_name=args.vqgan_config_name,
+ checkpoint_path=args.vqgan_checkpoint_path,
+ device=args.device,
+ )
+
+ logger.info("VQ-GAN model loaded, warming up...")
+
+ # Dry run to check if the model is loaded correctly and avoid the first-time latency
+ inference(
+ text="Hello, world!",
+ enable_reference_audio=False,
+ reference_audio=None,
+ reference_text="",
+ max_new_tokens=0,
+ chunk_length=0,
+ top_k=0, # 0 means no limit
+ top_p=0.7,
+ repetition_penalty=1.5,
+ temperature=0.7,
+ speaker=None,
+ )
+
+ logger.info("Warming up done, launching the web UI...")
+
+ app = build_app()
+ app.launch(show_api=False)
diff --git a/fish_speech/callbacks/__init__.py b/fish_speech/callbacks/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..bbcf3f33656d180ca87cd14a21ede1544e5a61a3
--- /dev/null
+++ b/fish_speech/callbacks/__init__.py
@@ -0,0 +1,3 @@
+from .grad_norm import GradNormMonitor
+
+__all__ = ["GradNormMonitor"]
diff --git a/fish_speech/callbacks/grad_norm.py b/fish_speech/callbacks/grad_norm.py
new file mode 100644
index 0000000000000000000000000000000000000000..dbc95ef2a3723323b2d976001ed1e3c79c00b21a
--- /dev/null
+++ b/fish_speech/callbacks/grad_norm.py
@@ -0,0 +1,113 @@
+from typing import Optional, Union
+
+import lightning.pytorch as pl
+import torch
+from lightning import LightningModule, Trainer
+from lightning.pytorch.callbacks import Callback
+from torch import Tensor, nn
+from torch.utils._foreach_utils import (
+ _group_tensors_by_device_and_dtype,
+ _has_foreach_support,
+)
+
+
+@torch.no_grad()
+def grad_norm(
+ parameters: Union[Tensor, list[Tensor]],
+ norm_type: float = 2.0,
+) -> float:
+ """
+ Returns the norm of the gradients of the given parameters.
+
+ Args:
+ parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
+ single Tensor that will have gradients normalized
+ norm_type (float): type of the used p-norm.
+
+ Returns:
+ Total norm of the parameter gradients (viewed as a single vector).
+ """ # noqa: E501
+
+ if isinstance(parameters, Tensor):
+ parameters = [parameters]
+
+ grads = [p.grad for p in parameters if p.grad is not None]
+ if len(grads) == 0:
+ return None
+
+ first_device = grads[0].device
+ grouped_grads: dict[
+ tuple[torch.device, torch.dtype], list[list[Tensor]]
+ ] = _group_tensors_by_device_and_dtype(
+ [[g.detach() for g in grads]]
+ ) # type: ignore[assignment]
+
+ norms = []
+ for (device, _), ([grads], _) in grouped_grads.items():
+ if _has_foreach_support(grads, device=device):
+ norms.extend(torch._foreach_norm(grads, norm_type))
+ else:
+ norms.extend([torch.norm(g, norm_type) for g in grads])
+
+ return torch.norm(torch.stack([norm.to(first_device) for norm in norms]), norm_type)
+
+
+class GradNormMonitor(Callback):
+ """
+ Callback that computes the gradient norm of the model parameters.
+ """
+
+ def __init__(
+ self,
+ norm_type: float = 2.0,
+ logging_interval: str = "step",
+ sub_module: Optional[Union[str, list[str]]] = None,
+ ) -> None:
+ """
+ Args:
+ norm_type (float): type of the used p-norm.
+ logging_interval (str): "step" or "epoch".
+ """
+ super().__init__()
+
+ self.norm_type = norm_type
+ self.logging_interval = logging_interval
+ self.sub_module = sub_module
+
+ def on_after_backward(self, trainer: Trainer, model: LightningModule) -> None:
+ """
+ Computes the gradient norm of the model parameters and logs it to the logger.
+
+ Args:
+ trainer (Trainer): The trainer object
+ model (LightningModule): The current lightningModule
+ """
+
+ lightning_model = model
+
+ if self.sub_module is None:
+ return self.log_sub_module_grad_norm(lightning_model, model, "")
+
+ sub_modules = self.sub_module
+ if isinstance(sub_modules, str):
+ sub_modules = [sub_modules]
+
+ for sub_module in sub_modules:
+ self.log_sub_module_grad_norm(
+ lightning_model, getattr(model, sub_module), f"/{sub_module}"
+ )
+
+ def log_sub_module_grad_norm(
+ self, lightning_model: LightningModule, model: nn.Module, path: str
+ ) -> None:
+ grad_norm_val = grad_norm(model.parameters(), self.norm_type)
+ if grad_norm_val is None:
+ return
+
+ on_step = self.logging_interval == "step"
+ lightning_model.log(
+ f"train{path}/grad_norm",
+ grad_norm_val,
+ on_step=on_step,
+ on_epoch=not on_step,
+ )
diff --git a/fish_speech/configs/base.yaml b/fish_speech/configs/base.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..6c416af60c813a4dbef322e4611db1e9f46a86c4
--- /dev/null
+++ b/fish_speech/configs/base.yaml
@@ -0,0 +1,86 @@
+# Base configuration for training a model
+paths:
+ run_dir: results/${project}
+ ckpt_dir: ${paths.run_dir}/checkpoints
+
+hydra:
+ run:
+ dir: ${paths.run_dir}
+
+# Lightning Trainer
+trainer:
+ _target_: lightning.pytorch.trainer.Trainer
+
+ default_root_dir: ${paths.run_dir}
+ accelerator: gpu
+ num_nodes: 1
+ devices: auto
+ strategy:
+ _target_: lightning.pytorch.strategies.DDPStrategy
+
+ precision: bf16-mixed
+
+ # disable validation by epoch end
+ check_val_every_n_epoch: null
+ val_check_interval: 5000
+ max_steps: 100_000
+
+ # Use torch.backends.cudnn.benchmark to speed up training
+ benchmark: true
+
+# Callbacks
+callbacks:
+ model_checkpoint:
+ _target_: lightning.pytorch.callbacks.ModelCheckpoint
+ dirpath: ${paths.ckpt_dir}
+ filename: "step_{step:09d}"
+ save_last: false # additionally always save an exact copy of the last checkpoint to a file last.ckpt
+ save_top_k: 5 # save 5 latest checkpoints
+ monitor: step # use step to monitor checkpoints
+ mode: max # save the latest checkpoint with the highest global_step
+ every_n_epochs: null # don't save checkpoints by epoch end
+ every_n_train_steps: 5000 # save checkpoints every 5000 steps
+ auto_insert_metric_name: false
+
+ model_summary:
+ _target_: lightning.pytorch.callbacks.ModelSummary
+ max_depth: 2 # the maximum depth of layer nesting that the summary will include
+
+ learning_rate_monitor:
+ _target_: lightning.pytorch.callbacks.LearningRateMonitor
+ logging_interval: step
+ log_momentum: false
+
+ grad_norm_monitor:
+ _target_: fish_speech.callbacks.GradNormMonitor
+ norm_type: 2
+ logging_interval: step
+
+# Logger
+logger:
+ tensorboard:
+ _target_: lightning.pytorch.loggers.tensorboard.TensorBoardLogger
+ save_dir: "${paths.run_dir}/tensorboard/"
+ name: null
+ log_graph: false
+ default_hp_metric: true
+ prefix: ""
+
+ # wandb:
+ # _target_: lightning.pytorch.loggers.wandb.WandbLogger
+ # # name: "" # name of the run (normally generated by wandb)
+ # save_dir: "${paths.run_dir}"
+ # offline: False
+ # id: null # pass correct id to resume experiment!
+ # anonymous: null # enable anonymous logging
+ # project: "fish-speech"
+ # log_model: False # upload lightning ckpts
+ # prefix: "" # a string to put at the beginning of metric keys
+ # # entity: "" # set to name of your wandb team
+ # group: ""
+ # tags: ["vq", "hq", "finetune"]
+ # job_type: ""
+
+# Loop
+train: true
+test: false
diff --git a/fish_speech/configs/model/dual_ar_2_codebook_large.yaml b/fish_speech/configs/model/dual_ar_2_codebook_large.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..d4504d8bba62009710ec7761b4dcc87f3172be01
--- /dev/null
+++ b/fish_speech/configs/model/dual_ar_2_codebook_large.yaml
@@ -0,0 +1,9 @@
+defaults:
+ - dual_ar_2_codebook_small
+ - _self_
+
+config:
+ n_layer: 30
+ n_fast_layer: 6
+ n_head: 24
+ dim: 1536
diff --git a/fish_speech/configs/model/dual_ar_2_codebook_medium.yaml b/fish_speech/configs/model/dual_ar_2_codebook_medium.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..0ad6c2a10ac82452e33685d08c188d6c1e735678
--- /dev/null
+++ b/fish_speech/configs/model/dual_ar_2_codebook_medium.yaml
@@ -0,0 +1,9 @@
+defaults:
+ - dual_ar_2_codebook_small
+ - _self_
+
+config:
+ n_layer: 24
+ n_fast_layer: 6
+ n_head: 16
+ dim: 1024
diff --git a/fish_speech/configs/model/dual_ar_2_codebook_small.yaml b/fish_speech/configs/model/dual_ar_2_codebook_small.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..a56974083d1d247629b2d83cb81a33838c70c22d
--- /dev/null
+++ b/fish_speech/configs/model/dual_ar_2_codebook_small.yaml
@@ -0,0 +1,13 @@
+_target_: fish_speech.models.text2semantic.llama.DualARTransformer
+config:
+ _target_: fish_speech.models.text2semantic.llama.DualARModelArgs
+ max_seq_len: ${max_length}
+ vocab_size: 264 # pad 262 to 8x
+ n_layer: 12
+ n_fast_layer: 4
+ n_head: 12
+ dim: 768
+ rope_base: 10000
+ norm_eps: 1e-5
+ num_codebooks: 2 # input/output codebook size
+ codebook_size: 1032 # codebook size 1024 + 2 special tokens
diff --git a/fish_speech/configs/model/naive_2_codebook_small.yaml b/fish_speech/configs/model/naive_2_codebook_small.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..16d1737c90c9c4f88587202ea7e7bb5bd741b30f
--- /dev/null
+++ b/fish_speech/configs/model/naive_2_codebook_small.yaml
@@ -0,0 +1,12 @@
+_target_: fish_speech.models.text2semantic.llama.NaiveTransformer
+config:
+ _target_: fish_speech.models.text2semantic.llama.NaiveModelArgs
+ max_seq_len: ${max_length}
+ vocab_size: 36408
+ n_layer: 12
+ n_head: 12
+ dim: 768
+ rope_base: 10000
+ norm_eps: 1e-5
+ num_codebooks: 2 # input/output codebook size
+ codebook_size: 1032 # codebook size 1024 + 2 special tokens
diff --git a/fish_speech/configs/text2semantic_finetune.yaml b/fish_speech/configs/text2semantic_finetune.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..7165839864ffbfbff3b21901ae1dfdf707ef66c3
--- /dev/null
+++ b/fish_speech/configs/text2semantic_finetune.yaml
@@ -0,0 +1,79 @@
+defaults:
+ - base
+ - model@model.model: dual_ar_2_codebook_small
+ - _self_
+
+project: text2semantic_finetune_dual_ar
+max_length: 2048
+ckpt_path: checkpoints/text2semantic-medium-v1-2k.pth
+resume_weights_only: true
+
+# Lightning Trainer
+trainer:
+ accumulate_grad_batches: 1
+ gradient_clip_val: 1.0
+ gradient_clip_algorithm: 'norm'
+ max_steps: 1000
+ precision: bf16-true
+ limit_val_batches: 10
+ val_check_interval: 100
+
+# Dataset Configuration
+tokenizer:
+ _target_: transformers.AutoTokenizer.from_pretrained
+ pretrained_model_name_or_path: fishaudio/fish-speech-1
+
+# Dataset Configuration
+train_dataset:
+ _target_: fish_speech.datasets.text.AutoAugTextDataset
+ proto_files:
+ - data/protos
+ tokenizer: ${tokenizer}
+ max_length: ${max_length}
+ num_codebooks: ${model.model.config.num_codebooks}
+ use_speaker: false
+
+val_dataset:
+ _target_: fish_speech.datasets.text.AutoAugTextDataset
+ proto_files:
+ - data/protos
+ tokenizer: ${tokenizer}
+ max_length: ${max_length}
+ num_codebooks: ${model.model.config.num_codebooks}
+ use_speaker: false
+
+data:
+ _target_: fish_speech.datasets.text.TextDataModule
+ train_dataset: ${train_dataset}
+ val_dataset: ${val_dataset}
+ num_workers: 4
+ batch_size: 8
+ tokenizer: ${tokenizer}
+ max_length: ${max_length}
+
+# Model Configuration
+model:
+ _target_: fish_speech.models.text2semantic.TextToSemantic
+ model: {}
+
+ optimizer:
+ _target_: torch.optim.AdamW
+ _partial_: true
+ lr: 1e-5
+ weight_decay: 0
+ betas: [0.9, 0.95]
+ eps: 1e-5
+
+ lr_scheduler:
+ _target_: torch.optim.lr_scheduler.LambdaLR
+ _partial_: true
+ lr_lambda:
+ _target_: fish_speech.scheduler.get_cosine_schedule_with_warmup_lr_lambda
+ _partial_: true
+ num_warmup_steps: 100
+ num_training_steps: ${trainer.max_steps}
+
+# Callbacks
+callbacks:
+ model_checkpoint:
+ every_n_train_steps: 100
diff --git a/fish_speech/configs/text2semantic_finetune_lora.yaml b/fish_speech/configs/text2semantic_finetune_lora.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..36da21ceac8fd9e0e684b5863d8200156b53e0de
--- /dev/null
+++ b/fish_speech/configs/text2semantic_finetune_lora.yaml
@@ -0,0 +1,13 @@
+defaults:
+ - text2semantic_finetune
+ - _self_
+
+project: text2semantic_finetune_dual_ar_lora
+
+# Model Configuration
+model:
+ save_lora_only: true
+ lora_config:
+ _target_: fish_speech.models.text2semantic.lit_module.LoraConfig
+ r: 8
+ lora_alpha: 16
diff --git a/fish_speech/configs/text2semantic_pretrain.yaml b/fish_speech/configs/text2semantic_pretrain.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..98983f4c2417028f980d24402aaecae39049d295
--- /dev/null
+++ b/fish_speech/configs/text2semantic_pretrain.yaml
@@ -0,0 +1,74 @@
+defaults:
+ - base
+ - model@model.model: dual_ar_2_codebook_small
+ - _self_
+
+project: text2semantic_pretrain_dual_ar_debug
+max_length: 2048
+
+# Lightning Trainer
+trainer:
+ accumulate_grad_batches: 1
+ gradient_clip_val: 1.0
+ gradient_clip_algorithm: 'norm'
+ max_steps: 1_000_000
+ precision: bf16-true
+ limit_val_batches: 10
+
+# Dataset Configuration
+tokenizer:
+ _target_: transformers.AutoTokenizer.from_pretrained
+ pretrained_model_name_or_path: fishaudio/fish-speech-1
+
+# Dataset Configuration
+train_dataset:
+ _target_: fish_speech.datasets.text.AutoAugTextDataset
+ proto_files:
+ - data/protos/train
+ tokenizer: ${tokenizer}
+ max_length: ${max_length}
+ num_codebooks: ${model.model.config.num_codebooks}
+ use_speaker: false
+ interactive_prob: 0.5
+
+val_dataset:
+ _target_: fish_speech.datasets.text.AutoAugTextDataset
+ proto_files:
+ - data/protos/test
+ tokenizer: ${tokenizer}
+ max_length: ${max_length}
+ num_codebooks: ${model.model.config.num_codebooks}
+ use_speaker: false
+ interactive_prob: 0.5
+
+data:
+ _target_: fish_speech.datasets.text.TextDataModule
+ train_dataset: ${train_dataset}
+ val_dataset: ${val_dataset}
+ num_workers: 4
+ batch_size: 8
+ tokenizer: ${tokenizer}
+ max_length: ${max_length}
+
+# Model Configuration
+model:
+ _target_: fish_speech.models.text2semantic.TextToSemantic
+ model: {}
+
+ optimizer:
+ _target_: torch.optim.AdamW
+ _partial_: true
+ lr: 3e-4
+ weight_decay: 0.01
+ betas: [0.9, 0.95]
+ eps: 1e-5
+
+ lr_scheduler:
+ _target_: torch.optim.lr_scheduler.LambdaLR
+ _partial_: true
+ lr_lambda:
+ _target_: fish_speech.scheduler.get_cosine_schedule_with_warmup_lr_lambda
+ _partial_: true
+ num_warmup_steps: 2000
+ num_training_steps: ${trainer.max_steps}
+ final_lr_ratio: 0.1
diff --git a/fish_speech/configs/text2semantic_sft.yaml b/fish_speech/configs/text2semantic_sft.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..9a3cf2675fc1d5f8feb3355b7e9c914a8f6254f0
--- /dev/null
+++ b/fish_speech/configs/text2semantic_sft.yaml
@@ -0,0 +1,87 @@
+defaults:
+ - base
+ - model@model.model: dual_ar_8_codebook_small
+ - _self_
+
+project: text2semantic_sft_medium_dual_ar
+max_length: 4096
+ckpt_path: results/text2semantic_pretrain_medium_dual_ar/checkpoints/step_000060000.ckpt
+resume_weights_only: true
+
+# Lightning Trainer
+trainer:
+ accumulate_grad_batches: 1
+ gradient_clip_val: 1.0
+ gradient_clip_algorithm: 'norm'
+ max_steps: 10_000
+ precision: bf16-true
+ limit_val_batches: 10
+ val_check_interval: 500
+
+# Dataset Configuration
+tokenizer:
+ _target_: transformers.AutoTokenizer.from_pretrained
+ pretrained_model_name_or_path: fishaudio/speech-lm-v1
+
+# Dataset Configuration
+train_dataset:
+ _target_: fish_speech.datasets.text.AutoAugTextDataset
+ use_data_server: false
+ proto_files:
+ - data/protos/sft/train_Genshin.protos
+ - data/protos/sft/sft.protos
+ tokenizer: ${tokenizer}
+ max_length: ${max_length}
+ num_codebooks: ${model.model.config.num_codebooks}
+ use_speaker: false
+ phones_prob: 0.5
+ interactive_prob: 0.5
+
+val_dataset:
+ _target_: fish_speech.datasets.text.AutoAugTextDataset
+ use_data_server: false
+ proto_files:
+ - data/protos/sft/val_Genshin.protos
+ tokenizer: ${tokenizer}
+ max_length: ${max_length}
+ num_codebooks: ${model.model.config.num_codebooks}
+ use_speaker: false
+ phones_prob: 0.5
+ interactive_prob: 0.5
+
+data:
+ _target_: fish_speech.datasets.text.TextDataModule
+ train_dataset: ${train_dataset}
+ val_dataset: ${val_dataset}
+ num_workers: 4
+ batch_size: 8
+ tokenizer: ${tokenizer}
+ max_length: ${max_length}
+
+# Model Configuration
+model:
+ _target_: fish_speech.models.text2semantic.TextToSemantic
+ model: {}
+
+ optimizer:
+ _target_: torch.optim.AdamW
+ _partial_: true
+ lr: 4e-5
+ weight_decay: 0
+ betas: [0.9, 0.95]
+ eps: 1e-5
+
+ lr_scheduler:
+ _target_: torch.optim.lr_scheduler.LambdaLR
+ _partial_: true
+ lr_lambda:
+ _target_: fish_speech.scheduler.get_cosine_schedule_with_warmup_lr_lambda
+ _partial_: true
+ num_warmup_steps: 100
+ num_training_steps: ${trainer.max_steps}
+ final_lr_ratio: 0
+
+callbacks:
+ model_checkpoint:
+ every_n_train_steps: 1000
+ save_top_k: 10
diff --git a/fish_speech/configs/vqgan_finetune.yaml b/fish_speech/configs/vqgan_finetune.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..138ab975b718e6cd93c92cd91c5b0bd3c6ad206e
--- /dev/null
+++ b/fish_speech/configs/vqgan_finetune.yaml
@@ -0,0 +1,135 @@
+defaults:
+ - base
+ - _self_
+
+project: vq-gan-finetune
+ckpt_path: checkpoints/vq-gan-group-fsq-2x1024.pth
+resume_weights_only: true
+
+# Lightning Trainer
+trainer:
+ accelerator: gpu
+ devices: auto
+ precision: bf16-mixed
+ max_steps: 100_000
+ val_check_interval: 5000
+ strategy: ddp_find_unused_parameters_true
+
+sample_rate: 44100
+hop_length: 512
+num_mels: 128
+n_fft: 2048
+win_length: 2048
+freeze_encoder: true
+
+# Dataset Configuration
+train_dataset:
+ _target_: fish_speech.datasets.vqgan.VQGANDataset
+ filelist: data/filelist.train.txt
+ sample_rate: ${sample_rate}
+ hop_length: ${hop_length}
+ slice_frames: 512
+
+val_dataset:
+ _target_: fish_speech.datasets.vqgan.VQGANDataset
+ filelist: data/filelist.val.txt
+ sample_rate: ${sample_rate}
+ hop_length: ${hop_length}
+
+data:
+ _target_: fish_speech.datasets.vqgan.VQGANDataModule
+ train_dataset: ${train_dataset}
+ val_dataset: ${val_dataset}
+ num_workers: 4
+ batch_size: 16
+ val_batch_size: 16
+
+# Model Configuration
+model:
+ _target_: fish_speech.models.vqgan.VQGAN
+
+ sampling_rate: ${sample_rate}
+ weight_adv: 0.2
+ weight_vq: 1.0
+ weight_mel: 1.0
+ freeze_encoder: false
+
+ encoder:
+ _target_: fish_speech.models.vqgan.modules.wavenet.WaveNet
+ input_channels: ${num_mels}
+ residual_channels: 768
+ residual_layers: 20
+ dilation_cycle: 4
+
+ quantizer:
+ _target_: fish_speech.models.vqgan.modules.fsq.DownsampleFiniteScalarQuantize
+ input_dim: 768
+ n_codebooks: 1
+ n_groups: 2
+ levels: [8, 5, 5, 5]
+
+ decoder:
+ _target_: fish_speech.models.vqgan.modules.wavenet.WaveNet
+ output_channels: ${num_mels}
+ residual_channels: 768
+ residual_layers: 20
+ dilation_cycle: 4
+ condition_channels: 768
+
+ discriminator:
+ _target_: fish_speech.models.vqgan.modules.discriminator.Discriminator
+
+ vocoder:
+ _target_: fish_speech.models.vqgan.modules.firefly.FireflyBase
+ ckpt_path: null # You may download the pretrained vocoder and set the path here
+
+ encode_mel_transform:
+ _target_: fish_speech.models.vqgan.spectrogram.LogMelSpectrogram
+ sample_rate: ${sample_rate}
+ n_fft: ${n_fft}
+ hop_length: ${hop_length}
+ win_length: ${win_length}
+ n_mels: ${num_mels}
+ f_min: 0.0
+ f_max: 8000.0
+
+ gt_mel_transform:
+ _target_: fish_speech.models.vqgan.spectrogram.LogMelSpectrogram
+ sample_rate: ${sample_rate}
+ n_fft: ${n_fft}
+ hop_length: ${hop_length}
+ win_length: ${win_length}
+ n_mels: ${num_mels}
+
+ optimizer:
+ _target_: torch.optim.AdamW
+ _partial_: true
+ lr: 4e-5
+ betas: [0.8, 0.99]
+ eps: 1e-5
+ weight_decay: 0.01
+
+ lr_scheduler:
+ _target_: torch.optim.lr_scheduler.LambdaLR
+ _partial_: true
+ lr_lambda:
+ _target_: fish_speech.scheduler.get_cosine_schedule_with_warmup_lr_lambda
+ _partial_: true
+ num_warmup_steps: 100
+ num_training_steps: ${trainer.max_steps}
+ final_lr_ratio: 0
+
+callbacks:
+ model_summary:
+ _target_: lightning.pytorch.callbacks.ModelSummary
+ max_depth: 1
+
+ model_checkpoint:
+ every_n_train_steps: ${trainer.val_check_interval}
+
+ grad_norm_monitor:
+ sub_module:
+ - encoder
+ - decoder
+ - quantizer
+ - discriminator
diff --git a/fish_speech/configs/vqgan_pretrain.yaml b/fish_speech/configs/vqgan_pretrain.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..271b97dd0ddb69e343557751d74a529005e35667
--- /dev/null
+++ b/fish_speech/configs/vqgan_pretrain.yaml
@@ -0,0 +1,139 @@
+defaults:
+ - base
+ - _self_
+
+project: vq-gan-pretrain
+
+# Lightning Trainer
+trainer:
+ accelerator: gpu
+ devices: auto
+ precision: bf16-mixed
+ max_steps: 1_000_000
+ val_check_interval: 5000
+ strategy: ddp_find_unused_parameters_true
+
+sample_rate: 44100
+hop_length: 512
+num_mels: 128
+n_fft: 2048
+win_length: 2048
+
+# Dataset Configuration
+train_dataset:
+ _target_: torch.utils.data.ConcatDataset
+ datasets:
+ - _target_: fish_speech.datasets.vqgan.VQGANDataset
+ filelist: data/gigaspeech/vq_train_filelist.txt
+ sample_rate: ${sample_rate}
+ hop_length: ${hop_length}
+ slice_frames: 512
+ - _target_: fish_speech.datasets.vqgan.VQGANDataset
+ filelist: data/sft/vq_train_filelist.txt
+ sample_rate: ${sample_rate}
+ hop_length: ${hop_length}
+ slice_frames: 512
+
+val_dataset:
+ _target_: fish_speech.datasets.vqgan.VQGANDataset
+ filelist: data/sft/vq_val_filelist.txt
+ sample_rate: ${sample_rate}
+ hop_length: ${hop_length}
+
+data:
+ _target_: fish_speech.datasets.vqgan.VQGANDataModule
+ train_dataset: ${train_dataset}
+ val_dataset: ${val_dataset}
+ num_workers: 4
+ batch_size: 32
+ val_batch_size: 32
+
+# Model Configuration
+model:
+ _target_: fish_speech.models.vqgan.VQGAN
+
+ sampling_rate: ${sample_rate}
+ weight_adv: 0.2
+ weight_vq: 1.0
+ weight_mel: 1.0
+ freeze_encoder: false
+
+ encoder:
+ _target_: fish_speech.models.vqgan.modules.wavenet.WaveNet
+ input_channels: ${num_mels}
+ residual_channels: 768
+ residual_layers: 20
+ dilation_cycle: 4
+
+ quantizer:
+ _target_: fish_speech.models.vqgan.modules.fsq.DownsampleFiniteScalarQuantize
+ input_dim: 768
+ n_codebooks: 1
+ n_groups: 2
+ levels: [8, 5, 5, 5]
+
+ decoder:
+ _target_: fish_speech.models.vqgan.modules.wavenet.WaveNet
+ output_channels: ${num_mels}
+ residual_channels: 768
+ residual_layers: 20
+ dilation_cycle: 4
+ condition_channels: 768
+
+ discriminator:
+ _target_: fish_speech.models.vqgan.modules.discriminator.Discriminator
+
+ vocoder:
+ _target_: fish_speech.models.vqgan.modules.firefly.FireflyBase
+ ckpt_path: null # You may download the pretrained vocoder and set the path here
+
+ encode_mel_transform:
+ _target_: fish_speech.models.vqgan.spectrogram.LogMelSpectrogram
+ sample_rate: ${sample_rate}
+ n_fft: ${n_fft}
+ hop_length: ${hop_length}
+ win_length: ${win_length}
+ n_mels: ${num_mels}
+ f_min: 0.0
+ f_max: 8000.0
+
+ gt_mel_transform:
+ _target_: fish_speech.models.vqgan.spectrogram.LogMelSpectrogram
+ sample_rate: ${sample_rate}
+ n_fft: ${n_fft}
+ hop_length: ${hop_length}
+ win_length: ${win_length}
+ n_mels: ${num_mels}
+
+ optimizer:
+ _target_: torch.optim.AdamW
+ _partial_: true
+ lr: 1e-4
+ betas: [0.8, 0.99]
+ eps: 1e-5
+ weight_decay: 0.01
+
+ lr_scheduler:
+ _target_: torch.optim.lr_scheduler.LambdaLR
+ _partial_: true
+ lr_lambda:
+ _target_: fish_speech.scheduler.get_cosine_schedule_with_warmup_lr_lambda
+ _partial_: true
+ num_warmup_steps: 100
+ num_training_steps: ${trainer.max_steps}
+ final_lr_ratio: 0
+
+callbacks:
+ model_summary:
+ _target_: lightning.pytorch.callbacks.ModelSummary
+ max_depth: 1
+
+ model_checkpoint:
+ every_n_train_steps: ${trainer.val_check_interval}
+
+ grad_norm_monitor:
+ sub_module:
+ - encoder
+ - decoder
+ - quantizer
+ - discriminator
diff --git a/fish_speech/datasets/protos/text-data.proto b/fish_speech/datasets/protos/text-data.proto
new file mode 100644
index 0000000000000000000000000000000000000000..5eb26d94aa3be1e21066f2bf38c90d54e85a8379
--- /dev/null
+++ b/fish_speech/datasets/protos/text-data.proto
@@ -0,0 +1,24 @@
+syntax = "proto3";
+
+package text_data;
+
+message Semantics {
+ repeated uint32 values = 1;
+}
+
+message Sentence {
+ repeated string texts = 1;
+ repeated Semantics semantics = 3;
+}
+
+message TextData {
+ string source = 1;
+ string name = 2;
+ repeated Sentence sentences = 4;
+}
+
+message SampledData {
+ string source = 1;
+ string name = 2;
+ repeated Sentence samples = 3;
+}
diff --git a/fish_speech/datasets/protos/text_data_pb2.py b/fish_speech/datasets/protos/text_data_pb2.py
new file mode 100644
index 0000000000000000000000000000000000000000..bfce0e8be59fc51e68999ef137e1fd0e4adc0d7e
--- /dev/null
+++ b/fish_speech/datasets/protos/text_data_pb2.py
@@ -0,0 +1,33 @@
+# -*- coding: utf-8 -*-
+# Generated by the protocol buffer compiler. DO NOT EDIT!
+# source: text-data.proto
+# Protobuf Python Version: 4.25.1
+"""Generated protocol buffer code."""
+from google.protobuf import descriptor as _descriptor
+from google.protobuf import descriptor_pool as _descriptor_pool
+from google.protobuf import symbol_database as _symbol_database
+from google.protobuf.internal import builder as _builder
+
+# @@protoc_insertion_point(imports)
+
+_sym_db = _symbol_database.Default()
+
+
+DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
+ b'\n\x0ftext-data.proto\x12\ttext_data"\x1b\n\tSemantics\x12\x0e\n\x06values\x18\x01 \x03(\r"B\n\x08Sentence\x12\r\n\x05texts\x18\x01 \x03(\t\x12\'\n\tsemantics\x18\x03 \x03(\x0b\x32\x14.text_data.Semantics"P\n\x08TextData\x12\x0e\n\x06source\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12&\n\tsentences\x18\x04 \x03(\x0b\x32\x13.text_data.Sentence"Q\n\x0bSampledData\x12\x0e\n\x06source\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12$\n\x07samples\x18\x03 \x03(\x0b\x32\x13.text_data.Sentenceb\x06proto3'
+)
+
+_globals = globals()
+_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
+_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "text_data_pb2", _globals)
+if _descriptor._USE_C_DESCRIPTORS == False:
+ DESCRIPTOR._options = None
+ _globals["_SEMANTICS"]._serialized_start = 30
+ _globals["_SEMANTICS"]._serialized_end = 57
+ _globals["_SENTENCE"]._serialized_start = 59
+ _globals["_SENTENCE"]._serialized_end = 125
+ _globals["_TEXTDATA"]._serialized_start = 127
+ _globals["_TEXTDATA"]._serialized_end = 207
+ _globals["_SAMPLEDDATA"]._serialized_start = 209
+ _globals["_SAMPLEDDATA"]._serialized_end = 290
+# @@protoc_insertion_point(module_scope)
diff --git a/fish_speech/datasets/protos/text_data_stream.py b/fish_speech/datasets/protos/text_data_stream.py
new file mode 100644
index 0000000000000000000000000000000000000000..ec3c25bcd764e8245de47dcdf9686d6adfb5a107
--- /dev/null
+++ b/fish_speech/datasets/protos/text_data_stream.py
@@ -0,0 +1,36 @@
+import struct
+
+from .text_data_pb2 import TextData
+
+
+def read_pb_stream(f):
+ while True:
+ buf = f.read(4)
+ if len(buf) == 0:
+ break
+ size = struct.unpack("I", buf)[0]
+ buf = f.read(size)
+ text_data = TextData()
+ text_data.ParseFromString(buf)
+ yield text_data
+
+
+def write_pb_stream(f, text_data):
+ buf = text_data.SerializeToString()
+ f.write(struct.pack("I", len(buf)))
+ f.write(buf)
+
+
+def pack_pb_stream(text_data):
+ buf = text_data.SerializeToString()
+ return struct.pack("I", len(buf)) + buf
+
+
+def split_pb_stream(f):
+ while True:
+ head = f.read(4)
+ if len(head) == 0:
+ break
+ size = struct.unpack("I", head)[0]
+ buf = f.read(size)
+ yield head + buf
diff --git a/fish_speech/datasets/text.py b/fish_speech/datasets/text.py
new file mode 100644
index 0000000000000000000000000000000000000000..44b3bfaa5582930d4f2c3950aa9e03ea4219017b
--- /dev/null
+++ b/fish_speech/datasets/text.py
@@ -0,0 +1,661 @@
+import random
+from dataclasses import dataclass
+from itertools import chain
+from pathlib import Path
+from random import Random
+from typing import Optional, Union
+
+import grpc
+import numpy as np
+import pyarrow.parquet as pq
+import torch
+import torch.nn.functional as F
+from datasets.download.streaming_download_manager import xopen
+from huggingface_hub import HfApi
+from lightning import LightningDataModule
+from torch.distributed import get_rank, get_world_size, is_initialized
+from torch.utils.data import DataLoader, IterableDataset, get_worker_info
+from transformers import AutoTokenizer
+
+from fish_speech.datasets.protos.text_data_pb2 import SampledData
+from fish_speech.datasets.protos.text_data_stream import read_pb_stream
+from fish_speech.text.clean import clean_text
+from fish_speech.utils import RankedLogger
+from fish_speech.utils.braceexpand import braceexpand
+
+log = RankedLogger(__name__, rank_zero_only=True)
+
+CODEBOOK_PAD_TOKEN_ID = 0
+CODEBOOK_EOS_TOKEN_ID = 1
+
+
+def split_by_rank_worker(files):
+ # We need to know the total number of devices
+ # to split the data properly
+
+ total_devices = 1
+ if is_initialized():
+ total_devices = get_world_size()
+
+ worker_info = get_worker_info()
+ if worker_info is not None:
+ total_devices *= worker_info.num_workers
+
+ if len(files) < total_devices:
+ # Repeat the files N times to match the number of devices
+ files = files * (total_devices // len(files) + 1)
+
+ # DDP
+ if is_initialized():
+ files = files[get_rank() :: get_world_size()]
+
+ # Split by worker
+ if worker_info is not None:
+ files = files[worker_info.id :: worker_info.num_workers]
+
+ return files
+
+
+class StreamTextDataset(IterableDataset):
+ def __init__(
+ self,
+ files: Optional[Union[list[str], str]] = None,
+ prefix: Optional[str] = None,
+ seed: int = 42,
+ parquet_batch_size: int = 10000,
+ repo: str = "uonlp/CulturaX",
+ max_length: int = 1024,
+ tokenizer: AutoTokenizer = None,
+ ):
+ super().__init__()
+
+ self.seed = seed
+ self.parquet_batch_size = parquet_batch_size
+ self.repo = repo
+ self.max_length = max_length
+ self.tokenizer = tokenizer
+
+ if files is None and prefix is None:
+ raise ValueError("Either files or prefix must be specified")
+
+ if prefix is not None:
+ files = HfApi().list_repo_files(repo, repo_type="dataset")
+ files = [
+ f for f in files if f.startswith(prefix) and f.endswith(".parquet")
+ ]
+ log.info(f"Found {len(files)} files in {repo} with prefix {prefix}")
+ else:
+ if isinstance(files, str):
+ files = [files]
+
+ files = list(chain.from_iterable(map(braceexpand, files)))
+ log.info(f"Expanded {len(files)} files in {repo}")
+
+ # Get sharded files
+ self.files = sorted(files)
+ Random(seed).shuffle(self.files)
+
+ def __iter__(self):
+ files = split_by_rank_worker(self.files)
+ random.shuffle(files)
+
+ for filename in files:
+ try:
+ yield from self.parse_data(filename)
+ except Exception as e:
+ log.exception(f"Failed to parse {filename}: {e}")
+
+ def parse_data(self, filename: str):
+ for data in self.parse_data_internal(filename):
+ text = data["text"]
+
+ # encode
+ tokens = self.tokenizer.encode(
+ text,
+ add_special_tokens=False,
+ truncation=False,
+ max_length=10**6,
+ )
+
+ # Random choice self.max_length
+ if len(tokens) > self.max_length:
+ start = random.randint(0, len(tokens) - self.max_length)
+ tokens = tokens[start : start + self.max_length - 1]
+
+ tokens = (
+ [self.tokenizer.bos_token_id] + tokens + [self.tokenizer.eos_token_id]
+ )
+ # Pad dims
+ placeholder_multi_codebook = torch.zeros((4, len(tokens)), dtype=torch.long)
+
+ tokens = torch.concat(
+ [
+ torch.tensor([tokens], dtype=torch.long),
+ placeholder_multi_codebook,
+ ],
+ dim=0,
+ )
+ labels = tokens.clone()
+ tokens = tokens[:, :-1]
+ labels = labels[:, 1:]
+ labels[1:] = -100 # remove all placeholders
+
+ yield {"tokens": tokens, "labels": labels}
+
+ def parse_data_internal(self, filename: str):
+ url = f"https://huggingface.co./datasets/{self.repo}/resolve/main/{filename}"
+
+ with xopen(url, mode="rb") as stream:
+ parquet_file = pq.ParquetFile(stream)
+
+ for batch in parquet_file.iter_batches(
+ batch_size=self.parquet_batch_size, columns=["text"]
+ ):
+ # In-batch shuffling
+ texts = [{"text": text.as_py()} for text in batch["text"]]
+ random.shuffle(texts)
+ yield from texts
+
+
+class AutoAugTextDataset(IterableDataset):
+ """
+ Auto Augment Dataset by Speaker
+
+ 1. Random concatenate multiple sentences from the same speaker to form a longer sentence
+ 2. Automatically normalize the text
+
+ For interactive mode, we use the following format (multiple sequences):
+ [INST] [SPK: speaker] text [/INST] ... [INST] text [/INST]
+
+ For non-interactive mode, we use the following format (one long sequence):
+ [INST] text [/INST] ...
+ """
+
+ def __init__(
+ self,
+ proto_files: list[str],
+ seed: int = 42,
+ interactive_prob: float = 0.5,
+ max_length: int = 1024,
+ tokenizer: AutoTokenizer = None,
+ use_speaker: bool = True,
+ causual: bool = True,
+ use_negative_samples: bool = False,
+ num_codebooks: Optional[int] = None,
+ ):
+ """
+ Args:
+ proto_files: proto buf files if using local data
+ seed: random seed
+ interactive_prob: probability to use interactive mode
+ max_length: max length of the text
+ tokenizer: tokenizer
+ use_speaker: include speaker information in the prompt
+ causual: use causual sampling when using local data, disable will lead to random sampling
+ use_negative_samples: generate negative samples
+ num_codebooks: number of codebooks, if None, it will be automatically detected
+ """
+
+ super().__init__()
+
+ assert 0 <= interactive_prob <= 1, "interactive_prob must be in [0, 1]"
+
+ self.seed = seed
+ self.max_length = max_length
+ self.tokenizer = tokenizer
+ self.interactive_prob = interactive_prob
+ self.use_speaker = use_speaker
+ self.proto_files = proto_files
+ self.causual = causual
+ self.use_negative_samples = use_negative_samples
+ self.num_codebooks = num_codebooks
+
+ self.semantic_token_id = self.tokenizer.convert_tokens_to_ids("<|semantic|>")
+ self.groups = None
+
+ def init_mock_data_server(self):
+ if self.groups is not None:
+ return
+
+ # Expand the proto files
+ expanded_proto_files = []
+ for filename in self.proto_files:
+ for i in braceexpand(filename):
+ i = Path(i)
+ if i.is_file():
+ expanded_proto_files.append(i)
+ elif i.is_dir():
+ expanded_proto_files.extend(i.rglob("*.proto"))
+ expanded_proto_files.extend(i.rglob("*.protos"))
+ else:
+ raise ValueError(f"{i} is not a file or directory")
+
+ expanded_proto_files = sorted(expanded_proto_files)
+ Random(self.seed).shuffle(expanded_proto_files)
+
+ self.groups = []
+ shard_proto_files = split_by_rank_worker(expanded_proto_files)
+ log.info(
+ f"Reading {len(shard_proto_files)} / {len(expanded_proto_files)} files"
+ )
+
+ count = 0
+ for filename in shard_proto_files:
+ with open(filename, "rb") as f:
+ for text_data in read_pb_stream(f):
+ self.groups.append(text_data)
+ count += 1
+
+ log.info(f"Read total {count} groups of data")
+
+ # Shuffle the lines
+ Random(self.seed).shuffle(self.groups)
+ self.group_weights = [len(i.sentences) for i in self.groups]
+
+ def __iter__(self):
+ while True:
+ yield self.augment()
+
+ def tokenize_sentence(self, sentence: str):
+ sentence = clean_text(sentence)
+ tokens = self.tokenizer.encode(
+ f"{sentence}",
+ max_length=10**6,
+ add_special_tokens=False,
+ truncation=False,
+ )
+ return sentence, len(tokens)
+
+ def sample_data(self):
+ if self.groups is None:
+ self.init_mock_data_server()
+
+ # Shuffle unique lines, estimate that each sample is at least 20 tokens
+ num_samples = self.max_length // 20
+
+ # choice group based on their number of samples
+ group = random.choices(self.groups, weights=self.group_weights, k=1)[0]
+
+ if self.causual:
+ # Sample in order
+ if num_samples >= len(group.sentences):
+ samples = group.sentences
+ else:
+ begin = random.randint(0, len(group.sentences) - num_samples)
+ samples = group.sentences[begin : begin + num_samples]
+ else:
+ samples = random.choices(
+ group.sentences, k=min(num_samples, len(group.sentences))
+ )
+
+ return SampledData(
+ source=group.source,
+ name=group.name,
+ samples=samples,
+ )
+
+ def augment(self):
+ # Random sample based on speaker using a truncated normal distribution
+ a = torch.tensor([0], dtype=torch.float32)
+ torch.nn.init.trunc_normal_(
+ a,
+ mean=self.max_length // 2,
+ std=self.max_length // 4,
+ a=10,
+ b=self.max_length,
+ )
+ remaining_tokens = a.long().item() - 4
+
+ final_text, final_semantic = [], []
+ response = self.sample_data()
+ if len(response.samples) == 0:
+ # Invalid group
+ return None
+
+ samples = list(response.samples)
+ idx = 0
+ use_interactive = random.random() < self.interactive_prob
+
+ all_tokens, all_labels = [], []
+ while remaining_tokens > 0 and len(samples) > 0:
+ sentence = samples.pop(0)
+
+ text = random.choice(sentence.texts)
+ text, length = self.tokenize_sentence(text)
+ remaining_tokens -= length + len(sentence.semantics[0].values)
+
+ if use_interactive is False:
+ final_text.append(text)
+ final_semantic.append(sentence.semantics)
+ else:
+ # For interactive mode, we only apply speaker for the first sentence
+ # [INST] [SPK: speaker] text [/INST] ... [INST] text [/INST]
+ tokens, labels = self.pack_sentences(
+ sentences=[text],
+ semantics=[sentence.semantics],
+ speaker=response.name if (self.use_speaker and idx == 0) else None,
+ add_bos=idx == 0,
+ )
+
+ all_tokens.append(tokens)
+ all_labels.append(labels)
+
+ idx += 1
+
+ if use_interactive is False:
+ tokens, labels = self.pack_sentences(
+ final_text,
+ semantics=final_semantic,
+ speaker=response.name if self.use_speaker else None,
+ add_bos=True,
+ )
+ all_tokens.append(tokens)
+ all_labels.append(labels)
+
+ tokens = torch.cat(all_tokens, dim=1)
+ labels = torch.cat(all_labels, dim=1)
+
+ # Verify that the length is correct
+ assert tokens.size(1) == labels.size(1), f"{tokens.size(1)} != {labels.size(1)}"
+
+ # Verify bos token
+ assert tokens[0, 0] == self.tokenizer.bos_token_id
+
+ data = {"tokens": tokens, "labels": labels}
+
+ if self.use_negative_samples:
+ negative_samples = self.generate_negative_samples(all_tokens, all_labels)
+ data.update(negative_samples)
+
+ return data
+
+ def generate_negative_samples(self, all_tokens, all_labels):
+ new_tokens, new_labels = [], []
+
+ for tokens, labels in zip(all_tokens, all_labels):
+ # If all codebooks are not -100, we find where it starts
+ start = torch.where(labels[1:].sum(0) != -100 * (labels.size(0) - 1))[0][0]
+ assert (labels[1:, start:] != -100).all() # This shouldn't happen
+
+ mode = random.choice(["repeat", "lost", "noise"])
+ begin = random.randint(start, labels.size(1) - 1)
+ end = random.randint(begin, labels.size(1) - 1)
+
+ if mode == "repeat":
+ tokens = torch.cat(
+ [
+ tokens[:, :begin],
+ tokens[:, begin:end],
+ tokens[:, begin:end],
+ tokens[:, end:],
+ ],
+ dim=1,
+ )
+ labels = torch.cat(
+ [
+ labels[:, :begin],
+ labels[:, begin:end],
+ labels[:, begin:end],
+ labels[:, end:],
+ ],
+ dim=1,
+ )
+ elif mode == "lost":
+ tokens = torch.cat([tokens[:, :begin], tokens[:, end:]], dim=1)
+ labels = torch.cat([labels[:, :begin], labels[:, end:]], dim=1)
+ elif mode == "noise":
+ middle_tokens, middle_labels = (
+ tokens[:, begin:end],
+ labels[:, begin:end],
+ )
+ random_order0 = torch.randperm(middle_tokens.size(1))
+ random_order1 = torch.randperm(middle_tokens.size(1))
+ middle_tokens = middle_tokens[:, random_order0]
+ middle_labels = middle_labels[:, random_order1]
+ tokens = torch.cat(
+ [tokens[:, :begin], middle_tokens, tokens[:, end:]], dim=1
+ )
+ labels = torch.cat(
+ [labels[:, :begin], middle_labels, labels[:, end:]], dim=1
+ )
+
+ new_tokens.append(tokens)
+ new_labels.append(labels)
+
+ tokens = torch.cat(new_tokens, dim=1)
+ labels = torch.cat(new_labels, dim=1)
+
+ # Verify that the length is correct
+ assert tokens.size(1) == labels.size(1), f"{tokens.size(1)} != {labels.size(1)}"
+
+ return {"negative_tokens": tokens, "negative_labels": labels}
+
+ def pack_sentences(
+ self,
+ sentences: list[str],
+ semantics=list,
+ speaker: Optional[str] = None,
+ add_bos: bool = True,
+ ):
+ if speaker is not None:
+ sentences = [f"[SPK: {speaker}]"] + sentences
+
+ final_text = "<|im_start|>user<|im_sep|>" + " ".join(sentences) + "<|im_end|>"
+ final_text = final_text + "<|im_start|>assistant<|im_sep|>"
+
+ encoded = self.tokenizer.encode(
+ final_text,
+ add_special_tokens=False,
+ truncation=False,
+ max_length=10**6,
+ )
+ semantic_length = sum([len(i[0].values) for i in semantics])
+ prompt_length = len(encoded)
+ num_codebooks = (
+ len(semantics[0]) if self.num_codebooks is None else self.num_codebooks
+ )
+
+ bos_bias = 1 if add_bos else 0
+
+ # Pack the tokens and semantics (add and to semantic tokens)
+ tokens = (
+ encoded
+ + [self.semantic_token_id] * semantic_length
+ + self.tokenizer.convert_tokens_to_ids(
+ ["<|im_end|>", "<|end_of_sequence|>"]
+ )
+ )
+
+ if add_bos:
+ tokens = [self.tokenizer.bos_token_id] + tokens
+
+ # Codebook bos/padding: 0, eos: 1
+ codes = [
+ [CODEBOOK_PAD_TOKEN_ID] * (prompt_length + bos_bias)
+ for _ in range(num_codebooks)
+ ]
+ for segment in semantics:
+ for book_idx, book in zip(range(num_codebooks), segment):
+ for j in book.values:
+ codes[book_idx].append(int(j) + 2)
+
+ for book in codes:
+ book.extend([CODEBOOK_EOS_TOKEN_ID] * 2)
+
+ tokens = [tokens] + codes
+
+ tokens = torch.tensor(tokens, dtype=torch.long)
+ labels = tokens.clone()
+
+ # Mask out the tokens for semantic, predict semantic tokens only
+ # Since we don't mask out the input tokens, the language modeling still works
+ labels[1:, : (prompt_length + bos_bias)] = -100
+
+ tokens = tokens[:, :-1]
+ labels = labels[:, 1:]
+
+ # Verify the padding is correct, and the last token is eos
+ assert add_bos is False or tokens[0, 0] == self.tokenizer.bos_token_id
+ assert (tokens[1:, : prompt_length + bos_bias] == CODEBOOK_PAD_TOKEN_ID).all()
+ assert labels[0, -1] == self.tokenizer.eos_token_id
+ assert (labels[1:, -2:] == CODEBOOK_EOS_TOKEN_ID).all()
+
+ return tokens, labels
+
+
+@dataclass
+class TextDataCollator:
+ tokenizer: AutoTokenizer
+ max_length: int = 1024
+
+ def __call__(self, examples):
+ if "negative_tokens" in examples:
+ positive_examples = []
+ negative_examples = []
+
+ for i in examples:
+ positive_examples.append(
+ {
+ "tokens": i["tokens"],
+ "labels": i["labels"],
+ }
+ )
+ negative_examples.append(
+ {
+ "tokens": i["negative_tokens"],
+ "labels": i["negative_labels"],
+ }
+ )
+
+ examples = positive_examples + negative_examples
+
+ return self.batchify(examples)
+
+ def batchify(self, examples, tokens_key="tokens", labels_key="labels"):
+ tokens, attention_masks, labels = [], [], []
+
+ # Calculate the max length
+ max_tokens_length = 0
+ for example in examples:
+ max_tokens_length = max(max_tokens_length, example[tokens_key].size(1))
+ max_tokens_length = min(max_tokens_length, self.max_length)
+
+ for example in examples:
+ _tokens = example[tokens_key][:, :max_tokens_length]
+ _labels = example[labels_key][:, :max_tokens_length]
+ _attention_mask = torch.ones((max_tokens_length,), dtype=torch.bool)
+ tokens_length = _tokens.size(1)
+ _attention_mask[:tokens_length] = False
+
+ assert tokens_length == _labels.size(
+ 1
+ ), f"{tokens_length} != {_labels.size(1)}"
+
+ if tokens_length < max_tokens_length:
+ _tokens = F.pad(
+ _tokens,
+ (0, max_tokens_length - tokens_length),
+ value=self.tokenizer.eos_token_id,
+ )
+ _tokens[1:, tokens_length:] = CODEBOOK_PAD_TOKEN_ID
+ _labels = F.pad(
+ _labels, (0, max_tokens_length - _labels.size(1)), value=-100
+ )
+
+ tokens.append(_tokens)
+ attention_masks.append(_attention_mask)
+ labels.append(_labels)
+
+ tokens = torch.stack(tokens, dim=0)
+ attention_masks = torch.stack(attention_masks, dim=0)
+ labels = torch.stack(labels, dim=0)
+
+ return {
+ "inputs": tokens,
+ "attention_masks": attention_masks,
+ "labels": labels,
+ }
+
+
+class InterleaveDataset(IterableDataset):
+ def __init__(
+ self,
+ datasets: list[IterableDataset],
+ probabilities: list[float],
+ seed: int = 42,
+ ):
+ super().__init__()
+
+ self.datasets = datasets
+ self.probabilities = probabilities
+ self.seed = seed
+
+ def __iter__(self):
+ rng = np.random.default_rng(self.seed)
+ dataset_iterators = [iter(dataset) for dataset in self.datasets]
+
+ while True:
+ # Random choice one
+ dataset_idx = rng.choice(len(self.datasets), p=self.probabilities)
+ dataset_iterator = dataset_iterators[dataset_idx]
+
+ try:
+ yield next(dataset_iterator)
+ except StopIteration:
+ # Exhausted, create a new iterator
+ dataset_iterators[dataset_idx] = iter(self.datasets[dataset_idx])
+ yield next(dataset_iterators[dataset_idx])
+
+
+class TextDataModule(LightningDataModule):
+ def __init__(
+ self,
+ train_dataset: Union[StreamTextDataset, AutoAugTextDataset, InterleaveDataset],
+ val_dataset: Union[StreamTextDataset, AutoAugTextDataset, InterleaveDataset],
+ batch_size: int = 32,
+ tokenizer: AutoTokenizer = None,
+ max_length: int = 1024,
+ num_workers: int = 4,
+ ):
+ super().__init__()
+
+ self.train_dataset = train_dataset
+ self.val_dataset = val_dataset
+ self.batch_size = batch_size
+ self.tokenizer = tokenizer
+ self.max_length = max_length
+ self.num_workers = num_workers
+
+ def train_dataloader(self):
+ return DataLoader(
+ self.train_dataset,
+ batch_size=self.batch_size,
+ collate_fn=TextDataCollator(self.tokenizer, self.max_length),
+ num_workers=self.num_workers,
+ )
+
+ def val_dataloader(self):
+ return DataLoader(
+ self.val_dataset,
+ batch_size=self.batch_size,
+ collate_fn=TextDataCollator(self.tokenizer, self.max_length),
+ num_workers=self.num_workers,
+ )
+
+
+if __name__ == "__main__":
+ from tqdm import tqdm
+
+ ds = AutoAugTextDataset(
+ ["data/protos"],
+ tokenizer=AutoTokenizer.from_pretrained("fishaudio/fish-speech-1"),
+ use_speaker=False,
+ interactive_prob=1.0,
+ use_negative_samples=False,
+ )
+
+ for i in ds:
+ print(ds.tokenizer.decode(i["tokens"][0], skip_special_tokens=False))
+ # i["labels"][0][i["labels"][0] == -100] = 0
+ # print(ds.tokenizer.decode(i["labels"][0], skip_special_tokens=False))
+ break
diff --git a/fish_speech/datasets/vqgan.py b/fish_speech/datasets/vqgan.py
new file mode 100644
index 0000000000000000000000000000000000000000..f61e2c513470015501dc2d8fbfc1392bd4353e75
--- /dev/null
+++ b/fish_speech/datasets/vqgan.py
@@ -0,0 +1,145 @@
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Optional
+
+import librosa
+import numpy as np
+import torch
+from lightning import LightningDataModule
+from torch.utils.data import DataLoader, Dataset
+
+from fish_speech.utils import RankedLogger
+
+logger = RankedLogger(__name__, rank_zero_only=False)
+
+
+class VQGANDataset(Dataset):
+ def __init__(
+ self,
+ filelist: str,
+ sample_rate: int = 32000,
+ hop_length: int = 640,
+ slice_frames: Optional[int] = None,
+ ):
+ super().__init__()
+
+ filelist = Path(filelist)
+ root = filelist.parent
+
+ self.files = [
+ root / line.strip()
+ for line in filelist.read_text().splitlines()
+ if line.strip()
+ ]
+ self.sample_rate = sample_rate
+ self.hop_length = hop_length
+ self.slice_frames = slice_frames
+
+ def __len__(self):
+ return len(self.files)
+
+ def get_item(self, idx):
+ file = self.files[idx]
+
+ audio, _ = librosa.load(file, sr=self.sample_rate, mono=True)
+
+ # Slice audio and features
+ if (
+ self.slice_frames is not None
+ and audio.shape[0] > self.slice_frames * self.hop_length
+ ):
+ start = np.random.randint(
+ 0, audio.shape[0] - self.slice_frames * self.hop_length
+ )
+ audio = audio[start : start + self.slice_frames * self.hop_length]
+
+ if len(audio) == 0:
+ return None
+
+ max_value = np.abs(audio).max()
+ if max_value > 1.0:
+ audio = audio / max_value
+
+ return {
+ "audio": torch.from_numpy(audio),
+ }
+
+ def __getitem__(self, idx):
+ try:
+ return self.get_item(idx)
+ except Exception as e:
+ import traceback
+
+ traceback.print_exc()
+ logger.error(f"Error loading {self.files[idx]}: {e}")
+ return None
+
+
+@dataclass
+class VQGANCollator:
+ def __call__(self, batch):
+ batch = [x for x in batch if x is not None]
+
+ audio_lengths = torch.tensor([len(x["audio"]) for x in batch])
+ audio_maxlen = audio_lengths.max()
+
+ # Rounds up to nearest multiple of 2 (audio_lengths)
+ audios = []
+ for x in batch:
+ audios.append(
+ torch.nn.functional.pad(x["audio"], (0, audio_maxlen - len(x["audio"])))
+ )
+
+ return {
+ "audios": torch.stack(audios),
+ "audio_lengths": audio_lengths,
+ }
+
+
+class VQGANDataModule(LightningDataModule):
+ def __init__(
+ self,
+ train_dataset: VQGANDataset,
+ val_dataset: VQGANDataset,
+ batch_size: int = 32,
+ num_workers: int = 4,
+ val_batch_size: Optional[int] = None,
+ ):
+ super().__init__()
+
+ self.train_dataset = train_dataset
+ self.val_dataset = val_dataset
+ self.batch_size = batch_size
+ self.val_batch_size = val_batch_size or batch_size
+ self.num_workers = num_workers
+
+ def train_dataloader(self):
+ return DataLoader(
+ self.train_dataset,
+ batch_size=self.batch_size,
+ collate_fn=VQGANCollator(),
+ num_workers=self.num_workers,
+ shuffle=True,
+ )
+
+ def val_dataloader(self):
+ return DataLoader(
+ self.val_dataset,
+ batch_size=self.val_batch_size,
+ collate_fn=VQGANCollator(),
+ num_workers=self.num_workers,
+ )
+
+
+if __name__ == "__main__":
+ dataset = VQGANDataset("data/LibriTTS_R/vq_train_filelist.txt")
+ dataloader = DataLoader(
+ dataset, batch_size=4, shuffle=False, collate_fn=VQGANCollator()
+ )
+
+ for batch in dataloader:
+ print(batch["audios"].shape)
+ print(batch["features"].shape)
+ print(batch["audio_lengths"])
+ print(batch["feature_lengths"])
+ break
diff --git a/fish_speech/models/text2semantic/__init__.py b/fish_speech/models/text2semantic/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..cceba5d6d39f149dfc77688663a4a200b74d3c5c
--- /dev/null
+++ b/fish_speech/models/text2semantic/__init__.py
@@ -0,0 +1,3 @@
+from .lit_module import TextToSemantic
+
+__all__ = ["TextToSemantic"]
diff --git a/fish_speech/models/text2semantic/lit_module.py b/fish_speech/models/text2semantic/lit_module.py
new file mode 100644
index 0000000000000000000000000000000000000000..de759bc31585e5fc1acf402228b57eee21b90332
--- /dev/null
+++ b/fish_speech/models/text2semantic/lit_module.py
@@ -0,0 +1,344 @@
+from dataclasses import dataclass
+from typing import Any, Optional
+
+import lightning as L
+import loralib as lora
+import torch
+import torch.nn.functional as F
+from lightning.pytorch.utilities.types import OptimizerLRScheduler
+
+import fish_speech.utils as utils
+from fish_speech.models.text2semantic.llama import NaiveTransformer
+
+log = utils.RankedLogger(__name__, rank_zero_only=True)
+
+
+@dataclass
+class LoraConfig:
+ r: int
+ lora_alpha: float
+ lora_dropout: float = 0.0
+
+
+class TextToSemantic(L.LightningModule):
+ def __init__(
+ self,
+ model: NaiveTransformer,
+ optimizer: Any,
+ lr_scheduler: Any,
+ lora_config: Optional[LoraConfig] = None,
+ save_lora_only: bool = False,
+ use_dpo: bool = False,
+ dpo_beta: float = 0.2,
+ ):
+ super().__init__()
+
+ self.model = model
+ self.optimizer_builder = optimizer
+ self.lr_scheduler_builder = lr_scheduler
+ self.lora_config = lora_config
+ self.save_lora_only = save_lora_only
+ self.use_dpo = use_dpo # We don't support reference model yet
+ self.dpo_beta = dpo_beta
+
+ if self.lora_config is not None:
+ self.setup_lora()
+
+ def setup_lora(self):
+ # Replace the embedding layer with a LoRA layer
+ self.model.embeddings = lora.Embedding(
+ num_embeddings=self.model.embeddings.num_embeddings,
+ embedding_dim=self.model.embeddings.embedding_dim,
+ padding_idx=self.model.embeddings.padding_idx,
+ r=self.lora_config.r,
+ lora_alpha=self.lora_config.lora_alpha,
+ )
+
+ # Replace output layer with a LoRA layer
+ linears = [(self.model, "output")]
+
+ # Replace all linear layers with LoRA layers
+ for layer in self.model.layers:
+ linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")])
+ linears.extend(
+ [
+ (layer.feed_forward, "w1"),
+ (layer.feed_forward, "w2"),
+ (layer.feed_forward, "w3"),
+ ]
+ )
+
+ if hasattr(self.model, "fast_layers"):
+ # Dual-AR model
+ linears.extend([(self.model, "fast_output")])
+
+ for layer in self.model.fast_layers:
+ linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")])
+ linears.extend(
+ [
+ (layer.feed_forward, "w1"),
+ (layer.feed_forward, "w2"),
+ (layer.feed_forward, "w3"),
+ ]
+ )
+
+ for module, layer in linears:
+ updated_linear = lora.Linear(
+ in_features=getattr(module, layer).in_features,
+ out_features=getattr(module, layer).out_features,
+ bias=getattr(module, layer).bias,
+ r=self.lora_config.r,
+ lora_alpha=self.lora_config.lora_alpha,
+ lora_dropout=self.lora_config.lora_dropout,
+ )
+ setattr(module, layer, updated_linear)
+
+ # Mark only the LoRA layers as trainable
+ lora.mark_only_lora_as_trainable(self.model, bias="lora_only")
+
+ def forward(self, x):
+ return self.model(x)
+
+ def on_save_checkpoint(self, checkpoint):
+ if self.lora_config is None or self.save_lora_only is False:
+ return
+
+ # Save only LoRA parameters
+ state_dict = checkpoint["state_dict"]
+ for name in list(state_dict.keys()):
+ if "lora" not in name:
+ state_dict.pop(name)
+
+ def configure_optimizers(self) -> OptimizerLRScheduler:
+ # Get weight decay parameters
+ weight_decay_parameters, other_parameters = [], []
+ for name, param in self.named_parameters():
+ if ".bias" in name or "norm.weight" in name or ".embeddings." in name:
+ other_parameters.append(param)
+ else:
+ weight_decay_parameters.append(param)
+
+ optimizer = self.optimizer_builder(
+ [
+ {"params": weight_decay_parameters},
+ {"params": other_parameters, "weight_decay": 0.0},
+ ]
+ )
+
+ # Print the parameters and their weight decay
+ for i in optimizer.param_groups:
+ log.info(
+ f"Set weight decay: {i['weight_decay']} for {len(i['params'])} parameters"
+ )
+
+ lr_scheduler = self.lr_scheduler_builder(optimizer)
+
+ return {
+ "optimizer": optimizer,
+ "lr_scheduler": {
+ "scheduler": lr_scheduler,
+ "interval": "step",
+ },
+ }
+
+ # Copied from https://github.com/eric-mitchell/direct-preference-optimization/blob/main/trainers.py#L90
+ def get_batch_logps(
+ self,
+ logits: torch.FloatTensor,
+ labels: torch.LongTensor,
+ average_log_prob: bool = False,
+ ) -> torch.FloatTensor:
+ """Compute the log probabilities of the given labels under the given logits.
+
+ Args:
+ logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, codebook_size, vocab_size)
+ labels: Labels for which to compute the log probabilities. Label tokens with a value of -100 are ignored. Shape: (batch_size, sequence_length, codebook_size)
+ average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
+
+ Returns:
+ A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
+ """
+ assert logits.shape[:-1] == labels.shape
+
+ labels = labels.clone()
+ loss_mask = labels != -100
+
+ # dummy token; we'll ignore the losses on these tokens later
+ labels[labels == -100] = 0
+
+ per_token_logps = torch.gather(
+ logits.log_softmax(-1), dim=-1, index=labels.unsqueeze(-1)
+ ).squeeze(-1)
+
+ if average_log_prob:
+ return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
+ else:
+ return (per_token_logps * loss_mask).sum(-1)
+
+ def _step(self, batch, batch_idx, stage: str):
+ is_train = stage == "train"
+
+ # Do positive and negative samples in the same batch to speed up training
+ labels = batch["labels"]
+ outputs = self.model(
+ inp=batch["inputs"],
+ key_padding_mask=batch["attention_masks"],
+ )
+ token_logits = outputs.token_logits
+ codebook_logits = outputs.codebook_logits
+
+ if self.use_dpo:
+ # Firtst half is positive, second half is negative
+ token_logits, negative_token_logits = token_logits.chunk(2)
+ codebook_logits, negative_codebook_logits = codebook_logits.chunk(2)
+ labels, negative_labels = labels.chunk(2)
+
+ # Generate labels
+ base_loss = F.cross_entropy(
+ token_logits.reshape(-1, token_logits.size(-1)),
+ labels[:, 0].reshape(-1),
+ ignore_index=-100,
+ )
+
+ codebook_labels = labels[:, 1 : 1 + self.model.config.num_codebooks].mT
+ semantic_loss = F.cross_entropy(
+ codebook_logits.reshape(-1, codebook_logits.size(-1)),
+ codebook_labels.reshape(-1),
+ ignore_index=-100,
+ )
+
+ loss = base_loss + semantic_loss
+
+ # If we use dpo
+ if self.use_dpo:
+ negative_codebook_labels = negative_labels[
+ :, 1 : 1 + self.model.config.num_codebooks
+ ].mT
+
+ positive_codebook_logps = self.get_batch_logps(
+ codebook_logits, codebook_labels
+ )
+ negative_codebook_logps = self.get_batch_logps(
+ negative_codebook_logits, negative_codebook_labels
+ )
+
+ # TODO: implement the reference model, avoid screwing up the gradients
+ dpo_loss = -F.logsigmoid(
+ (positive_codebook_logps - negative_codebook_logps) * self.dpo_beta
+ ).mean()
+
+ chosen_rewards = self.dpo_beta * positive_codebook_logps.detach()
+ rejected_rewards = self.dpo_beta * negative_codebook_logps.detach()
+ reward_accuracy = (chosen_rewards > rejected_rewards).float().mean()
+ chosen_rewards, rejected_rewards = (
+ chosen_rewards.mean(),
+ rejected_rewards.mean(),
+ )
+
+ loss = loss + dpo_loss
+
+ self.log(
+ f"{stage}/dpo_loss",
+ dpo_loss,
+ on_step=is_train,
+ on_epoch=not is_train,
+ prog_bar=False,
+ logger=True,
+ )
+
+ self.log(
+ f"{stage}/chosen_rewards",
+ chosen_rewards,
+ on_step=is_train,
+ on_epoch=not is_train,
+ prog_bar=False,
+ logger=True,
+ )
+
+ self.log(
+ f"{stage}/rejected_rewards",
+ rejected_rewards,
+ on_step=is_train,
+ on_epoch=not is_train,
+ prog_bar=False,
+ logger=True,
+ )
+
+ self.log(
+ f"{stage}/reward_accuracy",
+ reward_accuracy,
+ on_step=is_train,
+ on_epoch=not is_train,
+ prog_bar=False,
+ logger=True,
+ )
+
+ self.log(
+ f"{stage}/loss",
+ loss,
+ on_step=is_train,
+ on_epoch=not is_train,
+ prog_bar=True,
+ logger=True,
+ )
+
+ self.log(
+ f"{stage}/base_loss",
+ base_loss,
+ on_step=is_train,
+ on_epoch=not is_train,
+ prog_bar=False,
+ logger=True,
+ )
+
+ self.log(
+ f"{stage}/semantic_loss",
+ semantic_loss,
+ on_step=is_train,
+ on_epoch=not is_train,
+ prog_bar=False,
+ logger=True,
+ )
+
+ # Top-5 accuracy
+ accuracy = self.get_accuracy(codebook_logits, codebook_labels)
+ self.log(
+ f"{stage}/top_5_accuracy",
+ accuracy,
+ on_step=is_train,
+ on_epoch=not is_train,
+ prog_bar=True,
+ logger=True,
+ )
+
+ if self.model.config.num_codebooks != self.model.config.num_in_codebooks:
+ accuracy = self.get_accuracy(
+ codebook_logits[:, :, : self.model.config.num_in_codebooks],
+ codebook_labels[:, :, : self.model.config.num_in_codebooks],
+ )
+
+ self.log(
+ f"{stage}/top_5_accuracy_in",
+ accuracy,
+ on_step=is_train,
+ on_epoch=not is_train,
+ prog_bar=True,
+ logger=True,
+ )
+
+ return loss
+
+ def get_accuracy(self, logits, labels):
+ _, indices = logits.topk(5, dim=-1)
+ correct = indices.eq(labels.unsqueeze(-1))
+ correct[labels == -100] = 0
+ correct = correct.sum()
+ accuracy = correct / (labels != -100).sum()
+
+ return accuracy
+
+ def training_step(self, batch, batch_idx):
+ return self._step(batch, batch_idx, "train")
+
+ def validation_step(self, batch, batch_idx):
+ return self._step(batch, batch_idx, "val")
diff --git a/fish_speech/models/text2semantic/llama.py b/fish_speech/models/text2semantic/llama.py
new file mode 100644
index 0000000000000000000000000000000000000000..0999ecb3ef067880577e5ecad80f0e56970dcbeb
--- /dev/null
+++ b/fish_speech/models/text2semantic/llama.py
@@ -0,0 +1,595 @@
+import math
+from dataclasses import dataclass
+from typing import Optional
+
+import torch
+import torch.nn as nn
+from einops import rearrange
+from torch import Tensor
+from torch.nn import functional as F
+from torch.utils.checkpoint import checkpoint
+
+
+def find_multiple(n: int, k: int) -> int:
+ if n % k == 0:
+ return n
+ return n + k - (n % k)
+
+
+@dataclass
+class BaseModelArgs:
+ vocab_size: int = 32000
+ n_layer: int = 32
+ n_head: int = 32
+ dim: int = 4096
+ intermediate_size: int = None
+ n_local_heads: int = -1
+ head_dim: int = 64
+ rope_base: float = 10000
+ norm_eps: float = 1e-5
+ max_seq_len: int = 2048
+ dropout: float = 0.0
+
+ # Codebook configs
+ codebook_size: int = 160
+ num_codebooks: int = 4
+ num_in_codebooks: Optional[int] = None
+ codebook_padding_idx: int = 0
+
+ # Gradient checkpointing
+ use_gradient_checkpointing: bool = True
+
+ def __post_init__(self):
+ if self.n_local_heads == -1:
+ self.n_local_heads = self.n_head
+ if self.intermediate_size is None:
+ hidden_dim = 4 * self.dim
+ n_hidden = int(2 * hidden_dim / 3)
+ self.intermediate_size = find_multiple(n_hidden, 256)
+ if self.num_in_codebooks is None:
+ self.num_in_codebooks = self.num_codebooks
+ self.head_dim = self.dim // self.n_head
+
+
+@dataclass
+class NaiveModelArgs(BaseModelArgs):
+ pass
+
+
+@dataclass
+class DualARModelArgs(BaseModelArgs):
+ n_fast_layer: int = 4
+
+
+class KVCache(nn.Module):
+ def __init__(
+ self, max_batch_size, max_seq_len, n_heads, head_dim, dtype=torch.bfloat16
+ ):
+ super().__init__()
+ cache_shape = (max_batch_size, n_heads, max_seq_len, head_dim)
+ self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype))
+ self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype))
+
+ def update(self, input_pos, k_val, v_val):
+ # input_pos: [S], k_val: [B, H, S, D]
+ assert input_pos.shape[0] == k_val.shape[2]
+
+ k_out = self.k_cache
+ v_out = self.v_cache
+ k_out[:, :, input_pos] = k_val
+ v_out[:, :, input_pos] = v_val
+
+ return k_out, v_out
+
+
+@dataclass
+class TransformerForwardResult:
+ token_logits: Tensor
+ codebook_logits: Tensor
+
+
+@dataclass
+class BaseTransformerForwardResult:
+ logits: Tensor
+ hidden_states: Tensor
+
+
+class BaseTransformer(nn.Module):
+ def __init__(self, config: BaseModelArgs) -> None:
+ super().__init__()
+ self.config = config
+
+ # Slow transformer
+ self.embeddings = nn.Embedding(
+ config.vocab_size + config.codebook_size * config.num_in_codebooks,
+ config.dim,
+ )
+ self.layers = nn.ModuleList(
+ TransformerBlock(config, use_sdpa=True) for _ in range(config.n_layer)
+ )
+ self.norm = RMSNorm(config.dim, eps=config.norm_eps)
+ self.output = nn.Linear(
+ config.dim,
+ config.vocab_size,
+ bias=False,
+ )
+
+ self.register_buffer(
+ "freqs_cis",
+ precompute_freqs_cis(
+ config.max_seq_len,
+ config.dim // config.n_head,
+ config.rope_base,
+ ),
+ persistent=False,
+ )
+ self.register_buffer(
+ "causal_mask",
+ torch.tril(
+ torch.ones(
+ config.max_seq_len,
+ config.max_seq_len,
+ dtype=torch.bool,
+ )
+ ),
+ persistent=False,
+ )
+
+ # For kv cache
+ self.max_batch_size = -1
+ self.max_seq_len = -1
+
+ def setup_caches(
+ self, max_batch_size: int, max_seq_len: int, dtype: torch.dtype = torch.bfloat16
+ ):
+ if self.max_seq_len >= max_seq_len and self.max_batch_size >= max_batch_size:
+ return
+
+ head_dim = self.config.dim // self.config.n_head
+ max_seq_len = find_multiple(max_seq_len, 8)
+ self.max_seq_len = max_seq_len
+ self.max_batch_size = max_batch_size
+
+ for b in self.layers:
+ b.attention.kv_cache = KVCache(
+ max_batch_size,
+ max_seq_len,
+ self.config.n_local_heads,
+ head_dim,
+ dtype=dtype,
+ )
+
+ def embed(self, x: Tensor) -> Tensor:
+ vocab_embeds = [self.embeddings(x[:, 0])]
+ for i in range(self.config.num_in_codebooks):
+ emb = self.embeddings(
+ x[:, i + 1] + i * self.config.codebook_size + self.config.vocab_size
+ )
+ emb[x[:, i + 1] == self.config.codebook_padding_idx] = 0
+ vocab_embeds.append(emb)
+
+ x = torch.stack(vocab_embeds, dim=3)
+ x = x.sum(dim=3)
+
+ return x
+
+ def forward(
+ self, inp: Tensor, key_padding_mask: Optional[Tensor] = None
+ ) -> BaseTransformerForwardResult:
+ # x: (batch, num_codebooks + 1, seq_len)
+ seq_len = inp.size(2)
+
+ # Here we want to merge the embeddings of the codebooks
+ x = self.embed(inp)
+
+ mask = self.causal_mask[None, None, :seq_len, :seq_len] # (B, N, Q, K)
+ freqs_cis = self.freqs_cis[:seq_len]
+
+ # Not that the causal mask here follows the definition of scaled_dot_product_attention
+ # That is, FALSE means masked out
+ # To maintain consistency, key_padding_mask use TRUE to mask out
+ if key_padding_mask is not None:
+ mask = mask & key_padding_mask[:, None, None, :].logical_not()
+
+ for layer in self.layers:
+ if self.config.use_gradient_checkpointing and self.training:
+ x = checkpoint(layer, x, freqs_cis, mask, use_reentrant=True)
+ else:
+ x = layer(x, freqs_cis, mask)
+
+ # We got slow_out here
+ slow_out = self.norm(x)
+ token_logits = self.output(slow_out)
+
+ return BaseTransformerForwardResult(
+ logits=token_logits,
+ hidden_states=x,
+ )
+
+ def forward_generate(
+ self, x: Tensor, input_pos: Optional[Tensor] = None
+ ) -> BaseTransformerForwardResult:
+ # This is used for generation, optimized for torch compile
+ assert (
+ self.max_seq_len != -1 and self.max_batch_size != -1
+ ), "Please call setup_caches before forward_generate"
+
+ x = self.embed(x)
+
+ mask = self.causal_mask[
+ None, None, input_pos, : self.max_seq_len
+ ] # (B, N, Q, K)
+ freqs_cis = self.freqs_cis[input_pos]
+
+ for layer in self.layers:
+ x = layer(x, freqs_cis, mask, input_pos=input_pos)
+
+ # If prefill, we only calculate the logits of last token
+ if x.size(1) > 1:
+ x = x[:, -1:]
+
+ # We got slow_out here
+ slow_out = self.norm(x)
+ token_logits = self.output(slow_out)
+
+ return BaseTransformerForwardResult(
+ logits=token_logits,
+ hidden_states=x,
+ )
+
+
+class NaiveTransformer(BaseTransformer):
+ def __init__(self, config: NaiveModelArgs) -> None:
+ super().__init__(config)
+
+ self.codebook_norm = RMSNorm(config.dim, eps=config.norm_eps)
+ self.codebook_output = nn.Linear(
+ config.dim,
+ config.codebook_size * config.num_codebooks,
+ bias=False,
+ )
+
+ def decode(self, result: BaseTransformerForwardResult) -> TransformerForwardResult:
+ token_logits = result.logits
+ x = result.hidden_states
+
+ # Codebook
+ codebook_logits = self.codebook_output(self.codebook_norm(x))
+ codebook_logits = rearrange(
+ codebook_logits, "b n (c d) -> b n c d", c=self.config.num_codebooks
+ )
+
+ return TransformerForwardResult(
+ token_logits=token_logits,
+ codebook_logits=codebook_logits,
+ )
+
+ def forward(
+ self, inp: Tensor, key_padding_mask: Optional[Tensor] = None
+ ) -> TransformerForwardResult:
+ result = super().forward(inp, key_padding_mask)
+ return self.decode(result)
+
+ def forward_generate(
+ self, x: Tensor, input_pos: Optional[Tensor] = None
+ ) -> TransformerForwardResult:
+ result = super().forward_generate(x, input_pos)
+ return self.decode(result)
+
+
+class DualARTransformer(BaseTransformer):
+ def __init__(self, config: DualARModelArgs) -> None:
+ super().__init__(config)
+
+ # Fast transformer
+ self.fast_embeddings = nn.Embedding(
+ config.codebook_size, config.dim, padding_idx=config.codebook_padding_idx
+ )
+
+ # The equivalent bs is so large that sdpa doesn't work
+ self.fast_layers = nn.ModuleList(
+ TransformerBlock(config, use_sdpa=False) for _ in range(config.n_fast_layer)
+ )
+ self.fast_norm = RMSNorm(config.dim, eps=config.norm_eps)
+ self.fast_output = nn.Linear(
+ config.dim,
+ config.codebook_size,
+ bias=False,
+ )
+
+ def setup_caches(
+ self, max_batch_size: int, max_seq_len: int, dtype: torch.dtype = torch.bfloat16
+ ):
+ super().setup_caches(max_batch_size, max_seq_len, dtype)
+
+ head_dim = self.config.dim // self.config.n_head
+
+ # Fast transformer
+ # The max seq len here is the number of codebooks
+ for b in self.fast_layers:
+ b.attention.kv_cache = KVCache(
+ max_batch_size,
+ self.config.num_codebooks,
+ self.config.n_local_heads,
+ head_dim,
+ dtype=dtype,
+ )
+
+ def forward(
+ self, inp: Tensor, key_padding_mask: Optional[Tensor] = None
+ ) -> TransformerForwardResult:
+ parent_result = super().forward(inp, key_padding_mask)
+ token_logits = parent_result.logits
+ x = parent_result.hidden_states
+
+ # Fast transformer
+ fast_seq_len = self.config.num_codebooks
+ fast_mask = self.causal_mask[
+ None, None, :fast_seq_len, :fast_seq_len
+ ] # (B, N, Q, K)
+ fast_freqs_cis = self.freqs_cis[:fast_seq_len]
+
+ # Drop the last token and rotate left
+ codebooks = inp[:, 1:-1, 1:]
+ codebooks = F.pad(codebooks, (0, 1), value=self.config.codebook_padding_idx)
+ codebook_embeddings = self.fast_embeddings(codebooks)
+ x = torch.cat([x[:, None], codebook_embeddings], dim=1)
+ b, s = x.size(0), x.size(2)
+ x = rearrange(x, "b n s d -> (b s) n d") # flatten the batch and seq_len
+
+ # Remove padded part
+ codebooks = rearrange(codebooks, "b n s -> (b s) n")
+ codebook_mask = (codebooks == self.config.codebook_padding_idx).all(dim=-1)
+ x_bs, x_len = x.size(0), x.size(1)
+ x = x[~codebook_mask]
+
+ for layer in self.fast_layers:
+ if self.config.use_gradient_checkpointing and self.training:
+ x = checkpoint(layer, x, fast_freqs_cis, fast_mask, use_reentrant=True)
+ else:
+ x = layer(x, fast_freqs_cis, fast_mask)
+
+ # unflatten the batch and num_codebooks
+ fast_out = self.fast_norm(x)
+ codebook_logits = self.fast_output(fast_out)
+
+ # Re-pad the codebook_logits
+ buffer = torch.zeros(
+ x_bs,
+ x_len,
+ codebook_logits.size(-1),
+ device=codebook_logits.device,
+ dtype=codebook_logits.dtype,
+ )
+ buffer[~codebook_mask] = codebook_logits
+ codebook_logits = buffer
+
+ assert codebook_logits.shape[1] == self.config.num_codebooks
+ codebook_logits = rearrange(
+ codebook_logits,
+ "(b s) n d -> b s n d",
+ b=b,
+ s=s,
+ n=self.config.num_codebooks,
+ )
+
+ return TransformerForwardResult(
+ token_logits=token_logits,
+ codebook_logits=codebook_logits,
+ )
+
+ def forward_generate_fast(
+ self, x: Tensor, input_pos: Optional[Tensor] = None
+ ) -> Tensor:
+ # Fast transformer
+ x = x.view(1, 1, -1)
+
+ fast_mask = self.causal_mask[
+ None, None, input_pos, : self.config.num_codebooks
+ ] # (B, N, Q, K)
+ fast_freqs_cis = self.freqs_cis[input_pos]
+
+ for layer in self.fast_layers:
+ x = layer(x, fast_freqs_cis, fast_mask, input_pos=input_pos)
+
+ # unflatten the batch and num_codebooks
+ fast_out = self.fast_norm(x) # only take the last token
+ codebook_logits = self.fast_output(fast_out)
+
+ return codebook_logits
+
+
+class TransformerBlock(nn.Module):
+ def __init__(self, config: BaseModelArgs, use_sdpa: bool = True) -> None:
+ super().__init__()
+ self.attention = Attention(config, use_sdpa=use_sdpa)
+ self.feed_forward = FeedForward(config)
+ self.ffn_norm = RMSNorm(config.dim, config.norm_eps)
+ self.attention_norm = RMSNorm(config.dim, config.norm_eps)
+
+ def forward(
+ self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Tensor = None
+ ) -> Tensor:
+ h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos)
+ out = h + self.feed_forward(self.ffn_norm(h))
+ return out
+
+
+class Attention(nn.Module):
+ def __init__(self, config: BaseModelArgs, use_sdpa: bool = True):
+ super().__init__()
+ assert config.dim % config.n_head == 0
+
+ total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
+ # key, query, value projections for all heads, but in a batch
+ self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False)
+ self.wo = nn.Linear(config.dim, config.dim, bias=False)
+ self.kv_cache = None
+
+ self.dropout = config.dropout
+ self.n_head = config.n_head
+ self.head_dim = config.head_dim
+ self.n_local_heads = config.n_local_heads
+ self.dim = config.dim
+ self.use_sdpa = use_sdpa
+ self._register_load_state_dict_pre_hook(self.load_hook)
+
+ def load_hook(self, state_dict, prefix, *args):
+ if prefix + "wq.weight" in state_dict:
+ wq = state_dict.pop(prefix + "wq.weight")
+ wk = state_dict.pop(prefix + "wk.weight")
+ wv = state_dict.pop(prefix + "wv.weight")
+ state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])
+
+ def forward(
+ self,
+ x: Tensor,
+ freqs_cis: Tensor,
+ mask: Tensor,
+ input_pos: Optional[Tensor] = None,
+ ) -> Tensor:
+ bsz, seqlen, _ = x.shape
+
+ kv_size = self.n_local_heads * self.head_dim
+ q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)
+
+ q = q.view(bsz, seqlen, self.n_head, self.head_dim)
+ k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim)
+ v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim)
+
+ q = apply_rotary_emb(q, freqs_cis)
+ k = apply_rotary_emb(k, freqs_cis)
+
+ q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
+
+ if self.kv_cache is not None:
+ k, v = self.kv_cache.update(input_pos, k, v)
+
+ k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
+ v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
+
+ if self.use_sdpa:
+ y = F.scaled_dot_product_attention(
+ q,
+ k,
+ v,
+ attn_mask=mask,
+ dropout_p=self.dropout if self.training else 0.0,
+ )
+ else:
+ y = self.eq_scaled_dot_product_attention(
+ q,
+ k,
+ v,
+ attn_mask=mask,
+ dropout_p=self.dropout if self.training else 0.0,
+ )
+
+ y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
+
+ return self.wo(y)
+
+ def eq_scaled_dot_product_attention(
+ self,
+ query,
+ key,
+ value,
+ attn_mask=None,
+ dropout_p=0.0,
+ ) -> torch.Tensor:
+ # This is a standard scaled dot product attention
+ # It's low efficient, but it doesn't raise cuda error
+
+ L, S = query.size(-2), key.size(-2)
+ scale_factor = 1 / math.sqrt(query.size(-1))
+ attn_bias = torch.zeros(1, 1, L, S, dtype=query.dtype, device=query.device)
+
+ if attn_mask is not None:
+ if attn_mask.dtype == torch.bool:
+ attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
+ else:
+ attn_bias += attn_mask
+
+ attn_weight = query @ key.transpose(-2, -1) * scale_factor
+ attn_weight += attn_bias
+ attn_weight = torch.softmax(attn_weight, dim=-1)
+ attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
+
+ return attn_weight @ value
+
+
+class FeedForward(nn.Module):
+ def __init__(self, config: BaseModelArgs) -> None:
+ super().__init__()
+ self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False)
+ self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False)
+ self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False)
+
+ def forward(self, x: Tensor) -> Tensor:
+ return self.w2(F.silu(self.w1(x)) * self.w3(x))
+
+
+class RMSNorm(nn.Module):
+ def __init__(self, dim: int, eps: float = 1e-5):
+ super().__init__()
+ self.eps = eps
+ self.weight = nn.Parameter(torch.ones(dim))
+
+ def _norm(self, x):
+ return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
+
+ def forward(self, x: Tensor) -> Tensor:
+ output = self._norm(x.float()).type_as(x)
+ return output * self.weight
+
+
+def precompute_freqs_cis(seq_len: int, n_elem: int, base: int = 10000) -> Tensor:
+ freqs = 1.0 / (
+ base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)
+ )
+ t = torch.arange(seq_len, device=freqs.device)
+ freqs = torch.outer(t, freqs)
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
+ cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
+ return cache.to(dtype=torch.bfloat16)
+
+
+def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
+ xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
+ freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2)
+ x_out2 = torch.stack(
+ [
+ xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
+ xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
+ ],
+ -1,
+ )
+
+ x_out2 = x_out2.flatten(3)
+ return x_out2.type_as(x)
+
+
+if __name__ == "__main__":
+ args = DualARModelArgs(
+ max_seq_len=4096,
+ vocab_size=32312,
+ n_layer=12,
+ n_fast_layer=4,
+ n_head=12,
+ dim=768,
+ rope_base=10000,
+ norm_eps=1e-5,
+ codebook_size=128,
+ num_codebooks=4,
+ )
+
+ model = DualARTransformer(args)
+ model = model.cuda().bfloat16()
+ print("Total params:", sum(i.numel() for i in model.parameters()) / 1024 / 1024)
+
+ inputs = torch.randint(0, 100, (2, 5, 128)).cuda()
+ key_padding_mask = torch.zeros(2, 128).bool().cuda()
+ key_padding_mask[0, 2:] = True
+ x1 = model(inputs, key_padding_mask=key_padding_mask)
+ print(x1.token_logits.shape)
+ print(x1.codebook_logits.shape)
diff --git a/fish_speech/models/vqgan/__init__.py b/fish_speech/models/vqgan/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..401c6df468c7aa51be1ecaa71ac71513958ae055
--- /dev/null
+++ b/fish_speech/models/vqgan/__init__.py
@@ -0,0 +1,3 @@
+from .lit_module import VQGAN
+
+__all__ = ["VQGAN"]
diff --git a/fish_speech/models/vqgan/lit_module.py b/fish_speech/models/vqgan/lit_module.py
new file mode 100644
index 0000000000000000000000000000000000000000..bd0733ba748ab69bb539eb6b596b36a365ac460f
--- /dev/null
+++ b/fish_speech/models/vqgan/lit_module.py
@@ -0,0 +1,442 @@
+import itertools
+import math
+from typing import Any, Callable
+
+import lightning as L
+import torch
+import torch.nn.functional as F
+import wandb
+from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
+from matplotlib import pyplot as plt
+from torch import nn
+
+from fish_speech.models.vqgan.modules.discriminator import Discriminator
+from fish_speech.models.vqgan.modules.wavenet import WaveNet
+from fish_speech.models.vqgan.utils import avg_with_mask, plot_mel, sequence_mask
+
+
+class VQGAN(L.LightningModule):
+ def __init__(
+ self,
+ optimizer: Callable,
+ lr_scheduler: Callable,
+ encoder: WaveNet,
+ quantizer: nn.Module,
+ decoder: WaveNet,
+ discriminator: Discriminator,
+ vocoder: nn.Module,
+ encode_mel_transform: nn.Module,
+ gt_mel_transform: nn.Module,
+ weight_adv: float = 1.0,
+ weight_vq: float = 1.0,
+ weight_mel: float = 1.0,
+ sampling_rate: int = 44100,
+ freeze_encoder: bool = False,
+ ):
+ super().__init__()
+
+ # Model parameters
+ self.optimizer_builder = optimizer
+ self.lr_scheduler_builder = lr_scheduler
+
+ # Modules
+ self.encoder = encoder
+ self.quantizer = quantizer
+ self.decoder = decoder
+ self.vocoder = vocoder
+ self.discriminator = discriminator
+ self.encode_mel_transform = encode_mel_transform
+ self.gt_mel_transform = gt_mel_transform
+
+ # A simple linear layer to project quality to condition channels
+ self.quality_projection = nn.Linear(1, 768)
+
+ # Freeze vocoder
+ for param in self.vocoder.parameters():
+ param.requires_grad = False
+
+ # Loss weights
+ self.weight_adv = weight_adv
+ self.weight_vq = weight_vq
+ self.weight_mel = weight_mel
+
+ # Other parameters
+ self.sampling_rate = sampling_rate
+
+ # Disable strict loading
+ self.strict_loading = False
+
+ # If encoder is frozen
+ if freeze_encoder:
+ for param in self.encoder.parameters():
+ param.requires_grad = False
+
+ for param in self.quantizer.parameters():
+ param.requires_grad = False
+
+ self.automatic_optimization = False
+
+ def on_save_checkpoint(self, checkpoint):
+ # Do not save vocoder
+ state_dict = checkpoint["state_dict"]
+ for name in list(state_dict.keys()):
+ if "vocoder" in name:
+ state_dict.pop(name)
+
+ def configure_optimizers(self):
+ optimizer_generator = self.optimizer_builder(
+ itertools.chain(
+ self.encoder.parameters(),
+ self.quantizer.parameters(),
+ self.decoder.parameters(),
+ self.quality_projection.parameters(),
+ )
+ )
+ optimizer_discriminator = self.optimizer_builder(
+ self.discriminator.parameters()
+ )
+
+ lr_scheduler_generator = self.lr_scheduler_builder(optimizer_generator)
+ lr_scheduler_discriminator = self.lr_scheduler_builder(optimizer_discriminator)
+
+ return (
+ {
+ "optimizer": optimizer_generator,
+ "lr_scheduler": {
+ "scheduler": lr_scheduler_generator,
+ "interval": "step",
+ "name": "optimizer/generator",
+ },
+ },
+ {
+ "optimizer": optimizer_discriminator,
+ "lr_scheduler": {
+ "scheduler": lr_scheduler_discriminator,
+ "interval": "step",
+ "name": "optimizer/discriminator",
+ },
+ },
+ )
+
+ def training_step(self, batch, batch_idx):
+ optim_g, optim_d = self.optimizers()
+
+ audios, audio_lengths = batch["audios"], batch["audio_lengths"]
+
+ audios = audios.float()
+ audios = audios[:, None, :]
+
+ with torch.no_grad():
+ encoded_mels = self.encode_mel_transform(audios)
+ gt_mels = self.gt_mel_transform(audios)
+ quality = ((gt_mels.mean(-1) > -8).sum(-1) - 90) / 10
+ quality = quality.unsqueeze(-1)
+
+ mel_lengths = audio_lengths // self.gt_mel_transform.hop_length
+ mel_masks = sequence_mask(mel_lengths, gt_mels.shape[2])
+ mel_masks_float_conv = mel_masks[:, None, :].float()
+ gt_mels = gt_mels * mel_masks_float_conv
+ encoded_mels = encoded_mels * mel_masks_float_conv
+
+ # Encode
+ encoded_features = self.encoder(encoded_mels) * mel_masks_float_conv
+
+ # Quantize
+ vq_result = self.quantizer(encoded_features)
+ loss_vq = getattr("vq_result", "loss", 0.0)
+ vq_recon_features = vq_result.z * mel_masks_float_conv
+ vq_recon_features = (
+ vq_recon_features + self.quality_projection(quality)[:, :, None]
+ )
+
+ # VQ Decode
+ gen_mel = (
+ self.decoder(
+ torch.randn_like(vq_recon_features) * mel_masks_float_conv,
+ condition=vq_recon_features,
+ )
+ * mel_masks_float_conv
+ )
+
+ # Discriminator
+ real_logits = self.discriminator(gt_mels)
+ fake_logits = self.discriminator(gen_mel.detach())
+ d_mask = F.interpolate(
+ mel_masks_float_conv, size=(real_logits.shape[2],), mode="nearest"
+ )
+
+ loss_real = avg_with_mask((real_logits - 1) ** 2, d_mask)
+ loss_fake = avg_with_mask(fake_logits**2, d_mask)
+
+ loss_d = loss_real + loss_fake
+
+ self.log(
+ "train/discriminator/loss",
+ loss_d,
+ on_step=True,
+ on_epoch=False,
+ prog_bar=True,
+ logger=True,
+ )
+
+ # Discriminator backward
+ optim_d.zero_grad()
+ self.manual_backward(loss_d)
+ self.clip_gradients(
+ optim_d, gradient_clip_val=1000.0, gradient_clip_algorithm="norm"
+ )
+ optim_d.step()
+
+ # Mel Loss, applying l1, using a weighted sum
+ mel_distance = (
+ gen_mel - gt_mels
+ ).abs() # * 0.5 + self.ssim(gen_mel, gt_mels) * 0.5
+ loss_mel_low_freq = avg_with_mask(mel_distance[:, :40, :], mel_masks_float_conv)
+ loss_mel_mid_freq = avg_with_mask(
+ mel_distance[:, 40:70, :], mel_masks_float_conv
+ )
+ loss_mel_high_freq = avg_with_mask(
+ mel_distance[:, 70:, :], mel_masks_float_conv
+ )
+ loss_mel = (
+ loss_mel_low_freq * 0.6 + loss_mel_mid_freq * 0.3 + loss_mel_high_freq * 0.1
+ )
+
+ # Adversarial Loss
+ fake_logits = self.discriminator(gen_mel)
+ loss_adv = avg_with_mask((fake_logits - 1) ** 2, d_mask)
+
+ # Total loss
+ loss = (
+ self.weight_vq * loss_vq
+ + self.weight_mel * loss_mel
+ + self.weight_adv * loss_adv
+ )
+
+ # Log losses
+ self.log(
+ "train/generator/loss",
+ loss,
+ on_step=True,
+ on_epoch=False,
+ prog_bar=True,
+ logger=True,
+ )
+ self.log(
+ "train/generator/loss_vq",
+ loss_vq,
+ on_step=True,
+ on_epoch=False,
+ prog_bar=False,
+ logger=True,
+ )
+ self.log(
+ "train/generator/loss_mel",
+ loss_mel,
+ on_step=True,
+ on_epoch=False,
+ prog_bar=False,
+ logger=True,
+ )
+ self.log(
+ "train/generator/loss_adv",
+ loss_adv,
+ on_step=True,
+ on_epoch=False,
+ prog_bar=False,
+ logger=True,
+ )
+
+ # Generator backward
+ optim_g.zero_grad()
+ self.manual_backward(loss)
+ self.clip_gradients(
+ optim_g, gradient_clip_val=1000.0, gradient_clip_algorithm="norm"
+ )
+ optim_g.step()
+
+ scheduler_g, scheduler_d = self.lr_schedulers()
+ scheduler_g.step()
+ scheduler_d.step()
+
+ def validation_step(self, batch: Any, batch_idx: int):
+ audios, audio_lengths = batch["audios"], batch["audio_lengths"]
+
+ audios = audios.float()
+ audios = audios[:, None, :]
+
+ encoded_mels = self.encode_mel_transform(audios)
+ gt_mels = self.gt_mel_transform(audios)
+
+ mel_lengths = audio_lengths // self.gt_mel_transform.hop_length
+ mel_masks = sequence_mask(mel_lengths, gt_mels.shape[2])
+ mel_masks_float_conv = mel_masks[:, None, :].float()
+ gt_mels = gt_mels * mel_masks_float_conv
+ encoded_mels = encoded_mels * mel_masks_float_conv
+
+ # Encode
+ encoded_features = self.encoder(encoded_mels) * mel_masks_float_conv
+
+ # Quantize
+ vq_recon_features = self.quantizer(encoded_features).z * mel_masks_float_conv
+ vq_recon_features = (
+ vq_recon_features
+ + self.quality_projection(
+ torch.ones(
+ vq_recon_features.shape[0], 1, device=vq_recon_features.device
+ )
+ * 2
+ )[:, :, None]
+ )
+
+ # VQ Decode
+ gen_aux_mels = (
+ self.decoder(
+ torch.randn_like(vq_recon_features) * mel_masks_float_conv,
+ condition=vq_recon_features,
+ )
+ * mel_masks_float_conv
+ )
+ loss_mel = avg_with_mask((gen_aux_mels - gt_mels).abs(), mel_masks_float_conv)
+
+ self.log(
+ "val/loss_mel",
+ loss_mel,
+ on_step=False,
+ on_epoch=True,
+ prog_bar=False,
+ logger=True,
+ sync_dist=True,
+ )
+
+ recon_audios = self.vocoder(gt_mels)
+ gen_aux_audios = self.vocoder(gen_aux_mels)
+
+ # only log the first batch
+ if batch_idx != 0:
+ return
+
+ for idx, (
+ gt_mel,
+ gen_aux_mel,
+ audio,
+ gen_aux_audio,
+ recon_audio,
+ audio_len,
+ ) in enumerate(
+ zip(
+ gt_mels,
+ gen_aux_mels,
+ audios.cpu().float(),
+ gen_aux_audios.cpu().float(),
+ recon_audios.cpu().float(),
+ audio_lengths,
+ )
+ ):
+ if idx > 4:
+ break
+
+ mel_len = audio_len // self.gt_mel_transform.hop_length
+
+ image_mels = plot_mel(
+ [
+ gt_mel[:, :mel_len],
+ gen_aux_mel[:, :mel_len],
+ ],
+ [
+ "Ground-Truth",
+ "Auxiliary",
+ ],
+ )
+
+ if isinstance(self.logger, WandbLogger):
+ self.logger.experiment.log(
+ {
+ "reconstruction_mel": wandb.Image(image_mels, caption="mels"),
+ "wavs": [
+ wandb.Audio(
+ audio[0, :audio_len],
+ sample_rate=self.sampling_rate,
+ caption="gt",
+ ),
+ wandb.Audio(
+ gen_aux_audio[0, :audio_len],
+ sample_rate=self.sampling_rate,
+ caption="aux",
+ ),
+ wandb.Audio(
+ recon_audio[0, :audio_len],
+ sample_rate=self.sampling_rate,
+ caption="recon",
+ ),
+ ],
+ },
+ )
+
+ if isinstance(self.logger, TensorBoardLogger):
+ self.logger.experiment.add_figure(
+ f"sample-{idx}/mels",
+ image_mels,
+ global_step=self.global_step,
+ )
+ self.logger.experiment.add_audio(
+ f"sample-{idx}/wavs/gt",
+ audio[0, :audio_len],
+ self.global_step,
+ sample_rate=self.sampling_rate,
+ )
+ self.logger.experiment.add_audio(
+ f"sample-{idx}/wavs/gen",
+ gen_aux_audio[0, :audio_len],
+ self.global_step,
+ sample_rate=self.sampling_rate,
+ )
+ self.logger.experiment.add_audio(
+ f"sample-{idx}/wavs/recon",
+ recon_audio[0, :audio_len],
+ self.global_step,
+ sample_rate=self.sampling_rate,
+ )
+
+ plt.close(image_mels)
+
+ def encode(self, audios, audio_lengths):
+ audios = audios.float()
+
+ mels = self.encode_mel_transform(audios)
+ mel_lengths = audio_lengths // self.encode_mel_transform.hop_length
+ mel_masks = sequence_mask(mel_lengths, mels.shape[2])
+ mel_masks_float_conv = mel_masks[:, None, :].float()
+ mels = mels * mel_masks_float_conv
+
+ # Encode
+ encoded_features = self.encoder(mels) * mel_masks_float_conv
+ feature_lengths = mel_lengths // math.prod(self.quantizer.downsample_factor)
+
+ return self.quantizer.encode(encoded_features), feature_lengths
+
+ def decode(self, indices, feature_lengths, return_audios=False):
+ factor = math.prod(self.quantizer.downsample_factor)
+ mel_masks = sequence_mask(feature_lengths * factor, indices.shape[2] * factor)
+ mel_masks_float_conv = mel_masks[:, None, :].float()
+
+ z = self.quantizer.decode(indices) * mel_masks_float_conv
+ z = (
+ z
+ + self.quality_projection(torch.ones(z.shape[0], 1, device=z.device) * 2)[
+ :, :, None
+ ]
+ )
+
+ gen_mel = (
+ self.decoder(
+ torch.randn_like(z) * mel_masks_float_conv,
+ condition=z,
+ )
+ * mel_masks_float_conv
+ )
+
+ if return_audios:
+ return self.vocoder(gen_mel)
+
+ return gen_mel
diff --git a/fish_speech/models/vqgan/modules/discriminator.py b/fish_speech/models/vqgan/modules/discriminator.py
new file mode 100644
index 0000000000000000000000000000000000000000..69c7df41033f2cde22583468731f56b49eb594b7
--- /dev/null
+++ b/fish_speech/models/vqgan/modules/discriminator.py
@@ -0,0 +1,44 @@
+import torch
+from torch import nn
+from torch.nn.utils.parametrizations import weight_norm
+
+
+class Discriminator(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ blocks = []
+ convs = [
+ (1, 64, (3, 9), 1, (1, 4)),
+ (64, 128, (3, 9), (1, 2), (1, 4)),
+ (128, 256, (3, 9), (1, 2), (1, 4)),
+ (256, 512, (3, 9), (1, 2), (1, 4)),
+ (512, 1024, (3, 3), 1, (1, 1)),
+ (1024, 1, (3, 3), 1, (1, 1)),
+ ]
+
+ for idx, (in_channels, out_channels, kernel_size, stride, padding) in enumerate(
+ convs
+ ):
+ blocks.append(
+ weight_norm(
+ nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
+ )
+ )
+
+ if idx != len(convs) - 1:
+ blocks.append(nn.SiLU(inplace=True))
+
+ self.blocks = nn.Sequential(*blocks)
+
+ def forward(self, x):
+ return self.blocks(x[:, None])[:, 0]
+
+
+if __name__ == "__main__":
+ model = Discriminator()
+ print(sum(p.numel() for p in model.parameters()) / 1_000_000)
+ x = torch.randn(1, 128, 1024)
+ y = model(x)
+ print(y.shape)
+ print(y)
diff --git a/fish_speech/models/vqgan/modules/firefly.py b/fish_speech/models/vqgan/modules/firefly.py
new file mode 100644
index 0000000000000000000000000000000000000000..d762f99780ff72845c647b0ab0da0461160dc91b
--- /dev/null
+++ b/fish_speech/models/vqgan/modules/firefly.py
@@ -0,0 +1,538 @@
+# A inference only version of the FireflyGAN model
+
+from functools import partial
+from math import prod
+from typing import Callable
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torch import nn
+from torch.nn import Conv1d
+from torch.nn.utils.parametrizations import weight_norm
+from torch.nn.utils.parametrize import remove_parametrizations
+from torch.utils.checkpoint import checkpoint
+
+
+def init_weights(m, mean=0.0, std=0.01):
+ classname = m.__class__.__name__
+ if classname.find("Conv") != -1:
+ m.weight.data.normal_(mean, std)
+
+
+def get_padding(kernel_size, dilation=1):
+ return (kernel_size * dilation - dilation) // 2
+
+
+class ResBlock1(torch.nn.Module):
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
+ super().__init__()
+
+ self.convs1 = nn.ModuleList(
+ [
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=dilation[0],
+ padding=get_padding(kernel_size, dilation[0]),
+ )
+ ),
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=dilation[1],
+ padding=get_padding(kernel_size, dilation[1]),
+ )
+ ),
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=dilation[2],
+ padding=get_padding(kernel_size, dilation[2]),
+ )
+ ),
+ ]
+ )
+ self.convs1.apply(init_weights)
+
+ self.convs2 = nn.ModuleList(
+ [
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=1,
+ padding=get_padding(kernel_size, 1),
+ )
+ ),
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=1,
+ padding=get_padding(kernel_size, 1),
+ )
+ ),
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=1,
+ padding=get_padding(kernel_size, 1),
+ )
+ ),
+ ]
+ )
+ self.convs2.apply(init_weights)
+
+ def forward(self, x):
+ for c1, c2 in zip(self.convs1, self.convs2):
+ xt = F.silu(x)
+ xt = c1(xt)
+ xt = F.silu(xt)
+ xt = c2(xt)
+ x = xt + x
+ return x
+
+ def remove_parametrizations(self):
+ for conv in self.convs1:
+ remove_parametrizations(conv, tensor_name="weight")
+ for conv in self.convs2:
+ remove_parametrizations(conv, tensor_name="weight")
+
+
+class ParralelBlock(nn.Module):
+ def __init__(
+ self,
+ channels: int,
+ kernel_sizes: tuple[int] = (3, 7, 11),
+ dilation_sizes: tuple[tuple[int]] = ((1, 3, 5), (1, 3, 5), (1, 3, 5)),
+ ):
+ super().__init__()
+
+ assert len(kernel_sizes) == len(dilation_sizes)
+
+ self.blocks = nn.ModuleList()
+ for k, d in zip(kernel_sizes, dilation_sizes):
+ self.blocks.append(ResBlock1(channels, k, d))
+
+ def forward(self, x):
+ return torch.stack([block(x) for block in self.blocks], dim=0).mean(dim=0)
+
+ def remove_parametrizations(self):
+ for block in self.blocks:
+ block.remove_parametrizations()
+
+
+class HiFiGANGenerator(nn.Module):
+ def __init__(
+ self,
+ *,
+ hop_length: int = 512,
+ upsample_rates: tuple[int] = (8, 8, 2, 2, 2),
+ upsample_kernel_sizes: tuple[int] = (16, 16, 8, 2, 2),
+ resblock_kernel_sizes: tuple[int] = (3, 7, 11),
+ resblock_dilation_sizes: tuple[tuple[int]] = ((1, 3, 5), (1, 3, 5), (1, 3, 5)),
+ num_mels: int = 128,
+ upsample_initial_channel: int = 512,
+ use_template: bool = True,
+ pre_conv_kernel_size: int = 7,
+ post_conv_kernel_size: int = 7,
+ post_activation: Callable = partial(nn.SiLU, inplace=True),
+ ):
+ super().__init__()
+
+ assert (
+ prod(upsample_rates) == hop_length
+ ), f"hop_length must be {prod(upsample_rates)}"
+
+ self.conv_pre = weight_norm(
+ nn.Conv1d(
+ num_mels,
+ upsample_initial_channel,
+ pre_conv_kernel_size,
+ 1,
+ padding=get_padding(pre_conv_kernel_size),
+ )
+ )
+
+ self.num_upsamples = len(upsample_rates)
+ self.num_kernels = len(resblock_kernel_sizes)
+
+ self.noise_convs = nn.ModuleList()
+ self.use_template = use_template
+ self.ups = nn.ModuleList()
+
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
+ c_cur = upsample_initial_channel // (2 ** (i + 1))
+ self.ups.append(
+ weight_norm(
+ nn.ConvTranspose1d(
+ upsample_initial_channel // (2**i),
+ upsample_initial_channel // (2 ** (i + 1)),
+ k,
+ u,
+ padding=(k - u) // 2,
+ )
+ )
+ )
+
+ if not use_template:
+ continue
+
+ if i + 1 < len(upsample_rates):
+ stride_f0 = np.prod(upsample_rates[i + 1 :])
+ self.noise_convs.append(
+ Conv1d(
+ 1,
+ c_cur,
+ kernel_size=stride_f0 * 2,
+ stride=stride_f0,
+ padding=stride_f0 // 2,
+ )
+ )
+ else:
+ self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1))
+
+ self.resblocks = nn.ModuleList()
+ for i in range(len(self.ups)):
+ ch = upsample_initial_channel // (2 ** (i + 1))
+ self.resblocks.append(
+ ParralelBlock(ch, resblock_kernel_sizes, resblock_dilation_sizes)
+ )
+
+ self.activation_post = post_activation()
+ self.conv_post = weight_norm(
+ nn.Conv1d(
+ ch,
+ 1,
+ post_conv_kernel_size,
+ 1,
+ padding=get_padding(post_conv_kernel_size),
+ )
+ )
+ self.ups.apply(init_weights)
+ self.conv_post.apply(init_weights)
+
+ def forward(self, x, template=None):
+ x = self.conv_pre(x)
+
+ for i in range(self.num_upsamples):
+ x = F.silu(x, inplace=True)
+ x = self.ups[i](x)
+
+ if self.use_template:
+ x = x + self.noise_convs[i](template)
+
+ if self.training and self.checkpointing:
+ x = checkpoint(
+ self.resblocks[i],
+ x,
+ use_reentrant=False,
+ )
+ else:
+ x = self.resblocks[i](x)
+
+ x = self.activation_post(x)
+ x = self.conv_post(x)
+ x = torch.tanh(x)
+
+ return x
+
+ def remove_parametrizations(self):
+ for up in self.ups:
+ remove_parametrizations(up, tensor_name="weight")
+ for block in self.resblocks:
+ block.remove_parametrizations()
+ remove_parametrizations(self.conv_pre, tensor_name="weight")
+ remove_parametrizations(self.conv_post, tensor_name="weight")
+
+
+# DropPath copied from timm library
+def drop_path(
+ x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
+):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+
+ This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
+ the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
+ changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
+ 'survival rate' as the argument.
+
+ """ # noqa: E501
+
+ if drop_prob == 0.0 or not training:
+ return x
+ keep_prob = 1 - drop_prob
+ shape = (x.shape[0],) + (1,) * (
+ x.ndim - 1
+ ) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
+ if keep_prob > 0.0 and scale_by_keep:
+ random_tensor.div_(keep_prob)
+ return x * random_tensor
+
+
+class DropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" # noqa: E501
+
+ def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
+ super(DropPath, self).__init__()
+ self.drop_prob = drop_prob
+ self.scale_by_keep = scale_by_keep
+
+ def forward(self, x):
+ return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
+
+ def extra_repr(self):
+ return f"drop_prob={round(self.drop_prob,3):0.3f}"
+
+
+class LayerNorm(nn.Module):
+ r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
+ The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
+ shape (batch_size, height, width, channels) while channels_first corresponds to inputs
+ with shape (batch_size, channels, height, width).
+ """ # noqa: E501
+
+ def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
+ self.bias = nn.Parameter(torch.zeros(normalized_shape))
+ self.eps = eps
+ self.data_format = data_format
+ if self.data_format not in ["channels_last", "channels_first"]:
+ raise NotImplementedError
+ self.normalized_shape = (normalized_shape,)
+
+ def forward(self, x):
+ if self.data_format == "channels_last":
+ return F.layer_norm(
+ x, self.normalized_shape, self.weight, self.bias, self.eps
+ )
+ elif self.data_format == "channels_first":
+ u = x.mean(1, keepdim=True)
+ s = (x - u).pow(2).mean(1, keepdim=True)
+ x = (x - u) / torch.sqrt(s + self.eps)
+ x = self.weight[:, None] * x + self.bias[:, None]
+ return x
+
+
+# ConvNeXt Block copied from https://github.com/fishaudio/fish-diffusion/blob/main/fish_diffusion/modules/convnext.py
+class ConvNeXtBlock(nn.Module):
+ r"""ConvNeXt Block. There are two equivalent implementations:
+ (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
+ (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
+ We use (2) as we find it slightly faster in PyTorch
+
+ Args:
+ dim (int): Number of input channels.
+ drop_path (float): Stochastic depth rate. Default: 0.0
+ layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0.
+ kernel_size (int): Kernel size for depthwise conv. Default: 7.
+ dilation (int): Dilation for depthwise conv. Default: 1.
+ """ # noqa: E501
+
+ def __init__(
+ self,
+ dim: int,
+ drop_path: float = 0.0,
+ layer_scale_init_value: float = 1e-6,
+ mlp_ratio: float = 4.0,
+ kernel_size: int = 7,
+ dilation: int = 1,
+ ):
+ super().__init__()
+
+ self.dwconv = nn.Conv1d(
+ dim,
+ dim,
+ kernel_size=kernel_size,
+ padding=int(dilation * (kernel_size - 1) / 2),
+ groups=dim,
+ ) # depthwise conv
+ self.norm = LayerNorm(dim, eps=1e-6)
+ self.pwconv1 = nn.Linear(
+ dim, int(mlp_ratio * dim)
+ ) # pointwise/1x1 convs, implemented with linear layers
+ self.act = nn.GELU()
+ self.pwconv2 = nn.Linear(int(mlp_ratio * dim), dim)
+ self.gamma = (
+ nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
+ if layer_scale_init_value > 0
+ else None
+ )
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+ def forward(self, x, apply_residual: bool = True):
+ input = x
+
+ x = self.dwconv(x)
+ x = x.permute(0, 2, 1) # (N, C, L) -> (N, L, C)
+ x = self.norm(x)
+ x = self.pwconv1(x)
+ x = self.act(x)
+ x = self.pwconv2(x)
+
+ if self.gamma is not None:
+ x = self.gamma * x
+
+ x = x.permute(0, 2, 1) # (N, L, C) -> (N, C, L)
+ x = self.drop_path(x)
+
+ if apply_residual:
+ x = input + x
+
+ return x
+
+
+class ConvNeXtEncoder(nn.Module):
+ def __init__(
+ self,
+ input_channels: int = 3,
+ depths: list[int] = [3, 3, 9, 3],
+ dims: list[int] = [96, 192, 384, 768],
+ drop_path_rate: float = 0.0,
+ layer_scale_init_value: float = 1e-6,
+ kernel_size: int = 7,
+ ):
+ super().__init__()
+ assert len(depths) == len(dims)
+
+ self.downsample_layers = nn.ModuleList()
+ stem = nn.Sequential(
+ nn.Conv1d(
+ input_channels,
+ dims[0],
+ kernel_size=kernel_size,
+ padding=kernel_size // 2,
+ padding_mode="zeros",
+ ),
+ LayerNorm(dims[0], eps=1e-6, data_format="channels_first"),
+ )
+ self.downsample_layers.append(stem)
+
+ for i in range(len(depths) - 1):
+ mid_layer = nn.Sequential(
+ LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
+ nn.Conv1d(dims[i], dims[i + 1], kernel_size=1),
+ )
+ self.downsample_layers.append(mid_layer)
+
+ self.stages = nn.ModuleList()
+ dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
+
+ cur = 0
+ for i in range(len(depths)):
+ stage = nn.Sequential(
+ *[
+ ConvNeXtBlock(
+ dim=dims[i],
+ drop_path=dp_rates[cur + j],
+ layer_scale_init_value=layer_scale_init_value,
+ kernel_size=kernel_size,
+ )
+ for j in range(depths[i])
+ ]
+ )
+ self.stages.append(stage)
+ cur += depths[i]
+
+ self.norm = LayerNorm(dims[-1], eps=1e-6, data_format="channels_first")
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, (nn.Conv1d, nn.Linear)):
+ nn.init.trunc_normal_(m.weight, std=0.02)
+ nn.init.constant_(m.bias, 0)
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ ) -> torch.Tensor:
+ for i in range(len(self.downsample_layers)):
+ x = self.downsample_layers[i](x)
+ x = self.stages[i](x)
+
+ return self.norm(x)
+
+
+class FireflyBase(nn.Module):
+ def __init__(self, ckpt_path: str = None, pretrained: bool = True):
+ super().__init__()
+
+ self.backbone = ConvNeXtEncoder(
+ input_channels=128,
+ depths=[3, 3, 9, 3],
+ dims=[128, 256, 384, 512],
+ drop_path_rate=0.2,
+ kernel_size=7,
+ )
+
+ self.head = HiFiGANGenerator(
+ hop_length=512,
+ upsample_rates=[8, 8, 2, 2, 2],
+ upsample_kernel_sizes=[16, 16, 4, 4, 4],
+ resblock_kernel_sizes=[3, 7, 11],
+ resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
+ num_mels=512,
+ upsample_initial_channel=512,
+ use_template=False,
+ pre_conv_kernel_size=13,
+ post_conv_kernel_size=13,
+ )
+
+ if ckpt_path is not None:
+ self.load_state_dict(torch.load(ckpt_path, map_location="cpu"))
+ elif pretrained:
+ state_dict = torch.hub.load_state_dict_from_url(
+ "https://github.com/fishaudio/vocoder/releases/download/1.0.0/firefly-gan-base-generator.ckpt",
+ map_location="cpu",
+ )
+
+ if "state_dict" in state_dict:
+ state_dict = state_dict["state_dict"]
+
+ if any("generator." in k for k in state_dict):
+ state_dict = {
+ k.replace("generator.", ""): v
+ for k, v in state_dict.items()
+ if "generator." in k
+ }
+
+ self.load_state_dict(state_dict, strict=True)
+ self.head.remove_parametrizations()
+
+ @torch.no_grad()
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.backbone(x)
+ x = self.head(x)
+ if x.ndim == 2:
+ x = x[:, None, :]
+ return x
+
+
+if __name__ == "__main__":
+ model = FireflyBase()
+ model.eval()
+ x = torch.randn(1, 128, 128)
+ with torch.no_grad():
+ y = model(x)
+ print(y.shape)
diff --git a/fish_speech/models/vqgan/modules/fsq.py b/fish_speech/models/vqgan/modules/fsq.py
new file mode 100644
index 0000000000000000000000000000000000000000..c33319e0d20c8dd8d88e808ba3cc852f96ff3d2b
--- /dev/null
+++ b/fish_speech/models/vqgan/modules/fsq.py
@@ -0,0 +1,139 @@
+from dataclasses import dataclass
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import rearrange
+from vector_quantize_pytorch import GroupedResidualFSQ
+
+from .firefly import ConvNeXtBlock
+
+
+@dataclass
+class FSQResult:
+ z: torch.Tensor
+ codes: torch.Tensor
+ latents: torch.Tensor
+
+
+class DownsampleFiniteScalarQuantize(nn.Module):
+ def __init__(
+ self,
+ input_dim: int = 512,
+ n_codebooks: int = 9,
+ n_groups: int = 1,
+ levels: tuple[int] = (8, 5, 5, 5), # Approximate 2**10
+ downsample_factor: tuple[int] = (2, 2),
+ downsample_dims: tuple[int] | None = None,
+ ):
+ super().__init__()
+
+ if downsample_dims is None:
+ downsample_dims = [input_dim for _ in range(len(downsample_factor))]
+
+ all_dims = (input_dim,) + tuple(downsample_dims)
+
+ self.residual_fsq = GroupedResidualFSQ(
+ dim=all_dims[-1],
+ levels=levels,
+ num_quantizers=n_codebooks,
+ groups=n_groups,
+ )
+
+ self.downsample_factor = downsample_factor
+ self.downsample_dims = downsample_dims
+
+ self.downsample = nn.Sequential(
+ *[
+ nn.Sequential(
+ nn.Conv1d(
+ all_dims[idx],
+ all_dims[idx + 1],
+ kernel_size=factor,
+ stride=factor,
+ ),
+ ConvNeXtBlock(dim=all_dims[idx + 1]),
+ )
+ for idx, factor in enumerate(downsample_factor)
+ ]
+ )
+
+ self.upsample = nn.Sequential(
+ *[
+ nn.Sequential(
+ nn.ConvTranspose1d(
+ all_dims[idx + 1],
+ all_dims[idx],
+ kernel_size=factor,
+ stride=factor,
+ ),
+ ConvNeXtBlock(dim=all_dims[idx]),
+ )
+ for idx, factor in reversed(list(enumerate(downsample_factor)))
+ ]
+ )
+
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, (nn.Conv1d, nn.Linear)):
+ nn.init.trunc_normal_(m.weight, std=0.02)
+ nn.init.constant_(m.bias, 0)
+
+ def forward(self, z) -> FSQResult:
+ original_shape = z.shape
+ z = self.downsample(z)
+ quantized, indices = self.residual_fsq(z.mT)
+ result = FSQResult(
+ z=quantized.mT,
+ codes=indices.mT,
+ latents=z,
+ )
+ result.z = self.upsample(result.z)
+
+ # Pad or crop z to match original shape
+ diff = original_shape[-1] - result.z.shape[-1]
+ left = diff // 2
+ right = diff - left
+
+ if diff > 0:
+ result.z = F.pad(result.z, (left, right))
+ elif diff < 0:
+ result.z = result.z[..., left:-right]
+
+ return result
+
+ def encode(self, z):
+ z = self.downsample(z)
+ _, indices = self.residual_fsq(z.mT)
+ indices = rearrange(indices, "g b l r -> b (g r) l")
+ return indices
+
+ def decode(self, indices: torch.Tensor):
+ indices = rearrange(indices, "b (g r) l -> g b l r", g=self.residual_fsq.groups)
+ z_q = self.residual_fsq.get_output_from_indices(indices)
+ z_q = self.upsample(z_q.mT)
+ return z_q
+
+ # def from_latents(self, latents: torch.Tensor):
+ # z_q, z_p, codes = super().from_latents(latents)
+ # z_q = self.upsample(z_q)
+ # return z_q, z_p, codes
+
+
+if __name__ == "__main__":
+ rvq = DownsampleFiniteScalarQuantize(
+ n_codebooks=1,
+ downsample_factor=(2, 2),
+ )
+ x = torch.randn(16, 512, 80)
+
+ result = rvq(x)
+ print(rvq)
+ print(result.latents.shape, result.codes.shape, result.z.shape)
+
+ # y = rvq.from_codes(result.codes)
+ # print(y[0].shape)
+
+ # y = rvq.from_latents(result.latents)
+ # print(y[0].shape)
diff --git a/fish_speech/models/vqgan/modules/reference.py b/fish_speech/models/vqgan/modules/reference.py
new file mode 100644
index 0000000000000000000000000000000000000000..034d5c5e3572bd3828649fc0f82a1856ccc6b9e1
--- /dev/null
+++ b/fish_speech/models/vqgan/modules/reference.py
@@ -0,0 +1,113 @@
+from typing import Optional
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from .wavenet import WaveNet
+
+
+class ReferenceEncoder(WaveNet):
+ def __init__(
+ self,
+ input_channels: Optional[int] = None,
+ output_channels: Optional[int] = None,
+ residual_channels: int = 512,
+ residual_layers: int = 20,
+ dilation_cycle: Optional[int] = 4,
+ num_heads: int = 8,
+ latent_len: int = 4,
+ ):
+ super().__init__(
+ input_channels=input_channels,
+ residual_channels=residual_channels,
+ residual_layers=residual_layers,
+ dilation_cycle=dilation_cycle,
+ )
+
+ self.head_dim = residual_channels // num_heads
+ self.num_heads = num_heads
+
+ self.latent_len = latent_len
+ self.latent = nn.Parameter(torch.zeros(1, self.latent_len, residual_channels))
+
+ self.q = nn.Linear(residual_channels, residual_channels, bias=True)
+ self.kv = nn.Linear(residual_channels, residual_channels * 2, bias=True)
+ self.q_norm = nn.LayerNorm(self.head_dim)
+ self.k_norm = nn.LayerNorm(self.head_dim)
+ self.proj = nn.Linear(residual_channels, residual_channels)
+ self.proj_drop = nn.Dropout(0.1)
+
+ self.norm = nn.LayerNorm(residual_channels)
+ self.mlp = nn.Sequential(
+ nn.Linear(residual_channels, residual_channels * 4),
+ nn.SiLU(),
+ nn.Linear(residual_channels * 4, residual_channels),
+ )
+ self.output_projection_attn = nn.Linear(residual_channels, output_channels)
+
+ torch.nn.init.trunc_normal_(self.latent, std=0.02)
+ self.apply(self.init_weights)
+
+ def init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ torch.nn.init.trunc_normal_(m.weight, std=0.02)
+ if m.bias is not None:
+ torch.nn.init.constant_(m.bias, 0)
+
+ def forward(self, x, attn_mask=None):
+ x = super().forward(x).mT
+ B, N, C = x.shape
+
+ # Calculate mask
+ if attn_mask is not None:
+ assert attn_mask.shape == (B, N) and attn_mask.dtype == torch.bool
+
+ attn_mask = attn_mask[:, None, None, :].expand(
+ B, self.num_heads, self.latent_len, N
+ )
+
+ q_latent = self.latent.expand(B, -1, -1)
+ q = (
+ self.q(q_latent)
+ .reshape(B, self.latent_len, self.num_heads, self.head_dim)
+ .transpose(1, 2)
+ )
+
+ kv = (
+ self.kv(x)
+ .reshape(B, N, 2, self.num_heads, self.head_dim)
+ .permute(2, 0, 3, 1, 4)
+ )
+ k, v = kv.unbind(0)
+
+ q, k = self.q_norm(q), self.k_norm(k)
+ x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
+
+ x = x.transpose(1, 2).reshape(B, self.latent_len, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+
+ x = x + self.mlp(self.norm(x))
+ x = self.output_projection_attn(x)
+ x = x.mean(1)
+
+ return x
+
+
+if __name__ == "__main__":
+ with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
+ model = ReferenceEncoder(
+ input_channels=128,
+ output_channels=64,
+ residual_channels=384,
+ residual_layers=20,
+ dilation_cycle=4,
+ num_heads=8,
+ )
+ x = torch.randn(4, 128, 64)
+ mask = torch.ones(4, 64, dtype=torch.bool)
+ y = model(x, mask)
+ print(y.shape)
+ loss = F.mse_loss(y, torch.randn(4, 64))
+ loss.backward()
diff --git a/fish_speech/models/vqgan/modules/wavenet.py b/fish_speech/models/vqgan/modules/wavenet.py
new file mode 100644
index 0000000000000000000000000000000000000000..e7cc011c3e071067ff36e1aba12c05cff81d94f6
--- /dev/null
+++ b/fish_speech/models/vqgan/modules/wavenet.py
@@ -0,0 +1,225 @@
+import math
+from typing import Optional
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+
+class Mish(nn.Module):
+ def forward(self, x):
+ return x * torch.tanh(F.softplus(x))
+
+
+class DiffusionEmbedding(nn.Module):
+ """Diffusion Step Embedding"""
+
+ def __init__(self, d_denoiser):
+ super(DiffusionEmbedding, self).__init__()
+ self.dim = d_denoiser
+
+ def forward(self, x):
+ device = x.device
+ half_dim = self.dim // 2
+ emb = math.log(10000) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
+ emb = x[:, None] * emb[None, :]
+ emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
+ return emb
+
+
+class LinearNorm(nn.Module):
+ """LinearNorm Projection"""
+
+ def __init__(self, in_features, out_features, bias=False):
+ super(LinearNorm, self).__init__()
+ self.linear = nn.Linear(in_features, out_features, bias)
+
+ nn.init.xavier_uniform_(self.linear.weight)
+ if bias:
+ nn.init.constant_(self.linear.bias, 0.0)
+
+ def forward(self, x):
+ x = self.linear(x)
+ return x
+
+
+class ConvNorm(nn.Module):
+ """1D Convolution"""
+
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ kernel_size=1,
+ stride=1,
+ padding=None,
+ dilation=1,
+ bias=True,
+ w_init_gain="linear",
+ ):
+ super(ConvNorm, self).__init__()
+
+ if padding is None:
+ assert kernel_size % 2 == 1
+ padding = int(dilation * (kernel_size - 1) / 2)
+
+ self.conv = nn.Conv1d(
+ in_channels,
+ out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ dilation=dilation,
+ bias=bias,
+ )
+ nn.init.kaiming_normal_(self.conv.weight)
+
+ def forward(self, signal):
+ conv_signal = self.conv(signal)
+
+ return conv_signal
+
+
+class ResidualBlock(nn.Module):
+ """Residual Block"""
+
+ def __init__(
+ self,
+ residual_channels,
+ use_linear_bias=False,
+ dilation=1,
+ condition_channels=None,
+ ):
+ super(ResidualBlock, self).__init__()
+ self.conv_layer = ConvNorm(
+ residual_channels,
+ 2 * residual_channels,
+ kernel_size=3,
+ stride=1,
+ padding=dilation,
+ dilation=dilation,
+ )
+
+ if condition_channels is not None:
+ self.diffusion_projection = LinearNorm(
+ residual_channels, residual_channels, use_linear_bias
+ )
+ self.condition_projection = ConvNorm(
+ condition_channels, 2 * residual_channels, kernel_size=1
+ )
+
+ self.output_projection = ConvNorm(
+ residual_channels, 2 * residual_channels, kernel_size=1
+ )
+
+ def forward(self, x, condition=None, diffusion_step=None):
+ y = x
+
+ if diffusion_step is not None:
+ diffusion_step = self.diffusion_projection(diffusion_step).unsqueeze(-1)
+ y = y + diffusion_step
+
+ y = self.conv_layer(y)
+
+ if condition is not None:
+ condition = self.condition_projection(condition)
+ y = y + condition
+
+ gate, filter = torch.chunk(y, 2, dim=1)
+ y = torch.sigmoid(gate) * torch.tanh(filter)
+
+ y = self.output_projection(y)
+ residual, skip = torch.chunk(y, 2, dim=1)
+
+ return (x + residual) / math.sqrt(2.0), skip
+
+
+class WaveNet(nn.Module):
+ def __init__(
+ self,
+ input_channels: Optional[int] = None,
+ output_channels: Optional[int] = None,
+ residual_channels: int = 512,
+ residual_layers: int = 20,
+ dilation_cycle: Optional[int] = 4,
+ is_diffusion: bool = False,
+ condition_channels: Optional[int] = None,
+ ):
+ super().__init__()
+
+ # Input projection
+ self.input_projection = None
+ if input_channels is not None and input_channels != residual_channels:
+ self.input_projection = ConvNorm(
+ input_channels, residual_channels, kernel_size=1
+ )
+
+ if input_channels is None:
+ input_channels = residual_channels
+
+ self.input_channels = input_channels
+
+ # Residual layers
+ self.residual_layers = nn.ModuleList(
+ [
+ ResidualBlock(
+ residual_channels=residual_channels,
+ use_linear_bias=False,
+ dilation=2 ** (i % dilation_cycle) if dilation_cycle else 1,
+ condition_channels=condition_channels,
+ )
+ for i in range(residual_layers)
+ ]
+ )
+
+ # Skip projection
+ self.skip_projection = ConvNorm(
+ residual_channels, residual_channels, kernel_size=1
+ )
+
+ # Output projection
+ self.output_projection = None
+ if output_channels is not None and output_channels != residual_channels:
+ self.output_projection = ConvNorm(
+ residual_channels, output_channels, kernel_size=1
+ )
+
+ if is_diffusion:
+ self.diffusion_embedding = DiffusionEmbedding(residual_channels)
+ self.mlp = nn.Sequential(
+ LinearNorm(residual_channels, residual_channels * 4, False),
+ Mish(),
+ LinearNorm(residual_channels * 4, residual_channels, False),
+ )
+
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, (nn.Conv1d, nn.Linear)):
+ nn.init.trunc_normal_(m.weight, std=0.02)
+ if getattr(m, "bias", None) is not None:
+ nn.init.constant_(m.bias, 0)
+
+ def forward(self, x, t=None, condition=None):
+ if self.input_projection is not None:
+ x = self.input_projection(x)
+ x = F.silu(x)
+
+ if t is not None:
+ t = self.diffusion_embedding(t)
+ t = self.mlp(t)
+
+ skip = []
+ for layer in self.residual_layers:
+ x, skip_connection = layer(x, condition, t)
+ skip.append(skip_connection)
+
+ x = torch.sum(torch.stack(skip), dim=0) / math.sqrt(len(self.residual_layers))
+ x = self.skip_projection(x)
+
+ if self.output_projection is not None:
+ x = F.silu(x)
+ x = self.output_projection(x)
+
+ return x
diff --git a/fish_speech/models/vqgan/spectrogram.py b/fish_speech/models/vqgan/spectrogram.py
new file mode 100644
index 0000000000000000000000000000000000000000..01c3d7a2ab0f707ae92dbde0feb173927720c841
--- /dev/null
+++ b/fish_speech/models/vqgan/spectrogram.py
@@ -0,0 +1,122 @@
+import torch
+import torchaudio.functional as F
+from torch import Tensor, nn
+from torchaudio.transforms import MelScale
+
+
+class LinearSpectrogram(nn.Module):
+ def __init__(
+ self,
+ n_fft=2048,
+ win_length=2048,
+ hop_length=512,
+ center=False,
+ mode="pow2_sqrt",
+ ):
+ super().__init__()
+
+ self.n_fft = n_fft
+ self.win_length = win_length
+ self.hop_length = hop_length
+ self.center = center
+ self.mode = mode
+
+ self.register_buffer("window", torch.hann_window(win_length), persistent=False)
+
+ def forward(self, y: Tensor) -> Tensor:
+ if y.ndim == 3:
+ y = y.squeeze(1)
+
+ y = torch.nn.functional.pad(
+ y.unsqueeze(1),
+ (
+ (self.win_length - self.hop_length) // 2,
+ (self.win_length - self.hop_length + 1) // 2,
+ ),
+ mode="reflect",
+ ).squeeze(1)
+
+ spec = torch.stft(
+ y,
+ self.n_fft,
+ hop_length=self.hop_length,
+ win_length=self.win_length,
+ window=self.window,
+ center=self.center,
+ pad_mode="reflect",
+ normalized=False,
+ onesided=True,
+ return_complex=True,
+ )
+
+ spec = torch.view_as_real(spec)
+
+ if self.mode == "pow2_sqrt":
+ spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
+
+ return spec
+
+
+class LogMelSpectrogram(nn.Module):
+ def __init__(
+ self,
+ sample_rate=44100,
+ n_fft=2048,
+ win_length=2048,
+ hop_length=512,
+ n_mels=128,
+ center=False,
+ f_min=0.0,
+ f_max=None,
+ ):
+ super().__init__()
+
+ self.sample_rate = sample_rate
+ self.n_fft = n_fft
+ self.win_length = win_length
+ self.hop_length = hop_length
+ self.center = center
+ self.n_mels = n_mels
+ self.f_min = f_min
+ self.f_max = f_max or float(sample_rate // 2)
+
+ self.spectrogram = LinearSpectrogram(n_fft, win_length, hop_length, center)
+
+ fb = F.melscale_fbanks(
+ n_freqs=self.n_fft // 2 + 1,
+ f_min=self.f_min,
+ f_max=self.f_max,
+ n_mels=self.n_mels,
+ sample_rate=self.sample_rate,
+ norm="slaney",
+ mel_scale="slaney",
+ )
+ self.register_buffer(
+ "fb",
+ fb,
+ persistent=False,
+ )
+
+ def compress(self, x: Tensor) -> Tensor:
+ return torch.log(torch.clamp(x, min=1e-5))
+
+ def decompress(self, x: Tensor) -> Tensor:
+ return torch.exp(x)
+
+ def apply_mel_scale(self, x: Tensor) -> Tensor:
+ return torch.matmul(x.transpose(-1, -2), self.fb).transpose(-1, -2)
+
+ def forward(
+ self, x: Tensor, return_linear: bool = False, sample_rate: int = None
+ ) -> Tensor:
+ if sample_rate is not None and sample_rate != self.sample_rate:
+ x = F.resample(x, orig_freq=sample_rate, new_freq=self.sample_rate)
+
+ linear = self.spectrogram(x)
+ x = self.apply_mel_scale(linear)
+ x = self.compress(x)
+
+ if return_linear:
+ return x, self.compress(linear)
+
+ return x
diff --git a/fish_speech/models/vqgan/utils.py b/fish_speech/models/vqgan/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..b90c131d214006875476a161cdfd2dffa8949dac
--- /dev/null
+++ b/fish_speech/models/vqgan/utils.py
@@ -0,0 +1,94 @@
+import matplotlib
+import torch
+from matplotlib import pyplot as plt
+
+matplotlib.use("Agg")
+
+
+def convert_pad_shape(pad_shape):
+ l = pad_shape[::-1]
+ pad_shape = [item for sublist in l for item in sublist]
+ return pad_shape
+
+
+def sequence_mask(length, max_length=None):
+ if max_length is None:
+ max_length = length.max()
+ x = torch.arange(max_length, dtype=length.dtype, device=length.device)
+ return x.unsqueeze(0) < length.unsqueeze(1)
+
+
+def init_weights(m, mean=0.0, std=0.01):
+ classname = m.__class__.__name__
+ if classname.find("Conv") != -1:
+ m.weight.data.normal_(mean, std)
+
+
+def get_padding(kernel_size, dilation=1):
+ return int((kernel_size * dilation - dilation) / 2)
+
+
+def plot_mel(data, titles=None):
+ fig, axes = plt.subplots(len(data), 1, squeeze=False)
+
+ if titles is None:
+ titles = [None for i in range(len(data))]
+
+ plt.tight_layout()
+
+ for i in range(len(data)):
+ mel = data[i]
+
+ if isinstance(mel, torch.Tensor):
+ mel = mel.float().detach().cpu().numpy()
+
+ axes[i][0].imshow(mel, origin="lower")
+ axes[i][0].set_aspect(2.5, adjustable="box")
+ axes[i][0].set_ylim(0, mel.shape[0])
+ axes[i][0].set_title(titles[i], fontsize="medium")
+ axes[i][0].tick_params(labelsize="x-small", left=False, labelleft=False)
+ axes[i][0].set_anchor("W")
+
+ return fig
+
+
+def slice_segments(x, ids_str, segment_size=4):
+ ret = torch.zeros_like(x[:, :, :segment_size])
+ for i in range(x.size(0)):
+ idx_str = ids_str[i]
+ idx_end = idx_str + segment_size
+ ret[i] = x[i, :, idx_str:idx_end]
+
+ return ret
+
+
+def rand_slice_segments(x, x_lengths=None, segment_size=4):
+ b, d, t = x.size()
+ if x_lengths is None:
+ x_lengths = t
+ ids_str_max = torch.clamp(x_lengths - segment_size + 1, min=0)
+ ids_str = (torch.rand([b], device=x.device) * ids_str_max).to(dtype=torch.long)
+ ret = slice_segments(x, ids_str, segment_size)
+ return ret, ids_str
+
+
+@torch.jit.script
+def fused_add_tanh_sigmoid_multiply(in_act, n_channels):
+ n_channels_int = n_channels[0]
+ t_act = torch.tanh(in_act[:, :n_channels_int, :])
+ s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
+ acts = t_act * s_act
+
+ return acts
+
+
+def avg_with_mask(x, mask):
+ assert mask.dtype == torch.float, "Mask should be float"
+
+ if mask.ndim == 2:
+ mask = mask.unsqueeze(1)
+
+ if mask.shape[1] == 1:
+ mask = mask.expand_as(x)
+
+ return (x * mask).sum() / mask.sum()
diff --git a/fish_speech/scheduler.py b/fish_speech/scheduler.py
new file mode 100644
index 0000000000000000000000000000000000000000..08469aed3d283a64bfdc219173b28c262ad2398d
--- /dev/null
+++ b/fish_speech/scheduler.py
@@ -0,0 +1,22 @@
+import math
+
+
+def get_cosine_schedule_with_warmup_lr_lambda(
+ current_step: int,
+ *,
+ num_warmup_steps: int,
+ num_training_steps: int,
+ num_cycles: float = 0.5,
+ final_lr_ratio: float = 0.0,
+):
+ if current_step < num_warmup_steps:
+ return float(current_step) / float(max(1, num_warmup_steps))
+
+ progress = float(current_step - num_warmup_steps) / float(
+ max(1, num_training_steps - num_warmup_steps)
+ )
+
+ return max(
+ final_lr_ratio,
+ 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)),
+ )
diff --git a/fish_speech/text/__init__.py b/fish_speech/text/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ad704d79a881c50d68b4a0fc1490104ee933dec1
--- /dev/null
+++ b/fish_speech/text/__init__.py
@@ -0,0 +1,3 @@
+from .clean import clean_text
+
+__all__ = ["clean_text"]
diff --git a/fish_speech/text/clean.py b/fish_speech/text/clean.py
new file mode 100644
index 0000000000000000000000000000000000000000..f2c8e72ce626fe1c02ff8c40493aa486a5afaf2a
--- /dev/null
+++ b/fish_speech/text/clean.py
@@ -0,0 +1,73 @@
+import itertools
+import re
+
+LANGUAGE_UNICODE_RANGE_MAP = {
+ "ZH": [(0x4E00, 0x9FFF)],
+ "JP": [(0x4E00, 0x9FFF), (0x3040, 0x309F), (0x30A0, 0x30FF), (0x31F0, 0x31FF)],
+ "EN": [(0x0000, 0x007F)],
+}
+
+SYMBOLS_MAPPING = {
+ ":": ",",
+ ";": ",",
+ ",": ",",
+ "。": ".",
+ "!": "!",
+ "?": "?",
+ "\n": ".",
+ "·": ",",
+ "、": ",",
+ "...": "…",
+ "$": ".",
+ "“": "'",
+ "”": "'",
+ "‘": "'",
+ "’": "'",
+ "(": "'",
+ ")": "'",
+ "(": "'",
+ ")": "'",
+ "《": "'",
+ "》": "'",
+ "【": "'",
+ "】": "'",
+ "[": "'",
+ "]": "'",
+ "—": "-",
+ "~": "-",
+ "~": "-",
+ "・": "-",
+ "「": "'",
+ "」": "'",
+ ";": ",",
+ ":": ",",
+}
+
+REPLACE_SYMBOL_REGEX = re.compile(
+ "|".join(re.escape(p) for p in SYMBOLS_MAPPING.keys())
+)
+ALL_KNOWN_UTF8_RANGE = list(
+ itertools.chain.from_iterable(LANGUAGE_UNICODE_RANGE_MAP.values())
+)
+REMOVE_UNKNOWN_SYMBOL_REGEX = re.compile(
+ "[^"
+ + "".join(
+ f"{re.escape(chr(start))}-{re.escape(chr(end))}"
+ for start, end in ALL_KNOWN_UTF8_RANGE
+ )
+ + "]"
+)
+
+
+def clean_text(text):
+ # Clean the text
+ text = text.strip()
+ # Replace with
+ text = re.sub(r"", r"", text)
+ # Replace all chinese symbols with their english counterparts
+ text = REPLACE_SYMBOL_REGEX.sub(lambda x: SYMBOLS_MAPPING[x.group()], text)
+ text = REMOVE_UNKNOWN_SYMBOL_REGEX.sub("", text)
+ # Replace with
+ text = re.sub(r"", r"", text)
+
+ return text
diff --git a/fish_speech/train.py b/fish_speech/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..528ef9c95b2e96b3415c1160e45edd63933e0d3c
--- /dev/null
+++ b/fish_speech/train.py
@@ -0,0 +1,135 @@
+import os
+from typing import Optional
+
+import hydra
+import lightning as L
+import pyrootutils
+import torch
+from lightning import Callback, LightningDataModule, LightningModule, Trainer
+from lightning.pytorch.loggers import Logger
+from omegaconf import DictConfig, OmegaConf
+
+os.environ.pop("SLURM_NTASKS", None)
+os.environ.pop("SLURM_JOB_NAME", None)
+os.environ.pop("SLURM_NTASKS_PER_NODE", None)
+
+# register eval resolver and root
+pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
+
+# Allow TF32 on Ampere GPUs
+torch.set_float32_matmul_precision("high")
+torch.backends.cudnn.allow_tf32 = True
+
+# register eval resolver
+OmegaConf.register_new_resolver("eval", eval)
+
+import fish_speech.utils as utils
+
+log = utils.RankedLogger(__name__, rank_zero_only=True)
+
+
+@utils.task_wrapper
+def train(cfg: DictConfig) -> tuple[dict, dict]:
+ """Trains the model. Can additionally evaluate on a testset, using best weights obtained during
+ training.
+ This method is wrapped in optional @task_wrapper decorator, that controls the behavior during
+ failure. Useful for multiruns, saving info about the crash, etc.
+ Args:
+ cfg (DictConfig): Configuration composed by Hydra.
+ Returns:
+ Tuple[dict, dict]: Dict with metrics and dict with all instantiated objects.
+ """ # noqa: E501
+
+ # set seed for random number generators in pytorch, numpy and python.random
+ if cfg.get("seed"):
+ L.seed_everything(cfg.seed, workers=False)
+
+ if cfg.get("deterministic"):
+ torch.use_deterministic_algorithms(True)
+
+ log.info(f"Instantiating datamodule <{cfg.data._target_}>")
+ datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data)
+
+ log.info(f"Instantiating model <{cfg.model._target_}>")
+ model: LightningModule = hydra.utils.instantiate(cfg.model)
+
+ log.info("Instantiating callbacks...")
+ callbacks: list[Callback] = utils.instantiate_callbacks(cfg.get("callbacks"))
+
+ log.info("Instantiating loggers...")
+ logger: list[Logger] = utils.instantiate_loggers(cfg.get("logger"))
+
+ log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
+ trainer: Trainer = hydra.utils.instantiate(
+ cfg.trainer, callbacks=callbacks, logger=logger
+ )
+
+ object_dict = {
+ "cfg": cfg,
+ "datamodule": datamodule,
+ "model": model,
+ "callbacks": callbacks,
+ "logger": logger,
+ "trainer": trainer,
+ }
+
+ if logger:
+ log.info("Logging hyperparameters!")
+ utils.log_hyperparameters(object_dict)
+
+ if cfg.get("train"):
+ log.info("Starting training!")
+
+ ckpt_path = cfg.get("ckpt_path")
+ auto_resume = False
+
+ resume_ckpt_path = utils.get_latest_checkpoint(cfg.paths.ckpt_dir)
+ if resume_ckpt_path is not None:
+ ckpt_path = resume_ckpt_path
+ auto_resume = True
+
+ if ckpt_path is not None:
+ log.info(f"Resuming from checkpoint: {ckpt_path}")
+
+ # resume weights only is disabled for auto-resume
+ if cfg.get("resume_weights_only") and auto_resume is False:
+ log.info("Resuming weights only!")
+ ckpt = torch.load(ckpt_path, map_location=model.device)
+ if "state_dict" in ckpt:
+ ckpt = ckpt["state_dict"]
+ err = model.load_state_dict(ckpt, strict=False)
+ log.info(f"Error loading state dict: {err}")
+ ckpt_path = None
+
+ trainer.fit(model=model, datamodule=datamodule, ckpt_path=ckpt_path)
+
+ train_metrics = trainer.callback_metrics
+
+ if cfg.get("test"):
+ log.info("Starting testing!")
+ ckpt_path = trainer.checkpoint_callback.best_model_path
+ if ckpt_path == "":
+ log.warning("Best ckpt not found! Using current weights for testing...")
+ ckpt_path = cfg.get("ckpt_path")
+
+ trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path)
+ log.info(f"Best ckpt path: {ckpt_path}")
+
+ test_metrics = trainer.callback_metrics
+
+ # merge train and test metrics
+ metric_dict = {**train_metrics, **test_metrics}
+
+ return metric_dict, object_dict
+
+
+@hydra.main(
+ version_base="1.3", config_path="./configs", config_name="llama_pretrain.yaml"
+)
+def main(cfg: DictConfig) -> Optional[float]:
+ # train the model
+ train(cfg)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/fish_speech/utils/__init__.py b/fish_speech/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c3e0106d7f479b6f3b4ba19117c010bd87f39c2b
--- /dev/null
+++ b/fish_speech/utils/__init__.py
@@ -0,0 +1,21 @@
+from .braceexpand import braceexpand
+from .file import get_latest_checkpoint
+from .instantiators import instantiate_callbacks, instantiate_loggers
+from .logger import RankedLogger
+from .logging_utils import log_hyperparameters
+from .rich_utils import enforce_tags, print_config_tree
+from .utils import extras, get_metric_value, task_wrapper
+
+__all__ = [
+ "enforce_tags",
+ "extras",
+ "get_metric_value",
+ "RankedLogger",
+ "instantiate_callbacks",
+ "instantiate_loggers",
+ "log_hyperparameters",
+ "print_config_tree",
+ "task_wrapper",
+ "braceexpand",
+ "get_latest_checkpoint",
+]
diff --git a/fish_speech/utils/braceexpand.py b/fish_speech/utils/braceexpand.py
new file mode 100644
index 0000000000000000000000000000000000000000..f3ac739f01f7e10e039c68c1157d6c761064f974
--- /dev/null
+++ b/fish_speech/utils/braceexpand.py
@@ -0,0 +1,217 @@
+"""
+Bash-style brace expansion
+Copied from: https://github.com/trendels/braceexpand/blob/main/src/braceexpand/__init__.py
+License: MIT
+"""
+
+import re
+import string
+from itertools import chain, product
+from typing import Iterable, Iterator, Optional
+
+__all__ = ["braceexpand", "alphabet", "UnbalancedBracesError"]
+
+
+class UnbalancedBracesError(ValueError):
+ pass
+
+
+alphabet = string.ascii_uppercase + string.ascii_lowercase
+
+int_range_re = re.compile(r"^(-?\d+)\.\.(-?\d+)(?:\.\.-?(\d+))?$")
+char_range_re = re.compile(r"^([A-Za-z])\.\.([A-Za-z])(?:\.\.-?(\d+))?$")
+escape_re = re.compile(r"\\(.)")
+
+
+def braceexpand(pattern: str, escape: bool = True) -> Iterator[str]:
+ """braceexpand(pattern) -> iterator over generated strings
+
+ Returns an iterator over the strings resulting from brace expansion
+ of pattern. This function implements Brace Expansion as described in
+ bash(1), with the following limitations:
+
+ * A pattern containing unbalanced braces will raise an
+ UnbalancedBracesError exception. In bash, unbalanced braces will either
+ be partly expanded or ignored.
+
+ * A mixed-case character range like '{Z..a}' or '{a..Z}' will not
+ include the characters '[]^_`' between 'Z' and 'a'.
+
+ When escape is True (the default), characters in pattern can be
+ prefixed with a backslash to cause them not to be interpreted as
+ special characters for brace expansion (such as '{', '}', ',').
+ To pass through a a literal backslash, double it ('\\\\').
+
+ When escape is False, backslashes in pattern have no special
+ meaning and will be preserved in the output.
+
+ Examples:
+
+ >>> from braceexpand import braceexpand
+
+ # Integer range
+ >>> list(braceexpand('item{1..3}'))
+ ['item1', 'item2', 'item3']
+
+ # Character range
+ >>> list(braceexpand('{a..c}'))
+ ['a', 'b', 'c']
+
+ # Sequence
+ >>> list(braceexpand('index.html{,.backup}'))
+ ['index.html', 'index.html.backup']
+
+ # Nested patterns
+ >>> list(braceexpand('python{2.{5..7},3.{2,3}}'))
+ ['python2.5', 'python2.6', 'python2.7', 'python3.2', 'python3.3']
+
+ # Prefixing an integer with zero causes all numbers to be padded to
+ # the same width.
+ >>> list(braceexpand('{07..10}'))
+ ['07', '08', '09', '10']
+
+ # An optional increment can be specified for ranges.
+ >>> list(braceexpand('{a..g..2}'))
+ ['a', 'c', 'e', 'g']
+
+ # Ranges can go in both directions.
+ >>> list(braceexpand('{4..1}'))
+ ['4', '3', '2', '1']
+
+ # Numbers can be negative
+ >>> list(braceexpand('{2..-1}'))
+ ['2', '1', '0', '-1']
+
+ # Unbalanced braces raise an exception.
+ >>> list(braceexpand('{1{2,3}'))
+ Traceback (most recent call last):
+ ...
+ UnbalancedBracesError: Unbalanced braces: '{1{2,3}'
+
+ # By default, the backslash is the escape character.
+ >>> list(braceexpand(r'{1\\{2,3}'))
+ ['1{2', '3']
+
+ # Setting 'escape' to False disables backslash escaping.
+ >>> list(braceexpand(r'\\{1,2}', escape=False))
+ ['\\\\1', '\\\\2']
+
+ """
+ return (
+ escape_re.sub(r"\1", s) if escape else s for s in parse_pattern(pattern, escape)
+ )
+
+
+def parse_pattern(pattern: str, escape: bool) -> Iterator[str]:
+ start = 0
+ pos = 0
+ bracketdepth = 0
+ items: list[Iterable[str]] = []
+
+ # print 'pattern:', pattern
+ while pos < len(pattern):
+ if escape and pattern[pos] == "\\":
+ pos += 2
+ continue
+ elif pattern[pos] == "{":
+ if bracketdepth == 0 and pos > start:
+ # print 'literal:', pattern[start:pos]
+ items.append([pattern[start:pos]])
+ start = pos
+ bracketdepth += 1
+ elif pattern[pos] == "}":
+ bracketdepth -= 1
+ if bracketdepth == 0:
+ # print 'expression:', pattern[start+1:pos]
+ expr = pattern[start + 1 : pos]
+ item = parse_expression(expr, escape)
+ if item is None: # not a range or sequence
+ items.extend([["{"], parse_pattern(expr, escape), ["}"]])
+ else:
+ items.append(item)
+ start = pos + 1 # skip the closing brace
+ pos += 1
+
+ if bracketdepth != 0: # unbalanced braces
+ raise UnbalancedBracesError("Unbalanced braces: '%s'" % pattern)
+
+ if start < pos:
+ items.append([pattern[start:]])
+
+ return ("".join(item) for item in product(*items))
+
+
+def parse_expression(expr: str, escape: bool) -> Optional[Iterable[str]]:
+ int_range_match = int_range_re.match(expr)
+ if int_range_match:
+ return make_int_range(*int_range_match.groups())
+
+ char_range_match = char_range_re.match(expr)
+ if char_range_match:
+ return make_char_range(*char_range_match.groups())
+
+ return parse_sequence(expr, escape)
+
+
+def parse_sequence(seq: str, escape: bool) -> Optional[Iterator[str]]:
+ # sequence -> chain(*sequence_items)
+ start = 0
+ pos = 0
+ bracketdepth = 0
+ items: list[Iterable[str]] = []
+
+ # print 'sequence:', seq
+ while pos < len(seq):
+ if escape and seq[pos] == "\\":
+ pos += 2
+ continue
+ elif seq[pos] == "{":
+ bracketdepth += 1
+ elif seq[pos] == "}":
+ bracketdepth -= 1
+ elif seq[pos] == "," and bracketdepth == 0:
+ items.append(parse_pattern(seq[start:pos], escape))
+ start = pos + 1 # skip the comma
+ pos += 1
+
+ if bracketdepth != 0:
+ raise UnbalancedBracesError
+ if not items:
+ return None
+
+ # part after the last comma (may be the empty string)
+ items.append(parse_pattern(seq[start:], escape))
+ return chain(*items)
+
+
+def make_int_range(left: str, right: str, incr: Optional[str] = None) -> Iterator[str]:
+ if any([s.startswith(("0", "-0")) for s in (left, right) if s not in ("0", "-0")]):
+ padding = max(len(left), len(right))
+ else:
+ padding = 0
+ step = (int(incr) or 1) if incr else 1
+ start = int(left)
+ end = int(right)
+ r = range(start, end + 1, step) if start < end else range(start, end - 1, -step)
+ fmt = "%0{}d".format(padding)
+ return (fmt % i for i in r)
+
+
+def make_char_range(left: str, right: str, incr: Optional[str] = None) -> str:
+ step = (int(incr) or 1) if incr else 1
+ start = alphabet.index(left)
+ end = alphabet.index(right)
+ if start < end:
+ return alphabet[start : end + 1 : step]
+ else:
+ end = end or -len(alphabet)
+ return alphabet[start : end - 1 : -step]
+
+
+if __name__ == "__main__":
+ import doctest
+ import sys
+
+ failed, _ = doctest.testmod(optionflags=doctest.IGNORE_EXCEPTION_DETAIL)
+ if failed:
+ sys.exit(1)
diff --git a/fish_speech/utils/file.py b/fish_speech/utils/file.py
new file mode 100644
index 0000000000000000000000000000000000000000..11ae5975158c94a58931bf9ea3b19f2d6d04c1cb
--- /dev/null
+++ b/fish_speech/utils/file.py
@@ -0,0 +1,119 @@
+import os
+from glob import glob
+from pathlib import Path
+from typing import Union
+
+from loguru import logger
+from natsort import natsorted
+
+AUDIO_EXTENSIONS = {
+ ".mp3",
+ ".wav",
+ ".flac",
+ ".ogg",
+ ".m4a",
+ ".wma",
+ ".aac",
+ ".aiff",
+ ".aif",
+ ".aifc",
+}
+
+
+def list_files(
+ path: Union[Path, str],
+ extensions: set[str] = None,
+ recursive: bool = False,
+ sort: bool = True,
+) -> list[Path]:
+ """List files in a directory.
+
+ Args:
+ path (Path): Path to the directory.
+ extensions (set, optional): Extensions to filter. Defaults to None.
+ recursive (bool, optional): Whether to search recursively. Defaults to False.
+ sort (bool, optional): Whether to sort the files. Defaults to True.
+
+ Returns:
+ list: List of files.
+ """
+
+ if isinstance(path, str):
+ path = Path(path)
+
+ if not path.exists():
+ raise FileNotFoundError(f"Directory {path} does not exist.")
+
+ files = [file for ext in extensions for file in path.iglob(f"**/*{ext}")]
+
+ if sort:
+ files = natsorted(files)
+
+ return files
+
+
+def get_latest_checkpoint(path: Path | str) -> Path | None:
+ # Find the latest checkpoint
+ ckpt_dir = Path(path)
+
+ if ckpt_dir.exists() is False:
+ return None
+
+ ckpts = sorted(ckpt_dir.glob("*.ckpt"), key=os.path.getmtime)
+ if len(ckpts) == 0:
+ return None
+
+ return ckpts[-1]
+
+
+def load_filelist(path: Path | str) -> list[tuple[Path, str, str, str]]:
+ """
+ Load a Bert-VITS2 style filelist.
+ """
+
+ files = set()
+ results = []
+ count_duplicated, count_not_found = 0, 0
+
+ LANGUAGE_TO_LANGUAGES = {
+ "zh": ["zh", "en"],
+ "jp": ["jp", "en"],
+ "en": ["en"],
+ }
+
+ with open(path, "r", encoding="utf-8") as f:
+ for line in f.readlines():
+ splits = line.strip().split("|", maxsplit=3)
+ if len(splits) != 4:
+ logger.warning(f"Invalid line: {line}")
+ continue
+
+ filename, speaker, language, text = splits
+ file = Path(filename)
+ language = language.strip().lower()
+
+ if language == "ja":
+ language = "jp"
+
+ assert language in ["zh", "jp", "en"], f"Invalid language {language}"
+ languages = LANGUAGE_TO_LANGUAGES[language]
+
+ if file in files:
+ logger.warning(f"Duplicated file: {file}")
+ count_duplicated += 1
+ continue
+
+ if not file.exists():
+ logger.warning(f"File not found: {file}")
+ count_not_found += 1
+ continue
+
+ results.append((file, speaker, languages, text))
+
+ if count_duplicated > 0:
+ logger.warning(f"Total duplicated files: {count_duplicated}")
+
+ if count_not_found > 0:
+ logger.warning(f"Total files not found: {count_not_found}")
+
+ return results
diff --git a/fish_speech/utils/instantiators.py b/fish_speech/utils/instantiators.py
new file mode 100644
index 0000000000000000000000000000000000000000..f6ee463924f588a35477937fbe3c3364043bdf3e
--- /dev/null
+++ b/fish_speech/utils/instantiators.py
@@ -0,0 +1,50 @@
+from typing import List
+
+import hydra
+from omegaconf import DictConfig
+from pytorch_lightning import Callback
+from pytorch_lightning.loggers import Logger
+
+from .logger import RankedLogger
+
+log = RankedLogger(__name__, rank_zero_only=True)
+
+
+def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]:
+ """Instantiates callbacks from config."""
+
+ callbacks: List[Callback] = []
+
+ if not callbacks_cfg:
+ log.warning("No callback configs found! Skipping..")
+ return callbacks
+
+ if not isinstance(callbacks_cfg, DictConfig):
+ raise TypeError("Callbacks config must be a DictConfig!")
+
+ for _, cb_conf in callbacks_cfg.items():
+ if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf:
+ log.info(f"Instantiating callback <{cb_conf._target_}>")
+ callbacks.append(hydra.utils.instantiate(cb_conf))
+
+ return callbacks
+
+
+def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]:
+ """Instantiates loggers from config."""
+
+ logger: List[Logger] = []
+
+ if not logger_cfg:
+ log.warning("No logger configs found! Skipping...")
+ return logger
+
+ if not isinstance(logger_cfg, DictConfig):
+ raise TypeError("Logger config must be a DictConfig!")
+
+ for _, lg_conf in logger_cfg.items():
+ if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf:
+ log.info(f"Instantiating logger <{lg_conf._target_}>")
+ logger.append(hydra.utils.instantiate(lg_conf))
+
+ return logger
diff --git a/fish_speech/utils/logger.py b/fish_speech/utils/logger.py
new file mode 100644
index 0000000000000000000000000000000000000000..94f94f738d1d87404354d086c30ef0ad9ab04cdc
--- /dev/null
+++ b/fish_speech/utils/logger.py
@@ -0,0 +1,55 @@
+import logging
+from typing import Mapping, Optional
+
+from lightning_utilities.core.rank_zero import rank_prefixed_message, rank_zero_only
+
+
+class RankedLogger(logging.LoggerAdapter):
+ """A multi-GPU-friendly python command line logger."""
+
+ def __init__(
+ self,
+ name: str = __name__,
+ rank_zero_only: bool = True,
+ extra: Optional[Mapping[str, object]] = None,
+ ) -> None:
+ """Initializes a multi-GPU-friendly python command line logger that logs on all processes
+ with their rank prefixed in the log message.
+
+ :param name: The name of the logger. Default is ``__name__``.
+ :param rank_zero_only: Whether to force all logs to only occur on the rank zero process. Default is `False`.
+ :param extra: (Optional) A dict-like object which provides contextual information. See `logging.LoggerAdapter`.
+ """
+ logger = logging.getLogger(name)
+ super().__init__(logger=logger, extra=extra)
+ self.rank_zero_only = rank_zero_only
+
+ def log(
+ self, level: int, msg: str, rank: Optional[int] = None, *args, **kwargs
+ ) -> None:
+ """Delegate a log call to the underlying logger, after prefixing its message with the rank
+ of the process it's being logged from. If `'rank'` is provided, then the log will only
+ occur on that rank/process.
+
+ :param level: The level to log at. Look at `logging.__init__.py` for more information.
+ :param msg: The message to log.
+ :param rank: The rank to log at.
+ :param args: Additional args to pass to the underlying logging function.
+ :param kwargs: Any additional keyword args to pass to the underlying logging function.
+ """
+ if self.isEnabledFor(level):
+ msg, kwargs = self.process(msg, kwargs)
+ current_rank = getattr(rank_zero_only, "rank", None)
+ if current_rank is None:
+ raise RuntimeError(
+ "The `rank_zero_only.rank` needs to be set before use"
+ )
+ msg = rank_prefixed_message(msg, current_rank)
+ if self.rank_zero_only:
+ if current_rank == 0:
+ self.logger.log(level, msg, *args, **kwargs)
+ else:
+ if rank is None:
+ self.logger.log(level, msg, *args, **kwargs)
+ elif current_rank == rank:
+ self.logger.log(level, msg, *args, **kwargs)
diff --git a/fish_speech/utils/logging_utils.py b/fish_speech/utils/logging_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..8e3b0a2519e12845f09e5fbe86dfccbf5b345429
--- /dev/null
+++ b/fish_speech/utils/logging_utils.py
@@ -0,0 +1,48 @@
+from lightning.pytorch.utilities import rank_zero_only
+
+from fish_speech.utils import logger as log
+
+
+@rank_zero_only
+def log_hyperparameters(object_dict: dict) -> None:
+ """Controls which config parts are saved by lightning loggers.
+
+ Additionally saves:
+ - Number of model parameters
+ """
+
+ hparams = {}
+
+ cfg = object_dict["cfg"]
+ model = object_dict["model"]
+ trainer = object_dict["trainer"]
+
+ if not trainer.logger:
+ log.warning("Logger not found! Skipping hyperparameter logging...")
+ return
+
+ hparams["model"] = cfg["model"]
+
+ # save number of model parameters
+ hparams["model/params/total"] = sum(p.numel() for p in model.parameters())
+ hparams["model/params/trainable"] = sum(
+ p.numel() for p in model.parameters() if p.requires_grad
+ )
+ hparams["model/params/non_trainable"] = sum(
+ p.numel() for p in model.parameters() if not p.requires_grad
+ )
+
+ hparams["data"] = cfg["data"]
+ hparams["trainer"] = cfg["trainer"]
+
+ hparams["callbacks"] = cfg.get("callbacks")
+ hparams["extras"] = cfg.get("extras")
+
+ hparams["task_name"] = cfg.get("task_name")
+ hparams["tags"] = cfg.get("tags")
+ hparams["ckpt_path"] = cfg.get("ckpt_path")
+ hparams["seed"] = cfg.get("seed")
+
+ # send hparams to all loggers
+ for logger in trainer.loggers:
+ logger.log_hyperparams(hparams)
diff --git a/fish_speech/utils/rich_utils.py b/fish_speech/utils/rich_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..c8f77c72c270632779e7e9762919b615dcf3594e
--- /dev/null
+++ b/fish_speech/utils/rich_utils.py
@@ -0,0 +1,96 @@
+from pathlib import Path
+from typing import Sequence
+
+import rich
+import rich.syntax
+import rich.tree
+from hydra.core.hydra_config import HydraConfig
+from lightning.pytorch.utilities import rank_zero_only
+from omegaconf import DictConfig, OmegaConf, open_dict
+from rich.prompt import Prompt
+
+from fish_speech.utils import logger as log
+
+
+@rank_zero_only
+def print_config_tree(
+ cfg: DictConfig,
+ print_order: Sequence[str] = (
+ "data",
+ "model",
+ "callbacks",
+ "logger",
+ "trainer",
+ "paths",
+ "extras",
+ ),
+ resolve: bool = False,
+ save_to_file: bool = False,
+) -> None:
+ """Prints content of DictConfig using Rich library and its tree structure.
+
+ Args:
+ cfg (DictConfig): Configuration composed by Hydra.
+ print_order (Sequence[str], optional): Determines in what order config components are printed.
+ resolve (bool, optional): Whether to resolve reference fields of DictConfig.
+ save_to_file (bool, optional): Whether to export config to the hydra output folder.
+ """ # noqa: E501
+
+ style = "dim"
+ tree = rich.tree.Tree("CONFIG", style=style, guide_style=style)
+
+ queue = []
+
+ # add fields from `print_order` to queue
+ for field in print_order:
+ queue.append(field) if field in cfg else log.warning(
+ f"Field '{field}' not found in config. "
+ + f"Skipping '{field}' config printing..."
+ )
+
+ # add all the other fields to queue (not specified in `print_order`)
+ for field in cfg:
+ if field not in queue:
+ queue.append(field)
+
+ # generate config tree from queue
+ for field in queue:
+ branch = tree.add(field, style=style, guide_style=style)
+
+ config_group = cfg[field]
+ if isinstance(config_group, DictConfig):
+ branch_content = OmegaConf.to_yaml(config_group, resolve=resolve)
+ else:
+ branch_content = str(config_group)
+
+ branch.add(rich.syntax.Syntax(branch_content, "yaml"))
+
+ # print config tree
+ rich.print(tree)
+
+ # save config tree to file
+ if save_to_file:
+ with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file:
+ rich.print(tree, file=file)
+
+
+@rank_zero_only
+def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None:
+ """Prompts user to input tags from command line if no tags are provided in config.""" # noqa: E501
+
+ if not cfg.get("tags"):
+ if "id" in HydraConfig().cfg.hydra.job:
+ raise ValueError("Specify tags before launching a multirun!")
+
+ log.warning("No tags provided in config. Prompting user to input tags...")
+ tags = Prompt.ask("Enter a list of comma separated tags", default="dev")
+ tags = [t.strip() for t in tags.split(",") if t != ""]
+
+ with open_dict(cfg):
+ cfg.tags = tags
+
+ log.info(f"Tags: {cfg.tags}")
+
+ if save_to_file:
+ with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file:
+ rich.print(cfg.tags, file=file)
diff --git a/fish_speech/utils/utils.py b/fish_speech/utils/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..c546bfa1eddd2ac6bf484cce1ec06da1d33fb121
--- /dev/null
+++ b/fish_speech/utils/utils.py
@@ -0,0 +1,114 @@
+import warnings
+from importlib.util import find_spec
+from typing import Callable
+
+from omegaconf import DictConfig
+
+from .logger import RankedLogger
+from .rich_utils import enforce_tags, print_config_tree
+
+log = RankedLogger(__name__, rank_zero_only=True)
+
+
+def extras(cfg: DictConfig) -> None:
+ """Applies optional utilities before the task is started.
+
+ Utilities:
+ - Ignoring python warnings
+ - Setting tags from command line
+ - Rich config printing
+ """
+
+ # return if no `extras` config
+ if not cfg.get("extras"):
+ log.warning("Extras config not found! ")
+ return
+
+ # disable python warnings
+ if cfg.extras.get("ignore_warnings"):
+ log.info("Disabling python warnings! ")
+ warnings.filterwarnings("ignore")
+
+ # prompt user to input tags from command line if none are provided in the config
+ if cfg.extras.get("enforce_tags"):
+ log.info("Enforcing tags! ")
+ enforce_tags(cfg, save_to_file=True)
+
+ # pretty print config tree using Rich library
+ if cfg.extras.get("print_config"):
+ log.info("Printing config tree with Rich! ")
+ print_config_tree(cfg, resolve=True, save_to_file=True)
+
+
+def task_wrapper(task_func: Callable) -> Callable:
+ """Optional decorator that controls the failure behavior when executing the task function.
+
+ This wrapper can be used to:
+ - make sure loggers are closed even if the task function raises an exception (prevents multirun failure)
+ - save the exception to a `.log` file
+ - mark the run as failed with a dedicated file in the `logs/` folder (so we can find and rerun it later)
+ - etc. (adjust depending on your needs)
+
+ Example:
+ ```
+ @utils.task_wrapper
+ def train(cfg: DictConfig) -> Tuple[dict, dict]:
+
+ ...
+
+ return metric_dict, object_dict
+ ```
+ """ # noqa: E501
+
+ def wrap(cfg: DictConfig):
+ # execute the task
+ try:
+ metric_dict, object_dict = task_func(cfg=cfg)
+
+ # things to do if exception occurs
+ except Exception as ex:
+ # save exception to `.log` file
+ log.exception("")
+
+ # some hyperparameter combinations might be invalid or
+ # cause out-of-memory errors so when using hparam search
+ # plugins like Optuna, you might want to disable
+ # raising the below exception to avoid multirun failure
+ raise ex
+
+ # things to always do after either success or exception
+ finally:
+ # display output dir path in terminal
+ log.info(f"Output dir: {cfg.paths.run_dir}")
+
+ # always close wandb run (even if exception occurs so multirun won't fail)
+ if find_spec("wandb"): # check if wandb is installed
+ import wandb
+
+ if wandb.run:
+ log.info("Closing wandb!")
+ wandb.finish()
+
+ return metric_dict, object_dict
+
+ return wrap
+
+
+def get_metric_value(metric_dict: dict, metric_name: str) -> float:
+ """Safely retrieves value of the metric logged in LightningModule."""
+
+ if not metric_name:
+ log.info("Metric name is None! Skipping metric value retrieval...")
+ return None
+
+ if metric_name not in metric_dict:
+ raise Exception(
+ f"Metric value not found! \n"
+ "Make sure metric name logged in LightningModule is correct!\n"
+ "Make sure `optimized_metric` name in `hparams_search` config is correct!"
+ )
+
+ metric_value = metric_dict[metric_name].item()
+ log.info(f"Retrieved metric value! <{metric_name}={metric_value}>")
+
+ return metric_value
diff --git a/packages.txt b/packages.txt
new file mode 100644
index 0000000000000000000000000000000000000000..3d15867494d7267d6019817bb9d578590d384aaf
--- /dev/null
+++ b/packages.txt
@@ -0,0 +1,10 @@
+git
+curl
+build-essential
+ffmpeg
+libsm6
+libxext6
+libjpeg-dev
+zlib1g-dev
+protobuf-compiler
+cmake
diff --git a/pyrightconfig.json b/pyrightconfig.json
new file mode 100644
index 0000000000000000000000000000000000000000..ad1493530f7f6d8fa476dbe0b76e6239fce2d7e7
--- /dev/null
+++ b/pyrightconfig.json
@@ -0,0 +1,6 @@
+{
+ "exclude": [
+ "data",
+ "filelists"
+ ]
+}
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..1dc0981130b5b890c884679c625ac3a7e300289b
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,24 @@
+torch
+torchaudio
+transformers>=4.35.2
+datasets>=2.14.5
+lightning>=2.1.0
+hydra-core>=1.3.2
+tensorboard>=2.14.1
+natsort>=8.4.0
+einops>=0.7.0
+librosa>=0.10.1
+rich>=13.5.3
+gradio>=4.0.0
+wandb>=0.15.11
+grpcio>=1.58.0
+kui>=1.6.0
+zibai-server>=0.9.0
+loguru>=0.6.0
+loralib>=0.1.2
+natsort>=8.4.0
+pyrootutils>=1.0.4
+vector_quantize_pytorch>=1.14.7
+samplerate>=0.2.1
+resampy>=0.4.3
+spaces>=0.26.1"
diff --git a/setup.sh b/setup.sh
new file mode 100644
index 0000000000000000000000000000000000000000..c5031db6557d33a38c68909f753e06ddc3ccb001
--- /dev/null
+++ b/setup.sh
@@ -0,0 +1,18 @@
+#!/bin/bash
+set -e
+
+mkdir -p checkpoints
+
+if [ -e checkpoints/text2semantic-medium-v1-2k.pth ]; then
+ echo "checkpoints/text2semantic-medium-v1-2k.pth already exists"
+else
+ echo "Downloading text2semantic-medium-v1-2k.pth"
+ wget -O checkpoints/text2semantic-medium-v1-2k.pth $CKPT_SEMANTIC
+fi
+
+if [ -e checkpoints/vq-gan-group-fsq-2x1024.pth ]; then
+ echo "checkpoints/vq-gan-group-fsq-2x1024.pth already exists"
+else
+ echo "Downloading vq-gan-group-fsq-2x1024.pth"
+ wget -O checkpoints/vq-gan-group-fsq-2x1024.pth $CKPT_VQGAN
+fi
diff --git a/tools/extract_model.py b/tools/extract_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..97fe62507b7282890319d8dc1eaa3cbca0e1f60a
--- /dev/null
+++ b/tools/extract_model.py
@@ -0,0 +1,21 @@
+import click
+import torch
+from loguru import logger
+
+
+@click.command()
+@click.argument("model_path")
+@click.argument("output_path")
+def main(model_path, output_path):
+ if model_path == output_path:
+ logger.error("Model path and output path are the same")
+ return
+
+ logger.info(f"Loading model from {model_path}")
+ state_dict = torch.load(model_path, map_location="cpu")["state_dict"]
+ torch.save(state_dict, output_path)
+ logger.info(f"Model saved to {output_path}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/tools/llama/build_dataset.py b/tools/llama/build_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..2122f673aa007dc86b0432eeb5228360f5b067cd
--- /dev/null
+++ b/tools/llama/build_dataset.py
@@ -0,0 +1,165 @@
+import itertools
+import os
+import re
+from collections import defaultdict
+from functools import partial
+from multiprocessing import Pool
+from pathlib import Path
+
+import click
+import numpy as np
+from loguru import logger
+from tqdm import tqdm
+
+from fish_speech.datasets.protos.text_data_pb2 import Semantics, Sentence, TextData
+from fish_speech.datasets.protos.text_data_stream import pack_pb_stream
+from fish_speech.utils.file import load_filelist
+
+# To avoid CPU overload
+os.environ["MKL_NUM_THREADS"] = "1"
+os.environ["OMP_NUM_THREADS"] = "1"
+
+
+def task_generator_folder(root: Path, text_extension: str):
+ files = list(tqdm(Path(root).rglob("*.npy"), desc=f"Loading {root}"))
+ files = sorted(files)
+
+ grouped_files = defaultdict(list)
+ for file in tqdm(files, desc=f"Grouping {root}"):
+ p = str(file.parent)
+
+ try:
+ if isinstance(text_extension, str):
+ texts = [file.with_suffix(text_extension).read_text()]
+ else:
+ texts = [file.with_suffix(ext).read_text() for ext in text_extension]
+ except Exception as e:
+ logger.error(f"Failed to read text {file}: {e}")
+ continue
+
+ grouped_files[p].append((file, texts))
+
+ logger.info(
+ f"Found {len(grouped_files)} groups in {root}, {list(grouped_files.keys())[:5]}..."
+ )
+ for name, subset in grouped_files.items():
+ yield name, subset, "folder"
+
+
+def task_generator_filelist(filelist):
+ grouped_files = defaultdict(list)
+ for filename, speaker, _, text in load_filelist(filelist):
+ grouped_files[speaker].append((Path(filename), [text]))
+
+ logger.info(f"Found {len(grouped_files)} groups in {filelist}")
+ for speaker, values in grouped_files.items():
+ yield speaker, values, "filelist"
+
+
+def run_task(task):
+ name, subset, source = task
+
+ # Parse the files
+ sentences = []
+ for file in subset:
+ file, texts = file
+
+ np_file = file.with_suffix(".npy")
+ if np_file.exists() is False:
+ logger.warning(f"Can't find {np_file}")
+ continue
+
+ new_texts = []
+
+ for text in texts:
+ # Simple cleaning: replace { xxx } and < xxx > with space
+ text = re.sub(r"\{.*?\}", " ", text)
+ text = re.sub(r"<.*?>", " ", text)
+ text = re.sub(r"\s+", " ", text)
+ new_texts.append(text)
+
+ try:
+ semantics = np.load(np_file)
+ except Exception as e:
+ logger.error(f"Failed to parse {file}: {e}")
+ continue
+
+ if isinstance(semantics, np.ndarray):
+ semantics = semantics.tolist()
+
+ sentences.append(
+ Sentence(
+ texts=new_texts,
+ semantics=[Semantics(values=s) for s in semantics],
+ )
+ )
+
+ # Pack the sentences
+ return pack_pb_stream(
+ TextData(
+ source=source,
+ name=name,
+ sentences=sentences,
+ )
+ )
+
+
+@click.command()
+@click.option(
+ "--input",
+ type=click.Path(path_type=Path),
+ required=True,
+ help="A folder containing the dataset or a filelist",
+ multiple=True,
+)
+@click.option(
+ "--output", type=click.Path(path_type=Path), default="data/quantized-dataset-ft"
+)
+@click.option("--num-workers", type=int, default=16)
+@click.option("--text-extension", type=str, default=[".txt"], multiple=True)
+@click.option(
+ "--shard-size", type=int, default=10, help="The maximum size of each shard in mb"
+)
+def main(input, output, num_workers, text_extension, shard_size):
+ generator_fns = []
+
+ for f in input:
+ assert f.exists(), f"{f} not found"
+
+ if f.is_dir():
+ generator_fn = task_generator_folder(f, text_extension)
+ else:
+ generator_fn = task_generator_filelist(f)
+
+ generator_fns.append(generator_fn)
+
+ generator_fn = itertools.chain(*generator_fns)
+ output.mkdir(parents=True, exist_ok=True)
+
+ dataset_fp = None
+ tar_idx = 0
+ written_size = 0
+
+ with Pool(num_workers) as p:
+ for result in tqdm(p.imap_unordered(run_task, generator_fn)):
+ if dataset_fp is None:
+ dataset_fp = open(Path(output) / f"{tar_idx:08d}.protos", "wb")
+
+ dataset_fp.write(result)
+ written_size += len(result)
+
+ if written_size > shard_size * 1024 * 1024:
+ logger.info(f"Finished writing {tar_idx} shards to {output}")
+ dataset_fp.close()
+ dataset_fp = None
+ written_size = 0
+ tar_idx += 1
+
+ if dataset_fp is not None:
+ dataset_fp.close()
+
+ logger.info(f"Finished writing {tar_idx + 1} shards to {output}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/tools/llama/generate.py b/tools/llama/generate.py
new file mode 100644
index 0000000000000000000000000000000000000000..27a742551611aaa885310a3a26656054f33824ff
--- /dev/null
+++ b/tools/llama/generate.py
@@ -0,0 +1,674 @@
+import os
+import time
+from pathlib import Path
+from typing import Optional, Tuple, Union
+
+import click
+import numpy as np
+import torch
+import torch._dynamo.config
+import torch._inductor.config
+from hydra import compose, initialize
+from hydra.utils import instantiate
+from loguru import logger
+from tqdm import tqdm
+from transformers import AutoTokenizer
+
+from fish_speech.datasets.text import CODEBOOK_EOS_TOKEN_ID
+from fish_speech.text.clean import clean_text
+
+os.environ["TOKENIZERS_PARALLELISM"] = "false"
+torch._inductor.config.coordinate_descent_tuning = True
+torch._inductor.config.triton.unique_kernel_names = True
+
+if hasattr(torch._inductor.config, "fx_graph_cache"):
+ # Experimental feature to reduce compilation times, will be on by default in future
+ torch._inductor.config.fx_graph_cache = True
+
+
+from fish_speech.models.text2semantic.llama import DualARTransformer, NaiveTransformer
+
+
+def multinomial_sample_one_no_sync(
+ probs_sort,
+): # Does multinomial sampling without a cuda synchronization
+ q = torch.empty_like(probs_sort).exponential_(1)
+ return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
+
+
+def logits_to_probs(
+ logits,
+ previous_tokens: Optional[torch.Tensor] = None,
+ temperature: float = 1.0,
+ top_k: Optional[int] = None,
+ top_p: Optional[int] = None,
+ repetition_penalty: float = 1.0,
+):
+ if previous_tokens is not None and repetition_penalty != 1.0:
+ previous_tokens = previous_tokens.long()
+ score = torch.gather(logits, dim=0, index=previous_tokens)
+ score = torch.where(
+ score < 0, score * repetition_penalty, score / repetition_penalty
+ )
+ logits.scatter_(dim=0, index=previous_tokens, src=score)
+
+ if top_p is not None and top_p < 1.0:
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
+ cum_probs = torch.cumsum(
+ torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1
+ )
+ sorted_indices_to_remove = cum_probs > top_p
+ sorted_indices_to_remove[0] = False # keep at least one option
+ indices_to_remove = sorted_indices_to_remove.scatter(
+ dim=0, index=sorted_indices, src=sorted_indices_to_remove
+ )
+ logits = logits.masked_fill(indices_to_remove, -float("Inf"))
+
+ logits = logits / max(temperature, 1e-5)
+
+ if top_k is not None:
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
+ pivot = v.select(-1, -1).unsqueeze(-1)
+ logits = torch.where(logits < pivot, -float("Inf"), logits)
+
+ probs = torch.nn.functional.softmax(logits, dim=-1)
+ return probs
+
+
+def sample(
+ logits,
+ previous_tokens: Optional[torch.Tensor] = None,
+ **sampling_kwargs,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ probs = logits_to_probs(
+ logits=logits[0, -1], previous_tokens=previous_tokens, **sampling_kwargs
+ )
+ idx_next = multinomial_sample_one_no_sync(probs)
+ return idx_next, probs
+
+
+def decode_one_token_ar(
+ model: DualARTransformer,
+ x: torch.Tensor,
+ input_pos: torch.Tensor,
+ previous_tokens: torch.Tensor = None,
+ **sampling_kwargs,
+) -> torch.Tensor:
+ x = model.forward_generate(x, input_pos)
+ codebooks = [
+ sample(
+ x.logits,
+ previous_tokens=None, # Disable repetition penalty for the token codebook
+ **sampling_kwargs,
+ )[0]
+ ]
+ x = x.hidden_states
+
+ # Cleanup the cache
+ for layer in model.fast_layers:
+ layer.attention.kv_cache.k_cache.fill_(0)
+ layer.attention.kv_cache.v_cache.fill_(0)
+
+ for codebook_idx in range(model.config.num_codebooks):
+ input_pos = torch.tensor([codebook_idx], device=x.device, dtype=torch.long)
+ logits = model.forward_generate_fast(x, input_pos)
+ a = sample(
+ logits,
+ previous_tokens=(
+ previous_tokens[codebook_idx + 1]
+ if previous_tokens is not None
+ else None
+ ),
+ **sampling_kwargs,
+ )[0]
+ x = model.fast_embeddings(a)
+ codebooks.append(a)
+
+ return torch.stack(codebooks, dim=0)
+
+
+def decode_one_token_naive(
+ model: NaiveTransformer,
+ x: torch.Tensor,
+ input_pos: torch.Tensor,
+ previous_tokens: torch.Tensor = None,
+ **sampling_kwargs,
+) -> torch.Tensor:
+ x = model.forward_generate(x, input_pos)
+
+ codebooks = [
+ sample(
+ x.token_logits,
+ previous_tokens=None, # Disable repetition penalty for the token codebook
+ **sampling_kwargs,
+ )[0]
+ ]
+
+ for i in range(model.config.num_codebooks):
+ codebooks.append(
+ sample(
+ x.codebook_logits[:, :, i],
+ previous_tokens=(
+ previous_tokens[i + 1] if previous_tokens is not None else None
+ ),
+ **sampling_kwargs,
+ )[0]
+ )
+
+ return torch.stack(codebooks, dim=0)
+
+
+def decode_n_tokens(
+ model: NaiveTransformer,
+ cur_token: torch.Tensor,
+ input_pos: torch.Tensor,
+ num_new_tokens: int,
+ eos_token_id: int = 2,
+ im_end_id: int = 4,
+ decode_one_token=decode_one_token_naive,
+ **sampling_kwargs,
+):
+ previous_tokens = torch.zeros(
+ (model.config.num_codebooks + 1, model.config.max_seq_len),
+ dtype=torch.int,
+ device=cur_token.device,
+ )
+
+ for i in tqdm(range(num_new_tokens)):
+ # We need to get windowed repeat penalty
+ win_size = 16
+ if i < win_size:
+ window = previous_tokens[:, :win_size]
+ else:
+ window = previous_tokens[:, i - win_size : i]
+
+ with torch.backends.cuda.sdp_kernel(
+ enable_flash=False, enable_mem_efficient=False, enable_math=True
+ ): # Actually better for Inductor to codegen attention here
+ next_token = decode_one_token(
+ model=model,
+ x=cur_token,
+ input_pos=input_pos,
+ previous_tokens=window,
+ **sampling_kwargs,
+ )
+
+ input_pos += 1
+ cur_token = next_token.view(1, model.config.num_codebooks + 1, -1)
+ previous_tokens[:, i : i + 1] = next_token.view(
+ model.config.num_codebooks + 1, -1
+ )
+
+ if (
+ cur_token[0, 0, -1] == eos_token_id
+ or cur_token[0, 0, -1] == im_end_id
+ or (cur_token[0, 1:, -1] == CODEBOOK_EOS_TOKEN_ID).any()
+ ):
+ break
+
+ return previous_tokens[:, : i + 1]
+
+
+@torch.no_grad()
+@torch.inference_mode()
+def generate(
+ *,
+ model: NaiveTransformer,
+ prompt: torch.Tensor,
+ max_new_tokens: int,
+ eos_token_id: int = 2,
+ im_end_id: int = 4,
+ decode_one_token=decode_one_token_naive,
+ **sampling_kwargs,
+) -> torch.Tensor:
+ """
+ Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
+ """
+
+ # create an empty tensor of the expected final shape and fill in the current tokens
+ T = prompt.size(1)
+
+ if max_new_tokens:
+ if T + max_new_tokens > model.config.max_seq_len:
+ max_new_tokens = model.config.max_seq_len - T
+ logger.info(f"Truncating max_new_tokens to {max_new_tokens}")
+
+ T_new = T + max_new_tokens
+ else:
+ T_new = model.config.max_seq_len
+ max_new_tokens = T_new - T
+
+ device, dtype = prompt.device, prompt.dtype
+ with torch.device(device):
+ model.setup_caches(
+ max_batch_size=1, max_seq_len=T_new, dtype=next(model.parameters()).dtype
+ )
+
+ codebook_dim = 1 + model.config.num_codebooks
+ # create an empty tensor of the expected final shape and fill in the current tokens
+ empty = torch.empty((codebook_dim, T_new), dtype=dtype, device=device)
+ empty[:, :T] = prompt
+ seq = empty
+ input_pos = torch.arange(0, T, device=device)
+
+ # Use non-accelerated version for now, to avoid compilation overhead
+ prefill_decode = (
+ decode_one_token_naive
+ if isinstance(model, NaiveTransformer)
+ else decode_one_token_ar
+ )
+ next_token = prefill_decode(
+ model, prompt.view(1, codebook_dim, -1), input_pos, **sampling_kwargs
+ )
+ seq[:, T : T + 1] = next_token
+
+ input_pos = torch.tensor([T], device=device, dtype=torch.int)
+ x = decode_n_tokens(
+ model,
+ next_token.view(1, codebook_dim, -1),
+ input_pos,
+ max_new_tokens - 1,
+ eos_token_id=eos_token_id,
+ im_end_id=im_end_id,
+ decode_one_token=decode_one_token,
+ **sampling_kwargs,
+ )
+ # x = torch.cat(generated_tokens, dim=1)
+ seq = seq[:, : T + 1 + x.size(1)]
+ seq[:, T + 1 :] = x
+
+ return seq
+
+
+def encode_tokens(
+ tokenizer,
+ string,
+ bos=True,
+ device="cuda",
+ prompt_tokens=None,
+ speaker=None,
+ num_codebooks=4,
+):
+ string = clean_text(string)
+
+ if speaker is not None:
+ string = f"[SPK: {speaker}] {string}"
+
+ string = (
+ f"<|im_start|>user<|im_sep|>{string}<|im_end|><|im_start|>assistant<|im_sep|>"
+ )
+ if bos:
+ string = f"<|begin_of_sequence|>{string}"
+
+ new_tokens = tokenizer.encode(
+ string,
+ add_special_tokens=False,
+ max_length=10**6,
+ truncation=False,
+ )
+ tokens = torch.tensor([new_tokens], dtype=torch.int, device=device)
+
+ # Codebooks
+ zeros = torch.zeros((num_codebooks, tokens.size(1)), dtype=torch.int, device=device)
+ prompt = torch.cat((tokens, zeros), dim=0)
+
+ if prompt_tokens is None:
+ return prompt
+
+ # Get prompt tokens
+ if prompt_tokens.ndim == 3:
+ assert (
+ prompt_tokens.shape[0] == 1
+ ), f"3 dim prompt tokens should have shape (1, num_codebooks, seq_len)"
+ prompt_tokens = prompt_tokens[0]
+
+ assert prompt_tokens.ndim == 2
+ data = prompt_tokens + 2
+
+ if prompt_tokens.shape[0] > num_codebooks:
+ logger.warning(
+ f"Prompt tokens shape {prompt_tokens.shape} is larger than num_codebooks {num_codebooks}, getting first {num_codebooks} codebooks"
+ )
+ data = data[:num_codebooks]
+
+ # Since 1.0, we use <|semantic|>
+ s0_token_id = tokenizer.convert_tokens_to_ids("<|semantic|>")
+ main_token_ids = torch.tensor(
+ [[s0_token_id] * data.size(1)],
+ dtype=torch.int,
+ device=device,
+ )
+
+ data = torch.cat((main_token_ids, data), dim=0)
+ prompt = torch.cat((prompt, data), dim=1)
+
+ return prompt
+
+
+def load_model(
+ config_name, checkpoint_path, device, precision, max_length, compile=False
+):
+ with initialize(version_base="1.3", config_path="../../fish_speech/configs/model"):
+ cfg = compose(
+ config_name=config_name, overrides=[f"config.max_seq_len={max_length}"]
+ )
+
+ model: Union[NaiveTransformer, DualARTransformer] = instantiate(cfg)
+
+ if "int8" in str(checkpoint_path):
+ logger.info("Using int8 weight-only quantization!")
+ from quantize import WeightOnlyInt8QuantHandler
+
+ simple_quantizer = WeightOnlyInt8QuantHandler(model)
+ model = simple_quantizer.convert_for_runtime()
+
+ if "int4" in str(checkpoint_path):
+ logger.info("Using int4 quantization!")
+ path_comps = checkpoint_path.name.split(".")
+ assert path_comps[-2].startswith("g")
+ groupsize = int(path_comps[-2][1:])
+ from quantize import WeightOnlyInt4QuantHandler
+
+ simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize)
+ model = simple_quantizer.convert_for_runtime()
+
+ checkpoint = torch.load(str(checkpoint_path), map_location="cpu")
+ if "state_dict" in checkpoint:
+ checkpoint = checkpoint["state_dict"]
+
+ if any(k.startswith("model.") for k in checkpoint):
+ checkpoint = {
+ k.replace("model.", ""): v
+ for k, v in checkpoint.items()
+ if k.startswith("model.")
+ }
+
+ model.load_state_dict(checkpoint, assign=True)
+
+ model = model.to(device=device, dtype=precision)
+ logger.info("Restored model from checkpoint")
+
+ if isinstance(model, DualARTransformer):
+ decode_one_token = decode_one_token_ar
+ logger.info("Using DualARTransformer")
+ else:
+ decode_one_token = decode_one_token_naive
+ logger.info("Using NaiveTransformer")
+
+ if compile:
+ logger.info("Compiling function...")
+ decode_one_token = torch.compile(
+ decode_one_token, mode="reduce-overhead", fullgraph=True
+ )
+
+ return model.eval(), decode_one_token
+
+
+def split_text(text, min_length):
+ text = clean_text(text)
+ segments = []
+ curr = ""
+ for char in text:
+ curr += char
+ if char not in [".", ",", "!", "?"]:
+ continue
+
+ if len(curr) >= min_length:
+ segments.append(curr)
+ curr = ""
+
+ if curr:
+ segments.append(curr)
+
+ return segments
+
+
+def generate_long(
+ *,
+ model,
+ tokenizer: callable,
+ device: str | torch.device,
+ decode_one_token: callable,
+ text: str,
+ num_samples: int = 1,
+ max_new_tokens: int = 0,
+ top_k: int = None,
+ top_p: int = 0.7,
+ repetition_penalty: float = 1.5,
+ temperature: float = 0.7,
+ compile: bool = False,
+ iterative_prompt: bool = True,
+ max_length: int = 2048,
+ chunk_length: int = 30,
+ speaker: Optional[str] = None,
+ prompt_text: Optional[str] = None,
+ prompt_tokens: Optional[torch.Tensor] = None,
+):
+ model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
+ im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
+
+ use_prompt = prompt_text is not None and prompt_tokens is not None
+ encoded = []
+ texts = split_text(text, chunk_length) if iterative_prompt else [text]
+ for idx, text in enumerate(texts):
+ encoded.append(
+ encode_tokens(
+ tokenizer,
+ string=text,
+ bos=idx == 0 and not use_prompt,
+ device=device,
+ speaker=None,
+ num_codebooks=model.config.num_codebooks,
+ )
+ )
+ logger.info(f"Encoded text: {text}")
+
+ if use_prompt:
+ encoded_prompt = encode_tokens(
+ tokenizer,
+ prompt_text,
+ prompt_tokens=prompt_tokens,
+ bos=True,
+ device=device,
+ speaker=speaker,
+ num_codebooks=model.config.num_codebooks,
+ )
+
+ encoded[0] = torch.cat((encoded_prompt, encoded[0]), dim=1)
+
+ for sample_idx in range(num_samples):
+ torch.cuda.synchronize()
+ global_encoded = []
+ all_codes = []
+ seg_idx = 0
+
+ while seg_idx < len(encoded):
+ logger.info(
+ f"Generating sentence {seg_idx + 1}/{len(encoded)} of sample {sample_idx + 1}/{num_samples}"
+ )
+
+ seg = encoded[seg_idx]
+ global_encoded.append(seg)
+
+ lengths = reversed([seg.size(1) for seg in global_encoded])
+
+ # Pick last 2000 tokens
+ count = 0
+ for i, length in enumerate(lengths):
+ count += length
+ if count + length > max_length - 1024:
+ break
+
+ if i != 0 and i % 2 == 0:
+ i -= 1
+
+ # Rotate the list, always make sure first segment is included to avoid drift
+ if i < len(global_encoded) - 2:
+ partial_encoded = global_encoded[:2] + global_encoded[-i:]
+ else:
+ partial_encoded = global_encoded
+
+ cat_encoded = torch.cat(partial_encoded, dim=1)
+ prompt_length = cat_encoded.size(1)
+
+ t0 = time.perf_counter()
+ y = generate(
+ model=model,
+ prompt=cat_encoded,
+ max_new_tokens=max_new_tokens,
+ eos_token_id=tokenizer.eos_token_id,
+ im_end_id=im_end_id,
+ decode_one_token=decode_one_token,
+ temperature=temperature,
+ top_k=top_k,
+ top_p=top_p,
+ repetition_penalty=repetition_penalty,
+ )
+
+ if sample_idx == 0 and seg_idx == 0 and compile:
+ logger.info(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
+
+ torch.cuda.synchronize()
+ t = time.perf_counter() - t0
+
+ tokens_generated = y.size(1) - prompt_length
+ tokens_sec = tokens_generated / t
+ logger.info(
+ f"Generated {tokens_generated} tokens in {t:.02f} seconds, {tokens_sec:.02f} tokens/sec"
+ )
+ logger.info(
+ f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s"
+ )
+ logger.info(
+ f"GPU Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB"
+ )
+
+ # Put the generated tokens
+ # since there is and tokens, we remove last 2 tokens
+ codes = y[1:, prompt_length:-2].clone()
+
+ codes = codes - 2
+ if not (codes >= 0).all():
+ global_encoded.pop()
+ logger.warning(f"Negative code found: {codes}, retrying ...")
+ continue
+
+ decoded = y[:, prompt_length:-1].clone()
+ if decoded[0, -1] != im_end_id: #
+ val = [[im_end_id]] + [[CODEBOOK_EOS_TOKEN_ID]] * (decoded.size(0) - 1)
+ decoded = torch.cat(
+ (decoded, torch.tensor(val, device=device, dtype=torch.int)), dim=1
+ )
+
+ # But for global encoding, we should keep the token
+ global_encoded.append(decoded)
+ all_codes.append(codes)
+ seg_idx += 1
+
+ codes = torch.cat(all_codes, dim=1)
+ assert (codes >= 0).all(), f"Negative code found: {codes}"
+
+ yield codes
+
+
+@click.command()
+@click.option(
+ "--text",
+ type=str,
+ default="你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.",
+)
+@click.option("--prompt-text", type=str, default=None)
+@click.option(
+ "--prompt-tokens", type=click.Path(path_type=Path, exists=True), default=None
+)
+@click.option("--num-samples", type=int, default=1)
+@click.option("--max-new-tokens", type=int, default=0)
+@click.option("--top-k", type=int, default=None)
+@click.option("--top-p", type=float, default=0.7)
+@click.option("--repetition-penalty", type=float, default=1.5)
+@click.option("--temperature", type=float, default=0.7)
+@click.option(
+ "--checkpoint-path",
+ type=click.Path(path_type=Path, exists=True),
+ default="results/text2semantic_400m_finetune/step_000002000.pth",
+)
+@click.option("--config-name", type=str, default="dual_ar_8_codebook_small")
+@click.option("--tokenizer", type=str, default="fishaudio/fish-speech-1")
+@click.option("--compile/--no-compile", default=False)
+@click.option("--seed", type=int, default=42)
+@click.option("--speaker", type=str, default=None)
+@click.option("--half/--no-half", default=False)
+@click.option("--iterative-prompt/--no-iterative-prompt", default=True)
+@click.option("--max-length", type=int, default=2048)
+@click.option("--chunk-length", type=int, default=30)
+def main(
+ text: str,
+ prompt_text: Optional[str],
+ prompt_tokens: Optional[Path],
+ num_samples: int,
+ max_new_tokens: int,
+ top_k: int,
+ top_p: int,
+ repetition_penalty: float,
+ temperature: float,
+ checkpoint_path: Path,
+ config_name: str,
+ tokenizer: str,
+ compile: bool,
+ seed: int,
+ speaker: Optional[str],
+ half: bool,
+ iterative_prompt: bool,
+ max_length: int,
+ chunk_length: int,
+) -> None:
+ device = "cuda"
+
+ precision = torch.half if half else torch.bfloat16
+
+ logger.info("Loading model ...")
+ t0 = time.time()
+ model, decode_one_token = load_model(
+ config_name, checkpoint_path, device, precision, max_length, compile=compile
+ )
+ torch.cuda.synchronize()
+ logger.info(f"Time to load model: {time.time() - t0:.02f} seconds")
+
+ prompt_tokens = (
+ torch.from_numpy(np.load(prompt_tokens)).to(device)
+ if prompt_tokens is not None
+ else None
+ )
+
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed(seed)
+
+ generator = generate_long(
+ model=model,
+ device=device,
+ decode_one_token=decode_one_token,
+ text=text,
+ num_samples=num_samples,
+ max_new_tokens=max_new_tokens,
+ top_k=top_k,
+ top_p=top_p,
+ repetition_penalty=repetition_penalty,
+ temperature=temperature,
+ tokenizer=tokenizer,
+ compile=compile,
+ speaker=speaker,
+ iterative_prompt=iterative_prompt,
+ max_length=max_length,
+ chunk_length=chunk_length,
+ prompt_text=prompt_text,
+ prompt_tokens=prompt_tokens,
+ )
+
+ for idx, codes in enumerate(generator):
+ np.save(f"codes_{idx}.npy", codes.cpu().numpy())
+ logger.info(f"Saved codes to codes_{idx}.npy")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/tools/llama/quantize.py b/tools/llama/quantize.py
new file mode 100644
index 0000000000000000000000000000000000000000..fadee00b2bcb1e2dabd0e37d2d93a7e71097071c
--- /dev/null
+++ b/tools/llama/quantize.py
@@ -0,0 +1,515 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+import time
+from pathlib import Path
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from fish_speech.models.text2semantic.llama import ModelArgs, Transformer, find_multiple
+
+##### Quantization Primitives ######
+
+
+def dynamically_quantize_per_channel(x, quant_min, quant_max, target_dtype):
+ # assumes symmetric quantization
+ # assumes axis == 0
+ # assumes dense memory format
+ # TODO(future): relax ^ as needed
+
+ # default setup for affine quantization of activations
+ eps = torch.finfo(torch.float32).eps
+
+ # get min and max
+ min_val, max_val = torch.aminmax(x, dim=1)
+
+ # calculate scales and zero_points based on min and max
+ # reference: https://fburl.com/code/srbiybme
+ min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
+ max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
+ device = min_val_neg.device
+
+ # reference: https://fburl.com/code/4wll53rk
+ max_val_pos = torch.max(-min_val_neg, max_val_pos)
+ scales = max_val_pos / (float(quant_max - quant_min) / 2)
+ # ensure scales is the same dtype as the original tensor
+ scales = torch.clamp(scales, min=eps).to(x.dtype)
+ zero_points = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device)
+
+ # quantize based on qmin/qmax/scales/zp
+ # reference: https://www.internalfb.com/code/fbsource/[8edc275012b1]/fbcode/caffe2/torch/ao/quantization/fx/_decomposed.py?lines=63
+ x_div = x / scales.unsqueeze(-1)
+ x_round = torch.round(x_div)
+ x_zp = x_round + zero_points.unsqueeze(-1)
+ quant = torch.clamp(x_zp, quant_min, quant_max).to(target_dtype)
+
+ return quant, scales, zero_points
+
+
+def get_group_qparams(w, n_bit=4, groupsize=128):
+ # needed for GPTQ with padding
+ if groupsize > w.shape[-1]:
+ groupsize = w.shape[-1]
+ assert groupsize > 1
+ assert w.shape[-1] % groupsize == 0
+ assert w.dim() == 2
+
+ to_quant = w.reshape(-1, groupsize)
+ assert torch.isnan(to_quant).sum() == 0
+
+ max_val = to_quant.amax(dim=1, keepdim=True)
+ min_val = to_quant.amin(dim=1, keepdim=True)
+ max_int = 2**n_bit - 1
+ scales = (max_val - min_val).clamp(min=1e-6) / max_int
+ zeros = min_val + scales * (2 ** (n_bit - 1))
+ return scales.to(torch.bfloat16).reshape(w.shape[0], -1), zeros.to(
+ torch.bfloat16
+ ).reshape(w.shape[0], -1)
+
+
+def pack_scales_and_zeros(scales, zeros):
+ assert scales.shape == zeros.shape
+ assert scales.dtype == torch.bfloat16
+ assert zeros.dtype == torch.bfloat16
+ return (
+ torch.cat(
+ [
+ scales.reshape(scales.size(0), scales.size(1), 1),
+ zeros.reshape(zeros.size(0), zeros.size(1), 1),
+ ],
+ 2,
+ )
+ .transpose(0, 1)
+ .contiguous()
+ )
+
+
+def unpack_scales_and_zeros(scales_and_zeros):
+ assert len(scales_and_zeros.shape) == 3 and scales_and_zeros.shape[2] == 2
+ assert scales_and_zeros.dtype == torch.float
+ return torch.split(scales_and_zeros.transpose(0, 1), 1, 2)
+
+
+def group_quantize_tensor_from_qparams(w, scales, zeros, n_bit=4, groupsize=128):
+ assert groupsize > 1
+ # needed for GPTQ single column quantize
+ if groupsize > w.shape[-1] and scales.shape[-1] == 1:
+ groupsize = w.shape[-1]
+
+ assert w.shape[-1] % groupsize == 0
+ assert w.dim() == 2
+
+ to_quant = w.reshape(-1, groupsize)
+ assert torch.isnan(to_quant).sum() == 0
+
+ scales = scales.reshape(-1, 1)
+ zeros = zeros.reshape(-1, 1)
+ min_val = zeros - scales * (2 ** (n_bit - 1))
+ max_int = 2**n_bit - 1
+ min_int = 0
+ w_int32 = (
+ to_quant.sub(min_val)
+ .div(scales)
+ .round()
+ .clamp_(min_int, max_int)
+ .to(torch.int32)
+ .reshape_as(w)
+ )
+
+ return w_int32
+
+
+def group_quantize_tensor(w, n_bit=4, groupsize=128):
+ scales, zeros = get_group_qparams(w, n_bit, groupsize)
+ w_int32 = group_quantize_tensor_from_qparams(w, scales, zeros, n_bit, groupsize)
+ scales_and_zeros = pack_scales_and_zeros(scales, zeros)
+ return w_int32, scales_and_zeros
+
+
+def group_dequantize_tensor_from_qparams(
+ w_int32, scales, zeros, n_bit=4, groupsize=128
+):
+ assert groupsize > 1
+ # needed for GPTQ single column dequantize
+ if groupsize > w_int32.shape[-1] and scales.shape[-1] == 1:
+ groupsize = w_int32.shape[-1]
+ assert w_int32.shape[-1] % groupsize == 0
+ assert w_int32.dim() == 2
+
+ w_int32_grouped = w_int32.reshape(-1, groupsize)
+ scales = scales.reshape(-1, 1)
+ zeros = zeros.reshape(-1, 1)
+
+ w_dq = (
+ w_int32_grouped.sub(2 ** (n_bit - 1)).mul(scales).add(zeros).reshape_as(w_int32)
+ )
+ return w_dq
+
+
+def group_dequantize_tensor(w_int32, scales_and_zeros, n_bit=4, groupsize=128):
+ scales, zeros = unpack_scales_and_zeros(scales_and_zeros)
+ return group_dequantize_tensor_from_qparams(
+ w_int32, scales, zeros, n_bit, groupsize
+ )
+
+
+class QuantHandler:
+ def __init__(self, mod):
+ self.mod = mod
+
+ def create_quantized_state_dict(self) -> "StateDict":
+ pass
+
+ def convert_for_runtime(self) -> "nn.Module":
+ pass
+
+
+##### Weight-only int8 per-channel quantized code ######
+
+
+def replace_linear_weight_only_int8_per_channel(module):
+ for name, child in module.named_children():
+ if isinstance(child, nn.Linear):
+ setattr(
+ module,
+ name,
+ WeightOnlyInt8Linear(child.in_features, child.out_features),
+ )
+ else:
+ replace_linear_weight_only_int8_per_channel(child)
+
+
+class WeightOnlyInt8QuantHandler:
+ def __init__(self, mod):
+ self.mod = mod
+
+ @torch.no_grad()
+ def create_quantized_state_dict(self):
+ cur_state_dict = self.mod.state_dict()
+ for fqn, mod in self.mod.named_modules():
+ if isinstance(mod, torch.nn.Linear):
+ int8_weight, scales, _ = dynamically_quantize_per_channel(
+ mod.weight.float(), -128, 127, torch.int8
+ )
+ cur_state_dict[f"{fqn}.weight"] = int8_weight
+ cur_state_dict[f"{fqn}.scales"] = scales.to(mod.weight.dtype)
+
+ return cur_state_dict
+
+ def convert_for_runtime(self):
+ replace_linear_weight_only_int8_per_channel(self.mod)
+ return self.mod
+
+
+class WeightOnlyInt8Linear(torch.nn.Module):
+ __constants__ = ["in_features", "out_features"]
+ in_features: int
+ out_features: int
+ weight: torch.Tensor
+
+ def __init__(
+ self,
+ in_features: int,
+ out_features: int,
+ bias: bool = True,
+ device=None,
+ dtype=None,
+ ) -> None:
+ factory_kwargs = {"device": device, "dtype": dtype}
+ super().__init__()
+ self.in_features = in_features
+ self.out_features = out_features
+ self.register_buffer(
+ "weight", torch.empty((out_features, in_features), dtype=torch.int8)
+ )
+ self.register_buffer("scales", torch.ones(out_features, dtype=torch.bfloat16))
+
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
+ return F.linear(input, self.weight.to(dtype=input.dtype)) * self.scales
+
+
+##### weight only int4 per channel groupwise quantized code ######
+
+
+def prepare_int4_weight_and_scales_and_zeros(weight_bf16, groupsize, inner_k_tiles):
+ weight_int32, scales_and_zeros = group_quantize_tensor(
+ weight_bf16, n_bit=4, groupsize=groupsize
+ )
+ weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(
+ weight_int32, inner_k_tiles
+ )
+ return weight_int4pack, scales_and_zeros
+
+
+def linear_forward_int4(x, weight_int4pack, scales_and_zeros, out_features, groupsize):
+ origin_x_size = x.size()
+ x = x.reshape(-1, origin_x_size[-1])
+ c = torch.ops.aten._weight_int4pack_mm(
+ x, weight_int4pack, groupsize, scales_and_zeros
+ )
+ new_shape = origin_x_size[:-1] + (out_features,)
+ c = c.reshape(new_shape)
+ return c
+
+
+def _check_linear_int4_k(k, groupsize=1, inner_k_tiles=1):
+ return k % groupsize == 0 and k % (inner_k_tiles * 16) == 0
+
+
+def replace_linear_int4(module, groupsize, inner_k_tiles, padding):
+ for name, child in module.named_children():
+ if isinstance(child, nn.Linear):
+ if _check_linear_int4_k(child.in_features, groupsize, inner_k_tiles):
+ setattr(
+ module,
+ name,
+ WeightOnlyInt4Linear(
+ child.in_features,
+ child.out_features,
+ bias=False,
+ groupsize=groupsize,
+ inner_k_tiles=inner_k_tiles,
+ padding=False,
+ ),
+ )
+ elif padding:
+ setattr(
+ module,
+ name,
+ WeightOnlyInt4Linear(
+ child.in_features,
+ child.out_features,
+ bias=False,
+ groupsize=groupsize,
+ inner_k_tiles=inner_k_tiles,
+ padding=True,
+ ),
+ )
+ else:
+ replace_linear_int4(child, groupsize, inner_k_tiles, padding)
+
+
+class WeightOnlyInt4QuantHandler:
+ def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding=True):
+ self.mod = mod
+ self.groupsize = groupsize
+ self.inner_k_tiles = inner_k_tiles
+ self.padding = padding
+ assert groupsize in [32, 64, 128, 256]
+ assert inner_k_tiles in [2, 4, 8]
+
+ @torch.no_grad()
+ def create_quantized_state_dict(self):
+ cur_state_dict = self.mod.state_dict()
+ for fqn, mod in self.mod.named_modules():
+ if isinstance(mod, torch.nn.Linear):
+ assert not mod.bias
+ out_features = mod.out_features
+ in_features = mod.in_features
+ assert out_features % 8 == 0, "require out_features % 8 == 0"
+ print(f"linear: {fqn}, in={in_features}, out={out_features}")
+
+ weight = mod.weight.data
+ if not _check_linear_int4_k(
+ in_features, self.groupsize, self.inner_k_tiles
+ ):
+ if self.padding:
+ import torch.nn.functional as F
+
+ print(
+ f"warning: {fqn} is padded to satisfy in_features % 1024 == 0"
+ )
+ padded_in_features = find_multiple(in_features, 1024)
+ weight = F.pad(
+ weight, pad=(0, padded_in_features - in_features)
+ )
+ else:
+ print(
+ f"warning: {fqn} is skipped, int4 requires that in_features is 32, 64, or is divisible by 1024, "
+ + "and that groupsize and inner_k_tiles*16 evenly divide into it"
+ )
+ continue
+ (
+ weight_int4pack,
+ scales_and_zeros,
+ ) = prepare_int4_weight_and_scales_and_zeros(
+ weight.to(torch.bfloat16).to("cuda"),
+ self.groupsize,
+ self.inner_k_tiles,
+ )
+ cur_state_dict[f"{fqn}.weight"] = weight_int4pack.to("cpu")
+ cur_state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros.to("cpu")
+
+ return cur_state_dict
+
+ def convert_for_runtime(self):
+ replace_linear_int4(self.mod, self.groupsize, self.inner_k_tiles, self.padding)
+ return self.mod
+
+
+class WeightOnlyInt4Linear(torch.nn.Module):
+ __constants__ = ["in_features", "out_features"]
+ in_features: int
+ out_features: int
+ weight: torch.Tensor
+
+ def __init__(
+ self,
+ in_features: int,
+ out_features: int,
+ bias=True,
+ device=None,
+ dtype=None,
+ groupsize: int = 128,
+ inner_k_tiles: int = 8,
+ padding: bool = True,
+ ) -> None:
+ super().__init__()
+ self.padding = padding
+ if padding:
+ self.origin_in_features = in_features
+ in_features = find_multiple(in_features, 1024)
+
+ self.in_features = in_features
+ self.out_features = out_features
+ assert not bias, "require bias=False"
+ self.groupsize = groupsize
+ self.inner_k_tiles = inner_k_tiles
+
+ assert out_features % 8 == 0, "require out_features % 8 == 0"
+ assert (
+ in_features % (inner_k_tiles * 16) == 0
+ ), "require in_features % (innerKTiles * 16) == 0"
+ self.register_buffer(
+ "weight",
+ torch.empty(
+ (
+ out_features // 8,
+ in_features // (inner_k_tiles * 16),
+ 32,
+ inner_k_tiles // 2,
+ ),
+ dtype=torch.int32,
+ ),
+ )
+ self.register_buffer(
+ "scales_and_zeros",
+ torch.empty(
+ (in_features // groupsize, out_features, 2), dtype=torch.bfloat16
+ ),
+ )
+
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
+ input = input.to(torch.bfloat16)
+ if self.padding:
+ import torch.nn.functional as F
+
+ input = F.pad(input, pad=(0, self.in_features - self.origin_in_features))
+ return linear_forward_int4(
+ input, self.weight, self.scales_and_zeros, self.out_features, self.groupsize
+ )
+
+
+def quantize(
+ checkpoint_path: Path = Path("checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"),
+ mode: str = "int8",
+ # following arguments only available when setting int4 quantization.
+ groupsize: int = 128,
+) -> None:
+ assert checkpoint_path.is_file(), checkpoint_path
+
+ device = "cpu"
+ precision = torch.bfloat16
+
+ print("Loading model ...")
+ t0 = time.time()
+
+ with torch.device("meta"):
+ model = Transformer(
+ ModelArgs(
+ max_seq_len=4096,
+ vocab_size=36408,
+ n_layer=24,
+ n_head=16,
+ dim=1024,
+ rope_base=10000,
+ norm_eps=1e-5,
+ num_codebooks=4, # single codebook
+ codebook_size=168, # codebook size 160 + 2 special tokens
+ )
+ )
+
+ checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True)
+ if "state_dict" in checkpoint:
+ checkpoint = checkpoint["state_dict"]
+ checkpoint = {
+ k.replace("model.", ""): v
+ for k, v in checkpoint.items()
+ if k.startswith("model.")
+ }
+ model.load_state_dict(checkpoint, assign=True)
+ model = model.to(dtype=precision, device=device)
+
+ if mode == "int8":
+ print(
+ "Quantizing model weights for int8 weight-only symmetric per-channel quantization"
+ )
+ quant_handler = WeightOnlyInt8QuantHandler(model)
+ quantized_state_dict = quant_handler.create_quantized_state_dict()
+
+ dir_name = checkpoint_path.parent
+ base_name = checkpoint_path.stem
+ suffix = checkpoint_path.suffix
+ quantize_path = dir_name / f"{base_name}.int8{suffix}"
+
+ elif mode == "int4":
+ print(
+ "Quantizing model weights for int4 weight-only affine per-channel groupwise quantization"
+ )
+ quant_handler = WeightOnlyInt4QuantHandler(model, groupsize)
+ quantized_state_dict = quant_handler.create_quantized_state_dict()
+
+ dir_name = checkpoint_path.parent
+ base_name = checkpoint_path.name
+ suffix = checkpoint_path.suffix
+ quantize_path = dir_name / f"{base_name}.int4.g{groupsize}{suffix}"
+
+ else:
+ raise ValueError(
+ f"Invalid quantization mode {mode} needs to be one of [int8, int4, int4-gpptq]"
+ )
+
+ print(f"Writing quantized weights to {quantize_path}")
+ quantize_path.unlink(missing_ok=True) # remove existing file if one already there
+ torch.save(quantized_state_dict, quantize_path)
+ print(f"Quantization complete took {time.time() - t0:.02f} seconds")
+
+
+if __name__ == "__main__":
+ import argparse
+
+ parser = argparse.ArgumentParser(description="Quantize a model.")
+ parser.add_argument(
+ "--checkpoint_path",
+ type=Path,
+ default=Path("checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"),
+ help="Path to the model checkpoint to be quantized.",
+ )
+ parser.add_argument(
+ "--mode",
+ "-q",
+ type=str,
+ default="int8",
+ choices=["int8", "int4"],
+ help="type of quantization to perform",
+ )
+ parser.add_argument(
+ "--groupsize", type=int, default=32, help="Group size for int4 quantization."
+ )
+
+ args = parser.parse_args()
+ quantize(args.checkpoint_path, args.mode, args.groupsize)
diff --git a/tools/llama/rebuild_tokenizer.py b/tools/llama/rebuild_tokenizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..ea64fa6788833000c8dc41e3d570dd5b250fb14b
--- /dev/null
+++ b/tools/llama/rebuild_tokenizer.py
@@ -0,0 +1,57 @@
+from tokenizers import Tokenizer, decoders, models, pre_tokenizers, processors, trainers
+from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
+
+# Initialize a tokenizer
+tokenizer = Tokenizer(models.BPE())
+
+# Customize pre-tokenization and decoding
+tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)
+tokenizer.decoder = decoders.ByteLevel()
+tokenizer.post_processor = processors.ByteLevel(trim_offsets=False)
+
+# Don't train the tokenizer
+trainer = trainers.BpeTrainer(
+ vocab_size=0,
+ min_frequency=2,
+ initial_alphabet=pre_tokenizers.ByteLevel.alphabet(),
+ special_tokens=[
+ "<|begin_of_sequence|>",
+ "<|end_of_sequence|>",
+ "<|im_start|>",
+ "<|im_sep|>", # system, user, assistant, etc.
+ "<|im_end|>",
+ "<|semantic|>", # audio features
+ "<|pad|>",
+ ],
+)
+
+# <|im_start|>user<|im_sep|>...<|im_end|>
+# <|im_start|>assistant<|im_sep|><|semantic|><|semantic|><|semantic|><|semantic|><|semantic|><|im_end|>
+tokenizer.train_from_iterator([], trainer=trainer)
+
+print(len(tokenizer.get_vocab()))
+x = tokenizer.encode(
+ "Hello, how are you? dfgnviadfjoiviouajeiodfjv 你好世界 🈶<|semantic|>"
+).ids
+print(x, len(x))
+print(tokenizer.decode(x, skip_special_tokens=True))
+
+
+tokenizer = PreTrainedTokenizerFast(
+ tokenizer_object=tokenizer,
+ pad_token="<|pad|>",
+ bos_token="<|begin_of_sequence|>",
+ eos_token="<|end_of_sequence|>",
+)
+
+# Try tokenizing a new sequence
+sequence = "All around, too, lay vast quantities of the costliest merchandise, and treasures were heaped in every cranny of the rocks, but all these things only added to the desolation of the scene. 测试中文, 你好世界 🈶<|semantic|>"
+encoded = tokenizer(sequence).input_ids
+
+print("Test encoding....")
+print(f"\tSentence: {sequence}")
+print(f"\tEncoded: {encoded}")
+print(f"\tDecoded: {tokenizer.batch_decode(encoded)}")
+print(f"\tDecoded: {tokenizer.decode(encoded)}")
+
+tokenizer.push_to_hub("fishaudio/fish-speech-1", private=True)
diff --git a/tools/merge_asr_files.py b/tools/merge_asr_files.py
new file mode 100644
index 0000000000000000000000000000000000000000..d86d29a7a220aafc92cf8cf5ea9689f027b2287c
--- /dev/null
+++ b/tools/merge_asr_files.py
@@ -0,0 +1,55 @@
+import os
+from pathlib import Path
+
+from pydub import AudioSegment
+from tqdm import tqdm
+
+from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files
+
+
+def merge_and_delete_files(save_dir, original_files):
+ save_path = Path(save_dir)
+ audio_slice_files = list_files(
+ path=save_dir, extensions=AUDIO_EXTENSIONS.union([".lab"]), recursive=True
+ )
+ audio_files = {}
+ label_files = {}
+ for file_path in tqdm(audio_slice_files, desc="Merging audio files"):
+ rel_path = Path(file_path).relative_to(save_path)
+ (save_path / rel_path.parent).mkdir(parents=True, exist_ok=True)
+ if file_path.suffix == ".wav":
+ prefix = rel_path.parent / file_path.stem.rsplit("-", 1)[0]
+ if prefix == rel_path.parent / file_path.stem:
+ continue
+ audio = AudioSegment.from_wav(file_path)
+ if prefix in audio_files.keys():
+ audio_files[prefix] = audio_files[prefix] + audio
+ else:
+ audio_files[prefix] = audio
+
+ elif file_path.suffix == ".lab":
+ prefix = rel_path.parent / file_path.stem.rsplit("-", 1)[0]
+ if prefix == rel_path.parent / file_path.stem:
+ continue
+ with open(file_path, "r", encoding="utf-8") as f:
+ label = f.read()
+ if prefix in label_files.keys():
+ label_files[prefix] = label_files[prefix] + ", " + label
+ else:
+ label_files[prefix] = label
+
+ for prefix, audio in audio_files.items():
+ output_audio_path = save_path / f"{prefix}.wav"
+ audio.export(output_audio_path, format="wav")
+
+ for prefix, label in label_files.items():
+ output_label_path = save_path / f"{prefix}.lab"
+ with open(output_label_path, "w", encoding="utf-8") as f:
+ f.write(label)
+
+ for file_path in original_files:
+ os.remove(file_path)
+
+
+if __name__ == "__main__":
+ merge_and_delete_files("/made/by/spicysama/laziman", [__file__])
diff --git a/tools/vqgan/create_train_split.py b/tools/vqgan/create_train_split.py
new file mode 100644
index 0000000000000000000000000000000000000000..73a079ed931551407da3e479e63857843502c687
--- /dev/null
+++ b/tools/vqgan/create_train_split.py
@@ -0,0 +1,54 @@
+import math
+from pathlib import Path
+from random import Random
+
+import click
+from loguru import logger
+from tqdm import tqdm
+
+from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files, load_filelist
+
+
+@click.command()
+@click.argument("root", type=click.Path(exists=True, path_type=Path))
+@click.option("--val-ratio", type=float, default=None)
+@click.option("--val-count", type=int, default=None)
+@click.option("--filelist", default=None, type=Path)
+def main(root, val_ratio, val_count, filelist):
+ if filelist:
+ files = [i[0] for i in load_filelist(filelist)]
+ else:
+ files = list_files(root, AUDIO_EXTENSIONS, recursive=True, sort=True)
+
+ logger.info(f"Found {len(files)} files")
+ files = [str(file.relative_to(root)) for file in tqdm(files)]
+
+ Random(42).shuffle(files)
+
+ if val_count is None and val_ratio is None:
+ logger.info("Validation ratio and count not specified, using min(20%, 100)")
+ val_size = min(100, math.ceil(len(files) * 0.2))
+ elif val_count is not None and val_ratio is not None:
+ logger.error("Cannot specify both val_count and val_ratio")
+ return
+ elif val_count is not None:
+ if val_count < 1 or val_count > len(files):
+ logger.error("val_count must be between 1 and number of files")
+ return
+ val_size = val_count
+ else:
+ val_size = math.ceil(len(files) * val_ratio)
+
+ logger.info(f"Using {val_size} files for validation")
+
+ with open(root / "vq_train_filelist.txt", "w", encoding="utf-8") as f:
+ f.write("\n".join(files[val_size:]))
+
+ with open(root / "vq_val_filelist.txt", "w", encoding="utf-8") as f:
+ f.write("\n".join(files[:val_size]))
+
+ logger.info("Done")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/tools/vqgan/extract_vq.py b/tools/vqgan/extract_vq.py
new file mode 100644
index 0000000000000000000000000000000000000000..c8771faa533ee878a3453f3dfb87472ccc769e3e
--- /dev/null
+++ b/tools/vqgan/extract_vq.py
@@ -0,0 +1,213 @@
+import os
+import subprocess as sp
+import sys
+import time
+from datetime import timedelta
+from functools import lru_cache
+from pathlib import Path
+from random import Random
+
+import click
+import numpy as np
+import torch
+import torchaudio
+from hydra import compose, initialize
+from hydra.utils import instantiate
+from lightning import LightningModule
+from loguru import logger
+from omegaconf import OmegaConf
+
+from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files, load_filelist
+
+# register eval resolver
+OmegaConf.register_new_resolver("eval", eval)
+# This file is used to convert the audio files to text files using the Whisper model.
+# It's mainly used to generate the training data for the VQ model.
+
+
+RANK = int(os.environ.get("SLURM_PROCID", 0))
+WORLD_SIZE = int(os.environ.get("SLURM_NTASKS", 1))
+
+logger_format = (
+ "{time:YYYY-MM-DD HH:mm:ss.SSS} | "
+ "{level: <8} | "
+ "{name}:{function}:{line} | "
+ "{extra[rank]} - {message}"
+)
+logger.configure(extra={"rank": f"RANK: {RANK} / {WORLD_SIZE}"})
+logger.remove()
+logger.add(sys.stderr, format=logger_format)
+
+
+@lru_cache(maxsize=1)
+def get_model(
+ config_name: str = "vqgan_pretrain",
+ checkpoint_path: str = "checkpoints/vqgan/step_000380000.ckpt",
+):
+ with initialize(version_base="1.3", config_path="../../fish_speech/configs"):
+ cfg = compose(config_name=config_name)
+
+ model: LightningModule = instantiate(cfg.model)
+ state_dict = torch.load(
+ checkpoint_path,
+ map_location=model.device,
+ )
+ if "state_dict" in state_dict:
+ state_dict = state_dict["state_dict"]
+
+ model.load_state_dict(state_dict, strict=False)
+ model.eval()
+ model.cuda()
+
+ logger.info(f"Loaded model")
+ return model
+
+
+@torch.inference_mode()
+def process_batch(files: list[Path], model) -> float:
+ wavs = []
+ audio_lengths = []
+ new_files = []
+ max_length = total_time = 0
+
+ for file in files:
+ try:
+ wav, sr = torchaudio.load(
+ str(file), backend="sox"
+ ) # Need to install libsox-dev
+ except Exception as e:
+ logger.error(f"Error reading {file}: {e}")
+ continue
+
+ if wav.shape[0] > 1:
+ wav = wav.mean(dim=0, keepdim=True)
+
+ wav = torchaudio.functional.resample(wav.cuda(), sr, model.sampling_rate)[0]
+ total_time += len(wav) / model.sampling_rate
+ max_length = max(max_length, len(wav))
+
+ wavs.append(wav)
+ audio_lengths.append(len(wav))
+ new_files.append(file)
+
+ files = new_files
+
+ # Pad to max length
+ for i, wav in enumerate(wavs):
+ wavs[i] = torch.nn.functional.pad(wav, (0, max_length - len(wav)), "constant")
+
+ audios = torch.stack(wavs, dim=0)[:, None]
+ audio_lengths = torch.tensor(audio_lengths, device=model.device, dtype=torch.long)
+
+ # Calculate lengths
+ indices, feature_lengths = model.encode(audios, audio_lengths)
+
+ # Save to disk
+ outputs = indices.cpu().numpy()
+
+ for file, length, feature, audio_length in zip(
+ files, feature_lengths, outputs, audio_lengths
+ ):
+ feature = feature[:, :length]
+
+ # (T,)
+ with open(file.with_suffix(".npy"), "wb") as f:
+ np.save(f, feature)
+
+ return total_time
+
+
+@click.command()
+@click.argument("folder")
+@click.option("--num-workers", default=1)
+@click.option("--config-name", default="vqgan_pretrain")
+@click.option(
+ "--checkpoint-path",
+ default="checkpoints/vq-gan-group-fsq-8x1024-wn-20x768-30kh.pth",
+)
+@click.option("--batch-size", default=64)
+@click.option("--filelist", default=None, type=Path)
+def main(
+ folder: str,
+ num_workers: int,
+ config_name: str,
+ checkpoint_path: str,
+ batch_size: int,
+ filelist: Path,
+):
+ if num_workers > 1 and WORLD_SIZE != num_workers:
+ assert WORLD_SIZE == 1, "You should either use SLURM or this launcher, not both"
+
+ logger.info(f"Spawning {num_workers} workers")
+
+ visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
+ if visible_devices is None:
+ visible_devices = list(range(torch.cuda.device_count()))
+ else:
+ visible_devices = visible_devices.split(",")
+
+ processes = []
+ for i in range(num_workers):
+ env = os.environ.copy()
+ env["CUDA_VISIBLE_DEVICES"] = str(visible_devices[i % len(visible_devices)])
+ env["SLURM_PROCID"] = str(i)
+ env["SLURM_NTASKS"] = str(num_workers)
+
+ processes.append(
+ sp.Popen(
+ [sys.executable] + sys.argv.copy(),
+ env=env,
+ )
+ )
+
+ for p in processes:
+ p.wait()
+
+ logger.info(f"All workers finished")
+ return
+
+ # This is a worker
+ logger.info(f"Starting worker")
+ if filelist:
+ files = [i[0] for i in load_filelist(filelist)]
+ else:
+ files = list_files(folder, AUDIO_EXTENSIONS, recursive=True, sort=False)
+
+ print(f"Found {len(files)} files")
+ # files = [Path(f) for f in files if not Path(f).with_suffix(".npy").exists()]
+
+ total_files = len(files)
+ files = files[RANK::WORLD_SIZE]
+ logger.info(f"Processing {len(files)}/{total_files} files")
+
+ # Batch processing
+ total_time = 0
+ begin_time = time.time()
+ processed_files = 0
+ model = get_model(config_name, checkpoint_path)
+
+ for n_batch, idx in enumerate(range(0, len(files), batch_size)):
+ batch = files[idx : idx + batch_size]
+ batch_time = process_batch(batch, model)
+
+ total_time += batch_time
+ processed_files += len(batch)
+
+ if (n_batch + 1) % 10 == 0:
+ eta = (
+ (time.time() - begin_time)
+ / processed_files
+ * (len(files) - processed_files)
+ )
+ logger.info(
+ f"Processed {processed_files} files, {total_time / 3600:.2f} hours of audio, "
+ + f"ETA: {timedelta(seconds=round(eta))}s"
+ )
+
+ logger.info(
+ f"Finished processing {len(files)} files, {total_time / 3600:.2f} hours of audio"
+ )
+
+
+if __name__ == "__main__":
+ main()
diff --git a/tools/vqgan/inference.py b/tools/vqgan/inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..d25c00913e6248f2f6138bca3d2bed8b571084cd
--- /dev/null
+++ b/tools/vqgan/inference.py
@@ -0,0 +1,115 @@
+from pathlib import Path
+
+import click
+import librosa
+import numpy as np
+import soundfile as sf
+import torch
+from hydra import compose, initialize
+from hydra.utils import instantiate
+from lightning import LightningModule
+from loguru import logger
+from omegaconf import OmegaConf
+
+from fish_speech.utils.file import AUDIO_EXTENSIONS
+
+# register eval resolver
+OmegaConf.register_new_resolver("eval", eval)
+
+
+def load_model(config_name, checkpoint_path, device="cuda"):
+ with initialize(version_base="1.3", config_path="../../fish_speech/configs"):
+ cfg = compose(config_name=config_name)
+
+ model: LightningModule = instantiate(cfg.model)
+ state_dict = torch.load(
+ checkpoint_path,
+ map_location=model.device,
+ )
+
+ if "state_dict" in state_dict:
+ state_dict = state_dict["state_dict"]
+
+ model.load_state_dict(state_dict, strict=False)
+ model.eval()
+ model.to(device)
+ logger.info("Restored model from checkpoint")
+
+ return model
+
+
+@torch.no_grad()
+@click.command()
+@click.option(
+ "--input-path",
+ "-i",
+ default="test.wav",
+ type=click.Path(exists=True, path_type=Path),
+)
+@click.option(
+ "--output-path", "-o", default="fake.wav", type=click.Path(path_type=Path)
+)
+@click.option("--config-name", "-cfg", default="vqgan_pretrain")
+@click.option(
+ "--checkpoint-path",
+ "-ckpt",
+ default="checkpoints/vq-gan-group-fsq-2x1024.pth",
+)
+@click.option(
+ "--device",
+ "-d",
+ default="cuda",
+)
+def main(input_path, output_path, config_name, checkpoint_path, device):
+ model = load_model(config_name, checkpoint_path, device=device)
+
+ if input_path.suffix in AUDIO_EXTENSIONS:
+ logger.info(f"Processing in-place reconstruction of {input_path}")
+ # Load audio
+ audio, _ = librosa.load(
+ input_path,
+ sr=model.sampling_rate,
+ mono=True,
+ )
+ audios = torch.from_numpy(audio).to(model.device)[None, None, :]
+ logger.info(
+ f"Loaded audio with {audios.shape[2] / model.sampling_rate:.2f} seconds"
+ )
+
+ # VQ Encoder
+ audio_lengths = torch.tensor(
+ [audios.shape[2]], device=model.device, dtype=torch.long
+ )
+ indices = model.encode(audios, audio_lengths)[0][0]
+
+ logger.info(f"Generated indices of shape {indices.shape}")
+
+ # Save indices
+ np.save(output_path.with_suffix(".npy"), indices.cpu().numpy())
+ elif input_path.suffix == ".npy":
+ logger.info(f"Processing precomputed indices from {input_path}")
+ indices = np.load(input_path)
+ indices = torch.from_numpy(indices).to(model.device).long()
+ assert indices.ndim == 2, f"Expected 2D indices, got {indices.ndim}"
+ else:
+ raise ValueError(f"Unknown input type: {input_path}")
+
+ # Restore
+ feature_lengths = torch.tensor([indices.shape[1]], device=model.device)
+ fake_audios = model.decode(
+ indices=indices[None], feature_lengths=feature_lengths, return_audios=True
+ )
+ audio_time = fake_audios.shape[-1] / model.sampling_rate
+
+ logger.info(
+ f"Generated audio of shape {fake_audios.shape}, equivalent to {audio_time:.2f} seconds from {indices.shape[1]} features, features/second: {indices.shape[1] / audio_time:.2f}"
+ )
+
+ # Save audio
+ fake_audio = fake_audios[0, 0].float().cpu().numpy()
+ sf.write(output_path, fake_audio, model.sampling_rate)
+ logger.info(f"Saved audio to {output_path}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/tools/whisper_asr.py b/tools/whisper_asr.py
new file mode 100644
index 0000000000000000000000000000000000000000..d303607b3fb3af971af977465806a68687e343e2
--- /dev/null
+++ b/tools/whisper_asr.py
@@ -0,0 +1,113 @@
+"""
+Used to transcribe all audio files in one folder into another folder.
+e.g.
+Directory structure:
+--pre_data_root
+----SP_1
+------01.wav
+------02.wav
+------......
+----SP_2
+------01.wav
+------02.wav
+------......
+Use
+python tools/whisper_asr.py --audio_dir pre_data_root/SP_1 --save_dir data/SP_1
+to transcribe the first speaker.
+
+Use
+python tools/whisper_asr.py --audio_dir pre_data_root/SP_2 --save_dir data/SP_2
+to transcribe the second speaker.
+
+Note: Be aware of your audio sample rate, which defaults to 44.1kHz.
+"""
+from pathlib import Path
+
+import click
+import librosa
+import soundfile as sf
+import whisper
+from loguru import logger
+from merge_asr_files import merge_and_delete_files
+from tqdm import tqdm
+
+from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files
+
+
+@click.command()
+@click.option("--model-size", default="large", help="Size of the Whisper model")
+@click.option("--audio-dir", required=True, help="Directory containing audio files")
+@click.option(
+ "--save-dir", required=True, help="Directory to save processed audio files"
+)
+@click.option(
+ "--sample-rate",
+ default=None,
+ type=int,
+ help="Output sample rate, default to input sample rate",
+)
+@click.option("--device", default="cuda", help="Device to use")
+@click.option("--language", default="ZH", help="Language of the transcription")
+def main(model_size, audio_dir, save_dir, sample_rate, device, language):
+ logger.info("Loading / Downloading OpenAI Whisper model...")
+ model = whisper.load_model(
+ name=model_size,
+ device=device,
+ download_root=str(Path(".cache/whisper").resolve()),
+ )
+ logger.info("Model loaded.")
+
+ save_path = Path(save_dir)
+ save_path.mkdir(parents=True, exist_ok=True)
+ original_files = []
+ audio_files = list_files(
+ path=audio_dir, extensions=AUDIO_EXTENSIONS, recursive=True
+ )
+ for file_path in tqdm(audio_files, desc="Processing audio file"):
+ file_stem = file_path.stem
+ file_suffix = file_path.suffix
+
+ rel_path = Path(file_path).relative_to(audio_dir)
+ (save_path / rel_path.parent).mkdir(parents=True, exist_ok=True)
+
+ if (save_path / rel_path.parent / f"{rel_path.stem}.wav").exists() and (
+ save_path / rel_path.parent / f"{rel_path.stem}.lab"
+ ).exists():
+ continue
+
+ audio, sr = librosa.load(file_path, sr=sample_rate, mono=False)
+ transcription = model.transcribe(str(file_path), language=language)
+
+ for segment in transcription.get("segments", []):
+ id, text, start, end = (
+ segment["id"],
+ segment["text"],
+ segment["start"],
+ segment["end"],
+ )
+
+ extract = audio[..., int(start * sr) : int(end * sr)]
+ audio_save_path = (
+ save_path / rel_path.parent / f"{file_stem}-{id}{file_suffix}"
+ )
+ sf.write(
+ audio_save_path,
+ extract,
+ samplerate=sr,
+ )
+ original_files.append(audio_save_path)
+
+ transcript_save_path = save_path / rel_path.parent / f"{file_stem}-{id}.lab"
+ with open(
+ transcript_save_path,
+ "w",
+ encoding="utf-8",
+ ) as f:
+ f.write(text)
+ original_files.append(transcript_save_path)
+
+ merge_and_delete_files(save_dir, original_files)
+
+
+if __name__ == "__main__":
+ main()