Germano Cavalcante commited on
Commit
0576e6d
β€’
1 Parent(s): 086be5c

Add wiki to the rag system

Browse files
routers/embedding/{embeddings_manual.pkl β†’ embeddings_manual_wiki.pkl} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:9ed7475fc8ffda0d9e9deb6480b7152b53657f0fe6a6140bcb60360e425e7a01
3
- size 18659241
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3c4a71f60f1878e528b190c3c43f744611c90efdea4c2ef333962773fd2fd637
3
+ size 19670346
routers/tool_wiki_search.py CHANGED
@@ -1,28 +1,60 @@
1
- # routers/wiki_search.py
2
 
 
3
  import os
4
  import pickle
5
  import re
6
  import torch
 
7
  from typing import Dict, List
8
  from sentence_transformers import util
9
  from fastapi import APIRouter
10
  from fastapi.responses import PlainTextResponse
 
11
 
12
  try:
13
  from .embedding import EMBEDDING_CTX
14
  except:
15
  from embedding import EMBEDDING_CTX
16
 
17
- router = APIRouter()
18
-
19
  MANUAL_DIR = "D:/BlenderDev/blender-manual/manual/"
20
- BASE_URL = "https://docs.blender.org/manual/en/dev"
21
- G_data = None
 
 
 
 
22
 
23
 
24
  class _Data(dict):
25
- cache_path = "routers/embedding/embeddings_manual.pkl"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  @staticmethod
28
  def reduce_text(text):
@@ -38,24 +70,20 @@ class _Data(dict):
38
  return text
39
 
40
  @classmethod
41
- def parse_file_recursive(cls, filedir, filename):
42
- with open(os.path.join(filedir, filename), 'r', encoding='utf-8') as file:
43
  content = file.read()
44
 
45
  parsed_data = {}
46
 
47
- if not filename.endswith('index.rst'):
48
- body = content.strip()
49
- else:
50
  parts = content.split(".. toctree::")
51
- body = parts[0].strip()
52
-
53
  if len(parts) > 1:
54
  parsed_data["toctree"] = {}
55
  for part in parts[1:]:
56
- toctree_entries = part.split('\n')
57
- line = toctree_entries[0]
58
- for entry in toctree_entries[1:]:
59
  entry = entry.strip()
60
  if not entry:
61
  continue
@@ -67,20 +95,12 @@ class _Data(dict):
67
  if not entry.endswith('.rst'):
68
  continue
69
 
70
- if entry.endswith('/index.rst'):
71
- entry_name = entry[:-10]
72
- filedir_ = os.path.join(filedir, entry_name)
73
- filename_ = 'index.rst'
74
- else:
75
- entry_name = entry[:-4]
76
- filedir_ = filedir
77
- filename_ = entry
78
-
79
  parsed_data['toctree'][entry_name] = cls.parse_file_recursive(
80
- filedir_, filename_)
81
 
82
- # The '\n' at the end of the file resolves regex patterns
83
- parsed_data['body'] = body + '\n'
84
 
85
  return parsed_data
86
 
@@ -221,82 +241,122 @@ class _Data(dict):
221
  return result
222
 
223
  @classmethod
224
- def get_texts_recursive(cls, page, path=''):
225
  result = cls.split_into_many(page['body'], path)
226
 
227
  try:
228
  for key in page['toctree'].keys():
229
  page_child = page['toctree'][key]
230
  result.extend(cls.get_texts_recursive(
231
- page_child, f'{path}/{key}'))
232
  except KeyError:
233
  pass
234
 
235
  return result
236
 
237
- def _embeddings_generate(self):
238
- if os.path.exists(self.cache_path):
239
- with open(self.cache_path, 'rb') as file:
240
- data = pickle.load(file)
241
- self.update(data)
242
- return self
243
-
244
- # Generate
245
 
246
- manual = self.parse_file_recursive(MANUAL_DIR, 'index.rst')
247
- manual['toctree']["copyright"] = self.parse_file_recursive(
248
- MANUAL_DIR, 'copyright.rst')
249
 
250
- # Create a list to store the text files
251
- texts = self.get_texts_recursive(manual)
 
 
252
 
