# routers/wiki_search.py import os import pickle import re import torch from typing import Dict, List from sentence_transformers import util from fastapi import APIRouter from fastapi.responses import PlainTextResponse try: from .embedding import EMBEDDING_CTX except: from embedding import EMBEDDING_CTX router = APIRouter() MANUAL_DIR = "D:/BlenderDev/blender-manual/manual/" BASE_URL = "https://docs.blender.org/manual/en/dev" G_data = None class _Data(dict): cache_path = "routers/embedding/embeddings_manual.pkl" @staticmethod def reduce_text(text): # Remove repeated characters text = re.sub(r'%{2,}', '', text) # Title text = re.sub(r'#{2,}', '', text) # Title text = re.sub(r'\*{3,}', '', text) # Title text = re.sub(r'={3,}', '', text) # Topic text = re.sub(r'\^{3,}', '', text) text = re.sub(r'-{3,}', '', text) text = re.sub(r'(\s*\n\s*)+', '\n', text) return text @classmethod def parse_file_recursive(cls, filedir, filename): with open(os.path.join(filedir, filename), 'r', encoding='utf-8') as file: content = file.read() parsed_data = {} if not filename.endswith('index.rst'): body = content.strip() else: parts = content.split(".. toctree::") body = parts[0].strip() if len(parts) > 1: parsed_data["toctree"] = {} for part in parts[1:]: toctree_entries = part.split('\n') line = toctree_entries[0] for entry in toctree_entries[1:]: entry = entry.strip() if not entry: continue if entry.startswith('/'): # relative path. continue if not entry.endswith('.rst'): continue if entry.endswith('/index.rst'): entry_name = entry[:-10] filedir_ = os.path.join(filedir, entry_name) filename_ = 'index.rst' else: entry_name = entry[:-4] filedir_ = filedir filename_ = entry parsed_data['toctree'][entry_name] = cls.parse_file_recursive( filedir_, filename_) # The '\n' at the end of the file resolves regex patterns parsed_data['body'] = body + '\n' return parsed_data @staticmethod def split_into_topics(text: str, prefix: str = '') -> Dict[str, List[str]]: """ Splits a text into sections based on titles and subtitles, and organizes them into a dictionary. Args: text (str): The input text to be split. The text should contain titles marked by asterisks (***) or subtitles marked by equal signs (===). prefix (str): prefix to titles and subtitles Returns: Dict[str, List[str]]: A dictionary where keys are section titles or subtitles, and values are lists of strings corresponding to the content under each title or subtitle. Example: text = ''' ********************* The Blender Community ********************* Being freely available from the start. Independent Sites ================= There are `several independent websites. Getting Support =============== Blender's community is one of its greatest features. ''' result = split_in_topics(text) # result will be: # { # "# The Blender Community": [ # "Being freely available from the start." # ], # "# The Blender Community | Independent Sites": [ # "There are `several independent websites." # ], # "# The Blender Community | Getting Support": [ # "Blender's community is one of its greatest features." # ] # } """ # Remove patterns ".. word::" and ":word:" text = re.sub(r'\.\. [^\n]+\n+(?: {3,}[^\n]*\n)*|:\w+:', '', text) # Regular expression to find titles and subtitles pattern = r'([\*|#|%]{3,}\n[^\n]+\n[\*|#|%]{3,}|(?:={3,}\n)?[^\n]+\n={3,}\n)' # Split text by found patterns sections = re.split(pattern, text) # Remove possible white spaces at the beginning and end of each section sections = [section for section in sections if section.strip()] # Separate sections into a dictionary topics = {} current_title = '' current_topic = prefix for section in sections: if match := re.match(r'[\*|#|%]{3,}\n([^\n]+)\n[\*|#|%]{3,}', section): current_topic = current_title = f'{prefix}# {match.group(1)}' topics[current_topic] = [] elif match := re.match(r'(?:={3,}\n)?([^\n]+)\n={3,}\n', section): current_topic = current_title + ' | ' + match.group(1) topics[current_topic] = [] else: if current_topic == prefix: raise topics[current_topic].append(section) return topics @classmethod def split_into_many(cls, page_body, prefix=''): """ # Function to split the text into chunks of a maximum number of tokens """ tokenizer = EMBEDDING_CTX.model.tokenizer max_tokens = EMBEDDING_CTX.model.max_seq_length topics = cls.split_into_topics(page_body, prefix) for topic, content_list in topics.items(): title = topic + ':\n' title_tokens_len = len(tokenizer.tokenize(title)) content_list_new = [] for content in content_list: content_reduced = cls.reduce_text(content) content_tokens_len = len(tokenizer.tokenize(content_reduced)) if title_tokens_len + content_tokens_len <= max_tokens: content_list_new.append(content_reduced) continue # Split the text into sentences paragraphs = content_reduced.split('.\n') sentences = '' tokens_so_far = title_tokens_len # Loop through the sentences and tokens joined together in a tuple for sentence in paragraphs: sentence += '.\n' # Get the number of tokens for each sentence n_tokens = len(tokenizer.tokenize(sentence)) # If the number of tokens so far plus the number of tokens in the current sentence is greater # than the max number of tokens, then add the chunk to the list of chunks and reset # the chunk and tokens so far if tokens_so_far + n_tokens > max_tokens: content_list_new.append(sentences) sentences = '' tokens_so_far = title_tokens_len sentences += sentence tokens_so_far += n_tokens if sentences: content_list_new.append(sentences) # Replace content_list content_list.clear() content_list.extend(content_list_new) result = [] for topic, content_list in topics.items(): for content in content_list: result.append(topic + ':\n' + content) return result @classmethod def get_texts_recursive(cls, page, path=''): result = cls.split_into_many(page['body'], path) try: for key in page['toctree'].keys(): page_child = page['toctree'][key] result.extend(cls.get_texts_recursive( page_child, f'{path}/{key}')) except KeyError: pass return result def _embeddings_generate(self): if os.path.exists(self.cache_path): with open(self.cache_path, 'rb') as file: data = pickle.load(file) self.update(data) return self # Generate manual = self.parse_file_recursive(MANUAL_DIR, 'index.rst') manual['toctree']["copyright"] = self.parse_file_recursive( MANUAL_DIR, 'copyright.rst') # Create a list to store the text files texts = self.get_texts_recursive(manual) print("Embedding Texts...") self['texts'] = texts self['embeddings'] = EMBEDDING_CTX.encode(texts) with open(self.cache_path, "wb") as file: # Converting the embeddings to be CPU compatible, as the virtual machine in use currently only supports the CPU. self['embeddings'] = self['embeddings'].to(torch.device('cpu')) pickle.dump(dict(self), file, protocol=pickle.HIGHEST_PROTOCOL) return G_data def _sort_similarity(self, text_to_search, limit): results = [] query_emb = EMBEDDING_CTX.encode([text_to_search]) ret = util.semantic_search( query_emb, self['embeddings'], top_k=limit, score_function=util.dot_score) texts = self['texts'] for score in ret[0]: corpus_id = score['corpus_id'] text = texts[corpus_id] results.append(text) return results G_data = _Data() @router.get("/wiki_search", response_class=PlainTextResponse) def wiki_search(query: str = "") -> str: data = G_data._embeddings_generate() texts = G_data._sort_similarity(query, 5) result = f'BASE_URL: {BASE_URL}\n' for text in texts: index = text.find('#') result += f'''--- {text[:index] + '.html'} {text[index:]} ''' return result if __name__ == '__main__': tests = ["Set Snap Base", "Building the Manual", "Bisect Object"] result = wiki_search(tests[0]) print(result)