# import numpy as np # import h3 # import json # import os # # snt=False # # def get_labels(species, data): # species = str(species) # lat = [] # lon = [] # gt = [] # for hx in data: # cur_lat, cur_lon = h3.h3_to_geo(hx) # if species in data[hx]: # cur_label = int(len(data[hx][species]) > 0) # gt.append(cur_label) # lat.append(cur_lat) # lon.append(cur_lon) # lat = np.array(lat).astype(np.float32) # lon = np.array(lon).astype(np.float32) # obs_locs = np.vstack((lon, lat)).T # gt = np.array(gt).astype(np.float32) # return obs_locs, gt # # def lonlat_to_pixel(lonlat, grid_width, grid_height): # # Convert normalized lon/lat (-1 to 1) to pixel coordinates # x_pixel = np.floor((lonlat[:, 0] + 1) / 2 * (grid_width - 1)).astype(int) # y_pixel = np.floor((1 - (lonlat[:, 1] + 1) / 2) * (grid_height - 1)).astype(int) # return x_pixel, y_pixel # # ocean_mask = np.load("data/masks/ocean_mask.npy", allow_pickle=True) # # 1002, 2004 pixels # # 0 in ocean (needs to be masked out) # # if snt: # with open('paths.json', 'r') as f: # paths = json.load(f) # D = np.load(os.path.join(paths['snt'], 'snt_res_5.npy'), allow_pickle=True) # D = D.item() # loc_indices_per_species = D['loc_indices_per_species'] # labels_per_species = D['labels_per_species'] # taxa = D['taxa'] # obs_locs = D['obs_locs'] # obs_locs_idx = D['obs_locs_idx'] # else: # with open('paths.json', 'r') as f: # paths = json.load(f) # with open(os.path.join(paths['iucn'], 'iucn_res_5.json'), 'r') as f: # data = json.load(f) # obs_locs = np.array(data['locs'], dtype=np.float32) # taxa = [int(tt) for tt in data['taxa_presence'].keys()] # a = 6 # # data['taxa_presence'] is a dict where keys are "taxa" and then the values are the indices of "obs_locs" where the species is present # # obs locs is in lon, lat with -180 to 180 and -90 to 90 import numpy as np import h3 import json import os import matplotlib.pyplot as plt from mpl_toolkits.axes_grid1 import make_axes_locatable def get_labels(species, data): species = str(species) lat = [] lon = [] gt = [] for hx in data: cur_lat, cur_lon = h3.h3_to_geo(hx) if species in data[hx]: cur_label = int(len(data[hx][species]) > 0) gt.append(cur_label) lat.append(cur_lat) lon.append(cur_lon) lat = np.array(lat).astype(np.float32) lon = np.array(lon).astype(np.float32) obs_locs = np.vstack((lon, lat)).T gt = np.array(gt).astype(np.float32) return obs_locs, gt def lonlat_to_pixel(lonlat, grid_width, grid_height): # Convert normalized lon/lat (-1 to 1) to pixel coordinates x_pixel = np.floor((lonlat[:, 0] + 1) / 2 * (grid_width - 1)).astype(int) y_pixel = np.floor((1 - (lonlat[:, 1] + 1) / 2) * (grid_height - 1)).astype(int) return x_pixel, y_pixel # def plot_heatmap(data,save_loc): # # Apply mask if provided # ocean_mask = np.load("data/masks/ocean_mask.npy", allow_pickle=True) # # 1002, 2004 pixels # # 0 in ocean (needs to be masked out) # # # Convert ocean_mask to boolean mask # mask = ocean_mask.astype(bool) # mask = mask[::2, ::2] # # if mask is not None: # data = np.where(mask, data, 0) # # # Set NaN values to 0 for plotting # data = np.nan_to_num(data, nan=0.0) # # fig, ax = plt.subplots(figsize=(20.04, 10.02), dpi=100) # ax.set_xlim(-180, 180) # ax.set_ylim(-90, 90) # ax.axis('off') # # # Use 'magma' colormap with two discrete colors # cmap = plt.get_cmap('magma', 2) # cmap.set_bad(color='none') # plt.rcParams['font.family'] = 'serif' # # cax_im = ax.imshow(data, extent=(-180, 180, -90, 90), origin='upper', cmap=cmap, vmin=0, vmax=1) # # plt.tight_layout() # pdf_save_loc = save_loc + '.pdf' # png_save_loc = save_loc + '.png' # plt.savefig(pdf_save_loc, bbox_inches='tight', pad_inches=0) # plt.savefig(png_save_loc, bbox_inches='tight', pad_inches=0) # plt.close(fig) def plot_heatmap(data, save_loc): # Load the ocean mask ocean_mask = np.load("data/masks/ocean_mask.npy", allow_pickle=True) # 1002, 2004 pixels # 0 in ocean (needs to be masked out) # Convert ocean_mask to boolean mask mask = ocean_mask.astype(bool) # If you need to downsample the mask, uncomment the following line mask = mask[::2, ::2] # Set ocean areas to np.nan data = np.where(mask, data, np.nan) # Create a masked array where NaNs are masked data_masked = np.ma.array(data, mask=np.isnan(data)) fig, ax = plt.subplots(figsize=(20.04, 10.02), dpi=100) ax.set_xlim(-180, 180) ax.set_ylim(-90, 90) ax.axis('off') # Use 'magma' colormap with two discrete colors cmap = plt.get_cmap('plasma', 2) # Set color for masked (NaN) values cmap.set_bad(color='none') # 'none' makes it transparent; use 'white' for white background # Plot the data cax_im = ax.imshow( data_masked, extent=(-180, 180, -90, 90), origin='upper', cmap=cmap, vmin=0, vmax=1, interpolation='nearest' ) plt.tight_layout() pdf_save_loc = save_loc + '.pdf' png_save_loc = save_loc + '.png' plt.savefig(pdf_save_loc, bbox_inches='tight', pad_inches=0) plt.savefig(png_save_loc, bbox_inches='tight', pad_inches=0) plt.close(fig) def plot_heatmap_2(data, save_loc): # Load the ocean mask ocean_mask = np.load("data/masks/ocean_mask.npy", allow_pickle=True) # 1002, 2004 pixels # 0 in ocean (needs to be masked out) # Convert ocean_mask to boolean mask mask = ocean_mask.astype(bool) # If you need to downsample the mask, uncomment the following line # Set ocean areas to np.nan data = np.where(mask, data, np.nan) # Create a masked array where NaNs are masked data_masked = np.ma.array(data, mask=np.isnan(data)) fig, ax = plt.subplots(figsize=(20.04, 10.02), dpi=100) ax.set_xlim(-180, 180) ax.set_ylim(-90, 90) ax.axis('off') # Use 'magma' colormap with two discrete colors cmap = plt.get_cmap('plasma', 2) # Set color for masked (NaN) values cmap.set_bad(color='none') # 'none' makes it transparent; use 'white' for white background # Plot the data cax_im = ax.imshow( data_masked, extent=(-180, 180, -90, 90), origin='upper', cmap=cmap, vmin=0, vmax=1, interpolation='nearest' ) plt.tight_layout() pdf_save_loc = save_loc + '.pdf' png_save_loc = save_loc + '.png' plt.savefig(pdf_save_loc, bbox_inches='tight', pad_inches=0) plt.savefig(png_save_loc, bbox_inches='tight', pad_inches=0) plt.show(block=False) plt.close(fig) def generate_ground_truth(taxa_id, snt=True, grid_height=501, grid_width=1002): print(taxa_id) if snt: with open('paths.json', 'r') as f: paths = json.load(f) D = np.load(os.path.join(paths['snt'], 'snt_res_5.npy'), allow_pickle=True) D = D.item() loc_indices_per_species = D['loc_indices_per_species'] labels_per_species = D['labels_per_species'] taxa = D['taxa'] obs_locs = D['obs_locs'] obs_locs_idx = D['obs_locs_idx'] # class_index = np.where(taxa==taxa_id) # class_index = class_index[0] # class_index = class_index[0] # species_loc_indices = loc_indices_per_species[class_index] # species_locs = obs_locs[species_loc_indices] # presence_indices = labels_per_species[class_index] # species_locs = species_locs[presence_indices==1] # Ensure class_index is correctly obtained as an integer index class_indices = np.where(taxa == taxa_id)[0] if len(class_indices) == 0: raise ValueError(f"taxa_id {taxa_id} not found in taxa") class_index = class_indices[0] # Convert loc_indices_per_species[class_index] to a NumPy array species_loc_indices = np.array(loc_indices_per_species[class_index]) # Retrieve the species locations using the indices species_locs = obs_locs[species_loc_indices] # Convert labels_per_species[class_index] to a NumPy array presence_indices = np.array(labels_per_species[class_index]) # Filter species_locs where presence_indices == 1 species_locs = species_locs[presence_indices == 1] else: with open('paths.json', 'r') as f: paths = json.load(f) with open(os.path.join(paths['iucn'], 'iucn_res_5.json'), 'r') as f: data = json.load(f) obs_locs = np.array(data['locs'], dtype=np.float32) taxa = [int(tt) for tt in data['taxa_presence'].keys()] indices = data['taxa_presence'][str(taxa_id)] species_locs = obs_locs[indices] # shape (N, 2) # Normalize lonlat species_locs_normalized = species_locs.copy() species_locs_normalized[:, 0] = species_locs_normalized[:, 0] / 180 # lon / 180 species_locs_normalized[:, 1] = species_locs_normalized[:, 1] / 90 # lat / 90# Get grid dimensions from ocean_mas # Get pixel coordinates x_pixel, y_pixel = lonlat_to_pixel(species_locs_normalized, grid_width, grid_height) # Ensure x_pixel and y_pixel are within bounds x_pixel = np.clip(x_pixel, 0, grid_width - 1) y_pixel = np.clip(y_pixel, 0, grid_height - 1) # Create data array data_array = np.zeros((grid_height, grid_width)) # Set pixels where species is present data_array[y_pixel, x_pixel] = 1 # Now call plot_heatmap title = f"Species presence for taxa {taxa_id}" save_loc = f"./images/species_presence_{taxa_id}" plot_heatmap(data_array, save_loc) grid_height = 1002 grid_width = 2004 if snt: with open('paths.json', 'r') as f: paths = json.load(f) D = np.load(os.path.join(paths['snt'], 'snt_res_5.npy'), allow_pickle=True) D = D.item() loc_indices_per_species = D['loc_indices_per_species'] labels_per_species = D['labels_per_species'] taxa = D['taxa'] obs_locs = D['obs_locs'] obs_locs_idx = D['obs_locs_idx'] # class_index = np.where(taxa==taxa_id) # class_index = class_index[0] # class_index = class_index[0] # species_loc_indices = loc_indices_per_species[class_index] # species_locs = obs_locs[species_loc_indices] # presence_indices = labels_per_species[class_index] # species_locs = species_locs[presence_indices==1] # Ensure class_index is correctly obtained as an integer index class_indices = np.where(taxa == taxa_id)[0] if len(class_indices) == 0: raise ValueError(f"taxa_id {taxa_id} not found in taxa") class_index = class_indices[0] # Convert loc_indices_per_species[class_index] to a NumPy array species_loc_indices = np.array(loc_indices_per_species[class_index]) # Retrieve the species locations using the indices species_locs = obs_locs[species_loc_indices] # Convert labels_per_species[class_index] to a NumPy array presence_indices = np.array(labels_per_species[class_index]) # Filter species_locs where presence_indices == 1 species_locs = species_locs[presence_indices == 1] else: with open('paths.json', 'r') as f: paths = json.load(f) with open(os.path.join(paths['iucn'], 'iucn_res_5.json'), 'r') as f: data = json.load(f) obs_locs = np.array(data['locs'], dtype=np.float32) taxa = [int(tt) for tt in data['taxa_presence'].keys()] indices = data['taxa_presence'][str(taxa_id)] species_locs = obs_locs[indices] # shape (N, 2) # Normalize lonlat species_locs_normalized = species_locs.copy() species_locs_normalized[:, 0] = species_locs_normalized[:, 0] / 180 # lon / 180 species_locs_normalized[:, 1] = species_locs_normalized[:, 1] / 90 # lat / 90# Get grid dimensions from ocean_mas # Get pixel coordinates x_pixel, y_pixel = lonlat_to_pixel(species_locs_normalized, grid_width, grid_height) # Ensure x_pixel and y_pixel are within bounds x_pixel = np.clip(x_pixel, 0, grid_width - 1) y_pixel = np.clip(y_pixel, 0, grid_height - 1) # Create data array data_array = np.zeros((grid_height, grid_width)) # Set pixels where species is present data_array[y_pixel, x_pixel] = 1 # Now call plot_heatmap title = f"Species presence for taxa {taxa_id}" save_loc = f"./images/species_presence_hr_{taxa_id}" plot_heatmap_2(data_array, save_loc) return True if __name__ == '__main__': snt = True grid_height = 501 grid_width = 1002 taxa_id = 11901 # Or any taxa id you want to plot, as string #TODO: why snt true? can't generate gt for (hyacinth macaw(18938), yellow baboon(67683), pika(43188), southernflyingsquirrel (46272)) generate_ground_truth(taxa_id=taxa_id, snt=snt, grid_height=grid_height, grid_width=grid_width)