|
local env = import "../env.jsonnet"; |
|
local base = import "basic.jsonnet"; |
|
|
|
local debug = false; |
|
|
|
# re-train |
|
local pretrained_path = env.str("PRETRAINED_PATH", "cache/fn/best"); |
|
local rt_lr = env.json("RT_LR", 5e-5); |
|
|
|
# module |
|
local cuda_devices = base.cuda_devices; |
|
|
|
{ |
|
dataset_reader: base.dataset_reader, |
|
train_data_path: base.train_data_path, |
|
validation_data_path: base.validation_data_path, |
|
test_data_path: base.test_data_path, |
|
|
|
datasets_for_vocab_creation: ["train"], |
|
|
|
data_loader: base.data_loader, |
|
validation_data_loader: base.validation_data_loader, |
|
|
|
model: { |
|
type: "span", |
|
word_embedding: { |
|
"_pretrained": { |
|
"archive_file": pretrained_path, |
|
"module_path": "word_embedding", |
|
"freeze": false, |
|
} |
|
}, |
|
span_extractor: { |
|
"_pretrained": { |
|
"archive_file": pretrained_path, |
|
"module_path": "_span_extractor", |
|
"freeze": false, |
|
} |
|
}, |
|
span_finder: { |
|
"_pretrained": { |
|
"archive_file": pretrained_path, |
|
"module_path": "_span_finder", |
|
"freeze": false, |
|
} |
|
}, |
|
span_typing: { |
|
type: 'mlp', |
|
hidden_dims: base.model.span_typing.hidden_dims, |
|
}, |
|
metrics: [{type: "srl"}], |
|
|
|
typing_loss_factor: base.model.typing_loss_factor, |
|
label_dim: base.model.label_dim, |
|
max_decoding_spans: 128, |
|
max_recursion_depth: 2, |
|
debug: debug, |
|
}, |
|
|
|
trainer: { |
|
num_epochs: base.trainer.num_epochs, |
|
patience: base.trainer.patience, |
|
[if std.length(cuda_devices) == 1 then "cuda_device"]: cuda_devices[0], |
|
validation_metric: "+arg-c_f", |
|
num_gradient_accumulation_steps: base.trainer.num_gradient_accumulation_steps, |
|
optimizer: { |
|
type: "transformer", |
|
base: { |
|
type: "adam", |
|
lr: base.trainer.optimizer.base.lr, |
|
}, |
|
embeddings_lr: 0.0, |
|
encoder_lr: 1e-5, |
|
pooler_lr: 1e-5, |
|
layer_fix: base.trainer.optimizer.layer_fix, |
|
parameter_groups: [ |
|
[['_span_finder.*'], {'lr': rt_lr}], |
|
[['_span_extractor.*'], {'lr': rt_lr}], |
|
] |
|
} |
|
}, |
|
|
|
[if std.length(cuda_devices) > 1 then "distributed"]: { |
|
"cuda_devices": cuda_devices |
|
}, |
|
[if std.length(cuda_devices) == 1 then "evaluate_on_test"]: true |
|
} |
|
|