mrfakename commited on
Commit
c0fb8c8
·
verified ·
1 Parent(s): 1bcb8fe

Sync from GitHub repo

Browse files

This Space is synced from the GitHub repo: https://github.com/SWivid/F5-TTS. Please submit contributions to the Space there

src/f5_tts/api.py CHANGED
@@ -15,6 +15,7 @@ from f5_tts.infer.utils_infer import (
15
  preprocess_ref_audio_text,
16
  remove_silence_for_generated_wav,
17
  save_spectrogram,
 
18
  target_sample_rate,
19
  )
20
  from f5_tts.model import DiT, UNetT
@@ -82,6 +83,9 @@ class F5TTS:
82
  model_cls, model_cfg, ckpt_file, mel_spec_type, vocab_file, ode_method, use_ema, self.device
83
  )
84
 
 
 
 
85
  def export_wav(self, wav, file_wave, remove_silence=False):
86
  sf.write(file_wave, wav, self.target_sample_rate)
87
 
 
15
  preprocess_ref_audio_text,
16
  remove_silence_for_generated_wav,
17
  save_spectrogram,
18
+ transcribe,
19
  target_sample_rate,
20
  )
21
  from f5_tts.model import DiT, UNetT
 
83
  model_cls, model_cfg, ckpt_file, mel_spec_type, vocab_file, ode_method, use_ema, self.device
84
  )
85
 
86
+ def transcribe(self, ref_audio, language=None):
87
+ return transcribe(ref_audio, language)
88
+
89
  def export_wav(self, wav, file_wave, remove_silence=False):
90
  sf.write(file_wave, wav, self.target_sample_rate)
91
 
src/f5_tts/infer/utils_infer.py CHANGED
@@ -150,6 +150,22 @@ def initialize_asr_pipeline(device=device, dtype=None):
150
  )
151
 
152
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
  # load model checkpoint for inference
154
 
155
 
@@ -306,17 +322,8 @@ def preprocess_ref_audio_text(ref_audio_orig, ref_text, clip_short=True, show_in
306
  show_info("Using cached reference text...")
307
  ref_text = _ref_audio_cache[audio_hash]
308
  else:
309
- global asr_pipe
310
- if asr_pipe is None:
311
- initialize_asr_pipeline(device=device)
312
  show_info("No reference text provided, transcribing reference audio...")
313
- ref_text = asr_pipe(
314
- ref_audio,
315
- chunk_length_s=30,
316
- batch_size=128,
317
- generate_kwargs={"task": "transcribe"},
318
- return_timestamps=False,
319
- )["text"].strip()
320
  # Cache the transcribed text (not caching custom ref_text, enabling users to do manual tweak)
321
  _ref_audio_cache[audio_hash] = ref_text
322
  else:
 
150
  )
151
 
152
 
153
+ # transcribe
154
+
155
+
156
+ def transcribe(ref_audio, language=None):
157
+ global asr_pipe
158
+ if asr_pipe is None:
159
+ initialize_asr_pipeline(device=device)
160
+ return asr_pipe(
161
+ ref_audio,
162
+ chunk_length_s=30,
163
+ batch_size=128,
164
+ generate_kwargs={"task": "transcribe", "language": language} if language else {"task": "transcribe"},
165
+ return_timestamps=False,
166
+ )["text"].strip()
167
+
168
+
169
  # load model checkpoint for inference
170
 
171
 
 
322
  show_info("Using cached reference text...")
323
  ref_text = _ref_audio_cache[audio_hash]
324
  else:
 
 
 
325
  show_info("No reference text provided, transcribing reference audio...")
326
+ ref_text = transcribe(ref_audio)
 
 
 
 
 
 
327
  # Cache the transcribed text (not caching custom ref_text, enabling users to do manual tweak)
328
  _ref_audio_cache[audio_hash] = ref_text
329
  else:
src/f5_tts/train/finetune_cli.py CHANGED
@@ -13,6 +13,9 @@ from importlib.resources import files
13
  target_sample_rate = 24000
14
  n_mel_channels = 100
15
  hop_length = 256
 
 
 
16
 
17
 
18
  # -------------------------- Argument Parsing --------------------------- #
@@ -40,7 +43,7 @@ def parse_args():
40
  parser.add_argument("--max_samples", type=int, default=64, help="Max sequences per batch")
41
  parser.add_argument("--grad_accumulation_steps", type=int, default=1, help="Gradient accumulation steps")
