tools / routers /tool_wiki_search.py
Germano Cavalcante
API changes
9a6a74b
raw
history blame
10.4 kB
# routers/wiki_search.py
import os
import pickle
import re
import torch
from typing import Dict, List
from sentence_transformers import util
from fastapi import APIRouter
from fastapi.responses import PlainTextResponse
try:
from .embedding import EMBEDDING_CTX
except:
from embedding import EMBEDDING_CTX
router = APIRouter()
MANUAL_DIR = "D:/BlenderDev/blender-manual/manual/"
BASE_URL = "https://docs.blender.org/manual/en/dev"
G_data = None
class _Data(dict):
cache_path = "routers/embedding/embeddings_manual.pkl"
@staticmethod
def reduce_text(text):
# Remove repeated characters
text = re.sub(r'%{2,}', '', text) # Title
text = re.sub(r'#{2,}', '', text) # Title
text = re.sub(r'\*{3,}', '', text) # Title
text = re.sub(r'={3,}', '', text) # Topic
text = re.sub(r'\^{3,}', '', text)
text = re.sub(r'-{3,}', '', text)
text = re.sub(r'(\s*\n\s*)+', '\n', text)
return text
@classmethod
def parse_file_recursive(cls, filedir, filename):
with open(os.path.join(filedir, filename), 'r', encoding='utf-8') as file:
content = file.read()
parsed_data = {}
if not filename.endswith('index.rst'):
body = content.strip()
else:
parts = content.split(".. toctree::")
body = parts[0].strip()
if len(parts) > 1:
parsed_data["toctree"] = {}
for part in parts[1:]:
toctree_entries = part.split('\n')
line = toctree_entries[0]
for entry in toctree_entries[1:]:
entry = entry.strip()
if not entry:
continue
if entry.startswith('/'):
# relative path.
continue
if not entry.endswith('.rst'):
continue
if entry.endswith('/index.rst'):
entry_name = entry[:-10]
filedir_ = os.path.join(filedir, entry_name)
filename_ = 'index.rst'
else:
entry_name = entry[:-4]
filedir_ = filedir
filename_ = entry
parsed_data['toctree'][entry_name] = cls.parse_file_recursive(
filedir_, filename_)
# The '\n' at the end of the file resolves regex patterns
parsed_data['body'] = body + '\n'
return parsed_data
@staticmethod
def split_into_topics(text: str, prefix: str = '') -> Dict[str, List[str]]:
"""
Splits a text into sections based on titles and subtitles, and organizes them into a dictionary.
Args:
text (str): The input text to be split. The text should contain titles marked by asterisks (***)
or subtitles marked by equal signs (===).
prefix (str): prefix to titles and subtitles
Returns:
Dict[str, List[str]]: A dictionary where keys are section titles or subtitles, and values are lists of
strings corresponding to the content under each title or subtitle.
Example:
text = '''
*********************
The Blender Community
*********************
Being freely available from the start.
Independent Sites
=================
There are `several independent websites.
Getting Support
===============
Blender's community is one of its greatest features.
'''
result = split_in_topics(text)
# result will be:
# {
# "# The Blender Community": [
# "Being freely available from the start."
# ],
# "# The Blender Community | Independent Sites": [
# "There are `several independent websites."
# ],
# "# The Blender Community | Getting Support": [
# "Blender's community is one of its greatest features."
# ]
# }
"""
# Remove patterns ".. word::" and ":word:"
text = re.sub(r'\.\. [^\n]+\n+(?: {3,}[^\n]*\n)*|:\w+:', '', text)
# Regular expression to find titles and subtitles
pattern = r'([\*|#|%]{3,}\n[^\n]+\n[\*|#|%]{3,}|(?:={3,}\n)?[^\n]+\n={3,}\n)'
# Split text by found patterns
sections = re.split(pattern, text)
# Remove possible white spaces at the beginning and end of each section
sections = [section for section in sections if section.strip()]
# Separate sections into a dictionary
topics = {}
current_title = ''
current_topic = prefix
for section in sections:
if match := re.match(r'[\*|#|%]{3,}\n([^\n]+)\n[\*|#|%]{3,}', section):
current_topic = current_title = f'{prefix}# {match.group(1)}'
topics[current_topic] = []
elif match := re.match(r'(?:={3,}\n)?([^\n]+)\n={3,}\n', section):
current_topic = current_title + ' | ' + match.group(1)
topics[current_topic] = []
else:
if current_topic == prefix:
raise
topics[current_topic].append(section)
return topics
@classmethod
def split_into_many(cls, page_body, prefix=''):
"""
# Function to split the text into chunks of a maximum number of tokens
"""
tokenizer = EMBEDDING_CTX.model.tokenizer
max_tokens = EMBEDDING_CTX.model.max_seq_length
topics = cls.split_into_topics(page_body, prefix)
for topic, content_list in topics.items():
title = topic + ':\n'
title_tokens_len = len(tokenizer.tokenize(title))
content_list_new = []
for content in content_list:
content_reduced = cls.reduce_text(content)
content_tokens_len = len(tokenizer.tokenize(content_reduced))
if title_tokens_len + content_tokens_len <= max_tokens:
content_list_new.append(content_reduced)
continue
# Split the text into sentences
paragraphs = content_reduced.split('.\n')
sentences = ''
tokens_so_far = title_tokens_len
# Loop through the sentences and tokens joined together in a tuple
for sentence in paragraphs:
sentence += '.\n'
# Get the number of tokens for each sentence
n_tokens = len(tokenizer.tokenize(sentence))
# If the number of tokens so far plus the number of tokens in the current sentence is greater
# than the max number of tokens, then add the chunk to the list of chunks and reset
# the chunk and tokens so far
if tokens_so_far + n_tokens > max_tokens:
content_list_new.append(sentences)
sentences = ''
tokens_so_far = title_tokens_len
sentences += sentence
tokens_so_far += n_tokens
if sentences:
content_list_new.append(sentences)
# Replace content_list
content_list.clear()
content_list.extend(content_list_new)
result = []
for topic, content_list in topics.items():
for content in content_list:
result.append(topic + ':\n' + content)
return result
@classmethod
def get_texts_recursive(cls, page, path=''):
result = cls.split_into_many(page['body'], path)
try:
for key in page['toctree'].keys():
page_child = page['toctree'][key]
result.extend(cls.get_texts_recursive(
page_child, f'{path}/{key}'))
except KeyError:
pass
return result
def _embeddings_generate(self):
if os.path.exists(self.cache_path):
with open(self.cache_path, 'rb') as file:
data = pickle.load(file)
self.update(data)
return self
# Generate
manual = self.parse_file_recursive(MANUAL_DIR, 'index.rst')
manual['toctree']["copyright"] = self.parse_file_recursive(
MANUAL_DIR, 'copyright.rst')
# Create a list to store the text files
texts = self.get_texts_recursive(manual)
print("Embedding Texts...")
self['texts'] = texts
self['embeddings'] = EMBEDDING_CTX.encode(texts)
with open(self.cache_path, "wb") as file:
# Converting the embeddings to be CPU compatible, as the virtual machine in use currently only supports the CPU.
self['embeddings'] = self['embeddings'].to(torch.device('cpu'))
pickle.dump(dict(self), file, protocol=pickle.HIGHEST_PROTOCOL)
return G_data
def _sort_similarity(self, text_to_search, limit):
results = []
query_emb = EMBEDDING_CTX.encode([text_to_search])
ret = util.semantic_search(
query_emb, self['embeddings'], top_k=limit, score_function=util.dot_score)
texts = self['texts']
for score in ret[0]:
corpus_id = score['corpus_id']
text = texts[corpus_id]
results.append(text)
return results
G_data = _Data()
@router.get("/wiki_search", response_class=PlainTextResponse)
def wiki_search(query: str = "") -> str:
data = G_data._embeddings_generate()
texts = G_data._sort_similarity(query, 5)
result = f'BASE_URL: {BASE_URL}\n'
for text in texts:
index = text.find('#')
result += f'''---
{text[:index] + '.html'}
{text[index:]}
'''
return result
if __name__ == '__main__':
tests = ["Set Snap Base", "Building the Manual", "Bisect Object"]
result = wiki_search(tests[0])
print(result)