File size: 20,430 Bytes
d8b1734
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
from dataclasses import dataclass
from typing import Any, List

import numpy as np

from semantic_router.encoders.base import BaseEncoder
from semantic_chunkers.schema import Chunk
from semantic_chunkers.chunkers.base import BaseChunker
from semantic_chunkers.splitters.base import BaseSplitter
from semantic_chunkers.splitters.sentence import RegexSplitter
from semantic_chunkers.utils.text import tiktoken_length
from semantic_chunkers.utils.logger import logger

from tqdm.auto import tqdm


@dataclass
class ChunkStatistics:
    total_documents: int
    total_chunks: int
    chunks_by_threshold: int
    chunks_by_max_chunk_size: int
    chunks_by_last_split: int
    min_token_size: int
    max_token_size: int
    chunks_by_similarity_ratio: float

    def __str__(self):
        return (
            f"Chunking Statistics:\n"
            f"  - Total Documents: {self.total_documents}\n"
            f"  - Total Chunks: {self.total_chunks}\n"
            f"  - Chunks by Threshold: {self.chunks_by_threshold}\n"
            f"  - Chunks by Max Chunk Size: {self.chunks_by_max_chunk_size}\n"
            f"  - Last Chunk: {self.chunks_by_last_split}\n"
            f"  - Minimum Token Size of Chunk: {self.min_token_size}\n"
            f"  - Maximum Token Size of Chunk: {self.max_token_size}\n"
            f"  - Similarity Chunk Ratio: {self.chunks_by_similarity_ratio:.2f}"
        )


