|
import gradio as gr |
|
from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer |
|
|
|
lang_to_code = { |
|
"Akrikaans": "af", |
|
"Albanian": "sq", |
|
"Amharic": "am", |
|
"Arabic": "ar", |
|
"Armenian": "hy", |
|
"Assamese": "as", |
|
"Asturian": "ast", |
|
"Aymara": "ay", |
|
"Azerbaijani": "az", |
|
"Bashkir": "ba", |
|
"Belarusian": "be", |
|
"Bengali": "bn", |
|
"Bosnian": "bs", |
|
"Breton": "br", |
|
"Bulgarian": "bg", |
|
"Burmese": "my", |
|
"Catalan": "ca", |
|
"Cebuano": "ceb", |
|
"Central Khmer": "km", |
|
"Chinese": "zh", |
|
"Chokwe": "cjk", |
|
"Croatian": "hr", |
|
"Czech": "cs", |
|
"Danish": "da", |
|
"Dutch": "nl", |
|
"Dyula": "dyu", |
|
"English": "en", |
|
"Estonian": "et", |
|
"Finnish": "fi", |
|
"French": "fr", |
|
"Fulah": "ff", |
|
"Galician": "gl", |
|
"Ganda": "lg", |
|
"Georgian": "ka", |
|
"German": "de", |
|
"Greek": "el", |
|
"Gujarati": "gu", |
|
"Haitian Creole": "ht", |
|
"Hausa": "ha", |
|
"Hebrew": "he", |
|
"Hindi": "hi", |
|
"Hungarian": "hu", |
|
"Icelandic": "is", |
|
"Igbo": "ig", |
|
"Iloko": "ilo", |
|
"Indonesian": "id", |
|
"Irish": "ga", |
|
"Italian": "it", |
|
"Japanese": "ja", |
|
"Javanese": "jv", |
|
"Kabuverdianu": "kea", |
|
"Kachin": "kac", |
|
"Kamba": "kam", |
|
"Kannada": "kn", |
|
"Kazakh": "kk", |
|
"Kimbundu": "kmb", |
|
"Kongo": "kg", |
|
"Korean": "ko", |
|
"Kurdish": "ku", |
|
"Kyrgyz": "ky", |
|
"Lao": "lo", |
|
"Latvian": "lv", |
|
"Lingala": "ln", |
|
"Lithuanian": "lt", |
|
"Luo": "luo", |
|
"Luxembourgish": "lb", |
|
"Macedonian": "mk", |
|
"Malagasy": "mg", |
|
"Malay": "ms", |
|
"Malayalam": "ml", |
|
"Maltese": "mt", |
|
"Maori": "mi", |
|
"Marathi": "mr", |
|
"Mongolian": "mn", |
|
"Nepali": "ne", |
|
"Northern Kurdish": "kmr", |
|
"Northern Sotho": "ns", |
|
"Norwegian": "no", |
|
"Nyanja": "ny", |
|
"Occitan": "oc", |
|
"Oriya": "or", |
|
"Oromo": "om", |
|
"Pashto": "ps", |
|
"Persian": "fa", |
|
"Polish": "pl", |
|
"Portuguese": "pt", |
|
"Punjabi": "pa", |
|
"Quechua": "qu", |
|
"Romanian": "ro", |
|
"Russian": "ru", |
|
"Scottish Gaelic": "gd", |
|
"Serbian": "sr", |
|
"Shan": "shn", |
|
"Shona": "sn", |
|
"Sindhi": "sd", |
|
"Sinhala": "si", |
|
"Slovak": "sk", |
|
"Slovenian": "sl", |
|
"Somali": "so", |
|
"Spanish": "es", |
|
"Sundanese": "su", |
|
"Swahili": "sw", |
|
"Swati": "ss", |
|
"Swedish": "sv", |
|
"Tagalog": "tl", |
|
"Tajik": "tg", |
|
"Tamil": "ta", |
|
"Telugu": "te", |
|
"Thai": "th", |
|
"Tigrinya": "ti", |
|
"Tswana": "tn", |
|
"Turkish": "tr", |
|
"Ukrainian": "uk", |
|
"Umbundu": "umb", |
|
"Urdu": "ur", |
|
"Uzbek": "uz", |
|
"Vietnamese": "vi", |
|
"Welsh": "cy", |
|
"Western Frisian": "fy", |
|
"Wolof": "wo", |
|
"Xhosa": "xh", |
|
"Yiddish": "yi", |
|
"Yoruba": "yo", |
|
"Zulu": "zu", |
|
} |
|
lang_names = list(lang_to_code.keys()) |
|
|
|
|
|
tokenizer: M2M100Tokenizer = M2M100Tokenizer.from_pretrained( |
|
"seyoungsong/flores101_mm100_175M" |
|
) |
|
model = M2M100ForConditionalGeneration.from_pretrained( |
|
"seyoungsong/flores101_mm100_175M" |
|
) |
|
|
|
|
|
tokenizer.lang_token_to_id = { |
|
t: i |
|
for t, i in zip(tokenizer.all_special_tokens, tokenizer.all_special_ids) |
|
if i > 5 |
|
} |
|
tokenizer.lang_code_to_token = {s.strip("_"): s for s in tokenizer.lang_token_to_id} |
|
tokenizer.lang_code_to_id = { |
|
s.strip("_"): i for s, i in tokenizer.lang_token_to_id.items() |
|
} |
|
tokenizer.id_to_lang_token = {i: s for s, i in tokenizer.lang_token_to_id.items()} |
|
|
|
|
|
def translate(src_text: str, source_lang: str, target_lang: str) -> str: |
|
|
|
src_lang = lang_to_code[source_lang] |
|
tgt_lang = lang_to_code[target_lang] |
|
|
|
|
|
tokenizer.src_lang = src_lang |
|
encoded_input = tokenizer(src_text, return_tensors="pt") |
|
|
|
|
|
generated_tokens = model.generate( |
|
**encoded_input, |
|
forced_bos_token_id=tokenizer.get_lang_id(tgt_lang), |
|
max_length=1024, |
|
) |
|
|
|
|
|
pred_texts = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) |
|
pred_text = pred_texts[0] |
|
assert isinstance(pred_text, str) |
|
|
|
return pred_text |
|
|
|
|
|
inputs = [ |
|
gr.Textbox(lines=4, value="Hello world!", label="Input Text"), |
|
gr.Dropdown(lang_names, value="English", label="Source Language"), |
|
gr.Dropdown(lang_names, value="Korean", label="Target Language"), |
|
] |
|
|
|
|
|
outputs = gr.Textbox(lines=4, label="Output Text") |
|
|
|
|
|
demo = gr.Interface( |
|
fn=translate, |
|
inputs=inputs, |
|
outputs=outputs, |
|
title="Flores101: Large-Scale Multilingual Machine Translation", |
|
description="[`seyoungsong/flores101_mm100_175M`](https://huggingface.co./seyoungsong/flores101_mm100_175M)", |
|
) |
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
|
|
|
|
demo.launch() |
|
|