Ashhar commited on
Commit
db7b744
·
1 Parent(s): 71a2c91

added auth

Browse files
Files changed (5) hide show
  1. app.py +113 -105
  2. auth.py +56 -0
  3. requirements.txt +1 -0
  4. sidebar.py +57 -0
  5. utils.py +15 -0
app.py CHANGED
@@ -7,13 +7,15 @@ import re
7
  from typing import List, Literal, TypedDict, Tuple
8
  from transformers import AutoTokenizer
9
  from gradio_client import Client
10
- import constants as C
11
- import utils as U
12
-
13
  from openai import OpenAI
14
  import anthropic
15
  from groq import Groq
16
 
 
 
 
 
 
17
  from dotenv import load_dotenv
18
  load_dotenv()
19
 
@@ -226,7 +228,7 @@ def __logLlmRequest(messagesFormatted: list):
226
  # U.pprint(f"{messagesFormatted=}")
227
 
228
 
229
- def predict():
230
  messagesFormatted = []
231
 
232
  try:
@@ -352,112 +354,118 @@ U.pprint("\n")
352
  U.applyCommonStyles()
353
  st.title("Kommuneity Story Creator 🪄")
354
 
355
- if "startMsg" not in st.session_state:
356
- __setStartMsg("")
357
- st.button(C.START_MSG, on_click=lambda: __setStartMsg(C.START_MSG))
358
-
359
- for chat in st.session_state.chatHistory:
360
- role = chat["role"]
361
- content = chat["content"]
362
- imagePath = chat.get("image")
363
- avatar = C.AI_ICON if role == "assistant" else C.USER_ICON
364
- with st.chat_message(role, avatar=avatar):
365
- st.markdown(content)
366
- if imagePath:
367
- st.image(imagePath)
368
-
369
- # U.pprint(f"{st.session_state.buttonValue=}")
370
- # U.pprint(f"{st.session_state.selectedStory=}")
371
- # U.pprint(f"{st.session_state.startMsg=}")
372
-
373
- if prompt := (
374
- st.chat_input()
375
- or st.session_state["buttonValue"]
376
- or st.session_state["selectedStoryTitle"]
377
- or st.session_state["startMsg"]
378
- ):
379
- __resetButtonState()
380
- __setStartMsg("")
381
- if st.session_state["selectedStoryTitle"] != prompt:
382
- __resetSelectedStory()
383
- st.session_state.selectedStoryTitle = ""
384
 
385
- with st.chat_message("user", avatar=C.USER_ICON):
386
- st.markdown(prompt)
387
- U.pprint(f"{prompt=}")
388
- st.session_state.chatHistory.append({"role": "user", "content": prompt })
389
- st.session_state.messages.append({"role": "user", "content": prompt})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
390
 
391
- with st.chat_message("assistant", avatar=C.AI_ICON):
392
- responseContainer = st.empty()
 
 
 
393
 
394
- def __printAndGetResponse():
395
- response = ""
396
- responseContainer.image(C.TEXT_LOADER)
397
- responseGenerator = predict()
398
 
399
- for chunk in responseGenerator:
400
- response += chunk
401
- if __isInvalidResponse(response):
402
- U.pprint(f"InvalidResponse={response}")
403
- return
404
 
405
- if C.JSON_SEPARATOR not in response:
406
- responseContainer.markdown(response)
 
 
 
407
 
408
- return response
 
409
 
410
- response = __printAndGetResponse()
411
- while not response:
412
- U.pprint("Empty response. Retrying..")
413
- time.sleep(0.7)
414
- response = __printAndGetResponse()
415
 
416
- U.pprint(f"{response=}")
417
-
418
- def selectButton(optionLabel):
419
- st.session_state["buttonValue"] = optionLabel
420
- U.pprint(f"Selected: {optionLabel}")
421
-
422
- rawResponse = response
423
- responseParts = response.split(C.JSON_SEPARATOR)
424
-
425
- jsonStr = None
426
- if len(responseParts) > 1:
427
- [response, jsonStr] = responseParts
428
-
429
- imageContainer = st.empty()
430
- imagePath = __paintImageIfApplicable(imageContainer, prompt, response)
431
-
432
- st.session_state.chatHistory.append({
433
- "role": "assistant",
434
- "content": response,
435
- "image": imagePath,
436
- })
437
- st.session_state.messages.append({
438
- "role": "assistant",
439
- "content": rawResponse,
440
- })
441
-
442
- if jsonStr:
443
- try:
444
- json.loads(jsonStr)
445
- jsonObj = json.loads(jsonStr)
446
- options = jsonObj.get("options")
447
- action = jsonObj.get("action")
448
-
449
- if options:
450
- for option in options:
451
- st.button(
452
- option["label"],
453
- key=option["id"],
454
- on_click=lambda label=option["label"]: selectButton(label)
455
- )
456
- elif action:
457
- U.pprint(f"{action=}")
458
- if action == "SHOW_STORY_DATABASE":
459
- time.sleep(0.5)
460
- st.switch_page("pages/popular-stories.py")
461
- # st.code(jsonStr, language="json")
462
- except Exception as e:
463
- U.pprint(e)
 
 
 
 
 
 
 
 
 
 
 
