MedRPG / app.py.py
zy5830850
First model version
91ef820
raw
history blame
7 kB
import gradio as gr
import torch
import sys
# sys.path.insert(0, '/Users/daipl/Desktop/MedRPG_Demo/med_rpg')
sys.path.insert(0, 'med_rpg')
# import datasets
from models import build_model
from med_rpg import get_args_parser, medical_phrase_grounding
import PIL.Image as Image
from transformers import AutoTokenizer
'''
build args
'''
parser = get_args_parser()
args = parser.parse_args()
'''
build model
'''
# device = 'cpu'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device('mps')
# Check that MPS is available
# if not torch.backends.mps.is_available():
# if not torch.backends.mps.is_built():
# print("MPS not available because the current PyTorch install was not "
# "built with MPS enabled.")
# else:
# print("MPS not available because the current MacOS version is not 12.3+ "
# "and/or you do not have an MPS-enabled device on this machine.")
# else:
# device = torch.device("mps")
tokenizer = AutoTokenizer.from_pretrained(args.bert_model, do_lower_case=True)
## build model
model = build_model(args)
model.to(device)
checkpoint = torch.load(args.eval_model, map_location='cpu')
model.load_state_dict(checkpoint['model'], strict=False)
'''
inference model
'''
@torch.no_grad()
def inference(image, text, bbox = [0, 0, 0, 0]):
image = image.convert("RGB")
# if bbox is not None:
# bbox = bbox.to_numpy(dtype='int')[0].tolist()
with torch.autocast(device_type='cuda', dtype=torch.float16):
return medical_phrase_grounding(model, tokenizer, image, text, bbox)
# """
# Small left apical pneumothorax unchanged in size since ___:56 a.m.,
# and no appreciable left pleural effusion,
# basal pleural tubes still in place and reportedly on waterseal.
# Greater coalescence of consolidation in both the right mid and lower lung zones could be progressive atelectasis but is more concerning for pneumonia.
# Consolidation in the left lower lobe, however, has improved since ___ through ___.
# There is no right pleural effusion or definite right pneumothorax.
# Cardiomediastinal silhouette is normal.
# Distention of large and small bowel seen in the imaged portion of the upper abdomen is unchanged.
# """
def get_result(image, evt: gr.SelectData):
if evt.value:
bbox = evt.value[1][1:-1] # Remove "[" and "]"
bbox = [int(num) for num in bbox.split(",")]
output_img = inference(image, evt.value[0], bbox)
return evt.value[0], output_img
GT_text = {
"Finding 1": "Small left apical pneumothorax",
"Finding 2": "Greater coalescence of consolidation in both the right mid and lower lung zones",
"Finding 3": "Consilidation in the left lower lobe"
}
# GT_bboxes = {"Finding 1": [1, 332, 28, 141, 48], "Finding 2": [2, 57, 177, 163, 165], "Finding 3": [3, 325, 231, 183, 132]}
GT_bboxes = {"Finding 1": [1, 332, 28, 332+141, 28+48], "Finding 2": [2, 57, 177, 163+57, 165+177], "Finding 3": [3, 325, 231, 183+325, 132+231]}
def get_new_result(image, evt: gr.SelectData):
if evt.value[1]:
if evt.value[0] == "(Show GT)":
bbox = GT_bboxes[evt.value[1]]
text = GT_text[evt.value[1]]
else:
bbox = [GT_bboxes[evt.value[1]][0], 0, 0, 0, 0]
text = evt.value[0]
output_img = inference(image, text, bbox)
return text, output_img
def clear():
return ""
with gr.Blocks() as demo:
gr.Markdown(
"""
<center> <h1>Medical Phrase Grounding Demo</h1> </center>
<p style='text-align: center'> <a href='https://arxiv.org/abs/2303.07618' target='_blank'>Paper</a> </p>
""")
with gr.Row():
with gr.Column(scale=1, min_width=300):
input_image = gr.Image(type='pil', value="./images/649af982-e3af4e3a-75013d30-cdc71514-a34738fd.jpg")
hl_text = gr.HighlightedText(
label="Medical Report",
combine_adjacent=False,
# combine_adjacent=True,
show_legend=False,
# value = [("Small left apical pneumothorax","[332, 28, 141, 48]"),
# ("unchanged in size since ___:56 a.m., and no appreciable left pleural effusion, basal pleural tubes still in place and reportedly on waterseal.", None),
# ("Greater coalescence of consolidation in both the right mid and lower lung zones","[57, 177, 163, 165]"),
# ("could be progressive atelectasis but is more concerning for pneumonia.", None),
# ("Consilidation in the left lower lobe","[325, 231, 183, 132]"),
# (", however, has improved since ___ through ___.", None),
# # ("There is no right pleural effusion or definite right pneumothorax.", None),
# # ("Cardiomediastinal silhouette is normal.", None),
# # ("Distention of large and small bowel seen in the imaged portion of the upper abdomen is unchanged.", None),
# ]
value = [("Small left apical pneumothorax","Finding 1"),
("(Show GT)","Finding 1"),
("unchanged in size since ___:56 a.m., and no appreciable left pleural effusion, basal pleural tubes still in place and reportedly on waterseal.", None),
("Greater coalescence of consolidation in both the right mid and lower lung zones","Finding 2"),
("(Show GT)","Finding 2"),
("could be progressive atelectasis but is more concerning for pneumonia.", None),
("Consilidation in the left lower lobe","Finding 3"),
("(Show GT)","Finding 3"),
# ", however, has improved since ___ through ___.",
(", however, has improved since ___ through ___.", None),
]
)
input_text = gr.Textbox(label="Input Text", interactive=False)
# bbox = gr.Dataframe(
# headers=["x", "y", "w", "h"],
# datatype=["number", "number", "number", "number"],
# label="Groud-Truth Bounding Box",
# value=[[332, 28, 141, 48]]
# )
# with gr.Row():
# clear_btn = gr.Button("Clear")
# run_btn = gr.Button("Run")
# output = gr.Image(type="pil", label="Grounding Results", interactive=False).style(height=500)
output = gr.Image(type="pil", value="./images/649af982-e3af4e3a-75013d30-cdc71514-a34738fd.jpg", label="Grounding Results", interactive=False).style(height=500)
hl_text.select(get_new_result, inputs=[input_image], outputs=[input_text, output])
# run_btn.click(fn=inference, inputs=[input_image, input_text], outputs=output)
# clear_btn.click(fn=clear, outputs=input_text)
demo.launch(share=True)