Spaces:
Runtime error
Runtime error
Commit
•
cf91771
0
Parent(s):
Duplicate from GitMylo/bark-voice-cloning
Browse filesCo-authored-by: Mylo <[email protected]>
- .gitattributes +34 -0
- README.md +16 -0
- app.py +98 -0
- data/models/hubert/hubert.pt +3 -0
- data/models/hubert/tokenizer.pth +3 -0
- hubert/__init__.py +0 -0
- hubert/customtokenizer.py +182 -0
- hubert/hubert_manager.py +33 -0
- hubert/pre_kmeans_hubert.py +85 -0
- requirements.txt +5 -0
.gitattributes
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: Bark Voice Cloning
|
3 |
+
emoji: 🐶
|
4 |
+
colorFrom: blue
|
5 |
+
colorTo: green
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 3.29.0
|
8 |
+
python_version: 3.10.11
|
9 |
+
app_file: app.py
|
10 |
+
models:
|
11 |
+
- facebook/hubert-base-ls960
|
12 |
+
- GitMylo/bark-voice-cloning
|
13 |
+
pinned: false
|
14 |
+
license: mit
|
15 |
+
duplicated_from: GitMylo/bark-voice-cloning
|
16 |
+
---
|
app.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import os.path
|
3 |
+
import uuid
|
4 |
+
|
5 |
+
import gradio
|
6 |
+
import numpy
|
7 |
+
import torch
|
8 |
+
|
9 |
+
from hubert.hubert_manager import HuBERTManager
|
10 |
+
from hubert.pre_kmeans_hubert import CustomHubert
|
11 |
+
from hubert.customtokenizer import CustomTokenizer
|
12 |
+
from encodec import EncodecModel
|
13 |
+
from encodec.utils import convert_audio
|
14 |
+
|
15 |
+
|
16 |
+
hubert_model = CustomHubert(HuBERTManager.make_sure_hubert_installed())
|
17 |
+
tokenizer_model = CustomTokenizer.load_from_checkpoint(
|
18 |
+
HuBERTManager.make_sure_tokenizer_installed(model='quantifier_V1_hubert_base_ls960_23.pth'),
|
19 |
+
map_location=torch.device('cpu')
|
20 |
+
)
|
21 |
+
encodec_model = EncodecModel.encodec_model_24khz()
|
22 |
+
|
23 |
+
|
24 |
+
|
25 |
+
def clone(audio, *args):
|
26 |
+
sr, wav = audio
|
27 |
+
|
28 |
+
wav = torch.tensor(wav)
|
29 |
+
|
30 |
+
if wav.dtype == torch.int16:
|
31 |
+
wav = wav.float() / 32767.0
|
32 |
+
|
33 |
+
if len(wav.shape) == 2:
|
34 |
+
if wav.shape[0] == 2: # Stereo to mono if needed
|
35 |
+
wav = wav.mean(0, keepdim=True)
|
36 |
+
if wav.shape[1] == 2:
|
37 |
+
wav = wav.mean(1, keepdim=False).unsqueeze(-1)
|
38 |
+
|
39 |
+
wav = wav[-int(sr*20):] # Take only the last 20 seconds
|
40 |
+
|
41 |
+
wav = wav.reshape(1, -1) # Reshape from gradio style to HuBERT shape. (N, 1) to (1, N)
|
42 |
+
|
43 |
+
semantic_vectors = hubert_model.forward(wav, input_sample_hz=sr)
|
44 |
+
semantic_tokens = tokenizer_model.get_token(semantic_vectors)
|
45 |
+
|
46 |
+
encodec_model.set_target_bandwidth(6.0)
|
47 |
+
wav = convert_audio(wav, sr, encodec_model.sample_rate, 1)
|
48 |
+
wav = wav.unsqueeze(0)
|
49 |
+
|
50 |
+
with torch.no_grad():
|
51 |
+
encoded_frames = encodec_model.encode(wav)
|
52 |
+
|
53 |
+
codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1).squeeze() # [B, n_q, T]
|
54 |
+
|
55 |
+
if not os.path.isdir('data/speakers'):
|
56 |
+
os.makedirs('data/speakers')
|
57 |
+
|
58 |
+
file_path = f'data/speakers/{uuid.uuid4().hex}.npz'
|
59 |
+
|
60 |
+
numpy.savez(
|
61 |
+
file_path,
|
62 |
+
semantic_prompt=semantic_tokens,
|
63 |
+
fine_prompt=codes,
|
64 |
+
coarse_prompt=codes[:2, :]
|
65 |
+
)
|
66 |
+
|
67 |
+
return file_path
|
68 |
+
|
69 |
+
|
70 |
+
|
71 |
+
iface = gradio.interface.Interface(fn=clone, inputs=[
|
72 |
+
'audio',
|
73 |
+
gradio.Markdown(
|
74 |
+
'''
|
75 |
+
# Bark text to speech voice cloning
|
76 |
+
[Model](https://huggingface.co/GitMylo/bark-voice-cloning/), [Model GitHub](https://github.com/gitmylo/bark-voice-cloning-HuBERT-quantizer), [Webui GitHub](https://github.com/gitmylo/audio-webui)
|
77 |
+
|
78 |
+
For faster creation of voice clones [Duplicate this space](https://huggingface.co/spaces/GitMylo/bark-voice-cloning?duplicate=true)
|
79 |
+
|
80 |
+
Uploaded audio files get cut to 20 seconds in order to keep it fast for everyone. Only the last 20 seconds will be used. (Bark only uses the last 14 seconds anyway)
|
81 |
+
|
82 |
+
## Tips for better cloning
|
83 |
+
### Make sure these things are **NOT** in your voice input: (in no particular order)
|
84 |
+
* Noise (You can use a noise remover before)
|
85 |
+
* Music (There are also music remover tools) (Unless you want music in the background)
|
86 |
+
* A cut-off at the end (This will cause it to try and continue on the generation)
|
87 |
+
* Under 1 second of training data (i personally suggest around 10 seconds for good potential, but i've had great results with 5 seconds as well.)
|
88 |
+
|
89 |
+
### What makes for good prompt audio? (in no particular order)
|
90 |
+
* Clearly spoken
|
91 |
+
* No weird background noises
|
92 |
+
* Only one speaker
|
93 |
+
* Audio which ends after a sentence ends
|
94 |
+
* Regular/common voice (They usually have more success, it's still capable of cloning complex voices, but not as good at it)
|
95 |
+
* Around 10 seconds of data
|
96 |
+
''')
|
97 |
+
], outputs='file')
|
98 |
+
iface.launch()
|
data/models/hubert/hubert.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1703cf8d2cdc76f8c046f5f6a9bcd224e0e6caf4744cad1a1f4199c32cac8c8d
|
3 |
+
size 1136468879
|
data/models/hubert/tokenizer.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0d94c5dd646bcfe1a8bb470372f0004c189acf65d913831f3a6ed6414c9ba86f
|
3 |
+
size 243656111
|
hubert/__init__.py
ADDED
File without changes
|
hubert/customtokenizer.py
ADDED
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os.path
|
3 |
+
from zipfile import ZipFile
|
4 |
+
|
5 |
+
import numpy
|
6 |
+
import torch
|
7 |
+
from torch import nn, optim
|
8 |
+
from torch.serialization import MAP_LOCATION
|
9 |
+
|
10 |
+
|
11 |
+
class CustomTokenizer(nn.Module):
|
12 |
+
def __init__(self, hidden_size=1024, input_size=768, output_size=10000, version=0):
|
13 |
+
super(CustomTokenizer, self).__init__()
|
14 |
+
next_size = input_size
|
15 |
+
if version == 0:
|
16 |
+
self.lstm = nn.LSTM(input_size, hidden_size, 2, batch_first=True)
|
17 |
+
next_size = hidden_size
|
18 |
+
if version == 1:
|
19 |
+
self.lstm = nn.LSTM(input_size, hidden_size, 2, batch_first=True)
|
20 |
+
self.intermediate = nn.Linear(hidden_size, 4096)
|
21 |
+
next_size = 4096
|
22 |
+
|
23 |
+
self.fc = nn.Linear(next_size, output_size)
|
24 |
+
self.softmax = nn.LogSoftmax(dim=1)
|
25 |
+
self.optimizer: optim.Optimizer = None
|
26 |
+
self.lossfunc = nn.CrossEntropyLoss()
|
27 |
+
self.input_size = input_size
|
28 |
+
self.hidden_size = hidden_size
|
29 |
+
self.output_size = output_size
|
30 |
+
self.version = version
|
31 |
+
|
32 |
+
def forward(self, x):
|
33 |
+
x, _ = self.lstm(x)
|
34 |
+
if self.version == 1:
|
35 |
+
x = self.intermediate(x)
|
36 |
+
x = self.fc(x)
|
37 |
+
x = self.softmax(x)
|
38 |
+
return x
|
39 |
+
|
40 |
+
@torch.no_grad()
|
41 |
+
def get_token(self, x):
|
42 |
+
"""
|
43 |
+
Used to get the token for the first
|
44 |
+
:param x: An array with shape (N, input_size) where N is a whole number greater or equal to 1, and input_size is the input size used when creating the model.
|
45 |
+
:return: An array with shape (N,) where N is the same as N from the input. Every number in the array is a whole number in range 0...output_size - 1 where output_size is the output size used when creating the model.
|
46 |
+
"""
|
47 |
+
return torch.argmax(self(x), dim=1)
|
48 |
+
|
49 |
+
def prepare_training(self):
|
50 |
+
self.optimizer = optim.Adam(self.parameters(), 0.001)
|
51 |
+
|
52 |
+
def train_step(self, x_train, y_train, log_loss=False):
|
53 |
+
# y_train = y_train[:-1]
|
54 |
+
# y_train = y_train[1:]
|
55 |
+
|
56 |
+
optimizer = self.optimizer
|
57 |
+
lossfunc = self.lossfunc
|
58 |
+
# Zero the gradients
|
59 |
+
self.zero_grad()
|
60 |
+
|
61 |
+
# Forward pass
|
62 |
+
y_pred = self(x_train)
|
63 |
+
|
64 |
+
y_train_len = len(y_train)
|
65 |
+
y_pred_len = y_pred.shape[0]
|
66 |
+
|
67 |
+
if y_train_len > y_pred_len:
|
68 |
+
diff = y_train_len - y_pred_len
|
69 |
+
y_train = y_train[diff:]
|
70 |
+
elif y_train_len < y_pred_len:
|
71 |
+
diff = y_pred_len - y_train_len
|
72 |
+
y_pred = y_pred[:-diff, :]
|
73 |
+
|
74 |
+
y_train_hot = torch.zeros(len(y_train), self.output_size)
|
75 |
+
y_train_hot[range(len(y_train)), y_train] = 1
|
76 |
+
y_train_hot = y_train_hot.to('cuda')
|
77 |
+
|
78 |
+
# Calculate the loss
|
79 |
+
loss = lossfunc(y_pred, y_train_hot)
|
80 |
+
|
81 |
+
# Print loss
|
82 |
+
if log_loss:
|
83 |
+
print('Loss', loss.item())
|
84 |
+
|
85 |
+
# Backward pass
|
86 |
+
loss.backward()
|
87 |
+
|
88 |
+
# Update the weights
|
89 |
+
optimizer.step()
|
90 |
+
|
91 |
+
def save(self, path):
|
92 |
+
info_path = os.path.basename(path) + '/.info'
|
93 |
+
torch.save(self.state_dict(), path)
|
94 |
+
data_from_model = Data(self.input_size, self.hidden_size, self.output_size, self.version)
|
95 |
+
with ZipFile(path, 'a') as model_zip:
|
96 |
+
model_zip.writestr(info_path, data_from_model.save())
|
97 |
+
model_zip.close()
|
98 |
+
|
99 |
+
@staticmethod
|
100 |
+
def load_from_checkpoint(path, map_location: MAP_LOCATION = None):
|
101 |
+
old = True
|
102 |
+
with ZipFile(path) as model_zip:
|
103 |
+
filesMatch = [file for file in model_zip.namelist() if file.endswith('/.info')]
|
104 |
+
file = filesMatch[0] if filesMatch else None
|
105 |
+
if file:
|
106 |
+
old = False
|
107 |
+
data_from_model = Data.load(model_zip.read(file).decode('utf-8'))
|
108 |
+
model_zip.close()
|
109 |
+
if old:
|
110 |
+
model = CustomTokenizer()
|
111 |
+
else:
|
112 |
+
model = CustomTokenizer(data_from_model.hidden_size, data_from_model.input_size, data_from_model.output_size, data_from_model.version)
|
113 |
+
model.load_state_dict(torch.load(path, map_location))
|
114 |
+
return model
|
115 |
+
|
116 |
+
|
117 |
+
|
118 |
+
class Data:
|
119 |
+
input_size: int
|
120 |
+
hidden_size: int
|
121 |
+
output_size: int
|
122 |
+
version: int
|
123 |
+
|
124 |
+
def __init__(self, input_size=768, hidden_size=1024, output_size=10000, version=0):
|
125 |
+
self.input_size = input_size
|
126 |
+
self.hidden_size = hidden_size
|
127 |
+
self.output_size = output_size
|
128 |
+
self.version = version
|
129 |
+
|
130 |
+
@staticmethod
|
131 |
+
def load(string):
|
132 |
+
data = json.loads(string)
|
133 |
+
return Data(data['input_size'], data['hidden_size'], data['output_size'], data['version'])
|
134 |
+
|
135 |
+
def save(self):
|
136 |
+
data = {
|
137 |
+
'input_size': self.input_size,
|
138 |
+
'hidden_size': self.hidden_size,
|
139 |
+
'output_size': self.output_size,
|
140 |
+
'version': self.version,
|
141 |
+
}
|
142 |
+
return json.dumps(data)
|
143 |
+
|
144 |
+
|
145 |
+
def auto_train(data_path, save_path='model.pth', load_model: str | None = None, save_epochs=1):
|
146 |
+
data_x, data_y = [], []
|
147 |
+
|
148 |
+
if load_model and os.path.isfile(load_model):
|
149 |
+
print('Loading model from', load_model)
|
150 |
+
model_training = CustomTokenizer.load_from_checkpoint(load_model, 'cuda')
|
151 |
+
else:
|
152 |
+
print('Creating new model.')
|
153 |
+
model_training = CustomTokenizer(version=1).to('cuda') # Settings for the model to run without lstm
|
154 |
+
save_path = os.path.join(data_path, save_path)
|
155 |
+
base_save_path = '.'.join(save_path.split('.')[:-1])
|
156 |
+
|
157 |
+
sem_string = '_semantic.npy'
|
158 |
+
feat_string = '_semantic_features.npy'
|
159 |
+
|
160 |
+
ready = os.path.join(data_path, 'ready')
|
161 |
+
for input_file in os.listdir(ready):
|
162 |
+
full_path = os.path.join(ready, input_file)
|
163 |
+
if input_file.endswith(sem_string):
|
164 |
+
data_y.append(numpy.load(full_path))
|
165 |
+
elif input_file.endswith(feat_string):
|
166 |
+
data_x.append(numpy.load(full_path))
|
167 |
+
model_training.prepare_training()
|
168 |
+
|
169 |
+
epoch = 1
|
170 |
+
|
171 |
+
while 1:
|
172 |
+
for i in range(save_epochs):
|
173 |
+
j = 0
|
174 |
+
for x, y in zip(data_x, data_y):
|
175 |
+
model_training.train_step(torch.tensor(x).to('cuda'), torch.tensor(y).to('cuda'), j % 50 == 0) # Print loss every 50 steps
|
176 |
+
j += 1
|
177 |
+
save_p = save_path
|
178 |
+
save_p_2 = f'{base_save_path}_epoch_{epoch}.pth'
|
179 |
+
model_training.save(save_p)
|
180 |
+
model_training.save(save_p_2)
|
181 |
+
print(f'Epoch {epoch} completed')
|
182 |
+
epoch += 1
|
hubert/hubert_manager.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os.path
|
2 |
+
import shutil
|
3 |
+
import urllib.request
|
4 |
+
|
5 |
+
import huggingface_hub
|
6 |
+
|
7 |
+
|
8 |
+
class HuBERTManager:
|
9 |
+
@staticmethod
|
10 |
+
def make_sure_hubert_installed(download_url: str = 'https://dl.fbaipublicfiles.com/hubert/hubert_base_ls960.pt', file_name: str = 'hubert.pt'):
|
11 |
+
install_dir = os.path.join('data', 'models', 'hubert')
|
12 |
+
if not os.path.isdir(install_dir):
|
13 |
+
os.makedirs(install_dir, exist_ok=True)
|
14 |
+
install_file = os.path.join(install_dir, file_name)
|
15 |
+
if not os.path.isfile(install_file):
|
16 |
+
print('Downloading HuBERT base model')
|
17 |
+
urllib.request.urlretrieve(download_url, install_file)
|
18 |
+
print('Downloaded HuBERT')
|
19 |
+
return install_file
|
20 |
+
|
21 |
+
|
22 |
+
@staticmethod
|
23 |
+
def make_sure_tokenizer_installed(model: str = 'quantifier_hubert_base_ls960_14.pth', repo: str = 'GitMylo/bark-voice-cloning', local_file: str = 'tokenizer.pth'):
|
24 |
+
install_dir = os.path.join('data', 'models', 'hubert')
|
25 |
+
if not os.path.isdir(install_dir):
|
26 |
+
os.makedirs(install_dir, exist_ok=True)
|
27 |
+
install_file = os.path.join(install_dir, local_file)
|
28 |
+
if not os.path.isfile(install_file):
|
29 |
+
print('Downloading HuBERT custom tokenizer')
|
30 |
+
huggingface_hub.hf_hub_download(repo, model, local_dir=install_dir, local_dir_use_symlinks=False)
|
31 |
+
shutil.move(os.path.join(install_dir, model), install_file)
|
32 |
+
print('Downloaded tokenizer')
|
33 |
+
return install_file
|
hubert/pre_kmeans_hubert.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
from einops import pack, unpack
|
6 |
+
|
7 |
+
import fairseq
|
8 |
+
|
9 |
+
from torchaudio.functional import resample
|
10 |
+
|
11 |
+
import logging
|
12 |
+
logging.root.setLevel(logging.ERROR)
|
13 |
+
|
14 |
+
|
15 |
+
def exists(val):
|
16 |
+
return val is not None
|
17 |
+
|
18 |
+
|
19 |
+
def default(val, d):
|
20 |
+
return val if exists(val) else d
|
21 |
+
|
22 |
+
|
23 |
+
class CustomHubert(nn.Module):
|
24 |
+
"""
|
25 |
+
checkpoint and kmeans can be downloaded at https://github.com/facebookresearch/fairseq/tree/main/examples/hubert
|
26 |
+
or you can train your own
|
27 |
+
"""
|
28 |
+
|
29 |
+
def __init__(
|
30 |
+
self,
|
31 |
+
checkpoint_path,
|
32 |
+
target_sample_hz=16000,
|
33 |
+
seq_len_multiple_of=None,
|
34 |
+
output_layer=9
|
35 |
+
):
|
36 |
+
super().__init__()
|
37 |
+
self.target_sample_hz = target_sample_hz
|
38 |
+
self.seq_len_multiple_of = seq_len_multiple_of
|
39 |
+
self.output_layer = output_layer
|
40 |
+
|
41 |
+
model_path = Path(checkpoint_path)
|
42 |
+
|
43 |
+
assert model_path.exists(), f'path {checkpoint_path} does not exist'
|
44 |
+
|
45 |
+
checkpoint = torch.load(checkpoint_path)
|
46 |
+
load_model_input = {checkpoint_path: checkpoint}
|
47 |
+
model, *_ = fairseq.checkpoint_utils.load_model_ensemble_and_task(load_model_input)
|
48 |
+
|
49 |
+
self.model = model[0]
|
50 |
+
self.model.eval()
|
51 |
+
|
52 |
+
@property
|
53 |
+
def groups(self):
|
54 |
+
return 1
|
55 |
+
|
56 |
+
@torch.no_grad()
|
57 |
+
def forward(
|
58 |
+
self,
|
59 |
+
wav_input,
|
60 |
+
flatten=True,
|
61 |
+
input_sample_hz=None
|
62 |
+
):
|
63 |
+
device = wav_input.device
|
64 |
+
|
65 |
+
if exists(input_sample_hz):
|
66 |
+
wav_input = resample(wav_input, input_sample_hz, self.target_sample_hz)
|
67 |
+
|
68 |
+
embed = self.model(
|
69 |
+
wav_input,
|
70 |
+
features_only=True,
|
71 |
+
mask=False, # thanks to @maitycyrus for noticing that mask is defaulted to True in the fairseq code
|
72 |
+
output_layer=self.output_layer
|
73 |
+
)
|
74 |
+
|
75 |
+
embed, packed_shape = pack([embed['x']], '* d')
|
76 |
+
|
77 |
+
# codebook_indices = self.kmeans.predict(embed.cpu().detach().numpy())
|
78 |
+
|
79 |
+
codebook_indices = torch.from_numpy(embed.cpu().detach().numpy()).to(device) # .long()
|
80 |
+
|
81 |
+
if flatten:
|
82 |
+
return codebook_indices
|
83 |
+
|
84 |
+
codebook_indices, = unpack(codebook_indices, packed_shape, '*')
|
85 |
+
return codebook_indices
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
torchaudio
|
3 |
+
encodec
|
4 |
+
joblib
|
5 |
+
fairseq
|