File size: 12,832 Bytes
41dfa78
4b23311
601c9fd
4b23311
4eeb10d
 
 
 
 
 
 
 
4b23311
3f6a1fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0d1d63a
 
78f1f97
1ea8dd9
 
 
 
 
 
 
 
 
 
 
78f1f97
c2058c7
 
 
 
 
3f6a1fe
1ea8dd9
4511c84
1ea8dd9
3f6a1fe
41dfa78
3f6a1fe
 
 
41dfa78
3f6a1fe
 
 
1ea8dd9
4b23311
78f1f97
601c9fd
 
 
 
 
 
78f1f97
3f6a1fe
1632490
3f6a1fe
1632490
 
0d1d63a
1632490
 
 
0d1d63a
1632490
 
 
 
 
 
 
 
 
 
3f6a1fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4b23311
 
 
 
 
 
 
 
 
3f6a1fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
601c9fd
4b23311
3f6a1fe
bf03ba4
601c9fd
 
3f6a1fe
 
4eeb10d
5d813dc
4eeb10d
 
 
 
 
78f1f97
3f6a1fe
5d813dc
 
4eeb10d
78f1f97
 
 
 
5d813dc
 
a673dbd
78f1f97
aa30618
a673dbd
78f1f97
5d813dc
 
 
601c9fd
5d813dc
 
601c9fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51957d4
4b23311
 
51957d4
4b23311
601c9fd
 
4b23311
 
 
 
 
 
 
601c9fd
 
 
4b23311
601c9fd
59db7fd
4b23311
c2058c7
78f1f97
 
41dfa78
 
 
 
 
 
b7823b4
 
41dfa78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
from huggingface_hub import hf_hub_download, model_info
import gradio as gr
import json

COMPONENT_FILTER = [
    "scheduler",
    "feature_extractor",
    "tokenizer",
    "tokenizer_2",
    "_class_name",
    "_diffusers_version",
]

ARTICLE = """
## Notes on how to use the `controlnet_id` and `t2i_adapter_id` fields

Both `controlnet_id` and `t2i_adapter_id` fields support passing multiple checkpoint ids,
e.g., "thibaud/controlnet-openpose-sdxl-1.0,diffusers/controlnet-canny-sdxl-1.0". For 
`t2i_adapter_id`, this could be like - "TencentARC/t2iadapter_keypose_sd14v1,TencentARC/t2iadapter_depth_sd14v1".

Users should take care of passing the underlying base `pipeline_id` appropriately. For example,
passing `pipeline_id` as "runwayml/stable-diffusion-v1-5" and `controlnet_id` as "thibaud/controlnet-openpose-sdxl-1.0"
won't result in an error but these two things aren't meant to compatible. You should pass
a `controlnet_id` that is compatible with "runwayml/stable-diffusion-v1-5".

For further clarification on this topic, feel free to open a [discussion](https://huggingface.co./spaces/diffusers/compute-pipeline-size/discussions).

📔 Also, note that `revision` field is only reserved for `pipeline_id`. It won't have any effect on the
`controlnet_id` or `t2i_adapter_id`.
"""

ALLOWED_VARIANTS = ["fp32", "fp16", "bf16"]


def format_size(num: int) -> str:
    """Format size in bytes into a human-readable string.
    Taken from https://stackoverflow.com/a/1094933
    """
    num_f = float(num)
    for unit in ["", "K", "M", "G", "T", "P", "E", "Z"]:
        if abs(num_f) < 1000.0:
            return f"{num_f:3.1f}{unit}"
        num_f /= 1000.0
    return f"{num_f:.1f}Y"


def format_output(pipeline_id, memory_mapping, variant=None, controlnet_mapping=None, t2i_adapter_mapping=None):
    if variant is None:
        variant = "fp32"

    markdown_str = f"## {pipeline_id} ({variant})\n"

    if memory_mapping:
        for component, memory in memory_mapping.items():
            markdown_str += f"* {component}: {format_size(memory)}\n"
    if controlnet_mapping:
        markdown_str += f"\n## ControlNet(s) ({variant})\n"
        for controlnet_id, memory in controlnet_mapping.items():
            markdown_str += f"* {controlnet_id}: {format_size(memory)}\n"
    if t2i_adapter_mapping:
        markdown_str += f"\n## T2I-Adapters(s) ({variant})\n"
        for t2_adapter_id, memory in t2i_adapter_mapping.items():
            markdown_str += f"* {t2_adapter_id}: {format_size(memory)}\n"

    return markdown_str


