import streamlit as st from transformers import AutoModelForCausalLM, AutoTokenizer import torch # Load model and tokenizer @st.cache_resource def load_model_and_tokenizer(): model_name = "TheBloke/Mistral-7B-Instruct-v0.2-GPTQ" model = AutoModelForCausalLM.from_pretrained( model_name, device_map="auto", trust_remote_code=False, revision="main" ) tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) # Check if tokenizer has a pad token, if not add it if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # Use eos_token as padding token return model, tokenizer model, tokenizer = load_model_and_tokenizer() # Define the prompt template def generate_prompt(comment): instructions = f"""Virtual Psychologist, communicates with empathy and understanding, focusing on mental health support and providing advice within its expertise. \ It actively listens, acknowledges emotions, and avoids overly clinical or technical language unless specifically requested. \ It reacts to feedback with warmth and adjusts its tone to match the individual's needs, offering encouragement and validation as appropriate. \ Responses are tailored in length and tone to ensure a supportive and conversational experience. """ return f"[INST] {instructions} \n{comment} \n[/INST]" # Define the response generator def get_response(comment): prompt = generate_prompt(comment) inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True) # Check if CUDA is available, otherwise use CPU device = "cuda" if torch.cuda.is_available() else "cpu" outputs = model.generate( input_ids=inputs["input_ids"].to(device), attention_mask=inputs["attention_mask"].to(device), max_new_tokens=140, pad_token_id=tokenizer.pad_token_id # Ensure padding is handled properly ) response = tokenizer.decode(outputs[0], skip_special_tokens=True) return response.split("[/INST]")[-1].strip() # Streamlit app layout st.title("Virtual Psychologist") st.markdown("This virtual psychologist offers empathetic responses to your comments or questions. Enter your message below.") user_input = st.text_input("Your Comment/Question:", placeholder="Type here...") if user_input: with st.spinner("Generating response..."): response = get_response(user_input) st.write("### Response:") st.write(response) st.markdown("Built with ❤️ using [Hugging Face Transformers](https://huggingface.co./transformers/) and [Streamlit](https://streamlit.io/).")