Spaces:
Runtime error
Runtime error
File size: 3,827 Bytes
d775346 4ab8d0c d775346 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
from __future__ import annotations
import os
import gradio as gr
import numpy as np
import torch
import torchaudio
from seamless_communication.models.inference import Translator
from lang_list import LANGUAGE_NAME_TO_CODE, S2TT_TARGET_LANGUAGE_NAMES
DESCRIPTION = """# Speech to Text Translation
[SeamlessM4T](https://github.com/facebookresearch/seamless_communication) is designed to provide high-quality
translation, allowing people from different linguistic communities to communicate effortlessly through speech and text.
This unified model enables multiple tasks like Speech-to-Speech (S2ST), Speech-to-Text (S2TT), Text-to-Speech (T2ST)
translation and more, without relying on multiple separate models. Here the task is to do the speech to text translation.
"""
CACHE_EXAMPLES = os.getenv("CACHE_EXAMPLES") == "1"
AUDIO_SAMPLE_RATE = 44100
#MAX_INPUT_AUDIO_LENGTH = 1800 # in seconds
DEFAULT_TARGET_LANGUAGE = "French"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
translator = Translator(
model_name_or_card="seamlessM4T_medium",
vocoder_name_or_card="vocoder_36langs",
device=device,
)
task_name = "S2TT (Speech to Text translation)"
task_name = task_name.split()[0]
def predict(
audio_source: str,
input_audio_mic: str | None,
input_audio_file: str | None,
target_language: str,
) -> str:
target_language_code = LANGUAGE_NAME_TO_CODE[target_language]
if audio_source == "microphone":
input_data = input_audio_mic
else:
input_data = input_audio_file
arr, org_sr = torchaudio.load(input_data)
new_arr = torchaudio.functional.resample(arr, orig_freq=org_sr, new_freq=AUDIO_SAMPLE_RATE)
torchaudio.save(input_data, new_arr, sample_rate=int(AUDIO_SAMPLE_RATE))
text_out, wav, sr = translator.predict(
input = input_data,
task_str = task_name,
tgt_lang=target_language_code,
ngram_filtering=True,
)
return text_out
def update_audio_ui(audio_source: str) -> tuple[dict, dict]:
mic = audio_source == "microphone"
return (
gr.update(visible=mic, value=None), # input_audio_mic
gr.update(visible=not mic, value=None), # input_audio_file
)
with gr.Blocks(css="style.css") as demo:
gr.Markdown(DESCRIPTION)
with gr.Group():
with gr.Row():
target_language = gr.Dropdown(
label="Target language",
choices=S2TT_TARGET_LANGUAGE_NAMES,
)
with gr.Row() as audio_box:
audio_source = gr.Radio(
label="Audio source",
choices=["file", "microphone"],
value="file",
)
input_audio_mic = gr.Audio(
label="Input speech",
type="filepath",
source="microphone",
visible=False,
)
input_audio_file = gr.Audio(
label="Input speech",
type="filepath",
source="upload",
visible=True,
)
btn = gr.Button("Translate")
with gr.Column():
output_text = gr.Textbox(label="Translated text")
audio_source.change(
fn=update_audio_ui,
inputs=audio_source,
outputs=[
input_audio_mic,
input_audio_file,
],
queue=False,
api_name=False,
)
btn.click(
fn=predict,
inputs=[
audio_source,
input_audio_mic,
input_audio_file,
target_language,
],
outputs=[output_text],
api_name="run",
)
demo.launch()
|