import torch import torch.nn as nn from transformers import AutoConfig import gradio as gr import numpy as np from sklearn.preprocessing import StandardScaler # 定义自定义模型类 class CustomTransformerModel(nn.Module): def __init__(self, config): super(CustomTransformerModel, self).__init__() self.embedding = nn.Linear(config.input_dim, config.model_dim) self.encoder_layer = nn.TransformerEncoderLayer(d_model=config.model_dim, nhead=config.num_heads, batch_first=True) self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=config.num_layers) self.fc = nn.Linear(config.model_dim, config.output_dim) def forward(self, src): src = self.embedding(src) output = self.transformer_encoder(src) output = self.fc(output[:, -1, :]) return output # 加载模型配置 config = AutoConfig.from_pretrained("ckcl/mexc_price_model", config_file_name="setting.json") # 创建模型实例并加载权重 model = CustomTransformerModel(config) model.load_state_dict(torch.load("ckcl/mexc_price.pth")) model.eval() # 设置模型为评估模式 # 定义预测函数 def predict(time, open_price, close_price, high_price, low_price, vol, amount, real_open, real_close, real_high, real_low, MA_5, MA_10, vol_diff): new_data = np.array([[time, open_price, close_price, high_price, low_price, vol, amount, real_open, real_close, real_high, real_low, MA_5, MA_10, vol_diff]]) scaler = StandardScaler() new_data_scaled = scaler.fit_transform(new_data) input_tensor = torch.tensor(new_data_scaled, dtype=torch.float32).unsqueeze(1) with torch.no_grad(): prediction = model(input_tensor) predicted_value = prediction.squeeze().item() return predicted_value # 定义 Gradio 接口 inputs = [ gr.inputs.Number(label="Time"), gr.inputs.Number(label="Open Price"), gr.inputs.Number(label="Close Price"), gr.inputs.Number(label="High Price"), gr.inputs.Number(label="Low Price"), gr.inputs.Number(label="Volume"), gr.inputs.Number(label="Amount"), gr.inputs.Number(label="Real Open"), gr.inputs.Number(label="Real Close"), gr.inputs.Number(label="Real High"), gr.inputs.Number(label="Real Low"), gr.inputs.Number(label="MA 5"), gr.inputs.Number(label="MA 10"), gr.inputs.Number(label="Volume Diff") ] outputs = gr.outputs.Number(label="Predicted Price") gr.Interface(fn=predict, inputs=inputs, outputs=outputs, title="MEXC Price Prediction", description="Predict MEXC contract price using custom Transformer model").launch()