dipesh1701's picture
model change
ccf30c7
raw
history blame
2.73 kB
import os
import torch
import gradio as gr
import time
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
from flores200_codes import flores_codes
def load_models():
model_name_dict = {
"nllb-distilled-1.3B": "facebook/nllb-200-distilled-1.3B",
}
model_dict = {}
for call_name, real_name in model_name_dict.items():
print("\tLoading model:", call_name)
model = AutoModelForSeq2SeqLM.from_pretrained(real_name)
tokenizer = AutoTokenizer.from_pretrained(real_name)
model_dict[call_name] = {
"model": model,
"tokenizer": tokenizer,
}
return model_dict
# Load models and tokenizers once during initialization
model_dict = load_models()
# Translate text using preloaded models and tokenizers
def translate_text(source, target, text):
model_name = "nllb-distilled-600M"
if model_name in model_dict and model_dict[model_name]["model"] is not None:
model = model_dict[model_name]["model"]
tokenizer = model_dict[model_name]["tokenizer"]
start_time = time.time()
source = flores_codes[source]
target = flores_codes[target]
translator = pipeline(
"translation",
model=model,
tokenizer=tokenizer,
src_lang=source,
tgt_lang=target,
)
output = translator(text, max_length=400)
end_time = time.time()
output = output[0]["translation_text"]
result = {
"inference_time": end_time - start_time,
"source": source,
"target": target,
"result": output,
}
return result
else:
raise KeyError(f"Model '{model_name}' not found in model_dict")
if __name__ == "__main__":
print("\tInitializing models")
lang_codes = list(flores_codes.keys())
inputs = [
gr.inputs.Dropdown(lang_codes, default="English", label="Source"),
gr.inputs.Dropdown(lang_codes, default="Nepali", label="Target"),
gr.inputs.Textbox(lines=5, label="Input text"),
]
outputs = gr.outputs.JSON()
title = "The Master Betters Translator"
desc = "This is a beta version of The Master Betters Translator that utilizes pre-trained language models for translation. To use this app you need to have chosen the source and target language with your input text to get the output."
description = (
f"{desc}"
)
examples = [["English", "Nepali", "Hello, how are you?"]]
gr.Interface(
translate_text,
inputs,
outputs,
title=title,
description=description,
examples=examples,
examples_per_page=50,
).launch()