import gradio as gr import errant import spacy import os import json import nltk from utils import get_random_prompt, instruction_prompts from llama_cpp import Llama from transformers import pipeline import config # Load necessary models and resources nlp = spacy.load("en_core_web_sm") annotator = errant.load('en', nlp) errant_path = os.path.join(os.path.dirname("./"), 'errant_verbose.json') errant_verbose = json.load(open(errant_path, "r")) sent_detector = nltk.data.load('./nltk_data/tokenizers/punkt/english.pickle') # Load text editor (TinyLlama) text_editor = Llama( model_path="./texteditor-model/coedit-tinyllama-chat-bnb-4bit-unsloth.Q4_K_M.gguf", verbose=False ) print("text editor is loaded!") # Load grammar corrector (Flan-T5) grammar_corrector = pipeline( 'text2text-generation', 'pszemraj/flan-t5-large-grammar-synthesis', ) print("grammar corrector is loaded!") def correcting_text(src: str) -> str: """ Corrects grammatical errors in the given text using the grammar corrector model. Args: src: The text to be corrected. Returns: The grammatically corrected text. """ lines = src.split('\n') sentences = [] line_idx = [] for l_idx, line in enumerate(lines): if len(line) == 0: continue l_sents = sent_detector.tokenize(line) for sent in l_sents: sentences.append(sent) line_idx.append(l_idx) num_iter = (len(sentences) + config.BATCH_SIZE - 1) // config.BATCH_SIZE final_outs = [] out_lines = ["" for _ in lines] for i in range(num_iter): start = i * config.BATCH_SIZE end = min((i + 1) * config.BATCH_SIZE, len(sentences)) final_outs += grammar_corrector(sentences[start:end], max_length=128, num_beams=5, early_stopping=True) for i in range(len(final_outs)): out_lines[line_idx[i]] += final_outs[i]["generated_text"] + " " return "\n".join(out_lines) def annotate_text(src: str, tag: str, analyze: bool = True) -> list: """ Annotates the text with edits based on the provided tag using the Errant library. original code from: https://github.com/nusnlp/ALLECS Args: src: The source text. tag: The target text. analyze: Whether to analyze and provide detailed information about edits. Returns: A list of tuples representing the edits, where each tuple is: - (edit_text, edit_type) """ out = {"edits": []} out['source'] = src src_doc = annotator.parse(src) tag_doc = annotator.parse(tag) cur_edits = annotator.annotate(src_doc, tag_doc) for e in cur_edits: out["edits"].append((e.o_start, e.o_end, e.type, e.c_str)) result = [] last_pos = 0 if analyze: tokens = out['source'] if isinstance(tokens, str): tokens = tokens.split(' ') edits = out['edits'] offset = 0 for edit in edits: if isinstance(edit, dict): e_start = edit['start'] e_end = edit['end'] e_type = edit['type'] e_rep = edit['cor'] elif isinstance(edit, tuple): e_start = edit[0] e_end = edit[1] e_type = edit[2] e_rep = edit[3] else: raise ValueError("Data type {} is not supported."\ .format(type(edit))) e_rep = e_rep.strip() op_type = e_type[0] pos_type = e_type[2:] errant_info = errant_verbose[pos_type] title = errant_info["title"] result.append((' '.join(tokens[last_pos:e_start + offset]), None)) ori_str = ' '.join(tokens[e_start + offset:e_end + offset]).strip() if pos_type == "ORTH": # check if it's a casing issue if ori_str.lower() == e_rep.lower(): if e_rep[0].isupper() and ori_str[0].islower(): msg = "{ori} should be capitalized." elif e_rep[0].islower() and ori_str[0].isupper(): msg = "{ori} should not be capitalized." else: msg = "The casing of the word {ori} is wrong." # then it should be a spacing issue else: if len(ori_str) - 1 == len(e_rep): msg = "The word {ori} should not be written separately." elif len(ori_str) + 1 == len(e_rep): msg = "The word {ori} should be separated into {cor}." else: msg = "The word {ori} has orthography error." else: if op_type in errant_info: msg = errant_info[op_type] else: msg = errant_verbose["Default"][op_type] msg = '
' + msg.format(ori=ori_str, cor=e_rep) + '
' e_cor = e_rep.split() len_cor = len(e_cor) tokens[e_start + offset:e_end + offset] = e_cor last_pos = e_start + offset + len_cor offset = offset - (e_end - e_start) + len_cor result.append((e_rep, pos_type)) out = ' '.join(tokens) result.append((' '.join(tokens[last_pos:]), None)) print(result) return result def choices2promts() -> list: """ Returns a list of available instructions for text editing. Returns: A list of instruction names. """ return instruction_prompts.keys() with gr.Blocks() as demo: def turn_off_legend(msg: str) -> gr.update: """ Turns off the legend in the highlighted text component. Args: msg: The text input. Returns: A Gradio update object to hide the legend. """ return gr.update(show_legend=False) def turn_on_legend(annotate: bool) -> gr.update: """ Turns on the legend in the highlighted text component if annotate is True. Args: annotate: Whether to show annotations. Returns: A Gradio update object to show or hide the legend. """ if annotate: return gr.update(show_legend=True) else: return gr.update(show_legend=False) def bot(task: str, text: str, post_check: bool, annotate: bool) -> tuple: """ Processes the user input and returns the edited text along with annotations. Args: task: The chosen instruction for editing. text: The text to be edited. post_check: Whether to check for grammatical errors after text generation. annotate: Whether to show annotations. Yields: Tuples of (edited text, annotation type) to update the interface. """ response = "" if task == "Grammar Error Correction": yield [("Processing ...", None)], "Checking Grammar ..." response = correcting_text(text) else: instruction = get_random_prompt(task) prompt = instruction + ": " + text print(prompt) output = text_editor.create_chat_completion( messages=[ { "role": "system", "content": "You are an English writing assistant, editing the text of user input and response based on user instructions. Please do not provide explanations, but respond only with the edited text. Also, if the instruction is not provided, correct the grammar of the text. Finally, if the instruction is not for editing text, correct the grammar of the text.", }, {"role": "user", "content": f"{prompt}"}, ], temperature=0.0, stream=True, ) response = "" for chunk in output: delta = chunk["choices"][0]["delta"] if "role" in delta: pass elif "content" in delta: response+=delta['content'] res = [(response, None), ] print(res) yield res, "Generating output ..." if post_check: yield [(response, None)], "Checking Grammar ..." response = correcting_text(response) print(response) if annotate: e_edit = annotate_text(text, response) else: e_edit = [(response, None)] yield e_edit, "Done." def handle_highlight_selection(): """ Handles the selection event of the highlighted text component. This function is not implemented in the original code. """ # print("hi") return gr.Markdown("# English Text Editing Application using T5 and Tiny Llama") gr.Markdown("> source code: https://github.com/LETHEVIET/t5nyllama") with gr.Row() as row: with gr.Column(scale=1) as col1: instruction = gr.Dropdown( choices=choices2promts(), value="Grammar Error Correction", multiselect=False, label="Choose your instruction", interactive=True, scale=0 ) with gr.Row() as row2: clear = gr.Button("Clear", scale=-1) submit = gr.Button("submit", scale=-1) info_msg = gr.Textbox( label="Information", scale=1, lines=3, value="i can has cheezburger.", ) post_check = gr.Checkbox(label="Check grammaticality after text generation.", value=True) annotate = gr.Checkbox(label="Highlight different", value=True) with gr.Column(scale=2) as col2: msg = gr.Textbox( label="Input", scale=3, value="Information will show here.", ) result = gr.HighlightedText( label="Result", combine_adjacent=True, show_legend=False, scale=3 ) res_msg = gr.Textbox( scale=0, visible=False, label="Ouput", ) msg.submit(turn_off_legend, msg, result).then(bot, [instruction, msg, post_check, annotate], [result, info_msg]).then(turn_on_legend, annotate, result) clear.click(lambda: None, None, result, queue=False) submit.click(turn_off_legend, msg, result).then(bot, [instruction, msg, post_check, annotate], [result, info_msg]).then(turn_on_legend, annotate, result) result.select(handle_highlight_selection, [], []) if __name__ == "__main__": demo.launch(server_port=7860)