DiffLinker / src /generation.py
igashov's picture
Change max batch_size
d8600ba
raw
history blame
3.77 kB
import numpy as np
import os.path
import subprocess
import torch
from Bio.PDB import PDBParser
from src import const
from src.visualizer import save_xyz_file
from src.utils import FoundNaNException
from src.datasets import get_one_hot
def generate_linkers(ddpm, data, sample_fn, name, with_pocket=False, offset_idx=0):
chain = node_mask = None
for i in range(5):
try:
chain, node_mask = ddpm.sample_chain(data, sample_fn=sample_fn, keep_frames=1)
break
except FoundNaNException:
continue
print('Generated linker')
x = chain[0][:, :, :ddpm.n_dims]
h = chain[0][:, :, ddpm.n_dims:]
# Put the molecule back to the initial orientation
if with_pocket:
com_mask = data['fragment_only_mask'] if ddpm.center_of_mass == 'fragments' else data['anchors']
else:
com_mask = data['fragment_mask'] if ddpm.center_of_mass == 'fragments' else data['anchors']
pos_masked = data['positions'] * com_mask
N = com_mask.sum(1, keepdims=True)
mean = torch.sum(pos_masked, dim=1, keepdim=True) / N
x = x + mean * node_mask
if with_pocket:
node_mask[torch.where(data['pocket_mask'])] = 0
batch_size = len(data['positions'])
names = [f'output_{offset_idx + i + 1}_{name}' for i in range(batch_size)]
save_xyz_file('results', h, x, node_mask, names=names, is_geom=True, suffix='')
print('Saved XYZ files')
def try_to_convert_to_sdf(name, num_samples):
out_files = []
for i in range(num_samples):
out_xyz = f'results/output_{i + 1}_{name}_.xyz'
out_sdf = f'results/output_{i + 1}_{name}_.sdf'
subprocess.run(f'obabel {out_xyz} -O {out_sdf}', shell=True)
if os.path.exists(out_sdf):
out_files.append(out_sdf)
else:
out_files.append(out_xyz)
return out_files
def get_pocket(mol, pdb_path):
struct = PDBParser().get_structure('', pdb_path)
residue_ids = []
atom_coords = []
for residue in struct.get_residues():
resid = residue.get_id()[1]
for atom in residue.get_atoms():
atom_coords.append(atom.get_coord())
residue_ids.append(resid)
residue_ids = np.array(residue_ids)
atom_coords = np.array(atom_coords)
mol_atom_coords = mol.GetConformer().GetPositions()
distances = np.linalg.norm(atom_coords[:, None, :] - mol_atom_coords[None, :, :], axis=-1)
contact_residues = np.unique(residue_ids[np.where(distances.min(1) <= 6)[0]])
pocket_coords_full = []
pocket_types_full = []
pocket_coords_bb = []
pocket_types_bb = []
for residue in struct.get_residues():
resid = residue.get_id()[1]
if resid not in contact_residues:
continue
for atom in residue.get_atoms():
atom_name = atom.get_name()
atom_type = atom.element.upper()
atom_coord = atom.get_coord()
pocket_coords_full.append(atom_coord.tolist())
pocket_types_full.append(atom_type)
if atom_name in {'N', 'CA', 'C', 'O'}:
pocket_coords_bb.append(atom_coord.tolist())
pocket_types_bb.append(atom_type)
pocket_pos = []
pocket_one_hot = []
pocket_charges = []
for coord, atom_type in zip(pocket_coords_full, pocket_types_full):
if atom_type not in const.GEOM_ATOM2IDX.keys():
continue
pocket_pos.append(coord)
pocket_one_hot.append(get_one_hot(atom_type, const.GEOM_ATOM2IDX))
pocket_charges.append(const.GEOM_CHARGES[atom_type])
pocket_pos = np.array(pocket_pos)
pocket_one_hot = np.array(pocket_one_hot)
pocket_charges = np.array(pocket_charges)
return pocket_pos, pocket_one_hot, pocket_charges