lj1995 commited on
Commit
21fe39c
1 Parent(s): 7474b27

Update inference_webui.py

Browse files
Files changed (1) hide show
  1. inference_webui.py +125 -160
inference_webui.py CHANGED
@@ -253,7 +253,6 @@ def get_first(text):
253
  from text import chinese
254
  def get_phones_and_bert(text,language,version):
255
  if language in {"en", "all_zh", "all_ja", "all_ko", "all_yue"}:
256
- print(":1")
257
  language = language.replace("all_","")
258
  if language == "en":
259
  LangSegment.setfilters(["en"])
@@ -264,39 +263,27 @@ def get_phones_and_bert(text,language,version):
264
  while " " in formattext:
265
  formattext = formattext.replace(" ", " ")
266
  if language == "zh":
267
- print(":2")
268
  if re.search(r'[A-Za-z]', formattext):
269
- print(":3")
270
  formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext)
271
  formattext = chinese.mix_text_normalize(formattext)
272
- print(":4")
273
  return get_phones_and_bert(formattext,"zh",version)
274
  else:
275
- print(":5")
276
  phones, word2ph, norm_text = clean_text_inf(formattext, language, version)
277
- print(":6")
278
  bert = get_bert_feature(norm_text, word2ph).to(device)
279
- print(":7")
280
  elif language == "yue" and re.search(r'[A-Za-z]', formattext):
281
  formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext)
282
  formattext = chinese.mix_text_normalize(formattext)
283
- print(":8")
284
  return get_phones_and_bert(formattext,"yue",version)
285
  else:
286
- print(":9")
287
  phones, word2ph, norm_text = clean_text_inf(formattext, language, version)
288
- print(":10")
289
  bert = torch.zeros(
290
  (1024, len(phones)),
291
  dtype=torch.float16 if is_half == True else torch.float32,
292
  ).to(device)
293
- print(":11")
294
  elif language in {"zh", "ja", "ko", "yue", "auto", "auto_yue"}:
295
  textlist=[]
296
  langlist=[]
297
- print(":12")
298
  LangSegment.setfilters(["zh","ja","en","ko"])
299
- print(":13")
300
  if language == "auto":
301
  for tmp in LangSegment.getTexts(text):
302
  langlist.append(tmp["lang"])
@@ -356,157 +343,135 @@ def merge_short_text_in_array(texts, threshold):
356
  cache= {}
357
  def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, how_to_cut=i18n("不切"), top_k=20, top_p=0.6, temperature=0.6, ref_free = False,speed=1,if_freeze=False,inp_refs=123):
358
  global cache
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
359
 
360
- import psutil
361
-
362
- # 获取内存信息
363
- memory_info = psutil.virtual_memory()
364
-
365
- # 打印总内存和剩余内存
366
- total_memory = memory_info.total / (1024 ** 3) # 转换为 GB
367
- available_memory = memory_info.available / (1024 ** 3) # 转换为 GB
368
-
369
- print(f"总内存: {total_memory:.2f} GB")
370
- print(f"剩余内存: {available_memory:.2f} GB")
371
-
372
-
373
- try:
374
- if ref_wav_path:pass
375
- else:gr.Warning(i18n('请上传参考音频'))
376
- if text:pass
377
- else:gr.Warning(i18n('请填入推理文本'))
378
- t = []
379
- if prompt_text is None or len(prompt_text) == 0:
380
- ref_free = True
381
- t0 = ttime()
382
- prompt_language = dict_language[prompt_language]
383
- text_language = dict_language[text_language]
384
-
385
 
386
- if not ref_free:
387
- prompt_text = prompt_text.strip("\n")
388
- if (prompt_text[-1] not in splits): prompt_text += "。" if prompt_language != "en" else "."
389
- print(i18n("实际输入的参考文本:"), prompt_text)
390
- text = text.strip("\n")
391
- if (text[0] not in splits and len(get_first(text)) < 4): text = "。" + text if text_language != "en" else "." + text
392
-
393
- print(i18n("实际输入的目标文本:"), text)
394
- zero_wav = np.zeros(
395
- int(hps.data.sampling_rate * 0.3),
396
- dtype=np.float16 if is_half == True else np.float32,
397
- )
398
- if not ref_free:
399
  with torch.no_grad():
