File size: 6,892 Bytes
0576e6d
ed15883
0576e6d
ed15883
 
 
9a6a74b
0576e6d
25dbca2
9a6a74b
1b8973e
 
 
ed15883
 
23415c5
5974bb1
ed15883
23415c5
5974bb1
ed15883
1b8973e
 
af4d94e
0576e6d
 
 
1b8973e
 
0576e6d
af4d94e
 
eefe03e
 
 
9a6a74b
23415c5
0576e6d
 
1b8973e
 
 
 
 
 
0576e6d
1b8973e
 
0576e6d
 
 
74e0256
af4d94e
74e0256
 
af4d94e
 
0576e6d
 
 
 
1b8973e
 
 
 
9a6a74b
1b8973e
9a6a74b
0576e6d
 
1b8973e
af4d94e
 
 
 
 
1b8973e
 
af4d94e
 
1b8973e
 
 
 
 
 
 
 
 
 
 
 
 
 
ed15883
af4d94e
 
1b8973e
 
 
 
 
 
 
af4d94e
 
 
 
 
 
1b8973e
 
af4d94e
 
1b8973e
af4d94e
 
1b8973e
 
 
af4d94e
 
0576e6d
af4d94e
 
1b8973e
 
 
 
 
 
ed15883
1b8973e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0576e6d
1b8973e
af4d94e
1b8973e
af4d94e
 
1b8973e
af4d94e
ed15883
1b8973e
 
 
 
 
 
 
ed15883
1b8973e
 
ed15883
1b8973e
ed15883
 
9a6a74b
ed15883
0576e6d
ed15883
 
0576e6d
1b8973e
 
eefe03e
1b8973e
eefe03e
 
 
 
 
 
25dbca2
1b8973e
 
ed15883
af4d94e
ed15883
 
 
 
0576e6d
af4d94e
25dbca2
ed15883
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
# routers/tool_wiki_search.py

import base64
import os
import pickle
import re
import torch
from enum import Enum
from fastapi import APIRouter, Query, params
from fastapi.responses import PlainTextResponse
from heapq import nlargest
from sentence_transformers import util
from typing import Dict, List, Tuple, Set, LiteralString

try:
    from .rag import SplitDocs, EMBEDDING_CTX
    from .utils_gitea import gitea_wiki_page_get, gitea_wiki_pages_get
except:
    from rag import SplitDocs, EMBEDDING_CTX
    from utils_gitea import gitea_wiki_page_get, gitea_wiki_pages_get


MANUAL_DIR = "D:/BlenderDev/blender-manual/manual"
DOCS_DIR = "D:/BlenderDev/blender-developer-docs/docs"


class Group(str, Enum):
    dev_docs = "dev_docs"
    # wiki = "wiki"
    manual = "manual"


GROUPS_DEFAULT = {Group.dev_docs, Group.manual}


