File size: 3,054 Bytes
aa9ac9d
 
2b46c24
 
 
 
aa9ac9d
2b46c24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
---
license: mit
datasets:
- SkelterLabsInc/JaQuAD
language:
- ja
---
# Model Card for Model ID

<!-- Provide a quick summary of what the model is/does. -->

このモデルはrinna/japanese-gpt-1bをベースモデルとして、
コンテキストからの抽出型QAと、解答を新たなコンテキストでリファインするための学習を行ったモデルです。

gpt-index(v0.2.5)で利用することを前提に学習をしており、通常のQAタスクで使用することは想定していません。

利用例はこのリポジトリを参照してください。
https://github.com/oshizo/gpt_index_japanese_trial


# Model Details

モデルは2種類のpromptテンプレートに対してQA応答するように訓練されています。

```python
DEFAULT_PROMPT = """
文脈情報は以下です。
---
{context_str}
---
事前知識ではなく、文脈情報を参考に質問に答えてください。:{query_str}
"""
```

```python
REFINE_PROMPT = """
質問は以下です。:{query_str}
すでに答えの候補があります。:{existing_answer}
必要な場合のみ、以下の文脈情報を使ってこの答えを改良することができます。
---
{context_msg}
---
この文脈情報により、元の答えを改良して質問に答えてください。
文脈情報が有用でない場合は元の答えをそのまま返してください。
"""
```

```python
import torch
from transformers import T5Tokenizer, AutoModelForCausalLM

tokenizer = T5Tokenizer.from_pretrained("rinna/japanese-gpt-1b")
model = AutoModelForCausalLM.from_pretrained("oshizo/qa-refine-japanese-gpt-1b").to("cuda")


prompt = DEFAULT_PROMPT.format(
    context_str="山路を登りながら、こう考えた。智に働けば角が立つ。情に棹させば流される。意地を通せば窮屈だ。とかくに人の世は住みにくい。住みにくさが高じると、安い所へ引き越したくなる。どこへ越しても住みにくいと悟った時、詩が生れて、画が出来る。",
    query_str="意地を通すとどうなってしまう?"
    )

token_ids = tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt")
n = len(token_ids[0])

with torch.no_grad():
    output_ids = model.generate(
        token_ids.to(model.device),
        max_length=n+100,
        min_length=n+2,
        do_sample=False,
        pad_token_id=tokenizer.pad_token_id,
        bos_token_id=tokenizer.bos_token_id,
        eos_token_id=tokenizer.eos_token_id,
    )
output = tokenizer.decode(output_ids.tolist()[0][n:])
output.replace("</s>", "")

# -> 窮屈

```

# Training Details

JGLUE/JSQuADとJaQuADを用いて、コンテキストからの抽出型QAと、解答を新たなコンテキストでリファインするための学習を行いました。

学習スクリプトについてはこのリポジトリを参照してください。
https://github.com/oshizo/gpt_index_japanese_trial

Google Colab Pro A100 で約3.5時間、9.9kステップ学習しました。