import torch import torch.nn as nn import torch.nn.functional as F from typing import Optional class RMSNorm(nn.Module): """ Root Mean Square Layer Normalization (RMSNorm). """ def __init__(self, hidden_size: int, eps: float = 1e-5): super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.eps = eps def forward(self, x: torch.Tensor) -> torch.Tensor: variance = x.pow(2).mean(-1, keepdim=True) x = x * torch.rsqrt(variance + self.eps) return self.weight * x class RotaryPositionalEmbedding(nn.Module): """ Rotary Positional Embedding (RoPE) for transformers. """ def __init__(self, dim: int, theta: float = 10000.0): super().__init__() self.dim = dim self.theta = theta def forward(self, x: torch.Tensor, seq_len: int) -> torch.Tensor: """ Apply rotary positional embedding to the input tensor. Args: x (torch.Tensor): Input tensor of shape (batch_size, seq_len, num_heads, head_dim). seq_len (int): Sequence length. Returns: torch.Tensor: Output tensor with rotary positional embeddings applied. """ batch_size, seq_len, num_heads, head_dim = x.shape # Generate position indices position = torch.arange(seq_len, dtype=torch.float32, device=x.device).unsqueeze(-1) # Generate frequencies freqs = torch.exp( torch.arange(0, head_dim, 2, dtype=torch.float32, device=x.device) * -(torch.log(torch.tensor(self.theta)) / head_dim) ) # Compute sinusoids sinusoid = position * freqs sin = torch.sin(sinusoid) cos = torch.cos(sinusoid) # Reshape sin and cos to match the input tensor's shape sin = sin.unsqueeze(0).unsqueeze(2) # Shape: (1, seq_len, 1, head_dim // 2) cos = cos.unsqueeze(0).unsqueeze(2) # Shape: (1, seq_len, 1, head_dim // 2) # Apply rotary embeddings x_rotated = x.clone() x_rotated[..., 0::2] = x[..., 0::2] * cos - x[..., 1::2] * sin x_rotated[..., 1::2] = x[..., 1::2] * cos + x[..., 0::2] * sin return x_rotated from torch.utils.checkpoint import checkpoint class TransformerBlock(nn.Module): """ A single transformer block with self-attention and feed-forward layers. """ def __init__( self, hidden_size: int, num_attention_heads: int, intermediate_size: int, num_key_value_heads: int, rms_norm_eps: float, hidden_act: str = "silu", ): super().__init__() self.hidden_size = hidden_size self.num_attention_heads = num_attention_heads self.num_key_value_heads = num_key_value_heads self.head_dim = hidden_size // num_attention_heads # Ensure the hidden size is divisible by the number of attention heads if hidden_size % num_attention_heads != 0: raise ValueError( f"hidden_size ({hidden_size}) must be divisible by num_attention_heads ({num_attention_heads})" ) # Self-attention layers self.q_proj = nn.Linear(hidden_size, hidden_size) self.k_proj = nn.Linear(hidden_size, num_key_value_heads * self.head_dim) self.v_proj = nn.Linear(hidden_size, num_key_value_heads * self.head_dim) self.o_proj = nn.Linear(hidden_size, hidden_size) # Feed-forward layers self.gate_proj = nn.Linear(hidden_size, intermediate_size) self.up_proj = nn.Linear(hidden_size, intermediate_size) self.down_proj = nn.Linear(intermediate_size, hidden_size) # Normalization layers self.input_norm = RMSNorm(hidden_size, eps=rms_norm_eps) self.post_attention_norm = RMSNorm(hidden_size, eps=rms_norm_eps) # Activation function self.act = nn.SiLU() if hidden_act == "silu" else nn.GELU() # Rotary positional embedding self.rope = RotaryPositionalEmbedding(self.head_dim) def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: def create_custom_forward(module): def custom_forward(*inputs): return module._forward(inputs[0], inputs[1]) return custom_forward # Use gradient checkpointing return checkpoint(create_custom_forward(self), x, attention_mask) def _forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: # Self-attention residual = x x = self.input_norm(x) # Project inputs to query, key, and value batch_size, seq_len, _ = x.shape # Reshape queries for multi-head attention q = self.q_proj(x).view(batch_size, seq_len, self.num_attention_heads, self.head_dim) # Reshape keys and values for key-value heads k = self.k_proj(x).view(batch_size, seq_len, self.num_key_value_heads, self.head_dim) v = self.v_proj(x).view(batch_size, seq_len, self.num_key_value_heads, self.head_dim) # Apply rotary positional embedding q = self.rope(q, seq_len) k = self.rope(k, seq_len) # Scaled dot-product attention attn_output = F.scaled_dot_product_attention(q, k, v, attn_mask=attention_mask) attn_output = attn_output.transpose(1, 2).reshape(batch_size, seq_len, self.hidden_size) attn_output = self.o_proj(attn_output) # Add residual connection x = residual + attn_output # Feed-forward network residual = x x = self.post_attention_norm(x) gate = self.act(self.gate_proj(x)) up = self.up_proj(x) ff_output = self.down_proj(gate * up) # Add residual connection x = residual + ff_output return x class TransformerModel(nn.Module): def __init__( self, vocab_size: int, hidden_size: int, num_hidden_layers: int, num_attention_heads: int, intermediate_size: int, num_key_value_heads: int, max_position_embeddings: int, rms_norm_eps: float, hidden_act: str = "silu", tie_word_embeddings: bool = True, ): super().__init__() self.vocab_size = vocab_size self.hidden_size = hidden_size self.num_hidden_layers = num_hidden_layers self.max_position_embeddings = max_position_embeddings # Embedding layers (skip quantization for these) self.embed_tokens = nn.Embedding(vocab_size, hidden_size) self.embed_positions = nn.Embedding(max_position_embeddings, hidden_size) # Transformer blocks self.layers = nn.ModuleList([ TransformerBlock( hidden_size=hidden_size, num_attention_heads=num_attention_heads, intermediate_size=intermediate_size, num_key_value_heads=num_key_value_heads, rms_norm_eps=rms_norm_eps, hidden_act=hidden_act, ) for _ in range(num_hidden_layers) ]) # Final normalization layer self.final_norm = RMSNorm(hidden_size, eps=rms_norm_eps) # Output layer (tied to input embeddings if specified) self.lm_head = nn.Linear(hidden_size, vocab_size, bias=False) if tie_word_embeddings: self.lm_head.weight = self.embed_tokens.weight def forward(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: # Embed tokens and positions seq_len = input_ids.size(1) position_ids = torch.arange(seq_len, dtype=torch.long, device=input_ids.device) token_embeddings = self.embed_tokens(input_ids) position_embeddings = self.embed_positions(position_ids) x = token_embeddings + position_embeddings # Pass through transformer layers for layer in self.layers: x = layer(x, attention_mask) # Final normalization x = self.final_norm(x) # Output logits logits = self.lm_head(x) return logits def generate( self, input_ids: torch.Tensor, max_length: int = 50, temperature: float = 1.0, top_k: int = 50, do_sample: bool = True, ) -> torch.Tensor: """ Generate text autoregressively. Args: input_ids (torch.Tensor): Input token IDs of shape (batch_size, seq_len). max_length (int): Maximum length of the generated sequence. temperature (float): Sampling temperature. Higher values mean more random sampling. top_k (int): Top-k sampling. Only the top-k tokens are considered. do_sample (bool): Whether to sample from the distribution or take the argmax. Returns: torch.Tensor: Generated token IDs of shape (batch_size, max_length). """ self.eval() with torch.no_grad(): for _ in range(max_length - input_ids.size(1)): # Get the logits for the last token logits = self(input_ids)[:, -1, :] # Apply temperature logits = logits / temperature # Top-k sampling if top_k > 0: top_k_values, top_k_indices = torch.topk(logits, top_k) logits[logits < top_k_values[:, -1].unsqueeze(-1)] = -float("Inf") # Convert logits to probabilities probs = F.softmax(logits, dim=-1) # Sample or take the argmax if do_sample: next_token = torch.multinomial(probs, num_samples=1) else: next_token = torch.argmax(probs, dim=-1, keepdim=True) # Append the next token to the input_ids input_ids = torch.cat([input_ids, next_token], dim=-1) return input_ids # Create the model based on the configuration def create_model_from_config(config: dict) -> TransformerModel: model_config = config["model"]["model_config"] return TransformerModel( vocab_size=model_config["vocab_size"], hidden_size=model_config["hidden_size"], num_hidden_layers=model_config["num_hidden_layers"], num_attention_heads=model_config["num_attention_heads"], intermediate_size=model_config["intermediate_size"], num_key_value_heads=model_config["num_key_value_heads"], max_position_embeddings=model_config["max_position_embeddings"], rms_norm_eps=model_config["rms_norm_eps"], hidden_act=model_config["hidden_act"], tie_word_embeddings=model_config["tie_word_embeddings"], ) # Example usage if __name__ == "__main__": import json # Load the configuration file with open("config_smollm2_135M.json", "r") as f: config = json.load(f) # Create the model model = create_model_from_config(config) print(model)