import torch import math import torch.nn as nn import torch.nn.functional as F from typing import Dict, List, Tuple, Optional, Union from transformers import LlamaConfig from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaRMSNorm, LlamaRotaryEmbedding from transformers.cache_utils import DynamicCache from vita.model.vita_tts.encoder.encoder import add_encoder_args from vita.model.vita_tts.masks import * IGNORE_ID = -1 class CrossEntropyLoss(torch.nn.Module): def __init__(self, ignore_index=-1): super(CrossEntropyLoss, self).__init__() self.criterion = torch.nn.CrossEntropyLoss(reduction='sum', ignore_index=ignore_index) def forward(self, logits, target, target_subsampling_factor=1): """ logits: B*T1*D target: B*T2 """ logits = logits[:, :target.shape[1], :] logits = logits.transpose(1, 2) target = target.to(torch.long) loss = self.criterion(logits, target) return loss class LLM2TTSCodecAR(torch.nn.Module): """E2E module. Args: idim (int): dimension of inputs odim (int): dimension of outputs args (namespace): argument Namespace containing options """ @staticmethod def add_arguments(parser): """Extend arguments for transducer.""" group = parser.add_argument_group("TDNN model setting") group.add_argument('--encoder-pre-norm-type', default='ln', type=str, help="Type of input norm.") group.add_argument('--encoder-drop-rate', default=0.0, type=float, help="Dropout rate for output.") group.add_argument('--encoder-criterion', default='cross-entropy', type=str, help="Criterion for output") group.add_argument('--encoder-upsample-rate', default=1, type=int) group.add_argument('--kv-cache-prefix-finetune', default=0, type=int) group = add_encoder_args(group) return parser def __init__(self, idim, odim, args): """Initialize transducer modules. Args: idim (int): dimension of inputs odim (int): dimension of outputs args (Namespace): argument Namespace containing options """ super(LLM2TTSCodecAR, self).__init__() self.idim = args.idim self.odim = args.odim self.encoder_pre_norm_type = args.encoder_pre_norm_type self.encoder_drop_rate = args.encoder_drop_rate self.encoder_criterion = args.encoder_criterion self.encoder_upsample_rate = args.encoder_upsample_rate self.reporter = None self.vocab_size = self.odim config = LlamaConfig(vocab_size=self.vocab_size + 4, hidden_size=args.transformer_attention_dim, intermediate_size=args.transformer_linear_units, num_hidden_layers=args.transformer_num_blocks, num_attention_heads=args.transformer_attention_heads, max_position_embeddings=2048, bos_token_id=self.vocab_size + 1, eos_token_id=self.vocab_size + 2, pad_token_id=self.vocab_size + 3, attention_dropout=args.transformer_dropout_rate) self.embedding = nn.Embedding(self.vocab_size + 4, self.idim, padding_idx=self.vocab_size + 3) self.init_pre_nn(config) self.layers = nn.ModuleList( [LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = LlamaRotaryEmbedding(config=config) self.dropout = nn.Dropout(p=self.encoder_drop_rate) self.out_fnn = nn.Linear(args.encoder_output_dim, self.vocab_size + 4) self.kv_cache_prefix_finetune = args.kv_cache_prefix_finetune if self.kv_cache_prefix_finetune: self.init_kv_cache_prefix(config) self.embedding.eval() self.layers.eval() self.norm.eval() self.rotary_emb.eval() self.out_fnn.eval() for (name, param) in self.embedding.named_parameters(): param.requires_grad = False for (name, param) in self.layers.named_parameters(): param.requires_grad = False for (name, param) in self.norm.named_parameters(): param.requires_grad = False for (name, param) in self.rotary_emb.named_parameters(): param.requires_grad = False for (name, param) in self.out_fnn.named_parameters(): param.requires_grad = False if self.encoder_criterion == 'ce': self.criterion = CrossEntropyLoss(ignore_index=self.vocab_size + 3) def init_kv_cache_prefix(self, config): self.layers_prefix = nn.ModuleList( [LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) self.rotary_emb_prefix = LlamaRotaryEmbedding(config=config) def kv_cache_prefix_forward(self, prefix, prefix_lens, past_key_values): inputs_embeds = prefix past_seen_tokens = 0 cache_position = torch.arange(past_seen_tokens, past_seen_tokens + \ inputs_embeds.shape[1], device=inputs_embeds.device) position_ids = cache_position.unsqueeze(0) hidden_states = inputs_embeds position_embeddings = self.rotary_emb_prefix(hidden_states, position_ids) next_decoder_cache = None batch_size, max_len, _ = prefix.size() input_mask = torch.zeros(batch_size, max_len, max_len, dtype=torch.bool, device=prefix.device) for i in range(batch_size): input_mask[i, :prefix_lens[i], :prefix_lens[i]] = True attention_mask = ~(input_mask.unsqueeze(1)) * torch.finfo(inputs_embeds.dtype).min for decoder_layer in self.layers_prefix: layer_outputs = decoder_layer( hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=False, use_cache=True, cache_position=None, position_embeddings=position_embeddings, ) hidden_states = layer_outputs[0] next_decoder_cache = layer_outputs[1] past_key_values = next_decoder_cache def init_pre_nn(self, config): self.layers_pre_nn = nn.ModuleList( [LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers // 2)] ) self.rotary_emb_pre_nn = LlamaRotaryEmbedding(config=config) def pre_nn_forward(self, hidden, hidden_lens): inputs_embeds = hidden past_seen_tokens = 0 cache_position = torch.arange(past_seen_tokens, past_seen_tokens + \ inputs_embeds.shape[1], device=inputs_embeds.device) position_ids = cache_position.unsqueeze(0) hidden_states = inputs_embeds position_embeddings = self.rotary_emb_pre_nn(hidden_states, position_ids) next_decoder_cache = None batch_size, max_len, _ = hidden.size() input_mask = torch.zeros(batch_size, max_len, max_len, dtype=torch.bool, device=hidden.device) for i in range(batch_size): input_mask[i, :hidden_lens[i], :hidden_lens[i]] = True attention_mask = ~(input_mask.unsqueeze(1)) * torch.finfo(inputs_embeds.dtype).min for decoder_layer in self.layers_pre_nn: layer_outputs = decoder_layer( hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=None, output_attentions=False, use_cache=False, cache_position=None, position_embeddings=position_embeddings, ) hidden_states = layer_outputs[0] return hidden_states def forward(self, batch): llm_hidden = batch['x'] llm_hidden_lens = batch['x_lens'] y = batch['y'] y[y == IGNORE_ID] = self.vocab_size + 3 y_lens = batch['y_lens'] past_key_values = DynamicCache.from_legacy_cache(None) if self.kv_cache_prefix_finetune: self.kv_cache_prefix_forward(batch['x_prefix'], batch['x_prefix_lens'], past_key_values) # text_ids: (batch_size, max_len) batch_size, max_len = y.size() # Create bos, sos and eos tokens bos_token = torch.full((batch_size, 1), self.vocab_size, dtype=torch.long, device=y.device) sos_token = torch.full((batch_size, 1), self.vocab_size + 1, dtype=torch.long, device=y.device) eos_token = torch.full((batch_size, 1), self.vocab_size + 2, dtype=torch.long, device=y.device) padding_token = torch.full((batch_size, 1), self.vocab_size + 3, dtype=torch.long, device=y.device) # Pass through pre_nn llm_hidden = self.pre_nn_forward(llm_hidden, llm_hidden_lens) # Concat bos embedding bos_emb = self.embedding(bos_token) llm_hidden = torch.cat([bos_emb, llm_hidden], dim=1) llm_hidden_lens = llm_hidden_lens + 1 # Create input x with sos token at the beginning x = torch.cat([sos_token, y], dim=1) # (batch_size, max_len + 1) # Create output y with eos token at the end y = torch.cat([y, padding_token], dim=1) eos_positions = torch.arange(max_len + 1, device=y.device).expand(batch_size, max_len + 1) \ == y_lens.unsqueeze(1) y = y.masked_scatter(eos_positions, eos_token.expand_as(y)[eos_positions]) # Embed the input sequence x_emb = self.embedding(x) # (batch_size, max_len + 1, d_model) # compute masks if self.kv_cache_prefix_finetune: x_prefix = batch['x_prefix'] x_prefix_lens = batch['x_prefix_lens'] input_lens = llm_hidden.size(1) + max_len + 1 input_mask = torch.zeros(batch_size, input_lens, x_prefix.size(1) + input_lens, \ dtype=torch.bool, device=x_emb.device) for i in range(batch_size): input_mask[i, :llm_hidden_lens[i], :x_prefix_lens[i]] = True input_mask[i, :llm_hidden_lens[i], x_prefix.size(1): x_prefix.size(1) + llm_hidden_lens[i]] = True input_mask[i, llm_hidden.size(1): llm_hidden.size(1) + y_lens[i] + 1, :x_prefix_lens[i]] = True input_mask[i, llm_hidden.size(1): llm_hidden.size(1) + y_lens[i] + 1, \ x_prefix.size(1): x_prefix.size(1) + llm_hidden_lens[i]] = True input_mask[i, llm_hidden.size(1): llm_hidden.size(1) + y_lens[i] + 1, \ x_prefix.size(1) + llm_hidden.size(1): x_prefix.size(1) + \ llm_hidden.size(1) + y_lens[i] + 1] \ = subsequent_mask(y_lens[i] + 1, x_emb.device) else: input_lens = llm_hidden.size(1) + max_len + 1 input_mask = torch.zeros(batch_size, input_lens, input_lens, dtype=torch.bool, device=x_emb.device) for i in range(batch_size): input_mask[i, :llm_hidden_lens[i], :llm_hidden_lens[i]] = True input_mask[i, llm_hidden.size(1): llm_hidden.size(1) + y_lens[i] + 1, :llm_hidden_lens[i]] = True input_mask[i, llm_hidden.size(1): llm_hidden.size(1) + y_lens[i] + 1, \ llm_hidden.size(1): llm_hidden.size(1) + y_lens[i] + 1] \ = subsequent_mask(y_lens[i] + 1, x_emb.device) # Pass through the transformer inputs_embeds = torch.cat([llm_hidden, x_emb], 1) llm_hidden = self.dropout(llm_hidden) past_seen_tokens = 0 cache_position = torch.arange(past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], \ device=inputs_embeds.device) position_ids = cache_position.unsqueeze(0) hidden_states = inputs_embeds position_embeddings = self.rotary_emb(hidden_states, position_ids) attention_mask = ~(input_mask.unsqueeze(1)) * torch.finfo(inputs_embeds.dtype).min for decoder_layer in self.layers: layer_outputs = decoder_layer( hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=False, use_cache=True, cache_position=None, position_embeddings=position_embeddings, ) hidden_states = layer_outputs[0] hidden_states = self.norm(hidden_states) encoder_out = hidden_states[:, llm_hidden.size(1):] # Project to vocabulary size logits = self.out_fnn(encoder_out) if self.encoder_criterion == 'ce': loss = self.criterion(logits, y) if self.training: self.reporter.log_loss('loss', float(loss)) return loss def transformer_infer(self, inputs_embeds, cache_position, past_key_values): position_ids = cache_position.unsqueeze(0) hidden_states = inputs_embeds position_embeddings = self.rotary_emb(hidden_states, position_ids) next_decoder_cache = None for decoder_layer in self.layers: layer_outputs = decoder_layer( hidden_states, attention_mask=None, position_ids=position_ids, past_key_value=past_key_values, output_attentions=False, use_cache=True, cache_position=None, position_embeddings=position_embeddings, ) hidden_states = layer_outputs[0] next_decoder_cache = layer_outputs[1] return hidden_states def infer(self, hidden, top_k, prefix, penalty_window_size, penalty, max_tokens=1000): # Pass through pre_nn hidden = self.pre_nn_forward(hidden, [hidden.size(1)]) # Concat bos embedding bos_emb = self.embedding(torch.full((1, 1), self.vocab_size, dtype=torch.long, device=hidden.device)) hidden = torch.cat([bos_emb, hidden], dim=1) # init past key values past_key_values = DynamicCache.from_legacy_cache(None) # Pass through the prefix nar decoder if prefix is not None and self.kv_cache_prefix_finetune: self.kv_cache_prefix_forward(prefix, [prefix.size(1)], past_key_values) inputs_embeds = hidden past_seen_tokens = 0 cache_position = torch.arange(past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], \ device=inputs_embeds.device) hidden_states = self.transformer_infer(inputs_embeds, cache_position, past_key_values) # init generated tokens cur_token = torch.full((1, 1), self.vocab_size + 1, dtype=torch.long, device=hidden.device) generated_tokens = torch.full((1, 1), self.vocab_size + 1, dtype=torch.long, device=hidden.device) # generate tokens for i in range(max_tokens): inputs_embeds = self.embedding(cur_token) past_seen_tokens = past_key_values.get_seq_length() if prefix is not None: past_seen_tokens = past_seen_tokens - prefix.size(1) cache_position = torch.arange(past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], \ device=inputs_embeds.device) hidden_states = self.transformer_infer(inputs_embeds, cache_position, past_key_values) hidden_states = self.norm(hidden_states) # Project to vocabulary size logits = self.out_fnn(hidden_states) # apply penalty if penalty_window_size > 0: for token in set(generated_tokens[0][-penalty_window_size:]): logits[:, :, token] /= penalty # top k sampling output = logits.squeeze(0).squeeze(0) probs = torch.nn.functional.softmax(output, dim=-1) top_k_probs, top_k_indices = torch.topk(probs, top_k) probs = torch.zeros_like(probs).scatter_(0, top_k_indices, top_k_probs) probs = probs / probs.sum() next_token_id = torch.multinomial(probs, 1).unsqueeze(0) generated_tokens = torch.cat([generated_tokens, next_token_id], dim=-1) cur_token = next_token_id # eos if next_token_id == self.vocab_size + 2: break yield next_token_id