Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -119,6 +119,35 @@ def split_lyrics(lyrics: str):
|
|
119 |
return structured_lyrics
|
120 |
|
121 |
@spaces.GPU(duration=175)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
122 |
def generate_music(
|
123 |
genre_txt=None,
|
124 |
lyrics_txt=None,
|
@@ -171,7 +200,7 @@ def generate_music(
|
|
171 |
|
172 |
for i, p in enumerate(tqdm(prompt_texts[:run_n_segments])):
|
173 |
section_text = p.replace('[start_of_segment]', '').replace('[end_of_segment]', '')
|
174 |
-
guidance_scale = 1.5 if i <= 1 else 1.2 #
|
175 |
if i == 0:
|
176 |
continue
|
177 |
if i == 1:
|
@@ -182,56 +211,30 @@ def generate_music(
|
|
182 |
raw_codes = codec_model.encode(audio_prompt.to(device), target_bw=0.5)
|
183 |
raw_codes = raw_codes.transpose(0, 1)
|
184 |
raw_codes = raw_codes.cpu().numpy().astype(np.int16)
|
185 |
-
# Format audio prompt
|
186 |
code_ids = codectool.npy2ids(raw_codes[0])
|
187 |
-
audio_prompt_codec = code_ids[int(prompt_start_time * 50): int(prompt_end_time * 50)]
|
188 |
-
audio_prompt_codec_ids = [mmtokenizer.soa] + codectool.sep_ids + audio_prompt_codec + [
|
189 |
-
|
190 |
-
sentence_ids = mmtokenizer.tokenize("[start_of_reference]") + audio_prompt_codec_ids + mmtokenizer.tokenize(
|
191 |
-
"[end_of_reference]")
|
192 |
head_id = mmtokenizer.tokenize(prompt_texts[0]) + sentence_ids
|
193 |
else:
|
194 |
head_id = mmtokenizer.tokenize(prompt_texts[0])
|
195 |
prompt_ids = head_id + start_of_segment + mmtokenizer.tokenize(section_text) + [mmtokenizer.soa] + codectool.sep_ids
|
196 |
else:
|
197 |
prompt_ids = end_of_segment + start_of_segment + mmtokenizer.tokenize(section_text) + [mmtokenizer.soa] + codectool.sep_ids
|
198 |
-
|
199 |
prompt_ids = torch.as_tensor(prompt_ids).unsqueeze(0).to(device)
|
200 |
input_ids = torch.cat([raw_output, prompt_ids], dim=1) if i > 1 else prompt_ids
|
201 |
-
|
202 |
-
#
|
203 |
max_context = 16384 - max_new_tokens - 1
|
204 |
if input_ids.shape[-1] > max_context:
|
205 |
print(
|
206 |
f'Section {i}: output length {input_ids.shape[-1]} exceeding context length {max_context}, now using the last {max_context} tokens.')
|
207 |
input_ids = input_ids[:, -(max_context):]
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
"""
|
213 |
-
with torch.inference_mode(), torch.autocast(device_type='cuda', dtype=torch.float16):
|
214 |
-
output_seq = model.generate(
|
215 |
-
input_ids=input_ids,
|
216 |
-
max_new_tokens=max_new_tokens,
|
217 |
-
min_new_tokens=100, # Keep min_new_tokens to avoid short generations
|
218 |
-
do_sample=True,
|
219 |
-
top_p=top_p,
|
220 |
-
temperature=temperature,
|
221 |
-
repetition_penalty=repetition_penalty,
|
222 |
-
eos_token_id=mmtokenizer.eoa,
|
223 |
-
pad_token_id=mmtokenizer.eoa,
|
224 |
-
logits_processor=LogitsProcessorList([BlockTokenRangeProcessor(0, 32002), BlockTokenRangeProcessor(32016, 32016)]),
|
225 |
-
guidance_scale=guidance_scale,
|
226 |
-
use_cache=True
|
227 |
-
)
|
228 |
-
if output_seq[0][-1].item() != mmtokenizer.eoa:
|
229 |
-
tensor_eoa = torch.as_tensor([[mmtokenizer.eoa]]).to(model.device)
|
230 |
-
output_seq = torch.cat((output_seq, tensor_eoa), dim=1)
|
231 |
-
return output_seq
|
232 |
-
|
233 |
-
output_seq = model_inference(input_ids, max_new_tokens, top_p, temperature, repetition_penalty, guidance_scale)
|
234 |
-
|
235 |
if i > 1:
|
236 |
raw_output = torch.cat([raw_output, prompt_ids, output_seq[:, input_ids.shape[-1]:]], dim=1)
|
237 |
else:
|
|
|
119 |
return structured_lyrics
|
120 |
|
121 |
@spaces.GPU(duration=175)
|
122 |
+
def requires_cuda(input_ids, max_new_tokens, top_p, temperature, repetition_penalty, guidance_scale):
|
123 |
+
"""
|
124 |
+
This function wraps the heavy GPU inference that uses torch.autocast and torch.inference_mode.
|
125 |
+
It calls model.generate with the appropriate parameters and returns the generated sequence.
|
126 |
+
"""
|
127 |
+
with torch.inference_mode(), torch.autocast(device_type='cuda', dtype=torch.float16):
|
128 |
+
output_seq = model.generate(
|
129 |
+
input_ids=input_ids,
|
130 |
+
max_new_tokens=max_new_tokens,
|
131 |
+
min_new_tokens=100, # Keep min_new_tokens to avoid short generations
|
132 |
+
do_sample=True,
|
133 |
+
top_p=top_p,
|
134 |
+
temperature=temperature,
|
135 |
+
repetition_penalty=repetition_penalty,
|
136 |
+
eos_token_id=mmtokenizer.eoa,
|
137 |
+
pad_token_id=mmtokenizer.eoa,
|
138 |
+
logits_processor=LogitsProcessorList([
|
139 |
+
BlockTokenRangeProcessor(0, 32002),
|
140 |
+
BlockTokenRangeProcessor(32016, 32016)
|
141 |
+
]),
|
142 |
+
guidance_scale=guidance_scale,
|
143 |
+
use_cache=True
|
144 |
+
)
|
145 |
+
# If the output does not end with the EOS token, append it.
|
146 |
+
if output_seq[0][-1].item() != mmtokenizer.eoa:
|
147 |
+
tensor_eoa = torch.as_tensor([[mmtokenizer.eoa]]).to(model.device)
|
148 |
+
output_seq = torch.cat((output_seq, tensor_eoa), dim=1)
|
149 |
+
return output_seq
|
150 |
+
|
151 |
def generate_music(
|
152 |
genre_txt=None,
|
153 |
lyrics_txt=None,
|
|
|
200 |
|
201 |
for i, p in enumerate(tqdm(prompt_texts[:run_n_segments])):
|
202 |
section_text = p.replace('[start_of_segment]', '').replace('[end_of_segment]', '')
|
203 |
+
guidance_scale = 1.5 if i <= 1 else 1.2 # Adjust guidance scale per segment
|
204 |
if i == 0:
|
205 |
continue
|
206 |
if i == 1:
|
|
|
211 |
raw_codes = codec_model.encode(audio_prompt.to(device), target_bw=0.5)
|
212 |
raw_codes = raw_codes.transpose(0, 1)
|
213 |
raw_codes = raw_codes.cpu().numpy().astype(np.int16)
|
|
|
214 |
code_ids = codectool.npy2ids(raw_codes[0])
|
215 |
+
audio_prompt_codec = code_ids[int(prompt_start_time * 50): int(prompt_end_time * 50)]
|
216 |
+
audio_prompt_codec_ids = [mmtokenizer.soa] + codectool.sep_ids + audio_prompt_codec + [mmtokenizer.eoa]
|
217 |
+
sentence_ids = mmtokenizer.tokenize("[start_of_reference]") + audio_prompt_codec_ids + mmtokenizer.tokenize("[end_of_reference]")
|
|
|
|
|
218 |
head_id = mmtokenizer.tokenize(prompt_texts[0]) + sentence_ids
|
219 |
else:
|
220 |
head_id = mmtokenizer.tokenize(prompt_texts[0])
|
221 |
prompt_ids = head_id + start_of_segment + mmtokenizer.tokenize(section_text) + [mmtokenizer.soa] + codectool.sep_ids
|
222 |
else:
|
223 |
prompt_ids = end_of_segment + start_of_segment + mmtokenizer.tokenize(section_text) + [mmtokenizer.soa] + codectool.sep_ids
|
224 |
+
|
225 |
prompt_ids = torch.as_tensor(prompt_ids).unsqueeze(0).to(device)
|
226 |
input_ids = torch.cat([raw_output, prompt_ids], dim=1) if i > 1 else prompt_ids
|
227 |
+
|
228 |
+
# Window slicing in case the sequence exceeds the model's context length
|
229 |
max_context = 16384 - max_new_tokens - 1
|
230 |
if input_ids.shape[-1] > max_context:
|
231 |
print(
|
232 |
f'Section {i}: output length {input_ids.shape[-1]} exceeding context length {max_context}, now using the last {max_context} tokens.')
|
233 |
input_ids = input_ids[:, -(max_context):]
|
234 |
+
|
235 |
+
# Perform the GPU-heavy inference using the requires_cuda function.
|
236 |
+
output_seq = requires_cuda(input_ids, max_new_tokens, top_p, temperature, repetition_penalty, guidance_scale)
|
237 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
238 |
if i > 1:
|
239 |
raw_output = torch.cat([raw_output, prompt_ids, output_seq[:, input_ids.shape[-1]:]], dim=1)
|
240 |
else:
|