|
import numpy as np |
|
import multiprocessing |
|
|
|
import rdkit |
|
import rdkit.Chem as Chem |
|
rdkit.RDLogger.DisableLog('rdApp.*') |
|
from SmilesPE.pretokenizer import atomwise_tokenizer |
|
|
|
|
|
def canonicalize_smiles(smiles, ignore_chiral=False, ignore_cistrans=False, replace_rgroup=True): |
|
if type(smiles) is not str or smiles == '': |
|
return '', False |
|
if ignore_cistrans: |
|
smiles = smiles.replace('/', '').replace('\\', '') |
|
if replace_rgroup: |
|
tokens = atomwise_tokenizer(smiles) |
|
for j, token in enumerate(tokens): |
|
if token[0] == '[' and token[-1] == ']': |
|
symbol = token[1:-1] |
|
if symbol[0] == 'R' and symbol[1:].isdigit(): |
|
tokens[j] = f'[{symbol[1:]}*]' |
|
elif Chem.AtomFromSmiles(token) is None: |
|
tokens[j] = '*' |
|
smiles = ''.join(tokens) |
|
try: |
|
canon_smiles = Chem.CanonSmiles(smiles, useChiral=(not ignore_chiral)) |
|
success = True |
|
except: |
|
canon_smiles = smiles |
|
success = False |
|
return canon_smiles, success |
|
|
|
|
|
def convert_smiles_to_canonsmiles( |
|
smiles_list, ignore_chiral=False, ignore_cistrans=False, replace_rgroup=True, num_workers=16): |
|
with multiprocessing.Pool(num_workers) as p: |
|
results = p.starmap(canonicalize_smiles, |
|
[(smiles, ignore_chiral, ignore_cistrans, replace_rgroup) for smiles in smiles_list], |
|
chunksize=128) |
|
canon_smiles, success = zip(*results) |
|
return list(canon_smiles), np.mean(success) |
|
|
|
|
|
class SmilesEvaluator(object): |
|
|
|
def __init__(self, gold_smiles, num_workers=16): |
|
self.gold_smiles = gold_smiles |
|
self.gold_canon_smiles, self.gold_valid = convert_smiles_to_canonsmiles(gold_smiles, num_workers=num_workers) |
|
self.gold_smiles_chiral, _ = convert_smiles_to_canonsmiles(gold_smiles, |
|
ignore_chiral=True, num_workers=num_workers) |
|
self.gold_smiles_cistrans, _ = convert_smiles_to_canonsmiles(gold_smiles, |
|
ignore_cistrans=True, num_workers=num_workers) |
|
self.gold_canon_smiles = self._replace_empty(self.gold_canon_smiles) |
|
self.gold_smiles_chiral = self._replace_empty(self.gold_smiles_chiral) |
|
self.gold_smiles_cistrans = self._replace_empty(self.gold_smiles_cistrans) |
|
|
|
def _replace_empty(self, smiles_list): |
|
"""Replace empty SMILES in the gold, otherwise it will be considered correct if both pred and gold is empty.""" |
|
return [smiles if smiles is not None and type(smiles) is str and smiles != "" else "<empty>" |
|
for smiles in smiles_list] |
|
|
|
def evaluate(self, pred_smiles): |
|
results = {} |
|
results['gold_valid'] = self.gold_valid |
|
|
|
pred_canon_smiles, pred_valid = convert_smiles_to_canonsmiles(pred_smiles) |
|
results['canon_smiles_em'] = (np.array(self.gold_canon_smiles) == np.array(pred_canon_smiles)).mean() |
|
results['pred_valid'] = pred_valid |
|
|
|
pred_smiles_chiral, _ = convert_smiles_to_canonsmiles(pred_smiles, ignore_chiral=True) |
|
results['graph'] = (np.array(self.gold_smiles_chiral) == np.array(pred_smiles_chiral)).mean() |
|
|
|
pred_smiles_cistrans, _ = convert_smiles_to_canonsmiles(pred_smiles, ignore_cistrans=True) |
|
results['canon_smiles'] = (np.array(self.gold_smiles_cistrans) == np.array(pred_smiles_cistrans)).mean() |
|
|
|
chiral = np.array([[g, p] for g, p in zip(self.gold_smiles_cistrans, pred_smiles_cistrans) if '@' in g]) |
|
results['chiral_ratio'] = len(chiral) / len(self.gold_smiles) |
|
results['chiral'] = (chiral[:, 0] == chiral[:, 1]).mean() if len(chiral) > 0 else -1 |
|
return results |
|
|