Text Generation
Transformers
Safetensors
dbrx
conversational
text-generation-inference

The fused expert parameters means load_in_4bit doesn't work properly, nor does LoRA

#10
by tdrussell - opened

I noticed that in the modeling code, the expert parameters are "fused". You call .view() on the parameters, index the chosen expert, then do a matmul() manually. As opposed to something like the Transformers Mixtral code, where each expert is just a normal nn.Linear layer.

This has a few unfortunate downsides. Trying to load the model with bitsandbytes 4 bit quantization doesn't work as it should, in the sense that the expert parameters aren't quantized, since they aren't simple nn.Linear layers. I confirmed this myself. You should be able to load this model in 4 bit with 96GB of VRAM, but it's not even close if the expert parameters aren't quantized.

I also think that the PEFT code for training LoRA or QLoRA won't work, though I haven't tried it (as I can't even load the model). This is because PEFT is based around adapting nn.Linear layers.

Is there any chance of adding some kind of option for structuring the expert layers like Mixtral does, which is a separate nn.Linear layer for each expert? This would make a lot of things easier, namely making it more convenient to load the model on limited VRAM and to fine tune it with LoRA. I don't know if this would require the actual weights on disk to change, so it might complicate things.

Tried both bitsandbytes 4 bit & 8 bit quantization, not worked. Memory cost is still same to bf16

Tried both bitsandbytes 4 bit & 8 bit quantization, not worked. Memory cost is still same to bf16

how to load this model in 8bit?

I should have read this here first! https://github.com/TimDettmers/bitsandbytes/issues/1155

I change the model format using this script:

import json
from safetensors import safe_open
from safetensors.torch import save_file
from pathlib import Path

model_dir = Path('your_model_dir')
output_dir = Path('your_output_dir')

NUM_EXPERTS = 16
HIDDEN_SIZE = 6144
FFN_HIDDEN_SIZE = 10752

def change_tensor(tensor, reverse=False):

    output = [x.contiguous() if not reverse else x.t().contiguous() for x in tensor.reshape(NUM_EXPERTS, FFN_HIDDEN_SIZE, HIDDEN_SIZE)]

    return output

def change_mlp(tensors):

    keys = list(tensors.keys())
    for k in keys:
        if any([x in k for x in ['w1', 'v1', 'w2']]):
            prefix,dtype = k.rsplit('.', 1)
            tensor = tensors.pop(k)
            output_tensor = change_tensor(tensor, dtype=='w2')
            for i,t in enumerate(output_tensor):
                tensors[f'{prefix}.{i}.{dtype}.weight'] = t

    return tensors

for file in model_dir.glob('*.safetensors'):
    print(file)
    tensors = {}
    with safe_open(file, 'pt') as f:
        metadata = f.metadata()
        for k in f.keys():
            tensors[k] = f.get_tensor(k)
    tensors = change_mlp(tensors)
    save_file(tensors, (output_dir / file.name).as_posix(), metadata)

with open(model_dir / 'model.safetensors.index.json') as f:
    weight_map = json.load(f)

weight_keys = list(weight_map['weight_map'])
for k in weight_keys:
    if any([x in k for x in ['w1', 'v1', 'w2']]):
        prefix,dtype = k.rsplit('.', 1)
        value = weight_map['weight_map'].pop(k)
        for i in range(NUM_EXPERTS):
            weight_map['weight_map'][f'{prefix}.{i}.{dtype}.weight'] = value

sorted_map = sorted(list(weight_map['weight_map'].items()))
weight_map['weight_map'] = dict(sorted_map)

with open(output_dir / 'model.safetensors.index.json', 'w') as f:
    json.dump(weight_map, f, indent=4)

Then, inside file modeling_dbrx.py I change some here:

from this

class DbrxExpertGLU(nn.Module):

    def __init__(self, hidden_size: int, ffn_hidden_size: int,
                 moe_num_experts: int, ffn_act_fn: dict):
        super().__init__()
        self.hidden_size = hidden_size
        self.ffn_hidden_size = ffn_hidden_size
        self.moe_num_experts = moe_num_experts

        self.w1 = nn.Parameter(
            torch.empty(moe_num_experts * ffn_hidden_size, hidden_size))
        self.v1 = nn.Parameter(
            torch.empty(moe_num_experts * ffn_hidden_size, hidden_size))
        self.w2 = nn.Parameter(
            torch.empty(moe_num_experts * ffn_hidden_size, hidden_size))
        self.activation_fn = resolve_ffn_act_fn(ffn_act_fn)

    def forward(self, x: torch.Tensor, expert_idx: int) -> torch.Tensor:
        expert_w1 = self.w1.view(self.moe_num_experts, self.ffn_hidden_size,
                                 self.hidden_size)[expert_idx]
        expert_v1 = self.v1.view(self.moe_num_experts, self.ffn_hidden_size,
                                 self.hidden_size)[expert_idx]
        expert_w2 = self.w2.view(self.moe_num_experts, self.ffn_hidden_size,
                                 self.hidden_size)[expert_idx]

        x1 = x.matmul(expert_w1.t())
        x2 = x.matmul(expert_v1.t())
        x1 = self.activation_fn(x1)
        x1 = x1 * x2
        x1 = x1.matmul(expert_w2)
        return x1


class DbrxExperts(nn.Module):

    def __init__(self, hidden_size: int, ffn_hidden_size: int,
                 moe_num_experts: int, ffn_act_fn: dict):
        super().__init__()
        self.moe_num_experts = moe_num_experts
        self.mlp = DbrxExpertGLU(hidden_size=hidden_size,
                                 ffn_hidden_size=ffn_hidden_size,
                                 moe_num_experts=moe_num_experts,
                                 ffn_act_fn=ffn_act_fn)

    def forward(self, x: torch.Tensor, weights: torch.Tensor,
                top_weights: torch.Tensor,
                top_experts: torch.LongTensor) -> torch.Tensor:
        bsz, q_len, hidden_size = x.shape
        x = x.view(-1, hidden_size)
        out = torch.zeros_like(x)

        expert_mask = nn.functional.one_hot(
            top_experts, num_classes=self.moe_num_experts).permute(2, 1, 0)
        for expert_idx in range(0, self.moe_num_experts):
            topk_idx, token_idx = torch.where(expert_mask[expert_idx])
            if token_idx.shape[0] == 0:
                continue

            token_list = token_idx.tolist()
            topk_list = topk_idx.tolist()

            expert_tokens = x[None, token_list].reshape(-1, hidden_size)
            expert_out = self.mlp(
                expert_tokens, expert_idx) * top_weights[token_list, topk_list,
                                                         None]

            out.index_add_(0, token_idx, expert_out)

        out = out.reshape(bsz, q_len, hidden_size)
        return out

to this

class DbrxMLP(nn.Module):

    def __init__(self, hidden_size: int, ffn_hidden_size: int, ffn_act_fn: dict):
        super().__init__()

        self.w1 = nn.Linear(hidden_size, ffn_hidden_size, bias=False)
        self.v1 = nn.Linear(hidden_size, ffn_hidden_size, bias=False)
        self.w2 = nn.Linear(ffn_hidden_size, hidden_size, bias=False)
        self.activation_fn = resolve_ffn_act_fn(ffn_act_fn)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:

        return self.w2(self.activation_fn(self.w1(x)) * self.v1(x))


class DbrxExperts(nn.Module):

    def __init__(self, hidden_size: int, ffn_hidden_size: int,
                 moe_num_experts: int, ffn_act_fn: dict):
        super().__init__()
        self.moe_num_experts = moe_num_experts
        self.mlp = nn.ModuleList([DbrxMLP(hidden_size, ffn_hidden_size, ffn_act_fn) for _ in range(moe_num_experts)])

    def forward(self, x: torch.Tensor, weights: torch.Tensor,
                top_weights: torch.Tensor,
                top_experts: torch.LongTensor) -> torch.Tensor:
        bsz, q_len, hidden_size = x.shape
        x = x.view(-1, hidden_size)
        out = torch.zeros_like(x)

        expert_mask = nn.functional.one_hot(
            top_experts, num_classes=self.moe_num_experts).permute(2, 1, 0)
        for expert_idx in range(0, self.moe_num_experts):
            topk_idx, token_idx = torch.where(expert_mask[expert_idx])
            if token_idx.shape[0] == 0:
                continue

            token_list = token_idx.tolist()
            topk_list = topk_idx.tolist()

            expert_tokens = x[None, token_list].reshape(-1, hidden_size)
            expert_out = self.mlp[expert_idx](expert_tokens) * top_weights[token_list, topk_list, None]

            out.index_add_(0, token_idx, expert_out)

        out = out.reshape(bsz, q_len, hidden_size)
        return out

And from this

class DbrxPreTrainedModel(PreTrainedModel):
    config_class = DbrxConfig
    base_model_prefix = 'transformer'
    supports_gradient_checkpointing = True
    _no_split_modules = ['DbrxBlock']
    _skip_keys_device_placement = ['past_key_values']
    _supports_flash_attn_2 = True
    _supports_sdpa = False
    _supports_cache_class = True

    def _init_weights(self, module: nn.Module):
        std = self.config.initializer_range
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=std)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=std)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()
        elif isinstance(module, nn.LayerNorm):
            module.weight.data.normal_(mean=0.0, std=std)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, DbrxExpertGLU):
            module.w1.data.normal_(mean=0.0, std=std)
            module.v1.data.normal_(mean=0.0, std=std)
            module.w2.data.normal_(mean=0.0, std=std)

