Saiga – Llama 3 8B – OmniQuant
Based on Saiga Llama 3 8B.
Quantized with OmniQuant.
Evaluation
PPL (↓)
wiki | |
---|---|
FP | 7,862 |
Quantized | 8,615 |
Accuracy on English Benchmarks, % (↑)
piqa | arc_easy | arc_challenge | boolq | hellaswag | winogrande | mmlu_humanities | mmlu_social_sciences | mmlu_stem | mmlu_other | |
---|---|---|---|---|---|---|---|---|---|---|
FP | 78,5 | 82,2 | 50,4 | 82,7 | 58,1 | 72,4 | 65,5 | 72,6 | 53,8 | 68,4 |
Quantized | 78,5 | 80,8 | 47,6 | 81,7 | 56,9 | 71,2 | 62,3 | 68,9 | 49,7 | 63,3 |
Accuracy on Russian Benchmarks, % (↑)
danetqa | terra | rwsd | muserc | rucos | lidirus | parus | rcb | russe | rucola | |
---|---|---|---|---|---|---|---|---|---|---|
FP | 74,9 | 52,1 | 51,5 | 55,9 | 58,1 | 59,5 | 69,0 | 34,1 | 38,8 | 67,5 |
Quantized | 65,4 | 50,5 | 49,5 | 60,7 | 53,7 | 50,9 | 71,0 | 33,6 | 40,8 | 67,5 |
Summary
Avg acc diff on Eng, % (↑) | Avg acc diff on Rus, % (↑) | Occupied disk space, % (↓) | |
---|---|---|---|
FP | 0 | 0 | 100 |
Quantized | -2,4 | -1,8 | 35,7 |
Examples
Imports and Model Loading
Expand
import gc
import auto_gptq.nn_modules.qlinear.qlinear_cuda as qlinear_cuda
import auto_gptq.nn_modules.qlinear.qlinear_triton as qlinear_triton
import torch
from accelerate import (
init_empty_weights,
infer_auto_device_map,
load_checkpoint_in_model,
)
from tqdm import tqdm
from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoTokenizer,
pipeline,
)
def get_named_linears(model):
return {
name: module for name, module in model.named_modules()
if isinstance(module, torch.nn.Linear)
}
def set_module(model, name, module):
parent = model
levels = name.split('.')
for i in range(len(levels) - 1):
cur_name = levels[i]
if cur_name.isdigit():
parent = parent[int(cur_name)]
else:
parent = getattr(parent, cur_name)
setattr(parent, levels[-1], module)
def load_model(model_path):
# Based on: https://github.com/OpenGVLab/OmniQuant/blob/main/runing_quantized_mixtral_7bx8.ipynb
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
if not hasattr(config, 'quantization_config'):
raise AttributeError(
f'No quantization info found in model config "{model_path}"'
f' (`quantization_config` section is missing).'
)
wbits = config.quantization_config['bits']
group_size = config.quantization_config['group_size']
# We are going to init an ordinary model and then manually replace all Linears with QuantLinears
del config.quantization_config
with init_empty_weights():
model = AutoModelForCausalLM.from_config(config=config, torch_dtype=torch.float16, trust_remote_code=True)
layers = model.model.layers
for i in tqdm(range(len(layers))):
layer = layers[i]
named_linears = get_named_linears(layer)
for name, module in named_linears.items():
params = (
wbits, group_size,
module.in_features, module.out_features,
module.bias is not None
)
if wbits in [2, 4]:
q_linear = qlinear_triton.QuantLinear(*params)
elif wbits == 3:
q_linear = qlinear_cuda.QuantLinear(*params)
else:
raise NotImplementedError("Only 2, 3 and 4 bits are supported.")
q_linear.to(next(layer.parameters()).device)
set_module(layer, name, q_linear)
torch.cuda.empty_cache()
gc.collect()
model.tie_weights()
device_map = infer_auto_device_map(model)
print("Loading pre-computed quantized weights...")
load_checkpoint_in_model(
model, checkpoint=model_path,
device_map=device_map, offload_state_dict=True,
)
print("Model loaded successfully!")
return model
Inference
model_path = "compressa-ai/Saiga-Llama-3-8B-OmniQuant"
model = load_model(model_path).cuda()
tokenizer = AutoTokenizer.from_pretrained(
model_path, use_fast=False, trust_remote_code=True
)
system_message = "Ты — дружелюбный чат-бот, который всегда отвечает как пират."
user_message = "Куда мы направляемся, капитан?"
messages = [
{"role": "system", "content": system_message},
{"role": "user", "content": user_message},
]
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
inputs = tokenizer(prompt, return_tensors="pt")
inputs = {k: v.cuda() for k, v in inputs.items()}
outputs = model.generate(
**inputs, max_new_tokens=512,
do_sample=True, temperature=0.7, top_p=0.95,
)
response = tokenizer.decode(outputs[0])
continuation = response.removeprefix(prompt).removesuffix(tokenizer.eos_token)
print(f'Prompt:\n{prompt}')
print(f'Continuation:\n{continuation}\n')
Inference Using Pipeline
pipe = pipeline(
"text-generation",
model=model, tokenizer=tokenizer,
max_new_tokens=512, do_sample=True,
temperature=0.7, top_p=0.95,
device=0,
)
prompt = pipe.tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
outputs = pipe(prompt)
response = outputs[0]["generated_text"]
continuation = response.removeprefix(prompt)
print(f'Prompt:\n{prompt}')
print(f'Continuation:\n{continuation}\n')
- Downloads last month
- 15
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social
visibility and check back later, or deploy to Inference Endpoints (dedicated)
instead.
Model tree for compressa-ai/Saiga-Llama-3-8B-OmniQuant
Base model
IlyaGusev/saiga_llama3_8b