File size: 2,577 Bytes
463444e
 
 
c3e5a3b
463444e
 
06d2814
463444e
 
c3e5a3b
463444e
 
 
 
 
06d2814
820763d
463444e
06d2814
 
 
 
463444e
 
 
06d2814
 
 
463444e
 
06d2814
 
463444e
06d2814
 
463444e
 
 
 
 
06d2814
 
463444e
06d2814
463444e
 
 
06d2814
463444e
06d2814
 
 
463444e
06d2814
463444e
 
06d2814
463444e
06d2814
a89b78d
463444e
 
 
06d2814
 
 
463444e
 
 
 
9e6f72f
463444e
06d2814
b0ef2d7
463444e
06d2814
463444e
 
b0ef2d7
 
 
463444e
06d2814
463444e
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import torch
import gradio as gr
import time
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
from flores200_codes import flores_codes

# Load models and tokenizers once during initialization
def load_models():
    model_name_dict = {
        "nllb-distilled-600M": "facebook/nllb-200-distilled-600M",
    }

    model_dict = {}

    for call_name, real_name in model_name_dict.items():
        print("\tLoading model:", call_name)
        model = AutoModelForSeq2SeqLM.from_pretrained(real_name)
        tokenizer = AutoTokenizer.from_pretrained(real_name)
        model_dict[call_name] = {
            "model": model,
            "tokenizer": tokenizer,
        }

    return model_dict

# Translate text using preloaded models and tokenizers
def translate_text(source_lang, target_lang, input_text, model_dict):
    model_name = "nllb-distilled-600M"

    start_time = time.time()
    source_code = flores_codes[source_lang]
    target_code = flores_codes[target_lang]

    model = model_dict[model_name]["model"]
    tokenizer = model_dict[model_name]["tokenizer"]

    translator = pipeline(
        "translation",
        model=model,
        tokenizer=tokenizer,
        src_lang=source_code,
        tgt_lang=target_code,
    )
    translated_output = translator(input_text, max_length=400)

    end_time = time.time()

    translated_result = {
        "inference_time": end_time - start_time,
        "source": source_lang,
        "target": target_lang,
        "result": translated_output[0]["translation_text"],
    }
    return translated_result

if __name__ == "__main__":
    print("\tInitializing models")

    # Load models and tokenizers
    model_dict = load_models()  # Ensure that this line initializes model_dict correctly

    lang_codes = list(flores_codes.keys())
    inputs = [
        gr.inputs.Dropdown(lang_codes, default="English", label="Source Language"),
        gr.inputs.Dropdown(lang_codes, default="Nepali", label="Target Language"),
        gr.inputs.Textbox(lines=5, label="Input Text"),
    ]

    outputs = gr.outputs.JSON()

    title = "The Master Betters Translator"

    app_description = (
        "This is a beta version of The Master Betters Translator that utilizes pre-trained language models for translation."
    )
    examples = [["English", "Nepali", "Hello, how are you?"]]

    gr.Interface(
        fn=translate_text,
        inputs=inputs,
        outputs=outputs,
        title=title,
        description=app_description,
        examples=examples,
        examples_per_page=50,
    ).launch()