400
- wav16k, sr = librosa.load(ref_wav_path, sr=16000)
401
- if (wav16k.shape[0] > 160000 or wav16k.shape[0] < 48000):
402
- gr.Warning(i18n("参考音频在3~10秒范围外,请更换!"))
403
- raise OSError(i18n("参考音频在3~10秒范围外,请更换!"))
404
- wav16k = torch.from_numpy(wav16k)
405
- zero_wav_torch = torch.from_numpy(zero_wav)
406
- if is_half == True:
407
- wav16k = wav16k.half().to(device)
408
- zero_wav_torch = zero_wav_torch.half().to(device)
409
- else:
410
- wav16k = wav16k.to(device)
411
- zero_wav_torch = zero_wav_torch.to(device)
412
- wav16k = torch.cat([wav16k, zero_wav_torch])
413
- ssl_content = ssl_model.model(wav16k.unsqueeze(0))[
414
- "last_hidden_state"
415
- ].transpose(
416
- 1, 2
417
- ) # .float()
418
- codes = vq_model.extract_latent(ssl_content)
419
- prompt_semantic = codes[0, 0]
420
- prompt = prompt_semantic.unsqueeze(0).to(device)
421
-
 
 
 
 
 
 
 
 
422
  t1 = ttime()
423
- t.append(t1-t0)
424
-
425
- if (how_to_cut == i18n("凑四句一切")):
426
- text = cut1(text)
427
- elif (how_to_cut == i18n("凑50字一切")):
428
- text = cut2(text)
429
- elif (how_to_cut == i18n("按中文句号。切")):
430
- text = cut3(text)
431
- elif (how_to_cut == i18n("按英文句号.切")):
432
- text = cut4(text)
433
- elif (how_to_cut == i18n("按标点符号切")):
434
- text = cut5(text)
435
- while "\n\n" in text:
436
- text = text.replace("\n\n", "\n")
437
- print(i18n("实际输入的目标文本(切句后):"), text)
438
- texts = text.split("\n")
439
- texts = process_text(texts)
440
- print(2)
441
- texts = merge_short_text_in_array(texts, 5)
442
- print(3)
443
- audio_opt = []
444
- if not ref_free:
445
- phones1,bert1,norm_text1=get_phones_and_bert(prompt_text, prompt_language, version)
446
- print(4)
447
- for i_text,text in enumerate(texts):
448
- # 解决输入目标文本的空行导致报错的问题
449
- if (len(text.strip()) == 0):
450
- continue
451
- print(5)
452
- if (text[-1] not in splits): text += "。" if text_language != "en" else "."
453
- print(i18n("实际输入的目标文本(每句):"), text)
454
- print(6)
455
- phones2,bert2,norm_text2=get_phones_and_bert(text, text_language, version)
456
- print(i18n("前端处理后的文本(每句):"), norm_text2)
457
- print(7)
458
- if not ref_free:
459
- bert = torch.cat([bert1, bert2], 1)
460
- all_phoneme_ids = torch.LongTensor(phones1+phones2).to(device).unsqueeze(0)
461
- else:
462
- bert = bert2
463
- all_phoneme_ids = torch.LongTensor(phones2).to(device).unsqueeze(0)
464
-
465
- bert = bert.to(device).unsqueeze(0)
466
- all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device)
467
-
468
- t2 = ttime()
469
- # cache_key="%s-%s-%s-%s-%s-%s-%s-%s"%(ref_wav_path,prompt_text,prompt_language,text,text_language,top_k,top_p,temperature)
470
- # print(cache.keys(),if_freeze)
471
- if(i_text in cache and if_freeze==True):pred_semantic=cache[i_text]
472
- else:
473
- with torch.no_grad():
474
- pred_semantic, idx = t2s_model.model.infer_panel(
475
- all_phoneme_ids,
476
- all_phoneme_len,
477
- None if ref_free else prompt,
478
- bert,
479
- # prompt_phone_len=ph_offset,
480
- top_k=top_k,
481
- top_p=top_p,
482
- temperature=temperature,
483
- early_stop_num=hz * max_sec,
484
- )
485
- pred_semantic = pred_semantic[:, -idx:].unsqueeze(0)
486
- cache[i_text]=pred_semantic
487
- t3 = ttime()
488
- refers=[]
489
- if(inp_refs):
490
- for path in inp_refs:
491
- try:
492
- refer = get_spepc(hps, path.name).to(dtype).to(device)
493
- refers.append(refer)
494
- except:
495
- traceback.print_exc()
496
- if(len(refers)==0):refers = [get_spepc(hps, ref_wav_path).to(dtype).to(device)]
497
- audio = (vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refers,speed=speed).detach().cpu().numpy()[0, 0])
498
- max_audio=np.abs(audio).max()#简单防止16bit爆音
499
- if max_audio>1:audio/=max_audio
500
- audio_opt.append(audio)
501
- audio_opt.append(zero_wav)
502
- t4 = ttime()
503
- t.extend([t2 - t1,t3 - t2, t4 - t3])
504
- t1 = ttime()
505
- print("%.3f\t%.3f\t%.3f\t%.3f" %
506
- (t[0], sum(t[1::3]), sum(t[2::3]), sum(t[3::3]))
507
- )
508
- except:
509
- print(traceback.format_exc())
510
  yield hps.data.sampling_rate, (np.concatenate(audio_opt, 0) * 32768).astype(
511
  np.int16
512
  )
 
