ACMCMC commited on
Commit
4bb7c94
·
1 Parent(s): 1fb895a
Files changed (3) hide show
  1. app.py +44 -10
  2. llm_res.py +49 -36
  3. utils.py +1 -1
app.py CHANGED
@@ -28,6 +28,7 @@ show_graph = False
28
  show_analyze_status = False
29
  show_overview = False
30
  show_details = False
 
31
 
32
  # IRIS connection
33
  username = "demo"
@@ -66,7 +67,7 @@ with st.container():
66
  diseases_related_to_the_user_text = get_diseases_related_to_a_textual_description(
67
  description_input, encoder
68
  )
69
- status.info(f'Found {len(diseases_related_to_the_user_text)} diseases related to the description you entered.')
70
  status.json(diseases_related_to_the_user_text, expanded=False)
71
  status.divider()
72
  # 3. Get the similarities of the embeddings of those diseases (cosine similarity of the embeddings of the nodes of such diseases)
@@ -94,26 +95,37 @@ with st.container():
94
  clinical_trials_related_to_the_diseases = get_clinical_trials_related_to_diseases(
95
  augmented_set_of_diseases, encoder
96
  )
97
- status.info(f'Found {len(clinical_trials_related_to_the_diseases)} clinical trials related to the diseases.')
98
  status.json(clinical_trials_related_to_the_diseases, expanded=False)
99
  status.divider()
100
  status.write("Getting the details of the clinical trials...")
101
  json_of_clinical_trials = get_clinical_records_by_ids(
102
  [trial["nct_id"] for trial in clinical_trials_related_to_the_diseases]
103
  )
 
104
  status.json(json_of_clinical_trials, expanded=False)
105
  status.divider()
106
  # 7. Use an LLM to get a summary of the clinical trials, in plain text format.
107
- status.write("Getting a summary of the clinical trials...")
108
- response = get_short_summary_out_of_json_files(json_of_clinical_trials)
109
- disease_overview = response
 
 
 
 
 
110
  try:
111
  # 8. Use an LLM to extract numerical data from the clinical trials (e.g. number of patients, number of deaths, etc.). Get summary statistics out of that.
112
  status.write("Getting summary statistics of the clinical trials...")
113
  response = tagging_insights_from_json(json_of_clinical_trials)
 
 
 
 
114
  print(f'Response from LLM tagging: {response}')
115
- status.write(f'Response from LLM tagging: {response}')
116
  except Exception as e:
 
117
  print(f'Error while extracting numerical data from the clinical trials: {e}')
118
  status.warning(f'Error while extracting numerical data from the clinical trials. This information will not be shown.')
119
  # 9. Show the results to the user: graph of the diseases chosen, summary of the clinical trials, summary statistics of the clinical trials, and list of the details of the clinical trials considered