class _Data(dict):
    cache_path = "routers/rag/embeddings_{}.pkl"

    def __init__(self):
        for grp in list(Group):
            cache_path = self.cache_path.format(grp.name)
            if os.path.exists(cache_path):
                with open(cache_path, 'rb') as file:
                    self[grp.name] = pickle.load(file)
                continue

            # Generate
            print("Embedding Texts for", grp.name)
            self[grp.name] = {}

            # Create a list to store the text files
            if grp is Group.dev_docs:
                texts = self.docs_get_texts_to_embed()
            # elif grp is Group.wiki:
                # texts = self.wiki_get_texts_to_embed()
            else:
                texts = self.manual_get_texts_to_embed()

            self[grp]['texts'] = texts
            self[grp]['embeddings'] = EMBEDDING_CTX.encode(texts)

            with open(cache_path, "wb") as file:
                # Converting the embeddings to be CPU compatible, as the virtual machine in use currently only supports the CPU.
                self[grp]['embeddings'] = self[grp]['embeddings'].to(
                    torch.device('cpu'))

                pickle.dump(self[grp], file, protocol=pickle.HIGHEST_PROTOCOL)

    @classmethod
    def manual_get_texts_to_embed(cls):
        class SplitManual(SplitDocs):
            def reduce_text(_self, text):
                # Remove repeated characters
                text = re.sub(r'\^{3,}', '', text)
                text = re.sub(r'-{3,}', '', text)

                text = text.replace('.rst', '.html')
                text = super().reduce_text(text)
                return text

            def embedding_header(self, rel_path, titles):
                rel_path = rel_path.replace('.rst', '.html')
                return super().embedding_header(rel_path, titles)

        # Remove patterns ".. word::" and ":word:"
        pattern_content_sub = r'\.\. [^\n]+\n+(?: {3,}[^\n]*\n)*|:\w+:'
        patterns_titles = (
            r'[\*#%]{3,}\n\s*(.+)\n[\*#%]{3,}', r'(?:[=+]{3,}\n)?\s*(.+)\n[=+]{3,}\n')

        return SplitManual().split_for_embedding(
            MANUAL_DIR,
            pattern_content_sub=pattern_content_sub,
            patterns_titles=patterns_titles,
        )

    @staticmethod
    def wiki_get_texts_to_embed():
        class SplitWiki(SplitDocs):
            def split_in_topics(_self,
                                filedir: LiteralString = None,
                                *,
                                pattern_filename=None,
                                pattern_content_sub=None,
                                patterns_titles=None):
                owner = "blender"
                repo = "blender"
                pages = gitea_wiki_pages_get(owner, repo)
                for page_name in pages:
                    page_name_title = page_name["title"]
                    page = gitea_wiki_page_get(owner, repo, page_name_title)
                    rel_dir = f'/{owner}/{repo}/{page["sub_url"]}'
                    titles = [page_name_title]
                    text = base64.b64decode(
                        page["content_base64"]).decode('utf-8')
                    yield (rel_dir, titles, text)

            def reduce_text(_self, text):
                text = super().reduce_text(text)
                text = text.replace('https://projects.blender.org', '')
                return text

        return SplitWiki().split_for_embedding()

    @staticmethod
    def docs_get_texts_to_embed():
        class SplitBlenderDocs(SplitDocs):
            def reduce_text(_self, text):
                text = super().reduce_text(text)
                # Remove .md or index.md
                text = re.sub(r'(index)?.md', '', text)
                return text

            def embedding_header(_self, rel_path, titles):
                rel_path = re.sub(r'(index)?.md', '', rel_path)
                return super().embedding_header(rel_path, titles)

        return SplitBlenderDocs().split_for_embedding(DOCS_DIR)

    def _sort_similarity(
            self,
            text_to_search: str,
            groups: Set[Group] = Query(
                default={Group.dev_docs, Group.manual}),
            limit: int = 5) -> List[str]:
        base_url: Dict[Group, str] = {
            Group.dev_docs: "https://developer.blender.org/docs",
            # Group.wiki: "https://projects.blender.org",
            Group.manual: "https://docs.blender.org/manual/en/dev"
        }
        query_emb = EMBEDDING_CTX.encode([text_to_search])
        results: List[Tuple[float, str, Group]] = []
        for grp in groups:
            if grp not in self:
                continue

            search_results = util.semantic_search(
                query_emb, self[grp]['embeddings'], top_k=limit, score_function=util.dot_score)

            for score in search_results[0]:
                corpus_id = score['corpus_id']
                text = self[grp]['texts'][corpus_id]
                results.append((score['score'], text, grp))

        # Keep only the top `limit` results
        top_results = nlargest(limit, results, key=lambda x: x[0])

        # Extract sorted texts with base URL
        sorted_texts = [base_url[grp] + text for _, text, grp in top_results]

        return sorted_texts


G_data = _Data()

router = APIRouter()


@router.get("/wiki_search", response_class=PlainTextResponse)
def wiki_search(
    query: str = "",
    groups: Set[Group] = Query(default=GROUPS_DEFAULT)
) -> str:
    try:
        groups = GROUPS_DEFAULT.intersection(groups)
        if len(groups) == 0:
            raise
    except:
        groups = GROUPS_DEFAULT

    texts = G_data._sort_similarity(query, groups)
    result: str = ''
    for text in texts:
        result += f'\n---\n{text}'
    return result


if __name__ == '__main__':
    tests = ["Set Snap Base", "Building the Manual",
             "Bisect Object", "Who are the Triagers", "4.3 Release Notes Motion Paths"]
    result = wiki_search(tests[0])
    print(result)