Gakki-7B / README.md
ryota39's picture
Update README.md
4813b63 verified
metadata
library_name: transformers
license: apache-2.0
language:
  - ja
  - en

Model Merge

Gakki-7B was build by Chat Vector

A recipe shows as below

Rakuten/RakutenAI-7B-instruct + (prometheus-eval/prometheus-7b-v2.0 - mistralai/Mistral-7B-Instruct-v0.2)

Source Code

import torch
from transformers import AutoModelForCausalLM


def build_chat_vector_model(
    base_model_name,
    inst_model_name,
    target_model_name,
    skip_layers,
    ):

    base_model = AutoModelForCausalLM.from_pretrained(
        base_model_name,
        torch_dtype=torch.bfloat16,
        device_map="cpu",
    )
    inst_model = AutoModelForCausalLM.from_pretrained(
        inst_model_name,
        torch_dtype=torch.bfloat16,
        device_map="cpu",
    )

    target_model = AutoModelForCausalLM.from_pretrained(
        target_model_name,
        torch_dtype=torch.bfloat16,
        device_map="cuda",
    )

    # 英語ベースモデル
    for k, v in base_model.state_dict().items():
        print(k, v.shape)

    # 日本語継続事前学習モデル
    for k, v in target_model.state_dict().items():
        print(k, v.shape)

    # 除外対象
    skip_layers = ["model.embed_tokens.weight", "lm_head.weight"]

    for k, v in target_model.state_dict().items():
        # layernormも除外
        if (k in skip_layers) or ("layernorm" in k):
            continue
        chat_vector = inst_model.state_dict()[k] - base_model.state_dict()[k]
        new_v = v + chat_vector.to(v.device)
        v.copy_(new_v)

    target_model.save_pretrained("./Gakki-7B")

    return


if __name__ == '__main__':

    base_model_name = "mistralai/Mistral-7B-Instruct-v0.2"
    inst_model_name = "prometheus-eval/prometheus-7b-v2.0"
    target_model_name = "Rakuten/RakutenAI-7B-instruct"

    skip_layers = ["model.embed_tokens.weight", "lm_head.weight"]

    build_chat_vector_model(
        base_model_name=base_model_name,
        inst_model_name=inst_model_name,
        target_model_name=target_model_name,
        skip_layers=skip_layers
    )