KingNish commited on
Commit
bd3b355
·
verified ·
1 Parent(s): 2bd0633

Upload ./vocos/spectral_ops.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. vocos/spectral_ops.py +192 -0
vocos/spectral_ops.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import scipy
3
+ import torch
4
+ from torch import nn, view_as_real, view_as_complex
5
+
6
+
7
+ class ISTFT(nn.Module):
8
+ """
9
+ Custom implementation of ISTFT since torch.istft doesn't allow custom padding (other than `center=True`) with
10
+ windowing. This is because the NOLA (Nonzero Overlap Add) check fails at the edges.
11
+ See issue: https://github.com/pytorch/pytorch/issues/62323
12
+ Specifically, in the context of neural vocoding we are interested in "same" padding analogous to CNNs.
13
+ The NOLA constraint is met as we trim padded samples anyway.
14
+
15
+ Args:
16
+ n_fft (int): Size of Fourier transform.
17
+ hop_length (int): The distance between neighboring sliding window frames.
18
+ win_length (int): The size of window frame and STFT filter.
19
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
20
+ """
21
+
22
+ def __init__(self, n_fft: int, hop_length: int, win_length: int, padding: str = "same"):
23
+ super().__init__()
24
+ if padding not in ["center", "same"]:
25
+ raise ValueError("Padding must be 'center' or 'same'.")
26
+ self.padding = padding
27
+ self.n_fft = n_fft
28
+ self.hop_length = hop_length
29
+ self.win_length = win_length
30
+ window = torch.hann_window(win_length)
31
+ self.register_buffer("window", window)
32
+
33
+ def forward(self, spec: torch.Tensor) -> torch.Tensor:
34
+ """
35
+ Compute the Inverse Short Time Fourier Transform (ISTFT) of a complex spectrogram.
36
+
37
+ Args:
38
+ spec (Tensor): Input complex spectrogram of shape (B, N, T), where B is the batch size,
39
+ N is the number of frequency bins, and T is the number of time frames.
40
+
41
+ Returns:
42
+ Tensor: Reconstructed time-domain signal of shape (B, L), where L is the length of the output signal.
43
+ """
44
+ if self.padding == "center":
45
+ # Fallback to pytorch native implementation
46
+ return torch.istft(spec, self.n_fft, self.hop_length, self.win_length, self.window, center=True)
47
+ elif self.padding == "same":
48
+ pad = (self.win_length - self.hop_length) // 2
49
+ else:
50
+ raise ValueError("Padding must be 'center' or 'same'.")
51
+
52
+ assert spec.dim() == 3, "Expected a 3D tensor as input"
53
+ B, N, T = spec.shape
54
+
55
+ # Inverse FFT
56
+ ifft = torch.fft.irfft(spec, self.n_fft, dim=1, norm="backward")
57
+ ifft = ifft * self.window[None, :, None]
58
+
59
+ # Overlap and Add
60
+ output_size = (T - 1) * self.hop_length + self.win_length
61
+ y = torch.nn.functional.fold(
62
+ ifft, output_size=(1, output_size), kernel_size=(1, self.win_length), stride=(1, self.hop_length),
63
+ )[:, 0, 0, pad:-pad]
64
+
65
+ # Window envelope
66
+ window_sq = self.window.square().expand(1, T, -1).transpose(1, 2)
67
+ window_envelope = torch.nn.functional.fold(
68
+ window_sq, output_size=(1, output_size), kernel_size=(1, self.win_length), stride=(1, self.hop_length),
69
+ ).squeeze()[pad:-pad]
70
+
71
+ # Normalize
72
+ assert (window_envelope > 1e-11).all()
73
+ y = y / window_envelope
74
+
75
+ return y
76
+
77
+
78
+ class MDCT(nn.Module):
79
+ """
80
+ Modified Discrete Cosine Transform (MDCT) module.
81
+
82
+ Args:
83
+ frame_len (int): Length of the MDCT frame.
84
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
85
+ """
86
+
87
+ def __init__(self, frame_len: int, padding: str = "same"):
88
+ super().__init__()
89
+ if padding not in ["center", "same"]:
90
+ raise ValueError("Padding must be 'center' or 'same'.")
91
+ self.padding = padding
92
+ self.frame_len = frame_len
93
+ N = frame_len // 2
94
+ n0 = (N + 1) / 2
95
+ window = torch.from_numpy(scipy.signal.cosine(frame_len)).float()
96
+ self.register_buffer("window", window)
97
+
98
+ pre_twiddle = torch.exp(-1j * torch.pi * torch.arange(frame_len) / frame_len)
99
+ post_twiddle = torch.exp(-1j * torch.pi * n0 * (torch.arange(N) + 0.5) / N)
100
+ # view_as_real: NCCL Backend does not support ComplexFloat data type
101
+ # https://github.com/pytorch/pytorch/issues/71613
102
+ self.register_buffer("pre_twiddle", view_as_real(pre_twiddle))
103
+ self.register_buffer("post_twiddle", view_as_real(post_twiddle))
104
+
105
+ def forward(self, audio: torch.Tensor) -> torch.Tensor:
106
+ """
107
+ Apply the Modified Discrete Cosine Transform (MDCT) to the input audio.
108
+
109
+ Args:
110
+ audio (Tensor): Input audio waveform of shape (B, T), where B is the batch size
111
+ and T is the length of the audio.
112
+
113
+ Returns:
114
+ Tensor: MDCT coefficients of shape (B, L, N), where L is the number of output frames
115
+ and N is the number of frequency bins.
116
+ """
117
+ if self.padding == "center":
118
+ audio = torch.nn.functional.pad(audio, (self.frame_len // 2, self.frame_len // 2))
119
+ elif self.padding == "same":
120
+ # hop_length is 1/2 frame_len
121
+ audio = torch.nn.functional.pad(audio, (self.frame_len // 4, self.frame_len // 4))
122
+ else:
123
+ raise ValueError("Padding must be 'center' or 'same'.")
124
+
125
+ x = audio.unfold(-1, self.frame_len, self.frame_len // 2)
126
+ N = self.frame_len // 2
127
+ x = x * self.window.expand(x.shape)
128
+ X = torch.fft.fft(x * view_as_complex(self.pre_twiddle).expand(x.shape), dim=-1)[..., :N]
129
+ res = X * view_as_complex(self.post_twiddle).expand(X.shape) * np.sqrt(1 / N)
130
+ return torch.real(res) * np.sqrt(2)
131
+
132
+
133
+ class IMDCT(nn.Module):
134
+ """
135
+ Inverse Modified Discrete Cosine Transform (IMDCT) module.
136
+
137
+ Args:
138
+ frame_len (int): Length of the MDCT frame.
139
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
140
+ """
141
+
142
+ def __init__(self, frame_len: int, padding: str = "same"):
143
+ super().__init__()
144
+ if padding not in ["center", "same"]:
145
+ raise ValueError("Padding must be 'center' or 'same'.")
146
+ self.padding = padding
147
+ self.frame_len = frame_len
148
+ N = frame_len // 2
149
+ n0 = (N + 1) / 2
150
+ window = torch.from_numpy(scipy.signal.cosine(frame_len)).float()
151
+ self.register_buffer("window", window)
152
+
153
+ pre_twiddle = torch.exp(1j * torch.pi * n0 * torch.arange(N * 2) / N)
154
+ post_twiddle = torch.exp(1j * torch.pi * (torch.arange(N * 2) + n0) / (N * 2))
155
+ self.register_buffer("pre_twiddle", view_as_real(pre_twiddle))
156
+ self.register_buffer("post_twiddle", view_as_real(post_twiddle))
157
+
158
+ def forward(self, X: torch.Tensor) -> torch.Tensor:
159
+ """
160
+ Apply the Inverse Modified Discrete Cosine Transform (IMDCT) to the input MDCT coefficients.
161
+
162
+ Args:
163
+ X (Tensor): Input MDCT coefficients of shape (B, L, N), where B is the batch size,
164
+ L is the number of frames, and N is the number of frequency bins.
165
+
166
+ Returns:
167
+ Tensor: Reconstructed audio waveform of shape (B, T), where T is the length of the audio.
168
+ """
169
+ B, L, N = X.shape
170
+ Y = torch.zeros((B, L, N * 2), dtype=X.dtype, device=X.device)
171
+ Y[..., :N] = X
172
+ Y[..., N:] = -1 * torch.conj(torch.flip(X, dims=(-1,)))
173
+ y = torch.fft.ifft(Y * view_as_complex(self.pre_twiddle).expand(Y.shape), dim=-1)
174
+ y = torch.real(y * view_as_complex(self.post_twiddle).expand(y.shape)) * np.sqrt(N) * np.sqrt(2)
175
+ result = y * self.window.expand(y.shape)
176
+ output_size = (1, (L + 1) * N)
177
+ audio = torch.nn.functional.fold(
178
+ result.transpose(1, 2),
179
+ output_size=output_size,
180
+ kernel_size=(1, self.frame_len),
181
+ stride=(1, self.frame_len // 2),
182
+ )[:, 0, 0, :]
183
+
184
+ if self.padding == "center":
185
+ pad = self.frame_len // 2
186
+ elif self.padding == "same":
187
+ pad = self.frame_len // 4
188
+ else:
189
+ raise ValueError("Padding must be 'center' or 'same'.")
190
+
191
+ audio = audio[:, pad:-pad]
192
+ return audio