yunusserhat's picture
Create APP
894bc0c verified
import pandas as pd
import torch
from torch import nn
from models.networks.utils import UnormGPS
class Random(nn.Module):
def __init__(self, num_output):
"""Random"""
super().__init__()
self.num_output = num_output
self.unorm = UnormGPS()
def forward(self, x):
"""Predicts GPS coordinates from an image.
Args:
x: torch.Tensor with features
"""
#x = x["img"]
gps = torch.rand((x.shape[0], self.num_output), device=x.device) * 2 - 1
return {"gps": self.unorm(gps)}
class RandomCoords(nn.Module):
def __init__(self, coords_path: str):
"""Randomly sample from a list of coordinates
Args:
coords_path: str with path to csv file with coordinates
"""
super().__init__()
coordinates = pd.read_csv(coords_path)
longitudes = coordinates["longitude"].values / 180
latitudes = coordinates["latitude"].values / 90
self.unorm = UnormGPS()
del coordinates
self.N = len(longitudes)
assert len(longitudes) == len(latitudes)
self.coordinates = torch.stack(
[torch.tensor(latitudes), torch.tensor(longitudes)],
dim=-1,
)
del longitudes, latitudes
def forward(self, x):
"""Predicts GPS coordinates from an image.
Args:
x: torch.Tensor with features
"""
x = x["img"]
# randomly select a coordinate in the list
n = torch.randint(0, self.N, (x.shape[0],))
return {"gps": self.unorm(self.coordinates[n].to(x.device))}