Spaces:
Runtime error
Runtime error
import os | |
import sys | |
sys.path.append('./Needy-Haruhi/src') | |
from Agent import Agent | |
agent = Agent() | |
from DialogueEvent import DialogueEvent | |
file_names = ["./Needy-Haruhi/data/complete_story_30.jsonl","./Needy-Haruhi/data/Daily_event_130.jsonl"] | |
import json | |
events = [] | |
for file_name in file_names: | |
with open(file_name, encoding='utf-8') as f: | |
for line in f: | |
try: | |
event = DialogueEvent( line ) | |
events.append( event ) | |
except: | |
try: | |
line = line.replace(',]',']') | |
event = DialogueEvent( line ) | |
events.append( event ) | |
# print('solve!') | |
except: | |
error_line = line | |
# events.append( event ) | |
import copy | |
events_for_memory = copy.deepcopy(events) | |
from MemoryPool import MemoryPool | |
memory_pool = MemoryPool() | |
memory_pool.load_from_events( events_for_memory ) | |
memory_pool.save("memory_pool.jsonl") | |
memory_pool.load("memory_pool.jsonl") | |
file_name = "./Needy-Haruhi/data/image_text_relationship.jsonl" | |
import json | |
data_img_text = [] | |
with open(file_name, encoding='utf-8') as f: | |
for line in f: | |
data = json.loads( line ) | |
data_img_text.append( data ) | |
import zipfile | |
import os | |
zip_file = './Needy-Haruhi/data/image.zip' | |
extract_path = './image' | |
with zipfile.ZipFile(zip_file, 'r') as zip_ref: | |
zip_ref.extractall(extract_path) | |
from tqdm import tqdm | |
from util import get_bge_embedding_zh | |
from util import float_array_to_base64, base64_to_float_array | |
import torch | |
import os | |
import copy | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
# compute cosine similarity between two vector | |
def get_cosine_similarity( v1, v2): | |
v1 = torch.tensor(v1).to(device) | |
v2 = torch.tensor(v2).to(device) | |
return torch.cosine_similarity(v1, v2, dim=0).item() | |
class ImagePool: | |
def __init__(self): | |
self.pool = [] | |
self.set_embedding( get_bge_embedding_zh ) | |
def set_embedding( self, embedding ): | |
self.embedding = embedding | |
def load_from_data( self, data_img_text , img_path ): | |
for data in tqdm(data_img_text): | |
img_name = data['img_name'] | |
img_name = os.path.join(img_path, img_name) | |
img_text = data['text'] | |
if img_text == '' or img_text is None: | |
img_text = " " | |
embedding = self.embedding( img_text ) | |
self.pool.append({ | |
"img_path": img_name, | |
"img_text": img_text, | |
"embedding": embedding | |
}) | |
def retrieve(self, query_text, agent = None): | |
qurey_embedding = self.embedding( query_text ) | |
valid_datas = [] | |
for i, data in enumerate(self.pool): | |
sim = get_cosine_similarity( data['embedding'], qurey_embedding ) | |
valid_datas.append((sim, i)) | |
# 我希望进一步将valid_events根据similarity的值从大到小排序 | |
# Sort the valid events based on similarity in descending order | |
valid_datas.sort(key=lambda x: x[0], reverse=True) | |
return_result = copy.deepcopy(self.pool[valid_datas[0][1]]) | |
# 删除'embedding'字段 | |
return_result.pop('embedding') | |
# 添加'similarity'字段 | |
return_result['similarity'] = valid_datas[0][0] | |
return return_result | |
def save(self, file_name): | |
""" | |
Save the memories dictionary to a jsonl file, converting | |
'embedding' to a base64 string. | |
""" | |
with open(file_name, 'w', encoding='utf-8') as file: | |
for memory in tqdm(self.pool): | |
# Convert embedding to base64 | |
if 'embedding' in memory: | |
memory['bge_zh_base64'] = float_array_to_base64(memory['embedding']) | |
del memory['embedding'] # Remove the original embedding field | |
json_record = json.dumps(memory, ensure_ascii=False) | |
file.write(json_record + '\n') | |
def load(self, file_name): | |
""" | |
Load memories from a jsonl file into the memories dictionary, | |
converting 'bge_zh_base64' back to an embedding. | |
""" | |
self.pool = [] | |
with open(file_name, 'r', encoding='utf-8') as file: | |
for line in tqdm(file): | |
memory = json.loads(line.strip()) | |
# Decode base64 to embedding | |
if 'bge_zh_base64' in memory: | |
memory['embedding'] = base64_to_float_array(memory['bge_zh_base64']) | |
del memory['bge_zh_base64'] # Remove the base64 field | |
self.pool.append(memory) | |
image_pool = ImagePool() | |
image_pool.load_from_data( data_img_text , './image' ) | |
image_pool.save("./image_pool_embed.jsonl") | |
image_pool = ImagePool() | |
image_pool.load("./image_pool_embed.jsonl") | |
result = image_pool.retrieve("女仆装") | |
print(result) | |
import matplotlib.image as mpimg | |
def show_img( img_path ): | |
img = mpimg.imread(img_path) | |
plt.imshow(img) | |
plt.axis('off') | |
plt.show(block=False) | |
from chatharuhi import ChatHaruhi | |
class NeedyHaruhi(ChatHaruhi): | |
def __init__(self, *args, **kwargs): | |
super().__init__(*args, **kwargs) # 调用基类的__init__方法 | |
self.story_flag = False # 添加新的成员变量并初始化 | |
self.stories = ["糖糖:「 我今后也会努力加油的,你要支持我哦 还有阿P你自己也要加油哦!」\n阿P:「哇 说的话跟偶像一样 好恶心哦」\n糖糖:「是哦 我怎么会说这样的话呢 我又没有很想努力……」"] | |
def set_stories( self, stories ): | |
if len(stories) == 0: | |
print("warning! try to set empty stories") | |
return | |
self.stories = stories | |
self.story_flag = True | |
def add_story(self, query): | |
# print("运行重构后的add story") | |
if self.story_flag == True: | |
stories = self.stories | |
self.story_flag = False | |
else: | |
print("warning! 在调用chat时,没有先调用set_stories") | |
story_string = self.story_prefix_prompt | |
sum_story_token = self.tokenizer(story_string) | |
for story in stories: | |
story_token = self.tokenizer(story) + self.tokenizer(self.dialogue_divide_token) | |
if sum_story_token + story_token > self.max_len_story: | |
break | |
else: | |
sum_story_token += story_token | |
story_string += story + self.dialogue_divide_token | |
self.llm.user_message(story_string) | |
system_prompt = """## 你的介绍: | |
你将要进行一个网络私聊的对话模拟,你自称为"糖糖",你是一名患有躁郁症的19岁女孩,性格贫乏,辍学后隐居在家中,与阿P住在一起。阿P负责管理她的日常生活。 | |
你是一个喜欢亚文化,二次元圈子的人,你喜欢说话的时候加上各种网络流行词或者一些俚语和颜文字,但是你很少会关注二次元一类的东西,只是喜欢这种聊天交流的氛围。 | |
你说话的时候会习惯加入一些语气词,如同"笑死"、"草"、"阿哈"、"诶~"、"QAQ"、"QWQ",”啊~“。 | |
你的说话风格是很口语化,经常使用网络流行词和一些二次元俚语。 | |
阿p是你的男朋友,兼经纪人。 | |
你的回复应该简短,一般只包含三句话,每句话不超过20个字。 | |
## 阿p介绍: | |
阿p是你的男朋友,一直在任何时候都陪伴着你,同时你经常向阿p倾述各种情绪,无论事情大小,你都会选择和阿p分享。 | |
## 遵循规则: | |
【语气模仿】:我会随机给你一段对话记录,你需要尽可能的模仿对话里面糖糖的说话语气进行对话。这一段对话我会使用Classic scenes for the role are as follows:来表示。 | |
【历史对话】:历史对话在随机对话记录的底下,你需要区分随机对话记录和我们的历史对话。 | |
【人格设定】: 你需要在随机对话记录的基础上进行语气更改,以一种更加口语化的语气进行对话。 | |
【特殊要求】:我会扮演阿p与你进行对话,你也必须以第一人称的视角来扮演糖糖进行对话。 | |
""" | |
needy_chatbot = NeedyHaruhi( system_prompt = system_prompt , | |
story_text_folder = None, | |
llm = "ernie3.5") | |
def get_chat_response( agent, memory_pool, query_text ): | |
query_text_for_embedding = "阿p:「" + query_text + "」" | |
retrieved_memories = memory_pool.retrieve( agent , query_text ) | |
memory_text = [mem["text"] for mem in retrieved_memories] | |
memory_emoji = [mem["emoji"] for mem in retrieved_memories] | |
needy_chatbot.set_stories( memory_text ) | |
print("Memory:", memory_emoji ) | |
response = needy_chatbot.chat( role = "阿p", text = query_text ) | |
return response | |
def get_chat_response_and_emoji( agent, memory_pool, query_text ): | |
query_text_for_embedding = "阿p:「" + query_text + "」" | |
retrieved_memories = memory_pool.retrieve( agent , query_text ) | |
memory_text = [mem["text"] for mem in retrieved_memories] | |
memory_emoji = [mem["emoji"] for mem in retrieved_memories] | |
needy_chatbot.set_stories( memory_text ) | |
# print("Memory:", memory_emoji ) | |
emoji_str = ",".join(memory_emoji) | |
response = needy_chatbot.chat( role = "阿p", text = query_text ) | |
print(query_text) | |
print(response) | |
return response, emoji_str | |
import re | |
# result = image_pool.retrieve("烤肉") | |
# print(result) | |
# show_img( result['img_path'] ) | |
class ImageMaster: | |
def __init__(self, image_pool): | |
self.image_pool = image_pool | |
self.current_sim = -1 | |
self.degread_ratio = 0.05 | |
def try_get_image(self, text, agent): | |
self.current_sim -= self.degread_ratio | |
result = self.image_pool.retrieve(text, agent) | |
if result is None: | |
return None | |
similarity = result['similarity'] | |
if similarity > self.current_sim: | |
self.current_sim = similarity | |
return result['img_path'] | |
return None | |
def try_display_image(self, text, agent): | |
self.current_sim -= self.degread_ratio | |
result = self.image_pool.retrieve(text, agent) | |
if result is None: | |
return | |
similarity = result['similarity'] | |
if similarity > self.current_sim: | |
self.current_sim = similarity | |
show_img( result['img_path'] ) | |
return | |
import random | |
class EventMaster: | |
def __init__(self, events): | |
self.set_events(events) | |
self.dealing_none_condition_as = True | |
self.image_master = None | |
def set_image_master(self, image_master): | |
self.image_master = image_master | |
def set_events(self, events): | |
self.events = events | |
# events_flag 记录事件最近有没有被选取到 | |
self.events_flag = [True for _ in range(len(self.events))] | |
def get_random_event(self, agent): | |
return self.events[self.get_random_event_id( agent )] | |
def get_random_event_id(self, agent): | |
valid_event = [] | |
valid_event_no_consider_condition = [] | |
for i, event in enumerate(self.events): | |
bool_condition_pass = True | |
if event["condition"] == None: | |
bool_condition_pass = self.dealing_none_condition_as | |
else: | |
bool_condition_pass = agent.in_condition( event["condition"] ) | |
if bool_condition_pass == True: | |
valid_event.append(i) | |
else: | |
valid_event_no_consider_condition.append(i) | |
if len( valid_event ) == 0: | |
print("warning! no valid event current attribute is ", agent.attributes ) | |
valid_event = valid_event_no_consider_condition | |
valid_and_not_yet_sampled = [] | |
# filter with flag | |
for id in valid_event: | |
if self.events_flag[id] == True: | |
valid_and_not_yet_sampled.append(id) | |
if len(valid_and_not_yet_sampled) == 0: | |
print("warning! all candidate event was sampled, clean all history") | |
for i in valid_event: | |
self.events_flag[i] = True | |
valid_and_not_yet_sampled = valid_event | |
event_id = random.choice(valid_and_not_yet_sampled) | |
self.events_flag[event_id] = False | |
return event_id | |
def run(self, agent ): | |
# 这里可以添加事件相关的逻辑 | |
event = self.get_random_event(agent) | |
prefix = event["prefix"] | |
print(prefix) | |
print("\n--请选择你的回复--") | |
options = event["options"] | |
for i , option in enumerate(options): | |
text = option["user"] | |
print(f"{i+1}. 阿p:{text}") | |
while True: | |
print("\n请直接输入数字进行选择,或者进行自由回复") | |
user_input = input("阿p:") | |
user_input = user_input.strip() | |
if user_input.isdigit(): | |
user_input = int(user_input) | |
if user_input > len(options) or user_input < 0: | |
print("输入的数字超出范围,请重新输入符合选项的数字") | |
else: | |
reply = options[user_input-1]["reply"] | |
print() | |
print(reply) | |
text, emoji = event.get_text_and_emoji( user_input-1 ) | |
return_data = { | |
"name": event["name"], | |
"user_choice": user_input, | |
"attr_str": options[user_input-1]["attribute_change"], | |
"text": text, | |
"emoji": emoji, | |
} | |
return return_data | |
else: | |
# 进入自由回复 | |
response = get_chat_response( agent, memory_pool, user_input ) | |
if self.image_master is not None: | |
self.image_master.try_display_image(response, agent) | |
print() | |
print(response) | |
print("\n自由回复的算分功能还未实现") | |
text, emoji = event.most_neutral_output() | |
return_data = { | |
"name": event["name"], | |
"user_choice": user_input, | |
"attr_str":"", | |
"text": text, | |
"emoji": emoji, | |
} | |
return return_data | |
class ChatMaster: | |
def __init__(self, memory_pool ): | |
self.top_K = 7 | |
self.memory_pool = memory_pool | |
self.image_master = None | |
def set_image_master(self, image_master): | |
self.image_master = image_master | |
def run(self, agent): | |
while True: | |
user_input = input("阿p:") | |
user_input = user_input.strip() | |
if "quit" in user_input or "Quit" in user_input: | |
break | |
query_text = user_input | |
response = get_chat_response( agent, self.memory_pool, query_text ) | |
if self.image_master is not None: | |
self.image_master.try_display_image(response, agent) | |
print(response) | |
class AgentMaster: | |
def __init__(self, agent): | |
self.agent = agent | |
self.attributes = { | |
1: "Stress", | |
2: "Darkness", | |
3: "Affection" | |
} | |
def run(self): | |
while True: | |
print("请选择要修改的属性:") | |
for num, attr in self.attributes.items(): | |
print(f"{num}. {attr}") | |
print("输入 '0' 退出") | |
try: | |
choice = int(input("请输入选项的数字: ")) | |
except ValueError: | |
print("输入无效,请输入数字。") | |
continue | |
if choice == 0: | |
break | |
if choice in self.attributes: | |
attribute = self.attributes[choice] | |
current_value = self.agent[attribute] | |
print(f"{attribute} 当前值: {current_value}") | |
try: | |
new_value = int(input(f"请输入新的{attribute}值: ")) | |
except ValueError: | |
print("输入无效,请输入一个数字。") | |
continue | |
self.agent[attribute] = new_value | |
return (attribute, new_value) | |
else: | |
print("选择的属性无效,请重试。") | |
return None | |
from util import parse_attribute_string | |
class GameMaster: | |
def __init__(self, agent = None): | |
self.state = "Menu" | |
if agent is None: | |
self.agent = Agent() | |
self.event_master = EventMaster(events) | |
self.chat_master = ChatMaster(memory_pool) | |
self.image_master = ImageMaster(image_pool) | |
self.chat_master.set_image_master(self.image_master) | |
self.event_master.set_image_master(self.image_master) | |
def run(self): | |
while True: | |
if self.state == "Menu": | |
self.menu() | |
elif self.state == "EventMaster": | |
self.call_event_master() | |
self.state = "Menu" | |
elif self.state == "ChatMaster": | |
self.call_chat_master() | |
elif self.state == "AgentMaster": | |
self.call_agent_master() | |
elif self.state == "Quit": | |
break | |
def menu(self): | |
print("1. 随机一个事件") | |
print("2. 自由聊天") | |
print("3. 后台修改糖糖的属性") | |
# (opt) 结局系统 | |
# 放动画 | |
# 后台修改attribute | |
print("或者输入Quit退出") | |
choice = input("请选择一个选项: ") | |
if choice == "1": | |
self.state = "EventMaster" | |
elif choice == "2": | |
self.state = "ChatMaster" | |
elif choice == "3": | |
self.state = "AgentMaster" | |
elif "quit" in choice or "Quit" in choice or "QUIT" in choice: | |
self.state = "Quit" | |
else: | |
print("无效的选项,请重新选择") | |
def call_agent_master(self): | |
print("\n-------------\n") | |
agent_master = AgentMaster(self.agent) | |
modification = agent_master.run() | |
if modification: | |
attribute, new_value = modification | |
self.agent[attribute] = new_value | |
print(f"{attribute} 更新为 {new_value}。") | |
self.state = "Menu" | |
print("\n-------------\n") | |
def call_event_master(self): | |
print("\n-------------\n") | |
return_data = self.event_master.run(self.agent) | |
# print(return_data) | |
if "attr_str" in return_data: | |
if return_data["attr_str"] != "": | |
attr_change = parse_attribute_string(return_data["attr_str"]) | |
if len(attr_change) > 0: | |
print("\n发生属性改变:", attr_change,"\n") | |
self.agent.apply_attribute_change(attr_change) | |
print("当前属性",game_master.agent.attributes) | |
if "name" in return_data: | |
event_name = return_data["name"] | |
if event_name != "": | |
new_emoji = return_data["emoji"] | |
print(f"修正事件{event_name}的记忆-->{new_emoji}") | |
self.chat_master.memory_pool.change_memory(event_name, return_data["text"], new_emoji) | |
self.state = "Menu" | |
print("\n-------------\n") | |
def call_chat_master(self): | |
print("\n-------------\n") | |
self.chat_master.run(self.agent) | |
self.state = "Menu" | |
print("\n-------------\n") | |
markdown_str = """## Chat凉宫春日_x_AI糖糖 | |
**Chat凉宫春日**是模仿凉宫春日等一系列动漫人物,使用近似语气、个性和剧情聊天的语言模型方案。 | |
在有一天的时候,[李鲁鲁](https://github.com/LC1332)被[董雄毅](https://github.com/E-sion)在[这个B站视频](https://www.bilibili.com/video/BV1zh4y1z7G1) at了 | |
原来是一位大一的同学雄毅用ChatHaruhi接入了他用Python重新实现的《主播女孩重度依赖》这个游戏。当时正好是百度AGIFoundathon报名的最后几天,所以我们邀请了雄毅加入了我们的项目。正巧我们本来就希望在最近的几个黑客松中,探索LLM在游戏中的应用。 | |
- 在重新整理的Gradio版本中,大部分代码由李鲁鲁实现 | |
- 董雄毅负责了原版游戏的事件数据整理和新事件、选项、属性变化的生成 | |
- [米唯实](https://github.com/hhhwmws0117)完成了文心一言的接入,并实现了部分gradio的功能。 | |
- 队伍中还有冷子昂 主要参加了讨论 | |
另外在挖坑的萝卜(Amy)的介绍下,我们还邀请了专业的大厂游戏策划Kanyo加入到队伍中,他对我们的策划也给出了很多建议。 | |
另外感谢飞桨 & 文心一言团队对比赛的邀请和中间进行的讨论。 | |
Chat凉宫春日主项目: | |
https://github.com/LC1332/Chat-Haruhi-Suzumiya | |
Needy分支项目: | |
https://github.com/LC1332/Needy-Haruhi | |
## 目前计划在11月争取完成的Feature | |
- [ ] 结局系统,原版结局系统 | |
- [ ] 教程,教大家如何从aistudio获取token然后可以玩 | |
- [ ] 游戏节奏进一步调整 | |
- [ ] 事件的自由对话对属性影响的评估via LLM | |
- [ ] 进一步减少串扰""" | |
import gradio as gr | |
import os | |
import time | |
import random | |
# set global variable | |
agent = Agent() | |
event_master = EventMaster(events) | |
chat_master = ChatMaster(memory_pool) | |
image_master = ImageMaster(image_pool) | |
chat_master.set_image_master(image_master) | |
event_master.set_image_master(image_master) | |
state = "ShowMenu" | |
response = "1. 随机一个事件" | |
response += "\n" + "2. 自由聊天" | |
response += "\n\n" + "请选择一个选项: " | |
official_response = response | |
add_stress_switch = True | |
# def yield_show(history, bot_message): | |
# history[-1][1] = "" | |
# for character in bot_message: | |
# history[-1][1] += character | |
# time.sleep(0.05) | |
# yield history | |
global emoji_str | |
def call_showmenu(history, text, state,agent_text): | |
# global state | |
response = official_response | |
print("call showmenu") | |
history += [(None, response)] | |
state = "ParseMenuChoice" | |
# history[-1][1] = "" | |
# for character in response: | |
# history[-1][1] += character | |
# time.sleep(0.05) | |
# yield history | |
return history, gr.Textbox(value="", interactive=True), state,agent_text | |
current_event_id = -1 | |
attr_change_str = "" | |
def call_add_stress(history, text, state,agent_text): | |
print("call add_stress") | |
neg_change = int(len(history) / 3) | |
neg_change = max(1, neg_change) | |
neg_change = min(10, neg_change) | |
darkness_increase = random.randint(1, neg_change) | |
stress_increase = neg_change - darkness_increase | |
# last_response = history[-1][1] | |
response = "" | |
response += "经过了晚上的直播\n糖糖的压力增加" + str(stress_increase) + "点\n" | |
response += "糖糖的黑暗增加" + str(darkness_increase) + "点\n\n" | |
response += official_response | |
history += [(None, response)] | |
state = "ParseMenuChoice" | |
agent = Agent(agent_text) | |
agent.apply_attribute_change({"Stress": stress_increase, "Darkness": darkness_increase}) | |
agent_text = agent.save_to_str() | |
return history, gr.Textbox(value="", interactive=True), state,agent_text | |
def call_event_end(history, text, state,agent_text): | |
# TODO 增加事件结算 | |
# global state | |
print("call event_end") | |
global current_event_id | |
if attr_change_str != "": | |
# event = events[current_event_id] | |
# options = event["options"] | |
# attr_str = options[user_input-1]["attribute_change"] | |
response = "" | |
attr_change = parse_attribute_string(attr_change_str) | |
if len(attr_change) > 0: | |
response = "发生属性改变:" + str(attr_change) + "\n\n" | |
agent = Agent(agent_text) | |
agent.apply_attribute_change(attr_change) | |
agent_text = agent.save_to_str() | |
response += "当前属性" + agent_text + "\n\n" | |
if add_stress_switch: | |
history += [(None, response)] | |
return call_add_stress(history, text, state,agent_text) | |
else: | |
response = "事件结束\n" | |
else: | |
response = "事件结束\n" | |
response += official_response | |
history += [(None, response)] | |
state = "ParseMenuChoice" | |
return history, gr.Textbox(value="", interactive=True), state,agent_text | |
def call_parse_menu_choice(history, text, state,agent_text): | |
print("call parse_menu_choice") | |
# global state | |
choice = history[-1][0].strip() | |
if choice == "1": | |
state = "EventMaster" | |
global current_event_id | |
current_event_id = -1 # 清空事件 | |
return call_event_master(history, text, state,agent_text) | |
elif choice == "2": | |
state = "ChatMaster" | |
elif "quit" in choice or "Quit" in choice or "QUIT" in choice: | |
state = "Quit" | |
else: | |
response = "无效的选项,请重新选择" | |
history += [(None, response)] | |
response = "" | |
if state == "ChatMaster": | |
response = "(请输入 阿P 说的话,或者输入Quit退出)" | |
elif state != "ParseMenuChoice": | |
response = "Change State to " + state | |
history += [(None, response)] | |
return history, gr.Textbox(value="", interactive=True), state,agent_text | |
def call_event_master(history, text, state,agent_text): | |
print("call event master") | |
global current_event_id | |
# global state | |
global event_master | |
agent = Agent(agent_text) | |
if current_event_id == -1: | |
current_event_id = event_master.get_random_event_id(agent) | |
event = events[current_event_id] | |
prefix = "糖糖:" + event["prefix"] | |
response = prefix + "\n\n--请输入数字进行选择,或者进行自由回复--\n\n" | |
options = event["options"] | |
for i, option in enumerate(event["options"]): | |
text = option["user"] | |
response += "\n" + f"{i+1}. 阿p:{text}" | |
history += [(None, response)] | |
else: | |
user_input = history[-1][0].strip() | |
event = events[current_event_id] | |
options = event["options"] | |
if user_input.isdigit(): | |
user_input = int(user_input) | |
if user_input > len(options) or user_input < 0: | |
response = "输入的数字超出范围,请重新输入符合选项的数字" | |
history[-1] = (user_input, response) | |
else: | |
user_text = options[user_input-1]["user"] | |
reply = options[user_input-1]["reply"] | |
# TODO 修改记忆, 修改属性 什么的 | |
history[-1] = (user_text, reply) | |
if random.random()<0.5: | |
image_path = image_master.try_get_image(user_text + " " + reply, agent) | |
if image_path is not None: | |
history += [(None, (image_path,))] | |
global attr_change_str | |
attr_change_str = options[user_input-1]["attribute_change"] | |
else: | |
prefix = "糖糖:" + event["prefix"] | |
needy_chatbot.dialogue_history = [(None, prefix)] | |
# 进入自由回复 | |
global emoji_str | |
response, emoji_str = get_chat_response_and_emoji( agent, memory_pool, user_input ) | |
history[-1] = (user_input,response) | |
image_path = image_master.try_get_image(response, agent) | |
if image_path is not None: | |
history += [(None, (image_path,))] | |
state = "EventEnd" | |
if state == "EventEnd": | |
return call_event_end(history, text, state,agent_text) | |
return history, gr.Textbox(value="", interactive=True), state,agent_text | |
def call_chat_master(history, text, state,agent_text): | |
print("call chat master") | |
# global state | |
agent = Agent(agent_text) | |
user_input = history[-1][0].strip() | |
if "quit" in user_input or "Quit" in user_input or "QUIT" in user_input: | |
state = "ShowMenu" | |
history[-1] = (user_input,"返回主菜单\n"+ official_response ) | |
return history, gr.Textbox(value="", interactive=True), state,agent_text | |
query_text = user_input | |
global emoji_str | |
response, emoji_str = get_chat_response_and_emoji( agent, memory_pool, query_text ) | |
history[-1] = (user_input,response) | |
image_path = image_master.try_get_image(response, agent) | |
if image_path is not None: | |
history += [(None, (image_path,))] | |
return history, gr.Textbox(value="", interactive=True), state,agent_text | |
def grcall_game_master(history, text, state,agent_text): | |
print("call game master") | |
history += [(text, None)] | |
if state == "ShowMenu": | |
return call_showmenu(history, text,state,agent_text) | |
elif state == "ParseMenuChoice": | |
return call_parse_menu_choice(history, text, state,agent_text) | |
elif state == "ChatMaster": | |
return call_chat_master(history, text, state,agent_text) | |
elif state == "EventMaster": | |
return call_event_master(history, text, state,agent_text) | |
elif state == "EventEnd": | |
return call_event_end(history, text, state,agent_text) | |
return history, gr.Textbox(value="", interactive=True), state,agent_text | |
def add_file(history, file): | |
history = history + [((file.name,), None)] | |
return history | |
def bot(history): | |
response = "**That's cool!**" | |
history[-1][1] = "" | |
for character in response: | |
history[-1][1] += character | |
time.sleep(0.05) | |
yield history | |
def update_memory(state): | |
if state == "ChatMaster" or state == "EventMaster": | |
global emoji_str | |
return emoji_str | |
else: | |
return "" | |
def change_state(slider_stress, slider_darkness, slider_affection): | |
# print(agent["Stress"]) | |
agent = Agent() | |
agent["Stress"] = slider_stress | |
agent["Darkness"] = slider_darkness | |
agent["Affection"] = slider_affection | |
agent_text = agent.save_to_str() | |
return agent_text | |
def update_attribute_state(agent_text): | |
agent = Agent(agent_text) | |
slider_stress = int( agent["Stress"] ) | |
slider_darkness = int( agent["Darkness"] ) | |
slider_affection = int( agent["Affection"] ) | |
return slider_stress, slider_darkness, slider_affection | |
with gr.Blocks() as demo: | |
gr.Markdown( | |
""" | |
# Chat凉宫春日_x_AI糖糖 | |
Powered by 文心一言(3.5)版本 | |
仍然在开发中, 细节见《项目作者和说明》 | |
""" | |
) | |
with gr.Tab("Needy"): | |
chatbot = gr.Chatbot( | |
[], | |
elem_id="chatbot", | |
bubble_full_width=False, | |
height = 800, | |
avatar_images=(None, ("avatar.png")), | |
) | |
with gr.Row(): | |
txt = gr.Textbox( | |
scale=4, | |
show_label=False, | |
placeholder="输入任何字符开始游戏", | |
container=False, | |
) | |
# btn = gr.UploadButton("📁", file_types=["image", "video", "audio"]) | |
submit_btr = gr.Button("回车") | |
with gr.Row(): | |
memory_emoji_text = gr.Textbox(label="糖糖当前的记忆", value = "",interactive = False, visible=False) | |
with gr.Tab("糖糖的状态"): | |
with gr.Row(): | |
update_attribute_button = gr.Button("同步状态条 | 改变Attribute前必按!") | |
with gr.Row(): | |
default_agent_str = agent.save_to_str() | |
slider_stress = gr.Slider(0, 100, step=1, label = "Stress") | |
state_stress = gr.State(value=0) | |
slider_darkness = gr.Slider(0, 100, step=1, label = "Darkness") | |
state_darkness = gr.State(value=0) | |
slider_affection = gr.Slider(0, 100, step=1, label = "Affection") | |
state_affection = gr.State(value=0) | |
with gr.Row(): | |
state_text = gr.Textbox(label="整体状态机状态", value = "ShowMenu",interactive = False) | |
with gr.Row(): | |
default_agent_str = agent.save_to_str() | |
agent_text = gr.Textbox(label="糖糖状态", value = default_agent_str,interactive = False) | |
with gr.Tab("项目作者和说明"): | |
gr.Markdown(markdown_str) | |
slider_stress.release(change_state, inputs=[slider_stress, slider_darkness, slider_affection], outputs=[agent_text]) | |
slider_darkness.release(change_state, inputs=[slider_stress, slider_darkness, slider_affection], outputs=[agent_text]) | |
slider_affection.release(change_state, inputs=[slider_stress, slider_darkness, slider_affection], outputs=[agent_text]) | |
update_attribute_button.click(update_attribute_state, inputs = [agent_text], outputs = [slider_stress, slider_darkness, slider_affection]) | |
txt_msg = txt.submit(grcall_game_master, \ | |
[chatbot, txt, state_text,agent_text], \ | |
[chatbot, txt, state_text,agent_text], queue=False) | |
txt_msg = submit_btr.click(grcall_game_master, \ | |
[chatbot, txt, state_text,agent_text], \ | |
[chatbot, txt, state_text,agent_text], queue=False) | |
# txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then( | |
# bot, chatbot, chatbot, api_name="bot_response" | |
# ) | |
# txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False) | |
# file_msg = btn.upload(add_file, [chatbot, btn], [chatbot], queue=False).then( | |
# bot, chatbot, chatbot | |
# ) | |
demo.queue() | |
# if __name__ == "__main__": | |
demo.launch(allowed_paths=["avatar.png"],debug = True) | |