Spaces:
Runtime error
Runtime error
File size: 4,806 Bytes
4efe6b5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 |
import os
import sys
import tqdm
import torch
import torch.nn.functional as F
import soundfile as sf
import numpy as np
import time
now_dir = os.getcwd()
sys.path.append(now_dir)
from rvc.lib.utils import load_embedding
from rvc.configs.config import Config
config = Config()
def setup_paths(exp_dir: str, version: str):
"""Set up input and output paths."""
wav_path = os.path.join(exp_dir, "sliced_audios_16k")
out_path = os.path.join(
exp_dir, "v1_extracted" if version == "v1" else "v2_extracted"
)
os.makedirs(out_path, exist_ok=True)
return wav_path, out_path
def read_wave(wav_path: str, normalize: bool = False):
"""Read a wave file and return its features."""
wav, sr = sf.read(wav_path)
assert sr == 16000, "Sample rate must be 16000"
feats = torch.from_numpy(wav)
feats = feats.half() if config.is_half else feats.float()
feats = feats.mean(-1) if feats.dim() == 2 else feats
feats = feats.view(1, -1)
if normalize:
with torch.no_grad():
feats = F.layer_norm(feats, feats.shape)
return feats
def process_file(
file: str,
wav_path: str,
out_path: str,
model: torch.nn.Module,
device: str,
version: str,
saved_cfg: Config,
):
"""Process a single audio file."""
wav_file_path = os.path.join(wav_path, file)
out_file_path = os.path.join(out_path, file.replace("wav", "npy"))
if os.path.exists(out_file_path):
return
# Load and prepare features
feats = read_wave(wav_file_path, normalize=saved_cfg.task.normalize)
# Adjust dtype based on the device
dtype = torch.float16 if device.startswith("cuda") else torch.float32
feats = feats.to(dtype).to(device)
padding_mask = torch.BoolTensor(feats.shape).fill_(False).to(dtype).to(device)
inputs = {
"source": feats,
"padding_mask": padding_mask,
"output_layer": 9 if version == "v1" else 12,
}
with torch.no_grad():
model = model.to(device).to(dtype)
logits = model.extract_features(**inputs)
feats = model.final_proj(logits[0]) if version == "v1" else logits[0]
feats = feats.squeeze(0).float().cpu().numpy()
if not np.isnan(feats).any():
np.save(out_file_path, feats, allow_pickle=False)
else:
print(f"{file} contains NaN values and will be skipped.")
def main():
"""Main function to orchestrate the feature extraction process."""
try:
exp_dir = str(sys.argv[1])
version = str(sys.argv[2])
gpus = str(sys.argv[3])
embedder_model = str(sys.argv[4])
embedder_model_custom = str(sys.argv[5]) if len(sys.argv) > 5 else None
os.environ["CUDA_VISIBLE_DEVICES"] = gpus.replace("-", ",")
except IndexError:
print("Invalid arguments provided.")
sys.exit(1)
wav_path, out_path = setup_paths(exp_dir, version)
print("Starting feature extraction...")
start_time = time.time()
models, saved_cfg, task = load_embedding(embedder_model, embedder_model_custom)
model = models[0]
gpus = gpus.split("-") if gpus != "-" else ["cpu"]
devices = []
for gpu in gpus:
try:
if gpu != "cpu":
index = int(gpu)
if index < torch.cuda.device_count():
devices.append(f"cuda:{index}")
else:
print(
f"Oops, there was an issue initializing GPU. Maybe you don't have a GPU? No worries, switching to CPU for now."
)
devices.append("cpu")
else:
devices.append("cpu")
except ValueError:
f"Oops, there was an issue initializing GPU. Maybe you don't have a GPU? No worries, switching to CPU for now."
devices.append("cpu")
paths = sorted(os.listdir(wav_path))
if not paths:
print("No audio files found. Make sure you have provided the audios correctly.")
sys.exit(1)
pbar = tqdm.tqdm(total=len(paths), desc="Embedding Extraction")
# Create a list of tasks to be processed
tasks = [
(
file,
wav_path,
out_path,
model,
device,
version,
saved_cfg,
)
for file in paths
if file.endswith(".wav")
for device in devices
]
# Process files
for task in tasks:
try:
process_file(*task)
except Exception as error:
print(f"An error occurred processing {task[0]}: {error}")
pbar.update(1)
pbar.close()
elapsed_time = time.time() - start_time
print(f"Embedding extraction completed in {elapsed_time:.2f} seconds.")
if __name__ == "__main__":
main()
|