lora_gemma_9b

このモデルは、Gemma-2-9b をベースに微調整されたモデルです。 LLM講座の最終課題のために作成されたものになります。

出力方法

以下のコードを使用して、モデルをロードし、結果を生成できます。

from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import PeftModel
from datasets import load_dataset
from tqdm import tqdm
import torch
import json

# トークナイザーのロード
tokenizer = AutoTokenizer.from_pretrained("ultimatemagic79/lora_gemma_9b", use_fast=False)

# ベースモデルのロード
bnb_config = BitsAndBytesConfig(
    load_in_8bit=True,
    llm_int8_threshold=6.0,
)
base_model = AutoModelForCausalLM.from_pretrained(
    "google/gemma-2-9b",
    device_map="auto",
    quantization_config=bnb_config,
    torch_dtype=torch.float16,
)

# 微調整モデルのロード
model = PeftModel.from_pretrained(base_model, "ultimatemagic79/lora_gemma_9b")

# Few-Shot Promptの設定
dataset = load_dataset("elyza/ELYZA-tasks-100")
num_samples = 3
few_shot_samples = dataset["test"].select(range(num_samples))

# 推論の実行
# ELYZA-tasks-100-TVデータセットのロード
def load_elyza_tasks(file_path):
    data = []
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            data.append(json.loads(line))
    return data

test_dataset = load_elyza_tasks('elyza-tasks-100-TV_0.jsonl')

def generate_prompt(input_text, examples):
    prompt = ""
    for idx, example in enumerate(examples, 1):
        prompt += f"[例{idx}]\n"
        prompt += f"入力: {example['input']}\n"
        prompt += f"出力: {example['output']}\n\n"
    prompt += "[あなたの質問]\n"
    prompt += f"入力: {input_text}\n"
    prompt += "出力:"
    return prompt

def generate_response(model, tokenizer, prompt):
    inputs = tokenizer(
        prompt,
        return_tensors="pt",
        truncation=True,
        max_length=1024,
    ).to(model.device)
    
    with torch.no_grad():
        outputs = model.generate(
            input_ids=inputs["input_ids"],
            attention_mask=inputs["attention_mask"],
            max_new_tokens=256,
            temperature=0.7,
            repetition_penalty=1.1,
            top_p=0.9,
            do_sample=True,
            eos_token_id=tokenizer.eos_token_id,
        )
    
    output_text = tokenizer.decode(outputs[0][inputs["input_ids"].shape[-1]:], skip_special_tokens=True)
    return output_text

# 推論と結果の収集
results = []
for test_data in tqdm(test_dataset):
    task_id = test_data["task_id"]
    input_text = test_data["input"]
    prompt = generate_prompt(input_text, few_shot_samples)
    response = generate_response(model, tokenizer, prompt)
    results.append({
        "task_id": task_id,
        "output": response,
    })

ELYZA-tasks-100データセットの利用について

このモデルは、ELYZA社が公開する ELYZA-tasks-100を使用してファインチューニング,プロンプトエンジニアリングを行っています。

ELYZA-tasks-100は CC BY-SA 4.0でライセンスされています。

詳細なライセンス情報は、ELYZA-tasks-100のモデルカードをご参照ください。

Gemma-2-9bの使用権利について

このモデルは、Google社が提供する gemma-2-9b をベースに微調整されています。

Gemma-2-9bは、商用利用が許可されたライセンスの下で公開されています。

詳細なライセンス情報は、gemma-2-9bのモデルカードをご参照ください。

注意: このモデルを使用する際は、Gemma-2-9bのライセンスに従ってください。

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model is not currently available via any of the supported Inference Providers.
The model cannot be deployed to the HF Inference API: The model has no library tag.

Model tree for ultimatemagic79/lora_gemma_9b

Base model

google/gemma-2-9b
Finetuned
(225)
this model

Dataset used to train ultimatemagic79/lora_gemma_9b