|
""" |
|
Helpers for parameter-efficient finetuning via low-rank adapters (LoRA) |
|
-> Mainly follow PEFT / llama recipes |
|
|
|
Right now quantization not super tested |
|
""" |
|
import torch |
|
from torch.nn import Module |
|
|
|
|
|
|
|
def create_peft_config(model: Module, |
|
peft_config: dict, |
|
target_dtype: str = 'bfloat16', |
|
preserve_requires_grad: bool = False, |
|
use_gradient_checkpointing: bool = None, |
|
add_self_attn_prefix: bool = True): |
|
""" |
|
Create a parameter-efficient finetuning model (e.g., attaching LoRAs) |
|
-> Assumes that all non-trainable weights have been frozen already. |
|
If not, freeze them before calling this function. |
|
""" |
|
if peft_config['method'] == 'lora': |
|
from peft import ( |
|
get_peft_model, |
|
LoraConfig, |
|
TaskType, |
|
prepare_model_for_kbit_training, |
|
) |
|
try: |
|
target_modules = [] |
|
for module_name in peft_config['kwargs']['target_modules']: |
|
if ('_proj' in module_name and 'self_attn' not in module_name |
|
and add_self_attn_prefix): |
|
target_modules.append(f'self_attn.{module_name}') |
|
elif '_proj' in module_name: |
|
target_modules.append(module_name) |
|
peft_config['kwargs']['target_modules'] = target_modules |
|
except Exception as e: |
|
print(e) |
|
target_modules = [] |
|
|
|
if 'layers_to_ignore' in peft_config: |
|
peft_config['kwargs']['layers_to_transform'] = [ |
|
i for i in range(len(model.model.layers)) |
|
if i not in peft_config['layers_to_ignore'] |
|
] |
|
|
|
peft_config = LoraConfig( |
|
task_type=TaskType.CAUSAL_LM, |
|
inference_mode=False, |
|
**peft_config['kwargs'], |
|
) |
|
|
|
trainable_weights = [ |
|
n for n, p in model.named_parameters() if p.requires_grad |
|
] |
|
|
|
loaded_in_kbit = (getattr(model, "is_loaded_in_8bit", False) or |
|
getattr(model, "is_loaded_in_4bit", False)) |
|
if loaded_in_kbit: |
|
|
|
|
|
|
|
|
|
model.enable_input_require_grads() |
|
ugc = (use_gradient_checkpointing |
|
if use_gradient_checkpointing is not None else True) |
|
print('-> use_gradient_checkpointing:', ugc) |
|
|
|
model = prepare_model_for_kbit_training( |
|
model, use_gradient_checkpointing=ugc, |
|
gradient_checkpointing_kwargs={'use_reentrant': False}, |
|
) |
|
|
|
model = get_peft_model(model, peft_config) |
|
model.print_trainable_parameters() |
|
|
|
for n, p in model.named_parameters(): |
|
|
|
if preserve_requires_grad: |
|
if n[len('base_model.model.'):] in trainable_weights: |
|
p.requires_grad = True |
|
|
|
|
|
|
|
|
|
if p.requires_grad and loaded_in_kbit: |
|
p.data = p.data.to(getattr(torch, target_dtype)) |
|
|
|
if not loaded_in_kbit: |
|
model.to(dtype=getattr(torch, target_dtype)) |
|
|
|
return model, peft_config |
|
else: |
|
raise NotImplementedError(f"Sorry PEFT method {peft_config['method']} not implemented yet.") |
|
|