Spaces:
Runtime error
Runtime error
add it to compose
Browse files- inference.py +7 -5
inference.py
CHANGED
@@ -55,7 +55,8 @@ transforms = Compose(
|
|
55 |
EnsureTyped(keys=keys),
|
56 |
ConcatItemsd(keys=("t2"), name=CommonKeys.IMAGE, dim=0),
|
57 |
ConcatItemsd(keys=("t2_anatomy_reader1"), name=CommonKeys.LABEL, dim=0),
|
58 |
-
]
|
|
|
59 |
)
|
60 |
|
61 |
|
@@ -68,7 +69,8 @@ postprocessing = Compose(
|
|
68 |
keys=CommonKeys.PRED,
|
69 |
applied_labels=list(range(1, 3))
|
70 |
),
|
71 |
-
]
|
|
|
72 |
)
|
73 |
inferer = monai.inferers.SlidingWindowInferer(
|
74 |
roi_size=(96, 96, 96),
|
@@ -133,19 +135,19 @@ def make_inference(data_dict:list) -> str:
|
|
133 |
test_ds = Dataset(
|
134 |
data=data_dict,
|
135 |
transform=transforms,
|
136 |
-
allow_missing_keys=True,
|
137 |
)
|
138 |
model.eval()
|
139 |
with torch.no_grad():
|
140 |
example = test_ds[0]
|
141 |
-
label = example["t2_anatomy_reader1"]
|
142 |
input_tensor = example["t2"].unsqueeze(0)
|
143 |
input_tensor = input_tensor.to(device)
|
144 |
output_tensor = inferer(input_tensor, model)
|
145 |
output_tensor = output_tensor.argmax(dim=1, keepdim=False)
|
146 |
output_tensor = output_tensor.squeeze(0).to(torch.device("cpu"))
|
147 |
|
148 |
-
output_tensor = postprocessing({"pred": output_tensor, "label": label})["pred"]
|
|
|
149 |
output_tensor = output_tensor.numpy().astype(np.uint8)
|
150 |
target_shape = example["t2_meta_dict"]["spatial_shape"]
|
151 |
output_tensor = resize_image(output_tensor, target_shape)
|
|
|
55 |
EnsureTyped(keys=keys),
|
56 |
ConcatItemsd(keys=("t2"), name=CommonKeys.IMAGE, dim=0),
|
57 |
ConcatItemsd(keys=("t2_anatomy_reader1"), name=CommonKeys.LABEL, dim=0),
|
58 |
+
],
|
59 |
+
allow_missing_keys=True,
|
60 |
)
|
61 |
|
62 |
|
|
|
69 |
keys=CommonKeys.PRED,
|
70 |
applied_labels=list(range(1, 3))
|
71 |
),
|
72 |
+
],
|
73 |
+
allow_missing_keys=True,
|
74 |
)
|
75 |
inferer = monai.inferers.SlidingWindowInferer(
|
76 |
roi_size=(96, 96, 96),
|
|
|
135 |
test_ds = Dataset(
|
136 |
data=data_dict,
|
137 |
transform=transforms,
|
|
|
138 |
)
|
139 |
model.eval()
|
140 |
with torch.no_grad():
|
141 |
example = test_ds[0]
|
142 |
+
# label = example["t2_anatomy_reader1"]
|
143 |
input_tensor = example["t2"].unsqueeze(0)
|
144 |
input_tensor = input_tensor.to(device)
|
145 |
output_tensor = inferer(input_tensor, model)
|
146 |
output_tensor = output_tensor.argmax(dim=1, keepdim=False)
|
147 |
output_tensor = output_tensor.squeeze(0).to(torch.device("cpu"))
|
148 |
|
149 |
+
# output_tensor = postprocessing({"pred": output_tensor, "label": label})["pred"]
|
150 |
+
output_tensor = postprocessing({"pred": output_tensor})["pred"]
|
151 |
output_tensor = output_tensor.numpy().astype(np.uint8)
|
152 |
target_shape = example["t2_meta_dict"]["spatial_shape"]
|
153 |
output_tensor = resize_image(output_tensor, target_shape)
|