medusa-maker / src /train_workflow.py
joaogante's picture
joaogante HF staff
larger datasets
3b0ae8d
"""
Holds the interface between the gradio app and the medusa training script
"""
import os
import multiprocessing as mp
from huggingface_hub import HfApi
from huggingface_hub.utils import RepositoryNotFoundError
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
import torch
import torch.distributed.run as distributed_run
OUTPUT_DIR = "medusa_heads"
DATASET = "vicuna"
# These can't be changed (e.g. they control the output path)
FIXED_TRAINING_ARGS = \
"""src/medusa_training_script.py
--model_name_or_path {model_id}
--output_dir {output_dir}
--run_name {model_id}-medusa-{dataset}
--dataset {dataset}"""
# These can be freely changed
DEFAULT_TRAINING_ARGS = \
"""--medusa_num_heads 3
--medusa_num_layers 1
--model_max_length 2048
--bf16 True
--num_train_epochs 1
--per_device_train_batch_size 64
--per_device_eval_batch_size 64
--gradient_accumulation_steps 8
--evaluation_strategy no
--save_strategy no
--weight_decay 0.0
--warmup_ratio 0.1
--lr_scheduler_type cosine
--logging_steps 10
--tf32 True
--auto_find_batch_size True
--learning_rate 1e-3"""
def train_medusa_heads(model_id: str, training_args: str, dataset: str):
all_training_args = FIXED_TRAINING_ARGS.format(
model_id=model_id, output_dir=OUTPUT_DIR, dataset=dataset,
) + "\n" + training_args
all_training_arg_list = []
for arg in all_training_args.split("\n"):
all_training_arg_list += arg.split(" ")
print("Full argument list:", all_training_arg_list)
parser = distributed_run.get_args_parser()
args = parser.parse_args(all_training_arg_list)
distributed_run.run(args)
def run(model_id: str, training_args: str, dataset: str) -> str:
print(f"\n\n\nNEW RUN: {model_id}")
api = HfApi()
model_name = model_id.split("/")[-1]
repo_id = f"joaogante/{model_name}-medusa-{dataset}"
# Input validation
if model_id == "":
return """
### Invalid input 🐞
Please fill a model_id.
"""
if api.repo_exists(repo_id):
return f"""
### Invalid input 🐞
{repo_id} already exists, which means that {model_id} has already been used to create medusa heads.
"""
print(f"Valid inputs βœ…\nValidating model_id: {model_id}")
# Attempt to load the base model
try:
config = AutoConfig.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16)
del config, tokenizer, model
except Exception as e:
return f"""
### {model_id} can't be loaded with AutoClasses 🐞
{e}
"""
print(f"{model_id} can be loaded βœ…\nCreating medusa heads (will take a few hours)")
# Run the medusa heads creation
try:
proc = mp.Process(target=train_medusa_heads, args=(model_id, training_args, dataset))
proc.start()
proc.join()
print("Medusa heads training process completed (it might have crashed!)")
except Exception as e:
print("Error ❌\n", e)
return f"""
### Error 😒😒😒
{e}
"""
# Upload the medusa heads to the Hub
try:
# Folder path from https://github.com/FasterDecoding/Medusa/blob/main/medusa/train/train.py#L399
folder_path = (
f"{OUTPUT_DIR}_medusa_{model_name}"
)
if not any([x for x in os.listdir(folder_path) if len(x) >= 3 and x[-3:] == ".pt"]):
raise Exception(
"No model data in the expected model folder, the traning run probably failed. Check the logs for more "
"information."
)
api.create_repo(
repo_id=repo_id,
exist_ok=True,
)
api.upload_folder(
folder_path=folder_path,
repo_id=repo_id,
)
print("Medusa heads upload success βœ…\n Uploaded to: ", repo_id)
return f"""
### Success πŸ”₯
Yay! Medusa heads were successfully created and uploaded to the following repo: {repo_id}
"""
except Exception as e:
print("Error ❌\n", e)
try:
api.delete_repo(repo_id)
except RepositoryNotFoundError:
pass
return f"""
### Error 😒😒😒
{e}
"""