ACMCMC commited on
Commit
a6bd112
·
1 Parent(s): b375334
graph_analysis.m → MATLAB/get_metrics.m RENAMED
@@ -1,7 +1,6 @@
1
  % Read the CSV file
2
- data = readtable('MGREL.RRF', Delimiter='|', FileType='text', NumHeaderLines=0, VariableNamingRule='preserve');
3
  data = renamevars(data,"#CUI1","CUI1");
4
- data = data(1:1000,:);
5
  ids_1 = data.CUI1;
6
  for k = 1 : length(ids_1)
7
  cellContents = ids_1{k};
@@ -10,7 +9,6 @@ for k = 1 : length(ids_1)
10
  end
11
  ids_1 = str2double(ids_1);
12
  ids_2 = data.CUI2;
13
- ids_2 = data.CUI1(2:end);
14
  for k = 1 : length(ids_2)
15
  cellContents = ids_2{k};
16
  % Truncate and stick back into the cell
@@ -18,11 +16,6 @@ for k = 1 : length(ids_2)
18
  end
19
  ids_2 = str2double(ids_2);
20
 
21
-
22
- ids_1 = ids_1(1:end-1);
23
- ids_2 = ids_2(2:end);
24
-
25
-
26
  % Get the number of unique nodes
27
  %nodes = unique([ids_1; ids_2]);
28
  %num_nodes = length(nodes);
@@ -36,8 +29,7 @@ ids_2 = ids_2(2:end);
36
  %G = digraph(A);
37
  G = digraph(ids_1, ids_2);
38
  [bin,binsize] = conncomp(G,'Type','weak');
39
- bin(1:100)
40
- size(unique(bin))
41
  max(binsize)
42
  pg_ranks = centrality(G,'pagerank');
43
  G.Nodes.PageRank = pg_ranks;
@@ -46,4 +38,4 @@ G.Nodes.PageRank = pg_ranks;
46
  %G.Nodes.Hubs = hub_ranks;
47
  %G.Nodes.Authorities = auth_ranks;
48
  G.Nodes
49
- %plot(G);
 
1
  % Read the CSV file
2
+ data = readtable('../MGREL.RRF', Delimiter='|', FileType='text', NumHeaderLines=0, VariableNamingRule='preserve');
3
  data = renamevars(data,"#CUI1","CUI1");
 
4
  ids_1 = data.CUI1;
5
  for k = 1 : length(ids_1)
6
  cellContents = ids_1{k};
 
9
  end
10
  ids_1 = str2double(ids_1);
11
  ids_2 = data.CUI2;
 
12
  for k = 1 : length(ids_2)
13
  cellContents = ids_2{k};
14
  % Truncate and stick back into the cell
 
16
  end
17
  ids_2 = str2double(ids_2);
18
 
 
 
 
 
 
19
  % Get the number of unique nodes
20
  %nodes = unique([ids_1; ids_2]);
21
  %num_nodes = length(nodes);
 
29
  %G = digraph(A);
30
  G = digraph(ids_1, ids_2);
31
  [bin,binsize] = conncomp(G,'Type','weak');
32
+ bin(1:10)
 
33
  max(binsize)
34
  pg_ranks = centrality(G,'pagerank');
35
  G.Nodes.PageRank = pg_ranks;
 
38
  %G.Nodes.Hubs = hub_ranks;
39
  %G.Nodes.Authorities = auth_ranks;
40
  G.Nodes
41
+ %plot(G);
MATLAB/main.m CHANGED
@@ -17,7 +17,7 @@ end
17
 
18
  data = readtable('MGREL.RRF', Delimiter='|', FileType='text', NumHeaderLines=0, VariableNamingRule='preserve');
19
  data = renamevars(data,"#CUI1","CUI1");
20
- data = data(1:1000,:);
21
 
22
  % Create a Map to store connections
23
  connectionsMap = containers.Map('KeyType','char', 'ValueType','any');
 
17
 
18
  data = readtable('MGREL.RRF', Delimiter='|', FileType='text', NumHeaderLines=0, VariableNamingRule='preserve');
19
  data = renamevars(data,"#CUI1","CUI1");
20
+ data = data(1:2000,:);
21
 
22
  % Create a Map to store connections
23
  connectionsMap = containers.Map('KeyType','char', 'ValueType','any');
MATLAB/visualize_app.mlapp CHANGED
Binary files a/MATLAB/visualize_app.mlapp and b/MATLAB/visualize_app.mlapp differ
 
app.py CHANGED
@@ -1,27 +1,30 @@
1
- import streamlit as st
2
- from streamlit_agraph import agraph, Node, Edge, Config
3
  import os
4
- from sqlalchemy import create_engine, text
5
- import pandas as pd
6
  import time
 
 
 
 
 
 
 
 
 
 
7
  from utils import (
 
 
8
  get_all_diseases_name,
9
- get_most_similar_diseases_from_uri,
10
- get_uri_from_name,
11
  get_diseases_related_to_a_textual_description,
 
12
  get_similarities_among_diseases_uris,
13
- augment_the_set_of_diseaces,
14
- get_clinical_trials_related_to_diseases,
15
- get_clinical_records_by_ids,
16
  render_trial_details,
17
- filter_out_less_promising_diseases
18
  )
19
- from llm_res import get_short_summary_out_of_json_files, tagging_insights_from_json
20
- import json
21
- import numpy as np
22
- from sentence_transformers import SentenceTransformer
23
- import matplotlib
24
-
25
 
26
  # variables to reveal next steps
27
  show_graph = False
@@ -42,17 +45,22 @@ engine = create_engine(CONNECTION_STRING)
42
 
43
  st.image("img_klinic.jpeg", caption="(AI-generated image)", use_column_width=True)
44
  st.title("Klìnic", help="AI-powered clinical trial search engine")
45
- st.subheader("Find clinical trials in a scoped domain of biomedical research, guiding your research with AI-powered insights.")
 
 
46
 
47
- with st.container(): # user input
48
  col1, col2 = st.columns((6, 1))
49
 
50
  with col1:
51
- description_input = st.text_area(label="Enter a disease description 👇", placeholder='A disorder manifested in memory loss and other cognitive impairments among elderly patients (60+ years old), especially women.')
 
 
 
52
  with col2:
53
- st.text('') # dummy to center vertically
54
- st.text('') # dummy to center vertically
55
- st.text('') # dummy to center vertically
56
  show_analyze_status = st.button("Analyze 🔎")
57
 
58
 
@@ -64,45 +72,78 @@ with st.container():
64
  # 2. Get 5 diseases with the highest cosine silimarity from the DB
65
  status.write("Analyzing the description that you wrote...")
66
  encoder = SentenceTransformer("allenai-specter")
67
- diseases_related_to_the_user_text = get_diseases_related_to_a_textual_description(
68
- description_input, encoder
 
 
 
 
 
69
  )
70
- status.info(f'Selected {len(diseases_related_to_the_user_text)} diseases related to the description you entered.')
71
  status.json(diseases_related_to_the_user_text, expanded=False)
72
  status.divider()
73
  # 3. Get the similarities of the embeddings of those diseases (cosine similarity of the embeddings of the nodes of such diseases)
