ing0 commited on
Commit
d0bd81d
·
1 Parent(s): 6f225e3
Files changed (4) hide show
  1. README.md +5 -2
  2. app.py +67 -40
  3. diffrhythm/infer/infer_utils.py +11 -12
  4. src/DiffRhythm.jpg +0 -0
README.md CHANGED
@@ -1,7 +1,10 @@
1
  ---
2
- title: DiffRhythm
3
  emoji: 🎶
4
- colorFrom: green
 
 
 
5
  colorTo: purple
6
  sdk: gradio
7
  sdk_version: 5.20.0
 
1
  ---
2
+ title: Di♪♪Rhythm
3
  emoji: 🎶
4
+ tags:
5
+ - music generation
6
+ - diffusion models
7
+ colorFrom: red
8
  colorTo: purple
9
  sdk: gradio
10
  sdk_version: 5.20.0
app.py CHANGED
@@ -14,6 +14,7 @@ from tqdm import tqdm
14
  import random
15
  import numpy as np
16
  import sys
 
17
  from diffrhythm.infer.infer_utils import (
18
  get_reference_latent,
19
  get_lrc_token,
@@ -23,14 +24,17 @@ from diffrhythm.infer.infer_utils import (
23
  )
24
  from diffrhythm.infer.infer import inference
25
 
26
-
27
  device='cuda'
28
  cfm, tokenizer, muq, vae = prepare_model(device)
29
  cfm = torch.compile(cfm)
30
 
31
  @spaces.GPU
32
- def infer_music(lrc, ref_audio_path, steps, file_type, max_frames=2048, device='cuda'):
33
 
 
 
 
34
  sway_sampling_coef = -1 if steps < 32 else None
35
  lrc_prompt, start_time = get_lrc_token(lrc, tokenizer, device)
36
  style_prompt = get_style_prompt(muq, ref_audio_path)
@@ -115,7 +119,7 @@ def R1_infer2(tags_lyrics, lyrics_input):
115
  css = """
116
  /* 固定文本域高度并强制滚动条 */
117
  .lyrics-scroll-box textarea {
118
- height: 300px !important; /* 固定高度 */
119
  max-height: 500px !important; /* 最大高度 */
120
  overflow-y: auto !important; /* 垂直滚动 */
121
  white-space: pre-wrap; /* 保留换行 */
@@ -131,26 +135,36 @@ css = """
131
  }
132
 
133
  """
 
 
 
 
134
 
135
  with gr.Blocks(css=css) as demo:
136
  # gr.Markdown("<h1 style='text-align: center'>DiffRhythm (谛韵)</h1>")
137
- gr.HTML("""
138
-
139
- <div style="font-size: 2em; font-weight: bold; text-align: center; margin-bottom: 5px">
140
- DiffRhythm (谛韵)
141
- </div>
142
- <div style="display:flex; justify-content: center; column-gap:4px;">
143
- <a href="https://arxiv.org/abs/2503.01183">
144
- <img src='https://img.shields.io/badge/Arxiv-Paper-blue'>
145
- </a>
146
- <a href="https://github.com/ASLP-lab/DiffRhythm">
147
- <img src='https://img.shields.io/badge/GitHub-Repo-green'>
148
- </a>
149
- <a href="https://aslp-lab.github.io/DiffRhythm.github.io/">
150
- <img src='https://img.shields.io/badge/Project-Page-brown'>
151
- </a>
152
- </div>
153
- """)
 
 
 
 
 
 
154
 
155
  with gr.Tabs() as tabs:
156
 
@@ -158,7 +172,18 @@ with gr.Blocks(css=css) as demo:
158
  with gr.Tab("Music Generate", id=0):
159
  with gr.Row():
160
  with gr.Column():
161
- with gr.Accordion("Best Practices Guide", open=False):
 
 
 
 
 
 
 
 
 
 
 
162
  gr.Markdown("""
163
  1. **Lyrics Format Requirements**
164
  - Each line must follow: `[mm:ss.xx]Lyric content`
@@ -173,24 +198,27 @@ with gr.Blocks(css=css) as demo:
173
  - Total timestamps should not exceed 01:35.00 (95 seconds)
174
 
175
  3. **Audio Prompt Requirements**
176
- - Reference audio should be ≥10 seconds for optimal results
 
177
  - Shorter clips may lead to incoherent generation
 
 
 
 
