""" Hold the training script for the medusa model. Adapted from the original code here: https://github.com/FasterDecoding/Medusa/blob/main/medusa/train/train.py """ import os from dataclasses import dataclass, field import pathlib from typing import Dict, Optional import torch from torch.utils.data import Dataset import transformers from transformers import Trainer, BitsAndBytesConfig from transformers.trainer_pt_utils import LabelSmoother from torch.nn import CrossEntropyLoss from medusa.model.medusa_model import MedusaModel, MedusaConfig from calibration_datasets import CalibrationDataset IGNORE_TOKEN_ID = LabelSmoother.ignore_index # Customized for training Medusa heads class CustomizedTrainer(Trainer): def compute_loss(self, model, inputs, return_outputs=False): """ Compute the training loss for the model. Args: model (torch.nn.Module): The model for which to compute the loss. inputs (dict): The input data, including input IDs, attention mask, and labels. return_outputs (bool): Whether to return model outputs along with the loss. Returns: Union[float, Tuple[float, torch.Tensor]]: The computed loss, optionally with model outputs. """ # DDP will give us model.module if hasattr(model, "module"): medusa = model.module.medusa else: medusa = model.medusa logits = model( input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"] ) labels = inputs["labels"] # Shift so that tokens < n predict n loss = 0 loss_fct = CrossEntropyLoss() log = {} for i in range(medusa): medusa_logits = logits[i, :, : -(2 + i)].contiguous() medusa_labels = labels[..., 2 + i :].contiguous() medusa_logits = medusa_logits.view(-1, logits.shape[-1]) medusa_labels = medusa_labels.view(-1) medusa_labels = medusa_labels.to(medusa_logits.device) loss_i = loss_fct(medusa_logits, medusa_labels) loss += loss_i not_ignore = medusa_labels.ne(IGNORE_TOKEN_ID) medusa_labels = medusa_labels[not_ignore] # Add top-k accuracy for k in range(1, 6): _, topk = medusa_logits.topk(k, dim=-1) topk = topk[not_ignore] correct = topk.eq(medusa_labels.unsqueeze(-1)).any(-1) log[f"medusa{i}_top{k}"] = correct.float().mean().item() log[f"medusa{i}_loss"] = loss_i.item() self.log(log) return (loss, logits) if return_outputs else loss @dataclass class ModelArguments: model_name_or_path: Optional[str] = field() load_in_4bit: bool = field( default=False, metadata={"help": "Load in 4 bit."}, ) load_in_8bit: bool = field( default=False, metadata={"help": "Load in 8 bit."}, ) @dataclass class DataArguments: dataset: str = field( metadata={"help": "One of the datasets names in a CalibrationDataset subclass."}, ) @dataclass class TrainingArguments(transformers.TrainingArguments): cache_dir: Optional[str] = field(default=None) optim: str = field(default="adamw_torch") model_max_length: int = field( default=2048, metadata={ "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)." }, ) medusa_num_heads: int = field( default=1, metadata={"help": "Number of Medusa heads."}, ) medusa_num_layers: int = field( default=1, metadata={"help": "Number of layers for each Medusa head."}, ) local_rank = None def rank0_print(*args): if local_rank == 0: print(*args) def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str): """ Save the model's state dictionary to a specified directory. Args: trainer (transformers.Trainer): The Hugging Face Trainer object. output_dir (str): The directory where the model state dictionary will be saved. """ state_dict = trainer.model.state_dict() if trainer.args.should_save: cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()} del state_dict trainer._save(output_dir, state_dict=cpu_state_dict) # noqa class SupervisedDataset(Dataset): """Dataset for supervised fine-tuning. Args: dataset (str): One of the datasets names in a CalibrationDataset subclass. tokenizer (transformers.PreTrainedTokenizer): The tokenizer to use for data preprocessing. """ def __init__(self, dataset, tokenizer: transformers.PreTrainedTokenizer): super(SupervisedDataset, self).__init__() rank0_print("Formatting inputs...") dataset_classes = CalibrationDataset.__subclasses__() for dataset_class in dataset_classes: if dataset_class.dataset == dataset: dataset = dataset_class(num_samples=int(1e6), seqlen=tokenizer.model_max_length, tokenizer=tokenizer) break tokenized = dataset.tokenize_dataset() self.input_ids = torch.tensor([data["input_ids"] for data in tokenized], dtype=torch.long) self.attention_mask = torch.tensor([data["attention_mask"] for data in tokenized], dtype=torch.long) def __len__(self): return self.input_ids.shape[0] def __getitem__(self, i) -> Dict[str, torch.Tensor]: return dict( input_ids=self.input_ids[i], labels=self.input_ids[i], attention_mask=self.attention_mask[i], ) def train(): global local_rank parser = transformers.HfArgumentParser( (ModelArguments, DataArguments, TrainingArguments) ) model_args, data_args, training_args = parser.parse_args_into_dataclasses() local_rank = training_args.local_rank config = transformers.AutoConfig.from_pretrained( model_args.model_name_or_path, cache_dir=training_args.cache_dir, ) config.use_cache = False quantization_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", ) # Load model and tokenizer try: # Try loading with FA2 model = transformers.AutoModelForCausalLM.from_pretrained( model_args.model_name_or_path, config=config, cache_dir=training_args.cache_dir, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16, quantization_config=quantization_config if model_args.load_in_4bit else None, load_in_4bit=model_args.load_in_4bit, load_in_8bit=model_args.load_in_8bit, attn_implementation="flash_attention_2", ) except: model = transformers.AutoModelForCausalLM.from_pretrained( model_args.model_name_or_path, config=config, cache_dir=training_args.cache_dir, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16, quantization_config=quantization_config if model_args.load_in_4bit else None, load_in_4bit=model_args.load_in_4bit, load_in_8bit=model_args.load_in_8bit, ) # Freeze the base model for param in model.base_model.parameters(): param.requires_grad = False # Add Medusa heads medusa_lm_head = MedusaModel( model, medusa_num_heads=training_args.medusa_num_heads, medusa_num_layers=training_args.medusa_num_layers, base_model_name_or_path=model_args.model_name_or_path, ) # Format output dir training_args.output_dir = f"{training_args.output_dir}_medusa_{model_args.model_name_or_path.split('/')[-1]}" tokenizer = transformers.AutoTokenizer.from_pretrained( model_args.model_name_or_path, cache_dir=training_args.cache_dir, model_max_length=training_args.model_max_length, padding_side="right", use_fast=False, ) tokenizer.pad_token = tokenizer.unk_token # Load data data_module = {"train_dataset": SupervisedDataset(data_args.dataset, tokenizer), "eval_dataset": None} # Generate Medusa config for pushing to HF hub medusa_config = MedusaConfig( medusa_num_heads=training_args.medusa_num_heads, medusa_num_layers=training_args.medusa_num_layers, base_model_name_or_path=model_args.model_name_or_path, ) # Save Medusa config medusa_config.save_pretrained(training_args.output_dir) # Start trainner trainer = CustomizedTrainer( model=medusa_lm_head, tokenizer=tokenizer, args=training_args, **data_module ) if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")): trainer.train(resume_from_checkpoint=True) else: trainer.train() model.config.use_cache = True # Save MedusaHead seperately if hasattr(medusa_lm_head, "module"): lm_head = medusa_lm_head.module.medusa_head else: lm_head = medusa_lm_head.medusa_head # Save Medusa heads torch.save( lm_head.state_dict(), os.path.join(training_args.output_dir, "medusa_lm_head.pt"), ) if __name__ == "__main__": train()