|
import os |
|
import math |
|
import json |
|
import random |
|
import argparse |
|
import numpy as np |
|
|
|
import torch |
|
import torch.distributed as dist |
|
import pytorch_lightning as pl |
|
from pytorch_lightning import LightningModule, LightningDataModule |
|
from pytorch_lightning.callbacks import LearningRateMonitor |
|
from pytorch_lightning.strategies.ddp import DDPStrategy |
|
from transformers import get_scheduler |
|
|
|
from reaction.model import Encoder, Decoder |
|
from reaction.pix2seq import build_pix2seq_model |
|
from reaction.loss import Criterion |
|
from reaction.tokenizer import get_tokenizer |
|
from reaction.dataset import ReactionDataset, get_collate_fn |
|
from reaction.data import postprocess_reactions |
|
from reaction.evaluate import CocoEvaluator, ReactionEvaluator |
|
import reaction.utils as utils |
|
|
|
|
|
def get_args(notebook=False): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--do_train', action='store_true') |
|
parser.add_argument('--do_valid', action='store_true') |
|
parser.add_argument('--do_test', action='store_true') |
|
parser.add_argument('--fp16', action='store_true') |
|
parser.add_argument('--seed', type=int, default=42) |
|
parser.add_argument('--gpus', type=int, default=1) |
|
parser.add_argument('--print_freq', type=int, default=200) |
|
parser.add_argument('--debug', action='store_true') |
|
parser.add_argument('--no_eval', action='store_true') |
|
|
|
parser.add_argument('--encoder', type=str, default='resnet34') |
|
parser.add_argument('--decoder', type=str, default='lstm') |
|
parser.add_argument('--trunc_encoder', action='store_true') |
|
parser.add_argument('--no_pretrained', action='store_true') |
|
parser.add_argument('--use_checkpoint', action='store_true') |
|
parser.add_argument('--lstm_dropout', type=float, default=0.5) |
|
parser.add_argument('--embed_dim', type=int, default=256) |
|
parser.add_argument('--enc_pos_emb', action='store_true') |
|
group = parser.add_argument_group("lstm_options") |
|
group.add_argument('--decoder_dim', type=int, default=512) |
|
group.add_argument('--decoder_layer', type=int, default=1) |
|
group.add_argument('--attention_dim', type=int, default=256) |
|
group = parser.add_argument_group("transformer_options") |
|
group.add_argument("--dec_num_layers", help="No. of layers in transformer decoder", type=int, default=6) |
|
group.add_argument("--dec_hidden_size", help="Decoder hidden size", type=int, default=256) |
|
group.add_argument("--dec_attn_heads", help="Decoder no. of attention heads", type=int, default=8) |
|
group.add_argument("--dec_num_queries", type=int, default=128) |
|
group.add_argument("--hidden_dropout", help="Hidden dropout", type=float, default=0.1) |
|
group.add_argument("--attn_dropout", help="Attention dropout", type=float, default=0.1) |
|
group.add_argument("--max_relative_positions", help="Max relative positions", type=int, default=0) |
|
|
|
parser.add_argument('--pix2seq', action='store_true', help="specify the model from playground") |
|
parser.add_argument('--pix2seq_ckpt', type=str, default=None) |
|
parser.add_argument('--large_scale_jitter', action='store_true', help='large scale jitter') |
|
parser.add_argument('--pred_eos', action='store_true', help='use eos token instead of predicting 100 objects') |
|
|
|
parser.add_argument('--backbone', default='resnet50', type=str, help="Name of the convolutional backbone to use") |
|
parser.add_argument('--dilation', action='store_true', |
|
help="If true, we replace stride with dilation in the last convolutional block (DC5)") |
|
parser.add_argument('--position_embedding', default='sine', type=str, choices=('sine', 'learned'), |
|
help="Type of positional embedding to use on top of the image features") |
|
|
|
parser.add_argument('--enc_layers', default=6, type=int, help="Number of encoding layers in the transformer") |
|
parser.add_argument('--dec_layers', default=6, type=int, help="Number of decoding layers in the transformer") |
|
parser.add_argument('--dim_feedforward', default=1024, type=int, |
|
help="Intermediate size of the feedforward layers in the transformer blocks") |
|
parser.add_argument('--hidden_dim', default=256, type=int, |
|
help="Size of the embeddings (dimension of the transformer)") |
|
parser.add_argument('--dropout', default=0.1, type=float, help="Dropout applied in the transformer") |
|
parser.add_argument('--nheads', default=8, type=int, |
|
help="Number of attention heads inside the transformer's attentions") |
|
parser.add_argument('--pre_norm', action='store_true') |
|
|
|
parser.add_argument('--data_path', type=str, default=None) |
|
parser.add_argument('--image_path', type=str, default=None) |
|
parser.add_argument('--train_file', type=str, default=None) |
|
parser.add_argument('--valid_file', type=str, default=None) |
|
parser.add_argument('--test_file', type=str, default=None) |
|
parser.add_argument('--vocab_file', type=str, default=None) |
|
parser.add_argument('--format', type=str, default='reaction') |
|
parser.add_argument('--num_workers', type=int, default=8) |
|
parser.add_argument('--input_size', type=int, default=224) |
|
parser.add_argument('--augment', action='store_true') |
|
parser.add_argument('--composite_augment', action='store_true') |
|
parser.add_argument('--coord_bins', type=int, default=100) |
|
parser.add_argument('--sep_xy', action='store_true') |
|
parser.add_argument('--rand_order', action='store_true', help="randomly permute the sequence of input targets") |
|
parser.add_argument('--add_noise', action='store_true') |
|
parser.add_argument('--mix_noise', action='store_true') |
|
parser.add_argument('--shuffle_bbox', action='store_true') |
|
parser.add_argument('--images', type=str, default='') |
|
|
|
parser.add_argument('--epochs', type=int, default=8) |
|
parser.add_argument('--batch_size', type=int, default=256) |
|
parser.add_argument('--lr', type=float, default=1e-4) |
|
parser.add_argument('--weight_decay', type=float, default=0.05) |
|
parser.add_argument('--max_grad_norm', type=float, default=5.) |
|
parser.add_argument('--scheduler', type=str, choices=['cosine', 'constant'], default='cosine') |
|
parser.add_argument('--warmup_ratio', type=float, default=0) |
|
parser.add_argument('--gradient_accumulation_steps', type=int, default=1) |
|
parser.add_argument('--load_path', type=str, default=None) |
|
parser.add_argument('--load_encoder_only', action='store_true') |
|
parser.add_argument('--train_steps_per_epoch', type=int, default=-1) |
|
parser.add_argument('--eval_per_epoch', type=int, default=10) |
|
parser.add_argument('--save_path', type=str, default='output/') |
|
parser.add_argument('--save_mode', type=str, default='best', choices=['best', 'all', 'last']) |
|
parser.add_argument('--load_ckpt', type=str, default='best') |
|
parser.add_argument('--resume', action='store_true') |
|
parser.add_argument('--num_train_example', type=int, default=None) |
|
parser.add_argument('--label_smoothing', type=float, default=0.0) |
|
parser.add_argument('--save_image', action='store_true') |
|
|
|
parser.add_argument('--beam_size', type=int, default=1) |
|
parser.add_argument('--n_best', type=int, default=1) |
|
parser.add_argument('--molscribe', action='store_true') |
|
args = parser.parse_args([]) if notebook else parser.parse_args() |
|
|
|
args.images = args.images.split(',') |
|
|
|
return args |
|
|
|
|
|
class ReactionExtractor(LightningModule): |
|
|
|
def __init__(self, args, tokenizer): |
|
super().__init__() |
|
self.args = args |
|
self.tokenizer = tokenizer |
|
self.encoder = Encoder(args, pretrained=(not args.no_pretrained)) |
|
args.encoder_dim = self.encoder.n_features |
|
self.decoder = Decoder(args, tokenizer) |
|
self.criterion = Criterion(args, tokenizer) |
|
|
|
def training_step(self, batch, batch_idx): |
|
indices, images, refs = batch |
|
features, hiddens = self.encoder(images, refs) |
|
results = self.decoder(features, hiddens, refs) |
|
losses = self.criterion(results, refs) |
|
loss = sum(losses.values()) |
|
self.log('train/loss', loss) |
|
self.log('lr', self.lr_schedulers().get_lr()[0], prog_bar=True, logger=False) |
|
return loss |
|
|
|
def validation_step(self, batch, batch_idx): |
|
indices, images, refs = batch |
|
features, hiddens = self.encoder(images, refs) |
|
batch_preds, batch_beam_preds = self.decoder.decode( |
|
features, hiddens, refs, |
|
beam_size=self.args.beam_size, n_best=self.args.n_best) |
|
return indices, batch_preds |
|
|
|
def validation_epoch_end(self, outputs, phase='val'): |
|
if self.trainer.num_devices > 1: |
|
gathered_outputs = [None for i in range(self.trainer.num_devices)] |
|
dist.all_gather_object(gathered_outputs, outputs) |
|
gathered_outputs = sum(gathered_outputs, []) |
|
else: |
|
gathered_outputs = outputs |
|
|
|
format = self.args.format |
|
predictions = utils.merge_predictions(gathered_outputs) |
|
|
|
name = self.eval_dataset.name |
|
scores = [0] |
|
|
|
if self.trainer.is_global_zero: |
|
if not self.args.no_eval: |
|
if format == 'bbox': |
|
coco_evaluator = CocoEvaluator(self.eval_dataset.coco) |
|
stats = coco_evaluator.evaluate(predictions['bbox']) |
|
scores = results = list(stats) |
|
elif format == 'reaction': |
|
epoch = self.trainer.current_epoch |
|
evaluator = ReactionEvaluator() |
|
results, *_ = evaluator.evaluate_summarize(self.eval_dataset.data, predictions['reaction']) |
|
precision, recall, f1 = \ |
|
results['overall']['precision'], results['overall']['recall'], results['overall']['f1'] |
|
scores = [f1] |
|
self.print(f'Epoch: {epoch:>3} Precision: {precision:.4f} Recall: {recall:.4f} F1: {f1:.4f}') |
|
results['mol_only'], *_ = evaluator.evaluate_summarize( |
|
self.eval_dataset.data, predictions['reaction'], mol_only=True, merge_condition=True) |
|
else: |
|
raise NotImplementedError |
|
with open(os.path.join(self.trainer.default_root_dir, f'eval_{name}.json'), 'w') as f: |
|
json.dump(results, f) |
|
if phase == 'test': |
|
self.print(json.dumps(results, indent=4)) |
|
with open(os.path.join(self.trainer.default_root_dir, f'prediction_{name}.json'), 'w') as f: |
|
json.dump(predictions, f) |
|
|
|
dist.broadcast_object_list(scores) |
|
self.log(f'{phase}/score', scores[0], prog_bar=True, rank_zero_only=True) |
|
|
|
def test_step(self, batch, batch_idx): |
|
return self.validation_step(batch, batch_idx) |
|
|
|
def test_epoch_end(self, outputs): |
|
return self.validation_epoch_end(outputs, phase='test') |
|
|
|
def predict_step(self, batch, batch_idx): |
|
return self.validation_step(batch, batch_idx) |
|
|
|
def configure_optimizers(self): |
|
num_training_steps = self.trainer.num_training_steps |
|
self.print(f'Num training steps: {num_training_steps}') |
|
num_warmup_steps = int(num_training_steps * self.args.warmup_ratio) |
|
|
|
optimizer = torch.optim.AdamW(self.parameters(), lr=self.args.lr, weight_decay=self.args.weight_decay) |
|
scheduler = get_scheduler(self.args.scheduler, optimizer, num_warmup_steps, num_training_steps) |
|
return {'optimizer': optimizer, 'lr_scheduler': {'scheduler': scheduler, 'interval': 'step'}} |
|
|
|
|
|
class ReactionExtractorPix2Seq(ReactionExtractor): |
|
|
|
def __init__(self, args, tokenizer): |
|
super(ReactionExtractor, self).__init__() |
|
self.args = args |
|
self.tokenizer = tokenizer |
|
self.format = args.format |
|
self.model = build_pix2seq_model(args, tokenizer[self.format]) |
|
self.criterion = Criterion(args, tokenizer) |
|
self.molscribe = None |
|
|
|
def training_step(self, batch, batch_idx): |
|
indices, images, refs = batch |
|
format = self.format |
|
results = {format: (self.model(images, refs[format]), refs[format+'_out'][0][:, 1:])} |
|
losses = self.criterion(results, refs) |
|
loss = sum(losses.values()) |
|
self.log('train/loss', loss) |
|
self.log('lr', self.lr_schedulers().get_lr()[0], prog_bar=True, logger=False) |
|
return loss |
|
|
|
def validation_step(self, batch, batch_idx): |
|
indices, images, refs = batch |
|
format = self.format |
|
batch_preds = {format: [], 'file_name': []} |
|
pred_seqs, pred_scores = self.model(images, max_len=self.tokenizer[format].max_len) |
|
for i, (seqs, scores) in enumerate(zip(pred_seqs, pred_scores)): |
|
if format == 'reaction': |
|
reactions = self.tokenizer[format].sequence_to_data(seqs.tolist(), scores.tolist(), scale=refs['scale'][i]) |
|
reactions = postprocess_reactions(reactions) |
|
batch_preds[format].append(reactions) |
|
if format == 'bbox': |
|
bboxes = self.tokenizer[format].sequence_to_data(seqs.tolist(), scores.tolist(), scale=refs['scale'][i]) |
|
batch_preds[format].append(bboxes) |
|
batch_preds['file_name'].append(refs['file_name'][i]) |
|
return indices, batch_preds |
|
|
|
|
|
class ReactionDataModule(LightningDataModule): |
|
|
|
def __init__(self, args, tokenizer): |
|
super().__init__() |
|
self.args = args |
|
self.tokenizer = tokenizer |
|
self.collate_fn = get_collate_fn(self.pad_id) |
|
|
|
@property |
|
def pad_id(self): |
|
return self.tokenizer[self.args.format].PAD_ID |
|
|
|
def prepare_data(self): |
|
args = self.args |
|
if args.do_train: |
|
self.train_dataset = ReactionDataset(args, self.tokenizer, args.train_file, split='train') |
|
if self.args.do_train or self.args.do_valid: |
|
self.val_dataset = ReactionDataset(args, self.tokenizer, args.valid_file, split='valid') |
|
if self.args.do_test: |
|
self.test_dataset = ReactionDataset(args, self.tokenizer, args.test_file, split='test') |
|
|
|
def print_stats(self): |
|
if self.args.do_train: |
|
print(f'Train dataset: {len(self.train_dataset)}') |
|
if self.args.do_train or self.args.do_valid: |
|
print(f'Valid dataset: {len(self.val_dataset)}') |
|
if self.args.do_test: |
|
print(f'Test dataset: {len(self.test_dataset)}') |
|
|
|
def train_dataloader(self): |
|
return torch.utils.data.DataLoader( |
|
self.train_dataset, batch_size=self.args.batch_size, num_workers=self.args.num_workers, |
|
collate_fn=self.collate_fn) |
|
|
|
def val_dataloader(self): |
|
return torch.utils.data.DataLoader( |
|
self.val_dataset, batch_size=self.args.batch_size, num_workers=self.args.num_workers, |
|
collate_fn=self.collate_fn) |
|
|
|
def test_dataloader(self): |
|
return torch.utils.data.DataLoader( |
|
self.test_dataset, batch_size=self.args.batch_size, num_workers=self.args.num_workers, |
|
collate_fn=self.collate_fn) |
|
|
|
|
|
class ModelCheckpoint(pl.callbacks.ModelCheckpoint): |
|
def _get_metric_interpolated_filepath_name(self, monitor_candidates, trainer, del_filepath=None) -> str: |
|
filepath = self.format_checkpoint_name(monitor_candidates) |
|
return filepath |
|
|
|
|
|
def main(): |
|
|
|
args = get_args() |
|
pl.seed_everything(args.seed, workers=True) |
|
|
|
if args.debug: |
|
args.save_path = "output/debug" |
|
|
|
tokenizer = get_tokenizer(args) |
|
|
|
MODEL = ReactionExtractorPix2Seq if args.pix2seq else ReactionExtractor |
|
if args.do_train: |
|
model = MODEL(args, tokenizer) |
|
else: |
|
model = MODEL.load_from_checkpoint(os.path.join(args.save_path, 'checkpoints/best.ckpt'), strict=False, |
|
args=args, tokenizer=tokenizer) |
|
|
|
dm = ReactionDataModule(args, tokenizer) |
|
dm.prepare_data() |
|
dm.print_stats() |
|
|
|
checkpoint = ModelCheckpoint(monitor='val/score', mode='max', save_top_k=1, filename='best', save_last=True) |
|
|
|
lr_monitor = LearningRateMonitor(logging_interval='step') |
|
logger = pl.loggers.TensorBoardLogger(args.save_path, name='', version='') |
|
|
|
trainer = pl.Trainer( |
|
strategy=DDPStrategy(find_unused_parameters=False), |
|
accelerator='gpu', |
|
devices=4, |
|
logger=logger, |
|
default_root_dir=args.save_path, |
|
callbacks=[checkpoint, lr_monitor], |
|
max_epochs=args.epochs, |
|
gradient_clip_val=args.max_grad_norm, |
|
accumulate_grad_batches=args.gradient_accumulation_steps, |
|
check_val_every_n_epoch=args.eval_per_epoch, |
|
log_every_n_steps=10, |
|
deterministic=True) |
|
|
|
if args.do_train: |
|
trainer.num_training_steps = math.ceil( |
|
len(dm.train_dataset) / (args.batch_size * args.gpus * args.gradient_accumulation_steps)) * args.epochs |
|
model.eval_dataset = dm.val_dataset |
|
ckpt_path = os.path.join(args.save_path, 'checkpoints/last.ckpt') if args.resume else None |
|
trainer.fit(model, datamodule=dm, ckpt_path=ckpt_path) |
|
model = MODEL.load_from_checkpoint(checkpoint.best_model_path, args=args, tokenizer=tokenizer) |
|
|
|
if args.do_valid: |
|
model.eval_dataset = dm.val_dataset |
|
trainer.validate(model, datamodule=dm) |
|
|
|
if args.do_test: |
|
model.eval_dataset = dm.test_dataset |
|
trainer.test(model, datamodule=dm) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|