# Copyright (c) 2024 Amphion. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. from transformers import LlamaConfig, LlamaForCausalLM, LlamaModel import torch import torch.nn.functional as F import numpy as np import os import torch.nn as nn from typing import List, Optional, Tuple, Union import math from transformers.models.llama.modeling_llama import LlamaDecoderLayer from transformers.models.llama.modeling_llama import BaseModelOutputWithPast # sinusoidal positional encoding class SinusoidalPosEmb(nn.Module): def __init__(self, dim): super().__init__() self.dim = dim def forward(self, x): device = x.device half_dim = self.dim // 2 emb = math.log(10000) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, device=device) * -emb) emb = x[:, None] * emb[None, :] * 1.0 emb = torch.cat((emb.sin(), emb.cos()), dim=-1) return emb class LlamaAdaptiveRMSNorm(nn.Module): def __init__(self, hidden_size=1024, eps=1e-6, dim_cond=1024): super().__init__() self.to_weight = nn.Linear(dim_cond, hidden_size) nn.init.zeros_(self.to_weight.weight) nn.init.ones_(self.to_weight.bias) self.variance_epsilon = eps self._is_hf_initialized = True # disable automatic init def forward(self, hidden_states, cond_embedding): input_dtype = hidden_states.dtype variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) weight = self.to_weight(cond_embedding) if len(weight.shape) == 2: weight = weight.unsqueeze(1) return (weight * hidden_states).to(input_dtype) class LlamaNARDecoderLayer(LlamaDecoderLayer): def __init__(self, config: LlamaConfig, layer_idx: int): """Override to adaptive layer norm""" super().__init__(config, layer_idx) # init attention, mlp, etc. self.input_layernorm = LlamaAdaptiveRMSNorm( config.hidden_size, eps=config.rms_norm_eps, dim_cond=config.hidden_size ) self.post_attention_layernorm = LlamaAdaptiveRMSNorm( config.hidden_size, eps=config.rms_norm_eps, dim_cond=config.hidden_size ) # add `cond` in forward function def forward( self, hidden_states: torch.Tensor, cond_embedding: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, ) -> Tuple[ torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] ]: """ Args: hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` attention_mask (`torch.FloatTensor`, *optional*): attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states """ residual = hidden_states hidden_states = self.input_layernorm( hidden_states, cond_embedding=cond_embedding ) # Self Attention hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, ) hidden_states = residual + hidden_states # Fully Connected residual = hidden_states hidden_states = self.post_attention_layernorm( hidden_states, cond_embedding=cond_embedding ) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states outputs = (hidden_states,) if output_attentions: outputs += (self_attn_weights,) if use_cache: outputs += (present_key_value,) return outputs def __init__(self, config: LlamaConfig, layer_idx: int): """Override to adaptive layer norm""" super().__init__(config, layer_idx) # init attention, mlp, etc. self.layer_idx = layer_idx self.input_layernorm = LlamaAdaptiveRMSNorm( config.hidden_size, eps=config.rms_norm_eps, dim_cond=config.hidden_size ) self.post_attention_layernorm = LlamaAdaptiveRMSNorm( config.hidden_size, eps=config.rms_norm_eps, dim_cond=config.hidden_size ) def forward( self, hidden_states: torch.Tensor, cond_embedding: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, ) -> Tuple[ torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] ]: """ Args: hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` attention_mask (`torch.FloatTensor`, *optional*): attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states """ residual = hidden_states hidden_states = self.input_layernorm( hidden_states, cond_embedding=cond_embedding ) # Self Attention hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, ) hidden_states = residual + hidden_states # Fully Connected residual = hidden_states hidden_states = self.post_attention_layernorm( hidden_states, cond_embedding=cond_embedding ) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states outputs = (hidden_states,) if output_attentions: outputs += (self_attn_weights,) if use_cache: outputs += (present_key_value,) return outputs class DiffLlama(LlamaModel): def __init__( self, hidden_size=1024, num_heads=16, num_layers=16, config=LlamaConfig(0, 256, 1024, 1, 1), ): super().__init__(config) self.layers = nn.ModuleList( [ LlamaNARDecoderLayer( LlamaConfig( hidden_size=hidden_size, num_attention_heads=num_heads, max_position_embeddings=4096, intermediate_size=hidden_size * 4, ), layer_idx=i, ) for i in range(num_layers) ] ) self.norm = LlamaAdaptiveRMSNorm(hidden_size, dim_cond=hidden_size) self.diff_step_embedding = SinusoidalPosEmb(hidden_size) self.diff_step_mlp = nn.Sequential( nn.Linear(hidden_size, hidden_size * 4), nn.SiLU(), nn.Linear(hidden_size * 4, hidden_size), ) # self.position_embedding = PositionalEncoding(hidden_size, dropout=0.0) self.cond_mlp = nn.Sequential( nn.Linear(hidden_size, hidden_size * 4), nn.SiLU(), nn.Linear(hidden_size * 4, hidden_size), ) for layer in self.layers: layer.input_layernorm = LlamaAdaptiveRMSNorm( hidden_size, dim_cond=hidden_size ) layer.post_attention_layernorm = LlamaAdaptiveRMSNorm( hidden_size, dim_cond=hidden_size ) self.post_init() # self.reset_parameters() def _prepare_decoder_attention_mask( self, attention_mask, input_shape, inputs_embeds, past_key_values_length ): # create noncausal mask # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] combined_attention_mask = None def _expand_mask( mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None ): """ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. """ bsz, src_len = mask.size() tgt_len = tgt_len if tgt_len is not None else src_len expanded_mask = ( mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) ) inverted_mask = 1.0 - expanded_mask return inverted_mask.masked_fill( inverted_mask.to(torch.bool), torch.finfo(dtype).min ) if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] expanded_attn_mask = _expand_mask( attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] ).to(inputs_embeds.device) combined_attention_mask = ( expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask ) return combined_attention_mask def forward( self, x, diffusion_step, cond, x_mask, input_ids: torch.LongTensor = None, # [num_quant, B, T] attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: # retrieve some shape info batch_size, seq_length, _ = x.shape # condtion mlp cond_embedding = self.cond_mlp(cond) # (B, T, C) # diffusion step embedding diffusion_step = self.diff_step_embedding(diffusion_step).to(x.device) diffusion_step = self.diff_step_mlp(diffusion_step) # (B, C) x = x + cond_embedding inputs_embeds = x attention_mask = x_mask output_attentions = ( output_attentions if output_attentions is not None else self.config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) seq_length_with_past = seq_length past_key_values_length = 0 if past_key_values is not None: past_key_values_length = past_key_values[0][0].shape[2] seq_length_with_past = seq_length_with_past + past_key_values_length if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device position_ids = torch.arange( past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device, ) position_ids = position_ids.unsqueeze(0).view(-1, seq_length) else: position_ids = position_ids.view(-1, seq_length).long() # embed positions if attention_mask is None: attention_mask = torch.ones( (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device, ) attention_mask = self._prepare_decoder_attention_mask( attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length, ) hidden_states = inputs_embeds if self.gradient_checkpointing and self.training: if use_cache: use_cache = False # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None next_decoder_cache = () if use_cache else None for idx, decoder_layer in enumerate(self.layers): if output_hidden_states: all_hidden_states += (hidden_states,) past_key_value = ( past_key_values[idx] if past_key_values is not None else None ) if self.gradient_checkpointing and self.training: raise NotImplementedError else: layer_outputs = decoder_layer( hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, cond_embedding=diffusion_step, ) hidden_states = layer_outputs[0] if use_cache: next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) if output_attentions: all_self_attns += (layer_outputs[1],) hidden_states = self.norm(hidden_states, cond_embedding=diffusion_step) # add hidden states from the last decoder layer if output_hidden_states: all_hidden_states += (hidden_states,) next_cache = next_decoder_cache if use_cache else None return hidden_states class DiffLlamaPrefix(LlamaModel): def __init__( self, hidden_size=1024, num_heads=16, num_layers=16, config=LlamaConfig(0, 256, 1024, 1, 1), ): super().__init__(config) self.layers = nn.ModuleList( [ LlamaNARDecoderLayer( LlamaConfig( hidden_size=hidden_size, num_attention_heads=num_heads, max_position_embeddings=4096, intermediate_size=hidden_size * 4, ), layer_idx=i, ) for i in range(num_layers) ] ) self.norm = LlamaAdaptiveRMSNorm(hidden_size, dim_cond=hidden_size) self.diff_step_embedding = SinusoidalPosEmb(hidden_size) self.diff_step_mlp = nn.Sequential( nn.Linear(hidden_size, hidden_size * 4), nn.SiLU(), nn.Linear(hidden_size * 4, hidden_size), ) self.cond_mlp = nn.Sequential( nn.Linear(hidden_size, hidden_size * 4), nn.SiLU(), nn.Linear(hidden_size * 4, hidden_size), ) for layer in self.layers: layer.input_layernorm = LlamaAdaptiveRMSNorm( hidden_size, dim_cond=hidden_size ) layer.post_attention_layernorm = LlamaAdaptiveRMSNorm( hidden_size, dim_cond=hidden_size ) self.embed_tokens = None self.post_init() def _prepare_decoder_attention_mask( self, attention_mask, input_shape, inputs_embeds, past_key_values_length ): # create noncausal mask # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] combined_attention_mask = None def _expand_mask( mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None ): """ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. """ bsz, src_len = mask.size() tgt_len = tgt_len if tgt_len is not None else src_len expanded_mask = ( mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) ) inverted_mask = 1.0 - expanded_mask return inverted_mask.masked_fill( inverted_mask.to(torch.bool), torch.finfo(dtype).min ) if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] expanded_attn_mask = _expand_mask( attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] ).to(inputs_embeds.device) combined_attention_mask = ( expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask ) return combined_attention_mask def forward( self, x, diffusion_step, x_mask, phone_embedding: Optional[torch.LongTensor] = None, phone_mask: Optional[torch.FloatTensor] = None, input_ids: torch.LongTensor = None, # [num_quant, B, T] attention_mask: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: # retrieve some shape info phone_embedding = self.cond_mlp(phone_embedding) # (B, T, C) phone_length = phone_embedding.shape[1] inputs_embeds = torch.cat([phone_embedding, x], dim=1) attention_mask = torch.cat([phone_mask, x_mask], dim=1) # diffusion step embedding diffusion_step = self.diff_step_embedding(diffusion_step).to(x.device) diffusion_step = self.diff_step_mlp(diffusion_step) # (B, C) batch_size, seq_length, _ = inputs_embeds.shape output_attentions = ( output_attentions if output_attentions is not None else self.config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) seq_length_with_past = seq_length past_key_values_length = 0 if past_key_values is not None: past_key_values_length = past_key_values[0][0].shape[2] seq_length_with_past = seq_length_with_past + past_key_values_length if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device position_ids = torch.arange( past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device, ) position_ids = position_ids.unsqueeze(0).view(-1, seq_length) else: position_ids = position_ids.view(-1, seq_length).long() # embed positions if attention_mask is None: attention_mask = torch.ones( (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device, ) attention_mask = self._prepare_decoder_attention_mask( attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length, ) hidden_states = inputs_embeds if self.gradient_checkpointing and self.training: if use_cache: use_cache = False # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None next_decoder_cache = () if use_cache else None for idx, decoder_layer in enumerate(self.layers): if output_hidden_states: all_hidden_states += (hidden_states,) past_key_value = ( past_key_values[idx] if past_key_values is not None else None ) if self.gradient_checkpointing and self.training: raise NotImplementedError else: layer_outputs = decoder_layer( hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, cond_embedding=diffusion_step, ) hidden_states = layer_outputs[0] if use_cache: next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) if output_attentions: all_self_attns += (layer_outputs[1],) hidden_states = self.norm(hidden_states, cond_embedding=diffusion_step) # add hidden states from the last decoder layer if output_hidden_states: all_hidden_states += (hidden_states,) next_cache = next_decoder_cache if use_cache else None return hidden_states[ :, phone_length:, ]