Spaces:
Runtime error
Runtime error
File size: 4,787 Bytes
4222961 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
from tqdm import tqdm
from util import float_array_to_base64, base64_to_float_array
from util import get_bge_embedding_zh
import json
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# compute cosine similarity between two vector
def get_cosine_similarity( v1, v2):
v1 = torch.tensor(v1).to(device)
v2 = torch.tensor(v2).to(device)
return torch.cosine_similarity(v1, v2, dim=0).item()
class MemoryPool:
def __init__(self):
self.memories = {}
self.diff_threshold = 20
self.top_k = 7
self.set_embedding( get_bge_embedding_zh )
def set_embedding( self, embedding ):
self.embedding = embedding
def load_from_events( self, events ):
for event in tqdm( events ):
if len(event["options"])>0:
text, emoji = event.most_neutral_output()
else:
text = event["prefix"]
emoji = event["prefix_emoji"]
embedding = self.embedding( text )
condition = event["condition"]
if condition is None:
memory_attribute = ("Stress", 10 )
else:
memory_attribute = (condition[0],(condition[1]+ condition[2])//2 )
name = event["name"]
memory = {
"name": name,
"text": text,
"embedding": embedding,
"memory_attribute": memory_attribute,
"emoji": emoji # TODO
}
self.memories[ name ] = memory
# 我希望为这个类进一步实现save和load函数,save函数可以将memories中的每一个value对应的dict,存储到一个jsonl中,load函数可以读取回来。注意编码都要使用utf-8, ensure_ascii = False
# 我希望修改save和load函数
# 其中memory中会有embedding字段
# from util import float_array_to_base64
# from util import base64_to_float_array
# 我希望在save的时候,把embedding字段用float_array_to_base64替换为base64字符串,并且字段改名为bge_zh_base64
# 在load的时候再把bge_zh_base64字段用base64_to_float_array,解码为embedding
def save(self, file_name):
"""
Save the memories dictionary to a jsonl file, converting
'embedding' to a base64 string.
"""
with open(file_name, 'w', encoding='utf-8') as file:
for memory in tqdm(self.memories.values()):
# Convert embedding to base64
if 'embedding' in memory:
memory['bge_zh_base64'] = float_array_to_base64(memory['embedding'])
del memory['embedding'] # Remove the original embedding field
json_record = json.dumps(memory, ensure_ascii=False)
file.write(json_record + '\n')
def load(self, file_name):
"""
Load memories from a jsonl file into the memories dictionary,
converting 'bge_zh_base64' back to an embedding.
"""
with open(file_name, 'r', encoding='utf-8') as file:
for line in tqdm(file):
memory = json.loads(line.strip())
# Decode base64 to embedding
if 'bge_zh_base64' in memory:
memory['embedding'] = base64_to_float_array(memory['bge_zh_base64'])
del memory['bge_zh_base64'] # Remove the base64 field
self.memories[memory['name']] = memory
def change_memory( self, memory_name , new_text , new_emoji = None):
if memory_name in self.memories:
memory = self.memories[memory_name]
memory["text"] = new_text
memory["embedding"] = self.embedding( new_text )
if new_emoji:
memory["emoji"] = new_emoji
def retrieve( self, agent, query_text ):
query_embedding = self.embedding( query_text )
valid_events = []
# filter valid memory
for key in self.memories:
memory = self.memories[key]
attribute, value = memory["memory_attribute"]
if abs(agent[attribute] - value) <= self.diff_threshold:
# valid memory
simlarity = get_cosine_similarity(query_embedding, memory["embedding"])
valid_events.append((simlarity, key) )
# 我希望进一步将valid_events根据similarity的值从大到小排序
# Sort the valid events based on similarity in descending order
valid_events.sort(key=lambda x: x[0], reverse=True)
result = []
for _,key in valid_events:
result.append(self.memories[key])
if len(result)>=self.top_k:
break
return result |