John6666 commited on
Commit
dfe0a88
1 Parent(s): 3d57406

Upload convert_url_to_diffusers_flux_gr.py

Browse files
Files changed (1) hide show
  1. convert_url_to_diffusers_flux_gr.py +72 -52
convert_url_to_diffusers_flux_gr.py CHANGED
@@ -10,7 +10,14 @@ import os
10
  import argparse
11
  import gradio as gr
12
  # also requires aria, gdown, peft, huggingface_hub, safetensors, transformers, accelerate, pytorch_lightning
 
 
 
 
13
  import spaces
 
 
 
14
 
15
  flux_dev_repo = "ChuckMcSneed/FLUX.1-dev"
16
  flux_schnell_repo = "black-forest-labs/FLUX.1-schnell"
@@ -38,6 +45,24 @@ def is_repo_name(s):
38
  import re
39
  return re.fullmatch(r'^[^/,\s]+?/[^/,\s]+?$', s)
40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  def print_resource_usage():
42
  import psutil
43
  cpu_usage = psutil.cpu_percent()
@@ -363,7 +388,7 @@ def read_safetensors_metadata(path):
363
 
364
  def normalize_key(k: str):
365
  return k.replace("vae.", "").replace("model.diffusion_model.", "")\
366
- .replace("text_encoders.clip_l.transformer.text_model.", "")\
367
  .replace("text_encoders.t5xxl.transformer.", "")
368
 
369
  def load_json_list(path: str):
@@ -465,9 +490,7 @@ with torch.no_grad():
465
  print(e)
466
  return
467
  finally:
468
- del state_dict
469
- torch.cuda.empty_cache()
470
- gc.collect()
471
  new_path = str(Path(savepath, Path(path).stem + "_fixed" + Path(path).suffix))
472
  metadata = read_safetensors_metadata(path)
473
  progress(0.5, desc=f"Saving FLUX.1 safetensors: {new_path}")
@@ -476,9 +499,7 @@ with torch.no_grad():
476
  save_file(new_sd, new_path, metadata={"format": "pt", **metadata})
477
  progress(1, desc=f"Saved FLUX.1 safetensors: {new_path}")
478
  print(f"Saved FLUX.1 safetensors: {new_path}")
479
- del new_sd
480
- torch.cuda.empty_cache()
481
- gc.collect()
482
 
483
  with torch.no_grad():
