Spaces:
Runtime error
Runtime error
Upload infer.py
Browse files
infer.py
CHANGED
@@ -85,22 +85,22 @@ def get_text(text, language_str, hps, device):
|
|
85 |
for i in range(len(word2ph)):
|
86 |
word2ph[i] = word2ph[i] * 2
|
87 |
word2ph[0] += 1
|
88 |
-
|
89 |
del word2ph
|
90 |
-
assert
|
91 |
|
92 |
if language_str == "ZH":
|
93 |
-
bert =
|
94 |
ja_bert = torch.zeros(1024, len(phone))
|
95 |
en_bert = torch.zeros(1024, len(phone))
|
96 |
elif language_str == "JP":
|
97 |
bert = torch.zeros(1024, len(phone))
|
98 |
-
ja_bert =
|
99 |
en_bert = torch.zeros(1024, len(phone))
|
100 |
elif language_str == "EN":
|
101 |
bert = torch.zeros(1024, len(phone))
|
102 |
ja_bert = torch.zeros(1024, len(phone))
|
103 |
-
en_bert =
|
104 |
else:
|
105 |
raise ValueError("language_str should be ZH, JP or EN")
|
106 |
|
@@ -125,6 +125,8 @@ def infer(
|
|
125 |
hps,
|
126 |
net_g,
|
127 |
device,
|
|
|
|
|
128 |
):
|
129 |
# 支持中日双语版本
|
130 |
inferMap_V2 = {
|
@@ -172,6 +174,20 @@ def infer(
|
|
172 |
bert, ja_bert, en_bert, phones, tones, lang_ids = get_text(
|
173 |
text, language, hps, device
|
174 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
175 |
with torch.no_grad():
|
176 |
x_tst = phones.to(device).unsqueeze(0)
|
177 |
tones = tones.to(device).unsqueeze(0)
|
@@ -201,10 +217,11 @@ def infer(
|
|
201 |
.float()
|
202 |
.numpy()
|
203 |
)
|
204 |
-
del x_tst, tones, lang_ids, bert, x_tst_lengths, speakers
|
205 |
-
torch.cuda.
|
|
|
206 |
return audio
|
207 |
-
|
208 |
|
209 |
def infer_multilang(
|
210 |
text,
|
|
|
85 |
for i in range(len(word2ph)):
|
86 |
word2ph[i] = word2ph[i] * 2
|
87 |
word2ph[0] += 1
|
88 |
+
bert_ori = get_bert(norm_text, word2ph, language_str, device)
|
89 |
del word2ph
|
90 |
+
assert bert_ori.shape[-1] == len(phone), phone
|
91 |
|
92 |
if language_str == "ZH":
|
93 |
+
bert = bert_ori
|
94 |
ja_bert = torch.zeros(1024, len(phone))
|
95 |
en_bert = torch.zeros(1024, len(phone))
|
96 |
elif language_str == "JP":
|
97 |
bert = torch.zeros(1024, len(phone))
|
98 |
+
ja_bert = bert_ori
|
99 |
en_bert = torch.zeros(1024, len(phone))
|
100 |
elif language_str == "EN":
|
101 |
bert = torch.zeros(1024, len(phone))
|
102 |
ja_bert = torch.zeros(1024, len(phone))
|
103 |
+
en_bert = bert_ori
|
104 |
else:
|
105 |
raise ValueError("language_str should be ZH, JP or EN")
|
106 |
|
|
|
125 |
hps,
|
126 |
net_g,
|
127 |
device,
|
128 |
+
skip_start=False,
|
129 |
+
skip_end=False,
|
130 |
):
|
131 |
# 支持中日双语版本
|
132 |
inferMap_V2 = {
|
|
|
174 |
bert, ja_bert, en_bert, phones, tones, lang_ids = get_text(
|
175 |
text, language, hps, device
|
176 |
)
|
177 |
+
if skip_start:
|
178 |
+
phones = phones[1:]
|
179 |
+
tones = tones[1:]
|
180 |
+
lang_ids = lang_ids[1:]
|
181 |
+
bert = bert[:, 1:]
|
182 |
+
ja_bert = ja_bert[:, 1:]
|
183 |
+
en_bert = en_bert[:, 1:]
|
184 |
+
if skip_end:
|
185 |
+
phones = phones[:-1]
|
186 |
+
tones = tones[:-1]
|
187 |
+
lang_ids = lang_ids[:-1]
|
188 |
+
bert = bert[:, :-1]
|
189 |
+
ja_bert = ja_bert[:, :-1]
|
190 |
+
en_bert = en_bert[:, :-1]
|
191 |
with torch.no_grad():
|
192 |
x_tst = phones.to(device).unsqueeze(0)
|
193 |
tones = tones.to(device).unsqueeze(0)
|
|
|
217 |
.float()
|
218 |
.numpy()
|
219 |
)
|
220 |
+
del x_tst, tones, lang_ids, bert, x_tst_lengths, speakers, ja_bert, en_bert
|
221 |
+
if torch.cuda.is_available():
|
222 |
+
torch.cuda.empty_cache()
|
223 |
return audio
|
224 |
+
|
225 |
|
226 |
def infer_multilang(
|
227 |
text,
|