""" Demo that takes an iNaturalist taxa ID as input and generates a prediction for each location on the globe and saves the ouput as an image. """ import torch import numpy as np import matplotlib.pyplot as plt import os import json import argparse import utils import datasets import eval import create_inputs_to_fs_sinr text_model = './experiments/gpt_data.pt' def extract_grit_token(model, text:str): def gritlm_instruction(instruction): return "<|user|>\n" + instruction + "\n<|embed|>\n" if instruction else "<|embed|>\n" d_rep = model.encode([text], instruction=gritlm_instruction("")) d_rep = torch.from_numpy(d_rep) return d_rep def choose_context_points_from_map(eval_params): context_points = [] if False: def onclick(event): if event.xdata is not None and event.ydata is not None: # Convert image coordinates to normalized geographical coordinates lon = event.xdata / mask.shape[1] * 2 - 1 lat = 1 - event.ydata / mask.shape[0] * 2 context_points.append((lon, lat)) print(f"Added context point: ({lon}, {lat})") # Load ocean mask with open('paths.json', 'r') as f: paths = json.load(f) if eval_params['high_res']: mask = np.load(os.path.join(paths['masks'], 'ocean_mask_hr.npy')) else: mask = np.load(os.path.join(paths['masks'], 'ocean_mask.npy')) mask_inds = np.where(mask.reshape(-1) == 1)[0] # # Generate input features # locs = utils.coord_grid(mask.shape) # if not eval_params['disable_ocean_mask']: # locs = locs[mask_inds, :] # locs = torch.from_numpy(locs) # Reshape and create masked array for visualization op_im = np.ones((mask.shape[0] * mask.shape[1])) * np.nan # Set to NaN op_im[mask_inds] = 0 # Placeholder for the mask visualization op_im = op_im.reshape((mask.shape[0], mask.shape[1])) op_im = np.ma.masked_invalid(op_im) # Set color for masked values cmap = plt.cm.plasma cmap.set_bad(color='none') plt.ioff() # Display the map and capture context points fig, ax = plt.subplots(figsize=(6, 3), dpi=334) # Define the figure size ax.imshow(op_im, cmap=cmap, interpolation='nearest') # Display the image ax.axis('off') # Turn off the axis # Connect the onclick event to the handler cid = fig.canvas.mpl_connect('button_press_event', onclick) plt.show(block=True) # Block execution until the window is closed print(f"Context points collected: {context_points}") else: #USA #TODO: 37.541170, -92.003293 1. flip order, then 2. normalize so divide by 180 and 90 context_points = [(-0.5884012559178662, 0.46394662490802496), (-0.5451199953511522, 0.4504212309809269), (-0.5437674559584422, 0.5342786733289353), (-0.589753795310576, 0.5342786733289353)] print(f"Context points collected: {context_points}") return context_points def main(eval_params): # load params with open('paths.json', 'r') as f: paths = json.load(f) ckp_name = os.path.split(eval_params['model_path'])[-1] experiment_name = os.path.split(os.path.split(eval_params['model_path'])[-2])[-1] eval_overrides = {'ckp_name':ckp_name, 'experiment_name':experiment_name, 'device':eval_params['device']} train_overrides = {'dataset': 'eval_transformer'} #grit = GritLM("GritLM/GritLM-7B", torch_dtype="auto", mode="embedding") #grit_gpt = torch.load(text_model, map_location='cpu') #context_model = torch.load("experiments/zero_shot_ls_sin_cos_cap_1000_text_context_20_sinr_two_layer_nn/model.pt", map_location=torch.device('cpu')) context_data = np.load('data/positive_eval_data.npz') text_type_value = 0 for pt in eval_params['context_pt_trial']: number_of_context_points = pt if eval_params['choose_context_points'] == 1: #context_points = choose_context_points_from_map(eval_params) text_emb, text_type_value = create_inputs_to_fs_sinr.use_pregenerated_textemb_fromchris(taxon_id=eval_params['test_taxa'], text_type=eval_params['text_type']) context_points = create_inputs_to_fs_sinr.get_eval_context_points(taxa_id=eval_params['test_taxa'], context_data=context_data, size=number_of_context_points) model, context_locs_of_interest, train_params, class_of_interest = eval.generate_eval_embedding_from_given_points( context_points=context_points, overrides=eval_overrides, taxa_of_interest=eval_params['taxa_id'], train_overrides=train_overrides, text_emb=text_emb) #TODO: why is taxa_id updated to 'selected pts'?? eval_params['taxa_id'] = 'selected_points' else: model, context_locs_of_interest, train_params, class_of_interest = eval.generate_eval_embeddings( overrides=eval_overrides, taxa_of_interest=eval_params['taxa_id'], num_context=eval_params['num_context'], train_overrides=train_overrides) if train_params['params']['input_enc'] in ['env', 'sin_cos_env']: raster = datasets.load_env() else: raster = None enc = utils.CoordEncoder(train_params['params']['input_enc'], raster=raster, input_dim=train_params['params']['input_dim']) enc_time = utils.CoordEncoder('sin_cos', raster=None, input_dim=2 * train_params['params']['input_time_dim']) # load ocean mask if eval_params['high_res']: mask = np.load(os.path.join(paths['masks'], 'ocean_mask_hr.npy')) else: mask = np.load(os.path.join(paths['masks'], 'ocean_mask.npy')) #mask = 0*mask+1 mask_inds = np.where(mask.reshape(-1) == 1)[0] # generate input features locs = utils.coord_grid(mask.shape) if not eval_params['disable_ocean_mask']: locs = locs[mask_inds, :] locs = torch.from_numpy(locs) locs_enc = enc.encode(locs).to(eval_params['device']) if train_params['params']['input_time_dim'] > 0: extra_input = torch.cat([enc_time.encode(torch.tensor([[0.0]]), normalize=False), torch.tensor([[1.0]])], dim=1).to(eval_params['device']) locs_enc = torch.cat((locs_enc, extra_input.repeat(locs_enc.shape[0], 1)), dim=1) with torch.no_grad(): # Here if we set eval to False we will see what the ema embeddings look like (currently as ema is 1.0 this is just the last training example seen) preds = model.embedding_forward(x=locs_enc, class_ids=None, return_feats=False, class_of_interest=class_of_interest, eval=True).cpu().numpy() # threshold predictions if eval_params['threshold'] > 0: print(f'Applying threshold of {eval_params["threshold"]} to the predictions.') preds[preds=eval_params['threshold']] = 1.0 # mask data if not eval_params['disable_ocean_mask']: op_im = np.ones((mask.shape[0] * mask.shape[1])) * np.nan # set to NaN op_im[mask_inds] = preds else: op_im = preds # reshape and create masked array for visualization op_im = op_im.reshape((mask.shape[0], mask.shape[1])) op_im = np.ma.masked_invalid(op_im) # set color for masked values cmap = plt.cm.plasma cmap.set_bad(color='none') if eval_params['set_max_cmap_to_1']: vmax = 1.0 else: vmax = np.max(op_im) # # Display the image # if eval_params['show_map'] == 1: # fig, ax = plt.subplots() # cax = ax.imshow(op_im, vmin=0, vmax=vmax, cmap=cmap) # fig.colorbar(cax) # plt.show(block=True) # Set block=True to block code execution until the window is closed if eval_params['show_map'] == 1: # Display the image fig, ax = plt.subplots(figsize=(6,3), dpi=334) plt.imshow(op_im, vmin=0, vmax=vmax, cmap=cmap, interpolation='nearest') # Display the image plt.axis('off') # Turn off the axis if eval_params['show_context_points'] == 1: # Convert the tensor to numpy array if it's not already context_locs = context_locs_of_interest.numpy() if isinstance(context_locs_of_interest, torch.Tensor) else context_locs_of_interest # Convert context locations directly to image coordinates #delete our dumby context point (at 0,0) image_x = (context_locs[1:, 0] + 1) / 2 * op_im.shape[1] # Scale longitude from [-1, 1] to [0, image width] image_y = (1 - (context_locs[1:, 1] + 1) / 2) * op_im.shape[ 0] # Scale latitude from [-1, 1] to [0, image height] from matplotlib.offsetbox import OffsetImage, AnnotationBbox # Plot the context locations def getImage(path): return OffsetImage(plt.imread(path), zoom=.04) for x0, y0 in zip(image_x, image_y): ab = AnnotationBbox(getImage('black_circle.png'), (x0, y0), frameon=False) ax.add_artist(ab) #plt.scatter(image_x, image_y, c='green', s=30, marker=r'$\checkmark$') # Adjust color and size of the point #plt.show(block=True) # Block execution until the window is closed exp_name = eval_params['model_path'].split(os.path.sep)[-2] # save image #save_loc = os.path.join(eval_params['op_path'], exp_name + '_' + str(eval_params['taxa_id']) + '_' + eval_params['additional_save_name'] +'_map.png') #save_loc = os.path.join(eval_params['op_path'], exp_name + '_' + str(eval_params['taxa_id']) + '_' + eval_params['additional_save_name'] +'_map.png') #save_loc = 'images/testenv_' + eval_params['taxa_name'] + '(' + eval_params['taxa_id'] + ')_'+ eval_params['text_type'] + '(' + str(text_type_value) + ')_' + str(number_of_context_points) +'.png' save_loc = 'images/testenv_' + eval_params['taxa_name'] + '(' + eval_params['taxa_id'] + ')_'+ eval_params['text_type'] + '_' + str(number_of_context_points) +'.png' print(f'Saving image to {save_loc}') plt.savefig(save_loc, bbox_inches='tight', pad_inches=0, dpi=334) # plt.imsave(fname=save_loc, arr=op_im, vmin=0, vmax=vmax, cmap=cmap) plt.show(block=False) # Block execution until the window is closed return True if __name__ == '__main__': device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') info_str = '\nDemo that takes an iNaturalist taxa ID as input and ' + \ 'generates a predicted range for each location on the globe ' + \ 'and saves the ouput as an image.\n\n' + \ 'Warning: these estimated ranges should be validated before use.' parser = argparse.ArgumentParser(usage=info_str) # parser.add_argument('--model_path', type=str, default='./pretrained_models/model_an_full_input_enc_sin_cos_hard_cap_num_per_class_1000.pt') # parser.add_argument('--model_path', type=str, default='./experiments/transformer_ema_1.0/model_10.pt') # parser.add_argument('--model_path', type=str, default='./experiments/03_08_coord_multihead.pt/model.pt') # parser.add_argument('--model_path', type=str, default='./experimentvs/coord_context_20_without_registry/model_best.pt') # parser.add_argument('--model_path', type=str, default='./experiments/coord_sinr_inputs_context_20_without_registry/model_best.pt') parser.add_argument('--model_path', type=str, default='./experiments/zero_shot_ls_sin_cos_env_cap_1000_text_context_20_sinr_two_layer_nn/model.pt') #parser.add_argument('--model_path', type=str, default='./experiments/zero_shot_ls_sin_cos_cap_1000_text_context_20_sinr_two_layer_nn/model.pt') # parser.add_argument('--taxa_id', type=int, default=144575, help='iNaturalist taxon ID.') # parser.add_argument('--taxa_id', type=int, default=9083, help='iNaturalist taxon ID.') parser.add_argument('--taxa_id', type=int, default=3352, help='iNaturalist taxon ID.') parser.add_argument('--threshold', type=float, default=-1, help='Threshold the range map [0, 1].') parser.add_argument('--op_path', type=str, default='./images/', help='Location where the output image will be saved.') parser.add_argument('--rand_taxa', action='store_true', help='Select a random taxa.') parser.add_argument('--high_res', action='store_true', help='Generate higher resolution output.') parser.add_argument('--disable_ocean_mask', action='store_true', help='Do not use an ocean mask.') parser.add_argument('--set_max_cmap_to_1', action='store_true', help='Consistent maximum intensity ouput.') parser.add_argument('--device', type=str, default='cpu', help='cpu or cuda') #parser.add_argument('--device', type=str, default='cuda:3', help='cpu or cuda') parser.add_argument('--show_map', type=int, default=1, help='shows the map if 1') parser.add_argument('--show_context_points', type=int, default=1, help='also plots context points if 1') parser.add_argument('--prefix', type=str, default='') parser.add_argument('--num_context', type=int, default=5) parser.add_argument('--choose_context_points', type=int, default=1) parser.add_argument('--additional_save_name', type=str, default="") #taxas: black&whitewarbler(10286), hyacinth macaw(18938), yellow baboon(67683) # bawnswallow (11901), pika(43188), loon(4626), eurorobin(13094) # southernflyingsquirrel (46272) parser.add_argument('--taxa_name', type=str, default='sfs', help='Name of the taxon.') parser.add_argument('--test_taxa', type=int, default=46272, help='Taxon ID to test.') parser.add_argument('--text_type', type=str, default='range', help='Type of text for input.') parser.add_argument('--context_pt_trial', type=int, nargs='+', default=[0, 1, 2, 5, 10, 20], help='List of context points for trial.') eval_params = vars(parser.parse_args()) if not os.path.isdir(eval_params['op_path']): os.makedirs(eval_params['op_path']) eval_params['high_res'] = True main(eval_params)