Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import math | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from typing import Dict, List, Tuple, Optional, Union | |
from transformers import LlamaConfig | |
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaRMSNorm, LlamaRotaryEmbedding | |
from transformers.cache_utils import DynamicCache | |
from vita.model.vita_tts.encoder.encoder import add_encoder_args | |
from vita.model.vita_tts.masks import * | |
IGNORE_ID = -1 | |
class CrossEntropyLoss(torch.nn.Module): | |
def __init__(self, ignore_index=-1): | |
super(CrossEntropyLoss, self).__init__() | |
self.criterion = torch.nn.CrossEntropyLoss(reduction='sum', ignore_index=ignore_index) | |
def forward(self, logits, target, target_subsampling_factor=1): | |
""" | |
logits: B*T1*D | |
target: B*T2 | |
""" | |
logits = logits[:, :target.shape[1], :] | |
logits = logits.transpose(1, 2) | |
target = target.to(torch.long) | |
loss = self.criterion(logits, target) | |
return loss | |
class LLM2TTSCodecAR(torch.nn.Module): | |
"""E2E module. | |
Args: | |
idim (int): dimension of inputs | |
odim (int): dimension of outputs | |
args (namespace): argument Namespace containing options | |
""" | |
def add_arguments(parser): | |
"""Extend arguments for transducer.""" | |
group = parser.add_argument_group("TDNN model setting") | |
group.add_argument('--encoder-pre-norm-type', | |
default='ln', type=str, help="Type of input norm.") | |
group.add_argument('--encoder-drop-rate', default=0.0, | |
type=float, help="Dropout rate for output.") | |
group.add_argument('--encoder-criterion', default='cross-entropy', | |
type=str, help="Criterion for output") | |
group.add_argument('--encoder-upsample-rate', default=1, type=int) | |
group.add_argument('--kv-cache-prefix-finetune', default=0, type=int) | |
group = add_encoder_args(group) | |
return parser | |
def __init__(self, idim, odim, args): | |
"""Initialize transducer modules. | |
Args: | |
idim (int): dimension of inputs | |
odim (int): dimension of outputs | |
args (Namespace): argument Namespace containing options | |
""" | |
super(LLM2TTSCodecAR, self).__init__() | |
self.idim = args.idim | |
self.odim = args.odim | |
self.encoder_pre_norm_type = args.encoder_pre_norm_type | |
self.encoder_drop_rate = args.encoder_drop_rate | |
self.encoder_criterion = args.encoder_criterion | |
self.encoder_upsample_rate = args.encoder_upsample_rate | |
self.reporter = None | |
self.vocab_size = self.odim | |
config = LlamaConfig(vocab_size=self.vocab_size + 4, hidden_size=args.transformer_attention_dim, | |
intermediate_size=args.transformer_linear_units, | |
num_hidden_layers=args.transformer_num_blocks, | |
num_attention_heads=args.transformer_attention_heads, max_position_embeddings=2048, | |
bos_token_id=self.vocab_size + 1, | |
eos_token_id=self.vocab_size + 2, pad_token_id=self.vocab_size + 3, | |
attention_dropout=args.transformer_dropout_rate) | |
self.embedding = nn.Embedding(self.vocab_size + 4, self.idim, padding_idx=self.vocab_size + 3) | |
self.init_pre_nn(config) | |
self.layers = nn.ModuleList( | |
[LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] | |
) | |
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) | |
self.rotary_emb = LlamaRotaryEmbedding(config=config) | |
self.dropout = nn.Dropout(p=self.encoder_drop_rate) | |
self.out_fnn = nn.Linear(args.encoder_output_dim, self.vocab_size + 4) | |
self.kv_cache_prefix_finetune = args.kv_cache_prefix_finetune | |
if self.kv_cache_prefix_finetune: | |
self.init_kv_cache_prefix(config) | |
self.embedding.eval() | |
self.layers.eval() | |
self.norm.eval() | |
self.rotary_emb.eval() | |
self.out_fnn.eval() | |
for (name, param) in self.embedding.named_parameters(): | |
param.requires_grad = False | |
for (name, param) in self.layers.named_parameters(): | |
param.requires_grad = False | |
for (name, param) in self.norm.named_parameters(): | |
param.requires_grad = False | |
for (name, param) in self.rotary_emb.named_parameters(): | |
param.requires_grad = False | |
for (name, param) in self.out_fnn.named_parameters(): | |
param.requires_grad = False | |
if self.encoder_criterion == 'ce': | |
self.criterion = CrossEntropyLoss(ignore_index=self.vocab_size + 3) | |
def init_kv_cache_prefix(self, config): | |
self.layers_prefix = nn.ModuleList( | |
[LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] | |
) | |
self.rotary_emb_prefix = LlamaRotaryEmbedding(config=config) | |
def kv_cache_prefix_forward(self, prefix, prefix_lens, past_key_values): | |
inputs_embeds = prefix | |
past_seen_tokens = 0 | |
cache_position = torch.arange(past_seen_tokens, past_seen_tokens + \ | |
inputs_embeds.shape[1], device=inputs_embeds.device) | |
position_ids = cache_position.unsqueeze(0) | |
hidden_states = inputs_embeds | |
position_embeddings = self.rotary_emb_prefix(hidden_states, position_ids) | |
next_decoder_cache = None | |
batch_size, max_len, _ = prefix.size() | |
input_mask = torch.zeros(batch_size, max_len, max_len, dtype=torch.bool, device=prefix.device) | |
for i in range(batch_size): | |
input_mask[i, :prefix_lens[i], :prefix_lens[i]] = True | |
attention_mask = ~(input_mask.unsqueeze(1)) * torch.finfo(inputs_embeds.dtype).min | |
for decoder_layer in self.layers_prefix: | |
layer_outputs = decoder_layer( | |
hidden_states, | |
attention_mask=attention_mask, | |
position_ids=position_ids, | |
past_key_value=past_key_values, | |
output_attentions=False, | |
use_cache=True, | |
cache_position=None, | |
position_embeddings=position_embeddings, | |
) | |
hidden_states = layer_outputs[0] | |
next_decoder_cache = layer_outputs[1] | |
past_key_values = next_decoder_cache | |
def init_pre_nn(self, config): | |
self.layers_pre_nn = nn.ModuleList( | |
[LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers // 2)] | |
) | |
self.rotary_emb_pre_nn = LlamaRotaryEmbedding(config=config) | |
def pre_nn_forward(self, hidden, hidden_lens): | |
inputs_embeds = hidden | |
past_seen_tokens = 0 | |
cache_position = torch.arange(past_seen_tokens, past_seen_tokens + \ | |
inputs_embeds.shape[1], device=inputs_embeds.device) | |
position_ids = cache_position.unsqueeze(0) | |
hidden_states = inputs_embeds | |
position_embeddings = self.rotary_emb_pre_nn(hidden_states, position_ids) | |
next_decoder_cache = None | |
batch_size, max_len, _ = hidden.size() | |
input_mask = torch.zeros(batch_size, max_len, max_len, dtype=torch.bool, device=hidden.device) | |
for i in range(batch_size): | |
input_mask[i, :hidden_lens[i], :hidden_lens[i]] = True | |
attention_mask = ~(input_mask.unsqueeze(1)) * torch.finfo(inputs_embeds.dtype).min | |
for decoder_layer in self.layers_pre_nn: | |
layer_outputs = decoder_layer( | |
hidden_states, | |
attention_mask=attention_mask, | |
position_ids=position_ids, | |
past_key_value=None, | |
output_attentions=False, | |
use_cache=False, | |
cache_position=None, | |
position_embeddings=position_embeddings, | |
) | |
hidden_states = layer_outputs[0] | |
return hidden_states | |
def forward(self, batch): | |
llm_hidden = batch['x'] | |
llm_hidden_lens = batch['x_lens'] | |
y = batch['y'] | |
y[y == IGNORE_ID] = self.vocab_size + 3 | |
y_lens = batch['y_lens'] | |
past_key_values = DynamicCache.from_legacy_cache(None) | |
if self.kv_cache_prefix_finetune: | |
self.kv_cache_prefix_forward(batch['x_prefix'], batch['x_prefix_lens'], past_key_values) | |
# text_ids: (batch_size, max_len) | |
batch_size, max_len = y.size() | |
# Create bos, sos and eos tokens | |
bos_token = torch.full((batch_size, 1), self.vocab_size, dtype=torch.long, device=y.device) | |
sos_token = torch.full((batch_size, 1), self.vocab_size + 1, dtype=torch.long, device=y.device) | |
eos_token = torch.full((batch_size, 1), self.vocab_size + 2, dtype=torch.long, device=y.device) | |
padding_token = torch.full((batch_size, 1), self.vocab_size + 3, dtype=torch.long, device=y.device) | |
# Pass through pre_nn | |
llm_hidden = self.pre_nn_forward(llm_hidden, llm_hidden_lens) | |
# Concat bos embedding | |
bos_emb = self.embedding(bos_token) | |
llm_hidden = torch.cat([bos_emb, llm_hidden], dim=1) | |
llm_hidden_lens = llm_hidden_lens + 1 | |
# Create input x with sos token at the beginning | |
x = torch.cat([sos_token, y], dim=1) # (batch_size, max_len + 1) | |
# Create output y with eos token at the end | |
y = torch.cat([y, padding_token], dim=1) | |
eos_positions = torch.arange(max_len + 1, device=y.device).expand(batch_size, max_len + 1) \ | |
== y_lens.unsqueeze(1) | |
y = y.masked_scatter(eos_positions, eos_token.expand_as(y)[eos_positions]) | |
# Embed the input sequence | |
x_emb = self.embedding(x) # (batch_size, max_len + 1, d_model) | |
# compute masks | |
if self.kv_cache_prefix_finetune: | |
x_prefix = batch['x_prefix'] | |
x_prefix_lens = batch['x_prefix_lens'] | |
input_lens = llm_hidden.size(1) + max_len + 1 | |
input_mask = torch.zeros(batch_size, input_lens, x_prefix.size(1) + input_lens, \ | |
dtype=torch.bool, device=x_emb.device) | |
for i in range(batch_size): | |
input_mask[i, :llm_hidden_lens[i], :x_prefix_lens[i]] = True | |
input_mask[i, :llm_hidden_lens[i], x_prefix.size(1): x_prefix.size(1) + llm_hidden_lens[i]] = True | |
input_mask[i, llm_hidden.size(1): llm_hidden.size(1) + y_lens[i] + 1, :x_prefix_lens[i]] = True | |
input_mask[i, llm_hidden.size(1): llm_hidden.size(1) + y_lens[i] + 1, \ | |
x_prefix.size(1): x_prefix.size(1) + llm_hidden_lens[i]] = True | |
input_mask[i, llm_hidden.size(1): llm_hidden.size(1) + y_lens[i] + 1, \ | |
x_prefix.size(1) + llm_hidden.size(1): x_prefix.size(1) + \ | |
llm_hidden.size(1) + y_lens[i] + 1] \ | |
= subsequent_mask(y_lens[i] + 1, x_emb.device) | |
else: | |
input_lens = llm_hidden.size(1) + max_len + 1 | |
input_mask = torch.zeros(batch_size, input_lens, input_lens, dtype=torch.bool, device=x_emb.device) | |
for i in range(batch_size): | |
input_mask[i, :llm_hidden_lens[i], :llm_hidden_lens[i]] = True | |
input_mask[i, llm_hidden.size(1): llm_hidden.size(1) + y_lens[i] + 1, :llm_hidden_lens[i]] = True | |
input_mask[i, llm_hidden.size(1): llm_hidden.size(1) + y_lens[i] + 1, \ | |
llm_hidden.size(1): llm_hidden.size(1) + y_lens[i] + 1] \ | |
= subsequent_mask(y_lens[i] + 1, x_emb.device) | |
# Pass through the transformer | |
inputs_embeds = torch.cat([llm_hidden, x_emb], 1) | |
llm_hidden = self.dropout(llm_hidden) | |
past_seen_tokens = 0 | |
cache_position = torch.arange(past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], \ | |
device=inputs_embeds.device) | |
position_ids = cache_position.unsqueeze(0) | |
hidden_states = inputs_embeds | |
position_embeddings = self.rotary_emb(hidden_states, position_ids) | |
attention_mask = ~(input_mask.unsqueeze(1)) * torch.finfo(inputs_embeds.dtype).min | |
for decoder_layer in self.layers: | |
layer_outputs = decoder_layer( | |
hidden_states, | |
attention_mask=attention_mask, | |
position_ids=position_ids, | |
past_key_value=past_key_values, | |
output_attentions=False, | |
use_cache=True, | |
cache_position=None, | |
position_embeddings=position_embeddings, | |
) | |
hidden_states = layer_outputs[0] | |
hidden_states = self.norm(hidden_states) | |
encoder_out = hidden_states[:, llm_hidden.size(1):] | |
# Project to vocabulary size | |
logits = self.out_fnn(encoder_out) | |
if self.encoder_criterion == 'ce': | |
loss = self.criterion(logits, y) | |
if self.training: | |
self.reporter.log_loss('loss', float(loss)) | |
return loss | |
def transformer_infer(self, inputs_embeds, cache_position, past_key_values): | |
position_ids = cache_position.unsqueeze(0) | |
hidden_states = inputs_embeds | |
position_embeddings = self.rotary_emb(hidden_states, position_ids) | |
next_decoder_cache = None | |
for decoder_layer in self.layers: | |
layer_outputs = decoder_layer( | |
hidden_states, | |
attention_mask=None, | |
position_ids=position_ids, | |
past_key_value=past_key_values, | |
output_attentions=False, | |
use_cache=True, | |
cache_position=None, | |
position_embeddings=position_embeddings, | |
) | |
hidden_states = layer_outputs[0] | |
next_decoder_cache = layer_outputs[1] | |
return hidden_states | |
def infer(self, hidden, top_k, prefix, penalty_window_size, penalty, max_tokens=1000): | |
# Pass through pre_nn | |
hidden = self.pre_nn_forward(hidden, [hidden.size(1)]) | |
# Concat bos embedding | |
bos_emb = self.embedding(torch.full((1, 1), self.vocab_size, dtype=torch.long, device=hidden.device)) | |
hidden = torch.cat([bos_emb, hidden], dim=1) | |
# init past key values | |
past_key_values = DynamicCache.from_legacy_cache(None) | |
# Pass through the prefix nar decoder | |
if prefix is not None and self.kv_cache_prefix_finetune: | |
self.kv_cache_prefix_forward(prefix, [prefix.size(1)], past_key_values) | |
inputs_embeds = hidden | |
past_seen_tokens = 0 | |
cache_position = torch.arange(past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], \ | |
device=inputs_embeds.device) | |
hidden_states = self.transformer_infer(inputs_embeds, cache_position, past_key_values) | |
# init generated tokens | |
cur_token = torch.full((1, 1), self.vocab_size + 1, dtype=torch.long, device=hidden.device) | |
generated_tokens = torch.full((1, 1), self.vocab_size + 1, dtype=torch.long, device=hidden.device) | |
# generate tokens | |
for i in range(max_tokens): | |
inputs_embeds = self.embedding(cur_token) | |
past_seen_tokens = past_key_values.get_seq_length() | |
if prefix is not None: | |
past_seen_tokens = past_seen_tokens - prefix.size(1) | |
cache_position = torch.arange(past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], \ | |
device=inputs_embeds.device) | |
hidden_states = self.transformer_infer(inputs_embeds, cache_position, past_key_values) | |
hidden_states = self.norm(hidden_states) | |
# Project to vocabulary size | |
logits = self.out_fnn(hidden_states) | |
# apply penalty | |
if penalty_window_size > 0: | |
for token in set(generated_tokens[0][-penalty_window_size:]): | |
logits[:, :, token] /= penalty | |
# top k sampling | |
output = logits.squeeze(0).squeeze(0) | |
probs = torch.nn.functional.softmax(output, dim=-1) | |
top_k_probs, top_k_indices = torch.topk(probs, top_k) | |
probs = torch.zeros_like(probs).scatter_(0, top_k_indices, top_k_probs) | |
probs = probs / probs.sum() | |
next_token_id = torch.multinomial(probs, 1).unsqueeze(0) | |
generated_tokens = torch.cat([generated_tokens, next_token_id], dim=-1) | |
cur_token = next_token_id | |
# eos | |
if next_token_id == self.vocab_size + 2: | |
break | |
yield next_token_id | |