253
- print("Embedding Texts...")
254
- self['texts'] = texts
255
- self['embeddings'] = EMBEDDING_CTX.encode(texts)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
256
 
257
- with open(self.cache_path, "wb") as file:
258
- # Converting the embeddings to be CPU compatible, as the virtual machine in use currently only supports the CPU.
259
- self['embeddings'] = self['embeddings'].to(torch.device('cpu'))
260
- pickle.dump(dict(self), file, protocol=pickle.HIGHEST_PROTOCOL)
261
 
262
- return G_data
263
 
264
- def _sort_similarity(self, text_to_search, limit):
265
- results = []
 
 
266
 
267
- query_emb = EMBEDDING_CTX.encode([text_to_search])
268
- ret = util.semantic_search(
269
- query_emb, self['embeddings'], top_k=limit, score_function=util.dot_score)
 
 
 
 
270
 
271
- texts = self['texts']
272
- for score in ret[0]:
273
  corpus_id = score['corpus_id']
274
  text = texts[corpus_id]
275
- results.append(text)
276
 
277
- return results
278
 
279
 
280
  G_data = _Data()
281
 
 
282
 
283
- @router.get("/wiki_search", response_class=PlainTextResponse)
284
- def wiki_search(query: str = "") -> str:
285
- data = G_data._embeddings_generate()
286
- texts = G_data._sort_similarity(query, 5)
287
 
288
- result = f'BASE_URL: {BASE_URL}\n'
 
 
 
 
 
 
 
 
289
  for text in texts:
290
- index = text.find('#')
291
- result += f'''---
 
 
 
 
 
292
  {text[:index] + '.html'}
293
  {text[index:]}
294
-
295
  '''
296
  return result
297
 
298
 
299
  if __name__ == '__main__':
300
- tests = ["Set Snap Base", "Building the Manual", "Bisect Object"]
301
- result = wiki_search(tests[0])
 
302
  print(result)
 
1
+ # routers/tool_wiki_search.py
2
 
3
+ import base64
4
  import os
5
  import pickle
6
  import re
7
  import torch
8
+ from enum import Enum
9
  from typing import Dict, List
10
  from sentence_transformers import util
11
  from fastapi import APIRouter
12
  from fastapi.responses import PlainTextResponse
13
+ from utils_gitea import gitea_wiki_page_get, gitea_wiki_pages_get
14
 
15
  try:
16
  from .embedding import EMBEDDING_CTX
17
  except:
18
  from embedding import EMBEDDING_CTX
19
 
 
 
20
  MANUAL_DIR = "D:/BlenderDev/blender-manual/manual/"
21
+
22
+
23
+ class Group(str, Enum):
24
+ wiki = "wiki"
25
+ manual = "manual"
26
+ all = "all"
27
 
28
 
29
  class _Data(dict):
30
+ cache_path = "routers/embedding/embeddings_manual_wiki.pkl"
31
+
32
+ def __init__(self):
33
+ if os.path.exists(self.cache_path):
34
+ with open(self.cache_path, 'rb') as file:
35
+ data = pickle.load(file)
36
+ self.update(data)
37
+ return
38
+
39
+ # Generate
40
+
41
+ print("Embedding Texts...")
42
+ for grp in list(Group)[:-1]:
43
+ self[grp.name] = {}
44
+
45
+ # Create a list to store the text files
46
+ texts = self.manual_get_texts_to_embed(
47
+ ) if grp == Group.manual else self.wiki_get_texts_to_embed()
48
+
49
+ self[grp]['texts'] = texts
50
+ self[grp]['embeddings'] = EMBEDDING_CTX.encode(texts)
51
+
52
+ with open(self.cache_path, "wb") as file:
53
+ # Converting the embeddings to be CPU compatible, as the virtual machine in use currently only supports the CPU.
54
+ for val in self.values():
55
+ val['embeddings'] = val['embeddings'].to(torch.device('cpu'))
56
+
57
+ pickle.dump(dict(self), file, protocol=pickle.HIGHEST_PROTOCOL)
58
 
59
  @staticmethod
60
  def reduce_text(text):
 
70
  return text
71
 
72
  @classmethod
73
+ def parse_file_recursive(cls, filepath):
74
+ with open(filepath, 'r', encoding='utf-8') as file:
75
  content = file.read()