7
  from typing import List, Literal, TypedDict, Tuple
8
  from transformers import AutoTokenizer
9
  from gradio_client import Client
 
 
 
10
  from openai import OpenAI
11
  import anthropic
12
  from groq import Groq
13
 
14
+ import constants as C
15
+ import utils as U
16
+ from auth import authenticateFunc
17
+ from sidebar import showSidebar
18
+
19
  from dotenv import load_dotenv
20
  load_dotenv()
21
 
 
228
  # U.pprint(f"{messagesFormatted=}")
229
 
230
 
231
+ def __predict():
232
  messagesFormatted = []
233
 
234
  try:
 
354
  U.applyCommonStyles()
355
  st.title("Kommuneity Story Creator 🪄")
356
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
357
 
358
+ def mainApp():
359
+ if "startMsg" not in st.session_state:
360
+ __setStartMsg("")
361
+ st.button(C.START_MSG, on_click=lambda: __setStartMsg(C.START_MSG))
362
+
363
+ for chat in st.session_state.chatHistory:
364
+ role = chat["role"]
365
+ content = chat["content"]
366
+ imagePath = chat.get("image")
367
+ avatar = C.AI_ICON if role == "assistant" else C.USER_ICON
368
+ with st.chat_message(role, avatar=avatar):
369
+ st.markdown(content)
370
+ if imagePath:
371
+ st.image(imagePath)
372
+
373
+ # U.pprint(f"{st.session_state.buttonValue=}")
374
+ # U.pprint(f"{st.session_state.selectedStoryTitle=}")
375
+ # U.pprint(f"{st.session_state.startMsg=}")
376
+
377
+ if prompt := (
378
+ st.chat_input()
379
+ or st.session_state["buttonValue"]
380
+ or st.session_state["selectedStoryTitle"]
381
+ or st.session_state["startMsg"]
382
+ ):
383
+ __resetButtonState()
384
+ __setStartMsg("")
385
+ if st.session_state["selectedStoryTitle"] != prompt:
386
+ __resetSelectedStory()
387
+ st.session_state.selectedStoryTitle = ""
388
 
389
+ with st.chat_message("user", avatar=C.USER_ICON):
390
+ st.markdown(prompt)
391
+ U.pprint(f"{prompt=}")
392
+ st.session_state.chatHistory.append({"role": "user", "content": prompt })
393
+ st.session_state.messages.append({"role": "user", "content": prompt})
394
 
395
+ with st.chat_message("assistant", avatar=C.AI_ICON):
396
+ responseContainer = st.empty()
 
 
397
 
398
+ def __printAndGetResponse():
399
+ response = ""
400
+ responseContainer.image(C.TEXT_LOADER)
401
+ responseGenerator = __predict()
 
402
 
403
+ for chunk in responseGenerator:
404
+ response += chunk
405
+ if __isInvalidResponse(response):
406
+ U.pprint(f"InvalidResponse={response}")
407
+ return
408
 
409
+ if C.JSON_SEPARATOR not in response:
410
+ responseContainer.markdown(response)
411
 
412
+ return response
 
 
 
 
413
 
