sanjitaa's picture
remove comments
4ab8d0c
raw
history blame contribute delete
No virus
3.83 kB
from __future__ import annotations
import os
import gradio as gr
import numpy as np
import torch
import torchaudio
from seamless_communication.models.inference import Translator
from lang_list import LANGUAGE_NAME_TO_CODE, S2TT_TARGET_LANGUAGE_NAMES
DESCRIPTION = """# Speech to Text Translation
[SeamlessM4T](https://github.com/facebookresearch/seamless_communication) is designed to provide high-quality
translation, allowing people from different linguistic communities to communicate effortlessly through speech and text.
This unified model enables multiple tasks like Speech-to-Speech (S2ST), Speech-to-Text (S2TT), Text-to-Speech (T2ST)
translation and more, without relying on multiple separate models. Here the task is to do the speech to text translation.
"""
CACHE_EXAMPLES = os.getenv("CACHE_EXAMPLES") == "1"
AUDIO_SAMPLE_RATE = 44100
#MAX_INPUT_AUDIO_LENGTH = 1800 # in seconds
DEFAULT_TARGET_LANGUAGE = "French"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
translator = Translator(
model_name_or_card="seamlessM4T_medium",
vocoder_name_or_card="vocoder_36langs",
device=device,
)
task_name = "S2TT (Speech to Text translation)"
task_name = task_name.split()[0]
def predict(
audio_source: str,
input_audio_mic: str | None,
input_audio_file: str | None,
target_language: str,
) -> str:
target_language_code = LANGUAGE_NAME_TO_CODE[target_language]
if audio_source == "microphone":
input_data = input_audio_mic
else:
input_data = input_audio_file
arr, org_sr = torchaudio.load(input_data)
new_arr = torchaudio.functional.resample(arr, orig_freq=org_sr, new_freq=AUDIO_SAMPLE_RATE)
torchaudio.save(input_data, new_arr, sample_rate=int(AUDIO_SAMPLE_RATE))
text_out, wav, sr = translator.predict(
input = input_data,
task_str = task_name,
tgt_lang=target_language_code,
ngram_filtering=True,
)
return text_out
def update_audio_ui(audio_source: str) -> tuple[dict, dict]:
mic = audio_source == "microphone"
return (
gr.update(visible=mic, value=None), # input_audio_mic
gr.update(visible=not mic, value=None), # input_audio_file
)
with gr.Blocks(css="style.css") as demo:
gr.Markdown(DESCRIPTION)
with gr.Group():
with gr.Row():
target_language = gr.Dropdown(
label="Target language",
choices=S2TT_TARGET_LANGUAGE_NAMES,
)
with gr.Row() as audio_box:
audio_source = gr.Radio(
label="Audio source",
choices=["file", "microphone"],
value="file",
)
input_audio_mic = gr.Audio(
label="Input speech",
type="filepath",
source="microphone",
visible=False,
)
input_audio_file = gr.Audio(
label="Input speech",
type="filepath",
source="upload",
visible=True,
)
btn = gr.Button("Translate")
with gr.Column():
output_text = gr.Textbox(label="Translated text")
audio_source.change(
fn=update_audio_ui,
inputs=audio_source,
outputs=[
input_audio_mic,
input_audio_file,
],
queue=False,
api_name=False,
)
btn.click(
fn=predict,
inputs=[
audio_source,
input_audio_mic,
input_audio_file,
target_language,
],
outputs=[output_text],
api_name="run",
)
demo.launch()