--- language: - ja license: apache-2.0 tags: - ja - japanese - text-generation - lm - jax - flax - lm1b datasets: - wiki40b --- # transformer-lm-japanese-0.1b This is a JAX/Flax-based transformer language model trained on a Japanese dataset. It is based on the official Flax example code ([lm1b](https://github.com/google/flax/tree/main/examples/lm1b)). ## Update Log * 2024/05/13 FlaxAutoModelForCausalLM is now supported with custom model code added. ## Source Code We've modified Flax's 'lm1b' example to train on Japanese dataset. You can find the code on Github. * [transformer-lm-japanese](https://github.com/FookieMonster/transformer-lm-japanese) ## Our Blog Post * [【0.1Bから作るLLM】 JAX/Flaxで作るTransformer言語モデル](https://zenn.dev/fukugawa/articles/4446573ec0f697) ## Model Details | Model | Params | Layers | Dim | Heads | PPL | Dataset | Training time | |-|-|-|-|-|-|-|-| | lm1b-default | 0.05B | 6 | 512 | 8 | 22.67 | lm1b | 0.5 days | | transformer-lm-japanese-default | 0.05B | 6 | 512 | 8 | 66.38 | cc100/ja | 0.5 days | | transformer-lm-japanese-0.1b | 0.1B | 12 | 768 | 12 | 35.22 | wiki40b/ja | 1.5 days | ![tensor-board](./tensorboard-v1.png) ## Usage: FlaxAutoModel #### Requirements: ``` pip install transformers>=4.39.0 pip install jax==0.4.13 pip install flax==0.6.11 pip install sentencepiece==0.1.99 # For CPU pip install jax[cpu]==0.4.13 # For GPU pip install --upgrade "jax[cuda12_pip]==0.4.13" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html ``` Note: Set **trust_remote_code=True** to load our custom model. ~~~~python from transformers import AutoTokenizer, FlaxAutoModelForCausalLM tokenizer = AutoTokenizer.from_pretrained("fukugawa/transformer-lm-japanese-0.1b", trust_remote_code=True) model = FlaxAutoModelForCausalLM.from_pretrained("fukugawa/transformer-lm-japanese-0.1b", trust_remote_code=True) text = "日本の首都は、" token_ids = tokenizer.encode(text, return_tensors="jax", add_special_tokens=False) output_ids = model.generate( token_ids, do_sample=True, temperature=0.6, top_k=20, max_new_tokens=100 ) output = tokenizer.decode(output_ids[0][0], skip_special_tokens=True) print(output) ~~~~ We tested text generation in a Python 3.10 environment on GCP as follows * Machine Type: c2-standard-4 (4 CPUs, 16GB Memory) * Disk: 100GB (Standard Persistent Disk) * OS: Ubuntu 22.04 LTS x86/64 ## Dataset * wiki40b/ja ## Tokenization * [sentencepiece](https://github.com/google/sentencepiece) ## Author [Ryoichi Fukugawa](https://zenn.dev/fukugawa)