Spaces:
Running
Running
Liu Yiwen
commited on
Commit
·
fad121c
1
Parent(s):
f312fcb
修复单个target的bug,增加模拟和云端的通信功能
Browse files- __pycache__/comm_test.cpython-311.pyc +0 -0
- __pycache__/comm_utils.cpython-311.pyc +0 -0
- __pycache__/utils.cpython-311.pyc +0 -0
- app.py +16 -11
- comm_utils.py +25 -0
- test_server.py +17 -0
- utils.py +1 -4
__pycache__/comm_test.cpython-311.pyc
ADDED
Binary file (1.02 kB). View file
|
|
__pycache__/comm_utils.cpython-311.pyc
ADDED
Binary file (1.48 kB). View file
|
|
__pycache__/utils.cpython-311.pyc
CHANGED
Binary files a/__pycache__/utils.cpython-311.pyc and b/__pycache__/utils.cpython-311.pyc differ
|
|
app.py
CHANGED
@@ -15,6 +15,7 @@ from datasets import Features, Image, Audio, Sequence
|
|
15 |
from typing import List, Tuple, Callable
|
16 |
|
17 |
from utils import ndarray_to_base64, clean_up_df, create_statistic, create_plot
|
|
|
18 |
|
19 |
class AppError(RuntimeError):
|
20 |
pass
|
@@ -233,10 +234,13 @@ with gr.Blocks() as demo:
|
|
233 |
statistics_textbox = gr.DataFrame()
|
234 |
with gr.Column(scale=3):
|
235 |
plot = gr.Plot()
|
236 |
-
|
|
|
|
|
|
|
237 |
# componets.append({"select_sample_box": select_sample_box,
|
238 |
# "statistics_textbox": statistics_textbox,
|
239 |
-
# "
|
240 |
# "plot": plot})
|
241 |
|
242 |
# with gr.Row():
|
@@ -261,11 +265,13 @@ with gr.Blocks() as demo:
|
|
261 |
if type(page) == str:
|
262 |
page = [page]
|
263 |
df_list, id_list = [], []
|
|
|
264 |
for i, page in enumerate(page):
|
265 |
df, max_page, info = get_page(dataset, config, split, page)
|
266 |
global tot_samples, tot_targets
|
267 |
-
tot_samples, tot_targets = max_page, len(df['target'][0]) if isinstance(df['target'][0], np.ndarray) else 1
|
268 |
-
|
|
|
269 |
df = clean_up_df(df, sub_targets)
|
270 |
row = df.iloc[0]
|
271 |
id_list.append(row['item_id'])
|
@@ -334,25 +340,24 @@ with gr.Blocks() as demo:
|
|
334 |
except AppError as err:
|
335 |
return show_error(str(err))
|
336 |
|
337 |
-
def save_to_file(user_input):
|
338 |
-
with open("user_input.txt", "w") as file:
|
339 |
-
file.write(user_input)
|
340 |
-
|
341 |
all_outputs = [cp_config, cp_split,
|
342 |
# cp_page, cp_goto_page, cp_goto_next_page,
|
343 |
cp_result, cp_info, cp_error,
|
344 |
select_sample_box, select_subtarget_box,
|
345 |
-
select_buttom, statistics_textbox,
|
346 |
cp_go.click(show_dataset, inputs=[cp_dataset], outputs=all_outputs)
|
347 |
cp_config.change(show_dataset_at_config, inputs=[cp_dataset, cp_config], outputs=all_outputs)
|
348 |
cp_split.change(show_dataset_at_config_and_split, inputs=[cp_dataset, cp_config, cp_split], outputs=all_outputs)
|
349 |
# cp_goto_page.click(show_dataset_at_config_and_split_and_page, inputs=[cp_dataset, cp_config, cp_split, cp_page], outputs=all_outputs)
|
350 |
# cp_goto_next_page.click(show_dataset_at_config_and_split_and_next_page, inputs=[cp_dataset, cp_config, cp_split, cp_page], outputs=all_outputs)
|
351 |
-
|
352 |
select_buttom.click(show_dataset_at_config_and_split_and_page, inputs=[cp_dataset, cp_config, cp_split, select_sample_box, select_subtarget_box], outputs=all_outputs)
|
353 |
|
354 |
|
355 |
if __name__ == "__main__":
|
356 |
|
357 |
app = gr.mount_gradio_app(app, demo, path="/")
|
358 |
-
|
|
|
|
|
|
|
|
15 |
from typing import List, Tuple, Callable
|
16 |
|
17 |
from utils import ndarray_to_base64, clean_up_df, create_statistic, create_plot
|
18 |
+
from comm_utils import save_to_file, send_msg_to_server
|
19 |
|
20 |
class AppError(RuntimeError):
|
21 |
pass
|
|
|
234 |
statistics_textbox = gr.DataFrame()
|
235 |
with gr.Column(scale=3):
|
236 |
plot = gr.Plot()
|
237 |
+
with gr.Row():
|
238 |
+
user_input_box = gr.Textbox(placeholder="输入一些内容", label="输入", lines=5, interactive=True)
|
239 |
+
user_output_box = gr.Textbox(label="回答", lines=5, interactive=False)
|
240 |
+
user_io_buttom = gr.Button("发送", interactive=True)
|
241 |
# componets.append({"select_sample_box": select_sample_box,
|
242 |
# "statistics_textbox": statistics_textbox,
|
243 |
+
# "user_input_box": user_input_box,
|
244 |
# "plot": plot})
|
245 |
|
246 |
# with gr.Row():
|
|
|
265 |
if type(page) == str:
|
266 |
page = [page]
|
267 |
df_list, id_list = [], []
|
268 |
+
# TODO: 将以下内容封装为函数
|
269 |
for i, page in enumerate(page):
|
270 |
df, max_page, info = get_page(dataset, config, split, page)
|
271 |
global tot_samples, tot_targets
|
272 |
+
tot_samples, tot_targets = max_page, len(df['target'][0]) if isinstance(df['target'][0], np.ndarray) and df['target'][0].dtype == 'O' else 1
|
273 |
+
if 'all' in sub_targets:
|
274 |
+
sub_targets = [i for i in range(tot_targets)]
|
275 |
df = clean_up_df(df, sub_targets)
|
276 |
row = df.iloc[0]
|
277 |
id_list.append(row['item_id'])
|
|
|
340 |
except AppError as err:
|
341 |
return show_error(str(err))
|
342 |
|
|
|
|
|
|
|
|
|
343 |
all_outputs = [cp_config, cp_split,
|
344 |
# cp_page, cp_goto_page, cp_goto_next_page,
|
345 |
cp_result, cp_info, cp_error,
|
346 |
select_sample_box, select_subtarget_box,
|
347 |
+
select_buttom, statistics_textbox, plot]
|
348 |
cp_go.click(show_dataset, inputs=[cp_dataset], outputs=all_outputs)
|
349 |
cp_config.change(show_dataset_at_config, inputs=[cp_dataset, cp_config], outputs=all_outputs)
|
350 |
cp_split.change(show_dataset_at_config_and_split, inputs=[cp_dataset, cp_config, cp_split], outputs=all_outputs)
|
351 |
# cp_goto_page.click(show_dataset_at_config_and_split_and_page, inputs=[cp_dataset, cp_config, cp_split, cp_page], outputs=all_outputs)
|
352 |
# cp_goto_next_page.click(show_dataset_at_config_and_split_and_next_page, inputs=[cp_dataset, cp_config, cp_split, cp_page], outputs=all_outputs)
|
353 |
+
user_io_buttom.click(send_msg_to_server, inputs=[user_input_box], outputs=[user_output_box])
|
354 |
select_buttom.click(show_dataset_at_config_and_split_and_page, inputs=[cp_dataset, cp_config, cp_split, select_sample_box, select_subtarget_box], outputs=all_outputs)
|
355 |
|
356 |
|
357 |
if __name__ == "__main__":
|
358 |
|
359 |
app = gr.mount_gradio_app(app, demo, path="/")
|
360 |
+
host = "127.0.0.1" if os.getenv("DEV") else "0.0.0.0"
|
361 |
+
import subprocess
|
362 |
+
subprocess.Popen(["python", "test_server.py"])
|
363 |
+
uvicorn.run(app, host=host, port=7860)
|
comm_utils.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
+
|
3 |
+
|
4 |
+
API_URL = "http://127.0.0.1:5000/api/process"
|
5 |
+
|
6 |
+
def save_to_file(user_input):
|
7 |
+
with open("user_input.txt", "w") as file:
|
8 |
+
file.write(user_input)
|
9 |
+
|
10 |
+
|
11 |
+
def send_msg_to_server(input_text):
|
12 |
+
try:
|
13 |
+
# 构造请求数据
|
14 |
+
payload = {"text": input_text}
|
15 |
+
headers = {"Content-Type": "application/json"}
|
16 |
+
|
17 |
+
# 发送请求
|
18 |
+
response = requests.post(API_URL, json=payload, headers=headers)
|
19 |
+
response.raise_for_status() # 检查是否请求成功
|
20 |
+
|
21 |
+
# 返回响应结果
|
22 |
+
result = response.json() # 假设服务器返回的是 JSON 格式
|
23 |
+
return result.get("processed_text", "No result returned.")
|
24 |
+
except requests.RequestException as e:
|
25 |
+
return f"请求失败:{e}"
|
test_server.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
from flask import Flask, request, jsonify
|
3 |
+
|
4 |
+
app = Flask(__name__)
|
5 |
+
|
6 |
+
@app.route('/api/process', methods=['POST'])
|
7 |
+
def process_text():
|
8 |
+
data = request.get_json()
|
9 |
+
input_text = data.get("text", "")
|
10 |
+
|
11 |
+
time.sleep(1)
|
12 |
+
processed_text = f"{input_text[::-1]}"
|
13 |
+
|
14 |
+
return jsonify({"processed_text": processed_text})
|
15 |
+
|
16 |
+
if __name__ == "__main__":
|
17 |
+
app.run(host="127.0.0.1", port=5000)
|
utils.py
CHANGED
@@ -114,10 +114,7 @@ def clean_up_df(df: pd.DataFrame, rows_to_include: list[int]) -> pd.DataFrame:
|
|
114 |
"""
|
115 |
清理数据集,将嵌套的np.ndarray列展平为多列。
|
116 |
"""
|
117 |
-
|
118 |
-
rows_to_include = list(range(len(df['target'][0]))) if isinstance(df['target'][0], np.ndarray) else 1
|
119 |
-
else:
|
120 |
-
rows_to_include = sorted(rows_to_include)
|
121 |
|
122 |
df['timestamp'] = df.apply(lambda row: pd.date_range(
|
123 |
start=row['start'],
|
|
|
114 |
"""
|
115 |
清理数据集,将嵌套的np.ndarray列展平为多列。
|
116 |
"""
|
117 |
+
rows_to_include = sorted(rows_to_include)
|
|
|
|
|
|
|
118 |
|
119 |
df['timestamp'] = df.apply(lambda row: pd.date_range(
|
120 |
start=row['start'],
|