parkerjj commited on
Commit
62f31c8
·
1 Parent(s): 2609d5c

Daily Update, First Release for model 1012

Browse files
Files changed (5) hide show
  1. RequestModel.py +8 -0
  2. app.py +22 -10
  3. blkeras.py +176 -88
  4. preprocess.py +97 -40
  5. us_stock.py +18 -40
RequestModel.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+
2
+ from typing import Optional, List
3
+ from pydantic import BaseModel
4
+
5
+ class PredictRequest(BaseModel):
6
+ text: str
7
+ stock_codes: Optional[List[str]] = None # 定义为可选字段,可以是一个字符串列表
8
+
app.py CHANGED
@@ -6,6 +6,9 @@ from fastapi.middleware.wsgi import WSGIMiddleware
6
 
7
  from transformers import pipeline
8
 
 
 
 
9
  app = FastAPI() # 创建 FastAPI 应用
10
 
11
  # 定义请求模型
@@ -37,20 +40,29 @@ async def api_bbb(request: TextRequest):
37
  return {"result": result}
38
 
39
 
40
- pipe_flan = pipeline("text2text-generation", model="google/flan-t5-small")
41
-
42
- @app.get("/infer_t5")
43
- def t5_get(input):
44
- output = pipe_flan(input)
45
- return {"output": output[0]["generated_text"]}
46
 
 
 
 
47
 
48
- @app.post("/infer_t5")
49
- def t5_post(input):
50
- output = pipe_flan(input)
51
- return {"output": output[0]["generated_text"]}
 
 
 
 
52
 
53
  @app.get("/")
54
  async def root():
55
  return {"message": "Welcome to the API. Use /api/aaa or /api/bbb for processing."}
56
 
 
 
 
 
 
6
 
7
  from transformers import pipeline
8
 
9
+ from RequestModel import PredictRequest
10
+ from us_stock import fetch_symbols
11
+
12
  app = FastAPI() # 创建 FastAPI 应用
13
 
14
  # 定义请求模型
 
40
  return {"result": result}
41
 
42
 
43
+ @app.on_event("startup")
44
+ async def initialize_symbols():
45
+ # 在 FastAPI 启动时初始化变量
46
+ await fetch_symbols()
 
 
47
 
48
+ @app.post("/api/predict")
49
+ async def predict(request: PredictRequest):
50
+ from blkeras import predict
51
 
52
+ try:
53
+ input_text = request.text # FastAPI 会自动解析为 PredictRequest 对象
54
+ affected_stock_codes = request.stock_codes
55
+ print("Input text:", input_text)
56
+ print("Affected stock codes:", affected_stock_codes)
57
+ return predict(input_text, affected_stock_codes)
58
+ except Exception as e:
59
+ return {"error": str(e)}
60
 
61
  @app.get("/")
62
  async def root():
63
  return {"message": "Welcome to the API. Use /api/aaa or /api/bbb for processing."}
64
 
65
+
66
+ if __name__ == "__main__":
67
+ import uvicorn
68
+ uvicorn.run(app, host="0.0.0.0", port=7860)
blkeras.py CHANGED
@@ -19,6 +19,8 @@ from datetime import datetime, timedelta
19
 
20
  import os
21
 
 
 
22
  from us_stock import find_stock_codes_or_names
23
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
24
 
@@ -39,7 +41,7 @@ if model is None:
39
 
40
  # 下载模型到本地
41
  model_path = hf_hub_download(repo_id="parkerjj/BuckLake-Stock-Model",
42
- filename="20240927.keras",
43
  use_auth_token=hf_token)
44
 
45
  # 使用 Keras 加载模型
@@ -78,20 +80,13 @@ def generate_fake_accuracy():
78
 
79
 
80
 
81
- def predict():
82
  from tensorflow.keras.preprocessing.sequence import pad_sequences # type: ignore
83
  from preprocess import get_document_vector, get_stock_info, preprocessing_entry, process_entities, process_pos_tags, processing_entry
84
 
85
  try:
86
- # 获取请求数据,假设数据以 JSON 形式传入
87
- data = request.get_json()
88
-
89
- # 解析请求数据,获取文本字符串
90
- if 'text' not in data:
91
- raise ValueError("Missing 'text' field in input data")
92
-
93
- input_text = data['text']
94
- affected_stock_codes = data.get('stock_codes', None)
95
 
96
 
97
  print(f"predict() Input text: {input_text}")
@@ -118,18 +113,30 @@ def predict():
118
  # 检查缓存中是否已有结果
119
  if cache_key in prediction_cache:
120
  print(f"Cache hit: {cache_key} lemmatized_entry: {lemmatized_entry} value: {prediction_cache[cache_key]}" )
121
- return jsonify(prediction_cache[cache_key])
122
 
123
 
124
  # 调用 get_stock_info 函数
125
- stock_info = get_stock_info("", datetime.now())
126
- previous_stock_history, following_stock_history, previous_stock_index_history, following_stock_index_history = stock_info
127
 
128
- # 分别打印每个变量,便于调试
129
- print("Previous Stock History:", previous_stock_history)
130
- print("Following Stock History:", following_stock_history)
131
- print("Previous Stock Index History:", previous_stock_index_history)
132
- print("Following Stock Index History:", following_stock_index_history)
 
 
 
 
 
 
 
 
 
 
 
 
 
133
 
134
  # 3. 将特征转换为适合模型输入的形状
135
  # 这里假设文本、POS、实体识别等是向量,时间序列特征是 (sequence_length, feature_dim) 的形状
@@ -163,36 +170,39 @@ def predict():
163
  # 情感得分
164
  X_sentiment = np.array([[sentiment_score]], dtype='float32') # sentiment_score 已经是单值,直接转换为二维数组
165
 
166
- # 构造其他特征
167
- # 将时间序列特征转换为合适的形状
168
- # 确保 index_feature 和 stock_feature 的形状为 (1, 4, 6)
169
- index_feature = np.array(previous_stock_index_history, dtype='float32').reshape(1, 4, 6)
170
- stock_feature = np.array(previous_stock_history, dtype='float32').reshape(1, 4, 6)
171
-
172
- print("index_feature values:", index_feature)
173
- print("stock_feature values:", stock_feature)
174
-
175
  # 打印输入特征的形状,便于调试
