File size: 13,859 Bytes
2af5f09
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
"""
Tests:

- custom_path false / no user auth:
    -- upload file(yes)
    -- download file(yes)
    -- websocket(yes)
    -- block __pycache__ access(yes)
        -- rel (yes)
        -- abs (yes)
    -- block user access(fail) http://localhost:45013/file=gpt_log/admin/chat_secrets.log
        -- fix(commit f6bf05048c08f5cd84593f7fdc01e64dec1f584a)-> block successful

- custom_path yes("/cc/gptac") / no user auth:
    -- upload file(yes)
    -- download file(yes)
    -- websocket(yes)
    -- block __pycache__ access(yes)
    -- block user access(yes)

- custom_path yes("/cc/gptac/") / no user auth:
    -- upload file(yes)
    -- download file(yes)
    -- websocket(yes)
    -- block user access(yes)

- custom_path yes("/cc/gptac/") / + user auth:
    -- upload file(yes)
    -- download file(yes)
    -- websocket(yes)
    -- block user access(yes)
    -- block user-wise access (yes)

- custom_path no + user auth:
    -- upload file(yes)
    -- download file(yes)
    -- websocket(yes)
    -- block user access(yes)
    -- block user-wise access (yes)

queue cocurrent effectiveness
    -- upload file(yes)
    -- download file(yes)
    -- websocket(yes)
"""

import os, requests, threading, time
import uvicorn

def validate_path_safety(path_or_url, user):
    from toolbox import get_conf, default_user_name
    from toolbox import FriendlyException
    PATH_PRIVATE_UPLOAD, PATH_LOGGING = get_conf('PATH_PRIVATE_UPLOAD', 'PATH_LOGGING')
    sensitive_path = None
    path_or_url = os.path.relpath(path_or_url)
    if path_or_url.startswith(PATH_LOGGING):    # 日志文件(按用户划分)
        sensitive_path = PATH_LOGGING
    elif path_or_url.startswith(PATH_PRIVATE_UPLOAD):   # 用户的上传目录(按用户划分)
        sensitive_path = PATH_PRIVATE_UPLOAD
    elif path_or_url.startswith('tests') or path_or_url.startswith('build'):   # 一个常用的测试目录
        return True
    else:
        raise FriendlyException(f"输入文件的路径 ({path_or_url}) 存在,但位置非法。请将文件上传后再执行该任务。") # return False
    if sensitive_path:
        allowed_users = [user, 'autogen', 'arxiv_cache', default_user_name]  # three user path that can be accessed
        for user_allowed in allowed_users:
            if f"{os.sep}".join(path_or_url.split(os.sep)[:2]) == os.path.join(sensitive_path, user_allowed):
                return True
        raise FriendlyException(f"输入文件的路径 ({path_or_url}) 存在,但属于其他用户。请将文件上传后再执行该任务。") # return False
    return True

def _authorize_user(path_or_url, request, gradio_app):
    from toolbox import get_conf, default_user_name
    PATH_PRIVATE_UPLOAD, PATH_LOGGING = get_conf('PATH_PRIVATE_UPLOAD', 'PATH_LOGGING')
    sensitive_path = None
    path_or_url = os.path.relpath(path_or_url)
    if path_or_url.startswith(PATH_LOGGING):
        sensitive_path = PATH_LOGGING
    if path_or_url.startswith(PATH_PRIVATE_UPLOAD):
        sensitive_path = PATH_PRIVATE_UPLOAD
    if sensitive_path:
        token = request.cookies.get("access-token") or request.cookies.get("access-token-unsecure")
        user = gradio_app.tokens.get(token)  # get user
        allowed_users = [user, 'autogen', 'arxiv_cache', default_user_name]  # three user path that can be accessed
        for user_allowed in allowed_users:
            # exact match
            if f"{os.sep}".join(path_or_url.split(os.sep)[:2]) == os.path.join(sensitive_path, user_allowed):
                return True
        return False # "越权访问!"
    return True


class Server(uvicorn.Server):
    # A server that runs in a separate thread
    def install_signal_handlers(self):
        pass

    def run_in_thread(self):
        self.thread = threading.Thread(target=self.run, daemon=True)
        self.thread.start()
        while not self.started:
            time.sleep(5e-2)

    def close(self):
        self.should_exit = True
        self.thread.join()


