Spaces:
Runtime error
Runtime error
Commit
Β·
3b03b8f
1
Parent(s):
38ff6b6
Update llama/m2ugen.py
Browse files- llama/m2ugen.py +12 -13
llama/m2ugen.py
CHANGED
@@ -332,7 +332,7 @@ class M2UGen(nn.Module):
|
|
332 |
sub_x = all_layer_hidden_states.mean(-2).unsqueeze(0)
|
333 |
aggoutputs += sub_x
|
334 |
aggoutputs /= len(all_inputs)
|
335 |
-
sub_x = self.mu_mert_agg(aggoutputs.to(
|
336 |
del aggoutputs
|
337 |
xs.append(sub_x)
|
338 |
x = torch.stack(xs, dim=0)
|
@@ -345,7 +345,7 @@ class M2UGen(nn.Module):
|
|
345 |
with torch.no_grad():
|
346 |
outputs = self.vit_model(**inputs)
|
347 |
last_hidden_states = outputs.last_hidden_state
|
348 |
-
sub_x = self.iu_vit_agg(last_hidden_states.to(
|
349 |
xs.append(sub_x)
|
350 |
return torch.stack(xs, dim=0)
|
351 |
|
@@ -356,7 +356,7 @@ class M2UGen(nn.Module):
|
|
356 |
with torch.no_grad():
|
357 |
outputs = self.vivit_model(**inputs)
|
358 |
last_hidden_states = outputs.last_hidden_state
|
359 |
-
sub_x = self.iu_vivit_agg(last_hidden_states.to(
|
360 |
xs.append(sub_x)
|
361 |
return torch.stack(xs, dim=0)
|
362 |
|
@@ -489,20 +489,21 @@ class M2UGen(nn.Module):
|
|
489 |
@torch.inference_mode()
|
490 |
def forward_inference(self, tokens, start_pos: int, audio_feats=None, image_feats=None, video_feats=None):
|
491 |
_bsz, seqlen = tokens.shape
|
492 |
-
h = self.llama.tok_embeddings(tokens)
|
493 |
-
freqs_cis = self.llama.freqs_cis.to(
|
494 |
freqs_cis = freqs_cis[start_pos:start_pos + seqlen]
|
495 |
|
496 |
-
feats = torch.zeros((1, 1, 4096)).to(
|
497 |
if audio_feats is not None:
|
498 |
feats += audio_feats
|
499 |
if video_feats is not None:
|
500 |
feats += video_feats
|
501 |
if image_feats is not None:
|
502 |
feats += image_feats
|
503 |
-
|
|
|
504 |
mask = None
|
505 |
-
mask = torch.full((1, 1, seqlen, seqlen), float("-inf"), device=
|
506 |
mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h)
|
507 |
|
508 |
music_output_embedding = []
|
@@ -603,11 +604,10 @@ class M2UGen(nn.Module):
|
|
603 |
@torch.inference_mode()
|
604 |
def generate_music(self, embeddings, audio_length_in_s, music_caption):
|
605 |
gen_prefix = ''.join([f'[AUD{i}]' for i in range(len(self.audio_tokens))])
|
606 |
-
gen_prefx_ids = self.tokenizer(gen_prefix, add_special_tokens=False, return_tensors="pt").input_ids.to(
|
607 |
-
self.device)
|
608 |
gen_prefix_embs = self.llama.tok_embeddings(gen_prefx_ids)
|
609 |
if self.music_decoder == "audioldm2":
|
610 |
-
gen_emb = self.output_projector(embeddings.float().to("cuda"), gen_prefix_embs).squeeze(dim=0) / 10
|
611 |
prompt_embeds, generated_prompt_embeds = gen_emb[:, :128 * 1024], gen_emb[:, 128 * 1024:]
|
612 |
prompt_embeds = prompt_embeds.reshape(prompt_embeds.shape[0], 128, 1024)
|
613 |
generated_prompt_embeds = generated_prompt_embeds.reshape(generated_prompt_embeds.shape[0], 8, 768)
|
@@ -623,8 +623,7 @@ class M2UGen(nn.Module):
|
|
623 |
print("Generating Music...")
|
624 |
gen_emb = 0.1 * self.output_projector(embeddings.float().to("cuda"), gen_prefix_embs) / 10
|
625 |
gen_inputs = self.generation_processor(text=music_caption, padding='max_length',
|
626 |
-
max_length=128, truncation=True, return_tensors="pt").to(
|
627 |
-
self.device)
|
628 |
#gen_emb = self.generation_model.generate(**gen_inputs, guidance_scale=3.5, encoder_only=True)
|
629 |
audio_outputs = self.generation_model.generate(**gen_inputs, guidance_scale=3.5,
|
630 |
max_new_tokens=int(256 / 5 * audio_length_in_s))
|
|
|
332 |
sub_x = all_layer_hidden_states.mean(-2).unsqueeze(0)
|
333 |
aggoutputs += sub_x
|
334 |
aggoutputs /= len(all_inputs)
|
335 |
+
sub_x = self.mu_mert_agg(aggoutputs.to("cuda:0")).squeeze()
|
336 |
del aggoutputs
|
337 |
xs.append(sub_x)
|
338 |
x = torch.stack(xs, dim=0)
|
|
|
345 |
with torch.no_grad():
|
346 |
outputs = self.vit_model(**inputs)
|
347 |
last_hidden_states = outputs.last_hidden_state
|
348 |
+
sub_x = self.iu_vit_agg(last_hidden_states.to("cuda:0")).squeeze()
|
349 |
xs.append(sub_x)
|
350 |
return torch.stack(xs, dim=0)
|
351 |
|
|
|
356 |
with torch.no_grad():
|
357 |
outputs = self.vivit_model(**inputs)
|
358 |
last_hidden_states = outputs.last_hidden_state
|
359 |
+
sub_x = self.iu_vivit_agg(last_hidden_states.to("cuda:0")).squeeze()
|
360 |
xs.append(sub_x)
|
361 |
return torch.stack(xs, dim=0)
|
362 |
|
|
|
489 |
@torch.inference_mode()
|
490 |
def forward_inference(self, tokens, start_pos: int, audio_feats=None, image_feats=None, video_feats=None):
|
491 |
_bsz, seqlen = tokens.shape
|
492 |
+
h = self.llama.tok_embeddings(tokens).to("cuda:1")
|
493 |
+
freqs_cis = self.llama.freqs_cis.to("cuda:1")
|
494 |
freqs_cis = freqs_cis[start_pos:start_pos + seqlen]
|
495 |
|
496 |
+
feats = torch.zeros((1, 1, 4096)).to("cuda:0")
|
497 |
if audio_feats is not None:
|
498 |
feats += audio_feats
|
499 |
if video_feats is not None:
|
500 |
feats += video_feats
|
501 |
if image_feats is not None:
|
502 |
feats += image_feats
|
503 |
+
feats = feats.to("cuda:1")
|
504 |
+
|
505 |
mask = None
|
506 |
+
mask = torch.full((1, 1, seqlen, seqlen), float("-inf"), device="cuda:1")
|
507 |
mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h)
|
508 |
|
509 |
music_output_embedding = []
|
|
|
604 |
@torch.inference_mode()
|
605 |
def generate_music(self, embeddings, audio_length_in_s, music_caption):
|
606 |
gen_prefix = ''.join([f'[AUD{i}]' for i in range(len(self.audio_tokens))])
|
607 |
+
gen_prefx_ids = self.tokenizer(gen_prefix, add_special_tokens=False, return_tensors="pt").input_ids.to("cuda:1")
|
|
|
608 |
gen_prefix_embs = self.llama.tok_embeddings(gen_prefx_ids)
|
609 |
if self.music_decoder == "audioldm2":
|
610 |
+
gen_emb = self.output_projector(embeddings.float().to("cuda:1"), gen_prefix_embs).squeeze(dim=0) / 10
|
611 |
prompt_embeds, generated_prompt_embeds = gen_emb[:, :128 * 1024], gen_emb[:, 128 * 1024:]
|
612 |
prompt_embeds = prompt_embeds.reshape(prompt_embeds.shape[0], 128, 1024)
|
613 |
generated_prompt_embeds = generated_prompt_embeds.reshape(generated_prompt_embeds.shape[0], 8, 768)
|
|
|
623 |
print("Generating Music...")
|
624 |
gen_emb = 0.1 * self.output_projector(embeddings.float().to("cuda"), gen_prefix_embs) / 10
|
625 |
gen_inputs = self.generation_processor(text=music_caption, padding='max_length',
|
626 |
+
max_length=128, truncation=True, return_tensors="pt").to("cuda:1")
|
|
|
627 |
#gen_emb = self.generation_model.generate(**gen_inputs, guidance_scale=3.5, encoder_only=True)
|
628 |
audio_outputs = self.generation_model.generate(**gen_inputs, guidance_scale=3.5,
|
629 |
max_new_tokens=int(256 / 5 * audio_length_in_s))
|