File size: 3,135 Bytes
de6e35f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
"""

Example script for fine-tuning the pretrained model to your own data.



Comments in ALL CAPS are instructions

"""

import time

import wandb
from torch.utils.data import ConcatDataset

from Architectures.ToucanTTS.ToucanTTS import ToucanTTS
from Architectures.ToucanTTS.toucantts_train_loop_arbiter import train_loop
from Utility.corpus_preparation import prepare_tts_corpus
from Utility.path_to_transcript_dicts import *
from Utility.storage_config import MODELS_DIR
from Utility.storage_config import PREPROCESSING_DIR


def run(

    gpu_id,

    resume_checkpoint,

    finetune,

    model_dir,

    resume,

    use_wandb,

    wandb_resume_id,

    gpu_count,

):
    if gpu_id == "cpu":
        device = torch.device("cpu")
    else:
        device = torch.device("cuda")
    assert gpu_count == 1  # distributed finetuning is not supported

    # IF YOU'RE ADDING A NEW LANGUAGE, YOU MIGHT NEED TO ADD HANDLING FOR IT IN Preprocessing/TextFrontend.py

    print("Preparing")

    if model_dir is not None:
        save_dir = model_dir
    else:
        save_dir = os.path.join(
            MODELS_DIR, "ToucanTTS_Shan"
        )  # RENAME TO SOMETHING MEANINGFUL FOR YOUR DATA
    os.makedirs(save_dir, exist_ok=True)

    train_data = prepare_tts_corpus(
        transcript_dict=build_path_to_transcript_dict_shan(),
        corpus_dir=os.path.join(PREPROCESSING_DIR, "integration_shan"),
        lang="shn",
    )  # CHANGE THE TRANSCRIPT DICT, THE NAME OF THE CACHE DIRECTORY AND THE LANGUAGE TO YOUR NEEDS

    model = ToucanTTS()

    if use_wandb:
        wandb.init(
            name=(
                f"{__name__.split('.')[-1]}_{time.strftime('%Y%m%d-%H%M%S')}"
                if wandb_resume_id is None
                else None
            ),
            id=wandb_resume_id,  # this is None if not specified in the command line arguments.
            resume="must" if wandb_resume_id is not None else None,
        )

    print("Training model")
    train_loop(
        net=model,
        datasets=[train_data],
        device=device,
        save_directory=save_dir,
        batch_size=32,  # YOU MIGHT GET OUT OF MEMORY ISSUES ON SMALL GPUs, IF SO, DECREASE THIS.
        eval_lang="shn",  # THE LANGUAGE YOUR PROGRESS PLOTS WILL BE MADE IN
        warmup_steps=500,
        lr=1e-5,  # if you have enough data (over ~1000 datapoints) you can increase this up to 1e-4 and it will still be stable, but learn quicker.
        # DOWNLOAD THESE INITIALIZATION MODELS FROM THE RELEASE PAGE OF THE GITHUB OR RUN THE DOWNLOADER SCRIPT TO GET THEM AUTOMATICALLY
        path_to_checkpoint=(
            os.path.join(MODELS_DIR, "ToucanTTS_Meta", "best.pt")
            if resume_checkpoint is None
            else resume_checkpoint
        ),
        fine_tune=True if resume_checkpoint is None and not resume else finetune,
        resume=resume,
        steps=5000,
        use_wandb=use_wandb,
        train_samplers=[torch.utils.data.RandomSampler(train_data)],
        gpu_count=1,
    )
    if use_wandb:
        wandb.finish()