dipesh1701
commited on
Commit
•
06d2814
1
Parent(s):
2722db9
optimize code
Browse files
app.py
CHANGED
@@ -1,13 +1,11 @@
|
|
1 |
-
import os
|
2 |
import torch
|
3 |
import gradio as gr
|
4 |
import time
|
5 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
|
6 |
from flores200_codes import flores_codes
|
7 |
|
8 |
-
|
9 |
def load_models():
|
10 |
-
# build model and tokenizer
|
11 |
model_name_dict = {
|
12 |
"nllb-distilled-600M": "facebook/nllb-200-distilled-600M",
|
13 |
}
|
@@ -15,76 +13,74 @@ def load_models():
|
|
15 |
model_dict = {}
|
16 |
|
17 |
for call_name, real_name in model_name_dict.items():
|
18 |
-
print("\tLoading model:
|
19 |
model = AutoModelForSeq2SeqLM.from_pretrained(real_name)
|
20 |
tokenizer = AutoTokenizer.from_pretrained(real_name)
|
21 |
-
model_dict[call_name
|
22 |
-
|
|
|
|
|
23 |
|
24 |
return model_dict
|
25 |
|
26 |
-
|
27 |
-
def
|
28 |
-
|
29 |
-
model_name = "nllb-distilled-600M"
|
30 |
|
31 |
start_time = time.time()
|
32 |
-
|
33 |
-
|
34 |
|
35 |
-
model = model_dict[model_name
|
36 |
-
tokenizer = model_dict[model_name
|
37 |
|
38 |
translator = pipeline(
|
39 |
"translation",
|
40 |
model=model,
|
41 |
tokenizer=tokenizer,
|
42 |
-
src_lang=
|
43 |
-
tgt_lang=
|
44 |
)
|
45 |
-
|
46 |
|
47 |
end_time = time.time()
|
48 |
|
49 |
-
|
50 |
-
result = {
|
51 |
"inference_time": end_time - start_time,
|
52 |
-
"source":
|
53 |
-
"target":
|
54 |
-
"result":
|
55 |
}
|
56 |
-
return
|
57 |
-
|
58 |
|
59 |
if __name__ == "__main__":
|
60 |
-
|
61 |
|
|
|
62 |
model_dict = load_models()
|
63 |
|
64 |
-
# define gradio demo
|
65 |
lang_codes = list(flores_codes.keys())
|
66 |
inputs = [
|
67 |
-
gr.inputs.Dropdown(lang_codes, default="English", label="Source"),
|
68 |
-
gr.inputs.Dropdown(lang_codes, default="Nepali", label="Target"),
|
69 |
-
gr.inputs.Textbox(lines=5, label="Input
|
70 |
]
|
71 |
|
72 |
outputs = gr.outputs.JSON()
|
73 |
|
74 |
-
title = "
|
75 |
|
76 |
-
|
77 |
-
|
78 |
-
f"{desc}"
|
79 |
)
|
80 |
-
examples = [["English", "Nepali", "
|
81 |
|
82 |
gr.Interface(
|
83 |
-
|
84 |
inputs,
|
85 |
outputs,
|
86 |
title=title,
|
87 |
-
description=
|
88 |
examples=examples,
|
89 |
examples_per_page=50,
|
90 |
).launch()
|
|
|
|
|
1 |
import torch
|
2 |
import gradio as gr
|
3 |
import time
|
4 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
|
5 |
from flores200_codes import flores_codes
|
6 |
|
7 |
+
# Load models and tokenizers once during initialization
|
8 |
def load_models():
|
|
|
9 |
model_name_dict = {
|
10 |
"nllb-distilled-600M": "facebook/nllb-200-distilled-600M",
|
11 |
}
|
|
|
13 |
model_dict = {}
|
14 |
|
15 |
for call_name, real_name in model_name_dict.items():
|
16 |
+
print("\tLoading model:", call_name)
|
17 |
model = AutoModelForSeq2SeqLM.from_pretrained(real_name)
|
18 |
tokenizer = AutoTokenizer.from_pretrained(real_name)
|
19 |
+
model_dict[call_name] = {
|
20 |
+
"model": model,
|
21 |
+
"tokenizer": tokenizer,
|
22 |
+
}
|
23 |
|
24 |
return model_dict
|
25 |
|
26 |
+
# Translate text using preloaded models and tokenizers
|
27 |
+
def translate_text(source_lang, target_lang, input_text, model_dict):
|
28 |
+
model_name = "nllb-distilled-600M"
|
|
|
29 |
|
30 |
start_time = time.time()
|
31 |
+
source_code = flores_codes[source_lang]
|
32 |
+
target_code = flores_codes[target_lang]
|
33 |
|
34 |
+
model = model_dict[model_name]["model"]
|
35 |
+
tokenizer = model_dict[model_name]["tokenizer"]
|
36 |
|
37 |
translator = pipeline(
|
38 |
"translation",
|
39 |
model=model,
|
40 |
tokenizer=tokenizer,
|
41 |
+
src_lang=source_code,
|
42 |
+
tgt_lang=target_code,
|
43 |
)
|
44 |
+
translated_output = translator(input_text, max_length=400)
|
45 |
|
46 |
end_time = time.time()
|
47 |
|
48 |
+
translated_result = {
|
|
|
49 |
"inference_time": end_time - start_time,
|
50 |
+
"source": source_lang,
|
51 |
+
"target": target_lang,
|
52 |
+
"result": translated_output[0]["translation_text"],
|
53 |
}
|
54 |
+
return translated_result
|
|
|
55 |
|
56 |
if __name__ == "__main__":
|
57 |
+
print("\tInitializing models")
|
58 |
|
59 |
+
# Load models and tokenizers
|
60 |
model_dict = load_models()
|
61 |
|
|
|
62 |
lang_codes = list(flores_codes.keys())
|
63 |
inputs = [
|
64 |
+
gr.inputs.Dropdown(lang_codes, default="English", label="Source Language"),
|
65 |
+
gr.inputs.Dropdown(lang_codes, default="Nepali", label="Target Language"),
|
66 |
+
gr.inputs.Textbox(lines=5, label="Input Text"),
|
67 |
]
|
68 |
|
69 |
outputs = gr.outputs.JSON()
|
70 |
|
71 |
+
title = "Masterful Translator"
|
72 |
|
73 |
+
app_description = (
|
74 |
+
"This is a beta version of the Masterful Translator that utilizes pre-trained language models for translation."
|
|
|
75 |
)
|
76 |
+
examples = [["English", "Nepali", "Hello, how are you?"]]
|
77 |
|
78 |
gr.Interface(
|
79 |
+
translate_text,
|
80 |
inputs,
|
81 |
outputs,
|
82 |
title=title,
|
83 |
+
description=app_description,
|
84 |
examples=examples,
|
85 |
examples_per_page=50,
|
86 |
).launch()
|