Spaces:
Running
Running
Liu Yiwen
commited on
Commit
·
e03ca4d
1
Parent(s):
0edb9ff
更新了选择target的功能
Browse files- __pycache__/utils.cpython-311.pyc +0 -0
- app.py +18 -12
- utils.py +15 -10
__pycache__/utils.cpython-311.pyc
CHANGED
Binary files a/__pycache__/utils.cpython-311.pyc and b/__pycache__/utils.cpython-311.pyc differ
|
|
app.py
CHANGED
@@ -222,8 +222,10 @@ with gr.Blocks() as demo:
|
|
222 |
# componets = []
|
223 |
# for _ in range(TIME_PLOTS_NUM):
|
224 |
with gr.Row():
|
225 |
-
with gr.Column(scale=
|
226 |
-
|
|
|
|
|
227 |
with gr.Column(scale=1):
|
228 |
select_buttom = gr.Button("Show selected items")
|
229 |
with gr.Row():
|
@@ -232,7 +234,7 @@ with gr.Blocks() as demo:
|
|
232 |
with gr.Column(scale=3):
|
233 |
plot = gr.Plot()
|
234 |
user_input_text = gr.Textbox(placeholder="输入一些内容")
|
235 |
-
# componets.append({"
|
236 |
# "statistics_textbox": statistics_textbox,
|
237 |
# "user_input_text": user_input_text,
|
238 |
# "plot": plot})
|
@@ -248,7 +250,7 @@ with gr.Blocks() as demo:
|
|
248 |
cp_result: gr.update(visible=False, value=""),
|
249 |
}
|
250 |
|
251 |
-
def show_dataset_at_config_and_split_and_page(dataset: str, config: str, split: str, page: str|List[str]) -> dict:
|
252 |
try:
|
253 |
ret = {}
|
254 |
if dataset != 'Salesforce/lotsa_data':
|
@@ -261,15 +263,17 @@ with gr.Blocks() as demo:
|
|
261 |
df_list, id_list = [], []
|
262 |
for i, page in enumerate(page):
|
263 |
df, max_page, info = get_page(dataset, config, split, page)
|
264 |
-
|
|
|
|
|
|
|
265 |
row = df.iloc[0]
|
266 |
id_list.append(row['item_id'])
|
267 |
# 将单行的DataFrame展开为新的DataFrame
|
268 |
df_without_index = row.drop('item_id').to_frame().T
|
269 |
df_expanded = df_without_index.apply(pd.Series.explode).reset_index(drop=True).fillna(0)
|
270 |
df_list.append(df_expanded)
|
271 |
-
|
272 |
-
tot_samples = max_page
|
273 |
return {
|
274 |
statistics_textbox: gr.update(value=create_statistic(df_list, id_list)),
|
275 |
plot: gr.update(value=create_plot(df_list, id_list)),
|
@@ -292,8 +296,9 @@ with gr.Blocks() as demo:
|
|
292 |
def show_dataset_at_config_and_split(dataset: str, config: str, split: str) -> dict:
|
293 |
try:
|
294 |
return {
|
295 |
-
**show_dataset_at_config_and_split_and_page(dataset, config, split, "1"),
|
296 |
-
|
|
|
297 |
# cp_page: gr.update(value="1", visible=True),
|
298 |
# cp_goto_page: gr.update(visible=True),
|
299 |
# cp_goto_next_page: gr.update(visible=True),
|
@@ -336,17 +341,18 @@ with gr.Blocks() as demo:
|
|
336 |
all_outputs = [cp_config, cp_split,
|
337 |
# cp_page, cp_goto_page, cp_goto_next_page,
|
338 |
cp_result, cp_info, cp_error,
|
339 |
-
|
|
|
340 |
cp_go.click(show_dataset, inputs=[cp_dataset], outputs=all_outputs)
|
341 |
cp_config.change(show_dataset_at_config, inputs=[cp_dataset, cp_config], outputs=all_outputs)
|
342 |
cp_split.change(show_dataset_at_config_and_split, inputs=[cp_dataset, cp_config, cp_split], outputs=all_outputs)
|
343 |
# cp_goto_page.click(show_dataset_at_config_and_split_and_page, inputs=[cp_dataset, cp_config, cp_split, cp_page], outputs=all_outputs)
|
344 |
# cp_goto_next_page.click(show_dataset_at_config_and_split_and_next_page, inputs=[cp_dataset, cp_config, cp_split, cp_page], outputs=all_outputs)
|
345 |
user_input_text.submit(save_to_file, inputs=user_input_text)
|
346 |
-
select_buttom.click(show_dataset_at_config_and_split_and_page, inputs=[cp_dataset, cp_config, cp_split,
|
347 |
|
348 |
|
349 |
if __name__ == "__main__":
|
350 |
|
351 |
app = gr.mount_gradio_app(app, demo, path="/")
|
352 |
-
uvicorn.run(app, host="
|
|
|
222 |
# componets = []
|
223 |
# for _ in range(TIME_PLOTS_NUM):
|
224 |
with gr.Row():
|
225 |
+
with gr.Column(scale=2):
|
226 |
+
select_sample_box = gr.Dropdown(choices=["items"], label="Select some items", multiselect=True, interactive=True)
|
227 |
+
with gr.Column(scale=2):
|
228 |
+
select_subtarget_box = gr.Dropdown(choices=["subtargets"], label="Select some subtargets", multiselect=True, interactive=True)
|
229 |
with gr.Column(scale=1):
|
230 |
select_buttom = gr.Button("Show selected items")
|
231 |
with gr.Row():
|
|
|
234 |
with gr.Column(scale=3):
|
235 |
plot = gr.Plot()
|
236 |
user_input_text = gr.Textbox(placeholder="输入一些内容")
|
237 |
+
# componets.append({"select_sample_box": select_sample_box,
|
238 |
# "statistics_textbox": statistics_textbox,
|
239 |
# "user_input_text": user_input_text,
|
240 |
# "plot": plot})
|
|
|
250 |
cp_result: gr.update(visible=False, value=""),
|
251 |
}
|
252 |
|
253 |
+
def show_dataset_at_config_and_split_and_page(dataset: str, config: str, split: str, page: str|List[str], sub_targets: List[int|str]) -> dict:
|
254 |
try:
|
255 |
ret = {}
|
256 |
if dataset != 'Salesforce/lotsa_data':
|
|
|
263 |
df_list, id_list = [], []
|
264 |
for i, page in enumerate(page):
|
265 |
df, max_page, info = get_page(dataset, config, split, page)
|
266 |
+
global tot_samples, tot_targets
|
267 |
+
tot_samples, tot_targets = max_page, len(df['target'][0]) if isinstance(df['target'][0], np.ndarray) else 1
|
268 |
+
|
269 |
+
df = clean_up_df(df, sub_targets)
|
270 |
row = df.iloc[0]
|
271 |
id_list.append(row['item_id'])
|
272 |
# 将单行的DataFrame展开为新的DataFrame
|
273 |
df_without_index = row.drop('item_id').to_frame().T
|
274 |
df_expanded = df_without_index.apply(pd.Series.explode).reset_index(drop=True).fillna(0)
|
275 |
df_list.append(df_expanded)
|
276 |
+
|
|
|
277 |
return {
|
278 |
statistics_textbox: gr.update(value=create_statistic(df_list, id_list)),
|
279 |
plot: gr.update(value=create_plot(df_list, id_list)),
|
|
|
296 |
def show_dataset_at_config_and_split(dataset: str, config: str, split: str) -> dict:
|
297 |
try:
|
298 |
return {
|
299 |
+
**show_dataset_at_config_and_split_and_page(dataset, config, split, "1", [0]),
|
300 |
+
select_sample_box: gr.update(choices=[f"{i+1}" for i in range(tot_samples)], value=["1"]),
|
301 |
+
select_subtarget_box: gr.update(choices=[i for i in range(tot_targets)]+['all'], value=[0]),
|
302 |
# cp_page: gr.update(value="1", visible=True),
|
303 |
# cp_goto_page: gr.update(visible=True),
|
304 |
# cp_goto_next_page: gr.update(visible=True),
|
|
|
341 |
all_outputs = [cp_config, cp_split,
|
342 |
# cp_page, cp_goto_page, cp_goto_next_page,
|
343 |
cp_result, cp_info, cp_error,
|
344 |
+
select_sample_box, select_subtarget_box,
|
345 |
+
select_buttom, statistics_textbox, user_input_text, plot]
|
346 |
cp_go.click(show_dataset, inputs=[cp_dataset], outputs=all_outputs)
|
347 |
cp_config.change(show_dataset_at_config, inputs=[cp_dataset, cp_config], outputs=all_outputs)
|
348 |
cp_split.change(show_dataset_at_config_and_split, inputs=[cp_dataset, cp_config, cp_split], outputs=all_outputs)
|
349 |
# cp_goto_page.click(show_dataset_at_config_and_split_and_page, inputs=[cp_dataset, cp_config, cp_split, cp_page], outputs=all_outputs)
|
350 |
# cp_goto_next_page.click(show_dataset_at_config_and_split_and_next_page, inputs=[cp_dataset, cp_config, cp_split, cp_page], outputs=all_outputs)
|
351 |
user_input_text.submit(save_to_file, inputs=user_input_text)
|
352 |
+
select_buttom.click(show_dataset_at_config_and_split_and_page, inputs=[cp_dataset, cp_config, cp_split, select_sample_box, select_subtarget_box], outputs=all_outputs)
|
353 |
|
354 |
|
355 |
if __name__ == "__main__":
|
356 |
|
357 |
app = gr.mount_gradio_app(app, demo, path="/")
|
358 |
+
uvicorn.run(app, host="127.0.0.1", port=7860)
|
utils.py
CHANGED
@@ -33,22 +33,22 @@ def ndarray_to_base64(ndarray):
|
|
33 |
base64_str = base64.b64encode(buffer.getvalue()).decode('utf-8')
|
34 |
return f"data:image/png;base64,{base64_str}"
|
35 |
|
36 |
-
def flatten_ndarray_column(df, column_name):
|
37 |
"""
|
38 |
-
将嵌套的np.ndarray
|
39 |
"""
|
40 |
-
def
|
41 |
if isinstance(ndarray, np.ndarray) and ndarray.dtype == 'O':
|
42 |
-
|
|
|
43 |
elif isinstance(ndarray, np.ndarray) and ndarray.ndim == 1:
|
44 |
return np.expand_dims(ndarray, axis=0)
|
45 |
return ndarray
|
46 |
|
47 |
-
|
48 |
-
max_length = max(flattened_data.apply(len))
|
49 |
|
50 |
-
for i in
|
51 |
-
df[f'{column_name}_{i}'] =
|
52 |
|
53 |
return df
|
54 |
|
@@ -110,16 +110,21 @@ def create_statistic(dfs: list[pd.DataFrame], ids: list[str]):
|
|
110 |
combined_stats_df = pd.concat(stats_list, ignore_index=True)
|
111 |
return combined_stats_df
|
112 |
|
113 |
-
def clean_up_df(df: pd.DataFrame) -> pd.DataFrame:
|
114 |
"""
|
115 |
清理数据集,将嵌套的np.ndarray列展平为多列。
|
116 |
"""
|
|
|
|
|
|
|
|
|
|
|
117 |
df['timestamp'] = df.apply(lambda row: pd.date_range(
|
118 |
start=row['start'],
|
119 |
periods=len(row['target'][0]) if isinstance(row['target'][0], np.ndarray) else len(row['target']),
|
120 |
freq=row['freq']
|
121 |
).to_pydatetime().tolist(), axis=1)
|
122 |
-
df = flatten_ndarray_column(df, 'target')
|
123 |
# 删除原始的start和freq列
|
124 |
df.drop(columns=['start', 'freq', 'target'], inplace=True)
|
125 |
if 'past_feat_dynamic_real' in df.columns:
|
|
|
33 |
base64_str = base64.b64encode(buffer.getvalue()).decode('utf-8')
|
34 |
return f"data:image/png;base64,{base64_str}"
|
35 |
|
36 |
+
def flatten_ndarray_column(df, column_name, rows_to_include):
|
37 |
"""
|
38 |
+
将嵌套的np.ndarray列展平为多列,并只保留指定的行。
|
39 |
"""
|
40 |
+
def select_and_flatten(ndarray):
|
41 |
if isinstance(ndarray, np.ndarray) and ndarray.dtype == 'O':
|
42 |
+
selected = [ndarray[i] for i in rows_to_include if i < len(ndarray)]
|
43 |
+
return np.concatenate([select_and_flatten(subarray) for subarray in selected])
|
44 |
elif isinstance(ndarray, np.ndarray) and ndarray.ndim == 1:
|
45 |
return np.expand_dims(ndarray, axis=0)
|
46 |
return ndarray
|
47 |
|
48 |
+
selected_data = df[column_name].apply(select_and_flatten)
|
|
|
49 |
|
50 |
+
for i in rows_to_include:
|
51 |
+
df[f'{column_name}_{i}'] = selected_data.apply(lambda x: x[i] if i < len(x) else np.nan)
|
52 |
|
53 |
return df
|
54 |
|
|
|
110 |
combined_stats_df = pd.concat(stats_list, ignore_index=True)
|
111 |
return combined_stats_df
|
112 |
|
113 |
+
def clean_up_df(df: pd.DataFrame, rows_to_include: list[int]) -> pd.DataFrame:
|
114 |
"""
|
115 |
清理数据集,将嵌套的np.ndarray列展平为多列。
|
116 |
"""
|
117 |
+
if 'all' in rows_to_include:
|
118 |
+
rows_to_include = list(range(len(df['target'][0]))) if isinstance(df['target'][0], np.ndarray) else 1
|
119 |
+
else:
|
120 |
+
rows_to_include = sorted(rows_to_include)
|
121 |
+
|
122 |
df['timestamp'] = df.apply(lambda row: pd.date_range(
|
123 |
start=row['start'],
|
124 |
periods=len(row['target'][0]) if isinstance(row['target'][0], np.ndarray) else len(row['target']),
|
125 |
freq=row['freq']
|
126 |
).to_pydatetime().tolist(), axis=1)
|
127 |
+
df = flatten_ndarray_column(df, 'target', rows_to_include)
|
128 |
# 删除原始的start和freq列
|
129 |
df.drop(columns=['start', 'freq', 'target'], inplace=True)
|
130 |
if 'past_feat_dynamic_real' in df.columns:
|