igashov commited on
Commit
c1152c1
·
1 Parent(s): bc1ef42

handle nan values if linker size is small

Browse files
Files changed (2) hide show
  1. app.py +43 -32
  2. src/egnn.py +10 -3
app.py CHANGED
@@ -40,19 +40,6 @@ if not os.path.exists(diffusion_path):
40
  ddpm = DDPM.load_from_checkpoint('models/geom_difflinker.ckpt', map_location=device).eval().to(device)
41
  print('Loaded diffusion model')
42
 
43
-
44
- def sample_fn(_data):
45
- output, _ = size_nn.forward(_data, return_loss=False)
46
- probabilities = torch.softmax(output, dim=1)
47
- distribution = torch.distributions.Categorical(probs=probabilities)
48
- samples = distribution.sample()
49
- sizes = []
50
- for label in samples.detach().cpu().numpy():
51
- sizes.append(size_nn.linker_id2size[label])
52
- sizes = torch.tensor(sizes, device=samples.device, dtype=torch.long)
53
- return sizes
54
-
55
-
56
  def read_molecule_content(path):
57
  with open(path, "r") as f:
58
  return "".join(f.readlines())
@@ -72,7 +59,7 @@ def read_molecule(path):
72
 
73
  def show_input(input_file):
74
  if input_file is None:
75
- return ''
76
  if isinstance(input_file, str):
77
  path = input_file
78
  else:
@@ -80,15 +67,24 @@ def show_input(input_file):
80
  extension = path.split('.')[-1]
81
  if extension not in ['sdf', 'pdb', 'mol', 'mol2']:
82
  msg = output.INVALID_FORMAT_MSG.format(extension=extension)
83
- return output.IFRAME_TEMPLATE.format(html=msg)
 
 
 
84
 
85
  try:
86
  molecule = read_molecule_content(path)
87
  except Exception as e:
88
- return f'Could not read the molecule: {e}'
 
 
 
89
 
90
  html = output.INITIAL_RENDERING_TEMPLATE.format(molecule=molecule, fmt=extension)
91
- return output.IFRAME_TEMPLATE.format(html=html)
 
 
 
92
 
93
 
94
  def draw_sample(idx, out_files):
@@ -109,7 +105,7 @@ def draw_sample(idx, out_files):
109
  return output.IFRAME_TEMPLATE.format(html=html)
110
 
111
 
112
- def generate(input_file, n_steps):
113
  if input_file is None:
114
  return ''
115
 
@@ -156,6 +152,21 @@ def generate(input_file, n_steps):
156
  ddpm.edm.T = n_steps
157
  assert ddpm.center_of_mass == 'fragments'
158
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
  for data in dataloader:
160
  chain, node_mask = ddpm.sample_chain(data, sample_fn=sample_fn, keep_frames=1)
161
  print('Generated linker')
@@ -208,6 +219,11 @@ with demo:
208
  gr.Markdown('Upload the file with 3D-coordinates of the input fragments in .pdb, .mol2 or .sdf format:')
209
  input_file = gr.File(file_count='single', label='Input Fragments')
210
  n_steps = gr.Slider(minimum=10, maximum=500, label="Number of Denoising Steps", step=10)
 
 
 
 
 
