Spaces:
Runtime error
Runtime error
File size: 1,312 Bytes
890de26 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 |
# -*- coding:utf-8 -*-
# @FileName :svInfer.py
# @Time :2023/8/12 16:13
# @Author :lovemefan
# @Email :[email protected]
import os
from pathlib import Path
from typing import Union
import numpy as np
from paraformer.runtime.python.model.sv.campplus import Campplus
from paraformer.runtime.python.model.sv.eres2net import Eres2net
model_names = {
"cam++": (Campplus, "campplus.onnx"),
"eres2net": (Eres2net, "eres2net-aug-sv.onnx"),
"eres2net-quant": (Eres2net, "eres2net-aug-sv-quant.onnx"),
}
class SpeakerVerificationInfer:
def __init__(self, model_path=None, model_name="cam++", threshold=0.5):
if model_name not in model_names:
raise ValueError(f"model name {model_name} not in {model_names.keys()}")
project_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
model_dir = os.path.join(project_dir, "onnx", "sv")
model_path = model_path or os.path.join(model_dir, model_names[model_name][1])
self.threshold = threshold
self.model = model_names[model_name][0](model_path, threshold)
def register_speaker(self, emb: np.ndarray):
self.model.register_speaker(emb)
def recognize(self, waveform: Union[str, Path, bytes]):
return self.model.recognize(waveform, self.threshold)
|