74
- status.write("Getting the similarities among the diseases to filter out less promising ones...")
75
- diseases_uris = [disease["uri"] for disease in diseases_related_to_the_user_text]
 
 
 
 
76
  similarities = get_similarities_among_diseases_uris(diseases_uris)
77
- status.info(f'Obtained similarity information among the diseases by measuring the cosine similarity of their embeddings.')
 
 
78
  status.json(similarities, expanded=False)
79
- filtered_diseases_uris, df_similarities = filter_out_less_promising_diseases(similarities)
 
 
80
  # Apply a colormap to the table
81
- status.table(df_similarities.style.background_gradient(cmap='viridis', axis=None))
82
- status.info(f'Filtered out less promising diseases, keeping {len(filtered_diseases_uris)} diseases.')
 
 
 
 
83
  status.json(filtered_diseases_uris, expanded=False)
84
  status.divider()
85
  # 4. Potentially filter out the diseases that are not similar enough (e.g. similarity < 0.8)
86
  # 5. Augment the set of diseases: add new diseases that are similar to the ones that are already in the set, until we get 10-15 diseases
87
- status.write("Augmenting the set of diseases by finding others with related embeddings...")
 
 
88
  augmented_set_of_diseases = augment_the_set_of_diseaces(diseases_uris)
89
- # print(augmented_set_of_diseases)
90
- status.info(f'Augmented set of diseases: {len(augmented_set_of_diseases)} diseases.')
 
 
 
 
 
 
 
 
 
 
 
91
  status.json(augmented_set_of_diseases, expanded=False)
92
  status.divider()
93
  # 6. Query the embeddings of the diseases related to each clinical trial (also in the DB), to get the most similar clinical trials to our set of diseases
94
  status.write("Getting the clinical trials related to the diseases found...")
95
- clinical_trials_related_to_the_diseases = get_clinical_trials_related_to_diseases(
96
- augmented_set_of_diseases, encoder
 
 
 
 
 
97
  )
98
- status.info(f'Selected {len(clinical_trials_related_to_the_diseases)} clinical trials related to the diseases.')
99
  status.json(clinical_trials_related_to_the_diseases, expanded=False)
100
  status.divider()
101
  status.write("Getting the details of the clinical trials...")
102
  json_of_clinical_trials = get_clinical_records_by_ids(
103
  [trial["nct_id"] for trial in clinical_trials_related_to_the_diseases]
104
  )
105
- status.success(f'Details of the clinical trials obtained.')
106
  status.json(json_of_clinical_trials, expanded=False)
107
  status.divider()
108
  # 7. Use an LLM to get a summary of the clinical trials, in plain text format.
@@ -112,22 +153,27 @@ with st.container():
112
  status.success("Summary of the clinical trials obtained.")
113
  disease_overview = response
114
  except Exception as e:
115
- print(f'Error while getting a summary of the clinical trials: {e}')
116
- status.warning(f'Error while getting a summary of the clinical trials. This information will not be shown.')
 
 
117
  try:
118
- # 8. Use an LLM to extract numerical data from the clinical trials (e.g. number of patients, number of deaths, etc.). Get summary statistics out of that.
119
  status.write("Getting summary statistics of the clinical trials...")
120
  response = tagging_insights_from_json(json_of_clinical_trials)
121
  average_minimum_age = response["avg_min_age"]
122
  average_maximum_age = response["avg_max_age"]
123
- most_common_gender = response['most_common_gender']
124
 
125
- print(f'Response from LLM tagging: {response}')
126
- status.success(f'Summary statistics of the clinical trials obtained.')
127
  except Exception as e:
128
- raise e
129
- print(f'Error while extracting numerical data from the clinical trials: {e}')
130
- status.warning(f'Error while extracting numerical data from the clinical trials. This information will not be shown.')
 
 
 
131
  # 9. Show the results to the user: graph of the diseases chosen, summary of the clinical trials, summary statistics of the clinical trials, and list of the details of the clinical trials considered
132
  status.update(label="Done!", state="complete")
133
  status.balloons()
