crypto-code commited on
Commit
3b03b8f
Β·
1 Parent(s): 38ff6b6

Update llama/m2ugen.py

Browse files
Files changed (1) hide show
  1. llama/m2ugen.py +12 -13
llama/m2ugen.py CHANGED
@@ -332,7 +332,7 @@ class M2UGen(nn.Module):
332
  sub_x = all_layer_hidden_states.mean(-2).unsqueeze(0)
333
  aggoutputs += sub_x
334
  aggoutputs /= len(all_inputs)
335
- sub_x = self.mu_mert_agg(aggoutputs.to(self.device)).squeeze()
336
  del aggoutputs
337
  xs.append(sub_x)
338
  x = torch.stack(xs, dim=0)
@@ -345,7 +345,7 @@ class M2UGen(nn.Module):
345
  with torch.no_grad():
346
  outputs = self.vit_model(**inputs)
347
  last_hidden_states = outputs.last_hidden_state
348
- sub_x = self.iu_vit_agg(last_hidden_states.to(self.device)).squeeze()
349
  xs.append(sub_x)
350
  return torch.stack(xs, dim=0)
351
 
@@ -356,7 +356,7 @@ class M2UGen(nn.Module):
356
  with torch.no_grad():
357
  outputs = self.vivit_model(**inputs)
358
  last_hidden_states = outputs.last_hidden_state
359
- sub_x = self.iu_vivit_agg(last_hidden_states.to(self.device)).squeeze()
360
  xs.append(sub_x)
361
  return torch.stack(xs, dim=0)
362
 
@@ -489,20 +489,21 @@ class M2UGen(nn.Module):
489
  @torch.inference_mode()
490
  def forward_inference(self, tokens, start_pos: int, audio_feats=None, image_feats=None, video_feats=None):
491
  _bsz, seqlen = tokens.shape
492
- h = self.llama.tok_embeddings(tokens)
493
- freqs_cis = self.llama.freqs_cis.to(h.device)
494
  freqs_cis = freqs_cis[start_pos:start_pos + seqlen]
495
 
496
- feats = torch.zeros((1, 1, 4096)).to(self.device)
497
  if audio_feats is not None:
498
  feats += audio_feats
499
  if video_feats is not None:
500
  feats += video_feats
501
  if image_feats is not None:
502
  feats += image_feats
503
-
 
504
  mask = None
505
- mask = torch.full((1, 1, seqlen, seqlen), float("-inf"), device=h.device)
506
  mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h)
507
 
508
  music_output_embedding = []
@@ -603,11 +604,10 @@ class M2UGen(nn.Module):
603
  @torch.inference_mode()
604
  def generate_music(self, embeddings, audio_length_in_s, music_caption):
605
  gen_prefix = ''.join([f'[AUD{i}]' for i in range(len(self.audio_tokens))])
606
- gen_prefx_ids = self.tokenizer(gen_prefix, add_special_tokens=False, return_tensors="pt").input_ids.to(
607
- self.device)
608
  gen_prefix_embs = self.llama.tok_embeddings(gen_prefx_ids)
609
  if self.music_decoder == "audioldm2":
610
- gen_emb = self.output_projector(embeddings.float().to("cuda"), gen_prefix_embs).squeeze(dim=0) / 10
611
  prompt_embeds, generated_prompt_embeds = gen_emb[:, :128 * 1024], gen_emb[:, 128 * 1024:]
612
  prompt_embeds = prompt_embeds.reshape(prompt_embeds.shape[0], 128, 1024)
613
  generated_prompt_embeds = generated_prompt_embeds.reshape(generated_prompt_embeds.shape[0], 8, 768)
@@ -623,8 +623,7 @@ class M2UGen(nn.Module):
623
  print("Generating Music...")
624
  gen_emb = 0.1 * self.output_projector(embeddings.float().to("cuda"), gen_prefix_embs) / 10
625
  gen_inputs = self.generation_processor(text=music_caption, padding='max_length',
626
- max_length=128, truncation=True, return_tensors="pt").to(
627
- self.device)
628
  #gen_emb = self.generation_model.generate(**gen_inputs, guidance_scale=3.5, encoder_only=True)
