Spaces:
Running
Running
Germano Cavalcante
commited on
Commit
•
1b8973e
1
Parent(s):
af4d94e
Wiki Search: Updates
Browse files- remove wiki
- Deduplicate code
- rename docs to dev_docs
routers/embedding/__init__.py
CHANGED
@@ -1,10 +1,12 @@
|
|
1 |
# routers/embedding/__init__.py
|
2 |
|
3 |
import os
|
|
|
4 |
import sys
|
5 |
import threading
|
6 |
import torch
|
7 |
from sentence_transformers import SentenceTransformer, util
|
|
|
8 |
|
9 |
|
10 |
class EmbeddingContext:
|
@@ -111,4 +113,114 @@ class EmbeddingContext:
|
|
111 |
return tokens
|
112 |
|
113 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
114 |
EMBEDDING_CTX = EmbeddingContext()
|
|
|
1 |
# routers/embedding/__init__.py
|
2 |
|
3 |
import os
|
4 |
+
import re
|
5 |
import sys
|
6 |
import threading
|
7 |
import torch
|
8 |
from sentence_transformers import SentenceTransformer, util
|
9 |
+
from typing import Dict, List, Tuple, Set, LiteralString
|
10 |
|
11 |
|
12 |
class EmbeddingContext:
|
|
|
113 |
return tokens
|
114 |
|
115 |
|
116 |
+
class SplitDocs:
|
117 |
+
def split_in_topics(self,
|
118 |
+
filedir: LiteralString = None,
|
119 |
+
*,
|
120 |
+
pattern_filename=r'(?<!navigation)\.(md|rst)',
|
121 |
+
pattern_content_sub=r'---\nhide:[\s\S]+?---\s*',
|
122 |
+
patterns_titles=(
|
123 |
+
r'^# (.+)', r'^## (.+)', r'^### (.+)'),
|
124 |
+
) -> List[Tuple[str, str]]:
|
125 |
+
def matches_pattern(filename):
|
126 |
+
return re.search(pattern_filename, filename) is not None
|
127 |
+
|
128 |
+
def split_patterns_recursive(patterns, text, index=-1):
|
129 |
+
sections = re.split(patterns[0], text, flags=re.MULTILINE)
|
130 |
+
for i, section in enumerate(sections):
|
131 |
+
if not section.strip():
|
132 |
+
continue
|
133 |
+
is_match = bool(i & 1)
|
134 |
+
if is_match:
|
135 |
+
yield (index, section)
|
136 |
+
elif len(patterns) > 1:
|
137 |
+
for j, section_j in split_patterns_recursive(patterns[1:], section, index + 1):
|
138 |
+
yield (j, section_j)
|
139 |
+
else:
|
140 |
+
yield (-1, section)
|
141 |
+
|
142 |
+
for root, _, files in os.walk(filedir):
|
143 |
+
for name in files:
|
144 |
+
if not matches_pattern(name):
|
145 |
+
continue
|
146 |
+
|
147 |
+
full_path = os.path.join(root, name)
|
148 |
+
with open(full_path, 'r', encoding='utf-8') as file:
|
149 |
+
content = file.read()
|
150 |
+
|
151 |
+
if pattern_content_sub:
|
152 |
+
content = re.sub(pattern_content_sub, '', content)
|
153 |
+
|
154 |
+
rel_path = full_path.replace(filedir, '').replace('\\', '/')
|
155 |
+
|
156 |
+
# Protect code parts
|
157 |
+
patterns = (r'(```[\s\S]+?```)', *patterns_titles)
|
158 |
+
|
159 |
+
last_titles = []
|
160 |
+
last_titles_index = []
|
161 |
+
content_accum = ''
|
162 |
+
for i, section in split_patterns_recursive(patterns, content):
|
163 |
+
if i < 0:
|
164 |
+
content_accum += section
|
165 |
+
continue
|
166 |
+
if content_accum:
|
167 |
+
yield rel_path, last_titles, content_accum
|
168 |
+
content_accum = ''
|
169 |
+
if not last_titles_index or i > last_titles_index[-1]:
|
170 |
+
last_titles_index.append(i)
|
171 |
+
last_titles.append(section)
|
172 |
+
continue
|
173 |
+
while len(last_titles_index) > 1 and i < last_titles_index[-1]:
|
174 |
+
last_titles_index.pop()
|
175 |
+
last_titles.pop()
|
176 |
+
# Replace
|
177 |
+
last_titles_index[-1] = i
|
178 |
+
last_titles[-1] = section
|
179 |
+
if content_accum or i != -1:
|
180 |
+
yield rel_path, last_titles, content_accum
|
181 |
+
|
182 |
+
def reduce_text(_self, text):
|
183 |
+
text = re.sub(r'^\n+', '', text) # Strip
|
184 |
+
text = re.sub(r'<.*?>', '', text) # Remove HTML tags
|
185 |
+
text = re.sub(r':\S*: ', '', text) # Remove [:...:] patterns
|
186 |
+
text = re.sub(r'\s*\n+', '\n', text)
|
187 |
+
return text
|
188 |
+
|
189 |
+
def embedding_header(_self, rel_path, titles):
|
190 |
+
return f"{rel_path}\n# {' | '.join(titles)}\n\n"
|
191 |
+
|
192 |
+
def split_for_embedding(self,
|
193 |
+
filedir: LiteralString = None,
|
194 |
+
*,
|
195 |
+
pattern_filename=r'(?<!navigation)\.(md|rst)',
|
196 |
+
pattern_content_sub=r'---\nhide:[\s\S]+?---\s*',
|
197 |
+
patterns_titles=(
|
198 |
+
r'^# (.+)', r'^## (.+)', r'^### (.+)'),
|
199 |
+
):
|
200 |
+
tokenizer = EMBEDDING_CTX.model.tokenizer
|
201 |
+
max_tokens = EMBEDDING_CTX.model.max_seq_length
|
202 |
+
texts = []
|
203 |
+
|
204 |
+
for rel_path, titles, content in self.split_in_topics(
|
205 |
+
filedir, pattern_filename=pattern_filename, pattern_content_sub=pattern_content_sub, patterns_titles=patterns_titles):
|
206 |
+
header = self.embedding_header(rel_path, titles)
|
207 |
+
tokens_pre_len = len(tokenizer.tokenize(header))
|
208 |
+
tokens_so_far = tokens_pre_len
|
209 |
+
text_so_far = header
|
210 |
+
for part in self.reduce_text(content).splitlines():
|
211 |
+
part += '\n'
|
212 |
+
part_tokens_len = len(tokenizer.tokenize(part))
|
213 |
+
if tokens_so_far + part_tokens_len > max_tokens:
|
214 |
+
texts.append(text_so_far)
|
215 |
+
text_so_far = header
|
216 |
+
tokens_so_far = tokens_pre_len
|
217 |
+
text_so_far += part
|
218 |
+
tokens_so_far += part_tokens_len
|
219 |
+
|
220 |
+
if tokens_so_far != tokens_pre_len:
|
221 |
+
texts.append(text_so_far)
|
222 |
+
|
223 |
+
return texts
|
224 |
+
|
225 |
+
|
226 |
EMBEDDING_CTX = EmbeddingContext()
|
routers/embedding/{embeddings_manual_wiki.pkl → embeddings_dev_docs.pkl}
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e94dbc62cda6258367836eaec82dfda7f35183b1debdc980541e2ceb22d52637
|
3 |
+
size 15328541
|
routers/embedding/embeddings_manual.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b4f9759847d9fb0948eb2f550a3e5512b7a5bbdf8fd70118bb75fe670395ee7c
|
3 |
+
size 22522382
|
routers/tool_wiki_search.py
CHANGED
@@ -6,133 +6,43 @@ import pickle
|
|
6 |
import re
|
7 |
import torch
|
8 |
from enum import Enum
|
9 |
-
from
|
10 |
-
from sentence_transformers import util
|
11 |
-
from fastapi import APIRouter
|
12 |
from fastapi.responses import PlainTextResponse
|
|
|
|
|
|
|
13 |
|
14 |
try:
|
15 |
-
from .embedding import EMBEDDING_CTX
|
16 |
from .utils_gitea import gitea_wiki_page_get, gitea_wiki_pages_get
|
17 |
except:
|
18 |
-
from embedding import EMBEDDING_CTX
|
19 |
from utils_gitea import gitea_wiki_page_get, gitea_wiki_pages_get
|
20 |
|
21 |
-
|
|
|
22 |
DOCS_DIR = "D:/BlenderDev/blender-developer-docs/docs"
|
23 |
|
24 |
|
25 |
class Group(str, Enum):
|
26 |
-
|
27 |
-
wiki = "wiki"
|
28 |
manual = "manual"
|
29 |
-
all = "all"
|
30 |
-
|
31 |
-
|
32 |
-
class Split:
|
33 |
-
filedir = None
|
34 |
-
filetype = '.md'
|
35 |
-
|
36 |
-
def __init__(self, filedir=None, filetype='.md'):
|
37 |
-
self.filedir = filedir
|
38 |
-
self.filetype = filetype
|
39 |
-
|
40 |
-
def split_in_topics(self) -> List[Tuple[str, str]]:
|
41 |
-
for root, _dirs, files in os.walk(self.filedir):
|
42 |
-
for name in files:
|
43 |
-
if not name.endswith(self.filetype) or name == 'navigation.md':
|
44 |
-
continue
|
45 |
-
|
46 |
-
full_path = os.path.join(root, name)
|
47 |
-
with open(full_path, 'r', encoding='utf-8') as file:
|
48 |
-
content = file.read()
|
49 |
-
|
50 |
-
prefix = full_path.replace(self.filedir, '')
|
51 |
-
prefix = re.sub(r'(index)?.md', '', prefix)
|
52 |
-
prefix = prefix.replace('\\', '/')
|
53 |
-
|
54 |
-
# Protect code parts
|
55 |
-
parts = ['']
|
56 |
-
is_first = True
|
57 |
-
is_in_code_block = False
|
58 |
-
for line in content.splitlines():
|
59 |
-
if not line:
|
60 |
-
continue
|
61 |
-
line += '\n'
|
62 |
-
is_in_code_block = is_in_code_block != line.strip().startswith('```')
|
63 |
-
if not is_in_code_block and line.startswith('## '):
|
64 |
-
if not is_first:
|
65 |
-
parts.append(line)
|
66 |
-
continue
|
67 |
-
|
68 |
-
is_first = False
|
69 |
-
parts[-1] += line
|
70 |
-
|
71 |
-
title_main = ''
|
72 |
-
for topic in parts:
|
73 |
-
topic = topic.strip()
|
74 |
-
if not topic or topic.startswith('---\nhide'):
|
75 |
-
continue
|
76 |
-
|
77 |
-
try:
|
78 |
-
title, body = topic.split('\n', 1)
|
79 |
-
except ValueError:
|
80 |
-
# ignore non content
|
81 |
-
continue
|
82 |
-
|
83 |
-
if not title_main:
|
84 |
-
title_main = title
|
85 |
-
else:
|
86 |
-
title = title_main + ' | ' + title
|
87 |
-
|
88 |
-
yield (prefix + '\n' + title, body)
|
89 |
-
|
90 |
-
def reduce_text(_self, text):
|
91 |
-
text = re.sub(r'<.*?>', '', text) # Remove HTML tags
|
92 |
-
text = re.sub(r':\S*: ', '', text) # Remove [:...:] patterns
|
93 |
-
text = re.sub(r'(index)?.md', '', text) # Remove .md
|
94 |
-
return re.sub(r'(\s*\n\s*)+', '\n', text)
|
95 |
-
|
96 |
-
def split_for_embedding(self):
|
97 |
-
tokenizer = EMBEDDING_CTX.model.tokenizer
|
98 |
-
max_tokens = EMBEDDING_CTX.model.max_seq_length
|
99 |
-
texts = []
|
100 |
-
|
101 |
-
for prefix, content in self.split_in_topics():
|
102 |
-
prefix += '\n\n'
|
103 |
-
tokens_prefix_len = len(tokenizer.tokenize(prefix))
|
104 |
-
tokens_so_far = tokens_prefix_len
|
105 |
-
text_so_far = prefix
|
106 |
-
for part in self.reduce_text(content).splitlines():
|
107 |
-
part += '\n'
|
108 |
-
part_tokens_len = len(tokenizer.tokenize(part))
|
109 |
-
if tokens_so_far + part_tokens_len > max_tokens:
|
110 |
-
texts.append(text_so_far)
|
111 |
-
text_so_far = prefix
|
112 |
-
tokens_so_far = tokens_prefix_len
|
113 |
-
text_so_far += part
|
114 |
-
tokens_so_far += part_tokens_len
|
115 |
-
|
116 |
-
if tokens_so_far != tokens_prefix_len:
|
117 |
-
texts.append(text_so_far)
|
118 |
-
|
119 |
-
return texts
|
120 |
|
121 |
|
122 |
class _Data(dict):
|
123 |
-
cache_path = "routers/embedding/
|
124 |
|
125 |
def __init__(self):
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
# Generate
|
133 |
|
134 |
-
|
135 |
-
|
136 |
self[grp.name] = {}
|
137 |
|
138 |
# Create a list to store the text files
|
@@ -146,166 +56,115 @@ class _Data(dict):
|
|
146 |
self[grp]['texts'] = texts
|
147 |
self[grp]['embeddings'] = EMBEDDING_CTX.encode(texts)
|
148 |
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
pickle.dump(dict(self), file, protocol=pickle.HIGHEST_PROTOCOL)
|
155 |
-
|
156 |
-
@classmethod
|
157 |
-
def parse_file_recursive(cls, filepath):
|
158 |
-
with open(filepath, 'r', encoding='utf-8') as file:
|
159 |
-
content = file.read()
|
160 |
-
|
161 |
-
parsed_data = {}
|
162 |
-
|
163 |
-
if filepath.endswith('index.rst'):
|
164 |
-
filedir = os.path.dirname(filepath)
|
165 |
-
parts = content.split(".. toctree::")
|
166 |
-
if len(parts) > 1:
|
167 |
-
parsed_data["toctree"] = {}
|
168 |
-
for part in parts[1:]:
|
169 |
-
toctree_entries = part.splitlines()[1:]
|
170 |
-
for entry in toctree_entries:
|
171 |
-
entry = entry.strip()
|
172 |
-
if not entry:
|
173 |
-
continue
|
174 |
-
|
175 |
-
if entry.startswith('/'):
|
176 |
-
# relative path.
|
177 |
-
continue
|
178 |
-
|
179 |
-
if not entry.endswith('.rst'):
|
180 |
-
continue
|
181 |
|
182 |
-
|
183 |
-
filepath_iter = os.path.join(filedir, entry)
|
184 |
-
parsed_data['toctree'][entry_name] = cls.parse_file_recursive(
|
185 |
-
filepath_iter)
|
186 |
-
|
187 |
-
parsed_data['body'] = content
|
188 |
-
|
189 |
-
return parsed_data
|
190 |
|
191 |
@classmethod
|
192 |
def manual_get_texts_to_embed(cls):
|
193 |
-
class SplitManual(
|
194 |
-
def split_in_topics(_self):
|
195 |
-
def get_topics_recursive(page, path='/index.html'):
|
196 |
-
# Remove patterns ".. word::" and ":word:"
|
197 |
-
text = re.sub(
|
198 |
-
r'\.\. [^\n]+\n+(?: {3,}[^\n]*\n)*|:\w+:', '', page['body'])
|
199 |
-
|
200 |
-
# Regular expression to find titles and subtitles
|
201 |
-
pattern = r'([\*|#|%]{3,}\n[^\n]+\n[\*|#|%]{3,}|(?:={3,}\n)?[^\n]+\n={3,}\n)'
|
202 |
-
|
203 |
-
# Split text by found patterns
|
204 |
-
sections = re.split(pattern, text)
|
205 |
-
|
206 |
-
# Remove possible white spaces at the beginning and end of each section
|
207 |
-
sections = [
|
208 |
-
section for section in sections if section.strip()]
|
209 |
-
|
210 |
-
# Separate sections into a dictionary
|
211 |
-
topics = []
|
212 |
-
current_title = ''
|
213 |
-
current_topic = path
|
214 |
-
|
215 |
-
for section in sections:
|
216 |
-
if match := re.match(r'[\*|#|%]{3,}\n([^\n]+)\n[\*|#|%]{3,}', section):
|
217 |
-
current_topic = current_title = f'{path}\n# {match.group(1)}:'
|
218 |
-
elif match := re.match(r'(?:={3,}\n)?([^\n]+)\n={3,}\n', section):
|
219 |
-
current_topic = f'{current_title} | {match.group(1)}'
|
220 |
-
else:
|
221 |
-
if current_topic == path:
|
222 |
-
raise
|
223 |
-
topics.append((current_topic, section))
|
224 |
-
|
225 |
-
try:
|
226 |
-
for key in page['toctree'].keys():
|
227 |
-
page_child = page['toctree'][key]
|
228 |
-
topics.extend(get_topics_recursive(
|
229 |
-
page_child, path.replace('index', key)))
|
230 |
-
except KeyError:
|
231 |
-
pass
|
232 |
-
|
233 |
-
return topics
|
234 |
-
|
235 |
-
manual = cls.parse_file_recursive(
|
236 |
-
os.path.join(MANUAL_DIR, 'index.rst'))
|
237 |
-
manual['toctree']["copyright"] = cls.parse_file_recursive(
|
238 |
-
os.path.join(MANUAL_DIR, 'copyright.rst'))
|
239 |
-
|
240 |
-
return get_topics_recursive(manual)
|
241 |
-
|
242 |
def reduce_text(_self, text):
|
243 |
# Remove repeated characters
|
244 |
-
text = re.sub(r'%{2,}', '', text) # Title
|
245 |
-
text = re.sub(r'#{2,}', '', text) # Title
|
246 |
-
text = re.sub(r'\*{3,}', '', text) # Title
|
247 |
-
text = re.sub(r'={3,}', '', text) # Topic
|
248 |
text = re.sub(r'\^{3,}', '', text)
|
249 |
text = re.sub(r'-{3,}', '', text)
|
250 |
|
251 |
-
text =
|
|
|
252 |
return text
|
253 |
|
254 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
255 |
|
256 |
@staticmethod
|
257 |
def wiki_get_texts_to_embed():
|
258 |
-
class SplitWiki(
|
259 |
-
def split_in_topics(_self
|
|
|
|
|
|
|
|
|
|
|
260 |
owner = "blender"
|
261 |
repo = "blender"
|
262 |
pages = gitea_wiki_pages_get(owner, repo)
|
263 |
for page_name in pages:
|
264 |
page_name_title = page_name["title"]
|
265 |
page = gitea_wiki_page_get(owner, repo, page_name_title)
|
266 |
-
|
|
|
267 |
text = base64.b64decode(
|
268 |
page["content_base64"]).decode('utf-8')
|
269 |
-
yield (
|
270 |
|
271 |
def reduce_text(_self, text):
|
272 |
-
|
|
|
|
|
273 |
|
274 |
return SplitWiki().split_for_embedding()
|
275 |
|
276 |
@staticmethod
|
277 |
def docs_get_texts_to_embed():
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
|
|
|
|
|
282 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
283 |
query_emb = EMBEDDING_CTX.encode([text_to_search])
|
284 |
-
|
285 |
-
ret = {}
|
286 |
-
|
287 |
for grp in groups:
|
288 |
-
if not
|
289 |
continue
|
290 |
|
291 |
-
|
292 |
query_emb, self[grp]['embeddings'], top_k=limit, score_function=util.dot_score)
|
293 |
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
|
302 |
-
texts
|
303 |
-
for
|
304 |
-
corpus_id = score['corpus_id']
|
305 |
-
text = texts[corpus_id]
|
306 |
-
result.append(text)
|
307 |
|
308 |
-
return
|
309 |
|
310 |
|
311 |
G_data = _Data()
|
@@ -314,23 +173,12 @@ router = APIRouter()
|
|
314 |
|
315 |
|
316 |
@router.get("/wiki_search", response_class=PlainTextResponse)
|
317 |
-
def wiki_search(
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
if group is Group.all:
|
325 |
-
groups = {Group.docs, Group.wiki, Group.manual}
|
326 |
-
elif group is Group.wiki:
|
327 |
-
groups = {Group.docs, Group.wiki}
|
328 |
-
else:
|
329 |
-
groups = {group}
|
330 |
-
|
331 |
-
texts, group_best = G_data._sort_similarity(query, groups)
|
332 |
-
|
333 |
-
result = f'BASE_URL: {base_url[group_best]}\n'
|
334 |
for text in texts:
|
335 |
result += f'\n---\n{text}'
|
336 |
return result
|
@@ -339,5 +187,5 @@ def wiki_search(query: str = "", group: Group = Group.all) -> str:
|
|
339 |
if __name__ == '__main__':
|
340 |
tests = ["Set Snap Base", "Building the Manual",
|
341 |
"Bisect Object", "Who are the Triagers", "4.3 Release Notes Motion Paths"]
|
342 |
-
result = wiki_search(tests[
|
343 |
print(result)
|
|
|
6 |
import re
|
7 |
import torch
|
8 |
from enum import Enum
|
9 |
+
from fastapi import APIRouter, Query
|
|
|
|
|
10 |
from fastapi.responses import PlainTextResponse
|
11 |
+
from heapq import nlargest
|
12 |
+
from sentence_transformers import util
|
13 |
+
from typing import Dict, List, Tuple, Set, LiteralString
|
14 |
|
15 |
try:
|
16 |
+
from .embedding import SplitDocs, EMBEDDING_CTX
|
17 |
from .utils_gitea import gitea_wiki_page_get, gitea_wiki_pages_get
|
18 |
except:
|
19 |
+
from embedding import SplitDocs, EMBEDDING_CTX
|
20 |
from utils_gitea import gitea_wiki_page_get, gitea_wiki_pages_get
|
21 |
|
22 |
+
|
23 |
+
MANUAL_DIR = "D:/BlenderDev/blender-manual/manual"
|
24 |
DOCS_DIR = "D:/BlenderDev/blender-developer-docs/docs"
|
25 |
|
26 |
|
27 |
class Group(str, Enum):
|
28 |
+
dev_docs = "dev_docs"
|
29 |
+
# wiki = "wiki"
|
30 |
manual = "manual"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
|
32 |
|
33 |
class _Data(dict):
|
34 |
+
cache_path = "routers/embedding/embeddings_{}.pkl"
|
35 |
|
36 |
def __init__(self):
|
37 |
+
for grp in list(Group):
|
38 |
+
cache_path = self.cache_path.format(grp.name)
|
39 |
+
if os.path.exists(cache_path):
|
40 |
+
with open(cache_path, 'rb') as file:
|
41 |
+
self[grp.name] = pickle.load(file)
|
42 |
+
continue
|
|
|
43 |
|
44 |
+
# Generate
|
45 |
+
print("Embedding Texts for", grp.name)
|
46 |
self[grp.name] = {}
|
47 |
|
48 |
# Create a list to store the text files
|
|
|
56 |
self[grp]['texts'] = texts
|
57 |
self[grp]['embeddings'] = EMBEDDING_CTX.encode(texts)
|
58 |
|
59 |
+
with open(cache_path, "wb") as file:
|
60 |
+
# Converting the embeddings to be CPU compatible, as the virtual machine in use currently only supports the CPU.
|
61 |
+
self[grp]['embeddings'] = self[grp]['embeddings'].to(
|
62 |
+
torch.device('cpu'))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
|
64 |
+
pickle.dump(self[grp], file, protocol=pickle.HIGHEST_PROTOCOL)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
|
66 |
@classmethod
|
67 |
def manual_get_texts_to_embed(cls):
|
68 |
+
class SplitManual(SplitDocs):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
def reduce_text(_self, text):
|
70 |
# Remove repeated characters
|
|
|
|
|
|
|
|
|
71 |
text = re.sub(r'\^{3,}', '', text)
|
72 |
text = re.sub(r'-{3,}', '', text)
|
73 |
|
74 |
+
text = text.replace('.rst', '.html')
|
75 |
+
text = super().reduce_text(text)
|
76 |
return text
|
77 |
|
78 |
+
def embedding_header(self, rel_path, titles):
|
79 |
+
rel_path = rel_path.replace('.rst', '.html')
|
80 |
+
return super().embedding_header(rel_path, titles)
|
81 |
+
|
82 |
+
# Remove patterns ".. word::" and ":word:"
|
83 |
+
pattern_content_sub = r'\.\. [^\n]+\n+(?: {3,}[^\n]*\n)*|:\w+:'
|
84 |
+
patterns_titles = (
|
85 |
+
r'[\*#%]{3,}\n\s*(.+)\n[\*#%]{3,}', r'(?:[=+]{3,}\n)?\s*(.+)\n[=+]{3,}\n')
|
86 |
+
|
87 |
+
return SplitManual().split_for_embedding(
|
88 |
+
MANUAL_DIR,
|
89 |
+
pattern_content_sub=pattern_content_sub,
|
90 |
+
patterns_titles=patterns_titles,
|
91 |
+
)
|
92 |
|
93 |
@staticmethod
|
94 |
def wiki_get_texts_to_embed():
|
95 |
+
class SplitWiki(SplitDocs):
|
96 |
+
def split_in_topics(_self,
|
97 |
+
filedir: LiteralString = None,
|
98 |
+
*,
|
99 |
+
pattern_filename=None,
|
100 |
+
pattern_content_sub=None,
|
101 |
+
patterns_titles=None):
|
102 |
owner = "blender"
|
103 |
repo = "blender"
|
104 |
pages = gitea_wiki_pages_get(owner, repo)
|
105 |
for page_name in pages:
|
106 |
page_name_title = page_name["title"]
|
107 |
page = gitea_wiki_page_get(owner, repo, page_name_title)
|
108 |
+
rel_dir = f'/{owner}/{repo}/{page["sub_url"]}'
|
109 |
+
titles = [page_name_title]
|
110 |
text = base64.b64decode(
|
111 |
page["content_base64"]).decode('utf-8')
|
112 |
+
yield (rel_dir, titles, text)
|
113 |
|
114 |
def reduce_text(_self, text):
|
115 |
+
text = super().reduce_text(text)
|
116 |
+
text = text.replace('https://projects.blender.org', '')
|
117 |
+
return text
|
118 |
|
119 |
return SplitWiki().split_for_embedding()
|
120 |
|
121 |
@staticmethod
|
122 |
def docs_get_texts_to_embed():
|
123 |
+
class SplitBlenderDocs(SplitDocs):
|
124 |
+
def reduce_text(_self, text):
|
125 |
+
text = super().reduce_text(text)
|
126 |
+
# Remove .md or index.md
|
127 |
+
text = re.sub(r'(index)?.md', '', text)
|
128 |
+
return text
|
129 |
|
130 |
+
def embedding_header(_self, rel_path, titles):
|
131 |
+
rel_path = re.sub(r'(index)?.md', '', rel_path)
|
132 |
+
return super().embedding_header(rel_path, titles)
|
133 |
+
|
134 |
+
return SplitBlenderDocs().split_for_embedding(DOCS_DIR)
|
135 |
+
|
136 |
+
def _sort_similarity(
|
137 |
+
self,
|
138 |
+
text_to_search: str,
|
139 |
+
groups: Set[Group] = Query(
|
140 |
+
default={Group.dev_docs, Group.manual}),
|
141 |
+
limit: int = 5) -> List[str]:
|
142 |
+
base_url: Dict[Group, str] = {
|
143 |
+
Group.dev_docs: "https://developer.blender.org/docs",
|
144 |
+
# Group.wiki: "https://projects.blender.org",
|
145 |
+
Group.manual: "https://docs.blender.org/manual/en/dev"
|
146 |
+
}
|
147 |
query_emb = EMBEDDING_CTX.encode([text_to_search])
|
148 |
+
results: List[Tuple[float, str, Group]] = []
|
|
|
|
|
149 |
for grp in groups:
|
150 |
+
if grp not in self:
|
151 |
continue
|
152 |
|
153 |
+
search_results = util.semantic_search(
|
154 |
query_emb, self[grp]['embeddings'], top_k=limit, score_function=util.dot_score)
|
155 |
|
156 |
+
for score in search_results[0]:
|
157 |
+
corpus_id = score['corpus_id']
|
158 |
+
text = self[grp]['texts'][corpus_id]
|
159 |
+
results.append((score['score'], text, grp))
|
160 |
+
|
161 |
+
# Keep only the top `limit` results
|
162 |
+
top_results = nlargest(limit, results, key=lambda x: x[0])
|
163 |
|
164 |
+
# Extract sorted texts with base URL
|
165 |
+
sorted_texts = [base_url[grp] + text for _, text, grp in top_results]
|
|
|
|
|
|
|
166 |
|
167 |
+
return sorted_texts
|
168 |
|
169 |
|
170 |
G_data = _Data()
|
|
|
173 |
|
174 |
|
175 |
@router.get("/wiki_search", response_class=PlainTextResponse)
|
176 |
+
def wiki_search(
|
177 |
+
query: str = "",
|
178 |
+
groups: Set[Group] = Query(default={Group.dev_docs, Group.manual})
|
179 |
+
) -> str:
|
180 |
+
texts = G_data._sort_similarity(query, groups)
|
181 |
+
result: str = ''
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
182 |
for text in texts:
|
183 |
result += f'\n---\n{text}'
|
184 |
return result
|
|
|
187 |
if __name__ == '__main__':
|
188 |
tests = ["Set Snap Base", "Building the Manual",
|
189 |
"Bisect Object", "Who are the Triagers", "4.3 Release Notes Motion Paths"]
|
190 |
+
result = wiki_search(tests[0], {Group.dev_docs, Group.manual})
|
191 |
print(result)
|