KingNish commited on
Commit
ec39241
·
verified ·
1 Parent(s): dce5b4e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -37
app.py CHANGED
@@ -119,6 +119,35 @@ def split_lyrics(lyrics: str):
119
  return structured_lyrics
120
 
121
  @spaces.GPU(duration=175)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  def generate_music(
123
  genre_txt=None,
124
  lyrics_txt=None,
@@ -171,7 +200,7 @@ def generate_music(
171
 
172
  for i, p in enumerate(tqdm(prompt_texts[:run_n_segments])):
173
  section_text = p.replace('[start_of_segment]', '').replace('[end_of_segment]', '')
174
- guidance_scale = 1.5 if i <= 1 else 1.2 # Guidance scale adjusted based on segment index
175
  if i == 0:
176
  continue
177
  if i == 1:
@@ -182,56 +211,30 @@ def generate_music(
182
  raw_codes = codec_model.encode(audio_prompt.to(device), target_bw=0.5)
183
  raw_codes = raw_codes.transpose(0, 1)
184
  raw_codes = raw_codes.cpu().numpy().astype(np.int16)
185
- # Format audio prompt
186
  code_ids = codectool.npy2ids(raw_codes[0])
187
- audio_prompt_codec = code_ids[int(prompt_start_time * 50): int(prompt_end_time * 50)] # 50 is tps of xcodec
188
- audio_prompt_codec_ids = [mmtokenizer.soa] + codectool.sep_ids + audio_prompt_codec + [
189
- mmtokenizer.eoa]
190
- sentence_ids = mmtokenizer.tokenize("[start_of_reference]") + audio_prompt_codec_ids + mmtokenizer.tokenize(
191
- "[end_of_reference]")
192
  head_id = mmtokenizer.tokenize(prompt_texts[0]) + sentence_ids
193
  else:
194
  head_id = mmtokenizer.tokenize(prompt_texts[0])
195
  prompt_ids = head_id + start_of_segment + mmtokenizer.tokenize(section_text) + [mmtokenizer.soa] + codectool.sep_ids
196
  else:
197
  prompt_ids = end_of_segment + start_of_segment + mmtokenizer.tokenize(section_text) + [mmtokenizer.soa] + codectool.sep_ids
198
-
199
  prompt_ids = torch.as_tensor(prompt_ids).unsqueeze(0).to(device)
200
  input_ids = torch.cat([raw_output, prompt_ids], dim=1) if i > 1 else prompt_ids
201
-
202
- # Use window slicing in case output sequence exceeds the context of model
203
  max_context = 16384 - max_new_tokens - 1
204
  if input_ids.shape[-1] > max_context:
205
  print(
206
  f'Section {i}: output length {input_ids.shape[-1]} exceeding context length {max_context}, now using the last {max_context} tokens.')
207
  input_ids = input_ids[:, -(max_context):]
208
-
209
- def model_inference(input_ids, max_new_tokens, top_p, temperature, repetition_penalty, guidance_scale):
210
- """
211
- Performs model inference to generate music tokens.
212
- """
213
- with torch.inference_mode(), torch.autocast(device_type='cuda', dtype=torch.float16):
214
- output_seq = model.generate(
215
- input_ids=input_ids,
216
- max_new_tokens=max_new_tokens,
217
- min_new_tokens=100, # Keep min_new_tokens to avoid short generations
218
- do_sample=True,
219
- top_p=top_p,
220
- temperature=temperature,
221
- repetition_penalty=repetition_penalty,
222
- eos_token_id=mmtokenizer.eoa,
223
- pad_token_id=mmtokenizer.eoa,
224
- logits_processor=LogitsProcessorList([BlockTokenRangeProcessor(0, 32002), BlockTokenRangeProcessor(32016, 32016)]),
225
- guidance_scale=guidance_scale,
226
- use_cache=True
227
- )
228
- if output_seq[0][-1].item() != mmtokenizer.eoa:
229
- tensor_eoa = torch.as_tensor([[mmtokenizer.eoa]]).to(model.device)
230
- output_seq = torch.cat((output_seq, tensor_eoa), dim=1)
231
- return output_seq
232
-
233
- output_seq = model_inference(input_ids, max_new_tokens, top_p, temperature, repetition_penalty, guidance_scale)
234
-
235
  if i > 1:
236
  raw_output = torch.cat([raw_output, prompt_ids, output_seq[:, input_ids.shape[-1]:]], dim=1)
237
  else:
 
119
  return structured_lyrics
120
 
121
  @spaces.GPU(duration=175)
122
+ def requires_cuda(input_ids, max_new_tokens, top_p, temperature, repetition_penalty, guidance_scale):
123
+ """
124
+ This function wraps the heavy GPU inference that uses torch.autocast and torch.inference_mode.
125
+ It calls model.generate with the appropriate parameters and returns the generated sequence.
126
+ """
127
+ with torch.inference_mode(), torch.autocast(device_type='cuda', dtype=torch.float16):
128
+ output_seq = model.generate(
129
+ input_ids=input_ids,
130
+ max_new_tokens=max_new_tokens,
131
+ min_new_tokens=100, # Keep min_new_tokens to avoid short generations
132
+ do_sample=True,
133
+ top_p=top_p,
134
+ temperature=temperature,
135
+ repetition_penalty=repetition_penalty,
136
+ eos_token_id=mmtokenizer.eoa,
137
+ pad_token_id=mmtokenizer.eoa,
138
+ logits_processor=LogitsProcessorList([
139
+ BlockTokenRangeProcessor(0, 32002),
140
+ BlockTokenRangeProcessor(32016, 32016)
141
+ ]),
142
+ guidance_scale=guidance_scale,
143
+ use_cache=True
144
+ )
145
+ # If the output does not end with the EOS token, append it.
146
+ if output_seq[0][-1].item() != mmtokenizer.eoa:
147
+ tensor_eoa = torch.as_tensor([[mmtokenizer.eoa]]).to(model.device)
148
+ output_seq = torch.cat((output_seq, tensor_eoa), dim=1)
149
+ return output_seq
150
+
151
  def generate_music(
152
  genre_txt=None,
153
  lyrics_txt=None,
 
200
 
201
  for i, p in enumerate(tqdm(prompt_texts[:run_n_segments])):
202
  section_text = p.replace('[start_of_segment]', '').replace('[end_of_segment]', '')
203
+ guidance_scale = 1.5 if i <= 1 else 1.2 # Adjust guidance scale per segment
204
  if i == 0:
205
  continue
206
  if i == 1:
 
211
  raw_codes = codec_model.encode(audio_prompt.to(device), target_bw=0.5)
212
  raw_codes = raw_codes.transpose(0, 1)
213
  raw_codes = raw_codes.cpu().numpy().astype(np.int16)
 
214
  code_ids = codectool.npy2ids(raw_codes[0])
215
+ audio_prompt_codec = code_ids[int(prompt_start_time * 50): int(prompt_end_time * 50)]
216
+ audio_prompt_codec_ids = [mmtokenizer.soa] + codectool.sep_ids + audio_prompt_codec + [mmtokenizer.eoa]
217
+ sentence_ids = mmtokenizer.tokenize("[start_of_reference]") + audio_prompt_codec_ids + mmtokenizer.tokenize("[end_of_reference]")
 
 
218
  head_id = mmtokenizer.tokenize(prompt_texts[0]) + sentence_ids
219
  else:
220
  head_id = mmtokenizer.tokenize(prompt_texts[0])
221
  prompt_ids = head_id + start_of_segment + mmtokenizer.tokenize(section_text) + [mmtokenizer.soa] + codectool.sep_ids
222
  else:
223
  prompt_ids = end_of_segment + start_of_segment + mmtokenizer.tokenize(section_text) + [mmtokenizer.soa] + codectool.sep_ids
224
+
225
  prompt_ids = torch.as_tensor(prompt_ids).unsqueeze(0).to(device)
226
  input_ids = torch.cat([raw_output, prompt_ids], dim=1) if i > 1 else prompt_ids
227
+
228
+ # Window slicing in case the sequence exceeds the model's context length
229
  max_context = 16384 - max_new_tokens - 1
230
  if input_ids.shape[-1] > max_context:
231
  print(
232
  f'Section {i}: output length {input_ids.shape[-1]} exceeding context length {max_context}, now using the last {max_context} tokens.')
233
  input_ids = input_ids[:, -(max_context):]
234
+
235
+ # Perform the GPU-heavy inference using the requires_cuda function.
236
+ output_seq = requires_cuda(input_ids, max_new_tokens, top_p, temperature, repetition_penalty, guidance_scale)
237
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
238
  if i > 1:
239
  raw_output = torch.cat([raw_output, prompt_ids, output_seq[:, input_ids.shape[-1]:]], dim=1)
240
  else: