RxnIM / molscribe /utils.py
CYF200127's picture
Upload 116 files
5e9bd47 verified
raw
history blame
4.05 kB
import os
import random
import numpy as np
import torch
import math
import time
import datetime
import json
from json import encoder
FORMAT_INFO = {
"inchi": {
"name": "InChI_text",
"tokenizer": "tokenizer_inchi.json",
"max_len": 300
},
"atomtok": {
"name": "SMILES_atomtok",
"tokenizer": "tokenizer_smiles_atomtok.json",
"max_len": 256
},
"nodes": {"max_len": 384},
"atomtok_coords": {"max_len": 480},
"chartok_coords": {"max_len": 480}
}
def init_logger(log_file='train.log'):
from logging import getLogger, INFO, FileHandler, Formatter, StreamHandler
logger = getLogger(__name__)
logger.setLevel(INFO)
handler1 = StreamHandler()
handler1.setFormatter(Formatter("%(message)s"))
handler2 = FileHandler(filename=log_file)
handler2.setFormatter(Formatter("%(message)s"))
logger.addHandler(handler1)
logger.addHandler(handler2)
return logger
def init_summary_writer(save_path):
from tensorboardX import SummaryWriter
summary = SummaryWriter(save_path)
return summary
def save_args(args):
dt = datetime.datetime.strftime(datetime.datetime.now(), "%y%m%d-%H%M")
path = os.path.join(args.save_path, f'train_{dt}.log')
with open(path, 'w') as f:
for k, v in vars(args).items():
f.write(f"**** {k} = *{v}*\n")
return
def seed_torch(seed=42):
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
class EpochMeter(AverageMeter):
def __init__(self):
super().__init__()
self.epoch = AverageMeter()
def update(self, val, n=1):
super().update(val, n)
self.epoch.update(val, n)
class LossMeter(EpochMeter):
def __init__(self):
self.subs = {}
super().__init__()
def reset(self):
super().reset()
for k in self.subs:
self.subs[k].reset()
def update(self, loss, losses, n=1):
loss = loss.item()
super().update(loss, n)
losses = {k: v.item() for k, v in losses.items()}
for k, v in losses.items():
if k not in self.subs:
self.subs[k] = EpochMeter()
self.subs[k].update(v, n)
def asMinutes(s):
m = math.floor(s / 60)
s -= m * 60
return '%dm %ds' % (m, s)
def timeSince(since, percent):
now = time.time()
s = now - since
es = s / (percent)
rs = es - s
return '%s (remain %s)' % (asMinutes(s), asMinutes(rs))
def print_rank_0(message):
if torch.distributed.is_initialized():
if torch.distributed.get_rank() == 0:
print(message, flush=True)
else:
print(message, flush=True)
def to_device(data, device):
if torch.is_tensor(data):
return data.to(device)
if type(data) is list:
return [to_device(v, device) for v in data]
if type(data) is dict:
return {k: to_device(v, device) for k, v in data.items()}
def round_floats(o):
if isinstance(o, float):
return round(o, 3)
if isinstance(o, dict):
return {k: round_floats(v) for k, v in o.items()}
if isinstance(o, (list, tuple)):
return [round_floats(x) for x in o]
return o
def format_df(df):
def _dumps(obj):
if obj is None:
return obj
return json.dumps(round_floats(obj)).replace(" ", "")
for field in ['node_coords', 'node_symbols', 'edges']:
if field in df.columns:
df[field] = [_dumps(obj) for obj in df[field]]
return df