Spaces:
Running
Running
File size: 10,351 Bytes
ed15883 9a6a74b ed15883 9a6a74b ed15883 9a6a74b ed15883 9a6a74b ed15883 9a6a74b ed15883 9a6a74b ed15883 9a6a74b ed15883 9a6a74b ed15883 9a6a74b ed15883 9a6a74b ed15883 9a6a74b ed15883 9a6a74b ed15883 9a6a74b ed15883 9a6a74b ed15883 9a6a74b ed15883 9a6a74b ed15883 9a6a74b ed15883 9a6a74b ed15883 9a6a74b ed15883 9a6a74b ed15883 9a6a74b ed15883 9a6a74b ed15883 9a6a74b ed15883 9a6a74b ed15883 9a6a74b ed15883 9a6a74b ed15883 9a6a74b ed15883 9a6a74b ed15883 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 |
# routers/wiki_search.py
import os
import pickle
import re
import torch
from typing import Dict, List
from sentence_transformers import util
from fastapi import APIRouter
from fastapi.responses import PlainTextResponse
try:
from .embedding import EMBEDDING_CTX
except:
from embedding import EMBEDDING_CTX
router = APIRouter()
MANUAL_DIR = "D:/BlenderDev/blender-manual/manual/"
BASE_URL = "https://docs.blender.org/manual/en/dev"
G_data = None
class _Data(dict):
cache_path = "routers/embedding/embeddings_manual.pkl"
@staticmethod
def reduce_text(text):
# Remove repeated characters
text = re.sub(r'%{2,}', '', text) # Title
text = re.sub(r'#{2,}', '', text) # Title
text = re.sub(r'\*{3,}', '', text) # Title
text = re.sub(r'={3,}', '', text) # Topic
text = re.sub(r'\^{3,}', '', text)
text = re.sub(r'-{3,}', '', text)
text = re.sub(r'(\s*\n\s*)+', '\n', text)
return text
@classmethod
def parse_file_recursive(cls, filedir, filename):
with open(os.path.join(filedir, filename), 'r', encoding='utf-8') as file:
content = file.read()
parsed_data = {}
if not filename.endswith('index.rst'):
body = content.strip()
else:
parts = content.split(".. toctree::")
body = parts[0].strip()
if len(parts) > 1:
parsed_data["toctree"] = {}
for part in parts[1:]:
toctree_entries = part.split('\n')
line = toctree_entries[0]
for entry in toctree_entries[1:]:
entry = entry.strip()
if not entry:
continue
if entry.startswith('/'):
# relative path.
continue
if not entry.endswith('.rst'):
continue
if entry.endswith('/index.rst'):
entry_name = entry[:-10]
filedir_ = os.path.join(filedir, entry_name)
filename_ = 'index.rst'
else:
entry_name = entry[:-4]
filedir_ = filedir
filename_ = entry
parsed_data['toctree'][entry_name] = cls.parse_file_recursive(
filedir_, filename_)
# The '\n' at the end of the file resolves regex patterns
parsed_data['body'] = body + '\n'
return parsed_data
@staticmethod
def split_into_topics(text: str, prefix: str = '') -> Dict[str, List[str]]:
"""
Splits a text into sections based on titles and subtitles, and organizes them into a dictionary.
Args:
text (str): The input text to be split. The text should contain titles marked by asterisks (***)
or subtitles marked by equal signs (===).
prefix (str): prefix to titles and subtitles
Returns:
Dict[str, List[str]]: A dictionary where keys are section titles or subtitles, and values are lists of
strings corresponding to the content under each title or subtitle.
Example:
text = '''
*********************
The Blender Community
*********************
Being freely available from the start.
Independent Sites
=================
There are `several independent websites.
Getting Support
===============
Blender's community is one of its greatest features.
'''
result = split_in_topics(text)
# result will be:
# {
# "# The Blender Community": [
# "Being freely available from the start."
# ],
# "# The Blender Community | Independent Sites": [
# "There are `several independent websites."
# ],
# "# The Blender Community | Getting Support": [
# "Blender's community is one of its greatest features."
# ]
# }
"""
# Remove patterns ".. word::" and ":word:"
text = re.sub(r'\.\. [^\n]+\n+(?: {3,}[^\n]*\n)*|:\w+:', '', text)
# Regular expression to find titles and subtitles
pattern = r'([\*|#|%]{3,}\n[^\n]+\n[\*|#|%]{3,}|(?:={3,}\n)?[^\n]+\n={3,}\n)'
# Split text by found patterns
sections = re.split(pattern, text)
# Remove possible white spaces at the beginning and end of each section
sections = [section for section in sections if section.strip()]
# Separate sections into a dictionary
topics = {}
current_title = ''
current_topic = prefix
for section in sections:
if match := re.match(r'[\*|#|%]{3,}\n([^\n]+)\n[\*|#|%]{3,}', section):
current_topic = current_title = f'{prefix}# {match.group(1)}'
topics[current_topic] = []
elif match := re.match(r'(?:={3,}\n)?([^\n]+)\n={3,}\n', section):
current_topic = current_title + ' | ' + match.group(1)
topics[current_topic] = []
else:
if current_topic == prefix:
raise
topics[current_topic].append(section)
return topics
@classmethod
def split_into_many(cls, page_body, prefix=''):
"""
# Function to split the text into chunks of a maximum number of tokens
"""
tokenizer = EMBEDDING_CTX.model.tokenizer
max_tokens = EMBEDDING_CTX.model.max_seq_length
topics = cls.split_into_topics(page_body, prefix)
for topic, content_list in topics.items():
title = topic + ':\n'
title_tokens_len = len(tokenizer.tokenize(title))
content_list_new = []
for content in content_list:
content_reduced = cls.reduce_text(content)
content_tokens_len = len(tokenizer.tokenize(content_reduced))
if title_tokens_len + content_tokens_len <= max_tokens:
content_list_new.append(content_reduced)
continue
# Split the text into sentences
paragraphs = content_reduced.split('.\n')
sentences = ''
tokens_so_far = title_tokens_len
# Loop through the sentences and tokens joined together in a tuple
for sentence in paragraphs:
sentence += '.\n'
# Get the number of tokens for each sentence
n_tokens = len(tokenizer.tokenize(sentence))
# If the number of tokens so far plus the number of tokens in the current sentence is greater
# than the max number of tokens, then add the chunk to the list of chunks and reset
# the chunk and tokens so far
if tokens_so_far + n_tokens > max_tokens:
content_list_new.append(sentences)
sentences = ''
tokens_so_far = title_tokens_len
sentences += sentence
tokens_so_far += n_tokens
if sentences:
content_list_new.append(sentences)
# Replace content_list
content_list.clear()
content_list.extend(content_list_new)
result = []
for topic, content_list in topics.items():
for content in content_list:
result.append(topic + ':\n' + content)
return result
@classmethod
def get_texts_recursive(cls, page, path=''):
result = cls.split_into_many(page['body'], path)
try:
for key in page['toctree'].keys():
page_child = page['toctree'][key]
result.extend(cls.get_texts_recursive(
page_child, f'{path}/{key}'))
except KeyError:
pass
return result
def _embeddings_generate(self):
if os.path.exists(self.cache_path):
with open(self.cache_path, 'rb') as file:
data = pickle.load(file)
self.update(data)
return self
# Generate
manual = self.parse_file_recursive(MANUAL_DIR, 'index.rst')
manual['toctree']["copyright"] = self.parse_file_recursive(
MANUAL_DIR, 'copyright.rst')
# Create a list to store the text files
texts = self.get_texts_recursive(manual)
print("Embedding Texts...")
self['texts'] = texts
self['embeddings'] = EMBEDDING_CTX.encode(texts)
with open(self.cache_path, "wb") as file:
# Converting the embeddings to be CPU compatible, as the virtual machine in use currently only supports the CPU.
self['embeddings'] = self['embeddings'].to(torch.device('cpu'))
pickle.dump(dict(self), file, protocol=pickle.HIGHEST_PROTOCOL)
return G_data
def _sort_similarity(self, text_to_search, limit):
results = []
query_emb = EMBEDDING_CTX.encode([text_to_search])
ret = util.semantic_search(
query_emb, self['embeddings'], top_k=limit, score_function=util.dot_score)
texts = self['texts']
for score in ret[0]:
corpus_id = score['corpus_id']
text = texts[corpus_id]
results.append(text)
return results
G_data = _Data()
@router.get("/wiki_search", response_class=PlainTextResponse)
def wiki_search(query: str = "") -> str:
data = G_data._embeddings_generate()
texts = G_data._sort_similarity(query, 5)
result = f'BASE_URL: {BASE_URL}\n'
for text in texts:
index = text.find('#')
result += f'''---
{text[:index] + '.html'}
{text[index:]}
'''
return result
if __name__ == '__main__':
tests = ["Set Snap Base", "Building the Manual", "Bisect Object"]
result = wiki_search(tests[0])
print(result)
|