To this

class DbrxPreTrainedModel(PreTrainedModel):
    config_class = DbrxConfig
    base_model_prefix = 'transformer'
    supports_gradient_checkpointing = True
    _no_split_modules = ['DbrxBlock']
    _skip_keys_device_placement = ['past_key_values']
    _supports_flash_attn_2 = True
    _supports_sdpa = False
    _supports_cache_class = True

    def _init_weights(self, module: nn.Module):
        std = self.config.initializer_range
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=std)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=std)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()
        elif isinstance(module, nn.LayerNorm):
            module.weight.data.normal_(mean=0.0, std=std)
            if module.bias is not None:
                module.bias.data.zero_()

I tried something very similar to the above last night, on my own, and can confirm it works. Once the experts are just normal nn.Linear layers they can be loaded with 4 bit quantization. Well, the model "works" in the sense that when you load it in 4 bit it's mostly okay, but will frequently misspell words, double up commas, output random garbage tokens, etc. So something is still off about it (might just be unusually sensitive to load_in_4bit quantization).

Would be nice if we could get some way to do this officially. Otherwise, I imagine someone will upload a model with these changes, and lots of people will just use that version, as you can train LoRA on it etc. Alternatively, bitsandbytes and PEFT would have to add explicit support somehow for the current fused weight architecture.

