Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,555 Bytes
c968fc3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import shutil
import warnings
import argparse
import torch
import os
import yaml
warnings.simplefilter("ignore")
from .modules.commons import *
import time
import torchaudio
import librosa
from collections import OrderedDict
class FAcodecInference(object):
def __init__(self, args=None, cfg=None):
self.args = args
self.cfg = cfg
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = self._build_model()
self._load_checkpoint()
def _build_model(self):
model = build_model(self.cfg.model_params)
_ = [model[key].to(self.device) for key in model]
return model
def _load_checkpoint(self):
sd = torch.load(self.args.checkpoint_path, map_location="cpu")
sd = sd["net"] if "net" in sd else sd
new_params = dict()
for key, state_dict in sd.items():
new_state_dict = OrderedDict()
for k, v in state_dict.items():
if k.startswith("module."):
k = k[7:]
new_state_dict[k] = v
new_params[key] = new_state_dict
for key in new_params:
if key in self.model:
self.model[key].load_state_dict(new_params[key])
_ = [self.model[key].eval() for key in self.model]
@torch.no_grad()
def inference(self, source, output_dir):
source_audio = librosa.load(source, sr=self.cfg.preprocess_params.sr)[0]
source_audio = torch.tensor(source_audio).unsqueeze(0).float().to(self.device)
z = self.model.encoder(source_audio[None, ...].to(self.device).float())
(
z,
quantized,
commitment_loss,
codebook_loss,
timbre,
codes,
) = self.model.quantizer(
z,
source_audio[None, ...].to(self.device).float(),
n_c=self.cfg.model_params.n_c_codebooks,
return_codes=True,
)
full_pred_wave = self.model.decoder(z)
os.makedirs(output_dir, exist_ok=True)
source_name = source.split("/")[-1].split(".")[0]
torchaudio.save(
f"{output_dir}/reconstructed_{source_name}.wav",
full_pred_wave[0].cpu(),
self.cfg.preprocess_params.sr,
)
print(
"Reconstructed audio saved as: ",
f"{output_dir}/reconstructed_{source_name}.wav",
)
return quantized, codes
@torch.no_grad()
def voice_conversion(self, source, reference, output_dir):
source_audio = librosa.load(source, sr=self.cfg.preprocess_params.sr)[0]
source_audio = torch.tensor(source_audio).unsqueeze(0).float().to(self.device)
reference_audio = librosa.load(reference, sr=self.cfg.preprocess_params.sr)[0]
reference_audio = (
torch.tensor(reference_audio).unsqueeze(0).float().to(self.device)
)
z = self.model.encoder(source_audio[None, ...].to(self.device).float())
z, quantized, commitment_loss, codebook_loss, timbre = self.model.quantizer(
z,
source_audio[None, ...].to(self.device).float(),
n_c=self.cfg.model_params.n_c_codebooks,
)
z_ref = self.model.encoder(reference_audio[None, ...].to(self.device).float())
(
z_ref,
quantized_ref,
commitment_loss_ref,
codebook_loss_ref,
timbre_ref,
) = self.model.quantizer(
z_ref,
reference_audio[None, ...].to(self.device).float(),
n_c=self.cfg.model_params.n_c_codebooks,
)
z_conv = self.model.quantizer.voice_conversion(
quantized[0] + quantized[1],
reference_audio[None, ...].to(self.device).float(),
)
full_pred_wave = self.model.decoder(z_conv)
os.makedirs(output_dir, exist_ok=True)
source_name = source.split("/")[-1].split(".")[0]
reference_name = reference.split("/")[-1].split(".")[0]
torchaudio.save(
f"{output_dir}/converted_{source_name}_to_{reference_name}.wav",
full_pred_wave[0].cpu(),
self.cfg.preprocess_params.sr,
)
print(
"Voice conversion results saved as: ",
f"{output_dir}/converted_{source_name}_to_{reference_name}.wav",
)
|