File size: 5,170 Bytes
92df76e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import gradio as gr
import pandas as pd
from OpenAITools.FetchTools import fetch_clinical_trials, fetch_clinical_trials_jp
from langchain_openai import ChatOpenAI
from langchain_groq import ChatGroq
from OpenAITools.CrinicalTrialTools import QuestionModifierEnglish, TumorNameExtractor, SimpleClinicalTrialAgent, GraderAgent

# モデルとエージェントの初期化
groq = ChatGroq(model_name="llama3-70b-8192", temperature=0)
modifier = QuestionModifierEnglish(groq)
extractor = TumorNameExtractor(groq)
CriteriaCheckAgent = SimpleClinicalTrialAgent(groq)
grader_agent = GraderAgent(groq)

# データフレームを生成する関数
def generate_dataframe_from_question(ex_question):
    # Modify and extract tumor name
    modified_question = modifier.modify_question(ex_question)
    tumor_name = extractor.extract_tumor_name(ex_question)

    # Get clinical trials data based on tumor name
    df = fetch_clinical_trials(tumor_name)
    df['AgentJudgment'] = None
    df['AgentGrade'] = None
    
    # NCTIDのリストを作成し、プログレスバーを表示
    NCTIDs = list(df['NCTID'])
    progress = gr.Progress(track_tqdm=True)
    
    for i, nct_id in enumerate(NCTIDs):
        target_criteria = df.loc[df['NCTID'] == nct_id, 'Eligibility Criteria'].values[0]
        agent_judgment = CriteriaCheckAgent.evaluate_eligibility(target_criteria, modified_question)
        agent_grade = grader_agent.evaluate_eligibility(agent_judgment)
        
        # Update DataFrame
        df.loc[df['NCTID'] == nct_id, 'AgentJudgment'] = agent_judgment
        df.loc[df['NCTID'] == nct_id, 'AgentGrade'] = agent_grade
        
        # プログレスバーを更新(進行状況を浮動小数点数で渡す)
        progress((i + 1) / len(NCTIDs))
        
    # 列を指定した順に並び替え
    columns_order = ['NCTID', 'AgentGrade', 'Title', 'AgentJudgment', 'Japanes Locations', 
                     'Primary Completion Date', 'Cancer', 'Summary', 'Eligibility Criteria']
    df = df[columns_order]

    return df, df  # フィルタ用と表示用にデータフレームを返す

# AgentGradeが特定の値(yes, no, unclear)の行だけを選択する関数
def filter_rows_by_grade(original_df, grade):
    df_filtered = original_df[original_df['AgentGrade'] == grade]
    return df_filtered, df_filtered  # フィルタした結果を2つ返す

# CSVとして保存しダウンロードする関数
def download_filtered_csv(df):
    file_path = "filtered_data.csv"  # 現在の作業ディレクトリに保存
    df.to_csv(file_path, index=False)  # CSVファイルとして保存
    return file_path

# Gradioインターフェースの作成
with gr.Blocks() as demo:
    # 説明
    gr.Markdown("## 質問を入力して、患者さんが参加可能な臨床治験の情報を収集。参加可能か否かを判断根拠も含めて提示します。結果はcsvとしてダウンロード可能です")
    
    # 質問入力ボックス
    question_input = gr.Textbox(label="質問を入力してください", placeholder="例: 65歳男性でBRCA遺伝子の変異がある前立腺癌患者さんが参加できる臨床治験を教えて下さい。")

    # データフレーム表示エリア
    dataframe_output = gr.DataFrame()
    
    # データの元となるDataFrameを保存するためのstate
    original_df = gr.State()
    filtered_df = gr.State()

    # データフレームを生成するボタン
    generate_button = gr.Button("日本で行われている患者さんの癌腫の臨床治験を全て取得する")

    # ボタンでAgentGradeがyes, no, unclearの行のみ表示
    yes_button = gr.Button("AI Agentが患者さんが参加可能であると判断した臨床治験のみを表示")
    no_button = gr.Button("I Agentが患者さんが参加不可であると判断した臨床治験のみを表示")
    unclear_button = gr.Button("AI Agentが与えられた情報だけでは判断不可能とした臨床治験のみを表示")
    
    # フィルタ結果をダウンロードするボタン
    download_button = gr.Button("フィルタ結果をCSVとしてダウンロード")
    download_output = gr.File()  # ダウンロード用の出力エリア

    # データフレームを生成して保存
    generate_button.click(fn=generate_dataframe_from_question, inputs=question_input, outputs=[dataframe_output, original_df])

    # yesボタン、noボタン、unclearボタンが押されたらフィルタしたデータを表示
    yes_button.click(fn=filter_rows_by_grade, inputs=[original_df, gr.State("yes")], outputs=[dataframe_output, filtered_df])
    no_button.click(fn=filter_rows_by_grade, inputs=[original_df, gr.State("no")], outputs=[dataframe_output, filtered_df])
    unclear_button.click(fn=filter_rows_by_grade, inputs=[original_df, gr.State("unclear")], outputs=[dataframe_output, filtered_df])

    # ダウンロードボタンを押すとフィルタ結果のCSVをダウンロード
    download_button.click(fn=download_filtered_csv, inputs=filtered_df, outputs=download_output)


if __name__ == "__main__":
    demo.launch()