Spaces:
Sleeping
Sleeping
File size: 4,264 Bytes
0b54529 9ff98d7 0b54529 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
import copy
import torch
def apply_overrides(params, overrides):
params = copy.deepcopy(params)
for param_name in overrides:
if param_name not in params:
print(f'override failed: no parameter named {param_name}')
# raise ValueError
params[param_name] = overrides[param_name]
return params
def get_default_params_train(overrides={}):
params = {}
'''
misc
'''
params['device'] = 'cuda' # cuda, cpu
params['save_base'] = './experiments/'
params['save_frequency'] = 5
params['experiment_name'] = 'demo'
params['timestamp'] = False
'''
data
'''
params['species_set'] = 'all' # all, snt_birds
params['hard_cap_seed'] = 9472
params['hard_cap_num_per_class'] = -1 # -1 for no hard capping
params['aux_species_seed'] = 8099
params['num_aux_species'] = 0 # for snt_birds case, how many other species to add in
params['input_time'] = False # whether to input time as a feature
params['input_time_dim'] = 0
params['dataset'] = 'inat' # inat, iucn_inat, iucn_uniform
params['zero_shot'] = False
params['subset_cap_name'] = None
params['subset_cap_num_per_class'] = -1
# MINE - I added these -check if there is any impact
params['seed'] = 1000
params['add_location_noise'] = False
params['variable_context_length'] = False
params['eval_dataset'] = 'eval_transformer'
params['eval_num_context'] = 20
params['use_text_inputs'] = True
params['use_image_inputs'] = False
params['use_env_inputs'] = False
params['class_token_transformation'] = 'identity'
params['loc_prob'] = 1.0
params['text_prob'] = 0.0
params['image_prob'] = 0.0
params['env_prob'] = 0.0
'''
data files
'''
params['obs_file'] = 'geo_prior_train.csv'
params['taxa_file'] = 'geo_prior_train_meta.json'
'''
model
'''
params['model'] = 'ResidualFCNet' # ResidualFCNet, LinNet
params['num_filts'] = 256 # embedding dimension
params['input_enc'] = 'sin_cos' # sin_cos, env, sin_cos_env
params['input_dim'] = 4
params['depth'] = 4
params['noise_time'] = False
params['species_dim'] = 0
params['species_enc_depth'] = 0
params['species_filts'] = 256
params['species_enc'] = 'embed'
params['text_emb_path'] = ''
params['image_emb_path'] = ''
params['text_learn_dim'] = 0
params['text_hidden_dim'] = 0
params['text_num_layers'] = 1
params['text_batchnorm'] = False
params['species_dropout'] = 0.0
params['geoprior_temp'] = 0.0
# MINE - I added these
params['num_context'] = 50
params['transformer_input_enc'] = 'sin_cos'
params['transformer_dropout'] = 0.1
params['num_heads'] = 8
params['ema_factor'] = 0.1
params['use_register'] = True
params['use_pretrained_sinr']=False
params['pretrained_loc']=''
params['freeze_sinr']=False
params['pos_enc_class'] = 'sinr'
'''
loss
'''
params['loss'] = 'an_full' # an_full, an_ssdl, an_slds
params['pos_weight'] = 2048
'''
optimization
'''
params['batch_size'] = 2048
params['lr'] = 0.0005
params['lr_decay'] = 0.98
params['num_epochs'] = 10
'''
saving
'''
params['log_frequency'] = 512
params = apply_overrides(params, overrides)
return params
def get_default_params_eval(overrides={}):
params = {}
'''
misc
'''
params['device'] = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
params['seed'] = 2022
params['exp_base'] = './experiments'
params['ckp_name'] = 'model.pt'
params['eval_type'] = 'snt' # snt, iucn, geo_prior, geo_feature
params['experiment_name'] = 'demo'
params['input_dim'] = 4
params['input_time'] = False
params['input_time_dim'] = 0
# mine
# params['num_samples'] = -1
# maxs
params['num_samples'] = 0
params['text_section'] = ''
params['extract_pos'] = False
# MINE - but probably not needed anymore
# params['target_background']=True
'''
geo prior
'''
params['batch_size'] = 2048
'''
geo feature
'''
params['cell_size'] = 25
params = apply_overrides(params, overrides)
return params
|