cocktailpeanut commited on
Commit
c7f3c38
1 Parent(s): a7f6109
Files changed (1) hide show
  1. app2.py +6 -3
app2.py CHANGED
@@ -47,7 +47,9 @@ speed = 1.0
47
  # fix_duration = 27 # None or float (duration in seconds)
48
  fix_duration = None
49
 
50
- def load_model(exp_name, model_cls, model_cfg, ckpt_step):
 
 
51
  checkpoint = torch.load(str(cached_path(f"hf://SWivid/F5-TTS/{exp_name}/model_{ckpt_step}.pt")), map_location=device)
52
  vocab_char_map, vocab_size = get_tokenizer("Emilia_ZH_EN", "pinyin")
53
  model = CFM(
@@ -77,8 +79,9 @@ def load_model(exp_name, model_cls, model_cfg, ckpt_step):
77
  F5TTS_model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
78
  E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
79
 
80
- F5TTS_ema_model, F5TTS_base_model = load_model("F5TTS_Base", DiT, F5TTS_model_cfg, 1200000)
81
- E2TTS_ema_model, E2TTS_base_model = load_model("E2TTS_Base", UNetT, E2TTS_model_cfg, 1200000)
 
82
 
83
  def chunk_text(text, max_chars=200):
84
  chunks = []
 
47
  # fix_duration = 27 # None or float (duration in seconds)
48
  fix_duration = None
49
 
50
+ def load_model(repo_name, exp_name, model_cls, model_cfg, ckpt_step):
51
+ checkpoint = torch.load(str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.pt")), map_location=device)
52
+
53
  checkpoint = torch.load(str(cached_path(f"hf://SWivid/F5-TTS/{exp_name}/model_{ckpt_step}.pt")), map_location=device)
54
  vocab_char_map, vocab_size = get_tokenizer("Emilia_ZH_EN", "pinyin")
55
  model = CFM(
 
79
  F5TTS_model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
80
  E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
81
 
82
+ F5TTS_ema_model = load_model("F5-TTS", "F5TTS_Base", DiT, F5TTS_model_cfg, 1200000)
83
+ E2TTS_ema_model = load_model("E2-TTS", "E2TTS_Base", UNetT, E2TTS_model_cfg, 1200000)
84
+
85
 
86
  def chunk_text(text, max_chars=200):
87
  chunks = []