metadata
language:
- ja
license: apache-2.0
tags:
- ja
- japanese
- text-generation
- lm
- jax
- flax
- lm1b
datasets:
- wiki40b
transformer-lm-japanese-0.1b
This is a JAX/Flax-based transformer language model trained on a Japanese dataset. It is based on the official Flax example code (lm1b).
Update Log
- 2024/05/20 Added JGLUE 4-task benchmark scores.
- 2024/05/13 FlaxAutoModelForCausalLM is now supported with custom model code added.
Source Code
We've modified Flax's 'lm1b' example to train on Japanese dataset. You can find the code on Github.
Our Blog Post
Model Details
Model | Params | Layers | Dim | Heads | PPL | Dataset | Training time |
---|---|---|---|---|---|---|---|
transformer-lm-japanese-0.1b | 0.1B | 12 | 768 | 12 | 35.22 | wiki40b/ja | 1.5 days |
Benchmarking
JGLUE 4-task (2024/05/20)
- We used Stability-AI/lm-evaluation-harness library for evaluation.
- We modified the harness to work with the FlaxAutoModel for evaluating JAX/Flax models. See the code here.
- We evaluated four tasks: JCommonsenseQA-1.1, JNLI-1.3, MARC-ja-1.1, and JSQuAD-1.1.
- All evaluations used version 0.3 of the prompt template and were zero-shot.
- The number of few-shots is 0,0,0,0.
Model Average JCommonsenseQA JNLI MARC-ja JSQuAD transformer-lm-japanese-0.1b 41.19 25.47 45.60 85.46 8.24 Reference: rinna/japanese-gpt-neox-small 40.75 40.39 29.13 85.48 8.02
Usage: FlaxAutoModel
Requirements:
pip install transformers>=4.39.0
pip install jax==0.4.13
pip install flax==0.6.11
pip install sentencepiece==0.1.99
# For CPU
pip install jax[cpu]==0.4.13
# For GPU
pip install --upgrade "jax[cuda12_pip]==0.4.13" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Note: Set trust_remote_code=True to load our custom model.
from transformers import AutoTokenizer, FlaxAutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained("fukugawa/transformer-lm-japanese-0.1b", trust_remote_code=True)
model = FlaxAutoModelForCausalLM.from_pretrained("fukugawa/transformer-lm-japanese-0.1b", trust_remote_code=True)
text = "日本の首都は、"
token_ids = tokenizer.encode(text, return_tensors="jax", add_special_tokens=False)
output_ids = model.generate(
token_ids,
do_sample=True,
temperature=0.6,
top_k=20,
max_new_tokens=100
)
output = tokenizer.decode(output_ids[0][0], skip_special_tokens=True)
print(output)
We tested text generation in a Python 3.10 environment on GCP as follows
- Machine Type: c2-standard-4 (4 CPUs, 16GB Memory)
- Disk: 100GB (Standard Persistent Disk)
- OS: Ubuntu 22.04 LTS x86/64
Dataset
- wiki40b/ja