""" ein notation: b - batch n - sequence nt - text sequence nw - raw wave length d - dimension """ from __future__ import annotations from typing import Callable from random import random import torch from torch import nn import torch import torch.nn.functional as F from torch.nn.utils.rnn import pad_sequence from torchdiffeq import odeint from diffrhythm.model.modules import MelSpec from diffrhythm.model.utils import ( default, exists, list_str_to_idx, list_str_to_tensor, lens_to_mask, mask_from_frac_lengths, ) def custom_mask_from_start_end_indices(seq_len: int["b"], start: int["b"], end: int["b"], device, max_seq_len): # noqa: F722 F821 max_seq_len = max_seq_len seq = torch.arange(max_seq_len, device=device).long() start_mask = seq[None, :] >= start[:, None] end_mask = seq[None, :] < end[:, None] return start_mask & end_mask class CFM(nn.Module): def __init__( self, transformer: nn.Module, sigma=0.0, odeint_kwargs: dict = dict( # atol = 1e-5, # rtol = 1e-5, method="euler" # 'midpoint' # method="adaptive_heun" # dopri5 ), odeint_options: dict = dict( min_step=0.05 ), audio_drop_prob=0.3, cond_drop_prob=0.2, style_drop_prob=0.1, lrc_drop_prob=0.1, num_channels=None, frac_lengths_mask: tuple[float, float] = (0.7, 1.0), vocab_char_map: dict[str:int] | None = None, use_style_prompt: bool = False ): super().__init__() self.frac_lengths_mask = frac_lengths_mask self.num_channels = num_channels # classifier-free guidance self.audio_drop_prob = audio_drop_prob self.cond_drop_prob = cond_drop_prob self.style_drop_prob = style_drop_prob self.lrc_drop_prob = lrc_drop_prob print(f"audio drop prob -> {self.audio_drop_prob}; style_drop_prob -> {self.style_drop_prob}; lrc_drop_prob: {self.lrc_drop_prob}") # transformer self.transformer = transformer dim = transformer.dim self.dim = dim # conditional flow related self.sigma = sigma # sampling related self.odeint_kwargs = odeint_kwargs # print(f"ODE SOLVER: {self.odeint_kwargs['method']}") self.odeint_options = odeint_options # vocab map for tokenization self.vocab_char_map = vocab_char_map self.use_style_prompt = use_style_prompt @property def device(self): return next(self.parameters()).device @torch.no_grad() def sample( self, cond: float["b n d"] | float["b nw"], # noqa: F722 text: int["b nt"] | list[str], # noqa: F722 duration: int | int["b"], # noqa: F821 *, style_prompt = None, style_prompt_lens = None, negative_style_prompt = None, lens: int["b"] | None = None, # noqa: F821 steps=32, cfg_strength=4.0, sway_sampling_coef=None, seed: int | None = None, max_duration=4096, vocoder: Callable[[float["b d n"]], float["b nw"]] | None = None, # noqa: F722 no_ref_audio=False, duplicate_test=False, t_inter=0.1, edit_mask=None, start_time=None, latent_pred_start_frame=0, latent_pred_end_frame=2048, ): self.eval() if next(self.parameters()).dtype == torch.float16: cond = cond.half() # raw wave if cond.shape[1] > duration: cond = cond[:, :duration, :] if cond.ndim == 2: cond = self.mel_spec(cond) cond = cond.permute(0, 2, 1) assert cond.shape[-1] == self.num_channels batch, cond_seq_len, device = *cond.shape[:2], cond.device if not exists(lens): lens = torch.full((batch,), cond_seq_len, device=device, dtype=torch.long) # text if isinstance(text, list): if exists(self.vocab_char_map): text = list_str_to_idx(text, self.vocab_char_map).to(device) else: text = list_str_to_tensor(text).to(device) assert text.shape[0] == batch if exists(text): text_lens = (text != -1).sum(dim=-1) #lens = torch.maximum(text_lens, lens) # make sure lengths are at least those of the text characters # duration # import pdb; pdb.set_trace() cond_mask = lens_to_mask(lens) if edit_mask is not None: cond_mask = cond_mask & edit_mask latent_pred_start_frame = torch.tensor([latent_pred_start_frame]).to(cond.device) latent_pred_end_frame = duration latent_pred_end_frame = torch.tensor([latent_pred_end_frame]).to(cond.device) fixed_span_mask = custom_mask_from_start_end_indices(cond_seq_len, latent_pred_start_frame, latent_pred_end_frame, device=cond.device, max_seq_len=duration) fixed_span_mask = fixed_span_mask.unsqueeze(-1) step_cond = torch.where(fixed_span_mask, torch.zeros_like(cond), cond) if isinstance(duration, int): duration = torch.full((batch,), duration, device=device, dtype=torch.long) # duration = torch.maximum(lens + 1, duration) # just add one token so something is generated duration = duration.clamp(max=max_duration) max_duration = duration.amax() # duplicate test corner for inner time step oberservation if duplicate_test: test_cond = F.pad(cond, (0, 0, cond_seq_len, max_duration - 2 * cond_seq_len), value=0.0) # cond = F.pad(cond, (0, 0, 0, max_duration - cond_seq_len), value=0.0) # [b, t, d] # cond_mask = F.pad(cond_mask, (0, max_duration - cond_mask.shape[-1]), value=False) # [b, max_duration] # cond_mask = cond_mask.unsqueeze(-1) #[b, t, d] # step_cond = torch.where( # cond_mask, cond, torch.zeros_like(cond) # ) # allow direct control (cut cond audio) with lens passed in if batch > 1: mask = lens_to_mask(duration) else: # save memory and speed up, as single inference need no mask currently mask = None # test for no ref audio if no_ref_audio: cond = torch.zeros_like(cond) start_time_embed, positive_text_embed, positive_text_residuals = self.transformer.forward_timestep_invariant(text, step_cond.shape[1], drop_text=False, start_time=start_time) _, negative_text_embed, negative_text_residuals = self.transformer.forward_timestep_invariant(text, step_cond.shape[1], drop_text=True, start_time=start_time) text_embed = torch.cat([positive_text_embed, negative_text_embed], 0) text_residuals = [torch.cat([a, b], 0) for a, b in zip(positive_text_residuals, negative_text_residuals)] step_cond = torch.cat([step_cond, step_cond], 0) style_prompt = torch.cat([style_prompt, negative_style_prompt], 0) start_time_embed = torch.cat([start_time_embed, start_time_embed], 0) def fn(t, x): x = torch.cat([x, x], 0) pred = self.transformer( x=x, text_embed=text_embed, text_residuals=text_residuals, cond=step_cond, time=t, drop_audio_cond=True, drop_prompt=False, style_prompt=style_prompt, start_time=start_time_embed ) positive_pred, negative_pred = pred.chunk(2, 0) cfg_pred = positive_pred + (positive_pred - negative_pred) * cfg_strength return cfg_pred # noise input # to make sure batch inference result is same with different batch size, and for sure single inference # still some difference maybe due to convolutional layers y0 = [] for dur in duration: if exists(seed): torch.manual_seed(seed) y0.append(torch.randn(dur, self.num_channels, device=self.device, dtype=step_cond.dtype)) y0 = pad_sequence(y0, padding_value=0, batch_first=True) t_start = 0 # duplicate test corner for inner time step oberservation if duplicate_test: t_start = t_inter y0 = (1 - t_start) * y0 + t_start * test_cond steps = int(steps * (1 - t_start)) t = torch.linspace(t_start, 1, steps, device=self.device, dtype=step_cond.dtype) if sway_sampling_coef is not None: t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t) trajectory = odeint(fn, y0, t, **self.odeint_kwargs) sampled = trajectory[-1] out = sampled # out = torch.where(cond_mask, cond, out) out = torch.where(fixed_span_mask, out, cond) if exists(vocoder): out = out.permute(0, 2, 1) out = vocoder(out) return out, trajectory def forward( self, inp: float["b n d"] | float["b nw"], # mel or raw wave # noqa: F722 text: int["b nt"] | list[str], # noqa: F722 style_prompt = None, style_prompt_lens = None, lens: int["b"] | None = None, # noqa: F821 noise_scheduler: str | None = None, grad_ckpt = False, start_time = None, ): batch, seq_len, dtype, device, _σ1 = *inp.shape[:2], inp.dtype, self.device, self.sigma # lens and mask if not exists(lens): lens = torch.full((batch,), seq_len, device=device) mask = lens_to_mask(lens, length=seq_len) # useless here, as collate_fn will pad to max length in batch # get a random span to mask out for training conditionally frac_lengths = torch.zeros((batch,), device=self.device).float().uniform_(*self.frac_lengths_mask) rand_span_mask = mask_from_frac_lengths(lens, frac_lengths) if exists(mask): rand_span_mask = mask # rand_span_mask &= mask # mel is x1 x1 = inp # x0 is gaussian noise x0 = torch.randn_like(x1) # time step # time = torch.rand((batch,), dtype=dtype, device=self.device) time = torch.normal(mean=0, std=1, size=(batch,), device=self.device) time = torch.nn.functional.sigmoid(time) # TODO. noise_scheduler # sample xt (φ_t(x) in the paper) t = time.unsqueeze(-1).unsqueeze(-1) φ = (1 - t) * x0 + t * x1 flow = x1 - x0 # only predict what is within the random mask span for infilling cond = torch.where(rand_span_mask[..., None], torch.zeros_like(x1), x1) # transformer and cfg training with a drop rate drop_audio_cond = random() < self.audio_drop_prob # p_drop in voicebox paper drop_text = random() < self.lrc_drop_prob drop_prompt = random() < self.style_drop_prob # if want rigourously mask out padding, record in collate_fn in dataset.py, and pass in here # adding mask will use more memory, thus also need to adjust batchsampler with scaled down threshold for long sequences pred = self.transformer( x=φ, cond=cond, text=text, time=time, drop_audio_cond=drop_audio_cond, drop_text=drop_text, drop_prompt=drop_prompt, style_prompt=style_prompt, style_prompt_lens=style_prompt_lens, grad_ckpt=grad_ckpt, start_time=start_time ) # flow matching loss loss = F.mse_loss(pred, flow, reduction="none") loss = loss[rand_span_mask] return loss.mean(), cond, pred