quantizer / quantizer_gr.py
John6666's picture
Upload 6 files
c557532 verified
import os
if os.environ.get("SPACES_ZERO_GPU") is not None:
import spaces
else:
class spaces:
@staticmethod
def GPU(func):
def wrapper(*args, **kwargs):
return func(*args, **kwargs)
return wrapper
import gradio as gr
from pathlib import Path
import gc
import shutil
import torch
from utils import set_token, upload_repo, is_repo_exists, is_repo_name
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import BitsAndBytesConfig
@spaces.GPU
def fake_gpu():
pass
MODEL_CLASS = {
"AutoModelForCausalLM": [AutoModelForCausalLM, AutoTokenizer],
}
DTYPE_DICT = {
"fp16": torch.float16,
"bf16": torch.bfloat16,
"fp32": torch.float32,
"fp8": torch.float8_e4m3fn
}
def get_model_class():
return list(MODEL_CLASS.keys())
def get_model(mclass: str):
return MODEL_CLASS.get(mclass, [AutoModelForCausalLM, AutoTokenizer])[0]
def get_tokenizer(mclass: str):
return MODEL_CLASS.get(mclass, [AutoModelForCausalLM, AutoTokenizer])[1]
def get_dtype(dtype: str):
return DTYPE_DICT.get(dtype, torch.bfloat16)
def save_readme_md(dir, repo_id):
orig_name = repo_id
orig_url = f"https://huggingface.co./{repo_id}/"
md = f"""---
license: other
language:
- en
library_name: transformers
base_model: {repo_id}
tags:
- transformers
---
Quants of [{orig_name}]({orig_url}).
"""
path = str(Path(dir, "README.md"))
with open(path, mode='w', encoding="utf-8") as f:
f.write(md)
@spaces.GPU
def quantize_repo(repo_id: str, dtype: str="bf16", qtype: str="nf4", mclass: str=get_model_class()[0], progress=gr.Progress(track_tqdm=True)):
progress(0, desc="Start quantizing...")
out_dir = repo_id.split("/")[-1]
type_kwargs = {}
if dtype != "default": type_kwargs["torch_dtype"] = get_dtype(dtype)
nf4_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_quant_storage=get_dtype(dtype),
bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=get_dtype(dtype))
quant_kwargs = {}
if qtype == "nf4": quant_kwargs["quantization_config"] = nf4_config
progress(0.1, desc="Loading...")
tokenizer = get_tokenizer(mclass).from_pretrained(repo_id, legathy=False)
model = get_model(mclass).from_pretrained(repo_id, **type_kwargs, **quant_kwargs)
progress(0.5, desc="Saving...")
tokenizer.save_pretrained(out_dir)
model.save_pretrained(out_dir, safe_serialization=True)
if Path(out_dir).exists(): save_readme_md(out_dir, repo_id)
del tokenizer
del model
torch.cuda.empty_cache()
gc.collect()
progress(1, desc="Quantized.")
return out_dir
def quantize_gr(repo_id: str, hf_token: str, urls: list[str], newrepo_id: str, is_private: bool=True, is_overwrite: bool=False,
dtype: str="bf16", qtype: str="nf4", mclass: str=get_model_class()[0], progress=gr.Progress(track_tqdm=True)):
if not hf_token: hf_token = os.environ.get("HF_TOKEN") # default huggingface token
if not hf_token: raise gr.Error("HF write token is required for this process.")
set_token(hf_token)
if not newrepo_id: newrepo_id = os.environ.get("HF_OUTPUT_REPO") # default repo id
if not is_repo_name(repo_id): raise gr.Error(f"Invalid repo name: {repo_id}")
if not is_repo_name(newrepo_id): raise gr.Error(f"Invalid repo name: {newrepo_id}")
if not is_overwrite and is_repo_exists(newrepo_id): raise gr.Error(f"Repo already exists: {newrepo_id}")
progress(0, desc="Start quantizing...")
new_path = quantize_repo(repo_id, dtype, qtype, mclass)
if not new_path: return ""
if not urls: urls = []
progress(0.5, desc="Start uploading...")
repo_url = upload_repo(newrepo_id, new_path, is_private)
progress(1, desc="Processing...")
shutil.rmtree(new_path)
urls.append(repo_url)
md = "### Your new repo:\n"
for u in urls:
md += f"[{str(u).split('/')[-2]}/{str(u).split('/')[-1]}]({str(u)})<br>"
torch.cuda.empty_cache()
gc.collect()
return gr.update(value=urls, choices=urls), gr.update(value=md)