Spaces:
Runtime error
Runtime error
File size: 3,642 Bytes
b5dbcf3 689c24b b5dbcf3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 |
import gradio as gr
from gradio import *
from run import *
szse_summary_df = pd.read_csv(os.path.join(main_path ,"data/df1.csv"))
tableqa_ = "数据表问答(编辑数据)"
default_val_dict = {
tableqa_ :{
"tqa_question": "EPS大于0且周涨跌大于5的平均市值是多少?",
"tqa_header": szse_summary_df.columns.tolist(),
"tqa_rows": szse_summary_df.values.tolist(),
"tqa_data_path": os.path.join(main_path ,"data/df1.csv"),
"tqa_answer": {
"sql_query": "SELECT AVG(col_4) FROM Mem_Table WHERE col_5 > 0 and col_3 > 5",
"cnt_num": 2,
"conclusion": [57.645]
}
}
}
def tableqa_layer(post_data):
question = post_data["question"]
table_rows = post_data["table_rows"]
table_header = post_data["table_header"]
assert all(map(lambda x: type(x) == type(""), [question, table_rows, table_header]))
table_rows = json.loads(table_rows)
table_header = json.loads(table_header)
assert all(map(lambda x: type(x) == type([]), [table_rows, table_header]))
if bool(table_rows) and bool(table_header):
assert len(table_header) == len(table_rows[0])
df = pd.DataFrame(table_rows, columns = table_header)
conclusion = single_table_pred(question, df)
return conclusion
def run_tableqa(*input):
question, data = input
header = data.columns.tolist()
rows = data.values.tolist()
rows = list(filter(lambda x: any(map(lambda xx: bool(xx), x)), rows))
assert all(map(lambda x: type(x) == type([]), [header, rows]))
header = json.dumps(header)
rows = json.dumps(rows)
assert all(map(lambda x: type(x) == type(""), [question, header, rows]))
resp = tableqa_layer(
{
"question": question,
"table_header": header,
"table_rows": rows
}
)
if "cnt_num" in resp:
if hasattr(resp["cnt_num"], "tolist"):
resp["cnt_num"] = resp["cnt_num"].tolist()
if "conclusion" in resp:
if hasattr(resp["conclusion"], "tolist"):
resp["conclusion"] = resp["conclusion"].tolist()
'''
import pickle as pkl
with open("resp.pkl", "wb") as f:
pkl.dump(resp, f)
print(resp)
'''
resp = json.loads(json.dumps(resp))
return resp
demo = gr.Blocks(css=".container { max-width: 800px; margin: auto; }")
with demo:
gr.Markdown("")
gr.Markdown("This _example_ was **drive** from <br/><b><h4>[https://github.com/svjack/tableQA-Chinese](https://github.com/svjack/tableQA-Chinese)</h4></b>\n")
with gr.Tabs():
#### tableqa
with gr.TabItem("数据表问答(TableQA)"):
with gr.Tabs():
with gr.TabItem(tableqa_):
tqa_question = gr.Textbox(
default_val_dict[tableqa_]["tqa_question"],
label = "问句:(输入)"
)
tqa_data = gr.Dataframe(
headers=default_val_dict[tableqa_]["tqa_header"],
value=default_val_dict[tableqa_]["tqa_rows"],
row_count = len(default_val_dict[tableqa_]["tqa_rows"]) + 1
)
tqa_answer = JSON(
default_val_dict[tableqa_]["tqa_answer"],
label = "问句:(输出)"
)
tqa_button = gr.Button("得到答案")
tqa_button.click(run_tableqa, inputs=[
tqa_question,
tqa_data
], outputs=tqa_answer)
demo.launch(server_name="0.0.0.0")
|