File size: 6,136 Bytes
62d807b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
# !usr/bin/env python
# -*- coding:utf-8 -*-

'''
 Description  : 
 Version      : 1.0
 Author       : MrYXJ
 Mail         : [email protected]
 Github       : https://github.com/MrYxJ
 Date         : 2023-09-05 23:28:32
 LastEditTime : 2023-09-09 19:14:20
 Copyright (C) 2023 mryxj. All rights reserved.
'''


import gradio as gr
import torch

from accelerate.commands.estimate import check_has_model
from urllib.parse import urlparse
from huggingface_hub.utils import GatedRepoError
from huggingface_hub.utils import RepositoryNotFoundError

from calflops import create_empty_model
from calflops import calculate_flops_hf
from calflops import flops_to_string
from calflops import macs_to_string
from calflops import params_to_string

def calculate_flops_in_hugging_space(model_name: str, 
                                     empty_model: torch.nn.modules,
                                     access_token: str,
                                     input_shape: tuple,
                                     bp_factor: float,
                                     output_unit: str):
    
    "Calculates the FLOPs and Params usage for a model init on `meta` device"

    try:
        flops, macs, params, return_print  = calculate_flops_hf(model_name=model_name,
                                                                empty_model=empty_model,
                                                                access_token=access_token,
                                                                input_shape=input_shape,
                                                                return_results=True,
                                                                output_as_string=False)
    except Exception as e:
        print("Error info:", e)
        raise gr.Error(
            f"Model `{model_name}` does not support inference on the meta device, You can download the complete model parameters to your local and using the python package calflops to calculate FLOPs and Params of model `{model_name}`."
        )

    fw_bp_flops = flops * (1.0 + bp_factor)
    fw_bp_macs = macs * (1.0 + bp_factor)

    if output_unit == "":
        pass
    elif output_unit == "auto":
        params = params_to_string(params, units=None, precision=3)
        flops = flops_to_string(flops, units=None, precision=3)
        macs = macs_to_string(macs, units=None, precision=3)
        fw_bp_flops =  flops_to_string(fw_bp_flops, units=None, precision=3)
        fw_bp_macs =  macs_to_string(fw_bp_macs, units=None, precision=3)
    elif output_unit == "T" or output_unit == "G" or output_unit == "M" or output_unit == "K" or output_unit == "m" or output_unit == "u":
        params = params_to_string(params, units=output_unit, precision=3)
        flops = flops_to_string(flops, units=output_unit, precision=3)
        macs = macs_to_string(macs, units=output_unit, precision=3)
        fw_bp_flops =  flops_to_string(fw_bp_flops, units=output_unit, precision=3)
        fw_bp_macs =  macs_to_string(fw_bp_macs, units=output_unit, precision=3)
    
    return_lines = return_print.split("\n")
    return_start = False
    return_print = ""
    for line in return_lines[:-2]:
        if return_start:
            return_print += line + "\n"        
        if "Detailed" in line:
            return_start = True

    data = []
    data.append(
        {   "Total Training Params": params,
            "Forward FLOPs": flops,
            "Forward MACs": macs,
            "Forward+Backward FLOPs": fw_bp_flops,
            "Forward+Backward MACs": fw_bp_macs
        }
    )
    return data, return_print


def extract_from_url(name: str):
    "Checks if `name` is a URL, and if so converts it to a model name"
    is_url = False
    try:
        result = urlparse(name)
        is_url = all([result.scheme, result.netloc])
    except Exception:
        is_url = False
    # Pass through if not a URL
    if not is_url:
        return name
    else:
        path = result.path
        return path[1:]
    

def translate_llama2(text):
    "Translates llama-2 to its hf counterpart"
    if not text.endswith("-hf"):
        return text + "-hf"
    return text


def get_mode_from_hf(model_name: str, library: str, access_token: str):
    "Finds and grabs model from the Hub, and initializes on `meta`"
    if "meta-llama" in model_name:
        model_name = translate_llama2(model_name)
    if library == "auto":
        library = None
    model_name = extract_from_url(model_name)
    try:
        model = create_empty_model(model_name, library_name=library, trust_remote_code=True, access_token=access_token)
    except GatedRepoError:
        raise gr.Error(
            f"Model `{model_name}` is a gated model, please ensure to pass in your access token and try again if you have access. You can find your access token here : https://huggingface.co./settings/tokens. "
        )
    except RepositoryNotFoundError:
        raise gr.Error(f"Model `{model_name}` was not found on the Hub, please try another model name.")
    except ValueError:
        raise gr.Error(
            f"Model `{model_name}` does not have any library metadata on the Hub, please manually select a library_name to use (such as `transformers`)"
        )
    except (RuntimeError, OSError) as e:
        library = check_has_model(e)
        if library != "unknown":
            raise gr.Error(
                f"Tried to load `{model_name}` with `{library}` but a possible model to load was not found inside the repo."
            )
        raise gr.Error(
            f"Model `{model_name}` had an error, please open a discussion on the model's page with the error message and name: `{e}`"
        )
    except ImportError:
        # hacky way to check if it works with `trust_remote_code=False`
        model = create_empty_model(
            model_name, library_name=library, trust_remote_code=False, access_token=access_token
        )
    except Exception as e:
        raise gr.Error(
            f"Model `{model_name}` had an error, please open a discussion on the model's page with the error message and name: `{e}`"
        )
    return model