Spaces:
Sleeping
Sleeping
angelazhu96
commited on
Commit
·
9ff98d7
1
Parent(s):
dcb7cfe
code for viz
Browse files- app.py +87 -0
- create_inputs_to_fs_sinr.py +124 -0
- eval.py +0 -0
- get_gt.py +369 -0
- models.py +1434 -0
- paths.json +10 -0
- requirements.txt +10 -0
- setup.py +0 -0
- utils.py +326 -0
- viz_ls_map.py +283 -0
app.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from viz_ls_map import main
|
3 |
+
from get_gt import generate_ground_truth
|
4 |
+
|
5 |
+
def predict_species_distribution(taxa_id, taxa_name, text_type, num_context_points):
|
6 |
+
"""
|
7 |
+
Function to predict species distribution and visualize the map.
|
8 |
+
"""
|
9 |
+
isSnt = False
|
10 |
+
taxa_id = int(taxa_id)
|
11 |
+
#num_context_points = [0, 1, 2, 5, 10, 20]
|
12 |
+
num_context_points = [1]
|
13 |
+
|
14 |
+
# Generate ground truth for the species
|
15 |
+
generate_ground_truth(taxa_id, isSnt)
|
16 |
+
image_path_gt = f'images/species_presence_hr_{taxa_id}.png'
|
17 |
+
output_images = []
|
18 |
+
#print(num_context_points)
|
19 |
+
|
20 |
+
for text_type_i in ['none','range','habitat']:
|
21 |
+
# Set up evaluation parameters
|
22 |
+
eval_params = {
|
23 |
+
'model_path': './experiments/zero_shot_ls_sin_cos_env_cap_1000_text_context_20_sinr_two_layer_nn/model.pt',
|
24 |
+
'taxa_id': taxa_id,
|
25 |
+
'threshold': -1,
|
26 |
+
'op_path': './images/',
|
27 |
+
'rand_taxa': False,
|
28 |
+
'high_res': True,
|
29 |
+
'disable_ocean_mask': False,
|
30 |
+
'set_max_cmap_to_1': False,
|
31 |
+
'device': 'cpu',
|
32 |
+
'show_map': 1,
|
33 |
+
'show_context_points': 1,
|
34 |
+
'prefix': '',
|
35 |
+
'num_context': num_context_points,
|
36 |
+
'choose_context_points': 1,
|
37 |
+
'additional_save_name': "",
|
38 |
+
'taxa_name': taxa_name,
|
39 |
+
'test_taxa': taxa_id,
|
40 |
+
'text_type': text_type_i, # 'none', 'habitat', or 'range'
|
41 |
+
'context_pt_trial': num_context_points,
|
42 |
+
}
|
43 |
+
|
44 |
+
# Run the FS-SINR model with the specified parameters
|
45 |
+
main(eval_params)
|
46 |
+
|
47 |
+
# The output image is saved in './images/' with the predicted range map
|
48 |
+
#image_path = f'./images/{taxa_name}_predicted_range.png'
|
49 |
+
|
50 |
+
for k in num_context_points:
|
51 |
+
# Assume image filenames are stored like this
|
52 |
+
image_path = f'./images/testenv_{taxa_name}(selected_points)_{text_type_i}_{k}.png'
|
53 |
+
output_images.append(image_path)
|
54 |
+
|
55 |
+
|
56 |
+
return [image_path_gt] + output_images
|
57 |
+
#return True
|
58 |
+
|
59 |
+
# Define the Gradio interface
|
60 |
+
with gr.Blocks() as demo:
|
61 |
+
gr.Markdown("# View Species Distribution Predictions using FS-SINR")
|
62 |
+
|
63 |
+
# Input fields for the Gradio interface
|
64 |
+
taxa_id = gr.Number(label="Taxa ID", value=43188)
|
65 |
+
taxa_name = gr.Textbox(label="Taxa Name", value="test_pika")
|
66 |
+
text_type = gr.Radio(label="Text Type", choices=['none', 'habitat', 'range'], value='none')
|
67 |
+
#num_context_points = gr.Slider(label="Number of Context Points", minimum=1, maximum=20, value=5, step=1)
|
68 |
+
num_context_points = gr.CheckboxGroup([0,1,2,3,4,5,10,15,20], label="Number of Context Points")
|
69 |
+
|
70 |
+
# Button to trigger the prediction
|
71 |
+
predict_button = gr.Button("Predict Species Distribution")
|
72 |
+
|
73 |
+
# Output: predicted range map
|
74 |
+
ground_truth = gr.Image(label="Ground Truth Map")
|
75 |
+
none_maps = gr.Image(label=f"Map for No Text Input and Context Point {1}")
|
76 |
+
range_maps = gr.Image(label=f"Map for Range Text input and Context Point {1}")
|
77 |
+
hab_maps = gr.Image(label=f"Map for Habitat Text input and Context Point {1}")
|
78 |
+
output_images = [ground_truth, none_maps, range_maps, hab_maps]
|
79 |
+
|
80 |
+
|
81 |
+
# Link the button to the function and inputs
|
82 |
+
predict_button.click(fn=predict_species_distribution,
|
83 |
+
inputs=[taxa_id, taxa_name, text_type, num_context_points],
|
84 |
+
outputs=output_images)
|
85 |
+
|
86 |
+
# Launch the Gradio interface
|
87 |
+
demo.launch()
|
create_inputs_to_fs_sinr.py
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import csv
|
3 |
+
from gritlm import GritLM
|
4 |
+
import pandas as pd
|
5 |
+
import ast
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
input_text4 = ['The hyacinth macaw prefers semi-open, somewhat wooded habitats. It usually avoids dense, humid forest, and in regions dominated by such habitats, it is generally restricted to the edge or relatively open sections (e.g. along major rivers). In different areas of their range, these parrots are found in savannah grasslands, in dry thorn forests known as caatinga, and in palm stands or swamps, particularly the moriche palm (Mauritia flexuosa).',
|
9 |
+
'The hyacinth macaw occurs today in three main areas in South America: In the Pantanal region of Brazil, and adjacent eastern Bolivia and northeastern Paraguay, in the cerrado regions of the eastern interior of Brazil (Maranhão, Piauí, Bahia, Tocantins, Goiás, Mato Grosso, Mato Grosso do Sul, and Minas Gerais), and in the relatively open areas associated with the Tocantins River, Xingu River, Tapajós River, and the Marajó island in the eastern Amazon Basin of Brazil.',
|
10 |
+
'They are diurnal, terrestrial, and live in complex, mixed-gender social groups of 8 to 200 individuals per troop. They prefer savannas and light forests with a climate that is suitable for their omnivorous diet.',
|
11 |
+
'Yellow baboons inhabit savannas and light forests in eastern Africa, from Kenya and Tanzania to Zimbabwe and Botswana.']
|
12 |
+
input_text5 = ['chappell roan', 'europe', 'pawpaw',
|
13 |
+
'sierra nevada', 'great lakes', 'Treaty of Waitangi',
|
14 |
+
'hello kitty', 'disney', 'madagascar', 'Andes', 'africa',
|
15 |
+
'dessert', 'whale', 'moon snail', 'unicorn', 'rainfall',
|
16 |
+
'species occurs above 2000m of elevation', 'froyo', 'desert',
|
17 |
+
'dragon', 'bear', 'selkie', 'loch ness monster']
|
18 |
+
|
19 |
+
def extract_grit_token(model, text:str):
|
20 |
+
def gritlm_instruction(instruction):
|
21 |
+
return "<|user|>\n" + instruction + "\n<|embed|>\n" if instruction else "<|embed|>\n"
|
22 |
+
d_rep = model.encode([text], instruction=gritlm_instruction(""))
|
23 |
+
d_rep = torch.from_numpy(d_rep)
|
24 |
+
return d_rep
|
25 |
+
|
26 |
+
def generate_text_embs(text, output_file):
|
27 |
+
grit = GritLM("GritLM/GritLM-7B", torch_dtype="auto", mode="embedding")
|
28 |
+
|
29 |
+
with open(output_file, mode='w') as file:
|
30 |
+
writer = csv.writer(file)
|
31 |
+
writer.writerow(['Text', 'Embedding'])
|
32 |
+
for i in range(0, len(text)):
|
33 |
+
text_emb = extract_grit_token(grit, text[i]).to('cpu')
|
34 |
+
print(f" {text[i]}: {text_emb} ")
|
35 |
+
writer.writerow([text[i], text_emb.tolist()])
|
36 |
+
|
37 |
+
#TODO: max's generate text using grit
|
38 |
+
def generate_text_emb(text):
|
39 |
+
grit = GritLM("GritLM/GritLM-7B", torch_dtype="auto", mode="embedding")
|
40 |
+
text_emb = extract_grit_token(grit, text)
|
41 |
+
return text_emb
|
42 |
+
|
43 |
+
def use_pregenerated_textemb_fromgpt(taxon_id):
|
44 |
+
embs_loaded = torch.load('experiments/gpt_data.pt', map_location='cpu')
|
45 |
+
|
46 |
+
emb_ids = embs_loaded['taxon_id'].tolist() #(2785,)
|
47 |
+
keys1 = embs_loaded['keys'] #(11140, 2)
|
48 |
+
embs = embs_loaded['data'] # torch.Size([11140, 4096])
|
49 |
+
print(embs_loaded['taxon_id'].size())
|
50 |
+
|
51 |
+
matching_indices = [i for i, (tid) in enumerate(emb_ids) if tid == taxon_id]
|
52 |
+
print(matching_indices)
|
53 |
+
taxon_embeddings = embs[matching_indices, :] # Get embeddings for the matching indices
|
54 |
+
matching_keys = [keys1[i] for i in matching_indices] # Get the corresponding (taxon_id, text_type) keys
|
55 |
+
|
56 |
+
print(f"Found {len(matching_keys)} embeddings for taxon ID {taxon_id}:")
|
57 |
+
for i, key in enumerate(matching_keys):
|
58 |
+
print(f"Text Type: {key[1]}, Embedding: {taxon_embeddings[i, :]}")
|
59 |
+
|
60 |
+
return taxon_embeddings[i, :]
|
61 |
+
|
62 |
+
def use_pregenerated_textemb_fromchris(taxon_id, text_type):
|
63 |
+
#zero vector is for no text input
|
64 |
+
text_embedding = torch.zeros(1,4096)
|
65 |
+
if text_type is None or text_type == 'none':
|
66 |
+
return text_embedding, 0
|
67 |
+
|
68 |
+
embs1 = torch.load('experiments/gpt_data.pt', map_location='cpu')
|
69 |
+
emb_ids1 = embs1['taxon_id'].tolist()
|
70 |
+
keys1 = embs1['keys']
|
71 |
+
embs1 = embs1['data']
|
72 |
+
|
73 |
+
taxa_of_interest = taxon_id
|
74 |
+
taxa_index_of_interest = emb_ids1.index(taxa_of_interest) # gets 5
|
75 |
+
|
76 |
+
#keys_with_taxa_of_interest = [key for key in keys1 if key[0] == taxa_index_of_interest]
|
77 |
+
#indices_with_taxa_of_interest = [(key, i) for i, key in enumerate(keys1) if key[0] == taxa_index_of_interest]
|
78 |
+
possible_text_embedding_indexes = [i for i, key in enumerate(keys1) if key[0] == taxa_index_of_interest and key[1]==text_type]
|
79 |
+
|
80 |
+
if len(possible_text_embedding_indexes) != 1:
|
81 |
+
return text_embedding, 0
|
82 |
+
# take a look and choose what you want
|
83 |
+
# for key in indices_with_taxa_of_interest:
|
84 |
+
# print(key)
|
85 |
+
|
86 |
+
# ((5, 'range'), 20)
|
87 |
+
# ((5, 'habitat'), 21)
|
88 |
+
# ((5, 'species_description'), 22)
|
89 |
+
# ((5, 'overview_summary'), 23)
|
90 |
+
|
91 |
+
#macaw: range: 20, habitat: 21
|
92 |
+
#baboon: range: 7928, habitat: 7929
|
93 |
+
#black&white warbler: range: 16, habitat: 17
|
94 |
+
#barn swallow: range: 1652, habitat: 1653
|
95 |
+
#pika: range: 7116, habitat: 7117
|
96 |
+
#loon: range: 11056, habitat:11057
|
97 |
+
#euro robin: range: 2020, habitat: 2021
|
98 |
+
#sfs: range: 7148, habitat: 7149
|
99 |
+
text_embedding_index = possible_text_embedding_indexes[0]
|
100 |
+
text_embedding = embs1[text_embedding_index].unsqueeze(0)
|
101 |
+
#print(text_embedding_index)
|
102 |
+
return text_embedding, text_embedding_index
|
103 |
+
|
104 |
+
def use_pregenerated_textemb_fromcsv(input_text):
|
105 |
+
text_data = pd.read_csv('data/text_embs/text_embeddings_fig4.csv')
|
106 |
+
result_row = text_data[text_data['Text'] == input_text]
|
107 |
+
text_emb = ast.literal_eval(result_row['Embedding'].values[0])
|
108 |
+
embedding_tensor = torch.FloatTensor(text_emb)
|
109 |
+
return embedding_tensor
|
110 |
+
|
111 |
+
def get_eval_context_points(taxa_id, context_data, size):
|
112 |
+
all_context_pts = context_data['locs'][context_data['labels'] == np.argwhere(context_data['class_to_taxa'] == taxa_id)[0]][1:]
|
113 |
+
context_pts = all_context_pts[0:size]
|
114 |
+
dummy_classtoken = np.array([[0,0]])
|
115 |
+
context_pts = np.vstack((dummy_classtoken, context_pts))
|
116 |
+
#print(f"context point shape: {np.shape(context_pts)}")
|
117 |
+
normalized_pts = torch.from_numpy(context_pts) * torch.tensor([[1/180,1/90]], device='cpu')
|
118 |
+
|
119 |
+
return normalized_pts
|
120 |
+
|
121 |
+
if __name__ == '__main__':
|
122 |
+
print('starting to generate text_embs')
|
123 |
+
output_file = './data/text_embs/text_embeddings_fig4.csv'
|
124 |
+
use_pregenerated_textemb_fromchris()
|
eval.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
get_gt.py
ADDED
@@ -0,0 +1,369 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# import numpy as np
|
2 |
+
# import h3
|
3 |
+
# import json
|
4 |
+
# import os
|
5 |
+
#
|
6 |
+
# snt=False
|
7 |
+
#
|
8 |
+
# def get_labels(species, data):
|
9 |
+
# species = str(species)
|
10 |
+
# lat = []
|
11 |
+
# lon = []
|
12 |
+
# gt = []
|
13 |
+
# for hx in data:
|
14 |
+
# cur_lat, cur_lon = h3.h3_to_geo(hx)
|
15 |
+
# if species in data[hx]:
|
16 |
+
# cur_label = int(len(data[hx][species]) > 0)
|
17 |
+
# gt.append(cur_label)
|
18 |
+
# lat.append(cur_lat)
|
19 |
+
# lon.append(cur_lon)
|
20 |
+
# lat = np.array(lat).astype(np.float32)
|
21 |
+
# lon = np.array(lon).astype(np.float32)
|
22 |
+
# obs_locs = np.vstack((lon, lat)).T
|
23 |
+
# gt = np.array(gt).astype(np.float32)
|
24 |
+
# return obs_locs, gt
|
25 |
+
#
|
26 |
+
# def lonlat_to_pixel(lonlat, grid_width, grid_height):
|
27 |
+
# # Convert normalized lon/lat (-1 to 1) to pixel coordinates
|
28 |
+
# x_pixel = np.floor((lonlat[:, 0] + 1) / 2 * (grid_width - 1)).astype(int)
|
29 |
+
# y_pixel = np.floor((1 - (lonlat[:, 1] + 1) / 2) * (grid_height - 1)).astype(int)
|
30 |
+
# return x_pixel, y_pixel
|
31 |
+
#
|
32 |
+
# ocean_mask = np.load("data/masks/ocean_mask.npy", allow_pickle=True)
|
33 |
+
# # 1002, 2004 pixels
|
34 |
+
# # 0 in ocean (needs to be masked out)
|
35 |
+
#
|
36 |
+
# if snt:
|
37 |
+
# with open('paths.json', 'r') as f:
|
38 |
+
# paths = json.load(f)
|
39 |
+
# D = np.load(os.path.join(paths['snt'], 'snt_res_5.npy'), allow_pickle=True)
|
40 |
+
# D = D.item()
|
41 |
+
# loc_indices_per_species = D['loc_indices_per_species']
|
42 |
+
# labels_per_species = D['labels_per_species']
|
43 |
+
# taxa = D['taxa']
|
44 |
+
# obs_locs = D['obs_locs']
|
45 |
+
# obs_locs_idx = D['obs_locs_idx']
|
46 |
+
# else:
|
47 |
+
# with open('paths.json', 'r') as f:
|
48 |
+
# paths = json.load(f)
|
49 |
+
# with open(os.path.join(paths['iucn'], 'iucn_res_5.json'), 'r') as f:
|
50 |
+
# data = json.load(f)
|
51 |
+
# obs_locs = np.array(data['locs'], dtype=np.float32)
|
52 |
+
# taxa = [int(tt) for tt in data['taxa_presence'].keys()]
|
53 |
+
# a = 6
|
54 |
+
# # data['taxa_presence'] is a dict where keys are "taxa" and then the values are the indices of "obs_locs" where the species is present
|
55 |
+
# # obs locs is in lon, lat with -180 to 180 and -90 to 90
|
56 |
+
|
57 |
+
import numpy as np
|
58 |
+
import h3
|
59 |
+
import json
|
60 |
+
import os
|
61 |
+
import matplotlib.pyplot as plt
|
62 |
+
from mpl_toolkits.axes_grid1 import make_axes_locatable
|
63 |
+
|
64 |
+
|
65 |
+
def get_labels(species, data):
|
66 |
+
species = str(species)
|
67 |
+
lat = []
|
68 |
+
lon = []
|
69 |
+
gt = []
|
70 |
+
for hx in data:
|
71 |
+
cur_lat, cur_lon = h3.h3_to_geo(hx)
|
72 |
+
if species in data[hx]:
|
73 |
+
cur_label = int(len(data[hx][species]) > 0)
|
74 |
+
gt.append(cur_label)
|
75 |
+
lat.append(cur_lat)
|
76 |
+
lon.append(cur_lon)
|
77 |
+
lat = np.array(lat).astype(np.float32)
|
78 |
+
lon = np.array(lon).astype(np.float32)
|
79 |
+
obs_locs = np.vstack((lon, lat)).T
|
80 |
+
gt = np.array(gt).astype(np.float32)
|
81 |
+
return obs_locs, gt
|
82 |
+
|
83 |
+
def lonlat_to_pixel(lonlat, grid_width, grid_height):
|
84 |
+
# Convert normalized lon/lat (-1 to 1) to pixel coordinates
|
85 |
+
x_pixel = np.floor((lonlat[:, 0] + 1) / 2 * (grid_width - 1)).astype(int)
|
86 |
+
y_pixel = np.floor((1 - (lonlat[:, 1] + 1) / 2) * (grid_height - 1)).astype(int)
|
87 |
+
return x_pixel, y_pixel
|
88 |
+
|
89 |
+
# def plot_heatmap(data,save_loc):
|
90 |
+
# # Apply mask if provided
|
91 |
+
# ocean_mask = np.load("data/masks/ocean_mask.npy", allow_pickle=True)
|
92 |
+
# # 1002, 2004 pixels
|
93 |
+
# # 0 in ocean (needs to be masked out)
|
94 |
+
#
|
95 |
+
# # Convert ocean_mask to boolean mask
|
96 |
+
# mask = ocean_mask.astype(bool)
|
97 |
+
# mask = mask[::2, ::2]
|
98 |
+
#
|
99 |
+
# if mask is not None:
|
100 |
+
# data = np.where(mask, data, 0)
|
101 |
+
#
|
102 |
+
# # Set NaN values to 0 for plotting
|
103 |
+
# data = np.nan_to_num(data, nan=0.0)
|
104 |
+
#
|
105 |
+
# fig, ax = plt.subplots(figsize=(20.04, 10.02), dpi=100)
|
106 |
+
# ax.set_xlim(-180, 180)
|
107 |
+
# ax.set_ylim(-90, 90)
|
108 |
+
# ax.axis('off')
|
109 |
+
#
|
110 |
+
# # Use 'magma' colormap with two discrete colors
|
111 |
+
# cmap = plt.get_cmap('magma', 2)
|
112 |
+
# cmap.set_bad(color='none')
|
113 |
+
# plt.rcParams['font.family'] = 'serif'
|
114 |
+
#
|
115 |
+
# cax_im = ax.imshow(data, extent=(-180, 180, -90, 90), origin='upper', cmap=cmap, vmin=0, vmax=1)
|
116 |
+
#
|
117 |
+
# plt.tight_layout()
|
118 |
+
# pdf_save_loc = save_loc + '.pdf'
|
119 |
+
# png_save_loc = save_loc + '.png'
|
120 |
+
# plt.savefig(pdf_save_loc, bbox_inches='tight', pad_inches=0)
|
121 |
+
# plt.savefig(png_save_loc, bbox_inches='tight', pad_inches=0)
|
122 |
+
# plt.close(fig)
|
123 |
+
|
124 |
+
def plot_heatmap(data, save_loc):
|
125 |
+
# Load the ocean mask
|
126 |
+
ocean_mask = np.load("data/masks/ocean_mask.npy", allow_pickle=True)
|
127 |
+
# 1002, 2004 pixels
|
128 |
+
# 0 in ocean (needs to be masked out)
|
129 |
+
|
130 |
+
# Convert ocean_mask to boolean mask
|
131 |
+
mask = ocean_mask.astype(bool)
|
132 |
+
# If you need to downsample the mask, uncomment the following line
|
133 |
+
mask = mask[::2, ::2]
|
134 |
+
|
135 |
+
# Set ocean areas to np.nan
|
136 |
+
data = np.where(mask, data, np.nan)
|
137 |
+
|
138 |
+
# Create a masked array where NaNs are masked
|
139 |
+
data_masked = np.ma.array(data, mask=np.isnan(data))
|
140 |
+
|
141 |
+
fig, ax = plt.subplots(figsize=(20.04, 10.02), dpi=100)
|
142 |
+
ax.set_xlim(-180, 180)
|
143 |
+
ax.set_ylim(-90, 90)
|
144 |
+
ax.axis('off')
|
145 |
+
|
146 |
+
# Use 'magma' colormap with two discrete colors
|
147 |
+
cmap = plt.get_cmap('plasma', 2)
|
148 |
+
# Set color for masked (NaN) values
|
149 |
+
cmap.set_bad(color='none') # 'none' makes it transparent; use 'white' for white background
|
150 |
+
|
151 |
+
# Plot the data
|
152 |
+
cax_im = ax.imshow(
|
153 |
+
data_masked,
|
154 |
+
extent=(-180, 180, -90, 90),
|
155 |
+
origin='upper',
|
156 |
+
cmap=cmap,
|
157 |
+
vmin=0,
|
158 |
+
vmax=1,
|
159 |
+
interpolation='nearest'
|
160 |
+
)
|
161 |
+
|
162 |
+
plt.tight_layout()
|
163 |
+
pdf_save_loc = save_loc + '.pdf'
|
164 |
+
png_save_loc = save_loc + '.png'
|
165 |
+
plt.savefig(pdf_save_loc, bbox_inches='tight', pad_inches=0)
|
166 |
+
plt.savefig(png_save_loc, bbox_inches='tight', pad_inches=0)
|
167 |
+
plt.close(fig)
|
168 |
+
|
169 |
+
def plot_heatmap_2(data, save_loc):
|
170 |
+
# Load the ocean mask
|
171 |
+
ocean_mask = np.load("data/masks/ocean_mask.npy", allow_pickle=True)
|
172 |
+
# 1002, 2004 pixels
|
173 |
+
# 0 in ocean (needs to be masked out)
|
174 |
+
|
175 |
+
# Convert ocean_mask to boolean mask
|
176 |
+
mask = ocean_mask.astype(bool)
|
177 |
+
# If you need to downsample the mask, uncomment the following line
|
178 |
+
|
179 |
+
# Set ocean areas to np.nan
|
180 |
+
data = np.where(mask, data, np.nan)
|
181 |
+
|
182 |
+
# Create a masked array where NaNs are masked
|
183 |
+
data_masked = np.ma.array(data, mask=np.isnan(data))
|
184 |
+
|
185 |
+
fig, ax = plt.subplots(figsize=(20.04, 10.02), dpi=100)
|
186 |
+
ax.set_xlim(-180, 180)
|
187 |
+
ax.set_ylim(-90, 90)
|
188 |
+
ax.axis('off')
|
189 |
+
|
190 |
+
# Use 'magma' colormap with two discrete colors
|
191 |
+
cmap = plt.get_cmap('plasma', 2)
|
192 |
+
# Set color for masked (NaN) values
|
193 |
+
cmap.set_bad(color='none') # 'none' makes it transparent; use 'white' for white background
|
194 |
+
|
195 |
+
# Plot the data
|
196 |
+
cax_im = ax.imshow(
|
197 |
+
data_masked,
|
198 |
+
extent=(-180, 180, -90, 90),
|
199 |
+
origin='upper',
|
200 |
+
cmap=cmap,
|
201 |
+
vmin=0,
|
202 |
+
vmax=1,
|
203 |
+
interpolation='nearest'
|
204 |
+
)
|
205 |
+
|
206 |
+
plt.tight_layout()
|
207 |
+
pdf_save_loc = save_loc + '.pdf'
|
208 |
+
png_save_loc = save_loc + '.png'
|
209 |
+
plt.savefig(pdf_save_loc, bbox_inches='tight', pad_inches=0)
|
210 |
+
plt.savefig(png_save_loc, bbox_inches='tight', pad_inches=0)
|
211 |
+
plt.show(block=False)
|
212 |
+
plt.close(fig)
|
213 |
+
|
214 |
+
def generate_ground_truth(taxa_id, snt=True, grid_height=501, grid_width=1002):
|
215 |
+
print(taxa_id)
|
216 |
+
if snt:
|
217 |
+
with open('paths.json', 'r') as f:
|
218 |
+
paths = json.load(f)
|
219 |
+
D = np.load(os.path.join(paths['snt'], 'snt_res_5.npy'), allow_pickle=True)
|
220 |
+
D = D.item()
|
221 |
+
loc_indices_per_species = D['loc_indices_per_species']
|
222 |
+
labels_per_species = D['labels_per_species']
|
223 |
+
taxa = D['taxa']
|
224 |
+
obs_locs = D['obs_locs']
|
225 |
+
obs_locs_idx = D['obs_locs_idx']
|
226 |
+
# class_index = np.where(taxa==taxa_id)
|
227 |
+
# class_index = class_index[0]
|
228 |
+
# class_index = class_index[0]
|
229 |
+
# species_loc_indices = loc_indices_per_species[class_index]
|
230 |
+
# species_locs = obs_locs[species_loc_indices]
|
231 |
+
# presence_indices = labels_per_species[class_index]
|
232 |
+
# species_locs = species_locs[presence_indices==1]
|
233 |
+
|
234 |
+
# Ensure class_index is correctly obtained as an integer index
|
235 |
+
class_indices = np.where(taxa == taxa_id)[0]
|
236 |
+
if len(class_indices) == 0:
|
237 |
+
raise ValueError(f"taxa_id {taxa_id} not found in taxa")
|
238 |
+
class_index = class_indices[0]
|
239 |
+
|
240 |
+
# Convert loc_indices_per_species[class_index] to a NumPy array
|
241 |
+
species_loc_indices = np.array(loc_indices_per_species[class_index])
|
242 |
+
|
243 |
+
# Retrieve the species locations using the indices
|
244 |
+
species_locs = obs_locs[species_loc_indices]
|
245 |
+
|
246 |
+
# Convert labels_per_species[class_index] to a NumPy array
|
247 |
+
presence_indices = np.array(labels_per_species[class_index])
|
248 |
+
|
249 |
+
# Filter species_locs where presence_indices == 1
|
250 |
+
species_locs = species_locs[presence_indices == 1]
|
251 |
+
|
252 |
+
else:
|
253 |
+
with open('paths.json', 'r') as f:
|
254 |
+
paths = json.load(f)
|
255 |
+
with open(os.path.join(paths['iucn'], 'iucn_res_5.json'), 'r') as f:
|
256 |
+
data = json.load(f)
|
257 |
+
obs_locs = np.array(data['locs'], dtype=np.float32)
|
258 |
+
taxa = [int(tt) for tt in data['taxa_presence'].keys()]
|
259 |
+
indices = data['taxa_presence'][str(taxa_id)]
|
260 |
+
species_locs = obs_locs[indices] # shape (N, 2)
|
261 |
+
|
262 |
+
|
263 |
+
# Normalize lonlat
|
264 |
+
species_locs_normalized = species_locs.copy()
|
265 |
+
species_locs_normalized[:, 0] = species_locs_normalized[:, 0] / 180 # lon / 180
|
266 |
+
species_locs_normalized[:, 1] = species_locs_normalized[:, 1] / 90 # lat / 90# Get grid dimensions from ocean_mas
|
267 |
+
|
268 |
+
|
269 |
+
# Get pixel coordinates
|
270 |
+
x_pixel, y_pixel = lonlat_to_pixel(species_locs_normalized, grid_width, grid_height)
|
271 |
+
|
272 |
+
# Ensure x_pixel and y_pixel are within bounds
|
273 |
+
x_pixel = np.clip(x_pixel, 0, grid_width - 1)
|
274 |
+
y_pixel = np.clip(y_pixel, 0, grid_height - 1)
|
275 |
+
|
276 |
+
# Create data array
|
277 |
+
data_array = np.zeros((grid_height, grid_width))
|
278 |
+
|
279 |
+
# Set pixels where species is present
|
280 |
+
data_array[y_pixel, x_pixel] = 1
|
281 |
+
|
282 |
+
# Now call plot_heatmap
|
283 |
+
title = f"Species presence for taxa {taxa_id}"
|
284 |
+
save_loc = f"./images/species_presence_{taxa_id}"
|
285 |
+
plot_heatmap(data_array, save_loc)
|
286 |
+
|
287 |
+
grid_height = 1002
|
288 |
+
grid_width = 2004
|
289 |
+
|
290 |
+
if snt:
|
291 |
+
with open('paths.json', 'r') as f:
|
292 |
+
paths = json.load(f)
|
293 |
+
D = np.load(os.path.join(paths['snt'], 'snt_res_5.npy'), allow_pickle=True)
|
294 |
+
D = D.item()
|
295 |
+
loc_indices_per_species = D['loc_indices_per_species']
|
296 |
+
labels_per_species = D['labels_per_species']
|
297 |
+
taxa = D['taxa']
|
298 |
+
obs_locs = D['obs_locs']
|
299 |
+
obs_locs_idx = D['obs_locs_idx']
|
300 |
+
# class_index = np.where(taxa==taxa_id)
|
301 |
+
# class_index = class_index[0]
|
302 |
+
# class_index = class_index[0]
|
303 |
+
# species_loc_indices = loc_indices_per_species[class_index]
|
304 |
+
# species_locs = obs_locs[species_loc_indices]
|
305 |
+
# presence_indices = labels_per_species[class_index]
|
306 |
+
# species_locs = species_locs[presence_indices==1]
|
307 |
+
|
308 |
+
# Ensure class_index is correctly obtained as an integer index
|
309 |
+
class_indices = np.where(taxa == taxa_id)[0]
|
310 |
+
if len(class_indices) == 0:
|
311 |
+
raise ValueError(f"taxa_id {taxa_id} not found in taxa")
|
312 |
+
class_index = class_indices[0]
|
313 |
+
|
314 |
+
# Convert loc_indices_per_species[class_index] to a NumPy array
|
315 |
+
species_loc_indices = np.array(loc_indices_per_species[class_index])
|
316 |
+
|
317 |
+
# Retrieve the species locations using the indices
|
318 |
+
species_locs = obs_locs[species_loc_indices]
|
319 |
+
|
320 |
+
# Convert labels_per_species[class_index] to a NumPy array
|
321 |
+
presence_indices = np.array(labels_per_species[class_index])
|
322 |
+
|
323 |
+
# Filter species_locs where presence_indices == 1
|
324 |
+
species_locs = species_locs[presence_indices == 1]
|
325 |
+
|
326 |
+
else:
|
327 |
+
with open('paths.json', 'r') as f:
|
328 |
+
paths = json.load(f)
|
329 |
+
with open(os.path.join(paths['iucn'], 'iucn_res_5.json'), 'r') as f:
|
330 |
+
data = json.load(f)
|
331 |
+
obs_locs = np.array(data['locs'], dtype=np.float32)
|
332 |
+
taxa = [int(tt) for tt in data['taxa_presence'].keys()]
|
333 |
+
indices = data['taxa_presence'][str(taxa_id)]
|
334 |
+
species_locs = obs_locs[indices] # shape (N, 2)
|
335 |
+
|
336 |
+
|
337 |
+
# Normalize lonlat
|
338 |
+
species_locs_normalized = species_locs.copy()
|
339 |
+
species_locs_normalized[:, 0] = species_locs_normalized[:, 0] / 180 # lon / 180
|
340 |
+
species_locs_normalized[:, 1] = species_locs_normalized[:, 1] / 90 # lat / 90# Get grid dimensions from ocean_mas
|
341 |
+
|
342 |
+
|
343 |
+
# Get pixel coordinates
|
344 |
+
x_pixel, y_pixel = lonlat_to_pixel(species_locs_normalized, grid_width, grid_height)
|
345 |
+
|
346 |
+
# Ensure x_pixel and y_pixel are within bounds
|
347 |
+
x_pixel = np.clip(x_pixel, 0, grid_width - 1)
|
348 |
+
y_pixel = np.clip(y_pixel, 0, grid_height - 1)
|
349 |
+
|
350 |
+
# Create data array
|
351 |
+
data_array = np.zeros((grid_height, grid_width))
|
352 |
+
|
353 |
+
# Set pixels where species is present
|
354 |
+
data_array[y_pixel, x_pixel] = 1
|
355 |
+
|
356 |
+
# Now call plot_heatmap
|
357 |
+
title = f"Species presence for taxa {taxa_id}"
|
358 |
+
save_loc = f"./images/species_presence_hr_{taxa_id}"
|
359 |
+
plot_heatmap_2(data_array, save_loc)
|
360 |
+
return True
|
361 |
+
|
362 |
+
if __name__ == '__main__':
|
363 |
+
snt = True
|
364 |
+
grid_height = 501
|
365 |
+
grid_width = 1002
|
366 |
+
taxa_id = 11901 # Or any taxa id you want to plot, as string
|
367 |
+
|
368 |
+
#TODO: why snt true? can't generate gt for (hyacinth macaw(18938), yellow baboon(67683), pika(43188), southernflyingsquirrel (46272))
|
369 |
+
generate_ground_truth(taxa_id=taxa_id, snt=snt, grid_height=grid_height, grid_width=grid_width)
|
models.py
ADDED
@@ -0,0 +1,1434 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.utils.data
|
3 |
+
import torch.nn as nn
|
4 |
+
import math
|
5 |
+
import csv
|
6 |
+
import numpy as np
|
7 |
+
import json
|
8 |
+
import os
|
9 |
+
|
10 |
+
|
11 |
+
def get_model(params, inference_only=False):
|
12 |
+
if params['model'] == 'ResidualFCNet':
|
13 |
+
return ResidualFCNet(params['input_dim'] + params['input_time_dim'] + (20 if 'env' in params['input_enc'] and 'contrastive' not in params['input_enc'] else 0) + (1 if params['noise_time'] else 0), params['num_classes'] + (20 if 'env' in params['loss'] else 0), params['num_filts'], params['depth'])
|
14 |
+
elif params['model'] == 'LinNet':
|
15 |
+
return LinNet(params['input_dim'] + params['input_time_dim'] + (20 if 'env' in params['input_enc'] else 0) + (1 if params['noise_time'] else 0), params['num_classes'])
|
16 |
+
elif params['model'] == 'HyperNet':
|
17 |
+
return HyperNet(params, params['input_dim'] + (20 if 'env' in params['input_enc'] else 0), params['num_classes'], params['num_filts'], params['depth'],
|
18 |
+
params['species_dim'], params['species_enc_depth'], params['species_filts'], params['species_enc'], inference_only=inference_only)
|
19 |
+
# chris models
|
20 |
+
elif params['model'] == 'MultiInputModel':
|
21 |
+
return MultiInputModel(num_inputs=params['input_dim'] + params['input_time_dim'] + (20 if 'env' in params['input_enc'] and 'contrastive' not in params['input_enc'] else 0) + (1 if params['noise_time'] else 0),
|
22 |
+
num_filts=params['num_filts'], num_classes=params['num_classes'] + (20 if 'env' in params['loss'] else 0),
|
23 |
+
depth=params['depth'], ema_factor=params['ema_factor'], nhead=params['num_heads'], num_encoder_layers=params['species_enc_depth'],
|
24 |
+
dim_feedforward=params['species_filts'], dropout=params['transformer_dropout'],
|
25 |
+
batch_first=True, token_dim=(params['species_dim'] + (20 if 'env' in params['transformer_input_enc'] else 0)),
|
26 |
+
sinr_inputs=True if 'sinr' in params['transformer_input_enc'] else False,
|
27 |
+
register=params['use_register'], use_pretrained_sinr=params['use_pretrained_sinr'],
|
28 |
+
freeze_sinr=params['freeze_sinr'], pretrained_loc=params['pretrained_loc'],
|
29 |
+
text_inputs=params['use_text_inputs'], class_token_transformation=params['class_token_transformation'])
|
30 |
+
elif params['model'] == 'VariableInputModel':
|
31 |
+
return VariableInputModel(num_inputs=params['input_dim'] + params['input_time_dim'] + (20 if 'env' in params['input_enc'] and 'contrastive' not in params['input_enc'] else 0) + (1 if params['noise_time'] else 0),
|
32 |
+
num_filts=params['num_filts'], num_classes=params['num_classes'] + (20 if 'env' in params['loss'] else 0),
|
33 |
+
depth=params['depth'], ema_factor=params['ema_factor'], nhead=params['num_heads'], num_encoder_layers=params['species_enc_depth'],
|
34 |
+
dim_feedforward=params['species_filts'], dropout=params['transformer_dropout'],
|
35 |
+
batch_first=True, token_dim=(params['species_dim'] + (20 if 'env' in params['transformer_input_enc'] else 0)),
|
36 |
+
sinr_inputs=True if 'sinr' in params['transformer_input_enc'] else False,
|
37 |
+
register=params['use_register'], use_pretrained_sinr=params['use_pretrained_sinr'],
|
38 |
+
freeze_sinr=params['freeze_sinr'], pretrained_loc=params['pretrained_loc'],
|
39 |
+
text_inputs=params['use_text_inputs'], image_inputs=params['use_image_inputs'],
|
40 |
+
env_inputs=params['use_env_inputs'],
|
41 |
+
class_token_transformation=params['class_token_transformation'])
|
42 |
+
|
43 |
+
# class VariableInputModel(nn.Module):
|
44 |
+
# def __init__(self, num_inputs, num_filts, num_classes, depth=4, nonlin='relu', lowrank=0, ema_factor=0.1,
|
45 |
+
# nhead=8, num_encoder_layers=4, dim_feedforward=2048, dropout=0.1, batch_first=True, token_dim=256,
|
46 |
+
# sinr_inputs=False, register=False, use_pretrained_sinr=False, freeze_sinr=False, pretrained_loc='',
|
47 |
+
# text_inputs=False, image_inputs=False, env_inputs=False, class_token_transformation='identity'):
|
48 |
+
|
49 |
+
|
50 |
+
class ResLayer(nn.Module):
|
51 |
+
def __init__(self, linear_size, activation=nn.ReLU, p=0.5):
|
52 |
+
super(ResLayer, self).__init__()
|
53 |
+
self.l_size = linear_size
|
54 |
+
self.nonlin1 = activation()
|
55 |
+
self.nonlin2 = activation()
|
56 |
+
self.dropout1 = nn.Dropout(p=p)
|
57 |
+
self.w1 = nn.Linear(self.l_size, self.l_size)
|
58 |
+
self.w2 = nn.Linear(self.l_size, self.l_size)
|
59 |
+
|
60 |
+
def forward(self, x):
|
61 |
+
y = self.w1(x)
|
62 |
+
y = self.nonlin1(y)
|
63 |
+
y = self.dropout1(y)
|
64 |
+
y = self.w2(y)
|
65 |
+
y = self.nonlin2(y)
|
66 |
+
out = x + y
|
67 |
+
return out
|
68 |
+
|
69 |
+
|
70 |
+
class ResidualFCNet(nn.Module):
|
71 |
+
def __init__(self, num_inputs, num_classes, num_filts, depth=4, nonlin='relu', lowrank=0, dropout_p=0.5):
|
72 |
+
super(ResidualFCNet, self).__init__()
|
73 |
+
self.inc_bias = False
|
74 |
+
if lowrank < num_filts and lowrank != 0:
|
75 |
+
l1 = nn.Linear(num_filts if depth != -1 else num_inputs, lowrank, bias=self.inc_bias)
|
76 |
+
l2 = nn.Linear(lowrank, num_classes, bias=self.inc_bias)
|
77 |
+
self.class_emb = nn.Sequential(l1, l2)
|
78 |
+
else:
|
79 |
+
self.class_emb = nn.Linear(num_filts if depth != -1 else num_inputs, num_classes, bias=self.inc_bias)
|
80 |
+
if nonlin == 'relu':
|
81 |
+
activation = nn.ReLU
|
82 |
+
elif nonlin == 'silu':
|
83 |
+
activation = nn.SiLU
|
84 |
+
else:
|
85 |
+
raise NotImplementedError('Invalid nonlinearity specified.')
|
86 |
+
layers = []
|
87 |
+
if depth != -1:
|
88 |
+
layers.append(nn.Linear(num_inputs, num_filts))
|
89 |
+
layers.append(activation())
|
90 |
+
for i in range(depth):
|
91 |
+
layers.append(ResLayer(num_filts, activation=activation))
|
92 |
+
else:
|
93 |
+
layers.append(nn.Identity())
|
94 |
+
self.feats = torch.nn.Sequential(*layers)
|
95 |
+
|
96 |
+
def forward(self, x, class_of_interest=None, return_feats=False):
|
97 |
+
loc_emb = self.feats(x)
|
98 |
+
if return_feats:
|
99 |
+
return loc_emb
|
100 |
+
if class_of_interest is None:
|
101 |
+
class_pred = self.class_emb(loc_emb)
|
102 |
+
else:
|
103 |
+
class_pred = self.eval_single_class(loc_emb, class_of_interest), self.eval_single_class(loc_emb, -1)
|
104 |
+
return torch.sigmoid(class_pred[0]), torch.sigmoid(class_pred[1])
|
105 |
+
return torch.sigmoid(class_pred)
|
106 |
+
|
107 |
+
def eval_single_class(self, x, class_of_interest):
|
108 |
+
if self.inc_bias:
|
109 |
+
return x @ self.class_emb.weight[class_of_interest, :] + self.class_emb.bias[class_of_interest]
|
110 |
+
else:
|
111 |
+
return x @ self.class_emb.weight[class_of_interest, :]
|
112 |
+
|
113 |
+
|
114 |
+
class SimpleFCNet(ResidualFCNet):
|
115 |
+
def forward(self, x, return_feats=True):
|
116 |
+
assert return_feats
|
117 |
+
loc_emb = self.feats(x)
|
118 |
+
class_pred = self.class_emb(loc_emb)
|
119 |
+
return class_pred
|
120 |
+
|
121 |
+
|
122 |
+
class MockTransformer(nn.Module):
|
123 |
+
def __init__(self, num_classes, num_dims):
|
124 |
+
super(MockTransformer, self).__init__()
|
125 |
+
self.species_emb = nn.Embedding(num_embeddings=num_classes, embedding_dim=num_dims)
|
126 |
+
|
127 |
+
def forward(self, class_ids):
|
128 |
+
return self.species_emb(class_ids)
|
129 |
+
|
130 |
+
|
131 |
+
class CombinedModel(nn.Module):
|
132 |
+
def __init__(self, num_inputs, num_filts, num_classes, depth=4, nonlin='relu', lowrank=0, ema_factor=0.1):
|
133 |
+
super(CombinedModel, self).__init__()
|
134 |
+
self.headless_model = HeadlessSINR(num_inputs, num_filts, depth, nonlin, lowrank)
|
135 |
+
if lowrank < num_filts and lowrank != 0:
|
136 |
+
self.transformer_model = MockTransformer(num_classes, lowrank)
|
137 |
+
else:
|
138 |
+
self.transformer_model = MockTransformer(num_classes, num_filts)
|
139 |
+
self.ema_factor = ema_factor
|
140 |
+
self.ema_embeddings = nn.Embedding(num_embeddings=num_classes, embedding_dim=lowrank if (lowrank < num_filts and lowrank != 0) else num_filts)
|
141 |
+
self.ema_embeddings.weight.data.copy_(self.transformer_model.species_emb.weight.data) # Initialize EMA with the same values as transformer
|
142 |
+
# this will have to change when I start using the actual transformer
|
143 |
+
|
144 |
+
def forward(self, x, class_ids=None, return_feats=False, return_class_embeddings=False, class_of_interest=None):
|
145 |
+
# Process input through the headless model to get feature embeddings
|
146 |
+
feature_embeddings = self.headless_model(x)
|
147 |
+
|
148 |
+
if return_feats:
|
149 |
+
return feature_embeddings
|
150 |
+
else:
|
151 |
+
if class_of_interest == None:
|
152 |
+
# Get class-specific embeddings based on class_ids
|
153 |
+
class_embeddings = self.transformer_model(class_ids)
|
154 |
+
if return_class_embeddings:
|
155 |
+
return class_embeddings
|
156 |
+
else:
|
157 |
+
# Update EMA embeddings for these class IDs
|
158 |
+
if self.training:
|
159 |
+
self.update_ema_embeddings(class_ids, class_embeddings)
|
160 |
+
|
161 |
+
# Matrix multiplication to produce logits
|
162 |
+
logits = feature_embeddings @ class_embeddings.T
|
163 |
+
|
164 |
+
# Apply sigmoid to convert logits to probabilities
|
165 |
+
probabilities = torch.sigmoid(logits)
|
166 |
+
|
167 |
+
return probabilities
|
168 |
+
else:
|
169 |
+
device = self.ema_embeddings.weight.device
|
170 |
+
class_of_interest_tensor =torch.tensor([class_of_interest]).to(device)
|
171 |
+
class_embedding = self.get_ema_embeddings(class_of_interest_tensor)
|
172 |
+
print(f'using EMA estimate for class {class_of_interest}')
|
173 |
+
if return_class_embeddings:
|
174 |
+
return class_embedding
|
175 |
+
else:
|
176 |
+
# Matrix multiplication to produce logits
|
177 |
+
logits = feature_embeddings @ class_embedding.T
|
178 |
+
|
179 |
+
# Apply sigmoid to convert logits to probabilities
|
180 |
+
probabilities = torch.sigmoid(logits)
|
181 |
+
probabilities = probabilities.squeeze()
|
182 |
+
|
183 |
+
return probabilities
|
184 |
+
|
185 |
+
def update_ema_embeddings(self, class_ids, current_embeddings):
|
186 |
+
if self.training:
|
187 |
+
# Get current EMA embeddings for the class IDs
|
188 |
+
ema_current = self.ema_embeddings(class_ids)
|
189 |
+
|
190 |
+
# Calculate new EMA values
|
191 |
+
ema_new = self.ema_factor * current_embeddings + (1 - self.ema_factor) * ema_current
|
192 |
+
|
193 |
+
# Update the EMA embeddings
|
194 |
+
self.ema_embeddings.weight.data[class_ids] = ema_new.detach() # Detach to prevent gradients from flowing here
|
195 |
+
|
196 |
+
def get_ema_embeddings(self, class_ids):
|
197 |
+
# Method to access EMA embeddings
|
198 |
+
return self.ema_embeddings(class_ids)
|
199 |
+
|
200 |
+
class HeadlessSINR(nn.Module):
|
201 |
+
def __init__(self, num_inputs, num_filts, depth=4, nonlin='relu', lowrank=0, dropout_p=0.5):
|
202 |
+
super(HeadlessSINR, self).__init__()
|
203 |
+
self.inc_bias = False
|
204 |
+
self.low_rank_feats = None
|
205 |
+
if lowrank < num_filts and lowrank != 0:
|
206 |
+
l1 = nn.Linear(num_filts if depth != -1 else num_inputs, lowrank, bias=self.inc_bias)
|
207 |
+
self.low_rank_feats = l1
|
208 |
+
# else:
|
209 |
+
# self.class_emb = nn.Linear(num_filts if depth != -1 else num_inputs, num_classes, bias=self.inc_bias)
|
210 |
+
if nonlin == 'relu':
|
211 |
+
activation = nn.ReLU
|
212 |
+
elif nonlin == 'silu':
|
213 |
+
activation = nn.SiLU
|
214 |
+
else:
|
215 |
+
raise NotImplementedError('Invalid nonlinearity specified.')
|
216 |
+
|
217 |
+
# Create the layers list for feature extraction
|
218 |
+
layers = []
|
219 |
+
if depth != -1:
|
220 |
+
layers.append(nn.Linear(num_inputs, num_filts))
|
221 |
+
layers.append(activation())
|
222 |
+
for i in range(depth):
|
223 |
+
layers.append(ResLayer(num_filts, activation=activation, p=dropout_p))
|
224 |
+
else:
|
225 |
+
layers.append(nn.Identity())
|
226 |
+
# Include low-rank features in the sequential model if it is defined
|
227 |
+
if self.low_rank_feats:
|
228 |
+
# Apply initial layers then low-rank features
|
229 |
+
layers.append(self.low_rank_feats)
|
230 |
+
# Set up the features as a sequential model
|
231 |
+
self.feats = nn.Sequential(*layers)
|
232 |
+
|
233 |
+
def forward(self, x):
|
234 |
+
loc_emb = self.feats(x)
|
235 |
+
return loc_emb
|
236 |
+
|
237 |
+
|
238 |
+
class TransformerEncoderModel(nn.Module):
|
239 |
+
def __init__(self, d_model=256, nhead=8, num_encoder_layers=4, dim_feedforward=2048, dropout=0.1, activation='relu',
|
240 |
+
batch_first=True, output_dim=256): # BATCH FIRST MIGHT HAVE TO CHANGE
|
241 |
+
super(TransformerEncoderModel, self).__init__()
|
242 |
+
self.input_layer_norm = nn.LayerNorm(normalized_shape=d_model)
|
243 |
+
# Create an encoder layer
|
244 |
+
encoder_layer = nn.TransformerEncoderLayer(
|
245 |
+
d_model=d_model,
|
246 |
+
nhead=nhead,
|
247 |
+
dim_feedforward=dim_feedforward,
|
248 |
+
dropout=dropout,
|
249 |
+
activation=activation,
|
250 |
+
batch_first=batch_first
|
251 |
+
)
|
252 |
+
|
253 |
+
# Stack the encoder layers into an encoder module
|
254 |
+
self.transformer_encoder = nn.TransformerEncoder(
|
255 |
+
encoder_layer=encoder_layer,
|
256 |
+
num_layers=num_encoder_layers
|
257 |
+
)
|
258 |
+
|
259 |
+
# Example output layer (modify according to your needs)
|
260 |
+
self.output_layer = nn.Linear(d_model, output_dim)
|
261 |
+
|
262 |
+
def forward(self, src, src_mask=None, src_key_padding_mask=None):
|
263 |
+
"""
|
264 |
+
Args:
|
265 |
+
src: the sequence to the encoder (shape: [seq_length, batch_size, d_model])
|
266 |
+
src_mask: the mask for the src sequence (shape: [seq_length, seq_length])
|
267 |
+
src_key_padding_mask: the mask for the padding tokens (shape: [batch_size, seq_length])
|
268 |
+
|
269 |
+
Returns:
|
270 |
+
output of the transformer encoder
|
271 |
+
"""
|
272 |
+
# Pass the input through the transformer encoder
|
273 |
+
encoder_input = self.input_layer_norm(src)
|
274 |
+
encoder_output = self.transformer_encoder(encoder_input, src_key_padding_mask=src_key_padding_mask, mask=src_mask)
|
275 |
+
|
276 |
+
# # Pass the encoder output through the output layer
|
277 |
+
# output = self.output_layer(encoder_output)
|
278 |
+
|
279 |
+
# Assuming the class token is the first in the sequence
|
280 |
+
# batch_first so we have (batch, sequence, dim)
|
281 |
+
if encoder_output.ndim == 2:
|
282 |
+
# in situations where we don't have a batch
|
283 |
+
encoder_output = encoder_output.unsqueeze(0)
|
284 |
+
|
285 |
+
class_token_embedding = encoder_output[:, 0, :]
|
286 |
+
|
287 |
+
output = self.output_layer(class_token_embedding) # Process only the class token embedding
|
288 |
+
return output
|
289 |
+
|
290 |
+
|
291 |
+
class MultiInputModel(nn.Module):
|
292 |
+
def __init__(self, num_inputs, num_filts, num_classes, depth=4, nonlin='relu', lowrank=0, ema_factor=0.1,
|
293 |
+
nhead=8, num_encoder_layers=4, dim_feedforward=2048, dropout=0.1, batch_first=True, token_dim=256,
|
294 |
+
sinr_inputs=False, register=False, use_pretrained_sinr=False, freeze_sinr=False, pretrained_loc='',
|
295 |
+
text_inputs=False, class_token_transformation='identity'):
|
296 |
+
super(MultiInputModel, self).__init__()
|
297 |
+
|
298 |
+
self.headless_model = HeadlessSINR(num_inputs, num_filts, depth, nonlin, lowrank, dropout_p=dropout)
|
299 |
+
self.ema_factor = ema_factor
|
300 |
+
self.class_token_transformation = class_token_transformation
|
301 |
+
|
302 |
+
# Load pretrained state_dict if use_pretrained_sinr is set to True
|
303 |
+
if use_pretrained_sinr:
|
304 |
+
#pretrained_state_dict = torch.load(pretrained_loc, weights_only=False)['state_dict']
|
305 |
+
pretrained_state_dict = torch.load(pretrained_loc, map_location=torch.device('cpu'))['state_dict']
|
306 |
+
filtered_state_dict = {k: v for k, v in pretrained_state_dict.items() if not k.startswith('class_emb')}
|
307 |
+
self.headless_model.load_state_dict(filtered_state_dict, strict=False)
|
308 |
+
#print(f'Using pretrained sinr from {pretrained_loc}')
|
309 |
+
|
310 |
+
# Freeze the SINR model if freeze_sinr is set to True
|
311 |
+
if freeze_sinr:
|
312 |
+
for param in self.headless_model.parameters():
|
313 |
+
param.requires_grad = False
|
314 |
+
print("Freezing SINR model parameters")
|
315 |
+
|
316 |
+
# self.transformer_model = MockTransformer(num_classes, num_filts)
|
317 |
+
self.transformer_model = TransformerEncoderModel(d_model=token_dim,
|
318 |
+
nhead=nhead,
|
319 |
+
num_encoder_layers=num_encoder_layers,
|
320 |
+
dim_feedforward=dim_feedforward,
|
321 |
+
dropout=dropout,
|
322 |
+
batch_first=batch_first,
|
323 |
+
output_dim=num_filts)
|
324 |
+
|
325 |
+
self.ema_embeddings = nn.Embedding(num_embeddings=num_classes, embedding_dim=num_filts)
|
326 |
+
# this is just a workaround for now to load eval embeddings - probably not needed long term
|
327 |
+
self.eval_embeddings = nn.Embedding(num_embeddings=num_classes, embedding_dim=num_filts)
|
328 |
+
self.ema_embeddings.weight.requires_grad = False
|
329 |
+
self.eval_embeddings.weight.requires_grad = False
|
330 |
+
self.num_filts=num_filts
|
331 |
+
self.token_dim = token_dim
|
332 |
+
# nn.init.xavier_uniform_(self.ema_embeddings.weight) # not needed I think
|
333 |
+
self.sinr_inputs = sinr_inputs
|
334 |
+
if self.sinr_inputs:
|
335 |
+
if self.num_filts != self.token_dim and self.class_token_transformation == 'identity':
|
336 |
+
raise ValueError("If using sinr inputs to transformer with identity class token transformation"
|
337 |
+
"then token_dim of transformer must be equal to num_filts of sinr model")
|
338 |
+
|
339 |
+
# Add a class token
|
340 |
+
self.class_token = nn.Parameter(torch.empty(1, self.token_dim))
|
341 |
+
nn.init.xavier_uniform_(self.class_token)
|
342 |
+
|
343 |
+
if register:
|
344 |
+
# Add a register token initialized with Xavier uniform initialization
|
345 |
+
self.register = nn.Parameter(torch.empty(1, self.token_dim))
|
346 |
+
# self.register = (self.register / 2)
|
347 |
+
nn.init.xavier_uniform_(self.register)
|
348 |
+
else:
|
349 |
+
self.register = None
|
350 |
+
|
351 |
+
self.text_inputs = text_inputs
|
352 |
+
if self.text_inputs:
|
353 |
+
#print("JUST USING A HEADLESS SINR FOR THE TEXT MODEL RIGHT NOW")
|
354 |
+
self.text_model=HeadlessSINR(num_inputs=4096, num_filts=512, depth=2, nonlin=nonlin, lowrank=token_dim, dropout_p=dropout)
|
355 |
+
else:
|
356 |
+
self.text_model=None
|
357 |
+
|
358 |
+
# Type-specific embeddings for class, register, location, and text tokens
|
359 |
+
self.class_type_embedding = nn.Parameter(torch.empty(1, self.token_dim))
|
360 |
+
nn.init.xavier_uniform_(self.class_type_embedding)
|
361 |
+
if register:
|
362 |
+
self.register_type_embedding = nn.Parameter(torch.empty(1, self.token_dim))
|
363 |
+
nn.init.xavier_uniform_(self.register_type_embedding)
|
364 |
+
self.location_type_embedding = nn.Parameter(torch.empty(1, self.token_dim))
|
365 |
+
nn.init.xavier_uniform_(self.location_type_embedding)
|
366 |
+
if text_inputs:
|
367 |
+
self.text_type_embedding = nn.Parameter(torch.empty(1, self.token_dim))
|
368 |
+
nn.init.xavier_uniform_(self.text_type_embedding)
|
369 |
+
|
370 |
+
# Instantiate the class token transformation module
|
371 |
+
if class_token_transformation == 'identity':
|
372 |
+
self.class_token_transform = Identity(token_dim, num_filts)
|
373 |
+
elif class_token_transformation == 'linear':
|
374 |
+
self.class_token_transform = LinearTransformation(token_dim, num_filts)
|
375 |
+
elif class_token_transformation == 'single_layer_nn':
|
376 |
+
self.class_token_transform = SingleLayerNN(token_dim, num_filts, dropout_p=dropout)
|
377 |
+
elif class_token_transformation == 'two_layer_nn':
|
378 |
+
self.class_token_transform = TwoLayerNN(token_dim, num_filts, dropout_p=dropout)
|
379 |
+
elif class_token_transformation == 'sinr':
|
380 |
+
self.class_token_transform = HeadlessSINR(token_dim, num_filts, depth, nonlin, lowrank, dropout_p=dropout)
|
381 |
+
else:
|
382 |
+
raise ValueError(f"Unknown class_token_transformation: {class_token_transformation}")
|
383 |
+
|
384 |
+
|
385 |
+
def forward(self, x, context_sequence, context_mask, class_ids=None, return_feats=False, return_class_embeddings=False, class_of_interest=None, use_eval_embeddings=False, text_emb=None):
|
386 |
+
# Process input through the headless model to get feature embeddings
|
387 |
+
feature_embeddings = self.headless_model(x)
|
388 |
+
|
389 |
+
if return_feats:
|
390 |
+
return feature_embeddings
|
391 |
+
|
392 |
+
if context_sequence.dim() == 2:
|
393 |
+
context_sequence = context_sequence.unsqueeze(0) # Add batch dimension if missing
|
394 |
+
|
395 |
+
context_sequence = context_sequence[:, 1:, :]
|
396 |
+
|
397 |
+
if self.sinr_inputs:
|
398 |
+
# Pass through the headless model
|
399 |
+
context_sequence = self.headless_model(context_sequence)
|
400 |
+
|
401 |
+
# Add type-specific embedding to each location token
|
402 |
+
# print("SEE IF THIS WORKS")
|
403 |
+
context_sequence += self.location_type_embedding
|
404 |
+
|
405 |
+
batch_size = context_sequence.size(0)
|
406 |
+
|
407 |
+
# Expand the class token to match the batch size and add its type-specific embedding
|
408 |
+
class_token_expanded = self.class_token.expand(batch_size, -1, -1) + self.class_type_embedding
|
409 |
+
|
410 |
+
if self.text_inputs and (text_emb is not None):
|
411 |
+
text_mask = (text_emb.sum(dim=1) == 0)
|
412 |
+
text_emb = self.text_model(text_emb)
|
413 |
+
text_emb += self.text_type_embedding
|
414 |
+
text_emb[text_mask] = 0
|
415 |
+
# Reshape text_emb to have the shape (batch_size, 1, embedding_dim)
|
416 |
+
text_emb = text_emb.unsqueeze(1)
|
417 |
+
|
418 |
+
|
419 |
+
if self.register is None:
|
420 |
+
# context sequence = learnable class_token + rest of sequence
|
421 |
+
if self.text_inputs:
|
422 |
+
# Add the class token and text embeddings to the context sequence
|
423 |
+
context_sequence = torch.cat((class_token_expanded, text_emb, context_sequence), dim=1)
|
424 |
+
# Pad the context mask to account for the added text embeddings
|
425 |
+
context_mask = nn.functional.pad(context_mask, pad=(1, 0), value=False)
|
426 |
+
# Update the new part of the mask with the text_mask
|
427 |
+
context_mask[:, 1] = text_mask # Apply mask directly
|
428 |
+
else:
|
429 |
+
context_sequence = torch.cat((class_token_expanded, context_sequence), dim=1)
|
430 |
+
else:
|
431 |
+
# Expand the register token to match the batch size and add its type-specific embedding
|
432 |
+
register_expanded = self.register.expand(batch_size, -1, -1) + self.register_type_embedding
|
433 |
+
if self.text_inputs:
|
434 |
+
# Add all components: class token, register, text embeddings, and context
|
435 |
+
context_sequence = torch.cat((class_token_expanded, register_expanded, text_emb, context_sequence),
|
436 |
+
dim=1)
|
437 |
+
# Double pad the context mask: first for register, then for text embeddings
|
438 |
+
context_mask = nn.functional.pad(context_mask, pad=(1, 0), value=False)
|
439 |
+
context_mask = nn.functional.pad(context_mask, pad=(1, 0), value=False)
|
440 |
+
# Update the new part of the mask for text embeddings
|
441 |
+
context_mask[:, register_expanded.size(1) + 1] = text_mask # Apply mask directly
|
442 |
+
else:
|
443 |
+
context_sequence = torch.cat((class_token_expanded, register_expanded, context_sequence), dim=1)
|
444 |
+
# Update the context mask to account for the register token
|
445 |
+
context_mask = nn.functional.pad(context_mask, pad=(1, 0), value=False)
|
446 |
+
|
447 |
+
if use_eval_embeddings == False:
|
448 |
+
if class_of_interest == None:
|
449 |
+
# Get class-specific embeddings based on class_ids
|
450 |
+
class_token_output = self.transformer_model(src=context_sequence, src_key_padding_mask=context_mask)
|
451 |
+
# pass these through the class token transformation
|
452 |
+
class_embeddings = self.class_token_transform(class_token_output) # Shape: (batch_size, num_filts)
|
453 |
+
|
454 |
+
if return_class_embeddings:
|
455 |
+
return class_embeddings
|
456 |
+
else:
|
457 |
+
# Update EMA embeddings for these class IDs
|
458 |
+
with torch.no_grad():
|
459 |
+
if self.training:
|
460 |
+
self.update_ema_embeddings(class_ids, class_embeddings)
|
461 |
+
|
462 |
+
# Matrix multiplication to produce logits
|
463 |
+
logits = feature_embeddings @ class_embeddings.T
|
464 |
+
|
465 |
+
# Apply sigmoid to convert logits to probabilities
|
466 |
+
probabilities = torch.sigmoid(logits)
|
467 |
+
|
468 |
+
return probabilities
|
469 |
+
else:
|
470 |
+
device = self.ema_embeddings.weight.device
|
471 |
+
class_of_interest_tensor =torch.tensor([class_of_interest]).to(device)
|
472 |
+
class_embedding = self.get_ema_embeddings(class_of_interest_tensor)
|
473 |
+
print(f'using EMA estimate for class {class_of_interest}')
|
474 |
+
if return_class_embeddings:
|
475 |
+
return class_embedding
|
476 |
+
else:
|
477 |
+
# Matrix multiplication to produce logits
|
478 |
+
logits = feature_embeddings @ class_embedding.T
|
479 |
+
|
480 |
+
# Apply sigmoid to convert logits to probabilities
|
481 |
+
probabilities = torch.sigmoid(logits)
|
482 |
+
probabilities = probabilities.squeeze()
|
483 |
+
return probabilities
|
484 |
+
else:
|
485 |
+
self.eval()
|
486 |
+
if not hasattr(self, 'eval_embeddings'):
|
487 |
+
self.eval_embeddings = self.ema_embeddings
|
488 |
+
if class_of_interest == None:
|
489 |
+
# Get class-specific embeddings based on class_ids
|
490 |
+
class_token_output = self.transformer_model(src=context_sequence, src_key_padding_mask=context_mask)
|
491 |
+
class_embeddings = self.class_token_transform(class_token_output)
|
492 |
+
# Update EMA embeddings for these class IDs
|
493 |
+
|
494 |
+
self.generate_eval_embeddings(class_ids, class_embeddings)
|
495 |
+
|
496 |
+
# Matrix multiplication to produce logits
|
497 |
+
logits = feature_embeddings @ class_embeddings.T
|
498 |
+
|
499 |
+
# Apply sigmoid to convert logits to probabilities
|
500 |
+
probabilities = torch.sigmoid(logits)
|
501 |
+
|
502 |
+
return probabilities
|
503 |
+
else:
|
504 |
+
device = self.ema_embeddings.weight.device
|
505 |
+
class_of_interest_tensor =torch.tensor([class_of_interest]).to(device)
|
506 |
+
class_embedding = self.get_eval_embeddings(class_of_interest_tensor)
|
507 |
+
print(f'using eval embedding for class {class_of_interest}')
|
508 |
+
if return_class_embeddings:
|
509 |
+
return class_embedding
|
510 |
+
else:
|
511 |
+
# Matrix multiplication to produce logits
|
512 |
+
logits = feature_embeddings @ class_embedding.T
|
513 |
+
|
514 |
+
# Apply sigmoid to convert logits to probabilities
|
515 |
+
probabilities = torch.sigmoid(logits)
|
516 |
+
probabilities = probabilities.squeeze()
|
517 |
+
return probabilities
|
518 |
+
|
519 |
+
def init_eval_embeddings(self, num_classes):
|
520 |
+
self.eval_embeddings = nn.Embedding(num_embeddings=num_classes, embedding_dim=self.num_filts)
|
521 |
+
nn.init.xavier_uniform_(self.eval_embeddings.weight)
|
522 |
+
|
523 |
+
def get_ema_embeddings(self, class_ids):
|
524 |
+
# Method to access EMA embeddings
|
525 |
+
return self.ema_embeddings(class_ids)
|
526 |
+
|
527 |
+
def get_eval_embeddings(self, class_ids):
|
528 |
+
# Method to access eval embeddings
|
529 |
+
return self.eval_embeddings(class_ids)
|
530 |
+
|
531 |
+
def update_ema_embeddings(self, class_ids, current_embeddings):
|
532 |
+
if self.training:
|
533 |
+
# Get unique class IDs and their counts
|
534 |
+
unique_class_ids, inverse_indices, counts = class_ids.unique(return_counts=True, return_inverse=True)
|
535 |
+
|
536 |
+
# Get current EMA embeddings for unique class IDs
|
537 |
+
ema_current = self.ema_embeddings(unique_class_ids)
|
538 |
+
|
539 |
+
# Initialize a placeholder for new EMA values
|
540 |
+
ema_new = torch.zeros_like(ema_current)
|
541 |
+
|
542 |
+
# Compute the average of current embeddings for each unique class ID
|
543 |
+
current_sum = torch.zeros_like(ema_current)
|
544 |
+
current_sum.index_add_(0, inverse_indices, current_embeddings)
|
545 |
+
current_avg = current_sum / counts.unsqueeze(1)
|
546 |
+
|
547 |
+
# Apply EMA update formula
|
548 |
+
ema_new = self.ema_factor * current_avg + (1 - self.ema_factor) * ema_current
|
549 |
+
|
550 |
+
# Update the EMA embeddings for unique class IDs
|
551 |
+
self.ema_embeddings.weight.data[unique_class_ids] = ema_new.detach() # Detach to prevent gradients
|
552 |
+
|
553 |
+
def generate_eval_embeddings(self, class_id, current_embedding):
|
554 |
+
self.eval_embeddings.weight.data[class_id, :] = current_embedding.detach() # Detach to prevent gradients
|
555 |
+
|
556 |
+
# self.eval_embeddings.weight.data[class_id] = self.ema_embeddings.weight.data[class_id] # Detach to prevent gradients
|
557 |
+
|
558 |
+
|
559 |
+
def embedding_forward(self, x, class_ids=None, return_feats=False, return_class_embeddings=False, class_of_interest=None, eval=False):
|
560 |
+
# forward method that uses ema or eval embeddings rather than context sequence
|
561 |
+
|
562 |
+
# Process input through the headless model to get feature embeddings
|
563 |
+
feature_embeddings = self.headless_model(x)
|
564 |
+
|
565 |
+
if return_feats:
|
566 |
+
return feature_embeddings
|
567 |
+
else:
|
568 |
+
if class_of_interest == None:
|
569 |
+
# Get class-specific embeddings based on class_ids
|
570 |
+
if eval == False:
|
571 |
+
class_embeddings = self.get_ema_embeddings(class_ids=class_ids)
|
572 |
+
else:
|
573 |
+
class_embeddings = self.get_eval_embeddings(class_ids=class_ids)
|
574 |
+
if return_class_embeddings:
|
575 |
+
return class_embeddings
|
576 |
+
else:
|
577 |
+
# Matrix multiplication to produce logits
|
578 |
+
logits = feature_embeddings @ class_embeddings.T
|
579 |
+
|
580 |
+
# Apply sigmoid to convert logits to probabilities
|
581 |
+
probabilities = torch.sigmoid(logits)
|
582 |
+
|
583 |
+
return probabilities
|
584 |
+
else:
|
585 |
+
if eval == False:
|
586 |
+
device = self.ema_embeddings.weight.device
|
587 |
+
class_of_interest_tensor =torch.tensor([class_of_interest]).to(device)
|
588 |
+
class_embedding = self.get_ema_embeddings(class_of_interest_tensor)
|
589 |
+
print(f'using EMA estimate for class {class_of_interest}')
|
590 |
+
if return_class_embeddings:
|
591 |
+
return class_embedding
|
592 |
+
else:
|
593 |
+
# Matrix multiplication to produce logits
|
594 |
+
logits = feature_embeddings @ class_embedding.T
|
595 |
+
|
596 |
+
# Apply sigmoid to convert logits to probabilities
|
597 |
+
probabilities = torch.sigmoid(logits)
|
598 |
+
probabilities = probabilities.squeeze()
|
599 |
+
|
600 |
+
return probabilities
|
601 |
+
|
602 |
+
else:
|
603 |
+
device = self.eval_embeddings.weight.device
|
604 |
+
class_of_interest_tensor = torch.tensor([class_of_interest]).to(device)
|
605 |
+
class_embedding = self.get_eval_embeddings(class_of_interest_tensor)
|
606 |
+
#print(f'using eval estimate for class {class_of_interest}')
|
607 |
+
if return_class_embeddings:
|
608 |
+
return class_embedding
|
609 |
+
else:
|
610 |
+
# Matrix multiplication to produce logits
|
611 |
+
logits = feature_embeddings @ class_embedding.T
|
612 |
+
|
613 |
+
# Apply sigmoid to convert logits to probabilities
|
614 |
+
probabilities = torch.sigmoid(logits)
|
615 |
+
probabilities = probabilities.squeeze()
|
616 |
+
|
617 |
+
return probabilities
|
618 |
+
|
619 |
+
class VariableInputModel(nn.Module):
|
620 |
+
def __init__(self, num_inputs, num_filts, num_classes, depth=4, nonlin='relu', lowrank=0, ema_factor=0.1,
|
621 |
+
nhead=8, num_encoder_layers=4, dim_feedforward=2048, dropout=0.1, batch_first=True, token_dim=256,
|
622 |
+
sinr_inputs=False, register=False, use_pretrained_sinr=False, freeze_sinr=False, pretrained_loc='',
|
623 |
+
text_inputs=False, image_inputs=False, env_inputs=False, class_token_transformation='identity'):
|
624 |
+
|
625 |
+
super(VariableInputModel, self).__init__()
|
626 |
+
self.headless_model = HeadlessSINR(num_inputs, num_filts, depth, nonlin, lowrank, dropout_p=dropout)
|
627 |
+
self.ema_factor = ema_factor
|
628 |
+
self.class_token_transformation = class_token_transformation
|
629 |
+
|
630 |
+
# Load pretrained state_dict if use_pretrained_sinr is set to True
|
631 |
+
if use_pretrained_sinr:
|
632 |
+
pretrained_state_dict = torch.load(pretrained_loc, weights_only=False)['state_dict']
|
633 |
+
filtered_state_dict = {k: v for k, v in pretrained_state_dict.items() if not k.startswith('class_emb')}
|
634 |
+
self.headless_model.load_state_dict(filtered_state_dict, strict=False)
|
635 |
+
#print(f'Using pretrained sinr from {pretrained_loc}')
|
636 |
+
|
637 |
+
# Freeze the SINR model if freeze_sinr is set to True
|
638 |
+
if freeze_sinr:
|
639 |
+
for param in self.headless_model.parameters():
|
640 |
+
param.requires_grad = False
|
641 |
+
print("Freezing SINR model parameters")
|
642 |
+
|
643 |
+
# self.transformer_model = MockTransformer(num_classes, num_filts)
|
644 |
+
self.transformer_model = TransformerEncoderModel(d_model=token_dim,
|
645 |
+
nhead=nhead,
|
646 |
+
num_encoder_layers=num_encoder_layers,
|
647 |
+
dim_feedforward=dim_feedforward,
|
648 |
+
dropout=dropout,
|
649 |
+
batch_first=batch_first,
|
650 |
+
output_dim=num_filts)
|
651 |
+
|
652 |
+
self.ema_embeddings = nn.Embedding(num_embeddings=num_classes, embedding_dim=num_filts)
|
653 |
+
# this is just a workaround for now to load eval embeddings - probably not needed long term
|
654 |
+
self.eval_embeddings = nn.Embedding(num_embeddings=num_classes, embedding_dim=num_filts)
|
655 |
+
self.ema_embeddings.weight.requires_grad = False
|
656 |
+
self.eval_embeddings.weight.requires_grad = False
|
657 |
+
self.num_filts=num_filts
|
658 |
+
self.token_dim = token_dim
|
659 |
+
# nn.init.xavier_uniform_(self.ema_embeddings.weight) # not needed I think
|
660 |
+
self.sinr_inputs = sinr_inputs
|
661 |
+
if self.sinr_inputs:
|
662 |
+
if self.num_filts != self.token_dim and self.class_token_transformation == 'identity':
|
663 |
+
raise ValueError("If using sinr inputs to transformer with identity class token transformation"
|
664 |
+
"then token_dim of transformer must be equal to num_filts of sinr model")
|
665 |
+
|
666 |
+
# Add a class token
|
667 |
+
self.class_token = nn.Parameter(torch.empty(1, self.token_dim))
|
668 |
+
nn.init.xavier_uniform_(self.class_token)
|
669 |
+
|
670 |
+
if register:
|
671 |
+
# Add a register token initialized with Xavier uniform initialization
|
672 |
+
self.register = nn.Parameter(torch.empty(1, self.token_dim))
|
673 |
+
# self.register = (self.register / 2)
|
674 |
+
nn.init.xavier_uniform_(self.register)
|
675 |
+
else:
|
676 |
+
self.register = None
|
677 |
+
|
678 |
+
self.text_inputs = text_inputs
|
679 |
+
if self.text_inputs:
|
680 |
+
print("JUST USING A HEADLESS SINR FOR THE TEXT MODEL RIGHT NOW")
|
681 |
+
self.text_model=HeadlessSINR(num_inputs=4096, num_filts=512, depth=2, nonlin=nonlin, lowrank=token_dim, dropout_p=dropout)
|
682 |
+
else:
|
683 |
+
self.text_model=None
|
684 |
+
self.image_inputs = image_inputs
|
685 |
+
if self.image_inputs:
|
686 |
+
print("JUST USING A HEADLESS SINR FOR THE IMAGE MODEL RIGHT NOW")
|
687 |
+
self.image_model=HeadlessSINR(num_inputs=1024, num_filts=512, depth=2, nonlin=nonlin, lowrank=token_dim, dropout_p=dropout)
|
688 |
+
else:
|
689 |
+
self.image_model=None
|
690 |
+
self.env_inputs = env_inputs
|
691 |
+
if self.env_inputs:
|
692 |
+
print("JUST USING A HEADLESS SINR FOR THE ENV MODEL RIGHT NOW")
|
693 |
+
self.env_model=HeadlessSINR(num_inputs=20, num_filts=512, depth=2, nonlin=nonlin, lowrank=token_dim, dropout_p=dropout)
|
694 |
+
else:
|
695 |
+
self.env_model=None
|
696 |
+
|
697 |
+
# Type-specific embeddings for class, register, location, text, image and env tokens
|
698 |
+
self.class_type_embedding = nn.Parameter(torch.empty(1, self.token_dim))
|
699 |
+
nn.init.xavier_uniform_(self.class_type_embedding)
|
700 |
+
if register:
|
701 |
+
self.register_type_embedding = nn.Parameter(torch.empty(1, self.token_dim))
|
702 |
+
nn.init.xavier_uniform_(self.register_type_embedding)
|
703 |
+
self.location_type_embedding = nn.Parameter(torch.empty(1, self.token_dim))
|
704 |
+
nn.init.xavier_uniform_(self.location_type_embedding)
|
705 |
+
if text_inputs:
|
706 |
+
self.text_type_embedding = nn.Parameter(torch.empty(1, self.token_dim))
|
707 |
+
nn.init.xavier_uniform_(self.text_type_embedding)
|
708 |
+
if image_inputs:
|
709 |
+
self.image_type_embedding = nn.Parameter(torch.empty(1, self.token_dim))
|
710 |
+
nn.init.xavier_uniform_(self.image_type_embedding)
|
711 |
+
if env_inputs:
|
712 |
+
self.env_type_embedding = nn.Parameter(torch.empty(1, self.token_dim))
|
713 |
+
nn.init.xavier_uniform_(self.env_type_embedding)
|
714 |
+
|
715 |
+
# Instantiate the class token transformation module
|
716 |
+
if class_token_transformation == 'identity':
|
717 |
+
self.class_token_transform = Identity(token_dim, num_filts)
|
718 |
+
elif class_token_transformation == 'linear':
|
719 |
+
self.class_token_transform = LinearTransformation(token_dim, num_filts)
|
720 |
+
elif class_token_transformation == 'single_layer_nn':
|
721 |
+
self.class_token_transform = SingleLayerNN(token_dim, num_filts, dropout_p=dropout)
|
722 |
+
elif class_token_transformation == 'two_layer_nn':
|
723 |
+
self.class_token_transform = TwoLayerNN(token_dim, num_filts, dropout_p=dropout)
|
724 |
+
elif class_token_transformation == 'sinr':
|
725 |
+
self.class_token_transform = HeadlessSINR(token_dim, num_filts, 2, nonlin, lowrank, dropout_p=dropout)
|
726 |
+
else:
|
727 |
+
raise ValueError(f"Unknown class_token_transformation: {class_token_transformation}")
|
728 |
+
|
729 |
+
def forward(self, x, context_sequence, context_mask, class_ids=None, return_feats=False,
|
730 |
+
return_class_embeddings=False, class_of_interest=None, use_eval_embeddings=False, text_emb=None,
|
731 |
+
image_emb=None, env_emb=None):
|
732 |
+
# Process input through the headless model to get feature embeddings
|
733 |
+
feature_embeddings = self.headless_model(x)
|
734 |
+
|
735 |
+
if return_feats:
|
736 |
+
return feature_embeddings
|
737 |
+
|
738 |
+
if context_sequence.dim() == 2:
|
739 |
+
context_sequence = context_sequence.unsqueeze(0) # Add batch dimension if missing
|
740 |
+
|
741 |
+
context_sequence = context_sequence[:, 1:, :]
|
742 |
+
|
743 |
+
context_mask = context_mask[:, 1:]
|
744 |
+
|
745 |
+
if self.sinr_inputs:
|
746 |
+
context_sequence = self.headless_model(context_sequence)
|
747 |
+
|
748 |
+
# Add type-specific embedding to each location token
|
749 |
+
context_sequence += self.location_type_embedding
|
750 |
+
|
751 |
+
batch_size = context_sequence.size(0)
|
752 |
+
|
753 |
+
# Initialize lists for tokens and masks
|
754 |
+
tokens = []
|
755 |
+
masks = []
|
756 |
+
|
757 |
+
# Process class token
|
758 |
+
class_token_expanded = self.class_token.expand(batch_size, -1, -1) + self.class_type_embedding
|
759 |
+
tokens.append(class_token_expanded)
|
760 |
+
# The class token is always present, so mask is False (i.e., not masked out)
|
761 |
+
class_mask = torch.zeros(batch_size, 1, dtype=torch.bool, device=context_sequence.device)
|
762 |
+
masks.append(class_mask)
|
763 |
+
|
764 |
+
# Process register token if present
|
765 |
+
if self.register is not None:
|
766 |
+
register_expanded = self.register.expand(batch_size, -1, -1) + self.register_type_embedding
|
767 |
+
tokens.append(register_expanded)
|
768 |
+
register_mask = torch.zeros(batch_size, 1, dtype=torch.bool, device=context_sequence.device)
|
769 |
+
masks.append(register_mask)
|
770 |
+
|
771 |
+
# Process text embeddings
|
772 |
+
if self.text_inputs and (text_emb is not None):
|
773 |
+
text_mask = (text_emb.sum(dim=1) == 0)
|
774 |
+
text_emb = self.text_model(text_emb)
|
775 |
+
text_emb += self.text_type_embedding
|
776 |
+
# Set embeddings to zero where mask is True
|
777 |
+
text_emb[text_mask] = 0
|
778 |
+
text_emb = text_emb.unsqueeze(1)
|
779 |
+
tokens.append(text_emb)
|
780 |
+
# Expand text_mask to match sequence dimensions
|
781 |
+
text_mask = text_mask.unsqueeze(1)
|
782 |
+
masks.append(text_mask)
|
783 |
+
|
784 |
+
# Process image embeddings
|
785 |
+
if self.image_inputs and (image_emb is not None):
|
786 |
+
image_mask = (image_emb.sum(dim=1) == 0)
|
787 |
+
image_emb = self.image_model(image_emb)
|
788 |
+
image_emb += self.image_type_embedding
|
789 |
+
image_emb[image_mask] = 0
|
790 |
+
image_emb = image_emb.unsqueeze(1)
|
791 |
+
tokens.append(image_emb)
|
792 |
+
image_mask = image_mask.unsqueeze(1)
|
793 |
+
masks.append(image_mask)
|
794 |
+
|
795 |
+
# Process env embeddings if needed (can be added similarly)
|
796 |
+
if self.env_inputs and (env_emb is not None):
|
797 |
+
env_mask = context_mask
|
798 |
+
env_emb = self.env_model(env_emb)
|
799 |
+
env_emb += self.env_type_embedding
|
800 |
+
env_emb[env_mask] = 0
|
801 |
+
env_emb = env_emb.unsqueeze(1)
|
802 |
+
tokens.append(env_emb)
|
803 |
+
env_mask = env_mask.unsqueeze(1)
|
804 |
+
masks.append(env_mask)
|
805 |
+
|
806 |
+
# Process location tokens
|
807 |
+
tokens.append(context_sequence)
|
808 |
+
masks.append(context_mask)
|
809 |
+
|
810 |
+
# Concatenate all tokens and masks
|
811 |
+
context_sequence = torch.cat(tokens, dim=1)
|
812 |
+
context_mask = torch.cat(masks, dim=1)
|
813 |
+
|
814 |
+
if use_eval_embeddings == False:
|
815 |
+
if class_of_interest == None:
|
816 |
+
# Get class-specific embeddings based on class_ids
|
817 |
+
class_token_output = self.transformer_model(src=context_sequence, src_key_padding_mask=context_mask)
|
818 |
+
# pass these through the class token transformation
|
819 |
+
class_embeddings = self.class_token_transform(class_token_output) # Shape: (batch_size, num_filts)
|
820 |
+
|
821 |
+
if return_class_embeddings:
|
822 |
+
return class_embeddings
|
823 |
+
else:
|
824 |
+
# Update EMA embeddings for these class IDs
|
825 |
+
with torch.no_grad():
|
826 |
+
if self.training:
|
827 |
+
self.update_ema_embeddings(class_ids, class_embeddings)
|
828 |
+
|
829 |
+
# Matrix multiplication to produce logits
|
830 |
+
logits = feature_embeddings @ class_embeddings.T
|
831 |
+
|
832 |
+
# Apply sigmoid to convert logits to probabilities
|
833 |
+
probabilities = torch.sigmoid(logits)
|
834 |
+
|
835 |
+
return probabilities
|
836 |
+
else:
|
837 |
+
device = self.ema_embeddings.weight.device
|
838 |
+
class_of_interest_tensor =torch.tensor([class_of_interest]).to(device)
|
839 |
+
class_embedding = self.get_ema_embeddings(class_of_interest_tensor)
|
840 |
+
print(f'using EMA estimate for class {class_of_interest}')
|
841 |
+
if return_class_embeddings:
|
842 |
+
return class_embedding
|
843 |
+
else:
|
844 |
+
# Matrix multiplication to produce logits
|
845 |
+
logits = feature_embeddings @ class_embedding.T
|
846 |
+
|
847 |
+
# Apply sigmoid to convert logits to probabilities
|
848 |
+
probabilities = torch.sigmoid(logits)
|
849 |
+
probabilities = probabilities.squeeze()
|
850 |
+
return probabilities
|
851 |
+
else:
|
852 |
+
self.eval()
|
853 |
+
if not hasattr(self, 'eval_embeddings'):
|
854 |
+
print('No Eval Embeddings for this species?!')
|
855 |
+
self.eval_embeddings = self.ema_embeddings
|
856 |
+
if class_of_interest == None:
|
857 |
+
# Get class-specific embeddings based on class_ids
|
858 |
+
class_token_output = self.transformer_model(src=context_sequence, src_key_padding_mask=context_mask)
|
859 |
+
class_embeddings = self.class_token_transform(class_token_output)
|
860 |
+
# Update EMA embeddings for these class IDs
|
861 |
+
|
862 |
+
self.generate_eval_embeddings(class_ids, class_embeddings)
|
863 |
+
|
864 |
+
# Matrix multiplication to produce logits
|
865 |
+
logits = feature_embeddings @ class_embeddings.T
|
866 |
+
|
867 |
+
# Apply sigmoid to convert logits to probabilities
|
868 |
+
probabilities = torch.sigmoid(logits)
|
869 |
+
|
870 |
+
return probabilities
|
871 |
+
else:
|
872 |
+
device = self.ema_embeddings.weight.device
|
873 |
+
class_of_interest_tensor =torch.tensor([class_of_interest]).to(device)
|
874 |
+
class_embedding = self.get_eval_embeddings(class_of_interest_tensor)
|
875 |
+
print(f'using eval embedding for class {class_of_interest}')
|
876 |
+
if return_class_embeddings:
|
877 |
+
return class_embedding
|
878 |
+
else:
|
879 |
+
# Matrix multiplication to produce logits
|
880 |
+
logits = feature_embeddings @ class_embedding.T
|
881 |
+
|
882 |
+
# Apply sigmoid to convert logits to probabilities
|
883 |
+
probabilities = torch.sigmoid(logits)
|
884 |
+
probabilities = probabilities.squeeze()
|
885 |
+
return probabilities
|
886 |
+
|
887 |
+
def get_loc_emb(self, x):
|
888 |
+
feature_embeddings = self.headless_model(x)
|
889 |
+
return feature_embeddings
|
890 |
+
|
891 |
+
def init_eval_embeddings(self, num_classes):
|
892 |
+
self.eval_embeddings = nn.Embedding(num_embeddings=num_classes, embedding_dim=self.num_filts)
|
893 |
+
nn.init.xavier_uniform_(self.eval_embeddings.weight)
|
894 |
+
|
895 |
+
def get_ema_embeddings(self, class_ids):
|
896 |
+
# Method to access EMA embeddings
|
897 |
+
return self.ema_embeddings(class_ids)
|
898 |
+
|
899 |
+
def get_eval_embeddings(self, class_ids):
|
900 |
+
# Method to access eval embeddings
|
901 |
+
return self.eval_embeddings(class_ids)
|
902 |
+
|
903 |
+
def update_ema_embeddings(self, class_ids, current_embeddings):
|
904 |
+
if self.training:
|
905 |
+
# Get unique class IDs and their counts
|
906 |
+
unique_class_ids, inverse_indices, counts = class_ids.unique(return_counts=True, return_inverse=True)
|
907 |
+
|
908 |
+
# Get current EMA embeddings for unique class IDs
|
909 |
+
ema_current = self.ema_embeddings(unique_class_ids)
|
910 |
+
|
911 |
+
# Initialize a placeholder for new EMA values
|
912 |
+
ema_new = torch.zeros_like(ema_current)
|
913 |
+
|
914 |
+
# Compute the average of current embeddings for each unique class ID
|
915 |
+
current_sum = torch.zeros_like(ema_current)
|
916 |
+
current_sum.index_add_(0, inverse_indices, current_embeddings)
|
917 |
+
current_avg = current_sum / counts.unsqueeze(1)
|
918 |
+
|
919 |
+
# Apply EMA update formula
|
920 |
+
ema_new = self.ema_factor * current_avg + (1 - self.ema_factor) * ema_current
|
921 |
+
|
922 |
+
# Update the EMA embeddings for unique class IDs
|
923 |
+
self.ema_embeddings.weight.data[unique_class_ids] = ema_new.detach() # Detach to prevent gradients
|
924 |
+
|
925 |
+
def generate_eval_embeddings(self, class_id, current_embedding):
|
926 |
+
self.eval_embeddings.weight.data[class_id, :] = current_embedding.detach() # Detach to prevent gradients
|
927 |
+
|
928 |
+
# self.eval_embeddings.weight.data[class_id] = self.ema_embeddings.weight.data[class_id] # Detach to prevent gradients
|
929 |
+
|
930 |
+
def embedding_forward(self, x, class_ids=None, return_feats=False, return_class_embeddings=False, class_of_interest=None, eval=False):
|
931 |
+
# forward method that uses ema or eval embeddings rather than context sequence
|
932 |
+
|
933 |
+
# Process input through the headless model to get feature embeddings
|
934 |
+
feature_embeddings = self.headless_model(x)
|
935 |
+
|
936 |
+
if return_feats:
|
937 |
+
return feature_embeddings
|
938 |
+
else:
|
939 |
+
if class_of_interest == None:
|
940 |
+
# Get class-specific embeddings based on class_ids
|
941 |
+
if eval == False:
|
942 |
+
class_embeddings = self.get_ema_embeddings(class_ids=class_ids)
|
943 |
+
else:
|
944 |
+
class_embeddings = self.get_eval_embeddings(class_ids=class_ids)
|
945 |
+
if return_class_embeddings:
|
946 |
+
return class_embeddings
|
947 |
+
else:
|
948 |
+
# Matrix multiplication to produce logits
|
949 |
+
logits = feature_embeddings @ class_embeddings.T
|
950 |
+
|
951 |
+
# Apply sigmoid to convert logits to probabilities
|
952 |
+
probabilities = torch.sigmoid(logits)
|
953 |
+
|
954 |
+
return probabilities
|
955 |
+
else:
|
956 |
+
if eval == False:
|
957 |
+
device = self.ema_embeddings.weight.device
|
958 |
+
class_of_interest_tensor =torch.tensor([class_of_interest]).to(device)
|
959 |
+
class_embedding = self.get_ema_embeddings(class_of_interest_tensor)
|
960 |
+
print(f'using EMA estimate for class {class_of_interest}')
|
961 |
+
if return_class_embeddings:
|
962 |
+
return class_embedding
|
963 |
+
else:
|
964 |
+
# Matrix multiplication to produce logits
|
965 |
+
logits = feature_embeddings @ class_embedding.T
|
966 |
+
|
967 |
+
# Apply sigmoid to convert logits to probabilities
|
968 |
+
probabilities = torch.sigmoid(logits)
|
969 |
+
probabilities = probabilities.squeeze()
|
970 |
+
|
971 |
+
return probabilities
|
972 |
+
|
973 |
+
else:
|
974 |
+
device = self.eval_embeddings.weight.device
|
975 |
+
class_of_interest_tensor = torch.tensor([class_of_interest]).to(device)
|
976 |
+
class_embedding = self.get_eval_embeddings(class_of_interest_tensor)
|
977 |
+
#print(f'using eval estimate for class {class_of_interest}')
|
978 |
+
if return_class_embeddings:
|
979 |
+
return class_embedding
|
980 |
+
else:
|
981 |
+
# Matrix multiplication to produce logits
|
982 |
+
logits = feature_embeddings @ class_embedding.T
|
983 |
+
|
984 |
+
# Apply sigmoid to convert logits to probabilities
|
985 |
+
probabilities = torch.sigmoid(logits)
|
986 |
+
probabilities = probabilities.squeeze()
|
987 |
+
|
988 |
+
return probabilities
|
989 |
+
|
990 |
+
|
991 |
+
class LinNet(nn.Module):
|
992 |
+
def __init__(self, num_inputs, num_classes):
|
993 |
+
super(LinNet, self).__init__()
|
994 |
+
self.num_layers = 0
|
995 |
+
self.inc_bias = False
|
996 |
+
self.class_emb = nn.Linear(num_inputs, num_classes, bias=self.inc_bias)
|
997 |
+
self.feats = nn.Identity() # does not do anything
|
998 |
+
|
999 |
+
def forward(self, x, class_of_interest=None, return_feats=False):
|
1000 |
+
loc_emb = self.feats(x)
|
1001 |
+
if return_feats:
|
1002 |
+
return loc_emb
|
1003 |
+
if class_of_interest is None:
|
1004 |
+
class_pred = self.class_emb(loc_emb)
|
1005 |
+
else:
|
1006 |
+
class_pred = self.eval_single_class(loc_emb, class_of_interest)
|
1007 |
+
|
1008 |
+
return torch.sigmoid(class_pred)
|
1009 |
+
|
1010 |
+
def eval_single_class(self, x, class_of_interest):
|
1011 |
+
if self.inc_bias:
|
1012 |
+
return x @ self.class_emb.weight[class_of_interest, :] + self.class_emb.bias[class_of_interest]
|
1013 |
+
else:
|
1014 |
+
return x @ self.class_emb.weight[class_of_interest, :]
|
1015 |
+
|
1016 |
+
|
1017 |
+
class ParallelMulti(torch.nn.Module):
|
1018 |
+
def __init__(self, x: list[torch.nn.Module]):
|
1019 |
+
super(ParallelMulti, self).__init__()
|
1020 |
+
self.layers = nn.ModuleList(x)
|
1021 |
+
|
1022 |
+
def forward(self, xs, **kwargs):
|
1023 |
+
out = torch.cat([self.layers[i](x, **kwargs) for i,x in enumerate(xs)], dim=1)
|
1024 |
+
return out
|
1025 |
+
|
1026 |
+
|
1027 |
+
class SequentialMulti(torch.nn.Sequential):
|
1028 |
+
def forward(self, *inputs, **kwargs):
|
1029 |
+
for module in self._modules.values():
|
1030 |
+
if type(inputs) == tuple:
|
1031 |
+
inputs = module(*inputs, **kwargs)
|
1032 |
+
else:
|
1033 |
+
inputs = module(inputs)
|
1034 |
+
return inputs
|
1035 |
+
|
1036 |
+
|
1037 |
+
# Chris's transformation classes
|
1038 |
+
class Identity(nn.Module):
|
1039 |
+
def __init__(self, in_dim, out_dim):
|
1040 |
+
super(Identity, self).__init__()
|
1041 |
+
# No parameters needed for identity transformation
|
1042 |
+
|
1043 |
+
def forward(self, x):
|
1044 |
+
return x
|
1045 |
+
|
1046 |
+
class LinearTransformation(nn.Module):
|
1047 |
+
def __init__(self, in_dim, out_dim, bias=True):
|
1048 |
+
super(LinearTransformation, self).__init__()
|
1049 |
+
self.linear = nn.Linear(in_dim, out_dim, bias=bias)
|
1050 |
+
|
1051 |
+
def forward(self, x):
|
1052 |
+
return self.linear(x)
|
1053 |
+
|
1054 |
+
class SingleLayerNN(nn.Module):
|
1055 |
+
def __init__(self, in_dim, out_dim, dropout_p=0.1, bias=True):
|
1056 |
+
super(SingleLayerNN, self).__init__()
|
1057 |
+
hidden_dim = (in_dim + out_dim) // 2 # Choose an appropriate hidden dimension
|
1058 |
+
self.net = nn.Sequential(
|
1059 |
+
nn.Linear(in_dim, hidden_dim, bias=bias),
|
1060 |
+
nn.ReLU(),
|
1061 |
+
nn.Dropout(p=dropout_p),
|
1062 |
+
nn.Linear(hidden_dim, out_dim, bias=bias)
|
1063 |
+
)
|
1064 |
+
|
1065 |
+
def forward(self, x):
|
1066 |
+
return self.net(x)
|
1067 |
+
|
1068 |
+
class TwoLayerNN(nn.Module):
|
1069 |
+
def __init__(self, in_dim, out_dim, dropout_p=0.1, bias=True):
|
1070 |
+
super(TwoLayerNN, self).__init__()
|
1071 |
+
hidden_dim = (in_dim + out_dim) // 2 # Choose an appropriate hidden dimension
|
1072 |
+
self.net = nn.Sequential(
|
1073 |
+
nn.Linear(in_dim, hidden_dim, bias=bias),
|
1074 |
+
nn.ReLU(),
|
1075 |
+
nn.Dropout(p=dropout_p),
|
1076 |
+
nn.Linear(hidden_dim, hidden_dim, bias=bias),
|
1077 |
+
nn.ReLU(),
|
1078 |
+
nn.Dropout(p=dropout_p),
|
1079 |
+
nn.Linear(hidden_dim, out_dim, bias=bias)
|
1080 |
+
)
|
1081 |
+
|
1082 |
+
def forward(self, x):
|
1083 |
+
return self.net(x)
|
1084 |
+
|
1085 |
+
class HyperNet(nn.Module):
|
1086 |
+
'''
|
1087 |
+
:param asdf
|
1088 |
+
'''
|
1089 |
+
def __init__(self, params, num_inputs, num_classes, num_filts, pos_enc_depth, species_dim, species_enc_depth, species_filts, species_enc='embed', inference_only=False):
|
1090 |
+
super(HyperNet, self).__init__()
|
1091 |
+
if species_enc == 'embed':
|
1092 |
+
self.species_emb = nn.Embedding(num_classes, species_dim)
|
1093 |
+
self.species_emb.weight.data *= 0.01
|
1094 |
+
elif species_enc == 'taxa':
|
1095 |
+
self.species_emb = TaxaEncoder(params, './data/inat_taxa_info.csv', species_dim)
|
1096 |
+
elif species_enc == 'text':
|
1097 |
+
self.species_emb = TextEncoder(params, params['text_emb_path'], species_dim, './data/inat_taxa_info.csv')
|
1098 |
+
elif species_enc == 'wiki':
|
1099 |
+
self.species_emb = WikiEncoder(params, params['text_emb_path'], species_dim, inference_only=inference_only)
|
1100 |
+
if species_enc_depth == -1:
|
1101 |
+
self.species_enc = nn.Identity()
|
1102 |
+
elif species_enc_depth == 0:
|
1103 |
+
self.species_enc = nn.Linear(species_dim, num_filts+1)
|
1104 |
+
else:
|
1105 |
+
self.species_enc = SimpleFCNet(species_dim, num_filts+1, species_filts, depth=species_enc_depth)
|
1106 |
+
if 'geoprior' in params['loss']:
|
1107 |
+
self.species_params = nn.Parameter(torch.randn(num_classes, species_dim))
|
1108 |
+
self.species_params.data *= 0.0386
|
1109 |
+
self.pos_enc = SimpleFCNet(num_inputs, num_filts, num_filts, depth=pos_enc_depth)
|
1110 |
+
|
1111 |
+
def forward(self, x, y):
|
1112 |
+
ys, indmap = torch.unique(y, return_inverse=True)
|
1113 |
+
species = self.species_enc(self.species_emb(ys))
|
1114 |
+
species_w, species_b = species[...,:-1], species[...,-1:]
|
1115 |
+
pos = self.pos_enc(x)
|
1116 |
+
out = torch.bmm(species_w[indmap],pos[...,None])
|
1117 |
+
out = (out + 0*species_b[indmap]).squeeze(-1) #TODO
|
1118 |
+
if hasattr(self, 'species_params'):
|
1119 |
+
out2 = torch.bmm(self.species_params[ys][indmap],pos[...,None])
|
1120 |
+
out2 = out2.squeeze(-1)
|
1121 |
+
out3 = (species_w, self.species_params[ys], ys)
|
1122 |
+
return out, out2, out3
|
1123 |
+
else:
|
1124 |
+
return out
|
1125 |
+
|
1126 |
+
def zero_shot(self, x, species_emb):
|
1127 |
+
species = self.species_enc(self.species_emb.zero_shot(species_emb))
|
1128 |
+
species_w, _ = species[...,:-1], species[...,-1:]
|
1129 |
+
pos = self.pos_enc(x)
|
1130 |
+
out = pos @ species_w.T
|
1131 |
+
return out
|
1132 |
+
|
1133 |
+
|
1134 |
+
class TaxaEncoder(nn.Module):
|
1135 |
+
def __init__(self, params, fpath, embedding_dim):
|
1136 |
+
super(TaxaEncoder, self).__init__()
|
1137 |
+
import datasets
|
1138 |
+
with open('paths.json', 'r') as f:
|
1139 |
+
paths = json.load(f)
|
1140 |
+
data_dir = paths['train']
|
1141 |
+
obs_file = os.path.join(data_dir, params['obs_file'])
|
1142 |
+
taxa_file_snt = os.path.join(data_dir, 'taxa_subsets.json')
|
1143 |
+
|
1144 |
+
taxa_of_interest = datasets.get_taxa_of_interest(params['species_set'], params['num_aux_species'],
|
1145 |
+
params['aux_species_seed'], params['taxa_file'], taxa_file_snt)
|
1146 |
+
|
1147 |
+
locs, labels, _, dates, _, _ = datasets.load_inat_data(obs_file, taxa_of_interest)
|
1148 |
+
unique_taxa, class_ids = np.unique(labels, return_inverse=True)
|
1149 |
+
class_to_taxa = unique_taxa.tolist()
|
1150 |
+
|
1151 |
+
self.fpath = fpath
|
1152 |
+
ids = []
|
1153 |
+
rows = []
|
1154 |
+
with open(fpath, newline='') as csvfile:
|
1155 |
+
spamreader = csv.reader(csvfile, delimiter=',')
|
1156 |
+
for row in spamreader:
|
1157 |
+
if row[0] == 'taxon_id':
|
1158 |
+
continue
|
1159 |
+
ids.append(int(row[0]))
|
1160 |
+
rows.append(row[3:])
|
1161 |
+
print()
|
1162 |
+
rows = np.array(rows)
|
1163 |
+
rows = [np.unique(rows[:,i], return_inverse=True)[1] for i in range(rows.shape[1])]
|
1164 |
+
rows = torch.from_numpy(np.vstack(rows).T)
|
1165 |
+
rows = rows
|
1166 |
+
self.taxa2row = {taxaid:i for i, taxaid in enumerate(ids)}
|
1167 |
+
embs = [nn.Embedding(rows[:,i].max()+2, embedding_dim, 0) for i in range(rows.shape[1])]
|
1168 |
+
embs[-1] = nn.Embedding(len(class_to_taxa), embedding_dim)
|
1169 |
+
rows2 = torch.zeros((len(class_to_taxa), 7), dtype=rows.dtype)
|
1170 |
+
startind = rows[:,-1].max()
|
1171 |
+
for i in range(len(class_to_taxa)):
|
1172 |
+
if class_to_taxa[i] in ids:
|
1173 |
+
rows2[i] = rows[ids.index(class_to_taxa[i])]+1
|
1174 |
+
rows2[i,-1] -= 1
|
1175 |
+
else:
|
1176 |
+
rows2[i,-1] = startind
|
1177 |
+
startind += 1
|
1178 |
+
self.register_buffer('rows', rows2)
|
1179 |
+
for e in embs:
|
1180 |
+
e.weight.data *= 0.01
|
1181 |
+
self.embs = nn.ModuleList(embs)
|
1182 |
+
|
1183 |
+
def forward(self, x):
|
1184 |
+
inds = self.rows[x]
|
1185 |
+
out = sum([self.embs[i](inds[...,i]) for i in range(inds.shape[-1])])
|
1186 |
+
return out
|
1187 |
+
|
1188 |
+
|
1189 |
+
class TextEncoder(nn.Module):
|
1190 |
+
def __init__(self, params, path, embedding_dim, fpath='inat_taxa_info.csv'):
|
1191 |
+
super(TextEncoder, self).__init__()
|
1192 |
+
import datasets
|
1193 |
+
with open('paths.json', 'r') as f:
|
1194 |
+
paths = json.load(f)
|
1195 |
+
data_dir = paths['train']
|
1196 |
+
obs_file = os.path.join(data_dir, params['obs_file'])
|
1197 |
+
taxa_file_snt = os.path.join(data_dir, 'taxa_subsets.json')
|
1198 |
+
|
1199 |
+
taxa_of_interest = datasets.get_taxa_of_interest(params['species_set'], params['num_aux_species'],
|
1200 |
+
params['aux_species_seed'], params['taxa_file'], taxa_file_snt)
|
1201 |
+
|
1202 |
+
locs, labels, _, dates, _, _ = datasets.load_inat_data(obs_file, taxa_of_interest)
|
1203 |
+
unique_taxa, class_ids = np.unique(labels, return_inverse=True)
|
1204 |
+
class_to_taxa = unique_taxa.tolist()
|
1205 |
+
|
1206 |
+
self.fpath = fpath
|
1207 |
+
ids = []
|
1208 |
+
with open(fpath, newline='') as csvfile:
|
1209 |
+
spamreader = csv.reader(csvfile, delimiter=',')
|
1210 |
+
for row in spamreader:
|
1211 |
+
if row[0] == 'taxon_id':
|
1212 |
+
continue
|
1213 |
+
ids.append(int(row[0]))
|
1214 |
+
embs = torch.load(path)
|
1215 |
+
if len(embs) != len(ids):
|
1216 |
+
print("Warning: Number of embeddings doesn't match number of species")
|
1217 |
+
ids = ids[:embs.shape[0]]
|
1218 |
+
if isinstance(embs, list):
|
1219 |
+
embs = torch.stack(embs)
|
1220 |
+
self.taxa2row = {taxaid:i for i, taxaid in enumerate(ids)}
|
1221 |
+
indmap = -1+torch.zeros(len(class_to_taxa), dtype=torch.int)
|
1222 |
+
embmap = -1+torch.zeros(len(class_to_taxa), dtype=torch.int)
|
1223 |
+
self.missing_emb = nn.Embedding(len(class_to_taxa)-embs.shape[0], embedding_dim)
|
1224 |
+
|
1225 |
+
startind = 0
|
1226 |
+
for i in range(len(class_to_taxa)):
|
1227 |
+
if class_to_taxa[i] in ids:
|
1228 |
+
indmap[i] = ids.index(class_to_taxa[i])
|
1229 |
+
else:
|
1230 |
+
embmap[i] = startind
|
1231 |
+
startind += 1
|
1232 |
+
self.scales = nn.Parameter(torch.zeros(len(class_to_taxa), 1))
|
1233 |
+
self.register_buffer('indmap', indmap, persistent=False)
|
1234 |
+
self.register_buffer('embmap', embmap, persistent=False)
|
1235 |
+
self.register_buffer('embs', embs, persistent=False)
|
1236 |
+
if params['text_hidden_dim'] == 0:
|
1237 |
+
self.linear1 = nn.Linear(embs.shape[1], embedding_dim)
|
1238 |
+
else:
|
1239 |
+
self.linear1 = nn.Linear(embs.shape[1], params['text_hidden_dim'])
|
1240 |
+
self.linear2 = nn.Linear(params['text_hidden_dim'], embedding_dim)
|
1241 |
+
self.act = nn.SiLU()
|
1242 |
+
if params['text_learn_dim'] > 0:
|
1243 |
+
self.learned_emb = nn.Embedding(len(class_to_taxa), params['text_learn_dim'])
|
1244 |
+
self.learned_emb.weight.data *= 0.01
|
1245 |
+
self.linear_learned = nn.Linear(params['text_learn_dim'], embedding_dim)
|
1246 |
+
|
1247 |
+
def forward(self, x):
|
1248 |
+
inds = self.indmap[x]
|
1249 |
+
out = self.embs[self.indmap[x].cpu()]
|
1250 |
+
out = self.linear1(out)
|
1251 |
+
if hasattr(self, 'linear2'):
|
1252 |
+
out = self.linear2(self.act(out))
|
1253 |
+
out = self.scales[x] * (out / (out.std(dim=1)[:, None]))
|
1254 |
+
out[inds == -1] = self.missing_emb(self.embmap[x[inds == -1]])
|
1255 |
+
if hasattr(self, 'learned_emb'):
|
1256 |
+
out2 = self.learned_emb(x)
|
1257 |
+
out2 = self.linear_learned(out2)
|
1258 |
+
out = out+out2
|
1259 |
+
return out
|
1260 |
+
|
1261 |
+
|
1262 |
+
class WikiEncoder(nn.Module):
|
1263 |
+
def __init__(self, params, path, embedding_dim, inference_only=False):
|
1264 |
+
super(WikiEncoder, self).__init__()
|
1265 |
+
self.path = path
|
1266 |
+
if not inference_only:
|
1267 |
+
import datasets
|
1268 |
+
with open('paths.json', 'r') as f:
|
1269 |
+
paths = json.load(f)
|
1270 |
+
data_dir = paths['train']
|
1271 |
+
obs_file = os.path.join(data_dir, params['obs_file'])
|
1272 |
+
taxa_file_snt = os.path.join(data_dir, 'taxa_subsets.json')
|
1273 |
+
|
1274 |
+
taxa_of_interest = datasets.get_taxa_of_interest(params['species_set'], params['num_aux_species'],
|
1275 |
+
params['aux_species_seed'], params['taxa_file'], taxa_file_snt)
|
1276 |
+
|
1277 |
+
locs, labels, _, dates, _, _ = datasets.load_inat_data(obs_file, taxa_of_interest)
|
1278 |
+
if params['zero_shot']:
|
1279 |
+
with open('paths.json', 'r') as f:
|
1280 |
+
paths = json.load(f)
|
1281 |
+
with open(os.path.join(paths['iucn'], 'iucn_res_5.json'), 'r') as f:
|
1282 |
+
data = json.load(f)
|
1283 |
+
D = np.load(os.path.join(paths['snt'], 'snt_res_5.npy'), allow_pickle=True)
|
1284 |
+
D = D.item()
|
1285 |
+
taxa_snt = D['taxa'].tolist()
|
1286 |
+
taxa = [int(tt) for tt in data['taxa_presence'].keys()]
|
1287 |
+
taxa = list(set(taxa + taxa_snt))
|
1288 |
+
mask = labels != taxa[0]
|
1289 |
+
for i in range(1, len(taxa)):
|
1290 |
+
mask &= (labels != taxa[i])
|
1291 |
+
locs = locs[mask]
|
1292 |
+
dates = dates[mask]
|
1293 |
+
labels = labels[mask]
|
1294 |
+
unique_taxa, class_ids = np.unique(labels, return_inverse=True)
|
1295 |
+
class_to_taxa = unique_taxa.tolist()
|
1296 |
+
|
1297 |
+
embs = torch.load(path)
|
1298 |
+
ids = embs['taxon_id'].tolist()
|
1299 |
+
if 'keys' in embs:
|
1300 |
+
taxa_counts = torch.zeros(len(ids), dtype=torch.int32)
|
1301 |
+
for i,k in embs['keys']:
|
1302 |
+
taxa_counts[i] += 1
|
1303 |
+
else:
|
1304 |
+
taxa_counts = torch.ones(len(ids), dtype=torch.int32)
|
1305 |
+
count_sum = torch.cumsum(taxa_counts, dim=0) - taxa_counts
|
1306 |
+
embs = embs['data']
|
1307 |
+
|
1308 |
+
self.taxa2row = {taxaid:i for i, taxaid in enumerate(ids)}
|
1309 |
+
indmap = -1+torch.zeros(len(class_to_taxa), dtype=torch.int)
|
1310 |
+
countmap = torch.zeros(len(class_to_taxa), dtype=torch.int)
|
1311 |
+
self.species_emb = nn.Embedding(len(class_to_taxa), embedding_dim)
|
1312 |
+
self.species_emb.weight.data *= 0.01
|
1313 |
+
|
1314 |
+
for i in range(len(class_to_taxa)):
|
1315 |
+
if class_to_taxa[i] in ids:
|
1316 |
+
i2 = ids.index(class_to_taxa[i])
|
1317 |
+
indmap[i] = count_sum[i2]
|
1318 |
+
countmap[i] = taxa_counts[i2]
|
1319 |
+
|
1320 |
+
self.register_buffer('indmap', indmap, persistent=False)
|
1321 |
+
self.register_buffer('countmap', countmap, persistent=False)
|
1322 |
+
self.register_buffer('embs', embs, persistent=False)
|
1323 |
+
assert embs.shape[1] == 4096
|
1324 |
+
self.scale = nn.Parameter(torch.zeros(1))
|
1325 |
+
if params['species_dropout'] > 0:
|
1326 |
+
self.dropout = nn.Dropout(p=params['species_dropout'])
|
1327 |
+
if params['text_hidden_dim'] == 0:
|
1328 |
+
self.linear1 = nn.Linear(4096, embedding_dim)
|
1329 |
+
else:
|
1330 |
+
self.linear1 = nn.Linear(4096, params['text_hidden_dim'])
|
1331 |
+
if params['text_batchnorm']:
|
1332 |
+
self.bn1 = nn.BatchNorm1d(params['text_hidden_dim'])
|
1333 |
+
for l in range(params['text_num_layers']-1):
|
1334 |
+
setattr(self, f'linear{l+2}', nn.Linear(params['text_hidden_dim'], params['text_hidden_dim']))
|
1335 |
+
if params['text_batchnorm']:
|
1336 |
+
setattr(self, f'bn{l+2}', nn.BatchNorm1d(params['text_hidden_dim']))
|
1337 |
+
setattr(self, f'linear{params["text_num_layers"]+1}', nn.Linear(params['text_hidden_dim'], embedding_dim))
|
1338 |
+
self.act = nn.SiLU()
|
1339 |
+
if params['text_learn_dim'] > 0:
|
1340 |
+
self.learned_emb = nn.Embedding(len(class_to_taxa), params['text_learn_dim'])
|
1341 |
+
self.learned_emb.weight.data *= 0.01
|
1342 |
+
self.linear_learned = nn.Linear(params['text_learn_dim'], embedding_dim)
|
1343 |
+
|
1344 |
+
def forward(self, x):
|
1345 |
+
inds = self.indmap[x] + (torch.rand(x.shape,device=x.device)*self.countmap[x]).floor().int()
|
1346 |
+
out = self.embs[inds]
|
1347 |
+
if hasattr(self, 'dropout'):
|
1348 |
+
out = self.dropout(out)
|
1349 |
+
out = self.linear1(out)
|
1350 |
+
if hasattr(self, 'linear2'):
|
1351 |
+
out = self.act(out)
|
1352 |
+
if hasattr(self, 'bn1'):
|
1353 |
+
out = self.bn1(out)
|
1354 |
+
i = 2
|
1355 |
+
while hasattr(self, f'linear{i}'):
|
1356 |
+
if hasattr(self, f'linear{i}'):
|
1357 |
+
out = self.act(getattr(self, f'linear{i}')(out))
|
1358 |
+
if hasattr(self, f'bn{i}'):
|
1359 |
+
out = getattr(self, f'bn{i}')(out)
|
1360 |
+
i += 1
|
1361 |
+
#out = self.scale * (out / (out.std(dim=1)[:, None]))
|
1362 |
+
out2 = self.species_emb(x)
|
1363 |
+
chosen = torch.rand((out.shape[0],), device=x.device)
|
1364 |
+
chosen = 1+0*chosen #TODO fix this
|
1365 |
+
chosen[inds == -1] = 0
|
1366 |
+
out = chosen[:,None] * out + (1-chosen[:,None])*out2
|
1367 |
+
if hasattr(self, 'learned_emb'):
|
1368 |
+
out2 = self.learned_emb(x)
|
1369 |
+
out2 = self.linear_learned(out2)
|
1370 |
+
out = out+out2
|
1371 |
+
return out
|
1372 |
+
|
1373 |
+
|
1374 |
+
def zero_shot(self, species_emb):
|
1375 |
+
out = species_emb
|
1376 |
+
out = self.linear1(out)
|
1377 |
+
if hasattr(self, 'linear2'):
|
1378 |
+
out = self.act(out)
|
1379 |
+
if hasattr(self, 'bn1'):
|
1380 |
+
out = self.bn1(out)
|
1381 |
+
i = 2
|
1382 |
+
while hasattr(self, f'linear{i}'):
|
1383 |
+
if hasattr(self, f'linear{i}'):
|
1384 |
+
out = self.act(getattr(self, f'linear{i}')(out))
|
1385 |
+
if hasattr(self, f'bn{i}'):
|
1386 |
+
out = getattr(self, f'bn{i}')(out)
|
1387 |
+
i += 1
|
1388 |
+
return out
|
1389 |
+
|
1390 |
+
def zero_shot_old(self, species_emb):
|
1391 |
+
out = species_emb
|
1392 |
+
out = self.linear1(out)
|
1393 |
+
if hasattr(self, 'linear2'):
|
1394 |
+
out = self.linear2(self.act(out))
|
1395 |
+
out = self.scale * (out / (out.std(dim=-1, keepdim=True)))
|
1396 |
+
return out
|
1397 |
+
|
1398 |
+
# MINE - would only be used for my models - not currently being used at all
|
1399 |
+
# CURRENTLY JUST USING A HEADLESS_SINR FOR THE TEXT ENCODER
|
1400 |
+
class MultiInputTextEncoder(nn.Module):
|
1401 |
+
def __init__(self, token_dim, dropout, input_dim=4096, depth=2, hidden_dim=512, nonlin='relu', batch_norm=True, layer_norm=False):
|
1402 |
+
super(MultiInputTextEncoder, self).__init__()
|
1403 |
+
|
1404 |
+
print("THINK ABOUT IF SOME OF THESE HYPERPARAMETERS SHOULD BE DISTINCT FROM THE TRANSFORMER VERSION")
|
1405 |
+
print("DEPTH / NUM_ENCODER_LAYERS, DROPOUT, DIM_FEEDFORWARD, ETC")
|
1406 |
+
print("AT PRESENT WE JUST HAVE A SORT OF BASIC VERSION IMPLEMENTED THAT ATTEMPTS TO BE LIKE MAX'S VERSION")
|
1407 |
+
print("ALSO, OPTION TO HAVE IT PRETRAINED? ADD RESIDUAL LAYERS?")
|
1408 |
+
self.token_dim=token_dim
|
1409 |
+
self.dropout=dropout
|
1410 |
+
self.input_dim=input_dim
|
1411 |
+
self.depth=depth
|
1412 |
+
self.hidden_dim=hidden_dim
|
1413 |
+
self.batch_norm = batch_norm
|
1414 |
+
self.layer_norm = layer_norm
|
1415 |
+
|
1416 |
+
if nonlin == 'relu':
|
1417 |
+
activation = nn.ReLU
|
1418 |
+
elif nonlin == 'silu':
|
1419 |
+
activation = nn.SiLU
|
1420 |
+
else:
|
1421 |
+
raise NotImplementedError('Invalid nonlinearity specified.')
|
1422 |
+
|
1423 |
+
self.dropout_layer = nn.Dropout(p=self.dropout)
|
1424 |
+
if self.depth <= 1:
|
1425 |
+
self.linear1 = nn.Linear(self.input_dim, self.token_dim)
|
1426 |
+
|
1427 |
+
else:
|
1428 |
+
self.linear1 = nn.Linear(self.input_dim, self.hidden_dim)
|
1429 |
+
|
1430 |
+
if self.batch_norm:
|
1431 |
+
self.bn1 = nn.BatchNorm1d(self.hidden_dim)
|
1432 |
+
|
1433 |
+
# if self.layer_norm:
|
1434 |
+
# self.ln1 = nn.LayerNorm(self.hidden_dim)
|
paths.json
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"data": "data/",
|
3 |
+
"masks": "data/masks/",
|
4 |
+
"env": "data/env/",
|
5 |
+
"train": "data/train/",
|
6 |
+
"geo_prior": "data/eval/geo_prior/",
|
7 |
+
"snt": "data/eval/snt/",
|
8 |
+
"iucn": "data/eval/iucn/",
|
9 |
+
"geo_feature": "data/eval/geo_feature/"
|
10 |
+
}
|
requirements.txt
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio==3.36.1
|
2 |
+
h3==3.7.6
|
3 |
+
matplotlib==3.7.1
|
4 |
+
numpy==1.25.0
|
5 |
+
pandas==2.0.3
|
6 |
+
scikit_learn==1.3.0
|
7 |
+
scikit-image==0.19.3
|
8 |
+
tifffile==2023.7.4
|
9 |
+
torch==1.12.1
|
10 |
+
imagecodecs==2023.9.18
|
setup.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
utils.py
ADDED
@@ -0,0 +1,326 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import matplotlib.pyplot as plt
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
import math
|
5 |
+
import datetime
|
6 |
+
#from h3.unstable import vect
|
7 |
+
import h3
|
8 |
+
|
9 |
+
class CoordEncoder:
|
10 |
+
|
11 |
+
def __init__(self, input_enc, raster=None, input_dim=0):
|
12 |
+
self.input_enc = input_enc
|
13 |
+
self.raster = raster
|
14 |
+
self.input_dim = input_dim
|
15 |
+
|
16 |
+
def encode(self, locs, normalize=True):
|
17 |
+
# assumes lon, lat in range [-180, 180] and [-90, 90]
|
18 |
+
if normalize:
|
19 |
+
locs = normalize_coords(locs)
|
20 |
+
if self.input_enc == 'none':
|
21 |
+
loc_feats = locs * torch.tensor([[180.0,90.0]], device=locs.device)
|
22 |
+
elif self.input_enc == 'sin_cos': # sinusoidal encoding
|
23 |
+
loc_feats = encode_loc(locs, input_dim=self.input_dim)
|
24 |
+
elif self.input_enc == 'env': # bioclim variables
|
25 |
+
loc_feats = bilinear_interpolate(locs, self.raster)
|
26 |
+
elif self.input_enc == 'sin_cos_env': # sinusoidal encoding & bioclim variables
|
27 |
+
loc_feats = encode_loc(locs, input_dim=self.input_dim)
|
28 |
+
context_feats = bilinear_interpolate(locs, self.raster.to(locs.device))
|
29 |
+
loc_feats = torch.cat((loc_feats, context_feats), 1)
|
30 |
+
elif self.input_enc == 'satclip': #SatClip Embedding
|
31 |
+
if not hasattr(self, 'model'):
|
32 |
+
import sys
|
33 |
+
sys.path.append('./satclip/satclip')
|
34 |
+
from satclip.satclip.load import get_satclip
|
35 |
+
self.model = get_satclip('satclip/satclip-vit16-l10.ckpt', device="cpu")
|
36 |
+
self.model.eval()
|
37 |
+
self.model = self.model.to(locs.device)
|
38 |
+
locs = locs*torch.tensor([[180.0, 90.0]], device=locs.device)
|
39 |
+
max_batch = 1000000
|
40 |
+
loc_feats = torch.empty(locs.shape[0], 256, device=locs.device)
|
41 |
+
with torch.no_grad():
|
42 |
+
for i in range(0, locs.shape[0], max_batch):
|
43 |
+
loc_feats[i:i+max_batch] = self.model(locs[i:i+max_batch].double()).float()
|
44 |
+
else:
|
45 |
+
raise NotImplementedError('Unknown input encoding.')
|
46 |
+
return loc_feats
|
47 |
+
|
48 |
+
def encode_fast(self, loc: list[float], normalize=True):
|
49 |
+
assert not normalize
|
50 |
+
if self.input_enc == 'sin_cos':
|
51 |
+
loc_feats = encode_loc_fast(loc, input_dim=self.input_dim)
|
52 |
+
else:
|
53 |
+
raise NotImplementedError('Unknown input encoding.')
|
54 |
+
return loc_feats
|
55 |
+
|
56 |
+
|
57 |
+
class TimeEncoder:
|
58 |
+
|
59 |
+
def __init__(self, input_enc='conical'):
|
60 |
+
self.input_enc = input_enc
|
61 |
+
|
62 |
+
def encode(self, intervals):
|
63 |
+
# assumes time, width in range [0, 1]
|
64 |
+
t_center = intervals[:, :1]
|
65 |
+
t_width = intervals[:, 1:]
|
66 |
+
if self.input_enc == 'conical':
|
67 |
+
t_feats = torch.cat([(1 - t_width) * torch.sin(2 * torch.pi * t_center),
|
68 |
+
(1 - t_width) * torch.cos(2 * torch.pi * t_center), 2 * t_width - 1], dim=1)
|
69 |
+
elif self.input_enc == 'cylindrical':
|
70 |
+
t_feats = torch.cat([torch.sin(2 * torch.pi * t_center), torch.cos(2 * torch.pi * t_center), 2 * t_width - 1], dim=1)
|
71 |
+
return t_feats
|
72 |
+
|
73 |
+
def encode_fast(self, intervals):
|
74 |
+
# assumes time, width in range [0, 1]
|
75 |
+
t_center, t_width = intervals
|
76 |
+
if self.input_enc == 'conical':
|
77 |
+
t_feats = torch.tensor([(1 - t_width) * math.sin(2 * math.pi * t_center),
|
78 |
+
(1 - t_width) * math.cos(2 * math.pi * t_center), 2 * t_width - 1])
|
79 |
+
elif self.input_enc == 'cylindrical':
|
80 |
+
t_feats = torch.tensor([math.sin(2 * math.pi * t_center),
|
81 |
+
math.cos(2 * math.pi * t_center), 2 * t_width - 1])
|
82 |
+
return t_feats
|
83 |
+
|
84 |
+
|
85 |
+
def normalize_coords(locs):
|
86 |
+
# locs is in lon {-180, 180}, lat {90, -90}
|
87 |
+
# output is in the range [-1, 1]
|
88 |
+
|
89 |
+
locs[:,0] /= 180.0
|
90 |
+
locs[:,1] /= 90.0
|
91 |
+
|
92 |
+
return locs
|
93 |
+
|
94 |
+
def encode_loc(loc_ip, concat_dim=1, input_dim=0):
|
95 |
+
# assumes inputs location are in range -1 to 1
|
96 |
+
# location is lon, lat
|
97 |
+
encs = []
|
98 |
+
for i in range(input_dim//4):
|
99 |
+
encs.append(torch.sin(math.pi*(2**i)*loc_ip))
|
100 |
+
encs.append(torch.cos(math.pi*(2**i)*loc_ip))
|
101 |
+
feats = torch.cat(encs, concat_dim)
|
102 |
+
return feats
|
103 |
+
|
104 |
+
|
105 |
+
def encode_loc_fast(loc_ip: list[float], input_dim=0):
|
106 |
+
# assumes inputs location are in range -1 to 1
|
107 |
+
# location is lon, lat
|
108 |
+
input_dim //= 2 # needed to make it compatible with encode_loc
|
109 |
+
feats = [(math.sin if i%(2*len(loc_ip))<len(loc_ip) else math.cos)(math.pi*(2**(i//(2*len(loc_ip))))*loc_ip[i%len(loc_ip)]) for i in range(input_dim)]
|
110 |
+
return feats
|
111 |
+
|
112 |
+
|
113 |
+
def bilinear_interpolate(loc_ip, data, remove_nans_raster=True):
|
114 |
+
# loc is N x 2 vector, where each row is [lon,lat] entry
|
115 |
+
# each entry spans range [-1,1]
|
116 |
+
# data is H x W x C, height x width x channel data matrix
|
117 |
+
# op will be N x C matrix of interpolated features
|
118 |
+
|
119 |
+
assert data is not None
|
120 |
+
|
121 |
+
# map to [0,1], then scale to data size
|
122 |
+
loc = (loc_ip.clone() + 1) / 2.0
|
123 |
+
loc[:,1] = 1 - loc[:,1] # this is because latitude goes from +90 on top to bottom while
|
124 |
+
# longitude goes from -90 to 90 left to right
|
125 |
+
|
126 |
+
assert not torch.any(torch.isnan(loc))
|
127 |
+
|
128 |
+
if remove_nans_raster:
|
129 |
+
data[torch.isnan(data)] = 0.0 # replace with mean value (0 is mean post-normalization)
|
130 |
+
|
131 |
+
# cast locations into pixel space
|
132 |
+
loc[:, 0] *= (data.shape[1]-1)
|
133 |
+
loc[:, 1] *= (data.shape[0]-1)
|
134 |
+
|
135 |
+
loc_int = torch.floor(loc).long() # integer pixel coordinates
|
136 |
+
xx = loc_int[:, 0]
|
137 |
+
yy = loc_int[:, 1]
|
138 |
+
xx_plus = xx + 1
|
139 |
+
xx_plus[xx_plus > (data.shape[1]-1)] = data.shape[1]-1
|
140 |
+
yy_plus = yy + 1
|
141 |
+
yy_plus[yy_plus > (data.shape[0]-1)] = data.shape[0]-1
|
142 |
+
|
143 |
+
loc_delta = loc - torch.floor(loc) # delta values
|
144 |
+
dx = loc_delta[:, 0].unsqueeze(1)
|
145 |
+
dy = loc_delta[:, 1].unsqueeze(1)
|
146 |
+
|
147 |
+
interp_val = data[yy, xx, :]*(1-dx)*(1-dy) + data[yy, xx_plus, :]*dx*(1-dy) + \
|
148 |
+
data[yy_plus, xx, :]*(1-dx)*dy + data[yy_plus, xx_plus, :]*dx*dy
|
149 |
+
|
150 |
+
return interp_val
|
151 |
+
|
152 |
+
def rand_samples(batch_size, device, rand_type='uniform'):
|
153 |
+
# randomly sample background locations
|
154 |
+
|
155 |
+
if rand_type == 'spherical':
|
156 |
+
rand_loc = torch.rand(batch_size, 2).to(device)
|
157 |
+
theta1 = 2.0*math.pi*rand_loc[:, 0]
|
158 |
+
theta2 = torch.acos(2.0*rand_loc[:, 1] - 1.0)
|
159 |
+
lat = 1.0 - 2.0*theta2/math.pi
|
160 |
+
lon = (theta1/math.pi) - 1.0
|
161 |
+
rand_loc = torch.cat((lon.unsqueeze(1), lat.unsqueeze(1)), 1)
|
162 |
+
|
163 |
+
elif rand_type == 'uniform':
|
164 |
+
rand_loc = torch.rand(batch_size, 2).to(device)*2.0 - 1.0
|
165 |
+
|
166 |
+
return rand_loc
|
167 |
+
|
168 |
+
def get_time_stamp():
|
169 |
+
cur_time = str(datetime.datetime.now())
|
170 |
+
date, time = cur_time.split(' ')
|
171 |
+
h, m, s = time.split(':')
|
172 |
+
s = s.split('.')[0]
|
173 |
+
time_stamp = '{}-{}-{}-{}'.format(date, h, m, s)
|
174 |
+
return time_stamp
|
175 |
+
|
176 |
+
def coord_grid(grid_size, split_ids=None, split_of_interest=None):
|
177 |
+
# generate a grid of locations spaced evenly in coordinate space
|
178 |
+
|
179 |
+
feats = np.zeros((grid_size[0], grid_size[1], 2), dtype=np.float32)
|
180 |
+
mg = np.meshgrid(np.linspace(-180, 180, feats.shape[1]), np.linspace(90, -90, feats.shape[0]))
|
181 |
+
feats[:, :, 0] = mg[0]
|
182 |
+
feats[:, :, 1] = mg[1]
|
183 |
+
if split_ids is None or split_of_interest is None:
|
184 |
+
# return feats for all locations
|
185 |
+
# this will be an N x 2 array
|
186 |
+
return feats.reshape(feats.shape[0]*feats.shape[1], 2)
|
187 |
+
else:
|
188 |
+
# only select a subset of locations
|
189 |
+
ind_y, ind_x = np.where(split_ids==split_of_interest)
|
190 |
+
|
191 |
+
# these will be N_subset x 2 in size
|
192 |
+
return feats[ind_y, ind_x, :]
|
193 |
+
|
194 |
+
def create_spatial_split(raster, mask, train_amt=1.0, cell_size=25):
|
195 |
+
# generates a checkerboard style train test split
|
196 |
+
# 0 is invalid, 1 is train, and 2 is test
|
197 |
+
# c_size is units of pixels
|
198 |
+
split_ids = np.ones((raster.shape[0], raster.shape[1]))
|
199 |
+
start = cell_size
|
200 |
+
for ii in np.arange(0, split_ids.shape[0], cell_size):
|
201 |
+
if start == 0:
|
202 |
+
start = cell_size
|
203 |
+
else:
|
204 |
+
start = 0
|
205 |
+
for jj in np.arange(start, split_ids.shape[1], cell_size*2):
|
206 |
+
split_ids[ii:ii+cell_size, jj:jj+cell_size] = 2
|
207 |
+
split_ids = split_ids*mask
|
208 |
+
if train_amt < 1.0:
|
209 |
+
# take a subset of the data
|
210 |
+
tr_y, tr_x = np.where(split_ids==1)
|
211 |
+
inds = np.random.choice(len(tr_y), int(len(tr_y)*(1.0-train_amt)), replace=False)
|
212 |
+
split_ids[tr_y[inds], tr_x[inds]] = 0
|
213 |
+
return split_ids
|
214 |
+
|
215 |
+
def average_precision_score_faster(y_true, y_scores):
|
216 |
+
# drop in replacement for sklearn's average_precision_score
|
217 |
+
# comparable up to floating point differences
|
218 |
+
num_positives = y_true.sum()
|
219 |
+
inds = np.argsort(y_scores)[::-1]
|
220 |
+
y_true_s = y_true[inds]
|
221 |
+
|
222 |
+
false_pos_c = np.cumsum(1.0 - y_true_s)
|
223 |
+
true_pos_c = np.cumsum(y_true_s)
|
224 |
+
recall = true_pos_c / num_positives
|
225 |
+
false_neg = np.maximum(true_pos_c + false_pos_c, np.finfo(np.float32).eps)
|
226 |
+
precision = true_pos_c / false_neg
|
227 |
+
|
228 |
+
recall_e = np.hstack((0, recall, 1))
|
229 |
+
recall_e = (recall_e[1:] - recall_e[:-1])[:-1]
|
230 |
+
map_score = (recall_e*precision).sum()
|
231 |
+
return map_score
|
232 |
+
|
233 |
+
#TODO I might be able to just cast these to a float to make them 1 or 0
|
234 |
+
#TODO y_true are the same as the ones
|
235 |
+
def average_precision_score_fasterer(y_true, y_scores):
|
236 |
+
# drop in replacement for sklearn's average_precision_score
|
237 |
+
# comparable up to floating point differences
|
238 |
+
num_positives = y_true.sum()
|
239 |
+
inds = torch.argsort(y_scores, descending=True)
|
240 |
+
y_true_s = y_true[inds]
|
241 |
+
|
242 |
+
false_pos_c = torch.cumsum(1.0 - y_true_s, dim=0)
|
243 |
+
true_pos_c = torch.cumsum(y_true_s, dim=0)
|
244 |
+
recall = true_pos_c / num_positives
|
245 |
+
false_neg = (true_pos_c + false_pos_c).clip(min=np.finfo(np.float32).eps)
|
246 |
+
precision = true_pos_c / false_neg
|
247 |
+
|
248 |
+
recall_e = torch.cat([torch.zeros(1, device=recall.device), recall, torch.ones(1, device=recall.device)])
|
249 |
+
recall_e = (recall_e[1:] - recall_e[:-1])[:-1]
|
250 |
+
map_score = (recall_e*precision).sum()
|
251 |
+
return map_score
|
252 |
+
|
253 |
+
|
254 |
+
class DataPDFH3:
|
255 |
+
def __init__(self, data='data_pdf_h3.pt', device='cpu'):
|
256 |
+
super(DataPDFH3, self).__init__()
|
257 |
+
self.data = torch.cumsum(torch.load(data, map_location=device), dim=0)
|
258 |
+
self.data = torch.cat([torch.zeros_like(self.data[:1]), self.data], dim=0)
|
259 |
+
inds = torch.load('inds_h3.pt')
|
260 |
+
inds = ((inds >> 30) & 4194303)
|
261 |
+
self.ind_map = -1+torch.zeros(2 ** 22, dtype=torch.int32)
|
262 |
+
self.ind_map[inds] = torch.arange(inds.shape[0], dtype=torch.int32)
|
263 |
+
self.cum_counts = self.data.sum(dim=-1)
|
264 |
+
|
265 |
+
def _sample(self, pos, time, noise_level):
|
266 |
+
pos = pos.cpu()
|
267 |
+
time = time.cpu()
|
268 |
+
noise_level = noise_level.cpu()
|
269 |
+
t_low = (365*(time - 0.5*(noise_level))).int()
|
270 |
+
t_high = (365*(time + 0.5*(noise_level))).int()
|
271 |
+
t_high[t_low < 0] += 365
|
272 |
+
t_low[t_low < 0] += 365
|
273 |
+
|
274 |
+
pos_ind = torch.from_numpy((h3.latlng_to_cell(90*pos[:, 1], 180*pos[:, 0], 5).astype(np.int64) >> 30) & 4194303)
|
275 |
+
pos_ind = self.ind_map[pos_ind]
|
276 |
+
counts = self.data[t_high.clamp(max=364)+1, pos_ind] - self.data[t_low, pos_ind]
|
277 |
+
counts[t_high > 364] += self.data[(t_high[t_high > 364] - 365).clamp(max=364) + 1, pos_ind[t_high > 364]]
|
278 |
+
counts[t_high > 729] += self.data[(t_high[t_high > 729] - 730).clamp(max=364) + 1, pos_ind[t_high > 729]]
|
279 |
+
totals = self.cum_counts[t_high.clamp(max=364)+1] - self.cum_counts[t_low]
|
280 |
+
totals[t_high > 364] += self.cum_counts[(t_high[t_high > 364] - 365).clamp(max=364) + 1]
|
281 |
+
totals[t_high > 729] += self.cum_counts[(t_high[t_high > 729] - 730).clamp(max=364) + 1]
|
282 |
+
counts[pos_ind < 0] = 0
|
283 |
+
return counts, totals
|
284 |
+
|
285 |
+
def sample(self, pos, time, noise_level):
|
286 |
+
counts, totals = self._sample(pos, time, noise_level)
|
287 |
+
return counts/totals
|
288 |
+
|
289 |
+
def sample_log(self, pos, time, noise_level, eps=1e-2):
|
290 |
+
counts, totals = self._sample(pos, time, noise_level)
|
291 |
+
return torch.log(counts)-torch.log(totals+eps)
|
292 |
+
|
293 |
+
|
294 |
+
class LowRankModel:
|
295 |
+
def __init__(self, data='nmf_256.pt', device='cpu'):
|
296 |
+
super(LowRankModel, self).__init__()
|
297 |
+
dim=-1
|
298 |
+
x1, x2 = torch.load(data, map_location=device)
|
299 |
+
m = torch.load('class_counts_locs_h3.pt').float()
|
300 |
+
chosen_inds = m.sum(dim=0).to_dense().sort(descending=True).indices[:]
|
301 |
+
if dim == 0:
|
302 |
+
n = m.to_dense()[:, chosen_inds].sum(dim=dim, keepdim=True)
|
303 |
+
self.data = n*torch.softmax(x1 @ x2, dim=dim)
|
304 |
+
self.data = self.data/torch.sum(self.data, dim=1, keepdim=True)
|
305 |
+
elif dim == 1:
|
306 |
+
self.data = torch.softmax(x1 @ x2, dim=dim)
|
307 |
+
elif dim == -1:
|
308 |
+
self.data = torch.from_numpy(x1 @ x2)
|
309 |
+
self.data = self.data/torch.sum(self.data, dim=1, keepdim=True)
|
310 |
+
m = m.to_dense()[:, chosen_inds]
|
311 |
+
#self.data = m.to_dense().float()/torch.sum(m.to_dense(), dim=1, keepdim=True)
|
312 |
+
self.pc = m.sum(dim=1, keepdim=True) / m.sum()
|
313 |
+
inds = torch.load('inds_h3.pt')[chosen_inds]
|
314 |
+
inds = ((inds >> 30) & 4194303)
|
315 |
+
self.ind_map = -1+torch.zeros(2 ** 22, dtype=torch.int32)
|
316 |
+
self.ind_map[inds] = torch.arange(inds.shape[0], dtype=torch.int32)
|
317 |
+
|
318 |
+
def sample(self, pos):#, time, noise_level):
|
319 |
+
pos = pos.cpu()
|
320 |
+
pos_ind = torch.from_numpy((h3.latlng_to_cell(pos[:, 1], pos[:, 0], 5).astype(np.int64) >> 30) & 4194303)
|
321 |
+
pos_ind = self.ind_map[pos_ind]
|
322 |
+
out = self.data[:, pos_ind]
|
323 |
+
out *= self.pc
|
324 |
+
out = out/torch.sum(out, dim=0, keepdim=True)
|
325 |
+
out[:, pos_ind < 0] = 1.0/out.shape[0]
|
326 |
+
return out
|
viz_ls_map.py
ADDED
@@ -0,0 +1,283 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Demo that takes an iNaturalist taxa ID as input and generates a prediction
|
3 |
+
for each location on the globe and saves the ouput as an image.
|
4 |
+
"""
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import numpy as np
|
8 |
+
import matplotlib.pyplot as plt
|
9 |
+
import os
|
10 |
+
import json
|
11 |
+
import argparse
|
12 |
+
|
13 |
+
import utils
|
14 |
+
import datasets
|
15 |
+
import eval
|
16 |
+
import create_inputs_to_fs_sinr
|
17 |
+
|
18 |
+
text_model = './experiments/gpt_data.pt'
|
19 |
+
|
20 |
+
def extract_grit_token(model, text:str):
|
21 |
+
def gritlm_instruction(instruction):
|
22 |
+
return "<|user|>\n" + instruction + "\n<|embed|>\n" if instruction else "<|embed|>\n"
|
23 |
+
d_rep = model.encode([text], instruction=gritlm_instruction(""))
|
24 |
+
d_rep = torch.from_numpy(d_rep)
|
25 |
+
return d_rep
|
26 |
+
|
27 |
+
def choose_context_points_from_map(eval_params):
|
28 |
+
context_points = []
|
29 |
+
|
30 |
+
if False:
|
31 |
+
def onclick(event):
|
32 |
+
if event.xdata is not None and event.ydata is not None:
|
33 |
+
# Convert image coordinates to normalized geographical coordinates
|
34 |
+
lon = event.xdata / mask.shape[1] * 2 - 1
|
35 |
+
lat = 1 - event.ydata / mask.shape[0] * 2
|
36 |
+
context_points.append((lon, lat))
|
37 |
+
print(f"Added context point: ({lon}, {lat})")
|
38 |
+
|
39 |
+
# Load ocean mask
|
40 |
+
with open('paths.json', 'r') as f:
|
41 |
+
paths = json.load(f)
|
42 |
+
if eval_params['high_res']:
|
43 |
+
mask = np.load(os.path.join(paths['masks'], 'ocean_mask_hr.npy'))
|
44 |
+
else:
|
45 |
+
mask = np.load(os.path.join(paths['masks'], 'ocean_mask.npy'))
|
46 |
+
|
47 |
+
mask_inds = np.where(mask.reshape(-1) == 1)[0]
|
48 |
+
|
49 |
+
# # Generate input features
|
50 |
+
# locs = utils.coord_grid(mask.shape)
|
51 |
+
# if not eval_params['disable_ocean_mask']:
|
52 |
+
# locs = locs[mask_inds, :]
|
53 |
+
# locs = torch.from_numpy(locs)
|
54 |
+
|
55 |
+
# Reshape and create masked array for visualization
|
56 |
+
op_im = np.ones((mask.shape[0] * mask.shape[1])) * np.nan # Set to NaN
|
57 |
+
op_im[mask_inds] = 0 # Placeholder for the mask visualization
|
58 |
+
op_im = op_im.reshape((mask.shape[0], mask.shape[1]))
|
59 |
+
op_im = np.ma.masked_invalid(op_im)
|
60 |
+
|
61 |
+
# Set color for masked values
|
62 |
+
cmap = plt.cm.plasma
|
63 |
+
cmap.set_bad(color='none')
|
64 |
+
plt.ioff()
|
65 |
+
# Display the map and capture context points
|
66 |
+
fig, ax = plt.subplots(figsize=(6, 3), dpi=334) # Define the figure size
|
67 |
+
ax.imshow(op_im, cmap=cmap, interpolation='nearest') # Display the image
|
68 |
+
ax.axis('off') # Turn off the axis
|
69 |
+
|
70 |
+
# Connect the onclick event to the handler
|
71 |
+
cid = fig.canvas.mpl_connect('button_press_event', onclick)
|
72 |
+
|
73 |
+
plt.show(block=True) # Block execution until the window is closed
|
74 |
+
|
75 |
+
print(f"Context points collected: {context_points}")
|
76 |
+
|
77 |
+
else:
|
78 |
+
#USA
|
79 |
+
#TODO: 37.541170, -92.003293 1. flip order, then 2. normalize so divide by 180 and 90
|
80 |
+
context_points = [(-0.5884012559178662, 0.46394662490802496), (-0.5451199953511522, 0.4504212309809269),
|
81 |
+
(-0.5437674559584422, 0.5342786733289353), (-0.589753795310576, 0.5342786733289353)]
|
82 |
+
print(f"Context points collected: {context_points}")
|
83 |
+
return context_points
|
84 |
+
|
85 |
+
def main(eval_params):
|
86 |
+
# load params
|
87 |
+
with open('paths.json', 'r') as f:
|
88 |
+
paths = json.load(f)
|
89 |
+
|
90 |
+
ckp_name = os.path.split(eval_params['model_path'])[-1]
|
91 |
+
experiment_name = os.path.split(os.path.split(eval_params['model_path'])[-2])[-1]
|
92 |
+
|
93 |
+
eval_overrides = {'ckp_name':ckp_name,
|
94 |
+
'experiment_name':experiment_name,
|
95 |
+
'device':eval_params['device']}
|
96 |
+
|
97 |
+
|
98 |
+
train_overrides = {'dataset': 'eval_transformer'}
|
99 |
+
#grit = GritLM("GritLM/GritLM-7B", torch_dtype="auto", mode="embedding")
|
100 |
+
#grit_gpt = torch.load(text_model, map_location='cpu')
|
101 |
+
#context_model = torch.load("experiments/zero_shot_ls_sin_cos_cap_1000_text_context_20_sinr_two_layer_nn/model.pt", map_location=torch.device('cpu'))
|
102 |
+
context_data = np.load('data/positive_eval_data.npz')
|
103 |
+
text_type_value = 0
|
104 |
+
|
105 |
+
for pt in eval_params['context_pt_trial']:
|
106 |
+
number_of_context_points = pt
|
107 |
+
if eval_params['choose_context_points'] == 1:
|
108 |
+
#context_points = choose_context_points_from_map(eval_params)
|
109 |
+
text_emb, text_type_value = create_inputs_to_fs_sinr.use_pregenerated_textemb_fromchris(taxon_id=eval_params['test_taxa'],
|
110 |
+
text_type=eval_params['text_type'])
|
111 |
+
context_points = create_inputs_to_fs_sinr.get_eval_context_points(taxa_id=eval_params['test_taxa'],
|
112 |
+
context_data=context_data,
|
113 |
+
size=number_of_context_points)
|
114 |
+
model, context_locs_of_interest, train_params, class_of_interest = eval.generate_eval_embedding_from_given_points(
|
115 |
+
context_points=context_points,
|
116 |
+
overrides=eval_overrides,
|
117 |
+
taxa_of_interest=eval_params['taxa_id'],
|
118 |
+
train_overrides=train_overrides,
|
119 |
+
text_emb=text_emb)
|
120 |
+
#TODO: why is taxa_id updated to 'selected pts'??
|
121 |
+
eval_params['taxa_id'] = 'selected_points'
|
122 |
+
else:
|
123 |
+
model, context_locs_of_interest, train_params, class_of_interest = eval.generate_eval_embeddings(
|
124 |
+
overrides=eval_overrides,
|
125 |
+
taxa_of_interest=eval_params['taxa_id'],
|
126 |
+
num_context=eval_params['num_context'],
|
127 |
+
train_overrides=train_overrides)
|
128 |
+
|
129 |
+
if train_params['params']['input_enc'] in ['env', 'sin_cos_env']:
|
130 |
+
raster = datasets.load_env()
|
131 |
+
else:
|
132 |
+
raster = None
|
133 |
+
enc = utils.CoordEncoder(train_params['params']['input_enc'], raster=raster, input_dim=train_params['params']['input_dim'])
|
134 |
+
enc_time = utils.CoordEncoder('sin_cos', raster=None, input_dim=2 * train_params['params']['input_time_dim'])
|
135 |
+
|
136 |
+
# load ocean mask
|
137 |
+
if eval_params['high_res']:
|
138 |
+
mask = np.load(os.path.join(paths['masks'], 'ocean_mask_hr.npy'))
|
139 |
+
else:
|
140 |
+
mask = np.load(os.path.join(paths['masks'], 'ocean_mask.npy'))
|
141 |
+
#mask = 0*mask+1
|
142 |
+
mask_inds = np.where(mask.reshape(-1) == 1)[0]
|
143 |
+
|
144 |
+
# generate input features
|
145 |
+
locs = utils.coord_grid(mask.shape)
|
146 |
+
if not eval_params['disable_ocean_mask']:
|
147 |
+
locs = locs[mask_inds, :]
|
148 |
+
locs = torch.from_numpy(locs)
|
149 |
+
locs_enc = enc.encode(locs).to(eval_params['device'])
|
150 |
+
if train_params['params']['input_time_dim'] > 0:
|
151 |
+
extra_input = torch.cat([enc_time.encode(torch.tensor([[0.0]]), normalize=False), torch.tensor([[1.0]])],
|
152 |
+
dim=1).to(eval_params['device'])
|
153 |
+
locs_enc = torch.cat((locs_enc, extra_input.repeat(locs_enc.shape[0], 1)), dim=1)
|
154 |
+
|
155 |
+
with torch.no_grad():
|
156 |
+
# Here if we set eval to False we will see what the ema embeddings look like (currently as ema is 1.0 this is just the last training example seen)
|
157 |
+
preds = model.embedding_forward(x=locs_enc, class_ids=None, return_feats=False, class_of_interest=class_of_interest, eval=True).cpu().numpy()
|
158 |
+
|
159 |
+
# threshold predictions
|
160 |
+
if eval_params['threshold'] > 0:
|
161 |
+
print(f'Applying threshold of {eval_params["threshold"]} to the predictions.')
|
162 |
+
preds[preds<eval_params['threshold']] = 0.0
|
163 |
+
preds[preds>=eval_params['threshold']] = 1.0
|
164 |
+
|
165 |
+
# mask data
|
166 |
+
if not eval_params['disable_ocean_mask']:
|
167 |
+
op_im = np.ones((mask.shape[0] * mask.shape[1])) * np.nan # set to NaN
|
168 |
+
op_im[mask_inds] = preds
|
169 |
+
else:
|
170 |
+
op_im = preds
|
171 |
+
|
172 |
+
# reshape and create masked array for visualization
|
173 |
+
op_im = op_im.reshape((mask.shape[0], mask.shape[1]))
|
174 |
+
op_im = np.ma.masked_invalid(op_im)
|
175 |
+
|
176 |
+
# set color for masked values
|
177 |
+
cmap = plt.cm.plasma
|
178 |
+
cmap.set_bad(color='none')
|
179 |
+
if eval_params['set_max_cmap_to_1']:
|
180 |
+
vmax = 1.0
|
181 |
+
else:
|
182 |
+
vmax = np.max(op_im)
|
183 |
+
|
184 |
+
# # Display the image
|
185 |
+
# if eval_params['show_map'] == 1:
|
186 |
+
# fig, ax = plt.subplots()
|
187 |
+
# cax = ax.imshow(op_im, vmin=0, vmax=vmax, cmap=cmap)
|
188 |
+
# fig.colorbar(cax)
|
189 |
+
# plt.show(block=True) # Set block=True to block code execution until the window is closed
|
190 |
+
|
191 |
+
if eval_params['show_map'] == 1:
|
192 |
+
# Display the image
|
193 |
+
fig, ax = plt.subplots(figsize=(6,3), dpi=334)
|
194 |
+
plt.imshow(op_im, vmin=0, vmax=vmax, cmap=cmap, interpolation='nearest') # Display the image
|
195 |
+
plt.axis('off') # Turn off the axis
|
196 |
+
|
197 |
+
if eval_params['show_context_points'] == 1:
|
198 |
+
# Convert the tensor to numpy array if it's not already
|
199 |
+
context_locs = context_locs_of_interest.numpy() if isinstance(context_locs_of_interest, torch.Tensor) else context_locs_of_interest
|
200 |
+
# Convert context locations directly to image coordinates
|
201 |
+
#delete our dumby context point (at 0,0)
|
202 |
+
image_x = (context_locs[1:, 0] + 1) / 2 * op_im.shape[1] # Scale longitude from [-1, 1] to [0, image width]
|
203 |
+
image_y = (1 - (context_locs[1:, 1] + 1) / 2) * op_im.shape[
|
204 |
+
0] # Scale latitude from [-1, 1] to [0, image height]
|
205 |
+
|
206 |
+
from matplotlib.offsetbox import OffsetImage, AnnotationBbox
|
207 |
+
# Plot the context locations
|
208 |
+
def getImage(path):
|
209 |
+
return OffsetImage(plt.imread(path), zoom=.04)
|
210 |
+
|
211 |
+
for x0, y0 in zip(image_x, image_y):
|
212 |
+
ab = AnnotationBbox(getImage('black_circle.png'), (x0, y0), frameon=False)
|
213 |
+
ax.add_artist(ab)
|
214 |
+
#plt.scatter(image_x, image_y, c='green', s=30, marker=r'$\checkmark$') # Adjust color and size of the point
|
215 |
+
|
216 |
+
#plt.show(block=True) # Block execution until the window is closed
|
217 |
+
|
218 |
+
|
219 |
+
exp_name = eval_params['model_path'].split(os.path.sep)[-2]
|
220 |
+
|
221 |
+
# save image
|
222 |
+
#save_loc = os.path.join(eval_params['op_path'], exp_name + '_' + str(eval_params['taxa_id']) + '_' + eval_params['additional_save_name'] +'_map.png')
|
223 |
+
#save_loc = os.path.join(eval_params['op_path'], exp_name + '_' + str(eval_params['taxa_id']) + '_' + eval_params['additional_save_name'] +'_map.png')
|
224 |
+
#save_loc = 'images/testenv_' + eval_params['taxa_name'] + '(' + eval_params['taxa_id'] + ')_'+ eval_params['text_type'] + '(' + str(text_type_value) + ')_' + str(number_of_context_points) +'.png'
|
225 |
+
save_loc = 'images/testenv_' + eval_params['taxa_name'] + '(' + eval_params['taxa_id'] + ')_'+ eval_params['text_type'] + '_' + str(number_of_context_points) +'.png'
|
226 |
+
print(f'Saving image to {save_loc}')
|
227 |
+
plt.savefig(save_loc, bbox_inches='tight', pad_inches=0, dpi=334)
|
228 |
+
# plt.imsave(fname=save_loc, arr=op_im, vmin=0, vmax=vmax, cmap=cmap)
|
229 |
+
plt.show(block=False) # Block execution until the window is closed
|
230 |
+
|
231 |
+
return True
|
232 |
+
|
233 |
+
|
234 |
+
if __name__ == '__main__':
|
235 |
+
|
236 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
237 |
+
|
238 |
+
|
239 |
+
info_str = '\nDemo that takes an iNaturalist taxa ID as input and ' + \
|
240 |
+
'generates a predicted range for each location on the globe ' + \
|
241 |
+
'and saves the ouput as an image.\n\n' + \
|
242 |
+
'Warning: these estimated ranges should be validated before use.'
|
243 |
+
|
244 |
+
parser = argparse.ArgumentParser(usage=info_str)
|
245 |
+
# parser.add_argument('--model_path', type=str, default='./pretrained_models/model_an_full_input_enc_sin_cos_hard_cap_num_per_class_1000.pt')
|
246 |
+
# parser.add_argument('--model_path', type=str, default='./experiments/transformer_ema_1.0/model_10.pt')
|
247 |
+
# parser.add_argument('--model_path', type=str, default='./experiments/03_08_coord_multihead.pt/model.pt')
|
248 |
+
# parser.add_argument('--model_path', type=str, default='./experimentvs/coord_context_20_without_registry/model_best.pt')
|
249 |
+
# parser.add_argument('--model_path', type=str, default='./experiments/coord_sinr_inputs_context_20_without_registry/model_best.pt')
|
250 |
+
parser.add_argument('--model_path', type=str, default='./experiments/zero_shot_ls_sin_cos_env_cap_1000_text_context_20_sinr_two_layer_nn/model.pt')
|
251 |
+
#parser.add_argument('--model_path', type=str, default='./experiments/zero_shot_ls_sin_cos_cap_1000_text_context_20_sinr_two_layer_nn/model.pt')
|
252 |
+
# parser.add_argument('--taxa_id', type=int, default=144575, help='iNaturalist taxon ID.')
|
253 |
+
# parser.add_argument('--taxa_id', type=int, default=9083, help='iNaturalist taxon ID.')
|
254 |
+
parser.add_argument('--taxa_id', type=int, default=3352, help='iNaturalist taxon ID.')
|
255 |
+
parser.add_argument('--threshold', type=float, default=-1, help='Threshold the range map [0, 1].')
|
256 |
+
parser.add_argument('--op_path', type=str, default='./images/', help='Location where the output image will be saved.')
|
257 |
+
parser.add_argument('--rand_taxa', action='store_true', help='Select a random taxa.')
|
258 |
+
parser.add_argument('--high_res', action='store_true', help='Generate higher resolution output.')
|
259 |
+
parser.add_argument('--disable_ocean_mask', action='store_true', help='Do not use an ocean mask.')
|
260 |
+
parser.add_argument('--set_max_cmap_to_1', action='store_true', help='Consistent maximum intensity ouput.')
|
261 |
+
parser.add_argument('--device', type=str, default='cpu', help='cpu or cuda')
|
262 |
+
#parser.add_argument('--device', type=str, default='cuda:3', help='cpu or cuda')
|
263 |
+
parser.add_argument('--show_map', type=int, default=1, help='shows the map if 1')
|
264 |
+
parser.add_argument('--show_context_points', type=int, default=1, help='also plots context points if 1')
|
265 |
+
parser.add_argument('--prefix', type=str, default='')
|
266 |
+
parser.add_argument('--num_context', type=int, default=5)
|
267 |
+
parser.add_argument('--choose_context_points', type=int, default=1)
|
268 |
+
parser.add_argument('--additional_save_name', type=str, default="")
|
269 |
+
#taxas: black&whitewarbler(10286), hyacinth macaw(18938), yellow baboon(67683)
|
270 |
+
# bawnswallow (11901), pika(43188), loon(4626), eurorobin(13094)
|
271 |
+
# southernflyingsquirrel (46272)
|
272 |
+
parser.add_argument('--taxa_name', type=str, default='sfs', help='Name of the taxon.')
|
273 |
+
parser.add_argument('--test_taxa', type=int, default=46272, help='Taxon ID to test.')
|
274 |
+
parser.add_argument('--text_type', type=str, default='range', help='Type of text for input.')
|
275 |
+
parser.add_argument('--context_pt_trial', type=int, nargs='+', default=[0, 1, 2, 5, 10, 20], help='List of context points for trial.')
|
276 |
+
eval_params = vars(parser.parse_args())
|
277 |
+
|
278 |
+
if not os.path.isdir(eval_params['op_path']):
|
279 |
+
os.makedirs(eval_params['op_path'])
|
280 |
+
|
281 |
+
eval_params['high_res'] = True
|
282 |
+
|
283 |
+
main(eval_params)
|