|
|
|
|
|
|
|
|
|
from abc import ABC, abstractmethod |
|
from enum import Enum |
|
import logging |
|
from typing import ( |
|
AbstractSet, |
|
Any, |
|
Callable, |
|
Collection, |
|
Iterable, |
|
List, |
|
Literal, |
|
Optional, |
|
Sequence, |
|
Type, |
|
TypeVar, |
|
Union, |
|
) |
|
from base_chunker import BaseChunker |
|
|
|
|
|
from attr import dataclass |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
TS = TypeVar("TS", bound="TextSplitter") |
|
class TextSplitter(BaseChunker, ABC): |
|
"""Interface for splitting text into chunks.""" |
|
|
|
def __init__( |
|
self, |
|
chunk_size: int = 4000, |
|
chunk_overlap: int = 200, |
|
length_function: Callable[[str], int] = len, |
|
keep_separator: bool = False, |
|
add_start_index: bool = False, |
|
strip_whitespace: bool = True, |
|
) -> None: |
|
"""Create a new TextSplitter. |
|
|
|
Args: |
|
chunk_size: Maximum size of chunks to return |
|
chunk_overlap: Overlap in characters between chunks |
|
length_function: Function that measures the length of given chunks |
|
keep_separator: Whether to keep the separator in the chunks |
|
add_start_index: If `True`, includes chunk's start index in metadata |
|
strip_whitespace: If `True`, strips whitespace from the start and end of |
|
every document |
|
""" |
|
if chunk_overlap > chunk_size: |
|
raise ValueError( |
|
f"Got a larger chunk overlap ({chunk_overlap}) than chunk size " |
|
f"({chunk_size}), should be smaller." |
|
) |
|
self._chunk_size = chunk_size |
|
self._chunk_overlap = chunk_overlap |
|
self._length_function = length_function |
|
self._keep_separator = keep_separator |
|
self._add_start_index = add_start_index |
|
self._strip_whitespace = strip_whitespace |
|
|
|
@abstractmethod |
|
def split_text(self, text: str) -> List[str]: |
|
"""Split text into multiple components.""" |
|
|
|
def _join_docs(self, docs: List[str], separator: str) -> Optional[str]: |
|
text = separator.join(docs) |
|
if self._strip_whitespace: |
|
text = text.strip() |
|
if text == "": |
|
return None |
|
else: |
|
return text |
|
|
|
def _merge_splits(self, splits: Iterable[str], separator: str) -> List[str]: |
|
|
|
|
|
separator_len = self._length_function(separator) |
|
|
|
docs = [] |
|
current_doc: List[str] = [] |
|
total = 0 |
|
for d in splits: |
|
_len = self._length_function(d) |
|
if ( |
|
total + _len + (separator_len if len(current_doc) > 0 else 0) |
|
> self._chunk_size |
|
): |
|
if total > self._chunk_size: |
|
logger.warning( |
|
f"Created a chunk of size {total}, " |
|
f"which is longer than the specified {self._chunk_size}" |
|
) |
|
if len(current_doc) > 0: |
|
doc = self._join_docs(current_doc, separator) |
|
if doc is not None: |
|
docs.append(doc) |
|
|
|
|
|
|
|
while total > self._chunk_overlap or ( |
|
total + _len + (separator_len if len(current_doc) > 0 else 0) |
|
> self._chunk_size |
|
and total > 0 |
|
): |
|
total -= self._length_function(current_doc[0]) + ( |
|
separator_len if len(current_doc) > 1 else 0 |
|
) |
|
current_doc = current_doc[1:] |
|
current_doc.append(d) |
|
total += _len + (separator_len if len(current_doc) > 1 else 0) |
|
doc = self._join_docs(current_doc, separator) |
|
if doc is not None: |
|
docs.append(doc) |
|
return docs |