File size: 9,385 Bytes
9afd745
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
import copy
import re
from typing import Any, Dict, Iterable, List, Literal, Optional, Sequence, Tuple, cast

import numpy as np
from langchain_community.utils.math import (
    cosine_similarity,
)
from langchain_core.documents import BaseDocumentTransformer, Document
from langchain_core.embeddings import Embeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter


def calculate_cosine_distances(sentence_embeddings) -> np.array:
    """Calculate cosine distances between sentences.

    Args:
        sentence_embeddings: List of sentence embeddings to calculate distances for.

    Returns:
        Distance between each pair of adjacent sentences
    """
    return (1 - cosine_similarity(sentence_embeddings, sentence_embeddings)).flatten()[1::len(sentence_embeddings) + 1]


BreakpointThresholdType = Literal["percentile", "standard_deviation", "interquartile"]
BREAKPOINT_DEFAULTS: Dict[BreakpointThresholdType, float] = {
    "percentile": 95,
    "standard_deviation": 3,
    "interquartile": 1.5,
}


class BoundedSemanticChunker(BaseDocumentTransformer):
    """First splits the text using semantic chunking according to the specified
    'breakpoint_threshold_amount', but then uses a RecursiveCharacterTextSplitter
    to split all chunks that are larger than 'max_chunk_size'.

    Adapted from langchain_experimental.text_splitter.SemanticChunker"""

    def __init__(
            self,
            embeddings: Embeddings,
            buffer_size: int = 1,
            add_start_index: bool = False,
            breakpoint_threshold_type: BreakpointThresholdType = "percentile",
            breakpoint_threshold_amount: Optional[float] = None,
            number_of_chunks: Optional[int] = None,
            max_chunk_size: int = 500,
    ):
        self._add_start_index = add_start_index
        self.embeddings = embeddings
        self.buffer_size = buffer_size
        self.breakpoint_threshold_type = breakpoint_threshold_type
        self.number_of_chunks = number_of_chunks
        if breakpoint_threshold_amount is None:
            self.breakpoint_threshold_amount = BREAKPOINT_DEFAULTS[
                breakpoint_threshold_type
            ]
        else:
            self.breakpoint_threshold_amount = breakpoint_threshold_amount
        self.max_chunk_size = max_chunk_size
        # Splitting the text on '.', '?', and '!'
        self.sentence_split_regex = re.compile(r"(?<=[.?!])\s+")

        assert self.breakpoint_threshold_type == "percentile", "only breakpoint_threshold_type 'percentile' is currently supported"
        assert self.buffer_size == 1, "combining sentences is not supported yet"

    def _calculate_sentence_distances(
        self, sentences: List[dict]
    ) -> Tuple[List[float], List[dict]]:
        """Split text into multiple components."""
        embeddings = self.embeddings.embed_documents(sentences)
        return calculate_cosine_distances(embeddings)

    def _calculate_breakpoint_threshold(self, distances: np.array, alt_breakpoint_threshold_amount=None) -> float:
        if alt_breakpoint_threshold_amount is None:
            breakpoint_threshold_amount = self.breakpoint_threshold_amount
        else:
            breakpoint_threshold_amount = alt_breakpoint_threshold_amount
        if self.breakpoint_threshold_type == "percentile":
            return cast(
                float,
                np.percentile(distances, breakpoint_threshold_amount),
            )
        elif self.breakpoint_threshold_type == "standard_deviation":
            return cast(
                float,
                np.mean(distances)
                + breakpoint_threshold_amount * np.std(distances),
            )
        elif self.breakpoint_threshold_type == "interquartile":
            q1, q3 = np.percentile(distances, [25, 75])
            iqr = q3 - q1

            return np.mean(distances) + breakpoint_threshold_amount * iqr
        else:
            raise ValueError(
                f"Got unexpected `breakpoint_threshold_type`: "
                f"{self.breakpoint_threshold_type}"
            )

    def _threshold_from_clusters(self, distances: List[float]) -> float:
        """
        Calculate the threshold based on the number of chunks.
        Inverse of percentile method.
        """
        if self.number_of_chunks is None:
            raise ValueError(
                "This should never be called if `number_of_chunks` is None."
            )
        x1, y1 = len(distances), 0.0
        x2, y2 = 1.0, 100.0

        x = max(min(self.number_of_chunks, x1), x2)

        # Linear interpolation formula
        y = y1 + ((y2 - y1) / (x2 - x1)) * (x - x1)
        y = min(max(y, 0), 100)

        return cast(float, np.percentile(distances, y))

    def split_text(
        self,
        text: str,
    ) -> List[str]:
        sentences = self.sentence_split_regex.split(text)

        # having len(sentences) == 1 would cause the following
        # np.percentile to fail.
        if len(sentences) == 1:
            return sentences

        bad_sentences = []
        num_good_sentences = 0

        distances = self._calculate_sentence_distances(sentences)

        if self.number_of_chunks is not None:
            breakpoint_distance_threshold = self._threshold_from_clusters(distances)
        else:
            breakpoint_distance_threshold = self._calculate_breakpoint_threshold(
                distances
            )

        indices_above_thresh = [
            i for i, x in enumerate(distances) if x > breakpoint_distance_threshold
        ]

        chunks = []
        start_index = 0

        # Iterate through the breakpoints to slice the sentences
        for index in indices_above_thresh:
            # The end index is the current breakpoint
            end_index = index

            # Slice the sentence_dicts from the current start index to the end index
            group = sentences[start_index : end_index + 1]
            combined_text = " ".join(group)
            if len(combined_text) <= self.max_chunk_size:
                chunks.append(combined_text)
            else:
                sent_lengths = np.array([len(sd) for sd in group])
                good_indices = np.flatnonzero(np.cumsum(sent_lengths) <= self.max_chunk_size)
                smaller_group = [group[i] for i in good_indices]
                if smaller_group:
                    combined_text = " ".join(smaller_group)
                    chunks.append(combined_text)
                    group = group[good_indices[-1]:]
                bad_sentences.extend(group)

            # Update the start index for the next group
            start_index = index + 1

        # The last group, if any sentences remain
        if start_index < len(sentences):
            group = sentences[start_index:]
            combined_text = " ".join(group)
            if len(combined_text) <= self.max_chunk_size:
                chunks.append(combined_text)
            else:
                sent_lengths = np.array([len(sd) for sd in group])
                good_indices = np.flatnonzero(np.cumsum(sent_lengths) <= self.max_chunk_size)
                smaller_group = [group[i] for i in good_indices]
                if smaller_group:
                    combined_text = " ".join(smaller_group)
                    chunks.append(combined_text)
                    group = group[good_indices[-1]:]
                bad_sentences.extend(group)

        # If pure semantic chunking wasn't able to split all text for any breakpoint_threshold_amount,
        # split the remaining problematic text using a recursive character splitter instead
        if len(bad_sentences) > 0:
            recursive_splitter = RecursiveCharacterTextSplitter(chunk_size=self.max_chunk_size, chunk_overlap=10,
                                                                separators=["\n\n", "\n", ".", ", ", " ", ""])
            remaining_text = "".join(bad_sentences)
            chunks.extend(recursive_splitter.split_text(remaining_text))
        return chunks

    def create_documents(
                self, texts: List[str], metadatas: Optional[List[dict]] = None
        ) -> List[Document]:
            """Create documents from a list of texts."""
            _metadatas = metadatas or [{}] * len(texts)
            documents = []
            for i, text in enumerate(texts):
                index = -1
                for chunk in self.split_text(text):
                    metadata = copy.deepcopy(_metadatas[i])
                    if self._add_start_index:
                        index = text.find(chunk, index + 1)
                        metadata["start_index"] = index
                    new_doc = Document(page_content=chunk, metadata=metadata)
                    documents.append(new_doc)
            return documents

    def split_documents(self, documents: Iterable[Document]) -> List[Document]:
        """Split documents."""
        texts, metadatas = [], []
        for doc in documents:
            texts.append(doc.page_content)
            metadatas.append(doc.metadata)
        return self.create_documents(texts, metadatas=metadatas)

    def transform_documents(
            self, documents: Sequence[Document], **kwargs: Any
    ) -> Sequence[Document]:
        """Transform sequence of documents by splitting them."""
        return self.split_documents(list(documents))