import os from typing import Union, Tuple, Optional, List import torch import torch.nn as nn from torch.autograd import Function from transformers import PreTrainedModel from transformers.models.qwen2.modeling_qwen2 import ( Qwen2DecoderLayer, Qwen2RMSNorm, Qwen2RotaryEmbedding ) from transformers.utils import logging from transformers.cache_utils import ( Cache, DynamicCache, SlidingWindowCache, StaticCache ) from transformers.modeling_attn_mask_utils import AttentionMaskConverter from transformers.modeling_outputs import ( SequenceClassifierOutput, BaseModelOutputWithPast ) from transformers.utils import ModelOutput from .configuration_pure_qwen2 import PureQwen2Config logger = logging.get_logger(__name__) class CovarianceFunction(Function): @staticmethod def forward(ctx, inputs): x = inputs b, c, h, w = x.data.shape m = h * w x = x.view(b, c, m) I_hat = (-1.0 / m / m) * torch.ones(m, m, device=x.device) + ( 1.0 / m ) * torch.eye(m, m, device=x.device) I_hat = I_hat.view(1, m, m).repeat(b, 1, 1).type(x.dtype) y = x @ I_hat @ x.transpose(-1, -2) ctx.save_for_backward(inputs, I_hat) return y @staticmethod def backward(ctx, grad_output): inputs, I_hat = ctx.saved_tensors x = inputs b, c, h, w = x.data.shape m = h * w x = x.view(b, c, m) grad_input = grad_output + grad_output.transpose(1, 2) grad_input = grad_input @ x @ I_hat grad_input = grad_input.reshape(b, c, h, w) return grad_input class Covariance(nn.Module): def __init__(self): super(Covariance, self).__init__() def _covariance(self, x): return CovarianceFunction.apply(x) def forward(self, x): # x should be [batch_size, seq_len, embed_dim] if x.dim() == 2: x = x.transpose(-1, -2) C = self._covariance(x[None, :, :, None]) C = C.squeeze(dim=0) return C class PFSA(torch.nn.Module): """ https://openreview.net/pdf?id=isodM5jTA7h """ def __init__(self, input_dim, alpha=1): super(PFSA, self).__init__() self.input_dim = input_dim self.alpha = alpha def forward_one_sample(self, x): x = x.transpose(1, 2)[..., None] k = torch.mean(x, dim=[-1, -2], keepdim=True) kd = torch.sqrt((k - k.mean(dim=1, keepdim=True)).pow(2).sum(dim=1, keepdim=True)) # [B, 1, 1, 1] qd = torch.sqrt((x - x.mean(dim=1, keepdim=True)).pow(2).sum(dim=1, keepdim=True)) # [B, 1, T, 1] C_qk = (((x - x.mean(dim=1, keepdim=True)) * (k - k.mean(dim=1, keepdim=True))).sum(dim=1, keepdim=True)) / (qd * kd) A = (1 - torch.sigmoid(C_qk)) ** self.alpha out = x * A out = out.squeeze(dim=-1).transpose(1, 2) return out def forward(self, input_values, attention_mask=None): """ x: [B, T, F] """ out = [] b, t, f = input_values.shape for x, mask in zip(input_values, attention_mask): x = x.view(1, t, f) # x_in = x[:, :sum(mask), :] x_in = x[:, :int(mask.sum().item()), :] x_out = self.forward_one_sample(x_in) x_expanded = torch.zeros_like(x, device=x.device) x_expanded[:, :x_out.shape[-2], :x_out.shape[-1]] = x_out out.append(x_expanded) out = torch.vstack(out) out = out.view(b, t, f) return out class PURE(torch.nn.Module): def __init__( self, in_dim, svd_rank=16, num_pc_to_remove=1, center=False, num_iters=2, alpha=1, disable_pcr=False, disable_pfsa=False, disable_covariance=True, *args, **kwargs ): super().__init__() self.in_dim = in_dim self.svd_rank = svd_rank self.num_pc_to_remove = num_pc_to_remove self.center = center self.num_iters = num_iters self.do_pcr = not disable_pcr self.do_pfsa = not disable_pfsa self.do_covariance = not disable_covariance self.attention = PFSA(in_dim, alpha=alpha) def _compute_pc(self, X, attention_mask): """ x: (B, T, F) """ pcs = [] bs, seqlen, dim = X.shape for x, mask in zip(X, attention_mask): rank = int(mask.sum().item()) x = x[:rank, :] if self.do_covariance: x = Covariance()(x) q = self.svd_rank else: q = min(self.svd_rank, rank) _, _, V = torch.pca_lowrank(x, q=q, center=self.center, niter=self.num_iters) # _, _, Vh = torch.linalg.svd(x_, full_matrices=False) # V = Vh.mH pc = V.transpose(0, 1)[:self.num_pc_to_remove, :] # pc: [K, F] pcs.append(pc) # pcs = torch.vstack(pcs) # pcs = pcs.view(bs, self.num_pc_to_remove, dim) return pcs def _remove_pc(self, X, pcs): """ [B, T, F], [B, ..., F] """ b, t, f = X.shape out = [] for i, (x, pc) in enumerate(zip(X, pcs)): # v = [] # for j, t in enumerate(x): # t_ = t # for c_ in c: # t_ = t_.view(f, 1) - c_.view(f, 1) @ c_.view(1, f) @ t.view(f, 1) # v.append(t_.transpose(-1, -2)) # v = torch.vstack(v) v = x - x @ pc.transpose(0, 1) @ pc out.append(v[None, ...]) out = torch.vstack(out) return out def forward(self, input_values, attention_mask=None, *args, **kwargs): """ PCR -> Attention x: (B, T, F) """ x = input_values if self.do_pcr: pc = self._compute_pc(x, attention_mask) # pc: [B, K, F] xx = self._remove_pc(x, pc) # xx = xt - xt @ pc.transpose(1, 2) @ pc # [B, T, F] * [B, F, K] * [B, K, F] = [B, T, F] else: xx = x if self.do_pfsa: xx = self.attention(xx, attention_mask) return xx class StatisticsPooling(torch.nn.Module): def __init__(self, return_mean=True, return_std=True): super().__init__() # Small value for GaussNoise self.eps = 1e-5 self.return_mean = return_mean self.return_std = return_std if not (self.return_mean or self.return_std): raise ValueError( "both of statistics are equal to False \n" "consider enabling mean and/or std statistic pooling" ) def forward(self, input_values, attention_mask=None): """Calculates mean and std for a batch (input tensor). Arguments --------- x : torch.Tensor It represents a tensor for a mini-batch. """ x = input_values if attention_mask is None: if self.return_mean: mean = x.mean(dim=1) if self.return_std: std = x.std(dim=1) else: mean = [] std = [] for snt_id in range(x.shape[0]): # Avoiding padded time steps lengths = torch.sum(attention_mask, dim=1) relative_lengths = lengths / torch.max(lengths) actual_size = torch.round(relative_lengths[snt_id] * x.shape[1]).int() # actual_size = int(torch.round(lengths[snt_id] * x.shape[1])) # computing statistics if self.return_mean: mean.append( torch.mean(x[snt_id, 0:actual_size, ...], dim=0) ) if self.return_std: std.append(torch.std(x[snt_id, 0:actual_size, ...], dim=0)) if self.return_mean: mean = torch.stack(mean) if self.return_std: std = torch.stack(std) if self.return_mean: gnoise = self._get_gauss_noise(mean.size(), device=mean.device) gnoise = gnoise mean += gnoise if self.return_std: std = std + self.eps # Append mean and std of the batch if self.return_mean and self.return_std: pooled_stats = torch.cat((mean, std), dim=1) pooled_stats = pooled_stats.unsqueeze(1) elif self.return_mean: pooled_stats = mean.unsqueeze(1) elif self.return_std: pooled_stats = std.unsqueeze(1) return pooled_stats def _get_gauss_noise(self, shape_of_tensor, device="cpu"): """Returns a tensor of epsilon Gaussian noise. Arguments --------- shape_of_tensor : tensor It represents the size of tensor for generating Gaussian noise. """ gnoise = torch.randn(shape_of_tensor, device=device) gnoise -= torch.min(gnoise) gnoise /= torch.max(gnoise) gnoise = self.eps * ((1 - 9) * gnoise + 9) return gnoise class PureQwen2PreTrainedModel(PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models. """ config_class = PureQwen2Config base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["Qwen2DecoderLayer"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True _supports_sdpa = True _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = True def _init_weights(self, module): std = self.config.initializer_range if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() class PureQwen2Model(PureQwen2PreTrainedModel): """ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2DecoderLayer`] Args: config: Qwen2Config """ def __init__(self, config: PureQwen2Config): super().__init__(config) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) self.layers = nn.ModuleList( [Qwen2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) self._attn_implementation = config._attn_implementation self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = Qwen2RotaryEmbedding(config=config) self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.embed_tokens def set_input_embeddings(self, value): self.embed_tokens = value def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") if self.gradient_checkpointing and self.training: if use_cache: logger.warning_once( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False # kept for BC (non `Cache` `past_key_values` inputs) return_legacy_cache = False if use_cache and not isinstance(past_key_values, Cache): return_legacy_cache = True if past_key_values is None: past_key_values = DynamicCache() else: past_key_values = DynamicCache.from_legacy_cache(past_key_values) logger.warning_once( "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " "(https://huggingface.co./docs/transformers/kv_cache#legacy-cache-format)" ) if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) if position_ids is None: position_ids = cache_position.unsqueeze(0) causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions ) hidden_states = inputs_embeds # create position embeddings to be shared across the decoder layers position_embeddings = self.rotary_emb(hidden_states, position_ids) # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None next_decoder_cache = None for decoder_layer in self.layers: if output_hidden_states: all_hidden_states += (hidden_states,) if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, causal_mask, position_ids, past_key_values, output_attentions, use_cache, cache_position, position_embeddings, ) else: layer_outputs = decoder_layer( hidden_states, attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, ) hidden_states = layer_outputs[0] if use_cache: next_decoder_cache = layer_outputs[2 if output_attentions else 1] if output_attentions: all_self_attns += (layer_outputs[1],) hidden_states = self.norm(hidden_states) # add hidden states from the last decoder layer if output_hidden_states: all_hidden_states += (hidden_states,) next_cache = next_decoder_cache if use_cache else None if return_legacy_cache: next_cache = next_cache.to_legacy_cache() if not return_dict: return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attns, ) # Copied from transformers.models.phi3.modeling_phi3.Phi3Model._update_causal_mask def _update_causal_mask( self, attention_mask: torch.Tensor, input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, output_attentions: bool, ): if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward if ( self.config._attn_implementation == "sdpa" and not (using_static_cache or using_sliding_window_cache) and not output_attentions ): if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens, sliding_window=self.config.sliding_window, is_training=self.training, ): return None dtype, device = input_tensor.dtype, input_tensor.device min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] # SlidingWindowCache or StaticCache if using_sliding_window_cache or using_static_cache: target_length = past_key_values.get_max_cache_shape() # DynamicCache or no cache else: target_length = ( attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else past_seen_tokens + sequence_length + 1 ) # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( attention_mask, sequence_length=sequence_length, target_length=target_length, dtype=dtype, device=device, cache_position=cache_position, batch_size=input_tensor.shape[0], config=self.config, past_key_values=past_key_values, ) if ( self.config._attn_implementation == "sdpa" and attention_mask is not None and attention_mask.device.type == "cuda" and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. # Details: https://github.com/pytorch/pytorch/issues/110213 causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) return causal_mask @staticmethod # Copied from transformers.models.mistral.modeling_mistral.MistralModel._prepare_4d_causal_attention_mask_with_cache_position with Mistral->Qwen2 def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, target_length: int, dtype: torch.dtype, device: torch.device, cache_position: torch.Tensor, batch_size: int, config: PureQwen2Config, past_key_values: Cache, ): """ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. Args: attention_mask (`torch.Tensor`): A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. sequence_length (`int`): The sequence length being processed. target_length (`int`): The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. dtype (`torch.dtype`): The dtype to use for the 4D attention mask. device (`torch.device`): The device to plcae the 4D attention mask on. cache_position (`torch.Tensor`): Indices depicting the position of the input sequence tokens in the sequence. batch_size (`torch.Tensor`): Batch size. config (`Qwen2Config`): The model's configuration class past_key_values (`Cache`): The cache class that is being used currently to generate """ if attention_mask is not None and attention_mask.dim() == 4: # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. causal_mask = attention_mask else: min_dtype = torch.finfo(dtype).min causal_mask = torch.full( (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device ) diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) if config.sliding_window is not None: # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also # the check is needed to verify is current checkpoint was trained with sliding window or not if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: sliding_attend_mask = torch.arange(target_length, device=device) <= ( cache_position.reshape(-1, 1) - config.sliding_window ) diagonal_attend_mask.bitwise_or_(sliding_attend_mask) causal_mask *= diagonal_attend_mask causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) if attention_mask is not None: causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit if attention_mask.shape[-1] > target_length: attention_mask = attention_mask[:, :target_length] mask_length = attention_mask.shape[-1] padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] padding_mask = padding_mask == 0 causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( padding_mask, min_dtype ) return causal_mask class PureQwen2ForSequenceClassification(PureQwen2PreTrainedModel): def __init__( self, config, label_smoothing=0.0, ): super().__init__(config) self.label_smoothing = label_smoothing self.num_labels = config.num_labels self.config = config self.model = PureQwen2Model(config) self.pure = PURE( in_dim=config.hidden_size, svd_rank=config.svd_rank, num_pc_to_remove=config.num_pc_to_remove, center=config.center, num_iters=config.num_iters, alpha=config.alpha, disable_pcr=config.disable_pcr, disable_pfsa=config.disable_pfsa, disable_covariance=config.disable_covariance ) self.mean = StatisticsPooling(return_mean=True, return_std=False) self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) # Initialize weights and apply final processing self.post_init() def forward_pure_embeddings( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, ModelOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict transformer_outputs = self.model( input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) token_embeddings = transformer_outputs[0] token_embeddings = self.pure(token_embeddings, attention_mask) return ModelOutput( last_hidden_state=token_embeddings, ) def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, # position_ids: Optional[torch.LongTensor] = None, ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.model( input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) token_embeddings = outputs[0] token_embeddings = self.pure(token_embeddings, attention_mask) pooled_output = self.mean(token_embeddings).squeeze(1) logits = self.score(pooled_output) loss = None if labels is not None: if self.config.problem_type is None: if self.num_labels == 1: self.config.problem_type = "regression" elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): self.config.problem_type = "single_label_classification" else: self.config.problem_type = "multi_label_classification" if self.config.problem_type == "regression": loss_fct = nn.MSELoss() if self.num_labels == 1: loss = loss_fct(logits.squeeze(), labels.squeeze()) else: loss = loss_fct(logits, labels) elif self.config.problem_type == "single_label_classification": loss_fct = nn.CrossEntropyLoss(label_smoothing=self.label_smoothing) loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) elif self.config.problem_type == "multi_label_classification": loss_fct = nn.BCEWithLogitsLoss() loss = loss_fct(logits, labels) if not return_dict: output = (logits,) + outputs[2:] return ((loss,) + output) if loss is not None else output return SequenceClassifierOutput( loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, )