@@ -170,10 +182,32 @@ $$"""
170
  # overview
171
  with st.container():
172
  if show_overview:
173
- st.write("## Overview of Related Clinical Trials")
174
- st.write(disease_overview)
175
- time.sleep(2)
176
- show_details = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
 
178
 
179
  # details
 
28
  show_analyze_status = False
29
  show_overview = False
30
  show_details = False
31
+ show_metrics = False
32
 
33
  # IRIS connection
34
  username = "demo"
 
67
  diseases_related_to_the_user_text = get_diseases_related_to_a_textual_description(
68
  description_input, encoder
69
  )
70
+ status.info(f'Selected {len(diseases_related_to_the_user_text)} diseases related to the description you entered.')
71
  status.json(diseases_related_to_the_user_text, expanded=False)
72
  status.divider()
73
  # 3. Get the similarities of the embeddings of those diseases (cosine similarity of the embeddings of the nodes of such diseases)
 
95
  clinical_trials_related_to_the_diseases = get_clinical_trials_related_to_diseases(
96
  augmented_set_of_diseases, encoder
97
  )
98
+ status.info(f'Selected {len(clinical_trials_related_to_the_diseases)} clinical trials related to the diseases.')
99
  status.json(clinical_trials_related_to_the_diseases, expanded=False)
100
  status.divider()
101
  status.write("Getting the details of the clinical trials...")
102
  json_of_clinical_trials = get_clinical_records_by_ids(
103
  [trial["nct_id"] for trial in clinical_trials_related_to_the_diseases]
104
  )
105
+ status.success(f'Details of the clinical trials obtained.')
106
  status.json(json_of_clinical_trials, expanded=False)
107
  status.divider()
108
  # 7. Use an LLM to get a summary of the clinical trials, in plain text format.
109
+ try:
110
+ status.write("Getting a summary of the clinical trials...")
111
+ response = get_short_summary_out_of_json_files(json_of_clinical_trials)
112
+ status.success("Summary of the clinical trials obtained.")
113
+ disease_overview = response
114
+ except Exception as e:
115
+ print(f'Error while getting a summary of the clinical trials: {e}')
116
+ status.warning(f'Error while getting a summary of the clinical trials. This information will not be shown.')
117
  try:
118
  # 8. Use an LLM to extract numerical data from the clinical trials (e.g. number of patients, number of deaths, etc.). Get summary statistics out of that.
119
  status.write("Getting summary statistics of the clinical trials...")
120
  response = tagging_insights_from_json(json_of_clinical_trials)
121
+ average_minimum_age = response["avg_min_age"]
122
+ average_maximum_age = response["avg_max_age"]
123
+ most_common_gender = response['most_common_gender']
124
+
125
  print(f'Response from LLM tagging: {response}')
126
+ status.success(f'Summary statistics of the clinical trials obtained.')
127
  except Exception as e:
128
+ raise e
129
  print(f'Error while extracting numerical data from the clinical trials: {e}')
130
  status.warning(f'Error while extracting numerical data from the clinical trials. This information will not be shown.')
131
  # 9. Show the results to the user: graph of the diseases chosen, summary of the clinical trials, summary statistics of the clinical trials, and list of the details of the clinical trials considered
 
182
  # overview
183
  with st.container():
184
  if show_overview:
185
+ try:
186
+ st.write("## Overview of Related Clinical Trials")
187
+ st.write(disease_overview)
188
+ time.sleep(1)
189
+ except Exception as e:
190
+ print(f'Error while showing the overview of the clinical trials: {e}')
191
+ finally:
192
+ show_metrics = True
193
+
194
+
195
+ with st.container():
196
+ if show_metrics:
197
+ try:
198
+ st.write("## Metrics of the Clinical Trials")
199
+ col1, col2, col3 = st.columns(3)
200
+ with col1:
201
+ st.metric("Average Minimum Age", average_minimum_age)
202
+ with col2:
203
+ st.metric("Average Maximum Age", average_maximum_age)
204
+ with col3:
205
+ st.metric("Most Common Gender", most_common_gender)
206
+ time.sleep(2)
207
+ except Exception as e:
208
+ print(f'Error while showing the metrics: {e}')
209
+ finally:
210
+ show_details = True
211
 
212
 
213
  # details
llm_res.py CHANGED
@@ -24,6 +24,7 @@ from langchain.chains.llm import LLMChain
24
  from langchain_core.prompts import PromptTemplate
25
  from collections import Counter
26
  import statistics
 
27
 
28
  load_dotenv()
29
 
@@ -134,11 +135,12 @@ def get_clinical_records_by_ids(clinical_record_ids: List[str]) -> List[Dict[str
134
  # "eligibility": eligibility,
135
  # }
136
  # filtered_data.append(filtered_item)
137
-
138
  # return filtered_data
139
  # # for ele in filtered_data:
140
  # # print(ele)
141
 
 
142
  def process_dictionaty_with_llm_to_generate_response(json_data):
143
  # processed_data = process_json_data_for_llm(json_data)
144
  # res = tagging_chain.invoke({"input": processed_data})
@@ -217,9 +219,10 @@ def process_dictionaty_with_llm_to_generate_response(json_data):
217
  "eligibility": eligibility,
218
  }
219
  filtered_data.append(filtered_item)
220
-
221
  return filtered_data
222
 
 
223
  def get_short_summary_out_of_json_files(data_json):
224
  prompt_template = """You are an expert on clinicial trials and their analysis of their reports.
225
 
@@ -272,29 +275,36 @@ General summary:"""
272
 
273
  return result
274
 
 
275
  def analyze_data(data):
276
- # Extract minimum and maximum ages
277
- min_ages = [int(age.split()[0]) for age in data['minimum_age'] if age]
278
- max_ages = [int(age.split()[0]) for age in data['maximum_age'] if age]
 
279
  # primary_timeframe= [int(age.split()[0]) for age in data['[primary_outcome]'] if age]
280
-
281
  # Calculate average minimum and maximum ages
282
  avg_min_age = statistics.mean(min_ages) if min_ages else None
283
  avg_max_age = statistics.mean(max_ages) if max_ages else None
284
-
285
  # Find most common gender
286
- gender_counter = Counter(data['gender'])
287
  most_common_gender = gender_counter.most_common(1)[0][0]
288
-
289
  # Flatten keywords list and find common keywords
290
- keywords = [keyword for sublist in data['keywords'] for keyword in sublist]
291
- common_keywords = [word for word, count in Counter(keywords).most_common()]
292
-
293
- return avg_min_age, avg_max_age, most_common_gender, common_keywords
 
 
 
 
 
294
 
295
  def tagging_insights_from_json(data_json):
296
- processed_json= process_dictionaty_with_llm_to_generate_response(data_json)
297
-
298
  tagging_prompt = ChatPromptTemplate.from_template(
299
  """
