Transformers documentation

๐Ÿค— PEFT๋กœ ์–ด๋Œ‘ํ„ฐ ๊ฐ€์ ธ์˜ค๊ธฐ

You are viewing v4.45.1 version. A newer version v4.48.0 is available.
Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

๐Ÿค— PEFT๋กœ ์–ด๋Œ‘ํ„ฐ ๊ฐ€์ ธ์˜ค๊ธฐ

Parameter-Efficient Fine Tuning (PEFT) ๋ฐฉ๋ฒ•์€ ์‚ฌ์ „ํ›ˆ๋ จ๋œ ๋ชจ๋ธ์˜ ๋งค๊ฐœ๋ณ€์ˆ˜๋ฅผ ๋ฏธ์„ธ ์กฐ์ • ์ค‘ ๊ณ ์ •์‹œํ‚ค๊ณ , ๊ทธ ์œ„์— ํ›ˆ๋ จํ•  ์ˆ˜ ์žˆ๋Š” ๋งค์šฐ ์ ์€ ์ˆ˜์˜ ๋งค๊ฐœ๋ณ€์ˆ˜(์–ด๋Œ‘ํ„ฐ)๋ฅผ ์ถ”๊ฐ€ํ•ฉ๋‹ˆ๋‹ค. ์–ด๋Œ‘ํ„ฐ๋Š” ์ž‘์—…๋ณ„ ์ •๋ณด๋ฅผ ํ•™์Šตํ•˜๋„๋ก ํ›ˆ๋ จ๋ฉ๋‹ˆ๋‹ค. ์ด ์ ‘๊ทผ ๋ฐฉ์‹์€ ์™„์ „ํžˆ ๋ฏธ์„ธ ์กฐ์ •๋œ ๋ชจ๋ธ์— ํ•„์ ํ•˜๋Š” ๊ฒฐ๊ณผ๋ฅผ ์ƒ์„ฑํ•˜๋ฉด์„œ, ๋ฉ”๋ชจ๋ฆฌ ํšจ์œจ์ ์ด๊ณ  ๋น„๊ต์  ์ ์€ ์ปดํ“จํŒ… ๋ฆฌ์†Œ์Šค๋ฅผ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.

๋˜ํ•œ PEFT๋กœ ํ›ˆ๋ จ๋œ ์–ด๋Œ‘ํ„ฐ๋Š” ์ผ๋ฐ˜์ ์œผ๋กœ ์ „์ฒด ๋ชจ๋ธ๋ณด๋‹ค ํ›จ์”ฌ ์ž‘๊ธฐ ๋•Œ๋ฌธ์— ๊ณต์œ , ์ €์žฅ ๋ฐ ๊ฐ€์ ธ์˜ค๊ธฐ๊ฐ€ ํŽธ๋ฆฌํ•ฉ๋‹ˆ๋‹ค.

Hub์— ์ €์žฅ๋œ OPTForCausalLM ๋ชจ๋ธ์˜ ์–ด๋Œ‘ํ„ฐ ๊ฐ€์ค‘์น˜๋Š” ์ตœ๋Œ€ 700MB์— ๋‹ฌํ•˜๋Š” ๋ชจ๋ธ ๊ฐ€์ค‘์น˜์˜ ์ „์ฒด ํฌ๊ธฐ์— ๋น„ํ•ด ์•ฝ 6MB์— ๋ถˆ๊ณผํ•ฉ๋‹ˆ๋‹ค.

๐Ÿค— PEFT ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ์— ๋Œ€ํ•ด ์ž์„ธํžˆ ์•Œ์•„๋ณด๋ ค๋ฉด ๋ฌธ์„œ๋ฅผ ํ™•์ธํ•˜์„ธ์š”.

์„ค์ •

๐Ÿค— PEFT๋ฅผ ์„ค์น˜ํ•˜์—ฌ ์‹œ์ž‘ํ•˜์„ธ์š”:

pip install peft

์ƒˆ๋กœ์šด ๊ธฐ๋Šฅ์„ ์‚ฌ์šฉํ•ด๋ณด๊ณ  ์‹ถ๋‹ค๋ฉด, ๋‹ค์Œ ์†Œ์Šค์—์„œ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๋ฅผ ์„ค์น˜ํ•˜๋Š” ๊ฒƒ์ด ์ข‹์Šต๋‹ˆ๋‹ค:

pip install git+https://github.com/huggingface/peft.git

์ง€์›๋˜๋Š” PEFT ๋ชจ๋ธ

