drown0315
feat: 增加双语字幕
5e3d623
raw
history blame
5.17 kB
# Copyright 2022-2023 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import subprocess
from dataclasses import dataclass
from datetime import timedelta
from typing import Optional
from transformers import pipeline, MarianMTModel, MarianTokenizer
import numpy as np
import sherpa_onnx
from model import sample_rate
@dataclass
class Segment:
start: float
duration: float
text: str = ""
cn_text: str = ""
@property
def end(self):
return self.start + self.duration
def __str__(self):
s = f"0{timedelta(seconds=self.start)}"[:-3]
s += " --> "
s += f"0{timedelta(seconds=self.end)}"[:-3]
s = s.replace(".", ",")
s += "\n"
s += self.text
s += "\n"
s += self.cn_text
return s
def decode(
recognizer: sherpa_onnx.OfflineRecognizer,
vad: sherpa_onnx.VoiceActivityDetector,
punct: Optional[sherpa_onnx.OfflinePunctuation],
filename: str,
) -> str:
ffmpeg_cmd = [
"ffmpeg",
"-i",
filename,
"-f",
"s16le",
"-acodec",
"pcm_s16le",
"-ac",
"1",
"-ar",
str(sample_rate),
"-",
]
process = subprocess.Popen(
ffmpeg_cmd, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL
)
frames_per_read = int(sample_rate * 100) # 100 second
window_size = 512
buffer = []
segment_list = []
logging.info("Started!")
all_text = []
is_last = False
while True:
# *2 because int16_t has two bytes
data = process.stdout.read(frames_per_read * 2)
if not data:
if is_last:
break
is_last = True
data = np.zeros(sample_rate, dtype=np.int16)
samples = np.frombuffer(data, dtype=np.int16)
samples = samples.astype(np.float32) / 32768
buffer = np.concatenate([buffer, samples])
while len(buffer) > window_size:
vad.accept_waveform(buffer[:window_size])
buffer = buffer[window_size:]
streams = []
segments = []
while not vad.empty():
segment = Segment(
start=vad.front.start / sample_rate,
duration=len(vad.front.samples) / sample_rate,
)
segments.append(segment)
stream = recognizer.create_stream()
stream.accept_waveform(sample_rate, vad.front.samples)
streams.append(stream)
vad.pop()
for s in streams:
recognizer.decode_stream(s)
for seg, stream in zip(segments, streams):
en_text = stream.result.text.strip()
seg.text = en_text
if len(seg.text) == 0:
logging.info("Skip empty segment")
continue
seg.cn_text = _llm_translator.translate(en_text)
if len(all_text) == 0:
all_text.append(seg.text)
elif len(all_text[-1][0].encode()) == 1 and len(seg.text[0].encode()) == 1:
all_text.append(" ")
all_text.append(seg.text)
else:
all_text.append(seg.text)
if punct is not None:
seg.text = punct.add_punctuation(seg.text)
segment_list.append(seg)
all_text = "".join(all_text)
if punct is not None:
all_text = punct.add_punctuation(all_text)
return "\n\n".join(f"{i}\n{seg}" for i, seg in enumerate(segment_list, 1)), all_text
def translate_en_to_cn(src_text: str, ) -> str:
model_name = "Helsinki-NLP/opus-mt-en-zh"
tokenizer = MarianTokenizer.from_pretrained(model_name)
model = MarianMTModel.from_pretrained(model_name)
translated = model.generate(**tokenizer(src_text, return_tensors="pt", padding=True))
res = [tokenizer.decode(t, skip_special_tokens=True) for t in translated]
return res
class LLMTranslator:
_tokenizer: MarianTokenizer
_model: MarianMTModel
def __init__(self):
model_name = "Helsinki-NLP/opus-mt-en-zh"
self._tokenizer = MarianTokenizer.from_pretrained(model_name)
self._model = MarianMTModel.from_pretrained(model_name)
def translate(self, src_text: str) -> str:
translated = self._model.generate(**self._tokenizer(src_text, return_tensors="pt", padding=True))
res = [self._tokenizer.decode(t, skip_special_tokens=True) for t in translated]
return "".join(str(itemText) for itemText in res)
_llm_translator = LLMTranslator()