File size: 11,284 Bytes
56fe6da
cdebfc7
cd77b48
cdebfc7
56fe6da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cd77b48
56fe6da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cd77b48
cdebfc7
 
56fe6da
 
 
 
 
 
 
 
 
cdebfc7
 
 
cd77b48
 
cdebfc7
cd77b48
 
 
cdebfc7
56fe6da
cdebfc7
 
 
 
 
 
 
 
56fe6da
cd77b48
 
56fe6da
cdebfc7
 
 
 
 
 
56fe6da
 
 
cd77b48
 
 
 
 
 
56fe6da
 
 
cd77b48
56fe6da
 
 
 
 
 
 
cd77b48
56fe6da
 
 
 
 
 
 
 
 
cd77b48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cdebfc7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cd77b48
 
 
 
 
 
 
 
56fe6da
cd77b48
 
 
 
56fe6da
 
cd77b48
56fe6da
 
cd77b48
56fe6da
 
 
 
 
 
 
 
 
cd77b48
56fe6da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
import re
import warnings
from typing import Dict, Optional

import torch
import torch.nn as nn
from transformers import AutoConfig, AutoModel, PretrainedConfig
from transformers.modeling_outputs import (
    BaseModelOutput,
    BaseModelOutputWithPooling,
    BaseModelOutputWithPoolingAndCrossAttentions,
)

_HF_ARCH_DICT = {
    # https://huggingface.co./docs/transformers/model_doc/roberta#roberta
    'roberta': {
        'config_names': {
            'context_length': 'max_position_embeddings',
            'vocab_size': 'vocab_size',
            'width': 'hidden_size',
            'heads': 'num_attention_heads',
            'layers': 'num_hidden_layers',
            'layer_attr': 'layer',
            'token_embeddings_attr': 'embeddings',
        },
        'pooler': 'mean_pooler',
    },
    # https://huggingface.co./docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig
    'xlm-roberta': {
        'config_names': {
            'context_length': 'max_position_embeddings',
            'vocab_size': 'vocab_size',
            'width': 'hidden_size',
            'heads': 'num_attention_heads',
            'layers': 'num_hidden_layers',
            'layer_attr': 'layer',
            'token_embeddings_attr': 'embeddings',
        },
        'pooler': 'mean_pooler',
    },
    # https://huggingface.co./docs/transformers/model_doc/bert
    'bert': {
        'config_names': {
            'context_length': 'max_position_embeddings',
            'vocab_size': 'vocab_size',
            'width': 'hidden_size',
            'heads': 'num_attention_heads',
            'layers': 'num_hidden_layers',
        },
        'pooler': 'cls_pooler',
    },
}

_POOLERS = {}


def _camel2snake(s):
    return re.sub(r'(?<!^)(?=[A-Z])', '_', s).lower()


def register_pooler(cls):
    """Decorator registering pooler class"""
    _POOLERS[_camel2snake(cls.__name__)] = cls
    return cls


@register_pooler
class MeanPooler(nn.Module):
    @staticmethod
    def forward(x: BaseModelOutput, attention_mask: torch.Tensor):
        masked_output = x.last_hidden_state * attention_mask.unsqueeze(-1)
        return masked_output.sum(dim=1) / attention_mask.sum(-1, keepdim=True)


@register_pooler
class MaxPooler(nn.Module):
    @staticmethod
    def forward(x: BaseModelOutput, attention_mask: torch.Tensor):
        masked_output = x.last_hidden_state.masked_fill(
            attention_mask.unsqueeze(-1), -torch.inf
        )
        return masked_output.max(1).values


@register_pooler
class ClsPooler(nn.Module):
    def __init__(self, use_pooler_output: bool = True):
        super().__init__()
        self.cls_token_position = 0
        self.use_pooler_output = use_pooler_output

    def forward(self, x: BaseModelOutput, _: torch.Tensor):
        if (
            self.use_pooler_output
            and isinstance(
                x,
                (
                    BaseModelOutputWithPooling,
                    BaseModelOutputWithPoolingAndCrossAttentions,
                ),
            )
            and (x.pooler_output is not None)
        ):
            return x.pooler_output
        return x.last_hidden_state[:, self.cls_token_position, :]