176
  print("X_word2vec shape:", X_word2vec.shape)
177
  print("X_pos_tags shape:", X_pos_tags.shape)
178
  print("X_entities shape:", X_entities.shape)
179
  print("X_sentiment shape:", X_sentiment.shape)
180
- print("index_feature shape:", index_feature.shape)
181
- print("stock_feature shape:", stock_feature.shape)
182
 
183
- # 将所有特征组织为模型需要的输入格式
 
 
 
 
 
 
 
 
 
184
  features = [
185
- X_word2vec, # text_input (batch_size, word2vec_embedding_dim) => (1, 300)
186
- X_pos_tags, # pos_input (batch_size, pos_tag_dim) => (1, 1024)
187
- X_entities, # entity_input (batch_size, entity_dim) => (1, 1024)
188
- X_sentiment, # sentiment_input (batch_size, 1) => (1, 1)
189
- index_feature, # index_input (batch_size, sequence_length, feature_dim) => (1, 4, 6)
190
- stock_feature # stock_input (batch_size, sequence_length, feature_dim) => (1, 4, 6)
191
  ]
192
 
 
 
193
  # 打印特征数组的每个元素的形状,便于调试
194
- for i, feature in enumerate(features):
195
- print(f"Feature {i} shape: {feature.shape} value: {feature[0]} length: {len(feature[0])}")
 
 
 
 
 
196
 
197
  # 使用模型进行预测
198
  predictions = model.predict(features)
@@ -201,33 +211,80 @@ def predict():
201
  fake_accuracy = generate_fake_accuracy()
202
 
203
  # 将 predictions 中的每个数组转换为 Python 列表
204
- index_predictions = predictions[0].tolist()
205
- stock_predictions = predictions[1].tolist()
 
 
 
 
 
206
 
207
  # 打印预测结果,便于调试
208
- print("Index Predictions:", index_predictions)
 
 
 
209
  print("Stock Predictions:", stock_predictions)
210
 
211
 
212
 
213
 
214
  # 获取 index_feature 中最后一天的第一个值
215
- last_index_value = index_feature[0][-1][0]
 
 
 
216
 
217
  # 提取 Index Predictions 中每一天的第一个值
218
- index_day_1 = index_predictions[0][0][0]
219
- index_day_2 = index_predictions[0][1][0]
220
- index_day_3 = index_predictions[0][2][0]
 
 
 
 
 
 
 
 
 
 
 
 
221
 
222
  # 计算 impact_1_day, impact_2_day, impact_3_day
223
- impact_1_day = (index_day_1 - last_index_value) / last_index_value
224
- impact_2_day = (index_day_2 - index_day_1) / index_day_1
225
- impact_3_day = (index_day_3 - index_day_2) / index_day_2
 
 
 
 
 
 
 
 
 
 
 
 
226
 
227
  # 将 impact 值转换为百分比字符串
228
- impact_1_day_str = f"{impact_1_day:.2%}"
229
- impact_2_day_str = f"{impact_2_day:.2%}"
230
- impact_3_day_str = f"{impact_3_day:.2%}"
 
 
 
 
 
 
 
 
 
 
 
 
231
 
232
 
233
  # 如果需要返回原始预测数据进行调试,可以直接将其放到响应中
@@ -239,15 +296,24 @@ def predict():
239
 
240
 
241
  # 针对 926 模型的修复
242
- stock_predictions = stock_fix_for_926_model(float(X_sentiment[0][0]), stock_predictions[0], stock_feature[0][-1][0])
243
- index_predictions = stock_fix_for_926_model(float(X_sentiment[0][0]), index_predictions[0], last_index_value)
 
 
 
244
 
245
  print("Stock Predictions after fix:", stock_predictions)
246
- print("Index Predictions after fix:", index_predictions)
 
 
 
247
 
248
  # 扩展股票预测数据到分钟级别
249
  stock_predictions = extend_stock_days_to_mins(stock_predictions)
250
- index_predictions = extend_stock_days_to_mins(index_predictions)
 
 
 
251
 
252
 
253
 
@@ -255,13 +321,25 @@ def predict():
255
  result = {
256
  "news_title": input_text,
257
  "ai_prediction_score": float(X_sentiment[0][0]), # 假设第一个预测值是 AI 预测得分
258
- "impact_1_day": impact_1_day_str, # 计算并格式化 impact_1_day
259
- "impact_2_day": impact_2_day_str, # 计算并格式化 impact_2_day
260
- "impact_3_day": impact_3_day_str,
 
 
 
 
 
 
 
 
 
261
  "affected_stock_codes": affected_stock_codes_str, # 动态生成受影响的股票代码
262
  "accuracy": float(fake_accuracy),
263
  "impact_on_stock": stock_predictions, # 第��个预测值是股票影响
264
- "impact_on_index": index_predictions, # 第一个预测值是股票影响
 
 
 
265
 
266
  }
267
 
@@ -275,50 +353,60 @@ def predict():
275
  print(f"predict() result: {result}")
276
 
277
  # 返回预测结果
278
- return jsonify(result)
279
 
280
  except Exception as e:
281
  # 打印完整的错误堆栈信息
282
  traceback_str = traceback.print_exc()
283
  print(f"predict() error: {e}")
284
  print(traceback_str)
285
- return jsonify({"predict() error": str(e), "traceback": traceback_str})
286
 
287
 
288
- def stock_fix_for_926_model(score, predictions, last_price):
289
- # 修复 926 模型的预测结果
 
 
 
 
 
 
 
290
  coefficient = 1.2 # 调整系数,可以根据需要微调
291
  smoothing_factor = 0.7 # 平滑因子,控制曲线平滑度
292
  window_size = 3 # 滚动平均窗口大小
293
 
294
  smoothed_predictions = [] # 用于存储平滑后的预测
295
 
296
- # day0 = predictions[0]
297
- # day0[0] = last_price
298
- # predictions.insert(0, day0) # 将最后一天的价格插入到预测列表的第一个位置
299
-
300
  for i, day in enumerate(predictions):
