Spaces:
Running
Running
# routers/tool_wiki_search.py | |
import base64 | |
import os | |
import pickle | |
import re | |
import torch | |
from enum import Enum | |
from fastapi import APIRouter, Query, params | |
from fastapi.responses import PlainTextResponse | |
from heapq import nlargest | |
from sentence_transformers import util | |
from typing import Dict, List, Tuple, Set, LiteralString | |
try: | |
from .rag import SplitDocs, EMBEDDING_CTX | |
from .utils_gitea import gitea_wiki_page_get, gitea_wiki_pages_get | |
except: | |
from rag import SplitDocs, EMBEDDING_CTX | |
from utils_gitea import gitea_wiki_page_get, gitea_wiki_pages_get | |
MANUAL_DIR = "D:/BlenderDev/blender-manual/manual" | |
DOCS_DIR = "D:/BlenderDev/blender-developer-docs/docs" | |
class Group(str, Enum): | |
dev_docs = "dev_docs" | |
# wiki = "wiki" | |
manual = "manual" | |
GROUPS_DEFAULT = {Group.dev_docs, Group.manual} | |
class _Data(dict): | |
cache_path = "routers/rag/embeddings_{}.pkl" | |
def __init__(self): | |
for grp in list(Group): | |
cache_path = self.cache_path.format(grp.name) | |
if os.path.exists(cache_path): | |
with open(cache_path, 'rb') as file: | |
self[grp.name] = pickle.load(file) | |
continue | |
# Generate | |
print("Embedding Texts for", grp.name) | |
self[grp.name] = {} | |
# Create a list to store the text files | |
if grp is Group.dev_docs: | |
texts = self.docs_get_texts_to_embed() | |
# elif grp is Group.wiki: | |
# texts = self.wiki_get_texts_to_embed() | |
else: | |
texts = self.manual_get_texts_to_embed() | |
self[grp]['texts'] = texts | |
self[grp]['embeddings'] = EMBEDDING_CTX.encode(texts) | |
with open(cache_path, "wb") as file: | |
# Converting the embeddings to be CPU compatible, as the virtual machine in use currently only supports the CPU. | |
self[grp]['embeddings'] = self[grp]['embeddings'].to( | |
torch.device('cpu')) | |
pickle.dump(self[grp], file, protocol=pickle.HIGHEST_PROTOCOL) | |
def manual_get_texts_to_embed(cls): | |
class SplitManual(SplitDocs): | |
def reduce_text(_self, text): | |
# Remove repeated characters | |
text = re.sub(r'\^{3,}', '', text) | |
text = re.sub(r'-{3,}', '', text) | |
text = text.replace('.rst', '.html') | |
text = super().reduce_text(text) | |
return text | |
def embedding_header(self, rel_path, titles): | |
rel_path = rel_path.replace('.rst', '.html') | |
return super().embedding_header(rel_path, titles) | |
# Remove patterns ".. word::" and ":word:" | |
pattern_content_sub = r'\.\. [^\n]+\n+(?: {3,}[^\n]*\n)*|:\w+:' | |
patterns_titles = ( | |
r'[\*#%]{3,}\n\s*(.+)\n[\*#%]{3,}', r'(?:[=+]{3,}\n)?\s*(.+)\n[=+]{3,}\n') | |
return SplitManual().split_for_embedding( | |
MANUAL_DIR, | |
pattern_content_sub=pattern_content_sub, | |
patterns_titles=patterns_titles, | |
) | |
def wiki_get_texts_to_embed(): | |
class SplitWiki(SplitDocs): | |
def split_in_topics(_self, | |
filedir: LiteralString = None, | |
*, | |
pattern_filename=None, | |
pattern_content_sub=None, | |
patterns_titles=None): | |
owner = "blender" | |
repo = "blender" | |
pages = gitea_wiki_pages_get(owner, repo) | |
for page_name in pages: | |
page_name_title = page_name["title"] | |
page = gitea_wiki_page_get(owner, repo, page_name_title) | |
rel_dir = f'/{owner}/{repo}/{page["sub_url"]}' | |
titles = [page_name_title] | |
text = base64.b64decode( | |
page["content_base64"]).decode('utf-8') | |
yield (rel_dir, titles, text) | |
def reduce_text(_self, text): | |
text = super().reduce_text(text) | |
text = text.replace('https://projects.blender.org', '') | |
return text | |
return SplitWiki().split_for_embedding() | |
def docs_get_texts_to_embed(): | |
class SplitBlenderDocs(SplitDocs): | |
def reduce_text(_self, text): | |
text = super().reduce_text(text) | |
# Remove .md or index.md | |
text = re.sub(r'(index)?.md', '', text) | |
return text | |
def embedding_header(_self, rel_path, titles): | |
rel_path = re.sub(r'(index)?.md', '', rel_path) | |
return super().embedding_header(rel_path, titles) | |
return SplitBlenderDocs().split_for_embedding(DOCS_DIR) | |
def _sort_similarity( | |
self, | |
text_to_search: str, | |
groups: Set[Group] = Query( | |
default={Group.dev_docs, Group.manual}), | |
limit: int = 5) -> List[str]: | |
base_url: Dict[Group, str] = { | |
Group.dev_docs: "https://developer.blender.org/docs", | |
# Group.wiki: "https://projects.blender.org", | |
Group.manual: "https://docs.blender.org/manual/en/dev" | |
} | |
query_emb = EMBEDDING_CTX.encode([text_to_search]) | |
results: List[Tuple[float, str, Group]] = [] | |
for grp in groups: | |
if grp not in self: | |
continue | |
search_results = util.semantic_search( | |
query_emb, self[grp]['embeddings'], top_k=limit, score_function=util.dot_score) | |
for score in search_results[0]: | |
corpus_id = score['corpus_id'] | |
text = self[grp]['texts'][corpus_id] | |
results.append((score['score'], text, grp)) | |
# Keep only the top `limit` results | |
top_results = nlargest(limit, results, key=lambda x: x[0]) | |
# Extract sorted texts with base URL | |
sorted_texts = [base_url[grp] + text for _, text, grp in top_results] | |
return sorted_texts | |
G_data = _Data() | |
router = APIRouter() | |
def wiki_search( | |
query: str = "", | |
groups: Set[Group] = Query(default=GROUPS_DEFAULT) | |
) -> str: | |
try: | |
groups = GROUPS_DEFAULT.intersection(groups) | |
if len(groups) == 0: | |
raise | |
except: | |
groups = GROUPS_DEFAULT | |
texts = G_data._sort_similarity(query, groups) | |
result: str = '' | |
for text in texts: | |
result += f'\n---\n{text}' | |
return result | |
if __name__ == '__main__': | |
tests = ["Set Snap Base", "Building the Manual", | |
"Bisect Object", "Who are the Triagers", "4.3 Release Notes Motion Paths"] | |
result = wiki_search(tests[0]) | |
print(result) | |