# This script is adapted from the LangChain package, developed by LangChain AI. # Original code can be found at: https://github.com/langchain-ai/langchain/blob/master/libs/text-splitters/langchain_text_splitters/character.py # Original code can be found at: https://github.com/langchain-ai/langchain/blob/master/libs/text-splitters/langchain_text_splitters/base.py # License: MIT License from attr import dataclass import numpy as np import tiktoken from chunking_evaluation.utils import Language import re from abc import ABC, abstractmethod from enum import Enum import logging from typing import ( AbstractSet, Any, Callable, Collection, Iterable, List, Literal, Optional, Sequence, Type, TypeVar, Union, ) class BaseChunker(ABC): @abstractmethod def split_text(self, text: str) -> List[str]: pass logger = logging.getLogger(__name__) TS = TypeVar("TS", bound="TextSplitter") class TextSplitter(BaseChunker, ABC): """Interface for splitting text into chunks.""" def __init__( self, chunk_size: int = 4000, chunk_overlap: int = 200, length_function: Callable[[str], int] = len, keep_separator: bool = False, add_start_index: bool = False, strip_whitespace: bool = True, ) -> None: """Create a new TextSplitter. Args: chunk_size: Maximum size of chunks to return chunk_overlap: Overlap in characters between chunks length_function: Function that measures the length of given chunks keep_separator: Whether to keep the separator in the chunks add_start_index: If `True`, includes chunk's start index in metadata strip_whitespace: If `True`, strips whitespace from the start and end of every document """ if chunk_overlap > chunk_size: raise ValueError( f"Got a larger chunk overlap ({chunk_overlap}) than chunk size " f"({chunk_size}), should be smaller." ) self._chunk_size = chunk_size self._chunk_overlap = chunk_overlap self._length_function = length_function self._keep_separator = keep_separator self._add_start_index = add_start_index self._strip_whitespace = strip_whitespace @abstractmethod def split_text(self, text: str) -> List[str]: """Split text into multiple components.""" def _join_docs(self, docs: List[str], separator: str) -> Optional[str]: text = separator.join(docs) if self._strip_whitespace: text = text.strip() if text == "": return None else: return text def _merge_splits(self, splits: Iterable[str], separator: str) -> List[str]: # We now want to combine these smaller pieces into medium size # chunks to send to the LLM. separator_len = self._length_function(separator) docs = [] current_doc: List[str] = [] total = 0 for d in splits: _len = self._length_function(d) if ( total + _len + (separator_len if len(current_doc) > 0 else 0) > self._chunk_size ): if total > self._chunk_size: logger.warning( f"Created a chunk of size {total}, " f"which is longer than the specified {self._chunk_size}" ) if len(current_doc) > 0: doc = self._join_docs(current_doc, separator) if doc is not None: docs.append(doc) # Keep on popping if: # - we have a larger chunk than in the chunk overlap # - or if we still have any chunks and the length is long while total > self._chunk_overlap or ( total + _len + (separator_len if len(current_doc) > 0 else 0) > self._chunk_size and total > 0 ): total -= self._length_function(current_doc[0]) + ( separator_len if len(current_doc) > 1 else 0 ) current_doc = current_doc[1:] current_doc.append(d) total += _len + (separator_len if len(current_doc) > 1 else 0) doc = self._join_docs(current_doc, separator) if doc is not None: docs.append(doc) return docs # @classmethod # def from_huggingface_tokenizer(cls, tokenizer: Any, **kwargs: Any) -> TextSplitter: # """Text splitter that uses HuggingFace tokenizer to count length.""" # try: # from transformers import PreTrainedTokenizerBase # if not isinstance(tokenizer, PreTrainedTokenizerBase): # raise ValueError( # "Tokenizer received was not an instance of PreTrainedTokenizerBase" # ) # def _huggingface_tokenizer_length(text: str) -> int: # return len(tokenizer.encode(text)) # except ImportError: # raise ValueError( # "Could not import transformers python package. " # "Please install it with `pip install transformers`." # ) # return cls(length_function=_huggingface_tokenizer_length, **kwargs) @classmethod def from_tiktoken_encoder( cls: Type[TS], encoding_name: str = "gpt2", model_name: Optional[str] = None, allowed_special: Union[Literal["all"], AbstractSet[str]] = set(), disallowed_special: Union[Literal["all"], Collection[str]] = "all", **kwargs: Any, ) -> TS: """Text splitter that uses tiktoken encoder to count length.""" try: import tiktoken except ImportError: raise ImportError( "Could not import tiktoken python package. " "This is needed in order to calculate max_tokens_for_prompt. " "Please install it with `pip install tiktoken`." ) if model_name is not None: enc = tiktoken.encoding_for_model(model_name) else: enc = tiktoken.get_encoding(encoding_name) def _tiktoken_encoder(text: str) -> int: return len( enc.encode( text, allowed_special=allowed_special, disallowed_special=disallowed_special, ) ) if issubclass(cls, FixedTokenChunker): extra_kwargs = { "encoding_name": encoding_name, "model_name": model_name, "allowed_special": allowed_special, "disallowed_special": disallowed_special, } kwargs = {**kwargs, **extra_kwargs} return cls(length_function=_tiktoken_encoder, **kwargs) class FixedTokenChunker(TextSplitter): """Splitting text to tokens using model tokenizer.""" def __init__( self, encoding_name: str = "cl100k_base", model_name: Optional[str] = None, chunk_size: int = 4000, chunk_overlap: int = 200, allowed_special: Union[Literal["all"], AbstractSet[str]] = set(), disallowed_special: Union[Literal["all"], Collection[str]] = "all", **kwargs: Any, ) -> None: """Create a new TextSplitter.""" super().__init__(chunk_size=chunk_size, chunk_overlap=chunk_overlap, **kwargs) try: import tiktoken except ImportError: raise ImportError( "Could not import tiktoken python package. " "This is needed in order to for FixedTokenChunker. " "Please install it with `pip install tiktoken`." ) if model_name is not None: enc = tiktoken.encoding_for_model(model_name) else: enc = tiktoken.get_encoding(encoding_name) self._tokenizer = enc self._allowed_special = allowed_special self._disallowed_special = disallowed_special def split_text(self, text: str) -> List[str]: def _encode(_text: str) -> List[int]: return self._tokenizer.encode( _text, allowed_special=self._allowed_special, disallowed_special=self._disallowed_special, ) tokenizer = Tokenizer( chunk_overlap=self._chunk_overlap, tokens_per_chunk=self._chunk_size, decode=self._tokenizer.decode, encode=_encode, ) return split_text_on_tokens(text=text, tokenizer=tokenizer) @dataclass(frozen=True) class Tokenizer: """Tokenizer data class.""" chunk_overlap: int """Overlap in tokens between chunks""" tokens_per_chunk: int """Maximum number of tokens per chunk""" decode: Callable[[List[int]], str] """ Function to decode a list of token ids to a string""" encode: Callable[[str], List[int]] """ Function to encode a string to a list of token ids""" def split_text_on_tokens(*, text: str, tokenizer: Tokenizer) -> List[str]: """Split incoming text and return chunks using tokenizer.""" splits: List[str] = [] input_ids = tokenizer.encode(text) start_idx = 0 cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids)) chunk_ids = input_ids[start_idx:cur_idx] while start_idx < len(input_ids): splits.append(tokenizer.decode(chunk_ids)) if cur_idx == len(input_ids): break start_idx += tokenizer.tokens_per_chunk - tokenizer.chunk_overlap cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids)) chunk_ids = input_ids[start_idx:cur_idx] return splits def _split_text_with_regex( text: str, separator: str, keep_separator: bool ) -> List[str]: # Now that we have the separator, split the text if separator: if keep_separator: # The parentheses in the pattern keep the delimiters in the result. _splits = re.split(f"({separator})", text) splits = [_splits[i] + _splits[i + 1] for i in range(1, len(_splits), 2)] if len(_splits) % 2 == 0: splits += _splits[-1:] splits = [_splits[0]] + splits else: splits = re.split(separator, text) else: splits = list(text) return [s for s in splits if s != ""] class RecursiveTokenChunker(TextSplitter): """Splitting text by recursively look at characters. Recursively tries to split by different characters to find one that works. """ def __init__( self, chunk_size: int = 4000, chunk_overlap: int = 200, separators: Optional[List[str]] = None, keep_separator: bool = True, is_separator_regex: bool = False, **kwargs: Any, ) -> None: """Create a new TextSplitter.""" super().__init__(chunk_size=chunk_size, chunk_overlap=chunk_overlap, keep_separator=keep_separator, **kwargs) self._separators = separators or ["\n\n", "\n", ".", "?", "!", " ", ""] self._is_separator_regex = is_separator_regex def _split_text(self, text: str, separators: List[str]) -> List[str]: """Split incoming text and return chunks.""" final_chunks = [] # Get appropriate separator to use separator = separators[-1] new_separators = [] for i, _s in enumerate(separators): _separator = _s if self._is_separator_regex else re.escape(_s) if _s == "": separator = _s break if re.search(_separator, text): separator = _s new_separators = separators[i + 1 :] break _separator = separator if self._is_separator_regex else re.escape(separator) splits = _split_text_with_regex(text, _separator, self._keep_separator) # Now go merging things, recursively splitting longer texts. _good_splits = [] _separator = "" if self._keep_separator else separator for s in splits: if self._length_function(s) < self._chunk_size: _good_splits.append(s) else: if _good_splits: merged_text = self._merge_splits(_good_splits, _separator) final_chunks.extend(merged_text) _good_splits = [] if not new_separators: final_chunks.append(s) else: other_info = self._split_text(s, new_separators) final_chunks.extend(other_info) if _good_splits: merged_text = self._merge_splits(_good_splits, _separator) final_chunks.extend(merged_text) return final_chunks def split_text(self, text: str) -> List[str]: return self._split_text(text, self._separators) # @classmethod # def from_language( # cls, language: Language, **kwargs: Any # ) -> RecursiveCharacterTextSplitter: # separators = cls.get_separators_for_language(language) # return cls(separators=separators, is_separator_regex=True, **kwargs) @staticmethod def get_separators_for_language(language: Language) -> List[str]: if language == Language.CPP: return [ # Split along class definitions "\nclass ", # Split along function definitions "\nvoid ", "\nint ", "\nfloat ", "\ndouble ", # Split along control flow statements "\nif ", "\nfor ", "\nwhile ", "\nswitch ", "\ncase ", # Split by the normal type of lines "\n\n", "\n", " ", "", ] elif language == Language.GO: return [ # Split along function definitions "\nfunc ", "\nvar ", "\nconst ", "\ntype ", # Split along control flow statements "\nif ", "\nfor ", "\nswitch ", "\ncase ", # Split by the normal type of lines "\n\n", "\n", " ", "", ] elif language == Language.JAVA: return [ # Split along class definitions "\nclass ", # Split along method definitions "\npublic ", "\nprotected ", "\nprivate ", "\nstatic ", # Split along control flow statements "\nif ", "\nfor ", "\nwhile ", "\nswitch ", "\ncase ", # Split by the normal type of lines "\n\n", "\n", " ", "", ] elif language == Language.KOTLIN: return [ # Split along class definitions "\nclass ", # Split along method definitions "\npublic ", "\nprotected ", "\nprivate ", "\ninternal ", "\ncompanion ", "\nfun ", "\nval ", "\nvar ", # Split along control flow statements "\nif ", "\nfor ", "\nwhile ", "\nwhen ", "\ncase ", "\nelse ", # Split by the normal type of lines "\n\n", "\n", " ", "", ] elif language == Language.JS: return [ # Split along function definitions "\nfunction ", "\nconst ", "\nlet ", "\nvar ", "\nclass ", # Split along control flow statements "\nif ", "\nfor ", "\nwhile ", "\nswitch ", "\ncase ", "\ndefault ", # Split by the normal type of lines "\n\n", "\n", " ", "", ] elif language == Language.TS: return [ "\nenum ", "\ninterface ", "\nnamespace ", "\ntype ", # Split along class definitions "\nclass ", # Split along function definitions "\nfunction ", "\nconst ", "\nlet ", "\nvar ", # Split along control flow statements "\nif ", "\nfor ", "\nwhile ", "\nswitch ", "\ncase ", "\ndefault ", # Split by the normal type of lines "\n\n", "\n", " ", "", ] elif language == Language.PHP: return [ # Split along function definitions "\nfunction ", # Split along class definitions "\nclass ", # Split along control flow statements "\nif ", "\nforeach ", "\nwhile ", "\ndo ", "\nswitch ", "\ncase ", # Split by the normal type of lines "\n\n", "\n", " ", "", ] elif language == Language.PROTO: return [ # Split along message definitions "\nmessage ", # Split along service definitions "\nservice ", # Split along enum definitions "\nenum ", # Split along option definitions "\noption ", # Split along import statements "\nimport ", # Split along syntax declarations "\nsyntax ", # Split by the normal type of lines "\n\n", "\n", " ", "", ] elif language == Language.PYTHON: return [ # First, try to split along class definitions "\nclass ", "\ndef ", "\n\tdef ", # Now split by the normal type of lines "\n\n", "\n", " ", "", ] elif language == Language.RST: return [ # Split along section titles "\n=+\n", "\n-+\n", "\n\\*+\n", # Split along directive markers "\n\n.. *\n\n", # Split by the normal type of lines "\n\n", "\n", " ", "", ] elif language == Language.RUBY: return [ # Split along method definitions "\ndef ", "\nclass ", # Split along control flow statements "\nif ", "\nunless ", "\nwhile ", "\nfor ", "\ndo ", "\nbegin ", "\nrescue ", # Split by the normal type of lines "\n\n", "\n", " ", "", ] elif language == Language.RUST: return [ # Split along function definitions "\nfn ", "\nconst ", "\nlet ", # Split along control flow statements "\nif ", "\nwhile ", "\nfor ", "\nloop ", "\nmatch ", "\nconst ", # Split by the normal type of lines "\n\n", "\n", " ", "", ] elif language == Language.SCALA: return [ # Split along class definitions "\nclass ", "\nobject ", # Split along method definitions "\ndef ", "\nval ", "\nvar ", # Split along control flow statements "\nif ", "\nfor ", "\nwhile ", "\nmatch ", "\ncase ", # Split by the normal type of lines "\n\n", "\n", " ", "", ] elif language == Language.SWIFT: return [ # Split along function definitions "\nfunc ", # Split along class definitions "\nclass ", "\nstruct ", "\nenum ", # Split along control flow statements "\nif ", "\nfor ", "\nwhile ", "\ndo ", "\nswitch ", "\ncase ", # Split by the normal type of lines "\n\n", "\n", " ", "", ] elif language == Language.MARKDOWN: return [ # First, try to split along Markdown headings (starting with level 2) "\n#{1,6} ", # Note the alternative syntax for headings (below) is not handled here # Heading level 2 # --------------- # End of code block "```\n", # Horizontal lines "\n\\*\\*\\*+\n", "\n---+\n", "\n___+\n", # Note that this splitter doesn't handle horizontal lines defined # by *three or more* of ***, ---, or ___, but this is not handled "\n\n", "\n", " ", "", ] elif language == Language.LATEX: return [ # First, try to split along Latex sections "\n\\\\chapter{", "\n\\\\section{", "\n\\\\subsection{", "\n\\\\subsubsection{", # Now split by environments "\n\\\\begin{enumerate}", "\n\\\\begin{itemize}", "\n\\\\begin{description}", "\n\\\\begin{list}", "\n\\\\begin{quote}", "\n\\\\begin{quotation}", "\n\\\\begin{verse}", "\n\\\\begin{verbatim}", # Now split by math environments "\n\\\begin{align}", "$$", "$", # Now split by the normal type of lines " ", "", ] elif language == Language.HTML: return [ # First, try to split along HTML tags "= 0: # local_density = calculate_local_density(matrix, i, window_size) reward = self._calculate_reward(matrix, i - size + 1, i) # Adjust reward based on local density adjusted_reward = reward if i - size >= 0: adjusted_reward += dp[i - size] if adjusted_reward > dp[i]: dp[i] = adjusted_reward segmentation[i] = i - size + 1 clusters = [] i = n - 1 while i >= 0: start = segmentation[i] clusters.append((start, i)) i = start - 1 clusters.reverse() return clusters def split_text(self, text: str) -> List[str]: sentences = self.splitter.split_text(text) similarity_matrix = self._get_similarity_matrix(self.embedding_function, sentences) clusters = self._optimal_segmentation(similarity_matrix, max_cluster_size=self.max_cluster) docs = [' '.join(sentences[start:end+1]) for start, end in clusters] return docs