414
+ response = __printAndGetResponse()
415
+ while not response:
416
+ U.pprint("Empty response. Retrying..")
417
+ time.sleep(0.7)
418
+ response = __printAndGetResponse()
419
+
420
+ U.pprint(f"{response=}")
421
+
422
+ def selectButton(optionLabel):
423
+ st.session_state["buttonValue"] = optionLabel
424
+ U.pprint(f"Selected: {optionLabel}")
425
+
426
+ rawResponse = response
427
+ responseParts = response.split(C.JSON_SEPARATOR)
428
+
429
+ jsonStr = None
430
+ if len(responseParts) > 1:
431
+ [response, jsonStr] = responseParts
432
+
433
+ imageContainer = st.empty()
434
+ imagePath = __paintImageIfApplicable(imageContainer, prompt, response)
435
+
436
+ st.session_state.chatHistory.append({
437
+ "role": "assistant",
438
+ "content": response,
439
+ "image": imagePath,
440
+ })
441
+ st.session_state.messages.append({
442
+ "role": "assistant",
443
+ "content": rawResponse,
444
+ })
445
+
446
+ if jsonStr:
447
+ try:
448
+ json.loads(jsonStr)
449
+ jsonObj = json.loads(jsonStr)
450
+ options = jsonObj.get("options")
451
+ action = jsonObj.get("action")
452
+
453
+ if options:
454
+ for option in options:
455
+ st.button(
456
+ option["label"],
457
+ key=option["id"],
458
+ on_click=lambda label=option["label"]: selectButton(label)
459
+ )
460
+ elif action:
461
+ U.pprint(f"{action=}")
462
+ if action == "SHOW_STORY_DATABASE":
463
+ time.sleep(0.5)
464
+ st.switch_page("pages/popular-stories.py")
465
+ # st.code(jsonStr, language="json")
466
+ except Exception as e:
467
+ U.pprint(e)
468
+
469
+
470
+ authenticateFunc(mainApp)
471
+ showSidebar()
auth.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Callable, Any
3
+ import streamlit as st
4
+ from descope.descope_client import DescopeClient
5
+ from descope.exceptions import AuthException
6
+
7
+ DESCOPE_PROJECT_ID = os.environ.get("DESCOPE_PROJECT_ID")
8
+ descopeClient = DescopeClient(project_id=DESCOPE_PROJECT_ID)
9
+
10
+
11
+ def authenticateFunc(func: Callable[[], Any]):
12
+ if "token" not in st.session_state:
13
+ if "code" in st.query_params:
14
+ code = st.query_params["code"]
15
+ st.query_params.clear()
16
+ try:
17
+ with st.spinner("Authenticating ..."):
18
+ jwtResponse = descopeClient.sso.exchange_token(code)
19
+ st.session_state["token"] = jwtResponse["sessionToken"].get("jwt")
20
+ st.session_state["refreshToken"] = jwtResponse["refreshSessionToken"].get(
21
+ "jwt"
22
+ )
23
+ st.session_state["user"] = jwtResponse["user"]
24
+ st.rerun()
25
+ except AuthException:
26
+ st.error("Login failed!")
27
+
28
+ st.warning("You're not logged in. Please login to continue")
29
+ with st.container(border=False):
30
+ if st.button(
31
+ "Sign in with Google",
32
+ type="primary",
33
+ use_container_width=True
34
+ ):
35
+ oauthResponse = descopeClient.oauth.start(
36
+ provider="google", return_url=st.context.headers["Origin"]
37
+ )
38
+ url = oauthResponse["url"]
39
+ # Redirect to Google
40
+ st.markdown(
41
+ f'<meta http-equiv="refresh" content="0; url={url}">',
42
+ unsafe_allow_html=True,
43
+ )
44
+ else:
45
+ try:
46
+ with st.spinner("Authenticating ..."):
47
+ jwtResponse = descopeClient.validate_and_refresh_session(
48
+ st.session_state.token, st.session_state.refreshToken
49
+ )
50
+ st.session_state["token"] = jwtResponse["sessionToken"].get("jwt")
51
+
52
+ func()
53
+ except AuthException:
54
+ # Log out user
55
+ del st.session_state.token
56
+ st.rerun()
requirements.txt CHANGED
@@ -5,3 +5,4 @@ transformers
5
  gradio_client
6
  anthropic
7
  supabase
 
 
5
  gradio_client
6
  anthropic
7
  supabase
8
+ descope
sidebar.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+
4
+ def showSidebar():
5
+ with st.sidebar:
6
+ if "user" in st.session_state and "token" in st.session_state:
7
+ st.markdown("""
8
+ <style>
9
+ .user-info {
10
+ display: flex;
11
+ align-items: center;
12
+ padding: 10px;
13
+ background-color: rgba(255, 255, 255, 0.1);
14
+ border-radius: 5px;
15
+ margin-bottom: 10px;
16
+ margin-top: -2rem;
17
+ }
18
+ .user-avatar {
19
+ width: 40px;
20
+ height: 40px;
21
+ border-radius: 50%;
22
+ margin-right: 10px;
23
+ }
24
+ .user-details {
25
+ flex-grow: 1;
26
+ }
27
+ .user-name {
28
+ font-weight: bold;
29
+ margin: 0;
30
+ }
31
+ .user-email {
32
+ font-size: 0.8em;
33
+ color: #888;
34
+ margin: 0;
35
+ }
36
+ </style>
37
+ """, unsafe_allow_html=True)
38
+
39
+ user_avatar = st.session_state["user"].get("picture", "https://example.com/default-avatar.png")
40
+ user_name = st.session_state["user"].get("name", "User")
41
+ user_email = st.session_state["user"].get("email", "")
42
+
43
+ st.markdown(f"""
44
+ <div class="user-info">
45
+ <img src="{user_avatar}" class="user-avatar">
46
+ <div class="user-details">
47
+ <p class="user-name">{user_name}</p>
48
+ <p class="user-email">{user_email}</p>
49
+ </div>
50
+ </div>
51
+ """, unsafe_allow_html=True)
52
+
53
+ if st.button("Logout", key="logout_button", type="secondary", use_container_width=True):
54
+ for key in ["token", "user"]:
55
+ if key in st.session_state:
56
+ del st.session_state[key]
57
+ st.rerun()
utils.py CHANGED
@@ -100,6 +100,21 @@ def applyCommonStyles():
100
  height: 620px;
101
  }
102
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  </style>
104
  """,
105
  unsafe_allow_html=True
 
100
  height: 620px;
101
  }
102
 
103
+ div.stSpinner {
104
+ margin-left: auto;
105
+ margin-right: auto;
106
+ margin-top: 1rem;
107
+ width: fit-content;
108
+ }
109
+
110
+ div.stAlert {
111
+ margin-top: 1rem;
112
+ }
113
+
114
+ # section[data-testid="stSidebar"] {
115
+ # width: 20vw !important;
116
+ # }
117
+
118
  </style>
119
  """,
120
  unsafe_allow_html=True