Spaces:
Runtime error
Runtime error
File size: 7,289 Bytes
3061f29 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 |
import os
import torch
from datasets import load_dataset, DatasetDict
from transformers import AutoTokenizer, AutoModel
import chromadb
import gradio as gr
import numpy as np
from sklearn.metrics import precision_score, recall_score, f1_score
# Mean Pooling - Take attention mask into account for correct averaging
def meanpooling(output, mask):
embeddings = output[0] # First element of model_output contains all token embeddings
mask = mask.unsqueeze(-1).expand(embeddings.size()).float()
return torch.sum(embeddings * mask, 1) / torch.clamp(mask.sum(1), min=1e-9)
# Load the dataset
dataset = load_dataset("thankrandomness/mimic-iii")
# Split the dataset into train and validation sets
split_dataset = dataset['train'].train_test_split(test_size=0.2, seed=42)
dataset = DatasetDict({
'train': split_dataset['train'],
'validation': split_dataset['test']
})
# Load the model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("neuml/pubmedbert-base-embeddings-matryoshka")
model = AutoModel.from_pretrained("neuml/pubmedbert-base-embeddings-matryoshka")
# Function to normalize embeddings to unit vectors
def normalize_embedding(embedding):
norm = np.linalg.norm(embedding)
return (embedding / norm).tolist() if norm > 0 else embedding
# Function to embed and normalize text
def embed_text(text):
inputs = tokenizer(text, padding=True, truncation=True, max_length=512, return_tensors='pt')
with torch.no_grad():
output = model(**inputs)
embeddings = meanpooling(output, inputs['attention_mask'])
normalized_embeddings = normalize_embedding(embeddings.numpy())
return normalized_embeddings
# Initialize ChromaDB client
client = chromadb.Client()
collection = client.create_collection(name="pubmedbert_matryoshka_embeddings")
# Function to upsert data into ChromaDB
def upsert_data(dataset_split):
for i, row in enumerate(dataset_split):
for note in row['notes']:
text = note.get('text', '')
annotations_list = []
for annotation in note.get('annotations', []):
try:
code = annotation['code']
code_system = annotation['code_system']
description = annotation['description']
annotations_list.append({"code": code, "code_system": code_system, "description": description})
except KeyError as e:
print(f"Skipping annotation due to missing key: {e}")
if text and annotations_list:
embeddings = embed_text([text])[0]
# Upsert data, embeddings, and annotations into ChromaDB
for j, annotation in enumerate(annotations_list):
collection.upsert(
ids=[f"note_{note['note_id']}_{j}"],
embeddings=[embeddings],
metadatas=[annotation]
)
else:
print(f"Skipping note {note['note_id']} due to missing 'text' or 'annotations'")
# Upsert training data
upsert_data(dataset['train'])
# Define retrieval function with similarity threshold
def retrieve_relevant_text(input_text):
input_embedding = embed_text([input_text])[0]
results = collection.query(
query_embeddings=[input_embedding],
n_results=5,
include=["metadatas", "documents", "distances"]
)
output = []
#print("Retrieved items and their similarity scores:")
for metadata, distance in zip(results['metadatas'][0], results['distances'][0]):
#print(f"Code: {metadata['code']}, Similarity Score: {distance}")
#if distance <= similarity_threshold:
output.append({
"similarity_score": distance,
"code": metadata['code'],
"code_system": metadata['code_system'],
"description": metadata['description']
})
# if not output:
# print("No results met the similarity threshold.")
return output
# Evaluate retrieval efficiency on the validation/test set
def evaluate_efficiency(dataset_split):
y_true = []
y_pred = []
total_similarity = 0
total_items = 0
for i, row in enumerate(dataset_split):
for note in row['notes']:
text = note.get('text', '')
annotations_list = [annotation['code'] for annotation in note.get('annotations', []) if 'code' in annotation]
if text and annotations_list:
retrieved_results = retrieve_relevant_text(text)
retrieved_codes = [result['code'] for result in retrieved_results]
# Sum up similarity scores for average calculation
for result in retrieved_results:
total_similarity += result['similarity_score']
total_items += 1
# Ground truth
y_true.extend(annotations_list)
# Predictions (limit to length of true annotations to avoid mismatch)
y_pred.extend(retrieved_codes[:len(annotations_list)])
# for result in retrieved_results:
# print(f" Code: {result['code']}, Similarity Score: {result['similarity_score']:.2f}")
# Debugging output to check for mismatches and understand results
# print("Sample y_true:", y_true[:10])
# print("Sample y_pred:", y_pred[:10])
if total_items > 0:
avg_similarity = total_similarity / total_items
else:
avg_similarity = 0
if len(y_true) != len(y_pred):
min_length = min(len(y_true), len(y_pred))
y_true = y_true[:min_length]
y_pred = y_pred[:min_length]
# Calculate metrics
precision = precision_score(y_true, y_pred, average='macro', zero_division=0)
recall = recall_score(y_true, y_pred, average='macro', zero_division=0)
f1 = f1_score(y_true, y_pred, average='macro', zero_division=0)
return precision, recall, f1, avg_similarity
# Calculate retrieval efficiency metrics
precision, recall, f1, avg_similarity = evaluate_efficiency(dataset['validation'])
# Gradio interface
def gradio_interface(input_text):
results = retrieve_relevant_text(input_text)
formatted_results = [
f"Result {i + 1}:\n"
f"Similarity Score: {result['similarity_score']:.2f}\n"
f"Code: {result['code']}\n"
f"Code System: {result['code_system']}\n"
f"Description: {result['description']}\n"
"-------------------"
for i, result in enumerate(results)
]
return "\n".join(formatted_results)
# Display retrieval efficiency metrics
# metrics = f"Precision: {precision:.2f}, Recall: {recall:.2f}, F1 Score: {f1:.2f}"
metrics = f"Accuracy: {avg_similarity:.2f}"
with gr.Blocks() as interface:
gr.Markdown("# Automated Medical Coding POC")
# gr.Markdown(metrics)
with gr.Row():
with gr.Column():
text_input = gr.Textbox(label="Input Text")
submit_button = gr.Button("Submit")
with gr.Column():
text_output = gr.Textbox(label="Retrieved Results", lines=10)
submit_button.click(fn=gradio_interface, inputs=text_input, outputs=text_output)
interface.launch()
|