Spaces:
Runtime error
Runtime error
# -*- coding:utf-8 -*- | |
# @FileName :asr_all_in_one.py | |
# @Time :2023/8/14 09:31 | |
# @Author :lovemefan | |
# @Email :[email protected] | |
import time | |
import numpy as np | |
from paraformer.runtime.python.cttPunctuator import CttPunctuator | |
from paraformer.runtime.python.fsmnVadInfer import FSMNVadOnline | |
from paraformer.runtime.python.paraformerInfer import (ParaformerOffline, | |
ParaformerOnline) | |
from paraformer.runtime.python.svInfer import SpeakerVerificationInfer | |
from paraformer.runtime.python.utils.logger import logger | |
mode_available = ["offline", "file_transcription", "online", "2pass"] | |
class AsrAllInOne: | |
def __init__( | |
self, | |
mode: str, | |
*, | |
speaker_verification=False, | |
time_stamp=False, | |
chunk_interval=10, | |
sv_model_name="cam++", | |
sv_threshold=0.6, | |
sv_max_start_silence_time=3000, | |
vad_speech_max_length=20000, | |
vad_speech_noise_thresh_low=-0.1, | |
vad_speech_noise_thresh_high=0.3, | |
vad_speech_noise_thresh=0.6, | |
hot_words="", | |
): | |
""" | |
Args: | |
mode: | |
speaker_verification: | |
time_stamp: | |
""" | |
assert ( | |
mode in mode_available | |
), f"{mode} is not support, only {mode_available} is available" | |
self.mode = mode | |
self.speaker_verification = speaker_verification | |
self.time_stamp = time_stamp | |
self.start_frame = 0 | |
self.end_frame = 0 | |
self.vad_pre_idx = 0 | |
self.mode = mode | |
self.chunk_interval = chunk_interval | |
self.speech_start = False | |
self.frames = [] | |
self.offset = 0 | |
self.hot_words = hot_words | |
if mode == "offline": | |
self.asr_offline = ParaformerOffline() | |
elif mode == "online": | |
self.asr_online = ParaformerOnline() | |
elif mode == "2pass": | |
self.asr_offline = ParaformerOffline() | |
self.asr_online = ParaformerOnline() | |
self.vad = FSMNVadOnline() | |
self.vad.vad.vad_opts.max_single_segment_time = vad_speech_max_length | |
self.vad.vad.vad_opts.max_start_silence_time = sv_max_start_silence_time | |
self.vad.vad.vad_opts.speech_noise_thresh_low = vad_speech_noise_thresh_low | |
self.vad.vad.vad_opts.speech_noise_thresh_high = ( | |
vad_speech_noise_thresh_high | |
) | |
self.vad.vad.vad_opts.speech_noise_thresh = vad_speech_noise_thresh | |
self.punc = CttPunctuator(online=True) | |
self.text_cache = "" | |
elif mode == "file_transcription": | |
self.asr_offline = ParaformerOffline() | |
self.vad = FSMNVadOnline() | |
self.vad.vad.vad_opts.speech_noise_thresh_low = vad_speech_noise_thresh_low | |
self.vad.vad.vad_opts.speech_noise_thresh_high = ( | |
vad_speech_noise_thresh_high | |
) | |
self.vad.vad.vad_opts.speech_noise_thresh = vad_speech_noise_thresh | |
self.vad.vad.vad_opts.max_single_segment_time = vad_speech_max_length | |
self.vad.vad.vad_opts.max_start_silence_time = sv_max_start_silence_time | |
self.punc = CttPunctuator(online=False) | |
else: | |
raise ValueError(f"Do not support mode: {mode}") | |
if speaker_verification: | |
self.sv = SpeakerVerificationInfer( | |
model_name=sv_model_name, threshold=sv_threshold | |
) | |
def reset_asr(self): | |
self.frames = [] | |
self.start_frame = 0 | |
self.end_frame = 0 | |
self.vad_pre_idx = 0 | |
self.vad.vad.all_reset_detection() | |
def online(self, chunk: np.ndarray, is_final: bool = False): | |
return self.asr_online.infer_online(chunk, is_final) | |
def offline(self, audio_data: np.ndarray): | |
return self.asr_offline.infer_offline(audio_data, hot_words=self.hot_words) | |
def extract_endpoint_from_vad_result(self, segments_result): | |
segments = [] | |
for _start, _end in segments_result: | |
start = -1 | |
end = -1 | |
if _start != -1: | |
start = _start | |
if _end != -1: | |
end = _end | |
segments.append([start, end]) | |
return segments | |
def one_sentence_asr(self, audio: np.ndarray): | |
"""asr offline + punc""" | |
result = self.asr_offline.infer_offline(audio, hot_words=self.hot_words) | |
result = self.punc.punctuate(result)[0] | |
return result | |
def file_transcript(self, audio: np.ndarray, step=9600): | |
""" | |
asr offline + vad + punc | |
Args: | |
audio: | |
step: | |
Returns: | |
""" | |
vad_pre_idx = 0 | |
speech_length = len(audio) | |
sample_offset = 0 | |
for sample_offset in range( | |
0, speech_length, min(step, speech_length - sample_offset) | |
): | |
if sample_offset + step >= speech_length - 1: | |
step = speech_length - sample_offset | |
is_final = True | |
else: | |
is_final = False | |
chunk = audio[sample_offset : sample_offset + step] | |
vad_pre_idx += len(chunk) | |
segments_result = self.vad.segments_online(chunk, is_final=is_final) | |
start_frame = 0 | |
end_frame = 0 | |
result = {} | |
for start, end in segments_result: | |
if start != -1: | |
start_ms = start | |
# paraformer offline inference | |
if end != -1: | |
end_frame = end * 16 | |
end_ms = end | |
data = np.array(audio[start_ms * 16 : end_frame]) | |
time_start = time.time() | |
asr_offline_final = self.asr_offline.infer_offline(data) | |
logger.debug( | |
f"asr offline inference use {time.time() - time_start} s" | |
) | |
if self.speaker_verification: | |
time_start = time.time() | |
speaker_id = self.sv.recognize(data) | |
result["speaker_id"] = speaker_id | |
logger.debug( | |
f"asr offline inference use {time.time() - time_start} s" | |
) | |
self.speech_start = False | |
time_start = time.time() | |
_final = self.punc.punctuate(asr_offline_final)[0] | |
logger.debug( | |
f"punc online inference use {time.time() - time_start} s" | |
) | |
result["text"] = _final | |
result["time_stamp"] = {"start": start_ms, "end": end_ms} | |
if is_final: | |
self.reset_asr() | |
yield result | |
def two_pass_asr(self, chunk: np.ndarray, is_final: bool = False, hot_words=None): | |
self.frames.extend(chunk.tolist()) | |
self.vad_pre_idx += len(chunk) | |
# paraformer online inference | |
if self.end_frame != -1: | |
time_start = time.time() | |
partial = self.asr_online.infer_online(chunk, is_final) | |
self.text_cache += partial | |
# empty asr online buffer | |
logger.debug(f"asr online inference use {time.time() - time_start} s") | |
# if self.speech_start: | |
# self.frames_asr_offline.append(chunk) | |
# paraformer vad inference | |
time_start = time.time() | |
segments_result = self.vad.segments_online(chunk, is_final=is_final) | |
logger.debug(f"vad online inference use {time.time() - time_start} s") | |
segments = self.extract_endpoint_from_vad_result(segments_result) | |
final = None | |
time_stamp_start = 0 | |
time_stamp_end = 0 | |
for start, end in segments: | |
if start != -1: | |
self.speech_start = True | |
self.start_frame = start * 16 | |
start = self.start_frame + len(self.frames) - self.vad_pre_idx | |
self.frames = self.frames[start:] | |
# paraformer offline inference | |
if end != -1: | |
self.end_frame = end * 16 | |
time_stamp_start = self.start_frame / 16 | |
time_stamp_end = end | |
time_start = time.time() | |
end = self.end_frame + len(self.frames) - self.vad_pre_idx | |
data = np.array(self.frames[:end]) | |
self.frames = self.frames[end:] | |
asr_offline_final = self.asr_offline.infer_offline( | |
data, hot_words=(hot_words or self.hot_words) | |
) | |
logger.debug(f"asr offline inference use {time.time() - time_start} s") | |
if self.speaker_verification: | |
time_start = time.time() | |
speaker_id = self.sv.recognize(data) | |
logger.debug( | |
f"asr offline inference use {time.time() - time_start} s" | |
) | |
self.speech_start = False | |
time_start = time.time() | |
_final = self.punc.punctuate(asr_offline_final)[0] | |
final = _final | |
logger.debug(f"punc online inference use {time.time() - time_start} s") | |
result = { | |
"partial": self.text_cache, | |
} | |
if final is not None: | |
result["final"] = final | |
result["partial"] = "" | |
result["time_stamp"] = {"start": time_stamp_start, "end": time_stamp_end} | |
if self.speaker_verification: | |
result["speaker_id"] = speaker_id | |
self.text_cache = "" | |
if is_final: | |
self.reset_asr() | |
return result | |
def two_pass_for_dialogue(self, chunk, is_final=False): | |
""" | |
asr for dialogue | |
:return: | |
""" | |
self.frames.append(chunk) | |
self.vad_pre_idx += len(chunk) // 16 | |
# paraformer online inference | |
self.frames_asr_online.append(chunk) | |
if self.speaker_verification and len(self.frames) > 3: | |
time_start = time.time() | |
speaker_id = self.sv.recognize(np.concatenate(self.frames[-3:])) | |
# print(speaker_id) | |
logger.debug(f"asr offline inference use {time.time() - time_start} s") | |
if len(self.frames_asr_online) > 0 or self.end_frame != -1: | |
time_start = time.time() | |
data = np.concatenate(self.frames_asr_online) | |
partial = self.asr_online.infer_online(data, is_final) | |
self.text_cache += partial | |
# empty asr online buffer | |
logger.debug(f"asr online inference use {time.time() - time_start} s") | |
self.frames_asr_online = [] | |
if self.speech_start: | |
self.frames_asr_offline.append(chunk) | |
# parafprmer vad inference | |
time_start = time.time() | |
segments_result = self.vad.segments_online(chunk, is_final=is_final) | |
logger.debug(f"vad online inference use {time.time() - time_start} s") | |
segments = self.extract_endpoint_from_vad_result(segments_result) | |
final = None | |
for start, end in segments: | |
self.start_frame = start | |
self.end_frame = end | |
# print(self.start_frame, self.end_frame) | |
if self.start_frame != -1: | |
self.speech_start = True | |
beg_bias = (self.vad_pre_idx - self.start_frame) / (len(chunk) // 16) | |
# print(beg_bias) | |
end_idx = (beg_bias % 1) * len(self.frames[-int(beg_bias) - 1]) | |
frames_pre = [self.frames[-int(beg_bias) - 1][-int(end_idx) :]] | |
if int(beg_bias) != 0: | |
frames_pre.extend(self.frames[-int(beg_bias) :]) | |
frames_pre = [np.concatenate(frames_pre)] | |
# print(len(frames_pre[0])) | |
self.frames_asr_offline = [] | |
self.frames_asr_offline.extend(frames_pre) | |
# clear the frames queue | |
# self.frames = self.frames[-10:] | |
# parafprmer offline inference | |
if self.end_frame != -1 and len(self.frames_asr_offline) > 0: | |
time_start = time.time() | |
if len(self.frames_asr_offline) > 1: | |
data = np.concatenate(self.frames_asr_offline[:-1]) | |
else: | |
data = np.concatenate(self.frames_asr_offline) | |
asr_offline_final = self.asr_offline.infer_offline(data) | |
logger.debug(f"asr offline inference use {time.time() - time_start} s") | |
if len(self.frames_asr_offline) > 1: | |
self.frames_asr_offline = [self.frames_asr_offline[-1]] | |
else: | |
self.frames_asr_offline = [] | |
self.speech_start = False | |
time_start = time.time() | |
_final = self.punc.punctuate(asr_offline_final)[0] | |
if final is not None: | |
final += _final | |
else: | |
final = _final | |
logger.debug(f"punc online inference use {time.time() - time_start} s") | |
result = { | |
"partial": self.text_cache, | |
} | |
if final is not None: | |
result["final"] = final | |
result["partial"] = "" | |
# if self.speaker_verification: | |
# result["speaker_id"] = speaker_id | |
self.text_cache = "" | |
if is_final: | |
self.reset_asr() | |
return result | |