File size: 6,192 Bytes
b27e0c7 bc73b41 b27e0c7 9fdcc4b 0d741ac 783bbd7 0d741ac 109fa21 b27e0c7 8042f2b b27e0c7 8042f2b 109fa21 783bbd7 b27e0c7 0d741ac b27e0c7 0d741ac b27e0c7 0d741ac b27e0c7 0d741ac b27e0c7 b08ebd4 b27e0c7 b08ebd4 783bbd7 b08ebd4 b27e0c7 0d741ac b27e0c7 0d741ac b27e0c7 bc73b41 b27e0c7 bc73b41 b27e0c7 0d741ac b27e0c7 783bbd7 109fa21 783bbd7 3a3c504 783bbd7 f2f029b 3a3c504 783bbd7 f2f029b 109fa21 3a3c504 109fa21 3a3c504 109fa21 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 |
import torch
from torch import nn
from transformers import PreTrainedModel, AutoModelForCausalLM, AutoTokenizer
from transformers.modeling_outputs import CausalLMOutputWithPast
from torch.nn import CrossEntropyLoss
from typing import Optional
from .configuration_minGRULM import MinGRULMConfig
from minGRU_pytorch.minGRULM import minGRULM
# Wrapper class for device compatibility
class MinGRULMWrapped(nn.Module):
def __init__(self, min_gru_model):
super().__init__()
self.min_gru_model = min_gru_model
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def forward(self, *args, **kwargs):
# Move input tensors to the correct device
args = [arg.to(self.device) if isinstance(arg, torch.Tensor) else arg for arg in args]
kwargs = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()}
return self.min_gru_model(*args, **kwargs)
def to(self, device):
# Update device information
self.device = device
self.min_gru_model.to(device)
return self
class MinGRULMPreTrainedModel(PreTrainedModel):
config_class = MinGRULMConfig
base_model_prefix = "model"
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
class MinGRULMForCausalLM(PreTrainedModel):
config_class = MinGRULMConfig
base_model_prefix = "model"
def __init__(self, config: MinGRULMConfig):
super().__init__(config)
# Load model from minGRULM library and wrap it
raw_min_gru = minGRULM(
num_tokens=config.vocab_size,
dim=config.d_model,
depth=config.n_layer,
ff_mult=config.ff_mult,
min_gru_expansion=config.min_gru_expansion,
enable_conv=config.enable_conv,
)
self.model = MinGRULMWrapped(raw_min_gru)
# Language modeling head
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
self.post_init()
def post_init(self):
# Ensure tied weights and any additional setup
super().post_init()
self.tie_weights()
def tie_weights(self):
# Tie lm_head weights to the embedding layer weights
self.lm_head.weight = self.model.min_gru_model.token_emb.weight
def get_input_embeddings(self):
return self.model.min_gru_model.token_emb
def set_input_embeddings(self, value):
self.model.min_gru_model.token_emb = value
def get_output_embeddings(self):
return self.lm_head
def prepare_inputs_for_generation(self, input_ids: torch.LongTensor, **kwargs):
# Ensure that inputs for generation are properly handled
return {"input_ids": input_ids, "attention_mask": kwargs.get("attention_mask", None)}
def forward(
self,
input_ids: torch.LongTensor,
labels: Optional[torch.LongTensor] = None,
return_dict: Optional[bool] = True,
**kwargs
):
# Forward pass through the wrapped model
logits = self.model(input_ids)
loss = None
if labels is not None:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fct = CrossEntropyLoss()
loss = loss_fct(
shift_logits.view(-1, self.config.vocab_size),
shift_labels.view(-1),
)
if not return_dict:
return (loss, logits) if loss is not None else (logits,)
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
)
def state_dict(self):
"""
Custom state_dict function to return the model's state dict.
This includes the wrapped model and any extra components like the language model head.
"""
state_dict = {}
# Add min_gru_model's state_dict
state_dict['model'] = self.model.min_gru_model.state_dict()
# Add lm_head's state_dict
state_dict['lm_head'] = self.lm_head.state_dict()
# Optionally, add config if needed
state_dict['config'] = self.config.state_dict()
return state_dict
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
"""
Load model from a pretrained checkpoint.
"""
model = super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
return model
def save_pretrained(self, save_directory, safe_serialization: Optional[bool] = True):
"""
Save the model and configuration to a directory.
Args:
save_directory (str): Directory to save the model.
safe_serialization (bool, optional): Whether to use safe serialization. Defaults to True.
"""
# Create the save directory if it doesn't exist
os.makedirs(save_directory, exist_ok=True)
# Check if safe_serialization is enabled
if safe_serialization:
print("Saving with safe serialization.")
# Save the model's state_dict (model weights)
state_dict = self.state_dict()
torch.save(state_dict, os.path.join(save_directory, "pytorch_model.bin"))
# Save the configuration
self.config.save_pretrained(save_directory)
else:
print("Saving without safe serialization.")
# If not safe_serialization, use the default save mechanism from the base class
super().save_pretrained(save_directory)
|