Hecheng0625's picture
Upload 409 files
c968fc3 verified
raw
history blame
12.4 kB
# Copyright (c) 2024 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
import numpy as np
import torch.nn as nn
import math
from einops import rearrange
from models.tts.maskgct.llama_nar import DiffLlamaPrefix
def top_k(logits, thres=0.9):
k = math.ceil((1 - thres) * logits.shape[-1])
val, ind = logits.topk(k, dim=-1)
probs = torch.full_like(logits, float("-inf"))
probs.scatter_(2, ind, val)
return probs
def log(t, eps=1e-10):
return torch.log(t + eps)
def gumbel_noise(t):
noise = torch.zeros_like(t).uniform_(0, 1)
return -log(-log(noise))
def gumbel_sample(t, temperature=1.0, dim=-1):
return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim=dim)
class MaskGCT_T2S(nn.Module):
def __init__(
self,
hidden_size=1024,
num_layers=16,
num_heads=16,
cfg_scale=0.2,
cond_codebook_size=8192,
cond_dim=1024,
cfg=None,
):
super().__init__()
hidden_size = (
cfg.hidden_size
if cfg is not None and hasattr(cfg, "hidden_size")
else hidden_size
)
num_layers = (
cfg.num_layers
if cfg is not None and hasattr(cfg, "num_layers")
else num_layers
)
num_heads = (
cfg.num_heads
if cfg is not None and hasattr(cfg, "num_heads")
else num_heads
)
cfg_scale = (
cfg.cfg_scale
if cfg is not None and hasattr(cfg, "cfg_scale")
else cfg_scale
)
cond_codebook_size = (
cfg.cond_codebook_size
if cfg is not None and hasattr(cfg, "cond_codebook_size")
else cond_codebook_size
)
cond_dim = (
cfg.cond_dim if cfg is not None and hasattr(cfg, "cond_dim") else cond_dim
)
self.hidden_size = hidden_size
self.num_layers = num_layers
self.num_heads = num_heads
self.cfg_scale = cfg_scale
self.cond_codebook_size = cond_codebook_size
self.cond_dim = cond_dim
self.mask_emb = nn.Embedding(1, self.hidden_size)
self.to_logit = nn.Linear(self.hidden_size, self.cond_codebook_size)
self.cond_emb = nn.Embedding(cond_codebook_size, self.hidden_size)
self.phone_emb = nn.Embedding(1024, hidden_size, padding_idx=1023)
self.reset_parameters()
self.diff_estimator = DiffLlamaPrefix(
hidden_size=hidden_size,
num_heads=num_heads,
num_layers=num_layers,
)
def mask_prob(self, t):
return torch.sin(t * np.pi / 2).to(t.device)
def forward_diffusion(self, x0, t):
# x0: semantic tokens (B, T)
new_t = t
mask_prob = self.mask_prob(new_t) # (B,)
# if mask_prob[i] < 0.2, mask_prob[i] = 0.2
mask_prob = torch.where(
mask_prob < 0.2, torch.ones_like(mask_prob) * 0.2, mask_prob
)
mask_token = self.mask_emb(
torch.LongTensor([0]).to(x0.device)
) # (1, hidden_size)
xt = torch.zeros(x0.shape[0], x0.shape[1], self.hidden_size).to(x0.device)
cfg_scale = self.cfg_scale
# a segment of r% sequence length is masked, where r ~ U[60, 100]
if torch.rand(1) > cfg_scale:
prompt_len = torch.randint(
min(x0.shape[1] // 4, 5), int(x0.shape[1] * 0.4), (x0.shape[0],)
).to(
x0.device
) # (B,)
else:
prompt_len = torch.zeros(x0.shape[0]).to(x0) # (B,)
# get is prompt
is_prompt = torch.zeros_like(x0[:, :]) # (B, T)
col_indices = (
torch.arange(is_prompt.shape[1])
.repeat(is_prompt.shape[0], 1)
.to(prompt_len)
) # (B, T)
is_prompt[col_indices < prompt_len.unsqueeze(1)] = 1 # (B, T) 1 if prompt
# Add mask
mask = torch.bernoulli(torch.ones_like(x0[:, :]) * mask_prob[..., None])
mask[is_prompt.bool()] = 0
mask_num = mask[:,].sum(dim=1, keepdim=False)
all_zero_mask = (mask_num == 0).bool()
row_indices_to_modify = torch.nonzero(all_zero_mask)
mask[row_indices_to_modify, prompt_len[row_indices_to_modify]] = 1
mask = mask[..., None] # (B, T, 1)
xt = (
xt + mask * mask_token[:, None, :] + (1 - mask) * self.cond_emb(x0[:, :])
) # (B, T, hidden_size)
return xt, new_t, mask, prompt_len, mask_prob
def loss_t(self, x0, x_mask, t, phone_embedding=None, phone_mask=None):
xt, new_t, mask, prompt_len, mask_prob = self.forward_diffusion(x0, t)
# xt: (B, T, hidden_size)
# new_t: (B,)
# mask: (B, T, 1) mask if 1, not mask if 0
# prompt_len: (B,)
# mask_prob: (B,)
embeds = self.diff_estimator(
xt, new_t, x_mask, phone_embedding=phone_embedding, phone_mask=phone_mask
) # (B, T, hidden_size)
logits = self.to_logit(embeds) # (B, T, codebook_size)
# final mask used for loss calculation
final_mask = mask * x_mask[..., None] # (B, T, 1)
return logits, final_mask, x0, prompt_len, mask_prob
def compute_loss(self, x0, x_mask, phone_embedding=None, phone_mask=None):
# x0: (B, T)
# x_mask: (B, T) mask is 0 for padding
t = torch.rand(x0.shape[0], device=x0.device, requires_grad=False)
t = torch.clamp(t, 1e-5, 1.0)
return self.loss_t(x0, x_mask, t, phone_embedding, phone_mask)
def reset_parameters(self):
def _reset_parameters(m):
if isinstance(m, nn.MultiheadAttention):
if m._qkv_same_embed_dim:
nn.init.normal_(m.in_proj_weight, std=0.02)
else:
nn.init.normal_(m.q_proj_weight, std=0.02)
nn.init.normal_(m.k_proj_weight, std=0.02)
nn.init.normal_(m.v_proj_weight, std=0.02)
if m.in_proj_bias is not None:
nn.init.constant_(m.in_proj_bias, 0.0)
nn.init.constant_(m.out_proj.bias, 0.0)
if m.bias_k is not None:
nn.init.xavier_normal_(m.bias_k)
if m.bias_v is not None:
nn.init.xavier_normal_(m.bias_v)
elif (
isinstance(m, nn.Conv1d)
or isinstance(m, nn.ConvTranspose1d)
or isinstance(m, nn.Conv2d)
or isinstance(m, nn.ConvTranspose2d)
):
m.weight.data.normal_(0.0, 0.02)
elif isinstance(m, nn.Linear):
m.weight.data.normal_(mean=0.0, std=0.02)
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.Embedding):
m.weight.data.normal_(mean=0.0, std=0.02)
if m.padding_idx is not None:
m.weight.data[m.padding_idx].zero_()
self.apply(_reset_parameters)
@torch.no_grad()
def reverse_diffusion(
self,
prompt,
target_len,
phone_id,
prompt_mask=None,
temp=0.9,
filter_thres=0.98,
n_timesteps=40,
cfg=1.0,
rescale_cfg=1.0,
):
# prompt: (B, T)
phone_embedding = self.phone_emb(phone_id)
prompt_code = prompt # (B, prompt_len)
prompt_len = prompt_code.shape[1]
x_mask = torch.ones(prompt_code.shape[0], target_len).to(
prompt_code.device
) # (B, target_len)
phone_mask = torch.ones_like(phone_id)
if prompt_mask == None:
prompt_mask = torch.ones(prompt_code.shape[0], prompt_len).to(
prompt_code.device
) # (B, prompt_len)
cum = torch.zeros(x_mask.shape[0], x_mask.shape[1], self.hidden_size).to(
x_mask.device
) # (B, T, hidden_size)
bsz, seq_len, _ = cum.shape
choice_temp = 1.0
start_temp = temp # temperature for sampling
start_choice_temp = choice_temp # temperature for choicing mask tokens
xt = torch.LongTensor(bsz, seq_len).to(x_mask.device)
steps = n_timesteps
to_logit = self.to_logit
cond_emb = self.cond_emb
mask_token = self.mask_emb(torch.LongTensor([0]).to(xt.device))
mask = torch.full((bsz, seq_len, 1), True).to(x_mask.device) # (B, T, 1)
seq = torch.full((bsz, seq_len), 0).to(x_mask.device)
h = 1.0 / steps
cur_prompt = 0
cur_prompt = cur_prompt + cond_emb(prompt_code)
t_list = [1.0 - i * h for i in range(steps)]
t_list.append(0.0)
for i in range(steps):
t = t_list[i] * torch.ones(bsz).to(x_mask.device)
token = cond_emb(seq) # (B, T, hidden_size)
cur = cum + mask * mask_token[:, None, :] + (~mask) * token
xt_input = torch.cat([cur_prompt, cur], dim=1) # (B, T, hidden_size)
xt_mask = torch.cat(
[prompt_mask, x_mask], dim=1
) # (B, T), mask is 0 for padding
embeds = self.diff_estimator(
xt_input,
t,
xt_mask,
phone_embedding=phone_embedding,
phone_mask=phone_mask,
)
embeds = embeds[:, prompt_len:, :]
# classifier free guidance
# phone_embedding=phone_embedding[:,phone_embedding.shape[1]:,:] means phone_embedding is None
if cfg > 0:
mask_embeds = self.diff_estimator(
cur,
t,
x_mask,
phone_embedding=phone_embedding[:, phone_embedding.shape[1] :, :],
phone_mask=phone_mask[:, prompt_len:],
)
pos_emb_std = embeds.std() # std(g_cond)
embeds = embeds + cfg * (embeds - mask_embeds) # g_cfg
rescale_embeds = embeds * pos_emb_std / embeds.std() # g_final
embeds = rescale_cfg * rescale_embeds + (1 - rescale_cfg) * embeds
logits = to_logit(embeds) # (B, T, codebook_size)
annealing_scale = t_list[i]
choice_temp = start_choice_temp * annealing_scale
temp = start_temp * annealing_scale
logits = top_k(logits, filter_thres)
if i == steps - 1:
# greedy
if steps == 1:
temp = 0.2
sampled_ids = gumbel_sample(logits, temperature=max(temp, 1e-3))
else:
sampled_ids = logits.argmax(dim=-1)
else:
# sampling
sampled_ids = gumbel_sample(logits, temperature=max(temp, 1e-3))
seq = torch.where(mask.squeeze(-1), sampled_ids, seq)
scores = logits.softmax(dim=-1)
scores = scores.gather(2, rearrange(sampled_ids, "b n -> b n 1"))
scores = rearrange(scores, "b n 1 -> b n")
scores = choice_temp * gumbel_noise(scores) + scores
scores = 1 - scores
next_t = t_list[i + 1] * torch.ones(bsz).to(x_mask.device)
next_mask_num = (self.mask_prob(next_t) * seq_len).long()[0].item()
if next_mask_num == 0:
break
scores = scores.masked_fill(
~mask.squeeze(-1), -torch.finfo(scores.dtype).max
)
mask_indices = scores.topk(next_mask_num, dim=-1).indices
mask = torch.zeros_like(scores, dtype=torch.bool).scatter(
1, mask_indices, True
)
seq = seq.masked_fill(mask, 0)
mask = mask.unsqueeze(-1)
cum = cum + cond_emb(seq)
xt = seq
return xt
def forward(self, x0, x_mask, phone_id=None, phone_mask=None):
# x0: (B, T)
# x_mask: (B, T) mask is 0 for padding
phone_embedding = self.phone_emb(phone_id)
logits, final_mask, x0, prompt_len, mask_prob = self.compute_loss(
x0, x_mask, phone_embedding, phone_mask=phone_mask
)
return logits, final_mask, x0, prompt_len, mask_prob