import copy import re from typing import Any, Dict, Iterable, List, Literal, Optional, Sequence, Tuple, cast import numpy as np from langchain_community.utils.math import ( cosine_similarity, ) from langchain_core.documents import BaseDocumentTransformer, Document from langchain_core.embeddings import Embeddings from langchain.text_splitter import RecursiveCharacterTextSplitter def calculate_cosine_distances(sentence_embeddings) -> np.array: """Calculate cosine distances between sentences. Args: sentence_embeddings: List of sentence embeddings to calculate distances for. Returns: Distance between each pair of adjacent sentences """ return (1 - cosine_similarity(sentence_embeddings, sentence_embeddings)).flatten()[1::len(sentence_embeddings) + 1] BreakpointThresholdType = Literal["percentile", "standard_deviation", "interquartile"] BREAKPOINT_DEFAULTS: Dict[BreakpointThresholdType, float] = { "percentile": 95, "standard_deviation": 3, "interquartile": 1.5, } class BoundedSemanticChunker(BaseDocumentTransformer): """First splits the text using semantic chunking according to the specified 'breakpoint_threshold_amount', but then uses a RecursiveCharacterTextSplitter to split all chunks that are larger than 'max_chunk_size'. Adapted from langchain_experimental.text_splitter.SemanticChunker""" def __init__( self, embeddings: Embeddings, buffer_size: int = 1, add_start_index: bool = False, breakpoint_threshold_type: BreakpointThresholdType = "percentile", breakpoint_threshold_amount: Optional[float] = None, number_of_chunks: Optional[int] = None, max_chunk_size: int = 500, ): self._add_start_index = add_start_index self.embeddings = embeddings self.buffer_size = buffer_size self.breakpoint_threshold_type = breakpoint_threshold_type self.number_of_chunks = number_of_chunks if breakpoint_threshold_amount is None: self.breakpoint_threshold_amount = BREAKPOINT_DEFAULTS[ breakpoint_threshold_type ] else: self.breakpoint_threshold_amount = breakpoint_threshold_amount self.max_chunk_size = max_chunk_size # Splitting the text on '.', '?', and '!' self.sentence_split_regex = re.compile(r"(?<=[.?!])\s+") assert self.breakpoint_threshold_type == "percentile", "only breakpoint_threshold_type 'percentile' is currently supported" assert self.buffer_size == 1, "combining sentences is not supported yet" def _calculate_sentence_distances( self, sentences: List[dict] ) -> Tuple[List[float], List[dict]]: """Split text into multiple components.""" embeddings = self.embeddings.embed_documents(sentences) return calculate_cosine_distances(embeddings) def _calculate_breakpoint_threshold(self, distances: np.array, alt_breakpoint_threshold_amount=None) -> float: if alt_breakpoint_threshold_amount is None: breakpoint_threshold_amount = self.breakpoint_threshold_amount else: breakpoint_threshold_amount = alt_breakpoint_threshold_amount if self.breakpoint_threshold_type == "percentile": return cast( float, np.percentile(distances, breakpoint_threshold_amount), ) elif self.breakpoint_threshold_type == "standard_deviation": return cast( float, np.mean(distances) + breakpoint_threshold_amount * np.std(distances), ) elif self.breakpoint_threshold_type == "interquartile": q1, q3 = np.percentile(distances, [25, 75]) iqr = q3 - q1 return np.mean(distances) + breakpoint_threshold_amount * iqr else: raise ValueError( f"Got unexpected `breakpoint_threshold_type`: " f"{self.breakpoint_threshold_type}" ) def _threshold_from_clusters(self, distances: List[float]) -> float: """ Calculate the threshold based on the number of chunks. Inverse of percentile method. """ if self.number_of_chunks is None: raise ValueError( "This should never be called if `number_of_chunks` is None." ) x1, y1 = len(distances), 0.0 x2, y2 = 1.0, 100.0 x = max(min(self.number_of_chunks, x1), x2) # Linear interpolation formula y = y1 + ((y2 - y1) / (x2 - x1)) * (x - x1) y = min(max(y, 0), 100) return cast(float, np.percentile(distances, y)) def split_text( self, text: str, ) -> List[str]: sentences = self.sentence_split_regex.split(text) # having len(sentences) == 1 would cause the following # np.percentile to fail. if len(sentences) == 1: return sentences bad_sentences = [] num_good_sentences = 0 distances = self._calculate_sentence_distances(sentences) if self.number_of_chunks is not None: breakpoint_distance_threshold = self._threshold_from_clusters(distances) else: breakpoint_distance_threshold = self._calculate_breakpoint_threshold( distances ) indices_above_thresh = [ i for i, x in enumerate(distances) if x > breakpoint_distance_threshold ] chunks = [] start_index = 0 # Iterate through the breakpoints to slice the sentences for index in indices_above_thresh: # The end index is the current breakpoint end_index = index # Slice the sentence_dicts from the current start index to the end index group = sentences[start_index : end_index + 1] combined_text = " ".join(group) if len(combined_text) <= self.max_chunk_size: chunks.append(combined_text) else: sent_lengths = np.array([len(sd) for sd in group]) good_indices = np.flatnonzero(np.cumsum(sent_lengths) <= self.max_chunk_size) smaller_group = [group[i] for i in good_indices] if smaller_group: combined_text = " ".join(smaller_group) chunks.append(combined_text) group = group[good_indices[-1]:] bad_sentences.extend(group) # Update the start index for the next group start_index = index + 1 # The last group, if any sentences remain if start_index < len(sentences): group = sentences[start_index:] combined_text = " ".join(group) if len(combined_text) <= self.max_chunk_size: chunks.append(combined_text) else: sent_lengths = np.array([len(sd) for sd in group]) good_indices = np.flatnonzero(np.cumsum(sent_lengths) <= self.max_chunk_size) smaller_group = [group[i] for i in good_indices] if smaller_group: combined_text = " ".join(smaller_group) chunks.append(combined_text) group = group[good_indices[-1]:] bad_sentences.extend(group) # If pure semantic chunking wasn't able to split all text for any breakpoint_threshold_amount, # split the remaining problematic text using a recursive character splitter instead if len(bad_sentences) > 0: recursive_splitter = RecursiveCharacterTextSplitter(chunk_size=self.max_chunk_size, chunk_overlap=10, separators=["\n\n", "\n", ".", ", ", " ", ""]) remaining_text = "".join(bad_sentences) chunks.extend(recursive_splitter.split_text(remaining_text)) return chunks def create_documents( self, texts: List[str], metadatas: Optional[List[dict]] = None ) -> List[Document]: """Create documents from a list of texts.""" _metadatas = metadatas or [{}] * len(texts) documents = [] for i, text in enumerate(texts): index = -1 for chunk in self.split_text(text): metadata = copy.deepcopy(_metadatas[i]) if self._add_start_index: index = text.find(chunk, index + 1) metadata["start_index"] = index new_doc = Document(page_content=chunk, metadata=metadata) documents.append(new_doc) return documents def split_documents(self, documents: Iterable[Document]) -> List[Document]: """Split documents.""" texts, metadatas = [], [] for doc in documents: texts.append(doc.page_content) metadatas.append(doc.metadata) return self.create_documents(texts, metadatas=metadatas) def transform_documents( self, documents: Sequence[Document], **kwargs: Any ) -> Sequence[Document]: """Transform sequence of documents by splitting them.""" return self.split_documents(list(documents))