Spaces:
Runtime error
Runtime error
File size: 9,339 Bytes
35c1cfd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 |
from dataclasses import dataclass, field
from typing import Optional, List
import os
CKPT_NAME = "model.pt"
CKPT_LOCAL_DIR = "model_ckpts"
CKPT_PATH = os.path.join(CKPT_LOCAL_DIR, CKPT_NAME)
CKPT_REPO = "xcczach/mini-omni"
@dataclass
class VocabConfig:
text_vocabsize: int = 151936
text_specialtokens: int = 64
audio_vocabsize: int = 4096
audio_specialtokens: int = 64
total_vocabsize: int = 181120
code_layer: int = 7
padded_text_vocabsize: int = field(init=False)
padded_audio_vocabsize: int = field(init=False)
total_audio_vocabsize: int = field(init=False)
eot: int = field(init=False) # end of text token
pad_t: int = field(init=False) # padding text token
input_t: int = field(init=False) # input text token
answer_t: int = field(init=False) # answer text token
asr: int = field(init=False) # ASR token
eoa: int = field(init=False) # end of audio token
pad_a: int = field(init=False) # padding audio token
input_a: int = field(init=False) # input audio token
answer_a: int = field(init=False) # answer audio token
split: int = field(init=False) # split token
def __post_init__(self):
self.padded_text_vocabsize = self.text_vocabsize + self.text_specialtokens
self.padded_audio_vocabsize = self.audio_vocabsize + self.audio_specialtokens
self.total_audio_vocabsize = self.padded_audio_vocabsize * self.code_layer
self.eot = self.text_vocabsize
self.pad_t = self.text_vocabsize + 1
self.input_t = self.text_vocabsize + 2
self.answer_t = self.text_vocabsize + 3
self.asr = self.text_vocabsize + 4
self.eoa = self.audio_vocabsize
self.pad_a = self.audio_vocabsize + 1
self.input_a = self.audio_vocabsize + 2
self.answer_a = self.audio_vocabsize + 3
self.split = self.audio_vocabsize + 4
@dataclass
class TTSAdapterConfig:
add_qkv_bias: Optional[bool] = True
bias: bool = False
gelu_approximate: Optional[str] = None
head_size: Optional[int] = 64
intermediate_size: Optional[int] = 4864
lm_head_bias: bool = False
mlp_class_name: str = "GptNeoxMLP"
n_layer: int = 6
n_head: int = 14
n_embd: int = 896
n_query_groups: Optional[int] = 2
norm_class_name: str = "RMSNorm"
norm_eps: float = 1e-6
parallel_residual: bool = False
rotary_percentage: float = 1
shared_attention_norm: bool = False
def __post_init__(self):
self.rope_n_elem = int(self.rotary_percentage * self.head_size)
@dataclass
class ModelConfig:
file: str = "model/slam_model_s2s.py:model_factory"
llm_name: str = "qwen2-0.5b"
llm_path: str = "Qwen/Qwen2-0.5B"
llm_type: str = "decoder_only"
llm_dim: int = 896
encoder_name: Optional[str] = "whisper"
encoder_ds_rate: int = 2
encoder_path: Optional[str] = "small"
encoder_dim: int = 768
encoder_projector: str = "linear"
encoder_projector_ds_rate: int = 5
modal: str = "audio"
normalize: Optional[bool] = field(
default=False,
metadata={"help": "whether input is normalized, used for models such as wavlm"},
)
encoder_type: str = field(
default="finetune",
metadata={
"help": "whether model is only pretrained or finetuned, used for models such as hubert"
},
)
vocab_config: VocabConfig = field(default_factory=VocabConfig)
codec_decode: bool = True
codec_decoder_type: str = "SNAC"
codec_decoder_path: Optional[str] = "hubertsiuzdak/snac_24khz"
tts_adapter: bool = False
tts_adapter_config: TTSAdapterConfig = field(default_factory=TTSAdapterConfig)
@dataclass
class PeftConfig:
peft_method: str = "lora" # None , llama_adapter, prefix
r: int = 8
lora_alpha: int = 32
target_modules: List = field(default_factory=lambda: ["q_proj", "v_proj"])
bias: str = "none"
task_type: str = "CAUSAL_LM"
lora_dropout: float = 0.05
inference_mode: bool = False
@dataclass
class TrainConfig:
model_name: str = "s2s"
enable_ddp: bool = False
enable_deepspeed: bool = False
enable_fsdp: bool = False
low_cpu_fsdp: bool = False
run_validation: bool = True
batch_size_training: int = 4
batching_strategy: str = field(
default="custom", metadata={"help": "alternative: padding"}
) #
context_length: int = 4096
gradient_accumulation_steps: int = 1
num_epochs: int = 1
num_workers_dataloader: int = 2
warmup_steps: int = 1000
total_steps: int = 100000
validation_interval: int = 1000
lr: float = 1e-4
weight_decay: float = 0.0
gamma: float = 0.85
seed: int = 42
use_fp16: bool = False
mixed_precision: bool = True
val_batch_size: int = 1
use_peft: bool = False
peft_config: PeftConfig = field(default_factory=PeftConfig)
output_dir: str = "PATH/to/save/PEFT/model"
freeze_layers: bool = False
num_freeze_layers: int = 1
quantization: bool = False
one_gpu: bool = False
save_model: bool = True
dist_checkpoint_root_folder: str = (
"PATH/to/save/FSDP/model" # will be used if using FSDP
)
dist_checkpoint_folder: str = "fine-tuned" # will be used if using FSDP
save_optimizer: bool = False # will be used if using FSDP
use_fast_kernels: bool = (
False # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
)
run_test_during_validation: bool = False
run_test_during_validation_file: str = "test.wav"
run_test_during_validation_prompt: str = "<|S2S|>"
freeze_llm: bool = field(
default=True,
metadata={
"help": "whether to freeze llm when finetuning, should be true when use peft finetuning"
},
)
freeze_encoder: bool = True
train_embed_only: bool = False
train_audio_embed_only: bool = False
task_type: str = "s2s"
@dataclass
class DataConfig:
dataset: str = "speech_dataset_s2s"
file: str = "examples/s2s/speech_dataset_s2s.py:get_speech_dataset"
train_data_path: Optional[str] = None
val_data_path: Optional[str] = None
train_split: str = "train"
test_split: str = "validation"
prompt: Optional[str] = None
data_path: Optional[str] = None
max_words: Optional[int] = None
max_mel: Optional[float] = None
fix_length_audio: int = -1
inference_mode: bool = True
input_type: str = field(
default="mel",
metadata={"help": "Use raw when input is wav, mel when for whisper"},
)
mel_size: int = field(
default=80, metadata={"help": "80 for whisper large v1 and v2, 128 for v3"}
)
normalize: Optional[bool] = field(
default=False,
metadata={"help": "whether input is normalized, used for models such as wavlm"},
)
seed: int = 42
manifest_format: str = field(
default="datasets", metadata={"help": "alternative: jsonl"}
)
split_size: float = 0.1
vocab_config: VocabConfig = field(default_factory=VocabConfig)
load_from_cache_file: bool = False
task_type: str = "s2s"
@dataclass
class DecodeConfig:
do_sample: bool = False
max_new_tokens: int = 300
min_length: int = 10
temperature: float = 1.0
top_k: int = 50
top_p: float = 0.9
num_beams: int = 1
num_return_sequences: int = 1
num_samples: int = 1
max_time: float = 0.0
repetition_penalty: float = 1.0
length_penalty: float = 1.0
early_stopping: bool = False
no_repeat_ngram_size: int = 0
bad_words_ids: List = field(default_factory=list)
num_beam_groups: int = 1
diversity_penalty: float = 0.0
task_type: str = "s2s"
decode_text_only: bool = False
@dataclass
class FSDPConfig:
mixed_precision: bool = True
use_fp16: bool = False
# sharding_strategy = "FULL_SHARD" #ShardingStrategy = ShardingStrategy.FULL_SHARD
sharding_strategy: str = (
"NO_SHARD" # ShardingStrategy.NO_SHARD #MZY: set NO_SHARD when use DDP
)
checkpoint_type: str = (
"SHARDED_STATE_DICT" # alternatively can use SHARDED_STATE_DICT save one file per rank, and can resize the world-size.
)
fsdp_activation_checkpointing: bool = True
fsdp_cpu_offload: bool = False
pure_bf16: bool = False
optimizer: str = "AdamW"
@dataclass
class LogConfig:
use_wandb: bool = False
wandb_dir: str = "/valleblob/v-wenxichen/exp/wandb_log"
wandb_entity_name: str = "project_name"
wandb_project_name: str = "project_name"
wandb_exp_name: str = "exp_name"
log_file: str = "/valleblob/v-wenxichen/exp/log/test.log"
log_interval: int = 10
online_output_dir: Optional[str] = None
@dataclass
class InferenceConfig:
dataset_config: DataConfig = field(default_factory=DataConfig)
model_config: ModelConfig = field(default_factory=ModelConfig)
train_config: TrainConfig = field(default_factory=TrainConfig)
decode_config: DecodeConfig = field(default_factory=DecodeConfig)
|