Upload run_on_your_notes_and_trials.ipynb
Browse files- run_on_your_notes_and_trials.ipynb +1174 -0
run_on_your_notes_and_trials.ipynb
ADDED
@@ -0,0 +1,1174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": null,
|
6 |
+
"id": "e88793d7-e431-47dd-9964-0a633b94062b",
|
7 |
+
"metadata": {},
|
8 |
+
"outputs": [],
|
9 |
+
"source": []
|
10 |
+
},
|
11 |
+
{
|
12 |
+
"cell_type": "code",
|
13 |
+
"execution_count": 1,
|
14 |
+
"id": "f29a5b89-3b48-4217-8dfc-cca8222e2d1e",
|
15 |
+
"metadata": {},
|
16 |
+
"outputs": [
|
17 |
+
{
|
18 |
+
"name": "stderr",
|
19 |
+
"output_type": "stream",
|
20 |
+
"text": [
|
21 |
+
"/homes10/klkehl/miniconda3/envs/vllm2/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
22 |
+
" from .autonotebook import tqdm as notebook_tqdm\n",
|
23 |
+
"2025-01-16 18:44:13,978\tINFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.\n"
|
24 |
+
]
|
25 |
+
}
|
26 |
+
],
|
27 |
+
"source": [
|
28 |
+
"from vllm import LLM, SamplingParams\n",
|
29 |
+
"import pandas as pd\n",
|
30 |
+
"import numpy as np\n",
|
31 |
+
"import torch.nn.functional as F\n",
|
32 |
+
"import torch\n",
|
33 |
+
"from transformers import AutoTokenizer\n",
|
34 |
+
"from transformers import AutoModelForCausalLM\n",
|
35 |
+
"import re\n",
|
36 |
+
"import os\n",
|
37 |
+
"from transformers import pipeline, AutoModel\n",
|
38 |
+
"from torch.nn import functional as F\n",
|
39 |
+
"import torch.nn as nn\n",
|
40 |
+
"from torch.utils.data import DataLoader\n",
|
41 |
+
"from torch.nn import LSTM, Linear, Embedding, Conv1d, MaxPool1d, GRU, LSTMCell, Dropout, Module, Sequential, ReLU\n"
|
42 |
+
]
|
43 |
+
},
|
44 |
+
{
|
45 |
+
"cell_type": "code",
|
46 |
+
"execution_count": null,
|
47 |
+
"id": "779db7b2-7bdb-4dea-968e-6bec3b1c892c",
|
48 |
+
"metadata": {},
|
49 |
+
"outputs": [],
|
50 |
+
"source": []
|
51 |
+
},
|
52 |
+
{
|
53 |
+
"cell_type": "code",
|
54 |
+
"execution_count": 2,
|
55 |
+
"id": "d394de92-98cf-40e2-aa08-4e4f60f195bc",
|
56 |
+
"metadata": {},
|
57 |
+
"outputs": [
|
58 |
+
{
|
59 |
+
"data": {
|
60 |
+
"text/plain": [
|
61 |
+
"device(type='cuda')"
|
62 |
+
]
|
63 |
+
},
|
64 |
+
"execution_count": 2,
|
65 |
+
"metadata": {},
|
66 |
+
"output_type": "execute_result"
|
67 |
+
}
|
68 |
+
],
|
69 |
+
"source": [
|
70 |
+
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
|
71 |
+
"device"
|
72 |
+
]
|
73 |
+
},
|
74 |
+
{
|
75 |
+
"cell_type": "code",
|
76 |
+
"execution_count": null,
|
77 |
+
"id": "889585e0-3c94-485a-932a-b1cec935b1b3",
|
78 |
+
"metadata": {},
|
79 |
+
"outputs": [],
|
80 |
+
"source": []
|
81 |
+
},
|
82 |
+
{
|
83 |
+
"cell_type": "code",
|
84 |
+
"execution_count": null,
|
85 |
+
"id": "a41f8e56-e779-4540-a484-a8ec622be396",
|
86 |
+
"metadata": {},
|
87 |
+
"outputs": [],
|
88 |
+
"source": []
|
89 |
+
},
|
90 |
+
{
|
91 |
+
"cell_type": "code",
|
92 |
+
"execution_count": 3,
|
93 |
+
"id": "099b73f1-94ce-4908-8951-0041ede61ee8",
|
94 |
+
"metadata": {},
|
95 |
+
"outputs": [],
|
96 |
+
"source": [
|
97 |
+
"# here, pull in your raw patient clinical notes, imaging reports, and pathology reports\n",
|
98 |
+
"# your input file should contain at minimum columns like ['mrn', 'date', and 'text']; one row per clinical document\n",
|
99 |
+
"# you can combine notes from multiple patients into one input file as long as there is an mrn field\n",
|
100 |
+
"# this notebook expects MRNs to be called 'dfci_mrn, dates to be called 'date', and clinical text to be called 'text', so rename your columns accordingly\n",
|
101 |
+
"#all_reports = pd.read_csv(\"your_patient_notes_file_here.csv\")\n"
|
102 |
+
]
|
103 |
+
},
|
104 |
+
{
|
105 |
+
"cell_type": "code",
|
106 |
+
"execution_count": 4,
|
107 |
+
"id": "02d6736c-5df1-4507-b5c4-b29ac6d8ba0e",
|
108 |
+
"metadata": {},
|
109 |
+
"outputs": [],
|
110 |
+
"source": [
|
111 |
+
"# this is how i pull reports for patients at dfci, commented out for public use\n",
|
112 |
+
"\n",
|
113 |
+
"\n",
|
114 |
+
"# prefix = '/data/clin_notes_outcomes/pan_dfci_2024/derived_data/'\n",
|
115 |
+
"\n",
|
116 |
+
"# # pull in our large corpus of historical electronic health records data\n",
|
117 |
+
"# imaging = pd.read_parquet(prefix + 'all_imaging_reports.parquet')\n",
|
118 |
+
"# medonc = pd.read_parquet(prefix + 'all_clinical_notes.parquet')\n",
|
119 |
+
"# path = pd.read_parquet(prefix + 'all_path_reports.parquet')\n",
|
120 |
+
"\n",
|
121 |
+
"\n",
|
122 |
+
"# all_reports = pd.concat([imaging, medonc, path], axis=0).sort_values(by=['dfci_mrn','date']).reset_index(drop=True)\n"
|
123 |
+
]
|
124 |
+
},
|
125 |
+
{
|
126 |
+
"cell_type": "code",
|
127 |
+
"execution_count": null,
|
128 |
+
"id": "a44f0c77-840f-455b-bb62-055b21493324",
|
129 |
+
"metadata": {},
|
130 |
+
"outputs": [],
|
131 |
+
"source": []
|
132 |
+
},
|
133 |
+
{
|
134 |
+
"cell_type": "code",
|
135 |
+
"execution_count": 5,
|
136 |
+
"id": "aef31b68-98ac-4d51-a8d0-4adbcd2b42ff",
|
137 |
+
"metadata": {},
|
138 |
+
"outputs": [],
|
139 |
+
"source": [
|
140 |
+
"all_reports = all_reports.sort_values(by=['dfci_mrn','date']).reset_index(drop=True)\n"
|
141 |
+
]
|
142 |
+
},
|
143 |
+
{
|
144 |
+
"cell_type": "code",
|
145 |
+
"execution_count": 6,
|
146 |
+
"id": "d4f43b9c-0e6e-4907-a786-c1ae82ce240c",
|
147 |
+
"metadata": {},
|
148 |
+
"outputs": [
|
149 |
+
{
|
150 |
+
"name": "stdout",
|
151 |
+
"output_type": "stream",
|
152 |
+
"text": [
|
153 |
+
"<class 'pandas.core.frame.DataFrame'>\n",
|
154 |
+
"Index: 622 entries, 1627657 to 13607361\n",
|
155 |
+
"Data columns (total 9 columns):\n",
|
156 |
+
" # Column Non-Null Count Dtype \n",
|
157 |
+
"--- ------ -------------- ----- \n",
|
158 |
+
" 0 dfci_mrn 622 non-null int64 \n",
|
159 |
+
" 1 date 622 non-null datetime64[ns]\n",
|
160 |
+
" 2 text 622 non-null object \n",
|
161 |
+
" 3 scan_type 283 non-null object \n",
|
162 |
+
" 4 split 622 non-null object \n",
|
163 |
+
" 5 note_type 622 non-null object \n",
|
164 |
+
" 6 department 268 non-null object \n",
|
165 |
+
" 7 provider_type 268 non-null object \n",
|
166 |
+
" 8 path_type 71 non-null object \n",
|
167 |
+
"dtypes: datetime64[ns](1), int64(1), object(7)\n",
|
168 |
+
"memory usage: 48.6+ KB\n"
|
169 |
+
]
|
170 |
+
}
|
171 |
+
],
|
172 |
+
"source": [
|
173 |
+
"# these are the fields in the raw DFCI data, yours will differ\n",
|
174 |
+
"ten_sample_patients = all_reports.dfci_mrn.sample(n=10)\n",
|
175 |
+
"all_reports = all_reports[all_reports.dfci_mrn.isin(ten_sample_patients)]\n",
|
176 |
+
"all_reports.info()"
|
177 |
+
]
|
178 |
+
},
|
179 |
+
{
|
180 |
+
"cell_type": "code",
|
181 |
+
"execution_count": 7,
|
182 |
+
"id": "fc5d06b4-8762-4462-8b1b-2cdbbd0a8cf3",
|
183 |
+
"metadata": {},
|
184 |
+
"outputs": [],
|
185 |
+
"source": [
|
186 |
+
"# the next set of cells works to extract useful information from each clinical note in your dataset, yielding one long history document for each patient"
|
187 |
+
]
|
188 |
+
},
|
189 |
+
{
|
190 |
+
"cell_type": "code",
|
191 |
+
"execution_count": 8,
|
192 |
+
"id": "8828ae7b-7bbc-4aa7-afe1-ace1cf36df27",
|
193 |
+
"metadata": {},
|
194 |
+
"outputs": [
|
195 |
+
{
|
196 |
+
"name": "stderr",
|
197 |
+
"output_type": "stream",
|
198 |
+
"text": [
|
199 |
+
"/tmp/ipykernel_2451805/729519135.py:37: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
|
200 |
+
" themodel.load_state_dict(torch.load('./tiny_bert_tagger_synthetic.pt'))\n"
|
201 |
+
]
|
202 |
+
},
|
203 |
+
{
|
204 |
+
"data": {
|
205 |
+
"text/plain": [
|
206 |
+
"TagModel(\n",
|
207 |
+
" (bert): BertModel(\n",
|
208 |
+
" (embeddings): BertEmbeddings(\n",
|
209 |
+
" (word_embeddings): Embedding(30522, 128, padding_idx=0)\n",
|
210 |
+
" (position_embeddings): Embedding(512, 128)\n",
|
211 |
+
" (token_type_embeddings): Embedding(2, 128)\n",
|
212 |
+
" (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)\n",
|
213 |
+
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
214 |
+
" )\n",
|
215 |
+
" (encoder): BertEncoder(\n",
|
216 |
+
" (layer): ModuleList(\n",
|
217 |
+
" (0-1): 2 x BertLayer(\n",
|
218 |
+
" (attention): BertAttention(\n",
|
219 |
+
" (self): BertSdpaSelfAttention(\n",
|
220 |
+
" (query): Linear(in_features=128, out_features=128, bias=True)\n",
|
221 |
+
" (key): Linear(in_features=128, out_features=128, bias=True)\n",
|
222 |
+
" (value): Linear(in_features=128, out_features=128, bias=True)\n",
|
223 |
+
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
224 |
+
" )\n",
|
225 |
+
" (output): BertSelfOutput(\n",
|
226 |
+
" (dense): Linear(in_features=128, out_features=128, bias=True)\n",
|
227 |
+
" (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)\n",
|
228 |
+
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
229 |
+
" )\n",
|
230 |
+
" )\n",
|
231 |
+
" (intermediate): BertIntermediate(\n",
|
232 |
+
" (dense): Linear(in_features=128, out_features=512, bias=True)\n",
|
233 |
+
" (intermediate_act_fn): GELUActivation()\n",
|
234 |
+
" )\n",
|
235 |
+
" (output): BertOutput(\n",
|
236 |
+
" (dense): Linear(in_features=512, out_features=128, bias=True)\n",
|
237 |
+
" (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)\n",
|
238 |
+
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
239 |
+
" )\n",
|
240 |
+
" )\n",
|
241 |
+
" )\n",
|
242 |
+
" )\n",
|
243 |
+
" (pooler): BertPooler(\n",
|
244 |
+
" (dense): Linear(in_features=128, out_features=128, bias=True)\n",
|
245 |
+
" (activation): Tanh()\n",
|
246 |
+
" )\n",
|
247 |
+
" )\n",
|
248 |
+
" (prediction_heads): ModuleList(\n",
|
249 |
+
" (0-8): 9 x Sequential(\n",
|
250 |
+
" (0): Linear(in_features=128, out_features=128, bias=True)\n",
|
251 |
+
" (1): ReLU()\n",
|
252 |
+
" (2): Linear(in_features=128, out_features=1, bias=True)\n",
|
253 |
+
" )\n",
|
254 |
+
" )\n",
|
255 |
+
")"
|
256 |
+
]
|
257 |
+
},
|
258 |
+
"execution_count": 8,
|
259 |
+
"metadata": {},
|
260 |
+
"output_type": "execute_result"
|
261 |
+
}
|
262 |
+
],
|
263 |
+
"source": [
|
264 |
+
"valid_tags_list = ['is_tagged','cancer_type','stage_at_diagnosis','treatment','cancer_burden','cancer_status','adverse_event','comorbidity','biomarker']\n",
|
265 |
+
"best_f1_thresholds = [-1.2996799,\n",
|
266 |
+
" 1.8744006,\n",
|
267 |
+
" -0.90340906,\n",
|
268 |
+
" -1.3298296,\n",
|
269 |
+
" -1.3740511,\n",
|
270 |
+
" -0.97108084,\n",
|
271 |
+
" -1.0886533,\n",
|
272 |
+
" -1.9212211,\n",
|
273 |
+
" -0.7184834]\n",
|
274 |
+
"\n",
|
275 |
+
"\n",
|
276 |
+
"\n",
|
277 |
+
" \n",
|
278 |
+
"class TagModel(nn.Module):\n",
|
279 |
+
"\n",
|
280 |
+
" def __init__(self, num_tags, device):\n",
|
281 |
+
" super(TagModel, self).__init__()\n",
|
282 |
+
" \n",
|
283 |
+
" self.bert = AutoModel.from_pretrained('prajjwal1/bert-tiny').to(device)\n",
|
284 |
+
"\n",
|
285 |
+
" self.prediction_heads = nn.ModuleList([Sequential(Linear(128, 128), ReLU(), Linear(128,1)).to(device) for x in range(0, num_tags)])\n",
|
286 |
+
" \n",
|
287 |
+
"\n",
|
288 |
+
" def forward(self, x_text_tensor, x_attention_mask):\n",
|
289 |
+
" \n",
|
290 |
+
" main = self.bert(x_text_tensor, x_attention_mask)\n",
|
291 |
+
" main = main.last_hidden_state[:,0,:].squeeze(1)\n",
|
292 |
+
"\n",
|
293 |
+
" outputs = [x(main) for x in self.prediction_heads]\n",
|
294 |
+
"\n",
|
295 |
+
" return outputs\n",
|
296 |
+
"\n",
|
297 |
+
"num_valid_tags = len(valid_tags_list)\n",
|
298 |
+
"themodel = TagModel(num_valid_tags, device)\n",
|
299 |
+
"themodel.load_state_dict(torch.load('./tiny_bert_tagger_synthetic.pt'))\n",
|
300 |
+
"themodel.to(device)\n",
|
301 |
+
"\n",
|
302 |
+
"themodel.eval()"
|
303 |
+
]
|
304 |
+
},
|
305 |
+
{
|
306 |
+
"cell_type": "code",
|
307 |
+
"execution_count": 9,
|
308 |
+
"id": "e6922167-1ea6-4515-8a05-2c52d5e2715e",
|
309 |
+
"metadata": {},
|
310 |
+
"outputs": [],
|
311 |
+
"source": [
|
312 |
+
"from torch.utils import data\n",
|
313 |
+
"from transformers import AutoTokenizer\n",
|
314 |
+
"\n",
|
315 |
+
"class UnlabeledTagDataset(data.Dataset):\n",
|
316 |
+
" def __init__(self, pandas_dataset, valid_tags_list):\n",
|
317 |
+
" self.data = pandas_dataset.copy().reset_index(drop=True)\n",
|
318 |
+
" self.indices = self.data.index.unique()\n",
|
319 |
+
" self.tokenizer = AutoTokenizer.from_pretrained('prajjwal1/bert-tiny', max_length=128, truncation_side='left') \n",
|
320 |
+
" self.valid_tags_list = valid_tags_list\n",
|
321 |
+
" \n",
|
322 |
+
" def __len__(self):\n",
|
323 |
+
" # how many notes in the dataset\n",
|
324 |
+
" return len(self.indices)\n",
|
325 |
+
" \n",
|
326 |
+
" def __getitem__(self, index):\n",
|
327 |
+
" # get data for notes corresponding to indices passed\n",
|
328 |
+
" this_index = self.indices[index]\n",
|
329 |
+
" pand = self.data.loc[this_index, :]\n",
|
330 |
+
" \n",
|
331 |
+
" encoded = self.tokenizer(pand['excerpt'], padding='max_length', max_length=128, truncation=True)\n",
|
332 |
+
"\n",
|
333 |
+
" x_text_tensor = torch.tensor(encoded.input_ids, dtype=torch.long)\n",
|
334 |
+
" x_attention_mask = torch.tensor(encoded.attention_mask, dtype=torch.long)\n",
|
335 |
+
" \n",
|
336 |
+
"\n",
|
337 |
+
" return x_text_tensor, x_attention_mask\n",
|
338 |
+
" \n",
|
339 |
+
"def extract_relevant_text_from_patient(patient_frame_original, valid_tags_list, best_f1_thresholds, tagger_model):\n",
|
340 |
+
" num_valid_tags = len(valid_tags_list)\n",
|
341 |
+
" patient_frame = patient_frame_original.copy()\n",
|
342 |
+
" patient_frame['date'] = pd.to_datetime(patient_frame.date)\n",
|
343 |
+
" patient_frame = patient_frame.sort_values(by='date').reset_index()\n",
|
344 |
+
" chunk_frames = []\n",
|
345 |
+
" for i in range(0, patient_frame.shape[0]):\n",
|
346 |
+
" chunks = re.sub(\"\\n|\\r\", \" \", patient_frame.iloc[i].text.strip())\n",
|
347 |
+
" chunks = re.sub(r'\\s+', \" \", chunks)\n",
|
348 |
+
" chunks = \"<excerpt break>\" + re.sub(\"\\\\. \", \"<excerpt break>\", chunks) + \"<excerpt break>\"\n",
|
349 |
+
" chunks = pd.Series(chunks.split(\"<excerpt break>\")).str.strip()\n",
|
350 |
+
" chunks = chunks[chunks != '']\n",
|
351 |
+
" \n",
|
352 |
+
" chunk_frame = pd.DataFrame({'date':patient_frame.iloc[i].date, 'note_type':patient_frame.iloc[i].note_type, 'excerpt':chunks})\n",
|
353 |
+
" chunk_frames.append(chunk_frame)\n",
|
354 |
+
"\n",
|
355 |
+
" if len(chunk_frames) > 0:\n",
|
356 |
+
" chunk_frames = pd.concat(chunk_frames, axis=0)\n",
|
357 |
+
" chunk_frames = chunk_frames.drop_duplicates(subset=['excerpt'], keep='first')\n",
|
358 |
+
" \n",
|
359 |
+
" no_shuffle_valid_dataset = data.DataLoader(UnlabeledTagDataset(chunk_frames, valid_tags_list), batch_size=32, shuffle=False, num_workers=0)\n",
|
360 |
+
"\n",
|
361 |
+
" output_prediction_lists = [[] for x in range(num_valid_tags)]\n",
|
362 |
+
" for batch in no_shuffle_valid_dataset:\n",
|
363 |
+
" x_text_ids = batch[0].to(device)\n",
|
364 |
+
" x_attention_mask = batch[1].to(device)\n",
|
365 |
+
" with torch.no_grad():\n",
|
366 |
+
" predictions = tagger_model(x_text_ids, x_attention_mask)\n",
|
367 |
+
" \n",
|
368 |
+
" for j in range(num_valid_tags):\n",
|
369 |
+
" output_prediction_lists[j].append(predictions[j].squeeze(1).detach().cpu().numpy())\n",
|
370 |
+
" \n",
|
371 |
+
" output_prediction_lists = [np.concatenate(x) for x in output_prediction_lists]\n",
|
372 |
+
" \n",
|
373 |
+
" \n",
|
374 |
+
" output = chunk_frames.copy()\n",
|
375 |
+
" for x in range(num_valid_tags):\n",
|
376 |
+
" output['outcome_' + str(x) + '_logit'] = output_prediction_lists[x]\n",
|
377 |
+
" \n",
|
378 |
+
" output = output[output.outcome_0_logit > best_f1_thresholds[0]]\n",
|
379 |
+
"\n",
|
380 |
+
" output = output.groupby(['date','note_type'])['excerpt'].agg('. '.join).reset_index()\n",
|
381 |
+
" output = output[~output.excerpt.isnull()]\n",
|
382 |
+
" output['date_text'] = output['date'].astype(str) + \" \" + output['note_type'] + \" \" + output['excerpt']\n",
|
383 |
+
" return \"\\n\".join(output.date_text.tolist())\n",
|
384 |
+
" else:\n",
|
385 |
+
" return \"\"\n",
|
386 |
+
" "
|
387 |
+
]
|
388 |
+
},
|
389 |
+
{
|
390 |
+
"cell_type": "code",
|
391 |
+
"execution_count": null,
|
392 |
+
"id": "3ce166f0-2d21-47a3-85f0-bb0d96e28dc5",
|
393 |
+
"metadata": {},
|
394 |
+
"outputs": [],
|
395 |
+
"source": []
|
396 |
+
},
|
397 |
+
{
|
398 |
+
"cell_type": "code",
|
399 |
+
"execution_count": 10,
|
400 |
+
"id": "d1b9785a-2ecd-454d-a6be-26d106f0b827",
|
401 |
+
"metadata": {},
|
402 |
+
"outputs": [],
|
403 |
+
"source": [
|
404 |
+
"%%capture\n",
|
405 |
+
"# this generates a data frame with one row per patient, and a patient_long_text column with a bunch of relevant text extracted from each patient's notes\n",
|
406 |
+
"\n",
|
407 |
+
"patient_list = []\n",
|
408 |
+
"unique_patients = all_reports.groupby('dfci_mrn').first().reset_index()[['dfci_mrn']]\n",
|
409 |
+
"for i in range(unique_patients.shape[0]):\n",
|
410 |
+
" unique_patient = unique_patients.iloc[[i]]\n",
|
411 |
+
" patient_frame = all_reports[all_reports.dfci_mrn == unique_patient.dfci_mrn.iloc[0]]\n",
|
412 |
+
" if patient_frame.shape[0] > 0:\n",
|
413 |
+
" # this next line is used for retrospective analysis to restrict input text to text predating a treatment start\n",
|
414 |
+
" #patient_frame = patient_frame[pd.to_datetime(patient_frame.date) < patient_frame.treatment_start_date.iloc[0]]\n",
|
415 |
+
" unique_patient['patient_long_text'] = extract_relevant_text_from_patient(patient_frame, valid_tags_list, best_f1_thresholds, themodel)\n",
|
416 |
+
" patient_list.append(unique_patient) "
|
417 |
+
]
|
418 |
+
},
|
419 |
+
{
|
420 |
+
"cell_type": "code",
|
421 |
+
"execution_count": 11,
|
422 |
+
"id": "86428fae-c91c-4b1b-88da-344dec8d6074",
|
423 |
+
"metadata": {},
|
424 |
+
"outputs": [
|
425 |
+
{
|
426 |
+
"name": "stdout",
|
427 |
+
"output_type": "stream",
|
428 |
+
"text": [
|
429 |
+
"<class 'pandas.core.frame.DataFrame'>\n",
|
430 |
+
"Index: 10 entries, 0 to 9\n",
|
431 |
+
"Data columns (total 2 columns):\n",
|
432 |
+
" # Column Non-Null Count Dtype \n",
|
433 |
+
"--- ------ -------------- ----- \n",
|
434 |
+
" 0 dfci_mrn 10 non-null int64 \n",
|
435 |
+
" 1 patient_long_text 10 non-null object\n",
|
436 |
+
"dtypes: int64(1), object(1)\n",
|
437 |
+
"memory usage: 240.0+ bytes\n"
|
438 |
+
]
|
439 |
+
}
|
440 |
+
],
|
441 |
+
"source": [
|
442 |
+
"long_histories = pd.concat(patient_list, axis=0)\n",
|
443 |
+
"long_histories.info()"
|
444 |
+
]
|
445 |
+
},
|
446 |
+
{
|
447 |
+
"cell_type": "code",
|
448 |
+
"execution_count": 12,
|
449 |
+
"id": "40398c74-1f19-40e0-aefa-f5116eca7d7f",
|
450 |
+
"metadata": {},
|
451 |
+
"outputs": [],
|
452 |
+
"source": [
|
453 |
+
"# now you have long histories for each patient\n",
|
454 |
+
"# delete tiny bert tagging model to make room on GPU for llama\n",
|
455 |
+
"del themodel"
|
456 |
+
]
|
457 |
+
},
|
458 |
+
{
|
459 |
+
"cell_type": "code",
|
460 |
+
"execution_count": null,
|
461 |
+
"id": "77e81c05-2f29-4bff-9c06-3ff832c50e3b",
|
462 |
+
"metadata": {},
|
463 |
+
"outputs": [],
|
464 |
+
"source": []
|
465 |
+
},
|
466 |
+
{
|
467 |
+
"cell_type": "code",
|
468 |
+
"execution_count": null,
|
469 |
+
"id": "c0d98783-0f20-444f-8d14-cf059641bafb",
|
470 |
+
"metadata": {},
|
471 |
+
"outputs": [],
|
472 |
+
"source": []
|
473 |
+
},
|
474 |
+
{
|
475 |
+
"cell_type": "code",
|
476 |
+
"execution_count": 13,
|
477 |
+
"id": "d1ce57c1-3932-4fd9-8def-db554535c914",
|
478 |
+
"metadata": {},
|
479 |
+
"outputs": [],
|
480 |
+
"source": [
|
481 |
+
"# now get ready to use llama to summarize patient histories and extract trial spaces"
|
482 |
+
]
|
483 |
+
},
|
484 |
+
{
|
485 |
+
"cell_type": "code",
|
486 |
+
"execution_count": null,
|
487 |
+
"id": "5d1dd378-4c37-4ade-92ac-7bd982d1355e",
|
488 |
+
"metadata": {},
|
489 |
+
"outputs": [],
|
490 |
+
"source": []
|
491 |
+
},
|
492 |
+
{
|
493 |
+
"cell_type": "code",
|
494 |
+
"execution_count": 14,
|
495 |
+
"id": "4f074561-80ae-44a0-8085-9cac394b80a7",
|
496 |
+
"metadata": {},
|
497 |
+
"outputs": [
|
498 |
+
{
|
499 |
+
"name": "stdout",
|
500 |
+
"output_type": "stream",
|
501 |
+
"text": [
|
502 |
+
"INFO 01-16 18:49:14 awq_marlin.py:97] The model is convertible to awq_marlin during runtime. Using awq_marlin kernel.\n",
|
503 |
+
"INFO 01-16 18:49:14 config.py:905] Defaulting to use mp for distributed inference\n",
|
504 |
+
"WARNING 01-16 18:49:14 arg_utils.py:957] Chunked prefill is enabled by default for models with max_model_len > 32K. Currently, chunked prefill might not work with some features or models. If you encounter any issues, please disable chunked prefill by setting --enable-chunked-prefill=False.\n",
|
505 |
+
"INFO 01-16 18:49:14 config.py:1021] Chunked prefill is enabled with max_num_batched_tokens=512.\n",
|
506 |
+
"INFO 01-16 18:49:14 llm_engine.py:237] Initializing an LLM engine (v0.6.3.post1) with config: model='hugging-quants/Meta-Llama-3.1-70B-Instruct-AWQ-INT4', speculative_config=None, tokenizer='hugging-quants/Meta-Llama-3.1-70B-Instruct-AWQ-INT4', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, rope_scaling=None, rope_theta=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.float16, max_seq_len=120000, download_dir='../meta_ai/', load_format=LoadFormat.AUTO, tensor_parallel_size=2, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=awq_marlin, enforce_eager=False, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=0, served_model_name=hugging-quants/Meta-Llama-3.1-70B-Instruct-AWQ-INT4, num_scheduler_steps=1, chunked_prefill_enabled=True multi_step_stream_outputs=True, enable_prefix_caching=False, use_async_output_proc=True, use_cached_outputs=False, mm_processor_kwargs=None)\n",
|
507 |
+
"WARNING 01-16 18:49:14 multiproc_gpu_executor.py:127] CUDA was previously initialized. We must use the `spawn` multiprocessing start method. Setting VLLM_WORKER_MULTIPROC_METHOD to 'spawn'.\n",
|
508 |
+
"WARNING 01-16 18:49:14 multiproc_gpu_executor.py:53] Reducing Torch parallelism from 32 threads to 1 to avoid unnecessary CPU contention. Set OMP_NUM_THREADS in the external environment to tune this value as needed.\n",
|
509 |
+
"INFO 01-16 18:49:14 custom_cache_manager.py:17] Setting Triton cache manager to: vllm.triton_utils.custom_cache_manager:CustomCacheManager\n",
|
510 |
+
"\u001b[1;36m(VllmWorkerProcess pid=2453819)\u001b[0;0m INFO 01-16 18:49:19 multiproc_worker_utils.py:215] Worker ready; awaiting tasks\n",
|
511 |
+
"INFO 01-16 18:49:20 utils.py:1008] Found nccl from library libnccl.so.2\n",
|
512 |
+
"\u001b[1;36m(VllmWorkerProcess pid=2453819)\u001b[0;0m INFO 01-16 18:49:20 utils.py:1008] Found nccl from library libnccl.so.2\n",
|
513 |
+
"INFO 01-16 18:49:20 pynccl.py:63] vLLM is using nccl==2.20.5\n",
|
514 |
+
"\u001b[1;36m(VllmWorkerProcess pid=2453819)\u001b[0;0m INFO 01-16 18:49:20 pynccl.py:63] vLLM is using nccl==2.20.5\n",
|
515 |
+
"INFO 01-16 18:49:20 custom_all_reduce_utils.py:242] reading GPU P2P access cache from /homes10/klkehl/.cache/vllm/gpu_p2p_access_cache_for_2,3.json\n",
|
516 |
+
"\u001b[1;36m(VllmWorkerProcess pid=2453819)\u001b[0;0m INFO 01-16 18:49:20 custom_all_reduce_utils.py:242] reading GPU P2P access cache from /homes10/klkehl/.cache/vllm/gpu_p2p_access_cache_for_2,3.json\n",
|
517 |
+
"INFO 01-16 18:49:20 shm_broadcast.py:241] vLLM message queue communication handle: Handle(connect_ip='127.0.0.1', local_reader_ranks=[1], buffer=<vllm.distributed.device_communicators.shm_broadcast.ShmRingBuffer object at 0x7f4951076fc0>, local_subscribe_port=52437, remote_subscribe_port=None)\n",
|
518 |
+
"INFO 01-16 18:49:20 model_runner.py:1056] Starting to load model hugging-quants/Meta-Llama-3.1-70B-Instruct-AWQ-INT4...\n",
|
519 |
+
"\u001b[1;36m(VllmWorkerProcess pid=2453819)\u001b[0;0m INFO 01-16 18:49:20 model_runner.py:1056] Starting to load model hugging-quants/Meta-Llama-3.1-70B-Instruct-AWQ-INT4...\n",
|
520 |
+
"INFO 01-16 18:49:21 weight_utils.py:243] Using model weights format ['*.safetensors']\n",
|
521 |
+
"\u001b[1;36m(VllmWorkerProcess pid=2453819)\u001b[0;0m INFO 01-16 18:49:21 weight_utils.py:243] Using model weights format ['*.safetensors']\n"
|
522 |
+
]
|
523 |
+
},
|
524 |
+
{
|
525 |
+
"name": "stderr",
|
526 |
+
"output_type": "stream",
|
527 |
+
"text": [
|
528 |
+
"Loading safetensors checkpoint shards: 0% Completed | 0/9 [00:00<?, ?it/s]\n",
|
529 |
+
"Loading safetensors checkpoint shards: 11% Completed | 1/9 [00:17<02:23, 17.88s/it]\n",
|
530 |
+
"Loading safetensors checkpoint shards: 22% Completed | 2/9 [01:02<03:57, 33.89s/it]\n",
|
531 |
+
"Loading safetensors checkpoint shards: 33% Completed | 3/9 [01:49<03:58, 39.79s/it]\n",
|
532 |
+
"Loading safetensors checkpoint shards: 44% Completed | 4/9 [02:36<03:32, 42.58s/it]\n",
|
533 |
+
"Loading safetensors checkpoint shards: 56% Completed | 5/9 [03:23<02:56, 44.20s/it]\n",
|
534 |
+
"Loading safetensors checkpoint shards: 67% Completed | 6/9 [03:56<02:00, 40.29s/it]\n",
|
535 |
+
"Loading safetensors checkpoint shards: 78% Completed | 7/9 [04:43<01:25, 42.53s/it]\n",
|
536 |
+
"Loading safetensors checkpoint shards: 89% Completed | 8/9 [05:30<00:43, 43.99s/it]\n",
|
537 |
+
"Loading safetensors checkpoint shards: 100% Completed | 9/9 [06:16<00:00, 44.57s/it]\n",
|
538 |
+
"Loading safetensors checkpoint shards: 100% Completed | 9/9 [06:16<00:00, 41.83s/it]\n",
|
539 |
+
"\n"
|
540 |
+
]
|
541 |
+
},
|
542 |
+
{
|
543 |
+
"name": "stdout",
|
544 |
+
"output_type": "stream",
|
545 |
+
"text": [
|
546 |
+
"INFO 01-16 18:55:45 model_runner.py:1067] Loading model weights took 18.5818 GB\n",
|
547 |
+
"\u001b[1;36m(VllmWorkerProcess pid=2453819)\u001b[0;0m INFO 01-16 18:55:47 model_runner.py:1067] Loading model weights took 18.5807 GB\n",
|
548 |
+
"INFO 01-16 18:55:48 distributed_gpu_executor.py:57] # GPU blocks: 17638, # CPU blocks: 1638\n",
|
549 |
+
"INFO 01-16 18:55:48 distributed_gpu_executor.py:61] Maximum concurrency for 120000 tokens per request: 2.35x\n",
|
550 |
+
"\u001b[1;36m(VllmWorkerProcess pid=2453819)\u001b[0;0m INFO 01-16 18:55:53 model_runner.py:1395] Capturing the model for CUDA graphs. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI.\n",
|
551 |
+
"\u001b[1;36m(VllmWorkerProcess pid=2453819)\u001b[0;0m INFO 01-16 18:55:53 model_runner.py:1399] CUDA graphs can take additional 1~3 GiB memory per GPU. If you are running out of memory, consider decreasing `gpu_memory_utilization` or enforcing eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage.\n",
|
552 |
+
"INFO 01-16 18:55:53 model_runner.py:1395] Capturing the model for CUDA graphs. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI.\n",
|
553 |
+
"INFO 01-16 18:55:53 model_runner.py:1399] CUDA graphs can take additional 1~3 GiB memory per GPU. If you are running out of memory, consider decreasing `gpu_memory_utilization` or enforcing eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage.\n",
|
554 |
+
"INFO 01-16 18:56:21 custom_all_reduce.py:233] Registering 5635 cuda graph addresses\n",
|
555 |
+
"\u001b[1;36m(VllmWorkerProcess pid=2453819)\u001b[0;0m INFO 01-16 18:56:21 custom_all_reduce.py:233] Registering 5635 cuda graph addresses\n",
|
556 |
+
"\u001b[1;36m(VllmWorkerProcess pid=2453819)\u001b[0;0m INFO 01-16 18:56:21 model_runner.py:1523] Graph capturing finished in 29 secs.\n",
|
557 |
+
"INFO 01-16 18:56:21 model_runner.py:1523] Graph capturing finished in 28 secs.\n"
|
558 |
+
]
|
559 |
+
}
|
560 |
+
],
|
561 |
+
"source": [
|
562 |
+
"# load llama\n",
|
563 |
+
"# modify this depending on your GPU setup and where you want to dowwnload the llm\n",
|
564 |
+
"# requires vllm\n",
|
565 |
+
"import os\n",
|
566 |
+
"from vllm import LLM\n",
|
567 |
+
"os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"2,3\"\n",
|
568 |
+
"llama = LLM(model='hugging-quants/Meta-Llama-3.1-70B-Instruct-AWQ-INT4', tensor_parallel_size = 2, download_dir = \"../meta_ai/\", gpu_memory_utilization=0.80, max_model_len=120000)"
|
569 |
+
]
|
570 |
+
},
|
571 |
+
{
|
572 |
+
"cell_type": "code",
|
573 |
+
"execution_count": null,
|
574 |
+
"id": "2a2882e7-a180-48e1-a50b-ec95730e08ce",
|
575 |
+
"metadata": {},
|
576 |
+
"outputs": [],
|
577 |
+
"source": []
|
578 |
+
},
|
579 |
+
{
|
580 |
+
"cell_type": "code",
|
581 |
+
"execution_count": null,
|
582 |
+
"id": "485300a3-4661-4db5-b656-7aad11e9e8d3",
|
583 |
+
"metadata": {},
|
584 |
+
"outputs": [],
|
585 |
+
"source": []
|
586 |
+
},
|
587 |
+
{
|
588 |
+
"cell_type": "code",
|
589 |
+
"execution_count": null,
|
590 |
+
"id": "3eb4a246-27a7-4fa9-915d-17f463e23171",
|
591 |
+
"metadata": {},
|
592 |
+
"outputs": [],
|
593 |
+
"source": []
|
594 |
+
},
|
595 |
+
{
|
596 |
+
"cell_type": "code",
|
597 |
+
"execution_count": 15,
|
598 |
+
"id": "418b71f0-9c4a-4edc-b810-23db5b4e40f4",
|
599 |
+
"metadata": {},
|
600 |
+
"outputs": [],
|
601 |
+
"source": [
|
602 |
+
"# generate summaries for our patients"
|
603 |
+
]
|
604 |
+
},
|
605 |
+
{
|
606 |
+
"cell_type": "code",
|
607 |
+
"execution_count": 16,
|
608 |
+
"id": "ba00cbee-7b60-44c4-95e9-9da113de6de4",
|
609 |
+
"metadata": {},
|
610 |
+
"outputs": [],
|
611 |
+
"source": [
|
612 |
+
"def summarize_patients(patient_texts, llama_model):\n",
|
613 |
+
" \n",
|
614 |
+
"\n",
|
615 |
+
" prompts = []\n",
|
616 |
+
"\n",
|
617 |
+
" tokenizer = llama_model.get_tokenizer()\n",
|
618 |
+
"\n",
|
619 |
+
" prompts = []\n",
|
620 |
+
" for the_patient in patient_texts:\n",
|
621 |
+
"\n",
|
622 |
+
" patient_text_tokens = tokenizer(the_patient, add_special_tokens=False).input_ids\n",
|
623 |
+
" if len(patient_text_tokens) > 115000:\n",
|
624 |
+
" first_part = patient_text_tokens[:57500]\n",
|
625 |
+
" # Slice the last `slice_size` elements\n",
|
626 |
+
" last_part = patient_text_tokens[-57500:]\n",
|
627 |
+
" # Concatenate the two slices\n",
|
628 |
+
" patient_text_tokens = first_part + last_part\n",
|
629 |
+
" patient_text = tokenizer.decode(patient_text_tokens)\n",
|
630 |
+
" \n",
|
631 |
+
" messages = [{'role':'system', 'content': \"\"\"You are an experienced clinical oncology history summarization bot.\n",
|
632 |
+
" Your job is to construct a summary of the cancer history for a patient based on an excerpt of the patient's electronic health record. The text in the excerpt is provided in chronological order. \n",
|
633 |
+
" Document the cancer type/primary site (eg breast cancer, lung cancer, etc); histology (eg adenocarcinoma, squamous carcinoma, etc); current extent (localized, advanced, metastatic, etc); biomarkers (genomic results, protein expression, etc); and treatment history (surgery, radiation, chemotherapy/targeted therapy/immunotherapy, etc, including start and stop dates and best response if known).\n",
|
634 |
+
" Do not consider localized basal cell or squamous carcinomas of the skin, or colon polyps, to be cancers for your purposes.\n",
|
635 |
+
" Do not include the patient's name, but do include relevant dates whenever documented, including dates of diagnosis and start/stop dates of each treatment.\n",
|
636 |
+
" If a patient has a history of more than one cancer, document the cancers one at a time.\n",
|
637 |
+
" \"\"\"}, \n",
|
638 |
+
" {'role':'user', 'content': \"The excerpt is:\\n\" + the_patient + \"\"\"Now, write your summary. Do not add preceding text before the abstraction, and do not add notes or commentary afterwards. This will not be used for clinical care, so do not write any disclaimers or cautionary notes.\"\"\"}\n",
|
639 |
+
"\n",
|
640 |
+
" ]\n",
|
641 |
+
" \n",
|
642 |
+
"\n",
|
643 |
+
"\n",
|
644 |
+
" prompts.append(messages)\n",
|
645 |
+
"\n",
|
646 |
+
" trunc_messages = [x[1]['content'] for x in prompts]\n",
|
647 |
+
"\n",
|
648 |
+
" newprompts = []\n",
|
649 |
+
" for i, messages in enumerate(prompts):\n",
|
650 |
+
" messages[1]['content'] = trunc_messages[i]\n",
|
651 |
+
" template_prompt = tokenizer.apply_chat_template(conversation=messages, add_generation_prompt=True, tokenize=False)\n",
|
652 |
+
" newprompts.append(template_prompt)\n",
|
653 |
+
" \n",
|
654 |
+
"\n",
|
655 |
+
" \n",
|
656 |
+
" responses = llama_model.generate(\n",
|
657 |
+
" newprompts, \n",
|
658 |
+
" SamplingParams(\n",
|
659 |
+
" temperature=0.0,\n",
|
660 |
+
" top_p=0.2,\n",
|
661 |
+
" max_tokens=4096,\n",
|
662 |
+
" repetition_penalty=1.2,\n",
|
663 |
+
" stop_token_ids=[tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids(\"<|eot_id|>\")], # KEYPOINT HERE\n",
|
664 |
+
" ))\n",
|
665 |
+
"\n",
|
666 |
+
" response_texts = [x.outputs[0].text for x in responses]\n",
|
667 |
+
"\n",
|
668 |
+
"\n",
|
669 |
+
" return responses, response_texts\n",
|
670 |
+
" "
|
671 |
+
]
|
672 |
+
},
|
673 |
+
{
|
674 |
+
"cell_type": "code",
|
675 |
+
"execution_count": 17,
|
676 |
+
"id": "7de591f3-154f-4991-83db-0b45979e856a",
|
677 |
+
"metadata": {},
|
678 |
+
"outputs": [
|
679 |
+
{
|
680 |
+
"name": "stderr",
|
681 |
+
"output_type": "stream",
|
682 |
+
"text": [
|
683 |
+
"Processed prompts: 100%|█████████| 10/10 [00:47<00:00, 4.71s/it, est. speed input: 925.94 toks/s, output: 42.00 toks/s]\n"
|
684 |
+
]
|
685 |
+
}
|
686 |
+
],
|
687 |
+
"source": [
|
688 |
+
"long_histories['patient_summary'] = summarize_patients(long_histories.patient_long_text.tolist(), llama)[1]"
|
689 |
+
]
|
690 |
+
},
|
691 |
+
{
|
692 |
+
"cell_type": "code",
|
693 |
+
"execution_count": 18,
|
694 |
+
"id": "9605f785-8667-441f-838c-5dcf9eac19a9",
|
695 |
+
"metadata": {},
|
696 |
+
"outputs": [],
|
697 |
+
"source": [
|
698 |
+
"# now we turn attention to the clinical trials we want to match against\n",
|
699 |
+
"# assume you have a dataset of trials, each with an eligibilty_criteria text field as from clinicaltrials.gov\n",
|
700 |
+
"# here, i just used a download from ct.gov for trials relating to cancer\n",
|
701 |
+
"trials = pd.read_csv('ctgov_cancer_trials.csv')"
|
702 |
+
]
|
703 |
+
},
|
704 |
+
{
|
705 |
+
"cell_type": "code",
|
706 |
+
"execution_count": 19,
|
707 |
+
"id": "b13620e6-cdde-4668-a992-1e8fddf4ab8a",
|
708 |
+
"metadata": {},
|
709 |
+
"outputs": [],
|
710 |
+
"source": [
|
711 |
+
"# ultimately you want to have a raw trial_text field that combines the trial title, summary, and eligibility criteria text from ct.gov\n",
|
712 |
+
"trials['trial_text'] = trials['title'] + \"\\n\" + trials['brief_summary'] + \"\\n\" + trials['eligibility_criteria']"
|
713 |
+
]
|
714 |
+
},
|
715 |
+
{
|
716 |
+
"cell_type": "code",
|
717 |
+
"execution_count": 20,
|
718 |
+
"id": "ee5d88da-8867-4aae-892a-799dbffeb8c9",
|
719 |
+
"metadata": {},
|
720 |
+
"outputs": [],
|
721 |
+
"source": [
|
722 |
+
"# now summarize the trials of interest to you based on the trial_text field\n",
|
723 |
+
"def summarize_trials_multi_cohort(eligibility_texts, llama_model):\n",
|
724 |
+
"\n",
|
725 |
+
" tokenizer = llama.get_tokenizer()\n",
|
726 |
+
" prompts = []\n",
|
727 |
+
" for trial in eligibility_texts:\n",
|
728 |
+
" messages = [\n",
|
729 |
+
" {'role':'system', 'content': \"\"\"You are an expert clinical oncologist with an encyclopedic knowledge of cancer and its treatments.\n",
|
730 |
+
" Your job is to review a clinical trial document and extract a list of structured clinical spaces that are eligible for that trial.\n",
|
731 |
+
" A clinical space is defined as a unique combination of cancer primary site, histology, which treatments a patient must have received, which treatments a patient must not have received, cancer burden (eg presence of metastatic disease), and tumor biomarkers (such as germline or somatic gene mutations or alterations, or protein expression on tumor) that a patient must have or must not have; that renders a patient eligible for the trial.\n",
|
732 |
+
" Trials often specify that a particular treatment is excluded only if it was given within a short period of time, for example 14 days, one month, etc , prior to trial start. Do not include this type of time-specific treatment eligibility criteria in your output at all.\n",
|
733 |
+
" Some trials have only one space, while others have several. Do not output a space that contains multiple cancer types and/or histologies. Instead, generate separate spaces for each cancer type/histology combination.\n",
|
734 |
+
" For biomarkers, if the trial specifies whether the biomarker will be assessed during screening, note that.\n",
|
735 |
+
" Spell out cancer types; do not abbreviate them. For example, write \"non-small cell lung cancer\" rather than \"NSCLC\".\n",
|
736 |
+
" Structure your output like this, as a list of spaces, with spaces separated by newlines, as below:\n",
|
737 |
+
" 1. Cancer type allowed: <cancer_type_allowed>. Histology allowed: <histology_allowed>. Cancer burden allowed: <cancer_burden_allowed>. Prior treatment required: <prior_treatments_requred>. Prior treatment excluded: <prior_treatments_excluded>. Biomarkers required: <biomarkers_required>. Biomarkers excluded: <biomarkers_excluded>.\n",
|
738 |
+
" 2. Cancer type allowed: <cancer_type_allowed>, etc.\n",
|
739 |
+
" If a particular concept is not mentioned in the trial text, do not include it in your definition of trial space(s).\n",
|
740 |
+
" \"\"\"}, \n",
|
741 |
+
" \n",
|
742 |
+
" {'role':'user', 'content': \"Here is a clinical trial document: \\n\" + trial + \"\\n\" + \"\"\"Now, generate your list of the trial space(s), formatted as above.\n",
|
743 |
+
" Do not provide any introductory, explanatory, concluding, or disclaimer text.\n",
|
744 |
+
" Reminder: Treatment history is an important component of trial space definitions, but treatment history requirements that are described as applying only in a given period of time prior to trial treatment MUST BE IGNORED.\"\"\"\n",
|
745 |
+
" }\n",
|
746 |
+
" ]\n",
|
747 |
+
" \n",
|
748 |
+
" prompts.append(tokenizer.apply_chat_template(conversation=messages, add_generation_prompt=True, tokenize=False))\n",
|
749 |
+
" \n",
|
750 |
+
"\n",
|
751 |
+
" \n",
|
752 |
+
" responses = llama_model.generate(\n",
|
753 |
+
" prompts, \n",
|
754 |
+
" SamplingParams(\n",
|
755 |
+
" temperature=0.0,\n",
|
756 |
+
" top_p=0.9,\n",
|
757 |
+
" max_tokens=3096,\n",
|
758 |
+
" stop_token_ids=[tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids(\"<|eot_id|>\")], # KEYPOINT HERE\n",
|
759 |
+
" ))\n",
|
760 |
+
"\n",
|
761 |
+
" response_texts = [x.outputs[0].text for x in responses]\n",
|
762 |
+
"\n",
|
763 |
+
"\n",
|
764 |
+
" return responses, response_texts"
|
765 |
+
]
|
766 |
+
},
|
767 |
+
{
|
768 |
+
"cell_type": "code",
|
769 |
+
"execution_count": 21,
|
770 |
+
"id": "f05c9d94-9190-4d29-bcf8-0800fbb37f42",
|
771 |
+
"metadata": {},
|
772 |
+
"outputs": [],
|
773 |
+
"source": [
|
774 |
+
"# this runs the trial summarization/space extraction\n",
|
775 |
+
"# i have a premade trial spaces file, so this is commented out\n",
|
776 |
+
"# trials['spaces'] = summarize_trials_multi_cohort(trials.trial_text.tolist(), llama)[1]\n",
|
777 |
+
"#trials.to_csv('ctgov_all_trials_unique_trial_spaces_10-31-24.csv')"
|
778 |
+
]
|
779 |
+
},
|
780 |
+
{
|
781 |
+
"cell_type": "code",
|
782 |
+
"execution_count": 22,
|
783 |
+
"id": "14d40b23-fcfa-4fe6-a0ff-c58d9750cab2",
|
784 |
+
"metadata": {},
|
785 |
+
"outputs": [],
|
786 |
+
"source": [
|
787 |
+
"trials = pd.read_csv('ctgov_all_trials_unique_trial_spaces_10-31-24.csv')"
|
788 |
+
]
|
789 |
+
},
|
790 |
+
{
|
791 |
+
"cell_type": "code",
|
792 |
+
"execution_count": 23,
|
793 |
+
"id": "8ee87e63-ef82-43a5-b243-0de4561a8bd0",
|
794 |
+
"metadata": {},
|
795 |
+
"outputs": [],
|
796 |
+
"source": [
|
797 |
+
"# now parse the extracted trial spaces to get one row per space (can be one or more rows per trial)"
|
798 |
+
]
|
799 |
+
},
|
800 |
+
{
|
801 |
+
"cell_type": "code",
|
802 |
+
"execution_count": 24,
|
803 |
+
"id": "9cc06840-5647-4524-a7bf-a1ad53a07b7c",
|
804 |
+
"metadata": {},
|
805 |
+
"outputs": [],
|
806 |
+
"source": [
|
807 |
+
"frames = []\n",
|
808 |
+
"for i in range(trials.shape[0]):\n",
|
809 |
+
" cohorts = pd.Series(trials.iloc[i].spaces.split(\"\\n\"))\n",
|
810 |
+
" cohorts = cohorts[~((cohorts.isnull()) | (cohorts == \"\\n\") | (cohorts == ''))].reset_index(drop=True)\n",
|
811 |
+
" frame = pd.DataFrame(np.repeat(trials.iloc[[i]], len(cohorts), axis=0), columns=trials.columns)\n",
|
812 |
+
" frame['this_space'] = cohorts\n",
|
813 |
+
" frame['space_number'] = frame.index\n",
|
814 |
+
" frames.append(frame)\n",
|
815 |
+
" "
|
816 |
+
]
|
817 |
+
},
|
818 |
+
{
|
819 |
+
"cell_type": "code",
|
820 |
+
"execution_count": 25,
|
821 |
+
"id": "541669eb-f92e-49f3-9a36-b6625448c1a4",
|
822 |
+
"metadata": {},
|
823 |
+
"outputs": [],
|
824 |
+
"source": [
|
825 |
+
"cohort_level_trials = pd.concat(frames, axis=0)"
|
826 |
+
]
|
827 |
+
},
|
828 |
+
{
|
829 |
+
"cell_type": "code",
|
830 |
+
"execution_count": 26,
|
831 |
+
"id": "51a04e84-7483-4398-b4a0-d0cdab790609",
|
832 |
+
"metadata": {},
|
833 |
+
"outputs": [
|
834 |
+
{
|
835 |
+
"name": "stdout",
|
836 |
+
"output_type": "stream",
|
837 |
+
"text": [
|
838 |
+
"<class 'pandas.core.frame.DataFrame'>\n",
|
839 |
+
"Index: 38276 entries, 0 to 0\n",
|
840 |
+
"Data columns (total 10 columns):\n",
|
841 |
+
" # Column Non-Null Count Dtype \n",
|
842 |
+
"--- ------ -------------- ----- \n",
|
843 |
+
" 0 Unnamed: 0.1 38276 non-null object\n",
|
844 |
+
" 1 Unnamed: 0 38276 non-null object\n",
|
845 |
+
" 2 nct_id 38276 non-null object\n",
|
846 |
+
" 3 title 38276 non-null object\n",
|
847 |
+
" 4 brief_summary 38276 non-null object\n",
|
848 |
+
" 5 eligibility_criteria 38276 non-null object\n",
|
849 |
+
" 6 trial_text 38276 non-null object\n",
|
850 |
+
" 7 spaces 38276 non-null object\n",
|
851 |
+
" 8 this_space 38276 non-null object\n",
|
852 |
+
" 9 space_number 38276 non-null int64 \n",
|
853 |
+
"dtypes: int64(1), object(9)\n",
|
854 |
+
"memory usage: 3.2+ MB\n"
|
855 |
+
]
|
856 |
+
}
|
857 |
+
],
|
858 |
+
"source": [
|
859 |
+
"cohort_level_trials.info()"
|
860 |
+
]
|
861 |
+
},
|
862 |
+
{
|
863 |
+
"cell_type": "code",
|
864 |
+
"execution_count": 27,
|
865 |
+
"id": "648f0e1e-ef81-4983-8f03-1fbdb138f649",
|
866 |
+
"metadata": {},
|
867 |
+
"outputs": [
|
868 |
+
{
|
869 |
+
"data": {
|
870 |
+
"text/plain": [
|
871 |
+
"this_space\n",
|
872 |
+
"True 38140\n",
|
873 |
+
"False 136\n",
|
874 |
+
"Name: count, dtype: int64"
|
875 |
+
]
|
876 |
+
},
|
877 |
+
"execution_count": 27,
|
878 |
+
"metadata": {},
|
879 |
+
"output_type": "execute_result"
|
880 |
+
}
|
881 |
+
],
|
882 |
+
"source": [
|
883 |
+
"cohort_level_trials.this_space.str[0].isin(['1','2','3','4','5','6','7','8','9']).value_counts()"
|
884 |
+
]
|
885 |
+
},
|
886 |
+
{
|
887 |
+
"cell_type": "code",
|
888 |
+
"execution_count": 28,
|
889 |
+
"id": "9ea048c1-c4ef-4202-a9be-a4658c4f1058",
|
890 |
+
"metadata": {},
|
891 |
+
"outputs": [],
|
892 |
+
"source": [
|
893 |
+
"trial_spaces = cohort_level_trials[cohort_level_trials.this_space.str[0].isin(['1','2','3','4','5','6','7','8','9'])]"
|
894 |
+
]
|
895 |
+
},
|
896 |
+
{
|
897 |
+
"cell_type": "code",
|
898 |
+
"execution_count": null,
|
899 |
+
"id": "852aee9d-ad97-4374-932f-6cae378dde2a",
|
900 |
+
"metadata": {},
|
901 |
+
"outputs": [],
|
902 |
+
"source": []
|
903 |
+
},
|
904 |
+
{
|
905 |
+
"cell_type": "code",
|
906 |
+
"execution_count": 29,
|
907 |
+
"id": "00d2220a-627a-4b67-be28-c42561c3c964",
|
908 |
+
"metadata": {},
|
909 |
+
"outputs": [],
|
910 |
+
"source": [
|
911 |
+
"# if you want to save the extracted individual trial 'spaces' do this\n",
|
912 |
+
"#trial_spaces.to_csv('ctgov_all_trials_trial_space_lineitems_10-31-24.csv')"
|
913 |
+
]
|
914 |
+
},
|
915 |
+
{
|
916 |
+
"cell_type": "code",
|
917 |
+
"execution_count": 30,
|
918 |
+
"id": "d8d351c0-26fd-47f5-98f7-843198909733",
|
919 |
+
"metadata": {},
|
920 |
+
"outputs": [],
|
921 |
+
"source": [
|
922 |
+
"# this trial dataframe now has one row per trial 'space'; i have pre-generated it\n",
|
923 |
+
"trial_spaces = pd.read_csv('ctgov_all_trials_trial_space_lineitems_10-31-24.csv')"
|
924 |
+
]
|
925 |
+
},
|
926 |
+
{
|
927 |
+
"cell_type": "code",
|
928 |
+
"execution_count": null,
|
929 |
+
"id": "dd726671-8517-47f7-a306-fd28ae0ce25d",
|
930 |
+
"metadata": {},
|
931 |
+
"outputs": [],
|
932 |
+
"source": []
|
933 |
+
},
|
934 |
+
{
|
935 |
+
"cell_type": "code",
|
936 |
+
"execution_count": 31,
|
937 |
+
"id": "70dd7f46-8edb-4ff2-81d0-b94e64816ac5",
|
938 |
+
"metadata": {},
|
939 |
+
"outputs": [
|
940 |
+
{
|
941 |
+
"name": "stderr",
|
942 |
+
"output_type": "stream",
|
943 |
+
"text": [
|
944 |
+
"Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████| 2/2 [00:01<00:00, 1.23it/s]\n"
|
945 |
+
]
|
946 |
+
}
|
947 |
+
],
|
948 |
+
"source": [
|
949 |
+
"# now embed patients and trial spaces\n",
|
950 |
+
"from sentence_transformers import SentenceTransformer\n",
|
951 |
+
"\n",
|
952 |
+
"# lazily using cpu here\n",
|
953 |
+
"embedding_model = SentenceTransformer('ksg-dfci/TrialSpace', trust_remote_code=True, device='cpu')\n"
|
954 |
+
]
|
955 |
+
},
|
956 |
+
{
|
957 |
+
"cell_type": "code",
|
958 |
+
"execution_count": 32,
|
959 |
+
"id": "9a08f0ae-ee9b-475c-b2af-c46be89d71d5",
|
960 |
+
"metadata": {},
|
961 |
+
"outputs": [],
|
962 |
+
"source": [
|
963 |
+
"with torch.no_grad():\n",
|
964 |
+
" patient_embeddings = embedding_model.encode(long_histories.patient_summary, convert_to_tensor=True)"
|
965 |
+
]
|
966 |
+
},
|
967 |
+
{
|
968 |
+
"cell_type": "code",
|
969 |
+
"execution_count": 33,
|
970 |
+
"id": "aad4c966-f70a-48f7-b58c-d0932f9a9010",
|
971 |
+
"metadata": {},
|
972 |
+
"outputs": [],
|
973 |
+
"source": [
|
974 |
+
"# here's where we embed trial spaces\n",
|
975 |
+
"# this only needs to be run once to generate and save trial embeddings, or for a short list of trials you can run it every time\n",
|
976 |
+
"# here it is commented out, since I'll just load the previously generated embeddings\n",
|
977 |
+
"\n",
|
978 |
+
"# with torch.no_grad():\n",
|
979 |
+
"# trial_space_embeddings = embedding_model.encode(trial_spaces.this_space.tolist(), convert_to_tensor=True)\n",
|
980 |
+
"\n",
|
981 |
+
"# from safetensors.torch import save_file\n",
|
982 |
+
"# output_trial_file = {\"space_embeddings\": trial_space_embeddings}\n",
|
983 |
+
"# save_file(output_trial_file, \"trial_space_embeddings.safetensors\")\n",
|
984 |
+
"\n",
|
985 |
+
"# trial_space_embeddings.shape"
|
986 |
+
]
|
987 |
+
},
|
988 |
+
{
|
989 |
+
"cell_type": "code",
|
990 |
+
"execution_count": 34,
|
991 |
+
"id": "3690cf7b-7df8-46eb-86c0-1e9efdaa1f43",
|
992 |
+
"metadata": {},
|
993 |
+
"outputs": [],
|
994 |
+
"source": [
|
995 |
+
"# load trial space embeddings, should have same number of embeddings as there are in the trial spaces dataset\n",
|
996 |
+
"from safetensors import safe_open\n",
|
997 |
+
"with safe_open(\"trial_space_embeddings.safetensors\", framework=\"pt\", device='cpu') as f:\n",
|
998 |
+
" trial_space_embeddings = f.get_tensor(\"space_embeddings\")"
|
999 |
+
]
|
1000 |
+
},
|
1001 |
+
{
|
1002 |
+
"cell_type": "code",
|
1003 |
+
"execution_count": 35,
|
1004 |
+
"id": "7f1d5a59-d458-4cb8-8015-492ca1e31de5",
|
1005 |
+
"metadata": {},
|
1006 |
+
"outputs": [
|
1007 |
+
{
|
1008 |
+
"data": {
|
1009 |
+
"text/plain": [
|
1010 |
+
"(torch.Size([38140, 1024]), (38140, 10))"
|
1011 |
+
]
|
1012 |
+
},
|
1013 |
+
"execution_count": 35,
|
1014 |
+
"metadata": {},
|
1015 |
+
"output_type": "execute_result"
|
1016 |
+
}
|
1017 |
+
],
|
1018 |
+
"source": [
|
1019 |
+
"trial_space_embeddings.shape, trial_spaces.shape"
|
1020 |
+
]
|
1021 |
+
},
|
1022 |
+
{
|
1023 |
+
"cell_type": "code",
|
1024 |
+
"execution_count": 36,
|
1025 |
+
"id": "f541566e-1027-4de3-b5d9-6a04b1371cd2",
|
1026 |
+
"metadata": {},
|
1027 |
+
"outputs": [],
|
1028 |
+
"source": [
|
1029 |
+
"# now let's find the top ten trial 'spaces' for each patient based on cosine similarity\n",
|
1030 |
+
"\n",
|
1031 |
+
"output_list = []\n",
|
1032 |
+
"for i, patient_summary in enumerate(long_histories.patient_summary):\n",
|
1033 |
+
" patient_embedding = patient_embeddings[i, :]\n",
|
1034 |
+
" similarities = F.cosine_similarity(patient_embedding, trial_space_embeddings)\n",
|
1035 |
+
" sorted_similarities, sorted_indices = torch.sort(similarities, descending=True)\n",
|
1036 |
+
" relevant_spaces = trial_spaces.iloc[sorted_indices[0:10].cpu().numpy()]\n",
|
1037 |
+
" output = pd.DataFrame({'patient_summary':patient_summary, 'this_space':relevant_spaces.this_space, 'nct_id':relevant_spaces.nct_id, \n",
|
1038 |
+
" 'trial_title':relevant_spaces.title, 'trial_brief_summary':relevant_spaces.brief_summary,\n",
|
1039 |
+
" 'trial_text':relevant_spaces.trial_text})\n",
|
1040 |
+
" output_list.append(output)\n",
|
1041 |
+
"\n",
|
1042 |
+
"output = pd.concat(output_list, axis=0).reset_index(drop=True)\n",
|
1043 |
+
"output['patient_summary'] = output.patient_summary.str.strip()"
|
1044 |
+
]
|
1045 |
+
},
|
1046 |
+
{
|
1047 |
+
"cell_type": "code",
|
1048 |
+
"execution_count": 37,
|
1049 |
+
"id": "6664ebe1-34a6-4184-985b-ef13f4a39369",
|
1050 |
+
"metadata": {},
|
1051 |
+
"outputs": [],
|
1052 |
+
"source": [
|
1053 |
+
"# now run 'trial checker' classifier to double check the top (10) matches we have pulled"
|
1054 |
+
]
|
1055 |
+
},
|
1056 |
+
{
|
1057 |
+
"cell_type": "code",
|
1058 |
+
"execution_count": 38,
|
1059 |
+
"id": "ceaca6ce-f7af-4156-a185-2b506f57e469",
|
1060 |
+
"metadata": {},
|
1061 |
+
"outputs": [
|
1062 |
+
{
|
1063 |
+
"name": "stderr",
|
1064 |
+
"output_type": "stream",
|
1065 |
+
"text": [
|
1066 |
+
"Hardware accelerator e.g. GPU is available in the environment, but no `device` argument is passed to the `Pipeline` object. Model will be on CPU.\n"
|
1067 |
+
]
|
1068 |
+
}
|
1069 |
+
],
|
1070 |
+
"source": [
|
1071 |
+
"tokenizer = AutoTokenizer.from_pretrained(\"roberta-large\")\n",
|
1072 |
+
"\n",
|
1073 |
+
"checker_pipe = pipeline(\n",
|
1074 |
+
" 'text-classification', \n",
|
1075 |
+
" 'ksg-dfci/TrialChecker', \n",
|
1076 |
+
" tokenizer=tokenizer, \n",
|
1077 |
+
" truncation=True, \n",
|
1078 |
+
" padding='max_length', \n",
|
1079 |
+
" max_length=512\n",
|
1080 |
+
")"
|
1081 |
+
]
|
1082 |
+
},
|
1083 |
+
{
|
1084 |
+
"cell_type": "code",
|
1085 |
+
"execution_count": 39,
|
1086 |
+
"id": "9d9bc3cf-9f2f-41a2-a15f-0b552a4007c6",
|
1087 |
+
"metadata": {},
|
1088 |
+
"outputs": [],
|
1089 |
+
"source": [
|
1090 |
+
"\n",
|
1091 |
+
"output['pt_trial_pair'] = (output['this_space'] + \"\\nNow here is the patient summary:\" + output['patient_summary'])\n",
|
1092 |
+
"\n",
|
1093 |
+
"classifier_results = checker_pipe(output['pt_trial_pair'].tolist())\n",
|
1094 |
+
"output['trial_checker_result'] = [x['label'] for x in classifier_results]\n",
|
1095 |
+
"output['trial_checker_score'] = [x['score'] for x in classifier_results]"
|
1096 |
+
]
|
1097 |
+
},
|
1098 |
+
{
|
1099 |
+
"cell_type": "code",
|
1100 |
+
"execution_count": 40,
|
1101 |
+
"id": "450939bd-44cc-4dc7-900a-255a46f2c3ab",
|
1102 |
+
"metadata": {},
|
1103 |
+
"outputs": [
|
1104 |
+
{
|
1105 |
+
"name": "stdout",
|
1106 |
+
"output_type": "stream",
|
1107 |
+
"text": [
|
1108 |
+
"<class 'pandas.core.frame.DataFrame'>\n",
|
1109 |
+
"RangeIndex: 100 entries, 0 to 99\n",
|
1110 |
+
"Data columns (total 9 columns):\n",
|
1111 |
+
" # Column Non-Null Count Dtype \n",
|
1112 |
+
"--- ------ -------------- ----- \n",
|
1113 |
+
" 0 patient_summary 100 non-null object \n",
|
1114 |
+
" 1 this_space 100 non-null object \n",
|
1115 |
+
" 2 nct_id 100 non-null object \n",
|
1116 |
+
" 3 trial_title 100 non-null object \n",
|
1117 |
+
" 4 trial_brief_summary 100 non-null object \n",
|
1118 |
+
" 5 trial_text 100 non-null object \n",
|
1119 |
+
" 6 pt_trial_pair 100 non-null object \n",
|
1120 |
+
" 7 trial_checker_result 100 non-null object \n",
|
1121 |
+
" 8 trial_checker_score 100 non-null float64\n",
|
1122 |
+
"dtypes: float64(1), object(8)\n",
|
1123 |
+
"memory usage: 7.2+ KB\n"
|
1124 |
+
]
|
1125 |
+
}
|
1126 |
+
],
|
1127 |
+
"source": [
|
1128 |
+
"\n",
|
1129 |
+
"output.info()"
|
1130 |
+
]
|
1131 |
+
},
|
1132 |
+
{
|
1133 |
+
"cell_type": "code",
|
1134 |
+
"execution_count": null,
|
1135 |
+
"id": "f00270f2-7ce1-42e7-8f16-986acdde76f8",
|
1136 |
+
"metadata": {
|
1137 |
+
"scrolled": true
|
1138 |
+
},
|
1139 |
+
"outputs": [],
|
1140 |
+
"source": [
|
1141 |
+
"output"
|
1142 |
+
]
|
1143 |
+
},
|
1144 |
+
{
|
1145 |
+
"cell_type": "code",
|
1146 |
+
"execution_count": null,
|
1147 |
+
"id": "a905f88e-479e-4e4c-8492-827f7768d91c",
|
1148 |
+
"metadata": {},
|
1149 |
+
"outputs": [],
|
1150 |
+
"source": []
|
1151 |
+
}
|
1152 |
+
],
|
1153 |
+
"metadata": {
|
1154 |
+
"kernelspec": {
|
1155 |
+
"display_name": "Python 3 (ipykernel)",
|
1156 |
+
"language": "python",
|
1157 |
+
"name": "python3"
|
1158 |
+
},
|
1159 |
+
"language_info": {
|
1160 |
+
"codemirror_mode": {
|
1161 |
+
"name": "ipython",
|
1162 |
+
"version": 3
|
1163 |
+
},
|
1164 |
+
"file_extension": ".py",
|
1165 |
+
"mimetype": "text/x-python",
|
1166 |
+
"name": "python",
|
1167 |
+
"nbconvert_exporter": "python",
|
1168 |
+
"pygments_lexer": "ipython3",
|
1169 |
+
"version": "3.12.5"
|
1170 |
+
}
|
1171 |
+
},
|
1172 |
+
"nbformat": 4,
|
1173 |
+
"nbformat_minor": 5
|
1174 |
+
}
|