Spaces:
Sleeping
Sleeping
import torch | |
import torch.utils.data | |
import torch.nn as nn | |
import math | |
import csv | |
import numpy as np | |
import json | |
import os | |
def get_model(params, inference_only=False): | |
if params['model'] == 'ResidualFCNet': | |
return ResidualFCNet(params['input_dim'] + params['input_time_dim'] + (20 if 'env' in params['input_enc'] and 'contrastive' not in params['input_enc'] else 0) + (1 if params['noise_time'] else 0), params['num_classes'] + (20 if 'env' in params['loss'] else 0), params['num_filts'], params['depth']) | |
elif params['model'] == 'LinNet': | |
return LinNet(params['input_dim'] + params['input_time_dim'] + (20 if 'env' in params['input_enc'] else 0) + (1 if params['noise_time'] else 0), params['num_classes']) | |
elif params['model'] == 'HyperNet': | |
return HyperNet(params, params['input_dim'] + (20 if 'env' in params['input_enc'] else 0), params['num_classes'], params['num_filts'], params['depth'], | |
params['species_dim'], params['species_enc_depth'], params['species_filts'], params['species_enc'], inference_only=inference_only) | |
# chris models | |
elif params['model'] == 'MultiInputModel': | |
return MultiInputModel(num_inputs=params['input_dim'] + params['input_time_dim'] + (20 if 'env' in params['input_enc'] and 'contrastive' not in params['input_enc'] else 0) + (1 if params['noise_time'] else 0), | |
num_filts=params['num_filts'], num_classes=params['num_classes'] + (20 if 'env' in params['loss'] else 0), | |
depth=params['depth'], ema_factor=params['ema_factor'], nhead=params['num_heads'], num_encoder_layers=params['species_enc_depth'], | |
dim_feedforward=params['species_filts'], dropout=params['transformer_dropout'], | |
batch_first=True, token_dim=(params['species_dim'] + (20 if 'env' in params['transformer_input_enc'] else 0)), | |
sinr_inputs=True if 'sinr' in params['transformer_input_enc'] else False, | |
register=params['use_register'], use_pretrained_sinr=params['use_pretrained_sinr'], | |
freeze_sinr=params['freeze_sinr'], pretrained_loc=params['pretrained_loc'], | |
text_inputs=params['use_text_inputs'], class_token_transformation=params['class_token_transformation']) | |
elif params['model'] == 'VariableInputModel': | |
return VariableInputModel(num_inputs=params['input_dim'] + params['input_time_dim'] + (20 if 'env' in params['input_enc'] and 'contrastive' not in params['input_enc'] else 0) + (1 if params['noise_time'] else 0), | |
num_filts=params['num_filts'], num_classes=params['num_classes'] + (20 if 'env' in params['loss'] else 0), | |
depth=params['depth'], ema_factor=params['ema_factor'], nhead=params['num_heads'], num_encoder_layers=params['species_enc_depth'], | |
dim_feedforward=params['species_filts'], dropout=params['transformer_dropout'], | |
batch_first=True, token_dim=(params['species_dim'] + (20 if 'env' in params['transformer_input_enc'] else 0)), | |
sinr_inputs=True if 'sinr' in params['transformer_input_enc'] else False, | |
register=params['use_register'], use_pretrained_sinr=params['use_pretrained_sinr'], | |
freeze_sinr=params['freeze_sinr'], pretrained_loc=params['pretrained_loc'], | |
text_inputs=params['use_text_inputs'], image_inputs=params['use_image_inputs'], | |
env_inputs=params['use_env_inputs'], | |
class_token_transformation=params['class_token_transformation']) | |
# class VariableInputModel(nn.Module): | |
# def __init__(self, num_inputs, num_filts, num_classes, depth=4, nonlin='relu', lowrank=0, ema_factor=0.1, | |
# nhead=8, num_encoder_layers=4, dim_feedforward=2048, dropout=0.1, batch_first=True, token_dim=256, | |
# sinr_inputs=False, register=False, use_pretrained_sinr=False, freeze_sinr=False, pretrained_loc='', | |
# text_inputs=False, image_inputs=False, env_inputs=False, class_token_transformation='identity'): | |
class ResLayer(nn.Module): | |
def __init__(self, linear_size, activation=nn.ReLU, p=0.5): | |
super(ResLayer, self).__init__() | |
self.l_size = linear_size | |
self.nonlin1 = activation() | |
self.nonlin2 = activation() | |
self.dropout1 = nn.Dropout(p=p) | |
self.w1 = nn.Linear(self.l_size, self.l_size) | |
self.w2 = nn.Linear(self.l_size, self.l_size) | |
def forward(self, x): | |
y = self.w1(x) | |
y = self.nonlin1(y) | |
y = self.dropout1(y) | |
y = self.w2(y) | |
y = self.nonlin2(y) | |
out = x + y | |
return out | |
class ResidualFCNet(nn.Module): | |
def __init__(self, num_inputs, num_classes, num_filts, depth=4, nonlin='relu', lowrank=0, dropout_p=0.5): | |
super(ResidualFCNet, self).__init__() | |
self.inc_bias = False | |
if lowrank < num_filts and lowrank != 0: | |
l1 = nn.Linear(num_filts if depth != -1 else num_inputs, lowrank, bias=self.inc_bias) | |
l2 = nn.Linear(lowrank, num_classes, bias=self.inc_bias) | |
self.class_emb = nn.Sequential(l1, l2) | |
else: | |
self.class_emb = nn.Linear(num_filts if depth != -1 else num_inputs, num_classes, bias=self.inc_bias) | |
if nonlin == 'relu': | |
activation = nn.ReLU | |
elif nonlin == 'silu': | |
activation = nn.SiLU | |
else: | |
raise NotImplementedError('Invalid nonlinearity specified.') | |
layers = [] | |
if depth != -1: | |
layers.append(nn.Linear(num_inputs, num_filts)) | |
layers.append(activation()) | |
for i in range(depth): | |
layers.append(ResLayer(num_filts, activation=activation)) | |
else: | |
layers.append(nn.Identity()) | |
self.feats = torch.nn.Sequential(*layers) | |
def forward(self, x, class_of_interest=None, return_feats=False): | |
loc_emb = self.feats(x) | |
if return_feats: | |
return loc_emb | |
if class_of_interest is None: | |
class_pred = self.class_emb(loc_emb) | |
else: | |
class_pred = self.eval_single_class(loc_emb, class_of_interest), self.eval_single_class(loc_emb, -1) | |
return torch.sigmoid(class_pred[0]), torch.sigmoid(class_pred[1]) | |
return torch.sigmoid(class_pred) | |
def eval_single_class(self, x, class_of_interest): | |
if self.inc_bias: | |
return x @ self.class_emb.weight[class_of_interest, :] + self.class_emb.bias[class_of_interest] | |
else: | |
return x @ self.class_emb.weight[class_of_interest, :] | |
class SimpleFCNet(ResidualFCNet): | |
def forward(self, x, return_feats=True): | |
assert return_feats | |
loc_emb = self.feats(x) | |
class_pred = self.class_emb(loc_emb) | |
return class_pred | |
class MockTransformer(nn.Module): | |
def __init__(self, num_classes, num_dims): | |
super(MockTransformer, self).__init__() | |
self.species_emb = nn.Embedding(num_embeddings=num_classes, embedding_dim=num_dims) | |
def forward(self, class_ids): | |
return self.species_emb(class_ids) | |
class CombinedModel(nn.Module): | |
def __init__(self, num_inputs, num_filts, num_classes, depth=4, nonlin='relu', lowrank=0, ema_factor=0.1): | |
super(CombinedModel, self).__init__() | |
self.headless_model = HeadlessSINR(num_inputs, num_filts, depth, nonlin, lowrank) | |
if lowrank < num_filts and lowrank != 0: | |
self.transformer_model = MockTransformer(num_classes, lowrank) | |
else: | |
self.transformer_model = MockTransformer(num_classes, num_filts) | |
self.ema_factor = ema_factor | |
self.ema_embeddings = nn.Embedding(num_embeddings=num_classes, embedding_dim=lowrank if (lowrank < num_filts and lowrank != 0) else num_filts) | |
self.ema_embeddings.weight.data.copy_(self.transformer_model.species_emb.weight.data) # Initialize EMA with the same values as transformer | |
# this will have to change when I start using the actual transformer | |
def forward(self, x, class_ids=None, return_feats=False, return_class_embeddings=False, class_of_interest=None): | |
# Process input through the headless model to get feature embeddings | |
feature_embeddings = self.headless_model(x) | |
if return_feats: | |
return feature_embeddings | |
else: | |
if class_of_interest == None: | |
# Get class-specific embeddings based on class_ids | |
class_embeddings = self.transformer_model(class_ids) | |
if return_class_embeddings: | |
return class_embeddings | |
else: | |
# Update EMA embeddings for these class IDs | |
if self.training: | |
self.update_ema_embeddings(class_ids, class_embeddings) | |
# Matrix multiplication to produce logits | |
logits = feature_embeddings @ class_embeddings.T | |
# Apply sigmoid to convert logits to probabilities | |
probabilities = torch.sigmoid(logits) | |
return probabilities | |
else: | |
device = self.ema_embeddings.weight.device | |
class_of_interest_tensor =torch.tensor([class_of_interest]).to(device) | |
class_embedding = self.get_ema_embeddings(class_of_interest_tensor) | |
print(f'using EMA estimate for class {class_of_interest}') | |
if return_class_embeddings: | |
return class_embedding | |
else: | |
# Matrix multiplication to produce logits | |
logits = feature_embeddings @ class_embedding.T | |
# Apply sigmoid to convert logits to probabilities | |
probabilities = torch.sigmoid(logits) | |
probabilities = probabilities.squeeze() | |
return probabilities | |
def update_ema_embeddings(self, class_ids, current_embeddings): | |
if self.training: | |
# Get current EMA embeddings for the class IDs | |
ema_current = self.ema_embeddings(class_ids) | |
# Calculate new EMA values | |
ema_new = self.ema_factor * current_embeddings + (1 - self.ema_factor) * ema_current | |
# Update the EMA embeddings | |
self.ema_embeddings.weight.data[class_ids] = ema_new.detach() # Detach to prevent gradients from flowing here | |
def get_ema_embeddings(self, class_ids): | |
# Method to access EMA embeddings | |
return self.ema_embeddings(class_ids) | |
class HeadlessSINR(nn.Module): | |
def __init__(self, num_inputs, num_filts, depth=4, nonlin='relu', lowrank=0, dropout_p=0.5): | |
super(HeadlessSINR, self).__init__() | |
self.inc_bias = False | |
self.low_rank_feats = None | |
if lowrank < num_filts and lowrank != 0: | |
l1 = nn.Linear(num_filts if depth != -1 else num_inputs, lowrank, bias=self.inc_bias) | |
self.low_rank_feats = l1 | |
# else: | |
# self.class_emb = nn.Linear(num_filts if depth != -1 else num_inputs, num_classes, bias=self.inc_bias) | |
if nonlin == 'relu': | |
activation = nn.ReLU | |
elif nonlin == 'silu': | |
activation = nn.SiLU | |
else: | |
raise NotImplementedError('Invalid nonlinearity specified.') | |
# Create the layers list for feature extraction | |
layers = [] | |
if depth != -1: | |
layers.append(nn.Linear(num_inputs, num_filts)) | |
layers.append(activation()) | |
for i in range(depth): | |
layers.append(ResLayer(num_filts, activation=activation, p=dropout_p)) | |
else: | |
layers.append(nn.Identity()) | |
# Include low-rank features in the sequential model if it is defined | |
if self.low_rank_feats: | |
# Apply initial layers then low-rank features | |
layers.append(self.low_rank_feats) | |
# Set up the features as a sequential model | |
self.feats = nn.Sequential(*layers) | |
def forward(self, x): | |
loc_emb = self.feats(x) | |
return loc_emb | |
class TransformerEncoderModel(nn.Module): | |
def __init__(self, d_model=256, nhead=8, num_encoder_layers=4, dim_feedforward=2048, dropout=0.1, activation='relu', | |
batch_first=True, output_dim=256): # BATCH FIRST MIGHT HAVE TO CHANGE | |
super(TransformerEncoderModel, self).__init__() | |
self.input_layer_norm = nn.LayerNorm(normalized_shape=d_model) | |
# Create an encoder layer | |
encoder_layer = nn.TransformerEncoderLayer( | |
d_model=d_model, | |
nhead=nhead, | |
dim_feedforward=dim_feedforward, | |
dropout=dropout, | |
activation=activation, | |
batch_first=batch_first | |
) | |
# Stack the encoder layers into an encoder module | |
self.transformer_encoder = nn.TransformerEncoder( | |
encoder_layer=encoder_layer, | |
num_layers=num_encoder_layers | |
) | |
# Example output layer (modify according to your needs) | |
self.output_layer = nn.Linear(d_model, output_dim) | |
def forward(self, src, src_mask=None, src_key_padding_mask=None): | |
""" | |
Args: | |
src: the sequence to the encoder (shape: [seq_length, batch_size, d_model]) | |
src_mask: the mask for the src sequence (shape: [seq_length, seq_length]) | |
src_key_padding_mask: the mask for the padding tokens (shape: [batch_size, seq_length]) | |
Returns: | |
output of the transformer encoder | |
""" | |
# Pass the input through the transformer encoder | |
encoder_input = self.input_layer_norm(src) | |
encoder_output = self.transformer_encoder(encoder_input, src_key_padding_mask=src_key_padding_mask, mask=src_mask) | |
# # Pass the encoder output through the output layer | |
# output = self.output_layer(encoder_output) | |
# Assuming the class token is the first in the sequence | |
# batch_first so we have (batch, sequence, dim) | |
if encoder_output.ndim == 2: | |
# in situations where we don't have a batch | |
encoder_output = encoder_output.unsqueeze(0) | |
class_token_embedding = encoder_output[:, 0, :] | |
output = self.output_layer(class_token_embedding) # Process only the class token embedding | |
return output | |
class MultiInputModel(nn.Module): | |
def __init__(self, num_inputs, num_filts, num_classes, depth=4, nonlin='relu', lowrank=0, ema_factor=0.1, | |
nhead=8, num_encoder_layers=4, dim_feedforward=2048, dropout=0.1, batch_first=True, token_dim=256, | |
sinr_inputs=False, register=False, use_pretrained_sinr=False, freeze_sinr=False, pretrained_loc='', | |
text_inputs=False, class_token_transformation='identity'): | |
super(MultiInputModel, self).__init__() | |
self.headless_model = HeadlessSINR(num_inputs, num_filts, depth, nonlin, lowrank, dropout_p=dropout) | |
self.ema_factor = ema_factor | |
self.class_token_transformation = class_token_transformation | |
# Load pretrained state_dict if use_pretrained_sinr is set to True | |
if use_pretrained_sinr: | |
#pretrained_state_dict = torch.load(pretrained_loc, weights_only=False)['state_dict'] | |
pretrained_state_dict = torch.load(pretrained_loc, map_location=torch.device('cpu'))['state_dict'] | |
filtered_state_dict = {k: v for k, v in pretrained_state_dict.items() if not k.startswith('class_emb')} | |
self.headless_model.load_state_dict(filtered_state_dict, strict=False) | |
#print(f'Using pretrained sinr from {pretrained_loc}') | |
# Freeze the SINR model if freeze_sinr is set to True | |
if freeze_sinr: | |
for param in self.headless_model.parameters(): | |
param.requires_grad = False | |
print("Freezing SINR model parameters") | |
# self.transformer_model = MockTransformer(num_classes, num_filts) | |
self.transformer_model = TransformerEncoderModel(d_model=token_dim, | |
nhead=nhead, | |
num_encoder_layers=num_encoder_layers, | |
dim_feedforward=dim_feedforward, | |
dropout=dropout, | |
batch_first=batch_first, | |
output_dim=num_filts) | |
self.ema_embeddings = nn.Embedding(num_embeddings=num_classes, embedding_dim=num_filts) | |
# this is just a workaround for now to load eval embeddings - probably not needed long term | |
self.eval_embeddings = nn.Embedding(num_embeddings=num_classes, embedding_dim=num_filts) | |
self.ema_embeddings.weight.requires_grad = False | |
self.eval_embeddings.weight.requires_grad = False | |
self.num_filts=num_filts | |
self.token_dim = token_dim | |
# nn.init.xavier_uniform_(self.ema_embeddings.weight) # not needed I think | |
self.sinr_inputs = sinr_inputs | |
if self.sinr_inputs: | |
if self.num_filts != self.token_dim and self.class_token_transformation == 'identity': | |
raise ValueError("If using sinr inputs to transformer with identity class token transformation" | |
"then token_dim of transformer must be equal to num_filts of sinr model") | |
# Add a class token | |
self.class_token = nn.Parameter(torch.empty(1, self.token_dim)) | |
nn.init.xavier_uniform_(self.class_token) | |
if register: | |
# Add a register token initialized with Xavier uniform initialization | |
self.register = nn.Parameter(torch.empty(1, self.token_dim)) | |
# self.register = (self.register / 2) | |
nn.init.xavier_uniform_(self.register) | |
else: | |
self.register = None | |
self.text_inputs = text_inputs | |
if self.text_inputs: | |
#print("JUST USING A HEADLESS SINR FOR THE TEXT MODEL RIGHT NOW") | |
self.text_model=HeadlessSINR(num_inputs=4096, num_filts=512, depth=2, nonlin=nonlin, lowrank=token_dim, dropout_p=dropout) | |
else: | |
self.text_model=None | |
# Type-specific embeddings for class, register, location, and text tokens | |
self.class_type_embedding = nn.Parameter(torch.empty(1, self.token_dim)) | |
nn.init.xavier_uniform_(self.class_type_embedding) | |
if register: | |
self.register_type_embedding = nn.Parameter(torch.empty(1, self.token_dim)) | |
nn.init.xavier_uniform_(self.register_type_embedding) | |
self.location_type_embedding = nn.Parameter(torch.empty(1, self.token_dim)) | |
nn.init.xavier_uniform_(self.location_type_embedding) | |
if text_inputs: | |
self.text_type_embedding = nn.Parameter(torch.empty(1, self.token_dim)) | |
nn.init.xavier_uniform_(self.text_type_embedding) | |
# Instantiate the class token transformation module | |
if class_token_transformation == 'identity': | |
self.class_token_transform = Identity(token_dim, num_filts) | |
elif class_token_transformation == 'linear': | |
self.class_token_transform = LinearTransformation(token_dim, num_filts) | |
elif class_token_transformation == 'single_layer_nn': | |
self.class_token_transform = SingleLayerNN(token_dim, num_filts, dropout_p=dropout) | |
elif class_token_transformation == 'two_layer_nn': | |
self.class_token_transform = TwoLayerNN(token_dim, num_filts, dropout_p=dropout) | |
elif class_token_transformation == 'sinr': | |
self.class_token_transform = HeadlessSINR(token_dim, num_filts, depth, nonlin, lowrank, dropout_p=dropout) | |
else: | |
raise ValueError(f"Unknown class_token_transformation: {class_token_transformation}") | |
def forward(self, x, context_sequence, context_mask, class_ids=None, return_feats=False, return_class_embeddings=False, class_of_interest=None, use_eval_embeddings=False, text_emb=None): | |
# Process input through the headless model to get feature embeddings | |
feature_embeddings = self.headless_model(x) | |
if return_feats: | |
return feature_embeddings | |
if context_sequence.dim() == 2: | |
context_sequence = context_sequence.unsqueeze(0) # Add batch dimension if missing | |
context_sequence = context_sequence[:, 1:, :] | |
if self.sinr_inputs: | |
# Pass through the headless model | |
context_sequence = self.headless_model(context_sequence) | |
# Add type-specific embedding to each location token | |
# print("SEE IF THIS WORKS") | |
context_sequence += self.location_type_embedding | |
batch_size = context_sequence.size(0) | |
# Expand the class token to match the batch size and add its type-specific embedding | |
class_token_expanded = self.class_token.expand(batch_size, -1, -1) + self.class_type_embedding | |
if self.text_inputs and (text_emb is not None): | |
text_mask = (text_emb.sum(dim=1) == 0) | |
text_emb = self.text_model(text_emb) | |
text_emb += self.text_type_embedding | |
text_emb[text_mask] = 0 | |
# Reshape text_emb to have the shape (batch_size, 1, embedding_dim) | |
text_emb = text_emb.unsqueeze(1) | |
if self.register is None: | |
# context sequence = learnable class_token + rest of sequence | |
if self.text_inputs: | |
# Add the class token and text embeddings to the context sequence | |
context_sequence = torch.cat((class_token_expanded, text_emb, context_sequence), dim=1) | |
# Pad the context mask to account for the added text embeddings | |
context_mask = nn.functional.pad(context_mask, pad=(1, 0), value=False) | |
# Update the new part of the mask with the text_mask | |
context_mask[:, 1] = text_mask # Apply mask directly | |
else: | |
context_sequence = torch.cat((class_token_expanded, context_sequence), dim=1) | |
else: | |
# Expand the register token to match the batch size and add its type-specific embedding | |
register_expanded = self.register.expand(batch_size, -1, -1) + self.register_type_embedding | |
if self.text_inputs: | |
# Add all components: class token, register, text embeddings, and context | |
context_sequence = torch.cat((class_token_expanded, register_expanded, text_emb, context_sequence), | |
dim=1) | |
# Double pad the context mask: first for register, then for text embeddings | |
context_mask = nn.functional.pad(context_mask, pad=(1, 0), value=False) | |
context_mask = nn.functional.pad(context_mask, pad=(1, 0), value=False) | |
# Update the new part of the mask for text embeddings | |
context_mask[:, register_expanded.size(1) + 1] = text_mask # Apply mask directly | |
else: | |
context_sequence = torch.cat((class_token_expanded, register_expanded, context_sequence), dim=1) | |
# Update the context mask to account for the register token | |
context_mask = nn.functional.pad(context_mask, pad=(1, 0), value=False) | |
if use_eval_embeddings == False: | |
if class_of_interest == None: | |
# Get class-specific embeddings based on class_ids | |
class_token_output = self.transformer_model(src=context_sequence, src_key_padding_mask=context_mask) | |
# pass these through the class token transformation | |
class_embeddings = self.class_token_transform(class_token_output) # Shape: (batch_size, num_filts) | |
if return_class_embeddings: | |
return class_embeddings | |
else: | |
# Update EMA embeddings for these class IDs | |
with torch.no_grad(): | |
if self.training: | |
self.update_ema_embeddings(class_ids, class_embeddings) | |
# Matrix multiplication to produce logits | |
logits = feature_embeddings @ class_embeddings.T | |
# Apply sigmoid to convert logits to probabilities | |
probabilities = torch.sigmoid(logits) | |
return probabilities | |
else: | |
device = self.ema_embeddings.weight.device | |
class_of_interest_tensor =torch.tensor([class_of_interest]).to(device) | |
class_embedding = self.get_ema_embeddings(class_of_interest_tensor) | |
print(f'using EMA estimate for class {class_of_interest}') | |
if return_class_embeddings: | |
return class_embedding | |
else: | |
# Matrix multiplication to produce logits | |
logits = feature_embeddings @ class_embedding.T | |
# Apply sigmoid to convert logits to probabilities | |
probabilities = torch.sigmoid(logits) | |
probabilities = probabilities.squeeze() | |
return probabilities | |
else: | |
self.eval() | |
if not hasattr(self, 'eval_embeddings'): | |
self.eval_embeddings = self.ema_embeddings | |
if class_of_interest == None: | |
# Get class-specific embeddings based on class_ids | |
class_token_output = self.transformer_model(src=context_sequence, src_key_padding_mask=context_mask) | |
class_embeddings = self.class_token_transform(class_token_output) | |
# Update EMA embeddings for these class IDs | |
self.generate_eval_embeddings(class_ids, class_embeddings) | |
# Matrix multiplication to produce logits | |
logits = feature_embeddings @ class_embeddings.T | |
# Apply sigmoid to convert logits to probabilities | |
probabilities = torch.sigmoid(logits) | |
return probabilities | |
else: | |
device = self.ema_embeddings.weight.device | |
class_of_interest_tensor =torch.tensor([class_of_interest]).to(device) | |
class_embedding = self.get_eval_embeddings(class_of_interest_tensor) | |
print(f'using eval embedding for class {class_of_interest}') | |
if return_class_embeddings: | |
return class_embedding | |
else: | |
# Matrix multiplication to produce logits | |
logits = feature_embeddings @ class_embedding.T | |
# Apply sigmoid to convert logits to probabilities | |
probabilities = torch.sigmoid(logits) | |
probabilities = probabilities.squeeze() | |
return probabilities | |
def init_eval_embeddings(self, num_classes): | |
self.eval_embeddings = nn.Embedding(num_embeddings=num_classes, embedding_dim=self.num_filts) | |
nn.init.xavier_uniform_(self.eval_embeddings.weight) | |
def get_ema_embeddings(self, class_ids): | |
# Method to access EMA embeddings | |
return self.ema_embeddings(class_ids) | |
def get_eval_embeddings(self, class_ids): | |
# Method to access eval embeddings | |
return self.eval_embeddings(class_ids) | |
def update_ema_embeddings(self, class_ids, current_embeddings): | |
if self.training: | |
# Get unique class IDs and their counts | |
unique_class_ids, inverse_indices, counts = class_ids.unique(return_counts=True, return_inverse=True) | |
# Get current EMA embeddings for unique class IDs | |
ema_current = self.ema_embeddings(unique_class_ids) | |
# Initialize a placeholder for new EMA values | |
ema_new = torch.zeros_like(ema_current) | |
# Compute the average of current embeddings for each unique class ID | |
current_sum = torch.zeros_like(ema_current) | |
current_sum.index_add_(0, inverse_indices, current_embeddings) | |
current_avg = current_sum / counts.unsqueeze(1) | |
# Apply EMA update formula | |
ema_new = self.ema_factor * current_avg + (1 - self.ema_factor) * ema_current | |
# Update the EMA embeddings for unique class IDs | |
self.ema_embeddings.weight.data[unique_class_ids] = ema_new.detach() # Detach to prevent gradients | |
def generate_eval_embeddings(self, class_id, current_embedding): | |
self.eval_embeddings.weight.data[class_id, :] = current_embedding.detach() # Detach to prevent gradients | |
# self.eval_embeddings.weight.data[class_id] = self.ema_embeddings.weight.data[class_id] # Detach to prevent gradients | |
def embedding_forward(self, x, class_ids=None, return_feats=False, return_class_embeddings=False, class_of_interest=None, eval=False): | |
# forward method that uses ema or eval embeddings rather than context sequence | |
# Process input through the headless model to get feature embeddings | |
feature_embeddings = self.headless_model(x) | |
if return_feats: | |
return feature_embeddings | |
else: | |
if class_of_interest == None: | |
# Get class-specific embeddings based on class_ids | |
if eval == False: | |
class_embeddings = self.get_ema_embeddings(class_ids=class_ids) | |
else: | |
class_embeddings = self.get_eval_embeddings(class_ids=class_ids) | |
if return_class_embeddings: | |
return class_embeddings | |
else: | |
# Matrix multiplication to produce logits | |
logits = feature_embeddings @ class_embeddings.T | |
# Apply sigmoid to convert logits to probabilities | |
probabilities = torch.sigmoid(logits) | |
return probabilities | |
else: | |
if eval == False: | |
device = self.ema_embeddings.weight.device | |
class_of_interest_tensor =torch.tensor([class_of_interest]).to(device) | |
class_embedding = self.get_ema_embeddings(class_of_interest_tensor) | |
print(f'using EMA estimate for class {class_of_interest}') | |
if return_class_embeddings: | |
return class_embedding | |
else: | |
# Matrix multiplication to produce logits | |
logits = feature_embeddings @ class_embedding.T | |
# Apply sigmoid to convert logits to probabilities | |
probabilities = torch.sigmoid(logits) | |
probabilities = probabilities.squeeze() | |
return probabilities | |
else: | |
device = self.eval_embeddings.weight.device | |
class_of_interest_tensor = torch.tensor([class_of_interest]).to(device) | |
class_embedding = self.get_eval_embeddings(class_of_interest_tensor) | |
#print(f'using eval estimate for class {class_of_interest}') | |
if return_class_embeddings: | |
return class_embedding | |
else: | |
# Matrix multiplication to produce logits | |
logits = feature_embeddings @ class_embedding.T | |
# Apply sigmoid to convert logits to probabilities | |
probabilities = torch.sigmoid(logits) | |
probabilities = probabilities.squeeze() | |
return probabilities | |
class VariableInputModel(nn.Module): | |
def __init__(self, num_inputs, num_filts, num_classes, depth=4, nonlin='relu', lowrank=0, ema_factor=0.1, | |
nhead=8, num_encoder_layers=4, dim_feedforward=2048, dropout=0.1, batch_first=True, token_dim=256, | |
sinr_inputs=False, register=False, use_pretrained_sinr=False, freeze_sinr=False, pretrained_loc='', | |
text_inputs=False, image_inputs=False, env_inputs=False, class_token_transformation='identity'): | |
super(VariableInputModel, self).__init__() | |
self.headless_model = HeadlessSINR(num_inputs, num_filts, depth, nonlin, lowrank, dropout_p=dropout) | |
self.ema_factor = ema_factor | |
self.class_token_transformation = class_token_transformation | |
# Load pretrained state_dict if use_pretrained_sinr is set to True | |
if use_pretrained_sinr: | |
pretrained_state_dict = torch.load(pretrained_loc, weights_only=False)['state_dict'] | |
filtered_state_dict = {k: v for k, v in pretrained_state_dict.items() if not k.startswith('class_emb')} | |
self.headless_model.load_state_dict(filtered_state_dict, strict=False) | |
#print(f'Using pretrained sinr from {pretrained_loc}') | |
# Freeze the SINR model if freeze_sinr is set to True | |
if freeze_sinr: | |
for param in self.headless_model.parameters(): | |
param.requires_grad = False | |
print("Freezing SINR model parameters") | |
# self.transformer_model = MockTransformer(num_classes, num_filts) | |
self.transformer_model = TransformerEncoderModel(d_model=token_dim, | |
nhead=nhead, | |
num_encoder_layers=num_encoder_layers, | |
dim_feedforward=dim_feedforward, | |
dropout=dropout, | |
batch_first=batch_first, | |
output_dim=num_filts) | |
self.ema_embeddings = nn.Embedding(num_embeddings=num_classes, embedding_dim=num_filts) | |
# this is just a workaround for now to load eval embeddings - probably not needed long term | |
self.eval_embeddings = nn.Embedding(num_embeddings=num_classes, embedding_dim=num_filts) | |
self.ema_embeddings.weight.requires_grad = False | |
self.eval_embeddings.weight.requires_grad = False | |
self.num_filts=num_filts | |
self.token_dim = token_dim | |
# nn.init.xavier_uniform_(self.ema_embeddings.weight) # not needed I think | |
self.sinr_inputs = sinr_inputs | |
if self.sinr_inputs: | |
if self.num_filts != self.token_dim and self.class_token_transformation == 'identity': | |
raise ValueError("If using sinr inputs to transformer with identity class token transformation" | |
"then token_dim of transformer must be equal to num_filts of sinr model") | |
# Add a class token | |
self.class_token = nn.Parameter(torch.empty(1, self.token_dim)) | |
nn.init.xavier_uniform_(self.class_token) | |
if register: | |
# Add a register token initialized with Xavier uniform initialization | |
self.register = nn.Parameter(torch.empty(1, self.token_dim)) | |
# self.register = (self.register / 2) | |
nn.init.xavier_uniform_(self.register) | |
else: | |
self.register = None | |
self.text_inputs = text_inputs | |
if self.text_inputs: | |
print("JUST USING A HEADLESS SINR FOR THE TEXT MODEL RIGHT NOW") | |
self.text_model=HeadlessSINR(num_inputs=4096, num_filts=512, depth=2, nonlin=nonlin, lowrank=token_dim, dropout_p=dropout) | |
else: | |
self.text_model=None | |
self.image_inputs = image_inputs | |
if self.image_inputs: | |
print("JUST USING A HEADLESS SINR FOR THE IMAGE MODEL RIGHT NOW") | |
self.image_model=HeadlessSINR(num_inputs=1024, num_filts=512, depth=2, nonlin=nonlin, lowrank=token_dim, dropout_p=dropout) | |
else: | |
self.image_model=None | |
self.env_inputs = env_inputs | |
if self.env_inputs: | |
print("JUST USING A HEADLESS SINR FOR THE ENV MODEL RIGHT NOW") | |
self.env_model=HeadlessSINR(num_inputs=20, num_filts=512, depth=2, nonlin=nonlin, lowrank=token_dim, dropout_p=dropout) | |
else: | |
self.env_model=None | |
# Type-specific embeddings for class, register, location, text, image and env tokens | |
self.class_type_embedding = nn.Parameter(torch.empty(1, self.token_dim)) | |
nn.init.xavier_uniform_(self.class_type_embedding) | |
if register: | |
self.register_type_embedding = nn.Parameter(torch.empty(1, self.token_dim)) | |
nn.init.xavier_uniform_(self.register_type_embedding) | |
self.location_type_embedding = nn.Parameter(torch.empty(1, self.token_dim)) | |
nn.init.xavier_uniform_(self.location_type_embedding) | |
if text_inputs: | |
self.text_type_embedding = nn.Parameter(torch.empty(1, self.token_dim)) | |
nn.init.xavier_uniform_(self.text_type_embedding) | |
if image_inputs: | |
self.image_type_embedding = nn.Parameter(torch.empty(1, self.token_dim)) | |
nn.init.xavier_uniform_(self.image_type_embedding) | |
if env_inputs: | |
self.env_type_embedding = nn.Parameter(torch.empty(1, self.token_dim)) | |
nn.init.xavier_uniform_(self.env_type_embedding) | |
# Instantiate the class token transformation module | |
if class_token_transformation == 'identity': | |
self.class_token_transform = Identity(token_dim, num_filts) | |
elif class_token_transformation == 'linear': | |
self.class_token_transform = LinearTransformation(token_dim, num_filts) | |
elif class_token_transformation == 'single_layer_nn': | |
self.class_token_transform = SingleLayerNN(token_dim, num_filts, dropout_p=dropout) | |
elif class_token_transformation == 'two_layer_nn': | |
self.class_token_transform = TwoLayerNN(token_dim, num_filts, dropout_p=dropout) | |
elif class_token_transformation == 'sinr': | |
self.class_token_transform = HeadlessSINR(token_dim, num_filts, 2, nonlin, lowrank, dropout_p=dropout) | |
else: | |
raise ValueError(f"Unknown class_token_transformation: {class_token_transformation}") | |
def forward(self, x, context_sequence, context_mask, class_ids=None, return_feats=False, | |
return_class_embeddings=False, class_of_interest=None, use_eval_embeddings=False, text_emb=None, | |
image_emb=None, env_emb=None): | |
# Process input through the headless model to get feature embeddings | |
feature_embeddings = self.headless_model(x) | |
if return_feats: | |
return feature_embeddings | |
if context_sequence.dim() == 2: | |
context_sequence = context_sequence.unsqueeze(0) # Add batch dimension if missing | |
context_sequence = context_sequence[:, 1:, :] | |
context_mask = context_mask[:, 1:] | |
if self.sinr_inputs: | |
context_sequence = self.headless_model(context_sequence) | |
# Add type-specific embedding to each location token | |
context_sequence += self.location_type_embedding | |
batch_size = context_sequence.size(0) | |
# Initialize lists for tokens and masks | |
tokens = [] | |
masks = [] | |
# Process class token | |
class_token_expanded = self.class_token.expand(batch_size, -1, -1) + self.class_type_embedding | |
tokens.append(class_token_expanded) | |
# The class token is always present, so mask is False (i.e., not masked out) | |
class_mask = torch.zeros(batch_size, 1, dtype=torch.bool, device=context_sequence.device) | |
masks.append(class_mask) | |
# Process register token if present | |
if self.register is not None: | |
register_expanded = self.register.expand(batch_size, -1, -1) + self.register_type_embedding | |
tokens.append(register_expanded) | |
register_mask = torch.zeros(batch_size, 1, dtype=torch.bool, device=context_sequence.device) | |
masks.append(register_mask) | |
# Process text embeddings | |
if self.text_inputs and (text_emb is not None): | |
text_mask = (text_emb.sum(dim=1) == 0) | |
text_emb = self.text_model(text_emb) | |
text_emb += self.text_type_embedding | |
# Set embeddings to zero where mask is True | |
text_emb[text_mask] = 0 | |
text_emb = text_emb.unsqueeze(1) | |
tokens.append(text_emb) | |
# Expand text_mask to match sequence dimensions | |
text_mask = text_mask.unsqueeze(1) | |
masks.append(text_mask) | |
# Process image embeddings | |
if self.image_inputs and (image_emb is not None): | |
image_mask = (image_emb.sum(dim=1) == 0) | |
image_emb = self.image_model(image_emb) | |
image_emb += self.image_type_embedding | |
image_emb[image_mask] = 0 | |
image_emb = image_emb.unsqueeze(1) | |
tokens.append(image_emb) | |
image_mask = image_mask.unsqueeze(1) | |
masks.append(image_mask) | |
# Process env embeddings if needed (can be added similarly) | |
if self.env_inputs and (env_emb is not None): | |
env_mask = context_mask | |
env_emb = self.env_model(env_emb) | |
env_emb += self.env_type_embedding | |
env_emb[env_mask] = 0 | |
env_emb = env_emb.unsqueeze(1) | |
tokens.append(env_emb) | |
env_mask = env_mask.unsqueeze(1) | |
masks.append(env_mask) | |
# Process location tokens | |
tokens.append(context_sequence) | |
masks.append(context_mask) | |
# Concatenate all tokens and masks | |
context_sequence = torch.cat(tokens, dim=1) | |
context_mask = torch.cat(masks, dim=1) | |
if use_eval_embeddings == False: | |
if class_of_interest == None: | |
# Get class-specific embeddings based on class_ids | |
class_token_output = self.transformer_model(src=context_sequence, src_key_padding_mask=context_mask) | |
# pass these through the class token transformation | |
class_embeddings = self.class_token_transform(class_token_output) # Shape: (batch_size, num_filts) | |
if return_class_embeddings: | |
return class_embeddings | |
else: | |
# Update EMA embeddings for these class IDs | |
with torch.no_grad(): | |
if self.training: | |
self.update_ema_embeddings(class_ids, class_embeddings) | |
# Matrix multiplication to produce logits | |
logits = feature_embeddings @ class_embeddings.T | |
# Apply sigmoid to convert logits to probabilities | |
probabilities = torch.sigmoid(logits) | |
return probabilities | |
else: | |
device = self.ema_embeddings.weight.device | |
class_of_interest_tensor =torch.tensor([class_of_interest]).to(device) | |
class_embedding = self.get_ema_embeddings(class_of_interest_tensor) | |
print(f'using EMA estimate for class {class_of_interest}') | |
if return_class_embeddings: | |
return class_embedding | |
else: | |
# Matrix multiplication to produce logits | |
logits = feature_embeddings @ class_embedding.T | |
# Apply sigmoid to convert logits to probabilities | |
probabilities = torch.sigmoid(logits) | |
probabilities = probabilities.squeeze() | |
return probabilities | |
else: | |
self.eval() | |
if not hasattr(self, 'eval_embeddings'): | |
print('No Eval Embeddings for this species?!') | |
self.eval_embeddings = self.ema_embeddings | |
if class_of_interest == None: | |
# Get class-specific embeddings based on class_ids | |
class_token_output = self.transformer_model(src=context_sequence, src_key_padding_mask=context_mask) | |
class_embeddings = self.class_token_transform(class_token_output) | |
# Update EMA embeddings for these class IDs | |
self.generate_eval_embeddings(class_ids, class_embeddings) | |
# Matrix multiplication to produce logits | |
logits = feature_embeddings @ class_embeddings.T | |
# Apply sigmoid to convert logits to probabilities | |
probabilities = torch.sigmoid(logits) | |
return probabilities | |
else: | |
device = self.ema_embeddings.weight.device | |
class_of_interest_tensor =torch.tensor([class_of_interest]).to(device) | |
class_embedding = self.get_eval_embeddings(class_of_interest_tensor) | |
print(f'using eval embedding for class {class_of_interest}') | |
if return_class_embeddings: | |
return class_embedding | |
else: | |
# Matrix multiplication to produce logits | |
logits = feature_embeddings @ class_embedding.T | |
# Apply sigmoid to convert logits to probabilities | |
probabilities = torch.sigmoid(logits) | |
probabilities = probabilities.squeeze() | |
return probabilities | |
def get_loc_emb(self, x): | |
feature_embeddings = self.headless_model(x) | |
return feature_embeddings | |
def init_eval_embeddings(self, num_classes): | |
self.eval_embeddings = nn.Embedding(num_embeddings=num_classes, embedding_dim=self.num_filts) | |
nn.init.xavier_uniform_(self.eval_embeddings.weight) | |
def get_ema_embeddings(self, class_ids): | |
# Method to access EMA embeddings | |
return self.ema_embeddings(class_ids) | |
def get_eval_embeddings(self, class_ids): | |
# Method to access eval embeddings | |
return self.eval_embeddings(class_ids) | |
def update_ema_embeddings(self, class_ids, current_embeddings): | |
if self.training: | |
# Get unique class IDs and their counts | |
unique_class_ids, inverse_indices, counts = class_ids.unique(return_counts=True, return_inverse=True) | |
# Get current EMA embeddings for unique class IDs | |
ema_current = self.ema_embeddings(unique_class_ids) | |
# Initialize a placeholder for new EMA values | |
ema_new = torch.zeros_like(ema_current) | |
# Compute the average of current embeddings for each unique class ID | |
current_sum = torch.zeros_like(ema_current) | |
current_sum.index_add_(0, inverse_indices, current_embeddings) | |
current_avg = current_sum / counts.unsqueeze(1) | |
# Apply EMA update formula | |
ema_new = self.ema_factor * current_avg + (1 - self.ema_factor) * ema_current | |
# Update the EMA embeddings for unique class IDs | |
self.ema_embeddings.weight.data[unique_class_ids] = ema_new.detach() # Detach to prevent gradients | |
def generate_eval_embeddings(self, class_id, current_embedding): | |
self.eval_embeddings.weight.data[class_id, :] = current_embedding.detach() # Detach to prevent gradients | |
# self.eval_embeddings.weight.data[class_id] = self.ema_embeddings.weight.data[class_id] # Detach to prevent gradients | |
def embedding_forward(self, x, class_ids=None, return_feats=False, return_class_embeddings=False, class_of_interest=None, eval=False): | |
# forward method that uses ema or eval embeddings rather than context sequence | |
# Process input through the headless model to get feature embeddings | |
feature_embeddings = self.headless_model(x) | |
if return_feats: | |
return feature_embeddings | |
else: | |
if class_of_interest == None: | |
# Get class-specific embeddings based on class_ids | |
if eval == False: | |
class_embeddings = self.get_ema_embeddings(class_ids=class_ids) | |
else: | |
class_embeddings = self.get_eval_embeddings(class_ids=class_ids) | |
if return_class_embeddings: | |
return class_embeddings | |
else: | |
# Matrix multiplication to produce logits | |
logits = feature_embeddings @ class_embeddings.T | |
# Apply sigmoid to convert logits to probabilities | |
probabilities = torch.sigmoid(logits) | |
return probabilities | |
else: | |
if eval == False: | |
device = self.ema_embeddings.weight.device | |
class_of_interest_tensor =torch.tensor([class_of_interest]).to(device) | |
class_embedding = self.get_ema_embeddings(class_of_interest_tensor) | |
print(f'using EMA estimate for class {class_of_interest}') | |
if return_class_embeddings: | |
return class_embedding | |
else: | |
# Matrix multiplication to produce logits | |
logits = feature_embeddings @ class_embedding.T | |
# Apply sigmoid to convert logits to probabilities | |
probabilities = torch.sigmoid(logits) | |
probabilities = probabilities.squeeze() | |
return probabilities | |
else: | |
device = self.eval_embeddings.weight.device | |
class_of_interest_tensor = torch.tensor([class_of_interest]).to(device) | |
class_embedding = self.get_eval_embeddings(class_of_interest_tensor) | |
#print(f'using eval estimate for class {class_of_interest}') | |
if return_class_embeddings: | |
return class_embedding | |
else: | |
# Matrix multiplication to produce logits | |
logits = feature_embeddings @ class_embedding.T | |
# Apply sigmoid to convert logits to probabilities | |
probabilities = torch.sigmoid(logits) | |
probabilities = probabilities.squeeze() | |
return probabilities | |
class LinNet(nn.Module): | |
def __init__(self, num_inputs, num_classes): | |
super(LinNet, self).__init__() | |
self.num_layers = 0 | |
self.inc_bias = False | |
self.class_emb = nn.Linear(num_inputs, num_classes, bias=self.inc_bias) | |
self.feats = nn.Identity() # does not do anything | |
def forward(self, x, class_of_interest=None, return_feats=False): | |
loc_emb = self.feats(x) | |
if return_feats: | |
return loc_emb | |
if class_of_interest is None: | |
class_pred = self.class_emb(loc_emb) | |
else: | |
class_pred = self.eval_single_class(loc_emb, class_of_interest) | |
return torch.sigmoid(class_pred) | |
def eval_single_class(self, x, class_of_interest): | |
if self.inc_bias: | |
return x @ self.class_emb.weight[class_of_interest, :] + self.class_emb.bias[class_of_interest] | |
else: | |
return x @ self.class_emb.weight[class_of_interest, :] | |
class ParallelMulti(torch.nn.Module): | |
def __init__(self, x: list[torch.nn.Module]): | |
super(ParallelMulti, self).__init__() | |
self.layers = nn.ModuleList(x) | |
def forward(self, xs, **kwargs): | |
out = torch.cat([self.layers[i](x, **kwargs) for i,x in enumerate(xs)], dim=1) | |
return out | |
class SequentialMulti(torch.nn.Sequential): | |
def forward(self, *inputs, **kwargs): | |
for module in self._modules.values(): | |
if type(inputs) == tuple: | |
inputs = module(*inputs, **kwargs) | |
else: | |
inputs = module(inputs) | |
return inputs | |
# Chris's transformation classes | |
class Identity(nn.Module): | |
def __init__(self, in_dim, out_dim): | |
super(Identity, self).__init__() | |
# No parameters needed for identity transformation | |
def forward(self, x): | |
return x | |
class LinearTransformation(nn.Module): | |
def __init__(self, in_dim, out_dim, bias=True): | |
super(LinearTransformation, self).__init__() | |
self.linear = nn.Linear(in_dim, out_dim, bias=bias) | |
def forward(self, x): | |
return self.linear(x) | |
class SingleLayerNN(nn.Module): | |
def __init__(self, in_dim, out_dim, dropout_p=0.1, bias=True): | |
super(SingleLayerNN, self).__init__() | |
hidden_dim = (in_dim + out_dim) // 2 # Choose an appropriate hidden dimension | |
self.net = nn.Sequential( | |
nn.Linear(in_dim, hidden_dim, bias=bias), | |
nn.ReLU(), | |
nn.Dropout(p=dropout_p), | |
nn.Linear(hidden_dim, out_dim, bias=bias) | |
) | |
def forward(self, x): | |
return self.net(x) | |
class TwoLayerNN(nn.Module): | |
def __init__(self, in_dim, out_dim, dropout_p=0.1, bias=True): | |
super(TwoLayerNN, self).__init__() | |
hidden_dim = (in_dim + out_dim) // 2 # Choose an appropriate hidden dimension | |
self.net = nn.Sequential( | |
nn.Linear(in_dim, hidden_dim, bias=bias), | |
nn.ReLU(), | |
nn.Dropout(p=dropout_p), | |
nn.Linear(hidden_dim, hidden_dim, bias=bias), | |
nn.ReLU(), | |
nn.Dropout(p=dropout_p), | |
nn.Linear(hidden_dim, out_dim, bias=bias) | |
) | |
def forward(self, x): | |
return self.net(x) | |
class HyperNet(nn.Module): | |
''' | |
:param asdf | |
''' | |
def __init__(self, params, num_inputs, num_classes, num_filts, pos_enc_depth, species_dim, species_enc_depth, species_filts, species_enc='embed', inference_only=False): | |
super(HyperNet, self).__init__() | |
if species_enc == 'embed': | |
self.species_emb = nn.Embedding(num_classes, species_dim) | |
self.species_emb.weight.data *= 0.01 | |
elif species_enc == 'taxa': | |
self.species_emb = TaxaEncoder(params, './data/inat_taxa_info.csv', species_dim) | |
elif species_enc == 'text': | |
self.species_emb = TextEncoder(params, params['text_emb_path'], species_dim, './data/inat_taxa_info.csv') | |
elif species_enc == 'wiki': | |
self.species_emb = WikiEncoder(params, params['text_emb_path'], species_dim, inference_only=inference_only) | |
if species_enc_depth == -1: | |
self.species_enc = nn.Identity() | |
elif species_enc_depth == 0: | |
self.species_enc = nn.Linear(species_dim, num_filts+1) | |
else: | |
self.species_enc = SimpleFCNet(species_dim, num_filts+1, species_filts, depth=species_enc_depth) | |
if 'geoprior' in params['loss']: | |
self.species_params = nn.Parameter(torch.randn(num_classes, species_dim)) | |
self.species_params.data *= 0.0386 | |
self.pos_enc = SimpleFCNet(num_inputs, num_filts, num_filts, depth=pos_enc_depth) | |
def forward(self, x, y): | |
ys, indmap = torch.unique(y, return_inverse=True) | |
species = self.species_enc(self.species_emb(ys)) | |
species_w, species_b = species[...,:-1], species[...,-1:] | |
pos = self.pos_enc(x) | |
out = torch.bmm(species_w[indmap],pos[...,None]) | |
out = (out + 0*species_b[indmap]).squeeze(-1) #TODO | |
if hasattr(self, 'species_params'): | |
out2 = torch.bmm(self.species_params[ys][indmap],pos[...,None]) | |
out2 = out2.squeeze(-1) | |
out3 = (species_w, self.species_params[ys], ys) | |
return out, out2, out3 | |
else: | |
return out | |
def zero_shot(self, x, species_emb): | |
species = self.species_enc(self.species_emb.zero_shot(species_emb)) | |
species_w, _ = species[...,:-1], species[...,-1:] | |
pos = self.pos_enc(x) | |
out = pos @ species_w.T | |
return out | |
class TaxaEncoder(nn.Module): | |
def __init__(self, params, fpath, embedding_dim): | |
super(TaxaEncoder, self).__init__() | |
import datasets | |
with open('paths.json', 'r') as f: | |
paths = json.load(f) | |
data_dir = paths['train'] | |
obs_file = os.path.join(data_dir, params['obs_file']) | |
taxa_file_snt = os.path.join(data_dir, 'taxa_subsets.json') | |
taxa_of_interest = datasets.get_taxa_of_interest(params['species_set'], params['num_aux_species'], | |
params['aux_species_seed'], params['taxa_file'], taxa_file_snt) | |
locs, labels, _, dates, _, _ = datasets.load_inat_data(obs_file, taxa_of_interest) | |
unique_taxa, class_ids = np.unique(labels, return_inverse=True) | |
class_to_taxa = unique_taxa.tolist() | |
self.fpath = fpath | |
ids = [] | |
rows = [] | |
with open(fpath, newline='') as csvfile: | |
spamreader = csv.reader(csvfile, delimiter=',') | |
for row in spamreader: | |
if row[0] == 'taxon_id': | |
continue | |
ids.append(int(row[0])) | |
rows.append(row[3:]) | |
print() | |
rows = np.array(rows) | |
rows = [np.unique(rows[:,i], return_inverse=True)[1] for i in range(rows.shape[1])] | |
rows = torch.from_numpy(np.vstack(rows).T) | |
rows = rows | |
self.taxa2row = {taxaid:i for i, taxaid in enumerate(ids)} | |
embs = [nn.Embedding(rows[:,i].max()+2, embedding_dim, 0) for i in range(rows.shape[1])] | |
embs[-1] = nn.Embedding(len(class_to_taxa), embedding_dim) | |
rows2 = torch.zeros((len(class_to_taxa), 7), dtype=rows.dtype) | |
startind = rows[:,-1].max() | |
for i in range(len(class_to_taxa)): | |
if class_to_taxa[i] in ids: | |
rows2[i] = rows[ids.index(class_to_taxa[i])]+1 | |
rows2[i,-1] -= 1 | |
else: | |
rows2[i,-1] = startind | |
startind += 1 | |
self.register_buffer('rows', rows2) | |
for e in embs: | |
e.weight.data *= 0.01 | |
self.embs = nn.ModuleList(embs) | |
def forward(self, x): | |
inds = self.rows[x] | |
out = sum([self.embs[i](inds[...,i]) for i in range(inds.shape[-1])]) | |
return out | |
class TextEncoder(nn.Module): | |
def __init__(self, params, path, embedding_dim, fpath='inat_taxa_info.csv'): | |
super(TextEncoder, self).__init__() | |
import datasets | |
with open('paths.json', 'r') as f: | |
paths = json.load(f) | |
data_dir = paths['train'] | |
obs_file = os.path.join(data_dir, params['obs_file']) | |
taxa_file_snt = os.path.join(data_dir, 'taxa_subsets.json') | |
taxa_of_interest = datasets.get_taxa_of_interest(params['species_set'], params['num_aux_species'], | |
params['aux_species_seed'], params['taxa_file'], taxa_file_snt) | |
locs, labels, _, dates, _, _ = datasets.load_inat_data(obs_file, taxa_of_interest) | |
unique_taxa, class_ids = np.unique(labels, return_inverse=True) | |
class_to_taxa = unique_taxa.tolist() | |
self.fpath = fpath | |
ids = [] | |
with open(fpath, newline='') as csvfile: | |
spamreader = csv.reader(csvfile, delimiter=',') | |
for row in spamreader: | |
if row[0] == 'taxon_id': | |
continue | |
ids.append(int(row[0])) | |
embs = torch.load(path) | |
if len(embs) != len(ids): | |
print("Warning: Number of embeddings doesn't match number of species") | |
ids = ids[:embs.shape[0]] | |
if isinstance(embs, list): | |
embs = torch.stack(embs) | |
self.taxa2row = {taxaid:i for i, taxaid in enumerate(ids)} | |
indmap = -1+torch.zeros(len(class_to_taxa), dtype=torch.int) | |
embmap = -1+torch.zeros(len(class_to_taxa), dtype=torch.int) | |
self.missing_emb = nn.Embedding(len(class_to_taxa)-embs.shape[0], embedding_dim) | |
startind = 0 | |
for i in range(len(class_to_taxa)): | |
if class_to_taxa[i] in ids: | |
indmap[i] = ids.index(class_to_taxa[i]) | |
else: | |
embmap[i] = startind | |
startind += 1 | |
self.scales = nn.Parameter(torch.zeros(len(class_to_taxa), 1)) | |
self.register_buffer('indmap', indmap, persistent=False) | |
self.register_buffer('embmap', embmap, persistent=False) | |
self.register_buffer('embs', embs, persistent=False) | |
if params['text_hidden_dim'] == 0: | |
self.linear1 = nn.Linear(embs.shape[1], embedding_dim) | |
else: | |
self.linear1 = nn.Linear(embs.shape[1], params['text_hidden_dim']) | |
self.linear2 = nn.Linear(params['text_hidden_dim'], embedding_dim) | |
self.act = nn.SiLU() | |
if params['text_learn_dim'] > 0: | |
self.learned_emb = nn.Embedding(len(class_to_taxa), params['text_learn_dim']) | |
self.learned_emb.weight.data *= 0.01 | |
self.linear_learned = nn.Linear(params['text_learn_dim'], embedding_dim) | |
def forward(self, x): | |
inds = self.indmap[x] | |
out = self.embs[self.indmap[x].cpu()] | |
out = self.linear1(out) | |
if hasattr(self, 'linear2'): | |
out = self.linear2(self.act(out)) | |
out = self.scales[x] * (out / (out.std(dim=1)[:, None])) | |
out[inds == -1] = self.missing_emb(self.embmap[x[inds == -1]]) | |
if hasattr(self, 'learned_emb'): | |
out2 = self.learned_emb(x) | |
out2 = self.linear_learned(out2) | |
out = out+out2 | |
return out | |
class WikiEncoder(nn.Module): | |
def __init__(self, params, path, embedding_dim, inference_only=False): | |
super(WikiEncoder, self).__init__() | |
self.path = path | |
if not inference_only: | |
import datasets | |
with open('paths.json', 'r') as f: | |
paths = json.load(f) | |
data_dir = paths['train'] | |
obs_file = os.path.join(data_dir, params['obs_file']) | |
taxa_file_snt = os.path.join(data_dir, 'taxa_subsets.json') | |
taxa_of_interest = datasets.get_taxa_of_interest(params['species_set'], params['num_aux_species'], | |
params['aux_species_seed'], params['taxa_file'], taxa_file_snt) | |
locs, labels, _, dates, _, _ = datasets.load_inat_data(obs_file, taxa_of_interest) | |
if params['zero_shot']: | |
with open('paths.json', 'r') as f: | |
paths = json.load(f) | |
with open(os.path.join(paths['iucn'], 'iucn_res_5.json'), 'r') as f: | |
data = json.load(f) | |
D = np.load(os.path.join(paths['snt'], 'snt_res_5.npy'), allow_pickle=True) | |
D = D.item() | |
taxa_snt = D['taxa'].tolist() | |
taxa = [int(tt) for tt in data['taxa_presence'].keys()] | |
taxa = list(set(taxa + taxa_snt)) | |
mask = labels != taxa[0] | |
for i in range(1, len(taxa)): | |
mask &= (labels != taxa[i]) | |
locs = locs[mask] | |
dates = dates[mask] | |
labels = labels[mask] | |
unique_taxa, class_ids = np.unique(labels, return_inverse=True) | |
class_to_taxa = unique_taxa.tolist() | |
embs = torch.load(path) | |
ids = embs['taxon_id'].tolist() | |
if 'keys' in embs: | |
taxa_counts = torch.zeros(len(ids), dtype=torch.int32) | |
for i,k in embs['keys']: | |
taxa_counts[i] += 1 | |
else: | |
taxa_counts = torch.ones(len(ids), dtype=torch.int32) | |
count_sum = torch.cumsum(taxa_counts, dim=0) - taxa_counts | |
embs = embs['data'] | |
self.taxa2row = {taxaid:i for i, taxaid in enumerate(ids)} | |
indmap = -1+torch.zeros(len(class_to_taxa), dtype=torch.int) | |
countmap = torch.zeros(len(class_to_taxa), dtype=torch.int) | |
self.species_emb = nn.Embedding(len(class_to_taxa), embedding_dim) | |
self.species_emb.weight.data *= 0.01 | |
for i in range(len(class_to_taxa)): | |
if class_to_taxa[i] in ids: | |
i2 = ids.index(class_to_taxa[i]) | |
indmap[i] = count_sum[i2] | |
countmap[i] = taxa_counts[i2] | |
self.register_buffer('indmap', indmap, persistent=False) | |
self.register_buffer('countmap', countmap, persistent=False) | |
self.register_buffer('embs', embs, persistent=False) | |
assert embs.shape[1] == 4096 | |
self.scale = nn.Parameter(torch.zeros(1)) | |
if params['species_dropout'] > 0: | |
self.dropout = nn.Dropout(p=params['species_dropout']) | |
if params['text_hidden_dim'] == 0: | |
self.linear1 = nn.Linear(4096, embedding_dim) | |
else: | |
self.linear1 = nn.Linear(4096, params['text_hidden_dim']) | |
if params['text_batchnorm']: | |
self.bn1 = nn.BatchNorm1d(params['text_hidden_dim']) | |
for l in range(params['text_num_layers']-1): | |
setattr(self, f'linear{l+2}', nn.Linear(params['text_hidden_dim'], params['text_hidden_dim'])) | |
if params['text_batchnorm']: | |
setattr(self, f'bn{l+2}', nn.BatchNorm1d(params['text_hidden_dim'])) | |
setattr(self, f'linear{params["text_num_layers"]+1}', nn.Linear(params['text_hidden_dim'], embedding_dim)) | |
self.act = nn.SiLU() | |
if params['text_learn_dim'] > 0: | |
self.learned_emb = nn.Embedding(len(class_to_taxa), params['text_learn_dim']) | |
self.learned_emb.weight.data *= 0.01 | |
self.linear_learned = nn.Linear(params['text_learn_dim'], embedding_dim) | |
def forward(self, x): | |
inds = self.indmap[x] + (torch.rand(x.shape,device=x.device)*self.countmap[x]).floor().int() | |
out = self.embs[inds] | |
if hasattr(self, 'dropout'): | |
out = self.dropout(out) | |
out = self.linear1(out) | |
if hasattr(self, 'linear2'): | |
out = self.act(out) | |
if hasattr(self, 'bn1'): | |
out = self.bn1(out) | |
i = 2 | |
while hasattr(self, f'linear{i}'): | |
if hasattr(self, f'linear{i}'): | |
out = self.act(getattr(self, f'linear{i}')(out)) | |
if hasattr(self, f'bn{i}'): | |
out = getattr(self, f'bn{i}')(out) | |
i += 1 | |
#out = self.scale * (out / (out.std(dim=1)[:, None])) | |
out2 = self.species_emb(x) | |
chosen = torch.rand((out.shape[0],), device=x.device) | |
chosen = 1+0*chosen #TODO fix this | |
chosen[inds == -1] = 0 | |
out = chosen[:,None] * out + (1-chosen[:,None])*out2 | |
if hasattr(self, 'learned_emb'): | |
out2 = self.learned_emb(x) | |
out2 = self.linear_learned(out2) | |
out = out+out2 | |
return out | |
def zero_shot(self, species_emb): | |
out = species_emb | |
out = self.linear1(out) | |
if hasattr(self, 'linear2'): | |
out = self.act(out) | |
if hasattr(self, 'bn1'): | |
out = self.bn1(out) | |
i = 2 | |
while hasattr(self, f'linear{i}'): | |
if hasattr(self, f'linear{i}'): | |
out = self.act(getattr(self, f'linear{i}')(out)) | |
if hasattr(self, f'bn{i}'): | |
out = getattr(self, f'bn{i}')(out) | |
i += 1 | |
return out | |
def zero_shot_old(self, species_emb): | |
out = species_emb | |
out = self.linear1(out) | |
if hasattr(self, 'linear2'): | |
out = self.linear2(self.act(out)) | |
out = self.scale * (out / (out.std(dim=-1, keepdim=True))) | |
return out | |
# MINE - would only be used for my models - not currently being used at all | |
# CURRENTLY JUST USING A HEADLESS_SINR FOR THE TEXT ENCODER | |
class MultiInputTextEncoder(nn.Module): | |
def __init__(self, token_dim, dropout, input_dim=4096, depth=2, hidden_dim=512, nonlin='relu', batch_norm=True, layer_norm=False): | |
super(MultiInputTextEncoder, self).__init__() | |
print("THINK ABOUT IF SOME OF THESE HYPERPARAMETERS SHOULD BE DISTINCT FROM THE TRANSFORMER VERSION") | |
print("DEPTH / NUM_ENCODER_LAYERS, DROPOUT, DIM_FEEDFORWARD, ETC") | |
print("AT PRESENT WE JUST HAVE A SORT OF BASIC VERSION IMPLEMENTED THAT ATTEMPTS TO BE LIKE MAX'S VERSION") | |
print("ALSO, OPTION TO HAVE IT PRETRAINED? ADD RESIDUAL LAYERS?") | |
self.token_dim=token_dim | |
self.dropout=dropout | |
self.input_dim=input_dim | |
self.depth=depth | |
self.hidden_dim=hidden_dim | |
self.batch_norm = batch_norm | |
self.layer_norm = layer_norm | |
if nonlin == 'relu': | |
activation = nn.ReLU | |
elif nonlin == 'silu': | |
activation = nn.SiLU | |
else: | |
raise NotImplementedError('Invalid nonlinearity specified.') | |
self.dropout_layer = nn.Dropout(p=self.dropout) | |
if self.depth <= 1: | |
self.linear1 = nn.Linear(self.input_dim, self.token_dim) | |
else: | |
self.linear1 = nn.Linear(self.input_dim, self.hidden_dim) | |
if self.batch_norm: | |
self.bn1 = nn.BatchNorm1d(self.hidden_dim) | |
# if self.layer_norm: | |
# self.ln1 = nn.LayerNorm(self.hidden_dim) | |