anto18671 commited on
Commit
4306d2e
·
verified ·
1 Parent(s): daa7d1e

Upload 3 files

Browse files
config.json CHANGED
@@ -3,6 +3,10 @@
3
  "architectures": ["LumensparkModel"],
4
  "vocab_size": 50257,
5
  "embed_dim": 768,
 
 
 
 
6
  "tokenizer_class": "GPT2Tokenizer",
7
  "depth": 8,
8
  "heads": 12,
 
3
  "architectures": ["LumensparkModel"],
4
  "vocab_size": 50257,
5
  "embed_dim": 768,
6
+ "auto_map": {
7
+ "AutoConfig": "anto18671/lumenspark--configuration_lumenspark.LumensparkConfig",
8
+ "AutoModelForCausalLM": "anto18671/lumenspark--modeling_lumenspark.LumensparkModel"
9
+ },
10
  "tokenizer_class": "GPT2Tokenizer",
11
  "depth": 8,
12
  "heads": 12,
configuration_lumenspark.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+ # ----------------------------
4
+ # Define Lumenspark Configuration
5
+ # ----------------------------
6
+
7
+ class LumensparkConfig(PretrainedConfig):
8
+ """
9
+ Configuration class for the Lumenspark model.
10
+ Stores model hyperparameters like sequence length, embedding dimension, number of layers, and others.
11
+ """
12
+ model_type = "lumenspark"
13
+
14
+ def __init__(
15
+ self,
16
+ seq_length=768,
17
+ vocab_size=50257,
18
+ embed_dim=768,
19
+ depth=8,
20
+ heads=12,
21
+ dropout=1/17,
22
+ k=384,
23
+ rank=256,
24
+ **kwargs
25
+ ):
26
+ super().__init__(**kwargs)
27
+ self.vocab_size = vocab_size
28
+ self.embed_dim = embed_dim
29
+ self.depth = depth
30
+ self.heads = heads
31
+ self.seq_length = seq_length
32
+ self.dropout = dropout
33
+ self.k = k
34
+ self.rank = rank
35
+
36
+ def to_dict(self):
37
+ """
38
+ Converts the configuration parameters to a dictionary format.
39
+ Useful for saving the configuration or inspecting model settings.
40
+ """
41
+ output = super().to_dict()
42
+ output.update({
43
+ "vocab_size": self.vocab_size,
44
+ "embed_dim": self.embed_dim,
45
+ "depth": self.depth,
46
+ "heads": self.heads,
47
+ "seq_length": self.seq_length,
48
+ "dropout": self.dropout,
49
+ "k": self.k,
50
+ "rank": self.rank,
51
+ })
52
+ return output
modeling_lumenspark.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PreTrainedModel, AutoConfig, AutoModelForCausalLM
2
+ from .configuration_lumenspark import LumensparkConfig
3
+ from torch import nn
4
+ import torch
5
+ import math
6
+
7
+ # ----------------------------
8
+ # Low-Rank Linear Layer Implementation
9
+ # ----------------------------
10
+
11
+ class LowRankLinear(nn.Module):
12
+ def __init__(self, in_features, out_features, rank, init_std=0.02):
13
+ super().__init__()
14
+ self.U = nn.Linear(in_features, rank, bias=False)
15
+ self.V = nn.Linear(rank, out_features, bias=False)
16
+ nn.init.normal_(self.U.weight, std=init_std)
17
+ nn.init.normal_(self.V.weight, std=init_std)
18
+
19
+ def forward(self, x):
20
+ return self.V(self.U(x))
21
+
22
+ # ----------------------------
23
+ # Lumenspark Self-Attention Implementation
24
+ # ----------------------------
25
+
26
+ class LumensparkSelfAttention(nn.Module):
27
+ def __init__(self, embed_dim, num_heads, head_dim=None, dropout=0.0):
28
+ super().__init__()
29
+ assert (embed_dim % num_heads) == 0, 'Embedding dimension must be divisible by the number of heads'
30
+
31
+ self.num_heads = num_heads
32
+ self.embed_dim = embed_dim
33
+ self.head_dim = head_dim if head_dim is not None else embed_dim // num_heads
34
+
35
+ self.q_proj = nn.Linear(embed_dim, self.head_dim * num_heads)
36
+ self.k_proj = nn.Linear(embed_dim, self.head_dim * num_heads)
37
+ self.v_proj = nn.Linear(embed_dim, self.head_dim * num_heads)
38
+
39
+ self.dropout_layer = nn.Dropout(dropout)
40
+ self.output_transform = nn.Linear(self.head_dim * num_heads, embed_dim)
41
+
42
+ def stable_softmax(self, x, dim=-1):
43
+ x_max = torch.max(x, dim=dim, keepdim=True)[0]
44
+ exp_x = torch.exp(x - x_max)
45
+ return exp_x / (torch.sum(exp_x, dim=dim, keepdim=True) + 1e-6)
46
+
47
+ def forward(self, inputs, attention_mask=None):
48
+ batch_size, seq_len, _ = inputs.shape
49
+
50
+ q = self.q_proj(inputs).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
51
+ k = self.k_proj(inputs).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
52
+ v = self.v_proj(inputs).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
53
+
54
+ attention_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
55
+
56
+ if attention_mask is not None:
57
+ attention_scores = attention_scores.masked_fill(attention_mask == 0, float('-inf'))
58
+
59
+ attention_weights = self.stable_softmax(attention_scores, dim=-1)
60
+ attention_weights = self.dropout_layer(attention_weights)
61
+
62
+ attention_output = torch.matmul(attention_weights, v)
63
+ attention_output = attention_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.embed_dim)
64
+ return self.output_transform(attention_output)
65
+
66
+ # ----------------------------
67
+ # Define Lumenspark Model Wrapper
68
+ # ----------------------------
69
+
70
+ class LumensparkModel(PreTrainedModel):
71
+ config_class = LumensparkConfig
72
+
73
+ def __init__(self, config):
74
+ super().__init__(config)
75
+ self.config = config
76
+
77
+ # Token and position embeddings
78
+ self.token_embedding = nn.Embedding(config.vocab_size, config.embed_dim)
79
+ self.position_embedding = nn.Embedding(config.seq_length, config.embed_dim)
80
+
81
+ # Lumenspark transformer encoder layers with prenormalization and LayerScale
82
+ self.layers = nn.ModuleList()
83
+ for _ in range(config.depth):
84
+ layer = nn.ModuleDict({
85
+ "norm1": nn.LayerNorm(config.embed_dim),
86
+ "attn": LumensparkSelfAttention(
87
+ embed_dim=config.embed_dim,
88
+ num_heads=config.heads,
89
+ head_dim=config.embed_dim // config.heads,
90
+ dropout=config.dropout
91
+ ),
92
+ "norm2": nn.LayerNorm(config.embed_dim),
93
+ "ffn": nn.Sequential(
94
+ LowRankLinear(config.embed_dim, config.embed_dim * 4, rank=config.rank),
95
+ nn.GELU(),
96
+ nn.Dropout(config.dropout),
97
+ LowRankLinear(config.embed_dim * 4, config.embed_dim, rank=config.rank),
98
+ nn.Dropout(config.dropout)
99
+ ),
100
+ })
101
+ # Assign the parameters directly as attributes
102
+ layer.layer_scale_attn = nn.Parameter(torch.ones(config.embed_dim) * 1e-2)
103
+ layer.layer_scale_ffn = nn.Parameter(torch.ones(config.embed_dim) * 1e-2)
104
+ self.layers.append(layer)
105
+
106
+ # Final LayerNorm layer
107
+ self.final_norm = nn.LayerNorm(config.embed_dim)
108
+
109
+ # Feed-forward output layer
110
+ self.fc_out = nn.Linear(config.embed_dim, config.vocab_size)
111
+ self.dropout = nn.Dropout(config.dropout)
112
+
113
+ # Initialize model weights
114
+ self.init_weights()
115
+
116
+ @staticmethod
117
+ def top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float('Inf')):
118
+ """
119
+ Filter a distribution of logits using top-k and/or top-p filtering.
120
+ """
121
+ top_k = min(top_k, logits.size(-1))
122
+ if top_k > 0:
123
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
124
+ logits[indices_to_remove] = filter_value
125
+ if top_p < 1.0:
126
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
127
+ cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
128
+ sorted_indices_to_remove = cumulative_probs > top_p
129
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
130
+ sorted_indices_to_remove[..., 0] = 0
131
+ indices_to_remove = sorted_indices[sorted_indices_to_remove]
132
+ logits[:, indices_to_remove] = filter_value
133
+ return logits
134
+
135
+ def generate(self, input_ids, attention_mask=None, max_length=160, min_length=20, temperature=0.6, top_k=50, top_p=0.9, repetition_penalty=1.1, do_sample=True):
136
+ """
137
+ Text generation method that handles auto-regressive generation with repetition penalty.
138
+ Input `input_ids` should be a tensor. Returns generated tokens.
139
+ """
140
+ self.eval()
141
+ device = input_ids.device
142
+ generated_tokens = input_ids
143
+
144
+ for _ in range(max_length - input_ids.size(1)):
145
+ # Forward pass for logits
146
+ outputs = self.forward(input_ids=generated_tokens, attention_mask=attention_mask)
147
+ logits = outputs["logits"][:, -1, :]
148
+
149
+ # Adjust logits by temperature
150
+ logits = logits / temperature
151
+
152
+ # Apply repetition penalty by reducing logits of tokens already generated
153
+ for token in set(generated_tokens.view(-1).tolist()):
154
+ logits[:, token] /= repetition_penalty
155
+
156
+ # Apply sampling with top-k and top-p
157
+ if do_sample:
158
+ filtered_logits = LumensparkModel.top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
159
+ probs = torch.softmax(filtered_logits, dim=-1)
160
+ next_token = torch.multinomial(probs, num_samples=1)
161
+ else:
162
+ next_token = torch.argmax(logits, dim=-1, keepdim=True)
163
+
164
+ # Append the generated token
165
+ generated_tokens = torch.cat((generated_tokens, next_token), dim=1)
166
+ attention_mask = torch.ones_like(generated_tokens).to(device)
167
+
168
+ # Ensure min_length before stopping generation with end-of-sequence (EOS) token
169
+ if next_token.item() == self.config.eos_token_id and generated_tokens.size(1) < min_length:
170
+ continue
171
+ if next_token.item() == self.config.eos_token_id:
172
+ break
173
+ return generated_tokens
174
+
175
+ def forward(self, input_ids, attention_mask=None, labels=None):
176
+ """
177
+ Forward pass of the model. If `labels` are provided, computes the loss.
178
+ """
179
+ batch_size, seq_length = input_ids.size()
180
+
181
+ # Generate position ids for input tokens
182
+ position_ids = torch.arange(0, seq_length, dtype=torch.long, device=input_ids.device)
183
+ position_ids = position_ids.unsqueeze(0).expand(batch_size, seq_length)
184
+
185
+ # Embed tokens and positions
186
+ token_embeddings = self.token_embedding(input_ids)
187
+ position_embeddings = self.position_embedding(position_ids)
188
+
189
+ # Combine token and position embeddings
190
+ embeddings = token_embeddings + position_embeddings
191
+ embeddings = self.dropout(embeddings)
192
+
193
+ # Create causal mask for self-attention to ensure autoregressive behavior
194
+ device = embeddings.device
195
+ causal_mask = torch.tril(torch.ones((seq_length, seq_length), device=device)).unsqueeze(0).unsqueeze(0)
196
+
197
+ # Combine with attention mask if provided
198
+ combined_mask = causal_mask if attention_mask is None else attention_mask[:, None, None, :].float() * causal_mask
199
+
200
+ # Pass through transformer layers
201
+ for layer in self.layers:
202
+ embeddings_norm = layer["norm1"](embeddings)
203
+ attn_output = layer["attn"](embeddings_norm, attention_mask=combined_mask)
204
+ embeddings = embeddings + layer.layer_scale_attn * attn_output
205
+
206
+ embeddings_norm = layer["norm2"](embeddings)
207
+ ffn_output = layer["ffn"](embeddings_norm)
208
+ embeddings = embeddings + layer.layer_scale_ffn * ffn_output
209
+
210
+ # Final normalization and output projection to logits
211
+ embeddings = self.final_norm(embeddings)
212
+ logits = self.fc_out(embeddings)
213
+
214
+ # Compute loss if labels are provided
215
+ loss = None
216
+ if labels is not None:
217
+ shift_logits = logits[:, :-1, :].contiguous().view(-1, self.config.vocab_size)
218
+ shift_labels = labels[:, 1:].contiguous().view(-1)
219
+ loss_fct = nn.CrossEntropyLoss()
220
+ loss = loss_fct(shift_logits, shift_labels)
221
+
222
+ return {"loss": loss, "logits": logits}
223
+
224
+ # Register LumensparkForCausalLM with AutoModelForCausalLM
225
+ AutoConfig.register("lumenspark", LumensparkConfig)
226
+ AutoModelForCausalLM.register(LumensparkConfig, LumensparkModel)