from medusa.model.medusa_model_legacy import MedusaConfig from medusa.model.medusa_model_new import MedusaModel import torch import os import json import argparse def read_json(file_path): with open(file_path, 'r') as file: data = json.load(file) return data def write_json(data, file_path): with open(file_path, 'w') as file: json.dump(data, file, indent=4) os.environ["CUDA_VISIBLE_DEVICES"] = "7" def main(): parser = argparse.ArgumentParser(description='Example extract medusa head') parser.add_argument('-m', '--model-path', type=str, required=True, help='Path to new medusa models') parser.add_argument('-o', '--output-dir', type=str, required=True, help='Path to save medusa head') parser.add_argument('--use-full-key', action='store_true', help='medusa head keys with medusa_head. or not') parser.add_argument('--verbose', action='store_true', help='Enable verbose output') args = parser.parse_args() config = read_json(os.path.join(args.model_path, "config.json")) new_config = {} for k, v in config.items(): if k.startswith("medusa"): new_config[k] = v new_config["transformers_version"] = config["transformers_version"] new_config["base_model_name_or_path"] = config["_name_or_path"] model = MedusaModel.from_pretrained(pretrained_model_name_or_path=args.model_path) if args.verbose: print(model) medusa_lm_head = dict() for k, v in model.state_dict().items(): if k.startswith("medusa"): if args.verbose: print(k) if args.use_full_key: medusa_lm_head[k] = v else: medusa_lm_head[k[len("medusa_head."):]] = v if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) write_json(new_config, f"{args.output_dir}/config.json") model_save_path = f"{args.output_dir}/medusa_lm_head.pt" torch.save(medusa_lm_head, model_save_path) print(f"Model saved to {model_save_path}") if __name__ == "__main__": main()