File size: 3,631 Bytes
6cb30f5
 
 
 
 
 
 
2ef2a13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
---
datasets:
- ddrg/named_math_formulas
- ddrg/math_formula_retrieval
- ddrg/math_formulas
- ddrg/math_text
---
Pretrained model based on [microsoft/deberta-v3-base](https://huggingface.co./microsoft/deberta-v3-base) with further mathematical pre-training.

Compared to deberta-v3-base, 300 additional mathematical LaTeX tokens have been added before the mathematical pre-training. As this additional pre-training used NSP-like tasks, a pooling layer has been added to the model (`bias` and `weight`). If you don't need this pooling layer, just use the standard transformers DeBERTa model. If you want to use the additional pooling layer like the BERT one, a wrapper class like the following may be used:
```python
from typing import Mapping, Any

import torch
from torch import nn
from transformers import DebertaV2Model, DebertaV2Tokenizer, AutoConfig, AutoTokenizer

class DebertaV2ModelWithPoolingLayer:

    def __init__(self, pretrained_model_name):
        super(DebertaV2ModelWithPoolingLayer, self).__init__()

        # Load the Deberta model and tokenizer
        self.deberta = DebertaV2Model.from_pretrained(pretrained_model_name)
        self.tokenizer = DebertaV2Tokenizer.from_pretrained(pretrained_model_name)

        # Add a pooling layer (Linear + tanh activation) for the CLS token
        self.pooling_layer = nn.Sequential(
            nn.Linear(self.deberta.config.hidden_size, self.deberta.config.hidden_size),
            nn.Tanh()
        )

        self.config = self.deberta.config
        self.embeddings = self.deberta.embeddings


    def forward(self, input_ids, attention_mask, *args, **kwargs):
        # Forward pass through the Deberta model
        outputs = self.deberta(input_ids, attention_mask=attention_mask, *args, **kwargs)

        # Extract the hidden states from the output
        hidden_states = outputs.last_hidden_state

        # Get the CLS token representation (first token)
        cls_token = hidden_states[:, 0, :]

        # Apply the pooling layer to the CLS token representation
        pooled_output = self.pooling_layer(cls_token)
        # Include the pooled_output in the output dictionary as 'pooling_layer'
        outputs["pooler_output"] = pooled_output

        return outputs

    def save_pretrained(self, path):
        # Save the model's state_dict, configuration, and tokenizer
        state_dict = self.deberta.state_dict()
        state_dict.update(self.pooling_layer[0].state_dict())

        torch.save(state_dict, f"{path}/pytorch_model.bin")
        self.deberta.config.save_pretrained(path)
        self.tokenizer.save_pretrained(path)

    def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
        pooler_keys = ['bias', 'weight']
        deberta_state_dict = {k: v for k, v in state_dict.items() if k not in pooler_keys}
        pooler_state_dict = {k: v for k, v in state_dict.items() if k in pooler_keys}
        self.deberta.load_state_dict(deberta_state_dict, strict=strict)
        self.pooling_layer[0].load_state_dict(pooler_state_dict)

    @classmethod
    def from_pretrained(cls, name):
        # Initialize the instance
        instance = cls(name)

        try:
            # Load the model's state_dict
            instance.load_state_dict(torch.load(f"{name}/pytorch_model.bin"))
        except FileNotFoundError:
            print("Could not find DeBERTa pooling layer. Initialize new values")

        # Load the configuration and tokenizer
        instance.deberta.config = AutoConfig.from_pretrained(name)
        instance.tokenizer = AutoTokenizer.from_pretrained(name)

        return instance
```