ing0 commited on
Commit
ccebb03
·
1 Parent(s): 4e97955
Files changed (2) hide show
  1. app.py +24 -9
  2. diffrhythm/infer/infer.py +6 -5
app.py CHANGED
@@ -22,16 +22,13 @@ from diffrhythm.infer.infer_utils import (
22
  )
23
  from diffrhythm.infer.infer import inference
24
 
25
- device='cuda'
26
  cfm, tokenizer, muq, vae = prepare_model(device)
27
  cfm = torch.compile(cfm)
28
 
29
- def infer_music(lrc, ref_audio_path, max_frames=2048, device='cuda'):
30
-
31
- # lrc_list = lrc.split("\n")
32
- # print(lrc_list)
33
-
34
- # return "./gift_of_the_world.wav"
35
  lrc_prompt, start_time = get_lrc_token(lrc, tokenizer, device)
36
  style_prompt = get_style_prompt(muq, ref_audio_path)
37
  negative_style_prompt = get_negative_style_prompt(device)
@@ -43,6 +40,8 @@ def infer_music(lrc, ref_audio_path, max_frames=2048, device='cuda'):
43
  duration=max_frames,
44
  style_prompt=style_prompt,
45
  negative_style_prompt=negative_style_prompt,
 
 
46
  start_time=start_time
47
  )
48
  return generated_song
@@ -150,6 +149,22 @@ with gr.Blocks(css=css) as demo:
150
  audio_prompt = gr.Audio(label="Audio Prompt", type="filepath")
151
 
152
  with gr.Column():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
  lyrics_btn = gr.Button("Submit", variant="primary")
154
  audio_output = gr.Audio(label="Audio Result", type="filepath", elem_id="audio_output")
155
 
@@ -210,7 +225,7 @@ with gr.Blocks(css=css) as demo:
210
  [01:24.20]Your laughter spins aurora threads
211
  [01:28.65]Weaving dawn through featherbed"""]
212
  ],
213
- inputs=[lrc], # 只绑定到歌词输入
214
  label="Lrc Examples",
215
  examples_per_page=2
216
  )
@@ -306,7 +321,7 @@ with gr.Blocks(css=css) as demo:
306
 
307
  lyrics_btn.click(
308
  fn=infer_music,
309
- inputs=[lrc, audio_prompt],
310
  outputs=audio_output
311
  )
312
 
 
22
  )
23
  from diffrhythm.infer.infer import inference
24
 
25
+ device='cpu'
26
  cfm, tokenizer, muq, vae = prepare_model(device)
27
  cfm = torch.compile(cfm)
28
 
29
+ def infer_music(lrc, ref_audio_path, steps, sway_sampling_coef_bool, max_frames=2048, device='cpu'):
30
+
31
+ sway_sampling_coef = -1 if sway_sampling_coef_bool else None
 
 
 
32
  lrc_prompt, start_time = get_lrc_token(lrc, tokenizer, device)
33
  style_prompt = get_style_prompt(muq, ref_audio_path)
34
  negative_style_prompt = get_negative_style_prompt(device)
 
40
  duration=max_frames,
41
  style_prompt=style_prompt,
42
  negative_style_prompt=negative_style_prompt,
43
+ steps=steps,
44
+ sway_sampling_coef=sway_sampling_coef,
45
  start_time=start_time
46
  )
47
  return generated_song
 
149
  audio_prompt = gr.Audio(label="Audio Prompt", type="filepath")
150
 
151
  with gr.Column():
152
+ steps = gr.Slider(
153
+ minimum=10,
154
+ maximum=40,
155
+ value=32,
156
+ step=1,
157
+ label="Diffusion Steps",
158
+ interactive=True,
159
+ elem_id="step_slider"
160
+ )
161
+ sway_sampling_coef_bool = gr.Radio(
162
+ choices=[("False", False), ("True", True)],
163
+ label="Use sway_sampling_coef",
164
+ value=False,
165
+ interactive=True,
166
+ elem_classes="horizontal-radio"
167
+ )
168
  lyrics_btn = gr.Button("Submit", variant="primary")
169
  audio_output = gr.Audio(label="Audio Result", type="filepath", elem_id="audio_output")
170
 
 
225
  [01:24.20]Your laughter spins aurora threads
226
  [01:28.65]Weaving dawn through featherbed"""]
