anto18671 commited on
Commit
e2536b8
·
verified ·
1 Parent(s): d0e13a3

Upload 2 files

Browse files
Files changed (1) hide show
  1. modeling_lumenspark.py +10 -138
modeling_lumenspark.py CHANGED
@@ -1,11 +1,9 @@
1
- from transformers import PreTrainedModel, GPT2Tokenizer, AutoConfig, AutoModelForCausalLM
2
- from lumenspark.configuration_lumenspark import LumensparkConfig
3
- from huggingface_hub import hf_hub_download
4
- from safetensors import safe_open
5
  from torch import nn
6
  import torch
7
  import math
8
- import os
9
 
10
  # ----------------------------
11
  # Low-Rank Linear Layer Implementation
@@ -86,10 +84,9 @@ class LumensparkSelfAttention(nn.Module):
86
  class LumensparkModel(PreTrainedModel):
87
  config_class = LumensparkConfig
88
 
89
- def __init__(self, config, tokenizer):
90
  super().__init__(config)
91
  self.config = config
92
- self.tokenizer = tokenizer
93
 
94
  # Token and position embeddings
95
  self.token_embedding = nn.Embedding(config.vocab_size, config.embed_dim)
@@ -115,184 +112,59 @@ class LumensparkModel(PreTrainedModel):
115
  nn.Dropout(config.dropout)
116
  ),
117
  })
118
- # Assign the parameters directly as attributes
119
  layer.layer_scale_attn = nn.Parameter(torch.ones(config.embed_dim) * 1e-2)
120
  layer.layer_scale_ffn = nn.Parameter(torch.ones(config.embed_dim) * 1e-2)
121
  self.layers.append(layer)
122
 
123
- # Final LayerNorm layer
124
  self.final_norm = nn.LayerNorm(config.embed_dim)
125
-
126
- # Feed-forward output layer
127
  self.fc_out = nn.Linear(config.embed_dim, config.vocab_size)
128
  self.dropout = nn.Dropout(config.dropout)
129
 
130
- # Initialize model weights
131
  self.init_weights()
132
 
133
- @classmethod
134
- def from_pretrained(cls, model_id, cache_dir=None, **kwargs):
135
- """
136
- Downloads the pretrained weights from Hugging Face, and loads the GPT-2 tokenizer.
137
- """
138
- # Set cache directory for storing models
139
- cache_dir = cache_dir or os.path.join(os.getcwd(), "lumenspark_weights")
140
-
141
- # Download model weights in `.safetensors` format
142
- weight_path = hf_hub_download(repo_id=model_id, filename="model.safetensors", cache_dir=cache_dir)
143
-
144
- # Load the configuration
145
- config_path = hf_hub_download(repo_id=model_id, filename="config.json", cache_dir=cache_dir)
146
- config = LumensparkConfig.from_json_file(config_path)
147
-
148
- # Load GPT-2 tokenizer directly from Hugging Face
149
- tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
150
-
151
- # Instantiate the model
152
- model = cls(config, tokenizer=tokenizer)
153
-
154
- # Load state_dict from safetensors file
155
- with safe_open(weight_path, framework="pt") as f:
156
- state_dict = {k: f.get_tensor(k) for k in f.keys()}
157
-
158
- model.load_state_dict(state_dict)
159
-
160
- return model
161
-
162
- @staticmethod
163
- def top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float('Inf')):
164
- """
165
- Filter a distribution of logits using top-k and/or top-p filtering.
166
- """
167
- top_k = min(top_k, logits.size(-1))
168
- if top_k > 0:
169
- indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
170
- logits[indices_to_remove] = filter_value
171
-
172
- if top_p < 1.0:
173
- sorted_logits, sorted_indices = torch.sort(logits, descending=True)
174
- cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
175
-
176
- sorted_indices_to_remove = cumulative_probs > top_p
177
- sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
178
- sorted_indices_to_remove[..., 0] = 0
179
-
180
- indices_to_remove = sorted_indices[sorted_indices_to_remove]
181
- logits[:, indices_to_remove] = filter_value
182
- return logits
183
-
184
- def generate(self, text, max_length=160, min_length=20, temperature=0.6, top_k=50, top_p=0.9, repetition_penalty=1.1, do_sample=True):
185
- """
186
- Text generation method that handles auto-regressive generation with repetition penalty.
187
- The input is a string, and the output is a string generated by the model.
188
- """
189
- self.eval() # Set model to evaluation mode
190
- # Tokenize input text using GPT-2 tokenizer
191
- input_ids = torch.tensor([self.tokenizer.encode(text)], dtype=torch.long).to(self.device)
192
-
193
- # Initialize attention mask
194
- attention_mask = torch.ones_like(input_ids).to(self.device)
195
-
196
- generated_tokens = input_ids
197
-
198
- for _ in range(max_length - input_ids.size(1)):
199
- outputs = self.forward(input_ids=generated_tokens, attention_mask=attention_mask)
200
- logits = outputs["logits"][:, -1, :]
201
-
202
- # Adjust temperature for randomness
203
- logits = logits / temperature
204
-
205
- # Apply repetition penalty: reduce logits of tokens that have already been generated
206
- for token in set(generated_tokens.view(-1).tolist()):
207
- logits[:, token] /= repetition_penalty # Penalize repeated tokens
208
-
209
- # Apply top-k and top-p sampling to select the next token
210
- if do_sample:
211
- filtered_logits = LumensparkModel.top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
212
- probs = torch.softmax(filtered_logits, dim=-1)
213
- next_token = torch.multinomial(probs, num_samples=1)
214
- else:
215
- next_token = torch.argmax(logits, dim=-1, keepdim=True)
216
-
217
- # Append the generated token
218
- generated_tokens = torch.cat((generated_tokens, next_token), dim=1)
219
- # Update attention mask
220
- attention_mask = torch.ones_like(generated_tokens).to(self.device)
221
-
222
- # Prevent early stopping by ensuring min_length is reached before allowing EOS
223
- if next_token.item() == self.tokenizer.eos_token_id and generated_tokens.size(1) < min_length:
224
- continue # Skip EOS if output is too short
225
-
226
- # Stop if the EOS token is generated and minimum length is reached
227
- if next_token.item() == self.tokenizer.eos_token_id:
228
- break
229
-
230
- # Decode the generated tokens back to text
231
- generated_text = self.tokenizer.decode(generated_tokens[0].tolist())
232
-
233
- return generated_text
234
-
235
  def forward(self, input_ids, attention_mask=None, labels=None):
