ClinicalTrialV3 / app.py
Satoc's picture
gegege
e1553b0
import gradio as gr
import pandas as pd
from OpenAITools.FetchTools import fetch_clinical_trials
from langchain_openai import ChatOpenAI
from langchain_groq import ChatGroq
from OpenAITools.CrinicalTrialTools import SimpleClinicalTrialAgent, GraderAgent, LLMTranslator, generate_ex_question_English
from OpenAITools.JRCTTools import get_matched_df,GetJRCTCriteria
from sentence_transformers import SentenceTransformer
from sentence_transformers import util
# モデルとエージェントの初期化
groq = ChatGroq(model_name="llama3-70b-8192", temperature=0)
translator = LLMTranslator(groq)
CriteriaCheckAgent = SimpleClinicalTrialAgent(groq)
grader_agent = GraderAgent(groq)
selectionModel = SentenceTransformer('pritamdeka/S-PubMedBert-MS-MARCO')
# データフレームを生成する関数
def generate_dataframe(age, sex, tumor_type, GeneMutation, Meseable, Biopsiable):
# 日本語の腫瘍タイプを英語に翻訳
TumorName = translator.translate(tumor_type)
# 質問文を生成
ex_question = generate_ex_question_English(age, sex, TumorName, GeneMutation, Meseable, Biopsiable)
# 臨床試験データの取得
basedf = pd.read_csv("ClinicalTrialCSV/JRCT20241215CancerPost.csv", index_col=0)
df = get_matched_df(basedf=basedf, query=TumorName, model=selectionModel, threshold=0.925)
df['AgentJudgment'] = None
df['AgentGrade'] = None
# 臨床試験の適格性の評価
progress = gr.Progress(track_tqdm=True)
for i in range(len(df)):
TargetCriteria = GetJRCTCriteria(df, i)
AgentJudgment = CriteriaCheckAgent.evaluate_eligibility(TargetCriteria, ex_question)
AgentGrade = grader_agent.evaluate_eligibility(AgentJudgment)
# df.locを使って値を代入(行・列名で指定)
df.loc[df.index[i], 'AgentJudgment'] = AgentJudgment
df.loc[df.index[i], 'AgentGrade'] = AgentGrade
progress((i + 1) / len(df))
# 列を指定した順に並び替え
columns_order = ['JRCT ID', 'Title', '研究・治験の目的','AgentJudgment', 'AgentGrade','主たる選択基準', '主たる除外基準','Inclusion Criteria','Exclusion Criteria','NCT No', 'JapicCTI No']
df = df[columns_order]
return df, df # フィルタ用と表示用にデータフレームを返す
# 特定のAgentGrade(yes, no, unclear)に基づいて行をフィルタリングする関数
def filter_rows_by_grade(original_df, grade):
df_filtered = original_df[original_df['AgentGrade'] == grade]
return df_filtered, df_filtered
# CSVとして保存しダウンロードする関数
def download_filtered_csv(df):
file_path = "filtered_data.csv"
df.to_csv(file_path, index=False)
return file_path
# 全体結果をCSVとして保存しダウンロードする関数
def download_full_csv(df):
file_path = "full_data.csv"
df.to_csv(file_path, index=False)
return file_path
# Gradioインターフェースの作成
with gr.Blocks() as demo:
gr.Markdown("## 臨床試験適格性評価インターフェース")
# 各種入力フィールド
age_input = gr.Textbox(label="Age", placeholder="例: 65")
sex_input = gr.Dropdown(choices=["男性", "女性"], label="Sex")
tumor_type_input = gr.Textbox(label="Tumor Type", placeholder="例: gastric cancer, 日本でも良いですが英語の方が精度が高いです。")
gene_mutation_input = gr.Textbox(label="Gene Mutation", placeholder="例: HER2")
measurable_input = gr.Dropdown(choices=["有り", "無し", "不明"], label="Measurable Tumor")
biopsiable_input = gr.Dropdown(choices=["有り", "無し", "不明"], label="Biopsiable Tumor")
# データフレーム表示エリア
dataframe_output = gr.DataFrame()
original_df = gr.State()
filtered_df = gr.State()
# データフレーム生成ボタン
generate_button = gr.Button("Generate Clinical Trials Data")
# フィルタリングボタン
yes_button = gr.Button("Show Eligible Trials")
no_button = gr.Button("Show Ineligible Trials")
unclear_button = gr.Button("Show Unclear Trials")
# ダウンロードボタン
download_filtered_button = gr.Button("Download Filtered Data")
download_filtered_output = gr.File(label="Download Filtered Data")
download_full_button = gr.Button("Download Full Data")
download_full_output = gr.File(label="Download Full Data")
# ボタン動作の設定
generate_button.click(fn=generate_dataframe, inputs=[age_input, sex_input, tumor_type_input, gene_mutation_input, measurable_input, biopsiable_input], outputs=[dataframe_output, original_df])
yes_button.click(fn=filter_rows_by_grade, inputs=[original_df, gr.State("yes")], outputs=[dataframe_output, filtered_df])
no_button.click(fn=filter_rows_by_grade, inputs=[original_df, gr.State("no")], outputs=[dataframe_output, filtered_df])
unclear_button.click(fn=filter_rows_by_grade, inputs=[original_df, gr.State("unclear")], outputs=[dataframe_output, filtered_df])
download_filtered_button.click(fn=download_filtered_csv, inputs=filtered_df, outputs=download_filtered_output)
download_full_button.click(fn=download_full_csv, inputs=original_df, outputs=download_full_output)
if __name__ == "__main__":
demo.launch()