Kamtera commited on
Commit
1217548
1 Parent(s): 12aa548

Upload train_vits-0.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_vits-0.py +104 -0
train_vits-0.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from trainer import Trainer, TrainerArgs
4
+
5
+ from TTS.tts.configs.shared_configs import BaseDatasetConfig , CharactersConfig
6
+ from TTS.config.shared_configs import BaseAudioConfig
7
+ from TTS.tts.configs.vits_config import VitsConfig
8
+ from TTS.tts.datasets import load_tts_samples
9
+ from TTS.tts.models.vits import Vits, VitsAudioConfig
10
+ from TTS.tts.utils.text.tokenizer import TTSTokenizer
11
+ from TTS.utils.audio import AudioProcessor
12
+ from TTS.utils.downloaders import download_thorsten_de
13
+
14
+ output_path = os.path.dirname(os.path.abspath(__file__))
15
+ dataset_config = BaseDatasetConfig(
16
+ formatter="mozilla", meta_file_train="metadata.csv", path="/kaggle/input/persian-tts-dataset-male"
17
+ )
18
+
19
+
20
+
21
+ audio_config = BaseAudioConfig(
22
+ sample_rate=22050,
23
+ do_trim_silence=True,
24
+ resample=False,
25
+ mel_fmin=0,
26
+ mel_fmax=None
27
+ )
28
+ character_config=CharactersConfig(
29
+ characters='ءابتثجحخدذرزسشصضطظعغفقلمنهويِپچژکگیآأؤإئًَُّ',
30
+ punctuations='!(),-.:;? ̠،؛؟‌<>',
31
+ phonemes='ˈˌːˑpbtdʈɖcɟkɡqɢʔɴŋɲɳnɱmʙrʀⱱɾɽɸβfvθðszʃʒʂʐçʝxɣχʁħʕhɦɬɮʋɹɻjɰlɭʎʟaegiouwyɪʊ̩æɑɔəɚɛɝɨ̃ʉʌʍ0123456789"#$%*+/=ABCDEFGHIJKLMNOPRSTUVWXYZ[]^_{}',
32
+ pad="<PAD>",
33
+ eos="<EOS>",
34
+ bos="<BOS>",
35
+ blank="<BLNK>",
36
+ characters_class="TTS.tts.utils.text.characters.IPAPhonemes",
37
+ )
38
+ config = VitsConfig(
39
+ audio=audio_config,
40
+ run_name="vits_fa_male",
41
+ batch_size=8,
42
+ eval_batch_size=4,
43
+ batch_group_size=5,
44
+ num_loader_workers=0,
45
+ num_eval_loader_workers=2,
46
+ run_eval=True,
47
+ test_delay_epochs=-1,
48
+ epochs=1000,
49
+ save_step=1000,
50
+ text_cleaner="basic_cleaners",
51
+ use_phonemes=True,
52
+ phoneme_language="fa",
53
+ characters=character_config,
54
+ phoneme_cache_path=os.path.join(output_path, "phoneme_cache"),
55
+ compute_input_seq_cache=True,
56
+ print_step=25,
57
+ print_eval=True,
58
+ mixed_precision=False,
59
+ test_sentences=[
60
+ ["سلطان محمود در زمستانی سخت به طلخک گفت که: با این جامه ی یک لا در این سرما چه می کنی "],
61
+ ["مردی نزد بقالی آمد و گفت پیاز هم ده تا دهان بدان خو شبوی سازم."],
62
+ ["از مال خود پاره ای گوشت بستان و زیره بایی معطّر بساز"],
63
+ ["یک بار هم از جهنم بگویید."],
64
+ ["یکی اسبی به عاریت خواست"]
65
+ ],
66
+ output_path=output_path,
67
+ datasets=[dataset_config],
68
+ )
69
+
70
+ # INITIALIZE THE AUDIO PROCESSOR
71
+ # Audio processor is used for feature extraction and audio I/O.
72
+ # It mainly serves to the dataloader and the training loggers.
73
+ ap = AudioProcessor.init_from_config(config)
74
+
75
+ # INITIALIZE THE TOKENIZER
76
+ # Tokenizer is used to convert text to sequences of token IDs.
77
+ # config is updated with the default characters if not defined in the config.
78
+ tokenizer, config = TTSTokenizer.init_from_config(config)
79
+
80
+ # LOAD DATA SAMPLES
81
+ # Each sample is a list of ```[text, audio_file_path, speaker_name]```
82
+ # You can define your custom sample loader returning the list of samples.
83
+ # Or define your custom formatter and pass it to the `load_tts_samples`.
84
+ # Check `TTS.tts.datasets.load_tts_samples` for more details.
85
+ train_samples, eval_samples = load_tts_samples(
86
+ dataset_config,
87
+ eval_split=True,
88
+ eval_split_max_size=config.eval_split_max_size,
89
+ eval_split_size=config.eval_split_size,
90
+ )
91
+
92
+ # init model
93
+ model = Vits(config, ap, tokenizer, speaker_manager=None)
94
+
95
+ # init the trainer and 🚀
96
+ trainer = Trainer(
97
+ TrainerArgs(),
98
+ config,
99
+ output_path,
100
+ model=model,
101
+ train_samples=train_samples,
102
+ eval_samples=eval_samples,
103
+ )
104
+ trainer.fit()