from huggingface_hub import hf_hub_download, model_info import gradio as gr import json COMPONENT_FILTER = [ "scheduler", "feature_extractor", "tokenizer", "tokenizer_2", "_class_name", "_diffusers_version", ] ARTICLE = """ ## Notes on how to use the `controlnet_id` and `t2i_adapter_id` fields Both `controlnet_id` and `t2i_adapter_id` fields support passing multiple checkpoint ids, e.g., "thibaud/controlnet-openpose-sdxl-1.0,diffusers/controlnet-canny-sdxl-1.0". For `t2i_adapter_id`, this could be like - "TencentARC/t2iadapter_keypose_sd14v1,TencentARC/t2iadapter_depth_sd14v1". Users should take care of passing the underlying base `pipeline_id` appropriately. For example, passing `pipeline_id` as "runwayml/stable-diffusion-v1-5" and `controlnet_id` as "thibaud/controlnet-openpose-sdxl-1.0" won't result in an error but these two things aren't meant to compatible. You should pass a `controlnet_id` that is compatible with "runwayml/stable-diffusion-v1-5". For further clarification on this topic, feel free to open a [discussion]( ๐Ÿ“” Also, note that `revision` field is only reserved for `pipeline_id`. It won't have any effect on the `controlnet_id` or `t2i_adapter_id`. """ ALLOWED_VARIANTS = ["fp32", "fp16", "bf16"] def format_size(num: int) -> str: """Format size in bytes into a human-readable string. Taken from """ num_f = float(num) for unit in ["", "K", "M", "G", "T", "P", "E", "Z"]: if abs(num_f) < 1000.0: return f"{num_f:3.1f}{unit}" num_f /= 1000.0 return f"{num_f:.1f}Y" def format_output(pipeline_id, memory_mapping, variant=None, controlnet_mapping=None, t2i_adapter_mapping=None): if variant is None: variant = "fp32" markdown_str = f"## {pipeline_id} ({variant})\n" if memory_mapping: for component, memory in memory_mapping.items(): markdown_str += f"* {component}: {format_size(memory)}\n" if controlnet_mapping: markdown_str += f"\n## ControlNet(s) ({variant})\n" for controlnet_id, memory in controlnet_mapping.items(): markdown_str += f"* {controlnet_id}: {format_size(memory)}\n" if t2i_adapter_mapping: markdown_str += f"\n## T2I-Adapters(s) ({variant})\n" for t2_adapter_id, memory in t2i_adapter_mapping.items(): markdown_str += f"* {t2_adapter_id}: {format_size(memory)}\n" return markdown_str def load_model_index(pipeline_id, token=None, revision=None): index_path = hf_hub_download(repo_id=pipeline_id, filename="model_index.json", revision=revision, token=token) with open(index_path, "r") as f: index_dict = json.load(f) return index_dict def get_individual_model_memory(id, token, variant, extension): # Retrieve all files in the repository. files_in_repo = model_info(id, token=token, files_metadata=True).siblings # Filter files by extension and variant (if provided). if variant: candidates = [x for x in files_in_repo if (extension in x.rfilename) and (variant in x.rfilename)] if not candidates: raise ValueError(f"Requested variant ({variant}) for {id} couldn't be found with {extension} extension.") else: candidates = [ x for x in files_in_repo if (extension in x.rfilename) and all(var not in x.rfilename for var in ALLOWED_VARIANTS[1:]) ] if not candidates: raise ValueError(f"No file for {id} could be found with {extension} extension without specified variants.") # Return the size of the first matching file. return candidates[0].size def get_component_wise_memory( pipeline_id, controlnet_id=None, t2i_adapter_id=None, token=None, variant=None, revision=None, extension=".safetensors", ): if controlnet_id == "": controlnet_id = None if t2i_adapter_id == "": t2i_adapter_id = None if controlnet_id and t2i_adapter_id: raise ValueError("Both `controlnet_id` and `t2i_adapter_id` cannot be provided.") if token == "": token = None if revision == "": revision = None if variant == "fp32": variant = None # Handle ControlNet and T2I-Adapter. controlnet_mapping = t2_adapter_mapping = None if controlnet_id is not None: controlnet_id = controlnet_id.split(",") controlnet_sizes = [ get_individual_model_memory(id_, token=token, variant=variant, extension=extension) for id_ in controlnet_id ] controlnet_mapping = dict(zip(controlnet_id, controlnet_sizes)) elif t2i_adapter_id is not None: t2i_adapter_id = t2i_adapter_id.split(",") t2i_adapter_sizes = [ get_individual_model_memory(id_, token=token, variant=variant, extension=extension) for id_ in t2i_adapter_id ] t2_adapter_mapping = dict(zip(t2i_adapter_id, t2i_adapter_sizes)) print(f"pipeline_id: {pipeline_id}, variant: {variant}, revision: {revision}, extension: {extension}") # Load pipeline metadata. files_in_repo = model_info(pipeline_id, revision=revision, token=token, files_metadata=True).siblings index_dict = load_model_index(pipeline_id, token=token, revision=revision) # Check if all the concerned components have the checkpoints in # the requested "variant" and "extension". print(f"Index dict: {index_dict}") for current_component in index_dict: if ( current_component not in COMPONENT_FILTER and isinstance(index_dict[current_component], list) and len(index_dict[current_component]) == 2 ): current_component_fileobjs = list(filter(lambda x: current_component in x.rfilename, files_in_repo)) if current_component_fileobjs: current_component_filenames = [fileobj.rfilename for fileobj in current_component_fileobjs] condition = ( # noqa: E731 lambda filename: extension in filename and variant in filename if variant is not None else lambda filename: extension in filename ) variant_present_with_extension = any(condition(filename) for filename in current_component_filenames) if not variant_present_with_extension: formatted_filenames = ", ".join(current_component_filenames) raise ValueError( f"Requested extension ({extension}) and variant ({variant}) not present for {current_component}." f" Available files for this component: {formatted_filenames}." ) else: raise ValueError(f"Problem with {current_component}.") # Handle text encoder separately when it's sharded. is_text_encoder_shared = any(".index.json" in file_obj.rfilename for file_obj in files_in_repo) component_wise_memory = {} if is_text_encoder_shared: for current_file in files_in_repo: if "text_encoder" in current_file.rfilename: if not current_file.rfilename.endswith(".json") and current_file.rfilename.endswith(extension): if variant is not None and variant in current_file.rfilename: selected_file = current_file else: selected_file = current_file if "text_encoder" not in component_wise_memory: component_wise_memory["text_encoder"] = selected_file.size else: component_wise_memory["text_encoder"] += selected_file.size # Handle pipeline components. if is_text_encoder_shared: COMPONENT_FILTER.append("text_encoder") for current_file in files_in_repo: if all(substring not in current_file.rfilename for substring in COMPONENT_FILTER): is_folder = len(current_file.rfilename.split("/")) == 2 if is_folder and current_file.rfilename.split("/")[0] in index_dict: selected_file = None if not current_file.rfilename.endswith(".json") and current_file.rfilename.endswith(extension): component = current_file.rfilename.split("/")[0] if ( variant is not None and variant in current_file.rfilename and "ema" not in current_file.rfilename ): selected_file = current_file elif variant is None and "ema" not in current_file.rfilename: selected_file = current_file if selected_file is not None: component_wise_memory[component] = selected_file.size return format_output(pipeline_id, component_wise_memory, variant, controlnet_mapping, t2_adapter_mapping) with gr.Blocks(theme=gr.themes.Soft()) as demo: with gr.Column(): gr.Markdown( """

