Liu Yiwen commited on
Commit
e03ca4d
·
1 Parent(s): 0edb9ff

更新了选择target的功能

Browse files
Files changed (3) hide show
  1. __pycache__/utils.cpython-311.pyc +0 -0
  2. app.py +18 -12
  3. utils.py +15 -10
__pycache__/utils.cpython-311.pyc CHANGED
Binary files a/__pycache__/utils.cpython-311.pyc and b/__pycache__/utils.cpython-311.pyc differ
 
app.py CHANGED
@@ -222,8 +222,10 @@ with gr.Blocks() as demo:
222
  # componets = []
223
  # for _ in range(TIME_PLOTS_NUM):
224
  with gr.Row():
225
- with gr.Column(scale=3):
226
- select_box = gr.Dropdown(choices=["items"], label="Select some items", multiselect=True, interactive=True)
 
 
227
  with gr.Column(scale=1):
228
  select_buttom = gr.Button("Show selected items")
229
  with gr.Row():
@@ -232,7 +234,7 @@ with gr.Blocks() as demo:
232
  with gr.Column(scale=3):
233
  plot = gr.Plot()
234
  user_input_text = gr.Textbox(placeholder="输入一些内容")
235
- # componets.append({"select_box": select_box,
236
  # "statistics_textbox": statistics_textbox,
237
  # "user_input_text": user_input_text,
238
  # "plot": plot})
@@ -248,7 +250,7 @@ with gr.Blocks() as demo:
248
  cp_result: gr.update(visible=False, value=""),
249
  }
250
 
251
- def show_dataset_at_config_and_split_and_page(dataset: str, config: str, split: str, page: str|List[str]) -> dict:
252
  try:
253
  ret = {}
254
  if dataset != 'Salesforce/lotsa_data':
@@ -261,15 +263,17 @@ with gr.Blocks() as demo:
261
  df_list, id_list = [], []
262
  for i, page in enumerate(page):
263
  df, max_page, info = get_page(dataset, config, split, page)
264
- df = clean_up_df(df)
 
 
 
265
  row = df.iloc[0]
266
  id_list.append(row['item_id'])
267
  # 将单行的DataFrame展开为新的DataFrame
268
  df_without_index = row.drop('item_id').to_frame().T
269
  df_expanded = df_without_index.apply(pd.Series.explode).reset_index(drop=True).fillna(0)
270
  df_list.append(df_expanded)
