import torch import torch.nn as nn import tqdm from torch.nn import functional as F from core.layers import Block class GPTLanguageModel(nn.Module): def __init__(self, vocab_size, n_embd, block_size, n_head, n_layer, dropout, device, name = "GPT"): super().__init__() self.name = name self.block_size = block_size self.device = device self.token_embedding_table = nn.Embedding(vocab_size, n_embd) self.position_embedding_table = nn.Embedding(block_size, n_embd) self.blocks = nn.Sequential(*[Block(n_embd, n_head, block_size, dropout) for _ in range(n_layer)]) self.ln_f = nn.LayerNorm(n_embd) self.lm_head = nn.Linear(n_embd, vocab_size) self.apply(self._init_weights) self.history = {} self.vocab_size = vocab_size def _init_weights(self, module): if isinstance(module, nn.Linear): nn.init.normal_(module.weight, mean=0.0, std=0.02) if module.bias is not None: nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): nn.init.normal_(module.weight, mean=0.0, std=0.02) def forward(self, idx, targets=None): B, T = idx.shape assert torch.all(idx < self.vocab_size), f"Input indices must be less than vocab_size ({self.vocab_size})" assert T <= self.block_size, f"Input sequence length ({T}) must be <= block_size ({self.block_size})" tok_emb = self.token_embedding_table(idx) pos_emb = self.position_embedding_table(torch.arange(T, device=idx.device)) x = tok_emb + pos_emb x = self.blocks(x) x = self.ln_f(x) logits = self.lm_head(x) if targets is None: loss = None else: B, T, C = logits.shape logits = logits.view(B * T, C) targets = targets.view(B * T) loss = F.cross_entropy(logits, targets) return logits, loss def generate(self, idx, max_new_tokens, max_seq_length=200, temperature=1.0): for _ in range(max_new_tokens): if idx.size(1) > max_seq_length: idx = idx[:, -max_seq_length:] idx_cond = idx[:, -self.block_size:] logits, _ = self(idx_cond) logits = logits[:, -1, :] / temperature probs = F.softmax(logits, dim=-1) idx_next = torch.multinomial(probs, num_samples=1) idx = torch.cat((idx, idx_next), dim=1) yield idx