Spaces:
Running
on
Zero
Running
on
Zero
Upload ./vocos/spectral_ops.py with huggingface_hub
Browse files- vocos/spectral_ops.py +192 -0
vocos/spectral_ops.py
ADDED
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import scipy
|
3 |
+
import torch
|
4 |
+
from torch import nn, view_as_real, view_as_complex
|
5 |
+
|
6 |
+
|
7 |
+
class ISTFT(nn.Module):
|
8 |
+
"""
|
9 |
+
Custom implementation of ISTFT since torch.istft doesn't allow custom padding (other than `center=True`) with
|
10 |
+
windowing. This is because the NOLA (Nonzero Overlap Add) check fails at the edges.
|
11 |
+
See issue: https://github.com/pytorch/pytorch/issues/62323
|
12 |
+
Specifically, in the context of neural vocoding we are interested in "same" padding analogous to CNNs.
|
13 |
+
The NOLA constraint is met as we trim padded samples anyway.
|
14 |
+
|
15 |
+
Args:
|
16 |
+
n_fft (int): Size of Fourier transform.
|
17 |
+
hop_length (int): The distance between neighboring sliding window frames.
|
18 |
+
win_length (int): The size of window frame and STFT filter.
|
19 |
+
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
|
20 |
+
"""
|
21 |
+
|
22 |
+
def __init__(self, n_fft: int, hop_length: int, win_length: int, padding: str = "same"):
|
23 |
+
super().__init__()
|
24 |
+
if padding not in ["center", "same"]:
|
25 |
+
raise ValueError("Padding must be 'center' or 'same'.")
|
26 |
+
self.padding = padding
|
27 |
+
self.n_fft = n_fft
|
28 |
+
self.hop_length = hop_length
|
29 |
+
self.win_length = win_length
|
30 |
+
window = torch.hann_window(win_length)
|
31 |
+
self.register_buffer("window", window)
|
32 |
+
|
33 |
+
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
34 |
+
"""
|
35 |
+
Compute the Inverse Short Time Fourier Transform (ISTFT) of a complex spectrogram.
|
36 |
+
|
37 |
+
Args:
|
38 |
+
spec (Tensor): Input complex spectrogram of shape (B, N, T), where B is the batch size,
|
39 |
+
N is the number of frequency bins, and T is the number of time frames.
|
40 |
+
|
41 |
+
Returns:
|
42 |
+
Tensor: Reconstructed time-domain signal of shape (B, L), where L is the length of the output signal.
|
43 |
+
"""
|
44 |
+
if self.padding == "center":
|
45 |
+
# Fallback to pytorch native implementation
|
46 |
+
return torch.istft(spec, self.n_fft, self.hop_length, self.win_length, self.window, center=True)
|
47 |
+
elif self.padding == "same":
|
48 |
+
pad = (self.win_length - self.hop_length) // 2
|
49 |
+
else:
|
50 |
+
raise ValueError("Padding must be 'center' or 'same'.")
|
51 |
+
|
52 |
+
assert spec.dim() == 3, "Expected a 3D tensor as input"
|
53 |
+
B, N, T = spec.shape
|
54 |
+
|
55 |
+
# Inverse FFT
|
56 |
+
ifft = torch.fft.irfft(spec, self.n_fft, dim=1, norm="backward")
|
57 |
+
ifft = ifft * self.window[None, :, None]
|
58 |
+
|
59 |
+
# Overlap and Add
|
60 |
+
output_size = (T - 1) * self.hop_length + self.win_length
|
61 |
+
y = torch.nn.functional.fold(
|
62 |
+
ifft, output_size=(1, output_size), kernel_size=(1, self.win_length), stride=(1, self.hop_length),
|
63 |
+
)[:, 0, 0, pad:-pad]
|
64 |
+
|
65 |
+
# Window envelope
|
66 |
+
window_sq = self.window.square().expand(1, T, -1).transpose(1, 2)
|
67 |
+
window_envelope = torch.nn.functional.fold(
|
68 |
+
window_sq, output_size=(1, output_size), kernel_size=(1, self.win_length), stride=(1, self.hop_length),
|
69 |
+
).squeeze()[pad:-pad]
|
70 |
+
|
71 |
+
# Normalize
|
72 |
+
assert (window_envelope > 1e-11).all()
|
73 |
+
y = y / window_envelope
|
74 |
+
|
75 |
+
return y
|
76 |
+
|
77 |
+
|
78 |
+
class MDCT(nn.Module):
|
79 |
+
"""
|
80 |
+
Modified Discrete Cosine Transform (MDCT) module.
|
81 |
+
|
82 |
+
Args:
|
83 |
+
frame_len (int): Length of the MDCT frame.
|
84 |
+
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
|
85 |
+
"""
|
86 |
+
|
87 |
+
def __init__(self, frame_len: int, padding: str = "same"):
|
88 |
+
super().__init__()
|
89 |
+
if padding not in ["center", "same"]:
|
90 |
+
raise ValueError("Padding must be 'center' or 'same'.")
|
91 |
+
self.padding = padding
|
92 |
+
self.frame_len = frame_len
|
93 |
+
N = frame_len // 2
|
94 |
+
n0 = (N + 1) / 2
|
95 |
+
window = torch.from_numpy(scipy.signal.cosine(frame_len)).float()
|
96 |
+
self.register_buffer("window", window)
|
97 |
+
|
98 |
+
pre_twiddle = torch.exp(-1j * torch.pi * torch.arange(frame_len) / frame_len)
|
99 |
+
post_twiddle = torch.exp(-1j * torch.pi * n0 * (torch.arange(N) + 0.5) / N)
|
100 |
+
# view_as_real: NCCL Backend does not support ComplexFloat data type
|
101 |
+
# https://github.com/pytorch/pytorch/issues/71613
|
102 |
+
self.register_buffer("pre_twiddle", view_as_real(pre_twiddle))
|
103 |
+
self.register_buffer("post_twiddle", view_as_real(post_twiddle))
|
104 |
+
|
105 |
+
def forward(self, audio: torch.Tensor) -> torch.Tensor:
|
106 |
+
"""
|
107 |
+
Apply the Modified Discrete Cosine Transform (MDCT) to the input audio.
|
108 |
+
|
109 |
+
Args:
|
110 |
+
audio (Tensor): Input audio waveform of shape (B, T), where B is the batch size
|
111 |
+
and T is the length of the audio.
|
112 |
+
|
113 |
+
Returns:
|
114 |
+
Tensor: MDCT coefficients of shape (B, L, N), where L is the number of output frames
|
115 |
+
and N is the number of frequency bins.
|
116 |
+
"""
|
117 |
+
if self.padding == "center":
|
118 |
+
audio = torch.nn.functional.pad(audio, (self.frame_len // 2, self.frame_len // 2))
|
119 |
+
elif self.padding == "same":
|
120 |
+
# hop_length is 1/2 frame_len
|
121 |
+
audio = torch.nn.functional.pad(audio, (self.frame_len // 4, self.frame_len // 4))
|
122 |
+
else:
|
123 |
+
raise ValueError("Padding must be 'center' or 'same'.")
|
124 |
+
|
125 |
+
x = audio.unfold(-1, self.frame_len, self.frame_len // 2)
|
126 |
+
N = self.frame_len // 2
|
127 |
+
x = x * self.window.expand(x.shape)
|
128 |
+
X = torch.fft.fft(x * view_as_complex(self.pre_twiddle).expand(x.shape), dim=-1)[..., :N]
|
129 |
+
res = X * view_as_complex(self.post_twiddle).expand(X.shape) * np.sqrt(1 / N)
|
130 |
+
return torch.real(res) * np.sqrt(2)
|
131 |
+
|
132 |
+
|
133 |
+
class IMDCT(nn.Module):
|
134 |
+
"""
|
135 |
+
Inverse Modified Discrete Cosine Transform (IMDCT) module.
|
136 |
+
|
137 |
+
Args:
|
138 |
+
frame_len (int): Length of the MDCT frame.
|
139 |
+
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
|
140 |
+
"""
|
141 |
+
|
142 |
+
def __init__(self, frame_len: int, padding: str = "same"):
|
143 |
+
super().__init__()
|
144 |
+
if padding not in ["center", "same"]:
|
145 |
+
raise ValueError("Padding must be 'center' or 'same'.")
|
146 |
+
self.padding = padding
|
147 |
+
self.frame_len = frame_len
|
148 |
+
N = frame_len // 2
|
149 |
+
n0 = (N + 1) / 2
|
150 |
+
window = torch.from_numpy(scipy.signal.cosine(frame_len)).float()
|
151 |
+
self.register_buffer("window", window)
|
152 |
+
|
153 |
+
pre_twiddle = torch.exp(1j * torch.pi * n0 * torch.arange(N * 2) / N)
|
154 |
+
post_twiddle = torch.exp(1j * torch.pi * (torch.arange(N * 2) + n0) / (N * 2))
|
155 |
+
self.register_buffer("pre_twiddle", view_as_real(pre_twiddle))
|
156 |
+
self.register_buffer("post_twiddle", view_as_real(post_twiddle))
|
157 |
+
|
158 |
+
def forward(self, X: torch.Tensor) -> torch.Tensor:
|
159 |
+
"""
|
160 |
+
Apply the Inverse Modified Discrete Cosine Transform (IMDCT) to the input MDCT coefficients.
|
161 |
+
|
162 |
+
Args:
|
163 |
+
X (Tensor): Input MDCT coefficients of shape (B, L, N), where B is the batch size,
|
164 |
+
L is the number of frames, and N is the number of frequency bins.
|
165 |
+
|
166 |
+
Returns:
|
167 |
+
Tensor: Reconstructed audio waveform of shape (B, T), where T is the length of the audio.
|
168 |
+
"""
|
169 |
+
B, L, N = X.shape
|
170 |
+
Y = torch.zeros((B, L, N * 2), dtype=X.dtype, device=X.device)
|
171 |
+
Y[..., :N] = X
|
172 |
+
Y[..., N:] = -1 * torch.conj(torch.flip(X, dims=(-1,)))
|
173 |
+
y = torch.fft.ifft(Y * view_as_complex(self.pre_twiddle).expand(Y.shape), dim=-1)
|
174 |
+
y = torch.real(y * view_as_complex(self.post_twiddle).expand(y.shape)) * np.sqrt(N) * np.sqrt(2)
|
175 |
+
result = y * self.window.expand(y.shape)
|
176 |
+
output_size = (1, (L + 1) * N)
|
177 |
+
audio = torch.nn.functional.fold(
|
178 |
+
result.transpose(1, 2),
|
179 |
+
output_size=output_size,
|
180 |
+
kernel_size=(1, self.frame_len),
|
181 |
+
stride=(1, self.frame_len // 2),
|
182 |
+
)[:, 0, 0, :]
|
183 |
+
|
184 |
+
if self.padding == "center":
|
185 |
+
pad = self.frame_len // 2
|
186 |
+
elif self.padding == "same":
|
187 |
+
pad = self.frame_len // 4
|
188 |
+
else:
|
189 |
+
raise ValueError("Padding must be 'center' or 'same'.")
|
190 |
+
|
191 |
+
audio = audio[:, pad:-pad]
|
192 |
+
return audio
|