Spaces:
Sleeping
Sleeping
ACMCMC
commited on
Commit
·
52ee7a9
1
Parent(s):
90c8ced
changes to gpt inference
Browse files- app.py +3 -2
- llm_res.py +83 -46
app.py
CHANGED
@@ -14,7 +14,7 @@ from utils import (
|
|
14 |
get_clinical_trials_related_to_diseases,
|
15 |
get_clinical_records_by_ids
|
16 |
)
|
17 |
-
from llm_res import
|
18 |
import json
|
19 |
import numpy as np
|
20 |
from sentence_transformers import SentenceTransformer
|
@@ -81,8 +81,9 @@ with st.container():
|
|
81 |
status.json(json_of_clinical_trials)
|
82 |
# 7. Use an LLM to get a summary of the clinical trials, in plain text format.
|
83 |
status.write("Getting a summary of the clinical trials...")
|
84 |
-
response =
|
85 |
print(f'Response from LLM: {response}')
|
|
|
86 |
# 8. Use an LLM to extract numerical data from the clinical trials (e.g. number of patients, number of deaths, etc.). Get summary statistics out of that.
|
87 |
status.write("Getting summary statistics of the clinical trials...")
|
88 |
# 9. Show the results to the user: graph of the diseases chosen, summary of the clinical trials, summary statistics of the clinical trials, and list of the details of the clinical trials considered
|
|
|
14 |
get_clinical_trials_related_to_diseases,
|
15 |
get_clinical_records_by_ids
|
16 |
)
|
17 |
+
from llm_res import get_short_summary_out_of_json_files
|
18 |
import json
|
19 |
import numpy as np
|
20 |
from sentence_transformers import SentenceTransformer
|
|
|
81 |
status.json(json_of_clinical_trials)
|
82 |
# 7. Use an LLM to get a summary of the clinical trials, in plain text format.
|
83 |
status.write("Getting a summary of the clinical trials...")
|
84 |
+
response = get_short_summary_out_of_json_files(json_of_clinical_trials)
|
85 |
print(f'Response from LLM: {response}')
|
86 |
+
status.write(f'Response from LLM: {response}')
|
87 |
# 8. Use an LLM to extract numerical data from the clinical trials (e.g. number of patients, number of deaths, etc.). Get summary statistics out of that.
|
88 |
status.write("Getting summary statistics of the clinical trials...")
|
89 |
# 9. Show the results to the user: graph of the diseases chosen, summary of the clinical trials, summary statistics of the clinical trials, and list of the details of the clinical trials considered
|
llm_res.py
CHANGED
@@ -1,27 +1,27 @@
|
|
|
|
1 |
import json
|
2 |
-
from langchain_community.document_loaders.csv_loader import CSVLoader
|
3 |
-
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
4 |
-
import pandas as pd
|
5 |
-
import langchain
|
6 |
import os
|
|
|
|
|
|
|
7 |
import openai
|
8 |
-
import
|
|
|
|
|
9 |
from langchain import OpenAI
|
|
|
10 |
from langchain.chains.qa_with_sources.loading import load_qa_with_sources_chain
|
11 |
-
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
12 |
-
from langchain_community.document_loaders import JSONLoader
|
13 |
from langchain.document_loaders import UnstructuredURLLoader
|
14 |
from langchain.embeddings import OpenAIEmbeddings
|
|
|
15 |
from langchain.vectorstores import FAISS
|
|
|
|
|
16 |
from langchain_core.prompts import ChatPromptTemplate
|
17 |
from langchain_core.pydantic_v1 import BaseModel, Field
|
18 |
from langchain_openai import ChatOpenAI
|
19 |
-
from
|
20 |
-
from langchain_core.
|
21 |
-
from langchain_openai import ChatOpenAI
|
22 |
-
from typing import List, Dict, Any
|
23 |
-
import requests
|
24 |
-
from dotenv import load_dotenv
|
25 |
|
26 |
load_dotenv()
|
27 |
|
@@ -78,17 +78,17 @@ def process_json_data_for_llm(data):
|
|
78 |
except:
|
79 |
status = ""
|
80 |
try:
|
81 |
-
|
82 |
"briefSummary"
|
83 |
]
|
84 |
except:
|
85 |
-
|
86 |
try:
|
87 |
-
|
88 |
"detailedDescription"
|
89 |
]
|
90 |
except:
|
91 |
-
|
92 |
try:
|
93 |
conditions = item["protocolSection"]["conditionsModule"]["conditions"]
|
94 |
except:
|
@@ -123,8 +123,8 @@ def process_json_data_for_llm(data):
|
|
123 |
"organization_name": organization_name,
|
124 |
"project_title": project_title,
|
125 |
"status": status,
|
126 |
-
"
|
127 |
-
"
|
128 |
"keywords": keywords,
|
129 |
"interventions": interventions,
|
130 |
"primary_outcomes": primary_outcomes,
|
@@ -137,22 +137,56 @@ def process_json_data_for_llm(data):
|
|
137 |
# print(ele)
|
138 |
|
139 |
|
140 |
-
def
|
141 |
-
|
142 |
-
"""
|
143 |
-
Extract the desired information from the following list of JSON clinical trials.
|
144 |
|
145 |
-
|
|
|
146 |
|
147 |
-
|
148 |
-
{input}
|
149 |
|
150 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
151 |
)
|
152 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
153 |
class Classification(BaseModel):
|
154 |
description: str = Field(
|
155 |
-
description="text description grouping all the clinical trials using
|
156 |
)
|
157 |
project_title: list = Field(
|
158 |
description="Extract the project title of all the clinical trials"
|
@@ -160,9 +194,9 @@ def llm_config():
|
|
160 |
status: list = Field(
|
161 |
description="Extract the status of all the clinical trials"
|
162 |
)
|
163 |
-
keywords: list = Field(
|
164 |
-
|
165 |
-
)
|
166 |
interventions: list = Field(
|
167 |
description="describe the interventions for each clinical trial using title, name and description"
|
168 |
)
|
@@ -170,17 +204,17 @@ def llm_config():
|
|
170 |
description="get the primary outcomes of each clinical trial"
|
171 |
)
|
172 |
# secondary_outcomes: list= Field(description= "get the secondary outcomes of each clinical trial")
|
173 |
-
eligibility: list = Field(
|
174 |
-
|
175 |
-
)
|
176 |
# healthy_volunteers: list= Field(description= "determine whether the clinical trial requires healthy volunteers")
|
177 |
-
minimum_age: list = Field(
|
178 |
-
|
179 |
-
)
|
180 |
-
maximum_age: list = Field(
|
181 |
-
|
182 |
-
)
|
183 |
-
gender: list = Field(description="get the gender from each experiment")
|
184 |
|
185 |
def get_dict(self):
|
186 |
return {
|
@@ -205,9 +239,11 @@ def llm_config():
|
|
205 |
openai_api_key=os.environ["OPENAI_API_KEY"],
|
206 |
).with_structured_output(Classification)
|
207 |
|
208 |
-
|
209 |
|
210 |
-
|
|
|
|
|
211 |
|
212 |
|
213 |
# clinical_record_info = get_clinical_records_by_ids(['NCT00841061', 'NCT03035123', 'NCT02272751', 'NCT03035123', 'NCT03055377'])
|
@@ -216,9 +252,10 @@ def llm_config():
|
|
216 |
# with open('data.json', 'w') as f:
|
217 |
# json.dump(clinical_record_info, f, indent=4)
|
218 |
|
219 |
-
tagging_chain = llm_config()
|
|
|
220 |
|
221 |
def process_dictionaty_with_llm_to_generate_response(json_contents):
|
222 |
processed_data = process_json_data_for_llm(json_contents)
|
223 |
-
res = tagging_chain.invoke({"input": processed_data})
|
224 |
-
return res
|
|
|
1 |
+
import ast
|
2 |
import json
|
|
|
|
|
|
|
|
|
3 |
import os
|
4 |
+
from typing import Any, Dict, List
|
5 |
+
|
6 |
+
import langchain
|
7 |
import openai
|
8 |
+
import pandas as pd
|
9 |
+
import requests
|
10 |
+
from dotenv import load_dotenv
|
11 |
from langchain import OpenAI
|
12 |
+
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
|
13 |
from langchain.chains.qa_with_sources.loading import load_qa_with_sources_chain
|
|
|
|
|
14 |
from langchain.document_loaders import UnstructuredURLLoader
|
15 |
from langchain.embeddings import OpenAIEmbeddings
|
16 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
17 |
from langchain.vectorstores import FAISS
|
18 |
+
from langchain_community.document_loaders import JSONLoader
|
19 |
+
from langchain_community.document_loaders.csv_loader import CSVLoader
|
20 |
from langchain_core.prompts import ChatPromptTemplate
|
21 |
from langchain_core.pydantic_v1 import BaseModel, Field
|
22 |
from langchain_openai import ChatOpenAI
|
23 |
+
from langchain.chains.llm import LLMChain
|
24 |
+
from langchain_core.prompts import PromptTemplate
|
|
|
|
|
|
|
|
|
25 |
|
26 |
load_dotenv()
|
27 |
|
|
|
78 |
except:
|
79 |
status = ""
|
80 |
try:
|
81 |
+
briefDescription = item["protocolSection"]["descriptionModule"][
|
82 |
"briefSummary"
|
83 |
]
|
84 |
except:
|
85 |
+
briefDescription = ""
|
86 |
try:
|
87 |
+
detailedDescription = item["protocolSection"]["descriptionModule"][
|
88 |
"detailedDescription"
|
89 |
]
|
90 |
except:
|
91 |
+
detailedDescription = ""
|
92 |
try:
|
93 |
conditions = item["protocolSection"]["conditionsModule"]["conditions"]
|
94 |
except:
|
|
|
123 |
"organization_name": organization_name,
|
124 |
"project_title": project_title,
|
125 |
"status": status,
|
126 |
+
"briefDescription": briefDescription,
|
127 |
+
"detailedDescription": detailedDescription,
|
128 |
"keywords": keywords,
|
129 |
"interventions": interventions,
|
130 |
"primary_outcomes": primary_outcomes,
|
|
|
137 |
# print(ele)
|
138 |
|
139 |
|
140 |
+
def get_short_summary_out_of_json_files(data_json):
|
141 |
+
prompt_template = """ You are an expert clinician working on the analysis of reports of clinical trials.
|
|
|
|
|
142 |
|
143 |
+
# Task
|
144 |
+
You will be given a set of descriptions of clinical trials. Your job is to come up with a short summary (100-200 words) of the descriptions of the clinical trials. Your users are clinical researchers who are experts in medicine, so you should be technical and specific, including scientific terms. Always be faithful to the original information written in the reports.
|
145 |
|
146 |
+
To write your summary, you will need to read the following examples, labeled as "Report 1", "Report 2", and so on. Your answer should be a single paragraph (100-200 words) that summarizes the general content of all the reports.
|
|
|
147 |
|
148 |
+
{text}
|
149 |
+
|
150 |
+
General summary:"""
|
151 |
+
|
152 |
+
prompt = PromptTemplate.from_template(prompt_template)
|
153 |
+
|
154 |
+
llm = ChatOpenAI(
|
155 |
+
temperature=0.4, model_name="gpt-4-turbo", api_key=os.environ["OPENAI_API_KEY"]
|
156 |
+
)
|
157 |
+
llm_chain = LLMChain(llm=llm, prompt=prompt)
|
158 |
+
|
159 |
+
# Define StuffDocumentsChain
|
160 |
+
stuff_chain = StuffDocumentsChain(
|
161 |
+
llm_chain=llm_chain, document_variable_name="text"
|
162 |
)
|
163 |
|
164 |
+
descriptions = [
|
165 |
+
(
|
166 |
+
x["detailedDescription"]
|
167 |
+
if "detailedDescription" in x and len(x["detailedDescription"]) > 0
|
168 |
+
else x["briefSummary"]
|
169 |
+
)
|
170 |
+
for x in data_json
|
171 |
+
if "detailedDescription" in x or "briefSummary" in x
|
172 |
+
]
|
173 |
+
|
174 |
+
combined_descriptions = ""
|
175 |
+
for i, description in enumerate(descriptions):
|
176 |
+
combined_descriptions += f"Report {i+1}:\n{description}\n"
|
177 |
+
|
178 |
+
print(f"Combined descriptions: {combined_descriptions}")
|
179 |
+
|
180 |
+
result = stuff_chain.run(combined_descriptions)
|
181 |
+
print(f"Result: {result}")
|
182 |
+
|
183 |
+
return result
|
184 |
+
|
185 |
+
|
186 |
+
def taggingTemplate():
|
187 |
class Classification(BaseModel):
|
188 |
description: str = Field(
|
189 |
+
description="text description grouping all the clinical trials using briefDescription and detailedDescription keys"
|
190 |
)
|
191 |
project_title: list = Field(
|
192 |
description="Extract the project title of all the clinical trials"
|
|
|
194 |
status: list = Field(
|
195 |
description="Extract the status of all the clinical trials"
|
196 |
)
|
197 |
+
# keywords: list = Field(
|
198 |
+
# description="Extract the most relevant keywords regrouping all the clinical trials"
|
199 |
+
# )
|
200 |
interventions: list = Field(
|
201 |
description="describe the interventions for each clinical trial using title, name and description"
|
202 |
)
|
|
|
204 |
description="get the primary outcomes of each clinical trial"
|
205 |
)
|
206 |
# secondary_outcomes: list= Field(description= "get the secondary outcomes of each clinical trial")
|
207 |
+
# eligibility: list = Field(
|
208 |
+
# description="get the eligibilityCriteria grouping all the clinical trials"
|
209 |
+
# )
|
210 |
# healthy_volunteers: list= Field(description= "determine whether the clinical trial requires healthy volunteers")
|
211 |
+
# minimum_age: list = Field(
|
212 |
+
# description="get the minimum age from each experiment"
|
213 |
+
# )
|
214 |
+
# maximum_age: list = Field(
|
215 |
+
# description="get the maximum age from each experiment"
|
216 |
+
# )
|
217 |
+
# gender: list = Field(description="get the gender from each experiment")
|
218 |
|
219 |
def get_dict(self):
|
220 |
return {
|
|
|
239 |
openai_api_key=os.environ["OPENAI_API_KEY"],
|
240 |
).with_structured_output(Classification)
|
241 |
|
242 |
+
stuff_chain = StuffDocumentsChain(llm_chain=llm, document_variable_name="text")
|
243 |
|
244 |
+
# tagging_chain = prompt_template | llm
|
245 |
+
|
246 |
+
# return tagging_chain
|
247 |
|
248 |
|
249 |
# clinical_record_info = get_clinical_records_by_ids(['NCT00841061', 'NCT03035123', 'NCT02272751', 'NCT03035123', 'NCT03055377'])
|
|
|
252 |
# with open('data.json', 'w') as f:
|
253 |
# json.dump(clinical_record_info, f, indent=4)
|
254 |
|
255 |
+
# tagging_chain = llm_config()
|
256 |
+
|
257 |
|
258 |
def process_dictionaty_with_llm_to_generate_response(json_contents):
|
259 |
processed_data = process_json_data_for_llm(json_contents)
|
260 |
+
# res = tagging_chain.invoke({"input": processed_data})
|
261 |
+
# return res
|