Spaces:
Running
Running
Germano Cavalcante
commited on
Commit
β’
0576e6d
1
Parent(s):
086be5c
Add wiki to the rag system
Browse files
routers/embedding/{embeddings_manual.pkl β embeddings_manual_wiki.pkl}
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3c4a71f60f1878e528b190c3c43f744611c90efdea4c2ef333962773fd2fd637
|
3 |
+
size 19670346
|
routers/tool_wiki_search.py
CHANGED
@@ -1,28 +1,60 @@
|
|
1 |
-
# routers/
|
2 |
|
|
|
3 |
import os
|
4 |
import pickle
|
5 |
import re
|
6 |
import torch
|
|
|
7 |
from typing import Dict, List
|
8 |
from sentence_transformers import util
|
9 |
from fastapi import APIRouter
|
10 |
from fastapi.responses import PlainTextResponse
|
|
|
11 |
|
12 |
try:
|
13 |
from .embedding import EMBEDDING_CTX
|
14 |
except:
|
15 |
from embedding import EMBEDDING_CTX
|
16 |
|
17 |
-
router = APIRouter()
|
18 |
-
|
19 |
MANUAL_DIR = "D:/BlenderDev/blender-manual/manual/"
|
20 |
-
|
21 |
-
|
|
|
|
|
|
|
|
|
22 |
|
23 |
|
24 |
class _Data(dict):
|
25 |
-
cache_path = "routers/embedding/
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
|
27 |
@staticmethod
|
28 |
def reduce_text(text):
|
@@ -38,24 +70,20 @@ class _Data(dict):
|
|
38 |
return text
|
39 |
|
40 |
@classmethod
|
41 |
-
def parse_file_recursive(cls,
|
42 |
-
with open(
|
43 |
content = file.read()
|
44 |
|
45 |
parsed_data = {}
|
46 |
|
47 |
-
if
|
48 |
-
|
49 |
-
else:
|
50 |
parts = content.split(".. toctree::")
|
51 |
-
body = parts[0].strip()
|
52 |
-
|
53 |
if len(parts) > 1:
|
54 |
parsed_data["toctree"] = {}
|
55 |
for part in parts[1:]:
|
56 |
-
toctree_entries = part.
|
57 |
-
|
58 |
-
for entry in toctree_entries[1:]:
|
59 |
entry = entry.strip()
|
60 |
if not entry:
|
61 |
continue
|
@@ -67,20 +95,12 @@ class _Data(dict):
|
|
67 |
if not entry.endswith('.rst'):
|
68 |
continue
|
69 |
|
70 |
-
|
71 |
-
|
72 |
-
filedir_ = os.path.join(filedir, entry_name)
|
73 |
-
filename_ = 'index.rst'
|
74 |
-
else:
|
75 |
-
entry_name = entry[:-4]
|
76 |
-
filedir_ = filedir
|
77 |
-
filename_ = entry
|
78 |
-
|
79 |
parsed_data['toctree'][entry_name] = cls.parse_file_recursive(
|
80 |
-
|
81 |
|
82 |
-
|
83 |
-
parsed_data['body'] = body + '\n'
|
84 |
|
85 |
return parsed_data
|
86 |
|
@@ -221,82 +241,122 @@ class _Data(dict):
|
|
221 |
return result
|
222 |
|
223 |
@classmethod
|
224 |
-
def get_texts_recursive(cls, page, path=''):
|
225 |
result = cls.split_into_many(page['body'], path)
|
226 |
|
227 |
try:
|
228 |
for key in page['toctree'].keys():
|
229 |
page_child = page['toctree'][key]
|
230 |
result.extend(cls.get_texts_recursive(
|
231 |
-
page_child,
|
232 |
except KeyError:
|
233 |
pass
|
234 |
|
235 |
return result
|
236 |
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
# Generate
|
245 |
|
246 |
-
|
247 |
-
manual['toctree']["copyright"] = self.parse_file_recursive(
|
248 |
-
MANUAL_DIR, 'copyright.rst')
|
249 |
|
250 |
-
|
251 |
-
|
|
|
|
|
252 |
|
253 |
-
|
254 |
-
|
255 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
256 |
|
257 |
-
|
258 |
-
# Converting the embeddings to be CPU compatible, as the virtual machine in use currently only supports the CPU.
|
259 |
-
self['embeddings'] = self['embeddings'].to(torch.device('cpu'))
|
260 |
-
pickle.dump(dict(self), file, protocol=pickle.HIGHEST_PROTOCOL)
|
261 |
|
262 |
-
|
263 |
|
264 |
-
|
265 |
-
|
|
|
|
|
266 |
|
267 |
-
|
268 |
-
|
269 |
-
|
|
|
|
|
|
|
|
|
270 |
|
271 |
-
texts = self['texts']
|
272 |
-
for score in ret[0]:
|
273 |
corpus_id = score['corpus_id']
|
274 |
text = texts[corpus_id]
|
275 |
-
|
276 |
|
277 |
-
return
|
278 |
|
279 |
|
280 |
G_data = _Data()
|
281 |
|
|
|
282 |
|
283 |
-
@router.get("/wiki_search", response_class=PlainTextResponse)
|
284 |
-
def wiki_search(query: str = "") -> str:
|
285 |
-
data = G_data._embeddings_generate()
|
286 |
-
texts = G_data._sort_similarity(query, 5)
|
287 |
|
288 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
289 |
for text in texts:
|
290 |
-
|
291 |
-
|
|
|
|
|
|
|
|
|
|
|
292 |
{text[:index] + '.html'}
|
293 |
{text[index:]}
|
294 |
-
|
295 |
'''
|
296 |
return result
|
297 |
|
298 |
|
299 |
if __name__ == '__main__':
|
300 |
-
tests = ["Set Snap Base", "Building the Manual",
|
301 |
-
|
|
|
302 |
print(result)
|
|
|
1 |
+
# routers/tool_wiki_search.py
|
2 |
|
3 |
+
import base64
|
4 |
import os
|
5 |
import pickle
|
6 |
import re
|
7 |
import torch
|
8 |
+
from enum import Enum
|
9 |
from typing import Dict, List
|
10 |
from sentence_transformers import util
|
11 |
from fastapi import APIRouter
|
12 |
from fastapi.responses import PlainTextResponse
|
13 |
+
from utils_gitea import gitea_wiki_page_get, gitea_wiki_pages_get
|
14 |
|
15 |
try:
|
16 |
from .embedding import EMBEDDING_CTX
|
17 |
except:
|
18 |
from embedding import EMBEDDING_CTX
|
19 |
|
|
|
|
|
20 |
MANUAL_DIR = "D:/BlenderDev/blender-manual/manual/"
|
21 |
+
|
22 |
+
|
23 |
+
class Group(str, Enum):
|
24 |
+
wiki = "wiki"
|
25 |
+
manual = "manual"
|
26 |
+
all = "all"
|
27 |
|
28 |
|
29 |
class _Data(dict):
|
30 |
+
cache_path = "routers/embedding/embeddings_manual_wiki.pkl"
|
31 |
+
|
32 |
+
def __init__(self):
|
33 |
+
if os.path.exists(self.cache_path):
|
34 |
+
with open(self.cache_path, 'rb') as file:
|
35 |
+
data = pickle.load(file)
|
36 |
+
self.update(data)
|
37 |
+
return
|
38 |
+
|
39 |
+
# Generate
|
40 |
+
|
41 |
+
print("Embedding Texts...")
|
42 |
+
for grp in list(Group)[:-1]:
|
43 |
+
self[grp.name] = {}
|
44 |
+
|
45 |
+
# Create a list to store the text files
|
46 |
+
texts = self.manual_get_texts_to_embed(
|
47 |
+
) if grp == Group.manual else self.wiki_get_texts_to_embed()
|
48 |
+
|
49 |
+
self[grp]['texts'] = texts
|
50 |
+
self[grp]['embeddings'] = EMBEDDING_CTX.encode(texts)
|
51 |
+
|
52 |
+
with open(self.cache_path, "wb") as file:
|
53 |
+
# Converting the embeddings to be CPU compatible, as the virtual machine in use currently only supports the CPU.
|
54 |
+
for val in self.values():
|
55 |
+
val['embeddings'] = val['embeddings'].to(torch.device('cpu'))
|
56 |
+
|
57 |
+
pickle.dump(dict(self), file, protocol=pickle.HIGHEST_PROTOCOL)
|
58 |
|
59 |
@staticmethod
|
60 |
def reduce_text(text):
|
|
|
70 |
return text
|
71 |
|
72 |
@classmethod
|
73 |
+
def parse_file_recursive(cls, filepath):
|
74 |
+
with open(filepath, 'r', encoding='utf-8') as file:
|
75 |
content = file.read()
|
76 |
|
77 |
parsed_data = {}
|
78 |
|
79 |
+
if filepath.endswith('index.rst'):
|
80 |
+
filedir = os.path.dirname(filepath)
|
|
|
81 |
parts = content.split(".. toctree::")
|
|
|
|
|
82 |
if len(parts) > 1:
|
83 |
parsed_data["toctree"] = {}
|
84 |
for part in parts[1:]:
|
85 |
+
toctree_entries = part.splitlines()[1:]
|
86 |
+
for entry in toctree_entries:
|
|
|
87 |
entry = entry.strip()
|
88 |
if not entry:
|
89 |
continue
|
|
|
95 |
if not entry.endswith('.rst'):
|
96 |
continue
|
97 |
|
98 |
+
entry_name = entry[:-4] # remove '.rst'
|
99 |
+
filepath_iter = os.path.join(filedir, entry)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
parsed_data['toctree'][entry_name] = cls.parse_file_recursive(
|
101 |
+
filepath_iter)
|
102 |
|
103 |
+
parsed_data['body'] = content
|
|
|
104 |
|
105 |
return parsed_data
|
106 |
|
|
|
241 |
return result
|
242 |
|
243 |
@classmethod
|
244 |
+
def get_texts_recursive(cls, page, path='index'):
|
245 |
result = cls.split_into_many(page['body'], path)
|
246 |
|
247 |
try:
|
248 |
for key in page['toctree'].keys():
|
249 |
page_child = page['toctree'][key]
|
250 |
result.extend(cls.get_texts_recursive(
|
251 |
+
page_child, path.replace('index', key)))
|
252 |
except KeyError:
|
253 |
pass
|
254 |
|
255 |
return result
|
256 |
|
257 |
+
@classmethod
|
258 |
+
def manual_get_texts_to_embed(cls):
|
259 |
+
manual = cls.parse_file_recursive(
|
260 |
+
os.path.join(MANUAL_DIR, 'index.rst'))
|
261 |
+
manual['toctree']["copyright"] = cls.parse_file_recursive(
|
262 |
+
os.path.join(MANUAL_DIR, 'copyright.rst'))
|
|
|
|
|
263 |
|
264 |
+
return cls.get_texts_recursive(manual)
|
|
|
|
|
265 |
|
266 |
+
@classmethod
|
267 |
+
def wiki_get_texts_to_embed(cls):
|
268 |
+
tokenizer = EMBEDDING_CTX.model.tokenizer
|
269 |
+
max_tokens = EMBEDDING_CTX.model.max_seq_length
|
270 |
|
271 |
+
texts = []
|
272 |
+
owner = "blender"
|
273 |
+
repo = "blender"
|
274 |
+
pages = gitea_wiki_pages_get(owner, repo)
|
275 |
+
for page_name in pages:
|
276 |
+
page_name_title = page_name["title"]
|
277 |
+
page = gitea_wiki_page_get(owner, repo, page_name_title)
|
278 |
+
prefix = f'/{page["sub_url"]}\n# {page_name_title}:'
|
279 |
+
text = base64.b64decode(page["content_base64"]).decode('utf-8')
|
280 |
+
text = text.replace(
|
281 |
+
'https://projects.blender.org/blender/blender', '')
|
282 |
+
tokens_prefix_len = len(tokenizer.tokenize(prefix))
|
283 |
+
tokens_so_far = tokens_prefix_len
|
284 |
+
text_so_far = prefix
|
285 |
+
text_parts = text.split('\n#')
|
286 |
+
for part in text_parts:
|
287 |
+
part = '\n#' + part
|
288 |
+
part_tokens_len = len(tokenizer.tokenize(part))
|
289 |
+
if tokens_so_far + part_tokens_len > max_tokens:
|
290 |
+
texts.append(text_so_far)
|
291 |
+
text_so_far = prefix
|
292 |
+
tokens_so_far = tokens_prefix_len
|
293 |
+
text_so_far += part
|
294 |
+
tokens_so_far += part_tokens_len
|
295 |
+
|
296 |
+
if tokens_so_far != tokens_prefix_len:
|
297 |
+
texts.append(text_so_far)
|
298 |
+
|
299 |
+
return texts
|
300 |
+
|
301 |
+
def _sort_similarity(self, text_to_search, group: Group = Group.all, limit=4):
|
302 |
+
result = []
|
303 |
|
304 |
+
query_emb = EMBEDDING_CTX.encode([text_to_search])
|
|
|
|
|
|
|
305 |
|
306 |
+
ret = {}
|
307 |
|
308 |
+
for grp in list(Group)[:-1]:
|
309 |
+
if group in {grp, Group.all}:
|
310 |
+
ret[grp] = util.semantic_search(
|
311 |
+
query_emb, self[grp]['embeddings'], top_k=limit, score_function=util.dot_score)
|
312 |
|
313 |
+
score_best = 0.0
|
314 |
+
group_best = None
|
315 |
+
for grp, val in ret.items():
|
316 |
+
score_curr = val[0][0]['score']
|
317 |
+
if score_curr > score_best:
|
318 |
+
score_best = score_curr
|
319 |
+
group_best = grp
|
320 |
|
321 |
+
texts = self[group_best]['texts']
|
322 |
+
for score in ret[group_best][0]:
|
323 |
corpus_id = score['corpus_id']
|
324 |
text = texts[corpus_id]
|
325 |
+
result.append(text)
|
326 |
|
327 |
+
return result, group_best
|
328 |
|
329 |
|
330 |
G_data = _Data()
|
331 |
|
332 |
+
router = APIRouter()
|
333 |
|
|
|
|
|
|
|
|
|
334 |
|
335 |
+
@router.get("/wiki_search", response_class=PlainTextResponse)
|
336 |
+
def wiki_search(query: str = "", group: Group = Group.all) -> str:
|
337 |
+
base_url = {
|
338 |
+
Group.wiki: "https://projects.blender.org/blender/blender",
|
339 |
+
Group.manual: "https://docs.blender.org/manual/en/dev"
|
340 |
+
}
|
341 |
+
texts, group_best = G_data._sort_similarity(query, group)
|
342 |
+
|
343 |
+
result = f'BASE_URL: {base_url[group_best]}\n'
|
344 |
for text in texts:
|
345 |
+
if group_best == Group.wiki:
|
346 |
+
result += f'''---
|
347 |
+
{text}
|
348 |
+
'''
|
349 |
+
else:
|
350 |
+
index = text.find('#')
|
351 |
+
result += f'''---
|
352 |
{text[:index] + '.html'}
|
353 |
{text[index:]}
|
|
|
354 |
'''
|
355 |
return result
|
356 |
|
357 |
|
358 |
if __name__ == '__main__':
|
359 |
+
tests = ["Set Snap Base", "Building the Manual",
|
360 |
+
"Bisect Object", "Who are the Triagers"]
|
361 |
+
result = wiki_search(tests[1], Group.all)
|
362 |
print(result)
|
routers/utils_gitea.py
CHANGED
@@ -30,7 +30,7 @@ def url_json_get(url, data=None):
|
|
30 |
def url_json_get_all_pages(url, item_filter=None, limit=50, exclude=set(), verbose=False):
|
31 |
assert limit <= 50, "50 is the maximum limit of items per page"
|
32 |
|
33 |
-
url_for_page = f"{url}
|
34 |
|
35 |
with urllib.request.urlopen(url_for_page + '1') as response:
|
36 |
headers_first = response.info()
|
@@ -82,7 +82,6 @@ def gitea_fetch_issues(owner, repo, state='all', labels='', issue_attr_filter=No
|
|
82 |
if since:
|
83 |
query_params['since'] = since
|
84 |
|
85 |
-
BASE_API_URL = "https://projects.blender.org/api/v1"
|
86 |
base_url = f"{BASE_API_URL}/repos/{owner}/{repo}/issues"
|
87 |
encoded_query_params = urllib.parse.urlencode(query_params)
|
88 |
issues_url = f"{base_url}?{encoded_query_params}"
|
@@ -108,3 +107,20 @@ def gitea_issues_body_updated_at_get(issues, verbose=True):
|
|
108 |
all_results = [future.result() for future in as_completed(futures)]
|
109 |
|
110 |
return all_results
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
def url_json_get_all_pages(url, item_filter=None, limit=50, exclude=set(), verbose=False):
|
31 |
assert limit <= 50, "50 is the maximum limit of items per page"
|
32 |
|
33 |
+
url_for_page = f"{url}?limit={limit}&page="
|
34 |
|
35 |
with urllib.request.urlopen(url_for_page + '1') as response:
|
36 |
headers_first = response.info()
|
|
|
82 |
if since:
|
83 |
query_params['since'] = since
|
84 |
|
|
|
85 |
base_url = f"{BASE_API_URL}/repos/{owner}/{repo}/issues"
|
86 |
encoded_query_params = urllib.parse.urlencode(query_params)
|
87 |
issues_url = f"{base_url}?{encoded_query_params}"
|
|
|
107 |
all_results = [future.result() for future in as_completed(futures)]
|
108 |
|
109 |
return all_results
|
110 |
+
|
111 |
+
|
112 |
+
def gitea_wiki_page_get(owner, repo, page_name, verbose=True):
|
113 |
+
"""
|
114 |
+
Get a wiki page.
|
115 |
+
"""
|
116 |
+
encoded_page_name = urllib.parse.quote(page_name, safe='')
|
117 |
+
base_url = f"{BASE_API_URL}/repos/{owner}/{repo}/wiki/page/{encoded_page_name}"
|
118 |
+
return url_json_get(base_url)
|
119 |
+
|
120 |
+
|
121 |
+
def gitea_wiki_pages_get(owner, repo, verbose=True):
|
122 |
+
"""
|
123 |
+
Get all wiki pages.
|
124 |
+
"""
|
125 |
+
base_url = f"{BASE_API_URL}/repos/{owner}/{repo}/wiki/pages"
|
126 |
+
return url_json_get_all_pages(base_url)
|