178
  """)
179
- lrc = gr.Textbox(
180
- label="Lrc",
181
- placeholder="Input the full lyrics",
182
- lines=12,
183
- max_lines=50,
184
- elem_classes="lyrics-scroll-box",
185
- value="""[00:10.00]Moonlight spills through broken blinds\n[00:13.20]Your shadow dances on the dashboard shrine\n[00:16.85]Neon ghosts in gasoline rain\n[00:20.40]I hear your laughter down the midnight train\n[00:24.15]Static whispers through frayed wires\n[00:27.65]Guitar strings hum our cathedral choirs\n[00:31.30]Flicker screens show reruns of June\n[00:34.90]I'm drowning in this mercury lagoon\n[00:38.55]Electric veins pulse through concrete skies\n[00:42.10]Your name echoes in the hollow where my heartbeat lies\n[00:45.75]We're satellites trapped in parallel light\n[00:49.25]Burning through the atmosphere of endless night\n[01:00.00]Dusty vinyl spins reverse\n[01:03.45]Our polaroid timeline bleeds through the verse\n[01:07.10]Telescope aimed at dead stars\n[01:10.65]Still tracing constellations through prison bars\n[01:14.30]Electric veins pulse through concrete skies\n[01:17.85]Your name echoes in the hollow where my heartbeat lies\n[01:21.50]We're satellites trapped in parallel light\n[01:25.05]Burning through the atmosphere of endless night\n[02:10.00]Clockwork gears grind moonbeams to rust\n[02:13.50]Our fingerprint smudged by interstellar dust\n[02:17.15]Velvet thunder rolls through my veins\n[02:20.70]Chasing phantom trains through solar plane\n[02:24.35]Electric veins pulse through concrete skies\n[02:27.90]Your name echoes in the hollow where my heartbeat lies"""
186
- )
187
- audio_prompt = gr.Audio(label="Audio Prompt", type="filepath", value="./src/prompt/default.wav")
188
-
189
- with gr.Column():
190
 
191
- lyrics_btn = gr.Button("Submit", variant="primary")
192
  audio_output = gr.Audio(label="Audio Result", type="filepath", elem_id="audio_output")
193
  with gr.Accordion("Advanced Settings", open=False):
 
 
 
 
 
 
 
 
 
194
  steps = gr.Slider(
195
  minimum=10,
196
  maximum=100,
@@ -201,7 +229,6 @@ with gr.Blocks(css=css) as demo:
201
  elem_id="step_slider"
202
  )
203
  file_type = gr.Dropdown(["wav", "mp3", "ogg"], label="Output Format", value="wav")
204
-
205
 
206
 
