zmbfeng's picture
basics working
73e2cf0
raw
history blame
2.35 kB
import streamlit as st
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
big_text = """
<div style='text-align: center;'>
<h1 style='font-size: 30x;'>Knowledge Extraction 1</h1>
</div>
"""
st.markdown(big_text, unsafe_allow_html=True)
if 'is_initialized' not in st.session_state:
st.session_state['is_initialized'] = True
model_name = "EleutherAI/gpt-neo-125M"
st.session_state.model_name = "EleutherAI/gpt-neo-125M"
st.session_state.tokenizer = AutoTokenizer.from_pretrained(model_name)
st.session_state.model = AutoModelForCausalLM.from_pretrained("zmbfeng/gpt-neo-125M_untethered_100_epochs_multiple_paragraph")
if torch.cuda.is_available():
st.session_state.device = torch.device("cuda")
print("Using GPU:", torch.cuda.get_device_name(0))
else:
st.session_state.device = torch.device("cpu")
print("GPU is not available, using CPU instead.")
st.session_state.model.to(st.session_state.device)
#prompt = "Discuss the impact of artificial intelligence on modern society."
#prompt = "What is one of the best teachers in all of life?"
#prompt = "What is the necessary awareness for deep and meaningful relationships?"
#prompt = "What would happen if you knew you were going to die within a week or month?"
#prompt = "question: What is one of the best teachers in all of life? "
#prompt = "question: What would happen if death were to happen in an hour, week, or year?"
#=============
#prompt = "question: What if you live life fully?"
#prompt = "question: What does death do to you?"
#============
#prompt = "question: Do you understand that every minute you're on the verge of death?"
#most recent:
#prompt = "question: Are you going to wait until the last moment to let death be your teacher?"
query = st.text_input("Enter your query")
if query:
prompt = "question: "+query
input_ids = st.session_state.tokenizer(prompt, return_tensors="pt").input_ids.to(st.session_state.device)
# Generate a response
output = st.session_state.model.generate(input_ids, max_length=2048, do_sample=True,temperature=0.01, pad_token_id=st.session_state.tokenizer.eos_token_id) #exact result for single paragraph
# Decode the output
response = st.session_state.tokenizer.decode(output[0], skip_special_tokens=True)
st.write(response)