301
- if last_price == 0:
302
- last_price = 1
303
-
304
- # 计算波动系数,并限制其在一个较小的范围内
305
- fluctuation = random.uniform(-0.01, 0.01)
306
-
307
- # 当前预测值的修正
308
- day[0] = ((abs(day[0]) * score * coefficient / last_price / 10 / 100) + (1 + fluctuation)) * last_price
309
-
310
- # 滚动平均平滑
311
- if i >= window_size:
312
- # 计算之前窗口的平均值
313
- smoothed_value = (sum([smoothed_predictions[j][0] for j in range(i - window_size, i)]) / window_size)
314
- day[0] = smoothing_factor * smoothed_value + (1 - smoothing_factor) * day[0]
315
-
316
- # 更新最后一天的价格,用于下一个迭代
317
- last_price = day[0]
318
 
319
- # 将平滑后的预测存入
320
- smoothed_predictions.append(day)
321
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
322
  return smoothed_predictions
323
 
324
 
 
19
 
20
  import os
21
 
22
+ from RequestModel import PredictRequest
23
+ from app import TextRequest
24
  from us_stock import find_stock_codes_or_names
25
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
26
 
 
41
 
42
  # 下载模型到本地
43
  model_path = hf_hub_download(repo_id="parkerjj/BuckLake-Stock-Model",
44
+ filename="stock_prediction_model_1012.keras",
45
  use_auth_token=hf_token)
46
 
47
  # 使用 Keras 加载模型
 
80
 
81
 
82
 
83
+ def predict(text: str, stock_codes: list):
84
  from tensorflow.keras.preprocessing.sequence import pad_sequences # type: ignore
85
  from preprocess import get_document_vector, get_stock_info, preprocessing_entry, process_entities, process_pos_tags, processing_entry
86
 
87
  try:
88
+ input_text = text
89
+ affected_stock_codes = stock_codes
 
 
 
 
 
 
 
90
 
91
 
92
  print(f"predict() Input text: {input_text}")
 
113
  # 检查缓存中是否已有结果
114
  if cache_key in prediction_cache:
115
  print(f"Cache hit: {cache_key} lemmatized_entry: {lemmatized_entry} value: {prediction_cache[cache_key]}" )
116
+ return prediction_cache[cache_key]
117
 
118
 
119
  # 调用 get_stock_info 函数
120
+ previous_stock_history, _, previous_stock_inx_index_history, previous_stock_dj_index_history, previous_stock_ixic_index_history, previous_stock_ndx_index_history, _, _, _, _ = get_stock_info(affected_stock_codes)
 
121
 
122
+
123
+ def ensure_fixed_shape(data, shape, variable_name=""):
124
+ data = np.array(data)
125
+ if data.shape != shape:
126
+ fixed_data = np.full(shape, -1)
127
+ min_shape = tuple(min(s1, s2) for s1, s2 in zip(data.shape, shape))
128
+ fixed_data[:min_shape[0], :min_shape[1], :min_shape[2]] = data[:min_shape[0], :min_shape[1], :min_shape[2]]
129
+ return fixed_data
130
+ return data
131
+
132
+ previous_stock_history = ensure_fixed_shape(previous_stock_history, (1, 30, 6), "previous_stock_history")
133
+ previous_stock_inx_index_history = ensure_fixed_shape(previous_stock_inx_index_history, (1, 30, 6), "previous_stock_inx_index_history")
134
+ previous_stock_dj_index_history = ensure_fixed_shape(previous_stock_dj_index_history, (1, 30, 6), "previous_stock_dj_index_history")
135
+ previous_stock_ixic_index_history = ensure_fixed_shape(previous_stock_ixic_index_history, (1, 30, 6), "previous_stock_ixic_index_history")
136
+ previous_stock_ndx_index_history = ensure_fixed_shape(previous_stock_ndx_index_history, (1, 30, 6), "previous_stock_ndx_index_history")
137
+
138
+
139
+
140
 
141
  # 3. 将特征转换为适合模型输入的形状
142
  # 这里假设文本、POS、实体识别等是向量,时间序列特征是 (sequence_length, feature_dim) 的形状
 
170
  # 情感得分
171
  X_sentiment = np.array([[sentiment_score]], dtype='float32') # sentiment_score 已经是单值,直接转换为二维数组
172
 
 
 
 
 
 
 
 
 
 
173
  # 打印输入特征的形状,便于调试
174
  print("X_word2vec shape:", X_word2vec.shape)
175
  print("X_pos_tags shape:", X_pos_tags.shape)
176
  print("X_entities shape:", X_entities.shape)
177
  print("X_sentiment shape:", X_sentiment.shape)
 
 
178
 
179
+
180
+
181
+ # 静态特��
182
+ X_word2vec = ensure_fixed_shape(X_word2vec, (1, 300), "X_word2vec")
183
+ X_pos_tags = ensure_fixed_shape(X_pos_tags, (1, 1024), "X_pos_tags")
184
+ X_entities = ensure_fixed_shape(X_entities, (1, 1024), "X_entities")
185
+ X_sentiment = ensure_fixed_shape(X_sentiment, (1, 1), "X_sentiment")
186
+
187
+
188
+
189
  features = [
190
+ X_word2vec, X_pos_tags, X_entities, X_sentiment,
191
+ previous_stock_inx_index_history, previous_stock_dj_index_history,
192
+ previous_stock_ixic_index_history, previous_stock_ndx_index_history,
193
+ previous_stock_history
 
 
194
  ]
195
 
196
+
197
+
198
  # 打印特征数组的每个元素的形状,便于调试
199
+ # for i, feature in enumerate(features):
200
+ # print(f"Feature {i} shape: {feature.shape} value: {feature[0]} length: {len(feature[0])}")
201
+ for name, feature in enumerate(features):
202
+ print(f"模型输入数据 {name} shape: {feature.shape}")
203
+
204
+ for layer in model.input:
205
+ print(f"模型所需的输入层 {layer.name}, 形状: {layer.shape}")
206
 
207
  # 使用模型进行预测
208
  predictions = model.predict(features)
 
211
  fake_accuracy = generate_fake_accuracy()
212
 
213
  # 将 predictions 中的每个数组转换为 Python 列表
214
+ index_inx_predictions = predictions[0].tolist()
215
+ index_dj_predictions = predictions[1].tolist()
216
+ index_ixic_predictions = predictions[2].tolist()
217
+ index_ndx_predictions = predictions[3].tolist()
218
+ stock_predictions = predictions[4].tolist()
219
+
220
+ print(f"Original predictions: {predictions}")
221
 
