Upload 2 files
Browse files- modeling_lumenspark.py +10 -138
modeling_lumenspark.py
CHANGED
@@ -1,11 +1,9 @@
|
|
1 |
-
from transformers import
|
2 |
-
from
|
3 |
-
from
|
4 |
-
from safetensors import safe_open
|
5 |
from torch import nn
|
6 |
import torch
|
7 |
import math
|
8 |
-
import os
|
9 |
|
10 |
# ----------------------------
|
11 |
# Low-Rank Linear Layer Implementation
|
@@ -86,10 +84,9 @@ class LumensparkSelfAttention(nn.Module):
|
|
86 |
class LumensparkModel(PreTrainedModel):
|
87 |
config_class = LumensparkConfig
|
88 |
|
89 |
-
def __init__(self, config
|
90 |
super().__init__(config)
|
91 |
self.config = config
|
92 |
-
self.tokenizer = tokenizer
|
93 |
|
94 |
# Token and position embeddings
|
95 |
self.token_embedding = nn.Embedding(config.vocab_size, config.embed_dim)
|
@@ -115,184 +112,59 @@ class LumensparkModel(PreTrainedModel):
|
|
115 |
nn.Dropout(config.dropout)
|
116 |
),
|
117 |
})
|
118 |
-
# Assign the parameters directly as attributes
|
119 |
layer.layer_scale_attn = nn.Parameter(torch.ones(config.embed_dim) * 1e-2)
|
120 |
layer.layer_scale_ffn = nn.Parameter(torch.ones(config.embed_dim) * 1e-2)
|
121 |
self.layers.append(layer)
|
122 |
|
123 |
-
# Final LayerNorm layer
|
124 |
self.final_norm = nn.LayerNorm(config.embed_dim)
|
125 |
-
|
126 |
-
# Feed-forward output layer
|
127 |
self.fc_out = nn.Linear(config.embed_dim, config.vocab_size)
|
128 |
self.dropout = nn.Dropout(config.dropout)
|
129 |
|
130 |
-
#
|
131 |
self.init_weights()
|
132 |
|
133 |
-
@classmethod
|
134 |
-
def from_pretrained(cls, model_id, cache_dir=None, **kwargs):
|
135 |
-
"""
|
136 |
-
Downloads the pretrained weights from Hugging Face, and loads the GPT-2 tokenizer.
|
137 |
-
"""
|
138 |
-
# Set cache directory for storing models
|
139 |
-
cache_dir = cache_dir or os.path.join(os.getcwd(), "lumenspark_weights")
|
140 |
-
|
141 |
-
# Download model weights in `.safetensors` format
|
142 |
-
weight_path = hf_hub_download(repo_id=model_id, filename="model.safetensors", cache_dir=cache_dir)
|
143 |
-
|
144 |
-
# Load the configuration
|
145 |
-
config_path = hf_hub_download(repo_id=model_id, filename="config.json", cache_dir=cache_dir)
|
146 |
-
config = LumensparkConfig.from_json_file(config_path)
|
147 |
-
|
148 |
-
# Load GPT-2 tokenizer directly from Hugging Face
|
149 |
-
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
150 |
-
|
151 |
-
# Instantiate the model
|
152 |
-
model = cls(config, tokenizer=tokenizer)
|
153 |
-
|
154 |
-
# Load state_dict from safetensors file
|
155 |
-
with safe_open(weight_path, framework="pt") as f:
|
156 |
-
state_dict = {k: f.get_tensor(k) for k in f.keys()}
|
157 |
-
|
158 |
-
model.load_state_dict(state_dict)
|
159 |
-
|
160 |
-
return model
|
161 |
-
|
162 |
-
@staticmethod
|
163 |
-
def top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float('Inf')):
|
164 |
-
"""
|
165 |
-
Filter a distribution of logits using top-k and/or top-p filtering.
|
166 |
-
"""
|
167 |
-
top_k = min(top_k, logits.size(-1))
|
168 |
-
if top_k > 0:
|
169 |
-
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
170 |
-
logits[indices_to_remove] = filter_value
|
171 |
-
|
172 |
-
if top_p < 1.0:
|
173 |
-
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
174 |
-
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
|
175 |
-
|
176 |
-
sorted_indices_to_remove = cumulative_probs > top_p
|
177 |
-
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
178 |
-
sorted_indices_to_remove[..., 0] = 0
|
179 |
-
|
180 |
-
indices_to_remove = sorted_indices[sorted_indices_to_remove]
|
181 |
-
logits[:, indices_to_remove] = filter_value
|
182 |
-
return logits
|
183 |
-
|
184 |
-
def generate(self, text, max_length=160, min_length=20, temperature=0.6, top_k=50, top_p=0.9, repetition_penalty=1.1, do_sample=True):
|
185 |
-
"""
|
186 |
-
Text generation method that handles auto-regressive generation with repetition penalty.
|
187 |
-
The input is a string, and the output is a string generated by the model.
|
188 |
-
"""
|
189 |
-
self.eval() # Set model to evaluation mode
|
190 |
-
# Tokenize input text using GPT-2 tokenizer
|
191 |
-
input_ids = torch.tensor([self.tokenizer.encode(text)], dtype=torch.long).to(self.device)
|
192 |
-
|
193 |
-
# Initialize attention mask
|
194 |
-
attention_mask = torch.ones_like(input_ids).to(self.device)
|
195 |
-
|
196 |
-
generated_tokens = input_ids
|
197 |
-
|
198 |
-
for _ in range(max_length - input_ids.size(1)):
|
199 |
-
outputs = self.forward(input_ids=generated_tokens, attention_mask=attention_mask)
|
200 |
-
logits = outputs["logits"][:, -1, :]
|
201 |
-
|
202 |
-
# Adjust temperature for randomness
|
203 |
-
logits = logits / temperature
|
204 |
-
|
205 |
-
# Apply repetition penalty: reduce logits of tokens that have already been generated
|
206 |
-
for token in set(generated_tokens.view(-1).tolist()):
|
207 |
-
logits[:, token] /= repetition_penalty # Penalize repeated tokens
|
208 |
-
|
209 |
-
# Apply top-k and top-p sampling to select the next token
|
210 |
-
if do_sample:
|
211 |
-
filtered_logits = LumensparkModel.top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
|
212 |
-
probs = torch.softmax(filtered_logits, dim=-1)
|
213 |
-
next_token = torch.multinomial(probs, num_samples=1)
|
214 |
-
else:
|
215 |
-
next_token = torch.argmax(logits, dim=-1, keepdim=True)
|
216 |
-
|
217 |
-
# Append the generated token
|
218 |
-
generated_tokens = torch.cat((generated_tokens, next_token), dim=1)
|
219 |
-
# Update attention mask
|
220 |
-
attention_mask = torch.ones_like(generated_tokens).to(self.device)
|
221 |
-
|
222 |
-
# Prevent early stopping by ensuring min_length is reached before allowing EOS
|
223 |
-
if next_token.item() == self.tokenizer.eos_token_id and generated_tokens.size(1) < min_length:
|
224 |
-
continue # Skip EOS if output is too short
|
225 |
-
|
226 |
-
# Stop if the EOS token is generated and minimum length is reached
|
227 |
-
if next_token.item() == self.tokenizer.eos_token_id:
|
228 |
-
break
|
229 |
-
|
230 |
-
# Decode the generated tokens back to text
|
231 |
-
generated_text = self.tokenizer.decode(generated_tokens[0].tolist())
|
232 |
-
|
233 |
-
return generated_text
|
234 |
-
|
235 |
def forward(self, input_ids, attention_mask=None, labels=None):
|
236 |
-
"""
|
237 |
-
Forward pass of the model. If labels are provided, the loss is also computed.
|
238 |
-
"""
|
239 |
batch_size, seq_length = input_ids.size()
|
240 |
|
241 |
-
# Generate position ids
|
242 |
position_ids = torch.arange(0, seq_length, dtype=torch.long, device=input_ids.device)
|
243 |
position_ids = position_ids.unsqueeze(0).expand(batch_size, seq_length)
|
244 |
|
245 |
-
# Embed tokens and positions
|
246 |
token_embeddings = self.token_embedding(input_ids)
|
247 |
position_embeddings = self.position_embedding(position_ids)
|
248 |
-
|
249 |
-
# Combine token and position embeddings
|
250 |
embeddings = token_embeddings + position_embeddings
|
251 |
embeddings = self.dropout(embeddings)
|
252 |
|
253 |
-
|
254 |
-
device = embeddings.device
|
255 |
-
causal_mask = torch.tril(torch.ones((seq_length, seq_length), device=device)).unsqueeze(0).unsqueeze(0)
|
256 |
|
257 |
-
# Combine with attention mask if provided
|
258 |
if attention_mask is not None:
|
259 |
-
# Expand attention_mask to match dimensions
|
260 |
attention_mask = attention_mask[:, None, None, :].float()
|
261 |
combined_mask = attention_mask * causal_mask
|
262 |
else:
|
263 |
combined_mask = causal_mask
|
264 |
|
265 |
-
# Pass through each transformer layer with prenormalization and LayerScale
|
266 |
for layer in self.layers:
|
267 |
-
# Prenormalization before self-attention
|
268 |
embeddings_norm = layer["norm1"](embeddings)
|
269 |
attn_output = layer["attn"](embeddings_norm, attention_mask=combined_mask)
|
270 |
-
# Apply LayerScale for attention output
|
271 |
embeddings = embeddings + layer.layer_scale_attn * attn_output
|
272 |
|
273 |
-
# Prenormalization before feed-forward network
|
274 |
embeddings_norm = layer["norm2"](embeddings)
|
275 |
ffn_output = layer["ffn"](embeddings_norm)
|
276 |
-
# Apply LayerScale for feed-forward output
|
277 |
embeddings = embeddings + layer.layer_scale_ffn * ffn_output
|
278 |
|
279 |
-
# Apply final LayerNorm before output
|
280 |
embeddings = self.final_norm(embeddings)
|
281 |
-
|
282 |
-
# Compute logits (unnormalized scores)
|
283 |
logits = self.fc_out(embeddings)
|
284 |
|
285 |
-
# Compute loss if labels are provided
|
286 |
loss = None
|
287 |
if labels is not None:
|
288 |
shift_logits = logits[:, :-1, :].contiguous().view(-1, self.config.vocab_size)
|
289 |
shift_labels = labels[:, 1:].contiguous().view(-1)
|
290 |
-
|
291 |
-
# Base cross-entropy loss
|
292 |
loss_fct = nn.CrossEntropyLoss()
|
293 |
loss = loss_fct(shift_logits, shift_labels)
|
294 |
|
295 |
-
return
|
|
|
|
|
|
|
296 |
|
297 |
AutoConfig.register("lumenspark", LumensparkConfig)
|
298 |
AutoModelForCausalLM.register(LumensparkConfig, LumensparkModel)
|
|
|
1 |
+
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
|
2 |
+
from transformers import PreTrainedModel, AutoConfig, AutoModelForCausalLM
|
3 |
+
from .configuration_lumenspark import LumensparkConfig
|
|
|
4 |
from torch import nn
|
5 |
import torch
|
6 |
import math
|
|
|
7 |
|
8 |
# ----------------------------
|
9 |
# Low-Rank Linear Layer Implementation
|
|
|
84 |
class LumensparkModel(PreTrainedModel):
|
85 |
config_class = LumensparkConfig
|
86 |
|
87 |
+
def __init__(self, config):
|
88 |
super().__init__(config)
|
89 |
self.config = config
|
|
|
90 |
|
91 |
# Token and position embeddings
|
92 |
self.token_embedding = nn.Embedding(config.vocab_size, config.embed_dim)
|
|
|
112 |
nn.Dropout(config.dropout)
|
113 |
),
|
114 |
})
|
|
|
115 |
layer.layer_scale_attn = nn.Parameter(torch.ones(config.embed_dim) * 1e-2)
|
116 |
layer.layer_scale_ffn = nn.Parameter(torch.ones(config.embed_dim) * 1e-2)
|
117 |
self.layers.append(layer)
|
118 |
|
|
|
119 |
self.final_norm = nn.LayerNorm(config.embed_dim)
|
|
|
|
|
120 |
self.fc_out = nn.Linear(config.embed_dim, config.vocab_size)
|
121 |
self.dropout = nn.Dropout(config.dropout)
|
122 |
|
123 |
+
# Call init_weights at the end to ensure proper initialization
|
124 |
self.init_weights()
|
125 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
126 |
def forward(self, input_ids, attention_mask=None, labels=None):
|
|
|
|
|
|
|
127 |
batch_size, seq_length = input_ids.size()
|
128 |
|
|
|
129 |
position_ids = torch.arange(0, seq_length, dtype=torch.long, device=input_ids.device)
|
130 |
position_ids = position_ids.unsqueeze(0).expand(batch_size, seq_length)
|
131 |
|
|
|
132 |
token_embeddings = self.token_embedding(input_ids)
|
133 |
position_embeddings = self.position_embedding(position_ids)
|
|
|
|
|
134 |
embeddings = token_embeddings + position_embeddings
|
135 |
embeddings = self.dropout(embeddings)
|
136 |
|
137 |
+
causal_mask = torch.tril(torch.ones((seq_length, seq_length), device=embeddings.device)).unsqueeze(0).unsqueeze(0)
|
|
|
|
|
138 |
|
|
|
139 |
if attention_mask is not None:
|
|
|
140 |
attention_mask = attention_mask[:, None, None, :].float()
|
141 |
combined_mask = attention_mask * causal_mask
|
142 |
else:
|
143 |
combined_mask = causal_mask
|
144 |
|
|
|
145 |
for layer in self.layers:
|
|
|
146 |
embeddings_norm = layer["norm1"](embeddings)
|
147 |
attn_output = layer["attn"](embeddings_norm, attention_mask=combined_mask)
|
|
|
148 |
embeddings = embeddings + layer.layer_scale_attn * attn_output
|
149 |
|
|
|
150 |
embeddings_norm = layer["norm2"](embeddings)
|
151 |
ffn_output = layer["ffn"](embeddings_norm)
|
|
|
152 |
embeddings = embeddings + layer.layer_scale_ffn * ffn_output
|
153 |
|
|
|
154 |
embeddings = self.final_norm(embeddings)
|
|
|
|
|
155 |
logits = self.fc_out(embeddings)
|
156 |
|
|
|
157 |
loss = None
|
158 |
if labels is not None:
|
159 |
shift_logits = logits[:, :-1, :].contiguous().view(-1, self.config.vocab_size)
|
160 |
shift_labels = labels[:, 1:].contiguous().view(-1)
|
|
|
|
|
161 |
loss_fct = nn.CrossEntropyLoss()
|
162 |
loss = loss_fct(shift_logits, shift_labels)
|
163 |
|
164 |
+
return CausalLMOutputWithCrossAttentions(
|
165 |
+
loss=loss,
|
166 |
+
logits=logits
|
167 |
+
)
|
168 |
|
169 |
AutoConfig.register("lumenspark", LumensparkConfig)
|
170 |
AutoModelForCausalLM.register(LumensparkConfig, LumensparkModel)
|