sbkapelner's picture
Create app.py
2c59794 verified
raw
history blame
3.3 kB
import gradio as gr
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import langid
# Load models and tokenizers into dictionaries for easier access
models = {
"en": {
"fr":
AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-fr"),
"es":
AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-es"),
},
"fr": {
"en":
AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-fr-en"),
"es":
AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-fr-es"),
},
"es": {
"en":
AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-es-en"),
"fr":
AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-es-fr"),
},
}
tokenizers = {
"en": {
"fr":
AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-fr"),
"es":
AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-es"),
},
"fr": {
"en":
AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-fr-en"),
"es":
AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-fr-es"),
},
"es": {
"en":
AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-es-en"),
"fr":
AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-es-fr"),
},
}
def translate(input_text, source_lang, target_lang):
tokenizer = tokenizers[source_lang][target_lang]
model = models[source_lang][target_lang]
inputs = tokenizer(input_text, return_tensors="pt")
translated_tokens = model.generate(**inputs)
return tokenizer.batch_decode(translated_tokens,
skip_special_tokens=True)[0]
def translate_text(input_text):
detected_lang, _ = langid.classify(input_text)
translations = {"English": "", "French": "", "Spanish": ""}
if detected_lang == "en":
translations["French"] = translate(input_text, "en", "fr")
translations["Spanish"] = translate(input_text, "en", "es")
elif detected_lang == "fr":
translations["English"] = translate(input_text, "fr", "en")
translations["Spanish"] = translate(input_text, "fr", "es")
elif detected_lang == "es":
translations["English"] = translate(input_text, "es", "en")
translations["French"] = translate(input_text, "es", "fr")
else:
translations["Error"] = "Language not supported for translation."
return translations["English"], translations["French"], translations["Spanish"]
def clear_textboxes():
return "", ""
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
text_to_translate = gr.Textbox(label="Text to Translate")
translate_btn = gr.Button(value="Translate")
with gr.Column():
translation_en = gr.Textbox(label="Translation to English")
translation_fr = gr.Textbox(label="Translation to French")
translation_es = gr.Textbox(label="Translation to Spanish")
clear_btn = gr.Button(value="Clear")
translate_btn.click(
fn=translate_text,
inputs=[text_to_translate],
outputs=[translation_en, translation_fr, translation_es]
)
clear_btn.click(
fn=clear_textboxes,
inputs=None,
outputs=[text_to_translate, translation_en, translation_fr, translation_es]
)
demo.launch(share=True)