File size: 1,978 Bytes
de6e35f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import time

import wandb

from Architectures.ToucanTTS.StochasticToucanTTS.StochasticToucanTTS import StochasticToucanTTS
from Architectures.ToucanTTS.toucantts_train_loop_arbiter import train_loop
from Utility.corpus_preparation import prepare_tts_corpus
from Utility.path_to_transcript_dicts import *
from Utility.storage_config import MODELS_DIR
from Utility.storage_config import PREPROCESSING_DIR


def run(gpu_id, resume_checkpoint, finetune, model_dir, resume, use_wandb, wandb_resume_id):
    if gpu_id == "cpu":
        device = torch.device("cpu")
    else:
        device = torch.device("cuda")

    print("Preparing")

    if model_dir is not None:
        save_dir = model_dir
    else:
        save_dir = os.path.join(MODELS_DIR, "StochasticToucanTTS_Nancy")
    os.makedirs(save_dir, exist_ok=True)

    train_set = prepare_tts_corpus(transcript_dict=build_path_to_transcript_dict_nancy(),
                                   corpus_dir=os.path.join(PREPROCESSING_DIR, "Nancy"),
                                   lang="eng",
                                   save_imgs=False)

    model = StochasticToucanTTS()
    if use_wandb:
        wandb.init(
            name=f"{__name__.split('.')[-1]}_{time.strftime('%Y%m%d-%H%M%S')}" if wandb_resume_id is None else None,
            id=wandb_resume_id,  # this is None if not specified in the command line arguments.
            resume="must" if wandb_resume_id is not None else None)
    print("Training model")
    train_loop(net=model,
               datasets=[train_set],
               device=device,
               save_directory=save_dir,
               eval_lang="eng",
               path_to_checkpoint=resume_checkpoint,
               fine_tune=finetune,
               resume=resume,
               lr=0.0002,  # it seems the stochastic predictors need a smaller learning rate
               use_wandb=use_wandb)
    if use_wandb:
        wandb.finish()