MingGatsby commited on
Commit
fd99db1
1 Parent(s): 46f8c5e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -1
app.py CHANGED
@@ -153,6 +153,9 @@ torch.manual_seed(SEED)
153
 
154
  # Parameters
155
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
 
156
 
157
  def load_model(root_dir, model_name, model_file_name):
158
  if CUSTOM_MODEL_FLAG:
@@ -259,7 +262,7 @@ if uploaded_mri_file is not None:
259
 
260
  # Build the CAM (Class Activation Map)
261
  target_layers = [mri_model.model.norm]
262
- cam = GradCAM(model=mri_model, target_layers=target_layers, reshape_transform=reshape_transform, use_cuda=True)
263
  grayscale_cam = cam(input_tensor=image_tensor, targets=[ClassifierOutputTarget(CAM_CLASS_ID)])
264
  grayscale_cam = grayscale_cam[0, :]
265
 
 
153
 
154
  # Parameters
155
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
156
+ USE_CUDA = False
157
+ if device == torch.device("cuda"):
158
+ USE_CUDA = True
159
 
160
  def load_model(root_dir, model_name, model_file_name):
161
  if CUSTOM_MODEL_FLAG:
 
262
 
263
  # Build the CAM (Class Activation Map)
264
  target_layers = [mri_model.model.norm]
265
+ cam = GradCAM(model=mri_model, target_layers=target_layers, reshape_transform=reshape_transform, use_cuda=USE_CUDA)
266
  grayscale_cam = cam(input_tensor=image_tensor, targets=[ClassifierOutputTarget(CAM_CLASS_ID)])
267
  grayscale_cam = grayscale_cam[0, :]
268