Spaces:
Running
on
T4
Running
on
T4
from dataclasses import dataclass | |
from typing import Any, List | |
import numpy as np | |
from semantic_router.encoders.base import BaseEncoder | |
from semantic_chunkers.schema import Chunk | |
from semantic_chunkers.chunkers.base import BaseChunker | |
from semantic_chunkers.splitters.base import BaseSplitter | |
from semantic_chunkers.splitters.sentence import RegexSplitter | |
from semantic_chunkers.utils.text import tiktoken_length | |
from semantic_chunkers.utils.logger import logger | |
from tqdm.auto import tqdm | |
class ChunkStatistics: | |
total_documents: int | |
total_chunks: int | |
chunks_by_threshold: int | |
chunks_by_max_chunk_size: int | |
chunks_by_last_split: int | |
min_token_size: int | |
max_token_size: int | |
chunks_by_similarity_ratio: float | |
def __str__(self): | |
return ( | |
f"Chunking Statistics:\n" | |
f" - Total Documents: {self.total_documents}\n" | |
f" - Total Chunks: {self.total_chunks}\n" | |
f" - Chunks by Threshold: {self.chunks_by_threshold}\n" | |
f" - Chunks by Max Chunk Size: {self.chunks_by_max_chunk_size}\n" | |
f" - Last Chunk: {self.chunks_by_last_split}\n" | |
f" - Minimum Token Size of Chunk: {self.min_token_size}\n" | |
f" - Maximum Token Size of Chunk: {self.max_token_size}\n" | |
f" - Similarity Chunk Ratio: {self.chunks_by_similarity_ratio:.2f}" | |
) | |
class StatisticalChunker(BaseChunker): | |
def __init__( | |
self, | |
encoder: BaseEncoder, | |
splitter: BaseSplitter = RegexSplitter(), | |
name="statistical_chunker", | |
threshold_adjustment=0.01, | |
dynamic_threshold: bool = True, | |
window_size=5, | |
min_split_tokens=100, | |
max_split_tokens=300, | |
split_tokens_tolerance=10, | |
plot_chunks=False, | |
enable_statistics=False, | |
): | |
super().__init__(name=name, encoder=encoder, splitter=splitter) | |
self.calculated_threshold: float | |
self.encoder = encoder | |
self.threshold_adjustment = threshold_adjustment | |
self.dynamic_threshold = dynamic_threshold | |
self.window_size = window_size | |
self.plot_chunks = plot_chunks | |
self.min_split_tokens = min_split_tokens | |
self.max_split_tokens = max_split_tokens | |
self.split_tokens_tolerance = split_tokens_tolerance | |
self.enable_statistics = enable_statistics | |
self.statistics: ChunkStatistics | |
def _chunk( | |
self, splits: List[Any], metadatas: List[dict], batch_size: int = 64, enforce_max_tokens: bool = False | |
) -> List[Chunk]: | |
"""Merge splits into chunks using semantic similarity, with optional enforcement of maximum token limits per chunk. | |
:param splits: Splits to be merged into chunks. | |
:param batch_size: Number of splits to process in one batch. | |
:param enforce_max_tokens: If True, further split chunks that exceed the maximum token limit. | |
:return: List of chunks. | |
""" | |
# Split the docs that already exceed max_split_tokens to smaller chunks | |
if enforce_max_tokens: | |
new_splits = [] | |
for split in splits: | |
token_count = tiktoken_length(split) | |
if token_count > self.max_split_tokens: | |
logger.info( | |
f"Single document exceeds the maximum token limit " | |
f"of {self.max_split_tokens}. " | |
"Splitting to sentences before semantically merging." | |
) | |
_splits = self._split(split) | |
new_splits.extend(_splits) | |
else: | |
new_splits.append(split) | |
splits = [split for split in new_splits if split and split.strip()] | |
chunks = [] | |
last_split = None | |
for i in tqdm(range(0, len(splits), batch_size)): | |
batch_splits = splits[i : i + batch_size] | |
if last_split is not None: | |
batch_splits = last_split.splits + batch_splits | |
encoded_splits = self._encode_documents(batch_splits) | |
similarities = self._calculate_similarity_scores(encoded_splits) | |
if self.dynamic_threshold: | |
self._find_optimal_threshold(batch_splits, similarities) | |
else: | |
self.calculated_threshold = self.encoder.score_threshold | |
split_indices = self._find_split_indices(similarities=similarities) | |
doc_chunks = self._split_documents( | |
batch_splits, metadatas, split_indices, similarities | |
) | |
if len(doc_chunks) > 1: | |
chunks.extend(doc_chunks[:-1]) | |
last_split = doc_chunks[-1] | |
else: | |
last_split = doc_chunks[0] | |
if self.plot_chunks: | |
self.plot_similarity_scores(similarities, split_indices, doc_chunks) | |
if self.enable_statistics: | |
print(self.statistics) | |
if last_split: | |
chunks.append(last_split) | |
return chunks | |
def __call__(self, docs: List[str], metadatas: List[dict], batch_size: int = 64) -> List[List[Chunk]]: | |
"""Split documents into smaller chunks based on semantic similarity. | |
:param docs: list of text documents to be split, if only wanted to | |
split a single document, pass it as a list with a single element. | |
:return: list of Chunk objects containing the split documents. | |
""" | |
if not docs: | |
raise ValueError("At least one document is required for splitting.") | |
all_chunks = [] | |
for doc in docs: | |
token_count = tiktoken_length(doc) | |
if token_count > self.max_split_tokens: | |
logger.info( | |
f"Single document exceeds the maximum token limit " | |
f"of {self.max_split_tokens}. " | |
"Splitting to sentences before semantically merging." | |
) | |
if isinstance(doc, str): | |
splits = self._split(doc) | |
doc_chunks = self._chunk(splits, metadatas, batch_size=batch_size) | |
all_chunks.append(doc_chunks) | |
else: | |
raise ValueError("The document must be a string.") | |
return all_chunks | |
def _encode_documents(self, docs: List[str]) -> np.ndarray: | |
""" | |
:param docs: List of text documents to be encoded. | |
:return: A numpy array of embeddings for the given documents. | |
""" | |
return np.array(self.encoder(docs)) | |
def _calculate_similarity_scores(self, encoded_docs: np.ndarray) -> List[float]: | |
raw_similarities = [] | |
for idx in range(1, len(encoded_docs)): | |
window_start = max(0, idx - self.window_size) | |
cumulative_context = np.mean(encoded_docs[window_start:idx], axis=0) | |
curr_sim_score = np.dot(cumulative_context, encoded_docs[idx]) / ( | |
np.linalg.norm(cumulative_context) * np.linalg.norm(encoded_docs[idx]) | |
+ 1e-10 | |
) | |
raw_similarities.append(curr_sim_score) | |
return raw_similarities | |
def _find_split_indices(self, similarities: List[float]) -> List[int]: | |
split_indices = [] | |
for idx, score in enumerate(similarities): | |
logger.debug(f"Similarity score at index {idx}: {score}") | |
if score < self.calculated_threshold: | |
logger.debug( | |
f"Adding to split_indices due to score < threshold: " | |
f"{score} < {self.calculated_threshold}" | |
) | |
# Chunk after the document at idx | |
split_indices.append(idx + 1) | |
return split_indices | |
def _find_optimal_threshold(self, docs: List[str], similarity_scores: List[float]): | |
token_counts = [tiktoken_length(doc) for doc in docs] | |
cumulative_token_counts = np.cumsum([0] + token_counts) | |
# Analyze the distribution of similarity scores to set initial bounds | |
median_score = np.median(similarity_scores) | |
std_dev = np.std(similarity_scores) | |
# Set initial bounds based on median and standard deviation | |
low = max(0.0, float(median_score - std_dev)) | |
high = min(1.0, float(median_score + std_dev)) | |
iteration = 0 | |
median_tokens = 0 | |
while low <= high: | |
self.calculated_threshold = (low + high) / 2 | |
split_indices = self._find_split_indices(similarity_scores) | |
logger.debug( | |
f"Iteration {iteration}: Trying threshold: {self.calculated_threshold}" | |
) | |
# Calculate the token counts for each split using the cumulative sums | |
split_token_counts = [ | |
cumulative_token_counts[end] - cumulative_token_counts[start] | |
for start, end in zip( | |
[0] + split_indices, split_indices + [len(token_counts)] | |
) | |
] | |
# Calculate the median token count for the chunks | |
median_tokens = np.median(split_token_counts) | |
logger.debug( | |
f"Iteration {iteration}: Median tokens per split: {median_tokens}" | |
) | |
if ( | |
self.min_split_tokens - self.split_tokens_tolerance | |
<= median_tokens | |
<= self.max_split_tokens + self.split_tokens_tolerance | |
): | |
logger.debug("Median tokens in target range. Stopping iteration.") | |
break | |
elif median_tokens < self.min_split_tokens: | |
high = self.calculated_threshold - self.threshold_adjustment | |
logger.debug(f"Iteration {iteration}: Adjusting high to {high}") | |
else: | |
low = self.calculated_threshold + self.threshold_adjustment | |
logger.debug(f"Iteration {iteration}: Adjusting low to {low}") | |
iteration += 1 | |
logger.debug( | |
f"Optimal threshold {self.calculated_threshold} found " | |
f"with median tokens ({median_tokens}) in target range " | |
f"({self.min_split_tokens}-{self.max_split_tokens})." | |
) | |
return self.calculated_threshold | |
def _split_documents( | |
self, docs: List[str], metadatas: List[dict], split_indices: List[int], similarities: List[float] | |
) -> List[Chunk]: | |
""" | |
This method iterates through each document, appending it to the current split | |
until it either reaches a split point (determined by split_indices) or exceeds | |
the maximum token limit for a split (self.max_split_tokens). | |
When a document causes the current token count to exceed this limit, | |
or when a split point is reached and the minimum token requirement is met, | |
the current split is finalized and added to the List of chunks. | |
""" | |
token_counts = [tiktoken_length(doc) for doc in docs] | |
chunks, current_split = [], [] | |
current_tokens_count = 0 | |
# Statistics | |
chunks_by_threshold = 0 | |
chunks_by_max_chunk_size = 0 | |
chunks_by_last_split = 0 | |
for doc_idx, doc in enumerate(docs): | |
doc_token_count = token_counts[doc_idx] | |
logger.debug(f"Accumulative token count: {current_tokens_count} tokens") | |
logger.debug(f"Document token count: {doc_token_count} tokens") | |
# Check if current index is a split point based on similarity | |
if doc_idx + 1 in split_indices: | |
if ( | |
self.min_split_tokens | |
<= current_tokens_count + doc_token_count | |
< self.max_split_tokens | |
): | |
# Include the current document before splitting | |
# if it doesn't exceed the max limit | |
current_split.append(doc) | |
current_tokens_count += doc_token_count | |
triggered_score = ( | |
similarities[doc_idx] if doc_idx < len(similarities) else None | |
) | |
chunks.append( | |
Chunk( | |
splits=current_split.copy(), | |
is_triggered=True, | |
triggered_score=triggered_score, | |
token_count=current_tokens_count, | |
metadata=metadatas[doc_idx].copy() | |
) | |
) | |
logger.debug( | |
f"Chunk finalized with {current_tokens_count} tokens due to " | |
f"threshold {self.calculated_threshold}." | |
) | |
current_split, current_tokens_count = [], 0 | |
chunks_by_threshold += 1 | |
continue # Move to the next document after splitting | |
# Check if adding the current document exceeds the max token limit | |
if current_tokens_count + doc_token_count > self.max_split_tokens: | |
if current_tokens_count >= self.min_split_tokens: | |
chunks.append( | |
Chunk( | |
splits=current_split.copy(), | |
is_triggered=False, | |
triggered_score=None, | |
token_count=current_tokens_count, | |
metadata=metadatas[doc_idx].copy() | |
) | |
) | |
chunks_by_max_chunk_size += 1 | |
logger.debug( | |
f"Chink finalized with {current_tokens_count} tokens due to " | |
f"exceeding token limit of {self.max_split_tokens}." | |
) | |
current_split, current_tokens_count = [], 0 | |
current_split.append(doc) | |
current_tokens_count += doc_token_count | |
# Handle the last split | |
if current_split: | |
chunks.append( | |
Chunk( | |
splits=current_split.copy(), | |
is_triggered=False, | |
triggered_score=None, | |
token_count=current_tokens_count, | |
metadata=metadatas[doc_idx].copy() | |
) | |
) | |
chunks_by_last_split += 1 | |
logger.debug( | |
f"Final split added with {current_tokens_count} " | |
"tokens due to remaining documents." | |
) | |
# Validation to ensure no tokens are lost during the split | |
original_token_count = sum(token_counts) | |
split_token_count = sum( | |
[tiktoken_length(doc) for split in chunks for doc in split.splits] | |
) | |
if original_token_count != split_token_count: | |
logger.error( | |
f"Token count mismatch: {original_token_count} != {split_token_count}" | |
) | |
raise ValueError( | |
f"Token count mismatch: {original_token_count} != {split_token_count}" | |
) | |
# Statistics | |
total_chunks = len(chunks) | |
chunks_by_similarity_ratio = ( | |
chunks_by_threshold / total_chunks if total_chunks else 0 | |
) | |
min_token_size = max_token_size = 0 | |
if chunks: | |
token_counts = [ | |
split.token_count for split in chunks if split.token_count is not None | |
] | |
min_token_size, max_token_size = min(token_counts, default=0), max( | |
token_counts, default=0 | |
) | |
self.statistics = ChunkStatistics( | |
total_documents=len(docs), | |
total_chunks=total_chunks, | |
chunks_by_threshold=chunks_by_threshold, | |
chunks_by_max_chunk_size=chunks_by_max_chunk_size, | |
chunks_by_last_split=chunks_by_last_split, | |
min_token_size=min_token_size, | |
max_token_size=max_token_size, | |
chunks_by_similarity_ratio=chunks_by_similarity_ratio, | |
) | |
return chunks | |
def plot_similarity_scores( | |
self, | |
similarities: List[float], | |
split_indices: List[int], | |
chunks: list[Chunk], | |
): | |
try: | |
from matplotlib import pyplot as plt | |
except ImportError: | |
logger.warning( | |
"Plotting is disabled. Please `pip install " | |
"semantic-router[processing]`." | |
) | |
return | |
_, axs = plt.subplots(2, 1, figsize=(12, 12)) # Adjust for two plots | |
# Plot 1: Similarity Scores | |
axs[0].plot(similarities, label="Similarity Scores", marker="o") | |
for split_index in split_indices: | |
axs[0].axvline( | |
x=split_index - 1, | |
color="r", | |
linestyle="--", | |
label="Chunk" if split_index == split_indices[0] else "", | |
) | |
axs[0].axhline( | |
y=self.calculated_threshold, | |
color="g", | |
linestyle="-.", | |
label="Threshold Similarity Score", | |
) | |
# Annotating each similarity score | |
for i, score in enumerate(similarities): | |
axs[0].annotate( | |
f"{score:.2f}", # Formatting to two decimal places | |
(i, score), | |
textcoords="offset points", | |
xytext=(0, 10), # Positioning the text above the point | |
ha="center", | |
) # Center-align the text | |
axs[0].set_xlabel("Document Segment Index") | |
axs[0].set_ylabel("Similarity Score") | |
axs[0].set_title( | |
f"Threshold: {self.calculated_threshold} |" | |
f" Window Size: {self.window_size}", | |
loc="right", | |
fontsize=10, | |
) | |
axs[0].legend() | |
# Plot 2: Chunk Token Size Distribution | |
token_counts = [split.token_count for split in chunks] | |
axs[1].bar(range(len(token_counts)), token_counts, color="lightblue") | |
axs[1].set_title("Chunk Token Sizes") | |
axs[1].set_xlabel("Chunk Index") | |
axs[1].set_ylabel("Token Count") | |
axs[1].set_xticks(range(len(token_counts))) | |
axs[1].set_xticklabels([str(i) for i in range(len(token_counts))]) | |
axs[1].grid(True) | |
# Annotate each bar with the token size | |
for idx, token_count in enumerate(token_counts): | |
if not token_count: | |
continue | |
axs[1].text( | |
idx, token_count + 0.01, str(token_count), ha="center", va="bottom" | |
) | |
plt.tight_layout() | |
plt.show() | |
def plot_sentence_similarity_scores( | |
self, docs: List[str], threshold: float, window_size: int | |
): | |
try: | |
from matplotlib import pyplot as plt | |
except ImportError: | |
logger.warning("Plotting is disabled. Please `pip install matplotlib`.") | |
return | |
""" | |
Computes similarity scores between the average of the last | |
'window_size' sentences and the next one, | |
plots a graph of these similarity scores, and prints the first | |
sentence after a similarity score below | |
a specified threshold. | |
""" | |
sentences = [sentence for doc in docs for sentence in self._split(doc)] | |
encoded_sentences = self._encode_documents(sentences) | |
similarity_scores = [] | |
for i in range(window_size, len(encoded_sentences)): | |
window_avg_encoding = np.mean( | |
encoded_sentences[i - window_size : i], axis=0 | |
) | |
sim_score = np.dot(window_avg_encoding, encoded_sentences[i]) / ( | |
np.linalg.norm(window_avg_encoding) | |
* np.linalg.norm(encoded_sentences[i]) | |
+ 1e-10 | |
) | |
similarity_scores.append(sim_score) | |
plt.figure(figsize=(10, 8)) | |
plt.plot(similarity_scores, marker="o", linestyle="-", color="b") | |
plt.title("Sliding Window Sentence Similarity Scores") | |
plt.xlabel("Sentence Index") | |
plt.ylabel("Similarity Score") | |
plt.grid(True) | |
plt.axhline(y=threshold, color="r", linestyle="--", label="Threshold") | |
plt.show() | |
for i, score in enumerate(similarity_scores): | |
if score < threshold: | |
print( | |
f"First sentence after similarity score " | |
f"below {threshold}: {sentences[i + window_size]}" | |
) |