from dataclasses import dataclass from typing import Any, List import numpy as np from semantic_router.encoders.base import BaseEncoder from semantic_chunkers.schema import Chunk from semantic_chunkers.chunkers.base import BaseChunker from semantic_chunkers.splitters.base import BaseSplitter from semantic_chunkers.splitters.sentence import RegexSplitter from semantic_chunkers.utils.text import tiktoken_length from semantic_chunkers.utils.logger import logger from tqdm.auto import tqdm @dataclass class ChunkStatistics: total_documents: int total_chunks: int chunks_by_threshold: int chunks_by_max_chunk_size: int chunks_by_last_split: int min_token_size: int max_token_size: int chunks_by_similarity_ratio: float def __str__(self): return ( f"Chunking Statistics:\n" f" - Total Documents: {self.total_documents}\n" f" - Total Chunks: {self.total_chunks}\n" f" - Chunks by Threshold: {self.chunks_by_threshold}\n" f" - Chunks by Max Chunk Size: {self.chunks_by_max_chunk_size}\n" f" - Last Chunk: {self.chunks_by_last_split}\n" f" - Minimum Token Size of Chunk: {self.min_token_size}\n" f" - Maximum Token Size of Chunk: {self.max_token_size}\n" f" - Similarity Chunk Ratio: {self.chunks_by_similarity_ratio:.2f}" ) class StatisticalChunker(BaseChunker): def __init__( self, encoder: BaseEncoder, splitter: BaseSplitter = RegexSplitter(), name="statistical_chunker", threshold_adjustment=0.01, dynamic_threshold: bool = True, window_size=5, min_split_tokens=100, max_split_tokens=300, split_tokens_tolerance=10, plot_chunks=False, enable_statistics=False, ): super().__init__(name=name, encoder=encoder, splitter=splitter) self.calculated_threshold: float self.encoder = encoder self.threshold_adjustment = threshold_adjustment self.dynamic_threshold = dynamic_threshold self.window_size = window_size self.plot_chunks = plot_chunks self.min_split_tokens = min_split_tokens self.max_split_tokens = max_split_tokens self.split_tokens_tolerance = split_tokens_tolerance self.enable_statistics = enable_statistics self.statistics: ChunkStatistics def _chunk( self, splits: List[Any], metadatas: List[dict], batch_size: int = 64, enforce_max_tokens: bool = False ) -> List[Chunk]: """Merge splits into chunks using semantic similarity, with optional enforcement of maximum token limits per chunk. :param splits: Splits to be merged into chunks. :param batch_size: Number of splits to process in one batch. :param enforce_max_tokens: If True, further split chunks that exceed the maximum token limit. :return: List of chunks. """ # Split the docs that already exceed max_split_tokens to smaller chunks if enforce_max_tokens: new_splits = [] for split in splits: token_count = tiktoken_length(split) if token_count > self.max_split_tokens: logger.info( f"Single document exceeds the maximum token limit " f"of {self.max_split_tokens}. " "Splitting to sentences before semantically merging." ) _splits = self._split(split) new_splits.extend(_splits) else: new_splits.append(split) splits = [split for split in new_splits if split and split.strip()] chunks = [] last_split = None for i in tqdm(range(0, len(splits), batch_size)): batch_splits = splits[i : i + batch_size] if last_split is not None: batch_splits = last_split.splits + batch_splits encoded_splits = self._encode_documents(batch_splits) similarities = self._calculate_similarity_scores(encoded_splits) if self.dynamic_threshold: self._find_optimal_threshold(batch_splits, similarities) else: self.calculated_threshold = self.encoder.score_threshold split_indices = self._find_split_indices(similarities=similarities) doc_chunks = self._split_documents( batch_splits, metadatas, split_indices, similarities ) if len(doc_chunks) > 1: chunks.extend(doc_chunks[:-1]) last_split = doc_chunks[-1] else: last_split = doc_chunks[0] if self.plot_chunks: self.plot_similarity_scores(similarities, split_indices, doc_chunks) if self.enable_statistics: print(self.statistics) if last_split: chunks.append(last_split) return chunks def __call__(self, docs: List[str], metadatas: List[dict], batch_size: int = 64) -> List[List[Chunk]]: """Split documents into smaller chunks based on semantic similarity. :param docs: list of text documents to be split, if only wanted to split a single document, pass it as a list with a single element. :return: list of Chunk objects containing the split documents. """ if not docs: raise ValueError("At least one document is required for splitting.") all_chunks = [] for doc in docs: token_count = tiktoken_length(doc) if token_count > self.max_split_tokens: logger.info( f"Single document exceeds the maximum token limit " f"of {self.max_split_tokens}. " "Splitting to sentences before semantically merging." ) if isinstance(doc, str): splits = self._split(doc) doc_chunks = self._chunk(splits, metadatas, batch_size=batch_size) all_chunks.append(doc_chunks) else: raise ValueError("The document must be a string.") return all_chunks def _encode_documents(self, docs: List[str]) -> np.ndarray: """ :param docs: List of text documents to be encoded. :return: A numpy array of embeddings for the given documents. """ return np.array(self.encoder(docs)) def _calculate_similarity_scores(self, encoded_docs: np.ndarray) -> List[float]: raw_similarities = [] for idx in range(1, len(encoded_docs)): window_start = max(0, idx - self.window_size) cumulative_context = np.mean(encoded_docs[window_start:idx], axis=0) curr_sim_score = np.dot(cumulative_context, encoded_docs[idx]) / ( np.linalg.norm(cumulative_context) * np.linalg.norm(encoded_docs[idx]) + 1e-10 ) raw_similarities.append(curr_sim_score) return raw_similarities def _find_split_indices(self, similarities: List[float]) -> List[int]: split_indices = [] for idx, score in enumerate(similarities): logger.debug(f"Similarity score at index {idx}: {score}") if score < self.calculated_threshold: logger.debug( f"Adding to split_indices due to score < threshold: " f"{score} < {self.calculated_threshold}" ) # Chunk after the document at idx split_indices.append(idx + 1) return split_indices def _find_optimal_threshold(self, docs: List[str], similarity_scores: List[float]): token_counts = [tiktoken_length(doc) for doc in docs] cumulative_token_counts = np.cumsum([0] + token_counts) # Analyze the distribution of similarity scores to set initial bounds median_score = np.median(similarity_scores) std_dev = np.std(similarity_scores) # Set initial bounds based on median and standard deviation low = max(0.0, float(median_score - std_dev)) high = min(1.0, float(median_score + std_dev)) iteration = 0 median_tokens = 0 while low <= high: self.calculated_threshold = (low + high) / 2 split_indices = self._find_split_indices(similarity_scores) logger.debug( f"Iteration {iteration}: Trying threshold: {self.calculated_threshold}" ) # Calculate the token counts for each split using the cumulative sums split_token_counts = [ cumulative_token_counts[end] - cumulative_token_counts[start] for start, end in zip( [0] + split_indices, split_indices + [len(token_counts)] ) ] # Calculate the median token count for the chunks median_tokens = np.median(split_token_counts) logger.debug( f"Iteration {iteration}: Median tokens per split: {median_tokens}" ) if ( self.min_split_tokens - self.split_tokens_tolerance <= median_tokens <= self.max_split_tokens + self.split_tokens_tolerance ): logger.debug("Median tokens in target range. Stopping iteration.") break elif median_tokens < self.min_split_tokens: high = self.calculated_threshold - self.threshold_adjustment logger.debug(f"Iteration {iteration}: Adjusting high to {high}") else: low = self.calculated_threshold + self.threshold_adjustment logger.debug(f"Iteration {iteration}: Adjusting low to {low}") iteration += 1 logger.debug( f"Optimal threshold {self.calculated_threshold} found " f"with median tokens ({median_tokens}) in target range " f"({self.min_split_tokens}-{self.max_split_tokens})." ) return self.calculated_threshold def _split_documents( self, docs: List[str], metadatas: List[dict], split_indices: List[int], similarities: List[float] ) -> List[Chunk]: """ This method iterates through each document, appending it to the current split until it either reaches a split point (determined by split_indices) or exceeds the maximum token limit for a split (self.max_split_tokens). When a document causes the current token count to exceed this limit, or when a split point is reached and the minimum token requirement is met, the current split is finalized and added to the List of chunks. """ token_counts = [tiktoken_length(doc) for doc in docs] chunks, current_split = [], [] current_tokens_count = 0 # Statistics chunks_by_threshold = 0 chunks_by_max_chunk_size = 0 chunks_by_last_split = 0 for doc_idx, doc in enumerate(docs): doc_token_count = token_counts[doc_idx] logger.debug(f"Accumulative token count: {current_tokens_count} tokens") logger.debug(f"Document token count: {doc_token_count} tokens") # Check if current index is a split point based on similarity if doc_idx + 1 in split_indices: if ( self.min_split_tokens <= current_tokens_count + doc_token_count < self.max_split_tokens ): # Include the current document before splitting # if it doesn't exceed the max limit current_split.append(doc) current_tokens_count += doc_token_count triggered_score = ( similarities[doc_idx] if doc_idx < len(similarities) else None ) chunks.append( Chunk( splits=current_split.copy(), is_triggered=True, triggered_score=triggered_score, token_count=current_tokens_count, metadata=metadatas[doc_idx].copy() ) ) logger.debug( f"Chunk finalized with {current_tokens_count} tokens due to " f"threshold {self.calculated_threshold}." ) current_split, current_tokens_count = [], 0 chunks_by_threshold += 1 continue # Move to the next document after splitting # Check if adding the current document exceeds the max token limit if current_tokens_count + doc_token_count > self.max_split_tokens: if current_tokens_count >= self.min_split_tokens: chunks.append( Chunk( splits=current_split.copy(), is_triggered=False, triggered_score=None, token_count=current_tokens_count, metadata=metadatas[doc_idx].copy() ) ) chunks_by_max_chunk_size += 1 logger.debug( f"Chink finalized with {current_tokens_count} tokens due to " f"exceeding token limit of {self.max_split_tokens}." ) current_split, current_tokens_count = [], 0 current_split.append(doc) current_tokens_count += doc_token_count # Handle the last split if current_split: chunks.append( Chunk( splits=current_split.copy(), is_triggered=False, triggered_score=None, token_count=current_tokens_count, metadata=metadatas[doc_idx].copy() ) ) chunks_by_last_split += 1 logger.debug( f"Final split added with {current_tokens_count} " "tokens due to remaining documents." ) # Validation to ensure no tokens are lost during the split original_token_count = sum(token_counts) split_token_count = sum( [tiktoken_length(doc) for split in chunks for doc in split.splits] ) if original_token_count != split_token_count: logger.error( f"Token count mismatch: {original_token_count} != {split_token_count}" ) raise ValueError( f"Token count mismatch: {original_token_count} != {split_token_count}" ) # Statistics total_chunks = len(chunks) chunks_by_similarity_ratio = ( chunks_by_threshold / total_chunks if total_chunks else 0 ) min_token_size = max_token_size = 0 if chunks: token_counts = [ split.token_count for split in chunks if split.token_count is not None ] min_token_size, max_token_size = min(token_counts, default=0), max( token_counts, default=0 ) self.statistics = ChunkStatistics( total_documents=len(docs), total_chunks=total_chunks, chunks_by_threshold=chunks_by_threshold, chunks_by_max_chunk_size=chunks_by_max_chunk_size, chunks_by_last_split=chunks_by_last_split, min_token_size=min_token_size, max_token_size=max_token_size, chunks_by_similarity_ratio=chunks_by_similarity_ratio, ) return chunks def plot_similarity_scores( self, similarities: List[float], split_indices: List[int], chunks: list[Chunk], ): try: from matplotlib import pyplot as plt except ImportError: logger.warning( "Plotting is disabled. Please `pip install " "semantic-router[processing]`." ) return _, axs = plt.subplots(2, 1, figsize=(12, 12)) # Adjust for two plots # Plot 1: Similarity Scores axs[0].plot(similarities, label="Similarity Scores", marker="o") for split_index in split_indices: axs[0].axvline( x=split_index - 1, color="r", linestyle="--", label="Chunk" if split_index == split_indices[0] else "", ) axs[0].axhline( y=self.calculated_threshold, color="g", linestyle="-.", label="Threshold Similarity Score", ) # Annotating each similarity score for i, score in enumerate(similarities): axs[0].annotate( f"{score:.2f}", # Formatting to two decimal places (i, score), textcoords="offset points", xytext=(0, 10), # Positioning the text above the point ha="center", ) # Center-align the text axs[0].set_xlabel("Document Segment Index") axs[0].set_ylabel("Similarity Score") axs[0].set_title( f"Threshold: {self.calculated_threshold} |" f" Window Size: {self.window_size}", loc="right", fontsize=10, ) axs[0].legend() # Plot 2: Chunk Token Size Distribution token_counts = [split.token_count for split in chunks] axs[1].bar(range(len(token_counts)), token_counts, color="lightblue") axs[1].set_title("Chunk Token Sizes") axs[1].set_xlabel("Chunk Index") axs[1].set_ylabel("Token Count") axs[1].set_xticks(range(len(token_counts))) axs[1].set_xticklabels([str(i) for i in range(len(token_counts))]) axs[1].grid(True) # Annotate each bar with the token size for idx, token_count in enumerate(token_counts): if not token_count: continue axs[1].text( idx, token_count + 0.01, str(token_count), ha="center", va="bottom" ) plt.tight_layout() plt.show() def plot_sentence_similarity_scores( self, docs: List[str], threshold: float, window_size: int ): try: from matplotlib import pyplot as plt except ImportError: logger.warning("Plotting is disabled. Please `pip install matplotlib`.") return """ Computes similarity scores between the average of the last 'window_size' sentences and the next one, plots a graph of these similarity scores, and prints the first sentence after a similarity score below a specified threshold. """ sentences = [sentence for doc in docs for sentence in self._split(doc)] encoded_sentences = self._encode_documents(sentences) similarity_scores = [] for i in range(window_size, len(encoded_sentences)): window_avg_encoding = np.mean( encoded_sentences[i - window_size : i], axis=0 ) sim_score = np.dot(window_avg_encoding, encoded_sentences[i]) / ( np.linalg.norm(window_avg_encoding) * np.linalg.norm(encoded_sentences[i]) + 1e-10 ) similarity_scores.append(sim_score) plt.figure(figsize=(10, 8)) plt.plot(similarity_scores, marker="o", linestyle="-", color="b") plt.title("Sliding Window Sentence Similarity Scores") plt.xlabel("Sentence Index") plt.ylabel("Similarity Score") plt.grid(True) plt.axhline(y=threshold, color="r", linestyle="--", label="Threshold") plt.show() for i, score in enumerate(similarity_scores): if score < threshold: print( f"First sentence after similarity score " f"below {threshold}: {sentences[i + window_size]}" )