--- license: apache-2.0 --- ## Upstream model config ```json { "_name_or_path": "output/hermes-llama2-4k/checkpoint-2259", "architectures": [ "LlamaForCausalLM" ], "bos_token_id": 1, "eos_token_id": 2, "hidden_act": "silu", "hidden_size": 4096, "initializer_range": 0.02, "intermediate_size": 11008, "max_position_embeddings": 4096, "model_type": "llama", "num_attention_heads": 32, "num_hidden_layers": 32, "num_key_value_heads": 32, "pad_token_id": 0, "pretraining_tp": 1, "rms_norm_eps": 1e-05, "rope_scaling": null, "tie_word_embeddings": false, "torch_dtype": "bfloat16", "transformers_version": "4.32.0.dev0", "use_cache": false, "vocab_size": 32000 } ``` ### Dataset ``` DATASET = "abideen/Cosmopedia-100k-pretrain" # @param from datasets import load_dataset # converted to BitLinear class BitLinear(nn.Linear): def forward(self, x): w = self.weight # a weight tensor with shape [d, k] x = x.to(w.device) RMSNorm = LlamaRMSNorm(x.shape[-1]).to(w.device) x_norm = RMSNorm(x) # A trick for implementing Straight−Through−Estimator (STE) using detach() x_quant = x_norm + (activation_quant(x_norm) - x_norm).detach() w_quant = w + (weight_quant(w) - w).detach() y = F.linear(x_quant, w_quant) return y ### Create the llama model with our custom config. Convert it to bitnet. model = LlamaForCausalLM(config) convert_to_bitnet(model, copy_weights=False) ``` ### Training ```python args = TrainingArguments( output_dir=output_path, per_device_train_batch_size=BATCH_SIZE, logging_steps=100, gradient_accumulation_steps=2, num_train_epochs=EPOCHS, weight_decay=0.01, warmup_steps=0.1, lr_scheduler_type="cosine", learning_rate=LEARNING_RATE, # max_steps=5000, save_steps=0.25, fp16=True, report_to="wandb" ) trainer = Trainer( model=model, tokenizer=tokenizer, args=args, data_collator=data_collator, train_dataset=tokenized_data["train"], ) trainer.train() ``` ### Inference ```python from transformers import AutoModelForCausalLM, AutoTokenizer from transformers.models.llama.modeling_llama import * # Load a pretrained BitNet model model = "saadnaeem/Llama2-70M-Cosmopedia-100k-Pretrain" tokenizer = AutoTokenizer.from_pretrained(model) model = AutoModelForCausalLM.from_pretrained(model) def activation_quant(x): scale = 127.0 / x.abs().max(dim=-1, keepdim=True).values.clamp_(min=1e-5) y = (x * scale).round().clamp_(-128, 127) y = y / scale return y def weight_quant(w): scale = 1.0 / w.abs().mean().clamp_(min=1e-5) u = (w * scale).round().clamp_(-1, 1) u = u / scale return u class BitLinear(nn.Linear): def forward(self, x): w = self.weight # a weight tensor with shape [d, k] x = x.to(w.device) RMSNorm = LlamaRMSNorm(x.shape[-1]).to(w.device) x_norm = RMSNorm(x) # A trick for implementing Straight−Through−Estimator (STE) using detach() x_quant = x_norm + (activation_quant(x_norm) - x_norm).detach() w_quant = w + (weight_quant(w) - w).detach() y = F.linear(x_quant, w_quant) return y def convert_to_bitnet(model, copy_weights): for name, module in model.named_modules(): # Replace linear layers with BitNet if isinstance(module, LlamaSdpaAttention) or isinstance(module, LlamaMLP): for child_name, child_module in module.named_children(): if isinstance(child_module, nn.Linear): bitlinear = BitLinear(child_module.in_features, child_module.out_features, child_module.bias is not None).to(device="cuda:0") if copy_weights: bitlinear.weight = child_module.weight if child_module.bias is not None: bitlinear.bias = child_module.bias setattr(module, child_name, bitlinear) # Remove redundant input_layernorms elif isinstance(module, LlamaDecoderLayer): for child_name, child_module in module.named_children(): if isinstance(child_module, LlamaRMSNorm) and child_name == "input_layernorm": setattr(module, child_name, nn.Identity().to(device="cuda:0")) convert_to_bitnet(model, copy_weights=True) model.to(device="cuda:0") prompt = "What is Machine Learning?" inputs = tokenizer(prompt, return_tensors="pt").to(model.device) generate_ids = model.generate(inputs.input_ids, max_length=50) tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] ```