Spaces:
Running
Running
# routers/wiki_search.py | |
import os | |
import pickle | |
import re | |
import torch | |
from typing import Dict, List | |
from sentence_transformers import util | |
from fastapi import APIRouter | |
from fastapi.responses import PlainTextResponse | |
try: | |
from .embedding import EMBEDDING_CTX | |
except: | |
from embedding import EMBEDDING_CTX | |
router = APIRouter() | |
MANUAL_DIR = "D:/BlenderDev/blender-manual/manual/" | |
BASE_URL = "https://docs.blender.org/manual/en/dev" | |
G_data = None | |
class _Data(dict): | |
cache_path = "routers/embedding/embeddings_manual.pkl" | |
def reduce_text(text): | |
# Remove repeated characters | |
text = re.sub(r'%{2,}', '', text) # Title | |
text = re.sub(r'#{2,}', '', text) # Title | |
text = re.sub(r'\*{3,}', '', text) # Title | |
text = re.sub(r'={3,}', '', text) # Topic | |
text = re.sub(r'\^{3,}', '', text) | |
text = re.sub(r'-{3,}', '', text) | |
text = re.sub(r'(\s*\n\s*)+', '\n', text) | |
return text | |
def parse_file_recursive(cls, filedir, filename): | |
with open(os.path.join(filedir, filename), 'r', encoding='utf-8') as file: | |
content = file.read() | |
parsed_data = {} | |
if not filename.endswith('index.rst'): | |
body = content.strip() | |
else: | |
parts = content.split(".. toctree::") | |
body = parts[0].strip() | |
if len(parts) > 1: | |
parsed_data["toctree"] = {} | |
for part in parts[1:]: | |
toctree_entries = part.split('\n') | |
line = toctree_entries[0] | |
for entry in toctree_entries[1:]: | |
entry = entry.strip() | |
if not entry: | |
continue | |
if entry.startswith('/'): | |
# relative path. | |
continue | |
if not entry.endswith('.rst'): | |
continue | |
if entry.endswith('/index.rst'): | |
entry_name = entry[:-10] | |
filedir_ = os.path.join(filedir, entry_name) | |
filename_ = 'index.rst' | |
else: | |
entry_name = entry[:-4] | |
filedir_ = filedir | |
filename_ = entry | |
parsed_data['toctree'][entry_name] = cls.parse_file_recursive( | |
filedir_, filename_) | |
# The '\n' at the end of the file resolves regex patterns | |
parsed_data['body'] = body + '\n' | |
return parsed_data | |
def split_into_topics(text: str, prefix: str = '') -> Dict[str, List[str]]: | |
""" | |
Splits a text into sections based on titles and subtitles, and organizes them into a dictionary. | |
Args: | |
text (str): The input text to be split. The text should contain titles marked by asterisks (***) | |
or subtitles marked by equal signs (===). | |
prefix (str): prefix to titles and subtitles | |
Returns: | |
Dict[str, List[str]]: A dictionary where keys are section titles or subtitles, and values are lists of | |
strings corresponding to the content under each title or subtitle. | |
Example: | |
text = ''' | |
********************* | |
The Blender Community | |
********************* | |
Being freely available from the start. | |
Independent Sites | |
================= | |
There are `several independent websites. | |
Getting Support | |
=============== | |
Blender's community is one of its greatest features. | |
''' | |
result = split_in_topics(text) | |
# result will be: | |
# { | |
# "# The Blender Community": [ | |
# "Being freely available from the start." | |
# ], | |
# "# The Blender Community | Independent Sites": [ | |
# "There are `several independent websites." | |
# ], | |
# "# The Blender Community | Getting Support": [ | |
# "Blender's community is one of its greatest features." | |
# ] | |
# } | |
""" | |
# Remove patterns ".. word::" and ":word:" | |
text = re.sub(r'\.\. [^\n]+\n+(?: {3,}[^\n]*\n)*|:\w+:', '', text) | |
# Regular expression to find titles and subtitles | |
pattern = r'([\*|#|%]{3,}\n[^\n]+\n[\*|#|%]{3,}|(?:={3,}\n)?[^\n]+\n={3,}\n)' | |
# Split text by found patterns | |
sections = re.split(pattern, text) | |
# Remove possible white spaces at the beginning and end of each section | |
sections = [section for section in sections if section.strip()] | |
# Separate sections into a dictionary | |
topics = {} | |
current_title = '' | |
current_topic = prefix | |
for section in sections: | |
if match := re.match(r'[\*|#|%]{3,}\n([^\n]+)\n[\*|#|%]{3,}', section): | |
current_topic = current_title = f'{prefix}# {match.group(1)}' | |
topics[current_topic] = [] | |
elif match := re.match(r'(?:={3,}\n)?([^\n]+)\n={3,}\n', section): | |
current_topic = current_title + ' | ' + match.group(1) | |
topics[current_topic] = [] | |
else: | |
if current_topic == prefix: | |
raise | |
topics[current_topic].append(section) | |
return topics | |
def split_into_many(cls, page_body, prefix=''): | |
""" | |
# Function to split the text into chunks of a maximum number of tokens | |
""" | |
tokenizer = EMBEDDING_CTX.model.tokenizer | |
max_tokens = EMBEDDING_CTX.model.max_seq_length | |
topics = cls.split_into_topics(page_body, prefix) | |
for topic, content_list in topics.items(): | |
title = topic + ':\n' | |
title_tokens_len = len(tokenizer.tokenize(title)) | |
content_list_new = [] | |
for content in content_list: | |
content_reduced = cls.reduce_text(content) | |
content_tokens_len = len(tokenizer.tokenize(content_reduced)) | |
if title_tokens_len + content_tokens_len <= max_tokens: | |
content_list_new.append(content_reduced) | |
continue | |
# Split the text into sentences | |
paragraphs = content_reduced.split('.\n') | |
sentences = '' | |
tokens_so_far = title_tokens_len | |
# Loop through the sentences and tokens joined together in a tuple | |
for sentence in paragraphs: | |
sentence += '.\n' | |
# Get the number of tokens for each sentence | |
n_tokens = len(tokenizer.tokenize(sentence)) | |
# If the number of tokens so far plus the number of tokens in the current sentence is greater | |
# than the max number of tokens, then add the chunk to the list of chunks and reset | |
# the chunk and tokens so far | |
if tokens_so_far + n_tokens > max_tokens: | |
content_list_new.append(sentences) | |
sentences = '' | |
tokens_so_far = title_tokens_len | |
sentences += sentence | |
tokens_so_far += n_tokens | |
if sentences: | |
content_list_new.append(sentences) | |
# Replace content_list | |
content_list.clear() | |
content_list.extend(content_list_new) | |
result = [] | |
for topic, content_list in topics.items(): | |
for content in content_list: | |
result.append(topic + ':\n' + content) | |
return result | |
def get_texts_recursive(cls, page, path=''): | |
result = cls.split_into_many(page['body'], path) | |
try: | |
for key in page['toctree'].keys(): | |
page_child = page['toctree'][key] | |
result.extend(cls.get_texts_recursive( | |
page_child, f'{path}/{key}')) | |
except KeyError: | |
pass | |
return result | |
def _embeddings_generate(self): | |
if os.path.exists(self.cache_path): | |
with open(self.cache_path, 'rb') as file: | |
data = pickle.load(file) | |
self.update(data) | |
return self | |
# Generate | |
manual = self.parse_file_recursive(MANUAL_DIR, 'index.rst') | |
manual['toctree']["copyright"] = self.parse_file_recursive( | |
MANUAL_DIR, 'copyright.rst') | |
# Create a list to store the text files | |
texts = self.get_texts_recursive(manual) | |
print("Embedding Texts...") | |
self['texts'] = texts | |
self['embeddings'] = EMBEDDING_CTX.encode(texts) | |
with open(self.cache_path, "wb") as file: | |
# Converting the embeddings to be CPU compatible, as the virtual machine in use currently only supports the CPU. | |
self['embeddings'] = self['embeddings'].to(torch.device('cpu')) | |
pickle.dump(dict(self), file, protocol=pickle.HIGHEST_PROTOCOL) | |
return G_data | |
def _sort_similarity(self, text_to_search, limit): | |
results = [] | |
query_emb = EMBEDDING_CTX.encode([text_to_search]) | |
ret = util.semantic_search( | |
query_emb, self['embeddings'], top_k=limit, score_function=util.dot_score) | |
texts = self['texts'] | |
for score in ret[0]: | |
corpus_id = score['corpus_id'] | |
text = texts[corpus_id] | |
results.append(text) | |
return results | |
G_data = _Data() | |
def wiki_search(query: str = "") -> str: | |
data = G_data._embeddings_generate() | |
texts = G_data._sort_similarity(query, 5) | |
result = f'BASE_URL: {BASE_URL}\n' | |
for text in texts: | |
index = text.find('#') | |
result += f'''--- | |
{text[:index] + '.html'} | |
{text[index:]} | |
''' | |
return result | |
if __name__ == '__main__': | |
tests = ["Set Snap Base", "Building the Manual", "Bisect Object"] | |
result = wiki_search(tests[0]) | |
print(result) | |