File size: 3,285 Bytes
b0b72af
3721ad6
 
01799ff
3721ad6
 
 
 
 
 
 
 
01799ff
 
b0b72af
3721ad6
 
 
 
ac2d478
1b8053a
ac2d478
 
64d2ac6
3721ad6
 
 
 
 
61462f1
c8716cf
61462f1
c8716cf
3721ad6
 
 
 
 
 
1b8053a
 
 
 
 
 
 
 
 
 
 
 
 
 
85e238b
61462f1
3721ad6
61462f1
fda7b62
 
61462f1
 
 
 
3721ad6
0077f4b
61462f1
3721ad6
0077f4b
61462f1
3721ad6
 
61462f1
3721ad6
61462f1
 
3721ad6
61462f1
 
6ef0e58
61462f1
 
6ef0e58
61462f1
 
 
 
 
 
 
3721ad6
61462f1
 
 
3721ad6
61462f1
3721ad6
61462f1
 
 
3721ad6
 
 
 
 
 
 
 
 
 
 
c8716cf
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
---
language:
- ja
license: apache-2.0
tags:
- ja
- japanese
- text-generation
- lm
- jax
- flax
- lm1b
datasets:
- wiki40b
---
# transformer-lm-japanese-0.1b

This is a JAX/Flax-based transformer language model trained on a Japanese dataset. It is based on the official Flax example code ([lm1b](https://github.com/google/flax/tree/main/examples/lm1b)).

## Update Log
* 2024/05/20 Added JGLUE 4-task benchmark scores.
* 2024/05/13 FlaxAutoModelForCausalLM is now supported with custom model code added.

## Source Code

We've modified Flax's 'lm1b' example to train on Japanese dataset. You can find the code on Github.

* [transformer-lm-japanese](https://github.com/FookieMonster/transformer-lm-japanese)

## Our Blog Post

* [【0.1Bから作るLLM】 JAX/Flaxで作るTransformer言語モデル](https://zenn.dev/fukugawa/articles/4446573ec0f697)

## Model Details

| Model | Params | Layers | Dim | Heads | PPL | Dataset | Training time |
|-|-|-|-|-|-|-|-|
| transformer-lm-japanese-0.1b | 0.1B | 12 | 768 | 12 | 35.22 | wiki40b/ja | 1.5 days |

## Benchmarking

* **JGLUE 4-task (2024/05/20)**

    - *We used [Stability-AI/lm-evaluation-harness](https://github.com/Stability-AI/lm-evaluation-harness) library for evaluation.*
    - *We modified the harness to work with the FlaxAutoModel for evaluating JAX/Flax models. See the code [here](https://github.com/FookieMonster/lm-evaluation-harness).*
    - *We evaluated four tasks: JCommonsenseQA-1.1, JNLI-1.3, MARC-ja-1.1, and JSQuAD-1.1.*
    - *All evaluations used version 0.3 of the prompt template and were zero-shot.*
    - *The number of few-shots is 0,0,0,0.*
   
    | Model | Average | JCommonsenseQA | JNLI | MARC-ja | JSQuAD |
    | :-- | :-- | :-- | :-- | :-- | :-- |
    | transformer-lm-japanese-0.1b | 41.19 | 25.47 | 45.60 | 85.46 | 8.24 |
    | Reference: rinna/japanese-gpt-neox-small | 40.75 | 40.39 | 29.13 | 85.48 | 8.02 |

## Usage: FlaxAutoModel

#### Requirements:

```
pip install transformers>=4.39.0
pip install jax==0.4.13
pip install flax==0.6.11
pip install sentencepiece==0.1.99

# For CPU
pip install jax[cpu]==0.4.13

# For GPU
pip install --upgrade "jax[cuda12_pip]==0.4.13" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
```

Note: Set **trust_remote_code=True** to load our custom model.

~~~~python
from transformers import AutoTokenizer, FlaxAutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("fukugawa/transformer-lm-japanese-0.1b", trust_remote_code=True)
model = FlaxAutoModelForCausalLM.from_pretrained("fukugawa/transformer-lm-japanese-0.1b", trust_remote_code=True)

text = "日本の首都は、"
token_ids = tokenizer.encode(text, return_tensors="jax", add_special_tokens=False)

output_ids = model.generate(
  token_ids,
  do_sample=True,
  temperature=0.6,
  top_k=20,
  max_new_tokens=100
)

output = tokenizer.decode(output_ids[0][0], skip_special_tokens=True)
print(output)
~~~~

We tested text generation in a Python 3.10 environment on GCP as follows

* Machine Type: c2-standard-4 (4 CPUs, 16GB Memory)
* Disk: 100GB (Standard Persistent Disk)
* OS: Ubuntu 22.04 LTS x86/64

## Dataset

* wiki40b/ja
 
## Tokenization

* [sentencepiece](https://github.com/google/sentencepiece)

## Author

[Ryoichi Fukugawa](https://zenn.dev/fukugawa)