import torch import torch.nn as nn import torch.nn.functional as F from datasets import load_dataset from torch.utils.data import Dataset, DataLoader from transformers import GPT2Tokenizer import math from einops import einsum from tqdm import tqdm from einops.layers.torch import Rearrange import os import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.data.distributed import DistributedSampler def exists(v): return v is not None def default(v, d): return v if exists(v) else d class RMSNorm(nn.Module): def __init__(self, dim): super().__init__() self.scale = dim ** 0.5 self.gamma = nn.Parameter(torch.ones(dim)) def forward(self, x): return F.normalize(x, dim=-1) * self.scale * self.gamma class ProductKeyMemory(nn.Module): def __init__(self, dim, num_keys): super().__init__() self.dim = dim self.num_keys = num_keys self.keys = nn.Parameter(torch.randn(num_keys, dim // 2)) def forward(self, query): query = query.view(query.shape[0], 2, -1) dots = torch.einsum('bkd,nd->bkn', query, self.keys) return dots.view(query.shape[0], -1) class PEER(nn.Module): def __init__( self, dim, *, heads=8, num_experts=1_000_000, num_experts_per_head=16, activation=nn.GELU, dim_key=None, product_key_topk=None, separate_embed_per_head=False, pre_rmsnorm=False, dropout=0. ): super().__init__() self.norm = RMSNorm(dim) if pre_rmsnorm else nn.Identity() self.heads = heads self.separate_embed_per_head = separate_embed_per_head self.num_experts = num_experts num_expert_sets = heads if separate_embed_per_head else 1 self.weight_down_embed = nn.Embedding(num_experts * num_expert_sets, dim) self.weight_up_embed = nn.Embedding(num_experts * num_expert_sets, dim) self.activation = activation() assert (num_experts ** 0.5).is_integer(), '`num_experts` needs to be a square' assert (dim % 2) == 0, 'feature dimension should be divisible by 2' dim_key = default(dim_key, dim // 2) self.num_keys = int(num_experts ** 0.5) self.to_queries = nn.Sequential( nn.Linear(dim, dim_key * heads * 2, bias=False), Rearrange('b n (p h d) -> p b n h d', p=2, h=heads) ) self.product_key_topk = default(product_key_topk, num_experts_per_head) self.num_experts_per_head = num_experts_per_head self.keys = nn.Parameter(torch.randn(heads, self.num_keys, 2, dim_key)) self.dropout = nn.Dropout(dropout) def forward(self, x): x = self.norm(x) queries = self.to_queries(x) sim = einsum(queries, self.keys, 'p b n h d, h k p d -> p b n h k') (scores_x, scores_y), (indices_x, indices_y) = [s.topk(self.product_key_topk, dim=-1) for s in sim] all_scores = scores_x.unsqueeze(-1) + scores_y.unsqueeze(-2) all_indices = indices_x.unsqueeze(-1) * self.num_keys + indices_y.unsqueeze(-2) all_scores = all_scores.view(*all_scores.shape[:-2], -1) all_indices = all_indices.view(*all_indices.shape[:-2], -1) scores, pk_indices = all_scores.topk(self.num_experts_per_head, dim=-1) indices = all_indices.gather(-1, pk_indices) if self.separate_embed_per_head: head_expert_offsets = torch.arange(self.heads, device=x.device) * self.num_experts indices = indices + head_expert_offsets.view(1, 1, -1, 1) weights_down = self.weight_down_embed(pk_indices) weights_up = self.weight_up_embed(pk_indices) x = einsum(x, weights_down, 'b n d, b n h k d -> b n h k') x = self.activation(x) x = self.dropout(x) x = x * F.softmax(scores, dim=-1) x = einsum(x, weights_up, 'b n h k, b n h k d -> b n d') return x class TransformerBlock(nn.Module): def __init__(self, dim, num_heads, num_experts, num_experts_per_head, dropout=0.1): super(TransformerBlock, self).__init__() self.attention = nn.MultiheadAttention(dim, num_heads) self.norm1 = nn.LayerNorm(dim) self.norm2 = nn.LayerNorm(dim) self.peer1 = PEER(dim, heads=num_heads, num_experts=num_experts, num_experts_per_head=num_experts_per_head) self.peer2 = PEER(dim, heads=num_heads, num_experts=num_experts, num_experts_per_head=num_experts_per_head) self.dropout = nn.Dropout(dropout) def forward(self, x): attn_output, _ = self.attention(x, x, x) x = x + self.dropout(attn_output) x = self.norm1(x) peer_output1 = self.peer1(x) peer_output2 = self.peer2(F.gelu(peer_output1)) x = x + self.dropout(peer_output2) x = self.norm2(x) return x class PEERLanguageModel(nn.Module): def __init__(self, vocab_size, dim, num_layers, num_heads, num_experts, top_k): super().__init__() self.token_embedding = nn.Embedding(vocab_size, dim) self.position_embedding = nn.Embedding(512, dim) self.layers = nn.ModuleList([TransformerBlock(dim, num_heads, num_experts, top_k) for _ in range(num_layers)]) self.layer_norm = nn.LayerNorm(dim) self.lm_head = nn.Linear(dim, vocab_size, bias=False) def forward(self, x): b, s = x.shape positions = torch.arange(s, device=x.device).unsqueeze(0).expand(b, s) x = self.token_embedding(x) + self.position_embedding(positions) for layer in self.layers: x = layer(x) x = self.layer_norm(x) logits = self.lm_head(x) return logits class PileDataset(Dataset): def __init__(self, file_path, tokenizer, split='train', max_length=512): self.tokenizer = tokenizer self.max_length = max_length self.data = load_dataset(file_path, "wikitext-103-raw-v1", split=split) self.data = self.data.filter(lambda x: len(x['text']) > 0) if split == "train": self.data = self.data.select(range(0,300000)) def __len__(self): return len(self.data) def __getitem__(self, idx): text = self.data[idx]['text'] encoding = self.tokenizer(text, max_length=self.max_length, truncation=True, padding='max_length', return_tensors='pt') return encoding['input_ids'].squeeze(), encoding['attention_mask'].squeeze() def train(model, train_loader, optimizer, device): model.train() total_loss = 0 for batch in tqdm(train_loader, disable=not torch.distributed.get_rank() == 0): input_ids, attention_mask = batch input_ids, attention_mask = input_ids.to(device), attention_mask.to(device) optimizer.zero_grad() # Shift the input_ids and attention_mask to create targets targets = input_ids[:, 1:].contiguous() input_ids = input_ids[:, :-1].contiguous() attention_mask = attention_mask[:, :-1].contiguous() outputs = model(input_ids) # Reshape outputs and targets for loss calculation outputs = outputs.view(-1, outputs.size(-1)) targets = targets.view(-1) # Calculate loss (ignore padding token, usually 0) loss = F.cross_entropy(outputs, targets, ignore_index=0) loss.backward() optimizer.step() total_loss += loss.item() return total_loss / len(train_loader) def validate(model, val_loader, device): model.eval() total_loss = 0 with torch.no_grad(): for batch in tqdm(val_loader): input_ids, attention_mask = batch input_ids, attention_mask = input_ids.to(device), attention_mask.to(device) outputs = model(input_ids) loss = F.cross_entropy(outputs.view(-1, outputs.size(-1)), input_ids.view(-1), ignore_index=0) total_loss += loss.item() return total_loss / len(val_loader) # main execution if __name__ == "__main__": # Initialize distributed environment dist.init_process_group(backend='nccl') local_rank = int(os.environ["LOCAL_RANK"]) torch.cuda.set_device(local_rank) device = torch.device("cuda", local_rank) # Hyperparameters vocab_size = 50257 # GPT-2 tokenizer vocab size dim = 256 num_layers = 8 num_heads = 8 num_experts = 512 * 512 # 1M experts top_k = 16 batch_size = 6 num_epochs = 10 learning_rate = 1e-4 # Initialize tokenizer and model tokenizer = GPT2Tokenizer.from_pretrained('gpt2') tokenizer.pad_token = tokenizer.eos_token model = PEERLanguageModel(vocab_size, dim, num_layers, num_heads, num_experts, top_k).to(device) # Wrap the model with DistributedDataParallel model = DDP(model, device_ids=[local_rank], output_device=local_rank) # Load Pile dataset train_dataset = PileDataset('Salesforce/wikitext', tokenizer, split='train') val_dataset = PileDataset('Salesforce/wikitext', tokenizer, split='validation') # Use DistributedSampler for the training data train_sampler = DistributedSampler(train_dataset) train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler) val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False) # Optimizer optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) if local_rank == 0: print("Number of parameters:", sum(p.numel() for p in model.parameters())) # Training and validation loop best_val_loss = float('inf') for epoch in range(num_epochs): train_sampler.set_epoch(epoch) if local_rank == 0: print(f"Epoch Training {epoch+1}/{num_epochs}") train_loss = train(model, train_loader, optimizer, device) if local_rank == 0: print(f"Epoch Validation {epoch+1}/{num_epochs}") val_loss = validate(model, val_loader, device) print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}") # Save the best model if val_loss < best_val_loss: best_val_loss = val_loss torch.save(model.state_dict(), 'best_peer_language_model.pth') # Save the final trained model if local_rank == 0: torch.save(model.state_dict(), 'final_peer_language_model.pth') # Clean up dist.destroy_process_group()