TillCyrill commited on
Commit
6646de2
1 Parent(s): 1372a90
Files changed (1) hide show
  1. main.py +2 -2
main.py CHANGED
@@ -42,10 +42,10 @@ atom_mapping = {0:'H', 1:'C', 2:'N', 3:'O', 4:'F', 5:'P', 6:'S', 7:'CL', 8:'BR',
42
  model = GNN_MD(11, 64)
43
  state_dict = torch.load(
44
  "best_weights_rep0.pt",
45
- map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
46
  )["model_state_dict"]
47
  model.load_state_dict(state_dict)
48
- model = model.to("cuda" if torch.cuda.is_available() else "cpu")
49
  model.eval()
50
 
51
 
 
42
  model = GNN_MD(11, 64)
43
  state_dict = torch.load(
44
  "best_weights_rep0.pt",
45
+ map_location=torch.device("cpu"),
46
  )["model_state_dict"]
47
  model.load_state_dict(state_dict)
48
+ model = model.to('cpu')
49
  model.eval()
50
 
51