# routers/tool_wiki_search.py import base64 import os import pickle import re import torch from enum import Enum from typing import Dict, List from sentence_transformers import util from fastapi import APIRouter from fastapi.responses import PlainTextResponse from utils_gitea import gitea_wiki_page_get, gitea_wiki_pages_get try: from .embedding import EMBEDDING_CTX except: from embedding import EMBEDDING_CTX MANUAL_DIR = "D:/BlenderDev/blender-manual/manual/" class Group(str, Enum): wiki = "wiki" manual = "manual" all = "all" class _Data(dict): cache_path = "routers/embedding/embeddings_manual_wiki.pkl" def __init__(self): if os.path.exists(self.cache_path): with open(self.cache_path, 'rb') as file: data = pickle.load(file) self.update(data) return # Generate print("Embedding Texts...") for grp in list(Group)[:-1]: self[grp.name] = {} # Create a list to store the text files texts = self.manual_get_texts_to_embed( ) if grp == Group.manual else self.wiki_get_texts_to_embed() self[grp]['texts'] = texts self[grp]['embeddings'] = EMBEDDING_CTX.encode(texts) with open(self.cache_path, "wb") as file: # Converting the embeddings to be CPU compatible, as the virtual machine in use currently only supports the CPU. for val in self.values(): val['embeddings'] = val['embeddings'].to(torch.device('cpu')) pickle.dump(dict(self), file, protocol=pickle.HIGHEST_PROTOCOL) @staticmethod def reduce_text(text): # Remove repeated characters text = re.sub(r'%{2,}', '', text) # Title text = re.sub(r'#{2,}', '', text) # Title text = re.sub(r'\*{3,}', '', text) # Title text = re.sub(r'={3,}', '', text) # Topic text = re.sub(r'\^{3,}', '', text) text = re.sub(r'-{3,}', '', text) text = re.sub(r'(\s*\n\s*)+', '\n', text) return text @classmethod def parse_file_recursive(cls, filepath): with open(filepath, 'r', encoding='utf-8') as file: content = file.read() parsed_data = {} if filepath.endswith('index.rst'): filedir = os.path.dirname(filepath) parts = content.split(".. toctree::") if len(parts) > 1: parsed_data["toctree"] = {} for part in parts[1:]: toctree_entries = part.splitlines()[1:] for entry in toctree_entries: entry = entry.strip() if not entry: continue if entry.startswith('/'): # relative path. continue if not entry.endswith('.rst'): continue entry_name = entry[:-4] # remove '.rst' filepath_iter = os.path.join(filedir, entry) parsed_data['toctree'][entry_name] = cls.parse_file_recursive( filepath_iter) parsed_data['body'] = content return parsed_data @staticmethod def split_into_topics(text: str, prefix: str = '') -> Dict[str, List[str]]: """ Splits a text into sections based on titles and subtitles, and organizes them into a dictionary. Args: text (str): The input text to be split. The text should contain titles marked by asterisks (***) or subtitles marked by equal signs (===). prefix (str): prefix to titles and subtitles Returns: Dict[str, List[str]]: A dictionary where keys are section titles or subtitles, and values are lists of strings corresponding to the content under each title or subtitle. Example: text = ''' ********************* The Blender Community ********************* Being freely available from the start. Independent Sites ================= There are `several independent websites. Getting Support =============== Blender's community is one of its greatest features. ''' result = split_in_topics(text) # result will be: # { # "# The Blender Community": [ # "Being freely available from the start." # ], # "# The Blender Community | Independent Sites": [ # "There are `several independent websites." # ], # "# The Blender Community | Getting Support": [ # "Blender's community is one of its greatest features." # ] # } """ # Remove patterns ".. word::" and ":word:" text = re.sub(r'\.\. [^\n]+\n+(?: {3,}[^\n]*\n)*|:\w+:', '', text) # Regular expression to find titles and subtitles pattern = r'([\*|#|%]{3,}\n[^\n]+\n[\*|#|%]{3,}|(?:={3,}\n)?[^\n]+\n={3,}\n)' # Split text by found patterns sections = re.split(pattern, text) # Remove possible white spaces at the beginning and end of each section sections = [section for section in sections if section.strip()] # Separate sections into a dictionary topics = {} current_title = '' current_topic = prefix for section in sections: if match := re.match(r'[\*|#|%]{3,}\n([^\n]+)\n[\*|#|%]{3,}', section): current_topic = current_title = f'{prefix}# {match.group(1)}' topics[current_topic] = [] elif match := re.match(r'(?:={3,}\n)?([^\n]+)\n={3,}\n', section): current_topic = current_title + ' | ' + match.group(1) topics[current_topic] = [] else: if current_topic == prefix: raise topics[current_topic].append(section) return topics @classmethod def split_into_many(cls, page_body, prefix=''): """ # Function to split the text into chunks of a maximum number of tokens """ tokenizer = EMBEDDING_CTX.model.tokenizer max_tokens = EMBEDDING_CTX.model.max_seq_length topics = cls.split_into_topics(page_body, prefix) for topic, content_list in topics.items(): title = topic + ':\n' title_tokens_len = len(tokenizer.tokenize(title)) content_list_new = [] for content in content_list: content_reduced = cls.reduce_text(content) content_tokens_len = len(tokenizer.tokenize(content_reduced)) if title_tokens_len + content_tokens_len <= max_tokens: content_list_new.append(content_reduced) continue # Split the text into sentences paragraphs = content_reduced.split('.\n') sentences = '' tokens_so_far = title_tokens_len # Loop through the sentences and tokens joined together in a tuple for sentence in paragraphs: sentence += '.\n' # Get the number of tokens for each sentence n_tokens = len(tokenizer.tokenize(sentence)) # If the number of tokens so far plus the number of tokens in the current sentence is greater # than the max number of tokens, then add the chunk to the list of chunks and reset # the chunk and tokens so far if tokens_so_far + n_tokens > max_tokens: content_list_new.append(sentences) sentences = '' tokens_so_far = title_tokens_len sentences += sentence tokens_so_far += n_tokens if sentences: content_list_new.append(sentences) # Replace content_list content_list.clear() content_list.extend(content_list_new) result = [] for topic, content_list in topics.items(): for content in content_list: result.append(topic + ':\n' + content) return result @classmethod def get_texts_recursive(cls, page, path='index'): result = cls.split_into_many(page['body'], path) try: for key in page['toctree'].keys(): page_child = page['toctree'][key] result.extend(cls.get_texts_recursive( page_child, path.replace('index', key))) except KeyError: pass return result @classmethod def manual_get_texts_to_embed(cls): manual = cls.parse_file_recursive( os.path.join(MANUAL_DIR, 'index.rst')) manual['toctree']["copyright"] = cls.parse_file_recursive( os.path.join(MANUAL_DIR, 'copyright.rst')) return cls.get_texts_recursive(manual) @classmethod def wiki_get_texts_to_embed(cls): tokenizer = EMBEDDING_CTX.model.tokenizer max_tokens = EMBEDDING_CTX.model.max_seq_length texts = [] owner = "blender" repo = "blender" pages = gitea_wiki_pages_get(owner, repo) for page_name in pages: page_name_title = page_name["title"] page = gitea_wiki_page_get(owner, repo, page_name_title) prefix = f'/{page["sub_url"]}\n# {page_name_title}:' text = base64.b64decode(page["content_base64"]).decode('utf-8') text = text.replace( 'https://projects.blender.org/blender/blender', '') tokens_prefix_len = len(tokenizer.tokenize(prefix)) tokens_so_far = tokens_prefix_len text_so_far = prefix text_parts = text.split('\n#') for part in text_parts: part = '\n#' + part part_tokens_len = len(tokenizer.tokenize(part)) if tokens_so_far + part_tokens_len > max_tokens: texts.append(text_so_far) text_so_far = prefix tokens_so_far = tokens_prefix_len text_so_far += part tokens_so_far += part_tokens_len if tokens_so_far != tokens_prefix_len: texts.append(text_so_far) return texts def _sort_similarity(self, text_to_search, group: Group = Group.all, limit=4): result = [] query_emb = EMBEDDING_CTX.encode([text_to_search]) ret = {} for grp in list(Group)[:-1]: if group in {grp, Group.all}: ret[grp] = util.semantic_search( query_emb, self[grp]['embeddings'], top_k=limit, score_function=util.dot_score) score_best = 0.0 group_best = None for grp, val in ret.items(): score_curr = val[0][0]['score'] if score_curr > score_best: score_best = score_curr group_best = grp texts = self[group_best]['texts'] for score in ret[group_best][0]: corpus_id = score['corpus_id'] text = texts[corpus_id] result.append(text) return result, group_best G_data = _Data() router = APIRouter() @router.get("/wiki_search", response_class=PlainTextResponse) def wiki_search(query: str = "", group: Group = Group.all) -> str: base_url = { Group.wiki: "https://projects.blender.org/blender/blender", Group.manual: "https://docs.blender.org/manual/en/dev" } texts, group_best = G_data._sort_similarity(query, group) result = f'BASE_URL: {base_url[group_best]}\n' for text in texts: if group_best == Group.wiki: result += f'''--- {text} ''' else: index = text.find('#') result += f'''--- {text[:index] + '.html'} {text[index:]} ''' return result if __name__ == '__main__': tests = ["Set Snap Base", "Building the Manual", "Bisect Object", "Who are the Triagers"] result = wiki_search(tests[1], Group.all) print(result)