def start_app(app_block, CONCURRENT_COUNT, AUTHENTICATION, PORT, SSL_KEYFILE, SSL_CERTFILE):
    import uvicorn
    import fastapi
    import gradio as gr
    from fastapi import FastAPI
    from gradio.routes import App
    from toolbox import get_conf
    CUSTOM_PATH, PATH_LOGGING = get_conf('CUSTOM_PATH', 'PATH_LOGGING')

    # --- --- configurate gradio app block --- ---
    app_block:gr.Blocks
    app_block.ssl_verify = False
    app_block.auth_message = '请登录'
    app_block.favicon_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "docs/logo.png")
    app_block.auth = AUTHENTICATION if len(AUTHENTICATION) != 0 else None
    app_block.blocked_paths = ["config.py", "__pycache__", "config_private.py", "docker-compose.yml", "Dockerfile", f"{PATH_LOGGING}/admin"]
    app_block.dev_mode = False
    app_block.config = app_block.get_config_file()
    app_block.enable_queue = True
    app_block.queue(concurrency_count=CONCURRENT_COUNT)
    app_block.validate_queue_settings()
    app_block.show_api = False
    app_block.config = app_block.get_config_file()
    max_threads = 40
    app_block.max_threads = max(
        app_block._queue.max_thread_count if app_block.enable_queue else 0, max_threads
    )
    app_block.is_colab = False
    app_block.is_kaggle = False
    app_block.is_sagemaker = False

    gradio_app = App.create_app(app_block)
    for route in list(gradio_app.router.routes):
        if route.path == "/proxy={url_path:path}":
            gradio_app.router.routes.remove(route)
    # --- --- replace gradio endpoint to forbid access to sensitive files --- ---
    if len(AUTHENTICATION) > 0:
        dependencies = []
        endpoint = None
        for route in list(gradio_app.router.routes):
            if route.path == "/file/{path:path}":
                gradio_app.router.routes.remove(route)
            if route.path == "/file={path_or_url:path}":
                dependencies = route.dependencies
                endpoint = route.endpoint
                gradio_app.router.routes.remove(route)
        @gradio_app.get("/file/{path:path}", dependencies=dependencies)
        @gradio_app.head("/file={path_or_url:path}", dependencies=dependencies)
        @gradio_app.get("/file={path_or_url:path}", dependencies=dependencies)
        async def file(path_or_url: str, request: fastapi.Request):
            if not _authorize_user(path_or_url, request, gradio_app):
                return "越权访问!"
            stripped = path_or_url.lstrip().lower()
            if stripped.startswith("https://") or stripped.startswith("http://"):
                return "账户密码授权模式下, 禁止链接!"
            if '../' in stripped:
                return "非法路径!"
            return await endpoint(path_or_url, request)

        from fastapi import Request, status
        from fastapi.responses import FileResponse, RedirectResponse
        @gradio_app.get("/academic_logout")
        async def logout():
            response = RedirectResponse(url=CUSTOM_PATH, status_code=status.HTTP_302_FOUND)
            response.delete_cookie('access-token')
            response.delete_cookie('access-token-unsecure')
            return response
    else:
        dependencies = []
        endpoint = None
        for route in list(gradio_app.router.routes):
            if route.path == "/file/{path:path}":
                gradio_app.router.routes.remove(route)
            if route.path == "/file={path_or_url:path}":
                dependencies = route.dependencies
                endpoint = route.endpoint
                gradio_app.router.routes.remove(route)
        @gradio_app.get("/file/{path:path}", dependencies=dependencies)
        @gradio_app.head("/file={path_or_url:path}", dependencies=dependencies)
        @gradio_app.get("/file={path_or_url:path}", dependencies=dependencies)
        async def file(path_or_url: str, request: fastapi.Request):
            stripped = path_or_url.lstrip().lower()
            if stripped.startswith("https://") or stripped.startswith("http://"):
                return "账户密码授权模式下, 禁止链接!"
            if '../' in stripped:
                return "非法路径!"
            return await endpoint(path_or_url, request)

    # --- --- enable TTS (text-to-speech) functionality --- ---
    TTS_TYPE = get_conf("TTS_TYPE")
    if TTS_TYPE != "DISABLE":
        # audio generation functionality
        import httpx
        from fastapi import FastAPI, Request, HTTPException
        from starlette.responses import Response
        async def forward_request(request: Request, method: str) -> Response:
            async with httpx.AsyncClient() as client:
                try:
                    # Forward the request to the target service
                    if TTS_TYPE == "EDGE_TTS":
                        import tempfile
                        import edge_tts
                        import wave
                        import uuid
                        from pydub import AudioSegment
                        json = await request.json()
                        voice = get_conf("EDGE_TTS_VOICE")
                        tts = edge_tts.Communicate(text=json['text'], voice=voice)
                        temp_folder = tempfile.gettempdir()
                        temp_file_name = str(uuid.uuid4().hex)
                        temp_file = os.path.join(temp_folder, f'{temp_file_name}.mp3')
                        await tts.save(temp_file)
                        try:
                            mp3_audio = AudioSegment.from_file(temp_file, format="mp3")
                            mp3_audio.export(temp_file, format="wav")
                            with open(temp_file, 'rb') as wav_file: t = wav_file.read()
                            os.remove(temp_file)
                            return Response(content=t)
                        except:
                            raise RuntimeError("ffmpeg未安装,无法处理EdgeTTS音频。安装方法见`https://github.com/jiaaro/pydub#getting-ffmpeg-set-up`")
                    if TTS_TYPE == "LOCAL_SOVITS_API":
                        # Forward the request to the target service
                        TARGET_URL = get_conf("GPT_SOVITS_URL")
                        body = await request.body()
                        resp = await client.post(TARGET_URL, content=body, timeout=60)
                        # Return the response from the target service
                        return Response(content=resp.content, status_code=resp.status_code, headers=dict(resp.headers))
                except httpx.RequestError as e:
                    raise HTTPException(status_code=400, detail=f"Request to the target service failed: {str(e)}")
        @gradio_app.post("/vits")
        async def forward_post_request(request: Request):
            return await forward_request(request, "POST")

    # --- --- app_lifespan --- ---
    from contextlib import asynccontextmanager
    @asynccontextmanager
    async def app_lifespan(app):
        async def startup_gradio_app():
            if gradio_app.get_blocks().enable_queue:
                gradio_app.get_blocks().startup_events()
        async def shutdown_gradio_app():
            pass
        await startup_gradio_app() # startup logic here
        yield  # The application will serve requests after this point
        await shutdown_gradio_app() # cleanup/shutdown logic here

    # --- --- FastAPI --- ---
    fastapi_app = FastAPI(lifespan=app_lifespan)
    fastapi_app.mount(CUSTOM_PATH, gradio_app)

    # --- --- favicon and block fastapi api reference routes --- ---
    from starlette.responses import JSONResponse
    if CUSTOM_PATH != '/':
        from fastapi.responses import FileResponse
        @fastapi_app.get("/favicon.ico")
        async def favicon():
            return FileResponse(app_block.favicon_path)

        @fastapi_app.middleware("http")
        async def middleware(request: Request, call_next):
            if request.scope['path'] in ["/docs", "/redoc", "/openapi.json"]:
                return JSONResponse(status_code=404, content={"message": "Not Found"})
            response = await call_next(request)
            return response


    # --- --- uvicorn.Config --- ---
    ssl_keyfile = None if SSL_KEYFILE == "" else SSL_KEYFILE
    ssl_certfile = None if SSL_CERTFILE == "" else SSL_CERTFILE
    server_name = "0.0.0.0"
    config = uvicorn.Config(
        fastapi_app,
        host=server_name,
        port=PORT,
        reload=False,
        log_level="warning",
        ssl_keyfile=ssl_keyfile,
        ssl_certfile=ssl_certfile,
    )
    server = Server(config)
    url_host_name = "localhost" if server_name == "0.0.0.0" else server_name
    if ssl_keyfile is not None:
        if ssl_certfile is None:
            raise ValueError(
                "ssl_certfile must be provided if ssl_keyfile is provided."
            )
        path_to_local_server = f"https://{url_host_name}:{PORT}/"
    else:
        path_to_local_server = f"http://{url_host_name}:{PORT}/"
    if CUSTOM_PATH != '/':
        path_to_local_server += CUSTOM_PATH.lstrip('/').rstrip('/') + '/'
    # --- --- begin  --- ---
    server.run_in_thread()

    # --- --- after server launch --- ---
    app_block.server = server
    app_block.server_name = server_name
    app_block.local_url = path_to_local_server
    app_block.protocol = (
        "https"
        if app_block.local_url.startswith("https") or app_block.is_colab
        else "http"
    )

    if app_block.enable_queue:
        app_block._queue.set_url(path_to_local_server)

    forbid_proxies = {
        "http": "",
        "https": "",
    }
    requests.get(f"{app_block.local_url}startup-events", verify=app_block.ssl_verify, proxies=forbid_proxies)
    app_block.is_running = True
    app_block.block_thread()