222
  # 打印预测结果,便于调试
223
+ print("Index INX Predictions:", index_inx_predictions)
224
+ print("Index DJ Predictions:", index_dj_predictions)
225
+ print("Index IXIC Predictions:", index_ixic_predictions)
226
+ print("Index NDX Predictions:", index_ndx_predictions)
227
  print("Stock Predictions:", stock_predictions)
228
 
229
 
230
 
231
 
232
  # 获取 index_feature 中最后一天的第一个值
233
+ last_index_inx_value = previous_stock_inx_index_history[0][-1][0]
234
+ last_index_dj_value = previous_stock_dj_index_history[0][-1][0]
235
+ last_index_ixic_value = previous_stock_ixic_index_history[0][-1][0]
236
+ last_index_ndx_value = previous_stock_ndx_index_history[0][-1][0]
237
 
238
  # 提取 Index Predictions 中每一天的第一个值
239
+ index_inx_day_1 = index_inx_predictions[0][0][0]
240
+ index_inx_day_2 = index_inx_predictions[0][1][0]
241
+ index_inx_day_3 = index_inx_predictions[0][2][0]
242
+
243
+ index_dj_day_1 = index_dj_predictions[0][0][0]
244
+ index_dj_day_2 = index_dj_predictions[0][1][0]
245
+ index_dj_day_3 = index_dj_predictions[0][2][0]
246
+
247
+ index_ixic_day_1 = index_ixic_predictions[0][0][0]
248
+ index_ixic_day_2 = index_ixic_predictions[0][1][0]
249
+ index_ixic_day_3 = index_ixic_predictions[0][2][0]
250
+
251
+ index_ndx_day_1 = index_ndx_predictions[0][0][0]
252
+ index_ndx_day_2 = index_ndx_predictions[0][1][0]
253
+ index_ndx_day_3 = index_ndx_predictions[0][2][0]
254
 
255
  # 计算 impact_1_day, impact_2_day, impact_3_day
256
+ impact_inx_1_day = (index_inx_day_1 - last_index_inx_value) / last_index_inx_value
257
+ impact_inx_2_day = (index_inx_day_2 - index_inx_day_1) / index_inx_day_1
258
+ impact_inx_3_day = (index_inx_day_3 - index_inx_day_2) / index_inx_day_2
259
+
260
+ impact_dj_1_day = (index_dj_day_1 - last_index_dj_value) / last_index_dj_value
261
+ impact_dj_2_day = (index_dj_day_2 - index_dj_day_1) / index_dj_day_1
262
+ impact_dj_3_day = (index_dj_day_3 - index_dj_day_2) / index_dj_day_2
263
+
264
+ impact_ixic_1_day = (index_ixic_day_1 - last_index_ixic_value) / last_index_ixic_value
265
+ impact_ixic_2_day = (index_ixic_day_2 - index_ixic_day_1) / index_ixic_day_1
266
+ impact_ixic_3_day = (index_ixic_day_3 - index_ixic_day_2) / index_ixic_day_2
267
+
268
+ impact_ndx_1_day = (index_ndx_day_1 - last_index_ndx_value) / last_index_ndx_value
269
+ impact_ndx_2_day = (index_ndx_day_2 - index_ndx_day_1) / index_ndx_day_1
270
+ impact_ndx_3_day = (index_ndx_day_3 - index_ndx_day_2) / index_ndx_day_2
271
 
272
  # 将 impact 值转换为百分比字符串
273
+ impact_inx_1_day_str = f"{impact_inx_1_day:.2%}"
274
+ impact_inx_2_day_str = f"{impact_inx_2_day:.2%}"
275
+ impact_inx_3_day_str = f"{impact_inx_3_day:.2%}"
276
+
277
+ impact_dj_1_day_str = f"{impact_dj_1_day:.2%}"
278
+ impact_dj_2_day_str = f"{impact_dj_2_day:.2%}"
279
+ impact_dj_3_day_str = f"{impact_dj_3_day:.2%}"
280
+
281
+ impact_ixic_1_day_str = f"{impact_ixic_1_day:.2%}"
282
+ impact_ixic_2_day_str = f"{impact_ixic_2_day:.2%}"
283
+ impact_ixic_3_day_str = f"{impact_ixic_3_day:.2%}"
284
+
285
+ impact_ndx_1_day_str = f"{impact_ndx_1_day:.2%}"
286
+ impact_ndx_2_day_str = f"{impact_ndx_2_day:.2%}"
287
+ impact_ndx_3_day_str = f"{impact_ndx_3_day:.2%}"
288
 
289
 
290
  # 如果需要返回原始预测数据进行调试,可以直接将其放到响应中
 
296
 
297
 
298
  # 针对 926 模型的修复
299
+ stock_predictions = stock_fix_for_1012_model(float(X_sentiment[0][0]), stock_predictions[0], previous_stock_history[0][-1][0])
300
+ index_inx_predictions = stock_fix_for_1012_model(float(X_sentiment[0][0]), index_inx_predictions[0], last_index_inx_value)
301
+ index_dj_predictions = stock_fix_for_1012_model(float(X_sentiment[0][0]), index_dj_predictions[0], last_index_dj_value)
302
+ index_ixic_predictions = stock_fix_for_1012_model(float(X_sentiment[0][0]), index_ixic_predictions[0], last_index_ixic_value)
303
+ index_ndx_predictions = stock_fix_for_1012_model(float(X_sentiment[0][0]), index_ndx_predictions[0], last_index_ndx_value)
304
 
305
  print("Stock Predictions after fix:", stock_predictions)
306
+ print("Index INX Predictions after fix:", index_inx_predictions)
307
+ print("Index DJ Predictions after fix:", index_dj_predictions)
308
+ print("Index IXIC Predictions after fix:", index_ixic_predictions)
309
+ print("Index NDX Predictions after fix:", index_ndx_predictions)
310
 
311
  # 扩展股票预测数据到分钟级别
312
  stock_predictions = extend_stock_days_to_mins(stock_predictions)
