File size: 13,513 Bytes
9ff98d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
import matplotlib.pyplot as plt
import torch
import numpy as np
import math
import datetime
#from h3.unstable import vect
import h3

class CoordEncoder:

    def __init__(self, input_enc, raster=None, input_dim=0):
        self.input_enc = input_enc
        self.raster = raster
        self.input_dim = input_dim

    def encode(self, locs, normalize=True):
        # assumes lon, lat in range [-180, 180] and [-90, 90]
        if normalize:
            locs = normalize_coords(locs)
        if self.input_enc == 'none':
            loc_feats = locs * torch.tensor([[180.0,90.0]], device=locs.device)
        elif self.input_enc == 'sin_cos': # sinusoidal encoding
            loc_feats = encode_loc(locs, input_dim=self.input_dim)
        elif self.input_enc == 'env': # bioclim variables
            loc_feats = bilinear_interpolate(locs, self.raster)
        elif self.input_enc == 'sin_cos_env': # sinusoidal encoding & bioclim variables
            loc_feats = encode_loc(locs, input_dim=self.input_dim)
            context_feats = bilinear_interpolate(locs, self.raster.to(locs.device))
            loc_feats = torch.cat((loc_feats, context_feats), 1)
        elif self.input_enc == 'satclip': #SatClip Embedding
            if not hasattr(self, 'model'):
                import sys
                sys.path.append('./satclip/satclip')
                from satclip.satclip.load import get_satclip
                self.model = get_satclip('satclip/satclip-vit16-l10.ckpt', device="cpu")
                self.model.eval()
            self.model = self.model.to(locs.device)
            locs = locs*torch.tensor([[180.0, 90.0]], device=locs.device)
            max_batch = 1000000
            loc_feats = torch.empty(locs.shape[0], 256, device=locs.device)
            with torch.no_grad():
                for i in range(0, locs.shape[0], max_batch):
                    loc_feats[i:i+max_batch] = self.model(locs[i:i+max_batch].double()).float()
        else:
            raise NotImplementedError('Unknown input encoding.')
        return loc_feats

    def encode_fast(self, loc: list[float], normalize=True):
        assert not normalize
        if self.input_enc == 'sin_cos':
            loc_feats = encode_loc_fast(loc, input_dim=self.input_dim)
        else:
            raise NotImplementedError('Unknown input encoding.')
        return loc_feats


class TimeEncoder:

    def __init__(self, input_enc='conical'):
        self.input_enc = input_enc

    def encode(self, intervals):
        # assumes time, width in range [0, 1]
        t_center = intervals[:, :1]
        t_width = intervals[:, 1:]
        if self.input_enc == 'conical':
            t_feats = torch.cat([(1 - t_width) * torch.sin(2 * torch.pi * t_center),
                           (1 - t_width) * torch.cos(2 * torch.pi * t_center), 2 * t_width - 1], dim=1)
        elif self.input_enc == 'cylindrical':
            t_feats = torch.cat([torch.sin(2 * torch.pi * t_center), torch.cos(2 * torch.pi * t_center), 2 * t_width - 1], dim=1)
        return t_feats

    def encode_fast(self, intervals):
        # assumes time, width in range [0, 1]
        t_center, t_width = intervals
        if self.input_enc == 'conical':
            t_feats = torch.tensor([(1 - t_width) * math.sin(2 * math.pi * t_center),
                                 (1 - t_width) * math.cos(2 * math.pi * t_center), 2 * t_width - 1])
        elif self.input_enc == 'cylindrical':
            t_feats = torch.tensor([math.sin(2 * math.pi * t_center),
                                    math.cos(2 * math.pi * t_center), 2 * t_width - 1])
        return t_feats


def normalize_coords(locs):
    # locs is in lon {-180, 180}, lat {90, -90}
    # output is in the range [-1, 1]

    locs[:,0] /= 180.0
    locs[:,1] /= 90.0

    return locs