I uploaded an adjusted version based on the code from fahadh4ilyas, so people don't need to make the changes manually. Thank you so much for the scripts and the adjustments. Here is the model: SinclairSchneider/dbrx-instruct-quantization-fixed

@SinclairSchneider can you please upload the 4bit quantized version of the model? Thanks!

I uploaded an adjusted version based on the code from fahadh4ilyas, so people don't need to make the changes manually. Thank you so much for the scripts and the adjustments. Here is the model: SinclairSchneider/dbrx-instruct-quantization-fixed

Thank you! so this is a 8bit version?

Also, can you please share how the 4-bit/8-bit quantized model is performing? Thanks!

I uploaded an adjusted version based on the code from fahadh4ilyas, so people don't need to make the changes manually. Thank you so much for the scripts and the adjustments. Here is the model: SinclairSchneider/dbrx-instruct-quantization-fixed

Thank you! so this is a 8bit version?

Based on the file size, it is the original model with bf16.

It's not quantized but you can use "load_in_4bit=True" or "load_in_8bit=True" without running into out of memory issues like when using the original model.

It's not quantized but you can use "load_in_4bit=True" or "load_in_8bit=True" without running into out of memory issues like when using the original model.

Thanks! Understood

@SinclairSchneider is it possible for you to quantized the model using that option? Since you are loading the model in 4bit mode, I think the script is converting the model to 4 bits on the fly which means you can do the same for generating the quantized model files.

I tried something very similar to the above last night, on my own, and can confirm it works. Once the experts are just normal nn.Linear layers they can be loaded with 4 bit quantization. Well, the model "works" in the sense that when you load it in 4 bit it's mostly okay, but will frequently misspell words, double up commas, output random garbage tokens, etc. So something is still off about it (might just be unusually sensitive to load_in_4bit quantization).

Would be nice if we could get some way to do this officially. Otherwise, I imagine someone will upload a model with these changes, and lots of people will just use that version, as you can train LoRA on it etc. Alternatively, bitsandbytes and PEFT would have to add explicit support somehow for the current fused weight architecture.

In my case, it works quite well. I literally only change nn.Parameter into nn.Linear. Even when load it in cpu without quantization and generate using it, the generation result is no different with this repo.

@fahadh4ilyas You are awesome! Anyone interested, the dbrx-instruct converted model is currently getting uploaded to hf 33% complete:

