File size: 5,185 Bytes
601c9fd
4b23311
601c9fd
4b23311
 
1ea8dd9
 
 
 
 
 
 
 
 
 
 
 
46b001c
 
1ea8dd9
4511c84
1ea8dd9
 
4b23311
601c9fd
 
 
 
 
 
4b23311
 
 
 
 
 
 
 
 
 
601c9fd
4b23311
bf03ba4
601c9fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4b23311
 
601c9fd
4b23311
601c9fd
 
4b23311
 
 
 
 
 
 
601c9fd
 
 
4b23311
601c9fd
 
59db7fd
4b23311
46b001c
4b23311
 
 
 
3364fa1
4b23311
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2b36710
4b23311
 
 
13d87c4
4b23311
 
 
2b36710
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
from huggingface_hub import model_info, hf_hub_download
import gradio as gr
import json


def format_size(num: int) -> str:
    """Format size in bytes into a human-readable string.

    Taken from https://stackoverflow.com/a/1094933
    """
    num_f = float(num)
    for unit in ["", "K", "M", "G", "T", "P", "E", "Z"]:
        if abs(num_f) < 1000.0:
            return f"{num_f:3.1f}{unit}"
        num_f /= 1000.0
    return f"{num_f:.1f}Y"

def format_output(pipeline_id, memory_mapping):
    markdown_str = f"## {pipeline_id}\n"
    if memory_mapping:
        for component, memory in memory_mapping.items():
            markdown_str += f"* {component}: {format_size(memory)}\n"
    return markdown_str

def load_model_index(pipeline_id, token=None, revision=None):
    index_path = hf_hub_download(repo_id=pipeline_id, filename="model_index.json", revision=revision, token=token)
    with open(index_path, "r") as f:
        index_dict = json.load(f)
    return index_dict

def get_component_wise_memory(pipeline_id, token=None, variant=None, revision=None, extension=".safetensors"):
    if token == "":
        token = None

    if revision == "":
        revision = None

    if variant == "fp32":
        variant = None

    print(f"pipeline_id: {pipeline_id}, variant: {variant}, revision: {revision}, extension: {extension}")

    files_in_repo = model_info(pipeline_id, revision=revision, token=token, files_metadata=True).siblings
    index_dict = load_model_index(pipeline_id, token=token, revision=revision)

    is_text_encoder_shared = any(".index.json" in file_obj.rfilename for file_obj in files_in_repo)
    component_wise_memory = {}

    # Handle text encoder separately when it's sharded.
    if is_text_encoder_shared:
        for current_file in files_in_repo:
            if "text_encoder" in current_file.rfilename:
                if not current_file.rfilename.endswith(".json") and current_file.rfilename.endswith(extension):
                    if variant is not None and variant in current_file.rfilename:
                        selected_file = current_file
                    else:
                        selected_file = current_file
                    if "text_encoder" not in component_wise_memory:
                        component_wise_memory["text_encoder"] = selected_file.size
                    else:
                        component_wise_memory["text_encoder"] += selected_file.size

    print(component_wise_memory)

    # Handle pipeline components.
    component_filter = ["scheduler", "feature_extractor", "safety_checker", "tokenizer"]
    if is_text_encoder_shared:
        component_filter.append("text_encoder")

    for current_file in files_in_repo:
        if all(substring not in current_file.rfilename for substring in component_filter):
            is_folder = len(current_file.rfilename.split("/")) == 2
            if is_folder and current_file.rfilename.split("/")[0] in index_dict:
                selected_file = None
                if not current_file.rfilename.endswith(".json") and current_file.rfilename.endswith(extension):
                    component = current_file.rfilename.split("/")[0]
                    if (
                        variant is not None
                        and variant in current_file.rfilename
                        and "ema" not in current_file.rfilename
                    ):
                        selected_file = current_file
                    elif variant is None and "ema" not in current_file.rfilename:
                        selected_file = current_file

                    if selected_file is not None:
                        print(selected_file.rfilename)
                        component_wise_memory[component] = selected_file.size

    return format_output(pipeline_id, component_wise_memory)


gr.Interface(
    title="Compute component-wise memory of a 🧨 Diffusers pipeline.",
    description="Sizes will be reported in GB. Pipelines containing text encoders with sharded checkpoints are also supported (PixArt-Alpha, for example) 🤗",
    fn=get_component_wise_memory,
    inputs=[
        gr.components.Textbox(lines=1, label="pipeline_id", info="Example: runwayml/stable-diffusion-v1-5"),
        gr.components.Textbox(lines=1, label="hf_token", info="Pass this in case of private repositories."),
        gr.components.Dropdown(
            [
                "fp32",
                "fp16",
            ],
            label="variant",
            info="Precision to use for calculation.",
        ),
        gr.components.Textbox(lines=1, label="revision", info="Repository revision to use."),
        gr.components.Dropdown(
            [".bin", ".safetensors"],
            label="extension",
            info="Extension to use.",
        ),
    ],
    outputs=[gr.Markdown(label="Output")],
    examples=[
        ["runwayml/stable-diffusion-v1-5", None, "fp32", None, ".safetensors"],
        ["stabilityai/stable-diffusion-xl-base-1.0", None, "fp16", None, ".safetensors"],
        ["PixArt-alpha/PixArt-XL-2-1024-MS", None, "fp32", None, ".safetensors"],
    ],
    theme=gr.themes.Soft(),
    allow_flagging=False,
).launch(show_error=True)