File size: 4,271 Bytes
c557532 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
import os
if os.environ.get("SPACES_ZERO_GPU") is not None:
import spaces
else:
class spaces:
@staticmethod
def GPU(func):
def wrapper(*args, **kwargs):
return func(*args, **kwargs)
return wrapper
import gradio as gr
from pathlib import Path
import gc
import shutil
import torch
from utils import set_token, upload_repo, is_repo_exists, is_repo_name
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import BitsAndBytesConfig
@spaces.GPU
def fake_gpu():
pass
MODEL_CLASS = {
"AutoModelForCausalLM": [AutoModelForCausalLM, AutoTokenizer],
}
DTYPE_DICT = {
"fp16": torch.float16,
"bf16": torch.bfloat16,
"fp32": torch.float32,
"fp8": torch.float8_e4m3fn
}
def get_model_class():
return list(MODEL_CLASS.keys())
def get_model(mclass: str):
return MODEL_CLASS.get(mclass, [AutoModelForCausalLM, AutoTokenizer])[0]
def get_tokenizer(mclass: str):
return MODEL_CLASS.get(mclass, [AutoModelForCausalLM, AutoTokenizer])[1]
def get_dtype(dtype: str):
return DTYPE_DICT.get(dtype, torch.bfloat16)
def save_readme_md(dir, repo_id):
orig_name = repo_id
orig_url = f"https://huggingface.co./{repo_id}/"
md = f"""---
license: other
language:
- en
library_name: transformers
base_model: {repo_id}
tags:
- transformers
---
Quants of [{orig_name}]({orig_url}).
"""
path = str(Path(dir, "README.md"))
with open(path, mode='w', encoding="utf-8") as f:
f.write(md)
@spaces.GPU
def quantize_repo(repo_id: str, dtype: str="bf16", qtype: str="nf4", mclass: str=get_model_class()[0], progress=gr.Progress(track_tqdm=True)):
progress(0, desc="Start quantizing...")
out_dir = repo_id.split("/")[-1]
type_kwargs = {}
if dtype != "default": type_kwargs["torch_dtype"] = get_dtype(dtype)
nf4_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_quant_storage=get_dtype(dtype),
bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=get_dtype(dtype))
quant_kwargs = {}
if qtype == "nf4": quant_kwargs["quantization_config"] = nf4_config
progress(0.1, desc="Loading...")
tokenizer = get_tokenizer(mclass).from_pretrained(repo_id, legathy=False)
model = get_model(mclass).from_pretrained(repo_id, **type_kwargs, **quant_kwargs)
progress(0.5, desc="Saving...")
tokenizer.save_pretrained(out_dir)
model.save_pretrained(out_dir, safe_serialization=True)
if Path(out_dir).exists(): save_readme_md(out_dir, repo_id)
del tokenizer
del model
torch.cuda.empty_cache()
gc.collect()
progress(1, desc="Quantized.")
return out_dir
def quantize_gr(repo_id: str, hf_token: str, urls: list[str], newrepo_id: str, is_private: bool=True, is_overwrite: bool=False,
dtype: str="bf16", qtype: str="nf4", mclass: str=get_model_class()[0], progress=gr.Progress(track_tqdm=True)):
if not hf_token: hf_token = os.environ.get("HF_TOKEN") # default huggingface token
if not hf_token: raise gr.Error("HF write token is required for this process.")
set_token(hf_token)
if not newrepo_id: newrepo_id = os.environ.get("HF_OUTPUT_REPO") # default repo id
if not is_repo_name(repo_id): raise gr.Error(f"Invalid repo name: {repo_id}")
if not is_repo_name(newrepo_id): raise gr.Error(f"Invalid repo name: {newrepo_id}")
if not is_overwrite and is_repo_exists(newrepo_id): raise gr.Error(f"Repo already exists: {newrepo_id}")
progress(0, desc="Start quantizing...")
new_path = quantize_repo(repo_id, dtype, qtype, mclass)
if not new_path: return ""
if not urls: urls = []
progress(0.5, desc="Start uploading...")
repo_url = upload_repo(newrepo_id, new_path, is_private)
progress(1, desc="Processing...")
shutil.rmtree(new_path)
urls.append(repo_url)
md = "### Your new repo:\n"
for u in urls:
md += f"[{str(u).split('/')[-2]}/{str(u).split('/')[-1]}]({str(u)})<br>"
torch.cuda.empty_cache()
gc.collect()
return gr.update(value=urls, choices=urls), gr.update(value=md)
|