Spaces:
Runtime error
Runtime error
"""Embedding utils for queries.""" | |
from typing import Callable, Dict, List, Optional, Tuple | |
from gpt_index.data_structs.data_structs import Node | |
from gpt_index.embeddings.base import similarity as default_similarity_fn | |
def get_top_k_embeddings( | |
query_embedding: List[float], | |
embeddings: List[List[float]], | |
similarity_fn: Optional[Callable[..., float]] = None, | |
similarity_top_k: Optional[int] = None, | |
embedding_ids: Optional[List] = None, | |
similarity_cutoff: Optional[float] = None, | |
) -> Tuple[List[float], List]: | |
"""Get top nodes by similarity to the query.""" | |
if embedding_ids is None: | |
embedding_ids = [i for i in range(len(embeddings))] | |
similarity_fn = similarity_fn or default_similarity_fn | |
similarities = [] | |
for emb in embeddings: | |
similarity = similarity_fn(query_embedding, emb) | |
similarities.append(similarity) | |
sorted_tups = sorted( | |
zip(similarities, embedding_ids), key=lambda x: x[0], reverse=True | |
) | |
if similarity_cutoff is not None: | |
sorted_tups = [tup for tup in sorted_tups if tup[0] > similarity_cutoff] | |
similarity_top_k = similarity_top_k or len(sorted_tups) | |
result_tups = sorted_tups[:similarity_top_k] | |
result_similarities = [s for s, _ in result_tups] | |
result_ids = [n for _, n in result_tups] | |
return result_similarities, result_ids | |
class SimilarityTracker: | |
"""Helper class to manage node similarities during lifecycle of a single query.""" | |
# TODO: smarter way to store this information | |
lookup: Dict[str, float] = {} | |
def _hash(self, node: Node) -> str: | |
"""Generate a unique key for each node.""" | |
# TODO: Better way to get unique identifier of a node | |
return str(abs(hash(node.get_text()))) | |
def add(self, node: Node, similarity: float) -> None: | |
"""Add a node and its similarity score.""" | |
node_hash = self._hash(node) | |
self.lookup[node_hash] = similarity | |
def find(self, node: Node) -> Optional[float]: | |
"""Find a node's similarity score.""" | |
node_hash = self._hash(node) | |
if node_hash not in self.lookup: | |
return None | |
return self.lookup[node_hash] | |
def get_zipped_nodes(self, nodes: List[Node]) -> List[Tuple[Node, Optional[float]]]: | |
"""Get a zipped list of nodes and their corresponding scores.""" | |
similarities = [self.find(node) for node in nodes] | |
return list(zip(nodes, similarities)) | |