Joanna30 commited on
Commit
dbf29f5
·
verified ·
1 Parent(s): b14ce1c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -21
app.py CHANGED
@@ -2,7 +2,8 @@ import streamlit as st
2
  from streamlit_chat import message
3
  from langchain_google_genai import ChatGoogleGenerativeAI
4
  from langchain.chains import ConversationChain
5
- from langchain.chains.conversation.memory import ConversationBufferMemory, ConversationSummaryMemory
 
6
 
7
  # Step 1: Set up Google API key
8
  google_api_key = st.secrets["google_api_key"]
@@ -19,10 +20,32 @@ if 'API_Key' not in st.session_state:
19
  st.set_page_config(page_title="Chat GPT Clone", page_icon=":robot_face:")
20
  st.markdown("<h1 style='text-align: center;'>How can I assist you? </h1>", unsafe_allow_html=True)
21
 
22
- # Sidebar for API key input
23
  st.sidebar.title("To start chatting,")
24
  st.session_state['API_Key'] = st.sidebar.text_input("Enter your Google API key below", type="password", key="google_api_key_input")
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  # Summarization button
27
  summarise_button = st.sidebar.button("Summarise the conversation", key="summarise")
28
 
@@ -33,7 +56,7 @@ if summarise_button:
33
 
34
  # Split summary into sentences
35
  summary_sentences = summary.strip().split(". ")
36
-
37
  # Exclude the first two sentences
38
  filtered_summary = summary_sentences[2:]
39
 
@@ -49,22 +72,30 @@ if summarise_button:
49
  else:
50
  st.sidebar.write("No conversation history to summarize.")
51
 
52
-
53
  # Step 4: Define the getresponse function using Google's Gemini
54
- def getresponse(userInput, api_key):
55
- if st.session_state['conversation'] is None:
56
- # Initialize the Google generative model
57
- chat = ChatGoogleGenerativeAI(
58
- model="gemini-1.5-flash",
59
- google_api_key=api_key
60
- )
61
- st.session_state['conversation'] = ConversationChain(
62
- llm=chat,
63
- verbose=True,
64
- memory=ConversationSummaryMemory(llm=chat)
65
- )
66
- response = st.session_state['conversation'].predict(input=userInput)
67
- return response
 
 
 
 
 
 
 
 
 
68
 
69
  # Step 5: Creating the Chat UI
70
  response_container = st.container()
@@ -75,10 +106,14 @@ with container:
75
  user_input = st.text_area("Your question goes here:", key='input', height=100)
76
  submit_button = st.form_submit_button(label='Send')
77
  if submit_button:
78
- st.session_state['messages'].append(user_input)
79
- model_response = getresponse(user_input, st.session_state['API_Key'])
80
- st.session_state['messages'].append(model_response)
 
 
 
81
 
 
82
  with response_container:
83
  for i in range(len(st.session_state['messages'])):
84
  if (i % 2) == 0:
 
2
  from streamlit_chat import message
3
  from langchain_google_genai import ChatGoogleGenerativeAI
4
  from langchain.chains import ConversationChain
5
+ from langchain.chains.conversation.memory import ConversationSummaryMemory
6
+ from langchain_google_genai import GoogleGenerativeAIError
7
 
8
  # Step 1: Set up Google API key
9
  google_api_key = st.secrets["google_api_key"]
 
20
  st.set_page_config(page_title="Chat GPT Clone", page_icon=":robot_face:")
21
  st.markdown("<h1 style='text-align: center;'>How can I assist you? </h1>", unsafe_allow_html=True)
22
 
23
+ # Sidebar for API key input and model selection
24
  st.sidebar.title("To start chatting,")
25
  st.session_state['API_Key'] = st.sidebar.text_input("Enter your Google API key below", type="password", key="google_api_key_input")
26
 
27
+ # Support multiple models
28
+ st.sidebar.markdown("### Select Model:")
29
+ model_name = st.sidebar.selectbox(
30
+ "Choose a model:",
31
+ ["gemini-1.5-flash", "gemini-1.5-pro"],
32
+ index=0
33
+ )
34
+
35
+ # Add instructions for users
36
+ if 'welcome' not in st.session_state:
37
+ st.session_state['welcome'] = True
38
+
39
+ if st.session_state['welcome']:
40
+ st.sidebar.info(
41
+ "### Instructions:\n"
42
+ "1. Enter your Google API key (optional if pre-configured).\n"
43
+ "2. Choose a model from the dropdown menu.\n"
44
+ "3. Type your question in the text area and click 'Send'.\n"
45
+ "4. Click 'Summarise the conversation' to view a summary of your chat."
46
+ )
47
+ st.session_state['welcome'] = False
48
+
49
  # Summarization button
50
  summarise_button = st.sidebar.button("Summarise the conversation", key="summarise")
51
 
 
56
 
57
  # Split summary into sentences
58
  summary_sentences = summary.strip().split(". ")
59
+
60
  # Exclude the first two sentences
61
  filtered_summary = summary_sentences[2:]
62
 
 
72
  else:
73
  st.sidebar.write("No conversation history to summarize.")
74
 
 
75
  # Step 4: Define the getresponse function using Google's Gemini
76
+ def getresponse(userInput, api_key, model_name):
77
+ try:
78
+ if st.session_state['conversation'] is None:
79
+ # Initialize the Google generative model
80
+ with st.spinner("Setting up the conversation..."):
81
+ chat = ChatGoogleGenerativeAI(
82
+ model=model_name,
83
+ google_api_key=api_key
84
+ )
85
+ st.session_state['conversation'] = ConversationChain(
86
+ llm=chat,
87
+ verbose=True,
88
+ memory=ConversationSummaryMemory(llm=chat)
89
+ )
90
+
91
+ # Get response with loading indicator
92
+ with st.spinner("Generating response..."):
93
+ response = st.session_state['conversation'].predict(input=userInput)
94
+ return response
95
+
96
+ except GoogleGenerativeAIError as e:
97
+ st.error(f"API Error: {str(e)}")
98
+ return "Sorry, there was an issue processing your request."
99
 
100
  # Step 5: Creating the Chat UI
101
  response_container = st.container()
 
106
  user_input = st.text_area("Your question goes here:", key='input', height=100)
107
  submit_button = st.form_submit_button(label='Send')
108
  if submit_button:
109
+ if user_input.strip(): # Check for empty input
110
+ st.session_state['messages'].append(user_input)
111
+ model_response = getresponse(user_input, st.session_state['API_Key'], model_name)
112
+ st.session_state['messages'].append(model_response)
113
+ else:
114
+ st.warning("Please enter a message before sending.")
115
 
116
+ # Display chat messages
117
  with response_container:
118
  for i in range(len(st.session_state['messages'])):
119
  if (i % 2) == 0: