Text Generation
Transformers
PyTorch
Safetensors
Japanese
English
gpt_neox
text-generation-inference
File size: 8,528 Bytes
a365fbe
86b9caf
a365fbe
86b9caf
 
 
 
 
 
937933b
a365fbe
86b9caf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f1e3d20
 
 
 
 
 
 
86b9caf
 
 
 
 
 
 
 
 
 
 
 
 
 
b770517
86b9caf
 
 
 
 
f1e3d20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86b9caf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b770517
40e1af1
b770517
 
 
40e1af1
b770517
 
 
 
 
 
 
 
40e1af1
 
 
b770517
40e1af1
b770517
 
 
86b9caf
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
---
thumbnail: https://github.com/rinnakk/japanese-pretrained-models/blob/master/rinna.png
license: mit
datasets:
- Anthropic/hh-rlhf
language:
- ja
- en
inference: false
base_model: rinna/bilingual-gpt-neox-4b
---

# bilingual-gpt-neox-4b-instruction-ppo

![rinna-icon](./rinna.png)

---

# Overview
This repository provides an English-Japanese bilingual GPT-NeoX model of 3.8 billion parameters.

The model is based on [`rinna/bilingual-gpt-neox-4b-instruction-sft`](https://huggingface.co./rinna/bilingual-gpt-neox-4b-instruction-sft) and has been aligned to serve as an instruction-following conversational agent.

* **Model architecture**

    A 36-layer, 2816-hidden-size transformer-based language model.

* **RLHF**
    
    Following the [OpenAI InstructGPT paper](https://arxiv.org/abs/2203.02155), **Reinforcement Learning from Human Feedback** (RLHF) has been applied to aligning the model's behaviour with input instructions. Particularly, the model has been trained in two stages, i.e. **Supervised Fine-Tuning** (SFT) and [PPO](https://arxiv.org/abs/1707.06347)-based **Reinforcement Learning** (RL). 
    * The first SFT stage produces [`rinna/bilingual-gpt-neox-4b-instruction-sft`](https://huggingface.co./rinna/bilingual-gpt-neox-4b-instruction-sft).
    * The second RL stage produces this model.

* **Reinforcement learning**
    
    We used [CarperAI/trlx](https://github.com/CarperAI/trlx) and its implementation of the PPO algorithm for the RL stage.
    
    The RL data is the subset of the following dataset and has been translated into Japanese.
    * [Anthropic HH RLHF data](https://huggingface.co./datasets/Anthropic/hh-rlhf)

* **Model Series**

    | Variant | Link |
    | :-- | :--|
    | Bilingual 4B MiniGPT4 | https://huggingface.co./rinna/bilingual-gpt-neox-4b-minigpt4 |
    | Bilingual 4B PPO | https://huggingface.co./rinna/bilingual-gpt-neox-4b-instruction-ppo |
    | Bilingual 4B SFT | https://huggingface.co./rinna/bilingual-gpt-neox-4b-instruction-sft |
    | Bilingual 4B 8K | https://huggingface.co./rinna/bilingual-gpt-neox-4b-8k |
    | Bilingual 4B | https://huggingface.co./rinna/bilingual-gpt-neox-4b |
    | Japanese 3.6B PPO | https://huggingface.co./rinna/japanese-gpt-neox-3.6b-instruction-ppo |
    | Japanese 3.6B SFT-v2 | https://huggingface.co./rinna/japanese-gpt-neox-3.6b-instruction-sft-v2 |
    | Japanese 3.6B SFT | https://huggingface.co./rinna/japanese-gpt-neox-3.6b-instruction-sft |
    | Japanese 3.6B | https://huggingface.co./rinna/japanese-gpt-neox-3.6b |

* **Contributors**
    
    [Tianyu Zhao](https://huggingface.co./tianyuz) and [Kei Sawada](https://huggingface.co./keisawada)

---

# Benchmarking

  Our evaluation experiments suggest that the PPO does not particularly improve the model's performance on the Japanese LLM benchmark in comparison with [Bilingual GPT-NeoX 4B SFT](https://huggingface.co./rinna/bilingual-gpt-neox-4b-instruction-sft), but we have seen **better conversation experience** on the PPO model than its SFT counterpart. 
  - *The 4-task average accuracy is based on results of JCommonsenseQA, JNLI, MARC-ja, and JSQuAD.*
  - *The 6-task average accuracy is based on results of JCommonsenseQA, JNLI, MARC-ja, JSQuAD, XWinograd, and JAQKET-v2.*
   
  | Model | 4-task average accuracy | 6-task average accuracy |
  | :-- | :-- | :-- |
  | **bilingual-gpt-neox-4b-instruction-ppo** | **61.01** | **61.16** |
  | bilingual-gpt-neox-4b-instruction-sft | 61.02 | 61.69 |
  | bilingual-gpt-neox-4b | 56.12 | 51.83 |
  | japanese-gpt-neox-3.6b-instruction-ppo | 59.86 | 60.07 |
  | japanese-gpt-neox-3.6b | 55.07 | 50.32 |

---

# I/O Format
A special format has been adopted to construct inputs.
* An input prompt is formatted as a conversation between `ユーザー` and `システム`.
* Each input utterance consists of (1) its speaker (`"ユーザー"` or `"システム"`), (2) a colon (`":"`), (3) a whitespace (`" "`), and (4) utterance text (e.g. `"世界で一番高い山は?"`).
* The input prompt should be ended with `"システム: "` to acknowledge the model to generate a response.
* All the utterances in the input prompt should be separated by a newline `\n`.

Following is an example to construct input from a conversation.
~~~python
prompt = [
    {
        "speaker": "ユーザー",
        "text": "Hello, you are an assistant that helps me learn Japanese."
    },
    {
        "speaker": "システム",
        "text": "Sure, what can I do for you?"
    },
    {
        "speaker": "ユーザー",
        "text": "VRはなんですか。"
    }
]
prompt = [
    f"{uttr['speaker']}: {uttr['text']}"
    for uttr in prompt
]
prompt = "\n".join(prompt)
prompt = (
    prompt
    + "\n"
    + "システム: "
)
print(prompt)
"""
ユーザー: Hello, you are an assistant that helps me learn Japanese.
システム: Sure, what can I do for you?
ユーザー: VRはなんですか。
システム:
"""
~~~

---

# How to use the model

**Notice:** Since the model is **sensitive to decoding hyper-parameters** (e.g. `temperature`, `top_p`, `top_k`, `repetition_penalty`), it is suggested to explore the best setting for your task.

~~~~python
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("rinna/bilingual-gpt-neox-4b-instruction-ppo", use_fast=False)
model = AutoModelForCausalLM.from_pretrained("rinna/bilingual-gpt-neox-4b-instruction-ppo")

if torch.cuda.is_available():
    model = model.to("cuda")

token_ids = tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt")

with torch.no_grad():
    output_ids = model.generate(
        token_ids.to(model.device),
        max_new_tokens=512,
        do_sample=True,
        temperature=1.0,
        top_p=0.85,
        pad_token_id=tokenizer.pad_token_id,
        bos_token_id=tokenizer.bos_token_id,
        eos_token_id=tokenizer.eos_token_id
    )

output = tokenizer.decode(output_ids.tolist()[0][token_ids.size(1):])
print(output)
"""VRとはVirtual Realityの略で、仮想現実とも呼ばれます。これは、コンピューターを使用して仮想世界を作り出し、仮想世界上でコンピューターのゲームや仮想世界を体験するための技術です。この技術は、コンピューターやモバイ ルデバイスの進歩によって、2015年以降、ますます普及しています。VRは、ゲームや仮想世界、その他のアプリケー ションなどのさまざまな分野で、コンピューターと人間の相互作用の新しい方法を提供しています。</s>"""
~~~~

---

# Tokenization
The model uses a [sentencepiece](https://github.com/google/sentencepiece)-based tokenizer.
* The tokenizer has a vocabulary size of 65,536.
* It uses *byte fallback* to decompose unknown text pieces into UTF-8 byte pieces to avoid producing `<UNK>` tokens.
* It can recognize *consecutive whitespaces*, *newlines*, and *tabs* to handle structured texts better.
* We turned off the default behaviour of prepending leading whitespace because it is not beneficial for processing Japanese.
* Specifically, single whitespace is always processed as one token so that any English word won't have a preceding whitespace like in many other tokenizers (e.g. `_Hello`).
  * This decision trades the English processing efficiency for a unified way to treat whitespaces.
  * It leads to a significantly lower loss of next token prediction on English data because whitespaces are easy to predict.
* **Don't forget to set `use_fast=False` to make the above features function correctly.**

---

# How to cite
```bibtex
@misc{rinna-bilingual-gpt-neox-4b-instruction-ppo,
    title = {rinna/bilingual-gpt-neox-4b-instruction-ppo},
    author = {Zhao, Tianyu and Sawada, Kei},
    url = {https://huggingface.co./rinna/bilingual-gpt-neox-4b-instruction-ppo}
}

@inproceedings{sawada2024release,
    title = {Release of Pre-Trained Models for the {J}apanese Language},
    author = {Sawada, Kei and Zhao, Tianyu and Shing, Makoto and Mitsui, Kentaro and Kaga, Akio and Hono, Yukiya and Wakatsuki, Toshiaki and Mitsuda, Koh},
    booktitle = {Proceedings of the 2024 Joint International Conference on Computational Linguistics, Language Resources and Evaluation (LREC-COLING 2024)},
    month = {5},
    year = {2024},
    pages = {13898--13905},
    url = {https://aclanthology.org/2024.lrec-main.1213},
    note = {\url{https://arxiv.org/abs/2404.01657}}
}
```

---

# Licenese
[The MIT license](https://opensource.org/licenses/MIT)