|
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
|
|
from transformers import PreTrainedModel, AutoConfig, AutoModelForCausalLM
|
|
from .configuration_lumenspark import LumensparkConfig
|
|
from torch import nn
|
|
import torch
|
|
import math
|
|
|
|
|
|
|
|
|
|
|
|
class LowRankLinear(nn.Module):
|
|
"""
|
|
A low-rank linear layer that factorizes a standard linear layer into two smaller ones.
|
|
This allows for reduced parameter count and faster computation.
|
|
"""
|
|
def __init__(self, in_features, out_features, rank, init_std=0.02):
|
|
super().__init__()
|
|
self.U = nn.Linear(in_features, rank, bias=False)
|
|
self.V = nn.Linear(rank, out_features, bias=False)
|
|
nn.init.normal_(self.U.weight, std=init_std)
|
|
nn.init.normal_(self.V.weight, std=init_std)
|
|
|
|
def forward(self, x):
|
|
"""
|
|
Forward pass through two low-rank linear layers (U and V).
|
|
"""
|
|
return self.V(self.U(x))
|
|
|
|
|
|
|
|
|
|
|
|
class LumensparkSelfAttention(nn.Module):
|
|
"""
|
|
Custom self-attention mechanism for the Lumenspark model.
|
|
It uses low-rank approximations to reduce computational cost and memory usage.
|
|
"""
|
|
def __init__(self, embed_dim, num_heads, head_dim=None, dropout=0.0):
|
|
super().__init__()
|
|
assert (embed_dim % num_heads) == 0, 'Embedding dimension must be divisible by the number of heads'
|
|
|
|
self.num_heads = num_heads
|
|
self.embed_dim = embed_dim
|
|
self.head_dim = head_dim if head_dim is not None else embed_dim // num_heads
|
|
|
|
|
|
self.q_proj = nn.Linear(embed_dim, self.head_dim * num_heads)
|
|
self.k_proj = nn.Linear(embed_dim, self.head_dim * num_heads)
|
|
self.v_proj = nn.Linear(embed_dim, self.head_dim * num_heads)
|
|
|
|
self.dropout_layer = nn.Dropout(dropout)
|
|
self.output_transform = nn.Linear(self.head_dim * num_heads, embed_dim)
|
|
|
|
def stable_softmax(self, x, dim=-1):
|
|
|
|
x_max = torch.max(x, dim=dim, keepdim=True)[0]
|
|
exp_x = torch.exp(x - x_max)
|
|
return exp_x / (torch.sum(exp_x, dim=dim, keepdim=True) + 1e-6)
|
|
|
|
def forward(self, inputs, attention_mask=None):
|
|
batch_size, seq_len, _ = inputs.shape
|
|
|
|
q = self.q_proj(inputs).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
|
k = self.k_proj(inputs).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
|
v = self.v_proj(inputs).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
|
|
|
attention_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
|
|
|
if attention_mask is not None:
|
|
attention_scores = attention_scores.masked_fill(attention_mask == 0, float('-inf'))
|
|
|
|
attention_weights = self.stable_softmax(attention_scores, dim=-1)
|
|
attention_weights = self.dropout_layer(attention_weights)
|
|
|
|
attention_output = torch.matmul(attention_weights, v)
|
|
attention_output = attention_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.embed_dim)
|
|
return self.output_transform(attention_output)
|
|
|
|
|
|
|
|
|
|
|
|
class LumensparkModel(PreTrainedModel):
|
|
config_class = LumensparkConfig
|
|
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
self.config = config
|
|
|
|
|
|
self.token_embedding = nn.Embedding(config.vocab_size, config.embed_dim)
|
|
self.position_embedding = nn.Embedding(config.seq_length, config.embed_dim)
|
|
|
|
|
|
self.layers = nn.ModuleList()
|
|
for _ in range(config.depth):
|
|
layer = nn.ModuleDict({
|
|
"norm1": nn.LayerNorm(config.embed_dim),
|
|
"attn": LumensparkSelfAttention(
|
|
embed_dim=config.embed_dim,
|
|
num_heads=config.heads,
|
|
head_dim=config.embed_dim // config.heads,
|
|
dropout=config.dropout
|
|
),
|
|
"norm2": nn.LayerNorm(config.embed_dim),
|
|
"ffn": nn.Sequential(
|
|
LowRankLinear(config.embed_dim, config.embed_dim * 4, rank=config.rank),
|
|
nn.GELU(),
|
|
nn.Dropout(config.dropout),
|
|
LowRankLinear(config.embed_dim * 4, config.embed_dim, rank=config.rank),
|
|
nn.Dropout(config.dropout)
|
|
),
|
|
})
|
|
layer.layer_scale_attn = nn.Parameter(torch.ones(config.embed_dim) * 1e-2)
|
|
layer.layer_scale_ffn = nn.Parameter(torch.ones(config.embed_dim) * 1e-2)
|
|
self.layers.append(layer)
|
|
|
|
self.final_norm = nn.LayerNorm(config.embed_dim)
|
|
self.fc_out = nn.Linear(config.embed_dim, config.vocab_size)
|
|
self.dropout = nn.Dropout(config.dropout)
|
|
|
|
|
|
self.init_weights()
|
|
|
|
def forward(self, input_ids, attention_mask=None, labels=None):
|
|
batch_size, seq_length = input_ids.size()
|
|
|
|
position_ids = torch.arange(0, seq_length, dtype=torch.long, device=input_ids.device)
|
|
position_ids = position_ids.unsqueeze(0).expand(batch_size, seq_length)
|
|
|
|
token_embeddings = self.token_embedding(input_ids)
|
|
position_embeddings = self.position_embedding(position_ids)
|
|
embeddings = token_embeddings + position_embeddings
|
|
embeddings = self.dropout(embeddings)
|
|
|
|
causal_mask = torch.tril(torch.ones((seq_length, seq_length), device=embeddings.device)).unsqueeze(0).unsqueeze(0)
|
|
|
|
if attention_mask is not None:
|
|
attention_mask = attention_mask[:, None, None, :].float()
|
|
combined_mask = attention_mask * causal_mask
|
|
else:
|
|
combined_mask = causal_mask
|
|
|
|
for layer in self.layers:
|
|
embeddings_norm = layer["norm1"](embeddings)
|
|
attn_output = layer["attn"](embeddings_norm, attention_mask=combined_mask)
|
|
embeddings = embeddings + layer.layer_scale_attn * attn_output
|
|
|
|
embeddings_norm = layer["norm2"](embeddings)
|
|
ffn_output = layer["ffn"](embeddings_norm)
|
|
embeddings = embeddings + layer.layer_scale_ffn * ffn_output
|
|
|
|
embeddings = self.final_norm(embeddings)
|
|
logits = self.fc_out(embeddings)
|
|
|
|
loss = None
|
|
if labels is not None:
|
|
shift_logits = logits[:, :-1, :].contiguous().view(-1, self.config.vocab_size)
|
|
shift_labels = labels[:, 1:].contiguous().view(-1)
|
|
loss_fct = nn.CrossEntropyLoss()
|
|
loss = loss_fct(shift_logits, shift_labels)
|
|
|
|
return CausalLMOutputWithCrossAttentions(
|
|
loss=loss,
|
|
logits=logits
|
|
)
|
|
|
|
AutoConfig.register("lumenspark", LumensparkConfig)
|
|
AutoModelForCausalLM.register(LumensparkConfig, LumensparkModel)
|
|
|