Spaces:
Running
on
Zero
Running
on
Zero
from typing import Optional | |
import torch | |
from torch import nn | |
from torchaudio.functional.functional import _hz_to_mel, _mel_to_hz | |
from vocos.spectral_ops import IMDCT, ISTFT | |
from vocos.modules import symexp | |
class FourierHead(nn.Module): | |
"""Base class for inverse fourier modules.""" | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
""" | |
Args: | |
x (Tensor): Input tensor of shape (B, L, H), where B is the batch size, | |
L is the sequence length, and H denotes the model dimension. | |
Returns: | |
Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal. | |
""" | |
raise NotImplementedError("Subclasses must implement the forward method.") | |
class ISTFTHead(FourierHead): | |
""" | |
ISTFT Head module for predicting STFT complex coefficients. | |
Args: | |
dim (int): Hidden dimension of the model. | |
n_fft (int): Size of Fourier transform. | |
hop_length (int): The distance between neighboring sliding window frames, which should align with | |
the resolution of the input features. | |
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". | |
""" | |
def __init__(self, dim: int, n_fft: int, hop_length: int, padding: str = "same", ckpt: Optional[str] = None): | |
super().__init__() | |
out_dim = n_fft + 2 | |
self.out = torch.nn.Linear(dim, out_dim) | |
self.istft = ISTFT(n_fft=n_fft, hop_length=hop_length, win_length=n_fft, padding=padding) | |
# load the checkpoint if provided | |
if ckpt is not None: | |
params = torch.load(ckpt, map_location="cpu") | |
# find head.out.weight and head.out.bias in the checkpoint | |
out_weight = params["head.out.weight"] | |
out_bias = params["head.out.bias"] | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
""" | |
Forward pass of the ISTFTHead module. | |
Args: | |
x (Tensor): Input tensor of shape (B, L, H), where B is the batch size, | |
L is the sequence length, and H denotes the model dimension. | |
Returns: | |
Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal. | |
""" | |
x = self.out(x).transpose(1, 2) | |
mag, p = x.chunk(2, dim=1) | |
mag = torch.exp(mag) | |
mag = torch.clip(mag, max=1e2) # safeguard to prevent excessively large magnitudes | |
# wrapping happens here. These two lines produce real and imaginary value | |
x = torch.cos(p) | |
y = torch.sin(p) | |
# recalculating phase here does not produce anything new | |
# only costs time | |
# phase = torch.atan2(y, x) | |
# S = mag * torch.exp(phase * 1j) | |
# better directly produce the complex value | |
S = mag * (x + 1j * y) | |
audio = self.istft(S) | |
return audio | |
class IMDCTSymExpHead(FourierHead): | |
""" | |
IMDCT Head module for predicting MDCT coefficients with symmetric exponential function | |
Args: | |
dim (int): Hidden dimension of the model. | |
mdct_frame_len (int): Length of the MDCT frame. | |
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". | |
sample_rate (int, optional): The sample rate of the audio. If provided, the last layer will be initialized | |
based on perceptual scaling. Defaults to None. | |
clip_audio (bool, optional): Whether to clip the audio output within the range of [-1.0, 1.0]. Defaults to False. | |
""" | |
def __init__( | |
self, | |
dim: int, | |
mdct_frame_len: int, | |
padding: str = "same", | |
sample_rate: Optional[int] = None, | |
clip_audio: bool = False, | |
): | |
super().__init__() | |
out_dim = mdct_frame_len // 2 | |
self.out = nn.Linear(dim, out_dim) | |
self.imdct = IMDCT(frame_len=mdct_frame_len, padding=padding) | |
self.clip_audio = clip_audio | |
if sample_rate is not None: | |
# optionally init the last layer following mel-scale | |
m_max = _hz_to_mel(sample_rate // 2) | |
m_pts = torch.linspace(0, m_max, out_dim) | |
f_pts = _mel_to_hz(m_pts) | |
scale = 1 - (f_pts / f_pts.max()) | |
with torch.no_grad(): | |
self.out.weight.mul_(scale.view(-1, 1)) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
""" | |
Forward pass of the IMDCTSymExpHead module. | |
Args: | |
x (Tensor): Input tensor of shape (B, L, H), where B is the batch size, | |
L is the sequence length, and H denotes the model dimension. | |
Returns: | |
Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal. | |
""" | |
x = self.out(x) | |
x = symexp(x) | |
x = torch.clip(x, min=-1e2, max=1e2) # safeguard to prevent excessively large magnitudes | |
audio = self.imdct(x) | |
if self.clip_audio: | |
audio = torch.clip(x, min=-1.0, max=1.0) | |
return audio | |
class IMDCTCosHead(FourierHead): | |
""" | |
IMDCT Head module for predicting MDCT coefficients with parametrizing MDCT = exp(m) · cos(p) | |
Args: | |
dim (int): Hidden dimension of the model. | |
mdct_frame_len (int): Length of the MDCT frame. | |
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". | |
clip_audio (bool, optional): Whether to clip the audio output within the range of [-1.0, 1.0]. Defaults to False. | |
""" | |
def __init__(self, dim: int, mdct_frame_len: int, padding: str = "same", clip_audio: bool = False): | |
super().__init__() | |
self.clip_audio = clip_audio | |
self.out = nn.Linear(dim, mdct_frame_len) | |
self.imdct = IMDCT(frame_len=mdct_frame_len, padding=padding) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
""" | |
Forward pass of the IMDCTCosHead module. | |
Args: | |
x (Tensor): Input tensor of shape (B, L, H), where B is the batch size, | |
L is the sequence length, and H denotes the model dimension. | |
Returns: | |
Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal. | |
""" | |
x = self.out(x) | |
m, p = x.chunk(2, dim=2) | |
m = torch.exp(m).clip(max=1e2) # safeguard to prevent excessively large magnitudes | |
audio = self.imdct(m * torch.cos(p)) | |
if self.clip_audio: | |
audio = torch.clip(x, min=-1.0, max=1.0) | |
return audio | |