313
+ index_inx_predictions = extend_stock_days_to_mins(index_inx_predictions)
314
+ index_dj_predictions = extend_stock_days_to_mins(index_dj_predictions)
315
+ index_ixic_predictions = extend_stock_days_to_mins(index_ixic_predictions)
316
+ index_ndx_predictions = extend_stock_days_to_mins(index_ndx_predictions)
317
 
318
 
319
 
 
321
  result = {
322
  "news_title": input_text,
323
  "ai_prediction_score": float(X_sentiment[0][0]), # 假设第一个预测值是 AI 预测得分
324
+ "impact_inx_1_day": impact_inx_1_day_str, # 计算并格式化 impact_1_day
325
+ "impac_inx_2_day": impact_inx_2_day_str, # 计算并格式化 impact_2_day
326
+ "impact_inx_3_day": impact_inx_3_day_str,
327
+ "impact_dj_1_day": impact_dj_1_day_str, # 计算并格式化 impact_1_day
328
+ "impact_dj_2_day": impact_dj_2_day_str, # 计算并格式化 impact_2_day
329
+ "impact_dj_3_day": impact_dj_3_day_str,
330
+ "impact_ixic_1_day": impact_ixic_1_day_str, # 计算并格式化 impact_1_day
331
+ "impact_ixic_2_day": impact_ixic_2_day_str, # 计算并格式化 impact_2_day
332
+ "impact_ixic_3_day": impact_ixic_3_day_str,
333
+ "impact_ndx_1_day": impact_ndx_1_day_str, # 计算并格式化 impact_1_day
334
+ "impact_ndx_2_day": impact_ndx_2_day_str, # 计算并格式化 impact_2_day
335
+ "impact_ndx_3_day": impact_ndx_3_day_str,
336
  "affected_stock_codes": affected_stock_codes_str, # 动态生成受影响的股票代码
337
  "accuracy": float(fake_accuracy),
338
  "impact_on_stock": stock_predictions, # 第��个预测值是股票影响
339
+ "impact_on_index_inx": index_inx_predictions, # 第一个预测值是股票影响
340
+ "impact_on_index_dj": index_dj_predictions, # 第一个预测值是股票影响
341
+ "impact_on_index_ixic": index_ixic_predictions, # 第一个预测值是股票影响
342
+ "impact_on_index_ndx": index_ndx_predictions, # 第一个预测值是股票影响
343
 
344
  }
345
 
 
353
  print(f"predict() result: {result}")
354
 
355
  # 返回预测结果
356
+ return result
357
 
358
  except Exception as e:
359
  # 打印完整的错误堆栈信息
360
  traceback_str = traceback.print_exc()
361
  print(f"predict() error: {e}")
362
  print(traceback_str)
363
+ return {"predict() error": str(e), "traceback": traceback_str}
364
 
365
 
366
+ def stock_fix_for_1012_model(score, predictions, last_prices):
367
+ """
368
+ 修复 1012 模型的预测结果,支持多特征处理。
369
+
370
+ :param score: 模型评分,用于调整预测结果。
371
+ :param predictions: 模型的原始预测结果,形状为 (days, features)。
372
+ :param last_prices: 每个特征的最后价格,。
373
+ :return: 修正后的预测结果,形状与输入一致。
374
+ """
375
  coefficient = 1.2 # 调整系数,可以根据需要微调
376
  smoothing_factor = 0.7 # 平滑因子,控制曲线平滑度
377
  window_size = 3 # 滚动平均窗口大小
378
 
379
  smoothed_predictions = [] # 用于存储平滑后的预测
380
 
 
 
 
 
381
  for i, day in enumerate(predictions):
382
+ adjusted_day = [] # 存储当天修正后的各特征值
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
383
 
384
+ for feature_idx, value in enumerate(day):
385
+ # 获取当前特征的最后价格
386
+ last_price = last_prices
387
+ if last_price == 0:
388
+ last_price = 1
389
+
390
+ # 计算波动系数,并限制其在一个较小的范围内
391
+ fluctuation = random.uniform(-0.01, 0.01)
392
+
393
+ # 当前预测值的修正
394
+ adjusted_value = ((abs(value) * score * coefficient / last_price / 10 / 100) + (1 + fluctuation)) * last_price
395
+
396
+ # 滚动平均平滑(仅对收盘价进行平滑,假设收盘价是特征索引为 0 的值)
397
+ if feature_idx == 0 and i >= window_size:
398
+ smoothed_value = (
399
+ sum([smoothed_predictions[j][feature_idx] for j in range(i - window_size, i)]) / window_size
400
+ )
401
+ adjusted_value = smoothing_factor * smoothed_value + (1 - smoothing_factor) * adjusted_value
402
+
403
+ # 更新最后价格,用于下一个迭代
404
+ last_prices = adjusted_value
405
+ adjusted_day.append(adjusted_value)
406
+
407
+ # 将修正后的预测存入
408
+ smoothed_predictions.append(adjusted_day)
409
+
410
  return smoothed_predictions
411
 
412
 
preprocess.py CHANGED
@@ -220,35 +220,54 @@ def get_sentiment_score(text):
220
 
221
 
222
 
223
- def get_stock_info(stock_codes, news_date):
224
  # 获取股票代码和新闻日期
225
- stock_codes = stock_codes.split(',')
226
 
227
- news_date = news_date.strftime('%Y%m%d')
228
- print(f"Getting stock info for {stock_codes} on {news_date}")
229
 
230
  previous_stock_history = []
231
  following_stock_history = []
232
- previous_stock_index_history = []
233
- following_stock_index_history = []
234
 
235
- def process_history(stock_history, target_date):
 
 
 
 
 
 
 
 
 
 
 
 
236
  # 如果数据为空,创建一个空的 DataFrame 并填充为 0
237
  if stock_history.empty:
238
- empty_data = pd.DataFrame({
239
- '开盘': [0] * 4,
240
- '收盘': [0] * 4,
241
- '最高': [0] * 4,
242
- '最低': [0] * 4,
243
- '成交量': [0] * 4,
244
- '成交额': [0] * 4
245
  })
246
- return empty_data, empty_data
 
 
 
 
 
 
 
 
 
247
 
248
  # 确保 'date' 列存在
249
  if 'date' not in stock_history.columns:
250
  print(f"'date' column not found in stock history. Returning empty data.")
