Michaelj1 commited on
Commit
ef6a061
·
1 Parent(s): 8f4010a

full commit

Browse files
Files changed (2) hide show
  1. main.py +179 -0
  2. requirements.txt +7 -0
main.py CHANGED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import faiss
2
+ import torch
3
+ import numpy as np
4
+ from transformers import AutoTokenizer, AutoModel, pipeline
5
+ import json
6
+ import gradio as gr
7
+ import matplotlib.pyplot as plt
8
+ import tempfile
9
+ import os
10
+
11
+ class MedicalRAG:
12
+ def __init__(self, embed_path, pmids_path, content_path):
13
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
14
+ # Load data
15
+ self.embeddings = np.load(embed_path)
16
+ self.index = self._create_faiss_index(self.embeddings)
17
+ self.pmids, self.content = self._load_json_files(pmids_path, content_path)
18
+ # Setup models
19
+ self.encoder, self.tokenizer = self._setup_encoder()
20
+ self.generator = self._setup_generator()
21
+
22
+ def _create_faiss_index(self, embeddings):
23
+ index = faiss.IndexFlatIP(768) # 768 is embedding dimension
24
+ index.add(embeddings)
25
+ return index
26
+
27
+ def _load_json_files(self, pmids_path, content_path):
28
+ with open(pmids_path) as f:
29
+ pmids = json.load(f)
30
+ with open(content_path) as f:
31
+ content = json.load(f)
32
+ return pmids, content
33
+
34
+ def _setup_encoder(self):
35
+ model = AutoModel.from_pretrained("ncbi/MedCPT-Query-Encoder").to(self.device)
36
+ tokenizer = AutoTokenizer.from_pretrained("ncbi/MedCPT-Query-Encoder")
37
+ return model, tokenizer
38
+
39
+ def _setup_generator(self):
40
+ return pipeline(
41
+ "text-generation",
42
+ model="HuggingFaceTB/SmolLM2-1.7B-Instruct",
43
+ device=self.device,
44
+ torch_dtype=torch.float16 if self.device.type == 'cuda' else torch.float32
45
+ )
46
+
47
+ def encode_query(self, query):
48
+ with torch.no_grad():
49
+ inputs = self.tokenizer([query], truncation=True, padding=True,
50
+ return_tensors='pt', max_length=64).to(self.device)
51
+ embeddings = self.encoder(**inputs).last_hidden_state[:, 0, :]
52
+ return embeddings.cpu().numpy()
53
+
54
+ def search_documents(self, query_embedding, k=8):
55
+ scores, indices = self.index.search(query_embedding, k=k)
56
+ return [(self.pmids[idx], float(score)) for idx, score in zip(indices[0], scores[0])], indices[0]
57
+
58
+ def get_document_content(self, pmid):
59
+ doc = self.content.get(pmid, {})
60
+ return {
61
+ 'title': doc.get('t', '').strip(),
62
+ 'date': doc.get('d', '').strip(),
63
+ 'abstract': doc.get('a', '').strip()
64
+ }
65
+
66
+ def visualize_embeddings(self, query_embed, relevant_indices, labels):
67
+ plt.figure(figsize=(20, len(relevant_indices) + 1))
68
+
69
+ # Prepare embeddings for visualization
70
+ embeddings = np.vstack([query_embed[0], self.embeddings[relevant_indices]])
71
+ normalized_embeddings = embeddings / np.max(np.abs(embeddings))
72
+ # plt
73
+ for idx, (embedding, label) in enumerate(zip(normalized_embeddings, labels)):
74
+ y_pos = len(labels) - 1 - idx
75
+ plt.imshow(embedding.reshape(1, -1), aspect='auto', extent=[0, 768, y_pos, y_pos+0.8],
76
+ cmap='inferno')
77
+
78
+ # Add labels and styling
79
+ plt.yticks(range(len(labels)), labels)
80
+ plt.xlabel('Embedding Dimensions')
81
+ plt.colorbar(label='Normalized Value')
82
+ plt.title('Query and Retrieved Document Embeddings')
83
+
84
+ # Save plot
85
+ temp_path = os.path.join(tempfile.gettempdir(), f'embeddings_{hash(str(embeddings))}.png')
86
+ plt.savefig(temp_path, bbox_inches='tight', dpi=150)
87
+ plt.close()
88
+ return temp_path
89
+
90
+ def generate_answer(self, query, contexts):
91
+ prompt = (
92
+ "<|im_start|>system\n"
93
+ "You are a helpful medical assistant. Answer questions based on the provided literature."
94
+ "<|im_end|>\n<|im_start|>user\n"
95
+ f"Based on these medical articles, answer this question:\n\n"
96
+ f"Question: {query}\n\n"
97
+ f"Relevant Literature:\n{contexts}\n"
98
+ "<|im_end|>\n<|im_start|>assistant"
99
+ )
100
+
101
+ response = self.generator(
102
+ prompt,
103
+ max_new_tokens=200,
104
+ temperature=0.3,
105
+ top_p=0.95,
106
+ do_sample=True
107
+ )
108
+ return response[0]['generated_text'].split("<|im_start|>assistant")[-1].strip()
109
+
110
+ def process_query(self, query):
111
+ try:
112
+ # Encode and search
113
+ query_embed = self.encode_query(query)
114
+ doc_matches, indices = self.search_documents(query_embed)
115
+
116
+ # Prepare documents and labels
117
+ documents = []
118
+ sources = []
119
+ labels = ["Query"]
120
+
121
+ for pmid, score in doc_matches:
122
+ doc = self.get_document_content(pmid)
123
+ if doc['abstract']:
124
+ documents.append(f"Title: {doc['title']}\nAbstract: {doc['abstract']}")
125
+ sources.append(f"PMID: {pmid}, Score: {score:.3f}, Link: https://pubmed.ncbi.nlm.nih.gov/{pmid}/")
126
+ labels.append(f"Doc {len(labels)}: {doc['title'][:30]}...")
127
+
128
+
129
+ # Generate outputs
130
+ visualization = self.visualize_embeddings(query_embed, indices, labels)
131
+ answer = self.generate_answer(query, "\n\n".join(documents[:3]))
132
+ sources_text = "\n".join(sources)
133
+ context = "\n\n".join(documents)
134
+
135
+ return answer, sources_text, context, visualization
136
+
137
+ except Exception as e:
138
+ print(f"Error: {str(e)}")
139
+ return str(e), "Error retrieving sources", "", None
140
+ def create_interface():
141
+ rag = MedicalRAG(
142
+ embed_path="embeds_chunk_36.npy",
143
+ pmids_path="pmids_chunk_36.json",
144
+ content_path="pubmed_chunk_36.json"
145
+ )
146
+
147
+ with gr.Blocks(title="Medical Literature QA") as interface:
148
+ gr.Markdown("# Medical Literature Question Answering")
149
+ with gr.Row():
150
+ with gr.Column():
151
+ query = gr.Textbox(lines=2, placeholder="Enter your medical question...", label="Question")
152
+ submit = gr.Button("Submit", variant="primary")
153
+ sources = gr.Textbox(label="Sources", lines=3)
154
+ plot = gr.Image(label="Embedding Visualization")
155
+ with gr.Column():
156
+ answer = gr.Textbox(label="Answer", lines=5)
157
+ context = gr.Textbox(label="Context", lines=6)
158
+ with gr.Row():
159
+ gr.Examples(
160
+ examples=[
161
+ ["What are the latest treatments for diabetes?"],
162
+ ["How effective are COVID-19 vaccines?"],
163
+ ["What are common symptoms of the flu?"],
164
+ ["How can I maintain good heart health?"]
165
+ ],
166
+ inputs=query
167
+ )
168
+
169
+ submit.click(
170
+ fn=rag.process_query,
171
+ inputs=query,
172
+ outputs=[answer, sources, context, plot]
173
+ )
174
+
175
+ return interface
176
+
177
+ if __name__ == "__main__":
178
+ demo = create_interface()
179
+ demo.launch(share=True)
requirements.txt CHANGED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ torch==2.4.1
2
+ transformers==4.33.2
3
+ faiss-cpu==1.9.0.post1
4
+ numpy==1.26.4
5
+ gradio==5.9.1
6
+ matplotlib==3.7.1
7
+ jsonschema==4.23.0