Spaces:
Runtime error
Runtime error
from dataclasses import dataclass, field | |
from typing import Optional, List | |
import os | |
CKPT_NAME = "model.pt" | |
CKPT_LOCAL_DIR = "model_ckpts" | |
CKPT_PATH = os.path.join(CKPT_LOCAL_DIR, CKPT_NAME) | |
CKPT_REPO = "xcczach/mini-omni" | |
class VocabConfig: | |
text_vocabsize: int = 151936 | |
text_specialtokens: int = 64 | |
audio_vocabsize: int = 4096 | |
audio_specialtokens: int = 64 | |
total_vocabsize: int = 181120 | |
code_layer: int = 7 | |
padded_text_vocabsize: int = field(init=False) | |
padded_audio_vocabsize: int = field(init=False) | |
total_audio_vocabsize: int = field(init=False) | |
eot: int = field(init=False) # end of text token | |
pad_t: int = field(init=False) # padding text token | |
input_t: int = field(init=False) # input text token | |
answer_t: int = field(init=False) # answer text token | |
asr: int = field(init=False) # ASR token | |
eoa: int = field(init=False) # end of audio token | |
pad_a: int = field(init=False) # padding audio token | |
input_a: int = field(init=False) # input audio token | |
answer_a: int = field(init=False) # answer audio token | |
split: int = field(init=False) # split token | |
def __post_init__(self): | |
self.padded_text_vocabsize = self.text_vocabsize + self.text_specialtokens | |
self.padded_audio_vocabsize = self.audio_vocabsize + self.audio_specialtokens | |
self.total_audio_vocabsize = self.padded_audio_vocabsize * self.code_layer | |
self.eot = self.text_vocabsize | |
self.pad_t = self.text_vocabsize + 1 | |
self.input_t = self.text_vocabsize + 2 | |
self.answer_t = self.text_vocabsize + 3 | |
self.asr = self.text_vocabsize + 4 | |
self.eoa = self.audio_vocabsize | |
self.pad_a = self.audio_vocabsize + 1 | |
self.input_a = self.audio_vocabsize + 2 | |
self.answer_a = self.audio_vocabsize + 3 | |
self.split = self.audio_vocabsize + 4 | |
class TTSAdapterConfig: | |
add_qkv_bias: Optional[bool] = True | |
bias: bool = False | |
gelu_approximate: Optional[str] = None | |
head_size: Optional[int] = 64 | |
intermediate_size: Optional[int] = 4864 | |
lm_head_bias: bool = False | |
mlp_class_name: str = "GptNeoxMLP" | |
n_layer: int = 6 | |
n_head: int = 14 | |
n_embd: int = 896 | |
n_query_groups: Optional[int] = 2 | |
norm_class_name: str = "RMSNorm" | |
norm_eps: float = 1e-6 | |
parallel_residual: bool = False | |
rotary_percentage: float = 1 | |
shared_attention_norm: bool = False | |
def __post_init__(self): | |
self.rope_n_elem = int(self.rotary_percentage * self.head_size) | |
class ModelConfig: | |
file: str = "model/slam_model_s2s.py:model_factory" | |
llm_name: str = "qwen2-0.5b" | |
llm_path: str = "Qwen/Qwen2-0.5B" | |
llm_type: str = "decoder_only" | |
llm_dim: int = 896 | |
encoder_name: Optional[str] = "whisper" | |
encoder_ds_rate: int = 2 | |
encoder_path: Optional[str] = "small" | |
encoder_dim: int = 768 | |
encoder_projector: str = "linear" | |
encoder_projector_ds_rate: int = 5 | |
modal: str = "audio" | |
normalize: Optional[bool] = field( | |
default=False, | |
metadata={"help": "whether input is normalized, used for models such as wavlm"}, | |
) | |
encoder_type: str = field( | |
default="finetune", | |
metadata={ | |
"help": "whether model is only pretrained or finetuned, used for models such as hubert" | |
}, | |
) | |
vocab_config: VocabConfig = field(default_factory=VocabConfig) | |
codec_decode: bool = True | |
codec_decoder_type: str = "SNAC" | |
codec_decoder_path: Optional[str] = "hubertsiuzdak/snac_24khz" | |
tts_adapter: bool = False | |
tts_adapter_config: TTSAdapterConfig = field(default_factory=TTSAdapterConfig) | |
class PeftConfig: | |
peft_method: str = "lora" # None , llama_adapter, prefix | |
r: int = 8 | |
lora_alpha: int = 32 | |
target_modules: List = field(default_factory=lambda: ["q_proj", "v_proj"]) | |
bias: str = "none" | |
task_type: str = "CAUSAL_LM" | |
lora_dropout: float = 0.05 | |
inference_mode: bool = False | |
class TrainConfig: | |
model_name: str = "s2s" | |
enable_ddp: bool = False | |
enable_deepspeed: bool = False | |
enable_fsdp: bool = False | |
low_cpu_fsdp: bool = False | |
run_validation: bool = True | |
batch_size_training: int = 4 | |
batching_strategy: str = field( | |
default="custom", metadata={"help": "alternative: padding"} | |
) # | |
context_length: int = 4096 | |
gradient_accumulation_steps: int = 1 | |
num_epochs: int = 1 | |
num_workers_dataloader: int = 2 | |
warmup_steps: int = 1000 | |
total_steps: int = 100000 | |
validation_interval: int = 1000 | |
lr: float = 1e-4 | |
weight_decay: float = 0.0 | |
gamma: float = 0.85 | |
seed: int = 42 | |
use_fp16: bool = False | |
mixed_precision: bool = True | |
val_batch_size: int = 1 | |
use_peft: bool = False | |
peft_config: PeftConfig = field(default_factory=PeftConfig) | |
output_dir: str = "PATH/to/save/PEFT/model" | |
freeze_layers: bool = False | |
num_freeze_layers: int = 1 | |
quantization: bool = False | |
one_gpu: bool = False | |
save_model: bool = True | |
dist_checkpoint_root_folder: str = ( | |
"PATH/to/save/FSDP/model" # will be used if using FSDP | |
) | |
dist_checkpoint_folder: str = "fine-tuned" # will be used if using FSDP | |
save_optimizer: bool = False # will be used if using FSDP | |
use_fast_kernels: bool = ( | |
False # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels | |
) | |
run_test_during_validation: bool = False | |
run_test_during_validation_file: str = "test.wav" | |
run_test_during_validation_prompt: str = "<|S2S|>" | |
freeze_llm: bool = field( | |
default=True, | |
metadata={ | |
"help": "whether to freeze llm when finetuning, should be true when use peft finetuning" | |
}, | |
) | |
freeze_encoder: bool = True | |
train_embed_only: bool = False | |
train_audio_embed_only: bool = False | |
task_type: str = "s2s" | |
class DataConfig: | |
dataset: str = "speech_dataset_s2s" | |
file: str = "examples/s2s/speech_dataset_s2s.py:get_speech_dataset" | |
train_data_path: Optional[str] = None | |
val_data_path: Optional[str] = None | |
train_split: str = "train" | |
test_split: str = "validation" | |
prompt: Optional[str] = None | |
data_path: Optional[str] = None | |
max_words: Optional[int] = None | |
max_mel: Optional[float] = None | |
fix_length_audio: int = -1 | |
inference_mode: bool = True | |
input_type: str = field( | |
default="mel", | |
metadata={"help": "Use raw when input is wav, mel when for whisper"}, | |
) | |
mel_size: int = field( | |
default=80, metadata={"help": "80 for whisper large v1 and v2, 128 for v3"} | |
) | |
normalize: Optional[bool] = field( | |
default=False, | |
metadata={"help": "whether input is normalized, used for models such as wavlm"}, | |
) | |
seed: int = 42 | |
manifest_format: str = field( | |
default="datasets", metadata={"help": "alternative: jsonl"} | |
) | |
split_size: float = 0.1 | |
vocab_config: VocabConfig = field(default_factory=VocabConfig) | |
load_from_cache_file: bool = False | |
task_type: str = "s2s" | |
class DecodeConfig: | |
do_sample: bool = False | |
max_new_tokens: int = 300 | |
min_length: int = 10 | |
temperature: float = 1.0 | |
top_k: int = 50 | |
top_p: float = 0.9 | |
num_beams: int = 1 | |
num_return_sequences: int = 1 | |
num_samples: int = 1 | |
max_time: float = 0.0 | |
repetition_penalty: float = 1.0 | |
length_penalty: float = 1.0 | |
early_stopping: bool = False | |
no_repeat_ngram_size: int = 0 | |
bad_words_ids: List = field(default_factory=list) | |
num_beam_groups: int = 1 | |
diversity_penalty: float = 0.0 | |
task_type: str = "s2s" | |
decode_text_only: bool = False | |
class FSDPConfig: | |
mixed_precision: bool = True | |
use_fp16: bool = False | |
# sharding_strategy = "FULL_SHARD" #ShardingStrategy = ShardingStrategy.FULL_SHARD | |
sharding_strategy: str = ( | |
"NO_SHARD" # ShardingStrategy.NO_SHARD #MZY: set NO_SHARD when use DDP | |
) | |
checkpoint_type: str = ( | |
"SHARDED_STATE_DICT" # alternatively can use SHARDED_STATE_DICT save one file per rank, and can resize the world-size. | |
) | |
fsdp_activation_checkpointing: bool = True | |
fsdp_cpu_offload: bool = False | |
pure_bf16: bool = False | |
optimizer: str = "AdamW" | |
class LogConfig: | |
use_wandb: bool = False | |
wandb_dir: str = "/valleblob/v-wenxichen/exp/wandb_log" | |
wandb_entity_name: str = "project_name" | |
wandb_project_name: str = "project_name" | |
wandb_exp_name: str = "exp_name" | |
log_file: str = "/valleblob/v-wenxichen/exp/log/test.log" | |
log_interval: int = 10 | |
online_output_dir: Optional[str] = None | |
class InferenceConfig: | |
dataset_config: DataConfig = field(default_factory=DataConfig) | |
model_config: ModelConfig = field(default_factory=ModelConfig) | |
train_config: TrainConfig = field(default_factory=TrainConfig) | |
decode_config: DecodeConfig = field(default_factory=DecodeConfig) | |