anto18671 commited on
Commit
e17f601
·
verified ·
1 Parent(s): 41645fb

Upload 2 files

Browse files
Files changed (2) hide show
  1. configuration_lumenspark.py +52 -0
  2. modeling_lumenspark.py +295 -0
configuration_lumenspark.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+ # ----------------------------
4
+ # Define Lumenspark Configuration
5
+ # ----------------------------
6
+
7
+ class LumensparkConfig(PretrainedConfig):
8
+ """
9
+ Configuration class for the Lumenspark model.
10
+ Stores model hyperparameters like sequence length, embedding dimension, number of layers, and others.
11
+ """
12
+ model_type = "lumenspark"
13
+
14
+ def __init__(
15
+ self,
16
+ seq_length=768,
17
+ vocab_size=50257,
18
+ embed_dim=768,
19
+ depth=8,
20
+ heads=12,
21
+ dropout=1/17,
22
+ k=384,
23
+ rank=256,
24
+ **kwargs
25
+ ):
26
+ super().__init__(**kwargs)
27
+ self.vocab_size = vocab_size
28
+ self.embed_dim = embed_dim
29
+ self.depth = depth
30
+ self.heads = heads
31
+ self.seq_length = seq_length
32
+ self.dropout = dropout
33
+ self.k = k
34
+ self.rank = rank
35
+
36
+ def to_dict(self):
37
+ """
38
+ Converts the configuration parameters to a dictionary format.
39
+ Useful for saving the configuration or inspecting model settings.
40
+ """
41
+ output = super().to_dict()
42
+ output.update({
43
+ "vocab_size": self.vocab_size,
44
+ "embed_dim": self.embed_dim,
45
+ "depth": self.depth,
46
+ "heads": self.heads,
47
+ "seq_length": self.seq_length,
48
+ "dropout": self.dropout,
49
+ "k": self.k,
50
+ "rank": self.rank,
51
+ })
52
+ return output
modeling_lumenspark.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from lumenspark.configuration_lumenspark import LumensparkConfig
2
+ from transformers import PreTrainedModel, GPT2Tokenizer
3
+ from huggingface_hub import hf_hub_download
4
+ from safetensors import safe_open
5
+ from torch import nn
6
+ import torch
7
+ import math
8
+ import os
9
+
10
+ # ----------------------------
11
+ # Low-Rank Linear Layer Implementation
12
+ # ----------------------------
13
+
14
+ class LowRankLinear(nn.Module):
15
+ """
16
+ A low-rank linear layer that factorizes a standard linear layer into two smaller ones.
17
+ This allows for reduced parameter count and faster computation.
18
+ """
19
+ def __init__(self, in_features, out_features, rank, init_std=0.02):
20
+ super().__init__()
21
+ self.U = nn.Linear(in_features, rank, bias=False)
22
+ self.V = nn.Linear(rank, out_features, bias=False)
23
+ nn.init.normal_(self.U.weight, std=init_std)
24
+ nn.init.normal_(self.V.weight, std=init_std)
25
+
26
+ def forward(self, x):
27
+ """
28
+ Forward pass through two low-rank linear layers (U and V).
29
+ """
30
+ return self.V(self.U(x))
31
+
32
+ # ----------------------------
33
+ # Lumenspark Self-Attention Implementation
34
+ # ----------------------------
35
+
36
+ class LumensparkSelfAttention(nn.Module):
37
+ """
38
+ Custom self-attention mechanism for the Lumenspark model.
39
+ It uses low-rank approximations to reduce computational cost and memory usage.
40
+ """
41
+ def __init__(self, embed_dim, num_heads, head_dim=None, dropout=0.0):
42
+ super().__init__()
43
+ assert (embed_dim % num_heads) == 0, 'Embedding dimension must be divisible by the number of heads'
44
+
45
+ self.num_heads = num_heads
46
+ self.embed_dim = embed_dim
47
+ self.head_dim = head_dim if head_dim is not None else embed_dim // num_heads
48
+
49
+ # Query, Key and Value transformations using LowRankLinear
50
+ self.q_proj = nn.Linear(embed_dim, self.head_dim * num_heads)
51
+ self.k_proj = nn.Linear(embed_dim, self.head_dim * num_heads)
52
+ self.v_proj = nn.Linear(embed_dim, self.head_dim * num_heads)
53
+
54
+ self.dropout_layer = nn.Dropout(dropout)
55
+ self.output_transform = nn.Linear(self.head_dim * num_heads, embed_dim)
56
+
57
+ def stable_softmax(self, x, dim=-1):
58
+ # Subtract max for numerical stability
59
+ x_max = torch.max(x, dim=dim, keepdim=True)[0]
60
+ exp_x = torch.exp(x - x_max)
61
+ return exp_x / (torch.sum(exp_x, dim=dim, keepdim=True) + 1e-6)
62
+
63
+ def forward(self, inputs, attention_mask=None):
64
+ batch_size, seq_len, _ = inputs.shape
65
+
66
+ q = self.q_proj(inputs).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
67
+ k = self.k_proj(inputs).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
68
+ v = self.v_proj(inputs).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
69
+
70
+ attention_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
71
+
72
+ if attention_mask is not None:
73
+ attention_scores = attention_scores.masked_fill(attention_mask == 0, float('-inf'))
74
+
75
+ attention_weights = self.stable_softmax(attention_scores, dim=-1)
76
+ attention_weights = self.dropout_layer(attention_weights)
77
+
78
+ attention_output = torch.matmul(attention_weights, v)
79
+ attention_output = attention_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.embed_dim)
80
+ return self.output_transform(attention_output)
81
+
82
+ # ----------------------------
83
+ # Define Lumenspark Model Wrapper
84
+ # ----------------------------
85
+
86
+ class LumensparkModel(PreTrainedModel):
87
+ config_class = LumensparkConfig
88
+
89
+ def __init__(self, config, tokenizer):
90
+ super().__init__(config)
91
+ self.config = config
92
+ self.tokenizer = tokenizer
93
+
94
+ # Token and position embeddings
95
+ self.token_embedding = nn.Embedding(config.vocab_size, config.embed_dim)
96
+ self.position_embedding = nn.Embedding(config.seq_length, config.embed_dim)
97
+
98
+ # Lumenspark transformer encoder layers with prenormalization and LayerScale
99
+ self.layers = nn.ModuleList()
100
+ for _ in range(config.depth):
101
+ layer = nn.ModuleDict({
102
+ "norm1": nn.LayerNorm(config.embed_dim),
103
+ "attn": LumensparkSelfAttention(
104
+ embed_dim=config.embed_dim,
105
+ num_heads=config.heads,
106
+ head_dim=config.embed_dim // config.heads,
107
+ dropout=config.dropout
108
+ ),
109
+ "norm2": nn.LayerNorm(config.embed_dim),
110
+ "ffn": nn.Sequential(
111
+ LowRankLinear(config.embed_dim, config.embed_dim * 4, rank=config.rank),
112
+ nn.GELU(),
113
+ nn.Dropout(config.dropout),
114
+ LowRankLinear(config.embed_dim * 4, config.embed_dim, rank=config.rank),
115
+ nn.Dropout(config.dropout)
116
+ ),
117
+ })
118
+ # Assign the parameters directly as attributes
119
+ layer.layer_scale_attn = nn.Parameter(torch.ones(config.embed_dim) * 1e-2)
120
+ layer.layer_scale_ffn = nn.Parameter(torch.ones(config.embed_dim) * 1e-2)
121
+ self.layers.append(layer)
122
+
123
+ # Final LayerNorm layer
124
+ self.final_norm = nn.LayerNorm(config.embed_dim)
125
+
126
+ # Feed-forward output layer
127
+ self.fc_out = nn.Linear(config.embed_dim, config.vocab_size)
128
+ self.dropout = nn.Dropout(config.dropout)
129
+
130
+ # Initialize model weights
131
+ self.init_weights()
132
+
133
+ @classmethod
134
+ def from_pretrained(cls, model_id, cache_dir=None, **kwargs):
135
+ """
136
+ Downloads the pretrained weights from Hugging Face, and loads the GPT-2 tokenizer.
137
+ """
138
+ # Set cache directory for storing models
139
+ cache_dir = cache_dir or os.path.join(os.getcwd(), "lumenspark_weights")
140
+
141
+ # Download model weights in `.safetensors` format
142
+ weight_path = hf_hub_download(repo_id=model_id, filename="model.safetensors", cache_dir=cache_dir)
143
+
144
+ # Load the configuration
145
+ config_path = hf_hub_download(repo_id=model_id, filename="config.json", cache_dir=cache_dir)
146
+ config = LumensparkConfig.from_json_file(config_path)
147
+
148
+ # Load GPT-2 tokenizer directly from Hugging Face
149
+ tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
150
+
151
+ # Instantiate the model
152
+ model = cls(config, tokenizer=tokenizer)
153
+
154
+ # Load state_dict from safetensors file
155
+ with safe_open(weight_path, framework="pt") as f:
156
+ state_dict = {k: f.get_tensor(k) for k in f.keys()}
157
+
158
+ model.load_state_dict(state_dict)
159
+
160
+ return model
161
+
162
+ @staticmethod
163
+ def top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float('Inf')):
164
+ """
165
+ Filter a distribution of logits using top-k and/or top-p filtering.
166
+ """
167
+ top_k = min(top_k, logits.size(-1))
168
+ if top_k > 0:
169
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
170
+ logits[indices_to_remove] = filter_value
171
+
172
+ if top_p < 1.0:
173
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
174
+ cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
175
+
176
+ sorted_indices_to_remove = cumulative_probs > top_p
177
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
178
+ sorted_indices_to_remove[..., 0] = 0
179
+
180
+ indices_to_remove = sorted_indices[sorted_indices_to_remove]
181
+ logits[:, indices_to_remove] = filter_value
182
+ return logits
183
+
184
+ def generate(self, text, max_length=160, min_length=20, temperature=0.6, top_k=50, top_p=0.9, repetition_penalty=1.1, do_sample=True):
185
+ """
186
+ Text generation method that handles auto-regressive generation with repetition penalty.
187
+ The input is a string, and the output is a string generated by the model.
188
+ """
189
+ self.eval() # Set model to evaluation mode
190
+ # Tokenize input text using GPT-2 tokenizer
191
+ input_ids = torch.tensor([self.tokenizer.encode(text)], dtype=torch.long).to(self.device)
192
+
193
+ # Initialize attention mask
194
+ attention_mask = torch.ones_like(input_ids).to(self.device)
195
+
196
+ generated_tokens = input_ids
197
+
198
+ for _ in range(max_length - input_ids.size(1)):
199
+ outputs = self.forward(input_ids=generated_tokens, attention_mask=attention_mask)
200
+ logits = outputs["logits"][:, -1, :]
201
+
202
+ # Adjust temperature for randomness
203
+ logits = logits / temperature
204
+
205
+ # Apply repetition penalty: reduce logits of tokens that have already been generated
206
+ for token in set(generated_tokens.view(-1).tolist()):
207
+ logits[:, token] /= repetition_penalty # Penalize repeated tokens
208
+
209
+ # Apply top-k and top-p sampling to select the next token
210
+ if do_sample:
211
+ filtered_logits = LumensparkModel.top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
212
+ probs = torch.softmax(filtered_logits, dim=-1)
213
+ next_token = torch.multinomial(probs, num_samples=1)
214
+ else:
215
+ next_token = torch.argmax(logits, dim=-1, keepdim=True)
216
+
217
+ # Append the generated token
218
+ generated_tokens = torch.cat((generated_tokens, next_token), dim=1)
219
+ # Update attention mask
220
+ attention_mask = torch.ones_like(generated_tokens).to(self.device)
221
+
222
+ # Prevent early stopping by ensuring min_length is reached before allowing EOS
223
+ if next_token.item() == self.tokenizer.eos_token_id and generated_tokens.size(1) < min_length:
224
+ continue # Skip EOS if output is too short
225
+
226
+ # Stop if the EOS token is generated and minimum length is reached
227
+ if next_token.item() == self.tokenizer.eos_token_id:
228
+ break
229
+
230
+ # Decode the generated tokens back to text
231
+ generated_text = self.tokenizer.decode(generated_tokens[0].tolist())
232
+
233
+ return generated_text
234
+
235
+ def forward(self, input_ids, attention_mask=None, labels=None):
236
+ """
237
+ Forward pass of the model. If labels are provided, the loss is also computed.
238
+ """
239
+ batch_size, seq_length = input_ids.size()
240
+
241
+ # Generate position ids
242
+ position_ids = torch.arange(0, seq_length, dtype=torch.long, device=input_ids.device)
243
+ position_ids = position_ids.unsqueeze(0).expand(batch_size, seq_length)
244
+
245
+ # Embed tokens and positions
246
+ token_embeddings = self.token_embedding(input_ids)
247
+ position_embeddings = self.position_embedding(position_ids)
248
+
249
+ # Combine token and position embeddings
250
+ embeddings = token_embeddings + position_embeddings
251
+ embeddings = self.dropout(embeddings)
252
+
253
+ # Create causal mask
254
+ device = embeddings.device
255
+ causal_mask = torch.tril(torch.ones((seq_length, seq_length), device=device)).unsqueeze(0).unsqueeze(0)
256
+
257
+ # Combine with attention mask if provided
258
+ if attention_mask is not None:
259
+ # Expand attention_mask to match dimensions
260
+ attention_mask = attention_mask[:, None, None, :].float()
261
+ combined_mask = attention_mask * causal_mask
262
+ else:
263
+ combined_mask = causal_mask
264
+
265
+ # Pass through each transformer layer with prenormalization and LayerScale
266
+ for layer in self.layers:
267
+ # Prenormalization before self-attention
268
+ embeddings_norm = layer["norm1"](embeddings)
269
+ attn_output = layer["attn"](embeddings_norm, attention_mask=combined_mask)
270
+ # Apply LayerScale for attention output
271
+ embeddings = embeddings + layer.layer_scale_attn * attn_output
272
+
273
+ # Prenormalization before feed-forward network
274
+ embeddings_norm = layer["norm2"](embeddings)
275
+ ffn_output = layer["ffn"](embeddings_norm)
276
+ # Apply LayerScale for feed-forward output
277
+ embeddings = embeddings + layer.layer_scale_ffn * ffn_output
278
+
279
+ # Apply final LayerNorm before output
280
+ embeddings = self.final_norm(embeddings)
281
+
282
+ # Compute logits (unnormalized scores)
283
+ logits = self.fc_out(embeddings)
284
+
285
+ # Compute loss if labels are provided
286
+ loss = None
287
+ if labels is not None:
288
+ shift_logits = logits[:, :-1, :].contiguous().view(-1, self.config.vocab_size)
289
+ shift_labels = labels[:, 1:].contiguous().view(-1)
290
+
291
+ # Base cross-entropy loss
292
+ loss_fct = nn.CrossEntropyLoss()
293
+ loss = loss_fct(shift_logits, shift_labels)
294
+
295
+ return {"loss": loss, "logits": logits}