๐Ÿค— Transformers๋Š” ๊ธฐ๋ณธ์ ์œผ๋กœ ์ผ๋ถ€ PEFT ๋ฐฉ๋ฒ•์„ ์ง€์›ํ•˜๋ฉฐ, ๋กœ์ปฌ์ด๋‚˜ Hub์— ์ €์žฅ๋œ ์–ด๋Œ‘ํ„ฐ ๊ฐ€์ค‘์น˜๋ฅผ ๊ฐ€์ ธ์˜ค๊ณ  ๋ช‡ ์ค„์˜ ์ฝ”๋“œ๋งŒ์œผ๋กœ ์‰ฝ๊ฒŒ ์‹คํ–‰ํ•˜๊ฑฐ๋‚˜ ํ›ˆ๋ จํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ๋‹ค์Œ ๋ฐฉ๋ฒ•์„ ์ง€์›ํ•ฉ๋‹ˆ๋‹ค:

๐Ÿค— PEFT์™€ ๊ด€๋ จ๋œ ๋‹ค๋ฅธ ๋ฐฉ๋ฒ•(์˜ˆ: ํ”„๋กฌํ”„ํŠธ ํ›ˆ๋ จ ๋˜๋Š” ํ”„๋กฌํ”„ํŠธ ํŠœ๋‹) ๋˜๋Š” ์ผ๋ฐ˜์ ์ธ ๐Ÿค— PEFT ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ์— ๋Œ€ํ•ด ์ž์„ธํžˆ ์•Œ์•„๋ณด๋ ค๋ฉด ๋ฌธ์„œ๋ฅผ ์ฐธ์กฐํ•˜์„ธ์š”.

PEFT ์–ด๋Œ‘ํ„ฐ ๊ฐ€์ ธ์˜ค๊ธฐ

๐Ÿค— Transformers์—์„œ PEFT ์–ด๋Œ‘ํ„ฐ ๋ชจ๋ธ์„ ๊ฐ€์ ธ์˜ค๊ณ  ์‚ฌ์šฉํ•˜๋ ค๋ฉด Hub ์ €์žฅ์†Œ๋‚˜ ๋กœ์ปฌ ๋””๋ ‰ํ„ฐ๋ฆฌ์— adapter_config.json ํŒŒ์ผ๊ณผ ์–ด๋Œ‘ํ„ฐ ๊ฐ€์ค‘์น˜๊ฐ€ ํฌํ•จ๋˜์–ด ์žˆ๋Š”์ง€ ํ™•์ธํ•˜์‹ญ์‹œ์˜ค. ๊ทธ๋Ÿฐ ๋‹ค์Œ AutoModelFor ํด๋ž˜์Šค๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ PEFT ์–ด๋Œ‘ํ„ฐ ๋ชจ๋ธ์„ ๊ฐ€์ ธ์˜ฌ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์˜ˆ๋ฅผ ๋“ค์–ด ์ธ๊ณผ ๊ด€๊ณ„ ์–ธ์–ด ๋ชจ๋ธ์šฉ PEFT ์–ด๋Œ‘ํ„ฐ ๋ชจ๋ธ์„ ๊ฐ€์ ธ์˜ค๋ ค๋ฉด ๋‹ค์Œ ๋‹จ๊ณ„๋ฅผ ๋”ฐ๋ฅด์‹ญ์‹œ์˜ค:

  1. PEFT ๋ชจ๋ธ ID๋ฅผ ์ง€์ •ํ•˜์‹ญ์‹œ์˜ค.
  2. AutoModelForCausalLM ํด๋ž˜์Šค์— ์ „๋‹ฌํ•˜์‹ญ์‹œ์˜ค.
from transformers import AutoModelForCausalLM, AutoTokenizer

peft_model_id = "ybelkada/opt-350m-lora"
model = AutoModelForCausalLM.from_pretrained(peft_model_id)

AutoModelFor ํด๋ž˜์Šค๋‚˜ ๊ธฐ๋ณธ ๋ชจ๋ธ ํด๋ž˜์Šค(์˜ˆ: OPTForCausalLM ๋˜๋Š” LlamaForCausalLM) ์ค‘ ํ•˜๋‚˜๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ PEFT ์–ด๋Œ‘ํ„ฐ๋ฅผ ๊ฐ€์ ธ์˜ฌ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

load_adapter ๋ฉ”์†Œ๋“œ๋ฅผ ํ˜ธ์ถœํ•˜์—ฌ PEFT ์–ด๋Œ‘ํ„ฐ๋ฅผ ๊ฐ€์ ธ์˜ฌ ์ˆ˜๋„ ์žˆ์Šต๋‹ˆ๋‹ค.

