yuhaofeng-shiba's picture
Update models/llama.py
6259eea
raw
history blame
7.17 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
from models.norm import RMSNorm
from models.rope import precompute_freqs_cis, apply_rotary_emb
import bitsandbytes as bnb
import math
class NormalLinear(nn.Linear):
def reset_parameters(self) -> None:
pass
class BnbInt8Linear(bnb.nn.Linear8bitLt):
def __init__(self, *args, **kwargs):
super().__init__(has_fp16_weights=False, threshold=6.0, *args, **kwargs)
def reset_parameters(self) -> None:
pass
def get_linear_layer(use_int8):
if use_int8:
return BnbInt8Linear
return NormalLinear
class WordEmbedding(nn.Module):
def __init__(self, args):
super(WordEmbedding, self).__init__()
self.embedding = nn.Embedding(args.vocab_size, args.emb_size)
def forward(self, src):
emb = self.embedding(src)
return emb
class MultiHeadedAttention(nn.Module):
def __init__(self, args, hidden_size, heads_num, attention_head_size, has_bias=True, use_int8=True):
super(MultiHeadedAttention, self).__init__()
self.heads_num = heads_num
self.per_head_size = attention_head_size
self.inner_hidden_size = heads_num * attention_head_size
Linear = get_linear_layer(use_int8)
self.linear_layers = nn.ModuleList(
[Linear(hidden_size, self.inner_hidden_size, bias=has_bias) for _ in range(3)]
)
self.final_linear = Linear(self.inner_hidden_size, hidden_size, bias=has_bias)
# add cache to reduce compute source.
self.cache_k = torch.zeros(
(args.batch_size, args.seq_length, self.heads_num, self.per_head_size)
)
self.cache_v = torch.zeros(
(args.batch_size, args.seq_length, self.heads_num, self.per_head_size)
)
def forward(self, key, value, query, start_pos, continue_exsample, mask, freqs_cis):
batch_size, seq_length, _ = query.size()
heads_num = self.heads_num
per_head_size = self.per_head_size
query, key, value = [l(x).view(batch_size, -1, heads_num, per_head_size) \
for l, x in zip(self.linear_layers, (query, key, value))]
query, key = apply_rotary_emb(query, key, freqs_cis=freqs_cis)
if self.cache_k.device != key.device:
self.cache_k = self.cache_k.to(key)
if self.cache_v.device != value.device:
self.cache_v = self.cache_v.to(value)
self.cache_k[continue_exsample, start_pos: start_pos + seq_length] = key
self.cache_v[continue_exsample, start_pos: start_pos + seq_length] = value
key = self.cache_k[continue_exsample, : start_pos + seq_length]
value = self.cache_v[continue_exsample, : start_pos + seq_length]
query, key, value = [x.transpose(1, 2) for x in (query, key, value)]
scores = torch.matmul(query, key.transpose(-2, -1))
scores = scores / math.sqrt(float(per_head_size))
if mask is not None:
scores += mask
# probs = nn.Softmax(dim=-1)(scores)
probs = F.softmax(scores.float(), dim=-1).type_as(query)
output = torch.matmul(probs, value).transpose(1, 2).\
contiguous().view(batch_size, seq_length, -1)
return self.final_linear(output)
class GatedFeedForward(nn.Module):
def __init__(self, hidden_size, feedforward_size, has_bias=True, use_int8=True):
super(GatedFeedForward, self).__init__()
Linear = get_linear_layer(use_int8)
self.linear_gate = Linear(hidden_size, feedforward_size, bias=has_bias)
self.linear_1 = Linear(hidden_size, feedforward_size, bias=has_bias)
self.linear_2 = Linear(feedforward_size, hidden_size, bias=has_bias)
self.act = F.silu
def forward(self, x):
# gate = self.act(self.linear_gate(x))
gate = self.act(self.linear_gate(x)).type_as(x)
inter_linear = self.linear_1(x)
inter = gate * inter_linear
output = self.linear_2(inter)
return output
class TransformerLayer(nn.Module):
def __init__(self, args):
super(TransformerLayer, self).__init__()
if hasattr(args, "attention_head_size"):
attention_head_size = args.attention_head_size
else:
attention_head_size = args.hidden_size // args.heads_num
has_bias = bool(1 - args.remove_transformer_bias)
# Multi-head Attention
self.self_attn = MultiHeadedAttention(
args, args.hidden_size, args.heads_num, attention_head_size, has_bias=has_bias,
use_int8=args.use_int8
)
# FFN
self.feed_forward = GatedFeedForward(
args.hidden_size, args.feedforward_size, has_bias, use_int8=args.use_int8
)
self.layer_norm_1 = RMSNorm(args.hidden_size)
self.layer_norm_2 = RMSNorm(args.hidden_size)
def forward(self, hidden, start_pos, continue_exsample, mask, freqs_cis=None):
inter = self.layer_norm_1(hidden)
inter = self.self_attn(inter, inter, inter, start_pos, continue_exsample, mask, freqs_cis)
hidden = hidden + inter
output = self.layer_norm_2(hidden)
output = self.feed_forward(output) + hidden
return output
class TransformerEncoder(nn.Module):
def __init__(self, args):
super(TransformerEncoder, self).__init__()
self.mask = args.mask
self.layers_num = args.layers_num
self.transformer = nn.ModuleList(
[TransformerLayer(args) for _ in range(self.layers_num)]
)
self.layer_norm = RMSNorm(args.hidden_size)
self.freqs_cis = precompute_freqs_cis(args.hidden_size // args.heads_num, args.max_seq_length * 2)
def forward(self, emb, start_pos, continue_exsample):
batch_size, seq_length, _ = emb.size()
mask = None
if seq_length > 1:
mask = torch.ones(seq_length, seq_length, device=emb.device)
mask = torch.tril(mask)
mask = (1.0 - mask) * -10000
mask = mask.repeat(batch_size, 1, 1, 1)
hidden = emb
freqs_cis = self.freqs_cis[start_pos: start_pos + seq_length].to(hidden.device)
for i in range(self.layers_num):
hidden = self.transformer[i](hidden, start_pos, continue_exsample, mask, freqs_cis=freqs_cis)
return self.layer_norm(hidden)
class LmOutput(nn.Module):
def __init__(self, args):
super(LmOutput, self).__init__()
# update: lm output not use int8
Linear = get_linear_layer(False)
self.lm = Linear(args.hidden_size, args.vocab_size, bias=False)
def forward(self, x):
return self.lm(x[:, -1, :])
class LLaMa(nn.Module):
def __init__(self, args):
super(LLaMa, self).__init__()
self.embedding = WordEmbedding(args)
self.encoder = TransformerEncoder(args)
self.target = LmOutput(args)
#@torch.inference_mode()
def forward(self, src, start_pos, continue_exsample):
emb = self.embedding(src)
output = self.encoder(emb, start_pos, continue_exsample)
output = self.target(output)
return output