Spaces:
Running
on
Zero
Running
on
Zero
cpu test
Browse files- app.py +24 -9
- diffrhythm/infer/infer.py +6 -5
app.py
CHANGED
@@ -22,16 +22,13 @@ from diffrhythm.infer.infer_utils import (
|
|
22 |
)
|
23 |
from diffrhythm.infer.infer import inference
|
24 |
|
25 |
-
device='
|
26 |
cfm, tokenizer, muq, vae = prepare_model(device)
|
27 |
cfm = torch.compile(cfm)
|
28 |
|
29 |
-
def infer_music(lrc, ref_audio_path, max_frames=2048, device='
|
30 |
-
|
31 |
-
|
32 |
-
# print(lrc_list)
|
33 |
-
|
34 |
-
# return "./gift_of_the_world.wav"
|
35 |
lrc_prompt, start_time = get_lrc_token(lrc, tokenizer, device)
|
36 |
style_prompt = get_style_prompt(muq, ref_audio_path)
|
37 |
negative_style_prompt = get_negative_style_prompt(device)
|
@@ -43,6 +40,8 @@ def infer_music(lrc, ref_audio_path, max_frames=2048, device='cuda'):
|
|
43 |
duration=max_frames,
|
44 |
style_prompt=style_prompt,
|
45 |
negative_style_prompt=negative_style_prompt,
|
|
|
|
|
46 |
start_time=start_time
|
47 |
)
|
48 |
return generated_song
|
@@ -150,6 +149,22 @@ with gr.Blocks(css=css) as demo:
|
|
150 |
audio_prompt = gr.Audio(label="Audio Prompt", type="filepath")
|
151 |
|
152 |
with gr.Column():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
153 |
lyrics_btn = gr.Button("Submit", variant="primary")
|
154 |
audio_output = gr.Audio(label="Audio Result", type="filepath", elem_id="audio_output")
|
155 |
|
@@ -210,7 +225,7 @@ with gr.Blocks(css=css) as demo:
|
|
210 |
[01:24.20]Your laughter spins aurora threads
|
211 |
[01:28.65]Weaving dawn through featherbed"""]
|
212 |
],
|
213 |
-
inputs=[lrc],
|
214 |
label="Lrc Examples",
|
215 |
examples_per_page=2
|
216 |
)
|
@@ -306,7 +321,7 @@ with gr.Blocks(css=css) as demo:
|
|
306 |
|
307 |
lyrics_btn.click(
|
308 |
fn=infer_music,
|
309 |
-
inputs=[lrc, audio_prompt],
|
310 |
outputs=audio_output
|
311 |
)
|
312 |
|
|
|
22 |
)
|
23 |
from diffrhythm.infer.infer import inference
|
24 |
|
25 |
+
device='cpu'
|
26 |
cfm, tokenizer, muq, vae = prepare_model(device)
|
27 |
cfm = torch.compile(cfm)
|
28 |
|
29 |
+
def infer_music(lrc, ref_audio_path, steps, sway_sampling_coef_bool, max_frames=2048, device='cpu'):
|
30 |
+
|
31 |
+
sway_sampling_coef = -1 if sway_sampling_coef_bool else None
|
|
|
|
|
|
|
32 |
lrc_prompt, start_time = get_lrc_token(lrc, tokenizer, device)
|
33 |
style_prompt = get_style_prompt(muq, ref_audio_path)
|
34 |
negative_style_prompt = get_negative_style_prompt(device)
|
|
|
40 |
duration=max_frames,
|
41 |
style_prompt=style_prompt,
|
42 |
negative_style_prompt=negative_style_prompt,
|
43 |
+
steps=steps,
|
44 |
+
sway_sampling_coef=sway_sampling_coef,
|
45 |
start_time=start_time
|
46 |
)
|
47 |
return generated_song
|
|
|
149 |
audio_prompt = gr.Audio(label="Audio Prompt", type="filepath")
|
150 |
|
151 |
with gr.Column():
|
152 |
+
steps = gr.Slider(
|
153 |
+
minimum=10,
|
154 |
+
maximum=40,
|
155 |
+
value=32,
|
156 |
+
step=1,
|
157 |
+
label="Diffusion Steps",
|
158 |
+
interactive=True,
|
159 |
+
elem_id="step_slider"
|
160 |
+
)
|
161 |
+
sway_sampling_coef_bool = gr.Radio(
|
162 |
+
choices=[("False", False), ("True", True)],
|
163 |
+
label="Use sway_sampling_coef",
|
164 |
+
value=False,
|
165 |
+
interactive=True,
|
166 |
+
elem_classes="horizontal-radio"
|
167 |
+
)
|
168 |
lyrics_btn = gr.Button("Submit", variant="primary")
|
169 |
audio_output = gr.Audio(label="Audio Result", type="filepath", elem_id="audio_output")
|
170 |
|
|
|
225 |
[01:24.20]Your laughter spins aurora threads
|
226 |
[01:28.65]Weaving dawn through featherbed"""]
|
227 |
],
|
228 |
+
inputs=[lrc],
|
229 |
label="Lrc Examples",
|
230 |
examples_per_page=2
|
231 |
)
|
|
|
321 |
|
322 |
lyrics_btn.click(
|
323 |
fn=infer_music,
|
324 |
+
inputs=[lrc, audio_prompt, steps, sway_sampling_coef_bool],
|
325 |
outputs=audio_output
|
326 |
)
|
327 |
|
diffrhythm/infer/infer.py
CHANGED
@@ -72,7 +72,7 @@ def decode_audio(latents, vae_model, chunked=False, overlap=32, chunk_size=128):
|
|
72 |
y_final[:,:,t_start:t_end] = y_chunk[:,:,chunk_start:chunk_end]
|
73 |
return y_final
|
74 |
|
75 |
-
def inference(cfm_model, vae_model, cond, text, duration, style_prompt, negative_style_prompt,
|
76 |
# import pdb; pdb.set_trace()
|
77 |
with torch.inference_mode():
|
78 |
generated, _ = cfm_model.sample(
|
@@ -81,8 +81,9 @@ def inference(cfm_model, vae_model, cond, text, duration, style_prompt, negative
|
|
81 |
duration=duration,
|
82 |
style_prompt=style_prompt,
|
83 |
negative_style_prompt=negative_style_prompt,
|
84 |
-
steps=
|
85 |
cfg_strength=4.0,
|
|
|
86 |
start_time=start_time
|
87 |
)
|
88 |
|
@@ -100,10 +101,10 @@ def inference(cfm_model, vae_model, cond, text, duration, style_prompt, negative
|
|
100 |
|
101 |
if __name__ == "__main__":
|
102 |
parser = argparse.ArgumentParser()
|
103 |
-
parser.add_argument('--lrc-path', type=str, default="
|
104 |
-
parser.add_argument('--ref-audio-path', type=str, default="
|
105 |
parser.add_argument('--audio-length', type=int, default=95) # length of target song
|
106 |
-
parser.add_argument('--output-dir', type=str, default="
|
107 |
args = parser.parse_args()
|
108 |
|
109 |
device = 'cuda'
|
|
|
72 |
y_final[:,:,t_start:t_end] = y_chunk[:,:,chunk_start:chunk_end]
|
73 |
return y_final
|
74 |
|
75 |
+
def inference(cfm_model, vae_model, cond, text, duration, style_prompt, negative_style_prompt, steps, sway_sampling_coef, start_time):
|
76 |
# import pdb; pdb.set_trace()
|
77 |
with torch.inference_mode():
|
78 |
generated, _ = cfm_model.sample(
|
|
|
81 |
duration=duration,
|
82 |
style_prompt=style_prompt,
|
83 |
negative_style_prompt=negative_style_prompt,
|
84 |
+
steps=steps,
|
85 |
cfg_strength=4.0,
|
86 |
+
sway_sampling_coef=sway_sampling_coef,
|
87 |
start_time=start_time
|
88 |
)
|
89 |
|
|
|
101 |
|
102 |
if __name__ == "__main__":
|
103 |
parser = argparse.ArgumentParser()
|
104 |
+
parser.add_argument('--lrc-path', type=str, default="example/eg.lrc") # lyrics of target song
|
105 |
+
parser.add_argument('--ref-audio-path', type=str, default="example/eg.mp3") # reference audio as style prompt for target song
|
106 |
parser.add_argument('--audio-length', type=int, default=95) # length of target song
|
107 |
+
parser.add_argument('--output-dir', type=str, default="example/output")
|
108 |
args = parser.parse_args()
|
109 |
|
110 |
device = 'cuda'
|