https://huggingface.co./LnL-AI/dbrx-instruct-converted/discussions/2

and hack autogptq wip quant session using converted code/model in progress at: https://github.com/AutoGPTQ/AutoGPTQ/pull/625

I change the model format using this script:

import json
from safetensors import safe_open
from safetensors.torch import save_file
from pathlib import Path

model_dir = Path('your_model_dir')
output_dir = Path('your_output_dir')

NUM_EXPERTS = 16
HIDDEN_SIZE = 6144
FFN_HIDDEN_SIZE = 10752

def change_tensor(tensor, reverse=False):

    output = [x.contiguous() if not reverse else x.t().contiguous() for x in tensor.reshape(NUM_EXPERTS, FFN_HIDDEN_SIZE, HIDDEN_SIZE)]

    return output

def change_mlp(tensors):

    keys = list(tensors.keys())
    for k in keys:
        if any([x in k for x in ['w1', 'v1', 'w2']]):
            prefix,dtype = k.rsplit('.', 1)
            tensor = tensors.pop(k)
            output_tensor = change_tensor(tensor, dtype=='w2')
            for i,t in enumerate(output_tensor):
                tensors[f'{prefix}.{i}.{dtype}.weight'] = t

    return tensors

for file in model_dir.glob('*.safetensors'):
    print(file)
    tensors = {}
    with safe_open(file, 'pt') as f:
        metadata = f.metadata()
        for k in f.keys():
            tensors[k] = f.get_tensor(k)
    tensors = change_mlp(tensors)
    save_file(tensors, (output_dir / file.name).as_posix(), metadata)

with open(model_dir / 'model.safetensors.index.json') as f:
    weight_map = json.load(f)

weight_keys = list(weight_map['weight_map'])
for k in weight_keys:
    if any([x in k for x in ['w1', 'v1', 'w2']]):
        prefix,dtype = k.rsplit('.', 1)
        value = weight_map['weight_map'].pop(k)
        for i in range(NUM_EXPERTS):
            weight_map['weight_map'][f'{prefix}.{i}.{dtype}.weight'] = value

sorted_map = sorted(list(weight_map['weight_map'].items()))
weight_map['weight_map'] = dict(sorted_map)

with open(output_dir / 'model.safetensors.index.json', 'w') as f:
    json.dump(weight_map, f, indent=4)

Then, inside file modeling_dbrx.py I change some here:

from this

class DbrxExpertGLU(nn.Module):

    def __init__(self, hidden_size: int, ffn_hidden_size: int,
                 moe_num_experts: int, ffn_act_fn: dict):
        super().__init__()
        self.hidden_size = hidden_size
        self.ffn_hidden_size = ffn_hidden_size
        self.moe_num_experts = moe_num_experts

        self.w1 = nn.Parameter(
            torch.empty(moe_num_experts * ffn_hidden_size, hidden_size))
        self.v1 = nn.Parameter(
            torch.empty(moe_num_experts * ffn_hidden_size, hidden_size))
        self.w2 = nn.Parameter(
            torch.empty(moe_num_experts * ffn_hidden_size, hidden_size))
        self.activation_fn = resolve_ffn_act_fn(ffn_act_fn)

    def forward(self, x: torch.Tensor, expert_idx: int) -> torch.Tensor:
        expert_w1 = self.w1.view(self.moe_num_experts, self.ffn_hidden_size,
                                 self.hidden_size)[expert_idx]
        expert_v1 = self.v1.view(self.moe_num_experts, self.ffn_hidden_size,
                                 self.hidden_size)[expert_idx]
        expert_w2 = self.w2.view(self.moe_num_experts, self.ffn_hidden_size,
                                 self.hidden_size)[expert_idx]

        x1 = x.matmul(expert_w1.t())
        x2 = x.matmul(expert_v1.t())
        x1 = self.activation_fn(x1)
        x1 = x1 * x2
        x1 = x1.matmul(expert_w2)
        return x1