class HFTextEncoder(nn.Module):
    output_tokens: torch.jit.Final[bool]

    def __init__(
        self,
        model_name_or_path: str,
        output_dim: int,
        config: PretrainedConfig = None,
        pooler_type: str = None,
        proj_type: str = None,
        proj_bias: bool = False,
        pretrained: bool = True,
        output_tokens: bool = False,
        trust_remote_code: bool = False,
        revision: Optional[str] = None,
        code_revision: Optional[str] = None,
        default_instruction_task: Optional[str] = None,
        default_lora_task: Optional[str] = None,
        model_config_kwargs: Optional[Dict] = None,
    ):
        super().__init__()
        self.output_tokens = output_tokens
        self.output_dim = output_dim

        model_config_kwargs = model_config_kwargs or {}

        if config is None:
            if pretrained:
                self.transformer = AutoModel.from_pretrained(
                    model_name_or_path,
                    trust_remote_code=trust_remote_code,
                    revision=revision,
                    add_pooling_layer=False,
                    code_revision=code_revision,
                    **model_config_kwargs,
                )
                self.config = self.transformer.config
            else:
                self.config = AutoConfig.from_pretrained(
                    model_name_or_path,
                    trust_remote_code=trust_remote_code,
                    code_revision=code_revision,
                )
                self.config.update(model_config_kwargs)
                self.transformer = AutoModel.from_config(
                    self.config,
                    trust_remote_code=trust_remote_code,
                    add_pooling_layer=False,
                    code_revision=code_revision,
                )
            if (
                hasattr(self.config, 'is_encoder_decoder')
                and self.config.is_encoder_decoder
            ):
                self.transformer = self.transformer.encoder

        else:
            self.config = config
            self.config.update(model_config_kwargs)
            self.transformer = AutoModel.from_config(
                self.config,
                trust_remote_code=trust_remote_code,
                revision=revision,
                code_revision=code_revision,
            )
        self.vocab_size = getattr(self.config, 'vocab_size', 0)
        self.context_length = getattr(self.config, 'max_position_embeddings', 0)

        pooler_type = pooler_type or _HF_ARCH_DICT[self.config.model_type]['pooler']
        self.pooler = _POOLERS[pooler_type]()

        d_model = getattr(
            self.config, _HF_ARCH_DICT[self.config.model_type]['config_names']['width']
        )
        if (d_model == output_dim) and (proj_type is None):  # do we always need a proj?
            self.proj = nn.Identity()
        elif (d_model != output_dim) or proj_type == 'linear':
            self.proj = nn.Linear(d_model, output_dim, bias=proj_bias)
        elif proj_type == 'mlp':
            hidden_size = (d_model + output_dim) // 2
            self.proj = nn.Sequential(
                nn.Linear(d_model, hidden_size, bias=proj_bias),
                nn.GELU(),
                nn.Linear(hidden_size, output_dim, bias=proj_bias),
            )

        self._task_instructions = {}
        self._lora_adaptation_map = {}
        self._supports_task_instructions = False
        self._supports_lora = False
        if (
            hasattr(self.transformer, '_adaptation_map')
            and len(self.transformer._adaptation_map) > 0
        ):
            self._lora_adaptation_map = self.transformer._adaptation_map
            self._supports_lora = True
        if (
            hasattr(self.transformer, '_task_instructions')
            and len(self.transformer._task_instructions) > 0
        ):
            self._task_instructions = self.transformer._task_instructions
            self._supports_task_instructions = True

        self.default_instruction_task = None
        self.default_lora_task = None
        self.default_instruction = None
        self.default_loraid = None
        if default_instruction_task is not None:
            self.default_instruction_task = default_instruction_task
            self.default_instruction = self.get_instruction_from_task(
                default_instruction_task
            )
        if default_lora_task is not None:
            self.default_lora_task = default_lora_task
            self.default_loraid = self.get_loraid_from_task(default_lora_task)

    def get_instruction_from_task(self, task: str) -> Optional[str]:
        if self._supports_task_instructions:
            if task not in self._task_instructions:
                raise ValueError(
                    f'Unsupported task \'{task}\'. Choose one of the following: '
                    f'{", ".join(self._task_instructions)} or set to None to disable '
                    f'task instructions completely'
                )
            return self._task_instructions[task]
        else:
            warnings.warn(
                'Model does not support task instructions, ignoring instruction '
                f"task '{task}'"
            )
        return None

    def get_loraid_from_task(self, task: str) -> Optional[int]:
        if self._supports_lora:
            if task not in self._lora_adaptation_map:
                raise ValueError(
                    f'Unsupported task \'{task}\'. Choose one of the following: '
                    f'{", ".join(self._task_instructions)} or set to None to disable '
                    f'the LoRA adapters completely'
                )
            return self._lora_adaptation_map[task]
        else:
            warnings.warn(
                f"Model does not support LoRA adapters, ignoring LoRA task '{task}'"
            )
        return None

    @torch.jit.ignore
    def set_grad_checkpointing(self, _=True):
        self.transformer.gradient_checkpointing_enable()

    def init_parameters(self):
        pass

    def forward(self, x: torch.Tensor, adapter_mask: Optional[torch.Tensor] = None):
        attn_mask = (x != self.config.pad_token_id).long()
        kwargs = {}
        if adapter_mask is not None:
            kwargs['adapter_mask'] = adapter_mask
        out = self.transformer(input_ids=x, attention_mask=attn_mask, **kwargs)
        pooled_out = self.pooler(out, attn_mask)
        projected = self.proj(pooled_out)
        seqlen = out.last_hidden_state.shape[1]
        tokens = (
            out.last_hidden_state[
                :, torch.arange(seqlen) != self.pooler.cls_token_position, :
            ]
            if isinstance(self.pooler, ClsPooler)
            else out.last_hidden_state
        )
        if self.output_tokens:
            return projected, tokens
        return projected

    def lock(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
        if not unlocked_layers:
            for n, p in self.transformer.named_parameters():
                p.requires_grad = (
                    (not freeze_layer_norm) if 'LayerNorm' in n.split('.') else False
                )
            return

        encoder = (
            self.transformer.encoder
            if hasattr(self.transformer, 'encoder')
            else self.transformer
        )
        layer_list = getattr(
            encoder, _HF_ARCH_DICT[self.config.model_type]['config_names']['layer_attr']
        )
        print(f'Unlocking {unlocked_layers}/{len(layer_list) + 1} layers of hf model')
        embeddings = getattr(
            self.transformer,
            _HF_ARCH_DICT[self.config.model_type]['config_names'][
                'token_embeddings_attr'
            ],
        )
        modules = [embeddings, *layer_list][:-unlocked_layers]
        # freeze layers
        for module in modules:
            for n, p in module.named_parameters():
                p.requires_grad = (
                    (not freeze_layer_norm) if 'LayerNorm' in n.split('.') else False
                )