jdrechsel's picture
Update README.md
6cb30f5
|
raw
history blame
3.63 kB
---
datasets:
- ddrg/named_math_formulas
- ddrg/math_formula_retrieval
- ddrg/math_formulas
- ddrg/math_text
---
Pretrained model based on [microsoft/deberta-v3-base](https://huggingface.co./microsoft/deberta-v3-base) with further mathematical pre-training.
Compared to deberta-v3-base, 300 additional mathematical LaTeX tokens have been added before the mathematical pre-training. As this additional pre-training used NSP-like tasks, a pooling layer has been added to the model (`bias` and `weight`). If you don't need this pooling layer, just use the standard transformers DeBERTa model. If you want to use the additional pooling layer like the BERT one, a wrapper class like the following may be used:
```python
from typing import Mapping, Any
import torch
from torch import nn
from transformers import DebertaV2Model, DebertaV2Tokenizer, AutoConfig, AutoTokenizer
class DebertaV2ModelWithPoolingLayer:
def __init__(self, pretrained_model_name):
super(DebertaV2ModelWithPoolingLayer, self).__init__()
# Load the Deberta model and tokenizer
self.deberta = DebertaV2Model.from_pretrained(pretrained_model_name)
self.tokenizer = DebertaV2Tokenizer.from_pretrained(pretrained_model_name)
# Add a pooling layer (Linear + tanh activation) for the CLS token
self.pooling_layer = nn.Sequential(
nn.Linear(self.deberta.config.hidden_size, self.deberta.config.hidden_size),
nn.Tanh()
)
self.config = self.deberta.config
self.embeddings = self.deberta.embeddings
def forward(self, input_ids, attention_mask, *args, **kwargs):
# Forward pass through the Deberta model
outputs = self.deberta(input_ids, attention_mask=attention_mask, *args, **kwargs)
# Extract the hidden states from the output
hidden_states = outputs.last_hidden_state
# Get the CLS token representation (first token)
cls_token = hidden_states[:, 0, :]
# Apply the pooling layer to the CLS token representation
pooled_output = self.pooling_layer(cls_token)
# Include the pooled_output in the output dictionary as 'pooling_layer'
outputs["pooler_output"] = pooled_output
return outputs
def save_pretrained(self, path):
# Save the model's state_dict, configuration, and tokenizer
state_dict = self.deberta.state_dict()
state_dict.update(self.pooling_layer[0].state_dict())
torch.save(state_dict, f"{path}/pytorch_model.bin")
self.deberta.config.save_pretrained(path)
self.tokenizer.save_pretrained(path)
def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
pooler_keys = ['bias', 'weight']
deberta_state_dict = {k: v for k, v in state_dict.items() if k not in pooler_keys}
pooler_state_dict = {k: v for k, v in state_dict.items() if k in pooler_keys}
self.deberta.load_state_dict(deberta_state_dict, strict=strict)
self.pooling_layer[0].load_state_dict(pooler_state_dict)
@classmethod
def from_pretrained(cls, name):
# Initialize the instance
instance = cls(name)
try:
# Load the model's state_dict
instance.load_state_dict(torch.load(f"{name}/pytorch_model.bin"))
except FileNotFoundError:
print("Could not find DeBERTa pooling layer. Initialize new values")
# Load the configuration and tokenizer
instance.deberta.config = AutoConfig.from_pretrained(name)
instance.tokenizer = AutoTokenizer.from_pretrained(name)
return instance
```