class StatisticalChunker(BaseChunker):
    def __init__(
        self,
        encoder: BaseEncoder,
        splitter: BaseSplitter = RegexSplitter(),
        name="statistical_chunker",
        threshold_adjustment=0.01,
        dynamic_threshold: bool = True,
        window_size=5,
        min_split_tokens=100,
        max_split_tokens=300,
        split_tokens_tolerance=10,
        plot_chunks=False,
        enable_statistics=False,
    ):
        super().__init__(name=name, encoder=encoder, splitter=splitter)
        self.calculated_threshold: float
        self.encoder = encoder
        self.threshold_adjustment = threshold_adjustment
        self.dynamic_threshold = dynamic_threshold
        self.window_size = window_size
        self.plot_chunks = plot_chunks
        self.min_split_tokens = min_split_tokens
        self.max_split_tokens = max_split_tokens
        self.split_tokens_tolerance = split_tokens_tolerance
        self.enable_statistics = enable_statistics
        self.statistics: ChunkStatistics

    def _chunk(
        self, splits: List[Any], metadatas: List[dict], batch_size: int = 64, enforce_max_tokens: bool = False
    ) -> List[Chunk]:
        """Merge splits into chunks using semantic similarity, with optional enforcement of maximum token limits per chunk.

        :param splits: Splits to be merged into chunks.
        :param batch_size: Number of splits to process in one batch.
        :param enforce_max_tokens: If True, further split chunks that exceed the maximum token limit.

        :return: List of chunks.
        """
        # Split the docs that already exceed max_split_tokens to smaller chunks
        if enforce_max_tokens:
            new_splits = []
            for split in splits:
                token_count = tiktoken_length(split)
                if token_count > self.max_split_tokens:
                    logger.info(
                        f"Single document exceeds the maximum token limit "
                        f"of {self.max_split_tokens}. "
                        "Splitting to sentences before semantically merging."
                    )
                    _splits = self._split(split)
                    new_splits.extend(_splits)
                else:
                    new_splits.append(split)

            splits = [split for split in new_splits if split and split.strip()]

        chunks = []
        last_split = None
        for i in tqdm(range(0, len(splits), batch_size)):
            batch_splits = splits[i : i + batch_size]
            if last_split is not None:
                batch_splits = last_split.splits + batch_splits

            encoded_splits = self._encode_documents(batch_splits)
            similarities = self._calculate_similarity_scores(encoded_splits)
            if self.dynamic_threshold:
                self._find_optimal_threshold(batch_splits, similarities)
            else:
                self.calculated_threshold = self.encoder.score_threshold
            split_indices = self._find_split_indices(similarities=similarities)
            doc_chunks = self._split_documents(
                batch_splits, metadatas, split_indices, similarities
            )

            if len(doc_chunks) > 1:
                chunks.extend(doc_chunks[:-1])
                last_split = doc_chunks[-1]
            else:
                last_split = doc_chunks[0]

            if self.plot_chunks:
                self.plot_similarity_scores(similarities, split_indices, doc_chunks)

            if self.enable_statistics:
                print(self.statistics)

        if last_split:
            chunks.append(last_split)

        return chunks

    def __call__(self, docs: List[str], metadatas: List[dict], batch_size: int = 64) -> List[List[Chunk]]:
        """Split documents into smaller chunks based on semantic similarity.

        :param docs: list of text documents to be split, if only wanted to
            split a single document, pass it as a list with a single element.

        :return: list of Chunk objects containing the split documents.
        """
        if not docs:
            raise ValueError("At least one document is required for splitting.")

        all_chunks = []
        for doc in docs:
            token_count = tiktoken_length(doc)
            if token_count > self.max_split_tokens:
                logger.info(
                    f"Single document exceeds the maximum token limit "
                    f"of {self.max_split_tokens}. "
                    "Splitting to sentences before semantically merging."
                )
            if isinstance(doc, str):
                splits = self._split(doc)
                doc_chunks = self._chunk(splits, metadatas, batch_size=batch_size)
                all_chunks.append(doc_chunks)
            else:
                raise ValueError("The document must be a string.")
        return all_chunks

    def _encode_documents(self, docs: List[str]) -> np.ndarray:
        """
        :param docs: List of text documents to be encoded.
        :return: A numpy array of embeddings for the given documents.
        """
        return np.array(self.encoder(docs))

    def _calculate_similarity_scores(self, encoded_docs: np.ndarray) -> List[float]:
        raw_similarities = []
        for idx in range(1, len(encoded_docs)):
            window_start = max(0, idx - self.window_size)
            cumulative_context = np.mean(encoded_docs[window_start:idx], axis=0)
            curr_sim_score = np.dot(cumulative_context, encoded_docs[idx]) / (
                np.linalg.norm(cumulative_context) * np.linalg.norm(encoded_docs[idx])
                + 1e-10
            )
            raw_similarities.append(curr_sim_score)
        return raw_similarities

    def _find_split_indices(self, similarities: List[float]) -> List[int]:
        split_indices = []
        for idx, score in enumerate(similarities):
            logger.debug(f"Similarity score at index {idx}: {score}")
            if score < self.calculated_threshold:
                logger.debug(
                    f"Adding to split_indices due to score < threshold: "
                    f"{score} < {self.calculated_threshold}"
                )
                # Chunk after the document at idx
                split_indices.append(idx + 1)
        return split_indices

    def _find_optimal_threshold(self, docs: List[str], similarity_scores: List[float]):
        token_counts = [tiktoken_length(doc) for doc in docs]
        cumulative_token_counts = np.cumsum([0] + token_counts)

        # Analyze the distribution of similarity scores to set initial bounds
        median_score = np.median(similarity_scores)
        std_dev = np.std(similarity_scores)

        # Set initial bounds based on median and standard deviation
        low = max(0.0, float(median_score - std_dev))
        high = min(1.0, float(median_score + std_dev))

        iteration = 0
        median_tokens = 0
        while low <= high:
            self.calculated_threshold = (low + high) / 2
            split_indices = self._find_split_indices(similarity_scores)
            logger.debug(
                f"Iteration {iteration}: Trying threshold: {self.calculated_threshold}"
            )

            # Calculate the token counts for each split using the cumulative sums
            split_token_counts = [
                cumulative_token_counts[end] - cumulative_token_counts[start]
                for start, end in zip(
                    [0] + split_indices, split_indices + [len(token_counts)]
                )
            ]

            # Calculate the median token count for the chunks
            median_tokens = np.median(split_token_counts)
            logger.debug(
                f"Iteration {iteration}: Median tokens per split: {median_tokens}"
            )
            if (
                self.min_split_tokens - self.split_tokens_tolerance
                <= median_tokens
                <= self.max_split_tokens + self.split_tokens_tolerance
            ):
                logger.debug("Median tokens in target range. Stopping iteration.")
                break
            elif median_tokens < self.min_split_tokens:
                high = self.calculated_threshold - self.threshold_adjustment
                logger.debug(f"Iteration {iteration}: Adjusting high to {high}")
            else:
                low = self.calculated_threshold + self.threshold_adjustment
                logger.debug(f"Iteration {iteration}: Adjusting low to {low}")
            iteration += 1

        logger.debug(
            f"Optimal threshold {self.calculated_threshold} found "
            f"with median tokens ({median_tokens}) in target range "
            f"({self.min_split_tokens}-{self.max_split_tokens})."
        )

        return self.calculated_threshold

    def _split_documents(
        self, docs: List[str], metadatas: List[dict], split_indices: List[int], similarities: List[float]
    ) -> List[Chunk]:
        """
        This method iterates through each document, appending it to the current split
        until it either reaches a split point (determined by split_indices) or exceeds
        the maximum token limit for a split (self.max_split_tokens).
        When a document causes the current token count to exceed this limit,
        or when a split point is reached and the minimum token requirement is met,
        the current split is finalized and added to the List of chunks.
        """
        token_counts = [tiktoken_length(doc) for doc in docs]
        chunks, current_split = [], []
        current_tokens_count = 0

        # Statistics
        chunks_by_threshold = 0
        chunks_by_max_chunk_size = 0
        chunks_by_last_split = 0

        for doc_idx, doc in enumerate(docs):
            doc_token_count = token_counts[doc_idx]
            logger.debug(f"Accumulative token count: {current_tokens_count} tokens")
            logger.debug(f"Document token count: {doc_token_count} tokens")
            # Check if current index is a split point based on similarity
            if doc_idx + 1 in split_indices:
                if (
                    self.min_split_tokens
                    <= current_tokens_count + doc_token_count
                    < self.max_split_tokens
                ):
                    # Include the current document before splitting
                    # if it doesn't exceed the max limit
                    current_split.append(doc)
                    current_tokens_count += doc_token_count

                    triggered_score = (
                        similarities[doc_idx] if doc_idx < len(similarities) else None
                    )
                    chunks.append(
                        Chunk(
                            splits=current_split.copy(),
                            is_triggered=True,
                            triggered_score=triggered_score,
                            token_count=current_tokens_count,
                            metadata=metadatas[doc_idx].copy()
                        )
                    )
                    logger.debug(
                        f"Chunk finalized with {current_tokens_count} tokens due to "
                        f"threshold {self.calculated_threshold}."
                    )
                    current_split, current_tokens_count = [], 0
                    chunks_by_threshold += 1
                    continue  # Move to the next document after splitting

            # Check if adding the current document exceeds the max token limit
            if current_tokens_count + doc_token_count > self.max_split_tokens:
                if current_tokens_count >= self.min_split_tokens:
                    chunks.append(
                        Chunk(
                            splits=current_split.copy(),
                            is_triggered=False,
                            triggered_score=None,
                            token_count=current_tokens_count,
                            metadata=metadatas[doc_idx].copy()
                        )
                    )
                    chunks_by_max_chunk_size += 1
                    logger.debug(
                        f"Chink finalized with {current_tokens_count} tokens due to "
                        f"exceeding token limit of {self.max_split_tokens}."
                    )
                    current_split, current_tokens_count = [], 0

            current_split.append(doc)
            current_tokens_count += doc_token_count

        # Handle the last split
        if current_split:
            chunks.append(
                Chunk(
                    splits=current_split.copy(),
                    is_triggered=False,
                    triggered_score=None,
                    token_count=current_tokens_count,
                    metadata=metadatas[doc_idx].copy()
                )
            )
            chunks_by_last_split += 1
            logger.debug(
                f"Final split added with {current_tokens_count} "
                "tokens due to remaining documents."
            )

        # Validation to ensure no tokens are lost during the split
        original_token_count = sum(token_counts)
        split_token_count = sum(
            [tiktoken_length(doc) for split in chunks for doc in split.splits]
        )
        if original_token_count != split_token_count:
            logger.error(
                f"Token count mismatch: {original_token_count} != {split_token_count}"
            )
            raise ValueError(
                f"Token count mismatch: {original_token_count} != {split_token_count}"
            )

        # Statistics
        total_chunks = len(chunks)
        chunks_by_similarity_ratio = (
            chunks_by_threshold / total_chunks if total_chunks else 0
        )
        min_token_size = max_token_size = 0
        if chunks:
            token_counts = [
                split.token_count for split in chunks if split.token_count is not None
            ]
            min_token_size, max_token_size = min(token_counts, default=0), max(
                token_counts, default=0
            )

        self.statistics = ChunkStatistics(
            total_documents=len(docs),
            total_chunks=total_chunks,
            chunks_by_threshold=chunks_by_threshold,
            chunks_by_max_chunk_size=chunks_by_max_chunk_size,
            chunks_by_last_split=chunks_by_last_split,
            min_token_size=min_token_size,
            max_token_size=max_token_size,
            chunks_by_similarity_ratio=chunks_by_similarity_ratio,
        )

        return chunks

    def plot_similarity_scores(
        self,
        similarities: List[float],
        split_indices: List[int],
        chunks: list[Chunk],
    ):
        try:
            from matplotlib import pyplot as plt
        except ImportError:
            logger.warning(
                "Plotting is disabled. Please `pip install "
                "semantic-router[processing]`."
            )
            return

        _, axs = plt.subplots(2, 1, figsize=(12, 12))  # Adjust for two plots

        # Plot 1: Similarity Scores
        axs[0].plot(similarities, label="Similarity Scores", marker="o")
        for split_index in split_indices:
            axs[0].axvline(
                x=split_index - 1,
                color="r",
                linestyle="--",
                label="Chunk" if split_index == split_indices[0] else "",
            )
        axs[0].axhline(
            y=self.calculated_threshold,
            color="g",
            linestyle="-.",
            label="Threshold Similarity Score",
        )

        # Annotating each similarity score
        for i, score in enumerate(similarities):
            axs[0].annotate(
                f"{score:.2f}",  # Formatting to two decimal places
                (i, score),
                textcoords="offset points",
                xytext=(0, 10),  # Positioning the text above the point
                ha="center",
            )  # Center-align the text

        axs[0].set_xlabel("Document Segment Index")
        axs[0].set_ylabel("Similarity Score")
        axs[0].set_title(
            f"Threshold: {self.calculated_threshold} |"
            f" Window Size: {self.window_size}",
            loc="right",
            fontsize=10,
        )
        axs[0].legend()

        # Plot 2: Chunk Token Size Distribution
        token_counts = [split.token_count for split in chunks]
        axs[1].bar(range(len(token_counts)), token_counts, color="lightblue")
        axs[1].set_title("Chunk Token Sizes")
        axs[1].set_xlabel("Chunk Index")
        axs[1].set_ylabel("Token Count")
        axs[1].set_xticks(range(len(token_counts)))
        axs[1].set_xticklabels([str(i) for i in range(len(token_counts))])
        axs[1].grid(True)

        # Annotate each bar with the token size
        for idx, token_count in enumerate(token_counts):
            if not token_count:
                continue
            axs[1].text(
                idx, token_count + 0.01, str(token_count), ha="center", va="bottom"
            )

        plt.tight_layout()
        plt.show()

    def plot_sentence_similarity_scores(
        self, docs: List[str], threshold: float, window_size: int
    ):
        try:
            from matplotlib import pyplot as plt
        except ImportError:
            logger.warning("Plotting is disabled. Please `pip install matplotlib`.")
            return
        """
        Computes similarity scores between the average of the last
        'window_size' sentences and the next one,
        plots a graph of these similarity scores, and prints the first
        sentence after a similarity score below
        a specified threshold.
        """
        sentences = [sentence for doc in docs for sentence in self._split(doc)]
        encoded_sentences = self._encode_documents(sentences)
        similarity_scores = []

        for i in range(window_size, len(encoded_sentences)):
            window_avg_encoding = np.mean(
                encoded_sentences[i - window_size : i], axis=0
            )
            sim_score = np.dot(window_avg_encoding, encoded_sentences[i]) / (
                np.linalg.norm(window_avg_encoding)
                * np.linalg.norm(encoded_sentences[i])
                + 1e-10
            )
            similarity_scores.append(sim_score)

        plt.figure(figsize=(10, 8))
        plt.plot(similarity_scores, marker="o", linestyle="-", color="b")
        plt.title("Sliding Window Sentence Similarity Scores")
        plt.xlabel("Sentence Index")
        plt.ylabel("Similarity Score")
        plt.grid(True)
        plt.axhline(y=threshold, color="r", linestyle="--", label="Threshold")
        plt.show()

        for i, score in enumerate(similarity_scores):
            if score < threshold:
                print(
                    f"First sentence after similarity score "
                    f"below {threshold}: {sentences[i + window_size]}"
                )