seyoungsong commited on
Commit
8064c8d
1 Parent(s): 44d5f5a
Files changed (2) hide show
  1. app.py +177 -8
  2. requirements.txt +1 -1
app.py CHANGED
@@ -1,31 +1,200 @@
1
  import gradio as gr
 
2
 
3
- lang_list = ["Korean", "English", "Chinese"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- def translate(prompt: str, source_lang: str, target_lang: str) -> str:
7
- return f'"{prompt}" in {source_lang} means "{prompt}" in {target_lang}'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
 
10
  inputs = [
11
  gr.Textbox(lines=4, value="Hello world!", label="Input Text"),
12
- gr.Dropdown(lang_list, value="English", label="Source Language"),
13
- gr.Dropdown(lang_list, value="Korean", label="Target Language"),
14
  ]
15
 
16
 
17
- outputs = gr.Textbox(label="Output Text")
18
 
19
 
20
  demo = gr.Interface(
21
  fn=translate,
22
  inputs=inputs,
23
  outputs=outputs,
24
- title="Beyond English-Centric Multilingual Machine Translation",
 
25
  )
26
 
27
  if __name__ == "__main__":
 
 
28
  # gradio src/pretrained/gradio/app.py
29
  # http://127.0.0.1:7860
30
- # https://huggingface.co/spaces/seyoungsong/flores101_mm100_175M
31
  demo.launch()
 
1
  import gradio as gr
2
+ from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer
3
 
4
+ lang_to_code = {
5
+ "Akrikaans": "af",
6
+ "Albanian": "sq",
7
+ "Amharic": "am",
8
+ "Arabic": "ar",
9
+ "Armenian": "hy",
10
+ "Assamese": "as",
11
+ "Asturian": "ast",
12
+ "Aymara": "ay",
13
+ "Azerbaijani": "az",
14
+ "Bashkir": "ba",
15
+ "Belarusian": "be",
16
+ "Bengali": "bn",
17
+ "Bosnian": "bs",
18
+ "Breton": "br",
19
+ "Bulgarian": "bg",
20
+ "Burmese": "my",
21
+ "Catalan": "ca",
22
+ "Cebuano": "ceb",
23
+ "Central Khmer": "km",
24
+ "Chinese": "zh",
25
+ "Chokwe": "cjk",
26
+ "Croatian": "hr",
27
+ "Czech": "cs",
28
+ "Danish": "da",
29
+ "Dutch": "nl",
30
+ "Dyula": "dyu",
31
+ "English": "en",
32
+ "Estonian": "et",
33
+ "Finnish": "fi",
34
+ "French": "fr",
35
+ "Fulah": "ff",
36
+ "Galician": "gl",
37
+ "Ganda": "lg",
38
+ "Georgian": "ka",
39
+ "German": "de",
40
+ "Greek": "el",
41
+ "Gujarati": "gu",
42
+ "Haitian Creole": "ht",
43
+ "Hausa": "ha",
44
+ "Hebrew": "he",
45
+ "Hindi": "hi",
46
+ "Hungarian": "hu",
47
+ "Icelandic": "is",
48
+ "Igbo": "ig",
49
+ "Iloko": "ilo",
50
+ "Indonesian": "id",
51
+ "Irish": "ga",
52
+ "Italian": "it",
53
+ "Japanese": "ja",
54
+ "Javanese": "jv",
55
+ "Kabuverdianu": "kea",
56
+ "Kachin": "kac",
57
+ "Kamba": "kam",
58
+ "Kannada": "kn",
59
+ "Kazakh": "kk",
60
+ "Kimbundu": "kmb",
61
+ "Kongo": "kg",
62
+ "Korean": "ko",
63
+ "Kurdish": "ku",
64
+ "Kyrgyz": "ky",
65
+ "Lao": "lo",
66
+ "Latvian": "lv",
67
+ "Lingala": "ln",
68
+ "Lithuanian": "lt",
69
+ "Luo": "luo",
70
+ "Luxembourgish": "lb",
71
+ "Macedonian": "mk",
72
+ "Malagasy": "mg",
73
+ "Malay": "ms",
74
+ "Malayalam": "ml",
75
+ "Maltese": "mt",
76
+ "Maori": "mi",
77
+ "Marathi": "mr",
78
+ "Mongolian": "mn",
79
+ "Nepali": "ne",
80
+ "Northern Kurdish": "kmr",
81
+ "Northern Sotho": "ns",
82
+ "Norwegian": "no",
83
+ "Nyanja": "ny",
84
+ "Occitan": "oc",
85
+ "Oriya": "or",
86
+ "Oromo": "om",
87
+ "Pashto": "ps",
88
+ "Persian": "fa",
89
+ "Polish": "pl",
90
+ "Portuguese": "pt",
91
+ "Punjabi": "pa",
92
+ "Quechua": "qu",
93
+ "Romanian": "ro",
94
+ "Russian": "ru",
95
+ "Scottish Gaelic": "gd",
96
+ "Serbian": "sr",
97
+ "Shan": "shn",
98
+ "Shona": "sn",
99
+ "Sindhi": "sd",
100
+ "Sinhala": "si",
101
+ "Slovak": "sk",
102
+ "Slovenian": "sl",
103
+ "Somali": "so",
104
+ "Spanish": "es",
105
+ "Sundanese": "su",
106
+ "Swahili": "sw",
107
+ "Swati": "ss",
108
+ "Swedish": "sv",
109
+ "Tagalog": "tl",
110
+ "Tajik": "tg",
111
+ "Tamil": "ta",
112
+ "Telugu": "te",
113
+ "Thai": "th",
114
+ "Tigrinya": "ti",
115
+ "Tswana": "tn",
116
+ "Turkish": "tr",
117
+ "Ukrainian": "uk",
118
+ "Umbundu": "umb",
119
+ "Urdu": "ur",
120
+ "Uzbek": "uz",
121
+ "Vietnamese": "vi",
122
+ "Welsh": "cy",
123
+ "Western Frisian": "fy",
124
+ "Wolof": "wo",
125
+ "Xhosa": "xh",
126
+ "Yiddish": "yi",
127
+ "Yoruba": "yo",
128
+ "Zulu": "zu",
129
+ }
130
+ lang_names = list(lang_to_code.keys())
131
 
132
+ # load model
133
+ tokenizer: M2M100Tokenizer = M2M100Tokenizer.from_pretrained(
134
+ "seyoungsong/flores101_mm100_175M"
135
+ )
136
+ model = M2M100ForConditionalGeneration.from_pretrained(
137
+ "seyoungsong/flores101_mm100_175M"
138
+ )
139
+
140
+ # fix tokenizer
141
+ tokenizer.lang_token_to_id = {
142
+ t: i
143
+ for t, i in zip(tokenizer.all_special_tokens, tokenizer.all_special_ids)
144
+ if i > 5
145
+ }
146
+ tokenizer.lang_code_to_token = {s.strip("_"): s for s in tokenizer.lang_token_to_id}
147
+ tokenizer.lang_code_to_id = {
148
+ s.strip("_"): i for s, i in tokenizer.lang_token_to_id.items()
149
+ }
150
+ tokenizer.id_to_lang_token = {i: s for s, i in tokenizer.lang_token_to_id.items()}
151
+
152
+
153
+ def translate(src_text: str, source_lang: str, target_lang: str) -> str:
154
+ # get lang code
155
+ src_lang = lang_to_code[source_lang]
156
+ tgt_lang = lang_to_code[target_lang]
157
 
158
+ # encode
159
+ tokenizer.src_lang = src_lang
160
+ encoded_input = tokenizer(src_text, return_tensors="pt")
161
+
162
+ # inference
163
+ generated_tokens = model.generate(
164
+ **encoded_input,
165
+ forced_bos_token_id=tokenizer.get_lang_id(tgt_lang),
166
+ max_length=1024,
167
+ )
168
+
169
+ # decode
170
+ pred_texts = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
171
+ pred_text = pred_texts[0]
172
+ assert isinstance(pred_text, str)
173
+
174
+ return pred_text
175
 
176
 
177
  inputs = [
178
  gr.Textbox(lines=4, value="Hello world!", label="Input Text"),
179
+ gr.Dropdown(lang_names, value="English", label="Source Language"),
180
+ gr.Dropdown(lang_names, value="Korean", label="Target Language"),
181
  ]
182
 
183
 
184
+ outputs = gr.Textbox(lines=4, label="Output Text")
185
 
186
 
187
  demo = gr.Interface(
188
  fn=translate,
189
  inputs=inputs,
190
  outputs=outputs,
191
+ title="Flores101: Large-Scale Multilingual Machine Translation",
192
+ description="[`seyoungsong/flores101_mm100_175M`](https://huggingface.co/seyoungsong/flores101_mm100_175M)",
193
  )
194
 
195
  if __name__ == "__main__":
196
+ # https://huggingface.co/seyoungsong/flores101_mm100_175M
197
+ # https://huggingface.co/spaces/seyoungsong/flores101_mm100_175M
198
  # gradio src/pretrained/gradio/app.py
199
  # http://127.0.0.1:7860
 
200
  demo.launch()
requirements.txt CHANGED
@@ -1,4 +1,4 @@
1
  --find-links https://download.pytorch.org/whl/cpu
2
  sentencepiece
3
- torch
4
  transformers
 
1
  --find-links https://download.pytorch.org/whl/cpu
2
  sentencepiece
3
+ torch==2.1.1+cpu
4
  transformers