Spaces:
Runtime error
Runtime error
import torch | |
from torch.nn import functional as F | |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss | |
from torch import nn | |
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union | |
from torch.nn import Identity | |
from transformers.activations import get_activation | |
import numpy as np | |
#from torch_scatter import scatter_add | |
from .utils import input_check, pos_encoding | |
class classification_model(torch.nn.Module): | |
def __init__(self, pretrained_model, config, num_classifier=1, num_pos_emb_layer=1, bertsum=False, device=None): | |
super(classification_model, self).__init__() | |
self.config = config | |
self.num_labels = config.num_labels | |
self.pretrained_model = pretrained_model | |
if hasattr(config, 'd_model'): | |
self.pretrained_hidden = config.d_model | |
elif hasattr(config, 'hidden_size'): | |
self.pretrained_hidden = config.hidden_size | |
self.sequence_summary = SequenceSummary(config) | |
self.bertsum = bertsum | |
self.device = device | |
self.return_hidden = False | |
self.return_hidden_pretrained = False | |
if self.bertsum: | |
#self.pooling_1 = GATpooling(self.pretrained_hidden) | |
#self.fnn_1 = nn.Linear(self.pretrained_hidden, self.pretrained_hidden) | |
self.pooling_2 = GATpooling(self.pretrained_hidden, self.device) | |
self.fnn_2 = nn.Linear(self.pretrained_hidden, self.pretrained_hidden) | |
self.pos_emb_layer = nn.Sequential(*[nn.Linear(self.pretrained_hidden, self.pretrained_hidden) for _ in range(num_pos_emb_layer)]) | |
dim_list = np.linspace(self.pretrained_hidden, config.num_labels, num_classifier+1, dtype=np.int32) | |
#dim_list = np.linspace(768, config.num_labels, num_classifier+1, dtype=np.int32) | |
self.classifiers = nn.ModuleList() | |
for c in range(num_classifier): | |
self.classifiers.append(nn.Linear(dim_list[c], dim_list[c+1])) | |
def forward(self, inputs): | |
hidden_states = None | |
input_ids = inputs['input_ids'] | |
token_type_ids = inputs['token_type_ids'] | |
attention_mask = inputs['attention_mask'] | |
position = inputs['position'] | |
transformer_inputs = input_check({'input_ids':input_ids, 'token_type_ids':token_type_ids, 'attention_mask':attention_mask}, self.pretrained_model) | |
pretrianed_output = self.pretrained_model(**transformer_inputs) | |
output = pretrianed_output[0] | |
if self.return_hidden_pretrained and self.return_hidden: | |
hidden_states = pretrianed_output[1] | |
if self.bertsum: | |
output = scatter_add(output, inputs['sentence_batch'], dim=-2) | |
#output = self.pooling_1(output, inputs['sentence_batch']) | |
#output = self.fnn_1(output) | |
output = self.pooling_2(output) | |
output = output.squeeze() | |
output = self.fnn_2(output) | |
else: | |
output = self.sequence_summary(output) | |
# paragraph positional encoding vector add | |
pos_emb = pos_encoding(position, self.pretrained_hidden).to(self.device, dtype=torch.float) | |
output = torch.add(output,pos_emb) | |
output = self.pos_emb_layer(output) | |
if self.return_hidden and not self.return_hidden_pretrained: | |
hidden_states = output | |
for layer in self.classifiers: | |
output = layer(output) | |
logits = output | |
if 'labels' in inputs.keys(): | |
loss = self.classification_loss_f(inputs, logits) | |
else: | |
loss = None | |
return loss, output, hidden_states | |
def classification_loss_f(self, inputs, logits): | |
labels=inputs['labels'] | |
loss=None | |
if labels is not None: | |
if self.config.problem_type is None: | |
if self.num_labels == 1: | |
self.config.problem_type = "regression" | |
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): | |
self.config.problem_type = "single_label_classification" | |
else: | |
self.config.problem_type = "multi_label_classification" | |
if self.config.problem_type == "regression": | |
loss_fct = MSELoss() | |
if self.num_labels == 1: | |
loss = loss_fct(logits.squeeze(), labels.squeeze()) | |
else: | |
loss = loss_fct(logits, labels) | |
elif self.config.problem_type == "single_label_classification": | |
loss_fct = CrossEntropyLoss() | |
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) | |
elif self.config.problem_type == "multi_label_classification": | |
loss_fct = BCEWithLogitsLoss() | |
loss = loss_fct(logits, labels) | |
return loss | |
class GATpooling(nn.Module): | |
def __init__(self, hidden_size, device=None): | |
super(GATpooling, self).__init__() | |
self.gate_nn = nn.Linear(hidden_size, 1) | |
self.device = device | |
def forward(self, x, batch=None): | |
if batch==None: | |
batch = torch.zeros(x.shape[-2], dtype=torch.long).to(self.device) | |
gate = self.gate_nn(x) | |
gate = F.softmax(gate, dim=-1) | |
out = scatter_add(gate*x, batch, dim=-2) | |
return out | |
class SequenceSummary(nn.Module): | |
r""" | |
Compute a single vector summary of a sequence hidden states. | |
Args: | |
config ([`PretrainedConfig`]): | |
The config used by the model. Relevant arguments in the config class of the model are (refer to the actual | |
config class of your model for the default values it uses): | |
- **summary_type** (`str`) -- The method to use to make this summary. Accepted values are: | |
- `"last"` -- Take the last token hidden state (like XLNet) | |
- `"first"` -- Take the first token hidden state (like Bert) | |
- `"mean"` -- Take the mean of all tokens hidden states | |
- `"cls_index"` -- Supply a Tensor of classification token position (GPT/GPT-2) | |
- `"attn"` -- Not implemented now, use multi-head attention | |
- **summary_use_proj** (`bool`) -- Add a projection after the vector extraction. | |
- **summary_proj_to_labels** (`bool`) -- If `True`, the projection outputs to `config.num_labels` classes | |
(otherwise to `config.hidden_size`). | |
- **summary_activation** (`Optional[str]`) -- Set to `"tanh"` to add a tanh activation to the output, | |
another string or `None` will add no activation. | |
- **summary_first_dropout** (`float`) -- Optional dropout probability before the projection and activation. | |
- **summary_last_dropout** (`float`)-- Optional dropout probability after the projection and activation. | |
""" | |
def __init__(self, config): | |
super().__init__() | |
self.summary_type = getattr(config, "summary_type", "mean") | |
if self.summary_type == "attn": | |
# We should use a standard multi-head attention module with absolute positional embedding for that. | |
# Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276 | |
# We can probably just use the multi-head attention module of PyTorch >=1.1.0 | |
raise NotImplementedError | |
self.summary = Identity() | |
if hasattr(config, "summary_use_proj") and config.summary_use_proj: | |
if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0: | |
num_classes = config.num_labels | |
else: | |
num_classes = config.hidden_size | |
self.summary = nn.Linear(config.hidden_size, num_classes) | |
activation_string = getattr(config, "summary_activation", None) | |
self.activation: Callable = get_activation(activation_string) if activation_string else Identity() | |
self.first_dropout = Identity() | |
if hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0: | |
self.first_dropout = nn.Dropout(config.summary_first_dropout) | |
self.last_dropout = Identity() | |
if hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0: | |
self.last_dropout = nn.Dropout(config.summary_last_dropout) | |
def forward( | |
self, hidden_states: torch.FloatTensor, cls_index: Optional[torch.LongTensor] = None | |
) -> torch.FloatTensor: | |
""" | |
Compute a single vector summary of a sequence hidden states. | |
Args: | |
hidden_states (`torch.FloatTensor` of shape `[batch_size, seq_len, hidden_size]`): | |
The hidden states of the last layer. | |
cls_index (`torch.LongTensor` of shape `[batch_size]` or `[batch_size, ...]` where ... are optional leading dimensions of `hidden_states`, *optional*): | |
Used if `summary_type == "cls_index"` and takes the last token of the sequence as classification token. | |
Returns: | |
`torch.FloatTensor`: The summary of the sequence hidden states. | |
""" | |
if self.summary_type == "last": | |
output = hidden_states[:, -1] | |
elif self.summary_type == "first": | |
output = hidden_states[:, 0] | |
elif self.summary_type == "mean": | |
output = hidden_states.mean(dim=1) | |
elif self.summary_type == "cls_index": | |
if cls_index is None: | |
cls_index = torch.full_like( | |
hidden_states[..., :1, :], | |
hidden_states.shape[-2] - 1, | |
dtype=torch.long, | |
) | |
else: | |
cls_index = cls_index.unsqueeze(-1).unsqueeze(-1) | |
cls_index = cls_index.expand((-1,) * (cls_index.dim() - 1) + (hidden_states.size(-1),)) | |
# shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states | |
output = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, XX, hidden_size) | |
elif self.summary_type == "attn": | |
raise NotImplementedError | |
output = self.first_dropout(output) | |
output = self.summary(output) | |
output = self.activation(output) | |
output = self.last_dropout(output) | |
return output | |