# Copyright (c) 2024 Amphion. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import torch import numpy as np import torch.nn as nn import math from einops import rearrange from models.tts.maskgct.llama_nar import DiffLlama def top_k(logits, thres=0.9): k = math.ceil((1 - thres) * logits.shape[-1]) val, ind = logits.topk(k, dim=-1) probs = torch.full_like(logits, float("-inf")) probs.scatter_(2, ind, val) return probs def log(t, eps=1e-10): return torch.log(t + eps) def gumbel_noise(t): noise = torch.zeros_like(t).uniform_(0, 1) return -log(-log(noise)) def gumbel_sample(t, temperature=1.0, dim=-1): return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim=dim) def top_k(logits, thres=0.9): k = math.ceil((1 - thres) * logits.shape[-1]) val, ind = logits.topk(k, dim=-1) probs = torch.full_like(logits, float("-inf")) probs.scatter_(2, ind, val) return probs def log(t, eps=1e-10): return torch.log(t + eps) def gumbel_noise(t): noise = torch.zeros_like(t).uniform_(0, 1) return -log(-log(noise)) def gumbel_sample(t, temperature=1.0, dim=-1): return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim=dim) class MaskGCT_S2A(nn.Module): def __init__( self, num_quantizer=12, hidden_size=1024, num_layers=16, num_heads=16, codebook_size=1024, cfg_scale=0.15, mask_layer_schedule="linear", cond_codebook_size=1024, cond_dim=1024, predict_layer_1=True, cfg=None, ): super().__init__() num_quantizer = ( cfg.num_quantizer if cfg is not None and hasattr(cfg, "num_quantizer") else num_quantizer ) hidden_size = ( cfg.hidden_size if cfg is not None and hasattr(cfg, "hidden_size") else hidden_size ) num_layers = ( cfg.num_layers if cfg is not None and hasattr(cfg, "num_layers") else num_layers ) num_heads = ( cfg.num_heads if cfg is not None and hasattr(cfg, "num_heads") else num_heads ) codebook_size = ( cfg.codebook_size if cfg is not None and hasattr(cfg, "codebook_size") else codebook_size ) cfg_scale = ( cfg.cfg_scale if cfg is not None and hasattr(cfg, "cfg_scale") else cfg_scale ) mask_layer_schedule = ( cfg.mask_layer_schedule if cfg is not None and hasattr(cfg, "mask_layer_schedule") else mask_layer_schedule ) cond_codebook_size = ( cfg.cond_codebook_size if cfg is not None and hasattr(cfg, "cond_codebook_size") else cond_codebook_size ) cond_dim = ( cfg.cond_dim if cfg is not None and hasattr(cfg, "cond_dim") else cond_dim ) predict_layer_1 = ( cfg.predict_layer_1 if cfg is not None and hasattr(cfg, "predict_layer_1") else predict_layer_1 ) self.num_quantizer = num_quantizer self.hidden_size = hidden_size self.codebook_size = codebook_size self.num_layers = num_layers self.num_heads = num_heads self.cfg_scale = cfg_scale self.mask_layer_schedule = mask_layer_schedule self.cond_codebook_size = cond_codebook_size self.cond_dim = cond_dim self.predict_layer_1 = predict_layer_1 self.layer_emb = nn.Embedding(self.num_quantizer, self.hidden_size) self.mask_emb = nn.Embedding(1, self.hidden_size) self.token_emb = torch.nn.ModuleList( [ nn.Embedding(self.codebook_size, self.hidden_size) for _ in range(self.num_quantizer) ] ) self.to_logits = torch.nn.ModuleList( [ nn.Linear(self.hidden_size, self.codebook_size) for _ in range(self.num_quantizer) ] ) self.cond_emb = nn.Embedding(cond_codebook_size, self.hidden_size) self.reset_parameters() self.diff_estimator = DiffLlama( hidden_size=hidden_size, num_heads=self.num_heads, num_layers=num_layers, ) def mask_prob(self, t): return torch.sin(t * np.pi / 2).to(t.device) def mask_layer(self, t): # print(self.predict_layer_1) if self.mask_layer_schedule == "uniform": if self.predict_layer_1: mask_layer = torch.randint(0, self.num_quantizer, (1,)).to(t.device) else: mask_layer = torch.randint(1, self.num_quantizer, (1,)).to(t.device) elif self.mask_layer_schedule == "cosine": if self.predict_layer_1: weights = torch.tensor( [ np.cos(i / self.num_quantizer * np.pi / 2) for i in range(self.num_quantizer) ] ) else: weights = torch.tensor( [0] + [ np.cos((i - 1) / self.num_quantizer * np.pi / 2) for i in range(1, self.num_quantizer) ] ) mask_layer = torch.multinomial(weights, 1).to(t.device) elif self.mask_layer_schedule == "linear": if self.predict_layer_1: weights = torch.tensor( [self.num_quantizer - i for i in range(self.num_quantizer)] ) else: weights = torch.tensor( [0] + [ self.num_quantizer - (i - 1) for i in range(1, self.num_quantizer) ] ) weights = weights / weights.sum() mask_layer = torch.multinomial(weights, 1).to(t.device) # print(mask_layer) new_t = t return mask_layer, new_t def forward_diffusion(self, x0, t): # x0: (B, T, num_quantizer) mask_layer, new_t = self.mask_layer(t) # (1,) mask_prob = self.mask_prob(new_t) # (B,) mask_token = self.mask_emb(torch.zeros_like(mask_layer)) # (1, hidden_size) xt = torch.zeros(x0.shape[0], x0.shape[1], self.hidden_size).to(x0.device) cfg_scale = self.cfg_scale # get prompt len if torch.rand(1) > cfg_scale: prompt_len = torch.randint( min(x0.shape[1] // 4, 5), x0.shape[1] // 2, (x0.shape[0],) ).to( x0.device ) # (B,) else: prompt_len = torch.zeros(x0.shape[0]).to(x0) # (B,) # get is prompt is_prompt = torch.zeros_like(x0[:, :, 0]) # (B, T) col_indices = ( torch.arange(is_prompt.shape[1]) .repeat(is_prompt.shape[0], 1) .to(prompt_len) ) # (B, T) is_prompt[col_indices < prompt_len.unsqueeze(1)] = 1 # (B, T) 1 if prompt for idx, token_emb_idx in enumerate(self.token_emb): if idx < mask_layer: xt = xt + token_emb_idx(x0[:, :, idx]) # (B, T, hidden_size) elif idx == mask_layer: mask = torch.bernoulli( torch.ones_like(x0[:, :, idx]) * mask_prob[..., None] ) # mask if 1, not mask if 0 # prompt part don't need to be masked mask[is_prompt.bool()] = 0 # Ensure at least one token is masked mask_num = mask[:,].sum(dim=1, keepdim=False) all_zero_mask = (mask_num == 0).bool() row_indices_to_modify = torch.nonzero(all_zero_mask) # mask the first token if all tokens are not masked (may mask pad if random indices) mask[row_indices_to_modify, prompt_len[row_indices_to_modify]] = 1 mask = mask[..., None] # (B, T, 1) xt = ( xt + mask * mask_token[:, None, :] + (1 - mask) * token_emb_idx(x0[:, :, idx]) ) # (B, T, hidden_size) else: # prompt part don't need to be masked xt = ( xt + token_emb_idx(x0[:, :, idx]) * is_prompt[..., None] + mask_token * (1 - is_prompt[..., None]) ) return xt, new_t, mask_layer, mask, prompt_len, mask_prob def loss_t(self, x0, x_mask, t, cond=None): xt, new_t, mask_layer, mask, prompt_len, mask_prob = self.forward_diffusion( x0, t ) # xt: (B, T, hidden_size) # new_t: (B,) # mask_layer: (1,) # mask: (B, T, 1) mask if 1, not mask if 0 # prompt_len: (B,) # mask_prob: (B,) mask_layer_cond = self.layer_emb(mask_layer).unsqueeze(1) # (1, 1, hidden_size) cond = cond + mask_layer_cond # (B, T, hidden_size) embeds = self.diff_estimator(xt, new_t, cond, x_mask) # (B, T, hidden_size) logits = self.to_logits[mask_layer.item()](embeds) # (B, T, codebook_size) # final mask used for loss calculation final_mask = mask * x_mask[..., None] # (B, T, 1) return logits, mask_layer, final_mask, x0, prompt_len, mask_prob def compute_loss(self, x0, x_mask, cond=None): # x0: (B, T, num_quantizer) # x_mask: (B, T) mask is 0 for padding t = torch.rand(x0.shape[0], device=x0.device, requires_grad=False) t = torch.clamp(t, 1e-5, 1.0) return self.loss_t(x0, x_mask, t, cond) def reset_parameters(self): def _reset_parameters(m): if isinstance(m, nn.MultiheadAttention): if m._qkv_same_embed_dim: nn.init.normal_(m.in_proj_weight, std=0.02) else: nn.init.normal_(m.q_proj_weight, std=0.02) nn.init.normal_(m.k_proj_weight, std=0.02) nn.init.normal_(m.v_proj_weight, std=0.02) if m.in_proj_bias is not None: nn.init.constant_(m.in_proj_bias, 0.0) nn.init.constant_(m.out_proj.bias, 0.0) if m.bias_k is not None: nn.init.xavier_normal_(m.bias_k) if m.bias_v is not None: nn.init.xavier_normal_(m.bias_v) elif ( isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d) or isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d) ): m.weight.data.normal_(0.0, 0.02) elif isinstance(m, nn.Linear): m.weight.data.normal_(mean=0.0, std=0.02) if m.bias is not None: m.bias.data.zero_() elif isinstance(m, nn.Embedding): m.weight.data.normal_(mean=0.0, std=0.02) if m.padding_idx is not None: m.weight.data[m.padding_idx].zero_() self.apply(_reset_parameters) @torch.no_grad() def reverse_diffusion( self, cond, prompt, x_mask=None, prompt_mask=None, temp=1.5, filter_thres=0.98, max_layer=None, gt_code=None, n_timesteps=[10, 4, 4, 4, 4, 4, 4, 4], cfg=1.0, rescale_cfg=1.0, ): assert ( len(n_timesteps) == self.num_quantizer ) # each layer has a number of steps prompt_code = prompt # (B, prompt_len, num_quantizer) prompt_len = prompt_code.shape[1] target_len = cond.shape[1] - prompt_len if x_mask == None: x_mask = torch.ones(cond.shape[0], target_len).to(cond.device) # (B, T) if prompt_mask == None: prompt_mask = torch.ones(cond.shape[0], prompt_len).to( cond.device ) # (B, prompt_len) cum = torch.zeros(x_mask.shape[0], x_mask.shape[1], self.hidden_size).to( x_mask.device ) # (B, T, hidden_size) bsz, seq_len, _ = cum.shape choice_temp = 1.0 start_temp = temp # temperature for sampling start_choice_temp = choice_temp # temperature for choicing mask tokens if max_layer is None: max_layer = self.num_quantizer xt = torch.LongTensor(bsz, seq_len, max_layer).to(x_mask.device) if gt_code is not None: gt_layer = gt_code.shape[-1] xt[:, :, :gt_layer] = gt_code for i in range(gt_layer): cum += self.token_emb[i](xt[:, :, i]) else: gt_layer = 0 for mask_layer in range(gt_layer, max_layer): steps = n_timesteps[mask_layer] to_logits = self.to_logits[mask_layer] token_emb = self.token_emb[mask_layer] mask_layer = torch.tensor(mask_layer).to(x_mask.device).long().unsqueeze(0) mask_layer_cond = self.layer_emb(mask_layer).unsqueeze( 1 ) # (1,) -> (1, 1, hidden_size) temp_cond = cond + mask_layer_cond # (B, T, hidden_size) mask_token = self.mask_emb(torch.zeros_like(mask_layer)) # (1, hidden_size) mask = torch.full((bsz, seq_len, 1), True).to(x_mask.device) # (B, T, 1) seq = torch.full((bsz, seq_len), 0).to(x_mask.device) h = 1.0 / steps # prompt_code: (B, prompt_len, num_quantizer) cur_prompt = 0 for idx, emb in enumerate(self.token_emb): cur_prompt = cur_prompt + emb( prompt_code[:, :, idx] ) # (B, prompt_len, hidden_size) t_list = [1.0 - i * h for i in range(steps)] t_list.append(0.0) for i in range(steps): t = t_list[i] * torch.ones(bsz).to(x_mask.device) token = token_emb(seq) # (B, T, hidden_size) cur = cum + mask * mask_token[:, None, :] + (~mask) * token cur = cur + mask_token[:, None, :] * (max_layer - 1 - mask_layer) xt_input = torch.cat([cur_prompt, cur], dim=1) # (B, T, hidden_size) xt_mask = torch.cat( [prompt_mask, x_mask], dim=1 ) # (B, T), mask is 0 for padding embeds = self.diff_estimator(xt_input, t, temp_cond, xt_mask) embeds = embeds[:, prompt_len:, :] # cfg if cfg > 0: mask_embeds = self.diff_estimator( cur, t, temp_cond[:, prompt_len:, :], x_mask ) pos_emb_std = embeds.std() # std(g_cond) embeds = embeds + cfg * (embeds - mask_embeds) # g_cfg rescale_embeds = embeds * pos_emb_std / embeds.std() # g_final embeds = rescale_cfg * rescale_embeds + (1 - rescale_cfg) * embeds logits = to_logits(embeds) # (B, T, codebook_size) annealing_scale = t_list[i] choice_temp = start_choice_temp * annealing_scale temp = start_temp * annealing_scale logits = top_k(logits, filter_thres) if i == steps - 1: # greedy if steps == 1: temp = 0.2 sampled_ids = gumbel_sample(logits, temperature=max(temp, 1e-3)) else: sampled_ids = logits.argmax(dim=-1) else: # sampling sampled_ids = gumbel_sample(logits, temperature=max(temp, 1e-3)) seq = torch.where(mask.squeeze(-1), sampled_ids, seq) scores = logits.softmax(dim=-1) scores = scores.gather(2, rearrange(sampled_ids, "b n -> b n 1")) scores = rearrange(scores, "b n 1 -> b n") scores = choice_temp * gumbel_noise(scores) + scores scores = 1 - scores next_t = t_list[i + 1] * torch.ones(bsz).to(x_mask.device) next_mask_num = (self.mask_prob(next_t) * seq_len).long()[0].item() if next_mask_num == 0: break scores = scores.masked_fill( ~mask.squeeze(-1), -torch.finfo(scores.dtype).max ) mask_indices = scores.topk(next_mask_num, dim=-1).indices mask = torch.zeros_like(scores, dtype=torch.bool).scatter( 1, mask_indices, True ) seq = seq.masked_fill(mask, 0) mask = mask.unsqueeze(-1) cum = cum + token_emb(seq) xt[..., mask_layer.squeeze(0).item()] = seq return xt def forward(self, x0, x_mask, cond_code=None): # x0: (B, T, num_quantizer) # x_mask: (B, T) mask is 0 for padding # cond_code: semantic token (B, T) cond = self.cond_emb(cond_code) logits, mask_layer, final_mask, x0, prompt_len, mask_prob = self.compute_loss( x0, x_mask, cond, ) return logits, mask_layer, final_mask, x0, prompt_len, mask_prob