Spaces:
Running
on
Zero
Running
on
Zero
import copy | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from transformers.models.llama.modeling_llama import LlamaDecoderLayer | |
from transformers.models.qwen2.modeling_qwen2 import Qwen2DecoderLayer | |
from omni_speech.constants import IGNORE_INDEX | |
torch.autograd.set_detect_anomaly(True) | |
try: | |
import sys | |
sys.path.append('/mnt/lzy/LLaMA-Omni/CosyVoice/') | |
from cosyvoice.cli.cosyvoice import CosyVoice | |
except: | |
print('CosyVoice not found') | |
import os | |
if 'SPEECH_GEN_CONV_KERNEL' in os.environ: | |
SPEECH_GEN_CONV_KERNEL = int(os.environ['SPEECH_GEN_CONV_KERNEL']) | |
print(f'Using SPEECH_GEN_CONV_KERNEL={SPEECH_GEN_CONV_KERNEL}') | |
else: | |
SPEECH_GEN_CONV_KERNEL = -1 | |
if 'DISTILL_EMBEDDING' in os.environ: | |
DISTILL_EMBEDDING = True | |
print(f'DISTILL_EMBEDDING is set.') | |
else: | |
DISTILL_EMBEDDING = False | |
def lengths_to_padding_mask(lens): | |
bsz, max_lens = lens.size(0), torch.max(lens).item() | |
mask = torch.arange(max_lens).to(lens.device).view(1, max_lens) | |
mask = mask.expand(bsz, -1) >= lens.view(bsz, 1).expand(-1, max_lens) | |
return mask | |
def _uniform_assignment(src_lens, tgt_lens): | |
tgt_indices = torch.arange(torch.max(tgt_lens)).expand(len(tgt_lens), -1).to(tgt_lens.device) | |
ratio = tgt_lens / src_lens | |
index_t = (tgt_indices / ratio.view(-1, 1)).long() | |
return index_t | |
class SpeechGeneratorCTC(nn.Module): | |
def __init__(self, config): | |
super().__init__() | |
n_layers, n_dims, n_heads, n_inter_dims = list(map(int, config.ctc_decoder_config[1:-1].split(","))) | |
_config = copy.deepcopy(config) | |
_config.hidden_size = n_dims | |
_config.num_hidden_layers = n_layers | |
_config.num_attention_heads = n_heads | |
_config.num_key_value_heads = n_heads | |
_config.intermediate_size = n_inter_dims | |
_config._attn_implementation = "flash_attention_2" | |
self.upsample_factor = config.ctc_upsample_factor | |
self.input_proj = nn.Linear(config.hidden_size, n_dims) | |
self.layers = nn.ModuleList( | |
[LlamaDecoderLayer(_config, layer_idx) for layer_idx in range(n_layers)] | |
) | |
self.unit_vocab_size = config.unit_vocab_size | |
self.output_proj = nn.Linear(n_dims, config.unit_vocab_size + 1) | |
def upsample(self, reps, tgt_units=None): | |
src_lens = torch.LongTensor([len(rep) for rep in reps]).to(reps[0].device) | |
up_lens = src_lens * self.upsample_factor | |
if tgt_units is not None: | |
tgt_lens = tgt_units.ne(IGNORE_INDEX).long().sum(dim=-1) | |
up_lens = torch.max(up_lens, tgt_lens) | |
reps = torch.nn.utils.rnn.pad_sequence(reps, batch_first=True) | |
padding_mask = lengths_to_padding_mask(up_lens) | |
mapped_inputs = _uniform_assignment(src_lens, up_lens).masked_fill( | |
padding_mask, 0 | |
) | |
copied_reps = torch.gather( | |
reps, | |
1, | |
mapped_inputs.unsqueeze(-1).expand( | |
*mapped_inputs.size(), reps.size(-1) | |
), | |
) | |
copied_reps = copied_reps.masked_fill(padding_mask.unsqueeze(-1), 0) | |
position_ids = torch.arange(0, max(up_lens)).unsqueeze(0).expand(len(reps), -1).to(device=copied_reps.device) | |
return copied_reps, ~padding_mask, position_ids | |
def forward(self, tgt_reps, labels, tgt_units): | |
tgt_label_reps = [] | |
for tgt_rep, label in zip(tgt_reps, labels): | |
tgt_label_reps.append(tgt_rep[label != IGNORE_INDEX]) | |
hidden_states, attention_mask, position_ids = self.upsample(tgt_label_reps, tgt_units) | |
hidden_states = self.input_proj(hidden_states) | |
for layer in self.layers: | |
layer_outputs = layer( | |
hidden_states, | |
attention_mask=attention_mask, | |
position_ids=position_ids, | |
) | |
hidden_states = layer_outputs[0] | |
ctc_logits = self.output_proj(hidden_states) | |
ctc_lprobs = F.log_softmax(ctc_logits.float(), dim=-1, dtype=torch.float32) | |
ctc_lens = attention_mask.long().sum(dim=-1) | |
ctc_tgt_lens = tgt_units.ne(IGNORE_INDEX).long().sum(dim=-1) | |
ctc_tgt_mask = ~lengths_to_padding_mask(ctc_tgt_lens) | |
ctc_tgt_flat = tgt_units.masked_select(ctc_tgt_mask) | |
ctc_loss = F.ctc_loss( | |
ctc_lprobs.transpose(0, 1), | |
ctc_tgt_flat, | |
ctc_lens, | |
ctc_tgt_lens, | |
reduction="sum", | |
zero_infinity=True, | |
blank=self.unit_vocab_size | |
) | |
ctc_loss /= ctc_tgt_lens.sum().item() | |
return ctc_loss | |
def predict(self, tgt_reps): | |
hidden_states, attention_mask, position_ids = self.upsample([tgt_reps]) | |
hidden_states = self.input_proj(hidden_states) | |
for layer in self.layers: | |
layer_outputs = layer( | |
hidden_states, | |
attention_mask=attention_mask, | |
position_ids=position_ids, | |
) | |
hidden_states = layer_outputs[0] | |
ctc_logits = self.output_proj(hidden_states) | |
ctc_lprobs = F.log_softmax(ctc_logits.float(), dim=-1, dtype=torch.float32) | |
ctc_pred = ctc_lprobs.argmax(dim=-1).masked_fill_(~attention_mask, self.unit_vocab_size) | |
return ctc_pred | |
class SpeechGeneratorCTCQwen(nn.Module): | |
def __init__(self, config): | |
super().__init__() | |
n_layers, n_dims, n_heads, n_inter_dims, n_kv_heads = list(map(int, config.ctc_decoder_config[1:-1].split(","))) | |
_config = copy.deepcopy(config) | |
_config.hidden_size = n_dims | |
_config.num_hidden_layers = n_layers | |
_config.num_attention_heads = n_heads | |
_config.num_key_value_heads = n_kv_heads | |
_config.intermediate_size = n_inter_dims | |
_config._attn_implementation = "flash_attention_2" | |
self.upsample_factor = config.ctc_upsample_factor | |
self.input_proj = nn.Linear(config.hidden_size, n_dims) | |
self.layers = nn.ModuleList( | |
[Qwen2DecoderLayer(_config, layer_idx) for layer_idx in range(n_layers)] | |
) | |
self.unit_vocab_size = config.unit_vocab_size | |
self.output_proj = nn.Linear(n_dims, config.unit_vocab_size + 1) | |
if SPEECH_GEN_CONV_KERNEL > 0: | |
self.temporal_conv = nn.Conv1d(n_dims, n_dims, SPEECH_GEN_CONV_KERNEL, padding=0) | |
self.learnable_pad_left = nn.Parameter(torch.zeros(SPEECH_GEN_CONV_KERNEL // 2, n_dims)) | |
self.learnable_pad_right = nn.Parameter(torch.zeros(SPEECH_GEN_CONV_KERNEL // 2, n_dims)) | |
# self.conv_layer_id = n_layers // 2 # Insert temporal conv layer in the middle of the decoder layers | |
def upsample(self, reps, tgt_units=None): | |
src_lens = torch.LongTensor([len(rep) for rep in reps]).to(reps[0].device) | |
up_lens = src_lens * self.upsample_factor | |
if tgt_units is not None: | |
tgt_lens = tgt_units.ne(IGNORE_INDEX).long().sum(dim=-1) | |
up_lens = torch.max(up_lens, tgt_lens) | |
reps = torch.nn.utils.rnn.pad_sequence(reps, batch_first=True) | |
padding_mask = lengths_to_padding_mask(up_lens) | |
mapped_inputs = _uniform_assignment(src_lens, up_lens).masked_fill( | |
padding_mask, 0 | |
) | |
copied_reps = torch.gather( | |
reps, | |
1, | |
mapped_inputs.unsqueeze(-1).expand( | |
*mapped_inputs.size(), reps.size(-1) | |
), | |
) | |
copied_reps = copied_reps.masked_fill(padding_mask.unsqueeze(-1), 0) | |
position_ids = torch.arange(0, max(up_lens)).unsqueeze(0).expand(len(reps), -1).to(device=copied_reps.device) | |
return copied_reps, ~padding_mask, position_ids | |
def forward(self, tgt_reps, labels, tgt_units): | |
tgt_label_reps = [] | |
for tgt_rep, label in zip(tgt_reps, labels): | |
if SPEECH_GEN_CONV_KERNEL > 0: | |
now_rep = tgt_rep[label != IGNORE_INDEX] | |
now_rep = torch.cat([self.learnable_pad_left, now_rep, self.learnable_pad_right], dim=0) | |
now_rep = self.input_proj(now_rep)[None] | |
now_rep = self.temporal_conv(now_rep.transpose(1, 2)).transpose(1, 2)[0] | |
tgt_label_reps.append(now_rep) | |
else: | |
tgt_label_reps.append(tgt_rep[label != IGNORE_INDEX]) | |
hidden_states, attention_mask, position_ids = self.upsample(tgt_label_reps, tgt_units) | |
if SPEECH_GEN_CONV_KERNEL < 0: | |
hidden_states = self.input_proj(hidden_states) | |
for layer_id, layer in enumerate(self.layers): | |
# if SPEECH_GEN_CONV_KERNEL: | |
# if layer_id == self.conv_layer_id: | |
# hidden_states = self.temporal_conv(hidden_states.transpose(1, 2)).transpose(1, 2) | |
layer_outputs = layer( | |
hidden_states, | |
attention_mask=attention_mask, | |
position_ids=position_ids, | |
) | |
hidden_states = layer_outputs[0] | |
ctc_logits = self.output_proj(hidden_states) | |
ctc_lprobs = F.log_softmax(ctc_logits.float(), dim=-1, dtype=torch.float32) | |
ctc_lens = attention_mask.long().sum(dim=-1) | |
ctc_tgt_lens = tgt_units.ne(IGNORE_INDEX).long().sum(dim=-1) | |
ctc_tgt_mask = ~lengths_to_padding_mask(ctc_tgt_lens) | |
ctc_tgt_flat = tgt_units.masked_select(ctc_tgt_mask) | |
ctc_loss = F.ctc_loss( | |
ctc_lprobs.transpose(0, 1), | |
ctc_tgt_flat, | |
ctc_lens, | |
ctc_tgt_lens, | |
reduction="sum", | |
zero_infinity=True, | |
blank=self.unit_vocab_size | |
) | |
ctc_loss /= ctc_tgt_lens.sum().item() | |
return ctc_loss | |
def predict(self, tgt_reps): | |
hidden_states, attention_mask, position_ids = self.upsample([tgt_reps]) | |
hidden_states = self.input_proj(hidden_states) | |
for layer in self.layers: | |
layer_outputs = layer( | |
hidden_states, | |
attention_mask=attention_mask, | |
position_ids=position_ids, | |
) | |
hidden_states = layer_outputs[0] | |
ctc_logits = self.output_proj(hidden_states) | |
ctc_lprobs = F.log_softmax(ctc_logits.float(), dim=-1, dtype=torch.float32) | |
ctc_pred = ctc_lprobs.argmax(dim=-1).masked_fill_(~attention_mask, self.unit_vocab_size) | |
return ctc_pred | |
class SpeechGeneratorCEQwen(nn.Module): | |
def __init__(self, config): | |
super().__init__() | |
n_layers, n_dims, n_heads, n_inter_dims, n_kv_heads = list(map(int, config.ctc_decoder_config[1:-1].split(","))) | |
_config = copy.deepcopy(config) | |
_config.hidden_size = n_dims | |
_config.num_hidden_layers = n_layers | |
_config.num_attention_heads = n_heads | |
_config.num_key_value_heads = n_kv_heads | |
_config.intermediate_size = n_inter_dims | |
_config._attn_implementation = "flash_attention_2" | |
self.upsample_factor = 1 | |
self.input_proj = nn.Linear(config.hidden_size, n_dims) | |
self.layers = nn.ModuleList( | |
[Qwen2DecoderLayer(_config, layer_idx) for layer_idx in range(n_layers)] | |
) | |
self.unit_vocab_size = config.unit_vocab_size | |
self.output_proj = nn.Linear(n_dims, config.unit_vocab_size + 1) | |
def upsample(self, reps, tgt_units=None): | |
src_lens = torch.LongTensor([len(rep) for rep in reps]).to(reps[0].device) | |
up_lens = src_lens * self.upsample_factor | |
if tgt_units is not None: | |
tgt_lens = tgt_units.ne(IGNORE_INDEX).long().sum(dim=-1) | |
up_lens = torch.max(up_lens, tgt_lens) | |
reps = torch.nn.utils.rnn.pad_sequence(reps, batch_first=True) | |
padding_mask = lengths_to_padding_mask(up_lens) | |
mapped_inputs = _uniform_assignment(src_lens, up_lens).masked_fill( | |
padding_mask, 0 | |
) | |
copied_reps = torch.gather( | |
reps, | |
1, | |
mapped_inputs.unsqueeze(-1).expand( | |
*mapped_inputs.size(), reps.size(-1) | |
), | |
) | |
copied_reps = copied_reps.masked_fill(padding_mask.unsqueeze(-1), 0) | |
position_ids = torch.arange(0, max(up_lens)).unsqueeze(0).expand(len(reps), -1).to(device=copied_reps.device) | |
return copied_reps, ~padding_mask, position_ids | |
def forward(self, tgt_reps, labels, tgt_units): | |
tgt_label_reps = [] | |
for tgt_rep, label in zip(tgt_reps, labels): | |
# if SPEECH_GEN_CONV_KERNEL > 0: | |
# now_rep = tgt_rep[label != IGNORE_INDEX] | |
# now_rep = torch.cat([self.learnable_pad_left, now_rep, self.learnable_pad_right], dim=0) | |
# now_rep = self.input_proj(now_rep)[None] | |
# now_rep = self.temporal_conv(now_rep.transpose(1, 2)).transpose(1, 2)[0] | |
# tgt_label_reps.append(now_rep) | |
# else: | |
tgt_label_reps.append(tgt_rep[label != IGNORE_INDEX]) | |
hidden_states, attention_mask, position_ids = self.upsample(tgt_label_reps, tgt_units) | |
# if SPEECH_GEN_CONV_KERNEL < 0: | |
hidden_states = self.input_proj(hidden_states) | |
for layer_id, layer in enumerate(self.layers): | |
# if SPEECH_GEN_CONV_KERNEL: | |
# if layer_id == self.conv_layer_id: | |
# hidden_states = self.temporal_conv(hidden_states.transpose(1, 2)).transpose(1, 2) | |
layer_outputs = layer( | |
hidden_states, | |
attention_mask=attention_mask, | |
position_ids=position_ids, | |
) | |
hidden_states = layer_outputs[0] | |
shift_hidden_states = hidden_states[..., :-1, :].contiguous().reshape(-1, hidden_states.size(-1)) | |
logits = self.output_proj(shift_hidden_states) | |
shift_labels = tgt_units[..., 1:].contiguous().reshape(-1) | |
assert shift_labels.size(0) == shift_hidden_states.size(0) | |
loss_fct = nn.CrossEntropyLoss() | |
logits = logits.float() | |
loss = loss_fct(logits, shift_labels) | |
# loss = (loss / 1.0).sum().item() | |
# loss = loss.sum().item() | |
return loss | |
# def predict(self, tgt_reps): | |
# hidden_states, attention_mask, position_ids = self.upsample([tgt_reps]) | |
# hidden_states = self.input_proj(hidden_states) | |
# for layer in self.layers: | |
# layer_outputs = layer( | |
# hidden_states, | |
# attention_mask=attention_mask, | |
# position_ids=position_ids, | |
# ) | |
# hidden_states = layer_outputs[0] | |
# ctc_logits = self.output_proj(hidden_states) | |
# ctc_lprobs = F.log_softmax(ctc_logits.float(), dim=-1, dtype=torch.float32) | |
# ctc_pred = ctc_lprobs.argmax(dim=-1).masked_fill_(~attention_mask, self.unit_vocab_size) | |
# return ctc_pred | |
# class SpeechGeneratorCosyVoice(nn.Module): | |
# def __init__(self, config): | |
# super().__init__() | |
# self.input_proj = nn.Sequential( | |
# nn.Linear(config.hidden_size, 1024), | |
# nn.GELU(), | |
# nn.Linear(1024, 512) | |
# ) | |
# self.cosyvoice1 = CosyVoice('CosyVoice/pretrained_models/CosyVoice-300M-SFT', load_jit=False, load_onnx=False, fp16=False) | |
# self.cosyvoice = CosyVoice('CosyVoice/pretrained_models/CosyVoice-300M-SFT', load_jit=True, load_onnx=False, fp16=True) | |
# self.llm = self.cosyvoice1.model.llm | |
# if DISTILL_EMBEDDING: | |
# self.criterion = nn.CosineEmbeddingLoss() | |
# def forward(self, tgt_reps, labels, answer): | |
# tgt_label_reps = [] | |
# batch_speech_tokens = [] | |
# embeddings = [] | |
# target_embeddings = [] | |
# if DISTILL_EMBEDDING: | |
# for tgt_rep, label, ans in zip(tgt_reps, labels, answer): | |
# # make all label id in [151644,151645,198] to IGNORE_INDEX | |
# label[label == 151644] = IGNORE_INDEX | |
# label[label == 151645] = IGNORE_INDEX | |
# label[label == 198] = IGNORE_INDEX | |
# tgt_label_reps.append(tgt_rep[label != IGNORE_INDEX]) | |
# normalized_text = self.cosyvoice1.frontend.text_normalize(ans, split=True) | |
# tts_text_token_all = [] | |
# for norm_text in normalized_text: | |
# tts_text_token, tts_text_token_len = self.cosyvoice1.frontend._extract_text_token(norm_text) | |
# tts_text_token_all.append(tts_text_token) | |
# tts_text_token_all = torch.cat(tts_text_token_all, dim=0) | |
# target_embedding = self.cosyvoice1.model.llm.text_embedding(tts_text_token) | |
# target_embeddings.append(target_embedding) | |
# import pdb;pdb.set_trace() | |
# tgt_label_reps = torch.stack(tgt_label_reps) | |
# target_embeddings = torch.stack(target_embeddings).squeeze(1) | |
# hidden_states = self.input_proj(tgt_label_reps).reshape(-1, 512) | |
# target_embeddings = target_embeddings.reshape(-1, 512) | |
# loss = self.criterion(hidden_states, target_embeddings, torch.ones(hidden_states.size(0)).to(hidden_states.device)) | |
# else: | |
# for tgt_rep, label, ans in zip(tgt_reps, labels, answer): | |
# # make all label id in [151644,151645,198] to IGNORE_INDEX | |
# label[label == 151644] = IGNORE_INDEX | |
# label[label == 151645] = IGNORE_INDEX | |
# label[label == 198] = IGNORE_INDEX | |
# tgt_label_reps.append(tgt_rep[label != IGNORE_INDEX]) | |
# speech_token = self.cosyvoice.inference_label(ans, '英文女', stream=False) | |
# speech_tokens = [] | |
# for i,j in enumerate(speech_token): | |
# speech_tokens.append(j['tts_speech_token'].squeeze(0)) | |
# speech_tokens.append(torch.tensor([0])) | |
# speech_tokens = torch.cat(speech_tokens, dim=0) | |
# if speech_tokens.size(0) > 1: | |
# speech_tokens = speech_tokens[:-1] | |
# batch_speech_tokens.append(speech_tokens) | |
# embedding = self.cosyvoice.frontend.frontend_embedding('英文女') | |
# embeddings.append(embedding['llm_embedding'].squeeze(0)) | |
# tgt_label_reps = torch.stack(tgt_label_reps) | |
# batch_speech_token = torch.stack(batch_speech_tokens) | |
# embeddings = torch.stack(embeddings) | |
# hidden_states = self.input_proj(tgt_label_reps) | |
# batch = {'text_feature': hidden_states, 'text_token_len': torch.tensor([hidden_states.size(1)]).repeat(hidden_states.size(0)), | |
# 'speech_token': batch_speech_token, 'speech_token_len': torch.tensor([batch_speech_token.size(1)]).repeat(hidden_states.size(0)), | |
# 'embedding': embeddings} | |
# output = self.llm.forward_ours(batch, 'cuda') | |
# loss = output['loss'] | |
# return loss | |
class SpeechGeneratorCosyVoice(nn.Module): | |
def __init__(self, config): | |
super().__init__() | |
self.cosyvoice = CosyVoice('CosyVoice/pretrained_models/CosyVoice-300M-SFT', load_jit=True, load_onnx=False, fp16=True) | |
def forward(self, tgt_reps, labels, answer): | |
tgt_label_reps = [] | |
batch_speech_tokens = [] | |
embeddings = [] | |
target_embeddings = [] | |
if DISTILL_EMBEDDING: | |
for tgt_rep, label, ans in zip(tgt_reps, labels, answer): | |
# make all label id in [151644,151645,198] to IGNORE_INDEX | |
label[label == 151644] = IGNORE_INDEX | |
label[label == 151645] = IGNORE_INDEX | |
label[label == 198] = IGNORE_INDEX | |
tgt_label_reps.append(tgt_rep[label != IGNORE_INDEX]) | |
normalized_text = self.cosyvoice1.frontend.text_normalize(ans, split=True) | |
tts_text_token_all = [] | |
for norm_text in normalized_text: | |
tts_text_token, tts_text_token_len = self.cosyvoice1.frontend._extract_text_token(norm_text) | |
tts_text_token_all.append(tts_text_token) | |
tts_text_token_all = torch.cat(tts_text_token_all, dim=0) | |
target_embedding = self.cosyvoice1.model.llm.text_embedding(tts_text_token) | |
target_embeddings.append(target_embedding) | |
import pdb;pdb.set_trace() | |
tgt_label_reps = torch.stack(tgt_label_reps) | |
target_embeddings = torch.stack(target_embeddings).squeeze(1) | |
hidden_states = self.input_proj(tgt_label_reps).reshape(-1, 512) | |
target_embeddings = target_embeddings.reshape(-1, 512) | |
loss = self.criterion(hidden_states, target_embeddings, torch.ones(hidden_states.size(0)).to(hidden_states.device)) | |
else: | |
for tgt_rep, label, ans in zip(tgt_reps, labels, answer): | |
# make all label id in [151644,151645,198] to IGNORE_INDEX | |
label[label == 151644] = IGNORE_INDEX | |
label[label == 151645] = IGNORE_INDEX | |
label[label == 198] = IGNORE_INDEX | |
tgt_label_reps.append(tgt_rep[label != IGNORE_INDEX]) | |
speech_token = self.cosyvoice.inference_label(ans, '英文女', stream=False) | |
speech_tokens = [] | |
for i,j in enumerate(speech_token): | |
speech_tokens.append(j['tts_speech_token'].squeeze(0)) | |
speech_tokens.append(torch.tensor([0])) | |
speech_tokens = torch.cat(speech_tokens, dim=0) | |
if speech_tokens.size(0) > 1: | |
speech_tokens = speech_tokens[:-1] | |
batch_speech_tokens.append(speech_tokens) | |
embedding = self.cosyvoice.frontend.frontend_embedding('英文女') | |
embeddings.append(embedding['llm_embedding'].squeeze(0)) | |
tgt_label_reps = torch.stack(tgt_label_reps) | |
batch_speech_token = torch.stack(batch_speech_tokens) | |
embeddings = torch.stack(embeddings) | |
hidden_states = self.input_proj(tgt_label_reps) | |
batch = {'text_feature': hidden_states, 'text_token_len': torch.tensor([hidden_states.size(1)]).repeat(hidden_states.size(0)), | |
'speech_token': batch_speech_token, 'speech_token_len': torch.tensor([batch_speech_token.size(1)]).repeat(hidden_states.size(0)), | |
'embedding': embeddings} | |
output = self.llm.forward_ours(batch, 'cuda') | |
loss = output['loss'] | |
return loss |