ParisNeo commited on
Commit
3c977dc
·
1 Parent(s): 80a8250

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +131 -0
app.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
3
+ import torch
4
+
5
+
6
+ class MBartTranslator:
7
+ """MBartTranslator class provides a simple interface for translating text using the MBart language model.
8
+
9
+ The class can translate between 50 languages and is based on the "facebook/mbart-large-50-many-to-many-mmt"
10
+ pre-trained MBart model. However, it is possible to use a different MBart model by specifying its name.
11
+
12
+ Attributes:
13
+ model (MBartForConditionalGeneration): The MBart language model.
14
+ tokenizer (MBart50TokenizerFast): The MBart tokenizer.
15
+ """
16
+
17
+ def __init__(self, model_name="facebook/mbart-large-50-many-to-many-mmt", src_lang=None, tgt_lang=None):
18
+
19
+ self.supported_languages = [
20
+ "ar_AR",
21
+ "de_DE",
22
+ "en_XX",
23
+ "es_XX",
24
+ "fr_XX",
25
+ "hi_IN",
26
+ "it_IT",
27
+ "ja_XX",
28
+ "ko_XX",
29
+ "pt_XX",
30
+ "ru_XX",
31
+ "zh_XX",
32
+ "af_ZA",
33
+ "bn_BD",
34
+ "bs_XX",
35
+ "ca_XX",
36
+ "cs_CZ",
37
+ "da_XX",
38
+ "el_GR",
39
+ "et_EE",
40
+ "fa_IR",
41
+ "fi_FI",
42
+ "gu_IN",
43
+ "he_IL",
44
+ "hi_XX",
45
+ "hr_HR",
46
+ "hu_HU",
47
+ "id_ID",
48
+ "is_IS",
49
+ "ja_XX",
50
+ "jv_XX",
51
+ "ka_GE",
52
+ "kk_XX",
53
+ "km_KH",
54
+ "kn_IN",
55
+ "ko_KR",
56
+ "lo_LA",
57
+ "lt_LT",
58
+ "lv_LV",
59
+ "mk_MK",
60
+ "ml_IN",
61
+ "mr_IN",
62
+ "ms_MY",
63
+ "ne_NP",
64
+ "nl_XX",
65
+ "no_XX",
66
+ "pl_XX",
67
+ "ro_RO",
68
+ "si_LK",
69
+ "sk_SK",
70
+ "sl_SI",
71
+ "sq_AL",
72
+ "sr_XX",
73
+ "sv_XX",
74
+ "sw_TZ",
75
+ "ta_IN",
76
+ "te_IN",
77
+ "th_TH",
78
+ "tl_PH",
79
+ "tr_TR",
80
+ "uk_UA",
81
+ "ur_PK",
82
+ "vi_VN",
83
+ "war_PH",
84
+ "yue_XX",
85
+ "zh_CN",
86
+ "zh_TW",
87
+ ]
88
+ print("Building translator")
89
+ print("Loading generator (this may take few minutes the first time as I need to download the model)")
90
+ self.model = MBartForConditionalGeneration.from_pretrained(model_name)
91
+ print("Loading tokenizer")
92
+ self.tokenizer = MBart50TokenizerFast.from_pretrained(model_name, src_lang=src_lang, tgt_lang=tgt_lang)
93
+ print("Translator is ready")
94
+
95
+ def translate(self, text: str, input_language: str, output_language: str) -> str:
96
+ """Translate the given text from the input language to the output language.
97
+
98
+ Args:
99
+ text (str): The text to translate.
100
+ input_language (str): The input language code (e.g. "hi_IN" for Hindi).
101
+ output_language (str): The output language code (e.g. "en_US" for English).
102
+
103
+ Returns:
104
+ str: The translated text.
105
+ """
106
+ if input_language not in self.supported_languages:
107
+ raise ValueError(f"Input language not supported. Supported languages: {self.supported_languages}")
108
+ if output_language not in self.supported_languages:
109
+ raise ValueError(f"Output language not supported. Supported languages: {self.supported_languages}")
110
+
111
+ self.tokenizer.src_lang = input_language
112
+ encoded_input = self.tokenizer(text, return_tensors="pt")
113
+ generated_tokens = self.model.generate(
114
+ **encoded_input, forced_bos_token_id=self.tokenizer.lang_code_to_id[output_language]
115
+ )
116
+ translated_text = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=False)
117
+
118
+ return translated_text[0]
119
+
120
+
121
+
122
+ def translate_text(source_lang, target_lang, text):
123
+ translator = MBartTranslator()
124
+ return translator.translate(text, source_lang, target_lang)
125
+
126
+ translation_interface = gr.Interface(fn=translate_text,
127
+ inputs=[gr.inputs.Dropdown(choices=["en_XX", "es_XX", "fr_XX", "zh_XX", "hi_IN"], label="Source Language"),
128
+ gr.inputs.Dropdown(choices=["en_XX", "es_XX", "fr_XX", "zh_XX", "hi_IN"], label="Target Language"),
129
+ gr.inputs.Textbox(label="Text to translate")],
130
+ outputs=gr.outputs.Textbox(label="Translated text"))
131
+ translation_interface.launch()