from collections import Counter from dataclasses import dataclass import numpy as np from sklearn.cluster import DBSCAN from sklearn.metrics.pairwise import haversine_distances from PIL import Image import pandas as pd from geoguessr_bot.guessr import AbstractGuessr from geoguessr_bot.interfaces import Coordinate from geoguessr_bot.retriever import AbstractImageEmbedder from geoguessr_bot.retriever import Retriever def haversine_distance(x, y) -> float: """Compute the haversine distance between two coordinates """ return haversine_distances(np.array(x).reshape(1, -1), np.array(y).reshape(1, -1))[0][0] @dataclass class AverageNeighborsEmbedderGuessr(AbstractGuessr): """Guesses a coordinate using an Embedder and a retriever followed by NN. """ embedder: AbstractImageEmbedder retriever: Retriever metadata_path: str n_neighbors: int = 50 dbscan_eps: float = 0.05 def __post_init__(self): """Load metadata """ metadata = pd.read_csv(self.metadata_path) self.image_to_coordinate = { image.split("/")[-1]: Coordinate(latitude=latitude, longitude=longitude) for image, latitude, longitude in zip(metadata["path"], metadata["latitude"], metadata["longitude"]) } # DBSCAN will be used to take the centroid of the biggest cluster among the N neighbors, using Haversine self.dbscan = DBSCAN(eps=self.dbscan_eps, metric=haversine_distance) def guess(self, image: Image) -> Coordinate: """Guess a coordinate from an image """ # Embed image image = Image.fromarray(image) image_embedding = self.embedder.embed(image)[None, :] # Retrieve nearest neighbors nearest_neighbors, distances = self.retriever.retrieve(image_embedding, self.n_neighbors) nearest_neighbors = nearest_neighbors[0] distances = distances[0] # Get coordinates of neighbors neighbors_coordinates = [self.image_to_coordinate[nn].to_radians() for nn in nearest_neighbors] neighbors_coordinates = np.array([[nn.latitude, nn.longitude] for nn in neighbors_coordinates]) # Use DBSCAN to find the biggest cluster and potentially remove outliers clustering = self.dbscan.fit(neighbors_coordinates) labels = clustering.labels_ biggest_cluster = max(Counter(labels)) neighbors_coordinates = neighbors_coordinates[labels == biggest_cluster] distances = distances[labels == biggest_cluster] # Guess coordinate as the closest image among the cluster regarding retrieving distance guess_coordinate = neighbors_coordinates[np.argmin(distances)] guess_coordinate = Coordinate.from_radians(guess_coordinate[0], guess_coordinate[1]) return guess_coordinate