File size: 3,704 Bytes
521e6a5
d930560
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
521e6a5
 
d930560
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120

import re
import gradio as gr
import torch
import unicodedata
import commons
import utils
from models import SynthesizerTrn
from text import text_to_sequence

config_json = "muse_tricolor_b.json"
pth_path = "G=496.pth"


def get_text(text, hps, cleaned=False):
    if cleaned:
        text_norm = text_to_sequence(text, hps.symbols, [])
    else:
        text_norm = text_to_sequence(text, hps.symbols, hps.data.text_cleaners)
    if hps.data.add_blank:
        text_norm = commons.intersperse(text_norm, 0)
    text_norm = torch.LongTensor(text_norm)
    return text_norm


def get_label(text, label):
    if f'[{label}]' in text:
        return True, text.replace(f'[{label}]', '')
    else:
        return False, text


def clean_text(text):
    print(text)
    jap = re.compile(r'[\u3040-\u309F\u30A0-\u30FF]')  # 匹配日文
    text = unicodedata.normalize('NFKC', text)
    text = f"[JA]{text}[JA]" if jap.search(text) else f"[ZH]{text}[ZH]"
    return text


def load_model(config_json, pth_path):
    dev = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    hps_ms = utils.get_hparams_from_file(f"{config_json}")
    n_speakers = hps_ms.data.n_speakers if 'n_speakers' in hps_ms.data.keys() else 0
    n_symbols = len(hps_ms.symbols) if 'symbols' in hps_ms.keys() else 0
    net_g_ms = SynthesizerTrn(
        n_symbols,
        hps_ms.data.filter_length // 2 + 1,
        hps_ms.train.segment_size // hps_ms.data.hop_length,
        n_speakers=n_speakers,
        **hps_ms.model).to(dev)
    _ = net_g_ms.eval()
    _ = utils.load_checkpoint(pth_path, net_g_ms)
    return net_g_ms

net_g_ms = load_model(config_json, pth_path)

def selection(speaker):
    if speaker == "南小鸟":
        spk = 0
        return spk

    elif speaker == "园田海未":
        spk = 1
        return spk

    elif speaker == "小泉花阳":
        spk = 2
        return spk

    elif speaker == "星空凛":
        spk = 3
        return spk

    elif speaker == "东条希":
        spk = 4
        return spk

    elif speaker == "矢泽妮可":
        spk = 5
        return spk

    elif speaker == "绚濑绘里":
        spk = 6
        return spk

    elif speaker == "西木野真姬":
        spk = 7
        return spk

    elif speaker == "高坂穗乃果":
        spk = 8
        return spk

def infer(text,speaker_id):
    text = clean_text(text)
    speaker_id = int(selection(speaker_id))
    dev = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    hps_ms = utils.get_hparams_from_file(f"{config_json}")
    with torch.no_grad():
        stn_tst = get_text(text, hps_ms, cleaned=False)
        x_tst = stn_tst.unsqueeze(0).to(dev)
        x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).to(dev)
        sid = torch.LongTensor([speaker_id]).to(dev)
        audio = net_g_ms.infer(x_tst, x_tst_lengths, sid=sid, noise_scale=0.667, noise_scale_w=0.8, length_scale=1)[0][
            0, 0].data.cpu().float().numpy()
    return (hps_ms.data.sampling_rate, audio)

idols = ["南小鸟","园田海未","小泉花阳","星空凛","东条希","矢泽妮可","绚濑绘里","西木野真姬","高坂穗乃果"]
app = gr.Blocks()
with app:
    with gr.Tabs():
        with gr.TabItem("面板"):
            tts_input1 = gr.TextArea(label="请输入纯中文或纯日文", value="大家好,今天给大家来点想看的东西啊")
            speaker1 = gr.Dropdown(label="选择说话人",choices=idols, value="高坂穗乃果", interactive=True)
            tts_submit = gr.Button("Generate", variant="primary")
            tts_output2 = gr.Audio(label="Output")
            tts_submit.click(infer, [tts_input1,speaker1], [tts_output2])
    app.launch()