42
  parser.add_argument("--max_grad_norm", type=float, default=1.0, help="Max gradient norm for clipping")
43
- parser.add_argument("--epochs", type=int, default=10, help="Number of training epochs")
44
  parser.add_argument("--num_warmup_updates", type=int, default=300, help="Warmup steps")
45
  parser.add_argument("--save_per_updates", type=int, default=10000, help="Save checkpoint every X steps")
46
  parser.add_argument("--last_per_steps", type=int, default=50000, help="Save last checkpoint every X steps")
@@ -121,11 +124,15 @@ def main():
121
  vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer)
122
 
123
  print("\nvocab : ", vocab_size)
 
124
 
125
  mel_spec_kwargs = dict(
126
- target_sample_rate=target_sample_rate,
127
- n_mel_channels=n_mel_channels,
128
  hop_length=hop_length,
 
 
 
 
129
  )
130
 
131
  model = CFM(
 
13
  target_sample_rate = 24000
14
  n_mel_channels = 100
15
  hop_length = 256
16
+ win_length = 1024
17
+ n_fft = 1024
18
+ mel_spec_type = "vocos" # 'vocos' or 'bigvgan'
19
 
20
 
21
  # -------------------------- Argument Parsing --------------------------- #
 
43
  parser.add_argument("--max_samples", type=int, default=64, help="Max sequences per batch")
44
  parser.add_argument("--grad_accumulation_steps", type=int, default=1, help="Gradient accumulation steps")
45
  parser.add_argument("--max_grad_norm", type=float, default=1.0, help="Max gradient norm for clipping")
46
+ parser.add_argument("--epochs", type=int, default=100, help="Number of training epochs")
47
  parser.add_argument("--num_warmup_updates", type=int, default=300, help="Warmup steps")
48
  parser.add_argument("--save_per_updates", type=int, default=10000, help="Save checkpoint every X steps")
49
  parser.add_argument("--last_per_steps", type=int, default=50000, help="Save last checkpoint every X steps")
 
124
  vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer)
125
 
126
  print("\nvocab : ", vocab_size)
127
+ print("\nvocoder : ", mel_spec_type)
128
 
129
  mel_spec_kwargs = dict(
130
+ n_fft=n_fft,
 
131
  hop_length=hop_length,
132
+ win_length=win_length,
133
+ n_mel_channels=n_mel_channels,
134
+ target_sample_rate=target_sample_rate,
135
+ mel_spec_type=mel_spec_type,
136
  )
137
 
