概要

llm-jp/llm-jp-3-13bichikara-instruction でSFTしたモデル。SFTの際は、モデルパラメータに対し8bit量子化を行ったQLoRAを用いている。

推論方法

本モデルを用いて elyza-tasks-100-TV_0.jsonl に対して推論する方法を示す。

データ

elyza-tasks-100-TV_0.jsonl を事前にダウンロードする。

サンプルコード

import json
import re

import peft
import torch
import transformers


def load_jsonl(fname):
    with open(fname, encoding="utf-8") as f:
        data = []
        for line in f:
            _data = json.loads(line.strip())
            data.append(_data)
    return data


# loading dataset
dataset = load_jsonl("./elyza-tasks-100-TV_0.jsonl")


# loading model
bnb_config = transformers.BitsAndBytesConfig(load_in_8bit=True)

model = transformers.AutoModelForCausalLM.from_pretrained(
    pretrained_model_name_or_path="llm-jp/llm-jp-3-13b", device_map="auto", quantization_config=bnb_config
)
model = peft.PeftModel.from_pretrained(model, "orihihsoy/llm-jp-3-13b_qlora_8bit")

tokenizer = transformers.AutoTokenizer.from_pretrained(
    pretrained_model_name_or_path=="llm-jp/llm-jp-3-13b"
)


# evaluation
PROMPT_TEMPLATE = """{instruction}

### 指示:
{input}

### 回答:
{output}"""

results = []
for data in dataset:
    input = data["input"]
    BOS_TOKEN = tokenizer.bos_token

    prompt = BOS_TOKEN + PROMPT_TEMPLATE.format(
        instruction="以下は、タスクを説明する指示です。要求を適切に満たす応答を書きなさい。", input=input, output="")

    tokenized_input = tokenizer.encode(
        prompt, add_special_tokens=False, return_tensors="pt").to(model.device)
    attention_mask = torch.ones_like(tokenized_input)
    with torch.no_grad():
        outputs = model.generate(
            tokenized_input,
            attention_mask=attention_mask,
            max_new_tokens=1024,
            do_sample=True,
            top_p=0.95,
            temperature=0.7,
            repetition_penalty=1.05,
            pad_token_id=tokenizer.eos_token_id
        )[0]
    output = tokenizer.decode(
        outputs[tokenized_input.size(1):], skip_special_tokens=True)

    results.append({"task_id": data["task_id"],
                    "input": input, "output": output})

with open(f"gen.jsonl", 'w', encoding='utf-8') as f:
    for result in results:
        json.dump(result, f, ensure_ascii=False)
        f.write('\n')
Downloads last month
8
Inference API
Unable to determine this model’s pipeline type. Check the docs .

Model tree for orihihsoy/llm-jp-3-13b_qlora_8bit

Adapter
(15)
this model