maskgct / models /tts /valle_v2 /valle_nar.py
Hecheng0625's picture
Upload 409 files
c968fc3 verified
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from transformers import LlamaConfig, LlamaForCausalLM, LlamaModel
import torch
import torch.nn.functional as F
import numpy as np
import os
import torch.nn as nn
from typing import List, Optional, Tuple, Union
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
NUM_QUANTIZERS = 8 # number of quantizers in total, currently assumes first layer AR.
START_QUANTIZATION_LAYER = 1 # start quantization layer
END_QUANTIZATION_LAYER = 7 # end quantization layer
class LlamaAdaptiveRMSNorm(nn.Module):
def __init__(self, hidden_size=1024, eps=1e-9, dim_cond=1024):
super().__init__()
self.to_weight = nn.Linear(dim_cond, hidden_size)
nn.init.normal_(self.to_weight.weight, mean=0.0, std=0.02)
# nn.init.zeros_(self.to_weight.weight)
# nn.init.ones_(self.to_weight.bias)
self.variance_epsilon = eps
self._is_hf_initialized = True # disable automatic init
def forward(self, hidden_states, cond_embedding):
input_dtype = hidden_states.dtype
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
weight = self.to_weight(cond_embedding)
return (weight * hidden_states).to(input_dtype)
class LlamaNARDecoderLayer(LlamaDecoderLayer):
def __init__(self, config: LlamaConfig):
"""Override to adaptive layer norm"""
super().__init__(config=config, layer_idx=0) # init attention, mlp, etc.
self.input_layernorm = LlamaAdaptiveRMSNorm(
config.hidden_size, eps=config.rms_norm_eps, dim_cond=config.hidden_size
)
self.post_attention_layernorm = LlamaAdaptiveRMSNorm(
config.hidden_size, eps=config.rms_norm_eps, dim_cond=config.hidden_size
)
# add `cond` in forward function
def forward(
self,
hidden_states: torch.Tensor,
cond_embedding: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
) -> Tuple[
torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
"""
residual = hidden_states
hidden_states = self.input_layernorm(
hidden_states, cond_embedding=cond_embedding
)
# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(
hidden_states, cond_embedding=cond_embedding
)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
if use_cache:
outputs += (present_key_value,)
return outputs
from transformers.models.llama.modeling_llama import BaseModelOutputWithPast
class MultiEmbedding(nn.Module):
"""Embedding for multiple quantization layers, summing up the embeddings of each layer."""
def __init__(
self,
num_embeddings=1034,
embedding_dim=1024,
num_quantization_layers=NUM_QUANTIZERS,
):
super().__init__()
self.embeddings = nn.ModuleList(
[
nn.Embedding(num_embeddings, embedding_dim)
for _ in range(num_quantization_layers)
]
)
# initialize embeddings
for i in range(num_quantization_layers):
self.embeddings[i].weight.data.normal_(mean=0.0, std=0.02)
self._is_hf_initialized = True # disable automatic init
def forward(self, input_ids):
"""Input: [num_quant, B, T] -> Output: [B, T, H]"""
num_quant, B, T = input_ids.shape
summed_embeddings = torch.zeros(
B, T, self.embeddings[0].embedding_dim, device=input_ids.device
)
for i in range(num_quant):
summed_embeddings += self.embeddings[i](input_ids[i])
return summed_embeddings
class LlammaNARModel(LlamaModel):
def __init__(self, config):
"""Adding adaptive layer norm, conditional embeddings, and multi-level input embeddings to the decoder layer"""
super().__init__(config)
self.layers = nn.ModuleList(
[LlamaNARDecoderLayer(config) for _ in range(config.num_hidden_layers)]
)
self.norm = LlamaAdaptiveRMSNorm(
config.hidden_size, eps=config.rms_norm_eps, dim_cond=config.hidden_size
)
self.embed_cond = nn.Embedding(
NUM_QUANTIZERS, config.hidden_size
) # 7 quantization layers
for layer in self.layers:
layer.input_layernorm = LlamaAdaptiveRMSNorm(
config.hidden_size, eps=config.rms_norm_eps, dim_cond=config.hidden_size
)
layer.post_attention_layernorm = LlamaAdaptiveRMSNorm(
config.hidden_size, eps=config.rms_norm_eps, dim_cond=config.hidden_size
)
self.post_init()
def _prepare_decoder_attention_mask(
self, attention_mask, input_shape, inputs_embeds, past_key_values_length
):
# create noncausal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = None
def _expand_mask(
mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None
):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = (
mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(
inverted_mask.to(torch.bool), torch.finfo(dtype).min
)
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = _expand_mask(
attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
).to(inputs_embeds.device)
combined_attention_mask = (
expanded_attn_mask
if combined_attention_mask is None
else expanded_attn_mask + combined_attention_mask
)
return combined_attention_mask
def forward(
self,
input_ids: torch.LongTensor = None, # [num_quant, B, T]
cond: torch.LongTensor = None, # index for conditional embeddings, [B]
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
# retrieve some shape info
batch_size, seq_length, _ = input_ids.shape
inputs_embeds = input_ids # [B, T, H]
# embed cond
cond_embedding = self.embed_cond(cond) # [B, H]
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
seq_length_with_past = seq_length
past_key_values_length = 0
if past_key_values is not None:
past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
past_key_values_length,
seq_length + past_key_values_length,
dtype=torch.long,
device=device,
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
# embed positions
if attention_mask is None:
attention_mask = torch.ones(
(batch_size, seq_length_with_past),
dtype=torch.bool,
device=inputs_embeds.device,
)
attention_mask = self._prepare_decoder_attention_mask(
attention_mask,
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
)
hidden_states = inputs_embeds
if self.gradient_checkpointing and self.training:
if use_cache:
use_cache = False
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None
for idx, decoder_layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states,)
past_key_value = (
past_key_values[idx] if past_key_values is not None else None
)
if self.gradient_checkpointing and self.training:
raise NotImplementedError
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, output_attentions, None)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states,
attention_mask,
position_ids,
None,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cond_embedding=cond_embedding, # using cond embed
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states, cond_embedding=cond_embedding)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if not return_dict:
return tuple(
v
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
if v is not None
)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
from transformers.models.llama.modeling_llama import LlamaPreTrainedModel
from transformers.models.llama.modeling_llama import CrossEntropyLoss
from easydict import EasyDict as edict
class LlamaForNARModeling(LlamaPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.model = LlammaNARModel(config)
self.lm_head = nn.ModuleList(
[
nn.Linear(config.hidden_size, config.vocab_size, bias=False)
for i in range(END_QUANTIZATION_LAYER - START_QUANTIZATION_LAYER + 1)
]
)
# Initialize weights and apply final processing
self.post_init()
def forward(
self,
cond: torch.LongTensor, # added
prediction_target: torch.LongTensor = None, # added. No shifting. -100 means no loss
input_ids: torch.LongTensor = None, # expect an embedding, [B, T, H]
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
# labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
"""Prediction target: [B, T]"""
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
cond=cond, # added
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
logits = self.lm_head[cond - START_QUANTIZATION_LAYER](hidden_states)
loss = None
loss_fct = CrossEntropyLoss()
if prediction_target is not None:
# calculate loss if prediction_target is provided
logits_tmp = logits.view(-1, logits.size(-1))
prediction_target = prediction_target.view(-1)
loss = loss_fct(logits_tmp, prediction_target)
return edict(
loss=loss,
logits=logits,
)
class ValleNAR(nn.Module):
def __init__(
self,
phone_vocab_size=256,
target_vocab_size=1024,
hidden_size=1024,
intermediate_size=4096,
num_hidden_layers=12,
num_attention_heads=16,
pad_token_id=1024 + 256,
bos_target_id=1282,
eos_target_id=1283,
bos_phone_id=1284,
eos_phone_id=1285,
bos_prompt_id=1286,
eos_prompt_id=1287,
use_input_embeds=False,
emb_dim=256,
):
super(ValleNAR, self).__init__()
self.config = LlamaConfig(
vocab_size=phone_vocab_size + target_vocab_size + 10,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
num_hidden_layers=num_hidden_layers,
num_attention_heads=num_attention_heads,
pad_token_id=pad_token_id,
bos_token_id=bos_target_id,
eos_token_id=eos_target_id,
use_cache=False,
)
self.phone_vocab_size = phone_vocab_size
self.target_vocab_size = target_vocab_size
self.pad_token_id = pad_token_id
self.bos_target_id = bos_target_id
self.eos_target_id = eos_target_id
self.bos_phone_id = bos_phone_id
self.eos_phone_id = eos_phone_id
self.bos_prompt_id = bos_prompt_id
self.eos_prompt_id = eos_prompt_id
self.model = LlamaForNARModeling(self.config)
self.use_input_embeds = use_input_embeds
self.phone_embedder = nn.Embedding(
self.phone_vocab_size + 10, hidden_size
) # use phone_embedder to embed all eos, bos tokens
self.prompt_embedder = MultiEmbedding(
num_embeddings=self.target_vocab_size,
embedding_dim=hidden_size,
num_quantization_layers=NUM_QUANTIZERS,
)
self.phone_embedder.weight.data.normal_(mean=0.0, std=0.02)
# use linear mask schedule when training
# another option is uniform
self.mask_layer_schedule = "uniform"
# no input embedding is used to provide speaker information
if self.use_input_embeds:
self.emb_linear = nn.Linear(emb_dim, hidden_size)
self.emb_linear.weight.data.normal_(mean=0.0, std=0.01)
self.emb_linear.bias.data.zero_()
def forward(
self,
phone_ids,
phone_mask,
target_ids,
target_mask,
target_quantization_layer=None,
prompt_len=None,
dropout=0.0,
):
"""
phone_ids: [B, T]
phone_mask: [B, T]
target_ids: [8,B,T]
target_mask: [B, T]
dropout: rate of dropping out the target tokens
"""
assert (target_ids < 1024).all(), "target_ids should be less than 1024"
phone_ids = phone_ids + self.target_vocab_size
phone_ids = phone_ids * phone_mask + (1 - phone_mask) * self.pad_token_id
# assert (phone_ids >= 1024).all(), "phone_ids should be greater than 1024"
# phone_ids, phone_mask, phone_label = self.add_phone_eos_bos_label(
# phone_ids,
# phone_mask,
# self.eos_phone_id,
# self.bos_phone_id,
# self.pad_token_id,
# )
phone_label = -100 * (1 - phone_mask)
# get phone embedding
phone_embedding = self.phone_embedder(
phone_ids - self.target_vocab_size
) # [B, T, H]
if prompt_len is not None:
assert not self.training # inference stage fix prompt len to input
NUM_PROMPT_TOKENS = prompt_len
else:
assert self.training
# randomly select a prompt length
assert self.training # randomize prompt len in training
NUM_PROMPT_TOKENS = np.random.randint(
min(target_ids.shape[-1] // 4, 5), target_ids.shape[-1] // 2
)
# extract 8-level prompts
prompt_tokens = target_ids[:, :, :NUM_PROMPT_TOKENS] # [Q, B, T]
prompt_mask = torch.ones_like(prompt_tokens[0])
prompt_label = -100 * prompt_mask
# get prompt embedding
prompt_embedding = self.prompt_embedder(prompt_tokens) # [B, T, H]
# randomly select a target qnt layer to predict
# total quant layer is 0 to 7
if target_quantization_layer is None:
if self.mask_layer_schedule == "linear":
weights = torch.tensor(
[
NUM_QUANTIZERS - i
for i in range(
START_QUANTIZATION_LAYER, END_QUANTIZATION_LAYER + 1
)
]
)
weights = weights / weights.sum()
mask_layer = (
torch.multinomial(weights, 1, replacement=True)
+ START_QUANTIZATION_LAYER
)
assert (
mask_layer >= START_QUANTIZATION_LAYER
and mask_layer <= END_QUANTIZATION_LAYER
)
target_quantization_layer = mask_layer.item()
elif self.mask_layer_schedule == "cosine":
weights = torch.tensor(
[
np.cos(i / NUM_QUANTIZERS * np.pi / 2)
for i in range(
START_QUANTIZATION_LAYER, END_QUANTIZATION_LAYER + 1
)
]
)
weights = weights / weights.sum()
mask_layer = (
torch.multinomial(weights, 1, replacement=True)
+ START_QUANTIZATION_LAYER
)
assert (
mask_layer >= START_QUANTIZATION_LAYER
and mask_layer <= END_QUANTIZATION_LAYER
)
target_quantization_layer = mask_layer.item()
breakpoint()
elif self.mask_layer_schedule == "uniform":
target_quantization_layer = np.random.randint(
START_QUANTIZATION_LAYER, END_QUANTIZATION_LAYER + 1
)
# print(f'target layer: {target_quantization_layer}')
# prompt of the target part
target_prompt_ids = target_ids[
:target_quantization_layer, :, NUM_PROMPT_TOKENS:
]
def randomly_set_elements(tensor, fraction, value):
"""
Randomly set a fraction of the elements in a tensor to a specific value.
Args:
tensor (torch.Tensor): The input tensor.
fraction (float): The fraction of elements to set to the specified value (between 0 and 1).
value (float or int): The value to set the elements to.
Returns:
torch.Tensor: The tensor with some elements set to the specified value.
"""
# Create a mask with the same shape as the tensor
mask = torch.rand_like(tensor, dtype=torch.float32) < fraction
# Clone the tensor to avoid modifying the original tensor
result_tensor = tensor.clone()
# Set the elements where the mask is True to the specified value
result_tensor[mask] = value
return result_tensor
if dropout != 0.0:
target_prompt_ids = randomly_set_elements(
target_prompt_ids, dropout, self.target_vocab_size
)
target_embedding = self.prompt_embedder(target_prompt_ids)
# mask of the target part
target_mask = target_mask[:, NUM_PROMPT_TOKENS:]
target_labels = target_ids[
target_quantization_layer, :, NUM_PROMPT_TOKENS:
] * target_mask + (-100 * (1 - target_mask))
# input embeddings
input_embeddings = torch.cat(
[phone_embedding, prompt_embedding, target_embedding], dim=1
)
input_mask = torch.cat([phone_mask, prompt_mask, target_mask], dim=1) # [B, T]
prediction_target = torch.cat(
[phone_label, prompt_label, target_labels], dim=1
) # [B, T]
out = self.model(
cond=torch.tensor(
target_quantization_layer,
device=prediction_target.device,
dtype=torch.long,
),
input_ids=input_embeddings,
prediction_target=prediction_target,
attention_mask=input_mask,
return_dict=True,
)
logits = out.logits[:, -target_embedding.shape[1] :, :]
targets = prediction_target[..., -target_embedding.shape[1] :]
top1_acc = logits.argmax(-1) == targets
top1_acc = (top1_acc * target_mask).sum() / target_mask.sum()
top5_acc = (logits.topk(5, dim=-1).indices == targets.unsqueeze(-1)).any(-1)
top5_acc = (top5_acc * target_mask).sum() / target_mask.sum()
top10_acc = (logits.topk(10, dim=-1).indices == targets.unsqueeze(-1)).any(-1)
top10_acc = (top10_acc * target_mask).sum() / target_mask.sum()
out.target_quantization_layer = target_quantization_layer
out.top1_acc = top1_acc
out.top5_acc = top5_acc
out.top10_acc = top10_acc
return out
def add_phone_eos_bos_label(
self, phone_ids, phone_mask, phone_eos_id, phone_bos_id, pad_token_id
):
# phone_ids: [B, T]
# phone_mask: [B, T]
phone_ids = phone_ids + self.target_vocab_size * phone_mask
phone_ids = phone_ids * phone_mask
phone_ids = F.pad(phone_ids, (0, 1), value=0) + phone_eos_id * F.pad(
1 - phone_mask, (0, 1), value=1
) # make pad token eos token, add eos token at the end
phone_mask = F.pad(phone_mask, (1, 0), value=1) # add eos mask
phone_ids = phone_ids * phone_mask + pad_token_id * (
1 - phone_mask
) # restore pad token ids
phone_ids = F.pad(phone_ids, (1, 0), value=phone_bos_id) # add bos token
phone_mask = F.pad(phone_mask, (1, 0), value=1) # add bos mask
phone_label = -100 * torch.ones_like(
phone_ids
) # loss for entire phone is not computed (passed to llama)
return phone_ids, phone_mask, phone_label
@torch.no_grad()
def sample_hf(
self,
phone_ids, # [B, T]
prompt_ids, # [8, B, T]
first_stage_ids, # [B, T]
top_k=50,
top_p=1,
temperature=1.1,
first_stage_ids_gt=None, # [Q, B, T]
first_stage_ids_gt_end_layer=None, # 2 to 8
):
"""
phone_ids: [B, T]
prompt_ids: [8, B, T]
first_stage_ids: [B, T] result from first quant layer. Should be continuation of prompt_ids
"""
phone_mask = torch.ones_like(phone_ids, dtype=torch.long)
assert prompt_ids.shape[-1] >= 5, "prompt_ids should have at least 5 tokens"
target_ids = torch.cat(
[prompt_ids, first_stage_ids.expand(prompt_ids.shape[0], -1, -1)], dim=-1
)
target_mask = torch.ones_like(target_ids[0], dtype=torch.long)
if first_stage_ids_gt is not None:
target_ids[
:first_stage_ids_gt_end_layer, :, -first_stage_ids_gt.shape[-1] :
] = first_stage_ids_gt[:first_stage_ids_gt_end_layer]
gen_len = first_stage_ids.shape[-1]
start_qnt_layer = 1
if first_stage_ids_gt_end_layer is not None:
start_qnt_layer = first_stage_ids_gt_end_layer
for qnt_level in range(start_qnt_layer, 8):
out = self.forward(
phone_ids=phone_ids,
phone_mask=phone_mask,
target_ids=target_ids,
target_mask=target_mask,
target_quantization_layer=qnt_level,
prompt_len=prompt_ids.shape[-1],
)
logits = out.logits
gen_tokens = torch.argmax(logits, dim=-1).reshape(-1)[
-gen_len:
] # [T], generated tokens in this level
# overwrite the target_ids with the generated tokens
target_ids[qnt_level, :, -gen_len:] = gen_tokens
return target_ids[:, :, -gen_len:]
def test():
model = ValleNAR().cuda()
phone_ids = torch.LongTensor([1, 2, 3, 4, 5]).reshape(1, -1).cuda()
phone_mask = torch.LongTensor([1, 1, 1, 1, 1]).reshape(1, -1).cuda()
target_ids = torch.randint(high=1024, size=(8, 1, 250), dtype=torch.long).cuda()
target_mask = torch.ones(1, 250, dtype=torch.long).cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
for i in range(200):
optimizer.zero_grad()
out = model(
phone_ids=phone_ids,
phone_mask=phone_mask,
target_ids=target_ids,
target_mask=target_mask,
# target_quantization_layer=1+i%6,
)
loss = out.loss
loss.backward()
optimizer.step()
print(f"iter={i}, {loss}.")
target_ids_short = target_ids[:, :, :240]
model.eval()
sampled = model.sample_hf(
phone_ids, prompt_ids=target_ids_short, first_stage_ids=target_ids[0, :, 240:]
)
print(target_ids[:, :, -10:])
print(sampled)
print((sampled == target_ids[:, :, -10:]).all())
if __name__ == "__main__":
test()