"""Base embeddings file.""" import asyncio from abc import abstractmethod from enum import Enum from typing import Callable, Coroutine, List, Optional, Tuple import numpy as np from gpt_index.utils import globals_helper # TODO: change to numpy array EMB_TYPE = List DEFAULT_EMBED_BATCH_SIZE = 10 class SimilarityMode(str, Enum): """Modes for similarity/distance.""" DEFAULT = "cosine" DOT_PRODUCT = "dot_product" EUCLIDEAN = "euclidean" def mean_agg(embeddings: List[List[float]]) -> List[float]: """Mean aggregation for embeddings.""" return list(np.array(embeddings).mean(axis=0)) def similarity( embedding1: EMB_TYPE, embedding2: EMB_TYPE, mode: SimilarityMode = SimilarityMode.DEFAULT, ) -> float: """Get embedding similarity.""" if mode == SimilarityMode.EUCLIDEAN: return float(np.linalg.norm(np.array(embedding1) - np.array(embedding2))) elif mode == SimilarityMode.DOT_PRODUCT: product = np.dot(embedding1, embedding2) return product else: product = np.dot(embedding1, embedding2) norm = np.linalg.norm(embedding1) * np.linalg.norm(embedding2) return product / norm class BaseEmbedding: """Base class for embeddings.""" def __init__(self, embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE) -> None: """Init params.""" self._total_tokens_used = 0 self._last_token_usage: Optional[int] = None self._tokenizer: Callable = globals_helper.tokenizer # list of tuples of id, text self._text_queue: List[Tuple[str, str]] = [] if embed_batch_size <= 0: raise ValueError("embed_batch_size must be > 0") self._embed_batch_size = embed_batch_size @abstractmethod def _get_query_embedding(self, query: str) -> List[float]: """Get query embedding.""" def get_query_embedding(self, query: str) -> List[float]: """Get query embedding.""" query_embedding = self._get_query_embedding(query) query_tokens_count = len(self._tokenizer(query)) self._total_tokens_used += query_tokens_count return query_embedding def get_agg_embedding_from_queries( self, queries: List[str], agg_fn: Optional[Callable[..., List[float]]] = None, ) -> List[float]: """Get aggregated embedding from multiple queries.""" query_embeddings = [self.get_query_embedding(query) for query in queries] agg_fn = agg_fn or mean_agg return agg_fn(query_embeddings) @abstractmethod def _get_text_embedding(self, text: str) -> List[float]: """Get text embedding.""" async def _aget_text_embedding(self, text: str) -> List[float]: """Asynchronously get text embedding. By default, this falls back to _get_text_embedding. Meant to be overriden if there is a true async implementation. """ return self._get_text_embedding(text) def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]: """Get text embeddings. By default, this is a wrapper around _get_text_embedding. Meant to be overriden for batch queries. """ result = [self._get_text_embedding(text) for text in texts] return result async def _aget_text_embeddings(self, texts: List[str]) -> List[List[float]]: """Asynchronously get text embeddings. By default, this is a wrapper around _aget_text_embedding. Meant to be overriden for batch queries. """ result = await asyncio.gather( *[self._aget_text_embedding(text) for text in texts] ) return result def get_text_embedding(self, text: str) -> List[float]: """Get text embedding.""" text_embedding = self._get_text_embedding(text) text_tokens_count = len(self._tokenizer(text)) self._total_tokens_used += text_tokens_count return text_embedding def queue_text_for_embeddding(self, text_id: str, text: str) -> None: """Queue text for embedding. Used for batching texts during embedding calls. """ self._text_queue.append((text_id, text)) def get_queued_text_embeddings(self) -> Tuple[List[str], List[List[float]]]: """Get queued text embeddings. Call embedding API to get embeddings for all queued texts. """ text_queue = self._text_queue cur_batch: List[Tuple[str, str]] = [] result_ids: List[str] = [] result_embeddings: List[List[float]] = [] for idx, (text_id, text) in enumerate(text_queue): cur_batch.append((text_id, text)) text_tokens_count = len(self._tokenizer(text)) self._total_tokens_used += text_tokens_count if idx == len(text_queue) - 1 or len(cur_batch) == self._embed_batch_size: # flush cur_batch_ids = [text_id for text_id, _ in cur_batch] cur_batch_texts = [text for _, text in cur_batch] embeddings = self._get_text_embeddings(cur_batch_texts) result_ids.extend(cur_batch_ids) result_embeddings.extend(embeddings) # reset queue self._text_queue = [] return result_ids, result_embeddings async def aget_queued_text_embeddings( self, text_queue: List[Tuple[str, str]] ) -> Tuple[List[str], List[List[float]]]: """Asynchronously get a list of text embeddings. Call async embedding API to get embeddings for all queued texts in parallel. Argument `text_queue` must be passed in to avoid updating it async. """ cur_batch: List[Tuple[str, str]] = [] result_ids: List[str] = [] result_embeddings: List[List[float]] = [] embeddings_coroutines: List[Coroutine] = [] for idx, (text_id, text) in enumerate(text_queue): cur_batch.append((text_id, text)) text_tokens_count = len(self._tokenizer(text)) self._total_tokens_used += text_tokens_count if idx == len(text_queue) - 1 or len(cur_batch) == self._embed_batch_size: # flush cur_batch_ids = [text_id for text_id, _ in cur_batch] cur_batch_texts = [text for _, text in cur_batch] embeddings_coroutines.append( self._aget_text_embeddings(cur_batch_texts) ) result_ids.extend(cur_batch_ids) # flatten the results of asyncio.gather, which is a list of embeddings lists result_embeddings = [ embedding for embeddings in await asyncio.gather(*embeddings_coroutines) for embedding in embeddings ] return result_ids, result_embeddings def similarity( self, embedding1: EMB_TYPE, embedding2: EMB_TYPE, mode: SimilarityMode = SimilarityMode.DEFAULT, ) -> float: """Get embedding similarity.""" return similarity(embedding1=embedding1, embedding2=embedding2, mode=mode) @property def total_tokens_used(self) -> int: """Get the total tokens used so far.""" return self._total_tokens_used @property def last_token_usage(self) -> int: """Get the last token usage.""" if self._last_token_usage is None: return 0 return self._last_token_usage @last_token_usage.setter def last_token_usage(self, value: int) -> None: """Set the last token usage.""" self._last_token_usage = value