mlinmg commited on
Commit
aebdb64
·
verified ·
1 Parent(s): 69c1c0a

Update tokenizer.py

Browse files
Files changed (1) hide show
  1. tokenizer.py +750 -55
tokenizer.py CHANGED
@@ -1,25 +1,675 @@
1
- from typing import List, Optional, Union, Dict, Tuple, Any
2
- import os
3
  from functools import cached_property
4
 
5
- from transformers import PreTrainedTokenizerFast
6
- from transformers.tokenization_utils_base import TruncationStrategy, PaddingStrategy
7
- from tokenizers import Tokenizer, processors
8
- from tokenizers.pre_tokenizers import WhitespaceSplit
9
- from tokenizers.processors import TemplateProcessing
10
  import torch
11
  from hangul_romanize import Transliter
12
  from hangul_romanize.rule import academic
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  import cutlet
14
 
15
- from TTS.tts.layers.xtts.tokenizer import (multilingual_cleaners, basic_cleaners,
16
- chinese_transliterate, korean_transliterate,
17
- japanese_cleaners)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  class XTTSTokenizerFast(PreTrainedTokenizerFast):
20
  """
21
  Fast Tokenizer implementation for XTTS model using HuggingFace's PreTrainedTokenizerFast
22
  """
 
23
  def __init__(
24
  self,
25
  vocab_file: str = None,
@@ -28,6 +678,7 @@ class XTTSTokenizerFast(PreTrainedTokenizerFast):
28
  pad_token: str = "[PAD]",
29
  bos_token: str = "[START]",
30
  eos_token: str = "[STOP]",
 
31
  clean_up_tokenization_spaces: bool = True,
32
  **kwargs
33
  ):
@@ -37,11 +688,6 @@ class XTTSTokenizerFast(PreTrainedTokenizerFast):
37
  if tokenizer_object is not None:
38
  # Configure the tokenizer
39
  tokenizer_object.pre_tokenizer = WhitespaceSplit()
