Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
from transformers import AutoConfig | |
import gradio as gr | |
import numpy as np | |
from sklearn.preprocessing import StandardScaler | |
# 定义自定义模型类 | |
class CustomTransformerModel(nn.Module): | |
def __init__(self, config): | |
super(CustomTransformerModel, self).__init__() | |
self.embedding = nn.Linear(config.input_dim, config.model_dim) | |
self.encoder_layer = nn.TransformerEncoderLayer(d_model=config.model_dim, nhead=config.num_heads, batch_first=True) | |
self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=config.num_layers) | |
self.fc = nn.Linear(config.model_dim, config.output_dim) | |
def forward(self, src): | |
src = self.embedding(src) | |
output = self.transformer_encoder(src) | |
output = self.fc(output[:, -1, :]) | |
return output | |
# 加载模型配置 | |
config = AutoConfig.from_pretrained("ckcl/mexc_price_model", config_file_name="setting.json") | |
# 创建模型实例并加载权重 | |
model = CustomTransformerModel(config) | |
model.load_state_dict(torch.load("model_repo/mexc_price.pth")) | |
model.eval() # 设置模型为评估模式 | |
# 定义预测函数 | |
def predict(time, open_price, close_price, high_price, low_price, vol, amount, real_open, real_close, real_high, real_low, MA_5, MA_10, vol_diff): | |
new_data = np.array([[time, open_price, close_price, high_price, low_price, vol, amount, real_open, real_close, real_high, real_low, MA_5, MA_10, vol_diff]]) | |
scaler = StandardScaler() | |
new_data_scaled = scaler.fit_transform(new_data) | |
input_tensor = torch.tensor(new_data_scaled, dtype=torch.float32).unsqueeze(1) | |
with torch.no_grad(): | |
prediction = model(input_tensor) | |
predicted_value = prediction.squeeze().item() | |
return predicted_value | |
# 定义 Gradio 接口 | |
inputs = [ | |
gr.inputs.Number(label="Time"), | |
gr.inputs.Number(label="Open Price"), | |
gr.inputs.Number(label="Close Price"), | |
gr.inputs.Number(label="High Price"), | |
gr.inputs.Number(label="Low Price"), | |
gr.inputs.Number(label="Volume"), | |
gr.inputs.Number(label="Amount"), | |
gr.inputs.Number(label="Real Open"), | |
gr.inputs.Number(label="Real Close"), | |
gr.inputs.Number(label="Real High"), | |
gr.inputs.Number(label="Real Low"), | |
gr.inputs.Number(label="MA 5"), | |
gr.inputs.Number(label="MA 10"), | |
gr.inputs.Number(label="Volume Diff") | |
] | |
outputs = gr.outputs.Number(label="Predicted Price") | |
gr.Interface(fn=predict, inputs=inputs, outputs=outputs, title="MEXC Price Prediction", description="Predict MEXC contract price using custom Transformer model").launch() |