dipesh1701's picture
bug fix
48ff56c
raw
history blame
2.81 kB
import torch
import gradio as gr
import time
import asyncio
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
from flores200_codes import flores_codes
# Load models and tokenizers once during initialization
async def load_models():
model_name_dict = {
"nllb-distilled-600M": "facebook/nllb-200-distilled-600M",
}
model_dict = {}
for call_name, real_name in model_name_dict.items():
print("\tLoading model:", call_name)
model = await asyncio.to_thread(AutoModelForSeq2SeqLM.from_pretrained, real_name)
tokenizer = await asyncio.to_thread(AutoTokenizer.from_pretrained, real_name)
model_dict[call_name] = {
"model": model,
"tokenizer": tokenizer,
}
return model_dict
# Translate text using preloaded models and tokenizers
def translate_text(source_lang, target_lang, input_text, model_dict):
model_name = "nllb-distilled-600M"
start_time = time.time()
source_code = flores_codes[source_lang]
target_code = flores_codes[target_lang]
if model_name in model_dict:
model = model_dict[model_name]["model"]
tokenizer = model_dict[model_name]["tokenizer"]
translator = pipeline(
"translation",
model=model,
tokenizer=tokenizer,
src_lang=source_code,
tgt_lang=target_code,
)
translated_output = translator(input_text, max_length=400)
end_time = time.time()
translated_result = {
"inference_time": end_time - start_time,
"source": source_lang,
"target": target_lang,
"result": translated_output[0]["translation_text"],
}
return translated_result
else:
raise KeyError(f"Model '{model_name}' not found in model_dict")
async def main():
print("\tInitializing models")
# Load models and tokenizers
model_dict = await load_models()
lang_codes = list(flores_codes.keys())
inputs = [
gr.inputs.Dropdown(lang_codes, default="English", label="Source Language"),
gr.inputs.Dropdown(lang_codes, default="Nepali", label="Target Language"),
gr.inputs.Textbox(lines=5, label="Input Text"),
]
outputs = gr.outputs.JSON()
title = "Masterful Translator"
app_description = (
"This is a beta version of the Masterful Translator that utilizes pre-trained language models for translation."
)
examples = [["English", "Nepali", "Hello, how are you?"]]
gr.Interface(
fn=translate_text,
inputs=inputs,
outputs=outputs,
title=title,
description=app_description,
examples=examples,
examples_per_page=50,
).launch()
if __name__ == "__main__":
asyncio.run(main())