angelazhu96 commited on
Commit
9ff98d7
·
1 Parent(s): dcb7cfe

code for viz

Browse files
Files changed (10) hide show
  1. app.py +87 -0
  2. create_inputs_to_fs_sinr.py +124 -0
  3. eval.py +0 -0
  4. get_gt.py +369 -0
  5. models.py +1434 -0
  6. paths.json +10 -0
  7. requirements.txt +10 -0
  8. setup.py +0 -0
  9. utils.py +326 -0
  10. viz_ls_map.py +283 -0
app.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from viz_ls_map import main
3
+ from get_gt import generate_ground_truth
4
+
5
+ def predict_species_distribution(taxa_id, taxa_name, text_type, num_context_points):
6
+ """
7
+ Function to predict species distribution and visualize the map.
8
+ """
9
+ isSnt = False
10
+ taxa_id = int(taxa_id)
11
+ #num_context_points = [0, 1, 2, 5, 10, 20]
12
+ num_context_points = [1]
13
+
14
+ # Generate ground truth for the species
15
+ generate_ground_truth(taxa_id, isSnt)
16
+ image_path_gt = f'images/species_presence_hr_{taxa_id}.png'
17
+ output_images = []
18
+ #print(num_context_points)
19
+
20
+ for text_type_i in ['none','range','habitat']:
21
+ # Set up evaluation parameters
22
+ eval_params = {
23
+ 'model_path': './experiments/zero_shot_ls_sin_cos_env_cap_1000_text_context_20_sinr_two_layer_nn/model.pt',
24
+ 'taxa_id': taxa_id,
25
+ 'threshold': -1,
26
+ 'op_path': './images/',
27
+ 'rand_taxa': False,
28
+ 'high_res': True,
29
+ 'disable_ocean_mask': False,
30
+ 'set_max_cmap_to_1': False,
31
+ 'device': 'cpu',
32
+ 'show_map': 1,
33
+ 'show_context_points': 1,
34
+ 'prefix': '',
35
+ 'num_context': num_context_points,
36
+ 'choose_context_points': 1,
37
+ 'additional_save_name': "",
38
+ 'taxa_name': taxa_name,
39
+ 'test_taxa': taxa_id,
40
+ 'text_type': text_type_i, # 'none', 'habitat', or 'range'
41
+ 'context_pt_trial': num_context_points,
42
+ }
43
+
44
+ # Run the FS-SINR model with the specified parameters
45
+ main(eval_params)
46
+
47
+ # The output image is saved in './images/' with the predicted range map
48
+ #image_path = f'./images/{taxa_name}_predicted_range.png'
49
+
50
+ for k in num_context_points:
51
+ # Assume image filenames are stored like this
52
+ image_path = f'./images/testenv_{taxa_name}(selected_points)_{text_type_i}_{k}.png'
53
+ output_images.append(image_path)
54
+
55
+
56
+ return [image_path_gt] + output_images
57
+ #return True
58
+
59
+ # Define the Gradio interface
60
+ with gr.Blocks() as demo:
61
+ gr.Markdown("# View Species Distribution Predictions using FS-SINR")
62
+
63
+ # Input fields for the Gradio interface
64
+ taxa_id = gr.Number(label="Taxa ID", value=43188)
65
+ taxa_name = gr.Textbox(label="Taxa Name", value="test_pika")
66
+ text_type = gr.Radio(label="Text Type", choices=['none', 'habitat', 'range'], value='none')
67
+ #num_context_points = gr.Slider(label="Number of Context Points", minimum=1, maximum=20, value=5, step=1)
68
+ num_context_points = gr.CheckboxGroup([0,1,2,3,4,5,10,15,20], label="Number of Context Points")
69
+
70
+ # Button to trigger the prediction
71
+ predict_button = gr.Button("Predict Species Distribution")
72
+
73
+ # Output: predicted range map
74
+ ground_truth = gr.Image(label="Ground Truth Map")
75
+ none_maps = gr.Image(label=f"Map for No Text Input and Context Point {1}")
76
+ range_maps = gr.Image(label=f"Map for Range Text input and Context Point {1}")
77
+ hab_maps = gr.Image(label=f"Map for Habitat Text input and Context Point {1}")
78
+ output_images = [ground_truth, none_maps, range_maps, hab_maps]
79
+
80
+
81
+ # Link the button to the function and inputs
82
+ predict_button.click(fn=predict_species_distribution,
83
+ inputs=[taxa_id, taxa_name, text_type, num_context_points],
84
+ outputs=output_images)
85
+
86
+ # Launch the Gradio interface
87
+ demo.launch()
create_inputs_to_fs_sinr.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import csv
3
+ from gritlm import GritLM
4
+ import pandas as pd
5
+ import ast
6
+ import numpy as np
7
+
8
+ input_text4 = ['The hyacinth macaw prefers semi-open, somewhat wooded habitats. It usually avoids dense, humid forest, and in regions dominated by such habitats, it is generally restricted to the edge or relatively open sections (e.g. along major rivers). In different areas of their range, these parrots are found in savannah grasslands, in dry thorn forests known as caatinga, and in palm stands or swamps, particularly the moriche palm (Mauritia flexuosa).',
9
+ 'The hyacinth macaw occurs today in three main areas in South America: In the Pantanal region of Brazil, and adjacent eastern Bolivia and northeastern Paraguay, in the cerrado regions of the eastern interior of Brazil (Maranhão, Piauí, Bahia, Tocantins, Goiás, Mato Grosso, Mato Grosso do Sul, and Minas Gerais), and in the relatively open areas associated with the Tocantins River, Xingu River, Tapajós River, and the Marajó island in the eastern Amazon Basin of Brazil.',
10
+ 'They are diurnal, terrestrial, and live in complex, mixed-gender social groups of 8 to 200 individuals per troop. They prefer savannas and light forests with a climate that is suitable for their omnivorous diet.',
11
+ 'Yellow baboons inhabit savannas and light forests in eastern Africa, from Kenya and Tanzania to Zimbabwe and Botswana.']
12
+ input_text5 = ['chappell roan', 'europe', 'pawpaw',
13
+ 'sierra nevada', 'great lakes', 'Treaty of Waitangi',
14
+ 'hello kitty', 'disney', 'madagascar', 'Andes', 'africa',
15
+ 'dessert', 'whale', 'moon snail', 'unicorn', 'rainfall',
16
+ 'species occurs above 2000m of elevation', 'froyo', 'desert',
17
+ 'dragon', 'bear', 'selkie', 'loch ness monster']
18
+
19
+ def extract_grit_token(model, text:str):
20
+ def gritlm_instruction(instruction):
21
+ return "<|user|>\n" + instruction + "\n<|embed|>\n" if instruction else "<|embed|>\n"
22
+ d_rep = model.encode([text], instruction=gritlm_instruction(""))
23
+ d_rep = torch.from_numpy(d_rep)
24
+ return d_rep
25
+
26
+ def generate_text_embs(text, output_file):
27
+ grit = GritLM("GritLM/GritLM-7B", torch_dtype="auto", mode="embedding")
28
+
29
+ with open(output_file, mode='w') as file:
30
+ writer = csv.writer(file)
31
+ writer.writerow(['Text', 'Embedding'])
32
+ for i in range(0, len(text)):
33
+ text_emb = extract_grit_token(grit, text[i]).to('cpu')
34
+ print(f" {text[i]}: {text_emb} ")
35
+ writer.writerow([text[i], text_emb.tolist()])
36
+
37
+ #TODO: max's generate text using grit
38
+ def generate_text_emb(text):
39
+ grit = GritLM("GritLM/GritLM-7B", torch_dtype="auto", mode="embedding")
40
+ text_emb = extract_grit_token(grit, text)
41
+ return text_emb
42
+
43
+ def use_pregenerated_textemb_fromgpt(taxon_id):
44
+ embs_loaded = torch.load('experiments/gpt_data.pt', map_location='cpu')
45
+
46
+ emb_ids = embs_loaded['taxon_id'].tolist() #(2785,)
47
+ keys1 = embs_loaded['keys'] #(11140, 2)
48
+ embs = embs_loaded['data'] # torch.Size([11140, 4096])
49
+ print(embs_loaded['taxon_id'].size())
50
+
51
+ matching_indices = [i for i, (tid) in enumerate(emb_ids) if tid == taxon_id]
52
+ print(matching_indices)
53
+ taxon_embeddings = embs[matching_indices, :] # Get embeddings for the matching indices
54
+ matching_keys = [keys1[i] for i in matching_indices] # Get the corresponding (taxon_id, text_type) keys
55
+
56
+ print(f"Found {len(matching_keys)} embeddings for taxon ID {taxon_id}:")
57
+ for i, key in enumerate(matching_keys):
58
+ print(f"Text Type: {key[1]}, Embedding: {taxon_embeddings[i, :]}")
59
+
60
+ return taxon_embeddings[i, :]
61
+
62
+ def use_pregenerated_textemb_fromchris(taxon_id, text_type):
63
+ #zero vector is for no text input
64
+ text_embedding = torch.zeros(1,4096)
65
+ if text_type is None or text_type == 'none':
66
+ return text_embedding, 0
67
+
68
+ embs1 = torch.load('experiments/gpt_data.pt', map_location='cpu')
69
+ emb_ids1 = embs1['taxon_id'].tolist()
70
+ keys1 = embs1['keys']
71
+ embs1 = embs1['data']
72
+
73
+ taxa_of_interest = taxon_id
74
+ taxa_index_of_interest = emb_ids1.index(taxa_of_interest) # gets 5
75
+
76
+ #keys_with_taxa_of_interest = [key for key in keys1 if key[0] == taxa_index_of_interest]
77
+ #indices_with_taxa_of_interest = [(key, i) for i, key in enumerate(keys1) if key[0] == taxa_index_of_interest]
78
+ possible_text_embedding_indexes = [i for i, key in enumerate(keys1) if key[0] == taxa_index_of_interest and key[1]==text_type]
79
+
80
+ if len(possible_text_embedding_indexes) != 1:
81
+ return text_embedding, 0
82
+ # take a look and choose what you want
83
+ # for key in indices_with_taxa_of_interest:
84
+ # print(key)
85
+
86
+ # ((5, 'range'), 20)
87
+ # ((5, 'habitat'), 21)
88
+ # ((5, 'species_description'), 22)
89
+ # ((5, 'overview_summary'), 23)
90
+
91
+ #macaw: range: 20, habitat: 21
92
+ #baboon: range: 7928, habitat: 7929
93
+ #black&white warbler: range: 16, habitat: 17
94
+ #barn swallow: range: 1652, habitat: 1653
95
+ #pika: range: 7116, habitat: 7117
96
+ #loon: range: 11056, habitat:11057
97
+ #euro robin: range: 2020, habitat: 2021
98
+ #sfs: range: 7148, habitat: 7149
99
+ text_embedding_index = possible_text_embedding_indexes[0]
100
+ text_embedding = embs1[text_embedding_index].unsqueeze(0)
101
+ #print(text_embedding_index)
102
+ return text_embedding, text_embedding_index
103
+
104
+ def use_pregenerated_textemb_fromcsv(input_text):
105
+ text_data = pd.read_csv('data/text_embs/text_embeddings_fig4.csv')
106
+ result_row = text_data[text_data['Text'] == input_text]
107
+ text_emb = ast.literal_eval(result_row['Embedding'].values[0])
108
+ embedding_tensor = torch.FloatTensor(text_emb)
109
+ return embedding_tensor
110
+
111
+ def get_eval_context_points(taxa_id, context_data, size):
112
+ all_context_pts = context_data['locs'][context_data['labels'] == np.argwhere(context_data['class_to_taxa'] == taxa_id)[0]][1:]
113
+ context_pts = all_context_pts[0:size]
114
+ dummy_classtoken = np.array([[0,0]])
115
+ context_pts = np.vstack((dummy_classtoken, context_pts))
116
+ #print(f"context point shape: {np.shape(context_pts)}")
117
+ normalized_pts = torch.from_numpy(context_pts) * torch.tensor([[1/180,1/90]], device='cpu')
118
+
119
+ return normalized_pts
120
+
121
+ if __name__ == '__main__':
122
+ print('starting to generate text_embs')
123
+ output_file = './data/text_embs/text_embeddings_fig4.csv'
124
+ use_pregenerated_textemb_fromchris()
eval.py ADDED
The diff for this file is too large to render. See raw diff
 
