ing0 commited on
Commit
b96e750
·
1 Parent(s): 4273121
diffrhythm/g2p/g2p/__init__.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from diffrhythm.g2p.g2p import cleaners
7
+ from tokenizers import Tokenizer
8
+ from diffrhythm.g2p.g2p.text_tokenizers import TextTokenizer
9
+ import LangSegment
10
+ import json
11
+ import re
12
+
13
+
14
+ class PhonemeBpeTokenizer:
15
+
16
+ def __init__(self, vacab_path="./diffrhythm/g2p/g2p/vocab.json"):
17
+ self.lang2backend = {
18
+ "zh": "cmn",
19
+ "ja": "ja",
20
+ "en": "en-us",
21
+ "fr": "fr-fr",
22
+ "ko": "ko",
23
+ "de": "de",
24
+ }
25
+ self.text_tokenizers = {}
26
+ self.int_text_tokenizers()
27
+
28
+ with open(vacab_path, "r") as f:
29
+ json_data = f.read()
30
+ data = json.loads(json_data)
31
+ self.vocab = data["vocab"]
32
+ LangSegment.setfilters(["en", "zh", "ja", "ko", "fr", "de"])
33
+
34
+ def int_text_tokenizers(self):
35
+ for key, value in self.lang2backend.items():
36
+ self.text_tokenizers[key] = TextTokenizer(language=value)
37
+
38
+ def tokenize(self, text, sentence, language):
39
+
40
+ # 1. convert text to phoneme
41
+ phonemes = []
42
+ if language == "auto":
43
+ seglist = LangSegment.getTexts(text)
44
+ tmp_ph = []
45
+ for seg in seglist:
46
+ tmp_ph.append(
47
+ self._clean_text(
48
+ seg["text"], sentence, seg["lang"], ["cjekfd_cleaners"]
49
+ )
50
+ )
51
+ phonemes = "|_|".join(tmp_ph)
52
+ else:
53
+ phonemes = self._clean_text(text, sentence, language, ["cjekfd_cleaners"])
54
+ # print('clean text: ', phonemes)
55
+
56
+ # 2. tokenize phonemes
57
+ phoneme_tokens = self.phoneme2token(phonemes)
58
+ # print('encode: ', phoneme_tokens)
59
+
60
+ # # 3. decode tokens [optional]
61
+ # decoded_text = self.tokenizer.decode(phoneme_tokens)
62
+ # print('decoded: ', decoded_text)
63
+
64
+ return phonemes, phoneme_tokens
65
+
66
+ def _clean_text(self, text, sentence, language, cleaner_names):
67
+ for name in cleaner_names:
68
+ cleaner = getattr(cleaners, name)
69
+ if not cleaner:
70
+ raise Exception("Unknown cleaner: %s" % name)
71
+ text = cleaner(text, sentence, language, self.text_tokenizers)
72
+ return text
73
+
74
+ def phoneme2token(self, phonemes):
75
+ tokens = []
76
+ if isinstance(phonemes, list):
77
+ for phone in phonemes:
78
+ phone = phone.split("\t")[0]
79
+ phonemes_split = phone.split("|")
80
+ tokens.append(
81
+ [self.vocab[p] for p in phonemes_split if p in self.vocab]
82
+ )
83
+ else:
84
+ phonemes = phonemes.split("\t")[0]
85
+ phonemes_split = phonemes.split("|")
86
+ tokens = [self.vocab[p] for p in phonemes_split if p in self.vocab]
87
+ return tokens
diffrhythm/g2p/g2p/chinese_model_g2p.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import os
7
+ import numpy as np
8
+ import torch
9
+ from torch.utils.data import DataLoader
10
+ import json
11
+ from transformers import BertTokenizer
12
+ from torch.utils.data import Dataset
13
+ from transformers.models.bert.modeling_bert import *
14
+ import torch
15
+ import torch.nn.functional as F
16
+ from onnxruntime import InferenceSession, GraphOptimizationLevel, SessionOptions
17
+
18
+
19
+ class PolyDataset(Dataset):
20
+ def __init__(self, words, labels, word_pad_idx=0, label_pad_idx=-1):
21
+ self.dataset = self.preprocess(words, labels)
22
+ self.word_pad_idx = word_pad_idx
23
+ self.label_pad_idx = label_pad_idx
24
+
25
+ def preprocess(self, origin_sentences, origin_labels):
26
+ """
27
+ Maps tokens and tags to their indices and stores them in the dict data.
28
+ examples:
29
+ word:['[CLS]', '浙', '商', '银', '行', '企', '业', '信', '贷', '部']
30
+ sentence:([101, 3851, 1555, 7213, 6121, 821, 689, 928, 6587, 6956],
31
+ array([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]))
32
+ label:[3, 13, 13, 13, 0, 0, 0, 0, 0]
33
+ """
34
+ data = []
35
+ labels = []
36
+ sentences = []
37
+ # tokenize
38
+ for line in origin_sentences:
39
+ # replace each token by its index
40
+ # we can not use encode_plus because our sentences are aligned to labels in list type
41
+ words = []
42
+ word_lens = []
43
+ for token in line:
44
+ words.append(token)
45
+ word_lens.append(1)
46
+ token_start_idxs = 1 + np.cumsum([0] + word_lens[:-1])
47
+ sentences.append(((words, token_start_idxs), 0))
48
+ ###
49
+ for tag in origin_labels:
50
+ labels.append(tag)
51
+
52
+ for sentence, label in zip(sentences, labels):
53
+ data.append((sentence, label))
54
+ return data
55
+
56
+ def __getitem__(self, idx):
57
+ """sample data to get batch"""
58
+ word = self.dataset[idx][0]
59
+ label = self.dataset[idx][1]
60
+ return [word, label]
61
+
62
+ def __len__(self):
63
+ """get dataset size"""
64
+ return len(self.dataset)
65
+
66
+ def collate_fn(self, batch):
67
+
68
+ sentences = [x[0][0] for x in batch]
69
+ ori_sents = [x[0][1] for x in batch]
70
+ labels = [x[1] for x in batch]
71
+ batch_len = len(sentences)
72
+
73
+ # compute length of longest sentence in batch
74
+ max_len = max([len(s[0]) for s in sentences])
75
+ max_label_len = 0
76
+ batch_data = np.ones((batch_len, max_len))
77
+ batch_label_starts = []
78
+
79
+ # padding and aligning
80
+ for j in range(batch_len):
81
+ cur_len = len(sentences[j][0])
82
+ batch_data[j][:cur_len] = sentences[j][0]
83
+ label_start_idx = sentences[j][-1]
84
+ label_starts = np.zeros(max_len)
85
+ label_starts[[idx for idx in label_start_idx if idx < max_len]] = 1
86
+ batch_label_starts.append(label_starts)
87
+ max_label_len = max(int(sum(label_starts)), max_label_len)
88
+
89
+ # padding label
90
+ batch_labels = self.label_pad_idx * np.ones((batch_len, max_label_len))
91
+ batch_pmasks = self.label_pad_idx * np.ones((batch_len, max_label_len))
92
+ for j in range(batch_len):
93
+ cur_tags_len = len(labels[j])
94
+ batch_labels[j][:cur_tags_len] = labels[j]
95
+ batch_pmasks[j][:cur_tags_len] = [
96
+ 1 if item > 0 else 0 for item in labels[j]
97
+ ]
98
+
99
+ # convert data to torch LongTensors
100
+ batch_data = torch.tensor(batch_data, dtype=torch.long)
101
+ batch_label_starts = torch.tensor(batch_label_starts, dtype=torch.long)
102
+ batch_labels = torch.tensor(batch_labels, dtype=torch.long)
103
+ batch_pmasks = torch.tensor(batch_pmasks, dtype=torch.long)
104
+ return [batch_data, batch_label_starts, batch_labels, batch_pmasks, ori_sents]
105
+
106
+
107
+ class BertPolyPredict:
108
+ def __init__(self, bert_model, jsonr_file, json_file):
109
+ self.tokenizer = BertTokenizer.from_pretrained(bert_model, do_lower_case=True)
110
+ with open(jsonr_file, "r", encoding="utf8") as fp:
111
+ self.pron_dict = json.load(fp)
112
+ with open(json_file, "r", encoding="utf8") as fp:
113
+ self.pron_dict_id_2_pinyin = json.load(fp)
114
+ self.num_polyphone = len(self.pron_dict)
115
+ self.device = "cpu"
116
+ self.polydataset = PolyDataset
117
+ options = SessionOptions() # initialize session options
118
+ options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
119
+ print(os.path.join(bert_model, "poly_bert_model.onnx"))
120
+ self.session = InferenceSession(
121
+ os.path.join(bert_model, "poly_bert_model.onnx"),
122
+ sess_options=options,
123
+ providers=[
124
+ "CUDAExecutionProvider",
125
+ "CPUExecutionProvider",
126
+ ], # CPUExecutionProvider #CUDAExecutionProvider
127
+ )
128
+ # self.session.set_providers(['CUDAExecutionProvider', "CPUExecutionProvider"], [ {'device_id': 0}])
129
+
130
+ # disable session.run() fallback mechanism, it prevents for a reset of the execution provider
131
+ self.session.disable_fallback()
132
+
133
+ def predict_process(self, txt_list):
134
+ word_test, label_test, texts_test = self.get_examples_po(txt_list)
135
+ data = self.polydataset(word_test, label_test)
136
+ predict_loader = DataLoader(
137
+ data, batch_size=1, shuffle=False, collate_fn=data.collate_fn
138
+ )
139
+ pred_tags = self.predict_onnx(predict_loader)
140
+ return pred_tags
141
+
142
+ def predict_onnx(self, dev_loader):
143
+ pred_tags = []
144
+ with torch.no_grad():
145
+ for idx, batch_samples in enumerate(dev_loader):
146
+ # [batch_data, batch_label_starts, batch_labels, batch_pmasks, ori_sents]
147
+ batch_data, batch_label_starts, batch_labels, batch_pmasks, _ = (
148
+ batch_samples
149
+ )
150
+ # shift tensors to GPU if available
151
+ batch_data = batch_data.to(self.device)
152
+ batch_label_starts = batch_label_starts.to(self.device)
153
+ batch_labels = batch_labels.to(self.device)
154
+ batch_pmasks = batch_pmasks.to(self.device)
155
+ batch_data = np.asarray(batch_data, dtype=np.int32)
156
+ batch_pmasks = np.asarray(batch_pmasks, dtype=np.int32)
157
+ # batch_output = self.session.run(output_names=['outputs'], input_feed={"input_ids":batch_data, "input_pmasks": batch_pmasks})[0][0]
158
+ batch_output = self.session.run(
159
+ output_names=["outputs"], input_feed={"input_ids": batch_data}
160
+ )[0]
161
+ label_masks = batch_pmasks == 1
162
+ batch_labels = batch_labels.to("cpu").numpy()
163
+ for i, indices in enumerate(np.argmax(batch_output, axis=2)):
164
+ for j, idx in enumerate(indices):
165
+ if label_masks[i][j]:
166
+ # pred_tag.append(idx)
167
+ pred_tags.append(self.pron_dict_id_2_pinyin[str(idx + 1)])
168
+ return pred_tags
169
+
170
+ def get_examples_po(self, text_list):
171
+
172
+ word_list = []
173
+ label_list = []
174
+ sentence_list = []
175
+ id = 0
176
+ for line in [text_list]:
177
+ sentence = line[0]
178
+ words = []
179
+ tokens = line[0]
180
+ index = line[-1]
181
+ front = index
182
+ back = len(tokens) - index - 1
183
+ labels = [0] * front + [1] + [0] * back
184
+ words = ["[CLS]"] + [item for item in sentence]
185
+ words = self.tokenizer.convert_tokens_to_ids(words)
186
+ word_list.append(words)
187
+ label_list.append(labels)
188
+ sentence_list.append(sentence)
189
+
190
+ id += 1
191
+ # mask_list.append(masks)
192
+ assert len(labels) + 1 == len(words), print(
193
+ (
194
+ poly,
195
+ sentence,
196
+ words,
197
+ labels,
198
+ sentence,
199
+ len(sentence),
200
+ len(words),
201
+ len(labels),
202
+ )
203
+ )
204
+ assert len(labels) + 1 == len(
205
+ words
206
+ ), "Number of labels does not match number of words"
207
+ assert len(labels) == len(
208
+ sentence
209
+ ), "Number of labels does not match number of sentences"
210
+ assert len(word_list) == len(
211
+ label_list
212
+ ), "Number of label sentences does not match number of word sentences"
213
+ return word_list, label_list, text_list
diffrhythm/g2p/g2p/cleaners.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import re
7
+ from diffrhythm.g2p.g2p.japanese import japanese_to_ipa
8
+ from diffrhythm.g2p.g2p.mandarin import chinese_to_ipa
9
+ from diffrhythm.g2p.g2p.english import english_to_ipa
10
+ from diffrhythm.g2p.g2p.french import french_to_ipa
11
+ from diffrhythm.g2p.g2p.korean import korean_to_ipa
12
+ from diffrhythm.g2p.g2p.german import german_to_ipa
13
+
14
+
15
+ def cjekfd_cleaners(text, sentence, language, text_tokenizers):
16
+
17
+ if language == "zh":
18
+ return chinese_to_ipa(text, sentence, text_tokenizers["zh"])
19
+ elif language == "ja":
20
+ return japanese_to_ipa(text, text_tokenizers["ja"])
21
+ elif language == "en":
22
+ return english_to_ipa(text, text_tokenizers["en"])
23
+ elif language == "fr":
24
+ return french_to_ipa(text, text_tokenizers["fr"])
25
+ elif language == "ko":
26
+ return korean_to_ipa(text, text_tokenizers["ko"])
27
+ elif language == "de":
28
+ return german_to_ipa(text, text_tokenizers["de"])
29
+ else:
30
+ raise Exception("Unknown language: %s" % language)
31
+ return None
diffrhythm/g2p/g2p/english.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import re
7
+ from unidecode import unidecode
8
+ import inflect
9
+
10
+ """
11
+ Text clean time
12
+ """
13
+ _inflect = inflect.engine()
14
+ _comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])")
15
+ _decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)")
16
+ _percent_number_re = re.compile(r"([0-9\.\,]*[0-9]+%)")
17
+ _pounds_re = re.compile(r"£([0-9\,]*[0-9]+)")
18
+ _dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)")
19
+ _fraction_re = re.compile(r"([0-9]+)/([0-9]+)")
20
+ _ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)")
21
+ _number_re = re.compile(r"[0-9]+")
22
+
23
+ # List of (regular expression, replacement) pairs for abbreviations:
24
+ _abbreviations = [
25
+ (re.compile("\\b%s\\b" % x[0], re.IGNORECASE), x[1])
26
+ for x in [
27
+ ("mrs", "misess"),
28
+ ("mr", "mister"),
29
+ ("dr", "doctor"),
30
+ ("st", "saint"),
31
+ ("co", "company"),
32
+ ("jr", "junior"),
33
+ ("maj", "major"),
34
+ ("gen", "general"),
35
+ ("drs", "doctors"),
36
+ ("rev", "reverend"),
37
+ ("lt", "lieutenant"),
38
+ ("hon", "honorable"),
39
+ ("sgt", "sergeant"),
40
+ ("capt", "captain"),
41
+ ("esq", "esquire"),
42
+ ("ltd", "limited"),
43
+ ("col", "colonel"),
44
+ ("ft", "fort"),
45
+ ("etc", "et cetera"),
46
+ ("btw", "by the way"),
47
+ ]
48
+ ]
49
+
50
+ _special_map = [
51
+ ("t|ɹ", "tɹ"),
52
+ ("d|ɹ", "dɹ"),
53
+ ("t|s", "ts"),
54
+ ("d|z", "dz"),
55
+ ("ɪ|ɹ", "ɪɹ"),
56
+ ("ɐ", "ɚ"),
57
+ ("ᵻ", "ɪ"),
58
+ ("əl", "l"),
59
+ ("x", "k"),
60
+ ("ɬ", "l"),
61
+ ("ʔ", "t"),
62
+ ("n̩", "n"),
63
+ ("oː|ɹ", "oːɹ"),
64
+ ]
65
+
66
+
67
+ def expand_abbreviations(text):
68
+ for regex, replacement in _abbreviations:
69
+ text = re.sub(regex, replacement, text)
70
+ return text
71
+
72
+
73
+ def _remove_commas(m):
74
+ return m.group(1).replace(",", "")
75
+
76
+
77
+ def _expand_decimal_point(m):
78
+ return m.group(1).replace(".", " point ")
79
+
80
+
81
+ def _expand_percent(m):
82
+ return m.group(1).replace("%", " percent ")
83
+
84
+
85
+ def _expand_dollars(m):
86
+ match = m.group(1)
87
+ parts = match.split(".")
88
+ if len(parts) > 2:
89
+ return " " + match + " dollars " # Unexpected format
90
+ dollars = int(parts[0]) if parts[0] else 0
91
+ cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
92
+ if dollars and cents:
93
+ dollar_unit = "dollar" if dollars == 1 else "dollars"
94
+ cent_unit = "cent" if cents == 1 else "cents"
95
+ return " %s %s, %s %s " % (dollars, dollar_unit, cents, cent_unit)
96
+ elif dollars:
97
+ dollar_unit = "dollar" if dollars == 1 else "dollars"
98
+ return " %s %s " % (dollars, dollar_unit)
99
+ elif cents:
100
+ cent_unit = "cent" if cents == 1 else "cents"
101
+ return " %s %s " % (cents, cent_unit)
102
+ else:
103
+ return " zero dollars "
104
+
105
+
106
+ def fraction_to_words(numerator, denominator):
107
+ if numerator == 1 and denominator == 2:
108
+ return " one half "
109
+ if numerator == 1 and denominator == 4:
110
+ return " one quarter "
111
+ if denominator == 2:
112
+ return " " + _inflect.number_to_words(numerator) + " halves "
113
+ if denominator == 4:
114
+ return " " + _inflect.number_to_words(numerator) + " quarters "
115
+ return (
116
+ " "
117
+ + _inflect.number_to_words(numerator)
118
+ + " "
119
+ + _inflect.ordinal(_inflect.number_to_words(denominator))
120
+ + " "
121
+ )
122
+
123
+
124
+ def _expand_fraction(m):
125
+ numerator = int(m.group(1))
126
+ denominator = int(m.group(2))
127
+ return fraction_to_words(numerator, denominator)
128
+
129
+
130
+ def _expand_ordinal(m):
131
+ return " " + _inflect.number_to_words(m.group(0)) + " "
132
+
133
+
134
+ def _expand_number(m):
135
+ num = int(m.group(0))
136
+ if num > 1000 and num < 3000:
137
+ if num == 2000:
138
+ return " two thousand "
139
+ elif num > 2000 and num < 2010:
140
+ return " two thousand " + _inflect.number_to_words(num % 100) + " "
141
+ elif num % 100 == 0:
142
+ return " " + _inflect.number_to_words(num // 100) + " hundred "
143
+ else:
144
+ return (
145
+ " "
146
+ + _inflect.number_to_words(num, andword="", zero="oh", group=2).replace(
147
+ ", ", " "
148
+ )
149
+ + " "
150
+ )
151
+ else:
152
+ return " " + _inflect.number_to_words(num, andword="") + " "
153
+
154
+
155
+ # Normalize numbers pronunciation
156
+ def normalize_numbers(text):
157
+ text = re.sub(_comma_number_re, _remove_commas, text)
158
+ text = re.sub(_pounds_re, r"\1 pounds", text)
159
+ text = re.sub(_dollars_re, _expand_dollars, text)
160
+ text = re.sub(_fraction_re, _expand_fraction, text)
161
+ text = re.sub(_decimal_number_re, _expand_decimal_point, text)
162
+ text = re.sub(_percent_number_re, _expand_percent, text)
163
+ text = re.sub(_ordinal_re, _expand_ordinal, text)
164
+ text = re.sub(_number_re, _expand_number, text)
165
+ return text
166
+
167
+
168
+ def _english_to_ipa(text):
169
+ # text = unidecode(text).lower()
170
+ text = expand_abbreviations(text)
171
+ text = normalize_numbers(text)
172
+ return text
173
+
174
+
175
+ # special map
176
+ def special_map(text):
177
+ for regex, replacement in _special_map:
178
+ regex = regex.replace("|", "\|")
179
+ while re.search(r"(^|[_|]){}([_|]|$)".format(regex), text):
180
+ text = re.sub(
181
+ r"(^|[_|]){}([_|]|$)".format(regex), r"\1{}\2".format(replacement), text
182
+ )
183
+ # text = re.sub(r'([,.!?])', r'|\1', text)
184
+ return text
185
+
186
+
187
+ # Add some special operation
188
+ def english_to_ipa(text, text_tokenizer):
189
+ if type(text) == str:
190
+ text = _english_to_ipa(text)
191
+ else:
192
+ text = [_english_to_ipa(t) for t in text]
193
+ phonemes = text_tokenizer(text)
194
+ if phonemes[-1] in "p⁼ʰmftnlkxʃs`ɹaoəɛɪeɑʊŋiuɥwæjː":
195
+ phonemes += "|_"
196
+ if type(text) == str:
197
+ return special_map(phonemes)
198
+ else:
199
+ result_ph = []
200
+ for phone in phonemes:
201
+ result_ph.append(special_map(phone))
202
+ return result_ph
diffrhythm/g2p/g2p/french.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import re
7
+
8
+ """
9
+ Text clean time
10
+ """
11
+ # List of (regular expression, replacement) pairs for abbreviations in french:
12
+ _abbreviations = [
13
+ (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
14
+ for x in [
15
+ ("M", "monsieur"),
16
+ ("Mlle", "mademoiselle"),
17
+ ("Mlles", "mesdemoiselles"),
18
+ ("Mme", "Madame"),
19
+ ("Mmes", "Mesdames"),
20
+ ("N.B", "nota bene"),
21
+ ("M", "monsieur"),
22
+ ("p.c.q", "parce que"),
23
+ ("Pr", "professeur"),
24
+ ("qqch", "quelque chose"),
25
+ ("rdv", "rendez-vous"),
26
+ ("max", "maximum"),
27
+ ("min", "minimum"),
28
+ ("no", "numéro"),
29
+ ("adr", "adresse"),
30
+ ("dr", "docteur"),
31
+ ("st", "saint"),
32
+ ("co", "companie"),
33
+ ("jr", "junior"),
34
+ ("sgt", "sergent"),
35
+ ("capt", "capitain"),
36
+ ("col", "colonel"),
37
+ ("av", "avenue"),
38
+ ("av. J.-C", "avant Jésus-Christ"),
39
+ ("apr. J.-C", "après Jésus-Christ"),
40
+ ("art", "article"),
41
+ ("boul", "boulevard"),
42
+ ("c.-à-d", "c’est-à-dire"),
43
+ ("etc", "et cetera"),
44
+ ("ex", "exemple"),
45
+ ("excl", "exclusivement"),
46
+ ("boul", "boulevard"),
47
+ ]
48
+ ] + [
49
+ (re.compile("\\b%s" % x[0]), x[1])
50
+ for x in [
51
+ ("Mlle", "mademoiselle"),
52
+ ("Mlles", "mesdemoiselles"),
53
+ ("Mme", "Madame"),
54
+ ("Mmes", "Mesdames"),
55
+ ]
56
+ ]
57
+
58
+ rep_map = {
59
+ ":": ",",
60
+ ";": ",",
61
+ ",": ",",
62
+ "。": ".",
63
+ "!": "!",
64
+ "?": "?",
65
+ "\n": ".",
66
+ "·": ",",
67
+ "、": ",",
68
+ "...": ".",
69
+ "…": ".",
70
+ "$": ".",
71
+ "“": "",
72
+ "”": "",
73
+ "‘": "",
74
+ "’": "",
75
+ "(": "",
76
+ ")": "",
77
+ "(": "",
78
+ ")": "",
79
+ "《": "",
80
+ "》": "",
81
+ "【": "",
82
+ "】": "",
83
+ "[": "",
84
+ "]": "",
85
+ "—": "",
86
+ "~": "-",
87
+ "~": "-",
88
+ "「": "",
89
+ "」": "",
90
+ "¿": "",
91
+ "¡": "",
92
+ }
93
+
94
+
95
+ def collapse_whitespace(text):
96
+ # Regular expression matching whitespace:
97
+ _whitespace_re = re.compile(r"\s+")
98
+ return re.sub(_whitespace_re, " ", text).strip()
99
+
100
+
101
+ def remove_punctuation_at_begin(text):
102
+ return re.sub(r"^[,.!?]+", "", text)
103
+
104
+
105
+ def remove_aux_symbols(text):
106
+ text = re.sub(r"[\<\>\(\)\[\]\"\«\»]+", "", text)
107
+ return text
108
+
109
+
110
+ def replace_symbols(text):
111
+ text = text.replace(";", ",")
112
+ text = text.replace("-", " ")
113
+ text = text.replace(":", ",")
114
+ text = text.replace("&", " et ")
115
+ return text
116
+
117
+
118
+ def expand_abbreviations(text):
119
+ for regex, replacement in _abbreviations:
120
+ text = re.sub(regex, replacement, text)
121
+ return text
122
+
123
+
124
+ def replace_punctuation(text):
125
+ pattern = re.compile("|".join(re.escape(p) for p in rep_map.keys()))
126
+ replaced_text = pattern.sub(lambda x: rep_map[x.group()], text)
127
+ return replaced_text
128
+
129
+
130
+ def text_normalize(text):
131
+ text = expand_abbreviations(text)
132
+ text = replace_punctuation(text)
133
+ text = replace_symbols(text)
134
+ text = remove_aux_symbols(text)
135
+ text = remove_punctuation_at_begin(text)
136
+ text = collapse_whitespace(text)
137
+ text = re.sub(r"([^\.,!\?\-…])$", r"\1", text)
138
+ return text
139
+
140
+
141
+ def french_to_ipa(text, text_tokenizer):
142
+ if type(text) == str:
143
+ text = text_normalize(text)
144
+ phonemes = text_tokenizer(text)
145
+ return phonemes
146
+ else:
147
+ for i, t in enumerate(text):
148
+ text[i] = text_normalize(t)
149
+ return text_tokenizer(text)
diffrhythm/g2p/g2p/german.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import re
7
+
8
+ """
9
+ Text clean time
10
+ """
11
+ rep_map = {
12
+ ":": ",",
13
+ ";": ",",
14
+ ",": ",",
15
+ "。": ".",
16
+ "!": "!",
17
+ "?": "?",
18
+ "\n": ".",
19
+ "·": ",",
20
+ "、": ",",
21
+ "...": ".",
22
+ "…": ".",
23
+ "$": ".",
24
+ "“": "",
25
+ "”": "",
26
+ "‘": "",
27
+ "’": "",
28
+ "(": "",
29
+ ")": "",
30
+ "(": "",
31
+ ")": "",
32
+ "《": "",
33
+ "》": "",
34
+ "【": "",
35
+ "】": "",
36
+ "[": "",
37
+ "]": "",
38
+ "—": "",
39
+ "~": "-",
40
+ "~": "-",
41
+ "「": "",
42
+ "」": "",
43
+ "¿": "",
44
+ "¡": "",
45
+ }
46
+
47
+
48
+ def collapse_whitespace(text):
49
+ # Regular expression matching whitespace:
50
+ _whitespace_re = re.compile(r"\s+")
51
+ return re.sub(_whitespace_re, " ", text).strip()
52
+
53
+
54
+ def remove_punctuation_at_begin(text):
55
+ return re.sub(r"^[,.!?]+", "", text)
56
+
57
+
58
+ def remove_aux_symbols(text):
59
+ text = re.sub(r"[\<\>\(\)\[\]\"\«\»]+", "", text)
60
+ return text
61
+
62
+
63
+ def replace_symbols(text):
64
+ text = text.replace(";", ",")
65
+ text = text.replace("-", " ")
66
+ text = text.replace(":", ",")
67
+ return text
68
+
69
+
70
+ def replace_punctuation(text):
71
+ pattern = re.compile("|".join(re.escape(p) for p in rep_map.keys()))
72
+ replaced_text = pattern.sub(lambda x: rep_map[x.group()], text)
73
+ return replaced_text
74
+
75
+
76
+ def text_normalize(text):
77
+ text = replace_punctuation(text)
78
+ text = replace_symbols(text)
79
+ text = remove_aux_symbols(text)
80
+ text = remove_punctuation_at_begin(text)
81
+ text = collapse_whitespace(text)
82
+ text = re.sub(r"([^\.,!\?\-…])$", r"\1", text)
83
+ return text
84
+
85
+
86
+ def german_to_ipa(text, text_tokenizer):
87
+ if type(text) == str:
88
+ text = text_normalize(text)
89
+ phonemes = text_tokenizer(text)
90
+ return phonemes
91
+ else:
92
+ for i, t in enumerate(text):
93
+ text[i] = text_normalize(t)
94
+ return text_tokenizer(text)
diffrhythm/g2p/g2p/japanese.py ADDED
@@ -0,0 +1,816 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import io, re, os, sys, time, argparse, pdb, json
7
+ from io import StringIO
8
+ from typing import Optional
9
+ import numpy as np
10
+ import traceback
11
+ import pyopenjtalk
12
+ from pykakasi import kakasi
13
+
14
+ punctuation = [",", ".", "!", "?", ":", ";", "'", "…"]
15
+
16
+ jp_xphone2ipa = [
17
+ " a a",
18
+ " i i",
19
+ " u ɯ",
20
+ " e e",
21
+ " o o",
22
+ " a: aː",
23
+ " i: iː",
24
+ " u: ɯː",
25
+ " e: eː",
26
+ " o: oː",
27
+ " k k",
28
+ " s s",
29
+ " t t",
30
+ " n n",
31
+ " h ç",
32
+ " f ɸ",
33
+ " m m",
34
+ " y j",
35
+ " r ɾ",
36
+ " w ɰᵝ",
37
+ " N ɴ",
38
+ " g g",
39
+ " j d ʑ",
40
+ " z z",
41
+ " d d",
42
+ " b b",
43
+ " p p",
44
+ " q q",
45
+ " v v",
46
+ " : :",
47
+ " by b j",
48
+ " ch t ɕ",
49
+ " dy d e j",
50
+ " ty t e j",
51
+ " gy g j",
52
+ " gw g ɯ",
53
+ " hy ç j",
54
+ " ky k j",
55
+ " kw k ɯ",
56
+ " my m j",
57
+ " ny n j",
58
+ " py p j",
59
+ " ry ɾ j",
60
+ " sh ɕ",
61
+ " ts t s ɯ",
62
+ ]
63
+
64
+ _mora_list_minimum: list[tuple[str, Optional[str], str]] = [
65
+ ("ヴォ", "v", "o"),
66
+ ("ヴェ", "v", "e"),
67
+ ("ヴィ", "v", "i"),
68
+ ("ヴァ", "v", "a"),
69
+ ("ヴ", "v", "u"),
70
+ ("ン", None, "N"),
71
+ ("ワ", "w", "a"),
72
+ ("ロ", "r", "o"),
73
+ ("レ", "r", "e"),
74
+ ("ル", "r", "u"),
75
+ ("リョ", "ry", "o"),
76
+ ("リュ", "ry", "u"),
77
+ ("リャ", "ry", "a"),
78
+ ("リェ", "ry", "e"),
79
+ ("リ", "r", "i"),
80
+ ("ラ", "r", "a"),
81
+ ("ヨ", "y", "o"),
82
+ ("ユ", "y", "u"),
83
+ ("ヤ", "y", "a"),
84
+ ("モ", "m", "o"),
85
+ ("メ", "m", "e"),
86
+ ("ム", "m", "u"),
87
+ ("ミョ", "my", "o"),
88
+ ("ミュ", "my", "u"),
89
+ ("ミャ", "my", "a"),
90
+ ("ミェ", "my", "e"),
91
+ ("ミ", "m", "i"),
92
+ ("マ", "m", "a"),
93
+ ("ポ", "p", "o"),
94
+ ("ボ", "b", "o"),
95
+ ("ホ", "h", "o"),
96
+ ("ペ", "p", "e"),
97
+ ("ベ", "b", "e"),
98
+ ("ヘ", "h", "e"),
99
+ ("プ", "p", "u"),
100
+ ("ブ", "b", "u"),
101
+ ("フォ", "f", "o"),
102
+ ("フェ", "f", "e"),
103
+ ("フィ", "f", "i"),
104
+ ("ファ", "f", "a"),
105
+ ("フ", "f", "u"),
106
+ ("ピョ", "py", "o"),
107
+ ("ピュ", "py", "u"),
108
+ ("ピャ", "py", "a"),
109
+ ("ピェ", "py", "e"),
110
+ ("ピ", "p", "i"),
111
+ ("ビョ", "by", "o"),
112
+ ("ビュ", "by", "u"),
113
+ ("ビャ", "by", "a"),
114
+ ("ビェ", "by", "e"),
115
+ ("ビ", "b", "i"),
116
+ ("ヒョ", "hy", "o"),
117
+ ("ヒュ", "hy", "u"),
118
+ ("ヒャ", "hy", "a"),
119
+ ("ヒェ", "hy", "e"),
120
+ ("ヒ", "h", "i"),
121
+ ("パ", "p", "a"),
122
+ ("バ", "b", "a"),
123
+ ("ハ", "h", "a"),
124
+ ("ノ", "n", "o"),
125
+ ("ネ", "n", "e"),
126
+ ("ヌ", "n", "u"),
127
+ ("ニョ", "ny", "o"),
128
+ ("ニュ", "ny", "u"),
129
+ ("ニャ", "ny", "a"),
130
+ ("ニェ", "ny", "e"),
131
+ ("ニ", "n", "i"),
132
+ ("ナ", "n", "a"),
133
+ ("ドゥ", "d", "u"),
134
+ ("ド", "d", "o"),
135
+ ("トゥ", "t", "u"),
136
+ ("ト", "t", "o"),
137
+ ("デョ", "dy", "o"),
138
+ ("デュ", "dy", "u"),
139
+ ("デャ", "dy", "a"),
140
+ # ("デェ", "dy", "e"),
141
+ ("ディ", "d", "i"),
142
+ ("デ", "d", "e"),
143
+ ("テョ", "ty", "o"),
144
+ ("テュ", "ty", "u"),
145
+ ("テャ", "ty", "a"),
146
+ ("ティ", "t", "i"),
147
+ ("テ", "t", "e"),
148
+ ("ツォ", "ts", "o"),
149
+ ("ツェ", "ts", "e"),
150
+ ("ツィ", "ts", "i"),
151
+ ("ツァ", "ts", "a"),
152
+ ("ツ", "ts", "u"),
153
+ ("ッ", None, "q"), # 「cl」から「q」に変更
154
+ ("チョ", "ch", "o"),
155
+ ("チュ", "ch", "u"),
156
+ ("チャ", "ch", "a"),
157
+ ("チェ", "ch", "e"),
158
+ ("チ", "ch", "i"),
159
+ ("ダ", "d", "a"),
160
+ ("タ", "t", "a"),
161
+ ("ゾ", "z", "o"),
162
+ ("ソ", "s", "o"),
163
+ ("ゼ", "z", "e"),
164
+ ("セ", "s", "e"),
165
+ ("ズィ", "z", "i"),
166
+ ("ズ", "z", "u"),
167
+ ("スィ", "s", "i"),
168
+ ("ス", "s", "u"),
169
+ ("ジョ", "j", "o"),
170
+ ("ジュ", "j", "u"),
171
+ ("ジャ", "j", "a"),
172
+ ("ジェ", "j", "e"),
173
+ ("ジ", "j", "i"),
174
+ ("ショ", "sh", "o"),
175
+ ("シュ", "sh", "u"),
176
+ ("シャ", "sh", "a"),
177
+ ("シェ", "sh", "e"),
178
+ ("シ", "sh", "i"),
179
+ ("ザ", "z", "a"),
180
+ ("サ", "s", "a"),
181
+ ("ゴ", "g", "o"),
182
+ ("コ", "k", "o"),
183
+ ("ゲ", "g", "e"),
184
+ ("ケ", "k", "e"),
185
+ ("グヮ", "gw", "a"),
186
+ ("グ", "g", "u"),
187
+ ("クヮ", "kw", "a"),
188
+ ("ク", "k", "u"),
189
+ ("ギョ", "gy", "o"),
190
+ ("ギュ", "gy", "u"),
191
+ ("ギャ", "gy", "a"),
192
+ ("ギェ", "gy", "e"),
193
+ ("ギ", "g", "i"),
194
+ ("キョ", "ky", "o"),
195
+ ("キュ", "ky", "u"),
196
+ ("キャ", "ky", "a"),
197
+ ("キェ", "ky", "e"),
198
+ ("キ", "k", "i"),
199
+ ("ガ", "g", "a"),
200
+ ("カ", "k", "a"),
201
+ ("オ", None, "o"),
202
+ ("エ", None, "e"),
203
+ ("ウォ", "w", "o"),
204
+ ("ウェ", "w", "e"),
205
+ ("ウィ", "w", "i"),
206
+ ("ウ", None, "u"),
207
+ ("イェ", "y", "e"),
208
+ ("イ", None, "i"),
209
+ ("ア", None, "a"),
210
+ ]
211
+
212
+ _mora_list_additional: list[tuple[str, Optional[str], str]] = [
213
+ ("ヴョ", "by", "o"),
214
+ ("ヴュ", "by", "u"),
215
+ ("ヴャ", "by", "a"),
216
+ ("ヲ", None, "o"),
217
+ ("ヱ", None, "e"),
218
+ ("ヰ", None, "i"),
219
+ ("ヮ", "w", "a"),
220
+ ("ョ", "y", "o"),
221
+ ("ュ", "y", "u"),
222
+ ("ヅ", "z", "u"),
223
+ ("ヂ", "j", "i"),
224
+ ("ヶ", "k", "e"),
225
+ ("ャ", "y", "a"),
226
+ ("ォ", None, "o"),
227
+ ("ェ", None, "e"),
228
+ ("ゥ", None, "u"),
229
+ ("ィ", None, "i"),
230
+ ("ァ", None, "a"),
231
+ ]
232
+
233
+ # 例: "vo" -> "ヴォ", "a" -> "ア"
234
+ mora_phonemes_to_mora_kata: dict[str, str] = {
235
+ (consonant or "") + vowel: kana for [kana, consonant, vowel] in _mora_list_minimum
236
+ }
237
+
238
+ # 例: "ヴォ" -> ("v", "o"), "ア" -> (None, "a")
239
+ mora_kata_to_mora_phonemes: dict[str, tuple[Optional[str], str]] = {
240
+ kana: (consonant, vowel)
241
+ for [kana, consonant, vowel] in _mora_list_minimum + _mora_list_additional
242
+ }
243
+
244
+
245
+ # 正規化で記号を変換するための辞書
246
+ rep_map = {
247
+ ":": ":",
248
+ ";": ";",
249
+ ",": ",",
250
+ "。": ".",
251
+ "!": "!",
252
+ "?": "?",
253
+ "\n": ".",
254
+ ".": ".",
255
+ "⋯": "…",
256
+ "···": "…",
257
+ "・・・": "…",
258
+ "·": ",",
259
+ "・": ",",
260
+ "•": ",",
261
+ "、": ",",
262
+ "$": ".",
263
+ # "“": "'",
264
+ # "”": "'",
265
+ # '"': "'",
266
+ "‘": "'",
267
+ "’": "'",
268
+ # "(": "'",
269
+ # ")": "'",
270
+ # "(": "'",
271
+ # ")": "'",
272
+ # "《": "'",
273
+ # "》": "'",
274
+ # "【": "'",
275
+ # "】": "'",
276
+ # "[": "'",
277
+ # "]": "'",
278
+ # "——": "-",
279
+ # "−": "-",
280
+ # "-": "-",
281
+ # "『": "'",
282
+ # "』": "'",
283
+ # "〈": "'",
284
+ # "〉": "'",
285
+ # "«": "'",
286
+ # "»": "'",
287
+ # # "~": "-", # これは長音記号「ー」として扱うよう変更
288
+ # # "~": "-", # これは長音記号「ー」として扱うよう変更
289
+ # "「": "'",
290
+ # "」": "'",
291
+ }
292
+
293
+
294
+ def _numeric_feature_by_regex(regex, s):
295
+ match = re.search(regex, s)
296
+ if match is None:
297
+ return -50
298
+ return int(match.group(1))
299
+
300
+
301
+ def replace_punctuation(text: str) -> str:
302
+ """句読点等を「.」「,」「!」「?」「'」「-」に正規化し、OpenJTalkで読みが取得できるもののみ残す:
303
+ 漢字・平仮名・カタカナ、アルファベット、ギリシャ文字
304
+ """
305
+ pattern = re.compile("|".join(re.escape(p) for p in rep_map.keys()))
306
+ # print("before: ", text)
307
+ # 句読点を辞書で置換
308
+ replaced_text = pattern.sub(lambda x: rep_map[x.group()], text)
309
+
310
+ replaced_text = re.sub(
311
+ # ↓ ひらがな、カタカナ、漢字
312
+ r"[^\u3040-\u309F\u30A0-\u30FF\u4E00-\u9FFF\u3400-\u4DBF\u3005"
313
+ # ↓ 半角アルファベット(大文字と小文字)
314
+ + r"\u0041-\u005A\u0061-\u007A"
315
+ # ↓ 全角アルファベット(大文字と小文字)
316
+ + r"\uFF21-\uFF3A\uFF41-\uFF5A"
317
+ # ↓ ギリシャ文字
318
+ + r"\u0370-\u03FF\u1F00-\u1FFF"
319
+ # ↓ "!", "?", "…", ",", ".", "'", "-", 但し`…`はすでに`...`に変換されている
320
+ + "".join(punctuation) + r"]+",
321
+ # 上述以外の文字を削除
322
+ "",
323
+ replaced_text,
324
+ )
325
+ # print("after: ", replaced_text)
326
+ return replaced_text
327
+
328
+
329
+ def fix_phone_tone(phone_tone_list: list[tuple[str, int]]) -> list[tuple[str, int]]:
330
+ """
331
+ `phone_tone_list`のtone(アクセントの値)を0か1の範囲に修正する。
332
+ 例: [(a, 0), (i, -1), (u, -1)] → [(a, 1), (i, 0), (u, 0)]
333
+ """
334
+ tone_values = set(tone for _, tone in phone_tone_list)
335
+ if len(tone_values) == 1:
336
+ assert tone_values == {0}, tone_values
337
+ return phone_tone_list
338
+ elif len(tone_values) == 2:
339
+ if tone_values == {0, 1}:
340
+ return phone_tone_list
341
+ elif tone_values == {-1, 0}:
342
+ return [
343
+ (letter, 0 if tone == -1 else 1) for letter, tone in phone_tone_list
344
+ ]
345
+ else:
346
+ raise ValueError(f"Unexpected tone values: {tone_values}")
347
+ else:
348
+ raise ValueError(f"Unexpected tone values: {tone_values}")
349
+
350
+
351
+ def fix_phone_tone_wplen(phone_tone_list, word_phone_length_list):
352
+ phones = []
353
+ tones = []
354
+ w_p_len = []
355
+ p_len = len(phone_tone_list)
356
+ idx = 0
357
+ w_idx = 0
358
+ while idx < p_len:
359
+ offset = 0
360
+ if phone_tone_list[idx] == "▁":
361
+ w_p_len.append(w_idx + 1)
362
+
363
+ curr_w_p_len = word_phone_length_list[w_idx]
364
+ for i in range(curr_w_p_len):
365
+ p, t = phone_tone_list[idx]
366
+ if p == ":" and len(phones) > 0:
367
+ if phones[-1][-1] != ":":
368
+ phones[-1] += ":"
369
+ offset -= 1
370
+ else:
371
+ phones.append(p)
372
+ tones.append(str(t))
373
+ idx += 1
374
+ if idx >= p_len:
375
+ break
376
+ w_p_len.append(curr_w_p_len + offset)
377
+ w_idx += 1
378
+ # print(w_p_len)
379
+ return phones, tones, w_p_len
380
+
381
+
382
+ def g2phone_tone_wo_punct(prosodies) -> list[tuple[str, int]]:
383
+ """
384
+ テキストに対して、音素とアクセント(0か1)のペアのリストを返す。
385
+ ただし「!」「.」「?」等の非音素記号(punctuation)は全て消える(ポーズ記号も残さない)。
386
+ 非音素記号を含める処理は`align_tones()`で行われる。
387
+ また「っ」は「cl」でなく「q」に変換される(「ん」は「N」のまま)。
388
+ 例: "こんにちは、世界ー。。元気?!" →
389
+ [('k', 0), ('o', 0), ('N', 1), ('n', 1), ('i', 1), ('ch', 1), ('i', 1), ('w', 1), ('a', 1), ('s', 1), ('e', 1), ('k', 0), ('a', 0), ('i', 0), ('i', 0), ('g', 1), ('e', 1), ('N', 0), ('k', 0), ('i', 0)]
390
+ """
391
+ result: list[tuple[str, int]] = []
392
+ current_phrase: list[tuple[str, int]] = []
393
+ current_tone = 0
394
+ last_accent = ""
395
+ for i, letter in enumerate(prosodies):
396
+ # 特殊記号の処理
397
+
398
+ # 文頭記号、無視する
399
+ if letter == "^":
400
+ assert i == 0, "Unexpected ^"
401
+ # アクセント句の終わりに来る記号
402
+ elif letter in ("$", "?", "_", "#"):
403
+ # 保持しているフレーズを、アクセント数値を0-1に修正し結果に追加
404
+ result.extend(fix_phone_tone(current_phrase))
405
+ # 末尾に来る終了記号、無視(文中の疑問文は`_`になる)
406
+ if letter in ("$", "?"):
407
+ assert i == len(prosodies) - 1, f"Unexpected {letter}"
408
+ # あとは"_"(ポーズ)と"#"(アクセント句の境界)のみ
409
+ # これらは残さず、次のアクセント句に備える。
410
+
411
+ current_phrase = []
412
+ # 0を基準点にしてそこから上昇・下降する(負の場合は上の`fix_phone_tone`で直る)
413
+ current_tone = 0
414
+ last_accent = ""
415
+ # アクセント上昇記号
416
+ elif letter == "[":
417
+ if last_accent != letter:
418
+ current_tone = current_tone + 1
419
+ last_accent = letter
420
+ # アクセント下降記号
421
+ elif letter == "]":
422
+ if last_accent != letter:
423
+ current_tone = current_tone - 1
424
+ last_accent = letter
425
+ # それ以外は通常の音素
426
+ else:
427
+ if letter == "cl": # 「っ」の処理
428
+ letter = "q"
429
+ current_phrase.append((letter, current_tone))
430
+ return result
431
+
432
+
433
+ def handle_long(sep_phonemes: list[list[str]]) -> list[list[str]]:
434
+ for i in range(len(sep_phonemes)):
435
+ if sep_phonemes[i][0] == "ー":
436
+ # sep_phonemes[i][0] = sep_phonemes[i - 1][-1]
437
+ sep_phonemes[i][0] = ":"
438
+ if "ー" in sep_phonemes[i]:
439
+ for j in range(len(sep_phonemes[i])):
440
+ if sep_phonemes[i][j] == "ー":
441
+ # sep_phonemes[i][j] = sep_phonemes[i][j - 1][-1]
442
+ sep_phonemes[i][j] = ":"
443
+ return sep_phonemes
444
+
445
+
446
+ def handle_long_word(sep_phonemes: list[list[str]]) -> list[list[str]]:
447
+ res = []
448
+ for i in range(len(sep_phonemes)):
449
+ if sep_phonemes[i][0] == "ー":
450
+ sep_phonemes[i][0] = sep_phonemes[i - 1][-1]
451
+ # sep_phonemes[i][0] = ':'
452
+ if "ー" in sep_phonemes[i]:
453
+ for j in range(len(sep_phonemes[i])):
454
+ if sep_phonemes[i][j] == "ー":
455
+ sep_phonemes[i][j] = sep_phonemes[i][j - 1][-1]
456
+ # sep_phonemes[i][j] = ':'
457
+ res.append(sep_phonemes[i])
458
+ res.append("▁")
459
+ return res
460
+
461
+
462
+ def align_tones(
463
+ phones_with_punct: list[str], phone_tone_list: list[tuple[str, int]]
464
+ ) -> list[tuple[str, int]]:
465
+ """
466
+ 例:
467
+ …私は、、そう思う。
468
+ phones_with_punct:
469
+ [".", ".", ".", "w", "a", "t", "a", "sh", "i", "w", "a", ",", ",", "s", "o", "o", "o", "m", "o", "u", "."]
470
+ phone_tone_list:
471
+ [("w", 0), ("a", 0), ("t", 1), ("a", 1), ("sh", 1), ("i", 1), ("w", 1), ("a", 1), ("s", 0), ("o", 0), ("o", 1), ("o", 1), ("m", 1), ("o", 1), ("u", 0))]
472
+ Return:
473
+ [(".", 0), (".", 0), (".", 0), ("w", 0), ("a", 0), ("t", 1), ("a", 1), ("sh", 1), ("i", 1), ("w", 1), ("a", 1), (",", 0), (",", 0), ("s", 0), ("o", 0), ("o", 1), ("o", 1), ("m", 1), ("o", 1), ("u", 0), (".", 0)]
474
+ """
475
+ result: list[tuple[str, int]] = []
476
+ tone_index = 0
477
+ for phone in phones_with_punct:
478
+ if tone_index >= len(phone_tone_list):
479
+ # 余ったpunctuationがある場合 → (punctuation, 0)を追加
480
+ result.append((phone, 0))
481
+ elif phone == phone_tone_list[tone_index][0]:
482
+ # phone_tone_listの現在の音素と一致する場合 → toneをそこから取得、(phone, tone)を追加
483
+ result.append((phone, phone_tone_list[tone_index][1]))
484
+ # 探すindexを1つ進める
485
+ tone_index += 1
486
+ elif phone in punctuation or phone == "▁":
487
+ # phoneがpunctuationの場合 → (phone, 0)を追加
488
+ result.append((phone, 0))
489
+ else:
490
+ print(f"phones: {phones_with_punct}")
491
+ print(f"phone_tone_list: {phone_tone_list}")
492
+ print(f"result: {result}")
493
+ print(f"tone_index: {tone_index}")
494
+ print(f"phone: {phone}")
495
+ raise ValueError(f"Unexpected phone: {phone}")
496
+ return result
497
+
498
+
499
+ def kata2phoneme_list(text: str) -> list[str]:
500
+ """
501
+ 原則カタカナの`text`を受け取り、それをそのままいじらずに音素記号のリストに変換。
502
+ 注意点:
503
+ - punctuationが来た場合(punctuationが1文字の場合がありうる)、処理せず1文字のリストを返す
504
+ - 冒頭に続く「ー」はそのまま「ー」のままにする(`handle_long()`で処理される)
505
+ - 文中の「ー」は前の音素記号の最後の音素記号に変換される。
506
+ 例:
507
+ `ーーソーナノカーー` → ["ー", "ー", "s", "o", "o", "n", "a", "n", "o", "k", "a", "a", "a"]
508
+ `?` → ["?"]
509
+ """
510
+ if text in punctuation:
511
+ return [text]
512
+ # `text`がカタカナ(`ー`含む)のみからなるかどうかをチェック
513
+ if re.fullmatch(r"[\u30A0-\u30FF]+", text) is None:
514
+ raise ValueError(f"Input must be katakana only: {text}")
515
+ sorted_keys = sorted(mora_kata_to_mora_phonemes.keys(), key=len, reverse=True)
516
+ pattern = "|".join(map(re.escape, sorted_keys))
517
+
518
+ def mora2phonemes(mora: str) -> str:
519
+ cosonant, vowel = mora_kata_to_mora_phonemes[mora]
520
+ if cosonant is None:
521
+ return f" {vowel}"
522
+ return f" {cosonant} {vowel}"
523
+
524
+ spaced_phonemes = re.sub(pattern, lambda m: mora2phonemes(m.group()), text)
525
+
526
+ # 長音記号「ー」の処理
527
+ long_pattern = r"(\w)(ー*)"
528
+ long_replacement = lambda m: m.group(1) + (" " + m.group(1)) * len(m.group(2))
529
+ spaced_phonemes = re.sub(long_pattern, long_replacement, spaced_phonemes)
530
+ # spaced_phonemes += ' ▁'
531
+ return spaced_phonemes.strip().split(" ")
532
+
533
+
534
+ def frontend2phoneme(labels, drop_unvoiced_vowels=False):
535
+ N = len(labels)
536
+
537
+ phones = []
538
+ for n in range(N):
539
+ lab_curr = labels[n]
540
+ # print(lab_curr)
541
+ # current phoneme
542
+ p3 = re.search(r"\-(.*?)\+", lab_curr).group(1)
543
+
544
+ # deal unvoiced vowels as normal vowels
545
+ if drop_unvoiced_vowels and p3 in "AEIOU":
546
+ p3 = p3.lower()
547
+
548
+ # deal with sil at the beginning and the end of text
549
+ if p3 == "sil":
550
+ # assert n == 0 or n == N - 1
551
+ # if n == 0:
552
+ # phones.append("^")
553
+ # elif n == N - 1:
554
+ # # check question form or not
555
+ # e3 = _numeric_feature_by_regex(r"!(\d+)_", lab_curr)
556
+ # if e3 == 0:
557
+ # phones.append("$")
558
+ # elif e3 == 1:
559
+ # phones.append("?")
560
+ continue
561
+ elif p3 == "pau":
562
+ phones.append("_")
563
+ continue
564
+ else:
565
+ phones.append(p3)
566
+
567
+ # accent type and position info (forward or backward)
568
+ a1 = _numeric_feature_by_regex(r"/A:([0-9\-]+)\+", lab_curr)
569
+ a2 = _numeric_feature_by_regex(r"\+(\d+)\+", lab_curr)
570
+ a3 = _numeric_feature_by_regex(r"\+(\d+)/", lab_curr)
571
+
572
+ # number of mora in accent phrase
573
+ f1 = _numeric_feature_by_regex(r"/F:(\d+)_", lab_curr)
574
+
575
+ a2_next = _numeric_feature_by_regex(r"\+(\d+)\+", labels[n + 1])
576
+ # accent phrase border
577
+ # print(p3, a1, a2, a3, f1, a2_next, lab_curr)
578
+ if a3 == 1 and a2_next == 1 and p3 in "aeiouAEIOUNcl":
579
+ phones.append("#")
580
+ # pitch falling
581
+ elif a1 == 0 and a2_next == a2 + 1 and a2 != f1:
582
+ phones.append("]")
583
+ # pitch rising
584
+ elif a2 == 1 and a2_next == 2:
585
+ phones.append("[")
586
+
587
+ # phones = ' '.join(phones)
588
+ return phones
589
+
590
+
591
+ class JapanesePhoneConverter(object):
592
+ def __init__(self, lexicon_path=None, ipa_dict_path=None):
593
+ # lexicon_lines = open(lexicon_path, 'r', encoding='utf-8').readlines()
594
+ # self.lexicon = {}
595
+ # self.single_dict = {}
596
+ # self.double_dict = {}
597
+ # for curr_line in lexicon_lines:
598
+ # k,v = curr_line.strip().split('+',1)
599
+ # self.lexicon[k] = v
600
+ # if len(k) == 2:
601
+ # self.double_dict[k] = v
602
+ # elif len(k) == 1:
603
+ # self.single_dict[k] = v
604
+ self.ipa_dict = {}
605
+ for curr_line in jp_xphone2ipa:
606
+ k, v = curr_line.strip().split(" ", 1)
607
+ self.ipa_dict[k] = re.sub("\s", "", v)
608
+ # kakasi1 = kakasi()
609
+ # kakasi1.setMode("H","K")
610
+ # kakasi1.setMode("J","K")
611
+ # kakasi1.setMode("r","Hepburn")
612
+ self.japan_JH2K = kakasi()
613
+ self.table = {ord(f): ord(t) for f, t in zip("67", "_¯")}
614
+
615
+ def text2sep_kata(self, parsed) -> tuple[list[str], list[str]]:
616
+ """
617
+ `text_normalize`で正規化済みの`norm_text`を受け取り、それを単語分割し、
618
+ 分割された単語リストとその読み(カタカナor記号1文字)のリス���のタプルを返す。
619
+ 単語分割結果は、`g2p()`の`word2ph`で1文字あたりに割り振る音素記号の数を決めるために使う。
620
+ 例:
621
+ `私はそう思う!って感じ?` →
622
+ ["私", "は", "そう", "思う", "!", "って", "感じ", "?"], ["ワタシ", "ワ", "ソー", "オモウ", "!", "ッテ", "カンジ", "?"]
623
+ """
624
+ # parsed: OpenJTalkの解析結果
625
+ sep_text: list[str] = []
626
+ sep_kata: list[str] = []
627
+ fix_parsed = []
628
+ i = 0
629
+ while i <= len(parsed) - 1:
630
+ # word: 実際の単語の文字列
631
+ # yomi: その読み、但し無声化サインの`’`は除去
632
+ # print(parsed)
633
+ yomi = parsed[i]["pron"]
634
+ tmp_parsed = parsed[i]
635
+ if i != len(parsed) - 1 and parsed[i + 1]["string"] in [
636
+ "々",
637
+ "ゝ",
638
+ "ヽ",
639
+ "ゞ",
640
+ "ヾ",
641
+ "゛",
642
+ ]:
643
+ word = parsed[i]["string"] + parsed[i + 1]["string"]
644
+ i += 1
645
+ else:
646
+ word = parsed[i]["string"]
647
+ word, yomi = replace_punctuation(word), yomi.replace("’", "")
648
+ """
649
+ ここで`yomi`の取りうる値は以下の通りのはず。
650
+ - `word`が通常単語 → 通常の読み(カタカナ)
651
+ (カタカナからなり、長音記号も含みうる、`アー` 等)
652
+ - `word`が`ー` から始まる → `ーラー` や `ーーー` など
653
+ - `word`が句読点や空白等 → `、`
654
+ - `word`が`?` → `?`(全角になる)
655
+ 他にも`word`が読めないキリル文字アラビア文字等が来ると`、`になるが、正規化でこの場合は起きないはず。
656
+ また元のコードでは`yomi`が空白の場合の処理があったが、これは起きないはず。
657
+ 処理すべきは`yomi`が`、`の場合のみのはず。
658
+ """
659
+ assert yomi != "", f"Empty yomi: {word}"
660
+ if yomi == "、":
661
+ # wordは正規化されているので、`.`, `,`, `!`, `'`, `-`のいずれか
662
+ if word not in (
663
+ ".",
664
+ ",",
665
+ "!",
666
+ "'",
667
+ "-",
668
+ "?",
669
+ ":",
670
+ ";",
671
+ "…",
672
+ "",
673
+ ):
674
+ # ここはpyopenjtalkが読めない文字等のときに起こる
675
+ #print(
676
+ # "{}Cannot read:{}, yomi:{}, new_word:{};".format(
677
+ # parsed, word, yomi, self.japan_JH2K.convert(word)[0]["kana"]
678
+ # )
679
+ #)
680
+ # raise ValueError(word)
681
+ word = self.japan_JH2K.convert(word)[0]["kana"]
682
+ # print(word, self.japan_JH2K.convert(word)[0]['kana'], kata2phoneme_list(self.japan_JH2K.convert(word)[0]['kana']))
683
+ tmp_parsed["pron"] = word
684
+ # yomi = "-"
685
+ # word = ','
686
+ # yomiは元の記号のままに変更
687
+ # else:
688
+ # parsed[i]['pron'] = parsed[i]["string"]
689
+ yomi = word
690
+ elif yomi == "?":
691
+ assert word == "?", f"yomi `?` comes from: {word}"
692
+ yomi = "?"
693
+ if word == "":
694
+ i += 1
695
+ continue
696
+ sep_text.append(word)
697
+ sep_kata.append(yomi)
698
+ # print(word, yomi, parts)
699
+ fix_parsed.append(tmp_parsed)
700
+ i += 1
701
+ # print(sep_text, sep_kata)
702
+ return sep_text, sep_kata, fix_parsed
703
+
704
+ def getSentencePhone(self, sentence, blank_mode=True, phoneme_mode=False):
705
+ # print("origin:", sentence)
706
+ words = []
707
+ words_phone_len = []
708
+ short_char_flag = False
709
+ output_duration_flag = []
710
+ output_before_sil_flag = []
711
+ normed_text = []
712
+ sentence = sentence.strip().strip("'")
713
+ sentence = re.sub(r"\s+", "", sentence)
714
+ output_res = []
715
+ failed_words = []
716
+ last_long_pause = 4
717
+ last_word = None
718
+ frontend_text = pyopenjtalk.run_frontend(sentence)
719
+ # print("frontend_text: ", frontend_text)
720
+ try:
721
+ frontend_text = pyopenjtalk.estimate_accent(frontend_text)
722
+ except:
723
+ pass
724
+ # print("estimate_accent: ", frontend_text)
725
+ # sep_text: 単語単位の単語のリスト
726
+ # sep_kata: 単語単位の単語のカタカナ読みのリスト
727
+ sep_text, sep_kata, frontend_text = self.text2sep_kata(frontend_text)
728
+ # print("sep_text: ", sep_text)
729
+ # print("sep_kata: ", sep_kata)
730
+ # print("frontend_text: ", frontend_text)
731
+ # sep_phonemes: 各単語ご���の音素のリストのリスト
732
+ sep_phonemes = handle_long_word([kata2phoneme_list(i) for i in sep_kata])
733
+ # print("sep_phonemes: ", sep_phonemes)
734
+
735
+ pron_text = [x["pron"].strip().replace("’", "") for x in frontend_text]
736
+ # pdb.set_trace()
737
+ prosodys = pyopenjtalk.make_label(frontend_text)
738
+ prosodys = frontend2phoneme(prosodys, drop_unvoiced_vowels=True)
739
+ # print("prosodys: ", ' '.join(prosodys))
740
+ # print("pron_text: ", pron_text)
741
+ normed_text = [x["string"].strip() for x in frontend_text]
742
+ # punctuationがすべて消えた、音素とアクセントのタプルのリスト
743
+ phone_tone_list_wo_punct = g2phone_tone_wo_punct(prosodys)
744
+ # print("phone_tone_list_wo_punct: ", phone_tone_list_wo_punct)
745
+
746
+ # phone_w_punct: sep_phonemesを結合した、punctuationを元のまま保持した音素列
747
+ phone_w_punct: list[str] = []
748
+ w_p_len = []
749
+ for i in sep_phonemes:
750
+ phone_w_punct += i
751
+ w_p_len.append(len(i))
752
+ phone_w_punct = phone_w_punct[:-1]
753
+ # punctuation無しのアクセント情報を使って、punctuationを含めたアクセント情報を作る
754
+ # print("phone_w_punct: ", phone_w_punct)
755
+ # print("phone_tone_list_wo_punct: ", phone_tone_list_wo_punct)
756
+ phone_tone_list = align_tones(phone_w_punct, phone_tone_list_wo_punct)
757
+
758
+ jp_item = {}
759
+ jp_p = ""
760
+ jp_t = ""
761
+ # mye rye pye bye nye
762
+ # je she
763
+ # print(phone_tone_list)
764
+ for p, t in phone_tone_list:
765
+ if p in self.ipa_dict:
766
+ curr_p = self.ipa_dict[p]
767
+ jp_p += curr_p
768
+ jp_t += str(t + 6) * len(curr_p)
769
+ elif p in punctuation:
770
+ jp_p += p
771
+ jp_t += "0"
772
+ elif p == "▁":
773
+ jp_p += p
774
+ jp_t += " "
775
+ else:
776
+ print(p, t)
777
+ jp_p += "|"
778
+ jp_t += "0"
779
+ # return phones, tones, w_p_len
780
+ jp_p = jp_p.replace("▁", " ")
781
+ jp_t = jp_t.translate(self.table)
782
+ jp_l = ""
783
+ for t in jp_t:
784
+ if t == " ":
785
+ jp_l += " "
786
+ else:
787
+ jp_l += "2"
788
+ # print(jp_p)
789
+ # print(jp_t)
790
+ # print(jp_l)
791
+ # print(len(jp_p_len), sum(w_p_len), len(jp_p), sum(jp_p_len))
792
+ assert len(jp_p) == len(jp_t) and len(jp_p) == len(jp_l)
793
+
794
+ jp_item["jp_p"] = jp_p.replace("| |", "|").rstrip("|")
795
+ jp_item["jp_t"] = jp_t
796
+ jp_item["jp_l"] = jp_l
797
+ jp_item["jp_normed_text"] = " ".join(normed_text)
798
+ jp_item["jp_pron_text"] = " ".join(pron_text)
799
+ # jp_item['jp_ruoma'] = sep_phonemes
800
+ # print(len(normed_text), len(sep_phonemes))
801
+ # print(normed_text)
802
+ return jp_item
803
+
804
+
805
+ jpc = JapanesePhoneConverter()
806
+
807
+
808
+ def japanese_to_ipa(text, text_tokenizer):
809
+ # phonemes = text_tokenizer(text)
810
+ if type(text) == str:
811
+ return jpc.getSentencePhone(text)["jp_p"]
812
+ else:
813
+ result_ph = []
814
+ for t in text:
815
+ result_ph.append(jpc.getSentencePhone(t)["jp_p"])
816
+ return result_ph
diffrhythm/g2p/g2p/korean.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import re
7
+
8
+ """
9
+ Text clean time
10
+ """
11
+ english_dictionary = {
12
+ "KOREA": "코리아",
13
+ "IDOL": "아이돌",
14
+ "IT": "아이티",
15
+ "IQ": "아이큐",
16
+ "UP": "업",
17
+ "DOWN": "다운",
18
+ "PC": "피씨",
19
+ "CCTV": "씨씨티비",
20
+ "SNS": "에스엔에스",
21
+ "AI": "에이아이",
22
+ "CEO": "씨이오",
23
+ "A": "에이",
24
+ "B": "비",
25
+ "C": "씨",
26
+ "D": "디",
27
+ "E": "이",
28
+ "F": "에프",
29
+ "G": "지",
30
+ "H": "에이치",
31
+ "I": "아이",
32
+ "J": "제이",
33
+ "K": "케이",
34
+ "L": "엘",
35
+ "M": "엠",
36
+ "N": "엔",
37
+ "O": "오",
38
+ "P": "피",
39
+ "Q": "큐",
40
+ "R": "알",
41
+ "S": "에스",
42
+ "T": "티",
43
+ "U": "유",
44
+ "V": "브이",
45
+ "W": "더블유",
46
+ "X": "엑스",
47
+ "Y": "와이",
48
+ "Z": "제트",
49
+ }
50
+
51
+
52
+ def normalize(text):
53
+ text = text.strip()
54
+ text = re.sub(
55
+ "[⺀-⺙⺛-⻳⼀-⿕々〇〡-〩〸-〺〻㐀-䶵一-鿃豈-鶴侮-頻並-龎]", "", text
56
+ )
57
+ text = normalize_english(text)
58
+ text = text.lower()
59
+ return text
60
+
61
+
62
+ def normalize_english(text):
63
+ def fn(m):
64
+ word = m.group()
65
+ if word in english_dictionary:
66
+ return english_dictionary.get(word)
67
+ return word
68
+
69
+ text = re.sub("([A-Za-z]+)", fn, text)
70
+ return text
71
+
72
+
73
+ def korean_to_ipa(text, text_tokenizer):
74
+ if type(text) == str:
75
+ text = normalize(text)
76
+ phonemes = text_tokenizer(text)
77
+ return phonemes
78
+ else:
79
+ for i, t in enumerate(text):
80
+ text[i] = normalize(t)
81
+ return text_tokenizer(text)
diffrhythm/g2p/g2p/mandarin.py ADDED
@@ -0,0 +1,595 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import re
7
+ import jieba
8
+ import cn2an
9
+ from pypinyin import lazy_pinyin, BOPOMOFO
10
+ from typing import List
11
+ from diffrhythm.g2p.g2p.chinese_model_g2p import BertPolyPredict
12
+ from diffrhythm.g2p.utils.front_utils import *
13
+ import os
14
+
15
+ # from g2pw import G2PWConverter
16
+
17
+
18
+ # set blank level, {0:"none",1:"char", 2:"word"}
19
+ BLANK_LEVEL = 0
20
+
21
+ # conv = G2PWConverter(style='pinyin', enable_non_tradional_chinese=True)
22
+ resource_path = r"./diffrhythm/g2p"
23
+ poly_all_class_path = os.path.join(
24
+ resource_path, "sources", "g2p_chinese_model", "polychar.txt"
25
+ )
26
+ if not os.path.exists(poly_all_class_path):
27
+ print(
28
+ "Incorrect path for polyphonic character class dictionary: {}, please check...".format(
29
+ poly_all_class_path
30
+ )
31
+ )
32
+ exit()
33
+ poly_dict = generate_poly_lexicon(poly_all_class_path)
34
+
35
+ # Set up G2PW model parameters
36
+ g2pw_poly_model_path = os.path.join(resource_path, "sources", "g2p_chinese_model")
37
+ if not os.path.exists(g2pw_poly_model_path):
38
+ print(
39
+ "Incorrect path for g2pw polyphonic character model: {}, please check...".format(
40
+ g2pw_poly_model_path
41
+ )
42
+ )
43
+ exit()
44
+
45
+ json_file_path = os.path.join(
46
+ resource_path, "sources", "g2p_chinese_model", "polydict.json"
47
+ )
48
+ if not os.path.exists(json_file_path):
49
+ print(
50
+ "Incorrect path for g2pw id to pinyin dictionary: {}, please check...".format(
51
+ json_file_path
52
+ )
53
+ )
54
+ exit()
55
+
56
+ jsonr_file_path = os.path.join(
57
+ resource_path, "sources", "g2p_chinese_model", "polydict_r.json"
58
+ )
59
+ if not os.path.exists(jsonr_file_path):
60
+ print(
61
+ "Incorrect path for g2pw pinyin to id dictionary: {}, please check...".format(
62
+ jsonr_file_path
63
+ )
64
+ )
65
+ exit()
66
+
67
+ g2pw_poly_predict = BertPolyPredict(
68
+ g2pw_poly_model_path, jsonr_file_path, json_file_path
69
+ )
70
+
71
+
72
+ """
73
+ Text clean time
74
+ """
75
+ # List of (Latin alphabet, bopomofo) pairs:
76
+ _latin_to_bopomofo = [
77
+ (re.compile("%s" % x[0], re.IGNORECASE), x[1])
78
+ for x in [
79
+ ("a", "ㄟˉ"),
80
+ ("b", "ㄅㄧˋ"),
81
+ ("c", "ㄙㄧˉ"),
82
+ ("d", "ㄉㄧˋ"),
83
+ ("e", "ㄧˋ"),
84
+ ("f", "ㄝˊㄈㄨˋ"),
85
+ ("g", "ㄐㄧˋ"),
86
+ ("h", "ㄝˇㄑㄩˋ"),
87
+ ("i", "ㄞˋ"),
88
+ ("j", "ㄐㄟˋ"),
89
+ ("k", "ㄎㄟˋ"),
90
+ ("l", "ㄝˊㄛˋ"),
91
+ ("m", "ㄝˊㄇㄨˋ"),
92
+ ("n", "ㄣˉ"),
93
+ ("o", "ㄡˉ"),
94
+ ("p", "ㄆㄧˉ"),
95
+ ("q", "ㄎㄧㄡˉ"),
96
+ ("r", "ㄚˋ"),
97
+ ("s", "ㄝˊㄙˋ"),
98
+ ("t", "ㄊㄧˋ"),
99
+ ("u", "ㄧㄡˉ"),
100
+ ("v", "ㄨㄧˉ"),
101
+ ("w", "ㄉㄚˋㄅㄨˋㄌㄧㄡˋ"),
102
+ ("x", "ㄝˉㄎㄨˋㄙˋ"),
103
+ ("y", "ㄨㄞˋ"),
104
+ ("z", "ㄗㄟˋ"),
105
+ ]
106
+ ]
107
+
108
+ # List of (bopomofo, ipa) pairs:
109
+ _bopomofo_to_ipa = [
110
+ (re.compile("%s" % x[0]), x[1])
111
+ for x in [
112
+ ("ㄅㄛ", "p⁼wo"),
113
+ ("ㄆㄛ", "pʰwo"),
114
+ ("ㄇㄛ", "mwo"),
115
+ ("ㄈㄛ", "fwo"),
116
+ ("ㄧㄢ", "|jɛn"),
117
+ ("ㄩㄢ", "|ɥæn"),
118
+ ("ㄧㄣ", "|in"),
119
+ ("ㄩㄣ", "|ɥn"),
120
+ ("ㄧㄥ", "|iŋ"),
121
+ ("ㄨㄥ", "|ʊŋ"),
122
+ ("ㄩㄥ", "|jʊŋ"),
123
+ # Add
124
+ ("ㄧㄚ", "|ia"),
125
+ ("ㄧㄝ", "|iɛ"),
126
+ ("ㄧㄠ", "|iɑʊ"),
127
+ ("ㄧㄡ", "|ioʊ"),
128
+ ("ㄧㄤ", "|iɑŋ"),
129
+ ("ㄨㄚ", "|ua"),
130
+ ("ㄨㄛ", "|uo"),
131
+ ("ㄨㄞ", "|uaɪ"),
132
+ ("ㄨㄟ", "|ueɪ"),
133
+ ("ㄨㄢ", "|uan"),
134
+ ("ㄨㄣ", "|uən"),
135
+ ("ㄨㄤ", "|uɑŋ"),
136
+ ("ㄩㄝ", "|ɥɛ"),
137
+ # End
138
+ ("ㄅ", "p⁼"),
139
+ ("ㄆ", "pʰ"),
140
+ ("ㄇ", "m"),
141
+ ("ㄈ", "f"),
142
+ ("ㄉ", "t⁼"),
143
+ ("ㄊ", "tʰ"),
144
+ ("ㄋ", "n"),
145
+ ("ㄌ", "l"),
146
+ ("ㄍ", "k⁼"),
147
+ ("ㄎ", "kʰ"),
148
+ ("ㄏ", "x"),
149
+ ("ㄐ", "tʃ⁼"),
150
+ ("ㄑ", "tʃʰ"),
151
+ ("ㄒ", "ʃ"),
152
+ ("ㄓ", "ts`⁼"),
153
+ ("ㄔ", "ts`ʰ"),
154
+ ("ㄕ", "s`"),
155
+ ("ㄖ", "ɹ`"),
156
+ ("ㄗ", "ts⁼"),
157
+ ("ㄘ", "tsʰ"),
158
+ ("ㄙ", "|s"),
159
+ ("ㄚ", "|a"),
160
+ ("ㄛ", "|o"),
161
+ ("ㄜ", "|ə"),
162
+ ("ㄝ", "|ɛ"),
163
+ ("ㄞ", "|aɪ"),
164
+ ("ㄟ", "|eɪ"),
165
+ ("ㄠ", "|ɑʊ"),
166
+ ("ㄡ", "|oʊ"),
167
+ ("ㄢ", "|an"),
168
+ ("ㄣ", "|ən"),
169
+ ("ㄤ", "|ɑŋ"),
170
+ ("ㄥ", "|əŋ"),
171
+ ("ㄦ", "əɹ"),
172
+ ("ㄧ", "|i"),
173
+ ("ㄨ", "|u"),
174
+ ("ㄩ", "|ɥ"),
175
+ ("ˉ", "→|"),
176
+ ("ˊ", "↑|"),
177
+ ("ˇ", "↓↑|"),
178
+ ("ˋ", "↓|"),
179
+ ("˙", "|"),
180
+ ]
181
+ ]
182
+ must_not_er_words = {"女儿", "老儿", "男儿", "少儿", "小儿"}
183
+
184
+ word_pinyin_dict = {}
185
+ with open(
186
+ r"./diffrhythm/g2p/sources/chinese_lexicon.txt", "r", encoding="utf-8"
187
+ ) as fread:
188
+ txt_list = fread.readlines()
189
+ for txt in txt_list:
190
+ word, pinyin = txt.strip().split("\t")
191
+ word_pinyin_dict[word] = pinyin
192
+ fread.close()
193
+
194
+ pinyin_2_bopomofo_dict = {}
195
+ with open(
196
+ r"./diffrhythm/g2p/sources/pinyin_2_bpmf.txt", "r", encoding="utf-8"
197
+ ) as fread:
198
+ txt_list = fread.readlines()
199
+ for txt in txt_list:
200
+ pinyin, bopomofo = txt.strip().split("\t")
201
+ pinyin_2_bopomofo_dict[pinyin] = bopomofo
202
+ fread.close()
203
+
204
+ tone_dict = {
205
+ "0": "˙",
206
+ "5": "˙",
207
+ "1": "",
208
+ "2": "ˊ",
209
+ "3": "ˇ",
210
+ "4": "ˋ",
211
+ }
212
+
213
+ bopomofos2pinyin_dict = {}
214
+ with open(
215
+ r"./diffrhythm/g2p/sources/bpmf_2_pinyin.txt", "r", encoding="utf-8"
216
+ ) as fread:
217
+ txt_list = fread.readlines()
218
+ for txt in txt_list:
219
+ v, k = txt.strip().split("\t")
220
+ bopomofos2pinyin_dict[k] = v
221
+ fread.close()
222
+
223
+
224
+ def bpmf_to_pinyin(text):
225
+ bopomofo_list = text.split("|")
226
+ pinyin_list = []
227
+ for info in bopomofo_list:
228
+ pinyin = ""
229
+ for c in info:
230
+ if c in bopomofos2pinyin_dict:
231
+ pinyin += bopomofos2pinyin_dict[c]
232
+ if len(pinyin) == 0:
233
+ continue
234
+ if pinyin[-1] not in "01234":
235
+ pinyin += "1"
236
+ if pinyin[:-1] == "ve":
237
+ pinyin = "y" + pinyin
238
+ if pinyin[:-1] == "sh":
239
+ pinyin = pinyin[:-1] + "i" + pinyin[-1]
240
+ if pinyin == "sh":
241
+ pinyin = pinyin[:-1] + "i"
242
+ if pinyin[:-1] == "s":
243
+ pinyin = "si" + pinyin[-1]
244
+ if pinyin[:-1] == "c":
245
+ pinyin = "ci" + pinyin[-1]
246
+ if pinyin[:-1] == "i":
247
+ pinyin = "yi" + pinyin[-1]
248
+ if pinyin[:-1] == "iou":
249
+ pinyin = "you" + pinyin[-1]
250
+ if pinyin[:-1] == "ien":
251
+ pinyin = "yin" + pinyin[-1]
252
+ if "iou" in pinyin and pinyin[-4:-1] == "iou":
253
+ pinyin = pinyin[:-4] + "iu" + pinyin[-1]
254
+ if "uei" in pinyin:
255
+ if pinyin[:-1] == "uei":
256
+ pinyin = "wei" + pinyin[-1]
257
+ elif pinyin[-4:-1] == "uei":
258
+ pinyin = pinyin[:-4] + "ui" + pinyin[-1]
259
+ if "uen" in pinyin and pinyin[-4:-1] == "uen":
260
+ if pinyin[:-1] == "uen":
261
+ pinyin = "wen" + pinyin[-1]
262
+ elif pinyin[-4:-1] == "uei":
263
+ pinyin = pinyin[:-4] + "un" + pinyin[-1]
264
+ if "van" in pinyin and pinyin[-4:-1] == "van":
265
+ if pinyin[:-1] == "van":
266
+ pinyin = "yuan" + pinyin[-1]
267
+ elif pinyin[-4:-1] == "van":
268
+ pinyin = pinyin[:-4] + "uan" + pinyin[-1]
269
+ if "ueng" in pinyin and pinyin[-5:-1] == "ueng":
270
+ pinyin = pinyin[:-5] + "ong" + pinyin[-1]
271
+ if pinyin[:-1] == "veng":
272
+ pinyin = "yong" + pinyin[-1]
273
+ if "veng" in pinyin and pinyin[-5:-1] == "veng":
274
+ pinyin = pinyin[:-5] + "iong" + pinyin[-1]
275
+ if pinyin[:-1] == "ieng":
276
+ pinyin = "ying" + pinyin[-1]
277
+ if pinyin[:-1] == "u":
278
+ pinyin = "wu" + pinyin[-1]
279
+ if pinyin[:-1] == "v":
280
+ pinyin = "yv" + pinyin[-1]
281
+ if pinyin[:-1] == "ing":
282
+ pinyin = "ying" + pinyin[-1]
283
+ if pinyin[:-1] == "z":
284
+ pinyin = "zi" + pinyin[-1]
285
+ if pinyin[:-1] == "zh":
286
+ pinyin = "zhi" + pinyin[-1]
287
+ if pinyin[0] == "u":
288
+ pinyin = "w" + pinyin[1:]
289
+ if pinyin[0] == "i":
290
+ pinyin = "y" + pinyin[1:]
291
+ pinyin = pinyin.replace("ien", "in")
292
+
293
+ pinyin_list.append(pinyin)
294
+ return " ".join(pinyin_list)
295
+
296
+
297
+ # Convert numbers to Chinese pronunciation
298
+ def number_to_chinese(text):
299
+ # numbers = re.findall(r'\d+(?:\.?\d+)?', text)
300
+ # for number in numbers:
301
+ # text = text.replace(number, cn2an.an2cn(number), 1)
302
+ text = cn2an.transform(text, "an2cn")
303
+ return text
304
+
305
+
306
+ def normalization(text):
307
+ text = text.replace(",", ",")
308
+ text = text.replace("。", ".")
309
+ text = text.replace("!", "!")
310
+ text = text.replace("?", "?")
311
+ text = text.replace(";", ";")
312
+ text = text.replace(":", ":")
313
+ text = text.replace("、", ",")
314
+ text = text.replace("‘", "'")
315
+ text = text.replace("’", "'")
316
+ text = text.replace("⋯", "…")
317
+ text = text.replace("···", "…")
318
+ text = text.replace("・・・", "…")
319
+ text = text.replace("...", "…")
320
+ text = re.sub(r"\s+", "", text)
321
+ text = re.sub(r"[^\u4e00-\u9fff\s_,\.\?!;:\'…]", "", text)
322
+ text = re.sub(r"\s*([,\.\?!;:\'…])\s*", r"\1", text)
323
+ return text
324
+
325
+
326
+ def change_tone(bopomofo: str, tone: str) -> str:
327
+ if bopomofo[-1] not in "˙ˊˇˋ":
328
+ bopomofo = bopomofo + tone
329
+ else:
330
+ bopomofo = bopomofo[:-1] + tone
331
+ return bopomofo
332
+
333
+
334
+ def er_sandhi(word: str, bopomofos: List[str]) -> List[str]:
335
+ if len(word) > 1 and word[-1] == "儿" and word not in must_not_er_words:
336
+ bopomofos[-1] = change_tone(bopomofos[-1], "˙")
337
+ return bopomofos
338
+
339
+
340
+ def bu_sandhi(word: str, bopomofos: List[str]) -> List[str]:
341
+ valid_char = set(word)
342
+ if len(valid_char) == 1 and "不" in valid_char:
343
+ pass
344
+ elif word in ["不字"]:
345
+ pass
346
+ elif len(word) == 3 and word[1] == "不" and bopomofos[1][:-1] == "ㄅㄨ":
347
+ bopomofos[1] = bopomofos[1][:-1] + "˙"
348
+ else:
349
+ for i, char in enumerate(word):
350
+ if (
351
+ i + 1 < len(bopomofos)
352
+ and char == "不"
353
+ and i + 1 < len(word)
354
+ and 0 < len(bopomofos[i + 1])
355
+ and bopomofos[i + 1][-1] == "ˋ"
356
+ ):
357
+ bopomofos[i] = bopomofos[i][:-1] + "ˊ"
358
+ return bopomofos
359
+
360
+
361
+ def yi_sandhi(word: str, bopomofos: List[str]) -> List[str]:
362
+ punc = ":,;。?!“”‘’':,;.?!()(){}【】[]-~`、 "
363
+ if word.find("一") != -1 and any(
364
+ [item.isnumeric() for item in word if item != "一"]
365
+ ):
366
+ for i in range(len(word)):
367
+ if (
368
+ i == 0
369
+ and word[0] == "一"
370
+ and len(word) > 1
371
+ and word[1]
372
+ not in [
373
+ "零",
374
+ "一",
375
+ "二",
376
+ "三",
377
+ "四",
378
+ "五",
379
+ "六",
380
+ "七",
381
+ "八",
382
+ "九",
383
+ "十",
384
+ ]
385
+ ):
386
+ if len(bopomofos[0]) > 0 and bopomofos[1][-1] in ["ˋ", "˙"]:
387
+ bopomofos[0] = change_tone(bopomofos[0], "ˊ")
388
+ else:
389
+ bopomofos[0] = change_tone(bopomofos[0], "ˋ")
390
+ elif word[i] == "一":
391
+ bopomofos[i] = change_tone(bopomofos[i], "")
392
+ return bopomofos
393
+ elif len(word) == 3 and word[1] == "一" and word[0] == word[-1]:
394
+ bopomofos[1] = change_tone(bopomofos[1], "˙")
395
+ elif word.startswith("第一"):
396
+ bopomofos[1] = change_tone(bopomofos[1], "")
397
+ elif word.startswith("一月") or word.startswith("一日") or word.startswith("一号"):
398
+ bopomofos[0] = change_tone(bopomofos[0], "")
399
+ else:
400
+ for i, char in enumerate(word):
401
+ if char == "一" and i + 1 < len(word):
402
+ if (
403
+ len(bopomofos) > i + 1
404
+ and len(bopomofos[i + 1]) > 0
405
+ and bopomofos[i + 1][-1] in {"ˋ"}
406
+ ):
407
+ bopomofos[i] = change_tone(bopomofos[i], "ˊ")
408
+ else:
409
+ if word[i + 1] not in punc:
410
+ bopomofos[i] = change_tone(bopomofos[i], "ˋ")
411
+ else:
412
+ pass
413
+ return bopomofos
414
+
415
+
416
+ def merge_bu(seg: List) -> List:
417
+ new_seg = []
418
+ last_word = ""
419
+ for word in seg:
420
+ if word != "不":
421
+ if last_word == "不":
422
+ word = last_word + word
423
+ new_seg.append(word)
424
+ last_word = word
425
+ return new_seg
426
+
427
+
428
+ def merge_er(seg: List) -> List:
429
+ new_seg = []
430
+ for i, word in enumerate(seg):
431
+ if i - 1 >= 0 and word == "儿":
432
+ new_seg[-1] = new_seg[-1] + seg[i]
433
+ else:
434
+ new_seg.append(word)
435
+ return new_seg
436
+
437
+
438
+ def merge_yi(seg: List) -> List:
439
+ new_seg = []
440
+ # function 1
441
+ for i, word in enumerate(seg):
442
+ if (
443
+ i - 1 >= 0
444
+ and word == "一"
445
+ and i + 1 < len(seg)
446
+ and seg[i - 1] == seg[i + 1]
447
+ ):
448
+ if i - 1 < len(new_seg):
449
+ new_seg[i - 1] = new_seg[i - 1] + "一" + new_seg[i - 1]
450
+ else:
451
+ new_seg.append(word)
452
+ new_seg.append(seg[i + 1])
453
+ else:
454
+ if i - 2 >= 0 and seg[i - 1] == "一" and seg[i - 2] == word:
455
+ continue
456
+ else:
457
+ new_seg.append(word)
458
+ seg = new_seg
459
+ new_seg = []
460
+ isnumeric_flag = False
461
+ for i, word in enumerate(seg):
462
+ if all([item.isnumeric() for item in word]) and not isnumeric_flag:
463
+ isnumeric_flag = True
464
+ new_seg.append(word)
465
+ else:
466
+ new_seg.append(word)
467
+ seg = new_seg
468
+ new_seg = []
469
+ # function 2
470
+ for i, word in enumerate(seg):
471
+ if new_seg and new_seg[-1] == "一":
472
+ new_seg[-1] = new_seg[-1] + word
473
+ else:
474
+ new_seg.append(word)
475
+ return new_seg
476
+
477
+
478
+ # Word Segmentation, and convert Chinese pronunciation to pinyin (bopomofo)
479
+ def chinese_to_bopomofo(text_short, sentence):
480
+ # bopomofos = conv(text_short)
481
+ words = jieba.lcut(text_short, cut_all=False)
482
+ words = merge_yi(words)
483
+ words = merge_bu(words)
484
+ words = merge_er(words)
485
+ text = ""
486
+
487
+ char_index = 0
488
+ for word in words:
489
+ bopomofos = []
490
+ if word in word_pinyin_dict and word not in poly_dict:
491
+ pinyin = word_pinyin_dict[word]
492
+ for py in pinyin.split(" "):
493
+ if py[:-1] in pinyin_2_bopomofo_dict and py[-1] in tone_dict:
494
+ bopomofos.append(
495
+ pinyin_2_bopomofo_dict[py[:-1]] + tone_dict[py[-1]]
496
+ )
497
+ if BLANK_LEVEL == 1:
498
+ bopomofos.append("_")
499
+ else:
500
+ bopomofos_lazy = lazy_pinyin(word, BOPOMOFO)
501
+ bopomofos += bopomofos_lazy
502
+ if BLANK_LEVEL == 1:
503
+ bopomofos.append("_")
504
+ else:
505
+ for i in range(len(word)):
506
+ c = word[i]
507
+ if c in poly_dict:
508
+ poly_pinyin = g2pw_poly_predict.predict_process(
509
+ [text_short, char_index + i]
510
+ )[0]
511
+ py = poly_pinyin[2:-1]
512
+ bopomofos.append(
513
+ pinyin_2_bopomofo_dict[py[:-1]] + tone_dict[py[-1]]
514
+ )
515
+ if BLANK_LEVEL == 1:
516
+ bopomofos.append("_")
517
+ elif c in word_pinyin_dict:
518
+ py = word_pinyin_dict[c]
519
+ bopomofos.append(
520
+ pinyin_2_bopomofo_dict[py[:-1]] + tone_dict[py[-1]]
521
+ )
522
+ if BLANK_LEVEL == 1:
523
+ bopomofos.append("_")
524
+ else:
525
+ bopomofos.append(c)
526
+ if BLANK_LEVEL == 1:
527
+ bopomofos.append("_")
528
+ if BLANK_LEVEL == 2:
529
+ bopomofos.append("_")
530
+ char_index += len(word)
531
+
532
+ if (
533
+ len(word) == 3
534
+ and bopomofos[0][-1] == "ˇ"
535
+ and bopomofos[1][-1] == "ˇ"
536
+ and bopomofos[-1][-1] == "ˇ"
537
+ ):
538
+ bopomofos[0] = bopomofos[0] + "ˊ"
539
+ bopomofos[1] = bopomofos[1] + "ˊ"
540
+ if len(word) == 2 and bopomofos[0][-1] == "ˇ" and bopomofos[-1][-1] == "ˇ":
541
+ bopomofos[0] = bopomofos[0][:-1] + "ˊ"
542
+ bopomofos = bu_sandhi(word, bopomofos)
543
+ bopomofos = yi_sandhi(word, bopomofos)
544
+ bopomofos = er_sandhi(word, bopomofos)
545
+ if not re.search("[\u4e00-\u9fff]", word):
546
+ text += "|" + word
547
+ continue
548
+ for i in range(len(bopomofos)):
549
+ bopomofos[i] = re.sub(r"([\u3105-\u3129])$", r"\1ˉ", bopomofos[i])
550
+ if text != "":
551
+ text += "|"
552
+ text += "|".join(bopomofos)
553
+ return text
554
+
555
+
556
+ # Convert latin pronunciation to pinyin (bopomofo)
557
+ def latin_to_bopomofo(text):
558
+ for regex, replacement in _latin_to_bopomofo:
559
+ text = re.sub(regex, replacement, text)
560
+ return text
561
+
562
+
563
+ # Convert pinyin (bopomofo) to IPA
564
+ def bopomofo_to_ipa(text):
565
+ for regex, replacement in _bopomofo_to_ipa:
566
+ text = re.sub(regex, replacement, text)
567
+ return text
568
+
569
+
570
+ def _chinese_to_ipa(text, sentence):
571
+ text = number_to_chinese(text.strip())
572
+ text = normalization(text)
573
+ text = chinese_to_bopomofo(text, sentence)
574
+ # pinyin = bpmf_to_pinyin(text)
575
+ text = latin_to_bopomofo(text)
576
+ text = bopomofo_to_ipa(text)
577
+ text = re.sub("([sɹ]`[⁼ʰ]?)([→↓↑ ]+|$)", r"\1ɹ\2", text)
578
+ text = re.sub("([s][⁼ʰ]?)([→↓↑ ]+|$)", r"\1ɹ\2", text)
579
+ text = re.sub(r"^\||[^\w\s_,\.\?!;:\'…\|→↓↑⁼ʰ`]", "", text)
580
+ text = re.sub(r"([,\.\?!;:\'…])", r"|\1|", text)
581
+ text = re.sub(r"\|+", "|", text)
582
+ text = text.rstrip("|")
583
+ return text
584
+
585
+
586
+ # Convert Chinese to IPA
587
+ def chinese_to_ipa(text, sentence, text_tokenizer):
588
+ # phonemes = text_tokenizer(text.strip())
589
+ if type(text) == str:
590
+ return _chinese_to_ipa(text, sentence)
591
+ else:
592
+ result_ph = []
593
+ for t in text:
594
+ result_ph.append(_chinese_to_ipa(t, sentence))
595
+ return result_ph
diffrhythm/g2p/g2p/text_tokenizers.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import re
7
+ import os
8
+ from typing import List, Pattern, Union
9
+ from phonemizer.utils import list2str, str2list
10
+ from phonemizer.backend import EspeakBackend
11
+ from phonemizer.backend.espeak.language_switch import LanguageSwitch
12
+ from phonemizer.backend.espeak.words_mismatch import WordMismatch
13
+ from phonemizer.punctuation import Punctuation
14
+ from phonemizer.separator import Separator
15
+
16
+
17
+ class TextTokenizer:
18
+ """Phonemize Text."""
19
+
20
+ def __init__(
21
+ self,
22
+ language="en-us",
23
+ backend="espeak",
24
+ separator=Separator(word="|_|", syllable="-", phone="|"),
25
+ preserve_punctuation=True,
26
+ with_stress: bool = False,
27
+ tie: Union[bool, str] = False,
28
+ language_switch: LanguageSwitch = "remove-flags",
29
+ words_mismatch: WordMismatch = "ignore",
30
+ ) -> None:
31
+ self.preserve_punctuation_marks = ",.?!;:'…"
32
+ self.backend = EspeakBackend(
33
+ language,
34
+ punctuation_marks=self.preserve_punctuation_marks,
35
+ preserve_punctuation=preserve_punctuation,
36
+ with_stress=with_stress,
37
+ tie=tie,
38
+ language_switch=language_switch,
39
+ words_mismatch=words_mismatch,
40
+ )
41
+
42
+ self.separator = separator
43
+
44
+ # convert chinese punctuation to english punctuation
45
+ def convert_chinese_punctuation(self, text: str) -> str:
46
+ text = text.replace(",", ",")
47
+ text = text.replace("。", ".")
48
+ text = text.replace("!", "!")
49
+ text = text.replace("?", "?")
50
+ text = text.replace(";", ";")
51
+ text = text.replace(":", ":")
52
+ text = text.replace("、", ",")
53
+ text = text.replace("‘", "'")
54
+ text = text.replace("’", "'")
55
+ text = text.replace("⋯", "…")
56
+ text = text.replace("···", "…")
57
+ text = text.replace("・・・", "…")
58
+ text = text.replace("...", "…")
59
+ return text
60
+
61
+ def __call__(self, text, strip=True) -> List[str]:
62
+
63
+ text_type = type(text)
64
+ normalized_text = []
65
+ for line in str2list(text):
66
+ line = self.convert_chinese_punctuation(line.strip())
67
+ line = re.sub(r"[^\w\s_,\.\?!;:\'…]", "", line)
68
+ line = re.sub(r"\s*([,\.\?!;:\'…])\s*", r"\1", line)
69
+ line = re.sub(r"\s+", " ", line)
70
+ normalized_text.append(line)
71
+ # print("Normalized test: ", normalized_text[0])
72
+ phonemized = self.backend.phonemize(
73
+ normalized_text, separator=self.separator, strip=strip, njobs=1
74
+ )
75
+ if text_type == str:
76
+ phonemized = re.sub(r"([,\.\?!;:\'…])", r"|\1|", list2str(phonemized))
77
+ phonemized = re.sub(r"\|+", "|", phonemized)
78
+ phonemized = phonemized.rstrip("|")
79
+ else:
80
+ for i in range(len(phonemized)):
81
+ phonemized[i] = re.sub(r"([,\.\?!;:\'…])", r"|\1|", phonemized[i])
82
+ phonemized[i] = re.sub(r"\|+", "|", phonemized[i])
83
+ phonemized[i] = phonemized[i].rstrip("|")
84
+ return phonemized
diffrhythm/g2p/g2p/vocab.json ADDED
@@ -0,0 +1,372 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "vocab": {
3
+ ",": 0,
4
+ ".": 1,
5
+ "?": 2,
6
+ "!": 3,
7
+ "_": 4,
8
+ "iː": 5,
9
+ "ɪ": 6,
10
+ "ɜː": 7,
11
+ "ɚ": 8,
12
+ "oːɹ": 9,
13
+ "ɔː": 10,
14
+ "ɔːɹ": 11,
15
+ "ɑː": 12,
16
+ "uː": 13,
17
+ "ʊ": 14,
18
+ "ɑːɹ": 15,
19
+ "ʌ": 16,
20
+ "ɛ": 17,
21
+ "æ": 18,
22
+ "eɪ": 19,
23
+ "aɪ": 20,
24
+ "ɔɪ": 21,
25
+ "aʊ": 22,
26
+ "oʊ": 23,
27
+ "ɪɹ": 24,
28
+ "ɛɹ": 25,
29
+ "ʊɹ": 26,
30
+ "p": 27,
31
+ "b": 28,
32
+ "t": 29,
33
+ "d": 30,
34
+ "k": 31,
35
+ "ɡ": 32,
36
+ "f": 33,
37
+ "v": 34,
38
+ "θ": 35,
39
+ "ð": 36,
40
+ "s": 37,
41
+ "z": 38,
42
+ "ʃ": 39,
43
+ "ʒ": 40,
44
+ "h": 41,
45
+ "tʃ": 42,
46
+ "dʒ": 43,
47
+ "m": 44,
48
+ "n": 45,
49
+ "ŋ": 46,
50
+ "j": 47,
51
+ "w": 48,
52
+ "ɹ": 49,
53
+ "l": 50,
54
+ "tɹ": 51,
55
+ "dɹ": 52,
56
+ "ts": 53,
57
+ "dz": 54,
58
+ "i": 55,
59
+ "ɔ": 56,
60
+ "ə": 57,
61
+ "ɾ": 58,
62
+ "iə": 59,
63
+ "r": 60,
64
+ "u": 61,
65
+ "oː": 62,
66
+ "ɛː": 63,
67
+ "ɪː": 64,
68
+ "aɪə": 65,
69
+ "aɪɚ": 66,
70
+ "ɑ̃": 67,
71
+ "ç": 68,
72
+ "ɔ̃": 69,
73
+ "ææ": 70,
74
+ "ɐɐ": 71,
75
+ "ɡʲ": 72,
76
+ "nʲ": 73,
77
+ "iːː": 74,
78
+
79
+ "p⁼": 75,
80
+ "pʰ": 76,
81
+ "t⁼": 77,
82
+ "tʰ": 78,
83
+ "k⁼": 79,
84
+ "kʰ": 80,
85
+ "x": 81,
86
+ "tʃ⁼": 82,
87
+ "tʃʰ": 83,
88
+ "ts`⁼": 84,
89
+ "ts`ʰ": 85,
90
+ "s`": 86,
91
+ "ɹ`": 87,
92
+ "ts⁼": 88,
93
+ "tsʰ": 89,
94
+ "p⁼wo": 90,
95
+ "p⁼wo→": 91,
96
+ "p⁼wo↑": 92,
97
+ "p⁼wo↓↑": 93,
98
+ "p⁼wo↓": 94,
99
+ "pʰwo": 95,
100
+ "pʰwo→": 96,
101
+ "pʰwo↑": 97,
102
+ "pʰwo↓↑": 98,
103
+ "pʰwo↓": 99,
104
+ "mwo": 100,
105
+ "mwo→": 101,
106
+ "mwo↑": 102,
107
+ "mwo↓↑": 103,
108
+ "mwo↓": 104,
109
+ "fwo": 105,
110
+ "fwo→": 106,
111
+ "fwo↑": 107,
112
+ "fwo↓↑": 108,
113
+ "fwo↓": 109,
114
+ "jɛn": 110,
115
+ "jɛn→": 111,
116
+ "jɛn↑": 112,
117
+ "jɛn↓↑": 113,
118
+ "jɛn↓": 114,
119
+ "ɥæn": 115,
120
+ "ɥæn→": 116,
121
+ "ɥæn↑": 117,
122
+ "ɥæn↓↑": 118,
123
+ "ɥæn↓": 119,
124
+ "in": 120,
125
+ "in→": 121,
126
+ "in↑": 122,
127
+ "in↓↑": 123,
128
+ "in↓": 124,
129
+ "ɥn": 125,
130
+ "ɥn→": 126,
131
+ "ɥn↑": 127,
132
+ "ɥn↓↑": 128,
133
+ "ɥn↓": 129,
134
+ "iŋ": 130,
135
+ "iŋ→": 131,
136
+ "iŋ↑": 132,
137
+ "iŋ↓↑": 133,
138
+ "iŋ↓": 134,
139
+ "ʊŋ": 135,
140
+ "ʊŋ→": 136,
141
+ "ʊŋ↑": 137,
142
+ "ʊŋ↓↑": 138,
143
+ "ʊŋ↓": 139,
144
+ "jʊŋ": 140,
145
+ "jʊŋ→": 141,
146
+ "jʊŋ↑": 142,
147
+ "jʊŋ↓↑": 143,
148
+ "jʊŋ↓": 144,
149
+ "ia": 145,
150
+ "ia→": 146,
151
+ "ia↑": 147,
152
+ "ia↓↑": 148,
153
+ "ia↓": 149,
154
+ "iɛ": 150,
155
+ "iɛ→": 151,
156
+ "iɛ↑": 152,
157
+ "iɛ↓↑": 153,
158
+ "iɛ↓": 154,
159
+ "iɑʊ": 155,
160
+ "iɑʊ→": 156,
161
+ "iɑʊ↑": 157,
162
+ "iɑʊ↓↑": 158,
163
+ "iɑʊ↓": 159,
164
+ "ioʊ": 160,
165
+ "ioʊ→": 161,
166
+ "ioʊ↑": 162,
167
+ "ioʊ↓↑": 163,
168
+ "ioʊ↓": 164,
169
+ "iɑŋ": 165,
170
+ "iɑŋ→": 166,
171
+ "iɑŋ↑": 167,
172
+ "iɑŋ↓↑": 168,
173
+ "iɑŋ↓": 169,
174
+ "ua": 170,
175
+ "ua→": 171,
176
+ "ua↑": 172,
177
+ "ua↓↑": 173,
178
+ "ua↓": 174,
179
+ "uo": 175,
180
+ "uo→": 176,
181
+ "uo↑": 177,
182
+ "uo↓↑": 178,
183
+ "uo↓": 179,
184
+ "uaɪ": 180,
185
+ "uaɪ→": 181,
186
+ "uaɪ↑": 182,
187
+ "uaɪ↓↑": 183,
188
+ "uaɪ↓": 184,
189
+ "ueɪ": 185,
190
+ "ueɪ→": 186,
191
+ "ueɪ↑": 187,
192
+ "ueɪ↓↑": 188,
193
+ "ueɪ↓": 189,
194
+ "uan": 190,
195
+ "uan→": 191,
196
+ "uan↑": 192,
197
+ "uan↓↑": 193,
198
+ "uan↓": 194,
199
+ "uən": 195,
200
+ "uən→": 196,
201
+ "uən↑": 197,
202
+ "uən↓↑": 198,
203
+ "uən↓": 199,
204
+ "uɑŋ": 200,
205
+ "uɑŋ→": 201,
206
+ "uɑŋ↑": 202,
207
+ "uɑŋ↓↑": 203,
208
+ "uɑŋ↓": 204,
209
+ "ɥɛ": 205,
210
+ "ɥɛ→": 206,
211
+ "ɥɛ↑": 207,
212
+ "ɥɛ↓↑": 208,
213
+ "ɥɛ↓": 209,
214
+ "a": 210,
215
+ "a→": 211,
216
+ "a↑": 212,
217
+ "a↓↑": 213,
218
+ "a↓": 214,
219
+ "o": 215,
220
+ "o→": 216,
221
+ "o↑": 217,
222
+ "o↓↑": 218,
223
+ "o↓": 219,
224
+ "ə→": 220,
225
+ "ə↑": 221,
226
+ "ə↓↑": 222,
227
+ "ə↓": 223,
228
+ "ɛ→": 224,
229
+ "ɛ↑": 225,
230
+ "ɛ↓↑": 226,
231
+ "ɛ↓": 227,
232
+ "aɪ→": 228,
233
+ "aɪ↑": 229,
234
+ "aɪ↓↑": 230,
235
+ "aɪ↓": 231,
236
+ "eɪ→": 232,
237
+ "eɪ↑": 233,
238
+ "eɪ↓↑": 234,
239
+ "eɪ↓": 235,
240
+ "ɑʊ": 236,
241
+ "ɑʊ→": 237,
242
+ "ɑʊ↑": 238,
243
+ "ɑʊ↓↑": 239,
244
+ "ɑʊ↓": 240,
245
+ "oʊ→": 241,
246
+ "oʊ↑": 242,
247
+ "oʊ↓↑": 243,
248
+ "oʊ↓": 244,
249
+ "an": 245,
250
+ "an→": 246,
251
+ "an↑": 247,
252
+ "an↓↑": 248,
253
+ "an↓": 249,
254
+ "ən": 250,
255
+ "ən→": 251,
256
+ "ən↑": 252,
257
+ "ən↓↑": 253,
258
+ "ən↓": 254,
259
+ "ɑŋ": 255,
260
+ "ɑŋ→": 256,
261
+ "ɑŋ↑": 257,
262
+ "ɑŋ↓↑": 258,
263
+ "ɑŋ↓": 259,
264
+ "əŋ": 260,
265
+ "əŋ→": 261,
266
+ "əŋ↑": 262,
267
+ "əŋ↓↑": 263,
268
+ "əŋ↓": 264,
269
+ "əɹ": 265,
270
+ "əɹ→": 266,
271
+ "əɹ↑": 267,
272
+ "əɹ↓↑": 268,
273
+ "əɹ↓": 269,
274
+ "i→": 270,
275
+ "i↑": 271,
276
+ "i↓↑": 272,
277
+ "i↓": 273,
278
+ "u→": 274,
279
+ "u↑": 275,
280
+ "u↓↑": 276,
281
+ "u↓": 277,
282
+ "ɥ": 278,
283
+ "ɥ→": 279,
284
+ "ɥ↑": 280,
285
+ "ɥ↓↑": 281,
286
+ "ɥ↓": 282,
287
+ "ts`⁼ɹ": 283,
288
+ "ts`⁼ɹ→": 284,
289
+ "ts`⁼ɹ↑": 285,
290
+ "ts`⁼ɹ↓↑": 286,
291
+ "ts`⁼ɹ↓": 287,
292
+ "ts`ʰɹ": 288,
293
+ "ts`ʰɹ→": 289,
294
+ "ts`ʰɹ↑": 290,
295
+ "ts`ʰɹ↓↑": 291,
296
+ "ts`ʰɹ↓": 292,
297
+ "s`ɹ": 293,
298
+ "s`ɹ→": 294,
299
+ "s`ɹ↑": 295,
300
+ "s`ɹ↓↑": 296,
301
+ "s`ɹ���": 297,
302
+ "ɹ`ɹ": 298,
303
+ "ɹ`ɹ→": 299,
304
+ "ɹ`ɹ↑": 300,
305
+ "ɹ`ɹ↓↑": 301,
306
+ "ɹ`ɹ↓": 302,
307
+ "ts⁼ɹ": 303,
308
+ "ts⁼ɹ→": 304,
309
+ "ts⁼ɹ↑": 305,
310
+ "ts⁼ɹ↓↑": 306,
311
+ "ts⁼ɹ↓": 307,
312
+ "tsʰɹ": 308,
313
+ "tsʰɹ→": 309,
314
+ "tsʰɹ↑": 310,
315
+ "tsʰɹ↓↑": 311,
316
+ "tsʰɹ↓": 312,
317
+ "sɹ": 313,
318
+ "sɹ→": 314,
319
+ "sɹ↑": 315,
320
+ "sɹ↓↑": 316,
321
+ "sɹ↓": 317,
322
+
323
+ "ɯ": 318,
324
+ "e": 319,
325
+ "aː": 320,
326
+ "ɯː": 321,
327
+ "eː": 322,
328
+ "ç": 323,
329
+ "ɸ": 324,
330
+ "ɰᵝ": 325,
331
+ "ɴ": 326,
332
+ "g": 327,
333
+ "dʑ": 328,
334
+ "q": 329,
335
+ "ː": 330,
336
+ "bj": 331,
337
+ "tɕ": 332,
338
+ "dej": 333,
339
+ "tej": 334,
340
+ "gj": 335,
341
+ "gɯ": 336,
342
+ "çj": 337,
343
+ "kj": 338,
344
+ "kɯ": 339,
345
+ "mj": 340,
346
+ "nj": 341,
347
+ "pj": 342,
348
+ "ɾj": 343,
349
+ "ɕ": 344,
350
+ "tsɯ": 345,
351
+
352
+ "ɐ": 346,
353
+ "ɑ": 347,
354
+ "ɒ": 348,
355
+ "ɜ": 349,
356
+ "ɫ": 350,
357
+ "ʑ": 351,
358
+ "ʲ": 352,
359
+
360
+ "y": 353,
361
+ "ø": 354,
362
+ "œ": 355,
363
+ "ʁ": 356,
364
+ "̃": 357,
365
+ "ɲ": 358,
366
+
367
+ ":": 359,
368
+ ";": 360,
369
+ "'": 361,
370
+ "…": 362
371
+ }
372
+ }
diffrhythm/g2p/utils/front_utils.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import os
7
+
8
+
9
+ def generate_poly_lexicon(file_path: str):
10
+ """Generate poly char lexicon for Mandarin Chinese."""
11
+ poly_dict = {}
12
+
13
+ with open(file_path, "r", encoding="utf-8") as readf:
14
+ txt_list = readf.readlines()
15
+ for txt in txt_list:
16
+ word = txt.strip("\n")
17
+ if word not in poly_dict:
18
+ poly_dict[word] = 1
19
+ readf.close()
20
+ return poly_dict
diffrhythm/g2p/utils/g2p.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from phonemizer.backend import EspeakBackend
7
+ from phonemizer.separator import Separator
8
+ from phonemizer.utils import list2str, str2list
9
+ from typing import List, Union
10
+ import os
11
+ import json
12
+ import sys
13
+
14
+ # separator=Separator(phone=' ', word=' _ ', syllable='|'),
15
+ separator = Separator(word=" _ ", syllable="|", phone=" ")
16
+
17
+ phonemizer_zh = EspeakBackend(
18
+ "cmn", preserve_punctuation=False, with_stress=False, language_switch="remove-flags"
19
+ )
20
+ # phonemizer_zh.separator = separator
21
+
22
+ phonemizer_en = EspeakBackend(
23
+ "en-us",
24
+ preserve_punctuation=False,
25
+ with_stress=False,
26
+ language_switch="remove-flags",
27
+ )
28
+ # phonemizer_en.separator = separator
29
+
30
+ phonemizer_ja = EspeakBackend(
31
+ "ja", preserve_punctuation=False, with_stress=False, language_switch="remove-flags"
32
+ )
33
+ # phonemizer_ja.separator = separator
34
+
35
+ phonemizer_ko = EspeakBackend(
36
+ "ko", preserve_punctuation=False, with_stress=False, language_switch="remove-flags"
37
+ )
38
+ # phonemizer_ko.separator = separator
39
+
40
+ phonemizer_fr = EspeakBackend(
41
+ "fr-fr",
42
+ preserve_punctuation=False,
43
+ with_stress=False,
44
+ language_switch="remove-flags",
45
+ )
46
+ # phonemizer_fr.separator = separator
47
+
48
+ phonemizer_de = EspeakBackend(
49
+ "de", preserve_punctuation=False, with_stress=False, language_switch="remove-flags"
50
+ )
51
+ # phonemizer_de.separator = separator
52
+
53
+
54
+ lang2backend = {
55
+ "zh": phonemizer_zh,
56
+ "ja": phonemizer_ja,
57
+ "en": phonemizer_en,
58
+ "fr": phonemizer_fr,
59
+ "ko": phonemizer_ko,
60
+ "de": phonemizer_de,
61
+ }
62
+
63
+ with open("./diffrhythm/g2p/utils/mls_en.json", "r") as f:
64
+ json_data = f.read()
65
+ token = json.loads(json_data)
66
+
67
+
68
+ def phonemizer_g2p(text, language):
69
+ langbackend = lang2backend[language]
70
+ phonemes = _phonemize(
71
+ langbackend,
72
+ text,
73
+ separator,
74
+ strip=True,
75
+ njobs=1,
76
+ prepend_text=False,
77
+ preserve_empty_lines=False,
78
+ )
79
+ token_id = []
80
+ if isinstance(phonemes, list):
81
+ for phone in phonemes:
82
+ phonemes_split = phone.split(" ")
83
+ token_id.append([token[p] for p in phonemes_split if p in token])
84
+ else:
85
+ phonemes_split = phonemes.split(" ")
86
+ token_id = [token[p] for p in phonemes_split if p in token]
87
+ return phonemes, token_id
88
+
89
+
90
+ def _phonemize( # pylint: disable=too-many-arguments
91
+ backend,
92
+ text: Union[str, List[str]],
93
+ separator: Separator,
94
+ strip: bool,
95
+ njobs: int,
96
+ prepend_text: bool,
97
+ preserve_empty_lines: bool,
98
+ ):
99
+ """Auxiliary function to phonemize()
100
+
101
+ Does the phonemization and returns the phonemized text. Raises a
102
+ RuntimeError on error.
103
+
104
+ """
105
+ # remember the text type for output (either list or string)
106
+ text_type = type(text)
107
+
108
+ # force the text as a list
109
+ text = [line.strip(os.linesep) for line in str2list(text)]
110
+
111
+ # if preserving empty lines, note the index of each empty line
112
+ if preserve_empty_lines:
113
+ empty_lines = [n for n, line in enumerate(text) if not line.strip()]
114
+
115
+ # ignore empty lines
116
+ text = [line for line in text if line.strip()]
117
+
118
+ if text:
119
+ # phonemize the text
120
+ phonemized = backend.phonemize(
121
+ text, separator=separator, strip=strip, njobs=njobs
122
+ )
123
+ else:
124
+ phonemized = []
125
+
126
+ # if preserving empty lines, reinsert them into text and phonemized lists
127
+ if preserve_empty_lines:
128
+ for i in empty_lines: # noqa
129
+ if prepend_text:
130
+ text.insert(i, "")
131
+ phonemized.insert(i, "")
132
+
133
+ # at that point, the phonemized text is a list of str. Format it as
134
+ # expected by the parameters
135
+ if prepend_text:
136
+ return list(zip(text, phonemized))
137
+ if text_type == str:
138
+ return list2str(phonemized)
139
+ return phonemized
diffrhythm/g2p/utils/log.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+
7
+ import functools
8
+ import logging
9
+
10
+ __all__ = [
11
+ "logger",
12
+ ]
13
+
14
+
15
+ class Logger(object):
16
+ def __init__(self, name: str = None):
17
+ name = "PaddleSpeech" if not name else name
18
+ self.logger = logging.getLogger(name)
19
+
20
+ log_config = {
21
+ "DEBUG": 10,
22
+ "INFO": 20,
23
+ "TRAIN": 21,
24
+ "EVAL": 22,
25
+ "WARNING": 30,
26
+ "ERROR": 40,
27
+ "CRITICAL": 50,
28
+ "EXCEPTION": 100,
29
+ }
30
+ for key, level in log_config.items():
31
+ logging.addLevelName(level, key)
32
+ if key == "EXCEPTION":
33
+ self.__dict__[key.lower()] = self.logger.exception
34
+ else:
35
+ self.__dict__[key.lower()] = functools.partial(self.__call__, level)
36
+
37
+ self.format = logging.Formatter(
38
+ fmt="[%(asctime)-15s] [%(levelname)8s] - %(message)s"
39
+ )
40
+
41
+ self.handler = logging.StreamHandler()
42
+ self.handler.setFormatter(self.format)
43
+
44
+ self.logger.addHandler(self.handler)
45
+ self.logger.setLevel(logging.INFO)
46
+ self.logger.propagate = False
47
+
48
+ def __call__(self, log_level: str, msg: str):
49
+ self.logger.log(log_level, msg)
50
+
51
+
52
+ logger = Logger()
diffrhythm/g2p/utils/mls_en.json ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "[UNK]": 0,
3
+ "_": 1,
4
+ "b": 2,
5
+ "d": 3,
6
+ "f": 4,
7
+ "h": 5,
8
+ "i": 6,
9
+ "j": 7,
10
+ "k": 8,
11
+ "l": 9,
12
+ "m": 10,
13
+ "n": 11,
14
+ "p": 12,
15
+ "r": 13,
16
+ "s": 14,
17
+ "t": 15,
18
+ "v": 16,
19
+ "w": 17,
20
+ "x": 18,
21
+ "z": 19,
22
+ "æ": 20,
23
+ "ç": 21,
24
+ "ð": 22,
25
+ "ŋ": 23,
26
+ "ɐ": 24,
27
+ "ɔ": 25,
28
+ "ə": 26,
29
+ "ɚ": 27,
30
+ "ɛ": 28,
31
+ "ɡ": 29,
32
+ "ɪ": 30,
33
+ "ɬ": 31,
34
+ "ɹ": 32,
35
+ "ɾ": 33,
36
+ "ʃ": 34,
37
+ "ʊ": 35,
38
+ "ʌ": 36,
39
+ "ʒ": 37,
40
+ "ʔ": 38,
41
+ "θ": 39,
42
+ "ᵻ": 40,
43
+ "aɪ": 41,
44
+ "aʊ": 42,
45
+ "dʒ": 43,
46
+ "eɪ": 44,
47
+ "iə": 45,
48
+ "iː": 46,
49
+ "n̩": 47,
50
+ "oʊ": 48,
51
+ "oː": 49,
52
+ "tʃ": 50,
53
+ "uː": 51,
54
+ "ææ": 52,
55
+ "ɐɐ": 53,
56
+ "ɑː": 54,
57
+ "ɑ̃": 55,
58
+ "ɔɪ": 56,
59
+ "ɔː": 57,
60
+ "ɔ̃": 58,
61
+ "əl": 59,
62
+ "ɛɹ": 60,
63
+ "ɜː": 61,
64
+ "ɡʲ": 62,
65
+ "ɪɹ": 63,
66
+ "ʊɹ": 64,
67
+ "aɪə": 65,
68
+ "aɪɚ": 66,
69
+ "iːː": 67,
70
+ "oːɹ": 68,
71
+ "ɑːɹ": 69,
72
+ "ɔːɹ": 70,
73
+
74
+ "1": 71,
75
+ "a": 72,
76
+ "e": 73,
77
+ "o": 74,
78
+ "q": 75,
79
+ "u": 76,
80
+ "y": 77,
81
+ "ɑ": 78,
82
+ "ɒ": 79,
83
+ "ɕ": 80,
84
+ "ɣ": 81,
85
+ "ɫ": 82,
86
+ "ɯ": 83,
87
+ "ʐ": 84,
88
+ "ʲ": 85,
89
+ "a1": 86,
90
+ "a2": 87,
91
+ "a5": 88,
92
+ "ai": 89,
93
+ "aɜ": 90,
94
+ "aː": 91,
95
+ "ei": 92,
96
+ "eə": 93,
97
+ "i.": 94,
98
+ "i1": 95,
99
+ "i2": 96,
100
+ "i5": 97,
101
+ "io": 98,
102
+ "iɑ": 99,
103
+ "iɛ": 100,
104
+ "iɜ": 101,
105
+ "i̪": 102,
106
+ "kh": 103,
107
+ "nʲ": 104,
108
+ "o1": 105,
109
+ "o2": 106,
110
+ "o5": 107,
111
+ "ou": 108,
112
+ "oɜ": 109,
113
+ "ph": 110,
114
+ "s.": 111,
115
+ "th": 112,
116
+ "ts": 113,
117
+ "tɕ": 114,
118
+ "u1": 115,
119
+ "u2": 116,
120
+ "u5": 117,
121
+ "ua": 118,
122
+ "uo": 119,
123
+ "uə": 120,
124
+ "uɜ": 121,
125
+ "y1": 122,
126
+ "y2": 123,
127
+ "y5": 124,
128
+ "yu": 125,
129
+ "yæ": 126,
130
+ "yə": 127,
131
+ "yɛ": 128,
132
+ "yɜ": 129,
133
+ "ŋɜ": 130,
134
+ "ŋʲ": 131,
135
+ "ɑ1": 132,
136
+ "ɑ2": 133,
137
+ "ɑ5": 134,
138
+ "ɑu": 135,
139
+ "ɑɜ": 136,
140
+ "ɑʲ": 137,
141
+ "ə1": 138,
142
+ "ə2": 139,
143
+ "ə5": 140,
144
+ "ər": 141,
145
+ "əɜ": 142,
146
+ "əʊ": 143,
147
+ "ʊə": 144,
148
+ "ai1": 145,
149
+ "ai2": 146,
150
+ "ai5": 147,
151
+ "aiɜ": 148,
152
+ "ei1": 149,
153
+ "ei2": 150,
154
+ "ei5": 151,
155
+ "eiɜ": 152,
156
+ "i.1": 153,
157
+ "i.2": 154,
158
+ "i.5": 155,
159
+ "i.ɜ": 156,
160
+ "io5": 157,
161
+ "iou": 158,
162
+ "iɑ1": 159,
163
+ "iɑ2": 160,
164
+ "iɑ5": 161,
165
+ "iɑɜ": 162,
166
+ "iɛ1": 163,
167
+ "iɛ2": 164,
168
+ "iɛ5": 165,
169
+ "iɛɜ": 166,
170
+ "i̪1": 167,
171
+ "i̪2": 168,
172
+ "i̪5": 169,
173
+ "i̪ɜ": 170,
174
+ "onɡ": 171,
175
+ "ou1": 172,
176
+ "ou2": 173,
177
+ "ou5": 174,
178
+ "ouɜ": 175,
179
+ "ts.": 176,
180
+ "tsh": 177,
181
+ "tɕh": 178,
182
+ "u5ʲ": 179,
183
+ "ua1": 180,
184
+ "ua2": 181,
185
+ "ua5": 182,
186
+ "uai": 183,
187
+ "uaɜ": 184,
188
+ "uei": 185,
189
+ "uo1": 186,
190
+ "uo2": 187,
191
+ "uo5": 188,
192
+ "uoɜ": 189,
193
+ "uə1": 190,
194
+ "uə2": 191,
195
+ "uə5": 192,
196
+ "uəɜ": 193,
197
+ "yiɜ": 194,
198
+ "yu2": 195,
199
+ "yu5": 196,
200
+ "yæ2": 197,
201
+ "yæ5": 198,
202
+ "yæɜ": 199,
203
+ "yə2": 200,
204
+ "yə5": 201,
205
+ "yəɜ": 202,
206
+ "yɛ1": 203,
207
+ "yɛ2": 204,
208
+ "yɛ5": 205,
209
+ "yɛɜ": 206,
210
+ "ɑu1": 207,
211
+ "ɑu2": 208,
212
+ "ɑu5": 209,
213
+ "ɑuɜ": 210,
214
+ "ər1": 211,
215
+ "ər2": 212,
216
+ "ər5": 213,
217
+ "ərɜ": 214,
218
+ "əː1": 215,
219
+ "iou1": 216,
220
+ "iou2": 217,
221
+ "iou5": 218,
222
+ "iouɜ": 219,
223
+ "onɡ1": 220,
224
+ "onɡ2": 221,
225
+ "onɡ5": 222,
226
+ "onɡɜ": 223,
227
+ "ts.h": 224,
228
+ "uai2": 225,
229
+ "uai5": 226,
230
+ "uaiɜ": 227,
231
+ "uei1": 228,
232
+ "uei2": 229,
233
+ "uei5": 230,
234
+ "ueiɜ": 231,
235
+ "uoɜʲ": 232,
236
+ "yɛ5ʲ": 233,
237
+ "ɑu2ʲ": 234,
238
+
239
+ "2": 235,
240
+ "5": 236,
241
+ "ɜ": 237,
242
+ "ʂ": 238,
243
+ "dʑ": 239,
244
+ "iɪ": 240,
245
+ "uɪ": 241,
246
+ "xʲ": 242,
247
+ "ɑt": 243,
248
+ "ɛɜ": 244,
249
+ "ɛː": 245,
250
+ "ɪː": 246,
251
+ "phʲ": 247,
252
+ "ɑ5ʲ": 248,
253
+ "ɑuʲ": 249,
254
+ "ərə": 250,
255
+ "uozʰ": 251,
256
+ "ər1ʲ": 252,
257
+ "tɕhtɕh": 253,
258
+
259
+ "c": 254,
260
+ "ʋ": 255,
261
+ "ʍ": 256,
262
+ "ʑ": 257,
263
+ "ː": 258,
264
+ "aə": 259,
265
+ "eː": 260,
266
+ "hʲ": 261,
267
+ "iʊ": 262,
268
+ "kʲ": 263,
269
+ "lʲ": 264,
270
+ "oə": 265,
271
+ "oɪ": 266,
272
+ "oʲ": 267,
273
+ "pʲ": 268,
274
+ "sʲ": 269,
275
+ "u4": 270,
276
+ "uʲ": 271,
277
+ "yi": 272,
278
+ "yʲ": 273,
279
+ "ŋ2": 274,
280
+ "ŋ5": 275,
281
+ "ŋ̩": 276,
282
+ "ɑɪ": 277,
283
+ "ɑʊ": 278,
284
+ "ɕʲ": 279,
285
+ "ət": 280,
286
+ "əə": 281,
287
+ "əɪ": 282,
288
+ "əʲ": 283,
289
+ "ɛ1": 284,
290
+ "ɛ5": 285,
291
+ "aiə": 286,
292
+ "aiɪ": 287,
293
+ "azʰ": 288,
294
+ "eiə": 289,
295
+ "eiɪ": 290,
296
+ "eiʊ": 291,
297
+ "i.ə": 292,
298
+ "i.ɪ": 293,
299
+ "i.ʊ": 294,
300
+ "ioɜ": 295,
301
+ "izʰ": 296,
302
+ "iɑə": 297,
303
+ "iɑʊ": 298,
304
+ "iɑʲ": 299,
305
+ "iɛə": 300,
306
+ "iɛɪ": 301,
307
+ "iɛʊ": 302,
308
+ "i̪ə": 303,
309
+ "i̪ʊ": 304,
310
+ "khʲ": 305,
311
+ "ouʲ": 306,
312
+ "tsʲ": 307,
313
+ "u2ʲ": 308,
314
+ "uoɪ": 309,
315
+ "uzʰ": 310,
316
+ "uɜʲ": 311,
317
+ "yæɪ": 312,
318
+ "yəʊ": 313,
319
+ "ərt": 314,
320
+ "ərɪ": 315,
321
+ "ərʲ": 316,
322
+ "əːt": 317,
323
+ "iouə": 318,
324
+ "iouʊ": 319,
325
+ "iouʲ": 320,
326
+ "iɛzʰ": 321,
327
+ "onɡə": 322,
328
+ "onɡɪ": 323,
329
+ "onɡʊ": 324,
330
+ "ouzʰ": 325,
331
+ "uai1": 326,
332
+ "ueiɪ": 327,
333
+ "ɑuzʰ": 328,
334
+ "iouzʰ": 329
335
+ }
diffrhythm/infer/infer.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchaudio
3
+ from einops import rearrange
4
+ import argparse
5
+ import json
6
+ import os
7
+ from tqdm import tqdm
8
+ import random
9
+ import numpy as np
10
+ import time
11
+
12
+ from diffrhythm.infer.infer_utils import (
13
+ get_reference_latent,
14
+ get_lrc_token,
15
+ get_style_prompt,
16
+ prepare_model,
17
+ get_negative_style_prompt
18
+ )
19
+
20
+ def decode_audio(latents, vae_model, chunked=False, overlap=32, chunk_size=128):
21
+ downsampling_ratio = 2048
22
+ io_channels = 2
23
+ if not chunked:
24
+ # default behavior. Decode the entire latent in parallel
25
+ return vae_model.decode_export(latents)
26
+ else:
27
+ # chunked decoding
28
+ hop_size = chunk_size - overlap
29
+ total_size = latents.shape[2]
30
+ batch_size = latents.shape[0]
31
+ chunks = []
32
+ i = 0
33
+ for i in range(0, total_size - chunk_size + 1, hop_size):
34
+ chunk = latents[:,:,i:i+chunk_size]
35
+ chunks.append(chunk)
36
+ if i+chunk_size != total_size:
37
+ # Final chunk
38
+ chunk = latents[:,:,-chunk_size:]
39
+ chunks.append(chunk)
40
+ chunks = torch.stack(chunks)
41
+ num_chunks = chunks.shape[0]
42
+ # samples_per_latent is just the downsampling ratio
43
+ samples_per_latent = downsampling_ratio
44
+ # Create an empty waveform, we will populate it with chunks as decode them
45
+ y_size = total_size * samples_per_latent
46
+ y_final = torch.zeros((batch_size,io_channels,y_size)).to(latents.device)
47
+ for i in range(num_chunks):
48
+ x_chunk = chunks[i,:]
49
+ # decode the chunk
50
+ y_chunk = vae_model.decode_export(x_chunk)
51
+ # figure out where to put the audio along the time domain
52
+ if i == num_chunks-1:
53
+ # final chunk always goes at the end
54
+ t_end = y_size
55
+ t_start = t_end - y_chunk.shape[2]
56
+ else:
57
+ t_start = i * hop_size * samples_per_latent
58
+ t_end = t_start + chunk_size * samples_per_latent
59
+ # remove the edges of the overlaps
60
+ ol = (overlap//2) * samples_per_latent
61
+ chunk_start = 0
62
+ chunk_end = y_chunk.shape[2]
63
+ if i > 0:
64
+ # no overlap for the start of the first chunk
65
+ t_start += ol
66
+ chunk_start += ol
67
+ if i < num_chunks-1:
68
+ # no overlap for the end of the last chunk
69
+ t_end -= ol
70
+ chunk_end -= ol
71
+ # paste the chunked audio into our y_final output audio
72
+ y_final[:,:,t_start:t_end] = y_chunk[:,:,chunk_start:chunk_end]
73
+ return y_final
74
+
75
+ def inference(cfm_model, vae_model, cond, text, duration, style_prompt, negative_style_prompt, start_time):
76
+ # import pdb; pdb.set_trace()
77
+ with torch.inference_mode():
78
+ generated, _ = cfm_model.sample(
79
+ cond=cond,
80
+ text=text,
81
+ duration=duration,
82
+ style_prompt=style_prompt,
83
+ negative_style_prompt=negative_style_prompt,
84
+ steps=32,
85
+ cfg_strength=4.0,
86
+ start_time=start_time
87
+ )
88
+
89
+ generated = generated.to(torch.float32)
90
+ latent = generated.transpose(1, 2) # [b d t]
91
+
92
+ output = decode_audio(latent, vae_model)
93
+
94
+ # Rearrange audio batch to a single sequence
95
+ output = rearrange(output, "b d n -> d (b n)")
96
+ # Peak normalize, clip, convert to int16, and save to file
97
+ output = output.to(torch.float32).div(torch.max(torch.abs(output))).clamp(-1, 1).mul(32767).to(torch.int16).cpu()
98
+
99
+ return output
100
+
101
+ if __name__ == "__main__":
102
+ parser = argparse.ArgumentParser()
103
+ parser.add_argument('--lrc-path', type=str, default="/home/node59_tmpdata3/hkchen/DiffRhythm/diffrhythm/diffrhythm/infer/example/eg.lrc") # lyrics of target song
104
+ parser.add_argument('--ref-audio-path', type=str, default="/home/node59_tmpdata3/hkchen/DiffRhythm/diffrhythm/diffrhythm/infer/example/eg.mp3") # reference audio as style prompt for target song
105
+ parser.add_argument('--audio-length', type=int, default=95) # length of target song
106
+ parser.add_argument('--output-dir', type=str, default="/home/node59_tmpdata3/hkchen/DiffRhythm/diffrhythm/diffrhythm/infer/example/output")
107
+ args = parser.parse_args()
108
+
109
+ device = 'cuda'
110
+
111
+ audio_length = args.audio_length
112
+ if audio_length == 95:
113
+ max_frames = 2048
114
+ elif audio_length == 285:
115
+ max_frames = 6144
116
+
117
+ cfm, tokenizer, muq, vae = prepare_model(device)
118
+
119
+ with open(args.lrc_path, 'r') as f:
120
+ lrc = f.read()
121
+ lrc_prompt, start_time = get_lrc_token(lrc, tokenizer, device)
122
+
123
+ style_prompt = get_style_prompt(muq, args.ref_audio_path)
124
+
125
+ negative_style_prompt = get_negative_style_prompt(device)
126
+
127
+ latent_prompt = get_reference_latent(device, max_frames)
128
+
129
+ s_t = time.time()
130
+ generated_song = inference(cfm_model=cfm,
131
+ vae_model=vae,
132
+ cond=latent_prompt,
133
+ text=lrc_prompt,
134
+ duration=max_frames,
135
+ style_prompt=style_prompt,
136
+ negative_style_prompt=negative_style_prompt,
137
+ start_time=start_time
138
+ )
139
+ e_t = time.time() - s_t
140
+ print(f"inference cost {e_t} seconds")
141
+
142
+ output_dir = args.output_dir
143
+ os.makedirs(output_dir, exist_ok=True)
144
+
145
+ output_path = os.path.join(output_dir, "output.wav")
146
+ torchaudio.save(output_path, generated_song, sample_rate=44100)
147
+
diffrhythm/infer/infer_utils.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import librosa
3
+ import random
4
+ import json
5
+ from muq import MuQMuLan
6
+ from mutagen.mp3 import MP3
7
+ import os
8
+ import numpy as np
9
+
10
+ from diffrhythm.model import DiT, CFM
11
+
12
+
13
+ def prepare_model(device):
14
+ # prepare cfm model
15
+ dit_ckpt_path = "/home/node59_tmpdata3/hkchen/music_opensource/dit_model_dpo_normal.pt"
16
+ dit_config_path = "/home/node59_tmpdata3/hkchen/DiffRhythm/diffrhythm/diffrhythm/config/diffrhythm-1b.json"
17
+ with open(dit_config_path) as f:
18
+ model_config = json.load(f)
19
+ dit_model_cls = DiT
20
+ cfm = CFM(
21
+ transformer=dit_model_cls(**model_config["model"], use_style_prompt=True),
22
+ num_channels=model_config["model"]['mel_dim'],
23
+ use_style_prompt=True
24
+ )
25
+ cfm = cfm.to(device)
26
+ cfm = load_checkpoint(cfm, dit_ckpt_path, device=device, use_ema=False)
27
+
28
+ # prepare tokenizer
29
+ tokenizer = CNENTokenizer()
30
+
31
+ # prepare muq
32
+ muq = MuQMuLan.from_pretrained("OpenMuQ/MuQ-MuLan-large")
33
+ muq = muq.to(device).eval()
34
+
35
+ # prepare vae
36
+ vae = torch.jit.load("/home/node59_tmpdata3/hkchen/F5-TTS-V0/infer/vae_infer.pt").to(device)
37
+
38
+ return cfm, tokenizer, muq, vae
39
+
40
+
41
+ # for song edit, will be added in the future
42
+ def get_reference_latent(device, max_frames):
43
+ return torch.zeros(1, max_frames, 64).to(device)
44
+
45
+ def get_negative_style_prompt(device):
46
+ file_path = "/home/node59_tmpdata3/hkchen/DiffRhythm/diffrhythm/diffrhythm/infer/example/vocal.npy"
47
+ vocal_stlye = np.load(file_path)
48
+
49
+ vocal_stlye = torch.from_numpy(vocal_stlye).to(device) # [1, 512]
50
+ vocal_stlye = vocal_stlye.half()
51
+
52
+ return vocal_stlye
53
+
54
+ def get_style_prompt(model, wav_path):
55
+ mulan = model
56
+
57
+ ext = os.path.splitext(wav_path)[-1].lower()
58
+ if ext == '.mp3':
59
+ meta = MP3(wav_path)
60
+ audio_len = meta.info.length
61
+ src_sr = meta.info.sample_rate
62
+ elif ext == '.wav':
63
+ audio, sr = librosa.load(wav_path, sr=None)
64
+ audio_len = librosa.get_duration(y=audio, sr=sr)
65
+ src_sr = sr
66
+ else:
67
+ raise ValueError("Unsupported file format: {}".format(ext))
68
+
69
+ assert(audio_len >= 10)
70
+
71
+ mid_time = audio_len // 2
72
+ start_time = mid_time - 5
73
+ wav, sr = librosa.load(wav_path, sr=None, offset=start_time, duration=10)
74
+
75
+ resampled_wav = librosa.resample(wav, orig_sr=src_sr, target_sr=24000)
76
+ resampled_wav = torch.tensor(resampled_wav).unsqueeze(0).to(model.device)
77
+
78
+ with torch.no_grad():
79
+ audio_emb = mulan(wavs = resampled_wav) # [1, 512]
80
+
81
+ audio_emb = audio_emb
82
+ audio_emb = audio_emb.half()
83
+
84
+ return audio_emb
85
+
86
+ def parse_lyrics(lyrics: str):
87
+ lyrics_with_time = []
88
+ lyrics = lyrics.strip()
89
+ for line in lyrics.split('\n'):
90
+ try:
91
+ time, lyric = line[1:9], line[10:]
92
+ lyric = lyric.strip()
93
+ mins, secs = time.split(':')
94
+ secs = int(mins) * 60 + float(secs)
95
+ lyrics_with_time.append((secs, lyric))
96
+ except:
97
+ continue
98
+ return lyrics_with_time
99
+
100
+ class CNENTokenizer():
101
+ def __init__(self):
102
+ with open('./diffrhythm/g2p/g2p/vocab.json', 'r') as file:
103
+ self.phone2id:dict = json.load(file)['vocab']
104
+ self.id2phone = {v:k for (k, v) in self.phone2id.items()}
105
+ # from f5_tts.g2p.g2p_generation import chn_eng_g2p
106
+ from diffrhythm.g2p.g2p_generation import chn_eng_g2p
107
+ self.tokenizer = chn_eng_g2p
108
+ def encode(self, text):
109
+ phone, token = self.tokenizer(text)
110
+ token = [x+1 for x in token]
111
+ return token
112
+ def decode(self, token):
113
+ return "|".join([self.id2phone[x-1] for x in token])
114
+
115
+ def get_lrc_token(text, tokenizer, device):
116
+
117
+ max_frames = 2048
118
+ lyrics_shift = 0
119
+ sampling_rate = 44100
120
+ downsample_rate = 2048
121
+ max_secs = max_frames / (sampling_rate / downsample_rate)
122
+
123
+ pad_token_id = 0
124
+ comma_token_id = 1
125
+ period_token_id = 2
126
+
127
+ lrc_with_time = parse_lyrics(text)
128
+
129
+ modified_lrc_with_time = []
130
+ for i in range(len(lrc_with_time)):
131
+ time, line = lrc_with_time[i]
132
+ line_token = tokenizer.encode(line)
133
+ modified_lrc_with_time.append((time, line_token))
134
+ lrc_with_time = modified_lrc_with_time
135
+
136
+ lrc_with_time = [(time_start, line) for (time_start, line) in lrc_with_time if time_start < max_secs]
137
+ lrc_with_time = lrc_with_time[:-1] if len(lrc_with_time) >= 1 else lrc_with_time
138
+
139
+ normalized_start_time = 0.
140
+
141
+ lrc = torch.zeros((max_frames,), dtype=torch.long)
142
+
143
+ tokens_count = 0
144
+ last_end_pos = 0
145
+ for time_start, line in lrc_with_time:
146
+ tokens = [token if token != period_token_id else comma_token_id for token in line] + [period_token_id]
147
+ tokens = torch.tensor(tokens, dtype=torch.long)
148
+ num_tokens = tokens.shape[0]
149
+
150
+ gt_frame_start = int(time_start * sampling_rate / downsample_rate)
151
+
152
+ frame_shift = random.randint(int(lyrics_shift), int(lyrics_shift))
153
+
154
+ frame_start = max(gt_frame_start - frame_shift, last_end_pos)
155
+ frame_len = min(num_tokens, max_frames - frame_start)
156
+
157
+ #print(gt_frame_start, frame_shift, frame_start, frame_len, tokens_count, last_end_pos, full_pos_emb.shape)
158
+
159
+ lrc[frame_start:frame_start + frame_len] = tokens[:frame_len]
160
+
161
+ tokens_count += num_tokens
162
+ last_end_pos = frame_start + frame_len
163
+
164
+ lrc_emb = lrc.unsqueeze(0).to(device)
165
+
166
+ normalized_start_time = torch.tensor(normalized_start_time).unsqueeze(0).to(device)
167
+ normalized_start_time = normalized_start_time.half()
168
+
169
+ return lrc_emb, normalized_start_time
170
+
171
+ def load_checkpoint(model, ckpt_path, device, use_ema=True):
172
+ if device == "cuda":
173
+ model = model.half()
174
+
175
+ ckpt_type = ckpt_path.split(".")[-1]
176
+ if ckpt_type == "safetensors":
177
+ from safetensors.torch import load_file
178
+
179
+ checkpoint = load_file(ckpt_path)
180
+ else:
181
+ checkpoint = torch.load(ckpt_path, weights_only=True)
182
+
183
+ if use_ema:
184
+ if ckpt_type == "safetensors":
185
+ checkpoint = {"ema_model_state_dict": checkpoint}
186
+ checkpoint["model_state_dict"] = {
187
+ k.replace("ema_model.", ""): v
188
+ for k, v in checkpoint["ema_model_state_dict"].items()
189
+ if k not in ["initted", "step"]
190
+ }
191
+ model.load_state_dict(checkpoint["model_state_dict"], strict=False)
192
+ else:
193
+ if ckpt_type == "safetensors":
194
+ checkpoint = {"model_state_dict": checkpoint}
195
+ model.load_state_dict(checkpoint["model_state_dict"], strict=False)
196
+
197
+ return model.to(device)