sayakpaul HF staff commited on
Commit
5d813dc
·
verified ·
1 Parent(s): 2b36710

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -40
app.py CHANGED
@@ -2,10 +2,10 @@ from huggingface_hub import model_info, hf_hub_download
2
  import gradio as gr
3
  import json
4
 
 
5
 
6
  def format_size(num: int) -> str:
7
  """Format size in bytes into a human-readable string.
8
-
9
  Taken from https://stackoverflow.com/a/1094933
10
  """
11
  num_f = float(num)
@@ -43,10 +43,25 @@ def get_component_wise_memory(pipeline_id, token=None, variant=None, revision=No
43
  files_in_repo = model_info(pipeline_id, revision=revision, token=token, files_metadata=True).siblings
44
  index_dict = load_model_index(pipeline_id, token=token, revision=revision)
45
 
46
- is_text_encoder_shared = any(".index.json" in file_obj.rfilename for file_obj in files_in_repo)
47
- component_wise_memory = {}
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
  # Handle text encoder separately when it's sharded.
 
 
50
  if is_text_encoder_shared:
51
  for current_file in files_in_repo:
52
  if "text_encoder" in current_file.rfilename:
@@ -60,10 +75,7 @@ def get_component_wise_memory(pipeline_id, token=None, variant=None, revision=No
60
  else:
61
  component_wise_memory["text_encoder"] += selected_file.size
62
 
63
- print(component_wise_memory)
64
-
65
  # Handle pipeline components.
66
- component_filter = ["scheduler", "feature_extractor", "safety_checker", "tokenizer"]
67
  if is_text_encoder_shared:
68
  component_filter.append("text_encoder")
69
 
@@ -87,37 +99,4 @@ def get_component_wise_memory(pipeline_id, token=None, variant=None, revision=No
87
  print(selected_file.rfilename)
88
  component_wise_memory[component] = selected_file.size
89
 
90
- return format_output(pipeline_id, component_wise_memory)
91
-
92
-
93
- gr.Interface(
94
- title="Compute component-wise memory of a 🧨 Diffusers pipeline.",
95
- description="Sizes will be reported in GB. Pipelines containing text encoders with sharded checkpoints are also supported (PixArt-Alpha, for example) 🤗",
96
- fn=get_component_wise_memory,
97
- inputs=[
98
- gr.components.Textbox(lines=1, label="pipeline_id", info="Example: runwayml/stable-diffusion-v1-5"),
99
- gr.components.Textbox(lines=1, label="hf_token", info="Pass this in case of private repositories."),
100
- gr.components.Dropdown(
101
- [
102
- "fp32",
103
- "fp16",
104
- ],
105
- label="variant",
106
- info="Precision to use for calculation.",
107
- ),
108
- gr.components.Textbox(lines=1, label="revision", info="Repository revision to use."),
109
- gr.components.Dropdown(
110
- [".bin", ".safetensors"],
111
- label="extension",
112
- info="Extension to use.",
113
- ),
114
- ],
115
- outputs=[gr.Markdown(label="Output")],
116
- examples=[
117
- ["runwayml/stable-diffusion-v1-5", None, "fp32", None, ".safetensors"],
118
- ["stabilityai/stable-diffusion-xl-base-1.0", None, "fp16", None, ".safetensors"],
119
- ["PixArt-alpha/PixArt-XL-2-1024-MS", None, "fp32", None, ".safetensors"],
120
- ],
121
- theme=gr.themes.Soft(),
122
- allow_flagging=False,
123
- ).launch(show_error=True)
 
2
  import gradio as gr
3
  import json
4
 
5
+ component_filter = ["scheduler", "safety_checker", "tokenizer"]
6
 
7
  def format_size(num: int) -> str:
8
  """Format size in bytes into a human-readable string.
 
9
  Taken from https://stackoverflow.com/a/1094933
10
  """
11
  num_f = float(num)
 
43
  files_in_repo = model_info(pipeline_id, revision=revision, token=token, files_metadata=True).siblings
44
  index_dict = load_model_index(pipeline_id, token=token, revision=revision)
45
 
46
+ # Check if all the concerned components have the checkpoints in the requested "variant" and "extension".
47
+ index_filter = component_filter.copy()
48
+ index_filter.extend(["_class_name", "_diffusers_version"])
49
+ for current_component in index_dict:
50
+ if current_component not in index_filter:
51
+ current_component_fileobjs = list(filter(lambda x: current_component in x.rfilename, files_in_repo))
52
+ if current_component_fileobjs:
53
+ current_component_filenames = [fileobj.rfilename for fileobj in current_component_fileobjs]
54
+ condition = lambda filename: extension in filename and variant in filename if variant is not None else lambda filename: extension in filename
55
+ variant_present_with_extension = any(condition(filename) for filename in current_component_filenames)
56
+ if not variant_present_with_extension:
57
+ raise ValueError(f"Requested extension ({extension}) and variant ({variant}) not present for {current_component}. Available files for this component:\n{current_component_filenames}.")
58
+ else:
59
+ raise ValueError(f"Problem with {current_component}.")
60
+
61
 
62
  # Handle text encoder separately when it's sharded.
63
+ is_text_encoder_shared = any(".index.json" in file_obj.rfilename for file_obj in files_in_repo)
64
+ component_wise_memory = {}
65
  if is_text_encoder_shared:
66
  for current_file in files_in_repo:
67
  if "text_encoder" in current_file.rfilename:
 
75
  else:
76
  component_wise_memory["text_encoder"] += selected_file.size
77
 
 
 
78
  # Handle pipeline components.
 
79
  if is_text_encoder_shared:
80
  component_filter.append("text_encoder")
81
 
 
99
  print(selected_file.rfilename)
100
  component_wise_memory[component] = selected_file.size
101
 
102
+ return format_output(pipeline_id, component_wise_memory)