@@ -146,37 +192,55 @@ We use the embeddings of the diseases to determine the similarity between them.
146
  [TransH](https://ojs.aaai.org/index.php/AAAI/article/view/8870) utilizes hyperplanes to model relations between entities. It is a multi-relational model that can handle many-to-many relations between entities. The model is trained on the triples of the graph, where the triples are the subject, relation, and object of the graph. The model learns the embeddings of the entities and the relations, such that the embeddings of the subject and object are close to each other when the relation is true.
147
 
148
  Specifically, it optimizes the following cost function:
149
- $$"""
150
- )
151
- # TODO actual graph
152
- graph_of_diseases = agraph(
153
- nodes=[
154
- Node(id="A", label="Node A", size=10),
155
- Node(id="B", label="Node B", size=10),
156
- Node(id="C", label="Node C", size=10),
157
- Node(id="D", label="Node D", size=10),
158
- Node(id="E", label="Node E", size=10),
159
- Node(id="F", label="Node F", size=10),
160
- Node(id="G", label="Node G", size=10),
161
- Node(id="H", label="Node H", size=10),
162
- Node(id="I", label="Node I", size=10),
163
- Node(id="J", label="Node J", size=10),
164
- ],
165
- edges=[
166
- Edge(source="A", target="B"),
167
- Edge(source="B", target="C"),
168
- Edge(source="C", target="D"),
169
- Edge(source="D", target="E"),
170
- Edge(source="E", target="F"),
171
- Edge(source="F", target="G"),
172
- Edge(source="G", target="H"),
173
- Edge(source="H", target="I"),
174
- Edge(source="I", target="J"),
175
- ],
176
- config=Config(height=500, width=500),
177
  )
178
- time.sleep(2)
179
- show_overview = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
 
181
 
182
  # overview
@@ -187,7 +251,7 @@ with st.container():
187
  st.write(disease_overview)
188
  time.sleep(1)
189
  except Exception as e:
190
- print(f'Error while showing the overview of the clinical trials: {e}')
191
  finally:
192
  show_metrics = True
193
 
@@ -196,7 +260,7 @@ with st.container():
196
  if show_metrics:
197
  try:
198
  st.write("## Metrics of the Clinical Trials")
199
- col1, col2, col3 = st.columns(3)
200
  with col1:
201
  st.metric("Average Minimum Age", average_minimum_age)
202
  with col2:
@@ -205,7 +269,7 @@ with st.container():
205
  st.metric("Most Common Gender", most_common_gender)
206
  time.sleep(2)
207
  except Exception as e:
208
- print(f'Error while showing the metrics: {e}')
209
  finally:
210
  show_details = True
211
 
@@ -215,7 +279,10 @@ with st.container():
215
  if show_details:
216
  st.write("## Clinical Trials Details")
217
 
218
- tab_titles = [f"{trial['protocolSection']['identificationModule']['nctId']}" for trial in trials]
 
 
 
219
 
220
  tabs = st.tabs(tab_titles)
221
 
@@ -231,7 +298,7 @@ if show_graph_of_all_diseases:
231
  chosen_disease_name = st.selectbox(
232
  "Choose a disease",
233
  st.session_state.disease_names,
234
- )
235
 
236
  st.write("You selected:", chosen_disease_name)
237
  chosen_disease_uri = get_uri_from_name(engine, chosen_disease_name)
@@ -239,41 +306,39 @@ if show_graph_of_all_diseases:
239
  nodes = []
240
  edges = []
241
 
242
-
243
- nodes.append( Node(id=chosen_disease_uri,
244
- label=chosen_disease_name,
245
- size=25,
246
- shape="circular")
247
  )
248
 
249
- similar_diseases = get_most_similar_diseases_from_uri(engine, chosen_disease_uri, threshold=0.6)
 
 
250
  print(similar_diseases)
251
  for uri, name, weight in similar_diseases:
252
- nodes.append( Node(id=uri,
253
- label=name,
254
- size=25,
255
- shape="circular")
256
- )
257
 
258
  print(True if float(weight) > 0.7 else False)
259
- edges.append( Edge(source=chosen_disease_uri,
260
- target=uri,
261
- color="red" if float(weight) > 0.7 else "blue",
262
- weight=float(weight)**10,
263
- type="CURVE_SMOOTH"
264
- # type="STRAIGHT"
265
- )
266
- )
267
-
268
- config = Config(width=750,
269
- height=950,
270
- directed=False,
271
- physics=True,
272
- hierarchical=False,
273
- collapsible=False,
274
- # **kwargs
275
- )
 
 
 
276
 
277
- return_value = agraph(nodes=nodes,
278
- edges=edges,
279
- config=config)
 
1
+ import json
 
2
  import os
 
 
3
  import time
4
+
5
+ import matplotlib
6
+ import numpy as np
7
+ import pandas as pd
8
+ import streamlit as st
9
+ from sentence_transformers import SentenceTransformer
10
+ from sqlalchemy import create_engine, text
11
+ from streamlit_agraph import Config, Edge, Node, agraph
12
+
13
+ from llm_res import get_short_summary_out_of_json_files, tagging_insights_from_json
14
  from utils import (
15
+ augment_the_set_of_diseaces,
16
+ filter_out_less_promising_diseases,
17
  get_all_diseases_name,
18
+ get_clinical_records_by_ids,
19
+ get_clinical_trials_related_to_diseases,
20
  get_diseases_related_to_a_textual_description,
21
+ get_most_similar_diseases_from_uri,
22
  get_similarities_among_diseases_uris,
23
+ get_similarities_df,
24
+ get_uri_from_name,
 
25
  render_trial_details,
26
+ get_labels_of_diseases_from_uris,
27
  )
 
 
 
 
 
 
28
 
29
  # variables to reveal next steps
30
  show_graph = False
 
45
 
46
  st.image("img_klinic.jpeg", caption="(AI-generated image)", use_column_width=True)
47
  st.title("Klìnic", help="AI-powered clinical trial search engine")
48
+ st.subheader(
49
+ "Find clinical trials in a scoped domain of biomedical research, guiding your research with AI-powered insights."
50
+ )
51
 
52
+ with st.container(): # user input
53
  col1, col2 = st.columns((6, 1))
54
 
55
  with col1:
56
+ description_input = st.text_area(
57
+ label="Enter a disease description 👇",
58
+ placeholder="A disorder manifested in memory loss and other cognitive impairments among elderly patients (60+ years old), especially women.",
59
+ )
60
  with col2:
61
+ st.text("") # dummy to center vertically
62
+ st.text("") # dummy to center vertically
63
+ st.text("") # dummy to center vertically
64
  show_analyze_status = st.button("Analyze 🔎")
65
 
66
 
 
72
  # 2. Get 5 diseases with the highest cosine silimarity from the DB
73
  status.write("Analyzing the description that you wrote...")
74
  encoder = SentenceTransformer("allenai-specter")
75
+ diseases_related_to_the_user_text = (
76
+ get_diseases_related_to_a_textual_description(
77
+ description_input, encoder
78
+ )
79
+ )
80
+ status.info(
81
+ f"Selected {len(diseases_related_to_the_user_text)} diseases related to the description you entered."
82
  )
 
83
  status.json(diseases_related_to_the_user_text, expanded=False)
84
  status.divider()
85
  # 3. Get the similarities of the embeddings of those diseases (cosine similarity of the embeddings of the nodes of such diseases)
86
+ status.write(
87
+ "Getting the similarities among the diseases to filter out less promising ones..."
88
+ )
89
+ diseases_uris = [
90
+ disease["uri"] for disease in diseases_related_to_the_user_text
91
+ ]
92
  similarities = get_similarities_among_diseases_uris(diseases_uris)
93
+ status.info(
94
+ f"Obtained similarity information among the diseases by measuring the cosine similarity of their embeddings."
95
+ )
96
  status.json(similarities, expanded=False)
97
+ filtered_diseases_uris, df_similarities = (
98
+ filter_out_less_promising_diseases(similarities)
99
+ )
100
  # Apply a colormap to the table
101
+ status.table(
102
+ df_similarities.style.background_gradient(cmap="viridis", axis=None)
103
+ )
104
+ status.info(
105
+ f"Filtered out less promising diseases, keeping {len(filtered_diseases_uris)} diseases."
106
+ )
107
  status.json(filtered_diseases_uris, expanded=False)
108
  status.divider()
109
  # 4. Potentially filter out the diseases that are not similar enough (e.g. similarity < 0.8)
110
  # 5. Augment the set of diseases: add new diseases that are similar to the ones that are already in the set, until we get 10-15 diseases
111
+ status.write(
112
+ "Augmenting the set of diseases by finding others with related embeddings..."
113
+ )
114
  augmented_set_of_diseases = augment_the_set_of_diseaces(diseases_uris)
115
+ similarities_of_augmented_set_of_diseases = (
116
+ get_similarities_among_diseases_uris(augmented_set_of_diseases)
117
+ )
118
+ df_similarities_augmented_set = get_similarities_df(
119
+ similarities_of_augmented_set_of_diseases
120
+ )
121
+ status.table(
122
+ df_similarities_augmented_set.style.background_gradient(cmap="viridis", axis=None)
123
+ )
124
+ status.json(similarities_of_augmented_set_of_diseases, expanded=True)
125
+ status.info(
126
+ f"Augmented set of diseases: {len(augmented_set_of_diseases)} diseases."
127
+ )
128
  status.json(augmented_set_of_diseases, expanded=False)
129
  status.divider()
130
  # 6. Query the embeddings of the diseases related to each clinical trial (also in the DB), to get the most similar clinical trials to our set of diseases
131
  status.write("Getting the clinical trials related to the diseases found...")
132
+ clinical_trials_related_to_the_diseases = (
133
+ get_clinical_trials_related_to_diseases(
134
+ augmented_set_of_diseases, encoder
135
+ )
136
+ )
137
+ status.info(
138
+ f"Selected {len(clinical_trials_related_to_the_diseases)} clinical trials related to the diseases."
139
  )
 
140
  status.json(clinical_trials_related_to_the_diseases, expanded=False)
141
  status.divider()
142
  status.write("Getting the details of the clinical trials...")
143
  json_of_clinical_trials = get_clinical_records_by_ids(
144
  [trial["nct_id"] for trial in clinical_trials_related_to_the_diseases]
145
  )
146
+ status.success(f"Details of the clinical trials obtained.")
147
  status.json(json_of_clinical_trials, expanded=False)
148
  status.divider()
149
  # 7. Use an LLM to get a summary of the clinical trials, in plain text format.
 
153
  status.success("Summary of the clinical trials obtained.")
154
  disease_overview = response
155
  except Exception as e:
156
+ print(f"Error while getting a summary of the clinical trials: {e}")
157
+ status.warning(
158
+ f"Error while getting a summary of the clinical trials. This information will not be shown."
159
+ )
160
  try:
161
+ # 8. Use an LLM to extract numerical data from the clinical trials (e.g. number of patients, number of deaths, etc.). Get summary statistics out of that.
162
  status.write("Getting summary statistics of the clinical trials...")
163
  response = tagging_insights_from_json(json_of_clinical_trials)
164
  average_minimum_age = response["avg_min_age"]
165
  average_maximum_age = response["avg_max_age"]
166
+ most_common_gender = response["most_common_gender"]
167
 
168
+ print(f"Response from LLM tagging: {response}")
169
+ status.success(f"Summary statistics of the clinical trials obtained.")
170
  except Exception as e:
171
+ print(
172
+ f"Error while extracting numerical data from the clinical trials: {e}"
173
+ )
174
+ status.warning(
175
+ f"Error while extracting numerical data from the clinical trials. This information will not be shown."
176
+ )
177
  # 9. Show the results to the user: graph of the diseases chosen, summary of the clinical trials, summary statistics of the clinical trials, and list of the details of the clinical trials considered
178
  status.update(label="Done!", state="complete")
179
  status.balloons()
 
192
  [TransH](https://ojs.aaai.org/index.php/AAAI/article/view/8870) utilizes hyperplanes to model relations between entities. It is a multi-relational model that can handle many-to-many relations between entities. The model is trained on the triples of the graph, where the triples are the subject, relation, and object of the graph. The model learns the embeddings of the entities and the relations, such that the embeddings of the subject and object are close to each other when the relation is true.
193
 
194
  Specifically, it optimizes the following cost function:
195
+ $\\text{minimize} \\sum_{(h, r, t) \\in S} \\max(0, \\gamma + f(h, r, t) - f(h, r, t')) + \\sum_{(h, r, t) \\in S'} f(h, r, t)$
196
+ """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
  )
