Germano Cavalcante commited on
Commit
1b8973e
1 Parent(s): af4d94e

Wiki Search: Updates

Browse files

- remove wiki
- Deduplicate code
- rename docs to dev_docs

routers/embedding/__init__.py CHANGED
@@ -1,10 +1,12 @@
1
  # routers/embedding/__init__.py
2
 
3
  import os
 
4
  import sys
5
  import threading
6
  import torch
7
  from sentence_transformers import SentenceTransformer, util
 
8
 
9
 
10
  class EmbeddingContext:
@@ -111,4 +113,114 @@ class EmbeddingContext:
111
  return tokens
112
 
113
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
  EMBEDDING_CTX = EmbeddingContext()
 
1
  # routers/embedding/__init__.py
2
 
3
  import os
4
+ import re
5
  import sys
6
  import threading
7
  import torch
8
  from sentence_transformers import SentenceTransformer, util
9
+ from typing import Dict, List, Tuple, Set, LiteralString
10
 
11
 
12
  class EmbeddingContext:
 
113
  return tokens
114
 
115
 
116
+ class SplitDocs:
117
+ def split_in_topics(self,
118
+ filedir: LiteralString = None,
119
+ *,
120
+ pattern_filename=r'(?<!navigation)\.(md|rst)',
121
+ pattern_content_sub=r'---\nhide:[\s\S]+?---\s*',
122
+ patterns_titles=(
123
+ r'^# (.+)', r'^## (.+)', r'^### (.+)'),
124
+ ) -> List[Tuple[str, str]]:
125
+ def matches_pattern(filename):
126
+ return re.search(pattern_filename, filename) is not None
127
+
128
+ def split_patterns_recursive(patterns, text, index=-1):
129
+ sections = re.split(patterns[0], text, flags=re.MULTILINE)
130
+ for i, section in enumerate(sections):
131
+ if not section.strip():
132
+ continue
133
+ is_match = bool(i & 1)
134
+ if is_match:
135
+ yield (index, section)
136
+ elif len(patterns) > 1:
137
+ for j, section_j in split_patterns_recursive(patterns[1:], section, index + 1):
138
+ yield (j, section_j)
139
+ else:
140
+ yield (-1, section)
141
+
142
+ for root, _, files in os.walk(filedir):
143
+ for name in files:
144
+ if not matches_pattern(name):
145
+ continue
146
+
147
+ full_path = os.path.join(root, name)
148
+ with open(full_path, 'r', encoding='utf-8') as file:
149
+ content = file.read()
150
+
151
+ if pattern_content_sub:
152
+ content = re.sub(pattern_content_sub, '', content)
153
+
154
+ rel_path = full_path.replace(filedir, '').replace('\\', '/')
155
+
156
+ # Protect code parts
157
+ patterns = (r'(```[\s\S]+?```)', *patterns_titles)
158
+
159
+ last_titles = []
160
+ last_titles_index = []
161
+ content_accum = ''
162
+ for i, section in split_patterns_recursive(patterns, content):
163
+ if i < 0:
164
+ content_accum += section
165
+ continue
166
+ if content_accum:
167
+ yield rel_path, last_titles, content_accum
168
+ content_accum = ''
169
+ if not last_titles_index or i > last_titles_index[-1]:
170
+ last_titles_index.append(i)
171
+ last_titles.append(section)
172
+ continue
173
+ while len(last_titles_index) > 1 and i < last_titles_index[-1]:
174
+ last_titles_index.pop()
175
+ last_titles.pop()
176
+ # Replace
177
+ last_titles_index[-1] = i
178
+ last_titles[-1] = section
179
+ if content_accum or i != -1:
180
+ yield rel_path, last_titles, content_accum
181
+
182
+ def reduce_text(_self, text):
183
+ text = re.sub(r'^\n+', '', text) # Strip
184
+ text = re.sub(r'<.*?>', '', text) # Remove HTML tags
185
+ text = re.sub(r':\S*: ', '', text) # Remove [:...:] patterns
186
+ text = re.sub(r'\s*\n+', '\n', text)
187
+ return text
188
+
189
+ def embedding_header(_self, rel_path, titles):
190
+ return f"{rel_path}\n# {' | '.join(titles)}\n\n"
191
+
192
+ def split_for_embedding(self,
193
+ filedir: LiteralString = None,
194
+ *,
195
+ pattern_filename=r'(?<!navigation)\.(md|rst)',
196
+ pattern_content_sub=r'---\nhide:[\s\S]+?---\s*',
197
+ patterns_titles=(
198
+ r'^# (.+)', r'^## (.+)', r'^### (.+)'),
199
+ ):
200
+ tokenizer = EMBEDDING_CTX.model.tokenizer
201
+ max_tokens = EMBEDDING_CTX.model.max_seq_length
202
+ texts = []
203
+
204
+ for rel_path, titles, content in self.split_in_topics(
205
+ filedir, pattern_filename=pattern_filename, pattern_content_sub=pattern_content_sub, patterns_titles=patterns_titles):
206
+ header = self.embedding_header(rel_path, titles)
207
+ tokens_pre_len = len(tokenizer.tokenize(header))
208
+ tokens_so_far = tokens_pre_len
209
+ text_so_far = header
210
+ for part in self.reduce_text(content).splitlines():
211
+ part += '\n'
212
+ part_tokens_len = len(tokenizer.tokenize(part))
213
+ if tokens_so_far + part_tokens_len > max_tokens:
214
+ texts.append(text_so_far)
215
+ text_so_far = header
216
+ tokens_so_far = tokens_pre_len
217
+ text_so_far += part
218
+ tokens_so_far += part_tokens_len
219
+
220
+ if tokens_so_far != tokens_pre_len:
221
+ texts.append(text_so_far)
222
+
223
+ return texts
224
+
225
+
226
  EMBEDDING_CTX = EmbeddingContext()
