lotsa_explorer / utils.py
Liu Yiwen
修复了小数位保留的bug
beccaa7
#####################################################
# Utils
#####################################################
# 本文件包含了一些用于数据处理和绘图的实用函数。
import base64
from io import BytesIO
from matplotlib import pyplot as plt
import pandas as pd
import plotly.graph_objects as go
import numpy as np
def ndarray_to_base64(ndarray):
"""
将一维np.ndarray绘图并转换为Base64编码。
"""
# 创建绘图
plt.figure(figsize=(8, 4))
plt.plot(ndarray)
plt.title("Vector Plot")
plt.xlabel("Index")
plt.ylabel("Value")
plt.tight_layout()
# 保存图像到内存字节流
buffer = BytesIO()
plt.savefig(buffer, format="png")
plt.close()
buffer.seek(0)
# 转换为Base64字符串
base64_str = base64.b64encode(buffer.getvalue()).decode('utf-8')
return f"data:image/png;base64,{base64_str}"
def flatten_ndarray_column(df, column_name, rows_to_include):
"""
将嵌套的np.ndarray列展平为多列,并只保留指定的行。
"""
def select_and_flatten(ndarray):
if isinstance(ndarray, np.ndarray) and ndarray.dtype == 'O':
selected = [ndarray[i] for i in rows_to_include if i < len(ndarray)]
return np.concatenate([select_and_flatten(subarray) for subarray in selected])
elif isinstance(ndarray, np.ndarray) and ndarray.ndim == 1:
return np.expand_dims(ndarray, axis=0)
return ndarray
selected_data = df[column_name].apply(select_and_flatten)
for i, index in enumerate(rows_to_include):
df[f'{column_name}_{index}'] = selected_data.apply(lambda x: x[i])
return df
def create_plot(dfs:list[pd.DataFrame], ids:list[str]):
"""
创建一个包含所有传入 DataFrame 的线图。
"""
fig = go.Figure()
for df, df_id in zip(dfs, ids):
for i, column in enumerate(df.columns[1:]):
fig.add_trace(go.Scatter(
x=df[df.columns[0]],
y=df[column],
mode='lines',
name=f"item_{df_id} - {column}",
visible=True if i == 0 else 'legendonly'
))
# 配置图例
fig.update_layout(
legend=dict(
title="Variables",
orientation="h",
yanchor="top",
y=-0.2,
xanchor="center",
x=0.5
),
xaxis_title='Time',
yaxis_title='Values'
)
return fig
def create_statistic(dfs: list[pd.DataFrame], ids: list[str]):
"""
计算数据集列表的统计信息。
"""
stats_list = []
for df, id in zip(dfs, ids):
df_values = df.iloc[:, 1:]
# 计算统计值
mean_values = df_values.mean()
std_values = df_values.std()
max_values = df_values.max()
min_values = df_values.min()
# 将这些统计信息合并成一个新的DataFrame
stats_df = pd.DataFrame({
'Variables': [f"{id}_{col}" for col in df_values.columns],
'mean': mean_values.values,
'std': std_values.values,
'max': max_values.values,
'min': min_values.values
})
stats_list.append(stats_df)
# 合并所有统计信息DataFrame
combined_stats_df = pd.concat(stats_list, ignore_index=True)
combined_stats_df = combined_stats_df.applymap(lambda x: round(x, 2) if isinstance(x, (int, float)) else x)
return combined_stats_df
def clean_up_df(df: pd.DataFrame, rows_to_include: list[int]) -> pd.DataFrame:
"""
清理数据集,将嵌套的np.ndarray列展平为多列。
"""
rows_to_include = sorted(rows_to_include)
df['timestamp'] = df.apply(lambda row: pd.date_range(
start=row['start'],
periods=len(row['target'][0]) if isinstance(row['target'][0], np.ndarray) else len(row['target']),
freq=row['freq']
).to_pydatetime().tolist(), axis=1)
df = flatten_ndarray_column(df, 'target', rows_to_include)
# 删除原始的start和freq列
df.drop(columns=['start', 'freq', 'target'], inplace=True)
if 'past_feat_dynamic_real' in df.columns:
df.drop(columns=['past_feat_dynamic_real'], inplace=True)
return df
if __name__ == '__main__':
# 创建测试数据
data1 = {
'Time': ['2023-01-01', '2023-01-02', '2023-01-03'],
'Value1': [10, 15, 20],
'Value2': [20, 25, 30]
}
data2 = {
'Time': ['2023-01-01', '2023-01-02', '2023-01-03'],
'Value3': [5, 10, 15],
'Value4': [15, 20, 25]
}
df1 = pd.DataFrame(data1)
df2 = pd.DataFrame(data2)
# 转换时间列为日期时间格式
df1['Time'] = pd.to_datetime(df1['Time'])
df2['Time'] = pd.to_datetime(df2['Time'])
# 创建图表
fig = create_plot(df1, df2)
# 显示图表
fig.show()