198
+ try:
199
+ edges_to_show = []
200
+ labels_of_diseases = get_labels_of_diseases_from_uris(
201
+ df_similarities_augmented_set.index
202
+ )
203
+ uris_and_labels_of_diseases = dict(
204
+ zip(df_similarities_augmented_set.index, labels_of_diseases)
205
+ )
206
+ color_mapper = matplotlib.cm.get_cmap("viridis")
207
+ for source in df_similarities_augmented_set.index:
208
+ for target in df_similarities_augmented_set.columns:
209
+ if source != target:
210
+ weight = df_similarities_augmented_set.loc[source, target]
211
+ color = color_mapper(weight)
212
+ # Convert from rgba to hex
213
+ color = matplotlib.colors.to_hex(color)
214
+ edges_to_show.append(
215
+ Edge(
216
+ source=source,
217
+ target=target,
218
+ # Dynamic color based on the weight
219
+ color=color,
220
+ weight=weight**10,
221
+ type="CURVE_SMOOTH",
222
+ label=f"{weight:.2f}",
223
+ )
224
+ )
225
+ graph_of_diseases = agraph(
226
+ nodes=[
227
+ Node(
228
+ id=disease,
229
+ label=disease,#uris_and_labels_of_diseases[disease],
230
+ size=25,
231
+ shape="circular",
232
+ )
233
+ for disease in df_similarities_augmented_set.index
234
+ ],
235
+ edges=edges_to_show,
236
+ config=Config(height=500, width=500),
237
+ )
238
+ time.sleep(2)
239
+ except Exception as e:
240
+ print(f"Error while showing the graph of the diseases: {e}")
241
+ st.error("Error while showing the graph of the diseases.")
242
+ finally:
243
+ show_overview = True
244
 
245
 
246
  # overview
 
251
  st.write(disease_overview)
252
  time.sleep(1)
253
  except Exception as e:
254
+ print(f"Error while showing the overview of the clinical trials: {e}")
255
  finally:
256
  show_metrics = True
257
 
 
260
  if show_metrics:
261
  try:
262
  st.write("## Metrics of the Clinical Trials")
263
+ col1, col2, col3 = st.columns(3)
264
  with col1:
265
  st.metric("Average Minimum Age", average_minimum_age)
266
  with col2:
 
269
  st.metric("Most Common Gender", most_common_gender)
270
  time.sleep(2)
271
  except Exception as e:
272
+ print(f"Error while showing the metrics: {e}")
273
  finally:
274
  show_details = True
275
 
 
279
  if show_details:
280
  st.write("## Clinical Trials Details")
281
 
282
+ tab_titles = [
283
+ f"{trial['protocolSection']['identificationModule']['nctId']}"
284
+ for trial in trials
285
+ ]
286
 
287
  tabs = st.tabs(tab_titles)
288
 
 
298
  chosen_disease_name = st.selectbox(
299
  "Choose a disease",
300
  st.session_state.disease_names,
301
+ )
302
 
303
  st.write("You selected:", chosen_disease_name)
304
  chosen_disease_uri = get_uri_from_name(engine, chosen_disease_name)
 
306
  nodes = []
307
  edges = []
308
 
309
+ nodes.append(
310
+ Node(
311
+ id=chosen_disease_uri, label=chosen_disease_name, size=25, shape="circular"
312
+ )
 
313
  )
314
 
315
+ similar_diseases = get_most_similar_diseases_from_uri(
316
+ engine, chosen_disease_uri, threshold=0.6
317
+ )
318
  print(similar_diseases)
319
  for uri, name, weight in similar_diseases:
320
+ nodes.append(Node(id=uri, label=name, size=25, shape="circular"))
 
 
 
 
321
 
322
  print(True if float(weight) > 0.7 else False)
323
+ edges.append(
324
+ Edge(
325
+ source=chosen_disease_uri,
326
+ target=uri,
327
+ color="red" if float(weight) > 0.7 else "blue",
328
+ weight=float(weight) ** 10,
329
+ type="CURVE_SMOOTH",
330
+ # type="STRAIGHT"
331
+ )
332
+ )
333
+
334
+ config = Config(
335
+ width=750,
336
+ height=950,
337
+ directed=False,
338
+ physics=True,
339
+ hierarchical=False,
340
+ collapsible=False,
341
+ # **kwargs
342
+ )
343
 
