Spaces:
Running
on
T4
Running
on
T4
# This script is adapted from the LangChain package, developed by LangChain AI. | |
# Original code can be found at: https://github.com/langchain-ai/langchain/blob/master/libs/text-splitters/langchain_text_splitters/character.py | |
# Original code can be found at: https://github.com/langchain-ai/langchain/blob/master/libs/text-splitters/langchain_text_splitters/base.py | |
# License: MIT License | |
from attr import dataclass | |
import numpy as np | |
import tiktoken | |
from chunking_evaluation.utils import Language | |
import re | |
from abc import ABC, abstractmethod | |
from enum import Enum | |
import logging | |
from typing import ( | |
AbstractSet, | |
Any, | |
Callable, | |
Collection, | |
Iterable, | |
List, | |
Literal, | |
Optional, | |
Sequence, | |
Type, | |
TypeVar, | |
Union, | |
) | |
class BaseChunker(ABC): | |
def split_text(self, text: str) -> List[str]: | |
pass | |
logger = logging.getLogger(__name__) | |
TS = TypeVar("TS", bound="TextSplitter") | |
class TextSplitter(BaseChunker, ABC): | |
"""Interface for splitting text into chunks.""" | |
def __init__( | |
self, | |
chunk_size: int = 4000, | |
chunk_overlap: int = 200, | |
length_function: Callable[[str], int] = len, | |
keep_separator: bool = False, | |
add_start_index: bool = False, | |
strip_whitespace: bool = True, | |
) -> None: | |
"""Create a new TextSplitter. | |
Args: | |
chunk_size: Maximum size of chunks to return | |
chunk_overlap: Overlap in characters between chunks | |
length_function: Function that measures the length of given chunks | |
keep_separator: Whether to keep the separator in the chunks | |
add_start_index: If `True`, includes chunk's start index in metadata | |
strip_whitespace: If `True`, strips whitespace from the start and end of | |
every document | |
""" | |
if chunk_overlap > chunk_size: | |
raise ValueError( | |
f"Got a larger chunk overlap ({chunk_overlap}) than chunk size " | |
f"({chunk_size}), should be smaller." | |
) | |
self._chunk_size = chunk_size | |
self._chunk_overlap = chunk_overlap | |
self._length_function = length_function | |
self._keep_separator = keep_separator | |
self._add_start_index = add_start_index | |
self._strip_whitespace = strip_whitespace | |
def split_text(self, text: str) -> List[str]: | |
"""Split text into multiple components.""" | |
def _join_docs(self, docs: List[str], separator: str) -> Optional[str]: | |
text = separator.join(docs) | |
if self._strip_whitespace: | |
text = text.strip() | |
if text == "": | |
return None | |
else: | |
return text | |
def _merge_splits(self, splits: Iterable[str], separator: str) -> List[str]: | |
# We now want to combine these smaller pieces into medium size | |
# chunks to send to the LLM. | |
separator_len = self._length_function(separator) | |
docs = [] | |
current_doc: List[str] = [] | |
total = 0 | |
for d in splits: | |
_len = self._length_function(d) | |
if ( | |
total + _len + (separator_len if len(current_doc) > 0 else 0) | |
> self._chunk_size | |
): | |
if total > self._chunk_size: | |
logger.warning( | |
f"Created a chunk of size {total}, " | |
f"which is longer than the specified {self._chunk_size}" | |
) | |
if len(current_doc) > 0: | |
doc = self._join_docs(current_doc, separator) | |
if doc is not None: | |
docs.append(doc) | |
# Keep on popping if: | |
# - we have a larger chunk than in the chunk overlap | |
# - or if we still have any chunks and the length is long | |
while total > self._chunk_overlap or ( | |
total + _len + (separator_len if len(current_doc) > 0 else 0) | |
> self._chunk_size | |
and total > 0 | |
): | |
total -= self._length_function(current_doc[0]) + ( | |
separator_len if len(current_doc) > 1 else 0 | |
) | |
current_doc = current_doc[1:] | |
current_doc.append(d) | |
total += _len + (separator_len if len(current_doc) > 1 else 0) | |
doc = self._join_docs(current_doc, separator) | |
if doc is not None: | |
docs.append(doc) | |
return docs | |
# @classmethod | |
# def from_huggingface_tokenizer(cls, tokenizer: Any, **kwargs: Any) -> TextSplitter: | |
# """Text splitter that uses HuggingFace tokenizer to count length.""" | |
# try: | |
# from transformers import PreTrainedTokenizerBase | |
# if not isinstance(tokenizer, PreTrainedTokenizerBase): | |
# raise ValueError( | |
# "Tokenizer received was not an instance of PreTrainedTokenizerBase" | |
# ) | |
# def _huggingface_tokenizer_length(text: str) -> int: | |
# return len(tokenizer.encode(text)) | |
# except ImportError: | |
# raise ValueError( | |
# "Could not import transformers python package. " | |
# "Please install it with `pip install transformers`." | |
# ) | |
# return cls(length_function=_huggingface_tokenizer_length, **kwargs) | |
def from_tiktoken_encoder( | |
cls: Type[TS], | |
encoding_name: str = "gpt2", | |
model_name: Optional[str] = None, | |
allowed_special: Union[Literal["all"], AbstractSet[str]] = set(), | |
disallowed_special: Union[Literal["all"], Collection[str]] = "all", | |
**kwargs: Any, | |
) -> TS: | |
"""Text splitter that uses tiktoken encoder to count length.""" | |
try: | |
import tiktoken | |
except ImportError: | |
raise ImportError( | |
"Could not import tiktoken python package. " | |
"This is needed in order to calculate max_tokens_for_prompt. " | |
"Please install it with `pip install tiktoken`." | |
) | |
if model_name is not None: | |
enc = tiktoken.encoding_for_model(model_name) | |
else: | |
enc = tiktoken.get_encoding(encoding_name) | |
def _tiktoken_encoder(text: str) -> int: | |
return len( | |
enc.encode( | |
text, | |
allowed_special=allowed_special, | |
disallowed_special=disallowed_special, | |
) | |
) | |
if issubclass(cls, FixedTokenChunker): | |
extra_kwargs = { | |
"encoding_name": encoding_name, | |
"model_name": model_name, | |
"allowed_special": allowed_special, | |
"disallowed_special": disallowed_special, | |
} | |
kwargs = {**kwargs, **extra_kwargs} | |
return cls(length_function=_tiktoken_encoder, **kwargs) | |
class FixedTokenChunker(TextSplitter): | |
"""Splitting text to tokens using model tokenizer.""" | |
def __init__( | |
self, | |
encoding_name: str = "cl100k_base", | |
model_name: Optional[str] = None, | |
chunk_size: int = 4000, | |
chunk_overlap: int = 200, | |
allowed_special: Union[Literal["all"], AbstractSet[str]] = set(), | |
disallowed_special: Union[Literal["all"], Collection[str]] = "all", | |
**kwargs: Any, | |
) -> None: | |
"""Create a new TextSplitter.""" | |
super().__init__(chunk_size=chunk_size, chunk_overlap=chunk_overlap, **kwargs) | |
try: | |
import tiktoken | |
except ImportError: | |
raise ImportError( | |
"Could not import tiktoken python package. " | |
"This is needed in order to for FixedTokenChunker. " | |
"Please install it with `pip install tiktoken`." | |
) | |
if model_name is not None: | |
enc = tiktoken.encoding_for_model(model_name) | |
else: | |
enc = tiktoken.get_encoding(encoding_name) | |
self._tokenizer = enc | |
self._allowed_special = allowed_special | |
self._disallowed_special = disallowed_special | |
def split_text(self, text: str) -> List[str]: | |
def _encode(_text: str) -> List[int]: | |
return self._tokenizer.encode( | |
_text, | |
allowed_special=self._allowed_special, | |
disallowed_special=self._disallowed_special, | |
) | |
tokenizer = Tokenizer( | |
chunk_overlap=self._chunk_overlap, | |
tokens_per_chunk=self._chunk_size, | |
decode=self._tokenizer.decode, | |
encode=_encode, | |
) | |
return split_text_on_tokens(text=text, tokenizer=tokenizer) | |
class Tokenizer: | |
"""Tokenizer data class.""" | |
chunk_overlap: int | |
"""Overlap in tokens between chunks""" | |
tokens_per_chunk: int | |
"""Maximum number of tokens per chunk""" | |
decode: Callable[[List[int]], str] | |
""" Function to decode a list of token ids to a string""" | |
encode: Callable[[str], List[int]] | |
""" Function to encode a string to a list of token ids""" | |
def split_text_on_tokens(*, text: str, tokenizer: Tokenizer) -> List[str]: | |
"""Split incoming text and return chunks using tokenizer.""" | |
splits: List[str] = [] | |
input_ids = tokenizer.encode(text) | |
start_idx = 0 | |
cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids)) | |
chunk_ids = input_ids[start_idx:cur_idx] | |
while start_idx < len(input_ids): | |
splits.append(tokenizer.decode(chunk_ids)) | |
if cur_idx == len(input_ids): | |
break | |
start_idx += tokenizer.tokens_per_chunk - tokenizer.chunk_overlap | |
cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids)) | |
chunk_ids = input_ids[start_idx:cur_idx] | |
return splits | |
def _split_text_with_regex( | |
text: str, separator: str, keep_separator: bool | |
) -> List[str]: | |
# Now that we have the separator, split the text | |
if separator: | |
if keep_separator: | |
# The parentheses in the pattern keep the delimiters in the result. | |
_splits = re.split(f"({separator})", text) | |
splits = [_splits[i] + _splits[i + 1] for i in range(1, len(_splits), 2)] | |
if len(_splits) % 2 == 0: | |
splits += _splits[-1:] | |
splits = [_splits[0]] + splits | |
else: | |
splits = re.split(separator, text) | |
else: | |
splits = list(text) | |
return [s for s in splits if s != ""] | |
class RecursiveTokenChunker(TextSplitter): | |
"""Splitting text by recursively look at characters. | |
Recursively tries to split by different characters to find one | |
that works. | |
""" | |
def __init__( | |
self, | |
chunk_size: int = 4000, | |
chunk_overlap: int = 200, | |
separators: Optional[List[str]] = None, | |
keep_separator: bool = True, | |
is_separator_regex: bool = False, | |
**kwargs: Any, | |
) -> None: | |
"""Create a new TextSplitter.""" | |
super().__init__(chunk_size=chunk_size, chunk_overlap=chunk_overlap, keep_separator=keep_separator, **kwargs) | |
self._separators = separators or ["\n\n", "\n", ".", "?", "!", " ", ""] | |
self._is_separator_regex = is_separator_regex | |
def _split_text(self, text: str, separators: List[str]) -> List[str]: | |
"""Split incoming text and return chunks.""" | |
final_chunks = [] | |
# Get appropriate separator to use | |
separator = separators[-1] | |
new_separators = [] | |
for i, _s in enumerate(separators): | |
_separator = _s if self._is_separator_regex else re.escape(_s) | |
if _s == "": | |
separator = _s | |
break | |
if re.search(_separator, text): | |
separator = _s | |
new_separators = separators[i + 1 :] | |
break | |
_separator = separator if self._is_separator_regex else re.escape(separator) | |
splits = _split_text_with_regex(text, _separator, self._keep_separator) | |
# Now go merging things, recursively splitting longer texts. | |
_good_splits = [] | |
_separator = "" if self._keep_separator else separator | |
for s in splits: | |
if self._length_function(s) < self._chunk_size: | |
_good_splits.append(s) | |
else: | |
if _good_splits: | |
merged_text = self._merge_splits(_good_splits, _separator) | |
final_chunks.extend(merged_text) | |
_good_splits = [] | |
if not new_separators: | |
final_chunks.append(s) | |
else: | |
other_info = self._split_text(s, new_separators) | |
final_chunks.extend(other_info) | |
if _good_splits: | |
merged_text = self._merge_splits(_good_splits, _separator) | |
final_chunks.extend(merged_text) | |
return final_chunks | |
def split_text(self, text: str) -> List[str]: | |
return self._split_text(text, self._separators) | |
# @classmethod | |
# def from_language( | |
# cls, language: Language, **kwargs: Any | |
# ) -> RecursiveCharacterTextSplitter: | |
# separators = cls.get_separators_for_language(language) | |
# return cls(separators=separators, is_separator_regex=True, **kwargs) | |
def get_separators_for_language(language: Language) -> List[str]: | |
if language == Language.CPP: | |
return [ | |
# Split along class definitions | |
"\nclass ", | |
# Split along function definitions | |
"\nvoid ", | |
"\nint ", | |
"\nfloat ", | |
"\ndouble ", | |
# Split along control flow statements | |
"\nif ", | |
"\nfor ", | |
"\nwhile ", | |
"\nswitch ", | |
"\ncase ", | |
# Split by the normal type of lines | |
"\n\n", | |
"\n", | |
" ", | |
"", | |
] | |
elif language == Language.GO: | |
return [ | |
# Split along function definitions | |
"\nfunc ", | |
"\nvar ", | |
"\nconst ", | |
"\ntype ", | |
# Split along control flow statements | |
"\nif ", | |
"\nfor ", | |
"\nswitch ", | |
"\ncase ", | |
# Split by the normal type of lines | |
"\n\n", | |
"\n", | |
" ", | |
"", | |
] | |
elif language == Language.JAVA: | |
return [ | |
# Split along class definitions | |
"\nclass ", | |
# Split along method definitions | |
"\npublic ", | |
"\nprotected ", | |
"\nprivate ", | |
"\nstatic ", | |
# Split along control flow statements | |
"\nif ", | |
"\nfor ", | |
"\nwhile ", | |
"\nswitch ", | |
"\ncase ", | |
# Split by the normal type of lines | |
"\n\n", | |
"\n", | |
" ", | |
"", | |
] | |
elif language == Language.KOTLIN: | |
return [ | |
# Split along class definitions | |
"\nclass ", | |
# Split along method definitions | |
"\npublic ", | |
"\nprotected ", | |
"\nprivate ", | |
"\ninternal ", | |
"\ncompanion ", | |
"\nfun ", | |
"\nval ", | |
"\nvar ", | |
# Split along control flow statements | |
"\nif ", | |
"\nfor ", | |
"\nwhile ", | |
"\nwhen ", | |
"\ncase ", | |
"\nelse ", | |
# Split by the normal type of lines | |
"\n\n", | |
"\n", | |
" ", | |
"", | |
] | |
elif language == Language.JS: | |
return [ | |
# Split along function definitions | |
"\nfunction ", | |
"\nconst ", | |
"\nlet ", | |
"\nvar ", | |
"\nclass ", | |
# Split along control flow statements | |
"\nif ", | |
"\nfor ", | |
"\nwhile ", | |
"\nswitch ", | |
"\ncase ", | |
"\ndefault ", | |
# Split by the normal type of lines | |
"\n\n", | |
"\n", | |
" ", | |
"", | |
] | |
elif language == Language.TS: | |
return [ | |
"\nenum ", | |
"\ninterface ", | |
"\nnamespace ", | |
"\ntype ", | |
# Split along class definitions | |
"\nclass ", | |
# Split along function definitions | |
"\nfunction ", | |
"\nconst ", | |
"\nlet ", | |
"\nvar ", | |
# Split along control flow statements | |
"\nif ", | |
"\nfor ", | |
"\nwhile ", | |
"\nswitch ", | |
"\ncase ", | |
"\ndefault ", | |
# Split by the normal type of lines | |
"\n\n", | |
"\n", | |
" ", | |
"", | |
] | |
elif language == Language.PHP: | |
return [ | |
# Split along function definitions | |
"\nfunction ", | |
# Split along class definitions | |
"\nclass ", | |
# Split along control flow statements | |
"\nif ", | |
"\nforeach ", | |
"\nwhile ", | |
"\ndo ", | |
"\nswitch ", | |
"\ncase ", | |
# Split by the normal type of lines | |
"\n\n", | |
"\n", | |
" ", | |
"", | |
] | |
elif language == Language.PROTO: | |
return [ | |
# Split along message definitions | |
"\nmessage ", | |
# Split along service definitions | |
"\nservice ", | |
# Split along enum definitions | |
"\nenum ", | |
# Split along option definitions | |
"\noption ", | |
# Split along import statements | |
"\nimport ", | |
# Split along syntax declarations | |
"\nsyntax ", | |
# Split by the normal type of lines | |
"\n\n", | |
"\n", | |
" ", | |
"", | |
] | |
elif language == Language.PYTHON: | |
return [ | |
# First, try to split along class definitions | |
"\nclass ", | |
"\ndef ", | |
"\n\tdef ", | |
# Now split by the normal type of lines | |
"\n\n", | |
"\n", | |
" ", | |
"", | |
] | |
elif language == Language.RST: | |
return [ | |
# Split along section titles | |
"\n=+\n", | |
"\n-+\n", | |
"\n\\*+\n", | |
# Split along directive markers | |
"\n\n.. *\n\n", | |
# Split by the normal type of lines | |
"\n\n", | |
"\n", | |
" ", | |
"", | |
] | |
elif language == Language.RUBY: | |
return [ | |
# Split along method definitions | |
"\ndef ", | |
"\nclass ", | |
# Split along control flow statements | |
"\nif ", | |
"\nunless ", | |
"\nwhile ", | |
"\nfor ", | |
"\ndo ", | |
"\nbegin ", | |
"\nrescue ", | |
# Split by the normal type of lines | |
"\n\n", | |
"\n", | |
" ", | |
"", | |
] | |
elif language == Language.RUST: | |
return [ | |
# Split along function definitions | |
"\nfn ", | |
"\nconst ", | |
"\nlet ", | |
# Split along control flow statements | |
"\nif ", | |
"\nwhile ", | |
"\nfor ", | |
"\nloop ", | |
"\nmatch ", | |
"\nconst ", | |
# Split by the normal type of lines | |
"\n\n", | |
"\n", | |
" ", | |
"", | |
] | |
elif language == Language.SCALA: | |
return [ | |
# Split along class definitions | |
"\nclass ", | |
"\nobject ", | |
# Split along method definitions | |
"\ndef ", | |
"\nval ", | |
"\nvar ", | |
# Split along control flow statements | |
"\nif ", | |
"\nfor ", | |
"\nwhile ", | |
"\nmatch ", | |
"\ncase ", | |
# Split by the normal type of lines | |
"\n\n", | |
"\n", | |
" ", | |
"", | |
] | |
elif language == Language.SWIFT: | |
return [ | |
# Split along function definitions | |
"\nfunc ", | |
# Split along class definitions | |
"\nclass ", | |
"\nstruct ", | |
"\nenum ", | |
# Split along control flow statements | |
"\nif ", | |
"\nfor ", | |
"\nwhile ", | |
"\ndo ", | |
"\nswitch ", | |
"\ncase ", | |
# Split by the normal type of lines | |
"\n\n", | |
"\n", | |
" ", | |
"", | |
] | |
elif language == Language.MARKDOWN: | |
return [ | |
# First, try to split along Markdown headings (starting with level 2) | |
"\n#{1,6} ", | |
# Note the alternative syntax for headings (below) is not handled here | |
# Heading level 2 | |
# --------------- | |
# End of code block | |
"```\n", | |
# Horizontal lines | |
"\n\\*\\*\\*+\n", | |
"\n---+\n", | |
"\n___+\n", | |
# Note that this splitter doesn't handle horizontal lines defined | |
# by *three or more* of ***, ---, or ___, but this is not handled | |
"\n\n", | |
"\n", | |
" ", | |
"", | |
] | |
elif language == Language.LATEX: | |
return [ | |
# First, try to split along Latex sections | |
"\n\\\\chapter{", | |
"\n\\\\section{", | |
"\n\\\\subsection{", | |
"\n\\\\subsubsection{", | |
# Now split by environments | |
"\n\\\\begin{enumerate}", | |
"\n\\\\begin{itemize}", | |
"\n\\\\begin{description}", | |
"\n\\\\begin{list}", | |
"\n\\\\begin{quote}", | |
"\n\\\\begin{quotation}", | |
"\n\\\\begin{verse}", | |
"\n\\\\begin{verbatim}", | |
# Now split by math environments | |
"\n\\\begin{align}", | |
"$$", | |
"$", | |
# Now split by the normal type of lines | |
" ", | |
"", | |
] | |
elif language == Language.HTML: | |
return [ | |
# First, try to split along HTML tags | |
"<body", | |
"<div", | |
"<p", | |
"<br", | |
"<li", | |
"<h1", | |
"<h2", | |
"<h3", | |
"<h4", | |
"<h5", | |
"<h6", | |
"<span", | |
"<table", | |
"<tr", | |
"<td", | |
"<th", | |
"<ul", | |
"<ol", | |
"<header", | |
"<footer", | |
"<nav", | |
# Head | |
"<head", | |
"<style", | |
"<script", | |
"<meta", | |
"<title", | |
"", | |
] | |
elif language == Language.CSHARP: | |
return [ | |
"\ninterface ", | |
"\nenum ", | |
"\nimplements ", | |
"\ndelegate ", | |
"\nevent ", | |
# Split along class definitions | |
"\nclass ", | |
"\nabstract ", | |
# Split along method definitions | |
"\npublic ", | |
"\nprotected ", | |
"\nprivate ", | |
"\nstatic ", | |
"\nreturn ", | |
# Split along control flow statements | |
"\nif ", | |
"\ncontinue ", | |
"\nfor ", | |
"\nforeach ", | |
"\nwhile ", | |
"\nswitch ", | |
"\nbreak ", | |
"\ncase ", | |
"\nelse ", | |
# Split by exceptions | |
"\ntry ", | |
"\nthrow ", | |
"\nfinally ", | |
"\ncatch ", | |
# Split by the normal type of lines | |
"\n\n", | |
"\n", | |
" ", | |
"", | |
] | |
elif language == Language.SOL: | |
return [ | |
# Split along compiler information definitions | |
"\npragma ", | |
"\nusing ", | |
# Split along contract definitions | |
"\ncontract ", | |
"\ninterface ", | |
"\nlibrary ", | |
# Split along method definitions | |
"\nconstructor ", | |
"\ntype ", | |
"\nfunction ", | |
"\nevent ", | |
"\nmodifier ", | |
"\nerror ", | |
"\nstruct ", | |
"\nenum ", | |
# Split along control flow statements | |
"\nif ", | |
"\nfor ", | |
"\nwhile ", | |
"\ndo while ", | |
"\nassembly ", | |
# Split by the normal type of lines | |
"\n\n", | |
"\n", | |
" ", | |
"", | |
] | |
elif language == Language.COBOL: | |
return [ | |
# Split along divisions | |
"\nIDENTIFICATION DIVISION.", | |
"\nENVIRONMENT DIVISION.", | |
"\nDATA DIVISION.", | |
"\nPROCEDURE DIVISION.", | |
# Split along sections within DATA DIVISION | |
"\nWORKING-STORAGE SECTION.", | |
"\nLINKAGE SECTION.", | |
"\nFILE SECTION.", | |
# Split along sections within PROCEDURE DIVISION | |
"\nINPUT-OUTPUT SECTION.", | |
# Split along paragraphs and common statements | |
"\nOPEN ", | |
"\nCLOSE ", | |
"\nREAD ", | |
"\nWRITE ", | |
"\nIF ", | |
"\nELSE ", | |
"\nMOVE ", | |
"\nPERFORM ", | |
"\nUNTIL ", | |
"\nVARYING ", | |
"\nACCEPT ", | |
"\nDISPLAY ", | |
"\nSTOP RUN.", | |
# Split by the normal type of lines | |
"\n", | |
" ", | |
"", | |
] | |
else: | |
raise ValueError( | |
f"Language {language} is not supported! " | |
f"Please choose from {list(Language)}" | |
) | |
class ClusterSemanticChunker(BaseChunker): | |
def __init__(self, embedding_function=None, max_chunk_size=400, min_chunk_size=50, length_function=openai_token_count): | |
self.splitter = RecursiveTokenChunker( | |
chunk_size=min_chunk_size, | |
chunk_overlap=0, | |
length_function=openai_token_count, | |
separators = ["\n\n", "\n", ".", "?", "!", " ", ""] | |
) | |
self._chunk_size = max_chunk_size | |
self.max_cluster = max_chunk_size//min_chunk_size | |
self.embedding_function = embedding_function | |
def _get_similarity_matrix(self, embedding_function, sentences): | |
BATCH_SIZE = 500 | |
N = len(sentences) | |
embedding_matrix = None | |
for i in range(0, N, BATCH_SIZE): | |
batch_sentences = sentences[i:i+BATCH_SIZE] | |
embeddings = embedding_function(batch_sentences) | |
# Convert embeddings list of lists to numpy array | |
batch_embedding_matrix = np.array(embeddings) | |
# Append the batch embedding matrix to the main embedding matrix | |
if embedding_matrix is None: | |
embedding_matrix = batch_embedding_matrix | |
else: | |
embedding_matrix = np.concatenate((embedding_matrix, batch_embedding_matrix), axis=0) | |
similarity_matrix = np.dot(embedding_matrix, embedding_matrix.T) | |
return similarity_matrix | |
def _calculate_reward(self, matrix, start, end): | |
sub_matrix = matrix[start:end+1, start:end+1] | |
return np.sum(sub_matrix) | |
def _optimal_segmentation(self, matrix, max_cluster_size, window_size=3): | |
mean_value = np.mean(matrix[np.triu_indices(matrix.shape[0], k=1)]) | |
matrix = matrix - mean_value # Normalize the matrix | |
np.fill_diagonal(matrix, 0) # Set diagonal to 1 to avoid trivial solutions | |
n = matrix.shape[0] | |
dp = np.zeros(n) | |
segmentation = np.zeros(n, dtype=int) | |
for i in range(n): | |
for size in range(1, max_cluster_size + 1): | |
if i - size + 1 >= 0: | |
# local_density = calculate_local_density(matrix, i, window_size) | |
reward = self._calculate_reward(matrix, i - size + 1, i) | |
# Adjust reward based on local density | |
adjusted_reward = reward | |
if i - size >= 0: | |
adjusted_reward += dp[i - size] | |
if adjusted_reward > dp[i]: | |
dp[i] = adjusted_reward | |
segmentation[i] = i - size + 1 | |
clusters = [] | |
i = n - 1 | |
while i >= 0: | |
start = segmentation[i] | |
clusters.append((start, i)) | |
i = start - 1 | |
clusters.reverse() | |
return clusters | |
def split_text(self, text: str) -> List[str]: | |
sentences = self.splitter.split_text(text) | |
similarity_matrix = self._get_similarity_matrix(self.embedding_function, sentences) | |
clusters = self._optimal_segmentation(similarity_matrix, max_cluster_size=self.max_cluster) | |
docs = [' '.join(sentences[start:end+1]) for start, end in clusters] | |
return docs |