routers/embedding/{embeddings_manual_wiki.pkl → embeddings_dev_docs.pkl} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:e12f37bd8b14982fa070b5db9d9c468c0bc858fc65ab136cc714bf8fcce48d69
3
- size 31873812
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e94dbc62cda6258367836eaec82dfda7f35183b1debdc980541e2ceb22d52637
3
+ size 15328541
routers/embedding/embeddings_manual.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b4f9759847d9fb0948eb2f550a3e5512b7a5bbdf8fd70118bb75fe670395ee7c
3
+ size 22522382
routers/tool_wiki_search.py CHANGED
@@ -6,133 +6,43 @@ import pickle
6
  import re
7
  import torch
8
  from enum import Enum
9
- from typing import Dict, List, Tuple, Set
10
- from sentence_transformers import util
11
- from fastapi import APIRouter
12
  from fastapi.responses import PlainTextResponse
 
 
 
13
 
14
  try:
15
- from .embedding import EMBEDDING_CTX
16
  from .utils_gitea import gitea_wiki_page_get, gitea_wiki_pages_get
17
  except:
18
- from embedding import EMBEDDING_CTX
19
  from utils_gitea import gitea_wiki_page_get, gitea_wiki_pages_get
20
 
21
- MANUAL_DIR = "D:/BlenderDev/blender-manual/manual/"
 
22
  DOCS_DIR = "D:/BlenderDev/blender-developer-docs/docs"
23
 
24
 
25
  class Group(str, Enum):
26
- docs = "docs"
27
- wiki = "wiki"
28
  manual = "manual"
