penpen's picture
Update app.py
33ca0c4
raw
history blame contribute delete
No virus
3.04 kB
import gradio as gr
from transformers import pipeline
import pandas as pd
import numpy as np
import os
model_checkpoint = "penpen/novel-zh-en"
translator = pipeline("translation", model=model_checkpoint, max_time=7)
default_dict = pd.read_csv("example_dictionary.csv", names=["Chinese", "English"])
examples = pd.read_csv("examples.csv", header = None)
def predict(text, df):
translation = ""
terms_dict = {chinese: english for chinese, english in zip(df["Chinese"].tolist(), df["English"].tolist())}
for key in terms_dict:
if key in text:
masking = "MASK"*len(key)
text = text.replace(key, "<TERM>" + masking+ "<GLOS>" + terms_dict[key] + "</GLOS>")
split_text = text.splitlines()
for text in split_text:
text = text.strip()
if text:
if len(text) < 512:
sentence = translator(text)[0]["translation_text"] + '\n\n'
translation+=sentence
print(split_text)
else:
for i in range(0,len(text),512):
if i+512>len(text):
sentence = translator(text[i:])[0]["translation_text"]
else:
sentence = translator(text[i:i+512])[0]["translation_text"]
translation+=sentence
return translation
def load_dict(file):
df = pd.read_csv(file.name, names=["Chinese", "English"])
return df, df
def search_dict(query, df):
if not query:
return df
mask = np.column_stack([df[col].str.contains(query, na=False) for col in df])
return df.loc[mask.any(axis=1)]
with gr.Blocks() as project:
dict_hidden = gr.State(default_dict)
gr.Markdown("<center><h1>Chinese Webnovel Translator</h1> A translator that is fine-tuned on Chinese Webnovels</center>")
with gr.Tab("Translator"):
with gr.Row():
with gr.Column(scale=1, min_width=600):
translate_input = gr.Textbox(label="Chinese", lines=7, max_lines = 100, placeholder="Chinese...")
translate_button = gr.Button("Translate")
translate_hidden = gr.State("")
translate_output = gr.Textbox(label="English", lines=7, max_lines = 100, placeholder="English...")
example = gr.Examples(inputs = translate_input, examples=examples[0].tolist())
with gr.Tab("Proper Noun Dictionary"):
with gr.Row():
with gr.Column(scale=1, min_width=600):
dict_example_file = gr.File(label="Example Dictionary", value = "example_dictionary.csv")
dict_file = gr.File(interactive = True, label="Upload a custom dictionary (CSV File)")
dict_upload_button = gr.Button("Upload")
dict_search = gr.Textbox(label="Search Dictionary")
dict_search_button = gr.Button("Search")
dict_display = gr.Dataframe(value = default_dict, max_rows = 5, col_count=(2, "fixed"))
translate_button.click(predict, inputs=[translate_input, dict_hidden], outputs=translate_output)
dict_upload_button.click(load_dict, inputs=dict_file, outputs = [dict_hidden, dict_display])
dict_search_button.click(search_dict, inputs=[dict_search, dict_hidden], outputs = dict_display)
project.launch(debug=True)