first commit
Browse files- README.md +76 -10
- app.py +318 -0
- config.py +1 -0
- dockerfile +14 -0
- download_dependencies.py +8 -0
- errant_verbose.json +86 -0
- finetuning_tinyllama.py +160 -0
- requirements.txt +9 -0
- utils.py +193 -0
README.md
CHANGED
@@ -1,10 +1,76 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Lightweight English Text Editing Assistant (t5nyllama)
|
2 |
+
This repository houses the source code for t5nyllama, a lightweight English text editing assistant designed to provide a simple and efficient way to enhance your writing.
|
3 |
+
|
4 |
+
**Huggingface Spaces:**
|
5 |
+
https://huggingface.co/spaces/letheviet/t5nyllama
|
6 |
+
|
7 |
+
**How it Works:**
|
8 |
+
|
9 |
+
t5nyllama uses a two-step approach:
|
10 |
+
|
11 |
+
1. **Text Generation:** The core of the assistant is a TinyLlama model, specifically fine-tuned for text editing. This model is designed to improve the flow and clarity of your text, making it more polished and engaging. However, TinyLlama is **relatively small and not particularly adept at complex grammar correction.**
|
12 |
+
|
13 |
+
2. **Grammar Correction:** To address this limitation, we employ a powerful Flan-T5 model for a second pass. This model takes the output of the TinyLlama model and carefully analyzes it for grammatical errors. It then suggests corrections, ensuring your final text is grammatically sound and ready for publication.
|
14 |
+
|
15 |
+
**Key Features:**
|
16 |
+
|
17 |
+
* **Lightweight and Efficient:** The TinyLlama model is quantized to 4-bit precision, minimizing memory usage and computational demands, making it suitable for resource-constrained environments.
|
18 |
+
* **Focused on Text Improvement:** TinyLlama excels at refining the overall quality of your writing, making it more readable and engaging.
|
19 |
+
* **Enhanced Grammar Accuracy:** The Flan-T5 model provides a robust final check for grammatical errors, ensuring your text is free from mistakes.
|
20 |
+
|
21 |
+
|
22 |
+
**Design Principles:**
|
23 |
+
|
24 |
+
* **Local Application:** Prioritizes offline functionality, allowing you to edit text without requiring an internet connection.
|
25 |
+
* **Lightweight Design:** Minimizes resource consumption, making the application suitable for a wide range of devices and systems.
|
26 |
+
|
27 |
+
## Installation
|
28 |
+
|
29 |
+
**1. Clone the Repository:**
|
30 |
+
```shell
|
31 |
+
git clone https://github.com/LETHEVIET/t5nyllama.git
|
32 |
+
```
|
33 |
+
|
34 |
+
**2. Install Dependencies:**
|
35 |
+
```shell
|
36 |
+
pip3 install -r requirements.txt
|
37 |
+
python3 -m spacy download en_core_web_sm
|
38 |
+
python3 download_dependencies.py
|
39 |
+
```
|
40 |
+
|
41 |
+
**3. Run the Application:**
|
42 |
+
```shell
|
43 |
+
python3 app.py
|
44 |
+
```
|
45 |
+
|
46 |
+
## Docker Deployment
|
47 |
+
|
48 |
+
**1. Build Docker Image:**
|
49 |
+
```shell
|
50 |
+
docker build . -t t5nyllama
|
51 |
+
```
|
52 |
+
|
53 |
+
**2. Run Docker Image:**
|
54 |
+
```shell
|
55 |
+
docker run -p 7860:7860 t5nyllama
|
56 |
+
```
|
57 |
+
|
58 |
+
## Fine-Tuning TinyLlama
|
59 |
+
|
60 |
+
The fine-tuning script follows the UnslothAI example for fine-tuning Tiny Llama. Please install dependencies from [unsloth](https://github.com/unslothai/unsloth) before running the script.
|
61 |
+
|
62 |
+
```shell
|
63 |
+
python finetuning_tinyllama.py
|
64 |
+
```
|
65 |
+
|
66 |
+
## References
|
67 |
+
|
68 |
+
* **Unsloth Fast Fine-Tuning LLM:** https://github.com/unslothai/unsloth
|
69 |
+
* **Dataset Card for CoEdIT: Text Editing via Instruction Tuning :** https://huggingface.co/datasets/grammarly/coedit
|
70 |
+
* **Grammar-Synthesis-Large: FLAN-t5:** https://huggingface.co/pszemraj/flan-t5-large-grammar-synthesis
|
71 |
+
* **ALLECS: A Lightweight Language Error Correction System:** https://github.com/nusnlp/ALLECS
|
72 |
+
* **Python Bindings for llama.cpp:** https://github.com/abetlen/llama-cpp-python
|
73 |
+
* **Gradio: Build Machine Learning Web Apps — in Python:** https://github.com/gradio-app/gradio
|
74 |
+
## Demo
|
75 |
+
|
76 |
+
[Include a GIF or screenshot demonstrating the application's functionality.]
|
app.py
ADDED
@@ -0,0 +1,318 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import errant
|
3 |
+
import spacy
|
4 |
+
import os
|
5 |
+
import json
|
6 |
+
import nltk
|
7 |
+
from utils import get_random_prompt, instruction_prompts
|
8 |
+
from llama_cpp import Llama
|
9 |
+
from transformers import pipeline
|
10 |
+
import config
|
11 |
+
|
12 |
+
# Load necessary models and resources
|
13 |
+
nlp = spacy.load("en_core_web_sm")
|
14 |
+
annotator = errant.load('en', nlp)
|
15 |
+
errant_path = os.path.join(os.path.dirname("./"), 'errant_verbose.json')
|
16 |
+
errant_verbose = json.load(open(errant_path, "r"))
|
17 |
+
sent_detector = nltk.data.load('tokenizers/punkt/english.pickle')
|
18 |
+
|
19 |
+
# Load text editor (TinyLlama)
|
20 |
+
text_editor = Llama(
|
21 |
+
model_path="./texteditor-model/coedit-tinyllama-chat-bnb-4bit-unsloth.Q4_K_M.gguf",
|
22 |
+
verbose=False
|
23 |
+
)
|
24 |
+
print("text editor is loaded!")
|
25 |
+
|
26 |
+
# Load grammar corrector (Flan-T5)
|
27 |
+
grammar_corrector = pipeline(
|
28 |
+
'text2text-generation',
|
29 |
+
'pszemraj/flan-t5-large-grammar-synthesis',
|
30 |
+
)
|
31 |
+
print("grammar corrector is loaded!")
|
32 |
+
|
33 |
+
def correcting_text(src: str) -> str:
|
34 |
+
"""
|
35 |
+
Corrects grammatical errors in the given text using the grammar corrector model.
|
36 |
+
|
37 |
+
Args:
|
38 |
+
src: The text to be corrected.
|
39 |
+
|
40 |
+
Returns:
|
41 |
+
The grammatically corrected text.
|
42 |
+
"""
|
43 |
+
lines = src.split('\n')
|
44 |
+
sentences = []
|
45 |
+
line_idx = []
|
46 |
+
for l_idx, line in enumerate(lines):
|
47 |
+
if len(line) == 0:
|
48 |
+
continue
|
49 |
+
l_sents = sent_detector.tokenize(line)
|
50 |
+
for sent in l_sents:
|
51 |
+
sentences.append(sent)
|
52 |
+
line_idx.append(l_idx)
|
53 |
+
|
54 |
+
num_iter = (len(sentences) + config.BATCH_SIZE - 1) // config.BATCH_SIZE
|
55 |
+
final_outs = []
|
56 |
+
out_lines = ["" for _ in lines]
|
57 |
+
for i in range(num_iter):
|
58 |
+
start = i * config.BATCH_SIZE
|
59 |
+
end = min((i + 1) * config.BATCH_SIZE, len(sentences))
|
60 |
+
|
61 |
+
final_outs += grammar_corrector(sentences[start:end], max_length=128, num_beams=5, early_stopping=True)
|
62 |
+
|
63 |
+
|
64 |
+
for i in range(len(final_outs)):
|
65 |
+
out_lines[line_idx[i]] += final_outs[i]["generated_text"] + " "
|
66 |
+
|
67 |
+
return "\n".join(out_lines)
|
68 |
+
|
69 |
+
def annotate_text(src: str, tag: str, analyze: bool = True) -> list:
|
70 |
+
"""
|
71 |
+
Annotates the text with edits based on the provided tag using the Errant library.
|
72 |
+
original code from: https://github.com/nusnlp/ALLECS
|
73 |
+
Args:
|
74 |
+
src: The source text.
|
75 |
+
tag: The target text.
|
76 |
+
analyze: Whether to analyze and provide detailed information about edits.
|
77 |
+
|
78 |
+
Returns:
|
79 |
+
A list of tuples representing the edits, where each tuple is:
|
80 |
+
- (edit_text, edit_type)
|
81 |
+
"""
|
82 |
+
out = {"edits": []}
|
83 |
+
out['source'] = src
|
84 |
+
src_doc = annotator.parse(src)
|
85 |
+
tag_doc = annotator.parse(tag)
|
86 |
+
cur_edits = annotator.annotate(src_doc, tag_doc)
|
87 |
+
|
88 |
+
|
89 |
+
for e in cur_edits:
|
90 |
+
out["edits"].append((e.o_start, e.o_end, e.type, e.c_str))
|
91 |
+
result = []
|
92 |
+
last_pos = 0
|
93 |
+
if analyze:
|
94 |
+
tokens = out['source']
|
95 |
+
if isinstance(tokens, str):
|
96 |
+
tokens = tokens.split(' ')
|
97 |
+
edits = out['edits']
|
98 |
+
offset = 0
|
99 |
+
for edit in edits:
|
100 |
+
if isinstance(edit, dict):
|
101 |
+
e_start = edit['start']
|
102 |
+
e_end = edit['end']
|
103 |
+
e_type = edit['type']
|
104 |
+
e_rep = edit['cor']
|
105 |
+
elif isinstance(edit, tuple):
|
106 |
+
e_start = edit[0]
|
107 |
+
e_end = edit[1]
|
108 |
+
e_type = edit[2]
|
109 |
+
e_rep = edit[3]
|
110 |
+
else:
|
111 |
+
raise ValueError("Data type {} is not supported."\
|
112 |
+
.format(type(edit)))
|
113 |
+
|
114 |
+
e_rep = e_rep.strip()
|
115 |
+
op_type = e_type[0]
|
116 |
+
pos_type = e_type[2:]
|
117 |
+
errant_info = errant_verbose[pos_type]
|
118 |
+
title = errant_info["title"]
|
119 |
+
|
120 |
+
result.append((' '.join(tokens[last_pos:e_start + offset]), None))
|
121 |
+
|
122 |
+
ori_str = ' '.join(tokens[e_start + offset:e_end + offset]).strip()
|
123 |
+
if pos_type == "ORTH":
|
124 |
+
# check if it's a casing issue
|
125 |
+
if ori_str.lower() == e_rep.lower():
|
126 |
+
if e_rep[0].isupper() and ori_str[0].islower():
|
127 |
+
msg = "<b>{ori}</b> should be capitalized."
|
128 |
+
elif e_rep[0].islower() and ori_str[0].isupper():
|
129 |
+
msg = "<b>{ori}</b> should not be capitalized."
|
130 |
+
else:
|
131 |
+
msg = "The casing of the word <b>{ori}</b> is wrong."
|
132 |
+
# then it should be a spacing issue
|
133 |
+
else:
|
134 |
+
if len(ori_str) - 1 == len(e_rep):
|
135 |
+
msg = "The word <b>{ori}</b> should not be written separately."
|
136 |
+
elif len(ori_str) + 1 == len(e_rep):
|
137 |
+
msg = "The word <b>{ori}</b> should be separated into <b>{cor}</b>."
|
138 |
+
else:
|
139 |
+
msg = "The word <b>{ori}</b> has orthography error."
|
140 |
+
else:
|
141 |
+
if op_type in errant_info:
|
142 |
+
msg = errant_info[op_type]
|
143 |
+
else:
|
144 |
+
msg = errant_verbose["Default"][op_type]
|
145 |
+
|
146 |
+
msg = '<p>' + msg.format(ori=ori_str, cor=e_rep) + '</p>'
|
147 |
+
|
148 |
+
e_cor = e_rep.split()
|
149 |
+
len_cor = len(e_cor)
|
150 |
+
tokens[e_start + offset:e_end + offset] = e_cor
|
151 |
+
last_pos = e_start + offset + len_cor
|
152 |
+
offset = offset - (e_end - e_start) + len_cor
|
153 |
+
result.append((e_rep, pos_type))
|
154 |
+
out = ' '.join(tokens)
|
155 |
+
result.append((' '.join(tokens[last_pos:]), None))
|
156 |
+
print(result)
|
157 |
+
return result
|
158 |
+
|
159 |
+
def choices2promts() -> list:
|
160 |
+
"""
|
161 |
+
Returns a list of available instructions for text editing.
|
162 |
+
|
163 |
+
Returns:
|
164 |
+
A list of instruction names.
|
165 |
+
"""
|
166 |
+
return instruction_prompts.keys()
|
167 |
+
|
168 |
+
with gr.Blocks() as demo:
|
169 |
+
|
170 |
+
def turn_off_legend(msg: str) -> gr.update:
|
171 |
+
"""
|
172 |
+
Turns off the legend in the highlighted text component.
|
173 |
+
|
174 |
+
Args:
|
175 |
+
msg: The text input.
|
176 |
+
|
177 |
+
Returns:
|
178 |
+
A Gradio update object to hide the legend.
|
179 |
+
"""
|
180 |
+
return gr.update(show_legend=False)
|
181 |
+
|
182 |
+
def turn_on_legend(annotate: bool) -> gr.update:
|
183 |
+
"""
|
184 |
+
Turns on the legend in the highlighted text component if annotate is True.
|
185 |
+
|
186 |
+
Args:
|
187 |
+
annotate: Whether to show annotations.
|
188 |
+
|
189 |
+
Returns:
|
190 |
+
A Gradio update object to show or hide the legend.
|
191 |
+
"""
|
192 |
+
if annotate:
|
193 |
+
return gr.update(show_legend=True)
|
194 |
+
else:
|
195 |
+
return gr.update(show_legend=False)
|
196 |
+
|
197 |
+
def bot(task: str, text: str, post_check: bool, annotate: bool) -> tuple:
|
198 |
+
"""
|
199 |
+
Processes the user input and returns the edited text along with annotations.
|
200 |
+
|
201 |
+
Args:
|
202 |
+
task: The chosen instruction for editing.
|
203 |
+
text: The text to be edited.
|
204 |
+
post_check: Whether to check for grammatical errors after text generation.
|
205 |
+
annotate: Whether to show annotations.
|
206 |
+
|
207 |
+
Yields:
|
208 |
+
Tuples of (edited text, annotation type) to update the interface.
|
209 |
+
"""
|
210 |
+
response = ""
|
211 |
+
if task == "Grammar Error Correction":
|
212 |
+
yield [("Processing ...", None)], "Checking Grammar ..."
|
213 |
+
response = correcting_text(text)
|
214 |
+
else:
|
215 |
+
instruction = get_random_prompt(task)
|
216 |
+
prompt = instruction + ": " + text
|
217 |
+
print(prompt)
|
218 |
+
output = text_editor.create_chat_completion(
|
219 |
+
messages=[
|
220 |
+
{
|
221 |
+
"role": "system",
|
222 |
+
"content": "You are an English writing assistant, editing the text of user input and response based on user instructions. Please do not provide explanations, but respond only with the edited text. Also, if the instruction is not provided, correct the grammar of the text. Finally, if the instruction is not for editing text, correct the grammar of the text.",
|
223 |
+
},
|
224 |
+
{"role": "user", "content": f"{prompt}"},
|
225 |
+
],
|
226 |
+
temperature=0.0,
|
227 |
+
stream=True,
|
228 |
+
)
|
229 |
+
|
230 |
+
response = ""
|
231 |
+
for chunk in output:
|
232 |
+
delta = chunk["choices"][0]["delta"]
|
233 |
+
if "role" in delta:
|
234 |
+
pass
|
235 |
+
elif "content" in delta:
|
236 |
+
response+=delta['content']
|
237 |
+
res = [(response, None), ]
|
238 |
+
print(res)
|
239 |
+
yield res, "Generating output ..."
|
240 |
+
|
241 |
+
if post_check:
|
242 |
+
yield [(response, None)], "Checking Grammar ..."
|
243 |
+
response = correcting_text(response)
|
244 |
+
|
245 |
+
print(response)
|
246 |
+
|
247 |
+
if annotate:
|
248 |
+
e_edit = annotate_text(text, response)
|
249 |
+
else:
|
250 |
+
e_edit = [(response, None)]
|
251 |
+
|
252 |
+
yield e_edit, "Done."
|
253 |
+
|
254 |
+
def handle_highlight_selection():
|
255 |
+
"""
|
256 |
+
Handles the selection event of the highlighted text component.
|
257 |
+
|
258 |
+
This function is not implemented in the original code.
|
259 |
+
"""
|
260 |
+
# print("hi")
|
261 |
+
return
|
262 |
+
|
263 |
+
gr.Markdown("# English Text Editing Application using T5 and Tiny Llama")
|
264 |
+
gr.Markdown("> source code: https://github.com/LETHEVIET/t5nyllama")
|
265 |
+
with gr.Row() as row:
|
266 |
+
with gr.Column(scale=1) as col1:
|
267 |
+
instruction = gr.Dropdown(
|
268 |
+
choices=choices2promts(),
|
269 |
+
value="Grammar Error Correction",
|
270 |
+
multiselect=False,
|
271 |
+
label="Choose your instruction",
|
272 |
+
interactive=True,
|
273 |
+
scale=0
|
274 |
+
)
|
275 |
+
|
276 |
+
with gr.Row() as row2:
|
277 |
+
clear = gr.Button("Clear", scale=-1)
|
278 |
+
submit = gr.Button("submit", scale=-1)
|
279 |
+
|
280 |
+
info_msg = gr.Textbox(
|
281 |
+
label="Information",
|
282 |
+
scale=1,
|
283 |
+
lines=3,
|
284 |
+
value="Therefore careful analysis of a product has to be made before select a solution for testing and implementation.",
|
285 |
+
)
|
286 |
+
|
287 |
+
post_check = gr.Checkbox(label="Check grammaticality after text generation.", value=True)
|
288 |
+
annotate = gr.Checkbox(label="Highlight different", value=True)
|
289 |
+
with gr.Column(scale=2) as col2:
|
290 |
+
msg = gr.Textbox(
|
291 |
+
label="Input",
|
292 |
+
scale=3,
|
293 |
+
value="Therefore careful analysis of a product has to be made before select a solution for testing and implementation.",
|
294 |
+
)
|
295 |
+
|
296 |
+
result = gr.HighlightedText(
|
297 |
+
label="Result",
|
298 |
+
combine_adjacent=True,
|
299 |
+
show_legend=False,
|
300 |
+
scale=3
|
301 |
+
)
|
302 |
+
|
303 |
+
res_msg = gr.Textbox(
|
304 |
+
scale=0,
|
305 |
+
visible=False,
|
306 |
+
label="Ouput",
|
307 |
+
)
|
308 |
+
|
309 |
+
msg.submit(turn_off_legend, msg, result).then(bot, [instruction, msg, post_check, annotate], [result, info_msg]).then(turn_on_legend, annotate, result)
|
310 |
+
|
311 |
+
clear.click(lambda: None, None, result, queue=False)
|
312 |
+
|
313 |
+
submit.click(turn_off_legend, msg, result).then(bot, [instruction, msg, post_check, annotate], [result, info_msg]).then(turn_on_legend, annotate, result)
|
314 |
+
|
315 |
+
result.select(handle_highlight_selection, [], [])
|
316 |
+
|
317 |
+
if __name__ == "__main__":
|
318 |
+
demo.launch(server_port=7860)
|
config.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
BATCH_SIZE = 4
|
dockerfile
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM python:3.10
|
2 |
+
|
3 |
+
WORKDIR /app
|
4 |
+
|
5 |
+
COPY . .
|
6 |
+
|
7 |
+
RUN apt-get update && apt-get -y upgrade
|
8 |
+
RUN apt-get install -y build-essential
|
9 |
+
RUN pip3 install --upgrade pip setuptools
|
10 |
+
RUN pip3 install -r requirements.txt
|
11 |
+
RUN python3 -m spacy download en_core_web_sm
|
12 |
+
RUN python3 download_dependencies.py
|
13 |
+
|
14 |
+
CMD ["python3", "app.py"]
|
download_dependencies.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gdown
|
2 |
+
import nltk
|
3 |
+
|
4 |
+
id = "1TnPssg0CkWQ_thuAH8cY3hdB2J18A0Kl"
|
5 |
+
output = "texteditor-model/coedit-tinyllama-chat-bnb-4bit-unsloth.Q4_K_M.gguf"
|
6 |
+
gdown.download(id=id, output=output)
|
7 |
+
|
8 |
+
nltk.download('punkt')
|
errant_verbose.json
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"Default": {
|
3 |
+
"M": "<b>{cor}</b> should be inserted here, considering the context.",
|
4 |
+
"R": "<b>{cor} is more appropriate than <b>{ori}</b> in this context.",
|
5 |
+
"U": "<b>{ori}</b> is unnecessary/incorrect in this context."
|
6 |
+
},
|
7 |
+
"ADJ": {
|
8 |
+
"title": "Adjective"
|
9 |
+
},
|
10 |
+
"ADJ:FORM": {
|
11 |
+
"title": "Adjective Form"
|
12 |
+
},
|
13 |
+
"ADV": {
|
14 |
+
"title": "Adverb"
|
15 |
+
},
|
16 |
+
"CONJ": {
|
17 |
+
"title": "Conjunction"
|
18 |
+
},
|
19 |
+
"CONTR": {
|
20 |
+
"title": "Contraction"
|
21 |
+
},
|
22 |
+
"DET": {
|
23 |
+
"title": "Determiner"
|
24 |
+
},
|
25 |
+
"NOUN": {
|
26 |
+
"title": "Noun"
|
27 |
+
},
|
28 |
+
"NOUN:POSS": {
|
29 |
+
"title": "Possessive Noun"
|
30 |
+
},
|
31 |
+
"OTHER": {
|
32 |
+
"title": ""
|
33 |
+
},
|
34 |
+
"PART": {
|
35 |
+
"title": "Particle"
|
36 |
+
},
|
37 |
+
"PREP": {
|
38 |
+
"title": "Preposition"
|
39 |
+
},
|
40 |
+
"PRON": {
|
41 |
+
"title": "Pronoun"
|
42 |
+
},
|
43 |
+
"PUNCT": {
|
44 |
+
"title": "Punctuation"
|
45 |
+
},
|
46 |
+
"VERB": {
|
47 |
+
"title": "Verb"
|
48 |
+
},
|
49 |
+
"VERB:FORM": {
|
50 |
+
"title": "Verb Form"
|
51 |
+
},
|
52 |
+
"VERB:TENSE": {
|
53 |
+
"title": "Verb Tense"
|
54 |
+
},
|
55 |
+
"MORPH": {
|
56 |
+
"title": "Morphology",
|
57 |
+
"R": "The word form of <b>{cor}</b> is more appropriate than <b>{ori}</b> here."
|
58 |
+
},
|
59 |
+
"NOUN:INFL": {
|
60 |
+
"title": "Noun Inflection",
|
61 |
+
"R": "<b>{ori}</b> is wrong and should be written as <b>{cor}</b>."
|
62 |
+
},
|
63 |
+
"NOUN:NUM": {
|
64 |
+
"title": "Noun Number",
|
65 |
+
"R": "The noun number of <b>{ori}</b> is wrong and should be written as <b>{cor}</b>."
|
66 |
+
},
|
67 |
+
"ORTH": {
|
68 |
+
"title": "Orthography"
|
69 |
+
},
|
70 |
+
"SPELL": {
|
71 |
+
"title": "Spelling",
|
72 |
+
"R": "<b>{ori}</b> is not the correct spelling of <b>{cor}</b>."
|
73 |
+
},
|
74 |
+
"VERB:INFL": {
|
75 |
+
"title": "Verb Inflection",
|
76 |
+
"R": "<b>{ori}</b> is wrong and should be written as <b>{cor}</b>."
|
77 |
+
},
|
78 |
+
"VERB:SVA": {
|
79 |
+
"title": "Subject-Verb Agreement",
|
80 |
+
"R": "The form of <b>{ori}</b> does not follow the subject-verb agreement."
|
81 |
+
},
|
82 |
+
"WO": {
|
83 |
+
"title": "Word Order",
|
84 |
+
"R": "The word order of '<b>{ori}</b>' is wrong."
|
85 |
+
}
|
86 |
+
}
|
finetuning_tinyllama.py
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from unsloth import FastLanguageModel
|
2 |
+
import torch
|
3 |
+
|
4 |
+
# Define model parameters
|
5 |
+
max_seq_length = 4096 # Choose any! We auto support RoPE Scaling internally!
|
6 |
+
dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
|
7 |
+
load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False.
|
8 |
+
|
9 |
+
# Load the model and tokenizer
|
10 |
+
model, tokenizer = FastLanguageModel.from_pretrained(
|
11 |
+
model_name="unsloth/tinyllama-chat-bnb-4bit", # "unsloth/tinyllama" for 16bit loading
|
12 |
+
max_seq_length=max_seq_length,
|
13 |
+
dtype=dtype,
|
14 |
+
load_in_4bit=load_in_4bit,
|
15 |
+
# token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
|
16 |
+
)
|
17 |
+
|
18 |
+
# Apply PEFT (Parameter-Efficient Fine-Tuning)
|
19 |
+
model = FastLanguageModel.get_peft_model(
|
20 |
+
model,
|
21 |
+
r=32, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
|
22 |
+
target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
|
23 |
+
"gate_proj", "up_proj", "down_proj",],
|
24 |
+
lora_alpha=32,
|
25 |
+
lora_dropout=0, # Currently only supports dropout = 0
|
26 |
+
bias="none", # Currently only supports bias = "none"
|
27 |
+
use_gradient_checkpointing=False, # @@@ IF YOU GET OUT OF MEMORY - set to True @@@
|
28 |
+
random_state=3407,
|
29 |
+
use_rslora=False, # We support rank stabilized LoRA
|
30 |
+
loftq_config=None, # And LoftQ
|
31 |
+
)
|
32 |
+
|
33 |
+
# Data preparation
|
34 |
+
import pandas as pd
|
35 |
+
from sklearn.model_selection import train_test_split
|
36 |
+
import datasets
|
37 |
+
|
38 |
+
# Load the dataset
|
39 |
+
train = datasets.load_dataset("grammarly/coedit", split="train").to_pandas()
|
40 |
+
val = datasets.load_dataset("grammarly/coedit", split="validation").to_pandas()
|
41 |
+
|
42 |
+
# Data cleaning and preparation
|
43 |
+
data = pd.concat([train, val])
|
44 |
+
data[['instruction', 'input']] = data['src'].str.split(': ', n=1, expand=True)
|
45 |
+
data = data.rename(columns={"tgt": "output"})
|
46 |
+
data = data.drop(columns=["_id", "src"])
|
47 |
+
|
48 |
+
# Stratify based on task for balanced splits
|
49 |
+
stratify_col = data['task']
|
50 |
+
|
51 |
+
# Split the data into train and test sets
|
52 |
+
train_df, test_df = train_test_split(
|
53 |
+
data,
|
54 |
+
test_size=0.2,
|
55 |
+
random_state=42,
|
56 |
+
stratify=stratify_col
|
57 |
+
)
|
58 |
+
|
59 |
+
def formatting_prompts_func(examples, tokenizer):
|
60 |
+
"""
|
61 |
+
Formats the examples into the desired chat format for training.
|
62 |
+
|
63 |
+
Args:
|
64 |
+
examples: A dictionary of examples from the dataset.
|
65 |
+
tokenizer: The tokenizer used for processing text.
|
66 |
+
|
67 |
+
Returns:
|
68 |
+
A dictionary containing the formatted text for each example.
|
69 |
+
"""
|
70 |
+
instructions = examples["instruction"]
|
71 |
+
inputs = examples["input"]
|
72 |
+
outputs = examples["output"]
|
73 |
+
texts = []
|
74 |
+
for instruction, input, output in zip(instructions, inputs, outputs):
|
75 |
+
message = [
|
76 |
+
{"role": "user", "content": instruction + ": " + input},
|
77 |
+
{"role": "assistant", "content": output},
|
78 |
+
]
|
79 |
+
text = tokenizer.apply_chat_template(
|
80 |
+
message, tokenize=False, add_generation_prompt=False)
|
81 |
+
texts.append(text)
|
82 |
+
return {"text": texts, }
|
83 |
+
|
84 |
+
# Create datasets from pandas DataFrames
|
85 |
+
train_ds = datasets.Dataset.from_pandas(train_df)
|
86 |
+
test_ds = datasets.Dataset.from_pandas(test_df)
|
87 |
+
|
88 |
+
# Map the formatting function to the datasets for chat format conversion
|
89 |
+
train_ds = train_ds.map(formatting_prompts_func, fn_kwargs={"tokenizer": tokenizer}, batched=True,)
|
90 |
+
test_ds = test_ds.map(formatting_prompts_func, fn_kwargs={"tokenizer": tokenizer}, batched=True,)
|
91 |
+
|
92 |
+
print(train_ds[0]['text'])
|
93 |
+
|
94 |
+
# Fine-tuning with trl
|
95 |
+
from trl import SFTTrainer
|
96 |
+
from transformers import TrainingArguments
|
97 |
+
|
98 |
+
# Define training arguments
|
99 |
+
trainer = SFTTrainer(
|
100 |
+
model=model,
|
101 |
+
tokenizer=tokenizer,
|
102 |
+
train_dataset=train_ds,
|
103 |
+
eval_dataset=test_ds,
|
104 |
+
dataset_text_field="text",
|
105 |
+
max_seq_length=max_seq_length,
|
106 |
+
dataset_num_proc=10,
|
107 |
+
packing=False, # Can make training 5x faster for short sequences.
|
108 |
+
args=TrainingArguments(
|
109 |
+
per_device_train_batch_size=8,
|
110 |
+
per_device_eval_batch_size=8,
|
111 |
+
gradient_accumulation_steps=4,
|
112 |
+
warmup_steps=5,
|
113 |
+
num_train_epochs=2,
|
114 |
+
learning_rate=2e-4,
|
115 |
+
fp16=not torch.cuda.is_bf16_supported(),
|
116 |
+
bf16=torch.cuda.is_bf16_supported(),
|
117 |
+
logging_steps=1,
|
118 |
+
save_steps=100,
|
119 |
+
save_total_limit=4, # Limit the total number of checkpoints
|
120 |
+
evaluation_strategy="steps",
|
121 |
+
eval_steps=100,
|
122 |
+
optim="adamw_8bit",
|
123 |
+
weight_decay=0.01,
|
124 |
+
lr_scheduler_type="linear",
|
125 |
+
seed=3407,
|
126 |
+
output_dir="outputs",
|
127 |
+
load_best_model_at_end=True,
|
128 |
+
save_strategy="steps",
|
129 |
+
),
|
130 |
+
)
|
131 |
+
|
132 |
+
# Print GPU information
|
133 |
+
gpu_stats = torch.cuda.get_device_properties(0)
|
134 |
+
start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
|
135 |
+
max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
|
136 |
+
print(f"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.")
|
137 |
+
print(f"{start_gpu_memory} GB of memory reserved.")
|
138 |
+
|
139 |
+
# Train the model
|
140 |
+
trainer_stats = trainer.train()
|
141 |
+
|
142 |
+
# Print memory usage statistics
|
143 |
+
used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
|
144 |
+
used_memory_for_lora = round(used_memory - start_gpu_memory, 3)
|
145 |
+
used_percentage = round(used_memory / max_memory * 100, 3)
|
146 |
+
lora_percentage = round(used_memory_for_lora / max_memory * 100, 3)
|
147 |
+
print(f"{trainer_stats.metrics['train_runtime']} seconds used for training.")
|
148 |
+
print(f"{round(trainer_stats.metrics['train_runtime'] / 60, 2)} minutes used for training.")
|
149 |
+
print(f"Peak reserved memory = {used_memory} GB.")
|
150 |
+
print(f"Peak reserved memory for training = {used_memory_for_lora} GB.")
|
151 |
+
print(f"Peak reserved memory % of max memory = {used_percentage} %.")
|
152 |
+
print(f"Peak reserved memory for training % of max memory = {lora_percentage} %.")
|
153 |
+
|
154 |
+
# Save the trained model and tokenizer
|
155 |
+
print("Saving model to local")
|
156 |
+
model.save_pretrained("coedit-tinyllama-chat-bnb-4bit") # Local saving
|
157 |
+
tokenizer.save_pretrained("coedit-tinyllama-chat-bnb-4bit")
|
158 |
+
|
159 |
+
# Evaluate the model (Optional)
|
160 |
+
#trainer.evaluate()
|
requirements.txt
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
spacy
|
2 |
+
transformers
|
3 |
+
gradio
|
4 |
+
errant
|
5 |
+
nltk
|
6 |
+
llama-cpp-python
|
7 |
+
gdown
|
8 |
+
tensorflow
|
9 |
+
tf-keras
|
utils.py
ADDED
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
GEC = [
|
2 |
+
"Fix grammar",
|
3 |
+
"Fix grammar in this sentence",
|
4 |
+
"Fix grammar in the sentence",
|
5 |
+
"Fix grammar errors",
|
6 |
+
"Fix grammatical errors",
|
7 |
+
"Fix grammaticality",
|
8 |
+
"Fix all grammatical errors",
|
9 |
+
"Fix grammatical errors in this sentence",
|
10 |
+
"Fix grammar errors in this sentence",
|
11 |
+
"Fix grammatical mistakes in this sentence",
|
12 |
+
"Fix grammaticality in this sentence",
|
13 |
+
"Fix grammaticality of the sentence",
|
14 |
+
"Fix disfluencies in the sentence",
|
15 |
+
"Make the sentence grammatical",
|
16 |
+
"Make the sentence fluent",
|
17 |
+
"Fix errors in this text",
|
18 |
+
"Update to remove grammar errors",
|
19 |
+
"Remove all grammatical errors from this text",
|
20 |
+
"Improve the grammar of this text",
|
21 |
+
"Improve the grammaticality",
|
22 |
+
"Improve the grammaticality of this text",
|
23 |
+
"Improve the grammaticality of this sentence",
|
24 |
+
"Grammar improvements",
|
25 |
+
"Remove grammar mistakes",
|
26 |
+
"Remove grammatical mistakes",
|
27 |
+
"Fix the grammar mistakes",
|
28 |
+
"Fix grammatical mistakes Clarity Clarify the sentence",
|
29 |
+
]
|
30 |
+
Clarify = [
|
31 |
+
"Clarify this sentence",
|
32 |
+
"Clarify this text",
|
33 |
+
"Write a clearer version for the sentence",
|
34 |
+
"Write a clarified version of the sentence",
|
35 |
+
"Write a readable version of the sentence",
|
36 |
+
"Write a better readable version of the sentence",
|
37 |
+
"Rewrite the sentence more clearly",
|
38 |
+
"Rewrite this sentence clearly",
|
39 |
+
"Rewrite this sentence for clarity",
|
40 |
+
"Rewrite this sentence for readability",
|
41 |
+
"Improve this sentence for readability",
|
42 |
+
"Make this sentence better readable",
|
43 |
+
"Make this sentence more readable",
|
44 |
+
"Make this sentence readable",
|
45 |
+
"Make the sentence clear",
|
46 |
+
"Make the sentence clearer",
|
47 |
+
"Clarify",
|
48 |
+
"Make the text more understandable",
|
49 |
+
"Make this easier to read",
|
50 |
+
"Clarification",
|
51 |
+
"Change to clearer wording",
|
52 |
+
"Clarify this paragraph",
|
53 |
+
"Use clearer wording Simplification Simplify the sentence",
|
54 |
+
"Simplify this sentence",
|
55 |
+
"Simplify this text",
|
56 |
+
"Write a simpler version for the sentence",
|
57 |
+
"Rewrite the sentence to be simpler",
|
58 |
+
"Rewrite this sentence in a simpler manner",
|
59 |
+
"Rewrite this sentence for simplicity",
|
60 |
+
"Rewrite this with simpler wording",
|
61 |
+
"Make the sentence simple",
|
62 |
+
"Make the sentence simpler",
|
63 |
+
"Make this text less complex",
|
64 |
+
"Make this simpler",
|
65 |
+
"Simplify",
|
66 |
+
"Simplification",
|
67 |
+
"Change to simpler wording",
|
68 |
+
"Simplify this paragraph",
|
69 |
+
"Simplify this text",
|
70 |
+
"Use simpler wording",
|
71 |
+
"Make this easier to understand"
|
72 |
+
]
|
73 |
+
Coherence = [
|
74 |
+
"Fix coherence",
|
75 |
+
"Fix coherence in this sentence",
|
76 |
+
"Fix coherence in the sentence",
|
77 |
+
"Fix coherence in this text",
|
78 |
+
"Fix coherence in the text",
|
79 |
+
"Fix coherence errors",
|
80 |
+
"Fix sentence flow",
|
81 |
+
"Fix sentence transition",
|
82 |
+
"Fix coherence errors in this sentence",
|
83 |
+
"Fix coherence mistakes in this sentence",
|
84 |
+
"Fix coherence in this sentence",
|
85 |
+
"Fix coherence of the sentence",
|
86 |
+
"Fix lack of coherence in the sentence",
|
87 |
+
"Make the text more coherent",
|
88 |
+
"Make the text coherent",
|
89 |
+
"Make the text more cohesive",
|
90 |
+
"logically linked and consistent as a whole",
|
91 |
+
"Make the text more cohesive",
|
92 |
+
"Improve the cohesiveness of the text",
|
93 |
+
"Make the text more logical",
|
94 |
+
"Make the text more consistent",
|
95 |
+
"Improve the consistency of the text",
|
96 |
+
"Make the text clearer",
|
97 |
+
"Improve the coherence of the text"
|
98 |
+
]
|
99 |
+
Formality_Style_Transfer = [
|
100 |
+
"Formalize",
|
101 |
+
"Improve formality",
|
102 |
+
"Formalize the sentence",
|
103 |
+
"Formalize this sentence",
|
104 |
+
"Formalize the text",
|
105 |
+
"Formalize this text",
|
106 |
+
"Make this formal",
|
107 |
+
"Make this more formal",
|
108 |
+
"Make this sound more formal",
|
109 |
+
"Make the sentence formal",
|
110 |
+
"Make the sentence more formal",
|
111 |
+
"Make the sentence sound more formal",
|
112 |
+
"Write more formally",
|
113 |
+
"Write less informally",
|
114 |
+
"Rewrite more formally",
|
115 |
+
"Write this more formally",
|
116 |
+
"Rewrite this more formally",
|
117 |
+
"Write in a formal manner",
|
118 |
+
"Write in a more formal manner",
|
119 |
+
"Rewrite in a more formal manner"
|
120 |
+
]
|
121 |
+
Neutralization = [
|
122 |
+
"Remove POV",
|
123 |
+
"Remove POVs",
|
124 |
+
"Remove POV in this text",
|
125 |
+
"Remove POVs in this text",
|
126 |
+
"Neutralize this text",
|
127 |
+
"Neutralize the text",
|
128 |
+
"Neutralize this sentence",
|
129 |
+
"Neutralize the sentence",
|
130 |
+
"Make this more neutral",
|
131 |
+
"Make this text more neutral",
|
132 |
+
"Make this sentence more neutral",
|
133 |
+
"Make this paragraph more neutral",
|
134 |
+
"Remove unsourced opinions",
|
135 |
+
"Remove unsourced opinions from this text",
|
136 |
+
"Remove non-neutral POVs",
|
137 |
+
"Remove non-neutral POV",
|
138 |
+
"Remove non-neutral points of view",
|
139 |
+
"Remove points of view",
|
140 |
+
"Make this text less biased Paraphrasing Paraphrase the sentence",
|
141 |
+
"Paraphrase this sentence",
|
142 |
+
"Paraphrase this text",
|
143 |
+
]
|
144 |
+
Paraphrase = [
|
145 |
+
"Write a paraphrase for the sentence",
|
146 |
+
"Write a paraphrased version of the sentence",
|
147 |
+
"Rewrite the sentence with different wording",
|
148 |
+
"Use different wording",
|
149 |
+
"Rewrite this sentence",
|
150 |
+
"Reword this sentence",
|
151 |
+
"Rephrase this sentence",
|
152 |
+
"Rewrite this text",
|
153 |
+
"Reword this text",
|
154 |
+
"Rephrase this text"
|
155 |
+
]
|
156 |
+
|
157 |
+
|
158 |
+
import random
|
159 |
+
import os
|
160 |
+
|
161 |
+
instruction_prompts = {
|
162 |
+
"Grammar Error Correction": GEC,
|
163 |
+
"Clarify": Clarify,
|
164 |
+
"Coherence": Coherence,
|
165 |
+
"Formality Style Transfer": Formality_Style_Transfer,
|
166 |
+
"Neutralization": Neutralization,
|
167 |
+
"Paraphrase": Paraphrase,
|
168 |
+
}
|
169 |
+
|
170 |
+
def get_prompt_list(instruction_type: str) -> list:
|
171 |
+
"""
|
172 |
+
Returns a list of prompts for the given instruction type.
|
173 |
+
|
174 |
+
Args:
|
175 |
+
instruction_type: The type of instruction, e.g., "Grammar Error Correction".
|
176 |
+
|
177 |
+
Returns:
|
178 |
+
A list of prompts corresponding to the instruction type.
|
179 |
+
"""
|
180 |
+
return instruction_prompts[instruction_type]
|
181 |
+
|
182 |
+
def get_random_prompt(instruction_type: str) -> str:
|
183 |
+
"""
|
184 |
+
Returns a random prompt from the list of prompts for the given instruction type.
|
185 |
+
|
186 |
+
Args:
|
187 |
+
instruction_type: The type of instruction, e.g., "Grammar Error Correction".
|
188 |
+
|
189 |
+
Returns:
|
190 |
+
A random prompt from the list of prompts for the instruction type.
|
191 |
+
"""
|
192 |
+
return random.choice(instruction_prompts[instruction_type])
|
193 |
+
|