File size: 7,289 Bytes
3061f29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
import os
import torch
from datasets import load_dataset, DatasetDict
from transformers import AutoTokenizer, AutoModel
import chromadb
import gradio as gr
import numpy as np
from sklearn.metrics import precision_score, recall_score, f1_score

# Mean Pooling - Take attention mask into account for correct averaging
def meanpooling(output, mask):
    embeddings = output[0]  # First element of model_output contains all token embeddings
    mask = mask.unsqueeze(-1).expand(embeddings.size()).float()
    return torch.sum(embeddings * mask, 1) / torch.clamp(mask.sum(1), min=1e-9)

# Load the dataset
dataset = load_dataset("thankrandomness/mimic-iii")

# Split the dataset into train and validation sets
split_dataset = dataset['train'].train_test_split(test_size=0.2, seed=42)
dataset = DatasetDict({
    'train': split_dataset['train'],
    'validation': split_dataset['test']
})

# Load the model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("neuml/pubmedbert-base-embeddings-matryoshka")
model = AutoModel.from_pretrained("neuml/pubmedbert-base-embeddings-matryoshka")

# Function to normalize embeddings to unit vectors
def normalize_embedding(embedding):
    norm = np.linalg.norm(embedding)
    return (embedding / norm).tolist() if norm > 0 else embedding

# Function to embed and normalize text
def embed_text(text):
    inputs = tokenizer(text, padding=True, truncation=True, max_length=512, return_tensors='pt')
    with torch.no_grad():
        output = model(**inputs)
    embeddings = meanpooling(output, inputs['attention_mask'])
    normalized_embeddings = normalize_embedding(embeddings.numpy())
    return normalized_embeddings

# Initialize ChromaDB client
client = chromadb.Client()
collection = client.create_collection(name="pubmedbert_matryoshka_embeddings")

# Function to upsert data into ChromaDB
def upsert_data(dataset_split):
    for i, row in enumerate(dataset_split):
        for note in row['notes']:
            text = note.get('text', '')
            annotations_list = []
            
            for annotation in note.get('annotations', []):
                try:
                    code = annotation['code']
                    code_system = annotation['code_system']
                    description = annotation['description']
                    annotations_list.append({"code": code, "code_system": code_system, "description": description})
                except KeyError as e:
                    print(f"Skipping annotation due to missing key: {e}")

            if text and annotations_list:
                embeddings = embed_text([text])[0]

                # Upsert data, embeddings, and annotations into ChromaDB
                for j, annotation in enumerate(annotations_list):
                    collection.upsert(
                        ids=[f"note_{note['note_id']}_{j}"],
                        embeddings=[embeddings],
                        metadatas=[annotation]
                    )
            else:
                print(f"Skipping note {note['note_id']} due to missing 'text' or 'annotations'")

# Upsert training data
upsert_data(dataset['train'])

# Define retrieval function with similarity threshold
def retrieve_relevant_text(input_text):
    input_embedding = embed_text([input_text])[0]
    results = collection.query(
        query_embeddings=[input_embedding],
        n_results=5,
        include=["metadatas", "documents", "distances"]
    )
    
    output = []
    #print("Retrieved items and their similarity scores:")
    for metadata, distance in zip(results['metadatas'][0], results['distances'][0]):
        #print(f"Code: {metadata['code']}, Similarity Score: {distance}")
        #if distance <= similarity_threshold:
        output.append({
            "similarity_score": distance,
            "code": metadata['code'],
            "code_system": metadata['code_system'],
            "description": metadata['description']
        })
    
    # if not output:
    #     print("No results met the similarity threshold.")
    return output

# Evaluate retrieval efficiency on the validation/test set
def evaluate_efficiency(dataset_split):
    y_true = []
    y_pred = []
    total_similarity = 0
    total_items = 0

    for i, row in enumerate(dataset_split):
        for note in row['notes']:
            text = note.get('text', '')
            annotations_list = [annotation['code'] for annotation in note.get('annotations', []) if 'code' in annotation]
            
            if text and annotations_list:
                retrieved_results = retrieve_relevant_text(text)
                retrieved_codes = [result['code'] for result in retrieved_results]
                
                # Sum up similarity scores for average calculation
                for result in retrieved_results:
                    total_similarity += result['similarity_score']
                    total_items += 1

                # Ground truth
                y_true.extend(annotations_list)
                # Predictions (limit to length of true annotations to avoid mismatch)
                y_pred.extend(retrieved_codes[:len(annotations_list)])
                
                # for result in retrieved_results:
                #     print(f"  Code: {result['code']}, Similarity Score: {result['similarity_score']:.2f}")

    # Debugging output to check for mismatches and understand results
    # print("Sample y_true:", y_true[:10])
    # print("Sample y_pred:", y_pred[:10])

    if total_items > 0:
        avg_similarity = total_similarity / total_items
    else:
        avg_similarity = 0

    if len(y_true) != len(y_pred):
        min_length = min(len(y_true), len(y_pred))
        y_true = y_true[:min_length]
        y_pred = y_pred[:min_length]
    
    # Calculate metrics
    precision = precision_score(y_true, y_pred, average='macro', zero_division=0)
    recall = recall_score(y_true, y_pred, average='macro', zero_division=0)
    f1 = f1_score(y_true, y_pred, average='macro', zero_division=0)
    
    return precision, recall, f1, avg_similarity

# Calculate retrieval efficiency metrics
precision, recall, f1, avg_similarity = evaluate_efficiency(dataset['validation'])

# Gradio interface
def gradio_interface(input_text):
    results = retrieve_relevant_text(input_text)
    formatted_results = [
        f"Result {i + 1}:\n"
        f"Similarity Score: {result['similarity_score']:.2f}\n"
        f"Code: {result['code']}\n"
        f"Code System: {result['code_system']}\n"
        f"Description: {result['description']}\n"
        "-------------------"
        for i, result in enumerate(results)
    ]
    return "\n".join(formatted_results)

# Display retrieval efficiency metrics
# metrics = f"Precision: {precision:.2f}, Recall: {recall:.2f}, F1 Score: {f1:.2f}"
metrics = f"Accuracy: {avg_similarity:.2f}"

with gr.Blocks() as interface:
    gr.Markdown("# Automated Medical Coding POC")
    # gr.Markdown(metrics)
    with gr.Row():
        with gr.Column():
            text_input = gr.Textbox(label="Input Text")
            submit_button = gr.Button("Submit")
        with gr.Column():
            text_output = gr.Textbox(label="Retrieved Results", lines=10)
    submit_button.click(fn=gradio_interface, inputs=text_input, outputs=text_output)

interface.launch()