mini-omni-s2s / s2s_config.py
xcczach's picture
Upload 73 files
35c1cfd verified
raw
history blame
9.34 kB
from dataclasses import dataclass, field
from typing import Optional, List
import os
CKPT_NAME = "model.pt"
CKPT_LOCAL_DIR = "model_ckpts"
CKPT_PATH = os.path.join(CKPT_LOCAL_DIR, CKPT_NAME)
CKPT_REPO = "xcczach/mini-omni"
@dataclass
class VocabConfig:
text_vocabsize: int = 151936
text_specialtokens: int = 64
audio_vocabsize: int = 4096
audio_specialtokens: int = 64
total_vocabsize: int = 181120
code_layer: int = 7
padded_text_vocabsize: int = field(init=False)
padded_audio_vocabsize: int = field(init=False)
total_audio_vocabsize: int = field(init=False)
eot: int = field(init=False) # end of text token
pad_t: int = field(init=False) # padding text token
input_t: int = field(init=False) # input text token
answer_t: int = field(init=False) # answer text token
asr: int = field(init=False) # ASR token
eoa: int = field(init=False) # end of audio token
pad_a: int = field(init=False) # padding audio token
input_a: int = field(init=False) # input audio token
answer_a: int = field(init=False) # answer audio token
split: int = field(init=False) # split token
def __post_init__(self):
self.padded_text_vocabsize = self.text_vocabsize + self.text_specialtokens
self.padded_audio_vocabsize = self.audio_vocabsize + self.audio_specialtokens
self.total_audio_vocabsize = self.padded_audio_vocabsize * self.code_layer
self.eot = self.text_vocabsize
self.pad_t = self.text_vocabsize + 1
self.input_t = self.text_vocabsize + 2
self.answer_t = self.text_vocabsize + 3
self.asr = self.text_vocabsize + 4
self.eoa = self.audio_vocabsize
self.pad_a = self.audio_vocabsize + 1
self.input_a = self.audio_vocabsize + 2
self.answer_a = self.audio_vocabsize + 3
self.split = self.audio_vocabsize + 4
@dataclass
class TTSAdapterConfig:
add_qkv_bias: Optional[bool] = True
bias: bool = False
gelu_approximate: Optional[str] = None
head_size: Optional[int] = 64
intermediate_size: Optional[int] = 4864
lm_head_bias: bool = False
mlp_class_name: str = "GptNeoxMLP"
n_layer: int = 6
n_head: int = 14
n_embd: int = 896
n_query_groups: Optional[int] = 2
norm_class_name: str = "RMSNorm"
norm_eps: float = 1e-6
parallel_residual: bool = False
rotary_percentage: float = 1
shared_attention_norm: bool = False
def __post_init__(self):
self.rope_n_elem = int(self.rotary_percentage * self.head_size)
@dataclass
class ModelConfig:
file: str = "model/slam_model_s2s.py:model_factory"
llm_name: str = "qwen2-0.5b"
llm_path: str = "Qwen/Qwen2-0.5B"
llm_type: str = "decoder_only"
llm_dim: int = 896
encoder_name: Optional[str] = "whisper"
encoder_ds_rate: int = 2
encoder_path: Optional[str] = "small"
encoder_dim: int = 768
encoder_projector: str = "linear"
encoder_projector_ds_rate: int = 5
modal: str = "audio"
normalize: Optional[bool] = field(
default=False,
metadata={"help": "whether input is normalized, used for models such as wavlm"},
)
encoder_type: str = field(
default="finetune",
metadata={
"help": "whether model is only pretrained or finetuned, used for models such as hubert"
},
)
vocab_config: VocabConfig = field(default_factory=VocabConfig)
codec_decode: bool = True
codec_decoder_type: str = "SNAC"
codec_decoder_path: Optional[str] = "hubertsiuzdak/snac_24khz"
tts_adapter: bool = False
tts_adapter_config: TTSAdapterConfig = field(default_factory=TTSAdapterConfig)
@dataclass
class PeftConfig:
peft_method: str = "lora" # None , llama_adapter, prefix
r: int = 8
lora_alpha: int = 32
target_modules: List = field(default_factory=lambda: ["q_proj", "v_proj"])
bias: str = "none"
task_type: str = "CAUSAL_LM"
lora_dropout: float = 0.05
inference_mode: bool = False
@dataclass
class TrainConfig:
model_name: str = "s2s"
enable_ddp: bool = False
enable_deepspeed: bool = False
enable_fsdp: bool = False
low_cpu_fsdp: bool = False
run_validation: bool = True
batch_size_training: int = 4
batching_strategy: str = field(
default="custom", metadata={"help": "alternative: padding"}
) #
context_length: int = 4096
gradient_accumulation_steps: int = 1
num_epochs: int = 1
num_workers_dataloader: int = 2
warmup_steps: int = 1000
total_steps: int = 100000
validation_interval: int = 1000
lr: float = 1e-4
weight_decay: float = 0.0
gamma: float = 0.85
seed: int = 42
use_fp16: bool = False
mixed_precision: bool = True
val_batch_size: int = 1
use_peft: bool = False
peft_config: PeftConfig = field(default_factory=PeftConfig)
output_dir: str = "PATH/to/save/PEFT/model"
freeze_layers: bool = False
num_freeze_layers: int = 1
quantization: bool = False
one_gpu: bool = False
save_model: bool = True
dist_checkpoint_root_folder: str = (
"PATH/to/save/FSDP/model" # will be used if using FSDP
)
dist_checkpoint_folder: str = "fine-tuned" # will be used if using FSDP
save_optimizer: bool = False # will be used if using FSDP
use_fast_kernels: bool = (
False # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
)
run_test_during_validation: bool = False
run_test_during_validation_file: str = "test.wav"
run_test_during_validation_prompt: str = "<|S2S|>"
freeze_llm: bool = field(
default=True,
metadata={
"help": "whether to freeze llm when finetuning, should be true when use peft finetuning"
},
)
freeze_encoder: bool = True
train_embed_only: bool = False
train_audio_embed_only: bool = False
task_type: str = "s2s"
@dataclass
class DataConfig:
dataset: str = "speech_dataset_s2s"
file: str = "examples/s2s/speech_dataset_s2s.py:get_speech_dataset"
train_data_path: Optional[str] = None
val_data_path: Optional[str] = None
train_split: str = "train"
test_split: str = "validation"
prompt: Optional[str] = None
data_path: Optional[str] = None
max_words: Optional[int] = None
max_mel: Optional[float] = None
fix_length_audio: int = -1
inference_mode: bool = True
input_type: str = field(
default="mel",
metadata={"help": "Use raw when input is wav, mel when for whisper"},
)
mel_size: int = field(
default=80, metadata={"help": "80 for whisper large v1 and v2, 128 for v3"}
)
normalize: Optional[bool] = field(
default=False,
metadata={"help": "whether input is normalized, used for models such as wavlm"},
)
seed: int = 42
manifest_format: str = field(
default="datasets", metadata={"help": "alternative: jsonl"}
)
split_size: float = 0.1
vocab_config: VocabConfig = field(default_factory=VocabConfig)
load_from_cache_file: bool = False
task_type: str = "s2s"
@dataclass
class DecodeConfig:
do_sample: bool = False
max_new_tokens: int = 300
min_length: int = 10
temperature: float = 1.0
top_k: int = 50
top_p: float = 0.9
num_beams: int = 1
num_return_sequences: int = 1
num_samples: int = 1
max_time: float = 0.0
repetition_penalty: float = 1.0
length_penalty: float = 1.0
early_stopping: bool = False
no_repeat_ngram_size: int = 0
bad_words_ids: List = field(default_factory=list)
num_beam_groups: int = 1
diversity_penalty: float = 0.0
task_type: str = "s2s"
decode_text_only: bool = False
@dataclass
class FSDPConfig:
mixed_precision: bool = True
use_fp16: bool = False
# sharding_strategy = "FULL_SHARD" #ShardingStrategy = ShardingStrategy.FULL_SHARD
sharding_strategy: str = (
"NO_SHARD" # ShardingStrategy.NO_SHARD #MZY: set NO_SHARD when use DDP
)
checkpoint_type: str = (
"SHARDED_STATE_DICT" # alternatively can use SHARDED_STATE_DICT save one file per rank, and can resize the world-size.
)
fsdp_activation_checkpointing: bool = True
fsdp_cpu_offload: bool = False
pure_bf16: bool = False
optimizer: str = "AdamW"
@dataclass
class LogConfig:
use_wandb: bool = False
wandb_dir: str = "/valleblob/v-wenxichen/exp/wandb_log"
wandb_entity_name: str = "project_name"
wandb_project_name: str = "project_name"
wandb_exp_name: str = "exp_name"
log_file: str = "/valleblob/v-wenxichen/exp/log/test.log"
log_interval: int = 10
online_output_dir: Optional[str] = None
@dataclass
class InferenceConfig:
dataset_config: DataConfig = field(default_factory=DataConfig)
model_config: ModelConfig = field(default_factory=ModelConfig)
train_config: TrainConfig = field(default_factory=TrainConfig)
decode_config: DecodeConfig = field(default_factory=DecodeConfig)