from pathlib import Path import gradio as gr from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer lang_to_code = { "Akrikaans": "af", "Albanian": "sq", "Amharic": "am", "Arabic": "ar", "Armenian": "hy", "Assamese": "as", "Asturian": "ast", "Aymara": "ay", "Azerbaijani": "az", "Bashkir": "ba", "Belarusian": "be", "Bengali": "bn", "Bosnian": "bs", "Breton": "br", "Bulgarian": "bg", "Burmese": "my", "Catalan": "ca", "Cebuano": "ceb", "Central Khmer": "km", "Chinese": "zh", "Chokwe": "cjk", "Croatian": "hr", "Czech": "cs", "Danish": "da", "Dutch": "nl", "Dyula": "dyu", "English": "en", "Estonian": "et", "Finnish": "fi", "French": "fr", "Fulah": "ff", "Galician": "gl", "Ganda": "lg", "Georgian": "ka", "German": "de", "Greek": "el", "Gujarati": "gu", "Haitian Creole": "ht", "Hausa": "ha", "Hebrew": "he", "Hindi": "hi", "Hungarian": "hu", "Icelandic": "is", "Igbo": "ig", "Iloko": "ilo", "Indonesian": "id", "Irish": "ga", "Italian": "it", "Japanese": "ja", "Javanese": "jv", "Kabuverdianu": "kea", "Kachin": "kac", "Kamba": "kam", "Kannada": "kn", "Kazakh": "kk", "Kimbundu": "kmb", "Kongo": "kg", "Korean": "ko", "Kurdish": "ku", "Kyrgyz": "ky", "Lao": "lo", "Latvian": "lv", "Lingala": "ln", "Lithuanian": "lt", "Luo": "luo", "Luxembourgish": "lb", "Macedonian": "mk", "Malagasy": "mg", "Malay": "ms", "Malayalam": "ml", "Maltese": "mt", "Maori": "mi", "Marathi": "mr", "Mongolian": "mn", "Nepali": "ne", "Northern Kurdish": "kmr", "Northern Sotho": "ns", "Norwegian": "no", "Nyanja": "ny", "Occitan": "oc", "Oriya": "or", "Oromo": "om", "Pashto": "ps", "Persian": "fa", "Polish": "pl", "Portuguese": "pt", "Punjabi": "pa", "Quechua": "qu", "Romanian": "ro", "Russian": "ru", "Scottish Gaelic": "gd", "Serbian": "sr", "Shan": "shn", "Shona": "sn", "Sindhi": "sd", "Sinhala": "si", "Slovak": "sk", "Slovenian": "sl", "Somali": "so", "Spanish": "es", "Sundanese": "su", "Swahili": "sw", "Swati": "ss", "Swedish": "sv", "Tagalog": "tl", "Tajik": "tg", "Tamil": "ta", "Telugu": "te", "Thai": "th", "Tigrinya": "ti", "Tswana": "tn", "Turkish": "tr", "Ukrainian": "uk", "Umbundu": "umb", "Urdu": "ur", "Uzbek": "uz", "Vietnamese": "vi", "Welsh": "cy", "Western Frisian": "fy", "Wolof": "wo", "Xhosa": "xh", "Yiddish": "yi", "Yoruba": "yo", "Zulu": "zu", } lang_names = list(lang_to_code.keys()) # load model model_path = Path("./model_files").resolve() print(f"model_path: {model_path}") tokenizer: M2M100Tokenizer = M2M100Tokenizer.from_pretrained( pretrained_model_name_or_path=str(model_path), local_files_only=True ) model = M2M100ForConditionalGeneration.from_pretrained( pretrained_model_name_or_path=str(model_path), local_files_only=True ) # fix tokenizer tokenizer.lang_token_to_id = { t: i for t, i in zip(tokenizer.all_special_tokens, tokenizer.all_special_ids) if i > 5 } tokenizer.lang_code_to_token = {s.strip("_"): s for s in tokenizer.lang_token_to_id} tokenizer.lang_code_to_id = { s.strip("_"): i for s, i in tokenizer.lang_token_to_id.items() } tokenizer.id_to_lang_token = {i: s for s, i in tokenizer.lang_token_to_id.items()} def translate(src_text: str, source_lang: str, target_lang: str) -> str: # get lang code src_lang = lang_to_code[source_lang] tgt_lang = lang_to_code[target_lang] # encode tokenizer.src_lang = src_lang encoded_input = tokenizer(src_text, return_tensors="pt") # inference generated_tokens = model.generate( **encoded_input, forced_bos_token_id=tokenizer.get_lang_id(tgt_lang), max_length=1024, ) # decode pred_texts = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) pred_text = pred_texts[0] assert isinstance(pred_text, str) return pred_text inputs = [ gr.Textbox(lines=4, value="Hello world!", label="Input Text"), gr.Dropdown(lang_names, value="English", label="Source Language"), gr.Dropdown(lang_names, value="Korean", label="Target Language"), ] outputs = gr.Textbox(lines=4, label="Output Text") demo = gr.Interface( fn=translate, inputs=inputs, outputs=outputs, title="Flores101: Large-Scale Multilingual Machine Translation", description="[`seyoungsong/flores101_mm100_175M`](https://huggingface.co./seyoungsong/flores101_mm100_175M)", ) if __name__ == "__main__": # https://huggingface.co./seyoungsong/flores101_mm100_175M # https://huggingface.co./spaces/seyoungsong/flores101_mm100_175M # gradio src/pretrained/gradio/app.py # http://127.0.0.1:7860 demo.launch()