|
import gradio as gr |
|
import torch |
|
|
|
import sys |
|
|
|
sys.path.insert(0, 'med_rpg') |
|
|
|
|
|
from models import build_model |
|
from med_rpg import get_args_parser, medical_phrase_grounding |
|
import PIL.Image as Image |
|
from transformers import AutoTokenizer |
|
|
|
''' |
|
build args |
|
''' |
|
parser = get_args_parser() |
|
args = parser.parse_args() |
|
|
|
''' |
|
build model |
|
''' |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(args.bert_model, do_lower_case=True) |
|
|
|
model = build_model(args) |
|
model.to(device) |
|
checkpoint = torch.load(args.eval_model, map_location='cpu') |
|
model.load_state_dict(checkpoint['model'], strict=False) |
|
|
|
''' |
|
inference model |
|
''' |
|
@torch.no_grad() |
|
def inference(image, text, bbox = [0, 0, 0, 0]): |
|
image = image.convert("RGB") |
|
|
|
|
|
with torch.autocast(device_type='cuda', dtype=torch.float16): |
|
return medical_phrase_grounding(model, tokenizer, image, text, bbox) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_result(image, evt: gr.SelectData): |
|
if evt.value: |
|
bbox = evt.value[1][1:-1] |
|
bbox = [int(num) for num in bbox.split(",")] |
|
output_img = inference(image, evt.value[0], bbox) |
|
return evt.value[0], output_img |
|
|
|
GT_text = { |
|
"Finding 1": "Small left apical pneumothorax", |
|
"Finding 2": "Greater coalescence of consolidation in both the right mid and lower lung zones", |
|
"Finding 3": "Consilidation in the left lower lobe" |
|
} |
|
|
|
GT_bboxes = {"Finding 1": [1, 332, 28, 332+141, 28+48], "Finding 2": [2, 57, 177, 163+57, 165+177], "Finding 3": [3, 325, 231, 183+325, 132+231]} |
|
def get_new_result(image, evt: gr.SelectData): |
|
if evt.value[1]: |
|
if evt.value[0] == "(Show GT)": |
|
bbox = GT_bboxes[evt.value[1]] |
|
text = GT_text[evt.value[1]] |
|
else: |
|
bbox = [GT_bboxes[evt.value[1]][0], 0, 0, 0, 0] |
|
text = evt.value[0] |
|
output_img = inference(image, text, bbox) |
|
return text, output_img |
|
|
|
def clear(): |
|
return "" |
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown( |
|
""" |
|
<center> <h1>Medical Phrase Grounding Demo</h1> </center> |
|
<p style='text-align: center'> <a href='https://arxiv.org/abs/2303.07618' target='_blank'>Paper</a> </p> |
|
""") |
|
with gr.Row(): |
|
with gr.Column(scale=1, min_width=300): |
|
input_image = gr.Image(type='pil', value="./images/649af982-e3af4e3a-75013d30-cdc71514-a34738fd.jpg") |
|
hl_text = gr.HighlightedText( |
|
label="Medical Report", |
|
combine_adjacent=False, |
|
|
|
show_legend=False, |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
value = [("Small left apical pneumothorax","Finding 1"), |
|
("(Show GT)","Finding 1"), |
|
("unchanged in size since ___:56 a.m., and no appreciable left pleural effusion, basal pleural tubes still in place and reportedly on waterseal.", None), |
|
("Greater coalescence of consolidation in both the right mid and lower lung zones","Finding 2"), |
|
("(Show GT)","Finding 2"), |
|
("could be progressive atelectasis but is more concerning for pneumonia.", None), |
|
("Consilidation in the left lower lobe","Finding 3"), |
|
("(Show GT)","Finding 3"), |
|
|
|
(", however, has improved since ___ through ___.", None), |
|
] |
|
) |
|
input_text = gr.Textbox(label="Input Text", interactive=False) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
output = gr.Image(type="pil", value="./images/649af982-e3af4e3a-75013d30-cdc71514-a34738fd.jpg", label="Grounding Results", interactive=False).style(height=500) |
|
hl_text.select(get_new_result, inputs=[input_image], outputs=[input_text, output]) |
|
|
|
|
|
demo.launch(share=True) |
|
|
|
|
|
|