AlexanderKazakov
commited on
Commit
·
d7fdb42
1
Parent(s):
8b1c859
improve markdown chunking
Browse files- prep_scripts/lancedb_setup.py +49 -35
- prep_scripts/markdown_to_text.py +87 -43
- settings.py +10 -2
prep_scripts/lancedb_setup.py
CHANGED
@@ -11,14 +11,10 @@ import numpy as np
|
|
11 |
|
12 |
from sentence_transformers import SentenceTransformer
|
13 |
|
|
|
14 |
from settings import *
|
15 |
|
16 |
|
17 |
-
emb_sizes = {
|
18 |
-
"sentence-transformers/all-MiniLM-L6-v2": 384,
|
19 |
-
"thenlper/gte-large": 0
|
20 |
-
}
|
21 |
-
|
22 |
shutil.rmtree(LANCEDB_DIRECTORY, ignore_errors=True)
|
23 |
db = lancedb.connect(LANCEDB_DIRECTORY)
|
24 |
batch_size = 32
|
@@ -33,42 +29,60 @@ elif torch.cuda.is_available():
|
|
33 |
else:
|
34 |
device = "cpu"
|
35 |
|
36 |
-
schema = pa.schema(
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
tbl = db.create_table(LANCEDB_TABLE_NAME, schema=schema, mode="overwrite")
|
42 |
|
43 |
-
input_dir = Path(
|
44 |
files = list(input_dir.rglob("*"))
|
45 |
|
46 |
-
|
47 |
for file in files:
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
for i in tqdm.tqdm(range(0, int(np.ceil(len(sentences) / batch_size)))):
|
52 |
-
try:
|
53 |
-
batch = [sent for sent in sentences[i * batch_size:(i + 1) * batch_size] if len(sent) > 0]
|
54 |
-
encoded = model.encode(batch, normalize_embeddings=True, device=device)
|
55 |
-
encoded = [list(vec) for vec in encoded]
|
56 |
-
|
57 |
-
df = pd.DataFrame({
|
58 |
-
VECTOR_COLUMN_NAME: encoded,
|
59 |
-
TEXT_COLUMN_NAME: batch
|
60 |
-
})
|
61 |
|
62 |
-
|
|
|
|
|
|
|
63 |
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
# tbl.create_index(num_partitions=256, num_sub_vectors=96, vector_column_name=VECTOR_COLUMN_NAME)
|
74 |
|
|
|
11 |
|
12 |
from sentence_transformers import SentenceTransformer
|
13 |
|
14 |
+
from markdown_to_text import *
|
15 |
from settings import *
|
16 |
|
17 |
|
|
|
|
|
|
|
|
|
|
|
18 |
shutil.rmtree(LANCEDB_DIRECTORY, ignore_errors=True)
|
19 |
db = lancedb.connect(LANCEDB_DIRECTORY)
|
20 |
batch_size = 32
|
|
|
29 |
else:
|
30 |
device = "cpu"
|
31 |
|
32 |
+
schema = pa.schema([
|
33 |
+
pa.field(VECTOR_COLUMN_NAME, pa.list_(pa.float32(), emb_sizes[EMB_MODEL_NAME])),
|
34 |
+
pa.field(TEXT_COLUMN_NAME, pa.string()),
|
35 |
+
pa.field(DOCUMENT_PATH_COLUMN_NAME, pa.string()),
|
36 |
+
])
|
37 |
tbl = db.create_table(LANCEDB_TABLE_NAME, schema=schema, mode="overwrite")
|
38 |
|
39 |
+
input_dir = Path(MARKDOWN_SOURCE_DIR)
|
40 |
files = list(input_dir.rglob("*"))
|
41 |
|
42 |
+
chunks = []
|
43 |
for file in files:
|
44 |
+
if not os.path.isfile(file):
|
45 |
+
continue
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
|
47 |
+
file_path, file_ext = os.path.splitext(os.path.relpath(file, input_dir))
|
48 |
+
if file_ext != '.md':
|
49 |
+
print(f'Skipped {file_ext} extension: {file}')
|
50 |
+
continue
|
51 |
|
52 |
+
doc_header = ' / '.join(split_path(file_path)) + ':\n\n'
|
53 |
+
with open(file, encoding='utf-8') as f:
|
54 |
+
f = f.read()
|
55 |
+
f = remove_comments(f)
|
56 |
+
f = split_markdown(f)
|
57 |
+
chunks.extend((doc_header + chunk, os.path.abspath(file)) for chunk in f)
|
58 |
+
|
59 |
+
from matplotlib import pyplot as plt
|
60 |
+
plt.hist([len(c) for c, d in chunks], bins=100)
|
61 |
+
plt.show()
|
62 |
+
|
63 |
+
for i in tqdm.tqdm(range(0, int(np.ceil(len(chunks) / batch_size)))):
|
64 |
+
texts, doc_paths = [], []
|
65 |
+
for text, doc_path in chunks[i * batch_size:(i + 1) * batch_size]:
|
66 |
+
if len(text) > 0:
|
67 |
+
texts.append(text)
|
68 |
+
doc_paths.append(doc_path)
|
69 |
+
|
70 |
+
encoded = model.encode(texts, normalize_embeddings=True, device=device)
|
71 |
+
encoded = [list(vec) for vec in encoded]
|
72 |
+
|
73 |
+
df = pd.DataFrame({
|
74 |
+
VECTOR_COLUMN_NAME: encoded,
|
75 |
+
TEXT_COLUMN_NAME: texts,
|
76 |
+
DOCUMENT_PATH_COLUMN_NAME: doc_paths,
|
77 |
+
})
|
78 |
+
|
79 |
+
tbl.add(df)
|
80 |
+
|
81 |
+
|
82 |
+
# '''
|
83 |
+
# create ivf-pd index https://lancedb.github.io/lancedb/ann_indexes/
|
84 |
+
# with the size of the transformer docs, index is not really needed
|
85 |
+
# but we'll do it for demonstration purposes
|
86 |
+
# '''
|
87 |
# tbl.create_index(num_partitions=256, num_sub_vectors=96, vector_column_name=VECTOR_COLUMN_NAME)
|
88 |
|
prep_scripts/markdown_to_text.py
CHANGED
@@ -1,50 +1,94 @@
|
|
1 |
-
import shutil
|
2 |
-
|
3 |
-
from bs4 import BeautifulSoup
|
4 |
-
from markdown import markdown
|
5 |
import os
|
6 |
import re
|
7 |
-
from pathlib import Path
|
8 |
|
9 |
from settings import *
|
10 |
|
11 |
|
12 |
-
def
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
|
|
|
|
|
|
|
|
|
|
|
1 |
import os
|
2 |
import re
|
|
|
3 |
|
4 |
from settings import *
|
5 |
|
6 |
|
7 |
+
def split_path(path):
|
8 |
+
components = []
|
9 |
+
while True:
|
10 |
+
path, tail = os.path.split(path)
|
11 |
+
if tail == "":
|
12 |
+
if path != "":
|
13 |
+
components.append(path)
|
14 |
+
break
|
15 |
+
components.append(tail)
|
16 |
+
components.reverse()
|
17 |
+
return components
|
18 |
+
|
19 |
+
|
20 |
+
def remove_comments(md):
|
21 |
+
return re.sub(r'<!--((.|\n)*)-->', '', md)
|
22 |
+
|
23 |
+
|
24 |
+
header_pattern = re.compile(r'\n\s*\n(#{1,3})\s.*\n\s*\n')
|
25 |
+
|
26 |
+
|
27 |
+
def split_content(content):
|
28 |
+
_parts = content.split('\n\n')
|
29 |
+
parts = []
|
30 |
+
for p in _parts:
|
31 |
+
if len(p) < 2 * TEXT_CHUNK_SIZE:
|
32 |
+
parts.append(p)
|
33 |
+
else:
|
34 |
+
parts.extend(p.split('\n'))
|
35 |
+
|
36 |
+
res = ['']
|
37 |
+
for p in parts:
|
38 |
+
if len(res[-1]) + len(p) < TEXT_CHUNK_SIZE:
|
39 |
+
res[-1] += p + '\n\n'
|
40 |
+
else:
|
41 |
+
res.append(p + '\n\n')
|
42 |
+
|
43 |
+
if (
|
44 |
+
len(res) >= 2 and
|
45 |
+
len(res[-1]) < TEXT_CHUNK_SIZE / 4 and
|
46 |
+
len(res[-2]) < TEXT_CHUNK_SIZE
|
47 |
+
):
|
48 |
+
res[-2] += res[-1]
|
49 |
+
res.pop()
|
50 |
+
|
51 |
+
return res
|
52 |
+
|
53 |
+
|
54 |
+
def split_markdown(md):
|
55 |
+
def construct_chunks(content):
|
56 |
+
parts = split_content(content)
|
57 |
+
for p in parts:
|
58 |
+
construct_chunk(p)
|
59 |
+
|
60 |
+
def construct_chunk(content):
|
61 |
+
content = content.strip()
|
62 |
+
if len(content) == 0:
|
63 |
+
return
|
64 |
+
|
65 |
+
chunk = ''
|
66 |
+
for i in sorted(name_hierarchy):
|
67 |
+
if len(name_hierarchy[i]) != 0:
|
68 |
+
chunk += name_hierarchy[i] + '\n\n'
|
69 |
+
|
70 |
+
chunk += content
|
71 |
+
chunk = chunk.strip()
|
72 |
+
res.append(chunk)
|
73 |
+
|
74 |
+
md = f'\n\n{md}' # to find a header at the top of a file
|
75 |
+
headers = list(header_pattern.finditer(md))
|
76 |
+
name_hierarchy = {i: '' for i in (1, 2, 3)}
|
77 |
+
res = []
|
78 |
+
for i in range(len(headers)):
|
79 |
+
header = headers[i]
|
80 |
+
level = len(header.group(1))
|
81 |
+
name = header.group().strip()
|
82 |
+
name_hierarchy[level] = name
|
83 |
+
if i == 0 and header.start() != 0:
|
84 |
+
construct_chunks(md[:header.start()])
|
85 |
+
|
86 |
+
start = header.end()
|
87 |
+
end = headers[i + 1].start() if i + 1 < len(headers) else None
|
88 |
+
construct_chunks(md[start:end])
|
89 |
+
|
90 |
+
if len(headers) == 0:
|
91 |
+
construct_chunks(md)
|
92 |
+
|
93 |
+
return res
|
94 |
|
settings.py
CHANGED
@@ -1,13 +1,21 @@
|
|
1 |
-
|
2 |
-
TEXT_CHUNKS_DIR = "data/docs_dump"
|
3 |
EMB_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
|
4 |
LANCEDB_DIRECTORY = "data/lancedb"
|
5 |
LANCEDB_TABLE_NAME = "table"
|
6 |
VECTOR_COLUMN_NAME = "embedding"
|
7 |
TEXT_COLUMN_NAME = "text"
|
|
|
8 |
HF_LLM_NAME = "mistralai/Mistral-7B-Instruct-v0.1"
|
9 |
OPENAI_LLM_NAME = "gpt-3.5-turbo"
|
10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
context_lengths = {
|
12 |
"mistralai/Mistral-7B-Instruct-v0.1": 4096,
|
13 |
"gpt-3.5-turbo": 4096,
|
|
|
1 |
+
MARKDOWN_SOURCE_DIR = "data/transformers/docs/source/en/"
|
|
|
2 |
EMB_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
|
3 |
LANCEDB_DIRECTORY = "data/lancedb"
|
4 |
LANCEDB_TABLE_NAME = "table"
|
5 |
VECTOR_COLUMN_NAME = "embedding"
|
6 |
TEXT_COLUMN_NAME = "text"
|
7 |
+
DOCUMENT_PATH_COLUMN_NAME = "document_path"
|
8 |
HF_LLM_NAME = "mistralai/Mistral-7B-Instruct-v0.1"
|
9 |
OPENAI_LLM_NAME = "gpt-3.5-turbo"
|
10 |
|
11 |
+
""" in symbols, approximate, without headers """
|
12 |
+
TEXT_CHUNK_SIZE = 1000
|
13 |
+
|
14 |
+
emb_sizes = {
|
15 |
+
"sentence-transformers/all-MiniLM-L6-v2": 384,
|
16 |
+
"thenlper/gte-large": 0
|
17 |
+
}
|
18 |
+
|
19 |
context_lengths = {
|
20 |
"mistralai/Mistral-7B-Instruct-v0.1": 4096,
|
21 |
"gpt-3.5-turbo": 4096,
|