File size: 10,451 Bytes
dc07399
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
import torch
from torch.nn import functional as F
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from torch import nn
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
from torch.nn import Identity
from transformers.activations import get_activation
import numpy as np
#from torch_scatter import scatter_add
from .utils import input_check, pos_encoding

class classification_model(torch.nn.Module):
    def __init__(self, pretrained_model, config, num_classifier=1, num_pos_emb_layer=1, bertsum=False, device=None):
        super(classification_model, self).__init__()
        
        self.config = config
        self.num_labels = config.num_labels
        self.pretrained_model = pretrained_model
        if hasattr(config, 'd_model'):
            self.pretrained_hidden = config.d_model
        elif hasattr(config, 'hidden_size'):
            self.pretrained_hidden = config.hidden_size
        self.sequence_summary = SequenceSummary(config)
        self.bertsum = bertsum
        self.device = device
        self.return_hidden = False
        self.return_hidden_pretrained = False
        
        if self.bertsum:
            #self.pooling_1 = GATpooling(self.pretrained_hidden)
            #self.fnn_1 = nn.Linear(self.pretrained_hidden, self.pretrained_hidden)
            self.pooling_2 = GATpooling(self.pretrained_hidden, self.device)
            self.fnn_2 = nn.Linear(self.pretrained_hidden, self.pretrained_hidden)

        self.pos_emb_layer = nn.Sequential(*[nn.Linear(self.pretrained_hidden, self.pretrained_hidden) for _ in range(num_pos_emb_layer)])

        dim_list = np.linspace(self.pretrained_hidden, config.num_labels, num_classifier+1, dtype=np.int32)
        #dim_list = np.linspace(768, config.num_labels, num_classifier+1, dtype=np.int32)
        self.classifiers = nn.ModuleList()
        for c in range(num_classifier):
            self.classifiers.append(nn.Linear(dim_list[c], dim_list[c+1]))
        
    def forward(self, inputs):
        hidden_states = None
        input_ids = inputs['input_ids']
        token_type_ids = inputs['token_type_ids']
        attention_mask = inputs['attention_mask']
        position = inputs['position']
        transformer_inputs = input_check({'input_ids':input_ids, 'token_type_ids':token_type_ids, 'attention_mask':attention_mask}, self.pretrained_model)
        
        pretrianed_output = self.pretrained_model(**transformer_inputs)
        output = pretrianed_output[0]
        
        if self.return_hidden_pretrained and self.return_hidden:
            hidden_states = pretrianed_output[1]
        if self.bertsum:
            output = scatter_add(output, inputs['sentence_batch'], dim=-2)
            #output = self.pooling_1(output, inputs['sentence_batch'])
            #output = self.fnn_1(output)
            output = self.pooling_2(output)
            output = output.squeeze()
            output = self.fnn_2(output)
        else:
            output = self.sequence_summary(output)
            
        # paragraph positional encoding vector add
        pos_emb = pos_encoding(position, self.pretrained_hidden).to(self.device, dtype=torch.float)
        output = torch.add(output,pos_emb)
        output = self.pos_emb_layer(output)
            
        if self.return_hidden and not self.return_hidden_pretrained:
            hidden_states = output
        for layer in self.classifiers:
            output = layer(output)
        
        logits = output
        
        if 'labels' in inputs.keys():
            loss = self.classification_loss_f(inputs, logits)
        else:
            loss = None
        
        return loss, output, hidden_states
    
    def classification_loss_f(self, inputs, logits):
        labels=inputs['labels']
        loss=None
        
        if labels is not None:
            if self.config.problem_type is None:
                if self.num_labels == 1:
                    self.config.problem_type = "regression"
                elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
                    self.config.problem_type = "single_label_classification"
                else:
                    self.config.problem_type = "multi_label_classification"

            if self.config.problem_type == "regression":
                loss_fct = MSELoss()
                if self.num_labels == 1:
                    loss = loss_fct(logits.squeeze(), labels.squeeze())
                else:
                    loss = loss_fct(logits, labels)
            elif self.config.problem_type == "single_label_classification":
                loss_fct = CrossEntropyLoss()
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
            elif self.config.problem_type == "multi_label_classification":
                loss_fct = BCEWithLogitsLoss()
                loss = loss_fct(logits, labels)
        return loss


