|
import os
|
|
if os.environ.get("SPACES_ZERO_GPU") is not None:
|
|
import spaces
|
|
else:
|
|
class spaces:
|
|
@staticmethod
|
|
def GPU(func):
|
|
def wrapper(*args, **kwargs):
|
|
return func(*args, **kwargs)
|
|
return wrapper
|
|
import gradio as gr
|
|
from pathlib import Path
|
|
import gc
|
|
import shutil
|
|
import torch
|
|
from utils import set_token, upload_repo, is_repo_exists, is_repo_name
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM
|
|
from transformers import BitsAndBytesConfig
|
|
|
|
|
|
@spaces.GPU
|
|
def fake_gpu():
|
|
pass
|
|
|
|
|
|
MODEL_CLASS = {
|
|
"AutoModelForCausalLM": [AutoModelForCausalLM, AutoTokenizer],
|
|
}
|
|
|
|
|
|
DTYPE_DICT = {
|
|
"fp16": torch.float16,
|
|
"bf16": torch.bfloat16,
|
|
"fp32": torch.float32,
|
|
"fp8": torch.float8_e4m3fn
|
|
}
|
|
|
|
|
|
def get_model_class():
|
|
return list(MODEL_CLASS.keys())
|
|
|
|
|
|
def get_model(mclass: str):
|
|
return MODEL_CLASS.get(mclass, [AutoModelForCausalLM, AutoTokenizer])[0]
|
|
|
|
|
|
def get_tokenizer(mclass: str):
|
|
return MODEL_CLASS.get(mclass, [AutoModelForCausalLM, AutoTokenizer])[1]
|
|
|
|
|
|
def get_dtype(dtype: str):
|
|
return DTYPE_DICT.get(dtype, torch.bfloat16)
|
|
|
|
|
|
def save_readme_md(dir, repo_id):
|
|
orig_name = repo_id
|
|
orig_url = f"https://huggingface.co./{repo_id}/"
|
|
md = f"""---
|
|
license: other
|
|
language:
|
|
- en
|
|
library_name: transformers
|
|
base_model: {repo_id}
|
|
tags:
|
|
- transformers
|
|
---
|
|
Quants of [{orig_name}]({orig_url}).
|
|
"""
|
|
path = str(Path(dir, "README.md"))
|
|
with open(path, mode='w', encoding="utf-8") as f:
|
|
f.write(md)
|
|
|
|
|
|
@spaces.GPU
|
|
def quantize_repo(repo_id: str, dtype: str="bf16", qtype: str="nf4", mclass: str=get_model_class()[0], progress=gr.Progress(track_tqdm=True)):
|
|
progress(0, desc="Start quantizing...")
|
|
out_dir = repo_id.split("/")[-1]
|
|
|
|
type_kwargs = {}
|
|
if dtype != "default": type_kwargs["torch_dtype"] = get_dtype(dtype)
|
|
|
|
nf4_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_quant_storage=get_dtype(dtype),
|
|
bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=get_dtype(dtype))
|
|
quant_kwargs = {}
|
|
if qtype == "nf4": quant_kwargs["quantization_config"] = nf4_config
|
|
|
|
progress(0.1, desc="Loading...")
|
|
tokenizer = get_tokenizer(mclass).from_pretrained(repo_id, legathy=False)
|
|
model = get_model(mclass).from_pretrained(repo_id, **type_kwargs, **quant_kwargs)
|
|
|
|
progress(0.5, desc="Saving...")
|
|
tokenizer.save_pretrained(out_dir)
|
|
model.save_pretrained(out_dir, safe_serialization=True)
|
|
|
|
if Path(out_dir).exists(): save_readme_md(out_dir, repo_id)
|
|
|
|
del tokenizer
|
|
del model
|
|
torch.cuda.empty_cache()
|
|
gc.collect()
|
|
|
|
progress(1, desc="Quantized.")
|
|
return out_dir
|
|
|
|
def quantize_gr(repo_id: str, hf_token: str, urls: list[str], newrepo_id: str, is_private: bool=True, is_overwrite: bool=False,
|
|
dtype: str="bf16", qtype: str="nf4", mclass: str=get_model_class()[0], progress=gr.Progress(track_tqdm=True)):
|
|
if not hf_token: hf_token = os.environ.get("HF_TOKEN")
|
|
if not hf_token: raise gr.Error("HF write token is required for this process.")
|
|
set_token(hf_token)
|
|
if not newrepo_id: newrepo_id = os.environ.get("HF_OUTPUT_REPO")
|
|
if not is_repo_name(repo_id): raise gr.Error(f"Invalid repo name: {repo_id}")
|
|
if not is_repo_name(newrepo_id): raise gr.Error(f"Invalid repo name: {newrepo_id}")
|
|
if not is_overwrite and is_repo_exists(newrepo_id): raise gr.Error(f"Repo already exists: {newrepo_id}")
|
|
progress(0, desc="Start quantizing...")
|
|
new_path = quantize_repo(repo_id, dtype, qtype, mclass)
|
|
if not new_path: return ""
|
|
if not urls: urls = []
|
|
progress(0.5, desc="Start uploading...")
|
|
repo_url = upload_repo(newrepo_id, new_path, is_private)
|
|
progress(1, desc="Processing...")
|
|
shutil.rmtree(new_path)
|
|
urls.append(repo_url)
|
|
md = "### Your new repo:\n"
|
|
for u in urls:
|
|
md += f"[{str(u).split('/')[-2]}/{str(u).split('/')[-1]}]({str(u)})<br>"
|
|
torch.cuda.empty_cache()
|
|
gc.collect()
|
|
return gr.update(value=urls, choices=urls), gr.update(value=md)
|
|
|
|
|