Liu Yiwen commited on
Commit
b4b95a6
·
1 Parent(s): 266b4a6

更新了结合Ploty动态交互的功能,以及展示统计值的功能

Browse files
Files changed (4) hide show
  1. README.md +1 -0
  2. __pycache__/utils.cpython-311.pyc +0 -0
  3. app.py +26 -113
  4. utils.py +119 -0
README.md CHANGED
@@ -13,6 +13,7 @@ pinned: false
13
 
14
  Access any slice of data of any dataset on the [Hugging Face Dataset Hub](https://huggingface.co/datasets)
15
 
 
16
  Run:
17
 
18
  ```python
 
13
 
14
  Access any slice of data of any dataset on the [Hugging Face Dataset Hub](https://huggingface.co/datasets)
15
 
16
+ This project is modified based on the project https://huggingface.co/spaces/lhoestq/datasets-explorer
17
  Run:
18
 
19
  ```python
__pycache__/utils.cpython-311.pyc ADDED
Binary file (6.52 kB). View file
 
app.py CHANGED
@@ -1,24 +1,20 @@
1
- import base64
2
  import copy
3
- from datetime import datetime, timedelta
4
- from io import BytesIO
5
- import random
6
  import gradio as gr
7
- from functools import lru_cache
8
- from hffs.fs import HfFileSystem
9
- from typing import List, Tuple, Callable
10
- from matplotlib import pyplot as plt
11
- import pandas as pd
12
  import numpy as np
 
13
  import pyarrow as pa
14
  import pyarrow.parquet as pq
15
- from functools import partial
16
  from tqdm.contrib.concurrent import thread_map
17
- from datasets import Features, Image, Audio, Sequence
18
  from fastapi import FastAPI, Response
19
  import uvicorn
20
- import os
21
- from gradio_datetimerange import DateTimeRange
 
 
 
22
 
23
  class AppError(RuntimeError):
24
  pass
@@ -30,46 +26,7 @@ MAX_CACHED_BLOBS = PAGE_SIZE * 10
30
  TIME_PLOTS_NUM = 5
31
  _blobs_cache = {}
32
 
33
- #####################################################
34
- # Utils
35
- #####################################################
36
- def ndarray_to_base64(ndarray):
37
- """
38
- 将一维np.ndarray绘图并转换为Base64编码。
39
- """
40
- # 创建绘图
41
- plt.figure(figsize=(8, 4))
42
- plt.plot(ndarray)
43
- plt.title("Vector Plot")
44
- plt.xlabel("Index")
45
- plt.ylabel("Value")
46
- plt.tight_layout()
47
-
48
- # 保存图像到内存字节流
49
- buffer = BytesIO()
50
- plt.savefig(buffer, format="png")
51
- plt.close()
52
- buffer.seek(0)
53
-
54
- # 转换为Base64字符串
55
- base64_str = base64.b64encode(buffer.getvalue()).decode('utf-8')
56
- return f"data:image/png;base64,{base64_str}"
57
-
58
- def flatten_ndarray_column(df, column_name):
59
- def flatten_ndarray(ndarray):
60
- if isinstance(ndarray, np.ndarray) and ndarray.dtype == 'O':
61
- return np.concatenate([flatten_ndarray(subarray) for subarray in ndarray])
62
- elif isinstance(ndarray, np.ndarray) and ndarray.ndim == 1:
63
- return np.expand_dims(ndarray, axis=0)
64
- return ndarray
65
-
66
- flattened_data = df[column_name].apply(flatten_ndarray)
67
- max_length = max(flattened_data.apply(len))
68
-
69
- for i in range(max_length):
70
- df[f'{column_name}_{i}'] = flattened_data.apply(lambda x: x[i] if i < len(x) else np.nan)
71
 
72
- return df
73
  #####################################################
74
  # Define routes for image and audio files
75
  #####################################################
@@ -239,7 +196,7 @@ def get_page(dataset: str, config: str, split: str, page: str) -> Tuple[str, int
239
  info = "" if not unsupported_columns else f"Some columns are not supported yet: {unsupported_columns}"
240
  return df.reset_index().to_markdown(index=False), max_page, info
241
  else:
242
- # 其他的处理逻辑
243
  info = "" if not unsupported_columns else f"Some columns are not supported yet: {unsupported_columns}"
244
  return df, max_page, info
245
 
@@ -250,6 +207,7 @@ def get_page(dataset: str, config: str, split: str, page: str) -> Tuple[str, int
250
 
251
 
252
  with gr.Blocks() as demo:
 
253
  gr.Markdown("# 📖 Datasets Explorer\n\nAccess any slice of data of any dataset on the [Hugging Face Dataset Hub](https://huggingface.co/datasets)")
254
  gr.Markdown("This is the dataset viewer from parquet export demo before the feature was added on the Hugging Face website.")
255
  cp_dataset = gr.Textbox("Salesforce/lotsa_data", label="Pick a dataset", placeholder="competitions/aiornot")
@@ -261,32 +219,16 @@ with gr.Blocks() as demo:
261
  cp_info = gr.Markdown("", visible=False)
262
  cp_result = gr.Markdown("", visible=False)
263
 
264
- now = datetime.now()
265
- df = pd.DataFrame({
266
- 'time': [now - timedelta(minutes=5*i) for i in range(25)] + [now],
267
- 'price': np.random.randint(100, 1000, 26),
268
- 'origin': [random.choice(["DFW", "DAL", "HOU"]) for _ in range(26)],
269
- 'destination': [random.choice(["JFK", "LGA", "EWR"]) for _ in range(26)],
270
- })
271
-
272
  componets = []
273
  for _ in range(TIME_PLOTS_NUM):
274
  with gr.Row():
275
- textbox = gr.Textbox("名称或说明")
276
- with gr.Column():
277
- daterange = DateTimeRange(["now - 24h", "now"])
278
- plot1 = gr.LinePlot(df, x="time", y="price", color="origin")
279
- # plot2 = gr.LinePlot(df, x="time", y="price", color="origin")
280
- daterange.bind([plot1,
281
- # plot2,
282
- ])
283
- comp = {
284
- "textbox" : textbox,
285
- "daterange" : daterange,
286
- "plot1" : plot1,
287
- # "plot2" : plot2,
288
- }
289
- componets.append(comp)
290
 
291
  with gr.Row():
292
  cp_page = gr.Textbox("1", label="Page", placeholder="1", visible=False)
@@ -306,24 +248,19 @@ with gr.Blocks() as demo:
306
  markdown_result, max_page, info = get_page(dataset, config, split, page)
307
  ret[cp_result] = gr.update(visible=True, value=markdown_result)
308
  else:
 
309
  df, max_page, info = get_page(dataset, config, split, page)
310
- print(df.columns)
311
- # TODO:target为一维数组时len(row['target'][0])会直接报错
312
- df['timestamp'] = df.apply(lambda row: pd.date_range(start=row['start'], periods=len(row['target'][0]), freq=row['freq']).to_pydatetime().tolist(), axis=1)
313
- df = flatten_ndarray_column(df, 'target')
314
- # 删除原始的start和freq列
315
- df.drop(columns=['start', 'freq', 'target'], inplace=True)
316
- if 'past_feat_dynamic_real' in df.columns:
317
- df.drop(columns=['past_feat_dynamic_real'], inplace=True)
318
- info = f"({info})" if info else ""
319
  for i, rows in df.iterrows():
320
  index = rows['item_id']
 
321
  df_without_index = rows.drop('item_id').to_frame().T
322
  df_expanded = df_without_index.apply(pd.Series.explode).reset_index(drop=True).fillna(0)
 
323
  ret.update({
324
  componets[i]["textbox"]: gr.update(value=f"item_id: {index}"),
325
- componets[i]["daterange"]: gr.update(value=[df_without_index['timestamp'][i][0], df_without_index['timestamp'][i][-1]]),
326
- componets[i]["plot1"]: gr.update(value=df_expanded, x="timestamp", y="target_0"),
327
  })
328
  return {
329
  **ret,
@@ -381,34 +318,10 @@ with gr.Blocks() as demo:
381
  }
382
  except AppError as err:
383
  return show_error(str(err))
384
-
385
- """
386
- 动态生成组件时使用gr.LinePlot会有bug,直接卡死在show_dataset部分
387
- """
388
- # @gr.render(triggers=[cp_go.click])
389
- # def create_test():
390
- # now = datetime.now()
391
- # df = pd.DataFrame({
392
- # 'time': [now - timedelta(minutes=5*i) for i in range(25)],
393
- # 'price': np.random.randint(100, 1000, 25),
394
- # 'origin': [random.choice(["DFW", "DAL", "HOU"]) for _ in range(25)],
395
- # 'destination': [random.choice(["JFK", "LGA", "EWR"]) for _ in range(25)],
396
- # })
397
- # # componets = []
398
- # # daterange = DateTimeRange(["now - 24h", "now"])
399
- # plot1 = gr.LinePlot(df, x="time", y="price")
400
- # plot2 = gr.LinePlot(df, x="time", y="price", color="origin")
401
- # # # daterange.bind([plot1, plot2])
402
- # # componets.append(plot1)
403
- # # componets.append(plot2)
404
- # # componets.append(daterange)
405
- # # test = gr.Textbox(label="input")
406
- # # componets.append(test)
407
- # # return componets
408
 
409
  all_outputs = [cp_config, cp_split, cp_page, cp_goto_page, cp_goto_next_page, cp_result, cp_info, cp_error]
410
- for comp in componets:
411
- all_outputs += list(comp.values())
412
  cp_go.click(show_dataset, inputs=[cp_dataset], outputs=all_outputs)
413
  cp_config.change(show_dataset_at_config, inputs=[cp_dataset, cp_config], outputs=all_outputs)
414
  cp_split.change(show_dataset_at_config_and_split, inputs=[cp_dataset, cp_config, cp_split], outputs=all_outputs)
 
 
1
  import copy
2
+ import os
3
+ from functools import lru_cache, partial
4
+
5
  import gradio as gr
 
 
 
 
 
6
  import numpy as np
7
+ import pandas as pd
8
  import pyarrow as pa
9
  import pyarrow.parquet as pq
 
10
  from tqdm.contrib.concurrent import thread_map
 
11
  from fastapi import FastAPI, Response
12
  import uvicorn
13
+ from hffs.fs import HfFileSystem
14
+ from datasets import Features, Image, Audio, Sequence
15
+ from typing import List, Tuple, Callable
16
+
17
+ from utils import ndarray_to_base64, clean_up_df, create_statistic, create_plot
18
 
19
  class AppError(RuntimeError):
20
  pass
 
26
  TIME_PLOTS_NUM = 5
27
  _blobs_cache = {}
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
 
30
  #####################################################
31
  # Define routes for image and audio files
32
  #####################################################
 
196
  info = "" if not unsupported_columns else f"Some columns are not supported yet: {unsupported_columns}"
197
  return df.reset_index().to_markdown(index=False), max_page, info
198
  else:
199
+ # 对Salesforce/lotsa_data数据集进行特殊处理
200
  info = "" if not unsupported_columns else f"Some columns are not supported yet: {unsupported_columns}"
201
  return df, max_page, info
202
 
 
207
 
208
 
209
  with gr.Blocks() as demo:
210
+ # 初始化组件
211
  gr.Markdown("# 📖 Datasets Explorer\n\nAccess any slice of data of any dataset on the [Hugging Face Dataset Hub](https://huggingface.co/datasets)")
212
  gr.Markdown("This is the dataset viewer from parquet export demo before the feature was added on the Hugging Face website.")
213
  cp_dataset = gr.Textbox("Salesforce/lotsa_data", label="Pick a dataset", placeholder="competitions/aiornot")
 
219
  cp_info = gr.Markdown("", visible=False)
220
  cp_result = gr.Markdown("", visible=False)
221
 
222
+ # 初始化Salesforce/lotsa_data数据集展示使用的组件
 
 
 
 
 
 
 
223
  componets = []
224
  for _ in range(TIME_PLOTS_NUM):
225
  with gr.Row():
226
+ with gr.Column(scale=2):
227
+ textbox = gr.Textbox("名称或说明")
228
+ statistics_textbox = gr.DataFrame()
229
+ with gr.Column(scale=3):
230
+ plot = gr.Plot()
231
+ componets.append({"textbox": textbox, "statistics_textbox": statistics_textbox, "plot": plot})
 
 
 
 
 
 
 
 
 
232
 
233
  with gr.Row():
234
  cp_page = gr.Textbox("1", label="Page", placeholder="1", visible=False)
 
248
  markdown_result, max_page, info = get_page(dataset, config, split, page)
249
  ret[cp_result] = gr.update(visible=True, value=markdown_result)
250
  else:
251
+ # 对Salesforce/lotsa_data数据集进行特殊处理
252
  df, max_page, info = get_page(dataset, config, split, page)
253
+ df = clean_up_df(df)
 
 
 
 
 
 
 
 
254
  for i, rows in df.iterrows():
255
  index = rows['item_id']
256
+ # 将单行的DataFrame展开为新的DataFrame
257
  df_without_index = rows.drop('item_id').to_frame().T
258
  df_expanded = df_without_index.apply(pd.Series.explode).reset_index(drop=True).fillna(0)
259
+ df_statistics = create_statistic(df_expanded)
260
  ret.update({
261
  componets[i]["textbox"]: gr.update(value=f"item_id: {index}"),
262
+ componets[i]["statistics_textbox"]: gr.update(value=df_statistics),
263
+ componets[i]["plot"]: gr.update(value=create_plot(df_expanded))
264
  })
265
  return {
266
  **ret,
 
318
  }
319
  except AppError as err:
320
  return show_error(str(err))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
321
 
322
  all_outputs = [cp_config, cp_split, cp_page, cp_goto_page, cp_goto_next_page, cp_result, cp_info, cp_error]
323
+ for componet in componets:
324
+ all_outputs += list(componet.values())
325
  cp_go.click(show_dataset, inputs=[cp_dataset], outputs=all_outputs)
326
  cp_config.change(show_dataset_at_config, inputs=[cp_dataset, cp_config], outputs=all_outputs)
327
  cp_split.change(show_dataset_at_config_and_split, inputs=[cp_dataset, cp_config, cp_split], outputs=all_outputs)
utils.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #####################################################
2
+ # Utils
3
+ #####################################################
4
+ # 本文件包含了一些用于数据处理和绘图的实用函数。
5
+
6
+ import base64
7
+ from io import BytesIO
8
+ from matplotlib import pyplot as plt
9
+ import pandas as pd
10
+ import plotly.graph_objects as go
11
+ import numpy as np
12
+
13
+
14
+ def ndarray_to_base64(ndarray):
15
+ """
16
+ 将一维np.ndarray绘图并转换为Base64编码。
17
+ """
18
+ # 创建绘图
19
+ plt.figure(figsize=(8, 4))
20
+ plt.plot(ndarray)
21
+ plt.title("Vector Plot")
22
+ plt.xlabel("Index")
23
+ plt.ylabel("Value")
24
+ plt.tight_layout()
25
+
26
+ # 保存图像到内存字节流
27
+ buffer = BytesIO()
28
+ plt.savefig(buffer, format="png")
29
+ plt.close()
30
+ buffer.seek(0)
31
+
32
+ # 转换为Base64字符串
33
+ base64_str = base64.b64encode(buffer.getvalue()).decode('utf-8')
34
+ return f"data:image/png;base64,{base64_str}"
35
+
36
+ def flatten_ndarray_column(df, column_name):
37
+ """
38
+ 将嵌套的np.ndarray列展平为多列。
39
+ """
40
+ def flatten_ndarray(ndarray):
41
+ if isinstance(ndarray, np.ndarray) and ndarray.dtype == 'O':
42
+ return np.concatenate([flatten_ndarray(subarray) for subarray in ndarray])
43
+ elif isinstance(ndarray, np.ndarray) and ndarray.ndim == 1:
44
+ return np.expand_dims(ndarray, axis=0)
45
+ return ndarray
46
+
47
+ flattened_data = df[column_name].apply(flatten_ndarray)
48
+ max_length = max(flattened_data.apply(len))
49
+
50
+ for i in range(max_length):
51
+ df[f'{column_name}_{i}'] = flattened_data.apply(lambda x: x[i] if i < len(x) else np.nan)
52
+
53
+ return df
54
+
55
+ def create_plot(df):
56
+ """
57
+ 创建一个包含所有列的线图。
58
+ """
59
+ fig = go.Figure()
60
+ for i, column in enumerate(df.columns[1:]):
61
+ fig.add_trace(go.Scatter(
62
+ x=df[df.columns[0]],
63
+ y=df[column],
64
+ mode='lines',
65
+ name=column,
66
+ visible=True if i == 0 else 'legendonly'
67
+ ))
68
+
69
+ # 配置图例
70
+ fig.update_layout(
71
+ legend=dict(
72
+ title="Variables",
73
+ orientation="h",
74
+ yanchor="top",
75
+ y=-0.2,
76
+ xanchor="center",
77
+ x=0.5
78
+ ),
79
+ xaxis_title='Time',
80
+ yaxis_title='Values'
81
+ )
82
+ return fig
83
+
84
+ def create_statistic(df):
85
+ """
86
+ 计算数据集的统计信息。
87
+ """
88
+ df_values = df.iloc[:, 1:]
89
+ # 计算统计值
90
+ mean_values = df_values.mean()
91
+ std_values = df_values.std()
92
+ max_values = df_values.max()
93
+ min_values = df_values.min()
94
+
95
+ # 将这些统计信息合并成一个新的DataFrame
96
+ stats_df = pd.DataFrame({
97
+ 'Variables': df_values.columns,
98
+ 'mean': mean_values.values,
99
+ 'std': std_values.values,
100
+ 'max': max_values.values,
101
+ 'min': min_values.values
102
+ })
103
+ return stats_df
104
+
105
+ def clean_up_df(df: pd.DataFrame) -> pd.DataFrame:
106
+ """
107
+ 清理数据集,将嵌套的np.ndarray列展平为多列。
108
+ """
109
+ df['timestamp'] = df.apply(lambda row: pd.date_range(
110
+ start=row['start'],
111
+ periods=len(row['target'][0]) if isinstance(row['target'][0], np.ndarray) else len(row['target']),
112
+ freq=row['freq']
113
+ ).to_pydatetime().tolist(), axis=1)
114
+ df = flatten_ndarray_column(df, 'target')
115
+ # 删除原始的start和freq列
116
+ df.drop(columns=['start', 'freq', 'target'], inplace=True)
117
+ if 'past_feat_dynamic_real' in df.columns:
118
+ df.drop(columns=['past_feat_dynamic_real'], inplace=True)
119
+ return df