251
- return pd.DataFrame([[0] * 6] * 4), pd.DataFrame([[0] * 6] * 4)
252
 
253
  # 将日期转换为 datetime 格式,便于比较
254
  stock_history['date'] = pd.to_datetime(stock_history['date'])
@@ -265,44 +284,61 @@ def get_stock_info(stock_codes, news_date):
265
 
266
  # 确保找到的目标日期有数据
267
  if target_row.empty:
268
- return pd.DataFrame([[0] * 6] * 4), pd.DataFrame([[0] * 6] * 4)
269
 
270
  target_index = target_row.index[0]
271
  target_pos = stock_history.index.get_loc(target_index)
272
 
273
- # 取出目标日期及其前3条记录
274
- previous_rows = stock_history.iloc[max(0, target_pos - 3):target_pos + 1]
275
 
276
- # 取出目标日期及其后4条记录
277
- following_rows = stock_history.iloc[target_pos:target_pos + 4]
278
 
279
  # 删除日期列
280
  previous_rows = previous_rows.drop(columns=['date'])
281
  following_rows = following_rows.drop(columns=['date'])
282
 
283
- # 如果 previous_rows 或 following_rows 的行数不足 4,则填充至 4
284
- if len(previous_rows) < 4:
285
- previous_rows = previous_rows.reindex(range(4), fill_value=0)
286
 
287
- if len(following_rows) < 4:
288
- following_rows = following_rows.reindex(range(4), fill_value=0)
289
 
290
- # 只返回前4行,并只返回前6列(开盘、收盘、最高、最低、成交量、成交额)
291
- previous_rows = previous_rows.iloc[:4, :6]
292
- following_rows = following_rows.iloc[:4, :6]
293
 
294
  return previous_rows, following_rows
295
 
296
  if not stock_codes or stock_codes == ['']:
297
  # 如果 stock_codes 为空,直接获取并返回大盘数据
298
- stock_index_history = get_stock_index_history("", news_date)
299
- previous_rows, following_rows = process_history(stock_index_history, news_date)
300
- previous_stock_index_history.append(previous_rows.values.tolist())
301
- following_stock_index_history.append(following_rows.values.tolist())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
302
 
303
  # 个股补零逻辑
304
- previous_stock_history.append([[0] * len(previous_rows.columns)] * len(previous_rows))
305
- following_stock_history.append([[0] * len(following_rows.columns)] * len(following_rows))
306
 
307
 
308
 
@@ -310,7 +346,6 @@ def get_stock_info(stock_codes, news_date):
310
  for stock_code in stock_codes:
311
  stock_code = stock_code.strip()
312
  stock_history = get_stock_history(stock_code, news_date)
313
- stock_index_history = get_stock_index_history(stock_code, news_date)
314
 
315
  # 处理个股数据
316
  previous_rows, following_rows = process_history(stock_history, news_date)
@@ -318,11 +353,33 @@ def get_stock_info(stock_codes, news_date):
318
  following_stock_history.append(following_rows.values.tolist())
319
 
320
  # 处理大盘数据
321
- previous_rows, following_rows = process_history(stock_index_history, news_date)
322
- previous_stock_index_history.append(previous_rows.values.tolist())
323
- following_stock_index_history.append(following_rows.values.tolist())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
324
 
325
- return previous_stock_history, following_stock_history, previous_stock_index_history, following_stock_index_history
 
 
326
 
327
 
328
 
 
220
 
221
 
222
 
223
+ def get_stock_info(stock_codes, history_days=30):
224
  # 获取股票代码和新闻日期
225
+ stock_codes = stock_codes
226
 
227
+ news_date = datetime.now().strftime('%Y%m%d')
228
+ # print(f"Getting stock info for {stock_codes} on {news_date}")
229
 
230
  previous_stock_history = []
231
  following_stock_history = []
 
 
232
 
233
+ previous_stock_inx_index_history = []
234
+ previous_stock_dj_index_history = []
235
+ previous_stock_ixic_index_history = []
236
+ previous_stock_ndx_index_history = []
237
+
238
+ following_stock_inx_index_history = []
239
+ following_stock_dj_index_history = []
240
+ following_stock_ixic_index_history = []
241
+ following_stock_ndx_index_history = []
242
+
243
+
244
+
245
+ def process_history(stock_history, target_date, history_days=history_days, following_days = 3):
246
  # 如果数据为空,创建一个空的 DataFrame 并填充为 0
247
  if stock_history.empty:
248
+ empty_data_previous = pd.DataFrame({
249
+ '开盘': [-1] * history_days,
250
+ '收盘': [-1] * history_days,
251
+ '最高': [-1] * history_days,
252
+ '最低': [-1] * history_days,
253
+ '成交量': [-1] * history_days,
254
+ '成交额': [-1] * history_days
255
  })
256
+
257
+ empty_data_following = pd.DataFrame({
258
+ '开盘': [-1] * following_days,
259
+ '收盘': [-1] * following_days,
260
+ '最高': [-1] * following_days,
261
+ '最低': [-1] * following_days,
262
+ '成交量': [-1] * following_days,
263
+ '成交额': [-1] * following_days
264
+ })
265
+ return empty_data_previous, empty_data_following
266
 
267
  # 确保 'date' 列存在
268
  if 'date' not in stock_history.columns:
269
  print(f"'date' column not found in stock history. Returning empty data.")
270
+ return pd.DataFrame([[-1] * 6] * history_days), pd.DataFrame([[-1] * 6] * following_days)
271
 
272
  # 将日期转换为 datetime 格式,便于比较
273
  stock_history['date'] = pd.to_datetime(stock_history['date'])
 
284
 
285
  # 确保找到的目标日期有数据
286
  if target_row.empty:
287
+ return pd.DataFrame([[-1] * 6] * history_days), pd.DataFrame([[-1] * 6] * following_days)
288
 
289
  target_index = target_row.index[0]
290
  target_pos = stock_history.index.get_loc(target_index)
291
 
292
+ # 取出目标日期及其前history_days条记录
293
+ previous_rows = stock_history.iloc[max(0, target_pos - history_days):target_pos + 1]
294
 
295
+ # 取出目标日期及其后3条记录
296
+ following_rows = stock_history.iloc[target_pos + 1:target_pos + 4]
297
 
