yjwtheonly commited on
Commit
ac8e861
1 Parent(s): 0574a67

modification

Browse files
Files changed (2) hide show
  1. DiseaseSpecific/utils.py +1 -1
  2. server.py +2 -1
DiseaseSpecific/utils.py CHANGED
@@ -77,7 +77,7 @@ def load_model(model_path, args, n_ent, n_rel, device):
77
  for key, size, count in params:
78
  logger.info('Key:{0}, Size:{1}, Count:{2}'.format(key, size, count))
79
 
80
- model.load_state_dict(model_params)
81
  model.eval()
82
  logger.info(model)
83
 
 
77
  for key, size, count in params:
78
  logger.info('Key:{0}, Size:{1}, Count:{2}'.format(key, size, count))
79
 
80
+ model.load_state_dict(model_params, map_location=device)
81
  model.eval()
82
  logger.info(model)
83
 
server.py CHANGED
@@ -12,7 +12,8 @@ import spacy
12
  # os.system("python -m spacy download en-core-web-sm")
13
  import pickle as pkl
14
  #%%
15
-
 
16
  from torch.nn.modules.loss import CrossEntropyLoss
17
  from transformers import AutoTokenizer
18
  from transformers import BioGptForCausalLM, BartForConditionalGeneration
 
12
  # os.system("python -m spacy download en-core-web-sm")
13
  import pickle as pkl
14
  #%%
15
+ # please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.
16
+ # torch.loa
17
  from torch.nn.modules.loss import CrossEntropyLoss
18
  from transformers import AutoTokenizer
19
  from transformers import BioGptForCausalLM, BartForConditionalGeneration