236
- """
237
- Forward pass of the model. If labels are provided, the loss is also computed.
238
- """
239
  batch_size, seq_length = input_ids.size()
240
 
241
- # Generate position ids
242
  position_ids = torch.arange(0, seq_length, dtype=torch.long, device=input_ids.device)
243
  position_ids = position_ids.unsqueeze(0).expand(batch_size, seq_length)
244
 
245
- # Embed tokens and positions
246
  token_embeddings = self.token_embedding(input_ids)
247
  position_embeddings = self.position_embedding(position_ids)
248
-
249
- # Combine token and position embeddings
250
  embeddings = token_embeddings + position_embeddings
251
  embeddings = self.dropout(embeddings)
252
 
253
- # Create causal mask
254
- device = embeddings.device
255
- causal_mask = torch.tril(torch.ones((seq_length, seq_length), device=device)).unsqueeze(0).unsqueeze(0)
256
 
257
- # Combine with attention mask if provided
258
  if attention_mask is not None:
259
- # Expand attention_mask to match dimensions
260
  attention_mask = attention_mask[:, None, None, :].float()
261
  combined_mask = attention_mask * causal_mask
262
  else:
263
  combined_mask = causal_mask
264
 
265
- # Pass through each transformer layer with prenormalization and LayerScale
266
  for layer in self.layers:
267
- # Prenormalization before self-attention
268
  embeddings_norm = layer["norm1"](embeddings)
269
  attn_output = layer["attn"](embeddings_norm, attention_mask=combined_mask)
270
- # Apply LayerScale for attention output
271
  embeddings = embeddings + layer.layer_scale_attn * attn_output
272
 
273
- # Prenormalization before feed-forward network
274
  embeddings_norm = layer["norm2"](embeddings)
275
  ffn_output = layer["ffn"](embeddings_norm)
276
- # Apply LayerScale for feed-forward output
277
  embeddings = embeddings + layer.layer_scale_ffn * ffn_output
278
 
279
- # Apply final LayerNorm before output
280
  embeddings = self.final_norm(embeddings)
281
-
282
- # Compute logits (unnormalized scores)
283
  logits = self.fc_out(embeddings)
284
 