298
  # 删除日期列
299
  previous_rows = previous_rows.drop(columns=['date'])
300
  following_rows = following_rows.drop(columns=['date'])
301
 
302
+ # 如果 previous_rows 或 following_rows 的行数不足 history_days,则填充至 history_days
303
+ if len(previous_rows) < history_days:
304
+ previous_rows = previous_rows.reindex(range(history_days), fill_value=-1)
305
 
306
+ if len(following_rows) < 3:
307
+ following_rows = following_rows.reindex(range(3), fill_value=-1)
308
 
309
+ # 只返回前history_days行,并只返回前6列(开盘、收盘、最高、最低、成交量、成交额)
310
+ previous_rows = previous_rows.iloc[:history_days, :6]
311
+ following_rows = following_rows.iloc[:following_days, :6]
312
 
313
  return previous_rows, following_rows
314
 
315
  if not stock_codes or stock_codes == ['']:
316
  # 如果 stock_codes 为空,直接获取并返回大盘数据
317
+ stock_index_ndx_history = get_stock_index_history("", news_date, 1)
318
+ stock_index_dj_history = get_stock_index_history("", news_date, 2)
319
+ stock_index_inx_history = get_stock_index_history("", news_date, 3)
320
+ stock_index_ixic_history = get_stock_index_history("", news_date, 4)
321
+
322
+ previous_ndx_rows, following_ndx_rows = process_history(stock_index_ndx_history, news_date, history_days)
323
+ previous_dj_rows, following_dj_rows = process_history(stock_index_dj_history, news_date, history_days)
324
+ previous_inx_rows, following_inx_rows = process_history(stock_index_inx_history, news_date, history_days)
325
+ previous_ixic_rows, following_ixic_rows = process_history(stock_index_ixic_history, news_date, history_days)
326
+
327
+
328
+ previous_stock_inx_index_history.append(previous_inx_rows.values.tolist())
329
+ previous_stock_dj_index_history.append(previous_dj_rows.values.tolist())
330
+ previous_stock_ixic_index_history.append(previous_ixic_rows.values.tolist())
331
+ previous_stock_ndx_index_history.append(previous_ndx_rows.values.tolist())
332
+
333
+ following_stock_inx_index_history.append(following_inx_rows.values.tolist())
334
+ following_stock_dj_index_history.append(following_dj_rows.values.tolist())
335
+ following_stock_ixic_index_history.append(following_ixic_rows.values.tolist())
336
+ following_stock_ndx_index_history.append(following_ndx_rows.values.tolist())
337
+
338
 
339
  # 个股补零逻辑
340
+ previous_stock_history.append([[-1] * 6] * history_days)
341
+ following_stock_history.append([[-1] * 6] * 3)
342
 
343
 
344
 
 
346
  for stock_code in stock_codes:
347
  stock_code = stock_code.strip()
348
  stock_history = get_stock_history(stock_code, news_date)
 
349
 
350
  # 处理个股数据
351
  previous_rows, following_rows = process_history(stock_history, news_date)
 
353
  following_stock_history.append(following_rows.values.tolist())
354
 
355
  # 处理大盘数据
356
+ stock_index_ndx_history = get_stock_index_history("", news_date, 1)
357
+ stock_index_dj_history = get_stock_index_history("", news_date, 2)
358
+ stock_index_inx_history = get_stock_index_history("", news_date, 3)
359
+ stock_index_ixic_history = get_stock_index_history("", news_date, 4)
360
+
361
+ previous_ndx_rows, following_ndx_rows = process_history(stock_index_ndx_history, news_date, history_days)
362
+ previous_dj_rows, following_dj_rows = process_history(stock_index_dj_history, news_date, history_days)
363
+ previous_inx_rows, following_inx_rows = process_history(stock_index_inx_history, news_date, history_days)
364
+ previous_ixic_rows, following_ixic_rows = process_history(stock_index_ixic_history, news_date, history_days)
365
+
366
+
367
+ previous_stock_inx_index_history.append(previous_inx_rows.values.tolist())
368
+ previous_stock_dj_index_history.append(previous_dj_rows.values.tolist())
369
+ previous_stock_ixic_index_history.append(previous_ixic_rows.values.tolist())
370
+ previous_stock_ndx_index_history.append(previous_ndx_rows.values.tolist())
371
+
372
+ following_stock_inx_index_history.append(following_inx_rows.values.tolist())
373
+ following_stock_dj_index_history.append(following_dj_rows.values.tolist())
374
+ following_stock_ixic_index_history.append(following_ixic_rows.values.tolist())
375
+ following_stock_ndx_index_history.append(following_ndx_rows.values.tolist())
376
+
377
+ # 只返回第一支股票的数据
378
+ break
379
 
380
+ return previous_stock_history, following_stock_history, \
381
+ previous_stock_inx_index_history, previous_stock_dj_index_history, previous_stock_ixic_index_history, previous_stock_ndx_index_history, \
382
+ following_stock_inx_index_history, following_stock_dj_index_history, following_stock_ixic_index_history, following_stock_ndx_index_history,
383
 
384
 
385
 
us_stock.py CHANGED
@@ -19,10 +19,10 @@ logging.basicConfig(level=logging.INFO)
19
  base_dir = os.path.dirname(os.path.abspath(__file__))
20
 
21
  # 构建CSV文件的绝对路径
22
- nasdaq_100_path = os.path.join(base_dir, '../model/nasdaq100.csv')
23
- dow_jones_path = os.path.join(base_dir, '../model/dji.csv')
24
- sp500_path = os.path.join(base_dir, '../model/sp500.csv')
25
- nasdaq_composite_path = os.path.join(base_dir, '../model/nasdaq_all.csv')
26
  # 从CSV文件加载成分股数据
27
  nasdaq_100_stocks = pd.read_csv(nasdaq_100_path)
28
  dow_jones_stocks = pd.read_csv(dow_jones_path)
@@ -69,7 +69,13 @@ async def fetch_stock_us_spot_data_with_retries_async():
69
  await asyncio.sleep(wait_time)
70
  retry_index = min(retry_index + 1, len(retry_intervals) - 1)
71
 
72
- symbols = asyncio.run(fetch_stock_us_spot_data_with_retries_async())
 
 
 
 
 
 
73
 
74
 
