Spaces:
Sleeping
Sleeping
File size: 13,513 Bytes
9ff98d7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 |
import matplotlib.pyplot as plt
import torch
import numpy as np
import math
import datetime
#from h3.unstable import vect
import h3
class CoordEncoder:
def __init__(self, input_enc, raster=None, input_dim=0):
self.input_enc = input_enc
self.raster = raster
self.input_dim = input_dim
def encode(self, locs, normalize=True):
# assumes lon, lat in range [-180, 180] and [-90, 90]
if normalize:
locs = normalize_coords(locs)
if self.input_enc == 'none':
loc_feats = locs * torch.tensor([[180.0,90.0]], device=locs.device)
elif self.input_enc == 'sin_cos': # sinusoidal encoding
loc_feats = encode_loc(locs, input_dim=self.input_dim)
elif self.input_enc == 'env': # bioclim variables
loc_feats = bilinear_interpolate(locs, self.raster)
elif self.input_enc == 'sin_cos_env': # sinusoidal encoding & bioclim variables
loc_feats = encode_loc(locs, input_dim=self.input_dim)
context_feats = bilinear_interpolate(locs, self.raster.to(locs.device))
loc_feats = torch.cat((loc_feats, context_feats), 1)
elif self.input_enc == 'satclip': #SatClip Embedding
if not hasattr(self, 'model'):
import sys
sys.path.append('./satclip/satclip')
from satclip.satclip.load import get_satclip
self.model = get_satclip('satclip/satclip-vit16-l10.ckpt', device="cpu")
self.model.eval()
self.model = self.model.to(locs.device)
locs = locs*torch.tensor([[180.0, 90.0]], device=locs.device)
max_batch = 1000000
loc_feats = torch.empty(locs.shape[0], 256, device=locs.device)
with torch.no_grad():
for i in range(0, locs.shape[0], max_batch):
loc_feats[i:i+max_batch] = self.model(locs[i:i+max_batch].double()).float()
else:
raise NotImplementedError('Unknown input encoding.')
return loc_feats
def encode_fast(self, loc: list[float], normalize=True):
assert not normalize
if self.input_enc == 'sin_cos':
loc_feats = encode_loc_fast(loc, input_dim=self.input_dim)
else:
raise NotImplementedError('Unknown input encoding.')
return loc_feats
class TimeEncoder:
def __init__(self, input_enc='conical'):
self.input_enc = input_enc
def encode(self, intervals):
# assumes time, width in range [0, 1]
t_center = intervals[:, :1]
t_width = intervals[:, 1:]
if self.input_enc == 'conical':
t_feats = torch.cat([(1 - t_width) * torch.sin(2 * torch.pi * t_center),
(1 - t_width) * torch.cos(2 * torch.pi * t_center), 2 * t_width - 1], dim=1)
elif self.input_enc == 'cylindrical':
t_feats = torch.cat([torch.sin(2 * torch.pi * t_center), torch.cos(2 * torch.pi * t_center), 2 * t_width - 1], dim=1)
return t_feats
def encode_fast(self, intervals):
# assumes time, width in range [0, 1]
t_center, t_width = intervals
if self.input_enc == 'conical':
t_feats = torch.tensor([(1 - t_width) * math.sin(2 * math.pi * t_center),
(1 - t_width) * math.cos(2 * math.pi * t_center), 2 * t_width - 1])
elif self.input_enc == 'cylindrical':
t_feats = torch.tensor([math.sin(2 * math.pi * t_center),
math.cos(2 * math.pi * t_center), 2 * t_width - 1])
return t_feats
def normalize_coords(locs):
# locs is in lon {-180, 180}, lat {90, -90}
# output is in the range [-1, 1]
locs[:,0] /= 180.0
locs[:,1] /= 90.0
return locs
def encode_loc(loc_ip, concat_dim=1, input_dim=0):
# assumes inputs location are in range -1 to 1
# location is lon, lat
encs = []
for i in range(input_dim//4):
encs.append(torch.sin(math.pi*(2**i)*loc_ip))
encs.append(torch.cos(math.pi*(2**i)*loc_ip))
feats = torch.cat(encs, concat_dim)
return feats
def encode_loc_fast(loc_ip: list[float], input_dim=0):
# assumes inputs location are in range -1 to 1
# location is lon, lat
input_dim //= 2 # needed to make it compatible with encode_loc
feats = [(math.sin if i%(2*len(loc_ip))<len(loc_ip) else math.cos)(math.pi*(2**(i//(2*len(loc_ip))))*loc_ip[i%len(loc_ip)]) for i in range(input_dim)]
return feats
def bilinear_interpolate(loc_ip, data, remove_nans_raster=True):
# loc is N x 2 vector, where each row is [lon,lat] entry
# each entry spans range [-1,1]
# data is H x W x C, height x width x channel data matrix
# op will be N x C matrix of interpolated features
assert data is not None
# map to [0,1], then scale to data size
loc = (loc_ip.clone() + 1) / 2.0
loc[:,1] = 1 - loc[:,1] # this is because latitude goes from +90 on top to bottom while
# longitude goes from -90 to 90 left to right
assert not torch.any(torch.isnan(loc))
if remove_nans_raster:
data[torch.isnan(data)] = 0.0 # replace with mean value (0 is mean post-normalization)
# cast locations into pixel space
loc[:, 0] *= (data.shape[1]-1)
loc[:, 1] *= (data.shape[0]-1)
loc_int = torch.floor(loc).long() # integer pixel coordinates
xx = loc_int[:, 0]
yy = loc_int[:, 1]
xx_plus = xx + 1
xx_plus[xx_plus > (data.shape[1]-1)] = data.shape[1]-1
yy_plus = yy + 1
yy_plus[yy_plus > (data.shape[0]-1)] = data.shape[0]-1
loc_delta = loc - torch.floor(loc) # delta values
dx = loc_delta[:, 0].unsqueeze(1)
dy = loc_delta[:, 1].unsqueeze(1)
interp_val = data[yy, xx, :]*(1-dx)*(1-dy) + data[yy, xx_plus, :]*dx*(1-dy) + \
data[yy_plus, xx, :]*(1-dx)*dy + data[yy_plus, xx_plus, :]*dx*dy
return interp_val
def rand_samples(batch_size, device, rand_type='uniform'):
# randomly sample background locations
if rand_type == 'spherical':
rand_loc = torch.rand(batch_size, 2).to(device)
theta1 = 2.0*math.pi*rand_loc[:, 0]
theta2 = torch.acos(2.0*rand_loc[:, 1] - 1.0)
lat = 1.0 - 2.0*theta2/math.pi
lon = (theta1/math.pi) - 1.0
rand_loc = torch.cat((lon.unsqueeze(1), lat.unsqueeze(1)), 1)
elif rand_type == 'uniform':
rand_loc = torch.rand(batch_size, 2).to(device)*2.0 - 1.0
return rand_loc
def get_time_stamp():
cur_time = str(datetime.datetime.now())
date, time = cur_time.split(' ')
h, m, s = time.split(':')
s = s.split('.')[0]
time_stamp = '{}-{}-{}-{}'.format(date, h, m, s)
return time_stamp
def coord_grid(grid_size, split_ids=None, split_of_interest=None):
# generate a grid of locations spaced evenly in coordinate space
feats = np.zeros((grid_size[0], grid_size[1], 2), dtype=np.float32)
mg = np.meshgrid(np.linspace(-180, 180, feats.shape[1]), np.linspace(90, -90, feats.shape[0]))
feats[:, :, 0] = mg[0]
feats[:, :, 1] = mg[1]
if split_ids is None or split_of_interest is None:
# return feats for all locations
# this will be an N x 2 array
return feats.reshape(feats.shape[0]*feats.shape[1], 2)
else:
# only select a subset of locations
ind_y, ind_x = np.where(split_ids==split_of_interest)
# these will be N_subset x 2 in size
return feats[ind_y, ind_x, :]
def create_spatial_split(raster, mask, train_amt=1.0, cell_size=25):
# generates a checkerboard style train test split
# 0 is invalid, 1 is train, and 2 is test
# c_size is units of pixels
split_ids = np.ones((raster.shape[0], raster.shape[1]))
start = cell_size
for ii in np.arange(0, split_ids.shape[0], cell_size):
if start == 0:
start = cell_size
else:
start = 0
for jj in np.arange(start, split_ids.shape[1], cell_size*2):
split_ids[ii:ii+cell_size, jj:jj+cell_size] = 2
split_ids = split_ids*mask
if train_amt < 1.0:
# take a subset of the data
tr_y, tr_x = np.where(split_ids==1)
inds = np.random.choice(len(tr_y), int(len(tr_y)*(1.0-train_amt)), replace=False)
split_ids[tr_y[inds], tr_x[inds]] = 0
return split_ids
def average_precision_score_faster(y_true, y_scores):
# drop in replacement for sklearn's average_precision_score
# comparable up to floating point differences
num_positives = y_true.sum()
inds = np.argsort(y_scores)[::-1]
y_true_s = y_true[inds]
false_pos_c = np.cumsum(1.0 - y_true_s)
true_pos_c = np.cumsum(y_true_s)
recall = true_pos_c / num_positives
false_neg = np.maximum(true_pos_c + false_pos_c, np.finfo(np.float32).eps)
precision = true_pos_c / false_neg
recall_e = np.hstack((0, recall, 1))
recall_e = (recall_e[1:] - recall_e[:-1])[:-1]
map_score = (recall_e*precision).sum()
return map_score
#TODO I might be able to just cast these to a float to make them 1 or 0
#TODO y_true are the same as the ones
def average_precision_score_fasterer(y_true, y_scores):
# drop in replacement for sklearn's average_precision_score
# comparable up to floating point differences
num_positives = y_true.sum()
inds = torch.argsort(y_scores, descending=True)
y_true_s = y_true[inds]
false_pos_c = torch.cumsum(1.0 - y_true_s, dim=0)
true_pos_c = torch.cumsum(y_true_s, dim=0)
recall = true_pos_c / num_positives
false_neg = (true_pos_c + false_pos_c).clip(min=np.finfo(np.float32).eps)
precision = true_pos_c / false_neg
recall_e = torch.cat([torch.zeros(1, device=recall.device), recall, torch.ones(1, device=recall.device)])
recall_e = (recall_e[1:] - recall_e[:-1])[:-1]
map_score = (recall_e*precision).sum()
return map_score
class DataPDFH3:
def __init__(self, data='data_pdf_h3.pt', device='cpu'):
super(DataPDFH3, self).__init__()
self.data = torch.cumsum(torch.load(data, map_location=device), dim=0)
self.data = torch.cat([torch.zeros_like(self.data[:1]), self.data], dim=0)
inds = torch.load('inds_h3.pt')
inds = ((inds >> 30) & 4194303)
self.ind_map = -1+torch.zeros(2 ** 22, dtype=torch.int32)
self.ind_map[inds] = torch.arange(inds.shape[0], dtype=torch.int32)
self.cum_counts = self.data.sum(dim=-1)
def _sample(self, pos, time, noise_level):
pos = pos.cpu()
time = time.cpu()
noise_level = noise_level.cpu()
t_low = (365*(time - 0.5*(noise_level))).int()
t_high = (365*(time + 0.5*(noise_level))).int()
t_high[t_low < 0] += 365
t_low[t_low < 0] += 365
pos_ind = torch.from_numpy((h3.latlng_to_cell(90*pos[:, 1], 180*pos[:, 0], 5).astype(np.int64) >> 30) & 4194303)
pos_ind = self.ind_map[pos_ind]
counts = self.data[t_high.clamp(max=364)+1, pos_ind] - self.data[t_low, pos_ind]
counts[t_high > 364] += self.data[(t_high[t_high > 364] - 365).clamp(max=364) + 1, pos_ind[t_high > 364]]
counts[t_high > 729] += self.data[(t_high[t_high > 729] - 730).clamp(max=364) + 1, pos_ind[t_high > 729]]
totals = self.cum_counts[t_high.clamp(max=364)+1] - self.cum_counts[t_low]
totals[t_high > 364] += self.cum_counts[(t_high[t_high > 364] - 365).clamp(max=364) + 1]
totals[t_high > 729] += self.cum_counts[(t_high[t_high > 729] - 730).clamp(max=364) + 1]
counts[pos_ind < 0] = 0
return counts, totals
def sample(self, pos, time, noise_level):
counts, totals = self._sample(pos, time, noise_level)
return counts/totals
def sample_log(self, pos, time, noise_level, eps=1e-2):
counts, totals = self._sample(pos, time, noise_level)
return torch.log(counts)-torch.log(totals+eps)
class LowRankModel:
def __init__(self, data='nmf_256.pt', device='cpu'):
super(LowRankModel, self).__init__()
dim=-1
x1, x2 = torch.load(data, map_location=device)
m = torch.load('class_counts_locs_h3.pt').float()
chosen_inds = m.sum(dim=0).to_dense().sort(descending=True).indices[:]
if dim == 0:
n = m.to_dense()[:, chosen_inds].sum(dim=dim, keepdim=True)
self.data = n*torch.softmax(x1 @ x2, dim=dim)
self.data = self.data/torch.sum(self.data, dim=1, keepdim=True)
elif dim == 1:
self.data = torch.softmax(x1 @ x2, dim=dim)
elif dim == -1:
self.data = torch.from_numpy(x1 @ x2)
self.data = self.data/torch.sum(self.data, dim=1, keepdim=True)
m = m.to_dense()[:, chosen_inds]
#self.data = m.to_dense().float()/torch.sum(m.to_dense(), dim=1, keepdim=True)
self.pc = m.sum(dim=1, keepdim=True) / m.sum()
inds = torch.load('inds_h3.pt')[chosen_inds]
inds = ((inds >> 30) & 4194303)
self.ind_map = -1+torch.zeros(2 ** 22, dtype=torch.int32)
self.ind_map[inds] = torch.arange(inds.shape[0], dtype=torch.int32)
def sample(self, pos):#, time, noise_level):
pos = pos.cpu()
pos_ind = torch.from_numpy((h3.latlng_to_cell(pos[:, 1], pos[:, 0], 5).astype(np.int64) >> 30) & 4194303)
pos_ind = self.ind_map[pos_ind]
out = self.data[:, pos_ind]
out *= self.pc
out = out/torch.sum(out, dim=0, keepdim=True)
out[:, pos_ind < 0] = 1.0/out.shape[0]
return out |