253
  from text import chinese
254
  def get_phones_and_bert(text,language,version):
255
  if language in {"en", "all_zh", "all_ja", "all_ko", "all_yue"}:
 
256
  language = language.replace("all_","")
257
  if language == "en":
258
  LangSegment.setfilters(["en"])
 
263
  while " " in formattext:
264
  formattext = formattext.replace(" ", " ")
265
  if language == "zh":
 
266
  if re.search(r'[A-Za-z]', formattext):
 
267
  formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext)
268
  formattext = chinese.mix_text_normalize(formattext)
 
269
  return get_phones_and_bert(formattext,"zh",version)
270
  else:
 
271
  phones, word2ph, norm_text = clean_text_inf(formattext, language, version)
 
272
  bert = get_bert_feature(norm_text, word2ph).to(device)
 
273
  elif language == "yue" and re.search(r'[A-Za-z]', formattext):
274
  formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext)
275
  formattext = chinese.mix_text_normalize(formattext)
 
276
  return get_phones_and_bert(formattext,"yue",version)
277
  else:
 
278
  phones, word2ph, norm_text = clean_text_inf(formattext, language, version)
 
279
  bert = torch.zeros(
280
  (1024, len(phones)),
281
  dtype=torch.float16 if is_half == True else torch.float32,
282
  ).to(device)
 
283
  elif language in {"zh", "ja", "ko", "yue", "auto", "auto_yue"}:
284
  textlist=[]
285
  langlist=[]
 
286
  LangSegment.setfilters(["zh","ja","en","ko"])
 
287
  if language == "auto":
288
  for tmp in LangSegment.getTexts(text):
289
  langlist.append(tmp["lang"])
 
343
  cache= {}
344
  def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, how_to_cut=i18n("不切"), top_k=20, top_p=0.6, temperature=0.6, ref_free = False,speed=1,if_freeze=False,inp_refs=123):
345
  global cache
