File size: 2,468 Bytes
8a58cf3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
"""Embedding utils for queries."""

from typing import Callable, Dict, List, Optional, Tuple

from gpt_index.data_structs.data_structs import Node
from gpt_index.embeddings.base import similarity as default_similarity_fn


def get_top_k_embeddings(
    query_embedding: List[float],
    embeddings: List[List[float]],
    similarity_fn: Optional[Callable[..., float]] = None,
    similarity_top_k: Optional[int] = None,
    embedding_ids: Optional[List] = None,
    similarity_cutoff: Optional[float] = None,
) -> Tuple[List[float], List]:
    """Get top nodes by similarity to the query."""
    if embedding_ids is None:
        embedding_ids = [i for i in range(len(embeddings))]

    similarity_fn = similarity_fn or default_similarity_fn

    similarities = []
    for emb in embeddings:
        similarity = similarity_fn(query_embedding, emb)
        similarities.append(similarity)

    sorted_tups = sorted(
        zip(similarities, embedding_ids), key=lambda x: x[0], reverse=True
    )

    if similarity_cutoff is not None:
        sorted_tups = [tup for tup in sorted_tups if tup[0] > similarity_cutoff]

    similarity_top_k = similarity_top_k or len(sorted_tups)
    result_tups = sorted_tups[:similarity_top_k]

    result_similarities = [s for s, _ in result_tups]
    result_ids = [n for _, n in result_tups]

    return result_similarities, result_ids


class SimilarityTracker:
    """Helper class to manage node similarities during lifecycle of a single query."""

    # TODO: smarter way to store this information
    lookup: Dict[str, float] = {}

    def _hash(self, node: Node) -> str:
        """Generate a unique key for each node."""
        # TODO: Better way to get unique identifier of a node
        return str(abs(hash(node.get_text())))

    def add(self, node: Node, similarity: float) -> None:
        """Add a node and its similarity score."""
        node_hash = self._hash(node)
        self.lookup[node_hash] = similarity

    def find(self, node: Node) -> Optional[float]:
        """Find a node's similarity score."""
        node_hash = self._hash(node)
        if node_hash not in self.lookup:
            return None
        return self.lookup[node_hash]

    def get_zipped_nodes(self, nodes: List[Node]) -> List[Tuple[Node, Optional[float]]]:
        """Get a zipped list of nodes and their corresponding scores."""
        similarities = [self.find(node) for node in nodes]
        return list(zip(nodes, similarities))