AbeerTrial's picture
Upload folder using huggingface_hub
8a58cf3
"""Base embeddings file."""
import asyncio
from abc import abstractmethod
from enum import Enum
from typing import Callable, Coroutine, List, Optional, Tuple
import numpy as np
from gpt_index.utils import globals_helper
# TODO: change to numpy array
EMB_TYPE = List
DEFAULT_EMBED_BATCH_SIZE = 10
class SimilarityMode(str, Enum):
"""Modes for similarity/distance."""
DEFAULT = "cosine"
DOT_PRODUCT = "dot_product"
EUCLIDEAN = "euclidean"
def mean_agg(embeddings: List[List[float]]) -> List[float]:
"""Mean aggregation for embeddings."""
return list(np.array(embeddings).mean(axis=0))
def similarity(
embedding1: EMB_TYPE,
embedding2: EMB_TYPE,
mode: SimilarityMode = SimilarityMode.DEFAULT,
) -> float:
"""Get embedding similarity."""
if mode == SimilarityMode.EUCLIDEAN:
return float(np.linalg.norm(np.array(embedding1) - np.array(embedding2)))
elif mode == SimilarityMode.DOT_PRODUCT:
product = np.dot(embedding1, embedding2)
return product
else:
product = np.dot(embedding1, embedding2)
norm = np.linalg.norm(embedding1) * np.linalg.norm(embedding2)
return product / norm
class BaseEmbedding:
"""Base class for embeddings."""
def __init__(self, embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE) -> None:
"""Init params."""
self._total_tokens_used = 0
self._last_token_usage: Optional[int] = None
self._tokenizer: Callable = globals_helper.tokenizer
# list of tuples of id, text
self._text_queue: List[Tuple[str, str]] = []
if embed_batch_size <= 0:
raise ValueError("embed_batch_size must be > 0")
self._embed_batch_size = embed_batch_size
@abstractmethod
def _get_query_embedding(self, query: str) -> List[float]:
"""Get query embedding."""
def get_query_embedding(self, query: str) -> List[float]:
"""Get query embedding."""
query_embedding = self._get_query_embedding(query)
query_tokens_count = len(self._tokenizer(query))
self._total_tokens_used += query_tokens_count
return query_embedding
def get_agg_embedding_from_queries(
self,
queries: List[str],
agg_fn: Optional[Callable[..., List[float]]] = None,
) -> List[float]:
"""Get aggregated embedding from multiple queries."""
query_embeddings = [self.get_query_embedding(query) for query in queries]
agg_fn = agg_fn or mean_agg
return agg_fn(query_embeddings)
@abstractmethod
def _get_text_embedding(self, text: str) -> List[float]:
"""Get text embedding."""
async def _aget_text_embedding(self, text: str) -> List[float]:
"""Asynchronously get text embedding.
By default, this falls back to _get_text_embedding.
Meant to be overriden if there is a true async implementation.
"""
return self._get_text_embedding(text)
def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
"""Get text embeddings.
By default, this is a wrapper around _get_text_embedding.
Meant to be overriden for batch queries.
"""
result = [self._get_text_embedding(text) for text in texts]
return result
async def _aget_text_embeddings(self, texts: List[str]) -> List[List[float]]:
"""Asynchronously get text embeddings.
By default, this is a wrapper around _aget_text_embedding.
Meant to be overriden for batch queries.
"""
result = await asyncio.gather(
*[self._aget_text_embedding(text) for text in texts]
)
return result
def get_text_embedding(self, text: str) -> List[float]:
"""Get text embedding."""
text_embedding = self._get_text_embedding(text)
text_tokens_count = len(self._tokenizer(text))
self._total_tokens_used += text_tokens_count
return text_embedding
def queue_text_for_embeddding(self, text_id: str, text: str) -> None:
"""Queue text for embedding.
Used for batching texts during embedding calls.
"""
self._text_queue.append((text_id, text))
def get_queued_text_embeddings(self) -> Tuple[List[str], List[List[float]]]:
"""Get queued text embeddings.
Call embedding API to get embeddings for all queued texts.
"""
text_queue = self._text_queue
cur_batch: List[Tuple[str, str]] = []
result_ids: List[str] = []
result_embeddings: List[List[float]] = []
for idx, (text_id, text) in enumerate(text_queue):
cur_batch.append((text_id, text))
text_tokens_count = len(self._tokenizer(text))
self._total_tokens_used += text_tokens_count
if idx == len(text_queue) - 1 or len(cur_batch) == self._embed_batch_size:
# flush
cur_batch_ids = [text_id for text_id, _ in cur_batch]
cur_batch_texts = [text for _, text in cur_batch]
embeddings = self._get_text_embeddings(cur_batch_texts)
result_ids.extend(cur_batch_ids)
result_embeddings.extend(embeddings)
# reset queue
self._text_queue = []
return result_ids, result_embeddings
async def aget_queued_text_embeddings(
self, text_queue: List[Tuple[str, str]]
) -> Tuple[List[str], List[List[float]]]:
"""Asynchronously get a list of text embeddings.
Call async embedding API to get embeddings for all queued texts in parallel.
Argument `text_queue` must be passed in to avoid updating it async.
"""
cur_batch: List[Tuple[str, str]] = []
result_ids: List[str] = []
result_embeddings: List[List[float]] = []
embeddings_coroutines: List[Coroutine] = []
for idx, (text_id, text) in enumerate(text_queue):
cur_batch.append((text_id, text))
text_tokens_count = len(self._tokenizer(text))
self._total_tokens_used += text_tokens_count
if idx == len(text_queue) - 1 or len(cur_batch) == self._embed_batch_size:
# flush
cur_batch_ids = [text_id for text_id, _ in cur_batch]
cur_batch_texts = [text for _, text in cur_batch]
embeddings_coroutines.append(
self._aget_text_embeddings(cur_batch_texts)
)
result_ids.extend(cur_batch_ids)
# flatten the results of asyncio.gather, which is a list of embeddings lists
result_embeddings = [
embedding
for embeddings in await asyncio.gather(*embeddings_coroutines)
for embedding in embeddings
]
return result_ids, result_embeddings
def similarity(
self,
embedding1: EMB_TYPE,
embedding2: EMB_TYPE,
mode: SimilarityMode = SimilarityMode.DEFAULT,
) -> float:
"""Get embedding similarity."""
return similarity(embedding1=embedding1, embedding2=embedding2, mode=mode)
@property
def total_tokens_used(self) -> int:
"""Get the total tokens used so far."""
return self._total_tokens_used
@property
def last_token_usage(self) -> int:
"""Get the last token usage."""
if self._last_token_usage is None:
return 0
return self._last_token_usage
@last_token_usage.setter
def last_token_usage(self, value: int) -> None:
"""Set the last token usage."""
self._last_token_usage = value