lxysl's picture
upload vita-1.5 app.py
bc752b1
import torch
import math
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, List, Tuple, Optional, Union
from transformers import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaRMSNorm, LlamaRotaryEmbedding
from transformers.cache_utils import DynamicCache
from vita.model.vita_tts.encoder.encoder import add_encoder_args
from vita.model.vita_tts.masks import *
IGNORE_ID = -1
class CrossEntropyLoss(torch.nn.Module):
def __init__(self, ignore_index=-1):
super(CrossEntropyLoss, self).__init__()
self.criterion = torch.nn.CrossEntropyLoss(reduction='sum', ignore_index=ignore_index)
def forward(self, logits, target, target_subsampling_factor=1):
"""
logits: B*T1*D
target: B*T2
"""
logits = logits[:, :target.shape[1], :]
logits = logits.transpose(1, 2)
target = target.to(torch.long)
loss = self.criterion(logits, target)
return loss
class LLM2TTSCodecAR(torch.nn.Module):
"""E2E module.
Args:
idim (int): dimension of inputs
odim (int): dimension of outputs
args (namespace): argument Namespace containing options
"""
@staticmethod
def add_arguments(parser):
"""Extend arguments for transducer."""
group = parser.add_argument_group("TDNN model setting")
group.add_argument('--encoder-pre-norm-type',
default='ln', type=str, help="Type of input norm.")
group.add_argument('--encoder-drop-rate', default=0.0,
type=float, help="Dropout rate for output.")
group.add_argument('--encoder-criterion', default='cross-entropy',
type=str, help="Criterion for output")
group.add_argument('--encoder-upsample-rate', default=1, type=int)
group.add_argument('--kv-cache-prefix-finetune', default=0, type=int)
group = add_encoder_args(group)
return parser
def __init__(self, idim, odim, args):
"""Initialize transducer modules.
Args:
idim (int): dimension of inputs
odim (int): dimension of outputs
args (Namespace): argument Namespace containing options
"""
super(LLM2TTSCodecAR, self).__init__()
self.idim = args.idim
self.odim = args.odim
self.encoder_pre_norm_type = args.encoder_pre_norm_type
self.encoder_drop_rate = args.encoder_drop_rate
self.encoder_criterion = args.encoder_criterion
self.encoder_upsample_rate = args.encoder_upsample_rate
self.reporter = None
self.vocab_size = self.odim
config = LlamaConfig(vocab_size=self.vocab_size + 4, hidden_size=args.transformer_attention_dim,
intermediate_size=args.transformer_linear_units,
num_hidden_layers=args.transformer_num_blocks,
num_attention_heads=args.transformer_attention_heads, max_position_embeddings=2048,
bos_token_id=self.vocab_size + 1,
eos_token_id=self.vocab_size + 2, pad_token_id=self.vocab_size + 3,
attention_dropout=args.transformer_dropout_rate)
self.embedding = nn.Embedding(self.vocab_size + 4, self.idim, padding_idx=self.vocab_size + 3)
self.init_pre_nn(config)
self.layers = nn.ModuleList(
[LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = LlamaRotaryEmbedding(config=config)
self.dropout = nn.Dropout(p=self.encoder_drop_rate)
self.out_fnn = nn.Linear(args.encoder_output_dim, self.vocab_size + 4)
self.kv_cache_prefix_finetune = args.kv_cache_prefix_finetune
if self.kv_cache_prefix_finetune:
self.init_kv_cache_prefix(config)
self.embedding.eval()
self.layers.eval()
self.norm.eval()
self.rotary_emb.eval()
self.out_fnn.eval()
for (name, param) in self.embedding.named_parameters():
param.requires_grad = False
for (name, param) in self.layers.named_parameters():
param.requires_grad = False
for (name, param) in self.norm.named_parameters():
param.requires_grad = False
for (name, param) in self.rotary_emb.named_parameters():
param.requires_grad = False
for (name, param) in self.out_fnn.named_parameters():
param.requires_grad = False
if self.encoder_criterion == 'ce':
self.criterion = CrossEntropyLoss(ignore_index=self.vocab_size + 3)
def init_kv_cache_prefix(self, config):
self.layers_prefix = nn.ModuleList(
[LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.rotary_emb_prefix = LlamaRotaryEmbedding(config=config)
def kv_cache_prefix_forward(self, prefix, prefix_lens, past_key_values):
inputs_embeds = prefix
past_seen_tokens = 0
cache_position = torch.arange(past_seen_tokens, past_seen_tokens + \
inputs_embeds.shape[1], device=inputs_embeds.device)
position_ids = cache_position.unsqueeze(0)
hidden_states = inputs_embeds
position_embeddings = self.rotary_emb_prefix(hidden_states, position_ids)
next_decoder_cache = None
batch_size, max_len, _ = prefix.size()
input_mask = torch.zeros(batch_size, max_len, max_len, dtype=torch.bool, device=prefix.device)
for i in range(batch_size):
input_mask[i, :prefix_lens[i], :prefix_lens[i]] = True
attention_mask = ~(input_mask.unsqueeze(1)) * torch.finfo(inputs_embeds.dtype).min
for decoder_layer in self.layers_prefix:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=False,
use_cache=True,
cache_position=None,
position_embeddings=position_embeddings,
)
hidden_states = layer_outputs[0]
next_decoder_cache = layer_outputs[1]
past_key_values = next_decoder_cache
def init_pre_nn(self, config):
self.layers_pre_nn = nn.ModuleList(
[LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers // 2)]
)
self.rotary_emb_pre_nn = LlamaRotaryEmbedding(config=config)
def pre_nn_forward(self, hidden, hidden_lens):
inputs_embeds = hidden
past_seen_tokens = 0
cache_position = torch.arange(past_seen_tokens, past_seen_tokens + \
inputs_embeds.shape[1], device=inputs_embeds.device)
position_ids = cache_position.unsqueeze(0)
hidden_states = inputs_embeds
position_embeddings = self.rotary_emb_pre_nn(hidden_states, position_ids)
next_decoder_cache = None
batch_size, max_len, _ = hidden.size()
input_mask = torch.zeros(batch_size, max_len, max_len, dtype=torch.bool, device=hidden.device)
for i in range(batch_size):
input_mask[i, :hidden_lens[i], :hidden_lens[i]] = True
attention_mask = ~(input_mask.unsqueeze(1)) * torch.finfo(inputs_embeds.dtype).min
for decoder_layer in self.layers_pre_nn:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=None,
output_attentions=False,
use_cache=False,
cache_position=None,
position_embeddings=position_embeddings,
)
hidden_states = layer_outputs[0]
return hidden_states
def forward(self, batch):
llm_hidden = batch['x']
llm_hidden_lens = batch['x_lens']
y = batch['y']
y[y == IGNORE_ID] = self.vocab_size + 3
y_lens = batch['y_lens']
past_key_values = DynamicCache.from_legacy_cache(None)
if self.kv_cache_prefix_finetune:
self.kv_cache_prefix_forward(batch['x_prefix'], batch['x_prefix_lens'], past_key_values)
# text_ids: (batch_size, max_len)
batch_size, max_len = y.size()
# Create bos, sos and eos tokens
bos_token = torch.full((batch_size, 1), self.vocab_size, dtype=torch.long, device=y.device)
sos_token = torch.full((batch_size, 1), self.vocab_size + 1, dtype=torch.long, device=y.device)
eos_token = torch.full((batch_size, 1), self.vocab_size + 2, dtype=torch.long, device=y.device)
padding_token = torch.full((batch_size, 1), self.vocab_size + 3, dtype=torch.long, device=y.device)
# Pass through pre_nn
llm_hidden = self.pre_nn_forward(llm_hidden, llm_hidden_lens)
# Concat bos embedding
bos_emb = self.embedding(bos_token)
llm_hidden = torch.cat([bos_emb, llm_hidden], dim=1)
llm_hidden_lens = llm_hidden_lens + 1
# Create input x with sos token at the beginning
x = torch.cat([sos_token, y], dim=1) # (batch_size, max_len + 1)
# Create output y with eos token at the end
y = torch.cat([y, padding_token], dim=1)
eos_positions = torch.arange(max_len + 1, device=y.device).expand(batch_size, max_len + 1) \
== y_lens.unsqueeze(1)
y = y.masked_scatter(eos_positions, eos_token.expand_as(y)[eos_positions])
# Embed the input sequence
x_emb = self.embedding(x) # (batch_size, max_len + 1, d_model)
# compute masks
if self.kv_cache_prefix_finetune:
x_prefix = batch['x_prefix']
x_prefix_lens = batch['x_prefix_lens']
input_lens = llm_hidden.size(1) + max_len + 1
input_mask = torch.zeros(batch_size, input_lens, x_prefix.size(1) + input_lens, \
dtype=torch.bool, device=x_emb.device)
for i in range(batch_size):
input_mask[i, :llm_hidden_lens[i], :x_prefix_lens[i]] = True
input_mask[i, :llm_hidden_lens[i], x_prefix.size(1): x_prefix.size(1) + llm_hidden_lens[i]] = True
input_mask[i, llm_hidden.size(1): llm_hidden.size(1) + y_lens[i] + 1, :x_prefix_lens[i]] = True
input_mask[i, llm_hidden.size(1): llm_hidden.size(1) + y_lens[i] + 1, \
x_prefix.size(1): x_prefix.size(1) + llm_hidden_lens[i]] = True
input_mask[i, llm_hidden.size(1): llm_hidden.size(1) + y_lens[i] + 1, \
x_prefix.size(1) + llm_hidden.size(1): x_prefix.size(1) + \
llm_hidden.size(1) + y_lens[i] + 1] \
= subsequent_mask(y_lens[i] + 1, x_emb.device)
else:
input_lens = llm_hidden.size(1) + max_len + 1
input_mask = torch.zeros(batch_size, input_lens, input_lens, dtype=torch.bool, device=x_emb.device)
for i in range(batch_size):
input_mask[i, :llm_hidden_lens[i], :llm_hidden_lens[i]] = True
input_mask[i, llm_hidden.size(1): llm_hidden.size(1) + y_lens[i] + 1, :llm_hidden_lens[i]] = True
input_mask[i, llm_hidden.size(1): llm_hidden.size(1) + y_lens[i] + 1, \
llm_hidden.size(1): llm_hidden.size(1) + y_lens[i] + 1] \
= subsequent_mask(y_lens[i] + 1, x_emb.device)
# Pass through the transformer
inputs_embeds = torch.cat([llm_hidden, x_emb], 1)
llm_hidden = self.dropout(llm_hidden)
past_seen_tokens = 0
cache_position = torch.arange(past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], \
device=inputs_embeds.device)
position_ids = cache_position.unsqueeze(0)
hidden_states = inputs_embeds
position_embeddings = self.rotary_emb(hidden_states, position_ids)
attention_mask = ~(input_mask.unsqueeze(1)) * torch.finfo(inputs_embeds.dtype).min
for decoder_layer in self.layers:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=False,
use_cache=True,
cache_position=None,
position_embeddings=position_embeddings,
)
hidden_states = layer_outputs[0]
hidden_states = self.norm(hidden_states)
encoder_out = hidden_states[:, llm_hidden.size(1):]
# Project to vocabulary size
logits = self.out_fnn(encoder_out)
if self.encoder_criterion == 'ce':
loss = self.criterion(logits, y)
if self.training:
self.reporter.log_loss('loss', float(loss))
return loss
def transformer_infer(self, inputs_embeds, cache_position, past_key_values):
position_ids = cache_position.unsqueeze(0)
hidden_states = inputs_embeds
position_embeddings = self.rotary_emb(hidden_states, position_ids)
next_decoder_cache = None
for decoder_layer in self.layers:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=None,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=False,
use_cache=True,
cache_position=None,
position_embeddings=position_embeddings,
)
hidden_states = layer_outputs[0]
next_decoder_cache = layer_outputs[1]
return hidden_states
def infer(self, hidden, top_k, prefix, penalty_window_size, penalty, max_tokens=1000):
# Pass through pre_nn
hidden = self.pre_nn_forward(hidden, [hidden.size(1)])
# Concat bos embedding
bos_emb = self.embedding(torch.full((1, 1), self.vocab_size, dtype=torch.long, device=hidden.device))
hidden = torch.cat([bos_emb, hidden], dim=1)
# init past key values
past_key_values = DynamicCache.from_legacy_cache(None)
# Pass through the prefix nar decoder
if prefix is not None and self.kv_cache_prefix_finetune:
self.kv_cache_prefix_forward(prefix, [prefix.size(1)], past_key_values)
inputs_embeds = hidden
past_seen_tokens = 0
cache_position = torch.arange(past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], \
device=inputs_embeds.device)
hidden_states = self.transformer_infer(inputs_embeds, cache_position, past_key_values)
# init generated tokens
cur_token = torch.full((1, 1), self.vocab_size + 1, dtype=torch.long, device=hidden.device)
generated_tokens = torch.full((1, 1), self.vocab_size + 1, dtype=torch.long, device=hidden.device)
# generate tokens
for i in range(max_tokens):
inputs_embeds = self.embedding(cur_token)
past_seen_tokens = past_key_values.get_seq_length()
if prefix is not None:
past_seen_tokens = past_seen_tokens - prefix.size(1)
cache_position = torch.arange(past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], \
device=inputs_embeds.device)
hidden_states = self.transformer_infer(inputs_embeds, cache_position, past_key_values)
hidden_states = self.norm(hidden_states)
# Project to vocabulary size
logits = self.out_fnn(hidden_states)
# apply penalty
if penalty_window_size > 0:
for token in set(generated_tokens[0][-penalty_window_size:]):
logits[:, :, token] /= penalty
# top k sampling
output = logits.squeeze(0).squeeze(0)
probs = torch.nn.functional.softmax(output, dim=-1)
top_k_probs, top_k_indices = torch.topk(probs, top_k)
probs = torch.zeros_like(probs).scatter_(0, top_k_indices, top_k_probs)
probs = probs / probs.sum()
next_token_id = torch.multinomial(probs, 1).unsqueeze(0)
generated_tokens = torch.cat([generated_tokens, next_token_id], dim=-1)
cur_token = next_token_id
# eos
if next_token_id == self.vocab_size + 2:
break
yield next_token_id