Spaces:
Runtime error
Runtime error
"""Base embeddings file.""" | |
import asyncio | |
from abc import abstractmethod | |
from enum import Enum | |
from typing import Callable, Coroutine, List, Optional, Tuple | |
import numpy as np | |
from gpt_index.utils import globals_helper | |
# TODO: change to numpy array | |
EMB_TYPE = List | |
DEFAULT_EMBED_BATCH_SIZE = 10 | |
class SimilarityMode(str, Enum): | |
"""Modes for similarity/distance.""" | |
DEFAULT = "cosine" | |
DOT_PRODUCT = "dot_product" | |
EUCLIDEAN = "euclidean" | |
def mean_agg(embeddings: List[List[float]]) -> List[float]: | |
"""Mean aggregation for embeddings.""" | |
return list(np.array(embeddings).mean(axis=0)) | |
def similarity( | |
embedding1: EMB_TYPE, | |
embedding2: EMB_TYPE, | |
mode: SimilarityMode = SimilarityMode.DEFAULT, | |
) -> float: | |
"""Get embedding similarity.""" | |
if mode == SimilarityMode.EUCLIDEAN: | |
return float(np.linalg.norm(np.array(embedding1) - np.array(embedding2))) | |
elif mode == SimilarityMode.DOT_PRODUCT: | |
product = np.dot(embedding1, embedding2) | |
return product | |
else: | |
product = np.dot(embedding1, embedding2) | |
norm = np.linalg.norm(embedding1) * np.linalg.norm(embedding2) | |
return product / norm | |
class BaseEmbedding: | |
"""Base class for embeddings.""" | |
def __init__(self, embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE) -> None: | |
"""Init params.""" | |
self._total_tokens_used = 0 | |
self._last_token_usage: Optional[int] = None | |
self._tokenizer: Callable = globals_helper.tokenizer | |
# list of tuples of id, text | |
self._text_queue: List[Tuple[str, str]] = [] | |
if embed_batch_size <= 0: | |
raise ValueError("embed_batch_size must be > 0") | |
self._embed_batch_size = embed_batch_size | |
def _get_query_embedding(self, query: str) -> List[float]: | |
"""Get query embedding.""" | |
def get_query_embedding(self, query: str) -> List[float]: | |
"""Get query embedding.""" | |
query_embedding = self._get_query_embedding(query) | |
query_tokens_count = len(self._tokenizer(query)) | |
self._total_tokens_used += query_tokens_count | |
return query_embedding | |
def get_agg_embedding_from_queries( | |
self, | |
queries: List[str], | |
agg_fn: Optional[Callable[..., List[float]]] = None, | |
) -> List[float]: | |
"""Get aggregated embedding from multiple queries.""" | |
query_embeddings = [self.get_query_embedding(query) for query in queries] | |
agg_fn = agg_fn or mean_agg | |
return agg_fn(query_embeddings) | |
def _get_text_embedding(self, text: str) -> List[float]: | |
"""Get text embedding.""" | |
async def _aget_text_embedding(self, text: str) -> List[float]: | |
"""Asynchronously get text embedding. | |
By default, this falls back to _get_text_embedding. | |
Meant to be overriden if there is a true async implementation. | |
""" | |
return self._get_text_embedding(text) | |
def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]: | |
"""Get text embeddings. | |
By default, this is a wrapper around _get_text_embedding. | |
Meant to be overriden for batch queries. | |
""" | |
result = [self._get_text_embedding(text) for text in texts] | |
return result | |
async def _aget_text_embeddings(self, texts: List[str]) -> List[List[float]]: | |
"""Asynchronously get text embeddings. | |
By default, this is a wrapper around _aget_text_embedding. | |
Meant to be overriden for batch queries. | |
""" | |
result = await asyncio.gather( | |
*[self._aget_text_embedding(text) for text in texts] | |
) | |
return result | |
def get_text_embedding(self, text: str) -> List[float]: | |
"""Get text embedding.""" | |
text_embedding = self._get_text_embedding(text) | |
text_tokens_count = len(self._tokenizer(text)) | |
self._total_tokens_used += text_tokens_count | |
return text_embedding | |
def queue_text_for_embeddding(self, text_id: str, text: str) -> None: | |
"""Queue text for embedding. | |
Used for batching texts during embedding calls. | |
""" | |
self._text_queue.append((text_id, text)) | |
def get_queued_text_embeddings(self) -> Tuple[List[str], List[List[float]]]: | |
"""Get queued text embeddings. | |
Call embedding API to get embeddings for all queued texts. | |
""" | |
text_queue = self._text_queue | |
cur_batch: List[Tuple[str, str]] = [] | |
result_ids: List[str] = [] | |
result_embeddings: List[List[float]] = [] | |
for idx, (text_id, text) in enumerate(text_queue): | |
cur_batch.append((text_id, text)) | |
text_tokens_count = len(self._tokenizer(text)) | |
self._total_tokens_used += text_tokens_count | |
if idx == len(text_queue) - 1 or len(cur_batch) == self._embed_batch_size: | |
# flush | |
cur_batch_ids = [text_id for text_id, _ in cur_batch] | |
cur_batch_texts = [text for _, text in cur_batch] | |
embeddings = self._get_text_embeddings(cur_batch_texts) | |
result_ids.extend(cur_batch_ids) | |
result_embeddings.extend(embeddings) | |
# reset queue | |
self._text_queue = [] | |
return result_ids, result_embeddings | |
async def aget_queued_text_embeddings( | |
self, text_queue: List[Tuple[str, str]] | |
) -> Tuple[List[str], List[List[float]]]: | |
"""Asynchronously get a list of text embeddings. | |
Call async embedding API to get embeddings for all queued texts in parallel. | |
Argument `text_queue` must be passed in to avoid updating it async. | |
""" | |
cur_batch: List[Tuple[str, str]] = [] | |
result_ids: List[str] = [] | |
result_embeddings: List[List[float]] = [] | |
embeddings_coroutines: List[Coroutine] = [] | |
for idx, (text_id, text) in enumerate(text_queue): | |
cur_batch.append((text_id, text)) | |
text_tokens_count = len(self._tokenizer(text)) | |
self._total_tokens_used += text_tokens_count | |
if idx == len(text_queue) - 1 or len(cur_batch) == self._embed_batch_size: | |
# flush | |
cur_batch_ids = [text_id for text_id, _ in cur_batch] | |
cur_batch_texts = [text for _, text in cur_batch] | |
embeddings_coroutines.append( | |
self._aget_text_embeddings(cur_batch_texts) | |
) | |
result_ids.extend(cur_batch_ids) | |
# flatten the results of asyncio.gather, which is a list of embeddings lists | |
result_embeddings = [ | |
embedding | |
for embeddings in await asyncio.gather(*embeddings_coroutines) | |
for embedding in embeddings | |
] | |
return result_ids, result_embeddings | |
def similarity( | |
self, | |
embedding1: EMB_TYPE, | |
embedding2: EMB_TYPE, | |
mode: SimilarityMode = SimilarityMode.DEFAULT, | |
) -> float: | |
"""Get embedding similarity.""" | |
return similarity(embedding1=embedding1, embedding2=embedding2, mode=mode) | |
def total_tokens_used(self) -> int: | |
"""Get the total tokens used so far.""" | |
return self._total_tokens_used | |
def last_token_usage(self) -> int: | |
"""Get the last token usage.""" | |
if self._last_token_usage is None: | |
return 0 | |
return self._last_token_usage | |
def last_token_usage(self, value: int) -> None: | |
"""Set the last token usage.""" | |
self._last_token_usage = value | |