629
  audio_outputs = self.generation_model.generate(**gen_inputs, guidance_scale=3.5,
630
  max_new_tokens=int(256 / 5 * audio_length_in_s))
 
332
  sub_x = all_layer_hidden_states.mean(-2).unsqueeze(0)
333
  aggoutputs += sub_x
334
  aggoutputs /= len(all_inputs)
335
+ sub_x = self.mu_mert_agg(aggoutputs.to("cuda:0")).squeeze()
336
  del aggoutputs
337
  xs.append(sub_x)
338
  x = torch.stack(xs, dim=0)
 
345
  with torch.no_grad():
346
  outputs = self.vit_model(**inputs)
347
  last_hidden_states = outputs.last_hidden_state
348
+ sub_x = self.iu_vit_agg(last_hidden_states.to("cuda:0")).squeeze()
349
  xs.append(sub_x)
350
  return torch.stack(xs, dim=0)
351
 
 
356
  with torch.no_grad():
357
  outputs = self.vivit_model(**inputs)
358
  last_hidden_states = outputs.last_hidden_state
359
+ sub_x = self.iu_vivit_agg(last_hidden_states.to("cuda:0")).squeeze()
360
  xs.append(sub_x)
361
  return torch.stack(xs, dim=0)
362
 
 
489
  @torch.inference_mode()
490
  def forward_inference(self, tokens, start_pos: int, audio_feats=None, image_feats=None, video_feats=None):
491
  _bsz, seqlen = tokens.shape
492
+ h = self.llama.tok_embeddings(tokens).to("cuda:1")
493
+ freqs_cis = self.llama.freqs_cis.to("cuda:1")
494
  freqs_cis = freqs_cis[start_pos:start_pos + seqlen]
495
 
496
+ feats = torch.zeros((1, 1, 4096)).to("cuda:0")
497
  if audio_feats is not None:
498
  feats += audio_feats
499
  if video_feats is not None:
500
  feats += video_feats
501
  if image_feats is not None:
502
  feats += image_feats
503
+ feats = feats.to("cuda:1")
504
+
505
  mask = None
506
+ mask = torch.full((1, 1, seqlen, seqlen), float("-inf"), device="cuda:1")
507
  mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h)
508
 
509
  music_output_embedding = []
 
604
  @torch.inference_mode()
605
  def generate_music(self, embeddings, audio_length_in_s, music_caption):
606
  gen_prefix = ''.join([f'[AUD{i}]' for i in range(len(self.audio_tokens))])
607
+ gen_prefx_ids = self.tokenizer(gen_prefix, add_special_tokens=False, return_tensors="pt").input_ids.to("cuda:1")
 
608
  gen_prefix_embs = self.llama.tok_embeddings(gen_prefx_ids)
609
  if self.music_decoder == "audioldm2":
610
+ gen_emb = self.output_projector(embeddings.float().to("cuda:1"), gen_prefix_embs).squeeze(dim=0) / 10
611
  prompt_embeds, generated_prompt_embeds = gen_emb[:, :128 * 1024], gen_emb[:, 128 * 1024:]
612
  prompt_embeds = prompt_embeds.reshape(prompt_embeds.shape[0], 128, 1024)
613
  generated_prompt_embeds = generated_prompt_embeds.reshape(generated_prompt_embeds.shape[0], 8, 768)
 
623
  print("Generating Music...")
624
  gen_emb = 0.1 * self.output_projector(embeddings.float().to("cuda"), gen_prefix_embs) / 10
625
  gen_inputs = self.generation_processor(text=music_caption, padding='max_length',
626
+ max_length=128, truncation=True, return_tensors="pt").to("cuda:1")
 
627
  #gen_emb = self.generation_model.generate(**gen_inputs, guidance_scale=3.5, encoder_only=True)
628
  audio_outputs = self.generation_model.generate(**gen_inputs, guidance_scale=3.5,
629
  max_new_tokens=int(256 / 5 * audio_length_in_s))