from torch.nn import ( Module, Embedding, Dropout, ModuleDict, LayerNorm, ModuleList, Linear, GELU, functional, ) from torch.nn.init import normal_, zeros_ from dataclasses import dataclass from rotary_embedding_torch import RotaryEmbedding from torch import ones, cat from torch.nn.functional import scaled_dot_product_attention import torch.nn.functional as F from math import sqrt @dataclass class NBAConfig: players_per_team: int = None player_tokens: int = None age_tokens: int = None n_layer: int = None n_head: int = None n_embd: int = None dropout: float = None seed: int = None bias: bool = None dtype: type = None num_labels: int = None class SelfAttention(Module): def __init__(self, config): block_size = config.players_per_team * 2 + 1 super().__init__() assert config.n_embd % config.n_head == 0 self.c_attn = Linear(config.n_embd, 3 * config.n_embd, bias=config.bias, dtype=config.dtype) self.c_proj = Linear(config.n_embd, config.n_embd, bias=config.bias, dtype=config.dtype) self.attn_dropout = Dropout(config.dropout) self.resid_dropout = Dropout(config.dropout) self.n_head = config.n_head self.n_embd = config.n_embd self.dropout = config.dropout self.rotary_emb = RotaryEmbedding(config.n_embd) self.flash = hasattr(functional, 'scaled_dot_product_attention') if not self.flash: self.register_buffer("bias", ones(block_size, block_size) ).view(1, 1, block_size, block_size) def forward(self, x): B, T, C = x.size() q, k, v = self.c_attn(x).split(self.n_embd, dim=2) q = self.rotary_emb.rotate_queries_or_keys(q) k = self.rotary_emb.rotate_queries_or_keys(k) k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) if self.flash: y = scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=False) else: att = (q @ k.transpose(-2, -1)) * (1.0 / sqrt(k.size(-1))) att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf')) att = F.softmax(att, dim=-1) att = self.attn_dropout(att) y = att @ v y = y.transpose(1, 2).contiguous().view(B, T, C) # output projection y = self.resid_dropout(self.c_proj(y)) return y class MLP(Module): def __init__(self, config): super().__init__() self.c_fc = Linear(config.n_embd, 4 * config.n_embd, bias=config.bias, dtype=config.dtype) self.gelu = GELU() self.c_proj = Linear(4 * config.n_embd, config.n_embd, bias=config.bias, dtype=config.dtype) self.dropout = Dropout(config.dropout) def forward(self, x): x = self.c_fc(x) x = self.gelu(x) x = self.c_proj(x) x = self.dropout(x) return x class Block(Module): def __init__(self, config): super().__init__() self.ln_1 = LayerNorm(config.n_embd, bias=config.bias, dtype=config.dtype) self.attn = SelfAttention(config) self.ln_2 = LayerNorm(config.n_embd, bias=config.bias, dtype=config.dtype) self.mlp = MLP(config) def forward(self, x): x = x + self.attn(self.ln_1(x)) return x + self.mlp(self.ln_2(x)) class NBAModel(Module): def __init__(self, config) -> None: super().__init__() self.config = config self.transformer = ModuleDict(dict( home_player_embeddings = Embedding(config.player_tokens, config.n_embd, dtype=config.dtype), away_player_embeddings = Embedding(config.player_tokens, config.n_embd, dtype=config.dtype), home_age_embeddings = Embedding(config.age_tokens, config.n_embd, dtype=config.dtype), away_age_embeddings = Embedding(config.age_tokens, config.n_embd, dtype=config.dtype), drop = Dropout(config.dropout), h = ModuleList([Block(config) for _ in range(config.n_layer)]), ln_f = LayerNorm(config.n_embd, bias=config.bias, dtype=config.dtype), )) self.head = Linear(config.n_embd, config.num_labels, dtype=config.dtype) self.apply(self._init_weights) for pn, p in self.named_parameters(): if pn.endswith('c_proj.weight'): normal_(p, mean=0.0, std=0.02/sqrt(2 * config.n_layer)) def _init_weights(self, module): if isinstance(module, Linear): normal_(module.weight, mean=0.0, std=0.02) if module.bias is not None: zeros_(module.bias) elif isinstance(module, Embedding): normal_(module.weight, mean=0.0, std=0.02) def forward(self, **batch): home_player_tokens = batch['home_player_tokens'] away_player_tokens = batch['away_player_tokens'] home_age_tokens = batch['home_age_tokens'] away_age_tokens = batch['away_age_tokens'] home_player_embeddings = self.transformer.home_player_embeddings(home_player_tokens) away_player_embeddings = self.transformer.away_player_embeddings(away_player_tokens) home_age_embeddings = self.transformer.home_age_embeddings(home_age_tokens) away_age_embeddings = self.transformer.away_age_embeddings(away_age_tokens) home_emb = home_player_embeddings + home_age_embeddings away_emb = away_player_embeddings + away_age_embeddings x = cat([home_emb, away_emb], dim=1) x = self.transformer.drop(x) for block in self.transformer.h: x = block(x) x = self.transformer.ln_f(x) logits = self.head(x) logits = logits[:, 0] loss = None if 'home_team_won' in batch: loss = F.cross_entropy(logits, batch['home_net_score_token']) loss = {'loss': loss} return logits, loss