class DbrxExperts(nn.Module):

    def __init__(self, hidden_size: int, ffn_hidden_size: int,
                 moe_num_experts: int, ffn_act_fn: dict):
        super().__init__()
        self.moe_num_experts = moe_num_experts
        self.mlp = DbrxExpertGLU(hidden_size=hidden_size,
                                 ffn_hidden_size=ffn_hidden_size,
                                 moe_num_experts=moe_num_experts,
                                 ffn_act_fn=ffn_act_fn)

    def forward(self, x: torch.Tensor, weights: torch.Tensor,
                top_weights: torch.Tensor,
                top_experts: torch.LongTensor) -> torch.Tensor:
        bsz, q_len, hidden_size = x.shape
        x = x.view(-1, hidden_size)
        out = torch.zeros_like(x)

        expert_mask = nn.functional.one_hot(
            top_experts, num_classes=self.moe_num_experts).permute(2, 1, 0)
        for expert_idx in range(0, self.moe_num_experts):
            topk_idx, token_idx = torch.where(expert_mask[expert_idx])
            if token_idx.shape[0] == 0:
                continue

            token_list = token_idx.tolist()
            topk_list = topk_idx.tolist()

            expert_tokens = x[None, token_list].reshape(-1, hidden_size)
            expert_out = self.mlp(
                expert_tokens, expert_idx) * top_weights[token_list, topk_list,
                                                         None]

            out.index_add_(0, token_idx, expert_out)

        out = out.reshape(bsz, q_len, hidden_size)
        return out

to this

class DbrxMLP(nn.Module):

    def __init__(self, hidden_size: int, ffn_hidden_size: int, ffn_act_fn: dict):
        super().__init__()

        self.w1 = nn.Linear(hidden_size, ffn_hidden_size, bias=False)
        self.v1 = nn.Linear(hidden_size, ffn_hidden_size, bias=False)
        self.w2 = nn.Linear(ffn_hidden_size, hidden_size, bias=False)
        self.activation_fn = resolve_ffn_act_fn(ffn_act_fn)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:

        return self.w2(self.activation_fn(self.w1(x)) * self.v1(x))


class DbrxExperts(nn.Module):

    def __init__(self, hidden_size: int, ffn_hidden_size: int,
                 moe_num_experts: int, ffn_act_fn: dict):
        super().__init__()
        self.moe_num_experts = moe_num_experts
        self.mlp = nn.ModuleList([DbrxMLP(hidden_size, ffn_hidden_size, ffn_act_fn) for _ in range(moe_num_experts)])

    def forward(self, x: torch.Tensor, weights: torch.Tensor,
                top_weights: torch.Tensor,
                top_experts: torch.LongTensor) -> torch.Tensor:
        bsz, q_len, hidden_size = x.shape
        x = x.view(-1, hidden_size)
        out = torch.zeros_like(x)

        expert_mask = nn.functional.one_hot(
            top_experts, num_classes=self.moe_num_experts).permute(2, 1, 0)
        for expert_idx in range(0, self.moe_num_experts):
            topk_idx, token_idx = torch.where(expert_mask[expert_idx])
            if token_idx.shape[0] == 0:
                continue

            token_list = token_idx.tolist()
            topk_list = topk_idx.tolist()

            expert_tokens = x[None, token_list].reshape(-1, hidden_size)
            expert_out = self.mlp[expert_idx](expert_tokens) * top_weights[token_list, topk_list, None]

            out.index_add_(0, token_idx, expert_out)

        out = out.reshape(bsz, q_len, hidden_size)
        return out

And from this

class DbrxPreTrainedModel(PreTrainedModel):
    config_class = DbrxConfig
    base_model_prefix = 'transformer'
    supports_gradient_checkpointing = True
    _no_split_modules = ['DbrxBlock']
    _skip_keys_device_placement = ['past_key_values']
    _supports_flash_attn_2 = True
    _supports_sdpa = False
    _supports_cache_class = True

    def _init_weights(self, module: nn.Module):
        std = self.config.initializer_range
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=std)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=std)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()
        elif isinstance(module, nn.LayerNorm):
            module.weight.data.normal_(mean=0.0, std=std)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, DbrxExpertGLU):
            module.w1.data.normal_(mean=0.0, std=std)
            module.v1.data.normal_(mean=0.0, std=std)
            module.w2.data.normal_(mean=0.0, std=std)

To this

class DbrxPreTrainedModel(PreTrainedModel):
    config_class = DbrxConfig
    base_model_prefix = 'transformer'
    supports_gradient_checkpointing = True
    _no_split_modules = ['DbrxBlock']
    _skip_keys_device_placement = ['past_key_values']
    _supports_flash_attn_2 = True
    _supports_sdpa = False
    _supports_cache_class = True

    def _init_weights(self, module: nn.Module):
        std = self.config.initializer_range
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=std)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=std)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()
        elif isinstance(module, nn.LayerNorm):
            module.weight.data.normal_(mean=0.0, std=std)
            if module.bias is not None:
                module.bias.data.zero_()

