Spaces:
Running
Running
danielhajialigol
commited on
Commit
•
4067c90
1
Parent(s):
d1017a6
adding ner models
Browse files- app.py +39 -22
- discharge_embeddings.pt +2 -2
- requirements.txt +1 -0
- utils.py +46 -4
app.py
CHANGED
@@ -4,32 +4,38 @@ import pandas as pd
|
|
4 |
import torch
|
5 |
|
6 |
from model import MimicTransformer
|
7 |
-
from utils import load_rule, get_attribution, get_drg_link, get_icd_annotations, visualize_attn
|
8 |
-
from transformers import set_seed
|
9 |
|
10 |
set_seed(42)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
def read_model(model, path):
|
13 |
model.load_state_dict(torch.load(path, map_location=torch.device('cpu')), strict=False)
|
14 |
return model
|
15 |
|
16 |
-
model_path = 'checkpoint_0_9113.bin'
|
17 |
mimic = MimicTransformer(cutoff=512)
|
18 |
-
|
19 |
-
related_tensor = torch.load('discharge_embeddings.pt')
|
20 |
-
|
21 |
-
# get model and results
|
22 |
mimic = read_model(model=mimic, path=model_path)
|
23 |
-
all_summaries = pd.read_csv('all_summaries.csv')['SUMMARIES'].to_list()
|
24 |
-
|
25 |
tokenizer = mimic.tokenizer
|
26 |
mimic.eval()
|
27 |
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
related_summaries = [[ex1]]
|
34 |
related_chosen = []
|
35 |
related_attn = []
|
@@ -59,9 +65,14 @@ def get_model_results(text):
|
|
59 |
'logits': logits
|
60 |
}
|
61 |
|
62 |
-
def find_related_summaries(
|
63 |
-
|
64 |
-
|
|
|
|
|
|
|
|
|
|
|
65 |
scores_indices = scores.topk(k=5, dim=0)
|
66 |
indices, scores = scores_indices[-1], torch.round(100 * scores_indices[0], decimals=2)
|
67 |
summaries = []
|
@@ -74,8 +85,13 @@ def find_related_summaries(raw_embedding):
|
|
74 |
|
75 |
|
76 |
def run(text, related_discharges=False):
|
|
|
77 |
model_results = get_model_results(text=text)
|
78 |
drg_code = model_results['class']
|
|
|
|
|
|
|
|
|
79 |
drg_link = get_drg_link(drg_code=drg_code)
|
80 |
icd_results = get_icd_annotations(text=text)
|
81 |
row = rule_df[rule_df['DRG_CODE'] == drg_code]
|
@@ -85,7 +101,7 @@ def run(text, related_discharges=False):
|
|
85 |
model_results['icd_results'] = icd_results
|
86 |
global related_summaries
|
87 |
# related_summaries = generate_similar_summeries()
|
88 |
-
related_summaries = find_related_summaries(
|
89 |
if related_discharges:
|
90 |
return visualize_attn(model_results=model_results)
|
91 |
return (
|
@@ -193,10 +209,11 @@ def main():
|
|
193 |
|
194 |
# input to related summaries
|
195 |
with gr.Row() as row:
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
|
|
200 |
|
201 |
with gr.Row() as row:
|
202 |
related = gr.Dataset(samples=[], components=[input_related], visible=False, type='index')
|
|
|
4 |
import torch
|
5 |
|
6 |
from model import MimicTransformer
|
7 |
+
from utils import load_rule, get_attribution, get_diseases, get_drg_link, get_icd_annotations, visualize_attn
|
8 |
+
from transformers import AutoTokenizer, AutoModel, set_seed, pipeline
|
9 |
|
10 |
set_seed(42)
|
11 |
+
model_path = 'checkpoint_0_9113.bin'
|
12 |
+
related_tensor = torch.load('discharge_embeddings.pt')
|
13 |
+
all_summaries = pd.read_csv('all_summaries.csv')['SUMMARIES'].to_list()
|
14 |
+
|
15 |
+
similarity_tokenizer = AutoTokenizer.from_pretrained('kamalkraj/BioSimCSE-BioLinkBERT-BASE')
|
16 |
+
similarity_model = AutoModel.from_pretrained('kamalkraj/BioSimCSE-BioLinkBERT-BASE')
|
17 |
+
similarity_model.eval()
|
18 |
|
19 |
def read_model(model, path):
|
20 |
model.load_state_dict(torch.load(path, map_location=torch.device('cpu')), strict=False)
|
21 |
return model
|
22 |
|
|
|
23 |
mimic = MimicTransformer(cutoff=512)
|
|
|
|
|
|
|
|
|
24 |
mimic = read_model(model=mimic, path=model_path)
|
|
|
|
|
25 |
tokenizer = mimic.tokenizer
|
26 |
mimic.eval()
|
27 |
|
28 |
+
# disease ner model
|
29 |
+
pipe = pipeline("token-classification", model="alvaroalon2/biobert_diseases_ner")
|
30 |
+
|
31 |
+
#
|
32 |
+
|
33 |
+
ex1 = """HEAD CT: Head CT showed no intracranial hemorrhage or mass effect, but old infarction consistent with past medical history."""
|
34 |
+
ex2 = """Radiologic studies also included a chest CT, which confirmed cavitary lesions in the left lung apex consistent with infectious tuberculosis. This also moderate-sized left pleural effusion."""
|
35 |
+
ex3 = """We have discharged Mrs Smith on regular oral Furosemide (40mg OD) and we have requested an outpatient ultrasound of her renal tract which will be performed in the next few weeks. We will review Mrs Smith in the Cardiology Outpatient Clinic in 6 weeks time."""
|
36 |
+
ex4 = """Blood tests revealed a raised BNP. An ECG showed evidence of left-ventricular hypertrophy and echocardiography revealed grossly impaired ventricular function (ejection fraction 35%). A chest X-ray demonstrated bilateral pleural effusions, with evidence of upper lobe diversion."""
|
37 |
+
ex5 = """Mrs Smith presented to A&E with worsening shortness of breath and ankle swelling. On arrival, she was tachypnoeic and hypoxic (oxygen saturation 82% on air). Clinical examination revealed reduced breath sounds and dullness to percussion in both lung bases. There was also a significant degree of lower limb oedema extending up to the mid-thigh bilaterally."""
|
38 |
+
examples = [ex1, ex2, ex3, ex4, ex5]
|
39 |
related_summaries = [[ex1]]
|
40 |
related_chosen = []
|
41 |
related_attn = []
|
|
|
65 |
'logits': logits
|
66 |
}
|
67 |
|
68 |
+
def find_related_summaries(text):
|
69 |
+
inputs = similarity_tokenizer(
|
70 |
+
text, padding='max_length', truncation=True, return_tensors='pt', max_length=512
|
71 |
+
)
|
72 |
+
outputs = similarity_model(**inputs)
|
73 |
+
embedding = mean_pooling(outputs, attention_mask=inputs.attention_mask)
|
74 |
+
embedding = torch.nn.functional.normalize(embedding)
|
75 |
+
scores = torch.mm(related_tensor, embedding.transpose(1,0))
|
76 |
scores_indices = scores.topk(k=5, dim=0)
|
77 |
indices, scores = scores_indices[-1], torch.round(100 * scores_indices[0], decimals=2)
|
78 |
summaries = []
|
|
|
85 |
|
86 |
|
87 |
def run(text, related_discharges=False):
|
88 |
+
# initial drg results
|
89 |
model_results = get_model_results(text=text)
|
90 |
drg_code = model_results['class']
|
91 |
+
|
92 |
+
# find diseases
|
93 |
+
diseases = get_diseases(text=text, pipe=pipe)
|
94 |
+
model_results['diseases'] = diseases
|
95 |
drg_link = get_drg_link(drg_code=drg_code)
|
96 |
icd_results = get_icd_annotations(text=text)
|
97 |
row = rule_df[rule_df['DRG_CODE'] == drg_code]
|
|
|
101 |
model_results['icd_results'] = icd_results
|
102 |
global related_summaries
|
103 |
# related_summaries = generate_similar_summeries()
|
104 |
+
related_summaries = find_related_summaries(text=text)
|
105 |
if related_discharges:
|
106 |
return visualize_attn(model_results=model_results)
|
107 |
return (
|
|
|
209 |
|
210 |
# input to related summaries
|
211 |
with gr.Row() as row:
|
212 |
+
with gr.Column(scale=5) as col:
|
213 |
+
input_related = gr.TextArea(label="Input up to 3 Related Discharge Summary/Summaries Here", visible=False)
|
214 |
+
with gr.Column(scale=1) as col:
|
215 |
+
rmv_related_btn = gr.Button(value='Remove Related Summary', visible=False)
|
216 |
+
sbm_btn = gr.Button(value="Submit Related Summaries", components=[input_related], visible=False)
|
217 |
|
218 |
with gr.Row() as row:
|
219 |
related = gr.Dataset(samples=[], components=[input_related], visible=False, type='index')
|
discharge_embeddings.pt
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5596bf755d73898c6544b6254c2415283e2891deb6e3748f51ae9fb10794baee
|
3 |
+
size 30720786
|
requirements.txt
CHANGED
@@ -4,3 +4,4 @@ gradio
|
|
4 |
transformers
|
5 |
captum
|
6 |
tqdm
|
|
|
|
4 |
transformers
|
5 |
captum
|
6 |
tqdm
|
7 |
+
sentence-transformers
|
utils.py
CHANGED
@@ -20,6 +20,28 @@ class PyTMinMaxScalerVectorized(object):
|
|
20 |
scale = 1.0 / (tensor.max(dim=0, keepdim=True)[0] - tensor.min(dim=0, keepdim=True)[0])
|
21 |
tensor.mul_(scale).sub_(tensor.min(dim=0, keepdim=True)[0])
|
22 |
return tensor
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
|
24 |
def find_end(text):
|
25 |
"""Find the end of the report."""
|
@@ -230,7 +252,12 @@ def visualize_attn(model_results):
|
|
230 |
raw_input_ids=tokens,
|
231 |
convergence_score=1
|
232 |
)
|
233 |
-
return visualize_text(
|
|
|
|
|
|
|
|
|
|
|
234 |
|
235 |
|
236 |
def modify_attn_html(attn_html):
|
@@ -238,7 +265,7 @@ def modify_attn_html(attn_html):
|
|
238 |
htmls = [attn_split[0]]
|
239 |
for html in attn_split[1:]:
|
240 |
# wrap around href tag
|
241 |
-
href_html = f'<a href="https://
|
242 |
<mark{html} \
|
243 |
</a>'
|
244 |
htmls.append(href_html)
|
@@ -258,36 +285,51 @@ def get_icd_html(icd_list):
|
|
258 |
if len(icd_list) == 0:
|
259 |
return '<td><text style="padding-right:2em"><b>N/A</b></text></td>'
|
260 |
final_html = '<td>'
|
|
|
261 |
for icd_dict in icd_list:
|
262 |
text, link = icd_dict['text'], icd_dict['link']
|
|
|
|
|
263 |
tmp_html = visualization.format_classname(classname=text)
|
264 |
html = modify_code_html(html=tmp_html, link=link, icd=True)
|
265 |
final_html += html
|
|
|
266 |
return final_html + '</td>'
|
267 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
268 |
|
269 |
|
270 |
# copied out of captum because we need raw html instead of a jupyter widget
|
271 |
-
def visualize_text(datarecord, drg_link, icd_annotations):
|
272 |
dom = ["<table width: 100%>"]
|
273 |
rows = [
|
274 |
"<th style='text-align: left'>Predicted DRG</th>"
|
275 |
"<th style='text-align: left'>Word Importance</th>"
|
|
|
276 |
"<th style='text-align: left'>ICD Codes</th>"
|
277 |
]
|
278 |
pred_class_html = visualization.format_classname(datarecord.pred_class)
|
279 |
icd_class_html = get_icd_html(icd_annotations)
|
|
|
280 |
pred_class_html = modify_drg_html(html=pred_class_html, drg_link=drg_link)
|
281 |
word_attn_html = visualization.format_word_importances(
|
282 |
datarecord.raw_input_ids, datarecord.word_attributions
|
283 |
)
|
284 |
-
word_attn_html = modify_attn_html(word_attn_html)
|
285 |
rows.append(
|
286 |
"".join(
|
287 |
[
|
288 |
"<tr>",
|
289 |
pred_class_html,
|
290 |
word_attn_html,
|
|
|
291 |
icd_class_html,
|
292 |
"<tr>",
|
293 |
]
|
|
|
20 |
scale = 1.0 / (tensor.max(dim=0, keepdim=True)[0] - tensor.min(dim=0, keepdim=True)[0])
|
21 |
tensor.mul_(scale).sub_(tensor.min(dim=0, keepdim=True)[0])
|
22 |
return tensor
|
23 |
+
|
24 |
+
def get_diseases(text, pipe):
|
25 |
+
results = pipe(text)
|
26 |
+
diseases = []
|
27 |
+
disease_span = []
|
28 |
+
for result in results:
|
29 |
+
ent = result['entity']
|
30 |
+
# start of a new entity
|
31 |
+
if ent == 'B-DISEASE':
|
32 |
+
disease_span = result['start'], result['end']
|
33 |
+
elif ent == 'I-DISEASE':
|
34 |
+
disease_span = disease_span[0], result['end']
|
35 |
+
else:
|
36 |
+
if len(disease_span) > 1:
|
37 |
+
disease = text[disease_span[0]: disease_span[1]]
|
38 |
+
if len(disease) > 2:
|
39 |
+
diseases.append(disease)
|
40 |
+
disease_span = []
|
41 |
+
if len(disease_span) > 1:
|
42 |
+
disease = text[disease_span[0]: disease_span[1]]
|
43 |
+
diseases.append(disease)
|
44 |
+
return diseases
|
45 |
|
46 |
def find_end(text):
|
47 |
"""Find the end of the report."""
|
|
|
252 |
raw_input_ids=tokens,
|
253 |
convergence_score=1
|
254 |
)
|
255 |
+
return visualize_text(
|
256 |
+
viz_record,
|
257 |
+
drg_link=model_results['drg_link'],
|
258 |
+
icd_annotations=model_results['icd_results'],
|
259 |
+
diseases=model_results['diseases']
|
260 |
+
)
|
261 |
|
262 |
|
263 |
def modify_attn_html(attn_html):
|
|
|
265 |
htmls = [attn_split[0]]
|
266 |
for html in attn_split[1:]:
|
267 |
# wrap around href tag
|
268 |
+
href_html = f'<a href="https://" \
|
269 |
<mark{html} \
|
270 |
</a>'
|
271 |
htmls.append(href_html)
|
|
|
285 |
if len(icd_list) == 0:
|
286 |
return '<td><text style="padding-right:2em"><b>N/A</b></text></td>'
|
287 |
final_html = '<td>'
|
288 |
+
icd_set = set()
|
289 |
for icd_dict in icd_list:
|
290 |
text, link = icd_dict['text'], icd_dict['link']
|
291 |
+
if text in icd_set:
|
292 |
+
continue
|
293 |
tmp_html = visualization.format_classname(classname=text)
|
294 |
html = modify_code_html(html=tmp_html, link=link, icd=True)
|
295 |
final_html += html
|
296 |
+
icd_set.add(text)
|
297 |
return final_html + '</td>'
|
298 |
|
299 |
+
|
300 |
+
def get_disease_html(diseases):
|
301 |
+
if len(diseases) == 0:
|
302 |
+
return '<td><text style="padding-right:2em"><b>N/A</b></text></td>'
|
303 |
+
diseases = list(set(diseases))
|
304 |
+
diseases_str = ', '.join(diseases)
|
305 |
+
html = visualization.format_classname(classname=diseases_str)
|
306 |
+
return html + '</td>'
|
307 |
+
|
308 |
|
309 |
|
310 |
# copied out of captum because we need raw html instead of a jupyter widget
|
311 |
+
def visualize_text(datarecord, drg_link, icd_annotations, diseases):
|
312 |
dom = ["<table width: 100%>"]
|
313 |
rows = [
|
314 |
"<th style='text-align: left'>Predicted DRG</th>"
|
315 |
"<th style='text-align: left'>Word Importance</th>"
|
316 |
+
"<th style='text-align: left'>Diseases</th>"
|
317 |
"<th style='text-align: left'>ICD Codes</th>"
|
318 |
]
|
319 |
pred_class_html = visualization.format_classname(datarecord.pred_class)
|
320 |
icd_class_html = get_icd_html(icd_annotations)
|
321 |
+
disease_html = get_disease_html(diseases)
|
322 |
pred_class_html = modify_drg_html(html=pred_class_html, drg_link=drg_link)
|
323 |
word_attn_html = visualization.format_word_importances(
|
324 |
datarecord.raw_input_ids, datarecord.word_attributions
|
325 |
)
|
|
|
326 |
rows.append(
|
327 |
"".join(
|
328 |
[
|
329 |
"<tr>",
|
330 |
pred_class_html,
|
331 |
word_attn_html,
|
332 |
+
disease_html,
|
333 |
icd_class_html,
|
334 |
"<tr>",
|
335 |
]
|