Spaces:
Running
Running
yjwtheonly
commited on
Commit
•
f6678bd
1
Parent(s):
ac8e861
modification
Browse files- DiseaseSpecific/utils.py +2 -2
DiseaseSpecific/utils.py
CHANGED
@@ -71,13 +71,13 @@ def load_model(model_path, args, n_ent, n_rel, device):
|
|
71 |
model = add_model(args, n_ent, n_rel)
|
72 |
model.to(device)
|
73 |
logger.info('Loading saved model from {0}'.format(model_path))
|
74 |
-
state = torch.load(model_path)
|
75 |
model_params = state['state_dict']
|
76 |
params = [(key, value.size(), value.numel()) for key, value in model_params.items()]
|
77 |
for key, size, count in params:
|
78 |
logger.info('Key:{0}, Size:{1}, Count:{2}'.format(key, size, count))
|
79 |
|
80 |
-
model.load_state_dict(model_params
|
81 |
model.eval()
|
82 |
logger.info(model)
|
83 |
|
|
|
71 |
model = add_model(args, n_ent, n_rel)
|
72 |
model.to(device)
|
73 |
logger.info('Loading saved model from {0}'.format(model_path))
|
74 |
+
state = torch.load(model_path, map_location=device)
|
75 |
model_params = state['state_dict']
|
76 |
params = [(key, value.size(), value.numel()) for key, value in model_params.items()]
|
77 |
for key, size, count in params:
|
78 |
logger.info('Key:{0}, Size:{1}, Count:{2}'.format(key, size, count))
|
79 |
|
80 |
+
model.load_state_dict(model_params)
|
81 |
model.eval()
|
82 |
logger.info(model)
|
83 |
|