full commit
Browse files- main.py +179 -0
- requirements.txt +7 -0
main.py
CHANGED
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import faiss
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
from transformers import AutoTokenizer, AutoModel, pipeline
|
5 |
+
import json
|
6 |
+
import gradio as gr
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
+
import tempfile
|
9 |
+
import os
|
10 |
+
|
11 |
+
class MedicalRAG:
|
12 |
+
def __init__(self, embed_path, pmids_path, content_path):
|
13 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
14 |
+
# Load data
|
15 |
+
self.embeddings = np.load(embed_path)
|
16 |
+
self.index = self._create_faiss_index(self.embeddings)
|
17 |
+
self.pmids, self.content = self._load_json_files(pmids_path, content_path)
|
18 |
+
# Setup models
|
19 |
+
self.encoder, self.tokenizer = self._setup_encoder()
|
20 |
+
self.generator = self._setup_generator()
|
21 |
+
|
22 |
+
def _create_faiss_index(self, embeddings):
|
23 |
+
index = faiss.IndexFlatIP(768) # 768 is embedding dimension
|
24 |
+
index.add(embeddings)
|
25 |
+
return index
|
26 |
+
|
27 |
+
def _load_json_files(self, pmids_path, content_path):
|
28 |
+
with open(pmids_path) as f:
|
29 |
+
pmids = json.load(f)
|
30 |
+
with open(content_path) as f:
|
31 |
+
content = json.load(f)
|
32 |
+
return pmids, content
|
33 |
+
|
34 |
+
def _setup_encoder(self):
|
35 |
+
model = AutoModel.from_pretrained("ncbi/MedCPT-Query-Encoder").to(self.device)
|
36 |
+
tokenizer = AutoTokenizer.from_pretrained("ncbi/MedCPT-Query-Encoder")
|
37 |
+
return model, tokenizer
|
38 |
+
|
39 |
+
def _setup_generator(self):
|
40 |
+
return pipeline(
|
41 |
+
"text-generation",
|
42 |
+
model="HuggingFaceTB/SmolLM2-1.7B-Instruct",
|
43 |
+
device=self.device,
|
44 |
+
torch_dtype=torch.float16 if self.device.type == 'cuda' else torch.float32
|
45 |
+
)
|
46 |
+
|
47 |
+
def encode_query(self, query):
|
48 |
+
with torch.no_grad():
|
49 |
+
inputs = self.tokenizer([query], truncation=True, padding=True,
|
50 |
+
return_tensors='pt', max_length=64).to(self.device)
|
51 |
+
embeddings = self.encoder(**inputs).last_hidden_state[:, 0, :]
|
52 |
+
return embeddings.cpu().numpy()
|
53 |
+
|
54 |
+
def search_documents(self, query_embedding, k=8):
|
55 |
+
scores, indices = self.index.search(query_embedding, k=k)
|
56 |
+
return [(self.pmids[idx], float(score)) for idx, score in zip(indices[0], scores[0])], indices[0]
|
57 |
+
|
58 |
+
def get_document_content(self, pmid):
|
59 |
+
doc = self.content.get(pmid, {})
|
60 |
+
return {
|
61 |
+
'title': doc.get('t', '').strip(),
|
62 |
+
'date': doc.get('d', '').strip(),
|
63 |
+
'abstract': doc.get('a', '').strip()
|
64 |
+
}
|
65 |
+
|
66 |
+
def visualize_embeddings(self, query_embed, relevant_indices, labels):
|
67 |
+
plt.figure(figsize=(20, len(relevant_indices) + 1))
|
68 |
+
|
69 |
+
# Prepare embeddings for visualization
|
70 |
+
embeddings = np.vstack([query_embed[0], self.embeddings[relevant_indices]])
|
71 |
+
normalized_embeddings = embeddings / np.max(np.abs(embeddings))
|
72 |
+
# plt
|
73 |
+
for idx, (embedding, label) in enumerate(zip(normalized_embeddings, labels)):
|
74 |
+
y_pos = len(labels) - 1 - idx
|
75 |
+
plt.imshow(embedding.reshape(1, -1), aspect='auto', extent=[0, 768, y_pos, y_pos+0.8],
|
76 |
+
cmap='inferno')
|
77 |
+
|
78 |
+
# Add labels and styling
|
79 |
+
plt.yticks(range(len(labels)), labels)
|
80 |
+
plt.xlabel('Embedding Dimensions')
|
81 |
+
plt.colorbar(label='Normalized Value')
|
82 |
+
plt.title('Query and Retrieved Document Embeddings')
|
83 |
+
|
84 |
+
# Save plot
|
85 |
+
temp_path = os.path.join(tempfile.gettempdir(), f'embeddings_{hash(str(embeddings))}.png')
|
86 |
+
plt.savefig(temp_path, bbox_inches='tight', dpi=150)
|
87 |
+
plt.close()
|
88 |
+
return temp_path
|
89 |
+
|
90 |
+
def generate_answer(self, query, contexts):
|
91 |
+
prompt = (
|
92 |
+
"<|im_start|>system\n"
|
93 |
+
"You are a helpful medical assistant. Answer questions based on the provided literature."
|
94 |
+
"<|im_end|>\n<|im_start|>user\n"
|
95 |
+
f"Based on these medical articles, answer this question:\n\n"
|
96 |
+
f"Question: {query}\n\n"
|
97 |
+
f"Relevant Literature:\n{contexts}\n"
|
98 |
+
"<|im_end|>\n<|im_start|>assistant"
|
99 |
+
)
|
100 |
+
|
101 |
+
response = self.generator(
|
102 |
+
prompt,
|
103 |
+
max_new_tokens=200,
|
104 |
+
temperature=0.3,
|
105 |
+
top_p=0.95,
|
106 |
+
do_sample=True
|
107 |
+
)
|
108 |
+
return response[0]['generated_text'].split("<|im_start|>assistant")[-1].strip()
|
109 |
+
|
110 |
+
def process_query(self, query):
|
111 |
+
try:
|
112 |
+
# Encode and search
|
113 |
+
query_embed = self.encode_query(query)
|
114 |
+
doc_matches, indices = self.search_documents(query_embed)
|
115 |
+
|
116 |
+
# Prepare documents and labels
|
117 |
+
documents = []
|
118 |
+
sources = []
|
119 |
+
labels = ["Query"]
|
120 |
+
|
121 |
+
for pmid, score in doc_matches:
|
122 |
+
doc = self.get_document_content(pmid)
|
123 |
+
if doc['abstract']:
|
124 |
+
documents.append(f"Title: {doc['title']}\nAbstract: {doc['abstract']}")
|
125 |
+
sources.append(f"PMID: {pmid}, Score: {score:.3f}, Link: https://pubmed.ncbi.nlm.nih.gov/{pmid}/")
|
126 |
+
labels.append(f"Doc {len(labels)}: {doc['title'][:30]}...")
|
127 |
+
|
128 |
+
|
129 |
+
# Generate outputs
|
130 |
+
visualization = self.visualize_embeddings(query_embed, indices, labels)
|
131 |
+
answer = self.generate_answer(query, "\n\n".join(documents[:3]))
|
132 |
+
sources_text = "\n".join(sources)
|
133 |
+
context = "\n\n".join(documents)
|
134 |
+
|
135 |
+
return answer, sources_text, context, visualization
|
136 |
+
|
137 |
+
except Exception as e:
|
138 |
+
print(f"Error: {str(e)}")
|
139 |
+
return str(e), "Error retrieving sources", "", None
|
140 |
+
def create_interface():
|
141 |
+
rag = MedicalRAG(
|
142 |
+
embed_path="embeds_chunk_36.npy",
|
143 |
+
pmids_path="pmids_chunk_36.json",
|
144 |
+
content_path="pubmed_chunk_36.json"
|
145 |
+
)
|
146 |
+
|
147 |
+
with gr.Blocks(title="Medical Literature QA") as interface:
|
148 |
+
gr.Markdown("# Medical Literature Question Answering")
|
149 |
+
with gr.Row():
|
150 |
+
with gr.Column():
|
151 |
+
query = gr.Textbox(lines=2, placeholder="Enter your medical question...", label="Question")
|
152 |
+
submit = gr.Button("Submit", variant="primary")
|
153 |
+
sources = gr.Textbox(label="Sources", lines=3)
|
154 |
+
plot = gr.Image(label="Embedding Visualization")
|
155 |
+
with gr.Column():
|
156 |
+
answer = gr.Textbox(label="Answer", lines=5)
|
157 |
+
context = gr.Textbox(label="Context", lines=6)
|
158 |
+
with gr.Row():
|
159 |
+
gr.Examples(
|
160 |
+
examples=[
|
161 |
+
["What are the latest treatments for diabetes?"],
|
162 |
+
["How effective are COVID-19 vaccines?"],
|
163 |
+
["What are common symptoms of the flu?"],
|
164 |
+
["How can I maintain good heart health?"]
|
165 |
+
],
|
166 |
+
inputs=query
|
167 |
+
)
|
168 |
+
|
169 |
+
submit.click(
|
170 |
+
fn=rag.process_query,
|
171 |
+
inputs=query,
|
172 |
+
outputs=[answer, sources, context, plot]
|
173 |
+
)
|
174 |
+
|
175 |
+
return interface
|
176 |
+
|
177 |
+
if __name__ == "__main__":
|
178 |
+
demo = create_interface()
|
179 |
+
demo.launch(share=True)
|
requirements.txt
CHANGED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==2.4.1
|
2 |
+
transformers==4.33.2
|
3 |
+
faiss-cpu==1.9.0.post1
|
4 |
+
numpy==1.26.4
|
5 |
+
gradio==5.9.1
|
6 |
+
matplotlib==3.7.1
|
7 |
+
jsonschema==4.23.0
|