Spaces:
Sleeping
Sleeping
import os | |
import random | |
import re | |
import requests | |
import argparse | |
import string | |
from datetime import timedelta | |
from flask import Flask, session, request, jsonify, render_template | |
from transformers.models.bert.tokenization_bert import BertTokenizer | |
from bot.chatbot import ChatBot | |
from bot.config import special_token_list | |
app = Flask(__name__) | |
app.config["SECRET_KEY"] = os.urandom(74) | |
app.config["PERMANENT_SESSION_LIFETIME"] = timedelta(days=7) | |
tokenizer:BertTokenizer = None | |
history_matrix:dict = {} | |
def move_history_from_session_to_global_memory() -> None: | |
global history_matrix | |
if session.get( "session_hash") and session["history"]: | |
history_matrix[session["session_hash"]] = session["history"] | |
def move_history_from_global_memory_to_session() -> None: | |
global history_matrix | |
if session.get( "session_hash"): | |
session["history"] = history_matrix.get( session.get( "session_hash") ) | |
def set_args() -> argparse.Namespace: | |
parser:argparse.ArgumentParser = argparse.ArgumentParser() | |
parser.add_argument("--vocab_path", default=None, type=str, required=False, help="选择词库") | |
parser.add_argument("--model_path", default="lewiswu1209/Winnie", type=str, required=False, help="对话模型路径") | |
return parser.parse_args() | |
def get_history_list() -> str: | |
global tokenizer | |
move_history_from_global_memory_to_session() | |
history_list:list = session.get("history") | |
if history_list is None: | |
history_list = [] | |
history:list = [] | |
for history_ids in history_list: | |
tokens = tokenizer.convert_ids_to_tokens(history_ids) | |
fixed_tokens = [] | |
for token in tokens: | |
if token.startswith("##"): | |
token = token[2:] | |
fixed_tokens.append(token) | |
history.append( "".join( fixed_tokens ) ) | |
return jsonify(history) | |
def talk() -> str: | |
global tokenizer | |
global history_matrix | |
if request.args.get("hash"): | |
session["session_hash"] = request.args.get("hash") | |
move_history_from_global_memory_to_session() | |
if session.get("session_hash") is None: | |
session["session_hash"] = "".join( random.sample(string.ascii_lowercase + string.digits, 11) ) | |
if request.args.get("text"): | |
input_text = request.args.get("text") | |
history_list = session.get("history") | |
if input_text.upper()=="HELP": | |
help_info_list = ["输入任意文字,Winnie会和你对话", | |
"输入ERASE MEMORY,Winnie会清空记忆", | |
"输入\"<TAG>=<VALUE>\",Winnie会记录你的角色信息", | |
"例如:<NAME>=Vicky,Winnie会修改自己的名字", | |
"可以修改的角色信息有:", | |
"<NAME>, <GENDER>, <YEAROFBIRTH>, <MONTHOFBIRTH>, <DAYOFBIRTH>, <ZODIAC>, <AGE>", | |
"输入“上联:XXXXXXX”,Winnie会和你对对联", | |
"输入“写诗:XXXXXXX”,Winnie会以XXXXXXX为开头写诗", | |
"以\"请问\"开头并以问号结尾,Winnie会回答该问题" | |
] | |
return jsonify(help_info_list) | |
if history_list is None or len(history_list)==0 or input_text == "ERASE MEMORY": | |
history_list = [] | |
output_text = requests.post( | |
url='https://hf.space/embed/lewiswu1209/Winnie/+/api/predict/', | |
json={"data": ["ERASE MEMORY"], "session_hash": session["session_hash"]} | |
).json()["data"][0] | |
if input_text != "ERASE MEMORY": | |
if not re.match( r"^<.+>=.+$", input_text ): | |
history_list.append( tokenizer.encode(input_text, add_special_tokens=False) ) | |
output_text = requests.post( | |
url='https://hf.space/embed/lewiswu1209/Winnie/+/api/predict/', | |
json={"data": [input_text], "session_hash": session["session_hash"]} | |
).json()["data"][0] | |
if not re.match( r"^<.+>=.+$", input_text ): | |
history_list.append( tokenizer.encode(output_text, add_special_tokens=False) ) | |
session["history"] = history_list | |
history_matrix[session["session_hash"]] = history_list | |
return jsonify([output_text]) | |
else: | |
return jsonify([""]) | |
def index() -> str: | |
return "Hello world!" | |
def get_hash() -> str: | |
global history_matrix | |
if request.args.get("hash"): | |
session["session_hash"] = request.args.get("hash") | |
move_history_from_global_memory_to_session() | |
hash = session.get("session_hash") | |
if hash: | |
return session.get("session_hash") | |
else: | |
return " " | |
def chitchat() -> str: | |
return render_template( "chat_template.html" ) | |
def main() -> None: | |
global tokenizer | |
args = set_args() | |
tokenizer = ChatBot.get_tokenizer( | |
args.model_path, | |
vocab_path=args.vocab_path, | |
special_token_list = special_token_list | |
) | |
app.run( host = "127.0.0.1", port = 8080 ) | |
if __name__ == "__main__": | |
main() | |