# coding=utf-8 # Copyright 2022 The OpenBMB Team and The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ PyTorch CPMAnt""" import math from typing import List, Optional, Tuple, Union import torch import torch.nn.functional as F import torch.utils.checkpoint from torch import nn from torch.nn import CrossEntropyLoss from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_utils import PreTrainedModel from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging from .configuration_cpmant import CpmAntConfig logger = logging.get_logger(__name__) _CHECKPOINT_FOR_DOC = "openbmb/cpm-ant-10b" _CONFIG_FOR_DOC = "CpmAntConfig" CPMANT_PRETRAINED_MODEL_ARCHIVE_LIST = [ "openbmb/cpm-ant-10b", # See all CPMAnt models at https://huggingface.co./models?filter=cpmant ] class CpmAntLayerNorm(nn.Module): """ We use Root Mean Square (RMS) Layer Normalization, please see https://arxiv.org/abs/1910.07467 for details." """ def __init__(self, config: CpmAntConfig): super().__init__() self.eps = config.eps self.dim_norm = config.hidden_size self.weight = nn.Parameter(torch.empty(config.hidden_size)) def forward(self, hidden_states: torch.Tensor): """ Args: hidden_states (`torch.Tensor` of shape `(batch, seq_len, dim_in)`) """ if hidden_states.size(-1) != self.dim_norm: raise AssertionError("hidden_states.size(-1) != self.dim_norm") old_dtype = hidden_states.dtype variance = hidden_states.to(torch.float32).pow(2).mean(dim=-1, keepdim=True) hidden_states = (hidden_states * torch.rsqrt(variance + self.eps)).to(old_dtype) * self.weight return hidden_states class CpmAntAttention(nn.Module): def __init__(self, config: CpmAntConfig): super().__init__() self.dim_model = config.hidden_size self.num_heads = config.num_attention_heads self.dim_head = config.dim_head self.project_q = nn.Linear(self.dim_model, self.num_heads * self.dim_head, bias=False) self.project_k = nn.Linear(self.dim_model, self.num_heads * self.dim_head, bias=False) self.project_v = nn.Linear(self.dim_model, self.num_heads * self.dim_head, bias=False) self.attention_out = nn.Linear(self.num_heads * self.dim_head, self.dim_model, bias=False) self.softmax = torch.nn.Softmax(dim=-1) if config.dropout_p is not None: self.dropout = torch.nn.Dropout(p=config.dropout_p) else: self.dropout = None def forward( self, hidden_q: torch.Tensor, hidden_kv: torch.Tensor, attention_mask: torch.BoolTensor, position_bias: torch.Tensor, output_attentions: Optional[bool] = False, past_key_values: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, use_cache: Optional[bool] = None, ): """ Args: hidden_q (`torch.Tensor`): Input of transformer block(self-attention block). It can be the raw embedding of a batch of sequences. hidden_kv (`torch.Tensor` of shape `(batch, len_k, dim_model)`)): Tensor *key_value* and *query* of shape `(batch, len_k, dim_model)` attention_mask (`torch.Tensor` of shape `(batch, len_seq, len_seq)`): Avoid invalid areas to participate in the calculation of self-attention. position_bias (`torch.Tensor` of shape `(batch, len_seq, len_seq)`): Provide positional information to self-attention block. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. past_key_values (`Tuple[torch.Tensor, torch.Tensor]`, *optional*): Cached past key and value projection states. use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). """ batch_size = hidden_q.size(0) len_q = hidden_q.size(1) len_k = hidden_kv.size(1) query = self.project_q(hidden_q) key = self.project_k(hidden_kv) value = self.project_v(hidden_kv) query = query.view(batch_size, len_q, self.num_heads, self.dim_head).permute(0, 2, 1, 3) key = key.view(batch_size, len_k, self.num_heads, self.dim_head).permute(0, 2, 1, 3) value = value.view(batch_size, len_k, self.num_heads, self.dim_head).permute(0, 2, 1, 3) if past_key_values is not None: key = torch.cat([past_key_values[0], key], dim=-2) value = torch.cat([past_key_values[1], value], dim=-2) len_k = key.size(-2) # (batch_size, num_heads, len_q, dim_head) @ (batch_size, num_heads, dim_head, len_k) -> (batch_size, num_heads, len_q, len_k) score = torch.matmul(query, key.transpose(-1, -2)) / math.sqrt(self.dim_head) score = score + position_bias score = torch.masked_fill( score, attention_mask.view(batch_size, 1, len_q, len_k) == torch.tensor(False), torch.scalar_tensor(float("-inf"), device=score.device, dtype=score.dtype), ) score = self.softmax(score) score = torch.masked_fill( score, attention_mask.view(batch_size, 1, len_q, len_k) == torch.tensor(False), torch.scalar_tensor(0, device=score.device, dtype=score.dtype), ) if output_attentions: attn_weights = score else: attn_weights = None if self.dropout is not None: score = self.dropout(score) # (batch_size, num_heads, len_q, len_k) @ (batch_size, num_heads, len_k, dim_head) -> (batch_size, num_heads, len_q, dim_head) score = torch.matmul(score, value) score = score.view(batch_size, self.num_heads, len_q, self.dim_head).permute(0, 2, 1, 3) score = score.contiguous().view(batch_size, len_q, self.num_heads * self.dim_head) score = self.attention_out(score) past_key_values = None if use_cache: past_key_values = (key, value) return score, attn_weights, past_key_values class CpmAntSelfAttentionBlock(nn.Module): def __init__(self, config: CpmAntConfig): super().__init__() self.layernorm_before_attention = CpmAntLayerNorm(config) self.self_attention = CpmAntAttention(config) if config.dropout_p: self.dropout = torch.nn.Dropout(config.dropout_p) else: self.dropout = None def forward( self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, position_bias: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = False, past_key_values: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, use_cache: Optional[bool] = None, ): """ Args: hidden_states (`torch.Tensor` of shape `(batch, len_seq, dim_model)`): Input of transformer block(self-attention block). It can be the raw embedding of a batch of sequences. attention_mask (`torch.Tensor` of shape `(batch, len_seq, len_seq)`): Avoid invalid areas to participate in the calculation of self-attention. position_bias (`torch.Tensor` of shape `(batch, len_seq, len_seq)`): Provide positional information to self-attention block. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. past_key_values (`Tuple(torch.FloatTensor)`, *optional*): Cached past key and value projection states. use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). """ outputs = self.layernorm_before_attention(hidden_states) outputs = self.self_attention( outputs, outputs, attention_mask, position_bias, output_attentions, past_key_values, use_cache ) outputs, attn_weights, current_key_value = outputs if self.dropout is not None: outputs = self.dropout(outputs) hidden_states = hidden_states + outputs return hidden_states, attn_weights, current_key_value class CpmAntDenseGatedACT(nn.Module): def __init__(self, config: CpmAntConfig): super().__init__() self.w_0 = nn.Linear(config.hidden_size, config.dim_ff, bias=False) self.w_1 = nn.Linear(config.hidden_size, config.dim_ff, bias=False) self.act = torch.nn.GELU() def forward(self, hidden_states: torch.Tensor): """Transform an input tensor from one feature space to another via a nonlinear operation Args: hidden_states (`torch.Tensor` of shape `(batch, seq_len, dim_in)`) """ gate_score = self.act(self.w_0(hidden_states)) hidden_states = self.w_1(hidden_states) hidden_states = gate_score * hidden_states return hidden_states class CpmAntFeedForward(nn.Module): def __init__(self, config: CpmAntConfig): super().__init__() self.w_in = CpmAntDenseGatedACT(config) if config.dropout_p is not None: self.dropout = torch.nn.Dropout(config.dropout_p) else: self.dropout = None self.w_out = nn.Linear(config.dim_ff, config.hidden_size, bias=False) def forward(self, hidden_states: torch.Tensor): """ Args: hidden_states (`torch.Tensor` of shape `(batch, seq_len, dim_in)`) """ hidden_states = self.w_in(hidden_states) if self.dropout is not None: hidden_states = self.dropout(hidden_states) hidden_states = self.w_out(hidden_states) return hidden_states class CpmAntFFNBlock(nn.Module): def __init__(self, config: CpmAntConfig): super().__init__() self.layernorm_before_ffn = CpmAntLayerNorm(config) self.ffn = CpmAntFeedForward(config) if config.dropout_p: self.dropout = torch.nn.Dropout(config.dropout_p) else: self.dropout = None def forward( self, hidden_states: torch.Tensor, ): """ Args: hidden_states (`torch.Tensor` of shape `(batch, len_seq, dim_model)`): Hidden states before feed forward layer. """ ln_outputs = self.layernorm_before_ffn(hidden_states) outputs = self.ffn(ln_outputs) if self.dropout is not None: outputs = self.dropout(outputs) hidden_states = hidden_states + outputs return hidden_states class CpmAntTransformerBlock(nn.Module): def __init__(self, config: CpmAntConfig): super().__init__() self.self_att = CpmAntSelfAttentionBlock(config) self.ffn = CpmAntFFNBlock(config) def forward( self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, position_bias: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = False, past_key_values: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, use_cache: Optional[bool] = None, ): """ Args: hidden_states (`torch.Tensor`): Input to the layer of shape `(batch, seq_len, dim_model)` attention_mask (`torch.Tensor`): Avoid invalid areas to participate in the calculation of shape `(batch, seq_len, seq_len)` position_bias (`torch.Tensor`): Provides position information to attention mechanism of shape `(num_heads, seq_len, seq_len)` output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. past_key_values (`Tuple[torch.Tensor, torch.Tensor])`, *optional*): Cached past key and value projection states use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). """ hidden_states = self.self_att( hidden_states, attention_mask=attention_mask, position_bias=position_bias, output_attentions=output_attentions, past_key_values=past_key_values, use_cache=use_cache, ) hidden_states, attn_weights, current_key_value = hidden_states hidden_states = self.ffn(hidden_states) return hidden_states, attn_weights, current_key_value class CpmAntEncoder(nn.Module): def __init__(self, config: CpmAntConfig): super().__init__() self.num_layers = config.num_hidden_layers self.layers = nn.ModuleList([CpmAntTransformerBlock(config) for ith in range(self.num_layers)]) self.output_layernorm = CpmAntLayerNorm(config) def forward( self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, position_bias: torch.Tensor, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, past_key_values: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, use_cache: Optional[bool] = None, ): """ Args: hidden_states (`torch.Tensor`): Input to the layer of shape `(batch, seq_len, dim_model)` attention_mask (`torch.Tensor`): Avoid invalid areas to participate in the calculation of shape `(batch, seq_len, seq_len)` position_bias (`torch.Tensor`): Provides position information to attention mechanism of shape `(num_heads, seq_len, seq_len)` output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. past_key_values (`Tuple[torch.Tensor, torch.Tensor])`, *optional*): Cached past key and value projection states use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). """ all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None current_key_values = () if use_cache else None for i, layer in enumerate(self.layers): if output_hidden_states: all_hidden_states += (hidden_states,) layer_outputs = layer( hidden_states, attention_mask, position_bias, output_attentions=output_attentions, past_key_values=past_key_values[i] if past_key_values else None, use_cache=use_cache, ) hidden_states, attn_weights, current_key_value = layer_outputs if output_attentions: all_self_attns += (attn_weights,) if current_key_value is not None: current_key_values = current_key_values + (current_key_value,) hidden_states = self.output_layernorm(hidden_states) if output_hidden_states: all_hidden_states += (hidden_states,) return hidden_states, current_key_values, all_hidden_states, all_self_attns # Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->CPMAnt class CpmAntIntermediate(nn.Module): def __init__(self, config): super().__init__() self.dense = nn.Linear(config.hidden_size, config.intermediate_size) if isinstance(config.hidden_act, str): self.intermediate_act_fn = ACT2FN[config.hidden_act] else: self.intermediate_act_fn = config.hidden_act def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.dense(hidden_states) hidden_states = self.intermediate_act_fn(hidden_states) return hidden_states class CpmAntSegmentPositionEmbedding(nn.Module): def __init__(self, config: CpmAntConfig): super().__init__() self.num_heads = config.num_attention_heads self.num_buckets = config.position_bias_num_buckets self.max_distance = config.position_bias_max_distance self.num_segments = config.segment_types self.relative_attention_bias = nn.Parameter( torch.empty( config.segment_types * config.segment_types + config.position_bias_num_buckets, config.num_attention_heads, ) ) def forward( self, key_pos: torch.Tensor, query_pos: torch.Tensor, key_segment: torch.Tensor, query_segment: torch.Tensor, ): with torch.no_grad(): batch = key_pos.size(0) keylen = key_pos.size(1) querylen = query_pos.size(1) if key_pos.size(0) != query_pos.size(0): raise AssertionError( f"key_pos.size(0) should be equal to query_pos.size(0), but got {key_pos.size(0)} and {query_pos.size(0)}!" ) if keylen != key_segment.size(1) or querylen != query_segment.size(1): raise AssertionError( f"keylen should be equal to key_segment.size(1), but got {keylen} and {key_segment.size(1)}!" ) if querylen != query_segment.size(1): raise AssertionError( f"querylen should be equal to query_segment.size(1), but got {querylen} and {query_segment.szie(1)}!" ) key_pos = key_pos.view(batch, -1, keylen) query_pos = query_pos.view(batch, querylen, -1) key_segment = key_segment.view(batch, -1, keylen) query_segment = query_segment.view(batch, querylen, -1) relative_position_bucket = self._segment_relative_position_bucket(query_segment, key_segment) relative_position_bucket = relative_position_bucket + self.num_buckets # (batch, len_q, len_k) absolute_position_bucket = self._position_bucket( torch.arange(keylen, dtype=torch.int32, device=relative_position_bucket.device)[None, :] - torch.arange(querylen, dtype=torch.int32, device=relative_position_bucket.device)[:, None], num_buckets=self.num_buckets, max_distance=self.max_distance, ) relative_position_bucket = torch.where( (key_segment == query_segment), absolute_position_bucket[None, :, :], relative_position_bucket, ) # (batch, len_q, len_k, num_heads) embeds = F.embedding(relative_position_bucket, self.relative_attention_bias) # (batch, num_heads, len_q, len_k) embeds = embeds.permute(0, 3, 1, 2).contiguous() return embeds def _segment_relative_position_bucket(self, query_segment, key_segment): return query_segment * self.num_segments + key_segment def _position_bucket(self, relative_position, num_buckets=32, max_distance=128): relative_buckets = 0 # always bidirectional in CPMAnt num_buckets //= 2 relative_buckets = (relative_position > 0).to(torch.int32) * num_buckets relative_position = torch.abs(relative_position) max_exact = num_buckets // 2 is_small = relative_position < max_exact relative_postion_if_large = max_exact + ( torch.log(relative_position.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact) ).to(torch.int32) relative_postion_if_large = torch.min( relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1), ) relative_buckets += torch.where(is_small, relative_position.to(torch.int32), relative_postion_if_large) return relative_buckets # Copied from transformers.models.bert.modeling_bert.BertOutput with Bert->CPMAnt class CpmAntOutput(nn.Module): def __init__(self, config): super().__init__() self.dense = nn.Linear(config.intermediate_size, config.hidden_size) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) hidden_states = self.LayerNorm(hidden_states + input_tensor) return hidden_states class CpmAntPreTrainedModel(PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models. """ config_class = CpmAntConfig base_model_prefix = "cpmant" supports_gradient_checkpointing = True def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=self.config.init_std) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=self.config.init_std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) elif isinstance(module, CpmAntLayerNorm): module.weight.data.fill_(1.0) elif isinstance(module, CpmAntSegmentPositionEmbedding): module.relative_attention_bias.data.normal_(mean=0.0, std=self.config.init_std) def _set_gradient_checkpointing(self, module, value=False): if isinstance(module, CpmAntEncoder): module.gradient_checkpointing = value CPMANT_START_DOCSTRING = r""" This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior. Parameters config ([`~CpmAntConfig`]): Model configuration class with all the parameters of the Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. """ CPMANT_INPUTS_DOCSTRING = r""" Args: input_ids (`torch.Tensor` of shape `(batch_size, seq_len)`): Indices of input sequence tokens in the vocabulary. Indices can be obtained using [`CPMAntTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids) past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ @add_start_docstrings( "The bare CPMAnt Model outputting raw hidden-states without any specific head on top.", CPMANT_START_DOCSTRING, ) class CpmAntModel(CpmAntPreTrainedModel): def __init__(self, config: CpmAntConfig): super().__init__(config) self.encoder = CpmAntEncoder(config) self.segment_embedding = nn.Embedding(config.segment_types, config.hidden_size) self.input_embedding = nn.Embedding( config.vocab_size + config.prompt_types * config.prompt_length, config.hidden_size ) self.position_bias = CpmAntSegmentPositionEmbedding(config) self.prompt_length = config.prompt_length self.vocab_size = config.vocab_size self.post_init() def get_input_embeddings(self): return self.input_embedding def set_input_embeddings(self, embeddings, **kwargs): self.input_embedding = embeddings def _prepare_attention_mask(self, input_ids, span, context, length): batch = input_ids.size(0) seqlen = input_ids.size(1) device = input_ids.device directional_mask_2d = torch.arange(seqlen, device=device) <= torch.arange(seqlen, device=device).view(-1, 1) attention_mask = context[:, None, :] | ( context[:, :, None].logical_not() & directional_mask_2d.view(1, seqlen, seqlen) ) attention_mask = attention_mask & (span[:, None, :] == span[:, :, None]) # mask for left padding mask_1d = ( torch.tensor(list(range(seqlen - self.prompt_length))[::-1], device=device)[None, :].repeat(batch, 1) < length[:, None] ) mask_1d = torch.cat((torch.ones(batch, self.prompt_length, device=device).bool(), mask_1d), dim=1) attention_mask = mask_1d.view(batch, seqlen, 1) & mask_1d.view(batch, 1, seqlen) & attention_mask return attention_mask @add_start_docstrings_to_model_forward(CPMANT_INPUTS_DOCSTRING) @add_code_sample_docstrings( checkpoint=_CHECKPOINT_FOR_DOC, output_type=BaseModelOutputWithPast, config_class=_CONFIG_FOR_DOC, ) def forward( self, input_ids: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, use_cache: Optional[bool] = None, return_dict: Optional[bool] = None, **kwargs, ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict use_cache = use_cache if use_cache is not None else self.config.use_cache # add prompts ahead if input_ids.dtype != torch.int32: input_ids = input_ids.to(torch.int32) dtype, device = input_ids.dtype, input_ids.device segment = torch.where(input_ids != 0, 2, 0).to(dtype=dtype, device=device) length = (segment != 0).sum(-1).to(dtype=dtype, device=device) input_ids = torch.cat( ( torch.arange( self.prompt_length * 2 + self.vocab_size, self.prompt_length * 3 + self.vocab_size, dtype=dtype, device=device, ).repeat(input_ids.size(0), 1), input_ids, ), dim=1, ) batch, seq_length = input_ids.size() segment = torch.cat((torch.zeros(batch, self.prompt_length, dtype=dtype, device=device), segment), dim=1) context = torch.full((batch, seq_length), 1, dtype=dtype, device=device) position = torch.arange(seq_length, dtype=dtype, device=device).repeat(batch, 1) span = torch.full((batch, seq_length), 0, dtype=dtype, device=device) if past_key_values is None: past_length = 0 past_key_values = tuple([None] * self.encoder.num_layers) input_ids = input_ids.contiguous() hidden_states = self.input_embedding(input_ids) segment_states = self.segment_embedding(segment) hidden_states = hidden_states + segment_states else: past_length = past_key_values[0][0].size(-2) segment_states = self.segment_embedding(segment) hidden_states = self.input_embedding(input_ids) + segment_states[:, -1:, :] attention_mask = self._prepare_attention_mask(input_ids, span, context, length) position_bias = self.position_bias(position, position, segment, segment) attention_mask = attention_mask[:, past_length:, :] position_bias = position_bias[:, :, past_length:, :] hidden_states = hidden_states[:, past_length:, :] hidden_states, present_key_values, all_hidden_states, all_attentions = self.encoder( hidden_states, attention_mask, position_bias, output_attentions, output_hidden_states, past_key_values, use_cache, ) if past_length == 0: hidden_states = hidden_states[:, self.prompt_length :, :] # drop the prompt if all_attentions is not None: new_attentions = () for attention in all_attentions: new_attentions += (attention[:, :, self.prompt_length :, self.prompt_length :],) all_attentions = new_attentions if all_hidden_states is not None: new_hidden_states = () for hidden_state in all_hidden_states: new_hidden_states += (hidden_state[:, self.prompt_length :, :],) all_hidden_states = new_hidden_states if not return_dict: return tuple( v for v in [hidden_states, present_key_values, all_hidden_states, all_attentions] if v is not None ) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=present_key_values, hidden_states=all_hidden_states, attentions=all_attentions, ) @add_start_docstrings( """ The CPMAnt Model with a language modeling head on top (linear layer with weights tied to the input embeddings). """, CPMANT_START_DOCSTRING, ) class CpmAntForCausalLM(CpmAntPreTrainedModel): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config: CpmAntConfig): super().__init__(config) self.cpmant = CpmAntModel(config) # lm_head.weight is tied to cpmant.input_embedding.weight self.lm_head = nn.Linear( config.hidden_size, config.vocab_size + config.prompt_types * config.prompt_length, bias=False ) self.post_init() @add_start_docstrings_to_model_forward(CPMANT_INPUTS_DOCSTRING) @add_code_sample_docstrings( checkpoint=_CHECKPOINT_FOR_DOC, output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC, ) def forward( self, input_ids: Optional[torch.Tensor] = None, past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, labels: Optional[torch.Tensor] = None, return_dict: Optional[bool] = None, attention_mask: Optional[torch.Tensor] = None, # dummy parameter for text-generation pipeline **kwargs, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: input_ids (`torch.Tensor` of shape `(batch_size, seq_len)`): Indices of input sequence tokens in the vocabulary. Indices can be obtained using [`CPMAntTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids) past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. labels (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): CPMAnt will process attention mask automatically, this parameter is a dummy parameter for text-generation pipeline. Example: Text Generation with CpmAntForCausalLM. ```python >>> from transformers import CPMAntTokenizer, CpmAntForCausalLM >>> texts = "今天天气不错," >>> model = CpmAntForCausalLM.from_pretrained("openbmb/cpm-ant-10b") >>> tokenizer = CPMAntTokenizer.from_pretrained("openbmb/cpm-ant-10b") >>> input_ids = tokenizer(texts, return_tensors="pt") >>> outputs = model.generate(**input_ids) >>> output_texts = tokenizer.batch_decode(outputs) >>> print(output_texts) ['今天天气不错,阳光明媚,我和妈妈一起去超市买东西。\n在超市里,我看到了一个很好玩的玩具,它的名字叫“机器人”。它有一个圆圆的脑袋,两只圆圆的眼睛,还有一个圆圆的'] ``` """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict model_output = self.cpmant( input_ids, output_attentions, output_hidden_states, past_key_values, use_cache, return_dict ) hidden_states = model_output.last_hidden_state if return_dict else model_output[0] logits = self.lm_head(hidden_states) loss = None if labels is not None: loss_func = CrossEntropyLoss() loss = loss_func(logits.view(-1, logits.size(-1)), labels.view(-1)) if not return_dict: output = (logits,) + model_output[1:] return ((loss,) + output) if loss is not None else output return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=model_output.past_key_values, hidden_states=model_output.hidden_states, attentions=model_output.attentions, ) def get_input_embeddings(self): return self.cpmant.input_embedding def set_input_embeddings(self, embeddings): self.cpmant.input_embedding = embeddings def get_output_embeddings(self): return self.lm_head def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings def prepare_inputs_for_generation(self, input_ids, **kwargs): input_ids = input_ids.int() # save the memory usage of dummy attention mask if "attention_mask" in kwargs: kwargs["attention_mask"] = torch.zeros(1, 1) return { "input_ids": input_ids, "use_cache": kwargs["use_cache"], "past_key_values": kwargs.get("past_key_values", None), } def _reorder_cache(self, past_key_values, beam_idx): past_key_values = [list(each) if each is not None else each for each in past_key_values] for key_value_layer in past_key_values: key_value_layer[0] = key_value_layer[0][beam_idx] key_value_layer[1] = key_value_layer[1][beam_idx] return past_key_values