wetdog commited on
Commit
c52280c
1 Parent(s): a2aa500

Inital demo

Browse files
LICENSE ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 MediaLab, Department of Electrical & Electronic Engineering, Stellenbosch University
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software, and you spend at least 10 seconds
14
+ thinking about whether the idea of copyright for Software actually makes sense
15
+ the first time you download the Software.
16
+
17
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23
+ SOFTWARE.
24
+
25
+
app.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchaudio
3
+ import spaces
4
+ from typing import List
5
+ import soundfile as sf
6
+
7
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8
+ knn_vc = torch.hub.load('bshall/knn-vc', 'knn_vc', prematched=True, trust_repo=True, pretrained=True, device=device)
9
+
10
+
11
+ def convert_voice(src_wav_path:str, ref_wav_paths, top_k:int):
12
+
13
+ query_seq = knn_vc.get_features(src_wav_path)
14
+ matching_set = knn_vc.get_matching_set([ref_wav_paths])
15
+ out_wav = knn_vc.match(query_seq, matching_set, topk=int(top_k))
16
+
17
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as converted_file:
18
+ sf.write(converted_file.name, out_wav, 16000, "PCM_24")
19
+
20
+ return converted_file.name
21
+
22
+
23
+ title = """
24
+ <div style="text-align: center; max-width: 700px; margin: 0 auto;">
25
+ <div
26
+ style="display: inline-flex; align-items: center; gap: 0.8rem; font-size: 1.75rem;"
27
+ > <h1 style="font-weight: 900; margin-bottom: 7px; line-height: normal;">
28
+ KNN Voice Conversion
29
+ </h1> </div>
30
+ </div>
31
+ """
32
+
33
+ description = """
34
+ Voice Conversion With Just k-Nearest Neighbors. The source and reference utterance(s) are encoded into self-supervised features using WavLM.
35
+ Each source feature is assigned to the mean of the k closest features from the reference.
36
+ The resulting feature sequence is then vocoded with HiFi-GAN to arrive at the converted waveform output.
37
+ """
38
+
39
+ article = """
40
+ If the model contributes to your research please cite the following work:
41
+
42
+ Baas, M., van Niekerk, B., & Kamper, H. (2023). Voice conversion with just nearest neighbors. arXiv preprint arXiv:2305.18975.
43
+
44
+ demo contributed by [@wetdog](https://github.com/wetdog)
45
+ """
46
+ demo = gr.Blocks()
47
+ with demo:
48
+ gr.Markdown(title)
49
+ gr.Markdown(description)
50
+ gr.Interface(
51
+ fn=convert_voice,
52
+ inputs=[
53
+ gr.Audio(type='filepath'),
54
+ gr.Audio(type='filepath'),
55
+ gr.Slider(
56
+ 3,
57
+ 10,
58
+ value=4,
59
+ step=1,
60
+ label="Top-k",
61
+ info=f"These default settings provide pretty good results, but feel free to modify the kNN topk",
62
+ )],
63
+ outputs=[gr.Audio(type='filepath')],
64
+ allow_flagging=False,)
65
+ gr.Markdown(article)
66
+
67
+ demo.queue(max_size=10)
68
+ demo.launch(show_api=False, server_name="0.0.0.0", server_port=7860)
69
+
hifigan/config_v1_wavlm.json ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "resblock": "1",
3
+ "num_gpus": 0,
4
+ "batch_size": 16,
5
+ "learning_rate": 0.0002,
6
+ "adam_b1": 0.8,
7
+ "adam_b2": 0.99,
8
+ "lr_decay": 0.999,
9
+ "seed": 1234,
10
+
11
+ "upsample_rates": [10,8,2,2],
12
+ "upsample_kernel_sizes": [20,16,4,4],
13
+ "upsample_initial_channel": 512,
14
+ "resblock_kernel_sizes": [3,7,11],
15
+ "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
16
+
17
+ "hubert_dim": 1024,
18
+ "hifi_dim": 512,
19
+
20
+ "segment_size": 7040,
21
+ "num_mels": 80,
22
+ "num_freq": 1025,
23
+ "n_fft": 1024,
24
+ "hop_size": 320,
25
+ "win_size": 1024,
26
+
27
+ "sampling_rate": 16000,
28
+
29
+ "fmin": 0,
30
+ "fmax": 8000,
31
+ "fmax_for_loss": null,
32
+
33
+ "num_workers": 4,
34
+
35
+ "dist_config": {
36
+ "dist_backend": "nccl",
37
+ "dist_url": "tcp://localhost:54321",
38
+ "world_size": 1
39
+ }
40
+ }
hifigan/meldataset.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ import random
4
+ from pathlib import Path
5
+
6
+ import librosa
7
+ import numpy as np
8
+ import pandas as pd
9
+ import torch
10
+ import torch.nn.functional as F
11
+ import torch.utils.data
12
+ import torchaudio
13
+ from librosa.filters import mel as librosa_mel_fn
14
+ from librosa.util import normalize
15
+ from scipy.io.wavfile import read
16
+
17
+
18
+ def load_wav(full_path):
19
+ #sampling_rate, data = read(full_path)
20
+ #return data, sampling_rate
21
+ data, sampling_rate = librosa.load(full_path, sr=None)
22
+ return data, sampling_rate
23
+
24
+
25
+ def dynamic_range_compression(x, C=1, clip_val=1e-5):
26
+ return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
27
+
28
+
29
+ def dynamic_range_decompression(x, C=1):
30
+ return np.exp(x) / C
31
+
32
+
33
+ def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
34
+ return torch.log(torch.clamp(x, min=clip_val) * C)
35
+
36
+
37
+ def dynamic_range_decompression_torch(x, C=1):
38
+ return torch.exp(x) / C
39
+
40
+
41
+ def spectral_normalize_torch(magnitudes):
42
+ output = dynamic_range_compression_torch(magnitudes)
43
+ return output
44
+
45
+
46
+ def spectral_de_normalize_torch(magnitudes):
47
+ output = dynamic_range_decompression_torch(magnitudes)
48
+ return output
49
+
50
+
51
+ mel_basis = {}
52
+ hann_window = {}
53
+
54
+ class LogMelSpectrogram(torch.nn.Module):
55
+ def __init__(self, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
56
+ super().__init__()
57
+ self.melspctrogram = torchaudio.transforms.MelSpectrogram(
58
+ sample_rate=sampling_rate,
59
+ n_fft=n_fft,
60
+ win_length=win_size,
61
+ hop_length=hop_size,
62
+ center=center,
63
+ power=1.0,
64
+ norm="slaney",
65
+ onesided=True,
66
+ n_mels=num_mels,
67
+ mel_scale="slaney",
68
+ f_min=fmin,
69
+ f_max=fmax
70
+ )
71
+ self.n_fft = n_fft
72
+ self.hop_size = hop_size
73
+
74
+ def forward(self, wav):
75
+ wav = F.pad(wav, ((self.n_fft - self.hop_size) // 2, (self.n_fft - self.hop_size) // 2), "reflect")
76
+ mel = self.melspctrogram(wav)
77
+ logmel = torch.log(torch.clamp(mel, min=1e-5))
78
+ return logmel
79
+
80
+
81
+ def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
82
+ if torch.min(y) < -1.:
83
+ print('min value is ', torch.min(y))
84
+ if torch.max(y) > 1.:
85
+ print('max value is ', torch.max(y))
86
+
87
+ global mel_basis, hann_window
88
+ if fmax not in mel_basis:
89
+ mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax)
90
+ mel_basis[str(fmax)+'_'+str(y.device)] = torch.from_numpy(mel).float().to(y.device)
91
+ hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
92
+
93
+ # print("Padding by", int((n_fft - hop_size)/2), y.shape)
94
+ # pre-padding
95
+ n_pad = hop_size - ( y.shape[1] % hop_size )
96
+ y = F.pad(y.unsqueeze(1), (0, n_pad), mode='reflect').squeeze(1)
97
+ # print("intermediate:", y.shape)
98
+
99
+ y = F.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
100
+ y = y.squeeze(1)
101
+
102
+ spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)],
103
+ center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=True)
104
+ spec = spec.abs().clamp_(3e-5)
105
+ # print("Post: ", y.shape, spec.shape)
106
+
107
+ spec = torch.matmul(mel_basis[str(fmax)+'_'+str(y.device)], spec)
108
+ spec = spectral_normalize_torch(spec)
109
+
110
+ return spec
111
+
112
+
113
+ def get_dataset_filelist(a):
114
+ train_df = pd.read_csv(a.input_training_file)
115
+ valid_df = pd.read_csv(a.input_validation_file)
116
+ return train_df, valid_df
117
+
118
+
119
+ class MelDataset(torch.utils.data.Dataset):
120
+ def __init__(self, training_files, segment_size, n_fft, num_mels,
121
+ hop_size, win_size, sampling_rate, fmin, fmax, split=True, shuffle=True, n_cache_reuse=1,
122
+ device=None, fmax_loss=None, fine_tuning=False, audio_root_path=None, feat_root_path=None, use_alt_melcalc=False):
123
+ self.audio_files = training_files
124
+ if shuffle:
125
+ self.audio_files = self.audio_files.sample(frac=1, random_state=1234)
126
+ self.segment_size = segment_size
127
+ self.sampling_rate = sampling_rate
128
+ self.split = split
129
+ self.n_fft = n_fft
130
+ self.num_mels = num_mels
131
+ self.hop_size = hop_size
132
+ self.win_size = win_size
133
+ self.fmin = fmin
134
+ self.fmax = fmax
135
+ self.fmax_loss = fmax_loss
136
+ self.cached_wav = None
137
+ self.n_cache_reuse = n_cache_reuse
138
+ self._cache_ref_count = 0
139
+ self.device = device
140
+ self.fine_tuning = fine_tuning
141
+ self.audio_root_path = Path(audio_root_path)
142
+ self.feat_root_path = Path(feat_root_path)
143
+ self.alt_melspec = LogMelSpectrogram(n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax)
144
+ self.use_alt_melcalc = use_alt_melcalc
145
+
146
+ def __getitem__(self, index):
147
+ row = self.audio_files.iloc[index]
148
+ if self._cache_ref_count == 0:
149
+ audio, sampling_rate = load_wav(self.audio_root_path/row.audio_path)
150
+ if not self.fine_tuning:
151
+ audio = normalize(audio) * 0.95
152
+ self.cached_wav = audio
153
+ if sampling_rate != self.sampling_rate:
154
+ raise ValueError("{} SR doesn't match target {} SR".format(
155
+ sampling_rate, self.sampling_rate))
156
+ self._cache_ref_count = self.n_cache_reuse
157
+ else:
158
+ audio = self.cached_wav
159
+ self._cache_ref_count -= 1
160
+
161
+ audio = torch.tensor(audio, dtype=torch.float32)
162
+ audio = audio.unsqueeze(0)
163
+
164
+ if not self.fine_tuning:
165
+ if self.split:
166
+ if audio.size(1) >= self.segment_size:
167
+ max_audio_start = audio.size(1) - self.segment_size
168
+ audio_start = random.randint(0, max_audio_start)
169
+ audio = audio[:, audio_start:audio_start+self.segment_size]
170
+ else:
171
+ audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), 'constant')
172
+
173
+ if self.use_alt_melcalc:
174
+ mel = self.alt_melspec(audio)
175
+ else:
176
+ mel1 = mel_spectrogram(audio, self.n_fft, self.num_mels,
177
+ self.sampling_rate, self.hop_size, self.win_size, self.fmin, self.fmax,
178
+ center=False)
179
+
180
+ mel = mel.permute(0, 2, 1) # (1, dim, seq_len) --> (1, seq_len, dim)
181
+ else:
182
+ mel = torch.load(self.feat_root_path/row.feat_path, map_location='cpu').float()
183
+
184
+ if len(mel.shape) < 3:
185
+ mel = mel.unsqueeze(0) # (1, seq_len, dim)
186
+
187
+ if self.split:
188
+ frames_per_seg = math.ceil(self.segment_size / self.hop_size)
189
+
190
+ if audio.size(1) >= self.segment_size:
191
+ mel_start = random.randint(0, mel.size(1) - frames_per_seg - 1)
192
+ mel = mel[:, mel_start:mel_start + frames_per_seg, :]
193
+ audio = audio[:, mel_start * self.hop_size:(mel_start + frames_per_seg) * self.hop_size]
194
+ else:
195
+ mel = torch.nn.functional.pad(mel, (0, 0, 0, frames_per_seg - mel.size(2)), 'constant')
196
+ audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), 'constant')
197
+
198
+
199
+ if self.use_alt_melcalc:
200
+ mel_loss = self.alt_melspec(audio)
201
+ else:
202
+ mel_loss = mel_spectrogram(audio, self.n_fft, self.num_mels,
203
+ self.sampling_rate, self.hop_size, self.win_size, self.fmin, self.fmax_loss,
204
+ center=False)
205
+ return (mel.squeeze(), audio.squeeze(0), str(row.audio_path), mel_loss.squeeze())
206
+
207
+ def __len__(self):
208
+ return len(self.audio_files)
hifigan/models.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import torch.nn as nn
4
+ from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
5
+ from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
6
+ from .utils import init_weights, get_padding
7
+
8
+ LRELU_SLOPE = 0.1
9
+
10
+
11
+ class ResBlock1(torch.nn.Module):
12
+ def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
13
+ super(ResBlock1, self).__init__()
14
+ self.h = h
15
+ self.convs1 = nn.ModuleList([
16
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
17
+ padding=get_padding(kernel_size, dilation[0]))),
18
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
19
+ padding=get_padding(kernel_size, dilation[1]))),
20
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
21
+ padding=get_padding(kernel_size, dilation[2])))
22
+ ])
23
+ self.convs1.apply(init_weights)
24
+
25
+ self.convs2 = nn.ModuleList([
26
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
27
+ padding=get_padding(kernel_size, 1))),
28
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
29
+ padding=get_padding(kernel_size, 1))),
30
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
31
+ padding=get_padding(kernel_size, 1)))
32
+ ])
33
+ self.convs2.apply(init_weights)
34
+
35
+ def forward(self, x):
36
+ for c1, c2 in zip(self.convs1, self.convs2):
37
+ xt = F.leaky_relu(x, LRELU_SLOPE)
38
+ xt = c1(xt)
39
+ xt = F.leaky_relu(xt, LRELU_SLOPE)
40
+ xt = c2(xt)
41
+ x = xt + x
42
+ return x
43
+
44
+ def remove_weight_norm(self):
45
+ for l in self.convs1:
46
+ remove_weight_norm(l)
47
+ for l in self.convs2:
48
+ remove_weight_norm(l)
49
+
50
+
51
+ class ResBlock2(torch.nn.Module):
52
+ def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)):
53
+ super(ResBlock2, self).__init__()
54
+ self.h = h
55
+ self.convs = nn.ModuleList([
56
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
57
+ padding=get_padding(kernel_size, dilation[0]))),
58
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
59
+ padding=get_padding(kernel_size, dilation[1])))
60
+ ])
61
+ self.convs.apply(init_weights)
62
+
63
+ def forward(self, x):
64
+ for c in self.convs:
65
+ xt = F.leaky_relu(x, LRELU_SLOPE)
66
+ xt = c(xt)
67
+ x = xt + x
68
+ return x
69
+
70
+ def remove_weight_norm(self):
71
+ for l in self.convs:
72
+ remove_weight_norm(l)
73
+
74
+
75
+ class Generator(torch.nn.Module):
76
+ def __init__(self, h):
77
+ super(Generator, self).__init__()
78
+ self.h = h
79
+ self.lin_pre = nn.Linear(h.hubert_dim, h.hifi_dim)
80
+ self.num_kernels = len(h.resblock_kernel_sizes)
81
+ self.num_upsamples = len(h.upsample_rates)
82
+ self.conv_pre = weight_norm(Conv1d(h.hifi_dim, h.upsample_initial_channel, 7, 1, padding=3))
83
+ resblock = ResBlock1 if h.resblock == '1' else ResBlock2
84
+
85
+ self.ups = nn.ModuleList()
86
+ for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
87
+
88
+ self.ups.append(weight_norm(
89
+ ConvTranspose1d(h.upsample_initial_channel//(2**i), h.upsample_initial_channel//(2**(i+1)),
90
+ k, u, padding=(k-u)//2)))
91
+
92
+ self.resblocks = nn.ModuleList()
93
+ for i in range(len(self.ups)):
94
+ ch = h.upsample_initial_channel//(2**(i+1))
95
+ for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
96
+ self.resblocks.append(resblock(h, ch, k, d))
97
+
98
+ self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
99
+ self.ups.apply(init_weights)
100
+ self.conv_post.apply(init_weights)
101
+
102
+ def forward(self, x):
103
+ """ `x` as (bs, seq_len, dim), regular hifi assumes input of shape (bs, n_mels, seq_len) """
104
+ x = self.lin_pre(x)
105
+ x = x.permute(0, 2, 1) # (bs, seq_len, dim) --> (bs, dim, seq_len)
106
+
107
+ x = self.conv_pre(x)
108
+ for i in range(self.num_upsamples):
109
+ x = F.leaky_relu(x, LRELU_SLOPE)
110
+ x = self.ups[i](x)
111
+ xs = None
112
+ for j in range(self.num_kernels):
113
+ if xs is None:
114
+ xs = self.resblocks[i*self.num_kernels+j](x)
115
+ else:
116
+ xs += self.resblocks[i*self.num_kernels+j](x)
117
+ x = xs / self.num_kernels
118
+ x = F.leaky_relu(x)
119
+ x = self.conv_post(x)
120
+ x = torch.tanh(x)
121
+
122
+ return x
123
+
124
+ def remove_weight_norm(self):
125
+ print('Removing weight norm...')
126
+ for l in self.ups:
127
+ remove_weight_norm(l)
128
+ for l in self.resblocks:
129
+ l.remove_weight_norm()
130
+ remove_weight_norm(self.conv_pre)
131
+ remove_weight_norm(self.conv_post)
132
+
133
+
134
+ class DiscriminatorP(torch.nn.Module):
135
+ def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
136
+ super(DiscriminatorP, self).__init__()
137
+ self.period = period
138
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
139
+ self.convs = nn.ModuleList([
140
+ norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
141
+ norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
142
+ norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
143
+ norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
144
+ norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
145
+ ])
146
+ self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
147
+
148
+ def forward(self, x):
149
+ fmap = []
150
+
151
+ # 1d to 2d
152
+ b, c, t = x.shape
153
+ if t % self.period != 0: # pad first
154
+ n_pad = self.period - (t % self.period)
155
+ x = F.pad(x, (0, n_pad), "reflect")
156
+ t = t + n_pad
157
+ x = x.view(b, c, t // self.period, self.period)
158
+
159
+ for l in self.convs:
160
+ x = l(x)
161
+ x = F.leaky_relu(x, LRELU_SLOPE)
162
+ fmap.append(x)
163
+ x = self.conv_post(x)
164
+ fmap.append(x)
165
+ x = torch.flatten(x, 1, -1)
166
+
167
+ return x, fmap
168
+
169
+
170
+ class MultiPeriodDiscriminator(torch.nn.Module):
171
+ def __init__(self):
172
+ super(MultiPeriodDiscriminator, self).__init__()
173
+ self.discriminators = nn.ModuleList([
174
+ DiscriminatorP(2),
175
+ DiscriminatorP(3),
176
+ DiscriminatorP(5),
177
+ DiscriminatorP(7),
178
+ DiscriminatorP(11),
179
+ ])
180
+
181
+ def forward(self, y, y_hat):
182
+ y_d_rs = []
183
+ y_d_gs = []
184
+ fmap_rs = []
185
+ fmap_gs = []
186
+ for i, d in enumerate(self.discriminators):
187
+ y_d_r, fmap_r = d(y)
188
+ y_d_g, fmap_g = d(y_hat)
189
+ y_d_rs.append(y_d_r)
190
+ fmap_rs.append(fmap_r)
191
+ y_d_gs.append(y_d_g)
192
+ fmap_gs.append(fmap_g)
193
+
194
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
195
+
196
+
197
+ class DiscriminatorS(torch.nn.Module):
198
+ def __init__(self, use_spectral_norm=False):
199
+ super(DiscriminatorS, self).__init__()
200
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
201
+ self.convs = nn.ModuleList([
202
+ norm_f(Conv1d(1, 128, 15, 1, padding=7)),
203
+ norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)),
204
+ norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)),
205
+ norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)),
206
+ norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)),
207
+ norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),
208
+ norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
209
+ ])
210
+ self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
211
+
212
+ def forward(self, x):
213
+ fmap = []
214
+ for l in self.convs:
215
+ x = l(x)
216
+ x = F.leaky_relu(x, LRELU_SLOPE)
217
+ fmap.append(x)
218
+ x = self.conv_post(x)
219
+ fmap.append(x)
220
+ x = torch.flatten(x, 1, -1)
221
+
222
+ return x, fmap
223
+
224
+
225
+ class MultiScaleDiscriminator(torch.nn.Module):
226
+ def __init__(self):
227
+ super(MultiScaleDiscriminator, self).__init__()
228
+ self.discriminators = nn.ModuleList([
229
+ DiscriminatorS(use_spectral_norm=True),
230
+ DiscriminatorS(),
231
+ DiscriminatorS(),
232
+ ])
233
+ self.meanpools = nn.ModuleList([
234
+ AvgPool1d(4, 2, padding=2),
235
+ AvgPool1d(4, 2, padding=2)
236
+ ])
237
+
238
+ def forward(self, y, y_hat):
239
+ y_d_rs = []
240
+ y_d_gs = []
241
+ fmap_rs = []
242
+ fmap_gs = []
243
+ for i, d in enumerate(self.discriminators):
244
+ if i != 0:
245
+ y = self.meanpools[i-1](y)
246
+ y_hat = self.meanpools[i-1](y_hat)
247
+ y_d_r, fmap_r = d(y)
248
+ y_d_g, fmap_g = d(y_hat)
249
+ y_d_rs.append(y_d_r)
250
+ fmap_rs.append(fmap_r)
251
+ y_d_gs.append(y_d_g)
252
+ fmap_gs.append(fmap_g)
253
+
254
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
255
+
256
+
257
+ def feature_loss(fmap_r, fmap_g):
258
+ loss = 0
259
+ for dr, dg in zip(fmap_r, fmap_g):
260
+ for rl, gl in zip(dr, dg):
261
+ loss += torch.mean(torch.abs(rl - gl))
262
+
263
+ return loss*2
264
+
265
+
266
+ def discriminator_loss(disc_real_outputs, disc_generated_outputs):
267
+ loss = 0
268
+ r_losses = []
269
+ g_losses = []
270
+ for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
271
+ r_loss = torch.mean((1-dr)**2)
272
+ g_loss = torch.mean(dg**2)
273
+ loss += (r_loss + g_loss)
274
+ r_losses.append(r_loss.item())
275
+ g_losses.append(g_loss.item())
276
+
277
+ return loss, r_losses, g_losses
278
+
279
+
280
+ def generator_loss(disc_outputs):
281
+ loss = 0
282
+ gen_losses = []
283
+ for dg in disc_outputs:
284
+ l = torch.mean((1-dg)**2)
285
+ gen_losses.append(l)
286
+ loss += l
287
+
288
+ return loss, gen_losses
289
+
hifigan/train.py ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import itertools
3
+ import json
4
+ import os
5
+ import time
6
+
7
+ import torch
8
+ import torch.multiprocessing as mp
9
+ import torch.nn.functional as F
10
+ from fastprogress import master_bar, progress_bar
11
+ from torch.cuda.amp.grad_scaler import GradScaler
12
+ from torch.distributed import init_process_group
13
+ from torch.nn.parallel import DistributedDataParallel
14
+ from torch.utils.data import DataLoader, DistributedSampler
15
+ from torch.utils.tensorboard import SummaryWriter
16
+
17
+ from .meldataset import (LogMelSpectrogram, MelDataset, get_dataset_filelist,
18
+ mel_spectrogram)
19
+ from .models import (Generator, MultiPeriodDiscriminator,
20
+ MultiScaleDiscriminator, discriminator_loss, feature_loss,
21
+ generator_loss)
22
+ from .utils import (AttrDict, build_env, load_checkpoint, plot_spectrogram,
23
+ save_checkpoint, scan_checkpoint)
24
+
25
+ torch.backends.cudnn.benchmark = True
26
+ USE_ALT_MELCALC = True
27
+
28
+
29
+ def train(rank, a, h):
30
+ if h.num_gpus > 1:
31
+ init_process_group(backend=h.dist_config['dist_backend'], init_method=h.dist_config['dist_url'],
32
+ world_size=h.dist_config['world_size'] * h.num_gpus, rank=rank)
33
+
34
+ torch.cuda.manual_seed(h.seed)
35
+ device = torch.device('cuda:{:d}'.format(rank))
36
+
37
+ generator = Generator(h).to(device)
38
+ mpd = MultiPeriodDiscriminator().to(device)
39
+ msd = MultiScaleDiscriminator().to(device)
40
+
41
+ if rank == 0:
42
+ print(generator)
43
+ os.makedirs(a.checkpoint_path, exist_ok=True)
44
+ print("checkpoints directory : ", a.checkpoint_path)
45
+
46
+ if os.path.isdir(a.checkpoint_path):
47
+ cp_g = scan_checkpoint(a.checkpoint_path, 'g_')
48
+ cp_do = scan_checkpoint(a.checkpoint_path, 'do_')
49
+
50
+ steps = 0
51
+ if cp_g is None or cp_do is None:
52
+ state_dict_do = None
53
+ last_epoch = -1
54
+ else:
55
+ state_dict_g = load_checkpoint(cp_g, device)
56
+ state_dict_do = load_checkpoint(cp_do, device)
57
+ generator.load_state_dict(state_dict_g['generator'])
58
+ mpd.load_state_dict(state_dict_do['mpd'])
59
+ msd.load_state_dict(state_dict_do['msd'])
60
+ steps = state_dict_do['steps'] + 1
61
+ last_epoch = state_dict_do['epoch']
62
+ print(f"Restored checkpoint from {cp_g} and {cp_do}")
63
+
64
+ if h.num_gpus > 1:
65
+ print("Multi-gpu detected")
66
+ generator = DistributedDataParallel(generator, device_ids=[rank]).to(device)
67
+ mpd = DistributedDataParallel(mpd, device_ids=[rank]).to(device)
68
+ msd = DistributedDataParallel(msd, device_ids=[rank]).to(device)
69
+
70
+ optim_g = torch.optim.AdamW(generator.parameters(), h.learning_rate, betas=[h.adam_b1, h.adam_b2])
71
+ optim_d = torch.optim.AdamW(itertools.chain(msd.parameters(), mpd.parameters()),
72
+ h.learning_rate, betas=[h.adam_b1, h.adam_b2])
73
+
74
+ if state_dict_do is not None:
75
+ optim_g.load_state_dict(state_dict_do['optim_g'])
76
+ optim_d.load_state_dict(state_dict_do['optim_d'])
77
+
78
+ scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=h.lr_decay, last_epoch=last_epoch)
79
+ scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=h.lr_decay, last_epoch=last_epoch)
80
+ if a.fp16:
81
+ scaler_g = GradScaler()
82
+ scaler_d = GradScaler()
83
+
84
+ train_df, valid_df = get_dataset_filelist(a)
85
+
86
+ trainset = MelDataset(train_df, h.segment_size, h.n_fft, h.num_mels,
87
+ h.hop_size, h.win_size, h.sampling_rate, h.fmin, h.fmax, n_cache_reuse=0,
88
+ shuffle=False if h.num_gpus > 1 else True, fmax_loss=h.fmax_for_loss, device=device,
89
+ fine_tuning=a.fine_tuning,
90
+ audio_root_path=a.audio_root_path, feat_root_path=a.feature_root_path,
91
+ use_alt_melcalc=USE_ALT_MELCALC)
92
+
93
+ train_sampler = DistributedSampler(trainset) if h.num_gpus > 1 else None
94
+
95
+ train_loader = DataLoader(trainset, num_workers=h.num_workers, shuffle=False,
96
+ sampler=train_sampler,
97
+ batch_size=h.batch_size,
98
+ pin_memory=True,
99
+ persistent_workers=True,
100
+ drop_last=True)
101
+
102
+ alt_melspec = LogMelSpectrogram(h.n_fft, h.num_mels, h.sampling_rate, h.hop_size, h.win_size, h.fmin, h.fmax).to(device)
103
+
104
+ if rank == 0:
105
+ validset = MelDataset(valid_df, h.segment_size, h.n_fft, h.num_mels,
106
+ h.hop_size, h.win_size, h.sampling_rate, h.fmin, h.fmax, False, False, n_cache_reuse=0,
107
+ fmax_loss=h.fmax_for_loss, device=device, fine_tuning=a.fine_tuning,
108
+ audio_root_path=a.audio_root_path, feat_root_path=a.feature_root_path,
109
+ use_alt_melcalc=USE_ALT_MELCALC)
110
+ validation_loader = DataLoader(validset, num_workers=1, shuffle=False,
111
+ sampler=None,
112
+ batch_size=1,
113
+ pin_memory=True,
114
+ persistent_workers=True,
115
+ drop_last=True)
116
+
117
+ sw = SummaryWriter(os.path.join(a.checkpoint_path, 'logs'))
118
+
119
+ generator.train()
120
+ mpd.train()
121
+ msd.train()
122
+
123
+ if rank == 0: mb = master_bar(range(max(0, last_epoch), a.training_epochs))
124
+ else: mb = range(max(0, last_epoch), a.training_epochs)
125
+
126
+ for epoch in mb:
127
+ if rank == 0:
128
+ start = time.time()
129
+ mb.write("Epoch: {}".format(epoch+1))
130
+
131
+ if h.num_gpus > 1:
132
+ train_sampler.set_epoch(epoch)
133
+
134
+ if rank == 0: pb = progress_bar(enumerate(train_loader), total=len(train_loader), parent=mb)
135
+ else: pb = enumerate(train_loader)
136
+
137
+
138
+ for i, batch in pb:
139
+ if rank == 0:
140
+ start_b = time.time()
141
+ x, y, _, y_mel = batch
142
+ x = x.to(device, non_blocking=True)
143
+ y = y.to(device, non_blocking=True)
144
+ y_mel = y_mel.to(device, non_blocking=True)
145
+ y = y.unsqueeze(1)
146
+
147
+ with torch.cuda.amp.autocast(enabled=a.fp16):
148
+ y_g_hat = generator(x)
149
+ if USE_ALT_MELCALC:
150
+ y_g_hat_mel = alt_melspec(y_g_hat.squeeze(1))
151
+ else:
152
+ y_g_hat_mel = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels, h.sampling_rate, h.hop_size, h.win_size,
153
+ h.fmin, h.fmax_for_loss)
154
+ # print(x.shape, y_g_hat.shape, y_g_hat_mel.shape, y_mel.shape, y.shape)
155
+ optim_d.zero_grad()
156
+
157
+ with torch.cuda.amp.autocast(enabled=a.fp16):
158
+ # MPD
159
+ y_df_hat_r, y_df_hat_g, _, _ = mpd(y, y_g_hat.detach())
160
+ loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss(y_df_hat_r, y_df_hat_g)
161
+
162
+ # MSD
163
+ y_ds_hat_r, y_ds_hat_g, _, _ = msd(y, y_g_hat.detach())
164
+ loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss(y_ds_hat_r, y_ds_hat_g)
165
+
166
+ loss_disc_all = loss_disc_s + loss_disc_f
167
+
168
+ if a.fp16:
169
+ scaler_d.scale(loss_disc_all).backward()
170
+ scaler_d.step(optim_d)
171
+ scaler_d.update()
172
+ else:
173
+ loss_disc_all.backward()
174
+ optim_d.step()
175
+
176
+ # Generator
177
+ optim_g.zero_grad()
178
+
179
+ with torch.cuda.amp.autocast(enabled=a.fp16):
180
+ # L1 Mel-Spectrogram Loss
181
+ loss_mel = F.l1_loss(y_mel, y_g_hat_mel) * 45
182
+
183
+ y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = mpd(y, y_g_hat)
184
+ y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g = msd(y, y_g_hat)
185
+ loss_fm_f = feature_loss(fmap_f_r, fmap_f_g)
186
+ loss_fm_s = feature_loss(fmap_s_r, fmap_s_g)
187
+ loss_gen_f, losses_gen_f = generator_loss(y_df_hat_g)
188
+ loss_gen_s, losses_gen_s = generator_loss(y_ds_hat_g)
189
+ loss_gen_all = loss_gen_s + loss_gen_f + loss_fm_s + loss_fm_f + loss_mel
190
+
191
+ if a.fp16:
192
+ scaler_g.scale(loss_gen_all).backward()
193
+ scaler_g.step(optim_g)
194
+ scaler_g.update()
195
+ else:
196
+ loss_gen_all.backward()
197
+ optim_g.step()
198
+
199
+ if rank == 0:
200
+ # STDOUT logging
201
+ if steps % a.stdout_interval == 0:
202
+ with torch.no_grad():
203
+ mel_error = F.l1_loss(y_mel, y_g_hat_mel).item()
204
+
205
+ mb.write('Steps : {:,d}, Gen Loss Total : {:4.3f}, Mel-Spec. Error : {:4.3f}, sec/batch : {:4.3f}, peak mem: {:5.2f}GB'. \
206
+ format(steps, loss_gen_all, mel_error, time.time() - start_b, torch.cuda.max_memory_allocated()/1e9))
207
+ mb.child.comment = "Steps : {:,d}, Gen Loss Total : {:4.3f}, Mel-Spec. Error : {:4.3f}". \
208
+ format(steps, loss_gen_all, mel_error)
209
+
210
+
211
+ # checkpointing
212
+ if steps % a.checkpoint_interval == 0 and steps != 0:
213
+ checkpoint_path = "{}/g_{:08d}.pt".format(a.checkpoint_path, steps)
214
+ save_checkpoint(checkpoint_path,
215
+ {'generator': (generator.module if h.num_gpus > 1 else generator).state_dict()})
216
+ checkpoint_path = "{}/do_{:08d}.pt".format(a.checkpoint_path, steps)
217
+ save_checkpoint(checkpoint_path,
218
+ {'mpd': (mpd.module if h.num_gpus > 1
219
+ else mpd).state_dict(),
220
+ 'msd': (msd.module if h.num_gpus > 1
221
+ else msd).state_dict(),
222
+ 'optim_g': optim_g.state_dict(), 'optim_d': optim_d.state_dict(), 'steps': steps,
223
+ 'epoch': epoch})
224
+
225
+ # Tensorboard summary logging
226
+ if steps % a.summary_interval == 0:
227
+ sw.add_scalar("training/gen_loss_total", loss_gen_all, steps)
228
+ sw.add_scalar("training/mel_spec_error", mel_error, steps)
229
+ sw.add_scalar("training/disc_loss_total", loss_disc_all, steps)
230
+
231
+ # Validation
232
+ if steps % a.validation_interval == 0: # and steps != 0:
233
+ generator.eval()
234
+ torch.cuda.empty_cache()
235
+ val_err_tot = 0
236
+ with torch.no_grad():
237
+ for j, batch in progress_bar(enumerate(validation_loader), total=len(validation_loader), parent=mb):
238
+ x, y, _, y_mel = batch
239
+ y_g_hat = generator(x.to(device))
240
+ y_mel = y_mel.to(device, non_blocking=True)
241
+ if USE_ALT_MELCALC:
242
+ y_g_hat_mel = alt_melspec(y_g_hat.squeeze(1))
243
+ if y_g_hat_mel.shape[-1] != y_mel.shape[-1]:
244
+ # pad it
245
+ n_pad = h.hop_size
246
+ y_g_hat = F.pad(y_g_hat, (n_pad//2, n_pad - n_pad//2))
247
+ y_g_hat_mel = alt_melspec(y_g_hat.squeeze(1))
248
+ else:
249
+ y_g_hat_mel = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels, h.sampling_rate,
250
+ h.hop_size, h.win_size,
251
+ h.fmin, h.fmax_for_loss)
252
+ #print('valid', x.shape, y_g_hat.shape, y_g_hat_mel.shape, y_mel.shape, y.shape)
253
+ val_err_tot += F.l1_loss(y_mel, y_g_hat_mel).item()
254
+
255
+ if j <= 4:
256
+ if steps == 0:
257
+ sw.add_audio('gt/y_{}'.format(j), y[0], steps, h.sampling_rate)
258
+ sw.add_figure('gt/y_spec_{}'.format(j), plot_spectrogram(x[0]), steps)
259
+
260
+ sw.add_audio('generated/y_hat_{}'.format(j), y_g_hat[0], steps, h.sampling_rate)
261
+ if USE_ALT_MELCALC:
262
+ y_hat_spec = alt_melspec(y_g_hat.squeeze(1))
263
+ else:
264
+ y_hat_spec = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels, h.sampling_rate,
265
+ h.hop_size, h.win_size,
266
+ h.fmin, h.fmax_for_loss)
267
+
268
+ sw.add_figure('generated/y_hat_spec_{}'.format(j),
269
+ plot_spectrogram(y_hat_spec.squeeze(0).cpu().numpy()), steps)
270
+
271
+ val_err = val_err_tot / (j+1)
272
+ sw.add_scalar("validation/mel_spec_error", val_err, steps)
273
+ mb.write(f"validation run complete at {steps:,d} steps. validation mel spec error: {val_err:5.4f}")
274
+
275
+ generator.train()
276
+ sw.add_scalar("memory/max_allocated_gb", torch.cuda.max_memory_allocated()/1e9, steps)
277
+ sw.add_scalar("memory/max_reserved_gb", torch.cuda.max_memory_reserved()/1e9, steps)
278
+ torch.cuda.reset_peak_memory_stats()
279
+ torch.cuda.reset_accumulated_memory_stats()
280
+
281
+ steps += 1
282
+
283
+ scheduler_g.step()
284
+ scheduler_d.step()
285
+
286
+ if rank == 0:
287
+ print('Time taken for epoch {} is {} sec\n'.format(epoch + 1, int(time.time() - start)))
288
+
289
+
290
+ def main():
291
+ print('Initializing Training Process..')
292
+
293
+ parser = argparse.ArgumentParser()
294
+
295
+ parser.add_argument('--group_name', default=None)
296
+ parser.add_argument('--audio_root_path', required=True)
297
+ parser.add_argument('--feature_root_path', required=True)
298
+ parser.add_argument('--input_training_file', default='LJSpeech-1.1/training.txt')
299
+ parser.add_argument('--input_validation_file', default='LJSpeech-1.1/validation.txt')
300
+ parser.add_argument('--checkpoint_path', default='cp_hifigan')
301
+ parser.add_argument('--config', default='')
302
+ parser.add_argument('--training_epochs', default=1500, type=int)
303
+ parser.add_argument('--stdout_interval', default=5, type=int)
304
+ parser.add_argument('--checkpoint_interval', default=5000, type=int)
305
+ parser.add_argument('--summary_interval', default=25, type=int)
306
+ parser.add_argument('--validation_interval', default=1000, type=int)
307
+ parser.add_argument('--fp16', default=False, type=bool)
308
+ parser.add_argument('--fine_tuning', action='store_true')
309
+
310
+ a = parser.parse_args()
311
+ print(a)
312
+ with open(a.config) as f:
313
+ data = f.read()
314
+
315
+ json_config = json.loads(data)
316
+ h = AttrDict(json_config)
317
+ build_env(a.config, 'config.json', a.checkpoint_path)
318
+
319
+ torch.manual_seed(h.seed)
320
+ if torch.cuda.is_available():
321
+ torch.cuda.manual_seed(h.seed)
322
+ h.num_gpus = torch.cuda.device_count()
323
+ h.batch_size = int(h.batch_size / h.num_gpus)
324
+ print('Batch size per GPU :', h.batch_size)
325
+ else:
326
+ pass
327
+
328
+ if h.num_gpus > 1:
329
+ mp.spawn(train, nprocs=h.num_gpus, args=(a, h,))
330
+ else:
331
+ train(0, a, h)
332
+
333
+
334
+ if __name__ == '__main__':
335
+ main()
hifigan/utils.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import os
3
+ import shutil
4
+
5
+ import torch
6
+ from torch.nn.utils import weight_norm
7
+ import json
8
+
9
+
10
+ def plot_spectrogram(spectrogram):
11
+ import matplotlib.pylab as plt
12
+ import matplotlib
13
+ matplotlib.use("Agg")
14
+ fig, ax = plt.subplots(figsize=(10, 2))
15
+ im = ax.imshow(spectrogram, aspect="auto", origin="lower",
16
+ interpolation='none')
17
+ plt.colorbar(im, ax=ax)
18
+
19
+ fig.canvas.draw()
20
+ plt.close()
21
+
22
+ return fig
23
+
24
+
25
+ def init_weights(m, mean=0.0, std=0.01):
26
+ classname = m.__class__.__name__
27
+ if classname.find("Conv") != -1:
28
+ m.weight.data.normal_(mean, std)
29
+
30
+
31
+ def apply_weight_norm(m):
32
+ classname = m.__class__.__name__
33
+ if classname.find("Conv") != -1:
34
+ weight_norm(m)
35
+
36
+
37
+ def get_padding(kernel_size, dilation=1):
38
+ return int((kernel_size*dilation - dilation)/2)
39
+
40
+
41
+ def load_checkpoint(filepath, device):
42
+ assert os.path.isfile(filepath)
43
+ print("Loading '{}'".format(filepath))
44
+ checkpoint_dict = torch.load(filepath, map_location=device)
45
+ print("Complete.")
46
+ return checkpoint_dict
47
+
48
+
49
+ def save_checkpoint(filepath, obj):
50
+ print("Saving checkpoint to {}".format(filepath))
51
+ torch.save(obj, filepath)
52
+ print("Complete.")
53
+
54
+
55
+ def scan_checkpoint(cp_dir, prefix):
56
+ pattern = os.path.join(cp_dir, prefix + '*')
57
+ cp_list = glob.glob(pattern)
58
+ if len(cp_list) == 0:
59
+ return None
60
+ return sorted(cp_list)[-1]
61
+
62
+
63
+ class AttrDict(dict):
64
+ def __init__(self, *args, **kwargs):
65
+ super(AttrDict, self).__init__(*args, **kwargs)
66
+ self.__dict__ = self
67
+
68
+
69
+ def build_env(config, config_name, path):
70
+ t_path = os.path.join(path, config_name)
71
+ if config != t_path:
72
+ os.makedirs(path, exist_ok=True)
73
+ shutil.copyfile(config, os.path.join(path, config_name))
hubconf.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dependencies = ['torch', 'torchaudio', 'numpy']
2
+
3
+ import torch
4
+ from torch import Tensor
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import logging
8
+ import json
9
+ from pathlib import Path
10
+
11
+
12
+ from wavlm.WavLM import WavLM, WavLMConfig
13
+ from hifigan.models import Generator as HiFiGAN
14
+ from hifigan.utils import AttrDict
15
+ from matcher import KNeighborsVC
16
+
17
+
18
+ def knn_vc(pretrained=True, progress=True, prematched=True, device='cuda') -> KNeighborsVC:
19
+ """ Load kNN-VC (WavLM encoder and HiFiGAN decoder). Optionally use vocoder trained on `prematched` data. """
20
+ hifigan, hifigan_cfg = hifigan_wavlm(pretrained, progress, prematched, device)
21
+ wavlm = wavlm_large(pretrained, progress, device)
22
+ knnvc = KNeighborsVC(wavlm, hifigan, hifigan_cfg, device)
23
+ return knnvc
24
+
25
+
26
+ def hifigan_wavlm(pretrained=True, progress=True, prematched=True, device='cuda') -> HiFiGAN:
27
+ """ Load pretrained hifigan trained to vocode wavlm features. Optionally use weights trained on `prematched` data. """
28
+ cp = Path(__file__).parent.absolute()
29
+
30
+ with open(cp/'hifigan'/'config_v1_wavlm.json') as f:
31
+ data = f.read()
32
+ json_config = json.loads(data)
33
+ h = AttrDict(json_config)
34
+ device = torch.device(device)
35
+
36
+ generator = HiFiGAN(h).to(device)
37
+
38
+ if pretrained:
39
+ if prematched:
40
+ url = "https://github.com/bshall/knn-vc/releases/download/v0.1/prematch_g_02500000.pt"
41
+ else:
42
+ url = "https://github.com/bshall/knn-vc/releases/download/v0.1/g_02500000.pt"
43
+ state_dict_g = torch.hub.load_state_dict_from_url(
44
+ url,
45
+ map_location=device,
46
+ progress=progress
47
+ )
48
+ generator.load_state_dict(state_dict_g['generator'])
49
+ generator.eval()
50
+ generator.remove_weight_norm()
51
+ print(f"[HiFiGAN] Generator loaded with {sum([p.numel() for p in generator.parameters()]):,d} parameters.")
52
+ return generator, h
53
+
54
+
55
+ def wavlm_large(pretrained=True, progress=True, device='cuda') -> WavLM:
56
+ """Load the WavLM large checkpoint from the original paper. See https://github.com/microsoft/unilm/tree/master/wavlm for details. """
57
+ if torch.cuda.is_available() == False:
58
+ if str(device) != 'cpu':
59
+ logging.warning(f"Overriding device {device} to cpu since no GPU is available.")
60
+ device = 'cpu'
61
+ checkpoint = torch.hub.load_state_dict_from_url(
62
+ "https://github.com/bshall/knn-vc/releases/download/v0.1/WavLM-Large.pt",
63
+ map_location=device,
64
+ progress=progress
65
+ )
66
+
67
+ cfg = WavLMConfig(checkpoint['cfg'])
68
+ device = torch.device(device)
69
+ model = WavLM(cfg)
70
+ if pretrained:
71
+ model.load_state_dict(checkpoint['model'])
72
+ model = model.to(device)
73
+ model.eval()
74
+ print(f"WavLM-Large loaded with {sum([p.numel() for p in model.parameters()]):,d} parameters.")
75
+ return model
knnvc_utils.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ def generate_matrix_from_index(A, len=25):
4
+ matrix = np.zeros(len, dtype=float)
5
+ matrix[A] = 1
6
+ return matrix
7
+
8
+
9
+ def retrieve_index_from_matrix(matrix):
10
+ A = np.where(matrix == 1)[0]
11
+ return A
12
+
13
+ if __name__ == '__main__':
14
+ # Generating a matrix from index A
15
+ A = 6
16
+ matrix = generate_matrix_from_index(A)
17
+ print("Generated Matrix:")
18
+ print(matrix)
19
+
20
+ # Retrieving index A from the matrix
21
+ retrieved_A = retrieve_index_from_matrix(matrix)
22
+ print("Retrieved Index A:")
23
+ print(retrieved_A)
matcher.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from pathlib import Path
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import torchaudio
8
+ import torchaudio.transforms as T
9
+ from hifigan.models import Generator as HiFiGAN
10
+ from hifigan.utils import AttrDict
11
+ from torch import Tensor
12
+ from torchaudio.sox_effects import apply_effects_tensor
13
+ from wavlm.WavLM import WavLM
14
+ from knnvc_utils import generate_matrix_from_index
15
+
16
+
17
+ SPEAKER_INFORMATION_LAYER = 6
18
+ SPEAKER_INFORMATION_WEIGHTS = generate_matrix_from_index(SPEAKER_INFORMATION_LAYER)
19
+
20
+
21
+ def fast_cosine_dist(source_feats: Tensor, matching_pool: Tensor, device: str = 'cpu') -> Tensor:
22
+ """ Like torch.cdist, but fixed dim=-1 and for cosine distance."""
23
+ source_norms = torch.norm(source_feats, p=2, dim=-1).to(device)
24
+ matching_norms = torch.norm(matching_pool, p=2, dim=-1)
25
+ dotprod = -torch.cdist(source_feats[None].to(device), matching_pool[None], p=2)[0]**2 + source_norms[:, None]**2 + matching_norms[None]**2
26
+ dotprod /= 2
27
+
28
+ dists = 1 - ( dotprod / (source_norms[:, None] * matching_norms[None]) )
29
+ return dists
30
+
31
+
32
+ class KNeighborsVC(nn.Module):
33
+
34
+ def __init__(self,
35
+ wavlm: WavLM,
36
+ hifigan: HiFiGAN,
37
+ hifigan_cfg: AttrDict,
38
+ device='cuda'
39
+ ) -> None:
40
+ """ kNN-VC matcher.
41
+ Arguments:
42
+ - `wavlm` : trained WavLM model
43
+ - `hifigan`: trained hifigan model
44
+ - `hifigan_cfg`: hifigan config to use for vocoding.
45
+ """
46
+ super().__init__()
47
+ # set which features to extract from wavlm
48
+ self.weighting = torch.tensor(SPEAKER_INFORMATION_WEIGHTS, device=device)[:, None]
49
+ # load hifigan
50
+ self.hifigan = hifigan.eval()
51
+ self.h = hifigan_cfg
52
+ # store wavlm
53
+ self.wavlm = wavlm.eval()
54
+ self.device = torch.device(device)
55
+ self.sr = self.h.sampling_rate
56
+ self.hop_length = 320
57
+
58
+ def get_matching_set(self, wavs: list[Path] | list[Tensor], weights=None, vad_trigger_level=7) -> Tensor:
59
+ """ Get concatenated wavlm features for the matching set using all waveforms in `wavs`,
60
+ specified as either a list of paths or list of loaded waveform tensors of
61
+ shape (channels, T), assumed to be of 16kHz sample rate.
62
+ Optionally specify custom WavLM feature weighting with `weights`.
63
+ """
64
+ feats = []
65
+ for p in wavs:
66
+ feats.append(self.get_features(p, weights=self.weighting if weights is None else weights, vad_trigger_level=vad_trigger_level))
67
+
68
+ feats = torch.concat(feats, dim=0).cpu()
69
+ return feats
70
+
71
+
72
+ @torch.inference_mode()
73
+ def vocode(self, c: Tensor) -> Tensor:
74
+ """ Vocode features with hifigan. `c` is of shape (bs, seq_len, c_dim) """
75
+ y_g_hat = self.hifigan(c)
76
+ y_g_hat = y_g_hat.squeeze(1)
77
+ return y_g_hat
78
+
79
+
80
+ @torch.inference_mode()
81
+ def get_features(self, path, weights=None, vad_trigger_level=0):
82
+ """Returns features of `path` waveform as a tensor of shape (seq_len, dim), optionally perform VAD trimming
83
+ on start/end with `vad_trigger_level`.
84
+ """
85
+ # load audio
86
+ if weights == None: weights = self.weighting
87
+ if type(path) in [str, Path]:
88
+ x, sr = torchaudio.load(path, normalize=True)
89
+ else:
90
+ x: Tensor = path
91
+ sr = self.sr
92
+ if x.dim() == 1: x = x[None]
93
+
94
+ if not sr == self.sr :
95
+ print(f"resample {sr} to {self.sr} in {path}")
96
+ x = torchaudio.functional.resample(x, orig_freq=sr, new_freq=self.sr)
97
+ sr = self.sr
98
+
99
+ # trim silence from front and back
100
+ if vad_trigger_level > 1e-3:
101
+ transform = T.Vad(sample_rate=sr, trigger_level=vad_trigger_level)
102
+ x_front_trim = transform(x)
103
+ # original way, disabled because it lacks windows support
104
+ #waveform_reversed, sr = apply_effects_tensor(x_front_trim, sr, [["reverse"]])
105
+ waveform_reversed = torch.flip(x_front_trim, (-1,))
106
+ waveform_reversed_front_trim = transform(waveform_reversed)
107
+ waveform_end_trim = torch.flip(waveform_reversed_front_trim, (-1,))
108
+ #waveform_end_trim, sr = apply_effects_tensor(
109
+ # waveform_reversed_front_trim, sr, [["reverse"]]
110
+ #)
111
+ x = waveform_end_trim
112
+
113
+ # extract the representation of each layer
114
+ wav_input_16khz = x.to(self.device)
115
+ if torch.allclose(weights, self.weighting):
116
+ # use fastpath
117
+ features = self.wavlm.extract_features(wav_input_16khz, output_layer=SPEAKER_INFORMATION_LAYER, ret_layer_results=False)[0]
118
+ features = features.squeeze(0)
119
+ else:
120
+ # use slower weighted
121
+ rep, layer_results = self.wavlm.extract_features(wav_input_16khz, output_layer=self.wavlm.cfg.encoder_layers, ret_layer_results=True)[0]
122
+ features = torch.cat([x.transpose(0, 1) for x, _ in layer_results], dim=0) # (n_layers, seq_len, dim)
123
+ # save full sequence
124
+ features = ( features*weights[:, None] ).sum(dim=0) # (seq_len, dim)
125
+
126
+ return features
127
+
128
+
129
+ @torch.inference_mode()
130
+ def match(self, query_seq: Tensor, matching_set: Tensor, synth_set: Tensor = None,
131
+ topk: int = 4, tgt_loudness_db: float | None = -16,
132
+ target_duration: float | None = None, device: str | None = None) -> Tensor:
133
+ """ Given `query_seq`, `matching_set`, and `synth_set` tensors of shape (N, dim), perform kNN regression matching
134
+ with k=`topk`. Inputs:
135
+ - `query_seq`: Tensor (N1, dim) of the input/source query features.
136
+ - `matching_set`: Tensor (N2, dim) of the matching set used as the 'training set' for the kNN algorithm.
137
+ - `synth_set`: optional Tensor (N2, dim) corresponding to the matching set. We use the matching set to assign each query
138
+ vector to a vector in the matching set, and then use the corresponding vector from the synth set during HiFiGAN synthesis.
139
+ By default, and for best performance, this should be identical to the matching set.
140
+ - `topk`: k in the kNN -- the number of nearest neighbors to average over.
141
+ - `tgt_loudness_db`: float db used to normalize the output volume. Set to None to disable.
142
+ - `target_duration`: if set to a float, interpolate resulting waveform duration to be equal to this value in seconds.
143
+ - `device`: if None, uses default device at initialization. Otherwise uses specified device
144
+ Returns:
145
+ - converted waveform of shape (T,)
146
+ """
147
+ device = torch.device(device) if device is not None else self.device
148
+ if synth_set is None: synth_set = matching_set.to(device)
149
+ else: synth_set = synth_set.to(device)
150
+ matching_set = matching_set.to(device)
151
+ query_seq = query_seq.to(device)
152
+
153
+ if target_duration is not None:
154
+ target_samples = int(target_duration*self.sr)
155
+ scale_factor = (target_samples/self.hop_length) / query_seq.shape[0] # n_targ_feats / n_input_feats
156
+ query_seq = F.interpolate(query_seq.T[None], scale_factor=scale_factor, mode='linear')[0].T
157
+
158
+ dists = fast_cosine_dist(query_seq, matching_set, device=device)
159
+ best = dists.topk(k=topk, largest=False, dim=-1)
160
+ out_feats = synth_set[best.indices].mean(dim=1)
161
+
162
+ prediction = self.vocode(out_feats[None].to(device)).cpu().squeeze()
163
+
164
+ # normalization
165
+ if tgt_loudness_db is not None:
166
+ src_loudness = torchaudio.functional.loudness(prediction[None], self.h.sampling_rate)
167
+ tgt_loudness = tgt_loudness_db
168
+ pred_wav = torchaudio.functional.gain(prediction, tgt_loudness - src_loudness)
169
+ else: pred_wav = prediction
170
+ return pred_wav
171
+
172
+
prematch_dataset.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import gc
3
+ import os
4
+ import sys
5
+ import time
6
+ from pathlib import Path
7
+
8
+ import numpy as np
9
+ import pandas as pd
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ import torchaudio
14
+ from fastprogress.fastprogress import master_bar, progress_bar
15
+ from torch import Tensor
16
+
17
+ from hubconf import wavlm_large
18
+
19
+ DOWNSAMPLE_FACTOR = 320
20
+
21
+ global feature_cache
22
+ feature_cache = {}
23
+ global synthesis_cache
24
+ synthesis_cache = {}
25
+
26
+ def make_librispeech_df(root_path: Path) -> pd.DataFrame:
27
+ all_files = []
28
+ folders = ['train-clean-100', 'dev-clean']
29
+ print(f"[LIBRISPEECH] Computing folders {folders}")
30
+ for f in folders:
31
+ all_files.extend(list((root_path/f).rglob('**/*.flac')))
32
+ speakers = ['ls-' + f.stem.split('-')[0] for f in all_files]
33
+ df = pd.DataFrame({'path': all_files, 'speaker': speakers})
34
+ return df
35
+
36
+
37
+ def main(args):
38
+ device = torch.device(args.device)
39
+ SYNTH_WEIGHTINGS = F.one_hot(torch.tensor(args.synthesis_layer), num_classes=25).float().to(device)[:, None]
40
+ MATCH_WEIGHTINGS = F.one_hot(torch.tensor(args.matching_layer), num_classes=25).float().to(device)[:, None]
41
+
42
+ print(f"Matching weightings: {MATCH_WEIGHTINGS.squeeze()}\nSynthesis weightings: {SYNTH_WEIGHTINGS.squeeze()}")
43
+ ls_df = make_librispeech_df(Path(args.librispeech_path))
44
+
45
+ print(f"Loading wavlm.")
46
+ wavlm = wavlm_large(pretrained=True, progress=True, device=args.device)
47
+
48
+ np.random.seed(args.seed)
49
+ torch.manual_seed(args.seed)
50
+ extract(ls_df, wavlm, args.device, Path(args.librispeech_path), Path(args.out_path), SYNTH_WEIGHTINGS, MATCH_WEIGHTINGS)
51
+ print("All done!", flush=True)
52
+
53
+
54
+ def path2pools(path: Path, wavlm: nn.Module(), match_weights: Tensor, synth_weights: Tensor, device):
55
+ """Given a waveform `path`, compute the matching pool"""
56
+
57
+ uttrs_from_same_spk = sorted(list(path.parent.rglob('**/*.flac')))
58
+ uttrs_from_same_spk.remove(path)
59
+ matching_pool = []
60
+ synth_pool = []
61
+ for pth in uttrs_from_same_spk:
62
+ if pth in feature_cache and pth in synthesis_cache:
63
+ matching_feats = feature_cache[pth].float() # (seq_len, dim)
64
+ synth_feats = synthesis_cache[pth].float() # (seq_len, dim)
65
+ else:
66
+ feats = get_full_features(pth, wavlm, device)
67
+ matching_feats = ( feats*match_weights[:, None] ).sum(dim=0) # (seq_len, dim)
68
+ synth_feats = ( feats*synth_weights[:, None] ).sum(dim=0) # (seq_len, dim)
69
+ feature_cache[pth] = matching_feats.half().cpu()
70
+ synthesis_cache[pth] = synth_feats.half().cpu()
71
+
72
+ matching_pool.append(matching_feats.cpu())
73
+ synth_pool.append(synth_feats.cpu())
74
+ matching_pool = torch.concat(matching_pool, dim=0)
75
+ synth_pool = torch.concat(synth_pool, dim=0)
76
+ return matching_pool, synth_pool # (N, dim)
77
+
78
+
79
+ @torch.inference_mode()
80
+ def get_full_features(path, wavlm, device):
81
+
82
+ x, sr = torchaudio.load(path)
83
+ assert sr == 16000
84
+ # This does not work i.t.o the hifigan training.
85
+ # x = F.pad(x, (DOWNSAMPLE_FACTOR//2, DOWNSAMPLE_FACTOR - DOWNSAMPLE_FACTOR//2), value=0)
86
+ # This does.
87
+ n_pad = DOWNSAMPLE_FACTOR - (x.shape[-1] % DOWNSAMPLE_FACTOR)
88
+ x = F.pad(x, (0, n_pad), value=0)
89
+
90
+ # extract the representation of each layer
91
+ wav_input_16khz = x.to(device)
92
+ rep, layer_results = wavlm.extract_features(wav_input_16khz, output_layer=wavlm.cfg.encoder_layers, ret_layer_results=True)[0]
93
+ features = torch.cat([x.transpose(0, 1) for x, _ in layer_results], dim=0) # (n_layers, seq_len, dim)
94
+
95
+ return features
96
+
97
+
98
+ def fast_cosine_dist(source_feats, matching_pool):
99
+ source_norms = torch.norm(source_feats, p=2, dim=-1)
100
+ matching_norms = torch.norm(matching_pool, p=2, dim=-1)
101
+ dotprod = -torch.cdist(source_feats[None], matching_pool[None], p=2)[0]**2 + source_norms[:, None]**2 + matching_norms[None]**2
102
+ dotprod /= 2
103
+
104
+ dists = 1 - ( dotprod / (source_norms[:, None] * matching_norms[None]) )
105
+ return dists
106
+
107
+
108
+ @torch.inference_mode()
109
+ def extract(df: pd.DataFrame, wavlm: nn.Module, device, ls_path: Path, out_path: Path, synth_weights: Tensor, match_weights: Tensor):
110
+
111
+ pb = progress_bar(df.iterrows(), total=len(df))
112
+
113
+ for i, row in pb:
114
+ rel_path = Path(row.path).relative_to(ls_path)
115
+ targ_path = (out_path/rel_path).with_suffix('.pt')
116
+ if args.resume:
117
+ if targ_path.is_file(): continue
118
+ # if targ_path.is_file(): continue
119
+ os.makedirs(targ_path.parent, exist_ok=True)
120
+
121
+ if Path(row.path) in feature_cache:
122
+ source_feats = feature_cache[Path(row.path)].float()
123
+ else:
124
+ source_feats = get_full_features(row.path, wavlm, device)
125
+ source_feats = ( source_feats*match_weights[:, None] ).sum(dim=0) # (seq_len, dim)
126
+
127
+ matching_pool, synth_pool = path2pools(row.path, wavlm, match_weights, synth_weights, device)
128
+
129
+ if not args.prematch:
130
+ out_feats = source_feats.cpu()
131
+ else:
132
+ dists = fast_cosine_dist(source_feats.cpu(), matching_pool.cpu()).cpu()
133
+ best = dists.topk(k=args.topk, dim=-1, largest=False) # (src_len, 4)
134
+ out_feats = synth_pool[best.indices].mean(dim=1) # (N, dim)
135
+
136
+ # save matched sequence
137
+ if i < 3: print("Feature has shape: ", out_feats.shape, flush=True)
138
+ # 3. save
139
+ torch.save(out_feats.cpu().half(), str(targ_path))
140
+ if hasattr(pb, 'child'):
141
+ pb.child.comment = str(rel_path)
142
+ pb.child.wait_for = min(pb.child.wait_for, 10)
143
+ pb.main_bar.comment = str(rel_path)
144
+ else:
145
+ pb.wait_for = min(pb.wait_for, 10)
146
+ pb.comment = str(rel_path)
147
+
148
+
149
+ if i % 1000 == 0:
150
+ print(f"Done {i:,d}/{len(df):,d}", flush=True)
151
+ feature_cache.clear()
152
+ synthesis_cache.clear()
153
+ gc.collect()
154
+ time.sleep(4)
155
+
156
+
157
+ if __name__ == '__main__':
158
+ parser = argparse.ArgumentParser(description="Compute matched wavlm features for a librispeech dataset")
159
+
160
+ parser.add_argument('--librispeech_path', required=True, type=str)
161
+ parser.add_argument('--seed', default=123, type=int)
162
+ parser.add_argument('--out_path', required=True, type=str)
163
+ parser.add_argument('--device', default='cuda', type=str)
164
+ parser.add_argument('--topk', type=int, default=4)
165
+ parser.add_argument('--matching_layer', type=int, default=6)
166
+ parser.add_argument('--synthesis_layer', type=int, default=6)
167
+ parser.add_argument('--prematch', action='store_true', help='prematch')
168
+ parser.add_argument('--resume', action='store_true')
169
+
170
+ args = parser.parse_args()
171
+ main(args)
172
+
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch
2
+ torchaudio
3
+ soundfile
4
+ gradio
5
+ spaces
wavlm/WavLM.py ADDED
@@ -0,0 +1,743 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # WavLM: Large-Scale Self-Supervised Pre-training for Full Stack Speech Processing (https://arxiv.org/abs/2110.13900.pdf)
3
+ # Github source: https://github.com/microsoft/unilm/tree/master/wavlm
4
+ # Copyright (c) 2021 Microsoft
5
+ # Licensed under The MIT License [see LICENSE for details]
6
+ # Based on fairseq code bases
7
+ # https://github.com/pytorch/fairseq
8
+ # --------------------------------------------------------
9
+
10
+ import math
11
+ import logging
12
+ from typing import List, Optional, Tuple
13
+
14
+ import numpy as np
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+ import torch.nn.functional as F
19
+ from torch.nn import LayerNorm
20
+ from .modules import (
21
+ Fp32GroupNorm,
22
+ Fp32LayerNorm,
23
+ GradMultiply,
24
+ MultiheadAttention,
25
+ SamePad,
26
+ init_bert_params,
27
+ get_activation_fn,
28
+ TransposeLast,
29
+ GLU_Linear,
30
+ )
31
+
32
+ logger = logging.getLogger(__name__)
33
+
34
+
35
+ def compute_mask_indices(
36
+ shape: Tuple[int, int],
37
+ padding_mask: Optional[torch.Tensor],
38
+ mask_prob: float,
39
+ mask_length: int,
40
+ mask_type: str = "static",
41
+ mask_other: float = 0.0,
42
+ min_masks: int = 0,
43
+ no_overlap: bool = False,
44
+ min_space: int = 0,
45
+ ) -> np.ndarray:
46
+ """
47
+ Computes random mask spans for a given shape
48
+
49
+ Args:
50
+ shape: the the shape for which to compute masks.
51
+ should be of size 2 where first element is batch size and 2nd is timesteps
52
+ padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
53
+ mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
54
+ number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
55
+ however due to overlaps, the actual number will be smaller (unless no_overlap is True)
56
+ mask_type: how to compute mask lengths
57
+ static = fixed size
58
+ uniform = sample from uniform distribution [mask_other, mask_length*2]
59
+ normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element
60
+ poisson = sample from possion distribution with lambda = mask length
61
+ min_masks: minimum number of masked spans
62
+ no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping
63
+ min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans
64
+ """
65
+
66
+ bsz, all_sz = shape
67
+ mask = np.full((bsz, all_sz), False)
68
+
69
+ all_num_mask = int(
70
+ # add a random number for probabilistic rounding
71
+ mask_prob * all_sz / float(mask_length)
72
+ + np.random.rand()
73
+ )
74
+
75
+ all_num_mask = max(min_masks, all_num_mask)
76
+
77
+ mask_idcs = []
78
+ for i in range(bsz):
79
+ if padding_mask is not None:
80
+ sz = all_sz - padding_mask[i].long().sum().item()
81
+ num_mask = int(
82
+ # add a random number for probabilistic rounding
83
+ mask_prob * sz / float(mask_length)
84
+ + np.random.rand()
85
+ )
86
+ num_mask = max(min_masks, num_mask)
87
+ else:
88
+ sz = all_sz
89
+ num_mask = all_num_mask
90
+
91
+ if mask_type == "static":
92
+ lengths = np.full(num_mask, mask_length)
93
+ elif mask_type == "uniform":
94
+ lengths = np.random.randint(mask_other, mask_length * 2 + 1, size=num_mask)
95
+ elif mask_type == "normal":
96
+ lengths = np.random.normal(mask_length, mask_other, size=num_mask)
97
+ lengths = [max(1, int(round(x))) for x in lengths]
98
+ elif mask_type == "poisson":
99
+ lengths = np.random.poisson(mask_length, size=num_mask)
100
+ lengths = [int(round(x)) for x in lengths]
101
+ else:
102
+ raise Exception("unknown mask selection " + mask_type)
103
+
104
+ if sum(lengths) == 0:
105
+ lengths[0] = min(mask_length, sz - 1)
106
+
107
+ if no_overlap:
108
+ mask_idc = []
109
+
110
+ def arrange(s, e, length, keep_length):
111
+ span_start = np.random.randint(s, e - length)
112
+ mask_idc.extend(span_start + i for i in range(length))
113
+
114
+ new_parts = []
115
+ if span_start - s - min_space >= keep_length:
116
+ new_parts.append((s, span_start - min_space + 1))
117
+ if e - span_start - keep_length - min_space > keep_length:
118
+ new_parts.append((span_start + length + min_space, e))
119
+ return new_parts
120
+
121
+ parts = [(0, sz)]
122
+ min_length = min(lengths)
123
+ for length in sorted(lengths, reverse=True):
124
+ lens = np.fromiter(
125
+ (e - s if e - s >= length + min_space else 0 for s, e in parts),
126
+ np.int,
127
+ )
128
+ l_sum = np.sum(lens)
129
+ if l_sum == 0:
130
+ break
131
+ probs = lens / np.sum(lens)
132
+ c = np.random.choice(len(parts), p=probs)
133
+ s, e = parts.pop(c)
134
+ parts.extend(arrange(s, e, length, min_length))
135
+ mask_idc = np.asarray(mask_idc)
136
+ else:
137
+ min_len = min(lengths)
138
+ if sz - min_len <= num_mask:
139
+ min_len = sz - num_mask - 1
140
+
141
+ mask_idc = np.random.choice(sz - min_len, num_mask, replace=False)
142
+
143
+ mask_idc = np.asarray(
144
+ [
145
+ mask_idc[j] + offset
146
+ for j in range(len(mask_idc))
147
+ for offset in range(lengths[j])
148
+ ]
149
+ )
150
+
151
+ mask_idcs.append(np.unique(mask_idc[mask_idc < sz]))
152
+
153
+ min_len = min([len(m) for m in mask_idcs])
154
+ for i, mask_idc in enumerate(mask_idcs):
155
+ if len(mask_idc) > min_len:
156
+ mask_idc = np.random.choice(mask_idc, min_len, replace=False)
157
+ mask[i, mask_idc] = True
158
+
159
+ return mask
160
+
161
+
162
+ class WavLMConfig:
163
+ def __init__(self, cfg=None):
164
+ self.extractor_mode: str = "default" # mode for feature extractor. default has a single group norm with d groups in the first conv block, whereas layer_norm has layer norms in every block (meant to use with normalize=True)
165
+ self.encoder_layers: int = 12 # num encoder layers in the transformer
166
+
167
+ self.encoder_embed_dim: int = 768 # encoder embedding dimension
168
+ self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN
169
+ self.encoder_attention_heads: int = 12 # num encoder attention heads
170
+ self.activation_fn: str = "gelu" # activation function to use
171
+
172
+ self.layer_norm_first: bool = False # apply layernorm first in the transformer
173
+ self.conv_feature_layers: str = "[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2" # string describing convolutional feature extraction layers in form of a python list that contains [(dim, kernel_size, stride), ...]
174
+ self.conv_bias: bool = False # include bias in conv encoder
175
+ self.feature_grad_mult: float = 1.0 # multiply feature extractor var grads by this
176
+
177
+ self.normalize: bool = False # normalize input to have 0 mean and unit variance during training
178
+
179
+ # dropouts
180
+ self.dropout: float = 0.1 # dropout probability for the transformer
181
+ self.attention_dropout: float = 0.1 # dropout probability for attention weights
182
+ self.activation_dropout: float = 0.0 # dropout probability after activation in FFN
183
+ self.encoder_layerdrop: float = 0.0 # probability of dropping a tarnsformer layer
184
+ self.dropout_input: float = 0.0 # dropout to apply to the input (after feat extr)
185
+ self.dropout_features: float = 0.0 # dropout to apply to the features (after feat extr)
186
+
187
+ # masking
188
+ self.mask_length: int = 10 # mask length
189
+ self.mask_prob: float = 0.65 # probability of replacing a token with mask
190
+ self.mask_selection: str = "static" # how to choose mask length
191
+ self.mask_other: float = 0 # secondary mask argument (used for more complex distributions), see help in compute_mask_indicesh
192
+ self.no_mask_overlap: bool = False # whether to allow masks to overlap
193
+ self.mask_min_space: int = 1 # min space between spans (if no overlap is enabled)
194
+
195
+ # channel masking
196
+ self.mask_channel_length: int = 10 # length of the mask for features (channels)
197
+ self.mask_channel_prob: float = 0.0 # probability of replacing a feature with 0
198
+ self.mask_channel_selection: str = "static" # how to choose mask length for channel masking
199
+ self.mask_channel_other: float = 0 # secondary mask argument (used for more complex distributions), see help in compute_mask_indices
200
+ self.no_mask_channel_overlap: bool = False # whether to allow channel masks to overlap
201
+ self.mask_channel_min_space: int = 1 # min space between spans (if no overlap is enabled)
202
+
203
+ # positional embeddings
204
+ self.conv_pos: int = 128 # number of filters for convolutional positional embeddings
205
+ self.conv_pos_groups: int = 16 # number of groups for convolutional positional embedding
206
+
207
+ # relative position embedding
208
+ self.relative_position_embedding: bool = False # apply relative position embedding
209
+ self.num_buckets: int = 320 # number of buckets for relative position embedding
210
+ self.max_distance: int = 1280 # maximum distance for relative position embedding
211
+ self.gru_rel_pos: bool = False # apply gated relative position embedding
212
+
213
+ if cfg is not None:
214
+ self.update(cfg)
215
+
216
+ def update(self, cfg: dict):
217
+ self.__dict__.update(cfg)
218
+
219
+
220
+ class WavLM(nn.Module):
221
+ def __init__(
222
+ self,
223
+ cfg: WavLMConfig,
224
+ ) -> None:
225
+ super().__init__()
226
+ logger.info(f"WavLM Config: {cfg.__dict__}")
227
+
228
+ self.cfg = cfg
229
+ feature_enc_layers = eval(cfg.conv_feature_layers)
230
+ self.embed = feature_enc_layers[-1][0]
231
+
232
+ self.feature_extractor = ConvFeatureExtractionModel(
233
+ conv_layers=feature_enc_layers,
234
+ dropout=0.0,
235
+ mode=cfg.extractor_mode,
236
+ conv_bias=cfg.conv_bias,
237
+ )
238
+
239
+ self.post_extract_proj = (
240
+ nn.Linear(self.embed, cfg.encoder_embed_dim)
241
+ if self.embed != cfg.encoder_embed_dim
242
+ else None
243
+ )
244
+
245
+ self.mask_prob = cfg.mask_prob
246
+ self.mask_selection = cfg.mask_selection
247
+ self.mask_other = cfg.mask_other
248
+ self.mask_length = cfg.mask_length
249
+ self.no_mask_overlap = cfg.no_mask_overlap
250
+ self.mask_min_space = cfg.mask_min_space
251
+
252
+ self.mask_channel_prob = cfg.mask_channel_prob
253
+ self.mask_channel_selection = cfg.mask_channel_selection
254
+ self.mask_channel_other = cfg.mask_channel_other
255
+ self.mask_channel_length = cfg.mask_channel_length
256
+ self.no_mask_channel_overlap = cfg.no_mask_channel_overlap
257
+ self.mask_channel_min_space = cfg.mask_channel_min_space
258
+
259
+ self.dropout_input = nn.Dropout(cfg.dropout_input)
260
+ self.dropout_features = nn.Dropout(cfg.dropout_features)
261
+
262
+ self.feature_grad_mult = cfg.feature_grad_mult
263
+
264
+ self.mask_emb = nn.Parameter(
265
+ torch.FloatTensor(cfg.encoder_embed_dim).uniform_()
266
+ )
267
+
268
+ self.encoder = TransformerEncoder(cfg)
269
+ self.layer_norm = LayerNorm(self.embed)
270
+
271
+ def apply_mask(self, x, padding_mask):
272
+ B, T, C = x.shape
273
+ if self.mask_prob > 0:
274
+ mask_indices = compute_mask_indices(
275
+ (B, T),
276
+ padding_mask,
277
+ self.mask_prob,
278
+ self.mask_length,
279
+ self.mask_selection,
280
+ self.mask_other,
281
+ min_masks=2,
282
+ no_overlap=self.no_mask_overlap,
283
+ min_space=self.mask_min_space,
284
+ )
285
+ mask_indices = torch.from_numpy(mask_indices).to(x.device)
286
+ x[mask_indices] = self.mask_emb
287
+ else:
288
+ mask_indices = None
289
+
290
+ if self.mask_channel_prob > 0:
291
+ mask_channel_indices = compute_mask_indices(
292
+ (B, C),
293
+ None,
294
+ self.mask_channel_prob,
295
+ self.mask_channel_length,
296
+ self.mask_channel_selection,
297
+ self.mask_channel_other,
298
+ no_overlap=self.no_mask_channel_overlap,
299
+ min_space=self.mask_channel_min_space,
300
+ )
301
+ mask_channel_indices = (
302
+ torch.from_numpy(mask_channel_indices)
303
+ .to(x.device)
304
+ .unsqueeze(1)
305
+ .expand(-1, T, -1)
306
+ )
307
+ x[mask_channel_indices] = 0
308
+
309
+ return x, mask_indices
310
+
311
+ def forward_padding_mask(
312
+ self, features: torch.Tensor, padding_mask: torch.Tensor,
313
+ ) -> torch.Tensor:
314
+ extra = padding_mask.size(1) % features.size(1)
315
+ if extra > 0:
316
+ padding_mask = padding_mask[:, :-extra]
317
+ padding_mask = padding_mask.view(
318
+ padding_mask.size(0), features.size(1), -1
319
+ )
320
+ padding_mask = padding_mask.all(-1)
321
+ return padding_mask
322
+
323
+ def extract_features(
324
+ self,
325
+ source: torch.Tensor,
326
+ padding_mask: Optional[torch.Tensor] = None,
327
+ mask: bool = False,
328
+ ret_conv: bool = False,
329
+ output_layer: Optional[int] = None,
330
+ ret_layer_results: bool = False,
331
+ ):
332
+
333
+ if self.feature_grad_mult > 0:
334
+ features = self.feature_extractor(source)
335
+ if self.feature_grad_mult != 1.0:
336
+ features = GradMultiply.apply(features, self.feature_grad_mult)
337
+ else:
338
+ with torch.no_grad():
339
+ features = self.feature_extractor(source)
340
+
341
+ features = features.transpose(1, 2)
342
+ features = self.layer_norm(features)
343
+
344
+ if padding_mask is not None:
345
+ padding_mask = self.forward_padding_mask(features, padding_mask)
346
+
347
+ if self.post_extract_proj is not None:
348
+ features = self.post_extract_proj(features)
349
+
350
+ features = self.dropout_input(features)
351
+
352
+ if mask:
353
+ x, mask_indices = self.apply_mask(
354
+ features, padding_mask
355
+ )
356
+ else:
357
+ x = features
358
+
359
+ # feature: (B, T, D), float
360
+ # target: (B, T), long
361
+ # x: (B, T, D), float
362
+ # padding_mask: (B, T), bool
363
+ # mask_indices: (B, T), bool
364
+ x, layer_results = self.encoder(
365
+ x,
366
+ padding_mask=padding_mask,
367
+ layer=None if output_layer is None else output_layer - 1
368
+ )
369
+
370
+ res = {"x": x, "padding_mask": padding_mask, "features": features, "layer_results": layer_results}
371
+
372
+ feature = res["features"] if ret_conv else res["x"]
373
+ if ret_layer_results:
374
+ feature = (feature, res["layer_results"])
375
+ return feature, res["padding_mask"]
376
+
377
+
378
+ class ConvFeatureExtractionModel(nn.Module):
379
+ def __init__(
380
+ self,
381
+ conv_layers: List[Tuple[int, int, int]],
382
+ dropout: float = 0.0,
383
+ mode: str = "default",
384
+ conv_bias: bool = False,
385
+ conv_type: str = "default"
386
+ ):
387
+ super().__init__()
388
+
389
+ assert mode in {"default", "layer_norm"}
390
+
391
+ def block(
392
+ n_in,
393
+ n_out,
394
+ k,
395
+ stride,
396
+ is_layer_norm=False,
397
+ is_group_norm=False,
398
+ conv_bias=False,
399
+ ):
400
+ def make_conv():
401
+ conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias)
402
+ nn.init.kaiming_normal_(conv.weight)
403
+ return conv
404
+
405
+ assert (
406
+ is_layer_norm and is_group_norm
407
+ ) == False, "layer norm and group norm are exclusive"
408
+
409
+ if is_layer_norm:
410
+ return nn.Sequential(
411
+ make_conv(),
412
+ nn.Dropout(p=dropout),
413
+ nn.Sequential(
414
+ TransposeLast(),
415
+ Fp32LayerNorm(dim, elementwise_affine=True),
416
+ TransposeLast(),
417
+ ),
418
+ nn.GELU(),
419
+ )
420
+ elif is_group_norm:
421
+ return nn.Sequential(
422
+ make_conv(),
423
+ nn.Dropout(p=dropout),
424
+ Fp32GroupNorm(dim, dim, affine=True),
425
+ nn.GELU(),
426
+ )
427
+ else:
428
+ return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU())
429
+
430
+ self.conv_type = conv_type
431
+ if self.conv_type == "default":
432
+ in_d = 1
433
+ self.conv_layers = nn.ModuleList()
434
+ for i, cl in enumerate(conv_layers):
435
+ assert len(cl) == 3, "invalid conv definition: " + str(cl)
436
+ (dim, k, stride) = cl
437
+
438
+ self.conv_layers.append(
439
+ block(
440
+ in_d,
441
+ dim,
442
+ k,
443
+ stride,
444
+ is_layer_norm=mode == "layer_norm",
445
+ is_group_norm=mode == "default" and i == 0,
446
+ conv_bias=conv_bias,
447
+ )
448
+ )
449
+ in_d = dim
450
+ elif self.conv_type == "conv2d":
451
+ in_d = 1
452
+ self.conv_layers = nn.ModuleList()
453
+ for i, cl in enumerate(conv_layers):
454
+ assert len(cl) == 3
455
+ (dim, k, stride) = cl
456
+
457
+ self.conv_layers.append(
458
+ torch.nn.Conv2d(in_d, dim, k, stride)
459
+ )
460
+ self.conv_layers.append(torch.nn.ReLU())
461
+ in_d = dim
462
+ elif self.conv_type == "custom":
463
+ in_d = 1
464
+ idim = 80
465
+ self.conv_layers = nn.ModuleList()
466
+ for i, cl in enumerate(conv_layers):
467
+ assert len(cl) == 3
468
+ (dim, k, stride) = cl
469
+ self.conv_layers.append(
470
+ torch.nn.Conv2d(in_d, dim, k, stride, padding=1)
471
+ )
472
+ self.conv_layers.append(
473
+ torch.nn.LayerNorm([dim, idim])
474
+ )
475
+ self.conv_layers.append(torch.nn.ReLU())
476
+ in_d = dim
477
+ if (i + 1) % 2 == 0:
478
+ self.conv_layers.append(
479
+ torch.nn.MaxPool2d(2, stride=2, ceil_mode=True)
480
+ )
481
+ idim = int(math.ceil(idim / 2))
482
+ else:
483
+ pass
484
+
485
+ def forward(self, x, mask=None):
486
+
487
+ # BxT -> BxCxT
488
+ x = x.unsqueeze(1)
489
+ if self.conv_type == "custom":
490
+ for conv in self.conv_layers:
491
+ if isinstance(conv, nn.LayerNorm):
492
+ x = x.transpose(1, 2)
493
+ x = conv(x).transpose(1, 2)
494
+ else:
495
+ x = conv(x)
496
+ x = x.transpose(2, 3).contiguous()
497
+ x = x.view(x.size(0), -1, x.size(-1))
498
+ else:
499
+ for conv in self.conv_layers:
500
+ x = conv(x)
501
+ if self.conv_type == "conv2d":
502
+ b, c, t, f = x.size()
503
+ x = x.transpose(2, 3).contiguous().view(b, c * f, t)
504
+ return x
505
+
506
+
507
+ class TransformerEncoder(nn.Module):
508
+ def __init__(self, args):
509
+ super().__init__()
510
+
511
+ self.dropout = args.dropout
512
+ self.embedding_dim = args.encoder_embed_dim
513
+
514
+ self.pos_conv = nn.Conv1d(
515
+ self.embedding_dim,
516
+ self.embedding_dim,
517
+ kernel_size=args.conv_pos,
518
+ padding=args.conv_pos // 2,
519
+ groups=args.conv_pos_groups,
520
+ )
521
+ dropout = 0
522
+ std = math.sqrt((4 * (1.0 - dropout)) / (args.conv_pos * self.embedding_dim))
523
+ nn.init.normal_(self.pos_conv.weight, mean=0, std=std)
524
+ nn.init.constant_(self.pos_conv.bias, 0)
525
+
526
+ self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2)
527
+ self.pos_conv = nn.Sequential(self.pos_conv, SamePad(args.conv_pos), nn.GELU())
528
+
529
+ if hasattr(args, "relative_position_embedding"):
530
+ self.relative_position_embedding = args.relative_position_embedding
531
+ self.num_buckets = args.num_buckets
532
+ self.max_distance = args.max_distance
533
+ else:
534
+ self.relative_position_embedding = False
535
+ self.num_buckets = 0
536
+ self.max_distance = 0
537
+
538
+ self.layers = nn.ModuleList(
539
+ [
540
+ TransformerSentenceEncoderLayer(
541
+ embedding_dim=self.embedding_dim,
542
+ ffn_embedding_dim=args.encoder_ffn_embed_dim,
543
+ num_attention_heads=args.encoder_attention_heads,
544
+ dropout=self.dropout,
545
+ attention_dropout=args.attention_dropout,
546
+ activation_dropout=args.activation_dropout,
547
+ activation_fn=args.activation_fn,
548
+ layer_norm_first=args.layer_norm_first,
549
+ has_relative_attention_bias=(self.relative_position_embedding and i == 0),
550
+ num_buckets=self.num_buckets,
551
+ max_distance=self.max_distance,
552
+ gru_rel_pos=args.gru_rel_pos,
553
+ )
554
+ for i in range(args.encoder_layers)
555
+ ]
556
+ )
557
+
558
+ self.layer_norm_first = args.layer_norm_first
559
+ self.layer_norm = LayerNorm(self.embedding_dim)
560
+ self.layerdrop = args.encoder_layerdrop
561
+
562
+ self.apply(init_bert_params)
563
+
564
+ def forward(self, x, padding_mask=None, streaming_mask=None, layer=None):
565
+ x, layer_results = self.extract_features(x, padding_mask, streaming_mask, layer)
566
+
567
+ if self.layer_norm_first and layer is None:
568
+ x = self.layer_norm(x)
569
+
570
+ return x, layer_results
571
+
572
+ def extract_features(self, x, padding_mask=None, streaming_mask=None, tgt_layer=None):
573
+
574
+ if padding_mask is not None:
575
+ x[padding_mask] = 0
576
+
577
+ x_conv = self.pos_conv(x.transpose(1, 2))
578
+ x_conv = x_conv.transpose(1, 2)
579
+ x += x_conv
580
+
581
+ if not self.layer_norm_first:
582
+ x = self.layer_norm(x)
583
+
584
+ x = F.dropout(x, p=self.dropout, training=self.training)
585
+
586
+ # B x T x C -> T x B x C
587
+ x = x.transpose(0, 1)
588
+
589
+ layer_results = []
590
+ z = None
591
+ if tgt_layer is not None:
592
+ layer_results.append((x, z))
593
+ r = None
594
+ pos_bias = None
595
+ for i, layer in enumerate(self.layers):
596
+ dropout_probability = np.random.random()
597
+ if not self.training or (dropout_probability > self.layerdrop):
598
+ x, z, pos_bias = layer(x, self_attn_padding_mask=padding_mask, need_weights=False,
599
+ self_attn_mask=streaming_mask, pos_bias=pos_bias)
600
+ if tgt_layer is not None:
601
+ layer_results.append((x, z))
602
+ if i == tgt_layer:
603
+ r = x
604
+ break
605
+
606
+ if r is not None:
607
+ x = r
608
+
609
+ # T x B x C -> B x T x C
610
+ x = x.transpose(0, 1)
611
+
612
+ return x, layer_results
613
+
614
+
615
+ class TransformerSentenceEncoderLayer(nn.Module):
616
+ """
617
+ Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained
618
+ models.
619
+ """
620
+
621
+ def __init__(
622
+ self,
623
+ embedding_dim: float = 768,
624
+ ffn_embedding_dim: float = 3072,
625
+ num_attention_heads: float = 8,
626
+ dropout: float = 0.1,
627
+ attention_dropout: float = 0.1,
628
+ activation_dropout: float = 0.1,
629
+ activation_fn: str = "relu",
630
+ layer_norm_first: bool = False,
631
+ has_relative_attention_bias: bool = False,
632
+ num_buckets: int = 0,
633
+ max_distance: int = 0,
634
+ rescale_init: bool = False,
635
+ gru_rel_pos: bool = False,
636
+ ) -> None:
637
+
638
+ super().__init__()
639
+ # Initialize parameters
640
+ self.embedding_dim = embedding_dim
641
+ self.dropout = dropout
642
+ self.activation_dropout = activation_dropout
643
+
644
+ # Initialize blocks
645
+ self.activation_name = activation_fn
646
+ self.activation_fn = get_activation_fn(activation_fn)
647
+ self.self_attn = MultiheadAttention(
648
+ self.embedding_dim,
649
+ num_attention_heads,
650
+ dropout=attention_dropout,
651
+ self_attention=True,
652
+ has_relative_attention_bias=has_relative_attention_bias,
653
+ num_buckets=num_buckets,
654
+ max_distance=max_distance,
655
+ rescale_init=rescale_init,
656
+ gru_rel_pos=gru_rel_pos,
657
+ )
658
+
659
+ self.dropout1 = nn.Dropout(dropout)
660
+ self.dropout2 = nn.Dropout(self.activation_dropout)
661
+ self.dropout3 = nn.Dropout(dropout)
662
+
663
+ self.layer_norm_first = layer_norm_first
664
+
665
+ # layer norm associated with the self attention layer
666
+ self.self_attn_layer_norm = LayerNorm(self.embedding_dim)
667
+
668
+ if self.activation_name == "glu":
669
+ self.fc1 = GLU_Linear(self.embedding_dim, ffn_embedding_dim, "swish")
670
+ else:
671
+ self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim)
672
+ self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim)
673
+
674
+ # layer norm associated with the position wise feed-forward NN
675
+ self.final_layer_norm = LayerNorm(self.embedding_dim)
676
+
677
+ def forward(
678
+ self,
679
+ x: torch.Tensor,
680
+ self_attn_mask: torch.Tensor = None,
681
+ self_attn_padding_mask: torch.Tensor = None,
682
+ need_weights: bool = False,
683
+ pos_bias=None
684
+ ):
685
+ """
686
+ LayerNorm is applied either before or after the self-attention/ffn
687
+ modules similar to the original Transformer imlementation.
688
+ """
689
+ residual = x
690
+
691
+ if self.layer_norm_first:
692
+ x = self.self_attn_layer_norm(x)
693
+ x, attn, pos_bias = self.self_attn(
694
+ query=x,
695
+ key=x,
696
+ value=x,
697
+ key_padding_mask=self_attn_padding_mask,
698
+ need_weights=False,
699
+ attn_mask=self_attn_mask,
700
+ position_bias=pos_bias
701
+ )
702
+ x = self.dropout1(x)
703
+ x = residual + x
704
+
705
+ residual = x
706
+ x = self.final_layer_norm(x)
707
+ if self.activation_name == "glu":
708
+ x = self.fc1(x)
709
+ else:
710
+ x = self.activation_fn(self.fc1(x))
711
+ x = self.dropout2(x)
712
+ x = self.fc2(x)
713
+ x = self.dropout3(x)
714
+ x = residual + x
715
+ else:
716
+ x, attn, pos_bias = self.self_attn(
717
+ query=x,
718
+ key=x,
719
+ value=x,
720
+ key_padding_mask=self_attn_padding_mask,
721
+ need_weights=need_weights,
722
+ attn_mask=self_attn_mask,
723
+ position_bias=pos_bias
724
+ )
725
+
726
+ x = self.dropout1(x)
727
+ x = residual + x
728
+
729
+ x = self.self_attn_layer_norm(x)
730
+
731
+ residual = x
732
+ if self.activation_name == "glu":
733
+ x = self.fc1(x)
734
+ else:
735
+ x = self.activation_fn(self.fc1(x))
736
+ x = self.dropout2(x)
737
+ x = self.fc2(x)
738
+ x = self.dropout3(x)
739
+ x = residual + x
740
+ x = self.final_layer_norm(x)
741
+
742
+ return x, attn, pos_bias
743
+
wavlm/modules.py ADDED
@@ -0,0 +1,827 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # WavLM: Large-Scale Self-Supervised Pre-training for Full Stack Speech Processing (https://arxiv.org/abs/2110.13900.pdf)
3
+ # Github source: https://github.com/microsoft/unilm/tree/master/wavlm
4
+ # Copyright (c) 2021 Microsoft
5
+ # Licensed under The MIT License [see LICENSE for details]
6
+ # Based on fairseq code bases
7
+ # https://github.com/pytorch/fairseq
8
+ # --------------------------------------------------------
9
+
10
+ import math
11
+ import warnings
12
+ from typing import Dict, Optional, Tuple
13
+ import torch
14
+ from torch import Tensor, nn
15
+ from torch.nn import Parameter
16
+ import torch.nn.functional as F
17
+
18
+
19
+ class TransposeLast(nn.Module):
20
+ def __init__(self, deconstruct_idx=None):
21
+ super().__init__()
22
+ self.deconstruct_idx = deconstruct_idx
23
+
24
+ def forward(self, x):
25
+ if self.deconstruct_idx is not None:
26
+ x = x[self.deconstruct_idx]
27
+ return x.transpose(-2, -1)
28
+
29
+
30
+ class Fp32LayerNorm(nn.LayerNorm):
31
+ def __init__(self, *args, **kwargs):
32
+ super().__init__(*args, **kwargs)
33
+
34
+ def forward(self, input):
35
+ output = F.layer_norm(
36
+ input.float(),
37
+ self.normalized_shape,
38
+ self.weight.float() if self.weight is not None else None,
39
+ self.bias.float() if self.bias is not None else None,
40
+ self.eps,
41
+ )
42
+ return output.type_as(input)
43
+
44
+
45
+ class Fp32GroupNorm(nn.GroupNorm):
46
+ def __init__(self, *args, **kwargs):
47
+ super().__init__(*args, **kwargs)
48
+
49
+ def forward(self, input):
50
+ output = F.group_norm(
51
+ input.float(),
52
+ self.num_groups,
53
+ self.weight.float() if self.weight is not None else None,
54
+ self.bias.float() if self.bias is not None else None,
55
+ self.eps,
56
+ )
57
+ return output.type_as(input)
58
+
59
+
60
+ class GradMultiply(torch.autograd.Function):
61
+ @staticmethod
62
+ def forward(ctx, x, scale):
63
+ ctx.scale = scale
64
+ res = x.new(x)
65
+ return res
66
+
67
+ @staticmethod
68
+ def backward(ctx, grad):
69
+ return grad * ctx.scale, None
70
+
71
+
72
+ class SamePad(nn.Module):
73
+ def __init__(self, kernel_size, causal=False):
74
+ super().__init__()
75
+ if causal:
76
+ self.remove = kernel_size - 1
77
+ else:
78
+ self.remove = 1 if kernel_size % 2 == 0 else 0
79
+
80
+ def forward(self, x):
81
+ if self.remove > 0:
82
+ x = x[:, :, : -self.remove]
83
+ return x
84
+
85
+
86
+ class Swish(nn.Module):
87
+ """Swish function
88
+ """
89
+
90
+ def __init__(self):
91
+ """Construct an MultiHeadedAttention object."""
92
+ super(Swish, self).__init__()
93
+ self.act = torch.nn.Sigmoid()
94
+
95
+ def forward(self, x):
96
+ return x * self.act(x)
97
+
98
+
99
+ class GLU_Linear(nn.Module):
100
+ def __init__(self, input_dim, output_dim, glu_type="sigmoid", bias_in_glu=True):
101
+ super(GLU_Linear, self).__init__()
102
+
103
+ self.glu_type = glu_type
104
+ self.output_dim = output_dim
105
+
106
+ if glu_type == "sigmoid":
107
+ self.glu_act = torch.nn.Sigmoid()
108
+ elif glu_type == "swish":
109
+ self.glu_act = Swish()
110
+ elif glu_type == "relu":
111
+ self.glu_act = torch.nn.ReLU()
112
+ elif glu_type == "gelu":
113
+ self.glu_act = torch.nn.GELU()
114
+
115
+ if bias_in_glu:
116
+ self.linear = nn.Linear(input_dim, output_dim * 2, True)
117
+ else:
118
+ self.linear = nn.Linear(input_dim, output_dim * 2, False)
119
+
120
+ def forward(self, x):
121
+ # to be consistent with GLU_Linear, we assume the input always has the #channel (#dim) in the last dimension of the tensor, so need to switch the dimension first for 1D-Conv case
122
+ x = self.linear(x)
123
+
124
+ if self.glu_type == "bilinear":
125
+ x = (x[:, :, 0:self.output_dim] * x[:, :, self.output_dim:self.output_dim * 2])
126
+ else:
127
+ x = (x[:, :, 0:self.output_dim] * self.glu_act(x[:, :, self.output_dim:self.output_dim * 2]))
128
+
129
+ return x
130
+
131
+
132
+ def gelu_accurate(x):
133
+ if not hasattr(gelu_accurate, "_a"):
134
+ gelu_accurate._a = math.sqrt(2 / math.pi)
135
+ return (
136
+ 0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3))))
137
+ )
138
+
139
+
140
+ def gelu(x: torch.Tensor) -> torch.Tensor:
141
+ return torch.nn.functional.gelu(x.float()).type_as(x)
142
+
143
+
144
+ def get_activation_fn(activation: str):
145
+ """Returns the activation function corresponding to `activation`"""
146
+
147
+ if activation == "relu":
148
+ return F.relu
149
+ elif activation == "gelu":
150
+ return gelu
151
+ elif activation == "gelu_fast":
152
+ warnings.warn(
153
+ "--activation-fn=gelu_fast has been renamed to gelu_accurate"
154
+ )
155
+ return gelu_accurate
156
+ elif activation == "gelu_accurate":
157
+ return gelu_accurate
158
+ elif activation == "tanh":
159
+ return torch.tanh
160
+ elif activation == "linear":
161
+ return lambda x: x
162
+ elif activation == "glu":
163
+ return lambda x: x
164
+ else:
165
+ raise RuntimeError("--activation-fn {} not supported".format(activation))
166
+
167
+
168
+ def init_bert_params(module):
169
+ """
170
+ Initialize the weights specific to the BERT Model.
171
+ This overrides the default initializations depending on the specified arguments.
172
+ 1. If normal_init_linear_weights is set then weights of linear
173
+ layer will be initialized using the normal distribution and
174
+ bais will be set to the specified value.
175
+ 2. If normal_init_embed_weights is set then weights of embedding
176
+ layer will be initialized using the normal distribution.
177
+ 3. If normal_init_proj_weights is set then weights of
178
+ in_project_weight for MultiHeadAttention initialized using
179
+ the normal distribution (to be validated).
180
+ """
181
+
182
+ def normal_(data):
183
+ # with FSDP, module params will be on CUDA, so we cast them back to CPU
184
+ # so that the RNG is consistent with and without FSDP
185
+ data.copy_(
186
+ data.cpu().normal_(mean=0.0, std=0.02).to(data.device)
187
+ )
188
+
189
+ if isinstance(module, nn.Linear):
190
+ normal_(module.weight.data)
191
+ if module.bias is not None:
192
+ module.bias.data.zero_()
193
+ if isinstance(module, nn.Embedding):
194
+ normal_(module.weight.data)
195
+ if module.padding_idx is not None:
196
+ module.weight.data[module.padding_idx].zero_()
197
+ if isinstance(module, MultiheadAttention):
198
+ normal_(module.q_proj.weight.data)
199
+ normal_(module.k_proj.weight.data)
200
+ normal_(module.v_proj.weight.data)
201
+
202
+
203
+ def quant_noise(module, p, block_size):
204
+ """
205
+ Wraps modules and applies quantization noise to the weights for
206
+ subsequent quantization with Iterative Product Quantization as
207
+ described in "Training with Quantization Noise for Extreme Model Compression"
208
+
209
+ Args:
210
+ - module: nn.Module
211
+ - p: amount of Quantization Noise
212
+ - block_size: size of the blocks for subsequent quantization with iPQ
213
+
214
+ Remarks:
215
+ - Module weights must have the right sizes wrt the block size
216
+ - Only Linear, Embedding and Conv2d modules are supported for the moment
217
+ - For more detail on how to quantize by blocks with convolutional weights,
218
+ see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks"
219
+ - We implement the simplest form of noise here as stated in the paper
220
+ which consists in randomly dropping blocks
221
+ """
222
+
223
+ # if no quantization noise, don't register hook
224
+ if p <= 0:
225
+ return module
226
+
227
+ # supported modules
228
+ assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d))
229
+
230
+ # test whether module.weight has the right sizes wrt block_size
231
+ is_conv = module.weight.ndim == 4
232
+
233
+ # 2D matrix
234
+ if not is_conv:
235
+ assert (
236
+ module.weight.size(1) % block_size == 0
237
+ ), "Input features must be a multiple of block sizes"
238
+
239
+ # 4D matrix
240
+ else:
241
+ # 1x1 convolutions
242
+ if module.kernel_size == (1, 1):
243
+ assert (
244
+ module.in_channels % block_size == 0
245
+ ), "Input channels must be a multiple of block sizes"
246
+ # regular convolutions
247
+ else:
248
+ k = module.kernel_size[0] * module.kernel_size[1]
249
+ assert k % block_size == 0, "Kernel size must be a multiple of block size"
250
+
251
+ def _forward_pre_hook(mod, input):
252
+ # no noise for evaluation
253
+ if mod.training:
254
+ if not is_conv:
255
+ # gather weight and sizes
256
+ weight = mod.weight
257
+ in_features = weight.size(1)
258
+ out_features = weight.size(0)
259
+
260
+ # split weight matrix into blocks and randomly drop selected blocks
261
+ mask = torch.zeros(
262
+ in_features // block_size * out_features, device=weight.device
263
+ )
264
+ mask.bernoulli_(p)
265
+ mask = mask.repeat_interleave(block_size, -1).view(-1, in_features)
266
+
267
+ else:
268
+ # gather weight and sizes
269
+ weight = mod.weight
270
+ in_channels = mod.in_channels
271
+ out_channels = mod.out_channels
272
+
273
+ # split weight matrix into blocks and randomly drop selected blocks
274
+ if mod.kernel_size == (1, 1):
275
+ mask = torch.zeros(
276
+ int(in_channels // block_size * out_channels),
277
+ device=weight.device,
278
+ )
279
+ mask.bernoulli_(p)
280
+ mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels)
281
+ else:
282
+ mask = torch.zeros(
283
+ weight.size(0), weight.size(1), device=weight.device
284
+ )
285
+ mask.bernoulli_(p)
286
+ mask = (
287
+ mask.unsqueeze(2)
288
+ .unsqueeze(3)
289
+ .repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1])
290
+ )
291
+
292
+ # scale weights and apply mask
293
+ mask = mask.to(
294
+ torch.bool
295
+ ) # x.bool() is not currently supported in TorchScript
296
+ s = 1 / (1 - p)
297
+ mod.weight.data = s * weight.masked_fill(mask, 0)
298
+
299
+ module.register_forward_pre_hook(_forward_pre_hook)
300
+ return module
301
+
302
+
303
+ class MultiheadAttention(nn.Module):
304
+ """Multi-headed attention.
305
+
306
+ See "Attention Is All You Need" for more details.
307
+ """
308
+
309
+ def __init__(
310
+ self,
311
+ embed_dim,
312
+ num_heads,
313
+ kdim=None,
314
+ vdim=None,
315
+ dropout=0.0,
316
+ bias=True,
317
+ add_bias_kv=False,
318
+ add_zero_attn=False,
319
+ self_attention=False,
320
+ encoder_decoder_attention=False,
321
+ q_noise=0.0,
322
+ qn_block_size=8,
323
+ has_relative_attention_bias=False,
324
+ num_buckets=32,
325
+ max_distance=128,
326
+ gru_rel_pos=False,
327
+ rescale_init=False,
328
+ ):
329
+ super().__init__()
330
+ self.embed_dim = embed_dim
331
+ self.kdim = kdim if kdim is not None else embed_dim
332
+ self.vdim = vdim if vdim is not None else embed_dim
333
+ self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
334
+
335
+ self.num_heads = num_heads
336
+ self.dropout_module = nn.Dropout(dropout)
337
+
338
+ self.has_relative_attention_bias = has_relative_attention_bias
339
+ self.num_buckets = num_buckets
340
+ self.max_distance = max_distance
341
+ if self.has_relative_attention_bias:
342
+ self.relative_attention_bias = nn.Embedding(num_buckets, num_heads)
343
+
344
+ self.head_dim = embed_dim // num_heads
345
+ self.q_head_dim = self.head_dim
346
+ self.k_head_dim = self.head_dim
347
+ assert (
348
+ self.head_dim * num_heads == self.embed_dim
349
+ ), "embed_dim must be divisible by num_heads"
350
+ self.scaling = self.head_dim ** -0.5
351
+
352
+ self.self_attention = self_attention
353
+ self.encoder_decoder_attention = encoder_decoder_attention
354
+
355
+ assert not self.self_attention or self.qkv_same_dim, (
356
+ "Self-attention requires query, key and " "value to be of the same size"
357
+ )
358
+
359
+ k_bias = True
360
+ if rescale_init:
361
+ k_bias = False
362
+
363
+ k_embed_dim = embed_dim
364
+ q_embed_dim = embed_dim
365
+
366
+ self.k_proj = quant_noise(
367
+ nn.Linear(self.kdim, k_embed_dim, bias=k_bias), q_noise, qn_block_size
368
+ )
369
+ self.v_proj = quant_noise(
370
+ nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size
371
+ )
372
+ self.q_proj = quant_noise(
373
+ nn.Linear(embed_dim, q_embed_dim, bias=bias), q_noise, qn_block_size
374
+ )
375
+
376
+ self.out_proj = quant_noise(
377
+ nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size
378
+ )
379
+
380
+ if add_bias_kv:
381
+ self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
382
+ self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
383
+ else:
384
+ self.bias_k = self.bias_v = None
385
+
386
+ self.add_zero_attn = add_zero_attn
387
+
388
+ self.gru_rel_pos = gru_rel_pos
389
+ if self.gru_rel_pos:
390
+ self.grep_linear = nn.Linear(self.q_head_dim, 8)
391
+ self.grep_a = nn.Parameter(torch.ones(1, num_heads, 1, 1))
392
+
393
+ self.reset_parameters()
394
+
395
+ def reset_parameters(self):
396
+ if self.qkv_same_dim:
397
+ # Empirically observed the convergence to be much better with
398
+ # the scaled initialization
399
+ nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
400
+ nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
401
+ nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
402
+ else:
403
+ nn.init.xavier_uniform_(self.k_proj.weight)
404
+ nn.init.xavier_uniform_(self.v_proj.weight)
405
+ nn.init.xavier_uniform_(self.q_proj.weight)
406
+
407
+ nn.init.xavier_uniform_(self.out_proj.weight)
408
+ if self.out_proj.bias is not None:
409
+ nn.init.constant_(self.out_proj.bias, 0.0)
410
+ if self.bias_k is not None:
411
+ nn.init.xavier_normal_(self.bias_k)
412
+ if self.bias_v is not None:
413
+ nn.init.xavier_normal_(self.bias_v)
414
+ if self.has_relative_attention_bias:
415
+ nn.init.xavier_normal_(self.relative_attention_bias.weight)
416
+
417
+ def _relative_positions_bucket(self, relative_positions, bidirectional=True):
418
+ num_buckets = self.num_buckets
419
+ max_distance = self.max_distance
420
+ relative_buckets = 0
421
+
422
+ if bidirectional:
423
+ num_buckets = num_buckets // 2
424
+ relative_buckets += (relative_positions > 0).to(torch.long) * num_buckets
425
+ relative_positions = torch.abs(relative_positions)
426
+ else:
427
+ relative_positions = -torch.min(relative_positions, torch.zeros_like(relative_positions))
428
+
429
+ max_exact = num_buckets // 2
430
+ is_small = relative_positions < max_exact
431
+
432
+ relative_postion_if_large = max_exact + (
433
+ torch.log(relative_positions.float() / max_exact)
434
+ / math.log(max_distance / max_exact)
435
+ * (num_buckets - max_exact)
436
+ ).to(torch.long)
437
+ relative_postion_if_large = torch.min(
438
+ relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1)
439
+ )
440
+
441
+ relative_buckets += torch.where(is_small, relative_positions, relative_postion_if_large)
442
+ return relative_buckets
443
+
444
+ def compute_bias(self, query_length, key_length):
445
+ context_position = torch.arange(query_length, dtype=torch.long)[:, None]
446
+ memory_position = torch.arange(key_length, dtype=torch.long)[None, :]
447
+ relative_position = memory_position - context_position
448
+ relative_position_bucket = self._relative_positions_bucket(
449
+ relative_position,
450
+ bidirectional=True
451
+ )
452
+ relative_position_bucket = relative_position_bucket.to(self.relative_attention_bias.weight.device)
453
+ values = self.relative_attention_bias(relative_position_bucket)
454
+ values = values.permute([2, 0, 1])
455
+ return values
456
+
457
+ def forward(
458
+ self,
459
+ query,
460
+ key: Optional[Tensor],
461
+ value: Optional[Tensor],
462
+ key_padding_mask: Optional[Tensor] = None,
463
+ incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
464
+ need_weights: bool = True,
465
+ static_kv: bool = False,
466
+ attn_mask: Optional[Tensor] = None,
467
+ before_softmax: bool = False,
468
+ need_head_weights: bool = False,
469
+ position_bias: Optional[Tensor] = None
470
+ ) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
471
+ """Input shape: Time x Batch x Channel
472
+
473
+ Args:
474
+ key_padding_mask (ByteTensor, optional): mask to exclude
475
+ keys that are pads, of shape `(batch, src_len)`, where
476
+ padding elements are indicated by 1s.
477
+ need_weights (bool, optional): return the attention weights,
478
+ averaged over heads (default: False).
479
+ attn_mask (ByteTensor, optional): typically used to
480
+ implement causal attention, where the mask prevents the
481
+ attention from looking forward in time (default: None).
482
+ before_softmax (bool, optional): return the raw attention
483
+ weights and values before the attention softmax.
484
+ need_head_weights (bool, optional): return the attention
485
+ weights for each head. Implies *need_weights*. Default:
486
+ return the average attention weights over all heads.
487
+ """
488
+ if need_head_weights:
489
+ need_weights = True
490
+
491
+ is_tpu = query.device.type == "xla"
492
+
493
+ tgt_len, bsz, embed_dim = query.size()
494
+ src_len = tgt_len
495
+ assert embed_dim == self.embed_dim
496
+ assert list(query.size()) == [tgt_len, bsz, embed_dim]
497
+ if key is not None:
498
+ src_len, key_bsz, _ = key.size()
499
+ if not torch.jit.is_scripting():
500
+ assert key_bsz == bsz
501
+ assert value is not None
502
+ assert src_len, bsz == value.shape[:2]
503
+
504
+ if self.has_relative_attention_bias and position_bias is None:
505
+ position_bias = self.compute_bias(tgt_len, src_len)
506
+ position_bias = position_bias.unsqueeze(0).repeat(bsz, 1, 1, 1).view(bsz * self.num_heads, tgt_len, src_len)
507
+
508
+ if (
509
+ not is_tpu # don't use PyTorch version on TPUs
510
+ and incremental_state is None
511
+ and not static_kv
512
+ # A workaround for quantization to work. Otherwise JIT compilation
513
+ # treats bias in linear module as method.
514
+ and not torch.jit.is_scripting()
515
+ and self.q_head_dim == self.head_dim
516
+ ):
517
+ assert key is not None and value is not None
518
+ assert attn_mask is None
519
+
520
+ attn_mask_rel_pos = None
521
+ if position_bias is not None:
522
+ attn_mask_rel_pos = position_bias
523
+ if self.gru_rel_pos:
524
+ query_layer = query.transpose(0, 1)
525
+ new_x_shape = query_layer.size()[:-1] + (self.num_heads, -1)
526
+ query_layer = query_layer.view(*new_x_shape)
527
+ query_layer = query_layer.permute(0, 2, 1, 3)
528
+ _B, _H, _L, __ = query_layer.size()
529
+
530
+ gate_a, gate_b = torch.sigmoid(self.grep_linear(query_layer).view(
531
+ _B, _H, _L, 2, 4).sum(-1, keepdim=False)).chunk(2, dim=-1)
532
+ gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0
533
+ attn_mask_rel_pos = gate_a_1.view(bsz * self.num_heads, -1, 1) * position_bias
534
+
535
+ attn_mask_rel_pos = attn_mask_rel_pos.view((-1, tgt_len, tgt_len))
536
+ k_proj_bias = self.k_proj.bias
537
+ if k_proj_bias is None:
538
+ k_proj_bias = torch.zeros_like(self.q_proj.bias)
539
+
540
+ x, attn = F.multi_head_attention_forward(
541
+ query,
542
+ key,
543
+ value,
544
+ self.embed_dim,
545
+ self.num_heads,
546
+ torch.empty([0]),
547
+ torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)),
548
+ self.bias_k,
549
+ self.bias_v,
550
+ self.add_zero_attn,
551
+ self.dropout_module.p,
552
+ self.out_proj.weight,
553
+ self.out_proj.bias,
554
+ self.training,
555
+ # self.training or self.dropout_module.apply_during_inference,
556
+ key_padding_mask,
557
+ need_weights,
558
+ attn_mask_rel_pos,
559
+ use_separate_proj_weight=True,
560
+ q_proj_weight=self.q_proj.weight,
561
+ k_proj_weight=self.k_proj.weight,
562
+ v_proj_weight=self.v_proj.weight,
563
+ )
564
+ return x, attn, position_bias
565
+
566
+ if incremental_state is not None:
567
+ saved_state = self._get_input_buffer(incremental_state)
568
+ if saved_state is not None and "prev_key" in saved_state:
569
+ # previous time steps are cached - no need to recompute
570
+ # key and value if they are static
571
+ if static_kv:
572
+ assert self.encoder_decoder_attention and not self.self_attention
573
+ key = value = None
574
+ else:
575
+ saved_state = None
576
+
577
+ if self.self_attention:
578
+ q = self.q_proj(query)
579
+ k = self.k_proj(query)
580
+ v = self.v_proj(query)
581
+ elif self.encoder_decoder_attention:
582
+ # encoder-decoder attention
583
+ q = self.q_proj(query)
584
+ if key is None:
585
+ assert value is None
586
+ k = v = None
587
+ else:
588
+ k = self.k_proj(key)
589
+ v = self.v_proj(key)
590
+
591
+ else:
592
+ assert key is not None and value is not None
593
+ q = self.q_proj(query)
594
+ k = self.k_proj(key)
595
+ v = self.v_proj(value)
596
+ q *= self.scaling
597
+
598
+ if self.bias_k is not None:
599
+ assert self.bias_v is not None
600
+ k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
601
+ v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
602
+ if attn_mask is not None:
603
+ attn_mask = torch.cat(
604
+ [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
605
+ )
606
+ if key_padding_mask is not None:
607
+ key_padding_mask = torch.cat(
608
+ [
609
+ key_padding_mask,
610
+ key_padding_mask.new_zeros(key_padding_mask.size(0), 1),
611
+ ],
612
+ dim=1,
613
+ )
614
+
615
+ q = (
616
+ q.contiguous()
617
+ .view(tgt_len, bsz * self.num_heads, self.q_head_dim)
618
+ .transpose(0, 1)
619
+ )
620
+ if k is not None:
621
+ k = (
622
+ k.contiguous()
623
+ .view(-1, bsz * self.num_heads, self.k_head_dim)
624
+ .transpose(0, 1)
625
+ )
626
+ if v is not None:
627
+ v = (
628
+ v.contiguous()
629
+ .view(-1, bsz * self.num_heads, self.head_dim)
630
+ .transpose(0, 1)
631
+ )
632
+
633
+ if saved_state is not None:
634
+ # saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
635
+ if "prev_key" in saved_state:
636
+ _prev_key = saved_state["prev_key"]
637
+ assert _prev_key is not None
638
+ prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim)
639
+ if static_kv:
640
+ k = prev_key
641
+ else:
642
+ assert k is not None
643
+ k = torch.cat([prev_key, k], dim=1)
644
+ src_len = k.size(1)
645
+ if "prev_value" in saved_state:
646
+ _prev_value = saved_state["prev_value"]
647
+ assert _prev_value is not None
648
+ prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim)
649
+ if static_kv:
650
+ v = prev_value
651
+ else:
652
+ assert v is not None
653
+ v = torch.cat([prev_value, v], dim=1)
654
+ prev_key_padding_mask: Optional[Tensor] = None
655
+ if "prev_key_padding_mask" in saved_state:
656
+ prev_key_padding_mask = saved_state["prev_key_padding_mask"]
657
+ assert k is not None and v is not None
658
+ key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(
659
+ key_padding_mask=key_padding_mask,
660
+ prev_key_padding_mask=prev_key_padding_mask,
661
+ batch_size=bsz,
662
+ src_len=k.size(1),
663
+ static_kv=static_kv,
664
+ )
665
+
666
+ saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim)
667
+ saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim)
668
+ saved_state["prev_key_padding_mask"] = key_padding_mask
669
+ # In this branch incremental_state is never None
670
+ assert incremental_state is not None
671
+ incremental_state = self._set_input_buffer(incremental_state, saved_state)
672
+ assert k is not None
673
+ assert k.size(1) == src_len
674
+
675
+ # This is part of a workaround to get around fork/join parallelism
676
+ # not supporting Optional types.
677
+ if key_padding_mask is not None and key_padding_mask.dim() == 0:
678
+ key_padding_mask = None
679
+
680
+ if key_padding_mask is not None:
681
+ assert key_padding_mask.size(0) == bsz
682
+ assert key_padding_mask.size(1) == src_len
683
+
684
+ if self.add_zero_attn:
685
+ assert v is not None
686
+ src_len += 1
687
+ k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
688
+ v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
689
+ if attn_mask is not None:
690
+ attn_mask = torch.cat(
691
+ [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
692
+ )
693
+ if key_padding_mask is not None:
694
+ key_padding_mask = torch.cat(
695
+ [
696
+ key_padding_mask,
697
+ torch.zeros(key_padding_mask.size(0), 1).type_as(
698
+ key_padding_mask
699
+ ),
700
+ ],
701
+ dim=1,
702
+ )
703
+
704
+ attn_weights = torch.bmm(q, k.transpose(1, 2))
705
+ attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
706
+
707
+ assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
708
+
709
+ if attn_mask is not None:
710
+ attn_mask = attn_mask.unsqueeze(0)
711
+ attn_weights += attn_mask
712
+
713
+ if key_padding_mask is not None:
714
+ # don't attend to padding symbols
715
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
716
+ if not is_tpu:
717
+ attn_weights = attn_weights.masked_fill(
718
+ key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
719
+ float("-inf"),
720
+ )
721
+ else:
722
+ attn_weights = attn_weights.transpose(0, 2)
723
+ attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf"))
724
+ attn_weights = attn_weights.transpose(0, 2)
725
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
726
+
727
+ if before_softmax:
728
+ return attn_weights, v, position_bias
729
+
730
+ if position_bias is not None:
731
+ if self.gru_rel_pos == 1:
732
+ query_layer = q.view(bsz, self.num_heads, tgt_len, self.q_head_dim)
733
+ _B, _H, _L, __ = query_layer.size()
734
+ gate_a, gate_b = torch.sigmoid(self.grep_linear(query_layer).view(
735
+ _B, _H, _L, 2, 4).sum(-1, keepdim=False)).chunk(2, dim=-1)
736
+ gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0
737
+ position_bias = gate_a_1.view(bsz * self.num_heads, -1, 1) * position_bias
738
+
739
+ position_bias = position_bias.view(attn_weights.size())
740
+
741
+ attn_weights = attn_weights + position_bias
742
+
743
+ attn_weights_float = F.softmax(
744
+ attn_weights, dim=-1
745
+ )
746
+ attn_weights = attn_weights_float.type_as(attn_weights)
747
+ attn_probs = self.dropout_module(attn_weights)
748
+
749
+ assert v is not None
750
+ attn = torch.bmm(attn_probs, v)
751
+ assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
752
+ attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
753
+ attn = self.out_proj(attn)
754
+ attn_weights: Optional[Tensor] = None
755
+ if need_weights:
756
+ attn_weights = attn_weights_float.view(
757
+ bsz, self.num_heads, tgt_len, src_len
758
+ ).transpose(1, 0)
759
+ if not need_head_weights:
760
+ # average attention weights over heads
761
+ attn_weights = attn_weights.mean(dim=0)
762
+
763
+ return attn, attn_weights, position_bias
764
+
765
+ @staticmethod
766
+ def _append_prev_key_padding_mask(
767
+ key_padding_mask: Optional[Tensor],
768
+ prev_key_padding_mask: Optional[Tensor],
769
+ batch_size: int,
770
+ src_len: int,
771
+ static_kv: bool,
772
+ ) -> Optional[Tensor]:
773
+ # saved key padding masks have shape (bsz, seq_len)
774
+ if prev_key_padding_mask is not None and static_kv:
775
+ new_key_padding_mask = prev_key_padding_mask
776
+ elif prev_key_padding_mask is not None and key_padding_mask is not None:
777
+ new_key_padding_mask = torch.cat(
778
+ [prev_key_padding_mask.float(), key_padding_mask.float()], dim=1
779
+ )
780
+ # During incremental decoding, as the padding token enters and
781
+ # leaves the frame, there will be a time when prev or current
782
+ # is None
783
+ elif prev_key_padding_mask is not None:
784
+ if src_len > prev_key_padding_mask.size(1):
785
+ filler = torch.zeros(
786
+ (batch_size, src_len - prev_key_padding_mask.size(1)),
787
+ device=prev_key_padding_mask.device,
788
+ )
789
+ new_key_padding_mask = torch.cat(
790
+ [prev_key_padding_mask.float(), filler.float()], dim=1
791
+ )
792
+ else:
793
+ new_key_padding_mask = prev_key_padding_mask.float()
794
+ elif key_padding_mask is not None:
795
+ if src_len > key_padding_mask.size(1):
796
+ filler = torch.zeros(
797
+ (batch_size, src_len - key_padding_mask.size(1)),
798
+ device=key_padding_mask.device,
799
+ )
800
+ new_key_padding_mask = torch.cat(
801
+ [filler.float(), key_padding_mask.float()], dim=1
802
+ )
803
+ else:
804
+ new_key_padding_mask = key_padding_mask.float()
805
+ else:
806
+ new_key_padding_mask = prev_key_padding_mask
807
+ return new_key_padding_mask
808
+
809
+ def _get_input_buffer(
810
+ self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]
811
+ ) -> Dict[str, Optional[Tensor]]:
812
+ result = self.get_incremental_state(incremental_state, "attn_state")
813
+ if result is not None:
814
+ return result
815
+ else:
816
+ empty_result: Dict[str, Optional[Tensor]] = {}
817
+ return empty_result
818
+
819
+ def _set_input_buffer(
820
+ self,
821
+ incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
822
+ buffer: Dict[str, Optional[Tensor]],
823
+ ):
824
+ return self.set_incremental_state(incremental_state, "attn_state", buffer)
825
+
826
+ def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int):
827
+ return attn_weights