AlexanderKazakov commited on
Commit
d7fdb42
·
1 Parent(s): 8b1c859

improve markdown chunking

Browse files
prep_scripts/lancedb_setup.py CHANGED
@@ -11,14 +11,10 @@ import numpy as np
11
 
12
  from sentence_transformers import SentenceTransformer
13
 
 
14
  from settings import *
15
 
16
 
17
- emb_sizes = {
18
- "sentence-transformers/all-MiniLM-L6-v2": 384,
19
- "thenlper/gte-large": 0
20
- }
21
-
22
  shutil.rmtree(LANCEDB_DIRECTORY, ignore_errors=True)
23
  db = lancedb.connect(LANCEDB_DIRECTORY)
24
  batch_size = 32
@@ -33,42 +29,60 @@ elif torch.cuda.is_available():
33
  else:
34
  device = "cpu"
35
 
36
- schema = pa.schema(
37
- [
38
- pa.field(VECTOR_COLUMN_NAME, pa.list_(pa.float32(), emb_sizes[EMB_MODEL_NAME])),
39
- pa.field(TEXT_COLUMN_NAME, pa.string())
40
- ])
41
  tbl = db.create_table(LANCEDB_TABLE_NAME, schema=schema, mode="overwrite")
42
 
43
- input_dir = Path(TEXT_CHUNKS_DIR)
44
  files = list(input_dir.rglob("*"))
45
 
46
- sentences = []
47
  for file in files:
48
- with open(file, encoding='utf-8') as f:
49
- sentences.append(f.read())
50
-
51
- for i in tqdm.tqdm(range(0, int(np.ceil(len(sentences) / batch_size)))):
52
- try:
53
- batch = [sent for sent in sentences[i * batch_size:(i + 1) * batch_size] if len(sent) > 0]
54
- encoded = model.encode(batch, normalize_embeddings=True, device=device)
55
- encoded = [list(vec) for vec in encoded]
56
-
57
- df = pd.DataFrame({
58
- VECTOR_COLUMN_NAME: encoded,
59
- TEXT_COLUMN_NAME: batch
60
- })
61
 
62
- tbl.add(df)
 
 
 
63
 
64
- except:
65
- print(f"batch {i} was skipped: {traceback.format_exc()}")
66
-
67
-
68
- '''
69
- create ivf-pd index https://lancedb.github.io/lancedb/ann_indexes/
70
- with the size of the transformer docs, index is not really needed
71
- but we'll do it for demonstration purposes
72
- '''
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  # tbl.create_index(num_partitions=256, num_sub_vectors=96, vector_column_name=VECTOR_COLUMN_NAME)
74
 
 
11
 
12
  from sentence_transformers import SentenceTransformer
13
 
14
+ from markdown_to_text import *
15
  from settings import *
16
 
17
 
 
 
 
 
 
18
  shutil.rmtree(LANCEDB_DIRECTORY, ignore_errors=True)
19
  db = lancedb.connect(LANCEDB_DIRECTORY)
20
  batch_size = 32
 
29
  else:
30
  device = "cpu"
31
 
32
+ schema = pa.schema([
33
+ pa.field(VECTOR_COLUMN_NAME, pa.list_(pa.float32(), emb_sizes[EMB_MODEL_NAME])),
34
+ pa.field(TEXT_COLUMN_NAME, pa.string()),
35
+ pa.field(DOCUMENT_PATH_COLUMN_NAME, pa.string()),
36
+ ])
37
  tbl = db.create_table(LANCEDB_TABLE_NAME, schema=schema, mode="overwrite")
38
 
39
+ input_dir = Path(MARKDOWN_SOURCE_DIR)
40
  files = list(input_dir.rglob("*"))
41
 
42
+ chunks = []
43
  for file in files:
44
+ if not os.path.isfile(file):
45
+ continue
 
 
 
 
 
 
 
 
 
 
 
46
 
47
+ file_path, file_ext = os.path.splitext(os.path.relpath(file, input_dir))
48
+ if file_ext != '.md':
49
+ print(f'Skipped {file_ext} extension: {file}')
50
+ continue
51
 