484
  def extract_norm_flux_module_sd(path: str, dtype: torch.dtype = torch.bfloat16,
@@ -506,9 +527,7 @@ with torch.no_grad():
506
  finally:
507
  progress(1, desc=f"Normalized FLUX.1 {name} safetensors: {path}")
508
  print(f"Normalized FLUX.1 {name} safetensors: {path}")
509
- del state_dict
510
- torch.cuda.empty_cache()
511
- gc.collect()
512
  return new_sd
513
 
514
  with torch.no_grad():
@@ -541,9 +560,7 @@ with torch.no_grad():
541
  for k, v in sharded_sd.items():
542
  sharded_sd[k] = v.to(device="cpu")
543
  sd = sd | sharded_sd.copy()
544
- del sharded_sd
545
- torch.cuda.empty_cache()
546
- gc.collect()
547
  except Exception as e:
548
  print(e)
549
  return sd
@@ -561,9 +578,7 @@ with torch.no_grad():
561
  for k, v in sd.items():
562
  if k in set(keys_flux_transformer): sd[k] = v.to(device="cpu")
563
  save_torch_state_dict(sd, path, filename_pattern=pattern, max_shard_size=size)
564
- del sd
565
- torch.cuda.empty_cache()
566
- gc.collect()
567
  progress(0.25, desc=f"Saved temporary files to disk: {path}")
568
  print(f"Saved temporary files to disk: {path}")
569
  for filepath in glob.glob(f"{path}/*.safetensors"):
@@ -574,9 +589,7 @@ with torch.no_grad():
574
  for k, v in sharded_sd.items():
575
  sharded_sd[k] = v.to(device="cpu")
576
  save_file(sharded_sd, str(filepath))
577
- del sharded_sd
578
- torch.cuda.empty_cache()
579
- gc.collect()
580
  print(f"Loading temporary files from disk: {path}")
581
  sd = load_sharded_safetensors(path)
582
  print(f"Loaded temporary files from disk: {path}")
@@ -599,9 +612,7 @@ with torch.no_grad():
599
  for k, v in sd.items():
600
  sd[k] = v.to(device="cpu")
601
  save_torch_state_dict(sd, path, filename_pattern=pattern, max_shard_size=size)
602
- del sd
603
- torch.cuda.empty_cache()
604
- gc.collect()
605
  progress(0.25, desc=f"Saved temporary files to disk: {path}")
606
  print(f"Saved temporary files to disk: {path}")
607
  for filepath in glob.glob(f"{path}/*.safetensors"):
@@ -612,9 +623,7 @@ with torch.no_grad():
612
  for k, v in sharded_sd.items():
613
  sharded_sd[k] = v.to(device="cpu")
614
  save_file(sharded_sd, str(filepath))
615
- del sharded_sd
616
- torch.cuda.empty_cache()
617
- gc.collect()
618
  print(f"Processed temporary files: {str(filepath)}")
619
  print(f"Loading temporary files from disk: {path}")
620
  sd = load_sharded_safetensors(path)
@@ -678,8 +687,7 @@ with torch.no_grad():
678
  quantization: bool = False, model_type: str = "dev", dequant: bool = False):
679
  save_flux_other_diffusers(savepath, model_type)
680
  normalize_flux_state_dict(loadpath, savepath, dtype, dequant)
681
- torch.cuda.empty_cache()
682
- gc.collect()
683
 
684
  with torch.no_grad(): # Much lower memory consumption, but higher disk load
