File size: 4,264 Bytes
0b54529
9ff98d7
 
0b54529
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import copy
import torch

def apply_overrides(params, overrides):
    params = copy.deepcopy(params)
    for param_name in overrides:
        if param_name not in params:
             print(f'override failed: no parameter named {param_name}')
        #     raise ValueError
        params[param_name] = overrides[param_name]
    return params

def get_default_params_train(overrides={}):

    params = {}

    '''
    misc
    '''
    params['device'] = 'cuda' # cuda, cpu
    params['save_base'] = './experiments/'
    params['save_frequency'] = 5
    params['experiment_name'] = 'demo'
    params['timestamp'] = False

    '''
    data
    '''
    params['species_set'] = 'all' # all, snt_birds
    params['hard_cap_seed'] = 9472
    params['hard_cap_num_per_class'] = -1 # -1 for no hard capping
    params['aux_species_seed'] = 8099
    params['num_aux_species'] = 0 # for snt_birds case, how many other species to add in
    params['input_time'] = False # whether to input time as a feature
    params['input_time_dim'] = 0
    params['dataset'] = 'inat' # inat, iucn_inat, iucn_uniform
    params['zero_shot'] = False
    params['subset_cap_name'] = None
    params['subset_cap_num_per_class'] = -1
    # MINE - I added these -check if there is any impact
    params['seed'] = 1000
    params['add_location_noise'] = False
    params['variable_context_length'] = False
    params['eval_dataset'] = 'eval_transformer'
    params['eval_num_context'] = 20
    params['use_text_inputs'] = True
    params['use_image_inputs'] = False
    params['use_env_inputs'] = False
    params['class_token_transformation'] = 'identity'
    params['loc_prob'] = 1.0
    params['text_prob'] = 0.0
    params['image_prob'] = 0.0
    params['env_prob'] = 0.0

    '''
    data files
    '''
    params['obs_file'] = 'geo_prior_train.csv'
    params['taxa_file'] = 'geo_prior_train_meta.json'

    '''
    model
    '''
    params['model'] = 'ResidualFCNet'  # ResidualFCNet, LinNet
    params['num_filts'] = 256  # embedding dimension
    params['input_enc'] = 'sin_cos' # sin_cos, env, sin_cos_env
    params['input_dim'] = 4
    params['depth'] = 4
    params['noise_time'] = False

    params['species_dim'] = 0
    params['species_enc_depth'] = 0
    params['species_filts'] = 256
    params['species_enc'] = 'embed'
    params['text_emb_path'] = ''
    params['image_emb_path'] = ''
    params['text_learn_dim'] = 0
    params['text_hidden_dim'] = 0
    params['text_num_layers'] = 1
    params['text_batchnorm'] = False
    params['species_dropout'] = 0.0
    params['geoprior_temp'] = 0.0
    # MINE - I added these
    params['num_context'] = 50
    params['transformer_input_enc'] = 'sin_cos'
    params['transformer_dropout'] = 0.1
    params['num_heads'] = 8
    params['ema_factor'] = 0.1
    params['use_register'] = True
    params['use_pretrained_sinr']=False
    params['pretrained_loc']=''
    params['freeze_sinr']=False
    params['pos_enc_class'] = 'sinr'

    '''
    loss
    '''
    params['loss'] = 'an_full' # an_full, an_ssdl, an_slds
    params['pos_weight'] = 2048

    '''
    optimization
    '''
    params['batch_size'] = 2048
    params['lr'] = 0.0005
    params['lr_decay'] = 0.98
    params['num_epochs'] = 10

    '''
    saving
    '''
    params['log_frequency'] = 512

    params = apply_overrides(params, overrides)

    return params


def get_default_params_eval(overrides={}):

    params = {}

    '''
    misc
    '''
    params['device'] = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    params['seed'] = 2022
    params['exp_base'] = './experiments'
    params['ckp_name'] = 'model.pt'
    params['eval_type'] = 'snt' # snt, iucn, geo_prior, geo_feature
    params['experiment_name'] = 'demo'
    params['input_dim'] = 4
    params['input_time'] = False
    params['input_time_dim'] = 0
    # mine
    # params['num_samples'] = -1
    # maxs
    params['num_samples'] = 0
    params['text_section'] = ''
    params['extract_pos'] = False
    # MINE - but probably not needed anymore
    # params['target_background']=True
    '''
    geo prior
    '''
    params['batch_size'] = 2048

    '''
    geo feature
    '''
    params['cell_size'] = 25

    params = apply_overrides(params, overrides)

    return params