52
+ doc_header = ' / '.join(split_path(file_path)) + ':\n\n'
53
+ with open(file, encoding='utf-8') as f:
54
+ f = f.read()
55
+ f = remove_comments(f)
56
+ f = split_markdown(f)
57
+ chunks.extend((doc_header + chunk, os.path.abspath(file)) for chunk in f)
58
+
59
+ from matplotlib import pyplot as plt
60
+ plt.hist([len(c) for c, d in chunks], bins=100)
61
+ plt.show()
62
+
63
+ for i in tqdm.tqdm(range(0, int(np.ceil(len(chunks) / batch_size)))):
64
+ texts, doc_paths = [], []
65
+ for text, doc_path in chunks[i * batch_size:(i + 1) * batch_size]:
66
+ if len(text) > 0:
67
+ texts.append(text)
68
+ doc_paths.append(doc_path)
69
+
70
+ encoded = model.encode(texts, normalize_embeddings=True, device=device)
71
+ encoded = [list(vec) for vec in encoded]
72
+
73
+ df = pd.DataFrame({
74
+ VECTOR_COLUMN_NAME: encoded,
75
+ TEXT_COLUMN_NAME: texts,
76
+ DOCUMENT_PATH_COLUMN_NAME: doc_paths,
77
+ })
78
+
79
+ tbl.add(df)
80
+
81
+
82
+ # '''
83
+ # create ivf-pd index https://lancedb.github.io/lancedb/ann_indexes/
84
+ # with the size of the transformer docs, index is not really needed
85
+ # but we'll do it for demonstration purposes
86
+ # '''
87
  # tbl.create_index(num_partitions=256, num_sub_vectors=96, vector_column_name=VECTOR_COLUMN_NAME)
88
 
prep_scripts/markdown_to_text.py CHANGED
@@ -1,50 +1,94 @@
1
- import shutil
2
-
3
- from bs4 import BeautifulSoup
4
- from markdown import markdown
5
  import os
6
  import re
7
- from pathlib import Path
8
 
9
  from settings import *
10
 
11
 
