Update README.md
Browse files
README.md
CHANGED
@@ -1,3 +1,93 @@
|
|
1 |
-
---
|
2 |
-
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
language:
|
3 |
+
- ja
|
4 |
+
license: apache-2.0
|
5 |
+
tags:
|
6 |
+
- ja
|
7 |
+
- japanese
|
8 |
+
- text-generation
|
9 |
+
- lm
|
10 |
+
- jax
|
11 |
+
- flax
|
12 |
+
- lm1b
|
13 |
+
datasets:
|
14 |
+
- wiki40b
|
15 |
+
---
|
16 |
+
# transformer-lm-japanese-1.0b
|
17 |
+
|
18 |
+
This is a JAX/Flax-based transformer language model trained on a Japanese dataset. It is based on the official Flax example code ([lm1b](https://github.com/google/flax/tree/main/examples/lm1b)).
|
19 |
+
|
20 |
+
## Source Code
|
21 |
+
|
22 |
+
We've modified Flax's 'lm1b' example to train on Japanese dataset. You can find the code on Github.
|
23 |
+
|
24 |
+
* [transformer-lm-japanese](https://github.com/FookieMonster/transformer-lm-japanese)
|
25 |
+
|
26 |
+
## Our Blog Post
|
27 |
+
|
28 |
+
* [【0.1Bから作るLLM】 JAX/Flaxで作るTransformer言語モデル](https://zenn.dev/fukugawa/articles/4446573ec0f697)
|
29 |
+
|
30 |
+
## Model Details
|
31 |
+
|
32 |
+
| Model | Params | Layers | Dim | Heads | Dataset | Dataset size | Training time | PPL |
|
33 |
+
|-|-|-|-|-|-|-|-|-|
|
34 |
+
| transformer-lm-japanese-1.0b | 1.0B | 18 | 2048 | 16 | wiki40b/ja | 2.19GB | 4 days | 31.47 |
|
35 |
+
|
36 |
+
## Usage: FlaxAutoModel
|
37 |
+
|
38 |
+
#### Requirements:
|
39 |
+
|
40 |
+
```
|
41 |
+
pip install transformers>=4.39.0
|
42 |
+
pip install jax==0.4.31
|
43 |
+
pip install flax==0.8.3
|
44 |
+
pip install sentencepiece==0.1.99
|
45 |
+
|
46 |
+
# For CPU
|
47 |
+
pip install -U "jax[cpu]==0.4.31"
|
48 |
+
|
49 |
+
# For GPU
|
50 |
+
pip install -U "jax[cuda12]==0.4.31"
|
51 |
+
```
|
52 |
+
|
53 |
+
Note: Set **trust_remote_code=True** to load our custom model.
|
54 |
+
|
55 |
+
~~~~python
|
56 |
+
from transformers import AutoTokenizer, FlaxAutoModelForCausalLM
|
57 |
+
|
58 |
+
tokenizer = AutoTokenizer.from_pretrained("fukugawa/transformer-lm-japanese-1.0b", trust_remote_code=True)
|
59 |
+
model = FlaxAutoModelForCausalLM.from_pretrained("fukugawa/transformer-lm-japanese-1.0b", trust_remote_code=True)
|
60 |
+
|
61 |
+
text = "日本の首都は、"
|
62 |
+
token_ids = tokenizer.encode(text, return_tensors="jax", add_special_tokens=False)
|
63 |
+
|
64 |
+
output_ids = model.generate(
|
65 |
+
token_ids,
|
66 |
+
do_sample=True,
|
67 |
+
temperature=0.6,
|
68 |
+
top_k=20,
|
69 |
+
max_new_tokens=100
|
70 |
+
)
|
71 |
+
|
72 |
+
output = tokenizer.decode(output_ids[0][0], skip_special_tokens=True)
|
73 |
+
print(output)
|
74 |
+
~~~~
|
75 |
+
|
76 |
+
We tested text generation in a Python 3.10 environment on GCP as follows
|
77 |
+
|
78 |
+
* GPU Type: NVIDIA L4 (x 1)
|
79 |
+
* Machine Type: g2-standard-16 (16 CPUs, 64GB Memory)
|
80 |
+
* Disk: 256GB
|
81 |
+
* OS: Ubuntu 22.04 LTS x86/64
|
82 |
+
|
83 |
+
## Dataset
|
84 |
+
|
85 |
+
* [wiki40b/ja](https://www.tensorflow.org/datasets/catalog/wiki40b?hl=ja#wiki40bja) (2.19GB)
|
86 |
+
|
87 |
+
## Tokenization
|
88 |
+
|
89 |
+
* [sentencepiece](https://github.com/google/sentencepiece)
|
90 |
+
|
91 |
+
## Author
|
92 |
+
|
93 |
+
[Ryoichi Fukugawa](https://zenn.dev/fukugawa)
|