pure-qwen2-0.5b / modeling_pure_qwen2.py
yangwang825's picture
Update modeling_pure_qwen2.py
e1c7ec4 verified
import os
from typing import Union, Tuple, Optional, List
import torch
import torch.nn as nn
from torch.autograd import Function
from transformers import PreTrainedModel
from transformers.models.qwen2.modeling_qwen2 import (
Qwen2DecoderLayer, Qwen2RMSNorm, Qwen2RotaryEmbedding
)
from transformers.utils import logging
from transformers.cache_utils import (
Cache, DynamicCache, SlidingWindowCache, StaticCache
)
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
from transformers.modeling_outputs import (
SequenceClassifierOutput,
BaseModelOutputWithPast
)
from transformers.utils import ModelOutput
from .configuration_pure_qwen2 import PureQwen2Config
logger = logging.get_logger(__name__)
class CovarianceFunction(Function):
@staticmethod
def forward(ctx, inputs):
x = inputs
b, c, h, w = x.data.shape
m = h * w
x = x.view(b, c, m)
I_hat = (-1.0 / m / m) * torch.ones(m, m, device=x.device) + (
1.0 / m
) * torch.eye(m, m, device=x.device)
I_hat = I_hat.view(1, m, m).repeat(b, 1, 1).type(x.dtype)
y = x @ I_hat @ x.transpose(-1, -2)
ctx.save_for_backward(inputs, I_hat)
return y
@staticmethod
def backward(ctx, grad_output):
inputs, I_hat = ctx.saved_tensors
x = inputs
b, c, h, w = x.data.shape
m = h * w
x = x.view(b, c, m)
grad_input = grad_output + grad_output.transpose(1, 2)
grad_input = grad_input @ x @ I_hat
grad_input = grad_input.reshape(b, c, h, w)
return grad_input
class Covariance(nn.Module):
def __init__(self):
super(Covariance, self).__init__()
def _covariance(self, x):
return CovarianceFunction.apply(x)
def forward(self, x):
# x should be [batch_size, seq_len, embed_dim]
if x.dim() == 2:
x = x.transpose(-1, -2)
C = self._covariance(x[None, :, :, None])
C = C.squeeze(dim=0)
return C
class PFSA(torch.nn.Module):
"""
https://openreview.net/pdf?id=isodM5jTA7h
"""
def __init__(self, input_dim, alpha=1):
super(PFSA, self).__init__()
self.input_dim = input_dim
self.alpha = alpha
def forward_one_sample(self, x):
x = x.transpose(1, 2)[..., None]
k = torch.mean(x, dim=[-1, -2], keepdim=True)
kd = torch.sqrt((k - k.mean(dim=1, keepdim=True)).pow(2).sum(dim=1, keepdim=True)) # [B, 1, 1, 1]
qd = torch.sqrt((x - x.mean(dim=1, keepdim=True)).pow(2).sum(dim=1, keepdim=True)) # [B, 1, T, 1]
C_qk = (((x - x.mean(dim=1, keepdim=True)) * (k - k.mean(dim=1, keepdim=True))).sum(dim=1, keepdim=True)) / (qd * kd)
A = (1 - torch.sigmoid(C_qk)) ** self.alpha
out = x * A
out = out.squeeze(dim=-1).transpose(1, 2)
return out
def forward(self, input_values, attention_mask=None):
"""
x: [B, T, F]
"""
out = []
b, t, f = input_values.shape
for x, mask in zip(input_values, attention_mask):
x = x.view(1, t, f)
# x_in = x[:, :sum(mask), :]
x_in = x[:, :int(mask.sum().item()), :]
x_out = self.forward_one_sample(x_in)
x_expanded = torch.zeros_like(x, device=x.device)
x_expanded[:, :x_out.shape[-2], :x_out.shape[-1]] = x_out
out.append(x_expanded)
out = torch.vstack(out)
out = out.view(b, t, f)
return out
class PURE(torch.nn.Module):
def __init__(
self,
in_dim,
svd_rank=16,
num_pc_to_remove=1,
center=False,
num_iters=2,
alpha=1,
disable_pcr=False,
disable_pfsa=False,
disable_covariance=True,
*args, **kwargs
):
super().__init__()
self.in_dim = in_dim
self.svd_rank = svd_rank
self.num_pc_to_remove = num_pc_to_remove
self.center = center
self.num_iters = num_iters
self.do_pcr = not disable_pcr
self.do_pfsa = not disable_pfsa
self.do_covariance = not disable_covariance
self.attention = PFSA(in_dim, alpha=alpha)
def _compute_pc(self, X, attention_mask):
"""
x: (B, T, F)
"""
pcs = []
bs, seqlen, dim = X.shape
for x, mask in zip(X, attention_mask):
rank = int(mask.sum().item())
x = x[:rank, :]
if self.do_covariance:
x = Covariance()(x)
q = self.svd_rank
else:
q = min(self.svd_rank, rank)
_, _, V = torch.pca_lowrank(x, q=q, center=self.center, niter=self.num_iters)
# _, _, Vh = torch.linalg.svd(x_, full_matrices=False)
# V = Vh.mH
pc = V.transpose(0, 1)[:self.num_pc_to_remove, :] # pc: [K, F]
pcs.append(pc)
# pcs = torch.vstack(pcs)
# pcs = pcs.view(bs, self.num_pc_to_remove, dim)
return pcs
def _remove_pc(self, X, pcs):
"""
[B, T, F], [B, ..., F]
"""
b, t, f = X.shape
out = []
for i, (x, pc) in enumerate(zip(X, pcs)):
# v = []
# for j, t in enumerate(x):
# t_ = t
# for c_ in c:
# t_ = t_.view(f, 1) - c_.view(f, 1) @ c_.view(1, f) @ t.view(f, 1)
# v.append(t_.transpose(-1, -2))
# v = torch.vstack(v)
v = x - x @ pc.transpose(0, 1) @ pc
out.append(v[None, ...])
out = torch.vstack(out)
return out
def forward(self, input_values, attention_mask=None, *args, **kwargs):
"""
PCR -> Attention
x: (B, T, F)
"""
x = input_values
if self.do_pcr:
pc = self._compute_pc(x, attention_mask) # pc: [B, K, F]
xx = self._remove_pc(x, pc)
# xx = xt - xt @ pc.transpose(1, 2) @ pc # [B, T, F] * [B, F, K] * [B, K, F] = [B, T, F]
else:
xx = x
if self.do_pfsa:
xx = self.attention(xx, attention_mask)
return xx
class StatisticsPooling(torch.nn.Module):
def __init__(self, return_mean=True, return_std=True):
super().__init__()
# Small value for GaussNoise
self.eps = 1e-5
self.return_mean = return_mean
self.return_std = return_std
if not (self.return_mean or self.return_std):
raise ValueError(
"both of statistics are equal to False \n"
"consider enabling mean and/or std statistic pooling"
)
def forward(self, input_values, attention_mask=None):
"""Calculates mean and std for a batch (input tensor).
Arguments
---------
x : torch.Tensor
It represents a tensor for a mini-batch.
"""
x = input_values
if attention_mask is None:
if self.return_mean:
mean = x.mean(dim=1)
if self.return_std:
std = x.std(dim=1)
else:
mean = []
std = []
for snt_id in range(x.shape[0]):
# Avoiding padded time steps
lengths = torch.sum(attention_mask, dim=1)
relative_lengths = lengths / torch.max(lengths)
actual_size = torch.round(relative_lengths[snt_id] * x.shape[1]).int()
# actual_size = int(torch.round(lengths[snt_id] * x.shape[1]))
# computing statistics
if self.return_mean:
mean.append(
torch.mean(x[snt_id, 0:actual_size, ...], dim=0)
)
if self.return_std:
std.append(torch.std(x[snt_id, 0:actual_size, ...], dim=0))
if self.return_mean:
mean = torch.stack(mean)
if self.return_std:
std = torch.stack(std)
if self.return_mean:
gnoise = self._get_gauss_noise(mean.size(), device=mean.device)
gnoise = gnoise
mean += gnoise
if self.return_std:
std = std + self.eps
# Append mean and std of the batch
if self.return_mean and self.return_std:
pooled_stats = torch.cat((mean, std), dim=1)
pooled_stats = pooled_stats.unsqueeze(1)
elif self.return_mean:
pooled_stats = mean.unsqueeze(1)
elif self.return_std:
pooled_stats = std.unsqueeze(1)
return pooled_stats
def _get_gauss_noise(self, shape_of_tensor, device="cpu"):
"""Returns a tensor of epsilon Gaussian noise.
Arguments
---------
shape_of_tensor : tensor
It represents the size of tensor for generating Gaussian noise.
"""
gnoise = torch.randn(shape_of_tensor, device=device)
gnoise -= torch.min(gnoise)
gnoise /= torch.max(gnoise)
gnoise = self.eps * ((1 - 9) * gnoise + 9)
return gnoise
class PureQwen2PreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = PureQwen2Config
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["Qwen2DecoderLayer"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_cache_class = True
_supports_quantized_cache = True
_supports_static_cache = True
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
class PureQwen2Model(PureQwen2PreTrainedModel):
"""
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2DecoderLayer`]
Args:
config: Qwen2Config
"""
def __init__(self, config: PureQwen2Config):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.layers = nn.ModuleList(
[Qwen2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self._attn_implementation = config._attn_implementation
self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = Qwen2RotaryEmbedding(config=config)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.embed_tokens
def set_input_embeddings(self, value):
self.embed_tokens = value
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
# kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache):
return_legacy_cache = True
if past_key_values is None:
past_key_values = DynamicCache()
else:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
logger.warning_once(
"We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
"will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
"(https://huggingface.co./docs/transformers/kv_cache#legacy-cache-format)"
)
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)
hidden_states = inputs_embeds
# create position embeddings to be shared across the decoder layers
position_embeddings = self.rotary_emb(hidden_states, position_ids)
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = None
for decoder_layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
causal_mask,
position_ids,
past_key_values,
output_attentions,
use_cache,
cache_position,
position_embeddings,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if return_legacy_cache:
next_cache = next_cache.to_legacy_cache()
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
# Copied from transformers.models.phi3.modeling_phi3.Phi3Model._update_causal_mask
def _update_causal_mask(
self,
attention_mask: torch.Tensor,
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: Cache,
output_attentions: bool,
):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# to infer the attention mask.
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
using_static_cache = isinstance(past_key_values, StaticCache)
using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache)
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
if (
self.config._attn_implementation == "sdpa"
and not (using_static_cache or using_sliding_window_cache)
and not output_attentions
):
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask,
inputs_embeds=input_tensor,
past_key_values_length=past_seen_tokens,
sliding_window=self.config.sliding_window,
is_training=self.training,
):
return None
dtype, device = input_tensor.dtype, input_tensor.device
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
# SlidingWindowCache or StaticCache
if using_sliding_window_cache or using_static_cache:
target_length = past_key_values.get_max_cache_shape()
# DynamicCache or no cache
else:
target_length = (
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else past_seen_tokens + sequence_length + 1
)
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
config=self.config,
past_key_values=past_key_values,
)
if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type == "cuda"
and not output_attentions
):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
return causal_mask
@staticmethod
# Copied from transformers.models.mistral.modeling_mistral.MistralModel._prepare_4d_causal_attention_mask_with_cache_position with Mistral->Qwen2
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
config: PureQwen2Config,
past_key_values: Cache,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to plcae the 4D attention mask on.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
config (`Qwen2Config`):
The model's configuration class
past_key_values (`Cache`):
The cache class that is being used currently to generate
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
)
diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
if config.sliding_window is not None:
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
# the check is needed to verify is current checkpoint was trained with sliding window or not
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
sliding_attend_mask = torch.arange(target_length, device=device) <= (
cache_position.reshape(-1, 1) - config.sliding_window
)
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
causal_mask *= diagonal_attend_mask
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
if attention_mask.shape[-1] > target_length:
attention_mask = attention_mask[:, :target_length]
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
class PureQwen2ForSequenceClassification(PureQwen2PreTrainedModel):
def __init__(
self,
config,
label_smoothing=0.0,
):
super().__init__(config)
self.label_smoothing = label_smoothing
self.num_labels = config.num_labels
self.config = config
self.model = PureQwen2Model(config)
self.pure = PURE(
in_dim=config.hidden_size,
svd_rank=config.svd_rank,
num_pc_to_remove=config.num_pc_to_remove,
center=config.center,
num_iters=config.num_iters,
alpha=config.alpha,
disable_pcr=config.disable_pcr,
disable_pfsa=config.disable_pfsa,
disable_covariance=config.disable_covariance
)
self.mean = StatisticsPooling(return_mean=True, return_std=False)
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
# Initialize weights and apply final processing
self.post_init()
def forward_pure_embeddings(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, ModelOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
transformer_outputs = self.model(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
token_embeddings = transformer_outputs[0]
token_embeddings = self.pure(token_embeddings, attention_mask)
return ModelOutput(
last_hidden_state=token_embeddings,
)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
# position_ids: Optional[torch.LongTensor] = None,
) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.model(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
token_embeddings = outputs[0]
token_embeddings = self.pure(token_embeddings, attention_mask)
pooled_output = self.mean(token_embeddings).squeeze(1)
logits = self.score(pooled_output)
loss = None
if labels is not None:
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = nn.MSELoss()
if self.num_labels == 1:
loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = nn.CrossEntropyLoss(label_smoothing=self.label_smoothing)
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = nn.BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)