fs_sinr / models.py
angelazhu96
code for viz
9ff98d7
raw
history blame
69.4 kB
import torch
import torch.utils.data
import torch.nn as nn
import math
import csv
import numpy as np
import json
import os
def get_model(params, inference_only=False):
if params['model'] == 'ResidualFCNet':
return ResidualFCNet(params['input_dim'] + params['input_time_dim'] + (20 if 'env' in params['input_enc'] and 'contrastive' not in params['input_enc'] else 0) + (1 if params['noise_time'] else 0), params['num_classes'] + (20 if 'env' in params['loss'] else 0), params['num_filts'], params['depth'])
elif params['model'] == 'LinNet':
return LinNet(params['input_dim'] + params['input_time_dim'] + (20 if 'env' in params['input_enc'] else 0) + (1 if params['noise_time'] else 0), params['num_classes'])
elif params['model'] == 'HyperNet':
return HyperNet(params, params['input_dim'] + (20 if 'env' in params['input_enc'] else 0), params['num_classes'], params['num_filts'], params['depth'],
params['species_dim'], params['species_enc_depth'], params['species_filts'], params['species_enc'], inference_only=inference_only)
# chris models
elif params['model'] == 'MultiInputModel':
return MultiInputModel(num_inputs=params['input_dim'] + params['input_time_dim'] + (20 if 'env' in params['input_enc'] and 'contrastive' not in params['input_enc'] else 0) + (1 if params['noise_time'] else 0),
num_filts=params['num_filts'], num_classes=params['num_classes'] + (20 if 'env' in params['loss'] else 0),
depth=params['depth'], ema_factor=params['ema_factor'], nhead=params['num_heads'], num_encoder_layers=params['species_enc_depth'],
dim_feedforward=params['species_filts'], dropout=params['transformer_dropout'],
batch_first=True, token_dim=(params['species_dim'] + (20 if 'env' in params['transformer_input_enc'] else 0)),
sinr_inputs=True if 'sinr' in params['transformer_input_enc'] else False,
register=params['use_register'], use_pretrained_sinr=params['use_pretrained_sinr'],
freeze_sinr=params['freeze_sinr'], pretrained_loc=params['pretrained_loc'],
text_inputs=params['use_text_inputs'], class_token_transformation=params['class_token_transformation'])
elif params['model'] == 'VariableInputModel':
return VariableInputModel(num_inputs=params['input_dim'] + params['input_time_dim'] + (20 if 'env' in params['input_enc'] and 'contrastive' not in params['input_enc'] else 0) + (1 if params['noise_time'] else 0),
num_filts=params['num_filts'], num_classes=params['num_classes'] + (20 if 'env' in params['loss'] else 0),
depth=params['depth'], ema_factor=params['ema_factor'], nhead=params['num_heads'], num_encoder_layers=params['species_enc_depth'],
dim_feedforward=params['species_filts'], dropout=params['transformer_dropout'],
batch_first=True, token_dim=(params['species_dim'] + (20 if 'env' in params['transformer_input_enc'] else 0)),
sinr_inputs=True if 'sinr' in params['transformer_input_enc'] else False,
register=params['use_register'], use_pretrained_sinr=params['use_pretrained_sinr'],
freeze_sinr=params['freeze_sinr'], pretrained_loc=params['pretrained_loc'],
text_inputs=params['use_text_inputs'], image_inputs=params['use_image_inputs'],
env_inputs=params['use_env_inputs'],
class_token_transformation=params['class_token_transformation'])
# class VariableInputModel(nn.Module):
# def __init__(self, num_inputs, num_filts, num_classes, depth=4, nonlin='relu', lowrank=0, ema_factor=0.1,
# nhead=8, num_encoder_layers=4, dim_feedforward=2048, dropout=0.1, batch_first=True, token_dim=256,
# sinr_inputs=False, register=False, use_pretrained_sinr=False, freeze_sinr=False, pretrained_loc='',
# text_inputs=False, image_inputs=False, env_inputs=False, class_token_transformation='identity'):
class ResLayer(nn.Module):
def __init__(self, linear_size, activation=nn.ReLU, p=0.5):
super(ResLayer, self).__init__()
self.l_size = linear_size
self.nonlin1 = activation()
self.nonlin2 = activation()
self.dropout1 = nn.Dropout(p=p)
self.w1 = nn.Linear(self.l_size, self.l_size)
self.w2 = nn.Linear(self.l_size, self.l_size)
def forward(self, x):
y = self.w1(x)
y = self.nonlin1(y)
y = self.dropout1(y)
y = self.w2(y)
y = self.nonlin2(y)
out = x + y
return out
class ResidualFCNet(nn.Module):
def __init__(self, num_inputs, num_classes, num_filts, depth=4, nonlin='relu', lowrank=0, dropout_p=0.5):
super(ResidualFCNet, self).__init__()
self.inc_bias = False
if lowrank < num_filts and lowrank != 0:
l1 = nn.Linear(num_filts if depth != -1 else num_inputs, lowrank, bias=self.inc_bias)
l2 = nn.Linear(lowrank, num_classes, bias=self.inc_bias)
self.class_emb = nn.Sequential(l1, l2)
else:
self.class_emb = nn.Linear(num_filts if depth != -1 else num_inputs, num_classes, bias=self.inc_bias)
if nonlin == 'relu':
activation = nn.ReLU
elif nonlin == 'silu':
activation = nn.SiLU
else:
raise NotImplementedError('Invalid nonlinearity specified.')
layers = []
if depth != -1:
layers.append(nn.Linear(num_inputs, num_filts))
layers.append(activation())
for i in range(depth):
layers.append(ResLayer(num_filts, activation=activation))
else:
layers.append(nn.Identity())
self.feats = torch.nn.Sequential(*layers)
def forward(self, x, class_of_interest=None, return_feats=False):
loc_emb = self.feats(x)
if return_feats:
return loc_emb
if class_of_interest is None:
class_pred = self.class_emb(loc_emb)
else:
class_pred = self.eval_single_class(loc_emb, class_of_interest), self.eval_single_class(loc_emb, -1)
return torch.sigmoid(class_pred[0]), torch.sigmoid(class_pred[1])
return torch.sigmoid(class_pred)
def eval_single_class(self, x, class_of_interest):
if self.inc_bias:
return x @ self.class_emb.weight[class_of_interest, :] + self.class_emb.bias[class_of_interest]
else:
return x @ self.class_emb.weight[class_of_interest, :]
class SimpleFCNet(ResidualFCNet):
def forward(self, x, return_feats=True):
assert return_feats
loc_emb = self.feats(x)
class_pred = self.class_emb(loc_emb)
return class_pred
class MockTransformer(nn.Module):
def __init__(self, num_classes, num_dims):
super(MockTransformer, self).__init__()
self.species_emb = nn.Embedding(num_embeddings=num_classes, embedding_dim=num_dims)
def forward(self, class_ids):
return self.species_emb(class_ids)
class CombinedModel(nn.Module):
def __init__(self, num_inputs, num_filts, num_classes, depth=4, nonlin='relu', lowrank=0, ema_factor=0.1):
super(CombinedModel, self).__init__()
self.headless_model = HeadlessSINR(num_inputs, num_filts, depth, nonlin, lowrank)
if lowrank < num_filts and lowrank != 0:
self.transformer_model = MockTransformer(num_classes, lowrank)
else:
self.transformer_model = MockTransformer(num_classes, num_filts)
self.ema_factor = ema_factor
self.ema_embeddings = nn.Embedding(num_embeddings=num_classes, embedding_dim=lowrank if (lowrank < num_filts and lowrank != 0) else num_filts)
self.ema_embeddings.weight.data.copy_(self.transformer_model.species_emb.weight.data) # Initialize EMA with the same values as transformer
# this will have to change when I start using the actual transformer
def forward(self, x, class_ids=None, return_feats=False, return_class_embeddings=False, class_of_interest=None):
# Process input through the headless model to get feature embeddings
feature_embeddings = self.headless_model(x)
if return_feats:
return feature_embeddings
else:
if class_of_interest == None:
# Get class-specific embeddings based on class_ids
class_embeddings = self.transformer_model(class_ids)
if return_class_embeddings:
return class_embeddings
else:
# Update EMA embeddings for these class IDs
if self.training:
self.update_ema_embeddings(class_ids, class_embeddings)
# Matrix multiplication to produce logits
logits = feature_embeddings @ class_embeddings.T
# Apply sigmoid to convert logits to probabilities
probabilities = torch.sigmoid(logits)
return probabilities
else:
device = self.ema_embeddings.weight.device
class_of_interest_tensor =torch.tensor([class_of_interest]).to(device)
class_embedding = self.get_ema_embeddings(class_of_interest_tensor)
print(f'using EMA estimate for class {class_of_interest}')
if return_class_embeddings:
return class_embedding
else:
# Matrix multiplication to produce logits
logits = feature_embeddings @ class_embedding.T
# Apply sigmoid to convert logits to probabilities
probabilities = torch.sigmoid(logits)
probabilities = probabilities.squeeze()
return probabilities
def update_ema_embeddings(self, class_ids, current_embeddings):
if self.training:
# Get current EMA embeddings for the class IDs
ema_current = self.ema_embeddings(class_ids)
# Calculate new EMA values
ema_new = self.ema_factor * current_embeddings + (1 - self.ema_factor) * ema_current
# Update the EMA embeddings
self.ema_embeddings.weight.data[class_ids] = ema_new.detach() # Detach to prevent gradients from flowing here
def get_ema_embeddings(self, class_ids):
# Method to access EMA embeddings
return self.ema_embeddings(class_ids)
class HeadlessSINR(nn.Module):
def __init__(self, num_inputs, num_filts, depth=4, nonlin='relu', lowrank=0, dropout_p=0.5):
super(HeadlessSINR, self).__init__()
self.inc_bias = False
self.low_rank_feats = None
if lowrank < num_filts and lowrank != 0:
l1 = nn.Linear(num_filts if depth != -1 else num_inputs, lowrank, bias=self.inc_bias)
self.low_rank_feats = l1
# else:
# self.class_emb = nn.Linear(num_filts if depth != -1 else num_inputs, num_classes, bias=self.inc_bias)
if nonlin == 'relu':
activation = nn.ReLU
elif nonlin == 'silu':
activation = nn.SiLU
else:
raise NotImplementedError('Invalid nonlinearity specified.')
# Create the layers list for feature extraction
layers = []
if depth != -1:
layers.append(nn.Linear(num_inputs, num_filts))
layers.append(activation())
for i in range(depth):
layers.append(ResLayer(num_filts, activation=activation, p=dropout_p))
else:
layers.append(nn.Identity())
# Include low-rank features in the sequential model if it is defined
if self.low_rank_feats:
# Apply initial layers then low-rank features
layers.append(self.low_rank_feats)
# Set up the features as a sequential model
self.feats = nn.Sequential(*layers)
def forward(self, x):
loc_emb = self.feats(x)
return loc_emb
class TransformerEncoderModel(nn.Module):
def __init__(self, d_model=256, nhead=8, num_encoder_layers=4, dim_feedforward=2048, dropout=0.1, activation='relu',
batch_first=True, output_dim=256): # BATCH FIRST MIGHT HAVE TO CHANGE
super(TransformerEncoderModel, self).__init__()
self.input_layer_norm = nn.LayerNorm(normalized_shape=d_model)
# Create an encoder layer
encoder_layer = nn.TransformerEncoderLayer(
d_model=d_model,
nhead=nhead,
dim_feedforward=dim_feedforward,
dropout=dropout,
activation=activation,
batch_first=batch_first
)
# Stack the encoder layers into an encoder module
self.transformer_encoder = nn.TransformerEncoder(
encoder_layer=encoder_layer,
num_layers=num_encoder_layers
)
# Example output layer (modify according to your needs)
self.output_layer = nn.Linear(d_model, output_dim)
def forward(self, src, src_mask=None, src_key_padding_mask=None):
"""
Args:
src: the sequence to the encoder (shape: [seq_length, batch_size, d_model])
src_mask: the mask for the src sequence (shape: [seq_length, seq_length])
src_key_padding_mask: the mask for the padding tokens (shape: [batch_size, seq_length])
Returns:
output of the transformer encoder
"""
# Pass the input through the transformer encoder
encoder_input = self.input_layer_norm(src)
encoder_output = self.transformer_encoder(encoder_input, src_key_padding_mask=src_key_padding_mask, mask=src_mask)
# # Pass the encoder output through the output layer
# output = self.output_layer(encoder_output)
# Assuming the class token is the first in the sequence
# batch_first so we have (batch, sequence, dim)
if encoder_output.ndim == 2:
# in situations where we don't have a batch
encoder_output = encoder_output.unsqueeze(0)
class_token_embedding = encoder_output[:, 0, :]
output = self.output_layer(class_token_embedding) # Process only the class token embedding
return output
class MultiInputModel(nn.Module):
def __init__(self, num_inputs, num_filts, num_classes, depth=4, nonlin='relu', lowrank=0, ema_factor=0.1,
nhead=8, num_encoder_layers=4, dim_feedforward=2048, dropout=0.1, batch_first=True, token_dim=256,
sinr_inputs=False, register=False, use_pretrained_sinr=False, freeze_sinr=False, pretrained_loc='',
text_inputs=False, class_token_transformation='identity'):
super(MultiInputModel, self).__init__()
self.headless_model = HeadlessSINR(num_inputs, num_filts, depth, nonlin, lowrank, dropout_p=dropout)
self.ema_factor = ema_factor
self.class_token_transformation = class_token_transformation
# Load pretrained state_dict if use_pretrained_sinr is set to True
if use_pretrained_sinr:
#pretrained_state_dict = torch.load(pretrained_loc, weights_only=False)['state_dict']
pretrained_state_dict = torch.load(pretrained_loc, map_location=torch.device('cpu'))['state_dict']
filtered_state_dict = {k: v for k, v in pretrained_state_dict.items() if not k.startswith('class_emb')}
self.headless_model.load_state_dict(filtered_state_dict, strict=False)
#print(f'Using pretrained sinr from {pretrained_loc}')
# Freeze the SINR model if freeze_sinr is set to True
if freeze_sinr:
for param in self.headless_model.parameters():
param.requires_grad = False
print("Freezing SINR model parameters")
# self.transformer_model = MockTransformer(num_classes, num_filts)
self.transformer_model = TransformerEncoderModel(d_model=token_dim,
nhead=nhead,
num_encoder_layers=num_encoder_layers,
dim_feedforward=dim_feedforward,
dropout=dropout,
batch_first=batch_first,
output_dim=num_filts)
self.ema_embeddings = nn.Embedding(num_embeddings=num_classes, embedding_dim=num_filts)
# this is just a workaround for now to load eval embeddings - probably not needed long term
self.eval_embeddings = nn.Embedding(num_embeddings=num_classes, embedding_dim=num_filts)
self.ema_embeddings.weight.requires_grad = False
self.eval_embeddings.weight.requires_grad = False
self.num_filts=num_filts
self.token_dim = token_dim
# nn.init.xavier_uniform_(self.ema_embeddings.weight) # not needed I think
self.sinr_inputs = sinr_inputs
if self.sinr_inputs:
if self.num_filts != self.token_dim and self.class_token_transformation == 'identity':
raise ValueError("If using sinr inputs to transformer with identity class token transformation"
"then token_dim of transformer must be equal to num_filts of sinr model")
# Add a class token
self.class_token = nn.Parameter(torch.empty(1, self.token_dim))
nn.init.xavier_uniform_(self.class_token)
if register:
# Add a register token initialized with Xavier uniform initialization
self.register = nn.Parameter(torch.empty(1, self.token_dim))
# self.register = (self.register / 2)
nn.init.xavier_uniform_(self.register)
else:
self.register = None
self.text_inputs = text_inputs
if self.text_inputs:
#print("JUST USING A HEADLESS SINR FOR THE TEXT MODEL RIGHT NOW")
self.text_model=HeadlessSINR(num_inputs=4096, num_filts=512, depth=2, nonlin=nonlin, lowrank=token_dim, dropout_p=dropout)
else:
self.text_model=None
# Type-specific embeddings for class, register, location, and text tokens
self.class_type_embedding = nn.Parameter(torch.empty(1, self.token_dim))
nn.init.xavier_uniform_(self.class_type_embedding)
if register:
self.register_type_embedding = nn.Parameter(torch.empty(1, self.token_dim))
nn.init.xavier_uniform_(self.register_type_embedding)
self.location_type_embedding = nn.Parameter(torch.empty(1, self.token_dim))
nn.init.xavier_uniform_(self.location_type_embedding)
if text_inputs:
self.text_type_embedding = nn.Parameter(torch.empty(1, self.token_dim))
nn.init.xavier_uniform_(self.text_type_embedding)
# Instantiate the class token transformation module
if class_token_transformation == 'identity':
self.class_token_transform = Identity(token_dim, num_filts)
elif class_token_transformation == 'linear':
self.class_token_transform = LinearTransformation(token_dim, num_filts)
elif class_token_transformation == 'single_layer_nn':
self.class_token_transform = SingleLayerNN(token_dim, num_filts, dropout_p=dropout)
elif class_token_transformation == 'two_layer_nn':
self.class_token_transform = TwoLayerNN(token_dim, num_filts, dropout_p=dropout)
elif class_token_transformation == 'sinr':
self.class_token_transform = HeadlessSINR(token_dim, num_filts, depth, nonlin, lowrank, dropout_p=dropout)
else:
raise ValueError(f"Unknown class_token_transformation: {class_token_transformation}")
def forward(self, x, context_sequence, context_mask, class_ids=None, return_feats=False, return_class_embeddings=False, class_of_interest=None, use_eval_embeddings=False, text_emb=None):
# Process input through the headless model to get feature embeddings
feature_embeddings = self.headless_model(x)
if return_feats:
return feature_embeddings
if context_sequence.dim() == 2:
context_sequence = context_sequence.unsqueeze(0) # Add batch dimension if missing
context_sequence = context_sequence[:, 1:, :]
if self.sinr_inputs:
# Pass through the headless model
context_sequence = self.headless_model(context_sequence)
# Add type-specific embedding to each location token
# print("SEE IF THIS WORKS")
context_sequence += self.location_type_embedding
batch_size = context_sequence.size(0)
# Expand the class token to match the batch size and add its type-specific embedding
class_token_expanded = self.class_token.expand(batch_size, -1, -1) + self.class_type_embedding
if self.text_inputs and (text_emb is not None):
text_mask = (text_emb.sum(dim=1) == 0)
text_emb = self.text_model(text_emb)
text_emb += self.text_type_embedding
text_emb[text_mask] = 0
# Reshape text_emb to have the shape (batch_size, 1, embedding_dim)
text_emb = text_emb.unsqueeze(1)
if self.register is None:
# context sequence = learnable class_token + rest of sequence
if self.text_inputs:
# Add the class token and text embeddings to the context sequence
context_sequence = torch.cat((class_token_expanded, text_emb, context_sequence), dim=1)
# Pad the context mask to account for the added text embeddings
context_mask = nn.functional.pad(context_mask, pad=(1, 0), value=False)
# Update the new part of the mask with the text_mask
context_mask[:, 1] = text_mask # Apply mask directly
else:
context_sequence = torch.cat((class_token_expanded, context_sequence), dim=1)
else:
# Expand the register token to match the batch size and add its type-specific embedding
register_expanded = self.register.expand(batch_size, -1, -1) + self.register_type_embedding
if self.text_inputs:
# Add all components: class token, register, text embeddings, and context
context_sequence = torch.cat((class_token_expanded, register_expanded, text_emb, context_sequence),
dim=1)
# Double pad the context mask: first for register, then for text embeddings
context_mask = nn.functional.pad(context_mask, pad=(1, 0), value=False)
context_mask = nn.functional.pad(context_mask, pad=(1, 0), value=False)
# Update the new part of the mask for text embeddings
context_mask[:, register_expanded.size(1) + 1] = text_mask # Apply mask directly
else:
context_sequence = torch.cat((class_token_expanded, register_expanded, context_sequence), dim=1)
# Update the context mask to account for the register token
context_mask = nn.functional.pad(context_mask, pad=(1, 0), value=False)
if use_eval_embeddings == False:
if class_of_interest == None:
# Get class-specific embeddings based on class_ids
class_token_output = self.transformer_model(src=context_sequence, src_key_padding_mask=context_mask)
# pass these through the class token transformation
class_embeddings = self.class_token_transform(class_token_output) # Shape: (batch_size, num_filts)
if return_class_embeddings:
return class_embeddings
else:
# Update EMA embeddings for these class IDs
with torch.no_grad():
if self.training:
self.update_ema_embeddings(class_ids, class_embeddings)
# Matrix multiplication to produce logits
logits = feature_embeddings @ class_embeddings.T
# Apply sigmoid to convert logits to probabilities
probabilities = torch.sigmoid(logits)
return probabilities
else:
device = self.ema_embeddings.weight.device
class_of_interest_tensor =torch.tensor([class_of_interest]).to(device)
class_embedding = self.get_ema_embeddings(class_of_interest_tensor)
print(f'using EMA estimate for class {class_of_interest}')
if return_class_embeddings:
return class_embedding
else:
# Matrix multiplication to produce logits
logits = feature_embeddings @ class_embedding.T
# Apply sigmoid to convert logits to probabilities
probabilities = torch.sigmoid(logits)
probabilities = probabilities.squeeze()
return probabilities
else:
self.eval()
if not hasattr(self, 'eval_embeddings'):
self.eval_embeddings = self.ema_embeddings
if class_of_interest == None:
# Get class-specific embeddings based on class_ids
class_token_output = self.transformer_model(src=context_sequence, src_key_padding_mask=context_mask)
class_embeddings = self.class_token_transform(class_token_output)
# Update EMA embeddings for these class IDs
self.generate_eval_embeddings(class_ids, class_embeddings)
# Matrix multiplication to produce logits
logits = feature_embeddings @ class_embeddings.T
# Apply sigmoid to convert logits to probabilities
probabilities = torch.sigmoid(logits)
return probabilities
else:
device = self.ema_embeddings.weight.device
class_of_interest_tensor =torch.tensor([class_of_interest]).to(device)
class_embedding = self.get_eval_embeddings(class_of_interest_tensor)
print(f'using eval embedding for class {class_of_interest}')
if return_class_embeddings:
return class_embedding
else:
# Matrix multiplication to produce logits
logits = feature_embeddings @ class_embedding.T
# Apply sigmoid to convert logits to probabilities
probabilities = torch.sigmoid(logits)
probabilities = probabilities.squeeze()
return probabilities
def init_eval_embeddings(self, num_classes):
self.eval_embeddings = nn.Embedding(num_embeddings=num_classes, embedding_dim=self.num_filts)
nn.init.xavier_uniform_(self.eval_embeddings.weight)
def get_ema_embeddings(self, class_ids):
# Method to access EMA embeddings
return self.ema_embeddings(class_ids)
def get_eval_embeddings(self, class_ids):
# Method to access eval embeddings
return self.eval_embeddings(class_ids)
def update_ema_embeddings(self, class_ids, current_embeddings):
if self.training:
# Get unique class IDs and their counts
unique_class_ids, inverse_indices, counts = class_ids.unique(return_counts=True, return_inverse=True)
# Get current EMA embeddings for unique class IDs
ema_current = self.ema_embeddings(unique_class_ids)
# Initialize a placeholder for new EMA values
ema_new = torch.zeros_like(ema_current)
# Compute the average of current embeddings for each unique class ID
current_sum = torch.zeros_like(ema_current)
current_sum.index_add_(0, inverse_indices, current_embeddings)
current_avg = current_sum / counts.unsqueeze(1)
# Apply EMA update formula
ema_new = self.ema_factor * current_avg + (1 - self.ema_factor) * ema_current
# Update the EMA embeddings for unique class IDs
self.ema_embeddings.weight.data[unique_class_ids] = ema_new.detach() # Detach to prevent gradients
def generate_eval_embeddings(self, class_id, current_embedding):
self.eval_embeddings.weight.data[class_id, :] = current_embedding.detach() # Detach to prevent gradients
# self.eval_embeddings.weight.data[class_id] = self.ema_embeddings.weight.data[class_id] # Detach to prevent gradients
def embedding_forward(self, x, class_ids=None, return_feats=False, return_class_embeddings=False, class_of_interest=None, eval=False):
# forward method that uses ema or eval embeddings rather than context sequence
# Process input through the headless model to get feature embeddings
feature_embeddings = self.headless_model(x)
if return_feats:
return feature_embeddings
else:
if class_of_interest == None:
# Get class-specific embeddings based on class_ids
if eval == False:
class_embeddings = self.get_ema_embeddings(class_ids=class_ids)
else:
class_embeddings = self.get_eval_embeddings(class_ids=class_ids)
if return_class_embeddings:
return class_embeddings
else:
# Matrix multiplication to produce logits
logits = feature_embeddings @ class_embeddings.T
# Apply sigmoid to convert logits to probabilities
probabilities = torch.sigmoid(logits)
return probabilities
else:
if eval == False:
device = self.ema_embeddings.weight.device
class_of_interest_tensor =torch.tensor([class_of_interest]).to(device)
class_embedding = self.get_ema_embeddings(class_of_interest_tensor)
print(f'using EMA estimate for class {class_of_interest}')
if return_class_embeddings:
return class_embedding
else:
# Matrix multiplication to produce logits
logits = feature_embeddings @ class_embedding.T
# Apply sigmoid to convert logits to probabilities
probabilities = torch.sigmoid(logits)
probabilities = probabilities.squeeze()
return probabilities
else:
device = self.eval_embeddings.weight.device
class_of_interest_tensor = torch.tensor([class_of_interest]).to(device)
class_embedding = self.get_eval_embeddings(class_of_interest_tensor)
#print(f'using eval estimate for class {class_of_interest}')
if return_class_embeddings:
return class_embedding
else:
# Matrix multiplication to produce logits
logits = feature_embeddings @ class_embedding.T
# Apply sigmoid to convert logits to probabilities
probabilities = torch.sigmoid(logits)
probabilities = probabilities.squeeze()
return probabilities
class VariableInputModel(nn.Module):
def __init__(self, num_inputs, num_filts, num_classes, depth=4, nonlin='relu', lowrank=0, ema_factor=0.1,
nhead=8, num_encoder_layers=4, dim_feedforward=2048, dropout=0.1, batch_first=True, token_dim=256,
sinr_inputs=False, register=False, use_pretrained_sinr=False, freeze_sinr=False, pretrained_loc='',
text_inputs=False, image_inputs=False, env_inputs=False, class_token_transformation='identity'):
super(VariableInputModel, self).__init__()
self.headless_model = HeadlessSINR(num_inputs, num_filts, depth, nonlin, lowrank, dropout_p=dropout)
self.ema_factor = ema_factor
self.class_token_transformation = class_token_transformation
# Load pretrained state_dict if use_pretrained_sinr is set to True
if use_pretrained_sinr:
pretrained_state_dict = torch.load(pretrained_loc, weights_only=False)['state_dict']
filtered_state_dict = {k: v for k, v in pretrained_state_dict.items() if not k.startswith('class_emb')}
self.headless_model.load_state_dict(filtered_state_dict, strict=False)
#print(f'Using pretrained sinr from {pretrained_loc}')
# Freeze the SINR model if freeze_sinr is set to True
if freeze_sinr:
for param in self.headless_model.parameters():
param.requires_grad = False
print("Freezing SINR model parameters")
# self.transformer_model = MockTransformer(num_classes, num_filts)
self.transformer_model = TransformerEncoderModel(d_model=token_dim,
nhead=nhead,
num_encoder_layers=num_encoder_layers,
dim_feedforward=dim_feedforward,
dropout=dropout,
batch_first=batch_first,
output_dim=num_filts)
self.ema_embeddings = nn.Embedding(num_embeddings=num_classes, embedding_dim=num_filts)
# this is just a workaround for now to load eval embeddings - probably not needed long term
self.eval_embeddings = nn.Embedding(num_embeddings=num_classes, embedding_dim=num_filts)
self.ema_embeddings.weight.requires_grad = False
self.eval_embeddings.weight.requires_grad = False
self.num_filts=num_filts
self.token_dim = token_dim
# nn.init.xavier_uniform_(self.ema_embeddings.weight) # not needed I think
self.sinr_inputs = sinr_inputs
if self.sinr_inputs:
if self.num_filts != self.token_dim and self.class_token_transformation == 'identity':
raise ValueError("If using sinr inputs to transformer with identity class token transformation"
"then token_dim of transformer must be equal to num_filts of sinr model")
# Add a class token
self.class_token = nn.Parameter(torch.empty(1, self.token_dim))
nn.init.xavier_uniform_(self.class_token)
if register:
# Add a register token initialized with Xavier uniform initialization
self.register = nn.Parameter(torch.empty(1, self.token_dim))
# self.register = (self.register / 2)
nn.init.xavier_uniform_(self.register)
else:
self.register = None
self.text_inputs = text_inputs
if self.text_inputs:
print("JUST USING A HEADLESS SINR FOR THE TEXT MODEL RIGHT NOW")
self.text_model=HeadlessSINR(num_inputs=4096, num_filts=512, depth=2, nonlin=nonlin, lowrank=token_dim, dropout_p=dropout)
else:
self.text_model=None
self.image_inputs = image_inputs
if self.image_inputs:
print("JUST USING A HEADLESS SINR FOR THE IMAGE MODEL RIGHT NOW")
self.image_model=HeadlessSINR(num_inputs=1024, num_filts=512, depth=2, nonlin=nonlin, lowrank=token_dim, dropout_p=dropout)
else:
self.image_model=None
self.env_inputs = env_inputs
if self.env_inputs:
print("JUST USING A HEADLESS SINR FOR THE ENV MODEL RIGHT NOW")
self.env_model=HeadlessSINR(num_inputs=20, num_filts=512, depth=2, nonlin=nonlin, lowrank=token_dim, dropout_p=dropout)
else:
self.env_model=None
# Type-specific embeddings for class, register, location, text, image and env tokens
self.class_type_embedding = nn.Parameter(torch.empty(1, self.token_dim))
nn.init.xavier_uniform_(self.class_type_embedding)
if register:
self.register_type_embedding = nn.Parameter(torch.empty(1, self.token_dim))
nn.init.xavier_uniform_(self.register_type_embedding)
self.location_type_embedding = nn.Parameter(torch.empty(1, self.token_dim))
nn.init.xavier_uniform_(self.location_type_embedding)
if text_inputs:
self.text_type_embedding = nn.Parameter(torch.empty(1, self.token_dim))
nn.init.xavier_uniform_(self.text_type_embedding)
if image_inputs:
self.image_type_embedding = nn.Parameter(torch.empty(1, self.token_dim))
nn.init.xavier_uniform_(self.image_type_embedding)
if env_inputs:
self.env_type_embedding = nn.Parameter(torch.empty(1, self.token_dim))
nn.init.xavier_uniform_(self.env_type_embedding)
# Instantiate the class token transformation module
if class_token_transformation == 'identity':
self.class_token_transform = Identity(token_dim, num_filts)
elif class_token_transformation == 'linear':
self.class_token_transform = LinearTransformation(token_dim, num_filts)
elif class_token_transformation == 'single_layer_nn':
self.class_token_transform = SingleLayerNN(token_dim, num_filts, dropout_p=dropout)
elif class_token_transformation == 'two_layer_nn':
self.class_token_transform = TwoLayerNN(token_dim, num_filts, dropout_p=dropout)
elif class_token_transformation == 'sinr':
self.class_token_transform = HeadlessSINR(token_dim, num_filts, 2, nonlin, lowrank, dropout_p=dropout)
else:
raise ValueError(f"Unknown class_token_transformation: {class_token_transformation}")
def forward(self, x, context_sequence, context_mask, class_ids=None, return_feats=False,
return_class_embeddings=False, class_of_interest=None, use_eval_embeddings=False, text_emb=None,
image_emb=None, env_emb=None):
# Process input through the headless model to get feature embeddings
feature_embeddings = self.headless_model(x)
if return_feats:
return feature_embeddings
if context_sequence.dim() == 2:
context_sequence = context_sequence.unsqueeze(0) # Add batch dimension if missing
context_sequence = context_sequence[:, 1:, :]
context_mask = context_mask[:, 1:]
if self.sinr_inputs:
context_sequence = self.headless_model(context_sequence)
# Add type-specific embedding to each location token
context_sequence += self.location_type_embedding
batch_size = context_sequence.size(0)
# Initialize lists for tokens and masks
tokens = []
masks = []
# Process class token
class_token_expanded = self.class_token.expand(batch_size, -1, -1) + self.class_type_embedding
tokens.append(class_token_expanded)
# The class token is always present, so mask is False (i.e., not masked out)
class_mask = torch.zeros(batch_size, 1, dtype=torch.bool, device=context_sequence.device)
masks.append(class_mask)
# Process register token if present
if self.register is not None:
register_expanded = self.register.expand(batch_size, -1, -1) + self.register_type_embedding
tokens.append(register_expanded)
register_mask = torch.zeros(batch_size, 1, dtype=torch.bool, device=context_sequence.device)
masks.append(register_mask)
# Process text embeddings
if self.text_inputs and (text_emb is not None):
text_mask = (text_emb.sum(dim=1) == 0)
text_emb = self.text_model(text_emb)
text_emb += self.text_type_embedding
# Set embeddings to zero where mask is True
text_emb[text_mask] = 0
text_emb = text_emb.unsqueeze(1)
tokens.append(text_emb)
# Expand text_mask to match sequence dimensions
text_mask = text_mask.unsqueeze(1)
masks.append(text_mask)
# Process image embeddings
if self.image_inputs and (image_emb is not None):
image_mask = (image_emb.sum(dim=1) == 0)
image_emb = self.image_model(image_emb)
image_emb += self.image_type_embedding
image_emb[image_mask] = 0
image_emb = image_emb.unsqueeze(1)
tokens.append(image_emb)
image_mask = image_mask.unsqueeze(1)
masks.append(image_mask)
# Process env embeddings if needed (can be added similarly)
if self.env_inputs and (env_emb is not None):
env_mask = context_mask
env_emb = self.env_model(env_emb)
env_emb += self.env_type_embedding
env_emb[env_mask] = 0
env_emb = env_emb.unsqueeze(1)
tokens.append(env_emb)
env_mask = env_mask.unsqueeze(1)
masks.append(env_mask)
# Process location tokens
tokens.append(context_sequence)
masks.append(context_mask)
# Concatenate all tokens and masks
context_sequence = torch.cat(tokens, dim=1)
context_mask = torch.cat(masks, dim=1)
if use_eval_embeddings == False:
if class_of_interest == None:
# Get class-specific embeddings based on class_ids
class_token_output = self.transformer_model(src=context_sequence, src_key_padding_mask=context_mask)
# pass these through the class token transformation
class_embeddings = self.class_token_transform(class_token_output) # Shape: (batch_size, num_filts)
if return_class_embeddings:
return class_embeddings
else:
# Update EMA embeddings for these class IDs
with torch.no_grad():
if self.training:
self.update_ema_embeddings(class_ids, class_embeddings)
# Matrix multiplication to produce logits
logits = feature_embeddings @ class_embeddings.T
# Apply sigmoid to convert logits to probabilities
probabilities = torch.sigmoid(logits)
return probabilities
else:
device = self.ema_embeddings.weight.device
class_of_interest_tensor =torch.tensor([class_of_interest]).to(device)
class_embedding = self.get_ema_embeddings(class_of_interest_tensor)
print(f'using EMA estimate for class {class_of_interest}')
if return_class_embeddings:
return class_embedding
else:
# Matrix multiplication to produce logits
logits = feature_embeddings @ class_embedding.T
# Apply sigmoid to convert logits to probabilities
probabilities = torch.sigmoid(logits)
probabilities = probabilities.squeeze()
return probabilities
else:
self.eval()
if not hasattr(self, 'eval_embeddings'):
print('No Eval Embeddings for this species?!')
self.eval_embeddings = self.ema_embeddings
if class_of_interest == None:
# Get class-specific embeddings based on class_ids
class_token_output = self.transformer_model(src=context_sequence, src_key_padding_mask=context_mask)
class_embeddings = self.class_token_transform(class_token_output)
# Update EMA embeddings for these class IDs
self.generate_eval_embeddings(class_ids, class_embeddings)
# Matrix multiplication to produce logits
logits = feature_embeddings @ class_embeddings.T
# Apply sigmoid to convert logits to probabilities
probabilities = torch.sigmoid(logits)
return probabilities
else:
device = self.ema_embeddings.weight.device
class_of_interest_tensor =torch.tensor([class_of_interest]).to(device)
class_embedding = self.get_eval_embeddings(class_of_interest_tensor)
print(f'using eval embedding for class {class_of_interest}')
if return_class_embeddings:
return class_embedding
else:
# Matrix multiplication to produce logits
logits = feature_embeddings @ class_embedding.T
# Apply sigmoid to convert logits to probabilities
probabilities = torch.sigmoid(logits)
probabilities = probabilities.squeeze()
return probabilities
def get_loc_emb(self, x):
feature_embeddings = self.headless_model(x)
return feature_embeddings
def init_eval_embeddings(self, num_classes):
self.eval_embeddings = nn.Embedding(num_embeddings=num_classes, embedding_dim=self.num_filts)
nn.init.xavier_uniform_(self.eval_embeddings.weight)
def get_ema_embeddings(self, class_ids):
# Method to access EMA embeddings
return self.ema_embeddings(class_ids)
def get_eval_embeddings(self, class_ids):
# Method to access eval embeddings
return self.eval_embeddings(class_ids)
def update_ema_embeddings(self, class_ids, current_embeddings):
if self.training:
# Get unique class IDs and their counts
unique_class_ids, inverse_indices, counts = class_ids.unique(return_counts=True, return_inverse=True)
# Get current EMA embeddings for unique class IDs
ema_current = self.ema_embeddings(unique_class_ids)
# Initialize a placeholder for new EMA values
ema_new = torch.zeros_like(ema_current)
# Compute the average of current embeddings for each unique class ID
current_sum = torch.zeros_like(ema_current)
current_sum.index_add_(0, inverse_indices, current_embeddings)
current_avg = current_sum / counts.unsqueeze(1)
# Apply EMA update formula
ema_new = self.ema_factor * current_avg + (1 - self.ema_factor) * ema_current
# Update the EMA embeddings for unique class IDs
self.ema_embeddings.weight.data[unique_class_ids] = ema_new.detach() # Detach to prevent gradients
def generate_eval_embeddings(self, class_id, current_embedding):
self.eval_embeddings.weight.data[class_id, :] = current_embedding.detach() # Detach to prevent gradients
# self.eval_embeddings.weight.data[class_id] = self.ema_embeddings.weight.data[class_id] # Detach to prevent gradients
def embedding_forward(self, x, class_ids=None, return_feats=False, return_class_embeddings=False, class_of_interest=None, eval=False):
# forward method that uses ema or eval embeddings rather than context sequence
# Process input through the headless model to get feature embeddings
feature_embeddings = self.headless_model(x)
if return_feats:
return feature_embeddings
else:
if class_of_interest == None:
# Get class-specific embeddings based on class_ids
if eval == False:
class_embeddings = self.get_ema_embeddings(class_ids=class_ids)
else:
class_embeddings = self.get_eval_embeddings(class_ids=class_ids)
if return_class_embeddings:
return class_embeddings
else:
# Matrix multiplication to produce logits
logits = feature_embeddings @ class_embeddings.T
# Apply sigmoid to convert logits to probabilities
probabilities = torch.sigmoid(logits)
return probabilities
else:
if eval == False:
device = self.ema_embeddings.weight.device
class_of_interest_tensor =torch.tensor([class_of_interest]).to(device)
class_embedding = self.get_ema_embeddings(class_of_interest_tensor)
print(f'using EMA estimate for class {class_of_interest}')
if return_class_embeddings:
return class_embedding
else:
# Matrix multiplication to produce logits
logits = feature_embeddings @ class_embedding.T
# Apply sigmoid to convert logits to probabilities
probabilities = torch.sigmoid(logits)
probabilities = probabilities.squeeze()
return probabilities
else:
device = self.eval_embeddings.weight.device
class_of_interest_tensor = torch.tensor([class_of_interest]).to(device)
class_embedding = self.get_eval_embeddings(class_of_interest_tensor)
#print(f'using eval estimate for class {class_of_interest}')
if return_class_embeddings:
return class_embedding
else:
# Matrix multiplication to produce logits
logits = feature_embeddings @ class_embedding.T
# Apply sigmoid to convert logits to probabilities
probabilities = torch.sigmoid(logits)
probabilities = probabilities.squeeze()
return probabilities
class LinNet(nn.Module):
def __init__(self, num_inputs, num_classes):
super(LinNet, self).__init__()
self.num_layers = 0
self.inc_bias = False
self.class_emb = nn.Linear(num_inputs, num_classes, bias=self.inc_bias)
self.feats = nn.Identity() # does not do anything
def forward(self, x, class_of_interest=None, return_feats=False):
loc_emb = self.feats(x)
if return_feats:
return loc_emb
if class_of_interest is None:
class_pred = self.class_emb(loc_emb)
else:
class_pred = self.eval_single_class(loc_emb, class_of_interest)
return torch.sigmoid(class_pred)
def eval_single_class(self, x, class_of_interest):
if self.inc_bias:
return x @ self.class_emb.weight[class_of_interest, :] + self.class_emb.bias[class_of_interest]
else:
return x @ self.class_emb.weight[class_of_interest, :]
class ParallelMulti(torch.nn.Module):
def __init__(self, x: list[torch.nn.Module]):
super(ParallelMulti, self).__init__()
self.layers = nn.ModuleList(x)
def forward(self, xs, **kwargs):
out = torch.cat([self.layers[i](x, **kwargs) for i,x in enumerate(xs)], dim=1)
return out
class SequentialMulti(torch.nn.Sequential):
def forward(self, *inputs, **kwargs):
for module in self._modules.values():
if type(inputs) == tuple:
inputs = module(*inputs, **kwargs)
else:
inputs = module(inputs)
return inputs
# Chris's transformation classes
class Identity(nn.Module):
def __init__(self, in_dim, out_dim):
super(Identity, self).__init__()
# No parameters needed for identity transformation
def forward(self, x):
return x
class LinearTransformation(nn.Module):
def __init__(self, in_dim, out_dim, bias=True):
super(LinearTransformation, self).__init__()
self.linear = nn.Linear(in_dim, out_dim, bias=bias)
def forward(self, x):
return self.linear(x)
class SingleLayerNN(nn.Module):
def __init__(self, in_dim, out_dim, dropout_p=0.1, bias=True):
super(SingleLayerNN, self).__init__()
hidden_dim = (in_dim + out_dim) // 2 # Choose an appropriate hidden dimension
self.net = nn.Sequential(
nn.Linear(in_dim, hidden_dim, bias=bias),
nn.ReLU(),
nn.Dropout(p=dropout_p),
nn.Linear(hidden_dim, out_dim, bias=bias)
)
def forward(self, x):
return self.net(x)
class TwoLayerNN(nn.Module):
def __init__(self, in_dim, out_dim, dropout_p=0.1, bias=True):
super(TwoLayerNN, self).__init__()
hidden_dim = (in_dim + out_dim) // 2 # Choose an appropriate hidden dimension
self.net = nn.Sequential(
nn.Linear(in_dim, hidden_dim, bias=bias),
nn.ReLU(),
nn.Dropout(p=dropout_p),
nn.Linear(hidden_dim, hidden_dim, bias=bias),
nn.ReLU(),
nn.Dropout(p=dropout_p),
nn.Linear(hidden_dim, out_dim, bias=bias)
)
def forward(self, x):
return self.net(x)
class HyperNet(nn.Module):
'''
:param asdf
'''
def __init__(self, params, num_inputs, num_classes, num_filts, pos_enc_depth, species_dim, species_enc_depth, species_filts, species_enc='embed', inference_only=False):
super(HyperNet, self).__init__()
if species_enc == 'embed':
self.species_emb = nn.Embedding(num_classes, species_dim)
self.species_emb.weight.data *= 0.01
elif species_enc == 'taxa':
self.species_emb = TaxaEncoder(params, './data/inat_taxa_info.csv', species_dim)
elif species_enc == 'text':
self.species_emb = TextEncoder(params, params['text_emb_path'], species_dim, './data/inat_taxa_info.csv')
elif species_enc == 'wiki':
self.species_emb = WikiEncoder(params, params['text_emb_path'], species_dim, inference_only=inference_only)
if species_enc_depth == -1:
self.species_enc = nn.Identity()
elif species_enc_depth == 0:
self.species_enc = nn.Linear(species_dim, num_filts+1)
else:
self.species_enc = SimpleFCNet(species_dim, num_filts+1, species_filts, depth=species_enc_depth)
if 'geoprior' in params['loss']:
self.species_params = nn.Parameter(torch.randn(num_classes, species_dim))
self.species_params.data *= 0.0386
self.pos_enc = SimpleFCNet(num_inputs, num_filts, num_filts, depth=pos_enc_depth)
def forward(self, x, y):
ys, indmap = torch.unique(y, return_inverse=True)
species = self.species_enc(self.species_emb(ys))
species_w, species_b = species[...,:-1], species[...,-1:]
pos = self.pos_enc(x)
out = torch.bmm(species_w[indmap],pos[...,None])
out = (out + 0*species_b[indmap]).squeeze(-1) #TODO
if hasattr(self, 'species_params'):
out2 = torch.bmm(self.species_params[ys][indmap],pos[...,None])
out2 = out2.squeeze(-1)
out3 = (species_w, self.species_params[ys], ys)
return out, out2, out3
else:
return out
def zero_shot(self, x, species_emb):
species = self.species_enc(self.species_emb.zero_shot(species_emb))
species_w, _ = species[...,:-1], species[...,-1:]
pos = self.pos_enc(x)
out = pos @ species_w.T
return out
class TaxaEncoder(nn.Module):
def __init__(self, params, fpath, embedding_dim):
super(TaxaEncoder, self).__init__()
import datasets
with open('paths.json', 'r') as f:
paths = json.load(f)
data_dir = paths['train']
obs_file = os.path.join(data_dir, params['obs_file'])
taxa_file_snt = os.path.join(data_dir, 'taxa_subsets.json')
taxa_of_interest = datasets.get_taxa_of_interest(params['species_set'], params['num_aux_species'],
params['aux_species_seed'], params['taxa_file'], taxa_file_snt)
locs, labels, _, dates, _, _ = datasets.load_inat_data(obs_file, taxa_of_interest)
unique_taxa, class_ids = np.unique(labels, return_inverse=True)
class_to_taxa = unique_taxa.tolist()
self.fpath = fpath
ids = []
rows = []
with open(fpath, newline='') as csvfile:
spamreader = csv.reader(csvfile, delimiter=',')
for row in spamreader:
if row[0] == 'taxon_id':
continue
ids.append(int(row[0]))
rows.append(row[3:])
print()
rows = np.array(rows)
rows = [np.unique(rows[:,i], return_inverse=True)[1] for i in range(rows.shape[1])]
rows = torch.from_numpy(np.vstack(rows).T)
rows = rows
self.taxa2row = {taxaid:i for i, taxaid in enumerate(ids)}
embs = [nn.Embedding(rows[:,i].max()+2, embedding_dim, 0) for i in range(rows.shape[1])]
embs[-1] = nn.Embedding(len(class_to_taxa), embedding_dim)
rows2 = torch.zeros((len(class_to_taxa), 7), dtype=rows.dtype)
startind = rows[:,-1].max()
for i in range(len(class_to_taxa)):
if class_to_taxa[i] in ids:
rows2[i] = rows[ids.index(class_to_taxa[i])]+1
rows2[i,-1] -= 1
else:
rows2[i,-1] = startind
startind += 1
self.register_buffer('rows', rows2)
for e in embs:
e.weight.data *= 0.01
self.embs = nn.ModuleList(embs)
def forward(self, x):
inds = self.rows[x]
out = sum([self.embs[i](inds[...,i]) for i in range(inds.shape[-1])])
return out
class TextEncoder(nn.Module):
def __init__(self, params, path, embedding_dim, fpath='inat_taxa_info.csv'):
super(TextEncoder, self).__init__()
import datasets
with open('paths.json', 'r') as f:
paths = json.load(f)
data_dir = paths['train']
obs_file = os.path.join(data_dir, params['obs_file'])
taxa_file_snt = os.path.join(data_dir, 'taxa_subsets.json')
taxa_of_interest = datasets.get_taxa_of_interest(params['species_set'], params['num_aux_species'],
params['aux_species_seed'], params['taxa_file'], taxa_file_snt)
locs, labels, _, dates, _, _ = datasets.load_inat_data(obs_file, taxa_of_interest)
unique_taxa, class_ids = np.unique(labels, return_inverse=True)
class_to_taxa = unique_taxa.tolist()
self.fpath = fpath
ids = []
with open(fpath, newline='') as csvfile:
spamreader = csv.reader(csvfile, delimiter=',')
for row in spamreader:
if row[0] == 'taxon_id':
continue
ids.append(int(row[0]))
embs = torch.load(path)
if len(embs) != len(ids):
print("Warning: Number of embeddings doesn't match number of species")
ids = ids[:embs.shape[0]]
if isinstance(embs, list):
embs = torch.stack(embs)
self.taxa2row = {taxaid:i for i, taxaid in enumerate(ids)}
indmap = -1+torch.zeros(len(class_to_taxa), dtype=torch.int)
embmap = -1+torch.zeros(len(class_to_taxa), dtype=torch.int)
self.missing_emb = nn.Embedding(len(class_to_taxa)-embs.shape[0], embedding_dim)
startind = 0
for i in range(len(class_to_taxa)):
if class_to_taxa[i] in ids:
indmap[i] = ids.index(class_to_taxa[i])
else:
embmap[i] = startind
startind += 1
self.scales = nn.Parameter(torch.zeros(len(class_to_taxa), 1))
self.register_buffer('indmap', indmap, persistent=False)
self.register_buffer('embmap', embmap, persistent=False)
self.register_buffer('embs', embs, persistent=False)
if params['text_hidden_dim'] == 0:
self.linear1 = nn.Linear(embs.shape[1], embedding_dim)
else:
self.linear1 = nn.Linear(embs.shape[1], params['text_hidden_dim'])
self.linear2 = nn.Linear(params['text_hidden_dim'], embedding_dim)
self.act = nn.SiLU()
if params['text_learn_dim'] > 0:
self.learned_emb = nn.Embedding(len(class_to_taxa), params['text_learn_dim'])
self.learned_emb.weight.data *= 0.01
self.linear_learned = nn.Linear(params['text_learn_dim'], embedding_dim)
def forward(self, x):
inds = self.indmap[x]
out = self.embs[self.indmap[x].cpu()]
out = self.linear1(out)
if hasattr(self, 'linear2'):
out = self.linear2(self.act(out))
out = self.scales[x] * (out / (out.std(dim=1)[:, None]))
out[inds == -1] = self.missing_emb(self.embmap[x[inds == -1]])
if hasattr(self, 'learned_emb'):
out2 = self.learned_emb(x)
out2 = self.linear_learned(out2)
out = out+out2
return out
class WikiEncoder(nn.Module):
def __init__(self, params, path, embedding_dim, inference_only=False):
super(WikiEncoder, self).__init__()
self.path = path
if not inference_only:
import datasets
with open('paths.json', 'r') as f:
paths = json.load(f)
data_dir = paths['train']
obs_file = os.path.join(data_dir, params['obs_file'])
taxa_file_snt = os.path.join(data_dir, 'taxa_subsets.json')
taxa_of_interest = datasets.get_taxa_of_interest(params['species_set'], params['num_aux_species'],
params['aux_species_seed'], params['taxa_file'], taxa_file_snt)
locs, labels, _, dates, _, _ = datasets.load_inat_data(obs_file, taxa_of_interest)
if params['zero_shot']:
with open('paths.json', 'r') as f:
paths = json.load(f)
with open(os.path.join(paths['iucn'], 'iucn_res_5.json'), 'r') as f:
data = json.load(f)
D = np.load(os.path.join(paths['snt'], 'snt_res_5.npy'), allow_pickle=True)
D = D.item()
taxa_snt = D['taxa'].tolist()
taxa = [int(tt) for tt in data['taxa_presence'].keys()]
taxa = list(set(taxa + taxa_snt))
mask = labels != taxa[0]
for i in range(1, len(taxa)):
mask &= (labels != taxa[i])
locs = locs[mask]
dates = dates[mask]
labels = labels[mask]
unique_taxa, class_ids = np.unique(labels, return_inverse=True)
class_to_taxa = unique_taxa.tolist()
embs = torch.load(path)
ids = embs['taxon_id'].tolist()
if 'keys' in embs:
taxa_counts = torch.zeros(len(ids), dtype=torch.int32)
for i,k in embs['keys']:
taxa_counts[i] += 1
else:
taxa_counts = torch.ones(len(ids), dtype=torch.int32)
count_sum = torch.cumsum(taxa_counts, dim=0) - taxa_counts
embs = embs['data']
self.taxa2row = {taxaid:i for i, taxaid in enumerate(ids)}
indmap = -1+torch.zeros(len(class_to_taxa), dtype=torch.int)
countmap = torch.zeros(len(class_to_taxa), dtype=torch.int)
self.species_emb = nn.Embedding(len(class_to_taxa), embedding_dim)
self.species_emb.weight.data *= 0.01
for i in range(len(class_to_taxa)):
if class_to_taxa[i] in ids:
i2 = ids.index(class_to_taxa[i])
indmap[i] = count_sum[i2]
countmap[i] = taxa_counts[i2]
self.register_buffer('indmap', indmap, persistent=False)
self.register_buffer('countmap', countmap, persistent=False)
self.register_buffer('embs', embs, persistent=False)
assert embs.shape[1] == 4096
self.scale = nn.Parameter(torch.zeros(1))
if params['species_dropout'] > 0:
self.dropout = nn.Dropout(p=params['species_dropout'])
if params['text_hidden_dim'] == 0:
self.linear1 = nn.Linear(4096, embedding_dim)
else:
self.linear1 = nn.Linear(4096, params['text_hidden_dim'])
if params['text_batchnorm']:
self.bn1 = nn.BatchNorm1d(params['text_hidden_dim'])
for l in range(params['text_num_layers']-1):
setattr(self, f'linear{l+2}', nn.Linear(params['text_hidden_dim'], params['text_hidden_dim']))
if params['text_batchnorm']:
setattr(self, f'bn{l+2}', nn.BatchNorm1d(params['text_hidden_dim']))
setattr(self, f'linear{params["text_num_layers"]+1}', nn.Linear(params['text_hidden_dim'], embedding_dim))
self.act = nn.SiLU()
if params['text_learn_dim'] > 0:
self.learned_emb = nn.Embedding(len(class_to_taxa), params['text_learn_dim'])
self.learned_emb.weight.data *= 0.01
self.linear_learned = nn.Linear(params['text_learn_dim'], embedding_dim)
def forward(self, x):
inds = self.indmap[x] + (torch.rand(x.shape,device=x.device)*self.countmap[x]).floor().int()
out = self.embs[inds]
if hasattr(self, 'dropout'):
out = self.dropout(out)
out = self.linear1(out)
if hasattr(self, 'linear2'):
out = self.act(out)
if hasattr(self, 'bn1'):
out = self.bn1(out)
i = 2
while hasattr(self, f'linear{i}'):
if hasattr(self, f'linear{i}'):
out = self.act(getattr(self, f'linear{i}')(out))
if hasattr(self, f'bn{i}'):
out = getattr(self, f'bn{i}')(out)
i += 1
#out = self.scale * (out / (out.std(dim=1)[:, None]))
out2 = self.species_emb(x)
chosen = torch.rand((out.shape[0],), device=x.device)
chosen = 1+0*chosen #TODO fix this
chosen[inds == -1] = 0
out = chosen[:,None] * out + (1-chosen[:,None])*out2
if hasattr(self, 'learned_emb'):
out2 = self.learned_emb(x)
out2 = self.linear_learned(out2)
out = out+out2
return out
def zero_shot(self, species_emb):
out = species_emb
out = self.linear1(out)
if hasattr(self, 'linear2'):
out = self.act(out)
if hasattr(self, 'bn1'):
out = self.bn1(out)
i = 2
while hasattr(self, f'linear{i}'):
if hasattr(self, f'linear{i}'):
out = self.act(getattr(self, f'linear{i}')(out))
if hasattr(self, f'bn{i}'):
out = getattr(self, f'bn{i}')(out)
i += 1
return out
def zero_shot_old(self, species_emb):
out = species_emb
out = self.linear1(out)
if hasattr(self, 'linear2'):
out = self.linear2(self.act(out))
out = self.scale * (out / (out.std(dim=-1, keepdim=True)))
return out
# MINE - would only be used for my models - not currently being used at all
# CURRENTLY JUST USING A HEADLESS_SINR FOR THE TEXT ENCODER
class MultiInputTextEncoder(nn.Module):
def __init__(self, token_dim, dropout, input_dim=4096, depth=2, hidden_dim=512, nonlin='relu', batch_norm=True, layer_norm=False):
super(MultiInputTextEncoder, self).__init__()
print("THINK ABOUT IF SOME OF THESE HYPERPARAMETERS SHOULD BE DISTINCT FROM THE TRANSFORMER VERSION")
print("DEPTH / NUM_ENCODER_LAYERS, DROPOUT, DIM_FEEDFORWARD, ETC")
print("AT PRESENT WE JUST HAVE A SORT OF BASIC VERSION IMPLEMENTED THAT ATTEMPTS TO BE LIKE MAX'S VERSION")
print("ALSO, OPTION TO HAVE IT PRETRAINED? ADD RESIDUAL LAYERS?")
self.token_dim=token_dim
self.dropout=dropout
self.input_dim=input_dim
self.depth=depth
self.hidden_dim=hidden_dim
self.batch_norm = batch_norm
self.layer_norm = layer_norm
if nonlin == 'relu':
activation = nn.ReLU
elif nonlin == 'silu':
activation = nn.SiLU
else:
raise NotImplementedError('Invalid nonlinearity specified.')
self.dropout_layer = nn.Dropout(p=self.dropout)
if self.depth <= 1:
self.linear1 = nn.Linear(self.input_dim, self.token_dim)
else:
self.linear1 = nn.Linear(self.input_dim, self.hidden_dim)
if self.batch_norm:
self.bn1 = nn.BatchNorm1d(self.hidden_dim)
# if self.layer_norm:
# self.ln1 = nn.LayerNorm(self.hidden_dim)