fukugawa commited on
Commit
e6e20ca
1 Parent(s): 63f7abe

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +93 -3
README.md CHANGED
@@ -1,3 +1,93 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language:
3
+ - ja
4
+ license: apache-2.0
5
+ tags:
6
+ - ja
7
+ - japanese
8
+ - text-generation
9
+ - lm
10
+ - jax
11
+ - flax
12
+ - lm1b
13
+ datasets:
14
+ - wiki40b
15
+ ---
16
+ # transformer-lm-japanese-1.0b
17
+
18
+ This is a JAX/Flax-based transformer language model trained on a Japanese dataset. It is based on the official Flax example code ([lm1b](https://github.com/google/flax/tree/main/examples/lm1b)).
19
+
20
+ ## Source Code
21
+
22
+ We've modified Flax's 'lm1b' example to train on Japanese dataset. You can find the code on Github.
23
+
24
+ * [transformer-lm-japanese](https://github.com/FookieMonster/transformer-lm-japanese)
25
+
26
+ ## Our Blog Post
27
+
28
+ * [【0.1Bから作るLLM】 JAX/Flaxで作るTransformer言語モデル](https://zenn.dev/fukugawa/articles/4446573ec0f697)
29
+
30
+ ## Model Details
31
+
32
+ | Model | Params | Layers | Dim | Heads | Dataset | Dataset size | Training time | PPL |
33
+ |-|-|-|-|-|-|-|-|-|
34
+ | transformer-lm-japanese-1.0b | 1.0B | 18 | 2048 | 16 | wiki40b/ja | 2.19GB | 4 days | 31.47 |
35
+
36
+ ## Usage: FlaxAutoModel
37
+
38
+ #### Requirements:
39
+
40
+ ```
41
+ pip install transformers>=4.39.0
42
+ pip install jax==0.4.31
43
+ pip install flax==0.8.3
44
+ pip install sentencepiece==0.1.99
45
+
46
+ # For CPU
47
+ pip install -U "jax[cpu]==0.4.31"
48
+
49
+ # For GPU
50
+ pip install -U "jax[cuda12]==0.4.31"
51
+ ```
52
+
53
+ Note: Set **trust_remote_code=True** to load our custom model.
54
+
55
+ ~~~~python
56
+ from transformers import AutoTokenizer, FlaxAutoModelForCausalLM
57
+
58
+ tokenizer = AutoTokenizer.from_pretrained("fukugawa/transformer-lm-japanese-1.0b", trust_remote_code=True)
59
+ model = FlaxAutoModelForCausalLM.from_pretrained("fukugawa/transformer-lm-japanese-1.0b", trust_remote_code=True)
60
+
61
+ text = "日本の首都は、"
62
+ token_ids = tokenizer.encode(text, return_tensors="jax", add_special_tokens=False)
63
+
64
+ output_ids = model.generate(
65
+ token_ids,
66
+ do_sample=True,
67
+ temperature=0.6,
68
+ top_k=20,
69
+ max_new_tokens=100
70
+ )
71
+
72
+ output = tokenizer.decode(output_ids[0][0], skip_special_tokens=True)
73
+ print(output)
74
+ ~~~~
75
+
76
+ We tested text generation in a Python 3.10 environment on GCP as follows
77
+
78
+ * GPU Type: NVIDIA L4 (x 1)
79
+ * Machine Type: g2-standard-16 (16 CPUs, 64GB Memory)
80
+ * Disk: 256GB
81
+ * OS: Ubuntu 22.04 LTS x86/64
82
+
83
+ ## Dataset
84
+
85
+ * [wiki40b/ja](https://www.tensorflow.org/datasets/catalog/wiki40b?hl=ja#wiki40bja) (2.19GB)
86
+
87
+ ## Tokenization
88
+
89
+ * [sentencepiece](https://github.com/google/sentencepiece)
90
+
91
+ ## Author
92
+
93
+ [Ryoichi Fukugawa](https://zenn.dev/fukugawa)