ejbejaranos commited on
Commit
4a67c81
β€’
1 Parent(s): 7cb0765

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +151 -0
README.md ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Llama3-8B-ITCL-Bitnet1.6B πŸš€
2
+
3
+ ## Description πŸ“œ
4
+ **Llama3-8B-ITCL-Bitnet1.6B** is an experimental LLM model transformed from Llama3, optimized with bitlinear layers to enhance memory efficiency and inference speed. This model is designed for natural language processing tasks and is particularly useful in environments where resource-efficient performance is required. 🌟
5
+
6
+ ![image/png](https://cdn-uploads.huggingface.co/production/uploads/6419c2f6b4adb0e101b17b6c/WOcl2k9xdLT5aVqh-aERz.png)
7
+
8
+ ## Features 🌈
9
+ - **Model Size:** 8B parameters 🧠
10
+ - **Architecture:** BitNet πŸ—οΈ
11
+ - **Bitlinear Layers:** Reduces weights to values of 1, 0, and -1. βž–
12
+ - **Optimized for:** Fast inference and memory efficiency ⚑
13
+
14
+ ## Requirements πŸ“¦
15
+ Make sure you have the following libraries installed:
16
+
17
+ ```bash
18
+ pip install transformers torch huggingface_hub wandb coloredlogs
19
+ ```
20
+
21
+
22
+ You can install these dependencies using pip! πŸŽ‰
23
+
24
+ ## Usage πŸ”
25
+ ### Loading the Model
26
+ To load the model, you can simply run the following code:
27
+
28
+
29
+ Para usar este modelo, puedes cargarlo desde Hugging Face con el siguiente cΓ³digo:
30
+ ```python
31
+ from transformers import AutoModelForCausalLM, AutoTokenizer
32
+ from transformers.models.llama.modeling_llama import *
33
+ import torch
34
+ from torch import nn
35
+ import torch.nn.functional as F
36
+ import coloredlogs
37
+ import logging
38
+
39
+
40
+ coloredlogs.install(level='INFO', fmt='%(asctime)s - %(levelname)s - %(message)s', logger=logging.getLogger())
41
+ logger = logging.getLogger(__name__)
42
+
43
+
44
+
45
+
46
+ HF_TOKEN = "you_api_key_here"
47
+
48
+ model = "ejbejaranos/Llama3-8B-ITCL-Bitnet1.6B"
49
+
50
+ # Load a pretrained BitNet model
51
+ tokenizer = AutoTokenizer.from_pretrained(model)
52
+
53
+ model = AutoModelForCausalLM.from_pretrained(
54
+ model,
55
+ token=HF_TOKEN
56
+ )
57
+
58
+ # Establece el pad_token_id
59
+ model.config.pad_token_id = tokenizer.eos_token_id
60
+
61
+ def count_parameters(model):
62
+ # Calculate the number of parameters in billions
63
+ num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) / 10**9
64
+ print(f"Model size: {num_params:.3f}B parameters")
65
+ return int(num_params)
66
+
67
+ def activation_quant(x):
68
+ scale = 127.0 / x.abs().max(dim=-1, keepdim=True).values.clamp_(min=1e-5)
69
+ y = (x * scale).round().clamp_(-128, 127)
70
+ y = y / scale
71
+ return y
72
+
73
+ def weight_quant(w):
74
+ scale = 1.0 / w.abs().mean().clamp_(min=1e-5)
75
+ u = (w * scale).round().clamp_(-1, 1)
76
+ u = u / scale
77
+ return u
78
+
79
+ class BitLinear(nn.Linear):
80
+ def forward(self, x):
81
+ w = self.weight # a weight tensor with shape [d, k]
82
+ x = x.to(w.device)
83
+ RMSNorm = LlamaRMSNorm(x.shape[-1]).to(w.device)
84
+ x_norm = RMSNorm(x)
85
+ x_quant = x_norm + (activation_quant(x_norm) - x_norm).detach()
86
+ w_quant = w + (weight_quant(w) - w).detach()
87
+ y = F.linear(x_quant, w_quant)
88
+ return y
89
+
90
+ def convert_to_bitnet(model, copy_weights):
91
+ for name, module in model.named_modules():
92
+ if isinstance(module, LlamaSdpaAttention) or isinstance(module, LlamaMLP):
93
+ for child_name, child_module in module.named_children():
94
+ if isinstance(child_module, nn.Linear):
95
+ bitlinear = BitLinear(child_module.in_features, child_module.out_features, child_module.bias is not None).to(device="cuda:0")
96
+ if copy_weights:
97
+ bitlinear.weight = child_module.weight
98
+ if child_module.bias is not None:
99
+ bitlinear.bias = child_module.bias
100
+ setattr(module, child_name, bitlinear)
101
+ elif isinstance(module, LlamaDecoderLayer):
102
+ for child_name, child_module in module.named_children():
103
+ if isinstance(child_module, LlamaRMSNorm) and child_name == "input_layernorm":
104
+ setattr(module, child_name, nn.Identity().to(device="cuda:0"))
105
+
106
+ convert_to_bitnet(model, copy_weights=True)
107
+ model.to(device="cuda:0")
108
+
109
+
110
+ logger.info(f"πŸ”’ Number of parameters in the model after extracting weights: {count_parameters(model)}")
111
+ logger.info(f"πŸ“ Reduced model structure:\n{model}")
112
+
113
+
114
+
115
+
116
+
117
+ prompt = "What is the color of sky?"
118
+ inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True).to(model.device)
119
+ inputs['attention_mask'] = inputs['input_ids'] != model.config.pad_token_id
120
+
121
+ generate_ids = model.generate(inputs.input_ids, attention_mask=inputs['attention_mask'], max_length=250)
122
+ decoded_output = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
123
+
124
+ print(decoded_output[0]) # Print the generated response
125
+
126
+
127
+ ```
128
+
129
+
130
+ ### Performing Inference
131
+ Generate text using the model to unleash its power! πŸ’¬βœ¨
132
+
133
+ ## Training πŸ‹οΈ
134
+ To train the model, configure your settings and implement your training logic. πŸ› οΈ
135
+
136
+ ## Contributions 🀝
137
+ If you would like to contribute to this project, please follow these steps:
138
+ 1. Fork the repository. 🍴
139
+ 2. Create your branch (`git checkout -b feature-new-feature`). 🌿
140
+ 3. Make your changes and commit. πŸ“…
141
+ 4. Push to the branch. πŸ“€
142
+ 5. Open a Pull Request. πŸ“¬
143
+
144
+ ## License πŸ“„
145
+ This project is licensed under the MIT License. See the `LICENSE` file for details.
146
+
147
+ ## Contact πŸ“«
148
+ For questions or suggestions, feel free to reach out to me:
149
+ - **Email:** [email protected]
150
+ - **GitHub:** [ejbejaranos](https://github.com/ejbejaranos) 🌐
151
+