llm-jp-3-13b-it / README.md
togepi55's picture
Update README.md
12be96c verified
|
raw
history blame
5.34 kB
---
base_model: llm-jp/llm-jp-3-13b
library_name: peft
tags:
- text-generation-inference
- llama
- trl
license: apache-2.0
---
# Model Card for Model ID
- **ベースモデル :** llm-jp/llm-jp-3-13b
- **対応言語 :** English, Japanese
- **ライセンス :** apache-2.0
### 注意
1. プロンプトは次の形式でのみ学習しています。
2. モデルはアダプターのみですので,利用する際はベースモデルのllm-jp/llm-jp-3-13bも読み込むようにしてください。
~~~: プロンプトの形式
"""
<s>以下は、タスクを説明する指示です。要求を適切に満たす応答を書きなさい
### 指示:
{instruction}
### 応答:
"""
~~~
### テキスト生成のサンプルコード
~~~python
import torch
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
BitsAndBytesConfig,
)
from peft import LoraConfig, PeftModel
from transformers import TextStreamer
BASE_MODEL = "llm-jp/llm-jp-3-13b"
PEFT_MODEL = "togepi55/llm-jp-3-13b-it"
tokenizer = AutoTokenizer.from_pretrained(PEFT_MODEL)
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=False,
)
base_model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
device_map="auto",
quantization_config=bnb_config,
torch_dtype="auto",
trust_remote_code=True,
)
model = PeftModel.from_pretrained(base_model, PEFT_MODEL)
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
instruction = "東京は日本の"
prompt = f"<s>以下は、タスクを説明する指示です。要求を適切に満たす応答を書きなさい\n\n### 指示:\n{instruction}\n\n### 応答:\n"
print(prompt)
model_input = tokenizer(prompt, return_tensors="pt").to(model.device)
input_ids = model_input["input_ids"]
model.eval()
with torch.no_grad():
result = model.generate(
input_ids,
max_new_tokens=300,
attention_mask = model_input.attention_mask,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
do_sample=False,
streamer=streamer,
repetition_penalty=1.02,
)
print("----"*20)
del input_ids
torch.cuda.empty_cache()
~~~
## Bias, Risks, and Limitations
RLHF,DPOを実施していないため不適切な表現が出力される可能性があります。
### Training Details
指示チューニングデータとして下記のものを利用しました。
* ichikara-instruction-003-001-1.json
* ichikara-instruction-003-002-1.json
* elyza/ELYZA-tasks-100
### ライセンス
* ichikara-instructionデータセットのライセンスはcc-by-nc-sa,ELYZA-tasks-100のライセンスはcc-by-sa-4.0になっております。
### SFTの概要
* 4bit量子化
* LoRAによるSFT
* learning_rate = 2e-4
* optim="adamw_torch_fused"
* lr_scheduler_type="cosine"
* weight_decay=0.01
# elyza-tasks-100-TV_0.jsonlでの出力方法
特定タスクであるelyza-tasks-100-TV_0.jsonlに記載されている指示に対する返答のサンプル出力コードは次のようになります。
~~~
import torch
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
BitsAndBytesConfig,
)
from peft import LoraConfig, PeftModel
from datasets import load_dataset
BASE_MODEL = "llm-jp/llm-jp-3-13b"
PEFT_MODEL = "togepi55/llm-jp-3-13b-it"
tokenizer = AutoTokenizer.from_pretrained(PEFT_MODEL)
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=False,
)
base_model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
device_map="auto",
quantization_config=bnb_config,
torch_dtype="auto",
trust_remote_code=True,
)
model = PeftModel.from_pretrained(base_model, PEFT_MODEL)
# elyza-tasks-100-TV_0.jsonl データの読み込み
from datasets import load_dataset
dataset = load_dataset("json", data_files="./elyza-tasks-100-TV_0.jsonl", split="train")
results = []
for num in tqdm(range(100)):
instruction = dataset["input"][num]
prompt = f"<s>以下は、タスクを説明する指示です。要求を適切に満たす応答を書きなさい\n\n### 指示:\n{instruction}\n\n### 応答:\n"
model_input = tokenizer(prompt, return_tensors="pt").to(model.device)
input_ids = model_input["input_ids"]
with torch.no_grad():
outputs = model.generate(
input_ids,
max_new_tokens=300,
attention_mask = model_input.attention_mask,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
do_sample=False,
repetition_penalty=1.02,
)[0]
output = tokenizer.decode(outputs[input_ids.size(1):], skip_special_tokens=True)
results.append({"task_id": num, "input": instruction, "output": output})
# 保存する場合
import json
with open("output.jsonl", "wt", encoding='utf-8') as f:
for result in results:
json.dump(result, f, ensure_ascii=False)
f.write('\n')
~~~