def load_model_index(pipeline_id, token=None, revision=None):
    index_path = hf_hub_download(repo_id=pipeline_id, filename="model_index.json", revision=revision, token=token)
    with open(index_path, "r") as f:
        index_dict = json.load(f)
    return index_dict


def get_individual_model_memory(id, token, variant, extension):
    # Retrieve all files in the repository.
    files_in_repo = model_info(id, token=token, files_metadata=True).siblings

    # Filter files by extension and variant (if provided).
    if variant:
        candidates = [x for x in files_in_repo if (extension in x.rfilename) and (variant in x.rfilename)]
        if not candidates:
            raise ValueError(f"Requested variant ({variant}) for {id} couldn't be found with {extension} extension.")
    else:
        candidates = [
            x
            for x in files_in_repo
            if (extension in x.rfilename) and all(var not in x.rfilename for var in ALLOWED_VARIANTS[1:])
        ]
        if not candidates:
            raise ValueError(f"No file for {id} could be found with {extension} extension without specified variants.")

    # Return the size of the first matching file.
    return candidates[0].size


def get_component_wise_memory(
    pipeline_id,
    controlnet_id=None,
    t2i_adapter_id=None,
    token=None,
    variant=None,
    revision=None,
    extension=".safetensors",
):
    if controlnet_id == "":
        controlnet_id = None

    if t2i_adapter_id == "":
        t2i_adapter_id = None

    if controlnet_id and t2i_adapter_id:
        raise ValueError("Both `controlnet_id` and `t2i_adapter_id` cannot be provided.")

    if token == "":
        token = None

    if revision == "":
        revision = None

    if variant == "fp32":
        variant = None

    # Handle ControlNet and T2I-Adapter.
    controlnet_mapping = t2_adapter_mapping = None
    if controlnet_id is not None:
        controlnet_id = controlnet_id.split(",")
        controlnet_sizes = [
            get_individual_model_memory(id_, token=token, variant=variant, extension=extension)
            for id_ in controlnet_id
        ]
        controlnet_mapping = dict(zip(controlnet_id, controlnet_sizes))
    elif t2i_adapter_id is not None:
        t2i_adapter_id = t2i_adapter_id.split(",")
        t2i_adapter_sizes = [
            get_individual_model_memory(id_, token=token, variant=variant, extension=extension)
            for id_ in t2i_adapter_id
        ]
        t2_adapter_mapping = dict(zip(t2i_adapter_id, t2i_adapter_sizes))

    print(f"pipeline_id: {pipeline_id}, variant: {variant}, revision: {revision}, extension: {extension}")

    # Load pipeline metadata.
    files_in_repo = model_info(pipeline_id, revision=revision, token=token, files_metadata=True).siblings
    index_dict = load_model_index(pipeline_id, token=token, revision=revision)

    # Check if all the concerned components have the checkpoints in
    # the requested "variant" and "extension".
    print(f"Index dict: {index_dict}")
    for current_component in index_dict:
        if (
            current_component not in COMPONENT_FILTER
            and isinstance(index_dict[current_component], list)
            and len(index_dict[current_component]) == 2
        ):
            current_component_fileobjs = list(filter(lambda x: current_component in x.rfilename, files_in_repo))

            if current_component_fileobjs:
                current_component_filenames = [fileobj.rfilename for fileobj in current_component_fileobjs]
                condition = (  # noqa: E731
                    lambda filename: extension in filename and variant in filename
                    if variant is not None
                    else lambda filename: extension in filename
                )
                variant_present_with_extension = any(condition(filename) for filename in current_component_filenames)
                if not variant_present_with_extension:
                    formatted_filenames = ", ".join(current_component_filenames)
                    raise ValueError(
                        f"Requested extension ({extension}) and variant ({variant}) not present for {current_component}."
                        f" Available files for this component: {formatted_filenames}."
                    )
            else:
                raise ValueError(f"Problem with {current_component}.")

    # Handle text encoder separately when it's sharded.
    is_text_encoder_shared = any(".index.json" in file_obj.rfilename for file_obj in files_in_repo)
    component_wise_memory = {}
    if is_text_encoder_shared:
        for current_file in files_in_repo:
            if "text_encoder" in current_file.rfilename:
                if not current_file.rfilename.endswith(".json") and current_file.rfilename.endswith(extension):
                    if variant is not None and variant in current_file.rfilename:
                        selected_file = current_file
                    else:
                        selected_file = current_file
                    if "text_encoder" not in component_wise_memory:
                        component_wise_memory["text_encoder"] = selected_file.size
                    else:
                        component_wise_memory["text_encoder"] += selected_file.size

    # Handle pipeline components.
    if is_text_encoder_shared:
        COMPONENT_FILTER.append("text_encoder")

    for current_file in files_in_repo:
        if all(substring not in current_file.rfilename for substring in COMPONENT_FILTER):
            is_folder = len(current_file.rfilename.split("/")) == 2
            if is_folder and current_file.rfilename.split("/")[0] in index_dict:
                selected_file = None
                if not current_file.rfilename.endswith(".json") and current_file.rfilename.endswith(extension):
                    component = current_file.rfilename.split("/")[0]
                    if (
                        variant is not None
                        and variant in current_file.rfilename
                        and "ema" not in current_file.rfilename
                    ):
                        selected_file = current_file
                    elif variant is None and "ema" not in current_file.rfilename:
                        selected_file = current_file

                    if selected_file is not None:
                        component_wise_memory[component] = selected_file.size

    return format_output(pipeline_id, component_wise_memory, variant, controlnet_mapping, t2_adapter_mapping)


