Spaces:
Runtime error
Runtime error
import typing | |
import nltk | |
from transformers import AutoTokenizer | |
import pathlib | |
def fixed_strategy(tokenizer, data: str, max_length: int) -> typing.List[str]: | |
tokens = tokenizer(data)['input_ids'] | |
token_chunks = [tokens[idx: idx + max_length] for idx in range(0, len(tokens), max_length)] | |
chunks = [tokenizer.decode(token_chunk, skip_special_tokens=True) for token_chunk in token_chunks] | |
return chunks | |
def content_aware_strategy(tokenizer, data: str, max_length: int) -> typing.List[str]: | |
sentences = nltk.sent_tokenize(data) | |
chunks = [] | |
current_chunk = None | |
current_chunk_length = 0 | |
for sentence in sentences: | |
if current_chunk is None: | |
current_chunk = sentence | |
current_chunk_length = len(tokenizer(sentence)['input_ids']) | |
else: | |
current_sentence_length = len(tokenizer(sentence)['input_ids']) | |
if current_chunk_length + current_sentence_length > max_length: | |
chunks.append(current_chunk) | |
current_chunk = sentence | |
current_chunk_length = current_sentence_length | |
else: | |
current_chunk += sentence | |
current_chunk_length += current_sentence_length | |
if current_chunk is not None: | |
chunks.append(current_chunk) | |
return chunks | |
class DataChunker: | |
def __init__(self, model_name: str, max_length: int): | |
self.tokenizer = AutoTokenizer.from_pretrained(model_name) | |
self.max_length = max_length | |
def chunk_folder(self, input_dir: str, output_dir: str, strategy: typing.Callable): | |
p = pathlib.Path(output_dir) | |
p.mkdir(parents=True, exist_ok=True) | |
input_dir = pathlib.Path(input_dir) | |
for input_file_path in input_dir.glob("*.txt"): | |
with open(input_file_path, 'r') as f: | |
data = f.read() | |
chunks = strategy(self.tokenizer, data, self.max_length) | |
for i, chunk in enumerate(chunks): | |
new_file_path = f'{output_dir}/{input_file_path.stem}_{i}.txt' | |
with open(new_file_path, 'w') as fw: | |
fw.write(chunk) | |
if __name__ == "__main__": | |
nltk.download('punkt') | |
model_names = ["sentence-transformers/all-MiniLM-L6-v2", "BAAI/bge-large-en-v1.5"] | |
max_length = 512 | |
for model_name in model_names: | |
data_chunker = DataChunker(model_name, max_length) | |
model_suffix = model_name.split("/")[1] | |
data_chunker.chunk_folder("../docs", f"../docs_chunked_{model_suffix}", fixed_strategy) | |
data_chunker.chunk_folder("../docs", f"../docs_chunked_ca_{model_suffix}", content_aware_strategy) | |