|
from dataclasses import dataclass |
|
from typing import Any, List |
|
|
|
import numpy as np |
|
|
|
from semantic_router.encoders.base import BaseEncoder |
|
from semantic_chunkers.schema import Chunk |
|
from semantic_chunkers.chunkers.base import BaseChunker |
|
from semantic_chunkers.splitters.base import BaseSplitter |
|
from semantic_chunkers.splitters.sentence import RegexSplitter |
|
from semantic_chunkers.utils.text import tiktoken_length |
|
from semantic_chunkers.utils.logger import logger |
|
|
|
from tqdm.auto import tqdm |
|
|
|
|
|
@dataclass |
|
class ChunkStatistics: |
|
total_documents: int |
|
total_chunks: int |
|
chunks_by_threshold: int |
|
chunks_by_max_chunk_size: int |
|
chunks_by_last_split: int |
|
min_token_size: int |
|
max_token_size: int |
|
chunks_by_similarity_ratio: float |
|
|
|
def __str__(self): |
|
return ( |
|
f"Chunking Statistics:\n" |
|
f" - Total Documents: {self.total_documents}\n" |
|
f" - Total Chunks: {self.total_chunks}\n" |
|
f" - Chunks by Threshold: {self.chunks_by_threshold}\n" |
|
f" - Chunks by Max Chunk Size: {self.chunks_by_max_chunk_size}\n" |
|
f" - Last Chunk: {self.chunks_by_last_split}\n" |
|
f" - Minimum Token Size of Chunk: {self.min_token_size}\n" |
|
f" - Maximum Token Size of Chunk: {self.max_token_size}\n" |
|
f" - Similarity Chunk Ratio: {self.chunks_by_similarity_ratio:.2f}" |
|
) |
|
|
|
|
|
class StatisticalChunker(BaseChunker): |
|
def __init__( |
|
self, |
|
encoder: BaseEncoder, |
|
splitter: BaseSplitter = RegexSplitter(), |
|
name="statistical_chunker", |
|
threshold_adjustment=0.01, |
|
dynamic_threshold: bool = True, |
|
window_size=5, |
|
min_split_tokens=100, |
|
max_split_tokens=300, |
|
split_tokens_tolerance=10, |
|
plot_chunks=False, |
|
enable_statistics=False, |
|
): |
|
super().__init__(name=name, encoder=encoder, splitter=splitter) |
|
self.calculated_threshold: float |
|
self.encoder = encoder |
|
self.threshold_adjustment = threshold_adjustment |
|
self.dynamic_threshold = dynamic_threshold |
|
self.window_size = window_size |
|
self.plot_chunks = plot_chunks |
|
self.min_split_tokens = min_split_tokens |
|
self.max_split_tokens = max_split_tokens |
|
self.split_tokens_tolerance = split_tokens_tolerance |
|
self.enable_statistics = enable_statistics |
|
self.statistics: ChunkStatistics |
|
|
|
def _chunk( |
|
self, splits: List[Any], metadatas: List[dict], batch_size: int = 64, enforce_max_tokens: bool = False |
|
) -> List[Chunk]: |
|
"""Merge splits into chunks using semantic similarity, with optional enforcement of maximum token limits per chunk. |
|
|
|
:param splits: Splits to be merged into chunks. |
|
:param batch_size: Number of splits to process in one batch. |
|
:param enforce_max_tokens: If True, further split chunks that exceed the maximum token limit. |
|
|
|
:return: List of chunks. |
|
""" |
|
|
|
if enforce_max_tokens: |
|
new_splits = [] |
|
for split in splits: |
|
token_count = tiktoken_length(split) |
|
if token_count > self.max_split_tokens: |
|
logger.info( |
|
f"Single document exceeds the maximum token limit " |
|
f"of {self.max_split_tokens}. " |
|
"Splitting to sentences before semantically merging." |
|
) |
|
_splits = self._split(split) |
|
new_splits.extend(_splits) |
|
else: |
|
new_splits.append(split) |
|
|
|
splits = [split for split in new_splits if split and split.strip()] |
|
|
|
chunks = [] |
|
last_split = None |
|
for i in tqdm(range(0, len(splits), batch_size)): |
|
batch_splits = splits[i : i + batch_size] |
|
if last_split is not None: |
|
batch_splits = last_split.splits + batch_splits |
|
|
|
encoded_splits = self._encode_documents(batch_splits) |
|
similarities = self._calculate_similarity_scores(encoded_splits) |
|
if self.dynamic_threshold: |
|
self._find_optimal_threshold(batch_splits, similarities) |
|
else: |
|
self.calculated_threshold = self.encoder.score_threshold |
|
split_indices = self._find_split_indices(similarities=similarities) |
|
doc_chunks = self._split_documents( |
|
batch_splits, metadatas, split_indices, similarities |
|
) |
|
|
|
if len(doc_chunks) > 1: |
|
chunks.extend(doc_chunks[:-1]) |
|
last_split = doc_chunks[-1] |
|
else: |
|
last_split = doc_chunks[0] |
|
|
|
if self.plot_chunks: |
|
self.plot_similarity_scores(similarities, split_indices, doc_chunks) |
|
|
|
if self.enable_statistics: |
|
print(self.statistics) |
|
|
|
if last_split: |
|
chunks.append(last_split) |
|
|
|
return chunks |
|
|
|
def __call__(self, docs: List[str], metadatas: List[dict], batch_size: int = 64) -> List[List[Chunk]]: |
|
"""Split documents into smaller chunks based on semantic similarity. |
|
|
|
:param docs: list of text documents to be split, if only wanted to |
|
split a single document, pass it as a list with a single element. |
|
|
|
:return: list of Chunk objects containing the split documents. |
|
""" |
|
if not docs: |
|
raise ValueError("At least one document is required for splitting.") |
|
|
|
all_chunks = [] |
|
for doc in docs: |
|
token_count = tiktoken_length(doc) |
|
if token_count > self.max_split_tokens: |
|
logger.info( |
|
f"Single document exceeds the maximum token limit " |
|
f"of {self.max_split_tokens}. " |
|
"Splitting to sentences before semantically merging." |
|
) |
|
if isinstance(doc, str): |
|
splits = self._split(doc) |
|
doc_chunks = self._chunk(splits, metadatas, batch_size=batch_size) |
|
all_chunks.append(doc_chunks) |
|
else: |
|
raise ValueError("The document must be a string.") |
|
return all_chunks |
|
|
|
def _encode_documents(self, docs: List[str]) -> np.ndarray: |
|
""" |
|
:param docs: List of text documents to be encoded. |
|
:return: A numpy array of embeddings for the given documents. |
|
""" |
|
return np.array(self.encoder(docs)) |
|
|
|
def _calculate_similarity_scores(self, encoded_docs: np.ndarray) -> List[float]: |
|
raw_similarities = [] |
|
for idx in range(1, len(encoded_docs)): |
|
window_start = max(0, idx - self.window_size) |
|
cumulative_context = np.mean(encoded_docs[window_start:idx], axis=0) |
|
curr_sim_score = np.dot(cumulative_context, encoded_docs[idx]) / ( |
|
np.linalg.norm(cumulative_context) * np.linalg.norm(encoded_docs[idx]) |
|
+ 1e-10 |
|
) |
|
raw_similarities.append(curr_sim_score) |
|
return raw_similarities |
|
|
|
def _find_split_indices(self, similarities: List[float]) -> List[int]: |
|
split_indices = [] |
|
for idx, score in enumerate(similarities): |
|
logger.debug(f"Similarity score at index {idx}: {score}") |
|
if score < self.calculated_threshold: |
|
logger.debug( |
|
f"Adding to split_indices due to score < threshold: " |
|
f"{score} < {self.calculated_threshold}" |
|
) |
|
|
|
split_indices.append(idx + 1) |
|
return split_indices |
|
|
|
def _find_optimal_threshold(self, docs: List[str], similarity_scores: List[float]): |
|
token_counts = [tiktoken_length(doc) for doc in docs] |
|
cumulative_token_counts = np.cumsum([0] + token_counts) |
|
|
|
|
|
median_score = np.median(similarity_scores) |
|
std_dev = np.std(similarity_scores) |
|
|
|
|
|
low = max(0.0, float(median_score - std_dev)) |
|
high = min(1.0, float(median_score + std_dev)) |
|
|
|
iteration = 0 |
|
median_tokens = 0 |
|
while low <= high: |
|
self.calculated_threshold = (low + high) / 2 |
|
split_indices = self._find_split_indices(similarity_scores) |
|
logger.debug( |
|
f"Iteration {iteration}: Trying threshold: {self.calculated_threshold}" |
|
) |
|
|
|
|
|
split_token_counts = [ |
|
cumulative_token_counts[end] - cumulative_token_counts[start] |
|
for start, end in zip( |
|
[0] + split_indices, split_indices + [len(token_counts)] |
|
) |
|
] |
|
|
|
|
|
median_tokens = np.median(split_token_counts) |
|
logger.debug( |
|
f"Iteration {iteration}: Median tokens per split: {median_tokens}" |
|
) |
|
if ( |
|
self.min_split_tokens - self.split_tokens_tolerance |
|
<= median_tokens |
|
<= self.max_split_tokens + self.split_tokens_tolerance |
|
): |
|
logger.debug("Median tokens in target range. Stopping iteration.") |
|
break |
|
elif median_tokens < self.min_split_tokens: |
|
high = self.calculated_threshold - self.threshold_adjustment |
|
logger.debug(f"Iteration {iteration}: Adjusting high to {high}") |
|
else: |
|
low = self.calculated_threshold + self.threshold_adjustment |
|
logger.debug(f"Iteration {iteration}: Adjusting low to {low}") |
|
iteration += 1 |
|
|
|
logger.debug( |
|
f"Optimal threshold {self.calculated_threshold} found " |
|
f"with median tokens ({median_tokens}) in target range " |
|
f"({self.min_split_tokens}-{self.max_split_tokens})." |
|
) |
|
|
|
return self.calculated_threshold |
|
|
|
def _split_documents( |
|
self, docs: List[str], metadatas: List[dict], split_indices: List[int], similarities: List[float] |
|
) -> List[Chunk]: |
|
""" |
|
This method iterates through each document, appending it to the current split |
|
until it either reaches a split point (determined by split_indices) or exceeds |
|
the maximum token limit for a split (self.max_split_tokens). |
|
When a document causes the current token count to exceed this limit, |
|
or when a split point is reached and the minimum token requirement is met, |
|
the current split is finalized and added to the List of chunks. |
|
""" |
|
token_counts = [tiktoken_length(doc) for doc in docs] |
|
chunks, current_split = [], [] |
|
current_tokens_count = 0 |
|
|
|
|
|
chunks_by_threshold = 0 |
|
chunks_by_max_chunk_size = 0 |
|
chunks_by_last_split = 0 |
|
|
|
for doc_idx, doc in enumerate(docs): |
|
doc_token_count = token_counts[doc_idx] |
|
logger.debug(f"Accumulative token count: {current_tokens_count} tokens") |
|
logger.debug(f"Document token count: {doc_token_count} tokens") |
|
|
|
if doc_idx + 1 in split_indices: |
|
if ( |
|
self.min_split_tokens |
|
<= current_tokens_count + doc_token_count |
|
< self.max_split_tokens |
|
): |
|
|
|
|
|
current_split.append(doc) |
|
current_tokens_count += doc_token_count |
|
|
|
triggered_score = ( |
|
similarities[doc_idx] if doc_idx < len(similarities) else None |
|
) |
|
chunks.append( |
|
Chunk( |
|
splits=current_split.copy(), |
|
is_triggered=True, |
|
triggered_score=triggered_score, |
|
token_count=current_tokens_count, |
|
metadata=metadatas[doc_idx].copy() |
|
) |
|
) |
|
logger.debug( |
|
f"Chunk finalized with {current_tokens_count} tokens due to " |
|
f"threshold {self.calculated_threshold}." |
|
) |
|
current_split, current_tokens_count = [], 0 |
|
chunks_by_threshold += 1 |
|
continue |
|
|
|
|
|
if current_tokens_count + doc_token_count > self.max_split_tokens: |
|
if current_tokens_count >= self.min_split_tokens: |
|
chunks.append( |
|
Chunk( |
|
splits=current_split.copy(), |
|
is_triggered=False, |
|
triggered_score=None, |
|
token_count=current_tokens_count, |
|
metadata=metadatas[doc_idx].copy() |
|
) |
|
) |
|
chunks_by_max_chunk_size += 1 |
|
logger.debug( |
|
f"Chink finalized with {current_tokens_count} tokens due to " |
|
f"exceeding token limit of {self.max_split_tokens}." |
|
) |
|
current_split, current_tokens_count = [], 0 |
|
|
|
current_split.append(doc) |
|
current_tokens_count += doc_token_count |
|
|
|
|
|
if current_split: |
|
chunks.append( |
|
Chunk( |
|
splits=current_split.copy(), |
|
is_triggered=False, |
|
triggered_score=None, |
|
token_count=current_tokens_count, |
|
metadata=metadatas[doc_idx].copy() |
|
) |
|
) |
|
chunks_by_last_split += 1 |
|
logger.debug( |
|
f"Final split added with {current_tokens_count} " |
|
"tokens due to remaining documents." |
|
) |
|
|
|
|
|
original_token_count = sum(token_counts) |
|
split_token_count = sum( |
|
[tiktoken_length(doc) for split in chunks for doc in split.splits] |
|
) |
|
if original_token_count != split_token_count: |
|
logger.error( |
|
f"Token count mismatch: {original_token_count} != {split_token_count}" |
|
) |
|
raise ValueError( |
|
f"Token count mismatch: {original_token_count} != {split_token_count}" |
|
) |
|
|
|
|
|
total_chunks = len(chunks) |
|
chunks_by_similarity_ratio = ( |
|
chunks_by_threshold / total_chunks if total_chunks else 0 |
|
) |
|
min_token_size = max_token_size = 0 |
|
if chunks: |
|
token_counts = [ |
|
split.token_count for split in chunks if split.token_count is not None |
|
] |
|
min_token_size, max_token_size = min(token_counts, default=0), max( |
|
token_counts, default=0 |
|
) |
|
|
|
self.statistics = ChunkStatistics( |
|
total_documents=len(docs), |
|
total_chunks=total_chunks, |
|
chunks_by_threshold=chunks_by_threshold, |
|
chunks_by_max_chunk_size=chunks_by_max_chunk_size, |
|
chunks_by_last_split=chunks_by_last_split, |
|
min_token_size=min_token_size, |
|
max_token_size=max_token_size, |
|
chunks_by_similarity_ratio=chunks_by_similarity_ratio, |
|
) |
|
|
|
return chunks |
|
|
|
def plot_similarity_scores( |
|
self, |
|
similarities: List[float], |
|
split_indices: List[int], |
|
chunks: list[Chunk], |
|
): |
|
try: |
|
from matplotlib import pyplot as plt |
|
except ImportError: |
|
logger.warning( |
|
"Plotting is disabled. Please `pip install " |
|
"semantic-router[processing]`." |
|
) |
|
return |
|
|
|
_, axs = plt.subplots(2, 1, figsize=(12, 12)) |
|
|
|
|
|
axs[0].plot(similarities, label="Similarity Scores", marker="o") |
|
for split_index in split_indices: |
|
axs[0].axvline( |
|
x=split_index - 1, |
|
color="r", |
|
linestyle="--", |
|
label="Chunk" if split_index == split_indices[0] else "", |
|
) |
|
axs[0].axhline( |
|
y=self.calculated_threshold, |
|
color="g", |
|
linestyle="-.", |
|
label="Threshold Similarity Score", |
|
) |
|
|
|
|
|
for i, score in enumerate(similarities): |
|
axs[0].annotate( |
|
f"{score:.2f}", |
|
(i, score), |
|
textcoords="offset points", |
|
xytext=(0, 10), |
|
ha="center", |
|
) |
|
|
|
axs[0].set_xlabel("Document Segment Index") |
|
axs[0].set_ylabel("Similarity Score") |
|
axs[0].set_title( |
|
f"Threshold: {self.calculated_threshold} |" |
|
f" Window Size: {self.window_size}", |
|
loc="right", |
|
fontsize=10, |
|
) |
|
axs[0].legend() |
|
|
|
|
|
token_counts = [split.token_count for split in chunks] |
|
axs[1].bar(range(len(token_counts)), token_counts, color="lightblue") |
|
axs[1].set_title("Chunk Token Sizes") |
|
axs[1].set_xlabel("Chunk Index") |
|
axs[1].set_ylabel("Token Count") |
|
axs[1].set_xticks(range(len(token_counts))) |
|
axs[1].set_xticklabels([str(i) for i in range(len(token_counts))]) |
|
axs[1].grid(True) |
|
|
|
|
|
for idx, token_count in enumerate(token_counts): |
|
if not token_count: |
|
continue |
|
axs[1].text( |
|
idx, token_count + 0.01, str(token_count), ha="center", va="bottom" |
|
) |
|
|
|
plt.tight_layout() |
|
plt.show() |
|
|
|
def plot_sentence_similarity_scores( |
|
self, docs: List[str], threshold: float, window_size: int |
|
): |
|
try: |
|
from matplotlib import pyplot as plt |
|
except ImportError: |
|
logger.warning("Plotting is disabled. Please `pip install matplotlib`.") |
|
return |
|
""" |
|
Computes similarity scores between the average of the last |
|
'window_size' sentences and the next one, |
|
plots a graph of these similarity scores, and prints the first |
|
sentence after a similarity score below |
|
a specified threshold. |
|
""" |
|
sentences = [sentence for doc in docs for sentence in self._split(doc)] |
|
encoded_sentences = self._encode_documents(sentences) |
|
similarity_scores = [] |
|
|
|
for i in range(window_size, len(encoded_sentences)): |
|
window_avg_encoding = np.mean( |
|
encoded_sentences[i - window_size : i], axis=0 |
|
) |
|
sim_score = np.dot(window_avg_encoding, encoded_sentences[i]) / ( |
|
np.linalg.norm(window_avg_encoding) |
|
* np.linalg.norm(encoded_sentences[i]) |
|
+ 1e-10 |
|
) |
|
similarity_scores.append(sim_score) |
|
|
|
plt.figure(figsize=(10, 8)) |
|
plt.plot(similarity_scores, marker="o", linestyle="-", color="b") |
|
plt.title("Sliding Window Sentence Similarity Scores") |
|
plt.xlabel("Sentence Index") |
|
plt.ylabel("Similarity Score") |
|
plt.grid(True) |
|
plt.axhline(y=threshold, color="r", linestyle="--", label="Threshold") |
|
plt.show() |
|
|
|
for i, score in enumerate(similarity_scores): |
|
if score < threshold: |
|
print( |
|
f"First sentence after similarity score " |
|
f"below {threshold}: {sentences[i + window_size]}" |
|
) |