๐Ÿงจ Diffusers Pipeline Memory Calculator

This tool will help you to gauge the memory requirements of a Diffusers pipeline. Pipelines containing text encoders with sharded checkpoints are also supported (PixArt-Alpha, for example) ๐Ÿค— See instructions below the form on how to pass `controlnet_id` or `t2_adapter_id`. When performing inference, expect to add up to an additional 20% to this as found by [EleutherAI]( The final memory requirement will also depend on the requested resolution. You can click on one of the examples below the "Calculate Memory Usage" button to get started. Design adapted from [this Space]( """ ) out_text = gr.Markdown() with gr.Row(): pipeline_id = gr.Textbox(lines=1, label="pipeline_id", info="Example: runwayml/stable-diffusion-v1-5") with gr.Row(): controlnet_id = gr.Textbox(lines=1, label="controlnet_id", info="Example: lllyasviel/sd-controlnet-canny") t2i_adapter_id = gr.Textbox( lines=1, label="t2i_adapter_id", info="Example: TencentARC/t2iadapter_color_sd14v1" ) with gr.Row(): token = gr.Textbox(lines=1, label="hf_token", info="Pass this in case of private/gated repositories.") variant = gr.Radio( ALLOWED_VARIANTS, label="variant", info="Precision to use for calculation.", ) revision = gr.Textbox(lines=1, label="revision", info="Repository revision to use.") extension = gr.Radio( [".bin", ".safetensors"], label="extension", info="Extension to use.", ) with gr.Row(): btn = gr.Button("Calculate Memory Usage") gr.Markdown("## Examples") gr.Examples( [ ["runwayml/stable-diffusion-v1-5", None, None, None, "fp32", None, ".safetensors"], ["PixArt-alpha/PixArt-XL-2-1024-MS", None, None, None, "fp32", None, ".safetensors"], [ "runwayml/stable-diffusion-v1-5", "lllyasviel/sd-controlnet-canny", None, None, "fp32", None, ".safetensors", ], [ "stabilityai/stable-diffusion-xl-base-1.0", None, "TencentARC/t2i-adapter-lineart-sdxl-1.0,TencentARC/t2i-adapter-canny-sdxl-1.0", None, "fp16", None, ".safetensors", ], ["stabilityai/stable-cascade", None, None, None, "bf16", None, ".safetensors"], ["Deci/DeciDiffusion-v2-0", None, None, None, "fp32", None, ".safetensors"], ], [pipeline_id, controlnet_id, t2i_adapter_id, token, variant, revision, extension], out_text, get_component_wise_memory, cache_examples=False, ) gr.Markdown(ARTICLE) get_component_wise_memory, inputs=[pipeline_id, controlnet_id, t2i_adapter_id, token, variant, revision, extension], outputs=[out_text], api_name=False, ) demo.launch(show_error=True)