File size: 3,285 Bytes
b0b72af 3721ad6 01799ff 3721ad6 01799ff b0b72af 3721ad6 ac2d478 1b8053a ac2d478 64d2ac6 3721ad6 61462f1 c8716cf 61462f1 c8716cf 3721ad6 1b8053a 85e238b 61462f1 3721ad6 61462f1 fda7b62 61462f1 3721ad6 0077f4b 61462f1 3721ad6 0077f4b 61462f1 3721ad6 61462f1 3721ad6 61462f1 3721ad6 61462f1 6ef0e58 61462f1 6ef0e58 61462f1 3721ad6 61462f1 3721ad6 61462f1 3721ad6 61462f1 3721ad6 c8716cf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 |
---
language:
- ja
license: apache-2.0
tags:
- ja
- japanese
- text-generation
- lm
- jax
- flax
- lm1b
datasets:
- wiki40b
---
# transformer-lm-japanese-0.1b
This is a JAX/Flax-based transformer language model trained on a Japanese dataset. It is based on the official Flax example code ([lm1b](https://github.com/google/flax/tree/main/examples/lm1b)).
## Update Log
* 2024/05/20 Added JGLUE 4-task benchmark scores.
* 2024/05/13 FlaxAutoModelForCausalLM is now supported with custom model code added.
## Source Code
We've modified Flax's 'lm1b' example to train on Japanese dataset. You can find the code on Github.
* [transformer-lm-japanese](https://github.com/FookieMonster/transformer-lm-japanese)
## Our Blog Post
* [【0.1Bから作るLLM】 JAX/Flaxで作るTransformer言語モデル](https://zenn.dev/fukugawa/articles/4446573ec0f697)
## Model Details
| Model | Params | Layers | Dim | Heads | PPL | Dataset | Training time |
|-|-|-|-|-|-|-|-|
| transformer-lm-japanese-0.1b | 0.1B | 12 | 768 | 12 | 35.22 | wiki40b/ja | 1.5 days |
## Benchmarking
* **JGLUE 4-task (2024/05/20)**
- *We used [Stability-AI/lm-evaluation-harness](https://github.com/Stability-AI/lm-evaluation-harness) library for evaluation.*
- *We modified the harness to work with the FlaxAutoModel for evaluating JAX/Flax models. See the code [here](https://github.com/FookieMonster/lm-evaluation-harness).*
- *We evaluated four tasks: JCommonsenseQA-1.1, JNLI-1.3, MARC-ja-1.1, and JSQuAD-1.1.*
- *All evaluations used version 0.3 of the prompt template and were zero-shot.*
- *The number of few-shots is 0,0,0,0.*
| Model | Average | JCommonsenseQA | JNLI | MARC-ja | JSQuAD |
| :-- | :-- | :-- | :-- | :-- | :-- |
| transformer-lm-japanese-0.1b | 41.19 | 25.47 | 45.60 | 85.46 | 8.24 |
| Reference: rinna/japanese-gpt-neox-small | 40.75 | 40.39 | 29.13 | 85.48 | 8.02 |
## Usage: FlaxAutoModel
#### Requirements:
```
pip install transformers>=4.39.0
pip install jax==0.4.13
pip install flax==0.6.11
pip install sentencepiece==0.1.99
# For CPU
pip install jax[cpu]==0.4.13
# For GPU
pip install --upgrade "jax[cuda12_pip]==0.4.13" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
```
Note: Set **trust_remote_code=True** to load our custom model.
~~~~python
from transformers import AutoTokenizer, FlaxAutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained("fukugawa/transformer-lm-japanese-0.1b", trust_remote_code=True)
model = FlaxAutoModelForCausalLM.from_pretrained("fukugawa/transformer-lm-japanese-0.1b", trust_remote_code=True)
text = "日本の首都は、"
token_ids = tokenizer.encode(text, return_tensors="jax", add_special_tokens=False)
output_ids = model.generate(
token_ids,
do_sample=True,
temperature=0.6,
top_k=20,
max_new_tokens=100
)
output = tokenizer.decode(output_ids[0][0], skip_special_tokens=True)
print(output)
~~~~
We tested text generation in a Python 3.10 environment on GCP as follows
* Machine Type: c2-standard-4 (4 CPUs, 16GB Memory)
* Disk: 100GB (Standard Persistent Disk)
* OS: Ubuntu 22.04 LTS x86/64
## Dataset
* wiki40b/ja
## Tokenization
* [sentencepiece](https://github.com/google/sentencepiece)
## Author
[Ryoichi Fukugawa](https://zenn.dev/fukugawa) |