diff --git "a/setup.py" "b/setup.py" new file mode 100644--- /dev/null +++ "b/setup.py" @@ -0,0 +1,1768 @@ +import collections +import sys + +import numpy as np +import pandas as pd +import random +import torch +import time +import os +import json +import tifffile +import h3 +import setup +from sklearn.linear_model import RidgeCV +from sklearn.preprocessing import MinMaxScaler +from torch.utils.data import Subset +import utils +import models +import datasets +from calendar import monthrange +from torch.nn.functional import logsigmoid, softmax +import torch.nn as nn +from tqdm import tqdm +import csv + +def format_tensor(tensor): + # Convert tensor to list, then flatten to string + tensor_list = tensor.tolist() # Converts the tensor to a Python list + return str(tensor_list).replace('\n', '').replace(' ', '') + +class EvaluatorSNT: + def __init__(self, train_params, eval_params): + self.train_params = train_params + self.eval_params = eval_params + with open('paths.json', 'r') as f: + paths = json.load(f) + D = np.load(os.path.join(paths['snt'], 'snt_res_5.npy'), allow_pickle=True) + D = D.item() + self.loc_indices_per_species = D['loc_indices_per_species'] + self.labels_per_species = D['labels_per_species'] + self.taxa = D['taxa'] + self.obs_locs = D['obs_locs'] + self.obs_locs_idx = D['obs_locs_idx'] + self.pos_eval_data_loc = os.path.join(paths['data'], 'positive_eval_data.npz') + self.background_eval_data_loc = os.path.join(paths['data'], '10000_background_negs.npz') + + def get_labels(self, species): + species = str(species) + lat = [] + lon = [] + gt = [] + for hx in self.data: + cur_lat, cur_lon = h3.h3_to_geo(hx) + if species in self.data[hx]: + cur_label = int(len(self.data[hx][species]) > 0) + gt.append(cur_label) + lat.append(cur_lat) + lon.append(cur_lon) + lat = np.array(lat).astype(np.float32) + lon = np.array(lon).astype(np.float32) + obs_locs = np.vstack((lon, lat)).T + gt = np.array(gt).astype(np.float32) + return obs_locs, gt + + @torch.no_grad() + def run_evaluation(self, model, enc, extra_input=None): + results = {} + + # set seeds: + np.random.seed(self.eval_params['seed']) + random.seed(self.eval_params['seed']) + + # evaluate the geo model for each taxon + results['per_species_average_precision_all'] = np.zeros((len(self.taxa)), dtype=np.float32) + + # get eval locations and apply input encoding + obs_locs = torch.from_numpy(self.obs_locs).to(self.eval_params['device']) + loc_feat = torch.cat([enc.encode(obs_locs), extra_input.expand(obs_locs.shape[0], -1)], dim=1) if extra_input is not None else enc.encode(obs_locs) + + # get classes to eval + classes_of_interest = torch.zeros(len(self.taxa), dtype=torch.int64) + for tt_id, tt in enumerate(self.taxa): + class_of_interest = np.where(np.array(self.train_params['class_to_taxa']) == tt)[0] + if len(class_of_interest) != 0: + classes_of_interest[tt_id] = torch.from_numpy(class_of_interest) + + if self.eval_params['extract_pos']: + assert 'HyperNet' in self.train_params['model'] + model = model.pos_enc + self.train_params['model'] = 'ResidualFCNet' + + if ('CombinedModel' in self.train_params['model']) or ('MultiInputModel' in self.train_params['model']): + with torch.no_grad(): + dummy_context_mask = None + dummy_context_sequence = None + + # generate model predictions for classes of interest at eval locations + loc_emb = model(x=loc_feat, context_sequence=dummy_context_sequence, context_mask=dummy_context_mask, + class_ids=classes_of_interest, return_feats=True) + + classes_of_interest = classes_of_interest.to(self.eval_params["device"]) + + wt = model.get_eval_embeddings(classes_of_interest) + + pred_mtx = torch.matmul(loc_emb, torch.transpose(wt, 0, 1)) + + elif self.train_params['model'] == 'VariableInputModel': + with torch.no_grad(): + loc_emb = model.get_loc_emb(x=loc_feat) + + classes_of_interest = classes_of_interest.to(self.eval_params["device"]) + + wt = model.get_eval_embeddings(classes_of_interest) + + pred_mtx = torch.matmul(loc_emb, torch.transpose(wt, 0, 1)) + + elif 'HyperNet' not in self.train_params['model'] and not (self.train_params['zero_shot'] or self.eval_params['num_samples'] > 0): + with torch.no_grad(): + # generate model predictions for classes of interest at eval locations + loc_emb = model(loc_feat, return_feats=True) + wt = model.class_emb.weight[classes_of_interest, :] + pred_mtx = torch.matmul(loc_emb, torch.transpose(wt, 0, 1)) + elif (self.train_params['zero_shot'] or self.eval_params['num_samples'] > 0): + if self.train_params['model'] == 'ResidualFCNet': + import datasets + # from sklearn.linear_model import LogisticRegression + # with open('paths.json', 'r') as f: + # paths = json.load(f) + # data_dir = paths['train'] + # obs_file = os.path.join(data_dir, self.train_params['obs_file']) + # taxa_file = os.path.join(data_dir, self.train_params['taxa_file']) + # taxa_file_snt = os.path.join(data_dir, 'taxa_subsets.json') + + # taxa_of_interest = datasets.get_taxa_of_interest(self.train_params['species_set'], self.train_params['num_aux_species'], + # self.train_params['aux_species_seed'], self.train_params['taxa_file'], taxa_file_snt) + obs_file = self.pos_eval_data_loc + locs, labels, _, dates, _, _ = datasets.load_eval_inat_data(obs_file) + unique_taxa, class_ids = np.unique(labels, return_inverse=True) + class_to_taxa = unique_taxa.tolist() + # idx_ss = datasets.get_idx_subsample_observations(labels, self.eval_params['num_samples'], random.randint(0,2**32), None, -1) + idx_ss = datasets.get_idx_subsample_observations_eval(labels=labels, hard_cap=self.eval_params['num_samples']) + locs = torch.from_numpy(np.array(locs)) + labels = torch.from_numpy(np.array(class_ids)) + locs = locs[idx_ss] + labels = labels[idx_ss] + with torch.no_grad(): + pos_examples = {} + for tt in self.taxa: + c = class_to_taxa.index(tt) + pos_examples[tt] = locs[labels == c] + pos_examples[tt] = model(enc.encode(pos_examples[tt].to(self.eval_params['device'])), return_feats=True).cpu() + + # MAX VERSION # MAX VERSION # MAX VERSION + # random negs + neg_examples = utils.rand_samples(10000, self.eval_params['device'], rand_type='spherical') + obs_file = self.background_eval_data_loc + neg_locs, _, _, _, _, _ = datasets.load_eval_inat_data(obs_file) + neg_locs = torch.from_numpy(neg_locs) + if extra_input is not None: + raise NotImplementedError('extra_input provided') + # add target negs + neg_examples = model(torch.cat([enc.encode(neg_examples, normalize=False), enc.encode( + neg_locs[torch.randperm(neg_locs.shape[0], device=locs.device)[:10000]].clone().to( + self.eval_params['device']), normalize=True)]), return_feats=True).cpu() + loc_emb = model(loc_feat, return_feats=True) + elif self.train_params['model'] == 'HyperNet': + import datasets + # from sklearn.linear_model import LogisticRegression + # with open('paths.json', 'r') as f: + # paths = json.load(f) + # data_dir = paths['train'] + # obs_file = os.path.join(data_dir, self.train_params['obs_file']) + # taxa_file = os.path.join(data_dir, self.train_params['taxa_file']) + # taxa_file_snt = os.path.join(data_dir, 'taxa_subsets.json') + # + # taxa_of_interest = datasets.get_taxa_of_interest(self.train_params['species_set'], self.train_params['num_aux_species'], + # self.train_params['aux_species_seed'], self.train_params['taxa_file'], taxa_file_snt) + # + obs_file = self.pos_eval_data_loc + locs, labels, _, dates, _, _ = datasets.load_eval_inat_data(obs_file) + unique_taxa, class_ids = np.unique(labels, return_inverse=True) + class_to_taxa = unique_taxa.tolist() + + if self.eval_params['num_samples'] > 0: + # idx_ss = datasets.get_idx_subsample_observations(labels, self.eval_params['num_samples'], random.randint(0,2**32), None, -1) + idx_ss = datasets.get_idx_subsample_observations_eval(labels=labels, hard_cap=self.eval_params['num_samples']) + locs = torch.from_numpy(np.array(locs)[idx_ss]) + labels = torch.from_numpy(np.array(class_ids)[idx_ss]) + with torch.no_grad(): + pos_examples = {} + for tt in self.taxa: + c = class_to_taxa.index(tt) + pos_examples[tt] = locs[labels == c] + pos_examples[tt] = model.pos_enc(enc.encode(pos_examples[tt].to(self.eval_params['device']))).cpu() + # random negs + neg_examples = utils.rand_samples(10000, self.eval_params['device'], rand_type='spherical') + obs_file = self.background_eval_data_loc + neg_locs, _, _, _, _, _ = datasets.load_eval_inat_data(obs_file) + neg_locs = torch.from_numpy(neg_locs) + if extra_input is not None: + raise NotImplementedError('extra_input provided') + neg_examples = model.pos_enc(torch.cat([enc.encode(neg_examples, normalize=False), enc.encode(neg_locs[torch.randperm(neg_locs.shape[0], device=locs.device)[:10000]].clone().to(self.eval_params['device']), normalize=True)])).cpu() + loc_emb = model.pos_enc(loc_feat) + #embs = torch.load(self.train_params['text_emb_path']) #TODO + #embs1 = torch.load('experiments/gpt_data.pt', weights_only=False) + embs1 = torch.load('experiments/gpt_data.pt', map_location='cpu') + #embs1 = torch.load('ldsdm_data.pt') + emb_ids1 = embs1['taxon_id'].tolist() + keys1 = embs1['keys'] + embs1 = embs1['data'] + # embs2 doesn't even do anything. Could just remove the whole thing, but that is how it is in Max's code + # MINE MINE MINE MINE MINE + embs2 = torch.load('experiments/wiki_data_v4.pt') + # MAX MAX MAX MAX + # embs2 = torch.load('wiki_data_v3.pt') + emb_ids2 = embs2['taxon_id'].tolist() + keys2 = embs2['keys'] + embs2 = embs2['data'] + else: + raise NotImplementedError('Eval for zero-shot not implemented') + # if self.eval_params['num_samples'] == -1 and not (('CombinedModel' in self.train_params['model']) or ('MultiInputModel' in self.train_params['model'] or )): + if self.eval_params['num_samples'] == -1 and not (self.train_params['model'] in ['CombinedModel', 'MultiInputModel', 'VariableInputModel', 'ResidualFCNet']): + loc_emb = model.pos_enc(loc_feat) + elif self.eval_params['num_samples'] == -1 and not (self.train_params['model'] in ['CombinedModel', 'MultiInputModel', 'VariableInputModel']): + loc_emb = model.forward(loc_feat, return_feats=True) + split_rng = np.random.default_rng(self.eval_params['split_seed']) + write_gt_once = False + #TODO: tt is the iNat taxa id for the taxa we are calculating AP for rn, tt_id is the index in the dictionary + #ap_csv = "per_species_average_precision_valid.csv" + #taxa_id_csv = "per_species_taxa_id_valid.csv" + # with open(taxa_id_csv, mode='w', newline='') as csv_file: + # writer = csv.writer(csv_file) + + # # If the array is multi-dimensional (e.g., 2D), iterate over rows + # if isinstance(self.taxa, np.ndarray): + # for value in self.taxa: + # writer.writerow([value]) + # else: + # # If it's a flat array, directly write the values + # writer.writerow(per_species_average_precision_valid) + + range, range_locs = [], [] + for tt_id, tt in tqdm(enumerate(self.taxa)): + + class_of_interest = np.where(np.array(self.train_params['class_to_taxa']) == tt)[0] + if len(class_of_interest) == 0 and not (self.train_params['zero_shot'] or self.eval_params['num_samples'] > 0): + # taxa of interest is not in the model + results['per_species_average_precision_all'][tt_id] = np.nan + # this only effects my models + elif self.train_params['model'] == 'VariableInputModel': + # generate ground truth labels for current taxa + cur_loc_indices = np.array(self.loc_indices_per_species[tt_id]) + cur_labels = np.array(self.labels_per_species[tt_id]) + + # apply per-species split: + assert self.eval_params['split'] in ['all', 'val', 'test'] + if self.eval_params['split'] != 'all': + num_val = np.floor(len(cur_labels) * self.eval_params['val_frac']).astype(int) + idx_rand = split_rng.permutation(len(cur_labels)) + if self.eval_params['split'] == 'val': + idx_sel = idx_rand[:num_val] + elif self.eval_params['split'] == 'test': + idx_sel = idx_rand[num_val:] + cur_loc_indices = cur_loc_indices[idx_sel] + cur_labels = cur_labels[idx_sel] + cur_labels = (torch.from_numpy(cur_labels).to(self.eval_params['device']) > 0).float() + + with torch.no_grad(): + logits = pred_mtx[:, tt_id] + preds = torch.sigmoid(logits) + #TODO metric value is calcuated + #this is how we get the predictions, just matching the hexs for the spots we are interested in. + results['per_species_average_precision_all'][tt_id] = utils.average_precision_score_fasterer( + cur_labels, + preds[cur_loc_indices]).item() + continue + elif self.train_params['model'] == 'MultiInputModel': + # generate ground truth labels for current taxa + #todo: ask max, are the loc_indices the h3 indices at res 5? + #these are the inidices of the locations of where we have evaluations + cur_loc_indices = np.array(self.loc_indices_per_species[tt_id]) + #loc_indices_per_species_array = np.array(self.loc_indices_per_species[tt_id]) + #this is the answer key + cur_labels = np.array(self.labels_per_species[tt_id]) #87373 "0." + #labels_per_species_array = np.array(self.labels_per_species[tt_id]) #174746 '0' + + # apply per-species split: + assert self.eval_params['split'] in ['all', 'val', 'test'] + if self.eval_params['split'] != 'all': + num_val = np.floor(len(cur_labels) * self.eval_params['val_frac']).astype(int) + idx_rand = split_rng.permutation(len(cur_labels)) + if self.eval_params['split'] == 'val': + idx_sel = idx_rand[:num_val] + elif self.eval_params['split'] == 'test': + idx_sel = idx_rand[num_val:] + cur_loc_indices = cur_loc_indices[idx_sel] + cur_labels = cur_labels[idx_sel] + cur_labels = (torch.from_numpy(cur_labels).to(self.eval_params['device']) > 0).float() + #print('printing location testing') + #matching_locations = obs_locs[loc_indices_per_species_array[labels_per_species_array == 1]]#21737 this is bigger because we take out the all and val locations + matching_locations = obs_locs[cur_loc_indices[cur_labels == 1]] #10849 + range_locs.append(matching_locations) + #print(f'matching locations len: {len(matching_locations)}') + range.append(cur_labels) + #print(f'range cur labels len: {cur_labels.sum()}') + #print(f'number of locations matches: matching locations: {np.shape(matching_locations)} and cur_labels: {cur_labels.sum()}') + + # if not write_gt_once: + # snt_labels_csv = f"data/plot/taxa_locs/snt_locations_{tt}.csv" + # with open(snt_labels_csv, mode='w', newline='') as csv_file: + # writer = csv.writer(csv_file) + + # # If the array is multi-dimensional (e.g., 2D), iterate over rows + # if isinstance(matching_locations, np.ndarray): + # for value in matching_locations: + # writer.writerow([value]) + # else: + # # If it's a flat array, directly write the values + # writer.writerow(matching_locations) + #print(f'current labels snt: {np.shape(cur_labels)}') + + with torch.no_grad(): + logits = pred_mtx[:, tt_id] + preds = torch.sigmoid(logits) + #TODO metric value is calcuated + results['per_species_average_precision_all'][tt_id] = utils.average_precision_score_fasterer( + cur_labels, + preds[cur_loc_indices]).item() + continue + + # MINE MINE MINE MINE MINE MINE + # elif self.eval_params['num_samples'] == -1: + # gt = torch.zeros(obs_locs.shape[0], dtype=torch.float32, device=self.eval_params['device']) + # gt[self.data['taxa_presence'][str(tt)]] = 1.0 + # species_w = model.species_params[self.train_params['class_to_taxa'].index(tt)] + # preds = loc_emb @ species_w.detach() + # results['per_species_average_precision_all'][tt_id] = utils.average_precision_score_fasterer(gt, preds).item() + # continue + else: + # generate ground truth labels for current taxa + cur_loc_indices = np.array(self.loc_indices_per_species[tt_id]) + cur_labels = np.array(self.labels_per_species[tt_id]) + + # apply per-species split: + assert self.eval_params['split'] in ['all', 'val', 'test'] + if self.eval_params['split'] != 'all': + num_val = np.floor(len(cur_labels) * self.eval_params['val_frac']).astype(int) + idx_rand = split_rng.permutation(len(cur_labels)) + if self.eval_params['split'] == 'val': + idx_sel = idx_rand[:num_val] + elif self.eval_params['split'] == 'test': + idx_sel = idx_rand[num_val:] + cur_loc_indices = cur_loc_indices[idx_sel] + cur_labels = cur_labels[idx_sel] + cur_labels = (torch.from_numpy(cur_labels).to(self.eval_params['device']) > 0).float() + + ########################################################################################## + # + ########################################################################################## + if self.eval_params['num_samples'] == -1 and self.train_params['model'] == 'HyperNet': + species_w = model.species_params[self.train_params['class_to_taxa'].index(tt)] + preds = loc_emb @ species_w.detach() + results['per_species_average_precision_all'][tt_id] = utils.average_precision_score_fasterer(cur_labels, + preds[cur_loc_indices]).item() + continue + elif self.eval_params['num_samples'] == -1 and self.train_params['model'] == 'ResidualFCNet': + preds = model.eval_single_class(x=loc_emb, class_of_interest=self.train_params['class_to_taxa'].index(tt)).detach() + # species_w = model.species_params[self.train_params['class_to_taxa'].index(tt)] + # preds = loc_emb @ species_w.detach() + results['per_species_average_precision_all'][tt_id] = utils.average_precision_score_fasterer(cur_labels, + preds[cur_loc_indices]).item() + continue + if 'HyperNet' not in self.train_params['model'] and not (self.train_params['zero_shot'] or self.eval_params['num_samples'] > 0): + # extract model predictions for current taxa from prediction matrix + pred = pred_mtx[cur_loc_indices, tt_id] + elif self.train_params['zero_shot'] or self.eval_params['num_samples'] > 0: + if self.train_params['model'] == 'ResidualFCNet': + if self.eval_params['num_samples'] == 0: + X = torch.cat([pos_examples[tt], neg_examples], dim=0).to(self.eval_params['device']) + w = torch.nn.Parameter(torch.zeros(X.shape[1], 1, device=self.eval_params['device'])) + nn.init.xavier_uniform_(w) + pred = torch.sigmoid(((loc_emb @ w)))[cur_loc_indices].flatten() + else: + X = torch.cat([pos_examples[tt], neg_examples], dim=0).to(self.eval_params['device']) + y = torch.zeros(X.shape[0], dtype=torch.long, device=self.eval_params['device']) + y[:pos_examples[tt].shape[0]] = 1 + # MINE MINE MINE MINE MINE MINE + # clf = LogisticRegression(class_weight='balanced', fit_intercept=False, C=0.05, max_iter=200, random_state=0).fit(X.numpy(), y.numpy()) + # #pred = torch.from_numpy(clf.predict_proba(loc_emb.cpu()))[:,1] + # pred = torch.sigmoid(((loc_emb @ (torch.from_numpy(clf.coef_).cuda().float().T)) + torch.from_numpy(clf.intercept_).cuda().float()).squeeze(-1))[cur_loc_indices] + # MAX MAX MAX MAX MAX MAX MAX MAX MAX + + #clf = LogisticRegression(class_weight='balanced', fit_intercept=False, C=0.05, max_iter=200, random_state=0).fit(X.numpy(), y.numpy()) + + C = 0.05 + w = torch.nn.Parameter(torch.zeros(X.shape[1], 1, device=self.eval_params['device'])) + opt = torch.optim.Rprop([w], lr=0.001) + crit = torch.nn.BCEWithLogitsLoss() + crit2 = torch.nn.MSELoss() + with torch.set_grad_enabled(True): + for i in range(40): + opt.zero_grad() + output = X @ w + yhat = y.float()[:, None] + loss = 0.5 * crit(output[yhat == 0], yhat[yhat == 0]) + 0.5 * crit(output[yhat == 1], + yhat[ + yhat == 1]) + 1 / ( + C * len(pos_examples[tt])) * crit2(w, 0 * w) + loss.backward() + opt.step() + #pred = torch.from_numpy(clf.predict_proba(loc_emb.cpu()))[:,1] + # pred = torch.sigmoid(((loc_emb @ w.cuda())))[cur_loc_indices].flatten() + pred = torch.sigmoid(((loc_emb @ w)))[cur_loc_indices].flatten() + #pred = torch.sigmoid(((loc_emb @ (torch.from_numpy(clf.coef_).cuda().float().T)) + torch.from_numpy(clf.intercept_).cuda().float()).squeeze(-1))[cur_loc_indices] + elif self.train_params['model'] == 'HyperNet': + if tt in emb_ids1: + embs = embs1 + emb_ids = emb_ids1 + keys = keys1 + else: + print('yes') + results['per_species_average_precision_all'][tt_id] = 0.0 + continue + embs = embs2 + emb_ids = emb_ids2 + keys = keys2 + if tt not in emb_ids: + results['per_species_average_precision_all'][tt_id] = 0.0 + continue + with torch.no_grad(): + sec_ind = emb_ids.index(tt) + sections = [i for i,x in enumerate(keys) if x[0] == sec_ind] + def get_feat(x): + species = model.species_enc(model.species_emb.zero_shot(x)) + species_w, species_b = species[..., :-1], species[..., -1:] + if self.eval_params['num_samples'] == 0: + out = loc_emb @ (species_w.detach()).T + return out + X = torch.cat([pos_examples[tt], neg_examples], dim=0).to(self.eval_params['device']) + y = torch.zeros(X.shape[0], dtype=torch.long, device=self.eval_params['device']) + y[:pos_examples[tt].shape[0]] = 1 + C = 0.05 + w = torch.nn.Parameter(torch.zeros_like(species_w,device=self.eval_params['device'])) + opt = torch.optim.Rprop([w], lr=0.001) + crit = torch.nn.BCEWithLogitsLoss() + crit2 = torch.nn.MSELoss() + with torch.set_grad_enabled(True): + for i in range(40): + opt.zero_grad() + output = (X @ (w + species_w.detach()).T) + 0*species_b.squeeze(-1) + yhat = y.float()[:, None].repeat(1, w.shape[0]) + loss = 0.5*crit(output[yhat == 0], yhat[yhat == 0]) + 0.5*crit(output[yhat == 1], yhat[yhat == 1]) + \ + 1/(C*len(pos_examples[tt])) * crit2(w, 0*w) + loss.backward() + opt.step() + #print(i, loss.item()) + #print(' ') + out = loc_emb @ (w.data + species_w.detach()).T + out = (out + 0*species_b.squeeze(-1)) + return out + # average precision score: + yfeats = torch.cat([embs[section][None].to(self.eval_params['device']) for section in sections]) + preds = get_feat(yfeats) + if len(sections) > 1:#'habitat', 'overview_summary' + kws = ['text', 'range', 'distribution', 'habitat'] if len(keys) == len(keys2) else [self.eval_params['text_section']] + best_sections = [i for i,s in enumerate(sections) if any((x in keys[s][1].lower() for x in kws))] + results['per_species_average_precision_all'][tt_id] = utils.average_precision_score_fasterer(cur_labels, preds[cur_loc_indices][:,best_sections].mean(dim=1)).item() + else: + results['per_species_average_precision_all'][tt_id] = utils.average_precision_score_fasterer(cur_labels, preds[cur_loc_indices][:,0].mean(dim=1)).item() + continue + else: + raise NotImplementedError('Eval for hypernet not implemented') + pred = preds[:,tt_id%32] + # compute the AP for each taxa + results['per_species_average_precision_all'][tt_id] = utils.average_precision_score_fasterer(cur_labels, pred).item() + + valid_taxa = ~np.isnan(results['per_species_average_precision_all']) + + # store results + #TODO: this will have AP values for every species + #tt_id + per_species_average_precision_valid = results['per_species_average_precision_all'][valid_taxa] + + results['mean_average_precision'] = per_species_average_precision_valid.mean() + results['num_eval_species_w_valid_ap'] = valid_taxa.sum() + results['num_eval_species_total'] = len(self.taxa) + + taxas_and_ap_csv = "taxas_ap_range.csv" + #ap_csv = "per_species_taxa_id_valid.csv" + print(list(map(lambda row:len(row) ,range))) + zipped_data = zip(self.taxa, per_species_average_precision_valid, list(map(lambda row:int(row.sum()),range)), range_locs) + with open(taxas_and_ap_csv, mode='w', newline='') as csv_file: + writer = csv.writer(csv_file) + + # Write the header (optional) + writer.writerow(['Taxa ID', 'Average Precision','Range Size', 'Range']) + + # Write the zipped data + for taxa, ap, range_size, tensor_range in zipped_data: + # Flatten tensor to a single-line string + tensor_range_str = format_tensor(tensor_range) + writer.writerow([taxa, ap, range_size, tensor_range_str]) + # with open(ap_csv, mode='w', newline='') as csv_file: + # writer = csv.writer(csv_file) + + # # Write the zipped data + # writer.writerows(per_species_average_precision_valid) + + return results + + def report(self, results): + for field in ['mean_average_precision', 'num_eval_species_w_valid_ap', 'num_eval_species_total']: + print(f'{field}: {results[field]}') + +class EvaluatorIUCN: + + def __init__(self, train_params, eval_params): + self.train_params = train_params + print(train_params['text_num_layers'],train_params['text_batchnorm'],train_params['text_hidden_dim'])#TODO + self.eval_params = eval_params + with open('paths.json', 'r') as f: + paths = json.load(f) + with open(os.path.join(paths['iucn'], 'iucn_res_5.json'), 'r') as f: + self.data = json.load(f) + self.obs_locs = np.array(self.data['locs'], dtype=np.float32) + self.taxa = [int(tt) for tt in self.data['taxa_presence'].keys()] + self.pos_eval_data_loc = os.path.join(paths['data'], 'positive_eval_data.npz') + self.background_eval_data_loc = os.path.join(paths['data'], '10000_background_negs.npz') + + @torch.no_grad() + def run_evaluation(self, model, enc, extra_input=None): + results = {} + #self.train_params['model'] = 'ResidualFCNet' + #m = model + #model = lambda x, return_feats=True: m.pos_enc(x) + results['per_species_average_precision_all'] = np.zeros(len(self.taxa), dtype=np.float32) + # get eval locations and apply input encoding + obs_locs = torch.from_numpy(self.obs_locs).to(self.eval_params['device']) + loc_feat = torch.cat([enc.encode(obs_locs), extra_input.expand(obs_locs.shape[0], -1)], dim=1) if extra_input is not None else enc.encode(obs_locs) + + # get classes to eval + # classes_of_interest = torch.zeros(len(self.taxa), dtype=torch.int64) + classes_of_interest = np.zeros(len(self.taxa)) + array_class_to_taxa = np.array(self.train_params['class_to_taxa']) + for tt_id, tt in enumerate(self.taxa): + class_of_interest = np.where(array_class_to_taxa == tt)[0] + if len(class_of_interest) != 0: + classes_of_interest[tt_id] = class_of_interest + classes_of_interest = torch.from_numpy(classes_of_interest).to(dtype=torch.long, device=self.eval_params['device']) + + # MINE MINE MINE + # classes_of_interest = classes_of_interest.to(self.eval_params['device']) + + if self.eval_params['extract_pos']: + assert 'HyperNet' in self.train_params['model'] + model = model.pos_enc + self.train_params['model'] = 'ResidualFCNet' + # Should only effect mine + if ('CombinedModel' in self.train_params['model']) or ('MultiInputModel' in self.train_params['model']): + with torch.no_grad(): + dummy_context_mask = None + dummy_context_sequence = None + # generate model predictions for classes of interest at eval locations + loc_emb = model(x=loc_feat, context_sequence=dummy_context_sequence, context_mask=dummy_context_mask, + class_ids=classes_of_interest, return_feats=True) + wt = model.get_eval_embeddings(classes_of_interest) + print("Creating IUCN prediction matrix") + pred_mtx = torch.matmul(loc_emb, torch.transpose(wt, 0, 1)) + + elif self.train_params['model'] == 'VariableInputModel': + with torch.no_grad(): + loc_emb = model.get_loc_emb(x=loc_feat) + + classes_of_interest = classes_of_interest.to(self.eval_params["device"]) + + wt = model.get_eval_embeddings(classes_of_interest) + + wt2 = model.get_ema_embeddings(classes_of_interest) + # technically with my mock transformer I could just directly access the class embeddings but + # I will need to use the emas when I move to the true transformer model (I think) + + # wt = model.class_emb.weight[classes_of_interest, :] + + pred_mtx = torch.matmul(loc_emb, torch.transpose(wt, 0, 1)) + + elif 'HyperNet' not in self.train_params['model'] and not (self.train_params['zero_shot'] or self.eval_params['num_samples'] > 0): + # generate model predictions for classes of interest at eval locations + loc_emb = model(loc_feat, return_feats=True) + wt = model.class_emb.weight[classes_of_interest, :] + pred_mtx = torch.matmul(loc_emb, torch.transpose(wt, 0, 1)) + elif (self.train_params['zero_shot'] or self.eval_params['num_samples'] > 0): + if self.train_params['model'] == 'ResidualFCNet': + import datasets + # from sklearn.linear_model import LogisticRegression + # with open('paths.json', 'r') as f: + # paths = json.load(f) + # data_dir = paths['train'] + # obs_file = os.path.join(data_dir, self.train_params['obs_file']) + # taxa_file = os.path.join(data_dir, self.train_params['taxa_file']) + # taxa_file_snt = os.path.join(data_dir, 'taxa_subsets.json') + # + # taxa_of_interest = datasets.get_taxa_of_interest(self.train_params['species_set'], self.train_params['num_aux_species'], + # self.train_params['aux_species_seed'], self.train_params['taxa_file'], taxa_file_snt) + obs_file = self.pos_eval_data_loc + locs, labels, _, dates, _, _ = datasets.load_eval_inat_data(obs_file) + unique_taxa, class_ids = np.unique(labels, return_inverse=True) + class_to_taxa = unique_taxa.tolist() + # idx_ss = datasets.get_idx_subsample_observations(labels, self.eval_params['num_samples'], random.randint(0,2**32), None, -1) + idx_ss = datasets.get_idx_subsample_observations_eval(labels=labels, hard_cap=self.eval_params['num_samples']) + locs = torch.from_numpy(np.array(locs)) + labels = torch.from_numpy(np.array(class_ids)) + locs = locs[idx_ss] + labels = labels[idx_ss] + # MINE MINE MINE MINE MINE MINE MINE MINE MINE + # with torch.no_grad(): + # pos_examples = {} + # for tt in self.taxa: + # c = class_to_taxa.index(tt) + # pos_examples[tt] = locs[labels == c] + # pos_examples[tt] = model(enc.encode(pos_examples[tt].to(self.eval_params['device'])), return_feats=True).cpu() + # + # if self.eval_params['target_background']: + # target_background_dataset = datasets.get_train_data(params=self.train_params) + # # print("CHECK IF THIS TARGET NEGS THING IS WORKING PROPERLY WHEN SERVER WORKS") + # # print("IT MAY INCLUDE EVAL SPECIES / ONLY EVAL SPECIES") # it only includes the backbone species currently - good + # + # random_negs = utils.rand_samples(5000, self.eval_params['device'], rand_type='spherical') + # + # # Get the total number of locations + # total_locs = len(target_background_dataset.locs) + # + # # If there are more than 5000 locations, sample 5000 + # if total_locs > 5000: + # indices = np.random.choice(total_locs, 5000, replace=False) + # target_negs = target_background_dataset.locs[indices].to(self.eval_params['device']) + # else: + # target_negs = target_background_dataset.locs.to(self.eval_params['device']) + # # print('CHECK THE FORMAT OF THESE TARGET LOCS COMPARED TO NEG LOCS') # look good + # + # neg_examples = torch.vstack((random_negs, target_negs)) + # + # del target_background_dataset + # + # else: + # neg_examples = utils.rand_samples(10000, self.eval_params['device'], rand_type='spherical') + # if extra_input is not None: + # raise NotImplementedError('extra_input provided') + # neg_examples = model(enc.encode(neg_examples, normalize=False), return_feats=True).cpu() + # print("You can probably speed eval back up once the server is available by changing this shit back") + # + # # Function to process data in batches + # def process_in_batches(model, loc_feat, batch_size=64): + # loc_emb = [] + # for i in range(0, len(loc_feat), batch_size): + # batch = loc_feat[i:i + batch_size] + # with torch.no_grad(): + # batch_emb = model(batch, return_feats=True) + # loc_emb.append(batch_emb) + # return torch.cat(loc_emb, dim=0) # Concatenate the results + # + # # loc_emb = model(loc_feat, return_feats=True) + # loc_emb = process_in_batches(model, loc_feat, batch_size=2048) + pos_examples = {} + for tt in self.taxa: + c = class_to_taxa.index(tt) + pos_examples[tt] = locs[labels == c] + pos_examples[tt] = model(enc.encode(pos_examples[tt].to(self.eval_params['device'])), return_feats=True).cpu() + obs_file = self.background_eval_data_loc + neg_locs, _, _, _, _, _ = datasets.load_eval_inat_data(obs_file) + neg_locs = torch.from_numpy(neg_locs) + #random negs + neg_examples = utils.rand_samples(10000, self.eval_params['device'], rand_type='spherical') + if extra_input is not None: + raise NotImplementedError('extra_input provided') + # add target negs + neg_examples = model(torch.cat([enc.encode(neg_examples, normalize=False), enc.encode(neg_locs[torch.randperm(neg_locs.shape[0], device=locs.device)[:10000]].clone().to(self.eval_params['device']), normalize=True)]), return_feats=True).cpu() + loc_emb = model(loc_feat, return_feats=True) + elif self.train_params['model'] == 'HyperNet': + import datasets + # from sklearn.linear_model import LogisticRegression + # with open('paths.json', 'r') as f: + # paths = json.load(f) + # data_dir = paths['train'] + # obs_file = os.path.join(data_dir, self.train_params['obs_file']) + # taxa_file = os.path.join(data_dir, self.train_params['taxa_file']) + # taxa_file_snt = os.path.join(data_dir, 'taxa_subsets.json') + # + # taxa_of_interest = datasets.get_taxa_of_interest(self.train_params['species_set'], self.train_params['num_aux_species'], + # self.train_params['aux_species_seed'], self.train_params['taxa_file'], taxa_file_snt) + obs_file = self.pos_eval_data_loc + locs, labels, _, dates, _, _ = datasets.load_eval_inat_data(obs_file) + # MINE MINE MINE MINE + # unique_taxa, class_ids = np.unique(labels, return_inverse=True) + # class_to_taxa = unique_taxa.tolist() + # idx_ss = datasets.get_idx_subsample_observations(labels, self.eval_params['num_samples'], random.randint(0,2**32), None, -1) + # locs = torch.from_numpy(np.array(locs)) + # labels = torch.from_numpy(np.array(class_ids)) + # locs = locs[idx_ss] + # labels = labels[idx_ss] + # with torch.no_grad(): + # MAX MAX MAX MAX MAX MAX MAX + unique_taxa, class_ids, class_counts = np.unique(labels, return_inverse=True, return_counts=True) + class_counts = class_counts.clip(max=1000) + if self.eval_params['num_samples'] > 0: + class_to_taxa = unique_taxa.tolist() + idx_ss = datasets.get_idx_subsample_observations_eval(labels=labels, hard_cap=self.eval_params['num_samples']) + # idx_ss = datasets.get_idx_subsample_observations(labels, self.eval_params['num_samples'], random.randint(0,2**32), None, -1) + locs = torch.from_numpy(np.array(locs)) + labels = torch.from_numpy(np.array(class_ids)) + locs = locs[idx_ss] + labels = labels[idx_ss] + pos_examples = {} + for tt in self.taxa: + c = class_to_taxa.index(tt) + pos_examples[tt] = locs[labels == c] + pos_examples[tt] = model.pos_enc(enc.encode(pos_examples[tt].to(self.eval_params['device']))).cpu() + # MINE MINE MINE MINE MINE MINE MINE MINE + # if self.eval_params['target_background']: + # + # target_background_dataset = datasets.get_train_data(params=self.train_params) + # # print("CHECK IF THIS TARGET NEGS THING IS WORKING PROPERLY WHEN SERVER WORKS") + # # print("IT MAY INCLUDE EVAL SPECIES / ONLY EVAL SPECIES") + # + # random_negs = utils.rand_samples(5000, self.eval_params['device'], rand_type='spherical') + # + # # Get the total number of locations + # total_locs = len(target_background_dataset.locs) + # + # # If there are more than 5000 locations, sample 5000 + # if total_locs > 5000: + # indices = np.random.choice(total_locs, 5000, replace=False) + # target_negs = target_background_dataset.locs[indices].to(self.eval_params['device']) + # else: + # target_negs = target_background_dataset.locs.to(self.eval_params['device']) + # # print('CHECK THE FORMAT OF THESE TARGET LOCS COMPARED TO NEG LOCS') + # + # neg_examples = torch.vstack((random_negs, target_negs)) + # + # del target_background_dataset + # + # else: + # neg_examples = utils.rand_samples(10000, self.eval_params['device'], rand_type='spherical') + # MAX MAX MAX MAX MAX MAX MAX MAX MAX MAX MAX MAX MAX + obs_file = self.background_eval_data_loc + neg_locs, _, _, _, _, _ = datasets.load_eval_inat_data(obs_file) + neg_locs = torch.from_numpy(neg_locs) + # random negs + neg_examples = utils.rand_samples(10000, self.eval_params['device'], rand_type='spherical') + if extra_input is not None: + raise NotImplementedError('extra_input provided') + # MINE MINE MINE + # neg_examples = model.pos_enc(enc.encode(neg_examples, normalize=False)).cpu() + # MAX MAX MAX MAX MAX MAX MAX + # add target negs + neg_examples = model.pos_enc(torch.cat([enc.encode(neg_examples, normalize=False), enc.encode(neg_locs[torch.randperm(neg_locs.shape[0], device=locs.device)[:10000]].clone().to(self.eval_params['device']), normalize=True)])).cpu() + + #embs = torch.load(self.train_params['text_emb_path']) #TODO + embs = torch.load('gpt_data.pt', weights_only=False) + #embs = torch.load('ldsdm_data.pt') + emb_ids = embs['taxon_id'].tolist() + keys = embs['keys'] + embs = embs['data'] + # embs2 doesn't even do anything. Could just remove the whole thing, but that is how it is in Max's code + # MINE MINE MINE + embs2 = torch.load('wiki_data_v4.pt', weights_only=False) + # MAX MAX MAX + # embs2 = torch.load('wiki_data_v3.pt') + emb_ids2 = embs2['taxon_id'].tolist() + keys2 = embs2['keys'] + embs2 = embs2['data'] + loc_emb = model.pos_enc(loc_feat) + else: + raise NotImplementedError('Eval for zero-shot not implemented') + # MINE - my version - why am I stopping residualFCnets doing this? + # if self.eval_params['num_samples'] == -1 and not (('CombinedModel' in self.train_params['model']) or ('MultiInputModel' in self.train_params['model']) or ('ResidualFCNet' in self.train_params['model'])): + # MAX - a variant of Maxs - only difference should now be my model types + #if self.eval_params['num_samples'] == -1 and not (('CombinedModel' in self.train_params['model']) or ('MultiInputModel' in self.train_params['model'])): + if self.eval_params['num_samples'] == -1 and not (self.train_params['model'] in ['CombinedModel', 'MultiInputModel', 'VariableInputModel', 'ResidualFCNet']): + loc_emb = model.pos_enc(loc_feat) + if self.eval_params['num_samples'] == -1 and not (self.train_params['model'] in ['CombinedModel', 'MultiInputModel', 'VariableInputModel']): + loc_emb = model.forward(loc_feat, return_feats=True) + for tt_id, tt in tqdm(enumerate(self.taxa)): + class_of_interest = np.where(array_class_to_taxa == tt)[0] + if len(class_of_interest) == 0 and not (self.train_params['zero_shot'] or self.eval_params['num_samples'] > 0): + # taxa of interest is not in the model + results['per_species_average_precision_all'][tt_id] = np.nan + else: + # Only effects my models + if self.train_params['model'] == 'MultiInputModel': + gt = torch.zeros(obs_locs.shape[0], dtype=torch.float32, device=self.eval_params['device']) + gt[self.data['taxa_presence'][str(tt)]] = 1.0 + with torch.no_grad(): + logits = pred_mtx[:, tt_id] + preds = torch.sigmoid(logits) + results['per_species_average_precision_all'][tt_id] = utils.average_precision_score_fasterer(gt, preds).item() + continue + elif self.train_params['model'] == 'VariableInputModel': + gt = torch.zeros(obs_locs.shape[0], dtype=torch.float32, device=self.eval_params['device']) + gt[self.data['taxa_presence'][str(tt)]] = 1.0 + with torch.no_grad(): + logits = pred_mtx[:, tt_id] + preds = torch.sigmoid(logits) + results['per_species_average_precision_all'][tt_id] = utils.average_precision_score_fasterer(gt, preds).item() + continue + # MINE MINE MINE + # elif (self.train_params['model'] == 'ResidualFCNet') and (self.eval_params['num_samples'] <= 0): + # gt = torch.zeros(obs_locs.shape[0], dtype=torch.float32, device=self.eval_params['device']) + # gt[self.data['taxa_presence'][str(tt)]] = 1.0 + # with torch.no_grad(): + # logits = pred_mtx[:, tt_id] + # preds = torch.sigmoid(logits) + # results['per_species_average_precision_all'][tt_id] = utils.average_precision_score_fasterer(gt, preds).item() + # continue + if self.eval_params['num_samples'] == -1 and self.train_params['model'] == 'HyperNet': + gt = torch.zeros(obs_locs.shape[0], dtype=torch.float32, device=self.eval_params['device']) + gt[self.data['taxa_presence'][str(tt)]] = 1.0 + species_w = model.species_params[self.train_params['class_to_taxa'].index(tt)] + preds = loc_emb @ species_w.detach() + results['per_species_average_precision_all'][tt_id] = utils.average_precision_score_fasterer(gt, preds).item() + continue + elif self.eval_params['num_samples'] == -1 and self.train_params['model'] == 'ResidualFCNet': + gt = torch.zeros(obs_locs.shape[0], dtype=torch.float32, device=self.eval_params['device']) + gt[self.data['taxa_presence'][str(tt)]] = 1.0 + preds = model.eval_single_class(x=loc_emb, class_of_interest=self.train_params['class_to_taxa'].index(tt)).detach() + results['per_species_average_precision_all'][tt_id] = utils.average_precision_score_fasterer(gt, preds).item() + continue + # MINE MINE MINE MINE MINE MINE MINE - seems un needed? + # elif (self.eval_params['num_samples'] == -1) and ('Hypernet' in self.train_params['model']): + # gt = torch.zeros(obs_locs.shape[0], dtype=torch.float32, device=self.eval_params['device']) + # gt[self.data['taxa_presence'][str(tt)]] = 1.0 + # species_w = model.species_params[self.train_params['class_to_taxa'].index(tt)] + # preds = loc_emb @ species_w.detach() + # results['per_species_average_precision_all'][tt_id] = utils.average_precision_score_fasterer(gt, preds).item() + # continue + # extract model predictions for current taxa from prediction matrix + if 'HyperNet' not in self.train_params['model'] and not (self.train_params['zero_shot'] or self.eval_params['num_samples'] > 0): + pred = pred_mtx[:, tt_id] + elif (self.train_params['zero_shot'] or self.eval_params['num_samples'] > 0): + if self.train_params['model'] == 'ResidualFCNet': + if self.eval_params['num_samples'] == 0: + X = torch.cat([pos_examples[tt], neg_examples], dim=0).to(self.eval_params['device']) + w = torch.nn.Parameter(torch.zeros(X.shape[1], 1, device=self.eval_params['device'])) + nn.init.xavier_uniform_(w) + pred = torch.sigmoid(((loc_emb @ w)))[cur_loc_indices].flatten() + else: + X = torch.cat([pos_examples[tt], neg_examples], dim=0).to(self.eval_params['device']) + y = torch.zeros(X.shape[0], dtype=torch.long, device=self.eval_params['device']) + y[:pos_examples[tt].shape[0]] = 1 + # MINE MINE MINE + # clf = LogisticRegression(class_weight='balanced', fit_intercept=False, C=0.05, max_iter=200, random_state=0).fit(X.numpy(), y.numpy()) + # #pred = torch.from_numpy(clf.predict_proba(loc_emb.cpu()))[:,1] + # pred = torch.sigmoid(((loc_emb @ (torch.from_numpy(clf.coef_).to(self.eval_params['device']).float().T)) + torch.from_numpy(clf.intercept_).to(self.eval_params['device']).float()).squeeze(-1)) + # # pred = torch.sigmoid(((loc_emb @ (torch.from_numpy(clf.coef_).cuda().float().T)) + torch.from_numpy(clf.intercept_).cuda().float()).squeeze(-1)) + # MAX MAX MAX MAX MAX MAX MAX MAX MAX MAX MAX + + #clf = LogisticRegression(class_weight='balanced', fit_intercept=False, C=0.05, max_iter=200, random_state=0).fit(X.numpy(), y.numpy()) + + C = 0.05 + w = torch.nn.Parameter(torch.zeros(X.shape[1], 1, device=self.eval_params['device'])) + opt = torch.optim.Rprop([w], lr=0.001) + crit = torch.nn.BCEWithLogitsLoss() + crit2 = torch.nn.MSELoss() + with torch.set_grad_enabled(True): + for i in range(40): + opt.zero_grad() + output = X @ w + yhat = y.float()[:, None] + loss = 0.5 * crit(output[yhat == 0], yhat[yhat == 0]) + 0.5 * crit(output[yhat == 1], + yhat[ + yhat == 1]) + 1 / ( + C * len(pos_examples[tt])) * crit2(w, 0 * w) + loss.backward() + opt.step() + + pred = torch.sigmoid(((loc_emb @ w))).flatten() + #pred = torch.from_numpy(clf.predict_proba(loc_emb.cpu()))[:,1] + #pred = torch.sigmoid(((loc_emb @ (torch.from_numpy(clf.coef_).cuda().float().T)) + torch.from_numpy(clf.intercept_).cuda().float()).squeeze(-1)) + #locs = torch.from_numpy(utils.coord_grid((1000,2000))).to(self.eval_params['device']) + #locs = model(enc.encode(locs), return_feats=True) + #img = torch.sigmoid(((locs @ (torch.from_numpy(clf.coef_).cuda().float().T)) + torch.from_numpy(clf.intercept_).cuda().float()).squeeze(-1)) + #plt.imshow(img.detach().cpu()) + elif self.train_params['model'] == 'HyperNet': + if tt not in emb_ids and tt not in emb_ids2: + results['per_species_average_precision_all'][tt_id] = 0.0 + continue + gt = torch.zeros(obs_locs.shape[0], dtype=torch.float32, device=self.eval_params['device']) + gt[self.data['taxa_presence'][str(tt)]] = 1.0 + if self.eval_params['num_samples'] == -1: + species_w = model.species_params[self.train_params['class_to_taxa'].index(tt)] + preds = loc_emb @ species_w.detach() + results['per_species_average_precision_all'][tt_id] = utils.average_precision_score_fasterer(gt,preds).item() + continue + with torch.no_grad(): + if tt in emb_ids: + em = embs + emi = emb_ids + ky = keys + else: + results['per_species_average_precision_all'][tt_id] = 0.0 + continue + em = embs2 + emi = emb_ids2 + ky = keys2 + sec_ind = emi.index(tt) + sections = [i for i,x in enumerate(ky) if x[0] == sec_ind] + order = ['distribution', 'range', 'text'] + best_section = None + order_ind = 0 + while best_section is None and order_ind < len(order): + for section in sections: + if order[order_ind] in ky[section][1].lower(): + best_section = section + break + order_ind += 1 + gt = torch.zeros(obs_locs.shape[0], dtype=torch.float32, device=self.eval_params['device']) + gt[self.data['taxa_presence'][str(tt)]] = 1.0 + def get_feat(x): + species = model.species_enc(model.species_emb.zero_shot(x)) + species_w, species_b = species[..., :-1], species[..., -1:] + if self.eval_params['num_samples'] == 0: + out = loc_emb @ (species_w.detach()).T + return out + + X = torch.cat([pos_examples[tt], neg_examples], dim=0).to(self.eval_params['device']) + y = torch.zeros(X.shape[0], dtype=torch.long, device=self.eval_params['device']) + y[:pos_examples[tt].shape[0]] = 1 + C = 0.05 + + w = torch.nn.Parameter(torch.zeros_like(species_w, device=self.eval_params['device'])) + opt = torch.optim.Rprop([w], lr=0.001) + crit = torch.nn.BCEWithLogitsLoss() + crit2 = torch.nn.MSELoss() + with torch.set_grad_enabled(True): + for i in range(40): + opt.zero_grad() + output = (X @ (w + species_w.detach()).T) + 0*species_b.squeeze(-1) + yhat = y.float()[:, None].repeat(1, w.shape[0]) + loss = 0.5 * crit(output[yhat == 0], yhat[yhat == 0]) + 0.5 * crit( + output[yhat == 1], yhat[yhat == 1]) + 1 / ( + C * len(pos_examples[tt])) * crit2(w, 0 * w) + + loss.backward() + opt.step() + '''out = loc_emb @ (w.data + species_w.detach()).T + gt = torch.zeros(out.shape[0], dtype=torch.float32, + device=self.eval_params['device']) + gt[self.data['taxa_presence'][str(tt)]] = 1.0 + print(utils.average_precision_score_fasterer(gt, out[:, 0]).item())''' + + out = loc_emb @ (w.data + species_w.detach()).T + out = (out + 0*species_b.squeeze(-1)) + return out + # average precision score: + yfeats = torch.cat([em[section][None].to(self.eval_params['device']) for section in sections]) + preds = get_feat(yfeats) + if len(sections) > 1:#'habitat', 'overview_summary' + kws = [self.eval_params['text_section']] if len(ky) == len(keys) else ['text', 'range','distribution','habitat'] + best_sections = [i for i,s in enumerate(sections) if any((x in ky[s][1].lower() for x in kws))] + #yfeats2 = torch.cat( + # [em[section][None].to(self.eval_params['device']) for section in best_sections]).mean(dim=0, keepdim=True) + #pred2 = get_feat(yfeats2) + results['per_species_average_precision_all'][tt_id] = utils.average_precision_score_fasterer(gt, preds[:, best_sections].mean(dim=1)).item() + else: + # MINE MINE MINE MINE + # sigmoid_preds = torch.sigmoid(preds[:, 0]) + # results['per_species_average_precision_all'][tt_id] = utils.average_precision_score_fasterer(gt, sigmoid_preds).item() + results['per_species_average_precision_all'][tt_id] = utils.average_precision_score_fasterer(gt, preds[:, 0]).item() + continue + else: + if tt_id % 32 == 0: + # MINE MINE MINE MINE + # with torch.no_grad(): + # preds = torch.empty(loc_feat.shape[0], classes_of_interest[tt_id:tt_id+32].shape[0], device=self.eval_params['device']) + # for i in range(0,preds.shape[0],50000): + # xbatch = loc_feat[i:i+50000] + # ybatch = classes_of_interest[tt_id:tt_id+32].to(self.eval_params['device']).expand(xbatch.shape[0], -1) + # preds[i:i+50000] = model(xbatch, ybatch) + preds = torch.empty(loc_feat.shape[0], classes_of_interest[tt_id:tt_id+32].shape[0], device=self.eval_params['device']) + for i in range(0,preds.shape[0],50000): + xbatch = loc_feat[i:i+50000] + ybatch = classes_of_interest[tt_id:tt_id+32].to(self.eval_params['device']).expand(xbatch.shape[0], -1) + preds[i:i+50000] = model(xbatch, ybatch) + pred = preds[:,tt_id%32] + gt = torch.zeros(obs_locs.shape[0], dtype=torch.float32, device=self.eval_params['device']) + gt[self.data['taxa_presence'][str(tt)]] = 1.0 + # average precision score: + results['per_species_average_precision_all'][tt_id] = utils.average_precision_score_fasterer(gt, pred).item() + + valid_taxa = ~np.isnan(results['per_species_average_precision_all']) + + # store results + per_species_average_precision_valid = results['per_species_average_precision_all'][valid_taxa] + results['mean_average_precision'] = per_species_average_precision_valid.mean() + results['num_eval_species_w_valid_ap'] = valid_taxa.sum() + results['num_eval_species_total'] = len(self.taxa) + return results + + def report(self, results): + for field in ['mean_average_precision', 'num_eval_species_w_valid_ap', 'num_eval_species_total']: + print(f'{field}: {results[field]}') + + # MINE MINE MINE MINE but shouldn't effect things too much + def batched_matmul(self, loc_emb, wt): + batch_size = self.eval_params["batch_size"] + num_samples = loc_emb.size(0) + num_batches = (num_samples + batch_size - 1) // batch_size # Ensures rounding up + + # Preallocate the result array + pred_mtx = np.empty((num_samples, wt.size(0)), dtype=np.float32) + + wt_T = wt.t() + + # Buffer size for temporary storage + buffer_size = batch_size * 10 # Adjust buffer size as needed + buffer = np.empty((buffer_size, wt.size(0)), dtype=np.float32) + buffer_index = 0 + current_write_index = 0 + + for _, i in tqdm(enumerate(range(num_batches))): + start_idx = i * batch_size + end_idx = min(start_idx + batch_size, num_samples) + + # Perform matrix multiplication for the current batch in PyTorch + loc_emb_batch = loc_emb[start_idx:end_idx].to(self.eval_params['device']) + batch_result = torch.matmul(loc_emb_batch, wt_T).cpu().numpy() + + # Calculate the size of the current batch + current_batch_size = end_idx - start_idx + + # Check if the buffer can accommodate the current batch + if buffer_index + current_batch_size > buffer_size: + # Write buffer contents to pred_mtx + pred_mtx[current_write_index:current_write_index + buffer_index] = buffer[:buffer_index] + current_write_index += buffer_index + buffer_index = 0 # Reset buffer index + + # Add the current batch result to the buffer + buffer[buffer_index:buffer_index + current_batch_size] = batch_result + buffer_index += current_batch_size + + # Clean up to free memory + del loc_emb_batch + del batch_result + # torch.cuda.empty_cache() # Consider removing if unnecessary + + # Write any remaining data in the buffer to pred_mtx + if buffer_index > 0: + pred_mtx[current_write_index:current_write_index + buffer_index] = buffer[:buffer_index] + + return pred_mtx + + +class EvaluatorGeoPrior: + + def __init__(self, train_params, eval_params): + # store parameters: + self.train_params = train_params + self.eval_params = eval_params + with open('paths.json', 'r') as f: + paths = json.load(f) + # load vision model predictions: + self.data = np.load(os.path.join(paths['geo_prior'], 'geo_prior_model_preds.npz')) + print(self.data['probs'].shape[0], 'total test observations') + # load locations: + meta = pd.read_csv(os.path.join(paths['geo_prior'], 'geo_prior_model_meta.csv')) + self.obs_locs = np.vstack((meta['longitude'].values, meta['latitude'].values)).T.astype(np.float32) + temp = np.array(meta['observed_on'].values, dtype='S10') + temp = temp.view('S1').reshape((temp.size, -1)) + years = temp[:, :4].view('S4').astype(int)[:, 0] + months = temp[:, 5:7].view('S2').astype(int)[:, 0] + days = temp[:, 8:10].view('S2').astype(int)[:, 0] + days_per_month = np.cumsum([0] + [monthrange(2018, mm)[1] for mm in range(1, 12)]) + dates = days_per_month[months - 1] + days - 1 + self.dates = np.round((dates) / 365.0, 4).astype(np.float32) + # taxonomic mapping: + self.taxon_map = self.find_mapping_between_models(self.data['model_to_taxa'], self.train_params['class_to_taxa']) + self.time_enc = utils.TimeEncoder() if train_params['input_time'] else None + print(self.taxon_map.shape[0], 'out of', len(self.data['model_to_taxa']), 'taxa in both vision and geo models') + + cs = torch.load('class_counts.pt') + cs = cs.sum() / cs + cs = cs.to(self.eval_params['device']) + self.C = cs[None] + self.pdf = utils.DataPDFH3(device=self.eval_params['device']) + + def find_mapping_between_models(self, vision_taxa, geo_taxa): + # this will output an array of size N_overlap X 2 + # the first column will be the indices of the vision model, and the second is their + # corresponding index in the geo model + taxon_map = np.ones((vision_taxa.shape[0], 2), dtype=np.int32)*-1 + taxon_map[:, 0] = np.arange(vision_taxa.shape[0]) + geo_taxa_arr = np.array(geo_taxa) + for tt_id, tt in enumerate(vision_taxa): + ind = np.where(geo_taxa_arr==tt)[0] + if len(ind) > 0: + taxon_map[tt_id, 1] = ind[0] + inds = np.where(taxon_map[:, 1]>-1)[0] + taxon_map = taxon_map[inds, :] + return taxon_map + + def convert_to_inat_vision_order(self, geo_pred_ip, vision_top_k_prob, vision_top_k_inds, vision_taxa, taxon_map, k=1.0): + # this is slow as we turn the sparse input back into the same size as the dense one + vision_pred = np.zeros((geo_pred_ip.shape[0], len(vision_taxa)), dtype=np.float32) + geo_pred = k*np.ones((geo_pred_ip.shape[0], len(vision_taxa)), dtype=np.float32) + vision_pred[np.arange(vision_pred.shape[0])[..., np.newaxis], vision_top_k_inds] = vision_top_k_prob + + geo_pred[:, taxon_map[:, 0]] = geo_pred_ip[:, taxon_map[:, 1]] + + return geo_pred, vision_pred + + def run_evaluation(self, model, enc, extra_input=None): + results = {} + + # loop over in batches + batch_start = np.hstack((np.arange(0, self.data['probs'].shape[0], self.eval_params['batch_size']), self.data['probs'].shape[0])) + correct_pred = np.zeros(self.data['probs'].shape[0]) + from tqdm import tqdm + for bb_id, bb in tqdm(enumerate(range(len(batch_start)-1))): + batch_inds = np.arange(batch_start[bb], batch_start[bb+1]) + + vision_probs = self.data['probs'][batch_inds, :] + vision_inds = self.data['inds'][batch_inds, :] + gt = self.data['labels'][batch_inds] + dates = torch.from_numpy(self.dates[batch_inds]) + + obs_locs_batch = torch.from_numpy(self.obs_locs[batch_inds, :]).to(self.eval_params['device']) + noise_level = 1.0 + if self.time_enc is not None: + extra_input = self.time_enc.encode(torch.cat([dates[...,None], torch.full((*dates.shape, 1),noise_level)], dim=1)).to( + self.eval_params['device']) + loc_feat = torch.cat([enc.encode(obs_locs_batch), extra_input], 1) if extra_input is not None else enc.encode(obs_locs_batch) + + with torch.no_grad(): + geo_pred = model(loc_feat).cpu().numpy() + + geo_pred, vision_pred = self.convert_to_inat_vision_order(geo_pred, vision_probs, vision_inds, + self.data['model_to_taxa'], self.taxon_map, k=1.0) + #geo_pred = softmax(torch.from_numpy(geo_pred), dim=1).numpy() + comb_pred = np.argmax(vision_pred*geo_pred, 1) + comb_pred = (comb_pred==gt) + correct_pred[batch_inds] = comb_pred + accuracy_by_taxa = np.zeros(len(self.data['model_to_taxa'])) + for tt_id, tt in enumerate(self.data['model_to_taxa']): + inds = np.where(self.data['labels'] == tt)[0] + accuracy_by_taxa[tt_id] = float((correct_pred[inds].mean())) + torch.save(correct_pred, f'correct_{noise_level}.pt') + torch.save(accuracy_by_taxa, f'abt_{noise_level}.pt') + results['vision_only_top_1'] = float((self.data['inds'][:, -1] == self.data['labels']).mean()) + results['vision_geo_top_1'] = float(correct_pred.mean()) + return results + + def report(self, results): + print('Overall accuracy vision only model', round(results['vision_only_top_1'], 3)) + print('Overall accuracy of geo model ', round(results['vision_geo_top_1'], 3)) + print('Gain ', round(results['vision_geo_top_1'] - results['vision_only_top_1'], 3)) + +class EvaluatorGeoFeature: + + def __init__(self, train_params, eval_params): + self.train_params = train_params + self.eval_params = eval_params + with open('paths.json', 'r') as f: + paths = json.load(f) + self.data_path = paths['geo_feature'] + self.country_mask = tifffile.imread(os.path.join(paths['masks'], 'USA_MASK.tif')) == 1 + self.raster_names = ['ABOVE_GROUND_CARBON', 'ELEVATION', 'LEAF_AREA_INDEX', 'NON_TREE_VEGITATED', 'NOT_VEGITATED', 'POPULATION_DENSITY', 'SNOW_COVER', 'SOIL_MOISTURE', 'TREE_COVER'] + self.raster_names_log_transform = ['POPULATION_DENSITY'] + + def load_raster(self, raster_name, log_transform=False): + raster = tifffile.imread(os.path.join(self.data_path, raster_name + '.tif')).astype(np.float32) + valid_mask = ~np.isnan(raster).copy() & self.country_mask + # log scaling: + if log_transform: + raster[valid_mask] = np.log1p(raster[valid_mask] - raster[valid_mask].min()) + # 0/1 scaling: + raster[valid_mask] -= raster[valid_mask].min() + raster[valid_mask] /= raster[valid_mask].max() + + return raster, valid_mask + + def get_split_labels(self, raster, split_ids, split_of_interest): + # get the GT labels for a subset + inds_y, inds_x = np.where(split_ids==split_of_interest) + return raster[inds_y, inds_x] + + def get_split_feats(self, model, enc, split_ids, split_of_interest, extra_input=None): + locs = utils.coord_grid(self.country_mask.shape, split_ids=split_ids, split_of_interest=split_of_interest) + locs = torch.from_numpy(locs).to(self.eval_params['device']) + locs_enc = torch.cat([enc.encode(locs), extra_input.expand(locs.shape[0], -1)], 1) if extra_input is not None else enc.encode(locs) + with torch.no_grad(): + feats = model(locs_enc, return_feats=True).cpu().numpy() + return feats + + def run_evaluation(self, model2, enc, extra_input=None): + if self.train_params['model'] == 'ResidualFCNet': + model = model2 + elif self.train_params['model'] == 'HyperNet': + model = lambda x, return_feats=True: model2.pos_enc(x) + else: + raise NotImplementedError() + results = {} + for raster_name in self.raster_names: + do_log_transform = raster_name in self.raster_names_log_transform + raster, valid_mask = self.load_raster(raster_name, do_log_transform) + split_ids = utils.create_spatial_split(raster, valid_mask, cell_size=self.eval_params['cell_size']) + feats_train = self.get_split_feats(model, enc, split_ids=split_ids, split_of_interest=1, extra_input=extra_input) + feats_test = self.get_split_feats(model, enc, split_ids=split_ids, split_of_interest=2, extra_input=extra_input) + labels_train = self.get_split_labels(raster, split_ids, 1) + labels_test = self.get_split_labels(raster, split_ids, 2) + scaler = MinMaxScaler() + feats_train_scaled = scaler.fit_transform(feats_train) + feats_test_scaled = scaler.transform(feats_test) + clf = RidgeCV(alphas=(0.1, 1.0, 10.0), cv=10, fit_intercept=True, scoring='r2').fit(feats_train_scaled, labels_train) + train_score = clf.score(feats_train_scaled, labels_train) + test_score = clf.score(feats_test_scaled, labels_test) + results[f'train_r2_{raster_name}'] = float(train_score) + results[f'test_r2_{raster_name}'] = float(test_score) + results[f'alpha_{raster_name}'] = float(clf.alpha_) + return results + + def report(self, results): + report_fields = [x for x in results if 'test_r2' in x] + for field in report_fields: + print(f'{field}: {results[field]}') + print(np.mean([results[field] for field in report_fields])) + +# I need train overrides for some of my stuff but it should have zero impact on other things +def launch_eval_run(overrides, train_overrides=None): + + eval_params = setup.get_default_params_eval(overrides) + + # set up model: + eval_params['model_path'] = os.path.join(eval_params['exp_base'], eval_params['experiment_name'], eval_params['ckp_name']) + #train_params = torch.load(eval_params['model_path'], map_location='cpu', weights_only=False) + train_params = torch.load(eval_params['model_path'], map_location='cpu') + default_params = setup.get_default_params_train() + for key in default_params: + if key not in train_params['params']: + train_params['params'][key] = default_params[key] + # MINE - this is hopefully just for my models - must ensure this - should have zero impact on hypernets + if train_overrides != None: + for key, value in train_overrides.items(): + #print(f'updating train param {key}') + train_params['params'][key] = value + + model = models.get_model(train_params['params'], inference_only=True) + model.load_state_dict(train_params['state_dict'], strict=False) + model = model.to(eval_params['device']) + model.eval() + + # create input encoder: + if train_params['params']['input_enc'] in ['env', 'sin_cos_env', 'sh_env']: + raster = datasets.load_env().to(eval_params['device']) + else: + raster = None + enc = utils.CoordEncoder(train_params['params']['input_enc'], raster=raster, input_dim=train_params['params']['input_dim']) + if train_params['params']['input_time']: + time_enc = utils.TimeEncoder(input_enc='conical') if train_params['params']['input_time'] else None + extra_input = torch.cat([time_enc.encode(torch.tensor([[0.0, 1.0]]))], dim=1).to(eval_params['device']) + else: + extra_input = None + + # This should only effect my models + # This is where I create the eval "species tokens" from the specified number of context points + # TODO just use the existing train params and some if statements to get the right dataset without having to use train overides + if train_params['params']['model'] == 'MultiInputModel': + + train_dataset = datasets.get_train_data(train_params['params']) + + if 'text' in train_params['params']['dataset']: + if eval_params['text_section'] != '': + train_dataset.select_text_section(eval_params['text_section']) + print(f'Using {eval_params["text_section"]} text for evaluation') + + train_loader = torch.utils.data.DataLoader( + train_dataset, + batch_size=train_params['params']['batch_size'], + shuffle=True, + num_workers=8, + collate_fn=getattr(train_dataset, 'collate_fn', None)) + + # if len(train_params['params']['class_to_taxa']) != train_dataset.class_to_taxa: + + # Create new embedding layers for the expanded classes + num_new_classes = len(train_dataset.class_to_taxa) + embedding_dim = model.ema_embeddings.embedding_dim + new_ema_embeddings = nn.Embedding(num_embeddings=num_new_classes, embedding_dim=embedding_dim).to(eval_params["device"]) + new_eval_embeddings = nn.Embedding(num_embeddings=num_new_classes, embedding_dim=embedding_dim).to(eval_params["device"]) + nn.init.xavier_uniform_(new_ema_embeddings.weight) + nn.init.xavier_uniform_(new_eval_embeddings.weight) + + # Convert lists to numpy arrays for indexing + class_to_taxa_np = np.array(train_params['params']['class_to_taxa']) + class_to_taxa_expanded_np = np.array(train_dataset.class_to_taxa) + + # Find common taxa and their indices + common_taxa, original_indices, expanded_indices = np.intersect1d( + class_to_taxa_np, class_to_taxa_expanded_np, return_indices=True) + + # Update new embeddings for the common taxa + new_ema_embeddings.weight.data[expanded_indices] = model.ema_embeddings.weight.data[original_indices] + new_eval_embeddings.weight.data[expanded_indices] = model.eval_embeddings.weight.data[original_indices] + + # Replace old embeddings with new embeddings + model.ema_embeddings = new_ema_embeddings + model.eval_embeddings = new_eval_embeddings + + # Print to verify + #print("Updating EMA Embeddings: ", model.ema_embeddings.weight.size()) + #print("Updating Eval Embeddings: ", model.eval_embeddings.weight.size()) + + train_params['params']['class_to_taxa'] = train_dataset.class_to_taxa + + for _, batch in tqdm(enumerate(train_loader)): + if train_params['params']['use_text_inputs']: + loc_feat, _, class_id, context_feats, _, context_mask, embs = batch + loc_feat = loc_feat.to(eval_params['device']) + class_id = class_id.to(eval_params['device']) + context_feats = context_feats.to(eval_params['device']) + context_mask = context_mask.to(eval_params['device']) + embs = embs.to(eval_params['device']) + # Don't need to do anything with these probs - I am just updating the "eval embeddings" + probs = model.forward( + x=loc_feat, + context_sequence=context_feats, + context_mask=context_mask, + class_ids=class_id, + return_feats=False, + return_class_embeddings=False, + class_of_interest=None, + use_eval_embeddings=True, + text_emb = embs) + else: + loc_feat, _, class_id, context_feats, _, context_mask = batch + loc_feat = loc_feat.to(eval_params['device']) + class_id = class_id.to(eval_params['device']) + context_feats = context_feats.to(eval_params['device']) + context_mask = context_mask.to(eval_params['device']) + # Don't need to do anything with these probs - I am just updating the "eval embeddings" + probs = model.forward( + x=loc_feat, + context_sequence=context_feats, + context_mask=context_mask, + class_ids=class_id, + return_feats=False, + return_class_embeddings=False, + class_of_interest=None, + use_eval_embeddings=True + ) + print('eval embeddings generated!') + + elif train_params['params']['model'] == 'VariableInputModel': + + train_dataset = datasets.get_train_data(train_params['params']) + + if train_dataset.use_text: + if eval_params['text_section'] != '': + train_dataset.select_text_section(eval_params['text_section']) + print(f'Using {eval_params["text_section"]} text for evaluation') + + train_loader = torch.utils.data.DataLoader( + train_dataset, + batch_size=train_params['params']['batch_size'], + shuffle=True, + num_workers=8, + collate_fn=getattr(train_dataset, 'collate_fn', None)) + + # if len(train_params['params']['class_to_taxa']) != train_dataset.class_to_taxa: + + # Create new embedding layers for the expanded classes + num_new_classes = len(train_dataset.class_to_taxa) + embedding_dim = model.ema_embeddings.embedding_dim + new_ema_embeddings = nn.Embedding(num_embeddings=num_new_classes, embedding_dim=embedding_dim).to(eval_params["device"]) + new_eval_embeddings = nn.Embedding(num_embeddings=num_new_classes, embedding_dim=embedding_dim).to(eval_params["device"]) + nn.init.xavier_uniform_(new_ema_embeddings.weight) + nn.init.xavier_uniform_(new_eval_embeddings.weight) + + # Convert lists to numpy arrays for indexing + class_to_taxa_np = np.array(train_params['params']['class_to_taxa']) + class_to_taxa_expanded_np = np.array(train_dataset.class_to_taxa) + + # Find common taxa and their indices + common_taxa, original_indices, expanded_indices = np.intersect1d( + class_to_taxa_np, class_to_taxa_expanded_np, return_indices=True) + + # Update new embeddings for the common taxa + new_ema_embeddings.weight.data[expanded_indices] = model.ema_embeddings.weight.data[original_indices] + new_eval_embeddings.weight.data[expanded_indices] = model.eval_embeddings.weight.data[original_indices] + + # Replace old embeddings with new embeddings + model.ema_embeddings = new_ema_embeddings + model.eval_embeddings = new_eval_embeddings + + # Print to verify + #print("Updating EMA Embeddings: ", model.ema_embeddings.weight.size()) + #print("Updating Eval Embeddings: ", model.eval_embeddings.weight.size()) + + train_params['params']['class_to_taxa'] = train_dataset.class_to_taxa + + for _, batch in tqdm(enumerate(train_loader)): + loc_feat, _, class_id, context_feats, _, context_mask, text_emb, image_emb, env_emb = batch + # print('DO I NEED THE BELOW LINES? DO THEY SLOW THINGS DOWN') + # return padded_sequences, padded_locs, class_ids, sequence_mask + loc_feat = loc_feat.to(eval_params['device']) + class_id = class_id.to(eval_params['device']) + context_feats = context_feats.to(eval_params['device']) + context_mask = context_mask.to(eval_params['device']) + text_emb = text_emb.to(eval_params['device']) + image_emb = image_emb.to(eval_params['device']) + if env_emb is not None: + env_emb = env_emb.to(eval_params['device']) + # Don't need to do anything with these probs - I am just updating the "eval embeddings" + + probs = model.forward(x=loc_feat, + context_sequence=context_feats, + context_mask=context_mask, + class_ids=class_id, + text_emb=text_emb, + image_emb=image_emb, + env_emb=env_emb, + return_feats=False, + return_class_embeddings=False, + class_of_interest=None, + use_eval_embeddings=True) + + print('eval embeddings generated!') + + print('\n' + eval_params['eval_type']) + t = time.time() + if eval_params['eval_type'] == 'snt': + eval_params['split'] = 'test' # val, test, all + eval_params['val_frac'] = 0.50 + eval_params['split_seed'] = 7499 + evaluator = EvaluatorSNT(train_params['params'], eval_params) + results = evaluator.run_evaluation(model, enc, extra_input=extra_input) + evaluator.report(results) + elif eval_params['eval_type'] == 'iucn': + evaluator = EvaluatorIUCN(train_params['params'], eval_params) + results = evaluator.run_evaluation(model, enc, extra_input=extra_input) + evaluator.report(results) + elif eval_params['eval_type'] == 'geo_prior': + evaluator = EvaluatorGeoPrior(train_params['params'], eval_params) + results = evaluator.run_evaluation(model, enc, extra_input=extra_input) + evaluator.report(results) + elif eval_params['eval_type'] == 'geo_feature': + evaluator = EvaluatorGeoFeature(train_params['params'], eval_params) + results = evaluator.run_evaluation(model, enc, extra_input=extra_input) + evaluator.report(results) + else: + raise NotImplementedError('Eval type not implemented.') + print(f'evaluation completed in {np.around((time.time()-t)/60, 1)} min') + return results + +class EvaluatorGeoPriorLowRank: + + def __init__(self, train_params, eval_params): + # store parameters: + self.train_params = train_params + self.eval_params = eval_params + with open('paths.json', 'r') as f: + paths = json.load(f) + # load vision model predictions: + self.data = np.load(os.path.join(paths['geo_prior'], 'geo_prior_model_preds.npz')) + print(self.data['probs'].shape[0], 'total test observations') + # load locations: + meta = pd.read_csv(os.path.join(paths['geo_prior'], 'geo_prior_model_meta.csv')) + self.obs_locs = np.vstack((meta['longitude'].values, meta['latitude'].values)).T.astype(np.float32) + temp = np.array(meta['observed_on'].values, dtype='S10') + temp = temp.view('S1').reshape((temp.size, -1)) + years = temp[:, :4].view('S4').astype(int)[:, 0] + months = temp[:, 5:7].view('S2').astype(int)[:, 0] + days = temp[:, 8:10].view('S2').astype(int)[:, 0] + days_per_month = np.cumsum([0] + [monthrange(2018, mm)[1] for mm in range(1, 12)]) + dates = days_per_month[months - 1] + days - 1 + self.dates = np.round((dates) / 365.0, 4).astype(np.float32) + # taxonomic mapping: + self.taxon_map = self.find_mapping_between_models(self.data['model_to_taxa'], self.train_params['class_to_taxa']) + print(self.taxon_map.shape[0], 'out of', len(self.data['model_to_taxa']), 'taxa in both vision and geo models') + + def find_mapping_between_models(self, vision_taxa, geo_taxa): + # this will output an array of size N_overlap X 2 + # the first column will be the indices of the vision model, and the second is their + # corresponding index in the geo model + taxon_map = np.ones((vision_taxa.shape[0], 2), dtype=np.int32)*-1 + taxon_map[:, 0] = np.arange(vision_taxa.shape[0]) + geo_taxa_arr = np.array(geo_taxa) + for tt_id, tt in enumerate(vision_taxa): + ind = np.where(geo_taxa_arr==tt)[0] + if len(ind) > 0: + taxon_map[tt_id, 1] = ind[0] + inds = np.where(taxon_map[:, 1]>-1)[0] + taxon_map = taxon_map[inds, :] + return taxon_map + + def convert_to_inat_vision_order(self, geo_pred_ip, vision_top_k_prob, vision_top_k_inds, vision_taxa, taxon_map): + # this is slow as we turn the sparse input back into the same size as the dense one + vision_pred = np.zeros((geo_pred_ip.shape[0], len(vision_taxa)), dtype=np.float32) + geo_pred = np.ones((geo_pred_ip.shape[0], len(vision_taxa)), dtype=np.float32) + vision_pred[np.arange(vision_pred.shape[0])[..., np.newaxis], vision_top_k_inds] = vision_top_k_prob + + geo_pred[:, taxon_map[:, 0]] = geo_pred_ip[:, taxon_map[:, 1]] + + return geo_pred, vision_pred + + def run_evaluation(self, model): + results = {} + + # loop over in batches + batch_start = np.hstack((np.arange(0, self.data['probs'].shape[0], self.eval_params['batch_size']), self.data['probs'].shape[0])) + correct_pred = np.zeros(self.data['probs'].shape[0]) + from tqdm import tqdm + for bb_id, bb in tqdm(enumerate(range(len(batch_start)-1))): + batch_inds = np.arange(batch_start[bb], batch_start[bb+1]) + + vision_probs = self.data['probs'][batch_inds, :] + vision_inds = self.data['inds'][batch_inds, :] + gt = self.data['labels'][batch_inds] + dates = torch.from_numpy(self.dates[batch_inds]) + + obs_locs_batch = torch.from_numpy(self.obs_locs[batch_inds, :]).to(self.eval_params['device']) + + with torch.no_grad(): + geo_pdf = torch.log(model.sample(obs_locs_batch)).T + + for bias in range(11+5, 12+5): + geo_pred, vision_pred = self.convert_to_inat_vision_order(geo_pdf+bias, vision_probs, vision_inds, + self.data['model_to_taxa'], self.taxon_map) + geo_pred = softmax(torch.from_numpy(geo_pred), dim=1).numpy() + #print(bias, (np.argmax(vision_pred*geo_pred2, 1) == gt).mean().item()) + + comb_pred = np.argmax(vision_pred*geo_pred, 1) + comb_pred = (comb_pred==gt) + correct_pred[batch_inds] = comb_pred + accuracy_by_taxa = np.zeros(len(self.data['model_to_taxa'])) + for tt_id, tt in enumerate(self.data['model_to_taxa']): + inds = np.where(self.data['labels'] == tt)[0] + accuracy_by_taxa[tt_id] = float((correct_pred[inds].mean())) + results['vision_only_top_1'] = float((self.data['inds'][:, -1] == self.data['labels']).mean()) + results['vision_geo_top_1'] = float(correct_pred.mean()) + return results + + def report(self, results): + print('Overall accuracy vision only model', round(results['vision_only_top_1'], 3)) + print('Overall accuracy of geo model ', round(results['vision_geo_top_1'], 3)) + print('Gain ', round(results['vision_geo_top_1'] - results['vision_only_top_1'], 3)) + +# MINE MINE MINE - these are just to help with low shot plotting. Can probably be elsewhere. +def generate_eval_embeddings(overrides, taxa_of_interest, num_context, train_overrides=None): + + eval_params = setup.get_default_params_eval(overrides) + + # set up model: + eval_params['model_path'] = os.path.join(eval_params['exp_base'], eval_params['experiment_name'], eval_params['ckp_name']) + eval_params['device'] = 'cpu' + train_params = torch.load(eval_params['model_path'], map_location='cpu') + train_params['params']['device'] = 'cpu' + default_params = setup.get_default_params_train() + for key in default_params: + if key not in train_params['params']: + train_params['params'][key] = default_params[key] + + # create input encoder: + if train_params['params']['input_enc'] in ['env', 'sin_cos_env']: + raster = datasets.load_env().to(eval_params['device']) + else: + raster = None + enc = utils.CoordEncoder(train_params['params']['input_enc'], raster=raster, input_dim=train_params['params']['input_dim']) + if train_params['params']['input_time']: + time_enc = utils.TimeEncoder(input_enc='conical') if train_params['params']['input_time'] else None + extra_input = torch.cat([time_enc.encode(torch.tensor([[0.0, 1.0]]))], dim=1).to(eval_params['device']) + else: + extra_input = None + + if train_overrides != None: + for key, value in train_overrides.items(): + #print(f'updating train param {key}') + train_params['params'][key] = value + + train_dataset = datasets.get_train_data(train_params['params']) + + train_loader = torch.utils.data.DataLoader( + train_dataset, + batch_size=train_params['params']['batch_size'], + shuffle=True, + num_workers=8, + collate_fn=getattr(train_dataset, 'collate_fn', None)) + model = models.get_model(train_params['params'], inference_only=True) + # model.load_state_dict(train_params['state_dict'], strict=True) + model.load_state_dict(train_params['state_dict'], strict=False) + model = model.to(eval_params['device']) + model.eval() + + # Create new embedding layers for the expanded classes + num_new_classes = len(train_dataset.class_to_taxa) + embedding_dim = model.ema_embeddings.embedding_dim + new_ema_embeddings = nn.Embedding(num_embeddings=num_new_classes, embedding_dim=embedding_dim).to(eval_params["device"]) + new_eval_embeddings = nn.Embedding(num_embeddings=num_new_classes, embedding_dim=embedding_dim).to(eval_params["device"]) + nn.init.xavier_uniform_(new_ema_embeddings.weight) + nn.init.xavier_uniform_(new_eval_embeddings.weight) + + # Convert lists to numpy arrays for indexing + class_to_taxa_np = np.array(train_params['params']['class_to_taxa']) + class_to_taxa_expanded_np = np.array(train_dataset.class_to_taxa) + + # Find common taxa and their indices + common_taxa, original_indices, expanded_indices = np.intersect1d( + class_to_taxa_np, class_to_taxa_expanded_np, return_indices=True) + + # Update new embeddings for the common taxa + new_ema_embeddings.weight.data[expanded_indices] = model.ema_embeddings.weight.data[original_indices] + new_eval_embeddings.weight.data[expanded_indices] = model.eval_embeddings.weight.data[original_indices] + + # Replace old embeddings with new embeddings + model.ema_embeddings = new_ema_embeddings + model.eval_embeddings = new_eval_embeddings + + # Print to verify + #print("Updated EMA Embeddings: ", model.ema_embeddings.weight.size()) + #print("Updated Eval Embeddings: ", model.eval_embeddings.weight.size()) + + train_params['params']['class_to_taxa'] = train_dataset.class_to_taxa + + class_of_interest = train_dataset.class_to_taxa.index(taxa_of_interest) + + # Find the index of class_of_interest in the labels tensor + loc_index_of_interest = (train_dataset.labels == class_of_interest).nonzero(as_tuple=True)[0].item() + + # loc_index_of_interest = train_dataset.labels.index(class_of_interest) + + loc_of_interest = train_dataset.loc_feats[loc_index_of_interest] + + all_class_context_feats = train_dataset.per_class_loc_feats[class_of_interest] + all_class_context_locs = train_dataset.per_class_locs[class_of_interest] + + context_feats_of_interest = all_class_context_feats[:num_context,:] + context_locs_of_interest = all_class_context_locs[:num_context,:] + + # context_mask = context_feats_of_interest != -10 + # context_mask = None + # context_mask = (context_locs_of_interest == -10).all(dim=-1).to(eval_params['device']) + context_mask = (context_locs_of_interest == -10).all(dim=-1).to(eval_params['device']).unsqueeze(0) + + probs = model.forward( + x=loc_of_interest.to(train_params['params']['device']), + context_sequence=context_feats_of_interest.to(train_params['params']['device']), + context_mask=context_mask, + class_ids=class_of_interest, + return_feats=False, + return_class_embeddings=False, + class_of_interest=None, + use_eval_embeddings=True + ) + + #print(f'eval embedding generated for class {class_of_interest}, taxa {taxa_of_interest}') + + return model, context_locs_of_interest, train_params, class_of_interest + +def generate_eval_embedding_from_given_points(context_points, overrides, taxa_of_interest, train_overrides=None, text_emb=None): + + eval_params = setup.get_default_params_eval(overrides) + + # set up model: + eval_params['model_path'] = os.path.join(eval_params['exp_base'], eval_params['experiment_name'], eval_params['ckp_name']) + train_params = torch.load(eval_params['model_path'], map_location='cpu') + default_params = setup.get_default_params_train() + for key in default_params: + if key not in train_params['params']: + train_params['params'][key] = default_params[key] + + # create input encoder: + if train_params['params']['input_enc'] in ['env', 'sin_cos_env']: + raster = datasets.load_env().to(eval_params['device']) + else: + raster = None + enc = utils.CoordEncoder(train_params['params']['input_enc'], raster=raster, input_dim=train_params['params']['input_dim']) + if train_params['params']['input_time']: + time_enc = utils.TimeEncoder(input_enc='conical') if train_params['params']['input_time'] else None + extra_input = torch.cat([time_enc.encode(torch.tensor([[0.0, 1.0]]))], dim=1).to(eval_params['device']) + else: + extra_input = None + + if train_overrides != None: + for key, value in train_overrides.items(): + #print(f'updating train param {key}') + train_params['params'][key] = value + + # create context point encoder + transformer_input_enc = train_params['params']['transformer_input_enc'] + if transformer_input_enc in ['env', 'sin_cos_env']: + transformer_raster = datasets.load_env().to(eval_params['device']) + else: + transformer_raster = None + token_dim = train_params['params']['species_dim'] + + if transformer_input_enc == 'sinr': + transformer_enc = enc + else: + transformer_enc = utils.CoordEncoder(transformer_input_enc, transformer_raster, input_dim=token_dim) + + # transformer_enc = utils.CoordEncoder(transformer_input_enc, transformer_raster, input_dim=token_dim) + + # load model + model = models.get_model(train_params['params'], inference_only=True) + # model.load_state_dict(train_params['state_dict'], strict=True) + model.load_state_dict(train_params['state_dict'], strict=False) + model = model.to(eval_params['device']) + model.eval() + + # # Create new embedding layers for the expanded classes + # num_new_classes = len(train_params['params']['class_to_taxa']) + embedding_dim = model.ema_embeddings.embedding_dim + # new_ema_embeddings = nn.Embedding(num_embeddings=num_new_classes, embedding_dim=embedding_dim).to(eval_params["device"]) + new_eval_embeddings = nn.Embedding(num_embeddings=model.eval_embeddings.weight.size()[0], embedding_dim=embedding_dim).to(eval_params["device"]) + + # Update new embeddings for the common taxa + new_eval_embeddings.weight.data = model.eval_embeddings.weight.data + + # Replace old embeddings with new embeddings + model.eval_embeddings = new_eval_embeddings + + # Print to verify + #print("Updated EMA Embeddings: ", model.ema_embeddings.weight.size()) + #print("Updated Eval Embeddings: ", model.eval_embeddings.weight.size()) + + class_of_interest = 0 + + just_loc = torch.from_numpy(np.array([[0.0,0.0]]).astype(np.float32)) + + loc_of_interest = enc.encode(just_loc, normalize=False) + + context_points = torch.from_numpy(np.array(context_points).astype(np.float32)) + + all_class_context_feats = transformer_enc.encode(context_points, normalize=False) + all_class_context_locs = context_points + + context_feats_of_interest = all_class_context_feats + context_locs_of_interest = all_class_context_locs + + # context_mask = context_feats_of_interest[:,0] != -10 + # context_mask = None + context_mask = torch.from_numpy(np.full((1, context_feats_of_interest.shape[0]), False)) + + # probs = model.forward( + # x=loc_of_interest.to(train_params['params']['device']), + # context_sequence=context_feats_of_interest.to(train_params['params']['device']), + # context_mask=context_mask, + # class_ids=class_of_interest, + # return_feats=False, + # return_class_embeddings=False, + # class_of_interest=None, + # use_eval_embeddings=True + # ) + + probs = model.forward( + x=loc_of_interest.to(eval_params['device']), + context_sequence=context_feats_of_interest.to(eval_params['device']), + context_mask=context_mask, + class_ids=class_of_interest, + return_feats=False, + return_class_embeddings=False, + class_of_interest=None, + use_eval_embeddings=True, + text_emb=text_emb + ) + + #print(f'eval embedding generated for class {class_of_interest}, from hand selected context points') + + return model, context_locs_of_interest, train_params, class_of_interest \ No newline at end of file