File size: 7,454 Bytes
e2536b8 e17f601 e2536b8 e17f601 e2536b8 e17f601 e2536b8 e17f601 e2536b8 d0e13a3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 |
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
from transformers import PreTrainedModel, AutoConfig, AutoModelForCausalLM
from .configuration_lumenspark import LumensparkConfig
from torch import nn
import torch
import math
# ----------------------------
# Low-Rank Linear Layer Implementation
# ----------------------------
class LowRankLinear(nn.Module):
"""
A low-rank linear layer that factorizes a standard linear layer into two smaller ones.
This allows for reduced parameter count and faster computation.
"""
def __init__(self, in_features, out_features, rank, init_std=0.02):
super().__init__()
self.U = nn.Linear(in_features, rank, bias=False)
self.V = nn.Linear(rank, out_features, bias=False)
nn.init.normal_(self.U.weight, std=init_std)
nn.init.normal_(self.V.weight, std=init_std)
def forward(self, x):
"""
Forward pass through two low-rank linear layers (U and V).
"""
return self.V(self.U(x))
# ----------------------------
# Lumenspark Self-Attention Implementation
# ----------------------------
class LumensparkSelfAttention(nn.Module):
"""
Custom self-attention mechanism for the Lumenspark model.
It uses low-rank approximations to reduce computational cost and memory usage.
"""
def __init__(self, embed_dim, num_heads, head_dim=None, dropout=0.0):
super().__init__()
assert (embed_dim % num_heads) == 0, 'Embedding dimension must be divisible by the number of heads'
self.num_heads = num_heads
self.embed_dim = embed_dim
self.head_dim = head_dim if head_dim is not None else embed_dim // num_heads
# Query, Key and Value transformations using LowRankLinear
self.q_proj = nn.Linear(embed_dim, self.head_dim * num_heads)
self.k_proj = nn.Linear(embed_dim, self.head_dim * num_heads)
self.v_proj = nn.Linear(embed_dim, self.head_dim * num_heads)
self.dropout_layer = nn.Dropout(dropout)
self.output_transform = nn.Linear(self.head_dim * num_heads, embed_dim)
def stable_softmax(self, x, dim=-1):
# Subtract max for numerical stability
x_max = torch.max(x, dim=dim, keepdim=True)[0]
exp_x = torch.exp(x - x_max)
return exp_x / (torch.sum(exp_x, dim=dim, keepdim=True) + 1e-6)
def forward(self, inputs, attention_mask=None):
batch_size, seq_len, _ = inputs.shape
q = self.q_proj(inputs).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
k = self.k_proj(inputs).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
v = self.v_proj(inputs).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
attention_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
if attention_mask is not None:
attention_scores = attention_scores.masked_fill(attention_mask == 0, float('-inf'))
attention_weights = self.stable_softmax(attention_scores, dim=-1)
attention_weights = self.dropout_layer(attention_weights)
attention_output = torch.matmul(attention_weights, v)
attention_output = attention_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.embed_dim)
return self.output_transform(attention_output)
# ----------------------------
# Define Lumenspark Model Wrapper
# ----------------------------
class LumensparkModel(PreTrainedModel):
config_class = LumensparkConfig
def __init__(self, config):
super().__init__(config)
self.config = config
# Token and position embeddings
self.token_embedding = nn.Embedding(config.vocab_size, config.embed_dim)
self.position_embedding = nn.Embedding(config.seq_length, config.embed_dim)
# Lumenspark transformer encoder layers with prenormalization and LayerScale
self.layers = nn.ModuleList()
for _ in range(config.depth):
layer = nn.ModuleDict({
"norm1": nn.LayerNorm(config.embed_dim),
"attn": LumensparkSelfAttention(
embed_dim=config.embed_dim,
num_heads=config.heads,
head_dim=config.embed_dim // config.heads,
dropout=config.dropout
),
"norm2": nn.LayerNorm(config.embed_dim),
"ffn": nn.Sequential(
LowRankLinear(config.embed_dim, config.embed_dim * 4, rank=config.rank),
nn.GELU(),
nn.Dropout(config.dropout),
LowRankLinear(config.embed_dim * 4, config.embed_dim, rank=config.rank),
nn.Dropout(config.dropout)
),
})
layer.layer_scale_attn = nn.Parameter(torch.ones(config.embed_dim) * 1e-2)
layer.layer_scale_ffn = nn.Parameter(torch.ones(config.embed_dim) * 1e-2)
self.layers.append(layer)
self.final_norm = nn.LayerNorm(config.embed_dim)
self.fc_out = nn.Linear(config.embed_dim, config.vocab_size)
self.dropout = nn.Dropout(config.dropout)
# Call init_weights at the end to ensure proper initialization
self.init_weights()
def forward(self, input_ids, attention_mask=None, labels=None):
batch_size, seq_length = input_ids.size()
position_ids = torch.arange(0, seq_length, dtype=torch.long, device=input_ids.device)
position_ids = position_ids.unsqueeze(0).expand(batch_size, seq_length)
token_embeddings = self.token_embedding(input_ids)
position_embeddings = self.position_embedding(position_ids)
embeddings = token_embeddings + position_embeddings
embeddings = self.dropout(embeddings)
causal_mask = torch.tril(torch.ones((seq_length, seq_length), device=embeddings.device)).unsqueeze(0).unsqueeze(0)
if attention_mask is not None:
attention_mask = attention_mask[:, None, None, :].float()
combined_mask = attention_mask * causal_mask
else:
combined_mask = causal_mask
for layer in self.layers:
embeddings_norm = layer["norm1"](embeddings)
attn_output = layer["attn"](embeddings_norm, attention_mask=combined_mask)
embeddings = embeddings + layer.layer_scale_attn * attn_output
embeddings_norm = layer["norm2"](embeddings)
ffn_output = layer["ffn"](embeddings_norm)
embeddings = embeddings + layer.layer_scale_ffn * ffn_output
embeddings = self.final_norm(embeddings)
logits = self.fc_out(embeddings)
loss = None
if labels is not None:
shift_logits = logits[:, :-1, :].contiguous().view(-1, self.config.vocab_size)
shift_labels = labels[:, 1:].contiguous().view(-1)
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(shift_logits, shift_labels)
return CausalLMOutputWithCrossAttentions(
loss=loss,
logits=logits
)
AutoConfig.register("lumenspark", LumensparkConfig)
AutoModelForCausalLM.register(LumensparkConfig, LumensparkModel)
|