Spaces:
Sleeping
Sleeping
from template import TEMPLATE_v1, TEMPLATE_v2, TEMPLATE_v3, QUESTIONS | |
import json | |
class Prompt: | |
def __init__(self) -> None: | |
# self.questions = QUESTIONS | |
self.template_v1 = TEMPLATE_v1 | |
self.template_v2 = TEMPLATE_v2 | |
self.template_v3 = TEMPLATE_v3 | |
self.version = "v3" | |
def combine_questions(self, questions): | |
questions = [ f'Question {id_ +1 }: {q}' for id_, q in enumerate(questions) if 'Input question' not in q] | |
questions = '\n'.join(questions) | |
return questions | |
def _get_v1(self, input, questions): | |
questions = self.combine_questions(questions) | |
return self.template_v1.format(input, self.questions) | |
def _get_v2(self, input, questions): | |
questions = self.combine_questions(questions) | |
return self.template_v2.format(input, self.questions) | |
def _get_v3(self, input, questions): | |
return self.template_v3.format(input) | |
def get(self, input, questions, version = None): | |
self.version = version if version else self.version | |
if self.version == 'v1': | |
return self._get_v1(input, questions) | |
elif self.version == 'v2': | |
return self._get_v2(input, questions) | |
elif self.version == 'v3': | |
return self._get_v3(input, questions) | |
else: | |
raise ValueError('Version should be one of {v1, v2, v3}') | |
def _process_v1(self, res): | |
res = json.loads(res) | |
return res | |
def _process_v2(self, res): | |
res = json.loads(res) | |
return res | |
def _process_v3(self, x): | |
x = json.loads(x) | |
res = {} | |
question_id = 0 | |
for k, v in x.items(): | |
if 'answer' in v: | |
question_id += 1 | |
question_name = f'Question {question_id}' | |
res_tmp = {"answer": v['answer'], "original sentences": v['original sentences']} | |
res[question_name] = res_tmp | |
else: | |
k_1, k_2 = v.keys() | |
in_1 = v[k_1] | |
in_2 = v[k_2] | |
question_id += 1 | |
question_name = f'Question {question_id}' | |
res_tmp_1 = {"answer": in_1['answer'], "original sentences": in_1['original sentences']} | |
res[question_name] = res_tmp_1 | |
question_id += 1 | |
question_name = f'Question {question_id}' | |
res_tmp_2 = {"answer": in_2['answer'], "original sentences": in_2['original sentences']} | |
res[question_name] = res_tmp_2 | |
return res | |
def process_result(self, result, version = None): | |
if not version is None and self.version != version: | |
self.version = version | |
print(f'Version changed to {version}') | |
if version == 'v1': | |
result = self._process_v1(result) | |
return result | |
elif version == 'v2': | |
result = self._process_v2(result) | |
return result | |
elif version == 'v3': | |
result = self._process_v3(result) | |
return result | |
else: | |
raise ValueError('Version should be one of {v1, v2, v3}') | |