Spaces:
Sleeping
Sleeping
import random | |
import torch | |
import comet.src.data.config as cfg | |
import comet.src.train.atomic_train as base_train | |
import comet.src.train.batch as batch_utils | |
import comet.src.evaluate.conceptnet_evaluate as evaluate | |
import comet.src.evaluate.conceptnet_generate as gen | |
def make_trainer(opt, *args): | |
return ConceptNetGenerationIteratorTrainer(opt, *args) | |
class ConceptNetGenerationIteratorTrainer( | |
base_train.AtomicGenerationIteratorTrainer): | |
def set_evaluator(self, opt, model, data_loader): | |
self.evaluator = evaluate.make_evaluator( | |
opt, model, data_loader) | |
def set_generator(self, opt, model, data_loader): | |
self.generator = gen.make_generator( | |
opt, model, data_loader) | |
def batch(self, opt, *args): | |
outputs = batch_utils.batch_atomic_generate(opt, *args) | |
token_loss = outputs["loss"] | |
nums = outputs["nums"] | |
reset = outputs["reset"] | |
return token_loss, nums, reset | |
def update_top_score(self, opt): | |
print(self.top_score) | |
tracked_scores = self.get_tracked_score() | |
if self.top_score is None: | |
self.top_score = \ | |
self.top_score = {"epoch": {}, "score": {}} | |
self.top_score["epoch"]["total_micro"] = self.opt.train.dynamic.epoch | |
self.top_score["score"]["total_micro"] = tracked_scores["total_micro"] | |
else: | |
if tracked_scores["total_micro"] < self.top_score["score"]["total_micro"]: | |
self.top_score["epoch"]["total_micro"] = self.opt.train.dynamic.epoch | |
self.top_score["score"]["total_micro"] = tracked_scores["total_micro"] | |
print(self.top_score) | |
def get_tracked_score(self): | |
return { | |
"total_micro": self.losses["dev"]["total_micro"][self.opt.train.dynamic.epoch] | |
} | |
def decide_to_save(self): | |
to_save = cfg.save and not cfg.toy | |
curr_epoch = self.opt.train.dynamic.epoch | |
to_save = to_save or cfg.test_save | |
print(cfg.save_strategy) | |
if cfg.save_strategy == "best": | |
if ((self.top_score["epoch"]["total_micro"] != curr_epoch)): | |
to_save = False | |
return to_save | |