stochasticribosome
commited on
Commit
•
22dca11
1
Parent(s):
41311f3
Add inference preprocessing
Browse files
main.py
CHANGED
@@ -12,6 +12,9 @@ import os
|
|
12 |
from MDmodel import GNN_MD
|
13 |
import h5py
|
14 |
from transformMD import GNNTransformMD
|
|
|
|
|
|
|
15 |
|
16 |
# JavaScript functions
|
17 |
resid_hover = """function(atom,viewer) {{
|
@@ -46,6 +49,78 @@ model = model.to('cpu')
|
|
46 |
model.eval()
|
47 |
|
48 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
|
50 |
def get_pdb(pdb_code="", filepath=""):
|
51 |
try:
|
@@ -65,24 +140,37 @@ def get_offset(pdb):
|
|
65 |
return int(line[22:27])
|
66 |
|
67 |
|
|
|
|
|
|
|
|
|
68 |
def predict(pdb_code, pdb_file):
|
69 |
-
path_to_pdb = get_pdb(pdb_code=pdb_code, filepath=pdb_file)
|
|
|
|
|
|
|
|
|
|
|
70 |
mdh5_file = "inference_for_md.hdf5"
|
|
|
|
|
|
|
|
|
71 |
md_H5File = h5py.File(mdh5_file)
|
72 |
|
73 |
column_names = ["x", "y", "z", "element"]
|
74 |
atoms_protein = pd.DataFrame(columns = column_names)
|
75 |
-
cutoff = md_H5File[
|
76 |
|
77 |
-
atoms_protein["x"] = md_H5File[
|
78 |
-
atoms_protein["y"] = md_H5File[
|
79 |
-
atoms_protein["z"] = md_H5File[
|
80 |
|
81 |
-
atoms_protein["element"] = md_H5File[
|
82 |
|
83 |
item = {}
|
84 |
item["scores"] = 0
|
85 |
-
item["id"] =
|
86 |
item["atoms_protein"] = atoms_protein
|
87 |
|
88 |
transform = GNNTransformMD()
|
@@ -99,7 +187,7 @@ def predict(pdb_code, pdb_file):
|
|
99 |
topN = 100
|
100 |
topN_ind = np.argsort(adaptability)[::-1][:topN]
|
101 |
|
102 |
-
pdb = open(
|
103 |
|
104 |
view = py3Dmol.view(width=600, height=400)
|
105 |
view.setBackgroundColor('white')
|
@@ -149,4 +237,4 @@ def run():
|
|
149 |
|
150 |
|
151 |
if __name__ == "__main__":
|
152 |
-
run()
|
|
|
12 |
from MDmodel import GNN_MD
|
13 |
import h5py
|
14 |
from transformMD import GNNTransformMD
|
15 |
+
import sys
|
16 |
+
import pytraj as pt
|
17 |
+
import pickle
|
18 |
|
19 |
# JavaScript functions
|
20 |
resid_hover = """function(atom,viewer) {{
|
|
|
49 |
model.eval()
|
50 |
|
51 |
|
52 |
+
def run_leap(fileName, path):
|
53 |
+
leapText = """
|
54 |
+
source leaprc.protein.ff14SB
|
55 |
+
source leaprc.water.tip3p
|
56 |
+
exp = loadpdb PATH4amb.pdb
|
57 |
+
saveamberparm exp PATHexp.top PATHexp.crd
|
58 |
+
quit
|
59 |
+
"""
|
60 |
+
with open(path+"leap.in", "w") as outLeap:
|
61 |
+
outLeap.write(leapText.replace('PATH', path))
|
62 |
+
os.system("tleap -f "+path+"leap.in >> "+path+"leap.out")
|
63 |
+
|
64 |
+
def convert_to_amber_format(pdbName):
|
65 |
+
fileName, path = pdbName+'.pdb', pdbName+'/'
|
66 |
+
os.system("pdb4amber -i "+fileName+" -p -y -o "+path+"4amb.pdb -l "+path+"pdb4amber_protein.log")
|
67 |
+
run_leap(fileName, path)
|
68 |
+
traj = pt.iterload(path+'exp.crd', top = path+'exp.top')
|
69 |
+
pt.write_traj(path+fileName, traj, overwrite= True)
|
70 |
+
print(path+fileName+' was created. Please always use this file for inspection because the coordinates might get translated during amber file generation and thus might vary from the input pdb file.')
|
71 |
+
return pt.iterload(path+'exp.crd', top = path+'exp.top')
|
72 |
+
|
73 |
+
def get_maps(mapPath):
|
74 |
+
residueMap = pickle.load(open(mapPath+'atoms_residue_map_generate.pickle','rb'))
|
75 |
+
nameMap = pickle.load(open(mapPath+'atoms_name_map_generate.pickle','rb'))
|
76 |
+
typeMap = pickle.load(open(mapPath+'atoms_type_map_generate.pickle','rb'))
|
77 |
+
elementMap = pickle.load(open(mapPath+'map_atomType_element_numbers.pickle','rb'))
|
78 |
+
return residueMap, nameMap, typeMap, elementMap
|
79 |
+
|
80 |
+
def get_residues_atomwise(residues):
|
81 |
+
atomwise = []
|
82 |
+
for name, nAtoms in residues:
|
83 |
+
for i in range(nAtoms):
|
84 |
+
atomwise.append(name)
|
85 |
+
return atomwise
|
86 |
+
|
87 |
+
def get_begin_atom_index(traj):
|
88 |
+
natoms = [m.n_atoms for m in traj.top.mols]
|
89 |
+
molecule_begin_atom_index = [0]
|
90 |
+
x = 0
|
91 |
+
for i in range(len(natoms)):
|
92 |
+
x += natoms[i]
|
93 |
+
molecule_begin_atom_index.append(x)
|
94 |
+
print('molecule begin atom index', molecule_begin_atom_index, natoms)
|
95 |
+
return molecule_begin_atom_index
|
96 |
+
|
97 |
+
def get_traj_info(traj, mapPath):
|
98 |
+
coordinates = traj.xyz
|
99 |
+
residueMap, nameMap, typeMap, elementMap = get_maps(mapPath)
|
100 |
+
types = [typeMap[a.type] for a in traj.top.atoms]
|
101 |
+
elements = [elementMap[typ] for typ in types]
|
102 |
+
atomic_numbers = [a.atomic_number for a in traj.top.atoms]
|
103 |
+
molecule_begin_atom_index = get_begin_atom_index(traj)
|
104 |
+
residues = [(residueMap[res.name], res.n_atoms) for res in traj.top.residues]
|
105 |
+
residues_atomwise = get_residues_atomwise(residues)
|
106 |
+
return coordinates[0], elements, types, atomic_numbers, residues_atomwise, molecule_begin_atom_index
|
107 |
+
|
108 |
+
def write_h5_info(outName, struct, atoms_type, atoms_number, atoms_residue, atoms_element, molecules_begin_atom_index, atoms_coordinates_ref):
|
109 |
+
if os.path.isfile(outName):
|
110 |
+
os.remove(outName)
|
111 |
+
with h5py.File(outName, 'w') as oF:
|
112 |
+
subgroup = oF.create_group(struct)
|
113 |
+
subgroup.create_dataset('atoms_residue', data= atoms_residue, compression = "gzip", dtype='i8')
|
114 |
+
subgroup.create_dataset('molecules_begin_atom_index', data= molecules_begin_atom_index, compression = "gzip", dtype='i8')
|
115 |
+
subgroup.create_dataset('atoms_type', data= atoms_type, compression = "gzip", dtype='i8')
|
116 |
+
subgroup.create_dataset('atoms_number', data= atoms_number, compression = "gzip", dtype='i8')
|
117 |
+
subgroup.create_dataset('atoms_element', data= atoms_element, compression = "gzip", dtype='i8')
|
118 |
+
subgroup.create_dataset('atoms_coordinates_ref', data= atoms_coordinates_ref, compression = "gzip", dtype='f8')
|
119 |
+
|
120 |
+
def preprocess(pdbid: str = None, ouputfile: str = "inference_for_md.hdf5", mask: str = "!@H=", mappath: str = "/maps/"):
|
121 |
+
traj = convert_to_amber_format(pdbid)
|
122 |
+
atoms_coordinates_ref, atoms_element, atoms_type, atoms_number, atoms_residue, molecules_begin_atom_index = get_traj_info(traj[mask], mappath)
|
123 |
+
write_h5_info(ouputfile, pdbid, atoms_type, atoms_number, atoms_residue, atoms_element, molecules_begin_atom_index, atoms_coordinates_ref)
|
124 |
|
125 |
def get_pdb(pdb_code="", filepath=""):
|
126 |
try:
|
|
|
140 |
return int(line[22:27])
|
141 |
|
142 |
|
143 |
+
def get_pdbid_from_filename(filename: str):
|
144 |
+
# Assuming the filename would be of the standard form 11GS.pdb
|
145 |
+
return filename.split(".")[0]
|
146 |
+
|
147 |
def predict(pdb_code, pdb_file):
|
148 |
+
#path_to_pdb = get_pdb(pdb_code=pdb_code, filepath=pdb_file)
|
149 |
+
|
150 |
+
#pdb = open(path_to_pdb, "r").read()
|
151 |
+
# switch to misato env if not running from container
|
152 |
+
|
153 |
+
pdbid = get_pdbid_from_filename(pdb_file)
|
154 |
mdh5_file = "inference_for_md.hdf5"
|
155 |
+
mappath = "/maps"
|
156 |
+
mask = "!@H="
|
157 |
+
preprocess(pdbid=pdbid, ouputfile=mdh5_file, mask=mask, mappath=mappath)
|
158 |
+
|
159 |
md_H5File = h5py.File(mdh5_file)
|
160 |
|
161 |
column_names = ["x", "y", "z", "element"]
|
162 |
atoms_protein = pd.DataFrame(columns = column_names)
|
163 |
+
cutoff = md_H5File[pdbid]["molecules_begin_atom_index"][:][-1] # cutoff defines protein atoms
|
164 |
|
165 |
+
atoms_protein["x"] = md_H5File[pdbid]["atoms_coordinates_ref"][:][:cutoff, 0]
|
166 |
+
atoms_protein["y"] = md_H5File[pdbid]["atoms_coordinates_ref"][:][:cutoff, 1]
|
167 |
+
atoms_protein["z"] = md_H5File[pdbid]["atoms_coordinates_ref"][:][:cutoff, 2]
|
168 |
|
169 |
+
atoms_protein["element"] = md_H5File[pdbid]["atoms_element"][:][:cutoff]
|
170 |
|
171 |
item = {}
|
172 |
item["scores"] = 0
|
173 |
+
item["id"] = pdbid
|
174 |
item["atoms_protein"] = atoms_protein
|
175 |
|
176 |
transform = GNNTransformMD()
|
|
|
187 |
topN = 100
|
188 |
topN_ind = np.argsort(adaptability)[::-1][:topN]
|
189 |
|
190 |
+
pdb = open(pdb_file.name, "r").read()
|
191 |
|
192 |
view = py3Dmol.view(width=600, height=400)
|
193 |
view.setBackgroundColor('white')
|
|
|
237 |
|
238 |
|
239 |
if __name__ == "__main__":
|
240 |
+
run()
|