from transformers import AutoModelForCausalLM, AutoTokenizer

model_id = "facebook/opt-350m"
peft_model_id = "ybelkada/opt-350m-lora"

model = AutoModelForCausalLM.from_pretrained(model_id)
model.load_adapter(peft_model_id)

8๋น„ํŠธ ๋˜๋Š” 4๋น„ํŠธ๋กœ ๊ฐ€์ ธ์˜ค๊ธฐ

bitsandbytes ํ†ตํ•ฉ์€ 8๋น„ํŠธ์™€ 4๋น„ํŠธ ์ •๋ฐ€๋„ ๋ฐ์ดํ„ฐ ์œ ํ˜•์„ ์ง€์›ํ•˜๋ฏ€๋กœ ํฐ ๋ชจ๋ธ์„ ๊ฐ€์ ธ์˜ฌ ๋•Œ ์œ ์šฉํ•˜๋ฉด์„œ ๋ฉ”๋ชจ๋ฆฌ๋„ ์ ˆ์•ฝํ•ฉ๋‹ˆ๋‹ค. ๋ชจ๋ธ์„ ํ•˜๋“œ์›จ์–ด์— ํšจ๊ณผ์ ์œผ๋กœ ๋ถ„๋ฐฐํ•˜๋ ค๋ฉด from_pretrained()์— load_in_8bit ๋˜๋Š” load_in_4bit ๋งค๊ฐœ๋ณ€์ˆ˜๋ฅผ ์ถ”๊ฐ€ํ•˜๊ณ  device_map="auto"๋ฅผ ์„ค์ •ํ•˜์„ธ์š”:

from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

peft_model_id = "ybelkada/opt-350m-lora"
model = AutoModelForCausalLM.from_pretrained(peft_model_id, quantization_config=BitsAndBytesConfig(load_in_8bit=True))

์ƒˆ ์–ด๋Œ‘ํ„ฐ ์ถ”๊ฐ€

์ƒˆ ์–ด๋Œ‘ํ„ฐ๊ฐ€ ํ˜„์žฌ ์–ด๋Œ‘ํ„ฐ์™€ ๋™์ผํ•œ ์œ ํ˜•์ธ ๊ฒฝ์šฐ์— ํ•œํ•ด ๊ธฐ์กด ์–ด๋Œ‘ํ„ฐ๊ฐ€ ์žˆ๋Š” ๋ชจ๋ธ์— ์ƒˆ ์–ด๋Œ‘ํ„ฐ๋ฅผ ์ถ”๊ฐ€ํ•˜๋ ค๋ฉด ~peft.PeftModel.add_adapter๋ฅผ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์˜ˆ๋ฅผ ๋“ค์–ด ๋ชจ๋ธ์— ๊ธฐ์กด LoRA ์–ด๋Œ‘ํ„ฐ๊ฐ€ ์—ฐ๊ฒฐ๋˜์–ด ์žˆ๋Š” ๊ฒฝ์šฐ:

from transformers import AutoModelForCausalLM, OPTForCausalLM, AutoTokenizer
from peft import PeftConfig

model_id = "facebook/opt-350m"
model = AutoModelForCausalLM.from_pretrained(model_id)

lora_config = LoraConfig(
    target_modules=["q_proj", "k_proj"],
    init_lora_weights=False
)

model.add_adapter(lora_config, adapter_name="adapter_1")

์ƒˆ ์–ด๋Œ‘ํ„ฐ๋ฅผ ์ถ”๊ฐ€ํ•˜๋ ค๋ฉด:

# attach new adapter with same config
model.add_adapter(lora_config, adapter_name="adapter_2")

์ด์ œ ~peft.PeftModel.set_adapter๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์–ด๋Œ‘ํ„ฐ๋ฅผ ์‚ฌ์šฉํ•  ์–ด๋Œ‘ํ„ฐ๋กœ ์„ค์ •ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค:

# use adapter_1
model.set_adapter("adapter_1")
output = model.generate(**inputs)
print(tokenizer.decode(output_disabled[0], skip_special_tokens=True))

# use adapter_2
model.set_adapter("adapter_2")
output_enabled = model.generate(**inputs)
print(tokenizer.decode(output_enabled[0], skip_special_tokens=True))

์–ด๋Œ‘ํ„ฐ ํ™œ์„ฑํ™” ๋ฐ ๋น„ํ™œ์„ฑํ™”

