MingGatsby commited on
Commit
bf8ba85
1 Parent(s): 83997ab

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -4
app.py CHANGED
@@ -143,17 +143,17 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
143
  ct_root_dir = tempfile.mkdtemp() if CT_MODEL_DIRECTORY is None else CT_MODEL_DIRECTORY
144
  mri_root_dir = tempfile.mkdtemp() if MRI_MODEL_DIRECTORY is None else MRI_MODEL_DIRECTORY
145
 
146
- def load_model(root_dir, model_name, model_file_name, map_location):
147
  if CUSTOM_MODEL_FLAG:
148
  model = Build_Custom_Model(model_name, NUM_CLASSES, pretrained=False).to(device)
149
  else:
150
  model = SEResNet50(spatial_dims=2, in_channels=1, num_classes=NUM_CLASSES).to(device)
151
- model.load_state_dict(torch.load(os.path.join(root_dir, model_file_name), map_location=map_location)))
152
  model.eval()
153
  return model
154
 
155
- ct_model = load_model(ct_root_dir, CT_MODEL_NAME, CT_MODEL_FILE_NAME, device)
156
- mri_model = load_model(mri_root_dir, MRI_MODEL_NAME, MRI_MODEL_FILE_NAME, device)
157
  if LIST_MODEL_MODULES:
158
  for ct_name, _ in ct_model.named_modules():
159
  print(ct_name)
 
143
  ct_root_dir = tempfile.mkdtemp() if CT_MODEL_DIRECTORY is None else CT_MODEL_DIRECTORY
144
  mri_root_dir = tempfile.mkdtemp() if MRI_MODEL_DIRECTORY is None else MRI_MODEL_DIRECTORY
145
 
146
+ def load_model(root_dir, model_name, model_file_name):
147
  if CUSTOM_MODEL_FLAG:
148
  model = Build_Custom_Model(model_name, NUM_CLASSES, pretrained=False).to(device)
149
  else:
150
  model = SEResNet50(spatial_dims=2, in_channels=1, num_classes=NUM_CLASSES).to(device)
151
+ model.load_state_dict(torch.load(os.path.join(root_dir, model_file_name), map_location=device)))
152
  model.eval()
153
  return model
154
 
155
+ ct_model = load_model(ct_root_dir, CT_MODEL_NAME, CT_MODEL_FILE_NAME)
156
+ mri_model = load_model(mri_root_dir, MRI_MODEL_NAME, MRI_MODEL_FILE_NAME)
157
  if LIST_MODEL_MODULES:
158
  for ct_name, _ in ct_model.named_modules():
159
  print(ct_name)