Multipurpose-AI-Agent-Development / statistical_chunker.py
devve1's picture
Create statistical_chunker.py
d8b1734 verified
raw
history blame
No virus
20.4 kB
from dataclasses import dataclass
from typing import Any, List
import numpy as np
from semantic_router.encoders.base import BaseEncoder
from semantic_chunkers.schema import Chunk
from semantic_chunkers.chunkers.base import BaseChunker
from semantic_chunkers.splitters.base import BaseSplitter
from semantic_chunkers.splitters.sentence import RegexSplitter
from semantic_chunkers.utils.text import tiktoken_length
from semantic_chunkers.utils.logger import logger
from tqdm.auto import tqdm
@dataclass
class ChunkStatistics:
total_documents: int
total_chunks: int
chunks_by_threshold: int
chunks_by_max_chunk_size: int
chunks_by_last_split: int
min_token_size: int
max_token_size: int
chunks_by_similarity_ratio: float
def __str__(self):
return (
f"Chunking Statistics:\n"
f" - Total Documents: {self.total_documents}\n"
f" - Total Chunks: {self.total_chunks}\n"
f" - Chunks by Threshold: {self.chunks_by_threshold}\n"
f" - Chunks by Max Chunk Size: {self.chunks_by_max_chunk_size}\n"
f" - Last Chunk: {self.chunks_by_last_split}\n"
f" - Minimum Token Size of Chunk: {self.min_token_size}\n"
f" - Maximum Token Size of Chunk: {self.max_token_size}\n"
f" - Similarity Chunk Ratio: {self.chunks_by_similarity_ratio:.2f}"
)
class StatisticalChunker(BaseChunker):
def __init__(
self,
encoder: BaseEncoder,
splitter: BaseSplitter = RegexSplitter(),
name="statistical_chunker",
threshold_adjustment=0.01,
dynamic_threshold: bool = True,
window_size=5,
min_split_tokens=100,
max_split_tokens=300,
split_tokens_tolerance=10,
plot_chunks=False,
enable_statistics=False,
):
super().__init__(name=name, encoder=encoder, splitter=splitter)
self.calculated_threshold: float
self.encoder = encoder
self.threshold_adjustment = threshold_adjustment
self.dynamic_threshold = dynamic_threshold
self.window_size = window_size
self.plot_chunks = plot_chunks
self.min_split_tokens = min_split_tokens
self.max_split_tokens = max_split_tokens
self.split_tokens_tolerance = split_tokens_tolerance
self.enable_statistics = enable_statistics
self.statistics: ChunkStatistics
def _chunk(
self, splits: List[Any], metadatas: List[dict], batch_size: int = 64, enforce_max_tokens: bool = False
) -> List[Chunk]:
"""Merge splits into chunks using semantic similarity, with optional enforcement of maximum token limits per chunk.
:param splits: Splits to be merged into chunks.
:param batch_size: Number of splits to process in one batch.
:param enforce_max_tokens: If True, further split chunks that exceed the maximum token limit.
:return: List of chunks.
"""
# Split the docs that already exceed max_split_tokens to smaller chunks
if enforce_max_tokens:
new_splits = []
for split in splits:
token_count = tiktoken_length(split)
if token_count > self.max_split_tokens:
logger.info(
f"Single document exceeds the maximum token limit "
f"of {self.max_split_tokens}. "
"Splitting to sentences before semantically merging."
)
_splits = self._split(split)
new_splits.extend(_splits)
else:
new_splits.append(split)
splits = [split for split in new_splits if split and split.strip()]
chunks = []
last_split = None
for i in tqdm(range(0, len(splits), batch_size)):
batch_splits = splits[i : i + batch_size]
if last_split is not None:
batch_splits = last_split.splits + batch_splits
encoded_splits = self._encode_documents(batch_splits)
similarities = self._calculate_similarity_scores(encoded_splits)
if self.dynamic_threshold:
self._find_optimal_threshold(batch_splits, similarities)
else:
self.calculated_threshold = self.encoder.score_threshold
split_indices = self._find_split_indices(similarities=similarities)
doc_chunks = self._split_documents(
batch_splits, metadatas, split_indices, similarities
)
if len(doc_chunks) > 1:
chunks.extend(doc_chunks[:-1])
last_split = doc_chunks[-1]
else:
last_split = doc_chunks[0]
if self.plot_chunks:
self.plot_similarity_scores(similarities, split_indices, doc_chunks)
if self.enable_statistics:
print(self.statistics)
if last_split:
chunks.append(last_split)
return chunks
def __call__(self, docs: List[str], metadatas: List[dict], batch_size: int = 64) -> List[List[Chunk]]:
"""Split documents into smaller chunks based on semantic similarity.
:param docs: list of text documents to be split, if only wanted to
split a single document, pass it as a list with a single element.
:return: list of Chunk objects containing the split documents.
"""
if not docs:
raise ValueError("At least one document is required for splitting.")
all_chunks = []
for doc in docs:
token_count = tiktoken_length(doc)
if token_count > self.max_split_tokens:
logger.info(
f"Single document exceeds the maximum token limit "
f"of {self.max_split_tokens}. "
"Splitting to sentences before semantically merging."
)
if isinstance(doc, str):
splits = self._split(doc)
doc_chunks = self._chunk(splits, metadatas, batch_size=batch_size)
all_chunks.append(doc_chunks)
else:
raise ValueError("The document must be a string.")
return all_chunks
def _encode_documents(self, docs: List[str]) -> np.ndarray:
"""
:param docs: List of text documents to be encoded.
:return: A numpy array of embeddings for the given documents.
"""
return np.array(self.encoder(docs))
def _calculate_similarity_scores(self, encoded_docs: np.ndarray) -> List[float]:
raw_similarities = []
for idx in range(1, len(encoded_docs)):
window_start = max(0, idx - self.window_size)
cumulative_context = np.mean(encoded_docs[window_start:idx], axis=0)
curr_sim_score = np.dot(cumulative_context, encoded_docs[idx]) / (
np.linalg.norm(cumulative_context) * np.linalg.norm(encoded_docs[idx])
+ 1e-10
)
raw_similarities.append(curr_sim_score)
return raw_similarities
def _find_split_indices(self, similarities: List[float]) -> List[int]:
split_indices = []
for idx, score in enumerate(similarities):
logger.debug(f"Similarity score at index {idx}: {score}")
if score < self.calculated_threshold:
logger.debug(
f"Adding to split_indices due to score < threshold: "
f"{score} < {self.calculated_threshold}"
)
# Chunk after the document at idx
split_indices.append(idx + 1)
return split_indices
def _find_optimal_threshold(self, docs: List[str], similarity_scores: List[float]):
token_counts = [tiktoken_length(doc) for doc in docs]
cumulative_token_counts = np.cumsum([0] + token_counts)
# Analyze the distribution of similarity scores to set initial bounds
median_score = np.median(similarity_scores)
std_dev = np.std(similarity_scores)
# Set initial bounds based on median and standard deviation
low = max(0.0, float(median_score - std_dev))
high = min(1.0, float(median_score + std_dev))
iteration = 0
median_tokens = 0
while low <= high:
self.calculated_threshold = (low + high) / 2
split_indices = self._find_split_indices(similarity_scores)
logger.debug(
f"Iteration {iteration}: Trying threshold: {self.calculated_threshold}"
)
# Calculate the token counts for each split using the cumulative sums
split_token_counts = [
cumulative_token_counts[end] - cumulative_token_counts[start]
for start, end in zip(
[0] + split_indices, split_indices + [len(token_counts)]
)
]
# Calculate the median token count for the chunks
median_tokens = np.median(split_token_counts)
logger.debug(
f"Iteration {iteration}: Median tokens per split: {median_tokens}"
)
if (
self.min_split_tokens - self.split_tokens_tolerance
<= median_tokens
<= self.max_split_tokens + self.split_tokens_tolerance
):
logger.debug("Median tokens in target range. Stopping iteration.")
break
elif median_tokens < self.min_split_tokens:
high = self.calculated_threshold - self.threshold_adjustment
logger.debug(f"Iteration {iteration}: Adjusting high to {high}")
else:
low = self.calculated_threshold + self.threshold_adjustment
logger.debug(f"Iteration {iteration}: Adjusting low to {low}")
iteration += 1
logger.debug(
f"Optimal threshold {self.calculated_threshold} found "
f"with median tokens ({median_tokens}) in target range "
f"({self.min_split_tokens}-{self.max_split_tokens})."
)
return self.calculated_threshold
def _split_documents(
self, docs: List[str], metadatas: List[dict], split_indices: List[int], similarities: List[float]
) -> List[Chunk]:
"""
This method iterates through each document, appending it to the current split
until it either reaches a split point (determined by split_indices) or exceeds
the maximum token limit for a split (self.max_split_tokens).
When a document causes the current token count to exceed this limit,
or when a split point is reached and the minimum token requirement is met,
the current split is finalized and added to the List of chunks.
"""
token_counts = [tiktoken_length(doc) for doc in docs]
chunks, current_split = [], []
current_tokens_count = 0
# Statistics
chunks_by_threshold = 0
chunks_by_max_chunk_size = 0
chunks_by_last_split = 0
for doc_idx, doc in enumerate(docs):
doc_token_count = token_counts[doc_idx]
logger.debug(f"Accumulative token count: {current_tokens_count} tokens")
logger.debug(f"Document token count: {doc_token_count} tokens")
# Check if current index is a split point based on similarity
if doc_idx + 1 in split_indices:
if (
self.min_split_tokens
<= current_tokens_count + doc_token_count
< self.max_split_tokens
):
# Include the current document before splitting
# if it doesn't exceed the max limit
current_split.append(doc)
current_tokens_count += doc_token_count
triggered_score = (
similarities[doc_idx] if doc_idx < len(similarities) else None
)
chunks.append(
Chunk(
splits=current_split.copy(),
is_triggered=True,
triggered_score=triggered_score,
token_count=current_tokens_count,
metadata=metadatas[doc_idx].copy()
)
)
logger.debug(
f"Chunk finalized with {current_tokens_count} tokens due to "
f"threshold {self.calculated_threshold}."
)
current_split, current_tokens_count = [], 0
chunks_by_threshold += 1
continue # Move to the next document after splitting
# Check if adding the current document exceeds the max token limit
if current_tokens_count + doc_token_count > self.max_split_tokens:
if current_tokens_count >= self.min_split_tokens:
chunks.append(
Chunk(
splits=current_split.copy(),
is_triggered=False,
triggered_score=None,
token_count=current_tokens_count,
metadata=metadatas[doc_idx].copy()
)
)
chunks_by_max_chunk_size += 1
logger.debug(
f"Chink finalized with {current_tokens_count} tokens due to "
f"exceeding token limit of {self.max_split_tokens}."
)
current_split, current_tokens_count = [], 0
current_split.append(doc)
current_tokens_count += doc_token_count
# Handle the last split
if current_split:
chunks.append(
Chunk(
splits=current_split.copy(),
is_triggered=False,
triggered_score=None,
token_count=current_tokens_count,
metadata=metadatas[doc_idx].copy()
)
)
chunks_by_last_split += 1
logger.debug(
f"Final split added with {current_tokens_count} "
"tokens due to remaining documents."
)
# Validation to ensure no tokens are lost during the split
original_token_count = sum(token_counts)
split_token_count = sum(
[tiktoken_length(doc) for split in chunks for doc in split.splits]
)
if original_token_count != split_token_count:
logger.error(
f"Token count mismatch: {original_token_count} != {split_token_count}"
)
raise ValueError(
f"Token count mismatch: {original_token_count} != {split_token_count}"
)
# Statistics
total_chunks = len(chunks)
chunks_by_similarity_ratio = (
chunks_by_threshold / total_chunks if total_chunks else 0
)
min_token_size = max_token_size = 0
if chunks:
token_counts = [
split.token_count for split in chunks if split.token_count is not None
]
min_token_size, max_token_size = min(token_counts, default=0), max(
token_counts, default=0
)
self.statistics = ChunkStatistics(
total_documents=len(docs),
total_chunks=total_chunks,
chunks_by_threshold=chunks_by_threshold,
chunks_by_max_chunk_size=chunks_by_max_chunk_size,
chunks_by_last_split=chunks_by_last_split,
min_token_size=min_token_size,
max_token_size=max_token_size,
chunks_by_similarity_ratio=chunks_by_similarity_ratio,
)
return chunks
def plot_similarity_scores(
self,
similarities: List[float],
split_indices: List[int],
chunks: list[Chunk],
):
try:
from matplotlib import pyplot as plt
except ImportError:
logger.warning(
"Plotting is disabled. Please `pip install "
"semantic-router[processing]`."
)
return
_, axs = plt.subplots(2, 1, figsize=(12, 12)) # Adjust for two plots
# Plot 1: Similarity Scores
axs[0].plot(similarities, label="Similarity Scores", marker="o")
for split_index in split_indices:
axs[0].axvline(
x=split_index - 1,
color="r",
linestyle="--",
label="Chunk" if split_index == split_indices[0] else "",
)
axs[0].axhline(
y=self.calculated_threshold,
color="g",
linestyle="-.",
label="Threshold Similarity Score",
)
# Annotating each similarity score
for i, score in enumerate(similarities):
axs[0].annotate(
f"{score:.2f}", # Formatting to two decimal places
(i, score),
textcoords="offset points",
xytext=(0, 10), # Positioning the text above the point
ha="center",
) # Center-align the text
axs[0].set_xlabel("Document Segment Index")
axs[0].set_ylabel("Similarity Score")
axs[0].set_title(
f"Threshold: {self.calculated_threshold} |"
f" Window Size: {self.window_size}",
loc="right",
fontsize=10,
)
axs[0].legend()
# Plot 2: Chunk Token Size Distribution
token_counts = [split.token_count for split in chunks]
axs[1].bar(range(len(token_counts)), token_counts, color="lightblue")
axs[1].set_title("Chunk Token Sizes")
axs[1].set_xlabel("Chunk Index")
axs[1].set_ylabel("Token Count")
axs[1].set_xticks(range(len(token_counts)))
axs[1].set_xticklabels([str(i) for i in range(len(token_counts))])
axs[1].grid(True)
# Annotate each bar with the token size
for idx, token_count in enumerate(token_counts):
if not token_count:
continue
axs[1].text(
idx, token_count + 0.01, str(token_count), ha="center", va="bottom"
)
plt.tight_layout()
plt.show()
def plot_sentence_similarity_scores(
self, docs: List[str], threshold: float, window_size: int
):
try:
from matplotlib import pyplot as plt
except ImportError:
logger.warning("Plotting is disabled. Please `pip install matplotlib`.")
return
"""
Computes similarity scores between the average of the last
'window_size' sentences and the next one,
plots a graph of these similarity scores, and prints the first
sentence after a similarity score below
a specified threshold.
"""
sentences = [sentence for doc in docs for sentence in self._split(doc)]
encoded_sentences = self._encode_documents(sentences)
similarity_scores = []
for i in range(window_size, len(encoded_sentences)):
window_avg_encoding = np.mean(
encoded_sentences[i - window_size : i], axis=0
)
sim_score = np.dot(window_avg_encoding, encoded_sentences[i]) / (
np.linalg.norm(window_avg_encoding)
* np.linalg.norm(encoded_sentences[i])
+ 1e-10
)
similarity_scores.append(sim_score)
plt.figure(figsize=(10, 8))
plt.plot(similarity_scores, marker="o", linestyle="-", color="b")
plt.title("Sliding Window Sentence Similarity Scores")
plt.xlabel("Sentence Index")
plt.ylabel("Similarity Score")
plt.grid(True)
plt.axhline(y=threshold, color="r", linestyle="--", label="Threshold")
plt.show()
for i, score in enumerate(similarity_scores):
if score < threshold:
print(
f"First sentence after similarity score "
f"below {threshold}: {sentences[i + window_size]}"
)