Tusik commited on
Commit
6b8f7db
·
unverified ·
1 Parent(s): 40a0cc7

feat:替代并实现gradio logout route并添加退出按钮 (#1034)

Browse files

* feat:替代并实现gradio logout route并添加退出按钮

* 在无验证情况下隐藏

Files changed (2) hide show
  1. ChuanhuChatbot.py +12 -1
  2. modules/gradio_patch.py +114 -0
ChuanhuChatbot.py CHANGED
@@ -16,7 +16,9 @@ from modules.config import *
16
  from modules import config
17
  import gradio as gr
18
  import colorama
 
19
 
 
20
 
21
  logging.getLogger("httpx").setLevel(logging.WARNING)
22
 
@@ -33,6 +35,8 @@ def create_new_model():
33
 
34
  with gr.Blocks(theme=small_and_beautiful_theme) as demo:
35
  user_name = gr.Textbox("", visible=False)
 
 
36
  promptTemplates = gr.State(load_template(get_template_names()[0], mode=2))
37
  user_question = gr.State("")
38
  assert type(my_api_key) == str
@@ -391,6 +395,8 @@ with gr.Blocks(theme=small_and_beautiful_theme) as demo:
391
  single_turn_checkbox = gr.Checkbox(label=i18n(
392
  "单轮对话"), value=False, elem_classes="switch-checkbox", elem_id="gr-single-session-cb", visible=False)
393
  # checkUpdateBtn = gr.Button(i18n("🔄 检查更新..."), visible=check_update)
 
 
394
 
395
  with gr.Tab(i18n("网络")):
396
  gr.Markdown(
@@ -801,7 +807,12 @@ with gr.Blocks(theme=small_and_beautiful_theme) as demo:
801
  outputs=[saveFileName, systemPromptTxt, chatbot, single_turn_checkbox, temperature_slider, top_p_slider, n_choices_slider, stop_sequence_txt, max_context_length_slider, max_generation_slider, presence_penalty_slider, frequency_penalty_slider, logit_bias_txt, user_identifier_txt],
802
  _js='(a,b)=>{return bgSelectHistory(a,b);}'
803
  )
804
-
 
 
 
 
 
805
  # 默认开启本地服务器,默认可以直接从IP访问,默认不创建公开分享链接
806
  demo.title = i18n("川虎Chat 🚀")
807
 
 
16
  from modules import config
17
  import gradio as gr
18
  import colorama
19
+ from modules.gradio_patch import reg_patch
20
 
21
+ reg_patch()
22
 
23
  logging.getLogger("httpx").setLevel(logging.WARNING)
24
 
 
35
 
36
  with gr.Blocks(theme=small_and_beautiful_theme) as demo:
37
  user_name = gr.Textbox("", visible=False)
38
+ # 激活/logout路由
39
+ logout_hidden_btn = gr.LogoutButton(visible=False)
40
  promptTemplates = gr.State(load_template(get_template_names()[0], mode=2))
41
  user_question = gr.State("")
42
  assert type(my_api_key) == str
 
395
  single_turn_checkbox = gr.Checkbox(label=i18n(
396
  "单轮对话"), value=False, elem_classes="switch-checkbox", elem_id="gr-single-session-cb", visible=False)
397
  # checkUpdateBtn = gr.Button(i18n("🔄 检查更新..."), visible=check_update)
398
+ logout_btn = gr.Button(
399
+ i18n("退出用户"), variant="primary", visible=authflag)
400
 
401
  with gr.Tab(i18n("网络")):
402
  gr.Markdown(
 
807
  outputs=[saveFileName, systemPromptTxt, chatbot, single_turn_checkbox, temperature_slider, top_p_slider, n_choices_slider, stop_sequence_txt, max_context_length_slider, max_generation_slider, presence_penalty_slider, frequency_penalty_slider, logit_bias_txt, user_identifier_txt],
808
  _js='(a,b)=>{return bgSelectHistory(a,b);}'
809
  )
810
+ logout_btn.click(
811
+ fn=None,
812
+ inputs=[],
813
+ outputs=[],
814
+ _js='self.location="/logout"'
815
+ )
816
  # 默认开启本地服务器,默认可以直接从IP访问,默认不创建公开分享链接
817
  demo.title = i18n("川虎Chat 🚀")
818
 
modules/gradio_patch.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+
4
+ import fastapi
5
+ import gradio
6
+ from fastapi.responses import RedirectResponse
7
+ from gradio.oauth import MOCKED_OAUTH_TOKEN
8
+
9
+ from modules.presets import i18n
10
+
11
+ OAUTH_CLIENT_ID = os.environ.get("OAUTH_CLIENT_ID")
12
+ OAUTH_CLIENT_SECRET = os.environ.get("OAUTH_CLIENT_SECRET")
13
+ OAUTH_SCOPES = os.environ.get("OAUTH_SCOPES")
14
+ OPENID_PROVIDER_URL = os.environ.get("OPENID_PROVIDER_URL")
15
+ def _add_oauth_routes(app: fastapi.FastAPI) -> None:
16
+ """Add OAuth routes to the FastAPI app (login, callback handler and logout)."""
17
+ try:
18
+ from authlib.integrations.starlette_client import OAuth
19
+ except ImportError as e:
20
+ raise ImportError(
21
+ "Cannot initialize OAuth to due a missing library. Please run `pip install gradio[oauth]` or add "
22
+ "`gradio[oauth]` to your requirements.txt file in order to install the required dependencies."
23
+ ) from e
24
+
25
+ # Check environment variables
26
+ msg = (
27
+ "OAuth is required but {} environment variable is not set. Make sure you've enabled OAuth in your Space by"
28
+ " setting `hf_oauth: true` in the Space metadata."
29
+ )
30
+ if OAUTH_CLIENT_ID is None:
31
+ raise ValueError(msg.format("OAUTH_CLIENT_ID"))
32
+ if OAUTH_CLIENT_SECRET is None:
33
+ raise ValueError(msg.format("OAUTH_CLIENT_SECRET"))
34
+ if OAUTH_SCOPES is None:
35
+ raise ValueError(msg.format("OAUTH_SCOPES"))
36
+ if OPENID_PROVIDER_URL is None:
37
+ raise ValueError(msg.format("OPENID_PROVIDER_URL"))
38
+
39
+ # Register OAuth server
40
+ oauth = OAuth()
41
+ oauth.register(
42
+ name="huggingface",
43
+ client_id=OAUTH_CLIENT_ID,
44
+ client_secret=OAUTH_CLIENT_SECRET,
45
+ client_kwargs={"scope": OAUTH_SCOPES},
46
+ server_metadata_url=OPENID_PROVIDER_URL + "/.well-known/openid-configuration",
47
+ )
48
+
49
+ # Define OAuth routes
50
+ @app.get("/login/huggingface")
51
+ async def oauth_login(request: fastapi.Request):
52
+ """Endpoint that redirects to HF OAuth page."""
53
+ redirect_uri = str(request.url_for("oauth_redirect_callback"))
54
+ if ".hf.space" in redirect_uri:
55
+ # In Space, FastAPI redirect as http but we want https
56
+ redirect_uri = redirect_uri.replace("http://", "https://")
57
+ return await oauth.huggingface.authorize_redirect(request, redirect_uri)
58
+
59
+ @app.get("/login/callback")
60
+ async def oauth_redirect_callback(request: fastapi.Request) -> RedirectResponse:
61
+ """Endpoint that handles the OAuth callback."""
62
+ token = await oauth.huggingface.authorize_access_token(request)
63
+ request.session["oauth_profile"] = token["userinfo"]
64
+ request.session["oauth_token"] = token
65
+ return RedirectResponse("/")
66
+
67
+ @app.get("/logout")
68
+ async def oauth_logout(request: fastapi.Request) -> RedirectResponse:
69
+ """Endpoint that logs out the user (e.g. delete cookie session)."""
70
+ request.session.pop("oauth_profile", None)
71
+ request.session.pop("oauth_token", None)
72
+ # 清除cookie并跳转到首页
73
+ response = RedirectResponse(url="/", status_code=302)
74
+ response.delete_cookie(key=f"access-token")
75
+ response.delete_cookie(key=f"access-token-unsecure")
76
+ return response
77
+
78
+
79
+ def _add_mocked_oauth_routes(app: fastapi.FastAPI) -> None:
80
+ """Add fake oauth routes if Gradio is run locally and OAuth is enabled.
81
+
82
+ Clicking on a gr.LoginButton will have the same behavior as in a Space (i.e. gets redirected in a new tab) but
83
+ instead of authenticating with HF, a mocked user profile is added to the session.
84
+ """
85
+
86
+ # Define OAuth routes
87
+ @app.get("/login/huggingface")
88
+ async def oauth_login(request: fastapi.Request):
89
+ """Fake endpoint that redirects to HF OAuth page."""
90
+ return RedirectResponse("/login/callback")
91
+
92
+ @app.get("/login/callback")
93
+ async def oauth_redirect_callback(request: fastapi.Request) -> RedirectResponse:
94
+ """Endpoint that handles the OAuth callback."""
95
+ request.session["oauth_profile"] = MOCKED_OAUTH_TOKEN["userinfo"]
96
+ request.session["oauth_token"] = MOCKED_OAUTH_TOKEN
97
+ return RedirectResponse("/")
98
+
99
+ @app.get("/logout")
100
+ async def oauth_logout(request: fastapi.Request) -> RedirectResponse:
101
+ """Endpoint that logs out the user (e.g. delete cookie session)."""
102
+ request.session.pop("oauth_profile", None)
103
+ request.session.pop("oauth_token", None)
104
+ # 清除cookie并跳转到首页
105
+ response = RedirectResponse(url="/", status_code=302)
106
+ response.delete_cookie(key=f"access-token")
107
+ response.delete_cookie(key=f"access-token-unsecure")
108
+ return response
109
+
110
+
111
+ def reg_patch():
112
+ gradio.oauth._add_mocked_oauth_routes = _add_mocked_oauth_routes
113
+ gradio.oauth._add_oauth_routes = _add_oauth_routes
114
+ logging.info(i18n("覆盖gradio.oauth /logout路由"))