RxnIM / molscribe /interface.py
CYF200127's picture
Upload 116 files
5e9bd47 verified
history blame
11 kB
import argparse
from typing import List
import cv2
import torch
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.backends.backend_agg import FigureCanvasAgg
from .dataset import get_transforms
from .model import Encoder, Decoder
from .chemistry import convert_graph_to_smiles
from .tokenizer import get_tokenizer
BOND_TYPES = ["", "single", "double", "triple", "aromatic", "solid wedge", "dashed wedge"]
def safe_load(module, module_states):
def remove_prefix(state_dict):
return {k.replace('module.', ''): v for k, v in state_dict.items()}
missing_keys, unexpected_keys = module.load_state_dict(remove_prefix(module_states), strict=False)
class MolScribe:
def __init__(self, model_path, device=None):
MolScribe Interface
:param model_path: path of the model checkpoint.
:param device: torch device, defaults to be CPU.
model_states = torch.load(model_path, map_location=torch.device('cpu'))
args = self._get_args(model_states['args'])
if device is None:
device = torch.device('cpu')
self.device = device
self.tokenizer = get_tokenizer(args)
self.encoder, self.decoder = self._get_model(args, self.tokenizer, self.device, model_states)
self.transform = get_transforms(args.input_size, augment=False)
def _get_args(self, args_states=None):
parser = argparse.ArgumentParser()
# Model
parser.add_argument('--encoder', type=str, default='swin_base')
parser.add_argument('--decoder', type=str, default='transformer')
parser.add_argument('--trunc_encoder', action='store_true') # use the hidden states before downsample
parser.add_argument('--no_pretrained', action='store_true')
parser.add_argument('--use_checkpoint', action='store_true', default=True)
parser.add_argument('--dropout', type=float, default=0.5)
parser.add_argument('--embed_dim', type=int, default=256)
parser.add_argument('--enc_pos_emb', action='store_true')
group = parser.add_argument_group("transformer_options")
group.add_argument("--dec_num_layers", help="No. of layers in transformer decoder", type=int, default=6)
group.add_argument("--dec_hidden_size", help="Decoder hidden size", type=int, default=256)
group.add_argument("--dec_attn_heads", help="Decoder no. of attention heads", type=int, default=8)
group.add_argument("--dec_num_queries", type=int, default=128)
group.add_argument("--hidden_dropout", help="Hidden dropout", type=float, default=0.1)
group.add_argument("--attn_dropout", help="Attention dropout", type=float, default=0.1)
group.add_argument("--max_relative_positions", help="Max relative positions", type=int, default=0)
parser.add_argument('--continuous_coords', action='store_true')
parser.add_argument('--compute_confidence', action='store_true')
# Data
parser.add_argument('--input_size', type=int, default=384)
parser.add_argument('--vocab_file', type=str, default=None)
parser.add_argument('--coord_bins', type=int, default=64)
parser.add_argument('--sep_xy', action='store_true', default=True)
args = parser.parse_args([])
if args_states:
for key, value in args_states.items():
args.__dict__[key] = value
return args
def _get_model(self, args, tokenizer, device, states):
encoder = Encoder(args, pretrained=False)
args.encoder_dim = encoder.n_features
decoder = Decoder(args, tokenizer)
safe_load(encoder, states['encoder'])
safe_load(decoder, states['decoder'])
# print(f"Model loaded from {load_path}")
return encoder, decoder
def predict_images(self, input_images: List, return_atoms_bonds=False, return_confidence=False, batch_size=16):
device = self.device
predictions = []
self.decoder.compute_confidence = return_confidence
for idx in range(0, len(input_images), batch_size):
batch_images = input_images[idx:idx+batch_size]
images = [self.transform(image=image, keypoints=[])['image'] for image in batch_images]
images = torch.stack(images, dim=0).to(device)
with torch.no_grad():
features, hiddens = self.encoder(images)
batch_predictions = self.decoder.decode(features, hiddens)
predictions += batch_predictions
return self.convert_graph_to_output(predictions, input_images, return_confidence, return_atoms_bonds)
def convert_graph_to_output(self, predictions, input_images, return_confidence=True, return_atoms_bonds=True):
node_coords = [pred['chartok_coords']['coords'] for pred in predictions]
node_symbols = [pred['chartok_coords']['symbols'] for pred in predictions]
edges = [pred['edges'] for pred in predictions]
# node_symbols = [r_groups[symbol] if symbol in r_groups else symbol for symbol in node_symbols]
smiles_list, molblock_list, r_success = convert_graph_to_smiles(
node_coords, node_symbols, edges, images=input_images)
outputs = []
for smiles, molblock, pred in zip(smiles_list, molblock_list, predictions):
pred_dict = {"smiles": smiles, "molfile": molblock, "oringinal_coords": pred['chartok_coords']['coords'], "original_symbols": pred['chartok_coords']['symbols'], "orignal_edges": pred['edges']}
if return_confidence:
pred_dict["confidence"] = pred["overall_score"]
if return_atoms_bonds:
coords = pred['chartok_coords']['coords']
symbols = pred['chartok_coords']['symbols']
# get atoms info
atom_list = []
for i, (symbol, coord) in enumerate(zip(symbols, coords)):
atom_dict = {"atom_symbol": symbol, "x": round(coord[0],3), "y": round(coord[1],3)}
if return_confidence:
atom_dict["confidence"] = pred['chartok_coords']['atom_scores'][i]
pred_dict["atoms"] = atom_list
# get bonds info
bond_list = []
num_atoms = len(symbols)
for i in range(num_atoms-1):
for j in range(i+1, num_atoms):
bond_type_int = pred['edges'][i][j]
if bond_type_int != 0:
bond_type_str = BOND_TYPES[bond_type_int]
bond_dict = {"bond_type": bond_type_str, "endpoint_atoms": (i, j)}
if return_confidence:
bond_dict["confidence"] = pred["edge_scores"][i][j]
pred_dict["bonds"] = bond_list
return outputs
def predict_image(self, image, return_atoms_bonds=False, return_confidence=False):
return self.predict_images([
image], return_atoms_bonds=return_atoms_bonds, return_confidence=return_confidence)[0]
def predict_image_files(self, image_files: List, return_atoms_bonds=False, return_confidence=False):
input_images = []
for path in image_files:
image = cv2.imread(path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
return self.predict_images(
input_images, return_atoms_bonds=return_atoms_bonds, return_confidence=return_confidence)
def predict_image_file(self, image_file: str, return_atoms_bonds=False, return_confidence=False):
return self.predict_image_files(
[image_file], return_atoms_bonds=return_atoms_bonds, return_confidence=return_confidence)[0]
def draw_prediction(self, prediction, image, notebook=False):
if "atoms" not in prediction or "bonds" not in prediction:
raise ValueError("atoms and bonds information are not provided.")
h, w, _ = image.shape
h, w = np.array([h, w]) * 400 / max(h, w)
image = cv2.resize(image, (int(w), int(h)))
fig, ax = plt.subplots(1, 1)
ax.set_xlim(-0.05 * w, w * 1.05)
ax.set_ylim(1.05 * h, -0.05 * h)
plt.imshow(image, alpha=0.)
x = [a['x'] * w for a in prediction['atoms']]
y = [a['y'] * h for a in prediction['atoms']]
markersize = min(w, h) / 3
plt.scatter(x, y, marker='o', s=markersize, color='lightskyblue', zorder=10)
for i, atom in enumerate(prediction['atoms']):
symbol = atom['atom_symbol'].lstrip('[').rstrip(']')
plt.annotate(symbol, xy=(x[i], y[i]), ha='center', va='center', color='black', zorder=100)
for bond in prediction['bonds']:
u, v = bond['endpoint_atoms']
x1, y1, x2, y2 = x[u], y[u], x[v], y[v]
bond_type = bond['bond_type']
if bond_type == 'single':
color = 'tab:green'
ax.plot([x1, x2], [y1, y2], color, linewidth=4)
elif bond_type == 'aromatic':
color = 'tab:purple'
ax.plot([x1, x2], [y1, y2], color, linewidth=4)
elif bond_type == 'double':
color = 'tab:green'
ax.plot([x1, x2], [y1, y2], color=color, linewidth=7)
ax.plot([x1, x2], [y1, y2], color='w', linewidth=1.5, zorder=2.1)
elif bond_type == 'triple':
color = 'tab:green'
x1s, x2s = 0.8 * x1 + 0.2 * x2, 0.2 * x1 + 0.8 * x2
y1s, y2s = 0.8 * y1 + 0.2 * y2, 0.2 * y1 + 0.8 * y2
ax.plot([x1s, x2s], [y1s, y2s], color=color, linewidth=9)
ax.plot([x1, x2], [y1, y2], color='w', linewidth=5, zorder=2.05)
ax.plot([x1, x2], [y1, y2], color=color, linewidth=2, zorder=2.1)
length = 10
width = 10
color = 'tab:green'
if bond_type == 'solid wedge':
ax.annotate('', xy=(x1, y1), xytext=(x2, y2),
arrowprops=dict(color=color, width=3, headwidth=width, headlength=length), zorder=2)
ax.annotate('', xy=(x2, y2), xytext=(x1, y1),
arrowprops=dict(color=color, width=3, headwidth=width, headlength=length), zorder=2)
if not notebook:
canvas = FigureCanvasAgg(fig)
buf = canvas.buffer_rgba()
result_image = np.asarray(buf)
return result_image