File size: 12,867 Bytes
91ef820 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 |
import argparse
import numpy as np
import torch
# import datasets
import utils.misc as misc
from utils.box_utils import xywh2xyxy
from utils.visual_bbox import visualBBox
# from models import build_model
import transforms as T
import PIL.Image as Image
import data_loader
from transformers import AutoTokenizer
def get_args_parser():
parser = argparse.ArgumentParser('Set transformer detector', add_help=False)
# Input config
# parser.add_argument('--image', type=str, default='xxx', help="input X-ray image.")
# parser.add_argument('--phrase', type=str, default='xxx', help="input phrase.")
# parser.add_argument('--bbox', type=str, default='xxx', help="alternative, if you want to show ground-truth bbox")
# fool
parser.add_argument('--lr', default=1e-4, type=float)
parser.add_argument('--lr_bert', default=0., type=float)
parser.add_argument('--lr_visu_cnn', default=0., type=float)
parser.add_argument('--lr_visu_tra', default=1e-5, type=float)
parser.add_argument('--batch_size', default=32, type=int)
parser.add_argument('--weight_decay', default=1e-4, type=float)
parser.add_argument('--epochs', default=100, type=int)
parser.add_argument('--lr_power', default=0.9, type=float, help='lr poly power')
parser.add_argument('--clip_max_norm', default=0., type=float,
help='gradient clipping max norm')
parser.add_argument('--eval', dest='eval', default=False, action='store_true', help='if evaluation only')
parser.add_argument('--optimizer', default='rmsprop', type=str)
parser.add_argument('--lr_scheduler', default='poly', type=str)
parser.add_argument('--lr_drop', default=80, type=int)
# Model parameters
parser.add_argument('--model_name', type=str, default='TransVG_ca',
help="Name of model to be exploited.")
# Transformers in two branches
parser.add_argument('--bert_enc_num', default=12, type=int)
parser.add_argument('--detr_enc_num', default=6, type=int)
# DETR parameters
# * Backbone
parser.add_argument('--backbone', default='resnet50', type=str,
help="Name of the convolutional backbone to use")
parser.add_argument('--dilation', action='store_true',
help="If true, we replace stride with dilation in the last convolutional block (DC5)")
parser.add_argument('--position_embedding', default='sine', type=str, choices=('sine', 'learned'), help="Type of positional embedding to use on top of the image features")
# * Transformer
parser.add_argument('--enc_layers', default=6, type=int,
help="Number of encoding layers in the transformer")
parser.add_argument('--dec_layers', default=0, type=int,
help="Number of decoding layers in the transformer")
parser.add_argument('--dim_feedforward', default=2048, type=int,
help="Intermediate size of the feedforward layers in the transformer blocks")
parser.add_argument('--hidden_dim', default=256, type=int,
help="Size of the embeddings (dimension of the transformer)")
parser.add_argument('--dropout', default=0.1, type=float,
help="Dropout applied in the transformer")
parser.add_argument('--nheads', default=8, type=int,
help="Number of attention heads inside the transformer's attentions")
parser.add_argument('--num_queries', default=100, type=int,
help="Number of query slots")
parser.add_argument('--pre_norm', action='store_true')
parser.add_argument('--imsize', default=640, type=int, help='image size')
parser.add_argument('--emb_size', default=512, type=int,
help='fusion module embedding dimensions')
# Vision-Language Transformer
parser.add_argument('--use_vl_type_embed', action='store_true',
help="If true, use vl_type embedding")
parser.add_argument('--vl_dropout', default=0.1, type=float,
help="Dropout applied in the vision-language transformer")
parser.add_argument('--vl_nheads', default=8, type=int,
help="Number of attention heads inside the vision-language transformer's attentions")
parser.add_argument('--vl_hidden_dim', default=256, type=int,
help='Size of the embeddings (dimension of the vision-language transformer)')
parser.add_argument('--vl_dim_feedforward', default=2048, type=int,
help="Intermediate size of the feedforward layers in the vision-language transformer blocks")
parser.add_argument('--vl_enc_layers', default=6, type=int,
help='Number of encoders in the vision-language transformer')
# Dataset parameters
# parser.add_argument('--data_root', type=str, default='./ln_data/',
# help='path to ReferIt splits data folder')
# parser.add_argument('--split_root', type=str, default='data',
# help='location of pre-parsed dataset info')
parser.add_argument('--dataset', default='MS_CXR', type=str,
help='referit/flickr/unc/unc+/gref')
parser.add_argument('--max_query_len', default=20, type=int,
help='maximum time steps (lang length) per batch')
# dataset parameters
parser.add_argument('--output_dir', default='outputs',
help='path where to save, empty for no saving')
parser.add_argument('--device', default='cuda',
help='device to use for training / testing')
# parser.add_argument('--seed', default=13, type=int)
# parser.add_argument('--resume', default='', help='resume from checkpoint')
parser.add_argument('--detr_model', default='./saved_models/detr-r50.pth', type=str, help='detr model')
parser.add_argument('--bert_model', default='bert-base-uncased', type=str, help='bert model')
# parser.add_argument('--light', dest='light', default=False, action='store_true', help='if use smaller model')
# parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
# help='start epoch')
# parser.add_argument('--num_workers', default=2, type=int)
# distributed training parameters
# parser.add_argument('--world_size', default=1, type=int,
# help='number of distributed processes')
# parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
# evalutaion options
# parser.add_argument('--eval_set', default='test', type=str)
parser.add_argument('--eval_model', default='med_rpg/checkpoint/best_miou_checkpoint.pth', type=str)
# visualization options
# parser.add_argument('--visualization', action='store_true',
# help="If true, visual the bbox")
# parser.add_argument('--visual_MHA', action='store_true',
# help="If true, visual the attention maps")
return parser
def make_transforms(imsize):
return T.Compose([
T.RandomResize([imsize]),
T.ToTensor(),
T.NormalizeAndPad(size=imsize),
])
def medical_phrase_grounding(model, tokenizer, orig_img, text, bbox = None):
image_size = 640 # hyper parameters
max_query_len = 20
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device("mps")
if bbox is not None:
# bbox = bbox[:2] + [bbox[0]+bbox[2], bbox[1]+bbox[3]] # xywh2xyxy
# bbox[1:] = bbox[1:3] + [bbox[1]+bbox[3], bbox[2]+bbox[4]] # xywh2xyxy
# bbox[2] = bbox[0] + bbox[2]
# bbox[3] = bbox[1] + bbox[3] # xywh2xyxy
## encode phrase to bert input
examples = data_loader.read_examples(text, 1)
features = data_loader.convert_examples_to_features(
examples=examples, seq_length=max_query_len, tokenizer=tokenizer, usemarker=None)
word_id = torch.tensor(features[0].input_ids) #
word_mask = torch.tensor(features[0].input_mask) #
## read and transform image
input_dict = dict()
input_dict['img'] = orig_img
fake_bbox = torch.tensor(np.array([0,0,0,0], dtype=int)).float() #for avoid bug
input_dict['box'] = fake_bbox #for avoid bug
input_dict['text'] = text
transform = make_transforms(imsize=image_size)
input_dict = transform(input_dict)
img = input_dict['img'] #
img_mask = input_dict['mask'] #
img_data = misc.NestedTensor(img.unsqueeze(0), img_mask.unsqueeze(0))
text_data = misc.NestedTensor(word_id.unsqueeze(0), word_mask.unsqueeze(0))
## model infer
img_data = img_data.to(device)
text_data = text_data.to(device)
model.eval()
with torch.no_grad():
outputs = model(img_data, text_data)
pred_box = outputs['pred_box']
pred_box = xywh2xyxy(pred_box.detach().cpu())*image_size
pred_box = pred_box.numpy()[0]
pred_box = [round(pred_box[0]), round(pred_box[1]), round(pred_box[2]), round(pred_box[3])]
output_img = visualBBox(orig_img, pred_box, bbox)
return output_img
def main(args):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device("mps")
image_size = 640 # hyper parameters
## build data
# case1
img_path = "images/649af982-e3af4e3a-75013d30-cdc71514-a34738fd.jpg"
phrase = 'Small left apical pneumothorax'
bbox = [332, 28, 141, 48] # xywh
# # case2
# img_path = 'files/p10/p10977201/s59062881/00363400-cee06fa7-8c2ca1f7-2678a170-b3a62a6e.jpg'
# phrase = 'small apical pneumothorax'
# bbox = [161, 134, 111, 37]
# # case3
# img_path = 'files/p18/p18426683/s59612243/95423e8e-45dff550-563d3eba-b8bc94be-a87f5a1d.jpg'
# phrase = 'cardiac silhouette enlarged'
# bbox = [196, 312, 371, 231]
# # case4
# img_path = 'files/p10/p10048451/s53489305/4b7f7a4c-18c39245-53724c25-06878595-7e41bb94.jpg'
# phrase = 'Focal opacity in the lingular lobe'
# bbox = [467, 373, 131, 189]
# # case5
# img_path = 'files/p19/p19757720/s59572378/13255e1f-91b7b172-02baaeee-340ec493-0e531681.jpg'
# phrase = 'multisegmental right upper lobe consolidation is present'
# bbox = [9, 86, 232, 278]
# # case6
# img_path = 'files/p10/p10469621/s56786891/04e10148-c36f7afb-d0aaf964-152d8a5d-a02ab550.jpg'
# phrase = 'right middle lobe opacity, suspicious for pneumonia in the proper clinical setting'
# bbox = [108, 405, 162, 83]
# # case7
# img_path = 'files/p10/p10670818/s50191454/1176839d-cf4f677f-d597a1ef-548bc32a-c05429f3.jpg'
# phrase = 'Newly appeared lingular opacity'
# bbox = [392, 297, 141, 151]
bbox = bbox[:2] + [bbox[0]+bbox[2], bbox[1]+bbox[3]] # xywh2xyxy
## encode phrase to bert input
examples = data_loader.read_examples(phrase, 1)
tokenizer = AutoTokenizer.from_pretrained(args.bert_model, do_lower_case=True)
features = data_loader.convert_examples_to_features(
examples=examples, seq_length=args.max_query_len, tokenizer=tokenizer, usemarker=None)
word_id = torch.tensor(features[0].input_ids) #
word_mask = torch.tensor(features[0].input_mask) #
## read and transform image
input_dict = dict()
img = Image.open(img_path).convert("RGB")
input_dict['img'] = img
fake_bbox = torch.tensor(np.array([0,0,0,0], dtype=int)).float() #for avoid bug
input_dict['box'] = fake_bbox #for avoid bug
input_dict['text'] = phrase
transform = make_transforms(imsize=image_size)
input_dict = transform(input_dict)
img = input_dict['img'] #
img_mask = input_dict['mask'] #
# if bbox is not None:
# bbox = input_dict['box'] #
img_data = misc.NestedTensor(img.unsqueeze(0), img_mask.unsqueeze(0))
text_data = misc.NestedTensor(word_id.unsqueeze(0), word_mask.unsqueeze(0))
## build model
model = build_model(args)
model.to(device)
checkpoint = torch.load(args.eval_model, map_location='cpu')
model.load_state_dict(checkpoint['model'])
## model infer
img_data = img_data.to(device)
text_data = text_data.to(device)
model.eval()
with torch.no_grad():
outputs = model(img_data, text_data)
pred_box = outputs['pred_box']
pred_box = xywh2xyxy(pred_box.detach().cpu())*image_size
pred_box = pred_box.numpy()[0]
pred_box = [round(pred_box[0]), round(pred_box[1]), round(pred_box[2]), round(pred_box[3])]
visualBBox(img_path, pred_box, bbox)
if __name__ == '__main__':
parser = argparse.ArgumentParser('TransVG evaluation script', parents=[get_args_parser()])
args = parser.parse_args()
main(args)
|