29
- all = "all"
30
-
31
-
32
- class Split:
33
- filedir = None
34
- filetype = '.md'
35
-
36
- def __init__(self, filedir=None, filetype='.md'):
37
- self.filedir = filedir
38
- self.filetype = filetype
39
-
40
- def split_in_topics(self) -> List[Tuple[str, str]]:
41
- for root, _dirs, files in os.walk(self.filedir):
42
- for name in files:
43
- if not name.endswith(self.filetype) or name == 'navigation.md':
44
- continue
45
-
46
- full_path = os.path.join(root, name)
47
- with open(full_path, 'r', encoding='utf-8') as file:
48
- content = file.read()
49
-
50
- prefix = full_path.replace(self.filedir, '')
51
- prefix = re.sub(r'(index)?.md', '', prefix)
52
- prefix = prefix.replace('\\', '/')
53
-
54
- # Protect code parts
55
- parts = ['']
56
- is_first = True
57
- is_in_code_block = False
58
- for line in content.splitlines():
59
- if not line:
60
- continue
61
- line += '\n'
62
- is_in_code_block = is_in_code_block != line.strip().startswith('```')
63
- if not is_in_code_block and line.startswith('## '):
64
- if not is_first:
65
- parts.append(line)
66
- continue
67
-
68
- is_first = False
69
- parts[-1] += line
70
-
71
- title_main = ''
72
- for topic in parts:
73
- topic = topic.strip()
74
- if not topic or topic.startswith('---\nhide'):
75
- continue
76
-
77
- try:
78
- title, body = topic.split('\n', 1)
79
- except ValueError:
80
- # ignore non content
81
- continue
82
-
83
- if not title_main:
84
- title_main = title
85
- else:
86
- title = title_main + ' | ' + title
87
-
88
- yield (prefix + '\n' + title, body)
89
-
90
- def reduce_text(_self, text):
91
- text = re.sub(r'<.*?>', '', text) # Remove HTML tags
92
- text = re.sub(r':\S*: ', '', text) # Remove [:...:] patterns
93
- text = re.sub(r'(index)?.md', '', text) # Remove .md
94
- return re.sub(r'(\s*\n\s*)+', '\n', text)
95
-
96
- def split_for_embedding(self):
97
- tokenizer = EMBEDDING_CTX.model.tokenizer
98
- max_tokens = EMBEDDING_CTX.model.max_seq_length
99
- texts = []
100
-
101
- for prefix, content in self.split_in_topics():
102
- prefix += '\n\n'
103
- tokens_prefix_len = len(tokenizer.tokenize(prefix))
104
- tokens_so_far = tokens_prefix_len
105
- text_so_far = prefix
106
- for part in self.reduce_text(content).splitlines():
107
- part += '\n'
108
- part_tokens_len = len(tokenizer.tokenize(part))
109
- if tokens_so_far + part_tokens_len > max_tokens:
110
- texts.append(text_so_far)
111
- text_so_far = prefix
112
- tokens_so_far = tokens_prefix_len
113
- text_so_far += part
114
- tokens_so_far += part_tokens_len
115
-
116
- if tokens_so_far != tokens_prefix_len:
117
- texts.append(text_so_far)
118
-
119
- return texts
120
 
121
 
122
  class _Data(dict):
123
- cache_path = "routers/embedding/embeddings_manual_wiki.pkl"
124
 
125
  def __init__(self):
126
- if os.path.exists(self.cache_path):
127
- with open(self.cache_path, 'rb') as file:
128
- data = pickle.load(file)
129
- self.update(data)
130
- return
131
-
132
- # Generate
133
 
134
- print("Embedding Texts...")
135
- for grp in list(Group)[:-1]:
136
  self[grp.name] = {}
137
 
138
  # Create a list to store the text files
@@ -146,166 +56,115 @@ class _Data(dict):
146
  self[grp]['texts'] = texts
147
  self[grp]['embeddings'] = EMBEDDING_CTX.encode(texts)
148
 
149
- with open(self.cache_path, "wb") as file:
150
- # Converting the embeddings to be CPU compatible, as the virtual machine in use currently only supports the CPU.
151
- for val in self.values():
152
- val['embeddings'] = val['embeddings'].to(torch.device('cpu'))
153
-
154
- pickle.dump(dict(self), file, protocol=pickle.HIGHEST_PROTOCOL)
155
-
156
- @classmethod
157
- def parse_file_recursive(cls, filepath):
158
- with open(filepath, 'r', encoding='utf-8') as file:
159
- content = file.read()
160
-
161
- parsed_data = {}
162
-
163
- if filepath.endswith('index.rst'):
164
- filedir = os.path.dirname(filepath)
165
- parts = content.split(".. toctree::")
166
- if len(parts) > 1:
167
- parsed_data["toctree"] = {}
168
- for part in parts[1:]:
169
- toctree_entries = part.splitlines()[1:]
170
- for entry in toctree_entries:
171
- entry = entry.strip()
172
- if not entry:
173
- continue
174
-
175
- if entry.startswith('/'):
176
- # relative path.
177
- continue
178
-
179
- if not entry.endswith('.rst'):
180
- continue
181
 
182
- entry_name = entry[:-4] # remove '.rst'
183
- filepath_iter = os.path.join(filedir, entry)
184
- parsed_data['toctree'][entry_name] = cls.parse_file_recursive(
185
- filepath_iter)
186
-
187
- parsed_data['body'] = content
188
-
189
- return parsed_data
190
 
191
  @classmethod
192
  def manual_get_texts_to_embed(cls):
