Spaces:
Sleeping
Sleeping
import collections | |
import sys | |
import numpy as np | |
import pandas as pd | |
import random | |
import torch | |
import time | |
import os | |
import json | |
import tifffile | |
import h3 | |
import setup | |
from sklearn.linear_model import RidgeCV | |
from sklearn.preprocessing import MinMaxScaler | |
from torch.utils.data import Subset | |
import utils | |
import models | |
import datasets | |
from calendar import monthrange | |
from torch.nn.functional import logsigmoid, softmax | |
import torch.nn as nn | |
from tqdm import tqdm | |
import csv | |
def format_tensor(tensor): | |
# Convert tensor to list, then flatten to string | |
tensor_list = tensor.tolist() # Converts the tensor to a Python list | |
return str(tensor_list).replace('\n', '').replace(' ', '') | |
class EvaluatorSNT: | |
def __init__(self, train_params, eval_params): | |
self.train_params = train_params | |
self.eval_params = eval_params | |
with open('paths.json', 'r') as f: | |
paths = json.load(f) | |
D = np.load(os.path.join(paths['snt'], 'snt_res_5.npy'), allow_pickle=True) | |
D = D.item() | |
self.loc_indices_per_species = D['loc_indices_per_species'] | |
self.labels_per_species = D['labels_per_species'] | |
self.taxa = D['taxa'] | |
self.obs_locs = D['obs_locs'] | |
self.obs_locs_idx = D['obs_locs_idx'] | |
self.pos_eval_data_loc = os.path.join(paths['data'], 'positive_eval_data.npz') | |
self.background_eval_data_loc = os.path.join(paths['data'], '10000_background_negs.npz') | |
def get_labels(self, species): | |
species = str(species) | |
lat = [] | |
lon = [] | |
gt = [] | |
for hx in self.data: | |
cur_lat, cur_lon = h3.h3_to_geo(hx) | |
if species in self.data[hx]: | |
cur_label = int(len(self.data[hx][species]) > 0) | |
gt.append(cur_label) | |
lat.append(cur_lat) | |
lon.append(cur_lon) | |
lat = np.array(lat).astype(np.float32) | |
lon = np.array(lon).astype(np.float32) | |
obs_locs = np.vstack((lon, lat)).T | |
gt = np.array(gt).astype(np.float32) | |
return obs_locs, gt | |
def run_evaluation(self, model, enc, extra_input=None): | |
results = {} | |
# set seeds: | |
np.random.seed(self.eval_params['seed']) | |
random.seed(self.eval_params['seed']) | |
# evaluate the geo model for each taxon | |
results['per_species_average_precision_all'] = np.zeros((len(self.taxa)), dtype=np.float32) | |
# get eval locations and apply input encoding | |
obs_locs = torch.from_numpy(self.obs_locs).to(self.eval_params['device']) | |
loc_feat = torch.cat([enc.encode(obs_locs), extra_input.expand(obs_locs.shape[0], -1)], dim=1) if extra_input is not None else enc.encode(obs_locs) | |
# get classes to eval | |
classes_of_interest = torch.zeros(len(self.taxa), dtype=torch.int64) | |
for tt_id, tt in enumerate(self.taxa): | |
class_of_interest = np.where(np.array(self.train_params['class_to_taxa']) == tt)[0] | |
if len(class_of_interest) != 0: | |
classes_of_interest[tt_id] = torch.from_numpy(class_of_interest) | |
if self.eval_params['extract_pos']: | |
assert 'HyperNet' in self.train_params['model'] | |
model = model.pos_enc | |
self.train_params['model'] = 'ResidualFCNet' | |
if ('CombinedModel' in self.train_params['model']) or ('MultiInputModel' in self.train_params['model']): | |
with torch.no_grad(): | |
dummy_context_mask = None | |
dummy_context_sequence = None | |
# generate model predictions for classes of interest at eval locations | |
loc_emb = model(x=loc_feat, context_sequence=dummy_context_sequence, context_mask=dummy_context_mask, | |
class_ids=classes_of_interest, return_feats=True) | |
classes_of_interest = classes_of_interest.to(self.eval_params["device"]) | |
wt = model.get_eval_embeddings(classes_of_interest) | |
pred_mtx = torch.matmul(loc_emb, torch.transpose(wt, 0, 1)) | |
elif self.train_params['model'] == 'VariableInputModel': | |
with torch.no_grad(): | |
loc_emb = model.get_loc_emb(x=loc_feat) | |
classes_of_interest = classes_of_interest.to(self.eval_params["device"]) | |
wt = model.get_eval_embeddings(classes_of_interest) | |
pred_mtx = torch.matmul(loc_emb, torch.transpose(wt, 0, 1)) | |
elif 'HyperNet' not in self.train_params['model'] and not (self.train_params['zero_shot'] or self.eval_params['num_samples'] > 0): | |
with torch.no_grad(): | |
# generate model predictions for classes of interest at eval locations | |
loc_emb = model(loc_feat, return_feats=True) | |
wt = model.class_emb.weight[classes_of_interest, :] | |
pred_mtx = torch.matmul(loc_emb, torch.transpose(wt, 0, 1)) | |
elif (self.train_params['zero_shot'] or self.eval_params['num_samples'] > 0): | |
if self.train_params['model'] == 'ResidualFCNet': | |
import datasets | |
# from sklearn.linear_model import LogisticRegression | |
# with open('paths.json', 'r') as f: | |
# paths = json.load(f) | |
# data_dir = paths['train'] | |
# obs_file = os.path.join(data_dir, self.train_params['obs_file']) | |
# taxa_file = os.path.join(data_dir, self.train_params['taxa_file']) | |
# taxa_file_snt = os.path.join(data_dir, 'taxa_subsets.json') | |
# taxa_of_interest = datasets.get_taxa_of_interest(self.train_params['species_set'], self.train_params['num_aux_species'], | |
# self.train_params['aux_species_seed'], self.train_params['taxa_file'], taxa_file_snt) | |
obs_file = self.pos_eval_data_loc | |
locs, labels, _, dates, _, _ = datasets.load_eval_inat_data(obs_file) | |
unique_taxa, class_ids = np.unique(labels, return_inverse=True) | |
class_to_taxa = unique_taxa.tolist() | |
# idx_ss = datasets.get_idx_subsample_observations(labels, self.eval_params['num_samples'], random.randint(0,2**32), None, -1) | |
idx_ss = datasets.get_idx_subsample_observations_eval(labels=labels, hard_cap=self.eval_params['num_samples']) | |
locs = torch.from_numpy(np.array(locs)) | |
labels = torch.from_numpy(np.array(class_ids)) | |
locs = locs[idx_ss] | |
labels = labels[idx_ss] | |
with torch.no_grad(): | |
pos_examples = {} | |
for tt in self.taxa: | |
c = class_to_taxa.index(tt) | |
pos_examples[tt] = locs[labels == c] | |
pos_examples[tt] = model(enc.encode(pos_examples[tt].to(self.eval_params['device'])), return_feats=True).cpu() | |
# MAX VERSION # MAX VERSION # MAX VERSION | |
# random negs | |
neg_examples = utils.rand_samples(10000, self.eval_params['device'], rand_type='spherical') | |
obs_file = self.background_eval_data_loc | |
neg_locs, _, _, _, _, _ = datasets.load_eval_inat_data(obs_file) | |
neg_locs = torch.from_numpy(neg_locs) | |
if extra_input is not None: | |
raise NotImplementedError('extra_input provided') | |
# add target negs | |
neg_examples = model(torch.cat([enc.encode(neg_examples, normalize=False), enc.encode( | |
neg_locs[torch.randperm(neg_locs.shape[0], device=locs.device)[:10000]].clone().to( | |
self.eval_params['device']), normalize=True)]), return_feats=True).cpu() | |
loc_emb = model(loc_feat, return_feats=True) | |
elif self.train_params['model'] == 'HyperNet': | |
import datasets | |
# from sklearn.linear_model import LogisticRegression | |
# with open('paths.json', 'r') as f: | |
# paths = json.load(f) | |
# data_dir = paths['train'] | |
# obs_file = os.path.join(data_dir, self.train_params['obs_file']) | |
# taxa_file = os.path.join(data_dir, self.train_params['taxa_file']) | |
# taxa_file_snt = os.path.join(data_dir, 'taxa_subsets.json') | |
# | |
# taxa_of_interest = datasets.get_taxa_of_interest(self.train_params['species_set'], self.train_params['num_aux_species'], | |
# self.train_params['aux_species_seed'], self.train_params['taxa_file'], taxa_file_snt) | |
# | |
obs_file = self.pos_eval_data_loc | |
locs, labels, _, dates, _, _ = datasets.load_eval_inat_data(obs_file) | |
unique_taxa, class_ids = np.unique(labels, return_inverse=True) | |
class_to_taxa = unique_taxa.tolist() | |
if self.eval_params['num_samples'] > 0: | |
# idx_ss = datasets.get_idx_subsample_observations(labels, self.eval_params['num_samples'], random.randint(0,2**32), None, -1) | |
idx_ss = datasets.get_idx_subsample_observations_eval(labels=labels, hard_cap=self.eval_params['num_samples']) | |
locs = torch.from_numpy(np.array(locs)[idx_ss]) | |
labels = torch.from_numpy(np.array(class_ids)[idx_ss]) | |
with torch.no_grad(): | |
pos_examples = {} | |
for tt in self.taxa: | |
c = class_to_taxa.index(tt) | |
pos_examples[tt] = locs[labels == c] | |
pos_examples[tt] = model.pos_enc(enc.encode(pos_examples[tt].to(self.eval_params['device']))).cpu() | |
# random negs | |
neg_examples = utils.rand_samples(10000, self.eval_params['device'], rand_type='spherical') | |
obs_file = self.background_eval_data_loc | |
neg_locs, _, _, _, _, _ = datasets.load_eval_inat_data(obs_file) | |
neg_locs = torch.from_numpy(neg_locs) | |
if extra_input is not None: | |
raise NotImplementedError('extra_input provided') | |
neg_examples = model.pos_enc(torch.cat([enc.encode(neg_examples, normalize=False), enc.encode(neg_locs[torch.randperm(neg_locs.shape[0], device=locs.device)[:10000]].clone().to(self.eval_params['device']), normalize=True)])).cpu() | |
loc_emb = model.pos_enc(loc_feat) | |
#embs = torch.load(self.train_params['text_emb_path']) #TODO | |
#embs1 = torch.load('experiments/gpt_data.pt', weights_only=False) | |
embs1 = torch.load('experiments/gpt_data.pt', map_location='cpu') | |
#embs1 = torch.load('ldsdm_data.pt') | |
emb_ids1 = embs1['taxon_id'].tolist() | |
keys1 = embs1['keys'] | |
embs1 = embs1['data'] | |
# embs2 doesn't even do anything. Could just remove the whole thing, but that is how it is in Max's code | |
# MINE MINE MINE MINE MINE | |
embs2 = torch.load('experiments/wiki_data_v4.pt') | |
# MAX MAX MAX MAX | |
# embs2 = torch.load('wiki_data_v3.pt') | |
emb_ids2 = embs2['taxon_id'].tolist() | |
keys2 = embs2['keys'] | |
embs2 = embs2['data'] | |
else: | |
raise NotImplementedError('Eval for zero-shot not implemented') | |
# if self.eval_params['num_samples'] == -1 and not (('CombinedModel' in self.train_params['model']) or ('MultiInputModel' in self.train_params['model'] or )): | |
if self.eval_params['num_samples'] == -1 and not (self.train_params['model'] in ['CombinedModel', 'MultiInputModel', 'VariableInputModel', 'ResidualFCNet']): | |
loc_emb = model.pos_enc(loc_feat) | |
elif self.eval_params['num_samples'] == -1 and not (self.train_params['model'] in ['CombinedModel', 'MultiInputModel', 'VariableInputModel']): | |
loc_emb = model.forward(loc_feat, return_feats=True) | |
split_rng = np.random.default_rng(self.eval_params['split_seed']) | |
write_gt_once = False | |
#TODO: tt is the iNat taxa id for the taxa we are calculating AP for rn, tt_id is the index in the dictionary | |
#ap_csv = "per_species_average_precision_valid.csv" | |
#taxa_id_csv = "per_species_taxa_id_valid.csv" | |
# with open(taxa_id_csv, mode='w', newline='') as csv_file: | |
# writer = csv.writer(csv_file) | |
# # If the array is multi-dimensional (e.g., 2D), iterate over rows | |
# if isinstance(self.taxa, np.ndarray): | |
# for value in self.taxa: | |
# writer.writerow([value]) | |
# else: | |
# # If it's a flat array, directly write the values | |
# writer.writerow(per_species_average_precision_valid) | |
range, range_locs = [], [] | |
for tt_id, tt in tqdm(enumerate(self.taxa)): | |
class_of_interest = np.where(np.array(self.train_params['class_to_taxa']) == tt)[0] | |
if len(class_of_interest) == 0 and not (self.train_params['zero_shot'] or self.eval_params['num_samples'] > 0): | |
# taxa of interest is not in the model | |
results['per_species_average_precision_all'][tt_id] = np.nan | |
# this only effects my models | |
elif self.train_params['model'] == 'VariableInputModel': | |
# generate ground truth labels for current taxa | |
cur_loc_indices = np.array(self.loc_indices_per_species[tt_id]) | |
cur_labels = np.array(self.labels_per_species[tt_id]) | |
# apply per-species split: | |
assert self.eval_params['split'] in ['all', 'val', 'test'] | |
if self.eval_params['split'] != 'all': | |
num_val = np.floor(len(cur_labels) * self.eval_params['val_frac']).astype(int) | |
idx_rand = split_rng.permutation(len(cur_labels)) | |
if self.eval_params['split'] == 'val': | |
idx_sel = idx_rand[:num_val] | |
elif self.eval_params['split'] == 'test': | |
idx_sel = idx_rand[num_val:] | |
cur_loc_indices = cur_loc_indices[idx_sel] | |
cur_labels = cur_labels[idx_sel] | |
cur_labels = (torch.from_numpy(cur_labels).to(self.eval_params['device']) > 0).float() | |
with torch.no_grad(): | |
logits = pred_mtx[:, tt_id] | |
preds = torch.sigmoid(logits) | |
#TODO metric value is calcuated | |
#this is how we get the predictions, just matching the hexs for the spots we are interested in. | |
results['per_species_average_precision_all'][tt_id] = utils.average_precision_score_fasterer( | |
cur_labels, | |
preds[cur_loc_indices]).item() | |
continue | |
elif self.train_params['model'] == 'MultiInputModel': | |
# generate ground truth labels for current taxa | |
#todo: ask max, are the loc_indices the h3 indices at res 5? | |
#these are the inidices of the locations of where we have evaluations | |
cur_loc_indices = np.array(self.loc_indices_per_species[tt_id]) | |
#loc_indices_per_species_array = np.array(self.loc_indices_per_species[tt_id]) | |
#this is the answer key | |
cur_labels = np.array(self.labels_per_species[tt_id]) #87373 "0." | |
#labels_per_species_array = np.array(self.labels_per_species[tt_id]) #174746 '0' | |
# apply per-species split: | |
assert self.eval_params['split'] in ['all', 'val', 'test'] | |
if self.eval_params['split'] != 'all': | |
num_val = np.floor(len(cur_labels) * self.eval_params['val_frac']).astype(int) | |
idx_rand = split_rng.permutation(len(cur_labels)) | |
if self.eval_params['split'] == 'val': | |
idx_sel = idx_rand[:num_val] | |
elif self.eval_params['split'] == 'test': | |
idx_sel = idx_rand[num_val:] | |
cur_loc_indices = cur_loc_indices[idx_sel] | |
cur_labels = cur_labels[idx_sel] | |
cur_labels = (torch.from_numpy(cur_labels).to(self.eval_params['device']) > 0).float() | |
#print('printing location testing') | |
#matching_locations = obs_locs[loc_indices_per_species_array[labels_per_species_array == 1]]#21737 this is bigger because we take out the all and val locations | |
matching_locations = obs_locs[cur_loc_indices[cur_labels == 1]] #10849 | |
range_locs.append(matching_locations) | |
#print(f'matching locations len: {len(matching_locations)}') | |
range.append(cur_labels) | |
#print(f'range cur labels len: {cur_labels.sum()}') | |
#print(f'number of locations matches: matching locations: {np.shape(matching_locations)} and cur_labels: {cur_labels.sum()}') | |
# if not write_gt_once: | |
# snt_labels_csv = f"data/plot/taxa_locs/snt_locations_{tt}.csv" | |
# with open(snt_labels_csv, mode='w', newline='') as csv_file: | |
# writer = csv.writer(csv_file) | |
# # If the array is multi-dimensional (e.g., 2D), iterate over rows | |
# if isinstance(matching_locations, np.ndarray): | |
# for value in matching_locations: | |
# writer.writerow([value]) | |
# else: | |
# # If it's a flat array, directly write the values | |
# writer.writerow(matching_locations) | |
#print(f'current labels snt: {np.shape(cur_labels)}') | |
with torch.no_grad(): | |
logits = pred_mtx[:, tt_id] | |
preds = torch.sigmoid(logits) | |
#TODO metric value is calcuated | |
results['per_species_average_precision_all'][tt_id] = utils.average_precision_score_fasterer( | |
cur_labels, | |
preds[cur_loc_indices]).item() | |
continue | |
# MINE MINE MINE MINE MINE MINE | |
# elif self.eval_params['num_samples'] == -1: | |
# gt = torch.zeros(obs_locs.shape[0], dtype=torch.float32, device=self.eval_params['device']) | |
# gt[self.data['taxa_presence'][str(tt)]] = 1.0 | |
# species_w = model.species_params[self.train_params['class_to_taxa'].index(tt)] | |
# preds = loc_emb @ species_w.detach() | |
# results['per_species_average_precision_all'][tt_id] = utils.average_precision_score_fasterer(gt, preds).item() | |
# continue | |
else: | |
# generate ground truth labels for current taxa | |
cur_loc_indices = np.array(self.loc_indices_per_species[tt_id]) | |
cur_labels = np.array(self.labels_per_species[tt_id]) | |
# apply per-species split: | |
assert self.eval_params['split'] in ['all', 'val', 'test'] | |
if self.eval_params['split'] != 'all': | |
num_val = np.floor(len(cur_labels) * self.eval_params['val_frac']).astype(int) | |
idx_rand = split_rng.permutation(len(cur_labels)) | |
if self.eval_params['split'] == 'val': | |
idx_sel = idx_rand[:num_val] | |
elif self.eval_params['split'] == 'test': | |
idx_sel = idx_rand[num_val:] | |
cur_loc_indices = cur_loc_indices[idx_sel] | |
cur_labels = cur_labels[idx_sel] | |
cur_labels = (torch.from_numpy(cur_labels).to(self.eval_params['device']) > 0).float() | |
########################################################################################## | |
# | |
########################################################################################## | |
if self.eval_params['num_samples'] == -1 and self.train_params['model'] == 'HyperNet': | |
species_w = model.species_params[self.train_params['class_to_taxa'].index(tt)] | |
preds = loc_emb @ species_w.detach() | |
results['per_species_average_precision_all'][tt_id] = utils.average_precision_score_fasterer(cur_labels, | |
preds[cur_loc_indices]).item() | |
continue | |
elif self.eval_params['num_samples'] == -1 and self.train_params['model'] == 'ResidualFCNet': | |
preds = model.eval_single_class(x=loc_emb, class_of_interest=self.train_params['class_to_taxa'].index(tt)).detach() | |
# species_w = model.species_params[self.train_params['class_to_taxa'].index(tt)] | |
# preds = loc_emb @ species_w.detach() | |
results['per_species_average_precision_all'][tt_id] = utils.average_precision_score_fasterer(cur_labels, | |
preds[cur_loc_indices]).item() | |
continue | |
if 'HyperNet' not in self.train_params['model'] and not (self.train_params['zero_shot'] or self.eval_params['num_samples'] > 0): | |
# extract model predictions for current taxa from prediction matrix | |
pred = pred_mtx[cur_loc_indices, tt_id] | |
elif self.train_params['zero_shot'] or self.eval_params['num_samples'] > 0: | |
if self.train_params['model'] == 'ResidualFCNet': | |
if self.eval_params['num_samples'] == 0: | |
X = torch.cat([pos_examples[tt], neg_examples], dim=0).to(self.eval_params['device']) | |
w = torch.nn.Parameter(torch.zeros(X.shape[1], 1, device=self.eval_params['device'])) | |
nn.init.xavier_uniform_(w) | |
pred = torch.sigmoid(((loc_emb @ w)))[cur_loc_indices].flatten() | |
else: | |
X = torch.cat([pos_examples[tt], neg_examples], dim=0).to(self.eval_params['device']) | |
y = torch.zeros(X.shape[0], dtype=torch.long, device=self.eval_params['device']) | |
y[:pos_examples[tt].shape[0]] = 1 | |
# MINE MINE MINE MINE MINE MINE | |
# clf = LogisticRegression(class_weight='balanced', fit_intercept=False, C=0.05, max_iter=200, random_state=0).fit(X.numpy(), y.numpy()) | |
# #pred = torch.from_numpy(clf.predict_proba(loc_emb.cpu()))[:,1] | |
# pred = torch.sigmoid(((loc_emb @ (torch.from_numpy(clf.coef_).cuda().float().T)) + torch.from_numpy(clf.intercept_).cuda().float()).squeeze(-1))[cur_loc_indices] | |
# MAX MAX MAX MAX MAX MAX MAX MAX MAX | |
#clf = LogisticRegression(class_weight='balanced', fit_intercept=False, C=0.05, max_iter=200, random_state=0).fit(X.numpy(), y.numpy()) | |
C = 0.05 | |
w = torch.nn.Parameter(torch.zeros(X.shape[1], 1, device=self.eval_params['device'])) | |
opt = torch.optim.Rprop([w], lr=0.001) | |
crit = torch.nn.BCEWithLogitsLoss() | |
crit2 = torch.nn.MSELoss() | |
with torch.set_grad_enabled(True): | |
for i in range(40): | |
opt.zero_grad() | |
output = X @ w | |
yhat = y.float()[:, None] | |
loss = 0.5 * crit(output[yhat == 0], yhat[yhat == 0]) + 0.5 * crit(output[yhat == 1], | |
yhat[ | |
yhat == 1]) + 1 / ( | |
C * len(pos_examples[tt])) * crit2(w, 0 * w) | |
loss.backward() | |
opt.step() | |
#pred = torch.from_numpy(clf.predict_proba(loc_emb.cpu()))[:,1] | |
# pred = torch.sigmoid(((loc_emb @ w.cuda())))[cur_loc_indices].flatten() | |
pred = torch.sigmoid(((loc_emb @ w)))[cur_loc_indices].flatten() | |
#pred = torch.sigmoid(((loc_emb @ (torch.from_numpy(clf.coef_).cuda().float().T)) + torch.from_numpy(clf.intercept_).cuda().float()).squeeze(-1))[cur_loc_indices] | |
elif self.train_params['model'] == 'HyperNet': | |
if tt in emb_ids1: | |
embs = embs1 | |
emb_ids = emb_ids1 | |
keys = keys1 | |
else: | |
print('yes') | |
results['per_species_average_precision_all'][tt_id] = 0.0 | |
continue | |
embs = embs2 | |
emb_ids = emb_ids2 | |
keys = keys2 | |
if tt not in emb_ids: | |
results['per_species_average_precision_all'][tt_id] = 0.0 | |
continue | |
with torch.no_grad(): | |
sec_ind = emb_ids.index(tt) | |
sections = [i for i,x in enumerate(keys) if x[0] == sec_ind] | |
def get_feat(x): | |
species = model.species_enc(model.species_emb.zero_shot(x)) | |
species_w, species_b = species[..., :-1], species[..., -1:] | |
if self.eval_params['num_samples'] == 0: | |
out = loc_emb @ (species_w.detach()).T | |
return out | |
X = torch.cat([pos_examples[tt], neg_examples], dim=0).to(self.eval_params['device']) | |
y = torch.zeros(X.shape[0], dtype=torch.long, device=self.eval_params['device']) | |
y[:pos_examples[tt].shape[0]] = 1 | |
C = 0.05 | |
w = torch.nn.Parameter(torch.zeros_like(species_w,device=self.eval_params['device'])) | |
opt = torch.optim.Rprop([w], lr=0.001) | |
crit = torch.nn.BCEWithLogitsLoss() | |
crit2 = torch.nn.MSELoss() | |
with torch.set_grad_enabled(True): | |
for i in range(40): | |
opt.zero_grad() | |
output = (X @ (w + species_w.detach()).T) + 0*species_b.squeeze(-1) | |
yhat = y.float()[:, None].repeat(1, w.shape[0]) | |
loss = 0.5*crit(output[yhat == 0], yhat[yhat == 0]) + 0.5*crit(output[yhat == 1], yhat[yhat == 1]) + \ | |
1/(C*len(pos_examples[tt])) * crit2(w, 0*w) | |
loss.backward() | |
opt.step() | |
#print(i, loss.item()) | |
#print(' ') | |
out = loc_emb @ (w.data + species_w.detach()).T | |
out = (out + 0*species_b.squeeze(-1)) | |
return out | |
# average precision score: | |
yfeats = torch.cat([embs[section][None].to(self.eval_params['device']) for section in sections]) | |
preds = get_feat(yfeats) | |
if len(sections) > 1:#'habitat', 'overview_summary' | |
kws = ['text', 'range', 'distribution', 'habitat'] if len(keys) == len(keys2) else [self.eval_params['text_section']] | |
best_sections = [i for i,s in enumerate(sections) if any((x in keys[s][1].lower() for x in kws))] | |
results['per_species_average_precision_all'][tt_id] = utils.average_precision_score_fasterer(cur_labels, preds[cur_loc_indices][:,best_sections].mean(dim=1)).item() | |
else: | |
results['per_species_average_precision_all'][tt_id] = utils.average_precision_score_fasterer(cur_labels, preds[cur_loc_indices][:,0].mean(dim=1)).item() | |
continue | |
else: | |
raise NotImplementedError('Eval for hypernet not implemented') | |
pred = preds[:,tt_id%32] | |
# compute the AP for each taxa | |
results['per_species_average_precision_all'][tt_id] = utils.average_precision_score_fasterer(cur_labels, pred).item() | |
valid_taxa = ~np.isnan(results['per_species_average_precision_all']) | |
# store results | |
#TODO: this will have AP values for every species | |
#tt_id | |
per_species_average_precision_valid = results['per_species_average_precision_all'][valid_taxa] | |
results['mean_average_precision'] = per_species_average_precision_valid.mean() | |
results['num_eval_species_w_valid_ap'] = valid_taxa.sum() | |
results['num_eval_species_total'] = len(self.taxa) | |
taxas_and_ap_csv = "taxas_ap_range.csv" | |
#ap_csv = "per_species_taxa_id_valid.csv" | |
print(list(map(lambda row:len(row) ,range))) | |
zipped_data = zip(self.taxa, per_species_average_precision_valid, list(map(lambda row:int(row.sum()),range)), range_locs) | |
with open(taxas_and_ap_csv, mode='w', newline='') as csv_file: | |
writer = csv.writer(csv_file) | |
# Write the header (optional) | |
writer.writerow(['Taxa ID', 'Average Precision','Range Size', 'Range']) | |
# Write the zipped data | |
for taxa, ap, range_size, tensor_range in zipped_data: | |
# Flatten tensor to a single-line string | |
tensor_range_str = format_tensor(tensor_range) | |
writer.writerow([taxa, ap, range_size, tensor_range_str]) | |
# with open(ap_csv, mode='w', newline='') as csv_file: | |
# writer = csv.writer(csv_file) | |
# # Write the zipped data | |
# writer.writerows(per_species_average_precision_valid) | |
return results | |
def report(self, results): | |
for field in ['mean_average_precision', 'num_eval_species_w_valid_ap', 'num_eval_species_total']: | |
print(f'{field}: {results[field]}') | |
class EvaluatorIUCN: | |
def __init__(self, train_params, eval_params): | |
self.train_params = train_params | |
print(train_params['text_num_layers'],train_params['text_batchnorm'],train_params['text_hidden_dim'])#TODO | |
self.eval_params = eval_params | |
with open('paths.json', 'r') as f: | |
paths = json.load(f) | |
with open(os.path.join(paths['iucn'], 'iucn_res_5.json'), 'r') as f: | |
self.data = json.load(f) | |
self.obs_locs = np.array(self.data['locs'], dtype=np.float32) | |
self.taxa = [int(tt) for tt in self.data['taxa_presence'].keys()] | |
self.pos_eval_data_loc = os.path.join(paths['data'], 'positive_eval_data.npz') | |
self.background_eval_data_loc = os.path.join(paths['data'], '10000_background_negs.npz') | |
def run_evaluation(self, model, enc, extra_input=None): | |
results = {} | |
#self.train_params['model'] = 'ResidualFCNet' | |
#m = model | |
#model = lambda x, return_feats=True: m.pos_enc(x) | |
results['per_species_average_precision_all'] = np.zeros(len(self.taxa), dtype=np.float32) | |
# get eval locations and apply input encoding | |
obs_locs = torch.from_numpy(self.obs_locs).to(self.eval_params['device']) | |
loc_feat = torch.cat([enc.encode(obs_locs), extra_input.expand(obs_locs.shape[0], -1)], dim=1) if extra_input is not None else enc.encode(obs_locs) | |
# get classes to eval | |
# classes_of_interest = torch.zeros(len(self.taxa), dtype=torch.int64) | |
classes_of_interest = np.zeros(len(self.taxa)) | |
array_class_to_taxa = np.array(self.train_params['class_to_taxa']) | |
for tt_id, tt in enumerate(self.taxa): | |
class_of_interest = np.where(array_class_to_taxa == tt)[0] | |
if len(class_of_interest) != 0: | |
classes_of_interest[tt_id] = class_of_interest | |
classes_of_interest = torch.from_numpy(classes_of_interest).to(dtype=torch.long, device=self.eval_params['device']) | |
# MINE MINE MINE | |
# classes_of_interest = classes_of_interest.to(self.eval_params['device']) | |
if self.eval_params['extract_pos']: | |
assert 'HyperNet' in self.train_params['model'] | |
model = model.pos_enc | |
self.train_params['model'] = 'ResidualFCNet' | |
# Should only effect mine | |
if ('CombinedModel' in self.train_params['model']) or ('MultiInputModel' in self.train_params['model']): | |
with torch.no_grad(): | |
dummy_context_mask = None | |
dummy_context_sequence = None | |
# generate model predictions for classes of interest at eval locations | |
loc_emb = model(x=loc_feat, context_sequence=dummy_context_sequence, context_mask=dummy_context_mask, | |
class_ids=classes_of_interest, return_feats=True) | |
wt = model.get_eval_embeddings(classes_of_interest) | |
print("Creating IUCN prediction matrix") | |
pred_mtx = torch.matmul(loc_emb, torch.transpose(wt, 0, 1)) | |
elif self.train_params['model'] == 'VariableInputModel': | |
with torch.no_grad(): | |
loc_emb = model.get_loc_emb(x=loc_feat) | |
classes_of_interest = classes_of_interest.to(self.eval_params["device"]) | |
wt = model.get_eval_embeddings(classes_of_interest) | |
wt2 = model.get_ema_embeddings(classes_of_interest) | |
# technically with my mock transformer I could just directly access the class embeddings but | |
# I will need to use the emas when I move to the true transformer model (I think) | |
# wt = model.class_emb.weight[classes_of_interest, :] | |
pred_mtx = torch.matmul(loc_emb, torch.transpose(wt, 0, 1)) | |
elif 'HyperNet' not in self.train_params['model'] and not (self.train_params['zero_shot'] or self.eval_params['num_samples'] > 0): | |
# generate model predictions for classes of interest at eval locations | |
loc_emb = model(loc_feat, return_feats=True) | |
wt = model.class_emb.weight[classes_of_interest, :] | |
pred_mtx = torch.matmul(loc_emb, torch.transpose(wt, 0, 1)) | |
elif (self.train_params['zero_shot'] or self.eval_params['num_samples'] > 0): | |
if self.train_params['model'] == 'ResidualFCNet': | |
import datasets | |
# from sklearn.linear_model import LogisticRegression | |
# with open('paths.json', 'r') as f: | |
# paths = json.load(f) | |
# data_dir = paths['train'] | |
# obs_file = os.path.join(data_dir, self.train_params['obs_file']) | |
# taxa_file = os.path.join(data_dir, self.train_params['taxa_file']) | |
# taxa_file_snt = os.path.join(data_dir, 'taxa_subsets.json') | |
# | |
# taxa_of_interest = datasets.get_taxa_of_interest(self.train_params['species_set'], self.train_params['num_aux_species'], | |
# self.train_params['aux_species_seed'], self.train_params['taxa_file'], taxa_file_snt) | |
obs_file = self.pos_eval_data_loc | |
locs, labels, _, dates, _, _ = datasets.load_eval_inat_data(obs_file) | |
unique_taxa, class_ids = np.unique(labels, return_inverse=True) | |
class_to_taxa = unique_taxa.tolist() | |
# idx_ss = datasets.get_idx_subsample_observations(labels, self.eval_params['num_samples'], random.randint(0,2**32), None, -1) | |
idx_ss = datasets.get_idx_subsample_observations_eval(labels=labels, hard_cap=self.eval_params['num_samples']) | |
locs = torch.from_numpy(np.array(locs)) | |
labels = torch.from_numpy(np.array(class_ids)) | |
locs = locs[idx_ss] | |
labels = labels[idx_ss] | |
# MINE MINE MINE MINE MINE MINE MINE MINE MINE | |
# with torch.no_grad(): | |
# pos_examples = {} | |
# for tt in self.taxa: | |
# c = class_to_taxa.index(tt) | |
# pos_examples[tt] = locs[labels == c] | |
# pos_examples[tt] = model(enc.encode(pos_examples[tt].to(self.eval_params['device'])), return_feats=True).cpu() | |
# | |
# if self.eval_params['target_background']: | |
# target_background_dataset = datasets.get_train_data(params=self.train_params) | |
# # print("CHECK IF THIS TARGET NEGS THING IS WORKING PROPERLY WHEN SERVER WORKS") | |
# # print("IT MAY INCLUDE EVAL SPECIES / ONLY EVAL SPECIES") # it only includes the backbone species currently - good | |
# | |
# random_negs = utils.rand_samples(5000, self.eval_params['device'], rand_type='spherical') | |
# | |
# # Get the total number of locations | |
# total_locs = len(target_background_dataset.locs) | |
# | |
# # If there are more than 5000 locations, sample 5000 | |
# if total_locs > 5000: | |
# indices = np.random.choice(total_locs, 5000, replace=False) | |
# target_negs = target_background_dataset.locs[indices].to(self.eval_params['device']) | |
# else: | |
# target_negs = target_background_dataset.locs.to(self.eval_params['device']) | |
# # print('CHECK THE FORMAT OF THESE TARGET LOCS COMPARED TO NEG LOCS') # look good | |
# | |
# neg_examples = torch.vstack((random_negs, target_negs)) | |
# | |
# del target_background_dataset | |
# | |
# else: | |
# neg_examples = utils.rand_samples(10000, self.eval_params['device'], rand_type='spherical') | |
# if extra_input is not None: | |
# raise NotImplementedError('extra_input provided') | |
# neg_examples = model(enc.encode(neg_examples, normalize=False), return_feats=True).cpu() | |
# print("You can probably speed eval back up once the server is available by changing this shit back") | |
# | |
# # Function to process data in batches | |
# def process_in_batches(model, loc_feat, batch_size=64): | |
# loc_emb = [] | |
# for i in range(0, len(loc_feat), batch_size): | |
# batch = loc_feat[i:i + batch_size] | |
# with torch.no_grad(): | |
# batch_emb = model(batch, return_feats=True) | |
# loc_emb.append(batch_emb) | |
# return torch.cat(loc_emb, dim=0) # Concatenate the results | |
# | |
# # loc_emb = model(loc_feat, return_feats=True) | |
# loc_emb = process_in_batches(model, loc_feat, batch_size=2048) | |
pos_examples = {} | |
for tt in self.taxa: | |
c = class_to_taxa.index(tt) | |
pos_examples[tt] = locs[labels == c] | |
pos_examples[tt] = model(enc.encode(pos_examples[tt].to(self.eval_params['device'])), return_feats=True).cpu() | |
obs_file = self.background_eval_data_loc | |
neg_locs, _, _, _, _, _ = datasets.load_eval_inat_data(obs_file) | |
neg_locs = torch.from_numpy(neg_locs) | |
#random negs | |
neg_examples = utils.rand_samples(10000, self.eval_params['device'], rand_type='spherical') | |
if extra_input is not None: | |
raise NotImplementedError('extra_input provided') | |
# add target negs | |
neg_examples = model(torch.cat([enc.encode(neg_examples, normalize=False), enc.encode(neg_locs[torch.randperm(neg_locs.shape[0], device=locs.device)[:10000]].clone().to(self.eval_params['device']), normalize=True)]), return_feats=True).cpu() | |
loc_emb = model(loc_feat, return_feats=True) | |
elif self.train_params['model'] == 'HyperNet': | |
import datasets | |
# from sklearn.linear_model import LogisticRegression | |
# with open('paths.json', 'r') as f: | |
# paths = json.load(f) | |
# data_dir = paths['train'] | |
# obs_file = os.path.join(data_dir, self.train_params['obs_file']) | |
# taxa_file = os.path.join(data_dir, self.train_params['taxa_file']) | |
# taxa_file_snt = os.path.join(data_dir, 'taxa_subsets.json') | |
# | |
# taxa_of_interest = datasets.get_taxa_of_interest(self.train_params['species_set'], self.train_params['num_aux_species'], | |
# self.train_params['aux_species_seed'], self.train_params['taxa_file'], taxa_file_snt) | |
obs_file = self.pos_eval_data_loc | |
locs, labels, _, dates, _, _ = datasets.load_eval_inat_data(obs_file) | |
# MINE MINE MINE MINE | |
# unique_taxa, class_ids = np.unique(labels, return_inverse=True) | |
# class_to_taxa = unique_taxa.tolist() | |
# idx_ss = datasets.get_idx_subsample_observations(labels, self.eval_params['num_samples'], random.randint(0,2**32), None, -1) | |
# locs = torch.from_numpy(np.array(locs)) | |
# labels = torch.from_numpy(np.array(class_ids)) | |
# locs = locs[idx_ss] | |
# labels = labels[idx_ss] | |
# with torch.no_grad(): | |
# MAX MAX MAX MAX MAX MAX MAX | |
unique_taxa, class_ids, class_counts = np.unique(labels, return_inverse=True, return_counts=True) | |
class_counts = class_counts.clip(max=1000) | |
if self.eval_params['num_samples'] > 0: | |
class_to_taxa = unique_taxa.tolist() | |
idx_ss = datasets.get_idx_subsample_observations_eval(labels=labels, hard_cap=self.eval_params['num_samples']) | |
# idx_ss = datasets.get_idx_subsample_observations(labels, self.eval_params['num_samples'], random.randint(0,2**32), None, -1) | |
locs = torch.from_numpy(np.array(locs)) | |
labels = torch.from_numpy(np.array(class_ids)) | |
locs = locs[idx_ss] | |
labels = labels[idx_ss] | |
pos_examples = {} | |
for tt in self.taxa: | |
c = class_to_taxa.index(tt) | |
pos_examples[tt] = locs[labels == c] | |
pos_examples[tt] = model.pos_enc(enc.encode(pos_examples[tt].to(self.eval_params['device']))).cpu() | |
# MINE MINE MINE MINE MINE MINE MINE MINE | |
# if self.eval_params['target_background']: | |
# | |
# target_background_dataset = datasets.get_train_data(params=self.train_params) | |
# # print("CHECK IF THIS TARGET NEGS THING IS WORKING PROPERLY WHEN SERVER WORKS") | |
# # print("IT MAY INCLUDE EVAL SPECIES / ONLY EVAL SPECIES") | |
# | |
# random_negs = utils.rand_samples(5000, self.eval_params['device'], rand_type='spherical') | |
# | |
# # Get the total number of locations | |
# total_locs = len(target_background_dataset.locs) | |
# | |
# # If there are more than 5000 locations, sample 5000 | |
# if total_locs > 5000: | |
# indices = np.random.choice(total_locs, 5000, replace=False) | |
# target_negs = target_background_dataset.locs[indices].to(self.eval_params['device']) | |
# else: | |
# target_negs = target_background_dataset.locs.to(self.eval_params['device']) | |
# # print('CHECK THE FORMAT OF THESE TARGET LOCS COMPARED TO NEG LOCS') | |
# | |
# neg_examples = torch.vstack((random_negs, target_negs)) | |
# | |
# del target_background_dataset | |
# | |
# else: | |
# neg_examples = utils.rand_samples(10000, self.eval_params['device'], rand_type='spherical') | |
# MAX MAX MAX MAX MAX MAX MAX MAX MAX MAX MAX MAX MAX | |
obs_file = self.background_eval_data_loc | |
neg_locs, _, _, _, _, _ = datasets.load_eval_inat_data(obs_file) | |
neg_locs = torch.from_numpy(neg_locs) | |
# random negs | |
neg_examples = utils.rand_samples(10000, self.eval_params['device'], rand_type='spherical') | |
if extra_input is not None: | |
raise NotImplementedError('extra_input provided') | |
# MINE MINE MINE | |
# neg_examples = model.pos_enc(enc.encode(neg_examples, normalize=False)).cpu() | |
# MAX MAX MAX MAX MAX MAX MAX | |
# add target negs | |
neg_examples = model.pos_enc(torch.cat([enc.encode(neg_examples, normalize=False), enc.encode(neg_locs[torch.randperm(neg_locs.shape[0], device=locs.device)[:10000]].clone().to(self.eval_params['device']), normalize=True)])).cpu() | |
#embs = torch.load(self.train_params['text_emb_path']) #TODO | |
embs = torch.load('gpt_data.pt', weights_only=False) | |
#embs = torch.load('ldsdm_data.pt') | |
emb_ids = embs['taxon_id'].tolist() | |
keys = embs['keys'] | |
embs = embs['data'] | |
# embs2 doesn't even do anything. Could just remove the whole thing, but that is how it is in Max's code | |
# MINE MINE MINE | |
embs2 = torch.load('wiki_data_v4.pt', weights_only=False) | |
# MAX MAX MAX | |
# embs2 = torch.load('wiki_data_v3.pt') | |
emb_ids2 = embs2['taxon_id'].tolist() | |
keys2 = embs2['keys'] | |
embs2 = embs2['data'] | |
loc_emb = model.pos_enc(loc_feat) | |
else: | |
raise NotImplementedError('Eval for zero-shot not implemented') | |
# MINE - my version - why am I stopping residualFCnets doing this? | |
# if self.eval_params['num_samples'] == -1 and not (('CombinedModel' in self.train_params['model']) or ('MultiInputModel' in self.train_params['model']) or ('ResidualFCNet' in self.train_params['model'])): | |
# MAX - a variant of Maxs - only difference should now be my model types | |
#if self.eval_params['num_samples'] == -1 and not (('CombinedModel' in self.train_params['model']) or ('MultiInputModel' in self.train_params['model'])): | |
if self.eval_params['num_samples'] == -1 and not (self.train_params['model'] in ['CombinedModel', 'MultiInputModel', 'VariableInputModel', 'ResidualFCNet']): | |
loc_emb = model.pos_enc(loc_feat) | |
if self.eval_params['num_samples'] == -1 and not (self.train_params['model'] in ['CombinedModel', 'MultiInputModel', 'VariableInputModel']): | |
loc_emb = model.forward(loc_feat, return_feats=True) | |
for tt_id, tt in tqdm(enumerate(self.taxa)): | |
class_of_interest = np.where(array_class_to_taxa == tt)[0] | |
if len(class_of_interest) == 0 and not (self.train_params['zero_shot'] or self.eval_params['num_samples'] > 0): | |
# taxa of interest is not in the model | |
results['per_species_average_precision_all'][tt_id] = np.nan | |
else: | |
# Only effects my models | |
if self.train_params['model'] == 'MultiInputModel': | |
gt = torch.zeros(obs_locs.shape[0], dtype=torch.float32, device=self.eval_params['device']) | |
gt[self.data['taxa_presence'][str(tt)]] = 1.0 | |
with torch.no_grad(): | |
logits = pred_mtx[:, tt_id] | |
preds = torch.sigmoid(logits) | |
results['per_species_average_precision_all'][tt_id] = utils.average_precision_score_fasterer(gt, preds).item() | |
continue | |
elif self.train_params['model'] == 'VariableInputModel': | |
gt = torch.zeros(obs_locs.shape[0], dtype=torch.float32, device=self.eval_params['device']) | |
gt[self.data['taxa_presence'][str(tt)]] = 1.0 | |
with torch.no_grad(): | |
logits = pred_mtx[:, tt_id] | |
preds = torch.sigmoid(logits) | |
results['per_species_average_precision_all'][tt_id] = utils.average_precision_score_fasterer(gt, preds).item() | |
continue | |
# MINE MINE MINE | |
# elif (self.train_params['model'] == 'ResidualFCNet') and (self.eval_params['num_samples'] <= 0): | |
# gt = torch.zeros(obs_locs.shape[0], dtype=torch.float32, device=self.eval_params['device']) | |
# gt[self.data['taxa_presence'][str(tt)]] = 1.0 | |
# with torch.no_grad(): | |
# logits = pred_mtx[:, tt_id] | |
# preds = torch.sigmoid(logits) | |
# results['per_species_average_precision_all'][tt_id] = utils.average_precision_score_fasterer(gt, preds).item() | |
# continue | |
if self.eval_params['num_samples'] == -1 and self.train_params['model'] == 'HyperNet': | |
gt = torch.zeros(obs_locs.shape[0], dtype=torch.float32, device=self.eval_params['device']) | |
gt[self.data['taxa_presence'][str(tt)]] = 1.0 | |
species_w = model.species_params[self.train_params['class_to_taxa'].index(tt)] | |
preds = loc_emb @ species_w.detach() | |
results['per_species_average_precision_all'][tt_id] = utils.average_precision_score_fasterer(gt, preds).item() | |
continue | |
elif self.eval_params['num_samples'] == -1 and self.train_params['model'] == 'ResidualFCNet': | |
gt = torch.zeros(obs_locs.shape[0], dtype=torch.float32, device=self.eval_params['device']) | |
gt[self.data['taxa_presence'][str(tt)]] = 1.0 | |
preds = model.eval_single_class(x=loc_emb, class_of_interest=self.train_params['class_to_taxa'].index(tt)).detach() | |
results['per_species_average_precision_all'][tt_id] = utils.average_precision_score_fasterer(gt, preds).item() | |
continue | |
# MINE MINE MINE MINE MINE MINE MINE - seems un needed? | |
# elif (self.eval_params['num_samples'] == -1) and ('Hypernet' in self.train_params['model']): | |
# gt = torch.zeros(obs_locs.shape[0], dtype=torch.float32, device=self.eval_params['device']) | |
# gt[self.data['taxa_presence'][str(tt)]] = 1.0 | |
# species_w = model.species_params[self.train_params['class_to_taxa'].index(tt)] | |
# preds = loc_emb @ species_w.detach() | |
# results['per_species_average_precision_all'][tt_id] = utils.average_precision_score_fasterer(gt, preds).item() | |
# continue | |
# extract model predictions for current taxa from prediction matrix | |
if 'HyperNet' not in self.train_params['model'] and not (self.train_params['zero_shot'] or self.eval_params['num_samples'] > 0): | |
pred = pred_mtx[:, tt_id] | |
elif (self.train_params['zero_shot'] or self.eval_params['num_samples'] > 0): | |
if self.train_params['model'] == 'ResidualFCNet': | |
if self.eval_params['num_samples'] == 0: | |
X = torch.cat([pos_examples[tt], neg_examples], dim=0).to(self.eval_params['device']) | |
w = torch.nn.Parameter(torch.zeros(X.shape[1], 1, device=self.eval_params['device'])) | |
nn.init.xavier_uniform_(w) | |
pred = torch.sigmoid(((loc_emb @ w)))[cur_loc_indices].flatten() | |
else: | |
X = torch.cat([pos_examples[tt], neg_examples], dim=0).to(self.eval_params['device']) | |
y = torch.zeros(X.shape[0], dtype=torch.long, device=self.eval_params['device']) | |
y[:pos_examples[tt].shape[0]] = 1 | |
# MINE MINE MINE | |
# clf = LogisticRegression(class_weight='balanced', fit_intercept=False, C=0.05, max_iter=200, random_state=0).fit(X.numpy(), y.numpy()) | |
# #pred = torch.from_numpy(clf.predict_proba(loc_emb.cpu()))[:,1] | |
# pred = torch.sigmoid(((loc_emb @ (torch.from_numpy(clf.coef_).to(self.eval_params['device']).float().T)) + torch.from_numpy(clf.intercept_).to(self.eval_params['device']).float()).squeeze(-1)) | |
# # pred = torch.sigmoid(((loc_emb @ (torch.from_numpy(clf.coef_).cuda().float().T)) + torch.from_numpy(clf.intercept_).cuda().float()).squeeze(-1)) | |
# MAX MAX MAX MAX MAX MAX MAX MAX MAX MAX MAX | |
#clf = LogisticRegression(class_weight='balanced', fit_intercept=False, C=0.05, max_iter=200, random_state=0).fit(X.numpy(), y.numpy()) | |
C = 0.05 | |
w = torch.nn.Parameter(torch.zeros(X.shape[1], 1, device=self.eval_params['device'])) | |
opt = torch.optim.Rprop([w], lr=0.001) | |
crit = torch.nn.BCEWithLogitsLoss() | |
crit2 = torch.nn.MSELoss() | |
with torch.set_grad_enabled(True): | |
for i in range(40): | |
opt.zero_grad() | |
output = X @ w | |
yhat = y.float()[:, None] | |
loss = 0.5 * crit(output[yhat == 0], yhat[yhat == 0]) + 0.5 * crit(output[yhat == 1], | |
yhat[ | |
yhat == 1]) + 1 / ( | |
C * len(pos_examples[tt])) * crit2(w, 0 * w) | |
loss.backward() | |
opt.step() | |
pred = torch.sigmoid(((loc_emb @ w))).flatten() | |
#pred = torch.from_numpy(clf.predict_proba(loc_emb.cpu()))[:,1] | |
#pred = torch.sigmoid(((loc_emb @ (torch.from_numpy(clf.coef_).cuda().float().T)) + torch.from_numpy(clf.intercept_).cuda().float()).squeeze(-1)) | |
#locs = torch.from_numpy(utils.coord_grid((1000,2000))).to(self.eval_params['device']) | |
#locs = model(enc.encode(locs), return_feats=True) | |
#img = torch.sigmoid(((locs @ (torch.from_numpy(clf.coef_).cuda().float().T)) + torch.from_numpy(clf.intercept_).cuda().float()).squeeze(-1)) | |
#plt.imshow(img.detach().cpu()) | |
elif self.train_params['model'] == 'HyperNet': | |
if tt not in emb_ids and tt not in emb_ids2: | |
results['per_species_average_precision_all'][tt_id] = 0.0 | |
continue | |
gt = torch.zeros(obs_locs.shape[0], dtype=torch.float32, device=self.eval_params['device']) | |
gt[self.data['taxa_presence'][str(tt)]] = 1.0 | |
if self.eval_params['num_samples'] == -1: | |
species_w = model.species_params[self.train_params['class_to_taxa'].index(tt)] | |
preds = loc_emb @ species_w.detach() | |
results['per_species_average_precision_all'][tt_id] = utils.average_precision_score_fasterer(gt,preds).item() | |
continue | |
with torch.no_grad(): | |
if tt in emb_ids: | |
em = embs | |
emi = emb_ids | |
ky = keys | |
else: | |
results['per_species_average_precision_all'][tt_id] = 0.0 | |
continue | |
em = embs2 | |
emi = emb_ids2 | |
ky = keys2 | |
sec_ind = emi.index(tt) | |
sections = [i for i,x in enumerate(ky) if x[0] == sec_ind] | |
order = ['distribution', 'range', 'text'] | |
best_section = None | |
order_ind = 0 | |
while best_section is None and order_ind < len(order): | |
for section in sections: | |
if order[order_ind] in ky[section][1].lower(): | |
best_section = section | |
break | |
order_ind += 1 | |
gt = torch.zeros(obs_locs.shape[0], dtype=torch.float32, device=self.eval_params['device']) | |
gt[self.data['taxa_presence'][str(tt)]] = 1.0 | |
def get_feat(x): | |
species = model.species_enc(model.species_emb.zero_shot(x)) | |
species_w, species_b = species[..., :-1], species[..., -1:] | |
if self.eval_params['num_samples'] == 0: | |
out = loc_emb @ (species_w.detach()).T | |
return out | |
X = torch.cat([pos_examples[tt], neg_examples], dim=0).to(self.eval_params['device']) | |
y = torch.zeros(X.shape[0], dtype=torch.long, device=self.eval_params['device']) | |
y[:pos_examples[tt].shape[0]] = 1 | |
C = 0.05 | |
w = torch.nn.Parameter(torch.zeros_like(species_w, device=self.eval_params['device'])) | |
opt = torch.optim.Rprop([w], lr=0.001) | |
crit = torch.nn.BCEWithLogitsLoss() | |
crit2 = torch.nn.MSELoss() | |
with torch.set_grad_enabled(True): | |
for i in range(40): | |
opt.zero_grad() | |
output = (X @ (w + species_w.detach()).T) + 0*species_b.squeeze(-1) | |
yhat = y.float()[:, None].repeat(1, w.shape[0]) | |
loss = 0.5 * crit(output[yhat == 0], yhat[yhat == 0]) + 0.5 * crit( | |
output[yhat == 1], yhat[yhat == 1]) + 1 / ( | |
C * len(pos_examples[tt])) * crit2(w, 0 * w) | |
loss.backward() | |
opt.step() | |
'''out = loc_emb @ (w.data + species_w.detach()).T | |
gt = torch.zeros(out.shape[0], dtype=torch.float32, | |
device=self.eval_params['device']) | |
gt[self.data['taxa_presence'][str(tt)]] = 1.0 | |
print(utils.average_precision_score_fasterer(gt, out[:, 0]).item())''' | |
out = loc_emb @ (w.data + species_w.detach()).T | |
out = (out + 0*species_b.squeeze(-1)) | |
return out | |
# average precision score: | |
yfeats = torch.cat([em[section][None].to(self.eval_params['device']) for section in sections]) | |
preds = get_feat(yfeats) | |
if len(sections) > 1:#'habitat', 'overview_summary' | |
kws = [self.eval_params['text_section']] if len(ky) == len(keys) else ['text', 'range','distribution','habitat'] | |
best_sections = [i for i,s in enumerate(sections) if any((x in ky[s][1].lower() for x in kws))] | |
#yfeats2 = torch.cat( | |
# [em[section][None].to(self.eval_params['device']) for section in best_sections]).mean(dim=0, keepdim=True) | |
#pred2 = get_feat(yfeats2) | |
results['per_species_average_precision_all'][tt_id] = utils.average_precision_score_fasterer(gt, preds[:, best_sections].mean(dim=1)).item() | |
else: | |
# MINE MINE MINE MINE | |
# sigmoid_preds = torch.sigmoid(preds[:, 0]) | |
# results['per_species_average_precision_all'][tt_id] = utils.average_precision_score_fasterer(gt, sigmoid_preds).item() | |
results['per_species_average_precision_all'][tt_id] = utils.average_precision_score_fasterer(gt, preds[:, 0]).item() | |
continue | |
else: | |
if tt_id % 32 == 0: | |
# MINE MINE MINE MINE | |
# with torch.no_grad(): | |
# preds = torch.empty(loc_feat.shape[0], classes_of_interest[tt_id:tt_id+32].shape[0], device=self.eval_params['device']) | |
# for i in range(0,preds.shape[0],50000): | |
# xbatch = loc_feat[i:i+50000] | |
# ybatch = classes_of_interest[tt_id:tt_id+32].to(self.eval_params['device']).expand(xbatch.shape[0], -1) | |
# preds[i:i+50000] = model(xbatch, ybatch) | |
preds = torch.empty(loc_feat.shape[0], classes_of_interest[tt_id:tt_id+32].shape[0], device=self.eval_params['device']) | |
for i in range(0,preds.shape[0],50000): | |
xbatch = loc_feat[i:i+50000] | |
ybatch = classes_of_interest[tt_id:tt_id+32].to(self.eval_params['device']).expand(xbatch.shape[0], -1) | |
preds[i:i+50000] = model(xbatch, ybatch) | |
pred = preds[:,tt_id%32] | |
gt = torch.zeros(obs_locs.shape[0], dtype=torch.float32, device=self.eval_params['device']) | |
gt[self.data['taxa_presence'][str(tt)]] = 1.0 | |
# average precision score: | |
results['per_species_average_precision_all'][tt_id] = utils.average_precision_score_fasterer(gt, pred).item() | |
valid_taxa = ~np.isnan(results['per_species_average_precision_all']) | |
# store results | |
per_species_average_precision_valid = results['per_species_average_precision_all'][valid_taxa] | |
results['mean_average_precision'] = per_species_average_precision_valid.mean() | |
results['num_eval_species_w_valid_ap'] = valid_taxa.sum() | |
results['num_eval_species_total'] = len(self.taxa) | |
return results | |
def report(self, results): | |
for field in ['mean_average_precision', 'num_eval_species_w_valid_ap', 'num_eval_species_total']: | |
print(f'{field}: {results[field]}') | |
# MINE MINE MINE MINE but shouldn't effect things too much | |
def batched_matmul(self, loc_emb, wt): | |
batch_size = self.eval_params["batch_size"] | |
num_samples = loc_emb.size(0) | |
num_batches = (num_samples + batch_size - 1) // batch_size # Ensures rounding up | |
# Preallocate the result array | |
pred_mtx = np.empty((num_samples, wt.size(0)), dtype=np.float32) | |
wt_T = wt.t() | |
# Buffer size for temporary storage | |
buffer_size = batch_size * 10 # Adjust buffer size as needed | |
buffer = np.empty((buffer_size, wt.size(0)), dtype=np.float32) | |
buffer_index = 0 | |
current_write_index = 0 | |
for _, i in tqdm(enumerate(range(num_batches))): | |
start_idx = i * batch_size | |
end_idx = min(start_idx + batch_size, num_samples) | |
# Perform matrix multiplication for the current batch in PyTorch | |
loc_emb_batch = loc_emb[start_idx:end_idx].to(self.eval_params['device']) | |
batch_result = torch.matmul(loc_emb_batch, wt_T).cpu().numpy() | |
# Calculate the size of the current batch | |
current_batch_size = end_idx - start_idx | |
# Check if the buffer can accommodate the current batch | |
if buffer_index + current_batch_size > buffer_size: | |
# Write buffer contents to pred_mtx | |
pred_mtx[current_write_index:current_write_index + buffer_index] = buffer[:buffer_index] | |
current_write_index += buffer_index | |
buffer_index = 0 # Reset buffer index | |
# Add the current batch result to the buffer | |
buffer[buffer_index:buffer_index + current_batch_size] = batch_result | |
buffer_index += current_batch_size | |
# Clean up to free memory | |
del loc_emb_batch | |
del batch_result | |
# torch.cuda.empty_cache() # Consider removing if unnecessary | |
# Write any remaining data in the buffer to pred_mtx | |
if buffer_index > 0: | |
pred_mtx[current_write_index:current_write_index + buffer_index] = buffer[:buffer_index] | |
return pred_mtx | |
class EvaluatorGeoPrior: | |
def __init__(self, train_params, eval_params): | |
# store parameters: | |
self.train_params = train_params | |
self.eval_params = eval_params | |
with open('paths.json', 'r') as f: | |
paths = json.load(f) | |
# load vision model predictions: | |
self.data = np.load(os.path.join(paths['geo_prior'], 'geo_prior_model_preds.npz')) | |
print(self.data['probs'].shape[0], 'total test observations') | |
# load locations: | |
meta = pd.read_csv(os.path.join(paths['geo_prior'], 'geo_prior_model_meta.csv')) | |
self.obs_locs = np.vstack((meta['longitude'].values, meta['latitude'].values)).T.astype(np.float32) | |
temp = np.array(meta['observed_on'].values, dtype='S10') | |
temp = temp.view('S1').reshape((temp.size, -1)) | |
years = temp[:, :4].view('S4').astype(int)[:, 0] | |
months = temp[:, 5:7].view('S2').astype(int)[:, 0] | |
days = temp[:, 8:10].view('S2').astype(int)[:, 0] | |
days_per_month = np.cumsum([0] + [monthrange(2018, mm)[1] for mm in range(1, 12)]) | |
dates = days_per_month[months - 1] + days - 1 | |
self.dates = np.round((dates) / 365.0, 4).astype(np.float32) | |
# taxonomic mapping: | |
self.taxon_map = self.find_mapping_between_models(self.data['model_to_taxa'], self.train_params['class_to_taxa']) | |
self.time_enc = utils.TimeEncoder() if train_params['input_time'] else None | |
print(self.taxon_map.shape[0], 'out of', len(self.data['model_to_taxa']), 'taxa in both vision and geo models') | |
cs = torch.load('class_counts.pt') | |
cs = cs.sum() / cs | |
cs = cs.to(self.eval_params['device']) | |
self.C = cs[None] | |
self.pdf = utils.DataPDFH3(device=self.eval_params['device']) | |
def find_mapping_between_models(self, vision_taxa, geo_taxa): | |
# this will output an array of size N_overlap X 2 | |
# the first column will be the indices of the vision model, and the second is their | |
# corresponding index in the geo model | |
taxon_map = np.ones((vision_taxa.shape[0], 2), dtype=np.int32)*-1 | |
taxon_map[:, 0] = np.arange(vision_taxa.shape[0]) | |
geo_taxa_arr = np.array(geo_taxa) | |
for tt_id, tt in enumerate(vision_taxa): | |
ind = np.where(geo_taxa_arr==tt)[0] | |
if len(ind) > 0: | |
taxon_map[tt_id, 1] = ind[0] | |
inds = np.where(taxon_map[:, 1]>-1)[0] | |
taxon_map = taxon_map[inds, :] | |
return taxon_map | |
def convert_to_inat_vision_order(self, geo_pred_ip, vision_top_k_prob, vision_top_k_inds, vision_taxa, taxon_map, k=1.0): | |
# this is slow as we turn the sparse input back into the same size as the dense one | |
vision_pred = np.zeros((geo_pred_ip.shape[0], len(vision_taxa)), dtype=np.float32) | |
geo_pred = k*np.ones((geo_pred_ip.shape[0], len(vision_taxa)), dtype=np.float32) | |
vision_pred[np.arange(vision_pred.shape[0])[..., np.newaxis], vision_top_k_inds] = vision_top_k_prob | |
geo_pred[:, taxon_map[:, 0]] = geo_pred_ip[:, taxon_map[:, 1]] | |
return geo_pred, vision_pred | |
def run_evaluation(self, model, enc, extra_input=None): | |
results = {} | |
# loop over in batches | |
batch_start = np.hstack((np.arange(0, self.data['probs'].shape[0], self.eval_params['batch_size']), self.data['probs'].shape[0])) | |
correct_pred = np.zeros(self.data['probs'].shape[0]) | |
from tqdm import tqdm | |
for bb_id, bb in tqdm(enumerate(range(len(batch_start)-1))): | |
batch_inds = np.arange(batch_start[bb], batch_start[bb+1]) | |
vision_probs = self.data['probs'][batch_inds, :] | |
vision_inds = self.data['inds'][batch_inds, :] | |
gt = self.data['labels'][batch_inds] | |
dates = torch.from_numpy(self.dates[batch_inds]) | |
obs_locs_batch = torch.from_numpy(self.obs_locs[batch_inds, :]).to(self.eval_params['device']) | |
noise_level = 1.0 | |
if self.time_enc is not None: | |
extra_input = self.time_enc.encode(torch.cat([dates[...,None], torch.full((*dates.shape, 1),noise_level)], dim=1)).to( | |
self.eval_params['device']) | |
loc_feat = torch.cat([enc.encode(obs_locs_batch), extra_input], 1) if extra_input is not None else enc.encode(obs_locs_batch) | |
with torch.no_grad(): | |
geo_pred = model(loc_feat).cpu().numpy() | |
geo_pred, vision_pred = self.convert_to_inat_vision_order(geo_pred, vision_probs, vision_inds, | |
self.data['model_to_taxa'], self.taxon_map, k=1.0) | |
#geo_pred = softmax(torch.from_numpy(geo_pred), dim=1).numpy() | |
comb_pred = np.argmax(vision_pred*geo_pred, 1) | |
comb_pred = (comb_pred==gt) | |
correct_pred[batch_inds] = comb_pred | |
accuracy_by_taxa = np.zeros(len(self.data['model_to_taxa'])) | |
for tt_id, tt in enumerate(self.data['model_to_taxa']): | |
inds = np.where(self.data['labels'] == tt)[0] | |
accuracy_by_taxa[tt_id] = float((correct_pred[inds].mean())) | |
torch.save(correct_pred, f'correct_{noise_level}.pt') | |
torch.save(accuracy_by_taxa, f'abt_{noise_level}.pt') | |
results['vision_only_top_1'] = float((self.data['inds'][:, -1] == self.data['labels']).mean()) | |
results['vision_geo_top_1'] = float(correct_pred.mean()) | |
return results | |
def report(self, results): | |
print('Overall accuracy vision only model', round(results['vision_only_top_1'], 3)) | |
print('Overall accuracy of geo model ', round(results['vision_geo_top_1'], 3)) | |
print('Gain ', round(results['vision_geo_top_1'] - results['vision_only_top_1'], 3)) | |
class EvaluatorGeoFeature: | |
def __init__(self, train_params, eval_params): | |
self.train_params = train_params | |
self.eval_params = eval_params | |
with open('paths.json', 'r') as f: | |
paths = json.load(f) | |
self.data_path = paths['geo_feature'] | |
self.country_mask = tifffile.imread(os.path.join(paths['masks'], 'USA_MASK.tif')) == 1 | |
self.raster_names = ['ABOVE_GROUND_CARBON', 'ELEVATION', 'LEAF_AREA_INDEX', 'NON_TREE_VEGITATED', 'NOT_VEGITATED', 'POPULATION_DENSITY', 'SNOW_COVER', 'SOIL_MOISTURE', 'TREE_COVER'] | |
self.raster_names_log_transform = ['POPULATION_DENSITY'] | |
def load_raster(self, raster_name, log_transform=False): | |
raster = tifffile.imread(os.path.join(self.data_path, raster_name + '.tif')).astype(np.float32) | |
valid_mask = ~np.isnan(raster).copy() & self.country_mask | |
# log scaling: | |
if log_transform: | |
raster[valid_mask] = np.log1p(raster[valid_mask] - raster[valid_mask].min()) | |
# 0/1 scaling: | |
raster[valid_mask] -= raster[valid_mask].min() | |
raster[valid_mask] /= raster[valid_mask].max() | |
return raster, valid_mask | |
def get_split_labels(self, raster, split_ids, split_of_interest): | |
# get the GT labels for a subset | |
inds_y, inds_x = np.where(split_ids==split_of_interest) | |
return raster[inds_y, inds_x] | |
def get_split_feats(self, model, enc, split_ids, split_of_interest, extra_input=None): | |
locs = utils.coord_grid(self.country_mask.shape, split_ids=split_ids, split_of_interest=split_of_interest) | |
locs = torch.from_numpy(locs).to(self.eval_params['device']) | |
locs_enc = torch.cat([enc.encode(locs), extra_input.expand(locs.shape[0], -1)], 1) if extra_input is not None else enc.encode(locs) | |
with torch.no_grad(): | |
feats = model(locs_enc, return_feats=True).cpu().numpy() | |
return feats | |
def run_evaluation(self, model2, enc, extra_input=None): | |
if self.train_params['model'] == 'ResidualFCNet': | |
model = model2 | |
elif self.train_params['model'] == 'HyperNet': | |
model = lambda x, return_feats=True: model2.pos_enc(x) | |
else: | |
raise NotImplementedError() | |
results = {} | |
for raster_name in self.raster_names: | |
do_log_transform = raster_name in self.raster_names_log_transform | |
raster, valid_mask = self.load_raster(raster_name, do_log_transform) | |
split_ids = utils.create_spatial_split(raster, valid_mask, cell_size=self.eval_params['cell_size']) | |
feats_train = self.get_split_feats(model, enc, split_ids=split_ids, split_of_interest=1, extra_input=extra_input) | |
feats_test = self.get_split_feats(model, enc, split_ids=split_ids, split_of_interest=2, extra_input=extra_input) | |
labels_train = self.get_split_labels(raster, split_ids, 1) | |
labels_test = self.get_split_labels(raster, split_ids, 2) | |
scaler = MinMaxScaler() | |
feats_train_scaled = scaler.fit_transform(feats_train) | |
feats_test_scaled = scaler.transform(feats_test) | |
clf = RidgeCV(alphas=(0.1, 1.0, 10.0), cv=10, fit_intercept=True, scoring='r2').fit(feats_train_scaled, labels_train) | |
train_score = clf.score(feats_train_scaled, labels_train) | |
test_score = clf.score(feats_test_scaled, labels_test) | |
results[f'train_r2_{raster_name}'] = float(train_score) | |
results[f'test_r2_{raster_name}'] = float(test_score) | |
results[f'alpha_{raster_name}'] = float(clf.alpha_) | |
return results | |
def report(self, results): | |
report_fields = [x for x in results if 'test_r2' in x] | |
for field in report_fields: | |
print(f'{field}: {results[field]}') | |
print(np.mean([results[field] for field in report_fields])) | |
# I need train overrides for some of my stuff but it should have zero impact on other things | |
def launch_eval_run(overrides, train_overrides=None): | |
eval_params = setup.get_default_params_eval(overrides) | |
# set up model: | |
eval_params['model_path'] = os.path.join(eval_params['exp_base'], eval_params['experiment_name'], eval_params['ckp_name']) | |
#train_params = torch.load(eval_params['model_path'], map_location='cpu', weights_only=False) | |
train_params = torch.load(eval_params['model_path'], map_location='cpu') | |
default_params = setup.get_default_params_train() | |
for key in default_params: | |
if key not in train_params['params']: | |
train_params['params'][key] = default_params[key] | |
# MINE - this is hopefully just for my models - must ensure this - should have zero impact on hypernets | |
if train_overrides != None: | |
for key, value in train_overrides.items(): | |
#print(f'updating train param {key}') | |
train_params['params'][key] = value | |
model = models.get_model(train_params['params'], inference_only=True) | |
model.load_state_dict(train_params['state_dict'], strict=False) | |
model = model.to(eval_params['device']) | |
model.eval() | |
# create input encoder: | |
if train_params['params']['input_enc'] in ['env', 'sin_cos_env', 'sh_env']: | |
raster = datasets.load_env().to(eval_params['device']) | |
else: | |
raster = None | |
enc = utils.CoordEncoder(train_params['params']['input_enc'], raster=raster, input_dim=train_params['params']['input_dim']) | |
if train_params['params']['input_time']: | |
time_enc = utils.TimeEncoder(input_enc='conical') if train_params['params']['input_time'] else None | |
extra_input = torch.cat([time_enc.encode(torch.tensor([[0.0, 1.0]]))], dim=1).to(eval_params['device']) | |
else: | |
extra_input = None | |
# This should only effect my models | |
# This is where I create the eval "species tokens" from the specified number of context points | |
# TODO just use the existing train params and some if statements to get the right dataset without having to use train overides | |
if train_params['params']['model'] == 'MultiInputModel': | |
train_dataset = datasets.get_train_data(train_params['params']) | |
if 'text' in train_params['params']['dataset']: | |
if eval_params['text_section'] != '': | |
train_dataset.select_text_section(eval_params['text_section']) | |
print(f'Using {eval_params["text_section"]} text for evaluation') | |
train_loader = torch.utils.data.DataLoader( | |
train_dataset, | |
batch_size=train_params['params']['batch_size'], | |
shuffle=True, | |
num_workers=8, | |
collate_fn=getattr(train_dataset, 'collate_fn', None)) | |
# if len(train_params['params']['class_to_taxa']) != train_dataset.class_to_taxa: | |
# Create new embedding layers for the expanded classes | |
num_new_classes = len(train_dataset.class_to_taxa) | |
embedding_dim = model.ema_embeddings.embedding_dim | |
new_ema_embeddings = nn.Embedding(num_embeddings=num_new_classes, embedding_dim=embedding_dim).to(eval_params["device"]) | |
new_eval_embeddings = nn.Embedding(num_embeddings=num_new_classes, embedding_dim=embedding_dim).to(eval_params["device"]) | |
nn.init.xavier_uniform_(new_ema_embeddings.weight) | |
nn.init.xavier_uniform_(new_eval_embeddings.weight) | |
# Convert lists to numpy arrays for indexing | |
class_to_taxa_np = np.array(train_params['params']['class_to_taxa']) | |
class_to_taxa_expanded_np = np.array(train_dataset.class_to_taxa) | |
# Find common taxa and their indices | |
common_taxa, original_indices, expanded_indices = np.intersect1d( | |
class_to_taxa_np, class_to_taxa_expanded_np, return_indices=True) | |
# Update new embeddings for the common taxa | |
new_ema_embeddings.weight.data[expanded_indices] = model.ema_embeddings.weight.data[original_indices] | |
new_eval_embeddings.weight.data[expanded_indices] = model.eval_embeddings.weight.data[original_indices] | |
# Replace old embeddings with new embeddings | |
model.ema_embeddings = new_ema_embeddings | |
model.eval_embeddings = new_eval_embeddings | |
# Print to verify | |
#print("Updating EMA Embeddings: ", model.ema_embeddings.weight.size()) | |
#print("Updating Eval Embeddings: ", model.eval_embeddings.weight.size()) | |
train_params['params']['class_to_taxa'] = train_dataset.class_to_taxa | |
for _, batch in tqdm(enumerate(train_loader)): | |
if train_params['params']['use_text_inputs']: | |
loc_feat, _, class_id, context_feats, _, context_mask, embs = batch | |
loc_feat = loc_feat.to(eval_params['device']) | |
class_id = class_id.to(eval_params['device']) | |
context_feats = context_feats.to(eval_params['device']) | |
context_mask = context_mask.to(eval_params['device']) | |
embs = embs.to(eval_params['device']) | |
# Don't need to do anything with these probs - I am just updating the "eval embeddings" | |
probs = model.forward( | |
x=loc_feat, | |
context_sequence=context_feats, | |
context_mask=context_mask, | |
class_ids=class_id, | |
return_feats=False, | |
return_class_embeddings=False, | |
class_of_interest=None, | |
use_eval_embeddings=True, | |
text_emb = embs) | |
else: | |
loc_feat, _, class_id, context_feats, _, context_mask = batch | |
loc_feat = loc_feat.to(eval_params['device']) | |
class_id = class_id.to(eval_params['device']) | |
context_feats = context_feats.to(eval_params['device']) | |
context_mask = context_mask.to(eval_params['device']) | |
# Don't need to do anything with these probs - I am just updating the "eval embeddings" | |
probs = model.forward( | |
x=loc_feat, | |
context_sequence=context_feats, | |
context_mask=context_mask, | |
class_ids=class_id, | |
return_feats=False, | |
return_class_embeddings=False, | |
class_of_interest=None, | |
use_eval_embeddings=True | |
) | |
print('eval embeddings generated!') | |
elif train_params['params']['model'] == 'VariableInputModel': | |
train_dataset = datasets.get_train_data(train_params['params']) | |
if train_dataset.use_text: | |
if eval_params['text_section'] != '': | |
train_dataset.select_text_section(eval_params['text_section']) | |
print(f'Using {eval_params["text_section"]} text for evaluation') | |
train_loader = torch.utils.data.DataLoader( | |
train_dataset, | |
batch_size=train_params['params']['batch_size'], | |
shuffle=True, | |
num_workers=8, | |
collate_fn=getattr(train_dataset, 'collate_fn', None)) | |
# if len(train_params['params']['class_to_taxa']) != train_dataset.class_to_taxa: | |
# Create new embedding layers for the expanded classes | |
num_new_classes = len(train_dataset.class_to_taxa) | |
embedding_dim = model.ema_embeddings.embedding_dim | |
new_ema_embeddings = nn.Embedding(num_embeddings=num_new_classes, embedding_dim=embedding_dim).to(eval_params["device"]) | |
new_eval_embeddings = nn.Embedding(num_embeddings=num_new_classes, embedding_dim=embedding_dim).to(eval_params["device"]) | |
nn.init.xavier_uniform_(new_ema_embeddings.weight) | |
nn.init.xavier_uniform_(new_eval_embeddings.weight) | |
# Convert lists to numpy arrays for indexing | |
class_to_taxa_np = np.array(train_params['params']['class_to_taxa']) | |
class_to_taxa_expanded_np = np.array(train_dataset.class_to_taxa) | |
# Find common taxa and their indices | |
common_taxa, original_indices, expanded_indices = np.intersect1d( | |
class_to_taxa_np, class_to_taxa_expanded_np, return_indices=True) | |
# Update new embeddings for the common taxa | |
new_ema_embeddings.weight.data[expanded_indices] = model.ema_embeddings.weight.data[original_indices] | |
new_eval_embeddings.weight.data[expanded_indices] = model.eval_embeddings.weight.data[original_indices] | |
# Replace old embeddings with new embeddings | |
model.ema_embeddings = new_ema_embeddings | |
model.eval_embeddings = new_eval_embeddings | |
# Print to verify | |
#print("Updating EMA Embeddings: ", model.ema_embeddings.weight.size()) | |
#print("Updating Eval Embeddings: ", model.eval_embeddings.weight.size()) | |
train_params['params']['class_to_taxa'] = train_dataset.class_to_taxa | |
for _, batch in tqdm(enumerate(train_loader)): | |
loc_feat, _, class_id, context_feats, _, context_mask, text_emb, image_emb, env_emb = batch | |
# print('DO I NEED THE BELOW LINES? DO THEY SLOW THINGS DOWN') | |
# return padded_sequences, padded_locs, class_ids, sequence_mask | |
loc_feat = loc_feat.to(eval_params['device']) | |
class_id = class_id.to(eval_params['device']) | |
context_feats = context_feats.to(eval_params['device']) | |
context_mask = context_mask.to(eval_params['device']) | |
text_emb = text_emb.to(eval_params['device']) | |
image_emb = image_emb.to(eval_params['device']) | |
if env_emb is not None: | |
env_emb = env_emb.to(eval_params['device']) | |
# Don't need to do anything with these probs - I am just updating the "eval embeddings" | |
probs = model.forward(x=loc_feat, | |
context_sequence=context_feats, | |
context_mask=context_mask, | |
class_ids=class_id, | |
text_emb=text_emb, | |
image_emb=image_emb, | |
env_emb=env_emb, | |
return_feats=False, | |
return_class_embeddings=False, | |
class_of_interest=None, | |
use_eval_embeddings=True) | |
print('eval embeddings generated!') | |
print('\n' + eval_params['eval_type']) | |
t = time.time() | |
if eval_params['eval_type'] == 'snt': | |
eval_params['split'] = 'test' # val, test, all | |
eval_params['val_frac'] = 0.50 | |
eval_params['split_seed'] = 7499 | |
evaluator = EvaluatorSNT(train_params['params'], eval_params) | |
results = evaluator.run_evaluation(model, enc, extra_input=extra_input) | |
evaluator.report(results) | |
elif eval_params['eval_type'] == 'iucn': | |
evaluator = EvaluatorIUCN(train_params['params'], eval_params) | |
results = evaluator.run_evaluation(model, enc, extra_input=extra_input) | |
evaluator.report(results) | |
elif eval_params['eval_type'] == 'geo_prior': | |
evaluator = EvaluatorGeoPrior(train_params['params'], eval_params) | |
results = evaluator.run_evaluation(model, enc, extra_input=extra_input) | |
evaluator.report(results) | |
elif eval_params['eval_type'] == 'geo_feature': | |
evaluator = EvaluatorGeoFeature(train_params['params'], eval_params) | |
results = evaluator.run_evaluation(model, enc, extra_input=extra_input) | |
evaluator.report(results) | |
else: | |
raise NotImplementedError('Eval type not implemented.') | |
print(f'evaluation completed in {np.around((time.time()-t)/60, 1)} min') | |
return results | |
class EvaluatorGeoPriorLowRank: | |
def __init__(self, train_params, eval_params): | |
# store parameters: | |
self.train_params = train_params | |
self.eval_params = eval_params | |
with open('paths.json', 'r') as f: | |
paths = json.load(f) | |
# load vision model predictions: | |
self.data = np.load(os.path.join(paths['geo_prior'], 'geo_prior_model_preds.npz')) | |
print(self.data['probs'].shape[0], 'total test observations') | |
# load locations: | |
meta = pd.read_csv(os.path.join(paths['geo_prior'], 'geo_prior_model_meta.csv')) | |
self.obs_locs = np.vstack((meta['longitude'].values, meta['latitude'].values)).T.astype(np.float32) | |
temp = np.array(meta['observed_on'].values, dtype='S10') | |
temp = temp.view('S1').reshape((temp.size, -1)) | |
years = temp[:, :4].view('S4').astype(int)[:, 0] | |
months = temp[:, 5:7].view('S2').astype(int)[:, 0] | |
days = temp[:, 8:10].view('S2').astype(int)[:, 0] | |
days_per_month = np.cumsum([0] + [monthrange(2018, mm)[1] for mm in range(1, 12)]) | |
dates = days_per_month[months - 1] + days - 1 | |
self.dates = np.round((dates) / 365.0, 4).astype(np.float32) | |
# taxonomic mapping: | |
self.taxon_map = self.find_mapping_between_models(self.data['model_to_taxa'], self.train_params['class_to_taxa']) | |
print(self.taxon_map.shape[0], 'out of', len(self.data['model_to_taxa']), 'taxa in both vision and geo models') | |
def find_mapping_between_models(self, vision_taxa, geo_taxa): | |
# this will output an array of size N_overlap X 2 | |
# the first column will be the indices of the vision model, and the second is their | |
# corresponding index in the geo model | |
taxon_map = np.ones((vision_taxa.shape[0], 2), dtype=np.int32)*-1 | |
taxon_map[:, 0] = np.arange(vision_taxa.shape[0]) | |
geo_taxa_arr = np.array(geo_taxa) | |
for tt_id, tt in enumerate(vision_taxa): | |
ind = np.where(geo_taxa_arr==tt)[0] | |
if len(ind) > 0: | |
taxon_map[tt_id, 1] = ind[0] | |
inds = np.where(taxon_map[:, 1]>-1)[0] | |
taxon_map = taxon_map[inds, :] | |
return taxon_map | |
def convert_to_inat_vision_order(self, geo_pred_ip, vision_top_k_prob, vision_top_k_inds, vision_taxa, taxon_map): | |
# this is slow as we turn the sparse input back into the same size as the dense one | |
vision_pred = np.zeros((geo_pred_ip.shape[0], len(vision_taxa)), dtype=np.float32) | |
geo_pred = np.ones((geo_pred_ip.shape[0], len(vision_taxa)), dtype=np.float32) | |
vision_pred[np.arange(vision_pred.shape[0])[..., np.newaxis], vision_top_k_inds] = vision_top_k_prob | |
geo_pred[:, taxon_map[:, 0]] = geo_pred_ip[:, taxon_map[:, 1]] | |
return geo_pred, vision_pred | |
def run_evaluation(self, model): | |
results = {} | |
# loop over in batches | |
batch_start = np.hstack((np.arange(0, self.data['probs'].shape[0], self.eval_params['batch_size']), self.data['probs'].shape[0])) | |
correct_pred = np.zeros(self.data['probs'].shape[0]) | |
from tqdm import tqdm | |
for bb_id, bb in tqdm(enumerate(range(len(batch_start)-1))): | |
batch_inds = np.arange(batch_start[bb], batch_start[bb+1]) | |
vision_probs = self.data['probs'][batch_inds, :] | |
vision_inds = self.data['inds'][batch_inds, :] | |
gt = self.data['labels'][batch_inds] | |
dates = torch.from_numpy(self.dates[batch_inds]) | |
obs_locs_batch = torch.from_numpy(self.obs_locs[batch_inds, :]).to(self.eval_params['device']) | |
with torch.no_grad(): | |
geo_pdf = torch.log(model.sample(obs_locs_batch)).T | |
for bias in range(11+5, 12+5): | |
geo_pred, vision_pred = self.convert_to_inat_vision_order(geo_pdf+bias, vision_probs, vision_inds, | |
self.data['model_to_taxa'], self.taxon_map) | |
geo_pred = softmax(torch.from_numpy(geo_pred), dim=1).numpy() | |
#print(bias, (np.argmax(vision_pred*geo_pred2, 1) == gt).mean().item()) | |
comb_pred = np.argmax(vision_pred*geo_pred, 1) | |
comb_pred = (comb_pred==gt) | |
correct_pred[batch_inds] = comb_pred | |
accuracy_by_taxa = np.zeros(len(self.data['model_to_taxa'])) | |
for tt_id, tt in enumerate(self.data['model_to_taxa']): | |
inds = np.where(self.data['labels'] == tt)[0] | |
accuracy_by_taxa[tt_id] = float((correct_pred[inds].mean())) | |
results['vision_only_top_1'] = float((self.data['inds'][:, -1] == self.data['labels']).mean()) | |
results['vision_geo_top_1'] = float(correct_pred.mean()) | |
return results | |
def report(self, results): | |
print('Overall accuracy vision only model', round(results['vision_only_top_1'], 3)) | |
print('Overall accuracy of geo model ', round(results['vision_geo_top_1'], 3)) | |
print('Gain ', round(results['vision_geo_top_1'] - results['vision_only_top_1'], 3)) | |
# MINE MINE MINE - these are just to help with low shot plotting. Can probably be elsewhere. | |
def generate_eval_embeddings(overrides, taxa_of_interest, num_context, train_overrides=None): | |
eval_params = setup.get_default_params_eval(overrides) | |
# set up model: | |
eval_params['model_path'] = os.path.join(eval_params['exp_base'], eval_params['experiment_name'], eval_params['ckp_name']) | |
eval_params['device'] = 'cpu' | |
train_params = torch.load(eval_params['model_path'], map_location='cpu') | |
train_params['params']['device'] = 'cpu' | |
default_params = setup.get_default_params_train() | |
for key in default_params: | |
if key not in train_params['params']: | |
train_params['params'][key] = default_params[key] | |
# create input encoder: | |
if train_params['params']['input_enc'] in ['env', 'sin_cos_env']: | |
raster = datasets.load_env().to(eval_params['device']) | |
else: | |
raster = None | |
enc = utils.CoordEncoder(train_params['params']['input_enc'], raster=raster, input_dim=train_params['params']['input_dim']) | |
if train_params['params']['input_time']: | |
time_enc = utils.TimeEncoder(input_enc='conical') if train_params['params']['input_time'] else None | |
extra_input = torch.cat([time_enc.encode(torch.tensor([[0.0, 1.0]]))], dim=1).to(eval_params['device']) | |
else: | |
extra_input = None | |
if train_overrides != None: | |
for key, value in train_overrides.items(): | |
#print(f'updating train param {key}') | |
train_params['params'][key] = value | |
train_dataset = datasets.get_train_data(train_params['params']) | |
train_loader = torch.utils.data.DataLoader( | |
train_dataset, | |
batch_size=train_params['params']['batch_size'], | |
shuffle=True, | |
num_workers=8, | |
collate_fn=getattr(train_dataset, 'collate_fn', None)) | |
model = models.get_model(train_params['params'], inference_only=True) | |
# model.load_state_dict(train_params['state_dict'], strict=True) | |
model.load_state_dict(train_params['state_dict'], strict=False) | |
model = model.to(eval_params['device']) | |
model.eval() | |
# Create new embedding layers for the expanded classes | |
num_new_classes = len(train_dataset.class_to_taxa) | |
embedding_dim = model.ema_embeddings.embedding_dim | |
new_ema_embeddings = nn.Embedding(num_embeddings=num_new_classes, embedding_dim=embedding_dim).to(eval_params["device"]) | |
new_eval_embeddings = nn.Embedding(num_embeddings=num_new_classes, embedding_dim=embedding_dim).to(eval_params["device"]) | |
nn.init.xavier_uniform_(new_ema_embeddings.weight) | |
nn.init.xavier_uniform_(new_eval_embeddings.weight) | |
# Convert lists to numpy arrays for indexing | |
class_to_taxa_np = np.array(train_params['params']['class_to_taxa']) | |
class_to_taxa_expanded_np = np.array(train_dataset.class_to_taxa) | |
# Find common taxa and their indices | |
common_taxa, original_indices, expanded_indices = np.intersect1d( | |
class_to_taxa_np, class_to_taxa_expanded_np, return_indices=True) | |
# Update new embeddings for the common taxa | |
new_ema_embeddings.weight.data[expanded_indices] = model.ema_embeddings.weight.data[original_indices] | |
new_eval_embeddings.weight.data[expanded_indices] = model.eval_embeddings.weight.data[original_indices] | |
# Replace old embeddings with new embeddings | |
model.ema_embeddings = new_ema_embeddings | |
model.eval_embeddings = new_eval_embeddings | |
# Print to verify | |
#print("Updated EMA Embeddings: ", model.ema_embeddings.weight.size()) | |
#print("Updated Eval Embeddings: ", model.eval_embeddings.weight.size()) | |
train_params['params']['class_to_taxa'] = train_dataset.class_to_taxa | |
class_of_interest = train_dataset.class_to_taxa.index(taxa_of_interest) | |
# Find the index of class_of_interest in the labels tensor | |
loc_index_of_interest = (train_dataset.labels == class_of_interest).nonzero(as_tuple=True)[0].item() | |
# loc_index_of_interest = train_dataset.labels.index(class_of_interest) | |
loc_of_interest = train_dataset.loc_feats[loc_index_of_interest] | |
all_class_context_feats = train_dataset.per_class_loc_feats[class_of_interest] | |
all_class_context_locs = train_dataset.per_class_locs[class_of_interest] | |
context_feats_of_interest = all_class_context_feats[:num_context,:] | |
context_locs_of_interest = all_class_context_locs[:num_context,:] | |
# context_mask = context_feats_of_interest != -10 | |
# context_mask = None | |
# context_mask = (context_locs_of_interest == -10).all(dim=-1).to(eval_params['device']) | |
context_mask = (context_locs_of_interest == -10).all(dim=-1).to(eval_params['device']).unsqueeze(0) | |
probs = model.forward( | |
x=loc_of_interest.to(train_params['params']['device']), | |
context_sequence=context_feats_of_interest.to(train_params['params']['device']), | |
context_mask=context_mask, | |
class_ids=class_of_interest, | |
return_feats=False, | |
return_class_embeddings=False, | |
class_of_interest=None, | |
use_eval_embeddings=True | |
) | |
#print(f'eval embedding generated for class {class_of_interest}, taxa {taxa_of_interest}') | |
return model, context_locs_of_interest, train_params, class_of_interest | |
def generate_eval_embedding_from_given_points(context_points, overrides, taxa_of_interest, train_overrides=None, text_emb=None): | |
eval_params = setup.get_default_params_eval(overrides) | |
# set up model: | |
eval_params['model_path'] = os.path.join(eval_params['exp_base'], eval_params['experiment_name'], eval_params['ckp_name']) | |
train_params = torch.load(eval_params['model_path'], map_location='cpu') | |
default_params = setup.get_default_params_train() | |
for key in default_params: | |
if key not in train_params['params']: | |
train_params['params'][key] = default_params[key] | |
# create input encoder: | |
if train_params['params']['input_enc'] in ['env', 'sin_cos_env']: | |
raster = datasets.load_env().to(eval_params['device']) | |
else: | |
raster = None | |
enc = utils.CoordEncoder(train_params['params']['input_enc'], raster=raster, input_dim=train_params['params']['input_dim']) | |
if train_params['params']['input_time']: | |
time_enc = utils.TimeEncoder(input_enc='conical') if train_params['params']['input_time'] else None | |
extra_input = torch.cat([time_enc.encode(torch.tensor([[0.0, 1.0]]))], dim=1).to(eval_params['device']) | |
else: | |
extra_input = None | |
if train_overrides != None: | |
for key, value in train_overrides.items(): | |
#print(f'updating train param {key}') | |
train_params['params'][key] = value | |
# create context point encoder | |
transformer_input_enc = train_params['params']['transformer_input_enc'] | |
if transformer_input_enc in ['env', 'sin_cos_env']: | |
transformer_raster = datasets.load_env().to(eval_params['device']) | |
else: | |
transformer_raster = None | |
token_dim = train_params['params']['species_dim'] | |
if transformer_input_enc == 'sinr': | |
transformer_enc = enc | |
else: | |
transformer_enc = utils.CoordEncoder(transformer_input_enc, transformer_raster, input_dim=token_dim) | |
# transformer_enc = utils.CoordEncoder(transformer_input_enc, transformer_raster, input_dim=token_dim) | |
# load model | |
model = models.get_model(train_params['params'], inference_only=True) | |
# model.load_state_dict(train_params['state_dict'], strict=True) | |
model.load_state_dict(train_params['state_dict'], strict=False) | |
model = model.to(eval_params['device']) | |
model.eval() | |
# # Create new embedding layers for the expanded classes | |
# num_new_classes = len(train_params['params']['class_to_taxa']) | |
embedding_dim = model.ema_embeddings.embedding_dim | |
# new_ema_embeddings = nn.Embedding(num_embeddings=num_new_classes, embedding_dim=embedding_dim).to(eval_params["device"]) | |
new_eval_embeddings = nn.Embedding(num_embeddings=model.eval_embeddings.weight.size()[0], embedding_dim=embedding_dim).to(eval_params["device"]) | |
# Update new embeddings for the common taxa | |
new_eval_embeddings.weight.data = model.eval_embeddings.weight.data | |
# Replace old embeddings with new embeddings | |
model.eval_embeddings = new_eval_embeddings | |
# Print to verify | |
#print("Updated EMA Embeddings: ", model.ema_embeddings.weight.size()) | |
#print("Updated Eval Embeddings: ", model.eval_embeddings.weight.size()) | |
class_of_interest = 0 | |
just_loc = torch.from_numpy(np.array([[0.0,0.0]]).astype(np.float32)) | |
loc_of_interest = enc.encode(just_loc, normalize=False) | |
context_points = torch.from_numpy(np.array(context_points).astype(np.float32)) | |
all_class_context_feats = transformer_enc.encode(context_points, normalize=False) | |
all_class_context_locs = context_points | |
context_feats_of_interest = all_class_context_feats | |
context_locs_of_interest = all_class_context_locs | |
# context_mask = context_feats_of_interest[:,0] != -10 | |
# context_mask = None | |
context_mask = torch.from_numpy(np.full((1, context_feats_of_interest.shape[0]), False)) | |
# probs = model.forward( | |
# x=loc_of_interest.to(train_params['params']['device']), | |
# context_sequence=context_feats_of_interest.to(train_params['params']['device']), | |
# context_mask=context_mask, | |
# class_ids=class_of_interest, | |
# return_feats=False, | |
# return_class_embeddings=False, | |
# class_of_interest=None, | |
# use_eval_embeddings=True | |
# ) | |
probs = model.forward( | |
x=loc_of_interest.to(eval_params['device']), | |
context_sequence=context_feats_of_interest.to(eval_params['device']), | |
context_mask=context_mask, | |
class_ids=class_of_interest, | |
return_feats=False, | |
return_class_embeddings=False, | |
class_of_interest=None, | |
use_eval_embeddings=True, | |
text_emb=text_emb | |
) | |
#print(f'eval embedding generated for class {class_of_interest}, from hand selected context points') | |
return model, context_locs_of_interest, train_params, class_of_interest |