lolcats / src /model /peft.py
ariG23498's picture
ariG23498 HF staff
chore: adding lolcats configs scrc and src
ae81e0f
"""
Helpers for parameter-efficient finetuning via low-rank adapters (LoRA)
-> Mainly follow PEFT / llama recipes
Right now quantization not super tested
"""
import torch
from torch.nn import Module
# Modified from https://github.com/facebookresearch/llama-recipes/blob/main/examples/quickstart.ipynb
def create_peft_config(model: Module,
peft_config: dict,
target_dtype: str = 'bfloat16',
preserve_requires_grad: bool = False,
use_gradient_checkpointing: bool = None,
add_self_attn_prefix: bool = True):
"""
Create a parameter-efficient finetuning model (e.g., attaching LoRAs)
-> Assumes that all non-trainable weights have been frozen already.
If not, freeze them before calling this function.
"""
if peft_config['method'] == 'lora':
from peft import (
get_peft_model,
LoraConfig,
TaskType,
prepare_model_for_kbit_training,
)
try:
target_modules = [] # hack to only do self_attn terms
for module_name in peft_config['kwargs']['target_modules']:
if ('_proj' in module_name and 'self_attn' not in module_name
and add_self_attn_prefix):
target_modules.append(f'self_attn.{module_name}')
elif '_proj' in module_name:
target_modules.append(module_name)
peft_config['kwargs']['target_modules'] = target_modules
except Exception as e:
print(e)
target_modules = []
if 'layers_to_ignore' in peft_config:
peft_config['kwargs']['layers_to_transform'] = [
i for i in range(len(model.model.layers))
if i not in peft_config['layers_to_ignore']
]
peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
**peft_config['kwargs'],
)
# Save parameters that did not have frozen weights before to unfreeze later
trainable_weights = [
n for n, p in model.named_parameters() if p.requires_grad
]
# Prepare int-8 or int-4 model for training
loaded_in_kbit = (getattr(model, "is_loaded_in_8bit", False) or
getattr(model, "is_loaded_in_4bit", False))
if loaded_in_kbit: # From https://huggingface.co./docs/peft/en/package_reference/peft_model:
# This method wraps the entire protocol for preparing a model before running a training.
# 1- Cast the layernorm in fp32
# 2- making output embedding layer require grads
# 3- Add the upcasting of the lm head to fp32
model.enable_input_require_grads()
ugc = (use_gradient_checkpointing
if use_gradient_checkpointing is not None else True)
print('-> use_gradient_checkpointing:', ugc)
# model.gradient_checkpointing_enable()
model = prepare_model_for_kbit_training(
model, use_gradient_checkpointing=ugc,
gradient_checkpointing_kwargs={'use_reentrant': False},
)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
for n, p in model.named_parameters():
# Unfreeze weights frozen by get_peft_model()
if preserve_requires_grad:
if n[len('base_model.model.'):] in trainable_weights:
p.requires_grad = True
# prepare_model_for_kbit_training will cast all non INT8 parameters to fp32
# -> https://github.com/huggingface/peft/blob/7e84dec20b3106bdd0a90ba8e80187f0aec835b7/src/peft/utils/other.py#L103
# So we'll cast these back to their prior dtype
if p.requires_grad and loaded_in_kbit:
p.data = p.data.to(getattr(torch, target_dtype))
if not loaded_in_kbit:
model.to(dtype=getattr(torch, target_dtype))
return model, peft_config
else:
raise NotImplementedError(f"Sorry PEFT method {peft_config['method']} not implemented yet.")