maseiler commited on
Commit
9b5c4aa
·
1 Parent(s): 27d40b9
Files changed (1) hide show
  1. app.py +144 -111
app.py CHANGED
@@ -19,9 +19,13 @@ import numpy as np
19
  from sentence_transformers import SentenceTransformer
20
 
21
 
22
- begin = st.container()
23
-
 
 
 
24
 
 
25
  username = "demo"
26
  password = "demo"
27
  hostname = os.getenv("IRIS_HOSTNAME", "localhost")
@@ -30,116 +34,145 @@ namespace = "USER"
30
  CONNECTION_STRING = f"iris://{username}:{password}@{hostname}:{port}/{namespace}"
31
  engine = create_engine(CONNECTION_STRING)
32
 
33
- begin.write("# Klìnic")
34
 
35
- description_input = begin.text_input(
36
- label="Enter the disease description 👇",
37
- placeholder="A disease that causes memory loss and other cognitive impairments.",
38
- )
39
- if begin.button("Analyze 🔎"):
40
- # 1. Embed the textual description that the user entered using the model
41
- # 2. Get 5 diseases with the highest cosine silimarity from the DB
42
- encoder = SentenceTransformer("allenai-specter")
43
- diseases_related_to_the_user_text = get_diseases_related_to_a_textual_description(
44
- description_input, encoder
45
- )
46
- # for disease_label in diseases_related_to_the_user_text:
47
- # st.text(disease_label)
48
- # 3. Get the similarities of the embeddings of those diseases (cosine similarity of the embeddings of the nodes of such diseases)
49
- diseases_uris = [disease["uri"] for disease in diseases_related_to_the_user_text]
50
- get_similarities_among_diseases_uris(diseases_uris)
51
- print(diseases_related_to_the_user_text)
52
- # 4. Potentially filter out the diseases that are not similar enough (e.g. similarity < 0.8)
53
- # 5. Augment the set of diseases: add new diseases that are similar to the ones that are already in the set, until we get 10-15 diseases
54
- augmented_set_of_diseases = augment_the_set_of_diseaces(diseases_uris)
55
- print(augmented_set_of_diseases)
56
- # 6. Query the embeddings of the diseases related to each clinical trial (also in the DB), to get the most similar clinical trials to our set of diseases
57
- clinical_trials_related_to_the_diseases = get_clinical_trials_related_to_diseases(
58
- augmented_set_of_diseases, encoder
59
- )
60
- print(f'clinical_trials_related_to_the_diseases: {clinical_trials_related_to_the_diseases}')
61
- json_of_clinical_trials = get_clinical_records_by_ids(
62
- [trial["nct_id"] for trial in clinical_trials_related_to_the_diseases]
63
- )
64
- print(f'json_of_clinical_trials: {json_of_clinical_trials}')
65
- # 8. Use an LLM to extract numerical data from the clinical trials (e.g. number of patients, number of deaths, etc.). Get summary statistics out of that.
66
- # 9. Show the results to the user: graph of the diseases chosen, summary of the clinical trials, summary statistics of the clinical trials, and list of the details of the clinical trials considered
67
- graph_of_diseases = agraph(
68
- nodes=[
69
- Node(id="A", label="Node A", size=10),
70
- Node(id="B", label="Node B", size=10),
71
- Node(id="C", label="Node C", size=10),
72
- Node(id="D", label="Node D", size=10),
73
- Node(id="E", label="Node E", size=10),
74
- Node(id="F", label="Node F", size=10),
75
- Node(id="G", label="Node G", size=10),
76
- Node(id="H", label="Node H", size=10),
77
- Node(id="I", label="Node I", size=10),
78
- Node(id="J", label="Node J", size=10),
79
- ],
80
- edges=[
81
- Edge(source="A", target="B"),
82
- Edge(source="B", target="C"),
83
- Edge(source="C", target="D"),
84
- Edge(source="D", target="E"),
85
- Edge(source="E", target="F"),
86
- Edge(source="F", target="G"),
87
- Edge(source="G", target="H"),
88
- Edge(source="H", target="I"),
89
- Edge(source="I", target="J"),
90
- ],
91
- config=Config(height=500, width=500),
92
- )
93
- # TODO: also when user clicks enter
94
-
95
- begin.write(":red[Here should be the graph]") # TODO remove
96
- chart_data = pd.DataFrame(
97
- np.random.randn(20, 3), columns=["a", "b", "c"]
98
- ) # TODO remove
99
- begin.scatter_chart(chart_data) # TODO remove
100
-
101
- begin.write("## Disease Overview")
102
- disease_overview = ":red[lorem ipsum]" # TODO
103
- begin.write(disease_overview)
104
-
105
- begin.write("## Clinical Trials Details")
106
- trials = []
107
- # TODO replace mock data
108
- with open("mock_trial.json") as f:
109
- d = json.load(f)
110
- for i in range(0, 5):
111
- trials.append(d)
112
-
113
- for trial in trials:
114
- with st.expander(f"{trial['protocolSection']['identificationModule']['nctId']}"):
115
- official_title = trial["protocolSection"]["identificationModule"][
116
- "officialTitle"
117
- ]
118
- st.write(f"##### {official_title}")
119
-
120
- brief_summary = trial["protocolSection"]["descriptionModule"]["briefSummary"]
121
- st.write(brief_summary)
122
-
123
- status_module = {
124
- "Status": trial["protocolSection"]["statusModule"]["overallStatus"],
125
- "Status Date": trial["protocolSection"]["statusModule"][
126
- "statusVerifiedDate"
127
- ],
128
- }
129
- st.write("###### Status")
130
- st.table(status_module)
131
-
132
- design_module = {
133
- "Study Type": trial["protocolSection"]["designModule"]["studyType"],
134
- # "Phases": trial["protocolSection"]["designModule"]["phases"], # breaks formatting because it is an array
135
- "Allocation": trial["protocolSection"]["designModule"]["designInfo"][
136
- "allocation"
137
  ],
138
- "Participants": trial["protocolSection"]["designModule"]["enrollmentInfo"][
139
- "count"
 
 
 
 
 
 
 
 
140
  ],
141
- }
142
- st.write("###### Design")
143
- st.table(design_module)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
 