285
- # Compute loss if labels are provided
286
  loss = None
287
  if labels is not None:
288
  shift_logits = logits[:, :-1, :].contiguous().view(-1, self.config.vocab_size)
289
  shift_labels = labels[:, 1:].contiguous().view(-1)
290
-
291
- # Base cross-entropy loss
292
  loss_fct = nn.CrossEntropyLoss()
293
  loss = loss_fct(shift_logits, shift_labels)
294
 
295
- return {"loss": loss, "logits": logits}
 
 
 
296
 
297
  AutoConfig.register("lumenspark", LumensparkConfig)
298
  AutoModelForCausalLM.register(LumensparkConfig, LumensparkModel)
 
1
+ from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
2
+ from transformers import PreTrainedModel, AutoConfig, AutoModelForCausalLM
3
+ from .configuration_lumenspark import LumensparkConfig
 
4
  from torch import nn
5
  import torch
6
  import math
 
7
 
8
  # ----------------------------
9
  # Low-Rank Linear Layer Implementation
 
84
  class LumensparkModel(PreTrainedModel):
85
  config_class = LumensparkConfig
86
 
87
+ def __init__(self, config):
88
  super().__init__(config)
89
  self.config = config
 
90
 
91
  # Token and position embeddings
92
  self.token_embedding = nn.Embedding(config.vocab_size, config.embed_dim)
 
112
  nn.Dropout(config.dropout)
113
  ),
114
  })
 
115
  layer.layer_scale_attn = nn.Parameter(torch.ones(config.embed_dim) * 1e-2)
116
  layer.layer_scale_ffn = nn.Parameter(torch.ones(config.embed_dim) * 1e-2)
117
  self.layers.append(layer)
118
 
 
119
  self.final_norm = nn.LayerNorm(config.embed_dim)
 
 
120
  self.fc_out = nn.Linear(config.embed_dim, config.vocab_size)
121
  self.dropout = nn.Dropout(config.dropout)
122
 
123
+ # Call init_weights at the end to ensure proper initialization
124
  self.init_weights()
125
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  def forward(self, input_ids, attention_mask=None, labels=None):
 
 
 
127
  batch_size, seq_length = input_ids.size()
128
 
 
129
  position_ids = torch.arange(0, seq_length, dtype=torch.long, device=input_ids.device)
130
  position_ids = position_ids.unsqueeze(0).expand(batch_size, seq_length)
131
 
 
132
  token_embeddings = self.token_embedding(input_ids)
133
  position_embeddings = self.position_embedding(position_ids)
 
 
134
  embeddings = token_embeddings + position_embeddings
135
  embeddings = self.dropout(embeddings)
136
 
137
+ causal_mask = torch.tril(torch.ones((seq_length, seq_length), device=embeddings.device)).unsqueeze(0).unsqueeze(0)
 
 
138
 
 
139
  if attention_mask is not None:
 
140
  attention_mask = attention_mask[:, None, None, :].float()
141
  combined_mask = attention_mask * causal_mask
142
  else:
143
  combined_mask = causal_mask
144
 
 
145
  for layer in self.layers:
 
146
  embeddings_norm = layer["norm1"](embeddings)
147
  attn_output = layer["attn"](embeddings_norm, attention_mask=combined_mask)
 
148
  embeddings = embeddings + layer.layer_scale_attn * attn_output
149
 
 
150
  embeddings_norm = layer["norm2"](embeddings)
151
  ffn_output = layer["ffn"](embeddings_norm)
 
152
  embeddings = embeddings + layer.layer_scale_ffn * ffn_output
153
 
 
154
  embeddings = self.final_norm(embeddings)
 
 
155
  logits = self.fc_out(embeddings)
156
 
 
157
  loss = None
158
  if labels is not None:
159
  shift_logits = logits[:, :-1, :].contiguous().view(-1, self.config.vocab_size)
160
  shift_labels = labels[:, 1:].contiguous().view(-1)
 
 
161
  loss_fct = nn.CrossEntropyLoss()
162
  loss = loss_fct(shift_logits, shift_labels)
163
 
164
+ return CausalLMOutputWithCrossAttentions(
165
+ loss=loss,
166
+ logits=logits
167
+ )
168
 
169
  AutoConfig.register("lumenspark", LumensparkConfig)
170
  AutoModelForCausalLM.register(LumensparkConfig, LumensparkModel)