Spaces:
Sleeping
Sleeping
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig | |
import torch | |
import streamlit as st | |
from streamlit_chat import message | |
checkpoint = "." | |
tokenizer = AutoTokenizer.from_pretrained(checkpoint) | |
def get_model(): | |
model = AutoModelForCausalLM.from_pretrained(checkpoint) | |
return model | |
st.title("Chat with myGPT π¦") | |
st.write("This is a LLM that was fine-tuned on a dataset of daily conversations.") | |
if 'count' not in st.session_state or st.session_state.count >= 3: | |
st.session_state.count = 0 | |
st.session_state.chat_history_ids = None | |
st.session_state.old_response = '' | |
else: | |
st.session_state.count += 1 | |
if 'message_history' not in st.session_state: | |
st.session_state.message_history = [] | |
if 'response_history' not in st.session_state: | |
st.session_state.response_history = [] | |
if 'input' not in st.session_state: | |
st.session_state.input = '' | |
def submit(): | |
st.session_state.input = st.session_state.user_input | |
st.session_state.user_input = '' | |
# prompt = "How long will it take for the poc to finish?" | |
# inputs = tokenizer(prompt, return_tensors="pt") | |
model = get_model() | |
generation_config = GenerationConfig(max_new_tokens=32, | |
num_beams=4, | |
early_stopping=True, | |
no_repeat_ngram_size=2, | |
do_sample=True, | |
penalty_alpha=0.6, | |
top_k=4, | |
#top_p=0.95, | |
#temperature=0.8, | |
pad_token_id=tokenizer.eos_token_id) | |
for i in range(0, len(st.session_state.message_history)): | |
message(st.session_state.message_history[i], is_user=True, key=str(i)+'_user', avatar_style="identicon", seed='You') # display all the previous message | |
if i in range(0, len(st.session_state.response_history)): | |
message(st.session_state.response_history[i], key=str(i), avatar_style="bottts", seed='mera GPT') | |
placeholder = st.empty() # placeholder for latest message | |
st.text_input('You:', key='user_input', on_change=submit) | |
if st.session_state.input: | |
st.session_state.message_history.append(st.session_state.input) | |
new_user_input_ids = tokenizer.encode(tokenizer.eos_token + st.session_state.input, return_tensors="pt") | |
bot_input_ids = torch.cat([st.session_state.chat_history_ids, new_user_input_ids], dim=-1) if st.session_state.count > 1 else new_user_input_ids | |
st.session_state.chat_history_ids = model.generate(bot_input_ids, generation_config) | |
response = tokenizer.decode(st.session_state.chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True) | |
if st.session_state.old_response == response: | |
bot_input_ids = new_user_input_ids | |
st.session_state.chat_history_ids = model.generate(bot_input_ids, generation_config) | |
response = tokenizer.decode(st.session_state.chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True) | |
# st.write(f"meraGPT: {response}") | |
st.session_state.old_response = response | |
st.session_state.response_history.append(response) | |
with placeholder.container(): | |
message(st.session_state.message_history[-1], is_user=True, key=str(-1)+'_user', avatar_style="identicon", seed='You') # display the latest message | |
message(st.session_state.response_history[-1], key=str(-1), avatar_style="bottts", seed='mera GPT') # display the latest message |