Spaces:
Running
Running
jason-salt
commited on
Commit
•
b971d47
1
Parent(s):
53b664a
init
Browse files- .gitattributes +1 -0
- __pycache__/inference_tts_scale.cpython-310.pyc +0 -0
- data/__init__.py +0 -0
- data/__pycache__/__init__.cpython-310.pyc +0 -0
- data/__pycache__/tokenizer.cpython-310.pyc +0 -0
- data/gigaspeech.py +156 -0
- data/phonemize_encodec_encode_hf.py +206 -0
- data/tokenizer.py +149 -0
- demo/84_121550_000074_000000.wav +0 -0
- demo/generated_se/84_121550_000074_000000_new_seed1.wav +0 -0
- demo/generated_se/84_121550_000074_000000_orig.wav +0 -0
- demo/generated_tts/84_121550_000074_000000_concat_seed1.wav +0 -0
- demo/generated_tts/84_121550_000074_000000_gen_seed1.wav +0 -0
- demo/temp/84_121550_000074_000000.txt +1 -0
- demo/temp/84_121550_000074_000000.wav +0 -0
- demo/temp/mfa_alignments/84_121550_000074_000000.csv +109 -0
- gradio_app.py +528 -0
- inference_speech_editing_scale.py +226 -0
- inference_tts_scale.py +190 -0
- models/__pycache__/codebooks_patterns.cpython-310.pyc +0 -0
- models/__pycache__/voicecraft.cpython-310.pyc +0 -0
- models/codebooks_patterns.py +538 -0
- models/modules/__init__.py +0 -0
- models/modules/__pycache__/__init__.cpython-310.pyc +0 -0
- models/modules/__pycache__/__init__.cpython-39.pyc +0 -0
- models/modules/__pycache__/activation.cpython-310.pyc +0 -0
- models/modules/__pycache__/activation.cpython-39.pyc +0 -0
- models/modules/__pycache__/embedding.cpython-310.pyc +0 -0
- models/modules/__pycache__/embedding.cpython-39.pyc +0 -0
- models/modules/__pycache__/scaling.cpython-310.pyc +0 -0
- models/modules/__pycache__/scaling.cpython-39.pyc +0 -0
- models/modules/__pycache__/transformer.cpython-310.pyc +0 -0
- models/modules/__pycache__/transformer.cpython-39.pyc +0 -0
- models/modules/__pycache__/utils.cpython-310.pyc +0 -0
- models/modules/__pycache__/utils.cpython-39.pyc +0 -0
- models/modules/__pycache__/visualizer.cpython-39.pyc +0 -0
- models/modules/activation.py +653 -0
- models/modules/embedding.py +98 -0
- models/modules/sampling.py +63 -0
- models/modules/scaling.py +1406 -0
- models/modules/transformer.py +698 -0
- models/modules/utils.py +37 -0
- models/voicecraft.py +1406 -0
- pretrained_models/encodec_4cb2048_giga.th +3 -0
- pretrained_models/giga330M.pth +3 -0
- requirements.txt +9 -0
.gitattributes
CHANGED
@@ -20,6 +20,7 @@
|
|
20 |
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
*.pt filter=lfs diff=lfs merge=lfs -text
|
|
|
23 |
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
|
|
20 |
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.th filter=lfs diff=lfs merge=lfs -text
|
24 |
*.pth filter=lfs diff=lfs merge=lfs -text
|
25 |
*.rar filter=lfs diff=lfs merge=lfs -text
|
26 |
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
__pycache__/inference_tts_scale.cpython-310.pyc
ADDED
Binary file (6.8 kB). View file
|
|
data/__init__.py
ADDED
File without changes
|
data/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (130 Bytes). View file
|
|
data/__pycache__/tokenizer.cpython-310.pyc
ADDED
Binary file (4.83 kB). View file
|
|
data/gigaspeech.py
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import random
|
4 |
+
import copy
|
5 |
+
import logging
|
6 |
+
import shutil
|
7 |
+
|
8 |
+
class dataset(torch.utils.data.Dataset):
|
9 |
+
def __init__(self, args, split):
|
10 |
+
super().__init__()
|
11 |
+
self.args = args
|
12 |
+
self.split = split
|
13 |
+
assert self.split in ['train', 'validation', 'test']
|
14 |
+
manifest_fn = os.path.join(self.args.dataset_dir, self.args.manifest_name, self.split+".txt")
|
15 |
+
|
16 |
+
with open(manifest_fn, "r") as rf:
|
17 |
+
data = [l.strip().split("\t") for l in rf.readlines()]
|
18 |
+
lengths_list = [int(item[-1]) for item in data]
|
19 |
+
self.data = []
|
20 |
+
self.lengths_list = []
|
21 |
+
for d, l in zip(data, lengths_list):
|
22 |
+
if l >= self.args.encodec_sr*self.args.audio_min_length:
|
23 |
+
if self.args.drop_long and l > self.args.encodec_sr*self.args.audio_max_length:
|
24 |
+
continue
|
25 |
+
self.data.append(d)
|
26 |
+
self.lengths_list.append(l)
|
27 |
+
logging.info(f"number of data points for {self.split} split: {len(self.lengths_list)}")
|
28 |
+
|
29 |
+
# phoneme vocabulary
|
30 |
+
vocab_fn = os.path.join(self.args.dataset_dir,"vocab.txt")
|
31 |
+
shutil.copy(vocab_fn, os.path.join(self.args.exp_dir, "vocab.txt"))
|
32 |
+
with open(vocab_fn, "r") as f:
|
33 |
+
temp = [l.strip().split(" ") for l in f.readlines() if len(l) != 0]
|
34 |
+
self.phn2num = {item[1]:int(item[0]) for item in temp}
|
35 |
+
|
36 |
+
self.symbol_set = set(["<SIL>", "<MUSIC>", "<NOISE>", "<OTHER>"])
|
37 |
+
|
38 |
+
def __len__(self):
|
39 |
+
return len(self.lengths_list)
|
40 |
+
|
41 |
+
def _load_phn_enc(self, index):
|
42 |
+
item = self.data[index]
|
43 |
+
pf = os.path.join(self.args.dataset_dir, self.args.phn_folder_name, item[1]+".txt")
|
44 |
+
ef = os.path.join(self.args.dataset_dir, self.args.encodec_folder_name, item[1]+".txt")
|
45 |
+
try:
|
46 |
+
with open(pf, "r") as p, open(ef, "r") as e:
|
47 |
+
phns = [l.strip() for l in p.readlines()]
|
48 |
+
assert len(phns) == 1, phns
|
49 |
+
x = [self.phn2num[item] for item in phns[0].split(" ") if item not in self.symbol_set] # drop ["<SIL>", "<MUSIC>", "<NOISE>", "<OTHER>"], as they are not in training set annotation
|
50 |
+
encos = [l.strip().split() for k, l in enumerate(e.readlines()) if k < self.args.n_codebooks]
|
51 |
+
|
52 |
+
assert len(encos) == self.args.n_codebooks, ef
|
53 |
+
if self.args.special_first:
|
54 |
+
y = [[int(n)+self.args.n_special for n in l] for l in encos]
|
55 |
+
else:
|
56 |
+
y = [[int(n) for n in l] for l in encos]
|
57 |
+
except Exception as e:
|
58 |
+
logging.info(f"loading failed for {pf} and {ef}, maybe files don't exist or are corrupted")
|
59 |
+
logging.info(f"error message: {e}")
|
60 |
+
return [], [[]]
|
61 |
+
|
62 |
+
return x, y
|
63 |
+
|
64 |
+
def __getitem__(self, index):
|
65 |
+
x, y = self._load_phn_enc(index)
|
66 |
+
x_len, y_len = len(x), len(y[0])
|
67 |
+
|
68 |
+
if x_len == 0 or y_len == 0:
|
69 |
+
return {
|
70 |
+
"x": None,
|
71 |
+
"x_len": None,
|
72 |
+
"y": None,
|
73 |
+
"y_len": None,
|
74 |
+
"y_mask_interval": None, # index y_mask_interval[1] is the position of start_of_continue token
|
75 |
+
"extra_mask_start": None # this is only used in VE1
|
76 |
+
}
|
77 |
+
while y_len < self.args.encodec_sr*self.args.audio_min_length:
|
78 |
+
assert not self.args.dynamic_batching
|
79 |
+
index = random.choice(range(len(self))) # regenerate an index
|
80 |
+
x, y = self._load_phn_enc(index)
|
81 |
+
x_len, y_len = len(x), len(y[0])
|
82 |
+
if self.args.drop_long:
|
83 |
+
while x_len > self.args.text_max_length or y_len > self.args.encodec_sr*self.args.audio_max_length:
|
84 |
+
index = random.choice(range(len(self))) # regenerate an index
|
85 |
+
x, y = self._load_phn_enc(index)
|
86 |
+
x_len, y_len = len(x), len(y[0])
|
87 |
+
|
88 |
+
### padding and cropping below ###
|
89 |
+
### padding and cropping below ###
|
90 |
+
# adjust the length of encodec codes, pad to max_len or randomly crop
|
91 |
+
orig_y_len = copy.copy(y_len)
|
92 |
+
max_len = int(self.args.audio_max_length * self.args.encodec_sr)
|
93 |
+
if y_len > max_len:
|
94 |
+
audio_start = random.choice(range(0, y_len-max_len))
|
95 |
+
for i in range(len(y)):
|
96 |
+
y[i] = y[i][audio_start:(audio_start+max_len)]
|
97 |
+
y_len = max_len
|
98 |
+
else:
|
99 |
+
audio_start = 0
|
100 |
+
if not self.args.dynamic_batching:
|
101 |
+
pad = [0] * (max_len - y_len) if self.args.sep_special_token else [self.args.audio_pad_token] * (max_len - y_len)
|
102 |
+
for i in range(len(y)):
|
103 |
+
y[i] = y[i] + pad
|
104 |
+
|
105 |
+
# adjust text
|
106 |
+
# if audio is cropped, and text is longer than max, crop max based on how audio is cropped
|
107 |
+
if audio_start > 0 and len(x) > self.args.text_max_length: # if audio is longer than max and text is long than max, start text the way audio started
|
108 |
+
x = x[int(len(x)*audio_start/orig_y_len):]
|
109 |
+
if len(x) > self.args.text_max_length: # if text is still longer than max, cut the end
|
110 |
+
x = x[:self.args.text_max_length]
|
111 |
+
|
112 |
+
x_len = len(x)
|
113 |
+
if x_len > self.args.text_max_length:
|
114 |
+
text_start = random.choice(range(0, x_len - self.args.text_max_length))
|
115 |
+
x = x[text_start:text_start+self.args.text_max_length]
|
116 |
+
x_len = self.args.text_max_length
|
117 |
+
elif self.args.pad_x and x_len <= self.args.text_max_length:
|
118 |
+
pad = [0] * (self.args.text_max_length - x_len) if self.args.sep_special_token else [self.args.text_pad_token] * (self.args.text_max_length - x_len)
|
119 |
+
x = x + pad
|
120 |
+
### padding and cropping above ###
|
121 |
+
### padding and cropping above ###
|
122 |
+
|
123 |
+
return {
|
124 |
+
"x": torch.LongTensor(x),
|
125 |
+
"x_len": x_len,
|
126 |
+
"y": torch.LongTensor(y),
|
127 |
+
"y_len": y_len
|
128 |
+
}
|
129 |
+
|
130 |
+
|
131 |
+
def collate(self, batch):
|
132 |
+
out = {key:[] for key in batch[0]}
|
133 |
+
for item in batch:
|
134 |
+
if item['x'] == None: # deal with load failure
|
135 |
+
continue
|
136 |
+
for key, val in item.items():
|
137 |
+
out[key].append(val)
|
138 |
+
res = {}
|
139 |
+
if self.args.pad_x:
|
140 |
+
res["x"] = torch.stack(out["x"], dim=0)
|
141 |
+
else:
|
142 |
+
res["x"] = torch.nn.utils.rnn.pad_sequence(out["x"], batch_first=True, padding_value=self.args.text_pad_token)
|
143 |
+
res["x_lens"] = torch.LongTensor(out["x_len"])
|
144 |
+
if self.args.dynamic_batching:
|
145 |
+
if out['y'][0].ndim==2:
|
146 |
+
res['y'] = torch.nn.utils.rnn.pad_sequence([item.transpose(1,0) for item in out['y']],padding_value=self.args.audio_pad_token)
|
147 |
+
res['y'] = res['y'].permute(1,2,0) # T B K -> B K T
|
148 |
+
else:
|
149 |
+
assert out['y'][0].ndim==1, out['y'][0].shape
|
150 |
+
res['y'] = torch.nn.utils.rnn.pad_sequence(out['y'], batch_first=True, padding_value=self.args.audio_pad_token)
|
151 |
+
else:
|
152 |
+
res['y'] = torch.stack(out['y'], dim=0)
|
153 |
+
res["y_lens"] = torch.LongTensor(out["y_len"])
|
154 |
+
res["text_padding_mask"] = torch.arange(res['x'][0].shape[-1]).unsqueeze(0) >= res['x_lens'].unsqueeze(1)
|
155 |
+
res["audio_padding_mask"] = torch.arange(res['y'][0].shape[-1]).unsqueeze(0) >= res['y_lens'].unsqueeze(1)
|
156 |
+
return res
|
data/phonemize_encodec_encode_hf.py
ADDED
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
def parse_args():
|
3 |
+
parser = argparse.ArgumentParser(description="encode the librilight dataset using encodec model")
|
4 |
+
parser.add_argument("--dataset_size", type=str, default='xs', help='sizes of gigaspeech, xs, s, m, l, xl. we use xl for VoiceCraft training, xs is good for debugging')
|
5 |
+
parser.add_argument('--download_to', type=str, default="/data/scratch/pyp/datasets/gigaspeech_debug", help="dir where you want the huggingface gigaspeech dataset to be downloaded to")
|
6 |
+
parser.add_argument('--save_dir', type=str, default="/data/scratch/pyp/datasets/gigaspeech_phn_enc_manifest_debug", help="path to the manifest, phonemes, and encodec codes dirs")
|
7 |
+
parser.add_argument('--encodec_model_path', type=str, default="/data/scratch/pyp/exp_pyp/audiocraft/encodec/xps/6f79c6a8/checkpoint.th")
|
8 |
+
parser.add_argument('--n_workers', type=int, default=4, help="Number of parallel worker processes")
|
9 |
+
parser.add_argument('--mega_batch_size', type=int, default=100, help="Number of samples in each mega batch for multiprocess dataloading")
|
10 |
+
parser.add_argument('--batch_size', type=int, default=4, help="batch size for encodec encoding, decrease it if OOM. This is the sum of batch size *over each gpu*, so increase it if you are using more gpus")
|
11 |
+
parser.add_argument('--model_sr', type=int, default=16000, help='encodec input audio sample rate')
|
12 |
+
parser.add_argument('--downsample_rate', type=int, default=320, help='encodec downsample rate')
|
13 |
+
parser.add_argument('--model_code_sr', type=int, default=50, help='encodec model code sample rate')
|
14 |
+
parser.add_argument('--len_cap', type=float, default=35.0, help='will drop audios that are longer than this number')
|
15 |
+
parser.add_argument('--max_len', type=int, default=30000, help='max length of audio in samples, if exceed, will cut a batch into half to process, decrease this number if OOM on your machine')
|
16 |
+
return parser.parse_args()
|
17 |
+
if __name__ == "__main__":
|
18 |
+
import logging
|
19 |
+
formatter = (
|
20 |
+
"%(asctime)s [%(levelname)s] %(filename)s:%(lineno)d || %(message)s"
|
21 |
+
)
|
22 |
+
logging.basicConfig(format=formatter, level=logging.INFO)
|
23 |
+
args = parse_args()
|
24 |
+
|
25 |
+
import os
|
26 |
+
import numpy as np
|
27 |
+
import torch
|
28 |
+
import tqdm
|
29 |
+
import time
|
30 |
+
from datasets import load_dataset, DownloadConfig
|
31 |
+
|
32 |
+
from tokenizer import TextTokenizer, tokenize_text
|
33 |
+
|
34 |
+
# get the path
|
35 |
+
phn_save_root = os.path.join(args.save_dir, args.dataset_size, "phonemes")
|
36 |
+
codes_save_root = os.path.join(args.save_dir, args.dataset_size, "encodec_16khz_4codebooks")
|
37 |
+
vocab_fn = os.path.join(args.save_dir, args.dataset_size, "vocab.txt")
|
38 |
+
os.makedirs(phn_save_root, exist_ok=True)
|
39 |
+
os.makedirs(codes_save_root, exist_ok=True)
|
40 |
+
|
41 |
+
|
42 |
+
def sort_by_audio_len(lens):
|
43 |
+
inds = np.argsort(lens).tolist()
|
44 |
+
logging.info(f"longest: {lens[inds[-1]]*args.model_code_sr} encodec codes, {lens[inds[-1]]:.2f} sec.")
|
45 |
+
logging.info(f"shortest: {lens[inds[0]]*args.model_code_sr} encodec codes, {lens[inds[0]]:.2f} sec.")
|
46 |
+
logging.info(f"median: {lens[inds[len(inds)//2]]*args.model_code_sr} encodec codes, {lens[inds[len(inds)//2]]:.2f} sec.")
|
47 |
+
logging.info(f"95 percentile longest: {lens[inds[int(len(inds)*0.95)]]*args.model_code_sr} encodec codes, {lens[inds[int(len(inds)*0.95)]]:.2f} sec.")
|
48 |
+
return inds[::-1]
|
49 |
+
|
50 |
+
def write_array_to_txt_file(array, filename):
|
51 |
+
with open(filename, 'w') as f:
|
52 |
+
for a in array[:-1]:
|
53 |
+
f.write(' '.join(map(str, a))+'\n')
|
54 |
+
f.write(' '.join(map(str, array[-1])))
|
55 |
+
|
56 |
+
|
57 |
+
### phonemization
|
58 |
+
# load tokenizer
|
59 |
+
# load the encodec model
|
60 |
+
from audiocraft.solvers import CompressionSolver
|
61 |
+
model = CompressionSolver.model_from_checkpoint(args.encodec_model_path)
|
62 |
+
model = model.cuda()
|
63 |
+
model = model.eval()
|
64 |
+
text_tokenizer = TextTokenizer()
|
65 |
+
|
66 |
+
|
67 |
+
# https://github.com/SpeechColab/GigaSpeech
|
68 |
+
# there are only four different punctuations
|
69 |
+
# need to check whether there are other < started strings
|
70 |
+
punc2sym = {" <COMMA>": ",", " <PERIOD>": ".", " <QUESTIONMARK>": "?", " <EXCLAMATIONPOINT>": "!"} # note the space in front of each punc name
|
71 |
+
gar2sym = {"<SIL>": "#%#", "<MUSIC>": "##%", "<NOISE>": "%%#", "<OTHER>":"%#%"} # so that they are savely keep as the original sym when using tokenize_text
|
72 |
+
punc2sym.update(gar2sym)
|
73 |
+
|
74 |
+
word2sym = { "h æ ʃ h ɐ ʃ p ɚ s ɛ n t": "<MUSIC>", "h æ ʃ p ɚ s ɛ n t h æ ʃ": "<SIL>", "p ɚ s ɛ n t h ɐ ʃ p ɚ s ɛ n t": "<OTHER>", "p ɚ s ɛ n t p ɚ s ɛ n t h æ ʃ": "<NOISE>"}
|
75 |
+
forbidden_words = set(['#%#', '##%', '%%#', '%#%'])
|
76 |
+
|
77 |
+
dc = DownloadConfig(cache_dir=args.download_to)
|
78 |
+
stime = time.time()
|
79 |
+
logging.info("loading the dataset...")
|
80 |
+
gs = load_dataset("speechcolab/gigaspeech", args.dataset_size, use_auth_token=True, cache_dir = args.download_to, download_config=dc)
|
81 |
+
logging.info(f"time spend on loading the dataset: {time.time() - stime:.2f} seconds")
|
82 |
+
|
83 |
+
splits = ['validation', 'test', 'train']
|
84 |
+
|
85 |
+
logging.info(f"gigaspeech dataset {args.dataset_size} info: {gs}")
|
86 |
+
logging.info(f"phonemizing...")
|
87 |
+
phn_vocab = set()
|
88 |
+
all_lens = []
|
89 |
+
|
90 |
+
# you will see a ton of [WARNING] words_mismatch.py:88......, it's not a issue
|
91 |
+
for split in tqdm.tqdm(splits):
|
92 |
+
skip = 0
|
93 |
+
logging.info(f"now processing split {split}...")
|
94 |
+
for item in tqdm.tqdm(gs[split]):
|
95 |
+
save_fn = os.path.join(phn_save_root, item['segment_id']+".txt")
|
96 |
+
text = item['text']
|
97 |
+
if sum(word in forbidden_words for word in text.split(" ")):
|
98 |
+
logging.info(f"skip {item['segment_id']}, because it contains forbiden words. It's transcript: {text}")
|
99 |
+
skip += 1
|
100 |
+
continue
|
101 |
+
for k, v in punc2sym.items():
|
102 |
+
text = text.replace(k, v)
|
103 |
+
phn = tokenize_text(text_tokenizer, text)
|
104 |
+
phn_seq = " ".join(phn)
|
105 |
+
for k, v in word2sym.items():
|
106 |
+
phn_seq = phn_seq.replace(k, v)
|
107 |
+
phn_vocab.update(phn_seq.split(" "))
|
108 |
+
all_lens.append(len(phn_seq.split(" ")))
|
109 |
+
with open(save_fn, "w") as f:
|
110 |
+
f.write(phn_seq)
|
111 |
+
logging.info(f"split {split} has {len(gs[split])} samples in total, skipped {skip} due to forbiden words")
|
112 |
+
|
113 |
+
print(f"phn vocab size: {len(list(phn_vocab))}")
|
114 |
+
print("phn sequence stats: ")
|
115 |
+
print(f"longest: {max(all_lens)}")
|
116 |
+
print(f"shortest: {min(all_lens)}")
|
117 |
+
print(f"median: {np.quantile(all_lens, 0.5)}")
|
118 |
+
print(f"95 percentile longest: {np.quantile(all_lens, 0.95)}")
|
119 |
+
print("write vocabulary to ", vocab_fn)
|
120 |
+
with open(vocab_fn, "w") as f:
|
121 |
+
for i, phn in enumerate(list(phn_vocab)):
|
122 |
+
if i < len(list(phn_vocab)) - 1:
|
123 |
+
f.write(f"{str(i)} {phn}\n")
|
124 |
+
else:
|
125 |
+
f.write(f"{str(i)} {phn}")
|
126 |
+
|
127 |
+
class mydataset(torch.utils.data.Dataset):
|
128 |
+
def __init__(self, split):
|
129 |
+
super().__init__()
|
130 |
+
self.data = gs[split]
|
131 |
+
def __len__(self):
|
132 |
+
return len(self.data)
|
133 |
+
def __getitem__(self, ind):
|
134 |
+
try:
|
135 |
+
segment_id, audio, sr, text, begin_time, end_time = self.data[ind]['segment_id'], torch.from_numpy(self.data[ind]['audio']['array']).float(), self.data[ind]['audio']['sampling_rate'], self.data[ind]['text'], self.data[ind]['begin_time'], self.data[ind]['end_time']
|
136 |
+
except:
|
137 |
+
return None, None, None, None, None, None
|
138 |
+
|
139 |
+
return segment_id, audio, sr, text, begin_time, end_time
|
140 |
+
def collate(self, batch):
|
141 |
+
res = {'segment_id': [], "audio": [], "sr": [], "text": [], "begin_time": [], "end_time": []}
|
142 |
+
for item in batch:
|
143 |
+
if item[0] != None:
|
144 |
+
res['segment_id'].append(item[0])
|
145 |
+
res['audio'].append(item[1])
|
146 |
+
res['sr'].append(item[2])
|
147 |
+
res['text'].append(item[3])
|
148 |
+
res['begin_time'].append(item[4])
|
149 |
+
res['end_time'].append(item[5])
|
150 |
+
return res
|
151 |
+
|
152 |
+
|
153 |
+
## encodec codes extraction
|
154 |
+
logging.info("encodec encoding...")
|
155 |
+
train_dataset = mydataset('train')
|
156 |
+
train_loader = torch.torch.utils.data.DataLoader(train_dataset, batch_size=args.mega_batch_size, shuffle=False, drop_last=False, num_workers=args.n_workers, collate_fn=train_dataset.collate)
|
157 |
+
validation_dataset = mydataset('validation')
|
158 |
+
validation_loader = torch.torch.utils.data.DataLoader(validation_dataset, batch_size=args.mega_batch_size, shuffle=False, drop_last=False, num_workers=args.n_workers, collate_fn=validation_dataset.collate)
|
159 |
+
test_dataset = mydataset('test')
|
160 |
+
test_loader = torch.torch.utils.data.DataLoader(test_dataset, batch_size=args.mega_batch_size, shuffle=False, drop_last=False, num_workers=args.n_workers, collate_fn=test_dataset.collate)
|
161 |
+
splits = ['validation', 'test', 'train']
|
162 |
+
loaders = [validation_loader, test_loader, train_loader]
|
163 |
+
# splits = ['validation'] # for debug
|
164 |
+
# loaders = [validation_loader]
|
165 |
+
for split, loader in zip(splits, loaders):
|
166 |
+
skip = 0
|
167 |
+
logging.info(f"now processing split {split}...")
|
168 |
+
mega_n_steps = int(np.ceil(len(gs[split]) / args.mega_batch_size))
|
169 |
+
logging.info(f"partition the split {split} into {mega_n_steps} parts, each has {args.mega_batch_size} samples")
|
170 |
+
for m, mega_batch in enumerate(loader):
|
171 |
+
logging.info(f"====================================")
|
172 |
+
logging.info(f"====================================")
|
173 |
+
logging.info(f"now processing mega step {m+1}/{mega_n_steps}")
|
174 |
+
lengths = np.array(mega_batch['end_time']) - np.array(mega_batch['begin_time'])
|
175 |
+
sorted_inds = sort_by_audio_len(lengths)
|
176 |
+
for j in range(len(sorted_inds))[::-1]:
|
177 |
+
if lengths[sorted_inds[j]] < 0.2 or lengths[sorted_inds[j]] > args.len_cap: # skip samples that are too short (shorter than 0.2s), or too big (bigger than 80s)
|
178 |
+
skip += 1
|
179 |
+
del sorted_inds[j]
|
180 |
+
|
181 |
+
n_steps = int(np.ceil(len(sorted_inds) / args.batch_size))
|
182 |
+
for n in tqdm.tqdm(range(n_steps), disable=True):
|
183 |
+
inds_used = sorted_inds[n*args.batch_size:(n+1)*args.batch_size]
|
184 |
+
audio_batch = [mega_batch['audio'][id] for id in inds_used]
|
185 |
+
sr_batch = [mega_batch['sr'][id] for id in inds_used]
|
186 |
+
segment_id_batch = [mega_batch['segment_id'][id] for id in inds_used]
|
187 |
+
text_batch = [mega_batch['text'][id] for id in inds_used]
|
188 |
+
padded_wav = torch.nn.utils.rnn.pad_sequence(audio_batch, batch_first=True).unsqueeze(1) # [B, T] -> [B, 1, T]
|
189 |
+
all_lens = [lengths[id] for id in inds_used]
|
190 |
+
with torch.no_grad():
|
191 |
+
if max(all_lens) > args.max_len and len(all_lens) > 1: # NOTE decrease args.max_len if OOM, or chunk it into more than 2 forward passes
|
192 |
+
codes = []
|
193 |
+
inwav = padded_wav.cuda()
|
194 |
+
codes.append(model.encode(inwav[:len(inwav)//2])[0].cpu())
|
195 |
+
codes.append(model.encode(inwav[len(inwav)//2:])[0].cpu())
|
196 |
+
codes = torch.cat(codes, dim=0)
|
197 |
+
else:
|
198 |
+
encoded_frames = model.encode(padded_wav.cuda())
|
199 |
+
# logging.info(f"encoded_frames: {encoded_frames[0].shape}")
|
200 |
+
codes = encoded_frames[0].cpu()
|
201 |
+
|
202 |
+
for i, length in enumerate(all_lens):
|
203 |
+
save_fn = os.path.join(codes_save_root, segment_id_batch[i]+".txt")
|
204 |
+
actual_len = round(length * args.model_code_sr) # 320 is downsample rate for this model
|
205 |
+
cur_code = codes[i].tolist() if type(codes) == list else codes[i, :, :actual_len].tolist()
|
206 |
+
write_array_to_txt_file(cur_code, save_fn)
|
data/tokenizer.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# cp from https://github.com/lifeiteng/vall-e/blob/main/valle/data/tokenizer.py
|
2 |
+
# Copyright 2023 (authors: Feiteng Li)
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import re
|
17 |
+
from dataclasses import asdict, dataclass
|
18 |
+
from typing import Any, Dict, List, Optional, Pattern, Union
|
19 |
+
|
20 |
+
import numpy as np
|
21 |
+
import torch
|
22 |
+
import torchaudio
|
23 |
+
# from lhotse.features import FeatureExtractor
|
24 |
+
# from lhotse.utils import Seconds, compute_num_frames
|
25 |
+
from phonemizer.backend import EspeakBackend
|
26 |
+
from phonemizer.backend.espeak.language_switch import LanguageSwitch
|
27 |
+
from phonemizer.backend.espeak.words_mismatch import WordMismatch
|
28 |
+
from phonemizer.punctuation import Punctuation
|
29 |
+
from phonemizer.separator import Separator
|
30 |
+
|
31 |
+
|
32 |
+
|
33 |
+
class TextTokenizer:
|
34 |
+
"""Phonemize Text."""
|
35 |
+
|
36 |
+
def __init__(
|
37 |
+
self,
|
38 |
+
language="en-us",
|
39 |
+
backend="espeak",
|
40 |
+
separator=Separator(word="_", syllable="-", phone="|"),
|
41 |
+
preserve_punctuation=True,
|
42 |
+
punctuation_marks: Union[str, Pattern] = Punctuation.default_marks(),
|
43 |
+
with_stress: bool = False,
|
44 |
+
tie: Union[bool, str] = False,
|
45 |
+
language_switch: LanguageSwitch = "keep-flags",
|
46 |
+
words_mismatch: WordMismatch = "ignore",
|
47 |
+
) -> None:
|
48 |
+
phonemizer = EspeakBackend(
|
49 |
+
language,
|
50 |
+
punctuation_marks=punctuation_marks,
|
51 |
+
preserve_punctuation=preserve_punctuation,
|
52 |
+
with_stress=with_stress,
|
53 |
+
tie=tie,
|
54 |
+
language_switch=language_switch,
|
55 |
+
words_mismatch=words_mismatch,
|
56 |
+
)
|
57 |
+
|
58 |
+
self.backend = phonemizer
|
59 |
+
self.separator = separator
|
60 |
+
|
61 |
+
def to_list(self, phonemized: str) -> List[str]:
|
62 |
+
fields = []
|
63 |
+
for word in phonemized.split(self.separator.word):
|
64 |
+
# "ɐ m|iː|n?" ɹ|ɪ|z|ɜː|v; h|ɪ|z.
|
65 |
+
pp = re.findall(r"\w+|[^\w\s]", word, re.UNICODE)
|
66 |
+
fields.extend(
|
67 |
+
[p for p in pp if p != self.separator.phone]
|
68 |
+
+ [self.separator.word]
|
69 |
+
)
|
70 |
+
assert len("".join(fields[:-1])) == len(phonemized) - phonemized.count(
|
71 |
+
self.separator.phone
|
72 |
+
)
|
73 |
+
return fields[:-1]
|
74 |
+
|
75 |
+
def __call__(self, text, strip=True) -> List[List[str]]:
|
76 |
+
if isinstance(text, str):
|
77 |
+
text = [text]
|
78 |
+
|
79 |
+
phonemized = self.backend.phonemize(
|
80 |
+
text, separator=self.separator, strip=strip, njobs=1
|
81 |
+
)
|
82 |
+
return [self.to_list(p) for p in phonemized]
|
83 |
+
|
84 |
+
|
85 |
+
def tokenize_text(tokenizer: TextTokenizer, text: str) -> List[str]:
|
86 |
+
phonemes = tokenizer([text.strip()])
|
87 |
+
return phonemes[0] # k2symbols
|
88 |
+
|
89 |
+
def convert_audio(wav: torch.Tensor, sr: int, target_sr: int, target_channels: int):
|
90 |
+
assert wav.shape[0] in [1, 2], "Audio must be mono or stereo."
|
91 |
+
if target_channels == 1:
|
92 |
+
wav = wav.mean(0, keepdim=True)
|
93 |
+
elif target_channels == 2:
|
94 |
+
*shape, _, length = wav.shape
|
95 |
+
wav = wav.expand(*shape, target_channels, length)
|
96 |
+
elif wav.shape[0] == 1:
|
97 |
+
wav = wav.expand(target_channels, -1)
|
98 |
+
wav = torchaudio.transforms.Resample(sr, target_sr)(wav)
|
99 |
+
return wav
|
100 |
+
|
101 |
+
class AudioTokenizer:
|
102 |
+
"""EnCodec audio."""
|
103 |
+
|
104 |
+
def __init__(
|
105 |
+
self,
|
106 |
+
device: Any = None,
|
107 |
+
signature = None
|
108 |
+
) -> None:
|
109 |
+
from audiocraft.solvers import CompressionSolver
|
110 |
+
model = CompressionSolver.model_from_checkpoint(signature)
|
111 |
+
self.sample_rate = model.sample_rate
|
112 |
+
self.channels = model.channels
|
113 |
+
|
114 |
+
if not device:
|
115 |
+
device = torch.device("cpu")
|
116 |
+
if torch.cuda.is_available():
|
117 |
+
device = torch.device("cuda:0")
|
118 |
+
|
119 |
+
self._device = device
|
120 |
+
|
121 |
+
self.codec = model.to(device)
|
122 |
+
|
123 |
+
@property
|
124 |
+
def device(self):
|
125 |
+
return self._device
|
126 |
+
|
127 |
+
def encode(self, wav: torch.Tensor) -> torch.Tensor:
|
128 |
+
codes = self.codec.encode(wav.to(self.device))
|
129 |
+
return [(codes[0], None)]
|
130 |
+
|
131 |
+
def decode(self, frames: torch.Tensor) -> torch.Tensor:
|
132 |
+
frames = frames[0][0] # [1,4,T]
|
133 |
+
return self.codec.decode(frames)
|
134 |
+
|
135 |
+
|
136 |
+
|
137 |
+
def tokenize_audio(tokenizer: AudioTokenizer, audio_path: str, offset = -1, num_frames=-1):
|
138 |
+
# Load and pre-process the audio waveform
|
139 |
+
if offset != -1 and num_frames!=-1:
|
140 |
+
wav, sr = torchaudio.load(audio_path, frame_offset=offset, num_frames=num_frames)
|
141 |
+
else:
|
142 |
+
wav, sr = torchaudio.load(audio_path)
|
143 |
+
wav = convert_audio(wav, sr, tokenizer.sample_rate, tokenizer.channels)
|
144 |
+
wav = wav.unsqueeze(0)
|
145 |
+
|
146 |
+
# Extract discrete codes from EnCodec
|
147 |
+
with torch.no_grad():
|
148 |
+
encoded_frames = tokenizer.encode(wav)
|
149 |
+
return encoded_frames
|
demo/84_121550_000074_000000.wav
ADDED
Binary file (508 kB). View file
|
|
demo/generated_se/84_121550_000074_000000_new_seed1.wav
ADDED
Binary file (426 kB). View file
|
|
demo/generated_se/84_121550_000074_000000_orig.wav
ADDED
Binary file (508 kB). View file
|
|
demo/generated_tts/84_121550_000074_000000_concat_seed1.wav
ADDED
Binary file (522 kB). View file
|
|
demo/generated_tts/84_121550_000074_000000_gen_seed1.wav
ADDED
Binary file (329 kB). View file
|
|
demo/temp/84_121550_000074_000000.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
But when I had approached so near to them The common object, which the sense deceives, Lost not by distance any of its marks,
|
demo/temp/84_121550_000074_000000.wav
ADDED
Binary file (508 kB). View file
|
|
demo/temp/mfa_alignments/84_121550_000074_000000.csv
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Begin,End,Label,Type,Speaker
|
2 |
+
0.03,0.18,but,words,temp
|
3 |
+
0.18,0.32,when,words,temp
|
4 |
+
0.32,0.48,i,words,temp
|
5 |
+
0.48,0.64,had,words,temp
|
6 |
+
0.64,1.19,approached,words,temp
|
7 |
+
1.22,1.58,so,words,temp
|
8 |
+
1.58,1.91,near,words,temp
|
9 |
+
1.91,2.07,to,words,temp
|
10 |
+
2.07,2.42,them,words,temp
|
11 |
+
2.53,2.61,the,words,temp
|
12 |
+
2.61,3.01,common,words,temp
|
13 |
+
3.05,3.62,object,words,temp
|
14 |
+
3.68,3.93,which,words,temp
|
15 |
+
3.93,4.02,the,words,temp
|
16 |
+
4.02,4.34,sense,words,temp
|
17 |
+
4.34,4.97,deceives,words,temp
|
18 |
+
5.04,5.54,lost,words,temp
|
19 |
+
5.54,6.0,not,words,temp
|
20 |
+
6.0,6.14,by,words,temp
|
21 |
+
6.14,6.67,distance,words,temp
|
22 |
+
6.79,7.05,any,words,temp
|
23 |
+
7.05,7.18,of,words,temp
|
24 |
+
7.18,7.34,its,words,temp
|
25 |
+
7.34,7.87,marks,words,temp
|
26 |
+
0.03,0.06,B,phones,temp
|
27 |
+
0.06,0.09,AH1,phones,temp
|
28 |
+
0.09,0.18,T,phones,temp
|
29 |
+
0.18,0.23,W,phones,temp
|
30 |
+
0.23,0.27,EH1,phones,temp
|
31 |
+
0.27,0.32,N,phones,temp
|
32 |
+
0.32,0.48,AY1,phones,temp
|
33 |
+
0.48,0.49,HH,phones,temp
|
34 |
+
0.49,0.6,AE1,phones,temp
|
35 |
+
0.6,0.64,D,phones,temp
|
36 |
+
0.64,0.7,AH0,phones,temp
|
37 |
+
0.7,0.83,P,phones,temp
|
38 |
+
0.83,0.88,R,phones,temp
|
39 |
+
0.88,0.99,OW1,phones,temp
|
40 |
+
0.99,1.12,CH,phones,temp
|
41 |
+
1.12,1.19,T,phones,temp
|
42 |
+
1.22,1.4,S,phones,temp
|
43 |
+
1.4,1.58,OW1,phones,temp
|
44 |
+
1.58,1.7,N,phones,temp
|
45 |
+
1.7,1.84,IH1,phones,temp
|
46 |
+
1.84,1.91,R,phones,temp
|
47 |
+
1.91,2.01,T,phones,temp
|
48 |
+
2.01,2.07,AH0,phones,temp
|
49 |
+
2.07,2.13,DH,phones,temp
|
50 |
+
2.13,2.3,EH1,phones,temp
|
51 |
+
2.3,2.42,M,phones,temp
|
52 |
+
2.53,2.55,DH,phones,temp
|
53 |
+
2.55,2.61,AH0,phones,temp
|
54 |
+
2.61,2.73,K,phones,temp
|
55 |
+
2.73,2.85,AA1,phones,temp
|
56 |
+
2.85,2.9,M,phones,temp
|
57 |
+
2.9,2.95,AH0,phones,temp
|
58 |
+
2.95,3.01,N,phones,temp
|
59 |
+
3.05,3.22,AA1,phones,temp
|
60 |
+
3.22,3.27,B,phones,temp
|
61 |
+
3.27,3.34,JH,phones,temp
|
62 |
+
3.34,3.48,EH0,phones,temp
|
63 |
+
3.48,3.54,K,phones,temp
|
64 |
+
3.54,3.62,T,phones,temp
|
65 |
+
3.68,3.69,HH,phones,temp
|
66 |
+
3.69,3.76,W,phones,temp
|
67 |
+
3.76,3.8,IH1,phones,temp
|
68 |
+
3.8,3.93,CH,phones,temp
|
69 |
+
3.93,3.95,DH,phones,temp
|
70 |
+
3.95,4.02,AH0,phones,temp
|
71 |
+
4.02,4.12,S,phones,temp
|
72 |
+
4.12,4.21,EH1,phones,temp
|
73 |
+
4.21,4.27,N,phones,temp
|
74 |
+
4.27,4.34,S,phones,temp
|
75 |
+
4.34,4.42,D,phones,temp
|
76 |
+
4.42,4.45,IH0,phones,temp
|
77 |
+
4.45,4.59,S,phones,temp
|
78 |
+
4.59,4.79,IY1,phones,temp
|
79 |
+
4.79,4.87,V,phones,temp
|
80 |
+
4.87,4.97,Z,phones,temp
|
81 |
+
5.04,5.12,L,phones,temp
|
82 |
+
5.12,5.33,AO1,phones,temp
|
83 |
+
5.33,5.42,S,phones,temp
|
84 |
+
5.42,5.54,T,phones,temp
|
85 |
+
5.54,5.7,N,phones,temp
|
86 |
+
5.7,5.89,AA1,phones,temp
|
87 |
+
5.89,6.0,T,phones,temp
|
88 |
+
6.0,6.05,B,phones,temp
|
89 |
+
6.05,6.14,AY1,phones,temp
|
90 |
+
6.14,6.24,D,phones,temp
|
91 |
+
6.24,6.3,IH1,phones,temp
|
92 |
+
6.3,6.38,S,phones,temp
|
93 |
+
6.38,6.45,T,phones,temp
|
94 |
+
6.45,6.51,AH0,phones,temp
|
95 |
+
6.51,6.57,N,phones,temp
|
96 |
+
6.57,6.67,S,phones,temp
|
97 |
+
6.79,6.89,EH1,phones,temp
|
98 |
+
6.89,6.95,N,phones,temp
|
99 |
+
6.95,7.05,IY0,phones,temp
|
100 |
+
7.05,7.13,AH0,phones,temp
|
101 |
+
7.13,7.18,V,phones,temp
|
102 |
+
7.18,7.22,IH0,phones,temp
|
103 |
+
7.22,7.29,T,phones,temp
|
104 |
+
7.29,7.34,S,phones,temp
|
105 |
+
7.34,7.39,M,phones,temp
|
106 |
+
7.39,7.5,AA1,phones,temp
|
107 |
+
7.5,7.58,R,phones,temp
|
108 |
+
7.58,7.7,K,phones,temp
|
109 |
+
7.7,7.87,S,phones,temp
|
gradio_app.py
ADDED
@@ -0,0 +1,528 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
# os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
3 |
+
# os.environ["CUDA_VISIBLE_DEVICES"] = "5" # these are only used if developping locally
|
4 |
+
import gradio as gr
|
5 |
+
import torch
|
6 |
+
import torchaudio
|
7 |
+
from data.tokenizer import (
|
8 |
+
AudioTokenizer,
|
9 |
+
TextTokenizer,
|
10 |
+
)
|
11 |
+
from models import voicecraft
|
12 |
+
import io
|
13 |
+
import numpy as np
|
14 |
+
import random
|
15 |
+
import spaces
|
16 |
+
|
17 |
+
|
18 |
+
whisper_model, voicecraft_model = None, None
|
19 |
+
|
20 |
+
@spaces.GPU(duration=20)
|
21 |
+
def seed_everything(seed):
|
22 |
+
if seed != -1:
|
23 |
+
os.environ['PYTHONHASHSEED'] = str(seed)
|
24 |
+
random.seed(seed)
|
25 |
+
np.random.seed(seed)
|
26 |
+
torch.manual_seed(seed)
|
27 |
+
torch.cuda.manual_seed(seed)
|
28 |
+
torch.backends.cudnn.benchmark = False
|
29 |
+
torch.backends.cudnn.deterministic = True
|
30 |
+
|
31 |
+
@spaces.GPU(duration=120)
|
32 |
+
def load_models(whisper_model_choice, voicecraft_model_choice):
|
33 |
+
global whisper_model, voicecraft_model
|
34 |
+
|
35 |
+
if whisper_model_choice is not None:
|
36 |
+
import whisper
|
37 |
+
from whisper.tokenizer import get_tokenizer
|
38 |
+
whisper_model = {
|
39 |
+
"model": whisper.load_model(whisper_model_choice),
|
40 |
+
"tokenizer": get_tokenizer(multilingual=False)
|
41 |
+
}
|
42 |
+
|
43 |
+
|
44 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
45 |
+
|
46 |
+
voicecraft_name = f"{voicecraft_model_choice}.pth"
|
47 |
+
ckpt_fn = f"./pretrained_models/{voicecraft_name}"
|
48 |
+
encodec_fn = "./pretrained_models/encodec_4cb2048_giga.th"
|
49 |
+
if not os.path.exists(ckpt_fn):
|
50 |
+
os.system(f"wget https://huggingface.co/pyp1/VoiceCraft/resolve/main/{voicecraft_name}\?download\=true")
|
51 |
+
os.system(f"mv {voicecraft_name}\?download\=true ./pretrained_models/{voicecraft_name}")
|
52 |
+
if not os.path.exists(encodec_fn):
|
53 |
+
os.system(f"wget https://huggingface.co/pyp1/VoiceCraft/resolve/main/encodec_4cb2048_giga.th")
|
54 |
+
os.system(f"mv encodec_4cb2048_giga.th ./pretrained_models/encodec_4cb2048_giga.th")
|
55 |
+
|
56 |
+
ckpt = torch.load(ckpt_fn, map_location="cpu")
|
57 |
+
model = voicecraft.VoiceCraft(ckpt["config"])
|
58 |
+
model.load_state_dict(ckpt["model"])
|
59 |
+
model.to(device)
|
60 |
+
model.eval()
|
61 |
+
voicecraft_model = {
|
62 |
+
"ckpt": ckpt,
|
63 |
+
"model": model,
|
64 |
+
"text_tokenizer": TextTokenizer(backend="espeak"),
|
65 |
+
"audio_tokenizer": AudioTokenizer(signature=encodec_fn)
|
66 |
+
}
|
67 |
+
|
68 |
+
return gr.Accordion()
|
69 |
+
|
70 |
+
@spaces.GPU(duration=60)
|
71 |
+
def transcribe(seed, audio_path):
|
72 |
+
if whisper_model is None:
|
73 |
+
raise gr.Error("Whisper model not loaded")
|
74 |
+
seed_everything(seed)
|
75 |
+
|
76 |
+
number_tokens = [
|
77 |
+
i
|
78 |
+
for i in range(whisper_model["tokenizer"].eot)
|
79 |
+
if all(c in "0123456789" for c in whisper_model["tokenizer"].decode([i]).removeprefix(" "))
|
80 |
+
]
|
81 |
+
result = whisper_model["model"].transcribe(audio_path, suppress_tokens=[-1] + number_tokens, word_timestamps=True)
|
82 |
+
words = [word_info for segment in result["segments"] for word_info in segment["words"]]
|
83 |
+
|
84 |
+
transcript = result["text"]
|
85 |
+
transcript_with_start_time = " ".join([f"{word['start']} {word['word']}" for word in words])
|
86 |
+
transcript_with_end_time = " ".join([f"{word['word']} {word['end']}" for word in words])
|
87 |
+
|
88 |
+
choices = [f"{word['start']} {word['word']} {word['end']}" for word in words]
|
89 |
+
|
90 |
+
return [
|
91 |
+
transcript, transcript_with_start_time, transcript_with_end_time,
|
92 |
+
gr.Dropdown(value=choices[-1], choices=choices, interactive=True), # prompt_to_word
|
93 |
+
gr.Dropdown(value=choices[0], choices=choices, interactive=True), # edit_from_word
|
94 |
+
gr.Dropdown(value=choices[-1], choices=choices, interactive=True), # edit_to_word
|
95 |
+
words
|
96 |
+
]
|
97 |
+
|
98 |
+
|
99 |
+
def get_output_audio(audio_tensors, codec_audio_sr):
|
100 |
+
result = torch.cat(audio_tensors, 1)
|
101 |
+
buffer = io.BytesIO()
|
102 |
+
torchaudio.save(buffer, result, int(codec_audio_sr), format="wav")
|
103 |
+
buffer.seek(0)
|
104 |
+
return buffer.read()
|
105 |
+
|
106 |
+
@spaces.GPU(duration=90)
|
107 |
+
def run(seed, left_margin, right_margin, codec_audio_sr, codec_sr, top_k, top_p, temperature,
|
108 |
+
stop_repetition, sample_batch_size, kvcache, silence_tokens,
|
109 |
+
audio_path, word_info, transcript, smart_transcript,
|
110 |
+
mode, prompt_end_time, edit_start_time, edit_end_time,
|
111 |
+
split_text, selected_sentence, previous_audio_tensors):
|
112 |
+
if voicecraft_model is None:
|
113 |
+
raise gr.Error("VoiceCraft model not loaded")
|
114 |
+
if smart_transcript and (word_info is None):
|
115 |
+
raise gr.Error("Can't use smart transcript: whisper transcript not found")
|
116 |
+
|
117 |
+
seed_everything(seed)
|
118 |
+
if mode == "Long TTS":
|
119 |
+
if split_text == "Newline":
|
120 |
+
sentences = transcript.split('\n')
|
121 |
+
else:
|
122 |
+
from nltk.tokenize import sent_tokenize
|
123 |
+
sentences = sent_tokenize(transcript.replace("\n", " "))
|
124 |
+
elif mode == "Rerun":
|
125 |
+
colon_position = selected_sentence.find(':')
|
126 |
+
selected_sentence_idx = int(selected_sentence[:colon_position])
|
127 |
+
sentences = [selected_sentence[colon_position + 1:]]
|
128 |
+
else:
|
129 |
+
sentences = [transcript.replace("\n", " ")]
|
130 |
+
|
131 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
132 |
+
info = torchaudio.info(audio_path)
|
133 |
+
audio_dur = info.num_frames / info.sample_rate
|
134 |
+
|
135 |
+
audio_tensors = []
|
136 |
+
inference_transcript = ""
|
137 |
+
for sentence in sentences:
|
138 |
+
decode_config = {"top_k": top_k, "top_p": top_p, "temperature": temperature, "stop_repetition": stop_repetition,
|
139 |
+
"kvcache": kvcache, "codec_audio_sr": codec_audio_sr, "codec_sr": codec_sr,
|
140 |
+
"silence_tokens": silence_tokens, "sample_batch_size": sample_batch_size}
|
141 |
+
if mode != "Edit":
|
142 |
+
from inference_tts_scale import inference_one_sample
|
143 |
+
|
144 |
+
if smart_transcript:
|
145 |
+
target_transcript = ""
|
146 |
+
for word in word_info:
|
147 |
+
if word["end"] < prompt_end_time:
|
148 |
+
target_transcript += word["word"]
|
149 |
+
elif (word["start"] + word["end"]) / 2 < prompt_end_time:
|
150 |
+
# include part of the word it it's big, but adjust prompt_end_time
|
151 |
+
target_transcript += word["word"]
|
152 |
+
prompt_end_time = word["end"]
|
153 |
+
break
|
154 |
+
else:
|
155 |
+
break
|
156 |
+
target_transcript += f" {sentence}"
|
157 |
+
else:
|
158 |
+
target_transcript = sentence
|
159 |
+
|
160 |
+
inference_transcript += target_transcript + "\n"
|
161 |
+
|
162 |
+
prompt_end_frame = int(min(audio_dur, prompt_end_time) * info.sample_rate)
|
163 |
+
_, gen_audio = inference_one_sample(voicecraft_model["model"],
|
164 |
+
voicecraft_model["ckpt"]["config"],
|
165 |
+
voicecraft_model["ckpt"]["phn2num"],
|
166 |
+
voicecraft_model["text_tokenizer"], voicecraft_model["audio_tokenizer"],
|
167 |
+
audio_path, target_transcript, device, decode_config,
|
168 |
+
prompt_end_frame)
|
169 |
+
else:
|
170 |
+
from inference_speech_editing_scale import inference_one_sample
|
171 |
+
|
172 |
+
if smart_transcript:
|
173 |
+
target_transcript = ""
|
174 |
+
for word in word_info:
|
175 |
+
if word["start"] < edit_start_time:
|
176 |
+
target_transcript += word["word"]
|
177 |
+
else:
|
178 |
+
break
|
179 |
+
target_transcript += f" {sentence}"
|
180 |
+
for word in word_info:
|
181 |
+
if word["end"] > edit_end_time:
|
182 |
+
target_transcript += word["word"]
|
183 |
+
else:
|
184 |
+
target_transcript = sentence
|
185 |
+
|
186 |
+
inference_transcript += target_transcript + "\n"
|
187 |
+
|
188 |
+
morphed_span = (max(edit_start_time - left_margin, 1 / codec_sr), min(edit_end_time + right_margin, audio_dur))
|
189 |
+
mask_interval = [[round(morphed_span[0]*codec_sr), round(morphed_span[1]*codec_sr)]]
|
190 |
+
mask_interval = torch.LongTensor(mask_interval)
|
191 |
+
|
192 |
+
_, gen_audio = inference_one_sample(voicecraft_model["model"],
|
193 |
+
voicecraft_model["ckpt"]["config"],
|
194 |
+
voicecraft_model["ckpt"]["phn2num"],
|
195 |
+
voicecraft_model["text_tokenizer"], voicecraft_model["audio_tokenizer"],
|
196 |
+
audio_path, target_transcript, mask_interval, device, decode_config)
|
197 |
+
gen_audio = gen_audio[0].cpu()
|
198 |
+
audio_tensors.append(gen_audio)
|
199 |
+
|
200 |
+
if mode != "Rerun":
|
201 |
+
output_audio = get_output_audio(audio_tensors, codec_audio_sr)
|
202 |
+
sentences = [f"{idx}: {text}" for idx, text in enumerate(sentences)]
|
203 |
+
component = gr.Dropdown(choices=sentences, value=sentences[0])
|
204 |
+
return output_audio, inference_transcript, component, audio_tensors
|
205 |
+
else:
|
206 |
+
previous_audio_tensors[selected_sentence_idx] = audio_tensors[0]
|
207 |
+
output_audio = get_output_audio(previous_audio_tensors, codec_audio_sr)
|
208 |
+
sentence_audio = get_output_audio(audio_tensors, codec_audio_sr)
|
209 |
+
return output_audio, inference_transcript, sentence_audio, previous_audio_tensors
|
210 |
+
|
211 |
+
|
212 |
+
def update_input_audio(audio_path):
|
213 |
+
if audio_path is None:
|
214 |
+
return 0, 0, 0
|
215 |
+
|
216 |
+
info = torchaudio.info(audio_path)
|
217 |
+
max_time = round(info.num_frames / info.sample_rate, 2)
|
218 |
+
return [
|
219 |
+
gr.Slider(maximum=max_time, value=max_time),
|
220 |
+
gr.Slider(maximum=max_time, value=0),
|
221 |
+
gr.Slider(maximum=max_time, value=max_time),
|
222 |
+
]
|
223 |
+
|
224 |
+
|
225 |
+
def change_mode(mode):
|
226 |
+
tts_mode_controls, edit_mode_controls, edit_word_mode, split_text, long_tts_sentence_editor
|
227 |
+
return [
|
228 |
+
gr.Group(visible=mode != "Edit"),
|
229 |
+
gr.Group(visible=mode == "Edit"),
|
230 |
+
gr.Radio(visible=mode == "Edit"),
|
231 |
+
gr.Radio(visible=mode == "Long TTS"),
|
232 |
+
gr.Group(visible=mode == "Long TTS"),
|
233 |
+
]
|
234 |
+
|
235 |
+
|
236 |
+
def load_sentence(selected_sentence, codec_audio_sr, audio_tensors):
|
237 |
+
if selected_sentence is None:
|
238 |
+
return None
|
239 |
+
colon_position = selected_sentence.find(':')
|
240 |
+
selected_sentence_idx = int(selected_sentence[:colon_position])
|
241 |
+
return get_output_audio([audio_tensors[selected_sentence_idx]], codec_audio_sr)
|
242 |
+
|
243 |
+
|
244 |
+
def update_bound_word(is_first_word, selected_word, edit_word_mode):
|
245 |
+
if selected_word is None:
|
246 |
+
return None
|
247 |
+
|
248 |
+
word_start_time = float(selected_word.split(' ')[0])
|
249 |
+
word_end_time = float(selected_word.split(' ')[-1])
|
250 |
+
if edit_word_mode == "Replace half":
|
251 |
+
bound_time = (word_start_time + word_end_time) / 2
|
252 |
+
elif is_first_word:
|
253 |
+
bound_time = word_start_time
|
254 |
+
else:
|
255 |
+
bound_time = word_end_time
|
256 |
+
|
257 |
+
return bound_time
|
258 |
+
|
259 |
+
|
260 |
+
def update_bound_words(from_selected_word, to_selected_word, edit_word_mode):
|
261 |
+
return [
|
262 |
+
update_bound_word(True, from_selected_word, edit_word_mode),
|
263 |
+
update_bound_word(False, to_selected_word, edit_word_mode),
|
264 |
+
]
|
265 |
+
|
266 |
+
|
267 |
+
smart_transcript_info = """
|
268 |
+
If enabled, the target transcript will be constructed for you:</br>
|
269 |
+
- In TTS and Long TTS mode just write the text you want to synthesize.</br>
|
270 |
+
- In Edit mode just write the text to replace selected editing segment.</br>
|
271 |
+
If disabled, you should write the target transcript yourself:</br>
|
272 |
+
- In TTS mode write prompt transcript followed by generation transcript.</br>
|
273 |
+
- In Long TTS select split by newline (<b>SENTENCE SPLIT WON'T WORK</b>) and start each line with a prompt transcript.</br>
|
274 |
+
- In Edit mode write full prompt</br>
|
275 |
+
"""
|
276 |
+
|
277 |
+
demo_original_transcript = " But when I had approached so near to them, the common object, which the sense deceives, lost not by distance any of its marks."
|
278 |
+
|
279 |
+
demo_text = {
|
280 |
+
"TTS": {
|
281 |
+
"smart": "I cannot believe that the same model can also do text to speech synthesis as well!",
|
282 |
+
"regular": "But when I had approached so near to them, the common I cannot believe that the same model can also do text to speech synthesis as well!"
|
283 |
+
},
|
284 |
+
"Edit": {
|
285 |
+
"smart": "saw the mirage of the lake in the distance,",
|
286 |
+
"regular": "But when I saw the mirage of the lake in the distance, which the sense deceives, Lost not by distance any of its marks,"
|
287 |
+
},
|
288 |
+
"Long TTS": {
|
289 |
+
"smart": "You can run generation on a big text!\n"
|
290 |
+
"Just write it line-by-line. Or sentence-by-sentence.\n"
|
291 |
+
"If some sentences sound odd, just rerun generation on them, no need to generate the whole text again!",
|
292 |
+
"regular": "But when I had approached so near to them, the common You can run generation on a big text!\n"
|
293 |
+
"But when I had approached so near to them, the common Just write it line-by-line. Or sentence-by-sentence.\n"
|
294 |
+
"But when I had approached so near to them, the common If some sentences sound odd, just rerun generation on them, no need to generate the whole text again!"
|
295 |
+
}
|
296 |
+
}
|
297 |
+
|
298 |
+
all_demo_texts = {vv for k, v in demo_text.items() for kk, vv in v.items()}
|
299 |
+
|
300 |
+
demo_words = [
|
301 |
+
"0.03 but 0.18",
|
302 |
+
"0.18 when 0.32",
|
303 |
+
"0.32 i 0.48",
|
304 |
+
"0.48 had 0.64",
|
305 |
+
"0.64 approached 1.19",
|
306 |
+
"1.22 so 1.58",
|
307 |
+
"1.58 near 1.91",
|
308 |
+
"1.91 to 2.07",
|
309 |
+
"2.07 them 2.42",
|
310 |
+
"2.53 the 2.61",
|
311 |
+
"2.61 common 3.01",
|
312 |
+
"3.05 object 3.62",
|
313 |
+
"3.68 which 3.93",
|
314 |
+
"3.93 the 4.02",
|
315 |
+
"4.02 sense 4.34",
|
316 |
+
"4.34 deceives 4.97",
|
317 |
+
"5.04 lost 5.54",
|
318 |
+
"5.54 not 6.00",
|
319 |
+
"6.00 by 6.14",
|
320 |
+
"6.14 distance 6.67",
|
321 |
+
"6.79 any 7.05",
|
322 |
+
"7.05 of 7.18",
|
323 |
+
"7.18 its 7.34",
|
324 |
+
"7.34 marks 7.87"
|
325 |
+
]
|
326 |
+
|
327 |
+
demo_word_info = [
|
328 |
+
{"word": "but", "start": 0.03, "end": 0.18},
|
329 |
+
{"word": "when", "start": 0.18, "end": 0.32},
|
330 |
+
{"word": "i", "start": 0.32, "end": 0.48},
|
331 |
+
{"word": "had", "start": 0.48, "end": 0.64},
|
332 |
+
{"word": "approached", "start": 0.64, "end": 1.19},
|
333 |
+
{"word": "so", "start": 1.22, "end": 1.58},
|
334 |
+
{"word": "near", "start": 1.58, "end": 1.91},
|
335 |
+
{"word": "to", "start": 1.91, "end": 2.07},
|
336 |
+
{"word": "them", "start": 2.07, "end": 2.42},
|
337 |
+
{"word": "the", "start": 2.53, "end": 2.61},
|
338 |
+
{"word": "common", "start": 2.61, "end": 3.01},
|
339 |
+
{"word": "object", "start": 3.05, "end": 3.62},
|
340 |
+
{"word": "which", "start": 3.68, "end": 3.93},
|
341 |
+
{"word": "the", "start": 3.93, "end": 4.02},
|
342 |
+
{"word": "sense", "start": 4.02, "end": 4.34},
|
343 |
+
{"word": "deceives", "start": 4.34, "end": 4.97},
|
344 |
+
{"word": "lost", "start": 5.04, "end": 5.54},
|
345 |
+
{"word": "not", "start": 5.54, "end": 6.0},
|
346 |
+
{"word": "by", "start": 6.0, "end": 6.14},
|
347 |
+
{"word": "distance", "start": 6.14, "end": 6.67},
|
348 |
+
{"word": "any", "start": 6.79, "end": 7.05},
|
349 |
+
{"word": "of", "start": 7.05, "end": 7.18},
|
350 |
+
{"word": "its", "start": 7.18, "end": 7.34},
|
351 |
+
{"word": "marks", "start": 7.34, "end": 7.87}
|
352 |
+
]
|
353 |
+
|
354 |
+
|
355 |
+
def update_demo(mode, smart_transcript, edit_word_mode, transcript, edit_from_word, edit_to_word):
|
356 |
+
if transcript not in all_demo_texts:
|
357 |
+
return transcript, edit_from_word, edit_to_word
|
358 |
+
|
359 |
+
replace_half = edit_word_mode == "Replace half"
|
360 |
+
change_edit_from_word = edit_from_word == demo_words[2] or edit_from_word == demo_words[3]
|
361 |
+
change_edit_to_word = edit_to_word == demo_words[11] or edit_to_word == demo_words[12]
|
362 |
+
demo_edit_from_word_value = demo_words[2] if replace_half else demo_words[3]
|
363 |
+
demo_edit_to_word_value = demo_words[12] if replace_half else demo_words[11]
|
364 |
+
return [
|
365 |
+
demo_text[mode]["smart" if smart_transcript else "regular"],
|
366 |
+
demo_edit_from_word_value if change_edit_from_word else edit_from_word,
|
367 |
+
demo_edit_to_word_value if change_edit_to_word else edit_to_word,
|
368 |
+
]
|
369 |
+
|
370 |
+
|
371 |
+
with gr.Blocks() as app:
|
372 |
+
with gr.Row():
|
373 |
+
with gr.Column(scale=2):
|
374 |
+
load_models_btn = gr.Button(value="Load models")
|
375 |
+
with gr.Column(scale=5):
|
376 |
+
with gr.Accordion("Select models", open=False) as models_selector:
|
377 |
+
with gr.Row():
|
378 |
+
voicecraft_model_choice = gr.Radio(label="VoiceCraft model", value="giga830M", choices=["giga330M", "giga830M"])
|
379 |
+
whisper_model_choice = gr.Radio(label="Whisper model", value="base.en",
|
380 |
+
choices=[None, "tiny.en", "base.en", "small.en", "medium.en", "large"])
|
381 |
+
|
382 |
+
with gr.Row():
|
383 |
+
with gr.Column(scale=2):
|
384 |
+
input_audio = gr.Audio(value="./demo/84_121550_000074_000000.wav", label="Input Audio", type="filepath")
|
385 |
+
with gr.Group():
|
386 |
+
original_transcript = gr.Textbox(label="Original transcript", lines=5, value=demo_original_transcript, interactive=False,
|
387 |
+
info="Use whisper model to get the transcript. Fix it if necessary.")
|
388 |
+
with gr.Accordion("Word start time", open=False):
|
389 |
+
transcript_with_start_time = gr.Textbox(label="Start time", lines=5, interactive=False, info="Start time before each word")
|
390 |
+
with gr.Accordion("Word end time", open=False):
|
391 |
+
transcript_with_end_time = gr.Textbox(label="End time", lines=5, interactive=False, info="End time after each word")
|
392 |
+
|
393 |
+
transcribe_btn = gr.Button(value="Transcribe")
|
394 |
+
|
395 |
+
with gr.Column(scale=3):
|
396 |
+
with gr.Group():
|
397 |
+
transcript = gr.Textbox(label="Text", lines=7, value=demo_text["TTS"]["smart"])
|
398 |
+
with gr.Row():
|
399 |
+
smart_transcript = gr.Checkbox(label="Smart transcript", value=True)
|
400 |
+
with gr.Accordion(label="?", open=False):
|
401 |
+
info = gr.Markdown(value=smart_transcript_info)
|
402 |
+
|
403 |
+
with gr.Row():
|
404 |
+
mode = gr.Radio(label="Mode", choices=["TTS", "Edit", "Long TTS"], value="TTS")
|
405 |
+
split_text = gr.Radio(label="Split text", choices=["Newline", "Sentence"], value="Newline",
|
406 |
+
info="Split text into parts and run TTS for each part.", visible=False)
|
407 |
+
edit_word_mode = gr.Radio(label="Edit word mode", choices=["Replace half", "Replace all"], value="Replace half",
|
408 |
+
info="What to do with first and last word", visible=False)
|
409 |
+
|
410 |
+
with gr.Group() as tts_mode_controls:
|
411 |
+
prompt_to_word = gr.Dropdown(label="Last word in prompt", choices=demo_words, value=demo_words[10], interactive=True)
|
412 |
+
prompt_end_time = gr.Slider(label="Prompt end time", minimum=0, maximum=7.93, step=0.01, value=3.01)
|
413 |
+
|
414 |
+
with gr.Group(visible=False) as edit_mode_controls:
|
415 |
+
with gr.Row():
|
416 |
+
edit_from_word = gr.Dropdown(label="First word to edit", choices=demo_words, value=demo_words[2], interactive=True)
|
417 |
+
edit_to_word = gr.Dropdown(label="Last word to edit", choices=demo_words, value=demo_words[12], interactive=True)
|
418 |
+
with gr.Row():
|
419 |
+
edit_start_time = gr.Slider(label="Edit from time", minimum=0, maximum=7.93, step=0.01, value=0.35)
|
420 |
+
edit_end_time = gr.Slider(label="Edit to time", minimum=0, maximum=7.93, step=0.01, value=3.75)
|
421 |
+
|
422 |
+
run_btn = gr.Button(value="Run")
|
423 |
+
|
424 |
+
with gr.Column(scale=2):
|
425 |
+
output_audio = gr.Audio(label="Output Audio")
|
426 |
+
with gr.Accordion("Inference transcript", open=False):
|
427 |
+
inference_transcript = gr.Textbox(label="Inference transcript", lines=5, interactive=False,
|
428 |
+
info="Inference was performed on this transcript.")
|
429 |
+
with gr.Group(visible=False) as long_tts_sentence_editor:
|
430 |
+
sentence_selector = gr.Dropdown(label="Sentence", value=None,
|
431 |
+
info="Select sentence you want to regenerate")
|
432 |
+
sentence_audio = gr.Audio(label="Sentence Audio", scale=2)
|
433 |
+
rerun_btn = gr.Button(value="Rerun")
|
434 |
+
|
435 |
+
with gr.Row():
|
436 |
+
with gr.Accordion("VoiceCraft config", open=False):
|
437 |
+
seed = gr.Number(label="seed", value=-1, precision=0)
|
438 |
+
left_margin = gr.Number(label="left_margin", value=0.08)
|
439 |
+
right_margin = gr.Number(label="right_margin", value=0.08)
|
440 |
+
codec_audio_sr = gr.Number(label="codec_audio_sr", value=16000)
|
441 |
+
codec_sr = gr.Number(label="codec_sr", value=50)
|
442 |
+
top_k = gr.Number(label="top_k", value=0)
|
443 |
+
top_p = gr.Number(label="top_p", value=0.8)
|
444 |
+
temperature = gr.Number(label="temperature", value=1)
|
445 |
+
stop_repetition = gr.Radio(label="stop_repetition", choices=[-1, 1, 2, 3], value=3,
|
446 |
+
info="if there are long silence in the generated audio, reduce the stop_repetition to 3, 2 or even 1, -1 = disabled")
|
447 |
+
sample_batch_size = gr.Number(label="sample_batch_size", value=4, precision=0,
|
448 |
+
info="generate this many samples and choose the shortest one")
|
449 |
+
kvcache = gr.Radio(label="kvcache", choices=[0, 1], value=1,
|
450 |
+
info="set to 0 to use less VRAM, but with slower inference")
|
451 |
+
silence_tokens = gr.Textbox(label="silence tokens", value="[1388,1898,131]")
|
452 |
+
|
453 |
+
|
454 |
+
audio_tensors = gr.State()
|
455 |
+
word_info = gr.State(value=demo_word_info)
|
456 |
+
|
457 |
+
|
458 |
+
mode.change(fn=update_demo,
|
459 |
+
inputs=[mode, smart_transcript, edit_word_mode, transcript, edit_from_word, edit_to_word],
|
460 |
+
outputs=[transcript, edit_from_word, edit_to_word])
|
461 |
+
edit_word_mode.change(fn=update_demo,
|
462 |
+
inputs=[mode, smart_transcript, edit_word_mode, transcript, edit_from_word, edit_to_word],
|
463 |
+
outputs=[transcript, edit_from_word, edit_to_word])
|
464 |
+
smart_transcript.change(fn=update_demo,
|
465 |
+
inputs=[mode, smart_transcript, edit_word_mode, transcript, edit_from_word, edit_to_word],
|
466 |
+
outputs=[transcript, edit_from_word, edit_to_word])
|
467 |
+
|
468 |
+
load_models_btn.click(fn=load_models,
|
469 |
+
inputs=[whisper_model_choice, voicecraft_model_choice],
|
470 |
+
outputs=[models_selector])
|
471 |
+
|
472 |
+
input_audio.upload(fn=update_input_audio,
|
473 |
+
inputs=[input_audio],
|
474 |
+
outputs=[prompt_end_time, edit_start_time, edit_end_time])
|
475 |
+
transcribe_btn.click(fn=transcribe,
|
476 |
+
inputs=[seed, input_audio],
|
477 |
+
outputs=[original_transcript, transcript_with_start_time, transcript_with_end_time,
|
478 |
+
prompt_to_word, edit_from_word, edit_to_word, word_info])
|
479 |
+
|
480 |
+
mode.change(fn=change_mode,
|
481 |
+
inputs=[mode],
|
482 |
+
outputs=[tts_mode_controls, edit_mode_controls, edit_word_mode, split_text, long_tts_sentence_editor])
|
483 |
+
|
484 |
+
run_btn.click(fn=run,
|
485 |
+
inputs=[
|
486 |
+
seed, left_margin, right_margin,
|
487 |
+
codec_audio_sr, codec_sr,
|
488 |
+
top_k, top_p, temperature,
|
489 |
+
stop_repetition, sample_batch_size,
|
490 |
+
kvcache, silence_tokens,
|
491 |
+
input_audio, word_info, transcript, smart_transcript,
|
492 |
+
mode, prompt_end_time, edit_start_time, edit_end_time,
|
493 |
+
split_text, sentence_selector, audio_tensors
|
494 |
+
],
|
495 |
+
outputs=[output_audio, inference_transcript, sentence_selector, audio_tensors])
|
496 |
+
|
497 |
+
sentence_selector.change(fn=load_sentence,
|
498 |
+
inputs=[sentence_selector, codec_audio_sr, audio_tensors],
|
499 |
+
outputs=[sentence_audio])
|
500 |
+
rerun_btn.click(fn=run,
|
501 |
+
inputs=[
|
502 |
+
seed, left_margin, right_margin,
|
503 |
+
codec_audio_sr, codec_sr,
|
504 |
+
top_k, top_p, temperature,
|
505 |
+
stop_repetition, sample_batch_size,
|
506 |
+
kvcache, silence_tokens,
|
507 |
+
input_audio, word_info, transcript, smart_transcript,
|
508 |
+
gr.State(value="Rerun"), prompt_end_time, edit_start_time, edit_end_time,
|
509 |
+
split_text, sentence_selector, audio_tensors
|
510 |
+
],
|
511 |
+
outputs=[output_audio, inference_transcript, sentence_audio, audio_tensors])
|
512 |
+
|
513 |
+
prompt_to_word.change(fn=update_bound_word,
|
514 |
+
inputs=[gr.State(False), prompt_to_word, gr.State("Replace all")],
|
515 |
+
outputs=[prompt_end_time])
|
516 |
+
edit_from_word.change(fn=update_bound_word,
|
517 |
+
inputs=[gr.State(True), edit_from_word, edit_word_mode],
|
518 |
+
outputs=[edit_start_time])
|
519 |
+
edit_to_word.change(fn=update_bound_word,
|
520 |
+
inputs=[gr.State(False), edit_to_word, edit_word_mode],
|
521 |
+
outputs=[edit_end_time])
|
522 |
+
edit_word_mode.change(fn=update_bound_words,
|
523 |
+
inputs=[edit_from_word, edit_to_word, edit_word_mode],
|
524 |
+
outputs=[edit_start_time, edit_end_time])
|
525 |
+
|
526 |
+
|
527 |
+
if __name__ == "__main__":
|
528 |
+
app.launch()
|
inference_speech_editing_scale.py
ADDED
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse, pickle
|
2 |
+
import logging
|
3 |
+
import os, random
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import torchaudio
|
7 |
+
|
8 |
+
from data.tokenizer import (
|
9 |
+
AudioTokenizer,
|
10 |
+
TextTokenizer,
|
11 |
+
tokenize_audio,
|
12 |
+
tokenize_text
|
13 |
+
)
|
14 |
+
|
15 |
+
from models import voicecraft
|
16 |
+
import argparse, time, tqdm
|
17 |
+
|
18 |
+
# this script only works for the musicgen architecture
|
19 |
+
def get_args():
|
20 |
+
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
21 |
+
parser.add_argument("--manifest_fn", type=str, default="path/to/eval_metadata_file")
|
22 |
+
parser.add_argument("--audio_root", type=str, default="path/to/audio_folder")
|
23 |
+
parser.add_argument("--exp_dir", type=str, default="path/to/model_folder")
|
24 |
+
parser.add_argument("--left_margin", type=float, default=0.08, help="extra space on the left to the word boundary")
|
25 |
+
parser.add_argument("--right_margin", type=float, default=0.08, help="extra space on the right to the word boundary")
|
26 |
+
parser.add_argument("--seed", type=int, default=1)
|
27 |
+
parser.add_argument("--codec_audio_sr", type=int, default=16000, help='the sample rate of audio that the codec is trained for')
|
28 |
+
parser.add_argument("--codec_sr", type=int, default=50, help='the sample rate of the codec codes')
|
29 |
+
parser.add_argument("--top_k", type=int, default=-1, help="sampling param")
|
30 |
+
parser.add_argument("--top_p", type=float, default=0.8, help="sampling param")
|
31 |
+
parser.add_argument("--temperature", type=float, default=1.0, help="sampling param")
|
32 |
+
parser.add_argument("--output_dir", type=str, default=None)
|
33 |
+
parser.add_argument("--device", type=str, default="cuda")
|
34 |
+
parser.add_argument("--signature", type=str, default=None, help="path to the encodec model")
|
35 |
+
parser.add_argument("--stop_repetition", type=int, default=2, help="used for inference, when the number of consecutive repetition of a token is bigger than this, stop it")
|
36 |
+
parser.add_argument("--kvcache", type=int, default=1, help='if true, use kv cache, which is 4-8x faster than without')
|
37 |
+
parser.add_argument("--silence_tokens", type=str, default="[1388,1898,131]", help="note that if you are not using the pretrained encodec 6f79c6a8, make sure you specified it yourself, rather than using the default")
|
38 |
+
return parser.parse_args()
|
39 |
+
|
40 |
+
@torch.no_grad()
|
41 |
+
def inference_one_sample(model, model_args, phn2num, text_tokenizer, audio_tokenizer, audio_fn, target_text, mask_interval, device, decode_config):
|
42 |
+
# phonemize
|
43 |
+
text_tokens = [phn2num[phn] for phn in
|
44 |
+
tokenize_text(
|
45 |
+
text_tokenizer, text=target_text.strip()
|
46 |
+
) if phn in phn2num
|
47 |
+
]
|
48 |
+
text_tokens = torch.LongTensor(text_tokens).unsqueeze(0)
|
49 |
+
text_tokens_lens = torch.LongTensor([text_tokens.shape[-1]])
|
50 |
+
|
51 |
+
encoded_frames = tokenize_audio(audio_tokenizer, audio_fn)
|
52 |
+
original_audio = encoded_frames[0][0].transpose(2,1) # [1,T,K]
|
53 |
+
assert original_audio.ndim==3 and original_audio.shape[0] == 1 and original_audio.shape[2] == model_args.n_codebooks, original_audio.shape
|
54 |
+
logging.info(f"with direct encodec encoding before input, original audio length: {original_audio.shape[1]} codec frames, which is {original_audio.shape[1]/decode_config['codec_sr']:.2f} sec.")
|
55 |
+
|
56 |
+
# forward
|
57 |
+
stime = time.time()
|
58 |
+
encoded_frames = model.inference(
|
59 |
+
text_tokens.to(device),
|
60 |
+
text_tokens_lens.to(device),
|
61 |
+
original_audio[...,:model_args.n_codebooks].to(device), # [1,T,8]
|
62 |
+
mask_interval=mask_interval.unsqueeze(0).to(device),
|
63 |
+
top_k=decode_config['top_k'],
|
64 |
+
top_p=decode_config['top_p'],
|
65 |
+
temperature=decode_config['temperature'],
|
66 |
+
stop_repetition=decode_config['stop_repetition'],
|
67 |
+
kvcache=decode_config['kvcache'],
|
68 |
+
silence_tokens=eval(decode_config['silence_tokens']) if type(decode_config['silence_tokens']) == str else decode_config['silence_tokens'],
|
69 |
+
) # output is [1,K,T]
|
70 |
+
logging.info(f"inference on one sample take: {time.time() - stime:.4f} sec.")
|
71 |
+
if type(encoded_frames) == tuple:
|
72 |
+
encoded_frames = encoded_frames[0]
|
73 |
+
logging.info(f"generated encoded_frames.shape: {encoded_frames.shape}, which is {encoded_frames.shape[-1]/decode_config['codec_sr']} sec.")
|
74 |
+
|
75 |
+
|
76 |
+
# decode (both original and generated)
|
77 |
+
original_sample = audio_tokenizer.decode(
|
78 |
+
[(original_audio.transpose(2,1), None)] # [1,T,8] -> [1,8,T]
|
79 |
+
)
|
80 |
+
generated_sample = audio_tokenizer.decode(
|
81 |
+
[(encoded_frames, None)]
|
82 |
+
)
|
83 |
+
|
84 |
+
return original_sample, generated_sample
|
85 |
+
|
86 |
+
def get_model(exp_dir, device=None):
|
87 |
+
with open(os.path.join(exp_dir, "args.pkl"), "rb") as f:
|
88 |
+
model_args = pickle.load(f)
|
89 |
+
|
90 |
+
logging.info("load model weights...")
|
91 |
+
model = voicecraft.VoiceCraft(model_args)
|
92 |
+
ckpt_fn = os.path.join(exp_dir, "best_bundle.pth")
|
93 |
+
ckpt = torch.load(ckpt_fn, map_location='cpu')['model']
|
94 |
+
phn2num = torch.load(ckpt_fn, map_location='cpu')['phn2num']
|
95 |
+
model.load_state_dict(ckpt)
|
96 |
+
del ckpt
|
97 |
+
logging.info("done loading weights...")
|
98 |
+
if device == None:
|
99 |
+
device = torch.device("cpu")
|
100 |
+
if torch.cuda.is_available():
|
101 |
+
device = torch.device("cuda:0")
|
102 |
+
model.to(device)
|
103 |
+
model.eval()
|
104 |
+
return model, model_args, phn2num
|
105 |
+
|
106 |
+
|
107 |
+
def get_mask_interval(ali_fn, word_span_ind, editType):
|
108 |
+
with open(ali_fn, "r") as rf:
|
109 |
+
data = [l.strip().split(",") for l in rf.readlines()]
|
110 |
+
data = data[1:]
|
111 |
+
tmp = word_span_ind.split(",")
|
112 |
+
s, e = int(tmp[0]), int(tmp[-1])
|
113 |
+
start = None
|
114 |
+
for j, item in enumerate(data):
|
115 |
+
if j == s and item[3] == "words":
|
116 |
+
if editType == 'insertion':
|
117 |
+
start = float(item[1])
|
118 |
+
else:
|
119 |
+
start = float(item[0])
|
120 |
+
if j == e and item[3] == "words":
|
121 |
+
if editType == 'insertion':
|
122 |
+
end = float(item[0])
|
123 |
+
else:
|
124 |
+
end = float(item[1])
|
125 |
+
assert start != None
|
126 |
+
break
|
127 |
+
return (start, end)
|
128 |
+
|
129 |
+
if __name__ == "__main__":
|
130 |
+
def seed_everything(seed):
|
131 |
+
os.environ['PYTHONHASHSEED'] = str(seed)
|
132 |
+
random.seed(seed)
|
133 |
+
np.random.seed(seed)
|
134 |
+
torch.manual_seed(seed)
|
135 |
+
torch.cuda.manual_seed(seed)
|
136 |
+
torch.backends.cudnn.benchmark = False
|
137 |
+
torch.backends.cudnn.deterministic = True
|
138 |
+
formatter = (
|
139 |
+
"%(asctime)s [%(levelname)s] %(filename)s:%(lineno)d || %(message)s"
|
140 |
+
)
|
141 |
+
logging.basicConfig(format=formatter, level=logging.INFO)
|
142 |
+
args = get_args()
|
143 |
+
# args.device = 'cpu'
|
144 |
+
args.allowed_repeat_tokens = eval(args.allowed_repeat_tokens)
|
145 |
+
seed_everything(args.seed)
|
146 |
+
|
147 |
+
# load model
|
148 |
+
stime = time.time()
|
149 |
+
logging.info(f"loading model from {args.exp_dir}")
|
150 |
+
model, model_args, phn2num = get_model(args.exp_dir)
|
151 |
+
if not os.path.isfile(model_args.exp_dir):
|
152 |
+
model_args.exp_dir = args.exp_dir
|
153 |
+
logging.info(f"loading model done, took {time.time() - stime:.4f} sec")
|
154 |
+
|
155 |
+
# setup text and audio tokenizer
|
156 |
+
text_tokenizer = TextTokenizer(backend="espeak")
|
157 |
+
audio_tokenizer = AudioTokenizer(signature=args.signature) # will also put the neural codec model on gpu
|
158 |
+
|
159 |
+
with open(args.manifest_fn, "r") as rf:
|
160 |
+
manifest = [l.strip().split("\t") for l in rf.readlines()]
|
161 |
+
manifest = manifest[1:]
|
162 |
+
|
163 |
+
# wav_fn txt_fn alingment_fn num_words word_span_ind
|
164 |
+
audio_fns = []
|
165 |
+
target_texts = []
|
166 |
+
mask_intervals = []
|
167 |
+
edit_types = []
|
168 |
+
new_spans = []
|
169 |
+
orig_spans = []
|
170 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
171 |
+
if args.crop_concat:
|
172 |
+
mfa_temp = f"{args.output_dir}/mfa_temp"
|
173 |
+
os.makedirs(mfa_temp, exist_ok=True)
|
174 |
+
for item in manifest:
|
175 |
+
audio_fn = os.path.join(args.audio_root, item[0])
|
176 |
+
temp = torchaudio.info(audio_fn)
|
177 |
+
audio_dur = temp.num_frames/temp.sample_rate
|
178 |
+
audio_fns.append(audio_fn)
|
179 |
+
target_text = item[2].split("|")[-1]
|
180 |
+
edit_types.append(item[5].split("|"))
|
181 |
+
new_spans.append(item[4].split("|"))
|
182 |
+
orig_spans.append(item[3].split("|"))
|
183 |
+
target_texts.append(target_text) # the last transcript is the target
|
184 |
+
# mi needs to be created from word_ind_span and alignment_fn, along with args.left_margin and args.right_margin
|
185 |
+
mis = []
|
186 |
+
all_ind_intervals = item[3].split("|")
|
187 |
+
editTypes = item[5].split("|")
|
188 |
+
smaller_indx = []
|
189 |
+
alignment_fn = os.path.join(args.audio_root, "aligned", item[0].replace(".wav", ".csv"))
|
190 |
+
if not os.path.isfile(alignment_fn):
|
191 |
+
alignment_fn = alignment_fn.replace("/aligned/", "/aligned_csv/")
|
192 |
+
assert os.path.isfile(alignment_fn), alignment_fn
|
193 |
+
for ind_inter,editType in zip(all_ind_intervals, editTypes):
|
194 |
+
# print(ind_inter)
|
195 |
+
mi = get_mask_interval(alignment_fn, ind_inter, editType)
|
196 |
+
mi = (max(mi[0] - args.left_margin, 1/args.codec_sr), min(mi[1] + args.right_margin, audio_dur)) # in seconds
|
197 |
+
mis.append(mi)
|
198 |
+
smaller_indx.append(mi[0])
|
199 |
+
ind = np.argsort(smaller_indx)
|
200 |
+
mis = [mis[id] for id in ind]
|
201 |
+
mask_intervals.append(mis)
|
202 |
+
|
203 |
+
|
204 |
+
|
205 |
+
for i, (audio_fn, target_text, mask_interval) in enumerate(tqdm.tqdm(zip(audio_fns, target_texts, mask_intervals))):
|
206 |
+
orig_mask_interval = mask_interval
|
207 |
+
mask_interval = [[round(cmi[0]*args.codec_sr), round(cmi[1]*args.codec_sr)] for cmi in mask_interval]
|
208 |
+
# logging.info(f"i: {i}, mask_interval: {mask_interval}")
|
209 |
+
mask_interval = torch.LongTensor(mask_interval) # [M,2]
|
210 |
+
orig_audio, new_audio = inference_one_sample(model, model_args, phn2num, text_tokenizer, audio_tokenizer, audio_fn, target_text, mask_interval, args.device, vars(args))
|
211 |
+
|
212 |
+
# save segments for comparison
|
213 |
+
orig_audio, new_audio = orig_audio[0].cpu(), new_audio[0].cpu()
|
214 |
+
# logging.info(f"length of the resynthesize orig audio: {orig_audio.shape}")
|
215 |
+
|
216 |
+
save_fn_new = f"{args.output_dir}/{os.path.basename(audio_fn)[:-4]}_new_seed{args.seed}.wav"
|
217 |
+
|
218 |
+
torchaudio.save(save_fn_new, new_audio, args.codec_audio_sr)
|
219 |
+
|
220 |
+
save_fn_orig = f"{args.output_dir}/{os.path.basename(audio_fn)[:-4]}_orig.wav"
|
221 |
+
if not os.path.isfile(save_fn_orig):
|
222 |
+
orig_audio, orig_sr = torchaudio.load(audio_fn)
|
223 |
+
if orig_sr != args.codec_audio_sr:
|
224 |
+
orig_audio = torchaudio.transforms.Resample(orig_sr, args.codec_audio_sr)(orig_audio)
|
225 |
+
torchaudio.save(save_fn_orig, orig_audio, args.codec_audio_sr)
|
226 |
+
|
inference_tts_scale.py
ADDED
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse, pickle
|
2 |
+
import logging
|
3 |
+
import os, random
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import torchaudio
|
7 |
+
|
8 |
+
from data.tokenizer import (
|
9 |
+
AudioTokenizer,
|
10 |
+
TextTokenizer,
|
11 |
+
tokenize_audio,
|
12 |
+
tokenize_text
|
13 |
+
)
|
14 |
+
|
15 |
+
from models import voicecraft
|
16 |
+
import argparse, time, tqdm
|
17 |
+
|
18 |
+
|
19 |
+
# this script only works for the musicgen architecture
|
20 |
+
def get_args():
|
21 |
+
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
22 |
+
parser.add_argument("--manifest_fn", type=str, default="path/to/eval_metadata_file")
|
23 |
+
parser.add_argument("--audio_root", type=str, default="path/to/audio_folder")
|
24 |
+
parser.add_argument("--exp_dir", type=str, default="path/to/model_folder")
|
25 |
+
parser.add_argument("--seed", type=int, default=1)
|
26 |
+
parser.add_argument("--codec_audio_sr", type=int, default=16000, help='the sample rate of audio that the codec is trained for')
|
27 |
+
parser.add_argument("--codec_sr", type=int, default=50, help='the sample rate of the codec codes')
|
28 |
+
parser.add_argument("--top_k", type=int, default=0, help="sampling param")
|
29 |
+
parser.add_argument("--top_p", type=float, default=0.8, help="sampling param")
|
30 |
+
parser.add_argument("--temperature", type=float, default=1.0, help="sampling param")
|
31 |
+
parser.add_argument("--output_dir", type=str, default=None)
|
32 |
+
parser.add_argument("--device", type=str, default="cuda")
|
33 |
+
parser.add_argument("--signature", type=str, default=None, help="path to the encodec model")
|
34 |
+
parser.add_argument("--crop_concat", type=int, default=0)
|
35 |
+
parser.add_argument("--stop_repetition", type=int, default=-1, help="used for inference, when the number of consecutive repetition of a token is bigger than this, stop it")
|
36 |
+
parser.add_argument("--kvcache", type=int, default=1, help='if true, use kv cache, which is 4-8x faster than without')
|
37 |
+
parser.add_argument("--sample_batch_size", type=int, default=1, help="batch size for sampling, NOTE that it's not running inference for several samples, but duplicate one input sample batch_size times, and during inference, we only return the shortest generation")
|
38 |
+
parser.add_argument("--silence_tokens", type=str, default="[1388,1898,131]", help="note that if you are not using the pretrained encodec 6f79c6a8, make sure you specified it yourself, rather than using the default")
|
39 |
+
return parser.parse_args()
|
40 |
+
|
41 |
+
|
42 |
+
@torch.no_grad()
|
43 |
+
def inference_one_sample(model, model_args, phn2num, text_tokenizer, audio_tokenizer, audio_fn, target_text, device, decode_config, prompt_end_frame):
|
44 |
+
# phonemize
|
45 |
+
text_tokens = [phn2num[phn] for phn in
|
46 |
+
tokenize_text(
|
47 |
+
text_tokenizer, text=target_text.strip()
|
48 |
+
) if phn in phn2num
|
49 |
+
]
|
50 |
+
text_tokens = torch.LongTensor(text_tokens).unsqueeze(0)
|
51 |
+
text_tokens_lens = torch.LongTensor([text_tokens.shape[-1]])
|
52 |
+
|
53 |
+
# encode audio
|
54 |
+
encoded_frames = tokenize_audio(audio_tokenizer, audio_fn, offset=0, num_frames=prompt_end_frame)
|
55 |
+
original_audio = encoded_frames[0][0].transpose(2,1) # [1,T,K]
|
56 |
+
assert original_audio.ndim==3 and original_audio.shape[0] == 1 and original_audio.shape[2] == model_args.n_codebooks, original_audio.shape
|
57 |
+
logging.info(f"original audio length: {original_audio.shape[1]} codec frames, which is {original_audio.shape[1]/decode_config['codec_sr']:.2f} sec.")
|
58 |
+
|
59 |
+
# forward
|
60 |
+
stime = time.time()
|
61 |
+
if decode_config['sample_batch_size'] <= 1:
|
62 |
+
logging.info(f"running inference with batch size 1")
|
63 |
+
concat_frames, gen_frames = model.inference_tts(
|
64 |
+
text_tokens.to(device),
|
65 |
+
text_tokens_lens.to(device),
|
66 |
+
original_audio[...,:model_args.n_codebooks].to(device), # [1,T,8]
|
67 |
+
top_k=decode_config['top_k'],
|
68 |
+
top_p=decode_config['top_p'],
|
69 |
+
temperature=decode_config['temperature'],
|
70 |
+
stop_repetition=decode_config['stop_repetition'],
|
71 |
+
kvcache=decode_config['kvcache'],
|
72 |
+
silence_tokens=eval(decode_config['silence_tokens']) if type(decode_config['silence_tokens'])==str else decode_config['silence_tokens']
|
73 |
+
) # output is [1,K,T]
|
74 |
+
else:
|
75 |
+
logging.info(f"running inference with batch size {decode_config['sample_batch_size']}, i.e. return the shortest among {decode_config['sample_batch_size']} generations.")
|
76 |
+
concat_frames, gen_frames = model.inference_tts_batch(
|
77 |
+
text_tokens.to(device),
|
78 |
+
text_tokens_lens.to(device),
|
79 |
+
original_audio[...,:model_args.n_codebooks].to(device), # [1,T,8]
|
80 |
+
top_k=decode_config['top_k'],
|
81 |
+
top_p=decode_config['top_p'],
|
82 |
+
temperature=decode_config['temperature'],
|
83 |
+
stop_repetition=decode_config['stop_repetition'],
|
84 |
+
kvcache=decode_config['kvcache'],
|
85 |
+
batch_size = decode_config['sample_batch_size'],
|
86 |
+
silence_tokens=eval(decode_config['silence_tokens']) if type(decode_config['silence_tokens'])==str else decode_config['silence_tokens']
|
87 |
+
) # output is [1,K,T]
|
88 |
+
logging.info(f"inference on one sample take: {time.time() - stime:.4f} sec.")
|
89 |
+
|
90 |
+
logging.info(f"generated encoded_frames.shape: {gen_frames.shape}, which is {gen_frames.shape[-1]/decode_config['codec_sr']} sec.")
|
91 |
+
|
92 |
+
# for timestamp, codes in enumerate(gen_frames[0].transpose(1,0)):
|
93 |
+
# logging.info(f"{timestamp}: {codes.tolist()}")
|
94 |
+
# decode (both original and generated)
|
95 |
+
concat_sample = audio_tokenizer.decode(
|
96 |
+
[(concat_frames, None)] # [1,T,8] -> [1,8,T]
|
97 |
+
)
|
98 |
+
gen_sample = audio_tokenizer.decode(
|
99 |
+
[(gen_frames, None)]
|
100 |
+
)
|
101 |
+
|
102 |
+
# return
|
103 |
+
return concat_sample, gen_sample
|
104 |
+
|
105 |
+
def get_model(exp_dir, device=None):
|
106 |
+
with open(os.path.join(exp_dir, "args.pkl"), "rb") as f:
|
107 |
+
model_args = pickle.load(f)
|
108 |
+
|
109 |
+
logging.info("load model weights...")
|
110 |
+
model = voicecraft.VoiceCraft(model_args)
|
111 |
+
ckpt_fn = os.path.join(exp_dir, "best_bundle.pth")
|
112 |
+
ckpt = torch.load(ckpt_fn, map_location='cpu')['model']
|
113 |
+
phn2num = torch.load(ckpt_fn, map_location='cpu')['phn2num']
|
114 |
+
model.load_state_dict(ckpt)
|
115 |
+
del ckpt
|
116 |
+
logging.info("done loading weights...")
|
117 |
+
if device == None:
|
118 |
+
device = torch.device("cpu")
|
119 |
+
if torch.cuda.is_available():
|
120 |
+
device = torch.device("cuda:0")
|
121 |
+
model.to(device)
|
122 |
+
model.eval()
|
123 |
+
return model, model_args, phn2num
|
124 |
+
|
125 |
+
if __name__ == "__main__":
|
126 |
+
def seed_everything(seed):
|
127 |
+
os.environ['PYTHONHASHSEED'] = str(seed)
|
128 |
+
random.seed(seed)
|
129 |
+
np.random.seed(seed)
|
130 |
+
torch.manual_seed(seed)
|
131 |
+
torch.cuda.manual_seed(seed)
|
132 |
+
torch.backends.cudnn.benchmark = False
|
133 |
+
torch.backends.cudnn.deterministic = True
|
134 |
+
formatter = (
|
135 |
+
"%(asctime)s [%(levelname)s] %(filename)s:%(lineno)d || %(message)s"
|
136 |
+
)
|
137 |
+
logging.basicConfig(format=formatter, level=logging.INFO)
|
138 |
+
args = get_args()
|
139 |
+
# args.device='cpu'
|
140 |
+
seed_everything(args.seed)
|
141 |
+
|
142 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
143 |
+
# load model
|
144 |
+
|
145 |
+
with open(args.manifest_fn, "r") as rf:
|
146 |
+
manifest = [l.strip().split("\t") for l in rf.readlines()]
|
147 |
+
manifest = manifest[1:]
|
148 |
+
manifest = [[item[0], item[2], item[3], item[1], item[5]] for item in manifest]
|
149 |
+
|
150 |
+
stime = time.time()
|
151 |
+
logging.info(f"loading model from {args.exp_dir}")
|
152 |
+
model, model_args, phn2num = get_model(args.exp_dir)
|
153 |
+
logging.info(f"loading model done, took {time.time() - stime:.4f} sec")
|
154 |
+
|
155 |
+
# setup text and audio tokenizer
|
156 |
+
text_tokenizer = TextTokenizer(backend="espeak")
|
157 |
+
audio_tokenizer = AudioTokenizer(signature=args.signature) # will also put the neural codec model on gpu
|
158 |
+
|
159 |
+
audio_fns = []
|
160 |
+
texts = []
|
161 |
+
prompt_end_frames = []
|
162 |
+
new_audio_fns = []
|
163 |
+
text_to_syn = []
|
164 |
+
|
165 |
+
for item in manifest:
|
166 |
+
audio_fn = os.path.join(args.audio_root, item[0])
|
167 |
+
audio_fns.append(audio_fn)
|
168 |
+
temp = torchaudio.info(audio_fn)
|
169 |
+
prompt_end_frames.append(round(float(item[2])*temp.sample_rate))
|
170 |
+
texts.append(item[1])
|
171 |
+
new_audio_fns.append(item[-2])
|
172 |
+
all_text = item[1].split(" ")
|
173 |
+
start_ind = int(item[-1].split(",")[0])
|
174 |
+
text_to_syn.append(" ".join(all_text[start_ind:]))
|
175 |
+
|
176 |
+
for i, (audio_fn, text, prompt_end_frame, new_audio_fn, to_syn) in enumerate(tqdm.tqdm((zip(audio_fns, texts, prompt_end_frames, new_audio_fns, text_to_syn)))):
|
177 |
+
output_expected_sr = args.codec_audio_sr
|
178 |
+
concated_audio, gen_audio = inference_one_sample(model, model_args, phn2num, text_tokenizer, audio_tokenizer, audio_fn, text, args.device, vars(args), prompt_end_frame)
|
179 |
+
|
180 |
+
# save segments for comparison
|
181 |
+
concated_audio, gen_audio = concated_audio[0].cpu(), gen_audio[0].cpu()
|
182 |
+
if output_expected_sr != args.codec_audio_sr:
|
183 |
+
gen_audio = torchaudio.transforms.Resample(output_expected_sr, args.codec_audio_sr)(gen_audio)
|
184 |
+
concated_audio = torchaudio.transforms.Resample(output_expected_sr, args.codec_audio_sr)(concated_audio)
|
185 |
+
|
186 |
+
seg_save_fn_gen = f"{args.output_dir}/gen_{new_audio_fn[:-4]}_{i}_seed{args.seed}.wav"
|
187 |
+
seg_save_fn_concat = f"{args.output_dir}/concat_{new_audio_fn[:-4]}_{i}_seed{args.seed}.wav"
|
188 |
+
|
189 |
+
torchaudio.save(seg_save_fn_gen, gen_audio, args.codec_audio_sr)
|
190 |
+
torchaudio.save(seg_save_fn_concat, concated_audio, args.codec_audio_sr)
|
models/__pycache__/codebooks_patterns.cpython-310.pyc
ADDED
Binary file (25 kB). View file
|
|
models/__pycache__/voicecraft.cpython-310.pyc
ADDED
Binary file (40.1 kB). View file
|
|
models/codebooks_patterns.py
ADDED
@@ -0,0 +1,538 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from collections import namedtuple
|
8 |
+
from dataclasses import dataclass
|
9 |
+
from functools import lru_cache
|
10 |
+
import logging
|
11 |
+
import typing as tp
|
12 |
+
|
13 |
+
from abc import ABC, abstractmethod
|
14 |
+
import torch
|
15 |
+
|
16 |
+
LayoutCoord = namedtuple('LayoutCoord', ['t', 'q']) # (timestep, codebook index)
|
17 |
+
PatternLayout = tp.List[tp.List[LayoutCoord]] # Sequence of coordinates
|
18 |
+
|
19 |
+
|
20 |
+
@dataclass
|
21 |
+
class Pattern:
|
22 |
+
"""Base implementation of a pattern over a sequence with multiple codebooks.
|
23 |
+
|
24 |
+
The codebook pattern consists in a layout, defining for each sequence step
|
25 |
+
the list of coordinates of each codebook timestep in the resulting interleaved sequence.
|
26 |
+
The first item of the pattern is always an empty list in order to properly insert a special token
|
27 |
+
to start with. For convenience, we also keep track of ``n_q`` the number of codebooks used for the pattern
|
28 |
+
and ``timesteps`` the number of timesteps corresponding to the original sequence.
|
29 |
+
|
30 |
+
The pattern provides convenient methods to build and revert interleaved sequences from it:
|
31 |
+
``build_pattern_sequence`` maps a given a dense input tensor of multi-codebook sequence from [B, K, T]
|
32 |
+
to the interleaved sequence of shape [B, K, S] applying the pattern, with S being the batch size,
|
33 |
+
K being the number of codebooks, T the number of original timesteps and S the number of sequence steps
|
34 |
+
for the output sequence. The unfilled positions are replaced with a special token and the built sequence
|
35 |
+
is returned along with a mask indicating valid tokens.
|
36 |
+
``revert_pattern_sequence`` maps back an interleaved sequence of shape [B, K, S] to the original alignment
|
37 |
+
of codebooks across timesteps to an output tensor of shape [B, K, T], using again a special token and a mask
|
38 |
+
to fill and specify invalid positions if needed.
|
39 |
+
See the dedicated methods for more details.
|
40 |
+
"""
|
41 |
+
# Pattern layout, for each sequence step, we have a list of coordinates
|
42 |
+
# corresponding to the original codebook timestep and position.
|
43 |
+
# The first list is always an empty list in order to properly insert
|
44 |
+
# a special token to start with.
|
45 |
+
layout: PatternLayout
|
46 |
+
timesteps: int
|
47 |
+
n_q: int
|
48 |
+
|
49 |
+
def __post_init__(self):
|
50 |
+
assert len(self.layout) > 0
|
51 |
+
assert self.layout[0] == []
|
52 |
+
self._validate_layout()
|
53 |
+
self._build_reverted_sequence_scatter_indexes = lru_cache(100)(self._build_reverted_sequence_scatter_indexes)
|
54 |
+
self._build_pattern_sequence_scatter_indexes = lru_cache(100)(self._build_pattern_sequence_scatter_indexes)
|
55 |
+
# logging.info("New pattern, time steps: %d, sequence steps: %d", self.timesteps, len(self.layout))
|
56 |
+
|
57 |
+
def _validate_layout(self):
|
58 |
+
"""Runs checks on the layout to ensure a valid pattern is defined.
|
59 |
+
A pattern is considered invalid if:
|
60 |
+
- Multiple timesteps for a same codebook are defined in the same sequence step
|
61 |
+
- The timesteps for a given codebook are not in ascending order as we advance in the sequence
|
62 |
+
(this would mean that we have future timesteps before past timesteps).
|
63 |
+
"""
|
64 |
+
q_timesteps = {q: 0 for q in range(self.n_q)}
|
65 |
+
for s, seq_coords in enumerate(self.layout):
|
66 |
+
if len(seq_coords) > 0:
|
67 |
+
qs = set()
|
68 |
+
for coord in seq_coords:
|
69 |
+
qs.add(coord.q)
|
70 |
+
last_q_timestep = q_timesteps[coord.q]
|
71 |
+
assert coord.t >= last_q_timestep, \
|
72 |
+
f"Past timesteps are found in the sequence for codebook = {coord.q} at step {s}"
|
73 |
+
q_timesteps[coord.q] = coord.t
|
74 |
+
# each sequence step contains at max 1 coordinate per codebook
|
75 |
+
assert len(qs) == len(seq_coords), \
|
76 |
+
f"Multiple entries for a same codebook are found at step {s}"
|
77 |
+
|
78 |
+
@property
|
79 |
+
def num_sequence_steps(self):
|
80 |
+
return len(self.layout) - 1
|
81 |
+
|
82 |
+
@property
|
83 |
+
def max_delay(self):
|
84 |
+
max_t_in_seq_coords = 0
|
85 |
+
for seq_coords in self.layout[1:]:
|
86 |
+
for coords in seq_coords:
|
87 |
+
max_t_in_seq_coords = max(max_t_in_seq_coords, coords.t + 1)
|
88 |
+
return max_t_in_seq_coords - self.timesteps
|
89 |
+
|
90 |
+
@property
|
91 |
+
def valid_layout(self):
|
92 |
+
valid_step = len(self.layout) - self.max_delay
|
93 |
+
return self.layout[:valid_step]
|
94 |
+
|
95 |
+
def get_sequence_coords_with_timestep(self, t: int, q: tp.Optional[int] = None):
|
96 |
+
"""Get codebook coordinates in the layout that corresponds to the specified timestep t
|
97 |
+
and optionally to the codebook q. Coordinates are returned as a tuple with the sequence step
|
98 |
+
and the actual codebook coordinates.
|
99 |
+
"""
|
100 |
+
assert t <= self.timesteps, "provided timesteps is greater than the pattern's number of timesteps"
|
101 |
+
if q is not None:
|
102 |
+
assert q <= self.n_q, "provided number of codebooks is greater than the pattern's number of codebooks"
|
103 |
+
coords = []
|
104 |
+
for s, seq_codes in enumerate(self.layout):
|
105 |
+
for code in seq_codes:
|
106 |
+
if code.t == t and (q is None or code.q == q):
|
107 |
+
coords.append((s, code))
|
108 |
+
return coords
|
109 |
+
|
110 |
+
def get_steps_with_timestep(self, t: int, q: tp.Optional[int] = None) -> tp.List[int]:
|
111 |
+
return [step for step, coords in self.get_sequence_coords_with_timestep(t, q)]
|
112 |
+
|
113 |
+
def get_first_step_with_timesteps(self, t: int, q: tp.Optional[int] = None) -> tp.Optional[int]:
|
114 |
+
steps_with_timesteps = self.get_steps_with_timestep(t, q)
|
115 |
+
return steps_with_timesteps[0] if len(steps_with_timesteps) > 0 else None
|
116 |
+
|
117 |
+
def _build_pattern_sequence_scatter_indexes(self, timesteps: int, n_q: int, keep_only_valid_steps: bool,
|
118 |
+
device: tp.Union[torch.device, str] = 'cpu'):
|
119 |
+
"""Build scatter indexes corresponding to the pattern, up to the provided sequence_steps.
|
120 |
+
|
121 |
+
Args:
|
122 |
+
timesteps (int): Maximum number of timesteps steps to consider.
|
123 |
+
keep_only_valid_steps (bool): Restrict the pattern layout to match only valid steps.
|
124 |
+
device (Union[torch.device, str]): Device for created tensors.
|
125 |
+
Returns:
|
126 |
+
indexes (torch.Tensor): Indexes corresponding to the sequence, of shape [K, S].
|
127 |
+
mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes, of shape [K, S].
|
128 |
+
"""
|
129 |
+
assert n_q == self.n_q, f"invalid number of codebooks for the sequence and the pattern: {n_q} != {self.n_q}"
|
130 |
+
assert timesteps <= self.timesteps, "invalid number of timesteps used to build the sequence from the pattern"
|
131 |
+
# use the proper layout based on whether we limit ourselves to valid steps only or not,
|
132 |
+
# note that using the valid_layout will result in a truncated sequence up to the valid steps
|
133 |
+
ref_layout = self.valid_layout if keep_only_valid_steps else self.layout
|
134 |
+
# single item indexing being super slow with pytorch vs. numpy, so we use numpy here
|
135 |
+
indexes = torch.zeros(n_q, len(ref_layout), dtype=torch.long).numpy()
|
136 |
+
mask = torch.zeros(n_q, len(ref_layout), dtype=torch.bool).numpy()
|
137 |
+
# fill indexes with last sequence step value that will correspond to our special token
|
138 |
+
# the last value is n_q * timesteps as we have flattened z and append special token as the last token
|
139 |
+
# which will correspond to the index: n_q * timesteps
|
140 |
+
indexes[:] = n_q * timesteps
|
141 |
+
# iterate over the pattern and fill scattered indexes and mask
|
142 |
+
for s, sequence_coords in enumerate(ref_layout):
|
143 |
+
for coords in sequence_coords:
|
144 |
+
if coords.t < timesteps:
|
145 |
+
indexes[coords.q, s] = coords.t + coords.q * timesteps
|
146 |
+
mask[coords.q, s] = 1
|
147 |
+
indexes = torch.from_numpy(indexes).to(device)
|
148 |
+
mask = torch.from_numpy(mask).to(device)
|
149 |
+
return indexes, mask
|
150 |
+
|
151 |
+
def build_pattern_sequence(self, z: torch.Tensor, special_token: int, keep_only_valid_steps: bool = False):
|
152 |
+
"""Build sequence corresponding to the pattern from the input tensor z.
|
153 |
+
The sequence is built using up to sequence_steps if specified, and non-pattern
|
154 |
+
coordinates are filled with the special token.
|
155 |
+
|
156 |
+
Args:
|
157 |
+
z (torch.Tensor): Input tensor of multi-codebooks sequence, of shape [B, K, T].
|
158 |
+
special_token (int): Special token used to fill non-pattern coordinates in the new sequence.
|
159 |
+
keep_only_valid_steps (bool): Build a sequence from the pattern up to valid (= fully defined) steps.
|
160 |
+
Steps that are beyond valid steps will be replaced by the special_token in that case.
|
161 |
+
Returns:
|
162 |
+
values (torch.Tensor): Interleaved sequence matching the pattern, of shape [B, K, S] with S
|
163 |
+
corresponding either to the sequence_steps if provided, otherwise to the length of the pattern.
|
164 |
+
indexes (torch.Tensor): Indexes corresponding to the interleaved sequence, of shape [K, S].
|
165 |
+
mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, S].
|
166 |
+
"""
|
167 |
+
B, K, T = z.shape
|
168 |
+
indexes, mask = self._build_pattern_sequence_scatter_indexes(
|
169 |
+
T, K, keep_only_valid_steps=keep_only_valid_steps, device=str(z.device)
|
170 |
+
)
|
171 |
+
z = z.view(B, -1)
|
172 |
+
# we append the special token as the last index of our flattened z tensor
|
173 |
+
z = torch.cat([z, torch.zeros_like(z[:, :1]) + special_token], dim=1)
|
174 |
+
values = z[:, indexes.view(-1)]
|
175 |
+
values = values.view(B, K, indexes.shape[-1])
|
176 |
+
return values, indexes, mask
|
177 |
+
|
178 |
+
def _build_reverted_sequence_scatter_indexes(self, sequence_steps: int, n_q: int,
|
179 |
+
keep_only_valid_steps: bool = False,
|
180 |
+
is_model_output: bool = False,
|
181 |
+
device: tp.Union[torch.device, str] = 'cpu'):
|
182 |
+
"""Builds scatter indexes required to retrieve the original multi-codebook sequence
|
183 |
+
from interleaving pattern.
|
184 |
+
|
185 |
+
Args:
|
186 |
+
sequence_steps (int): Sequence steps.
|
187 |
+
n_q (int): Number of codebooks.
|
188 |
+
keep_only_valid_steps (bool): Build a sequence from the pattern up to valid (= fully defined) steps.
|
189 |
+
Steps that are beyond valid steps will be replaced by the special_token in that case.
|
190 |
+
is_model_output (bool): Whether to keep the sequence item corresponding to initial special token or not.
|
191 |
+
device (Union[torch.device, str]): Device for created tensors.
|
192 |
+
Returns:
|
193 |
+
torch.Tensor: Indexes for reconstructing the output, of shape [K, T].
|
194 |
+
mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, T].
|
195 |
+
"""
|
196 |
+
ref_layout = self.valid_layout if keep_only_valid_steps else self.layout
|
197 |
+
# TODO(jade): Do we want to further truncate to only valid timesteps here as well?
|
198 |
+
timesteps = self.timesteps
|
199 |
+
assert n_q == self.n_q, f"invalid number of codebooks for the sequence and the pattern: {n_q} != {self.n_q}"
|
200 |
+
assert sequence_steps <= len(ref_layout), \
|
201 |
+
f"sequence to revert is longer than the defined pattern: {sequence_steps} > {len(ref_layout)}"
|
202 |
+
|
203 |
+
# ensure we take the appropriate indexes to keep the model output from the first special token as well
|
204 |
+
if is_model_output:
|
205 |
+
ref_layout = ref_layout[1:]
|
206 |
+
|
207 |
+
# single item indexing being super slow with pytorch vs. numpy, so we use numpy here
|
208 |
+
indexes = torch.zeros(n_q, timesteps, dtype=torch.long).numpy()
|
209 |
+
mask = torch.zeros(n_q, timesteps, dtype=torch.bool).numpy()
|
210 |
+
# fill indexes with last sequence step value that will correspond to our special token
|
211 |
+
indexes[:] = n_q * sequence_steps
|
212 |
+
for s, sequence_codes in enumerate(ref_layout):
|
213 |
+
if s < sequence_steps:
|
214 |
+
for code in sequence_codes:
|
215 |
+
if code.t < timesteps:
|
216 |
+
indexes[code.q, code.t] = s + code.q * sequence_steps
|
217 |
+
mask[code.q, code.t] = 1
|
218 |
+
indexes = torch.from_numpy(indexes).to(device)
|
219 |
+
mask = torch.from_numpy(mask).to(device)
|
220 |
+
return indexes, mask
|
221 |
+
|
222 |
+
def revert_pattern_sequence(self, s: torch.Tensor, special_token: int, keep_only_valid_steps: bool = False):
|
223 |
+
"""Revert a sequence built from the pattern back to the original multi-codebook sequence without interleaving.
|
224 |
+
The sequence is reverted using up to timesteps if specified, and non-pattern coordinates
|
225 |
+
are filled with the special token.
|
226 |
+
|
227 |
+
Args:
|
228 |
+
s (torch.Tensor): Interleaved sequence tensor obtained from the pattern, of shape [B, K, S].
|
229 |
+
special_token (int or float): Special token used to fill non-pattern coordinates in the new sequence.
|
230 |
+
Returns:
|
231 |
+
values (torch.Tensor): Interleaved sequence matching the pattern, of shape [B, K, T] with T
|
232 |
+
corresponding either to the timesteps if provided, or the total timesteps in pattern otherwise.
|
233 |
+
indexes (torch.Tensor): Indexes corresponding to the interleaved sequence, of shape [K, T].
|
234 |
+
mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, T].
|
235 |
+
"""
|
236 |
+
B, K, S = s.shape
|
237 |
+
indexes, mask = self._build_reverted_sequence_scatter_indexes(
|
238 |
+
S, K, keep_only_valid_steps, is_model_output=False, device=str(s.device)
|
239 |
+
)
|
240 |
+
s = s.view(B, -1)
|
241 |
+
# we append the special token as the last index of our flattened z tensor
|
242 |
+
s = torch.cat([s, torch.zeros_like(s[:, :1]) + special_token], dim=1)
|
243 |
+
values = s[:, indexes.view(-1)]
|
244 |
+
values = values.view(B, K, indexes.shape[-1])
|
245 |
+
return values, indexes, mask
|
246 |
+
|
247 |
+
def revert_pattern_logits(self, logits: torch.Tensor, special_token: float, keep_only_valid_steps: bool = False):
|
248 |
+
"""Revert model logits obtained on a sequence built from the pattern
|
249 |
+
back to a tensor matching the original sequence.
|
250 |
+
|
251 |
+
This method is similar to ``revert_pattern_sequence`` with the following specificities:
|
252 |
+
1. It is designed to work with the extra cardinality dimension
|
253 |
+
2. We return the logits for the first sequence item that matches the special_token and
|
254 |
+
which matching target in the original sequence is the first item of the sequence,
|
255 |
+
while we skip the last logits as there is no matching target
|
256 |
+
"""
|
257 |
+
B, card, K, S = logits.shape
|
258 |
+
indexes, mask = self._build_reverted_sequence_scatter_indexes(
|
259 |
+
S, K, keep_only_valid_steps, is_model_output=True, device=logits.device
|
260 |
+
)
|
261 |
+
logits = logits.reshape(B, card, -1)
|
262 |
+
# we append the special token as the last index of our flattened z tensor
|
263 |
+
logits = torch.cat([logits, torch.zeros_like(logits[:, :, :1]) + special_token], dim=-1) # [B, card, K x S]
|
264 |
+
values = logits[:, :, indexes.view(-1)]
|
265 |
+
values = values.view(B, card, K, indexes.shape[-1])
|
266 |
+
return values, indexes, mask
|
267 |
+
|
268 |
+
|
269 |
+
class CodebooksPatternProvider(ABC):
|
270 |
+
"""Abstraction around providing pattern for interleaving codebooks.
|
271 |
+
|
272 |
+
The CodebooksPatternProvider abstraction allows to implement various strategies to
|
273 |
+
define interleaving pattern of sequences composed of multiple codebooks. For a given
|
274 |
+
number of codebooks `n_q`, the pattern provider can generate a specified pattern
|
275 |
+
corresponding to a sequence of `T` timesteps with `n_q` parallel codebooks. This pattern
|
276 |
+
can be used to construct a new sequence from the original codes respecting the specified
|
277 |
+
pattern. The pattern is defined as a list of list of code coordinates, code coordinate
|
278 |
+
being a tuple with the original timestep and codebook to build the new sequence.
|
279 |
+
Note that all patterns must start with an empty list that is then used to insert a first
|
280 |
+
sequence step of special tokens in the newly generated sequence.
|
281 |
+
|
282 |
+
Args:
|
283 |
+
n_q (int): number of codebooks.
|
284 |
+
cached (bool): if True, patterns for a given length are cached. In general
|
285 |
+
that should be true for efficiency reason to avoid synchronization points.
|
286 |
+
"""
|
287 |
+
def __init__(self, n_q: int, cached: bool = True):
|
288 |
+
assert n_q > 0
|
289 |
+
self.n_q = n_q
|
290 |
+
self.get_pattern = lru_cache(100)(self.get_pattern) # type: ignore
|
291 |
+
|
292 |
+
@abstractmethod
|
293 |
+
def get_pattern(self, timesteps: int) -> Pattern:
|
294 |
+
"""Builds pattern with specific interleaving between codebooks.
|
295 |
+
|
296 |
+
Args:
|
297 |
+
timesteps (int): Total numer of timesteps.
|
298 |
+
"""
|
299 |
+
raise NotImplementedError()
|
300 |
+
|
301 |
+
|
302 |
+
class DelayedPatternProvider(CodebooksPatternProvider):
|
303 |
+
"""Provider for delayed pattern across delayed codebooks.
|
304 |
+
Codebooks are delayed in the sequence and sequence steps will contain codebooks
|
305 |
+
from different timesteps.
|
306 |
+
|
307 |
+
Example:
|
308 |
+
Taking timesteps=4 and n_q=3, delays=None, the multi-codebook sequence:
|
309 |
+
[[1, 2, 3, 4],
|
310 |
+
[1, 2, 3, 4],
|
311 |
+
[1, 2, 3, 4]]
|
312 |
+
The resulting sequence obtained from the returned pattern is:
|
313 |
+
[[S, 1, 2, 3, 4],
|
314 |
+
[S, S, 1, 2, 3],
|
315 |
+
[S, S, S, 1, 2]]
|
316 |
+
(with S being a special token)
|
317 |
+
|
318 |
+
Args:
|
319 |
+
n_q (int): Number of codebooks.
|
320 |
+
delays (Optional[List[int]]): Delay for each of the codebooks.
|
321 |
+
If delays not defined, each codebook is delayed by 1 compared to the previous one.
|
322 |
+
flatten_first (int): Flatten the first N timesteps.
|
323 |
+
empty_initial (int): Prepend with N empty list of coordinates.
|
324 |
+
"""
|
325 |
+
def __init__(self, n_q: int, delays: tp.Optional[tp.List[int]] = None,
|
326 |
+
flatten_first: int = 0, empty_initial: int = 0):
|
327 |
+
super().__init__(n_q)
|
328 |
+
if delays is None:
|
329 |
+
delays = list(range(n_q))
|
330 |
+
self.delays = delays
|
331 |
+
self.flatten_first = flatten_first
|
332 |
+
self.empty_initial = empty_initial
|
333 |
+
assert len(self.delays) == self.n_q
|
334 |
+
assert sorted(self.delays) == self.delays
|
335 |
+
|
336 |
+
def get_pattern(self, timesteps: int) -> Pattern:
|
337 |
+
out: PatternLayout = [[]]
|
338 |
+
max_delay = max(self.delays)
|
339 |
+
if self.empty_initial:
|
340 |
+
out += [[] for _ in range(self.empty_initial)]
|
341 |
+
if self.flatten_first:
|
342 |
+
for t in range(min(timesteps, self.flatten_first)):
|
343 |
+
for q in range(self.n_q):
|
344 |
+
out.append([LayoutCoord(t, q)])
|
345 |
+
for t in range(self.flatten_first, timesteps + max_delay):
|
346 |
+
v = []
|
347 |
+
for q, delay in enumerate(self.delays):
|
348 |
+
t_for_q = t - delay
|
349 |
+
if t_for_q >= self.flatten_first:
|
350 |
+
v.append(LayoutCoord(t_for_q, q))
|
351 |
+
out.append(v)
|
352 |
+
return Pattern(out, n_q=self.n_q, timesteps=timesteps)
|
353 |
+
|
354 |
+
|
355 |
+
class ParallelPatternProvider(DelayedPatternProvider):
|
356 |
+
"""Provider for parallel pattern across codebooks.
|
357 |
+
This pattern provider is a special case of the delayed pattern with actually no delay,
|
358 |
+
hence delays=repeat(0, n_q).
|
359 |
+
|
360 |
+
Args:
|
361 |
+
n_q (int): Number of codebooks.
|
362 |
+
"""
|
363 |
+
def __init__(self, n_q: int):
|
364 |
+
super().__init__(n_q, [0] * n_q)
|
365 |
+
|
366 |
+
|
367 |
+
class UnrolledPatternProvider(CodebooksPatternProvider):
|
368 |
+
"""Provider for unrolling codebooks pattern.
|
369 |
+
This pattern provider enables to represent the codebook flattened completely or only to some extend
|
370 |
+
while also specifying a given delay between the flattened codebooks representation, allowing to
|
371 |
+
unroll the codebooks in the sequence.
|
372 |
+
|
373 |
+
Example:
|
374 |
+
1. Flattening of the codebooks.
|
375 |
+
By default, the pattern provider will fully flatten the codebooks such as flattening=range(n_q),
|
376 |
+
taking n_q = 3 and timesteps = 4:
|
377 |
+
[[1, 2, 3, 4],
|
378 |
+
[1, 2, 3, 4],
|
379 |
+
[1, 2, 3, 4]]
|
380 |
+
will result into:
|
381 |
+
[[S, S, 1, S, S, 2, S, S, 3, S, S, 4],
|
382 |
+
[S, 1, S, S, 2, S, S, 3, S, S, 4, S],
|
383 |
+
[1, S, S, 2, S, S, 3, S, S, 4, S, S]]
|
384 |
+
2. Partial flattening of the codebooks. The ``flattening`` parameter allows to specify the inner step
|
385 |
+
for each of the codebook, allowing to define which codebook to flatten (or keep in parallel), for example
|
386 |
+
taking n_q = 3, timesteps = 4 and flattening = [0, 1, 1]:
|
387 |
+
[[1, 2, 3, 4],
|
388 |
+
[1, 2, 3, 4],
|
389 |
+
[1, 2, 3, 4]]
|
390 |
+
will result into:
|
391 |
+
[[S, 1, S, S, 2, S, S, 3, S, S, 4, S],
|
392 |
+
[S, 1, S, S, 2, S, S, 3, S, S, 4, S],
|
393 |
+
[1, S, S, 2, S, S, 3, S, S, 4, S, S]]
|
394 |
+
3. Flattening with delay. The ``delay`` parameter allows to further unroll the sequence of codebooks
|
395 |
+
allowing to specify the delay per codebook. Note that the delay between codebooks flattened to the
|
396 |
+
same inner timestep should be coherent. For example, taking n_q = 3, timesteps = 4, flattening = [0, 1, 1]
|
397 |
+
and delays = [0, 3, 3]:
|
398 |
+
[[1, 2, 3, 4],
|
399 |
+
[1, 2, 3, 4],
|
400 |
+
[1, 2, 3, 4]]
|
401 |
+
will result into:
|
402 |
+
[[S, S, S, 1, S, 2, S, 3, S, 4],
|
403 |
+
[S, S, S, 1, S, 2, S, 3, S, 4],
|
404 |
+
[1, 2, 3, S, 4, S, 5, S, 6, S]]
|
405 |
+
|
406 |
+
Args:
|
407 |
+
n_q (int): Number of codebooks.
|
408 |
+
flattening (Optional[List[int]]): Flattening schema over the codebooks. If not defined,
|
409 |
+
the codebooks will be flattened to 1 codebook per step, meaning that the sequence will
|
410 |
+
have n_q extra steps for each timestep.
|
411 |
+
delays (Optional[List[int]]): Delay for each of the codebooks. If not defined,
|
412 |
+
no delay is added and therefore will default to [0] * ``n_q``.
|
413 |
+
Note that two codebooks that will be flattened to the same inner step
|
414 |
+
should have the same delay, otherwise the pattern is considered as invalid.
|
415 |
+
"""
|
416 |
+
FlattenedCodebook = namedtuple('FlattenedCodebook', ['codebooks', 'delay'])
|
417 |
+
|
418 |
+
def __init__(self, n_q: int, flattening: tp.Optional[tp.List[int]] = None,
|
419 |
+
delays: tp.Optional[tp.List[int]] = None):
|
420 |
+
super().__init__(n_q)
|
421 |
+
if flattening is None:
|
422 |
+
flattening = list(range(n_q))
|
423 |
+
if delays is None:
|
424 |
+
delays = [0] * n_q
|
425 |
+
assert len(flattening) == n_q
|
426 |
+
assert len(delays) == n_q
|
427 |
+
assert sorted(flattening) == flattening
|
428 |
+
assert sorted(delays) == delays
|
429 |
+
self._flattened_codebooks = self._build_flattened_codebooks(delays, flattening)
|
430 |
+
self.max_delay = max(delays)
|
431 |
+
|
432 |
+
def _build_flattened_codebooks(self, delays: tp.List[int], flattening: tp.List[int]):
|
433 |
+
"""Build a flattened codebooks representation as a dictionary of inner step
|
434 |
+
and the actual codebook indices corresponding to the flattened codebook. For convenience, we
|
435 |
+
also store the delay associated to the flattened codebook to avoid maintaining an extra mapping.
|
436 |
+
"""
|
437 |
+
flattened_codebooks: dict = {}
|
438 |
+
for q, (inner_step, delay) in enumerate(zip(flattening, delays)):
|
439 |
+
if inner_step not in flattened_codebooks:
|
440 |
+
flat_codebook = UnrolledPatternProvider.FlattenedCodebook(codebooks=[q], delay=delay)
|
441 |
+
else:
|
442 |
+
flat_codebook = flattened_codebooks[inner_step]
|
443 |
+
assert flat_codebook.delay == delay, (
|
444 |
+
"Delay and flattening between codebooks is inconsistent: ",
|
445 |
+
"two codebooks flattened to the same position should have the same delay."
|
446 |
+
)
|
447 |
+
flat_codebook.codebooks.append(q)
|
448 |
+
flattened_codebooks[inner_step] = flat_codebook
|
449 |
+
return flattened_codebooks
|
450 |
+
|
451 |
+
@property
|
452 |
+
def _num_inner_steps(self):
|
453 |
+
"""Number of inner steps to unroll between timesteps in order to flatten the codebooks.
|
454 |
+
"""
|
455 |
+
return max([inner_step for inner_step in self._flattened_codebooks.keys()]) + 1
|
456 |
+
|
457 |
+
def num_virtual_steps(self, timesteps: int) -> int:
|
458 |
+
return timesteps * self._num_inner_steps + 1
|
459 |
+
|
460 |
+
def get_pattern(self, timesteps: int) -> Pattern:
|
461 |
+
"""Builds pattern for delay across codebooks.
|
462 |
+
|
463 |
+
Args:
|
464 |
+
timesteps (int): Total numer of timesteps.
|
465 |
+
"""
|
466 |
+
# the PatternLayout is built as a tuple of sequence position and list of coordinates
|
467 |
+
# so that it can be reordered properly given the required delay between codebooks of given timesteps
|
468 |
+
indexed_out: list = [(-1, [])]
|
469 |
+
max_timesteps = timesteps + self.max_delay
|
470 |
+
for t in range(max_timesteps):
|
471 |
+
# for each timestep, we unroll the flattened codebooks,
|
472 |
+
# emitting the sequence step with the corresponding delay
|
473 |
+
for step in range(self._num_inner_steps):
|
474 |
+
if step in self._flattened_codebooks:
|
475 |
+
# we have codebooks at this virtual step to emit
|
476 |
+
step_codebooks = self._flattened_codebooks[step]
|
477 |
+
t_for_q = t + step_codebooks.delay
|
478 |
+
coords = [LayoutCoord(t, q) for q in step_codebooks.codebooks]
|
479 |
+
if t_for_q < max_timesteps and t < max_timesteps:
|
480 |
+
indexed_out.append((t_for_q, coords))
|
481 |
+
else:
|
482 |
+
# there is no codebook in this virtual step so we emit an empty list
|
483 |
+
indexed_out.append((t, []))
|
484 |
+
out = [coords for _, coords in sorted(indexed_out)]
|
485 |
+
return Pattern(out, n_q=self.n_q, timesteps=timesteps)
|
486 |
+
|
487 |
+
|
488 |
+
class VALLEPattern(CodebooksPatternProvider):
|
489 |
+
"""Almost VALL-E style pattern. We futher allow some delays for the
|
490 |
+
codebooks other than the first one.
|
491 |
+
|
492 |
+
Args:
|
493 |
+
n_q (int): Number of codebooks.
|
494 |
+
delays (Optional[List[int]]): Delay for each of the codebooks.
|
495 |
+
If delays not defined, each codebook is delayed by 1 compared to the previous one.
|
496 |
+
"""
|
497 |
+
def __init__(self, n_q: int, delays: tp.Optional[tp.List[int]] = None):
|
498 |
+
super().__init__(n_q)
|
499 |
+
if delays is None:
|
500 |
+
delays = [0] * (n_q - 1)
|
501 |
+
self.delays = delays
|
502 |
+
assert len(self.delays) == self.n_q - 1
|
503 |
+
assert sorted(self.delays) == self.delays
|
504 |
+
|
505 |
+
def get_pattern(self, timesteps: int) -> Pattern:
|
506 |
+
out: PatternLayout = [[]]
|
507 |
+
for t in range(timesteps):
|
508 |
+
out.append([LayoutCoord(t, 0)])
|
509 |
+
max_delay = max(self.delays)
|
510 |
+
for t in range(timesteps + max_delay):
|
511 |
+
v = []
|
512 |
+
for q, delay in enumerate(self.delays):
|
513 |
+
t_for_q = t - delay
|
514 |
+
if t_for_q >= 0:
|
515 |
+
v.append(LayoutCoord(t_for_q, q + 1))
|
516 |
+
out.append(v)
|
517 |
+
return Pattern(out, n_q=self.n_q, timesteps=timesteps)
|
518 |
+
|
519 |
+
|
520 |
+
class MusicLMPattern(CodebooksPatternProvider):
|
521 |
+
"""Almost MusicLM style pattern. This is equivalent to full flattening
|
522 |
+
but in a different order.
|
523 |
+
|
524 |
+
Args:
|
525 |
+
n_q (int): Number of codebooks.
|
526 |
+
group_by (int): Number of codebooks to group together.
|
527 |
+
"""
|
528 |
+
def __init__(self, n_q: int, group_by: int = 2):
|
529 |
+
super().__init__(n_q)
|
530 |
+
self.group_by = group_by
|
531 |
+
|
532 |
+
def get_pattern(self, timesteps: int) -> Pattern:
|
533 |
+
out: PatternLayout = [[]]
|
534 |
+
for offset in range(0, self.n_q, self.group_by):
|
535 |
+
for t in range(timesteps):
|
536 |
+
for q in range(offset, offset + self.group_by):
|
537 |
+
out.append([LayoutCoord(t, q)])
|
538 |
+
return Pattern(out, n_q=self.n_q, timesteps=timesteps)
|
models/modules/__init__.py
ADDED
File without changes
|
models/modules/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (140 Bytes). View file
|
|
models/modules/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (132 Bytes). View file
|
|
models/modules/__pycache__/activation.cpython-310.pyc
ADDED
Binary file (18.8 kB). View file
|
|
models/modules/__pycache__/activation.cpython-39.pyc
ADDED
Binary file (18.8 kB). View file
|
|
models/modules/__pycache__/embedding.cpython-310.pyc
ADDED
Binary file (3.07 kB). View file
|
|
models/modules/__pycache__/embedding.cpython-39.pyc
ADDED
Binary file (3.04 kB). View file
|
|
models/modules/__pycache__/scaling.cpython-310.pyc
ADDED
Binary file (40.4 kB). View file
|
|
models/modules/__pycache__/scaling.cpython-39.pyc
ADDED
Binary file (40 kB). View file
|
|
models/modules/__pycache__/transformer.cpython-310.pyc
ADDED
Binary file (16.1 kB). View file
|
|
models/modules/__pycache__/transformer.cpython-39.pyc
ADDED
Binary file (15.8 kB). View file
|
|
models/modules/__pycache__/utils.cpython-310.pyc
ADDED
Binary file (1.42 kB). View file
|
|
models/modules/__pycache__/utils.cpython-39.pyc
ADDED
Binary file (1.41 kB). View file
|
|
models/modules/__pycache__/visualizer.cpython-39.pyc
ADDED
Binary file (2.02 kB). View file
|
|
models/modules/activation.py
ADDED
@@ -0,0 +1,653 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# cp from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/activation.py, modified by Puyuan Peng, 2024
|
2 |
+
from typing import Optional, Tuple
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch import Tensor
|
6 |
+
from torch.nn import Linear, Module
|
7 |
+
from torch.nn import functional as F
|
8 |
+
from torch.nn.init import constant_, xavier_normal_, xavier_uniform_
|
9 |
+
from torch.nn.modules.linear import NonDynamicallyQuantizableLinear
|
10 |
+
from torch.nn.parameter import Parameter
|
11 |
+
import logging
|
12 |
+
from typing import Callable, List, Optional, Tuple, Union
|
13 |
+
from typing import TYPE_CHECKING
|
14 |
+
if TYPE_CHECKING:
|
15 |
+
from torch.types import _dtype as DType
|
16 |
+
else:
|
17 |
+
# The JIT doesn't understand Union, nor torch.dtype here
|
18 |
+
DType = int
|
19 |
+
|
20 |
+
def _canonical_mask(
|
21 |
+
mask: Optional[Tensor],
|
22 |
+
mask_name: str,
|
23 |
+
other_type: Optional[DType],
|
24 |
+
other_name: str,
|
25 |
+
target_type: DType,
|
26 |
+
check_other: bool = True,
|
27 |
+
) -> Optional[Tensor]:
|
28 |
+
|
29 |
+
if mask is not None:
|
30 |
+
_mask_dtype = mask.dtype
|
31 |
+
_mask_is_float = torch.is_floating_point(mask)
|
32 |
+
if _mask_dtype != torch.bool and not _mask_is_float:
|
33 |
+
raise AssertionError(
|
34 |
+
f"only bool and floating types of {mask_name} are supported")
|
35 |
+
if check_other and other_type is not None:
|
36 |
+
if _mask_dtype != other_type:
|
37 |
+
warnings.warn(
|
38 |
+
f"Support for mismatched {mask_name} and {other_name} "
|
39 |
+
"is deprecated. Use same type for both instead."
|
40 |
+
)
|
41 |
+
if not _mask_is_float:
|
42 |
+
mask = (
|
43 |
+
torch.zeros_like(mask, dtype=target_type)
|
44 |
+
.masked_fill_(mask, float("-inf"))
|
45 |
+
)
|
46 |
+
return mask
|
47 |
+
|
48 |
+
def _in_projection_packed(
|
49 |
+
q: Tensor,
|
50 |
+
k: Tensor,
|
51 |
+
v: Tensor,
|
52 |
+
w: Tensor,
|
53 |
+
b: Optional[Tensor] = None,
|
54 |
+
) -> List[Tensor]:
|
55 |
+
r"""
|
56 |
+
Performs the in-projection step of the attention operation, using packed weights.
|
57 |
+
Output is a triple containing projection tensors for query, key and value.
|
58 |
+
|
59 |
+
Args:
|
60 |
+
q, k, v: query, key and value tensors to be projected. For self-attention,
|
61 |
+
these are typically the same tensor; for encoder-decoder attention,
|
62 |
+
k and v are typically the same tensor. (We take advantage of these
|
63 |
+
identities for performance if they are present.) Regardless, q, k and v
|
64 |
+
must share a common embedding dimension; otherwise their shapes may vary.
|
65 |
+
w: projection weights for q, k and v, packed into a single tensor. Weights
|
66 |
+
are packed along dimension 0, in q, k, v order.
|
67 |
+
b: optional projection biases for q, k and v, packed into a single tensor
|
68 |
+
in q, k, v order.
|
69 |
+
|
70 |
+
Shape:
|
71 |
+
Inputs:
|
72 |
+
- q: :math:`(..., E)` where E is the embedding dimension
|
73 |
+
- k: :math:`(..., E)` where E is the embedding dimension
|
74 |
+
- v: :math:`(..., E)` where E is the embedding dimension
|
75 |
+
- w: :math:`(E * 3, E)` where E is the embedding dimension
|
76 |
+
- b: :math:`E * 3` where E is the embedding dimension
|
77 |
+
|
78 |
+
Output:
|
79 |
+
- in output list :math:`[q', k', v']`, each output tensor will have the
|
80 |
+
same shape as the corresponding input tensor.
|
81 |
+
"""
|
82 |
+
E = q.size(-1)
|
83 |
+
if k is v:
|
84 |
+
if q is k:
|
85 |
+
# self-attention
|
86 |
+
proj = F.linear(q, w, b)
|
87 |
+
# reshape to 3, E and not E, 3 is deliberate for better memory coalescing and keeping same order as chunk()
|
88 |
+
proj = proj.unflatten(-1, (3, E)).unsqueeze(0).transpose(0, -2).squeeze(-2).contiguous()
|
89 |
+
return proj[0], proj[1], proj[2]
|
90 |
+
else:
|
91 |
+
# encoder-decoder attention
|
92 |
+
w_q, w_kv = w.split([E, E * 2])
|
93 |
+
if b is None:
|
94 |
+
b_q = b_kv = None
|
95 |
+
else:
|
96 |
+
b_q, b_kv = b.split([E, E * 2])
|
97 |
+
q_proj = F.linear(q, w_q, b_q)
|
98 |
+
kv_proj = F.linear(k, w_kv, b_kv)
|
99 |
+
# reshape to 2, E and not E, 2 is deliberate for better memory coalescing and keeping same order as chunk()
|
100 |
+
kv_proj = kv_proj.unflatten(-1, (2, E)).unsqueeze(0).transpose(0, -2).squeeze(-2).contiguous()
|
101 |
+
return (q_proj, kv_proj[0], kv_proj[1])
|
102 |
+
else:
|
103 |
+
w_q, w_k, w_v = w.chunk(3)
|
104 |
+
if b is None:
|
105 |
+
b_q = b_k = b_v = None
|
106 |
+
else:
|
107 |
+
b_q, b_k, b_v = b.chunk(3)
|
108 |
+
return F.linear(q, w_q, b_q), F.linear(k, w_k, b_k), F.linear(v, w_v, b_v)
|
109 |
+
|
110 |
+
def _none_or_dtype(input: Optional[Tensor]) -> Optional[DType]:
|
111 |
+
if input is None:
|
112 |
+
return None
|
113 |
+
elif isinstance(input, torch.Tensor):
|
114 |
+
return input.dtype
|
115 |
+
raise RuntimeError("input to _none_or_dtype() must be None or torch.Tensor")
|
116 |
+
class MultiheadAttention(Module):
|
117 |
+
r"""Allows the model to jointly attend to information
|
118 |
+
from different representation subspaces as described in the paper:
|
119 |
+
`Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_.
|
120 |
+
|
121 |
+
Multi-Head Attention is defined as:
|
122 |
+
|
123 |
+
.. math::
|
124 |
+
\text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
|
125 |
+
|
126 |
+
where :math:`head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`.
|
127 |
+
|
128 |
+
``forward()`` will use a special optimized implementation if all of the following
|
129 |
+
conditions are met:
|
130 |
+
|
131 |
+
- self attention is being computed (i.e., ``query``, ``key``, and ``value`` are the same tensor. This
|
132 |
+
restriction will be loosened in the future.)
|
133 |
+
- Either autograd is disabled (using ``torch.inference_mode`` or ``torch.no_grad``) or no tensor argument ``requires_grad``
|
134 |
+
- training is disabled (using ``.eval()``)
|
135 |
+
- dropout is 0
|
136 |
+
- ``add_bias_kv`` is ``False``
|
137 |
+
- ``add_zero_attn`` is ``False``
|
138 |
+
- ``batch_first`` is ``True`` and the input is batched
|
139 |
+
- ``kdim`` and ``vdim`` are equal to ``embed_dim``
|
140 |
+
- at most one of ``key_padding_mask`` or ``attn_mask`` is passed
|
141 |
+
- if a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ is passed, neither ``key_padding_mask``
|
142 |
+
nor ``attn_mask`` is passed
|
143 |
+
|
144 |
+
If the optimized implementation is in use, a
|
145 |
+
`NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ can be passed for
|
146 |
+
``query``/``key``/``value`` to represent padding more efficiently than using a
|
147 |
+
padding mask. In this case, a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_
|
148 |
+
will be returned, and an additional speedup proportional to the fraction of the input
|
149 |
+
that is padding can be expected.
|
150 |
+
|
151 |
+
Args:
|
152 |
+
embed_dim: Total dimension of the model.
|
153 |
+
num_heads: Number of parallel attention heads. Note that ``embed_dim`` will be split
|
154 |
+
across ``num_heads`` (i.e. each head will have dimension ``embed_dim // num_heads``).
|
155 |
+
dropout: Dropout probability on ``attn_output_weights``. Default: ``0.0`` (no dropout).
|
156 |
+
bias: If specified, adds bias to input / output projection layers. Default: ``True``.
|
157 |
+
add_bias_kv: If specified, adds bias to the key and value sequences at dim=0. Default: ``False``.
|
158 |
+
add_zero_attn: If specified, adds a new batch of zeros to the key and value sequences at dim=1.
|
159 |
+
Default: ``False``.
|
160 |
+
kdim: Total number of features for keys. Default: ``None`` (uses ``kdim=embed_dim``).
|
161 |
+
vdim: Total number of features for values. Default: ``None`` (uses ``vdim=embed_dim``).
|
162 |
+
batch_first: If ``True``, then the input and output tensors are provided
|
163 |
+
as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
|
164 |
+
|
165 |
+
Examples::
|
166 |
+
|
167 |
+
>>> # xdoctest: +SKIP
|
168 |
+
>>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
|
169 |
+
>>> attn_output, attn_output_weights = multihead_attn(query, key, value)
|
170 |
+
|
171 |
+
"""
|
172 |
+
__constants__ = ["batch_first"]
|
173 |
+
bias_k: Optional[torch.Tensor]
|
174 |
+
bias_v: Optional[torch.Tensor]
|
175 |
+
|
176 |
+
def __init__(
|
177 |
+
self,
|
178 |
+
embed_dim,
|
179 |
+
num_heads,
|
180 |
+
dropout=0.0,
|
181 |
+
bias=True,
|
182 |
+
add_bias_kv=False,
|
183 |
+
add_zero_attn=False,
|
184 |
+
kdim=None,
|
185 |
+
vdim=None,
|
186 |
+
batch_first=False,
|
187 |
+
linear1_cls=Linear,
|
188 |
+
linear2_cls=Linear,
|
189 |
+
device=None,
|
190 |
+
dtype=None,
|
191 |
+
) -> None:
|
192 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
193 |
+
super(MultiheadAttention, self).__init__()
|
194 |
+
self.embed_dim = embed_dim
|
195 |
+
self.kdim = kdim if kdim is not None else embed_dim
|
196 |
+
self.vdim = vdim if vdim is not None else embed_dim
|
197 |
+
self._qkv_same_embed_dim = (
|
198 |
+
self.kdim == embed_dim and self.vdim == embed_dim
|
199 |
+
)
|
200 |
+
|
201 |
+
self.num_heads = num_heads
|
202 |
+
self.dropout = dropout
|
203 |
+
self.batch_first = batch_first
|
204 |
+
self.head_dim = embed_dim // num_heads
|
205 |
+
assert (
|
206 |
+
self.head_dim * num_heads == self.embed_dim
|
207 |
+
), "embed_dim must be divisible by num_heads"
|
208 |
+
|
209 |
+
if add_bias_kv:
|
210 |
+
self.bias_k = Parameter(
|
211 |
+
torch.empty((1, 1, embed_dim), **factory_kwargs)
|
212 |
+
)
|
213 |
+
self.bias_v = Parameter(
|
214 |
+
torch.empty((1, 1, embed_dim), **factory_kwargs)
|
215 |
+
)
|
216 |
+
else:
|
217 |
+
self.bias_k = self.bias_v = None
|
218 |
+
|
219 |
+
if linear1_cls == Linear:
|
220 |
+
if not self._qkv_same_embed_dim:
|
221 |
+
self.q_proj_weight = Parameter(
|
222 |
+
torch.empty((embed_dim, embed_dim), **factory_kwargs)
|
223 |
+
)
|
224 |
+
self.k_proj_weight = Parameter(
|
225 |
+
torch.empty((embed_dim, self.kdim), **factory_kwargs)
|
226 |
+
)
|
227 |
+
self.v_proj_weight = Parameter(
|
228 |
+
torch.empty((embed_dim, self.vdim), **factory_kwargs)
|
229 |
+
)
|
230 |
+
self.register_parameter("in_proj_weight", None)
|
231 |
+
else:
|
232 |
+
# go down this route with voicecraft
|
233 |
+
self.in_proj_weight = Parameter(
|
234 |
+
torch.empty((3 * embed_dim, embed_dim), **factory_kwargs)
|
235 |
+
)
|
236 |
+
self.register_parameter("q_proj_weight", None)
|
237 |
+
self.register_parameter("k_proj_weight", None)
|
238 |
+
self.register_parameter("v_proj_weight", None)
|
239 |
+
|
240 |
+
if bias: # True by default
|
241 |
+
self.in_proj_bias = Parameter(
|
242 |
+
torch.empty(3 * embed_dim, **factory_kwargs)
|
243 |
+
)
|
244 |
+
else:
|
245 |
+
self.register_parameter("in_proj_bias", None)
|
246 |
+
self.out_proj = NonDynamicallyQuantizableLinear(
|
247 |
+
embed_dim, embed_dim, bias=bias, **factory_kwargs
|
248 |
+
)
|
249 |
+
|
250 |
+
self._reset_parameters()
|
251 |
+
else:
|
252 |
+
if not self._qkv_same_embed_dim:
|
253 |
+
raise NotImplementedError
|
254 |
+
else:
|
255 |
+
self.in_proj_linear = linear1_cls(
|
256 |
+
embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs
|
257 |
+
)
|
258 |
+
self.in_proj_weight = self.in_proj_linear.weight
|
259 |
+
|
260 |
+
self.register_parameter("q_proj_weight", None)
|
261 |
+
self.register_parameter("k_proj_weight", None)
|
262 |
+
self.register_parameter("v_proj_weight", None)
|
263 |
+
|
264 |
+
if bias:
|
265 |
+
self.in_proj_bias = self.in_proj_linear.bias
|
266 |
+
else:
|
267 |
+
self.register_parameter("in_proj_bias", None)
|
268 |
+
|
269 |
+
self.out_proj = linear2_cls(
|
270 |
+
embed_dim, embed_dim, bias=bias, **factory_kwargs
|
271 |
+
)
|
272 |
+
|
273 |
+
if self.bias_k is not None:
|
274 |
+
xavier_normal_(self.bias_k)
|
275 |
+
if self.bias_v is not None:
|
276 |
+
xavier_normal_(self.bias_v)
|
277 |
+
|
278 |
+
self.add_zero_attn = add_zero_attn
|
279 |
+
|
280 |
+
def _reset_parameters(self):
|
281 |
+
if self._qkv_same_embed_dim:
|
282 |
+
xavier_uniform_(self.in_proj_weight)
|
283 |
+
else:
|
284 |
+
xavier_uniform_(self.q_proj_weight)
|
285 |
+
xavier_uniform_(self.k_proj_weight)
|
286 |
+
xavier_uniform_(self.v_proj_weight)
|
287 |
+
|
288 |
+
if self.in_proj_bias is not None:
|
289 |
+
constant_(self.in_proj_bias, 0.0)
|
290 |
+
constant_(self.out_proj.bias, 0.0)
|
291 |
+
|
292 |
+
if self.bias_k is not None:
|
293 |
+
xavier_normal_(self.bias_k)
|
294 |
+
if self.bias_v is not None:
|
295 |
+
xavier_normal_(self.bias_v)
|
296 |
+
|
297 |
+
def __setstate__(self, state):
|
298 |
+
# Support loading old MultiheadAttention checkpoints generated by v1.1.0
|
299 |
+
if "_qkv_same_embed_dim" not in state:
|
300 |
+
state["_qkv_same_embed_dim"] = True
|
301 |
+
|
302 |
+
super(MultiheadAttention, self).__setstate__(state)
|
303 |
+
|
304 |
+
def forward(
|
305 |
+
self,
|
306 |
+
query: Tensor,
|
307 |
+
key: Tensor,
|
308 |
+
value: Tensor,
|
309 |
+
key_padding_mask: Optional[Tensor] = None,
|
310 |
+
need_weights: bool = True,
|
311 |
+
attn_mask: Optional[Tensor] = None,
|
312 |
+
average_attn_weights: bool = True,
|
313 |
+
past: Optional[Tensor] = None,
|
314 |
+
) -> Tuple[Tensor, Optional[Tensor]]:
|
315 |
+
r"""
|
316 |
+
Args:
|
317 |
+
query: Query embeddings of shape :math:`(L, E_q)` for unbatched input, :math:`(L, N, E_q)` when ``batch_first=False``
|
318 |
+
or :math:`(N, L, E_q)` when ``batch_first=True``, where :math:`L` is the target sequence length,
|
319 |
+
:math:`N` is the batch size, and :math:`E_q` is the query embedding dimension ``embed_dim``.
|
320 |
+
Queries are compared against key-value pairs to produce the output.
|
321 |
+
See "Attention Is All You Need" for more details.
|
322 |
+
key: Key embeddings of shape :math:`(S, E_k)` for unbatched input, :math:`(S, N, E_k)` when ``batch_first=False``
|
323 |
+
or :math:`(N, S, E_k)` when ``batch_first=True``, where :math:`S` is the source sequence length,
|
324 |
+
:math:`N` is the batch size, and :math:`E_k` is the key embedding dimension ``kdim``.
|
325 |
+
See "Attention Is All You Need" for more details.
|
326 |
+
value: Value embeddings of shape :math:`(S, E_v)` for unbatched input, :math:`(S, N, E_v)` when
|
327 |
+
``batch_first=False`` or :math:`(N, S, E_v)` when ``batch_first=True``, where :math:`S` is the source
|
328 |
+
sequence length, :math:`N` is the batch size, and :math:`E_v` is the value embedding dimension ``vdim``.
|
329 |
+
See "Attention Is All You Need" for more details.
|
330 |
+
key_padding_mask: If specified, a mask of shape :math:`(N, S)` indicating which elements within ``key``
|
331 |
+
to ignore for the purpose of attention (i.e. treat as "padding"). For unbatched `query`, shape should be :math:`(S)`.
|
332 |
+
Binary and byte masks are supported.
|
333 |
+
For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for
|
334 |
+
the purpose of attention. For a float mask, it will be directly added to the corresponding ``key`` value.
|
335 |
+
need_weights: If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``.
|
336 |
+
Default: ``True``.
|
337 |
+
attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape
|
338 |
+
:math:`(L, S)` or :math:`(N\cdot\text{num\_heads}, L, S)`, where :math:`N` is the batch size,
|
339 |
+
:math:`L` is the target sequence length, and :math:`S` is the source sequence length. A 2D mask will be
|
340 |
+
broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch.
|
341 |
+
Binary, byte, and float masks are supported. For a binary mask, a ``True`` value indicates that the
|
342 |
+
corresponding position is not allowed to attend. For a byte mask, a non-zero value indicates that the
|
343 |
+
corresponding position is not allowed to attend. For a float mask, the mask values will be added to
|
344 |
+
the attention weight.
|
345 |
+
average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across
|
346 |
+
heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an
|
347 |
+
effect when ``need_weights=True``. Default: ``True`` (i.e. average weights across heads)
|
348 |
+
|
349 |
+
Outputs:
|
350 |
+
- **attn_output** - Attention outputs of shape :math:`(L, E)` when input is unbatched,
|
351 |
+
:math:`(L, N, E)` when ``batch_first=False`` or :math:`(N, L, E)` when ``batch_first=True``,
|
352 |
+
where :math:`L` is the target sequence length, :math:`N` is the batch size, and :math:`E` is the
|
353 |
+
embedding dimension ``embed_dim``.
|
354 |
+
- **attn_output_weights** - Only returned when ``need_weights=True``. If ``average_attn_weights=True``,
|
355 |
+
returns attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or
|
356 |
+
:math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and
|
357 |
+
:math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per
|
358 |
+
head of shape :math:`(\text{num\_heads}, L, S)` when input is unbatched or :math:`(N, \text{num\_heads}, L, S)`.
|
359 |
+
|
360 |
+
.. note::
|
361 |
+
`batch_first` argument is ignored for unbatched inputs.
|
362 |
+
"""
|
363 |
+
is_batched = query.dim() == 3
|
364 |
+
if key_padding_mask is not None:
|
365 |
+
_kpm_dtype = key_padding_mask.dtype
|
366 |
+
if _kpm_dtype != torch.bool and not torch.is_floating_point(
|
367 |
+
key_padding_mask
|
368 |
+
):
|
369 |
+
raise AssertionError(
|
370 |
+
"only bool and floating types of key_padding_mask are supported"
|
371 |
+
)
|
372 |
+
why_not_fast_path = ""
|
373 |
+
if not is_batched:
|
374 |
+
why_not_fast_path = f"input not batched; expected query.dim() of 3 but got {query.dim()}"
|
375 |
+
elif query is not key or key is not value:
|
376 |
+
# When lifting this restriction, don't forget to either
|
377 |
+
# enforce that the dtypes all match or test cases where
|
378 |
+
# they don't!
|
379 |
+
why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)"
|
380 |
+
elif (
|
381 |
+
self.in_proj_bias is not None
|
382 |
+
and query.dtype != self.in_proj_bias.dtype
|
383 |
+
):
|
384 |
+
why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match"
|
385 |
+
elif (
|
386 |
+
self.in_proj_weight is not None
|
387 |
+
and query.dtype != self.in_proj_weight.dtype
|
388 |
+
):
|
389 |
+
# this case will fail anyway, but at least they'll get a useful error message.
|
390 |
+
why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match"
|
391 |
+
elif self.training:
|
392 |
+
why_not_fast_path = "training is enabled"
|
393 |
+
elif not self.batch_first:
|
394 |
+
why_not_fast_path = "batch_first was not True"
|
395 |
+
elif self.bias_k is not None:
|
396 |
+
why_not_fast_path = "self.bias_k was not None"
|
397 |
+
elif self.bias_v is not None:
|
398 |
+
why_not_fast_path = "self.bias_v was not None"
|
399 |
+
elif self.dropout:
|
400 |
+
why_not_fast_path = f"dropout was {self.dropout}, required zero"
|
401 |
+
elif self.add_zero_attn:
|
402 |
+
why_not_fast_path = "add_zero_attn was enabled"
|
403 |
+
elif not self._qkv_same_embed_dim:
|
404 |
+
why_not_fast_path = "_qkv_same_embed_dim was not True"
|
405 |
+
elif attn_mask is not None:
|
406 |
+
why_not_fast_path = "attn_mask was not None"
|
407 |
+
elif query.is_nested and key_padding_mask is not None:
|
408 |
+
why_not_fast_path = (
|
409 |
+
"key_padding_mask is not supported with NestedTensor input"
|
410 |
+
)
|
411 |
+
elif self.num_heads % 2 == 1:
|
412 |
+
why_not_fast_path = "num_heads is odd"
|
413 |
+
elif torch.is_autocast_enabled():
|
414 |
+
why_not_fast_path = "autocast is enabled"
|
415 |
+
|
416 |
+
if not why_not_fast_path:
|
417 |
+
tensor_args = (
|
418 |
+
query,
|
419 |
+
key,
|
420 |
+
value,
|
421 |
+
self.in_proj_weight,
|
422 |
+
self.in_proj_bias,
|
423 |
+
self.out_proj.weight,
|
424 |
+
self.out_proj.bias,
|
425 |
+
)
|
426 |
+
# We have to use list comprehensions below because TorchScript does not support
|
427 |
+
# generator expressions.
|
428 |
+
if torch.overrides.has_torch_function(tensor_args):
|
429 |
+
why_not_fast_path = "some Tensor argument has_torch_function"
|
430 |
+
elif not all(
|
431 |
+
[
|
432 |
+
(x is None or x.is_cuda or "cpu" in str(x.device))
|
433 |
+
for x in tensor_args
|
434 |
+
]
|
435 |
+
):
|
436 |
+
why_not_fast_path = (
|
437 |
+
"some Tensor argument is neither CUDA nor CPU"
|
438 |
+
)
|
439 |
+
elif torch.is_grad_enabled() and any(
|
440 |
+
[x is not None and x.requires_grad for x in tensor_args]
|
441 |
+
):
|
442 |
+
why_not_fast_path = (
|
443 |
+
"grad is enabled and at least one of query or the "
|
444 |
+
"input/output projection weights or biases requires_grad"
|
445 |
+
)
|
446 |
+
if not why_not_fast_path:
|
447 |
+
return torch._native_multi_head_attention(
|
448 |
+
query,
|
449 |
+
key,
|
450 |
+
value,
|
451 |
+
self.embed_dim,
|
452 |
+
self.num_heads,
|
453 |
+
self.in_proj_weight,
|
454 |
+
self.in_proj_bias,
|
455 |
+
self.out_proj.weight,
|
456 |
+
self.out_proj.bias,
|
457 |
+
key_padding_mask
|
458 |
+
if key_padding_mask is not None
|
459 |
+
else attn_mask,
|
460 |
+
need_weights,
|
461 |
+
average_attn_weights,
|
462 |
+
1
|
463 |
+
if key_padding_mask is not None
|
464 |
+
else 0
|
465 |
+
if attn_mask is not None
|
466 |
+
else None,
|
467 |
+
)
|
468 |
+
|
469 |
+
any_nested = query.is_nested or key.is_nested or value.is_nested
|
470 |
+
assert not any_nested, (
|
471 |
+
"MultiheadAttention does not support NestedTensor outside of its fast path. "
|
472 |
+
+ f"The fast path was not hit because {why_not_fast_path}"
|
473 |
+
)
|
474 |
+
|
475 |
+
if self.batch_first and is_batched:
|
476 |
+
# make sure that the transpose op does not affect the "is" property
|
477 |
+
if key is value:
|
478 |
+
if query is key:
|
479 |
+
query = key = value = query.transpose(1, 0)
|
480 |
+
else:
|
481 |
+
query, key = [x.transpose(1, 0) for x in (query, key)]
|
482 |
+
value = key
|
483 |
+
else:
|
484 |
+
query, key, value = [
|
485 |
+
x.transpose(1, 0) for x in (query, key, value)
|
486 |
+
]
|
487 |
+
|
488 |
+
if not self._qkv_same_embed_dim:
|
489 |
+
attn_output, attn_output_weights = F.multi_head_attention_forward(
|
490 |
+
query,
|
491 |
+
key,
|
492 |
+
value,
|
493 |
+
self.embed_dim,
|
494 |
+
self.num_heads,
|
495 |
+
self.in_proj_weight,
|
496 |
+
self.in_proj_bias,
|
497 |
+
self.bias_k,
|
498 |
+
self.bias_v,
|
499 |
+
self.add_zero_attn,
|
500 |
+
self.dropout,
|
501 |
+
self.out_proj.weight,
|
502 |
+
self.out_proj.bias,
|
503 |
+
training=self.training,
|
504 |
+
key_padding_mask=key_padding_mask,
|
505 |
+
need_weights=need_weights,
|
506 |
+
attn_mask=attn_mask,
|
507 |
+
use_separate_proj_weight=True,
|
508 |
+
q_proj_weight=self.q_proj_weight,
|
509 |
+
k_proj_weight=self.k_proj_weight,
|
510 |
+
v_proj_weight=self.v_proj_weight,
|
511 |
+
average_attn_weights=average_attn_weights,
|
512 |
+
)
|
513 |
+
else:
|
514 |
+
# re-write the self.attention here, to get k, v cache
|
515 |
+
tgt_len, bsz, embed_dim = query.shape
|
516 |
+
src_len, _, _ = key.shape
|
517 |
+
num_heads = self.num_heads
|
518 |
+
key_padding_mask = _canonical_mask(
|
519 |
+
mask=key_padding_mask,
|
520 |
+
mask_name="key_padding_mask",
|
521 |
+
other_type=_none_or_dtype(attn_mask),
|
522 |
+
other_name="attn_mask",
|
523 |
+
target_type=query.dtype
|
524 |
+
)
|
525 |
+
attn_mask = _canonical_mask(
|
526 |
+
mask=attn_mask,
|
527 |
+
mask_name="attn_mask",
|
528 |
+
other_type=None,
|
529 |
+
other_name="",
|
530 |
+
target_type=query.dtype,
|
531 |
+
check_other=False,
|
532 |
+
)
|
533 |
+
head_dim = self.embed_dim // self.num_heads
|
534 |
+
assert head_dim * self.num_heads == self.embed_dim, f"embed_dim {self.embed_dim} not divisible by num_heads {self.num_heads}"
|
535 |
+
assert key.shape == value.shape, f"key shape {key.shape} does not match value shape {value.shape}"
|
536 |
+
q, k, v = _in_projection_packed(query, key, value, self.in_proj_weight, self.in_proj_bias)
|
537 |
+
# k_present, v_present = k, v
|
538 |
+
|
539 |
+
#
|
540 |
+
# reshape q, k, v for multihead attention and make em batch first
|
541 |
+
#
|
542 |
+
|
543 |
+
q = q.view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
|
544 |
+
k = k.view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
|
545 |
+
v = v.view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1) # (bsz * num_heads, src_len, head_dim)
|
546 |
+
src_len = k.size(1)
|
547 |
+
if past is not None and past.ndim > 2:
|
548 |
+
expected_src_len = src_len + past[0].shape[-2]
|
549 |
+
else:
|
550 |
+
expected_src_len = src_len
|
551 |
+
|
552 |
+
|
553 |
+
# ensure attn_mask's dim is 3
|
554 |
+
if attn_mask.dim() == 2:
|
555 |
+
correct_2d_size = (tgt_len, expected_src_len)
|
556 |
+
if attn_mask.shape != correct_2d_size:
|
557 |
+
raise RuntimeError(f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}.")
|
558 |
+
attn_mask = attn_mask.unsqueeze(0)
|
559 |
+
elif attn_mask.dim() == 3:
|
560 |
+
correct_3d_size = (bsz * num_heads, tgt_len, expected_src_len)
|
561 |
+
if attn_mask.shape != correct_3d_size:
|
562 |
+
raise RuntimeError(f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}.")
|
563 |
+
else:
|
564 |
+
raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported")
|
565 |
+
|
566 |
+
if key_padding_mask is not None:
|
567 |
+
assert key_padding_mask.shape == (bsz, expected_src_len), \
|
568 |
+
f"expecting key_padding_mask shape of {(bsz, expected_src_len)}, but got {key_padding_mask.shape}"
|
569 |
+
key_padding_mask = key_padding_mask.view(bsz, 1, 1, expected_src_len). \
|
570 |
+
expand(-1, num_heads, -1, -1).reshape(bsz * num_heads, 1, expected_src_len)
|
571 |
+
if attn_mask is None:
|
572 |
+
attn_mask = key_padding_mask
|
573 |
+
else:
|
574 |
+
attn_mask = attn_mask + key_padding_mask
|
575 |
+
|
576 |
+
if not self.training:
|
577 |
+
dropout_p = 0.0
|
578 |
+
else:
|
579 |
+
dropout_p = self.dropout
|
580 |
+
|
581 |
+
if need_weights:
|
582 |
+
raise NotImplementedError("need_weights not implemented for voicecraft")
|
583 |
+
# B, Nt, E = q.shape
|
584 |
+
# q_scaled = q / math.sqrt(E)
|
585 |
+
|
586 |
+
# assert not (is_causal and attn_mask is None), "FIXME: is_causal not implemented for need_weights"
|
587 |
+
|
588 |
+
# if attn_mask is not None:
|
589 |
+
# attn_output_weights = torch.baddbmm(attn_mask, q_scaled, k.transpose(-2, -1))
|
590 |
+
# else:
|
591 |
+
# attn_output_weights = torch.bmm(q_scaled, k.transpose(-2, -1))
|
592 |
+
# attn_output_weights = softmax(attn_output_weights, dim=-1)
|
593 |
+
# if dropout_p > 0.0:
|
594 |
+
# attn_output_weights = dropout(attn_output_weights, p=dropout_p)
|
595 |
+
|
596 |
+
# attn_output = torch.bmm(attn_output_weights, v)
|
597 |
+
|
598 |
+
# attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim)
|
599 |
+
# attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
|
600 |
+
# attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
|
601 |
+
|
602 |
+
# # optionally average attention weights over heads
|
603 |
+
# attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
|
604 |
+
# if average_attn_weights:
|
605 |
+
# attn_output_weights = attn_output_weights.mean(dim=1)
|
606 |
+
|
607 |
+
# if not is_batched:
|
608 |
+
# # squeeze the output if input was unbatched
|
609 |
+
# attn_output = attn_output.squeeze(1)
|
610 |
+
# attn_output_weights = attn_output_weights.squeeze(0)
|
611 |
+
# return attn_output, attn_output_weights
|
612 |
+
else:
|
613 |
+
# attn_mask can be either (L,S) or (N*num_heads, L, S)
|
614 |
+
# if attn_mask's shape is (1, L, S) we need to unsqueeze to (1, 1, L, S)
|
615 |
+
# in order to match the input for SDPA of (N, num_heads, L, S)
|
616 |
+
if attn_mask is not None:
|
617 |
+
if attn_mask.size(0) == 1 and attn_mask.dim() == 3:
|
618 |
+
attn_mask = attn_mask.unsqueeze(0)
|
619 |
+
else:
|
620 |
+
attn_mask = attn_mask.view(bsz, num_heads, -1, expected_src_len)
|
621 |
+
|
622 |
+
q = q.view(bsz, num_heads, tgt_len, head_dim)
|
623 |
+
k = k.view(bsz, num_heads, src_len, head_dim)
|
624 |
+
v = v.view(bsz, num_heads, src_len, head_dim)
|
625 |
+
# logging.info(f"shape of past: {past.shape}")
|
626 |
+
if past is not None:
|
627 |
+
present = torch.stack([k, v], dim=0) # (2, bsz, num_heads, src_len, head_dim)
|
628 |
+
if past.ndim > 2: # this means we use kvcache, otherwise we just pass in a placeholder, but not actually using kvcache
|
629 |
+
pk, pv = past
|
630 |
+
k = torch.cat([pk, k], dim=-2)
|
631 |
+
v = torch.cat([pv, v], dim=-2)
|
632 |
+
else:
|
633 |
+
present = None
|
634 |
+
attn_output = F.scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal=False)
|
635 |
+
attn_output = attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim)
|
636 |
+
|
637 |
+
attn_output = F.linear(attn_output, self.out_proj.weight, self.out_proj.bias)
|
638 |
+
attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
|
639 |
+
if not is_batched:
|
640 |
+
# squeeze the output if input was unbatched
|
641 |
+
attn_output = attn_output.squeeze(1)
|
642 |
+
# if self.training:
|
643 |
+
# return attn_output, None
|
644 |
+
# else:
|
645 |
+
# return (attn_output, present), None
|
646 |
+
|
647 |
+
# harded coded, the code do not support returning attn weigths yet
|
648 |
+
attn_output_weights=None
|
649 |
+
if self.batch_first and is_batched:
|
650 |
+
return attn_output.transpose(1, 0), present
|
651 |
+
else:
|
652 |
+
return attn_output, present
|
653 |
+
|
models/modules/embedding.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# cp from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/embedding.py
|
2 |
+
# Copyright 2023 (authors: Feiteng Li)
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import math
|
17 |
+
|
18 |
+
import torch
|
19 |
+
import torch.nn as nn
|
20 |
+
|
21 |
+
|
22 |
+
class TokenEmbedding(nn.Module):
|
23 |
+
def __init__(
|
24 |
+
self,
|
25 |
+
dim_model: int,
|
26 |
+
vocab_size: int,
|
27 |
+
dropout: float = 0.0,
|
28 |
+
):
|
29 |
+
super().__init__()
|
30 |
+
|
31 |
+
self.vocab_size = vocab_size
|
32 |
+
self.dim_model = dim_model
|
33 |
+
|
34 |
+
self.dropout = torch.nn.Dropout(p=dropout)
|
35 |
+
self.word_embeddings = nn.Embedding(self.vocab_size, self.dim_model)
|
36 |
+
|
37 |
+
@property
|
38 |
+
def weight(self) -> torch.Tensor:
|
39 |
+
return self.word_embeddings.weight
|
40 |
+
|
41 |
+
def embedding(self, index: int) -> torch.Tensor:
|
42 |
+
return self.word_embeddings.weight[index : index + 1]
|
43 |
+
|
44 |
+
def forward(self, x: torch.Tensor):
|
45 |
+
X = self.word_embeddings(x)
|
46 |
+
X = self.dropout(X)
|
47 |
+
|
48 |
+
return X
|
49 |
+
|
50 |
+
|
51 |
+
class SinePositionalEmbedding(nn.Module):
|
52 |
+
def __init__(
|
53 |
+
self,
|
54 |
+
dim_model: int,
|
55 |
+
dropout: float = 0.0,
|
56 |
+
scale: bool = False,
|
57 |
+
alpha: bool = False,
|
58 |
+
):
|
59 |
+
super().__init__()
|
60 |
+
self.dim_model = dim_model
|
61 |
+
self.x_scale = math.sqrt(dim_model) if scale else 1.0
|
62 |
+
self.alpha = nn.Parameter(torch.ones(1), requires_grad=alpha)
|
63 |
+
self.dropout = torch.nn.Dropout(p=dropout)
|
64 |
+
|
65 |
+
self.reverse = False
|
66 |
+
self.pe = None
|
67 |
+
self.extend_pe(torch.tensor(0.0).expand(1, 4000))
|
68 |
+
|
69 |
+
def extend_pe(self, x):
|
70 |
+
"""Reset the positional encodings."""
|
71 |
+
if self.pe is not None:
|
72 |
+
if self.pe.size(1) >= x.size(1):
|
73 |
+
if self.pe.dtype != x.dtype or self.pe.device != x.device:
|
74 |
+
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
|
75 |
+
return
|
76 |
+
pe = torch.zeros(x.size(1), self.dim_model)
|
77 |
+
if self.reverse:
|
78 |
+
position = torch.arange(
|
79 |
+
x.size(1) - 1, -1, -1.0, dtype=torch.float32
|
80 |
+
).unsqueeze(1)
|
81 |
+
else:
|
82 |
+
position = torch.arange(
|
83 |
+
0, x.size(1), dtype=torch.float32
|
84 |
+
).unsqueeze(1)
|
85 |
+
div_term = torch.exp(
|
86 |
+
torch.arange(0, self.dim_model, 2, dtype=torch.float32)
|
87 |
+
* -(math.log(10000.0) / self.dim_model)
|
88 |
+
)
|
89 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
90 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
91 |
+
pe = pe.unsqueeze(0)
|
92 |
+
self.pe = pe.to(device=x.device, dtype=x.dtype).detach()
|
93 |
+
|
94 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
95 |
+
self.extend_pe(x)
|
96 |
+
output = x.unsqueeze(-1) if x.ndim == 2 else x
|
97 |
+
output = output * self.x_scale + self.alpha * self.pe[:, : x.size(1)]
|
98 |
+
return self.dropout(output)
|
models/modules/sampling.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
|
4 |
+
def top_k_top_p_filtering(
|
5 |
+
logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1
|
6 |
+
):
|
7 |
+
"""Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
|
8 |
+
Args:
|
9 |
+
logits: logits distribution shape (batch size, vocabulary size)
|
10 |
+
if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
|
11 |
+
if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
|
12 |
+
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
|
13 |
+
Make sure we keep at least min_tokens_to_keep per batch example in the output
|
14 |
+
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
|
15 |
+
"""
|
16 |
+
if top_k > 0:
|
17 |
+
top_k = min(
|
18 |
+
max(top_k, min_tokens_to_keep), logits.size(-1)
|
19 |
+
) # Safety check
|
20 |
+
# Remove all tokens with a probability less than the last token of the top-k
|
21 |
+
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
22 |
+
logits[indices_to_remove] = filter_value
|
23 |
+
|
24 |
+
if top_p < 1.0:
|
25 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
26 |
+
cumulative_probs = torch.cumsum(
|
27 |
+
F.softmax(sorted_logits, dim=-1), dim=-1
|
28 |
+
)
|
29 |
+
|
30 |
+
# Remove tokens with cumulative probability above the threshold (token with 0 are kept)
|
31 |
+
sorted_indices_to_remove = cumulative_probs > top_p
|
32 |
+
if min_tokens_to_keep > 1:
|
33 |
+
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
|
34 |
+
sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
|
35 |
+
# Shift the indices to the right to keep also the first token above the threshold
|
36 |
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
|
37 |
+
..., :-1
|
38 |
+
].clone()
|
39 |
+
sorted_indices_to_remove[..., 0] = 0
|
40 |
+
|
41 |
+
# scatter sorted tensors to original indexing
|
42 |
+
indices_to_remove = sorted_indices_to_remove.scatter(
|
43 |
+
1, sorted_indices, sorted_indices_to_remove
|
44 |
+
)
|
45 |
+
logits[indices_to_remove] = filter_value
|
46 |
+
return logits
|
47 |
+
|
48 |
+
def topk_sampling(logits, top_k=10, top_p=1.0, temperature=1.0):
|
49 |
+
# temperature: (`optional`) float
|
50 |
+
# The value used to module the next token probabilities. Must be strictly positive. Default to 1.0.
|
51 |
+
# top_k: (`optional`) int
|
52 |
+
# The number of highest probability vocabulary tokens to keep for top-k-filtering. Between 1 and infinity. Default to 50.
|
53 |
+
# top_p: (`optional`) float
|
54 |
+
# The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Must be between 0 and 1. Default to 1.
|
55 |
+
|
56 |
+
# Temperature (higher temperature => more likely to sample low probability tokens)
|
57 |
+
if temperature != 1.0:
|
58 |
+
logits = logits / temperature
|
59 |
+
# Top-p/top-k filtering
|
60 |
+
logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
|
61 |
+
# Sample
|
62 |
+
token = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
|
63 |
+
return token
|
models/modules/scaling.py
ADDED
@@ -0,0 +1,1406 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# cp from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/scaling.py
|
2 |
+
# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey)
|
3 |
+
#
|
4 |
+
# See ../../../../LICENSE for clarification regarding multiple authors
|
5 |
+
#
|
6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
7 |
+
# you may not use this file except in compliance with the License.
|
8 |
+
# You may obtain a copy of the License at
|
9 |
+
#
|
10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
11 |
+
#
|
12 |
+
# Unless required by applicable law or agreed to in writing, software
|
13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15 |
+
# See the License for the specific language governing permissions and
|
16 |
+
# limitations under the License.
|
17 |
+
|
18 |
+
|
19 |
+
import collections
|
20 |
+
import logging
|
21 |
+
import random
|
22 |
+
import math
|
23 |
+
from functools import reduce
|
24 |
+
from itertools import repeat
|
25 |
+
from typing import Optional, Tuple, Union
|
26 |
+
|
27 |
+
import torch
|
28 |
+
import torch.nn as nn
|
29 |
+
import torch.nn.functional as F
|
30 |
+
from torch import Tensor
|
31 |
+
from torch.nn import Embedding as ScaledEmbedding
|
32 |
+
|
33 |
+
# from valle.utils import Transpose
|
34 |
+
|
35 |
+
class Transpose(nn.Identity):
|
36 |
+
"""(N, T, D) -> (N, D, T)"""
|
37 |
+
|
38 |
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
39 |
+
return input.transpose(1, 2)
|
40 |
+
|
41 |
+
class ActivationBalancerFunction(torch.autograd.Function):
|
42 |
+
@staticmethod
|
43 |
+
def forward(
|
44 |
+
ctx,
|
45 |
+
x: Tensor,
|
46 |
+
scale_factor: Tensor,
|
47 |
+
sign_factor: Optional[Tensor],
|
48 |
+
channel_dim: int,
|
49 |
+
) -> Tensor:
|
50 |
+
if channel_dim < 0:
|
51 |
+
channel_dim += x.ndim
|
52 |
+
ctx.channel_dim = channel_dim
|
53 |
+
xgt0 = x > 0
|
54 |
+
if sign_factor is None:
|
55 |
+
ctx.save_for_backward(xgt0, scale_factor)
|
56 |
+
else:
|
57 |
+
ctx.save_for_backward(xgt0, scale_factor, sign_factor)
|
58 |
+
return x
|
59 |
+
|
60 |
+
@staticmethod
|
61 |
+
def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None]:
|
62 |
+
if len(ctx.saved_tensors) == 3:
|
63 |
+
xgt0, scale_factor, sign_factor = ctx.saved_tensors
|
64 |
+
for _ in range(ctx.channel_dim, x_grad.ndim - 1):
|
65 |
+
scale_factor = scale_factor.unsqueeze(-1)
|
66 |
+
sign_factor = sign_factor.unsqueeze(-1)
|
67 |
+
factor = sign_factor + scale_factor * (xgt0.to(x_grad.dtype) - 0.5)
|
68 |
+
else:
|
69 |
+
xgt0, scale_factor = ctx.saved_tensors
|
70 |
+
for _ in range(ctx.channel_dim, x_grad.ndim - 1):
|
71 |
+
scale_factor = scale_factor.unsqueeze(-1)
|
72 |
+
factor = scale_factor * (xgt0.to(x_grad.dtype) - 0.5)
|
73 |
+
neg_delta_grad = x_grad.abs() * factor
|
74 |
+
return (
|
75 |
+
x_grad - neg_delta_grad,
|
76 |
+
None,
|
77 |
+
None,
|
78 |
+
None,
|
79 |
+
)
|
80 |
+
|
81 |
+
|
82 |
+
def _compute_scale_factor(
|
83 |
+
x: Tensor,
|
84 |
+
channel_dim: int,
|
85 |
+
min_abs: float,
|
86 |
+
max_abs: float,
|
87 |
+
gain_factor: float,
|
88 |
+
max_factor: float,
|
89 |
+
) -> Tensor:
|
90 |
+
if channel_dim < 0:
|
91 |
+
channel_dim += x.ndim
|
92 |
+
sum_dims = [d for d in range(x.ndim) if d != channel_dim]
|
93 |
+
x_abs_mean = torch.mean(x.abs(), dim=sum_dims).to(torch.float32)
|
94 |
+
|
95 |
+
if min_abs == 0.0:
|
96 |
+
below_threshold = 0.0
|
97 |
+
else:
|
98 |
+
# below_threshold is 0 if x_abs_mean > min_abs, can be at most max_factor if
|
99 |
+
# x_abs)_mean , min_abs.
|
100 |
+
below_threshold = (
|
101 |
+
(min_abs - x_abs_mean) * (gain_factor / min_abs)
|
102 |
+
).clamp(min=0, max=max_factor)
|
103 |
+
|
104 |
+
above_threshold = ((x_abs_mean - max_abs) * (gain_factor / max_abs)).clamp(
|
105 |
+
min=0, max=max_factor
|
106 |
+
)
|
107 |
+
|
108 |
+
return below_threshold - above_threshold
|
109 |
+
|
110 |
+
|
111 |
+
def _compute_sign_factor(
|
112 |
+
x: Tensor,
|
113 |
+
channel_dim: int,
|
114 |
+
min_positive: float,
|
115 |
+
max_positive: float,
|
116 |
+
gain_factor: float,
|
117 |
+
max_factor: float,
|
118 |
+
) -> Tensor:
|
119 |
+
if channel_dim < 0:
|
120 |
+
channel_dim += x.ndim
|
121 |
+
sum_dims = [d for d in range(x.ndim) if d != channel_dim]
|
122 |
+
proportion_positive = torch.mean((x > 0).to(torch.float32), dim=sum_dims)
|
123 |
+
if min_positive == 0.0:
|
124 |
+
factor1 = 0.0
|
125 |
+
else:
|
126 |
+
# 0 if proportion_positive >= min_positive, else can be
|
127 |
+
# as large as max_factor.
|
128 |
+
factor1 = (
|
129 |
+
(min_positive - proportion_positive) * (gain_factor / min_positive)
|
130 |
+
).clamp_(min=0, max=max_factor)
|
131 |
+
|
132 |
+
if max_positive == 1.0:
|
133 |
+
factor2 = 0.0
|
134 |
+
else:
|
135 |
+
# 0 if self.proportion_positive <= max_positive, else can be
|
136 |
+
# as large as -max_factor.
|
137 |
+
factor2 = (
|
138 |
+
(proportion_positive - max_positive)
|
139 |
+
* (gain_factor / (1.0 - max_positive))
|
140 |
+
).clamp_(min=0, max=max_factor)
|
141 |
+
sign_factor = factor1 - factor2
|
142 |
+
# require min_positive != 0 or max_positive != 1:
|
143 |
+
assert not isinstance(sign_factor, float)
|
144 |
+
return sign_factor
|
145 |
+
|
146 |
+
|
147 |
+
class ActivationScaleBalancerFunction(torch.autograd.Function):
|
148 |
+
"""
|
149 |
+
This object is used in class ActivationBalancer when the user specified
|
150 |
+
min_positive=0, max_positive=1, so there are no constraints on the signs
|
151 |
+
of the activations and only the absolute value has a constraint.
|
152 |
+
"""
|
153 |
+
|
154 |
+
@staticmethod
|
155 |
+
def forward(
|
156 |
+
ctx,
|
157 |
+
x: Tensor,
|
158 |
+
sign_factor: Tensor,
|
159 |
+
scale_factor: Tensor,
|
160 |
+
channel_dim: int,
|
161 |
+
) -> Tensor:
|
162 |
+
if channel_dim < 0:
|
163 |
+
channel_dim += x.ndim
|
164 |
+
ctx.channel_dim = channel_dim
|
165 |
+
xgt0 = x > 0
|
166 |
+
ctx.save_for_backward(xgt0, sign_factor, scale_factor)
|
167 |
+
return x
|
168 |
+
|
169 |
+
@staticmethod
|
170 |
+
def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None]:
|
171 |
+
xgt0, sign_factor, scale_factor = ctx.saved_tensors
|
172 |
+
for _ in range(ctx.channel_dim, x_grad.ndim - 1):
|
173 |
+
sign_factor = sign_factor.unsqueeze(-1)
|
174 |
+
scale_factor = scale_factor.unsqueeze(-1)
|
175 |
+
|
176 |
+
factor = sign_factor + scale_factor * (xgt0.to(x_grad.dtype) - 0.5)
|
177 |
+
neg_delta_grad = x_grad.abs() * factor
|
178 |
+
return (
|
179 |
+
x_grad - neg_delta_grad,
|
180 |
+
None,
|
181 |
+
None,
|
182 |
+
None,
|
183 |
+
)
|
184 |
+
|
185 |
+
|
186 |
+
class RandomClampFunction(torch.autograd.Function):
|
187 |
+
@staticmethod
|
188 |
+
def forward(
|
189 |
+
ctx,
|
190 |
+
x: Tensor,
|
191 |
+
min: Optional[float],
|
192 |
+
max: Optional[float],
|
193 |
+
prob: float,
|
194 |
+
reflect: float,
|
195 |
+
) -> Tensor:
|
196 |
+
x_clamped = torch.clamp(x, min=min, max=max)
|
197 |
+
mask = torch.rand_like(x) < prob
|
198 |
+
ans = torch.where(mask, x_clamped, x)
|
199 |
+
if x.requires_grad:
|
200 |
+
ctx.save_for_backward(ans == x)
|
201 |
+
ctx.reflect = reflect
|
202 |
+
if reflect != 0.0:
|
203 |
+
ans = ans * (1.0 + reflect) - (x * reflect)
|
204 |
+
return ans
|
205 |
+
|
206 |
+
@staticmethod
|
207 |
+
def backward(
|
208 |
+
ctx, ans_grad: Tensor
|
209 |
+
) -> Tuple[Tensor, None, None, None, None]:
|
210 |
+
(is_same,) = ctx.saved_tensors
|
211 |
+
x_grad = ans_grad * is_same.to(ans_grad.dtype)
|
212 |
+
reflect = ctx.reflect
|
213 |
+
if reflect != 0.0:
|
214 |
+
x_grad = x_grad * (1.0 + reflect) - (ans_grad * reflect)
|
215 |
+
return x_grad, None, None, None, None
|
216 |
+
|
217 |
+
|
218 |
+
def random_clamp(
|
219 |
+
x: Tensor,
|
220 |
+
min: Optional[float] = None,
|
221 |
+
max: Optional[float] = None,
|
222 |
+
prob: float = 0.5,
|
223 |
+
reflect: float = 0.0,
|
224 |
+
):
|
225 |
+
return RandomClampFunction.apply(x, min, max, prob, reflect)
|
226 |
+
|
227 |
+
|
228 |
+
def random_cast_to_half(x: Tensor, min_abs: float = 5.0e-06) -> Tensor:
|
229 |
+
"""
|
230 |
+
A randomized way of casting a floating point value to half precision.
|
231 |
+
"""
|
232 |
+
if x.dtype == torch.float16:
|
233 |
+
return x
|
234 |
+
x_abs = x.abs()
|
235 |
+
is_too_small = x_abs < min_abs
|
236 |
+
# for elements where is_too_small is true, random_val will contain +-min_abs with
|
237 |
+
# probability (x.abs() / min_abs), and 0.0 otherwise. [so this preserves expectations,
|
238 |
+
# for those elements].
|
239 |
+
random_val = min_abs * x.sign() * (torch.rand_like(x) * min_abs < x_abs)
|
240 |
+
return torch.where(is_too_small, random_val, x).to(torch.float16)
|
241 |
+
|
242 |
+
|
243 |
+
class RandomGradFunction(torch.autograd.Function):
|
244 |
+
"""
|
245 |
+
Does nothing in forward pass; in backward pass, gets rid of very small grads using
|
246 |
+
randomized approach that preserves expectations (intended to reduce roundoff).
|
247 |
+
"""
|
248 |
+
|
249 |
+
@staticmethod
|
250 |
+
def forward(ctx, x: Tensor, min_abs: float) -> Tensor:
|
251 |
+
ctx.min_abs = min_abs
|
252 |
+
return x
|
253 |
+
|
254 |
+
@staticmethod
|
255 |
+
def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None]:
|
256 |
+
if ans_grad.dtype == torch.float16:
|
257 |
+
return (
|
258 |
+
random_cast_to_half(
|
259 |
+
ans_grad.to(torch.float32), min_abs=ctx.min_abs
|
260 |
+
),
|
261 |
+
None,
|
262 |
+
)
|
263 |
+
else:
|
264 |
+
return ans_grad, None
|
265 |
+
|
266 |
+
|
267 |
+
class RandomGrad(torch.nn.Module):
|
268 |
+
"""
|
269 |
+
Gets rid of very small gradients using an expectation-preserving method, intended to increase
|
270 |
+
accuracy of training when using amp (automatic mixed precision)
|
271 |
+
"""
|
272 |
+
|
273 |
+
def __init__(self, min_abs: float = 5.0e-06):
|
274 |
+
super(RandomGrad, self).__init__()
|
275 |
+
self.min_abs = min_abs
|
276 |
+
|
277 |
+
def forward(self, x: Tensor):
|
278 |
+
if (
|
279 |
+
torch.jit.is_scripting()
|
280 |
+
or not self.training
|
281 |
+
or torch.jit.is_tracing()
|
282 |
+
):
|
283 |
+
return x
|
284 |
+
else:
|
285 |
+
return RandomGradFunction.apply(x, self.min_abs)
|
286 |
+
|
287 |
+
|
288 |
+
class SoftmaxFunction(torch.autograd.Function):
|
289 |
+
"""
|
290 |
+
Tries to handle half-precision derivatives in a randomized way that should
|
291 |
+
be more accurate for training than the default behavior.
|
292 |
+
"""
|
293 |
+
|
294 |
+
@staticmethod
|
295 |
+
def forward(ctx, x: Tensor, dim: int):
|
296 |
+
ans = x.softmax(dim=dim)
|
297 |
+
# if x dtype is float16, x.softmax() returns a float32 because
|
298 |
+
# (presumably) that op does not support float16, and autocast
|
299 |
+
# is enabled.
|
300 |
+
if torch.is_autocast_enabled():
|
301 |
+
ans = ans.to(torch.float16)
|
302 |
+
ctx.save_for_backward(ans)
|
303 |
+
ctx.x_dtype = x.dtype
|
304 |
+
ctx.dim = dim
|
305 |
+
return ans
|
306 |
+
|
307 |
+
@staticmethod
|
308 |
+
def backward(ctx, ans_grad: Tensor):
|
309 |
+
(ans,) = ctx.saved_tensors
|
310 |
+
with torch.cuda.amp.autocast(enabled=False):
|
311 |
+
ans_grad = ans_grad.to(torch.float32)
|
312 |
+
ans = ans.to(torch.float32)
|
313 |
+
x_grad = ans_grad * ans
|
314 |
+
x_grad = x_grad - ans * x_grad.sum(dim=ctx.dim, keepdim=True)
|
315 |
+
return x_grad, None
|
316 |
+
|
317 |
+
|
318 |
+
def softmax(x: Tensor, dim: int):
|
319 |
+
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
320 |
+
return x.softmax(dim)
|
321 |
+
|
322 |
+
return SoftmaxFunction.apply(x, dim)
|
323 |
+
|
324 |
+
|
325 |
+
class MaxEigLimiterFunction(torch.autograd.Function):
|
326 |
+
@staticmethod
|
327 |
+
def forward(
|
328 |
+
ctx,
|
329 |
+
x: Tensor,
|
330 |
+
coeffs: Tensor,
|
331 |
+
direction: Tensor,
|
332 |
+
channel_dim: int,
|
333 |
+
grad_scale: float,
|
334 |
+
) -> Tensor:
|
335 |
+
ctx.channel_dim = channel_dim
|
336 |
+
ctx.grad_scale = grad_scale
|
337 |
+
ctx.save_for_backward(x.detach(), coeffs.detach(), direction.detach())
|
338 |
+
return x
|
339 |
+
|
340 |
+
@staticmethod
|
341 |
+
def backward(ctx, x_grad, *args):
|
342 |
+
with torch.enable_grad():
|
343 |
+
(x_orig, coeffs, new_direction) = ctx.saved_tensors
|
344 |
+
x_orig.requires_grad = True
|
345 |
+
num_channels = x_orig.shape[ctx.channel_dim]
|
346 |
+
x = x_orig.transpose(ctx.channel_dim, -1).reshape(-1, num_channels)
|
347 |
+
new_direction.requires_grad = False
|
348 |
+
x = x - x.mean(dim=0)
|
349 |
+
x_var = (x ** 2).mean()
|
350 |
+
x_residual = x - coeffs * new_direction
|
351 |
+
x_residual_var = (x_residual ** 2).mean()
|
352 |
+
# `variance_proportion` is the proportion of the variance accounted for
|
353 |
+
# by the top eigen-direction. This is to be minimized.
|
354 |
+
variance_proportion = (x_var - x_residual_var) / (x_var + 1.0e-20)
|
355 |
+
variance_proportion.backward()
|
356 |
+
x_orig_grad = x_orig.grad
|
357 |
+
x_extra_grad = (
|
358 |
+
x_orig.grad
|
359 |
+
* ctx.grad_scale
|
360 |
+
* x_grad.norm()
|
361 |
+
/ (x_orig_grad.norm() + 1.0e-20)
|
362 |
+
)
|
363 |
+
return x_grad + x_extra_grad.detach(), None, None, None, None
|
364 |
+
|
365 |
+
|
366 |
+
class BasicNorm(torch.nn.Module):
|
367 |
+
"""
|
368 |
+
This is intended to be a simpler, and hopefully cheaper, replacement for
|
369 |
+
LayerNorm. The observation this is based on, is that Transformer-type
|
370 |
+
networks, especially with pre-norm, sometimes seem to set one of the
|
371 |
+
feature dimensions to a large constant value (e.g. 50), which "defeats"
|
372 |
+
the LayerNorm because the output magnitude is then not strongly dependent
|
373 |
+
on the other (useful) features. Presumably the weight and bias of the
|
374 |
+
LayerNorm are required to allow it to do this.
|
375 |
+
|
376 |
+
So the idea is to introduce this large constant value as an explicit
|
377 |
+
parameter, that takes the role of the "eps" in LayerNorm, so the network
|
378 |
+
doesn't have to do this trick. We make the "eps" learnable.
|
379 |
+
|
380 |
+
Args:
|
381 |
+
num_channels: the number of channels, e.g. 512.
|
382 |
+
channel_dim: the axis/dimension corresponding to the channel,
|
383 |
+
interprted as an offset from the input's ndim if negative.
|
384 |
+
shis is NOT the num_channels; it should typically be one of
|
385 |
+
{-2, -1, 0, 1, 2, 3}.
|
386 |
+
eps: the initial "epsilon" that we add as ballast in:
|
387 |
+
scale = ((input_vec**2).mean() + epsilon)**-0.5
|
388 |
+
Note: our epsilon is actually large, but we keep the name
|
389 |
+
to indicate the connection with conventional LayerNorm.
|
390 |
+
learn_eps: if true, we learn epsilon; if false, we keep it
|
391 |
+
at the initial value.
|
392 |
+
eps_min: float
|
393 |
+
eps_max: float
|
394 |
+
"""
|
395 |
+
|
396 |
+
def __init__(
|
397 |
+
self,
|
398 |
+
num_channels: int,
|
399 |
+
channel_dim: int = -1, # CAUTION: see documentation.
|
400 |
+
eps: float = 0.25,
|
401 |
+
learn_eps: bool = True,
|
402 |
+
eps_min: float = -3.0,
|
403 |
+
eps_max: float = 3.0,
|
404 |
+
) -> None:
|
405 |
+
super(BasicNorm, self).__init__()
|
406 |
+
self.num_channels = num_channels
|
407 |
+
self.channel_dim = channel_dim
|
408 |
+
if learn_eps:
|
409 |
+
self.eps = nn.Parameter(torch.tensor(eps).log().detach())
|
410 |
+
else:
|
411 |
+
self.register_buffer("eps", torch.tensor(eps).log().detach())
|
412 |
+
self.eps_min = eps_min
|
413 |
+
self.eps_max = eps_max
|
414 |
+
|
415 |
+
def forward(self, x: Tensor) -> Tensor:
|
416 |
+
assert x.shape[self.channel_dim] == self.num_channels
|
417 |
+
eps = self.eps
|
418 |
+
if self.training and random.random() < 0.25:
|
419 |
+
# with probability 0.25, in training mode, clamp eps between the min
|
420 |
+
# and max; this will encourage it to learn parameters within the
|
421 |
+
# allowed range by making parameters that are outside the allowed
|
422 |
+
# range noisy.
|
423 |
+
|
424 |
+
# gradients to allow the parameter to get back into the allowed region if it happens to exit it.
|
425 |
+
eps = eps.clamp(min=self.eps_min, max=self.eps_max)
|
426 |
+
scales = (
|
427 |
+
torch.mean(x ** 2, dim=self.channel_dim, keepdim=True) + eps.exp()
|
428 |
+
) ** -0.5
|
429 |
+
return x * scales
|
430 |
+
|
431 |
+
|
432 |
+
def ScaledLinear(*args, initial_scale: float = 1.0, **kwargs) -> nn.Linear:
|
433 |
+
"""
|
434 |
+
Behaves like a constructor of a modified version of nn.Linear
|
435 |
+
that gives an easy way to set the default initial parameter scale.
|
436 |
+
|
437 |
+
Args:
|
438 |
+
Accepts the standard args and kwargs that nn.Linear accepts
|
439 |
+
e.g. in_features, out_features, bias=False.
|
440 |
+
|
441 |
+
initial_scale: you can override this if you want to increase
|
442 |
+
or decrease the initial magnitude of the module's output
|
443 |
+
(affects the initialization of weight_scale and bias_scale).
|
444 |
+
Another option, if you want to do something like this, is
|
445 |
+
to re-initialize the parameters.
|
446 |
+
"""
|
447 |
+
ans = nn.Linear(*args, **kwargs)
|
448 |
+
with torch.no_grad():
|
449 |
+
ans.weight[:] *= initial_scale
|
450 |
+
if ans.bias is not None:
|
451 |
+
torch.nn.init.uniform_(
|
452 |
+
ans.bias, -0.1 * initial_scale, 0.1 * initial_scale
|
453 |
+
)
|
454 |
+
return ans
|
455 |
+
|
456 |
+
|
457 |
+
def ScaledConv1d(
|
458 |
+
*args,
|
459 |
+
initial_scale: float = 1.0,
|
460 |
+
kernel_size: int = 3,
|
461 |
+
padding: str = "same",
|
462 |
+
**kwargs,
|
463 |
+
) -> nn.Conv1d:
|
464 |
+
"""
|
465 |
+
Behaves like a constructor of a modified version of nn.Conv1d
|
466 |
+
that gives an easy way to set the default initial parameter scale.
|
467 |
+
|
468 |
+
Args:
|
469 |
+
Accepts the standard args and kwargs that nn.Linear accepts
|
470 |
+
e.g. in_features, out_features, bias=False.
|
471 |
+
|
472 |
+
initial_scale: you can override this if you want to increase
|
473 |
+
or decrease the initial magnitude of the module's output
|
474 |
+
(affects the initialization of weight_scale and bias_scale).
|
475 |
+
Another option, if you want to do something like this, is
|
476 |
+
to re-initialize the parameters.
|
477 |
+
"""
|
478 |
+
ans = nn.Conv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs)
|
479 |
+
with torch.no_grad():
|
480 |
+
ans.weight[:] *= initial_scale
|
481 |
+
if ans.bias is not None:
|
482 |
+
torch.nn.init.uniform_(
|
483 |
+
ans.bias, -0.1 * initial_scale, 0.1 * initial_scale
|
484 |
+
)
|
485 |
+
return ans
|
486 |
+
|
487 |
+
|
488 |
+
def TransposeScaledConv1d(
|
489 |
+
*args,
|
490 |
+
initial_scale: float = 1.0,
|
491 |
+
kernel_size: int = 3,
|
492 |
+
padding: str = "same",
|
493 |
+
**kwargs,
|
494 |
+
) -> nn.Sequential:
|
495 |
+
"""
|
496 |
+
Transpose -> ScaledConv1d
|
497 |
+
"""
|
498 |
+
return nn.Sequential(
|
499 |
+
Transpose(),
|
500 |
+
ScaledConv1d(
|
501 |
+
*args,
|
502 |
+
initial_scale=initial_scale,
|
503 |
+
kernel_size=kernel_size,
|
504 |
+
padding=padding,
|
505 |
+
**kwargs,
|
506 |
+
),
|
507 |
+
)
|
508 |
+
|
509 |
+
|
510 |
+
def ScaledConv1dTranspose(
|
511 |
+
*args,
|
512 |
+
initial_scale: float = 1.0,
|
513 |
+
kernel_size: int = 3,
|
514 |
+
padding: str = "same",
|
515 |
+
**kwargs,
|
516 |
+
) -> nn.Sequential:
|
517 |
+
"""
|
518 |
+
Transpose -> ScaledConv1d
|
519 |
+
"""
|
520 |
+
return nn.Sequential(
|
521 |
+
ScaledConv1d(
|
522 |
+
*args,
|
523 |
+
initial_scale=initial_scale,
|
524 |
+
kernel_size=kernel_size,
|
525 |
+
padding=padding,
|
526 |
+
**kwargs,
|
527 |
+
),
|
528 |
+
Transpose(),
|
529 |
+
)
|
530 |
+
|
531 |
+
|
532 |
+
def TransposeConv1d(
|
533 |
+
*args, kernel_size: int = 3, padding: str = "same", **kwargs
|
534 |
+
) -> nn.Sequential:
|
535 |
+
"""
|
536 |
+
Transpose -> Conv1d
|
537 |
+
"""
|
538 |
+
return nn.Sequential(
|
539 |
+
Transpose(),
|
540 |
+
nn.Conv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs),
|
541 |
+
)
|
542 |
+
|
543 |
+
|
544 |
+
def Conv1dTranspose(
|
545 |
+
*args, kernel_size: int = 3, padding: str = "same", **kwargs
|
546 |
+
) -> nn.Sequential:
|
547 |
+
"""
|
548 |
+
ScaledConv1d -> Transpose
|
549 |
+
"""
|
550 |
+
return nn.Sequential(
|
551 |
+
nn.Conv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs),
|
552 |
+
Transpose(),
|
553 |
+
)
|
554 |
+
|
555 |
+
|
556 |
+
class SRLinear(nn.Linear):
|
557 |
+
"""https://arxiv.org/abs/2303.06296
|
558 |
+
Stabilizing Transformer Training by Preventing Attention Entropy Collapse
|
559 |
+
"""
|
560 |
+
|
561 |
+
def __init__(self, in_features, out_features, bias=True, **kwargs):
|
562 |
+
super().__init__(in_features, out_features, bias=bias, **kwargs)
|
563 |
+
self.register_buffer(
|
564 |
+
"u", nn.functional.normalize(torch.randn(in_features), dim=0)
|
565 |
+
)
|
566 |
+
with torch.no_grad():
|
567 |
+
sigma = self.get_sigma()
|
568 |
+
self.register_buffer("spectral_norm", sigma)
|
569 |
+
self.sigma = nn.Parameter(torch.ones(1))
|
570 |
+
|
571 |
+
def get_sigma(self):
|
572 |
+
with torch.no_grad():
|
573 |
+
u = self.u
|
574 |
+
v = self.weight.mv(u)
|
575 |
+
v = nn.functional.normalize(v, dim=0)
|
576 |
+
u = self.weight.T.mv(v)
|
577 |
+
u = nn.functional.normalize(u, dim=0)
|
578 |
+
self.u.data.copy_(u)
|
579 |
+
return torch.einsum("c,cd,d->", v, self.weight, u)
|
580 |
+
|
581 |
+
def get_weight(self):
|
582 |
+
sigma = self.get_sigma()
|
583 |
+
if self.training:
|
584 |
+
self.spectral_norm.data.copy_(sigma)
|
585 |
+
weight = (self.sigma / sigma) * self.weight
|
586 |
+
return weight
|
587 |
+
|
588 |
+
def forward(self, x):
|
589 |
+
return nn.functional.linear(x, self.get_weight(), self.bias)
|
590 |
+
|
591 |
+
|
592 |
+
class SRConv1d(SRLinear):
|
593 |
+
def __init__(
|
594 |
+
self,
|
595 |
+
in_features,
|
596 |
+
out_features,
|
597 |
+
kernel_size,
|
598 |
+
stride: int = 1,
|
599 |
+
padding: str = "same",
|
600 |
+
bias: bool = True,
|
601 |
+
**kwargs,
|
602 |
+
):
|
603 |
+
in_features = in_features * kernel_size
|
604 |
+
super().__init__(in_features, out_features, bias=bias, **kwargs)
|
605 |
+
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
606 |
+
self.kernel_size = kernel_size
|
607 |
+
self.stride = stride
|
608 |
+
self.padding = padding
|
609 |
+
|
610 |
+
def forward(self, x):
|
611 |
+
in_features = self.in_features // self.kernel_size
|
612 |
+
weight = self.get_weight().view(
|
613 |
+
self.out_features, in_features, self.kernel_size
|
614 |
+
)
|
615 |
+
return nn.functional.conv1d(
|
616 |
+
x, weight, bias=self.bias, stride=self.stride, padding=self.padding
|
617 |
+
)
|
618 |
+
|
619 |
+
|
620 |
+
def TransposeSRConv1d(
|
621 |
+
*args, kernel_size: int = 3, padding: str = "same", **kwargs
|
622 |
+
) -> nn.Sequential:
|
623 |
+
"""
|
624 |
+
Transpose -> SRConv1d
|
625 |
+
"""
|
626 |
+
return nn.Sequential(
|
627 |
+
Transpose(),
|
628 |
+
SRConv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs),
|
629 |
+
)
|
630 |
+
|
631 |
+
|
632 |
+
def SRConv1dTranspose(
|
633 |
+
*args, kernel_size: int = 3, padding: str = "same", **kwargs
|
634 |
+
) -> nn.Sequential:
|
635 |
+
"""
|
636 |
+
SRConv1d -> Transpose
|
637 |
+
"""
|
638 |
+
return nn.Sequential(
|
639 |
+
SRConv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs),
|
640 |
+
Transpose(),
|
641 |
+
)
|
642 |
+
|
643 |
+
|
644 |
+
class ActivationBalancer(torch.nn.Module):
|
645 |
+
"""
|
646 |
+
Modifies the backpropped derivatives of a function to try to encourage, for
|
647 |
+
each channel, that it is positive at least a proportion `threshold` of the
|
648 |
+
time. It does this by multiplying negative derivative values by up to
|
649 |
+
(1+max_factor), and positive derivative values by up to (1-max_factor),
|
650 |
+
interpolated from 1 at the threshold to those extremal values when none
|
651 |
+
of the inputs are positive.
|
652 |
+
|
653 |
+
Args:
|
654 |
+
num_channels: the number of channels
|
655 |
+
channel_dim: the dimension/axis corresponding to the channel, e.g.
|
656 |
+
-1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative.
|
657 |
+
min_positive: the minimum, per channel, of the proportion of the time
|
658 |
+
that (x > 0), below which we start to modify the derivatives.
|
659 |
+
max_positive: the maximum, per channel, of the proportion of the time
|
660 |
+
that (x > 0), above which we start to modify the derivatives.
|
661 |
+
max_factor: the maximum factor by which we modify the derivatives for
|
662 |
+
either the sign constraint or the magnitude constraint;
|
663 |
+
e.g. with max_factor=0.02, the the derivatives would be multiplied by
|
664 |
+
values in the range [0.98..1.02].
|
665 |
+
sign_gain_factor: determines the 'gain' with which we increase the
|
666 |
+
change in gradient once the constraints on min_positive and max_positive
|
667 |
+
are violated.
|
668 |
+
scale_gain_factor: determines the 'gain' with which we increase the
|
669 |
+
change in gradient once the constraints on min_abs and max_abs
|
670 |
+
are violated.
|
671 |
+
min_abs: the minimum average-absolute-value difference from the mean
|
672 |
+
value per channel, which we allow, before we start to modify
|
673 |
+
the derivatives to prevent this.
|
674 |
+
max_abs: the maximum average-absolute-value difference from the mean
|
675 |
+
value per channel, which we allow, before we start to modify
|
676 |
+
the derivatives to prevent this.
|
677 |
+
min_prob: determines the minimum probability with which we modify the
|
678 |
+
gradients for the {min,max}_positive and {min,max}_abs constraints,
|
679 |
+
on each forward(). This is done randomly to prevent all layers
|
680 |
+
from doing it at the same time. Early in training we may use
|
681 |
+
higher probabilities than this; it will decay to this value.
|
682 |
+
"""
|
683 |
+
|
684 |
+
def __init__(
|
685 |
+
self,
|
686 |
+
num_channels: int,
|
687 |
+
channel_dim: int,
|
688 |
+
min_positive: float = 0.05,
|
689 |
+
max_positive: float = 0.95,
|
690 |
+
max_factor: float = 0.04,
|
691 |
+
sign_gain_factor: float = 0.01,
|
692 |
+
scale_gain_factor: float = 0.02,
|
693 |
+
min_abs: float = 0.2,
|
694 |
+
max_abs: float = 100.0,
|
695 |
+
min_prob: float = 0.1,
|
696 |
+
):
|
697 |
+
super(ActivationBalancer, self).__init__()
|
698 |
+
self.num_channels = num_channels
|
699 |
+
self.channel_dim = channel_dim
|
700 |
+
self.min_positive = min_positive
|
701 |
+
self.max_positive = max_positive
|
702 |
+
self.max_factor = max_factor
|
703 |
+
self.min_abs = min_abs
|
704 |
+
self.max_abs = max_abs
|
705 |
+
self.min_prob = min_prob
|
706 |
+
self.sign_gain_factor = sign_gain_factor
|
707 |
+
self.scale_gain_factor = scale_gain_factor
|
708 |
+
|
709 |
+
# count measures how many times the forward() function has been called.
|
710 |
+
# We occasionally sync this to a tensor called `count`, that exists to
|
711 |
+
# make sure it is synced to disk when we load and save the model.
|
712 |
+
self.cpu_count = 0
|
713 |
+
self.register_buffer("count", torch.tensor(0, dtype=torch.int64))
|
714 |
+
|
715 |
+
def forward(self, x: Tensor) -> Tensor:
|
716 |
+
if (
|
717 |
+
torch.jit.is_scripting()
|
718 |
+
or not x.requires_grad
|
719 |
+
or torch.jit.is_tracing()
|
720 |
+
):
|
721 |
+
return _no_op(x)
|
722 |
+
|
723 |
+
count = self.cpu_count
|
724 |
+
self.cpu_count += 1
|
725 |
+
|
726 |
+
if random.random() < 0.01:
|
727 |
+
# Occasionally sync self.cpu_count with self.count.
|
728 |
+
# count affects the decay of 'prob'. don't do this on every iter,
|
729 |
+
# because syncing with the GPU is slow.
|
730 |
+
self.cpu_count = max(self.cpu_count, self.count.item())
|
731 |
+
self.count.fill_(self.cpu_count)
|
732 |
+
|
733 |
+
# the prob of doing some work exponentially decreases from 0.5 till it hits
|
734 |
+
# a floor at min_prob (==0.1, by default)
|
735 |
+
prob = max(self.min_prob, 0.5 ** (1 + (count / 4000.0)))
|
736 |
+
|
737 |
+
if random.random() < prob:
|
738 |
+
sign_gain_factor = 0.5
|
739 |
+
if self.min_positive != 0.0 or self.max_positive != 1.0:
|
740 |
+
sign_factor = _compute_sign_factor(
|
741 |
+
x,
|
742 |
+
self.channel_dim,
|
743 |
+
self.min_positive,
|
744 |
+
self.max_positive,
|
745 |
+
gain_factor=self.sign_gain_factor / prob,
|
746 |
+
max_factor=self.max_factor,
|
747 |
+
)
|
748 |
+
else:
|
749 |
+
sign_factor = None
|
750 |
+
|
751 |
+
scale_factor = _compute_scale_factor(
|
752 |
+
x.detach(),
|
753 |
+
self.channel_dim,
|
754 |
+
min_abs=self.min_abs,
|
755 |
+
max_abs=self.max_abs,
|
756 |
+
gain_factor=self.scale_gain_factor / prob,
|
757 |
+
max_factor=self.max_factor,
|
758 |
+
)
|
759 |
+
return ActivationBalancerFunction.apply(
|
760 |
+
x,
|
761 |
+
scale_factor,
|
762 |
+
sign_factor,
|
763 |
+
self.channel_dim,
|
764 |
+
)
|
765 |
+
else:
|
766 |
+
return _no_op(x)
|
767 |
+
|
768 |
+
|
769 |
+
def penalize_abs_values_gt(x: Tensor, limit: float, penalty: float) -> Tensor:
|
770 |
+
"""
|
771 |
+
Returns x unmodified, but in backprop will put a penalty for the excess of
|
772 |
+
the absolute values of elements of x over the limit "limit". E.g. if
|
773 |
+
limit == 10.0, then if x has any values over 10 it will get a penalty.
|
774 |
+
|
775 |
+
Caution: the value of this penalty will be affected by grad scaling used
|
776 |
+
in automatic mixed precision training. For this reasons we use this,
|
777 |
+
it shouldn't really matter, or may even be helpful; we just use this
|
778 |
+
to disallow really implausible values of scores to be given to softmax.
|
779 |
+
"""
|
780 |
+
x_sign = x.sign()
|
781 |
+
over_limit = (x.abs() - limit) > 0
|
782 |
+
# The following is a memory efficient way to penalize the absolute values of
|
783 |
+
# x that's over the limit. (The memory efficiency comes when you think
|
784 |
+
# about which items torch needs to cache for the autograd, and which ones it
|
785 |
+
# can throw away). The numerical value of aux_loss as computed here will
|
786 |
+
# actually be larger than it should be, by limit * over_limit.sum(), but it
|
787 |
+
# has the same derivative as the real aux_loss which is penalty * (x.abs() -
|
788 |
+
# limit).relu().
|
789 |
+
aux_loss = penalty * ((x_sign * over_limit).to(torch.int8) * x)
|
790 |
+
# note: we don't do sum() here on aux)_loss, but it's as if we had done
|
791 |
+
# sum() due to how with_loss() works.
|
792 |
+
x = with_loss(x, aux_loss)
|
793 |
+
# you must use x for something, or this will be ineffective.
|
794 |
+
return x
|
795 |
+
|
796 |
+
|
797 |
+
def _diag(x: Tensor): # like .diag(), but works for tensors with 3 dims.
|
798 |
+
if x.ndim == 2:
|
799 |
+
return x.diag()
|
800 |
+
else:
|
801 |
+
(batch, dim, dim) = x.shape
|
802 |
+
x = x.reshape(batch, dim * dim)
|
803 |
+
x = x[:, :: dim + 1]
|
804 |
+
assert x.shape == (batch, dim)
|
805 |
+
return x
|
806 |
+
|
807 |
+
|
808 |
+
def _whitening_metric(x: Tensor, num_groups: int):
|
809 |
+
"""
|
810 |
+
Computes the "whitening metric", a value which will be 1.0 if all the eigenvalues of
|
811 |
+
of the centered feature covariance are the same within each group's covariance matrix
|
812 |
+
and also between groups.
|
813 |
+
Args:
|
814 |
+
x: a Tensor of shape (*, num_channels)
|
815 |
+
num_groups: the number of groups of channels, a number >=1 that divides num_channels
|
816 |
+
Returns:
|
817 |
+
Returns a scalar Tensor that will be 1.0 if the data is "perfectly white" and
|
818 |
+
greater than 1.0 otherwise.
|
819 |
+
"""
|
820 |
+
assert x.dtype != torch.float16
|
821 |
+
x = x.reshape(-1, x.shape[-1])
|
822 |
+
(num_frames, num_channels) = x.shape
|
823 |
+
assert num_channels % num_groups == 0
|
824 |
+
channels_per_group = num_channels // num_groups
|
825 |
+
x = x.reshape(num_frames, num_groups, channels_per_group).transpose(0, 1)
|
826 |
+
# x now has shape (num_groups, num_frames, channels_per_group)
|
827 |
+
# subtract the mean so we use the centered, not uncentered, covariance.
|
828 |
+
# My experience has been that when we "mess with the gradients" like this,
|
829 |
+
# it's better not do anything that tries to move the mean around, because
|
830 |
+
# that can easily cause instability.
|
831 |
+
x = x - x.mean(dim=1, keepdim=True)
|
832 |
+
# x_covar: (num_groups, channels_per_group, channels_per_group)
|
833 |
+
x_covar = torch.matmul(x.transpose(1, 2), x)
|
834 |
+
x_covar_mean_diag = _diag(x_covar).mean()
|
835 |
+
# the following expression is what we'd get if we took the matrix product
|
836 |
+
# of each covariance and measured the mean of its trace, i.e.
|
837 |
+
# the same as _diag(torch.matmul(x_covar, x_covar)).mean().
|
838 |
+
x_covarsq_mean_diag = (x_covar ** 2).sum() / (
|
839 |
+
num_groups * channels_per_group
|
840 |
+
)
|
841 |
+
# this metric will be >= 1.0; the larger it is, the less 'white' the data was.
|
842 |
+
metric = x_covarsq_mean_diag / (x_covar_mean_diag ** 2 + 1.0e-20)
|
843 |
+
return metric
|
844 |
+
|
845 |
+
|
846 |
+
class WhiteningPenaltyFunction(torch.autograd.Function):
|
847 |
+
@staticmethod
|
848 |
+
def forward(
|
849 |
+
ctx,
|
850 |
+
x: Tensor,
|
851 |
+
num_groups: int,
|
852 |
+
whitening_limit: float,
|
853 |
+
grad_scale: float,
|
854 |
+
) -> Tensor:
|
855 |
+
ctx.save_for_backward(x)
|
856 |
+
ctx.num_groups = num_groups
|
857 |
+
ctx.whitening_limit = whitening_limit
|
858 |
+
ctx.grad_scale = grad_scale
|
859 |
+
return x
|
860 |
+
|
861 |
+
@staticmethod
|
862 |
+
def backward(ctx, x_grad: Tensor):
|
863 |
+
(x_orig,) = ctx.saved_tensors
|
864 |
+
with torch.enable_grad():
|
865 |
+
with torch.cuda.amp.autocast(enabled=False):
|
866 |
+
x_detached = x_orig.to(torch.float32).detach()
|
867 |
+
x_detached.requires_grad = True
|
868 |
+
|
869 |
+
metric = _whitening_metric(x_detached, ctx.num_groups)
|
870 |
+
|
871 |
+
if random.random() < 0.005 or __name__ == "__main__":
|
872 |
+
logging.info(
|
873 |
+
f"Whitening: num_groups={ctx.num_groups}, num_channels={x_orig.shape[-1]}, "
|
874 |
+
f"metric={metric.item():.2f} vs. limit={ctx.whitening_limit}"
|
875 |
+
)
|
876 |
+
|
877 |
+
(metric - ctx.whitening_limit).relu().backward()
|
878 |
+
penalty_grad = x_detached.grad
|
879 |
+
scale = ctx.grad_scale * (
|
880 |
+
x_grad.to(torch.float32).norm()
|
881 |
+
/ (penalty_grad.norm() + 1.0e-20)
|
882 |
+
)
|
883 |
+
penalty_grad = penalty_grad * scale
|
884 |
+
return x_grad + penalty_grad.to(x_grad.dtype), None, None, None
|
885 |
+
|
886 |
+
|
887 |
+
class Whiten(nn.Module):
|
888 |
+
def __init__(
|
889 |
+
self,
|
890 |
+
num_groups: int,
|
891 |
+
whitening_limit: float,
|
892 |
+
prob: Union[float, Tuple[float, float]],
|
893 |
+
grad_scale: float,
|
894 |
+
):
|
895 |
+
"""
|
896 |
+
Args:
|
897 |
+
num_groups: the number of groups to divide the channel dim into before
|
898 |
+
whitening. We will attempt to make the feature covariance
|
899 |
+
within each group, after mean subtraction, as "white" as possible,
|
900 |
+
while having the same trace across all groups.
|
901 |
+
whitening_limit: a value greater than 1.0, that dictates how much
|
902 |
+
freedom we have to violate the constraints. 1.0 would mean perfectly
|
903 |
+
white, with exactly the same trace across groups; larger values
|
904 |
+
give more freedom. E.g. 2.0.
|
905 |
+
prob: the probability with which we apply the gradient modification
|
906 |
+
(also affects the grad scale). May be supplied as a float,
|
907 |
+
or as a pair (min_prob, max_prob)
|
908 |
+
|
909 |
+
grad_scale: determines the scale on the gradient term from this object,
|
910 |
+
relative to the rest of the gradient on the attention weights.
|
911 |
+
E.g. 0.02 (you may want to use smaller values than this if prob is large)
|
912 |
+
"""
|
913 |
+
super(Whiten, self).__init__()
|
914 |
+
assert num_groups >= 1
|
915 |
+
assert whitening_limit >= 1
|
916 |
+
assert grad_scale >= 0
|
917 |
+
self.num_groups = num_groups
|
918 |
+
self.whitening_limit = whitening_limit
|
919 |
+
if isinstance(prob, float):
|
920 |
+
assert 0 < prob <= 1
|
921 |
+
self.prob = prob
|
922 |
+
else:
|
923 |
+
(self.min_prob, self.max_prob) = prob
|
924 |
+
assert 0 < self.min_prob < self.max_prob <= 1
|
925 |
+
self.prob = self.max_prob
|
926 |
+
|
927 |
+
self.grad_scale = grad_scale
|
928 |
+
|
929 |
+
def forward(self, x: Tensor) -> Tensor:
|
930 |
+
"""
|
931 |
+
In the forward pass, this function just returns the input unmodified.
|
932 |
+
In the backward pass, it will modify the gradients to ensure that the
|
933 |
+
distribution in each group has close to (lambda times I) as the covariance
|
934 |
+
after mean subtraction, with the same lambda across groups.
|
935 |
+
For whitening_limit > 1, there will be more freedom to violate this
|
936 |
+
constraint.
|
937 |
+
|
938 |
+
Args:
|
939 |
+
x: the input of shape (*, num_channels)
|
940 |
+
|
941 |
+
Returns:
|
942 |
+
x, unmodified. You should make sure
|
943 |
+
you use the returned value, or the graph will be freed
|
944 |
+
and nothing will happen in backprop.
|
945 |
+
"""
|
946 |
+
if (
|
947 |
+
not x.requires_grad
|
948 |
+
or random.random() > self.prob
|
949 |
+
or self.grad_scale == 0
|
950 |
+
):
|
951 |
+
return _no_op(x)
|
952 |
+
else:
|
953 |
+
if hasattr(self, "min_prob") and random.random() < 0.25:
|
954 |
+
# occasionally switch between min_prob and max_prob, based on whether
|
955 |
+
# we are above or below the threshold.
|
956 |
+
if (
|
957 |
+
_whitening_metric(x.to(torch.float32), self.num_groups)
|
958 |
+
> self.whitening_limit
|
959 |
+
):
|
960 |
+
# there would be a change to the grad.
|
961 |
+
self.prob = self.max_prob
|
962 |
+
else:
|
963 |
+
self.prob = self.min_prob
|
964 |
+
|
965 |
+
return WhiteningPenaltyFunction.apply(
|
966 |
+
x, self.num_groups, self.whitening_limit, self.grad_scale
|
967 |
+
)
|
968 |
+
|
969 |
+
|
970 |
+
class WithLoss(torch.autograd.Function):
|
971 |
+
@staticmethod
|
972 |
+
def forward(ctx, x: Tensor, y: Tensor):
|
973 |
+
ctx.y_shape = y.shape
|
974 |
+
return x
|
975 |
+
|
976 |
+
@staticmethod
|
977 |
+
def backward(ctx, ans_grad: Tensor):
|
978 |
+
return ans_grad, torch.ones(
|
979 |
+
ctx.y_shape, dtype=ans_grad.dtype, device=ans_grad.device
|
980 |
+
)
|
981 |
+
|
982 |
+
|
983 |
+
def with_loss(x, y):
|
984 |
+
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
985 |
+
return x
|
986 |
+
# returns x but adds y.sum() to the loss function.
|
987 |
+
return WithLoss.apply(x, y)
|
988 |
+
|
989 |
+
|
990 |
+
def _no_op(x: Tensor) -> Tensor:
|
991 |
+
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
992 |
+
return x
|
993 |
+
else:
|
994 |
+
# a no-op function that will have a node in the autograd graph,
|
995 |
+
# to avoid certain bugs relating to backward hooks
|
996 |
+
return x.chunk(1, dim=-1)[0]
|
997 |
+
|
998 |
+
|
999 |
+
class Identity(torch.nn.Module):
|
1000 |
+
def __init__(self):
|
1001 |
+
super(Identity, self).__init__()
|
1002 |
+
|
1003 |
+
def forward(self, x):
|
1004 |
+
return _no_op(x)
|
1005 |
+
|
1006 |
+
|
1007 |
+
class MaxEig(torch.nn.Module):
|
1008 |
+
"""
|
1009 |
+
Modifies the backpropped derivatives of a function to try to discourage
|
1010 |
+
that any given direction in activation space accounts for more than
|
1011 |
+
a specified proportion of the covariance (e.g. 0.2).
|
1012 |
+
|
1013 |
+
|
1014 |
+
Args:
|
1015 |
+
num_channels: the number of channels
|
1016 |
+
channel_dim: the dimension/axis corresponding to the channel, e.g.
|
1017 |
+
-1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative.
|
1018 |
+
max_var_per_eig: the maximum proportion of the variance of the
|
1019 |
+
features/channels, after mean subtraction, that can come from
|
1020 |
+
any given eigenvalue.
|
1021 |
+
min_prob: the minimum probability with which we apply this during any invocation
|
1022 |
+
of forward(), assuming last time we applied the constraint it was
|
1023 |
+
not active; supplied for speed.
|
1024 |
+
scale: determines the scale with which we modify the gradients, relative
|
1025 |
+
to the existing / unmodified gradients
|
1026 |
+
"""
|
1027 |
+
|
1028 |
+
def __init__(
|
1029 |
+
self,
|
1030 |
+
num_channels: int,
|
1031 |
+
channel_dim: int,
|
1032 |
+
max_var_per_eig: float = 0.2,
|
1033 |
+
min_prob: float = 0.01,
|
1034 |
+
scale: float = 0.01,
|
1035 |
+
):
|
1036 |
+
super(MaxEig, self).__init__()
|
1037 |
+
self.num_channels = num_channels
|
1038 |
+
self.channel_dim = channel_dim
|
1039 |
+
self.scale = scale
|
1040 |
+
assert max_var_per_eig == 0.0 or max_var_per_eig > 1.0 / num_channels
|
1041 |
+
self.max_var_per_eig = max_var_per_eig
|
1042 |
+
|
1043 |
+
# we figure out the dominant direction using the power method: starting with
|
1044 |
+
# a random vector, keep multiplying by the covariance and renormalizing.
|
1045 |
+
with torch.no_grad():
|
1046 |
+
# arbitrary.. would use randn() but want to leave the rest of the model's
|
1047 |
+
# random parameters unchanged for comparison
|
1048 |
+
direction = torch.arange(num_channels).to(torch.float)
|
1049 |
+
direction = direction / direction.norm()
|
1050 |
+
self.register_buffer("max_eig_direction", direction)
|
1051 |
+
|
1052 |
+
self.min_prob = min_prob
|
1053 |
+
# cur_prob is the current probability we'll use to apply the ActivationBalancer.
|
1054 |
+
# We'll regress this towards prob, each tiem we try to apply it and it is not
|
1055 |
+
# active.
|
1056 |
+
self.cur_prob = 1.0
|
1057 |
+
|
1058 |
+
def forward(self, x: Tensor) -> Tensor:
|
1059 |
+
if (
|
1060 |
+
torch.jit.is_scripting()
|
1061 |
+
or self.max_var_per_eig <= 0
|
1062 |
+
or random.random() > self.cur_prob
|
1063 |
+
or torch.jit.is_tracing()
|
1064 |
+
):
|
1065 |
+
return _no_op(x)
|
1066 |
+
|
1067 |
+
with torch.cuda.amp.autocast(enabled=False):
|
1068 |
+
eps = 1.0e-20
|
1069 |
+
orig_x = x
|
1070 |
+
x = x.to(torch.float32)
|
1071 |
+
with torch.no_grad():
|
1072 |
+
x = x.transpose(self.channel_dim, -1).reshape(
|
1073 |
+
-1, self.num_channels
|
1074 |
+
)
|
1075 |
+
x = x - x.mean(dim=0)
|
1076 |
+
new_direction, coeffs = self._find_direction_coeffs(
|
1077 |
+
x, self.max_eig_direction
|
1078 |
+
)
|
1079 |
+
x_var = (x ** 2).mean()
|
1080 |
+
x_residual = x - coeffs * new_direction
|
1081 |
+
x_residual_var = (x_residual ** 2).mean()
|
1082 |
+
|
1083 |
+
# `variance_proportion` is the proportion of the variance accounted for
|
1084 |
+
# by the top eigen-direction.
|
1085 |
+
variance_proportion = (x_var - x_residual_var) / (
|
1086 |
+
x_var + 1.0e-20
|
1087 |
+
)
|
1088 |
+
|
1089 |
+
# ensure new direction is nonzero even if x == 0, by including `direction`.
|
1090 |
+
self._set_direction(
|
1091 |
+
0.1 * self.max_eig_direction + new_direction
|
1092 |
+
)
|
1093 |
+
|
1094 |
+
if random.random() < 0.01 or __name__ == "__main__":
|
1095 |
+
logging.info(
|
1096 |
+
f"variance_proportion = {variance_proportion.item()}, shape={tuple(orig_x.shape)}, cur_prob={self.cur_prob}"
|
1097 |
+
)
|
1098 |
+
|
1099 |
+
if variance_proportion >= self.max_var_per_eig:
|
1100 |
+
# The constraint is active. Note, we should quite rarely
|
1101 |
+
# reach here, only near the beginning of training if we are
|
1102 |
+
# starting to diverge, should this constraint be active.
|
1103 |
+
cur_prob = self.cur_prob
|
1104 |
+
self.cur_prob = (
|
1105 |
+
1.0 # next time, do the update with probability 1.0.
|
1106 |
+
)
|
1107 |
+
return MaxEigLimiterFunction.apply(
|
1108 |
+
orig_x, coeffs, new_direction, self.channel_dim, self.scale
|
1109 |
+
)
|
1110 |
+
else:
|
1111 |
+
# let self.cur_prob exponentially approach self.min_prob, as
|
1112 |
+
# long as the constraint is inactive.
|
1113 |
+
self.cur_prob = 0.75 * self.cur_prob + 0.25 * self.min_prob
|
1114 |
+
return orig_x
|
1115 |
+
|
1116 |
+
def _set_direction(self, direction: Tensor):
|
1117 |
+
"""
|
1118 |
+
Sets self.max_eig_direction to a normalized version of `direction`
|
1119 |
+
"""
|
1120 |
+
direction = direction.detach()
|
1121 |
+
direction = direction / direction.norm()
|
1122 |
+
direction_sum = direction.sum().item()
|
1123 |
+
if direction_sum - direction_sum == 0: # no inf/nan
|
1124 |
+
self.max_eig_direction[:] = direction
|
1125 |
+
else:
|
1126 |
+
logging.info(
|
1127 |
+
f"Warning: sum of direction in MaxEig is {direction_sum}, "
|
1128 |
+
"num_channels={self.num_channels}, channel_dim={self.channel_dim}"
|
1129 |
+
)
|
1130 |
+
|
1131 |
+
def _find_direction_coeffs(
|
1132 |
+
self, x: Tensor, prev_direction: Tensor
|
1133 |
+
) -> Tuple[Tensor, Tensor, Tensor]:
|
1134 |
+
"""
|
1135 |
+
Figure out (an approximation to) the proportion of the variance of a set of
|
1136 |
+
feature vectors that can be attributed to the top eigen-direction.
|
1137 |
+
Args:
|
1138 |
+
x: a Tensor of shape (num_frames, num_channels), with num_frames > 1.
|
1139 |
+
prev_direction: a Tensor of shape (num_channels,), that is our previous estimate
|
1140 |
+
of the top eigen-direction, or a random direction if this is the first
|
1141 |
+
iteration. Does not have to be normalized, but should be nonzero.
|
1142 |
+
|
1143 |
+
Returns: (cur_direction, coeffs), where:
|
1144 |
+
cur_direction: a Tensor of shape (num_channels,) that is the current
|
1145 |
+
estimate of the top eigen-direction.
|
1146 |
+
coeffs: a Tensor of shape (num_frames, 1) that minimizes, or
|
1147 |
+
approximately minimizes, (x - coeffs * cur_direction).norm()
|
1148 |
+
"""
|
1149 |
+
(num_frames, num_channels) = x.shape
|
1150 |
+
assert num_channels > 1 and num_frames > 1
|
1151 |
+
assert prev_direction.shape == (num_channels,)
|
1152 |
+
# `coeffs` are the coefficients of `prev_direction` in x.
|
1153 |
+
# actually represent the coeffs up to a constant positive factor.
|
1154 |
+
coeffs = (x * prev_direction).sum(dim=1, keepdim=True) + 1.0e-10
|
1155 |
+
cur_direction = (x * coeffs).sum(dim=0) / (
|
1156 |
+
(coeffs ** 2).sum() + 1.0e-20
|
1157 |
+
)
|
1158 |
+
return cur_direction, coeffs
|
1159 |
+
|
1160 |
+
|
1161 |
+
class DoubleSwishFunction(torch.autograd.Function):
|
1162 |
+
"""
|
1163 |
+
double_swish(x) = x * torch.sigmoid(x-1)
|
1164 |
+
This is a definition, originally motivated by its close numerical
|
1165 |
+
similarity to swish(swish(x)), where swish(x) = x * sigmoid(x).
|
1166 |
+
|
1167 |
+
Memory-efficient derivative computation:
|
1168 |
+
double_swish(x) = x * s, where s(x) = torch.sigmoid(x-1)
|
1169 |
+
double_swish'(x) = d/dx double_swish(x) = x * s'(x) + x' * s(x) = x * s'(x) + s(x).
|
1170 |
+
Now, s'(x) = s(x) * (1-s(x)).
|
1171 |
+
double_swish'(x) = x * s'(x) + s(x).
|
1172 |
+
= x * s(x) * (1-s(x)) + s(x).
|
1173 |
+
= double_swish(x) * (1-s(x)) + s(x)
|
1174 |
+
... so we just need to remember s(x) but not x itself.
|
1175 |
+
"""
|
1176 |
+
|
1177 |
+
@staticmethod
|
1178 |
+
def forward(ctx, x: Tensor) -> Tensor:
|
1179 |
+
requires_grad = x.requires_grad
|
1180 |
+
x_dtype = x.dtype
|
1181 |
+
if x.dtype == torch.float16:
|
1182 |
+
x = x.to(torch.float32)
|
1183 |
+
|
1184 |
+
s = torch.sigmoid(x - 1.0)
|
1185 |
+
y = x * s
|
1186 |
+
|
1187 |
+
if requires_grad:
|
1188 |
+
deriv = y * (1 - s) + s
|
1189 |
+
# notes on derivative of x * sigmoid(x - 1):
|
1190 |
+
# https://www.wolframalpha.com/input?i=d%2Fdx+%28x+*+sigmoid%28x-1%29%29
|
1191 |
+
# min \simeq -0.043638. Take floor as -0.043637 so it's a lower bund
|
1192 |
+
# max \simeq 1.1990. Take ceil to be 1.2 so it's an upper bound.
|
1193 |
+
# the combination of "+ torch.rand_like(deriv)" and casting to torch.uint8 (which
|
1194 |
+
# floors), should be expectation-preserving.
|
1195 |
+
floor = -0.043637
|
1196 |
+
ceil = 1.2
|
1197 |
+
d_scaled = (deriv - floor) * (
|
1198 |
+
255.0 / (ceil - floor)
|
1199 |
+
) + torch.rand_like(deriv)
|
1200 |
+
if __name__ == "__main__":
|
1201 |
+
# for self-testing only.
|
1202 |
+
assert d_scaled.min() >= 0.0
|
1203 |
+
assert d_scaled.max() < 256.0
|
1204 |
+
d_int = d_scaled.to(torch.uint8)
|
1205 |
+
ctx.save_for_backward(d_int)
|
1206 |
+
if x.dtype == torch.float16 or torch.is_autocast_enabled():
|
1207 |
+
y = y.to(torch.float16)
|
1208 |
+
return y
|
1209 |
+
|
1210 |
+
@staticmethod
|
1211 |
+
def backward(ctx, y_grad: Tensor) -> Tensor:
|
1212 |
+
(d,) = ctx.saved_tensors
|
1213 |
+
# the same constants as used in forward pass.
|
1214 |
+
floor = -0.043637
|
1215 |
+
ceil = 1.2
|
1216 |
+
d = d * ((ceil - floor) / 255.0) + floor
|
1217 |
+
return y_grad * d
|
1218 |
+
|
1219 |
+
|
1220 |
+
class DoubleSwish(torch.nn.Module):
|
1221 |
+
def forward(self, x: Tensor) -> Tensor:
|
1222 |
+
"""Return double-swish activation function which is an approximation to Swish(Swish(x)),
|
1223 |
+
that we approximate closely with x * sigmoid(x-1).
|
1224 |
+
"""
|
1225 |
+
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
1226 |
+
return x * torch.sigmoid(x - 1.0)
|
1227 |
+
return DoubleSwishFunction.apply(x)
|
1228 |
+
|
1229 |
+
|
1230 |
+
def BalancedDoubleSwish(
|
1231 |
+
d_model, channel_dim=-1, max_abs=10.0, min_prob=0.25
|
1232 |
+
) -> nn.Sequential:
|
1233 |
+
"""
|
1234 |
+
ActivationBalancer -> DoubleSwish
|
1235 |
+
"""
|
1236 |
+
balancer = ActivationBalancer(
|
1237 |
+
d_model, channel_dim=channel_dim, max_abs=max_abs, min_prob=min_prob
|
1238 |
+
)
|
1239 |
+
return nn.Sequential(
|
1240 |
+
balancer,
|
1241 |
+
DoubleSwish(),
|
1242 |
+
)
|
1243 |
+
|
1244 |
+
|
1245 |
+
def _test_max_eig():
|
1246 |
+
for proportion in [0.1, 0.5, 10.0]:
|
1247 |
+
logging.info(f"proportion = {proportion}")
|
1248 |
+
x = torch.randn(100, 128)
|
1249 |
+
direction = torch.randn(128)
|
1250 |
+
coeffs = torch.randn(100, 1)
|
1251 |
+
x += proportion * direction * coeffs
|
1252 |
+
|
1253 |
+
x.requires_grad = True
|
1254 |
+
|
1255 |
+
num_channels = 128
|
1256 |
+
m = MaxEig(
|
1257 |
+
num_channels, 1, 0.5, scale=0.1 # channel_dim # max_var_per_eig
|
1258 |
+
) # grad_scale
|
1259 |
+
|
1260 |
+
for _ in range(4):
|
1261 |
+
y = m(x)
|
1262 |
+
|
1263 |
+
y_grad = torch.randn_like(x)
|
1264 |
+
y.backward(gradient=y_grad)
|
1265 |
+
|
1266 |
+
if proportion < 0.2:
|
1267 |
+
assert torch.allclose(x.grad, y_grad, atol=1.0e-02)
|
1268 |
+
elif proportion > 1.0:
|
1269 |
+
assert not torch.allclose(x.grad, y_grad)
|
1270 |
+
|
1271 |
+
|
1272 |
+
def _test_whiten():
|
1273 |
+
for proportion in [0.1, 0.5, 10.0]:
|
1274 |
+
logging.info(f"_test_whiten(): proportion = {proportion}")
|
1275 |
+
x = torch.randn(100, 128)
|
1276 |
+
direction = torch.randn(128)
|
1277 |
+
coeffs = torch.randn(100, 1)
|
1278 |
+
x += proportion * direction * coeffs
|
1279 |
+
|
1280 |
+
x.requires_grad = True
|
1281 |
+
|
1282 |
+
num_channels = 128
|
1283 |
+
m = Whiten(
|
1284 |
+
1, 5.0, prob=1.0, grad_scale=0.1 # num_groups # whitening_limit,
|
1285 |
+
) # grad_scale
|
1286 |
+
|
1287 |
+
for _ in range(4):
|
1288 |
+
y = m(x)
|
1289 |
+
|
1290 |
+
y_grad = torch.randn_like(x)
|
1291 |
+
y.backward(gradient=y_grad)
|
1292 |
+
|
1293 |
+
if proportion < 0.2:
|
1294 |
+
assert torch.allclose(x.grad, y_grad)
|
1295 |
+
elif proportion > 1.0:
|
1296 |
+
assert not torch.allclose(x.grad, y_grad)
|
1297 |
+
|
1298 |
+
|
1299 |
+
def _test_activation_balancer_sign():
|
1300 |
+
probs = torch.arange(0, 1, 0.01)
|
1301 |
+
N = 1000
|
1302 |
+
x = 1.0 * (
|
1303 |
+
(2.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1))) - 1.0
|
1304 |
+
)
|
1305 |
+
x = x.detach()
|
1306 |
+
x.requires_grad = True
|
1307 |
+
m = ActivationBalancer(
|
1308 |
+
probs.numel(),
|
1309 |
+
channel_dim=0,
|
1310 |
+
min_positive=0.05,
|
1311 |
+
max_positive=0.95,
|
1312 |
+
max_factor=0.2,
|
1313 |
+
min_abs=0.0,
|
1314 |
+
)
|
1315 |
+
|
1316 |
+
y_grad = torch.sign(torch.randn(probs.numel(), N))
|
1317 |
+
|
1318 |
+
y = m(x)
|
1319 |
+
y.backward(gradient=y_grad)
|
1320 |
+
print("_test_activation_balancer_sign: x = ", x)
|
1321 |
+
print("_test_activation_balancer_sign: y grad = ", y_grad)
|
1322 |
+
print("_test_activation_balancer_sign: x grad = ", x.grad)
|
1323 |
+
|
1324 |
+
|
1325 |
+
def _test_activation_balancer_magnitude():
|
1326 |
+
magnitudes = torch.arange(0, 1, 0.01)
|
1327 |
+
N = 1000
|
1328 |
+
x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(
|
1329 |
+
-1
|
1330 |
+
)
|
1331 |
+
x = x.detach()
|
1332 |
+
x.requires_grad = True
|
1333 |
+
m = ActivationBalancer(
|
1334 |
+
magnitudes.numel(),
|
1335 |
+
channel_dim=0,
|
1336 |
+
min_positive=0.0,
|
1337 |
+
max_positive=1.0,
|
1338 |
+
max_factor=0.2,
|
1339 |
+
min_abs=0.2,
|
1340 |
+
max_abs=0.8,
|
1341 |
+
min_prob=1.0,
|
1342 |
+
)
|
1343 |
+
|
1344 |
+
y_grad = torch.sign(torch.randn(magnitudes.numel(), N))
|
1345 |
+
|
1346 |
+
y = m(x)
|
1347 |
+
y.backward(gradient=y_grad)
|
1348 |
+
print("_test_activation_balancer_magnitude: x = ", x)
|
1349 |
+
print("_test_activation_balancer_magnitude: y grad = ", y_grad)
|
1350 |
+
print("_test_activation_balancer_magnitude: x grad = ", x.grad)
|
1351 |
+
|
1352 |
+
|
1353 |
+
def _test_basic_norm():
|
1354 |
+
num_channels = 128
|
1355 |
+
m = BasicNorm(num_channels=num_channels, channel_dim=1)
|
1356 |
+
|
1357 |
+
x = torch.randn(500, num_channels)
|
1358 |
+
|
1359 |
+
y = m(x)
|
1360 |
+
|
1361 |
+
assert y.shape == x.shape
|
1362 |
+
x_rms = (x ** 2).mean().sqrt()
|
1363 |
+
y_rms = (y ** 2).mean().sqrt()
|
1364 |
+
print("x rms = ", x_rms)
|
1365 |
+
print("y rms = ", y_rms)
|
1366 |
+
assert y_rms < x_rms
|
1367 |
+
assert y_rms > 0.5 * x_rms
|
1368 |
+
|
1369 |
+
|
1370 |
+
def _test_double_swish_deriv():
|
1371 |
+
x = torch.randn(10, 12, dtype=torch.double) * 3.0
|
1372 |
+
x.requires_grad = True
|
1373 |
+
m = DoubleSwish()
|
1374 |
+
|
1375 |
+
tol = (1.2 - (-0.043637)) / 255.0
|
1376 |
+
torch.autograd.gradcheck(m, x, atol=tol)
|
1377 |
+
|
1378 |
+
# for self-test.
|
1379 |
+
x = torch.randn(1000, 1000, dtype=torch.double) * 3.0
|
1380 |
+
x.requires_grad = True
|
1381 |
+
y = m(x)
|
1382 |
+
|
1383 |
+
|
1384 |
+
def _test_softmax():
|
1385 |
+
a = torch.randn(2, 10, dtype=torch.float64)
|
1386 |
+
b = a.clone()
|
1387 |
+
a.requires_grad = True
|
1388 |
+
b.requires_grad = True
|
1389 |
+
a.softmax(dim=1)[:, 0].sum().backward()
|
1390 |
+
print("a grad = ", a.grad)
|
1391 |
+
softmax(b, dim=1)[:, 0].sum().backward()
|
1392 |
+
print("b grad = ", b.grad)
|
1393 |
+
assert torch.allclose(a.grad, b.grad)
|
1394 |
+
|
1395 |
+
|
1396 |
+
if __name__ == "__main__":
|
1397 |
+
logging.getLogger().setLevel(logging.INFO)
|
1398 |
+
torch.set_num_threads(1)
|
1399 |
+
torch.set_num_interop_threads(1)
|
1400 |
+
_test_softmax()
|
1401 |
+
_test_whiten()
|
1402 |
+
_test_max_eig()
|
1403 |
+
_test_activation_balancer_sign()
|
1404 |
+
_test_activation_balancer_magnitude()
|
1405 |
+
_test_basic_norm()
|
1406 |
+
_test_double_swish_deriv()
|
models/modules/transformer.py
ADDED
@@ -0,0 +1,698 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# cp from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/transformer.py, modified by Puyuan Peng 2024
|
2 |
+
import copy
|
3 |
+
import numbers
|
4 |
+
from functools import partial
|
5 |
+
from typing import Any, Callable, List, Optional, Tuple, Union
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from torch import Tensor, nn
|
9 |
+
from torch.nn import functional as F
|
10 |
+
|
11 |
+
from .activation import MultiheadAttention
|
12 |
+
from .scaling import ActivationBalancer, BalancedDoubleSwish
|
13 |
+
from .scaling import BasicNorm as _BasicNorm
|
14 |
+
|
15 |
+
_shape_t = Union[int, List[int], torch.Size]
|
16 |
+
|
17 |
+
|
18 |
+
class LayerNorm(nn.Module):
|
19 |
+
__constants__ = ["normalized_shape", "eps", "elementwise_affine"]
|
20 |
+
normalized_shape: Tuple[int, ...]
|
21 |
+
eps: float
|
22 |
+
elementwise_affine: bool
|
23 |
+
|
24 |
+
def __init__(
|
25 |
+
self,
|
26 |
+
normalized_shape: _shape_t,
|
27 |
+
eps: float = 1e-5,
|
28 |
+
elementwise_affine: bool = True,
|
29 |
+
device=None,
|
30 |
+
dtype=None,
|
31 |
+
) -> None:
|
32 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
33 |
+
super(LayerNorm, self).__init__()
|
34 |
+
if isinstance(normalized_shape, numbers.Integral):
|
35 |
+
# mypy error: incompatible types in assignment
|
36 |
+
normalized_shape = (normalized_shape,) # type: ignore[assignment]
|
37 |
+
self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
|
38 |
+
self.eps = eps
|
39 |
+
self.elementwise_affine = elementwise_affine
|
40 |
+
if self.elementwise_affine:
|
41 |
+
self.weight = nn.Parameter(
|
42 |
+
torch.empty(self.normalized_shape, **factory_kwargs)
|
43 |
+
)
|
44 |
+
self.bias = nn.Parameter(
|
45 |
+
torch.empty(self.normalized_shape, **factory_kwargs)
|
46 |
+
)
|
47 |
+
else:
|
48 |
+
self.register_parameter("weight", None)
|
49 |
+
self.register_parameter("bias", None)
|
50 |
+
|
51 |
+
self.reset_parameters()
|
52 |
+
|
53 |
+
def reset_parameters(self) -> None:
|
54 |
+
if self.elementwise_affine:
|
55 |
+
nn.init.ones_(self.weight)
|
56 |
+
nn.init.zeros_(self.bias)
|
57 |
+
|
58 |
+
def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
|
59 |
+
if isinstance(input, tuple):
|
60 |
+
input, embedding = input
|
61 |
+
return (
|
62 |
+
F.layer_norm(
|
63 |
+
input,
|
64 |
+
self.normalized_shape,
|
65 |
+
self.weight,
|
66 |
+
self.bias,
|
67 |
+
self.eps,
|
68 |
+
),
|
69 |
+
embedding,
|
70 |
+
)
|
71 |
+
|
72 |
+
assert embedding is None
|
73 |
+
return F.layer_norm(
|
74 |
+
input, self.normalized_shape, self.weight, self.bias, self.eps
|
75 |
+
)
|
76 |
+
|
77 |
+
def extra_repr(self) -> str:
|
78 |
+
return (
|
79 |
+
"{normalized_shape}, eps={eps}, "
|
80 |
+
"elementwise_affine={elementwise_affine}".format(**self.__dict__)
|
81 |
+
)
|
82 |
+
|
83 |
+
|
84 |
+
class AdaptiveLayerNorm(nn.Module):
|
85 |
+
r"""Adaptive Layer Normalization"""
|
86 |
+
|
87 |
+
def __init__(self, d_model, norm) -> None:
|
88 |
+
super(AdaptiveLayerNorm, self).__init__()
|
89 |
+
self.project_layer = nn.Linear(d_model, 2 * d_model)
|
90 |
+
self.norm = norm
|
91 |
+
self.d_model = d_model
|
92 |
+
self.eps = self.norm.eps
|
93 |
+
|
94 |
+
def forward(self, input: Tensor, embedding: Tensor = None) -> Tensor:
|
95 |
+
if isinstance(input, tuple):
|
96 |
+
input, embedding = input
|
97 |
+
weight, bias = torch.split(
|
98 |
+
self.project_layer(embedding),
|
99 |
+
split_size_or_sections=self.d_model,
|
100 |
+
dim=-1,
|
101 |
+
)
|
102 |
+
return (weight * self.norm(input) + bias, embedding)
|
103 |
+
|
104 |
+
weight, bias = torch.split(
|
105 |
+
self.project_layer(embedding),
|
106 |
+
split_size_or_sections=self.d_model,
|
107 |
+
dim=-1,
|
108 |
+
)
|
109 |
+
return weight * self.norm(input) + bias
|
110 |
+
|
111 |
+
|
112 |
+
class BasicNorm(_BasicNorm):
|
113 |
+
def __init__(
|
114 |
+
self,
|
115 |
+
d_model: int,
|
116 |
+
eps: float = 1e-5,
|
117 |
+
device=None,
|
118 |
+
dtype=None,
|
119 |
+
):
|
120 |
+
super(BasicNorm, self).__init__(d_model, eps=eps)
|
121 |
+
|
122 |
+
def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
|
123 |
+
if isinstance(input, tuple):
|
124 |
+
input, embedding = input
|
125 |
+
return (
|
126 |
+
super(BasicNorm, self).forward(input),
|
127 |
+
embedding,
|
128 |
+
)
|
129 |
+
|
130 |
+
assert embedding is None
|
131 |
+
return super(BasicNorm, self).forward(input)
|
132 |
+
|
133 |
+
|
134 |
+
class BalancedBasicNorm(nn.Module):
|
135 |
+
def __init__(
|
136 |
+
self,
|
137 |
+
d_model: int,
|
138 |
+
eps: float = 1e-5,
|
139 |
+
device=None,
|
140 |
+
dtype=None,
|
141 |
+
):
|
142 |
+
super(BalancedBasicNorm, self).__init__()
|
143 |
+
self.balancer = ActivationBalancer(
|
144 |
+
d_model,
|
145 |
+
channel_dim=-1,
|
146 |
+
min_positive=0.45,
|
147 |
+
max_positive=0.55,
|
148 |
+
max_abs=6.0,
|
149 |
+
)
|
150 |
+
self.norm = BasicNorm(d_model, eps, device=device, dtype=dtype)
|
151 |
+
|
152 |
+
def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
|
153 |
+
if isinstance(input, tuple):
|
154 |
+
input, embedding = input
|
155 |
+
return self.norm((self.balancer(input), embedding))
|
156 |
+
|
157 |
+
assert embedding is None
|
158 |
+
return self.norm(self.balancer(input))
|
159 |
+
|
160 |
+
|
161 |
+
class IdentityNorm(nn.Module):
|
162 |
+
def __init__(
|
163 |
+
self,
|
164 |
+
d_model: int,
|
165 |
+
eps: float = 1e-5,
|
166 |
+
device=None,
|
167 |
+
dtype=None,
|
168 |
+
) -> None:
|
169 |
+
super(IdentityNorm, self).__init__()
|
170 |
+
|
171 |
+
def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
|
172 |
+
if isinstance(input, tuple):
|
173 |
+
return input
|
174 |
+
|
175 |
+
assert embedding is None
|
176 |
+
return input
|
177 |
+
|
178 |
+
|
179 |
+
class TransformerEncoderLayer(nn.Module):
|
180 |
+
__constants__ = ["batch_first", "norm_first"]
|
181 |
+
|
182 |
+
def __init__(
|
183 |
+
self,
|
184 |
+
d_model: int,
|
185 |
+
nhead: int,
|
186 |
+
dim_feedforward: int = 2048,
|
187 |
+
dropout: float = 0.1,
|
188 |
+
activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
|
189 |
+
batch_first: bool = False,
|
190 |
+
norm_first: bool = False,
|
191 |
+
device=None,
|
192 |
+
dtype=None,
|
193 |
+
linear1_self_attention_cls: nn.Module = nn.Linear,
|
194 |
+
linear2_self_attention_cls: nn.Module = nn.Linear,
|
195 |
+
linear1_feedforward_cls: nn.Module = nn.Linear,
|
196 |
+
linear2_feedforward_cls: nn.Module = nn.Linear,
|
197 |
+
layer_norm_cls: nn.Module = LayerNorm,
|
198 |
+
layer_norm_eps: float = 1e-5,
|
199 |
+
adaptive_layer_norm=False,
|
200 |
+
) -> None:
|
201 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
202 |
+
super(TransformerEncoderLayer, self).__init__()
|
203 |
+
self.self_attn = MultiheadAttention(
|
204 |
+
d_model,
|
205 |
+
nhead,
|
206 |
+
dropout=dropout,
|
207 |
+
batch_first=batch_first,
|
208 |
+
linear1_cls=linear1_self_attention_cls,
|
209 |
+
linear2_cls=linear2_self_attention_cls,
|
210 |
+
**factory_kwargs,
|
211 |
+
)
|
212 |
+
|
213 |
+
# Implementation of Feedforward model
|
214 |
+
self.linear1 = linear1_feedforward_cls(
|
215 |
+
d_model, dim_feedforward, **factory_kwargs
|
216 |
+
)
|
217 |
+
self.dropout = nn.Dropout(dropout)
|
218 |
+
self.linear2 = linear2_feedforward_cls(
|
219 |
+
dim_feedforward, d_model, **factory_kwargs
|
220 |
+
)
|
221 |
+
|
222 |
+
self.norm_first = norm_first
|
223 |
+
self.dropout1 = nn.Dropout(dropout)
|
224 |
+
self.dropout2 = nn.Dropout(dropout)
|
225 |
+
|
226 |
+
# Legacy string support for activation function.
|
227 |
+
if isinstance(activation, str):
|
228 |
+
activation = _get_activation_fn(activation)
|
229 |
+
elif isinstance(activation, partial):
|
230 |
+
activation = activation(d_model)
|
231 |
+
elif activation == BalancedDoubleSwish:
|
232 |
+
activation = BalancedDoubleSwish(d_model)
|
233 |
+
|
234 |
+
# # We can't test self.activation in forward() in TorchScript,
|
235 |
+
# # so stash some information about it instead.
|
236 |
+
# if activation is F.relu or isinstance(activation, torch.nn.ReLU):
|
237 |
+
# self.activation_relu_or_gelu = 1
|
238 |
+
# elif activation is F.gelu or isinstance(activation, torch.nn.GELU):
|
239 |
+
# self.activation_relu_or_gelu = 2
|
240 |
+
# else:
|
241 |
+
# self.activation_relu_or_gelu = 0
|
242 |
+
self.activation = activation
|
243 |
+
|
244 |
+
norm1 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs)
|
245 |
+
if layer_norm_cls == IdentityNorm:
|
246 |
+
norm2 = BalancedBasicNorm(
|
247 |
+
d_model, eps=layer_norm_eps, **factory_kwargs
|
248 |
+
)
|
249 |
+
else:
|
250 |
+
norm2 = layer_norm_cls(
|
251 |
+
d_model, eps=layer_norm_eps, **factory_kwargs
|
252 |
+
)
|
253 |
+
|
254 |
+
if adaptive_layer_norm:
|
255 |
+
self.norm1 = AdaptiveLayerNorm(d_model, norm1)
|
256 |
+
self.norm2 = AdaptiveLayerNorm(d_model, norm2)
|
257 |
+
else:
|
258 |
+
self.norm1 = norm1
|
259 |
+
self.norm2 = norm2
|
260 |
+
|
261 |
+
def __setstate__(self, state):
|
262 |
+
super(TransformerEncoderLayer, self).__setstate__(state)
|
263 |
+
if not hasattr(self, "activation"):
|
264 |
+
self.activation = F.relu
|
265 |
+
|
266 |
+
def forward(
|
267 |
+
self,
|
268 |
+
src: Tensor,
|
269 |
+
src_mask: Optional[Tensor] = None,
|
270 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
271 |
+
need_weights: Optional[bool] = False,
|
272 |
+
past: Optional[Tensor] = None,
|
273 |
+
) -> Tensor:
|
274 |
+
r"""Pass the input through the encoder layer.
|
275 |
+
|
276 |
+
Args:
|
277 |
+
src: the sequence to the encoder layer (required).
|
278 |
+
src_mask: the mask for the src sequence (optional).
|
279 |
+
src_key_padding_mask: the mask for the src keys per batch (optional).
|
280 |
+
|
281 |
+
Shape:
|
282 |
+
see the docs in Transformer class.
|
283 |
+
"""
|
284 |
+
x, stage_embedding = src, None
|
285 |
+
is_src_tuple = False
|
286 |
+
if isinstance(src, tuple):
|
287 |
+
x, stage_embedding = src
|
288 |
+
is_src_tuple = True
|
289 |
+
|
290 |
+
if src_key_padding_mask is not None:
|
291 |
+
_skpm_dtype = src_key_padding_mask.dtype
|
292 |
+
if _skpm_dtype != torch.bool and not torch.is_floating_point(
|
293 |
+
src_key_padding_mask
|
294 |
+
):
|
295 |
+
raise AssertionError(
|
296 |
+
"only bool and floating types of key_padding_mask are supported"
|
297 |
+
)
|
298 |
+
if need_weights:
|
299 |
+
if self.norm_first:
|
300 |
+
out, attn = self._sa_block_attn(
|
301 |
+
self.norm1(x, stage_embedding),
|
302 |
+
src_mask,
|
303 |
+
src_key_padding_mask,
|
304 |
+
past
|
305 |
+
)
|
306 |
+
out, present = out # present is the kvcache of the present timestep
|
307 |
+
x = x + out
|
308 |
+
x = x + self._ff_block(self.norm2(x, stage_embedding))
|
309 |
+
else:
|
310 |
+
out, attn = self._sa_block_attn(x, src_mask, src_key_padding_mask, past)
|
311 |
+
out, present = out # present is the kvcache of the present timestep
|
312 |
+
x = self.norm1(
|
313 |
+
x + out,
|
314 |
+
stage_embedding,
|
315 |
+
)
|
316 |
+
x = self.norm2(x + self._ff_block(x), stage_embedding)
|
317 |
+
assert not is_src_tuple
|
318 |
+
# return (x, stage_embedding)
|
319 |
+
return (x, attn)
|
320 |
+
else:
|
321 |
+
if self.norm_first:
|
322 |
+
out = self._sa_block(
|
323 |
+
self.norm1(x, stage_embedding),
|
324 |
+
src_mask,
|
325 |
+
src_key_padding_mask, past
|
326 |
+
)
|
327 |
+
out, present = out # present is the kvcache of the present timestep
|
328 |
+
x = x + out
|
329 |
+
x = x + self._ff_block(self.norm2(x, stage_embedding))
|
330 |
+
else:
|
331 |
+
out = self._sa_block(x, src_mask, src_key_padding_mask)
|
332 |
+
out, present = out # present is the kvcache of the present timestep
|
333 |
+
x = self.norm1(
|
334 |
+
x + out,
|
335 |
+
stage_embedding, past
|
336 |
+
)
|
337 |
+
x = self.norm2(x + self._ff_block(x), stage_embedding)
|
338 |
+
|
339 |
+
if is_src_tuple:
|
340 |
+
x = (x, stage_embedding)
|
341 |
+
if present != None:
|
342 |
+
x = [x, present]
|
343 |
+
return x
|
344 |
+
|
345 |
+
# self-attention block
|
346 |
+
def _sa_block(
|
347 |
+
self,
|
348 |
+
x: Tensor,
|
349 |
+
attn_mask: Optional[Tensor],
|
350 |
+
key_padding_mask: Optional[Tensor],
|
351 |
+
past: Optional[Tensor] = None,
|
352 |
+
) -> Tensor:
|
353 |
+
x = self.self_attn(
|
354 |
+
x,
|
355 |
+
x,
|
356 |
+
x,
|
357 |
+
attn_mask=attn_mask,
|
358 |
+
key_padding_mask=key_padding_mask,
|
359 |
+
need_weights=False,
|
360 |
+
past=past
|
361 |
+
)
|
362 |
+
x, present = x
|
363 |
+
return self.dropout1(x), present
|
364 |
+
|
365 |
+
# self-attention block, also return attention weights
|
366 |
+
def _sa_block_attn(
|
367 |
+
self,
|
368 |
+
x: Tensor,
|
369 |
+
attn_mask: Optional[Tensor],
|
370 |
+
key_padding_mask: Optional[Tensor],
|
371 |
+
past: Optional[Tensor] = None,
|
372 |
+
) -> Tensor:
|
373 |
+
x, attn = self.self_attn(
|
374 |
+
x,
|
375 |
+
x,
|
376 |
+
x,
|
377 |
+
attn_mask=attn_mask,
|
378 |
+
key_padding_mask=key_padding_mask,
|
379 |
+
need_weights=True,
|
380 |
+
past=past
|
381 |
+
)
|
382 |
+
x, present = x
|
383 |
+
return (self.dropout1(x), present), attn
|
384 |
+
|
385 |
+
# feed forward block
|
386 |
+
def _ff_block(self, x: Tensor) -> Tensor:
|
387 |
+
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
|
388 |
+
return self.dropout2(x)
|
389 |
+
|
390 |
+
|
391 |
+
class TransformerEncoder(nn.Module):
|
392 |
+
r"""TransformerEncoder is a stack of N encoder layers. Users can build the
|
393 |
+
BERT(https://arxiv.org/abs/1810.04805) model with corresponding parameters.
|
394 |
+
|
395 |
+
Args:
|
396 |
+
encoder_layer: an instance of the TransformerEncoderLayer() class (required).
|
397 |
+
num_layers: the number of sub-encoder-layers in the encoder (required).
|
398 |
+
norm: the layer normalization component (optional).
|
399 |
+
enable_nested_tensor: if True, input will automatically convert to nested tensor
|
400 |
+
(and convert back on output). This will improve the overall performance of
|
401 |
+
TransformerEncoder when padding rate is high. Default: ``True`` (enabled).
|
402 |
+
|
403 |
+
Examples::
|
404 |
+
>>> encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8)
|
405 |
+
>>> transformer_encoder = TransformerEncoder(encoder_layer, num_layers=6)
|
406 |
+
>>> src = torch.rand(10, 32, 512)
|
407 |
+
>>> out = transformer_encoder(src)
|
408 |
+
"""
|
409 |
+
__constants__ = ["norm"]
|
410 |
+
|
411 |
+
def __init__(self, encoder_layer, num_layers, norm=None):
|
412 |
+
super(TransformerEncoder, self).__init__()
|
413 |
+
self.layers = _get_clones(encoder_layer, num_layers)
|
414 |
+
self.num_layers = num_layers
|
415 |
+
self.norm = norm
|
416 |
+
|
417 |
+
def forward(
|
418 |
+
self,
|
419 |
+
src: Tensor,
|
420 |
+
mask: Optional[Tensor] = None,
|
421 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
422 |
+
return_layer_states: bool = False,
|
423 |
+
need_weights:Optional[bool] = False,
|
424 |
+
past: Optional[Tensor] = None,
|
425 |
+
) -> Tensor:
|
426 |
+
r"""Pass the input through the encoder layers in turn.
|
427 |
+
|
428 |
+
Args:
|
429 |
+
src: the sequence to the encoder (required).
|
430 |
+
mask: the mask for the src sequence (optional).
|
431 |
+
src_key_padding_mask: the mask for the src keys per batch (optional).
|
432 |
+
return_layer_states: return layers' state (optional).
|
433 |
+
|
434 |
+
Shape:
|
435 |
+
see the docs in Transformer class.
|
436 |
+
"""
|
437 |
+
if return_layer_states:
|
438 |
+
assert not need_weights
|
439 |
+
layer_states = [] # layers' output
|
440 |
+
output = src
|
441 |
+
for mod in self.layers:
|
442 |
+
output = mod(
|
443 |
+
output,
|
444 |
+
src_mask=mask,
|
445 |
+
src_key_padding_mask=src_key_padding_mask,
|
446 |
+
past=past
|
447 |
+
)
|
448 |
+
layer_states.append(output[0])
|
449 |
+
|
450 |
+
if self.norm is not None:
|
451 |
+
output = self.norm(output)
|
452 |
+
|
453 |
+
return layer_states, output
|
454 |
+
if need_weights:
|
455 |
+
assert not return_layer_states
|
456 |
+
layer_attn = [] # layers' output
|
457 |
+
output = src
|
458 |
+
for mod in self.layers:
|
459 |
+
output = mod(
|
460 |
+
output,
|
461 |
+
src_mask=mask,
|
462 |
+
src_key_padding_mask=src_key_padding_mask,
|
463 |
+
need_weights=True,
|
464 |
+
past=past
|
465 |
+
)
|
466 |
+
layer_attn.append(output[1])
|
467 |
+
|
468 |
+
if self.norm is not None:
|
469 |
+
output = self.norm(output)
|
470 |
+
|
471 |
+
return layer_attn, output
|
472 |
+
|
473 |
+
output = src
|
474 |
+
all_present = []
|
475 |
+
for n_layer, mod in enumerate(self.layers):
|
476 |
+
output = mod(
|
477 |
+
output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, past=None if past is None else past[n_layer]
|
478 |
+
)
|
479 |
+
if isinstance(output, list):
|
480 |
+
output, present = output
|
481 |
+
all_present.append(present)
|
482 |
+
|
483 |
+
if self.norm is not None:
|
484 |
+
output = self.norm(output)
|
485 |
+
if all_present != []:
|
486 |
+
all_present = torch.stack(all_present, dim=0) # (num_layers, 2, batch_size, num_heads, seq_len, head_dim)
|
487 |
+
output = [output, all_present]
|
488 |
+
return output
|
489 |
+
|
490 |
+
|
491 |
+
class TransformerDecoderLayer(nn.Module):
|
492 |
+
__constants__ = ["batch_first", "norm_first"]
|
493 |
+
|
494 |
+
def __init__(
|
495 |
+
self,
|
496 |
+
d_model: int,
|
497 |
+
nhead: int,
|
498 |
+
dim_feedforward: int = 2048,
|
499 |
+
dropout: float = 0.1,
|
500 |
+
activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
|
501 |
+
linear1_self_attention_cls: nn.Module = nn.Linear,
|
502 |
+
linear2_self_attention_cls: nn.Module = nn.Linear,
|
503 |
+
linear1_feedforward_cls: nn.Module = nn.Linear,
|
504 |
+
linear2_feedforward_cls: nn.Module = nn.Linear,
|
505 |
+
batch_first: bool = False,
|
506 |
+
norm_first: bool = False,
|
507 |
+
device=None,
|
508 |
+
dtype=None,
|
509 |
+
layer_norm_cls: nn.Module = LayerNorm,
|
510 |
+
layer_norm_eps: float = 1e-5,
|
511 |
+
adaptive_layer_norm=False,
|
512 |
+
) -> None:
|
513 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
514 |
+
super(TransformerDecoderLayer, self).__init__()
|
515 |
+
self.self_attn = MultiheadAttention(
|
516 |
+
d_model,
|
517 |
+
nhead,
|
518 |
+
dropout=dropout,
|
519 |
+
batch_first=batch_first,
|
520 |
+
linear1_cls=linear1_self_attention_cls,
|
521 |
+
linear2_cls=linear2_self_attention_cls,
|
522 |
+
**factory_kwargs,
|
523 |
+
)
|
524 |
+
self.multihead_attn = MultiheadAttention(
|
525 |
+
d_model,
|
526 |
+
nhead,
|
527 |
+
dropout=dropout,
|
528 |
+
batch_first=batch_first,
|
529 |
+
linear1_cls=linear1_self_attention_cls,
|
530 |
+
linear2_cls=linear2_self_attention_cls,
|
531 |
+
**factory_kwargs,
|
532 |
+
)
|
533 |
+
# Implementation of Feedforward model
|
534 |
+
self.linear1 = linear1_feedforward_cls(
|
535 |
+
d_model, dim_feedforward, **factory_kwargs
|
536 |
+
)
|
537 |
+
self.dropout = nn.Dropout(dropout)
|
538 |
+
self.linear2 = linear2_feedforward_cls(
|
539 |
+
dim_feedforward, d_model, **factory_kwargs
|
540 |
+
)
|
541 |
+
|
542 |
+
self.norm_first = norm_first
|
543 |
+
self.dropout1 = nn.Dropout(dropout)
|
544 |
+
self.dropout2 = nn.Dropout(dropout)
|
545 |
+
self.dropout3 = nn.Dropout(dropout)
|
546 |
+
|
547 |
+
# Legacy string support for activation function.
|
548 |
+
if isinstance(activation, str):
|
549 |
+
self.activation = _get_activation_fn(activation)
|
550 |
+
elif isinstance(activation, partial):
|
551 |
+
self.activation = activation(d_model)
|
552 |
+
elif activation == BalancedDoubleSwish:
|
553 |
+
self.activation = BalancedDoubleSwish(d_model)
|
554 |
+
else:
|
555 |
+
self.activation = activation
|
556 |
+
|
557 |
+
if adaptive_layer_norm:
|
558 |
+
norm1 = layer_norm_cls(
|
559 |
+
d_model, eps=layer_norm_eps, **factory_kwargs
|
560 |
+
)
|
561 |
+
norm2 = layer_norm_cls(
|
562 |
+
d_model, eps=layer_norm_eps, **factory_kwargs
|
563 |
+
)
|
564 |
+
norm3 = layer_norm_cls(
|
565 |
+
d_model, eps=layer_norm_eps, **factory_kwargs
|
566 |
+
)
|
567 |
+
|
568 |
+
self.norm1 = AdaptiveLayerNorm(d_model, norm1)
|
569 |
+
self.norm2 = AdaptiveLayerNorm(d_model, norm2)
|
570 |
+
self.norm3 = AdaptiveLayerNorm(d_model, norm3)
|
571 |
+
else:
|
572 |
+
self.norm1 = layer_norm_cls(
|
573 |
+
d_model, eps=layer_norm_eps, **factory_kwargs
|
574 |
+
)
|
575 |
+
self.norm2 = layer_norm_cls(
|
576 |
+
d_model, eps=layer_norm_eps, **factory_kwargs
|
577 |
+
)
|
578 |
+
if layer_norm_cls == IdentityNorm:
|
579 |
+
self.norm3 = BalancedBasicNorm(
|
580 |
+
d_model, eps=layer_norm_eps, **factory_kwargs
|
581 |
+
)
|
582 |
+
else:
|
583 |
+
self.norm3 = layer_norm_cls(
|
584 |
+
d_model, eps=layer_norm_eps, **factory_kwargs
|
585 |
+
)
|
586 |
+
|
587 |
+
def forward(
|
588 |
+
self,
|
589 |
+
tgt: Tensor,
|
590 |
+
memory: Tensor,
|
591 |
+
tgt_mask: Optional[Tensor] = None,
|
592 |
+
memory_mask: Optional[Tensor] = None,
|
593 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
594 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
595 |
+
) -> Tensor:
|
596 |
+
r"""Pass the inputs (and mask) through the decoder layer.
|
597 |
+
|
598 |
+
Args:
|
599 |
+
tgt: the sequence to the decoder layer (required).
|
600 |
+
memory: the sequence from the last layer of the encoder (required).
|
601 |
+
tgt_mask: the mask for the tgt sequence (optional).
|
602 |
+
memory_mask: the mask for the memory sequence (optional).
|
603 |
+
tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
|
604 |
+
memory_key_padding_mask: the mask for the memory keys per batch (optional).
|
605 |
+
|
606 |
+
Shape:
|
607 |
+
see the docs in Transformer class.
|
608 |
+
"""
|
609 |
+
tgt_is_tuple = False
|
610 |
+
if isinstance(tgt, tuple):
|
611 |
+
x, stage_embedding = tgt
|
612 |
+
tgt_is_tuple = True
|
613 |
+
else:
|
614 |
+
x, stage_embedding = tgt, None
|
615 |
+
|
616 |
+
if self.norm_first:
|
617 |
+
x = x + self._sa_block(
|
618 |
+
self.norm1(x, stage_embedding), tgt_mask, tgt_key_padding_mask
|
619 |
+
)
|
620 |
+
x = x + self._mha_block(
|
621 |
+
self.norm2(x, stage_embedding),
|
622 |
+
memory,
|
623 |
+
memory_mask,
|
624 |
+
memory_key_padding_mask,
|
625 |
+
)
|
626 |
+
x = x + self._ff_block(self.norm3(x, stage_embedding))
|
627 |
+
else:
|
628 |
+
x = self.norm1(
|
629 |
+
x + self._sa_block(x, tgt_mask, tgt_key_padding_mask),
|
630 |
+
stage_embedding,
|
631 |
+
)
|
632 |
+
x = self.norm2(
|
633 |
+
x
|
634 |
+
+ self._mha_block(
|
635 |
+
x, memory, memory_mask, memory_key_padding_mask
|
636 |
+
),
|
637 |
+
stage_embedding,
|
638 |
+
)
|
639 |
+
x = self.norm3(x + self._ff_block(x), stage_embedding)
|
640 |
+
|
641 |
+
if tgt_is_tuple:
|
642 |
+
return (x, stage_embedding)
|
643 |
+
return x
|
644 |
+
|
645 |
+
# self-attention block
|
646 |
+
def _sa_block(
|
647 |
+
self,
|
648 |
+
x: Tensor,
|
649 |
+
attn_mask: Optional[Tensor],
|
650 |
+
key_padding_mask: Optional[Tensor],
|
651 |
+
) -> Tensor:
|
652 |
+
x = self.self_attn(
|
653 |
+
x,
|
654 |
+
x,
|
655 |
+
x,
|
656 |
+
attn_mask=attn_mask,
|
657 |
+
key_padding_mask=key_padding_mask,
|
658 |
+
need_weights=False,
|
659 |
+
)[0]
|
660 |
+
return self.dropout1(x)
|
661 |
+
|
662 |
+
# multihead attention block
|
663 |
+
def _mha_block(
|
664 |
+
self,
|
665 |
+
x: Tensor,
|
666 |
+
mem: Tensor,
|
667 |
+
attn_mask: Optional[Tensor],
|
668 |
+
key_padding_mask: Optional[Tensor],
|
669 |
+
) -> Tensor:
|
670 |
+
x = self.multihead_attn(
|
671 |
+
x,
|
672 |
+
mem,
|
673 |
+
mem,
|
674 |
+
attn_mask=attn_mask,
|
675 |
+
key_padding_mask=key_padding_mask,
|
676 |
+
need_weights=False,
|
677 |
+
)[0]
|
678 |
+
return self.dropout2(x)
|
679 |
+
|
680 |
+
# feed forward block
|
681 |
+
def _ff_block(self, x: Tensor) -> Tensor:
|
682 |
+
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
|
683 |
+
return self.dropout3(x)
|
684 |
+
|
685 |
+
|
686 |
+
def _get_clones(module, N):
|
687 |
+
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
|
688 |
+
|
689 |
+
|
690 |
+
def _get_activation_fn(activation: str) -> Callable[[Tensor], Tensor]:
|
691 |
+
if activation == "relu":
|
692 |
+
return F.relu
|
693 |
+
elif activation == "gelu":
|
694 |
+
return F.gelu
|
695 |
+
|
696 |
+
raise RuntimeError(
|
697 |
+
"activation should be relu/gelu, not {}".format(activation)
|
698 |
+
)
|
models/modules/utils.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# cp from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/transformer.py, modified by Puyuan Peng
|
2 |
+
import torch
|
3 |
+
|
4 |
+
|
5 |
+
def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
|
6 |
+
"""
|
7 |
+
Args:
|
8 |
+
lengths:
|
9 |
+
A 1-D tensor containing sentence lengths.
|
10 |
+
max_len:
|
11 |
+
The length of masks.
|
12 |
+
Returns:
|
13 |
+
Return a 2-D bool tensor, where masked positions
|
14 |
+
are filled with `True` and non-masked positions are
|
15 |
+
filled with `False`.
|
16 |
+
|
17 |
+
>>> lengths = torch.tensor([1, 3, 2, 5])
|
18 |
+
>>> make_pad_mask(lengths)
|
19 |
+
tensor([[False, True, True, True, True],
|
20 |
+
[False, False, False, True, True],
|
21 |
+
[False, False, True, True, True],
|
22 |
+
[False, False, False, False, False]])
|
23 |
+
"""
|
24 |
+
assert lengths.ndim == 1, lengths.ndim
|
25 |
+
max_len = max(max_len, lengths.max())
|
26 |
+
n = lengths.size(0)
|
27 |
+
seq_range = torch.arange(0, max_len, device=lengths.device)
|
28 |
+
expaned_lengths = seq_range.unsqueeze(0).expand(n, max_len)
|
29 |
+
|
30 |
+
return expaned_lengths >= lengths.unsqueeze(-1)
|
31 |
+
|
32 |
+
def generate_partial_autoregressive_mask(sz, start, end):
|
33 |
+
mask = torch.zeros(sz, sz).bool()
|
34 |
+
mask[start:end, start:end] = torch.triu(torch.ones(end-start, end-start,dtype=torch.bool), diagonal=1)
|
35 |
+
mask[:start, start:end] = True
|
36 |
+
mask[end:, start:end] = True
|
37 |
+
return mask
|
models/voicecraft.py
ADDED
@@ -0,0 +1,1406 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import logging
|
5 |
+
import argparse, copy
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from torchmetrics.classification import MulticlassAccuracy
|
10 |
+
|
11 |
+
from .modules.utils import make_pad_mask
|
12 |
+
|
13 |
+
from .modules.embedding import SinePositionalEmbedding, TokenEmbedding
|
14 |
+
from .modules.transformer import (
|
15 |
+
LayerNorm,
|
16 |
+
TransformerEncoder,
|
17 |
+
TransformerEncoderLayer,
|
18 |
+
)
|
19 |
+
from .codebooks_patterns import DelayedPatternProvider
|
20 |
+
|
21 |
+
def top_k_top_p_filtering(
|
22 |
+
logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1
|
23 |
+
):
|
24 |
+
"""Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
|
25 |
+
Args:
|
26 |
+
logits: logits distribution shape (batch size, vocabulary size)
|
27 |
+
if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
|
28 |
+
if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
|
29 |
+
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
|
30 |
+
Make sure we keep at least min_tokens_to_keep per batch example in the output
|
31 |
+
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
|
32 |
+
"""
|
33 |
+
if top_k > 0:
|
34 |
+
top_k = min(
|
35 |
+
max(top_k, min_tokens_to_keep), logits.size(-1)
|
36 |
+
) # Safety check
|
37 |
+
# Remove all tokens with a probability less than the last token of the top-k
|
38 |
+
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
39 |
+
logits[indices_to_remove] = filter_value
|
40 |
+
|
41 |
+
if top_p < 1.0:
|
42 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
43 |
+
cumulative_probs = torch.cumsum(
|
44 |
+
F.softmax(sorted_logits, dim=-1), dim=-1
|
45 |
+
)
|
46 |
+
|
47 |
+
# Remove tokens with cumulative probability above the threshold (token with 0 are kept)
|
48 |
+
sorted_indices_to_remove = cumulative_probs > top_p
|
49 |
+
if min_tokens_to_keep > 1:
|
50 |
+
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
|
51 |
+
sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
|
52 |
+
# Shift the indices to the right to keep also the first token above the threshold
|
53 |
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
|
54 |
+
..., :-1
|
55 |
+
].clone()
|
56 |
+
sorted_indices_to_remove[..., 0] = 0
|
57 |
+
|
58 |
+
# scatter sorted tensors to original indexing
|
59 |
+
indices_to_remove = sorted_indices_to_remove.scatter(
|
60 |
+
1, sorted_indices, sorted_indices_to_remove
|
61 |
+
)
|
62 |
+
logits[indices_to_remove] = filter_value
|
63 |
+
return logits
|
64 |
+
|
65 |
+
|
66 |
+
def topk_sampling(logits, top_k=10, top_p=1.0, temperature=1.0):
|
67 |
+
# temperature: (`optional`) float
|
68 |
+
# The value used to module the next token probabilities. Must be strictly positive. Default to 1.0.
|
69 |
+
# top_k: (`optional`) int
|
70 |
+
# The number of highest probability vocabulary tokens to keep for top-k-filtering. Between 1 and infinity. Default to 50.
|
71 |
+
# top_p: (`optional`) float
|
72 |
+
# The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Must be between 0 and 1. Default to 1.
|
73 |
+
|
74 |
+
# Temperature (higher temperature => more likely to sample low probability tokens)
|
75 |
+
if temperature != 1.0:
|
76 |
+
logits = logits / temperature
|
77 |
+
# Top-p/top-k filtering
|
78 |
+
logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
|
79 |
+
# Sample
|
80 |
+
token = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
|
81 |
+
return token
|
82 |
+
|
83 |
+
|
84 |
+
|
85 |
+
class VoiceCraft(nn.Module):
|
86 |
+
def __init__(self, args):
|
87 |
+
super().__init__()
|
88 |
+
self.args = copy.copy(args)
|
89 |
+
self.pattern = DelayedPatternProvider(n_q=self.args.n_codebooks)
|
90 |
+
if not getattr(self.args, "special_first", False):
|
91 |
+
self.args.special_first = 0
|
92 |
+
if not getattr(self.args, "n_special", False):
|
93 |
+
self.args.n_special = 3
|
94 |
+
self.args.eos = getattr(self.args, "eos", -1)
|
95 |
+
self.eog = nn.Parameter(torch.full((self.args.n_codebooks, 1), self.args.eog, dtype=torch.long), requires_grad=False) # [K 1]
|
96 |
+
if self.args.eos > 0:
|
97 |
+
assert self.args.eos != self.args.audio_pad_token and self.args.eos != self.args.empty_token, self.args.eos
|
98 |
+
self.eos = nn.Parameter(torch.full((self.args.n_codebooks, 1), self.args.eos, dtype=torch.long), requires_grad=False) # [K 1]
|
99 |
+
if type(self.args.audio_vocab_size) == str:
|
100 |
+
self.args.audio_vocab_size = eval(self.args.audio_vocab_size)
|
101 |
+
|
102 |
+
self.n_text_tokens = self.args.text_vocab_size + 1
|
103 |
+
assert self.args.text_pad_token == self.args.text_vocab_size, f"self.args.text_vocab_size: {self.args.text_vocab_size}, self.args.text_pad_token: {self.args.text_pad_token}"
|
104 |
+
|
105 |
+
self.n_audio_tokens = [self.args.audio_vocab_size + self.args.n_special] * self.args.n_codebooks # special tokens: empty token, EOG token, audio pad token
|
106 |
+
assert self.args.audio_vocab_size == self.args.empty_token, self.args.empty_token
|
107 |
+
assert self.args.eog == self.args.audio_vocab_size + 1, self.args.eog
|
108 |
+
assert self.args.audio_pad_token == self.args.audio_vocab_size + 2, self.args.audio_pad_token
|
109 |
+
|
110 |
+
self.text_embedding = TokenEmbedding(
|
111 |
+
dim_model=self.args.d_model,
|
112 |
+
vocab_size=self.n_text_tokens,
|
113 |
+
dropout=self.args.text_embedding_dropout
|
114 |
+
)
|
115 |
+
|
116 |
+
self.audio_embedding = nn.ModuleList(
|
117 |
+
[
|
118 |
+
TokenEmbedding(
|
119 |
+
dim_model=self.args.audio_embedding_dim,
|
120 |
+
vocab_size=self.n_audio_tokens[k],
|
121 |
+
dropout=self.args.audio_embedding_dropout
|
122 |
+
) for k in range(self.args.n_codebooks)
|
123 |
+
]
|
124 |
+
)
|
125 |
+
self.mask_embedding = nn.Parameter(torch.randn(self.args.max_n_spans, self.args.d_model), requires_grad=True)
|
126 |
+
self.text_positional_embedding = SinePositionalEmbedding(
|
127 |
+
self.args.d_model,
|
128 |
+
dropout=self.args.text_positional_embedding_dropout,
|
129 |
+
scale=False,
|
130 |
+
alpha=True, # learnable scaler, scale the volume of positional embedding
|
131 |
+
)
|
132 |
+
self.audio_positional_embedding = SinePositionalEmbedding(
|
133 |
+
self.args.d_model,
|
134 |
+
dropout=self.args.audio_positional_embedding_dropout,
|
135 |
+
scale=False,
|
136 |
+
alpha=True, # learnable scaler, scale the volume of positional embedding
|
137 |
+
)
|
138 |
+
|
139 |
+
dec_layer = TransformerEncoderLayer(
|
140 |
+
self.args.d_model,
|
141 |
+
self.args.nhead,
|
142 |
+
dim_feedforward=self.args.d_model * 4,
|
143 |
+
dropout=self.args.trm_dropout,
|
144 |
+
batch_first=True,
|
145 |
+
norm_first=True,
|
146 |
+
layer_norm_cls=LayerNorm
|
147 |
+
)
|
148 |
+
self.decoder = TransformerEncoder(
|
149 |
+
dec_layer,
|
150 |
+
num_layers=self.args.num_decoder_layers,
|
151 |
+
norm=LayerNorm(self.args.d_model),
|
152 |
+
)
|
153 |
+
|
154 |
+
self.predict_layer = nn.ModuleList(
|
155 |
+
[
|
156 |
+
nn.Sequential(nn.Linear(self.args.d_model, self.args.audio_vocab_size//2), nn.GELU(), nn.Linear(self.args.audio_vocab_size//2, self.n_audio_tokens[k])) for k in range(self.args.n_codebooks)
|
157 |
+
]
|
158 |
+
)
|
159 |
+
|
160 |
+
self.accuracy_metrics = nn.ModuleList(
|
161 |
+
[MulticlassAccuracy(
|
162 |
+
self.n_audio_tokens[k],
|
163 |
+
top_k=10,
|
164 |
+
average="micro",
|
165 |
+
multidim_average="global",
|
166 |
+
ignore_index=None,
|
167 |
+
) for k in range(self.args.n_codebooks)]
|
168 |
+
)
|
169 |
+
|
170 |
+
|
171 |
+
def prepare_mask_intervals(self, y_lens):
|
172 |
+
mask_intervals = []
|
173 |
+
non_mask_intervals = []
|
174 |
+
|
175 |
+
for i, y_len in enumerate(y_lens):
|
176 |
+
if self.args.mask_sample_dist == "uniform":
|
177 |
+
n_spans = random.choice(range(1, self.args.max_n_spans+1))
|
178 |
+
elif "poisson" in self.args.mask_sample_dist.lower():
|
179 |
+
param = float(self.args.mask_sample_dist[len("poisson"):])
|
180 |
+
poisson_sample = torch.poisson(torch.tensor([param]))
|
181 |
+
n_spans = int(poisson_sample.clamp(1, self.args.max_n_spans).item())
|
182 |
+
|
183 |
+
starts = random.sample(range(1, y_len-1-self.args.mask_len_min), n_spans)
|
184 |
+
starts = sorted(starts)
|
185 |
+
|
186 |
+
for j in range(len(starts)-1, 0, -1):
|
187 |
+
if starts[j] - starts[j-1] < self.args.min_gap:
|
188 |
+
del starts[j] # If elements are too close, delete the later one
|
189 |
+
assert len(starts) > 0, f"there is no masked span left, y_len: {y_len}, sampled n_spans: {n_spans}"
|
190 |
+
|
191 |
+
temp_starts = starts + [y_len]
|
192 |
+
gaps = [temp_starts[j+1] - temp_starts[j] for j in range(len(temp_starts)-1)]
|
193 |
+
|
194 |
+
ends = []
|
195 |
+
|
196 |
+
for j, (start, gap) in enumerate(zip(starts, gaps)):
|
197 |
+
mask_len = random.randint(self.args.mask_len_min, self.args.mask_len_max)
|
198 |
+
# if mask_len > gap * self.args.max_mask_portion: # make sure the masks are not overlapping with each other
|
199 |
+
if mask_len > gap - 1: # make sure the masks are not overlapping with each other
|
200 |
+
# temp_mask_start = int(0.6*gap*self.args.max_mask_portion)
|
201 |
+
# temp_mask_end = int(gap*self.args.max_mask_portion)
|
202 |
+
temp_mask_start = 1
|
203 |
+
temp_mask_end = gap - 1
|
204 |
+
mask_len = random.randint(temp_mask_start, temp_mask_end)
|
205 |
+
ends.append(start + mask_len)
|
206 |
+
|
207 |
+
mask_intervals.append([(s,e) for s,e in zip(starts, ends)])
|
208 |
+
non_mask_intervals.append([(ns,ne) for ns, ne in zip([0]+ends, starts+[y_len])])
|
209 |
+
|
210 |
+
return mask_intervals, non_mask_intervals
|
211 |
+
|
212 |
+
def rearrange(self, y, non_mask_intervals, mask_intervals):
|
213 |
+
reduced_eog = getattr(self.args, "reduced_eog", 0)
|
214 |
+
rearranged_y = []
|
215 |
+
for i in range(len(y)):
|
216 |
+
if self.args.eos > 0:
|
217 |
+
assert reduced_eog
|
218 |
+
cur_y = [y[i, :, item[0]: item[1]] for item in non_mask_intervals[i][:-1]] + [torch.cat([y[i, :, non_mask_intervals[i][-1][0]: non_mask_intervals[i][-1][1]], self.eos], dim=-1)] + [torch.cat([y[i, :, item[0]: item[1]], self.eog], dim=-1) for item in mask_intervals[i]] # only insert eog to the last non-mask-interval, which is when the utterance actual ends
|
219 |
+
else:
|
220 |
+
if reduced_eog:
|
221 |
+
cur_y = [y[i, :, item[0]: item[1]] for item in non_mask_intervals[i][:-1]] + [torch.cat([y[i, :, non_mask_intervals[i][-1][0]: non_mask_intervals[i][-1][1]], self.eog], dim=-1)] + [torch.cat([y[i, :, item[0]: item[1]], self.eog], dim=-1) for item in mask_intervals[i]] # only insert eog to the last non-mask-interval, which is when the utterance actual ends
|
222 |
+
else:
|
223 |
+
cur_y = [torch.cat([y[i, :, item[0]: item[1]], self.eog], dim=-1) for item in non_mask_intervals[i]] + [torch.cat([y[i, :, item[0]: item[1]], self.eog], dim=-1) for item in mask_intervals[i]] # eog is added to each section TODO this is not correct, I should add eog to non_mask_intervals if that segment is not the ending segment (as there is no way for the model to predict eog for those segments, and this will do harm to tts experiment, where the model randomly output eog for the first segment)
|
224 |
+
rearranged_y.append(cur_y)
|
225 |
+
return rearranged_y
|
226 |
+
|
227 |
+
def shift(self, rearranged_y):
|
228 |
+
shifted_y = []
|
229 |
+
patterns = []
|
230 |
+
for i in range(len(rearranged_y)):
|
231 |
+
cur_patterns = [self.pattern.get_pattern(cur_y.shape[1]) for cur_y in rearranged_y[i]]
|
232 |
+
out = [cur_pattern.build_pattern_sequence(z=cur_y.unsqueeze(0).contiguous(), special_token=self.args.empty_token, keep_only_valid_steps=False) for cur_pattern, cur_y in zip(cur_patterns, rearranged_y[i])]
|
233 |
+
shifted_y.append([item[0].squeeze(0) for item in out]) # the first item is values, later two are indexes and mask
|
234 |
+
patterns.append(cur_patterns)
|
235 |
+
return shifted_y, patterns
|
236 |
+
|
237 |
+
def insert_mask(self, shifted_y):
|
238 |
+
inserted_y = []
|
239 |
+
mask_position = []
|
240 |
+
mask_value = []
|
241 |
+
for i in range(len(shifted_y)):
|
242 |
+
num_masks = (len(shifted_y[i]) - 1) // 2
|
243 |
+
assert num_masks == (len(shifted_y[i]) - 1) / 2, len(shifted_y[i])
|
244 |
+
emb_inds = list(range(self.args.max_n_spans))
|
245 |
+
if self.args.shuffle_mask_embedding:
|
246 |
+
random.shuffle(emb_inds)
|
247 |
+
emb_inds_use = emb_inds[:num_masks]
|
248 |
+
emb_inds_use = emb_inds_use + emb_inds_use
|
249 |
+
mask_value.append(emb_inds_use)
|
250 |
+
cur_inserted_y = []
|
251 |
+
cur_mask_position = []
|
252 |
+
for j in range(len(shifted_y[i])-1):
|
253 |
+
cur_inserted_y.append(shifted_y[i][j])
|
254 |
+
cur_mask_position.append(sum([item.shape[1] for item in cur_inserted_y])) # each item is of shape [K S], so take shape[1]
|
255 |
+
cur_inserted_y.append(self.eog) # insert mask token of shape [K, 1], BUT we are actually using the eog token as a place holder here, as the real mask will be inserted in embed_y function
|
256 |
+
|
257 |
+
cur_inserted_y.append(shifted_y[i][-1])
|
258 |
+
|
259 |
+
inserted_y.append(cur_inserted_y)
|
260 |
+
mask_position.append(cur_mask_position)
|
261 |
+
return inserted_y, mask_position, mask_value
|
262 |
+
|
263 |
+
def cat_y(self, inserted_y, mask_position, y_lens):
|
264 |
+
reduced_eog = getattr(self.args, "reduced_eog", 0)
|
265 |
+
cated_y = []
|
266 |
+
new_y_lens = []
|
267 |
+
for i in range(len(inserted_y)):
|
268 |
+
cur_cated_y = torch.cat(inserted_y[i], dim=1) #[K S]
|
269 |
+
cur_cated_y = cur_cated_y.transpose(1,0) # [S K]
|
270 |
+
cur_cated_y_len = cur_cated_y.shape[0]
|
271 |
+
if reduced_eog:
|
272 |
+
assert cur_cated_y_len == y_lens[i] + len(mask_position[i]) + (len(mask_position[i]) + 1) * self.args.n_codebooks + (len(mask_position[i])/2 + 1), f"cur_cated_y_len == {cur_cated_y_len}, but it should be y_lens[i] ({y_lens[i]}) + len(mask_position[i]) ({len(mask_position[i])}) + (len(mask_position[i]) + 1) * self.args.n_codebooks ({(len(mask_position[i]) + 1) * self.args.n_codebooks}) + (len(mask_position[i])/2 + 1) ({len(mask_position[i])/2 + 1})={y_lens[i] + len(mask_position[i]) + (len(mask_position[i]) + 1) * self.args.n_codebooks + (len(mask_position[i])/2 + 1)}"
|
273 |
+
else:
|
274 |
+
assert cur_cated_y_len == y_lens[i] + len(mask_position[i]) + (len(mask_position[i]) + 1) * self.args.n_codebooks + (len(mask_position[i]) + 1), f"cur_cated_y_len == {cur_cated_y_len}, but it should be y_lens[i] ({y_lens[i]}) + len(mask_position[i]) ({len(mask_position[i])}) + (len(mask_position[i]) + 1) * self.args.n_codebooks ({(len(mask_position[i]) + 1) * self.args.n_codebooks}) + (len(mask_position[i]) + 1) ({len(mask_position[i]) + 1})" # the last term represent the inserted eog token, originally it's inserted at the end of every token, but this is wrong
|
275 |
+
new_y_lens.append(cur_cated_y_len)
|
276 |
+
cated_y.append(cur_cated_y)
|
277 |
+
|
278 |
+
cated_y = torch.nn.utils.rnn.pad_sequence(cated_y, batch_first=False, padding_value=self.args.audio_pad_token)
|
279 |
+
assert cated_y.shape == torch.Size([max(new_y_lens),len(inserted_y), self.args.n_codebooks]), f"cated_y.shape: {cated_y.shape}, but it should be {torch.Size([max(new_y_lens,len(inserted_y), self.args.n_codebooks)])}"
|
280 |
+
cated_y = cated_y.permute(2,0,1) # [T,B,K]->[K,T,B]
|
281 |
+
assert cated_y.shape[0] == self.args.n_codebooks, cated_y.shape
|
282 |
+
return cated_y, torch.LongTensor(new_y_lens).to(cated_y.device)
|
283 |
+
|
284 |
+
def embed_y(self, cated_y, mask_position, mask_value):
|
285 |
+
embedded_y = torch.stack([self.audio_embedding[k](cated_y[k]) for k in range(self.args.n_codebooks)], dim=0) # [K, T, B, D]
|
286 |
+
assert embedded_y.shape[0] == self.args.n_codebooks, embedded_y.shape
|
287 |
+
assert embedded_y.shape[-1] == self.args.d_model, embedded_y.shape
|
288 |
+
embedded_y = embedded_y.sum(dim=0) # [K,T,B,D]->[T,B,D]
|
289 |
+
embedded_y = embedded_y.transpose(1,0) # [T,B,D]->[B,T,D]
|
290 |
+
for i in range(len(embedded_y)):
|
291 |
+
if len(mask_position[i]) > 0:
|
292 |
+
embedded_y[i, mask_position[i]] = self.mask_embedding[mask_value[i]]
|
293 |
+
return embedded_y
|
294 |
+
|
295 |
+
def prepare_input_target(self, y, y_lens):
|
296 |
+
# rearrange y
|
297 |
+
# assume y shape: [B T K], K is n_codebooks
|
298 |
+
assert y.shape[1] == self.args.n_codebooks, y.shape
|
299 |
+
# sample mask_intervals
|
300 |
+
mask_intervals, non_mask_intervals = self.prepare_mask_intervals(y_lens)
|
301 |
+
|
302 |
+
# need to have EOG in each section (SOG will be generated by the pattern class)
|
303 |
+
# but mask can be inserted later after we have shifted the input
|
304 |
+
# y could be rearranged in this way:
|
305 |
+
# [
|
306 |
+
# [tensor[4, 12], tensor[4, 45], tensor[4, 102], tensor[4, 32]], tensor[4, 22]],
|
307 |
+
# [tensor[4, 44], tensor[4, 56], tensor[4, 19]],
|
308 |
+
# ...
|
309 |
+
# ]
|
310 |
+
# for the first list of tensors (4 tensors), first 3 tensors are non_masked part, last 2 are masked part.
|
311 |
+
# NOTE #non_masked_part = #masked_part + 1
|
312 |
+
# NOTE *these are also the targets*
|
313 |
+
# added eog at the end of each segment (masked segment and unmasked segment)
|
314 |
+
rearranged_y = self.rearrange(y, non_mask_intervals, mask_intervals)
|
315 |
+
targets = rearranged_y # each element in each sample is of shape [K T]
|
316 |
+
assert targets[0][0].shape[0] == self.args.n_codebooks, targets[0][0].shape
|
317 |
+
|
318 |
+
# next we need to apply pattern shifting to each tensor, after which, we'll replace the starting tokens of each section with a token that's different from the special padding token
|
319 |
+
# [[5, 1, 2, 3, 4, 5, 5],
|
320 |
+
# [5, 5, 1, 2, 3, 4, 5],
|
321 |
+
# [5, 5, 5, 1, 2, 3, 4]]
|
322 |
+
shifted_y, patterns = self.shift(rearranged_y) # each element [K S]
|
323 |
+
assert shifted_y[0][0].shape[0] == self.args.n_codebooks, shifted_y[0][0].shape[0]
|
324 |
+
|
325 |
+
|
326 |
+
# then, insert mask token at the intersection of each tensor (we want to decide the arrangement of the mask (shuffle or not)), we better have a separate nn.embedding for it
|
327 |
+
# we also need to record the position of the inserted mask
|
328 |
+
inserted_y, mask_position, mask_value = self.insert_mask(shifted_y)
|
329 |
+
assert inserted_y[0][0].shape[0] == self.args.n_codebooks, inserted_y[0][0].shape[0]
|
330 |
+
assert inserted_y[0][1].shape == torch.Size((self.args.n_codebooks, 1)), f"this should be a mask, so should have shape {(self.args.n_codebooks, 1)}, but it's {inserted_y[0][1].shape}"
|
331 |
+
|
332 |
+
# then concat tensors that belong to the same sample (in order) then get the length of each sample, and then stack them in batch dimension, pad them with pad_token
|
333 |
+
cated_y, new_y_lens = self.cat_y(inserted_y, mask_position, y_lens) # KTB
|
334 |
+
assert cated_y.shape == torch.Size((self.args.n_codebooks, cated_y.shape[1], len(inserted_y)))
|
335 |
+
|
336 |
+
|
337 |
+
# embed remember to separately embed the mask tokens
|
338 |
+
embedded_y = self.embed_y(cated_y, mask_position, mask_value) #BTD
|
339 |
+
assert embedded_y.shape[1:] == torch.Size((max(new_y_lens), self.args.d_model)), embedded_y.shape
|
340 |
+
|
341 |
+
# positional embedding
|
342 |
+
y_input = self.audio_positional_embedding(embedded_y)
|
343 |
+
|
344 |
+
# make attention mask and padding mask
|
345 |
+
y_padding_mask = make_pad_mask(new_y_lens).to(y.device)
|
346 |
+
y_attention_mask = torch.triu(torch.ones(y_input.shape[1], y_input.shape[1]), diagonal=1).bool().to(y_padding_mask.device)
|
347 |
+
return y_input, new_y_lens, targets, y_padding_mask, y_attention_mask, mask_position, patterns
|
348 |
+
|
349 |
+
def remove_mask(self, logits, mask_position, new_y_lens):
|
350 |
+
# logits: [B K S card]
|
351 |
+
logits_use = []
|
352 |
+
for i in range(len(logits)):
|
353 |
+
non_mask_positions = [-1] + mask_position[i] + [new_y_lens[i]]
|
354 |
+
non_mask_intervals = [[non_mask_positions[i]+1, non_mask_positions[i+1]] for i in range(len(non_mask_positions)-1)]
|
355 |
+
cur_logits_use = [logits[i, :, l:r] for l,r in non_mask_intervals]
|
356 |
+
logits_use.append(cur_logits_use)
|
357 |
+
|
358 |
+
return logits_use
|
359 |
+
|
360 |
+
def revert_pattern(self, patterns, logits_use):
|
361 |
+
logits_final = []
|
362 |
+
logit_masks = []
|
363 |
+
for i in range(len(logits_use)):
|
364 |
+
cur_logits = [
|
365 |
+
item.unsqueeze(0).permute(0, 3, 1, 2).contiguous() for item in logits_use[i]
|
366 |
+
] # each item is of shape [1 K S card] [1 card K S]
|
367 |
+
cur_logits_final = [
|
368 |
+
cur_pattern.revert_pattern_logits(
|
369 |
+
item, 0, keep_only_valid_steps=False
|
370 |
+
)
|
371 |
+
for cur_pattern, item in zip(patterns[i], cur_logits)
|
372 |
+
] # if input output order doesn't match, this step will give an error
|
373 |
+
cur_logits_final_ret = [item[0].permute(0,2,3,1).squeeze(0) for item in cur_logits_final] # each element is of shape [K,T,card]
|
374 |
+
logits_final.append(cur_logits_final_ret)
|
375 |
+
logit_masks.append([item[2] for item in cur_logits_final])
|
376 |
+
|
377 |
+
return logits_final, logit_masks
|
378 |
+
|
379 |
+
def dec_forward(
|
380 |
+
self,
|
381 |
+
x_input,
|
382 |
+
x_lens,
|
383 |
+
x_attention_mask,
|
384 |
+
x_padding_mask,
|
385 |
+
y_input,
|
386 |
+
new_y_lens,
|
387 |
+
y_attention_mask,
|
388 |
+
y_padding_mask,
|
389 |
+
past=None,
|
390 |
+
last_3_tokens=False
|
391 |
+
):
|
392 |
+
x_attn_mask = F.pad(
|
393 |
+
x_attention_mask,
|
394 |
+
(0, new_y_lens.max()),
|
395 |
+
value=True,
|
396 |
+
) # x attn to all x, doesn't attn to any y, this follow figure 3 of the valle paper
|
397 |
+
y_attn_mask = F.pad(
|
398 |
+
y_attention_mask,
|
399 |
+
(x_lens.max(), 0), # y is padded at the front
|
400 |
+
value=False,
|
401 |
+
) # y attn to all x, for y itself use lower triangle mask to ensure autoregressive
|
402 |
+
xy_attn_mask = torch.concat([x_attn_mask, y_attn_mask], dim=0)
|
403 |
+
|
404 |
+
# merge key padding and attention masks
|
405 |
+
bsz, src_len = x_input.shape[0], x_lens.max() + new_y_lens.max()
|
406 |
+
xy_padding_mask = torch.concat([x_padding_mask, y_padding_mask], dim=1)
|
407 |
+
_xy_padding_mask = (
|
408 |
+
xy_padding_mask.view(bsz, 1, 1, src_len)
|
409 |
+
.expand(-1, self.args.nhead, -1, -1)
|
410 |
+
.reshape(bsz * self.args.nhead, 1, src_len)
|
411 |
+
)
|
412 |
+
xy_attn_mask = xy_attn_mask.logical_or(_xy_padding_mask)
|
413 |
+
|
414 |
+
new_attn_mask = torch.zeros_like(xy_attn_mask)
|
415 |
+
new_attn_mask.masked_fill_(xy_attn_mask, float("-inf"))
|
416 |
+
xy_attn_mask = new_attn_mask
|
417 |
+
|
418 |
+
xy_input = torch.cat([x_input, y_input], dim=1)
|
419 |
+
|
420 |
+
if past == None: # do not use kvcache
|
421 |
+
out, _ = self.decoder((xy_input, None), mask=xy_attn_mask)
|
422 |
+
return out[:, x_lens.max():], None
|
423 |
+
else: # use kvcache
|
424 |
+
if past.ndim > 3: # uses kvcache, only need to pass the last tokens, this doesn't work with multi-span speech editing yet
|
425 |
+
if last_3_tokens:
|
426 |
+
xy_input = xy_input[:, -3:]
|
427 |
+
xy_attn_mask = xy_attn_mask[:, -3:]
|
428 |
+
else:
|
429 |
+
xy_input = xy_input[:, -1:]
|
430 |
+
xy_attn_mask = xy_attn_mask[:, -1:]
|
431 |
+
|
432 |
+
out, present = self.decoder((xy_input, None), mask=xy_attn_mask, past=past)
|
433 |
+
if isinstance(out, tuple): # get rid of stage_embedding
|
434 |
+
out = out[0]
|
435 |
+
|
436 |
+
if out.shape[1] > x_lens.max(): # the first pass, not kvcache yet
|
437 |
+
return out[:, x_lens.max():], present
|
438 |
+
else: # used kvcache
|
439 |
+
return out, present
|
440 |
+
|
441 |
+
def forward(self, batch):
|
442 |
+
"""
|
443 |
+
Args:
|
444 |
+
x:
|
445 |
+
A 2-D tensor of shape (N, S).
|
446 |
+
x_lens:
|
447 |
+
A 1-D tensor of shape (N,). It contains the number of tokens in `x`
|
448 |
+
before padding.
|
449 |
+
y:
|
450 |
+
A 3-D tensor of shape (N, K, T).
|
451 |
+
where K is the number of codebooks
|
452 |
+
y_lens:
|
453 |
+
A 1-D tensor of shape (N,). It contains the number of tokens in `x`
|
454 |
+
before padding.
|
455 |
+
"""
|
456 |
+
x, x_lens, y, y_lens = batch["x"], batch["x_lens"], batch["y"], batch["y_lens"]
|
457 |
+
x = x[:, :x_lens.max()] # this deal with gradient accumulation, where x_lens.max() might not be longer than the length of the current slice of x
|
458 |
+
y = y[:, :y_lens.max()]
|
459 |
+
assert x.ndim == 2, x.shape
|
460 |
+
assert x_lens.ndim == 1, x_lens.shape
|
461 |
+
assert y.ndim == 3 and y.shape[1] == self.args.n_codebooks, y.shape
|
462 |
+
assert y_lens.ndim == 1, y_lens.shape
|
463 |
+
# makes attention mask and padding mask for x
|
464 |
+
x_padding_mask = make_pad_mask(x_lens).to(x.device)
|
465 |
+
x_attention_mask = torch.triu(torch.ones(x.shape[1], x.shape[1]), diagonal=1).bool().to(x_padding_mask.device)
|
466 |
+
x_input = self.text_embedding(x)
|
467 |
+
x_input = self.text_positional_embedding(x_input)
|
468 |
+
y_input, new_y_lens, targets, y_padding_mask, y_attention_mask, mask_position, patterns = self.prepare_input_target(y, y_lens)
|
469 |
+
y_out = self.dec_forward(
|
470 |
+
x_input,
|
471 |
+
x_lens,
|
472 |
+
x_attention_mask,
|
473 |
+
x_padding_mask,
|
474 |
+
y_input,
|
475 |
+
new_y_lens,
|
476 |
+
y_attention_mask,
|
477 |
+
y_padding_mask
|
478 |
+
)
|
479 |
+
y_out = y_out[0] # no kv-caching during training
|
480 |
+
assert y_out.shape == y_input.shape, f"y_out.shape: {y_out.shape}, y_input.shape: {y_input.shape}" # [B S D]
|
481 |
+
|
482 |
+
logits = torch.stack([self.predict_layer[i](y_out) for i in range(self.args.n_codebooks)], dim=1) # [B K S card]
|
483 |
+
# take out the mask token (using mask_position and new_y_lens) and revert (using function provided by self.pattern)
|
484 |
+
assert logits.shape[1] == self.args.n_codebooks and logits.shape[3] == self.n_audio_tokens[0], logits.shape
|
485 |
+
|
486 |
+
logits_use = self.remove_mask(logits, mask_position, new_y_lens)
|
487 |
+
|
488 |
+
# revert the pattern shift for each logits section in each sample
|
489 |
+
logits_final, logit_masks = self.revert_pattern(patterns, logits_use)
|
490 |
+
assert logits_final[0][0].shape[0] == self.args.n_codebooks and logits_final[0][0].shape[2] == self.n_audio_tokens[0], f"it is: {logits_final[0][0].shape}, but should be [K, T, card]"
|
491 |
+
# testing
|
492 |
+
sample_to_test = 0
|
493 |
+
assert len(logits_final[sample_to_test]) == len(targets[sample_to_test]), f"{len(logits_final[sample_to_test])}, {len(targets[sample_to_test])}"
|
494 |
+
temp = sum([logits_final[sample_to_test][i].shape[:-1] != targets[sample_to_test][i].shape for i in range(len(targets[sample_to_test]))])
|
495 |
+
assert temp == 0, f"none equal positions: {temp}, total number of elements: {len(targets[sample_to_test])}"
|
496 |
+
|
497 |
+
logit_masked = sum([(item==False).any() for cur_mask in logit_masks for item in cur_mask])
|
498 |
+
assert logit_masked == 0, logit_masks
|
499 |
+
|
500 |
+
logits = torch.cat([torch.cat(item, dim=1) for item in logits_final], dim=1) # [K, T1+T2+T3+..., card]
|
501 |
+
targets = torch.cat([torch.cat(item, dim=1) for item in targets], dim=1) # [K, T1+T2+T3+...]
|
502 |
+
assert targets.shape[0] == logits.shape[0], f"{targets.shape}, {logits.shape}"
|
503 |
+
loss = []
|
504 |
+
ntokens = []
|
505 |
+
top10acc = []
|
506 |
+
for k, (logit, target) in enumerate(zip(logits, targets)):
|
507 |
+
loss.append(F.cross_entropy(logit, target, reduction='mean'))
|
508 |
+
top10acc.append(self.accuracy_metrics[k](logit.detach(), target))
|
509 |
+
ntokens.append(len(logit))
|
510 |
+
|
511 |
+
all_ntokens = sum(ntokens)
|
512 |
+
if self.args.codebook_weight != None:
|
513 |
+
codebook_weight = eval(self.args.codebook_weight)
|
514 |
+
else:
|
515 |
+
codebook_weight = [1.] * self.args.n_codebooks
|
516 |
+
loss = sum([l*nt*cw for l, nt, cw in zip(loss, ntokens, codebook_weight)])
|
517 |
+
top10acc_by_codebook = [t10a*nt for t10a, nt in zip(top10acc, ntokens)]
|
518 |
+
top10acc = sum(top10acc_by_codebook)
|
519 |
+
ntokens = torch.tensor(all_ntokens).to(logits.device)
|
520 |
+
|
521 |
+
return {
|
522 |
+
"loss": loss,
|
523 |
+
"top10acc": top10acc,
|
524 |
+
"top10acc_by_codebook": top10acc_by_codebook,
|
525 |
+
"effective_ntoken": ntokens,
|
526 |
+
}
|
527 |
+
|
528 |
+
def inference(
|
529 |
+
self,
|
530 |
+
x: torch.Tensor,
|
531 |
+
x_lens: torch.Tensor,
|
532 |
+
y: torch.Tensor,
|
533 |
+
mask_interval: list[torch.Tensor],
|
534 |
+
top_k: int=-100,
|
535 |
+
top_p: float=1.0,
|
536 |
+
temperature: float=1.0,
|
537 |
+
stop_repetition: int=-1,
|
538 |
+
kvcache: int=1,
|
539 |
+
silence_tokens: list[int]=[1388,1898,131],
|
540 |
+
) -> torch.Tensor:
|
541 |
+
"""
|
542 |
+
Args:
|
543 |
+
x:
|
544 |
+
A 2-D tensor of shape (1, L).
|
545 |
+
x_lens:
|
546 |
+
A 1-D tensor of shape (1,). It contains the number of tokens in `x`
|
547 |
+
before padding.
|
548 |
+
y:
|
549 |
+
A 3-D tensor of shape (1, T, K).
|
550 |
+
mask_interval:
|
551 |
+
a list of tensors of shape (M, 2). contains M mask_start and mask_end. list length is actually 1, because we only support single sample inference for now
|
552 |
+
top_k: (`optional`) int
|
553 |
+
The number of highest probability tokens to keep for top-k-filtering. Default to -100.
|
554 |
+
top_p: (`optional`) float
|
555 |
+
For Neucleus sampling
|
556 |
+
temperature: (`optional`) float
|
557 |
+
The value used to module the next token probabilities. Must be strictly positive. Default to 1.0.
|
558 |
+
eog_coef: (`optional`) float
|
559 |
+
if 0, no change to eog token logits, otherwise, will adjust eog token logit based on the difference between acoustic token and phn token length
|
560 |
+
stop_repetition (`optional`) int
|
561 |
+
if not -1, will set the logits of a token that repeated this many times to be -100000, to avoid generating it again. This only apply to tokens from the first codebook
|
562 |
+
allowed_repeat_tokens (`optional`) list of ints
|
563 |
+
by inspecting the validation set, get a few tokens that indeed repeat a significant amount of time, and exclude those tokens from prevent repetition
|
564 |
+
ultimate_stop_repetition (`optional`) int
|
565 |
+
no matter that token it is, stop repetition once after this number
|
566 |
+
"""
|
567 |
+
assert x.ndim == 2, x.shape
|
568 |
+
assert x_lens.ndim == 1, x_lens.shape
|
569 |
+
assert y.ndim == 3, y.shape
|
570 |
+
if self.args.special_first:
|
571 |
+
y = y + int(self.args.n_special)
|
572 |
+
y = y.transpose(2,1) # [1,T,K] -> [1,K,T]
|
573 |
+
assert y.shape[0] == 1 and y.shape[1] == self.args.n_codebooks, y.shape # there is no padding
|
574 |
+
assert mask_interval.shape == torch.Size((1, mask_interval.shape[1], 2)), mask_interval
|
575 |
+
|
576 |
+
# make x attention mask and x_input
|
577 |
+
x_attention_mask = torch.triu(torch.ones(x.shape[1], x.shape[1]), diagonal=1).bool().to(x.device)
|
578 |
+
# x_attention_mask = torch.zeros(x.shape[1], x.shape[1]).bool().to(x.device)
|
579 |
+
x_input = self.text_embedding(x)
|
580 |
+
x_input = self.text_positional_embedding(x_input)
|
581 |
+
|
582 |
+
# make initial y_input
|
583 |
+
|
584 |
+
# make mask_interval and non_mask_interval
|
585 |
+
y_len = y.shape[2]
|
586 |
+
y_lens = torch.LongTensor([y_len]).to(y.device)
|
587 |
+
mask_interval = mask_interval[0]
|
588 |
+
starts = [item[0].item() for item in mask_interval] + [y_len]
|
589 |
+
ends = [0] + [item[1].item() for item in mask_interval]
|
590 |
+
mask_intervals = [[
|
591 |
+
(item[0].item(), item[1].item()) for item in mask_interval
|
592 |
+
]] # a werid name change, mask_interval is input, now is mask_intervals, with one more dimension
|
593 |
+
non_mask_intervals = [[
|
594 |
+
(ns, ne) for ns, ne in zip(ends, starts)
|
595 |
+
]]
|
596 |
+
|
597 |
+
# rearrange y
|
598 |
+
# will add have EOG in each section (SOG will be generated by the pattern class)
|
599 |
+
# but mask can be inserted later after we have shifted the input
|
600 |
+
# y could be rearranged in this way:
|
601 |
+
# [
|
602 |
+
# [tensor[4, 12], tensor[4, 45], tensor[4, 102], tensor[4, 32]], tensor[4, 22]],
|
603 |
+
# [tensor[4, 44], tensor[4, 56], tensor[4, 19]],
|
604 |
+
# ...
|
605 |
+
# ]
|
606 |
+
# for the first list of tensors (4 tensors), first 3 tensors are non_masked part, last 2 are masked part.
|
607 |
+
# NOTE #non_masked_part = #masked_part + 1
|
608 |
+
rearranged_y = self.rearrange(y, non_mask_intervals, mask_intervals)
|
609 |
+
assert rearranged_y[0][0].shape[0] == self.args.n_codebooks, rearranged_y[0][0].shape
|
610 |
+
|
611 |
+
# shift each element of y
|
612 |
+
# next we need to apply pattern shifting to each tensor, after which, we'll replace the starting tokens of each section with a token that's different from the special padding token
|
613 |
+
# [
|
614 |
+
# [empty, 1, 2, 3, eog, empty, empty, empty],
|
615 |
+
# [empty, empty, 1, 2, 3, eog, empty, empty],
|
616 |
+
# [empty, empty, empty, 1, 2, 3, eog, empty],
|
617 |
+
# [empty, empty, empty, empty, 1, 2, 3, eog]
|
618 |
+
# ]
|
619 |
+
shifted_y, patterns = self.shift(rearranged_y) # each element [K S], patterns is not used, as we directly use the original input y
|
620 |
+
assert shifted_y[0][0].shape[0] == self.args.n_codebooks, shifted_y[0][0].shape
|
621 |
+
|
622 |
+
# insert mask token at the intersction of each tensor, but *actually inserted eog as place holder*
|
623 |
+
# the position of inserted mask is also recorded
|
624 |
+
# and the mask_value, the index of the mask emb is recorded
|
625 |
+
inserted_y, mask_position, mask_value = self.insert_mask(shifted_y)
|
626 |
+
assert inserted_y[0][0].shape[0] == self.args.n_codebooks, inserted_y[0][0].shape[0]
|
627 |
+
assert inserted_y[0][1].shape == torch.Size((self.args.n_codebooks, 1)), f"this should be a mask, so should have shape {(self.args.n_codebooks, 1)}, but it's {inserted_y[0][1].shape}"
|
628 |
+
|
629 |
+
# then concat tensors that belong to the same sample (in order) then get the length of each sample, and then stack them in batch dimension, pad them with pad_token
|
630 |
+
cated_y, new_y_lens = self.cat_y(inserted_y, mask_position, y_lens) # KTB
|
631 |
+
assert cated_y.shape == torch.Size((self.args.n_codebooks, cated_y.shape[1], len(inserted_y)))
|
632 |
+
assert not (cated_y == self.args.audio_pad_token).any(), cated_y
|
633 |
+
|
634 |
+
### NOTE this is different from forward, as we will remove the masked tokens
|
635 |
+
### say there are two masked region
|
636 |
+
### the cated_y should be like
|
637 |
+
### [empty a a a a mask0 empty b b b mask1 empty c c mask0 empty]
|
638 |
+
### which means we need to take the part after the last empty out
|
639 |
+
num_mask = len(mask_position[0])//2
|
640 |
+
assert num_mask == len(mask_position[0])/2, mask_position
|
641 |
+
cated_y = cated_y[:, :mask_position[0][num_mask]+2] # of shape [K,T,B]
|
642 |
+
# logging.info(f"mask_position[0][num_mask]+2: {mask_position[0][num_mask]+2}")
|
643 |
+
more_mask_value = mask_value[0][num_mask+1:] # NOTE this will be used in the generation loop for reference for inserting mask embedding
|
644 |
+
new_y_lens[0] = mask_position[0][num_mask]+2
|
645 |
+
mask_position[0] = mask_position[0][:num_mask+1]
|
646 |
+
assert mask_position[0][num_mask]+2 == cated_y.shape[1], f"num_mask: {num_mask}, mask_position: {mask_position}, cated_y.shape: {cated_y.shape}"
|
647 |
+
|
648 |
+
# embed: remember to separately embed the mask tokens
|
649 |
+
embedded_y = self.embed_y(cated_y, mask_position, [mask_value[0][:num_mask+1]]) #BTD
|
650 |
+
# assert embedded_y.shape == torch.Size((y.shape[0], max(new_y_lens), self.args.d_model)), embedded_y.shape
|
651 |
+
|
652 |
+
# positional embedding
|
653 |
+
y_input = self.audio_positional_embedding(embedded_y)
|
654 |
+
|
655 |
+
# make attention mask and padding mask
|
656 |
+
y_attention_mask = torch.triu(torch.ones(y_input.shape[1], y_input.shape[1]), diagonal=1).bool().to(y.device)
|
657 |
+
# y_lens = torch.LongTensor([y_input.shape[1]]).to(y.device)
|
658 |
+
|
659 |
+
x_padding_mask = torch.full((1,x_lens[0]), False).to(x.device)
|
660 |
+
y_padding_mask = torch.full((1,new_y_lens[0]), False).to(y.device)
|
661 |
+
|
662 |
+
|
663 |
+
codebook_eog = [False] * self.args.n_codebooks
|
664 |
+
generated = [] # doesn't contain any empty_token, contains eog
|
665 |
+
cur_generated = []
|
666 |
+
# say 0 is empty, 4 is eog
|
667 |
+
# tensor([[ 1, 2, 3, 4, 0, 0],
|
668 |
+
# [ 0, 1, 2, 3, 4, 0],
|
669 |
+
# [ 0, 0, 1, 2, 3, 4]])
|
670 |
+
num_gen = []
|
671 |
+
cur_num_gen = 0
|
672 |
+
##################### silence repetition handling #####################
|
673 |
+
##################### silence repetition handling #####################
|
674 |
+
logging.info(f"silence tokens: {silence_tokens}, note that if you are not using the pretrained encodec 6f79c6a8, make sure you specified it yourself, rather than using the default")
|
675 |
+
consec_silence_count = 0
|
676 |
+
prev_token = None
|
677 |
+
##################### silence repetition handling #####################
|
678 |
+
##################### silence repetition handling #####################
|
679 |
+
# prepare the cache placeholder
|
680 |
+
# n_layers, 2, bsz, num_heads, src_len, head_dim
|
681 |
+
past = torch.ones([self.args.num_decoder_layers, 2, x.shape[0]], device=x.device, dtype=torch.float32) if kvcache else None
|
682 |
+
# handle multi-span kv-cache
|
683 |
+
new_masked_span = False
|
684 |
+
|
685 |
+
def sample_helper(n_eog, logits, codebook_eog, top_k, top_p, temperature, prev_token, consec_silence_count, stop_repetition, silence_tokens, cur_num_gen):
|
686 |
+
if n_eog == 0:
|
687 |
+
logits_adjust = logits
|
688 |
+
for jj in range(1,self.args.n_codebooks):
|
689 |
+
logits_adjust[jj][self.args.eog] = -10000
|
690 |
+
logits_adjust[jj][self.args.empty_token] = -10000
|
691 |
+
##################### silence repetition handling #####################
|
692 |
+
if stop_repetition > 0 and prev_token in silence_tokens and consec_silence_count > stop_repetition:
|
693 |
+
if logits_adjust[0, prev_token] < 0:
|
694 |
+
logits_adjust[0, prev_token] = logits_adjust[0, prev_token] * (consec_silence_count - (stop_repetition-1))
|
695 |
+
else:
|
696 |
+
logits_adjust[0, prev_token] = logits_adjust[0, prev_token] / (consec_silence_count - (stop_repetition-1))
|
697 |
+
##################### silence repetition handling #####################
|
698 |
+
if type(logits_adjust) == list:
|
699 |
+
samples_list= []
|
700 |
+
for logit in logits_adjust:
|
701 |
+
# print(logit)
|
702 |
+
# print(logit.shape)
|
703 |
+
cur_sample = topk_sampling(
|
704 |
+
logit.unsqueeze(0), top_k=top_k, top_p=top_p, temperature=temperature
|
705 |
+
) # [1, 1]
|
706 |
+
samples_list.append(cur_sample)
|
707 |
+
samples = torch.cat(samples_list, dim=0) # [K, 1]
|
708 |
+
else:
|
709 |
+
samples = topk_sampling(
|
710 |
+
logits_adjust, top_k=top_k, top_p=top_p, temperature=temperature
|
711 |
+
) # [K, 1]
|
712 |
+
assert samples.shape == torch.Size((self.args.n_codebooks, 1)), f"samples.shape: {samples.shape}"
|
713 |
+
if cur_num_gen < self.args.n_codebooks-1:
|
714 |
+
for jj in range(1, self.args.n_codebooks - cur_num_gen):
|
715 |
+
samples[-jj, 0] = self.args.empty_token
|
716 |
+
|
717 |
+
if (
|
718 |
+
samples[0,0] == self.args.eog or torch.argmax(logits[0], dim=-1) == self.args.eog or y_input.shape[1] > x_lens[0] * 10
|
719 |
+
): # last one means y is already too long, shouldn't happen, but put it here
|
720 |
+
samples[0,0] = self.args.eog
|
721 |
+
codebook_eog[0] = True
|
722 |
+
##################### silence repetition handling #####################
|
723 |
+
##################### silence repetition handling #####################
|
724 |
+
if samples[0,0] in silence_tokens and samples[0,0] == prev_token:
|
725 |
+
consec_silence_count += 1
|
726 |
+
else:
|
727 |
+
consec_silence_count = 0
|
728 |
+
prev_token = samples[0,0]
|
729 |
+
##################### silence repetition handling #####################
|
730 |
+
##################### silence repetition handling #####################
|
731 |
+
return samples, codebook_eog, prev_token, consec_silence_count
|
732 |
+
else:
|
733 |
+
assert sum(codebook_eog[i] for i in range(n_eog)) == n_eog, f"codebook_eog: {codebook_eog}, but n_eog: {n_eog}"
|
734 |
+
logits_adjust = logits
|
735 |
+
for jj in range(n_eog+1,self.args.n_codebooks):
|
736 |
+
logits_adjust[jj][self.args.eog] = -10000
|
737 |
+
logits_adjust[jj][self.args.empty_token] = -10000
|
738 |
+
if type(logits_adjust) == list:
|
739 |
+
samples_list= []
|
740 |
+
for logit in logits_adjust:
|
741 |
+
cur_sample = topk_sampling(
|
742 |
+
logit.unsqueeze(0), top_k=top_k, top_p=top_p, temperature=temperature
|
743 |
+
) # [1, 1]
|
744 |
+
samples_list.append(cur_sample)
|
745 |
+
samples = torch.cat(samples_list, dim=0) # [K, 1]
|
746 |
+
else:
|
747 |
+
samples = topk_sampling(
|
748 |
+
logits_adjust, top_k=top_k, top_p=top_p, temperature=temperature
|
749 |
+
) # [K, 1]
|
750 |
+
for jj in range(n_eog):
|
751 |
+
samples[jj, 0] = self.args.empty_token
|
752 |
+
samples[n_eog, 0] = self.args.eog
|
753 |
+
codebook_eog[n_eog] = True
|
754 |
+
return samples, codebook_eog, prev_token, consec_silence_count
|
755 |
+
|
756 |
+
while True:
|
757 |
+
y_out, present = self.dec_forward(
|
758 |
+
x_input,
|
759 |
+
x_lens,
|
760 |
+
x_attention_mask,
|
761 |
+
x_padding_mask,
|
762 |
+
y_input,
|
763 |
+
new_y_lens,
|
764 |
+
y_attention_mask,
|
765 |
+
y_padding_mask,
|
766 |
+
past=past,
|
767 |
+
last_3_tokens = new_masked_span
|
768 |
+
)
|
769 |
+
if new_masked_span:
|
770 |
+
new_masked_span = False
|
771 |
+
|
772 |
+
if past != None:
|
773 |
+
past = torch.cat([past, present.to(past.dtype)], dim=-2) if past.ndim > 3 else present.to(past.dtype)
|
774 |
+
|
775 |
+
y_out = y_out[:, -1:] # only take the last one
|
776 |
+
|
777 |
+
logits = torch.stack([self.predict_layer[i](y_out) for i in range(self.args.n_codebooks)], dim=1) # [B K S card], B==S==1, so [1 K 1 card]
|
778 |
+
logits = logits.squeeze(0).squeeze(1) # [K card]
|
779 |
+
assert logits.shape == torch.Size((self.args.n_codebooks, self.n_audio_tokens[0])), f"{logits.shape}"
|
780 |
+
|
781 |
+
n_eog = sum(codebook_eog)
|
782 |
+
assert n_eog < self.args.n_codebooks
|
783 |
+
if self.args.eos > 0: # eos stands for end-of-sentence, which shouldn't be used as we are doing speech editing
|
784 |
+
for jj in range(self.args.n_codebooks):
|
785 |
+
logits[jj][self.args.eos] = -10000.
|
786 |
+
# need to use a helper function to hand different n_eog cases
|
787 |
+
samples, codebook_eog, prev_token, consec_silence_count = sample_helper(n_eog, logits, codebook_eog, top_k, top_p, temperature, prev_token, consec_silence_count, stop_repetition, silence_tokens, cur_num_gen)
|
788 |
+
cur_num_gen += 1
|
789 |
+
cur_generated.append(samples.squeeze(-1)) # [K,1] -> [K]
|
790 |
+
# get samples_emb
|
791 |
+
samples_emb = torch.stack([self.audio_embedding[k](samples[k]) for k in range(self.args.n_codebooks)], dim=0) # [K,1,D]
|
792 |
+
samples_emb = samples_emb.sum(dim=0,keepdim=True) # [1,1,D]
|
793 |
+
|
794 |
+
if sum(codebook_eog) == self.args.n_codebooks: # generation for the current span is done
|
795 |
+
# re-init
|
796 |
+
codebook_eog = [False] * self.args.n_codebooks
|
797 |
+
num_gen.append(cur_num_gen)
|
798 |
+
cur_num_gen = 0
|
799 |
+
generated.append(cur_generated)
|
800 |
+
cur_generated = []
|
801 |
+
|
802 |
+
# if the current mask span is the last span, then all done
|
803 |
+
# else
|
804 |
+
# append the next mask token and the four empty tokens to start the next generation
|
805 |
+
if len(more_mask_value) > 0:
|
806 |
+
next_mask_ind = more_mask_value.pop(0)
|
807 |
+
mask_emb = self.mask_embedding[next_mask_ind].unsqueeze(0).unsqueeze(0) # [1,1,D]
|
808 |
+
assert mask_emb.shape == torch.Size((1,1,self.args.d_model)), mask_emb.shape
|
809 |
+
empty_token = torch.LongTensor([self.args.empty_token]).to(y.device)
|
810 |
+
empty_emb = torch.stack([
|
811 |
+
self.audio_embedding[k](empty_token) for k in range(self.args.n_codebooks)], dim=0
|
812 |
+
).sum(dim=0, keepdim=True) # [1,1,D]
|
813 |
+
assert empty_emb.shape == torch.Size((1,1,self.args.d_model)), empty_emb.shape
|
814 |
+
extra_emb = torch.cat([mask_emb, empty_emb], dim=1) # [1,2,D]
|
815 |
+
samples_emb = torch.cat([samples_emb, extra_emb], dim=1) # [1,3,D] # prev_last_token, mask_token, empty token
|
816 |
+
assert samples_emb.shape == torch.Size((1,3,self.args.d_model)), f"samples_emb.shape: {samples_emb.shape}"
|
817 |
+
##################### silence repetition handling #####################
|
818 |
+
##################### silence repetition handling #####################
|
819 |
+
consec_silence_count = 0
|
820 |
+
prev_token = None
|
821 |
+
##################### silence repetition handling #####################
|
822 |
+
##################### silence repetition handling #####################
|
823 |
+
|
824 |
+
# handling kv-caching for multi-span editing
|
825 |
+
new_masked_span = True
|
826 |
+
else:
|
827 |
+
break
|
828 |
+
else:
|
829 |
+
assert samples_emb.shape == torch.Size((1,1,self.args.d_model)), f"samples_emb.shape: {samples_emb.shape}"
|
830 |
+
|
831 |
+
embedded_y = torch.cat([embedded_y, samples_emb], dim=1)
|
832 |
+
# positional embedding
|
833 |
+
y_input = self.audio_positional_embedding(embedded_y) # [B T D]
|
834 |
+
# make attention mask and padding mask
|
835 |
+
y_attention_mask = torch.triu(torch.ones(y_input.shape[1], y_input.shape[1]), diagonal=1).bool().to(y.device)
|
836 |
+
new_y_lens = torch.LongTensor([y_input.shape[1]]).to(y.device)
|
837 |
+
y_padding_mask = torch.full((1,new_y_lens[0]), False).to(y.device)
|
838 |
+
|
839 |
+
assert len(generated) == num_mask, f"len(generated): {len(generated)}, num_mask: {num_mask}"
|
840 |
+
|
841 |
+
# # combine non_masked_span with generated spans
|
842 |
+
# first need to shift the generated part back
|
843 |
+
flatten_gen = []
|
844 |
+
for l, orig_span in enumerate(generated):
|
845 |
+
span = torch.stack(orig_span, dim=0) # [T K]
|
846 |
+
span = span.transpose(1,0) # [K, T]
|
847 |
+
assert span.shape[0] == self.args.n_codebooks, span.shape
|
848 |
+
unshifted_span = []
|
849 |
+
for j, s in enumerate(span):
|
850 |
+
start_from = j
|
851 |
+
end_at = - (self.args.n_codebooks - start_from)
|
852 |
+
unshifted_span.append(s[start_from:end_at])
|
853 |
+
unshifted_span = torch.stack(unshifted_span, dim=0)
|
854 |
+
|
855 |
+
assert unshifted_span.shape[1] == num_gen[l] - self.args.n_codebooks, f"len(unshifted_spans[0]): {len(unshifted_span[0])}, num_gen[l]: {num_gen[l]}"
|
856 |
+
flatten_gen.append(unshifted_span)
|
857 |
+
# logging.info(f"unshfited_span: {unshifted_span.shape}")
|
858 |
+
# raise
|
859 |
+
assert len(non_mask_intervals[0]) - 1 == len(flatten_gen), f"len(non_mask_intervals[0]): {len(non_mask_intervals[0])}, len(flatten_gen): {len(flatten_gen)}"
|
860 |
+
res = []
|
861 |
+
for orig_interval, gen in zip(non_mask_intervals[0], flatten_gen):
|
862 |
+
res.append(y[0, :, orig_interval[0]:orig_interval[1]])
|
863 |
+
res.append(gen)
|
864 |
+
res.append(y[0, :, non_mask_intervals[0][-1][0]:non_mask_intervals[0][-1][1]])
|
865 |
+
res = torch.cat(res, dim=1).unsqueeze(0) # [K,new_T] -> [1, K, new_T]
|
866 |
+
|
867 |
+
expected_y_len = y_len - sum([item[1] - item[0] for item in mask_intervals[0]]) + sum([item - self.args.n_codebooks for item in num_gen])
|
868 |
+
assert res.shape == torch.Size((1, self.args.n_codebooks, expected_y_len)), f"res.shape: {res.shape}, expected_y_len: {expected_y_len}. y_len - sum([item[1] - item[0] for item in mask_interval]) + sum([item - self.args.n_codebooks for item in num_gen]): {y_len}-{sum([item[1] - item[0] for item in mask_interval])} + {sum([item - self.args.n_codebooks for item in num_gen])}"
|
869 |
+
|
870 |
+
if self.args.special_first:
|
871 |
+
res = res - int(self.args.n_special)
|
872 |
+
|
873 |
+
return res
|
874 |
+
|
875 |
+
def inference_tts(
|
876 |
+
self,
|
877 |
+
x: torch.Tensor,
|
878 |
+
x_lens: torch.Tensor,
|
879 |
+
y: torch.Tensor,
|
880 |
+
top_k: int=-100,
|
881 |
+
top_p: float=1.0,
|
882 |
+
temperature: float=1.0,
|
883 |
+
stop_repetition: int=3,
|
884 |
+
kvcache: int=1,
|
885 |
+
silence_tokens: list[int]=[1388,1898,131],
|
886 |
+
*kargs
|
887 |
+
) -> torch.Tensor:
|
888 |
+
"""
|
889 |
+
different from inference_tts, this implementation uses kvcache, which should have significant speed up
|
890 |
+
Args:
|
891 |
+
x:
|
892 |
+
A 2-D tensor of shape (1, L).
|
893 |
+
x_lens:
|
894 |
+
A 1-D tensor of shape (1,). It contains the number of tokens in `x`
|
895 |
+
before padding.
|
896 |
+
y:
|
897 |
+
A 3-D tensor of shape (1, T, K).
|
898 |
+
top_k: (`optional`) int
|
899 |
+
The number of highest probability tokens to keep for top-k-filtering. Default to -100.
|
900 |
+
top_p: (`optional`) float
|
901 |
+
For Neucleus sampling
|
902 |
+
temperature: (`optional`) float
|
903 |
+
The value used to module the next token probabilities. Must be strictly positive. Default to 1.0.
|
904 |
+
"""
|
905 |
+
eog_inference = self.args.eos if self.args.eos>0 else self.args.eog
|
906 |
+
assert x.ndim == 2, x.shape
|
907 |
+
assert x_lens.ndim == 1, x_lens.shape
|
908 |
+
assert y.ndim == 3, y.shape
|
909 |
+
if self.args.special_first:
|
910 |
+
y = y + int(self.args.n_special)
|
911 |
+
y = y.transpose(2,1) # [1,T,K] -> [1,K,T]
|
912 |
+
assert y.shape[0] == 1 and y.shape[1] == self.args.n_codebooks, y.shape # there is no padding
|
913 |
+
|
914 |
+
# make x attention mask and x_input
|
915 |
+
x_attention_mask = torch.triu(torch.ones(x.shape[1], x.shape[1]), diagonal=1).bool().to(x.device)
|
916 |
+
# x_attention_mask = torch.zeros(x.shape[1], x.shape[1]).bool().to(x.device)
|
917 |
+
x_input = self.text_embedding(x)
|
918 |
+
x_input = self.text_positional_embedding(x_input)
|
919 |
+
|
920 |
+
y_len = y.shape[2]
|
921 |
+
y_lens = torch.LongTensor([y_len]).to(y.device)
|
922 |
+
|
923 |
+
# rearrange y, we don't add eog to the end, this doesn't actually do anything in the tts scenario
|
924 |
+
rearranged_y = [[y[0]]]
|
925 |
+
assert rearranged_y[0][0].shape[0] == self.args.n_codebooks, rearranged_y[0][0].shape
|
926 |
+
|
927 |
+
# shift y to create the delayed pattern
|
928 |
+
shifted_y, patterns = self.shift(rearranged_y) # each element [K S], patterns is not used, as we directly use the original input y
|
929 |
+
assert shifted_y[0][0].shape[0] == self.args.n_codebooks, shifted_y[0][0].shape
|
930 |
+
assert len(shifted_y[0]) == 1, len(shifted_y[0])
|
931 |
+
|
932 |
+
# below is different from forward or inference
|
933 |
+
# where we cut this shifted part
|
934 |
+
shifted_y[0][0] = shifted_y[0][0][:, :-(self.args.n_codebooks-1)]
|
935 |
+
assert not (shifted_y[0][0][self.args.n_codebooks:] == self.args.empty_token).any() and not (shifted_y[0][0][self.args.n_codebooks:] == self.args.eog).any(), shifted_y[0][0]
|
936 |
+
|
937 |
+
# next section in inference is insert mask at the intersection of each tensor in a sample, but we don't need to do that
|
938 |
+
# next section is concate tensors of each sample to one tensor, which we also don't need
|
939 |
+
cated_y = shifted_y[0][0].unsqueeze(-1) #[K,S]->[K,S,B]
|
940 |
+
new_y_lens = torch.LongTensor([cated_y.shape[1]]).to(cated_y.device)
|
941 |
+
assert cated_y.shape == torch.Size((self.args.n_codebooks, cated_y.shape[1], 1))
|
942 |
+
assert not (cated_y == self.args.audio_pad_token).any(), cated_y
|
943 |
+
|
944 |
+
# replace tokens in y with the embeddings, add sum codebooks up
|
945 |
+
embedded_y = torch.stack([self.audio_embedding[k](cated_y[k]) for k in range(self.args.n_codebooks)], dim=0) # [K, S, B, D]
|
946 |
+
assert embedded_y.shape[0] == self.args.n_codebooks, embedded_y.shape
|
947 |
+
assert embedded_y.shape[-1] == self.args.d_model, embedded_y.shape
|
948 |
+
embedded_y = embedded_y.sum(dim=0) # [K,S,B,D]->[S,B,D]
|
949 |
+
embedded_y = embedded_y.transpose(1,0) # [S,B,D]->[B,S,D]
|
950 |
+
|
951 |
+
# positional embedding
|
952 |
+
y_input = self.audio_positional_embedding(embedded_y)
|
953 |
+
|
954 |
+
# make attention mask and padding mask
|
955 |
+
y_attention_mask = torch.triu(torch.ones(y_input.shape[1], y_input.shape[1]), diagonal=1).bool().to(y.device)
|
956 |
+
|
957 |
+
x_padding_mask = torch.full((1,x_lens[0]), False).to(x.device)
|
958 |
+
y_padding_mask = torch.full((1,new_y_lens[0]), False).to(y.device)
|
959 |
+
|
960 |
+
# entering the generation stage
|
961 |
+
# starting from line 708
|
962 |
+
codebook_eog = [False] * self.args.n_codebooks
|
963 |
+
generated = [] # doesn't contain any empty token, contain eog
|
964 |
+
cur_generated = []
|
965 |
+
# say 0 is empty, 4 is eog
|
966 |
+
# tensor([[ 1, 2, 3, 4, 0, 0],
|
967 |
+
# [ 0, 1, 2, 3, 4, 0],
|
968 |
+
# [ 0, 0, 1, 2, 3, 4]])
|
969 |
+
num_gen = []
|
970 |
+
cur_num_gen = 0
|
971 |
+
##################### silence repetition handling #####################
|
972 |
+
##################### silence repetition handling #####################
|
973 |
+
logging.info(f"silence tokens: {silence_tokens}, note that if you are not using the pretrained encodec 6f79c6a8, make sure you specified it yourself, rather than using the default")
|
974 |
+
consec_silence_count = 0
|
975 |
+
prev_token = None
|
976 |
+
##################### silence repetition handling #####################
|
977 |
+
##################### silence repetition handling #####################
|
978 |
+
|
979 |
+
# prepare the cache placeholder
|
980 |
+
# n_layers, 2, bsz, num_heads, src_len, head_dim
|
981 |
+
past = torch.ones([self.args.num_decoder_layers, 2, x.shape[0]], device=x.device, dtype=torch.float32) if kvcache else None
|
982 |
+
# logging.info(f"number of decoder layers: {self.args.num_decoder_layers}")
|
983 |
+
# logging.info(f"number of decoder layers: {self.args.num_decoder_layers}")
|
984 |
+
# logging.info(f"number of decoder layers: {self.args.num_decoder_layers}")
|
985 |
+
def sample_helper(n_eog, logits, codebook_eog, top_k, top_p, temperature, prev_token, consec_silence_count, stop_repetition, silence_tokens, cur_num_gen):
|
986 |
+
if n_eog == 0:
|
987 |
+
logits_adjust = logits
|
988 |
+
for jj in range(1,self.args.n_codebooks):
|
989 |
+
logits_adjust[jj][eog_inference] = -10000
|
990 |
+
logits_adjust[jj][self.args.empty_token] = -10000
|
991 |
+
if cur_num_gen <= self.args.encodec_sr // 5: # this shouldn't happen, but just in case the model stopped too early
|
992 |
+
logits_adjust[0][eog_inference] = -10000
|
993 |
+
##################### silence repetition handling #####################
|
994 |
+
if stop_repetition > 0 and prev_token in silence_tokens and consec_silence_count > stop_repetition:
|
995 |
+
if logits_adjust[0, prev_token] < 0:
|
996 |
+
logits_adjust[0, prev_token] = logits_adjust[0, prev_token] * (consec_silence_count - (stop_repetition-1))
|
997 |
+
else:
|
998 |
+
logits_adjust[0, prev_token] = logits_adjust[0, prev_token] / (consec_silence_count - (stop_repetition-1))
|
999 |
+
##################### silence repetition handling #####################
|
1000 |
+
samples = topk_sampling(
|
1001 |
+
logits_adjust, top_k=top_k, top_p=top_p, temperature=temperature
|
1002 |
+
) # [K, 1]
|
1003 |
+
assert samples.shape == torch.Size((self.args.n_codebooks, 1)), f"samples.shape: {samples.shape}"
|
1004 |
+
if cur_num_gen < self.args.n_codebooks-1:
|
1005 |
+
for jj in range(1, self.args.n_codebooks - cur_num_gen):
|
1006 |
+
samples[-jj, 0] = self.args.empty_token
|
1007 |
+
|
1008 |
+
if (
|
1009 |
+
samples[0,0] == eog_inference or torch.argmax(logits[0], dim=-1) == eog_inference or y_input.shape[1] > x_lens[0] * (self.args.encodec_sr//5)
|
1010 |
+
): # last one means y is already too long, shouldn't happen, but put it here
|
1011 |
+
samples[0,0] = eog_inference
|
1012 |
+
codebook_eog[0] = True
|
1013 |
+
##################### silence repetition handling #####################
|
1014 |
+
if samples[0,0] in silence_tokens and samples[0,0] == prev_token:
|
1015 |
+
consec_silence_count += 1
|
1016 |
+
else:
|
1017 |
+
consec_silence_count = 0
|
1018 |
+
prev_token = samples[0,0]
|
1019 |
+
##################### silence repetition handling #####################
|
1020 |
+
return samples, codebook_eog, prev_token, consec_silence_count
|
1021 |
+
else:
|
1022 |
+
assert sum(codebook_eog[i] for i in range(n_eog)) == n_eog, f"codebook_eog: {codebook_eog}, but n_eog: {n_eog}"
|
1023 |
+
logits_adjust = logits
|
1024 |
+
for jj in range(n_eog+1,self.args.n_codebooks):
|
1025 |
+
logits_adjust[jj][eog_inference] = -10000
|
1026 |
+
logits_adjust[jj][self.args.empty_token] = -10000
|
1027 |
+
samples = topk_sampling(
|
1028 |
+
logits_adjust, top_k=top_k, top_p=top_p, temperature=temperature
|
1029 |
+
) # [K, 1]
|
1030 |
+
for jj in range(n_eog):
|
1031 |
+
samples[jj, 0] = self.args.empty_token
|
1032 |
+
samples[n_eog, 0] = eog_inference
|
1033 |
+
codebook_eog[n_eog] = True
|
1034 |
+
return samples, codebook_eog, prev_token, consec_silence_count
|
1035 |
+
while True:
|
1036 |
+
y_out, present = self.dec_forward(
|
1037 |
+
x_input,
|
1038 |
+
x_lens,
|
1039 |
+
x_attention_mask,
|
1040 |
+
x_padding_mask,
|
1041 |
+
y_input,
|
1042 |
+
new_y_lens,
|
1043 |
+
y_attention_mask,
|
1044 |
+
y_padding_mask,
|
1045 |
+
past=past
|
1046 |
+
)
|
1047 |
+
if past != None:
|
1048 |
+
past = torch.cat([past, present.to(past.dtype)], dim=-2) if past.ndim > 3 else present.to(past.dtype)
|
1049 |
+
|
1050 |
+
|
1051 |
+
y_out = y_out[:, -1:] # only take the last token
|
1052 |
+
logits = torch.stack([self.predict_layer[i](y_out) for i in range(self.args.n_codebooks)], dim=1) # [B K S card], B==S==1, so [1 K 1 card]
|
1053 |
+
logits = logits.squeeze(0).squeeze(1) # [K card]
|
1054 |
+
assert logits.shape == torch.Size((self.args.n_codebooks, self.n_audio_tokens[0])), f"{logits.shape}"
|
1055 |
+
|
1056 |
+
n_eog = sum(codebook_eog)
|
1057 |
+
assert n_eog < self.args.n_codebooks
|
1058 |
+
if self.args.eos > 0: # if we are using end-of-sentence token (which is used by default), eog shouldn't be used here, as there is no masked spans
|
1059 |
+
for jj in range(self.args.n_codebooks):
|
1060 |
+
logits[jj][self.args.eog] = -10000.
|
1061 |
+
|
1062 |
+
samples, codebook_eog, prev_token, consec_silence_count = sample_helper(n_eog, logits, codebook_eog, top_k, top_p, temperature, prev_token, consec_silence_count, stop_repetition, silence_tokens, cur_num_gen)
|
1063 |
+
|
1064 |
+
cur_num_gen += 1
|
1065 |
+
cur_generated.append(samples.squeeze(-1)) # [K,1] -> [K]
|
1066 |
+
|
1067 |
+
# samples.shape is [K,1]
|
1068 |
+
# ge samples_emb
|
1069 |
+
samples_emb = torch.stack([self.audio_embedding[k](samples[k]) for k in range(self.args.n_codebooks)], dim=0) # [K,1,D]
|
1070 |
+
samples_emb = samples_emb.sum(dim=0,keepdim=True) # [1,1,D]
|
1071 |
+
|
1072 |
+
if sum(codebook_eog) == self.args.n_codebooks: # generation for the current span is done
|
1073 |
+
codebook_eog = [False] * self.args.n_codebooks
|
1074 |
+
num_gen.append(cur_num_gen)
|
1075 |
+
cur_num_gen = 0
|
1076 |
+
generated.append(cur_generated)
|
1077 |
+
cur_generated = []
|
1078 |
+
break
|
1079 |
+
else:
|
1080 |
+
assert samples_emb.shape == torch.Size((1,1,self.args.d_model)), f"samples_emb.shape: {samples_emb.shape}"
|
1081 |
+
|
1082 |
+
embedded_y = torch.cat([embedded_y, samples_emb], dim=1)
|
1083 |
+
y_input = self.audio_positional_embedding(embedded_y) # [B T D]
|
1084 |
+
# make attention mask and padding mask
|
1085 |
+
y_attention_mask = torch.triu(torch.ones(y_input.shape[1], y_input.shape[1]), diagonal=1).bool().to(y.device)
|
1086 |
+
new_y_lens = torch.LongTensor([y_input.shape[1]]).to(y.device)
|
1087 |
+
y_padding_mask = torch.full((1,new_y_lens[0]), False).to(y.device)
|
1088 |
+
|
1089 |
+
assert len(generated) == 1, f"len(generated): {len(generated)}"
|
1090 |
+
|
1091 |
+
# revert the pattern
|
1092 |
+
flatten_gen = []
|
1093 |
+
for l, orig_span in enumerate(generated):
|
1094 |
+
span = torch.stack(orig_span, dim=0) # [T, K]
|
1095 |
+
span = span.transpose(1,0) # [K, T]
|
1096 |
+
assert span.shape[0] == self.args.n_codebooks, span.shape
|
1097 |
+
unshifted_span = []
|
1098 |
+
for j, s in enumerate(span):
|
1099 |
+
start_from = j
|
1100 |
+
end_at = - (self.args.n_codebooks - start_from)
|
1101 |
+
unshifted_span.append(s[start_from:end_at])
|
1102 |
+
unshifted_span = torch.stack(unshifted_span, dim=0)
|
1103 |
+
|
1104 |
+
assert unshifted_span.shape[1] == num_gen[l] - self.args.n_codebooks, f"len(unshifted_spans[0]): {len(unshifted_span[0])}, num_gen[l]: {num_gen[l]}"
|
1105 |
+
|
1106 |
+
flatten_gen.append(unshifted_span)
|
1107 |
+
assert len(flatten_gen) == 1, len(flatten_gen)
|
1108 |
+
|
1109 |
+
# combine
|
1110 |
+
res = [y[0], flatten_gen[0]]
|
1111 |
+
res = torch.cat(res, dim=1).unsqueeze(0) # [K, new_t] -> [1, K, new_T]
|
1112 |
+
|
1113 |
+
expected_y_len = y_len + sum([item - self.args.n_codebooks for item in num_gen])
|
1114 |
+
assert res.shape == torch.Size((1, self.args.n_codebooks, expected_y_len)), f"res.shape: {res.shape}, expected_y_len: {expected_y_len}. y_len + sum([item - self.args.n_codebooks for item in num_gen]): {y_len} + {sum([item - self.args.n_codebooks for item in num_gen])}"
|
1115 |
+
|
1116 |
+
if self.args.special_first:
|
1117 |
+
res = res - int(self.args.n_special)
|
1118 |
+
flatten_gen = flatten_gen - int(self.args.n_special)
|
1119 |
+
|
1120 |
+
return res, flatten_gen[0].unsqueeze(0)
|
1121 |
+
|
1122 |
+
|
1123 |
+
def inference_tts_batch(
|
1124 |
+
self,
|
1125 |
+
x: torch.Tensor,
|
1126 |
+
x_lens: torch.Tensor,
|
1127 |
+
y: torch.Tensor,
|
1128 |
+
top_k: int=-100,
|
1129 |
+
top_p: float=1.0,
|
1130 |
+
temperature: float=1.0,
|
1131 |
+
stop_repetition: int=3,
|
1132 |
+
kvcache: int=1,
|
1133 |
+
batch_size: int=5,
|
1134 |
+
silence_tokens: list[int]=[1388,1898,131],
|
1135 |
+
*kargs
|
1136 |
+
) -> torch.Tensor:
|
1137 |
+
"""
|
1138 |
+
have a batch size when forward passing, but they are equivalant to same example but different random seed, therefore as long as one example generated eog, we can drop all other samlpes
|
1139 |
+
different from inference_tts, this implementation uses kvcache, which should have significant speed up
|
1140 |
+
Args:
|
1141 |
+
x:
|
1142 |
+
A 2-D tensor of shape (1, L).
|
1143 |
+
x_lens:
|
1144 |
+
A 1-D tensor of shape (1,). It contains the number of tokens in `x`
|
1145 |
+
before padding.
|
1146 |
+
y:
|
1147 |
+
A 3-D tensor of shape (1, T, K).
|
1148 |
+
top_k: (`optional`) int
|
1149 |
+
The number of highest probability tokens to keep for top-k-filtering. Default to -100.
|
1150 |
+
top_p: (`optional`) float
|
1151 |
+
For Neucleus sampling
|
1152 |
+
temperature: (`optional`) float
|
1153 |
+
The value used to module the next token probabilities. Must be strictly positive. Default to 1.0.
|
1154 |
+
"""
|
1155 |
+
eog_inference = self.args.eos if self.args.eos>0 else self.args.eog
|
1156 |
+
assert x.ndim == 2, x.shape
|
1157 |
+
assert x_lens.ndim == 1, x_lens.shape
|
1158 |
+
assert y.ndim == 3, y.shape
|
1159 |
+
if self.args.special_first:
|
1160 |
+
y = y + int(self.args.n_special)
|
1161 |
+
y = y.transpose(2,1) # [1,T,K] -> [1,K,T]
|
1162 |
+
assert y.shape[0] == 1 and y.shape[1] == self.args.n_codebooks, y.shape # there is no padding
|
1163 |
+
|
1164 |
+
# make x attention mask and x_input
|
1165 |
+
x_attention_mask = torch.triu(torch.ones(x.shape[1], x.shape[1]), diagonal=1).bool().to(x.device)
|
1166 |
+
# x_attention_mask = torch.zeros(x.shape[1], x.shape[1]).bool().to(x.device)
|
1167 |
+
x_input = self.text_embedding(x)
|
1168 |
+
x_input = self.text_positional_embedding(x_input)
|
1169 |
+
|
1170 |
+
y_len = y.shape[2]
|
1171 |
+
y_lens = torch.LongTensor([y_len]).to(y.device)
|
1172 |
+
|
1173 |
+
# rearrange y, we don't add eog to the end, this doesn't actually do anything in the tts scenario
|
1174 |
+
rearranged_y = [[y[0]]]
|
1175 |
+
assert rearranged_y[0][0].shape[0] == self.args.n_codebooks, rearranged_y[0][0].shape
|
1176 |
+
|
1177 |
+
# shift y to create the delayed pattern
|
1178 |
+
shifted_y, patterns = self.shift(rearranged_y) # each element [K S], patterns is not used, as we directly use the original input y
|
1179 |
+
assert shifted_y[0][0].shape[0] == self.args.n_codebooks, shifted_y[0][0].shape
|
1180 |
+
assert len(shifted_y[0]) == 1, len(shifted_y[0])
|
1181 |
+
|
1182 |
+
# below is different from forward or inference
|
1183 |
+
# where we cut this shifted part
|
1184 |
+
shifted_y[0][0] = shifted_y[0][0][:, :-(self.args.n_codebooks-1)]
|
1185 |
+
assert not (shifted_y[0][0][self.args.n_codebooks:] == self.args.empty_token).any() and not (shifted_y[0][0][self.args.n_codebooks:] == self.args.eog).any(), shifted_y[0][0]
|
1186 |
+
|
1187 |
+
# next section in inference is insert mask at the intersection of each tensor in a sample, but we don't need to do that
|
1188 |
+
# next section is concate tensors of each sample to one tensor, which we also don't need
|
1189 |
+
cated_y = shifted_y[0][0].unsqueeze(-1) #[K,S]->[K,S,B]
|
1190 |
+
new_y_lens = torch.LongTensor([cated_y.shape[1]]).to(cated_y.device)
|
1191 |
+
assert cated_y.shape == torch.Size((self.args.n_codebooks, cated_y.shape[1], 1))
|
1192 |
+
assert not (cated_y == self.args.audio_pad_token).any(), cated_y
|
1193 |
+
|
1194 |
+
# replace tokens in y with the embeddings, add sum codebooks up
|
1195 |
+
embedded_y = torch.stack([self.audio_embedding[k](cated_y[k]) for k in range(self.args.n_codebooks)], dim=0) # [K, S, B, D]
|
1196 |
+
assert embedded_y.shape[0] == self.args.n_codebooks, embedded_y.shape
|
1197 |
+
assert embedded_y.shape[-1] == self.args.d_model, embedded_y.shape
|
1198 |
+
embedded_y = embedded_y.sum(dim=0) # [K,S,B,D]->[S,B,D]
|
1199 |
+
embedded_y = embedded_y.transpose(1,0) # [S,B,D]->[B,S,D]
|
1200 |
+
|
1201 |
+
# positional embedding
|
1202 |
+
y_input = self.audio_positional_embedding(embedded_y)
|
1203 |
+
|
1204 |
+
# make attention mask and padding mask
|
1205 |
+
y_attention_mask = torch.triu(torch.ones(y_input.shape[1], y_input.shape[1]), diagonal=1).bool().to(y.device)
|
1206 |
+
|
1207 |
+
x_padding_mask = torch.full((1,x_lens[0]), False).to(x.device)
|
1208 |
+
y_padding_mask = torch.full((1,new_y_lens[0]), False).to(y.device)
|
1209 |
+
|
1210 |
+
# entering the generation stage
|
1211 |
+
# starting from line 708
|
1212 |
+
codebook_eog = [False] * self.args.n_codebooks
|
1213 |
+
generated = [] # doesn't contain any empty token, contain eog
|
1214 |
+
cur_generated = [[] for _ in range(batch_size)]
|
1215 |
+
# say 0 is empty, 4 is eog
|
1216 |
+
# tensor([[ 1, 2, 3, 4, 0, 0],
|
1217 |
+
# [ 0, 1, 2, 3, 4, 0],
|
1218 |
+
# [ 0, 0, 1, 2, 3, 4]])
|
1219 |
+
num_gen = []
|
1220 |
+
cur_num_gen = 0
|
1221 |
+
##################### silence repetition handling #####################
|
1222 |
+
##################### silence repetition handling #####################
|
1223 |
+
logging.info(f"silence tokens: {silence_tokens}, note that if you are not using the pretrained encodec 6f79c6a8, make sure you specified it yourself, rather than using the default")
|
1224 |
+
consec_silence_counts = [0 for _ in range(batch_size)]
|
1225 |
+
prev_tokens = [None for _ in range(batch_size)]
|
1226 |
+
##################### silence repetition handling #####################
|
1227 |
+
##################### silence repetition handling #####################
|
1228 |
+
|
1229 |
+
# prepare the cache placeholder
|
1230 |
+
# n_layers, 2, bsz, num_heads, src_len, head_dim
|
1231 |
+
past = torch.ones([self.args.num_decoder_layers, 2, x.shape[0]], device=x.device, dtype=torch.float32) if kvcache else None
|
1232 |
+
# logging.info(f"number of decoder layers: {self.args.num_decoder_layers}")
|
1233 |
+
# logging.info(f"number of decoder layers: {self.args.num_decoder_layers}")
|
1234 |
+
# logging.info(f"number of decoder layers: {self.args.num_decoder_layers}")
|
1235 |
+
keep = None # NOTE: this very important, tells which sample to keep
|
1236 |
+
def sample_helper(n_eog, logits, codebook_eog, top_k, top_p, temperature, prev_tokens, consec_silence_counts, stop_repetition, silence_tokens, cur_num_gen, keep):
|
1237 |
+
if n_eog == 0:
|
1238 |
+
logits_adjust = logits
|
1239 |
+
for jj in range(1,self.args.n_codebooks):
|
1240 |
+
logits_adjust[:,jj,eog_inference] = -10000
|
1241 |
+
logits_adjust[:,jj,self.args.empty_token] = -10000
|
1242 |
+
if cur_num_gen <= self.args.encodec_sr // 5: # this shouldn't happen, but just in case the model stopped too early
|
1243 |
+
logits_adjust[:,:,eog_inference] = -10000
|
1244 |
+
##################### silence repetition handling #####################
|
1245 |
+
for b in range(batch_size):
|
1246 |
+
prev_token = prev_tokens[b]
|
1247 |
+
consec_silence_count = consec_silence_counts[b]
|
1248 |
+
if stop_repetition > 0 and prev_token in silence_tokens and consec_silence_count > stop_repetition:
|
1249 |
+
if logits_adjust[b, 0, prev_token] < 0:
|
1250 |
+
logits_adjust[b, 0, prev_token] = logits_adjust[b, 0, prev_token] * (consec_silence_count - (stop_repetition-1))
|
1251 |
+
else:
|
1252 |
+
logits_adjust[b, 0, prev_token] = logits_adjust[b, 0, prev_token] / (consec_silence_count - (stop_repetition-1))
|
1253 |
+
##################### silence repetition handling #####################
|
1254 |
+
samples = topk_sampling(
|
1255 |
+
logits_adjust.reshape(batch_size * self.args.n_codebooks, logits_adjust.shape[-1]), top_k=top_k, top_p=top_p, temperature=temperature
|
1256 |
+
) # [B*K, 1]
|
1257 |
+
samples = samples.reshape(batch_size, self.args.n_codebooks, 1)
|
1258 |
+
assert samples.shape == torch.Size((batch_size, self.args.n_codebooks, 1)), f"samples.shape: {samples.shape}"
|
1259 |
+
for b in range(batch_size):
|
1260 |
+
if cur_num_gen < self.args.n_codebooks-1:
|
1261 |
+
for jj in range(1, self.args.n_codebooks - cur_num_gen):
|
1262 |
+
samples[b, -jj, 0] = self.args.empty_token
|
1263 |
+
|
1264 |
+
if (
|
1265 |
+
samples[b,0,0] == eog_inference or torch.argmax(logits[b,0], dim=-1) == eog_inference or y_input.shape[1] > x_lens[b] * (self.args.encodec_sr//5)
|
1266 |
+
): # last one means y is already too long, shouldn't happen, but put it here
|
1267 |
+
samples[b,0,0] = eog_inference
|
1268 |
+
codebook_eog[0] = True
|
1269 |
+
keep = b # NOTE keep is a very important variable, we only return this one, note that if eog shows up in two samples, keep will be overwritten by the later one (or the last one)
|
1270 |
+
##################### silence repetition handling #####################
|
1271 |
+
if samples[b,0,0] in silence_tokens and samples[b,0,0] == prev_tokens[b]:
|
1272 |
+
consec_silence_counts[b] += 1
|
1273 |
+
else:
|
1274 |
+
consec_silence_counts[b] = 0
|
1275 |
+
prev_tokens[b] = samples[b,0,0]
|
1276 |
+
##################### silence repetition handling #####################
|
1277 |
+
return samples, codebook_eog, prev_tokens, consec_silence_counts, keep
|
1278 |
+
else:
|
1279 |
+
assert sum(codebook_eog[i] for i in range(n_eog)) == n_eog, f"codebook_eog: {codebook_eog}, but n_eog: {n_eog}"
|
1280 |
+
logits_adjust = logits
|
1281 |
+
for jj in range(n_eog+1,self.args.n_codebooks):
|
1282 |
+
logits_adjust[:,jj,eog_inference] = -10000
|
1283 |
+
logits_adjust[:,jj,self.args.empty_token] = -10000
|
1284 |
+
samples = topk_sampling(
|
1285 |
+
logits_adjust.reshape(batch_size * self.args.n_codebooks, logits_adjust.shape[-1]), top_k=top_k, top_p=top_p, temperature=temperature
|
1286 |
+
) # [B, K, 1]
|
1287 |
+
samples = samples.reshape(batch_size, self.args.n_codebooks, 1)
|
1288 |
+
for jj in range(n_eog):
|
1289 |
+
samples[keep, jj, 0] = self.args.empty_token
|
1290 |
+
samples[keep, n_eog, 0] = eog_inference
|
1291 |
+
codebook_eog[n_eog] = True
|
1292 |
+
return samples, codebook_eog, prev_tokens, consec_silence_counts, keep
|
1293 |
+
while True:
|
1294 |
+
# if cur_num_gen > 0, should have everything in kvcache, so only pass in the last token
|
1295 |
+
# in the first generation step, we repeat each tensor to make their first dimension of length the batch size
|
1296 |
+
if cur_num_gen == 0:
|
1297 |
+
assert x_input.ndim == 3 and x_input.shape[0] == 1, x_input.shape
|
1298 |
+
assert x_padding_mask.ndim == 2 and x_padding_mask.shape[0] == 1, x_padding_mask.shape
|
1299 |
+
assert y_input.ndim == 3 and y_input.shape[0] == 1 and y_input.shape[1] == new_y_lens[0], y_input.shape
|
1300 |
+
assert embedded_y.ndim == 3 and embedded_y.shape[0] == 1 and embedded_y.shape[1] == new_y_lens[0], embedded_y.shape
|
1301 |
+
x_input = x_input.repeat(batch_size, 1, 1)
|
1302 |
+
x_lens = x_lens.repeat(batch_size)
|
1303 |
+
# x_attention_mask = x_attention_mask.repeat(batch_size, 1, 1) # no need to work with attention mask, it doesn't contain batch dimension
|
1304 |
+
x_padding_mask = x_padding_mask.repeat(batch_size, 1)
|
1305 |
+
y_input = y_input.repeat(batch_size, 1, 1)
|
1306 |
+
new_y_lens = new_y_lens.repeat(batch_size)
|
1307 |
+
# y_attention_mask = y_attention_mask.repeat(batch_size, 1, 1) # no need to work with attention mask, it doesn't contain batch dimension
|
1308 |
+
y_padding_mask = y_padding_mask.repeat(batch_size, 1)
|
1309 |
+
embedded_y = embedded_y.repeat(batch_size, 1, 1) # will be used to concat with newly generated token embedding
|
1310 |
+
past = past.repeat(1, 1, batch_size) if past != None else None
|
1311 |
+
else:
|
1312 |
+
assert x_input.shape[0] == batch_size and x_padding_mask.shape[0] == batch_size and y_input.shape[0] == batch_size and new_y_lens.shape[0] == batch_size, f"x_input.shape: {x_input.shape}, x_padding_mask.shape: {x_padding_mask.shape}, y_input.shape: {y_input.shape}, new_y_lens.shape: {new_y_lens.shape}"
|
1313 |
+
y_out, present = self.dec_forward(
|
1314 |
+
x_input,
|
1315 |
+
x_lens,
|
1316 |
+
x_attention_mask,
|
1317 |
+
x_padding_mask,
|
1318 |
+
y_input,
|
1319 |
+
new_y_lens,
|
1320 |
+
y_attention_mask,
|
1321 |
+
y_padding_mask,
|
1322 |
+
past=past
|
1323 |
+
)
|
1324 |
+
if past != None:
|
1325 |
+
past = torch.cat([past, present.to(past.dtype)], dim=-2) if past.ndim > 3 else present.to(past.dtype)
|
1326 |
+
|
1327 |
+
# if no eog emerges, y_out should have batch size of batch_size
|
1328 |
+
if sum(codebook_eog) == 0:
|
1329 |
+
assert y_out.shape[0] == batch_size and y_out.ndim == 3, y_out.shape
|
1330 |
+
y_out = y_out[:, -1:] # only take the last token
|
1331 |
+
logits = torch.stack([self.predict_layer[i](y_out) for i in range(self.args.n_codebooks)], dim=1) # [B K S card], S==1, so [B K 1 card]
|
1332 |
+
logits = logits.squeeze(2) # [B K card]
|
1333 |
+
assert logits.shape == torch.Size((batch_size, self.args.n_codebooks, self.n_audio_tokens[0])), f"{logits.shape}"
|
1334 |
+
|
1335 |
+
n_eog = sum(codebook_eog)
|
1336 |
+
if self.args.eos > 0:
|
1337 |
+
for jj in range(self.args.n_codebooks):
|
1338 |
+
logits[:,jj,self.args.eog] = -10000.
|
1339 |
+
samples, codebook_eog, prev_tokens, consec_silence_counts, keep = sample_helper(n_eog, logits, codebook_eog, top_k, top_p, temperature, prev_tokens, consec_silence_counts, stop_repetition, silence_tokens, cur_num_gen, keep)
|
1340 |
+
|
1341 |
+
cur_num_gen += 1
|
1342 |
+
if sum(codebook_eog) == 0: # no eog yet, keep batch_size of samples
|
1343 |
+
assert keep == None
|
1344 |
+
for b in range(batch_size):
|
1345 |
+
cur_generated[b].append(samples[b].squeeze(-1))
|
1346 |
+
elif sum(codebook_eog) == 1: # the first eog just showed up in this step
|
1347 |
+
assert keep != None
|
1348 |
+
cur_generated = cur_generated[keep]
|
1349 |
+
cur_generated.append(samples[keep].squeeze(-1))
|
1350 |
+
else: # we are generating the rest eogs for the 'keep' sample
|
1351 |
+
cur_generated.append(samples[keep].squeeze(-1))
|
1352 |
+
|
1353 |
+
# samples.shape is [K,1]
|
1354 |
+
# ge samples_emb
|
1355 |
+
samples_emb = torch.stack([self.audio_embedding[k](samples[:, k]) for k in range(self.args.n_codebooks)], dim=1) # [B, K,1,D]
|
1356 |
+
assert samples_emb.shape == torch.Size([batch_size, self.args.n_codebooks, 1, self.args.d_model])
|
1357 |
+
samples_emb = samples_emb.sum(dim=1,keepdim=False) # [B,1,D]
|
1358 |
+
if sum(codebook_eog) == self.args.n_codebooks: # generation for the current span is done
|
1359 |
+
codebook_eog = [False] * self.args.n_codebooks
|
1360 |
+
num_gen.append(cur_num_gen)
|
1361 |
+
cur_num_gen = 0
|
1362 |
+
generated.append(cur_generated)
|
1363 |
+
cur_generated = [[] for _ in range(batch_size)]
|
1364 |
+
break
|
1365 |
+
else:
|
1366 |
+
assert samples_emb.shape == torch.Size((batch_size,1,self.args.d_model)), f"samples_emb.shape: {samples_emb.shape}"
|
1367 |
+
|
1368 |
+
embedded_y = torch.cat([embedded_y, samples_emb], dim=1)
|
1369 |
+
y_input = self.audio_positional_embedding(embedded_y) # [B T D]
|
1370 |
+
# make attention mask and padding mask
|
1371 |
+
y_attention_mask = torch.triu(torch.ones(y_input.shape[1], y_input.shape[1]), diagonal=1).bool().to(y.device)
|
1372 |
+
new_y_lens = torch.LongTensor([y_input.shape[1]]).to(y.device).repeat(batch_size)
|
1373 |
+
y_padding_mask = torch.full((batch_size,new_y_lens[0]), False).to(y.device)
|
1374 |
+
|
1375 |
+
assert len(generated) == 1, f"len(generated): {len(generated)}"
|
1376 |
+
|
1377 |
+
# revert the pattern
|
1378 |
+
flatten_gen = []
|
1379 |
+
for l, orig_span in enumerate(generated):
|
1380 |
+
span = torch.stack(orig_span, dim=0) # [T, K]
|
1381 |
+
span = span.transpose(1,0) # [K, T]
|
1382 |
+
assert span.shape[0] == self.args.n_codebooks, span.shape
|
1383 |
+
unshifted_span = []
|
1384 |
+
for j, s in enumerate(span):
|
1385 |
+
start_from = j
|
1386 |
+
end_at = - (self.args.n_codebooks - start_from)
|
1387 |
+
unshifted_span.append(s[start_from:end_at])
|
1388 |
+
unshifted_span = torch.stack(unshifted_span, dim=0)
|
1389 |
+
|
1390 |
+
assert unshifted_span.shape[1] == num_gen[l] - self.args.n_codebooks, f"len(unshifted_spans[0]): {len(unshifted_span[0])}, num_gen[l]: {num_gen[l]}"
|
1391 |
+
|
1392 |
+
flatten_gen.append(unshifted_span)
|
1393 |
+
assert len(flatten_gen) == 1, len(flatten_gen)
|
1394 |
+
|
1395 |
+
# combine
|
1396 |
+
res = [y[0], flatten_gen[0]]
|
1397 |
+
res = torch.cat(res, dim=1).unsqueeze(0) # [K, new_t] -> [1, K, new_T]
|
1398 |
+
|
1399 |
+
expected_y_len = y_len + sum([item - self.args.n_codebooks for item in num_gen])
|
1400 |
+
assert res.shape == torch.Size((1, self.args.n_codebooks, expected_y_len)), f"res.shape: {res.shape}, expected_y_len: {expected_y_len}. y_len + sum([item - self.args.n_codebooks for item in num_gen]): {y_len} + {sum([item - self.args.n_codebooks for item in num_gen])}"
|
1401 |
+
|
1402 |
+
if self.args.special_first:
|
1403 |
+
res = res - int(self.args.n_special)
|
1404 |
+
flatten_gen = flatten_gen - int(self.args.n_special)
|
1405 |
+
|
1406 |
+
return res, flatten_gen[0].unsqueeze(0)
|
pretrained_models/encodec_4cb2048_giga.th
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:caa0c595d4919527a9728d627150aa2a0b15b6d117b21855165851333dc63378
|
3 |
+
size 1167842971
|
pretrained_models/giga330M.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:35e028b8c5237cb4a6050ca81d4569b98e3a34ad9175fa252f7b1d13e6a9ad26
|
3 |
+
size 1746844161
|
requirements.txt
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
-e git+https://github.com/facebookresearch/audiocraft.git@c5157b5bf14bf83449c17ea1eeb66c19fb4bc7f0#egg=audiocraft
|
2 |
+
xformers==0.0.22
|
3 |
+
torchaudio==2.0.2
|
4 |
+
torch==2.0.1
|
5 |
+
phonemizer==3.2.1
|
6 |
+
gradio==3.50.2
|
7 |
+
nltk>=3.8.1
|
8 |
+
openai-whisper>=20231117
|
9 |
+
spaces
|