300
  You are an expert on clinicial trials and analysis of their reports.
@@ -307,6 +317,7 @@ def tagging_insights_from_json(data_json):
307
  {input}
308
  """
309
  )
 
310
  class Classification(BaseModel):
311
  # description: str = Field(
312
  # description="text description grouping all the clinical trials using briefDescription and detailedDescription keys"
@@ -317,25 +328,25 @@ def tagging_insights_from_json(data_json):
317
  # status: list = Field(
318
  # description="Extract the status of all the clinical trials"
319
  # )
320
- #keywords: list = Field(
321
  # description="Extract the most relevant keywords for each clinical trials"
322
- #)
323
  # interventions: list = Field(
324
  # description="describe the interventions for each clinical trial using title, name and description"
325
  # )
326
- #primary_outcomes: list = Field(
327
  # description="get the timeframe of each clinical trial"
328
- #)
329
- #secondary_outcomes: list= Field(description= "get the secondary outcomes of each clinical trial")
330
- #eligibility: list = Field(
331
  # description="get the timeframe of each clinical trial"
332
- #)
333
  # healthy_volunteers: list= Field(description= "determine whether the clinical trial requires healthy volunteers")
334
  minimum_age: list = Field(
335
- description="get the minimum age from each experiment"
336
  )
337
  maximum_age: list = Field(
338
- description="get the maximum age from each experiment"
339
  )
340
  gender: list = Field(description="get the gender from each experiment")
341
 
@@ -343,15 +354,15 @@ def tagging_insights_from_json(data_json):
343
  return {
344
  # "project_title": self.project_title,
345
  # "status": self.status,
346
- #"keywords": self.keywords,
347
  # "interventions": self.interventions,
348
- #"primary_outcomes": self.primary_outcomes,
349
- #"secondary_outcomes": self.secondary_outcomes,
350
  # "eligibility": self.eligibility,
351
  # "healthy_volunteers": self.healthy_volunteers,
352
  "minimum_age": self.minimum_age,
353
  "maximum_age": self.maximum_age,
354
- "gender": self.gender
355
  }
356
 
357
  # LLM
@@ -365,18 +376,20 @@ def tagging_insights_from_json(data_json):
365
 
366
  tagging_chain = tagging_prompt | llm
367
 
368
- res= tagging_chain.invoke({"input": processed_json})
369
- result_dict= res.get_dict()
370
 
371
- avg_min_age, avg_max_age, most_common_gender, common_keywords= analyze_data(result_dict)
 
 
372
 
373
- #stats_dict= {'Average Minimum age': avg_min_age,
374
  # 'Average Maximum age': avg_max_age,
375
  # 'Most common gender undergoing the trials': most_common_gender,
376
  # 'common keywords found in the trials': common_keywords}
377
-
378
- print(f"Result_tagging: {result_dict}")
379
- return result_dict#, stats_dict
380
 
381
 
382
  # clinical_record_info = get_clinical_records_by_ids(['NCT00841061', 'NCT03035123', 'NCT02272751', 'NCT03035123', 'NCT03055377'])
@@ -386,4 +399,4 @@ def tagging_insights_from_json(data_json):
386
  # json.dump(clinical_record_info, f, indent=4)
387
 
388
 
389
- # tagging_chain = tagging_insights_from_json(json_data)
 
24
  from langchain_core.prompts import PromptTemplate
25
  from collections import Counter
26
  import statistics
27
+ import regex as re
28
 
29
  load_dotenv()
30
 
 
135
  # "eligibility": eligibility,
136
  # }
137
  # filtered_data.append(filtered_item)
138
+
139
  # return filtered_data
140
  # # for ele in filtered_data:
141
  # # print(ele)
142
 
143
+
144
  def process_dictionaty_with_llm_to_generate_response(json_data):
145
  # processed_data = process_json_data_for_llm(json_data)
146
  # res = tagging_chain.invoke({"input": processed_data})
 
219
  "eligibility": eligibility,
220
  }
221
  filtered_data.append(filtered_item)
222
+
223
  return filtered_data
224
 
225
+
226
  def get_short_summary_out_of_json_files(data_json):
227
  prompt_template = """You are an expert on clinicial trials and their analysis of their reports.
228
 
 
275
 
276
  return result
277
 
278
+
279
  def analyze_data(data):
280
+ print(f"Data: {data}")
281
+ # Extract minimum and maximum ages: Turn ['18 Years', '20 Years'] into [18, 20]
282
+ min_ages = [int(re.search(r"\d+", age).group()) for age in data["minimum_age"] if age]
283
+ max_ages = [int(re.search(r"\d+", age).group()) for age in data["maximum_age"] if age]
284
  # primary_timeframe= [int(age.split()[0]) for age in data['[primary_outcome]'] if age]
