mexc_prediction / app.py
ckcl's picture
Update app.py
92edfc3 verified
raw
history blame
2.62 kB
import torch
import torch.nn as nn
from transformers import AutoConfig
import gradio as gr
import numpy as np
from sklearn.preprocessing import StandardScaler
# 定义自定义模型类
class CustomTransformerModel(nn.Module):
def __init__(self, config):
super(CustomTransformerModel, self).__init__()
self.embedding = nn.Linear(config.input_dim, config.model_dim)
self.encoder_layer = nn.TransformerEncoderLayer(d_model=config.model_dim, nhead=config.num_heads, batch_first=True)
self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=config.num_layers)
self.fc = nn.Linear(config.model_dim, config.output_dim)
def forward(self, src):
src = self.embedding(src)
output = self.transformer_encoder(src)
output = self.fc(output[:, -1, :])
return output
# 加载模型配置
config = AutoConfig.from_pretrained("ckcl/mexc_price_model", config_file_name="setting.json")
# 创建模型实例并加载权重
model = CustomTransformerModel(config)
model.load_state_dict(torch.load("model_repo/mexc_price.pth"))
model.eval() # 设置模型为评估模式
# 定义预测函数
def predict(time, open_price, close_price, high_price, low_price, vol, amount, real_open, real_close, real_high, real_low, MA_5, MA_10, vol_diff):
new_data = np.array([[time, open_price, close_price, high_price, low_price, vol, amount, real_open, real_close, real_high, real_low, MA_5, MA_10, vol_diff]])
scaler = StandardScaler()
new_data_scaled = scaler.fit_transform(new_data)
input_tensor = torch.tensor(new_data_scaled, dtype=torch.float32).unsqueeze(1)
with torch.no_grad():
prediction = model(input_tensor)
predicted_value = prediction.squeeze().item()
return predicted_value
# 定义 Gradio 接口
inputs = [
gr.inputs.Number(label="Time"),
gr.inputs.Number(label="Open Price"),
gr.inputs.Number(label="Close Price"),
gr.inputs.Number(label="High Price"),
gr.inputs.Number(label="Low Price"),
gr.inputs.Number(label="Volume"),
gr.inputs.Number(label="Amount"),
gr.inputs.Number(label="Real Open"),
gr.inputs.Number(label="Real Close"),
gr.inputs.Number(label="Real High"),
gr.inputs.Number(label="Real Low"),
gr.inputs.Number(label="MA 5"),
gr.inputs.Number(label="MA 10"),
gr.inputs.Number(label="Volume Diff")
]
outputs = gr.outputs.Number(label="Predicted Price")
gr.Interface(fn=predict, inputs=inputs, outputs=outputs, title="MEXC Price Prediction", description="Predict MEXC contract price using custom Transformer model").launch()