Joanna30 commited on
Commit
c97b6ba
·
verified ·
1 Parent(s): 30f1ad1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -22
app.py CHANGED
@@ -1,9 +1,12 @@
1
  import streamlit as st
2
  from streamlit_chat import message
3
- from langchain.chains import ConversationChain
4
  from langchain.chains.conversation.memory import ConversationBufferMemory, ConversationSummaryMemory
5
- from langchain.llms import GooglePalm
6
- from langchain.prompts import PromptTemplate
 
 
 
7
 
8
  # Initialize session state variables
9
  if 'conversation' not in st.session_state:
@@ -11,31 +14,35 @@ if 'conversation' not in st.session_state:
11
  if 'messages' not in st.session_state:
12
  st.session_state['messages'] = []
13
  if 'API_Key' not in st.session_state:
14
- st.session_state['API_Key'] = ''
15
 
16
  # Setting page title and header
17
- st.set_page_config(page_title="ChatGPT Clone", page_icon=":robot_face:")
18
  st.markdown("<h1 style='text-align: center;'>How can I assist you?</h1>", unsafe_allow_html=True)
19
 
20
- # Sidebar input for the API key
21
- st.sidebar.title("⚙️ Settings")
22
- st.session_state['API_Key'] = st.sidebar.text_input("Enter your Google API key", type="password")
23
 
24
- # Summarization button
25
  summarise_button = st.sidebar.button("Summarise the conversation", key="summarise")
26
  if summarise_button:
27
- summary = st.session_state['conversation'].memory.buffer
28
- st.sidebar.write("Conversation Summary:\n\n" + summary)
 
 
29
 
30
  # Defining the get_response function
31
  def get_response(user_input, api_key):
32
  if st.session_state['conversation'] is None:
33
- # Use Google Palm as the LLM
34
- llm = GooglePalm(google_api_key=api_key, temperature=0.7)
 
 
 
35
  st.session_state['conversation'] = ConversationChain(
36
  llm=llm,
37
  verbose=True,
38
- memory=ConversationSummaryMemory(llm=llm)
39
  )
40
  response = st.session_state['conversation'].predict(input=user_input)
41
  return response
@@ -45,16 +52,18 @@ response_container = st.container()
45
  container = st.container()
46
  with container:
47
  with st.form(key='my_form', clear_on_submit=True):
48
- user_input = st.text_area("Your question here:", key='input', height=100)
49
  submit_button = st.form_submit_button(label='Send')
50
- if submit_button and user_input:
51
  st.session_state['messages'].append(user_input)
52
- model_response = get_response(user_input, st.session_state['API_Key'])
 
53
  st.session_state['messages'].append(model_response)
54
 
55
  with response_container:
56
- for i in range(len(st.session_state['messages'])):
57
- if (i % 2) == 0:
58
- message(st.session_state['messages'][i], is_user=True, key=str(i) + '_user')
59
- else:
60
- message(st.session_state['messages'][i], key=str(i) + '_AI')
 
 
1
  import streamlit as st
2
  from streamlit_chat import message
3
+ from langchain import HuggingFaceHub, ConversationChain
4
  from langchain.chains.conversation.memory import ConversationBufferMemory, ConversationSummaryMemory
5
+ from langchain.memory import ConversationBufferWindowMemory
6
+ import os
7
+
8
+ # Get API key from environment variable
9
+ hf_api_key = os.environ.get("HUGGINGFACEHUB_API_TOKEN")
10
 
11
  # Initialize session state variables
12
  if 'conversation' not in st.session_state:
 
14
  if 'messages' not in st.session_state:
15
  st.session_state['messages'] = []
16
  if 'API_Key' not in st.session_state:
17
+ st.session_state['API_Key'] = '' # No need for API key input in this case
18
 
19
  # Setting page title and header
20
+ st.set_page_config(page_title="Chat GPT Clone", page_icon=":robot_face:")
21
  st.markdown("<h1 style='text-align: center;'>How can I assist you?</h1>", unsafe_allow_html=True)
22
 
23
+ # Sidebar (no API key input needed)
24
+ st.sidebar.title("Options")
 
25
 
26
+ # Summarization Button
27
  summarise_button = st.sidebar.button("Summarise the conversation", key="summarise")
28
  if summarise_button:
29
+ with st.spinner("Summarizing..."): # Add a spinner for visual feedback
30
+ # Access the buffer for ConversationBufferMemory
31
+ summary = st.session_state['conversation'].memory.buffer
32
+ st.sidebar.write("Conversation Summary:\n\n" + summary)
33
 
34
  # Defining the get_response function
35
  def get_response(user_input, api_key):
36
  if st.session_state['conversation'] is None:
37
+ llm = HuggingFaceHub(
38
+ repo_id="google/gemini-pro-flash", # Use Gemini Pro Flash
39
+ model_kwargs={"temperature": 0.1, "max_new_tokens": 512}
40
+ )
41
+ # Use ConversationBufferMemory for summarization
42
  st.session_state['conversation'] = ConversationChain(
43
  llm=llm,
44
  verbose=True,
45
+ memory=ConversationBufferMemory()
46
  )
47
  response = st.session_state['conversation'].predict(input=user_input)
48
  return response
 
52
  container = st.container()
53
  with container:
54
  with st.form(key='my_form', clear_on_submit=True):
55
+ user_input = st.text_area("Your question goes here:", key='input', height=100)
56
  submit_button = st.form_submit_button(label='Send')
57
+ if submit_button and user_input: # Check if user_input is not empty
58
  st.session_state['messages'].append(user_input)
59
+ with st.spinner("Thinking..."): # Add a spinner
60
+ model_response = get_response(user_input, hf_api_key)
61
  st.session_state['messages'].append(model_response)
62
 
63
  with response_container:
64
+ if st.session_state['messages']: # Check if there are messages to display
65
+ for i in range(len(st.session_state['messages'])):
66
+ if (i % 2) == 0:
67
+ message(st.session_state['messages'][i], is_user=True, key=str(i) + '_user')
68
+ else:
69
+ message(st.session_state['messages'][i], key=str(i) + '_AI')