Minh Q. Le
Pushed COSMIC code
a446b0b
raw
history blame
2.22 kB
import random
import torch
import comet.src.data.config as cfg
import comet.src.train.atomic_train as base_train
import comet.src.train.batch as batch_utils
import comet.src.evaluate.conceptnet_evaluate as evaluate
import comet.src.evaluate.conceptnet_generate as gen
def make_trainer(opt, *args):
return ConceptNetGenerationIteratorTrainer(opt, *args)
class ConceptNetGenerationIteratorTrainer(
base_train.AtomicGenerationIteratorTrainer):
def set_evaluator(self, opt, model, data_loader):
self.evaluator = evaluate.make_evaluator(
opt, model, data_loader)
def set_generator(self, opt, model, data_loader):
self.generator = gen.make_generator(
opt, model, data_loader)
def batch(self, opt, *args):
outputs = batch_utils.batch_atomic_generate(opt, *args)
token_loss = outputs["loss"]
nums = outputs["nums"]
reset = outputs["reset"]
return token_loss, nums, reset
def update_top_score(self, opt):
print(self.top_score)
tracked_scores = self.get_tracked_score()
if self.top_score is None:
self.top_score = \
self.top_score = {"epoch": {}, "score": {}}
self.top_score["epoch"]["total_micro"] = self.opt.train.dynamic.epoch
self.top_score["score"]["total_micro"] = tracked_scores["total_micro"]
else:
if tracked_scores["total_micro"] < self.top_score["score"]["total_micro"]:
self.top_score["epoch"]["total_micro"] = self.opt.train.dynamic.epoch
self.top_score["score"]["total_micro"] = tracked_scores["total_micro"]
print(self.top_score)
def get_tracked_score(self):
return {
"total_micro": self.losses["dev"]["total_micro"][self.opt.train.dynamic.epoch]
}
def decide_to_save(self):
to_save = cfg.save and not cfg.toy
curr_epoch = self.opt.train.dynamic.epoch
to_save = to_save or cfg.test_save
print(cfg.save_strategy)
if cfg.save_strategy == "best":
if ((self.top_score["epoch"]["total_micro"] != curr_epoch)):
to_save = False
return to_save