Spaces:
Running
Running
File size: 9,894 Bytes
a950ee6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 |
import argparse
import os
import torch
import torch.nn.functional as F
import json
from segment_anything_volumetric import sam_model_registry
from network.model import SegVol
from data_process.demo_data_process import process_ct_gt
import monai.transforms as transforms
from utils.monai_inferers_utils import sliding_window_inference, generate_box, select_points, build_binary_cube, build_binary_points, logits2roi_coor
from utils.visualize import draw_result
def set_parse():
# %% set up parser
parser = argparse.ArgumentParser()
parser.add_argument("--test_mode", default=True, type=bool)
parser.add_argument("--resume", type = str, default = '')
parser.add_argument("-infer_overlap", default=0.5, type=float, help="sliding window inference overlap")
parser.add_argument("-spatial_size", default=(32, 256, 256), type=tuple)
parser.add_argument("-patch_size", default=(4, 16, 16), type=tuple)
parser.add_argument('-work_dir', type=str, default='./work_dir')
### demo
parser.add_argument('--demo_config', type=str, required=True)
parser.add_argument("--clip_ckpt", type = str, default = './config/clip')
args = parser.parse_args()
return args
def dice_score(preds, labels): # on GPU
assert preds.shape[0] == labels.shape[0], "predict & target batch size don't match\n" + str(preds.shape) + str(labels.shape)
predict = preds.view(1, -1)
target = labels.view(1, -1)
if target.shape[1] < 1e8:
predict = predict.cuda()
target = target.cuda()
predict = torch.sigmoid(predict)
predict = torch.where(predict > 0.5, 1., 0.)
tp = torch.sum(torch.mul(predict, target))
den = torch.sum(predict) + torch.sum(target) + 1
dice = 2 * tp / den
if target.shape[1] < 1e8:
predict = predict.cpu()
target = target.cpu()
return dice
def zoom_in_zoom_out(args, segvol_model, image, image_resize, gt3D, gt3D_resize, categories=None):
logits_labels_record = {}
image_single_resize = image_resize
image_single = image[0,0]
ori_shape = image_single.shape
for item_idx in range(len(categories)):
# get label to generate prompts
label_single = gt3D[0][item_idx]
label_single_resize = gt3D_resize[0][item_idx]
# skip meaningless categories
if torch.sum(label_single) == 0:
print('No object, skip')
continue
# generate prompts
text_single = categories[item_idx] if args.use_text_prompt else None
if categories is not None: print(f'inference |{categories[item_idx]}| target...')
points_single = None
box_single = None
if args.use_point_prompt:
point, point_label = select_points(label_single_resize, num_positive_extra=3, num_negative_extra=3)
points_single = (point.unsqueeze(0).float().cuda(), point_label.unsqueeze(0).float().cuda())
binary_points_resize = build_binary_points(point, point_label, label_single_resize.shape)
if args.use_box_prompt:
box_single = generate_box(label_single_resize).unsqueeze(0).float().cuda()
binary_cube_resize = build_binary_cube(box_single, binary_cube_shape=label_single_resize.shape)
####################
# zoom-out inference:
print('--- zoom out inference ---')
print(f'use text-prompt [{text_single!=None}], use box-prompt [{box_single!=None}], use point-prompt [{points_single!=None}]')
with torch.no_grad():
logits_global_single = segvol_model(image_single_resize.cuda(),
text=text_single,
boxes=box_single,
points=points_single)
# resize back global logits
logits_global_single = F.interpolate(
logits_global_single.cpu(),
size=ori_shape, mode='nearest')[0][0]
# build prompt reflection for zoom-in
if args.use_point_prompt:
binary_points = F.interpolate(
binary_points_resize.unsqueeze(0).unsqueeze(0).float(),
size=ori_shape, mode='nearest')[0][0]
if args.use_box_prompt:
binary_cube = F.interpolate(
binary_cube_resize.unsqueeze(0).unsqueeze(0).float(),
size=ori_shape, mode='nearest')[0][0]
zoom_out_dice = dice_score(logits_global_single.squeeze(), label_single.squeeze())
logits_labels_record[categories[item_idx]] = (
zoom_out_dice,
image_single,
points_single,
box_single,
logits_global_single,
label_single)
print(f'zoom out inference done with zoom_out_dice: {zoom_out_dice:.4f}')
if not args.use_zoom_in:
continue
####################
# zoom-in inference:
min_d, min_h, min_w, max_d, max_h, max_w = logits2roi_coor(args.spatial_size, logits_global_single)
if min_d is None:
print('Fail to detect foreground!')
continue
# Crop roi
image_single_cropped = image_single[min_d:max_d+1, min_h:max_h+1, min_w:max_w+1].unsqueeze(0).unsqueeze(0)
global_preds = (torch.sigmoid(logits_global_single[min_d:max_d+1, min_h:max_h+1, min_w:max_w+1])>0.5).long()
assert not (args.use_box_prompt and args.use_point_prompt)
# label_single_cropped = label_single[min_d:max_d+1, min_h:max_h+1, min_w:max_w+1].unsqueeze(0).unsqueeze(0)
prompt_reflection = None
if args.use_box_prompt:
binary_cube_cropped = binary_cube[min_d:max_d+1, min_h:max_h+1, min_w:max_w+1]
prompt_reflection = (
binary_cube_cropped.unsqueeze(0).unsqueeze(0),
global_preds.unsqueeze(0).unsqueeze(0)
)
if args.use_point_prompt:
binary_points_cropped = binary_points[min_d:max_d+1, min_h:max_h+1, min_w:max_w+1]
prompt_reflection = (
binary_points_cropped.unsqueeze(0).unsqueeze(0),
global_preds.unsqueeze(0).unsqueeze(0)
)
## inference
with torch.no_grad():
logits_single_cropped = sliding_window_inference(
image_single_cropped.cuda(), prompt_reflection,
args.spatial_size, 1, segvol_model, args.infer_overlap,
text=text_single,
use_box=args.use_box_prompt,
use_point=args.use_point_prompt,
)
logits_single_cropped = logits_single_cropped.cpu().squeeze()
logits_global_single[min_d:max_d+1, min_h:max_h+1, min_w:max_w+1] = logits_single_cropped
zoom_in_dice = dice_score(logits_global_single.squeeze(), label_single.squeeze())
logits_labels_record[categories[item_idx]] = (
zoom_in_dice,
image_single,
points_single,
box_single,
logits_global_single,
label_single)
print(f'===> zoom out dice {zoom_out_dice:.4f} -> zoom-out-zoom-in dice {zoom_in_dice:.4f} <===')
return logits_labels_record
def inference_single_ct(args, segvol_model, data_item, categories):
segvol_model.eval()
image, gt3D = data_item["image"].float(), data_item["label"]
image_zoom_out, gt3D__zoom_out = data_item["zoom_out_image"].float(), data_item['zoom_out_label']
logits_labels_record = zoom_in_zoom_out(
args, segvol_model,
image.unsqueeze(0), image_zoom_out.unsqueeze(0),
gt3D.unsqueeze(0), gt3D__zoom_out.unsqueeze(0), # add batch dim
categories=categories)
# visualize
if args.visualize:
for target, values in logits_labels_record.items():
dice_score, image, point_prompt, box_prompt, logits, labels = values
print(f'{target} result with Dice score {dice_score:.4f} visualizing')
draw_result(target + f"-Dice {dice_score:.4f}", image, box_prompt, point_prompt, logits, labels, args.spatial_size, args.work_dir)
def main(args):
gpu = 0
torch.cuda.set_device(gpu)
# build model
sam_model = sam_model_registry['vit'](args=args)
segvol_model = SegVol(
image_encoder=sam_model.image_encoder,
mask_decoder=sam_model.mask_decoder,
prompt_encoder=sam_model.prompt_encoder,
clip_ckpt=args.clip_ckpt,
roi_size=args.spatial_size,
patch_size=args.patch_size,
test_mode=args.test_mode,
).cuda()
segvol_model = torch.nn.DataParallel(segvol_model, device_ids=[gpu])
# load param
if os.path.isfile(args.resume):
## Map model to be loaded to specified single GPU
loc = 'cuda:{}'.format(gpu)
checkpoint = torch.load(args.resume, map_location=loc)
segvol_model.load_state_dict(checkpoint['model'], strict=True)
print("loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch']))
# load demo config
with open(args.demo_config, 'r') as file:
config_dict = json.load(file)
ct_path, gt_path, categories = config_dict['demo_case']['ct_path'], config_dict['demo_case']['gt_path'], config_dict['categories']
# preprocess for data
data_item = process_ct_gt(ct_path, gt_path, categories, args.spatial_size) # keys: image, label
# seg config for prompt & zoom-in-zoom-out
args.use_zoom_in = True
args.use_text_prompt = True
args.use_box_prompt = True
args.use_point_prompt = False
args.visualize = False
inference_single_ct(args, segvol_model, data_item, categories)
if __name__ == "__main__":
args = set_parse()
main(args)
|