Liu Yiwen commited on
Commit
fad121c
·
1 Parent(s): f312fcb

修复单个target的bug,增加模拟和云端的通信功能

Browse files
__pycache__/comm_test.cpython-311.pyc ADDED
Binary file (1.02 kB). View file
 
__pycache__/comm_utils.cpython-311.pyc ADDED
Binary file (1.48 kB). View file
 
__pycache__/utils.cpython-311.pyc CHANGED
Binary files a/__pycache__/utils.cpython-311.pyc and b/__pycache__/utils.cpython-311.pyc differ
 
app.py CHANGED
@@ -15,6 +15,7 @@ from datasets import Features, Image, Audio, Sequence
15
  from typing import List, Tuple, Callable
16
 
17
  from utils import ndarray_to_base64, clean_up_df, create_statistic, create_plot
 
18
 
19
  class AppError(RuntimeError):
20
  pass
@@ -233,10 +234,13 @@ with gr.Blocks() as demo:
233
  statistics_textbox = gr.DataFrame()
234
  with gr.Column(scale=3):
235
  plot = gr.Plot()
236
- user_input_text = gr.Textbox(placeholder="输入一些内容")
 
 
 
237
  # componets.append({"select_sample_box": select_sample_box,
238
  # "statistics_textbox": statistics_textbox,
239
- # "user_input_text": user_input_text,
240
  # "plot": plot})
241
 
242
  # with gr.Row():
@@ -261,11 +265,13 @@ with gr.Blocks() as demo:
261
  if type(page) == str:
262
  page = [page]
263
  df_list, id_list = [], []
 
264
  for i, page in enumerate(page):
265
  df, max_page, info = get_page(dataset, config, split, page)
266
  global tot_samples, tot_targets
267
- tot_samples, tot_targets = max_page, len(df['target'][0]) if isinstance(df['target'][0], np.ndarray) else 1
268
-
 
269
  df = clean_up_df(df, sub_targets)
270
  row = df.iloc[0]
271
  id_list.append(row['item_id'])
@@ -334,25 +340,24 @@ with gr.Blocks() as demo:
334
  except AppError as err:
335
  return show_error(str(err))
336
 
337
- def save_to_file(user_input):
338
- with open("user_input.txt", "w") as file:
339
- file.write(user_input)
340
-
341
  all_outputs = [cp_config, cp_split,
342
  # cp_page, cp_goto_page, cp_goto_next_page,
343
  cp_result, cp_info, cp_error,
344
  select_sample_box, select_subtarget_box,
345
- select_buttom, statistics_textbox, user_input_text, plot]
346
  cp_go.click(show_dataset, inputs=[cp_dataset], outputs=all_outputs)
347
  cp_config.change(show_dataset_at_config, inputs=[cp_dataset, cp_config], outputs=all_outputs)
348
  cp_split.change(show_dataset_at_config_and_split, inputs=[cp_dataset, cp_config, cp_split], outputs=all_outputs)
349
  # cp_goto_page.click(show_dataset_at_config_and_split_and_page, inputs=[cp_dataset, cp_config, cp_split, cp_page], outputs=all_outputs)
350
  # cp_goto_next_page.click(show_dataset_at_config_and_split_and_next_page, inputs=[cp_dataset, cp_config, cp_split, cp_page], outputs=all_outputs)
351
- user_input_text.submit(save_to_file, inputs=user_input_text)
352
  select_buttom.click(show_dataset_at_config_and_split_and_page, inputs=[cp_dataset, cp_config, cp_split, select_sample_box, select_subtarget_box], outputs=all_outputs)
353
 
354
 
355
  if __name__ == "__main__":
356
 
357
  app = gr.mount_gradio_app(app, demo, path="/")
358
- uvicorn.run(app, host="127.0.0.1", port=7860)
 
 
 
 
15
  from typing import List, Tuple, Callable
16
 
17
  from utils import ndarray_to_base64, clean_up_df, create_statistic, create_plot
18
+ from comm_utils import save_to_file, send_msg_to_server
19
 
20
  class AppError(RuntimeError):
21
  pass
 
234
  statistics_textbox = gr.DataFrame()
235
  with gr.Column(scale=3):
236
  plot = gr.Plot()
237
+ with gr.Row():
238
+ user_input_box = gr.Textbox(placeholder="输入一些内容", label="输入", lines=5, interactive=True)
239
+ user_output_box = gr.Textbox(label="回答", lines=5, interactive=False)
240
+ user_io_buttom = gr.Button("发送", interactive=True)
241
  # componets.append({"select_sample_box": select_sample_box,
242
  # "statistics_textbox": statistics_textbox,
243
+ # "user_input_box": user_input_box,
244
  # "plot": plot})
245
 
246
  # with gr.Row():
 
265
  if type(page) == str:
