File size: 6,999 Bytes
91ef820 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
import gradio as gr
import torch
import sys
# sys.path.insert(0, '/Users/daipl/Desktop/MedRPG_Demo/med_rpg')
sys.path.insert(0, 'med_rpg')
# import datasets
from models import build_model
from med_rpg import get_args_parser, medical_phrase_grounding
import PIL.Image as Image
from transformers import AutoTokenizer
'''
build args
'''
parser = get_args_parser()
args = parser.parse_args()
'''
build model
'''
# device = 'cpu'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device('mps')
# Check that MPS is available
# if not torch.backends.mps.is_available():
# if not torch.backends.mps.is_built():
# print("MPS not available because the current PyTorch install was not "
# "built with MPS enabled.")
# else:
# print("MPS not available because the current MacOS version is not 12.3+ "
# "and/or you do not have an MPS-enabled device on this machine.")
# else:
# device = torch.device("mps")
tokenizer = AutoTokenizer.from_pretrained(args.bert_model, do_lower_case=True)
## build model
model = build_model(args)
model.to(device)
checkpoint = torch.load(args.eval_model, map_location='cpu')
model.load_state_dict(checkpoint['model'], strict=False)
'''
inference model
'''
@torch.no_grad()
def inference(image, text, bbox = [0, 0, 0, 0]):
image = image.convert("RGB")
# if bbox is not None:
# bbox = bbox.to_numpy(dtype='int')[0].tolist()
with torch.autocast(device_type='cuda', dtype=torch.float16):
return medical_phrase_grounding(model, tokenizer, image, text, bbox)
# """
# Small left apical pneumothorax unchanged in size since ___:56 a.m.,
# and no appreciable left pleural effusion,
# basal pleural tubes still in place and reportedly on waterseal.
# Greater coalescence of consolidation in both the right mid and lower lung zones could be progressive atelectasis but is more concerning for pneumonia.
# Consolidation in the left lower lobe, however, has improved since ___ through ___.
# There is no right pleural effusion or definite right pneumothorax.
# Cardiomediastinal silhouette is normal.
# Distention of large and small bowel seen in the imaged portion of the upper abdomen is unchanged.
# """
def get_result(image, evt: gr.SelectData):
if evt.value:
bbox = evt.value[1][1:-1] # Remove "[" and "]"
bbox = [int(num) for num in bbox.split(",")]
output_img = inference(image, evt.value[0], bbox)
return evt.value[0], output_img
GT_text = {
"Finding 1": "Small left apical pneumothorax",
"Finding 2": "Greater coalescence of consolidation in both the right mid and lower lung zones",
"Finding 3": "Consilidation in the left lower lobe"
}
# GT_bboxes = {"Finding 1": [1, 332, 28, 141, 48], "Finding 2": [2, 57, 177, 163, 165], "Finding 3": [3, 325, 231, 183, 132]}
GT_bboxes = {"Finding 1": [1, 332, 28, 332+141, 28+48], "Finding 2": [2, 57, 177, 163+57, 165+177], "Finding 3": [3, 325, 231, 183+325, 132+231]}
def get_new_result(image, evt: gr.SelectData):
if evt.value[1]:
if evt.value[0] == "(Show GT)":
bbox = GT_bboxes[evt.value[1]]
text = GT_text[evt.value[1]]
else:
bbox = [GT_bboxes[evt.value[1]][0], 0, 0, 0, 0]
text = evt.value[0]
output_img = inference(image, text, bbox)
return text, output_img
def clear():
return ""
with gr.Blocks() as demo:
gr.Markdown(
"""
<center> <h1>Medical Phrase Grounding Demo</h1> </center>
<p style='text-align: center'> <a href='https://arxiv.org/abs/2303.07618' target='_blank'>Paper</a> </p>
""")
with gr.Row():
with gr.Column(scale=1, min_width=300):
input_image = gr.Image(type='pil', value="./images/649af982-e3af4e3a-75013d30-cdc71514-a34738fd.jpg")
hl_text = gr.HighlightedText(
label="Medical Report",
combine_adjacent=False,
# combine_adjacent=True,
show_legend=False,
# value = [("Small left apical pneumothorax","[332, 28, 141, 48]"),
# ("unchanged in size since ___:56 a.m., and no appreciable left pleural effusion, basal pleural tubes still in place and reportedly on waterseal.", None),
# ("Greater coalescence of consolidation in both the right mid and lower lung zones","[57, 177, 163, 165]"),
# ("could be progressive atelectasis but is more concerning for pneumonia.", None),
# ("Consilidation in the left lower lobe","[325, 231, 183, 132]"),
# (", however, has improved since ___ through ___.", None),
# # ("There is no right pleural effusion or definite right pneumothorax.", None),
# # ("Cardiomediastinal silhouette is normal.", None),
# # ("Distention of large and small bowel seen in the imaged portion of the upper abdomen is unchanged.", None),
# ]
value = [("Small left apical pneumothorax","Finding 1"),
("(Show GT)","Finding 1"),
("unchanged in size since ___:56 a.m., and no appreciable left pleural effusion, basal pleural tubes still in place and reportedly on waterseal.", None),
("Greater coalescence of consolidation in both the right mid and lower lung zones","Finding 2"),
("(Show GT)","Finding 2"),
("could be progressive atelectasis but is more concerning for pneumonia.", None),
("Consilidation in the left lower lobe","Finding 3"),
("(Show GT)","Finding 3"),
# ", however, has improved since ___ through ___.",
(", however, has improved since ___ through ___.", None),
]
)
input_text = gr.Textbox(label="Input Text", interactive=False)
# bbox = gr.Dataframe(
# headers=["x", "y", "w", "h"],
# datatype=["number", "number", "number", "number"],
# label="Groud-Truth Bounding Box",
# value=[[332, 28, 141, 48]]
# )
# with gr.Row():
# clear_btn = gr.Button("Clear")
# run_btn = gr.Button("Run")
# output = gr.Image(type="pil", label="Grounding Results", interactive=False).style(height=500)
output = gr.Image(type="pil", value="./images/649af982-e3af4e3a-75013d30-cdc71514-a34738fd.jpg", label="Grounding Results", interactive=False).style(height=500)
hl_text.select(get_new_result, inputs=[input_image], outputs=[input_text, output])
# run_btn.click(fn=inference, inputs=[input_image, input_text], outputs=output)
# clear_btn.click(fn=clear, outputs=input_text)
demo.launch(share=True)
|