with gr.Blocks(theme=gr.themes.Soft()) as demo:
    with gr.Column():
        gr.Markdown(
            """<img src="https://huggingface.co./spaces/hf-accelerate/model-memory-usage/resolve/main/measure_model_size.png" style="float: left;" width="150" height="175"><h1>🧨 Diffusers Pipeline Memory Calculator</h1>
    This tool will help you to gauge the memory requirements of a Diffusers pipeline. Pipelines containing text encoders with sharded checkpoints are also supported
    (PixArt-Alpha, for example) 🤗 See instructions below the form on how to pass `controlnet_id` or `t2_adapter_id`. When performing inference, expect to add up to an
    additional 20% to this as found by [EleutherAI](https://blog.eleuther.ai/transformer-math/). The final memory requirement will also depend on the requested resolution. You can click on one
    of the examples below the "Calculate Memory Usage" button to get started. Design adapted from [this Space](https://huggingface.co./spaces/hf-accelerate/model-memory-usage).
    """
        )
        out_text = gr.Markdown()
        with gr.Row():
            pipeline_id = gr.Textbox(lines=1, label="pipeline_id", info="Example: runwayml/stable-diffusion-v1-5")
        with gr.Row():
            controlnet_id = gr.Textbox(lines=1, label="controlnet_id", info="Example: lllyasviel/sd-controlnet-canny")
            t2i_adapter_id = gr.Textbox(
                lines=1, label="t2i_adapter_id", info="Example: TencentARC/t2iadapter_color_sd14v1"
            )
        with gr.Row():
            token = gr.Textbox(lines=1, label="hf_token", info="Pass this in case of private/gated repositories.")
            variant = gr.Radio(
                ALLOWED_VARIANTS,
                label="variant",
                info="Precision to use for calculation.",
            )
            revision = gr.Textbox(lines=1, label="revision", info="Repository revision to use.")
            extension = gr.Radio(
                [".bin", ".safetensors"],
                label="extension",
                info="Extension to use.",
            )
        with gr.Row():
            btn = gr.Button("Calculate Memory Usage")

        gr.Markdown("## Examples")
        gr.Examples(
            [
                ["runwayml/stable-diffusion-v1-5", None, None, None, "fp32", None, ".safetensors"],
                ["PixArt-alpha/PixArt-XL-2-1024-MS", None, None, None, "fp32", None, ".safetensors"],
                [
                    "runwayml/stable-diffusion-v1-5",
                    "lllyasviel/sd-controlnet-canny",
                    None,
                    None,
                    "fp32",
                    None,
                    ".safetensors",
                ],
                [
                    "stabilityai/stable-diffusion-xl-base-1.0",
                    None,
                    "TencentARC/t2i-adapter-lineart-sdxl-1.0,TencentARC/t2i-adapter-canny-sdxl-1.0",
                    None,
                    "fp16",
                    None,
                    ".safetensors",
                ],
                ["stabilityai/stable-cascade", None, None, None, "bf16", None, ".safetensors"],
                ["Deci/DeciDiffusion-v2-0", None, None, None, "fp32", None, ".safetensors"],
            ],
            [pipeline_id, controlnet_id, t2i_adapter_id, token, variant, revision, extension],
            out_text,
            get_component_wise_memory,
            cache_examples=False,
        )

        gr.Markdown(ARTICLE)

    btn.click(
        get_component_wise_memory,
        inputs=[pipeline_id, controlnet_id, t2i_adapter_id, token, variant, revision, extension],
        outputs=[out_text],
        api_name=False,
    )

demo.launch(show_error=True)