import torch from torch import nn from transformers import PreTrainedModel from transformers.modeling_outputs import CausalLMOutputWithPast from torch.nn import CrossEntropyLoss from typing import Optional from .configuration_minGRULM import MinGRULMConfig from minGRU_pytorch.minGRULM import minGRULM class MinGRULMPreTrainedModel(PreTrainedModel): config_class = MinGRULMConfig base_model_prefix = "model" def _init_weights(self, module): std = 0.02 if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() class MinGRULMForCausalLM(MinGRULMPreTrainedModel): def __init__(self, config: MinGRULMConfig): super().__init__(config) # Load model from minGRULM library self.model = minGRULM( num_tokens=config.vocab_size, dim=config.d_model, depth=config.n_layer, ff_mult=config.ff_mult, min_gru_expansion=config.expand, enable_conv=config.enable_conv, ) self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) self.post_init() def get_input_embeddings(self): return self.model.token_emb def set_input_embeddings(self, value): self.model.token_emb = value def get_output_embeddings(self): return self.lm_head def forward( self, input_ids: torch.LongTensor, labels: Optional[torch.LongTensor] = None, return_dict: Optional[bool] = True, ): # Forward pass through the model logits = self.model(input_ids) loss = None if labels is not None: shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() loss_fct = CrossEntropyLoss() loss = loss_fct( shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1), ) if not return_dict: return (loss, logits) if loss is not None else (logits,) return CausalLMOutputWithPast( loss=loss, logits=logits, )