File size: 2,295 Bytes
17f3a9b
 
 
f56051d
 
068fdbc
 
d4b1508
62f31c8
 
068fdbc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2ae9fb3
f56051d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17f3a9b
 
 
2ae9fb3
17f3a9b
 
d4b1508
 
 
 
 
 
 
 
 
 
 
 
 
 
17f3a9b
 
2ae9fb3
17f3a9b
 
 
 
8d84024
f56051d
62f31c8
 
 
 
068fdbc
f56051d
62f31c8
558076d
d4b1508
efad2c7
 
 
a65e7e5
62f31c8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import os
from fastapi import FastAPI
from pydantic import BaseModel
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.trustedhost import TrustedHostMiddleware
import asyncio
from contextlib import asynccontextmanager

from RequestModel import PredictRequest

# 全局变量,用于跟踪初始化状态
is_initialized = False
initialization_lock = asyncio.Lock()

@asynccontextmanager
async def lifespan(app: FastAPI):
    # 启动时运行
    global is_initialized
    async with initialization_lock:
        if not is_initialized:
            await initialize_application()
            is_initialized = True
    yield
    # 关闭时运行
    # cleanup_code_here()

async def initialize_application():
    # 在这里进行所有需要的初始化
    from us_stock import fetch_symbols

    await fetch_symbols()
    # 其他初始化代码...

app = FastAPI(lifespan=lifespan)

# 添加 CORS 中间件和限流配置
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# 添加信任主机中间件
app.add_middleware(
    TrustedHostMiddleware,
    allowed_hosts=["*"]
)

# 定义请求模型
class TextRequest(BaseModel):
    text: str

# 定义两个 API 路由处理函数
@app.post("/api/aaa")
async def api_aaa_post(request: TextRequest):
    result = request.text + 'aaa'
    return {"result": result}

# 定义两个 API 路由处理函数
@app.post("/aaa")
async def aaa(request: TextRequest):
    result = request.text + 'aaa'
    return {"result": result}


# 定义两个 API 路由处理函数
@app.get("/aaa")
async def api_aaa_get(request: TextRequest):
    result = request.text + 'aaa'
    return {"result": result}

@app.post("/api/bbb")
async def api_bbb(request: TextRequest):
    result = request.text + 'bbb'
    return {"result": result}

# 优化预测路由
@app.post("/api/predict")
async def predict(request: PredictRequest):
    from blkeras import predict
    try:
        result = await asyncio.to_thread(predict, request.text, request.stock_codes)
        return result
    except Exception as e:
        return []

@app.get("/")
async def root():
    return {"message": "Welcome to the API. Use /api/aaa or /api/bbb for processing."}