76
 
77
  parsed_data = {}
78
 
79
+ if filepath.endswith('index.rst'):
80
+ filedir = os.path.dirname(filepath)
 
81
  parts = content.split(".. toctree::")
 
 
82
  if len(parts) > 1:
83
  parsed_data["toctree"] = {}
84
  for part in parts[1:]:
85
+ toctree_entries = part.splitlines()[1:]
86
+ for entry in toctree_entries:
 
87
  entry = entry.strip()
88
  if not entry:
89
  continue
 
95
  if not entry.endswith('.rst'):
96
  continue
97
 
98
+ entry_name = entry[:-4] # remove '.rst'
99
+ filepath_iter = os.path.join(filedir, entry)
 
 
 
 
 
 
 
100
  parsed_data['toctree'][entry_name] = cls.parse_file_recursive(
101
+ filepath_iter)
102
 
103
+ parsed_data['body'] = content
 
104
 
105
  return parsed_data
106
 
 
241
  return result
242
 
243
  @classmethod
244
+ def get_texts_recursive(cls, page, path='index'):
245
  result = cls.split_into_many(page['body'], path)
246
 
247
  try:
248
  for key in page['toctree'].keys():
249
  page_child = page['toctree'][key]
250
  result.extend(cls.get_texts_recursive(
251
+ page_child, path.replace('index', key)))
252
  except KeyError:
253
  pass
254
 
255
  return result
256
 
257
+ @classmethod
258
+ def manual_get_texts_to_embed(cls):
259
+ manual = cls.parse_file_recursive(
260
+ os.path.join(MANUAL_DIR, 'index.rst'))
261
+ manual['toctree']["copyright"] = cls.parse_file_recursive(
262
+ os.path.join(MANUAL_DIR, 'copyright.rst'))
 
 
263
 
264
+ return cls.get_texts_recursive(manual)
 
 
265
 
266
+ @classmethod
267
+ def wiki_get_texts_to_embed(cls):
268
+ tokenizer = EMBEDDING_CTX.model.tokenizer
269
+ max_tokens = EMBEDDING_CTX.model.max_seq_length
270
 
271
+ texts = []
272
+ owner = "blender"
273
+ repo = "blender"
274
+ pages = gitea_wiki_pages_get(owner, repo)
275
+ for page_name in pages:
276
+ page_name_title = page_name["title"]
277
+ page = gitea_wiki_page_get(owner, repo, page_name_title)
278
+ prefix = f'/{page["sub_url"]}\n# {page_name_title}:'
279
+ text = base64.b64decode(page["content_base64"]).decode('utf-8')
280
+ text = text.replace(
281
+ 'https://projects.blender.org/blender/blender', '')
282
+ tokens_prefix_len = len(tokenizer.tokenize(prefix))
283
+ tokens_so_far = tokens_prefix_len
284
+ text_so_far = prefix
285
+ text_parts = text.split('\n#')
286
+ for part in text_parts:
287
+ part = '\n#' + part
288
+ part_tokens_len = len(tokenizer.tokenize(part))
289
+ if tokens_so_far + part_tokens_len > max_tokens:
290
+ texts.append(text_so_far)
291
+ text_so_far = prefix
292
+ tokens_so_far = tokens_prefix_len
293
+ text_so_far += part
294
+ tokens_so_far += part_tokens_len
295
+
296
+ if tokens_so_far != tokens_prefix_len:
297
+ texts.append(text_so_far)
298
+
299
+ return texts
300
+
301
+ def _sort_similarity(self, text_to_search, group: Group = Group.all, limit=4):
302
+ result = []
303
 
304
+ query_emb = EMBEDDING_CTX.encode([text_to_search])
 
 
 
305
 
306
+ ret = {}
307
 
308
+ for grp in list(Group)[:-1]:
309
+ if group in {grp, Group.all}:
310
+ ret[grp] = util.semantic_search(
311
+ query_emb, self[grp]['embeddings'], top_k=limit, score_function=util.dot_score)
312
 
313
+ score_best = 0.0
314
+ group_best = None
315
+ for grp, val in ret.items():
316
+ score_curr = val[0][0]['score']
317
+ if score_curr > score_best:
318
+ score_best = score_curr
319
+ group_best = grp
320
 
