##################################################### # Utils ##################################################### # 本文件包含了一些用于数据处理和绘图的实用函数。 import base64 from io import BytesIO from matplotlib import pyplot as plt import pandas as pd import plotly.graph_objects as go import numpy as np def ndarray_to_base64(ndarray): """ 将一维np.ndarray绘图并转换为Base64编码。 """ # 创建绘图 plt.figure(figsize=(8, 4)) plt.plot(ndarray) plt.title("Vector Plot") plt.xlabel("Index") plt.ylabel("Value") plt.tight_layout() # 保存图像到内存字节流 buffer = BytesIO() plt.savefig(buffer, format="png") plt.close() buffer.seek(0) # 转换为Base64字符串 base64_str = base64.b64encode(buffer.getvalue()).decode('utf-8') return f"data:image/png;base64,{base64_str}" def flatten_ndarray_column(df, column_name): """ 将嵌套的np.ndarray列展平为多列。 """ def flatten_ndarray(ndarray): if isinstance(ndarray, np.ndarray) and ndarray.dtype == 'O': return np.concatenate([flatten_ndarray(subarray) for subarray in ndarray]) elif isinstance(ndarray, np.ndarray) and ndarray.ndim == 1: return np.expand_dims(ndarray, axis=0) return ndarray flattened_data = df[column_name].apply(flatten_ndarray) max_length = max(flattened_data.apply(len)) for i in range(max_length): df[f'{column_name}_{i}'] = flattened_data.apply(lambda x: x[i] if i < len(x) else np.nan) return df def create_plot(dfs:list[pd.DataFrame], ids:list[str]): """ 创建一个包含所有传入 DataFrame 的线图。 """ fig = go.Figure() for df, df_id in zip(dfs, ids): for i, column in enumerate(df.columns[1:]): fig.add_trace(go.Scatter( x=df[df.columns[0]], y=df[column], mode='lines', name=f"item_{df_id} - {column}", visible=True if i == 0 else 'legendonly' )) # 配置图例 fig.update_layout( legend=dict( title="Variables", orientation="h", yanchor="top", y=-0.2, xanchor="center", x=0.5 ), xaxis_title='Time', yaxis_title='Values' ) return fig def create_statistic(dfs: list[pd.DataFrame], ids: list[str]): """ 计算数据集列表的统计信息。 """ stats_list = [] for df, id in zip(dfs, ids): df_values = df.iloc[:, 1:] # 计算统计值 mean_values = df_values.mean().round(2) std_values = df_values.std().round(2) max_values = df_values.max().round(2) min_values = df_values.min().round(2) # 将这些统计信息合并成一个新的DataFrame stats_df = pd.DataFrame({ 'Variables': [f"{id}_{col}" for col in df_values.columns], 'mean': mean_values.values, 'std': std_values.values, 'max': max_values.values, 'min': min_values.values }) stats_list.append(stats_df) # 合并所有统计信息DataFrame combined_stats_df = pd.concat(stats_list, ignore_index=True) return combined_stats_df def clean_up_df(df: pd.DataFrame) -> pd.DataFrame: """ 清理数据集,将嵌套的np.ndarray列展平为多列。 """ df['timestamp'] = df.apply(lambda row: pd.date_range( start=row['start'], periods=len(row['target'][0]) if isinstance(row['target'][0], np.ndarray) else len(row['target']), freq=row['freq'] ).to_pydatetime().tolist(), axis=1) df = flatten_ndarray_column(df, 'target') # 删除原始的start和freq列 df.drop(columns=['start', 'freq', 'target'], inplace=True) if 'past_feat_dynamic_real' in df.columns: df.drop(columns=['past_feat_dynamic_real'], inplace=True) return df if __name__ == '__main__': # 创建测试数据 data1 = { 'Time': ['2023-01-01', '2023-01-02', '2023-01-03'], 'Value1': [10, 15, 20], 'Value2': [20, 25, 30] } data2 = { 'Time': ['2023-01-01', '2023-01-02', '2023-01-03'], 'Value3': [5, 10, 15], 'Value4': [15, 20, 25] } df1 = pd.DataFrame(data1) df2 = pd.DataFrame(data2) # 转换时间列为日期时间格式 df1['Time'] = pd.to_datetime(df1['Time']) df2['Time'] = pd.to_datetime(df2['Time']) # 创建图表 fig = create_plot(df1, df2) # 显示图表 fig.show()