KoRWKV
RWKV-Runner์์ ์ฌ์ฉํ๊ธฐ ์ํด ๋ณํํ ๋ชจ๋ธ ํ์ผ
import re
import torch
from transformers import RwkvForCausalLM
def convert_state_dict(state_dict):
state_dict_keys = list(state_dict.keys())
for name in state_dict_keys:
weight = state_dict.pop(name)
# emb -> embedding
if name.startswith("emb."):
name = name.replace("emb.", "embeddings.")
# ln_0 -> pre_ln (only present at block 0)
if name.startswith("blocks.0.ln0"):
name = name.replace("blocks.0.ln0", "blocks.0.pre_ln")
# att -> attention
name = re.sub(r"blocks\.(\d+)\.att", r"blocks.\1.attention", name)
# ffn -> feed_forward
name = re.sub(r"blocks\.(\d+)\.ffn", r"blocks.\1.feed_forward", name)
# time_mix_k -> time_mix_key and reshape
if name.endswith(".time_mix_k"):
name = name.replace(".time_mix_k", ".time_mix_key")
# time_mix_v -> time_mix_value and reshape
if name.endswith(".time_mix_v"):
name = name.replace(".time_mix_v", ".time_mix_value")
# time_mix_r -> time_mix_key and reshape
if name.endswith(".time_mix_r"):
name = name.replace(".time_mix_r", ".time_mix_receptance")
if name != "head.weight":
name = "rwkv." + name
state_dict[name] = weight
return state_dict
def revert_state_dict(state_dict):
state_dict_keys = list(state_dict.keys())
for name in state_dict_keys:
weight = state_dict.pop(name)
name = name.removeprefix("rwkv.")
# emb -> embedding
if name.startswith("embeddings."):
name = name.replace("embeddings.", "emb.")
# ln_0 -> pre_ln (only present at block 0)
if name.startswith("blocks.0.pre_ln"):
name = name.replace("blocks.0.pre_ln", "blocks.0.ln0")
# att -> attention
name = re.sub(r"blocks\.(\d+)\.attention", r"blocks.\1.att", name)
# ffn -> feed_forward
name = re.sub(r"blocks\.(\d+)\.feed_forward", r"blocks.\1.ffn", name)
# time_mix_k -> time_mix_key and reshape
if name.endswith(".time_mix_key"):
name = name.replace(".time_mix_key", ".time_mix_k")
# time_mix_v -> time_mix_value and reshape
if name.endswith(".time_mix_value"):
name = name.replace(".time_mix_value", ".time_mix_v")
# time_mix_r -> time_mix_key and reshape
if name.endswith(".time_mix_receptance"):
name = name.replace(".time_mix_receptance", ".time_mix_r")
state_dict[name] = weight
return state_dict
if __name__ == "__main__":
# repo = "beomi/KoRWKV-6B"
repo = "beomi/KoAlpaca-KoRWKV-6B"
model = RwkvForCausalLM.from_pretrained(repo, torch_dtype=torch.bfloat16)
state_dict = model.state_dict()
converted = revert_state_dict(state_dict)
name = repo.split("/")[-1] + ".bf16.pth"
torch.save(converted, name)