|
from huggingface_hub import model_info |
|
import gradio as gr |
|
|
|
|
|
def bytes_to_giga_bytes(bytes): |
|
return f"{(bytes / 1024 / 1024 / 1024):.3f}" |
|
|
|
|
|
def get_component_wise_memory(pipeline_id, token=None, variant=None, revision=None, extension=".safetensors"): |
|
if token == "": |
|
token = None |
|
|
|
if variant == "": |
|
variant = None |
|
|
|
if revision == "": |
|
revision = None |
|
|
|
if variant == "fp32": |
|
variant = None |
|
|
|
print(pipeline_id, variant, revision, extension) |
|
component_wise_memory = {} |
|
|
|
files_in_repo = model_info(pipeline_id, revision=revision, token=token, files_metadata=True).siblings |
|
|
|
for current_file in files_in_repo: |
|
if all( |
|
substring not in current_file.rfilename |
|
for substring in ["scheduler", "feature_extractor", "safety_checker", "tokenizer"] |
|
): |
|
is_folder = len(current_file.rfilename.split("/")) == 2 |
|
if is_folder: |
|
filename = None |
|
if not current_file.rfilename.endswith(".json") and current_file.rfilename.endswith(extension): |
|
component = current_file.rfilename.split("/")[0] |
|
if ( |
|
variant is not None |
|
and variant in current_file.rfilename |
|
and "ema" not in current_file.rfilename |
|
): |
|
filename = current_file.rfilename |
|
elif "ema" not in current_file.rfilename: |
|
filename = current_file.rfilename |
|
|
|
if filename is not None: |
|
component_wise_memory[component] = bytes_to_giga_bytes(current_file.size) |
|
|
|
return component_wise_memory |
|
|
|
|
|
gr.Interface( |
|
title="Compute component-wise memory of a 🧨 Diffusers pipeline.", |
|
description="Sizes will be reported in GB.", |
|
fn=get_component_wise_memory, |
|
inputs=[ |
|
gr.components.Textbox(lines=1, label="pipeline_id", info="Example: runwayml/stable-diffusion-v1-5"), |
|
gr.components.Textbox(lines=1, label="hf_token", info="Pass this in case of private repositories."), |
|
gr.components.Dropdown( |
|
[ |
|
"fp32", |
|
"fp16", |
|
], |
|
label="variant", |
|
info="Precision to use for calculation.", |
|
), |
|
gr.components.Textbox(lines=1, label="revision", info="Repository revision to use."), |
|
gr.components.Dropdown( |
|
[".bin", ".safetensors"], |
|
label="extension", |
|
info="Extension to use.", |
|
), |
|
], |
|
outputs="text", |
|
examples=[ |
|
["runwayml/stable-diffusion-v1-5", None, "fp32", None, ".safetensors"], |
|
["stabilityai/stable-diffusion-xl-base-1.0", None, "fp16", None, ".safetensors"], |
|
], |
|
theme=gr.themes.Soft(), |
|
allow_flagging=False, |
|
).launch() |
|
|