193
- class SplitManual(Split):
194
- def split_in_topics(_self):
195
- def get_topics_recursive(page, path='/index.html'):
196
- # Remove patterns ".. word::" and ":word:"
197
- text = re.sub(
198
- r'\.\. [^\n]+\n+(?: {3,}[^\n]*\n)*|:\w+:', '', page['body'])
199
-
200
- # Regular expression to find titles and subtitles
201
- pattern = r'([\*|#|%]{3,}\n[^\n]+\n[\*|#|%]{3,}|(?:={3,}\n)?[^\n]+\n={3,}\n)'
202
-
203
- # Split text by found patterns
204
- sections = re.split(pattern, text)
205
-
206
- # Remove possible white spaces at the beginning and end of each section
207
- sections = [
208
- section for section in sections if section.strip()]
209
-
210
- # Separate sections into a dictionary
211
- topics = []
212
- current_title = ''
213
- current_topic = path
214
-
215
- for section in sections:
216
- if match := re.match(r'[\*|#|%]{3,}\n([^\n]+)\n[\*|#|%]{3,}', section):
217
- current_topic = current_title = f'{path}\n# {match.group(1)}:'
218
- elif match := re.match(r'(?:={3,}\n)?([^\n]+)\n={3,}\n', section):
219
- current_topic = f'{current_title} | {match.group(1)}'
220
- else:
221
- if current_topic == path:
222
- raise
223
- topics.append((current_topic, section))
224
-
225
- try:
226
- for key in page['toctree'].keys():
227
- page_child = page['toctree'][key]
228
- topics.extend(get_topics_recursive(
229
- page_child, path.replace('index', key)))
230
- except KeyError:
231
- pass
232
-
233
- return topics
234
-
235
- manual = cls.parse_file_recursive(
236
- os.path.join(MANUAL_DIR, 'index.rst'))
237
- manual['toctree']["copyright"] = cls.parse_file_recursive(
238
- os.path.join(MANUAL_DIR, 'copyright.rst'))
239
-
240
- return get_topics_recursive(manual)
241
-
242
  def reduce_text(_self, text):
243
  # Remove repeated characters
244
- text = re.sub(r'%{2,}', '', text) # Title
245
- text = re.sub(r'#{2,}', '', text) # Title
246
- text = re.sub(r'\*{3,}', '', text) # Title
247
- text = re.sub(r'={3,}', '', text) # Topic
248
  text = re.sub(r'\^{3,}', '', text)
249
  text = re.sub(r'-{3,}', '', text)
250
 
251
- text = re.sub(r'(\s*\n\s*)+', '\n', text)
 
252
  return text
253
 
254
- return SplitManual().split_for_embedding()
 
 
 
 
 
 
 
 
 
 
 
 
 
255
 
256
  @staticmethod
257
  def wiki_get_texts_to_embed():
258
- class SplitWiki(Split):
259
- def split_in_topics(_self):
 
 
 
 
 
260
  owner = "blender"
261
  repo = "blender"
262
  pages = gitea_wiki_pages_get(owner, repo)
263
  for page_name in pages:
264
  page_name_title = page_name["title"]
265
  page = gitea_wiki_page_get(owner, repo, page_name_title)
266
- prefix = f'/{owner}/{repo}/{page["sub_url"]}\n# {page_name_title}:\n'
 
267
  text = base64.b64decode(
268
  page["content_base64"]).decode('utf-8')
269
- yield (prefix, text)
270
 
271
  def reduce_text(_self, text):
272
- return super().reduce_text(text).replace('https://projects.blender.org', '')
 
 
273
 
274
  return SplitWiki().split_for_embedding()
275
 
276
  @staticmethod
277
  def docs_get_texts_to_embed():
278
- return Split(DOCS_DIR).split_for_embedding()
279
-
280
- def _sort_similarity(self, text_to_search, groups: Set[Group] = {Group.docs, Group.wiki, Group.manual}, limit=5):
281
- result = []
 
 
282
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
283
  query_emb = EMBEDDING_CTX.encode([text_to_search])
284
-
285
- ret = {}
286
-
287
  for grp in groups:
288
- if not grp in self:
289
  continue
290
 
291
- ret[grp] = util.semantic_search(
292
  query_emb, self[grp]['embeddings'], top_k=limit, score_function=util.dot_score)
293
 
294
- score_best = 0.0
295
- group_best = None
296
- for grp, val in ret.items():
297
- score_curr = val[0][0]['score']
298
- if score_curr > score_best:
299
- score_best = score_curr
300
- group_best = grp
301
 
302
- texts = self[group_best]['texts']
303
- for score in ret[group_best][0]:
304
- corpus_id = score['corpus_id']
305
- text = texts[corpus_id]
306
- result.append(text)
307
 
