MingGatsby
commited on
Commit
•
bf8ba85
1
Parent(s):
83997ab
Update app.py
Browse files
app.py
CHANGED
@@ -143,17 +143,17 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
143 |
ct_root_dir = tempfile.mkdtemp() if CT_MODEL_DIRECTORY is None else CT_MODEL_DIRECTORY
|
144 |
mri_root_dir = tempfile.mkdtemp() if MRI_MODEL_DIRECTORY is None else MRI_MODEL_DIRECTORY
|
145 |
|
146 |
-
def load_model(root_dir, model_name, model_file_name
|
147 |
if CUSTOM_MODEL_FLAG:
|
148 |
model = Build_Custom_Model(model_name, NUM_CLASSES, pretrained=False).to(device)
|
149 |
else:
|
150 |
model = SEResNet50(spatial_dims=2, in_channels=1, num_classes=NUM_CLASSES).to(device)
|
151 |
-
model.load_state_dict(torch.load(os.path.join(root_dir, model_file_name), map_location=
|
152 |
model.eval()
|
153 |
return model
|
154 |
|
155 |
-
ct_model = load_model(ct_root_dir, CT_MODEL_NAME, CT_MODEL_FILE_NAME
|
156 |
-
mri_model = load_model(mri_root_dir, MRI_MODEL_NAME, MRI_MODEL_FILE_NAME
|
157 |
if LIST_MODEL_MODULES:
|
158 |
for ct_name, _ in ct_model.named_modules():
|
159 |
print(ct_name)
|
|
|
143 |
ct_root_dir = tempfile.mkdtemp() if CT_MODEL_DIRECTORY is None else CT_MODEL_DIRECTORY
|
144 |
mri_root_dir = tempfile.mkdtemp() if MRI_MODEL_DIRECTORY is None else MRI_MODEL_DIRECTORY
|
145 |
|
146 |
+
def load_model(root_dir, model_name, model_file_name):
|
147 |
if CUSTOM_MODEL_FLAG:
|
148 |
model = Build_Custom_Model(model_name, NUM_CLASSES, pretrained=False).to(device)
|
149 |
else:
|
150 |
model = SEResNet50(spatial_dims=2, in_channels=1, num_classes=NUM_CLASSES).to(device)
|
151 |
+
model.load_state_dict(torch.load(os.path.join(root_dir, model_file_name), map_location=device)))
|
152 |
model.eval()
|
153 |
return model
|
154 |
|
155 |
+
ct_model = load_model(ct_root_dir, CT_MODEL_NAME, CT_MODEL_FILE_NAME)
|
156 |
+
mri_model = load_model(mri_root_dir, MRI_MODEL_NAME, MRI_MODEL_FILE_NAME)
|
157 |
if LIST_MODEL_MODULES:
|
158 |
for ct_name, _ in ct_model.named_modules():
|
159 |
print(ct_name)
|