minGRU-LM / modeling_minGRULM.py
suayptalha's picture
Create modeling_minGRULM.py
b27e0c7 verified
raw
history blame
2.45 kB
import torch
from torch import nn
from transformers import PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast
from torch.nn import CrossEntropyLoss
from typing import Optional
from .configuration_minGRULM import MinGRULMConfig
from minGRU_pytorch.minGRULM import minGRULM
class MinGRULMPreTrainedModel(PreTrainedModel):
config_class = MinGRULMConfig
base_model_prefix = "model"
def _init_weights(self, module):
std = 0.02
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
class MinGRULMForCausalLM(MinGRULMPreTrainedModel):
def __init__(self, config: MinGRULMConfig):
super().__init__(config)
# Load model from minGRULM library
self.model = minGRULM(
num_tokens=config.vocab_size,
dim=config.d_model,
depth=config.n_layer,
ff_mult=config.ff_mult,
min_gru_expansion=config.expand,
enable_conv=config.enable_conv,
)
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
self.post_init()
def get_input_embeddings(self):
return self.model.token_emb
def set_input_embeddings(self, value):
self.model.token_emb = value
def get_output_embeddings(self):
return self.lm_head
def forward(
self,
input_ids: torch.LongTensor,
labels: Optional[torch.LongTensor] = None,
return_dict: Optional[bool] = True,
):
# Forward pass through the model
logits = self.model(input_ids)
loss = None
if labels is not None:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fct = CrossEntropyLoss()
loss = loss_fct(
shift_logits.view(-1, self.config.vocab_size),
shift_labels.view(-1),
)
if not return_dict:
return (loss, logits) if loss is not None else (logits,)
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
)