308
- return result, group_best
309
 
310
 
311
  G_data = _Data()
@@ -314,23 +173,12 @@ router = APIRouter()
314
 
315
 
316
  @router.get("/wiki_search", response_class=PlainTextResponse)
317
- def wiki_search(query: str = "", group: Group = Group.all) -> str:
318
- base_url = {
319
- "docs": "https://developer.blender.org/docs",
320
- "wiki": "https://projects.blender.org",
321
- "manual": "https://docs.blender.org/manual/en/dev"
322
- }
323
-
324
- if group is Group.all:
325
- groups = {Group.docs, Group.wiki, Group.manual}
326
- elif group is Group.wiki:
327
- groups = {Group.docs, Group.wiki}
328
- else:
329
- groups = {group}
330
-
331
- texts, group_best = G_data._sort_similarity(query, groups)
332
-
333
- result = f'BASE_URL: {base_url[group_best]}\n'
334
  for text in texts:
335
  result += f'\n---\n{text}'
336
  return result
@@ -339,5 +187,5 @@ def wiki_search(query: str = "", group: Group = Group.all) -> str:
339
  if __name__ == '__main__':
340
  tests = ["Set Snap Base", "Building the Manual",
341
  "Bisect Object", "Who are the Triagers", "4.3 Release Notes Motion Paths"]
342
- result = wiki_search(tests[4], Group.wiki)
343
  print(result)
 
6
  import re
7
  import torch
8
  from enum import Enum
9
+ from fastapi import APIRouter, Query
 
 
10
  from fastapi.responses import PlainTextResponse
11
+ from heapq import nlargest
12
+ from sentence_transformers import util
13
+ from typing import Dict, List, Tuple, Set, LiteralString
14
 
15
  try:
16
+ from .embedding import SplitDocs, EMBEDDING_CTX
17
  from .utils_gitea import gitea_wiki_page_get, gitea_wiki_pages_get
18
  except:
19
+ from embedding import SplitDocs, EMBEDDING_CTX
20
  from utils_gitea import gitea_wiki_page_get, gitea_wiki_pages_get
21
 
22
+
23
+ MANUAL_DIR = "D:/BlenderDev/blender-manual/manual"
24
  DOCS_DIR = "D:/BlenderDev/blender-developer-docs/docs"
25
 
26
 
27
  class Group(str, Enum):
28
+ dev_docs = "dev_docs"
29
+ # wiki = "wiki"
30
  manual = "manual"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
 
33
  class _Data(dict):
34
+ cache_path = "routers/embedding/embeddings_{}.pkl"
35
 
36
  def __init__(self):
37
+ for grp in list(Group):
38
+ cache_path = self.cache_path.format(grp.name)
39
+ if os.path.exists(cache_path):
40
+ with open(cache_path, 'rb') as file:
41
+ self[grp.name] = pickle.load(file)
42
+ continue
 
43
 
44
+ # Generate
45
+ print("Embedding Texts for", grp.name)
46
  self[grp.name] = {}
47
 
48
  # Create a list to store the text files
 
56
  self[grp]['texts'] = texts
57
  self[grp]['embeddings'] = EMBEDDING_CTX.encode(texts)
58
 
59
+ with open(cache_path, "wb") as file:
60
+ # Converting the embeddings to be CPU compatible, as the virtual machine in use currently only supports the CPU.
61
+ self[grp]['embeddings'] = self[grp]['embeddings'].to(
62
+ torch.device('cpu'))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
+ pickle.dump(self[grp], file, protocol=pickle.HIGHEST_PROTOCOL)
 
 
 
 
 
 
 
65
 
66
  @classmethod
67
  def manual_get_texts_to_embed(cls):
68
+ class SplitManual(SplitDocs):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  def reduce_text(_self, text):
70
  # Remove repeated characters
 
 
 
 
71
  text = re.sub(r'\^{3,}', '', text)
72
  text = re.sub(r'-{3,}', '', text)
73
 
74
+ text = text.replace('.rst', '.html')
75
+ text = super().reduce_text(text)
76
  return text
77
 
