from json import JSONDecodeError
import gradio as gr
from openai import OpenAI
import os
import json
from typing import *
import sys
class Chat:
chat_history = []
like_history = []
dislike_history = []
max_chat_round = -1
round = 0
def __init__(self):
pass
def init_chat_history(self):
prompt = f'''关于哪吒的提示词'''
self.chat_history = [
{
"role":"system",
"content":prompt
}
]
self.round = 0
def add_to_chat_history(self, chat: dict):
self.chat_history.append(chat)
def remove_from_chat_history(self, chat: dict):
pass
def get_chat_history(self):
return self.chat_history
def add_to_like_history(self, index: List):
self.like_history.append(index)
def remove_from_like_history(self, index: List):
pass
def add_to_dislike_history(self, index: List):
self.dislike_history.append(index)
def remove_from_like_history(self, index: List):
pass
def format(self) -> List[dict]:
result = self.chat_history
for chat in result:
chat["liked"] = 0
for like in self.like_history:
like_index = like[0] * 2 + like[1]
result[like_index]["liked"] = 1
for dislike in self.dislike_history:
dislike_index = dislike[0] * 2 + dislike[1]
result[dislike_index]["liked"] = -1
return result
def save(self, file_path: str):
with open(file_path, 'w', encoding='utf-8') as file:
json.dump(self.format(), file, ensure_ascii=False, indent=4)
file.close()
def round_increase(self):
self.round += 1
if self.max_chat_round == -1:
return True
elif self.round == self.max_chat_round:
return False
chat = Chat()
def save_single_data(instruction: str, input: str, output: str, file_path: str):
data = [
{
"instruction": instruction,
"input": input,
"output": output,
}
]
# 尝试读取现有的JSON数据,如果文件不存在则创建一个空列表
try:
with open(file_path, 'r', encoding='utf-8') as file:
# 读取JSON数据并转换为Python列表
original_data = json.load(file)
print(original_data)
file.close()
except FileNotFoundError:
# 如果文件不存在,初始化一个空列表
original_data = []
except ValueError:
# 文件存在单内容为空,捕获报错
original_data = []
# 将新数据追加到现有的数据列表中
original_data.extend(data)
# 将更新后的数据写回JSON文件
with open(file_path, 'w', encoding='utf-8') as file:
json.dump(original_data, file, ensure_ascii=False, indent=4)
file.close()
# def response(message, history):
# client = OpenAI(
# api_key=os.getenv("KIMI_API_KEY"),
# base_url="https://api.moonshot.cn/v1",
# )
#
# chat.add_to_chat_history({"role": "user", "content": message})
# messages = chat.get_chat_history()
# completion = client.chat.completions.create(
# model="moonshot-v1-8k",
# messages=messages,
# temperature=0.3,
# )
#
# chat.add_to_chat_history({"role": "assistant", "content": completion.choices[0].message.content})
# chat.round_increase()
# return completion.choices[0].message.content
def response(message, history):
return "未接入LLM..."
def vote(data: gr.LikeData):
if data.liked:
chat.add_to_like_history(data.index)
else:
chat.add_to_dislike_history(data.index)
def end():
# 刷新界面 保存数据
path = 'test.json'
chat.save(path)
# chat.init_chat_history()
theme = gr.themes.Base()
with gr.Blocks(theme=theme) as demo:
chatbot = gr.Chatbot(height=500,placeholder="哪吒-魔童降世
Chat with Me", type='tuples')
chatbot.like(vote, None, None)
gr.ChatInterface(fn=response, chatbot=chatbot, show_progress='full',retry_btn=None,undo_btn=None)
end_btn = gr.Button("Upload Chat Data")
end_btn.click(end)
demo.launch()
# if __name__=='__main__':
# path = 'test.json'
# save_single_data('asd','asd','asdd',path)