145
- # TODO more modules?
 
19
  from sentence_transformers import SentenceTransformer
20
 
21
 
22
+ # variables to reveal next steps
23
+ show_graph = False
24
+ show_analyze_status = False
25
+ show_overview = False
26
+ show_details = False
27
 
28
+ # IRIS connection
29
  username = "demo"
30
  password = "demo"
31
  hostname = os.getenv("IRIS_HOSTNAME", "localhost")
 
34
  CONNECTION_STRING = f"iris://{username}:{password}@{hostname}:{port}/{namespace}"
35
  engine = create_engine(CONNECTION_STRING)
36
 
 
37
 
38
+ st.title("Klìnic")
39
+ st.header("", divider='rainbow')
40
+ st.text('') # dummy to add spacing
41
+
42
+ with st.container(): # user input
43
+ col1, col2 = st.columns((6, 1))
44
+
45
+ with col1:
46
+ description_input = st.text_area(label="Enter the disease description 👇", placeholder='A disease that causes memory loss and other cognitive impairments.')
47
+
48
+ with col2:
49
+ st.text('') # dummy to center vertically
50
+ st.text('') # dummy to center vertically
51
+ st.text('') # dummy to center vertically
52
+ show_analyze_status = st.button("Analyze 🔎")
53
+
54
+
55
+ # analyze
56
+ with st.container():
57
+ if show_analyze_status:
58
+ with st.status("Analyzing...") as status:
59
+ # 1. Embed the textual description that the user entered using the model
60
+ # 2. Get 5 diseases with the highest cosine silimarity from the DB
61
+ encoder = SentenceTransformer("allenai-specter")
62
+ diseases_related_to_the_user_text = get_diseases_related_to_a_textual_description(
63
+ description_input, encoder
64
+ )
65
+ # for disease_label in diseases_related_to_the_user_text:
66
+ # st.text(disease_label)
67
+ # 3. Get the similarities of the embeddings of those diseases (cosine similarity of the embeddings of the nodes of such diseases)
68
+ diseases_uris = [disease["uri"] for disease in diseases_related_to_the_user_text]
69
+ get_similarities_among_diseases_uris(diseases_uris)
70
+ #print(diseases_related_to_the_user_text)
71
+ # 4. Potentially filter out the diseases that are not similar enough (e.g. similarity < 0.8)
72
+ # 5. Augment the set of diseases: add new diseases that are similar to the ones that are already in the set, until we get 10-15 diseases
73
+ augmented_set_of_diseases = augment_the_set_of_diseaces(diseases_uris)
74
+ #print(augmented_set_of_diseases)
75
+ # 6. Query the embeddings of the diseases related to each clinical trial (also in the DB), to get the most similar clinical trials to our set of diseases
76
+ clinical_trials_related_to_the_diseases = get_clinical_trials_related_to_diseases(
77
+ augmented_set_of_diseases, encoder
78
+ )
79
+ #print(f'clinical_trials_related_to_the_diseases: {clinical_trials_related_to_the_diseases}')
80
+ json_of_clinical_trials = get_clinical_records_by_ids(
81
+ [trial["nct_id"] for trial in clinical_trials_related_to_the_diseases]
82
+ )
83
+ #print(f'json_of_clinical_trials: {json_of_clinical_trials}')
84
+ # 8. Use an LLM to extract numerical data from the clinical trials (e.g. number of patients, number of deaths, etc.). Get summary statistics out of that.
85
+ # 9. Show the results to the user: graph of the diseases chosen, summary of the clinical trials, summary statistics of the clinical trials, and list of the details of the clinical trials considered
86
+ status.update(label="Done!", state="complete")
87
+ time.sleep(1)
88
+ show_graph = True
89
+
90
+
91
+ # graph
92
+ with st.container():
93
+ if show_graph:
94
+ # TODO actual graph
95
+ graph_of_diseases = agraph(
96
+ nodes=[
97
+ Node(id="A", label="Node A", size=10),
98
+ Node(id="B", label="Node B", size=10),
99
+ Node(id="C", label="Node C", size=10),
100
+ Node(id="D", label="Node D", size=10),
101
+ Node(id="E", label="Node E", size=10),
102
+ Node(id="F", label="Node F", size=10),
103
+ Node(id="G", label="Node G", size=10),
104
+ Node(id="H", label="Node H", size=10),
105
+ Node(id="I", label="Node I", size=10),
106
+ Node(id="J", label="Node J", size=10),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  ],
108
+ edges=[
109
+ Edge(source="A", target="B"),
110
+ Edge(source="B", target="C"),
111
+ Edge(source="C", target="D"),
112
+ Edge(source="D", target="E"),
113
+ Edge(source="E", target="F"),
114
+ Edge(source="F", target="G"),
115
+ Edge(source="G", target="H"),
116
+ Edge(source="H", target="I"),
117
+ Edge(source="I", target="J"),
118
  ],
119
+ config=Config(height=500, width=500),
120
+ )
121
+ time.sleep(2)
122
+ show_overview = True
123
+
124
+
125
+ # overview
126
+ with st.container():
127
+ if show_overview:
128
+ st.write("## Disease Overview")
129
+ disease_overview = ":red[lorem ipsum]" # TODO
130
+ st.write(disease_overview)
131
+ time.sleep(2)
132
+ show_details = True
133
+
134
+
135
+ # details
136
+ with st.container():
137
+ if show_details:
138
+ st.write("## Clinical Trials Details")
139
+ trials = []
140
+ # TODO replace mock data
141
+ with open("mock_trial.json") as f:
142
+ d = json.load(f)
143
+ for i in range(0, 5):
144
+ trials.append(d)
145
+
146
+ for trial in trials:
147
+ with st.expander(f"{trial['protocolSection']['identificationModule']['nctId']}"):
148
+ official_title = trial["protocolSection"]["identificationModule"][
149
+ "officialTitle"
150
+ ]
151
+ st.write(f"##### {official_title}")
152
+
153
+ brief_summary = trial["protocolSection"]["descriptionModule"]["briefSummary"]
154
+ st.write(brief_summary)
155
+
156
+ status_module = {
157
+ "Status": trial["protocolSection"]["statusModule"]["overallStatus"],
158
+ "Status Date": trial["protocolSection"]["statusModule"][
159
+ "statusVerifiedDate"
160
+ ],
161
+ }
162
+ st.write("###### Status")
163
+ st.table(status_module)
164
+
165
+ design_module = {
166
+ "Study Type": trial["protocolSection"]["designModule"]["studyType"],
167
+ # "Phases": trial["protocolSection"]["designModule"]["phases"], # breaks formatting because it is an array
168
+ "Allocation": trial["protocolSection"]["designModule"]["designInfo"][
169
+ "allocation"
170
+ ],
171
+ "Participants": trial["protocolSection"]["designModule"]["enrollmentInfo"][
172
+ "count"
173
+ ],
174
+ }
175
+ st.write("###### Design")
176
+ st.table(design_module)
177
 
178
+ # TODO more modules?