File size: 4,271 Bytes
c557532
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import os
if os.environ.get("SPACES_ZERO_GPU") is not None:
    import spaces
else:
    class spaces:
        @staticmethod
        def GPU(func):
            def wrapper(*args, **kwargs):
                return func(*args, **kwargs)
            return wrapper
import gradio as gr
from pathlib import Path
import gc
import shutil
import torch
from utils import set_token, upload_repo, is_repo_exists, is_repo_name
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import BitsAndBytesConfig


@spaces.GPU
def fake_gpu():
    pass


MODEL_CLASS = {
    "AutoModelForCausalLM": [AutoModelForCausalLM, AutoTokenizer],
}


DTYPE_DICT = {
    "fp16": torch.float16,
    "bf16": torch.bfloat16,
    "fp32": torch.float32,
    "fp8": torch.float8_e4m3fn
}


def get_model_class():
    return list(MODEL_CLASS.keys())


def get_model(mclass: str):
    return MODEL_CLASS.get(mclass, [AutoModelForCausalLM, AutoTokenizer])[0]


def get_tokenizer(mclass: str):
    return MODEL_CLASS.get(mclass, [AutoModelForCausalLM, AutoTokenizer])[1]


def get_dtype(dtype: str):
    return DTYPE_DICT.get(dtype, torch.bfloat16)


def save_readme_md(dir, repo_id):
    orig_name = repo_id
    orig_url = f"https://huggingface.co./{repo_id}/"
    md = f"""---

license: other

language:

- en

library_name: transformers

base_model: {repo_id}

tags:

- transformers

---

Quants of [{orig_name}]({orig_url}).

"""
    path = str(Path(dir, "README.md"))
    with open(path, mode='w', encoding="utf-8") as f:
        f.write(md)


@spaces.GPU
def quantize_repo(repo_id: str, dtype: str="bf16", qtype: str="nf4", mclass: str=get_model_class()[0], progress=gr.Progress(track_tqdm=True)):
    progress(0, desc="Start quantizing...")
    out_dir = repo_id.split("/")[-1]

    type_kwargs = {}
    if dtype != "default": type_kwargs["torch_dtype"] = get_dtype(dtype)

    nf4_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_quant_storage=get_dtype(dtype),
                                    bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=get_dtype(dtype))
    quant_kwargs = {}
    if qtype == "nf4": quant_kwargs["quantization_config"] = nf4_config

    progress(0.1, desc="Loading...")
    tokenizer = get_tokenizer(mclass).from_pretrained(repo_id, legathy=False)
    model = get_model(mclass).from_pretrained(repo_id, **type_kwargs, **quant_kwargs)

    progress(0.5, desc="Saving...")
    tokenizer.save_pretrained(out_dir)
    model.save_pretrained(out_dir, safe_serialization=True)

    if Path(out_dir).exists(): save_readme_md(out_dir, repo_id)

    del tokenizer
    del model
    torch.cuda.empty_cache()
    gc.collect()

    progress(1, desc="Quantized.")
    return out_dir

def quantize_gr(repo_id: str, hf_token: str, urls: list[str], newrepo_id: str, is_private: bool=True, is_overwrite: bool=False,

                dtype: str="bf16", qtype: str="nf4", mclass: str=get_model_class()[0], progress=gr.Progress(track_tqdm=True)):
    if not hf_token: hf_token = os.environ.get("HF_TOKEN") # default huggingface token
    if not hf_token: raise gr.Error("HF write token is required for this process.")
    set_token(hf_token)
    if not newrepo_id: newrepo_id = os.environ.get("HF_OUTPUT_REPO") # default repo id
    if not is_repo_name(repo_id): raise gr.Error(f"Invalid repo name: {repo_id}")
    if not is_repo_name(newrepo_id): raise gr.Error(f"Invalid repo name: {newrepo_id}")
    if not is_overwrite and is_repo_exists(newrepo_id): raise gr.Error(f"Repo already exists: {newrepo_id}")
    progress(0, desc="Start quantizing...")
    new_path = quantize_repo(repo_id, dtype, qtype, mclass)
    if not new_path: return ""
    if not urls: urls = []
    progress(0.5, desc="Start uploading...")
    repo_url = upload_repo(newrepo_id, new_path, is_private)
    progress(1, desc="Processing...")
    shutil.rmtree(new_path)
    urls.append(repo_url)
    md = "### Your new repo:\n"
    for u in urls:
        md += f"[{str(u).split('/')[-2]}/{str(u).split('/')[-1]}]({str(u)})<br>"
    torch.cuda.empty_cache()
    gc.collect()
    return gr.update(value=urls, choices=urls), gr.update(value=md)