#!/usr/bin/env python3 from safetensors.torch import load_file as safe_load_file import time import sys direct_on_gpu = bool(int(sys.argv[1])) if direct_on_gpu: start_time = time.time() checkpoint = safe_load_file("/home/patrick_huggingface_co/stable-diffusion-v1-4/unet/diffusion_pytorch_model.safetensors", device=0) print("Directly on GPU", time.time() - start_time) else: start_time = time.time() checkpoint = safe_load_file("/home/patrick_huggingface_co/stable-diffusion-v1-4/unet/diffusion_pytorch_model.safetensors") checkpoint = {k: v.to("cuda:0") for k, v in checkpoint.items()} print("On CPU", time.time() - start_time)