78
+ def embedding_header(self, rel_path, titles):
79
+ rel_path = rel_path.replace('.rst', '.html')
80
+ return super().embedding_header(rel_path, titles)
81
+
82
+ # Remove patterns ".. word::" and ":word:"
83
+ pattern_content_sub = r'\.\. [^\n]+\n+(?: {3,}[^\n]*\n)*|:\w+:'
84
+ patterns_titles = (
85
+ r'[\*#%]{3,}\n\s*(.+)\n[\*#%]{3,}', r'(?:[=+]{3,}\n)?\s*(.+)\n[=+]{3,}\n')
86
+
87
+ return SplitManual().split_for_embedding(
88
+ MANUAL_DIR,
89
+ pattern_content_sub=pattern_content_sub,
90
+ patterns_titles=patterns_titles,
91
+ )
92
 
93
  @staticmethod
94
  def wiki_get_texts_to_embed():
95
+ class SplitWiki(SplitDocs):
96
+ def split_in_topics(_self,
97
+ filedir: LiteralString = None,
98
+ *,
99
+ pattern_filename=None,
100
+ pattern_content_sub=None,
101
+ patterns_titles=None):
102
  owner = "blender"
103
  repo = "blender"
104
  pages = gitea_wiki_pages_get(owner, repo)
105
  for page_name in pages:
106
  page_name_title = page_name["title"]
107
  page = gitea_wiki_page_get(owner, repo, page_name_title)
108
+ rel_dir = f'/{owner}/{repo}/{page["sub_url"]}'
109
+ titles = [page_name_title]
110
  text = base64.b64decode(
111
  page["content_base64"]).decode('utf-8')
112
+ yield (rel_dir, titles, text)
113
 
114
  def reduce_text(_self, text):
115
+ text = super().reduce_text(text)
116
+ text = text.replace('https://projects.blender.org', '')
117
+ return text
118
 
119
  return SplitWiki().split_for_embedding()
120
 
121
  @staticmethod
122
  def docs_get_texts_to_embed():
123
+ class SplitBlenderDocs(SplitDocs):
124
+ def reduce_text(_self, text):
125
+ text = super().reduce_text(text)
126
+ # Remove .md or index.md
127
+ text = re.sub(r'(index)?.md', '', text)
128
+ return text
129
 
130
+ def embedding_header(_self, rel_path, titles):
131
+ rel_path = re.sub(r'(index)?.md', '', rel_path)
132
+ return super().embedding_header(rel_path, titles)
133
+
134
+ return SplitBlenderDocs().split_for_embedding(DOCS_DIR)
135
+
136
+ def _sort_similarity(
137
+ self,
138
+ text_to_search: str,
139
+ groups: Set[Group] = Query(
140
+ default={Group.dev_docs, Group.manual}),
141
+ limit: int = 5) -> List[str]:
142
+ base_url: Dict[Group, str] = {
143
+ Group.dev_docs: "https://developer.blender.org/docs",
144
+ # Group.wiki: "https://projects.blender.org",
145
+ Group.manual: "https://docs.blender.org/manual/en/dev"
146
+ }
147
  query_emb = EMBEDDING_CTX.encode([text_to_search])
148
+ results: List[Tuple[float, str, Group]] = []
 
 
149
  for grp in groups:
150
+ if grp not in self:
151
  continue
152
 
153
+ search_results = util.semantic_search(
154
  query_emb, self[grp]['embeddings'], top_k=limit, score_function=util.dot_score)
155
 
156
+ for score in search_results[0]:
157
+ corpus_id = score['corpus_id']
158
+ text = self[grp]['texts'][corpus_id]
159
+ results.append((score['score'], text, grp))
160
+
161
+ # Keep only the top `limit` results
162
+ top_results = nlargest(limit, results, key=lambda x: x[0])
163
 
164
+ # Extract sorted texts with base URL
165
+ sorted_texts = [base_url[grp] + text for _, text, grp in top_results]
 
 
 
166
 
167
+ return sorted_texts
168
 
169
 
170
  G_data = _Data()
 
173
 
174
 
175
  @router.get("/wiki_search", response_class=PlainTextResponse)
176
+ def wiki_search(
177
+ query: str = "",
178
+ groups: Set[Group] = Query(default={Group.dev_docs, Group.manual})
179
+ ) -> str:
180
+ texts = G_data._sort_similarity(query, groups)
181
+ result: str = ''
 
 
 
 
 
 
 
 
 
 
 
182
  for text in texts:
183
  result += f'\n---\n{text}'
184
  return result
 
187
  if __name__ == '__main__':
188
  tests = ["Set Snap Base", "Building the Manual",
189
  "Bisect Object", "Who are the Triagers", "4.3 Release Notes Motion Paths"]
190
+ result = wiki_search(tests[0], {Group.dev_docs, Group.manual})
191
  print(result)