Spaces:
Runtime error
Runtime error
Upload 15 files
Browse files- ChatHaruhi/ChatHaruhi.py +471 -0
- ChatHaruhi/NaiveDB.py +88 -0
- ChatHaruhi/Readme.md +352 -0
- ChatHaruhi/SparkApi.py +140 -0
- ChatHaruhi/__init__.py +2 -0
- ChatHaruhi/embeddings.py +270 -0
- ChatHaruhi/novel_extract.py +176 -0
- ChatHaruhi/response_GLM_local.py +121 -0
- ChatHaruhi/response_GLM_lora.py +133 -0
- ChatHaruhi/response_erniebot.py +90 -0
- ChatHaruhi/response_openai.py +65 -0
- ChatHaruhi/response_spark.py +51 -0
- ChatHaruhi/response_zhipu.py +54 -0
- ChatHaruhi/sugar_map.py +30 -0
- ChatHaruhi/utils.py +89 -0
ChatHaruhi/ChatHaruhi.py
ADDED
@@ -0,0 +1,471 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .utils import base64_to_float_array, base64_to_string
|
2 |
+
|
3 |
+
def get_text_from_data( data ):
|
4 |
+
if "text" in data:
|
5 |
+
return data['text']
|
6 |
+
elif "enc_text" in data:
|
7 |
+
# from .utils import base64_to_string
|
8 |
+
return base64_to_string( data['enc_text'] )
|
9 |
+
else:
|
10 |
+
print("warning! failed to get text from data ", data)
|
11 |
+
return ""
|
12 |
+
|
13 |
+
def parse_rag(text):
|
14 |
+
lines = text.split("\n")
|
15 |
+
ans = []
|
16 |
+
|
17 |
+
for i, line in enumerate(lines):
|
18 |
+
if "{{RAG对话}}" in line:
|
19 |
+
ans.append({"n": 1, "max_token": -1, "query": "default", "lid": i})
|
20 |
+
elif "{{RAG对话|" in line:
|
21 |
+
query_info = line.split("|")[1].rstrip("}}")
|
22 |
+
ans.append({"n": 1, "max_token": -1, "query": query_info, "lid": i})
|
23 |
+
elif "{{RAG多对话|" in line:
|
24 |
+
parts = line.split("|")
|
25 |
+
max_token = int(parts[1].split("<=")[1])
|
26 |
+
max_n = int(parts[2].split("<=")[1].rstrip("}}"))
|
27 |
+
ans.append({"n": max_n, "max_token": max_token, "query": "default", "lid": i})
|
28 |
+
|
29 |
+
return ans
|
30 |
+
|
31 |
+
class ChatHaruhi:
|
32 |
+
def __init__(self,
|
33 |
+
role_name = None,
|
34 |
+
user_name = None,
|
35 |
+
persona = None,
|
36 |
+
stories = None,
|
37 |
+
story_vecs = None,
|
38 |
+
role_from_hf = None,
|
39 |
+
role_from_jsonl = None,
|
40 |
+
llm = None, # 默认的message2response的函数
|
41 |
+
llm_async = None, # 默认的message2response的async函数
|
42 |
+
user_name_in_message = "default",
|
43 |
+
verbose = None,
|
44 |
+
embed_name = None,
|
45 |
+
embedding = None,
|
46 |
+
db = None,
|
47 |
+
token_counter = "default",
|
48 |
+
max_input_token = 1800,
|
49 |
+
max_len_story_haruhi = 1000,
|
50 |
+
max_story_n_haruhi = 5
|
51 |
+
):
|
52 |
+
|
53 |
+
self.verbose = True if verbose is None or verbose else False
|
54 |
+
|
55 |
+
self.db = db
|
56 |
+
|
57 |
+
self.embed_name = embed_name
|
58 |
+
|
59 |
+
self.max_len_story_haruhi = max_len_story_haruhi # 这个设置只对过往Haruhi的sugar角色有效
|
60 |
+
self.max_story_n_haruhi = max_story_n_haruhi # 这个设置只对过往Haruhi的sugar角色有效
|
61 |
+
|
62 |
+
self.last_query_msg = None
|
63 |
+
|
64 |
+
if embedding is None:
|
65 |
+
self.embedding = self.set_embedding_with_name( embed_name )
|
66 |
+
|
67 |
+
if persona and role_name and stories and story_vecs and len(stories) == len(story_vecs):
|
68 |
+
# 完全从外部设置,这个时候要求story_vecs和embedding的返回长度一致
|
69 |
+
self.persona, self.role_name, self.user_name = persona, role_name, user_name
|
70 |
+
self.build_db(stories, story_vecs)
|
71 |
+
elif persona and role_name and stories:
|
72 |
+
# 从stories中提取story_vecs,重新用self.embedding进行embedding
|
73 |
+
story_vecs = self.extract_story_vecs(stories)
|
74 |
+
self.persona, self.role_name, self.user_name = persona, role_name, user_name
|
75 |
+
self.build_db(stories, story_vecs)
|
76 |
+
elif role_from_hf:
|
77 |
+
# 从hf加载role
|
78 |
+
self.persona, new_role_name, self.stories, self.story_vecs = self.load_role_from_hf(role_from_hf)
|
79 |
+
if new_role_name:
|
80 |
+
self.role_name = new_role_name
|
81 |
+
else:
|
82 |
+
self.role_name = role_name
|
83 |
+
self.user_name = user_name
|
84 |
+
self.build_db(self.stories, self.story_vecs)
|
85 |
+
elif role_from_jsonl:
|
86 |
+
# 从jsonl加载role
|
87 |
+
self.persona, new_role_name, self.stories, self.story_vecs = self.load_role_from_jsonl(role_from_jsonl)
|
88 |
+
if new_role_name:
|
89 |
+
self.role_name = new_role_name
|
90 |
+
else:
|
91 |
+
self.role_name = role_name
|
92 |
+
self.user_name = user_name
|
93 |
+
self.build_db(self.stories, self.story_vecs)
|
94 |
+
elif persona and role_name:
|
95 |
+
# 这个时候也就是说没有任何的RAG,
|
96 |
+
self.persona, self.role_name, self.user_name = persona, role_name, user_name
|
97 |
+
self.db = None
|
98 |
+
elif role_name and self.check_sugar( role_name ):
|
99 |
+
# 这个时候是sugar的role
|
100 |
+
self.persona, self.role_name, self.stories, self.story_vecs = self.load_role_from_sugar( role_name )
|
101 |
+
self.build_db(self.stories, self.story_vecs)
|
102 |
+
# 与 江YH讨论 所有的载入方式都要在外部使用 add_rag_prompt_after_persona() 防止混淆
|
103 |
+
# self.add_rag_prompt_after_persona()
|
104 |
+
else:
|
105 |
+
raise ValueError("persona和role_name必须同时设置,或者role_name是ChatHaruhi的预设人物")
|
106 |
+
|
107 |
+
self.llm, self.llm_async = llm, llm_async
|
108 |
+
if not self.llm and self.verbose:
|
109 |
+
print("warning, llm没有设置,仅get_message起作用,调用chat将回复idle message")
|
110 |
+
|
111 |
+
self.user_name_in_message = user_name_in_message
|
112 |
+
self.previous_user_pool = set([user_name]) if user_name else set()
|
113 |
+
self.current_user_name_in_message = user_name_in_message.lower() == "add"
|
114 |
+
|
115 |
+
self.idle_message = "idel message, you see this because self.llm has not been set."
|
116 |
+
|
117 |
+
if token_counter.lower() == "default":
|
118 |
+
# TODO change load from util
|
119 |
+
from .utils import tiktoken_counter
|
120 |
+
self.token_counter = tiktoken_counter
|
121 |
+
elif token_counter == None:
|
122 |
+
self.token_counter = lambda x: 0
|
123 |
+
else:
|
124 |
+
self.token_counter = token_counter
|
125 |
+
if self.verbose:
|
126 |
+
print("user set costomized token_counter")
|
127 |
+
|
128 |
+
self.max_input_token = max_input_token
|
129 |
+
|
130 |
+
self.history = []
|
131 |
+
|
132 |
+
def check_sugar(self, role_name):
|
133 |
+
from .sugar_map import sugar_role_names, enname2zhname
|
134 |
+
return role_name in sugar_role_names
|
135 |
+
|
136 |
+
def load_role_from_sugar(self, role_name):
|
137 |
+
from .sugar_map import sugar_role_names, enname2zhname
|
138 |
+
en_role_name = sugar_role_names[role_name]
|
139 |
+
new_role_name = enname2zhname[en_role_name]
|
140 |
+
role_from_hf = "silk-road/ChatHaruhi-RolePlaying/" + en_role_name
|
141 |
+
persona, _, stories, story_vecs = self.load_role_from_hf(role_from_hf)
|
142 |
+
|
143 |
+
return persona, new_role_name, stories, story_vecs
|
144 |
+
|
145 |
+
def add_rag_prompt_after_persona( self ):
|
146 |
+
rag_sentence = "{{RAG多对话|token<=" + str(self.max_len_story_haruhi) + "|n<=" + str(self.max_story_n_haruhi) + "}}"
|
147 |
+
self.persona += "Classic scenes for the role are as follows:\n" + rag_sentence + "\n"
|
148 |
+
|
149 |
+
def set_embedding_with_name(self, embed_name):
|
150 |
+
if embed_name is None or embed_name == "bge_zh":
|
151 |
+
from .embeddings import get_bge_zh_embedding
|
152 |
+
self.embed_name = "bge_zh"
|
153 |
+
return get_bge_zh_embedding
|
154 |
+
elif embed_name == "foo":
|
155 |
+
from .embeddings import foo_embedding
|
156 |
+
return foo_embedding
|
157 |
+
elif embed_name == "bce":
|
158 |
+
from .embeddings import foo_bce
|
159 |
+
return foo_bce
|
160 |
+
elif embed_name == "openai" or embed_name == "luotuo_openai":
|
161 |
+
from .embeddings import foo_openai
|
162 |
+
return foo_openai
|
163 |
+
|
164 |
+
def set_new_user(self, user):
|
165 |
+
if len(self.previous_user_pool) > 0 and user not in self.previous_user_pool:
|
166 |
+
if self.user_name_in_message.lower() == "default":
|
167 |
+
if self.verbose:
|
168 |
+
print(f'new user {user} included in conversation')
|
169 |
+
self.current_user_name_in_message = True
|
170 |
+
self.user_name = user
|
171 |
+
self.previous_user_pool.add(user)
|
172 |
+
|
173 |
+
def chat(self, user, text):
|
174 |
+
self.set_new_user(user)
|
175 |
+
message = self.get_message(user, text)
|
176 |
+
if self.llm:
|
177 |
+
response = self.llm(message)
|
178 |
+
self.append_message(response)
|
179 |
+
return response
|
180 |
+
return None
|
181 |
+
|
182 |
+
async def async_chat(self, user, text):
|
183 |
+
self.set_new_user(user)
|
184 |
+
message = self.get_message(user, text)
|
185 |
+
if self.llm_async:
|
186 |
+
response = await self.llm_async(message)
|
187 |
+
self.append_message(response)
|
188 |
+
return response
|
189 |
+
|
190 |
+
def parse_rag_from_persona(self, persona, text = None):
|
191 |
+
#每个query_rag需要饱含
|
192 |
+
# "n" 需要几个story
|
193 |
+
# "max_token" 最多允许多少个token,如果-1则不限制
|
194 |
+
# "query" 需要查询的内容,如果等同于"default"则替换为text
|
195 |
+
# "lid" 需要替换的行,这里直接进行行替换,忽视行的其他内容
|
196 |
+
|
197 |
+
query_rags = parse_rag( persona )
|
198 |
+
|
199 |
+
if text is not None:
|
200 |
+
for rag in query_rags:
|
201 |
+
if rag['query'] == "default":
|
202 |
+
rag['query'] = text
|
203 |
+
|
204 |
+
return query_rags, self.token_counter(persona)
|
205 |
+
|
206 |
+
def append_message( self, response , speaker = None ):
|
207 |
+
if self.last_query_msg is not None:
|
208 |
+
self.history.append(self.last_query_msg)
|
209 |
+
self.last_query_msg = None
|
210 |
+
|
211 |
+
if speaker is None:
|
212 |
+
# 如果role是none,则认为是本角色{{role}}输出的句子
|
213 |
+
self.history.append({"speaker":"{{role}}","content":response})
|
214 |
+
# 叫speaker是为了和role进行区分
|
215 |
+
else:
|
216 |
+
self.history.append({"speaker":speaker,"content":response})
|
217 |
+
|
218 |
+
def check_recompute_stories_token(self):
|
219 |
+
return len(self.db.metas) == len(self.db.stories)
|
220 |
+
|
221 |
+
def recompute_stories_token(self):
|
222 |
+
self.db.metas = [self.token_counter(story) for story in self.db.stories]
|
223 |
+
|
224 |
+
def rag_retrieve( self, query, n, max_token, avoid_ids = [] ):
|
225 |
+
# 返回一个rag_id的列表
|
226 |
+
query_vec = self.embedding(query)
|
227 |
+
|
228 |
+
self.db.clean_flag()
|
229 |
+
self.db.disable_story_with_ids( avoid_ids )
|
230 |
+
|
231 |
+
retrieved_ids = self.db.search( query_vec, n )
|
232 |
+
|
233 |
+
if self.check_recompute_stories_token():
|
234 |
+
self.recompute_stories_token()
|
235 |
+
|
236 |
+
sum_token = 0
|
237 |
+
|
238 |
+
ans = []
|
239 |
+
|
240 |
+
for i in range(0, len(retrieved_ids)):
|
241 |
+
if i == 0:
|
242 |
+
sum_token += self.db.metas[retrieved_ids[i]]
|
243 |
+
ans.append(retrieved_ids[i])
|
244 |
+
continue
|
245 |
+
else:
|
246 |
+
sum_token += self.db.metas[retrieved_ids[i]]
|
247 |
+
if sum_token <= max_token:
|
248 |
+
ans.append(retrieved_ids[i])
|
249 |
+
else:
|
250 |
+
break
|
251 |
+
|
252 |
+
return ans
|
253 |
+
|
254 |
+
|
255 |
+
def rag_retrieve_all( self, query_rags, rest_limit ):
|
256 |
+
# 返回一个rag_ids的列表
|
257 |
+
retrieved_ids = []
|
258 |
+
rag_ids = []
|
259 |
+
|
260 |
+
for query_rag in query_rags:
|
261 |
+
query = query_rag['query']
|
262 |
+
n = query_rag['n']
|
263 |
+
max_token = rest_limit
|
264 |
+
if rest_limit > query_rag['max_token'] and query_rag['max_token'] > 0:
|
265 |
+
max_token = query_rag['max_token']
|
266 |
+
|
267 |
+
rag_id = self.rag_retrieve( query, n, max_token, avoid_ids = retrieved_ids )
|
268 |
+
rag_ids.append( rag_id )
|
269 |
+
retrieved_ids += rag_id
|
270 |
+
|
271 |
+
return rag_ids
|
272 |
+
|
273 |
+
def append_history_under_limit(self, message, rest_limit):
|
274 |
+
# 返回一个messages的列表
|
275 |
+
# print("call append history_under_limit")
|
276 |
+
# 从后往前计算token,不超过rest limit,
|
277 |
+
# 如果speaker是{{role}J,则message的role是assistant
|
278 |
+
current_limit = rest_limit
|
279 |
+
|
280 |
+
history_list = []
|
281 |
+
|
282 |
+
for item in reversed(self.history):
|
283 |
+
current_token = self.token_counter(item['content'])
|
284 |
+
current_limit -= current_token
|
285 |
+
if current_limit < 0:
|
286 |
+
break
|
287 |
+
else:
|
288 |
+
history_list.append(item)
|
289 |
+
|
290 |
+
history_list = list(reversed(history_list))
|
291 |
+
|
292 |
+
# TODO: 之后为了解决多人对话,这了content还会额外增加speaker: content这样的信息
|
293 |
+
|
294 |
+
for item in history_list:
|
295 |
+
if item['speaker'] == "{{role}}":
|
296 |
+
message.append({"role":"assistant","content":item['content']})
|
297 |
+
else:
|
298 |
+
message.append({"role":"user","content":item['content']})
|
299 |
+
|
300 |
+
return message
|
301 |
+
|
302 |
+
def get_message(self, user, text):
|
303 |
+
query_token = self.token_counter(text)
|
304 |
+
|
305 |
+
# 首先获取需要多少个rag story
|
306 |
+
query_rags, persona_token = self.parse_rag_from_persona( self.persona, text )
|
307 |
+
#每个query_rag需要饱含
|
308 |
+
# "n" 需要几个story
|
309 |
+
# "max_token" 最多允许多少个token,如果-1则不限制
|
310 |
+
# "query" 需要查询的内容,如果等同于"default"则替换为text
|
311 |
+
# "lid" 需要替换的行,这里直接进行行替换,忽视行的其他内容
|
312 |
+
|
313 |
+
|
314 |
+
|
315 |
+
rest_limit = self.max_input_token - persona_token - query_token
|
316 |
+
|
317 |
+
if self.verbose:
|
318 |
+
print(f"query_rags: {query_rags} rest_limit = { rest_limit }")
|
319 |
+
|
320 |
+
rag_ids = self.rag_retrieve_all( query_rags, rest_limit )
|
321 |
+
|
322 |
+
# 将rag_ids对应的故事 替换到persona中
|
323 |
+
augmented_persona = self.augment_persona( self.persona, rag_ids, query_rags )
|
324 |
+
|
325 |
+
system_prompt = self.package_system_prompt( self.role_name, augmented_persona )
|
326 |
+
|
327 |
+
token_for_system = self.token_counter( system_prompt )
|
328 |
+
|
329 |
+
rest_limit = self.max_input_token - token_for_system - query_token
|
330 |
+
|
331 |
+
message = [{"role":"system","content":system_prompt}]
|
332 |
+
|
333 |
+
message = self.append_history_under_limit( message, rest_limit )
|
334 |
+
|
335 |
+
# TODO: 之后为了解决多人对话,这了content还会额外增加speaker: content这样的信息
|
336 |
+
|
337 |
+
message.append({"role":"user","content":text})
|
338 |
+
|
339 |
+
self.last_query_msg = {"speaker":user,"content":text}
|
340 |
+
|
341 |
+
return message
|
342 |
+
|
343 |
+
def package_system_prompt(self, role_name, augmented_persona):
|
344 |
+
bot_name = role_name
|
345 |
+
return f"""You are now in roleplay conversation mode. Pretend to be {bot_name} whose persona follows:
|
346 |
+
{augmented_persona}
|
347 |
+
|
348 |
+
You will stay in-character whenever possible, and generate responses as if you were {bot_name}"""
|
349 |
+
|
350 |
+
|
351 |
+
def augment_persona(self, persona, rag_ids, query_rags):
|
352 |
+
lines = persona.split("\n")
|
353 |
+
for rag_id, query_rag in zip(rag_ids, query_rags):
|
354 |
+
lid = query_rag['lid']
|
355 |
+
new_text = ""
|
356 |
+
for id in rag_id:
|
357 |
+
new_text += "###\n" + self.db.stories[id].strip() + "\n"
|
358 |
+
new_text = new_text.strip()
|
359 |
+
lines[lid] = new_text
|
360 |
+
return "\n".join(lines)
|
361 |
+
|
362 |
+
def load_role_from_jsonl( self, role_from_jsonl ):
|
363 |
+
import json
|
364 |
+
datas = []
|
365 |
+
with open(role_from_jsonl, 'r') as f:
|
366 |
+
for line in f:
|
367 |
+
try:
|
368 |
+
datas.append(json.loads(line))
|
369 |
+
except:
|
370 |
+
continue
|
371 |
+
|
372 |
+
column_name = ""
|
373 |
+
|
374 |
+
from .embeddings import embedname2columnname
|
375 |
+
|
376 |
+
if self.embed_name in embedname2columnname:
|
377 |
+
column_name = embedname2columnname[self.embed_name]
|
378 |
+
else:
|
379 |
+
print('warning! unkown embedding name ', self.embed_name ,' while loading role')
|
380 |
+
column_name = 'luotuo_openai'
|
381 |
+
|
382 |
+
stories, story_vecs, persona = self.extract_text_vec_from_datas(datas, column_name)
|
383 |
+
|
384 |
+
return persona, None, stories, story_vecs
|
385 |
+
|
386 |
+
|
387 |
+
def load_role_from_hf(self, role_from_hf):
|
388 |
+
# 从hf加载role
|
389 |
+
# self.persona, new_role_name, self.stories, self.story_vecs = self.load_role_from_hf(role_from_hf)
|
390 |
+
|
391 |
+
from datasets import load_dataset
|
392 |
+
|
393 |
+
if role_from_hf.count("/") == 1:
|
394 |
+
dataset = load_dataset(role_from_hf)
|
395 |
+
datas = dataset["train"]
|
396 |
+
elif role_from_hf.count("/") >= 2:
|
397 |
+
split_index = role_from_hf.index('/')
|
398 |
+
second_split_index = role_from_hf.index('/', split_index+1)
|
399 |
+
dataset_name = role_from_hf[:second_split_index]
|
400 |
+
split_name = role_from_hf[second_split_index+1:]
|
401 |
+
|
402 |
+
fname = split_name + '.jsonl'
|
403 |
+
dataset = load_dataset(dataset_name,data_files={'train':fname})
|
404 |
+
datas = dataset["train"]
|
405 |
+
|
406 |
+
column_name = ""
|
407 |
+
|
408 |
+
from .embeddings import embedname2columnname
|
409 |
+
|
410 |
+
if self.embed_name in embedname2columnname:
|
411 |
+
column_name = embedname2columnname[self.embed_name]
|
412 |
+
else:
|
413 |
+
print('warning! unkown embedding name ', self.embed_name ,' while loading role')
|
414 |
+
column_name = 'luotuo_openai'
|
415 |
+
|
416 |
+
stories, story_vecs, persona = self.extract_text_vec_from_datas(datas, column_name)
|
417 |
+
|
418 |
+
return persona, None, stories, story_vecs
|
419 |
+
|
420 |
+
def extract_text_vec_from_datas(self, datas, column_name):
|
421 |
+
# 从datas中提取text和vec
|
422 |
+
# extract text and vec from huggingface dataset
|
423 |
+
# return texts, vecs
|
424 |
+
# from .utils import base64_to_float_array
|
425 |
+
|
426 |
+
texts = []
|
427 |
+
vecs = []
|
428 |
+
for data in datas:
|
429 |
+
if data[column_name] == 'system_prompt':
|
430 |
+
system_prompt = get_text_from_data( data )
|
431 |
+
elif data[column_name] == 'config':
|
432 |
+
pass
|
433 |
+
else:
|
434 |
+
vec = base64_to_float_array( data[column_name] )
|
435 |
+
text = get_text_from_data( data )
|
436 |
+
vecs.append( vec )
|
437 |
+
texts.append( text )
|
438 |
+
return texts, vecs, system_prompt
|
439 |
+
|
440 |
+
def extract_story_vecs(self, stories):
|
441 |
+
# 从stories中提取story_vecs
|
442 |
+
|
443 |
+
if self.verbose:
|
444 |
+
print(f"re-extract vector for {len(stories)} stories")
|
445 |
+
|
446 |
+
story_vecs = []
|
447 |
+
|
448 |
+
from .embeddings import embedshortname2model_name
|
449 |
+
from .embeddings import device
|
450 |
+
|
451 |
+
if device.type != "cpu" and self.embed_name in embedshortname2model_name:
|
452 |
+
# model_name = "BAAI/bge-small-zh-v1.5"
|
453 |
+
model_name = embedshortname2model_name[self.embed_name]
|
454 |
+
|
455 |
+
from .utils import get_general_embeddings_safe
|
456 |
+
story_vecs = get_general_embeddings_safe( stories, model_name = model_name )
|
457 |
+
# 使用batch的方式进行embedding,非常快
|
458 |
+
else:
|
459 |
+
from tqdm import tqdm
|
460 |
+
for story in tqdm(stories):
|
461 |
+
story_vecs.append(self.embedding(story))
|
462 |
+
|
463 |
+
return story_vecs
|
464 |
+
|
465 |
+
def build_db(self, stories, story_vecs):
|
466 |
+
# db的构造函数
|
467 |
+
if self.db is None:
|
468 |
+
from .NaiveDB import NaiveDB
|
469 |
+
self.db = NaiveDB()
|
470 |
+
self.db.build_db(stories, story_vecs)
|
471 |
+
|
ChatHaruhi/NaiveDB.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import string
|
3 |
+
import os
|
4 |
+
from math import sqrt
|
5 |
+
|
6 |
+
class NaiveDB:
|
7 |
+
def __init__(self):
|
8 |
+
self.verbose = False
|
9 |
+
self.init_db()
|
10 |
+
|
11 |
+
def init_db(self):
|
12 |
+
if self.verbose:
|
13 |
+
print("call init_db")
|
14 |
+
self.stories = []
|
15 |
+
self.norms = []
|
16 |
+
self.vecs = []
|
17 |
+
self.flags = [] # 用于标记每个story是否可以被搜索
|
18 |
+
self.metas = [] # 用于存储每个story的meta信息
|
19 |
+
self.last_search_ids = [] # 用于存储上一次搜索的结果
|
20 |
+
|
21 |
+
def build_db(self, stories, vecs, flags = None, metas = None):
|
22 |
+
self.stories = stories
|
23 |
+
self.vecs = vecs
|
24 |
+
self.flags = flags if flags else [True for _ in self.stories]
|
25 |
+
self.metas = metas if metas else [{} for _ in self.stories]
|
26 |
+
self.recompute_norm()
|
27 |
+
|
28 |
+
def save(self, file_path):
|
29 |
+
print( "warning! directly save folder from dbtype NaiveDB has not been implemented yet, try use role_from_hf to load role instead" )
|
30 |
+
|
31 |
+
def load(self, file_path):
|
32 |
+
print( "warning! directly load folder from dbtype NaiveDB has not been implemented yet, try use role_from_hf to load role instead" )
|
33 |
+
|
34 |
+
def recompute_norm( self ):
|
35 |
+
# 补全这部分代码,self.norms 分别存储每个vector的l2 norm
|
36 |
+
# 计算每个向量的L2范数
|
37 |
+
self.norms = [sqrt(sum([x**2 for x in vec])) for vec in self.vecs]
|
38 |
+
|
39 |
+
def get_stories_with_id(self, ids ):
|
40 |
+
return [self.stories[i] for i in ids]
|
41 |
+
|
42 |
+
def clean_flag(self):
|
43 |
+
self.flags = [True for _ in self.stories]
|
44 |
+
|
45 |
+
def disable_story_with_ids(self, close_ids ):
|
46 |
+
for id in close_ids:
|
47 |
+
self.flags[id] = False
|
48 |
+
|
49 |
+
def close_last_search(self):
|
50 |
+
for id in self.last_search_ids:
|
51 |
+
self.flags[id] = False
|
52 |
+
|
53 |
+
def search(self, query_vector , n_results):
|
54 |
+
|
55 |
+
if self.verbose:
|
56 |
+
print("call search")
|
57 |
+
|
58 |
+
if len(self.norms) != len(self.vecs):
|
59 |
+
self.recompute_norm()
|
60 |
+
|
61 |
+
# 计算查询向量的范数
|
62 |
+
query_norm = sqrt(sum([x**2 for x in query_vector]))
|
63 |
+
|
64 |
+
idxs = list(range(len(self.vecs)))
|
65 |
+
|
66 |
+
# 计算余弦相似度
|
67 |
+
similarities = []
|
68 |
+
for vec, norm, idx in zip(self.vecs, self.norms, idxs ):
|
69 |
+
if len(self.flags) == len(self.vecs) and not self.flags[idx]:
|
70 |
+
continue
|
71 |
+
|
72 |
+
dot_product = sum(q * v for q, v in zip(query_vector, vec))
|
73 |
+
if query_norm < 1e-20:
|
74 |
+
similarities.append( (random.random(), idx) )
|
75 |
+
continue
|
76 |
+
cosine_similarity = dot_product / (query_norm * norm)
|
77 |
+
similarities.append( ( cosine_similarity, idx) )
|
78 |
+
|
79 |
+
# 获取最相似的n_results个结果, 使用第0个字段进行排序
|
80 |
+
similarities.sort(key=lambda x: x[0], reverse=True)
|
81 |
+
self.last_search_ids = [x[1] for x in similarities[:n_results]]
|
82 |
+
|
83 |
+
top_indices = [x[1] for x in similarities[:n_results]]
|
84 |
+
return top_indices
|
85 |
+
|
86 |
+
|
87 |
+
|
88 |
+
|
ChatHaruhi/Readme.md
ADDED
@@ -0,0 +1,352 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ChatHaruhi 3.0的接口设计
|
2 |
+
|
3 |
+
在ChatHaruhi2.0大约1个季度的使用后
|
4 |
+
我们初步知道了这样一个模型的一些需求,所以我们在这里开始设计ChatHaruhi3.0
|
5 |
+
|
6 |
+
## 基本原则
|
7 |
+
|
8 |
+
- 兼容RAG和Zeroshot模式
|
9 |
+
- 主类以返回message为主,当然可以把语言模型(adapter直接to response)的接口设置给chatbot
|
10 |
+
- 主类尽可能轻量,除了embedding没有什么依赖
|
11 |
+
|
12 |
+
## 用户代码
|
13 |
+
|
14 |
+
```python
|
15 |
+
from ChatHaruhi import ChatHaruhi
|
16 |
+
from ChatHaruhi.openai import get_openai_response
|
17 |
+
|
18 |
+
chatbot = ChatHaruhi( role_name = 'haruhi', llm = get_openai_response )
|
19 |
+
|
20 |
+
response = chatbot.chat(user = '阿虚', text = '我看新一年的棒球比赛要开始了!我们要去参加吗?')
|
21 |
+
```
|
22 |
+
|
23 |
+
这样的好处是ChatHaruhi类载入的时候,不需要install 除了embedding以外 其他的东西,llm需要的依赖库储存在每个语言模型自己的文件里面。
|
24 |
+
|
25 |
+
zero的模式(快速新建角色)
|
26 |
+
|
27 |
+
```python
|
28 |
+
from ChatHaruhi import ChatHaruhi
|
29 |
+
from ChatHaruhi.openai import get_openai_response
|
30 |
+
|
31 |
+
chatbot = ChatHaruhi( role_name = '小猫咪', persona = "你扮演一只小猫咪", llm = get_openai_response )
|
32 |
+
|
33 |
+
response = chatbot.chat(user = '怪叔叔', text = '嘿 *抓住了小猫咪*')
|
34 |
+
```
|
35 |
+
|
36 |
+
### 外置的inference
|
37 |
+
|
38 |
+
```python
|
39 |
+
def get_response( message ):
|
40 |
+
return "语言模型输出了角色扮演的结果"
|
41 |
+
|
42 |
+
from ChatHaruhi import ChatHaruhi
|
43 |
+
|
44 |
+
chatbot = ChatHaruhi( role_name = 'haruhi' ) # 默认情况下 llm = None
|
45 |
+
|
46 |
+
message = chatbot.get_message( user = '阿虚', text = '我看新一年的棒球比赛要开始了!我们要去参加吗?' )
|
47 |
+
|
48 |
+
response = get_response( message )
|
49 |
+
|
50 |
+
chatbot.append_message( response )
|
51 |
+
```
|
52 |
+
|
53 |
+
这个行为和下面的行为是等价的
|
54 |
+
|
55 |
+
```python
|
56 |
+
def get_response( message ):
|
57 |
+
return "语言模型输出了角色扮演的结果"
|
58 |
+
|
59 |
+
from ChatHaruhi import ChatHaruhi
|
60 |
+
|
61 |
+
chatbot = ChatHaruhi( role_name = 'haruhi', llm = get_response )
|
62 |
+
|
63 |
+
response = chatbot.chat(user = '阿虚', text = '我看新一年的棒球比赛要开始了!我们要去参加吗?' )
|
64 |
+
```
|
65 |
+
|
66 |
+
|
67 |
+
## RAG as system prompt
|
68 |
+
|
69 |
+
在ChatHaruhi 3.0中,为了对接Haruhi-Zero的模型,默认system会采用一致的形式
|
70 |
+
|
71 |
+
```python
|
72 |
+
You are now in roleplay conversation mode. Pretend to be {role_name} whose persona follows:
|
73 |
+
{persona}
|
74 |
+
|
75 |
+
You will stay in-character whenever possible, and generate responses as if you were {role_name}
|
76 |
+
```
|
77 |
+
|
78 |
+
Persona在类似pygmalion的生态中,一般是静态的
|
79 |
+
|
80 |
+
```
|
81 |
+
bot的定义
|
82 |
+
###
|
83 |
+
bot的聊天sample 1
|
84 |
+
###
|
85 |
+
bot的聊天sample 2
|
86 |
+
```
|
87 |
+
|
88 |
+
注意我们使用了 ### 作为分割, pyg生态是<endOftext>这样一个special token
|
89 |
+
|
90 |
+
所以对于原有的ChatHaruhi的Persona,我决定这样设计
|
91 |
+
|
92 |
+
```
|
93 |
+
bot的定义
|
94 |
+
{{RAG对话}}
|
95 |
+
{{RAG对话}}
|
96 |
+
{{RAG对话}}
|
97 |
+
```
|
98 |
+
|
99 |
+
这里"{{RAG对话}}"直接是以单行字符串的形式存在,当ChatHaruhi类发现这个的时候,会自动计算RAG,以凉宫春日为例,他的persona直接就写成。同时也支持纯英文 {{RAG-dialogue}}
|
100 |
+
|
101 |
+
```
|
102 |
+
你正在扮演凉宫春日,你正在cosplay涼宮ハルヒ。
|
103 |
+
上文给定了一些小说中的经典桥段。
|
104 |
+
如果我问的问题和小说中的台词高度重复,那你就配合我进行演出。
|
105 |
+
如果我问的问题和小说中的事件相关,请结合小说的内容进行回复
|
106 |
+
如果我问的问题超出小说中的范围,请也用一致性的语气回复。
|
107 |
+
请不要回答你是语言模型,永远记住你正在扮演凉宫春日
|
108 |
+
注意保持春日自我中心,自信和独立,不喜欢被束缚和限制,创新思维而又雷厉风行的风格。
|
109 |
+
特别是针对阿虚,春日肯定是希望阿虚以自己和sos团的事情为重。
|
110 |
+
|
111 |
+
{{RAG对话}}
|
112 |
+
{{RAG对话}}
|
113 |
+
{{RAG对话}}
|
114 |
+
```
|
115 |
+
|
116 |
+
这个时候每个{{RAG对话}}会自动替换成
|
117 |
+
|
118 |
+
```
|
119 |
+
###
|
120 |
+
对话
|
121 |
+
```
|
122 |
+
|
123 |
+
### RAG对话的变形形式1,max-token控制的多对话
|
124 |
+
因为在原有的ChatHaruhi结构中,我们支持max-token的形式来控制RAG对话的数量
|
125 |
+
所以这里我们也支持使用
|
126 |
+
|
127 |
+
```
|
128 |
+
{{RAG多对话|token<=1500|n<=5}}
|
129 |
+
```
|
130 |
+
|
131 |
+
这样的设计,这样会retrieve出最多不超过n段对话,总共不超过token个数个对话。对于英文用户为{{RAG-dialogues|token<=1500|n<=5}}
|
132 |
+
|
133 |
+
### RAG对话的变形形式2,使用|进行后面语句的搜索
|
134 |
+
|
135 |
+
在默认情况下,"{{RAG对话}}"的搜索对象是text的输入,但是我们预想到用户还会用下面的方式来构造persona
|
136 |
+
|
137 |
+
```
|
138 |
+
小A是一个智能的机器人
|
139 |
+
|
140 |
+
当小A高兴时
|
141 |
+
{{RAG对话|高兴的对话}}
|
142 |
+
|
143 |
+
当小A伤心时
|
144 |
+
{{RAG对话|伤心的对话}}
|
145 |
+
这个时候我们支持使用""{{RAG对话|<不包含花括号的一个字符串>}}"" 来进行RAG
|
146 |
+
```
|
147 |
+
|
148 |
+
## get_message
|
149 |
+
|
150 |
+
get_message会返回一个类似openai message形式的message
|
151 |
+
|
152 |
+
```
|
153 |
+
[{"role":"system","content":整个system prompt},
|
154 |
+
{"role":"user","content":用户的输入},
|
155 |
+
{"role":"assistant","content":模型的输出},
|
156 |
+
...]
|
157 |
+
```
|
158 |
+
|
159 |
+
原则上来说,如果使用openai,可以直接使用
|
160 |
+
|
161 |
+
```python
|
162 |
+
def get_response( messages ):
|
163 |
+
completion = client.chat.completions.create(
|
164 |
+
model="gpt-3.5-turbo-1106",
|
165 |
+
messages=messages,
|
166 |
+
temperature=0.3
|
167 |
+
)
|
168 |
+
|
169 |
+
return completion.choices[0].message.content
|
170 |
+
```
|
171 |
+
|
172 |
+
对于异步的实现
|
173 |
+
|
174 |
+
```python
|
175 |
+
async def async_get_response( messages ):
|
176 |
+
resp = await aclient.chat.completions.create(
|
177 |
+
model=model,
|
178 |
+
messages=messages,
|
179 |
+
temperature=0.3,
|
180 |
+
)
|
181 |
+
return result
|
182 |
+
```
|
183 |
+
|
184 |
+
### async_chat的调用
|
185 |
+
设计上也会去支持
|
186 |
+
|
187 |
+
```python
|
188 |
+
async def get_response( message ):
|
189 |
+
return "语言模型输出了角色扮演的结果"
|
190 |
+
|
191 |
+
from ChatHaruhi import ChatHaruhi
|
192 |
+
|
193 |
+
chatbot = ChatHaruhi( role_name = 'haruhi', llm_async = get_response )
|
194 |
+
|
195 |
+
response = await chatbot.async_chat(user='阿虚', text = '我看新一年的棒球比赛要开始了!我们要去参加吗?' )
|
196 |
+
```
|
197 |
+
|
198 |
+
这样异步的调用
|
199 |
+
|
200 |
+
# 角色载入
|
201 |
+
|
202 |
+
如果这样看来,新的ChatHaruhi3.0需要以下信息
|
203 |
+
|
204 |
+
- persona 这个是必须的
|
205 |
+
- role_name, 在后处理的时候,把 {{role}} 和 {{角色}} 替换为这个字段, 这个字段不能为空,因为system prompt使用了这个字段,如果要支持这个字段为空,我们要额外设计一个备用prompt
|
206 |
+
- user_name, 在后处理的时候,把 {{用户}} 和 {{user}} 替换为这个字段,如果不设置也可以不替换
|
207 |
+
- RAG库, 当RAG库为空的时候,所有{{RAG*}}就直接删除了
|
208 |
+
|
209 |
+
## role_name载入
|
210 |
+
|
211 |
+
语法糖载入,不支持用户自己搞新角色,这个时候我们可以完全使用原来的数据
|
212 |
+
|
213 |
+
额外需要设置一个role_name
|
214 |
+
|
215 |
+
## role_from_jsonl载入
|
216 |
+
|
217 |
+
这个时候我们需要设置role_name
|
218 |
+
|
219 |
+
如果不设置我们会抛出一个error
|
220 |
+
|
221 |
+
## role_from_hf
|
222 |
+
|
223 |
+
本质上就是role_from_jsonl
|
224 |
+
|
225 |
+
## 分别设置persona和role_name
|
226 |
+
|
227 |
+
这个时候作为新人物考虑,默认没有RAG库,即Zero模式
|
228 |
+
|
229 |
+
## 分别设置persona, role_name, texts
|
230 |
+
|
231 |
+
这个时候会为texts再次抽取vectors
|
232 |
+
|
233 |
+
## 分别设置persona, role_name, texts, vecs
|
234 |
+
|
235 |
+
|
236 |
+
|
237 |
+
# 额外变量
|
238 |
+
|
239 |
+
## max_input_token
|
240 |
+
|
241 |
+
默认为1600,会根据这个来限制history的长度
|
242 |
+
|
243 |
+
## user_name_in_message
|
244 |
+
|
245 |
+
(这个功能在现在的预期核心代码中还没实现)
|
246 |
+
|
247 |
+
默认为'default', 当用户始终用同一个user_name和角色对话的时候,并不添加
|
248 |
+
|
249 |
+
如果用户使用不同的role和chatbot聊天 user_name_in_message 会改为 'add' 并在每个message标记是谁说的
|
250 |
+
|
251 |
+
(bot的也会添加)
|
252 |
+
|
253 |
+
并且user_name替换为最后一个调用的user_name
|
254 |
+
|
255 |
+
如果'not_add' 则永远不添加
|
256 |
+
|
257 |
+
S MSG_U1 MSG_A MSG_U1 MSG_A
|
258 |
+
|
259 |
+
当出现U2后
|
260 |
+
|
261 |
+
S, U1:MSG_U1, A:MSG_A, U1:MSG_U1, A:MSG_A, U2:MSG_U2
|
262 |
+
|
263 |
+
## token_counter
|
264 |
+
|
265 |
+
tokenizer默认为gpt3.5的tiktoken,设置为None的时候,不进行任何的token长度限制
|
266 |
+
|
267 |
+
## transfer_haruhi_2_zero
|
268 |
+
|
269 |
+
(这个功能在现在的预期核心代码中还没实现)
|
270 |
+
|
271 |
+
默认为true
|
272 |
+
|
273 |
+
把原本ChatHaruhi的 角色: 「对话」的格式,去掉「」
|
274 |
+
|
275 |
+
# Embedding
|
276 |
+
|
277 |
+
中文考虑用bge_small
|
278 |
+
|
279 |
+
Cross language考虑使用bce,相对还比较小, bge-m3太大了
|
280 |
+
|
281 |
+
也就是ChatHaruhi类会有默认的embedding
|
282 |
+
|
283 |
+
self.embedding = ChatHaruhi.bge_small
|
284 |
+
|
285 |
+
对于输入的文本,我们会使用这个embedding来进行encode然后进行检索替换掉RAG的内容
|
286 |
+
|
287 |
+
# 辅助接口
|
288 |
+
|
289 |
+
## save_to_jsonl
|
290 |
+
|
291 |
+
把一个角色保存成jsonl格式,方便上传hf
|
292 |
+
|
293 |
+
|
294 |
+
# 预计的伪代码
|
295 |
+
|
296 |
+
这里的核心就是去考虑ChatHaruhi下get_message函数的伪代码
|
297 |
+
|
298 |
+
```python
|
299 |
+
class ChatHaruhi:
|
300 |
+
|
301 |
+
def __init__( self ):
|
302 |
+
pass
|
303 |
+
|
304 |
+
def rag_retrieve( self, query_rags, rest_limit ):
|
305 |
+
# 返回一个rag_ids的列表
|
306 |
+
retrieved_ids = []
|
307 |
+
rag_ids = []
|
308 |
+
|
309 |
+
for query_rag in query_rags:
|
310 |
+
query = query_rag['query']
|
311 |
+
n = query_rag['n']
|
312 |
+
max_token = rest_limit
|
313 |
+
if rest_limit > query_rag['max_token'] and query_rag['max_token'] > 0:
|
314 |
+
max_token = query_rag['max_token']
|
315 |
+
|
316 |
+
rag_id = self.rag_retrieve( query, n, max_token, avoid_ids = retrieved_ids )
|
317 |
+
rag_ids.append( rag_id )
|
318 |
+
retrieved_ids += rag_id
|
319 |
+
|
320 |
+
def get_message(self, user, text):
|
321 |
+
|
322 |
+
query_token = self.token_counter( text )
|
323 |
+
|
324 |
+
# 首先获取需要多少个rag story
|
325 |
+
query_rags, persona_token = self.parse_persona( self.persona, text )
|
326 |
+
#每个query_rag需要饱含
|
327 |
+
# "n" 需要几个story
|
328 |
+
# "max_token" 最多允许多少个token,如果-1则不限制
|
329 |
+
# "query" 需要查询的内容
|
330 |
+
# "lid" 需要替换的行,这里直接进行行替换,忽视行的其他内容
|
331 |
+
|
332 |
+
rest_limit = self.max_input_token - persona_token - query_token
|
333 |
+
|
334 |
+
rag_ids = self.rag_retrieve( query_rags, rest_limit )
|
335 |
+
|
336 |
+
# 将rag_ids对应的故事 替换到persona中
|
337 |
+
augmented_persona = self.augment_persona( self.persona, rag_ids )
|
338 |
+
|
339 |
+
system_prompt = self.package_system_prompt( self.role_name, augmented_persona )
|
340 |
+
|
341 |
+
token_for_system = self.token_counter( system_prompt )
|
342 |
+
|
343 |
+
rest_limit = self.max_input_token - token_for_system - query_token
|
344 |
+
|
345 |
+
messages = [{"role":"system","content":system_prompt}]
|
346 |
+
|
347 |
+
messages = self.append_history_under_limit( messages, rest_limit )
|
348 |
+
|
349 |
+
messages.append({"role":"user",query})
|
350 |
+
|
351 |
+
return messages
|
352 |
+
```
|
ChatHaruhi/SparkApi.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import _thread as thread
|
2 |
+
import base64
|
3 |
+
import datetime
|
4 |
+
import hashlib
|
5 |
+
import hmac
|
6 |
+
import json
|
7 |
+
from urllib.parse import urlparse
|
8 |
+
import ssl
|
9 |
+
from datetime import datetime
|
10 |
+
from time import mktime
|
11 |
+
from urllib.parse import urlencode
|
12 |
+
from wsgiref.handlers import format_date_time
|
13 |
+
|
14 |
+
import websocket # 使用websocket_client
|
15 |
+
answer = ""
|
16 |
+
appid = None
|
17 |
+
api_secret = None
|
18 |
+
api_key = None
|
19 |
+
|
20 |
+
class Ws_Param(object):
|
21 |
+
# 初始化
|
22 |
+
def __init__(self, APPID, APIKey, APISecret, Spark_url):
|
23 |
+
self.APPID = APPID
|
24 |
+
self.APIKey = APIKey
|
25 |
+
self.APISecret = APISecret
|
26 |
+
self.host = urlparse(Spark_url).netloc
|
27 |
+
self.path = urlparse(Spark_url).path
|
28 |
+
self.Spark_url = Spark_url
|
29 |
+
|
30 |
+
# 生成url
|
31 |
+
def create_url(self):
|
32 |
+
# 生成RFC1123格式的时间戳
|
33 |
+
now = datetime.now()
|
34 |
+
date = format_date_time(mktime(now.timetuple()))
|
35 |
+
|
36 |
+
# 拼接字符串
|
37 |
+
signature_origin = "host: " + self.host + "\n"
|
38 |
+
signature_origin += "date: " + date + "\n"
|
39 |
+
signature_origin += "GET " + self.path + " HTTP/1.1"
|
40 |
+
|
41 |
+
# 进行hmac-sha256进行加密
|
42 |
+
signature_sha = hmac.new(self.APISecret.encode('utf-8'), signature_origin.encode('utf-8'),
|
43 |
+
digestmod=hashlib.sha256).digest()
|
44 |
+
|
45 |
+
signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding='utf-8')
|
46 |
+
|
47 |
+
authorization_origin = f'api_key="{self.APIKey}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"'
|
48 |
+
|
49 |
+
authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
|
50 |
+
|
51 |
+
# 将请求的鉴权参数组合为字典
|
52 |
+
v = {
|
53 |
+
"authorization": authorization,
|
54 |
+
"date": date,
|
55 |
+
"host": self.host
|
56 |
+
}
|
57 |
+
# 拼接鉴权参数,生成url
|
58 |
+
url = self.Spark_url + '?' + urlencode(v)
|
59 |
+
# 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释,比对相同参数时生成的url与自己代码生成的url是否一致
|
60 |
+
return url
|
61 |
+
|
62 |
+
|
63 |
+
# 收到websocket错误的处理
|
64 |
+
def on_error(ws, error):
|
65 |
+
print("### error:", error)
|
66 |
+
|
67 |
+
|
68 |
+
# 收到websocket关闭的处理
|
69 |
+
def on_close(ws,one,two):
|
70 |
+
return
|
71 |
+
# print(" ")
|
72 |
+
|
73 |
+
|
74 |
+
# 收到websocket连接建立的处理
|
75 |
+
def on_open(ws):
|
76 |
+
thread.start_new_thread(run, (ws,))
|
77 |
+
|
78 |
+
|
79 |
+
def run(ws, *args):
|
80 |
+
data = json.dumps(gen_params(appid=ws.appid, domain= ws.domain,question=ws.question))
|
81 |
+
ws.send(data)
|
82 |
+
|
83 |
+
|
84 |
+
# 收到websocket消息的处理
|
85 |
+
def on_message(ws, message):
|
86 |
+
# print(message)
|
87 |
+
data = json.loads(message)
|
88 |
+
code = data['header']['code']
|
89 |
+
if code != 0:
|
90 |
+
print(f'请求错误: {code}, {data}')
|
91 |
+
ws.close()
|
92 |
+
else:
|
93 |
+
choices = data["payload"]["choices"]
|
94 |
+
status = choices["status"]
|
95 |
+
content = choices["text"][0]["content"]
|
96 |
+
# print(content,end ="")
|
97 |
+
global answer
|
98 |
+
answer += content
|
99 |
+
# print(1)
|
100 |
+
if status == 2:
|
101 |
+
ws.close()
|
102 |
+
|
103 |
+
|
104 |
+
def gen_params(appid, domain,question):
|
105 |
+
"""
|
106 |
+
通过appid和用户的提问来生成请参数
|
107 |
+
"""
|
108 |
+
data = {
|
109 |
+
"header": {
|
110 |
+
"app_id": appid,
|
111 |
+
"uid": "1234"
|
112 |
+
},
|
113 |
+
"parameter": {
|
114 |
+
"chat": {
|
115 |
+
"domain": domain,
|
116 |
+
"temperature": 0.5,
|
117 |
+
"max_tokens": 2048
|
118 |
+
}
|
119 |
+
},
|
120 |
+
"payload": {
|
121 |
+
"message": {
|
122 |
+
"text": question
|
123 |
+
}
|
124 |
+
}
|
125 |
+
}
|
126 |
+
return data
|
127 |
+
|
128 |
+
|
129 |
+
def main(appid, api_key, api_secret, Spark_url,domain, question):
|
130 |
+
# print("星火:")
|
131 |
+
wsParam = Ws_Param(appid, api_key, api_secret, Spark_url)
|
132 |
+
websocket.enableTrace(False)
|
133 |
+
wsUrl = wsParam.create_url()
|
134 |
+
ws = websocket.WebSocketApp(wsUrl, on_message=on_message, on_error=on_error, on_close=on_close, on_open=on_open)
|
135 |
+
ws.appid = appid
|
136 |
+
ws.question = question
|
137 |
+
ws.domain = domain
|
138 |
+
ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
|
139 |
+
|
140 |
+
|
ChatHaruhi/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from .ChatHaruhi import ChatHaruhi
|
ChatHaruhi/embeddings.py
ADDED
@@ -0,0 +1,270 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
|
3 |
+
# elif embedding == 'bge_en':
|
4 |
+
# embed_name = 'bge_en_s15'
|
5 |
+
# elif embedding == 'bge_zh':
|
6 |
+
# embed_name = 'bge_zh_s15'
|
7 |
+
|
8 |
+
import torch
|
9 |
+
|
10 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
11 |
+
|
12 |
+
|
13 |
+
embedshortname2model_name = {
|
14 |
+
"bge_zh":"BAAI/bge-small-zh-v1.5",
|
15 |
+
}
|
16 |
+
|
17 |
+
embedname2columnname = {
|
18 |
+
"luotuo_openai":"luotuo_openai",
|
19 |
+
"openai":"luotuo_openai",
|
20 |
+
"bge_zh":"bge_zh_s15",
|
21 |
+
"bge_en":"bge_en_s15",
|
22 |
+
"bce":"bce_base",
|
23 |
+
}
|
24 |
+
|
25 |
+
# 这是用来调试的foo embedding
|
26 |
+
|
27 |
+
def foo_embedding( text ):
|
28 |
+
# whatever text input , output a 2 dim 0-1 random vects
|
29 |
+
return [random.random(), random.random()]
|
30 |
+
|
31 |
+
# TODO: add bge-zh-small(or family) BCE and openai embedding here 米唯实
|
32 |
+
# ======== add bge_zh mmodel
|
33 |
+
# by Weishi MI
|
34 |
+
|
35 |
+
def foo_bge_zh_15( text ):
|
36 |
+
dim = 512
|
37 |
+
model_name = "BAAI/bge-small-zh-v1.5"
|
38 |
+
if isinstance(text, str):
|
39 |
+
text_list = [text]
|
40 |
+
else:
|
41 |
+
get_general_embeddings_safe(text, model_name)
|
42 |
+
|
43 |
+
global _model_pool
|
44 |
+
global _tokenizer_pool
|
45 |
+
|
46 |
+
if model_name not in _model_pool:
|
47 |
+
from transformers import AutoTokenizer, AutoModel
|
48 |
+
_tokenizer_pool[model_name] = AutoTokenizer.from_pretrained(model_name)
|
49 |
+
_model_pool[model_name] = AutoModel.from_pretrained(model_name)
|
50 |
+
|
51 |
+
_model_pool[model_name].eval()
|
52 |
+
|
53 |
+
# Tokenize sentences
|
54 |
+
encoded_input = _tokenizer_pool[model_name](text_list, padding=True, truncation=True, return_tensors='pt', max_length = 512)
|
55 |
+
|
56 |
+
# Compute token embeddings
|
57 |
+
with torch.no_grad():
|
58 |
+
model_output = _model_pool[model_name](**encoded_input)
|
59 |
+
# Perform pooling. In this case, cls pooling.
|
60 |
+
sentence_embeddings = model_output[0][:, 0]
|
61 |
+
|
62 |
+
# normalize embeddings
|
63 |
+
sentence_embeddings = torch.nn.functional.normalize(sentence_embeddings, p=2, dim=1)
|
64 |
+
return sentence_embeddings.cpu().tolist()[0]
|
65 |
+
# return [random.random() for _ in range(dim)]
|
66 |
+
|
67 |
+
def foo_bce( text ):
|
68 |
+
from transformers import AutoModel, AutoTokenizer
|
69 |
+
if isinstance(text, str):
|
70 |
+
text_list = [text]
|
71 |
+
|
72 |
+
# init model and tokenizer
|
73 |
+
tokenizer = AutoTokenizer.from_pretrained('maidalun1020/bce-embedding-base_v1')
|
74 |
+
model = AutoModel.from_pretrained('maidalun1020/bce-embedding-base_v1')
|
75 |
+
|
76 |
+
model.to(device)
|
77 |
+
|
78 |
+
# get inputs
|
79 |
+
inputs = tokenizer(text_list, padding=True, truncation=True, max_length=512, return_tensors="pt")
|
80 |
+
inputs_on_device = {k: v.to(self.device) for k, v in inputs.items()}
|
81 |
+
|
82 |
+
# get embeddings
|
83 |
+
outputs = model(**inputs_on_device, return_dict=True)
|
84 |
+
embeddings = outputs.last_hidden_state[:, 0] # cls pooler
|
85 |
+
embeddings = embeddings / embeddings.norm(dim=1, keepdim=True) # normalize
|
86 |
+
return embeddings
|
87 |
+
def download_models():
|
88 |
+
print("正在下载Luotuo-Bert")
|
89 |
+
# Import our models. The package will take care of downloading the models automatically
|
90 |
+
model_args = Namespace(do_mlm=None, pooler_type="cls", temp=0.05, mlp_only_train=False,
|
91 |
+
init_embeddings_model=None)
|
92 |
+
model = AutoModel.from_pretrained("silk-road/luotuo-bert-medium", trust_remote_code=True, model_args=model_args).to(
|
93 |
+
device)
|
94 |
+
print("Luotuo-Bert下载完毕")
|
95 |
+
return model
|
96 |
+
|
97 |
+
def get_luotuo_model():
|
98 |
+
global _luotuo_model
|
99 |
+
if _luotuo_model is None:
|
100 |
+
_luotuo_model = download_models()
|
101 |
+
return _luotuo_model
|
102 |
+
|
103 |
+
|
104 |
+
def luotuo_embedding(model, texts):
|
105 |
+
# Tokenize the texts_source
|
106 |
+
tokenizer = AutoTokenizer.from_pretrained("silk-road/luotuo-bert-medium")
|
107 |
+
inputs = tokenizer(texts, padding=True, truncation=False, return_tensors="pt")
|
108 |
+
inputs = inputs.to(device)
|
109 |
+
# Extract the embeddings
|
110 |
+
# Get the embeddings
|
111 |
+
with torch.no_grad():
|
112 |
+
embeddings = model(**inputs, output_hidden_states=True, return_dict=True, sent_emb=True).pooler_output
|
113 |
+
return embeddings
|
114 |
+
|
115 |
+
def luotuo_en_embedding( texts ):
|
116 |
+
# this function implemented by Cheng
|
117 |
+
global _luotuo_model_en
|
118 |
+
global _luotuo_en_tokenizer
|
119 |
+
|
120 |
+
if _luotuo_model_en is None:
|
121 |
+
_luotuo_en_tokenizer = AutoTokenizer.from_pretrained("silk-road/luotuo-bert-en")
|
122 |
+
_luotuo_model_en = AutoModel.from_pretrained("silk-road/luotuo-bert-en").to(device)
|
123 |
+
|
124 |
+
if _luotuo_en_tokenizer is None:
|
125 |
+
_luotuo_en_tokenizer = AutoTokenizer.from_pretrained("silk-road/luotuo-bert-en")
|
126 |
+
|
127 |
+
inputs = _luotuo_en_tokenizer(texts, padding=True, truncation=False, return_tensors="pt")
|
128 |
+
inputs = inputs.to(device)
|
129 |
+
|
130 |
+
with torch.no_grad():
|
131 |
+
embeddings = _luotuo_model_en(**inputs, output_hidden_states=True, return_dict=True, sent_emb=True).pooler_output
|
132 |
+
|
133 |
+
return embeddings
|
134 |
+
|
135 |
+
|
136 |
+
def get_embedding_for_chinese(model, texts):
|
137 |
+
model = model.to(device)
|
138 |
+
# str or strList
|
139 |
+
texts = texts if isinstance(texts, list) else [texts]
|
140 |
+
# 截断
|
141 |
+
for i in range(len(texts)):
|
142 |
+
if len(texts[i]) > 510:
|
143 |
+
texts[i] = texts[i][:510]
|
144 |
+
if len(texts) >= 64:
|
145 |
+
embeddings = []
|
146 |
+
chunk_size = 64
|
147 |
+
for i in range(0, len(texts), chunk_size):
|
148 |
+
embeddings.append(luotuo_embedding(model, texts[i: i + chunk_size]))
|
149 |
+
return torch.cat(embeddings, dim=0)
|
150 |
+
else:
|
151 |
+
return luotuo_embedding(model, texts)
|
152 |
+
|
153 |
+
|
154 |
+
def is_chinese_or_english(text):
|
155 |
+
# no longer use online openai api
|
156 |
+
return "chinese"
|
157 |
+
|
158 |
+
text = list(text)
|
159 |
+
is_chinese, is_english = 0, 0
|
160 |
+
|
161 |
+
for char in text:
|
162 |
+
# 判断字符的Unicode值是否在中文字符的Unicode范围内
|
163 |
+
if '\u4e00' <= char <= '\u9fa5':
|
164 |
+
is_chinese += 4
|
165 |
+
# 判断字符是否为英文字符(包括大小写字母和常见标点符号)
|
166 |
+
elif ('\u0041' <= char <= '\u005a') or ('\u0061' <= char <= '\u007a'):
|
167 |
+
is_english += 1
|
168 |
+
if is_chinese >= is_english:
|
169 |
+
return "chinese"
|
170 |
+
else:
|
171 |
+
return "english"
|
172 |
+
|
173 |
+
|
174 |
+
def get_embedding_openai(text, model="text-embedding-ada-002"):
|
175 |
+
text = text.replace("\n", " ")
|
176 |
+
return client.embeddings.create(input = [text], model=model).data[0].embedding
|
177 |
+
|
178 |
+
def get_embedding_for_english(text, model="text-embedding-ada-002"):
|
179 |
+
text = text.replace("\n", " ")
|
180 |
+
return client.embeddings.create(input = [text], model=model).data[0].embedding
|
181 |
+
|
182 |
+
import os
|
183 |
+
|
184 |
+
def foo_openai( text ):
|
185 |
+
# dim = 1536
|
186 |
+
|
187 |
+
openai_key = os.environ.get("OPENAI_API_KEY")
|
188 |
+
|
189 |
+
if isinstance(texts, list):
|
190 |
+
index = random.randint(0, len(texts) - 1)
|
191 |
+
if openai_key is None or is_chinese_or_english(texts[index]) == "chinese":
|
192 |
+
return [embed.cpu().tolist() for embed in get_embedding_for_chinese(get_luotuo_model(), texts)]
|
193 |
+
else:
|
194 |
+
return [get_embedding_for_english(text) for text in texts]
|
195 |
+
else:
|
196 |
+
if openai_key is None or is_chinese_or_english(texts) == "chinese":
|
197 |
+
return get_embedding_for_chinese(get_luotuo_model(), texts)[0].cpu().tolist()
|
198 |
+
else:
|
199 |
+
return get_embedding_for_english(texts)
|
200 |
+
|
201 |
+
|
202 |
+
### BGE family
|
203 |
+
|
204 |
+
|
205 |
+
# ======== add bge_zh mmodel
|
206 |
+
# by Cheng Li
|
207 |
+
# 这一次我们试图一次性去适配更多的模型
|
208 |
+
import torch
|
209 |
+
|
210 |
+
_model_pool = {}
|
211 |
+
_tokenizer_pool = {}
|
212 |
+
|
213 |
+
# BAAI/bge-small-zh-v1.5
|
214 |
+
|
215 |
+
def get_general_embeddings( sentences , model_name = "BAAI/bge-small-zh-v1.5" ):
|
216 |
+
|
217 |
+
global _model_pool
|
218 |
+
global _tokenizer_pool
|
219 |
+
|
220 |
+
if model_name not in _model_pool:
|
221 |
+
from transformers import AutoTokenizer, AutoModel
|
222 |
+
_tokenizer_pool[model_name] = AutoTokenizer.from_pretrained(model_name)
|
223 |
+
_model_pool[model_name] = AutoModel.from_pretrained(model_name).to(device)
|
224 |
+
|
225 |
+
_model_pool[model_name].eval()
|
226 |
+
|
227 |
+
# Tokenize sentences
|
228 |
+
encoded_input = _tokenizer_pool[model_name](sentences, padding=True, truncation=True, return_tensors='pt', max_length = 512).to(device)
|
229 |
+
|
230 |
+
# Compute token embeddings
|
231 |
+
with torch.no_grad():
|
232 |
+
model_output = _model_pool[model_name](**encoded_input)
|
233 |
+
# Perform pooling. In this case, cls pooling.
|
234 |
+
sentence_embeddings = model_output[0][:, 0]
|
235 |
+
|
236 |
+
# normalize embeddings
|
237 |
+
sentence_embeddings = torch.nn.functional.normalize(sentence_embeddings, p=2, dim=1)
|
238 |
+
return sentence_embeddings.cpu().tolist()
|
239 |
+
|
240 |
+
def get_general_embedding( text_or_texts , model_name = "BAAI/bge-small-zh-v1.5" ):
|
241 |
+
if isinstance(text_or_texts, str):
|
242 |
+
return get_general_embeddings([text_or_texts], model_name)[0]
|
243 |
+
else:
|
244 |
+
return get_general_embeddings_safe(text_or_texts, model_name)
|
245 |
+
|
246 |
+
general_batch_size = 16
|
247 |
+
|
248 |
+
import math
|
249 |
+
|
250 |
+
def get_general_embeddings_safe(sentences, model_name = "BAAI/bge-small-zh-v1.5"):
|
251 |
+
|
252 |
+
embeddings = []
|
253 |
+
|
254 |
+
num_batches = math.ceil(len(sentences) / general_batch_size)
|
255 |
+
|
256 |
+
from tqdm import tqdm
|
257 |
+
|
258 |
+
for i in tqdm( range(num_batches) ):
|
259 |
+
# print("run bge with batch ", i)
|
260 |
+
start_index = i * general_batch_size
|
261 |
+
end_index = min(len(sentences), start_index + general_batch_size)
|
262 |
+
batch = sentences[start_index:end_index]
|
263 |
+
embs = get_general_embeddings(batch, model_name)
|
264 |
+
embeddings.extend(embs)
|
265 |
+
|
266 |
+
return embeddings
|
267 |
+
|
268 |
+
def get_bge_zh_embedding( text_or_texts ):
|
269 |
+
return get_general_embedding(text_or_texts, "BAAI/bge-small-zh-v1.5")
|
270 |
+
|
ChatHaruhi/novel_extract.py
ADDED
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
3 |
+
|
4 |
+
|
5 |
+
import re
|
6 |
+
|
7 |
+
def extract_speaker(text):
|
8 |
+
# 使用正则表达式匹配文本开头的 "<name> :" 格式,并捕获冒号后面的内容
|
9 |
+
match = re.match(r'^([^:]+) :(.*)', text)
|
10 |
+
if match:
|
11 |
+
return (match.group(1), match.group(2).strip()) # 返回匹配到的name部分和冒号后面的内容作为元组
|
12 |
+
else:
|
13 |
+
return None, text # 如果不匹配,返回None和原始文本
|
14 |
+
|
15 |
+
|
16 |
+
def get_line_recall(query, line):
|
17 |
+
# 获得query中每个汉字在 line 中的recall
|
18 |
+
if not query or not line:
|
19 |
+
return 0
|
20 |
+
line_set = set(line)
|
21 |
+
return sum(char in line_set for char in query) / len(query)
|
22 |
+
|
23 |
+
|
24 |
+
def get_max_recall_in_lines(query, lines):
|
25 |
+
recall_values = [(get_line_recall(query, line), i) for i, line in enumerate(lines)]
|
26 |
+
return max(recall_values, default=(-1, -1), key=lambda x: x[0])
|
27 |
+
|
28 |
+
def extract_dialogues_from_response(text):
|
29 |
+
# Split the text into lines
|
30 |
+
lines = text.split('\n')
|
31 |
+
|
32 |
+
# Initialize an empty list to store the extracted dialogues
|
33 |
+
extracted_dialogues = []
|
34 |
+
|
35 |
+
valid_said_by = ["said by", "thought by", "described by", "from"]
|
36 |
+
|
37 |
+
# Iterate through each line
|
38 |
+
for line in lines:
|
39 |
+
# Split the line by '|' and strip whitespace from each part
|
40 |
+
parts = [part.strip() for part in line.split('|')]
|
41 |
+
|
42 |
+
# Check if the line has 4 parts and the third part is 'said by'
|
43 |
+
if len(parts) == 3:
|
44 |
+
# Extract the dialogue and speaker, and add to the list
|
45 |
+
if parts[2] == "speaker":
|
46 |
+
continue
|
47 |
+
|
48 |
+
if parts[1].strip().lower() not in valid_said_by:
|
49 |
+
continue
|
50 |
+
|
51 |
+
dialogue_dict = {
|
52 |
+
'dialogue': parts[0],
|
53 |
+
'speaker': parts[2],
|
54 |
+
"said_by": parts[1]
|
55 |
+
}
|
56 |
+
extracted_dialogues.append(dialogue_dict)
|
57 |
+
|
58 |
+
return extracted_dialogues
|
59 |
+
|
60 |
+
|
61 |
+
def extract_dialogues_from_glm_response(text):
|
62 |
+
# Split the text into lines
|
63 |
+
lines = text.split('\n')
|
64 |
+
|
65 |
+
# Initialize an empty list to store the extracted dialogues
|
66 |
+
extracted_dialogues = []
|
67 |
+
|
68 |
+
valid_said_by = ["said by", "thought by", "described by", "from"]
|
69 |
+
|
70 |
+
# Iterate through each line
|
71 |
+
for line in lines:
|
72 |
+
# Split the line by '|' and strip whitespace from each part
|
73 |
+
parts = [part.strip() for part in line.split('|')]
|
74 |
+
|
75 |
+
# Check if the line has 4 parts and the third part is 'said by'
|
76 |
+
if len(parts) == 4:
|
77 |
+
# Extract the dialogue and speaker, and add to the list
|
78 |
+
if parts[3] == "speaker":
|
79 |
+
continue
|
80 |
+
|
81 |
+
if parts[2].strip().lower() not in valid_said_by:
|
82 |
+
continue
|
83 |
+
|
84 |
+
try:
|
85 |
+
id_num = int(parts[0])
|
86 |
+
except ValueError:
|
87 |
+
id_num = id
|
88 |
+
|
89 |
+
dialogue_dict = {
|
90 |
+
'id': id_num,
|
91 |
+
'dialogue': parts[1],
|
92 |
+
'speaker': parts[3],
|
93 |
+
"said_by": parts[2]
|
94 |
+
}
|
95 |
+
extracted_dialogues.append(dialogue_dict)
|
96 |
+
|
97 |
+
return extracted_dialogues
|
98 |
+
|
99 |
+
|
100 |
+
def has_dialogue_sentences(text: str) -> int:
|
101 |
+
# 定义成对的引号
|
102 |
+
paired_quotes = [
|
103 |
+
("“", "”"),
|
104 |
+
("‘", "’"),
|
105 |
+
("「", "」")
|
106 |
+
]
|
107 |
+
# 定义符号列表(包括全角和半角的逗号和句号)
|
108 |
+
symbols = ['。', '!', '?', '*', '.', '?', '!', '"', '”', ',', '~', ')', ')', '…', ']', '♪',',']
|
109 |
+
|
110 |
+
# 检查成对引号内的内容
|
111 |
+
for start_quote, end_quote in paired_quotes:
|
112 |
+
start_index = text.find(start_quote)
|
113 |
+
while start_index != -1:
|
114 |
+
end_index = text.find(end_quote, start_index + 1)
|
115 |
+
if end_index != -1:
|
116 |
+
quote_content = text[start_index + 1:end_index]
|
117 |
+
# 检查引号内的内容是否符合条件
|
118 |
+
if any(symbol in quote_content for symbol in symbols) or len(quote_content) >= 10:
|
119 |
+
return 2 # 成对引号内有符号或长度>=10
|
120 |
+
start_index = text.find(start_quote, end_index + 1)
|
121 |
+
else:
|
122 |
+
break
|
123 |
+
|
124 |
+
# 检查双引号'"'
|
125 |
+
double_quotes_indices = [i for i, char in enumerate(text) if char == '"']
|
126 |
+
if len(double_quotes_indices) % 2 == 0: # 必须是偶数个双引号
|
127 |
+
for i in range(0, len(double_quotes_indices), 2):
|
128 |
+
start_index, end_index = double_quotes_indices[i], double_quotes_indices[i+1]
|
129 |
+
quote_content = text[start_index+1:end_index]
|
130 |
+
# 检查引号内的内容是否含有符号
|
131 |
+
if any(symbol in quote_content for symbol in symbols):
|
132 |
+
return 1 # 双引号内有符号
|
133 |
+
|
134 |
+
return 0 # 没有符合条件的对话型句子
|
135 |
+
|
136 |
+
def replace_recalled_dialogue( raw_text, response_text ):
|
137 |
+
dialogues = extract_dialogues_from_response( response_text )
|
138 |
+
|
139 |
+
lines = raw_text.split("\n")
|
140 |
+
|
141 |
+
lines = [line.strip().strip("\u3000") for line in lines]
|
142 |
+
|
143 |
+
recall_flag = [ False for line in lines ]
|
144 |
+
line2ids = [ [] for line in lines ]
|
145 |
+
|
146 |
+
for id, dialogue in enumerate(dialogues):
|
147 |
+
dialogue_text = dialogue['dialogue']
|
148 |
+
remove_symbol_text = dialogue_text.replace("*","").replace('"',"")
|
149 |
+
|
150 |
+
recall, lid = get_max_recall_in_lines( remove_symbol_text, lines )
|
151 |
+
|
152 |
+
if recall > 0.3:
|
153 |
+
recall_flag[lid] = True
|
154 |
+
line2ids[lid].append(id)
|
155 |
+
|
156 |
+
new_text = ""
|
157 |
+
|
158 |
+
for lid, line in enumerate(lines):
|
159 |
+
if recall_flag[lid]:
|
160 |
+
if len(line2ids[lid]) == 1 and ("未知" in dialogues[0]['speaker'] or dialogues[0]['speaker'].strip() == ""):
|
161 |
+
new_text += line + "\n"
|
162 |
+
continue
|
163 |
+
|
164 |
+
for dia_id in line2ids[lid]:
|
165 |
+
speaker = dialogues[dia_id]['speaker']
|
166 |
+
dialogue = dialogues[dia_id]['dialogue']
|
167 |
+
dialogue = dialogue.replace('"',"").replace('“',"").replace('”',"")
|
168 |
+
new_text += speaker + " : " + dialogue + "\n"
|
169 |
+
else:
|
170 |
+
new_text += line + "\n"
|
171 |
+
|
172 |
+
return new_text.strip()
|
173 |
+
|
174 |
+
|
175 |
+
|
176 |
+
|
ChatHaruhi/response_GLM_local.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from string import Template
|
3 |
+
from typing import List, Dict
|
4 |
+
|
5 |
+
import torch.cuda
|
6 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
7 |
+
|
8 |
+
aclient = None
|
9 |
+
|
10 |
+
client = None
|
11 |
+
tokenizer = None
|
12 |
+
|
13 |
+
END_POINT = "https://hf-mirror.com"
|
14 |
+
|
15 |
+
|
16 |
+
def init_client(model_name: str, verbose: bool) -> None:
|
17 |
+
"""
|
18 |
+
初始化模型,通过可用的设备进行模型加载推理。
|
19 |
+
|
20 |
+
Params:
|
21 |
+
model_name (`str`)
|
22 |
+
HuggingFace中的模型项目名,例如"THUDM/chatglm3-6b"
|
23 |
+
"""
|
24 |
+
|
25 |
+
# 将client设置为全局变量
|
26 |
+
global client
|
27 |
+
global tokenizer
|
28 |
+
|
29 |
+
# 判断 使用MPS、CUDA、CPU运行模型
|
30 |
+
if torch.cuda.is_available():
|
31 |
+
device = torch.device("cuda")
|
32 |
+
elif torch.backends.mps.is_available():
|
33 |
+
device = torch.device("mps")
|
34 |
+
else:
|
35 |
+
device = torch.device("cpu")
|
36 |
+
|
37 |
+
if verbose:
|
38 |
+
print("Using device: ", device)
|
39 |
+
|
40 |
+
# TODO 考虑支持deepspeed 进行多gpu推理,以及zero
|
41 |
+
|
42 |
+
try:
|
43 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
44 |
+
model_name, trust_remote_code=True, local_files_only=True)
|
45 |
+
client = AutoModelForCausalLM.from_pretrained(
|
46 |
+
model_name, trust_remote_code=True, local_files_only=True)
|
47 |
+
except Exception:
|
48 |
+
if pretrained_model_download(model_name, verbose=verbose):
|
49 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
50 |
+
model_name, trust_remote_code=True, local_files_only=True)
|
51 |
+
client = AutoModelForCausalLM.from_pretrained(
|
52 |
+
model_name, trust_remote_code=True, local_files_only=True)
|
53 |
+
|
54 |
+
client = client.to(device).eval()
|
55 |
+
|
56 |
+
|
57 |
+
def pretrained_model_download(model_name_or_path: str, verbose: bool) -> bool:
|
58 |
+
"""
|
59 |
+
使用huggingface_hub下载模型(model_name_or_path)。下载成功返回true,失败返回False。
|
60 |
+
Params:
|
61 |
+
model_name_or_path (`str`): 模型的huggingface地址
|
62 |
+
Returns:
|
63 |
+
`bool` 是否下载成功
|
64 |
+
"""
|
65 |
+
# TODO 使用hf镜像加速下载 未测试windows端
|
66 |
+
|
67 |
+
# 判断是否使用HF_transfer,默认不使用。
|
68 |
+
if os.getenv("HF_HUB_ENABLE_HF_TRANSFER") == 1:
|
69 |
+
try:
|
70 |
+
import hf_transfer
|
71 |
+
except ImportError:
|
72 |
+
print("Install hf_transfer.")
|
73 |
+
os.system("pip -q install hf_transfer")
|
74 |
+
import hf_transfer
|
75 |
+
|
76 |
+
# 尝试引入huggingface_hub
|
77 |
+
try:
|
78 |
+
import huggingface_hub
|
79 |
+
except ImportError:
|
80 |
+
print("Install huggingface_hub.")
|
81 |
+
os.system("pip -q install huggingface_hub")
|
82 |
+
import huggingface_hub
|
83 |
+
|
84 |
+
# 使用huggingface_hub下载模型。
|
85 |
+
try:
|
86 |
+
print(f"downloading {model_name_or_path}")
|
87 |
+
huggingface_hub.snapshot_download(
|
88 |
+
repo_id=model_name_or_path, endpoint=END_POINT, resume_download=True, local_dir_use_symlinks=False)
|
89 |
+
except Exception as e:
|
90 |
+
raise e
|
91 |
+
|
92 |
+
return True
|
93 |
+
|
94 |
+
|
95 |
+
def message2query(messages: List[Dict[str, str]]) -> str:
|
96 |
+
# [{'role': 'user', 'content': '老师: 同学请自我介绍一下'}]
|
97 |
+
# <|system|>
|
98 |
+
# You are ChatGLM3, a large language model trained by Zhipu.AI. Follow the user's instructions carefully. Respond using markdown.
|
99 |
+
# <|user|>
|
100 |
+
# Hello
|
101 |
+
# <|assistant|>
|
102 |
+
# Hello, I'm ChatGLM3. What can I assist you today?
|
103 |
+
template = Template("<|$role|>\n$content\n")
|
104 |
+
|
105 |
+
return "".join([template.substitute(message) for message in messages])
|
106 |
+
|
107 |
+
|
108 |
+
def get_response(message, model_name: str = "THUDM/chatglm3-6b", verbose: bool = False):
|
109 |
+
global client
|
110 |
+
global tokenizer
|
111 |
+
|
112 |
+
if client is None:
|
113 |
+
init_client(model_name, verbose=verbose)
|
114 |
+
|
115 |
+
if verbose:
|
116 |
+
print(message)
|
117 |
+
print(message2query(message))
|
118 |
+
|
119 |
+
response, history = client.chat(tokenizer, message2query(message))
|
120 |
+
|
121 |
+
return response
|
ChatHaruhi/response_GLM_lora.py
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from string import Template
|
3 |
+
from typing import List, Dict
|
4 |
+
|
5 |
+
import torch.cuda
|
6 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
7 |
+
from peft import AutoPeftModelForCausalLM
|
8 |
+
|
9 |
+
|
10 |
+
aclient = None
|
11 |
+
|
12 |
+
client = None
|
13 |
+
tokenizer = None
|
14 |
+
|
15 |
+
END_POINT = "https://hf-mirror.com"
|
16 |
+
|
17 |
+
|
18 |
+
def init_client(model_name: str, verbose: bool) -> None:
|
19 |
+
"""
|
20 |
+
初始化模型,通过可用的设备进行模型加载推理。
|
21 |
+
|
22 |
+
Params:
|
23 |
+
model_name (`str`)
|
24 |
+
HuggingFace中的模型项目名,例如"THUDM/chatglm3-6b"
|
25 |
+
"""
|
26 |
+
|
27 |
+
# 将client设置为全局变量
|
28 |
+
global client
|
29 |
+
global tokenizer
|
30 |
+
|
31 |
+
# 判断 使用MPS、CUDA、CPU运行模型
|
32 |
+
if torch.cuda.is_available():
|
33 |
+
device = torch.device("cuda")
|
34 |
+
elif torch.backends.mps.is_available():
|
35 |
+
device = torch.device("mps")
|
36 |
+
else:
|
37 |
+
device = torch.device("cpu")
|
38 |
+
|
39 |
+
if verbose:
|
40 |
+
print("Using device: ", device)
|
41 |
+
|
42 |
+
# TODO 上传模型后,更改为从huggingface获取模型
|
43 |
+
client = AutoPeftModelForCausalLM.from_pretrained(
|
44 |
+
model_name, trust_remote_code=True)
|
45 |
+
tokenizer_dir = client.peft_config['default'].base_model_name_or_path
|
46 |
+
if verbose:
|
47 |
+
print(tokenizer_dir)
|
48 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
49 |
+
tokenizer_dir, trust_remote_code=True)
|
50 |
+
|
51 |
+
# try:
|
52 |
+
# tokenizer = AutoTokenizer.from_pretrained(
|
53 |
+
# model_name, trust_remote_code=True, local_files_only=True)
|
54 |
+
# client = AutoModelForCausalLM.from_pretrained(
|
55 |
+
# model_name, trust_remote_code=True, local_files_only=True)
|
56 |
+
# except Exception:
|
57 |
+
# if pretrained_model_download(model_name, verbose=verbose):
|
58 |
+
# tokenizer = AutoTokenizer.from_pretrained(
|
59 |
+
# model_name, trust_remote_code=True, local_files_only=True)
|
60 |
+
# client = AutoModelForCausalLM.from_pretrained(
|
61 |
+
# model_name, trust_remote_code=True, local_files_only=True)
|
62 |
+
|
63 |
+
# client = client.to(device).eval()
|
64 |
+
client = client.eval()
|
65 |
+
|
66 |
+
|
67 |
+
def pretrained_model_download(model_name_or_path: str, verbose: bool) -> bool:
|
68 |
+
"""
|
69 |
+
使用huggingface_hub下载模型(model_name_or_path)。下载成功返回true,失败返回False。
|
70 |
+
Params:
|
71 |
+
model_name_or_path (`str`): 模型的huggingface地址
|
72 |
+
Returns:
|
73 |
+
`bool` 是否下载成功
|
74 |
+
"""
|
75 |
+
# TODO 使用hf镜像加速下载 未测试windows端
|
76 |
+
|
77 |
+
# 判断是否使用HF_transfer,默认不使用。
|
78 |
+
if os.getenv("HF_HUB_ENABLE_HF_TRANSFER") == 1:
|
79 |
+
try:
|
80 |
+
import hf_transfer
|
81 |
+
except ImportError:
|
82 |
+
print("Install hf_transfer.")
|
83 |
+
os.system("pip -q install hf_transfer")
|
84 |
+
import hf_transfer
|
85 |
+
|
86 |
+
# 尝试引入huggingface_hub
|
87 |
+
try:
|
88 |
+
import huggingface_hub
|
89 |
+
except ImportError:
|
90 |
+
print("Install huggingface_hub.")
|
91 |
+
os.system("pip -q install huggingface_hub")
|
92 |
+
import huggingface_hub
|
93 |
+
|
94 |
+
# 使用huggingface_hub下载模型。
|
95 |
+
try:
|
96 |
+
print(f"downloading {model_name_or_path}")
|
97 |
+
huggingface_hub.snapshot_download(
|
98 |
+
repo_id=model_name_or_path, endpoint=END_POINT, resume_download=True, local_dir_use_symlinks=False)
|
99 |
+
except Exception as e:
|
100 |
+
raise e
|
101 |
+
|
102 |
+
return True
|
103 |
+
|
104 |
+
|
105 |
+
def message2query(messages: List[Dict[str, str]]) -> str:
|
106 |
+
# [{'role': 'user', 'content': '老师: 同学请自我介绍一下'}]
|
107 |
+
# <|system|>
|
108 |
+
# You are ChatGLM3, a large language model trained by Zhipu.AI. Follow the user's instructions carefully. Respond using markdown.
|
109 |
+
# <|user|>
|
110 |
+
# Hello
|
111 |
+
# <|assistant|>
|
112 |
+
# Hello, I'm ChatGLM3. What can I assist you today?
|
113 |
+
template = Template("<|$role|>\n$content\n")
|
114 |
+
|
115 |
+
return "".join([template.substitute(message) for message in messages])
|
116 |
+
|
117 |
+
|
118 |
+
def get_response(message, model_name: str = "/workspace/jyh/Zero-Haruhi/checkpoint-1500", verbose: bool = True):
|
119 |
+
global client
|
120 |
+
global tokenizer
|
121 |
+
|
122 |
+
if client is None:
|
123 |
+
init_client(model_name, verbose=verbose)
|
124 |
+
|
125 |
+
if verbose:
|
126 |
+
print(message)
|
127 |
+
print(message2query(message))
|
128 |
+
|
129 |
+
response, history = client.chat(tokenizer, message2query(message))
|
130 |
+
if verbose:
|
131 |
+
print((response, history))
|
132 |
+
|
133 |
+
return response
|
ChatHaruhi/response_erniebot.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import erniebot
|
2 |
+
|
3 |
+
aclient = None
|
4 |
+
|
5 |
+
client = None
|
6 |
+
|
7 |
+
import os
|
8 |
+
|
9 |
+
def normalize2uaua_ernie( message, if_replace_system = False ):
|
10 |
+
new_message = []
|
11 |
+
last_role = ""
|
12 |
+
|
13 |
+
for msg in message:
|
14 |
+
role = msg["role"]
|
15 |
+
if if_replace_system and role == "system":
|
16 |
+
role = "user"
|
17 |
+
msg["role"] = role
|
18 |
+
|
19 |
+
if last_role == role:
|
20 |
+
new_message[-1]["content"] = new_message[-1]["content"] + "\n" + msg["content"]
|
21 |
+
else:
|
22 |
+
last_role = role
|
23 |
+
new_message.append( msg )
|
24 |
+
|
25 |
+
return new_message
|
26 |
+
|
27 |
+
def init_client():
|
28 |
+
|
29 |
+
# 将client设置为全局变量
|
30 |
+
global client
|
31 |
+
|
32 |
+
# 将ERNIE_ACCESS_TOKEN作为参数值传递给OS
|
33 |
+
api_key = os.getenv("ERNIE_ACCESS_TOKEN")
|
34 |
+
if api_key is None:
|
35 |
+
raise ValueError("环境变量'ERNIE_ACCESS_TOKEN'未设置,请确保已经定义了API密钥")
|
36 |
+
erniebot.api_type = "aistudio"
|
37 |
+
erniebot.access_token = api_key
|
38 |
+
client = erniebot
|
39 |
+
|
40 |
+
def get_response( message, model_name = "ernie-4.0" ):
|
41 |
+
if client is None:
|
42 |
+
init_client()
|
43 |
+
|
44 |
+
message_ua = normalize2uaua_ernie(message, if_replace_system = True)
|
45 |
+
# print(message_ua)
|
46 |
+
response = client.ChatCompletion.create(\
|
47 |
+
model=model_name,\
|
48 |
+
messages = message_ua, \
|
49 |
+
temperature = 0.1 )
|
50 |
+
return response.get_result()
|
51 |
+
|
52 |
+
import json
|
53 |
+
import asyncio
|
54 |
+
from erniebot_agent.chat_models import ERNIEBot
|
55 |
+
from erniebot_agent.memory import HumanMessage, AIMessage, SystemMessage, FunctionMessage
|
56 |
+
|
57 |
+
def init_aclient(model="ernie-4.0"):
|
58 |
+
|
59 |
+
# 将aclient设置为全局变量
|
60 |
+
global aclient
|
61 |
+
|
62 |
+
api_key = os.getenv("ERNIE_ACCESS_TOKEN")
|
63 |
+
if api_key is None:
|
64 |
+
raise ValueError("环境变量'ERNIE_ACCESS_TOKEN'未设置。请确保已经定义了API密钥。")
|
65 |
+
os.environ["EB_AGENT_ACCESS_TOKEN"] = api_key
|
66 |
+
aclient = ERNIEBot(model=model) # 创建模型
|
67 |
+
|
68 |
+
|
69 |
+
|
70 |
+
async def async_get_response( message, model="ernie-4.0" ):
|
71 |
+
if aclient is None:
|
72 |
+
init_aclient(model=model)
|
73 |
+
|
74 |
+
messages = []
|
75 |
+
system_message = None
|
76 |
+
message_ua = normalize2uaua_ernie(message, if_replace_system = False)
|
77 |
+
print(message_ua)
|
78 |
+
for item in message_ua:
|
79 |
+
if item["role"] == "user":
|
80 |
+
messages.append(HumanMessage(item["content"]))
|
81 |
+
elif item["role"] == "system":
|
82 |
+
system_message = SystemMessage(item["content"])
|
83 |
+
else:
|
84 |
+
messages.append(AIMessage(item["content"]))
|
85 |
+
if system_message:
|
86 |
+
ai_message = await aclient.chat(messages=messages, temperature = 0.1)
|
87 |
+
else:
|
88 |
+
ai_message = await aclient.chat(messages=messages, system=system_message.content, temperature = 0.1) # 调用模型chat接口,非流式返回
|
89 |
+
|
90 |
+
return ai_message.content
|
ChatHaruhi/response_openai.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import openai
|
2 |
+
|
3 |
+
aclient = None
|
4 |
+
|
5 |
+
client = None
|
6 |
+
|
7 |
+
import os
|
8 |
+
from openai import OpenAI
|
9 |
+
|
10 |
+
def init_client():
|
11 |
+
# 将client设置为全局变量,以便在其他函数中使用
|
12 |
+
global client
|
13 |
+
|
14 |
+
# 检查是否存在API_KEY环境变量
|
15 |
+
api_key = os.getenv("OPENAI_API_KEY")
|
16 |
+
if api_key is None:
|
17 |
+
raise ValueError("环境变量'OPENAI_API_KEY'未设置。请确保已经定义了API密钥。")
|
18 |
+
|
19 |
+
# 检查是否存在API_BASE环境变量,并据此设置base_url参数
|
20 |
+
api_base = os.getenv("OPENAI_API_BASE")
|
21 |
+
if api_base:
|
22 |
+
client = OpenAI(base_url=api_base, api_key=api_key)
|
23 |
+
else:
|
24 |
+
client = OpenAI(api_key=api_key)
|
25 |
+
|
26 |
+
|
27 |
+
|
28 |
+
def get_response( message ):
|
29 |
+
if client is None:
|
30 |
+
init_client()
|
31 |
+
response = client.chat.completions.create(\
|
32 |
+
model="gpt-3.5-turbo",\
|
33 |
+
messages = message, \
|
34 |
+
max_tokens = 300, \
|
35 |
+
temperature = 0.1 )
|
36 |
+
return response.choices[0].message.content
|
37 |
+
|
38 |
+
from openai import AsyncOpenAI
|
39 |
+
|
40 |
+
def init_aclient():
|
41 |
+
# 将aclient设置为全局变量,以便在其他函数中使用
|
42 |
+
global aclient
|
43 |
+
|
44 |
+
# 检查是否存在API_KEY环境变量
|
45 |
+
api_key = os.getenv("OPENAI_API_KEY")
|
46 |
+
if api_key is None:
|
47 |
+
raise ValueError("环境变量'OPENAI_API_KEY'未设置。请确保已经定义了API密钥。")
|
48 |
+
|
49 |
+
# 检查是否存在API_BASE环境变量,并据此设置base_url参数
|
50 |
+
api_base = os.getenv("OPENAI_API_BASE")
|
51 |
+
if api_base:
|
52 |
+
aclient = AsyncOpenAI(base_url=api_base, api_key=api_key)
|
53 |
+
else:
|
54 |
+
aclient = AsyncOpenAI(api_key=api_key)
|
55 |
+
|
56 |
+
async def async_get_response( message ):
|
57 |
+
if aclient is None:
|
58 |
+
init_aclient()
|
59 |
+
response = await aclient.chat.completions.create(\
|
60 |
+
model="gpt-3.5-turbo",\
|
61 |
+
messages = message, \
|
62 |
+
max_tokens = 300, \
|
63 |
+
temperature = 0.1 )
|
64 |
+
return response.choices[0].message.content
|
65 |
+
|
ChatHaruhi/response_spark.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from . import SparkApi
|
2 |
+
|
3 |
+
aclient = None
|
4 |
+
|
5 |
+
client = None
|
6 |
+
|
7 |
+
import os
|
8 |
+
|
9 |
+
def init_client():
|
10 |
+
|
11 |
+
# 将client设置为全局变量
|
12 |
+
global client
|
13 |
+
|
14 |
+
# 将ERNIE_ACCESS_TOKEN作为参数值传递给OS
|
15 |
+
appid = os.getenv("SPARK_APPID")
|
16 |
+
api_secret = os.getenv("SPARK_API_SECRET")
|
17 |
+
api_key = os.getenv("SPARK_API_KEY")
|
18 |
+
if appid is None:
|
19 |
+
raise ValueError("环境变量'SPARK_APPID'未设置,请确保已经定义了API密钥")
|
20 |
+
if api_secret is None:
|
21 |
+
raise ValueError("环境变量'SPARK_API_SECRET'未设置,请确保已经定义了API密钥")
|
22 |
+
if api_key is None:
|
23 |
+
raise ValueError("环境变量'SPARK_API_KEY'未设置,请确保已经定义了API密钥")
|
24 |
+
SparkApi.appid = appid
|
25 |
+
SparkApi.api_secret = api_secret
|
26 |
+
SparkApi.api_key = api_key
|
27 |
+
client = SparkApi
|
28 |
+
|
29 |
+
def get_response(message, model_name = "Spark3.5"):
|
30 |
+
if client is None:
|
31 |
+
init_client()
|
32 |
+
|
33 |
+
if model_name == "Spark2.0":
|
34 |
+
domain = "generalv2" # v2.0版本
|
35 |
+
Spark_url = "ws://spark-api.xf-yun.com/v2.1/chat" # v2.0环境的地址
|
36 |
+
elif model_name == "Spark1.5":
|
37 |
+
domain = "general" # v1.5版本
|
38 |
+
Spark_url = "ws://spark-api.xf-yun.com/v1.1/chat" # v1.5环境的地址
|
39 |
+
elif model_name == "Spark3.0":
|
40 |
+
domain = "generalv3" # v3.0版本
|
41 |
+
Spark_url = "ws://spark-api.xf-yun.com/v3.1/chat" # v3.0环境的地址
|
42 |
+
elif model_name == "Spark3.5":
|
43 |
+
domain = "generalv3.5" # v3.5版本
|
44 |
+
Spark_url = "ws://spark-api.xf-yun.com/v3.5/chat" # v3.5环境的地址
|
45 |
+
else:
|
46 |
+
raise Exception("Unknown Spark model")
|
47 |
+
# print(message_ua)
|
48 |
+
client.answer = ""
|
49 |
+
client.main(client.appid,client.api_key,client.api_secret,Spark_url,domain,message)
|
50 |
+
return client.answer
|
51 |
+
|
ChatHaruhi/response_zhipu.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import zhipuai
|
2 |
+
|
3 |
+
aclient = None
|
4 |
+
|
5 |
+
client = None
|
6 |
+
|
7 |
+
import os
|
8 |
+
from zhipuai import ZhipuAI
|
9 |
+
|
10 |
+
def init_client():
|
11 |
+
|
12 |
+
# 将client设置为全局变量
|
13 |
+
global client
|
14 |
+
|
15 |
+
# 将ZHIPUAI_API_KEY作为参数值传递给OS
|
16 |
+
api_key = os.getenv("ZHIPUAI_API_KEY")
|
17 |
+
if api_key is None:
|
18 |
+
raise ValueError("环境变量'ZHIPUAI_API_KEY'未设置,请确保已经定义了API密钥")
|
19 |
+
|
20 |
+
client = ZhipuAI(api_key=api_key)
|
21 |
+
|
22 |
+
|
23 |
+
def init_aclient():
|
24 |
+
|
25 |
+
# 将aclient设置为全局变量
|
26 |
+
global aclient
|
27 |
+
|
28 |
+
# 将ZHIPUAI_API_KEY作为参数值传递给OS
|
29 |
+
api_key = os.getenv("ZHIPUAI_API_KEY")
|
30 |
+
if api_key is None:
|
31 |
+
raise ValueError("环境变量'ZHIPUAI_API_KEY'未设置,请确保已经定义了API密钥")
|
32 |
+
|
33 |
+
def get_response( message, model_name = "glm-3-turbo" ):
|
34 |
+
if client is None:
|
35 |
+
init_client()
|
36 |
+
response = client.chat.completions.create(\
|
37 |
+
model=model_name,\
|
38 |
+
messages = message, \
|
39 |
+
max_tokens = 300, \
|
40 |
+
temperature = 0.1 )
|
41 |
+
return response.choices[0].message.content
|
42 |
+
|
43 |
+
|
44 |
+
|
45 |
+
|
46 |
+
|
47 |
+
|
48 |
+
|
49 |
+
|
50 |
+
|
51 |
+
|
52 |
+
|
53 |
+
|
54 |
+
|
ChatHaruhi/sugar_map.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
sugar_role_names = {'汤师爷': 'tangshiye', 'tangshiye': 'tangshiye', 'Tangshiye': 'tangshiye',
|
2 |
+
'慕容复': 'murongfu', 'murongfu': 'murongfu', 'Murongfu': 'murongfu',
|
3 |
+
'李云龙': 'liyunlong', 'liyunlong': 'liyunlong', 'Liyunlong': 'liyunlong',
|
4 |
+
'Luna': 'Luna', '王多鱼': 'wangduoyu', 'wangduoyu': 'wangduoyu',
|
5 |
+
'Wangduoyu': 'wangduoyu', 'Ron': 'Ron', '鸠摩智': 'jiumozhi',
|
6 |
+
'jiumozhi': 'jiumozhi', 'Jiumozhi': 'jiumozhi', 'Snape': 'Snape',
|
7 |
+
'凉宫春日': 'haruhi', 'haruhi': 'haruhi', 'Haruhi': 'haruhi',
|
8 |
+
'Malfoy': 'Malfoy', '虚竹': 'xuzhu', 'xuzhu': 'xuzhu',
|
9 |
+
'Xuzhu': 'xuzhu', '萧峰': 'xiaofeng',
|
10 |
+
'xiaofeng': 'xiaofeng', 'Xiaofeng': 'xiaofeng', '段誉': 'duanyu',
|
11 |
+
'duanyu': 'duanyu', 'Duanyu': 'duanyu', 'Hermione': 'Hermione',
|
12 |
+
'Dumbledore': 'Dumbledore', '王语嫣': 'wangyuyan', 'wangyuyan':
|
13 |
+
'wangyuyan', 'Wangyuyan': 'wangyuyan', 'Harry': 'Harry',
|
14 |
+
'McGonagall': 'McGonagall', '白展堂': 'baizhantang',
|
15 |
+
'baizhantang': 'baizhantang', 'Baizhantang': 'baizhantang',
|
16 |
+
'佟湘玉': 'tongxiangyu', 'tongxiangyu': 'tongxiangyu',
|
17 |
+
'Tongxiangyu': 'tongxiangyu', '郭芙蓉': 'guofurong',
|
18 |
+
'guofurong': 'guofurong', 'Guofurong': 'guofurong', '流浪者': 'wanderer',
|
19 |
+
'wanderer': 'wanderer', 'Wanderer': 'wanderer', '钟离': 'zhongli',
|
20 |
+
'zhongli': 'zhongli', 'Zhongli': 'zhongli', '胡桃': 'hutao', 'hutao': 'hutao',
|
21 |
+
'Hutao': 'hutao', 'Sheldon': 'Sheldon', 'Raj': 'Raj',
|
22 |
+
'Penny': 'Penny', '韦小宝': 'weixiaobao', 'weixiaobao': 'weixiaobao',
|
23 |
+
'Weixiaobao': 'weixiaobao', '乔峰': 'qiaofeng', 'qiaofeng': 'qiaofeng',
|
24 |
+
'Qiaofeng': 'qiaofeng', '神里绫华': 'ayaka', 'ayaka': 'ayaka',
|
25 |
+
'Ayaka': 'ayaka', '雷电将军': 'raidenShogun', 'raidenShogun': 'raidenShogun',
|
26 |
+
'RaidenShogun': 'raidenShogun', '于谦': 'yuqian', 'yuqian': 'yuqian',
|
27 |
+
'Yuqian': 'yuqian', 'Professor McGonagall': 'McGonagall',
|
28 |
+
'Professor Dumbledore': 'Dumbledore'}
|
29 |
+
|
30 |
+
enname2zhname = {'tangshiye': '汤师爷', 'murongfu': '慕容复', 'liyunlong': '李云龙', 'Luna': 'Luna', 'wangduoyu': '王多鱼', 'Ron': 'Ron', 'jiumozhi': '鸠摩智', 'Snape': 'Snape', 'haruhi': '凉宫春日', 'Malfoy': 'Malfoy', 'xuzhu': '虚竹', 'xiaofeng': '萧峰', 'duanyu': '段誉', 'Hermione': 'Hermione', 'Dumbledore': 'Dumbledore', 'wangyuyan': '王语嫣', 'Harry': 'Harry', 'McGonagall': 'McGonagall', 'baizhantang': '白展堂', 'tongxiangyu': '佟湘玉', 'guofurong': '郭芙蓉', 'wanderer': '流浪者', 'zhongli': '钟离', 'hutao': '胡桃', 'Sheldon': 'Sheldon', 'Raj': 'Raj', 'Penny': 'Penny', 'weixiaobao': '韦小宝', 'qiaofeng': '乔峰', 'ayaka': '神里绫华', 'raidenShogun': '雷电将军', 'yuqian': '于谦'}
|
ChatHaruhi/utils.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tiktoken
|
2 |
+
|
3 |
+
_enc_model = None
|
4 |
+
|
5 |
+
def normalize2uaua( message, if_replace_system = False ):
|
6 |
+
new_message = []
|
7 |
+
last_role = ""
|
8 |
+
|
9 |
+
for msg in message:
|
10 |
+
role = msg["role"]
|
11 |
+
if if_replace_system and role == "system":
|
12 |
+
role = "user"
|
13 |
+
|
14 |
+
if last_role == role:
|
15 |
+
new_message[-1]["content"] = new_message[-1]["content"] + "\n" + msg["content"]
|
16 |
+
else:
|
17 |
+
last_role = role
|
18 |
+
new_message.append( msg )
|
19 |
+
|
20 |
+
return new_message
|
21 |
+
|
22 |
+
def tiktoken_counter( text ):
|
23 |
+
global _enc_model
|
24 |
+
|
25 |
+
if _enc_model is None:
|
26 |
+
_enc_model = tiktoken.get_encoding("cl100k_base")
|
27 |
+
|
28 |
+
return len(_enc_model.encode(text))
|
29 |
+
|
30 |
+
|
31 |
+
def string_to_base64(text):
|
32 |
+
import base64
|
33 |
+
byte_array = b''
|
34 |
+
for char in text:
|
35 |
+
num_bytes = char.encode('utf-8')
|
36 |
+
byte_array += num_bytes
|
37 |
+
|
38 |
+
base64_data = base64.b64encode(byte_array)
|
39 |
+
return base64_data.decode('utf-8')
|
40 |
+
|
41 |
+
def base64_to_string(base64_data):
|
42 |
+
import base64
|
43 |
+
byte_array = base64.b64decode(base64_data)
|
44 |
+
text = byte_array.decode('utf-8')
|
45 |
+
return text
|
46 |
+
|
47 |
+
|
48 |
+
def float_array_to_base64(float_arr):
|
49 |
+
import struct
|
50 |
+
import base64
|
51 |
+
byte_array = b''
|
52 |
+
|
53 |
+
for f in float_arr:
|
54 |
+
# 将每个浮点数打包为4字节
|
55 |
+
num_bytes = struct.pack('!f', f)
|
56 |
+
byte_array += num_bytes
|
57 |
+
|
58 |
+
# 将字节数组进行base64编码
|
59 |
+
base64_data = base64.b64encode(byte_array)
|
60 |
+
|
61 |
+
return base64_data.decode('utf-8')
|
62 |
+
|
63 |
+
def base64_to_float_array(base64_data):
|
64 |
+
import struct
|
65 |
+
import base64
|
66 |
+
byte_array = base64.b64decode(base64_data)
|
67 |
+
|
68 |
+
float_array = []
|
69 |
+
|
70 |
+
# 每 4 个字节解析为一个浮点数
|
71 |
+
for i in range(0, len(byte_array), 4):
|
72 |
+
num = struct.unpack('!f', byte_array[i:i+4])[0]
|
73 |
+
float_array.append(num)
|
74 |
+
|
75 |
+
return float_array
|
76 |
+
|
77 |
+
def load_datas_from_jsonl( file_path ):
|
78 |
+
import json
|
79 |
+
datas = []
|
80 |
+
with open(file_path, 'r', encoding = 'utf-8') as f:
|
81 |
+
for line in f:
|
82 |
+
datas.append(json.loads(line))
|
83 |
+
return datas
|
84 |
+
|
85 |
+
def save_datas_to_jsonl( file_path, datas ):
|
86 |
+
import json
|
87 |
+
with open(file_path, 'w', encoding = 'utf-8') as f:
|
88 |
+
for data in datas:
|
89 |
+
f.write(json.dumps(data, ensure_ascii=False) + '\n')
|