from typing import Dict, List, Tuple import numpy as np import faiss class Retriever: def __init__(self, embeddings_path: str): self.embeddings: Dict[str, np.ndarray] = self.load_embeddings(embeddings_path) # Keep track of image names self.image_to_index = {image_name: i for i, image_name in enumerate(self.embeddings.keys())} self.index_to_image = {i: image_name for i, image_name in enumerate(self.embeddings.keys())} # Build Faiss index self.embeddings = np.array(list(self.embeddings.values())) self.dim = self.embeddings.shape[1] self.index = faiss.IndexFlatL2(self.dim) self.index.add(self.embeddings) @staticmethod def load_embeddings(embeddings_path: str) -> Dict[str, np.ndarray]: """Load embeddings from a .npy file """ return np.load(embeddings_path, allow_pickle=True).item() def retrieve(self, queries: np.ndarray, n_neighbors: int = 5) -> Tuple[List[List[str]], List[List[float]]]: """Retrieve nearest neighbors indexes from queries """ distances, indexes = self.index.search(queries, n_neighbors) return [[self.index_to_image[i] for i in index] for index in indexes], distances