# Copyright (c) Meta Platforms, Inc. and affiliates. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. import inspect # from dataclasses import asdict import torch.distributed as dist from torch.utils.data import DistributedSampler from peft import ( LoraConfig, AdaptionPromptConfig, PrefixTuningConfig, ) from transformers import default_data_collator from transformers.data import DataCollatorForSeq2Seq # from llama_recipes.configs import datasets, lora_config, llama_adapter_config, prefix_config, train_config from slam_llm.data.sampler import LengthBasedBatchSampler, DistributedLengthBasedBatchSampler from omegaconf import OmegaConf import logging logger = logging.getLogger(__name__) # def update_config(config, **kwargs): # if isinstance(config, (tuple, list)): # for c in config: # update_config(c, **kwargs) # else: # for k, v in kwargs.items(): # if hasattr(config, k): # setattr(config, k, v) # elif "." in k: # # allow --some_config.some_param=True # config_name, param_name = k.split(".") # if type(config).__name__ == config_name: # if hasattr(config, param_name): # setattr(config, param_name, v) # else: # # In case of specialized config we can warm user # logger.warning(f"Warning: {config_name} does not accept parameter: {k}") # elif isinstance(config, train_config): # logger.warning(f"Warning: unknown parameter {k}") def generate_peft_config(train_config): # configs = (lora_config, llama_adapter_config, prefix_config) # peft_configs = (LoraConfig, AdaptionPromptConfig, PrefixTuningConfig) peft_configs = {"lora": LoraConfig, "llama_adapter": AdaptionPromptConfig, "prefix": PrefixTuningConfig } # names = tuple(c.__name__.rstrip("_config") for c in configs) # # assert train_config.peft_method in names, f"Peft config not found: {train_config.peft_method}" # # config = configs[names.index(train_config.peft_method)]() config = train_config.peft_config params = OmegaConf.to_container(config, resolve=True) # peft_config = peft_configs[names.index(train_config.peft_method)](**params) params.pop("peft_method", None) #(FIX:MZY): remove peft_method from params to avoid error peft_config = peft_configs[config.get("peft_method", "lora")](**params) return peft_config def get_dataloader_kwargs(train_config, dataset, tokenizer, mode): kwargs = {} batch_size = train_config.batch_size_training if mode=="train" else train_config.val_batch_size if train_config.batching_strategy == "padding": if train_config.enable_fsdp or train_config.enable_ddp or train_config.enable_deepspeed: kwargs["batch_sampler"] = DistributedLengthBasedBatchSampler( dataset, batch_size=batch_size, rank=dist.get_rank(), num_replicas=dist.get_world_size(), shuffle=mode=="train", ) else: kwargs["batch_sampler"] = LengthBasedBatchSampler(dataset, batch_size, drop_last=True, shuffle=mode=="train") kwargs["collate_fn"] = DataCollatorForSeq2Seq(tokenizer) elif train_config.batching_strategy == "packing": if train_config.enable_fsdp or train_config.enable_ddp or train_config.enable_deepspeed: kwargs["sampler"] = DistributedSampler( dataset, rank=dist.get_rank(), num_replicas=dist.get_world_size(), shuffle=mode=="train", ) kwargs["batch_size"] = batch_size kwargs["drop_last"] = True kwargs["collate_fn"] = default_data_collator else: # raise ValueError(f"Unknown batching strategy: {train_config.batching_strategy}") if train_config.enable_fsdp or train_config.enable_ddp or train_config.enable_deepspeed: kwargs["sampler"] = DistributedSampler( dataset, rank=dist.get_rank(), num_replicas=dist.get_world_size(), shuffle=mode=="train", ) kwargs["batch_size"] = batch_size kwargs["drop_last"] = True kwargs["collate_fn"] = dataset.collator logger.info(f"Using batching strategy: {train_config.batching_strategy}") return kwargs