SNOMED-Entity-Linking / segmentation.py
xlreator's picture
initial commit
67eaf9f verified
import torch.nn.functional as F
import pandas as pd
from dataloader import create_dataloader
from utils import *
def predict_segmentation(inp, model, device, batch_size=8):
test_loader = create_dataloader(inp, batch_size)
predictions = []
for batch in test_loader:
batch = {k: v.to(device) for k, v in batch.items()}
p = F.sigmoid(model(**batch).logits).detach().cpu().numpy()
predictions.append(p)
return np.concatenate(predictions, axis=0)
def create_data(text, tokenizer, seq_len=512):
tokens = tokenizer(text, add_special_tokens=False)
_token_batches = {k: [pad_seq(x, seq_len) for x in batch_list(v, seq_len)]
for (k, v) in tokens.items()}
n_batches = len(_token_batches['input_ids'])
return [{k: v[i] for k, v in _token_batches.items()}
for i in range(n_batches)]
def segment_tokens(notes, model, tokenizer, device, batch_size=8):
predictions = {}
for note in notes.itertuples():
note_id = note.note_id
raw_text = note.text.lower()
inp = create_data(raw_text, tokenizer)
pred_probs = predict_segmentation(inp, model, device, batch_size=batch_size)
pred_probs = np.squeeze(pred_probs, -1)
pred_probs = np.concatenate(pred_probs)
predictions[note_id] = pred_probs
return predictions
def segment(notes, model, tokenizer, device, thresh, batch_size=8):
predictions = []
predictions_prob_map = segment_tokens(notes, model, tokenizer, device, batch_size)
for note in notes.itertuples():
note_id = note.note_id
raw_text = note.text
decoded_text = tokenizer.decode(tokenizer.encode(raw_text, add_special_tokens=False))
pred_probs = predictions_prob_map[note_id]
_, pred_probs = align_decoded(raw_text, decoded_text, pred_probs)
pred_probs = np.array(pred_probs, 'float32')
pred = (pred_probs > thresh).astype('uint8')
spans = get_sequential_spans(pred)
note_predictions = {'note_id': [], 'start': [], 'end': [], 'mention': [], 'score': []}
for (start, end) in spans:
note_predictions['note_id'].append(note_id)
note_predictions['score'].append(pred_probs[start:end].mean())
note_predictions['start'].append(start)
note_predictions['end'].append(end)
note_predictions['mention'].append(raw_text[start:end])
note_predictions = pd.DataFrame(note_predictions)
note_predictions = note_predictions.sort_values('score', ascending=False)
# remove overlapping spans
seen_spans = set()
unseen = []
for span in note_predictions[['start', 'end']].values:
span = tuple(span)
s = False
if not is_overlap(seen_spans, span):
seen_spans.add(span)
s = True
unseen.append(s)
note_predictions = note_predictions[unseen]
predictions.append(note_predictions)
predictions = pd.concat(predictions).reset_index(drop=True)
return predictions