|
"""GPT Blocks used for the GPT Model.""" |
|
|
|
from typing import Any, Dict, Optional, Tuple |
|
import torch |
|
import torch.nn as nn |
|
from .attention import ATTN_CLASS_REGISTRY |
|
from .ffn import FFN_CLASS_REGISTRY, build_ffn |
|
from .norm import NORM_CLASS_REGISTRY |
|
|
|
try: |
|
from flash_attn.bert_padding import unpad_input, pad_input |
|
except: |
|
(unpad_input, pad_input) = (None, None) |
|
attn_config_defaults: Dict = { |
|
"attn_type": "multihead_attention", |
|
"attn_pdrop": 0.0, |
|
"attn_impl": "flash", |
|
"qk_ln": True, |
|
"qk_gn": False, |
|
"clip_qkv": None, |
|
"softmax_scale": None, |
|
"prefix_lm": False, |
|
"attn_uses_sequence_id": False, |
|
"sliding_window_size": -1, |
|
"alibi": False, |
|
"alibi_bias_max": 8, |
|
"rope": False, |
|
"rope_theta": 10000, |
|
"rope_impl": "dail", |
|
"rope_dail_config": { |
|
"type": "original", |
|
"pos_idx_in_fp32": True, |
|
"xpos_scale_base": 512, |
|
}, |
|
"rope_hf_config": {"type": "no_scaling", "factor": 1.0}, |
|
} |
|
|
|
|
|
class MPTBlock(nn.Module): |
|
|
|
def __init__( |
|
self, |
|
d_model: int, |
|
n_heads: int, |
|
expansion_ratio: int, |
|
attn_config: Optional[Dict] = None, |
|
ffn_config: Optional[Dict] = None, |
|
resid_pdrop: float = 0.0, |
|
norm_type: str = "low_precision_layernorm", |
|
fc_type: str = "torch", |
|
device: Optional[str] = None, |
|
no_bias: bool = False, |
|
use_pad_tok_in_ffn: bool = True, |
|
**kwargs: Any |
|
): |
|
if attn_config is None: |
|
attn_config = attn_config_defaults |
|
if ffn_config is None: |
|
ffn_config = {"ffn_type": "mptmlp"} |
|
del kwargs |
|
super().__init__() |
|
norm_class = NORM_CLASS_REGISTRY[norm_type.lower()] |
|
assert isinstance(attn_config["attn_type"], str) |
|
attn_class = ATTN_CLASS_REGISTRY[attn_config["attn_type"]] |
|
args_to_exclude_in_attn_class = { |
|
"attn_type", |
|
"prefix_lm", |
|
"alibi", |
|
"attn_uses_sequence_id", |
|
"alibi_bias_max", |
|
"rope", |
|
"rope_theta", |
|
"rope_impl", |
|
"rope_dail_config", |
|
"rope_hf_config", |
|
} |
|
attn_config_subset_for_attn_class = { |
|
k: v |
|
for (k, v) in attn_config.items() |
|
if k not in args_to_exclude_in_attn_class |
|
} |
|
self.norm_1 = norm_class(d_model, device=device) |
|
self.attn = attn_class( |
|
d_model=d_model, |
|
n_heads=n_heads, |
|
fc_type=fc_type, |
|
device=device, |
|
**attn_config_subset_for_attn_class, |
|
bias=not no_bias |
|
) |
|
self.norm_2 = None |
|
if not getattr(FFN_CLASS_REGISTRY[ffn_config["ffn_type"]], "_has_norm", False): |
|
self.norm_2 = norm_class(d_model, device=device) |
|
self.ffn = build_ffn( |
|
d_model=d_model, |
|
expansion_ratio=expansion_ratio, |
|
device=device, |
|
bias=not no_bias, |
|
**ffn_config |
|
) |
|
self.resid_attn_dropout = nn.Dropout(resid_pdrop) |
|
self.resid_ffn_dropout = nn.Dropout(resid_pdrop) |
|
self.use_pad_tok_in_ffn = use_pad_tok_in_ffn |
|
|
|
def forward( |
|
self, |
|
x: torch.Tensor, |
|
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
|
attn_bias: Optional[torch.Tensor] = None, |
|
rotary_emb_w_meta_info: Optional[Dict] = None, |
|
attention_mask: Optional[torch.ByteTensor] = None, |
|
is_causal: bool = True, |
|
output_attentions: bool = False, |
|
alibi_slopes: Optional[torch.Tensor] = None, |
|
flash_attn_padding_info: Optional[dict[str, torch.Tensor]] = None, |
|
) -> Tuple[ |
|
torch.Tensor, |
|
Optional[torch.Tensor], |
|
Optional[Tuple[torch.Tensor, torch.Tensor]], |
|
]: |
|
a = self.norm_1(x) |
|
(b, attn_weights, past_key_value) = self.attn( |
|
a, |
|
past_key_value=past_key_value, |
|
attn_bias=attn_bias, |
|
rotary_emb_w_meta_info=rotary_emb_w_meta_info, |
|
attention_mask=attention_mask, |
|
is_causal=is_causal, |
|
needs_weights=output_attentions, |
|
alibi_slopes=alibi_slopes, |
|
flash_attn_padding_info=flash_attn_padding_info, |
|
) |
|
x = x + self.resid_attn_dropout(b) |
|
m = x |
|
if self.norm_2 is not None: |
|
m = self.norm_2(x) |
|
(batch_size, seq_len) = m.size()[:2] |
|
indices = None |
|
if not self.use_pad_tok_in_ffn: |
|
assert unpad_input is not None |
|
(m, indices, _, _) = unpad_input(m, attention_mask) |
|
n = self.ffn(m) |
|
if not self.use_pad_tok_in_ffn: |
|
assert pad_input is not None |
|
n = pad_input(n, indices, batch_size, seq_len) |
|
x = x + self.resid_ffn_dropout(n) |
|
return (x, attn_weights, past_key_value) |
|
|