File size: 4,806 Bytes
4efe6b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
import os
import sys
import tqdm
import torch
import torch.nn.functional as F
import soundfile as sf
import numpy as np
import time

now_dir = os.getcwd()
sys.path.append(now_dir)

from rvc.lib.utils import load_embedding
from rvc.configs.config import Config

config = Config()


def setup_paths(exp_dir: str, version: str):
    """Set up input and output paths."""
    wav_path = os.path.join(exp_dir, "sliced_audios_16k")
    out_path = os.path.join(
        exp_dir, "v1_extracted" if version == "v1" else "v2_extracted"
    )
    os.makedirs(out_path, exist_ok=True)
    return wav_path, out_path


def read_wave(wav_path: str, normalize: bool = False):
    """Read a wave file and return its features."""
    wav, sr = sf.read(wav_path)
    assert sr == 16000, "Sample rate must be 16000"

    feats = torch.from_numpy(wav)
    feats = feats.half() if config.is_half else feats.float()
    feats = feats.mean(-1) if feats.dim() == 2 else feats
    feats = feats.view(1, -1)

    if normalize:
        with torch.no_grad():
            feats = F.layer_norm(feats, feats.shape)
    return feats


def process_file(
    file: str,
    wav_path: str,
    out_path: str,
    model: torch.nn.Module,
    device: str,
    version: str,
    saved_cfg: Config,
):
    """Process a single audio file."""
    wav_file_path = os.path.join(wav_path, file)
    out_file_path = os.path.join(out_path, file.replace("wav", "npy"))

    if os.path.exists(out_file_path):
        return

    # Load and prepare features
    feats = read_wave(wav_file_path, normalize=saved_cfg.task.normalize)

    # Adjust dtype based on the device
    dtype = torch.float16 if device.startswith("cuda") else torch.float32
    feats = feats.to(dtype).to(device)

    padding_mask = torch.BoolTensor(feats.shape).fill_(False).to(dtype).to(device)

    inputs = {
        "source": feats,
        "padding_mask": padding_mask,
        "output_layer": 9 if version == "v1" else 12,
    }

    with torch.no_grad():
        model = model.to(device).to(dtype)

        logits = model.extract_features(**inputs)
        feats = model.final_proj(logits[0]) if version == "v1" else logits[0]

    feats = feats.squeeze(0).float().cpu().numpy()
    if not np.isnan(feats).any():
        np.save(out_file_path, feats, allow_pickle=False)
    else:
        print(f"{file} contains NaN values and will be skipped.")


def main():
    """Main function to orchestrate the feature extraction process."""
    try:
        exp_dir = str(sys.argv[1])
        version = str(sys.argv[2])
        gpus = str(sys.argv[3])
        embedder_model = str(sys.argv[4])
        embedder_model_custom = str(sys.argv[5]) if len(sys.argv) > 5 else None

        os.environ["CUDA_VISIBLE_DEVICES"] = gpus.replace("-", ",")
    except IndexError:
        print("Invalid arguments provided.")
        sys.exit(1)

    wav_path, out_path = setup_paths(exp_dir, version)

    print("Starting feature extraction...")
    start_time = time.time()

    models, saved_cfg, task = load_embedding(embedder_model, embedder_model_custom)
    model = models[0]

    gpus = gpus.split("-") if gpus != "-" else ["cpu"]

    devices = []
    for gpu in gpus:
        try:
            if gpu != "cpu":
                index = int(gpu)
                if index < torch.cuda.device_count():
                    devices.append(f"cuda:{index}")
                else:
                    print(
                        f"Oops, there was an issue initializing GPU. Maybe you don't have a GPU? No worries, switching to CPU for now."
                    )
                    devices.append("cpu")
            else:
                devices.append("cpu")
        except ValueError:
            f"Oops, there was an issue initializing GPU. Maybe you don't have a GPU? No worries, switching to CPU for now."
            devices.append("cpu")

    paths = sorted(os.listdir(wav_path))
    if not paths:
        print("No audio files found. Make sure you have provided the audios correctly.")
        sys.exit(1)

    pbar = tqdm.tqdm(total=len(paths), desc="Embedding Extraction")

    # Create a list of tasks to be processed
    tasks = [
        (
            file,
            wav_path,
            out_path,
            model,
            device,
            version,
            saved_cfg,
        )
        for file in paths
        if file.endswith(".wav")
        for device in devices
    ]

    # Process files
    for task in tasks:
        try:
            process_file(*task)
        except Exception as error:
            print(f"An error occurred processing {task[0]}: {error}")
        pbar.update(1)

    pbar.close()
    elapsed_time = time.time() - start_time
    print(f"Embedding extraction completed in {elapsed_time:.2f} seconds.")


if __name__ == "__main__":
    main()