File size: 1,413 Bytes
944c93a
bd2d69d
 
d78053e
bd2d69d
4388025
944c93a
4388025
 
 
 
 
1791df2
 
4388025
 
 
d78053e
 
 
 
 
 
 
 
 
 
 
4388025
 
 
 
 
989941c
d78053e
 
1791df2
4388025
1791df2
4388025
 
d78053e
4388025
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
from dataclasses import dataclass

from PIL import Image
import pandas as pd

from geoguessr_bot.guessr import AbstractGuessr
from geoguessr_bot.interfaces import Coordinate
from geoguessr_bot.retriever import AbstractImageEmbedder
from geoguessr_bot.retriever import Retriever


@dataclass
class NearestNeighborEmbedderGuessr(AbstractGuessr):
    """Guesses a coordinate using an Embedder and a retriever followed by NN.
    """
    embedder: AbstractImageEmbedder
    retriever: Retriever
    metadata_path: str

    def __post_init__(self):
        """Load metadata
        """
        metadata = pd.read_csv(self.metadata_path)
        self.image_to_coordinate = {
            image.split("/")[-1]: Coordinate(latitude=latitude, longitude=longitude)
            for image, latitude, longitude in zip(metadata["path"], metadata["latitude"], metadata["longitude"])
        }
        

    def guess(self, image: Image) -> Coordinate:
        """Guess a coordinate from an image
        """
        # Embed image
        image = Image.fromarray(image)
        image_embedding = self.embedder.embed(image)[None, :]
        
        # Retrieve nearest neighbor
        nearest_neighbors = self.retriever.retrieve(image_embedding)
        nearest_neighbor = nearest_neighbors[0][0][0]

        # Guess coordinate
        guess_coordinate = self.image_to_coordinate[nearest_neighbor]
        return guess_coordinate