PANH commited on
Commit
5638c25
1 Parent(s): 80e54ec

converting cptk to safe tensors

Browse files
Files changed (1) hide show
  1. app.py +89 -0
app.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from safetensors.torch import save_file
4
+ import requests
5
+ import os
6
+
7
+ def download_ckpt_file(ckpt_url):
8
+ """
9
+ Downloads the .ckpt file from the provided Hugging Face URL.
10
+ """
11
+ try:
12
+ # Get the filename from the URL
13
+ filename = ckpt_url.split("/")[-1]
14
+ response = requests.get(ckpt_url)
15
+
16
+ # Save the file locally
17
+ with open(filename, 'wb') as f:
18
+ f.write(response.content)
19
+
20
+ return filename
21
+ except Exception as e:
22
+ return None, f"Error downloading the file: {str(e)}"
23
+
24
+ def convert_ckpt_to_safetensors(ckpt_file):
25
+ """
26
+ Converts a .ckpt file to safetensors format.
27
+ """
28
+ try:
29
+ # Load the checkpoint
30
+ checkpoint = torch.load(ckpt_file, map_location='cpu')
31
+
32
+ # Ensure the checkpoint contains a 'state_dict'
33
+ if 'state_dict' in checkpoint:
34
+ state_dict = checkpoint['state_dict']
35
+ else:
36
+ state_dict = checkpoint
37
+
38
+ # Remove any prefixes if necessary (e.g., 'module.')
39
+ new_state_dict = {}
40
+ for key, value in state_dict.items():
41
+ if key.startswith('module.'):
42
+ new_key = key[len('module.'):]
43
+ else:
44
+ new_key = key
45
+ new_state_dict[new_key] = value
46
+
47
+ # Save to safetensors format
48
+ output_file = ckpt_file.replace(".ckpt", ".safetensors")
49
+ save_file(new_state_dict, output_file)
50
+ return output_file
51
+ except Exception as e:
52
+ return f"Error converting to safetensors: {str(e)}"
53
+
54
+ def handle_conversion(ckpt_url):
55
+ """
56
+ Handles the entire process of downloading the .ckpt file from the link,
57
+ converting it to safetensors, and providing the user with the output file.
58
+ """
59
+ # Download the .ckpt file
60
+ filename = download_ckpt_file(ckpt_url)
61
+ if not filename:
62
+ return None, "Failed to download the file."
63
+
64
+ # Convert the .ckpt file to safetensors
65
+ safetensors_file = convert_ckpt_to_safetensors(filename)
66
+
67
+ # If the conversion is successful, return the safetensors file
68
+ if safetensors_file.endswith(".safetensors"):
69
+ return safetensors_file
70
+ else:
71
+ return None, safetensors_file
72
+
73
+ # Gradio Interface
74
+ def convert_and_download(ckpt_url):
75
+ safetensors_file, message = handle_conversion(ckpt_url)
76
+
77
+ if safetensors_file:
78
+ return safetensors_file # Provide the converted file for download
79
+ else:
80
+ return message # Provide the error message
81
+
82
+ # Define the Gradio interface
83
+ gr.Interface(
84
+ fn=convert_and_download,
85
+ inputs=gr.Textbox(label="Hugging Face CKPT File URL", placeholder="Enter the link to a .ckpt file"),
86
+ outputs=gr.File(label="Download .safetensors file"),
87
+ title="CKPT to Safetensors Converter",
88
+ description="Enter the Hugging Face URL for a .ckpt file and convert it to safetensors format."
89
+ ).launch()