GouthamVarma commited on
Commit
5886e34
·
verified ·
1 Parent(s): 3b5600e

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -0
app.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ import torch
4
+
5
+ # Load the saved model and tokenizer
6
+ model_path = "GouthamVarma/mentalhealth_coversational_chatbot" # You'll need to upload your model to HF Hub first
7
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
8
+ model = AutoModelForCausalLM.from_pretrained(
9
+ model_path,
10
+ device_map="auto",
11
+ torch_dtype=torch.float16,
12
+ trust_remote_code=True
13
+ )
14
+
15
+ def chat_response(message, history):
16
+ formatted_prompt = f"User: {message}\nAssistant: "
17
+
18
+ inputs = tokenizer(formatted_prompt, return_tensors="pt", truncation=True, max_length=512)
19
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
20
+
21
+ outputs = model.generate(
22
+ **inputs,
23
+ max_length=512,
24
+ num_return_sequences=1,
25
+ temperature=0.7,
26
+ do_sample=True,
27
+ top_p=0.85,
28
+ top_k=40,
29
+ no_repeat_ngram_size=3,
30
+ repetition_penalty=1.3,
31
+ pad_token_id=tokenizer.eos_token_id
32
+ )
33
+
34
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
35
+ response = response.split("Assistant: ")[-1].strip()
36
+
37
+ return response
38
+
39
+ # Create Gradio Interface
40
+ demo = gr.ChatInterface(
41
+ fn=chat_response,
42
+ title="Mental Health Support Assistant",
43
+ description="A supportive AI assistant trained to provide empathetic responses to mental health concerns. Please note: This is not a replacement for professional mental health support.",
44
+ theme="soft",
45
+ examples=[
46
+ "I've been feeling really anxious lately about work.",
47
+ "I can't sleep at night because of stress.",
48
+ "I feel lonely and isolated."
49
+ ]
50
+ )
51
+
52
+ if __name__ == "__main__":
53
+ demo.launch()