class GATpooling(nn.Module):
    def __init__(self, hidden_size, device=None):
        super(GATpooling, self).__init__()
        self.gate_nn = nn.Linear(hidden_size, 1)
        self.device = device
        
    def forward(self, x, batch=None):
        if batch==None:
            batch = torch.zeros(x.shape[-2], dtype=torch.long).to(self.device)
        gate = self.gate_nn(x)
        gate = F.softmax(gate, dim=-1)
        out = scatter_add(gate*x, batch, dim=-2)
        return out
    

class SequenceSummary(nn.Module):
    r"""
    Compute a single vector summary of a sequence hidden states.
    Args:
        config ([`PretrainedConfig`]):
            The config used by the model. Relevant arguments in the config class of the model are (refer to the actual
            config class of your model for the default values it uses):
            - **summary_type** (`str`) -- The method to use to make this summary. Accepted values are:
                - `"last"` -- Take the last token hidden state (like XLNet)
                - `"first"` -- Take the first token hidden state (like Bert)
                - `"mean"` -- Take the mean of all tokens hidden states
                - `"cls_index"` -- Supply a Tensor of classification token position (GPT/GPT-2)
                - `"attn"` -- Not implemented now, use multi-head attention
            - **summary_use_proj** (`bool`) -- Add a projection after the vector extraction.
            - **summary_proj_to_labels** (`bool`) -- If `True`, the projection outputs to `config.num_labels` classes
              (otherwise to `config.hidden_size`).
            - **summary_activation** (`Optional[str]`) -- Set to `"tanh"` to add a tanh activation to the output,
              another string or `None` will add no activation.
            - **summary_first_dropout** (`float`) -- Optional dropout probability before the projection and activation.
            - **summary_last_dropout** (`float`)-- Optional dropout probability after the projection and activation.
    """

    def __init__(self, config):
        super().__init__()

        self.summary_type = getattr(config, "summary_type", "mean")
        if self.summary_type == "attn":
            # We should use a standard multi-head attention module with absolute positional embedding for that.
            # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
            # We can probably just use the multi-head attention module of PyTorch >=1.1.0
            raise NotImplementedError

        self.summary = Identity()
        if hasattr(config, "summary_use_proj") and config.summary_use_proj:
            if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0:
                num_classes = config.num_labels
            else:
                num_classes = config.hidden_size
            self.summary = nn.Linear(config.hidden_size, num_classes)

        activation_string = getattr(config, "summary_activation", None)
        self.activation: Callable = get_activation(activation_string) if activation_string else Identity()

        self.first_dropout = Identity()
        if hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0:
            self.first_dropout = nn.Dropout(config.summary_first_dropout)

        self.last_dropout = Identity()
        if hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0:
            self.last_dropout = nn.Dropout(config.summary_last_dropout)

    def forward(
        self, hidden_states: torch.FloatTensor, cls_index: Optional[torch.LongTensor] = None
    ) -> torch.FloatTensor:
        """
        Compute a single vector summary of a sequence hidden states.
        Args:
            hidden_states (`torch.FloatTensor` of shape `[batch_size, seq_len, hidden_size]`):
                The hidden states of the last layer.
            cls_index (`torch.LongTensor` of shape `[batch_size]` or `[batch_size, ...]` where ... are optional leading dimensions of `hidden_states`, *optional*):
                Used if `summary_type == "cls_index"` and takes the last token of the sequence as classification token.
        Returns:
            `torch.FloatTensor`: The summary of the sequence hidden states.
        """
        if self.summary_type == "last":
            output = hidden_states[:, -1]
        elif self.summary_type == "first":
            output = hidden_states[:, 0]
        elif self.summary_type == "mean":
            output = hidden_states.mean(dim=1)
        elif self.summary_type == "cls_index":
            if cls_index is None:
                cls_index = torch.full_like(
                    hidden_states[..., :1, :],
                    hidden_states.shape[-2] - 1,
                    dtype=torch.long,
                )
            else:
                cls_index = cls_index.unsqueeze(-1).unsqueeze(-1)
                cls_index = cls_index.expand((-1,) * (cls_index.dim() - 1) + (hidden_states.size(-1),))
            # shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states
            output = hidden_states.gather(-2, cls_index).squeeze(-2)  # shape (bsz, XX, hidden_size)
        elif self.summary_type == "attn":
            raise NotImplementedError

        output = self.first_dropout(output)
        output = self.summary(output)
        output = self.activation(output)
        output = self.last_dropout(output)

        return output