40
- tokenizer_object.enable_padding(
41
- direction='right',
42
- pad_id=tokenizer_object.token_to_id(pad_token) or 0,
43
- pad_token=pad_token
44
- )
45
  tokenizer_object.post_processor = TemplateProcessing(
46
  single=f"{bos_token} $A {eos_token}",
47
  special_tokens=[
@@ -72,41 +718,89 @@ class XTTSTokenizerFast(PreTrainedTokenizerFast):
72
  self._katsu = None
73
  self._korean_transliter = Transliter(academic)
74
 
 
 
 
 
75
  @cached_property
76
  def katsu(self):
77
  if self._katsu is None:
78
  self._katsu = cutlet.Cutlet()
79
  return self._katsu
80
 
81
- def check_input_length(self, text: str, lang: str):
82
- """Check if input text length is within limits for language"""
83
- lang = lang.split("-")[0] # remove region
84
- limit = self.char_limits.get(lang, 250)
85
- if len(text) > limit:
86
- print(f"Warning: Text length exceeds {limit} char limit for '{lang}', may cause truncation.")
87
-
88
  def preprocess_text(self, text: str, lang: str) -> str:
89
  """Apply text preprocessing for language"""
90
- if lang in {"ar", "cs", "de", "en", "es", "fr", "hu", "it",
91
- "nl", "pl", "pt", "ru", "tr", "zh", "ko"}:
92
- text = multilingual_cleaners(text, lang)
93
- if lang == "zh":
 
94
  text = chinese_transliterate(text)
95
- if lang == "ko":
96
- text = korean_transliterate(text)
97
- elif lang == "ja":
98
  text = japanese_cleaners(text, self.katsu)
99
  else:
100
  text = basic_cleaners(text)
101
  return text
102
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  def _batch_encode_plus(
104
  self,
105
  batch_text_or_text_pairs,
106
  add_special_tokens: bool = True,
107
- padding_strategy = PaddingStrategy.DO_NOT_PAD,
108
- truncation_strategy = TruncationStrategy.DO_NOT_TRUNCATE,
109
- max_length: Optional[int] = 402,
110
  stride: int = 0,
111
  is_split_into_words: bool = False,
112
  pad_to_multiple_of: Optional[int] = None,
@@ -125,18 +819,26 @@ class XTTSTokenizerFast(PreTrainedTokenizerFast):
125
  """
126
  lang = kwargs.pop("lang", ["en"] * len(batch_text_or_text_pairs))
127
  if isinstance(lang, str):
128
- lang = [lang] * len(batch_text_or_text_pairs)
 
 
 
 
 
 
 
129
 
130
  # Preprocess each text in the batch with its corresponding language
131
  processed_texts = []
132
  for text, text_lang in zip(batch_text_or_text_pairs, lang):
133
  if isinstance(text, str):
134
  # Check length and preprocess
135
- self.check_input_length(text, text_lang)
136
  processed_text = self.preprocess_text(text, text_lang)
137
 
138
  # Format text with language tag and spaces
139
- lang_code = "zh-cn" if text_lang == "zh" else text_lang
 
140
  processed_text = f"[{lang_code}]{processed_text}"
141
  processed_text = processed_text.replace(" ", "[SPACE]")
142
 
@@ -165,47 +867,40 @@ class XTTSTokenizerFast(PreTrainedTokenizerFast):
165
  **kwargs
166
  )
167
 
 
168
  def __call__(
169
  self,
170
  text: Union[str, List[str]],
171
  lang: Union[str, List[str]] = "en",
172
  add_special_tokens: bool = True,
173
- padding: Union[bool, str, PaddingStrategy] = True, # Changed default to True
174
- truncation: Union[bool, str, TruncationStrategy] = True, # Changed default to True
175
- max_length: Optional[int] = 402,
176
  stride: int = 0,
177
  return_tensors: Optional[str] = None,
178
  return_token_type_ids: Optional[bool] = None,
179
- return_attention_mask: Optional[bool] = True, # Changed default to True
180
  **kwargs
181
  ):
182
  """
183
  Main tokenization method
184
- Args:
185
- text: Text or list of texts to tokenize
186
- lang: Language code or list of language codes corresponding to each text
187
- add_special_tokens: Whether to add special tokens
188
- padding: Padding strategy (default True)
189
- truncation: Truncation strategy (default True)
190
- max_length: Maximum length
191
- stride: Stride for truncation
192
- return_tensors: Format of output tensors ("pt" for PyTorch)
193
- return_token_type_ids: Whether to return token type IDs
194
- return_attention_mask: Whether to return attention mask (default True)
195
  """
196
  # Convert single string to list for batch processing
197
  if isinstance(text, str):
198
  text = [text]
199
- if isinstance(lang, str):
200
- lang = [lang]
 
 
 
201
 
202
  # Ensure text and lang lists have same length
203
  if len(text) != len(lang):
204
- raise ValueError(f"Number of texts ({len(text)}) must match number of language codes ({len(lang)})")
205
 
206
  # Convert padding strategy
207
  if isinstance(padding, bool):
208
- padding_strategy = PaddingStrategy.MAX_LENGTH if padding else PaddingStrategy.DO_NOT_PAD
209
  else:
210
  padding_strategy = PaddingStrategy(padding)
211
 
@@ -230,4 +925,4 @@ class XTTSTokenizerFast(PreTrainedTokenizerFast):
230
  **kwargs
231
  )
232
 
233
- return encoded
 
1
+ import re
2
+ from typing import List, Optional, Union, Dict, Any
3
  from functools import cached_property
4
 
5
+ import pypinyin
 
 
 
 
6
  import torch
7
  from hangul_romanize import Transliter
8
  from hangul_romanize.rule import academic
9
+ from num2words import num2words
10
+ from spacy.lang.ar import Arabic
11
+ from spacy.lang.en import English
12
+ from spacy.lang.es import Spanish
13
+ from spacy.lang.ja import Japanese
14
+ from spacy.lang.zh import Chinese
15
+ from transformers import PreTrainedTokenizerFast, BatchEncoding
16
+ from transformers.tokenization_utils_base import TruncationStrategy, PaddingStrategy
17
+ from tokenizers import Tokenizer
18
+ from tokenizers.pre_tokenizers import WhitespaceSplit
19
+ from tokenizers.processors import TemplateProcessing
20
+
21
+ from auralis.models.xttsv2.components.tts.layers.xtts.zh_num2words import TextNorm as zh_num2words
22
+
23
  import cutlet
24
 
25
+ def get_spacy_lang(lang):
26
+ if lang == "zh":
27
+ return Chinese()
28
+ elif lang == "ja":
29
+ return Japanese()
30
+ elif lang == "ar":
31
+ return Arabic()
32
+ elif lang == "es":
33
+ return Spanish()
34
+ else:
35
+ # For most languages, English does the job
36
+ return English()
37
+
38
+
39
+ def find_best_split_point(text: str, target_pos: int, window_size: int = 30) -> int:
40
+ """
41
+ Find best split point near target position considering punctuation and language markers.
42
+ added for better sentence splitting in TTS.
43
+ """
44
+ # Define split markers by priority
45
+ markers = [
46
+ # Strong breaks (longest pause)
47
+ (r'[.!?؟။။။]+[\s]*', 1.0), # Periods, exclamation, question (multi-script)
48
+ (r'[\n\r]+\s*[\n\r]+', 1.0), # Multiple newlines
49
+ (r'[:|;;:;][\s]*', 0.9), # Colons, semicolons (multi-script)
50
+
51
+ # Medium breaks
52
+ (r'[,,،、][\s]*', 0.8), # Commas (multi-script)
53
+ (r'[)}\])】』»›》\s]+', 0.7), # Closing brackets/parentheses
54
+ (r'[-—−]+[\s]*', 0.7), # Dashes
55
+
56
+ # Weak breaks
57
+ (r'\s+[&+=/\s]+\s+', 0.6), # Special characters with spaces
58
+ (r'[\s]+', 0.5), # Any whitespace as last resort
59
+ ]
60
+
61
+ # Calculate window boundaries
62
+ start = max(0, target_pos - window_size)
63
+ end = min(len(text), target_pos + window_size)
64
+ window = text[start:end]
65
+
66
+ best_pos = target_pos
67
+ best_score = 0
68
+
69
+ for pattern, priority in markers:
70
+ matches = list(re.finditer(pattern, window))
71
+ for match in matches:
72
+ # Calculate position score based on distance from target
73
+ pos = start + match.end()
74
+ distance = abs(pos - target_pos)
75
+ distance_score = 1 - (distance / (window_size * 2))
76
+
77
+ # Combine priority and position scores
78
+ score = priority * distance_score
79
+
80
+ if score > best_score:
81
+ best_score = score
82
+ best_pos = pos
83
+
84
+ return best_pos
85
+
86
+
87
+ def split_sentence(text: str, lang: str, text_split_length: int = 250) -> List[str]:
88
+ """
89
+ Enhanced sentence splitting with language awareness and optimal breakpoints.
90
+
91
+ Args:
92
+ text: Input text to split
93
+ lang: Language code
94
+ text_split_length: Target length for splits
95
+
96
+ Returns:
97
+ List of text splits optimized for TTS
98
+ """
99
+ text = text.strip()
100
+ if len(text) <= text_split_length:
101
+ return [text]
102
+
103
+ nlp = get_spacy_lang(lang)
104
+ if "sentencizer" not in nlp.pipe_names:
105
+ nlp.add_pipe("sentencizer")
106
+
107
+ # Get base sentences using spaCy
108
+ doc = nlp(text)
109
+ sentences = list(doc.sents)
110
+
111
+ splits = []
112
+ current_split = []
113
+ current_length = 0
114
+
115
+ for sent in sentences:
116
+ sentence_text = str(sent).strip()
117
+ sentence_length = len(sentence_text)
118
+
119
+ # If sentence fits in current split
120
+ if current_length + sentence_length <= text_split_length:
121
+ current_split.append(sentence_text)
122
+ current_length += sentence_length + 1
123
+
124
+ # Handle long sentences
125
+ elif sentence_length > text_split_length:
126
+ # Add current split if exists
127
+ if current_split:
128
+ splits.append(" ".join(current_split))
129
+ current_split = []
130
+ current_length = 0
131
+
132
+ # Split long sentence at optimal points
133
+ remaining = sentence_text
134
+ while len(remaining) > text_split_length:
135
+ split_pos = find_best_split_point(
136
+ remaining,
137
+ text_split_length,
138
+ window_size=30
139
+ )
140
+
141
+ # Add split and continue with remainder
142
+ splits.append(remaining[:split_pos].strip())
143
+ remaining = remaining[split_pos:].strip()
144
+
145
+ # Handle remaining text
146
+ if remaining:
147
+ current_split = [remaining]
148
+ current_length = len(remaining)
149
+
150
+ # Start new split
151
+ else:
152
+ splits.append(" ".join(current_split))
153
+ current_split = [sentence_text]
154
+ current_length = sentence_length
155
+
156
+ # Add final split if needed
157
+ if current_split:
158
+ splits.append(" ".join(current_split))
159
+
160
+ cleaned_sentences = [s[:-1]+' ' if s.endswith('.') else s for s in splits if s] # prevents annoying sounds in italian
161
+ # Clean up splits
162
+ return cleaned_sentences
163
+
164
+ _whitespace_re = re.compile(r"\s+")
165
+
166
+ # List of (regular expression, replacement) pairs for abbreviations:
167
+ _abbreviations = {
168
+ "en": [
169
+ (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
170
+ for x in [
171
+ ("mrs", "misess"),
172
+ ("mr", "mister"),
173
+ ("dr", "doctor"),
174
+ ("st", "saint"),
175
+ ("co", "company"),
176
+ ("jr", "junior"),
177
+ ("maj", "major"),
178
+ ("gen", "general"),
179
+ ("drs", "doctors"),
180
+ ("rev", "reverend"),
181
+ ("lt", "lieutenant"),
182
+ ("hon", "honorable"),
183
+ ("sgt", "sergeant"),
184
+ ("capt", "captain"),
185
+ ("esq", "esquire"),
186
+ ("ltd", "limited"),
187
+ ("col", "colonel"),
188
+ ("ft", "fort"),
189
+ ]
190
+ ],
191
+ "es": [
192
+ (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
193
+ for x in [
194
+ ("sra", "señora"),
195
+ ("sr", "señor"),
196
+ ("dr", "doctor"),
197
+ ("dra", "doctora"),
198
+ ("st", "santo"),
199
+ ("co", "compañía"),
200
+ ("jr", "junior"),
201
+ ("ltd", "limitada"),
202
+ ]
203
+ ],
204
+ "fr": [
205
+ (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
206
+ for x in [
207
+ ("mme", "madame"),
208
+ ("mr", "monsieur"),
209
+ ("dr", "docteur"),
210
+ ("st", "saint"),
211
+ ("co", "compagnie"),
212
+ ("jr", "junior"),
213
+ ("ltd", "limitée"),
214
+ ]
215
+ ],
216
+ "de": [
217
+ (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
218
+ for x in [
219
+ ("fr", "frau"),
220
+ ("dr", "doktor"),
221
+ ("st", "sankt"),
222
+ ("co", "firma"),
223
+ ("jr", "junior"),
224
+ ]
225
+ ],
226
+ "pt": [
227
+ (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
228
+ for x in [
229
+ ("sra", "senhora"),
230
+ ("sr", "senhor"),
231
+ ("dr", "doutor"),
232
+ ("dra", "doutora"),
233
+ ("st", "santo"),
234
+ ("co", "companhia"),
235
+ ("jr", "júnior"),
236
+ ("ltd", "limitada"),
237
+ ]
238
+ ],
239
+ "it": [
240
+ (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
241
+ for x in [
242
+ # ("sig.ra", "signora"),
243
+ ("sig", "signore"),
244
+ ("dr", "dottore"),
245
+ ("st", "santo"),
246
+ ("co", "compagnia"),
247
+ ("jr", "junior"),
248
+ ("ltd", "limitata"),
249
+ ]
250
+ ],
251
+ "pl": [
252
+ (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
253
+ for x in [
254
+ ("p", "pani"),
255
+ ("m", "pan"),
256
+ ("dr", "doktor"),
257
+ ("sw", "święty"),
258
+ ("jr", "junior"),
259
+ ]
260
+ ],
261
+ "ar": [
262
+ (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
263
+ for x in [
264
+ # There are not many common abbreviations in Arabic as in English.
265
+ ]
266
+ ],
267
+ "zh": [
268
+ (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
269
+ for x in [
270
+ # Chinese doesn't typically use abbreviations in the same way as Latin-based scripts.
271
+ ]
272
+ ],
273
+ "cs": [
274
+ (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
275
+ for x in [
276
+ ("dr", "doktor"), # doctor
277
+ ("ing", "inženýr"), # engineer
278
+ ("p", "pan"), # Could also map to pani for woman but no easy way to do it
279
+ # Other abbreviations would be specialized and not as common.
280
+ ]
281
+ ],
282
+ "ru": [
283
+ (re.compile("\\b%s\\b" % x[0], re.IGNORECASE), x[1])
284
+ for x in [
285
+ ("г-жа", "госпожа"), # Mrs.
286
+ ("г-н", "господин"), # Mr.
287
+ ("д-р", "доктор"), # doctor
288
+ # Other abbreviations are less common or specialized.
289
+ ]
290
+ ],
291
+ "nl": [
292
+ (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
293
+ for x in [
294
+ ("dhr", "de heer"), # Mr.
295
+ ("mevr", "mevrouw"), # Mrs.
296
+ ("dr", "dokter"), # doctor
297
+ ("jhr", "jonkheer"), # young lord or nobleman
298
+ # Dutch uses more abbreviations, but these are the most common ones.
299
+ ]
300
+ ],
301
+ "tr": [
302
+ (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
303
+ for x in [
304
+ ("b", "bay"), # Mr.
305
+ ("byk", "büyük"), # büyük
306
+ ("dr", "doktor"), # doctor
307
+ # Add other Turkish abbreviations here if needed.
308
+ ]
309
+ ],
310
+ "hu": [
311
+ (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
312
+ for x in [
313
+ ("dr", "doktor"), # doctor
314
+ ("b", "bácsi"), # Mr.
315
+ ("nőv", "nővér"), # nurse
316
+ # Add other Hungarian abbreviations here if needed.
317
+ ]
318
+ ],
319
+ "ko": [
320
+ (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
321
+ for x in [
322
+ # Korean doesn't typically use abbreviations in the same way as Latin-based scripts.
323
+ ]
324
+ ],
325
+ }
326
+
327
+ def expand_abbreviations_multilingual(text, lang="en"):
328
+ if lang in _abbreviations:
329
+ for regex, replacement in _abbreviations[lang]:
330
+ text = re.sub(regex, replacement, text)
331
+ return text
332
+
333
+ _symbols_multilingual = {
334
+ "en": [
335
+ (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
336
+ for x in [
337
+ ("&", " and "),
338
+ ("@", " at "),
339
+ ("%", " percent "),
340
+ ("#", " hash "),
341
+ ("$", " dollar "),
342
+ ("£", " pound "),
343
+ ("°", " degree "),
344
+ ]
345
+ ],
346
+ "es": [
347
+ (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
348
+ for x in [
349
+ ("&", " y "),
350
+ ("@", " arroba "),
351
+ ("%", " por ciento "),
352
+ ("#", " numeral "),
353
+ ("$", " dolar "),
354
+ ("£", " libra "),
355
+ ("°", " grados "),
356
+ ]
357
+ ],
358
+ "fr": [
359
+ (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
360
+ for x in [
361
+ ("&", " et "),
362
+ ("@", " arobase "),
363
+ ("%", " pour cent "),
364
+ ("#", " dièse "),
365
+ ("$", " dollar "),
366
+ ("£", " livre "),
367
+ ("°", " degrés "),
368
+ ]
369
+ ],
370
+ "de": [
371
+ (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
372
+ for x in [
373
+ ("&", " und "),
374
+ ("@", " at "),
375
+ ("%", " prozent "),
376
+ ("#", " raute "),
377
+ ("$", " dollar "),
378
+ ("£", " pfund "),
379
+ ("°", " grad "),
380
+ ]
381
+ ],
382
+ "pt": [
383
+ (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
384
+ for x in [
385
+ ("&", " e "),
386
+ ("@", " arroba "),
387
+ ("%", " por cento "),
388
+ ("#", " cardinal "),
389
+ ("$", " dólar "),
390
+ ("£", " libra "),
391
+ ("°", " graus "),
392
+ ]
393
+ ],
394
+ "it": [
395
+ (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
396
+ for x in [
397
+ ("&", " e "),
398
+ ("@", " chiocciola "),
399
+ ("%", " per cento "),
400
+ ("#", " cancelletto "),
401
+ ("$", " dollaro "),
402
+ ("£", " sterlina "),
403
+ ("°", " gradi "),
404
+ ]
405
+ ],
406
+ "pl": [
407
+ (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
408
+ for x in [
409
+ ("&", " i "),
410
+ ("@", " małpa "),
411
+ ("%", " procent "),
412
+ ("#", " krzyżyk "),
413
+ ("$", " dolar "),
414
+ ("£", " funt "),
415
+ ("°", " stopnie "),
416
+ ]
417
+ ],
418
+ "ar": [
419
+ # Arabic
420
+ (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
421
+ for x in [
422
+ ("&", " و "),
423
+ ("@", " على "),
424
+ ("%", " في المئة "),
425
+ ("#", " رقم "),
426
+ ("$", " دولار "),
427
+ ("£", " جنيه "),
428
+ ("°", " درجة "),
429
+ ]
430
+ ],
431
+ "zh": [
432
+ # Chinese
433
+ (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
434
+ for x in [
435
+ ("&", " 和 "),
436
+ ("@", " 在 "),
437
+ ("%", " 百分之 "),
438
+ ("#", " 号 "),
439
+ ("$", " 美元 "),
440
+ ("£", " 英镑 "),
441
+ ("°", " 度 "),
442
+ ]
443
+ ],
444
+ "cs": [
445
+ # Czech
446
+ (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
447
+ for x in [
448
+ ("&", " a "),
449
+ ("@", " na "),
450
+ ("%", " procento "),
451
+ ("#", " křížek "),
452
+ ("$", " dolar "),
453
+ ("£", " libra "),
454
+ ("°", " stupně "),
455
+ ]
456
+ ],
457
+ "ru": [
458
+ # Russian
459
+ (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
460
+ for x in [
461
+ ("&", " и "),
462
+ ("@", " собака "),
463
+ ("%", " процентов "),
464
+ ("#", " номер "),
465
+ ("$", " доллар "),
466
+ ("£", " фунт "),
467
+ ("°", " градус "),
468
+ ]
469
+ ],
470
+ "nl": [
471
+ # Dutch
472
+ (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
473
+ for x in [
474
+ ("&", " en "),
475
+ ("@", " bij "),
476
+ ("%", " procent "),
477
+ ("#", " hekje "),
478
+ ("$", " dollar "),
479
+ ("£", " pond "),
480
+ ("°", " graden "),
481
+ ]
482
+ ],
483
+ "tr": [
484
+ (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
485
+ for x in [
486
+ ("&", " ve "),
487
+ ("@", " at "),
488
+ ("%", " yüzde "),
489
+ ("#", " diyez "),
490
+ ("$", " dolar "),
491
+ ("£", " sterlin "),
492
+ ("°", " derece "),
493
+ ]
494
+ ],
495
+ "hu": [
496
+ (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
497
+ for x in [
498
+ ("&", " és "),
499
+ ("@", " kukac "),
500
+ ("%", " százalék "),
501
+ ("#", " kettőskereszt "),
502
+ ("$", " dollár "),
503
+ ("£", " font "),
504
+ ("°", " fok "),
505
+ ]
506
+ ],
507
+ "ko": [
508
+ # Korean
509
+ (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
510
+ for x in [
511
+ ("&", " 그리고 "),
512
+ ("@", " 에 "),
513
+ ("%", " 퍼센트 "),
514
+ ("#", " 번호 "),
515
+ ("$", " 달러 "),
516
+ ("£", " 파운드 "),
517
+ ("°", " 도 "),
518
+ ]
519
+ ],
520
+ }
521
+
522
+ def expand_symbols_multilingual(text, lang="en"):
523
+ if lang in _symbols_multilingual:
524
+ for regex, replacement in _symbols_multilingual[lang]:
525
+ text = re.sub(regex, replacement, text)
526
+ text = text.replace(" ", " ") # Ensure there are no double spaces
527
+ return text.strip()
528
+
529
+ _ordinal_re = {
530
+ "en": re.compile(r"([0-9]+)(st|nd|rd|th)"),
531
+ "es": re.compile(r"([0-9]+)(º|ª|er|o|a|os|as)"),
532
+ "fr": re.compile(r"([0-9]+)(º|ª|er|re|e|ème)"),
533
+ "de": re.compile(r"([0-9]+)(st|nd|rd|th|º|ª|\.(?=\s|$))"),
534
+ "pt": re.compile(r"([0-9]+)(º|ª|o|a|os|as)"),
535
+ "it": re.compile(r"([0-9]+)(º|°|ª|o|a|i|e)"),
536
+ "pl": re.compile(r"([0-9]+)(º|ª|st|nd|rd|th)"),
537
+ "ar": re.compile(r"([0-9]+)(ون|ين|ث|ر|ى)"),
538
+ "cs": re.compile(r"([0-9]+)\.(?=\s|$)"), # In Czech, a dot is often used after the number to indicate ordinals.
539
+ "ru": re.compile(r"([0-9]+)(-й|-я|-е|-ое|-ье|-го)"),
540
+ "nl": re.compile(r"([0-9]+)(de|ste|e)"),
541
+ "tr": re.compile(r"([0-9]+)(\.|inci|nci|uncu|üncü|\.)"),
542
+ "hu": re.compile(r"([0-9]+)(\.|adik|edik|odik|edik|ödik|ödike|ik)"),
543
+ "ko": re.compile(r"([0-9]+)(번째|번|차|째)"),
544
+ }
545
+ _number_re = re.compile(r"[0-9]+")
546
+ # noinspection Annotator
547
+ _currency_re = {
548
+ "USD": re.compile(r"((\$[0-9\.\,]*[0-9]+)|([0-9\.\,]*[0-9]+\$))"),
549
+ "GBP": re.compile(r"((£[0-9\.\,]*[0-9]+)|([0-9\.\,]*[0-9]+£))"),
550
+ "EUR": re.compile(r"(([0-9\.\,]*[0-9]+€)|((€[0-9\.\,]*[0-9]+)))"),
551
+ }
552
+
553
+ _comma_number_re = re.compile(r"\b\d{1,3}(,\d{3})*(\.\d+)?\b")
554
+ _dot_number_re = re.compile(r"\b\d{1,3}(\.\d{3})*(\,\d+)?\b")
555
+ _decimal_number_re = re.compile(r"([0-9]+[.,][0-9]+)")
556
+
557
+ def _remove_commas(m):
558
+ text = m.group(0)
559
+ if "," in text:
560
+ text = text.replace(",", "")
561
+ return text
562
+
563
+ def _remove_dots(m):
564
+ text = m.group(0)
565
+ if "." in text:
566
+ text = text.replace(".", "")
567
+ return text
568
+
569
+ def _expand_decimal_point(m, lang="en"):
570
+ amount = m.group(1).replace(",", ".")
571
+ return num2words(float(amount), lang=lang if lang != "cs" else "cz")
572
+
573
+ def _expand_currency(m, lang="en", currency="USD"):
574
+ amount = float((re.sub(r"[^\d.]", "", m.group(0).replace(",", "."))))
575
+ full_amount = num2words(amount, to="currency", currency=currency, lang=lang if lang != "cs" else "cz")
576
+
577
+ and_equivalents = {
578
+ "en": ", ",
579
+ "es": " con ",
580
+ "fr": " et ",
581
+ "de": " und ",
582
+ "pt": " e ",
583
+ "it": " e ",
584
+ "pl": ", ",
585
+ "cs": ", ",
586
+ "ru": ", ",
587
+ "nl": ", ",
588
+ "ar": ", ",
589
+ "tr": ", ",
590
+ "hu": ", ",
591
+ "ko": ", ",
592
+ }
593
+
594
+ if amount.is_integer():
595
+ last_and = full_amount.rfind(and_equivalents.get(lang, ", "))
596
+ if last_and != -1:
597
+ full_amount = full_amount[:last_and]
598
+
599
+ return full_amount
600
+
601
+ def _expand_ordinal(m, lang="en"):
602
+ return num2words(int(m.group(1)), ordinal=True, lang=lang if lang != "cs" else "cz")
603
+
604
+ def _expand_number(m, lang="en"):
605
+ return num2words(int(m.group(0)), lang=lang if lang != "cs" else "cz")
606
+
607
+ def expand_numbers_multilingual(text, lang="en"):
608
+ if lang == "zh":
609
+ text = zh_num2words()(text)
610
+ else:
611
+ if lang in ["en", "ru"]:
612
+ text = re.sub(_comma_number_re, _remove_commas, text)
613
+ else:
614
+ text = re.sub(_dot_number_re, _remove_dots, text)
615
+ try:
616
+ text = re.sub(_currency_re["GBP"], lambda m: _expand_currency(m, lang, "GBP"), text)
617
+ text = re.sub(_currency_re["USD"], lambda m: _expand_currency(m, lang, "USD"), text)
618
+ text = re.sub(_currency_re["EUR"], lambda m: _expand_currency(m, lang, "EUR"), text)
619
+ except Exception as e:
620
+ pass
621
+ if lang != "tr":
622
+ text = re.sub(_decimal_number_re, lambda m: _expand_decimal_point(m, lang), text)
623
+ if lang in _ordinal_re:
624
+ text = re.sub(_ordinal_re[lang], lambda m: _expand_ordinal(m, lang), text)
625
+ text = re.sub(_number_re, lambda m: _expand_number(m, lang), text)
626
+ return text
627
+
628
+ def lowercase(text):
629
+ return text.lower()
630
+
631
+ def collapse_whitespace(text):
632
+ return re.sub(_whitespace_re, " ", text)
633
+
634
+ def multilingual_cleaners(text, lang):
635
+ text = text.replace('"', "")
636
+ if lang == "tr":
637
+ text = text.replace("İ", "i")
638
+ text = text.replace("Ö", "ö")
639
+ text = text.replace("Ü", "ü")
640
+ text = lowercase(text)
641
+ text = expand_numbers_multilingual(text, lang)
642
+ text = expand_abbreviations_multilingual(text, lang)
643
+ text = expand_symbols_multilingual(text, lang=lang)
644
+ text = collapse_whitespace(text)
645
+ return text
646
+
647
+ def basic_cleaners(text):
648
+ """Basic pipeline that lowercases and collapses whitespace without transliteration."""
649
+ text = lowercase(text)
650
+ text = collapse_whitespace(text)
651
+ return text
652
+
653
+ def chinese_transliterate(text):
654
+ return "".join(
655
+ [p[0] for p in pypinyin.pinyin(text, style=pypinyin.Style.TONE3, heteronym=False, neutral_tone_with_five=True)]
656
+ )
657
+
658
+ def japanese_cleaners(text, katsu):
659
+ text = katsu.romaji(text)
660
+ text = lowercase(text)
661
+ return text
662
+
663
+ def korean_transliterate(text, transliter):
664
+ return transliter.translit(text)
665
+
666
+ # Fast Tokenizer Class
667
 
668
  class XTTSTokenizerFast(PreTrainedTokenizerFast):
669
  """
670
  Fast Tokenizer implementation for XTTS model using HuggingFace's PreTrainedTokenizerFast
671
  """
672
+
673
  def __init__(
674
  self,
675
  vocab_file: str = None,
 
678
  pad_token: str = "[PAD]",
679
  bos_token: str = "[START]",
680
  eos_token: str = "[STOP]",
681
+ auto_map: dict = {"AutoTokenizer": ["AstraMindAI/xtts2-gpt--tokenizer.XTTSTokenizerFast", None]},
682
  clean_up_tokenization_spaces: bool = True,
683
  **kwargs
684
  ):
 
688
  if tokenizer_object is not None:
689
  # Configure the tokenizer
690
  tokenizer_object.pre_tokenizer = WhitespaceSplit()
 
 
 
 
 
691
  tokenizer_object.post_processor = TemplateProcessing(
692
  single=f"{bos_token} $A {eos_token}",
693
  special_tokens=[
 
718
  self._katsu = None
719
  self._korean_transliter = Transliter(academic)
720
 
721
+ # Ensure pad_token_id is set
722
+ if self.pad_token_id is None:
723
+ self.pad_token_id = self.tokenizer.token_to_id(self.pad_token)
724
+
725
  @cached_property
726
  def katsu(self):
727
  if self._katsu is None:
728
  self._katsu = cutlet.Cutlet()
729
  return self._katsu
730
 
 
 
 
 
 
 
 
731
  def preprocess_text(self, text: str, lang: str) -> str:
732
  """Apply text preprocessing for language"""
733
+ base_lang = lang.split("-")[0] # remove region
734
+ if base_lang in {"ar", "cs", "de", "en", "es", "fr", "hu", "it",
735
+ "nl", "pl", "pt", "ru", "tr", "zh", "ko"}:
736
+ text = multilingual_cleaners(text, base_lang)
737
+ if base_lang == "zh":
738
  text = chinese_transliterate(text)
739
+ if base_lang == "ko":
740
+ text = korean_transliterate(text, self._korean_transliter)
741
+ elif base_lang == "ja":
742
  text = japanese_cleaners(text, self.katsu)
743
  else:
744
  text = basic_cleaners(text)
745
  return text
746
 
747
+ def batch_encode_with_split(self, texts: Union[str, List[str]], lang: Union[str, List[str]],
748
+ **kwargs) -> torch.Tensor:
749
+ """
750
+ Split texts into smaller chunks based on language character limits and encode them using HuggingFace fast tokenizer.
751
+ strictly mimic the xttsv2 tokenizer
752
+ """
753
+ # Convert single inputs to lists
754
+ if isinstance(texts, str):
755
+ texts = [texts]
756
+ if isinstance(lang, str):
757
+ lang = [lang]
758
+ # Ensure lang list matches texts list
759
+ if len(lang) == 1 and len(texts) > 1:
760
+ lang = lang * len(texts)
761
+
762
+ # Check if texts and lang have the same length
763
+ if len(texts) != len(lang):
764
+ raise ValueError(f"Number of texts ({len(texts)}) does not match number of languages ({len(lang)}).")
765
+
766
+ chunk_list = []
767
+ max_splits = 0
768
+
769
+ # For each text, split into chunks based on character limit
770
+ for text, text_lang in zip(texts, lang):
771
+ # Get language character limit
772
+ base_lang = text_lang.split("-")[0]
773
+ char_limit = self.char_limits.get(base_lang, 250)
774
+
775
+ # Clean and preprocess
776
+ text = self.preprocess_text(text, text_lang)
777
+
778
+ # Split text into sentences/chunks based on language
779
+ chunk_list = split_sentence(text, base_lang, text_split_length=char_limit)
780
+
781
+ # Ensure the tokenizer is a fast tokenizer
782
+ if not self.is_fast:
783
+ raise ValueError("The tokenizer must be a fast tokenizer.")
784
+
785
+ # Encode all chunks using the fast tokenizer
786
+ encoding: BatchEncoding = self(
787
+ chunk_list,
788
+ lang = lang,
789
+ add_special_tokens=False,
790
+ padding=False,
791
+ **kwargs
792
+ )
793
+
794
+ # The 'input_ids' tensor will have shape [total_chunks, max_sequence_length]
795
+ return encoding['input_ids'] # Tensor of shape [total_chunks, sequence_length]
796
+
797
  def _batch_encode_plus(
798
  self,
799
  batch_text_or_text_pairs,
800
  add_special_tokens: bool = True,
801
+ padding_strategy=PaddingStrategy.DO_NOT_PAD,
802
+ truncation_strategy=TruncationStrategy.DO_NOT_TRUNCATE,
803
+ max_length: Optional[int] = None,
804
  stride: int = 0,
805
  is_split_into_words: bool = False,
806
  pad_to_multiple_of: Optional[int] = None,
 
819
  """
820
  lang = kwargs.pop("lang", ["en"] * len(batch_text_or_text_pairs))
821
  if isinstance(lang, str):
822
+ lang = [lang]
823
+ # Ensure lang list matches texts list
824
+ if len(lang) == 1 and len(batch_text_or_text_pairs) > 1:
825
+ lang = lang * len(batch_text_or_text_pairs)
826
+
827
+ # Check if batch_text_or_text_pairs and lang have the same length
828
+ if len(batch_text_or_text_pairs) != len(lang):
829
+ raise ValueError(f"Number of texts ({len(batch_text_or_text_pairs)}) does not match number of languages ({len(lang)}).")
830
 
831
  # Preprocess each text in the batch with its corresponding language
832
  processed_texts = []
833
  for text, text_lang in zip(batch_text_or_text_pairs, lang):
834
  if isinstance(text, str):
835
  # Check length and preprocess
836
+ #self.check_input_length(text, text_lang)
837
  processed_text = self.preprocess_text(text, text_lang)
838
 
839
  # Format text with language tag and spaces
840
+ base_lang = text_lang.split("-")[0]
841
+ lang_code = "zh-cn" if base_lang == "zh" else base_lang
842
  processed_text = f"[{lang_code}]{processed_text}"
843
  processed_text = processed_text.replace(" ", "[SPACE]")
844
 
 
867
  **kwargs
868
  )
869
 
870
+
871
  def __call__(
872
  self,
873
  text: Union[str, List[str]],
874
  lang: Union[str, List[str]] = "en",
875
  add_special_tokens: bool = True,
876
+ padding: Union[bool, str, PaddingStrategy] = False,
877
+ truncation: Union[bool, str, TruncationStrategy] = False,
878
+ max_length: Optional[int] = None,
879
  stride: int = 0,
880
  return_tensors: Optional[str] = None,
881
  return_token_type_ids: Optional[bool] = None,
882
+ return_attention_mask: Optional[bool] = True,
883
  **kwargs
884
  ):
885
  """
886
  Main tokenization method
 
 
 
 
 
 
 
 
 
 
 
887
  """
888
  # Convert single string to list for batch processing
889
  if isinstance(text, str):
890
  text = [text]
891
+ if isinstance(lang, str):
892
+ lang = [lang]
893
+ # Ensure lang list matches texts list
894
+ if len(lang) == 1 and len(text) > 1:
895
+ lang = lang * len(text)
896
 
897
  # Ensure text and lang lists have same length
898
  if len(text) != len(lang):
899
+ raise ValueError(f"Number of texts ({len(text)}) does not match number of languages ({len(lang)}).")
900
 
901
  # Convert padding strategy
902
  if isinstance(padding, bool):
903
+ padding_strategy = PaddingStrategy.LONGEST if padding else PaddingStrategy.DO_NOT_PAD
904
  else:
905
  padding_strategy = PaddingStrategy(padding)
906
 
 
925
  **kwargs
926
  )
927
 
928
+ return encoded