12
- def markdown_to_text(markdown_string):
13
- """ Converts a markdown string to plaintext """
14
-
15
- # md -> html -> text since BeautifulSoup can extract text cleanly
16
- html = markdown(markdown_string)
17
-
18
- html = re.sub(r'<!--((.|\n)*)-->', '', html)
19
- html = re.sub('<code>bash', '<code>', html)
20
-
21
- # extract text
22
- soup = BeautifulSoup(html, "html.parser")
23
- text = ''.join(soup.findAll(string=True))
24
-
25
- text = re.sub('```(py|diff|python)', '', text)
26
- text = re.sub('```\n', '\n', text)
27
- text = re.sub('- .*', '', text)
28
- text = text.replace('...', '')
29
- text = re.sub('\n(\n)+', '\n\n', text)
30
-
31
- return text
32
-
33
-
34
- dir_to_scrape = Path(MARKDOWN_DIR_TO_SCRAPE)
35
- files = list(dir_to_scrape.rglob("*"))
36
-
37
- shutil.rmtree(TEXT_CHUNKS_DIR, ignore_errors=True)
38
- os.makedirs(TEXT_CHUNKS_DIR)
39
-
40
- for file in files:
41
- parent = file.parent.stem if file.parent.stem != dir_to_scrape.stem else ""
42
- if file.is_file():
43
- with open(file, encoding='utf-8') as f:
44
- md = f.read()
45
-
46
- text = markdown_to_text(md)
47
-
48
- with open(os.path.join(TEXT_CHUNKS_DIR, f"{parent}_{file.stem}.txt"), "w", encoding='utf-8') as f:
49
- f.write(text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
 
 
 
 
 
1
  import os
2
  import re
 
3
 
4
  from settings import *
5
 
6
 
7
+ def split_path(path):
8
+ components = []
9
+ while True:
10
+ path, tail = os.path.split(path)
11
+ if tail == "":
12
+ if path != "":
13
+ components.append(path)
14
+ break
15
+ components.append(tail)
16
+ components.reverse()
17
+ return components
18
+
19
+
20
+ def remove_comments(md):
21
+ return re.sub(r'<!--((.|\n)*)-->', '', md)
22
+
23
+
24
+ header_pattern = re.compile(r'\n\s*\n(#{1,3})\s.*\n\s*\n')
25
+
26
+
27
+ def split_content(content):
28
+ _parts = content.split('\n\n')
29
+ parts = []
30
+ for p in _parts:
31
+ if len(p) < 2 * TEXT_CHUNK_SIZE:
32
+ parts.append(p)
33
+ else:
34
+ parts.extend(p.split('\n'))
35
+
36
+ res = ['']
37
+ for p in parts:
38
+ if len(res[-1]) + len(p) < TEXT_CHUNK_SIZE:
39
+ res[-1] += p + '\n\n'
40
+ else:
41
+ res.append(p + '\n\n')
42
+
43
+ if (
44
+ len(res) >= 2 and
45
+ len(res[-1]) < TEXT_CHUNK_SIZE / 4 and
46
+ len(res[-2]) < TEXT_CHUNK_SIZE
47
+ ):
48
+ res[-2] += res[-1]
49
+ res.pop()
50
+
51
+ return res
52
+
53
+
54
+ def split_markdown(md):
55
+ def construct_chunks(content):
56
+ parts = split_content(content)
57
+ for p in parts:
58
+ construct_chunk(p)
59
+
60
+ def construct_chunk(content):
61
+ content = content.strip()
62
+ if len(content) == 0:
63
+ return
64
+
65
+ chunk = ''
66
+ for i in sorted(name_hierarchy):
67
+ if len(name_hierarchy[i]) != 0:
68
+ chunk += name_hierarchy[i] + '\n\n'
69
+
70
+ chunk += content
71
+ chunk = chunk.strip()
72
+ res.append(chunk)
73
+
74
+ md = f'\n\n{md}' # to find a header at the top of a file
75
+ headers = list(header_pattern.finditer(md))
76
+ name_hierarchy = {i: '' for i in (1, 2, 3)}
77
+ res = []
78
+ for i in range(len(headers)):
79
+ header = headers[i]
80
+ level = len(header.group(1))
81
+ name = header.group().strip()
82
+ name_hierarchy[level] = name
83
+ if i == 0 and header.start() != 0:
84
+ construct_chunks(md[:header.start()])
85
+
86
+ start = header.end()
87
+ end = headers[i + 1].start() if i + 1 < len(headers) else None
88
+ construct_chunks(md[start:end])
89
+
90
+ if len(headers) == 0:
91
+ construct_chunks(md)
92
+
93
+ return res
94
 
settings.py CHANGED
@@ -1,13 +1,21 @@
1
- MARKDOWN_DIR_TO_SCRAPE = "data/transformers/docs/source/en/"
2
- TEXT_CHUNKS_DIR = "data/docs_dump"
3
  EMB_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
4
  LANCEDB_DIRECTORY = "data/lancedb"
5
  LANCEDB_TABLE_NAME = "table"
6
  VECTOR_COLUMN_NAME = "embedding"
7
  TEXT_COLUMN_NAME = "text"
 
8
  HF_LLM_NAME = "mistralai/Mistral-7B-Instruct-v0.1"
9
  OPENAI_LLM_NAME = "gpt-3.5-turbo"
10
 
 
 
 
 
 
 
 
 
11
  context_lengths = {
12
  "mistralai/Mistral-7B-Instruct-v0.1": 4096,
13
  "gpt-3.5-turbo": 4096,
 
1
+ MARKDOWN_SOURCE_DIR = "data/transformers/docs/source/en/"
 
2
  EMB_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
3
  LANCEDB_DIRECTORY = "data/lancedb"
4
  LANCEDB_TABLE_NAME = "table"
5
  VECTOR_COLUMN_NAME = "embedding"
6
  TEXT_COLUMN_NAME = "text"
7
+ DOCUMENT_PATH_COLUMN_NAME = "document_path"
8
  HF_LLM_NAME = "mistralai/Mistral-7B-Instruct-v0.1"
9
  OPENAI_LLM_NAME = "gpt-3.5-turbo"
10
 
11
+ """ in symbols, approximate, without headers """
12
+ TEXT_CHUNK_SIZE = 1000
13
+
14
+ emb_sizes = {
15
+ "sentence-transformers/all-MiniLM-L6-v2": 384,
16
+ "thenlper/gte-large": 0
17
+ }
18
+
19
  context_lengths = {
20
  "mistralai/Mistral-7B-Instruct-v0.1": 4096,
21
  "gpt-3.5-turbo": 4096,