sharpenb's picture
Upload folder using huggingface_hub (#1)
a1d0506 verified
"""GPT Blocks used for the GPT Model."""
from typing import Any, Dict, Optional, Tuple
import torch
import torch.nn as nn
from .attention import ATTN_CLASS_REGISTRY
from .ffn import FFN_CLASS_REGISTRY, build_ffn
from .norm import NORM_CLASS_REGISTRY
try:
from flash_attn.bert_padding import unpad_input, pad_input
except:
(unpad_input, pad_input) = (None, None)
attn_config_defaults: Dict = {
"attn_type": "multihead_attention",
"attn_pdrop": 0.0,
"attn_impl": "flash",
"qk_ln": True,
"qk_gn": False,
"clip_qkv": None,
"softmax_scale": None,
"prefix_lm": False,
"attn_uses_sequence_id": False,
"sliding_window_size": -1,
"alibi": False,
"alibi_bias_max": 8,
"rope": False,
"rope_theta": 10000,
"rope_impl": "dail",
"rope_dail_config": {
"type": "original",
"pos_idx_in_fp32": True,
"xpos_scale_base": 512,
},
"rope_hf_config": {"type": "no_scaling", "factor": 1.0},
}
class MPTBlock(nn.Module):
def __init__(
self,
d_model: int,
n_heads: int,
expansion_ratio: int,
attn_config: Optional[Dict] = None,
ffn_config: Optional[Dict] = None,
resid_pdrop: float = 0.0,
norm_type: str = "low_precision_layernorm",
fc_type: str = "torch",
device: Optional[str] = None,
no_bias: bool = False,
use_pad_tok_in_ffn: bool = True,
**kwargs: Any
):
if attn_config is None:
attn_config = attn_config_defaults
if ffn_config is None:
ffn_config = {"ffn_type": "mptmlp"}
del kwargs
super().__init__()
norm_class = NORM_CLASS_REGISTRY[norm_type.lower()]
assert isinstance(attn_config["attn_type"], str)
attn_class = ATTN_CLASS_REGISTRY[attn_config["attn_type"]]
args_to_exclude_in_attn_class = {
"attn_type",
"prefix_lm",
"alibi",
"attn_uses_sequence_id",
"alibi_bias_max",
"rope",
"rope_theta",
"rope_impl",
"rope_dail_config",
"rope_hf_config",
}
attn_config_subset_for_attn_class = {
k: v
for (k, v) in attn_config.items()
if k not in args_to_exclude_in_attn_class
}
self.norm_1 = norm_class(d_model, device=device)
self.attn = attn_class(
d_model=d_model,
n_heads=n_heads,
fc_type=fc_type,
device=device,
**attn_config_subset_for_attn_class,
bias=not no_bias
)
self.norm_2 = None
if not getattr(FFN_CLASS_REGISTRY[ffn_config["ffn_type"]], "_has_norm", False):
self.norm_2 = norm_class(d_model, device=device)
self.ffn = build_ffn(
d_model=d_model,
expansion_ratio=expansion_ratio,
device=device,
bias=not no_bias,
**ffn_config
)
self.resid_attn_dropout = nn.Dropout(resid_pdrop)
self.resid_ffn_dropout = nn.Dropout(resid_pdrop)
self.use_pad_tok_in_ffn = use_pad_tok_in_ffn
def forward(
self,
x: torch.Tensor,
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
attn_bias: Optional[torch.Tensor] = None,
rotary_emb_w_meta_info: Optional[Dict] = None,
attention_mask: Optional[torch.ByteTensor] = None,
is_causal: bool = True,
output_attentions: bool = False,
alibi_slopes: Optional[torch.Tensor] = None,
flash_attn_padding_info: Optional[dict[str, torch.Tensor]] = None,
) -> Tuple[
torch.Tensor,
Optional[torch.Tensor],
Optional[Tuple[torch.Tensor, torch.Tensor]],
]:
a = self.norm_1(x)
(b, attn_weights, past_key_value) = self.attn(
a,
past_key_value=past_key_value,
attn_bias=attn_bias,
rotary_emb_w_meta_info=rotary_emb_w_meta_info,
attention_mask=attention_mask,
is_causal=is_causal,
needs_weights=output_attentions,
alibi_slopes=alibi_slopes,
flash_attn_padding_info=flash_attn_padding_info,
)
x = x + self.resid_attn_dropout(b)
m = x
if self.norm_2 is not None:
m = self.norm_2(x)
(batch_size, seq_len) = m.size()[:2]
indices = None
if not self.use_pad_tok_in_ffn:
assert unpad_input is not None
(m, indices, _, _) = unpad_input(m, attention_mask)
n = self.ffn(m)
if not self.use_pad_tok_in_ffn:
assert pad_input is not None
n = pad_input(n, indices, batch_size, seq_len)
x = x + self.resid_ffn_dropout(n)
return (x, attn_weights, past_key_value)