def encode_loc(loc_ip, concat_dim=1, input_dim=0):
    # assumes inputs location are in range -1 to 1
    # location is lon, lat
    encs = []
    for i in range(input_dim//4):
        encs.append(torch.sin(math.pi*(2**i)*loc_ip))
        encs.append(torch.cos(math.pi*(2**i)*loc_ip))
    feats = torch.cat(encs, concat_dim)
    return feats


def encode_loc_fast(loc_ip: list[float], input_dim=0):
    # assumes inputs location are in range -1 to 1
    # location is lon, lat
    input_dim //= 2 # needed to make it compatible with encode_loc
    feats = [(math.sin if i%(2*len(loc_ip))<len(loc_ip) else math.cos)(math.pi*(2**(i//(2*len(loc_ip))))*loc_ip[i%len(loc_ip)]) for i in range(input_dim)]
    return feats


def bilinear_interpolate(loc_ip, data, remove_nans_raster=True):
    # loc is N x 2 vector, where each row is [lon,lat] entry
    #   each entry spans range [-1,1]
    # data is H x W x C, height x width x channel data matrix
    # op will be N x C matrix of interpolated features

    assert data is not None

    # map to [0,1], then scale to data size
    loc = (loc_ip.clone() + 1) / 2.0
    loc[:,1] = 1 - loc[:,1] # this is because latitude goes from +90 on top to bottom while
                            # longitude goes from -90 to 90 left to right

    assert not torch.any(torch.isnan(loc))

    if remove_nans_raster:
        data[torch.isnan(data)] = 0.0 # replace with mean value (0 is mean post-normalization)

    # cast locations into pixel space
    loc[:, 0] *= (data.shape[1]-1)
    loc[:, 1] *= (data.shape[0]-1)

    loc_int = torch.floor(loc).long()  # integer pixel coordinates
    xx = loc_int[:, 0]
    yy = loc_int[:, 1]
    xx_plus = xx + 1
    xx_plus[xx_plus > (data.shape[1]-1)] = data.shape[1]-1
    yy_plus = yy + 1
    yy_plus[yy_plus > (data.shape[0]-1)] = data.shape[0]-1

    loc_delta = loc - torch.floor(loc)   # delta values
    dx = loc_delta[:, 0].unsqueeze(1)
    dy = loc_delta[:, 1].unsqueeze(1)

    interp_val = data[yy, xx, :]*(1-dx)*(1-dy) + data[yy, xx_plus, :]*dx*(1-dy) + \
                 data[yy_plus, xx, :]*(1-dx)*dy   + data[yy_plus, xx_plus, :]*dx*dy

    return interp_val

def rand_samples(batch_size, device, rand_type='uniform'):
    # randomly sample background locations

    if rand_type == 'spherical':
        rand_loc = torch.rand(batch_size, 2).to(device)
        theta1 = 2.0*math.pi*rand_loc[:, 0]
        theta2 = torch.acos(2.0*rand_loc[:, 1] - 1.0)
        lat = 1.0 - 2.0*theta2/math.pi
        lon = (theta1/math.pi) - 1.0
        rand_loc = torch.cat((lon.unsqueeze(1), lat.unsqueeze(1)), 1)

    elif rand_type == 'uniform':
        rand_loc = torch.rand(batch_size, 2).to(device)*2.0 - 1.0

    return rand_loc

def get_time_stamp():
    cur_time = str(datetime.datetime.now())
    date, time = cur_time.split(' ')
    h, m, s = time.split(':')
    s = s.split('.')[0]
    time_stamp = '{}-{}-{}-{}'.format(date, h, m, s)
    return time_stamp

def coord_grid(grid_size, split_ids=None, split_of_interest=None):
    # generate a grid of locations spaced evenly in coordinate space

    feats = np.zeros((grid_size[0], grid_size[1], 2), dtype=np.float32)
    mg = np.meshgrid(np.linspace(-180, 180, feats.shape[1]), np.linspace(90, -90, feats.shape[0]))
    feats[:, :, 0] = mg[0]
    feats[:, :, 1] = mg[1]
    if split_ids is None or split_of_interest is None:
        # return feats for all locations
        # this will be an N x 2 array
        return feats.reshape(feats.shape[0]*feats.shape[1], 2)
    else:
        # only select a subset of locations
        ind_y, ind_x = np.where(split_ids==split_of_interest)

        # these will be N_subset x 2 in size
        return feats[ind_y, ind_x, :]

def create_spatial_split(raster, mask, train_amt=1.0, cell_size=25):
    # generates a checkerboard style train test split
    # 0 is invalid, 1 is train, and 2 is test
    # c_size is units of pixels
    split_ids = np.ones((raster.shape[0], raster.shape[1]))
    start = cell_size
    for ii in np.arange(0, split_ids.shape[0], cell_size):
        if start == 0:
            start = cell_size
        else:
            start = 0
        for jj in np.arange(start, split_ids.shape[1], cell_size*2):
            split_ids[ii:ii+cell_size, jj:jj+cell_size] = 2
    split_ids = split_ids*mask
    if train_amt < 1.0:
        # take a subset of the data
        tr_y, tr_x = np.where(split_ids==1)
        inds = np.random.choice(len(tr_y), int(len(tr_y)*(1.0-train_amt)), replace=False)
        split_ids[tr_y[inds], tr_x[inds]] = 0
    return split_ids

def average_precision_score_faster(y_true, y_scores):
    # drop in replacement for sklearn's average_precision_score
    # comparable up to floating point differences
    num_positives = y_true.sum()
    inds = np.argsort(y_scores)[::-1]
    y_true_s = y_true[inds]

    false_pos_c = np.cumsum(1.0 - y_true_s)
    true_pos_c = np.cumsum(y_true_s)
    recall = true_pos_c / num_positives
    false_neg = np.maximum(true_pos_c + false_pos_c, np.finfo(np.float32).eps)
    precision = true_pos_c / false_neg

    recall_e = np.hstack((0, recall, 1))
    recall_e = (recall_e[1:] - recall_e[:-1])[:-1]
    map_score = (recall_e*precision).sum()
    return map_score

#TODO I might be able to just cast these to a float to make them 1 or 0 
#TODO y_true are the same as the ones 
def average_precision_score_fasterer(y_true, y_scores):
    # drop in replacement for sklearn's average_precision_score
    # comparable up to floating point differences
    num_positives = y_true.sum()
    inds = torch.argsort(y_scores, descending=True)
    y_true_s = y_true[inds]

    false_pos_c = torch.cumsum(1.0 - y_true_s, dim=0)
    true_pos_c = torch.cumsum(y_true_s, dim=0)
    recall = true_pos_c / num_positives
    false_neg = (true_pos_c + false_pos_c).clip(min=np.finfo(np.float32).eps)
    precision = true_pos_c / false_neg

    recall_e = torch.cat([torch.zeros(1, device=recall.device), recall, torch.ones(1, device=recall.device)])
    recall_e = (recall_e[1:] - recall_e[:-1])[:-1]
    map_score = (recall_e*precision).sum()
    return map_score


class DataPDFH3:
    def __init__(self, data='data_pdf_h3.pt', device='cpu'):
        super(DataPDFH3, self).__init__()
        self.data = torch.cumsum(torch.load(data, map_location=device), dim=0)
        self.data = torch.cat([torch.zeros_like(self.data[:1]), self.data], dim=0)
        inds = torch.load('inds_h3.pt')
        inds = ((inds >> 30) & 4194303)
        self.ind_map = -1+torch.zeros(2 ** 22, dtype=torch.int32)
        self.ind_map[inds] = torch.arange(inds.shape[0], dtype=torch.int32)
        self.cum_counts = self.data.sum(dim=-1)

    def _sample(self, pos, time, noise_level):
        pos = pos.cpu()
        time = time.cpu()
        noise_level = noise_level.cpu()
        t_low = (365*(time - 0.5*(noise_level))).int()
        t_high = (365*(time + 0.5*(noise_level))).int()
        t_high[t_low < 0] += 365
        t_low[t_low < 0] += 365

        pos_ind = torch.from_numpy((h3.latlng_to_cell(90*pos[:, 1], 180*pos[:, 0], 5).astype(np.int64) >> 30) & 4194303)
        pos_ind = self.ind_map[pos_ind]
        counts = self.data[t_high.clamp(max=364)+1, pos_ind] - self.data[t_low, pos_ind]
        counts[t_high > 364] += self.data[(t_high[t_high > 364] - 365).clamp(max=364) + 1, pos_ind[t_high > 364]]
        counts[t_high > 729] += self.data[(t_high[t_high > 729] - 730).clamp(max=364) + 1, pos_ind[t_high > 729]]
        totals = self.cum_counts[t_high.clamp(max=364)+1] - self.cum_counts[t_low]
        totals[t_high > 364] += self.cum_counts[(t_high[t_high > 364] - 365).clamp(max=364) + 1]
        totals[t_high > 729] += self.cum_counts[(t_high[t_high > 729] - 730).clamp(max=364) + 1]
        counts[pos_ind < 0] = 0
        return counts, totals

    def sample(self, pos, time, noise_level):
        counts, totals = self._sample(pos, time, noise_level)
        return counts/totals

    def sample_log(self, pos, time, noise_level, eps=1e-2):
        counts, totals = self._sample(pos, time, noise_level)
        return torch.log(counts)-torch.log(totals+eps)


class LowRankModel:
    def __init__(self, data='nmf_256.pt', device='cpu'):
        super(LowRankModel, self).__init__()
        dim=-1
        x1, x2 = torch.load(data, map_location=device)
        m = torch.load('class_counts_locs_h3.pt').float()
        chosen_inds = m.sum(dim=0).to_dense().sort(descending=True).indices[:]
        if dim == 0:
            n = m.to_dense()[:, chosen_inds].sum(dim=dim, keepdim=True)
            self.data = n*torch.softmax(x1 @ x2, dim=dim)
            self.data = self.data/torch.sum(self.data, dim=1, keepdim=True)
        elif dim == 1:
            self.data = torch.softmax(x1 @ x2, dim=dim)
        elif dim == -1:
            self.data = torch.from_numpy(x1 @ x2)
            self.data = self.data/torch.sum(self.data, dim=1, keepdim=True)
        m = m.to_dense()[:, chosen_inds]
        #self.data = m.to_dense().float()/torch.sum(m.to_dense(), dim=1, keepdim=True)
        self.pc = m.sum(dim=1, keepdim=True) / m.sum()
        inds = torch.load('inds_h3.pt')[chosen_inds]
        inds = ((inds >> 30) & 4194303)
        self.ind_map = -1+torch.zeros(2 ** 22, dtype=torch.int32)
        self.ind_map[inds] = torch.arange(inds.shape[0], dtype=torch.int32)

    def sample(self, pos):#, time, noise_level):
        pos = pos.cpu()
        pos_ind = torch.from_numpy((h3.latlng_to_cell(pos[:, 1], pos[:, 0], 5).astype(np.int64) >> 30) & 4194303)
        pos_ind = self.ind_map[pos_ind]
        out = self.data[:, pos_ind]
        out *= self.pc
        out = out/torch.sum(out, dim=0, keepdim=True)
        out[:, pos_ind < 0] = 1.0/out.shape[0]
        return out