Spaces:
Running
Running
File size: 4,881 Bytes
b4b95a6 e03ca4d b4b95a6 e03ca4d b4b95a6 e03ca4d b4b95a6 e03ca4d b4b95a6 e03ca4d b4b95a6 f312fcb b4b95a6 e6303fa b4b95a6 e6303fa b4b95a6 e6303fa b4b95a6 e6303fa b4b95a6 e6303fa b4b95a6 e6303fa beccaa7 e6303fa beccaa7 e6303fa b4b95a6 e03ca4d b4b95a6 fad121c e03ca4d b4b95a6 e03ca4d b4b95a6 e6303fa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 |
#####################################################
# Utils
#####################################################
# 本文件包含了一些用于数据处理和绘图的实用函数。
import base64
from io import BytesIO
from matplotlib import pyplot as plt
import pandas as pd
import plotly.graph_objects as go
import numpy as np
def ndarray_to_base64(ndarray):
"""
将一维np.ndarray绘图并转换为Base64编码。
"""
# 创建绘图
plt.figure(figsize=(8, 4))
plt.plot(ndarray)
plt.title("Vector Plot")
plt.xlabel("Index")
plt.ylabel("Value")
plt.tight_layout()
# 保存图像到内存字节流
buffer = BytesIO()
plt.savefig(buffer, format="png")
plt.close()
buffer.seek(0)
# 转换为Base64字符串
base64_str = base64.b64encode(buffer.getvalue()).decode('utf-8')
return f"data:image/png;base64,{base64_str}"
def flatten_ndarray_column(df, column_name, rows_to_include):
"""
将嵌套的np.ndarray列展平为多列,并只保留指定的行。
"""
def select_and_flatten(ndarray):
if isinstance(ndarray, np.ndarray) and ndarray.dtype == 'O':
selected = [ndarray[i] for i in rows_to_include if i < len(ndarray)]
return np.concatenate([select_and_flatten(subarray) for subarray in selected])
elif isinstance(ndarray, np.ndarray) and ndarray.ndim == 1:
return np.expand_dims(ndarray, axis=0)
return ndarray
selected_data = df[column_name].apply(select_and_flatten)
for i, index in enumerate(rows_to_include):
df[f'{column_name}_{index}'] = selected_data.apply(lambda x: x[i])
return df
def create_plot(dfs:list[pd.DataFrame], ids:list[str]):
"""
创建一个包含所有传入 DataFrame 的线图。
"""
fig = go.Figure()
for df, df_id in zip(dfs, ids):
for i, column in enumerate(df.columns[1:]):
fig.add_trace(go.Scatter(
x=df[df.columns[0]],
y=df[column],
mode='lines',
name=f"item_{df_id} - {column}",
visible=True if i == 0 else 'legendonly'
))
# 配置图例
fig.update_layout(
legend=dict(
title="Variables",
orientation="h",
yanchor="top",
y=-0.2,
xanchor="center",
x=0.5
),
xaxis_title='Time',
yaxis_title='Values'
)
return fig
def create_statistic(dfs: list[pd.DataFrame], ids: list[str]):
"""
计算数据集列表的统计信息。
"""
stats_list = []
for df, id in zip(dfs, ids):
df_values = df.iloc[:, 1:]
# 计算统计值
mean_values = df_values.mean()
std_values = df_values.std()
max_values = df_values.max()
min_values = df_values.min()
# 将这些统计信息合并成一个新的DataFrame
stats_df = pd.DataFrame({
'Variables': [f"{id}_{col}" for col in df_values.columns],
'mean': mean_values.values,
'std': std_values.values,
'max': max_values.values,
'min': min_values.values
})
stats_list.append(stats_df)
# 合并所有统计信息DataFrame
combined_stats_df = pd.concat(stats_list, ignore_index=True)
combined_stats_df = combined_stats_df.applymap(lambda x: round(x, 2) if isinstance(x, (int, float)) else x)
return combined_stats_df
def clean_up_df(df: pd.DataFrame, rows_to_include: list[int]) -> pd.DataFrame:
"""
清理数据集,将嵌套的np.ndarray列展平为多列。
"""
rows_to_include = sorted(rows_to_include)
df['timestamp'] = df.apply(lambda row: pd.date_range(
start=row['start'],
periods=len(row['target'][0]) if isinstance(row['target'][0], np.ndarray) else len(row['target']),
freq=row['freq']
).to_pydatetime().tolist(), axis=1)
df = flatten_ndarray_column(df, 'target', rows_to_include)
# 删除原始的start和freq列
df.drop(columns=['start', 'freq', 'target'], inplace=True)
if 'past_feat_dynamic_real' in df.columns:
df.drop(columns=['past_feat_dynamic_real'], inplace=True)
return df
if __name__ == '__main__':
# 创建测试数据
data1 = {
'Time': ['2023-01-01', '2023-01-02', '2023-01-03'],
'Value1': [10, 15, 20],
'Value2': [20, 25, 30]
}
data2 = {
'Time': ['2023-01-01', '2023-01-02', '2023-01-03'],
'Value3': [5, 10, 15],
'Value4': [15, 20, 25]
}
df1 = pd.DataFrame(data1)
df2 = pd.DataFrame(data2)
# 转换时间列为日期时间格式
df1['Time'] = pd.to_datetime(df1['Time'])
df2['Time'] = pd.to_datetime(df2['Time'])
# 创建图表
fig = create_plot(df1, df2)
# 显示图表
fig.show() |