LLMsearch / semantic_chunker.py
Ascol57's picture
Upload 18 files
9afd745 verified
raw
history blame contribute delete
No virus
9.39 kB
import copy
import re
from typing import Any, Dict, Iterable, List, Literal, Optional, Sequence, Tuple, cast
import numpy as np
from langchain_community.utils.math import (
cosine_similarity,
)
from langchain_core.documents import BaseDocumentTransformer, Document
from langchain_core.embeddings import Embeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
def calculate_cosine_distances(sentence_embeddings) -> np.array:
"""Calculate cosine distances between sentences.
Args:
sentence_embeddings: List of sentence embeddings to calculate distances for.
Returns:
Distance between each pair of adjacent sentences
"""
return (1 - cosine_similarity(sentence_embeddings, sentence_embeddings)).flatten()[1::len(sentence_embeddings) + 1]
BreakpointThresholdType = Literal["percentile", "standard_deviation", "interquartile"]
BREAKPOINT_DEFAULTS: Dict[BreakpointThresholdType, float] = {
"percentile": 95,
"standard_deviation": 3,
"interquartile": 1.5,
}
class BoundedSemanticChunker(BaseDocumentTransformer):
"""First splits the text using semantic chunking according to the specified
'breakpoint_threshold_amount', but then uses a RecursiveCharacterTextSplitter
to split all chunks that are larger than 'max_chunk_size'.
Adapted from langchain_experimental.text_splitter.SemanticChunker"""
def __init__(
self,
embeddings: Embeddings,
buffer_size: int = 1,
add_start_index: bool = False,
breakpoint_threshold_type: BreakpointThresholdType = "percentile",
breakpoint_threshold_amount: Optional[float] = None,
number_of_chunks: Optional[int] = None,
max_chunk_size: int = 500,
):
self._add_start_index = add_start_index
self.embeddings = embeddings
self.buffer_size = buffer_size
self.breakpoint_threshold_type = breakpoint_threshold_type
self.number_of_chunks = number_of_chunks
if breakpoint_threshold_amount is None:
self.breakpoint_threshold_amount = BREAKPOINT_DEFAULTS[
breakpoint_threshold_type
]
else:
self.breakpoint_threshold_amount = breakpoint_threshold_amount
self.max_chunk_size = max_chunk_size
# Splitting the text on '.', '?', and '!'
self.sentence_split_regex = re.compile(r"(?<=[.?!])\s+")
assert self.breakpoint_threshold_type == "percentile", "only breakpoint_threshold_type 'percentile' is currently supported"
assert self.buffer_size == 1, "combining sentences is not supported yet"
def _calculate_sentence_distances(
self, sentences: List[dict]
) -> Tuple[List[float], List[dict]]:
"""Split text into multiple components."""
embeddings = self.embeddings.embed_documents(sentences)
return calculate_cosine_distances(embeddings)
def _calculate_breakpoint_threshold(self, distances: np.array, alt_breakpoint_threshold_amount=None) -> float:
if alt_breakpoint_threshold_amount is None:
breakpoint_threshold_amount = self.breakpoint_threshold_amount
else:
breakpoint_threshold_amount = alt_breakpoint_threshold_amount
if self.breakpoint_threshold_type == "percentile":
return cast(
float,
np.percentile(distances, breakpoint_threshold_amount),
)
elif self.breakpoint_threshold_type == "standard_deviation":
return cast(
float,
np.mean(distances)
+ breakpoint_threshold_amount * np.std(distances),
)
elif self.breakpoint_threshold_type == "interquartile":
q1, q3 = np.percentile(distances, [25, 75])
iqr = q3 - q1
return np.mean(distances) + breakpoint_threshold_amount * iqr
else:
raise ValueError(
f"Got unexpected `breakpoint_threshold_type`: "
f"{self.breakpoint_threshold_type}"
)
def _threshold_from_clusters(self, distances: List[float]) -> float:
"""
Calculate the threshold based on the number of chunks.
Inverse of percentile method.
"""
if self.number_of_chunks is None:
raise ValueError(
"This should never be called if `number_of_chunks` is None."
)
x1, y1 = len(distances), 0.0
x2, y2 = 1.0, 100.0
x = max(min(self.number_of_chunks, x1), x2)
# Linear interpolation formula
y = y1 + ((y2 - y1) / (x2 - x1)) * (x - x1)
y = min(max(y, 0), 100)
return cast(float, np.percentile(distances, y))
def split_text(
self,
text: str,
) -> List[str]:
sentences = self.sentence_split_regex.split(text)
# having len(sentences) == 1 would cause the following
# np.percentile to fail.
if len(sentences) == 1:
return sentences
bad_sentences = []
num_good_sentences = 0
distances = self._calculate_sentence_distances(sentences)
if self.number_of_chunks is not None:
breakpoint_distance_threshold = self._threshold_from_clusters(distances)
else:
breakpoint_distance_threshold = self._calculate_breakpoint_threshold(
distances
)
indices_above_thresh = [
i for i, x in enumerate(distances) if x > breakpoint_distance_threshold
]
chunks = []
start_index = 0
# Iterate through the breakpoints to slice the sentences
for index in indices_above_thresh:
# The end index is the current breakpoint
end_index = index
# Slice the sentence_dicts from the current start index to the end index
group = sentences[start_index : end_index + 1]
combined_text = " ".join(group)
if len(combined_text) <= self.max_chunk_size:
chunks.append(combined_text)
else:
sent_lengths = np.array([len(sd) for sd in group])
good_indices = np.flatnonzero(np.cumsum(sent_lengths) <= self.max_chunk_size)
smaller_group = [group[i] for i in good_indices]
if smaller_group:
combined_text = " ".join(smaller_group)
chunks.append(combined_text)
group = group[good_indices[-1]:]
bad_sentences.extend(group)
# Update the start index for the next group
start_index = index + 1
# The last group, if any sentences remain
if start_index < len(sentences):
group = sentences[start_index:]
combined_text = " ".join(group)
if len(combined_text) <= self.max_chunk_size:
chunks.append(combined_text)
else:
sent_lengths = np.array([len(sd) for sd in group])
good_indices = np.flatnonzero(np.cumsum(sent_lengths) <= self.max_chunk_size)
smaller_group = [group[i] for i in good_indices]
if smaller_group:
combined_text = " ".join(smaller_group)
chunks.append(combined_text)
group = group[good_indices[-1]:]
bad_sentences.extend(group)
# If pure semantic chunking wasn't able to split all text for any breakpoint_threshold_amount,
# split the remaining problematic text using a recursive character splitter instead
if len(bad_sentences) > 0:
recursive_splitter = RecursiveCharacterTextSplitter(chunk_size=self.max_chunk_size, chunk_overlap=10,
separators=["\n\n", "\n", ".", ", ", " ", ""])
remaining_text = "".join(bad_sentences)
chunks.extend(recursive_splitter.split_text(remaining_text))
return chunks
def create_documents(
self, texts: List[str], metadatas: Optional[List[dict]] = None
) -> List[Document]:
"""Create documents from a list of texts."""
_metadatas = metadatas or [{}] * len(texts)
documents = []
for i, text in enumerate(texts):
index = -1
for chunk in self.split_text(text):
metadata = copy.deepcopy(_metadatas[i])
if self._add_start_index:
index = text.find(chunk, index + 1)
metadata["start_index"] = index
new_doc = Document(page_content=chunk, metadata=metadata)
documents.append(new_doc)
return documents
def split_documents(self, documents: Iterable[Document]) -> List[Document]:
"""Split documents."""
texts, metadatas = [], []
for doc in documents:
texts.append(doc.page_content)
metadatas.append(doc.metadata)
return self.create_documents(texts, metadatas=metadatas)
def transform_documents(
self, documents: Sequence[Document], **kwargs: Any
) -> Sequence[Document]:
"""Transform sequence of documents by splitting them."""
return self.split_documents(list(documents))