Spaces:
Runtime error
Runtime error
import numpy as np | |
import os.path | |
import subprocess | |
import torch | |
from Bio.PDB import PDBParser | |
from src import const | |
from src.visualizer import save_xyz_file | |
from src.utils import FoundNaNException | |
from src.datasets import get_one_hot | |
def generate_linkers(ddpm, data, sample_fn, name, with_pocket=False, offset_idx=0): | |
chain = node_mask = None | |
for i in range(5): | |
try: | |
chain, node_mask = ddpm.sample_chain(data, sample_fn=sample_fn, keep_frames=1) | |
break | |
except FoundNaNException: | |
continue | |
print('Generated linker') | |
x = chain[0][:, :, :ddpm.n_dims] | |
h = chain[0][:, :, ddpm.n_dims:] | |
# Put the molecule back to the initial orientation | |
if with_pocket: | |
com_mask = data['fragment_only_mask'] if ddpm.center_of_mass == 'fragments' else data['anchors'] | |
else: | |
com_mask = data['fragment_mask'] if ddpm.center_of_mass == 'fragments' else data['anchors'] | |
pos_masked = data['positions'] * com_mask | |
N = com_mask.sum(1, keepdims=True) | |
mean = torch.sum(pos_masked, dim=1, keepdim=True) / N | |
x = x + mean * node_mask | |
if with_pocket: | |
node_mask[torch.where(data['pocket_mask'])] = 0 | |
batch_size = len(data['positions']) | |
names = [f'output_{offset_idx + i + 1}_{name}' for i in range(batch_size)] | |
save_xyz_file('results', h, x, node_mask, names=names, is_geom=True, suffix='') | |
print('Saved XYZ files') | |
def try_to_convert_to_sdf(name, num_samples): | |
out_files = [] | |
for i in range(num_samples): | |
out_xyz = f'results/output_{i + 1}_{name}_.xyz' | |
out_sdf = f'results/output_{i + 1}_{name}_.sdf' | |
subprocess.run(f'obabel {out_xyz} -O {out_sdf}', shell=True) | |
if os.path.exists(out_sdf): | |
out_files.append(out_sdf) | |
else: | |
out_files.append(out_xyz) | |
return out_files | |
def get_pocket(mol, pdb_path): | |
struct = PDBParser().get_structure('', pdb_path) | |
residue_ids = [] | |
atom_coords = [] | |
for residue in struct.get_residues(): | |
resid = residue.get_id()[1] | |
for atom in residue.get_atoms(): | |
atom_coords.append(atom.get_coord()) | |
residue_ids.append(resid) | |
residue_ids = np.array(residue_ids) | |
atom_coords = np.array(atom_coords) | |
mol_atom_coords = mol.GetConformer().GetPositions() | |
distances = np.linalg.norm(atom_coords[:, None, :] - mol_atom_coords[None, :, :], axis=-1) | |
contact_residues = np.unique(residue_ids[np.where(distances.min(1) <= 6)[0]]) | |
pocket_coords_full = [] | |
pocket_types_full = [] | |
pocket_coords_bb = [] | |
pocket_types_bb = [] | |
for residue in struct.get_residues(): | |
resid = residue.get_id()[1] | |
if resid not in contact_residues: | |
continue | |
for atom in residue.get_atoms(): | |
atom_name = atom.get_name() | |
atom_type = atom.element.upper() | |
atom_coord = atom.get_coord() | |
pocket_coords_full.append(atom_coord.tolist()) | |
pocket_types_full.append(atom_type) | |
if atom_name in {'N', 'CA', 'C', 'O'}: | |
pocket_coords_bb.append(atom_coord.tolist()) | |
pocket_types_bb.append(atom_type) | |
pocket_pos = [] | |
pocket_one_hot = [] | |
pocket_charges = [] | |
for coord, atom_type in zip(pocket_coords_full, pocket_types_full): | |
if atom_type not in const.GEOM_ATOM2IDX.keys(): | |
continue | |
pocket_pos.append(coord) | |
pocket_one_hot.append(get_one_hot(atom_type, const.GEOM_ATOM2IDX)) | |
pocket_charges.append(const.GEOM_CHARGES[atom_type]) | |
pocket_pos = np.array(pocket_pos) | |
pocket_one_hot = np.array(pocket_one_hot) | |
pocket_charges = np.array(pocket_charges) | |
return pocket_pos, pocket_one_hot, pocket_charges | |