ckcl commited on
Commit
f370b67
·
verified ·
1 Parent(s): 92edfc3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -53
app.py CHANGED
@@ -1,62 +1,32 @@
1
  import torch
2
- import torch.nn as nn
3
- from transformers import AutoConfig
4
  import gradio as gr
5
- import numpy as np
6
- from sklearn.preprocessing import StandardScaler
7
 
8
- # 定义自定义模型类
9
- class CustomTransformerModel(nn.Module):
10
- def __init__(self, config):
11
- super(CustomTransformerModel, self).__init__()
12
- self.embedding = nn.Linear(config.input_dim, config.model_dim)
13
- self.encoder_layer = nn.TransformerEncoderLayer(d_model=config.model_dim, nhead=config.num_heads, batch_first=True)
14
- self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=config.num_layers)
15
- self.fc = nn.Linear(config.model_dim, config.output_dim)
16
-
17
- def forward(self, src):
18
- src = self.embedding(src)
19
- output = self.transformer_encoder(src)
20
- output = self.fc(output[:, -1, :])
21
- return output
22
-
23
- # 加载模型配置
24
- config = AutoConfig.from_pretrained("ckcl/mexc_price_model", config_file_name="setting.json")
25
 
26
- # 创建模型实例并加载权重
27
- model = CustomTransformerModel(config)
28
- model.load_state_dict(torch.load("model_repo/mexc_price.pth"))
29
- model.eval() # 设置模型为评估模式
30
 
31
  # 定义预测函数
32
- def predict(time, open_price, close_price, high_price, low_price, vol, amount, real_open, real_close, real_high, real_low, MA_5, MA_10, vol_diff):
33
- new_data = np.array([[time, open_price, close_price, high_price, low_price, vol, amount, real_open, real_close, real_high, real_low, MA_5, MA_10, vol_diff]])
34
- scaler = StandardScaler()
35
- new_data_scaled = scaler.fit_transform(new_data)
36
- input_tensor = torch.tensor(new_data_scaled, dtype=torch.float32).unsqueeze(1)
37
  with torch.no_grad():
38
- prediction = model(input_tensor)
39
- predicted_value = prediction.squeeze().item()
40
- return predicted_value
41
-
42
- # 定义 Gradio 接口
43
- inputs = [
44
- gr.inputs.Number(label="Time"),
45
- gr.inputs.Number(label="Open Price"),
46
- gr.inputs.Number(label="Close Price"),
47
- gr.inputs.Number(label="High Price"),
48
- gr.inputs.Number(label="Low Price"),
49
- gr.inputs.Number(label="Volume"),
50
- gr.inputs.Number(label="Amount"),
51
- gr.inputs.Number(label="Real Open"),
52
- gr.inputs.Number(label="Real Close"),
53
- gr.inputs.Number(label="Real High"),
54
- gr.inputs.Number(label="Real Low"),
55
- gr.inputs.Number(label="MA 5"),
56
- gr.inputs.Number(label="MA 10"),
57
- gr.inputs.Number(label="Volume Diff")
58
- ]
59
 
60
- outputs = gr.outputs.Number(label="Predicted Price")
 
61
 
62
- gr.Interface(fn=predict, inputs=inputs, outputs=outputs, title="MEXC Price Prediction", description="Predict MEXC contract price using custom Transformer model").launch()
 
 
1
  import torch
2
+ from transformers import AutoAdapterModel, AutoTokenizer
3
+ from datasets import load_dataset
4
  import gradio as gr
 
 
5
 
6
+ # 加载模型和分词器
7
+ model_name = "ckcl/mexc_price_model"
8
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
9
+ model = AutoAdapterModel.from_pretrained(model_name)
10
+ model.load_adapter(model_name, set_active=True)
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
+ # 加载数据集
13
+ ds = load_dataset("ckcl/BTC_USDT_dataset")
 
 
14
 
15
  # 定义预测函数
16
+ def predict(input_text):
17
+ # 处理输入
18
+ inputs = tokenizer(input_text, return_tensors="pt")
19
+
20
+ # 进行预测
21
  with torch.no_grad():
22
+ outputs = model(**inputs)
23
+
24
+ # 获取预测结果
25
+ predictions = torch.argmax(outputs.logits, dim=-1)
26
+ return str(predictions.item())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
+ # 创建 Gradio 界面
29
+ iface = gr.Interface(fn=predict, inputs="text", outputs="text", title="MEXC Contract Prediction", description="Predict contract prices for MEXC.")
30
 
31
+ # 启动应用
32
+ iface.launch()