285
+
286
  # Calculate average minimum and maximum ages
287
  avg_min_age = statistics.mean(min_ages) if min_ages else None
288
  avg_max_age = statistics.mean(max_ages) if max_ages else None
289
+
290
  # Find most common gender
291
+ gender_counter = Counter(data["gender"])
292
  most_common_gender = gender_counter.most_common(1)[0][0]
293
+
294
  # Flatten keywords list and find common keywords
295
+ #keywords = [keyword for sublist in data["keywords"] for keyword in sublist]
296
+ #common_keywords = [word for word, count in Counter(keywords).most_common()]
297
+
298
+ return {
299
+ "avg_min_age": avg_min_age,
300
+ "avg_max_age": avg_max_age,
301
+ "most_common_gender": most_common_gender
302
+ }
303
+
304
 
305
  def tagging_insights_from_json(data_json):
306
+ processed_json = process_dictionaty_with_llm_to_generate_response(data_json)
307
+
308
  tagging_prompt = ChatPromptTemplate.from_template(
309
  """
310
  You are an expert on clinicial trials and analysis of their reports.
 
317
  {input}
318
  """
319
  )
320
+
321
  class Classification(BaseModel):
322
  # description: str = Field(
323
  # description="text description grouping all the clinical trials using briefDescription and detailedDescription keys"
 
328
  # status: list = Field(
329
  # description="Extract the status of all the clinical trials"
330
  # )
331
+ # keywords: list = Field(
332
  # description="Extract the most relevant keywords for each clinical trials"
333
+ # )
334
  # interventions: list = Field(
335
  # description="describe the interventions for each clinical trial using title, name and description"
336
  # )
337
+ # primary_outcomes: list = Field(
338
  # description="get the timeframe of each clinical trial"
339
+ # )
340
+ # secondary_outcomes: list= Field(description= "get the secondary outcomes of each clinical trial")
341
+ # eligibility: list = Field(
342
  # description="get the timeframe of each clinical trial"
343
+ # )
344
  # healthy_volunteers: list= Field(description= "determine whether the clinical trial requires healthy volunteers")
345
  minimum_age: list = Field(
346
+ description="get the minimum age from each experiment"
347
  )
348
  maximum_age: list = Field(
349
+ description="get the maximum age from each experiment"
350
  )
351
  gender: list = Field(description="get the gender from each experiment")
352
 
 
354
  return {
355
  # "project_title": self.project_title,
356
  # "status": self.status,
357
+ # "keywords": self.keywords,
358
  # "interventions": self.interventions,
359
+ # "primary_outcomes": self.primary_outcomes,
360
+ # "secondary_outcomes": self.secondary_outcomes,
361
  # "eligibility": self.eligibility,
362
  # "healthy_volunteers": self.healthy_volunteers,
363
  "minimum_age": self.minimum_age,
364
  "maximum_age": self.maximum_age,
365
+ "gender": self.gender,
366
  }
367
 
368
  # LLM
 
376
 
377
  tagging_chain = tagging_prompt | llm
378
 
379
+ res = tagging_chain.invoke({"input": processed_json})
380
+ unprocessed_results_dict = res.get_dict()
381
 
382
+ results_dict = analyze_data(
383
+ unprocessed_results_dict
384
+ )
385
 
386
+ # stats_dict= {'Average Minimum age': avg_min_age,
387
  # 'Average Maximum age': avg_max_age,
388
  # 'Most common gender undergoing the trials': most_common_gender,
389
  # 'common keywords found in the trials': common_keywords}
390
+
391
+ print(f"Result_tagging: {results_dict}")
392
+ return results_dict
393
 
394
 
395
  # clinical_record_info = get_clinical_records_by_ids(['NCT00841061', 'NCT03035123', 'NCT02272751', 'NCT03035123', 'NCT03055377'])
 
399
  # json.dump(clinical_record_info, f, indent=4)
400
 
401
 
402
+ # tagging_chain = tagging_insights_from_json(json_data)
utils.py CHANGED
@@ -189,7 +189,7 @@ def get_clinical_trials_related_to_diseases(
189
  with engine.connect() as conn:
190
  with conn.begin():
191
  sql = f"""
192
- SELECT TOP 10 d.nct_id, VECTOR_COSINE(d.embedding, TO_VECTOR('{string_representation}', DOUBLE)) AS distance
193
  FROM Test.ClinicalTrials d
194
  ORDER BY distance DESC
195
  """
 
189
  with engine.connect() as conn:
190
  with conn.begin():
191
  sql = f"""
192
+ SELECT TOP 15 d.nct_id, VECTOR_COSINE(d.embedding, TO_VECTOR('{string_representation}', DOUBLE)) AS distance
193
  FROM Test.ClinicalTrials d
194
  ORDER BY distance DESC
195
  """