346
+ if ref_wav_path:pass
347
+ else:gr.Warning(i18n('请上传参考音频'))
348
+ if text:pass
349
+ else:gr.Warning(i18n('请填入推理文本'))
350
+ t = []
351
+ if prompt_text is None or len(prompt_text) == 0:
352
+ ref_free = True
353
+ t0 = ttime()
354
+ prompt_language = dict_language[prompt_language]
355
+ text_language = dict_language[text_language]
356
+
357
+
358
+ if not ref_free:
359
+ prompt_text = prompt_text.strip("\n")
360
+ if (prompt_text[-1] not in splits): prompt_text += "。" if prompt_language != "en" else "."
361
+ print(i18n("实际输入的参考文本:"), prompt_text)
362
+ text = text.strip("\n")
363
+ if (text[0] not in splits and len(get_first(text)) < 4): text = "。" + text if text_language != "en" else "." + text
364
+
365
+ print(i18n("实际输入的目标文本:"), text)
366
+ zero_wav = np.zeros(
367
+ int(hps.data.sampling_rate * 0.3),
368
+ dtype=np.float16 if is_half == True else np.float32,
369
+ )
370
+ if not ref_free:
371
+ with torch.no_grad():
372
+ wav16k, sr = librosa.load(ref_wav_path, sr=16000)
373
+ if (wav16k.shape[0] > 160000 or wav16k.shape[0] < 48000):
374
+ gr.Warning(i18n("参考音频在3~10秒范围外,请更换!"))
375
+ raise OSError(i18n("参考音频在3~10秒范围外,请更换!"))
376
+ wav16k = torch.from_numpy(wav16k)
377
+ zero_wav_torch = torch.from_numpy(zero_wav)
378
+ if is_half == True:
379
+ wav16k = wav16k.half().to(device)
380
+ zero_wav_torch = zero_wav_torch.half().to(device)
381
+ else:
382
+ wav16k = wav16k.to(device)
383
+ zero_wav_torch = zero_wav_torch.to(device)
384
+ wav16k = torch.cat([wav16k, zero_wav_torch])
385
+ ssl_content = ssl_model.model(wav16k.unsqueeze(0))[
386
+ "last_hidden_state"
387
+ ].transpose(
388
+ 1, 2
389
+ ) # .float()
390
+ codes = vq_model.extract_latent(ssl_content)
391
+ prompt_semantic = codes[0, 0]
392
+ prompt = prompt_semantic.unsqueeze(0).to(device)
393
+
394
+ t1 = ttime()
395
+ t.append(t1-t0)
396
+
397
+ if (how_to_cut == i18n("凑四句一切")):
398
+ text = cut1(text)
399
+ elif (how_to_cut == i18n("凑50字一切")):
400
+ text = cut2(text)
401
+ elif (how_to_cut == i18n("按中文句号。切")):
402
+ text = cut3(text)
403
+ elif (how_to_cut == i18n("按英文句号.切")):
404
+ text = cut4(text)
405
+ elif (how_to_cut == i18n("按标点符号切")):
406
+ text = cut5(text)
407
+ while "\n\n" in text:
408
+ text = text.replace("\n\n", "\n")
409
+ print(i18n("实际输入的目标文本(切句后):"), text)
410
+ texts = text.split("\n")
411
+ texts = process_text(texts)
412
+ texts = merge_short_text_in_array(texts, 5)
413
+ audio_opt = []
414
+ if not ref_free:
415
+ phones1,bert1,norm_text1=get_phones_and_bert(prompt_text, prompt_language, version)
416
+
417
+ for i_text,text in enumerate(texts):
418
+ # 解决输入目标文本的空行导致报错的问题
419
+ if (len(text.strip()) == 0):
420
+ continue
421
+ if (text[-1] not in splits): text += "。" if text_language != "en" else "."
422
+ print(i18n("实际输入的目标文本(每句):"), text)
423
+ phones2,bert2,norm_text2=get_phones_and_bert(text, text_language, version)
424
+ print(i18n("前端处理后的文本(每句):"), norm_text2)
425
+ if not ref_free:
426
+ bert = torch.cat([bert1, bert2], 1)
427
+ all_phoneme_ids = torch.LongTensor(phones1+phones2).to(device).unsqueeze(0)
428
+ else:
429
+ bert = bert2
430
+ all_phoneme_ids = torch.LongTensor(phones2).to(device).unsqueeze(0)
431
 
432
+ bert = bert.to(device).unsqueeze(0)
433
+ all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
434
 
435
+ t2 = ttime()
436
+ # cache_key="%s-%s-%s-%s-%s-%s-%s-%s"%(ref_wav_path,prompt_text,prompt_language,text,text_language,top_k,top_p,temperature)
437
+ # print(cache.keys(),if_freeze)
438
+ if(i_text in cache and if_freeze==True):pred_semantic=cache[i_text]
439
+ else:
 
 
 
 
 
 
 
 
440
  with torch.no_grad():
441
+ pred_semantic, idx = t2s_model.model.infer_panel(
442
+ all_phoneme_ids,
443
+ all_phoneme_len,
444
+ None if ref_free else prompt,
445
+ bert,
446
+ # prompt_phone_len=ph_offset,
447
+ top_k=top_k,
448
+ top_p=top_p,
449
+ temperature=temperature,
450
+ early_stop_num=hz * max_sec,
451
+ )
452
+ pred_semantic = pred_semantic[:, -idx:].unsqueeze(0)
453
+ cache[i_text]=pred_semantic
454
+ t3 = ttime()
455
+ refers=[]
456
+ if(inp_refs):
457
+ for path in inp_refs:
458
+ try:
459
+ refer = get_spepc(hps, path.name).to(dtype).to(device)
460
+ refers.append(refer)
461
+ except:
462
+ traceback.print_exc()
463
+ if(len(refers)==0):refers = [get_spepc(hps, ref_wav_path).to(dtype).to(device)]
464
+ audio = (vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refers,speed=speed).detach().cpu().numpy()[0, 0])
465
+ max_audio=np.abs(audio).max()#简单防止16bit爆音
466
+ if max_audio>1:audio/=max_audio
467
+ audio_opt.append(audio)
468
+ audio_opt.append(zero_wav)
469
+ t4 = ttime()
470
+ t.extend([t2 - t1,t3 - t2, t4 - t3])
471
  t1 = ttime()
472
+ print("%.3f\t%.3f\t%.3f\t%.3f" %
473
+ (t[0], sum(t[1::3]), sum(t[2::3]), sum(t[3::3]))
474
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
475
  yield hps.data.sampling_rate, (np.concatenate(audio_opt, 0) * 32768).astype(
476
  np.int16
477
  )