Spaces:
Running
Running
File size: 3,489 Bytes
41fa981 c665729 41fa981 4108df0 41fa981 4108df0 41fa981 6f87c71 4108df0 2c96f08 41fa981 f121b56 41fa981 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
import torch
import streamlit as st
from streamlit_chat import message
checkpoint = "."
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
@st.cache
def get_model():
model = AutoModelForCausalLM.from_pretrained(checkpoint)
return model
st.title("Chat with myGPT 🦄")
st.write("This is a LLM that was fine-tuned on a dataset of daily conversations.")
if 'count' not in st.session_state or st.session_state.count >= 3:
st.session_state.count = 0
st.session_state.chat_history_ids = None
st.session_state.old_response = ''
else:
st.session_state.count += 1
if 'message_history' not in st.session_state:
st.session_state.message_history = []
if 'response_history' not in st.session_state:
st.session_state.response_history = []
if 'input' not in st.session_state:
st.session_state.input = ''
def submit():
st.session_state.input = st.session_state.user_input
st.session_state.user_input = ''
# prompt = "How long will it take for the poc to finish?"
# inputs = tokenizer(prompt, return_tensors="pt")
model = get_model()
generation_config = GenerationConfig(max_new_tokens=32,
num_beams=4,
early_stopping=True,
no_repeat_ngram_size=2,
do_sample=True,
penalty_alpha=0.6,
top_k=4,
#top_p=0.95,
#temperature=0.8,
pad_token_id=tokenizer.eos_token_id)
for i in range(0, len(st.session_state.message_history)):
message(st.session_state.message_history[i], is_user=True, key=str(i)+'_user', avatar_style="identicon", seed='You') # display all the previous message
if i in range(0, len(st.session_state.response_history)):
message(st.session_state.response_history[i], key=str(i), avatar_style="bottts", seed='mera GPT')
placeholder = st.empty() # placeholder for latest message
st.text_input('You:', key='user_input', on_change=submit)
if st.session_state.input:
st.session_state.message_history.append(st.session_state.input)
new_user_input_ids = tokenizer.encode(tokenizer.eos_token + st.session_state.input, return_tensors="pt")
bot_input_ids = torch.cat([st.session_state.chat_history_ids, new_user_input_ids], dim=-1) if st.session_state.count > 1 else new_user_input_ids
st.session_state.chat_history_ids = model.generate(bot_input_ids, generation_config)
response = tokenizer.decode(st.session_state.chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
if st.session_state.old_response == response:
bot_input_ids = new_user_input_ids
st.session_state.chat_history_ids = model.generate(bot_input_ids, generation_config)
response = tokenizer.decode(st.session_state.chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
# st.write(f"meraGPT: {response}")
st.session_state.old_response = response
st.session_state.response_history.append(response)
with placeholder.container():
message(st.session_state.message_history[-1], is_user=True, key=str(-1)+'_user', avatar_style="identicon", seed='You') # display the latest message
message(st.session_state.response_history[-1], key=str(-1), avatar_style="bottts", seed='mera GPT') # display the latest message |