271
- global tot_samples
272
- tot_samples = max_page
273
  return {
274
  statistics_textbox: gr.update(value=create_statistic(df_list, id_list)),
275
  plot: gr.update(value=create_plot(df_list, id_list)),
@@ -292,8 +296,9 @@ with gr.Blocks() as demo:
292
  def show_dataset_at_config_and_split(dataset: str, config: str, split: str) -> dict:
293
  try:
294
  return {
295
- **show_dataset_at_config_and_split_and_page(dataset, config, split, "1"),
296
- select_box: gr.update(choices=[f"{i+1}" for i in range(tot_samples)], value=["1"]),
 
297
  # cp_page: gr.update(value="1", visible=True),
298
  # cp_goto_page: gr.update(visible=True),
299
  # cp_goto_next_page: gr.update(visible=True),
@@ -336,17 +341,18 @@ with gr.Blocks() as demo:
336
  all_outputs = [cp_config, cp_split,
337
  # cp_page, cp_goto_page, cp_goto_next_page,
338
  cp_result, cp_info, cp_error,
339
- select_box, select_buttom, statistics_textbox, user_input_text, plot]
 
340
  cp_go.click(show_dataset, inputs=[cp_dataset], outputs=all_outputs)
341
  cp_config.change(show_dataset_at_config, inputs=[cp_dataset, cp_config], outputs=all_outputs)
342
  cp_split.change(show_dataset_at_config_and_split, inputs=[cp_dataset, cp_config, cp_split], outputs=all_outputs)
343
  # cp_goto_page.click(show_dataset_at_config_and_split_and_page, inputs=[cp_dataset, cp_config, cp_split, cp_page], outputs=all_outputs)
344
  # cp_goto_next_page.click(show_dataset_at_config_and_split_and_next_page, inputs=[cp_dataset, cp_config, cp_split, cp_page], outputs=all_outputs)
345
  user_input_text.submit(save_to_file, inputs=user_input_text)
346
- select_buttom.click(show_dataset_at_config_and_split_and_page, inputs=[cp_dataset, cp_config, cp_split, select_box], outputs=all_outputs)
347
 
348
 
349
  if __name__ == "__main__":
350
 
351
  app = gr.mount_gradio_app(app, demo, path="/")
352
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
222
  # componets = []
223
  # for _ in range(TIME_PLOTS_NUM):
224
  with gr.Row():
225
+ with gr.Column(scale=2):
226
+ select_sample_box = gr.Dropdown(choices=["items"], label="Select some items", multiselect=True, interactive=True)
227
+ with gr.Column(scale=2):
228
+ select_subtarget_box = gr.Dropdown(choices=["subtargets"], label="Select some subtargets", multiselect=True, interactive=True)
229
  with gr.Column(scale=1):
230
  select_buttom = gr.Button("Show selected items")
231
  with gr.Row():
 
234
  with gr.Column(scale=3):
235
  plot = gr.Plot()
236
  user_input_text = gr.Textbox(placeholder="输入一些内容")
237
+ # componets.append({"select_sample_box": select_sample_box,
238
  # "statistics_textbox": statistics_textbox,
239
  # "user_input_text": user_input_text,
240
  # "plot": plot})
 
250
  cp_result: gr.update(visible=False, value=""),
251
  }
252
 
253
+ def show_dataset_at_config_and_split_and_page(dataset: str, config: str, split: str, page: str|List[str], sub_targets: List[int|str]) -> dict:
254
  try:
255
  ret = {}
256
  if dataset != 'Salesforce/lotsa_data':
 
263
  df_list, id_list = [], []
264
  for i, page in enumerate(page):
265
  df, max_page, info = get_page(dataset, config, split, page)
266
+ global tot_samples, tot_targets
267
+ tot_samples, tot_targets = max_page, len(df['target'][0]) if isinstance(df['target'][0], np.ndarray) else 1
268
+
269
+ df = clean_up_df(df, sub_targets)
270
  row = df.iloc[0]
271
  id_list.append(row['item_id'])
272
  # 将单行的DataFrame展开为新的DataFrame
273
  df_without_index = row.drop('item_id').to_frame().T
274
  df_expanded = df_without_index.apply(pd.Series.explode).reset_index(drop=True).fillna(0)
275
  df_list.append(df_expanded)
276
+
 
277
  return {
278
  statistics_textbox: gr.update(value=create_statistic(df_list, id_list)),
279
  plot: gr.update(value=create_plot(df_list, id_list)),
 
296
  def show_dataset_at_config_and_split(dataset: str, config: str, split: str) -> dict:
297
  try:
298
  return {
299
+ **show_dataset_at_config_and_split_and_page(dataset, config, split, "1", [0]),
300
+ select_sample_box: gr.update(choices=[f"{i+1}" for i in range(tot_samples)], value=["1"]),
301
+ select_subtarget_box: gr.update(choices=[i for i in range(tot_targets)]+['all'], value=[0]),
302
  # cp_page: gr.update(value="1", visible=True),
303
  # cp_goto_page: gr.update(visible=True),
304
  # cp_goto_next_page: gr.update(visible=True),
 
341
  all_outputs = [cp_config, cp_split,
342
  # cp_page, cp_goto_page, cp_goto_next_page,
343
  cp_result, cp_info, cp_error,
344
+ select_sample_box, select_subtarget_box,
345
+ select_buttom, statistics_textbox, user_input_text, plot]
346
  cp_go.click(show_dataset, inputs=[cp_dataset], outputs=all_outputs)
347
  cp_config.change(show_dataset_at_config, inputs=[cp_dataset, cp_config], outputs=all_outputs)
348
  cp_split.change(show_dataset_at_config_and_split, inputs=[cp_dataset, cp_config, cp_split], outputs=all_outputs)
349
  # cp_goto_page.click(show_dataset_at_config_and_split_and_page, inputs=[cp_dataset, cp_config, cp_split, cp_page], outputs=all_outputs)
350
  # cp_goto_next_page.click(show_dataset_at_config_and_split_and_next_page, inputs=[cp_dataset, cp_config, cp_split, cp_page], outputs=all_outputs)
351
  user_input_text.submit(save_to_file, inputs=user_input_text)
352
+ select_buttom.click(show_dataset_at_config_and_split_and_page, inputs=[cp_dataset, cp_config, cp_split, select_sample_box, select_subtarget_box], outputs=all_outputs)
353
 
354
 
355
  if __name__ == "__main__":
356
 
357
  app = gr.mount_gradio_app(app, demo, path="/")
358
+ uvicorn.run(app, host="127.0.0.1", port=7860)
utils.py CHANGED
@@ -33,22 +33,22 @@ def ndarray_to_base64(ndarray):
33
  base64_str = base64.b64encode(buffer.getvalue()).decode('utf-8')
34
  return f"data:image/png;base64,{base64_str}"
35
 
36
- def flatten_ndarray_column(df, column_name):
37
  """
38
- 将嵌套的np.ndarray列展平为多列。
39
  """
40
- def flatten_ndarray(ndarray):
41
  if isinstance(ndarray, np.ndarray) and ndarray.dtype == 'O':
42
- return np.concatenate([flatten_ndarray(subarray) for subarray in ndarray])
 
43
  elif isinstance(ndarray, np.ndarray) and ndarray.ndim == 1:
44
  return np.expand_dims(ndarray, axis=0)
45
  return ndarray
46
 
47
- flattened_data = df[column_name].apply(flatten_ndarray)
48
- max_length = max(flattened_data.apply(len))
49
 
50
- for i in range(max_length):
51
- df[f'{column_name}_{i}'] = flattened_data.apply(lambda x: x[i] if i < len(x) else np.nan)
52
 
53
  return df
54
 
@@ -110,16 +110,21 @@ def create_statistic(dfs: list[pd.DataFrame], ids: list[str]):
110
  combined_stats_df = pd.concat(stats_list, ignore_index=True)
111
  return combined_stats_df
112
 
113
- def clean_up_df(df: pd.DataFrame) -> pd.DataFrame:
114
  """
115
  清理数据集,将嵌套的np.ndarray列展平为多列。
116
  """
 
 
 
 
 
117
  df['timestamp'] = df.apply(lambda row: pd.date_range(
118
  start=row['start'],
119
  periods=len(row['target'][0]) if isinstance(row['target'][0], np.ndarray) else len(row['target']),
120
  freq=row['freq']
121
  ).to_pydatetime().tolist(), axis=1)
122
- df = flatten_ndarray_column(df, 'target')
123
  # 删除原始的start和freq列
124
  df.drop(columns=['start', 'freq', 'target'], inplace=True)
125
  if 'past_feat_dynamic_real' in df.columns:
 
33
  base64_str = base64.b64encode(buffer.getvalue()).decode('utf-8')
34
  return f"data:image/png;base64,{base64_str}"
35
 
36
+ def flatten_ndarray_column(df, column_name, rows_to_include):
37
  """
38
+ 将嵌套的np.ndarray列展平为多列,并只保留指定的行。
39
  """
40
+ def select_and_flatten(ndarray):
41
  if isinstance(ndarray, np.ndarray) and ndarray.dtype == 'O':
42
+ selected = [ndarray[i] for i in rows_to_include if i < len(ndarray)]
43
+ return np.concatenate([select_and_flatten(subarray) for subarray in selected])
44
  elif isinstance(ndarray, np.ndarray) and ndarray.ndim == 1:
45
  return np.expand_dims(ndarray, axis=0)
46
  return ndarray
47
 
48
+ selected_data = df[column_name].apply(select_and_flatten)
 
49
 
50
+ for i in rows_to_include:
51
+ df[f'{column_name}_{i}'] = selected_data.apply(lambda x: x[i] if i < len(x) else np.nan)
52
 
53
  return df
54
 
 
110
  combined_stats_df = pd.concat(stats_list, ignore_index=True)
111
  return combined_stats_df
112
 
113
+ def clean_up_df(df: pd.DataFrame, rows_to_include: list[int]) -> pd.DataFrame:
114
  """
115
  清理数据集,将嵌套的np.ndarray列展平为多列。
116
  """
117
+ if 'all' in rows_to_include:
118
+ rows_to_include = list(range(len(df['target'][0]))) if isinstance(df['target'][0], np.ndarray) else 1
119
+ else:
120
+ rows_to_include = sorted(rows_to_include)
121
+
122
  df['timestamp'] = df.apply(lambda row: pd.date_range(
123
  start=row['start'],
124
  periods=len(row['target'][0]) if isinstance(row['target'][0], np.ndarray) else len(row['target']),
125
  freq=row['freq']
126
  ).to_pydatetime().tolist(), axis=1)
127
+ df = flatten_ndarray_column(df, 'target', rows_to_include)
128
  # 删除原始的start和freq列
129
  df.drop(columns=['start', 'freq', 'target'], inplace=True)
130
  if 'past_feat_dynamic_real' in df.columns: