|
|
|
|
|
|
|
|
|
import time |
|
import torch |
|
import gradio as gr |
|
from info import article |
|
from transformers import FillMaskPipeline |
|
from transformers import BertTokenizer |
|
from kplug.modeling_kplug import KplugForMaskedLM |
|
from pycorrector.bert.bert_corrector import BertCorrector |
|
from pycorrector import config |
|
from loguru import logger |
|
|
|
device_id = 0 if torch.cuda.is_available() else -1 |
|
|
|
|
|
css = """ |
|
.category-legend {display: none !important} |
|
""" |
|
|
|
class KplugCorrector(BertCorrector): |
|
|
|
def __init__(self, bert_model_dir=config.bert_model_dir, device=device_id): |
|
super(BertCorrector, self).__init__() |
|
self.name = 'kplug_corrector' |
|
t1 = time.time() |
|
|
|
tokenizer = BertTokenizer.from_pretrained("eson/kplug-base-encoder") |
|
model = KplugForMaskedLM.from_pretrained("eson/kplug-base-encoder") |
|
|
|
self.model = FillMaskPipeline(model=model, tokenizer=tokenizer, device=device) |
|
if self.model: |
|
self.mask = self.model.tokenizer.mask_token |
|
logger.debug('Loaded bert model: %s, spend: %.3f s.' % (bert_model_dir, time.time() - t1)) |
|
|
|
|
|
corrector = KplugCorrector() |
|
|
|
error_sentences = [ |
|
'少先队员因该为老人让坐', |
|
'机七学习是人工智能领遇最能体现智能的一个分知', |
|
'今天心情很好', |
|
] |
|
|
|
|
|
def mock_data(): |
|
corrected_sent = '机器学习是人工智能领域最能体现智能的一个分知' |
|
errs = [('七', '器', 1, 2), ('遇', '域', 10, 11)] |
|
return corrected_sent, errs |
|
|
|
|
|
def correct(sent): |
|
""" |
|
{"text": sent, "entities": [{}, {}] } 是 gradio 要求的格式,详见 https://www.gradio.app/docs/highlightedtext |
|
""" |
|
corrected_sent, errs = corrector.bert_correct(sent) |
|
|
|
print("original sentence:{} => {}, err:{}".format(sent, corrected_sent, errs)) |
|
output = [{"entity": "纠错", "score": 0.5, "word": err[1], "start": err[2], "end": err[3]} for i, err in |
|
enumerate(errs)] |
|
return {"text": corrected_sent, "entities": output}, errs |
|
|
|
|
|
def test(): |
|
for sent in error_sentences: |
|
corrected_sent, err = corrector.bert_correct(sent) |
|
print("original sentence:{} => {}, err:{}".format(sent, corrected_sent, err)) |
|
|
|
|
|
corr_iface = gr.Interface( |
|
fn=correct, |
|
inputs=gr.Textbox( |
|
label="输入文本", |
|
value="少先队员因该为老人让坐"), |
|
outputs=[ |
|
gr.HighlightedText( |
|
label="文本纠错", |
|
show_legend=True, |
|
|
|
), |
|
gr.JSON( |
|
|
|
) |
|
], |
|
examples=error_sentences, |
|
title="文本纠错(Corrector)", |
|
description='自动对汉语文本中的拼写、语法、标点等多种问题进行纠错校对,提示错误位置并返回修改建议', |
|
article=article, |
|
css=css |
|
) |
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
corr_iface.launch() |
|
|