##################################################### # Utils ##################################################### # 本文件包含了一些用于数据处理和绘图的实用函数。 import base64 from io import BytesIO from matplotlib import pyplot as plt import pandas as pd import plotly.graph_objects as go import numpy as np def ndarray_to_base64(ndarray): """ 将一维np.ndarray绘图并转换为Base64编码。 """ # 创建绘图 plt.figure(figsize=(8, 4)) plt.plot(ndarray) plt.title("Vector Plot") plt.xlabel("Index") plt.ylabel("Value") plt.tight_layout() # 保存图像到内存字节流 buffer = BytesIO() plt.savefig(buffer, format="png") plt.close() buffer.seek(0) # 转换为Base64字符串 base64_str = base64.b64encode(buffer.getvalue()).decode('utf-8') return f"data:image/png;base64,{base64_str}" def flatten_ndarray_column(df, column_name, rows_to_include): """ 将嵌套的np.ndarray列展平为多列,并只保留指定的行。 """ def select_and_flatten(ndarray): if isinstance(ndarray, np.ndarray) and ndarray.dtype == 'O': selected = [ndarray[i] for i in rows_to_include if i < len(ndarray)] return np.concatenate([select_and_flatten(subarray) for subarray in selected]) elif isinstance(ndarray, np.ndarray) and ndarray.ndim == 1: return np.expand_dims(ndarray, axis=0) return ndarray selected_data = df[column_name].apply(select_and_flatten) for i, index in enumerate(rows_to_include): df[f'{column_name}_{index}'] = selected_data.apply(lambda x: x[i]) return df def create_plot(dfs:list[pd.DataFrame], ids:list[str]): """ 创建一个包含所有传入 DataFrame 的线图。 """ fig = go.Figure() for df, df_id in zip(dfs, ids): for i, column in enumerate(df.columns[1:]): fig.add_trace(go.Scatter( x=df[df.columns[0]], y=df[column], mode='lines', name=f"item_{df_id} - {column}", visible=True if i == 0 else 'legendonly' )) # 配置图例 fig.update_layout( legend=dict( title="Variables", orientation="h", yanchor="top", y=-0.2, xanchor="center", x=0.5 ), xaxis_title='Time', yaxis_title='Values' ) return fig def create_statistic(dfs: list[pd.DataFrame], ids: list[str]): """ 计算数据集列表的统计信息。 """ stats_list = [] for df, id in zip(dfs, ids): df_values = df.iloc[:, 1:] # 计算统计值 mean_values = df_values.mean() std_values = df_values.std() max_values = df_values.max() min_values = df_values.min() # 将这些统计信息合并成一个新的DataFrame stats_df = pd.DataFrame({ 'Variables': [f"{id}_{col}" for col in df_values.columns], 'mean': mean_values.values, 'std': std_values.values, 'max': max_values.values, 'min': min_values.values }) stats_list.append(stats_df) # 合并所有统计信息DataFrame combined_stats_df = pd.concat(stats_list, ignore_index=True) combined_stats_df = combined_stats_df.applymap(lambda x: round(x, 2) if isinstance(x, (int, float)) else x) return combined_stats_df def clean_up_df(df: pd.DataFrame, rows_to_include: list[int]) -> pd.DataFrame: """ 清理数据集,将嵌套的np.ndarray列展平为多列。 """ rows_to_include = sorted(rows_to_include) df['timestamp'] = df.apply(lambda row: pd.date_range( start=row['start'], periods=len(row['target'][0]) if isinstance(row['target'][0], np.ndarray) else len(row['target']), freq=row['freq'] ).to_pydatetime().tolist(), axis=1) df = flatten_ndarray_column(df, 'target', rows_to_include) # 删除原始的start和freq列 df.drop(columns=['start', 'freq', 'target'], inplace=True) if 'past_feat_dynamic_real' in df.columns: df.drop(columns=['past_feat_dynamic_real'], inplace=True) return df if __name__ == '__main__': # 创建测试数据 data1 = { 'Time': ['2023-01-01', '2023-01-02', '2023-01-03'], 'Value1': [10, 15, 20], 'Value2': [20, 25, 30] } data2 = { 'Time': ['2023-01-01', '2023-01-02', '2023-01-03'], 'Value3': [5, 10, 15], 'Value4': [15, 20, 25] } df1 = pd.DataFrame(data1) df2 = pd.DataFrame(data2) # 转换时间列为日期时间格式 df1['Time'] = pd.to_datetime(df1['Time']) df2['Time'] = pd.to_datetime(df2['Time']) # 创建图表 fig = create_plot(df1, df2) # 显示图表 fig.show()