211
  examples = gr.Dataset(
212
  components=[gr.File(visible=False)],
213
  samples=[['examples/example_1.sdf'], ['examples/example_2.sdf']],
@@ -235,21 +251,21 @@ with demo:
235
  input_file.change(
236
  fn=show_input,
237
  inputs=[input_file],
238
- outputs=[visualization],
 
 
 
 
 
239
  )
240
  examples.click(
241
- fn=lambda idx: [
242
- f'examples/example_{idx+1}.sdf',
243
- 10,
244
- show_input(f'examples/example_{idx+1}.sdf'),
245
- gr.Radio(value='Sample 1', visible=False)
246
- ],
247
  inputs=[examples],
248
- outputs=[input_file, n_steps, visualization, samples]
249
  )
250
  button.click(
251
  fn=generate,
252
- inputs=[input_file, n_steps],
253
  outputs=[visualization, output_files, samples],
254
  )
255
  samples.change(
@@ -257,10 +273,5 @@ with demo:
257
  inputs=[samples, output_files],
258
  outputs=[visualization],
259
  )
260
- input_file.clear(
261
- fn=lambda: ['', gr.Radio(value='Sample 1', visible=False)],
262
- inputs=[],
263
- outputs=[visualization, samples],
264
- )
265
 
266
  demo.launch(server_name=args.ip)
 
40
  ddpm = DDPM.load_from_checkpoint('models/geom_difflinker.ckpt', map_location=device).eval().to(device)
41
  print('Loaded diffusion model')
42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  def read_molecule_content(path):
44
  with open(path, "r") as f:
45
  return "".join(f.readlines())
 
59
 
60
  def show_input(input_file):
61
  if input_file is None:
62
+ return ['', gr.Radio.update(visible=False, value='Sample 1')]
63
  if isinstance(input_file, str):
64
  path = input_file
65
  else:
 
67
  extension = path.split('.')[-1]
68
  if extension not in ['sdf', 'pdb', 'mol', 'mol2']:
69
  msg = output.INVALID_FORMAT_MSG.format(extension=extension)
70
+ return [
71
+ output.IFRAME_TEMPLATE.format(html=msg),
72
+ gr.Radio.update(visible=False)
73
+ ]
74
 
75
  try:
76
  molecule = read_molecule_content(path)
77
  except Exception as e:
78
+ return [
79
+ f'Could not read the molecule: {e}',
80
+ gr.Radio.update(visible=False)
81
+ ]
82
 
83
  html = output.INITIAL_RENDERING_TEMPLATE.format(molecule=molecule, fmt=extension)
84
+ return [
85
+ output.IFRAME_TEMPLATE.format(html=html),
86
+ gr.Radio.update(visible=False)
87
+ ]
88
 
89
 
90
  def draw_sample(idx, out_files):
 
105
  return output.IFRAME_TEMPLATE.format(html=html)
106
 
107
 
108
+ def generate(input_file, n_steps, n_atoms):
109
  if input_file is None:
110
  return ''
111
 
 
152
  ddpm.edm.T = n_steps
153
  assert ddpm.center_of_mass == 'fragments'
154
 
155
+ if n_atoms == 0:
156
+ def sample_fn(_data):
157
+ out, _ = size_nn.forward(_data, return_loss=False)
158
+ probabilities = torch.softmax(out, dim=1)
159
+ distribution = torch.distributions.Categorical(probs=probabilities)
160
+ samples = distribution.sample()
161
+ sizes = []
162
+ for label in samples.detach().cpu().numpy():
163
+ sizes.append(size_nn.linker_id2size[label])
164
+ sizes = torch.tensor(sizes, device=samples.device, dtype=torch.long)
165
+ return sizes
166
+ else:
167
+ def sample_fn(_data):
168
+ return torch.ones(_data['positions'].shape[0], device=device, dtype=torch.long) * n_atoms
169
+
170
  for data in dataloader:
171
  chain, node_mask = ddpm.sample_chain(data, sample_fn=sample_fn, keep_frames=1)
172
  print('Generated linker')
 
219
  gr.Markdown('Upload the file with 3D-coordinates of the input fragments in .pdb, .mol2 or .sdf format:')
220
  input_file = gr.File(file_count='single', label='Input Fragments')
221
  n_steps = gr.Slider(minimum=10, maximum=500, label="Number of Denoising Steps", step=10)
222
+ n_atoms = gr.Slider(
223
+ minimum=0, maximum=20,
224
+ label="Linker Size: DiffLinker will predict it if set to 0",
225
+ step=1
226
+ )
227
  examples = gr.Dataset(
228
  components=[gr.File(visible=False)],
229
  samples=[['examples/example_1.sdf'], ['examples/example_2.sdf']],
 
251
  input_file.change(
252
  fn=show_input,
253
  inputs=[input_file],
254
+ outputs=[visualization, samples],
255
+ )
256
+ input_file.clear(
257
+ fn=lambda: [None, '', gr.Radio.update(visible=False)],
258
+ inputs=[],
259
+ outputs=[input_file, visualization, samples],
260
  )
261
  examples.click(
262
+ fn=lambda idx: [f'examples/example_{idx+1}.sdf', 10, 0] + show_input(f'examples/example_{idx+1}.sdf'),
 
 
 
 
 
263
  inputs=[examples],
264
+ outputs=[input_file, n_steps, n_atoms, visualization, samples]
265
  )
266
  button.click(
267
  fn=generate,
268
+ inputs=[input_file, n_steps, n_atoms],
269
  outputs=[visualization, output_files, samples],
270
  )
271
  samples.change(
 
273
  inputs=[samples, output_files],
274
  outputs=[visualization],
275
  )
 
 
 
 
 
276
 
277
  demo.launch(server_name=args.ip)
src/egnn.py CHANGED
@@ -421,13 +421,20 @@ class Dynamics(nn.Module):
421
  if self.condition_time:
422
  h_final = h_final[:, :-1]
423
 
 
 
 
 
 
 
 
 
 
 
424
  vel = vel.view(bs, n_nodes, -1) # (B, N, 3)
425
  h_final = h_final.view(bs, n_nodes, -1) # (B, N, D)
426
  node_mask = node_mask.view(bs, n_nodes, 1) # (B, N, 1)
427
 
428
- if torch.any(torch.isnan(vel)) or torch.any(torch.isnan(h_final)):
429
- raise utils.FoundNaNException(vel, h_final)
430
-
431
  if self.centering:
432
  vel = utils.remove_mean_with_mask(vel, node_mask)
433
 
 
421
  if self.condition_time:
422
  h_final = h_final[:, :-1]
423
 
424
+ if torch.any(torch.isnan(vel)):
425
+ print('Found NaN values in velocities')
426
+ nan_mask = torch.isnan(vel).float()
427
+ vel = x * nan_mask + torch.nan_to_num(vel) * (1 - nan_mask)
428
+
429
+ if torch.any(torch.isnan(h_final)):
430
+ print('Found NaN values in features')
431
+ nan_mask = torch.isnan(h_final).float()
432
+ h_final = h[:, :h_final.shape[1]] * nan_mask + torch.nan_to_num(h_final) * (1 - nan_mask)
433
+
434
  vel = vel.view(bs, n_nodes, -1) # (B, N, 3)
435
  h_final = h_final.view(bs, n_nodes, -1) # (B, N, D)
436
  node_mask = node_mask.view(bs, n_nodes, 1) # (B, N, 1)
437
 
 
 
 
438
  if self.centering:
439
  vel = utils.remove_mean_with_mask(vel, node_mask)
440