get_gt.py ADDED
@@ -0,0 +1,369 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import numpy as np
2
+ # import h3
3
+ # import json
4
+ # import os
5
+ #
6
+ # snt=False
7
+ #
8
+ # def get_labels(species, data):
9
+ # species = str(species)
10
+ # lat = []
11
+ # lon = []
12
+ # gt = []
13
+ # for hx in data:
14
+ # cur_lat, cur_lon = h3.h3_to_geo(hx)
15
+ # if species in data[hx]:
16
+ # cur_label = int(len(data[hx][species]) > 0)
17
+ # gt.append(cur_label)
18
+ # lat.append(cur_lat)
19
+ # lon.append(cur_lon)
20
+ # lat = np.array(lat).astype(np.float32)
21
+ # lon = np.array(lon).astype(np.float32)
22
+ # obs_locs = np.vstack((lon, lat)).T
23
+ # gt = np.array(gt).astype(np.float32)
24
+ # return obs_locs, gt
25
+ #
26
+ # def lonlat_to_pixel(lonlat, grid_width, grid_height):
27
+ # # Convert normalized lon/lat (-1 to 1) to pixel coordinates
28
+ # x_pixel = np.floor((lonlat[:, 0] + 1) / 2 * (grid_width - 1)).astype(int)
29
+ # y_pixel = np.floor((1 - (lonlat[:, 1] + 1) / 2) * (grid_height - 1)).astype(int)
30
+ # return x_pixel, y_pixel
31
+ #
32
+ # ocean_mask = np.load("data/masks/ocean_mask.npy", allow_pickle=True)
33
+ # # 1002, 2004 pixels
34
+ # # 0 in ocean (needs to be masked out)
35
+ #
36
+ # if snt:
37
+ # with open('paths.json', 'r') as f:
38
+ # paths = json.load(f)
39
+ # D = np.load(os.path.join(paths['snt'], 'snt_res_5.npy'), allow_pickle=True)
40
+ # D = D.item()
41
+ # loc_indices_per_species = D['loc_indices_per_species']
42
+ # labels_per_species = D['labels_per_species']
43
+ # taxa = D['taxa']
44
+ # obs_locs = D['obs_locs']
45
+ # obs_locs_idx = D['obs_locs_idx']
46
+ # else:
47
+ # with open('paths.json', 'r') as f:
48
+ # paths = json.load(f)
49
+ # with open(os.path.join(paths['iucn'], 'iucn_res_5.json'), 'r') as f:
50
+ # data = json.load(f)
51
+ # obs_locs = np.array(data['locs'], dtype=np.float32)
52
+ # taxa = [int(tt) for tt in data['taxa_presence'].keys()]
53
+ # a = 6
54
+ # # data['taxa_presence'] is a dict where keys are "taxa" and then the values are the indices of "obs_locs" where the species is present
55
+ # # obs locs is in lon, lat with -180 to 180 and -90 to 90
56
+
57
+ import numpy as np
58
+ import h3
59
+ import json
60
+ import os
61
+ import matplotlib.pyplot as plt
62
+ from mpl_toolkits.axes_grid1 import make_axes_locatable
63
+
64
+
65
+ def get_labels(species, data):
66
+ species = str(species)
67
+ lat = []
68
+ lon = []
69
+ gt = []
70
+ for hx in data:
71
+ cur_lat, cur_lon = h3.h3_to_geo(hx)
72
+ if species in data[hx]:
73
+ cur_label = int(len(data[hx][species]) > 0)
74
+ gt.append(cur_label)
75
+ lat.append(cur_lat)
76
+ lon.append(cur_lon)
77
+ lat = np.array(lat).astype(np.float32)
78
+ lon = np.array(lon).astype(np.float32)
79
+ obs_locs = np.vstack((lon, lat)).T
80
+ gt = np.array(gt).astype(np.float32)
81
+ return obs_locs, gt
82
+
83
+ def lonlat_to_pixel(lonlat, grid_width, grid_height):
84
+ # Convert normalized lon/lat (-1 to 1) to pixel coordinates
85
+ x_pixel = np.floor((lonlat[:, 0] + 1) / 2 * (grid_width - 1)).astype(int)
86
+ y_pixel = np.floor((1 - (lonlat[:, 1] + 1) / 2) * (grid_height - 1)).astype(int)
87
+ return x_pixel, y_pixel
88
+
89
+ # def plot_heatmap(data,save_loc):
90
+ # # Apply mask if provided
91
+ # ocean_mask = np.load("data/masks/ocean_mask.npy", allow_pickle=True)
92
+ # # 1002, 2004 pixels
93
+ # # 0 in ocean (needs to be masked out)
94
+ #
95
+ # # Convert ocean_mask to boolean mask
96
+ # mask = ocean_mask.astype(bool)
97
+ # mask = mask[::2, ::2]
98
+ #
99
+ # if mask is not None:
100
+ # data = np.where(mask, data, 0)
101
+ #
102
+ # # Set NaN values to 0 for plotting
103
+ # data = np.nan_to_num(data, nan=0.0)
104
+ #
105
+ # fig, ax = plt.subplots(figsize=(20.04, 10.02), dpi=100)
106
+ # ax.set_xlim(-180, 180)
107
+ # ax.set_ylim(-90, 90)
108
+ # ax.axis('off')
109
+ #
110
+ # # Use 'magma' colormap with two discrete colors
111
+ # cmap = plt.get_cmap('magma', 2)
112
+ # cmap.set_bad(color='none')
113
+ # plt.rcParams['font.family'] = 'serif'
114
+ #
115
+ # cax_im = ax.imshow(data, extent=(-180, 180, -90, 90), origin='upper', cmap=cmap, vmin=0, vmax=1)
116
+ #
117
+ # plt.tight_layout()
118
+ # pdf_save_loc = save_loc + '.pdf'
119
+ # png_save_loc = save_loc + '.png'
120
+ # plt.savefig(pdf_save_loc, bbox_inches='tight', pad_inches=0)
121
+ # plt.savefig(png_save_loc, bbox_inches='tight', pad_inches=0)
122
+ # plt.close(fig)
123
+
124
+ def plot_heatmap(data, save_loc):
125
+ # Load the ocean mask
126
+ ocean_mask = np.load("data/masks/ocean_mask.npy", allow_pickle=True)
127
+ # 1002, 2004 pixels
128
+ # 0 in ocean (needs to be masked out)
129
+
130
+ # Convert ocean_mask to boolean mask
131
+ mask = ocean_mask.astype(bool)
132
+ # If you need to downsample the mask, uncomment the following line
133
+ mask = mask[::2, ::2]
134
+
135
+ # Set ocean areas to np.nan
136
+ data = np.where(mask, data, np.nan)
137
+
138
+ # Create a masked array where NaNs are masked
139
+ data_masked = np.ma.array(data, mask=np.isnan(data))
140
+
141
+ fig, ax = plt.subplots(figsize=(20.04, 10.02), dpi=100)
142
+ ax.set_xlim(-180, 180)
143
+ ax.set_ylim(-90, 90)
144
+ ax.axis('off')
145
+
146
+ # Use 'magma' colormap with two discrete colors
147
+ cmap = plt.get_cmap('plasma', 2)
148
+ # Set color for masked (NaN) values
149
+ cmap.set_bad(color='none') # 'none' makes it transparent; use 'white' for white background
150
+
151
+ # Plot the data
152
+ cax_im = ax.imshow(
153
+ data_masked,
154
+ extent=(-180, 180, -90, 90),
155
+ origin='upper',
156
+ cmap=cmap,
157
+ vmin=0,
158
+ vmax=1,
159
+ interpolation='nearest'
160
+ )
161
+
162
+ plt.tight_layout()
163
+ pdf_save_loc = save_loc + '.pdf'
164
+ png_save_loc = save_loc + '.png'
165
+ plt.savefig(pdf_save_loc, bbox_inches='tight', pad_inches=0)
166
+ plt.savefig(png_save_loc, bbox_inches='tight', pad_inches=0)
167
+ plt.close(fig)
168
+
169
+ def plot_heatmap_2(data, save_loc):
170
+ # Load the ocean mask
171
+ ocean_mask = np.load("data/masks/ocean_mask.npy", allow_pickle=True)
172
+ # 1002, 2004 pixels
173
+ # 0 in ocean (needs to be masked out)
174
+
175
+ # Convert ocean_mask to boolean mask
176
+ mask = ocean_mask.astype(bool)
177
+ # If you need to downsample the mask, uncomment the following line
178
+
179
+ # Set ocean areas to np.nan
180
+ data = np.where(mask, data, np.nan)
181
+
182
+ # Create a masked array where NaNs are masked
183
+ data_masked = np.ma.array(data, mask=np.isnan(data))
184
+
185
+ fig, ax = plt.subplots(figsize=(20.04, 10.02), dpi=100)
186
+ ax.set_xlim(-180, 180)
187
+ ax.set_ylim(-90, 90)
188
+ ax.axis('off')
189
+
190
+ # Use 'magma' colormap with two discrete colors
191
+ cmap = plt.get_cmap('plasma', 2)
192
+ # Set color for masked (NaN) values
193
+ cmap.set_bad(color='none') # 'none' makes it transparent; use 'white' for white background
194
+
195
+ # Plot the data
196
+ cax_im = ax.imshow(
197
+ data_masked,
198
+ extent=(-180, 180, -90, 90),
199
+ origin='upper',
200
+ cmap=cmap,
201
+ vmin=0,
202
+ vmax=1,
203
+ interpolation='nearest'
204
+ )
205
+
206
+ plt.tight_layout()
207
+ pdf_save_loc = save_loc + '.pdf'
208
+ png_save_loc = save_loc + '.png'
209
+ plt.savefig(pdf_save_loc, bbox_inches='tight', pad_inches=0)
210
+ plt.savefig(png_save_loc, bbox_inches='tight', pad_inches=0)
211
+ plt.show(block=False)
212
+ plt.close(fig)
213
+
214
+ def generate_ground_truth(taxa_id, snt=True, grid_height=501, grid_width=1002):
215
+ print(taxa_id)
216
+ if snt:
217
+ with open('paths.json', 'r') as f:
218
+ paths = json.load(f)
219
+ D = np.load(os.path.join(paths['snt'], 'snt_res_5.npy'), allow_pickle=True)
220
+ D = D.item()
221
+ loc_indices_per_species = D['loc_indices_per_species']
222
+ labels_per_species = D['labels_per_species']
223
+ taxa = D['taxa']
224
+ obs_locs = D['obs_locs']
225
+ obs_locs_idx = D['obs_locs_idx']
226
+ # class_index = np.where(taxa==taxa_id)
227
+ # class_index = class_index[0]
228
+ # class_index = class_index[0]
229
+ # species_loc_indices = loc_indices_per_species[class_index]
230
+ # species_locs = obs_locs[species_loc_indices]
231
+ # presence_indices = labels_per_species[class_index]
232
+ # species_locs = species_locs[presence_indices==1]
233
+
234
+ # Ensure class_index is correctly obtained as an integer index
235
+ class_indices = np.where(taxa == taxa_id)[0]
236
+ if len(class_indices) == 0:
237
+ raise ValueError(f"taxa_id {taxa_id} not found in taxa")
238
+ class_index = class_indices[0]
239
+
240
+ # Convert loc_indices_per_species[class_index] to a NumPy array
241
+ species_loc_indices = np.array(loc_indices_per_species[class_index])
242
+
243
+ # Retrieve the species locations using the indices
244
+ species_locs = obs_locs[species_loc_indices]
245
+
246
+ # Convert labels_per_species[class_index] to a NumPy array
247
+ presence_indices = np.array(labels_per_species[class_index])
248
+
249
+ # Filter species_locs where presence_indices == 1
250
+ species_locs = species_locs[presence_indices == 1]
251
+
252
+ else:
253
+ with open('paths.json', 'r') as f:
254
+ paths = json.load(f)
255
+ with open(os.path.join(paths['iucn'], 'iucn_res_5.json'), 'r') as f:
256
+ data = json.load(f)
257
+ obs_locs = np.array(data['locs'], dtype=np.float32)
258
+ taxa = [int(tt) for tt in data['taxa_presence'].keys()]
259
+ indices = data['taxa_presence'][str(taxa_id)]
260
+ species_locs = obs_locs[indices] # shape (N, 2)
261
+
262
+
263
+ # Normalize lonlat
264
+ species_locs_normalized = species_locs.copy()
265
+ species_locs_normalized[:, 0] = species_locs_normalized[:, 0] / 180 # lon / 180
266
+ species_locs_normalized[:, 1] = species_locs_normalized[:, 1] / 90 # lat / 90# Get grid dimensions from ocean_mas
267
+
268
+
269
+ # Get pixel coordinates
270
+ x_pixel, y_pixel = lonlat_to_pixel(species_locs_normalized, grid_width, grid_height)
271
+
272
+ # Ensure x_pixel and y_pixel are within bounds
273
+ x_pixel = np.clip(x_pixel, 0, grid_width - 1)
274
+ y_pixel = np.clip(y_pixel, 0, grid_height - 1)
275
+
276
+ # Create data array
277
+ data_array = np.zeros((grid_height, grid_width))
278
+
279
+ # Set pixels where species is present
280
+ data_array[y_pixel, x_pixel] = 1
281
+
282
+ # Now call plot_heatmap
283
+ title = f"Species presence for taxa {taxa_id}"
284
+ save_loc = f"./images/species_presence_{taxa_id}"
285
+ plot_heatmap(data_array, save_loc)
286
+
287
+ grid_height = 1002
288
+ grid_width = 2004
289
+
290
+ if snt:
291
+ with open('paths.json', 'r') as f:
292
+ paths = json.load(f)
293
+ D = np.load(os.path.join(paths['snt'], 'snt_res_5.npy'), allow_pickle=True)
294
+ D = D.item()
295
+ loc_indices_per_species = D['loc_indices_per_species']
296
+ labels_per_species = D['labels_per_species']
297
+ taxa = D['taxa']
298
+ obs_locs = D['obs_locs']
299
+ obs_locs_idx = D['obs_locs_idx']
300
+ # class_index = np.where(taxa==taxa_id)
301
+ # class_index = class_index[0]
302
+ # class_index = class_index[0]
303
+ # species_loc_indices = loc_indices_per_species[class_index]
304
+ # species_locs = obs_locs[species_loc_indices]
305
+ # presence_indices = labels_per_species[class_index]
306
+ # species_locs = species_locs[presence_indices==1]
307
+
308
+ # Ensure class_index is correctly obtained as an integer index
309
+ class_indices = np.where(taxa == taxa_id)[0]
310
+ if len(class_indices) == 0:
311
+ raise ValueError(f"taxa_id {taxa_id} not found in taxa")
312
+ class_index = class_indices[0]
313
+
314
+ # Convert loc_indices_per_species[class_index] to a NumPy array
315
+ species_loc_indices = np.array(loc_indices_per_species[class_index])
316
+
317
+ # Retrieve the species locations using the indices
318
+ species_locs = obs_locs[species_loc_indices]
319
+
320
+ # Convert labels_per_species[class_index] to a NumPy array
321
+ presence_indices = np.array(labels_per_species[class_index])
322
+
323
+ # Filter species_locs where presence_indices == 1
324
+ species_locs = species_locs[presence_indices == 1]
325
+
326
+ else:
327
+ with open('paths.json', 'r') as f:
328
+ paths = json.load(f)
329
+ with open(os.path.join(paths['iucn'], 'iucn_res_5.json'), 'r') as f:
330
+ data = json.load(f)
331
+ obs_locs = np.array(data['locs'], dtype=np.float32)
332
+ taxa = [int(tt) for tt in data['taxa_presence'].keys()]
333
+ indices = data['taxa_presence'][str(taxa_id)]
334
+ species_locs = obs_locs[indices] # shape (N, 2)
335
+
336
+
337
+ # Normalize lonlat
338
+ species_locs_normalized = species_locs.copy()
339
+ species_locs_normalized[:, 0] = species_locs_normalized[:, 0] / 180 # lon / 180
340
+ species_locs_normalized[:, 1] = species_locs_normalized[:, 1] / 90 # lat / 90# Get grid dimensions from ocean_mas
341
+
342
+
343
+ # Get pixel coordinates
344
+ x_pixel, y_pixel = lonlat_to_pixel(species_locs_normalized, grid_width, grid_height)
345
+
346
+ # Ensure x_pixel and y_pixel are within bounds
347
+ x_pixel = np.clip(x_pixel, 0, grid_width - 1)
348
+ y_pixel = np.clip(y_pixel, 0, grid_height - 1)
349
+
350
+ # Create data array
351
+ data_array = np.zeros((grid_height, grid_width))
352
+
353
+ # Set pixels where species is present
354
+ data_array[y_pixel, x_pixel] = 1
355
+
356
+ # Now call plot_heatmap
357
+ title = f"Species presence for taxa {taxa_id}"
358
+ save_loc = f"./images/species_presence_hr_{taxa_id}"
359
+ plot_heatmap_2(data_array, save_loc)
360
+ return True
361
+
362
+ if __name__ == '__main__':
363
+ snt = True
364
+ grid_height = 501
365
+ grid_width = 1002
366
+ taxa_id = 11901 # Or any taxa id you want to plot, as string
367
+
368
+ #TODO: why snt true? can't generate gt for (hyacinth macaw(18938), yellow baboon(67683), pika(43188), southernflyingsquirrel (46272))
369
+ generate_ground_truth(taxa_id=taxa_id, snt=snt, grid_height=grid_height, grid_width=grid_width)
models.py ADDED
@@ -0,0 +1,1434 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.utils.data
3
+ import torch.nn as nn
4
+ import math
5
+ import csv
6
+ import numpy as np
7
+ import json
8
+ import os
9
+
10
+
11
+ def get_model(params, inference_only=False):
12
+ if params['model'] == 'ResidualFCNet':
13
+ return ResidualFCNet(params['input_dim'] + params['input_time_dim'] + (20 if 'env' in params['input_enc'] and 'contrastive' not in params['input_enc'] else 0) + (1 if params['noise_time'] else 0), params['num_classes'] + (20 if 'env' in params['loss'] else 0), params['num_filts'], params['depth'])
14
+ elif params['model'] == 'LinNet':
15
+ return LinNet(params['input_dim'] + params['input_time_dim'] + (20 if 'env' in params['input_enc'] else 0) + (1 if params['noise_time'] else 0), params['num_classes'])
16
+ elif params['model'] == 'HyperNet':
17
+ return HyperNet(params, params['input_dim'] + (20 if 'env' in params['input_enc'] else 0), params['num_classes'], params['num_filts'], params['depth'],
18
+ params['species_dim'], params['species_enc_depth'], params['species_filts'], params['species_enc'], inference_only=inference_only)
19
+ # chris models
20
+ elif params['model'] == 'MultiInputModel':
21
+ return MultiInputModel(num_inputs=params['input_dim'] + params['input_time_dim'] + (20 if 'env' in params['input_enc'] and 'contrastive' not in params['input_enc'] else 0) + (1 if params['noise_time'] else 0),
22
+ num_filts=params['num_filts'], num_classes=params['num_classes'] + (20 if 'env' in params['loss'] else 0),
23
+ depth=params['depth'], ema_factor=params['ema_factor'], nhead=params['num_heads'], num_encoder_layers=params['species_enc_depth'],
24
+ dim_feedforward=params['species_filts'], dropout=params['transformer_dropout'],
25
+ batch_first=True, token_dim=(params['species_dim'] + (20 if 'env' in params['transformer_input_enc'] else 0)),
26
+ sinr_inputs=True if 'sinr' in params['transformer_input_enc'] else False,
27
+ register=params['use_register'], use_pretrained_sinr=params['use_pretrained_sinr'],
28
+ freeze_sinr=params['freeze_sinr'], pretrained_loc=params['pretrained_loc'],
29
+ text_inputs=params['use_text_inputs'], class_token_transformation=params['class_token_transformation'])
30
+ elif params['model'] == 'VariableInputModel':
31
+ return VariableInputModel(num_inputs=params['input_dim'] + params['input_time_dim'] + (20 if 'env' in params['input_enc'] and 'contrastive' not in params['input_enc'] else 0) + (1 if params['noise_time'] else 0),
32
+ num_filts=params['num_filts'], num_classes=params['num_classes'] + (20 if 'env' in params['loss'] else 0),
33
+ depth=params['depth'], ema_factor=params['ema_factor'], nhead=params['num_heads'], num_encoder_layers=params['species_enc_depth'],
34
+ dim_feedforward=params['species_filts'], dropout=params['transformer_dropout'],
35
+ batch_first=True, token_dim=(params['species_dim'] + (20 if 'env' in params['transformer_input_enc'] else 0)),
36
+ sinr_inputs=True if 'sinr' in params['transformer_input_enc'] else False,
37
+ register=params['use_register'], use_pretrained_sinr=params['use_pretrained_sinr'],
38
+ freeze_sinr=params['freeze_sinr'], pretrained_loc=params['pretrained_loc'],
39
+ text_inputs=params['use_text_inputs'], image_inputs=params['use_image_inputs'],
40
+ env_inputs=params['use_env_inputs'],
41
+ class_token_transformation=params['class_token_transformation'])
42
+
43
+ # class VariableInputModel(nn.Module):
44
+ # def __init__(self, num_inputs, num_filts, num_classes, depth=4, nonlin='relu', lowrank=0, ema_factor=0.1,
45
+ # nhead=8, num_encoder_layers=4, dim_feedforward=2048, dropout=0.1, batch_first=True, token_dim=256,
46
+ # sinr_inputs=False, register=False, use_pretrained_sinr=False, freeze_sinr=False, pretrained_loc='',
47
+ # text_inputs=False, image_inputs=False, env_inputs=False, class_token_transformation='identity'):
48
+
49
+
50
+ class ResLayer(nn.Module):
51
+ def __init__(self, linear_size, activation=nn.ReLU, p=0.5):
52
+ super(ResLayer, self).__init__()
53
+ self.l_size = linear_size
54
+ self.nonlin1 = activation()
55
+ self.nonlin2 = activation()
56
+ self.dropout1 = nn.Dropout(p=p)
57
+ self.w1 = nn.Linear(self.l_size, self.l_size)
58
+ self.w2 = nn.Linear(self.l_size, self.l_size)
59
+
60
+ def forward(self, x):
61
+ y = self.w1(x)
62
+ y = self.nonlin1(y)
63
+ y = self.dropout1(y)
64
+ y = self.w2(y)
65
+ y = self.nonlin2(y)
66
+ out = x + y
67
+ return out
68
+
69
+
70
+ class ResidualFCNet(nn.Module):
71
+ def __init__(self, num_inputs, num_classes, num_filts, depth=4, nonlin='relu', lowrank=0, dropout_p=0.5):
72
+ super(ResidualFCNet, self).__init__()
73
+ self.inc_bias = False
74
+ if lowrank < num_filts and lowrank != 0:
75
+ l1 = nn.Linear(num_filts if depth != -1 else num_inputs, lowrank, bias=self.inc_bias)
76
+ l2 = nn.Linear(lowrank, num_classes, bias=self.inc_bias)
77
+ self.class_emb = nn.Sequential(l1, l2)
78
+ else:
79
+ self.class_emb = nn.Linear(num_filts if depth != -1 else num_inputs, num_classes, bias=self.inc_bias)
80
+ if nonlin == 'relu':
81
+ activation = nn.ReLU
82
+ elif nonlin == 'silu':
83
+ activation = nn.SiLU
84
+ else:
85
+ raise NotImplementedError('Invalid nonlinearity specified.')
86
+ layers = []
87
+ if depth != -1:
88
+ layers.append(nn.Linear(num_inputs, num_filts))
89
+ layers.append(activation())
90
+ for i in range(depth):
91
+ layers.append(ResLayer(num_filts, activation=activation))
92
+ else:
93
+ layers.append(nn.Identity())
94
+ self.feats = torch.nn.Sequential(*layers)
95
+
96
+ def forward(self, x, class_of_interest=None, return_feats=False):
97
+ loc_emb = self.feats(x)
98
+ if return_feats:
99
+ return loc_emb
100
+ if class_of_interest is None:
101
+ class_pred = self.class_emb(loc_emb)
102
+ else:
103
+ class_pred = self.eval_single_class(loc_emb, class_of_interest), self.eval_single_class(loc_emb, -1)
104
+ return torch.sigmoid(class_pred[0]), torch.sigmoid(class_pred[1])
105
+ return torch.sigmoid(class_pred)
106
+
107
+ def eval_single_class(self, x, class_of_interest):
108
+ if self.inc_bias:
109
+ return x @ self.class_emb.weight[class_of_interest, :] + self.class_emb.bias[class_of_interest]
110
+ else:
111
+ return x @ self.class_emb.weight[class_of_interest, :]
112
+
113
+
114
+ class SimpleFCNet(ResidualFCNet):
115
+ def forward(self, x, return_feats=True):
116
+ assert return_feats
117
+ loc_emb = self.feats(x)
118
+ class_pred = self.class_emb(loc_emb)
119
+ return class_pred
120
+
121
+
122
+ class MockTransformer(nn.Module):
123
+ def __init__(self, num_classes, num_dims):
124
+ super(MockTransformer, self).__init__()
125
+ self.species_emb = nn.Embedding(num_embeddings=num_classes, embedding_dim=num_dims)
126
+
127
+ def forward(self, class_ids):
128
+ return self.species_emb(class_ids)
129
+
130
+
131
+ class CombinedModel(nn.Module):
132
+ def __init__(self, num_inputs, num_filts, num_classes, depth=4, nonlin='relu', lowrank=0, ema_factor=0.1):
133
+ super(CombinedModel, self).__init__()
134
+ self.headless_model = HeadlessSINR(num_inputs, num_filts, depth, nonlin, lowrank)
135
+ if lowrank < num_filts and lowrank != 0:
136
+ self.transformer_model = MockTransformer(num_classes, lowrank)
137
+ else:
138
+ self.transformer_model = MockTransformer(num_classes, num_filts)
139
+ self.ema_factor = ema_factor
140
+ self.ema_embeddings = nn.Embedding(num_embeddings=num_classes, embedding_dim=lowrank if (lowrank < num_filts and lowrank != 0) else num_filts)
141
+ self.ema_embeddings.weight.data.copy_(self.transformer_model.species_emb.weight.data) # Initialize EMA with the same values as transformer
142
+ # this will have to change when I start using the actual transformer
143
+
144
+ def forward(self, x, class_ids=None, return_feats=False, return_class_embeddings=False, class_of_interest=None):
145
+ # Process input through the headless model to get feature embeddings
146
+ feature_embeddings = self.headless_model(x)
147
+
148
+ if return_feats:
149
+ return feature_embeddings
150
+ else:
151
+ if class_of_interest == None:
152
+ # Get class-specific embeddings based on class_ids
153
+ class_embeddings = self.transformer_model(class_ids)
154
+ if return_class_embeddings:
155
+ return class_embeddings
156
+ else:
157
+ # Update EMA embeddings for these class IDs
158
+ if self.training:
159
+ self.update_ema_embeddings(class_ids, class_embeddings)
160
+
161
+ # Matrix multiplication to produce logits
162
+ logits = feature_embeddings @ class_embeddings.T
163
+
164
+ # Apply sigmoid to convert logits to probabilities
165
+ probabilities = torch.sigmoid(logits)
166
+
167
+ return probabilities
168
+ else:
169
+ device = self.ema_embeddings.weight.device
170
+ class_of_interest_tensor =torch.tensor([class_of_interest]).to(device)
171
+ class_embedding = self.get_ema_embeddings(class_of_interest_tensor)
172
+ print(f'using EMA estimate for class {class_of_interest}')
173
+ if return_class_embeddings:
174
+ return class_embedding
175
+ else:
176
+ # Matrix multiplication to produce logits
177
+ logits = feature_embeddings @ class_embedding.T
178
+
179
+ # Apply sigmoid to convert logits to probabilities
180
+ probabilities = torch.sigmoid(logits)
181
+ probabilities = probabilities.squeeze()
182
+
183
+ return probabilities
184
+
185
+ def update_ema_embeddings(self, class_ids, current_embeddings):
186
+ if self.training:
187
+ # Get current EMA embeddings for the class IDs
188
+ ema_current = self.ema_embeddings(class_ids)
189
+
190
+ # Calculate new EMA values
191
+ ema_new = self.ema_factor * current_embeddings + (1 - self.ema_factor) * ema_current
192
+
193
+ # Update the EMA embeddings
194
+ self.ema_embeddings.weight.data[class_ids] = ema_new.detach() # Detach to prevent gradients from flowing here
195
+
196
+ def get_ema_embeddings(self, class_ids):
197
+ # Method to access EMA embeddings
198
+ return self.ema_embeddings(class_ids)
199
+
200
+ class HeadlessSINR(nn.Module):
201
+ def __init__(self, num_inputs, num_filts, depth=4, nonlin='relu', lowrank=0, dropout_p=0.5):
202
+ super(HeadlessSINR, self).__init__()
203
+ self.inc_bias = False
204
+ self.low_rank_feats = None
205
+ if lowrank < num_filts and lowrank != 0:
206
+ l1 = nn.Linear(num_filts if depth != -1 else num_inputs, lowrank, bias=self.inc_bias)
207
+ self.low_rank_feats = l1
208
+ # else:
209
+ # self.class_emb = nn.Linear(num_filts if depth != -1 else num_inputs, num_classes, bias=self.inc_bias)
210
+ if nonlin == 'relu':
211
+ activation = nn.ReLU
212
+ elif nonlin == 'silu':
213
+ activation = nn.SiLU
214
+ else:
215
+ raise NotImplementedError('Invalid nonlinearity specified.')
216
+
217
+ # Create the layers list for feature extraction
218
+ layers = []
219
+ if depth != -1:
220
+ layers.append(nn.Linear(num_inputs, num_filts))
221
+ layers.append(activation())
222
+ for i in range(depth):
223
+ layers.append(ResLayer(num_filts, activation=activation, p=dropout_p))
224
+ else:
225
+ layers.append(nn.Identity())
226
+ # Include low-rank features in the sequential model if it is defined
227
+ if self.low_rank_feats:
228
+ # Apply initial layers then low-rank features
229
+ layers.append(self.low_rank_feats)
230
+ # Set up the features as a sequential model
231
+ self.feats = nn.Sequential(*layers)
232
+
233
+ def forward(self, x):
234
+ loc_emb = self.feats(x)
235
+ return loc_emb
236
+
237
+
238
+ class TransformerEncoderModel(nn.Module):
239
+ def __init__(self, d_model=256, nhead=8, num_encoder_layers=4, dim_feedforward=2048, dropout=0.1, activation='relu',
240
+ batch_first=True, output_dim=256): # BATCH FIRST MIGHT HAVE TO CHANGE
241
+ super(TransformerEncoderModel, self).__init__()
242
+ self.input_layer_norm = nn.LayerNorm(normalized_shape=d_model)
243
+ # Create an encoder layer
244
+ encoder_layer = nn.TransformerEncoderLayer(
245
+ d_model=d_model,
246
+ nhead=nhead,
247
+ dim_feedforward=dim_feedforward,
248
+ dropout=dropout,
249
+ activation=activation,
250
+ batch_first=batch_first
251
+ )
252
+
253
+ # Stack the encoder layers into an encoder module
254
+ self.transformer_encoder = nn.TransformerEncoder(
255
+ encoder_layer=encoder_layer,
256
+ num_layers=num_encoder_layers
257
+ )
258
+
259
+ # Example output layer (modify according to your needs)
260
+ self.output_layer = nn.Linear(d_model, output_dim)
261
+
262
+ def forward(self, src, src_mask=None, src_key_padding_mask=None):
263
+ """
264
+ Args:
265
+ src: the sequence to the encoder (shape: [seq_length, batch_size, d_model])
266
+ src_mask: the mask for the src sequence (shape: [seq_length, seq_length])
267
+ src_key_padding_mask: the mask for the padding tokens (shape: [batch_size, seq_length])
268
+
269
+ Returns:
270
+ output of the transformer encoder
271
+ """
272
+ # Pass the input through the transformer encoder
273
+ encoder_input = self.input_layer_norm(src)
274
+ encoder_output = self.transformer_encoder(encoder_input, src_key_padding_mask=src_key_padding_mask, mask=src_mask)
275
+
276
+ # # Pass the encoder output through the output layer
277
+ # output = self.output_layer(encoder_output)
278
+
279
+ # Assuming the class token is the first in the sequence
280
+ # batch_first so we have (batch, sequence, dim)
281
+ if encoder_output.ndim == 2:
282
+ # in situations where we don't have a batch
283
+ encoder_output = encoder_output.unsqueeze(0)
284
+
285
+ class_token_embedding = encoder_output[:, 0, :]
286
+
287
+ output = self.output_layer(class_token_embedding) # Process only the class token embedding
288
+ return output
289
+
290
+
291
+ class MultiInputModel(nn.Module):
292
+ def __init__(self, num_inputs, num_filts, num_classes, depth=4, nonlin='relu', lowrank=0, ema_factor=0.1,
293
+ nhead=8, num_encoder_layers=4, dim_feedforward=2048, dropout=0.1, batch_first=True, token_dim=256,
294
+ sinr_inputs=False, register=False, use_pretrained_sinr=False, freeze_sinr=False, pretrained_loc='',
295
+ text_inputs=False, class_token_transformation='identity'):
296
+ super(MultiInputModel, self).__init__()
297
+
298
+ self.headless_model = HeadlessSINR(num_inputs, num_filts, depth, nonlin, lowrank, dropout_p=dropout)
299
+ self.ema_factor = ema_factor
300
+ self.class_token_transformation = class_token_transformation
301
+
302
+ # Load pretrained state_dict if use_pretrained_sinr is set to True
303
+ if use_pretrained_sinr:
304
+ #pretrained_state_dict = torch.load(pretrained_loc, weights_only=False)['state_dict']
305
+ pretrained_state_dict = torch.load(pretrained_loc, map_location=torch.device('cpu'))['state_dict']
306
+ filtered_state_dict = {k: v for k, v in pretrained_state_dict.items() if not k.startswith('class_emb')}
307
+ self.headless_model.load_state_dict(filtered_state_dict, strict=False)
308
+ #print(f'Using pretrained sinr from {pretrained_loc}')
309
+
310
+ # Freeze the SINR model if freeze_sinr is set to True
311
+ if freeze_sinr:
312
+ for param in self.headless_model.parameters():
313
+ param.requires_grad = False
314
+ print("Freezing SINR model parameters")
315
+
316
+ # self.transformer_model = MockTransformer(num_classes, num_filts)
317
+ self.transformer_model = TransformerEncoderModel(d_model=token_dim,
318
+ nhead=nhead,
319
+ num_encoder_layers=num_encoder_layers,
320
+ dim_feedforward=dim_feedforward,
321
+ dropout=dropout,
322
+ batch_first=batch_first,
323
+ output_dim=num_filts)
324
+
325
+ self.ema_embeddings = nn.Embedding(num_embeddings=num_classes, embedding_dim=num_filts)
326
+ # this is just a workaround for now to load eval embeddings - probably not needed long term
327
+ self.eval_embeddings = nn.Embedding(num_embeddings=num_classes, embedding_dim=num_filts)
328
+ self.ema_embeddings.weight.requires_grad = False
329
+ self.eval_embeddings.weight.requires_grad = False
330
+ self.num_filts=num_filts
331
+ self.token_dim = token_dim
332
+ # nn.init.xavier_uniform_(self.ema_embeddings.weight) # not needed I think
333
+ self.sinr_inputs = sinr_inputs
334
+ if self.sinr_inputs:
335
+ if self.num_filts != self.token_dim and self.class_token_transformation == 'identity':
336
+ raise ValueError("If using sinr inputs to transformer with identity class token transformation"
337
+ "then token_dim of transformer must be equal to num_filts of sinr model")
338
+
339
+ # Add a class token
340
+ self.class_token = nn.Parameter(torch.empty(1, self.token_dim))
341
+ nn.init.xavier_uniform_(self.class_token)
342
+
343
+ if register:
344
+ # Add a register token initialized with Xavier uniform initialization
345
+ self.register = nn.Parameter(torch.empty(1, self.token_dim))
346
+ # self.register = (self.register / 2)
347
+ nn.init.xavier_uniform_(self.register)
348
+ else:
349
+ self.register = None
350
+
351
+ self.text_inputs = text_inputs
352
+ if self.text_inputs:
353
+ #print("JUST USING A HEADLESS SINR FOR THE TEXT MODEL RIGHT NOW")
354
+ self.text_model=HeadlessSINR(num_inputs=4096, num_filts=512, depth=2, nonlin=nonlin, lowrank=token_dim, dropout_p=dropout)
355
+ else:
356
+ self.text_model=None
357
+
358
+ # Type-specific embeddings for class, register, location, and text tokens
359
+ self.class_type_embedding = nn.Parameter(torch.empty(1, self.token_dim))
360
+ nn.init.xavier_uniform_(self.class_type_embedding)
361
+ if register:
362
+ self.register_type_embedding = nn.Parameter(torch.empty(1, self.token_dim))
363
+ nn.init.xavier_uniform_(self.register_type_embedding)
364
+ self.location_type_embedding = nn.Parameter(torch.empty(1, self.token_dim))
365
+ nn.init.xavier_uniform_(self.location_type_embedding)
366
+ if text_inputs:
367
+ self.text_type_embedding = nn.Parameter(torch.empty(1, self.token_dim))
368
+ nn.init.xavier_uniform_(self.text_type_embedding)
369
+
370
+ # Instantiate the class token transformation module
371
+ if class_token_transformation == 'identity':
372
+ self.class_token_transform = Identity(token_dim, num_filts)
373
+ elif class_token_transformation == 'linear':
374
+ self.class_token_transform = LinearTransformation(token_dim, num_filts)
375
+ elif class_token_transformation == 'single_layer_nn':
376
+ self.class_token_transform = SingleLayerNN(token_dim, num_filts, dropout_p=dropout)
377
+ elif class_token_transformation == 'two_layer_nn':
378
+ self.class_token_transform = TwoLayerNN(token_dim, num_filts, dropout_p=dropout)
379
+ elif class_token_transformation == 'sinr':
380
+ self.class_token_transform = HeadlessSINR(token_dim, num_filts, depth, nonlin, lowrank, dropout_p=dropout)
381
+ else:
382
+ raise ValueError(f"Unknown class_token_transformation: {class_token_transformation}")
383
+
384
+
385
+ def forward(self, x, context_sequence, context_mask, class_ids=None, return_feats=False, return_class_embeddings=False, class_of_interest=None, use_eval_embeddings=False, text_emb=None):
386
+ # Process input through the headless model to get feature embeddings
387
+ feature_embeddings = self.headless_model(x)
388
+
389
+ if return_feats:
390
+ return feature_embeddings
391
+
392
+ if context_sequence.dim() == 2:
393
+ context_sequence = context_sequence.unsqueeze(0) # Add batch dimension if missing
394
+
395
+ context_sequence = context_sequence[:, 1:, :]
396
+
397
+ if self.sinr_inputs:
398
+ # Pass through the headless model
399
+ context_sequence = self.headless_model(context_sequence)
400
+
401
+ # Add type-specific embedding to each location token
402
+ # print("SEE IF THIS WORKS")
403
+ context_sequence += self.location_type_embedding
404
+
405
+ batch_size = context_sequence.size(0)
406
+
407
+ # Expand the class token to match the batch size and add its type-specific embedding
408
+ class_token_expanded = self.class_token.expand(batch_size, -1, -1) + self.class_type_embedding
409
+
410
+ if self.text_inputs and (text_emb is not None):
411
+ text_mask = (text_emb.sum(dim=1) == 0)
412
+ text_emb = self.text_model(text_emb)
413
+ text_emb += self.text_type_embedding
414
+ text_emb[text_mask] = 0
415
+ # Reshape text_emb to have the shape (batch_size, 1, embedding_dim)
416
+ text_emb = text_emb.unsqueeze(1)
417
+
418
+
419
+ if self.register is None:
420
+ # context sequence = learnable class_token + rest of sequence
421
+ if self.text_inputs:
422
+ # Add the class token and text embeddings to the context sequence
423
+ context_sequence = torch.cat((class_token_expanded, text_emb, context_sequence), dim=1)
424
+ # Pad the context mask to account for the added text embeddings
425
+ context_mask = nn.functional.pad(context_mask, pad=(1, 0), value=False)
426
+ # Update the new part of the mask with the text_mask
427
+ context_mask[:, 1] = text_mask # Apply mask directly
428
+ else:
429
+ context_sequence = torch.cat((class_token_expanded, context_sequence), dim=1)
430
+ else:
431
+ # Expand the register token to match the batch size and add its type-specific embedding
432
+ register_expanded = self.register.expand(batch_size, -1, -1) + self.register_type_embedding
433
+ if self.text_inputs:
434
+ # Add all components: class token, register, text embeddings, and context
435
+ context_sequence = torch.cat((class_token_expanded, register_expanded, text_emb, context_sequence),
436
+ dim=1)
437
+ # Double pad the context mask: first for register, then for text embeddings
438
+ context_mask = nn.functional.pad(context_mask, pad=(1, 0), value=False)
439
+ context_mask = nn.functional.pad(context_mask, pad=(1, 0), value=False)
440
+ # Update the new part of the mask for text embeddings
441
+ context_mask[:, register_expanded.size(1) + 1] = text_mask # Apply mask directly
442
+ else:
443
+ context_sequence = torch.cat((class_token_expanded, register_expanded, context_sequence), dim=1)
444
+ # Update the context mask to account for the register token
445
+ context_mask = nn.functional.pad(context_mask, pad=(1, 0), value=False)
446
+
447
+ if use_eval_embeddings == False:
448
+ if class_of_interest == None:
449
+ # Get class-specific embeddings based on class_ids
450
+ class_token_output = self.transformer_model(src=context_sequence, src_key_padding_mask=context_mask)
451
+ # pass these through the class token transformation
452
+ class_embeddings = self.class_token_transform(class_token_output) # Shape: (batch_size, num_filts)
453
+
454
+ if return_class_embeddings:
455
+ return class_embeddings
456
+ else:
457
+ # Update EMA embeddings for these class IDs
458
+ with torch.no_grad():
459
+ if self.training:
460
+ self.update_ema_embeddings(class_ids, class_embeddings)
461
+
462
+ # Matrix multiplication to produce logits
463
+ logits = feature_embeddings @ class_embeddings.T
464
+
465
+ # Apply sigmoid to convert logits to probabilities
466
+ probabilities = torch.sigmoid(logits)
467
+
468
+ return probabilities
469
+ else:
470
+ device = self.ema_embeddings.weight.device
471
+ class_of_interest_tensor =torch.tensor([class_of_interest]).to(device)
472
+ class_embedding = self.get_ema_embeddings(class_of_interest_tensor)
473
+ print(f'using EMA estimate for class {class_of_interest}')
474
+ if return_class_embeddings:
475
+ return class_embedding
476
+ else:
477
+ # Matrix multiplication to produce logits
478
+ logits = feature_embeddings @ class_embedding.T
479
+
480
+ # Apply sigmoid to convert logits to probabilities
481
+ probabilities = torch.sigmoid(logits)
482
+ probabilities = probabilities.squeeze()
483
+ return probabilities
484
+ else:
485
+ self.eval()
486
+ if not hasattr(self, 'eval_embeddings'):
487
+ self.eval_embeddings = self.ema_embeddings
488
+ if class_of_interest == None:
489
+ # Get class-specific embeddings based on class_ids
490
+ class_token_output = self.transformer_model(src=context_sequence, src_key_padding_mask=context_mask)
491
+ class_embeddings = self.class_token_transform(class_token_output)
492
+ # Update EMA embeddings for these class IDs
493
+
494
+ self.generate_eval_embeddings(class_ids, class_embeddings)
495
+
496
+ # Matrix multiplication to produce logits
497
+ logits = feature_embeddings @ class_embeddings.T
498
+
499
+ # Apply sigmoid to convert logits to probabilities
500
+ probabilities = torch.sigmoid(logits)
501
+
502
+ return probabilities
503
+ else:
504
+ device = self.ema_embeddings.weight.device
505
+ class_of_interest_tensor =torch.tensor([class_of_interest]).to(device)
506
+ class_embedding = self.get_eval_embeddings(class_of_interest_tensor)
507
+ print(f'using eval embedding for class {class_of_interest}')
508
+ if return_class_embeddings:
509
+ return class_embedding
510
+ else:
511
+ # Matrix multiplication to produce logits
512
+ logits = feature_embeddings @ class_embedding.T
513
+
514
+ # Apply sigmoid to convert logits to probabilities
515
+ probabilities = torch.sigmoid(logits)
516
+ probabilities = probabilities.squeeze()
517
+ return probabilities
518
+
519
+ def init_eval_embeddings(self, num_classes):
520
+ self.eval_embeddings = nn.Embedding(num_embeddings=num_classes, embedding_dim=self.num_filts)
521
+ nn.init.xavier_uniform_(self.eval_embeddings.weight)
522
+
523
+ def get_ema_embeddings(self, class_ids):
524
+ # Method to access EMA embeddings
525
+ return self.ema_embeddings(class_ids)
526
+
527
+ def get_eval_embeddings(self, class_ids):
528
+ # Method to access eval embeddings
529
+ return self.eval_embeddings(class_ids)
530
+
531
+ def update_ema_embeddings(self, class_ids, current_embeddings):
532
+ if self.training:
533
+ # Get unique class IDs and their counts
534
+ unique_class_ids, inverse_indices, counts = class_ids.unique(return_counts=True, return_inverse=True)
535
+
536
+ # Get current EMA embeddings for unique class IDs
537
+ ema_current = self.ema_embeddings(unique_class_ids)
538
+
539
+ # Initialize a placeholder for new EMA values
540
+ ema_new = torch.zeros_like(ema_current)
541
+
542
+ # Compute the average of current embeddings for each unique class ID
543
+ current_sum = torch.zeros_like(ema_current)
544
+ current_sum.index_add_(0, inverse_indices, current_embeddings)
545
+ current_avg = current_sum / counts.unsqueeze(1)
546
+
547
+ # Apply EMA update formula
548
+ ema_new = self.ema_factor * current_avg + (1 - self.ema_factor) * ema_current
549
+
550
+ # Update the EMA embeddings for unique class IDs
551
+ self.ema_embeddings.weight.data[unique_class_ids] = ema_new.detach() # Detach to prevent gradients
552
+
553
+ def generate_eval_embeddings(self, class_id, current_embedding):
554
+ self.eval_embeddings.weight.data[class_id, :] = current_embedding.detach() # Detach to prevent gradients
555
+
556
+ # self.eval_embeddings.weight.data[class_id] = self.ema_embeddings.weight.data[class_id] # Detach to prevent gradients
557
+
558
+
559
+ def embedding_forward(self, x, class_ids=None, return_feats=False, return_class_embeddings=False, class_of_interest=None, eval=False):
560
+ # forward method that uses ema or eval embeddings rather than context sequence
561
+
562
+ # Process input through the headless model to get feature embeddings
563
+ feature_embeddings = self.headless_model(x)
564
+
565
+ if return_feats:
566
+ return feature_embeddings
567
+ else:
568
+ if class_of_interest == None:
569
+ # Get class-specific embeddings based on class_ids
570
+ if eval == False:
571
+ class_embeddings = self.get_ema_embeddings(class_ids=class_ids)
572
+ else:
573
+ class_embeddings = self.get_eval_embeddings(class_ids=class_ids)
574
+ if return_class_embeddings:
575
+ return class_embeddings
576
+ else:
577
+ # Matrix multiplication to produce logits
578
+ logits = feature_embeddings @ class_embeddings.T
579
+
580
+ # Apply sigmoid to convert logits to probabilities
581
+ probabilities = torch.sigmoid(logits)
582
+
583
+ return probabilities
584
+ else:
585
+ if eval == False:
586
+ device = self.ema_embeddings.weight.device
587
+ class_of_interest_tensor =torch.tensor([class_of_interest]).to(device)
588
+ class_embedding = self.get_ema_embeddings(class_of_interest_tensor)
589
+ print(f'using EMA estimate for class {class_of_interest}')
590
+ if return_class_embeddings:
591
+ return class_embedding
592
+ else:
593
+ # Matrix multiplication to produce logits
594
+ logits = feature_embeddings @ class_embedding.T
595
+
596
+ # Apply sigmoid to convert logits to probabilities
597
+ probabilities = torch.sigmoid(logits)
598
+ probabilities = probabilities.squeeze()
599
+
600
+ return probabilities
601
+
602
+ else:
603
+ device = self.eval_embeddings.weight.device
604
+ class_of_interest_tensor = torch.tensor([class_of_interest]).to(device)
605
+ class_embedding = self.get_eval_embeddings(class_of_interest_tensor)
606
+ #print(f'using eval estimate for class {class_of_interest}')
607
+ if return_class_embeddings:
608
+ return class_embedding
609
+ else:
610
+ # Matrix multiplication to produce logits
611
+ logits = feature_embeddings @ class_embedding.T
612
+
613
+ # Apply sigmoid to convert logits to probabilities
614
+ probabilities = torch.sigmoid(logits)
615
+ probabilities = probabilities.squeeze()
616
+
617
+ return probabilities
618
+
619
+ class VariableInputModel(nn.Module):
620
+ def __init__(self, num_inputs, num_filts, num_classes, depth=4, nonlin='relu', lowrank=0, ema_factor=0.1,
621
+ nhead=8, num_encoder_layers=4, dim_feedforward=2048, dropout=0.1, batch_first=True, token_dim=256,
622
+ sinr_inputs=False, register=False, use_pretrained_sinr=False, freeze_sinr=False, pretrained_loc='',
623
+ text_inputs=False, image_inputs=False, env_inputs=False, class_token_transformation='identity'):
624
+
625
+ super(VariableInputModel, self).__init__()
626
+ self.headless_model = HeadlessSINR(num_inputs, num_filts, depth, nonlin, lowrank, dropout_p=dropout)
627
+ self.ema_factor = ema_factor
628
+ self.class_token_transformation = class_token_transformation
629
+
630
+ # Load pretrained state_dict if use_pretrained_sinr is set to True
631
+ if use_pretrained_sinr:
632
+ pretrained_state_dict = torch.load(pretrained_loc, weights_only=False)['state_dict']
633
+ filtered_state_dict = {k: v for k, v in pretrained_state_dict.items() if not k.startswith('class_emb')}
634
+ self.headless_model.load_state_dict(filtered_state_dict, strict=False)
635
+ #print(f'Using pretrained sinr from {pretrained_loc}')
636
+
637
+ # Freeze the SINR model if freeze_sinr is set to True
638
+ if freeze_sinr:
639
+ for param in self.headless_model.parameters():
640
+ param.requires_grad = False
641
+ print("Freezing SINR model parameters")
642
+
643
+ # self.transformer_model = MockTransformer(num_classes, num_filts)
644
+ self.transformer_model = TransformerEncoderModel(d_model=token_dim,
645
+ nhead=nhead,
646
+ num_encoder_layers=num_encoder_layers,
647
+ dim_feedforward=dim_feedforward,
648
+ dropout=dropout,
649
+ batch_first=batch_first,
650
+ output_dim=num_filts)
651
+
652
+ self.ema_embeddings = nn.Embedding(num_embeddings=num_classes, embedding_dim=num_filts)
653
+ # this is just a workaround for now to load eval embeddings - probably not needed long term
654
+ self.eval_embeddings = nn.Embedding(num_embeddings=num_classes, embedding_dim=num_filts)
655
+ self.ema_embeddings.weight.requires_grad = False
656
+ self.eval_embeddings.weight.requires_grad = False
657
+ self.num_filts=num_filts
658
+ self.token_dim = token_dim
659
+ # nn.init.xavier_uniform_(self.ema_embeddings.weight) # not needed I think
660
+ self.sinr_inputs = sinr_inputs
661
+ if self.sinr_inputs:
662
+ if self.num_filts != self.token_dim and self.class_token_transformation == 'identity':
663
+ raise ValueError("If using sinr inputs to transformer with identity class token transformation"
664
+ "then token_dim of transformer must be equal to num_filts of sinr model")
665
+
666
+ # Add a class token
667
+ self.class_token = nn.Parameter(torch.empty(1, self.token_dim))
668
+ nn.init.xavier_uniform_(self.class_token)
669
+
670
+ if register:
671
+ # Add a register token initialized with Xavier uniform initialization
672
+ self.register = nn.Parameter(torch.empty(1, self.token_dim))
673
+ # self.register = (self.register / 2)
674
+ nn.init.xavier_uniform_(self.register)
675
+ else:
676
+ self.register = None
677
+
678
+ self.text_inputs = text_inputs
679
+ if self.text_inputs:
680
+ print("JUST USING A HEADLESS SINR FOR THE TEXT MODEL RIGHT NOW")
681
+ self.text_model=HeadlessSINR(num_inputs=4096, num_filts=512, depth=2, nonlin=nonlin, lowrank=token_dim, dropout_p=dropout)
682
+ else:
683
+ self.text_model=None
684
+ self.image_inputs = image_inputs
685
+ if self.image_inputs:
686
+ print("JUST USING A HEADLESS SINR FOR THE IMAGE MODEL RIGHT NOW")
687
+ self.image_model=HeadlessSINR(num_inputs=1024, num_filts=512, depth=2, nonlin=nonlin, lowrank=token_dim, dropout_p=dropout)
688
+ else:
689
+ self.image_model=None
690
+ self.env_inputs = env_inputs
691
+ if self.env_inputs:
692
+ print("JUST USING A HEADLESS SINR FOR THE ENV MODEL RIGHT NOW")
693
+ self.env_model=HeadlessSINR(num_inputs=20, num_filts=512, depth=2, nonlin=nonlin, lowrank=token_dim, dropout_p=dropout)
694
+ else:
695
+ self.env_model=None
696
+
697
+ # Type-specific embeddings for class, register, location, text, image and env tokens
698
+ self.class_type_embedding = nn.Parameter(torch.empty(1, self.token_dim))
699
+ nn.init.xavier_uniform_(self.class_type_embedding)
700
+ if register:
701
+ self.register_type_embedding = nn.Parameter(torch.empty(1, self.token_dim))
702
+ nn.init.xavier_uniform_(self.register_type_embedding)
703
+ self.location_type_embedding = nn.Parameter(torch.empty(1, self.token_dim))
704
+ nn.init.xavier_uniform_(self.location_type_embedding)
705
+ if text_inputs:
706
+ self.text_type_embedding = nn.Parameter(torch.empty(1, self.token_dim))
707
+ nn.init.xavier_uniform_(self.text_type_embedding)
708
+ if image_inputs:
709
+ self.image_type_embedding = nn.Parameter(torch.empty(1, self.token_dim))
710
+ nn.init.xavier_uniform_(self.image_type_embedding)
711
+ if env_inputs:
712
+ self.env_type_embedding = nn.Parameter(torch.empty(1, self.token_dim))
713
+ nn.init.xavier_uniform_(self.env_type_embedding)
714
+
715
+ # Instantiate the class token transformation module
716
+ if class_token_transformation == 'identity':
717
+ self.class_token_transform = Identity(token_dim, num_filts)
718
+ elif class_token_transformation == 'linear':
719
+ self.class_token_transform = LinearTransformation(token_dim, num_filts)
720
+ elif class_token_transformation == 'single_layer_nn':
721
+ self.class_token_transform = SingleLayerNN(token_dim, num_filts, dropout_p=dropout)
722
+ elif class_token_transformation == 'two_layer_nn':
723
+ self.class_token_transform = TwoLayerNN(token_dim, num_filts, dropout_p=dropout)
724
+ elif class_token_transformation == 'sinr':
725
+ self.class_token_transform = HeadlessSINR(token_dim, num_filts, 2, nonlin, lowrank, dropout_p=dropout)
726
+ else:
727
+ raise ValueError(f"Unknown class_token_transformation: {class_token_transformation}")
728
+
729
+ def forward(self, x, context_sequence, context_mask, class_ids=None, return_feats=False,
730
+ return_class_embeddings=False, class_of_interest=None, use_eval_embeddings=False, text_emb=None,
731
+ image_emb=None, env_emb=None):
732
+ # Process input through the headless model to get feature embeddings
733
+ feature_embeddings = self.headless_model(x)
734
+
735
+ if return_feats:
736
+ return feature_embeddings
737
+
738
+ if context_sequence.dim() == 2:
739
+ context_sequence = context_sequence.unsqueeze(0) # Add batch dimension if missing
740
+
741
+ context_sequence = context_sequence[:, 1:, :]
742
+
743
+ context_mask = context_mask[:, 1:]
744
+
745
+ if self.sinr_inputs:
746
+ context_sequence = self.headless_model(context_sequence)
747
+
748
+ # Add type-specific embedding to each location token
749
+ context_sequence += self.location_type_embedding
750
+
751
+ batch_size = context_sequence.size(0)
752
+
753
+ # Initialize lists for tokens and masks
754
+ tokens = []
755
+ masks = []
756
+
757
+ # Process class token
758
+ class_token_expanded = self.class_token.expand(batch_size, -1, -1) + self.class_type_embedding
759
+ tokens.append(class_token_expanded)
760
+ # The class token is always present, so mask is False (i.e., not masked out)
761
+ class_mask = torch.zeros(batch_size, 1, dtype=torch.bool, device=context_sequence.device)
762
+ masks.append(class_mask)
763
+
764
+ # Process register token if present
765
+ if self.register is not None:
766
+ register_expanded = self.register.expand(batch_size, -1, -1) + self.register_type_embedding
767
+ tokens.append(register_expanded)
768
+ register_mask = torch.zeros(batch_size, 1, dtype=torch.bool, device=context_sequence.device)
769
+ masks.append(register_mask)
770
+
771
+ # Process text embeddings
772
+ if self.text_inputs and (text_emb is not None):
773
+ text_mask = (text_emb.sum(dim=1) == 0)
774
+ text_emb = self.text_model(text_emb)
775
+ text_emb += self.text_type_embedding
776
+ # Set embeddings to zero where mask is True
777
+ text_emb[text_mask] = 0
778
+ text_emb = text_emb.unsqueeze(1)
779
+ tokens.append(text_emb)
780
+ # Expand text_mask to match sequence dimensions
781
+ text_mask = text_mask.unsqueeze(1)
782
+ masks.append(text_mask)
783
+
784
+ # Process image embeddings
785
+ if self.image_inputs and (image_emb is not None):
786
+ image_mask = (image_emb.sum(dim=1) == 0)
787
+ image_emb = self.image_model(image_emb)
788
+ image_emb += self.image_type_embedding
789
+ image_emb[image_mask] = 0
790
+ image_emb = image_emb.unsqueeze(1)
791
+ tokens.append(image_emb)
792
+ image_mask = image_mask.unsqueeze(1)
793
+ masks.append(image_mask)
794
+
795
+ # Process env embeddings if needed (can be added similarly)
796
+ if self.env_inputs and (env_emb is not None):
797
+ env_mask = context_mask
798
+ env_emb = self.env_model(env_emb)
799
+ env_emb += self.env_type_embedding
800
+ env_emb[env_mask] = 0
801
+ env_emb = env_emb.unsqueeze(1)
802
+ tokens.append(env_emb)
803
+ env_mask = env_mask.unsqueeze(1)
804
+ masks.append(env_mask)
805
+
806
+ # Process location tokens
807
+ tokens.append(context_sequence)
808
+ masks.append(context_mask)
809
+
810
+ # Concatenate all tokens and masks
811
+ context_sequence = torch.cat(tokens, dim=1)
812
+ context_mask = torch.cat(masks, dim=1)
813
+
814
+ if use_eval_embeddings == False:
815
+ if class_of_interest == None:
816
+ # Get class-specific embeddings based on class_ids
817
+ class_token_output = self.transformer_model(src=context_sequence, src_key_padding_mask=context_mask)
818
+ # pass these through the class token transformation
819
+ class_embeddings = self.class_token_transform(class_token_output) # Shape: (batch_size, num_filts)
820
+
821
+ if return_class_embeddings:
822
+ return class_embeddings
823
+ else:
824
+ # Update EMA embeddings for these class IDs
825
+ with torch.no_grad():
826
+ if self.training:
827
+ self.update_ema_embeddings(class_ids, class_embeddings)
828
+
829
+ # Matrix multiplication to produce logits
830
+ logits = feature_embeddings @ class_embeddings.T
831
+
832
+ # Apply sigmoid to convert logits to probabilities
833
+ probabilities = torch.sigmoid(logits)
834
+
835
+ return probabilities
836
+ else:
837
+ device = self.ema_embeddings.weight.device
838
+ class_of_interest_tensor =torch.tensor([class_of_interest]).to(device)
839
+ class_embedding = self.get_ema_embeddings(class_of_interest_tensor)
840
+ print(f'using EMA estimate for class {class_of_interest}')
841
+ if return_class_embeddings:
842
+ return class_embedding
843
+ else:
844
+ # Matrix multiplication to produce logits
845
+ logits = feature_embeddings @ class_embedding.T
846
+
847
+ # Apply sigmoid to convert logits to probabilities
848
+ probabilities = torch.sigmoid(logits)
849
+ probabilities = probabilities.squeeze()
850
+ return probabilities
851
+ else:
852
+ self.eval()
853
+ if not hasattr(self, 'eval_embeddings'):
854
+ print('No Eval Embeddings for this species?!')
855
+ self.eval_embeddings = self.ema_embeddings
856
+ if class_of_interest == None:
857
+ # Get class-specific embeddings based on class_ids
858
+ class_token_output = self.transformer_model(src=context_sequence, src_key_padding_mask=context_mask)
859
+ class_embeddings = self.class_token_transform(class_token_output)
860
+ # Update EMA embeddings for these class IDs
861
+
862
+ self.generate_eval_embeddings(class_ids, class_embeddings)
863
+
864
+ # Matrix multiplication to produce logits
865
+ logits = feature_embeddings @ class_embeddings.T
866
+
867
+ # Apply sigmoid to convert logits to probabilities
868
+ probabilities = torch.sigmoid(logits)
869
+
870
+ return probabilities
871
+ else:
872
+ device = self.ema_embeddings.weight.device
873
+ class_of_interest_tensor =torch.tensor([class_of_interest]).to(device)
874
+ class_embedding = self.get_eval_embeddings(class_of_interest_tensor)
875
+ print(f'using eval embedding for class {class_of_interest}')
876
+ if return_class_embeddings:
877
+ return class_embedding
878
+ else:
879
+ # Matrix multiplication to produce logits
880
+ logits = feature_embeddings @ class_embedding.T
881
+
882
+ # Apply sigmoid to convert logits to probabilities
883
+ probabilities = torch.sigmoid(logits)
884
+ probabilities = probabilities.squeeze()
885
+ return probabilities
886
+
887
+ def get_loc_emb(self, x):
888
+ feature_embeddings = self.headless_model(x)
889
+ return feature_embeddings
890
+
891
+ def init_eval_embeddings(self, num_classes):
892
+ self.eval_embeddings = nn.Embedding(num_embeddings=num_classes, embedding_dim=self.num_filts)
893
+ nn.init.xavier_uniform_(self.eval_embeddings.weight)
894
+
895
+ def get_ema_embeddings(self, class_ids):
896
+ # Method to access EMA embeddings
897
+ return self.ema_embeddings(class_ids)
898
+
899
+ def get_eval_embeddings(self, class_ids):
900
+ # Method to access eval embeddings
901
+ return self.eval_embeddings(class_ids)
902
+
903
+ def update_ema_embeddings(self, class_ids, current_embeddings):
904
+ if self.training:
905
+ # Get unique class IDs and their counts
906
+ unique_class_ids, inverse_indices, counts = class_ids.unique(return_counts=True, return_inverse=True)
907
+
908
+ # Get current EMA embeddings for unique class IDs
909
+ ema_current = self.ema_embeddings(unique_class_ids)
910
+
911
+ # Initialize a placeholder for new EMA values
912
+ ema_new = torch.zeros_like(ema_current)
913
+
914
+ # Compute the average of current embeddings for each unique class ID
915
+ current_sum = torch.zeros_like(ema_current)
916
+ current_sum.index_add_(0, inverse_indices, current_embeddings)
917
+ current_avg = current_sum / counts.unsqueeze(1)
918
+
919
+ # Apply EMA update formula
920
+ ema_new = self.ema_factor * current_avg + (1 - self.ema_factor) * ema_current
921
+
922
+ # Update the EMA embeddings for unique class IDs
923
+ self.ema_embeddings.weight.data[unique_class_ids] = ema_new.detach() # Detach to prevent gradients
924
+
925
+ def generate_eval_embeddings(self, class_id, current_embedding):
926
+ self.eval_embeddings.weight.data[class_id, :] = current_embedding.detach() # Detach to prevent gradients
927
+
928
+ # self.eval_embeddings.weight.data[class_id] = self.ema_embeddings.weight.data[class_id] # Detach to prevent gradients
929
+
930
+ def embedding_forward(self, x, class_ids=None, return_feats=False, return_class_embeddings=False, class_of_interest=None, eval=False):
931
+ # forward method that uses ema or eval embeddings rather than context sequence
932
+
933
+ # Process input through the headless model to get feature embeddings
934
+ feature_embeddings = self.headless_model(x)
935
+
936
+ if return_feats:
937
+ return feature_embeddings
938
+ else:
939
+ if class_of_interest == None:
940
+ # Get class-specific embeddings based on class_ids
941
+ if eval == False:
942
+ class_embeddings = self.get_ema_embeddings(class_ids=class_ids)
943
+ else:
944
+ class_embeddings = self.get_eval_embeddings(class_ids=class_ids)
945
+ if return_class_embeddings:
946
+ return class_embeddings
947
+ else:
948
+ # Matrix multiplication to produce logits
949
+ logits = feature_embeddings @ class_embeddings.T
950
+
951
+ # Apply sigmoid to convert logits to probabilities
952
+ probabilities = torch.sigmoid(logits)
953
+
954
+ return probabilities
955
+ else:
956
+ if eval == False:
957
+ device = self.ema_embeddings.weight.device
958
+ class_of_interest_tensor =torch.tensor([class_of_interest]).to(device)
959
+ class_embedding = self.get_ema_embeddings(class_of_interest_tensor)
960
+ print(f'using EMA estimate for class {class_of_interest}')
961
+ if return_class_embeddings:
962
+ return class_embedding
963
+ else:
964
+ # Matrix multiplication to produce logits
965
+ logits = feature_embeddings @ class_embedding.T
966
+
967
+ # Apply sigmoid to convert logits to probabilities
968
+ probabilities = torch.sigmoid(logits)
969
+ probabilities = probabilities.squeeze()
970
+
971
+ return probabilities
972
+
973
+ else:
974
+ device = self.eval_embeddings.weight.device
975
+ class_of_interest_tensor = torch.tensor([class_of_interest]).to(device)
976
+ class_embedding = self.get_eval_embeddings(class_of_interest_tensor)
977
+ #print(f'using eval estimate for class {class_of_interest}')
978
+ if return_class_embeddings:
979
+ return class_embedding
980
+ else:
981
+ # Matrix multiplication to produce logits
982
+ logits = feature_embeddings @ class_embedding.T
983
+
984
+ # Apply sigmoid to convert logits to probabilities
985
+ probabilities = torch.sigmoid(logits)
986
+ probabilities = probabilities.squeeze()
987
+
988
+ return probabilities
989
+
990
+
991
+ class LinNet(nn.Module):
992
+ def __init__(self, num_inputs, num_classes):
993
+ super(LinNet, self).__init__()
994
+ self.num_layers = 0
995
+ self.inc_bias = False
996
+ self.class_emb = nn.Linear(num_inputs, num_classes, bias=self.inc_bias)
997
+ self.feats = nn.Identity() # does not do anything
998
+
999
+ def forward(self, x, class_of_interest=None, return_feats=False):
1000
+ loc_emb = self.feats(x)
1001
+ if return_feats:
1002
+ return loc_emb
1003
+ if class_of_interest is None:
1004
+ class_pred = self.class_emb(loc_emb)
1005
+ else:
1006
+ class_pred = self.eval_single_class(loc_emb, class_of_interest)
1007
+
1008
+ return torch.sigmoid(class_pred)
1009
+
1010
+ def eval_single_class(self, x, class_of_interest):
1011
+ if self.inc_bias:
1012
+ return x @ self.class_emb.weight[class_of_interest, :] + self.class_emb.bias[class_of_interest]
1013
+ else:
1014
+ return x @ self.class_emb.weight[class_of_interest, :]
1015
+
1016
+
1017
+ class ParallelMulti(torch.nn.Module):
1018
+ def __init__(self, x: list[torch.nn.Module]):
1019
+ super(ParallelMulti, self).__init__()
1020
+ self.layers = nn.ModuleList(x)
1021
+
1022
+ def forward(self, xs, **kwargs):
1023
+ out = torch.cat([self.layers[i](x, **kwargs) for i,x in enumerate(xs)], dim=1)
1024
+ return out
1025
+
1026
+
1027
+ class SequentialMulti(torch.nn.Sequential):
1028
+ def forward(self, *inputs, **kwargs):
1029
+ for module in self._modules.values():
1030
+ if type(inputs) == tuple:
1031
+ inputs = module(*inputs, **kwargs)
1032
+ else:
1033
+ inputs = module(inputs)
1034
+ return inputs
1035
+
1036
+
1037
+ # Chris's transformation classes
1038
+ class Identity(nn.Module):
1039
+ def __init__(self, in_dim, out_dim):
1040
+ super(Identity, self).__init__()
1041
+ # No parameters needed for identity transformation
1042
+
1043
+ def forward(self, x):
1044
+ return x
1045
+
1046
+ class LinearTransformation(nn.Module):
1047
+ def __init__(self, in_dim, out_dim, bias=True):
1048
+ super(LinearTransformation, self).__init__()
1049
+ self.linear = nn.Linear(in_dim, out_dim, bias=bias)
1050
+
1051
+ def forward(self, x):
1052
+ return self.linear(x)
1053
+
1054
+ class SingleLayerNN(nn.Module):
1055
+ def __init__(self, in_dim, out_dim, dropout_p=0.1, bias=True):
1056
+ super(SingleLayerNN, self).__init__()
1057
+ hidden_dim = (in_dim + out_dim) // 2 # Choose an appropriate hidden dimension
1058
+ self.net = nn.Sequential(
1059
+ nn.Linear(in_dim, hidden_dim, bias=bias),
1060
+ nn.ReLU(),
1061
+ nn.Dropout(p=dropout_p),
1062
+ nn.Linear(hidden_dim, out_dim, bias=bias)
1063
+ )
1064
+
1065
+ def forward(self, x):
1066
+ return self.net(x)
1067
+
1068
+ class TwoLayerNN(nn.Module):
1069
+ def __init__(self, in_dim, out_dim, dropout_p=0.1, bias=True):
1070
+ super(TwoLayerNN, self).__init__()
1071
+ hidden_dim = (in_dim + out_dim) // 2 # Choose an appropriate hidden dimension
1072
+ self.net = nn.Sequential(
1073
+ nn.Linear(in_dim, hidden_dim, bias=bias),
1074
+ nn.ReLU(),
1075
+ nn.Dropout(p=dropout_p),
1076
+ nn.Linear(hidden_dim, hidden_dim, bias=bias),
1077
+ nn.ReLU(),
1078
+ nn.Dropout(p=dropout_p),
1079
+ nn.Linear(hidden_dim, out_dim, bias=bias)
1080
+ )
1081
+
1082
+ def forward(self, x):
1083
+ return self.net(x)
1084
+
1085
+ class HyperNet(nn.Module):
1086
+ '''
1087
+ :param asdf
1088
+ '''
1089
+ def __init__(self, params, num_inputs, num_classes, num_filts, pos_enc_depth, species_dim, species_enc_depth, species_filts, species_enc='embed', inference_only=False):
1090
+ super(HyperNet, self).__init__()
1091
+ if species_enc == 'embed':
1092
+ self.species_emb = nn.Embedding(num_classes, species_dim)
1093
+ self.species_emb.weight.data *= 0.01
1094
+ elif species_enc == 'taxa':
1095
+ self.species_emb = TaxaEncoder(params, './data/inat_taxa_info.csv', species_dim)
1096
+ elif species_enc == 'text':
1097
+ self.species_emb = TextEncoder(params, params['text_emb_path'], species_dim, './data/inat_taxa_info.csv')
1098
+ elif species_enc == 'wiki':
1099
+ self.species_emb = WikiEncoder(params, params['text_emb_path'], species_dim, inference_only=inference_only)
1100
+ if species_enc_depth == -1:
1101
+ self.species_enc = nn.Identity()
1102
+ elif species_enc_depth == 0:
1103
+ self.species_enc = nn.Linear(species_dim, num_filts+1)
1104
+ else:
1105
+ self.species_enc = SimpleFCNet(species_dim, num_filts+1, species_filts, depth=species_enc_depth)
1106
+ if 'geoprior' in params['loss']:
1107
+ self.species_params = nn.Parameter(torch.randn(num_classes, species_dim))
1108
+ self.species_params.data *= 0.0386
1109
+ self.pos_enc = SimpleFCNet(num_inputs, num_filts, num_filts, depth=pos_enc_depth)
1110
+
1111
+ def forward(self, x, y):
1112
+ ys, indmap = torch.unique(y, return_inverse=True)
1113
+ species = self.species_enc(self.species_emb(ys))
1114
+ species_w, species_b = species[...,:-1], species[...,-1:]
1115
+ pos = self.pos_enc(x)
1116
+ out = torch.bmm(species_w[indmap],pos[...,None])
1117
+ out = (out + 0*species_b[indmap]).squeeze(-1) #TODO
1118
+ if hasattr(self, 'species_params'):
1119
+ out2 = torch.bmm(self.species_params[ys][indmap],pos[...,None])
1120
+ out2 = out2.squeeze(-1)
1121
+ out3 = (species_w, self.species_params[ys], ys)
1122
+ return out, out2, out3
1123
+ else:
1124
+ return out
1125
+
1126
+ def zero_shot(self, x, species_emb):
1127
+ species = self.species_enc(self.species_emb.zero_shot(species_emb))
1128
+ species_w, _ = species[...,:-1], species[...,-1:]
1129
+ pos = self.pos_enc(x)
1130
+ out = pos @ species_w.T
1131
+ return out
1132
+
1133
+
1134
+ class TaxaEncoder(nn.Module):
1135
+ def __init__(self, params, fpath, embedding_dim):
1136
+ super(TaxaEncoder, self).__init__()
1137
+ import datasets
1138
+ with open('paths.json', 'r') as f:
1139
+ paths = json.load(f)
1140
+ data_dir = paths['train']
1141
+ obs_file = os.path.join(data_dir, params['obs_file'])
1142
+ taxa_file_snt = os.path.join(data_dir, 'taxa_subsets.json')
1143
+
1144
+ taxa_of_interest = datasets.get_taxa_of_interest(params['species_set'], params['num_aux_species'],
1145
+ params['aux_species_seed'], params['taxa_file'], taxa_file_snt)
1146
+
1147
+ locs, labels, _, dates, _, _ = datasets.load_inat_data(obs_file, taxa_of_interest)
1148
+ unique_taxa, class_ids = np.unique(labels, return_inverse=True)
1149
+ class_to_taxa = unique_taxa.tolist()
1150
+
1151
+ self.fpath = fpath
1152
+ ids = []
1153
+ rows = []
1154
+ with open(fpath, newline='') as csvfile:
1155
+ spamreader = csv.reader(csvfile, delimiter=',')
1156
+ for row in spamreader:
1157
+ if row[0] == 'taxon_id':
1158
+ continue
1159
+ ids.append(int(row[0]))
1160
+ rows.append(row[3:])
1161
+ print()
1162
+ rows = np.array(rows)
1163
+ rows = [np.unique(rows[:,i], return_inverse=True)[1] for i in range(rows.shape[1])]
1164
+ rows = torch.from_numpy(np.vstack(rows).T)
1165
+ rows = rows
1166
+ self.taxa2row = {taxaid:i for i, taxaid in enumerate(ids)}
1167
+ embs = [nn.Embedding(rows[:,i].max()+2, embedding_dim, 0) for i in range(rows.shape[1])]
1168
+ embs[-1] = nn.Embedding(len(class_to_taxa), embedding_dim)
1169
+ rows2 = torch.zeros((len(class_to_taxa), 7), dtype=rows.dtype)
1170
+ startind = rows[:,-1].max()
1171
+ for i in range(len(class_to_taxa)):
1172
+ if class_to_taxa[i] in ids:
1173
+ rows2[i] = rows[ids.index(class_to_taxa[i])]+1
1174
+ rows2[i,-1] -= 1
1175
+ else:
1176
+ rows2[i,-1] = startind
1177
+ startind += 1
1178
+ self.register_buffer('rows', rows2)
1179
+ for e in embs:
1180
+ e.weight.data *= 0.01
1181
+ self.embs = nn.ModuleList(embs)
1182
+
1183
+ def forward(self, x):
1184
+ inds = self.rows[x]
1185
+ out = sum([self.embs[i](inds[...,i]) for i in range(inds.shape[-1])])
1186
+ return out
1187
+
1188
+
1189
+ class TextEncoder(nn.Module):
1190
+ def __init__(self, params, path, embedding_dim, fpath='inat_taxa_info.csv'):
1191
+ super(TextEncoder, self).__init__()
1192
+ import datasets
1193
+ with open('paths.json', 'r') as f:
1194
+ paths = json.load(f)
1195
+ data_dir = paths['train']
1196
+ obs_file = os.path.join(data_dir, params['obs_file'])
1197
+ taxa_file_snt = os.path.join(data_dir, 'taxa_subsets.json')
1198
+
1199
+ taxa_of_interest = datasets.get_taxa_of_interest(params['species_set'], params['num_aux_species'],
1200
+ params['aux_species_seed'], params['taxa_file'], taxa_file_snt)
1201
+
1202
+ locs, labels, _, dates, _, _ = datasets.load_inat_data(obs_file, taxa_of_interest)
1203
+ unique_taxa, class_ids = np.unique(labels, return_inverse=True)
1204
+ class_to_taxa = unique_taxa.tolist()
1205
+
1206
+ self.fpath = fpath
1207
+ ids = []
1208
+ with open(fpath, newline='') as csvfile:
1209
+ spamreader = csv.reader(csvfile, delimiter=',')
1210
+ for row in spamreader:
1211
+ if row[0] == 'taxon_id':
1212
+ continue
1213
+ ids.append(int(row[0]))
1214
+ embs = torch.load(path)
1215
+ if len(embs) != len(ids):
1216
+ print("Warning: Number of embeddings doesn't match number of species")
1217
+ ids = ids[:embs.shape[0]]
1218
+ if isinstance(embs, list):
1219
+ embs = torch.stack(embs)
1220
+ self.taxa2row = {taxaid:i for i, taxaid in enumerate(ids)}
1221
+ indmap = -1+torch.zeros(len(class_to_taxa), dtype=torch.int)
1222
+ embmap = -1+torch.zeros(len(class_to_taxa), dtype=torch.int)
1223
+ self.missing_emb = nn.Embedding(len(class_to_taxa)-embs.shape[0], embedding_dim)
1224
+
1225
+ startind = 0
1226
+ for i in range(len(class_to_taxa)):
1227
+ if class_to_taxa[i] in ids:
1228
+ indmap[i] = ids.index(class_to_taxa[i])
1229
+ else:
1230
+ embmap[i] = startind
1231
+ startind += 1
1232
+ self.scales = nn.Parameter(torch.zeros(len(class_to_taxa), 1))
1233
+ self.register_buffer('indmap', indmap, persistent=False)
1234
+ self.register_buffer('embmap', embmap, persistent=False)
1235
+ self.register_buffer('embs', embs, persistent=False)
1236
+ if params['text_hidden_dim'] == 0:
1237
+ self.linear1 = nn.Linear(embs.shape[1], embedding_dim)
1238
+ else:
1239
+ self.linear1 = nn.Linear(embs.shape[1], params['text_hidden_dim'])
1240
+ self.linear2 = nn.Linear(params['text_hidden_dim'], embedding_dim)
1241
+ self.act = nn.SiLU()
1242
+ if params['text_learn_dim'] > 0:
1243
+ self.learned_emb = nn.Embedding(len(class_to_taxa), params['text_learn_dim'])
1244
+ self.learned_emb.weight.data *= 0.01
1245
+ self.linear_learned = nn.Linear(params['text_learn_dim'], embedding_dim)
1246
+
1247
+ def forward(self, x):
1248
+ inds = self.indmap[x]
1249
+ out = self.embs[self.indmap[x].cpu()]
1250
+ out = self.linear1(out)
1251
+ if hasattr(self, 'linear2'):
1252
+ out = self.linear2(self.act(out))
1253
+ out = self.scales[x] * (out / (out.std(dim=1)[:, None]))
1254
+ out[inds == -1] = self.missing_emb(self.embmap[x[inds == -1]])
1255
+ if hasattr(self, 'learned_emb'):
1256
+ out2 = self.learned_emb(x)
1257
+ out2 = self.linear_learned(out2)
1258
+ out = out+out2
1259
+ return out
1260
+
1261
+
1262
+ class WikiEncoder(nn.Module):
1263
+ def __init__(self, params, path, embedding_dim, inference_only=False):
1264
+ super(WikiEncoder, self).__init__()
1265
+ self.path = path
1266
+ if not inference_only:
1267
+ import datasets
1268
+ with open('paths.json', 'r') as f:
1269
+ paths = json.load(f)
1270
+ data_dir = paths['train']
1271
+ obs_file = os.path.join(data_dir, params['obs_file'])
1272
+ taxa_file_snt = os.path.join(data_dir, 'taxa_subsets.json')
1273
+
1274
+ taxa_of_interest = datasets.get_taxa_of_interest(params['species_set'], params['num_aux_species'],
1275
+ params['aux_species_seed'], params['taxa_file'], taxa_file_snt)
1276
+
1277
+ locs, labels, _, dates, _, _ = datasets.load_inat_data(obs_file, taxa_of_interest)
1278
+ if params['zero_shot']:
1279
+ with open('paths.json', 'r') as f:
1280
+ paths = json.load(f)
1281
+ with open(os.path.join(paths['iucn'], 'iucn_res_5.json'), 'r') as f:
1282
+ data = json.load(f)
1283
+ D = np.load(os.path.join(paths['snt'], 'snt_res_5.npy'), allow_pickle=True)
1284
+ D = D.item()
1285
+ taxa_snt = D['taxa'].tolist()
1286
+ taxa = [int(tt) for tt in data['taxa_presence'].keys()]
1287
+ taxa = list(set(taxa + taxa_snt))
1288
+ mask = labels != taxa[0]
1289
+ for i in range(1, len(taxa)):
1290
+ mask &= (labels != taxa[i])
1291
+ locs = locs[mask]
1292
+ dates = dates[mask]
1293
+ labels = labels[mask]
1294
+ unique_taxa, class_ids = np.unique(labels, return_inverse=True)
1295
+ class_to_taxa = unique_taxa.tolist()
1296
+
1297
+ embs = torch.load(path)
1298
+ ids = embs['taxon_id'].tolist()
1299
+ if 'keys' in embs:
1300
+ taxa_counts = torch.zeros(len(ids), dtype=torch.int32)
1301
+ for i,k in embs['keys']:
1302
+ taxa_counts[i] += 1
1303
+ else:
1304
+ taxa_counts = torch.ones(len(ids), dtype=torch.int32)
1305
+ count_sum = torch.cumsum(taxa_counts, dim=0) - taxa_counts
1306
+ embs = embs['data']
1307
+
1308
+ self.taxa2row = {taxaid:i for i, taxaid in enumerate(ids)}
1309
+ indmap = -1+torch.zeros(len(class_to_taxa), dtype=torch.int)
1310
+ countmap = torch.zeros(len(class_to_taxa), dtype=torch.int)
1311
+ self.species_emb = nn.Embedding(len(class_to_taxa), embedding_dim)
1312
+ self.species_emb.weight.data *= 0.01
1313
+
1314
+ for i in range(len(class_to_taxa)):
1315
+ if class_to_taxa[i] in ids:
1316
+ i2 = ids.index(class_to_taxa[i])
1317
+ indmap[i] = count_sum[i2]
1318
+ countmap[i] = taxa_counts[i2]
1319
+
1320
+ self.register_buffer('indmap', indmap, persistent=False)
1321
+ self.register_buffer('countmap', countmap, persistent=False)
1322
+ self.register_buffer('embs', embs, persistent=False)
1323
+ assert embs.shape[1] == 4096
1324
+ self.scale = nn.Parameter(torch.zeros(1))
1325
+ if params['species_dropout'] > 0:
1326
+ self.dropout = nn.Dropout(p=params['species_dropout'])
1327
+ if params['text_hidden_dim'] == 0:
1328
+ self.linear1 = nn.Linear(4096, embedding_dim)
1329
+ else:
1330
+ self.linear1 = nn.Linear(4096, params['text_hidden_dim'])
1331
+ if params['text_batchnorm']:
1332
+ self.bn1 = nn.BatchNorm1d(params['text_hidden_dim'])
1333
+ for l in range(params['text_num_layers']-1):
1334
+ setattr(self, f'linear{l+2}', nn.Linear(params['text_hidden_dim'], params['text_hidden_dim']))
1335
+ if params['text_batchnorm']:
1336
+ setattr(self, f'bn{l+2}', nn.BatchNorm1d(params['text_hidden_dim']))
1337
+ setattr(self, f'linear{params["text_num_layers"]+1}', nn.Linear(params['text_hidden_dim'], embedding_dim))
1338
+ self.act = nn.SiLU()
1339
+ if params['text_learn_dim'] > 0:
1340
+ self.learned_emb = nn.Embedding(len(class_to_taxa), params['text_learn_dim'])
1341
+ self.learned_emb.weight.data *= 0.01
1342
+ self.linear_learned = nn.Linear(params['text_learn_dim'], embedding_dim)
1343
+
1344
+ def forward(self, x):
1345
+ inds = self.indmap[x] + (torch.rand(x.shape,device=x.device)*self.countmap[x]).floor().int()
1346
+ out = self.embs[inds]
1347
+ if hasattr(self, 'dropout'):
1348
+ out = self.dropout(out)
1349
+ out = self.linear1(out)
1350
+ if hasattr(self, 'linear2'):
1351
+ out = self.act(out)
1352
+ if hasattr(self, 'bn1'):
1353
+ out = self.bn1(out)
1354
+ i = 2
1355
+ while hasattr(self, f'linear{i}'):
1356
+ if hasattr(self, f'linear{i}'):
1357
+ out = self.act(getattr(self, f'linear{i}')(out))
1358
+ if hasattr(self, f'bn{i}'):
1359
+ out = getattr(self, f'bn{i}')(out)
1360
+ i += 1
1361
+ #out = self.scale * (out / (out.std(dim=1)[:, None]))
1362
+ out2 = self.species_emb(x)
1363
+ chosen = torch.rand((out.shape[0],), device=x.device)
1364
+ chosen = 1+0*chosen #TODO fix this
1365
+ chosen[inds == -1] = 0
1366
+ out = chosen[:,None] * out + (1-chosen[:,None])*out2
1367
+ if hasattr(self, 'learned_emb'):
1368
+ out2 = self.learned_emb(x)
1369
+ out2 = self.linear_learned(out2)
1370
+ out = out+out2
1371
+ return out
1372
+
1373
+
1374
+ def zero_shot(self, species_emb):
1375
+ out = species_emb
1376
+ out = self.linear1(out)
1377
+ if hasattr(self, 'linear2'):
1378
+ out = self.act(out)
1379
+ if hasattr(self, 'bn1'):
1380
+ out = self.bn1(out)
1381
+ i = 2
1382
+ while hasattr(self, f'linear{i}'):
1383
+ if hasattr(self, f'linear{i}'):
1384
+ out = self.act(getattr(self, f'linear{i}')(out))
1385
+ if hasattr(self, f'bn{i}'):
1386
+ out = getattr(self, f'bn{i}')(out)
1387
+ i += 1
1388
+ return out
1389
+
1390
+ def zero_shot_old(self, species_emb):
1391
+ out = species_emb
1392
+ out = self.linear1(out)
1393
+ if hasattr(self, 'linear2'):
1394
+ out = self.linear2(self.act(out))
1395
+ out = self.scale * (out / (out.std(dim=-1, keepdim=True)))
1396
+ return out
1397
+
1398
+ # MINE - would only be used for my models - not currently being used at all
1399
+ # CURRENTLY JUST USING A HEADLESS_SINR FOR THE TEXT ENCODER
1400
+ class MultiInputTextEncoder(nn.Module):
1401
+ def __init__(self, token_dim, dropout, input_dim=4096, depth=2, hidden_dim=512, nonlin='relu', batch_norm=True, layer_norm=False):
1402
+ super(MultiInputTextEncoder, self).__init__()
1403
+
1404
+ print("THINK ABOUT IF SOME OF THESE HYPERPARAMETERS SHOULD BE DISTINCT FROM THE TRANSFORMER VERSION")
1405
+ print("DEPTH / NUM_ENCODER_LAYERS, DROPOUT, DIM_FEEDFORWARD, ETC")
1406
+ print("AT PRESENT WE JUST HAVE A SORT OF BASIC VERSION IMPLEMENTED THAT ATTEMPTS TO BE LIKE MAX'S VERSION")
1407
+ print("ALSO, OPTION TO HAVE IT PRETRAINED? ADD RESIDUAL LAYERS?")
1408
+ self.token_dim=token_dim
1409
+ self.dropout=dropout
1410
+ self.input_dim=input_dim
1411
+ self.depth=depth
1412
+ self.hidden_dim=hidden_dim
1413
+ self.batch_norm = batch_norm
1414
+ self.layer_norm = layer_norm
1415
+
1416
+ if nonlin == 'relu':
1417
+ activation = nn.ReLU
1418
+ elif nonlin == 'silu':
1419
+ activation = nn.SiLU
1420
+ else:
1421
+ raise NotImplementedError('Invalid nonlinearity specified.')
1422
+
1423
+ self.dropout_layer = nn.Dropout(p=self.dropout)
1424
+ if self.depth <= 1:
1425
+ self.linear1 = nn.Linear(self.input_dim, self.token_dim)
1426
+
1427
+ else:
1428
+ self.linear1 = nn.Linear(self.input_dim, self.hidden_dim)
1429
+
1430
+ if self.batch_norm:
1431
+ self.bn1 = nn.BatchNorm1d(self.hidden_dim)
1432
+
1433
+ # if self.layer_norm:
1434
+ # self.ln1 = nn.LayerNorm(self.hidden_dim)
paths.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "data": "data/",
3
+ "masks": "data/masks/",
4
+ "env": "data/env/",
5
+ "train": "data/train/",
6
+ "geo_prior": "data/eval/geo_prior/",
7
+ "snt": "data/eval/snt/",
8
+ "iucn": "data/eval/iucn/",
9
+ "geo_feature": "data/eval/geo_feature/"
10
+ }
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio==3.36.1
2
+ h3==3.7.6
3
+ matplotlib==3.7.1
4
+ numpy==1.25.0
5
+ pandas==2.0.3
6
+ scikit_learn==1.3.0
7
+ scikit-image==0.19.3
8
+ tifffile==2023.7.4
9
+ torch==1.12.1
10
+ imagecodecs==2023.9.18
setup.py ADDED
The diff for this file is too large to render. See raw diff
 