344
+ return_value = agraph(nodes=nodes, edges=edges, config=config)
 
 
calculate_smilar_nodes.py CHANGED
@@ -6,6 +6,7 @@ def transe_distance(head, tail, relation, entity_embeddings, relation_embeddings
6
  distance = head_embedding + relation_embeddings - tail_embedding
7
  return distance
8
 
 
9
  def calculate_similar_nodes(node, entity_embeddings, relation_embeddings, top_n=10):
10
  distances = []
11
  for i in range(len(entity_embeddings)):
@@ -14,6 +15,7 @@ def calculate_similar_nodes(node, entity_embeddings, relation_embeddings, top_n=
14
  distances.sort(key=lambda x: x[1].norm().item())
15
  return distances[:top_n]
16
 
 
17
  # %%
18
  import pandas as pd
19
 
@@ -55,9 +57,13 @@ print(
55
  )
56
  # %%
57
  # Calculate similar nodes to the head
58
- similar_nodes = calculate_similar_nodes(head, entity_embeddings["embedding"], relation_embeddings["embedding"])
 
 
59
  print(f"Similar nodes to {entity_embeddings['label'][head]} ({head}):")
60
  # Print the similar nodes
61
  for i, (node, distance) in enumerate(similar_nodes):
62
- print(f"{i}: {entity_embeddings['label'][node]} ({node}) with distance {distance.norm().item()}")
63
- # %%
 
 
 
6
  distance = head_embedding + relation_embeddings - tail_embedding
7
  return distance
8
 
9
+
10
  def calculate_similar_nodes(node, entity_embeddings, relation_embeddings, top_n=10):
11
  distances = []
12
  for i in range(len(entity_embeddings)):
 
15
  distances.sort(key=lambda x: x[1].norm().item())
16
  return distances[:top_n]
17
 
18
+
19
  # %%
20
  import pandas as pd
21
 
 
57
  )
58
  # %%
59
  # Calculate similar nodes to the head
60
+ similar_nodes = calculate_similar_nodes(
61
+ head, entity_embeddings["embedding"], relation_embeddings["embedding"]
62
+ )
63
  print(f"Similar nodes to {entity_embeddings['label'][head]} ({head}):")
64
  # Print the similar nodes
65
  for i, (node, distance) in enumerate(similar_nodes):
66
+ print(
67
+ f"{i}: {entity_embeddings['label'][node]} ({node}) with distance {distance.norm().item()}"
68
+ )
69
+ # %%
llm_res.py CHANGED
@@ -1,15 +1,19 @@
1
  import ast
2
  import json
3
  import os
 
 
4
  from typing import Any, Dict, List
5
 
6
  import langchain
7
  import openai
8
  import pandas as pd
 
9
  import requests
10
  from dotenv import load_dotenv
11
  from langchain import OpenAI
12
  from langchain.chains.combine_documents.stuff import StuffDocumentsChain
 
13
  from langchain.chains.qa_with_sources.loading import load_qa_with_sources_chain
14
  from langchain.document_loaders import UnstructuredURLLoader
15
  from langchain.embeddings import OpenAIEmbeddings
@@ -17,14 +21,9 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
17
  from langchain.vectorstores import FAISS
18
  from langchain_community.document_loaders import JSONLoader
19
  from langchain_community.document_loaders.csv_loader import CSVLoader
20
- from langchain_core.prompts import ChatPromptTemplate
21
  from langchain_core.pydantic_v1 import BaseModel, Field
22
  from langchain_openai import ChatOpenAI
23
- from langchain.chains.llm import LLMChain
24
- from langchain_core.prompts import PromptTemplate
25
- from collections import Counter
26
- import statistics
27
- import regex as re
28
 
29
  load_dotenv()
30
 
@@ -245,7 +244,7 @@ General summary:"""
245
  prompt = PromptTemplate.from_template(prompt_template)
246
 
247
  llm = ChatOpenAI(
248
- temperature=0.4, model_name="gpt-4-turbo", api_key=os.environ["OPENAI_API_KEY"]
249
  )
250
  llm_chain = LLMChain(llm=llm, prompt=prompt)
251
 
@@ -279,8 +278,12 @@ General summary:"""
279
  def analyze_data(data):
280
  print(f"Data: {data}")
281
  # Extract minimum and maximum ages: Turn ['18 Years', '20 Years'] into [18, 20]
282
- min_ages = [int(re.search(r"\d+", age).group()) for age in data["minimum_age"] if age]
283
- max_ages = [int(re.search(r"\d+", age).group()) for age in data["maximum_age"] if age]
 
 
 
 
284
  # primary_timeframe= [int(age.split()[0]) for age in data['[primary_outcome]'] if age]
285
 
286
  # Calculate average minimum and maximum ages
@@ -292,13 +295,13 @@ def analyze_data(data):
292
  most_common_gender = gender_counter.most_common(1)[0][0]
293
 
294
  # Flatten keywords list and find common keywords
295
- #keywords = [keyword for sublist in data["keywords"] for keyword in sublist]
296
- #common_keywords = [word for word, count in Counter(keywords).most_common()]
297
 
298
  return {
299
  "avg_min_age": avg_min_age,
300
  "avg_max_age": avg_max_age,
301
- "most_common_gender": most_common_gender
302
  }
303
 
304
 
@@ -379,9 +382,7 @@ def tagging_insights_from_json(data_json):
379
  res = tagging_chain.invoke({"input": processed_json})
380
  unprocessed_results_dict = res.get_dict()
381
 
382
- results_dict = analyze_data(
383
- unprocessed_results_dict
384
- )
385
 
386
  # stats_dict= {'Average Minimum age': avg_min_age,
387
  # 'Average Maximum age': avg_max_age,
 
1
  import ast
2
  import json
3
  import os
4
+ import statistics
5
+ from collections import Counter
6
  from typing import Any, Dict, List
7
 
8
  import langchain
9
  import openai
10
  import pandas as pd
11
+ import regex as re
12
  import requests
13
  from dotenv import load_dotenv
14
  from langchain import OpenAI
15
  from langchain.chains.combine_documents.stuff import StuffDocumentsChain
16
+ from langchain.chains.llm import LLMChain
17
  from langchain.chains.qa_with_sources.loading import load_qa_with_sources_chain
18
  from langchain.document_loaders import UnstructuredURLLoader
19
  from langchain.embeddings import OpenAIEmbeddings
 
21
  from langchain.vectorstores import FAISS
22
  from langchain_community.document_loaders import JSONLoader
23
  from langchain_community.document_loaders.csv_loader import CSVLoader
24
+ from langchain_core.prompts import ChatPromptTemplate, PromptTemplate
25
  from langchain_core.pydantic_v1 import BaseModel, Field
26
  from langchain_openai import ChatOpenAI
 
 
 
 
 
27
 
28
  load_dotenv()
29
 
 
244
  prompt = PromptTemplate.from_template(prompt_template)
245
 
246
  llm = ChatOpenAI(
247
+ temperature=0.5, model_name="gpt-4-turbo", api_key=os.environ["OPENAI_API_KEY"]
248
  )
249
  llm_chain = LLMChain(llm=llm, prompt=prompt)
250
 
 
278
  def analyze_data(data):
279
  print(f"Data: {data}")
280
  # Extract minimum and maximum ages: Turn ['18 Years', '20 Years'] into [18, 20]
281
+ min_ages = [
282
+ int(re.search(r"\d+", age).group()) for age in data["minimum_age"] if age
283
+ ]
284
+ max_ages = [
285
+ int(re.search(r"\d+", age).group()) for age in data["maximum_age"] if age
286
+ ]
287
  # primary_timeframe= [int(age.split()[0]) for age in data['[primary_outcome]'] if age]
288
 
289
  # Calculate average minimum and maximum ages
 
295
  most_common_gender = gender_counter.most_common(1)[0][0]
296
 
297
  # Flatten keywords list and find common keywords
298
+ # keywords = [keyword for sublist in data["keywords"] for keyword in sublist]
299
+ # common_keywords = [word for word, count in Counter(keywords).most_common()]
300
 
301
  return {
302
  "avg_min_age": avg_min_age,
303
  "avg_max_age": avg_max_age,
304
+ "most_common_gender": most_common_gender,
305
  }
306
 
307
 
 
382
  res = tagging_chain.invoke({"input": processed_json})
383
  unprocessed_results_dict = res.get_dict()
384
 
385
+ results_dict = analyze_data(unprocessed_results_dict)
 
 
386
 
387
  # stats_dict= {'Average Minimum age': avg_min_age,
388
  # 'Average Maximum age': avg_max_age,
main.ipynb CHANGED
@@ -245,52 +245,55 @@
245
  }
