Liusuthu's picture
Upload folder using huggingface_hub
890de26 verified
raw
history blame
21.7 kB
# -*- coding:utf-8 -*-
# @FileName :paraformer.py
# @Time :2023/8/8 20:04
# @Author :lovemefan
# @Email :[email protected]
import glob
import os
import pickle
from pathlib import Path
from typing import List, Tuple, Union
import numpy as np
from paraformer.runtime.python.model.lm.transformer_lm import TransformerLM
from paraformer.runtime.python.utils.asrOrtInferRuntimeSession import (
AsrOfflineOrtInferRuntimeSession,
AsrOnlineDecoderOrtInferRuntimeSession,
AsrOnlineEncoderOrtInferRuntimeSession,
CharTokenizer,
Hypothesis,
TokenIDConverter,
)
from paraformer.runtime.python.utils.audioHelper import AudioReader
from paraformer.runtime.python.utils.logger import logger
from paraformer.runtime.python.utils.postprocess import sentence_postprocess
from paraformer.runtime.python.utils.preprocess import (
SinusoidalPositionEncoderOnline,
WavFrontend,
WavFrontendOnline,
)
from paraformer.runtime.python.utils.singleton import get_all_instance, singleton
class ParaformerOnlineModel:
def __init__(
self,
model_dir: Union[str, Path] = None,
batch_size: int = 1,
chunk_size: List = [5, 10, 5],
device_id: Union[str, int] = "-1",
quantize: bool = False,
intra_op_num_threads: int = 4,
):
logger.info(f"init online context for client")
config_file = os.path.join(model_dir, "config.pkl")
with open(config_file, "rb") as file:
config = pickle.load(file)
cmvn_file = os.path.join(model_dir, "am.mvn")
self.converter = TokenIDConverter(config["token_list"])
self.tokenizer = CharTokenizer()
self.frontend = WavFrontendOnline(
cmvn_file=cmvn_file, **config["frontend_conf"]
)
if (
"AsrOnlineEncoderOrtInferRuntimeSession" not in get_all_instance()
and "AsrOnlineDecoderOrtInferRuntimeSession" not in get_all_instance()
):
if not Path(model_dir).exists():
raise FileNotFoundError(f"{model_dir} is not exist")
encoder_model_file = os.path.join(model_dir, "model.onnx")
decoder_model_file = os.path.join(model_dir, "decoder.onnx")
if quantize:
encoder_model_file = glob.glob(
os.path.join(model_dir, "model_quant_*.onnx")
)
decoder_model_file = os.path.join(model_dir, "decoder_quant.onnx")
self.pe = SinusoidalPositionEncoderOnline()
self.ort_encoder_infer = AsrOnlineEncoderOrtInferRuntimeSession(
encoder_model_file, device_id, intra_op_num_threads=intra_op_num_threads
)
self.ort_decoder_infer = AsrOnlineDecoderOrtInferRuntimeSession(
decoder_model_file, device_id, intra_op_num_threads=intra_op_num_threads
)
else:
self.pe = SinusoidalPositionEncoderOnline()
self.ort_encoder_infer = get_all_instance().get(
"AsrOnlineEncoderOrtInferRuntimeSession"
)
self.ort_decoder_infer = get_all_instance().get(
"AsrOnlineDecoderOrtInferRuntimeSession"
)
self.batch_size = batch_size
self.chunk_size = chunk_size
self.encoder_output_size = config["encoder_conf"]["output_size"]
self.fsmn_layer = config["decoder_conf"]["num_blocks"]
self.fsmn_lorder = config["decoder_conf"]["kernel_size"] - 1
self.fsmn_dims = config["encoder_conf"]["output_size"]
self.feats_dims = (
config["frontend_conf"]["n_mels"] * config["frontend_conf"]["lfr_m"]
)
self.cif_threshold = config["predictor_conf"]["threshold"]
self.tail_threshold = config["predictor_conf"]["tail_threshold"]
def prepare_cache(self, cache: dict = {}, batch_size=1):
if len(cache) > 0:
return cache
cache["start_idx"] = 0
cache["cif_hidden"] = np.zeros(
(batch_size, 1, self.encoder_output_size)
).astype(np.float32)
cache["cif_alphas"] = np.zeros((batch_size, 1)).astype(np.float32)
cache["chunk_size"] = self.chunk_size
cache["last_chunk"] = False
cache["feats"] = np.zeros(
(batch_size, self.chunk_size[0] + self.chunk_size[2], self.feats_dims)
).astype(np.float32)
cache["decoder_fsmn"] = []
for i in range(self.fsmn_layer):
fsmn_cache = np.zeros(
(batch_size, self.fsmn_dims, self.fsmn_lorder)
).astype(np.float32)
cache["decoder_fsmn"].append(fsmn_cache)
return cache
def add_overlap_chunk(self, feats: np.ndarray, cache: dict = {}):
if len(cache) == 0:
return feats
# process last chunk
overlap_feats = np.concatenate((cache["feats"], feats), axis=1)
if cache["is_final"]:
cache["feats"] = overlap_feats[:, -self.chunk_size[0] :, :]
if not cache["last_chunk"]:
padding_length = sum(self.chunk_size) - overlap_feats.shape[1]
overlap_feats = np.pad(
overlap_feats, ((0, 0), (0, padding_length), (0, 0))
)
else:
cache["feats"] = overlap_feats[
:, -(self.chunk_size[0] + self.chunk_size[2]) :, :
]
return overlap_feats
def __call__(self, audio_in: np.ndarray, **kwargs):
waveforms = np.expand_dims(audio_in, axis=0)
param_dict = kwargs.get("param_dict", dict())
is_final = param_dict.get("is_final", False)
cache = param_dict.get("cache", dict())
asr_res = []
if waveforms.shape[1] < 16 * 60 and is_final and len(cache) > 0:
cache["last_chunk"] = True
feats = cache["feats"]
feats_len = np.array([feats.shape[1]]).astype(np.int32)
asr_res = self.infer(feats, feats_len, cache)
return asr_res
feats, feats_len = self.extract_feat(waveforms, is_final)
if feats.shape[1] != 0:
feats *= self.encoder_output_size**0.5
cache = self.prepare_cache(cache)
cache["is_final"] = is_final
# fbank -> position encoding -> overlap chunk
feats = self.pe.forward(feats, cache["start_idx"])
cache["start_idx"] += feats.shape[1]
if is_final:
if feats.shape[1] + self.chunk_size[2] <= self.chunk_size[1]:
cache["last_chunk"] = True
feats = self.add_overlap_chunk(feats, cache)
else:
# first chunk
feats_chunk1 = self.add_overlap_chunk(
feats[:, : self.chunk_size[1], :], cache
)
feats_len = np.array([feats_chunk1.shape[1]]).astype(np.int32)
asr_res_chunk1 = self.infer(feats_chunk1, feats_len, cache)
# last chunk
cache["last_chunk"] = True
feats_chunk2 = self.add_overlap_chunk(
feats[
:,
-(
feats.shape[1] + self.chunk_size[2] - self.chunk_size[1]
) :,
:,
],
cache,
)
feats_len = np.array([feats_chunk2.shape[1]]).astype(np.int32)
asr_res_chunk2 = self.infer(feats_chunk2, feats_len, cache)
asr_res_chunk = asr_res_chunk1 + asr_res_chunk2
res = {}
for pred in asr_res_chunk:
for key, value in pred.items():
if key in res:
res[key][0] += value[0]
res[key][1].extend(value[1])
else:
res[key] = [value[0], value[1]]
return [res]
else:
feats = self.add_overlap_chunk(feats, cache)
feats_len = np.array([feats.shape[1]]).astype(np.int32)
asr_res = self.infer(feats, feats_len, cache)
return asr_res
def infer(self, feats: np.ndarray, feats_len: np.ndarray, cache):
# encoder forward
enc_input = [feats, feats_len]
enc, enc_lens, cif_alphas = self.ort_encoder_infer(enc_input)
# predictor forward
acoustic_embeds, acoustic_embeds_len = self.cif_search(enc, cif_alphas, cache)
# decoder forward
asr_res = []
if acoustic_embeds.shape[1] > 0:
dec_input = [enc, enc_lens, acoustic_embeds, acoustic_embeds_len]
dec_input.extend(cache["decoder_fsmn"])
dec_output = self.ort_decoder_infer(dec_input)
logits, _, cache["decoder_fsmn"] = (
dec_output[0],
dec_output[1],
dec_output[2:],
)
cache["decoder_fsmn"] = [
item[:, :, -self.fsmn_lorder :] for item in cache["decoder_fsmn"]
]
preds = self.decode(logits, acoustic_embeds_len)
for pred in preds:
pred = sentence_postprocess(pred)
asr_res.append({"preds": pred})
return asr_res
def load_data(self, wav_content: Union[str, np.ndarray, List[str]]) -> List:
def load_wav(path: str) -> np.ndarray:
waveform, _ = AudioReader.read_wav_file(path)
return waveform
if isinstance(wav_content, np.ndarray):
return [wav_content]
if isinstance(wav_content, str):
return [load_wav(wav_content)]
if isinstance(wav_content, list):
return [load_wav(path) for path in wav_content]
raise TypeError(f"The type of {wav_content} is not in [str, np.ndarray, list]")
def extract_feat(
self, waveforms: np.ndarray, is_final: bool = False
) -> Tuple[np.ndarray, np.ndarray]:
waveforms_lens = np.zeros(waveforms.shape[0]).astype(np.int32)
for idx, waveform in enumerate(waveforms):
waveforms_lens[idx] = waveform.shape[-1]
feats, feats_len = self.frontend.extract_fbank(
waveforms, waveforms_lens, is_final
)
return feats.astype(np.float32), feats_len.astype(np.int32)
def decode(self, am_scores: np.ndarray, token_nums: int):
return [
self.decode_one(am_score, token_num)
for am_score, token_num in zip(am_scores, token_nums)
]
def decode_one(self, am_score: np.ndarray, valid_token_num: int) -> List[str]:
yseq = am_score.argmax(axis=-1)
score = am_score.max(axis=-1)
score = np.sum(score, axis=-1)
# pad with mask tokens to ensure compatibility with sos/eos tokens
# asr_model.sos:1 asr_model.eos:2
yseq = np.array([1] + yseq.tolist() + [2])
hyp = Hypothesis(yseq=yseq, score=score)
# remove sos/eos and get results
last_pos = -1
token_int = hyp.yseq[1:last_pos].tolist()
# remove blank symbol id, which is assumed to be 0
token_int = list(filter(lambda x: x not in (0, 2), token_int))
# Change integer-ids to tokens
token = self.converter.ids2tokens(token_int)
token = token[:valid_token_num]
# texts = sentence_postprocess(token)
return token
def cif_search(self, hidden, alphas, cache=None):
batch_size, len_time, hidden_size = hidden.shape
token_length = []
list_fires = []
list_frames = []
cache_alphas = []
cache_hiddens = []
alphas[:, : self.chunk_size[0]] = 0.0
alphas[:, sum(self.chunk_size[:2]) :] = 0.0
if cache is not None and "cif_alphas" in cache and "cif_hidden" in cache:
hidden = np.concatenate((cache["cif_hidden"], hidden), axis=1)
alphas = np.concatenate((cache["cif_alphas"], alphas), axis=1)
if cache is not None and "last_chunk" in cache and cache["last_chunk"]:
tail_hidden = np.zeros((batch_size, 1, hidden_size)).astype(np.float32)
tail_alphas = np.array([[self.tail_threshold]]).astype(np.float32)
tail_alphas = np.tile(tail_alphas, (batch_size, 1))
hidden = np.concatenate((hidden, tail_hidden), axis=1)
alphas = np.concatenate((alphas, tail_alphas), axis=1)
len_time = alphas.shape[1]
for b in range(batch_size):
integrate = 0.0
frames = np.zeros(hidden_size).astype(np.float32)
list_frame = []
list_fire = []
for t in range(len_time):
alpha = alphas[b][t]
if alpha + integrate < self.cif_threshold:
integrate += alpha
list_fire.append(integrate)
frames += alpha * hidden[b][t]
else:
frames += (self.cif_threshold - integrate) * hidden[b][t]
list_frame.append(frames)
integrate += alpha
list_fire.append(integrate)
integrate -= self.cif_threshold
frames = integrate * hidden[b][t]
cache_alphas.append(integrate)
if integrate > 0.0:
cache_hiddens.append(frames / integrate)
else:
cache_hiddens.append(frames)
token_length.append(len(list_frame))
list_fires.append(list_fire)
list_frames.append(list_frame)
max_token_len = max(token_length)
list_ls = []
for b in range(batch_size):
pad_frames = np.zeros(
(max_token_len - token_length[b], hidden_size)
).astype(np.float32)
if token_length[b] == 0:
list_ls.append(pad_frames)
else:
list_ls.append(np.concatenate((list_frames[b], pad_frames), axis=0))
cache["cif_alphas"] = np.stack(cache_alphas, axis=0)
cache["cif_alphas"] = np.expand_dims(cache["cif_alphas"], axis=0)
cache["cif_hidden"] = np.stack(cache_hiddens, axis=0)
cache["cif_hidden"] = np.expand_dims(cache["cif_hidden"], axis=0)
return np.stack(list_ls, axis=0).astype(np.float32), np.stack(
token_length, axis=0
).astype(np.int32)
@singleton
class ParaformerOfflineModel:
def __init__(
self, model_dir: str = None, use_lm=False, intra_op_num_threads=4
) -> None:
config_path = os.path.join(model_dir, "config.pkl")
with open(config_path, "rb") as file:
config = pickle.load(file)
self.use_lm = use_lm
self.converter = TokenIDConverter(config["token_list"])
self.tokenizer = CharTokenizer(**config["CharTokenizer"])
self.frontend = WavFrontend(
cmvn_file=os.path.join(model_dir, "am.mvn"), **config["frontend_conf"]
)
if os.path.exists(os.path.join(model_dir, "model_quant.onnx")):
model_file = os.path.join(model_dir, "model_quant.onnx")
else:
model_file = glob.glob(os.path.join(model_dir, "model_quant_*.onnx"))
contextual_model = os.path.join(model_dir, "model_eb.onnx")
if use_lm:
lm_model_path = os.path.join(model_dir, "lm")
self.lm = TransformerLM(lm_model_path, intra_op_num_threads)
self.ort_infer = AsrOfflineOrtInferRuntimeSession(
model_file,
contextual_model=contextual_model,
intra_op_num_threads=intra_op_num_threads,
)
def extract_feat(self, waveforms: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
fbank, fbank_len = self.frontend.fbank(waveforms)
feats, feats_len = self.frontend.lfr_cmvn(fbank)
return feats.astype(np.float32), feats_len.astype(np.int32)
def decoder_with_greedy_search(self, am_score):
yseq = am_score.argmax(axis=-1)
score = am_score.max(axis=-1)
score = np.sum(score, axis=-1)
# pad with mask tokens to ensure compatibility with sos/eos tokens
# asr_model.sos:1 asr_model.eos:2
yseq = np.array([1] + yseq.tolist() + [2])
hyp = Hypothesis(yseq=yseq, score=score)
# remove sos/eos and get results
last_pos = -1
token_int = hyp.yseq[1:last_pos].tolist()
# remove blank symbol id, which is assumed to be 0
token_int = list(filter(lambda x: x not in (0, 2), token_int))
# Change integer-ids to tokens
token = self.converter.ids2tokens(token_int)
texts = sentence_postprocess(token)
return texts
def search(self, beams, am_score: np.ndarray, beam_size=5, lm_weight=0.25):
"""Search new tokens for running hypotheses and encoded speech x.
Args:
beams (List[Hypothesis]): Running hypotheses on beam
am_score (torch.Tensor): decoded output (L, vocab_size)
beam_size: beam size
lm_weight: the weight of lm
"""
best_hyps = []
n_vocab = len(self.converter.token_list)
part_ids = np.arange(n_vocab) # no pre-beam
for hyp in beams:
# scoring
weighted_scores = np.zeros(n_vocab)
weighted_scores += am_score
if self.use_lm:
lm_score = self.lm.lm(hyp.yseq[:, -20:])
weighted_scores += lm_weight * lm_score[0][0]
# add previous hyp score
weighted_scores += hyp.score
# update hyps
for j in np.argpartition(weighted_scores, -beam_size)[-beam_size:]:
# will be (2 x beam at most)
best_hyps.append(
Hypothesis(
score=weighted_scores[j],
yseq=np.concatenate(
(hyp.yseq[0], np.array([j], dtype=np.int64))
)[None, ...],
)
)
# sort and prune 2 x beam -> beam
best_hyps = sorted(best_hyps, key=lambda x: x.score, reverse=True)[
: min(len(best_hyps), beam_size)
]
return best_hyps
def decoder_with_beam_search(self, am_scores, beam_size=5, lm_weight=0.15):
# set length bounds
# main loop of prefix search
beams = [
Hypothesis(
score=0,
yseq=np.array([[1]], dtype=np.int64),
)
]
for score in am_scores:
beams = self.search(beams, score, beam_size=beam_size, lm_weight=lm_weight)
# remove blank symbol id, which is assumed to be 0
token_int = list(filter(lambda x: x not in (0, 2), beams[0].yseq.tolist()[0]))
# Change integer-ids to tokens
token = self.converter.ids2tokens(token_int)
texts = sentence_postprocess(token)
return texts
def infer(
self,
audio: Union[str, np.ndarray, bytes],
hot_words: str = None,
beam_search=False,
beam_size=5,
lm_weight=0.15,
):
if isinstance(audio, str):
audio, _ = AudioReader.read_wav_file(audio)
elif isinstance(audio, bytes):
audio, _ = AudioReader.read_wav_bytes(audio)
feats, feats_len = self.extract_feat(audio)
feats = feats[None, ...]
feats_len = feats_len[None, ...]
hot_words, hot_words_length = self.proc_hot_words(hot_words)
input_dict = dict(
zip(self.ort_infer.get_contextual_model_input_names(), (hot_words,))
)
[bias_embed] = self.ort_infer.contextual_model.run(None, input_dict)
# index from bias_embed
bias_embed = bias_embed.transpose(1, 0, 2)
_ind = np.arange(0, len(hot_words)).tolist()
bias_embed = bias_embed[_ind, hot_words_length]
bias_embed = np.expand_dims(bias_embed, axis=0)
bias_embed = np.repeat(bias_embed, feats.shape[0], axis=0)
if feats_len > 0:
am_scores = self.ort_infer(
feats=feats, feats_length=feats_len, bias_embed=bias_embed
)
else:
am_scores = []
results = []
for am_score in am_scores:
if beam_search:
pred_res = self.decoder_with_beam_search(
am_score, beam_size=beam_size, lm_weight=lm_weight
)
else:
pred_res = self.decoder_with_greedy_search(am_score)
results.append(pred_res)
return results if len(results) != 0 else [[""]]
def proc_hot_words(self, hot_words: str):
hot_words = hot_words.strip().split(" ")
hot_words_length = [len(i) - 1 for i in hot_words]
hot_words_length.append(0)
hot_words_length = np.array(hot_words_length).astype("int32")
def word_map(word):
return np.array([self.converter.tokens2ids(i)[0] for i in word])
hot_words_int = [word_map(i) for i in hot_words]
# import pdb; pdb.set_trace()
hot_words_int.append(np.array([1]))
hot_words = self._pad_list(hot_words_int, max_len=10)
return hot_words, hot_words_length
def _pad_list(self, xs, max_len=None):
n_batch = len(xs)
if max_len is None:
max_len = max(x.size(0) for x in xs)
pad = np.zeros((n_batch, max_len), dtype=np.int32)
for i in range(n_batch):
pad[i, : xs[i].shape[0]] = xs[i]
return pad