File size: 6,999 Bytes
91ef820
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import gradio as gr
import torch

import sys
# sys.path.insert(0, '/Users/daipl/Desktop/MedRPG_Demo/med_rpg')
sys.path.insert(0, 'med_rpg')

# import datasets
from models import build_model
from med_rpg import get_args_parser, medical_phrase_grounding
import PIL.Image as Image
from transformers import AutoTokenizer

'''
build args
'''
parser = get_args_parser()
args = parser.parse_args()

'''
build model
'''

# device = 'cpu'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device('mps')

# Check that MPS is available
# if not torch.backends.mps.is_available():
#     if not torch.backends.mps.is_built():
#         print("MPS not available because the current PyTorch install was not "
#               "built with MPS enabled.")
#     else:
#         print("MPS not available because the current MacOS version is not 12.3+ "
#               "and/or you do not have an MPS-enabled device on this machine.")

# else:
#     device = torch.device("mps")

tokenizer = AutoTokenizer.from_pretrained(args.bert_model, do_lower_case=True)
## build model
model = build_model(args)
model.to(device)
checkpoint = torch.load(args.eval_model, map_location='cpu')
model.load_state_dict(checkpoint['model'], strict=False)

'''
inference model
'''
@torch.no_grad()
def inference(image, text, bbox = [0, 0, 0, 0]):
    image = image.convert("RGB")
    # if bbox is not None:
    # bbox = bbox.to_numpy(dtype='int')[0].tolist()
    with torch.autocast(device_type='cuda', dtype=torch.float16):
        return medical_phrase_grounding(model, tokenizer, image, text, bbox)

# """
#     Small left apical pneumothorax unchanged in size since ___:56 a.m., 
#     and no appreciable left pleural effusion, 
#     basal pleural tubes still in place and reportedly on waterseal. 
#     Greater coalescence of consolidation in both the right mid and lower lung zones could be progressive atelectasis but is more concerning for pneumonia. 
#     Consolidation in the left lower lobe, however, has improved since ___ through ___. 
#     There is no right pleural effusion or definite right pneumothorax. 
#     Cardiomediastinal silhouette is normal. 
#     Distention of large and small bowel seen in the imaged portion of the upper abdomen is unchanged. 
# """
        
def get_result(image, evt: gr.SelectData):
    if evt.value:
        bbox = evt.value[1][1:-1] # Remove "[" and "]"
        bbox = [int(num) for num in bbox.split(",")]
        output_img = inference(image, evt.value[0], bbox)
        return evt.value[0], output_img

GT_text = {
    "Finding 1": "Small left apical pneumothorax", 
    "Finding 2": "Greater coalescence of consolidation in both the right mid and lower lung zones", 
    "Finding 3": "Consilidation in the left lower lobe"
}
# GT_bboxes = {"Finding 1": [1, 332, 28, 141, 48], "Finding 2": [2, 57, 177, 163, 165], "Finding 3": [3, 325, 231, 183, 132]}
GT_bboxes = {"Finding 1": [1, 332, 28, 332+141, 28+48], "Finding 2": [2, 57, 177, 163+57, 165+177], "Finding 3": [3, 325, 231, 183+325, 132+231]}
def get_new_result(image, evt: gr.SelectData):
    if evt.value[1]:
        if evt.value[0] == "(Show GT)":
            bbox = GT_bboxes[evt.value[1]]
            text = GT_text[evt.value[1]]
        else:
            bbox = [GT_bboxes[evt.value[1]][0], 0, 0, 0, 0]
            text = evt.value[0]
        output_img = inference(image, text, bbox)
        return text, output_img
    
def clear():
    return ""

with gr.Blocks() as demo:
    gr.Markdown(
    """
    <center> <h1>Medical Phrase Grounding Demo</h1> </center>
    <p style='text-align: center'> <a href='https://arxiv.org/abs/2303.07618' target='_blank'>Paper</a> </p>
    """)
    with gr.Row():
        with gr.Column(scale=1, min_width=300):
            input_image = gr.Image(type='pil', value="./images/649af982-e3af4e3a-75013d30-cdc71514-a34738fd.jpg")
            hl_text = gr.HighlightedText(
                label="Medical Report",
                combine_adjacent=False,
                # combine_adjacent=True,
                show_legend=False,
                # value = [("Small left apical pneumothorax","[332, 28, 141, 48]"), 
                #         ("unchanged in size since ___:56 a.m., and no appreciable left pleural effusion, basal pleural tubes still in place and reportedly on waterseal.", None),
                #         ("Greater coalescence of consolidation in both the right mid and lower lung zones","[57, 177, 163, 165]"), 
                #         ("could be progressive atelectasis but is more concerning for pneumonia.", None),
                #         ("Consilidation in the left lower lobe","[325, 231, 183, 132]"),
                #         (", however, has improved since ___ through ___.", None),
                #         # ("There is no right pleural effusion or definite right pneumothorax.", None),
                #         # ("Cardiomediastinal silhouette is normal.", None),
                #         # ("Distention of large and small bowel seen in the imaged portion of the upper abdomen is unchanged.", None),
                # ]
                value = [("Small left apical pneumothorax","Finding 1"), 
                        ("(Show GT)","Finding 1"), 
                        ("unchanged in size since ___:56 a.m., and no appreciable left pleural effusion, basal pleural tubes still in place and reportedly on waterseal.", None),
                        ("Greater coalescence of consolidation in both the right mid and lower lung zones","Finding 2"), 
                        ("(Show GT)","Finding 2"), 
                        ("could be progressive atelectasis but is more concerning for pneumonia.", None),
                        ("Consilidation in the left lower lobe","Finding 3"),
                        ("(Show GT)","Finding 3"), 
                        # ", however, has improved since ___ through ___.",
                        (", however, has improved since ___ through ___.", None),
                ]
            )
            input_text = gr.Textbox(label="Input Text", interactive=False)
            # bbox = gr.Dataframe(
            #     headers=["x", "y", "w", "h"],
            #     datatype=["number", "number", "number", "number"],
            #     label="Groud-Truth Bounding Box",
            #     value=[[332, 28, 141, 48]]
            # )
            # with gr.Row():
            #     clear_btn = gr.Button("Clear")
            #     run_btn = gr.Button("Run")
        # output = gr.Image(type="pil", label="Grounding Results", interactive=False).style(height=500)
        output = gr.Image(type="pil", value="./images/649af982-e3af4e3a-75013d30-cdc71514-a34738fd.jpg", label="Grounding Results", interactive=False).style(height=500)
        hl_text.select(get_new_result, inputs=[input_image], outputs=[input_text, output])
        # run_btn.click(fn=inference, inputs=[input_image, input_text], outputs=output)
        # clear_btn.click(fn=clear, outputs=input_text)
demo.launch(share=True)