# routers/tool_wiki_search.py import base64 import os import pickle import re import torch from enum import Enum from fastapi import APIRouter, Query, params from fastapi.responses import PlainTextResponse from heapq import nlargest from sentence_transformers import util from typing import Dict, List, Tuple, Set, LiteralString try: from .rag import SplitDocs, EMBEDDING_CTX from .utils_gitea import gitea_wiki_page_get, gitea_wiki_pages_get except: from rag import SplitDocs, EMBEDDING_CTX from utils_gitea import gitea_wiki_page_get, gitea_wiki_pages_get MANUAL_DIR = "D:/BlenderDev/blender-manual/manual" DOCS_DIR = "D:/BlenderDev/blender-developer-docs/docs" class Group(str, Enum): dev_docs = "dev_docs" # wiki = "wiki" manual = "manual" GROUPS_DEFAULT = {Group.dev_docs, Group.manual} class _Data(dict): cache_path = "routers/rag/embeddings_{}.pkl" def __init__(self): for grp in list(Group): cache_path = self.cache_path.format(grp.name) if os.path.exists(cache_path): with open(cache_path, 'rb') as file: self[grp.name] = pickle.load(file) continue # Generate print("Embedding Texts for", grp.name) self[grp.name] = {} # Create a list to store the text files if grp is Group.dev_docs: texts = self.docs_get_texts_to_embed() # elif grp is Group.wiki: # texts = self.wiki_get_texts_to_embed() else: texts = self.manual_get_texts_to_embed() self[grp]['texts'] = texts self[grp]['embeddings'] = EMBEDDING_CTX.encode(texts) with open(cache_path, "wb") as file: # Converting the embeddings to be CPU compatible, as the virtual machine in use currently only supports the CPU. self[grp]['embeddings'] = self[grp]['embeddings'].to( torch.device('cpu')) pickle.dump(self[grp], file, protocol=pickle.HIGHEST_PROTOCOL) @classmethod def manual_get_texts_to_embed(cls): class SplitManual(SplitDocs): def reduce_text(_self, text): # Remove repeated characters text = re.sub(r'\^{3,}', '', text) text = re.sub(r'-{3,}', '', text) text = text.replace('.rst', '.html') text = super().reduce_text(text) return text def embedding_header(self, rel_path, titles): rel_path = rel_path.replace('.rst', '.html') return super().embedding_header(rel_path, titles) # Remove patterns ".. word::" and ":word:" pattern_content_sub = r'\.\. [^\n]+\n+(?: {3,}[^\n]*\n)*|:\w+:' patterns_titles = ( r'[\*#%]{3,}\n\s*(.+)\n[\*#%]{3,}', r'(?:[=+]{3,}\n)?\s*(.+)\n[=+]{3,}\n') return SplitManual().split_for_embedding( MANUAL_DIR, pattern_content_sub=pattern_content_sub, patterns_titles=patterns_titles, ) @staticmethod def wiki_get_texts_to_embed(): class SplitWiki(SplitDocs): def split_in_topics(_self, filedir: LiteralString = None, *, pattern_filename=None, pattern_content_sub=None, patterns_titles=None): owner = "blender" repo = "blender" pages = gitea_wiki_pages_get(owner, repo) for page_name in pages: page_name_title = page_name["title"] page = gitea_wiki_page_get(owner, repo, page_name_title) rel_dir = f'/{owner}/{repo}/{page["sub_url"]}' titles = [page_name_title] text = base64.b64decode( page["content_base64"]).decode('utf-8') yield (rel_dir, titles, text) def reduce_text(_self, text): text = super().reduce_text(text) text = text.replace('https://projects.blender.org', '') return text return SplitWiki().split_for_embedding() @staticmethod def docs_get_texts_to_embed(): class SplitBlenderDocs(SplitDocs): def reduce_text(_self, text): text = super().reduce_text(text) # Remove .md or index.md text = re.sub(r'(index)?.md', '', text) return text def embedding_header(_self, rel_path, titles): rel_path = re.sub(r'(index)?.md', '', rel_path) return super().embedding_header(rel_path, titles) return SplitBlenderDocs().split_for_embedding(DOCS_DIR) def _sort_similarity( self, text_to_search: str, groups: Set[Group] = Query( default={Group.dev_docs, Group.manual}), limit: int = 5) -> List[str]: base_url: Dict[Group, str] = { Group.dev_docs: "https://developer.blender.org/docs", # Group.wiki: "https://projects.blender.org", Group.manual: "https://docs.blender.org/manual/en/dev" } query_emb = EMBEDDING_CTX.encode([text_to_search]) results: List[Tuple[float, str, Group]] = [] for grp in groups: if grp not in self: continue search_results = util.semantic_search( query_emb, self[grp]['embeddings'], top_k=limit, score_function=util.dot_score) for score in search_results[0]: corpus_id = score['corpus_id'] text = self[grp]['texts'][corpus_id] results.append((score['score'], text, grp)) # Keep only the top `limit` results top_results = nlargest(limit, results, key=lambda x: x[0]) # Extract sorted texts with base URL sorted_texts = [base_url[grp] + text for _, text, grp in top_results] return sorted_texts G_data = _Data() router = APIRouter() @router.get("/wiki_search", response_class=PlainTextResponse) def wiki_search( query: str = "", groups: Set[Group] = Query(default=GROUPS_DEFAULT) ) -> str: try: groups = GROUPS_DEFAULT.intersection(groups) if len(groups) == 0: raise except: groups = GROUPS_DEFAULT texts = G_data._sort_similarity(query, groups) result: str = '' for text in texts: result += f'\n---\n{text}' return result if __name__ == '__main__': tests = ["Set Snap Base", "Building the Manual", "Bisect Object", "Who are the Triagers", "4.3 Release Notes Motion Paths"] result = wiki_search(tests[0]) print(result)