Spaces:
Sleeping
Sleeping
""" | |
Demo that takes an iNaturalist taxa ID as input and generates a prediction | |
for each location on the globe and saves the ouput as an image. | |
""" | |
import torch | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import os | |
import json | |
import argparse | |
import utils | |
import datasets | |
import eval | |
import create_inputs_to_fs_sinr | |
text_model = './experiments/gpt_data.pt' | |
def extract_grit_token(model, text:str): | |
def gritlm_instruction(instruction): | |
return "<|user|>\n" + instruction + "\n<|embed|>\n" if instruction else "<|embed|>\n" | |
d_rep = model.encode([text], instruction=gritlm_instruction("")) | |
d_rep = torch.from_numpy(d_rep) | |
return d_rep | |
def choose_context_points_from_map(eval_params): | |
context_points = [] | |
if False: | |
def onclick(event): | |
if event.xdata is not None and event.ydata is not None: | |
# Convert image coordinates to normalized geographical coordinates | |
lon = event.xdata / mask.shape[1] * 2 - 1 | |
lat = 1 - event.ydata / mask.shape[0] * 2 | |
context_points.append((lon, lat)) | |
print(f"Added context point: ({lon}, {lat})") | |
# Load ocean mask | |
with open('paths.json', 'r') as f: | |
paths = json.load(f) | |
if eval_params['high_res']: | |
mask = np.load(os.path.join(paths['masks'], 'ocean_mask_hr.npy')) | |
else: | |
mask = np.load(os.path.join(paths['masks'], 'ocean_mask.npy')) | |
mask_inds = np.where(mask.reshape(-1) == 1)[0] | |
# # Generate input features | |
# locs = utils.coord_grid(mask.shape) | |
# if not eval_params['disable_ocean_mask']: | |
# locs = locs[mask_inds, :] | |
# locs = torch.from_numpy(locs) | |
# Reshape and create masked array for visualization | |
op_im = np.ones((mask.shape[0] * mask.shape[1])) * np.nan # Set to NaN | |
op_im[mask_inds] = 0 # Placeholder for the mask visualization | |
op_im = op_im.reshape((mask.shape[0], mask.shape[1])) | |
op_im = np.ma.masked_invalid(op_im) | |
# Set color for masked values | |
cmap = plt.cm.plasma | |
cmap.set_bad(color='none') | |
plt.ioff() | |
# Display the map and capture context points | |
fig, ax = plt.subplots(figsize=(6, 3), dpi=334) # Define the figure size | |
ax.imshow(op_im, cmap=cmap, interpolation='nearest') # Display the image | |
ax.axis('off') # Turn off the axis | |
# Connect the onclick event to the handler | |
cid = fig.canvas.mpl_connect('button_press_event', onclick) | |
plt.show(block=True) # Block execution until the window is closed | |
print(f"Context points collected: {context_points}") | |
else: | |
#USA | |
#TODO: 37.541170, -92.003293 1. flip order, then 2. normalize so divide by 180 and 90 | |
context_points = [(-0.5884012559178662, 0.46394662490802496), (-0.5451199953511522, 0.4504212309809269), | |
(-0.5437674559584422, 0.5342786733289353), (-0.589753795310576, 0.5342786733289353)] | |
print(f"Context points collected: {context_points}") | |
return context_points | |
def main(eval_params): | |
# load params | |
with open('paths.json', 'r') as f: | |
paths = json.load(f) | |
ckp_name = os.path.split(eval_params['model_path'])[-1] | |
experiment_name = os.path.split(os.path.split(eval_params['model_path'])[-2])[-1] | |
eval_overrides = {'ckp_name':ckp_name, | |
'experiment_name':experiment_name, | |
'device':eval_params['device']} | |
train_overrides = {'dataset': 'eval_transformer'} | |
#grit = GritLM("GritLM/GritLM-7B", torch_dtype="auto", mode="embedding") | |
#grit_gpt = torch.load(text_model, map_location='cpu') | |
#context_model = torch.load("experiments/zero_shot_ls_sin_cos_cap_1000_text_context_20_sinr_two_layer_nn/model.pt", map_location=torch.device('cpu')) | |
context_data = np.load('data/positive_eval_data.npz') | |
text_type_value = 0 | |
for pt in eval_params['context_pt_trial']: | |
number_of_context_points = pt | |
if eval_params['choose_context_points'] == 1: | |
#context_points = choose_context_points_from_map(eval_params) | |
text_emb, text_type_value = create_inputs_to_fs_sinr.use_pregenerated_textemb_fromchris(taxon_id=eval_params['test_taxa'], | |
text_type=eval_params['text_type']) | |
context_points = create_inputs_to_fs_sinr.get_eval_context_points(taxa_id=eval_params['test_taxa'], | |
context_data=context_data, | |
size=number_of_context_points) | |
model, context_locs_of_interest, train_params, class_of_interest = eval.generate_eval_embedding_from_given_points( | |
context_points=context_points, | |
overrides=eval_overrides, | |
taxa_of_interest=eval_params['taxa_id'], | |
train_overrides=train_overrides, | |
text_emb=text_emb) | |
#TODO: why is taxa_id updated to 'selected pts'?? | |
eval_params['taxa_id'] = 'selected_points' | |
else: | |
model, context_locs_of_interest, train_params, class_of_interest = eval.generate_eval_embeddings( | |
overrides=eval_overrides, | |
taxa_of_interest=eval_params['taxa_id'], | |
num_context=eval_params['num_context'], | |
train_overrides=train_overrides) | |
if train_params['params']['input_enc'] in ['env', 'sin_cos_env']: | |
raster = datasets.load_env() | |
else: | |
raster = None | |
enc = utils.CoordEncoder(train_params['params']['input_enc'], raster=raster, input_dim=train_params['params']['input_dim']) | |
enc_time = utils.CoordEncoder('sin_cos', raster=None, input_dim=2 * train_params['params']['input_time_dim']) | |
# load ocean mask | |
if eval_params['high_res']: | |
mask = np.load(os.path.join(paths['masks'], 'ocean_mask_hr.npy')) | |
else: | |
mask = np.load(os.path.join(paths['masks'], 'ocean_mask.npy')) | |
#mask = 0*mask+1 | |
mask_inds = np.where(mask.reshape(-1) == 1)[0] | |
# generate input features | |
locs = utils.coord_grid(mask.shape) | |
if not eval_params['disable_ocean_mask']: | |
locs = locs[mask_inds, :] | |
locs = torch.from_numpy(locs) | |
locs_enc = enc.encode(locs).to(eval_params['device']) | |
if train_params['params']['input_time_dim'] > 0: | |
extra_input = torch.cat([enc_time.encode(torch.tensor([[0.0]]), normalize=False), torch.tensor([[1.0]])], | |
dim=1).to(eval_params['device']) | |
locs_enc = torch.cat((locs_enc, extra_input.repeat(locs_enc.shape[0], 1)), dim=1) | |
with torch.no_grad(): | |
# Here if we set eval to False we will see what the ema embeddings look like (currently as ema is 1.0 this is just the last training example seen) | |
preds = model.embedding_forward(x=locs_enc, class_ids=None, return_feats=False, class_of_interest=class_of_interest, eval=True).cpu().numpy() | |
# threshold predictions | |
if eval_params['threshold'] > 0: | |
print(f'Applying threshold of {eval_params["threshold"]} to the predictions.') | |
preds[preds<eval_params['threshold']] = 0.0 | |
preds[preds>=eval_params['threshold']] = 1.0 | |
# mask data | |
if not eval_params['disable_ocean_mask']: | |
op_im = np.ones((mask.shape[0] * mask.shape[1])) * np.nan # set to NaN | |
op_im[mask_inds] = preds | |
else: | |
op_im = preds | |
# reshape and create masked array for visualization | |
op_im = op_im.reshape((mask.shape[0], mask.shape[1])) | |
op_im = np.ma.masked_invalid(op_im) | |
# set color for masked values | |
cmap = plt.cm.plasma | |
cmap.set_bad(color='none') | |
if eval_params['set_max_cmap_to_1']: | |
vmax = 1.0 | |
else: | |
vmax = np.max(op_im) | |
# # Display the image | |
# if eval_params['show_map'] == 1: | |
# fig, ax = plt.subplots() | |
# cax = ax.imshow(op_im, vmin=0, vmax=vmax, cmap=cmap) | |
# fig.colorbar(cax) | |
# plt.show(block=True) # Set block=True to block code execution until the window is closed | |
if eval_params['show_map'] == 1: | |
# Display the image | |
fig, ax = plt.subplots(figsize=(6,3), dpi=334) | |
plt.imshow(op_im, vmin=0, vmax=vmax, cmap=cmap, interpolation='nearest') # Display the image | |
plt.axis('off') # Turn off the axis | |
if eval_params['show_context_points'] == 1: | |
# Convert the tensor to numpy array if it's not already | |
context_locs = context_locs_of_interest.numpy() if isinstance(context_locs_of_interest, torch.Tensor) else context_locs_of_interest | |
# Convert context locations directly to image coordinates | |
#delete our dumby context point (at 0,0) | |
image_x = (context_locs[1:, 0] + 1) / 2 * op_im.shape[1] # Scale longitude from [-1, 1] to [0, image width] | |
image_y = (1 - (context_locs[1:, 1] + 1) / 2) * op_im.shape[ | |
0] # Scale latitude from [-1, 1] to [0, image height] | |
from matplotlib.offsetbox import OffsetImage, AnnotationBbox | |
# Plot the context locations | |
def getImage(path): | |
return OffsetImage(plt.imread(path), zoom=.04) | |
for x0, y0 in zip(image_x, image_y): | |
ab = AnnotationBbox(getImage('black_circle.png'), (x0, y0), frameon=False) | |
ax.add_artist(ab) | |
#plt.scatter(image_x, image_y, c='green', s=30, marker=r'$\checkmark$') # Adjust color and size of the point | |
#plt.show(block=True) # Block execution until the window is closed | |
exp_name = eval_params['model_path'].split(os.path.sep)[-2] | |
# save image | |
#save_loc = os.path.join(eval_params['op_path'], exp_name + '_' + str(eval_params['taxa_id']) + '_' + eval_params['additional_save_name'] +'_map.png') | |
#save_loc = os.path.join(eval_params['op_path'], exp_name + '_' + str(eval_params['taxa_id']) + '_' + eval_params['additional_save_name'] +'_map.png') | |
#save_loc = 'images/testenv_' + eval_params['taxa_name'] + '(' + eval_params['taxa_id'] + ')_'+ eval_params['text_type'] + '(' + str(text_type_value) + ')_' + str(number_of_context_points) +'.png' | |
save_loc = 'images/testenv_' + eval_params['taxa_name'] + '(' + eval_params['taxa_id'] + ')_'+ eval_params['text_type'] + '_' + str(number_of_context_points) +'.png' | |
print(f'Saving image to {save_loc}') | |
plt.savefig(save_loc, bbox_inches='tight', pad_inches=0, dpi=334) | |
# plt.imsave(fname=save_loc, arr=op_im, vmin=0, vmax=vmax, cmap=cmap) | |
plt.show(block=False) # Block execution until the window is closed | |
return True | |
if __name__ == '__main__': | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
info_str = '\nDemo that takes an iNaturalist taxa ID as input and ' + \ | |
'generates a predicted range for each location on the globe ' + \ | |
'and saves the ouput as an image.\n\n' + \ | |
'Warning: these estimated ranges should be validated before use.' | |
parser = argparse.ArgumentParser(usage=info_str) | |
# parser.add_argument('--model_path', type=str, default='./pretrained_models/model_an_full_input_enc_sin_cos_hard_cap_num_per_class_1000.pt') | |
# parser.add_argument('--model_path', type=str, default='./experiments/transformer_ema_1.0/model_10.pt') | |
# parser.add_argument('--model_path', type=str, default='./experiments/03_08_coord_multihead.pt/model.pt') | |
# parser.add_argument('--model_path', type=str, default='./experimentvs/coord_context_20_without_registry/model_best.pt') | |
# parser.add_argument('--model_path', type=str, default='./experiments/coord_sinr_inputs_context_20_without_registry/model_best.pt') | |
parser.add_argument('--model_path', type=str, default='./experiments/zero_shot_ls_sin_cos_env_cap_1000_text_context_20_sinr_two_layer_nn/model.pt') | |
#parser.add_argument('--model_path', type=str, default='./experiments/zero_shot_ls_sin_cos_cap_1000_text_context_20_sinr_two_layer_nn/model.pt') | |
# parser.add_argument('--taxa_id', type=int, default=144575, help='iNaturalist taxon ID.') | |
# parser.add_argument('--taxa_id', type=int, default=9083, help='iNaturalist taxon ID.') | |
parser.add_argument('--taxa_id', type=int, default=3352, help='iNaturalist taxon ID.') | |
parser.add_argument('--threshold', type=float, default=-1, help='Threshold the range map [0, 1].') | |
parser.add_argument('--op_path', type=str, default='./images/', help='Location where the output image will be saved.') | |
parser.add_argument('--rand_taxa', action='store_true', help='Select a random taxa.') | |
parser.add_argument('--high_res', action='store_true', help='Generate higher resolution output.') | |
parser.add_argument('--disable_ocean_mask', action='store_true', help='Do not use an ocean mask.') | |
parser.add_argument('--set_max_cmap_to_1', action='store_true', help='Consistent maximum intensity ouput.') | |
parser.add_argument('--device', type=str, default='cpu', help='cpu or cuda') | |
#parser.add_argument('--device', type=str, default='cuda:3', help='cpu or cuda') | |
parser.add_argument('--show_map', type=int, default=1, help='shows the map if 1') | |
parser.add_argument('--show_context_points', type=int, default=1, help='also plots context points if 1') | |
parser.add_argument('--prefix', type=str, default='') | |
parser.add_argument('--num_context', type=int, default=5) | |
parser.add_argument('--choose_context_points', type=int, default=1) | |
parser.add_argument('--additional_save_name', type=str, default="") | |
#taxas: black&whitewarbler(10286), hyacinth macaw(18938), yellow baboon(67683) | |
# bawnswallow (11901), pika(43188), loon(4626), eurorobin(13094) | |
# southernflyingsquirrel (46272) | |
parser.add_argument('--taxa_name', type=str, default='sfs', help='Name of the taxon.') | |
parser.add_argument('--test_taxa', type=int, default=46272, help='Taxon ID to test.') | |
parser.add_argument('--text_type', type=str, default='range', help='Type of text for input.') | |
parser.add_argument('--context_pt_trial', type=int, nargs='+', default=[0, 1, 2, 5, 10, 20], help='List of context points for trial.') | |
eval_params = vars(parser.parse_args()) | |
if not os.path.isdir(eval_params['op_path']): | |
os.makedirs(eval_params['op_path']) | |
eval_params['high_res'] = True | |
main(eval_params) | |