Rhaps360 commited on
Commit
053fd9f
·
verified ·
1 Parent(s): bea88ec

updated usage card

Browse files

this is gemma 2b model from google which is fine-tuned on mental health conversations dataset

Files changed (1) hide show
  1. README.md +24 -20
README.md CHANGED
@@ -4,6 +4,9 @@ tags:
4
  - text-generation-inference
5
  - text-generation
6
  - peft
 
 
 
7
  library_name: transformers
8
  widget:
9
  - messages:
@@ -18,28 +21,29 @@ This model was trained using AutoTrain. For more information, please visit [Auto
18
 
19
  # Usage
20
 
21
- ```python
 
22
 
23
- from transformers import AutoModelForCausalLM, AutoTokenizer
24
 
25
- model_path = "PATH_TO_THIS_REPO"
 
 
 
 
 
 
26
 
27
- tokenizer = AutoTokenizer.from_pretrained(model_path)
28
- model = AutoModelForCausalLM.from_pretrained(
29
- model_path,
30
- device_map="auto",
31
- torch_dtype='auto'
32
- ).eval()
33
-
34
- # Prompt content: "hi"
35
  messages = [
36
- {"role": "user", "content": "hi"}
37
  ]
38
-
39
- input_ids = tokenizer.apply_chat_template(conversation=messages, tokenize=True, add_generation_prompt=True, return_tensors='pt')
40
- output_ids = model.generate(input_ids.to('cuda'))
41
- response = tokenizer.decode(output_ids[0][input_ids.shape[1]:], skip_special_tokens=True)
42
-
43
- # Model response: "Hello! How can I assist you today?"
44
- print(response)
45
- ```
 
 
 
4
  - text-generation-inference
5
  - text-generation
6
  - peft
7
+ - chatbot
8
+ - depression
9
+ - therapy
10
  library_name: transformers
11
  widget:
12
  - messages:
 
21
 
22
  # Usage
23
 
24
+ from transformers import AutoTokenizer, pipeline
25
+ import torch
26
 
27
+ model = "Rhaps360/gemma-dep-ins-ft"
28
 
29
+ tokenizer = AutoTokenizer.from_pretrained(model)
30
+ pipeline = pipeline(
31
+ "text-generation",
32
+ model=model,
33
+ model_kwargs={"torch_dtype": torch.bfloat16},
34
+ device="cuda" if(torch.cuda.is_available()) else "cpu",
35
+ )
36
 
 
 
 
 
 
 
 
 
37
  messages = [
38
+ {"role": "user", "content": "### Context: the input message goes here. ### Response: "}
39
  ]
40
+ prompt = pipeline.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
41
+ outputs = pipeline(
42
+ prompt,
43
+ max_new_tokens=300,
44
+ do_sample=True,
45
+ temperature=0.2,
46
+ top_k=50,
47
+ top_p=0.95
48
+ )
49
+ print(outputs[0]["generated_text"][len(prompt):])