75
  # 全局变量
@@ -238,58 +244,31 @@ def get_stock_history(symbol, news_date, retries=10):
238
  # result = get_stock_history('ATMU', '20231218')
239
  # print(result)
240
 
241
-
242
  # 返回个股所属指数历史数据
243
- def get_stock_index_history(symbol, news_date):
244
  # 检查股票所属的指数
245
- if symbol in nasdaq_100_stocks['Symbol'].values:
246
  index_code = ".NDX"
247
  index_data = index_us_stock_index_NDX
248
- elif symbol in dow_jones_stocks['Symbol'].values:
249
  index_code = ".DJI"
250
  index_data = index_us_stock_index_DJI
251
- elif symbol in sp500_stocks['Symbol'].values:
252
  index_code = ".INX"
253
  index_data = index_us_stock_index_INX
254
- elif symbol in nasdaq_composite_stocks["Symbol"].values or symbol is None or symbol == "":
255
  index_code = ".IXIC"
256
  index_data = index_us_stock_index_IXIC
257
  else:
258
-
259
  index_code = ".IXIC"
260
  index_data = index_us_stock_index_IXIC
261
-
262
- # print(f"股票代码 {symbol} 不属于纳斯达克100、道琼斯工业、标准普尔500或纳斯达克综合指数。")
263
- # 将 news_date 转换为 datetime 对象
264
- news_date_dt = datetime.strptime(news_date, "%Y%m%d")
265
-
266
- # 计算 start_date 和 end_date
267
- start_date = (news_date_dt - timedelta(weeks=2)).strftime("%Y-%m-%d")
268
- end_date = (news_date_dt + timedelta(weeks=2)).strftime("%Y-%m-%d")
269
-
270
- # 构建一个空的 DataFrame,包含指定日期范围的空数据
271
- date_range = pd.date_range(start=start_date, end=end_date)
272
- stock_hist_df = pd.DataFrame({
273
- 'date': date_range,
274
- 'open': 0,
275
- 'high': 0,
276
- 'low': 0,
277
- 'close': 0,
278
- 'volume': 0,
279
- 'amount': 0
280
- })
281
- # 统一列名
282
- stock_hist_df = stock_hist_df.rename(columns=column_mapping)
283
- stock_hist_df = stock_hist_df.reindex(columns=standard_columns)
284
- # 处理个股数据,保留所需列
285
- stock_hist_df = reduce_columns(stock_hist_df, standard_columns)
286
- return stock_hist_df
287
 
288
  # 将 news_date 转换为 datetime 对象
289
  news_date_dt = datetime.strptime(news_date, "%Y%m%d")
290
 
291
  # 计算 start_date 和 end_date
292
- start_date = (news_date_dt - timedelta(weeks=2)).strftime("%Y-%m-%d")
293
  end_date = (news_date_dt + timedelta(weeks=2)).strftime("%Y-%m-%d")
294
 
295
  # 确保 index_data['date'] 是 datetime 类型
@@ -311,7 +290,6 @@ def get_stock_index_history(symbol, news_date):
311
  '''
312
 
313
 
314
-
315
  def find_stock_codes_or_names(entities):
316
  """
317
  从给定的实体列表中检索股票代码或公司名称。
 
19
  base_dir = os.path.dirname(os.path.abspath(__file__))
20
 
21
  # 构建CSV文件的绝对路径
22
+ nasdaq_100_path = os.path.join(base_dir, './model/nasdaq100.csv')
23
+ dow_jones_path = os.path.join(base_dir, './model/dji.csv')
24
+ sp500_path = os.path.join(base_dir, './model/sp500.csv')
25
+ nasdaq_composite_path = os.path.join(base_dir, './model/nasdaq_all.csv')
26
  # 从CSV文件加载成分股数据
27
  nasdaq_100_stocks = pd.read_csv(nasdaq_100_path)
28
  dow_jones_stocks = pd.read_csv(dow_jones_path)
 
69
  await asyncio.sleep(wait_time)
70
  retry_index = min(retry_index + 1, len(retry_intervals) - 1)
71
 
72
+ symbols = None
73
+
74
+ async def fetch_symbols():
75
+ global symbols
76
+ # 异步获取数据
77
+ symbols = await fetch_stock_us_spot_data_with_retries_async()
78
+ print("Symbols initialized:", symbols)
79
 
80
 
81
  # 全局变量
 
244
  # result = get_stock_history('ATMU', '20231218')
245
  # print(result)
246
 
 
247
  # 返回个股所属指数历史数据
248
+ def get_stock_index_history(symbol, news_date, force_index=0):
249
  # 检查股票所属的指数
250
+ if symbol in nasdaq_100_stocks['Symbol'].values or force_index == 1:
251
  index_code = ".NDX"
252
  index_data = index_us_stock_index_NDX
253
+ elif symbol in dow_jones_stocks['Symbol'].values or force_index == 2:
254
  index_code = ".DJI"
255
  index_data = index_us_stock_index_DJI
256
+ elif symbol in sp500_stocks['Symbol'].values or force_index == 3:
257
  index_code = ".INX"
258
  index_data = index_us_stock_index_INX
259
+ elif symbol in nasdaq_composite_stocks["Symbol"].values or symbol is None or symbol == "" or force_index == 4:
260
  index_code = ".IXIC"
261
  index_data = index_us_stock_index_IXIC
262
  else:
263
+ # print(f"股票代码 {symbol} 不属于纳斯达克100、道琼斯工业、标准普尔500或纳斯达克综合指数。")
264
  index_code = ".IXIC"
265
  index_data = index_us_stock_index_IXIC
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
266
 
267
  # 将 news_date 转换为 datetime 对象
268
  news_date_dt = datetime.strptime(news_date, "%Y%m%d")
269
 
270
  # 计算 start_date 和 end_date
271
+ start_date = (news_date_dt - timedelta(weeks=8)).strftime("%Y-%m-%d")
272
  end_date = (news_date_dt + timedelta(weeks=2)).strftime("%Y-%m-%d")
273
 
274
  # 确保 index_data['date'] 是 datetime 类型
 
290
  '''
291
 
292
 
 
293
  def find_stock_codes_or_names(entities):
294
  """
295
  从给定的实体列表中检索股票代码或公司名称。