vivicai commited on
Commit
34b19b3
β€’
1 Parent(s): 94d07d7

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +38 -0
README.md CHANGED
@@ -10,15 +10,53 @@ license: apache-2.0
10
  <p align="center">
11
  🌐 <a href="https://tigerbot.com/" target="_blank">TigerBot</a> β€’ πŸ€— <a href="https://huggingface.co/TigerResearch" target="_blank">Hugging Face</a>
12
  </p>
 
13
  ## Github
 
14
  https://github.com/TigerResearch/TigerBot
15
 
16
  ## Usage
17
 
18
  ```python
19
  from transformers import AutoTokenizer, AutoModelForCausalLM
 
 
20
 
21
  tokenizer = AutoTokenizer.from_pretrained("TigerResearch/tigerbot-7b-sft")
22
 
23
  model = AutoModelForCausalLM.from_pretrained("TigerResearch/tigerbot-7b-sft")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  ```
 
10
  <p align="center">
11
  🌐 <a href="https://tigerbot.com/" target="_blank">TigerBot</a> β€’ πŸ€— <a href="https://huggingface.co/TigerResearch" target="_blank">Hugging Face</a>
12
  </p>
13
+
14
  ## Github
15
+
16
  https://github.com/TigerResearch/TigerBot
17
 
18
  ## Usage
19
 
20
  ```python
21
  from transformers import AutoTokenizer, AutoModelForCausalLM
22
+ from accelerate import infer_auto_device_map, dispatch_model
23
+ from accelerate.utils import get_balanced_memory
24
 
25
  tokenizer = AutoTokenizer.from_pretrained("TigerResearch/tigerbot-7b-sft")
26
 
27
  model = AutoModelForCausalLM.from_pretrained("TigerResearch/tigerbot-7b-sft")
28
+ max_memory = get_balanced_memory(model)
29
+ device_map = infer_auto_device_map(model, max_memory=max_memory, no_split_module_classes=["BloomBlock"])
30
+ model = dispatch_model(model, device_map=device_map, offload_buffers=True)
31
+
32
+ device = torch.cuda.current_device()
33
+
34
+
35
+ tok_ins = "\n\n### Instruction:\n"
36
+ tok_res = "\n\n### Response:\n"
37
+ prompt_input = tok_ins + "{instruction}" + tok_res
38
+
39
+ input_text = "What is the next number after this list: [1, 2, 3, 5, 8, 13, 21]"
40
+ input_text = prompt_input.format_map({'instruction': input_text})
41
+
42
+ max_input_length = 512
43
+ max_generate_length = 1024
44
+ generation_kwargs = {
45
+ "top_p": 0.95,
46
+ "temperature": 0.8,
47
+ "max_length": max_generate_length,
48
+ "eos_token_id": tokenizer.eos_token_id,
49
+ "pad_token_id": tokenizer.pad_token_id,
50
+ "early_stopping": True,
51
+ "no_repeat_ngram_size": 4,
52
+ }
53
+
54
+ inputs = tokenizer(input_text, return_tensors='pt', truncation=True, max_length=max_input_length)
55
+ inputs = {k: v.to(device) for k, v in inputs.items()}
56
+ output = model.generate(**inputs, **generation_kwargs)
57
+ answer = ''
58
+ for tok_id in output[0][inputs['input_ids'].shape[1]:]:
59
+ if tok_id != tokenizer.eos_token_id:
60
+ answer += tokenizer.decode(tok_id)
61
+ print(answer)
62
  ```