Spaces:
Runtime error
Runtime error
File size: 7,616 Bytes
8a58cf3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 |
"""Base embeddings file."""
import asyncio
from abc import abstractmethod
from enum import Enum
from typing import Callable, Coroutine, List, Optional, Tuple
import numpy as np
from gpt_index.utils import globals_helper
# TODO: change to numpy array
EMB_TYPE = List
DEFAULT_EMBED_BATCH_SIZE = 10
class SimilarityMode(str, Enum):
"""Modes for similarity/distance."""
DEFAULT = "cosine"
DOT_PRODUCT = "dot_product"
EUCLIDEAN = "euclidean"
def mean_agg(embeddings: List[List[float]]) -> List[float]:
"""Mean aggregation for embeddings."""
return list(np.array(embeddings).mean(axis=0))
def similarity(
embedding1: EMB_TYPE,
embedding2: EMB_TYPE,
mode: SimilarityMode = SimilarityMode.DEFAULT,
) -> float:
"""Get embedding similarity."""
if mode == SimilarityMode.EUCLIDEAN:
return float(np.linalg.norm(np.array(embedding1) - np.array(embedding2)))
elif mode == SimilarityMode.DOT_PRODUCT:
product = np.dot(embedding1, embedding2)
return product
else:
product = np.dot(embedding1, embedding2)
norm = np.linalg.norm(embedding1) * np.linalg.norm(embedding2)
return product / norm
class BaseEmbedding:
"""Base class for embeddings."""
def __init__(self, embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE) -> None:
"""Init params."""
self._total_tokens_used = 0
self._last_token_usage: Optional[int] = None
self._tokenizer: Callable = globals_helper.tokenizer
# list of tuples of id, text
self._text_queue: List[Tuple[str, str]] = []
if embed_batch_size <= 0:
raise ValueError("embed_batch_size must be > 0")
self._embed_batch_size = embed_batch_size
@abstractmethod
def _get_query_embedding(self, query: str) -> List[float]:
"""Get query embedding."""
def get_query_embedding(self, query: str) -> List[float]:
"""Get query embedding."""
query_embedding = self._get_query_embedding(query)
query_tokens_count = len(self._tokenizer(query))
self._total_tokens_used += query_tokens_count
return query_embedding
def get_agg_embedding_from_queries(
self,
queries: List[str],
agg_fn: Optional[Callable[..., List[float]]] = None,
) -> List[float]:
"""Get aggregated embedding from multiple queries."""
query_embeddings = [self.get_query_embedding(query) for query in queries]
agg_fn = agg_fn or mean_agg
return agg_fn(query_embeddings)
@abstractmethod
def _get_text_embedding(self, text: str) -> List[float]:
"""Get text embedding."""
async def _aget_text_embedding(self, text: str) -> List[float]:
"""Asynchronously get text embedding.
By default, this falls back to _get_text_embedding.
Meant to be overriden if there is a true async implementation.
"""
return self._get_text_embedding(text)
def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
"""Get text embeddings.
By default, this is a wrapper around _get_text_embedding.
Meant to be overriden for batch queries.
"""
result = [self._get_text_embedding(text) for text in texts]
return result
async def _aget_text_embeddings(self, texts: List[str]) -> List[List[float]]:
"""Asynchronously get text embeddings.
By default, this is a wrapper around _aget_text_embedding.
Meant to be overriden for batch queries.
"""
result = await asyncio.gather(
*[self._aget_text_embedding(text) for text in texts]
)
return result
def get_text_embedding(self, text: str) -> List[float]:
"""Get text embedding."""
text_embedding = self._get_text_embedding(text)
text_tokens_count = len(self._tokenizer(text))
self._total_tokens_used += text_tokens_count
return text_embedding
def queue_text_for_embeddding(self, text_id: str, text: str) -> None:
"""Queue text for embedding.
Used for batching texts during embedding calls.
"""
self._text_queue.append((text_id, text))
def get_queued_text_embeddings(self) -> Tuple[List[str], List[List[float]]]:
"""Get queued text embeddings.
Call embedding API to get embeddings for all queued texts.
"""
text_queue = self._text_queue
cur_batch: List[Tuple[str, str]] = []
result_ids: List[str] = []
result_embeddings: List[List[float]] = []
for idx, (text_id, text) in enumerate(text_queue):
cur_batch.append((text_id, text))
text_tokens_count = len(self._tokenizer(text))
self._total_tokens_used += text_tokens_count
if idx == len(text_queue) - 1 or len(cur_batch) == self._embed_batch_size:
# flush
cur_batch_ids = [text_id for text_id, _ in cur_batch]
cur_batch_texts = [text for _, text in cur_batch]
embeddings = self._get_text_embeddings(cur_batch_texts)
result_ids.extend(cur_batch_ids)
result_embeddings.extend(embeddings)
# reset queue
self._text_queue = []
return result_ids, result_embeddings
async def aget_queued_text_embeddings(
self, text_queue: List[Tuple[str, str]]
) -> Tuple[List[str], List[List[float]]]:
"""Asynchronously get a list of text embeddings.
Call async embedding API to get embeddings for all queued texts in parallel.
Argument `text_queue` must be passed in to avoid updating it async.
"""
cur_batch: List[Tuple[str, str]] = []
result_ids: List[str] = []
result_embeddings: List[List[float]] = []
embeddings_coroutines: List[Coroutine] = []
for idx, (text_id, text) in enumerate(text_queue):
cur_batch.append((text_id, text))
text_tokens_count = len(self._tokenizer(text))
self._total_tokens_used += text_tokens_count
if idx == len(text_queue) - 1 or len(cur_batch) == self._embed_batch_size:
# flush
cur_batch_ids = [text_id for text_id, _ in cur_batch]
cur_batch_texts = [text for _, text in cur_batch]
embeddings_coroutines.append(
self._aget_text_embeddings(cur_batch_texts)
)
result_ids.extend(cur_batch_ids)
# flatten the results of asyncio.gather, which is a list of embeddings lists
result_embeddings = [
embedding
for embeddings in await asyncio.gather(*embeddings_coroutines)
for embedding in embeddings
]
return result_ids, result_embeddings
def similarity(
self,
embedding1: EMB_TYPE,
embedding2: EMB_TYPE,
mode: SimilarityMode = SimilarityMode.DEFAULT,
) -> float:
"""Get embedding similarity."""
return similarity(embedding1=embedding1, embedding2=embedding2, mode=mode)
@property
def total_tokens_used(self) -> int:
"""Get the total tokens used so far."""
return self._total_tokens_used
@property
def last_token_usage(self) -> int:
"""Get the last token usage."""
if self._last_token_usage is None:
return 0
return self._last_token_usage
@last_token_usage.setter
def last_token_usage(self, value: int) -> None:
"""Set the last token usage."""
self._last_token_usage = value
|