266
  page = [page]
267
  df_list, id_list = [], []
268
+ # TODO: 将以下内容封装为函数
269
  for i, page in enumerate(page):
270
  df, max_page, info = get_page(dataset, config, split, page)
271
  global tot_samples, tot_targets
272
+ tot_samples, tot_targets = max_page, len(df['target'][0]) if isinstance(df['target'][0], np.ndarray) and df['target'][0].dtype == 'O' else 1
273
+ if 'all' in sub_targets:
274
+ sub_targets = [i for i in range(tot_targets)]
275
  df = clean_up_df(df, sub_targets)
276
  row = df.iloc[0]
277
  id_list.append(row['item_id'])
 
340
  except AppError as err:
341
  return show_error(str(err))
342
 
 
 
 
 
343
  all_outputs = [cp_config, cp_split,
344
  # cp_page, cp_goto_page, cp_goto_next_page,
345
  cp_result, cp_info, cp_error,
346
  select_sample_box, select_subtarget_box,
347
+ select_buttom, statistics_textbox, plot]
348
  cp_go.click(show_dataset, inputs=[cp_dataset], outputs=all_outputs)
349
  cp_config.change(show_dataset_at_config, inputs=[cp_dataset, cp_config], outputs=all_outputs)
350
  cp_split.change(show_dataset_at_config_and_split, inputs=[cp_dataset, cp_config, cp_split], outputs=all_outputs)
351
  # cp_goto_page.click(show_dataset_at_config_and_split_and_page, inputs=[cp_dataset, cp_config, cp_split, cp_page], outputs=all_outputs)
352
  # cp_goto_next_page.click(show_dataset_at_config_and_split_and_next_page, inputs=[cp_dataset, cp_config, cp_split, cp_page], outputs=all_outputs)
353
+ user_io_buttom.click(send_msg_to_server, inputs=[user_input_box], outputs=[user_output_box])
354
  select_buttom.click(show_dataset_at_config_and_split_and_page, inputs=[cp_dataset, cp_config, cp_split, select_sample_box, select_subtarget_box], outputs=all_outputs)
355
 
356
 
357
  if __name__ == "__main__":
358
 
359
  app = gr.mount_gradio_app(app, demo, path="/")
360
+ host = "127.0.0.1" if os.getenv("DEV") else "0.0.0.0"
361
+ import subprocess
362
+ subprocess.Popen(["python", "test_server.py"])
363
+ uvicorn.run(app, host=host, port=7860)
comm_utils.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+
3
+
4
+ API_URL = "http://127.0.0.1:5000/api/process"
5
+
6
+ def save_to_file(user_input):
7
+ with open("user_input.txt", "w") as file:
8
+ file.write(user_input)
9
+
10
+
11
+ def send_msg_to_server(input_text):
12
+ try:
13
+ # 构造请求数据
14
+ payload = {"text": input_text}
15
+ headers = {"Content-Type": "application/json"}
16
+
17
+ # 发送请求
18
+ response = requests.post(API_URL, json=payload, headers=headers)
19
+ response.raise_for_status() # 检查是否请求成功
20
+
21
+ # 返回响应结果
22
+ result = response.json() # 假设服务器返回的是 JSON 格式
23
+ return result.get("processed_text", "No result returned.")
24
+ except requests.RequestException as e:
25
+ return f"请求失败:{e}"
test_server.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ from flask import Flask, request, jsonify
3
+
4
+ app = Flask(__name__)
5
+
6
+ @app.route('/api/process', methods=['POST'])
7
+ def process_text():
8
+ data = request.get_json()
9
+ input_text = data.get("text", "")
10
+
11
+ time.sleep(1)
12
+ processed_text = f"{input_text[::-1]}"
13
+
14
+ return jsonify({"processed_text": processed_text})
15
+
16
+ if __name__ == "__main__":
17
+ app.run(host="127.0.0.1", port=5000)
utils.py CHANGED
@@ -114,10 +114,7 @@ def clean_up_df(df: pd.DataFrame, rows_to_include: list[int]) -> pd.DataFrame:
114
  """
115
  清理数据集,将嵌套的np.ndarray列展平为多列。
116
  """
117
- if 'all' in rows_to_include:
118
- rows_to_include = list(range(len(df['target'][0]))) if isinstance(df['target'][0], np.ndarray) else 1
119
- else:
120
- rows_to_include = sorted(rows_to_include)
121
 
122
  df['timestamp'] = df.apply(lambda row: pd.date_range(
123
  start=row['start'],
 
114
  """
115
  清理数据集,将嵌套的np.ndarray列展平为多列。
116
  """
117
+ rows_to_include = sorted(rows_to_include)
 
 
 
118
 
119
  df['timestamp'] = df.apply(lambda row: pd.date_range(
120
  start=row['start'],