Spaces:
Running
Running
File size: 6,892 Bytes
0576e6d ed15883 0576e6d ed15883 9a6a74b 0576e6d 25dbca2 9a6a74b 1b8973e ed15883 23415c5 5974bb1 ed15883 23415c5 5974bb1 ed15883 1b8973e af4d94e 0576e6d 1b8973e 0576e6d af4d94e eefe03e 9a6a74b 23415c5 0576e6d 1b8973e 0576e6d 1b8973e 0576e6d 74e0256 af4d94e 74e0256 af4d94e 0576e6d 1b8973e 9a6a74b 1b8973e 9a6a74b 0576e6d 1b8973e af4d94e 1b8973e af4d94e 1b8973e ed15883 af4d94e 1b8973e af4d94e 1b8973e af4d94e 1b8973e af4d94e 1b8973e af4d94e 0576e6d af4d94e 1b8973e ed15883 1b8973e 0576e6d 1b8973e af4d94e 1b8973e af4d94e 1b8973e af4d94e ed15883 1b8973e ed15883 1b8973e ed15883 1b8973e ed15883 9a6a74b ed15883 0576e6d ed15883 0576e6d 1b8973e eefe03e 1b8973e eefe03e 25dbca2 1b8973e ed15883 af4d94e ed15883 0576e6d af4d94e 25dbca2 ed15883 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 |
# routers/tool_wiki_search.py
import base64
import os
import pickle
import re
import torch
from enum import Enum
from fastapi import APIRouter, Query, params
from fastapi.responses import PlainTextResponse
from heapq import nlargest
from sentence_transformers import util
from typing import Dict, List, Tuple, Set, LiteralString
try:
from .rag import SplitDocs, EMBEDDING_CTX
from .utils_gitea import gitea_wiki_page_get, gitea_wiki_pages_get
except:
from rag import SplitDocs, EMBEDDING_CTX
from utils_gitea import gitea_wiki_page_get, gitea_wiki_pages_get
MANUAL_DIR = "D:/BlenderDev/blender-manual/manual"
DOCS_DIR = "D:/BlenderDev/blender-developer-docs/docs"
class Group(str, Enum):
dev_docs = "dev_docs"
# wiki = "wiki"
manual = "manual"
GROUPS_DEFAULT = {Group.dev_docs, Group.manual}
class _Data(dict):
cache_path = "routers/rag/embeddings_{}.pkl"
def __init__(self):
for grp in list(Group):
cache_path = self.cache_path.format(grp.name)
if os.path.exists(cache_path):
with open(cache_path, 'rb') as file:
self[grp.name] = pickle.load(file)
continue
# Generate
print("Embedding Texts for", grp.name)
self[grp.name] = {}
# Create a list to store the text files
if grp is Group.dev_docs:
texts = self.docs_get_texts_to_embed()
# elif grp is Group.wiki:
# texts = self.wiki_get_texts_to_embed()
else:
texts = self.manual_get_texts_to_embed()
self[grp]['texts'] = texts
self[grp]['embeddings'] = EMBEDDING_CTX.encode(texts)
with open(cache_path, "wb") as file:
# Converting the embeddings to be CPU compatible, as the virtual machine in use currently only supports the CPU.
self[grp]['embeddings'] = self[grp]['embeddings'].to(
torch.device('cpu'))
pickle.dump(self[grp], file, protocol=pickle.HIGHEST_PROTOCOL)
@classmethod
def manual_get_texts_to_embed(cls):
class SplitManual(SplitDocs):
def reduce_text(_self, text):
# Remove repeated characters
text = re.sub(r'\^{3,}', '', text)
text = re.sub(r'-{3,}', '', text)
text = text.replace('.rst', '.html')
text = super().reduce_text(text)
return text
def embedding_header(self, rel_path, titles):
rel_path = rel_path.replace('.rst', '.html')
return super().embedding_header(rel_path, titles)
# Remove patterns ".. word::" and ":word:"
pattern_content_sub = r'\.\. [^\n]+\n+(?: {3,}[^\n]*\n)*|:\w+:'
patterns_titles = (
r'[\*#%]{3,}\n\s*(.+)\n[\*#%]{3,}', r'(?:[=+]{3,}\n)?\s*(.+)\n[=+]{3,}\n')
return SplitManual().split_for_embedding(
MANUAL_DIR,
pattern_content_sub=pattern_content_sub,
patterns_titles=patterns_titles,
)
@staticmethod
def wiki_get_texts_to_embed():
class SplitWiki(SplitDocs):
def split_in_topics(_self,
filedir: LiteralString = None,
*,
pattern_filename=None,
pattern_content_sub=None,
patterns_titles=None):
owner = "blender"
repo = "blender"
pages = gitea_wiki_pages_get(owner, repo)
for page_name in pages:
page_name_title = page_name["title"]
page = gitea_wiki_page_get(owner, repo, page_name_title)
rel_dir = f'/{owner}/{repo}/{page["sub_url"]}'
titles = [page_name_title]
text = base64.b64decode(
page["content_base64"]).decode('utf-8')
yield (rel_dir, titles, text)
def reduce_text(_self, text):
text = super().reduce_text(text)
text = text.replace('https://projects.blender.org', '')
return text
return SplitWiki().split_for_embedding()
@staticmethod
def docs_get_texts_to_embed():
class SplitBlenderDocs(SplitDocs):
def reduce_text(_self, text):
text = super().reduce_text(text)
# Remove .md or index.md
text = re.sub(r'(index)?.md', '', text)
return text
def embedding_header(_self, rel_path, titles):
rel_path = re.sub(r'(index)?.md', '', rel_path)
return super().embedding_header(rel_path, titles)
return SplitBlenderDocs().split_for_embedding(DOCS_DIR)
def _sort_similarity(
self,
text_to_search: str,
groups: Set[Group] = Query(
default={Group.dev_docs, Group.manual}),
limit: int = 5) -> List[str]:
base_url: Dict[Group, str] = {
Group.dev_docs: "https://developer.blender.org/docs",
# Group.wiki: "https://projects.blender.org",
Group.manual: "https://docs.blender.org/manual/en/dev"
}
query_emb = EMBEDDING_CTX.encode([text_to_search])
results: List[Tuple[float, str, Group]] = []
for grp in groups:
if grp not in self:
continue
search_results = util.semantic_search(
query_emb, self[grp]['embeddings'], top_k=limit, score_function=util.dot_score)
for score in search_results[0]:
corpus_id = score['corpus_id']
text = self[grp]['texts'][corpus_id]
results.append((score['score'], text, grp))
# Keep only the top `limit` results
top_results = nlargest(limit, results, key=lambda x: x[0])
# Extract sorted texts with base URL
sorted_texts = [base_url[grp] + text for _, text, grp in top_results]
return sorted_texts
G_data = _Data()
router = APIRouter()
@router.get("/wiki_search", response_class=PlainTextResponse)
def wiki_search(
query: str = "",
groups: Set[Group] = Query(default=GROUPS_DEFAULT)
) -> str:
try:
groups = GROUPS_DEFAULT.intersection(groups)
if len(groups) == 0:
raise
except:
groups = GROUPS_DEFAULT
texts = G_data._sort_similarity(query, groups)
result: str = ''
for text in texts:
result += f'\n---\n{text}'
return result
if __name__ == '__main__':
tests = ["Set Snap Base", "Building the Manual",
"Bisect Object", "Who are the Triagers", "4.3 Release Notes Motion Paths"]
result = wiki_search(tests[0])
print(result)
|