local env = import "../env.jsonnet"; local base = import "basic.jsonnet"; local debug = false; # re-train local pretrained_path = env.str("PRETRAINED_PATH", "cache/fn/best"); local rt_lr = env.json("RT_LR", 5e-5); # module local cuda_devices = base.cuda_devices; { dataset_reader: base.dataset_reader, train_data_path: base.train_data_path, validation_data_path: base.validation_data_path, test_data_path: base.test_data_path, datasets_for_vocab_creation: ["train"], data_loader: base.data_loader, validation_data_loader: base.validation_data_loader, model: { type: "span", word_embedding: { "_pretrained": { "archive_file": pretrained_path, "module_path": "word_embedding", "freeze": false, } }, span_extractor: { "_pretrained": { "archive_file": pretrained_path, "module_path": "_span_extractor", "freeze": false, } }, span_finder: { "_pretrained": { "archive_file": pretrained_path, "module_path": "_span_finder", "freeze": false, } }, span_typing: { type: 'mlp', hidden_dims: base.model.span_typing.hidden_dims, }, metrics: [{type: "srl"}], typing_loss_factor: base.model.typing_loss_factor, label_dim: base.model.label_dim, max_decoding_spans: 128, max_recursion_depth: 2, debug: debug, }, trainer: { num_epochs: base.trainer.num_epochs, patience: base.trainer.patience, [if std.length(cuda_devices) == 1 then "cuda_device"]: cuda_devices[0], validation_metric: "+arg-c_f", num_gradient_accumulation_steps: base.trainer.num_gradient_accumulation_steps, optimizer: { type: "transformer", base: { type: "adam", lr: base.trainer.optimizer.base.lr, }, embeddings_lr: 0.0, encoder_lr: 1e-5, pooler_lr: 1e-5, layer_fix: base.trainer.optimizer.layer_fix, parameter_groups: [ [['_span_finder.*'], {'lr': rt_lr}], [['_span_extractor.*'], {'lr': rt_lr}], ] } }, [if std.length(cuda_devices) > 1 then "distributed"]: { "cuda_devices": cuda_devices }, [if std.length(cuda_devices) == 1 then "evaluate_on_test"]: true }