Spaces:
Running
Running
File size: 4,545 Bytes
3c977dc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import gradio as gr
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
import torch
class MBartTranslator:
"""MBartTranslator class provides a simple interface for translating text using the MBart language model.
The class can translate between 50 languages and is based on the "facebook/mbart-large-50-many-to-many-mmt"
pre-trained MBart model. However, it is possible to use a different MBart model by specifying its name.
Attributes:
model (MBartForConditionalGeneration): The MBart language model.
tokenizer (MBart50TokenizerFast): The MBart tokenizer.
"""
def __init__(self, model_name="facebook/mbart-large-50-many-to-many-mmt", src_lang=None, tgt_lang=None):
self.supported_languages = [
"ar_AR",
"de_DE",
"en_XX",
"es_XX",
"fr_XX",
"hi_IN",
"it_IT",
"ja_XX",
"ko_XX",
"pt_XX",
"ru_XX",
"zh_XX",
"af_ZA",
"bn_BD",
"bs_XX",
"ca_XX",
"cs_CZ",
"da_XX",
"el_GR",
"et_EE",
"fa_IR",
"fi_FI",
"gu_IN",
"he_IL",
"hi_XX",
"hr_HR",
"hu_HU",
"id_ID",
"is_IS",
"ja_XX",
"jv_XX",
"ka_GE",
"kk_XX",
"km_KH",
"kn_IN",
"ko_KR",
"lo_LA",
"lt_LT",
"lv_LV",
"mk_MK",
"ml_IN",
"mr_IN",
"ms_MY",
"ne_NP",
"nl_XX",
"no_XX",
"pl_XX",
"ro_RO",
"si_LK",
"sk_SK",
"sl_SI",
"sq_AL",
"sr_XX",
"sv_XX",
"sw_TZ",
"ta_IN",
"te_IN",
"th_TH",
"tl_PH",
"tr_TR",
"uk_UA",
"ur_PK",
"vi_VN",
"war_PH",
"yue_XX",
"zh_CN",
"zh_TW",
]
print("Building translator")
print("Loading generator (this may take few minutes the first time as I need to download the model)")
self.model = MBartForConditionalGeneration.from_pretrained(model_name)
print("Loading tokenizer")
self.tokenizer = MBart50TokenizerFast.from_pretrained(model_name, src_lang=src_lang, tgt_lang=tgt_lang)
print("Translator is ready")
def translate(self, text: str, input_language: str, output_language: str) -> str:
"""Translate the given text from the input language to the output language.
Args:
text (str): The text to translate.
input_language (str): The input language code (e.g. "hi_IN" for Hindi).
output_language (str): The output language code (e.g. "en_US" for English).
Returns:
str: The translated text.
"""
if input_language not in self.supported_languages:
raise ValueError(f"Input language not supported. Supported languages: {self.supported_languages}")
if output_language not in self.supported_languages:
raise ValueError(f"Output language not supported. Supported languages: {self.supported_languages}")
self.tokenizer.src_lang = input_language
encoded_input = self.tokenizer(text, return_tensors="pt")
generated_tokens = self.model.generate(
**encoded_input, forced_bos_token_id=self.tokenizer.lang_code_to_id[output_language]
)
translated_text = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=False)
return translated_text[0]
def translate_text(source_lang, target_lang, text):
translator = MBartTranslator()
return translator.translate(text, source_lang, target_lang)
translation_interface = gr.Interface(fn=translate_text,
inputs=[gr.inputs.Dropdown(choices=["en_XX", "es_XX", "fr_XX", "zh_XX", "hi_IN"], label="Source Language"),
gr.inputs.Dropdown(choices=["en_XX", "es_XX", "fr_XX", "zh_XX", "hi_IN"], label="Target Language"),
gr.inputs.Textbox(label="Text to translate")],
outputs=gr.outputs.Textbox(label="Translated text"))
translation_interface.launch()
|