MaxJeblick commited on
Commit
7581874
1 Parent(s): ba8a171

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +40 -13
README.md CHANGED
@@ -27,17 +27,44 @@ tokenizer.push_to_hub(repo_name, private=False)
27
  config.push_to_hub(repo_name, private=False)
28
  ```
29
 
 
30
 
31
- Use the following configuration in [H2O LLM Studio](https://github.com/h2oai/h2o-llmstudio) to run a complete experiment in **5 seconds** using the default dataset and default settings otherwise:
32
-
33
- ```yaml
34
- Validation Size: 0.1
35
- Data Sample: 0.1
36
- Max Length Prompt: 32
37
- Max Length Answer: 32
38
- Max Length: 64
39
- Backbone Dtype: float16
40
- Gradient Checkpointing: False
41
- Batch Size: 8
42
- Max Length Inference: 16
43
- ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  config.push_to_hub(repo_name, private=False)
28
  ```
29
 
30
+ Below is a small example that will run in ~ 1 second.
31
 
32
+ ```python
33
+ import torch
34
+ from transformers import AutoModelForCausalLM
35
+
36
+
37
+ def test_manual_greedy_generate():
38
+ max_new_tokens = 10
39
+
40
+ # note this is on CPU!
41
+ model = AutoModelForCausalLM.from_pretrained("MaxJeblick/llama2-0b-unit-test").eval()
42
+ input_ids = model.dummy_inputs["input_ids"]
43
+
44
+ y = model.generate(input_ids, max_new_tokens=max_new_tokens)
45
+
46
+ assert y.shape == (3, input_ids.shape[1] + max_new_tokens)
47
+
48
+ for _ in range(max_new_tokens):
49
+ with torch.no_grad():
50
+ outputs = model(input_ids)
51
+
52
+ next_token_logits = outputs.logits[:, -1, :]
53
+ next_token_id = torch.argmax(next_token_logits, dim=-1).unsqueeze(-1)
54
+
55
+ input_ids = torch.cat([input_ids, next_token_id], dim=-1)
56
+
57
+ assert torch.allclose(y, input_ids)
58
+ ```
59
+
60
+ Tipp:
61
+
62
+ Use fixtures with session scope to load the model only once. This will decrease test runtime further.
63
+
64
+ ```python
65
+ import pytest
66
+ from transformers import AutoModelForCausalLM
67
+ @pytest.fixture(scope="session")
68
+ def model():
69
+ return AutoModelForCausalLM.from_pretrained("MaxJeblick/llama2-0b-unit-test").eval()
70
+ ```