AnswerMate / app.py
Brasd99's picture
Added Sber Gpt
97f265b
raw
history blame
5.48 kB
from typing import List, Tuple, Dict, Any
import time
import json
import requests
import gradio as gr
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer
with open("config.json", "r") as f:
config = json.load(f)
max_questions_count = config["MAX_QUESTIONS_COUNT"]
max_tags_count = config["MAX_TAGS_COUNT"]
max_attempts = config["MAX_ATTEMPS"]
wait_time = config["WAIT_TIME"]
chatgpt_url = config["CHATGPT_URL"]
system_prompt = config["SYSTEM_PROMPT"]
sber_gpt = config["SBER_GRT"]
use_sber = config["USE_SBER"]
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if use_sber:
tokenizer = GPT2Tokenizer.from_pretrained(sber_gpt)
model = GPT2LMHeadModel.from_pretrained(sber_gpt).to(DEVICE)
def generate(
model, tok, text,
do_sample=True, max_length=10000, repetition_penalty=5.0,
top_k=5, top_p=0.95, temperature=1,
num_beams=None,
no_repeat_ngram_size=3
):
input_ids = tok.encode(text, return_tensors="pt").to(DEVICE)
out = model.generate(
input_ids.to(DEVICE),
max_length=max_length,
repetition_penalty=repetition_penalty,
do_sample=do_sample,
top_k=top_k, top_p=top_p, temperature=temperature,
num_beams=num_beams, no_repeat_ngram_size=no_repeat_ngram_size
)
return list(map(tok.decode, out))[0]
def get_answer(question: str) -> Dict[str, Any]:
if use_sber:
content = generate(model, tokenizer, question)
return {
'status': True,
'content': content
}
headers = {
'Content-Type': 'application/json; charset=utf-8'
}
payload = {
'model': 'gpt-3.5-turbo',
'messages': [
{
'role': 'system',
'content': system_prompt
},
{
'role': 'user',
'content': question
}
]
}
try:
response = requests.post(chatgpt_url, headers=headers, data=json.dumps(payload))
response.raise_for_status()
content = response.json()['choices'][0]['message']['content']
return {
'status': True,
'content': content
}
except:
return {
'status': False
}
def format_results(results: List[Tuple[str, str]]) -> str:
output = ''
for i, (question, answer) in enumerate(results):
output += f'Question №{i+1}: {question}\n'
output += f'Answer: {answer}\n'
if i < len(results) - 1:
output += '--------------------------------------\n\n'
output = output.strip()
return output
def validate_and_get_tags(tags: str) -> List[str]:
if not tags.strip():
raise gr.Error('Validation error. It is necessary to set at least one tag')
tags = [tag.strip() for tag in tags.split('\n') if tag.strip()]
if len(tags) > max_tags_count:
raise gr.Error(f'Validation error. The maximum allowed number of tags is {max_tags_count}.')
return tags
def validate_and_get_questions(questions: str) -> List[str]:
if not questions.strip():
raise gr.Error('Validation error. It is necessary to ask at least one question')
questions = [question.strip() for question in questions.split('\n') if question.strip()]
if len(questions) > max_questions_count:
raise gr.Error(f'Validation error. The maximum allowed number of questions is {max_questions_count}.')
return questions
def find_answers(tags: str, questions: str, progress=gr.Progress()) -> str:
tags = validate_and_get_tags(tags)
questions = validate_and_get_questions(questions)
print(f'New attempt to get answers. Got {len(tags)} tags and {len(questions)} questions')
print(f'Tags: {tags}')
print(f'Questions: {questions}')
tags_str = ''.join([f'[{tag}]' for tag in tags])
results = []
for question in progress.tqdm(questions):
time.sleep(wait_time)
tagged_question = f'{tags_str} {question}'
for attempt in range(max_attempts):
answer = get_answer(tagged_question)
if answer['status']:
results.append((question, answer['content']))
break
elif attempt == max_attempts - 1:
results.append((question, 'An error occurred while receiving data.'))
else:
time.sleep(wait_time)
return format_results(results)
title = '<h1 style="text-align:center">AnswerMate</h1>'
with gr.Blocks(theme='soft', title='AnswerMate') as blocks:
gr.HTML(title)
gr.Markdown('The service allows you to get answers to all questions on the specified topic.')
with gr.Row():
tags_input = gr.Textbox(
label=f'Enter tags (each line is a separate tag). Maximum: {max_tags_count}.',
placeholder='.NET\nC#',
lines=max_tags_count
)
questions_input = gr.Textbox(
label=f'Enter questions (each line is a separate question). Maximum: {max_questions_count}.',
placeholder='What is inheritance, encapsulation, abstraction, polymorphism?\nWhat is CLR?',
lines=max_questions_count
)
process_button = gr.Button('Find answers')
outputs = gr.Textbox(label='Output', placeholder='Output will appear here')
process_button.click(fn=find_answers, inputs=[tags_input, questions_input], outputs=outputs)
blocks.queue(concurrency_count=1).launch()