227
  ],
228
+ inputs=[lrc],
229
  label="Lrc Examples",
230
  examples_per_page=2
231
  )
 
321
 
322
  lyrics_btn.click(
323
  fn=infer_music,
324
+ inputs=[lrc, audio_prompt, steps, sway_sampling_coef_bool],
325
  outputs=audio_output
326
  )
327
 
diffrhythm/infer/infer.py CHANGED
@@ -72,7 +72,7 @@ def decode_audio(latents, vae_model, chunked=False, overlap=32, chunk_size=128):
72
  y_final[:,:,t_start:t_end] = y_chunk[:,:,chunk_start:chunk_end]
73
  return y_final
74
 
75
- def inference(cfm_model, vae_model, cond, text, duration, style_prompt, negative_style_prompt, start_time, steps):
76
  # import pdb; pdb.set_trace()
77
  with torch.inference_mode():
78
  generated, _ = cfm_model.sample(
@@ -81,8 +81,9 @@ def inference(cfm_model, vae_model, cond, text, duration, style_prompt, negative
81
  duration=duration,
82
  style_prompt=style_prompt,
83
  negative_style_prompt=negative_style_prompt,
84
- steps=32,
85
  cfg_strength=4.0,
 
86
  start_time=start_time
87
  )
88
 
@@ -100,10 +101,10 @@ def inference(cfm_model, vae_model, cond, text, duration, style_prompt, negative
100
 
101
  if __name__ == "__main__":
102
  parser = argparse.ArgumentParser()
103
- parser.add_argument('--lrc-path', type=str, default="/home/node59_tmpdata3/hkchen/DiffRhythm/diffrhythm/diffrhythm/infer/example/eg.lrc") # lyrics of target song
104
- parser.add_argument('--ref-audio-path', type=str, default="/home/node59_tmpdata3/hkchen/DiffRhythm/diffrhythm/diffrhythm/infer/example/eg.mp3") # reference audio as style prompt for target song
105
  parser.add_argument('--audio-length', type=int, default=95) # length of target song
106
- parser.add_argument('--output-dir', type=str, default="/home/node59_tmpdata3/hkchen/DiffRhythm/diffrhythm/diffrhythm/infer/example/output")
107
  args = parser.parse_args()
108
 
109
  device = 'cuda'
 
72
  y_final[:,:,t_start:t_end] = y_chunk[:,:,chunk_start:chunk_end]
73
  return y_final
74
 
75
+ def inference(cfm_model, vae_model, cond, text, duration, style_prompt, negative_style_prompt, steps, sway_sampling_coef, start_time):
76
  # import pdb; pdb.set_trace()
77
  with torch.inference_mode():
78
  generated, _ = cfm_model.sample(
 
81
  duration=duration,
82
  style_prompt=style_prompt,
83
  negative_style_prompt=negative_style_prompt,
84
+ steps=steps,
85
  cfg_strength=4.0,
86
+ sway_sampling_coef=sway_sampling_coef,
87
  start_time=start_time
88
  )
89
 
 
101
 
102
  if __name__ == "__main__":
103
  parser = argparse.ArgumentParser()
104
+ parser.add_argument('--lrc-path', type=str, default="example/eg.lrc") # lyrics of target song
105
+ parser.add_argument('--ref-audio-path', type=str, default="example/eg.mp3") # reference audio as style prompt for target song
106
  parser.add_argument('--audio-length', type=int, default=95) # length of target song
107
+ parser.add_argument('--output-dir', type=str, default="example/output")
108
  args = parser.parse_args()
109
 
110
  device = 'cuda'