Spaces:
Building
on
T4
Building
on
T4
# This script is adapted from the LangChain package, developed by LangChain AI. | |
# Original code can be found at: https://github.com/langchain-ai/langchain/blob/master/libs/text-splitters/langchain_text_splitters/base.py | |
# License: MIT License | |
from abc import ABC, abstractmethod | |
from enum import Enum | |
import logging | |
from typing import ( | |
AbstractSet, | |
Any, | |
Callable, | |
Collection, | |
Iterable, | |
List, | |
Literal, | |
Optional, | |
Sequence, | |
Type, | |
TypeVar, | |
Union, | |
) | |
from base_chunker import BaseChunker | |
from attr import dataclass | |
logger = logging.getLogger(__name__) | |
TS = TypeVar("TS", bound="TextSplitter") | |
class TextSplitter(BaseChunker, ABC): | |
"""Interface for splitting text into chunks.""" | |
def __init__( | |
self, | |
chunk_size: int = 4000, | |
chunk_overlap: int = 200, | |
length_function: Callable[[str], int] = len, | |
keep_separator: bool = False, | |
add_start_index: bool = False, | |
strip_whitespace: bool = True, | |
) -> None: | |
"""Create a new TextSplitter. | |
Args: | |
chunk_size: Maximum size of chunks to return | |
chunk_overlap: Overlap in characters between chunks | |
length_function: Function that measures the length of given chunks | |
keep_separator: Whether to keep the separator in the chunks | |
add_start_index: If `True`, includes chunk's start index in metadata | |
strip_whitespace: If `True`, strips whitespace from the start and end of | |
every document | |
""" | |
if chunk_overlap > chunk_size: | |
raise ValueError( | |
f"Got a larger chunk overlap ({chunk_overlap}) than chunk size " | |
f"({chunk_size}), should be smaller." | |
) | |
self._chunk_size = chunk_size | |
self._chunk_overlap = chunk_overlap | |
self._length_function = length_function | |
self._keep_separator = keep_separator | |
self._add_start_index = add_start_index | |
self._strip_whitespace = strip_whitespace | |
def split_text(self, text: str) -> List[str]: | |
"""Split text into multiple components.""" | |
def _join_docs(self, docs: List[str], separator: str) -> Optional[str]: | |
text = separator.join(docs) | |
if self._strip_whitespace: | |
text = text.strip() | |
if text == "": | |
return None | |
else: | |
return text | |
def _merge_splits(self, splits: Iterable[str], separator: str) -> List[str]: | |
# We now want to combine these smaller pieces into medium size | |
# chunks to send to the LLM. | |
separator_len = self._length_function(separator) | |
docs = [] | |
current_doc: List[str] = [] | |
total = 0 | |
for d in splits: | |
_len = self._length_function(d) | |
if ( | |
total + _len + (separator_len if len(current_doc) > 0 else 0) | |
> self._chunk_size | |
): | |
if total > self._chunk_size: | |
logger.warning( | |
f"Created a chunk of size {total}, " | |
f"which is longer than the specified {self._chunk_size}" | |
) | |
if len(current_doc) > 0: | |
doc = self._join_docs(current_doc, separator) | |
if doc is not None: | |
docs.append(doc) | |
# Keep on popping if: | |
# - we have a larger chunk than in the chunk overlap | |
# - or if we still have any chunks and the length is long | |
while total > self._chunk_overlap or ( | |
total + _len + (separator_len if len(current_doc) > 0 else 0) | |
> self._chunk_size | |
and total > 0 | |
): | |
total -= self._length_function(current_doc[0]) + ( | |
separator_len if len(current_doc) > 1 else 0 | |
) | |
current_doc = current_doc[1:] | |
current_doc.append(d) | |
total += _len + (separator_len if len(current_doc) > 1 else 0) | |
doc = self._join_docs(current_doc, separator) | |
if doc is not None: | |
docs.append(doc) | |
return docs |