Spaces:
Runtime error
Runtime error
from nltk import sent_tokenize | |
import openai | |
import re | |
import nltk | |
nltk.download('punkt') | |
class SynonymEditor: | |
def __init__(self, api_key, model_engine, max_tokens, temperature, language): | |
openai.api_key = api_key | |
self.model_engine = model_engine | |
self.max_tokens = max_tokens | |
self.temperature = temperature | |
self.language = language | |
self.quote = '__ZITIEREN__' if (language == 'de') else '__QUOTE__' | |
# Play with the prompts here and change the return index to change and see the effect of the prompt on the output quality | |
# Note that the longer the prompt, higher the token used and hence the billing | |
def _get_prompt(self, sentence, few_shots): | |
if (few_shots): | |
if (self.language == "de"): | |
prompt = 'Modernisiere den deutschen Text. Fasse direkte Reden NIE zusammen.\n' + \ | |
few_shots + "\nEingang:" + sentence + " Ausgang:" | |
else: | |
prompt = "Replace exactly one word with a synonym while preserving the overall sentence structure and meaning.\n" + \ | |
few_shots + "\nInput:" + sentence + " Output:" | |
elif self.quote in sentence: | |
if (self.language == "de"): | |
prompt = 'Modernisiere den deutschen Text. Fasse direkte Reden NIE zusammen.\n'+sentence+'\n' | |
else: | |
prompt = "Replace exactly one word with a synonym while preserving __QUOTE__ in the following sentence:\n"+sentence+"\n" | |
else: | |
if (self.language == "de"): | |
prompt = 'Modernisiere den deutschen Text. Fasse direkte Reden NIE zusammen.\n'+sentence+'\n' | |
else: | |
prompt = "Replace exactly one word with a synonym in the following sentence:\n"+sentence+"\n" | |
return prompt | |
# Call the OpenAI API here | |
def __call_ai(self, sentence, few_shots): | |
prompt = self._get_prompt(sentence, few_shots) | |
print(prompt) | |
response = openai.Completion.create( | |
model=self.model_engine, | |
prompt=prompt, | |
temperature=self.temperature, | |
max_tokens=self.max_tokens, | |
top_p=1, | |
frequency_penalty=0, | |
presence_penalty=0 | |
) | |
return self._post_process_sentence(response.choices[0].text.strip()) | |
# Split the paragraph to preserve quotation marks | |
def _split_into_sentences(self, text, few_shots): | |
if (few_shots == False): | |
text = text.replace('"', self.quote) | |
text = re.sub(r'\s+', ' ', text) | |
text = text.strip() | |
sentences = sent_tokenize(text) | |
return sentences | |
def _post_process_sentence(self, text): | |
print(text) | |
print("==============") | |
return text.replace(self.quote, '"') | |
# Preprocess the text, perform edit task and join back to get the original format | |
def _edit_text(self, text, few_shots=False): | |
edited_text = "" | |
paragraphs = text.split("\n\n") | |
edited_paragraphs = [] | |
for paragraph in paragraphs: | |
sentences = self._split_into_sentences(paragraph, few_shots) | |
edited_sentences = [] | |
for sentence in sentences: | |
new_sentence = self.__call_ai(sentence, few_shots) | |
edited_sentences.append(new_sentence) | |
# join edited sentences to form an edited paragraph | |
edited_paragraph = ' '.join(edited_sentences) | |
edited_paragraphs.append(edited_paragraph) | |
# join edited paragraphs to form edited text | |
edited_text = '\n\n'.join(edited_paragraphs) | |
return edited_text | |
# File Read Write operation | |
def edit_file(self, input_file, output_file): | |
print("Opening File") | |
with open(input_file, "r", encoding="utf8", errors="ignore") as f: | |
text = f.read() | |
print("Editing") | |
edited_text = self._edit_text(text) | |
print("Finishing up") | |
with open(output_file, "w") as f: | |
f.write(edited_text) | |
print("Done!") | |