|
|
|
import os |
|
import gradio as gr |
|
from langchain_community.llms import HuggingFaceEndpoint |
|
from langchain.prompts import PromptTemplate |
|
|
|
|
|
HF_TOKEN = os.getenv("HF_TOKEN") |
|
llm = HuggingFaceEndpoint( |
|
repo_id="google/gemma-1.1-7b-it", |
|
task="text-generation", |
|
max_new_tokens=512, |
|
top_k=5, |
|
temperature=0.1, |
|
repetition_penalty=1.03, |
|
huggingfacehub_api_token=HF_TOKEN |
|
) |
|
template = """ |
|
You are a Mental Health Chatbot. Help the user with their mental health concerns. |
|
Use the context below to answer the questions {context} |
|
Question: {question} |
|
Helpful Answer:""" |
|
QA_CHAIN_PROMPT = PromptTemplate(input_variables=["context", "question"],template=template) |
|
|
|
def predict(message, history): |
|
input_prompt = QA_CHAIN_PROMPT.format(question=message, context=history) |
|
result = llm.generate([input_prompt]) |
|
print(result) |
|
|
|
|
|
if result.generations: |
|
ai_msg = result.generations[0][0].text |
|
else: |
|
ai_msg = "I'm sorry, I couldn't generate a response for that input." |
|
|
|
return ai_msg |
|
|
|
|
|
gr.ChatInterface(predict).launch() |
|
|