246
  ],
247
  "source": [
248
- "df_summary = pd.read_csv('file_db/brief_summaries.txt', delimiter='|')\n",
249
- "df_summary = df_summary.rename(columns={'description': 'summary'})\n",
250
  "\n",
251
  "### create and merge intervention ###\n",
252
- "df_intervention = pd.read_csv('file_db/interventions.txt', delimiter='|')\n",
253
  "\n",
254
- "intervention_grouped = df_intervention.groupby('nct_id')['name'].apply(list).reset_index()\n",
255
- "intervention_grouped = intervention_grouped.rename(columns={'name': 'intervention_name'})\n",
 
 
 
 
256
  "merged_df = pd.merge(\n",
257
- " df_summary[['nct_id', 'summary']], \n",
258
- " intervention_grouped[['nct_id', 'intervention_name']], \n",
259
- " on='nct_id')\n",
 
260
  "\n",
261
- "df_intervention = df_intervention.rename(columns={'description': 'intervention_description'})\n",
 
 
262
  "\n",
263
  "merged_df = pd.merge(\n",
264
  " merged_df,\n",
265
- " df_intervention[['nct_id', 'intervention_type', 'intervention_description']], \n",
266
- " on='nct_id')\n",
 
267
  "\n",
268
  "### create and merge keywords ###\n",
269
- "df_keyword = pd.read_csv('file_db/keywords.txt', delimiter='|')\n",
270
- "keywords_grouped = df_keyword.groupby('nct_id')['name'].apply(list).reset_index()\n",
271
- "keywords_grouped = keywords_grouped.rename(columns={'name': 'keywords'})\n",
272
  "\n",
273
- "merged_df = pd.merge(\n",
274
- " merged_df,\n",
275
- " keywords_grouped,\n",
276
- " on='nct_id'\n",
277
- ")\n",
278
  "\n",
279
  "### create and merge browse conditions\n",
280
- "df_condition = pd.read_csv('file_db/browse_conditions.txt', delimiter='|')\n",
281
- "conditions_grouped = df_condition.groupby('nct_id')['downcase_mesh_term'].apply(list).reset_index()\n",
282
- "conditions_grouped = conditions_grouped.rename(columns={'downcase_mesh_term': 'desease_condition'})\n",
283
- "\n",
284
- "merged_df = pd.merge(\n",
285
- " merged_df,\n",
286
- " conditions_grouped,\n",
287
- " on='nct_id'\n",
288
  ")\n",
289
  "\n",
290
- "merged_df = merged_df.drop_duplicates(subset='nct_id')\n",
291
  "\n",
292
- "merged_df.head()\n",
293
- "\n"
 
294
  ]
295
  },
296
  {
 
245
  }
246
  ],
247
  "source": [
248
+ "df_summary = pd.read_csv(\"file_db/brief_summaries.txt\", delimiter=\"|\")\n",
249
+ "df_summary = df_summary.rename(columns={\"description\": \"summary\"})\n",
250
  "\n",
251
  "### create and merge intervention ###\n",
252
+ "df_intervention = pd.read_csv(\"file_db/interventions.txt\", delimiter=\"|\")\n",
253
  "\n",
254
+ "intervention_grouped = (\n",
255
+ " df_intervention.groupby(\"nct_id\")[\"name\"].apply(list).reset_index()\n",
256
+ ")\n",
257
+ "intervention_grouped = intervention_grouped.rename(\n",
258
+ " columns={\"name\": \"intervention_name\"}\n",
259
+ ")\n",
260
  "merged_df = pd.merge(\n",
261
+ " df_summary[[\"nct_id\", \"summary\"]],\n",
262
+ " intervention_grouped[[\"nct_id\", \"intervention_name\"]],\n",
263
+ " on=\"nct_id\",\n",
264
+ ")\n",
265
  "\n",
266
+ "df_intervention = df_intervention.rename(\n",
267
+ " columns={\"description\": \"intervention_description\"}\n",
268
+ ")\n",
269
  "\n",
270
  "merged_df = pd.merge(\n",
271
  " merged_df,\n",
272
+ " df_intervention[[\"nct_id\", \"intervention_type\", \"intervention_description\"]],\n",
273
+ " on=\"nct_id\",\n",
274
+ ")\n",
275
  "\n",
276
  "### create and merge keywords ###\n",
277
+ "df_keyword = pd.read_csv(\"file_db/keywords.txt\", delimiter=\"|\")\n",
278
+ "keywords_grouped = df_keyword.groupby(\"nct_id\")[\"name\"].apply(list).reset_index()\n",
279
+ "keywords_grouped = keywords_grouped.rename(columns={\"name\": \"keywords\"})\n",
280
  "\n",
281
+ "merged_df = pd.merge(merged_df, keywords_grouped, on=\"nct_id\")\n",
 
 
 
 
282
  "\n",
283
  "### create and merge browse conditions\n",
284
+ "df_condition = pd.read_csv(\"file_db/browse_conditions.txt\", delimiter=\"|\")\n",
285
+ "conditions_grouped = (\n",
286
+ " df_condition.groupby(\"nct_id\")[\"downcase_mesh_term\"].apply(list).reset_index()\n",
287
+ ")\n",
288
+ "conditions_grouped = conditions_grouped.rename(\n",
289
+ " columns={\"downcase_mesh_term\": \"desease_condition\"}\n",
 
 
290
  ")\n",
291
  "\n",
292
+ "merged_df = pd.merge(merged_df, conditions_grouped, on=\"nct_id\")\n",
293
  "\n",
294
+ "merged_df = merged_df.drop_duplicates(subset=\"nct_id\")\n",
295
+ "\n",
296
+ "merged_df.head()"
297
  ]
298
  },
