rag_time / backend /utils /data_chunking.py
LevGervich's picture
Fix imports
fa4b416
raw
history blame
2.67 kB
import typing
import nltk
from transformers import AutoTokenizer
import pathlib
def fixed_strategy(tokenizer, data: str, max_length: int) -> typing.List[str]:
tokens = tokenizer(data)['input_ids']
token_chunks = [tokens[idx: idx + max_length] for idx in range(0, len(tokens), max_length)]
chunks = [tokenizer.decode(token_chunk, skip_special_tokens=True) for token_chunk in token_chunks]
return chunks
def content_aware_strategy(tokenizer, data: str, max_length: int) -> typing.List[str]:
sentences = nltk.sent_tokenize(data)
chunks = []
current_chunk = None
current_chunk_length = 0
for sentence in sentences:
if current_chunk is None:
current_chunk = sentence
current_chunk_length = len(tokenizer(sentence)['input_ids'])
else:
current_sentence_length = len(tokenizer(sentence)['input_ids'])
if current_chunk_length + current_sentence_length > max_length:
chunks.append(current_chunk)
current_chunk = sentence
current_chunk_length = current_sentence_length
else:
current_chunk += sentence
current_chunk_length += current_sentence_length
if current_chunk is not None:
chunks.append(current_chunk)
return chunks
class DataChunker:
def __init__(self, model_name: str, max_length: int):
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.max_length = max_length
def chunk_folder(self, input_dir: str, output_dir: str, strategy: typing.Callable):
p = pathlib.Path(output_dir)
p.mkdir(parents=True, exist_ok=True)
input_dir = pathlib.Path(input_dir)
for input_file_path in input_dir.glob("*.txt"):
with open(input_file_path, 'r') as f:
data = f.read()
chunks = strategy(self.tokenizer, data, self.max_length)
for i, chunk in enumerate(chunks):
new_file_path = f'{output_dir}/{input_file_path.stem}_{i}.txt'
with open(new_file_path, 'w') as fw:
fw.write(chunk)
if __name__ == "__main__":
nltk.download('punkt')
model_names = ["sentence-transformers/all-MiniLM-L6-v2", "BAAI/bge-large-en-v1.5"]
max_length = 512
for model_name in model_names:
data_chunker = DataChunker(model_name, max_length)
model_suffix = model_name.split("/")[1]
data_chunker.chunk_folder("../docs", f"../docs_chunked_{model_suffix}", fixed_strategy)
data_chunker.chunk_folder("../docs", f"../docs_chunked_ca_{model_suffix}", content_aware_strategy)