File size: 3,403 Bytes
c968fc3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import json
import os
import time
import numpy as np
from tqdm import tqdm
import torch
from torch.utils.data import DataLoader

from models.svc.base import SVCInference
from models.svc.vits.vits import SynthesizerTrn

from models.svc.base.svc_dataset import SVCTestDataset, SVCTestCollator
from utils.io import save_audio
from utils.audio_slicer import is_silence


class VitsInference(SVCInference):
    def __init__(self, args=None, cfg=None, infer_type="from_dataset"):
        SVCInference.__init__(self, args, cfg)

    def _build_model(self):
        net_g = SynthesizerTrn(
            self.cfg.preprocess.n_fft // 2 + 1,
            self.cfg.preprocess.segment_size // self.cfg.preprocess.hop_size,
            self.cfg,
        )
        self.model = net_g
        return net_g

    def build_save_dir(self, dataset, speaker):
        save_dir = os.path.join(
            self.args.output_dir,
            "svc_am_step-{}_{}".format(self.am_restore_step, self.args.mode),
        )
        if dataset is not None:
            save_dir = os.path.join(save_dir, "data_{}".format(dataset))
        if speaker != -1:
            save_dir = os.path.join(
                save_dir,
                "spk_{}".format(speaker),
            )
        os.makedirs(save_dir, exist_ok=True)
        print("Saving to ", save_dir)
        return save_dir

    def _build_dataloader(self):
        datasets, collate = self._build_test_dataset()
        self.test_dataset = datasets(self.args, self.cfg, self.infer_type)
        self.test_collate = collate(self.cfg)
        self.test_batch_size = min(
            self.cfg.inference.batch_size, len(self.test_dataset.metadata)
        )
        test_dataloader = DataLoader(
            self.test_dataset,
            collate_fn=self.test_collate,
            num_workers=1,
            batch_size=self.test_batch_size,
            shuffle=False,
        )
        return test_dataloader

    @torch.inference_mode()
    def inference(self):
        res = []
        for i, batch in enumerate(self.test_dataloader):
            pred_audio_list = self._inference_each_batch(batch)
            for j, wav in enumerate(pred_audio_list):
                uid = self.test_dataset.metadata[i * self.test_batch_size + j]["Uid"]
                file = os.path.join(self.args.output_dir, f"{uid}.wav")
                print(f"Saving {file}")

                wav = wav.numpy(force=True)
                save_audio(
                    file,
                    wav,
                    self.cfg.preprocess.sample_rate,
                    add_silence=False,
                    turn_up=not is_silence(wav, self.cfg.preprocess.sample_rate),
                )
                res.append(file)
        return res

    def _inference_each_batch(self, batch_data, noise_scale=0.667):
        device = self.accelerator.device
        pred_res = []
        self.model.eval()
        with torch.no_grad():
            # Put the data to device
            # device = self.accelerator.device
            for k, v in batch_data.items():
                batch_data[k] = v.to(device)

            audios, f0 = self.model.infer(batch_data, noise_scale=noise_scale)

            pred_res.extend(audios)

        return pred_res