stochasticribosome commited on
Commit
22dca11
1 Parent(s): 41311f3

Add inference preprocessing

Browse files
Files changed (1) hide show
  1. main.py +97 -9
main.py CHANGED
@@ -12,6 +12,9 @@ import os
12
  from MDmodel import GNN_MD
13
  import h5py
14
  from transformMD import GNNTransformMD
 
 
 
15
 
16
  # JavaScript functions
17
  resid_hover = """function(atom,viewer) {{
@@ -46,6 +49,78 @@ model = model.to('cpu')
46
  model.eval()
47
 
48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
  def get_pdb(pdb_code="", filepath=""):
51
  try:
@@ -65,24 +140,37 @@ def get_offset(pdb):
65
  return int(line[22:27])
66
 
67
 
 
 
 
 
68
  def predict(pdb_code, pdb_file):
69
- path_to_pdb = get_pdb(pdb_code=pdb_code, filepath=pdb_file)
 
 
 
 
 
70
  mdh5_file = "inference_for_md.hdf5"
 
 
 
 
71
  md_H5File = h5py.File(mdh5_file)
72
 
73
  column_names = ["x", "y", "z", "element"]
74
  atoms_protein = pd.DataFrame(columns = column_names)
75
- cutoff = md_H5File["11GS"]["molecules_begin_atom_index"][:][-1] # cutoff defines protein atoms
76
 
77
- atoms_protein["x"] = md_H5File["11GS"]["atoms_coordinates_ref"][:][:cutoff, 0]
78
- atoms_protein["y"] = md_H5File["11GS"]["atoms_coordinates_ref"][:][:cutoff, 1]
79
- atoms_protein["z"] = md_H5File["11GS"]["atoms_coordinates_ref"][:][:cutoff, 2]
80
 
81
- atoms_protein["element"] = md_H5File["11GS"]["atoms_element"][:][:cutoff]
82
 
83
  item = {}
84
  item["scores"] = 0
85
- item["id"] = "11GS"
86
  item["atoms_protein"] = atoms_protein
87
 
88
  transform = GNNTransformMD()
@@ -99,7 +187,7 @@ def predict(pdb_code, pdb_file):
99
  topN = 100
100
  topN_ind = np.argsort(adaptability)[::-1][:topN]
101
 
102
- pdb = open(path_to_pdb, "r").read()
103
 
104
  view = py3Dmol.view(width=600, height=400)
105
  view.setBackgroundColor('white')
@@ -149,4 +237,4 @@ def run():
149
 
150
 
151
  if __name__ == "__main__":
152
- run()
 
12
  from MDmodel import GNN_MD
13
  import h5py
14
  from transformMD import GNNTransformMD
15
+ import sys
16
+ import pytraj as pt
17
+ import pickle
18
 
19
  # JavaScript functions
20
  resid_hover = """function(atom,viewer) {{
 
49
  model.eval()
50
 
51
 
52
+ def run_leap(fileName, path):
53
+ leapText = """
54
+ source leaprc.protein.ff14SB
55
+ source leaprc.water.tip3p
56
+ exp = loadpdb PATH4amb.pdb
57
+ saveamberparm exp PATHexp.top PATHexp.crd
58
+ quit
59
+ """
60
+ with open(path+"leap.in", "w") as outLeap:
61
+ outLeap.write(leapText.replace('PATH', path))
62
+ os.system("tleap -f "+path+"leap.in >> "+path+"leap.out")
63
+
64
+ def convert_to_amber_format(pdbName):
65
+ fileName, path = pdbName+'.pdb', pdbName+'/'
66
+ os.system("pdb4amber -i "+fileName+" -p -y -o "+path+"4amb.pdb -l "+path+"pdb4amber_protein.log")
67
+ run_leap(fileName, path)
68
+ traj = pt.iterload(path+'exp.crd', top = path+'exp.top')
69
+ pt.write_traj(path+fileName, traj, overwrite= True)
70
+ print(path+fileName+' was created. Please always use this file for inspection because the coordinates might get translated during amber file generation and thus might vary from the input pdb file.')
71
+ return pt.iterload(path+'exp.crd', top = path+'exp.top')
72
+
73
+ def get_maps(mapPath):
74
+ residueMap = pickle.load(open(mapPath+'atoms_residue_map_generate.pickle','rb'))
75
+ nameMap = pickle.load(open(mapPath+'atoms_name_map_generate.pickle','rb'))
76
+ typeMap = pickle.load(open(mapPath+'atoms_type_map_generate.pickle','rb'))
77
+ elementMap = pickle.load(open(mapPath+'map_atomType_element_numbers.pickle','rb'))
78
+ return residueMap, nameMap, typeMap, elementMap
79
+
80
+ def get_residues_atomwise(residues):
81
+ atomwise = []
82
+ for name, nAtoms in residues:
83
+ for i in range(nAtoms):
84
+ atomwise.append(name)
85
+ return atomwise
86
+
87
+ def get_begin_atom_index(traj):
88
+ natoms = [m.n_atoms for m in traj.top.mols]
89
+ molecule_begin_atom_index = [0]
90
+ x = 0
91
+ for i in range(len(natoms)):
92
+ x += natoms[i]
93
+ molecule_begin_atom_index.append(x)
94
+ print('molecule begin atom index', molecule_begin_atom_index, natoms)
95
+ return molecule_begin_atom_index
96
+
97
+ def get_traj_info(traj, mapPath):
98
+ coordinates = traj.xyz
99
+ residueMap, nameMap, typeMap, elementMap = get_maps(mapPath)
100
+ types = [typeMap[a.type] for a in traj.top.atoms]
101
+ elements = [elementMap[typ] for typ in types]
102
+ atomic_numbers = [a.atomic_number for a in traj.top.atoms]
103
+ molecule_begin_atom_index = get_begin_atom_index(traj)
104
+ residues = [(residueMap[res.name], res.n_atoms) for res in traj.top.residues]
105
+ residues_atomwise = get_residues_atomwise(residues)
106
+ return coordinates[0], elements, types, atomic_numbers, residues_atomwise, molecule_begin_atom_index
107
+
108
+ def write_h5_info(outName, struct, atoms_type, atoms_number, atoms_residue, atoms_element, molecules_begin_atom_index, atoms_coordinates_ref):
109
+ if os.path.isfile(outName):
110
+ os.remove(outName)
111
+ with h5py.File(outName, 'w') as oF:
112
+ subgroup = oF.create_group(struct)
113
+ subgroup.create_dataset('atoms_residue', data= atoms_residue, compression = "gzip", dtype='i8')
114
+ subgroup.create_dataset('molecules_begin_atom_index', data= molecules_begin_atom_index, compression = "gzip", dtype='i8')
115
+ subgroup.create_dataset('atoms_type', data= atoms_type, compression = "gzip", dtype='i8')
116
+ subgroup.create_dataset('atoms_number', data= atoms_number, compression = "gzip", dtype='i8')
117
+ subgroup.create_dataset('atoms_element', data= atoms_element, compression = "gzip", dtype='i8')
118
+ subgroup.create_dataset('atoms_coordinates_ref', data= atoms_coordinates_ref, compression = "gzip", dtype='f8')
119
+
120
+ def preprocess(pdbid: str = None, ouputfile: str = "inference_for_md.hdf5", mask: str = "!@H=", mappath: str = "/maps/"):
121
+ traj = convert_to_amber_format(pdbid)
122
+ atoms_coordinates_ref, atoms_element, atoms_type, atoms_number, atoms_residue, molecules_begin_atom_index = get_traj_info(traj[mask], mappath)
123
+ write_h5_info(ouputfile, pdbid, atoms_type, atoms_number, atoms_residue, atoms_element, molecules_begin_atom_index, atoms_coordinates_ref)
124
 
125
  def get_pdb(pdb_code="", filepath=""):
126
  try:
 
140
  return int(line[22:27])
141
 
142
 
143
+ def get_pdbid_from_filename(filename: str):
144
+ # Assuming the filename would be of the standard form 11GS.pdb
145
+ return filename.split(".")[0]
146
+
147
  def predict(pdb_code, pdb_file):
148
+ #path_to_pdb = get_pdb(pdb_code=pdb_code, filepath=pdb_file)
149
+
150
+ #pdb = open(path_to_pdb, "r").read()
151
+ # switch to misato env if not running from container
152
+
153
+ pdbid = get_pdbid_from_filename(pdb_file)
154
  mdh5_file = "inference_for_md.hdf5"
155
+ mappath = "/maps"
156
+ mask = "!@H="
157
+ preprocess(pdbid=pdbid, ouputfile=mdh5_file, mask=mask, mappath=mappath)
158
+
159
  md_H5File = h5py.File(mdh5_file)
160
 
161
  column_names = ["x", "y", "z", "element"]
162
  atoms_protein = pd.DataFrame(columns = column_names)
163
+ cutoff = md_H5File[pdbid]["molecules_begin_atom_index"][:][-1] # cutoff defines protein atoms
164
 
165
+ atoms_protein["x"] = md_H5File[pdbid]["atoms_coordinates_ref"][:][:cutoff, 0]
166
+ atoms_protein["y"] = md_H5File[pdbid]["atoms_coordinates_ref"][:][:cutoff, 1]
167
+ atoms_protein["z"] = md_H5File[pdbid]["atoms_coordinates_ref"][:][:cutoff, 2]
168
 
169
+ atoms_protein["element"] = md_H5File[pdbid]["atoms_element"][:][:cutoff]
170
 
171
  item = {}
172
  item["scores"] = 0
173
+ item["id"] = pdbid
174
  item["atoms_protein"] = atoms_protein
175
 
176
  transform = GNNTransformMD()
 
187
  topN = 100
188
  topN_ind = np.argsort(adaptability)[::-1][:topN]
189
 
190
+ pdb = open(pdb_file.name, "r").read()
191
 
192
  view = py3Dmol.view(width=600, height=400)
193
  view.setBackgroundColor('white')
 
237
 
238
 
239
  if __name__ == "__main__":
240
+ run()