yjwtheonly commited on
Commit
f6678bd
1 Parent(s): ac8e861

modification

Browse files
Files changed (1) hide show
  1. DiseaseSpecific/utils.py +2 -2
DiseaseSpecific/utils.py CHANGED
@@ -71,13 +71,13 @@ def load_model(model_path, args, n_ent, n_rel, device):
71
  model = add_model(args, n_ent, n_rel)
72
  model.to(device)
73
  logger.info('Loading saved model from {0}'.format(model_path))
74
- state = torch.load(model_path)
75
  model_params = state['state_dict']
76
  params = [(key, value.size(), value.numel()) for key, value in model_params.items()]
77
  for key, size, count in params:
78
  logger.info('Key:{0}, Size:{1}, Count:{2}'.format(key, size, count))
79
 
80
- model.load_state_dict(model_params, map_location=device)
81
  model.eval()
82
  logger.info(model)
83
 
 
71
  model = add_model(args, n_ent, n_rel)
72
  model.to(device)
73
  logger.info('Loading saved model from {0}'.format(model_path))
74
+ state = torch.load(model_path, map_location=device)
75
  model_params = state['state_dict']
76
  params = [(key, value.size(), value.numel()) for key, value in model_params.items()]
77
  for key, size, count in params:
78
  logger.info('Key:{0}, Size:{1}, Count:{2}'.format(key, size, count))
79
 
80
+ model.load_state_dict(model_params)
81
  model.eval()
82
  logger.info(model)
83