|
|
|
|
|
|
|
|
|
|
|
from attr import dataclass |
|
|
|
import numpy as np |
|
import tiktoken |
|
|
|
from chunking_evaluation.utils import Language |
|
import re |
|
|
|
from abc import ABC, abstractmethod |
|
from enum import Enum |
|
import logging |
|
from typing import ( |
|
AbstractSet, |
|
Any, |
|
Callable, |
|
Collection, |
|
Iterable, |
|
List, |
|
Literal, |
|
Optional, |
|
Sequence, |
|
Type, |
|
TypeVar, |
|
Union, |
|
) |
|
|
|
class BaseChunker(ABC): |
|
@abstractmethod |
|
def split_text(self, text: str) -> List[str]: |
|
pass |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
TS = TypeVar("TS", bound="TextSplitter") |
|
class TextSplitter(BaseChunker, ABC): |
|
"""Interface for splitting text into chunks.""" |
|
|
|
def __init__( |
|
self, |
|
chunk_size: int = 4000, |
|
chunk_overlap: int = 200, |
|
length_function: Callable[[str], int] = len, |
|
keep_separator: bool = False, |
|
add_start_index: bool = False, |
|
strip_whitespace: bool = True, |
|
) -> None: |
|
"""Create a new TextSplitter. |
|
|
|
Args: |
|
chunk_size: Maximum size of chunks to return |
|
chunk_overlap: Overlap in characters between chunks |
|
length_function: Function that measures the length of given chunks |
|
keep_separator: Whether to keep the separator in the chunks |
|
add_start_index: If `True`, includes chunk's start index in metadata |
|
strip_whitespace: If `True`, strips whitespace from the start and end of |
|
every document |
|
""" |
|
if chunk_overlap > chunk_size: |
|
raise ValueError( |
|
f"Got a larger chunk overlap ({chunk_overlap}) than chunk size " |
|
f"({chunk_size}), should be smaller." |
|
) |
|
self._chunk_size = chunk_size |
|
self._chunk_overlap = chunk_overlap |
|
self._length_function = length_function |
|
self._keep_separator = keep_separator |
|
self._add_start_index = add_start_index |
|
self._strip_whitespace = strip_whitespace |
|
|
|
@abstractmethod |
|
def split_text(self, text: str) -> List[str]: |
|
"""Split text into multiple components.""" |
|
|
|
def _join_docs(self, docs: List[str], separator: str) -> Optional[str]: |
|
text = separator.join(docs) |
|
if self._strip_whitespace: |
|
text = text.strip() |
|
if text == "": |
|
return None |
|
else: |
|
return text |
|
|
|
def _merge_splits(self, splits: Iterable[str], separator: str) -> List[str]: |
|
|
|
|
|
separator_len = self._length_function(separator) |
|
|
|
docs = [] |
|
current_doc: List[str] = [] |
|
total = 0 |
|
for d in splits: |
|
_len = self._length_function(d) |
|
if ( |
|
total + _len + (separator_len if len(current_doc) > 0 else 0) |
|
> self._chunk_size |
|
): |
|
if total > self._chunk_size: |
|
logger.warning( |
|
f"Created a chunk of size {total}, " |
|
f"which is longer than the specified {self._chunk_size}" |
|
) |
|
if len(current_doc) > 0: |
|
doc = self._join_docs(current_doc, separator) |
|
if doc is not None: |
|
docs.append(doc) |
|
|
|
|
|
|
|
while total > self._chunk_overlap or ( |
|
total + _len + (separator_len if len(current_doc) > 0 else 0) |
|
> self._chunk_size |
|
and total > 0 |
|
): |
|
total -= self._length_function(current_doc[0]) + ( |
|
separator_len if len(current_doc) > 1 else 0 |
|
) |
|
current_doc = current_doc[1:] |
|
current_doc.append(d) |
|
total += _len + (separator_len if len(current_doc) > 1 else 0) |
|
doc = self._join_docs(current_doc, separator) |
|
if doc is not None: |
|
docs.append(doc) |
|
return docs |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@classmethod |
|
def from_tiktoken_encoder( |
|
cls: Type[TS], |
|
encoding_name: str = "gpt2", |
|
model_name: Optional[str] = None, |
|
allowed_special: Union[Literal["all"], AbstractSet[str]] = set(), |
|
disallowed_special: Union[Literal["all"], Collection[str]] = "all", |
|
**kwargs: Any, |
|
) -> TS: |
|
"""Text splitter that uses tiktoken encoder to count length.""" |
|
try: |
|
import tiktoken |
|
except ImportError: |
|
raise ImportError( |
|
"Could not import tiktoken python package. " |
|
"This is needed in order to calculate max_tokens_for_prompt. " |
|
"Please install it with `pip install tiktoken`." |
|
) |
|
|
|
if model_name is not None: |
|
enc = tiktoken.encoding_for_model(model_name) |
|
else: |
|
enc = tiktoken.get_encoding(encoding_name) |
|
|
|
def _tiktoken_encoder(text: str) -> int: |
|
return len( |
|
enc.encode( |
|
text, |
|
allowed_special=allowed_special, |
|
disallowed_special=disallowed_special, |
|
) |
|
) |
|
|
|
if issubclass(cls, FixedTokenChunker): |
|
extra_kwargs = { |
|
"encoding_name": encoding_name, |
|
"model_name": model_name, |
|
"allowed_special": allowed_special, |
|
"disallowed_special": disallowed_special, |
|
} |
|
kwargs = {**kwargs, **extra_kwargs} |
|
|
|
return cls(length_function=_tiktoken_encoder, **kwargs) |
|
|
|
class FixedTokenChunker(TextSplitter): |
|
"""Splitting text to tokens using model tokenizer.""" |
|
|
|
def __init__( |
|
self, |
|
encoding_name: str = "cl100k_base", |
|
model_name: Optional[str] = None, |
|
chunk_size: int = 4000, |
|
chunk_overlap: int = 200, |
|
allowed_special: Union[Literal["all"], AbstractSet[str]] = set(), |
|
disallowed_special: Union[Literal["all"], Collection[str]] = "all", |
|
**kwargs: Any, |
|
) -> None: |
|
"""Create a new TextSplitter.""" |
|
super().__init__(chunk_size=chunk_size, chunk_overlap=chunk_overlap, **kwargs) |
|
try: |
|
import tiktoken |
|
except ImportError: |
|
raise ImportError( |
|
"Could not import tiktoken python package. " |
|
"This is needed in order to for FixedTokenChunker. " |
|
"Please install it with `pip install tiktoken`." |
|
) |
|
|
|
if model_name is not None: |
|
enc = tiktoken.encoding_for_model(model_name) |
|
else: |
|
enc = tiktoken.get_encoding(encoding_name) |
|
self._tokenizer = enc |
|
self._allowed_special = allowed_special |
|
self._disallowed_special = disallowed_special |
|
|
|
def split_text(self, text: str) -> List[str]: |
|
def _encode(_text: str) -> List[int]: |
|
return self._tokenizer.encode( |
|
_text, |
|
allowed_special=self._allowed_special, |
|
disallowed_special=self._disallowed_special, |
|
) |
|
|
|
tokenizer = Tokenizer( |
|
chunk_overlap=self._chunk_overlap, |
|
tokens_per_chunk=self._chunk_size, |
|
decode=self._tokenizer.decode, |
|
encode=_encode, |
|
) |
|
|
|
return split_text_on_tokens(text=text, tokenizer=tokenizer) |
|
|
|
@dataclass(frozen=True) |
|
class Tokenizer: |
|
"""Tokenizer data class.""" |
|
|
|
chunk_overlap: int |
|
"""Overlap in tokens between chunks""" |
|
tokens_per_chunk: int |
|
"""Maximum number of tokens per chunk""" |
|
decode: Callable[[List[int]], str] |
|
""" Function to decode a list of token ids to a string""" |
|
encode: Callable[[str], List[int]] |
|
""" Function to encode a string to a list of token ids""" |
|
|
|
|
|
def split_text_on_tokens(*, text: str, tokenizer: Tokenizer) -> List[str]: |
|
"""Split incoming text and return chunks using tokenizer.""" |
|
splits: List[str] = [] |
|
input_ids = tokenizer.encode(text) |
|
start_idx = 0 |
|
cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids)) |
|
chunk_ids = input_ids[start_idx:cur_idx] |
|
while start_idx < len(input_ids): |
|
splits.append(tokenizer.decode(chunk_ids)) |
|
if cur_idx == len(input_ids): |
|
break |
|
start_idx += tokenizer.tokens_per_chunk - tokenizer.chunk_overlap |
|
cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids)) |
|
chunk_ids = input_ids[start_idx:cur_idx] |
|
return splits |
|
|
|
def _split_text_with_regex( |
|
text: str, separator: str, keep_separator: bool |
|
) -> List[str]: |
|
|
|
if separator: |
|
if keep_separator: |
|
|
|
_splits = re.split(f"({separator})", text) |
|
splits = [_splits[i] + _splits[i + 1] for i in range(1, len(_splits), 2)] |
|
if len(_splits) % 2 == 0: |
|
splits += _splits[-1:] |
|
splits = [_splits[0]] + splits |
|
else: |
|
splits = re.split(separator, text) |
|
else: |
|
splits = list(text) |
|
return [s for s in splits if s != ""] |
|
|
|
class RecursiveTokenChunker(TextSplitter): |
|
"""Splitting text by recursively look at characters. |
|
|
|
Recursively tries to split by different characters to find one |
|
that works. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
chunk_size: int = 4000, |
|
chunk_overlap: int = 200, |
|
separators: Optional[List[str]] = None, |
|
keep_separator: bool = True, |
|
is_separator_regex: bool = False, |
|
**kwargs: Any, |
|
) -> None: |
|
"""Create a new TextSplitter.""" |
|
super().__init__(chunk_size=chunk_size, chunk_overlap=chunk_overlap, keep_separator=keep_separator, **kwargs) |
|
self._separators = separators or ["\n\n", "\n", ".", "?", "!", " ", ""] |
|
self._is_separator_regex = is_separator_regex |
|
|
|
def _split_text(self, text: str, separators: List[str]) -> List[str]: |
|
"""Split incoming text and return chunks.""" |
|
final_chunks = [] |
|
|
|
separator = separators[-1] |
|
new_separators = [] |
|
for i, _s in enumerate(separators): |
|
_separator = _s if self._is_separator_regex else re.escape(_s) |
|
if _s == "": |
|
separator = _s |
|
break |
|
if re.search(_separator, text): |
|
separator = _s |
|
new_separators = separators[i + 1 :] |
|
break |
|
|
|
_separator = separator if self._is_separator_regex else re.escape(separator) |
|
splits = _split_text_with_regex(text, _separator, self._keep_separator) |
|
|
|
|
|
_good_splits = [] |
|
_separator = "" if self._keep_separator else separator |
|
for s in splits: |
|
if self._length_function(s) < self._chunk_size: |
|
_good_splits.append(s) |
|
else: |
|
if _good_splits: |
|
merged_text = self._merge_splits(_good_splits, _separator) |
|
final_chunks.extend(merged_text) |
|
_good_splits = [] |
|
if not new_separators: |
|
final_chunks.append(s) |
|
else: |
|
other_info = self._split_text(s, new_separators) |
|
final_chunks.extend(other_info) |
|
if _good_splits: |
|
merged_text = self._merge_splits(_good_splits, _separator) |
|
final_chunks.extend(merged_text) |
|
return final_chunks |
|
|
|
def split_text(self, text: str) -> List[str]: |
|
return self._split_text(text, self._separators) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod |
|
def get_separators_for_language(language: Language) -> List[str]: |
|
if language == Language.CPP: |
|
return [ |
|
|
|
"\nclass ", |
|
|
|
"\nvoid ", |
|
"\nint ", |
|
"\nfloat ", |
|
"\ndouble ", |
|
|
|
"\nif ", |
|
"\nfor ", |
|
"\nwhile ", |
|
"\nswitch ", |
|
"\ncase ", |
|
|
|
"\n\n", |
|
"\n", |
|
" ", |
|
"", |
|
] |
|
elif language == Language.GO: |
|
return [ |
|
|
|
"\nfunc ", |
|
"\nvar ", |
|
"\nconst ", |
|
"\ntype ", |
|
|
|
"\nif ", |
|
"\nfor ", |
|
"\nswitch ", |
|
"\ncase ", |
|
|
|
"\n\n", |
|
"\n", |
|
" ", |
|
"", |
|
] |
|
elif language == Language.JAVA: |
|
return [ |
|
|
|
"\nclass ", |
|
|
|
"\npublic ", |
|
"\nprotected ", |
|
"\nprivate ", |
|
"\nstatic ", |
|
|
|
"\nif ", |
|
"\nfor ", |
|
"\nwhile ", |
|
"\nswitch ", |
|
"\ncase ", |
|
|
|
"\n\n", |
|
"\n", |
|
" ", |
|
"", |
|
] |
|
elif language == Language.KOTLIN: |
|
return [ |
|
|
|
"\nclass ", |
|
|
|
"\npublic ", |
|
"\nprotected ", |
|
"\nprivate ", |
|
"\ninternal ", |
|
"\ncompanion ", |
|
"\nfun ", |
|
"\nval ", |
|
"\nvar ", |
|
|
|
"\nif ", |
|
"\nfor ", |
|
"\nwhile ", |
|
"\nwhen ", |
|
"\ncase ", |
|
"\nelse ", |
|
|
|
"\n\n", |
|
"\n", |
|
" ", |
|
"", |
|
] |
|
elif language == Language.JS: |
|
return [ |
|
|
|
"\nfunction ", |
|
"\nconst ", |
|
"\nlet ", |
|
"\nvar ", |
|
"\nclass ", |
|
|
|
"\nif ", |
|
"\nfor ", |
|
"\nwhile ", |
|
"\nswitch ", |
|
"\ncase ", |
|
"\ndefault ", |
|
|
|
"\n\n", |
|
"\n", |
|
" ", |
|
"", |
|
] |
|
elif language == Language.TS: |
|
return [ |
|
"\nenum ", |
|
"\ninterface ", |
|
"\nnamespace ", |
|
"\ntype ", |
|
|
|
"\nclass ", |
|
|
|
"\nfunction ", |
|
"\nconst ", |
|
"\nlet ", |
|
"\nvar ", |
|
|
|
"\nif ", |
|
"\nfor ", |
|
"\nwhile ", |
|
"\nswitch ", |
|
"\ncase ", |
|
"\ndefault ", |
|
|
|
"\n\n", |
|
"\n", |
|
" ", |
|
"", |
|
] |
|
elif language == Language.PHP: |
|
return [ |
|
|
|
"\nfunction ", |
|
|
|
"\nclass ", |
|
|
|
"\nif ", |
|
"\nforeach ", |
|
"\nwhile ", |
|
"\ndo ", |
|
"\nswitch ", |
|
"\ncase ", |
|
|
|
"\n\n", |
|
"\n", |
|
" ", |
|
"", |
|
] |
|
elif language == Language.PROTO: |
|
return [ |
|
|
|
"\nmessage ", |
|
|
|
"\nservice ", |
|
|
|
"\nenum ", |
|
|
|
"\noption ", |
|
|
|
"\nimport ", |
|
|
|
"\nsyntax ", |
|
|
|
"\n\n", |
|
"\n", |
|
" ", |
|
"", |
|
] |
|
elif language == Language.PYTHON: |
|
return [ |
|
|
|
"\nclass ", |
|
"\ndef ", |
|
"\n\tdef ", |
|
|
|
"\n\n", |
|
"\n", |
|
" ", |
|
"", |
|
] |
|
elif language == Language.RST: |
|
return [ |
|
|
|
"\n=+\n", |
|
"\n-+\n", |
|
"\n\\*+\n", |
|
|
|
"\n\n.. *\n\n", |
|
|
|
"\n\n", |
|
"\n", |
|
" ", |
|
"", |
|
] |
|
elif language == Language.RUBY: |
|
return [ |
|
|
|
"\ndef ", |
|
"\nclass ", |
|
|
|
"\nif ", |
|
"\nunless ", |
|
"\nwhile ", |
|
"\nfor ", |
|
"\ndo ", |
|
"\nbegin ", |
|
"\nrescue ", |
|
|
|
"\n\n", |
|
"\n", |
|
" ", |
|
"", |
|
] |
|
elif language == Language.RUST: |
|
return [ |
|
|
|
"\nfn ", |
|
"\nconst ", |
|
"\nlet ", |
|
|
|
"\nif ", |
|
"\nwhile ", |
|
"\nfor ", |
|
"\nloop ", |
|
"\nmatch ", |
|
"\nconst ", |
|
|
|
"\n\n", |
|
"\n", |
|
" ", |
|
"", |
|
] |
|
elif language == Language.SCALA: |
|
return [ |
|
|
|
"\nclass ", |
|
"\nobject ", |
|
|
|
"\ndef ", |
|
"\nval ", |
|
"\nvar ", |
|
|
|
"\nif ", |
|
"\nfor ", |
|
"\nwhile ", |
|
"\nmatch ", |
|
"\ncase ", |
|
|
|
"\n\n", |
|
"\n", |
|
" ", |
|
"", |
|
] |
|
elif language == Language.SWIFT: |
|
return [ |
|
|
|
"\nfunc ", |
|
|
|
"\nclass ", |
|
"\nstruct ", |
|
"\nenum ", |
|
|
|
"\nif ", |
|
"\nfor ", |
|
"\nwhile ", |
|
"\ndo ", |
|
"\nswitch ", |
|
"\ncase ", |
|
|
|
"\n\n", |
|
"\n", |
|
" ", |
|
"", |
|
] |
|
elif language == Language.MARKDOWN: |
|
return [ |
|
|
|
"\n#{1,6} ", |
|
|
|
|
|
|
|
|
|
"```\n", |
|
|
|
"\n\\*\\*\\*+\n", |
|
"\n---+\n", |
|
"\n___+\n", |
|
|
|
|
|
"\n\n", |
|
"\n", |
|
" ", |
|
"", |
|
] |
|
elif language == Language.LATEX: |
|
return [ |
|
|
|
"\n\\\\chapter{", |
|
"\n\\\\section{", |
|
"\n\\\\subsection{", |
|
"\n\\\\subsubsection{", |
|
|
|
"\n\\\\begin{enumerate}", |
|
"\n\\\\begin{itemize}", |
|
"\n\\\\begin{description}", |
|
"\n\\\\begin{list}", |
|
"\n\\\\begin{quote}", |
|
"\n\\\\begin{quotation}", |
|
"\n\\\\begin{verse}", |
|
"\n\\\\begin{verbatim}", |
|
|
|
"\n\\\begin{align}", |
|
"$$", |
|
"$", |
|
|
|
" ", |
|
"", |
|
] |
|
elif language == Language.HTML: |
|
return [ |
|
|
|
"<body", |
|
"<div", |
|
"<p", |
|
"<br", |
|
"<li", |
|
"<h1", |
|
"<h2", |
|
"<h3", |
|
"<h4", |
|
"<h5", |
|
"<h6", |
|
"<span", |
|
"<table", |
|
"<tr", |
|
"<td", |
|
"<th", |
|
"<ul", |
|
"<ol", |
|
"<header", |
|
"<footer", |
|
"<nav", |
|
|
|
"<head", |
|
"<style", |
|
"<script", |
|
"<meta", |
|
"<title", |
|
"", |
|
] |
|
elif language == Language.CSHARP: |
|
return [ |
|
"\ninterface ", |
|
"\nenum ", |
|
"\nimplements ", |
|
"\ndelegate ", |
|
"\nevent ", |
|
|
|
"\nclass ", |
|
"\nabstract ", |
|
|
|
"\npublic ", |
|
"\nprotected ", |
|
"\nprivate ", |
|
"\nstatic ", |
|
"\nreturn ", |
|
|
|
"\nif ", |
|
"\ncontinue ", |
|
"\nfor ", |
|
"\nforeach ", |
|
"\nwhile ", |
|
"\nswitch ", |
|
"\nbreak ", |
|
"\ncase ", |
|
"\nelse ", |
|
|
|
"\ntry ", |
|
"\nthrow ", |
|
"\nfinally ", |
|
"\ncatch ", |
|
|
|
"\n\n", |
|
"\n", |
|
" ", |
|
"", |
|
] |
|
elif language == Language.SOL: |
|
return [ |
|
|
|
"\npragma ", |
|
"\nusing ", |
|
|
|
"\ncontract ", |
|
"\ninterface ", |
|
"\nlibrary ", |
|
|
|
"\nconstructor ", |
|
"\ntype ", |
|
"\nfunction ", |
|
"\nevent ", |
|
"\nmodifier ", |
|
"\nerror ", |
|
"\nstruct ", |
|
"\nenum ", |
|
|
|
"\nif ", |
|
"\nfor ", |
|
"\nwhile ", |
|
"\ndo while ", |
|
"\nassembly ", |
|
|
|
"\n\n", |
|
"\n", |
|
" ", |
|
"", |
|
] |
|
elif language == Language.COBOL: |
|
return [ |
|
|
|
"\nIDENTIFICATION DIVISION.", |
|
"\nENVIRONMENT DIVISION.", |
|
"\nDATA DIVISION.", |
|
"\nPROCEDURE DIVISION.", |
|
|
|
"\nWORKING-STORAGE SECTION.", |
|
"\nLINKAGE SECTION.", |
|
"\nFILE SECTION.", |
|
|
|
"\nINPUT-OUTPUT SECTION.", |
|
|
|
"\nOPEN ", |
|
"\nCLOSE ", |
|
"\nREAD ", |
|
"\nWRITE ", |
|
"\nIF ", |
|
"\nELSE ", |
|
"\nMOVE ", |
|
"\nPERFORM ", |
|
"\nUNTIL ", |
|
"\nVARYING ", |
|
"\nACCEPT ", |
|
"\nDISPLAY ", |
|
"\nSTOP RUN.", |
|
|
|
"\n", |
|
" ", |
|
"", |
|
] |
|
|
|
else: |
|
raise ValueError( |
|
f"Language {language} is not supported! " |
|
f"Please choose from {list(Language)}" |
|
) |
|
|
|
class ClusterSemanticChunker(BaseChunker): |
|
def __init__(self, embedding_function=None, max_chunk_size=400, min_chunk_size=50, length_function=openai_token_count): |
|
self.splitter = RecursiveTokenChunker( |
|
chunk_size=min_chunk_size, |
|
chunk_overlap=0, |
|
length_function=openai_token_count, |
|
separators = ["\n\n", "\n", ".", "?", "!", " ", ""] |
|
) |
|
|
|
self._chunk_size = max_chunk_size |
|
self.max_cluster = max_chunk_size//min_chunk_size |
|
self.embedding_function = embedding_function |
|
|
|
def _get_similarity_matrix(self, embedding_function, sentences): |
|
BATCH_SIZE = 500 |
|
N = len(sentences) |
|
embedding_matrix = None |
|
|
|
for i in range(0, N, BATCH_SIZE): |
|
batch_sentences = sentences[i:i+BATCH_SIZE] |
|
embeddings = embedding_function(batch_sentences) |
|
|
|
|
|
batch_embedding_matrix = np.array(embeddings) |
|
|
|
|
|
if embedding_matrix is None: |
|
embedding_matrix = batch_embedding_matrix |
|
else: |
|
embedding_matrix = np.concatenate((embedding_matrix, batch_embedding_matrix), axis=0) |
|
|
|
similarity_matrix = np.dot(embedding_matrix, embedding_matrix.T) |
|
|
|
return similarity_matrix |
|
|
|
def _calculate_reward(self, matrix, start, end): |
|
sub_matrix = matrix[start:end+1, start:end+1] |
|
return np.sum(sub_matrix) |
|
|
|
def _optimal_segmentation(self, matrix, max_cluster_size, window_size=3): |
|
mean_value = np.mean(matrix[np.triu_indices(matrix.shape[0], k=1)]) |
|
matrix = matrix - mean_value |
|
np.fill_diagonal(matrix, 0) |
|
|
|
n = matrix.shape[0] |
|
dp = np.zeros(n) |
|
segmentation = np.zeros(n, dtype=int) |
|
|
|
for i in range(n): |
|
for size in range(1, max_cluster_size + 1): |
|
if i - size + 1 >= 0: |
|
|
|
reward = self._calculate_reward(matrix, i - size + 1, i) |
|
|
|
adjusted_reward = reward |
|
if i - size >= 0: |
|
adjusted_reward += dp[i - size] |
|
if adjusted_reward > dp[i]: |
|
dp[i] = adjusted_reward |
|
segmentation[i] = i - size + 1 |
|
|
|
clusters = [] |
|
i = n - 1 |
|
while i >= 0: |
|
start = segmentation[i] |
|
clusters.append((start, i)) |
|
i = start - 1 |
|
|
|
clusters.reverse() |
|
return clusters |
|
|
|
def split_text(self, text: str) -> List[str]: |
|
sentences = self.splitter.split_text(text) |
|
|
|
similarity_matrix = self._get_similarity_matrix(self.embedding_function, sentences) |
|
|
|
clusters = self._optimal_segmentation(similarity_matrix, max_cluster_size=self.max_cluster) |
|
|
|
docs = [' '.join(sentences[start:end+1]) for start, end in clusters] |
|
|
|
return docs |