RxnIM / rxn /main.py
CYF200127's picture
Upload 116 files
5e9bd47 verified
import os
import math
import json
import random
import argparse
import numpy as np
import torch
import torch.distributed as dist
import pytorch_lightning as pl
from pytorch_lightning import LightningModule, LightningDataModule
from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning.strategies.ddp import DDPStrategy
from transformers import get_scheduler
from reaction.model import Encoder, Decoder
from reaction.pix2seq import build_pix2seq_model
from reaction.loss import Criterion
from reaction.tokenizer import get_tokenizer
from reaction.dataset import ReactionDataset, get_collate_fn
from reaction.data import postprocess_reactions
from reaction.evaluate import CocoEvaluator, ReactionEvaluator
import reaction.utils as utils
def get_args(notebook=False):
parser = argparse.ArgumentParser()
parser.add_argument('--do_train', action='store_true')
parser.add_argument('--do_valid', action='store_true')
parser.add_argument('--do_test', action='store_true')
parser.add_argument('--fp16', action='store_true')
parser.add_argument('--seed', type=int, default=42)
parser.add_argument('--gpus', type=int, default=1)
parser.add_argument('--print_freq', type=int, default=200)
parser.add_argument('--debug', action='store_true')
parser.add_argument('--no_eval', action='store_true')
# Model
parser.add_argument('--encoder', type=str, default='resnet34')
parser.add_argument('--decoder', type=str, default='lstm')
parser.add_argument('--trunc_encoder', action='store_true') # use the hidden states before downsample
parser.add_argument('--no_pretrained', action='store_true')
parser.add_argument('--use_checkpoint', action='store_true')
parser.add_argument('--lstm_dropout', type=float, default=0.5)
parser.add_argument('--embed_dim', type=int, default=256)
parser.add_argument('--enc_pos_emb', action='store_true')
group = parser.add_argument_group("lstm_options")
group.add_argument('--decoder_dim', type=int, default=512)
group.add_argument('--decoder_layer', type=int, default=1)
group.add_argument('--attention_dim', type=int, default=256)
group = parser.add_argument_group("transformer_options")
group.add_argument("--dec_num_layers", help="No. of layers in transformer decoder", type=int, default=6)
group.add_argument("--dec_hidden_size", help="Decoder hidden size", type=int, default=256)
group.add_argument("--dec_attn_heads", help="Decoder no. of attention heads", type=int, default=8)
group.add_argument("--dec_num_queries", type=int, default=128)
group.add_argument("--hidden_dropout", help="Hidden dropout", type=float, default=0.1)
group.add_argument("--attn_dropout", help="Attention dropout", type=float, default=0.1)
group.add_argument("--max_relative_positions", help="Max relative positions", type=int, default=0)
# Pix2Seq
parser.add_argument('--pix2seq', action='store_true', help="specify the model from playground")
parser.add_argument('--pix2seq_ckpt', type=str, default=None)
parser.add_argument('--large_scale_jitter', action='store_true', help='large scale jitter')
parser.add_argument('--pred_eos', action='store_true', help='use eos token instead of predicting 100 objects')
# * Backbone
parser.add_argument('--backbone', default='resnet50', type=str, help="Name of the convolutional backbone to use")
parser.add_argument('--dilation', action='store_true',
help="If true, we replace stride with dilation in the last convolutional block (DC5)")
parser.add_argument('--position_embedding', default='sine', type=str, choices=('sine', 'learned'),
help="Type of positional embedding to use on top of the image features")
# * Transformer
parser.add_argument('--enc_layers', default=6, type=int, help="Number of encoding layers in the transformer")
parser.add_argument('--dec_layers', default=6, type=int, help="Number of decoding layers in the transformer")
parser.add_argument('--dim_feedforward', default=1024, type=int,
help="Intermediate size of the feedforward layers in the transformer blocks")
parser.add_argument('--hidden_dim', default=256, type=int,
help="Size of the embeddings (dimension of the transformer)")
parser.add_argument('--dropout', default=0.1, type=float, help="Dropout applied in the transformer")
parser.add_argument('--nheads', default=8, type=int,
help="Number of attention heads inside the transformer's attentions")
parser.add_argument('--pre_norm', action='store_true')
# Data
parser.add_argument('--data_path', type=str, default=None)
parser.add_argument('--image_path', type=str, default=None)
parser.add_argument('--train_file', type=str, default=None)
parser.add_argument('--valid_file', type=str, default=None)
parser.add_argument('--test_file', type=str, default=None)
parser.add_argument('--vocab_file', type=str, default=None)
parser.add_argument('--format', type=str, default='reaction')
parser.add_argument('--num_workers', type=int, default=8)
parser.add_argument('--input_size', type=int, default=224)
parser.add_argument('--augment', action='store_true')
parser.add_argument('--composite_augment', action='store_true')
parser.add_argument('--coord_bins', type=int, default=100)
parser.add_argument('--sep_xy', action='store_true')
parser.add_argument('--rand_order', action='store_true', help="randomly permute the sequence of input targets")
parser.add_argument('--add_noise', action='store_true')
parser.add_argument('--mix_noise', action='store_true')
parser.add_argument('--shuffle_bbox', action='store_true')
parser.add_argument('--images', type=str, default='')
# Training
parser.add_argument('--epochs', type=int, default=8)
parser.add_argument('--batch_size', type=int, default=256)
parser.add_argument('--lr', type=float, default=1e-4)
parser.add_argument('--weight_decay', type=float, default=0.05)
parser.add_argument('--max_grad_norm', type=float, default=5.)
parser.add_argument('--scheduler', type=str, choices=['cosine', 'constant'], default='cosine')
parser.add_argument('--warmup_ratio', type=float, default=0)
parser.add_argument('--gradient_accumulation_steps', type=int, default=1)
parser.add_argument('--load_path', type=str, default=None)
parser.add_argument('--load_encoder_only', action='store_true')
parser.add_argument('--train_steps_per_epoch', type=int, default=-1)
parser.add_argument('--eval_per_epoch', type=int, default=10)
parser.add_argument('--save_path', type=str, default='output/')
parser.add_argument('--save_mode', type=str, default='best', choices=['best', 'all', 'last'])
parser.add_argument('--load_ckpt', type=str, default='best')
parser.add_argument('--resume', action='store_true')
parser.add_argument('--num_train_example', type=int, default=None)
parser.add_argument('--label_smoothing', type=float, default=0.0)
parser.add_argument('--save_image', action='store_true')
# Inference
parser.add_argument('--beam_size', type=int, default=1)
parser.add_argument('--n_best', type=int, default=1)
parser.add_argument('--molscribe', action='store_true')
args = parser.parse_args([]) if notebook else parser.parse_args()
args.images = args.images.split(',')
return args
class ReactionExtractor(LightningModule):
def __init__(self, args, tokenizer):
super().__init__()
self.args = args
self.tokenizer = tokenizer
self.encoder = Encoder(args, pretrained=(not args.no_pretrained))
args.encoder_dim = self.encoder.n_features
self.decoder = Decoder(args, tokenizer)
self.criterion = Criterion(args, tokenizer)
def training_step(self, batch, batch_idx):
indices, images, refs = batch
features, hiddens = self.encoder(images, refs)
results = self.decoder(features, hiddens, refs)
losses = self.criterion(results, refs)
loss = sum(losses.values())
self.log('train/loss', loss)
self.log('lr', self.lr_schedulers().get_lr()[0], prog_bar=True, logger=False)
return loss
def validation_step(self, batch, batch_idx):
indices, images, refs = batch
features, hiddens = self.encoder(images, refs)
batch_preds, batch_beam_preds = self.decoder.decode(
features, hiddens, refs,
beam_size=self.args.beam_size, n_best=self.args.n_best)
return indices, batch_preds
def validation_epoch_end(self, outputs, phase='val'):
if self.trainer.num_devices > 1:
gathered_outputs = [None for i in range(self.trainer.num_devices)]
dist.all_gather_object(gathered_outputs, outputs)
gathered_outputs = sum(gathered_outputs, [])
else:
gathered_outputs = outputs
format = self.args.format
predictions = utils.merge_predictions(gathered_outputs)
name = self.eval_dataset.name
scores = [0]
if self.trainer.is_global_zero:
if not self.args.no_eval:
if format == 'bbox':
coco_evaluator = CocoEvaluator(self.eval_dataset.coco)
stats = coco_evaluator.evaluate(predictions['bbox'])
scores = results = list(stats)
elif format == 'reaction':
epoch = self.trainer.current_epoch
evaluator = ReactionEvaluator()
results, *_ = evaluator.evaluate_summarize(self.eval_dataset.data, predictions['reaction'])
precision, recall, f1 = \
results['overall']['precision'], results['overall']['recall'], results['overall']['f1']
scores = [f1]
self.print(f'Epoch: {epoch:>3} Precision: {precision:.4f} Recall: {recall:.4f} F1: {f1:.4f}')
results['mol_only'], *_ = evaluator.evaluate_summarize(
self.eval_dataset.data, predictions['reaction'], mol_only=True, merge_condition=True)
else:
raise NotImplementedError
with open(os.path.join(self.trainer.default_root_dir, f'eval_{name}.json'), 'w') as f:
json.dump(results, f)
if phase == 'test':
self.print(json.dumps(results, indent=4))
with open(os.path.join(self.trainer.default_root_dir, f'prediction_{name}.json'), 'w') as f:
json.dump(predictions, f)
dist.broadcast_object_list(scores)
self.log(f'{phase}/score', scores[0], prog_bar=True, rank_zero_only=True)
def test_step(self, batch, batch_idx):
return self.validation_step(batch, batch_idx)
def test_epoch_end(self, outputs):
return self.validation_epoch_end(outputs, phase='test')
def predict_step(self, batch, batch_idx):
return self.validation_step(batch, batch_idx)
def configure_optimizers(self):
num_training_steps = self.trainer.num_training_steps
self.print(f'Num training steps: {num_training_steps}')
num_warmup_steps = int(num_training_steps * self.args.warmup_ratio)
# parameters = list(self.encoder.parameters()) + list(self.decoder.parameters())
optimizer = torch.optim.AdamW(self.parameters(), lr=self.args.lr, weight_decay=self.args.weight_decay)
scheduler = get_scheduler(self.args.scheduler, optimizer, num_warmup_steps, num_training_steps)
return {'optimizer': optimizer, 'lr_scheduler': {'scheduler': scheduler, 'interval': 'step'}}
class ReactionExtractorPix2Seq(ReactionExtractor):
def __init__(self, args, tokenizer):
super(ReactionExtractor, self).__init__()
self.args = args
self.tokenizer = tokenizer
self.format = args.format
self.model = build_pix2seq_model(args, tokenizer[self.format])
self.criterion = Criterion(args, tokenizer)
self.molscribe = None
def training_step(self, batch, batch_idx):
indices, images, refs = batch
format = self.format
results = {format: (self.model(images, refs[format]), refs[format+'_out'][0][:, 1:])}
losses = self.criterion(results, refs)
loss = sum(losses.values())
self.log('train/loss', loss)
self.log('lr', self.lr_schedulers().get_lr()[0], prog_bar=True, logger=False)
return loss
def validation_step(self, batch, batch_idx):
indices, images, refs = batch
format = self.format
batch_preds = {format: [], 'file_name': []}
pred_seqs, pred_scores = self.model(images, max_len=self.tokenizer[format].max_len)
for i, (seqs, scores) in enumerate(zip(pred_seqs, pred_scores)):
if format == 'reaction':
reactions = self.tokenizer[format].sequence_to_data(seqs.tolist(), scores.tolist(), scale=refs['scale'][i])
reactions = postprocess_reactions(reactions)
batch_preds[format].append(reactions)
if format == 'bbox':
bboxes = self.tokenizer[format].sequence_to_data(seqs.tolist(), scores.tolist(), scale=refs['scale'][i])
batch_preds[format].append(bboxes)
batch_preds['file_name'].append(refs['file_name'][i])
return indices, batch_preds
class ReactionDataModule(LightningDataModule):
def __init__(self, args, tokenizer):
super().__init__()
self.args = args
self.tokenizer = tokenizer
self.collate_fn = get_collate_fn(self.pad_id)
@property
def pad_id(self):
return self.tokenizer[self.args.format].PAD_ID
def prepare_data(self):
args = self.args
if args.do_train:
self.train_dataset = ReactionDataset(args, self.tokenizer, args.train_file, split='train')
if self.args.do_train or self.args.do_valid:
self.val_dataset = ReactionDataset(args, self.tokenizer, args.valid_file, split='valid')
if self.args.do_test:
self.test_dataset = ReactionDataset(args, self.tokenizer, args.test_file, split='test')
def print_stats(self):
if self.args.do_train:
print(f'Train dataset: {len(self.train_dataset)}')
if self.args.do_train or self.args.do_valid:
print(f'Valid dataset: {len(self.val_dataset)}')
if self.args.do_test:
print(f'Test dataset: {len(self.test_dataset)}')
def train_dataloader(self):
return torch.utils.data.DataLoader(
self.train_dataset, batch_size=self.args.batch_size, num_workers=self.args.num_workers,
collate_fn=self.collate_fn)
def val_dataloader(self):
return torch.utils.data.DataLoader(
self.val_dataset, batch_size=self.args.batch_size, num_workers=self.args.num_workers,
collate_fn=self.collate_fn)
def test_dataloader(self):
return torch.utils.data.DataLoader(
self.test_dataset, batch_size=self.args.batch_size, num_workers=self.args.num_workers,
collate_fn=self.collate_fn)
class ModelCheckpoint(pl.callbacks.ModelCheckpoint):
def _get_metric_interpolated_filepath_name(self, monitor_candidates, trainer, del_filepath=None) -> str:
filepath = self.format_checkpoint_name(monitor_candidates)
return filepath
def main():
args = get_args()
pl.seed_everything(args.seed, workers=True)
if args.debug:
args.save_path = "output/debug"
tokenizer = get_tokenizer(args)
MODEL = ReactionExtractorPix2Seq if args.pix2seq else ReactionExtractor
if args.do_train:
model = MODEL(args, tokenizer)
else:
model = MODEL.load_from_checkpoint(os.path.join(args.save_path, 'checkpoints/best.ckpt'), strict=False,
args=args, tokenizer=tokenizer)
dm = ReactionDataModule(args, tokenizer)
dm.prepare_data()
dm.print_stats()
checkpoint = ModelCheckpoint(monitor='val/score', mode='max', save_top_k=1, filename='best', save_last=True)
# checkpoint = ModelCheckpoint(monitor=None, save_top_k=0, save_last=True)
lr_monitor = LearningRateMonitor(logging_interval='step')
logger = pl.loggers.TensorBoardLogger(args.save_path, name='', version='')
trainer = pl.Trainer(
strategy=DDPStrategy(find_unused_parameters=False),
accelerator='gpu',
devices=4,
logger=logger,
default_root_dir=args.save_path,
callbacks=[checkpoint, lr_monitor],
max_epochs=args.epochs,
gradient_clip_val=args.max_grad_norm,
accumulate_grad_batches=args.gradient_accumulation_steps,
check_val_every_n_epoch=args.eval_per_epoch,
log_every_n_steps=10,
deterministic=True)
if args.do_train:
trainer.num_training_steps = math.ceil(
len(dm.train_dataset) / (args.batch_size * args.gpus * args.gradient_accumulation_steps)) * args.epochs
model.eval_dataset = dm.val_dataset
ckpt_path = os.path.join(args.save_path, 'checkpoints/last.ckpt') if args.resume else None
trainer.fit(model, datamodule=dm, ckpt_path=ckpt_path)
model = MODEL.load_from_checkpoint(checkpoint.best_model_path, args=args, tokenizer=tokenizer)
if args.do_valid:
model.eval_dataset = dm.val_dataset
trainer.validate(model, datamodule=dm)
if args.do_test:
model.eval_dataset = dm.test_dataset
trainer.test(model, datamodule=dm)
if __name__ == "__main__":
main()