207
  gr.Examples(
@@ -245,8 +272,8 @@ with gr.Blocks(css=css) as demo:
245
 
246
  with gr.Group():
247
  gr.Markdown("### Method 1: Generate from Theme")
248
- theme = gr.Textbox(label="theme", placeholder="Enter song theme, e.g. Love and Heartbreak")
249
- tags_gen = gr.Textbox(label="tags", placeholder="Example: male pop confidence healing")
250
  language = gr.Radio(["zh", "en"], label="Language", value="en")
251
  gen_from_theme_btn = gr.Button("Generate LRC (From Theme)", variant="primary")
252
 
@@ -269,10 +296,10 @@ with gr.Blocks(css=css) as demo:
269
 
270
  with gr.Group(visible=True):
271
  gr.Markdown("### Method 2: Add Timestamps to Lyrics")
272
- tags_lyrics = gr.Textbox(label="tags", placeholder="Example: female ballad piano slow")
273
  lyrics_input = gr.Textbox(
274
  label="Raw Lyrics (without timestamps)",
275
- placeholder="Enter plain lyrics (without timestamps), e.g.:\nYesterday\nAll my troubles...",
276
  lines=10,
277
  max_lines=50,
278
  elem_classes="lyrics-scroll-box"
@@ -326,7 +353,7 @@ with gr.Blocks(css=css) as demo:
326
 
327
  lyrics_btn.click(
328
  fn=infer_music,
329
- inputs=[lrc, audio_prompt, steps, file_type],
330
  outputs=audio_output
331
  )
332
 
 
14
  import random
15
  import numpy as np
16
  import sys
17
+ import base64
18
  from diffrhythm.infer.infer_utils import (
19
  get_reference_latent,
20
  get_lrc_token,
 
24
  )
25
  from diffrhythm.infer.infer import inference
26
 
27
+ MAX_SEED = np.iinfo(np.int32).max
28
  device='cuda'
29
  cfm, tokenizer, muq, vae = prepare_model(device)
30
  cfm = torch.compile(cfm)
31
 
32
  @spaces.GPU
33
+ def infer_music(lrc, ref_audio_path, seed=42, randomize_seed=False, steps=32, file_type='wav', max_frames=2048, device='cuda'):
34
 
35
+ if randomize_seed:
36
+ seed = random.randint(0, MAX_SEED)
37
+ torch.manual_seed(seed)
38
  sway_sampling_coef = -1 if steps < 32 else None
39
  lrc_prompt, start_time = get_lrc_token(lrc, tokenizer, device)
40
  style_prompt = get_style_prompt(muq, ref_audio_path)
 
119
  css = """
120
  /* 固定文本域高度并强制滚动条 */
121
  .lyrics-scroll-box textarea {
122
+ height: 405px !important; /* 固定高度 */
123
  max-height: 500px !important; /* 最大高度 */
124
  overflow-y: auto !important; /* 垂直滚动 */
125
  white-space: pre-wrap; /* 保留换行 */
 
135
  }
136
 
137
  """
138
+ def image_to_base64(image_path):
139
+ with open(image_path, "rb") as f:
140
+ return f"data:image/png;base64,{base64.b64encode(f.read()).decode('utf-8')}"
141
+
142
 
143
  with gr.Blocks(css=css) as demo:
144
  # gr.Markdown("<h1 style='text-align: center'>DiffRhythm (谛韵)</h1>")
145
+ gr.HTML(f"""
146
+ <div style="display: flex; align-items: center;">
147
+ <img src='{image_to_base64("./src/DiffRhythm.jpg")}'
148
+ style='width: 200px; height: 40%; display: block; margin: 0 auto 20px;'>
149
+ </div>
150
+
151
+ <div style="flex: 1; text-align: center;">
152
+ <div style="font-size: 2em; font-weight: bold; text-align: center; margin-bottom: 5px">
153
+ Di♪♪Rhythm (谛韵)
154
+ </div>
155
+ <div style="display:flex; justify-content: center; column-gap:4px;">
156
+ <a href="https://arxiv.org/abs/2503.01183">
157
+ <img src='https://img.shields.io/badge/Arxiv-Paper-blue'>
158
+ </a>
159
+ <a href="https://github.com/ASLP-lab/DiffRhythm">
160
+ <img src='https://img.shields.io/badge/GitHub-Repo-green'>
161
+ </a>
162
+ <a href="https://aslp-lab.github.io/DiffRhythm.github.io/">
163
+ <img src='https://img.shields.io/badge/Project-Page-brown'>
164
+ </a>
165
+ </div>
166
+ </div>
167
+ """)
168
 
169
  with gr.Tabs() as tabs:
170
 
 
172
  with gr.Tab("Music Generate", id=0):
173
  with gr.Row():
174
  with gr.Column():
175
+ lrc = gr.Textbox(
176
+ label="Lrc",
177
+ placeholder="Input the full lyrics",
178
+ lines=12,
179
+ max_lines=50,
180
+ elem_classes="lyrics-scroll-box",
181
+ value="""[00:10.00]Moonlight spills through broken blinds\n[00:13.20]Your shadow dances on the dashboard shrine\n[00:16.85]Neon ghosts in gasoline rain\n[00:20.40]I hear your laughter down the midnight train\n[00:24.15]Static whispers through frayed wires\n[00:27.65]Guitar strings hum our cathedral choirs\n[00:31.30]Flicker screens show reruns of June\n[00:34.90]I'm drowning in this mercury lagoon\n[00:38.55]Electric veins pulse through concrete skies\n[00:42.10]Your name echoes in the hollow where my heartbeat lies\n[00:45.75]We're satellites trapped in parallel light\n[00:49.25]Burning through the atmosphere of endless night\n[01:00.00]Dusty vinyl spins reverse\n[01:03.45]Our polaroid timeline bleeds through the verse\n[01:07.10]Telescope aimed at dead stars\n[01:10.65]Still tracing constellations through prison bars\n[01:14.30]Electric veins pulse through concrete skies\n[01:17.85]Your name echoes in the hollow where my heartbeat lies\n[01:21.50]We're satellites trapped in parallel light\n[01:25.05]Burning through the atmosphere of endless night\n[02:10.00]Clockwork gears grind moonbeams to rust\n[02:13.50]Our fingerprint smudged by interstellar dust\n[02:17.15]Velvet thunder rolls through my veins\n[02:20.70]Chasing phantom trains through solar plane\n[02:24.35]Electric veins pulse through concrete skies\n[02:27.90]Your name echoes in the hollow where my heartbeat lies"""
182
+ )
183
+ audio_prompt = gr.Audio(label="Audio Prompt", type="filepath", value="./src/prompt/default.wav")
184
+
185
+ with gr.Column():
186
+ with gr.Accordion("Best Practices Guide", open=True):
187
  gr.Markdown("""
188
  1. **Lyrics Format Requirements**
189
  - Each line must follow: `[mm:ss.xx]Lyric content`
 
198
  - Total timestamps should not exceed 01:35.00 (95 seconds)
199
 
200
  3. **Audio Prompt Requirements**
201
+ - Reference audio should be ≥ 1 second, audio >10 seconds will be randomly clipped into 10 seconds
202
+ - For optimal results, the 10-second clips should be carefully selected
203
  - Shorter clips may lead to incoherent generation
204
+
205
+ 4. **Supported Languages**
206
+ - Chinese and English
207
+ - More languages comming soon
208
  """)
 
 
 
 
 
 
 
 
 
 
 
209
 
210
+ lyrics_btn = gr.Button("Generate", variant="primary")
211
  audio_output = gr.Audio(label="Audio Result", type="filepath", elem_id="audio_output")
212
  with gr.Accordion("Advanced Settings", open=False):
213
+ seed = gr.Slider(
214
+ label="Seed",
215
+ minimum=0,
216
+ maximum=MAX_SEED,
217
+ step=1,
218
+ value=0,
219
+ )
220
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
221
+
222
  steps = gr.Slider(
223
  minimum=10,
224
  maximum=100,
 
229
  elem_id="step_slider"
230
  )
231
  file_type = gr.Dropdown(["wav", "mp3", "ogg"], label="Output Format", value="wav")
 
232
 
233
 
234
  gr.Examples(
 
272
 
273
  with gr.Group():
274
  gr.Markdown("### Method 1: Generate from Theme")
275
+ theme = gr.Textbox(label="theme", placeholder="Enter song theme, e.g: Love and Heartbreak")
276
+ tags_gen = gr.Textbox(label="tags", placeholder="Enter song tags, e.g: pop confidence healing")
277
  language = gr.Radio(["zh", "en"], label="Language", value="en")
278
  gen_from_theme_btn = gr.Button("Generate LRC (From Theme)", variant="primary")
279
 
 
296
 
297
  with gr.Group(visible=True):
298
  gr.Markdown("### Method 2: Add Timestamps to Lyrics")
299
+ tags_lyrics = gr.Textbox(label="tags", placeholder="Enter song tags, e.g: ballad piano slow")
300
  lyrics_input = gr.Textbox(
301
  label="Raw Lyrics (without timestamps)",
302
+ placeholder="Enter plain lyrics (without timestamps), e.g:\nYesterday\nAll my troubles...",
303
  lines=10,
304
  max_lines=50,
305
  elem_classes="lyrics-scroll-box"
 
353
 
354
  lyrics_btn.click(
355
  fn=infer_music,
356
+ inputs=[lrc, audio_prompt, seed, randomize_seed, steps, file_type],
357
  outputs=audio_output
358
  )
359
 
diffrhythm/infer/infer_utils.py CHANGED
@@ -58,25 +58,24 @@ def get_style_prompt(model, wav_path):
58
  if ext == '.mp3':
59
  meta = MP3(wav_path)
60
  audio_len = meta.info.length
61
- src_sr = meta.info.sample_rate
62
- elif ext == '.wav':
63
- audio, sr = librosa.load(wav_path, sr=None)
64
- audio_len = librosa.get_duration(y=audio, sr=sr)
65
- src_sr = sr
66
  else:
67
  raise ValueError("Unsupported file format: {}".format(ext))
68
 
69
- assert(audio_len >= 10)
70
 
71
- mid_time = audio_len // 2
72
- start_time = mid_time - 5
73
- wav, sr = librosa.load(wav_path, sr=None, offset=start_time, duration=10)
 
74
 
75
- resampled_wav = librosa.resample(wav, orig_sr=src_sr, target_sr=24000)
76
- resampled_wav = torch.tensor(resampled_wav).unsqueeze(0).to(model.device)
 
77
 
78
  with torch.no_grad():
79
- audio_emb = mulan(wavs = resampled_wav) # [1, 512]
80
 
81
  audio_emb = audio_emb
82
  audio_emb = audio_emb.half()
 
58
  if ext == '.mp3':
59
  meta = MP3(wav_path)
60
  audio_len = meta.info.length
61
+ elif ext in ['.wav', '.flac']:
62
+ audio_len = librosa.get_duration(path=wav_path)
 
 
 
63
  else:
64
  raise ValueError("Unsupported file format: {}".format(ext))
65
 
66
+ assert audio_len >= 1, "Input audio length shorter than 1 second"
67
 
68
+ if audio_len >= 10:
69
+ mid_time = audio_len // 2
70
+ start_time = mid_time - 5
71
+ wav, _ = librosa.load(wav_path, sr=24000, offset=start_time, duration=10)
72
 
73
+ else:
74
+ wav, _ = librosa.load(wav_path, sr=24000)
75
+ wav = torch.tensor(wav).unsqueeze(0).to(model.device)
76
 
77
  with torch.no_grad():
78
+ audio_emb = mulan(wavs = wav) # [1, 512]
79
 
80
  audio_emb = audio_emb
81
  audio_emb = audio_emb.half()
src/DiffRhythm.jpg CHANGED