Ola / ola /model /speech_generator /speech_generator.py
dongyh20
update space
1938217
raw
history blame
22.5 kB
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
from transformers.models.qwen2.modeling_qwen2 import Qwen2DecoderLayer
from omni_speech.constants import IGNORE_INDEX
torch.autograd.set_detect_anomaly(True)
try:
import sys
sys.path.append('/mnt/lzy/LLaMA-Omni/CosyVoice/')
from cosyvoice.cli.cosyvoice import CosyVoice
except:
print('CosyVoice not found')
import os
if 'SPEECH_GEN_CONV_KERNEL' in os.environ:
SPEECH_GEN_CONV_KERNEL = int(os.environ['SPEECH_GEN_CONV_KERNEL'])
print(f'Using SPEECH_GEN_CONV_KERNEL={SPEECH_GEN_CONV_KERNEL}')
else:
SPEECH_GEN_CONV_KERNEL = -1
if 'DISTILL_EMBEDDING' in os.environ:
DISTILL_EMBEDDING = True
print(f'DISTILL_EMBEDDING is set.')
else:
DISTILL_EMBEDDING = False
def lengths_to_padding_mask(lens):
bsz, max_lens = lens.size(0), torch.max(lens).item()
mask = torch.arange(max_lens).to(lens.device).view(1, max_lens)
mask = mask.expand(bsz, -1) >= lens.view(bsz, 1).expand(-1, max_lens)
return mask
def _uniform_assignment(src_lens, tgt_lens):
tgt_indices = torch.arange(torch.max(tgt_lens)).expand(len(tgt_lens), -1).to(tgt_lens.device)
ratio = tgt_lens / src_lens
index_t = (tgt_indices / ratio.view(-1, 1)).long()
return index_t
class SpeechGeneratorCTC(nn.Module):
def __init__(self, config):
super().__init__()
n_layers, n_dims, n_heads, n_inter_dims = list(map(int, config.ctc_decoder_config[1:-1].split(",")))
_config = copy.deepcopy(config)
_config.hidden_size = n_dims
_config.num_hidden_layers = n_layers
_config.num_attention_heads = n_heads
_config.num_key_value_heads = n_heads
_config.intermediate_size = n_inter_dims
_config._attn_implementation = "flash_attention_2"
self.upsample_factor = config.ctc_upsample_factor
self.input_proj = nn.Linear(config.hidden_size, n_dims)
self.layers = nn.ModuleList(
[LlamaDecoderLayer(_config, layer_idx) for layer_idx in range(n_layers)]
)
self.unit_vocab_size = config.unit_vocab_size
self.output_proj = nn.Linear(n_dims, config.unit_vocab_size + 1)
def upsample(self, reps, tgt_units=None):
src_lens = torch.LongTensor([len(rep) for rep in reps]).to(reps[0].device)
up_lens = src_lens * self.upsample_factor
if tgt_units is not None:
tgt_lens = tgt_units.ne(IGNORE_INDEX).long().sum(dim=-1)
up_lens = torch.max(up_lens, tgt_lens)
reps = torch.nn.utils.rnn.pad_sequence(reps, batch_first=True)
padding_mask = lengths_to_padding_mask(up_lens)
mapped_inputs = _uniform_assignment(src_lens, up_lens).masked_fill(
padding_mask, 0
)
copied_reps = torch.gather(
reps,
1,
mapped_inputs.unsqueeze(-1).expand(
*mapped_inputs.size(), reps.size(-1)
),
)
copied_reps = copied_reps.masked_fill(padding_mask.unsqueeze(-1), 0)
position_ids = torch.arange(0, max(up_lens)).unsqueeze(0).expand(len(reps), -1).to(device=copied_reps.device)
return copied_reps, ~padding_mask, position_ids
def forward(self, tgt_reps, labels, tgt_units):
tgt_label_reps = []
for tgt_rep, label in zip(tgt_reps, labels):
tgt_label_reps.append(tgt_rep[label != IGNORE_INDEX])
hidden_states, attention_mask, position_ids = self.upsample(tgt_label_reps, tgt_units)
hidden_states = self.input_proj(hidden_states)
for layer in self.layers:
layer_outputs = layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
)
hidden_states = layer_outputs[0]
ctc_logits = self.output_proj(hidden_states)
ctc_lprobs = F.log_softmax(ctc_logits.float(), dim=-1, dtype=torch.float32)
ctc_lens = attention_mask.long().sum(dim=-1)
ctc_tgt_lens = tgt_units.ne(IGNORE_INDEX).long().sum(dim=-1)
ctc_tgt_mask = ~lengths_to_padding_mask(ctc_tgt_lens)
ctc_tgt_flat = tgt_units.masked_select(ctc_tgt_mask)
ctc_loss = F.ctc_loss(
ctc_lprobs.transpose(0, 1),
ctc_tgt_flat,
ctc_lens,
ctc_tgt_lens,
reduction="sum",
zero_infinity=True,
blank=self.unit_vocab_size
)
ctc_loss /= ctc_tgt_lens.sum().item()
return ctc_loss
def predict(self, tgt_reps):
hidden_states, attention_mask, position_ids = self.upsample([tgt_reps])
hidden_states = self.input_proj(hidden_states)
for layer in self.layers:
layer_outputs = layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
)
hidden_states = layer_outputs[0]
ctc_logits = self.output_proj(hidden_states)
ctc_lprobs = F.log_softmax(ctc_logits.float(), dim=-1, dtype=torch.float32)
ctc_pred = ctc_lprobs.argmax(dim=-1).masked_fill_(~attention_mask, self.unit_vocab_size)
return ctc_pred
class SpeechGeneratorCTCQwen(nn.Module):
def __init__(self, config):
super().__init__()
n_layers, n_dims, n_heads, n_inter_dims, n_kv_heads = list(map(int, config.ctc_decoder_config[1:-1].split(",")))
_config = copy.deepcopy(config)
_config.hidden_size = n_dims
_config.num_hidden_layers = n_layers
_config.num_attention_heads = n_heads
_config.num_key_value_heads = n_kv_heads
_config.intermediate_size = n_inter_dims
_config._attn_implementation = "flash_attention_2"
self.upsample_factor = config.ctc_upsample_factor
self.input_proj = nn.Linear(config.hidden_size, n_dims)
self.layers = nn.ModuleList(
[Qwen2DecoderLayer(_config, layer_idx) for layer_idx in range(n_layers)]
)
self.unit_vocab_size = config.unit_vocab_size
self.output_proj = nn.Linear(n_dims, config.unit_vocab_size + 1)
if SPEECH_GEN_CONV_KERNEL > 0:
self.temporal_conv = nn.Conv1d(n_dims, n_dims, SPEECH_GEN_CONV_KERNEL, padding=0)
self.learnable_pad_left = nn.Parameter(torch.zeros(SPEECH_GEN_CONV_KERNEL // 2, n_dims))
self.learnable_pad_right = nn.Parameter(torch.zeros(SPEECH_GEN_CONV_KERNEL // 2, n_dims))
# self.conv_layer_id = n_layers // 2 # Insert temporal conv layer in the middle of the decoder layers
def upsample(self, reps, tgt_units=None):
src_lens = torch.LongTensor([len(rep) for rep in reps]).to(reps[0].device)
up_lens = src_lens * self.upsample_factor
if tgt_units is not None:
tgt_lens = tgt_units.ne(IGNORE_INDEX).long().sum(dim=-1)
up_lens = torch.max(up_lens, tgt_lens)
reps = torch.nn.utils.rnn.pad_sequence(reps, batch_first=True)
padding_mask = lengths_to_padding_mask(up_lens)
mapped_inputs = _uniform_assignment(src_lens, up_lens).masked_fill(
padding_mask, 0
)
copied_reps = torch.gather(
reps,
1,
mapped_inputs.unsqueeze(-1).expand(
*mapped_inputs.size(), reps.size(-1)
),
)
copied_reps = copied_reps.masked_fill(padding_mask.unsqueeze(-1), 0)
position_ids = torch.arange(0, max(up_lens)).unsqueeze(0).expand(len(reps), -1).to(device=copied_reps.device)
return copied_reps, ~padding_mask, position_ids
def forward(self, tgt_reps, labels, tgt_units):
tgt_label_reps = []
for tgt_rep, label in zip(tgt_reps, labels):
if SPEECH_GEN_CONV_KERNEL > 0:
now_rep = tgt_rep[label != IGNORE_INDEX]
now_rep = torch.cat([self.learnable_pad_left, now_rep, self.learnable_pad_right], dim=0)
now_rep = self.input_proj(now_rep)[None]
now_rep = self.temporal_conv(now_rep.transpose(1, 2)).transpose(1, 2)[0]
tgt_label_reps.append(now_rep)
else:
tgt_label_reps.append(tgt_rep[label != IGNORE_INDEX])
hidden_states, attention_mask, position_ids = self.upsample(tgt_label_reps, tgt_units)
if SPEECH_GEN_CONV_KERNEL < 0:
hidden_states = self.input_proj(hidden_states)
for layer_id, layer in enumerate(self.layers):
# if SPEECH_GEN_CONV_KERNEL:
# if layer_id == self.conv_layer_id:
# hidden_states = self.temporal_conv(hidden_states.transpose(1, 2)).transpose(1, 2)
layer_outputs = layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
)
hidden_states = layer_outputs[0]
ctc_logits = self.output_proj(hidden_states)
ctc_lprobs = F.log_softmax(ctc_logits.float(), dim=-1, dtype=torch.float32)
ctc_lens = attention_mask.long().sum(dim=-1)
ctc_tgt_lens = tgt_units.ne(IGNORE_INDEX).long().sum(dim=-1)
ctc_tgt_mask = ~lengths_to_padding_mask(ctc_tgt_lens)
ctc_tgt_flat = tgt_units.masked_select(ctc_tgt_mask)
ctc_loss = F.ctc_loss(
ctc_lprobs.transpose(0, 1),
ctc_tgt_flat,
ctc_lens,
ctc_tgt_lens,
reduction="sum",
zero_infinity=True,
blank=self.unit_vocab_size
)
ctc_loss /= ctc_tgt_lens.sum().item()
return ctc_loss
def predict(self, tgt_reps):
hidden_states, attention_mask, position_ids = self.upsample([tgt_reps])
hidden_states = self.input_proj(hidden_states)
for layer in self.layers:
layer_outputs = layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
)
hidden_states = layer_outputs[0]
ctc_logits = self.output_proj(hidden_states)
ctc_lprobs = F.log_softmax(ctc_logits.float(), dim=-1, dtype=torch.float32)
ctc_pred = ctc_lprobs.argmax(dim=-1).masked_fill_(~attention_mask, self.unit_vocab_size)
return ctc_pred
class SpeechGeneratorCEQwen(nn.Module):
def __init__(self, config):
super().__init__()
n_layers, n_dims, n_heads, n_inter_dims, n_kv_heads = list(map(int, config.ctc_decoder_config[1:-1].split(",")))
_config = copy.deepcopy(config)
_config.hidden_size = n_dims
_config.num_hidden_layers = n_layers
_config.num_attention_heads = n_heads
_config.num_key_value_heads = n_kv_heads
_config.intermediate_size = n_inter_dims
_config._attn_implementation = "flash_attention_2"
self.upsample_factor = 1
self.input_proj = nn.Linear(config.hidden_size, n_dims)
self.layers = nn.ModuleList(
[Qwen2DecoderLayer(_config, layer_idx) for layer_idx in range(n_layers)]
)
self.unit_vocab_size = config.unit_vocab_size
self.output_proj = nn.Linear(n_dims, config.unit_vocab_size + 1)
def upsample(self, reps, tgt_units=None):
src_lens = torch.LongTensor([len(rep) for rep in reps]).to(reps[0].device)
up_lens = src_lens * self.upsample_factor
if tgt_units is not None:
tgt_lens = tgt_units.ne(IGNORE_INDEX).long().sum(dim=-1)
up_lens = torch.max(up_lens, tgt_lens)
reps = torch.nn.utils.rnn.pad_sequence(reps, batch_first=True)
padding_mask = lengths_to_padding_mask(up_lens)
mapped_inputs = _uniform_assignment(src_lens, up_lens).masked_fill(
padding_mask, 0
)
copied_reps = torch.gather(
reps,
1,
mapped_inputs.unsqueeze(-1).expand(
*mapped_inputs.size(), reps.size(-1)
),
)
copied_reps = copied_reps.masked_fill(padding_mask.unsqueeze(-1), 0)
position_ids = torch.arange(0, max(up_lens)).unsqueeze(0).expand(len(reps), -1).to(device=copied_reps.device)
return copied_reps, ~padding_mask, position_ids
def forward(self, tgt_reps, labels, tgt_units):
tgt_label_reps = []
for tgt_rep, label in zip(tgt_reps, labels):
# if SPEECH_GEN_CONV_KERNEL > 0:
# now_rep = tgt_rep[label != IGNORE_INDEX]
# now_rep = torch.cat([self.learnable_pad_left, now_rep, self.learnable_pad_right], dim=0)
# now_rep = self.input_proj(now_rep)[None]
# now_rep = self.temporal_conv(now_rep.transpose(1, 2)).transpose(1, 2)[0]
# tgt_label_reps.append(now_rep)
# else:
tgt_label_reps.append(tgt_rep[label != IGNORE_INDEX])
hidden_states, attention_mask, position_ids = self.upsample(tgt_label_reps, tgt_units)
# if SPEECH_GEN_CONV_KERNEL < 0:
hidden_states = self.input_proj(hidden_states)
for layer_id, layer in enumerate(self.layers):
# if SPEECH_GEN_CONV_KERNEL:
# if layer_id == self.conv_layer_id:
# hidden_states = self.temporal_conv(hidden_states.transpose(1, 2)).transpose(1, 2)
layer_outputs = layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
)
hidden_states = layer_outputs[0]
shift_hidden_states = hidden_states[..., :-1, :].contiguous().reshape(-1, hidden_states.size(-1))
logits = self.output_proj(shift_hidden_states)
shift_labels = tgt_units[..., 1:].contiguous().reshape(-1)
assert shift_labels.size(0) == shift_hidden_states.size(0)
loss_fct = nn.CrossEntropyLoss()
logits = logits.float()
loss = loss_fct(logits, shift_labels)
# loss = (loss / 1.0).sum().item()
# loss = loss.sum().item()
return loss
# def predict(self, tgt_reps):
# hidden_states, attention_mask, position_ids = self.upsample([tgt_reps])
# hidden_states = self.input_proj(hidden_states)
# for layer in self.layers:
# layer_outputs = layer(
# hidden_states,
# attention_mask=attention_mask,
# position_ids=position_ids,
# )
# hidden_states = layer_outputs[0]
# ctc_logits = self.output_proj(hidden_states)
# ctc_lprobs = F.log_softmax(ctc_logits.float(), dim=-1, dtype=torch.float32)
# ctc_pred = ctc_lprobs.argmax(dim=-1).masked_fill_(~attention_mask, self.unit_vocab_size)
# return ctc_pred
# class SpeechGeneratorCosyVoice(nn.Module):
# def __init__(self, config):
# super().__init__()
# self.input_proj = nn.Sequential(
# nn.Linear(config.hidden_size, 1024),
# nn.GELU(),
# nn.Linear(1024, 512)
# )
# self.cosyvoice1 = CosyVoice('CosyVoice/pretrained_models/CosyVoice-300M-SFT', load_jit=False, load_onnx=False, fp16=False)
# self.cosyvoice = CosyVoice('CosyVoice/pretrained_models/CosyVoice-300M-SFT', load_jit=True, load_onnx=False, fp16=True)
# self.llm = self.cosyvoice1.model.llm
# if DISTILL_EMBEDDING:
# self.criterion = nn.CosineEmbeddingLoss()
# def forward(self, tgt_reps, labels, answer):
# tgt_label_reps = []
# batch_speech_tokens = []
# embeddings = []
# target_embeddings = []
# if DISTILL_EMBEDDING:
# for tgt_rep, label, ans in zip(tgt_reps, labels, answer):
# # make all label id in [151644,151645,198] to IGNORE_INDEX
# label[label == 151644] = IGNORE_INDEX
# label[label == 151645] = IGNORE_INDEX
# label[label == 198] = IGNORE_INDEX
# tgt_label_reps.append(tgt_rep[label != IGNORE_INDEX])
# normalized_text = self.cosyvoice1.frontend.text_normalize(ans, split=True)
# tts_text_token_all = []
# for norm_text in normalized_text:
# tts_text_token, tts_text_token_len = self.cosyvoice1.frontend._extract_text_token(norm_text)
# tts_text_token_all.append(tts_text_token)
# tts_text_token_all = torch.cat(tts_text_token_all, dim=0)
# target_embedding = self.cosyvoice1.model.llm.text_embedding(tts_text_token)
# target_embeddings.append(target_embedding)
# import pdb;pdb.set_trace()
# tgt_label_reps = torch.stack(tgt_label_reps)
# target_embeddings = torch.stack(target_embeddings).squeeze(1)
# hidden_states = self.input_proj(tgt_label_reps).reshape(-1, 512)
# target_embeddings = target_embeddings.reshape(-1, 512)
# loss = self.criterion(hidden_states, target_embeddings, torch.ones(hidden_states.size(0)).to(hidden_states.device))
# else:
# for tgt_rep, label, ans in zip(tgt_reps, labels, answer):
# # make all label id in [151644,151645,198] to IGNORE_INDEX
# label[label == 151644] = IGNORE_INDEX
# label[label == 151645] = IGNORE_INDEX
# label[label == 198] = IGNORE_INDEX
# tgt_label_reps.append(tgt_rep[label != IGNORE_INDEX])
# speech_token = self.cosyvoice.inference_label(ans, '英文女', stream=False)
# speech_tokens = []
# for i,j in enumerate(speech_token):
# speech_tokens.append(j['tts_speech_token'].squeeze(0))
# speech_tokens.append(torch.tensor([0]))
# speech_tokens = torch.cat(speech_tokens, dim=0)
# if speech_tokens.size(0) > 1:
# speech_tokens = speech_tokens[:-1]
# batch_speech_tokens.append(speech_tokens)
# embedding = self.cosyvoice.frontend.frontend_embedding('英文女')
# embeddings.append(embedding['llm_embedding'].squeeze(0))
# tgt_label_reps = torch.stack(tgt_label_reps)
# batch_speech_token = torch.stack(batch_speech_tokens)
# embeddings = torch.stack(embeddings)
# hidden_states = self.input_proj(tgt_label_reps)
# batch = {'text_feature': hidden_states, 'text_token_len': torch.tensor([hidden_states.size(1)]).repeat(hidden_states.size(0)),
# 'speech_token': batch_speech_token, 'speech_token_len': torch.tensor([batch_speech_token.size(1)]).repeat(hidden_states.size(0)),
# 'embedding': embeddings}
# output = self.llm.forward_ours(batch, 'cuda')
# loss = output['loss']
# return loss
class SpeechGeneratorCosyVoice(nn.Module):
def __init__(self, config):
super().__init__()
self.cosyvoice = CosyVoice('CosyVoice/pretrained_models/CosyVoice-300M-SFT', load_jit=True, load_onnx=False, fp16=True)
def forward(self, tgt_reps, labels, answer):
tgt_label_reps = []
batch_speech_tokens = []
embeddings = []
target_embeddings = []
if DISTILL_EMBEDDING:
for tgt_rep, label, ans in zip(tgt_reps, labels, answer):
# make all label id in [151644,151645,198] to IGNORE_INDEX
label[label == 151644] = IGNORE_INDEX
label[label == 151645] = IGNORE_INDEX
label[label == 198] = IGNORE_INDEX
tgt_label_reps.append(tgt_rep[label != IGNORE_INDEX])
normalized_text = self.cosyvoice1.frontend.text_normalize(ans, split=True)
tts_text_token_all = []
for norm_text in normalized_text:
tts_text_token, tts_text_token_len = self.cosyvoice1.frontend._extract_text_token(norm_text)
tts_text_token_all.append(tts_text_token)
tts_text_token_all = torch.cat(tts_text_token_all, dim=0)
target_embedding = self.cosyvoice1.model.llm.text_embedding(tts_text_token)
target_embeddings.append(target_embedding)
import pdb;pdb.set_trace()
tgt_label_reps = torch.stack(tgt_label_reps)
target_embeddings = torch.stack(target_embeddings).squeeze(1)
hidden_states = self.input_proj(tgt_label_reps).reshape(-1, 512)
target_embeddings = target_embeddings.reshape(-1, 512)
loss = self.criterion(hidden_states, target_embeddings, torch.ones(hidden_states.size(0)).to(hidden_states.device))
else:
for tgt_rep, label, ans in zip(tgt_reps, labels, answer):
# make all label id in [151644,151645,198] to IGNORE_INDEX
label[label == 151644] = IGNORE_INDEX
label[label == 151645] = IGNORE_INDEX
label[label == 198] = IGNORE_INDEX
tgt_label_reps.append(tgt_rep[label != IGNORE_INDEX])
speech_token = self.cosyvoice.inference_label(ans, '英文女', stream=False)
speech_tokens = []
for i,j in enumerate(speech_token):
speech_tokens.append(j['tts_speech_token'].squeeze(0))
speech_tokens.append(torch.tensor([0]))
speech_tokens = torch.cat(speech_tokens, dim=0)
if speech_tokens.size(0) > 1:
speech_tokens = speech_tokens[:-1]
batch_speech_tokens.append(speech_tokens)
embedding = self.cosyvoice.frontend.frontend_embedding('英文女')
embeddings.append(embedding['llm_embedding'].squeeze(0))
tgt_label_reps = torch.stack(tgt_label_reps)
batch_speech_token = torch.stack(batch_speech_tokens)
embeddings = torch.stack(embeddings)
hidden_states = self.input_proj(tgt_label_reps)
batch = {'text_feature': hidden_states, 'text_token_len': torch.tensor([hidden_states.size(1)]).repeat(hidden_states.size(0)),
'speech_token': batch_speech_token, 'speech_token_len': torch.tensor([batch_speech_token.size(1)]).repeat(hidden_states.size(0)),
'embedding': embeddings}
output = self.llm.forward_ours(batch, 'cuda')
loss = output['loss']
return loss