138
  model = CFM(
src/f5_tts/train/finetune_gradio.py CHANGED
@@ -26,12 +26,13 @@ from datasets import Dataset as Dataset_
26
  from datasets.arrow_writer import ArrowWriter
27
  from safetensors.torch import save_file
28
  from scipy.io import wavfile
29
- from transformers import pipeline
30
  from cached_path import cached_path
31
  from f5_tts.api import F5TTS
32
  from f5_tts.model.utils import convert_char_to_pinyin
 
33
  from importlib.resources import files
34
 
 
35
  training_process = None
36
  system = platform.system()
37
  python_executable = sys.executable or "python"
@@ -47,8 +48,6 @@ file_train = "src/f5_tts/train/finetune_cli.py"
47
 
48
  device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
49
 
50
- pipe = None
51
-
52
 
53
  # Save settings from a JSON file
54
  def save_settings(
@@ -70,6 +69,7 @@ def save_settings(
70
  tokenizer_file,
71
  mixed_precision,
72
  logger,
 
73
  ):
74
  path_project = os.path.join(path_project_ckpts, project_name)
75
  os.makedirs(path_project, exist_ok=True)
@@ -93,6 +93,7 @@ def save_settings(
93
  "tokenizer_file": tokenizer_file,
94
  "mixed_precision": mixed_precision,
95
  "logger": logger,
 
96
  }
97
  with open(file_setting, "w") as f:
98
  json.dump(settings, f, indent=4)
@@ -124,6 +125,7 @@ def load_settings(project_name):
124
  "tokenizer_file": "",
125
  "mixed_precision": "none",
126
  "logger": "wandb",
 
127
  }
128
  return (
129
  settings["exp_name"],
@@ -143,12 +145,15 @@ def load_settings(project_name):
143
  settings["tokenizer_file"],
144
  settings["mixed_precision"],
145
  settings["logger"],
 
146
  )
147
 
148
  with open(file_setting, "r") as f:
149
  settings = json.load(f)
150
  if "logger" not in settings:
151
  settings["logger"] = "wandb"
 
 
152
  return (
153
  settings["exp_name"],
154
  settings["learning_rate"],
@@ -167,6 +172,7 @@ def load_settings(project_name):
167
  settings["tokenizer_file"],
168
  settings["mixed_precision"],
169
  settings["logger"],
 
170
  )
171
 
172
 
@@ -381,18 +387,17 @@ def start_training(
381
  mixed_precision="fp16",
382
  stream=False,
383
  logger="wandb",
 
384
  ):
385
- global training_process, tts_api, stop_signal, pipe
386
 
387
- if tts_api is not None or pipe is not None:
388
  if tts_api is not None:
389
  del tts_api
390
- if pipe is not None:
391
- del pipe
392
  gc.collect()
393
  torch.cuda.empty_cache()
394
  tts_api = None
395
- pipe = None
396
 
397
  path_project = os.path.join(path_data, dataset_name)
398
 
@@ -447,11 +452,10 @@ def start_training(
447
  f"--dataset_name {dataset_name}"
448
  )
449
 
450
- if finetune:
451
- cmd += f" --finetune {finetune}"
452
 
453
  if file_checkpoint_train != "":
454
- cmd += f" --file_checkpoint_train {file_checkpoint_train}"
455
 
456
  if tokenizer_file != "":
457
  cmd += f" --tokenizer_path {tokenizer_file}"
@@ -460,7 +464,10 @@ def start_training(
460
 
461
  cmd += f" --log_samples True --logger {logger} "
462
 
463
- print(cmd)
 
 
 
464
 
465
  save_settings(
466
  dataset_name,
@@ -481,6 +488,7 @@ def start_training(
481
  tokenizer_file,
482
  mixed_precision,
483
  logger,
 
484
  )
485
 
486
  try:
@@ -641,27 +649,6 @@ def create_data_project(name, tokenizer_type):
641
  return gr.update(choices=project_list, value=name)
642
 
643
 
644
- def transcribe(file_audio, language="english"):
645
- global pipe
646
-
647
- if pipe is None:
648
- pipe = pipeline(
649
- "automatic-speech-recognition",
650
- model="openai/whisper-large-v3-turbo",
651
- torch_dtype=torch.float16,
652
- device=device,
653
- )
654
-
655
- text_transcribe = pipe(
656
- file_audio,
657
- chunk_length_s=30,
658
- batch_size=128,
659
- generate_kwargs={"task": "transcribe", "language": language},
660
- return_timestamps=False,
661
- )["text"].strip()
662
- return text_transcribe
663
-
664
-
665
  def transcribe_all(name_project, audio_files, language, user=False, progress=gr.Progress()):
666
  path_project = os.path.join(path_data, name_project)
667
  path_dataset = os.path.join(path_project, "dataset")
@@ -758,11 +745,9 @@ def get_correct_audio_path(
758
  # Case 2: If it has a supported extension but is not a full path
759
  elif has_supported_extension(audio_input) and not os.path.isabs(audio_input):
760
  file_audio = os.path.join(base_path, audio_input)
761
- print("2")
762
 
763
  # Case 3: If only the name is given (no extension and not a full path)
764
  elif not has_supported_extension(audio_input) and not os.path.isabs(audio_input):
765
- print("3")
766
  for ext in supported_formats:
767
  potential_file = os.path.join(base_path, f"{audio_input}.{ext}")
768
  if os.path.exists(potential_file):
@@ -816,9 +801,12 @@ def create_metadata(name_project, ch_tokenizer, progress=gr.Progress()):
816
  continue
817
 
818
  if duration < 1 or duration > 25:
819
- error_files.append([file_audio, "duration < 1 or > 25 "])
 
 
 
820
  continue
821
- if len(text) < 4:
822
  error_files.append([file_audio, "very small text len 3"])
823
  continue
824
 
@@ -1208,7 +1196,9 @@ def get_random_sample_infer(project_name):
1208
  )
1209
 
1210
 
1211
- def infer(project, file_checkpoint, exp_name, ref_text, ref_audio, gen_text, nfe_step, use_ema):
 
 
1212
  global last_checkpoint, last_device, tts_api, last_ema
1213
 
1214
  if not os.path.isfile(file_checkpoint):
@@ -1238,8 +1228,17 @@ def infer(project, file_checkpoint, exp_name, ref_text, ref_audio, gen_text, nfe
1238
  print("update >> ", device_test, file_checkpoint, use_ema)
1239
 
1240
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
1241
- tts_api.infer(gen_text=gen_text, ref_text=ref_text, ref_file=ref_audio, nfe_step=nfe_step, file_wave=f.name)
1242
- return f.name, tts_api.device
 
 
 
 
 
 
 
 
 
1243
 
1244
 
1245
  def check_finetune(finetune):
@@ -1506,6 +1505,7 @@ Skip this step if you have your dataset, raw.arrow, duration.json, and vocab.txt
1506
  ```"""
1507
  )
1508
  ch_tokenizern = gr.Checkbox(label="Create Vocabulary", value=False, visible=False)
 
1509
  bt_prepare = bt_create = gr.Button("Prepare")
1510
  txt_info_prepare = gr.Text(label="Info", value="")
1511
  txt_vocab_prepare = gr.Text(label="Vocab", value="")
@@ -1560,6 +1560,7 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
1560
  last_per_steps = gr.Number(label="Last per Steps", value=100)
1561
 
1562
  with gr.Row():
 
1563
  mixed_precision = gr.Radio(label="mixed_precision", choices=["none", "fp16", "bf16"], value="none")
1564
  cd_logger = gr.Radio(label="logger", choices=["wandb", "tensorboard"], value="wandb")
1565
  start_button = gr.Button("Start Training")
@@ -1584,6 +1585,7 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
1584
  tokenizer_filev,
1585
  mixed_precisionv,
1586
  cd_loggerv,
 
1587
  ) = load_settings(projects_selelect)
1588
  exp_name.value = exp_namev
1589
  learning_rate.value = learning_ratev
@@ -1602,6 +1604,7 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
1602
  tokenizer_file.value = tokenizer_filev
1603
  mixed_precision.value = mixed_precisionv
1604
  cd_logger.value = cd_loggerv
 
1605
 
1606
  ch_stream = gr.Checkbox(label="Stream Output Experiment", value=True)
1607
  txt_info_train = gr.Text(label="Info", value="")
@@ -1660,6 +1663,7 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
1660
  mixed_precision,
1661
  ch_stream,
1662
  cd_logger,
 
1663
  ],
1664
  outputs=[txt_info_train, start_button, stop_button],
1665
  )
@@ -1732,12 +1736,17 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
1732
 
1733
  with gr.TabItem("Test Model"):
1734
  gr.Markdown("""```plaintext
1735
- SOS: Check the use_ema setting (True or False) for your model to see what works best for you.
1736
  ```""")
1737
  exp_name = gr.Radio(label="Model", choices=["F5-TTS", "E2-TTS"], value="F5-TTS")
1738
  list_checkpoints, checkpoint_select = get_checkpoints_project(projects_selelect, False)
1739
 
1740
- nfe_step = gr.Number(label="NFE Step", value=32)
 
 
 
 
 
1741
  ch_use_ema = gr.Checkbox(label="Use EMA", value=True)
1742
  with gr.Row():
1743
  cm_checkpoint = gr.Dropdown(
@@ -1757,14 +1766,27 @@ SOS: Check the use_ema setting (True or False) for your model to see what works
1757
 
1758
  with gr.Row():
1759
  txt_info_gpu = gr.Textbox("", label="Device")
 
1760
  check_button_infer = gr.Button("Infer")
1761
 
1762
  gen_audio = gr.Audio(label="Audio Gen", type="filepath")
1763
 
1764
  check_button_infer.click(
1765
  fn=infer,
1766
- inputs=[cm_project, cm_checkpoint, exp_name, ref_text, ref_audio, gen_text, nfe_step, ch_use_ema],
1767
- outputs=[gen_audio, txt_info_gpu],
 
 
 
 
 
 
 
 
 
 
 
 
1768
  )
1769
 
1770
  bt_checkpoint_refresh.click(fn=get_checkpoints_project, inputs=[cm_project], outputs=[cm_checkpoint])
 
26
  from datasets.arrow_writer import ArrowWriter
27
  from safetensors.torch import save_file
28
  from scipy.io import wavfile
 
29
  from cached_path import cached_path
30
  from f5_tts.api import F5TTS
31
  from f5_tts.model.utils import convert_char_to_pinyin
32
+ from f5_tts.infer.utils_infer import transcribe
33
  from importlib.resources import files
34
 
35
+
36
  training_process = None
37
  system = platform.system()
38
  python_executable = sys.executable or "python"
 
48
 
49
  device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
50
 
 
 
51
 
52
  # Save settings from a JSON file
53
  def save_settings(
 
69
  tokenizer_file,
70
  mixed_precision,
71
  logger,
72
+ ch_8bit_adam,
73
  ):
74
  path_project = os.path.join(path_project_ckpts, project_name)
75
  os.makedirs(path_project, exist_ok=True)
 
93
  "tokenizer_file": tokenizer_file,
94
  "mixed_precision": mixed_precision,
95
  "logger": logger,
96
+ "bnb_optimizer": ch_8bit_adam,
97
  }
98
  with open(file_setting, "w") as f:
99
  json.dump(settings, f, indent=4)
 
125
  "tokenizer_file": "",
126
  "mixed_precision": "none",
127
  "logger": "wandb",
128
+ "bnb_optimizer": False,
129
  }
130
  return (
131
  settings["exp_name"],
 
145
  settings["tokenizer_file"],
146
  settings["mixed_precision"],
147
  settings["logger"],
148
+ settings["bnb_optimizer"],
149
  )
150
 
151
  with open(file_setting, "r") as f:
152
  settings = json.load(f)
153
  if "logger" not in settings:
154
  settings["logger"] = "wandb"
155
+ if "bnb_optimizer" not in settings:
156
+ settings["bnb_optimizer"] = False
157
  return (
158
  settings["exp_name"],
159
  settings["learning_rate"],
 
172
  settings["tokenizer_file"],
173
  settings["mixed_precision"],
174
  settings["logger"],
175
+ settings["bnb_optimizer"],
176
  )
177
 
178
 
 
387
  mixed_precision="fp16",
388
  stream=False,
389
  logger="wandb",
390
+ ch_8bit_adam=False,
391
  ):
392
+ global training_process, tts_api, stop_signal
393
 
394
+ if tts_api is not None:
395
  if tts_api is not None:
396
  del tts_api
397
+
 
398
  gc.collect()
399
  torch.cuda.empty_cache()
400
  tts_api = None
 
401
 
402
  path_project = os.path.join(path_data, dataset_name)
403
 
 
452
  f"--dataset_name {dataset_name}"
453
  )
454
 
455
+ cmd += f" --finetune {finetune}"
 
456
 
457
  if file_checkpoint_train != "":
458
+ cmd += f" --pretrain {file_checkpoint_train}"
459
 
460
  if tokenizer_file != "":
461
  cmd += f" --tokenizer_path {tokenizer_file}"
 
464
 
465
  cmd += f" --log_samples True --logger {logger} "
466
 
467
+ if ch_8bit_adam:
468
+ cmd += " --bnb_optimizer True "
469
+
470
+ print("run command : \n" + cmd + "\n")
471
 
472
  save_settings(
473
  dataset_name,
 
488
  tokenizer_file,
489
  mixed_precision,
490
  logger,
491
+ ch_8bit_adam,
492
  )
493
 
494
  try:
 
649
  return gr.update(choices=project_list, value=name)
650
 
651
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
652
  def transcribe_all(name_project, audio_files, language, user=False, progress=gr.Progress()):
653
  path_project = os.path.join(path_data, name_project)
654
  path_dataset = os.path.join(path_project, "dataset")
 
745
  # Case 2: If it has a supported extension but is not a full path
746
  elif has_supported_extension(audio_input) and not os.path.isabs(audio_input):
747
  file_audio = os.path.join(base_path, audio_input)
 
748
 
749
  # Case 3: If only the name is given (no extension and not a full path)
750
  elif not has_supported_extension(audio_input) and not os.path.isabs(audio_input):
 
751
  for ext in supported_formats:
752
  potential_file = os.path.join(base_path, f"{audio_input}.{ext}")
753
  if os.path.exists(potential_file):
 
801
  continue
802
 
803
  if duration < 1 or duration > 25:
804
+ if duration > 25:
805
+ error_files.append([file_audio, "duration > 25 sec"])
806
+ if duration < 1:
807
+ error_files.append([file_audio, "duration < 1 sec "])
808
  continue
809
+ if len(text) < 3:
810
  error_files.append([file_audio, "very small text len 3"])
811
  continue
812
 
 
1196
  )
1197
 
1198
 
1199
+ def infer(
1200
+ project, file_checkpoint, exp_name, ref_text, ref_audio, gen_text, nfe_step, use_ema, speed, seed, remove_silence
1201
+ ):
1202
  global last_checkpoint, last_device, tts_api, last_ema
1203
 
1204
  if not os.path.isfile(file_checkpoint):
 
1228
  print("update >> ", device_test, file_checkpoint, use_ema)
1229
 
1230
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
1231
+ tts_api.infer(
1232
+ gen_text=gen_text.lower().strip(),
1233
+ ref_text=ref_text.lower().strip(),
1234
+ ref_file=ref_audio,
1235
+ nfe_step=nfe_step,
1236
+ file_wave=f.name,
1237
+ speed=speed,
1238
+ seed=seed,
1239
+ remove_silence=remove_silence,
1240
+ )
1241
+ return f.name, tts_api.device, str(tts_api.seed)
1242
 
1243
 
1244
  def check_finetune(finetune):
 
1505
  ```"""
1506
  )
1507
  ch_tokenizern = gr.Checkbox(label="Create Vocabulary", value=False, visible=False)
1508
+
1509
  bt_prepare = bt_create = gr.Button("Prepare")
1510
  txt_info_prepare = gr.Text(label="Info", value="")
1511
  txt_vocab_prepare = gr.Text(label="Vocab", value="")
 
1560
  last_per_steps = gr.Number(label="Last per Steps", value=100)
1561
 
1562
  with gr.Row():
1563
+ ch_8bit_adam = gr.Checkbox(label="Use 8-bit Adam optimizer")
1564
  mixed_precision = gr.Radio(label="mixed_precision", choices=["none", "fp16", "bf16"], value="none")
1565
  cd_logger = gr.Radio(label="logger", choices=["wandb", "tensorboard"], value="wandb")
1566
  start_button = gr.Button("Start Training")
 
1585
  tokenizer_filev,
1586
  mixed_precisionv,
1587
  cd_loggerv,
1588
+ ch_8bit_adamv,
1589
  ) = load_settings(projects_selelect)
1590
  exp_name.value = exp_namev
1591
  learning_rate.value = learning_ratev
 
1604
  tokenizer_file.value = tokenizer_filev
1605
  mixed_precision.value = mixed_precisionv
1606
  cd_logger.value = cd_loggerv
1607
+ ch_8bit_adam.value = ch_8bit_adamv
1608
 
1609
  ch_stream = gr.Checkbox(label="Stream Output Experiment", value=True)
1610
  txt_info_train = gr.Text(label="Info", value="")
 
1663
  mixed_precision,
1664
  ch_stream,
1665
  cd_logger,
1666
+ ch_8bit_adam,
1667
  ],
1668
  outputs=[txt_info_train, start_button, stop_button],
1669
  )
 
1736
 
1737
  with gr.TabItem("Test Model"):
1738
  gr.Markdown("""```plaintext
1739
+ SOS: Check the use_ema setting (True or False) for your model to see what works best for you. use seed -1 from random
1740
  ```""")
1741
  exp_name = gr.Radio(label="Model", choices=["F5-TTS", "E2-TTS"], value="F5-TTS")
1742
  list_checkpoints, checkpoint_select = get_checkpoints_project(projects_selelect, False)
1743
 
1744
+ with gr.Row():
1745
+ nfe_step = gr.Number(label="NFE Step", value=32)
1746
+ speed = gr.Slider(label="Speed", value=1.0, minimum=0.3, maximum=2.0, step=0.1)
1747
+ seed = gr.Number(label="Seed", value=-1, minimum=-1)
1748
+ remove_silence = gr.Checkbox(label="Remove Silence")
1749
+
1750
  ch_use_ema = gr.Checkbox(label="Use EMA", value=True)
1751
  with gr.Row():
1752
  cm_checkpoint = gr.Dropdown(
 
1766
 
1767
  with gr.Row():
1768
  txt_info_gpu = gr.Textbox("", label="Device")
1769
+ seed_info = gr.Text(label="Seed :")
1770
  check_button_infer = gr.Button("Infer")
1771
 
1772
  gen_audio = gr.Audio(label="Audio Gen", type="filepath")
1773
 
1774
  check_button_infer.click(
1775
  fn=infer,
1776
+ inputs=[
1777
+ cm_project,
1778
+ cm_checkpoint,
1779
+ exp_name,
1780
+ ref_text,
1781
+ ref_audio,
1782
+ gen_text,
1783
+ nfe_step,
1784
+ ch_use_ema,
1785
+ speed,
1786
+ seed,
1787
+ remove_silence,
1788
+ ],
1789
+ outputs=[gen_audio, txt_info_gpu, seed_info],
1790
  )
1791
 
1792
  bt_checkpoint_refresh.click(fn=get_checkpoints_project, inputs=[cm_project], outputs=[cm_checkpoint])