import gradio as gr import torch from safetensors.torch import save_file import requests import os def download_ckpt_file(ckpt_url): """ Downloads the .ckpt file from the provided Hugging Face URL. """ try: # Get the filename from the URL filename = ckpt_url.split("/")[-1] response = requests.get(ckpt_url) # Save the file locally with open(filename, 'wb') as f: f.write(response.content) return filename except Exception as e: return None, f"Error downloading the file: {str(e)}" def convert_ckpt_to_safetensors(ckpt_file): """ Converts a .ckpt file to safetensors format. """ try: # Load the checkpoint checkpoint = torch.load(ckpt_file, map_location='cpu') # Ensure the checkpoint contains a 'state_dict' if 'state_dict' in checkpoint: state_dict = checkpoint['state_dict'] else: state_dict = checkpoint # Remove any prefixes if necessary (e.g., 'module.') new_state_dict = {} for key, value in state_dict.items(): if key.startswith('module.'): new_key = key[len('module.'):] else: new_key = key new_state_dict[new_key] = value # Save to safetensors format output_file = ckpt_file.replace(".ckpt", ".safetensors") save_file(new_state_dict, output_file) return output_file except Exception as e: return f"Error converting to safetensors: {str(e)}" def handle_conversion(ckpt_url): """ Handles the entire process of downloading the .ckpt file from the link, converting it to safetensors, and providing the user with the output file. """ # Download the .ckpt file filename = download_ckpt_file(ckpt_url) if not filename: return None, "Failed to download the file." # Convert the .ckpt file to safetensors safetensors_file = convert_ckpt_to_safetensors(filename) # If the conversion is successful, return the safetensors file if safetensors_file.endswith(".safetensors"): return safetensors_file else: return None, safetensors_file # Gradio Interface def convert_and_download(ckpt_url): safetensors_file, message = handle_conversion(ckpt_url) if safetensors_file: return safetensors_file # Provide the converted file for download else: return message # Provide the error message # Define the Gradio interface gr.Interface( fn=convert_and_download, inputs=gr.Textbox(label="Hugging Face CKPT File URL", placeholder="Enter the link to a .ckpt file"), outputs=gr.File(label="Download .safetensors file"), title="CKPT to Safetensors Converter", description="Enter the Hugging Face URL for a .ckpt file and convert it to safetensors format." ).launch()