VoiceCloning-be's picture
new file: .github/FUNDING.yml
4efe6b5
raw
history blame
No virus
2.99 kB
import dataclasses
import pathlib
import libf0
import librosa
import numpy as np
import resampy
import torch
import torchcrepe
import torchfcpe
import os
# from tools.anyf0.rmvpe import RMVPE
from rvc.lib.predictors.RMVPE import RMVPE0Predictor
from rvc.configs.config import Config
config = Config()
@dataclasses.dataclass
class F0Extractor:
wav_path: pathlib.Path
sample_rate: int = 44100
hop_length: int = 512
f0_min: int = 50
f0_max: int = 1600
method: str = "rmvpe"
x: np.ndarray = dataclasses.field(init=False)
def __post_init__(self):
self.x, self.sample_rate = librosa.load(self.wav_path, sr=self.sample_rate)
@property
def hop_size(self) -> float:
return self.hop_length / self.sample_rate
@property
def wav16k(self) -> np.ndarray:
return resampy.resample(self.x, self.sample_rate, 16000)
def extract_f0(self) -> np.ndarray:
f0 = None
method = self.method
if method == "crepe":
wav16k_torch = torch.FloatTensor(self.wav16k).unsqueeze(0).to(config.device)
f0 = torchcrepe.predict(
wav16k_torch,
sample_rate=16000,
hop_length=160,
batch_size=512,
fmin=self.f0_min,
fmax=self.f0_max,
device=config.device,
)
f0 = f0[0].cpu().numpy()
elif method == "fcpe":
audio = librosa.to_mono(self.x)
audio_length = len(audio)
f0_target_length = (audio_length // self.hop_length) + 1
audio = (
torch.from_numpy(audio)
.float()
.unsqueeze(0)
.unsqueeze(-1)
.to(config.device)
)
model = torchfcpe.spawn_bundled_infer_model(device=config.device)
f0 = model.infer(
audio,
sr=self.sample_rate,
decoder_mode="local_argmax",
threshold=0.006,
f0_min=self.f0_min,
f0_max=self.f0_max,
interp_uv=False,
output_interp_target_length=f0_target_length,
)
f0 = f0.squeeze().cpu().numpy()
elif method == "rmvpe":
model_rmvpe = RMVPE0Predictor(
os.path.join("rvc", "models", "predictors", "rmvpe.pt"),
is_half=config.is_half,
device=config.device,
# hop_length=80
)
f0 = model_rmvpe.infer_from_audio(self.wav16k, thred=0.03)
else:
raise ValueError(f"Unknown method: {self.method}")
return libf0.hz_to_cents(f0, librosa.midi_to_hz(0))
def plot_f0(self, f0):
from matplotlib import pyplot as plt
plt.figure(figsize=(10, 4))
plt.plot(f0)
plt.title(self.method)
plt.xlabel("Time (frames)")
plt.ylabel("F0 (cents)")
plt.show()