๋ชจ๋ธ์— ์–ด๋Œ‘ํ„ฐ๋ฅผ ์ถ”๊ฐ€ํ•œ ํ›„ ์–ด๋Œ‘ํ„ฐ ๋ชจ๋“ˆ์„ ํ™œ์„ฑํ™” ๋˜๋Š” ๋น„ํ™œ์„ฑํ™”ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์–ด๋Œ‘ํ„ฐ ๋ชจ๋“ˆ์„ ํ™œ์„ฑํ™”ํ•˜๋ ค๋ฉด:

from transformers import AutoModelForCausalLM, OPTForCausalLM, AutoTokenizer
from peft import PeftConfig

model_id = "facebook/opt-350m"
adapter_model_id = "ybelkada/opt-350m-lora"
tokenizer = AutoTokenizer.from_pretrained(model_id)
text = "Hello"
inputs = tokenizer(text, return_tensors="pt")

model = AutoModelForCausalLM.from_pretrained(model_id)
peft_config = PeftConfig.from_pretrained(adapter_model_id)

# to initiate with random weights
peft_config.init_lora_weights = False

model.add_adapter(peft_config)
model.enable_adapters()
output = model.generate(**inputs)

์–ด๋Œ‘ํ„ฐ ๋ชจ๋“ˆ์„ ๋น„ํ™œ์„ฑํ™”ํ•˜๋ ค๋ฉด:

model.disable_adapters()
output = model.generate(**inputs)

PEFT ์–ด๋Œ‘ํ„ฐ ํ›ˆ๋ จ

PEFT ์–ด๋Œ‘ํ„ฐ๋Š” Trainer ํด๋ž˜์Šค์—์„œ ์ง€์›๋˜๋ฏ€๋กœ ํŠน์ • ์‚ฌ์šฉ ์‚ฌ๋ก€์— ๋งž๊ฒŒ ์–ด๋Œ‘ํ„ฐ๋ฅผ ํ›ˆ๋ จํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ๋ช‡ ์ค„์˜ ์ฝ”๋“œ๋ฅผ ์ถ”๊ฐ€ํ•˜๊ธฐ๋งŒ ํ•˜๋ฉด ๋ฉ๋‹ˆ๋‹ค. ์˜ˆ๋ฅผ ๋“ค์–ด LoRA ์–ด๋Œ‘ํ„ฐ๋ฅผ ํ›ˆ๋ จํ•˜๋ ค๋ฉด:

Trainer๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๋ชจ๋ธ์„ ๋ฏธ์„ธ ์กฐ์ •ํ•˜๋Š” ๊ฒƒ์ด ์ต์ˆ™ํ•˜์ง€ ์•Š๋‹ค๋ฉด ์‚ฌ์ „ํ›ˆ๋ จ๋œ ๋ชจ๋ธ์„ ๋ฏธ์„ธ ์กฐ์ •ํ•˜๊ธฐ ํŠœํ† ๋ฆฌ์–ผ์„ ํ™•์ธํ•˜์„ธ์š”.

  1. ์ž‘์—… ์œ ํ˜• ๋ฐ ํ•˜์ดํผํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ ์ง€์ •ํ•˜์—ฌ ์–ด๋Œ‘ํ„ฐ ๊ตฌ์„ฑ์„ ์ •์˜ํ•ฉ๋‹ˆ๋‹ค. ํ•˜์ดํผํŒŒ๋ผ๋ฏธํ„ฐ์— ๋Œ€ํ•œ ์ž์„ธํ•œ ๋‚ด์šฉ์€ ~peft.LoraConfig๋ฅผ ์ฐธ์กฐํ•˜์„ธ์š”.
from peft import LoraConfig

peft_config = LoraConfig(
    lora_alpha=16,
    lora_dropout=0.1,
    r=64,
    bias="none",
    task_type="CAUSAL_LM",
)
  1. ๋ชจ๋ธ์— ์–ด๋Œ‘ํ„ฐ๋ฅผ ์ถ”๊ฐ€ํ•ฉ๋‹ˆ๋‹ค.
model.add_adapter(peft_config)
  1. ์ด์ œ ๋ชจ๋ธ์„ Trainer์— ์ „๋‹ฌํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค!
trainer = Trainer(model=model, ...)
trainer.train()

ํ›ˆ๋ จํ•œ ์–ด๋Œ‘ํ„ฐ๋ฅผ ์ €์žฅํ•˜๊ณ  ๋‹ค์‹œ ๊ฐ€์ ธ์˜ค๋ ค๋ฉด:

model.save_pretrained(save_dir)
model = AutoModelForCausalLM.from_pretrained(save_dir)
< > Update on GitHub