import streamlit as st from transformers import AutoModelForCausalLM, AutoTokenizer import torch # Load the model and tokenizer @st.cache_resource def load_model(): model_name = "tiiuae/falcon-7b-instruct" # Replace with the desired Falcon model tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained( model_name, device_map="auto", # Automatically assign model layers to available GPUs/CPUs torch_dtype=torch.float16 # Use FP16 for faster inference ) return model, tokenizer model, tokenizer = load_model() # Initialize chat history if "messages" not in st.session_state: st.session_state["messages"] = [] # Sidebar configuration st.sidebar.title("Chatbot Settings") st.sidebar.write("Customize your chatbot:") max_length = st.sidebar.slider("Max Response Length (Tokens)", 50, 500, 150) temperature = st.sidebar.slider("Response Creativity (Temperature)", 0.1, 1.0, 0.7) # App title st.title("🤖 Falcon Chatbot") # Chat interface st.write("### Chat with the bot:") user_input = st.text_input("You:", key="user_input", placeholder="Type your message here...") if user_input: # Add user input to chat history st.session_state["messages"].append(f"User: {user_input}") # Prepare input for the model prompt = "\n".join(st.session_state["messages"]) + f"\nAssistant:" inputs = tokenizer(prompt, return_tensors="pt", truncation=True).to(model.device) # Generate response with st.spinner("Thinking..."): output = model.generate( inputs.input_ids, max_length=max_length, temperature=temperature, pad_token_id=tokenizer.eos_token_id, ) bot_response = tokenizer.decode(output[0], skip_special_tokens=True).split("Assistant:")[-1].strip() # Add bot response to chat history st.session_state["messages"].append(f"Assistant: {bot_response}") # Display chat history for msg in st.session_state["messages"]: if msg.startswith("User:"): st.markdown(f"**{msg}**") elif msg.startswith("Assistant:"): st.markdown(f"> {msg}") # Clear chat history button if st.button("Clear Chat"): st.session_state["messages"] = [] st.experimental_rerun()