yuhaofeng-shiba commited on
Commit
402c662
1 Parent(s): 7d77ab3

first upload code and model

Browse files
app.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ import gradio as gr
4
+ import argparse
5
+ from utils import load_hyperparam, load_model
6
+ from models.tokenize import Tokenizer
7
+ from models.llama import *
8
+ from generate import LmGeneration
9
+
10
+
11
+ args = None
12
+ lm_generation = None
13
+
14
+
15
+ def init_args():
16
+ global args
17
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
18
+ args = parser.parse_args()
19
+ args.load_model_path = './model_file/chatllama_7b.bin'
20
+ args.config_path = './config/llama_7b.json'
21
+ args.spm_model_path = './model_file/tokenizer.model'
22
+ args.batch_size = 1
23
+ args.seq_length = 1024
24
+ args.world_size = 1
25
+ args.use_int8 = False
26
+ args.top_p = 0
27
+ args.repetition_penalty_range = 1024
28
+ args.repetition_penalty_slope = 0
29
+ args.repetition_penalty = 1.15
30
+
31
+ args = load_hyperparam(args)
32
+
33
+ args.tokenizer = Tokenizer(model_path=args.spm_model_path)
34
+ args.vocab_size = args.tokenizer.sp_model.vocab_size()
35
+
36
+
37
+ def init_model():
38
+ global lm_generation
39
+ torch.set_default_tensor_type(torch.HalfTensor)
40
+ model = LLaMa(args)
41
+ torch.set_default_tensor_type(torch.FloatTensor)
42
+ model = load_model(model, args.load_model_path)
43
+ model.eval()
44
+
45
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
46
+ model.to(device)
47
+ lm_generation = LmGeneration(model, args.tokenizer)
48
+
49
+
50
+ def chat(prompt, top_k, temperature):
51
+ args.top_k = int(top_k)
52
+ args.temperature = temperature
53
+ response = lm_generation.generate(args, [prompt])
54
+ return response[0]
55
+
56
+
57
+ if __name__ == '__main__':
58
+ init_args()
59
+ init_model()
60
+ demo = gr.Interface(
61
+ fn=chat,
62
+ inputs=["text", gr.Slider(1, 60, value=40, step=1), gr.Slider(0.1, 2.0, value=1.2, step=0.1)],
63
+ outputs="text",
64
+ )
65
+ demo.launch()
66
+
config/llama_7b.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "emb_size": 4096,
3
+ "feedforward_size": 11008,
4
+ "hidden_size": 4096,
5
+ "hidden_act": "silu",
6
+ "heads_num": 32,
7
+ "layers_num": 32,
8
+ "dropout": 0.1,
9
+ "data_processor": "lm",
10
+ "max_seq_length": 2048,
11
+ "embedding": ["word"],
12
+ "remove_transformer_bias": true,
13
+ "remove_embedding_layernorm": true,
14
+ "rotary_position_embedding": true,
15
+ "encoder": "transformer",
16
+ "feed_forward": "gated",
17
+ "mask": "causal",
18
+ "layernorm_positioning": "pre",
19
+ "layernorm": "rms",
20
+ "target": ["lm"]
21
+ }
generate.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+
5
+ def apply_temperature(scores, tempt):
6
+ if tempt > 0:
7
+ scores = scores / tempt
8
+ return scores
9
+
10
+
11
+ def apply_top_p(scores, top_p, filter_value=-float("Inf"), min_tokens_to_keep=1):
12
+ if top_p > 0 and top_p < 1:
13
+ sorted_logits, sorted_indices = torch.sort(scores, descending=False)
14
+ cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
15
+
16
+ # Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
17
+ sorted_indices_to_remove = cumulative_probs <= (1 - top_p)
18
+ if min_tokens_to_keep > 1:
19
+ # Keep at least min_tokens_to_keep
20
+ sorted_indices_to_remove[..., -min_tokens_to_keep:] = 0
21
+
22
+ # scatter sorted tensors to original indexing
23
+ indices_to_remove = sorted_indices_to_remove.scatter(
24
+ 1, sorted_indices, sorted_indices_to_remove
25
+ )
26
+ scores = scores.masked_fill(indices_to_remove, filter_value)
27
+ return scores
28
+
29
+
30
+ def apply_top_k(logits, top_k):
31
+ top_k = min(top_k, logits.size(-1)) # Safety check
32
+ if top_k > 0:
33
+ # Remove all tokens with a probability less than the last token of the top-k
34
+ indices_to_remove = logits < torch.topk(logits.float(), top_k)[0][..., -1, None]
35
+ logits[indices_to_remove] = -float("Inf")
36
+
37
+ return logits
38
+
39
+ def apply_advanced_repetition_penalty(
40
+ input_ids, scores, penalty_range, penalty_slope, penalty
41
+ ):
42
+ penalty_range = int(penalty_range)
43
+ clipped_penalty_range = min(input_ids.shape[-1], penalty_range)
44
+
45
+ if penalty != 1.0:
46
+ if penalty_range > 0:
47
+ if clipped_penalty_range < input_ids.shape[1]:
48
+ input_ids = input_ids[..., -clipped_penalty_range:]
49
+
50
+ if penalty_slope != 0:
51
+ _penalty = (
52
+ torch.arange(
53
+ penalty_range, dtype=scores.dtype, device=scores.device
54
+ )
55
+ / (penalty_range - 1)
56
+ ) * 2.0 - 1
57
+ _penalty = (penalty_slope * _penalty) / (
58
+ 1 + torch.abs(_penalty) * (penalty_slope - 1)
59
+ )
60
+ _penalty = 1 + ((_penalty + 1) / 2).unsqueeze(0) * (penalty - 1)
61
+ penalty = _penalty[..., -clipped_penalty_range:]
62
+
63
+ score = torch.gather(scores, 1, input_ids)
64
+ score = torch.where(score <= 0, score * penalty, score / penalty)
65
+ scores.scatter_(1, input_ids, score)
66
+
67
+ return scores
68
+
69
+
70
+ class LmGeneration:
71
+ def __init__(self, model, tokenizer):
72
+ self.model = model
73
+ self.tokenizer = tokenizer
74
+
75
+ def generate(self, args, prompts, cut_off=None, cut_off_times=1):
76
+ if cut_off is not None:
77
+ cut_off_times = [cut_off_times for i in range(len(prompts))]
78
+ batch = len(prompts)
79
+ assert batch <= args.batch_size
80
+
81
+ prompt_tokens = [args.tokenizer.encode(x, bos=True, eos=False) for x in prompts]
82
+
83
+ min_prompt_len = min([len(x) for x in prompt_tokens])
84
+ # max_prompt_len = max([len(x) for x in prompt_tokens])
85
+
86
+ total_len = args.seq_length
87
+
88
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
89
+ tokens = torch.full((batch, total_len), self.tokenizer.pad_id).to(device).long()
90
+ for idx, t in enumerate(prompt_tokens):
91
+ tokens[idx, : len(t)] = torch.tensor(t).long()
92
+ mask = tokens != self.tokenizer.pad_id
93
+ start_pos = min_prompt_len
94
+ prev_pos = 0
95
+ continue_exsample = [i for i in range(batch)]
96
+ with torch.no_grad():
97
+ for cur_pos in range(start_pos, total_len):
98
+ print(cur_pos)
99
+ logits = self.model.forward(tokens[continue_exsample, prev_pos:cur_pos], prev_pos, continue_exsample).float()
100
+ next_token_scores = apply_top_k(logits, top_k=args.top_k)
101
+ next_token_scores = apply_top_p(next_token_scores, args.top_p)
102
+ next_token_scores = apply_temperature(next_token_scores, args.temperature)
103
+ next_token_scores = apply_advanced_repetition_penalty(
104
+ tokens[continue_exsample, :cur_pos],
105
+ next_token_scores,
106
+ args.repetition_penalty_range,
107
+ args.repetition_penalty_slope,
108
+ args.repetition_penalty
109
+ )
110
+ scores = F.softmax(next_token_scores, dim=-1)
111
+ next_token = torch.multinomial(scores, num_samples=1).squeeze(1)
112
+ next_token = next_token.reshape(-1)
113
+ next_token = torch.where(
114
+ mask[continue_exsample, cur_pos], tokens[continue_exsample, cur_pos], next_token
115
+ )
116
+ tokens[continue_exsample, cur_pos] = next_token
117
+ prev_pos = cur_pos
118
+ # remove eos examples.
119
+ continue_exsample = []
120
+ for i, t in enumerate(tokens.tolist()):
121
+ try:
122
+ t.index(self.tokenizer.eos_id)
123
+ except ValueError:
124
+ if cut_off is not None:
125
+ if cut_off == self.tokenizer.decode(t[:cur_pos + 1])[-len(cut_off):]:
126
+ if cut_off_times[i] == 1:
127
+ continue
128
+ else:
129
+ cut_off_times[i] -= 1
130
+ continue_exsample.append(i)
131
+ if len(continue_exsample) == 0:
132
+ break
133
+
134
+ decoder = []
135
+ for i, t in enumerate(tokens.tolist()):
136
+ t = t[: args.seq_length]
137
+ try:
138
+ t = t[: t.index(self.tokenizer.pad_id)]
139
+ t = t[: t.index(self.tokenizer.eos_id)]
140
+ except ValueError:
141
+ pass
142
+ decoder.append(self.tokenizer.decode(t))
143
+
144
+ return decoder
model_file/chatllama_7b.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b5bb1fb1ddf737e7f1fcbe0284ecd384dbe8f243d843b82fcdf59fd00e9b3c61
3
+ size 13476956615
model_file/tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9e556afd44213b6bd1be2b850ebbbd98f5481437a8021afaf58ee7fb1818d347
3
+ size 499723
models/llama.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from models.norm import RMSNorm
5
+ from models.rope import precompute_freqs_cis, apply_rotary_emb
6
+ import bitsandbytes as bnb
7
+ import math
8
+
9
+
10
+ class NormalLinear(nn.Linear):
11
+ def reset_parameters(self) -> None:
12
+ pass
13
+
14
+
15
+ class BnbInt8Linear(bnb.nn.Linear8bitLt):
16
+ def __init__(self, *args, **kwargs):
17
+ super().__init__(has_fp16_weights=False, threshold=6.0, *args, **kwargs)
18
+
19
+ def reset_parameters(self) -> None:
20
+ pass
21
+
22
+
23
+ def get_linear_layer(use_int8):
24
+ if use_int8:
25
+ return BnbInt8Linear
26
+ return NormalLinear
27
+
28
+
29
+ class WordEmbedding(nn.Module):
30
+ def __init__(self, args):
31
+ super(WordEmbedding, self).__init__()
32
+ self.embedding = nn.Embedding(args.vocab_size, args.emb_size)
33
+
34
+ def forward(self, src):
35
+ emb = self.embedding(src)
36
+ return emb
37
+
38
+
39
+ class MultiHeadedAttention(nn.Module):
40
+ def __init__(self, args, hidden_size, heads_num, attention_head_size, has_bias=True, use_int8=True):
41
+ super(MultiHeadedAttention, self).__init__()
42
+ self.heads_num = heads_num
43
+
44
+ self.per_head_size = attention_head_size
45
+ self.inner_hidden_size = heads_num * attention_head_size
46
+
47
+ Linear = get_linear_layer(use_int8)
48
+ self.linear_layers = nn.ModuleList(
49
+ [Linear(hidden_size, self.inner_hidden_size, bias=has_bias) for _ in range(3)]
50
+ )
51
+
52
+ self.final_linear = Linear(self.inner_hidden_size, hidden_size, bias=has_bias)
53
+
54
+ # add cache to reduce compute source.
55
+ self.cache_k = torch.zeros(
56
+ (args.batch_size, args.seq_length, self.heads_num, self.per_head_size)
57
+ )
58
+ self.cache_v = torch.zeros(
59
+ (args.batch_size, args.seq_length, self.heads_num, self.per_head_size)
60
+ )
61
+
62
+ def forward(self, key, value, query, start_pos, continue_exsample, mask, freqs_cis):
63
+ batch_size, seq_length, _ = query.size()
64
+ heads_num = self.heads_num
65
+ per_head_size = self.per_head_size
66
+ query, key, value = [l(x).view(batch_size, -1, heads_num, per_head_size) \
67
+ for l, x in zip(self.linear_layers, (query, key, value))]
68
+ query, key = apply_rotary_emb(query, key, freqs_cis=freqs_cis)
69
+ if self.cache_k.device != key.device:
70
+ self.cache_k = self.cache_k.to(key)
71
+ if self.cache_v.device != value.device:
72
+ self.cache_v = self.cache_v.to(value)
73
+
74
+ self.cache_k[continue_exsample, start_pos: start_pos + seq_length] = key
75
+ self.cache_v[continue_exsample, start_pos: start_pos + seq_length] = value
76
+
77
+ key = self.cache_k[continue_exsample, : start_pos + seq_length]
78
+ value = self.cache_v[continue_exsample, : start_pos + seq_length]
79
+
80
+ query, key, value = [x.transpose(1, 2) for x in (query, key, value)]
81
+
82
+ scores = torch.matmul(query, key.transpose(-2, -1))
83
+ scores = scores / math.sqrt(float(per_head_size))
84
+ if mask is not None:
85
+ scores += mask
86
+ # probs = nn.Softmax(dim=-1)(scores)
87
+ probs = F.softmax(scores.float(), dim=-1).type_as(query)
88
+ output = torch.matmul(probs, value).transpose(1, 2).\
89
+ contiguous().view(batch_size, seq_length, -1)
90
+ return self.final_linear(output)
91
+
92
+
93
+ class GatedFeedForward(nn.Module):
94
+ def __init__(self, hidden_size, feedforward_size, has_bias=True, use_int8=True):
95
+ super(GatedFeedForward, self).__init__()
96
+ Linear = get_linear_layer(use_int8)
97
+ self.linear_gate = Linear(hidden_size, feedforward_size, bias=has_bias)
98
+ self.linear_1 = Linear(hidden_size, feedforward_size, bias=has_bias)
99
+ self.linear_2 = Linear(feedforward_size, hidden_size, bias=has_bias)
100
+ self.act = F.silu
101
+
102
+ def forward(self, x):
103
+ # gate = self.act(self.linear_gate(x))
104
+ gate = self.act(self.linear_gate(x)).type_as(x)
105
+ inter_linear = self.linear_1(x)
106
+ inter = gate * inter_linear
107
+ output = self.linear_2(inter)
108
+ return output
109
+
110
+
111
+ class TransformerLayer(nn.Module):
112
+ def __init__(self, args):
113
+ super(TransformerLayer, self).__init__()
114
+
115
+ if hasattr(args, "attention_head_size"):
116
+ attention_head_size = args.attention_head_size
117
+ else:
118
+ attention_head_size = args.hidden_size // args.heads_num
119
+
120
+ has_bias = bool(1 - args.remove_transformer_bias)
121
+ # Multi-head Attention
122
+ self.self_attn = MultiHeadedAttention(
123
+ args, args.hidden_size, args.heads_num, attention_head_size, has_bias=has_bias,
124
+ use_int8=args.use_int8
125
+ )
126
+
127
+ # FFN
128
+ self.feed_forward = GatedFeedForward(
129
+ args.hidden_size, args.feedforward_size, has_bias, use_int8=args.use_int8
130
+ )
131
+
132
+ self.layer_norm_1 = RMSNorm(args.hidden_size)
133
+ self.layer_norm_2 = RMSNorm(args.hidden_size)
134
+
135
+ def forward(self, hidden, start_pos, continue_exsample, mask, freqs_cis=None):
136
+ inter = self.layer_norm_1(hidden)
137
+ inter = self.self_attn(inter, inter, inter, start_pos, continue_exsample, mask, freqs_cis)
138
+ hidden = hidden + inter
139
+ output = self.layer_norm_2(hidden)
140
+ output = self.feed_forward(output) + hidden
141
+ return output
142
+
143
+
144
+ class TransformerEncoder(nn.Module):
145
+ def __init__(self, args):
146
+ super(TransformerEncoder, self).__init__()
147
+ self.mask = args.mask
148
+ self.layers_num = args.layers_num
149
+
150
+ self.transformer = nn.ModuleList(
151
+ [TransformerLayer(args) for _ in range(self.layers_num)]
152
+ )
153
+
154
+ self.layer_norm = RMSNorm(args.hidden_size)
155
+ self.freqs_cis = precompute_freqs_cis(args.hidden_size // args.heads_num, args.max_seq_length * 2)
156
+
157
+ def forward(self, emb, start_pos, continue_exsample):
158
+ batch_size, seq_length, _ = emb.size()
159
+ mask = None
160
+ if seq_length > 1:
161
+ mask = torch.ones(seq_length, seq_length, device=emb.device)
162
+ mask = torch.tril(mask)
163
+ mask = (1.0 - mask) * -10000
164
+ mask = mask.repeat(batch_size, 1, 1, 1)
165
+
166
+ hidden = emb
167
+ freqs_cis = self.freqs_cis[start_pos: start_pos + seq_length].to(hidden.device)
168
+
169
+ for i in range(self.layers_num):
170
+ hidden = self.transformer[i](hidden, start_pos, continue_exsample, mask, freqs_cis=freqs_cis)
171
+ return self.layer_norm(hidden)
172
+
173
+
174
+ class LmOutput(nn.Module):
175
+ def __init__(self, args):
176
+ super(LmOutput, self).__init__()
177
+ # update: lm output not use int8
178
+ Linear = get_linear_layer(False)
179
+ self.lm = Linear(args.hidden_size, args.vocab_size, bias=False)
180
+
181
+ def forward(self, x):
182
+ return self.lm(x[:, -1, :])
183
+
184
+
185
+ class LLaMa(nn.Module):
186
+ def __init__(self, args):
187
+ super(LLaMa, self).__init__()
188
+ self.embedding = WordEmbedding(args)
189
+ self.encoder = TransformerEncoder(args)
190
+ self.target = LmOutput(args)
191
+
192
+ #@torch.inference_mode()
193
+ def forward(self, src, start_pos, continue_exsample):
194
+ emb = self.embedding(src)
195
+ output = self.encoder(emb, start_pos, continue_exsample)
196
+ output = self.target(output)
197
+ return output
models/norm.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ import torch
3
+
4
+
5
+ class RMSNorm(torch.nn.Module):
6
+ def __init__(self, hidden_size, eps=1e-6):
7
+ super().__init__()
8
+ self.eps = eps
9
+ self.weight = nn.Parameter(torch.ones(hidden_size))
10
+
11
+ def _norm(self, x):
12
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
13
+
14
+ def forward(self, x):
15
+ output = self._norm(x.float()).type_as(x)
16
+ return output * self.weight
models/rope.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Tuple
3
+
4
+ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
5
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
6
+ t = torch.arange(end, device=freqs.device) # type: ignore
7
+ freqs = torch.outer(t, freqs).float() # type: ignore
8
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
9
+ return freqs_cis
10
+
11
+
12
+ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
13
+ ndim = x.ndim
14
+ assert 0 <= 1 < ndim
15
+ assert freqs_cis.shape == (x.shape[1], x.shape[-1])
16
+ shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
17
+ return freqs_cis.view(*shape)
18
+
19
+
20
+ def apply_rotary_emb(
21
+ xq: torch.Tensor,
22
+ xk: torch.Tensor,
23
+ freqs_cis: torch.Tensor,
24
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
25
+ xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
26
+ xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
27
+ freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
28
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
29
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
30
+ return xq_out.type_as(xq), xk_out.type_as(xk)
models/tokenize.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # copy from
2
+ # https://github.com/tloen/llama-int8/blob/ce74669c767e42b5082391dd0cfcb621ba40c7f9/llama/tokenizer.py
3
+
4
+ from sentencepiece import SentencePieceProcessor
5
+ from logging import getLogger
6
+ from typing import List
7
+ import os
8
+
9
+
10
+ logger = getLogger()
11
+
12
+
13
+ class Tokenizer:
14
+ def __init__(self, model_path: str):
15
+ # reload tokenizer
16
+ assert os.path.isfile(model_path), model_path
17
+ self.sp_model = SentencePieceProcessor(model_file=model_path)
18
+ logger.info(f"Reloaded SentencePiece model from {model_path}")
19
+
20
+ # BOS / EOS token IDs
21
+ self.n_words: int = self.sp_model.vocab_size()
22
+ self.bos_id: int = self.sp_model.bos_id()
23
+ self.eos_id: int = self.sp_model.eos_id()
24
+ self.pad_id: int = self.sp_model.pad_id()
25
+ logger.info(
26
+ f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}"
27
+ )
28
+ assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()
29
+
30
+ def encode(self, s: str, bos: bool, eos: bool) -> List[int]:
31
+ assert type(s) is str
32
+ t = self.sp_model.encode(s)
33
+ if bos:
34
+ t = [self.bos_id] + t
35
+ if eos:
36
+ t = t + [self.eos_id]
37
+ return t
38
+
39
+ def decode(self, t: List[int]) -> str:
40
+ return self.sp_model.decode(t)
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch==1.9.0
2
+ bitsandbytes==0.37.2
3
+ sentencepiece
4
+ argparse
utils.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import sys
3
+ from argparse import Namespace
4
+ import torch
5
+ import os
6
+
7
+
8
+ def load_hyperparam(default_args):
9
+ """
10
+ Load arguments form argparse and config file
11
+ Priority: default options < config file < command line args
12
+ """
13
+ with open(default_args.config_path, mode="r", encoding="utf-8") as f:
14
+ config_args_dict = json.load(f)
15
+
16
+ default_args_dict = vars(default_args)
17
+
18
+ command_line_args_dict = {k: default_args_dict[k] for k in [
19
+ a[2:] for a in sys.argv if (a[:2] == "--" and "local_rank" not in a)
20
+ ]}
21
+ default_args_dict.update(config_args_dict)
22
+ default_args_dict.update(command_line_args_dict)
23
+ args = Namespace(**default_args_dict)
24
+
25
+ return args
26
+
27
+
28
+ def _load_state_dict_into_model(model_to_load, model_path, start_prefix=""):
29
+ # Convert old format to new format if needed from a PyTorch state_dict
30
+
31
+ # copy state_dict so _load_from_state_dict can modify it
32
+ state_dict = torch.load(model_path, map_location="cpu")
33
+ metadata = getattr(state_dict, "_metadata", None)
34
+ state_dict = state_dict.copy()
35
+ state_dict['target.lm.weight'] = state_dict['target.lm.output_layer.weight']
36
+ del state_dict['target.lm.output_layer.weight']
37
+ state_dict['embedding.embedding.weight'] = state_dict['embedding.word.embedding.weight']
38
+ del state_dict['embedding.word.embedding.weight']
39
+
40
+ if metadata is not None:
41
+ metadata['embedding.embedding'] = metadata['embedding.word.embedding']
42
+ metadata['target.lm'] = metadata['target.lm.output_layer']
43
+ if metadata.get('embedding.dropout', None) is not None:
44
+ del metadata['embedding.dropout']
45
+ del metadata['embedding.word']
46
+ del metadata['embedding.word.embedding']
47
+ del metadata['target.lm.output_layer']
48
+ del metadata['target.lm.softmax']
49
+ del metadata['target.lm.criterion']
50
+ state_dict._metadata = metadata
51
+
52
+ error_msgs = []
53
+
54
+ # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
55
+ # so we need to apply the function recursively.
56
+ def load(module, state_dict, prefix=""):
57
+ local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
58
+ args = (state_dict, prefix, local_metadata, True, [], [], error_msgs)
59
+ # Parameters of module and children will start with prefix. We can exit early if there are none in this
60
+ # state_dict
61
+ if len([key for key in state_dict if key.startswith(prefix)]) > 0:
62
+ import deepspeed
63
+ # In sharded models, each shard has only part of the full state_dict, so only gather
64
+ # parameters that are in the current state_dict.
65
+ named_parameters = dict(module.named_parameters(prefix=prefix[:-1], recurse=False))
66
+ params_to_gather = [named_parameters[k] for k in state_dict.keys() if k in named_parameters]
67
+ if len(params_to_gather) > 0:
68
+ # because zero3 puts placeholders in model params, this context
69
+ # manager gathers (unpartitions) the params of the current layer, then loads from
70
+ # the state dict and then re-partitions them again
71
+ with deepspeed.zero.GatheredParameters(params_to_gather, modifier_rank=0):
72
+ if torch.distributed.get_rank() == 0:
73
+ module._load_from_state_dict(*args)
74
+
75
+ for name, child in module._modules.items():
76
+ if child is not None:
77
+ load(child, state_dict, prefix + name + ".")
78
+
79
+ load(model_to_load, state_dict, prefix=start_prefix)
80
+ # Delete `state_dict` so it could be collected by GC earlier. Note that `state_dict` is a copy of the argument, so
81
+ # it's safe to delete it.
82
+ del state_dict
83
+
84
+ return model_to_load
85
+
86
+
87
+ def convert_normal_parameter_to_int8(model, threshold=6.0, modules_to_not_convert=None, current_key_name=None):
88
+ import bitsandbytes as bnb
89
+ modules_to_not_convert = ["lm"] if modules_to_not_convert is None else modules_to_not_convert
90
+ for name, module in model.named_children():
91
+ if current_key_name is None:
92
+ current_key_name = []
93
+ current_key_name.append(name)
94
+
95
+ if len(list(module.children())) > 0:
96
+ convert_normal_parameter_to_int8(module, threshold, modules_to_not_convert, current_key_name)
97
+
98
+ if isinstance(module, bnb.nn.Linear8bitLt) and name not in modules_to_not_convert:
99
+ # Check if the current key is not in the `modules_to_not_convert`
100
+ if not any(key in ".".join(current_key_name) for key in modules_to_not_convert):
101
+ model._modules[name].weight = bnb.nn.Int8Params(
102
+ module.weight.data,
103
+ requires_grad=False,
104
+ has_fp16_weights=False
105
+ )
106
+ # Force requires grad to False to avoid unexpected errors
107
+ model._modules[name].requires_grad_(False)
108
+ # Remove the last key for recursion
109
+ current_key_name.pop(-1)
110
+ return model
111
+
112
+
113
+ def load_model(model, model_path):
114
+ if os.path.isdir(model_path):
115
+ index_filename = os.path.join(model_path, 'pytorch_model.bin.index.json')
116
+ with open(index_filename, "r") as f:
117
+ index = json.loads(f.read())
118
+ shard_filenames = sorted(set(index["weight_map"].values()))
119
+ shard_filenames = [os.path.join(model_path, f) for f in shard_filenames]
120
+ for shard_file in shard_filenames:
121
+ shard_checkpoint = torch.load(shard_file, map_location='cpu')
122
+ for name, parameter in model.named_parameters():
123
+ if shard_checkpoint.get(name, None) is not None:
124
+ if 'target' in name:
125
+ parameter.data = shard_checkpoint['target.lm.output_layer.weight']
126
+ elif 'embedding' in name:
127
+ parameter.data = shard_checkpoint['embedding.word.embedding.weight']
128
+ else:
129
+ parameter.data = shard_checkpoint[name]
130
+ parameter.requires_grad = False
131
+ del shard_checkpoint
132
+ else:
133
+ checkpoint = torch.load(model_path, map_location='cpu')
134
+ for parameter_name, parameter in model.named_parameters():
135
+ if 'target' in parameter_name:
136
+ parameter.data = checkpoint['target.lm.output_layer.weight']
137
+ elif 'embedding' in parameter_name:
138
+ parameter.data = checkpoint['embedding.word.embedding.weight']
139
+ else:
140
+ parameter.data = checkpoint[parameter_name]
141
+ parameter.requires_grad = False
142
+ del checkpoint
143
+ return model