ing0 commited on
Commit
4e97955
·
1 Parent(s): b96e750
diffrhythm/infer/infer.py CHANGED
@@ -72,7 +72,7 @@ def decode_audio(latents, vae_model, chunked=False, overlap=32, chunk_size=128):
72
  y_final[:,:,t_start:t_end] = y_chunk[:,:,chunk_start:chunk_end]
73
  return y_final
74
 
75
- def inference(cfm_model, vae_model, cond, text, duration, style_prompt, negative_style_prompt, start_time):
76
  # import pdb; pdb.set_trace()
77
  with torch.inference_mode():
78
  generated, _ = cfm_model.sample(
 
72
  y_final[:,:,t_start:t_end] = y_chunk[:,:,chunk_start:chunk_end]
73
  return y_final
74
 
75
+ def inference(cfm_model, vae_model, cond, text, duration, style_prompt, negative_style_prompt, start_time, steps):
76
  # import pdb; pdb.set_trace()
77
  with torch.inference_mode():
78
  generated, _ = cfm_model.sample(
diffrhythm/infer/infer_utils.py CHANGED
@@ -6,14 +6,14 @@ from muq import MuQMuLan
6
  from mutagen.mp3 import MP3
7
  import os
8
  import numpy as np
9
-
10
  from diffrhythm.model import DiT, CFM
11
 
12
 
13
  def prepare_model(device):
14
  # prepare cfm model
15
- dit_ckpt_path = "/home/node59_tmpdata3/hkchen/music_opensource/dit_model_dpo_normal.pt"
16
- dit_config_path = "/home/node59_tmpdata3/hkchen/DiffRhythm/diffrhythm/diffrhythm/config/diffrhythm-1b.json"
17
  with open(dit_config_path) as f:
18
  model_config = json.load(f)
19
  dit_model_cls = DiT
@@ -33,7 +33,8 @@ def prepare_model(device):
33
  muq = muq.to(device).eval()
34
 
35
  # prepare vae
36
- vae = torch.jit.load("/home/node59_tmpdata3/hkchen/F5-TTS-V0/infer/vae_infer.pt").to(device)
 
37
 
38
  return cfm, tokenizer, muq, vae
39
 
@@ -43,7 +44,7 @@ def get_reference_latent(device, max_frames):
43
  return torch.zeros(1, max_frames, 64).to(device)
44
 
45
  def get_negative_style_prompt(device):
46
- file_path = "/home/node59_tmpdata3/hkchen/DiffRhythm/diffrhythm/diffrhythm/infer/example/vocal.npy"
47
  vocal_stlye = np.load(file_path)
48
 
49
  vocal_stlye = torch.from_numpy(vocal_stlye).to(device) # [1, 512]
 
6
  from mutagen.mp3 import MP3
7
  import os
8
  import numpy as np
9
+ from huggingface_hub import hf_hub_download
10
  from diffrhythm.model import DiT, CFM
11
 
12
 
13
  def prepare_model(device):
14
  # prepare cfm model
15
+ dit_ckpt_path = hf_hub_download(repo_id="ASLP-lab/DiffRhythm-base", filename="cfm_model.pt")
16
+ dit_config_path = "./diffrhythm/config/diffrhythm-1b.json"
17
  with open(dit_config_path) as f:
18
  model_config = json.load(f)
19
  dit_model_cls = DiT
 
33
  muq = muq.to(device).eval()
34
 
35
  # prepare vae
36
+ vae_ckpt_path = hf_hub_download(repo_id="ASLP-lab/DiffRhythm-vae", filename="vae_model.pt")
37
+ vae = torch.jit.load(vae_ckpt_path).to(device)
38
 
39
  return cfm, tokenizer, muq, vae
40
 
 
44
  return torch.zeros(1, max_frames, 64).to(device)
45
 
46
  def get_negative_style_prompt(device):
47
+ file_path = "./prompt/negative_prompt.npy"
48
  vocal_stlye = np.load(file_path)
49
 
50
  vocal_stlye = torch.from_numpy(vocal_stlye).to(device) # [1, 512]
prompt/negative_prompt.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6cb7d74eb7a8eda12acb8247b21d373928301db8a8cb0db480d341799fed3ce5
3
+ size 2176