import os import random import re import requests import argparse import string from datetime import timedelta from flask import Flask, session, request, jsonify, render_template from transformers.models.bert.tokenization_bert import BertTokenizer from bot.chatbot import ChatBot from bot.config import special_token_list app = Flask(__name__) app.config["SECRET_KEY"] = os.urandom(74) app.config["PERMANENT_SESSION_LIFETIME"] = timedelta(days=7) tokenizer:BertTokenizer = None history_matrix:dict = {} def move_history_from_session_to_global_memory() -> None: global history_matrix if session.get( "session_hash") and session["history"]: history_matrix[session["session_hash"]] = session["history"] def move_history_from_global_memory_to_session() -> None: global history_matrix if session.get( "session_hash"): session["history"] = history_matrix.get( session.get( "session_hash") ) def set_args() -> argparse.Namespace: parser:argparse.ArgumentParser = argparse.ArgumentParser() parser.add_argument("--vocab_path", default=None, type=str, required=False, help="选择词库") parser.add_argument("--model_path", default="lewiswu1209/Winnie", type=str, required=False, help="对话模型路径") return parser.parse_args() @app.route("/chitchat/history", methods = ["GET"]) def get_history_list() -> str: global tokenizer move_history_from_global_memory_to_session() history_list:list = session.get("history") if history_list is None: history_list = [] history:list = [] for history_ids in history_list: tokens = tokenizer.convert_ids_to_tokens(history_ids) fixed_tokens = [] for token in tokens: if token.startswith("##"): token = token[2:] fixed_tokens.append(token) history.append( "".join( fixed_tokens ) ) return jsonify(history) @app.route("/chitchat/chat", methods = ["GET"]) def talk() -> str: global tokenizer global history_matrix if request.args.get("hash"): session["session_hash"] = request.args.get("hash") move_history_from_global_memory_to_session() if session.get("session_hash") is None: session["session_hash"] = "".join( random.sample(string.ascii_lowercase + string.digits, 11) ) if request.args.get("text"): input_text = request.args.get("text") history_list = session.get("history") if input_text.upper()=="HELP": help_info_list = ["输入任意文字,Winnie会和你对话", "输入ERASE MEMORY,Winnie会清空记忆", "输入\"=\",Winnie会记录你的角色信息", "例如:=Vicky,Winnie会修改自己的名字", "可以修改的角色信息有:", ", , , , , , ", "输入“上联:XXXXXXX”,Winnie会和你对对联", "输入“写诗:XXXXXXX”,Winnie会以XXXXXXX为开头写诗", "以\"请问\"开头并以问号结尾,Winnie会回答该问题" ] return jsonify(help_info_list) if history_list is None or len(history_list)==0 or input_text == "ERASE MEMORY": history_list = [] output_text = requests.post( url='https://hf.space/embed/lewiswu1209/Winnie/+/api/predict/', json={"data": ["ERASE MEMORY"], "session_hash": session["session_hash"]} ).json()["data"][0] if input_text != "ERASE MEMORY": if not re.match( r"^<.+>=.+$", input_text ): history_list.append( tokenizer.encode(input_text, add_special_tokens=False) ) output_text = requests.post( url='https://hf.space/embed/lewiswu1209/Winnie/+/api/predict/', json={"data": [input_text], "session_hash": session["session_hash"]} ).json()["data"][0] if not re.match( r"^<.+>=.+$", input_text ): history_list.append( tokenizer.encode(output_text, add_special_tokens=False) ) session["history"] = history_list history_matrix[session["session_hash"]] = history_list return jsonify([output_text]) else: return jsonify([""]) @app.route("/") def index() -> str: return "Hello world!" @app.route("/chitchat/hash", methods = ["GET"]) def get_hash() -> str: global history_matrix if request.args.get("hash"): session["session_hash"] = request.args.get("hash") move_history_from_global_memory_to_session() hash = session.get("session_hash") if hash: return session.get("session_hash") else: return " " @app.route( "/chitchat", methods = ["GET"] ) def chitchat() -> str: return render_template( "chat_template.html" ) def main() -> None: global tokenizer args = set_args() tokenizer = ChatBot.get_tokenizer( args.model_path, vocab_path=args.vocab_path, special_token_list = special_token_list ) app.run( host = "127.0.0.1", port = 8080 ) if __name__ == "__main__": main()