PANH commited on
Commit
9b45214
1 Parent(s): 9115b9f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -21
app.py CHANGED
@@ -6,47 +6,77 @@ import os
6
 
7
  def convert_ckpt_to_safetensors(input_path, output_path):
8
  # Load the .ckpt file
9
- state_dict = torch.load(input_path, map_location='cpu')
 
 
 
10
 
11
- # If the checkpoint has a 'state_dict' key, extract it
12
- if 'state_dict' in state_dict:
13
- state_dict = state_dict['state_dict']
14
- elif 'model' in state_dict:
15
- state_dict = state_dict['model']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
- # Save the entire state dictionary, including non-tensor entries
18
- save_model(state_dict, output_path)
19
 
20
  def process(url, uploaded_file):
21
  if url:
22
  # Download the .ckpt file
23
  local_filename = 'model.ckpt'
24
- with requests.get(url, stream=True) as r:
25
- r.raise_for_status()
26
- with open(local_filename, 'wb') as f:
27
- for chunk in r.iter_content(chunk_size=8192):
28
- f.write(chunk)
 
 
 
29
  elif uploaded_file is not None:
30
  # Save uploaded file
31
- local_filename = uploaded_file.name
32
- with open(local_filename, 'wb') as f:
33
- f.write(uploaded_file.read())
 
 
 
34
  else:
35
- return "Please provide a URL or upload a .ckpt file."
36
 
37
  output_filename = local_filename.replace('.ckpt', '.safetensors')
 
38
  # Convert the .ckpt to .safetensors
39
  try:
40
- convert_ckpt_to_safetensors(local_filename, output_filename)
 
 
 
 
41
  except Exception as e:
42
  # Clean up the input file
43
  os.remove(local_filename)
44
- return f"An error occurred during conversion: {e}"
45
 
46
  # Clean up the input file
47
  os.remove(local_filename)
48
 
49
- return output_filename
 
50
 
51
  iface = gr.Interface(
52
  fn=process,
@@ -56,7 +86,10 @@ iface = gr.Interface(
56
  ],
57
  outputs=gr.File(label="Converted .safetensors file"),
58
  title="CKPT to SafeTensors Converter",
59
- description="Convert .ckpt files to .safetensors format. Provide a URL or upload your .ckpt file."
 
 
 
60
  )
61
 
62
  iface.launch()
 
6
 
7
  def convert_ckpt_to_safetensors(input_path, output_path):
8
  # Load the .ckpt file
9
+ # ⚠️ SECURITY WARNING:
10
+ # Loading untrusted .ckpt files with torch.load() can execute arbitrary code.
11
+ # Only load files from trusted sources.
12
+ obj = torch.load(input_path, map_location='cpu')
13
 
14
+ # Determine if obj is a state dict or a model object
15
+ if isinstance(obj, dict):
16
+ # Check for nested 'state_dict' or 'model' keys
17
+ if 'state_dict' in obj:
18
+ state_dict = obj['state_dict']
19
+ elif 'model' in obj:
20
+ state_dict = obj['model']
21
+ else:
22
+ # Assume obj is the state dict
23
+ state_dict = obj
24
+ elif hasattr(obj, 'state_dict'):
25
+ # If obj is a model object
26
+ state_dict = obj.state_dict()
27
+ else:
28
+ return "Unsupported checkpoint format."
29
+
30
+ # Save the state dictionary, including shared tensors and LM head
31
+ try:
32
+ save_model(state_dict, output_path)
33
+ except Exception as e:
34
+ return f"An error occurred during saving: {e}"
35
 
36
+ return "Success"
 
37
 
38
  def process(url, uploaded_file):
39
  if url:
40
  # Download the .ckpt file
41
  local_filename = 'model.ckpt'
42
+ try:
43
+ with requests.get(url, stream=True) as r:
44
+ r.raise_for_status()
45
+ with open(local_filename, 'wb') as f:
46
+ for chunk in r.iter_content(chunk_size=8192):
47
+ f.write(chunk)
48
+ except Exception as e:
49
+ return f"<p style='color:red;'>Failed to download file: {e}</p>"
50
  elif uploaded_file is not None:
51
  # Save uploaded file
52
+ local_filename = 'uploaded_model.ckpt'
53
+ try:
54
+ with open(local_filename, 'wb') as f:
55
+ f.write(uploaded_file.read())
56
+ except Exception as e:
57
+ return f"<p style='color:red;'>Failed to save uploaded file: {e}</p>"
58
  else:
59
+ return "<p style='color:red;'>Please provide a URL or upload a .ckpt file.</p>"
60
 
61
  output_filename = local_filename.replace('.ckpt', '.safetensors')
62
+
63
  # Convert the .ckpt to .safetensors
64
  try:
65
+ result = convert_ckpt_to_safetensors(local_filename, output_filename)
66
+ if result != "Success":
67
+ # Clean up the input file
68
+ os.remove(local_filename)
69
+ return f"<p style='color:red;'>An error occurred during conversion: {result}</p>"
70
  except Exception as e:
71
  # Clean up the input file
72
  os.remove(local_filename)
73
+ return f"<p style='color:red;'>An exception occurred: {e}</p>"
74
 
75
  # Clean up the input file
76
  os.remove(local_filename)
77
 
78
+ # Provide a download link for the output file
79
+ return gr.File.update(value=output_filename, visible=True)
80
 
81
  iface = gr.Interface(
82
  fn=process,
 
86
  ],
87
  outputs=gr.File(label="Converted .safetensors file"),
88
  title="CKPT to SafeTensors Converter",
89
+ description="""
90
+ Convert .ckpt files to .safetensors format. Provide a URL or upload your .ckpt file.
91
+ **Security Warning:** Loading .ckpt files can execute arbitrary code. Only use files from trusted sources.
92
+ """
93
  )
94
 
95
  iface.launch()