299
  {
utils.py CHANGED
@@ -1,11 +1,12 @@
1
  # %%
2
- from typing import List, Dict, Any
3
  import os
4
- from sqlalchemy import create_engine, text
 
 
5
  import requests
6
- from sentence_transformers import SentenceTransformer
7
  import streamlit as st
8
- import pandas as pd
 
9
 
10
  username = "demo"
11
  password = "demo"
@@ -124,16 +125,19 @@ def get_similarities_among_diseases_uris(
124
  result = conn.execute(text(sql))
125
  data = result.fetchall()
126
 
127
- return [{
128
- "uri1": row[0].split("/")[-1],
129
- "uri2": row[1].split("/")[-1],
130
- "distance": float(row[2]),
131
- } for row in data]
 
 
 
132
 
133
 
134
  def augment_the_set_of_diseaces(diseases: List[str]) -> str:
135
  augmented_diseases = diseases.copy()
136
- for i in range(15-len(augmented_diseases)):
137
  with engine.connect() as conn:
138
  with conn.begin():
139
  sql = f"""
@@ -153,6 +157,7 @@ def augment_the_set_of_diseaces(diseases: List[str]) -> str:
153
 
154
  return augmented_diseases
155
 
 
156
  def get_embedding(string: str, encoder) -> List[float]:
157
  # Embed the string using sentence-transformers
158
  vector = encoder.encode(string, show_progress_bar=False)
@@ -176,11 +181,14 @@ def get_diseases_related_to_a_textual_description(
176
  result = conn.execute(text(sql))
177
  data = result.fetchall()
178
 
179
- return [{"uri": row[0], "distance": float(row[1])} for row in data if float(row[1]) > 0.8]
 
 
 
 
180
 
181
- def get_clinical_trials_related_to_diseases(
182
- diseases: List[str], encoder
183
- ) -> List[str]:
184
  # Embed the diseases using sentence-transformers
185
  diseases_string = ", ".join(diseases)
186
  disease_embedding = get_embedding(diseases_string, encoder)
@@ -189,7 +197,7 @@ def get_clinical_trials_related_to_diseases(
189
  with engine.connect() as conn:
190
  with conn.begin():
191
  sql = f"""
192
- SELECT TOP 15 d.nct_id, VECTOR_COSINE(d.embedding, TO_VECTOR('{string_representation}', DOUBLE)) AS distance
193
  FROM Test.ClinicalTrials d
194
  ORDER BY distance DESC
195
  """
@@ -198,82 +206,139 @@ def get_clinical_trials_related_to_diseases(
198
 
199
  return [{"nct_id": row[0], "distance": row[1]} for row in data]
200
 
201
- def filter_out_less_promising_diseases(info_dicts: List[Dict[str, Any]]) -> List[str]:
 
202
  # Find out the score of each disease by averaging the cosine similarity of the embeddings of the diseases that include it as uri1 or uri2
203
- df_diseases_similarities = pd.DataFrame(info_dicts)
204
  # Use uri1 as the index, and uri2 as the columns. The values are the distances.
205
- df_diseases_similarities = df_diseases_similarities.pivot(index="uri1", columns="uri2", values="distance")
 
 
206
  # Fill the diagonal with 1.0
207
  df_diseases_similarities = df_diseases_similarities.fillna(1.0)
208
 
209
- # Filter out the diseases that are 1 standard deviation below the mean
 
 
 
 
 
 
210
  mean = df_diseases_similarities.mean().mean()
211
  std = df_diseases_similarities.mean().std()
212
- filtered_diseases = df_diseases_similarities.mean()[df_diseases_similarities.mean() > mean - std].index.tolist()
 
 
213
  return filtered_diseases, df_diseases_similarities
214
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
215
  def to_capitalized_case(string: str) -> str:
216
  string = string.replace("_", " ")
217
  if string.isupper():
218
  return string[0] + string[1:].lower()
219
-
 
220
  def list_to_capitalized_case(strings: List[str]) -> str:
221
  strings = [to_capitalized_case(s) for s in strings]
222
  return ", ".join(strings)
223
 
 
224
  def render_trial_details(trial: dict) -> None:
225
- # TODO: handle key errors for all cases (→ do not render)
226
-
227
- official_title = trial["protocolSection"]["identificationModule"]["officialTitle"]
228
- st.write(f"##### {official_title}")
229
-
230
- try:
231
- st.write(trial["protocolSection"]["descriptionModule"]["briefSummary"])
232
- except KeyError:
233
- try:
234
- st.write(trial["protocolSection"]["descriptionModule"]["detailedDescription"])
235
- except KeyError:
236
- st.error("No description available.")
237
-
238
- st.write("###### Status")
239
- try:
240
- status_module = {
241
- "Status": to_capitalized_case(trial["protocolSection"]["statusModule"]["overallStatus"]),
242
- "Status Date": trial["protocolSection"]["statusModule"]["statusVerifiedDate"],
243
- "Has Results": trial["hasResults"]
244
- }
245
- st.table(status_module)
246
- except KeyError:
247
- st.info("No status information available.")
248
-
249
- st.write("###### Design")
250
- try:
251
- design_module = {
252
- "Study Type": to_capitalized_case(trial["protocolSection"]["designModule"]["studyType"]),
253
- "Phases": list_to_capitalized_case(trial["protocolSection"]["designModule"]["phases"]),
254
- "Allocation": to_capitalized_case(trial["protocolSection"]["designModule"]["designInfo"]["allocation"]),
255
- "Primary Purpose": to_capitalized_case(trial["protocolSection"]["designModule"]["designInfo"]["primaryPurpose"]),
256
- "Participants": trial["protocolSection"]["designModule"]["enrollmentInfo"]["count"],
257
- "Masking": to_capitalized_case(trial["protocolSection"]["designModule"]["designInfo"]["maskingInfo"]["masking"]),
258
- "Who Masked": list_to_capitalized_case(trial["protocolSection"]["designModule"]["designInfo"]["maskingInfo"]["whoMasked"])
259
- }
260
- st.table(design_module)
261
- except KeyError:
262
- st.info("No design information available.")
263
-
264
- st.write("###### Interventions")
265
- try:
266
- interventions_module = {}
267
- for intervention in trial["protocolSection"]["armsInterventionsModule"]["interventions"]:
268
- name = intervention["name"]
269
- desc = intervention["description"]
270
- interventions_module[name] = desc
271
- st.table(interventions_module)
272
- except KeyError:
273
- st.info("No interventions information available.")
274
-
275
- # Button to go to ClinicalTrials.gov and see the trial. It takes the user to the official page of the trial.
276
- st.markdown(f"See more in [ClinicalTrials.gov](https://clinicaltrials.gov/study/{trial['protocolSection']['identificationModule']['nctId']})")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
277
 
278
  if __name__ == "__main__":
279
  username = "demo"
 
1
  # %%
 
2
  import os
3
+ from typing import Any, Dict, List
4
+
5
+ import pandas as pd
6
  import requests
 
7
  import streamlit as st
8
+ from sentence_transformers import SentenceTransformer
9
+ from sqlalchemy import create_engine, text
10
 
11
  username = "demo"
12
  password = "demo"
 
125
  result = conn.execute(text(sql))
126
  data = result.fetchall()
127
 
128
+ return [
129
+ {
130
+ "uri1": row[0].split("/")[-1],
131
+ "uri2": row[1].split("/")[-1],
132
+ "distance": float(row[2]),
133
+ }
134
+ for row in data
135
+ ]
136
 
137
 
138
  def augment_the_set_of_diseaces(diseases: List[str]) -> str:
139
  augmented_diseases = diseases.copy()
140
+ for i in range(10 - len(augmented_diseases)):
141
  with engine.connect() as conn:
142
  with conn.begin():
143
  sql = f"""
 
157
 
158
  return augmented_diseases
159
 
160
+
161
  def get_embedding(string: str, encoder) -> List[float]:
162
  # Embed the string using sentence-transformers
163
  vector = encoder.encode(string, show_progress_bar=False)
 
181
  result = conn.execute(text(sql))
182
  data = result.fetchall()
183
 
184
+ return [
185
+ {"uri": row[0], "distance": float(row[1])}
186
+ for row in data
187
+ if float(row[1]) > 0.8
188
+ ]
189
 
190
+
191
+ def get_clinical_trials_related_to_diseases(diseases: List[str], encoder) -> List[str]:
 
192
  # Embed the diseases using sentence-transformers
193
  diseases_string = ", ".join(diseases)
194
  disease_embedding = get_embedding(diseases_string, encoder)
 
197
  with engine.connect() as conn:
198
  with conn.begin():
199
  sql = f"""
200
+ SELECT TOP 20 d.nct_id, VECTOR_COSINE(d.embedding, TO_VECTOR('{string_representation}', DOUBLE)) AS distance
201
  FROM Test.ClinicalTrials d
202
  ORDER BY distance DESC
203
  """
 
206
 
207
  return [{"nct_id": row[0], "distance": row[1]} for row in data]
208
 
209
+
210
+ def get_similarities_df(diseases: List[Dict[str, Any]]) -> pd.DataFrame:
211
  # Find out the score of each disease by averaging the cosine similarity of the embeddings of the diseases that include it as uri1 or uri2
212
+ df_diseases_similarities = pd.DataFrame(diseases)
213
  # Use uri1 as the index, and uri2 as the columns. The values are the distances.
214
+ df_diseases_similarities = df_diseases_similarities.pivot(
215
+ index="uri1", columns="uri2", values="distance"
216
+ )
217
  # Fill the diagonal with 1.0
218
  df_diseases_similarities = df_diseases_similarities.fillna(1.0)
219
 
220
+ return df_diseases_similarities
221
+
222
+
223
+ def filter_out_less_promising_diseases(info_dicts: List[Dict[str, Any]]) -> List[str]:
224
+ df_diseases_similarities = get_similarities_df(info_dicts)
225
+
226
+ # Filter out the diseases that are 0.2 standard deviations below the mean
227
  mean = df_diseases_similarities.mean().mean()
228
  std = df_diseases_similarities.mean().std()
229
+ filtered_diseases = df_diseases_similarities.mean()[
230
+ df_diseases_similarities.mean() > mean - 0.2 * std
231
+ ].index.tolist()
232
  return filtered_diseases, df_diseases_similarities
233
 
234
+
235
+ def get_labels_of_diseases_from_uris(uris: List[str]) -> List[str]:
236
+ with engine.connect() as conn:
237
+ with conn.begin():
238
+ joined_uris = ", ".join([f"'{uri}'" for uri in uris])
239
+ sql = f"""
240
+ SELECT label FROM Test.EntityEmbeddings
241
+ WHERE uri IN ({joined_uris})
242
+ """
243
+ result = conn.execute(text(sql))
244
+ data = result.fetchall()
245
+
246
+ return [row[0] for row in data]
247
+
248
+
249
  def to_capitalized_case(string: str) -> str:
250
  string = string.replace("_", " ")
251
  if string.isupper():
252
  return string[0] + string[1:].lower()
253
+
254
+
255
  def list_to_capitalized_case(strings: List[str]) -> str:
256
  strings = [to_capitalized_case(s) for s in strings]
257
  return ", ".join(strings)
258
 
259
+
260
  def render_trial_details(trial: dict) -> None:
261
+ # TODO: handle key errors for all cases (→ do not render)
262
+
263
+ official_title = trial["protocolSection"]["identificationModule"]["officialTitle"]
264
+ st.write(f"##### {official_title}")
265
+
266
+ try:
267
+ st.write(trial["protocolSection"]["descriptionModule"]["briefSummary"])
268
+ except KeyError:
269
+ try:
270
+ st.write(
271
+ trial["protocolSection"]["descriptionModule"]["detailedDescription"]
272
+ )
273
+ except KeyError:
274
+ st.error("No description available.")
275
+
276
+ st.write("###### Status")
277
+ try:
278
+ status_module = {
279
+ "Status": to_capitalized_case(
280
+ trial["protocolSection"]["statusModule"]["overallStatus"]
281
+ ),
282
+ "Status Date": trial["protocolSection"]["statusModule"][
283
+ "statusVerifiedDate"
284
+ ],
285
+ "Has Results": trial["hasResults"],
286
+ }
287
+ st.table(status_module)
288
+ except KeyError:
289
+ st.info("No status information available.")
290
+
291
+ st.write("###### Design")
292
+ try:
293
+ design_module = {
294
+ "Study Type": to_capitalized_case(
295
+ trial["protocolSection"]["designModule"]["studyType"]
296
+ ),
297
+ "Phases": list_to_capitalized_case(
298
+ trial["protocolSection"]["designModule"]["phases"]
299
+ ),
300
+ "Allocation": to_capitalized_case(
301
+ trial["protocolSection"]["designModule"]["designInfo"]["allocation"]
302
+ ),
303
+ "Primary Purpose": to_capitalized_case(
304
+ trial["protocolSection"]["designModule"]["designInfo"]["primaryPurpose"]
305
+ ),
306
+ "Participants": trial["protocolSection"]["designModule"]["enrollmentInfo"][
307
+ "count"
308
+ ],
309
+ "Masking": to_capitalized_case(
310
+ trial["protocolSection"]["designModule"]["designInfo"]["maskingInfo"][
311
+ "masking"
312
+ ]
313
+ ),
314
+ "Who Masked": list_to_capitalized_case(
315
+ trial["protocolSection"]["designModule"]["designInfo"]["maskingInfo"][
316
+ "whoMasked"
317
+ ]
318
+ ),
319
+ }
320
+ st.table(design_module)
321
+ except KeyError:
322
+ st.info("No design information available.")
323
+
324
+ st.write("###### Interventions")
325
+ try:
326
+ interventions_module = {}
327
+ for intervention in trial["protocolSection"]["armsInterventionsModule"][
328
+ "interventions"
329
+ ]:
330
+ name = intervention["name"]
331
+ desc = intervention["description"]
332
+ interventions_module[name] = desc
333
+ st.table(interventions_module)
334
+ except KeyError:
335
+ st.info("No interventions information available.")
336
+
337
+ # Button to go to ClinicalTrials.gov and see the trial. It takes the user to the official page of the trial.
338
+ st.markdown(
339
+ f"See more in [ClinicalTrials.gov](https://clinicaltrials.gov/study/{trial['protocolSection']['identificationModule']['nctId']})"
340
+ )
341
+
342
 
343
  if __name__ == "__main__":
344
  username = "demo"