# This script is adapted from the LangChain package, developed by LangChain AI.
# Original code can be found at:
# License: MIT License
from attr import dataclass
import numpy as np
import tiktoken
from chunking_evaluation.utils import Language
import re
from abc import ABC, abstractmethod
from enum import Enum
import logging
from typing import (
class BaseChunker(ABC):
def split_text(self, text: str) -> List[str]:
logger = logging.getLogger(__name__)
TS = TypeVar("TS", bound="TextSplitter")
class TextSplitter(BaseChunker, ABC):
"""Interface for splitting text into chunks."""
def __init__(
chunk_size: int = 4000,
chunk_overlap: int = 200,
length_function: Callable[[str], int] = len,
keep_separator: bool = False,
add_start_index: bool = False,
strip_whitespace: bool = True,
) -> None:
"""Create a new TextSplitter.
chunk_size: Maximum size of chunks to return
chunk_overlap: Overlap in characters between chunks
length_function: Function that measures the length of given chunks
keep_separator: Whether to keep the separator in the chunks
add_start_index: If `True`, includes chunk's start index in metadata
strip_whitespace: If `True`, strips whitespace from the start and end of
every document
if chunk_overlap > chunk_size:
raise ValueError(
f"Got a larger chunk overlap ({chunk_overlap}) than chunk size "
f"({chunk_size}), should be smaller."
self._chunk_size = chunk_size
self._chunk_overlap = chunk_overlap
self._length_function = length_function
self._keep_separator = keep_separator
self._add_start_index = add_start_index
self._strip_whitespace = strip_whitespace
def split_text(self, text: str) -> List[str]:
"""Split text into multiple components."""
def _join_docs(self, docs: List[str], separator: str) -> Optional[str]:
text = separator.join(docs)
if self._strip_whitespace:
text = text.strip()
if text == "":
return None
return text
def _merge_splits(self, splits: Iterable[str], separator: str) -> List[str]:
# We now want to combine these smaller pieces into medium size
# chunks to send to the LLM.
separator_len = self._length_function(separator)
docs = []
current_doc: List[str] = []
total = 0
for d in splits:
_len = self._length_function(d)
if (
total + _len + (separator_len if len(current_doc) > 0 else 0)
> self._chunk_size
if total > self._chunk_size:
f"Created a chunk of size {total}, "
f"which is longer than the specified {self._chunk_size}"
if len(current_doc) > 0:
doc = self._join_docs(current_doc, separator)
if doc is not None:
# Keep on popping if:
# - we have a larger chunk than in the chunk overlap
# - or if we still have any chunks and the length is long
while total > self._chunk_overlap or (
total + _len + (separator_len if len(current_doc) > 0 else 0)
> self._chunk_size
and total > 0
total -= self._length_function(current_doc[0]) + (
separator_len if len(current_doc) > 1 else 0
current_doc = current_doc[1:]
total += _len + (separator_len if len(current_doc) > 1 else 0)
doc = self._join_docs(current_doc, separator)
if doc is not None:
return docs
def from_tiktoken_encoder(
cls: Type[TS],
encoding_name: str = "gpt2",
model_name: Optional[str] = None,
allowed_special: Union[Literal["all"], AbstractSet[str]] = set(),
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
**kwargs: Any,
) -> TS:
"""Text splitter that uses tiktoken encoder to count length."""
import tiktoken
except ImportError:
raise ImportError(
"Could not import tiktoken python package. "
"This is needed in order to calculate max_tokens_for_prompt. "
"Please install it with `pip install tiktoken`."
if model_name is not None:
enc = tiktoken.encoding_for_model(model_name)
enc = tiktoken.get_encoding(encoding_name)
def _tiktoken_encoder(text: str) -> int:
return len(
if issubclass(cls, FixedTokenChunker):
extra_kwargs = {
"encoding_name": encoding_name,
"model_name": model_name,
"allowed_special": allowed_special,
"disallowed_special": disallowed_special,
kwargs = {**kwargs, **extra_kwargs}
return cls(length_function=_tiktoken_encoder, **kwargs)
class FixedTokenChunker(TextSplitter):
"""Splitting text to tokens using model tokenizer."""
def __init__(
encoding_name: str = "cl100k_base",
model_name: Optional[str] = None,
chunk_size: int = 4000,
chunk_overlap: int = 200,
allowed_special: Union[Literal["all"], AbstractSet[str]] = set(),
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
**kwargs: Any,
) -> None:
"""Create a new TextSplitter."""
super().__init__(chunk_size=chunk_size, chunk_overlap=chunk_overlap, **kwargs)
import tiktoken
except ImportError:
raise ImportError(
"Could not import tiktoken python package. "
"This is needed in order to for FixedTokenChunker. "
"Please install it with `pip install tiktoken`."
if model_name is not None:
enc = tiktoken.encoding_for_model(model_name)
enc = tiktoken.get_encoding(encoding_name)
self._tokenizer = enc
self._allowed_special = allowed_special
self._disallowed_special = disallowed_special
def split_text(self, text: str) -> List[str]:
def _encode(_text: str) -> List[int]:
return self._tokenizer.encode(
tokenizer = Tokenizer(
return split_text_on_tokens(text=text, tokenizer=tokenizer)
class Tokenizer:
"""Tokenizer data class."""
chunk_overlap: int
"""Overlap in tokens between chunks"""
tokens_per_chunk: int
"""Maximum number of tokens per chunk"""
decode: Callable[[List[int]], str]
""" Function to decode a list of token ids to a string"""
encode: Callable[[str], List[int]]
""" Function to encode a string to a list of token ids"""
def split_text_on_tokens(*, text: str, tokenizer: Tokenizer) -> List[str]:
"""Split incoming text and return chunks using tokenizer."""
splits: List[str] = []
input_ids = tokenizer.encode(text)
start_idx = 0
cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids))
chunk_ids = input_ids[start_idx:cur_idx]
while start_idx < len(input_ids):
if cur_idx == len(input_ids):
start_idx += tokenizer.tokens_per_chunk - tokenizer.chunk_overlap
cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids))
chunk_ids = input_ids[start_idx:cur_idx]
return splits
def _split_text_with_regex(
text: str, separator: str, keep_separator: bool
) -> List[str]:
# Now that we have the separator, split the text
if separator:
if keep_separator:
# The parentheses in the pattern keep the delimiters in the result.
_splits = re.split(f"({separator})", text)
splits = [_splits[i] + _splits[i + 1] for i in range(1, len(_splits), 2)]
if len(_splits) % 2 == 0:
splits += _splits[-1:]
splits = [_splits[0]] + splits
splits = re.split(separator, text)
splits = list(text)
return [s for s in splits if s != ""]
class RecursiveTokenChunker(TextSplitter):
"""Splitting text by recursively look at characters.
Recursively tries to split by different characters to find one
that works.
def __init__(
chunk_size: int = 4000,
chunk_overlap: int = 200,
separators: Optional[List[str]] = None,
keep_separator: bool = True,
is_separator_regex: bool = False,
**kwargs: Any,
) -> None:
"""Create a new TextSplitter."""
super().__init__(chunk_size=chunk_size, chunk_overlap=chunk_overlap, keep_separator=keep_separator, **kwargs)
self._separators = separators or ["\n\n", "\n", ".", "?", "!", " ", ""]
self._is_separator_regex = is_separator_regex
def _split_text(self, text: str, separators: List[str]) -> List[str]:
"""Split incoming text and return chunks."""
final_chunks = []
# Get appropriate separator to use
separator = separators[-1]
new_separators = []
for i, _s in enumerate(separators):
_separator = _s if self._is_separator_regex else re.escape(_s)
if _s == "":
separator = _s
if, text):
separator = _s
new_separators = separators[i + 1 :]
_separator = separator if self._is_separator_regex else re.escape(separator)
splits = _split_text_with_regex(text, _separator, self._keep_separator)
# Now go merging things, recursively splitting longer texts.
_good_splits = []
_separator = "" if self._keep_separator else separator
for s in splits:
if self._length_function(s) < self._chunk_size:
if _good_splits:
merged_text = self._merge_splits(_good_splits, _separator)
_good_splits = []
if not new_separators:
other_info = self._split_text(s, new_separators)
if _good_splits:
merged_text = self._merge_splits(_good_splits, _separator)
return final_chunks
def split_text(self, text: str) -> List[str]:
return self._split_text(text, self._separators)
def get_separators_for_language(language: Language) -> List[str]:
if language == Language.CPP:
return [
# Split along class definitions
"\nclass ",
# Split along function definitions
"\nvoid ",
"\nint ",
"\nfloat ",
"\ndouble ",
# Split along control flow statements
"\nif ",
"\nfor ",
"\nwhile ",
"\nswitch ",
"\ncase ",
# Split by the normal type of lines
" ",
elif language == Language.GO:
return [
# Split along function definitions
"\nfunc ",
"\nvar ",
"\nconst ",
"\ntype ",
# Split along control flow statements
"\nif ",
"\nfor ",
"\nswitch ",
"\ncase ",
# Split by the normal type of lines
" ",
elif language == Language.JAVA:
return [
# Split along class definitions
"\nclass ",
# Split along method definitions
"\npublic ",
"\nprotected ",
"\nprivate ",
"\nstatic ",
# Split along control flow statements
"\nif ",
"\nfor ",
"\nwhile ",
"\nswitch ",
"\ncase ",
# Split by the normal type of lines
" ",
elif language == Language.KOTLIN:
return [
# Split along class definitions
"\nclass ",
# Split along method definitions
"\npublic ",
"\nprotected ",
"\nprivate ",
"\ninternal ",
"\ncompanion ",
"\nfun ",
"\nval ",
"\nvar ",
# Split along control flow statements
"\nif ",
"\nfor ",
"\nwhile ",
"\nwhen ",
"\ncase ",
"\nelse ",
# Split by the normal type of lines
" ",
elif language == Language.JS:
return [
# Split along function definitions
"\nfunction ",
"\nconst ",
"\nlet ",
"\nvar ",
"\nclass ",
# Split along control flow statements
"\nif ",
"\nfor ",
"\nwhile ",
"\nswitch ",
"\ncase ",
"\ndefault ",
# Split by the normal type of lines
" ",
elif language == Language.TS:
return [
"\nenum ",
"\ninterface ",
"\nnamespace ",
"\ntype ",
# Split along class definitions
"\nclass ",
# Split along function definitions
"\nfunction ",
"\nconst ",
"\nlet ",
"\nvar ",
# Split along control flow statements
"\nif ",
"\nfor ",
"\nwhile ",
"\nswitch ",
"\ncase ",
"\ndefault ",
# Split by the normal type of lines
" ",
elif language == Language.PHP:
return [
# Split along function definitions
"\nfunction ",
# Split along class definitions
"\nclass ",
# Split along control flow statements
"\nif ",
"\nforeach ",
"\nwhile ",
"\ndo ",
"\nswitch ",
"\ncase ",
# Split by the normal type of lines
" ",
elif language == Language.PROTO:
return [
# Split along message definitions
"\nmessage ",
# Split along service definitions
"\nservice ",
# Split along enum definitions
"\nenum ",
# Split along option definitions
"\noption ",
# Split along import statements
"\nimport ",
# Split along syntax declarations
"\nsyntax ",
# Split by the normal type of lines
" ",
elif language == Language.PYTHON:
return [
# First, try to split along class definitions
"\nclass ",
"\ndef ",
"\n\tdef ",
# Now split by the normal type of lines
" ",
elif language == Language.RST:
return [
# Split along section titles
# Split along directive markers
"\n\n.. *\n\n",
# Split by the normal type of lines
" ",
elif language == Language.RUBY:
return [
# Split along method definitions
"\ndef ",
"\nclass ",
# Split along control flow statements
"\nif ",
"\nunless ",
"\nwhile ",
"\nfor ",
"\ndo ",
"\nbegin ",
"\nrescue ",
# Split by the normal type of lines
" ",
elif language == Language.RUST:
return [
# Split along function definitions
"\nfn ",
"\nconst ",
"\nlet ",
# Split along control flow statements
"\nif ",
"\nwhile ",
"\nfor ",
"\nloop ",
"\nmatch ",
"\nconst ",
# Split by the normal type of lines
" ",
elif language == Language.SCALA:
return [
# Split along class definitions
"\nclass ",
"\nobject ",
# Split along method definitions
"\ndef ",
"\nval ",
"\nvar ",
# Split along control flow statements
"\nif ",
"\nfor ",
"\nwhile ",
"\nmatch ",
"\ncase ",
# Split by the normal type of lines
" ",
elif language == Language.SWIFT:
return [
# Split along function definitions
"\nfunc ",
# Split along class definitions
"\nclass ",
"\nstruct ",
"\nenum ",
# Split along control flow statements
"\nif ",
"\nfor ",
"\nwhile ",
"\ndo ",
"\nswitch ",
"\ncase ",
# Split by the normal type of lines
" ",
elif language == Language.MARKDOWN:
return [
# First, try to split along Markdown headings (starting with level 2)
"\n#{1,6} ",
# Note the alternative syntax for headings (below) is not handled here
# Heading level 2
# ---------------
# End of code block
# Horizontal lines
# Note that this splitter doesn't handle horizontal lines defined
# by *three or more* of ***, ---, or ___, but this is not handled
" ",
elif language == Language.LATEX:
return [
# First, try to split along Latex sections
# Now split by environments
# Now split by math environments
# Now split by the normal type of lines
" ",
elif language == Language.HTML:
return [
# First, try to split along HTML tags
# Head
elif language == Language.CSHARP:
return [
"\ninterface ",
"\nenum ",
"\nimplements ",
"\ndelegate ",
"\nevent ",
# Split along class definitions
"\nclass ",
"\nabstract ",
# Split along method definitions
"\npublic ",
"\nprotected ",
"\nprivate ",
"\nstatic ",
"\nreturn ",
# Split along control flow statements
"\nif ",
"\ncontinue ",
"\nfor ",
"\nforeach ",
"\nwhile ",
"\nswitch ",
"\nbreak ",
"\ncase ",
"\nelse ",
# Split by exceptions
"\ntry ",
"\nthrow ",
"\nfinally ",
"\ncatch ",
# Split by the normal type of lines
" ",
elif language == Language.SOL:
return [
# Split along compiler information definitions
"\npragma ",
"\nusing ",
# Split along contract definitions
"\ncontract ",
"\ninterface ",
"\nlibrary ",
# Split along method definitions
"\nconstructor ",
"\ntype ",
"\nfunction ",
"\nevent ",
"\nmodifier ",
"\nerror ",
"\nstruct ",
"\nenum ",
# Split along control flow statements
"\nif ",
"\nfor ",
"\nwhile ",
"\ndo while ",
"\nassembly ",
# Split by the normal type of lines
" ",
elif language == Language.COBOL:
return [
# Split along divisions
# Split along sections within DATA DIVISION
# Split along sections within PROCEDURE DIVISION
# Split along paragraphs and common statements
"\nOPEN ",
"\nCLOSE ",
"\nREAD ",
"\nWRITE ",
"\nIF ",
"\nELSE ",
"\nMOVE ",
"\nUNTIL ",
"\nACCEPT ",
"\nSTOP RUN.",
# Split by the normal type of lines
" ",
raise ValueError(
f"Language {language} is not supported! "
f"Please choose from {list(Language)}"
class ClusterSemanticChunker(BaseChunker):
def __init__(self, embedding_function=None, max_chunk_size=400, min_chunk_size=50, length_function=openai_token_count):
self.splitter = RecursiveTokenChunker(
separators = ["\n\n", "\n", ".", "?", "!", " ", ""]
self._chunk_size = max_chunk_size
self.max_cluster = max_chunk_size//min_chunk_size
self.embedding_function = embedding_function
def _get_similarity_matrix(self, embedding_function, sentences):
N = len(sentences)
embedding_matrix = None
for i in range(0, N, BATCH_SIZE):
batch_sentences = sentences[i:i+BATCH_SIZE]
embeddings = embedding_function(batch_sentences)
# Convert embeddings list of lists to numpy array
batch_embedding_matrix = np.array(embeddings)
# Append the batch embedding matrix to the main embedding matrix
if embedding_matrix is None:
embedding_matrix = batch_embedding_matrix
embedding_matrix = np.concatenate((embedding_matrix, batch_embedding_matrix), axis=0)
similarity_matrix =, embedding_matrix.T)
return similarity_matrix
def _calculate_reward(self, matrix, start, end):
sub_matrix = matrix[start:end+1, start:end+1]
return np.sum(sub_matrix)
def _optimal_segmentation(self, matrix, max_cluster_size, window_size=3):
mean_value = np.mean(matrix[np.triu_indices(matrix.shape[0], k=1)])
matrix = matrix - mean_value # Normalize the matrix
np.fill_diagonal(matrix, 0) # Set diagonal to 1 to avoid trivial solutions
n = matrix.shape[0]
dp = np.zeros(n)
segmentation = np.zeros(n, dtype=int)
for i in range(n):
for size in range(1, max_cluster_size + 1):
if i - size + 1 >= 0:
# local_density = calculate_local_density(matrix, i, window_size)
reward = self._calculate_reward(matrix, i - size + 1, i)
# Adjust reward based on local density
adjusted_reward = reward
if i - size >= 0:
adjusted_reward += dp[i - size]
if adjusted_reward > dp[i]:
dp[i] = adjusted_reward
segmentation[i] = i - size + 1
clusters = []
i = n - 1
while i >= 0:
start = segmentation[i]
clusters.append((start, i))
i = start - 1
return clusters
def split_text(self, text: str) -> List[str]:
sentences = self.splitter.split_text(text)
similarity_matrix = self._get_similarity_matrix(self.embedding_function, sentences)
clusters = self._optimal_segmentation(similarity_matrix, max_cluster_size=self.max_cluster)
docs = [' '.join(sentences[start:end+1]) for start, end in clusters]
return docs