File size: 6,192 Bytes
b27e0c7
 
bc73b41
b27e0c7
 
 
 
 
 
9fdcc4b
0d741ac
 
 
 
 
783bbd7
0d741ac
 
 
 
 
 
 
 
 
 
 
 
 
 
109fa21
b27e0c7
 
 
 
 
8042f2b
b27e0c7
 
 
 
 
 
 
 
8042f2b
 
 
109fa21
783bbd7
 
 
 
b27e0c7
 
 
 
0d741ac
 
b27e0c7
 
 
 
0d741ac
b27e0c7
 
0d741ac
b27e0c7
0d741ac
b27e0c7
b08ebd4
b27e0c7
 
b08ebd4
783bbd7
b08ebd4
 
 
 
 
 
 
b27e0c7
0d741ac
b27e0c7
 
0d741ac
b27e0c7
 
 
 
bc73b41
 
 
 
b27e0c7
 
 
 
 
bc73b41
b27e0c7
0d741ac
b27e0c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
783bbd7
 
109fa21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
783bbd7
 
 
 
 
 
 
 
3a3c504
783bbd7
 
f2f029b
3a3c504
 
 
783bbd7
f2f029b
 
109fa21
 
3a3c504
109fa21
 
 
 
 
 
 
 
3a3c504
109fa21
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
import torch
from torch import nn
from transformers import PreTrainedModel, AutoModelForCausalLM, AutoTokenizer
from transformers.modeling_outputs import CausalLMOutputWithPast
from torch.nn import CrossEntropyLoss
from typing import Optional
from .configuration_minGRULM import MinGRULMConfig
from minGRU_pytorch.minGRULM import minGRULM


# Wrapper class for device compatibility
class MinGRULMWrapped(nn.Module):
    def __init__(self, min_gru_model):
        super().__init__()
        self.min_gru_model = min_gru_model
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    def forward(self, *args, **kwargs):
        # Move input tensors to the correct device
        args = [arg.to(self.device) if isinstance(arg, torch.Tensor) else arg for arg in args]
        kwargs = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()}
        return self.min_gru_model(*args, **kwargs)

    def to(self, device):
        # Update device information
        self.device = device
        self.min_gru_model.to(device)
        return self



class MinGRULMPreTrainedModel(PreTrainedModel):
    config_class = MinGRULMConfig
    base_model_prefix = "model"

    def _init_weights(self, module):
        std = self.config.initializer_range
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=std)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=std)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)

            
class MinGRULMForCausalLM(PreTrainedModel):
    config_class = MinGRULMConfig
    base_model_prefix = "model"

    def __init__(self, config: MinGRULMConfig):
        super().__init__(config)

        # Load model from minGRULM library and wrap it
        raw_min_gru = minGRULM(
            num_tokens=config.vocab_size,
            dim=config.d_model,
            depth=config.n_layer,
            ff_mult=config.ff_mult,
            min_gru_expansion=config.min_gru_expansion,
            enable_conv=config.enable_conv,
        )
        self.model = MinGRULMWrapped(raw_min_gru)

        # Language modeling head
        self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)

        self.post_init()

    def post_init(self):
        # Ensure tied weights and any additional setup
        super().post_init()
        self.tie_weights()

    def tie_weights(self):
        # Tie lm_head weights to the embedding layer weights
        self.lm_head.weight = self.model.min_gru_model.token_emb.weight

    def get_input_embeddings(self):
        return self.model.min_gru_model.token_emb

    def set_input_embeddings(self, value):
        self.model.min_gru_model.token_emb = value

    def get_output_embeddings(self):
        return self.lm_head

    def prepare_inputs_for_generation(self, input_ids: torch.LongTensor, **kwargs):
        # Ensure that inputs for generation are properly handled
        return {"input_ids": input_ids, "attention_mask": kwargs.get("attention_mask", None)}

    def forward(
        self,
        input_ids: torch.LongTensor,
        labels: Optional[torch.LongTensor] = None,
        return_dict: Optional[bool] = True,
        **kwargs
    ):
        # Forward pass through the wrapped model
        logits = self.model(input_ids)

        loss = None
        if labels is not None:
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(
                shift_logits.view(-1, self.config.vocab_size),
                shift_labels.view(-1),
            )

        if not return_dict:
            return (loss, logits) if loss is not None else (logits,)

        return CausalLMOutputWithPast(
            loss=loss,
            logits=logits,
        )

    def state_dict(self):
        """
        Custom state_dict function to return the model's state dict.
        This includes the wrapped model and any extra components like the language model head.
        """
        state_dict = {}
        
        # Add min_gru_model's state_dict
        state_dict['model'] = self.model.min_gru_model.state_dict()
        
        # Add lm_head's state_dict
        state_dict['lm_head'] = self.lm_head.state_dict()
        
        # Optionally, add config if needed
        state_dict['config'] = self.config.state_dict()
    
        return state_dict


    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
        """
        Load model from a pretrained checkpoint.
        """
        model = super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
        return model

    def save_pretrained(self, save_directory, safe_serialization: Optional[bool] = True):
        """
        Save the model and configuration to a directory.
    
        Args:
            save_directory (str): Directory to save the model.
            safe_serialization (bool, optional): Whether to use safe serialization. Defaults to True.
        """
        # Create the save directory if it doesn't exist
        os.makedirs(save_directory, exist_ok=True)
        
        # Check if safe_serialization is enabled
        if safe_serialization:
            print("Saving with safe serialization.")
            
            # Save the model's state_dict (model weights)
            state_dict = self.state_dict()
            torch.save(state_dict, os.path.join(save_directory, "pytorch_model.bin"))
            
            # Save the configuration
            self.config.save_pretrained(save_directory)
        else:
            print("Saving without safe serialization.")
            # If not safe_serialization, use the default save mechanism from the base class
            super().save_pretrained(save_directory)