Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) 2024 Amphion. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
from transformers import LlamaConfig, LlamaForCausalLM, LlamaModel | |
import torch | |
import torch.nn.functional as F | |
import numpy as np | |
import os | |
import torch.nn as nn | |
from typing import List, Optional, Tuple, Union | |
import math | |
from transformers.models.llama.modeling_llama import LlamaDecoderLayer | |
from transformers.models.llama.modeling_llama import BaseModelOutputWithPast | |
# sinusoidal positional encoding | |
class SinusoidalPosEmb(nn.Module): | |
def __init__(self, dim): | |
super().__init__() | |
self.dim = dim | |
def forward(self, x): | |
device = x.device | |
half_dim = self.dim // 2 | |
emb = math.log(10000) / (half_dim - 1) | |
emb = torch.exp(torch.arange(half_dim, device=device) * -emb) | |
emb = x[:, None] * emb[None, :] * 1.0 | |
emb = torch.cat((emb.sin(), emb.cos()), dim=-1) | |
return emb | |
class LlamaAdaptiveRMSNorm(nn.Module): | |
def __init__(self, hidden_size=1024, eps=1e-6, dim_cond=1024): | |
super().__init__() | |
self.to_weight = nn.Linear(dim_cond, hidden_size) | |
nn.init.zeros_(self.to_weight.weight) | |
nn.init.ones_(self.to_weight.bias) | |
self.variance_epsilon = eps | |
self._is_hf_initialized = True # disable automatic init | |
def forward(self, hidden_states, cond_embedding): | |
input_dtype = hidden_states.dtype | |
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) | |
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) | |
weight = self.to_weight(cond_embedding) | |
if len(weight.shape) == 2: | |
weight = weight.unsqueeze(1) | |
return (weight * hidden_states).to(input_dtype) | |
class LlamaNARDecoderLayer(LlamaDecoderLayer): | |
def __init__(self, config: LlamaConfig, layer_idx: int): | |
"""Override to adaptive layer norm""" | |
super().__init__(config, layer_idx) # init attention, mlp, etc. | |
self.input_layernorm = LlamaAdaptiveRMSNorm( | |
config.hidden_size, eps=config.rms_norm_eps, dim_cond=config.hidden_size | |
) | |
self.post_attention_layernorm = LlamaAdaptiveRMSNorm( | |
config.hidden_size, eps=config.rms_norm_eps, dim_cond=config.hidden_size | |
) | |
# add `cond` in forward function | |
def forward( | |
self, | |
hidden_states: torch.Tensor, | |
cond_embedding: torch.Tensor, | |
attention_mask: Optional[torch.Tensor] = None, | |
position_ids: Optional[torch.LongTensor] = None, | |
past_key_value: Optional[Tuple[torch.Tensor]] = None, | |
output_attentions: Optional[bool] = False, | |
use_cache: Optional[bool] = False, | |
) -> Tuple[ | |
torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] | |
]: | |
""" | |
Args: | |
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` | |
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size | |
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. | |
output_attentions (`bool`, *optional*): | |
Whether or not to return the attentions tensors of all attention layers. See `attentions` under | |
returned tensors for more detail. | |
use_cache (`bool`, *optional*): | |
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding | |
(see `past_key_values`). | |
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states | |
""" | |
residual = hidden_states | |
hidden_states = self.input_layernorm( | |
hidden_states, cond_embedding=cond_embedding | |
) | |
# Self Attention | |
hidden_states, self_attn_weights, present_key_value = self.self_attn( | |
hidden_states=hidden_states, | |
attention_mask=attention_mask, | |
position_ids=position_ids, | |
past_key_value=past_key_value, | |
output_attentions=output_attentions, | |
use_cache=use_cache, | |
) | |
hidden_states = residual + hidden_states | |
# Fully Connected | |
residual = hidden_states | |
hidden_states = self.post_attention_layernorm( | |
hidden_states, cond_embedding=cond_embedding | |
) | |
hidden_states = self.mlp(hidden_states) | |
hidden_states = residual + hidden_states | |
outputs = (hidden_states,) | |
if output_attentions: | |
outputs += (self_attn_weights,) | |
if use_cache: | |
outputs += (present_key_value,) | |
return outputs | |
def __init__(self, config: LlamaConfig, layer_idx: int): | |
"""Override to adaptive layer norm""" | |
super().__init__(config, layer_idx) # init attention, mlp, etc. | |
self.layer_idx = layer_idx | |
self.input_layernorm = LlamaAdaptiveRMSNorm( | |
config.hidden_size, eps=config.rms_norm_eps, dim_cond=config.hidden_size | |
) | |
self.post_attention_layernorm = LlamaAdaptiveRMSNorm( | |
config.hidden_size, eps=config.rms_norm_eps, dim_cond=config.hidden_size | |
) | |
def forward( | |
self, | |
hidden_states: torch.Tensor, | |
cond_embedding: torch.Tensor, | |
attention_mask: Optional[torch.Tensor] = None, | |
position_ids: Optional[torch.LongTensor] = None, | |
past_key_value: Optional[Tuple[torch.Tensor]] = None, | |
output_attentions: Optional[bool] = False, | |
use_cache: Optional[bool] = False, | |
) -> Tuple[ | |
torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] | |
]: | |
""" | |
Args: | |
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` | |
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size | |
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. | |
output_attentions (`bool`, *optional*): | |
Whether or not to return the attentions tensors of all attention layers. See `attentions` under | |
returned tensors for more detail. | |
use_cache (`bool`, *optional*): | |
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding | |
(see `past_key_values`). | |
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states | |
""" | |
residual = hidden_states | |
hidden_states = self.input_layernorm( | |
hidden_states, cond_embedding=cond_embedding | |
) | |
# Self Attention | |
hidden_states, self_attn_weights, present_key_value = self.self_attn( | |
hidden_states=hidden_states, | |
attention_mask=attention_mask, | |
position_ids=position_ids, | |
past_key_value=past_key_value, | |
output_attentions=output_attentions, | |
use_cache=use_cache, | |
) | |
hidden_states = residual + hidden_states | |
# Fully Connected | |
residual = hidden_states | |
hidden_states = self.post_attention_layernorm( | |
hidden_states, cond_embedding=cond_embedding | |
) | |
hidden_states = self.mlp(hidden_states) | |
hidden_states = residual + hidden_states | |
outputs = (hidden_states,) | |
if output_attentions: | |
outputs += (self_attn_weights,) | |
if use_cache: | |
outputs += (present_key_value,) | |
return outputs | |
class DiffLlama(LlamaModel): | |
def __init__( | |
self, | |
hidden_size=1024, | |
num_heads=16, | |
num_layers=16, | |
config=LlamaConfig(0, 256, 1024, 1, 1), | |
): | |
super().__init__(config) | |
self.layers = nn.ModuleList( | |
[ | |
LlamaNARDecoderLayer( | |
LlamaConfig( | |
hidden_size=hidden_size, | |
num_attention_heads=num_heads, | |
max_position_embeddings=4096, | |
intermediate_size=hidden_size * 4, | |
), | |
layer_idx=i, | |
) | |
for i in range(num_layers) | |
] | |
) | |
self.norm = LlamaAdaptiveRMSNorm(hidden_size, dim_cond=hidden_size) | |
self.diff_step_embedding = SinusoidalPosEmb(hidden_size) | |
self.diff_step_mlp = nn.Sequential( | |
nn.Linear(hidden_size, hidden_size * 4), | |
nn.SiLU(), | |
nn.Linear(hidden_size * 4, hidden_size), | |
) | |
# self.position_embedding = PositionalEncoding(hidden_size, dropout=0.0) | |
self.cond_mlp = nn.Sequential( | |
nn.Linear(hidden_size, hidden_size * 4), | |
nn.SiLU(), | |
nn.Linear(hidden_size * 4, hidden_size), | |
) | |
for layer in self.layers: | |
layer.input_layernorm = LlamaAdaptiveRMSNorm( | |
hidden_size, dim_cond=hidden_size | |
) | |
layer.post_attention_layernorm = LlamaAdaptiveRMSNorm( | |
hidden_size, dim_cond=hidden_size | |
) | |
self.post_init() | |
# self.reset_parameters() | |
def _prepare_decoder_attention_mask( | |
self, attention_mask, input_shape, inputs_embeds, past_key_values_length | |
): | |
# create noncausal mask | |
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] | |
combined_attention_mask = None | |
def _expand_mask( | |
mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None | |
): | |
""" | |
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. | |
""" | |
bsz, src_len = mask.size() | |
tgt_len = tgt_len if tgt_len is not None else src_len | |
expanded_mask = ( | |
mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) | |
) | |
inverted_mask = 1.0 - expanded_mask | |
return inverted_mask.masked_fill( | |
inverted_mask.to(torch.bool), torch.finfo(dtype).min | |
) | |
if attention_mask is not None: | |
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] | |
expanded_attn_mask = _expand_mask( | |
attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] | |
).to(inputs_embeds.device) | |
combined_attention_mask = ( | |
expanded_attn_mask | |
if combined_attention_mask is None | |
else expanded_attn_mask + combined_attention_mask | |
) | |
return combined_attention_mask | |
def forward( | |
self, | |
x, | |
diffusion_step, | |
cond, | |
x_mask, | |
input_ids: torch.LongTensor = None, # [num_quant, B, T] | |
attention_mask: Optional[torch.Tensor] = None, | |
position_ids: Optional[torch.LongTensor] = None, | |
past_key_values: Optional[List[torch.FloatTensor]] = None, | |
inputs_embeds: Optional[torch.FloatTensor] = None, | |
use_cache: Optional[bool] = None, | |
output_attentions: Optional[bool] = None, | |
output_hidden_states: Optional[bool] = None, | |
return_dict: Optional[bool] = None, | |
) -> Union[Tuple, BaseModelOutputWithPast]: | |
# retrieve some shape info | |
batch_size, seq_length, _ = x.shape | |
# condtion mlp | |
cond_embedding = self.cond_mlp(cond) # (B, T, C) | |
# diffusion step embedding | |
diffusion_step = self.diff_step_embedding(diffusion_step).to(x.device) | |
diffusion_step = self.diff_step_mlp(diffusion_step) # (B, C) | |
x = x + cond_embedding | |
inputs_embeds = x | |
attention_mask = x_mask | |
output_attentions = ( | |
output_attentions | |
if output_attentions is not None | |
else self.config.output_attentions | |
) | |
output_hidden_states = ( | |
output_hidden_states | |
if output_hidden_states is not None | |
else self.config.output_hidden_states | |
) | |
use_cache = use_cache if use_cache is not None else self.config.use_cache | |
return_dict = ( | |
return_dict if return_dict is not None else self.config.use_return_dict | |
) | |
seq_length_with_past = seq_length | |
past_key_values_length = 0 | |
if past_key_values is not None: | |
past_key_values_length = past_key_values[0][0].shape[2] | |
seq_length_with_past = seq_length_with_past + past_key_values_length | |
if position_ids is None: | |
device = input_ids.device if input_ids is not None else inputs_embeds.device | |
position_ids = torch.arange( | |
past_key_values_length, | |
seq_length + past_key_values_length, | |
dtype=torch.long, | |
device=device, | |
) | |
position_ids = position_ids.unsqueeze(0).view(-1, seq_length) | |
else: | |
position_ids = position_ids.view(-1, seq_length).long() | |
# embed positions | |
if attention_mask is None: | |
attention_mask = torch.ones( | |
(batch_size, seq_length_with_past), | |
dtype=torch.bool, | |
device=inputs_embeds.device, | |
) | |
attention_mask = self._prepare_decoder_attention_mask( | |
attention_mask, | |
(batch_size, seq_length), | |
inputs_embeds, | |
past_key_values_length, | |
) | |
hidden_states = inputs_embeds | |
if self.gradient_checkpointing and self.training: | |
if use_cache: | |
use_cache = False | |
# decoder layers | |
all_hidden_states = () if output_hidden_states else None | |
all_self_attns = () if output_attentions else None | |
next_decoder_cache = () if use_cache else None | |
for idx, decoder_layer in enumerate(self.layers): | |
if output_hidden_states: | |
all_hidden_states += (hidden_states,) | |
past_key_value = ( | |
past_key_values[idx] if past_key_values is not None else None | |
) | |
if self.gradient_checkpointing and self.training: | |
raise NotImplementedError | |
else: | |
layer_outputs = decoder_layer( | |
hidden_states, | |
attention_mask=attention_mask, | |
position_ids=position_ids, | |
past_key_value=past_key_value, | |
output_attentions=output_attentions, | |
use_cache=use_cache, | |
cond_embedding=diffusion_step, | |
) | |
hidden_states = layer_outputs[0] | |
if use_cache: | |
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) | |
if output_attentions: | |
all_self_attns += (layer_outputs[1],) | |
hidden_states = self.norm(hidden_states, cond_embedding=diffusion_step) | |
# add hidden states from the last decoder layer | |
if output_hidden_states: | |
all_hidden_states += (hidden_states,) | |
next_cache = next_decoder_cache if use_cache else None | |
return hidden_states | |
class DiffLlamaPrefix(LlamaModel): | |
def __init__( | |
self, | |
hidden_size=1024, | |
num_heads=16, | |
num_layers=16, | |
config=LlamaConfig(0, 256, 1024, 1, 1), | |
): | |
super().__init__(config) | |
self.layers = nn.ModuleList( | |
[ | |
LlamaNARDecoderLayer( | |
LlamaConfig( | |
hidden_size=hidden_size, | |
num_attention_heads=num_heads, | |
max_position_embeddings=4096, | |
intermediate_size=hidden_size * 4, | |
), | |
layer_idx=i, | |
) | |
for i in range(num_layers) | |
] | |
) | |
self.norm = LlamaAdaptiveRMSNorm(hidden_size, dim_cond=hidden_size) | |
self.diff_step_embedding = SinusoidalPosEmb(hidden_size) | |
self.diff_step_mlp = nn.Sequential( | |
nn.Linear(hidden_size, hidden_size * 4), | |
nn.SiLU(), | |
nn.Linear(hidden_size * 4, hidden_size), | |
) | |
self.cond_mlp = nn.Sequential( | |
nn.Linear(hidden_size, hidden_size * 4), | |
nn.SiLU(), | |
nn.Linear(hidden_size * 4, hidden_size), | |
) | |
for layer in self.layers: | |
layer.input_layernorm = LlamaAdaptiveRMSNorm( | |
hidden_size, dim_cond=hidden_size | |
) | |
layer.post_attention_layernorm = LlamaAdaptiveRMSNorm( | |
hidden_size, dim_cond=hidden_size | |
) | |
self.embed_tokens = None | |
self.post_init() | |
def _prepare_decoder_attention_mask( | |
self, attention_mask, input_shape, inputs_embeds, past_key_values_length | |
): | |
# create noncausal mask | |
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] | |
combined_attention_mask = None | |
def _expand_mask( | |
mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None | |
): | |
""" | |
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. | |
""" | |
bsz, src_len = mask.size() | |
tgt_len = tgt_len if tgt_len is not None else src_len | |
expanded_mask = ( | |
mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) | |
) | |
inverted_mask = 1.0 - expanded_mask | |
return inverted_mask.masked_fill( | |
inverted_mask.to(torch.bool), torch.finfo(dtype).min | |
) | |
if attention_mask is not None: | |
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] | |
expanded_attn_mask = _expand_mask( | |
attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] | |
).to(inputs_embeds.device) | |
combined_attention_mask = ( | |
expanded_attn_mask | |
if combined_attention_mask is None | |
else expanded_attn_mask + combined_attention_mask | |
) | |
return combined_attention_mask | |
def forward( | |
self, | |
x, | |
diffusion_step, | |
x_mask, | |
phone_embedding: Optional[torch.LongTensor] = None, | |
phone_mask: Optional[torch.FloatTensor] = None, | |
input_ids: torch.LongTensor = None, # [num_quant, B, T] | |
attention_mask: Optional[torch.LongTensor] = None, | |
position_ids: Optional[torch.LongTensor] = None, | |
past_key_values: Optional[List[torch.FloatTensor]] = None, | |
inputs_embeds: Optional[torch.FloatTensor] = None, | |
use_cache: Optional[bool] = None, | |
output_attentions: Optional[bool] = None, | |
output_hidden_states: Optional[bool] = None, | |
return_dict: Optional[bool] = None, | |
) -> Union[Tuple, BaseModelOutputWithPast]: | |
# retrieve some shape info | |
phone_embedding = self.cond_mlp(phone_embedding) # (B, T, C) | |
phone_length = phone_embedding.shape[1] | |
inputs_embeds = torch.cat([phone_embedding, x], dim=1) | |
attention_mask = torch.cat([phone_mask, x_mask], dim=1) | |
# diffusion step embedding | |
diffusion_step = self.diff_step_embedding(diffusion_step).to(x.device) | |
diffusion_step = self.diff_step_mlp(diffusion_step) # (B, C) | |
batch_size, seq_length, _ = inputs_embeds.shape | |
output_attentions = ( | |
output_attentions | |
if output_attentions is not None | |
else self.config.output_attentions | |
) | |
output_hidden_states = ( | |
output_hidden_states | |
if output_hidden_states is not None | |
else self.config.output_hidden_states | |
) | |
use_cache = use_cache if use_cache is not None else self.config.use_cache | |
return_dict = ( | |
return_dict if return_dict is not None else self.config.use_return_dict | |
) | |
seq_length_with_past = seq_length | |
past_key_values_length = 0 | |
if past_key_values is not None: | |
past_key_values_length = past_key_values[0][0].shape[2] | |
seq_length_with_past = seq_length_with_past + past_key_values_length | |
if position_ids is None: | |
device = input_ids.device if input_ids is not None else inputs_embeds.device | |
position_ids = torch.arange( | |
past_key_values_length, | |
seq_length + past_key_values_length, | |
dtype=torch.long, | |
device=device, | |
) | |
position_ids = position_ids.unsqueeze(0).view(-1, seq_length) | |
else: | |
position_ids = position_ids.view(-1, seq_length).long() | |
# embed positions | |
if attention_mask is None: | |
attention_mask = torch.ones( | |
(batch_size, seq_length_with_past), | |
dtype=torch.bool, | |
device=inputs_embeds.device, | |
) | |
attention_mask = self._prepare_decoder_attention_mask( | |
attention_mask, | |
(batch_size, seq_length), | |
inputs_embeds, | |
past_key_values_length, | |
) | |
hidden_states = inputs_embeds | |
if self.gradient_checkpointing and self.training: | |
if use_cache: | |
use_cache = False | |
# decoder layers | |
all_hidden_states = () if output_hidden_states else None | |
all_self_attns = () if output_attentions else None | |
next_decoder_cache = () if use_cache else None | |
for idx, decoder_layer in enumerate(self.layers): | |
if output_hidden_states: | |
all_hidden_states += (hidden_states,) | |
past_key_value = ( | |
past_key_values[idx] if past_key_values is not None else None | |
) | |
if self.gradient_checkpointing and self.training: | |
raise NotImplementedError | |
else: | |
layer_outputs = decoder_layer( | |
hidden_states, | |
attention_mask=attention_mask, | |
position_ids=position_ids, | |
past_key_value=past_key_value, | |
output_attentions=output_attentions, | |
use_cache=use_cache, | |
cond_embedding=diffusion_step, | |
) | |
hidden_states = layer_outputs[0] | |
if use_cache: | |
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) | |
if output_attentions: | |
all_self_attns += (layer_outputs[1],) | |
hidden_states = self.norm(hidden_states, cond_embedding=diffusion_step) | |
# add hidden states from the last decoder layer | |
if output_hidden_states: | |
all_hidden_states += (hidden_states,) | |
next_cache = next_decoder_cache if use_cache else None | |
return hidden_states[ | |
:, | |
phone_length:, | |
] | |