Hi @fahadh4ilyas , I have tried your modification but have this error:

 File "modeling_dbrx.py", line 642, in __init__
    self.norm_1 = nn.LayerNorm(hidden_size, bias=False)
TypeError: __init__() got an unexpected keyword argument 'bias'

Hi @fahadh4ilyas , I have tried your modification but have this error:

 File "modeling_dbrx.py", line 642, in __init__
    self.norm_1 = nn.LayerNorm(hidden_size, bias=False)
TypeError: __init__() got an unexpected keyword argument 'bias'

I think you have to update your torch. Because that error implied that nn.LayerNorm has no bias parameter in its init method. But, from torch documentation here, clearly bias exists there.

@fahadh4ilyas Thanks for the clarification! I think now I can load the model in 4 bit, taking ~71G ram in total, but the loading is very slow. The generation also contains many random and garbage tokens.

dbrx-base-converted now on hf: https://huggingface.co./LnL-AI/dbrx-base-converted

(pending upload...should be complete in about 10 minutes). I have cancelled the broken upload of instruct-converted model now that base is here.

bnb 4 bit version can be found here: https://huggingface.co./PrunaAI/dbrx-base-bnb-4bit

dbrx-base-converted-v2 now on hf: https://huggingface.co./LnL-AI/dbrx-base-converted-v2 with Wqkv split into q, k, v for potentially better quant compatible. exllama v2 latest code also does this split. Testing and quant has not been verified yet.

Again thanks to @fahadh4ilyas for the v2 changes.

bnb 4 bit version can be found here: https://huggingface.co./PrunaAI/dbrx-base-bnb-4bit

Is there bnb 4bit for dbrx-instruct? The base model is not good at instruction following. Thanks!

@MLDataScientist 4bit quant is available for apple mlx (but need like m ultra) and exllama v2 has it too. But if you want base to instruct follow, it doesn't do that. It is base vs instruct for good reason.

bnb 4 bit version can be found here: https://huggingface.co./PrunaAI/dbrx-base-bnb-4bit

Is there bnb 4bit for dbrx-instruct? The base model is not good at instruction following. Thanks!

Yes you can find it here: https://huggingface.co./PrunaAI/dbrx-instruct-bnb-4bit

bnb 4 bit version can be found here: https://huggingface.co./PrunaAI/dbrx-base-bnb-4bit

Is there bnb 4bit for dbrx-instruct? The base model is not good at instruction following. Thanks!

Yes you can find it here: https://huggingface.co./PrunaAI/dbrx-instruct-bnb-4bit

amazing. Thank you!

@MLDataScientist 4bit quant is available for apple mlx (but need like m ultra) and exllama v2 has it too. But if you want base to instruct follow, it doesn't do that. It is base vs instruct for good reason.

I know, right! That is why I needed the instruct bnb 4bit. Fortunately, we already have it as @johnrachwanpruna mentioned above.

bnb 4 bit version can be found here: https://huggingface.co./PrunaAI/dbrx-base-bnb-4bit

Is there bnb 4bit for dbrx-instruct? The base model is not good at instruction following. Thanks!

Yes you can find it here: https://huggingface.co./PrunaAI/dbrx-instruct-bnb-4bit

@johnrachwanpruna , Is this dbrx-instruct converted using the method described by @fahadh4ilyas ?

bnb 4 bit version can be found here: https://huggingface.co./PrunaAI/dbrx-base-bnb-4bit

Is there bnb 4bit for dbrx-instruct? The base model is not good at instruction following. Thanks!

Yes you can find it here: https://huggingface.co./PrunaAI/dbrx-instruct-bnb-4bit

@johnrachwanpruna , Is this dbrx-instruct converted using the method described by @fahadh4ilyas ?

Yes exactly :)

@fahadh4ilyas @Qubitium Looks like the converted models don't play nice with 4bit-qlora+fsdp:

https://gist.github.com/winglian/348f792e62386007bc589667f01d2cae

@fahadh4ilyas @Qubitium Looks like the converted models don't play nice with 4bit-qlora+fsdp:

https://gist.github.com/winglian/348f792e62386007bc589667f01d2cae

Which model are you using? v1 or v2?

@fahadh4ilyas @winglian was using v2 for that crash

Sign up or log in to comment