from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions from transformers import PreTrainedModel, AutoConfig, AutoModelForCausalLM from .configuration_lumenspark import LumensparkConfig from torch import nn import torch import math # ---------------------------- # Low-Rank Linear Layer Implementation # ---------------------------- class LowRankLinear(nn.Module): """ A low-rank linear layer that factorizes a standard linear layer into two smaller ones. This allows for reduced parameter count and faster computation. """ def __init__(self, in_features, out_features, rank, init_std=0.02): super().__init__() self.U = nn.Linear(in_features, rank, bias=False) self.V = nn.Linear(rank, out_features, bias=False) nn.init.normal_(self.U.weight, std=init_std) nn.init.normal_(self.V.weight, std=init_std) def forward(self, x): """ Forward pass through two low-rank linear layers (U and V). """ return self.V(self.U(x)) # ---------------------------- # Lumenspark Self-Attention Implementation # ---------------------------- class LumensparkSelfAttention(nn.Module): """ Custom self-attention mechanism for the Lumenspark model. It uses low-rank approximations to reduce computational cost and memory usage. """ def __init__(self, embed_dim, num_heads, head_dim=None, dropout=0.0): super().__init__() assert (embed_dim % num_heads) == 0, 'Embedding dimension must be divisible by the number of heads' self.num_heads = num_heads self.embed_dim = embed_dim self.head_dim = head_dim if head_dim is not None else embed_dim // num_heads # Query, Key and Value transformations using LowRankLinear self.q_proj = nn.Linear(embed_dim, self.head_dim * num_heads) self.k_proj = nn.Linear(embed_dim, self.head_dim * num_heads) self.v_proj = nn.Linear(embed_dim, self.head_dim * num_heads) self.dropout_layer = nn.Dropout(dropout) self.output_transform = nn.Linear(self.head_dim * num_heads, embed_dim) def stable_softmax(self, x, dim=-1): # Subtract max for numerical stability x_max = torch.max(x, dim=dim, keepdim=True)[0] exp_x = torch.exp(x - x_max) return exp_x / (torch.sum(exp_x, dim=dim, keepdim=True) + 1e-6) def forward(self, inputs, attention_mask=None): batch_size, seq_len, _ = inputs.shape q = self.q_proj(inputs).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) k = self.k_proj(inputs).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) v = self.v_proj(inputs).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) attention_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim) if attention_mask is not None: attention_scores = attention_scores.masked_fill(attention_mask == 0, float('-inf')) attention_weights = self.stable_softmax(attention_scores, dim=-1) attention_weights = self.dropout_layer(attention_weights) attention_output = torch.matmul(attention_weights, v) attention_output = attention_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.embed_dim) return self.output_transform(attention_output) # ---------------------------- # Define Lumenspark Model Wrapper # ---------------------------- class LumensparkModel(PreTrainedModel): config_class = LumensparkConfig def __init__(self, config): super().__init__(config) self.config = config # Token and position embeddings self.token_embedding = nn.Embedding(config.vocab_size, config.embed_dim) self.position_embedding = nn.Embedding(config.seq_length, config.embed_dim) # Lumenspark transformer encoder layers with prenormalization and LayerScale self.layers = nn.ModuleList() for _ in range(config.depth): layer = nn.ModuleDict({ "norm1": nn.LayerNorm(config.embed_dim), "attn": LumensparkSelfAttention( embed_dim=config.embed_dim, num_heads=config.heads, head_dim=config.embed_dim // config.heads, dropout=config.dropout ), "norm2": nn.LayerNorm(config.embed_dim), "ffn": nn.Sequential( LowRankLinear(config.embed_dim, config.embed_dim * 4, rank=config.rank), nn.GELU(), nn.Dropout(config.dropout), LowRankLinear(config.embed_dim * 4, config.embed_dim, rank=config.rank), nn.Dropout(config.dropout) ), }) layer.layer_scale_attn = nn.Parameter(torch.ones(config.embed_dim) * 1e-2) layer.layer_scale_ffn = nn.Parameter(torch.ones(config.embed_dim) * 1e-2) self.layers.append(layer) self.final_norm = nn.LayerNorm(config.embed_dim) self.fc_out = nn.Linear(config.embed_dim, config.vocab_size) self.dropout = nn.Dropout(config.dropout) # Call init_weights at the end to ensure proper initialization self.init_weights() def forward(self, input_ids, attention_mask=None, labels=None): batch_size, seq_length = input_ids.size() position_ids = torch.arange(0, seq_length, dtype=torch.long, device=input_ids.device) position_ids = position_ids.unsqueeze(0).expand(batch_size, seq_length) token_embeddings = self.token_embedding(input_ids) position_embeddings = self.position_embedding(position_ids) embeddings = token_embeddings + position_embeddings embeddings = self.dropout(embeddings) causal_mask = torch.tril(torch.ones((seq_length, seq_length), device=embeddings.device)).unsqueeze(0).unsqueeze(0) if attention_mask is not None: attention_mask = attention_mask[:, None, None, :].float() combined_mask = attention_mask * causal_mask else: combined_mask = causal_mask for layer in self.layers: embeddings_norm = layer["norm1"](embeddings) attn_output = layer["attn"](embeddings_norm, attention_mask=combined_mask) embeddings = embeddings + layer.layer_scale_attn * attn_output embeddings_norm = layer["norm2"](embeddings) ffn_output = layer["ffn"](embeddings_norm) embeddings = embeddings + layer.layer_scale_ffn * ffn_output embeddings = self.final_norm(embeddings) logits = self.fc_out(embeddings) loss = None if labels is not None: shift_logits = logits[:, :-1, :].contiguous().view(-1, self.config.vocab_size) shift_labels = labels[:, 1:].contiguous().view(-1) loss_fct = nn.CrossEntropyLoss() loss = loss_fct(shift_logits, shift_labels) return CausalLMOutputWithCrossAttentions( loss=loss, logits=logits ) AutoConfig.register("lumenspark", LumensparkConfig) AutoModelForCausalLM.register(LumensparkConfig, LumensparkModel)