devve1 commited on
Commit
f767050
1 Parent(s): 3e97fe5

Create fixed_token_chunker.py

Browse files
Files changed (1) hide show
  1. fixed_token_chunker.py +262 -0
fixed_token_chunker.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This script is adapted from the LangChain package, developed by LangChain AI.
2
+ # Original code can be found at: https://github.com/langchain-ai/langchain/blob/master/libs/text-splitters/langchain_text_splitters/base.py
3
+ # License: MIT License
4
+
5
+ from abc import ABC, abstractmethod
6
+ from enum import Enum
7
+ import logging
8
+ from typing import (
9
+ AbstractSet,
10
+ Any,
11
+ Callable,
12
+ Collection,
13
+ Iterable,
14
+ List,
15
+ Literal,
16
+ Optional,
17
+ Sequence,
18
+ Type,
19
+ TypeVar,
20
+ Union,
21
+ )
22
+ from base_chunker import BaseChunker
23
+
24
+
25
+ from attr import dataclass
26
+
27
+ logger = logging.getLogger(__name__)
28
+
29
+ TS = TypeVar("TS", bound="TextSplitter")
30
+ class TextSplitter(BaseChunker, ABC):
31
+ """Interface for splitting text into chunks."""
32
+
33
+ def __init__(
34
+ self,
35
+ chunk_size: int = 4000,
36
+ chunk_overlap: int = 200,
37
+ length_function: Callable[[str], int] = len,
38
+ keep_separator: bool = False,
39
+ add_start_index: bool = False,
40
+ strip_whitespace: bool = True,
41
+ ) -> None:
42
+ """Create a new TextSplitter.
43
+
44
+ Args:
45
+ chunk_size: Maximum size of chunks to return
46
+ chunk_overlap: Overlap in characters between chunks
47
+ length_function: Function that measures the length of given chunks
48
+ keep_separator: Whether to keep the separator in the chunks
49
+ add_start_index: If `True`, includes chunk's start index in metadata
50
+ strip_whitespace: If `True`, strips whitespace from the start and end of
51
+ every document
52
+ """
53
+ if chunk_overlap > chunk_size:
54
+ raise ValueError(
55
+ f"Got a larger chunk overlap ({chunk_overlap}) than chunk size "
56
+ f"({chunk_size}), should be smaller."
57
+ )
58
+ self._chunk_size = chunk_size
59
+ self._chunk_overlap = chunk_overlap
60
+ self._length_function = length_function
61
+ self._keep_separator = keep_separator
62
+ self._add_start_index = add_start_index
63
+ self._strip_whitespace = strip_whitespace
64
+
65
+ @abstractmethod
66
+ def split_text(self, text: str) -> List[str]:
67
+ """Split text into multiple components."""
68
+
69
+ def _join_docs(self, docs: List[str], separator: str) -> Optional[str]:
70
+ text = separator.join(docs)
71
+ if self._strip_whitespace:
72
+ text = text.strip()
73
+ if text == "":
74
+ return None
75
+ else:
76
+ return text
77
+
78
+ def _merge_splits(self, splits: Iterable[str], separator: str) -> List[str]:
79
+ # We now want to combine these smaller pieces into medium size
80
+ # chunks to send to the LLM.
81
+ separator_len = self._length_function(separator)
82
+
83
+ docs = []
84
+ current_doc: List[str] = []
85
+ total = 0
86
+ for d in splits:
87
+ _len = self._length_function(d)
88
+ if (
89
+ total + _len + (separator_len if len(current_doc) > 0 else 0)
90
+ > self._chunk_size
91
+ ):
92
+ if total > self._chunk_size:
93
+ logger.warning(
94
+ f"Created a chunk of size {total}, "
95
+ f"which is longer than the specified {self._chunk_size}"
96
+ )
97
+ if len(current_doc) > 0:
98
+ doc = self._join_docs(current_doc, separator)
99
+ if doc is not None:
100
+ docs.append(doc)
101
+ # Keep on popping if:
102
+ # - we have a larger chunk than in the chunk overlap
103
+ # - or if we still have any chunks and the length is long
104
+ while total > self._chunk_overlap or (
105
+ total + _len + (separator_len if len(current_doc) > 0 else 0)
106
+ > self._chunk_size
107
+ and total > 0
108
+ ):
109
+ total -= self._length_function(current_doc[0]) + (
110
+ separator_len if len(current_doc) > 1 else 0
111
+ )
112
+ current_doc = current_doc[1:]
113
+ current_doc.append(d)
114
+ total += _len + (separator_len if len(current_doc) > 1 else 0)
115
+ doc = self._join_docs(current_doc, separator)
116
+ if doc is not None:
117
+ docs.append(doc)
118
+ return docs
119
+
120
+ # @classmethod
121
+ # def from_huggingface_tokenizer(cls, tokenizer: Any, **kwargs: Any) -> TextSplitter:
122
+ # """Text splitter that uses HuggingFace tokenizer to count length."""
123
+ # try:
124
+ # from transformers import PreTrainedTokenizerBase
125
+
126
+ # if not isinstance(tokenizer, PreTrainedTokenizerBase):
127
+ # raise ValueError(
128
+ # "Tokenizer received was not an instance of PreTrainedTokenizerBase"
129
+ # )
130
+
131
+ # def _huggingface_tokenizer_length(text: str) -> int:
132
+ # return len(tokenizer.encode(text))
133
+
134
+ # except ImportError:
135
+ # raise ValueError(
136
+ # "Could not import transformers python package. "
137
+ # "Please install it with `pip install transformers`."
138
+ # )
139
+ # return cls(length_function=_huggingface_tokenizer_length, **kwargs)
140
+
141
+ @classmethod
142
+ def from_tiktoken_encoder(
143
+ cls: Type[TS],
144
+ encoding_name: str = "gpt2",
145
+ model_name: Optional[str] = None,
146
+ allowed_special: Union[Literal["all"], AbstractSet[str]] = set(),
147
+ disallowed_special: Union[Literal["all"], Collection[str]] = "all",
148
+ **kwargs: Any,
149
+ ) -> TS:
150
+ """Text splitter that uses tiktoken encoder to count length."""
151
+ try:
152
+ import tiktoken
153
+ except ImportError:
154
+ raise ImportError(
155
+ "Could not import tiktoken python package. "
156
+ "This is needed in order to calculate max_tokens_for_prompt. "
157
+ "Please install it with `pip install tiktoken`."
158
+ )
159
+
160
+ if model_name is not None:
161
+ enc = tiktoken.encoding_for_model(model_name)
162
+ else:
163
+ enc = tiktoken.get_encoding(encoding_name)
164
+
165
+ def _tiktoken_encoder(text: str) -> int:
166
+ return len(
167
+ enc.encode(
168
+ text,
169
+ allowed_special=allowed_special,
170
+ disallowed_special=disallowed_special,
171
+ )
172
+ )
173
+
174
+ if issubclass(cls, FixedTokenChunker):
175
+ extra_kwargs = {
176
+ "encoding_name": encoding_name,
177
+ "model_name": model_name,
178
+ "allowed_special": allowed_special,
179
+ "disallowed_special": disallowed_special,
180
+ }
181
+ kwargs = {**kwargs, **extra_kwargs}
182
+
183
+ return cls(length_function=_tiktoken_encoder, **kwargs)
184
+
185
+ class FixedTokenChunker(TextSplitter):
186
+ """Splitting text to tokens using model tokenizer."""
187
+
188
+ def __init__(
189
+ self,
190
+ encoding_name: str = "cl100k_base",
191
+ model_name: Optional[str] = None,
192
+ chunk_size: int = 4000,
193
+ chunk_overlap: int = 200,
194
+ allowed_special: Union[Literal["all"], AbstractSet[str]] = set(),
195
+ disallowed_special: Union[Literal["all"], Collection[str]] = "all",
196
+ **kwargs: Any,
197
+ ) -> None:
198
+ """Create a new TextSplitter."""
199
+ super().__init__(chunk_size=chunk_size, chunk_overlap=chunk_overlap, **kwargs)
200
+ try:
201
+ import tiktoken
202
+ except ImportError:
203
+ raise ImportError(
204
+ "Could not import tiktoken python package. "
205
+ "This is needed in order to for FixedTokenChunker. "
206
+ "Please install it with `pip install tiktoken`."
207
+ )
208
+
209
+ if model_name is not None:
210
+ enc = tiktoken.encoding_for_model(model_name)
211
+ else:
212
+ enc = tiktoken.get_encoding(encoding_name)
213
+ self._tokenizer = enc
214
+ self._allowed_special = allowed_special
215
+ self._disallowed_special = disallowed_special
216
+
217
+ def split_text(self, text: str) -> List[str]:
218
+ def _encode(_text: str) -> List[int]:
219
+ return self._tokenizer.encode(
220
+ _text,
221
+ allowed_special=self._allowed_special,
222
+ disallowed_special=self._disallowed_special,
223
+ )
224
+
225
+ tokenizer = Tokenizer(
226
+ chunk_overlap=self._chunk_overlap,
227
+ tokens_per_chunk=self._chunk_size,
228
+ decode=self._tokenizer.decode,
229
+ encode=_encode,
230
+ )
231
+
232
+ return split_text_on_tokens(text=text, tokenizer=tokenizer)
233
+
234
+ @dataclass(frozen=True)
235
+ class Tokenizer:
236
+ """Tokenizer data class."""
237
+
238
+ chunk_overlap: int
239
+ """Overlap in tokens between chunks"""
240
+ tokens_per_chunk: int
241
+ """Maximum number of tokens per chunk"""
242
+ decode: Callable[[List[int]], str]
243
+ """ Function to decode a list of token ids to a string"""
244
+ encode: Callable[[str], List[int]]
245
+ """ Function to encode a string to a list of token ids"""
246
+
247
+
248
+ def split_text_on_tokens(*, text: str, tokenizer: Tokenizer) -> List[str]:
249
+ """Split incoming text and return chunks using tokenizer."""
250
+ splits: List[str] = []
251
+ input_ids = tokenizer.encode(text)
252
+ start_idx = 0
253
+ cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids))
254
+ chunk_ids = input_ids[start_idx:cur_idx]
255
+ while start_idx < len(input_ids):
256
+ splits.append(tokenizer.decode(chunk_ids))
257
+ if cur_idx == len(input_ids):
258
+ break
259
+ start_idx += tokenizer.tokens_per_chunk - tokenizer.chunk_overlap
260
+ cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids))
261
+ chunk_ids = input_ids[start_idx:cur_idx]
262
+ return splits