321
+ texts = self[group_best]['texts']
322
+ for score in ret[group_best][0]:
323
  corpus_id = score['corpus_id']
324
  text = texts[corpus_id]
325
+ result.append(text)
326
 
327
+ return result, group_best
328
 
329
 
330
  G_data = _Data()
331
 
332
+ router = APIRouter()
333
 
 
 
 
 
334
 
335
+ @router.get("/wiki_search", response_class=PlainTextResponse)
336
+ def wiki_search(query: str = "", group: Group = Group.all) -> str:
337
+ base_url = {
338
+ Group.wiki: "https://projects.blender.org/blender/blender",
339
+ Group.manual: "https://docs.blender.org/manual/en/dev"
340
+ }
341
+ texts, group_best = G_data._sort_similarity(query, group)
342
+
343
+ result = f'BASE_URL: {base_url[group_best]}\n'
344
  for text in texts:
345
+ if group_best == Group.wiki:
346
+ result += f'''---
347
+ {text}
348
+ '''
349
+ else:
350
+ index = text.find('#')
351
+ result += f'''---
352
  {text[:index] + '.html'}
353
  {text[index:]}
 
354
  '''
355
  return result
356
 
357
 
358
  if __name__ == '__main__':
359
+ tests = ["Set Snap Base", "Building the Manual",
360
+ "Bisect Object", "Who are the Triagers"]
361
+ result = wiki_search(tests[1], Group.all)
362
  print(result)
routers/utils_gitea.py CHANGED
@@ -30,7 +30,7 @@ def url_json_get(url, data=None):
30
  def url_json_get_all_pages(url, item_filter=None, limit=50, exclude=set(), verbose=False):
31
  assert limit <= 50, "50 is the maximum limit of items per page"
32
 
33
- url_for_page = f"{url}&limit={limit}&page="
34
 
35
  with urllib.request.urlopen(url_for_page + '1') as response:
36
  headers_first = response.info()
@@ -82,7 +82,6 @@ def gitea_fetch_issues(owner, repo, state='all', labels='', issue_attr_filter=No
82
  if since:
83
  query_params['since'] = since
84
 
85
- BASE_API_URL = "https://projects.blender.org/api/v1"
86
  base_url = f"{BASE_API_URL}/repos/{owner}/{repo}/issues"
87
  encoded_query_params = urllib.parse.urlencode(query_params)
88
  issues_url = f"{base_url}?{encoded_query_params}"
@@ -108,3 +107,20 @@ def gitea_issues_body_updated_at_get(issues, verbose=True):
108
  all_results = [future.result() for future in as_completed(futures)]
109
 
110
  return all_results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  def url_json_get_all_pages(url, item_filter=None, limit=50, exclude=set(), verbose=False):
31
  assert limit <= 50, "50 is the maximum limit of items per page"
32
 
33
+ url_for_page = f"{url}?limit={limit}&page="
34
 
35
  with urllib.request.urlopen(url_for_page + '1') as response:
36
  headers_first = response.info()
 
82
  if since:
83
  query_params['since'] = since
84
 
 
85
  base_url = f"{BASE_API_URL}/repos/{owner}/{repo}/issues"
86
  encoded_query_params = urllib.parse.urlencode(query_params)
87
  issues_url = f"{base_url}?{encoded_query_params}"
 
107
  all_results = [future.result() for future in as_completed(futures)]
108
 
109
  return all_results
110
+
111
+
112
+ def gitea_wiki_page_get(owner, repo, page_name, verbose=True):
113
+ """
114
+ Get a wiki page.
115
+ """
116
+ encoded_page_name = urllib.parse.quote(page_name, safe='')
117
+ base_url = f"{BASE_API_URL}/repos/{owner}/{repo}/wiki/page/{encoded_page_name}"
118
+ return url_json_get(base_url)
119
+
120
+
121
+ def gitea_wiki_pages_get(owner, repo, verbose=True):
122
+ """
123
+ Get all wiki pages.
124
+ """
125
+ base_url = f"{BASE_API_URL}/repos/{owner}/{repo}/wiki/pages"
126
+ return url_json_get_all_pages(base_url)