tools / routers /tool_wiki_search.py
Germano Cavalcante
RAG: Update cache
74e0256
# routers/tool_wiki_search.py
import base64
import os
import pickle
import re
import torch
from enum import Enum
from fastapi import APIRouter, Query, params
from fastapi.responses import PlainTextResponse
from heapq import nlargest
from sentence_transformers import util
from typing import Dict, List, Tuple, Set, LiteralString
try:
from .rag import SplitDocs, EMBEDDING_CTX
from .utils_gitea import gitea_wiki_page_get, gitea_wiki_pages_get
except:
from rag import SplitDocs, EMBEDDING_CTX
from utils_gitea import gitea_wiki_page_get, gitea_wiki_pages_get
MANUAL_DIR = "D:/BlenderDev/blender-manual/manual"
DOCS_DIR = "D:/BlenderDev/blender-developer-docs/docs"
class Group(str, Enum):
dev_docs = "dev_docs"
# wiki = "wiki"
manual = "manual"
GROUPS_DEFAULT = {Group.dev_docs, Group.manual}
class _Data(dict):
cache_path = "routers/rag/embeddings_{}.pkl"
def __init__(self):
for grp in list(Group):
cache_path = self.cache_path.format(grp.name)
if os.path.exists(cache_path):
with open(cache_path, 'rb') as file:
self[grp.name] = pickle.load(file)
continue
# Generate
print("Embedding Texts for", grp.name)
self[grp.name] = {}
# Create a list to store the text files
if grp is Group.dev_docs:
texts = self.docs_get_texts_to_embed()
# elif grp is Group.wiki:
# texts = self.wiki_get_texts_to_embed()
else:
texts = self.manual_get_texts_to_embed()
self[grp]['texts'] = texts
self[grp]['embeddings'] = EMBEDDING_CTX.encode(texts)
with open(cache_path, "wb") as file:
# Converting the embeddings to be CPU compatible, as the virtual machine in use currently only supports the CPU.
self[grp]['embeddings'] = self[grp]['embeddings'].to(
torch.device('cpu'))
pickle.dump(self[grp], file, protocol=pickle.HIGHEST_PROTOCOL)
@classmethod
def manual_get_texts_to_embed(cls):
class SplitManual(SplitDocs):
def reduce_text(_self, text):
# Remove repeated characters
text = re.sub(r'\^{3,}', '', text)
text = re.sub(r'-{3,}', '', text)
text = text.replace('.rst', '.html')
text = super().reduce_text(text)
return text
def embedding_header(self, rel_path, titles):
rel_path = rel_path.replace('.rst', '.html')
return super().embedding_header(rel_path, titles)
# Remove patterns ".. word::" and ":word:"
pattern_content_sub = r'\.\. [^\n]+\n+(?: {3,}[^\n]*\n)*|:\w+:'
patterns_titles = (
r'[\*#%]{3,}\n\s*(.+)\n[\*#%]{3,}', r'(?:[=+]{3,}\n)?\s*(.+)\n[=+]{3,}\n')
return SplitManual().split_for_embedding(
MANUAL_DIR,
pattern_content_sub=pattern_content_sub,
patterns_titles=patterns_titles,
)
@staticmethod
def wiki_get_texts_to_embed():
class SplitWiki(SplitDocs):
def split_in_topics(_self,
filedir: LiteralString = None,
*,
pattern_filename=None,
pattern_content_sub=None,
patterns_titles=None):
owner = "blender"
repo = "blender"
pages = gitea_wiki_pages_get(owner, repo)
for page_name in pages:
page_name_title = page_name["title"]
page = gitea_wiki_page_get(owner, repo, page_name_title)
rel_dir = f'/{owner}/{repo}/{page["sub_url"]}'
titles = [page_name_title]
text = base64.b64decode(
page["content_base64"]).decode('utf-8')
yield (rel_dir, titles, text)
def reduce_text(_self, text):
text = super().reduce_text(text)
text = text.replace('https://projects.blender.org', '')
return text
return SplitWiki().split_for_embedding()
@staticmethod
def docs_get_texts_to_embed():
class SplitBlenderDocs(SplitDocs):
def reduce_text(_self, text):
text = super().reduce_text(text)
# Remove .md or index.md
text = re.sub(r'(index)?.md', '', text)
return text
def embedding_header(_self, rel_path, titles):
rel_path = re.sub(r'(index)?.md', '', rel_path)
return super().embedding_header(rel_path, titles)
return SplitBlenderDocs().split_for_embedding(DOCS_DIR)
def _sort_similarity(
self,
text_to_search: str,
groups: Set[Group] = Query(
default={Group.dev_docs, Group.manual}),
limit: int = 5) -> List[str]:
base_url: Dict[Group, str] = {
Group.dev_docs: "https://developer.blender.org/docs",
# Group.wiki: "https://projects.blender.org",
Group.manual: "https://docs.blender.org/manual/en/dev"
}
query_emb = EMBEDDING_CTX.encode([text_to_search])
results: List[Tuple[float, str, Group]] = []
for grp in groups:
if grp not in self:
continue
search_results = util.semantic_search(
query_emb, self[grp]['embeddings'], top_k=limit, score_function=util.dot_score)
for score in search_results[0]:
corpus_id = score['corpus_id']
text = self[grp]['texts'][corpus_id]
results.append((score['score'], text, grp))
# Keep only the top `limit` results
top_results = nlargest(limit, results, key=lambda x: x[0])
# Extract sorted texts with base URL
sorted_texts = [base_url[grp] + text for _, text, grp in top_results]
return sorted_texts
G_data = _Data()
router = APIRouter()
@router.get("/wiki_search", response_class=PlainTextResponse)
def wiki_search(
query: str = "",
groups: Set[Group] = Query(default=GROUPS_DEFAULT)
) -> str:
try:
groups = GROUPS_DEFAULT.intersection(groups)
if len(groups) == 0:
raise
except:
groups = GROUPS_DEFAULT
texts = G_data._sort_similarity(query, groups)
result: str = ''
for text in texts:
result += f'\n---\n{text}'
return result
if __name__ == '__main__':
tests = ["Set Snap Base", "Building the Manual",
"Bisect Object", "Who are the Triagers", "4.3 Release Notes Motion Paths"]
result = wiki_search(tests[0])
print(result)