add support for controlnet and t2i adapter too
Browse files
app.py
CHANGED
@@ -11,6 +11,24 @@ COMPONENT_FILTER = [
|
|
11 |
"_diffusers_version",
|
12 |
]
|
13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
def format_size(num: int) -> str:
|
16 |
"""Format size in bytes into a human-readable string.
|
@@ -24,11 +42,21 @@ def format_size(num: int) -> str:
|
|
24 |
return f"{num_f:.1f}Y"
|
25 |
|
26 |
|
27 |
-
def format_output(pipeline_id, memory_mapping):
|
28 |
markdown_str = f"## {pipeline_id}\n"
|
|
|
29 |
if memory_mapping:
|
30 |
for component, memory in memory_mapping.items():
|
31 |
markdown_str += f"* {component}: {format_size(memory)}\n"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
return markdown_str
|
33 |
|
34 |
|
@@ -39,7 +67,35 @@ def load_model_index(pipeline_id, token=None, revision=None):
|
|
39 |
return index_dict
|
40 |
|
41 |
|
42 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
if token == "":
|
44 |
token = None
|
45 |
|
@@ -49,12 +105,31 @@ def get_component_wise_memory(pipeline_id, token=None, variant=None, revision=No
|
|
49 |
if variant == "fp32":
|
50 |
variant = None
|
51 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
print(f"pipeline_id: {pipeline_id}, variant: {variant}, revision: {revision}, extension: {extension}")
|
53 |
|
|
|
54 |
files_in_repo = model_info(pipeline_id, revision=revision, token=token, files_metadata=True).siblings
|
55 |
index_dict = load_model_index(pipeline_id, token=token, revision=revision)
|
56 |
|
57 |
-
# Check if all the concerned components have the checkpoints in
|
|
|
58 |
print(f"Index dict: {index_dict}")
|
59 |
for current_component in index_dict:
|
60 |
if (
|
@@ -63,6 +138,7 @@ def get_component_wise_memory(pipeline_id, token=None, variant=None, revision=No
|
|
63 |
and len(index_dict[current_component]) == 2
|
64 |
):
|
65 |
current_component_fileobjs = list(filter(lambda x: current_component in x.rfilename, files_in_repo))
|
|
|
66 |
if current_component_fileobjs:
|
67 |
current_component_filenames = [fileobj.rfilename for fileobj in current_component_fileobjs]
|
68 |
condition = ( # noqa: E731
|
@@ -119,16 +195,20 @@ def get_component_wise_memory(pipeline_id, token=None, variant=None, revision=No
|
|
119 |
if selected_file is not None:
|
120 |
component_wise_memory[component] = selected_file.size
|
121 |
|
122 |
-
return format_output(pipeline_id, component_wise_memory)
|
123 |
|
124 |
|
125 |
with gr.Interface(
|
126 |
title="Compute component-wise memory of a 🧨 Diffusers pipeline.",
|
127 |
description="Pipelines containing text encoders with sharded checkpoints are also supported"
|
128 |
-
" (PixArt-Alpha, for example) 🤗"
|
|
|
|
|
129 |
fn=get_component_wise_memory,
|
130 |
inputs=[
|
131 |
gr.components.Textbox(lines=1, label="pipeline_id", info="Example: runwayml/stable-diffusion-v1-5"),
|
|
|
|
|
132 |
gr.components.Textbox(lines=1, label="hf_token", info="Pass this in case of private repositories."),
|
133 |
gr.components.Radio(
|
134 |
["fp32", "fp16", "bf16"],
|
@@ -144,11 +224,20 @@ with gr.Interface(
|
|
144 |
],
|
145 |
outputs=[gr.Markdown(label="Output")],
|
146 |
examples=[
|
147 |
-
["runwayml/stable-diffusion-v1-5", None, "fp32", None, ".safetensors"],
|
148 |
-
["
|
149 |
-
["
|
150 |
-
[
|
151 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
152 |
],
|
153 |
theme=gr.themes.Soft(),
|
154 |
allow_flagging="never",
|
|
|
11 |
"_diffusers_version",
|
12 |
]
|
13 |
|
14 |
+
ARTICLE = """
|
15 |
+
## Notes on how to use the `controlnet_id` and `t2i_adapter_id` fields
|
16 |
+
|
17 |
+
Both `controlnet_id` and `t2i_adapter_id` fields support passing multiple checkpoint ids,
|
18 |
+
e.g., "thibaud/controlnet-openpose-sdxl-1.0,diffusers/controlnet-canny-sdxl-1.0". For
|
19 |
+
`t2i_adapter_id`, this could be like - "TencentARC/t2iadapter_keypose_sd14v1,TencentARC/t2iadapter_depth_sd14v1".
|
20 |
+
|
21 |
+
Users should take care of passing the underlying base `pipeline_id` appropriately. For example,
|
22 |
+
passing `pipeline_id` as "runwayml/stable-diffusion-v1-5" and `controlnet_id` as "thibaud/controlnet-openpose-sdxl-1.0"
|
23 |
+
won't result in an error but these two things aren't meant to compatible. You should pass
|
24 |
+
a `controlnet_id` that is compatible with "runwayml/stable-diffusion-v1-5".
|
25 |
+
|
26 |
+
For further clarification on this topic, feel free to open a [discussion](https://huggingface.co/spaces/diffusers/compute-pipeline-size/discussions).
|
27 |
+
|
28 |
+
📔 Also, note that `revision` field is only reserved for `pipeline_id`. It won't have any effect on the
|
29 |
+
`controlnet_id` or `t2i_adapter_id`.
|
30 |
+
"""
|
31 |
+
|
32 |
|
33 |
def format_size(num: int) -> str:
|
34 |
"""Format size in bytes into a human-readable string.
|
|
|
42 |
return f"{num_f:.1f}Y"
|
43 |
|
44 |
|
45 |
+
def format_output(pipeline_id, memory_mapping, controlnet_mapping=None, t2i_adapter_mapping=None):
|
46 |
markdown_str = f"## {pipeline_id}\n"
|
47 |
+
|
48 |
if memory_mapping:
|
49 |
for component, memory in memory_mapping.items():
|
50 |
markdown_str += f"* {component}: {format_size(memory)}\n"
|
51 |
+
if controlnet_mapping:
|
52 |
+
markdown_str += "\n## ControlNet(s)\n"
|
53 |
+
for controlnet_id, memory in controlnet_mapping.items():
|
54 |
+
markdown_str += f"* {controlnet_id}: {format_size(memory)}\n"
|
55 |
+
if t2i_adapter_mapping:
|
56 |
+
markdown_str += "\n## T2I-Adapters(s)\n"
|
57 |
+
for t2_adapter_id, memory in t2i_adapter_mapping.items():
|
58 |
+
markdown_str += f"* {t2_adapter_id}: {format_size(memory)}\n"
|
59 |
+
|
60 |
return markdown_str
|
61 |
|
62 |
|
|
|
67 |
return index_dict
|
68 |
|
69 |
|
70 |
+
def get_individual_model_memory(id, token, variant, extension):
|
71 |
+
files_in_repo = model_info(id, token=token, files_metadata=True).siblings
|
72 |
+
for x in files_in_repo:
|
73 |
+
if extension in x.rfilename:
|
74 |
+
if variant:
|
75 |
+
if variant in x.rfilename:
|
76 |
+
return x.size
|
77 |
+
else:
|
78 |
+
return x.size
|
79 |
+
|
80 |
+
|
81 |
+
def get_component_wise_memory(
|
82 |
+
pipeline_id,
|
83 |
+
controlnet_id=None,
|
84 |
+
t2i_adapter_id=None,
|
85 |
+
token=None,
|
86 |
+
variant=None,
|
87 |
+
revision=None,
|
88 |
+
extension=".safetensors",
|
89 |
+
):
|
90 |
+
if controlnet_id == "":
|
91 |
+
controlnet_id = None
|
92 |
+
|
93 |
+
if t2i_adapter_id == "":
|
94 |
+
t2i_adapter_id = None
|
95 |
+
|
96 |
+
if controlnet_id and t2i_adapter_id:
|
97 |
+
raise ValueError("Both `controlnet_id` and `t2i_adapter_id` cannot be provided.")
|
98 |
+
|
99 |
if token == "":
|
100 |
token = None
|
101 |
|
|
|
105 |
if variant == "fp32":
|
106 |
variant = None
|
107 |
|
108 |
+
# Handle ControlNet and T2I-Adapter.
|
109 |
+
controlnet_mapping = t2_adapter_mapping = None
|
110 |
+
if controlnet_id is not None:
|
111 |
+
controlnet_id = controlnet_id.split(",")
|
112 |
+
controlnet_sizes = [
|
113 |
+
get_individual_model_memory(id_, token=token, variant=variant, extension=extension)
|
114 |
+
for id_ in controlnet_id
|
115 |
+
]
|
116 |
+
controlnet_mapping = dict(zip(controlnet_id, controlnet_sizes))
|
117 |
+
elif t2i_adapter_id is not None:
|
118 |
+
t2i_adapter_id = t2i_adapter_id.split(",")
|
119 |
+
t2i_adapter_sizes = [
|
120 |
+
get_individual_model_memory(id_, token=token, variant=variant, extension=extension)
|
121 |
+
for id_ in t2i_adapter_id
|
122 |
+
]
|
123 |
+
t2_adapter_mapping = dict(zip(t2i_adapter_id, t2i_adapter_sizes))
|
124 |
+
|
125 |
print(f"pipeline_id: {pipeline_id}, variant: {variant}, revision: {revision}, extension: {extension}")
|
126 |
|
127 |
+
# Load pipeline metadata.
|
128 |
files_in_repo = model_info(pipeline_id, revision=revision, token=token, files_metadata=True).siblings
|
129 |
index_dict = load_model_index(pipeline_id, token=token, revision=revision)
|
130 |
|
131 |
+
# Check if all the concerned components have the checkpoints in
|
132 |
+
# the requested "variant" and "extension".
|
133 |
print(f"Index dict: {index_dict}")
|
134 |
for current_component in index_dict:
|
135 |
if (
|
|
|
138 |
and len(index_dict[current_component]) == 2
|
139 |
):
|
140 |
current_component_fileobjs = list(filter(lambda x: current_component in x.rfilename, files_in_repo))
|
141 |
+
|
142 |
if current_component_fileobjs:
|
143 |
current_component_filenames = [fileobj.rfilename for fileobj in current_component_fileobjs]
|
144 |
condition = ( # noqa: E731
|
|
|
195 |
if selected_file is not None:
|
196 |
component_wise_memory[component] = selected_file.size
|
197 |
|
198 |
+
return format_output(pipeline_id, component_wise_memory, controlnet_mapping, t2_adapter_mapping)
|
199 |
|
200 |
|
201 |
with gr.Interface(
|
202 |
title="Compute component-wise memory of a 🧨 Diffusers pipeline.",
|
203 |
description="Pipelines containing text encoders with sharded checkpoints are also supported"
|
204 |
+
" (PixArt-Alpha, for example) 🤗 See instructions below the form on how to pass"
|
205 |
+
" `controlnet_id` or `t2_adapter_id`.",
|
206 |
+
article=ARTICLE,
|
207 |
fn=get_component_wise_memory,
|
208 |
inputs=[
|
209 |
gr.components.Textbox(lines=1, label="pipeline_id", info="Example: runwayml/stable-diffusion-v1-5"),
|
210 |
+
gr.components.Textbox(lines=1, label="controlnet_id", info="Example: lllyasviel/sd-controlnet-canny"),
|
211 |
+
gr.components.Textbox(lines=1, label="t2i_adapter_id", info="Example: TencentARC/t2iadapter_color_sd14v1"),
|
212 |
gr.components.Textbox(lines=1, label="hf_token", info="Pass this in case of private repositories."),
|
213 |
gr.components.Radio(
|
214 |
["fp32", "fp16", "bf16"],
|
|
|
224 |
],
|
225 |
outputs=[gr.Markdown(label="Output")],
|
226 |
examples=[
|
227 |
+
["runwayml/stable-diffusion-v1-5", None, None, None, "fp32", None, ".safetensors"],
|
228 |
+
["PixArt-alpha/PixArt-XL-2-1024-MS", None, None, None, "fp32", None, ".safetensors"],
|
229 |
+
["runwayml/stable-diffusion-v1-5", "lllyasviel/sd-controlnet-canny", None, None, "fp32", None, ".safetensors"],
|
230 |
+
[
|
231 |
+
"stabilityai/stable-diffusion-xl-base-1.0",
|
232 |
+
None,
|
233 |
+
"TencentARC/t2i-adapter-lineart-sdxl-1.0,TencentARC/t2i-adapter-canny-sdxl-1.0",
|
234 |
+
None,
|
235 |
+
"fp32",
|
236 |
+
None,
|
237 |
+
".safetensors",
|
238 |
+
],
|
239 |
+
["stabilityai/stable-cascade", None, None, None, "bf16", None, ".safetensors"],
|
240 |
+
["Deci/DeciDiffusion-v2-0", None, None, None, "fp32", None, ".safetensors"],
|
241 |
],
|
242 |
theme=gr.themes.Soft(),
|
243 |
allow_flagging="never",
|