lucio's picture
omg
d27ee9b
raw
history blame
4.62 kB
from io import BytesIO
from typing import Tuple
import wave
import gradio as gr
import numpy as np
from pydub.audio_segment import AudioSegment
import requests
from os.path import exists
from stt import Model
import torch
from transformers import AutoModelForCTC, Wav2Vec2Processor
import torchaudio
from speechbrain.pretrained import EncoderClassifier
# initialize language ID model
lang_classifier = EncoderClassifier.from_hparams(
source="speechbrain/lang-id-commonlanguage_ecapa",
savedir="pretrained_models/lang-id-commonlanguage_ecapa"
)
def load_hf_model(model_path="facebook/wav2vec2-large-robust-ft-swbd-300h"):
processor = Wav2Vec2Processor.from_pretrained(model_path)
model = AutoModelForCTC.from_pretrained(model_path)
return processor, model
# download STT model
model_info = {
"mixteco": ("https://coqui.gateway.scarf.sh/mixtec/jemeyer/v1.0.0/model.tflite", "mixtec.tflite"),
"chatino": ("https://coqui.gateway.scarf.sh/chatino/bozden/v1.0.0/model.tflite", "chatino.tflite"),
"totonaco": ("https://coqui.gateway.scarf.sh/totonac/bozden/v1.0.0/model.tflite", "totonac.tflite"),
"español": ("jonatasgrosman/wav2vec2-large-xlsr-53-spanish", "spanish_xlsr"),
"inglés": ("facebook/wav2vec2-large-robust-ft-swbd-300h", "english_xlsr"),
}
STT_MODELS = {lang: load_hf_model(model_info[lang][0]) for lang in ("inglés", "español")}
def client(audio_data: np.array, sample_rate: int, default_lang: str):
output_audio = _convert_audio(audio_data, sample_rate)
waveform, _ = torchaudio.load(output_audio)
out_prob, score, index, text_lab = lang_classifier.classify_batch(waveform)
output_audio.seek(0)
fin = wave.open(output_audio, 'rb')
audio = np.frombuffer(fin.readframes(fin.getnframes()), np.int16)
fin.close()
print(default_lang, text_lab)
if text_lab == 'Spanish':
processor, model = STT_MODELS['español']
inputs = processor(waveform)
logits = model(inputs.input_values, attention_mask=inputs.attention_mask).logits
result = processor.decode(torch.argmax(logits, dim=-1).cpu().tolist())
else:
ds = STT_MODELS[default_lang]
result = ds.stt(audio)
return f"{text_lab}: {result}"
def load_coqui_models(language):
model_path, file_name = model_info.get(language, ("", ""))
if not exists(file_name):
print(f"Downloading {model_path}")
r = requests.get(model_path, allow_redirects=True)
with open(file_name, 'wb') as file:
file.write(r.content)
else:
print(f"Found {file_name}. Skipping download...")
return Model(file_name)
for lang in ('mixteco', 'chatino', 'totonaco'):
STT_MODELS[lang] = load_coqui_models(lang)
def stt(default_lang: str, audio: Tuple[int, np.array]):
sample_rate, audio = audio
use_scorer = False
recognized_result = client(audio, sample_rate, default_lang)
return recognized_result
def _convert_audio(audio_data: np.array, sample_rate: int):
source_audio = BytesIO()
source_audio.write(audio_data)
source_audio.seek(0)
output_audio = BytesIO()
wav_file = AudioSegment.from_raw(
source_audio,
channels=1,
sample_width=2,
frame_rate=sample_rate
)
wav_file.set_frame_rate(16000).set_channels(1).export(output_audio, "wav", codec="pcm_s16le")
output_audio.seek(0)
return output_audio
iface = gr.Interface(
fn=stt,
inputs=[
gr.inputs.Radio(choices=("chatino", "mixteco", "totonaco"), default="mixteco", label="Lengua principal"),
gr.inputs.Audio(type="numpy", label="Audio", optional=False),
],
outputs=gr.outputs.Textbox(label="Output"),
title="Coqui STT Yoloxochitl Mixtec",
theme="huggingface",
description="Prueba de dictado a texto para el mixteco de Yoloxochitl,"
" usando [el modelo entrenado por Josh Meyer](https://coqui.ai/mixtec/jemeyer/v1.0.0/)"
" con [los datos recopilados por Rey Castillo y sus colaboradores](https://www.openslr.org/89)."
" Esta prueba es basada en la de [Ukraniano](https://huggingface.co./spaces/robinhad/ukrainian-stt)."
" \n\n"
"Speech-to-text demo for Yoloxochitl Mixtec,"
" using [the model trained by Josh Meyer](https://coqui.ai/mixtec/jemeyer/v1.0.0/)"
" on [the corpus compiled by Rey Castillo and collaborators](https://www.openslr.org/89)."
" This demo is based on the [Ukrainian STT demo](https://huggingface.co./spaces/robinhad/ukrainian-stt).",
)
iface.launch()