utils.py ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import torch
3
+ import numpy as np
4
+ import math
5
+ import datetime
6
+ #from h3.unstable import vect
7
+ import h3
8
+
9
+ class CoordEncoder:
10
+
11
+ def __init__(self, input_enc, raster=None, input_dim=0):
12
+ self.input_enc = input_enc
13
+ self.raster = raster
14
+ self.input_dim = input_dim
15
+
16
+ def encode(self, locs, normalize=True):
17
+ # assumes lon, lat in range [-180, 180] and [-90, 90]
18
+ if normalize:
19
+ locs = normalize_coords(locs)
20
+ if self.input_enc == 'none':
21
+ loc_feats = locs * torch.tensor([[180.0,90.0]], device=locs.device)
22
+ elif self.input_enc == 'sin_cos': # sinusoidal encoding
23
+ loc_feats = encode_loc(locs, input_dim=self.input_dim)
24
+ elif self.input_enc == 'env': # bioclim variables
25
+ loc_feats = bilinear_interpolate(locs, self.raster)
26
+ elif self.input_enc == 'sin_cos_env': # sinusoidal encoding & bioclim variables
27
+ loc_feats = encode_loc(locs, input_dim=self.input_dim)
28
+ context_feats = bilinear_interpolate(locs, self.raster.to(locs.device))
29
+ loc_feats = torch.cat((loc_feats, context_feats), 1)
30
+ elif self.input_enc == 'satclip': #SatClip Embedding
31
+ if not hasattr(self, 'model'):
32
+ import sys
33
+ sys.path.append('./satclip/satclip')
34
+ from satclip.satclip.load import get_satclip
35
+ self.model = get_satclip('satclip/satclip-vit16-l10.ckpt', device="cpu")
36
+ self.model.eval()
37
+ self.model = self.model.to(locs.device)
38
+ locs = locs*torch.tensor([[180.0, 90.0]], device=locs.device)
39
+ max_batch = 1000000
40
+ loc_feats = torch.empty(locs.shape[0], 256, device=locs.device)
41
+ with torch.no_grad():
42
+ for i in range(0, locs.shape[0], max_batch):
43
+ loc_feats[i:i+max_batch] = self.model(locs[i:i+max_batch].double()).float()
44
+ else:
45
+ raise NotImplementedError('Unknown input encoding.')
46
+ return loc_feats
47
+
48
+ def encode_fast(self, loc: list[float], normalize=True):
49
+ assert not normalize
50
+ if self.input_enc == 'sin_cos':
51
+ loc_feats = encode_loc_fast(loc, input_dim=self.input_dim)
52
+ else:
53
+ raise NotImplementedError('Unknown input encoding.')
54
+ return loc_feats
55
+
56
+
57
+ class TimeEncoder:
58
+
59
+ def __init__(self, input_enc='conical'):
60
+ self.input_enc = input_enc
61
+
62
+ def encode(self, intervals):
63
+ # assumes time, width in range [0, 1]
64
+ t_center = intervals[:, :1]
65
+ t_width = intervals[:, 1:]
66
+ if self.input_enc == 'conical':
67
+ t_feats = torch.cat([(1 - t_width) * torch.sin(2 * torch.pi * t_center),
68
+ (1 - t_width) * torch.cos(2 * torch.pi * t_center), 2 * t_width - 1], dim=1)
69
+ elif self.input_enc == 'cylindrical':
70
+ t_feats = torch.cat([torch.sin(2 * torch.pi * t_center), torch.cos(2 * torch.pi * t_center), 2 * t_width - 1], dim=1)
71
+ return t_feats
72
+
73
+ def encode_fast(self, intervals):
74
+ # assumes time, width in range [0, 1]
75
+ t_center, t_width = intervals
76
+ if self.input_enc == 'conical':
77
+ t_feats = torch.tensor([(1 - t_width) * math.sin(2 * math.pi * t_center),
78
+ (1 - t_width) * math.cos(2 * math.pi * t_center), 2 * t_width - 1])
79
+ elif self.input_enc == 'cylindrical':
80
+ t_feats = torch.tensor([math.sin(2 * math.pi * t_center),
81
+ math.cos(2 * math.pi * t_center), 2 * t_width - 1])
82
+ return t_feats
83
+
84
+
85
+ def normalize_coords(locs):
86
+ # locs is in lon {-180, 180}, lat {90, -90}
87
+ # output is in the range [-1, 1]
88
+
89
+ locs[:,0] /= 180.0
90
+ locs[:,1] /= 90.0
91
+
92
+ return locs
93
+
94
+ def encode_loc(loc_ip, concat_dim=1, input_dim=0):
95
+ # assumes inputs location are in range -1 to 1
96
+ # location is lon, lat
97
+ encs = []
98
+ for i in range(input_dim//4):
99
+ encs.append(torch.sin(math.pi*(2**i)*loc_ip))
100
+ encs.append(torch.cos(math.pi*(2**i)*loc_ip))
101
+ feats = torch.cat(encs, concat_dim)
102
+ return feats
103
+
104
+
105
+ def encode_loc_fast(loc_ip: list[float], input_dim=0):
106
+ # assumes inputs location are in range -1 to 1
107
+ # location is lon, lat
108
+ input_dim //= 2 # needed to make it compatible with encode_loc
109
+ feats = [(math.sin if i%(2*len(loc_ip))<len(loc_ip) else math.cos)(math.pi*(2**(i//(2*len(loc_ip))))*loc_ip[i%len(loc_ip)]) for i in range(input_dim)]
110
+ return feats
111
+
112
+
113
+ def bilinear_interpolate(loc_ip, data, remove_nans_raster=True):
114
+ # loc is N x 2 vector, where each row is [lon,lat] entry
115
+ # each entry spans range [-1,1]
116
+ # data is H x W x C, height x width x channel data matrix
117
+ # op will be N x C matrix of interpolated features
118
+
119
+ assert data is not None
120
+
121
+ # map to [0,1], then scale to data size
122
+ loc = (loc_ip.clone() + 1) / 2.0
123
+ loc[:,1] = 1 - loc[:,1] # this is because latitude goes from +90 on top to bottom while
124
+ # longitude goes from -90 to 90 left to right
125
+
126
+ assert not torch.any(torch.isnan(loc))
127
+
128
+ if remove_nans_raster:
129
+ data[torch.isnan(data)] = 0.0 # replace with mean value (0 is mean post-normalization)
130
+
131
+ # cast locations into pixel space
132
+ loc[:, 0] *= (data.shape[1]-1)
133
+ loc[:, 1] *= (data.shape[0]-1)
134
+
135
+ loc_int = torch.floor(loc).long() # integer pixel coordinates
136
+ xx = loc_int[:, 0]
137
+ yy = loc_int[:, 1]
138
+ xx_plus = xx + 1
139
+ xx_plus[xx_plus > (data.shape[1]-1)] = data.shape[1]-1
140
+ yy_plus = yy + 1
141
+ yy_plus[yy_plus > (data.shape[0]-1)] = data.shape[0]-1
142
+
143
+ loc_delta = loc - torch.floor(loc) # delta values
144
+ dx = loc_delta[:, 0].unsqueeze(1)
145
+ dy = loc_delta[:, 1].unsqueeze(1)
146
+
147
+ interp_val = data[yy, xx, :]*(1-dx)*(1-dy) + data[yy, xx_plus, :]*dx*(1-dy) + \
148
+ data[yy_plus, xx, :]*(1-dx)*dy + data[yy_plus, xx_plus, :]*dx*dy
149
+
150
+ return interp_val
151
+
152
+ def rand_samples(batch_size, device, rand_type='uniform'):
153
+ # randomly sample background locations
154
+
155
+ if rand_type == 'spherical':
156
+ rand_loc = torch.rand(batch_size, 2).to(device)
157
+ theta1 = 2.0*math.pi*rand_loc[:, 0]
158
+ theta2 = torch.acos(2.0*rand_loc[:, 1] - 1.0)
159
+ lat = 1.0 - 2.0*theta2/math.pi
160
+ lon = (theta1/math.pi) - 1.0
161
+ rand_loc = torch.cat((lon.unsqueeze(1), lat.unsqueeze(1)), 1)
162
+
163
+ elif rand_type == 'uniform':
164
+ rand_loc = torch.rand(batch_size, 2).to(device)*2.0 - 1.0
165
+
166
+ return rand_loc
167
+
168
+ def get_time_stamp():
169
+ cur_time = str(datetime.datetime.now())
170
+ date, time = cur_time.split(' ')
171
+ h, m, s = time.split(':')
172
+ s = s.split('.')[0]
173
+ time_stamp = '{}-{}-{}-{}'.format(date, h, m, s)
174
+ return time_stamp
175
+
176
+ def coord_grid(grid_size, split_ids=None, split_of_interest=None):
177
+ # generate a grid of locations spaced evenly in coordinate space
178
+
179
+ feats = np.zeros((grid_size[0], grid_size[1], 2), dtype=np.float32)
180
+ mg = np.meshgrid(np.linspace(-180, 180, feats.shape[1]), np.linspace(90, -90, feats.shape[0]))
181
+ feats[:, :, 0] = mg[0]
182
+ feats[:, :, 1] = mg[1]
183
+ if split_ids is None or split_of_interest is None:
184
+ # return feats for all locations
185
+ # this will be an N x 2 array
186
+ return feats.reshape(feats.shape[0]*feats.shape[1], 2)
187
+ else:
188
+ # only select a subset of locations
189
+ ind_y, ind_x = np.where(split_ids==split_of_interest)
190
+
191
+ # these will be N_subset x 2 in size
192
+ return feats[ind_y, ind_x, :]
193
+
194
+ def create_spatial_split(raster, mask, train_amt=1.0, cell_size=25):
195
+ # generates a checkerboard style train test split
196
+ # 0 is invalid, 1 is train, and 2 is test
197
+ # c_size is units of pixels
198
+ split_ids = np.ones((raster.shape[0], raster.shape[1]))
199
+ start = cell_size
200
+ for ii in np.arange(0, split_ids.shape[0], cell_size):
201
+ if start == 0:
202
+ start = cell_size
203
+ else:
204
+ start = 0
205
+ for jj in np.arange(start, split_ids.shape[1], cell_size*2):
206
+ split_ids[ii:ii+cell_size, jj:jj+cell_size] = 2
207
+ split_ids = split_ids*mask
208
+ if train_amt < 1.0:
209
+ # take a subset of the data
210
+ tr_y, tr_x = np.where(split_ids==1)
211
+ inds = np.random.choice(len(tr_y), int(len(tr_y)*(1.0-train_amt)), replace=False)
212
+ split_ids[tr_y[inds], tr_x[inds]] = 0
213
+ return split_ids
214
+
215
+ def average_precision_score_faster(y_true, y_scores):
216
+ # drop in replacement for sklearn's average_precision_score
217
+ # comparable up to floating point differences
218
+ num_positives = y_true.sum()
219
+ inds = np.argsort(y_scores)[::-1]
220
+ y_true_s = y_true[inds]
221
+
222
+ false_pos_c = np.cumsum(1.0 - y_true_s)
223
+ true_pos_c = np.cumsum(y_true_s)
224
+ recall = true_pos_c / num_positives
225
+ false_neg = np.maximum(true_pos_c + false_pos_c, np.finfo(np.float32).eps)
226
+ precision = true_pos_c / false_neg
227
+
228
+ recall_e = np.hstack((0, recall, 1))
229
+ recall_e = (recall_e[1:] - recall_e[:-1])[:-1]
230
+ map_score = (recall_e*precision).sum()
231
+ return map_score
232
+
233
+ #TODO I might be able to just cast these to a float to make them 1 or 0
234
+ #TODO y_true are the same as the ones
235
+ def average_precision_score_fasterer(y_true, y_scores):
236
+ # drop in replacement for sklearn's average_precision_score
237
+ # comparable up to floating point differences
238
+ num_positives = y_true.sum()
239
+ inds = torch.argsort(y_scores, descending=True)
240
+ y_true_s = y_true[inds]
241
+
242
+ false_pos_c = torch.cumsum(1.0 - y_true_s, dim=0)
243
+ true_pos_c = torch.cumsum(y_true_s, dim=0)
244
+ recall = true_pos_c / num_positives
245
+ false_neg = (true_pos_c + false_pos_c).clip(min=np.finfo(np.float32).eps)
246
+ precision = true_pos_c / false_neg
247
+
248
+ recall_e = torch.cat([torch.zeros(1, device=recall.device), recall, torch.ones(1, device=recall.device)])
249
+ recall_e = (recall_e[1:] - recall_e[:-1])[:-1]
250
+ map_score = (recall_e*precision).sum()
251
+ return map_score
252
+
253
+
254
+ class DataPDFH3:
255
+ def __init__(self, data='data_pdf_h3.pt', device='cpu'):
256
+ super(DataPDFH3, self).__init__()
257
+ self.data = torch.cumsum(torch.load(data, map_location=device), dim=0)
258
+ self.data = torch.cat([torch.zeros_like(self.data[:1]), self.data], dim=0)
259
+ inds = torch.load('inds_h3.pt')
260
+ inds = ((inds >> 30) & 4194303)
261
+ self.ind_map = -1+torch.zeros(2 ** 22, dtype=torch.int32)
262
+ self.ind_map[inds] = torch.arange(inds.shape[0], dtype=torch.int32)
263
+ self.cum_counts = self.data.sum(dim=-1)
264
+
265
+ def _sample(self, pos, time, noise_level):
266
+ pos = pos.cpu()
267
+ time = time.cpu()
268
+ noise_level = noise_level.cpu()
269
+ t_low = (365*(time - 0.5*(noise_level))).int()
270
+ t_high = (365*(time + 0.5*(noise_level))).int()
271
+ t_high[t_low < 0] += 365
272
+ t_low[t_low < 0] += 365
273
+
274
+ pos_ind = torch.from_numpy((h3.latlng_to_cell(90*pos[:, 1], 180*pos[:, 0], 5).astype(np.int64) >> 30) & 4194303)
275
+ pos_ind = self.ind_map[pos_ind]
276
+ counts = self.data[t_high.clamp(max=364)+1, pos_ind] - self.data[t_low, pos_ind]
277
+ counts[t_high > 364] += self.data[(t_high[t_high > 364] - 365).clamp(max=364) + 1, pos_ind[t_high > 364]]
278
+ counts[t_high > 729] += self.data[(t_high[t_high > 729] - 730).clamp(max=364) + 1, pos_ind[t_high > 729]]
279
+ totals = self.cum_counts[t_high.clamp(max=364)+1] - self.cum_counts[t_low]
280
+ totals[t_high > 364] += self.cum_counts[(t_high[t_high > 364] - 365).clamp(max=364) + 1]
281
+ totals[t_high > 729] += self.cum_counts[(t_high[t_high > 729] - 730).clamp(max=364) + 1]
282
+ counts[pos_ind < 0] = 0
283
+ return counts, totals
284
+
285
+ def sample(self, pos, time, noise_level):
286
+ counts, totals = self._sample(pos, time, noise_level)
287
+ return counts/totals
288
+
289
+ def sample_log(self, pos, time, noise_level, eps=1e-2):
290
+ counts, totals = self._sample(pos, time, noise_level)
291
+ return torch.log(counts)-torch.log(totals+eps)
292
+
293
+
294
+ class LowRankModel:
295
+ def __init__(self, data='nmf_256.pt', device='cpu'):
296
+ super(LowRankModel, self).__init__()
297
+ dim=-1
298
+ x1, x2 = torch.load(data, map_location=device)
299
+ m = torch.load('class_counts_locs_h3.pt').float()
300
+ chosen_inds = m.sum(dim=0).to_dense().sort(descending=True).indices[:]
301
+ if dim == 0:
302
+ n = m.to_dense()[:, chosen_inds].sum(dim=dim, keepdim=True)
303
+ self.data = n*torch.softmax(x1 @ x2, dim=dim)
304
+ self.data = self.data/torch.sum(self.data, dim=1, keepdim=True)
305
+ elif dim == 1:
306
+ self.data = torch.softmax(x1 @ x2, dim=dim)
307
+ elif dim == -1:
308
+ self.data = torch.from_numpy(x1 @ x2)
309
+ self.data = self.data/torch.sum(self.data, dim=1, keepdim=True)
310
+ m = m.to_dense()[:, chosen_inds]
311
+ #self.data = m.to_dense().float()/torch.sum(m.to_dense(), dim=1, keepdim=True)
312
+ self.pc = m.sum(dim=1, keepdim=True) / m.sum()
313
+ inds = torch.load('inds_h3.pt')[chosen_inds]
314
+ inds = ((inds >> 30) & 4194303)
315
+ self.ind_map = -1+torch.zeros(2 ** 22, dtype=torch.int32)
316
+ self.ind_map[inds] = torch.arange(inds.shape[0], dtype=torch.int32)
317
+
318
+ def sample(self, pos):#, time, noise_level):
319
+ pos = pos.cpu()
320
+ pos_ind = torch.from_numpy((h3.latlng_to_cell(pos[:, 1], pos[:, 0], 5).astype(np.int64) >> 30) & 4194303)
321
+ pos_ind = self.ind_map[pos_ind]
322
+ out = self.data[:, pos_ind]
323
+ out *= self.pc
324
+ out = out/torch.sum(out, dim=0, keepdim=True)
325
+ out[:, pos_ind < 0] = 1.0/out.shape[0]
326
+ return out
viz_ls_map.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Demo that takes an iNaturalist taxa ID as input and generates a prediction
3
+ for each location on the globe and saves the ouput as an image.
4
+ """
5
+
6
+ import torch
7
+ import numpy as np
8
+ import matplotlib.pyplot as plt
9
+ import os
10
+ import json
11
+ import argparse
12
+
13
+ import utils
14
+ import datasets
15
+ import eval
16
+ import create_inputs_to_fs_sinr
17
+
18
+ text_model = './experiments/gpt_data.pt'
19
+
20
+ def extract_grit_token(model, text:str):
21
+ def gritlm_instruction(instruction):
22
+ return "<|user|>\n" + instruction + "\n<|embed|>\n" if instruction else "<|embed|>\n"
23
+ d_rep = model.encode([text], instruction=gritlm_instruction(""))
24
+ d_rep = torch.from_numpy(d_rep)
25
+ return d_rep
26
+
27
+ def choose_context_points_from_map(eval_params):
28
+ context_points = []
29
+
30
+ if False:
31
+ def onclick(event):
32
+ if event.xdata is not None and event.ydata is not None:
33
+ # Convert image coordinates to normalized geographical coordinates
34
+ lon = event.xdata / mask.shape[1] * 2 - 1
35
+ lat = 1 - event.ydata / mask.shape[0] * 2
36
+ context_points.append((lon, lat))
37
+ print(f"Added context point: ({lon}, {lat})")
38
+
39
+ # Load ocean mask
40
+ with open('paths.json', 'r') as f:
41
+ paths = json.load(f)
42
+ if eval_params['high_res']:
43
+ mask = np.load(os.path.join(paths['masks'], 'ocean_mask_hr.npy'))
44
+ else:
45
+ mask = np.load(os.path.join(paths['masks'], 'ocean_mask.npy'))
46
+
47
+ mask_inds = np.where(mask.reshape(-1) == 1)[0]
48
+
49
+ # # Generate input features
50
+ # locs = utils.coord_grid(mask.shape)
51
+ # if not eval_params['disable_ocean_mask']:
52
+ # locs = locs[mask_inds, :]
53
+ # locs = torch.from_numpy(locs)
54
+
55
+ # Reshape and create masked array for visualization
56
+ op_im = np.ones((mask.shape[0] * mask.shape[1])) * np.nan # Set to NaN
57
+ op_im[mask_inds] = 0 # Placeholder for the mask visualization
58
+ op_im = op_im.reshape((mask.shape[0], mask.shape[1]))
59
+ op_im = np.ma.masked_invalid(op_im)
60
+
61
+ # Set color for masked values
62
+ cmap = plt.cm.plasma
63
+ cmap.set_bad(color='none')
64
+ plt.ioff()
65
+ # Display the map and capture context points
66
+ fig, ax = plt.subplots(figsize=(6, 3), dpi=334) # Define the figure size
67
+ ax.imshow(op_im, cmap=cmap, interpolation='nearest') # Display the image
68
+ ax.axis('off') # Turn off the axis
69
+
70
+ # Connect the onclick event to the handler
71
+ cid = fig.canvas.mpl_connect('button_press_event', onclick)
72
+
73
+ plt.show(block=True) # Block execution until the window is closed
74
+
75
+ print(f"Context points collected: {context_points}")
76
+
77
+ else:
78
+ #USA
79
+ #TODO: 37.541170, -92.003293 1. flip order, then 2. normalize so divide by 180 and 90
80
+ context_points = [(-0.5884012559178662, 0.46394662490802496), (-0.5451199953511522, 0.4504212309809269),
81
+ (-0.5437674559584422, 0.5342786733289353), (-0.589753795310576, 0.5342786733289353)]
82
+ print(f"Context points collected: {context_points}")
83
+ return context_points
84
+
85
+ def main(eval_params):
86
+ # load params
87
+ with open('paths.json', 'r') as f:
88
+ paths = json.load(f)
89
+
90
+ ckp_name = os.path.split(eval_params['model_path'])[-1]
91
+ experiment_name = os.path.split(os.path.split(eval_params['model_path'])[-2])[-1]
92
+
93
+ eval_overrides = {'ckp_name':ckp_name,
94
+ 'experiment_name':experiment_name,
95
+ 'device':eval_params['device']}
96
+
97
+
98
+ train_overrides = {'dataset': 'eval_transformer'}
99
+ #grit = GritLM("GritLM/GritLM-7B", torch_dtype="auto", mode="embedding")
100
+ #grit_gpt = torch.load(text_model, map_location='cpu')
101
+ #context_model = torch.load("experiments/zero_shot_ls_sin_cos_cap_1000_text_context_20_sinr_two_layer_nn/model.pt", map_location=torch.device('cpu'))
102
+ context_data = np.load('data/positive_eval_data.npz')
103
+ text_type_value = 0
104
+
105
+ for pt in eval_params['context_pt_trial']:
106
+ number_of_context_points = pt
107
+ if eval_params['choose_context_points'] == 1:
108
+ #context_points = choose_context_points_from_map(eval_params)
109
+ text_emb, text_type_value = create_inputs_to_fs_sinr.use_pregenerated_textemb_fromchris(taxon_id=eval_params['test_taxa'],
110
+ text_type=eval_params['text_type'])
111
+ context_points = create_inputs_to_fs_sinr.get_eval_context_points(taxa_id=eval_params['test_taxa'],
112
+ context_data=context_data,
113
+ size=number_of_context_points)
114
+ model, context_locs_of_interest, train_params, class_of_interest = eval.generate_eval_embedding_from_given_points(
115
+ context_points=context_points,
116
+ overrides=eval_overrides,
117
+ taxa_of_interest=eval_params['taxa_id'],
118
+ train_overrides=train_overrides,
119
+ text_emb=text_emb)
120
+ #TODO: why is taxa_id updated to 'selected pts'??
121
+ eval_params['taxa_id'] = 'selected_points'
122
+ else:
123
+ model, context_locs_of_interest, train_params, class_of_interest = eval.generate_eval_embeddings(
124
+ overrides=eval_overrides,
125
+ taxa_of_interest=eval_params['taxa_id'],
126
+ num_context=eval_params['num_context'],
127
+ train_overrides=train_overrides)
128
+
129
+ if train_params['params']['input_enc'] in ['env', 'sin_cos_env']:
130
+ raster = datasets.load_env()
131
+ else:
132
+ raster = None
133
+ enc = utils.CoordEncoder(train_params['params']['input_enc'], raster=raster, input_dim=train_params['params']['input_dim'])
134
+ enc_time = utils.CoordEncoder('sin_cos', raster=None, input_dim=2 * train_params['params']['input_time_dim'])
135
+
136
+ # load ocean mask
137
+ if eval_params['high_res']:
138
+ mask = np.load(os.path.join(paths['masks'], 'ocean_mask_hr.npy'))
139
+ else:
140
+ mask = np.load(os.path.join(paths['masks'], 'ocean_mask.npy'))
141
+ #mask = 0*mask+1
142
+ mask_inds = np.where(mask.reshape(-1) == 1)[0]
143
+
144
+ # generate input features
145
+ locs = utils.coord_grid(mask.shape)
146
+ if not eval_params['disable_ocean_mask']:
147
+ locs = locs[mask_inds, :]
148
+ locs = torch.from_numpy(locs)
149
+ locs_enc = enc.encode(locs).to(eval_params['device'])
150
+ if train_params['params']['input_time_dim'] > 0:
151
+ extra_input = torch.cat([enc_time.encode(torch.tensor([[0.0]]), normalize=False), torch.tensor([[1.0]])],
152
+ dim=1).to(eval_params['device'])
153
+ locs_enc = torch.cat((locs_enc, extra_input.repeat(locs_enc.shape[0], 1)), dim=1)
154
+
155
+ with torch.no_grad():
156
+ # Here if we set eval to False we will see what the ema embeddings look like (currently as ema is 1.0 this is just the last training example seen)
157
+ preds = model.embedding_forward(x=locs_enc, class_ids=None, return_feats=False, class_of_interest=class_of_interest, eval=True).cpu().numpy()
158
+
159
+ # threshold predictions
160
+ if eval_params['threshold'] > 0:
161
+ print(f'Applying threshold of {eval_params["threshold"]} to the predictions.')
162
+ preds[preds<eval_params['threshold']] = 0.0
163
+ preds[preds>=eval_params['threshold']] = 1.0
164
+
165
+ # mask data
166
+ if not eval_params['disable_ocean_mask']:
167
+ op_im = np.ones((mask.shape[0] * mask.shape[1])) * np.nan # set to NaN
168
+ op_im[mask_inds] = preds
169
+ else:
170
+ op_im = preds
171
+
172
+ # reshape and create masked array for visualization
173
+ op_im = op_im.reshape((mask.shape[0], mask.shape[1]))
174
+ op_im = np.ma.masked_invalid(op_im)
175
+
176
+ # set color for masked values
177
+ cmap = plt.cm.plasma
178
+ cmap.set_bad(color='none')
179
+ if eval_params['set_max_cmap_to_1']:
180
+ vmax = 1.0
181
+ else:
182
+ vmax = np.max(op_im)
183
+
184
+ # # Display the image
185
+ # if eval_params['show_map'] == 1:
186
+ # fig, ax = plt.subplots()
187
+ # cax = ax.imshow(op_im, vmin=0, vmax=vmax, cmap=cmap)
188
+ # fig.colorbar(cax)
189
+ # plt.show(block=True) # Set block=True to block code execution until the window is closed
190
+
191
+ if eval_params['show_map'] == 1:
192
+ # Display the image
193
+ fig, ax = plt.subplots(figsize=(6,3), dpi=334)
194
+ plt.imshow(op_im, vmin=0, vmax=vmax, cmap=cmap, interpolation='nearest') # Display the image
195
+ plt.axis('off') # Turn off the axis
196
+
197
+ if eval_params['show_context_points'] == 1:
198
+ # Convert the tensor to numpy array if it's not already
199
+ context_locs = context_locs_of_interest.numpy() if isinstance(context_locs_of_interest, torch.Tensor) else context_locs_of_interest
200
+ # Convert context locations directly to image coordinates
201
+ #delete our dumby context point (at 0,0)
202
+ image_x = (context_locs[1:, 0] + 1) / 2 * op_im.shape[1] # Scale longitude from [-1, 1] to [0, image width]
203
+ image_y = (1 - (context_locs[1:, 1] + 1) / 2) * op_im.shape[
204
+ 0] # Scale latitude from [-1, 1] to [0, image height]
205
+
206
+ from matplotlib.offsetbox import OffsetImage, AnnotationBbox
207
+ # Plot the context locations
208
+ def getImage(path):
209
+ return OffsetImage(plt.imread(path), zoom=.04)
210
+
211
+ for x0, y0 in zip(image_x, image_y):
212
+ ab = AnnotationBbox(getImage('black_circle.png'), (x0, y0), frameon=False)
213
+ ax.add_artist(ab)
214
+ #plt.scatter(image_x, image_y, c='green', s=30, marker=r'$\checkmark$') # Adjust color and size of the point
215
+
216
+ #plt.show(block=True) # Block execution until the window is closed
217
+
218
+
219
+ exp_name = eval_params['model_path'].split(os.path.sep)[-2]
220
+
221
+ # save image
222
+ #save_loc = os.path.join(eval_params['op_path'], exp_name + '_' + str(eval_params['taxa_id']) + '_' + eval_params['additional_save_name'] +'_map.png')
223
+ #save_loc = os.path.join(eval_params['op_path'], exp_name + '_' + str(eval_params['taxa_id']) + '_' + eval_params['additional_save_name'] +'_map.png')
224
+ #save_loc = 'images/testenv_' + eval_params['taxa_name'] + '(' + eval_params['taxa_id'] + ')_'+ eval_params['text_type'] + '(' + str(text_type_value) + ')_' + str(number_of_context_points) +'.png'
225
+ save_loc = 'images/testenv_' + eval_params['taxa_name'] + '(' + eval_params['taxa_id'] + ')_'+ eval_params['text_type'] + '_' + str(number_of_context_points) +'.png'
226
+ print(f'Saving image to {save_loc}')
227
+ plt.savefig(save_loc, bbox_inches='tight', pad_inches=0, dpi=334)
228
+ # plt.imsave(fname=save_loc, arr=op_im, vmin=0, vmax=vmax, cmap=cmap)
229
+ plt.show(block=False) # Block execution until the window is closed
230
+
231
+ return True
232
+
233
+
234
+ if __name__ == '__main__':
235
+
236
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
237
+
238
+
239
+ info_str = '\nDemo that takes an iNaturalist taxa ID as input and ' + \
240
+ 'generates a predicted range for each location on the globe ' + \
241
+ 'and saves the ouput as an image.\n\n' + \
242
+ 'Warning: these estimated ranges should be validated before use.'
243
+
244
+ parser = argparse.ArgumentParser(usage=info_str)
245
+ # parser.add_argument('--model_path', type=str, default='./pretrained_models/model_an_full_input_enc_sin_cos_hard_cap_num_per_class_1000.pt')
246
+ # parser.add_argument('--model_path', type=str, default='./experiments/transformer_ema_1.0/model_10.pt')
247
+ # parser.add_argument('--model_path', type=str, default='./experiments/03_08_coord_multihead.pt/model.pt')
248
+ # parser.add_argument('--model_path', type=str, default='./experimentvs/coord_context_20_without_registry/model_best.pt')
249
+ # parser.add_argument('--model_path', type=str, default='./experiments/coord_sinr_inputs_context_20_without_registry/model_best.pt')
250
+ parser.add_argument('--model_path', type=str, default='./experiments/zero_shot_ls_sin_cos_env_cap_1000_text_context_20_sinr_two_layer_nn/model.pt')
251
+ #parser.add_argument('--model_path', type=str, default='./experiments/zero_shot_ls_sin_cos_cap_1000_text_context_20_sinr_two_layer_nn/model.pt')
252
+ # parser.add_argument('--taxa_id', type=int, default=144575, help='iNaturalist taxon ID.')
253
+ # parser.add_argument('--taxa_id', type=int, default=9083, help='iNaturalist taxon ID.')
254
+ parser.add_argument('--taxa_id', type=int, default=3352, help='iNaturalist taxon ID.')
255
+ parser.add_argument('--threshold', type=float, default=-1, help='Threshold the range map [0, 1].')
256
+ parser.add_argument('--op_path', type=str, default='./images/', help='Location where the output image will be saved.')
257
+ parser.add_argument('--rand_taxa', action='store_true', help='Select a random taxa.')
258
+ parser.add_argument('--high_res', action='store_true', help='Generate higher resolution output.')
259
+ parser.add_argument('--disable_ocean_mask', action='store_true', help='Do not use an ocean mask.')
260
+ parser.add_argument('--set_max_cmap_to_1', action='store_true', help='Consistent maximum intensity ouput.')
261
+ parser.add_argument('--device', type=str, default='cpu', help='cpu or cuda')
262
+ #parser.add_argument('--device', type=str, default='cuda:3', help='cpu or cuda')
263
+ parser.add_argument('--show_map', type=int, default=1, help='shows the map if 1')
264
+ parser.add_argument('--show_context_points', type=int, default=1, help='also plots context points if 1')
265
+ parser.add_argument('--prefix', type=str, default='')
266
+ parser.add_argument('--num_context', type=int, default=5)
267
+ parser.add_argument('--choose_context_points', type=int, default=1)
268
+ parser.add_argument('--additional_save_name', type=str, default="")
269
+ #taxas: black&whitewarbler(10286), hyacinth macaw(18938), yellow baboon(67683)
270
+ # bawnswallow (11901), pika(43188), loon(4626), eurorobin(13094)
271
+ # southernflyingsquirrel (46272)
272
+ parser.add_argument('--taxa_name', type=str, default='sfs', help='Name of the taxon.')
273
+ parser.add_argument('--test_taxa', type=int, default=46272, help='Taxon ID to test.')
274
+ parser.add_argument('--text_type', type=str, default='range', help='Type of text for input.')
275
+ parser.add_argument('--context_pt_trial', type=int, nargs='+', default=[0, 1, 2, 5, 10, 20], help='List of context points for trial.')
276
+ eval_params = vars(parser.parse_args())
277
+
278
+ if not os.path.isdir(eval_params['op_path']):
279
+ os.makedirs(eval_params['op_path'])
280
+
281
+ eval_params['high_res'] = True
282
+
283
+ main(eval_params)