685
  def flux_to_diffusers_lowmem(loadpath: str, savepath: str, dtype: torch.dtype = torch.bfloat16,
@@ -698,40 +706,46 @@ with torch.no_grad(): # Much lower memory consumption, but higher disk load
698
  vae_sd_path = savepath.removesuffix("/") + "/vae"
699
  vae_sd_pattern = "diffusion_pytorch_model{suffix}.safetensors"
700
  vae_sd_size = "10GB"
 
701
  metadata = {"format": "pt", **read_safetensors_metadata(loadpath)}
 
 
702
  if "vae" not in use_original:
703
  vae_sd = extract_norm_flux_module_sd(loadpath, torch.bfloat16, dequant, "VAE",
704
  keys_flux_vae)
705
  to_safetensors_flux_module(vae_sd, vae_sd_path, vae_sd_pattern, vae_sd_size,
706
  quantization, "VAE", None)
707
- del vae_sd
708
- torch.cuda.empty_cache()
709
- gc.collect()
710
  if "text_encoder" not in use_original:
711
  clip_sd = extract_norm_flux_module_sd(loadpath, torch.bfloat16, dequant, "Text Encoder",
712
  keys_flux_clip)
713
  to_safetensors_flux_module(clip_sd, clip_sd_path, clip_sd_pattern, clip_sd_size,
714
  quantization, "Text Encoder", None)
715
- del clip_sd
716
- torch.cuda.empty_cache()
717
- gc.collect()
718
  if "text_encoder_2" not in use_original:
719
  te_sd = extract_norm_flux_module_sd(loadpath, dtype, dequant, "Text Encoder 2",
720
  keys_flux_t5xxl)
721
  to_safetensors_flux_module(te_sd, te_sd_path, te_sd_pattern, te_sd_size,
722
  quantization, "Text Encoder 2", None)
723
- del te_sd
724
- torch.cuda.empty_cache()
725
- gc.collect()
726
  unet_sd = extract_norm_flux_module_sd(loadpath, dtype, dequant, "Transformer",
727
  keys_flux_transformer)
728
- if not local: os.remove(loadpath)
 
 
 
 
 
 
729
  to_safetensors_flux_module(unet_sd, unet_sd_path, unet_sd_pattern, unet_sd_size,
730
  quantization, "Transformer", metadata)
731
- del unet_sd
732
- torch.cuda.empty_cache()
733
- gc.collect()
734
  save_flux_other_diffusers(savepath, model_type, use_original)
 
735
 
736
  with torch.no_grad(): # lowest memory consumption, but higheest disk load
737
  def flux_to_diffusers_lowmem2(loadpath: str, savepath: str, dtype: torch.dtype = torch.bfloat16,
@@ -752,47 +766,52 @@ with torch.no_grad(): # lowest memory consumption, but higheest disk load
752
  vae_sd_path = savepath.removesuffix("/") + "/vae"
753
  vae_sd_pattern = "diffusion_pytorch_model{suffix}.safetensors"
754
  vae_sd_size = "10GB"
 
755
  metadata = {"format": "pt", **read_safetensors_metadata(loadpath)}
 
 
756
  if "vae" not in use_original:
757
  vae_sd = extract_norm_flux_module_sd(loadpath, torch.bfloat16, dequant, "VAE",
758
  keys_flux_vae)
759
  to_safetensors_flux_module(vae_sd, vae_sd_path, vae_sd_pattern, vae_sd_size,
760
  quantization, "VAE", None)
761
- del vae_sd
762
- torch.cuda.empty_cache()
763
- gc.collect()
764
  if "text_encoder" not in use_original:
765
  clip_sd = extract_norm_flux_module_sd(loadpath, torch.bfloat16, dequant, "Text Encoder",
766
  keys_flux_clip)
767
  to_safetensors_flux_module(clip_sd, clip_sd_path, clip_sd_pattern, clip_sd_size,
768
  quantization, "Text Encoder", None)
769
- del clip_sd
770
- torch.cuda.empty_cache()
771
- gc.collect()
772
  if "text_encoder_2" not in use_original:
773
  te_sd = extract_norm_flux_module_sd(loadpath, dtype, dequant, "Text Encoder 2",
774
  keys_flux_t5xxl)
775
  to_safetensors_flux_module(te_sd, te_sd_path, te_sd_pattern, te_sd_size,
776
  quantization, "Text Encoder 2", None)
777
- del te_sd
778
- torch.cuda.empty_cache()
779
- gc.collect()
780
  unet_sd = extract_normalized_flux_state_dict_sharded(loadpath, dtype, dequant,
781
  unet_temp_path, unet_sd_pattern, unet_temp_size)
 
 
782
  unet_sd = convert_flux_transformer_sd_to_diffusers_sharded(unet_sd, unet_temp_path,
783
  unet_sd_pattern, unet_temp_size)
 
 
784
  to_safetensors_flux_module(unet_sd, unet_sd_path, unet_sd_pattern, unet_sd_size,
785
  quantization, "Transformer", metadata)
786
- del unet_sd
787
- torch.cuda.empty_cache()
788
- gc.collect()
789
  save_flux_other_diffusers(savepath, model_type, use_original)
 
790
 
791
  def convert_url_to_diffusers_flux(url, civitai_key="", is_upload_sf=False, data_type="bf16",
792
  model_type="dev", dequant=False, use_original=["vae", "text_encoder"],
793
  hf_user="", hf_repo="", q=None, progress=gr.Progress(track_tqdm=True)):
794
  progress(0, desc="Start converting...")
795
  temp_dir = "."
 
796
  new_file = get_download_file(temp_dir, url, civitai_key)
797
  if not new_file:
798
  print(f"Not found: {url}")
@@ -825,6 +844,7 @@ def convert_url_to_fixed_flux_safetensors(url, civitai_key="", is_upload_sf=Fals
825
  model_type="dev", dequant=False, q=None, progress=gr.Progress(track_tqdm=True)):
826
  progress(0, desc="Start converting...")
827
  temp_dir = "."
 
828
  new_file = get_download_file(temp_dir, url, civitai_key)
829
  if not new_file:
830
  print(f"Not found: {url}")
 
10
  import argparse
11
  import gradio as gr
12
  # also requires aria, gdown, peft, huggingface_hub, safetensors, transformers, accelerate, pytorch_lightning
13
+
14
+ import subprocess
15
+ subprocess.run('pip cache purge', shell=True)
16
+
17
  import spaces
18
+ @spaces.GPU()
19
+ def spaces_dummy():
20
+ pass
21
 
22
  flux_dev_repo = "ChuckMcSneed/FLUX.1-dev"
23
  flux_schnell_repo = "black-forest-labs/FLUX.1-schnell"
 
45
  import re
46
  return re.fullmatch(r'^[^/,\s]+?/[^/,\s]+?$', s)
47
 
48
+ def clear_cache():
49
+ torch.cuda.empty_cache()
50
+ gc.collect()
51
+
52
+ def clear_sd(sd: dict):
53
+ for k in list(sd.keys()):
54
+ sd.pop(k)
55
+ del sd
56
+ torch.cuda.empty_cache()
57
+ gc.collect()
58
+
59
+ def clone_sd(sd: dict):
60
+ print("Cloning state dict.")
61
+ for k in list(sd.keys()):
62
+ sd[k] = sd.pop(k).detach().clone()
63
+ torch.cuda.empty_cache()
64
+ gc.collect()
65
+
66
  def print_resource_usage():
67
  import psutil
68
  cpu_usage = psutil.cpu_percent()
 
388
 
389
  def normalize_key(k: str):
390
  return k.replace("vae.", "").replace("model.diffusion_model.", "")\
391
+ .replace("text_encoders.clip_l.transformer.", "")\
392
  .replace("text_encoders.t5xxl.transformer.", "")
393
 
394
  def load_json_list(path: str):
 
490
  print(e)
491
  return
492
  finally:
493
+ clear_sd(state_dict)
 
 
494
  new_path = str(Path(savepath, Path(path).stem + "_fixed" + Path(path).suffix))
495
  metadata = read_safetensors_metadata(path)
496
  progress(0.5, desc=f"Saving FLUX.1 safetensors: {new_path}")
 
499
  save_file(new_sd, new_path, metadata={"format": "pt", **metadata})
500
  progress(1, desc=f"Saved FLUX.1 safetensors: {new_path}")
501
  print(f"Saved FLUX.1 safetensors: {new_path}")
502
+ clear_sd(new_sd)
 
 
503
 
504
  with torch.no_grad():
505
  def extract_norm_flux_module_sd(path: str, dtype: torch.dtype = torch.bfloat16,
 
527
  finally:
528
  progress(1, desc=f"Normalized FLUX.1 {name} safetensors: {path}")
529
  print(f"Normalized FLUX.1 {name} safetensors: {path}")
530
+ clear_sd(state_dict)
 
 
531
  return new_sd
532
 
533
  with torch.no_grad():
 
560
  for k, v in sharded_sd.items():
561
  sharded_sd[k] = v.to(device="cpu")
562
  sd = sd | sharded_sd.copy()
563
+ clear_sd(sharded_sd)
 
 
564
  except Exception as e:
565
  print(e)
566
  return sd
 
578
  for k, v in sd.items():
579
  if k in set(keys_flux_transformer): sd[k] = v.to(device="cpu")
580
  save_torch_state_dict(sd, path, filename_pattern=pattern, max_shard_size=size)
581
+ clear_sd(sd)
 
 
582
  progress(0.25, desc=f"Saved temporary files to disk: {path}")
583
  print(f"Saved temporary files to disk: {path}")
584
  for filepath in glob.glob(f"{path}/*.safetensors"):
 
589
  for k, v in sharded_sd.items():
590
  sharded_sd[k] = v.to(device="cpu")
591
  save_file(sharded_sd, str(filepath))
592
+ clear_sd(sharded_sd)
 
 
593
  print(f"Loading temporary files from disk: {path}")
594
  sd = load_sharded_safetensors(path)
595
  print(f"Loaded temporary files from disk: {path}")
 
612
  for k, v in sd.items():
613
  sd[k] = v.to(device="cpu")
614
  save_torch_state_dict(sd, path, filename_pattern=pattern, max_shard_size=size)
615
+ clear_sd(sd)
 
 
616
  progress(0.25, desc=f"Saved temporary files to disk: {path}")
617
  print(f"Saved temporary files to disk: {path}")
618
  for filepath in glob.glob(f"{path}/*.safetensors"):
 
623
  for k, v in sharded_sd.items():
624
  sharded_sd[k] = v.to(device="cpu")
625
  save_file(sharded_sd, str(filepath))
626
+ clear_sd(sharded_sd)
 
 
627
  print(f"Processed temporary files: {str(filepath)}")
628
  print(f"Loading temporary files from disk: {path}")
629
  sd = load_sharded_safetensors(path)
 
687
  quantization: bool = False, model_type: str = "dev", dequant: bool = False):
688
  save_flux_other_diffusers(savepath, model_type)
689
  normalize_flux_state_dict(loadpath, savepath, dtype, dequant)
690
+ clear_cache()
 
691
 
692
  with torch.no_grad(): # Much lower memory consumption, but higher disk load
693
  def flux_to_diffusers_lowmem(loadpath: str, savepath: str, dtype: torch.dtype = torch.bfloat16,
 
706
  vae_sd_path = savepath.removesuffix("/") + "/vae"
707
  vae_sd_pattern = "diffusion_pytorch_model{suffix}.safetensors"
708
  vae_sd_size = "10GB"
709
+ print_resource_usage() #
710
  metadata = {"format": "pt", **read_safetensors_metadata(loadpath)}
711
+ clear_cache()
712
+ print_resource_usage() #
713
  if "vae" not in use_original:
714
  vae_sd = extract_norm_flux_module_sd(loadpath, torch.bfloat16, dequant, "VAE",
715
  keys_flux_vae)
716
  to_safetensors_flux_module(vae_sd, vae_sd_path, vae_sd_pattern, vae_sd_size,
717
  quantization, "VAE", None)
718
+ clear_sd(vae_sd)
719
+ print_resource_usage() #
 
720
  if "text_encoder" not in use_original:
721
  clip_sd = extract_norm_flux_module_sd(loadpath, torch.bfloat16, dequant, "Text Encoder",
722
  keys_flux_clip)
723
  to_safetensors_flux_module(clip_sd, clip_sd_path, clip_sd_pattern, clip_sd_size,
724
  quantization, "Text Encoder", None)
725
+ clear_sd(clip_sd)
726
+ print_resource_usage() #
 
727
  if "text_encoder_2" not in use_original:
728
  te_sd = extract_norm_flux_module_sd(loadpath, dtype, dequant, "Text Encoder 2",
729
  keys_flux_t5xxl)
730
  to_safetensors_flux_module(te_sd, te_sd_path, te_sd_pattern, te_sd_size,
731
  quantization, "Text Encoder 2", None)
732
+ clear_sd(te_sd)
733
+ print_resource_usage() #
 
734
  unet_sd = extract_norm_flux_module_sd(loadpath, dtype, dequant, "Transformer",
735
  keys_flux_transformer)
736
+ clear_cache()
737
+ print_resource_usage() #
738
+ if not local:
739
+ os.remove(loadpath)
740
+ print("Deleted downloaded file.")
741
+ clear_cache()
742
+ print_resource_usage() #
743
  to_safetensors_flux_module(unet_sd, unet_sd_path, unet_sd_pattern, unet_sd_size,
744
  quantization, "Transformer", metadata)
745
+ clear_sd(unet_sd)
746
+ print_resource_usage() #
 
747
  save_flux_other_diffusers(savepath, model_type, use_original)
748
+ print_resource_usage() #
749
 
750
  with torch.no_grad(): # lowest memory consumption, but higheest disk load
751
  def flux_to_diffusers_lowmem2(loadpath: str, savepath: str, dtype: torch.dtype = torch.bfloat16,
 
766
  vae_sd_path = savepath.removesuffix("/") + "/vae"
767
  vae_sd_pattern = "diffusion_pytorch_model{suffix}.safetensors"
768
  vae_sd_size = "10GB"
769
+ print_resource_usage() #
770
  metadata = {"format": "pt", **read_safetensors_metadata(loadpath)}
771
+ clear_cache()
772
+ print_resource_usage() #
773
  if "vae" not in use_original:
774
  vae_sd = extract_norm_flux_module_sd(loadpath, torch.bfloat16, dequant, "VAE",
775
  keys_flux_vae)
776
  to_safetensors_flux_module(vae_sd, vae_sd_path, vae_sd_pattern, vae_sd_size,
777
  quantization, "VAE", None)
778
+ clear_sd(vae_sd)
779
+ print_resource_usage() #
 
780
  if "text_encoder" not in use_original:
781
  clip_sd = extract_norm_flux_module_sd(loadpath, torch.bfloat16, dequant, "Text Encoder",
782
  keys_flux_clip)
783
  to_safetensors_flux_module(clip_sd, clip_sd_path, clip_sd_pattern, clip_sd_size,
784
  quantization, "Text Encoder", None)
785
+ clear_sd(clip_sd)
786
+ print_resource_usage() #
 
787
  if "text_encoder_2" not in use_original:
788
  te_sd = extract_norm_flux_module_sd(loadpath, dtype, dequant, "Text Encoder 2",
789
  keys_flux_t5xxl)
790
  to_safetensors_flux_module(te_sd, te_sd_path, te_sd_pattern, te_sd_size,
791
  quantization, "Text Encoder 2", None)
792
+ clear_sd(te_sd)
793
+ print_resource_usage() #
 
794
  unet_sd = extract_normalized_flux_state_dict_sharded(loadpath, dtype, dequant,
795
  unet_temp_path, unet_sd_pattern, unet_temp_size)
796
+ clear_cache()
797
+ print_resource_usage() #
798
  unet_sd = convert_flux_transformer_sd_to_diffusers_sharded(unet_sd, unet_temp_path,
799
  unet_sd_pattern, unet_temp_size)
800
+ clear_cache()
801
+ print_resource_usage() #
802
  to_safetensors_flux_module(unet_sd, unet_sd_path, unet_sd_pattern, unet_sd_size,
803
  quantization, "Transformer", metadata)
804
+ clear_sd(unet_sd)
805
+ print_resource_usage() #
 
806
  save_flux_other_diffusers(savepath, model_type, use_original)
807
+ print_resource_usage() #
808
 
809
  def convert_url_to_diffusers_flux(url, civitai_key="", is_upload_sf=False, data_type="bf16",
810
  model_type="dev", dequant=False, use_original=["vae", "text_encoder"],
811
  hf_user="", hf_repo="", q=None, progress=gr.Progress(track_tqdm=True)):
812
  progress(0, desc="Start converting...")
813
  temp_dir = "."
814
+ print_resource_usage() #
815
  new_file = get_download_file(temp_dir, url, civitai_key)
816
  if not new_file:
817
  print(f"Not found: {url}")
 
844
  model_type="dev", dequant=False, q=None, progress=gr.Progress(track_tqdm=True)):
845
  progress(0, desc="Start converting...")
846
  temp_dir = "."
847
+ print_resource_usage() #
848
  new_file = get_download_file(temp_dir, url, civitai_key)
849
  if not new_file:
850
  print(f"Not found: {url}")