Spaces:
Sleeping
Sleeping
import copy | |
import torch | |
def apply_overrides(params, overrides): | |
params = copy.deepcopy(params) | |
for param_name in overrides: | |
if param_name not in params: | |
print(f'override failed: no parameter named {param_name}') | |
# raise ValueError | |
params[param_name] = overrides[param_name] | |
return params | |
def get_default_params_train(overrides={}): | |
params = {} | |
''' | |
misc | |
''' | |
params['device'] = 'cuda' # cuda, cpu | |
params['save_base'] = './experiments/' | |
params['save_frequency'] = 5 | |
params['experiment_name'] = 'demo' | |
params['timestamp'] = False | |
''' | |
data | |
''' | |
params['species_set'] = 'all' # all, snt_birds | |
params['hard_cap_seed'] = 9472 | |
params['hard_cap_num_per_class'] = -1 # -1 for no hard capping | |
params['aux_species_seed'] = 8099 | |
params['num_aux_species'] = 0 # for snt_birds case, how many other species to add in | |
params['input_time'] = False # whether to input time as a feature | |
params['input_time_dim'] = 0 | |
params['dataset'] = 'inat' # inat, iucn_inat, iucn_uniform | |
params['zero_shot'] = False | |
params['subset_cap_name'] = None | |
params['subset_cap_num_per_class'] = -1 | |
# MINE - I added these -check if there is any impact | |
params['seed'] = 1000 | |
params['add_location_noise'] = False | |
params['variable_context_length'] = False | |
params['eval_dataset'] = 'eval_transformer' | |
params['eval_num_context'] = 20 | |
params['use_text_inputs'] = True | |
params['use_image_inputs'] = False | |
params['use_env_inputs'] = False | |
params['class_token_transformation'] = 'identity' | |
params['loc_prob'] = 1.0 | |
params['text_prob'] = 0.0 | |
params['image_prob'] = 0.0 | |
params['env_prob'] = 0.0 | |
''' | |
data files | |
''' | |
params['obs_file'] = 'geo_prior_train.csv' | |
params['taxa_file'] = 'geo_prior_train_meta.json' | |
''' | |
model | |
''' | |
params['model'] = 'ResidualFCNet' # ResidualFCNet, LinNet | |
params['num_filts'] = 256 # embedding dimension | |
params['input_enc'] = 'sin_cos' # sin_cos, env, sin_cos_env | |
params['input_dim'] = 4 | |
params['depth'] = 4 | |
params['noise_time'] = False | |
params['species_dim'] = 0 | |
params['species_enc_depth'] = 0 | |
params['species_filts'] = 256 | |
params['species_enc'] = 'embed' | |
params['text_emb_path'] = '' | |
params['image_emb_path'] = '' | |
params['text_learn_dim'] = 0 | |
params['text_hidden_dim'] = 0 | |
params['text_num_layers'] = 1 | |
params['text_batchnorm'] = False | |
params['species_dropout'] = 0.0 | |
params['geoprior_temp'] = 0.0 | |
# MINE - I added these | |
params['num_context'] = 50 | |
params['transformer_input_enc'] = 'sin_cos' | |
params['transformer_dropout'] = 0.1 | |
params['num_heads'] = 8 | |
params['ema_factor'] = 0.1 | |
params['use_register'] = True | |
params['use_pretrained_sinr']=False | |
params['pretrained_loc']='' | |
params['freeze_sinr']=False | |
params['pos_enc_class'] = 'sinr' | |
''' | |
loss | |
''' | |
params['loss'] = 'an_full' # an_full, an_ssdl, an_slds | |
params['pos_weight'] = 2048 | |
''' | |
optimization | |
''' | |
params['batch_size'] = 2048 | |
params['lr'] = 0.0005 | |
params['lr_decay'] = 0.98 | |
params['num_epochs'] = 10 | |
''' | |
saving | |
''' | |
params['log_frequency'] = 512 | |
params = apply_overrides(params, overrides) | |
return params | |
def get_default_params_eval(overrides={}): | |
params = {} | |
''' | |
misc | |
''' | |
params['device'] = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
params['seed'] = 2022 | |
params['exp_base'] = './experiments' | |
params['ckp_name'] = 'model.pt' | |
params['eval_type'] = 'snt' # snt, iucn, geo_prior, geo_feature | |
params['experiment_name'] = 'demo' | |
params['input_dim'] = 4 | |
params['input_time'] = False | |
params['input_time_dim'] = 0 | |
# mine | |
# params['num_samples'] = -1 | |
# maxs | |
params['num_samples'] = 0 | |
params['text_section'] = '' | |
params['extract_pos'] = False | |
# MINE - but probably not needed anymore | |
# params['target_background']=True | |
''' | |
geo prior | |
''' | |
params['batch_size'] = 2048 | |
''' | |
geo feature | |
''' | |
params['cell_size'] = 25 | |
params = apply_overrides(params, overrides) | |
return params | |