ahatamiz commited on
Commit
46a13cf
1 Parent(s): 318d126

Upload 2 files

Browse files
Files changed (2) hide show
  1. hf_model.py +297 -0
  2. lm_harness.py +112 -0
hf_model.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
3
+ from transformers.modeling_utils import PreTrainedModel
4
+ from .hf_config import HFConfig
5
+ import torch.nn as nn
6
+ from lit_gpt.model import Block, MBlock
7
+ try:
8
+ from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn
9
+ except ImportError:
10
+ RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
11
+ from typing import Any, Literal, Optional, Type, Union, List, Tuple
12
+ from lit_gpt.config import Config
13
+
14
+ RoPECache = Tuple[torch.Tensor, torch.Tensor]
15
+ KVCache = Tuple[torch.Tensor, torch.Tensor]
16
+
17
+ class HF_GPTPreTrainedModel(PreTrainedModel):
18
+ config_class = HFConfig
19
+ supports_gradient_checkpointing = True
20
+ _no_split_modules = ["Block"]
21
+
22
+ def __init__(self, *inputs, **kwargs):
23
+ super().__init__(*inputs, **kwargs)
24
+
25
+
26
+ class HF_GPTModel(HF_GPTPreTrainedModel):
27
+ def __init__(self, config):
28
+ super().__init__(config)
29
+ self.config = config
30
+ self.lit_config = Config.from_name(config.name)
31
+ config = self.lit_config
32
+ self.embeddings = nn.Embedding(config.vocab_size, config.n_embd)
33
+ self.h = nn.ModuleList([Block(config, i) for i in range(config.n_layer)])
34
+ self.ln_f= config.norm_class(config.n_embd, eps=config.norm_eps)
35
+
36
+ self.rope_cache: Optional[RoPECache] = None
37
+ self.mask_cache: Optional[torch.Tensor] = None
38
+ self.kv_caches: List[KVCache] = []
39
+ self.max_len = self.lit_config.block_size
40
+ self.mamba_init = self.lit_config.mamba or self.lit_config.mamba_init
41
+
42
+ def forward(
43
+ self,
44
+ input_ids: torch.LongTensor = None,
45
+ attention_mask: Optional[torch.Tensor] = None,
46
+ position_ids: Optional[torch.LongTensor] = None,
47
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
48
+ inputs_embeds: Optional[torch.FloatTensor] = None,
49
+ labels: Optional[torch.LongTensor] = None,
50
+ use_cache: Optional[bool] = None,
51
+ output_attentions: Optional[bool] = None,
52
+ output_hidden_states: Optional[bool] = None,
53
+ return_dict: Optional[bool] = True,
54
+ max_seq_length: Optional[int] = None,
55
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
56
+ idx = input_ids
57
+ input_pos = position_ids
58
+
59
+ assert inputs_embeds is None
60
+
61
+ if self.lit_config.mamba:
62
+ hidden_states = self.embeddings(idx)
63
+ residual = None
64
+ for block in self.h:
65
+ hidden_states, residual = block(
66
+ hidden_states, residual, inference_params=None
67
+ )
68
+ norm_f = self.ln_f
69
+ if not self.lit_config.fused_add_norm:
70
+ residual = (hidden_states + residual) if residual is not None else hidden_states
71
+ hidden_states = norm_f(residual.to(dtype= norm_f.weight.dtype))
72
+ else:
73
+ # Set prenorm=False here since we don't need the residual
74
+ fused_add_norm_fn = rms_norm_fn if isinstance(norm_f, RMSNorm) else layer_norm_fn
75
+ hidden_states = fused_add_norm_fn(
76
+ hidden_states,
77
+ norm_f.weight,
78
+ norm_f.bias,
79
+ eps=norm_f.eps,
80
+ residual=residual,
81
+ prenorm=False,
82
+ residual_in_fp32=self.lit_config.residual_in_fp32,
83
+ )
84
+
85
+ else:
86
+ B, T = idx.size()
87
+ # use_kv_cache = input_pos is not None
88
+ use_kv_cache = input_pos is not None
89
+ block_size = self.lit_config.block_size
90
+ if max_seq_length is None:
91
+ max_seq_length = block_size
92
+ if use_kv_cache: # not relevant otherwise
93
+ assert (
94
+ max_seq_length >= T
95
+ ), f"Cannot forward sequence of length {T}, max seq length is only {max_seq_length}"
96
+ #assert max_seq_length <= block_size, f"Cannot attend to {max_seq_length}, block size is only {block_size}"
97
+ #assert block_size >= T, f"Cannot forward sequence of length {T}, block size is only {block_size}"
98
+ if not self.lit_config.nope:
99
+ if self.rope_cache is None:
100
+ self.rope_cache = self.build_rope_cache(idx, self.max_len)
101
+ elif T> self.max_len:
102
+ self.max_len = T
103
+ self.rope_cache = self.build_rope_cache(idx, self.max_len)
104
+ cos, sin = self.rope_cache
105
+ # passing `attn_mask` to SDPA downgrades it to use the inefficient implementation. since we only need the mask
106
+ # for the kv-cache support (only during inference), we only create it in that situation
107
+ # this will be resolved by https://github.com/pytorch/pytorch/issues/96099
108
+ if use_kv_cache and self.mask_cache is None:
109
+ self.mask_cache = self.build_mask_cache(idx)
110
+ if use_kv_cache:
111
+ if not self.lit_config.nope:
112
+ cos = cos.index_select(0, input_pos)
113
+ sin = sin.index_select(0, input_pos)
114
+ mask = self.mask_cache.index_select(2, input_pos)
115
+ mask = mask[:, :, :, :max_seq_length]
116
+ else:
117
+ if not self.lit_config.nope:
118
+ cos = cos[:T]
119
+ sin = sin[:T]
120
+ mask = None
121
+ if self.lit_config.nope:
122
+ rope = None
123
+ else:
124
+ rope = (cos, sin)
125
+ # forward the model itself
126
+ x = self.embeddings(idx) # token embeddings of shape (b, t, n_embd)
127
+ if not use_kv_cache:
128
+ for block in self.h:
129
+ x, *_ = block(x, rope, max_seq_length)
130
+ else:
131
+ if self.lit_config.nope:
132
+ self.kv_caches = self.kv_caches or self.build_kv_caches(x, max_seq_length, None )
133
+ else:
134
+ self.kv_caches = self.kv_caches or self.build_kv_caches(x, max_seq_length, cos.size(-1) * 2)
135
+ for i, block in enumerate(self.h):
136
+ x, self.kv_caches[i] = block(x, rope, max_seq_length, mask, input_pos, self.kv_caches[i])
137
+
138
+ hidden_states = self.ln_f(x)
139
+
140
+ return BaseModelOutputWithPast(
141
+ last_hidden_state=hidden_states,
142
+ past_key_values=None,
143
+ hidden_states=hidden_states,
144
+ attentions=None
145
+ )
146
+
147
+
148
+ def build_rope_cache(self, idx: torch.Tensor, seq_len: int) -> RoPECache:
149
+ return build_rope_cache(
150
+ seq_len=seq_len,
151
+ n_elem=int(self.lit_config.rotary_percentage * self.lit_config.head_size),
152
+ dtype=torch.float32,
153
+ device=idx.device,
154
+ condense_ratio=self.lit_config.condense_ratio,
155
+ )
156
+
157
+ def build_mask_cache(self, idx: torch.Tensor) -> torch.Tensor:
158
+ ones = torch.ones((self.lit_config.block_size, self.lit_config.block_size), device=idx.device, dtype=torch.bool)
159
+ return torch.tril(ones).unsqueeze(0).unsqueeze(0)
160
+
161
+ def build_kv_caches(self, idx: torch.Tensor, max_seq_length: int, rope_cache_length: int) -> List[KVCache]:
162
+ B = idx.size(0)
163
+ heads = 1 if self.lit_config.n_query_groups == 1 else self.lit_config.n_query_groups
164
+ if rope_cache_length is not None:
165
+ k_cache_shape = (
166
+ B,
167
+ max_seq_length,
168
+ heads,
169
+ rope_cache_length + self.lit_config.head_size - int(self.lit_config.rotary_percentage * self.lit_config.head_size),
170
+ )
171
+ else:
172
+ k_cache_shape = (
173
+ B,
174
+ max_seq_length,
175
+ heads,
176
+ self.lit_config.head_size,
177
+ )
178
+ v_cache_shape = (B, max_seq_length, heads, self.lit_config.head_size)
179
+ device = idx.device
180
+ return [
181
+ (torch.zeros(k_cache_shape, device=device), torch.zeros(v_cache_shape, device=device))
182
+ for _ in range(self.lit_config.n_layer)
183
+ ]
184
+
185
+
186
+
187
+ class HF_GPTForCausalLM(HF_GPTPreTrainedModel):
188
+ _tied_weights_keys = ["lm_head.weight"]
189
+
190
+ def __init__(self, config):
191
+ super().__init__(config)
192
+ self.lit_config = Config.from_name(config.name)
193
+ self.config = config
194
+ self.transformer = HF_GPTModel(config)
195
+ config = self.lit_config
196
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
197
+ self.n_layer = config.n_layer
198
+ self.post_init()
199
+
200
+ def get_input_embeddings(self):
201
+ return self.transformer.embeddings
202
+
203
+ def set_input_embeddings(self, value):
204
+ self.transformer.embeddings = value
205
+
206
+ def get_output_embeddings(self):
207
+ return self.lm_head
208
+
209
+ def set_output_embeddings(self, new_embeddings):
210
+ self.lm_head = new_embeddings
211
+
212
+ def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
213
+ inputs = {"input_ids": input_ids}
214
+ if past is not None:
215
+ inputs["past_key_values"] = past
216
+ return inputs
217
+
218
+ def forward(
219
+ self,
220
+ input_ids: torch.LongTensor = None,
221
+ attention_mask: Optional[torch.Tensor] = None,
222
+ position_ids: Optional[torch.LongTensor] = None,
223
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
224
+ inputs_embeds: Optional[torch.FloatTensor] = None,
225
+ labels: Optional[torch.LongTensor] = None,
226
+ use_cache: Optional[bool] = None,
227
+ output_attentions: Optional[bool] = None,
228
+ output_hidden_states: Optional[bool] = None,
229
+ return_dict: Optional[bool] = True,
230
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
231
+ transformer_outputs = self.transformer(
232
+ input_ids,
233
+ attention_mask=attention_mask,
234
+ position_ids=position_ids,
235
+ past_key_values=past_key_values,
236
+ inputs_embeds=inputs_embeds,
237
+ use_cache=use_cache,
238
+ output_attentions=output_attentions,
239
+ output_hidden_states=output_hidden_states,
240
+ return_dict=return_dict,
241
+ )
242
+
243
+ hidden_states = transformer_outputs[0]
244
+ logits = self.lm_head(hidden_states)
245
+
246
+ loss = None
247
+ if labels is not None:
248
+ shift_logits = logits[..., :-1, :].contiguous()
249
+ shift_labels = labels[..., 1:].contiguous()
250
+ loss_fct = nn.CrossEntropyLoss()
251
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
252
+
253
+ if not return_dict:
254
+ output = (logits,) + transformer_outputs[1:]
255
+ return ((loss,) + output) if loss is not None else output
256
+
257
+ return CausalLMOutputWithPast(
258
+ loss=loss,
259
+ logits=logits,
260
+ past_key_values=transformer_outputs.past_key_values,
261
+ hidden_states=transformer_outputs.hidden_states,
262
+ attentions=transformer_outputs.attentions,
263
+ )
264
+
265
+ def _reorder_cache(self, past, beam_idx):
266
+ return tuple(
267
+ tuple(past_state.index_select(0, beam_idx) for past_state in layer_past)
268
+ for layer_past in past
269
+ )
270
+
271
+ def build_rope_cache(
272
+ seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000, condense_ratio: int = 1
273
+ ) -> RoPECache:
274
+ """Enhanced Transformer with Rotary Position Embedding.
275
+
276
+ Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/
277
+ transformers/rope/__init__.py. MIT License:
278
+ https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license.
279
+ """
280
+ # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
281
+ theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, device=device) / n_elem))
282
+
283
+ # Create position indexes `[0, 1, ..., seq_len - 1]`
284
+ seq_idx = torch.arange(seq_len, device=device) / condense_ratio
285
+
286
+ # Calculate the product of position index and $\theta_i$
287
+ idx_theta = torch.outer(seq_idx, theta)
288
+
289
+ cos, sin = torch.cos(idx_theta), torch.sin(idx_theta)
290
+
291
+ # added by peiyuan to ensure same data type with q, k, to use fused rotary embedding
292
+ if dtype == torch.bfloat16:
293
+ return cos.bfloat16(), sin.bfloat16()
294
+ # this is to mimic the behaviour of complex32, else we will get different results
295
+ if dtype in (torch.float16, torch.bfloat16, torch.int8):
296
+ return cos.half(), sin.half()
297
+ return cos, sin
lm_harness.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from __future__ import annotations
4
+
5
+ import fla # noqa
6
+ from lm_eval.__main__ import cli_evaluate
7
+ from lm_eval.api.registry import register_model
8
+ from lm_eval.models.huggingface import HFLM
9
+ from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
10
+ from hf_gpt.hf_model import HF_GPTForCausalLM
11
+ from hf_gpt.hf_config import HFConfig
12
+ import requests
13
+ import wandb
14
+ import lm_eval
15
+ # from lm_eval.loggers import WandbLogger
16
+ import argparse
17
+ AutoConfig.register("hf_gpt",HFConfig)
18
+ AutoModelForCausalLM.register(HFConfig,HF_GPTForCausalLM)
19
+ import logging
20
+ logging.basicConfig(level=logging.INFO)
21
+ import torch
22
+ import os
23
+ import pdb
24
+
25
+ import os
26
+ os.environ['HF_HOME'] = '/lustre/fs8/portfolios/nvr/users/ahatamizadeh/hf_cache/'
27
+
28
+
29
+ def is_directory_non_empty(directory):
30
+ if not os.path.isdir(directory):
31
+ return "The provided path is not a directory."
32
+
33
+ return len(os.listdir(directory)) > 0
34
+
35
+ def main(args):
36
+ ### First convert to Huggingface models when neccessary
37
+ import datasets
38
+
39
+ datasets.config.HF_DATASETS_TRUST_REMOTE_CODE = True
40
+ hf_save_dir = args.hf_save_dir or os.path.dirname(args.ckpt_path)
41
+ # for example: /lustre/fsw/portfolios/nvr/users/soyang/code/next_gen_llm-1/checkpoint/outputs/tsz512x4k_20B_Samba_421M_tsz512x4k_20B_Samba_421M_sy_stream_v11/iter-009198-ckpt.pth
42
+ ckpt = torch.load(args.ckpt_path)
43
+ print("Checkpoint loaded")
44
+ hf_config = HFConfig(name=args.model_name)
45
+ hf_model = HF_GPTForCausalLM(hf_config)
46
+ model_weight = ckpt['model']
47
+ new_weight = {}
48
+ for k, v in model_weight.items():
49
+ if 'wte' in k:
50
+ new_weight[k.replace("wte", "embeddings")] = v
51
+ elif 'beta_proj' in k:
52
+ new_weight[k.replace("beta_proj", "b_proj")] = v
53
+ elif 'bias_proj' in k:
54
+ new_weight[k.replace("bias_proj", "b_proj")] = v
55
+ else:
56
+ new_weight[k] = v
57
+
58
+ hf_model.load_state_dict(new_weight)
59
+ tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, trust_remote_code=True)
60
+ hf_model.save_pretrained(hf_save_dir)
61
+ tokenizer.save_pretrained(hf_save_dir)
62
+ print("Huggingface model saved")
63
+
64
+ ### Then call lm_eval
65
+ tasks = args.tasks.split(',')
66
+ assert hf_save_dir is not None
67
+ assert args.dtype in ['bfloat16', 'float32']
68
+ # wandb_logger = wandb.init(project="llm_next_gen", name=args.exp_name, id=args.exp_name, group=args.wandb_group_name) # or empty if wandb.init(...) already called before
69
+ print("Start lm eval....")
70
+
71
+ results = lm_eval.simple_evaluate(
72
+ model="hf",
73
+ model_args=f"pretrained={hf_save_dir},trust_remote_code=True,dtype={args.dtype}",
74
+ tasks=tasks,
75
+ device="cuda",
76
+ log_samples=False,
77
+ batch_size=1,
78
+ num_fewshot=args.num_fewshot,
79
+ )['results']
80
+
81
+ print('swde: {}'.format(results['swde']['contains,none']))
82
+ print('squad_completion: {}'.format(results['squad_completion']['contains,none']))
83
+ # print('mmlu: {}'.format(results['mmlu']['acc,none']))
84
+ print('piqa: {}'.format(results['piqa']['acc,none']))
85
+ print('hellaswag: {}'.format(results['hellaswag']['acc_norm,none']))
86
+ print('winogrande: {}'.format(results['winogrande']['acc,none']))
87
+ print('arc_easy: {}'.format(results['arc_easy']['acc,none']))
88
+ print('arc_challenge: {}'.format(results['arc_challenge']['acc_norm,none']))
89
+ print('wikitext, ppl: {}'.format(results['wikitext']['word_perplexity,none']))
90
+ print('lambada_openai, acc: {}'.format(results['lambada_openai']['acc,none']))
91
+ print('lambada_openai, ppl: {}'.format(results['lambada_openai']['perplexity,none']))
92
+
93
+
94
+ if __name__ == "__main__":
95
+ parser = argparse.ArgumentParser(description='LLM Training')
96
+ parser.add_argument('--ckpt_path', type=str, default=None, help='Path to the ckpt directory')
97
+ parser.add_argument('--hf_save_dir', type=str, default=None, help='(Selective) Path to the saved HF model directory')
98
+ parser.add_argument('--dtype', type=str, default='bfloat16', help='Data type to use for inference')
99
+ parser.add_argument('--model_name', type=str, default='Samba_421M', help='Model name')
100
+ parser.add_argument('--exp_name', type=str, default='hf_eval', help='Experiment name')
101
+ parser.add_argument('--wandb_dir', type=str, default='/lustre/fsw/portfolios/nvr/users/soyang/code/next_gen_llm-1/checkpoint/outputs', help='Wandb directory')
102
+ parser.add_argument('--wandb_group_name', type=str, default='lm-eval-harness', help='Wandb group name')
103
+ parser.add_argument('--tasks', type=str, default='wikitext,lambada_openai,piqa,hellaswag,winogrande,arc_easy,arc_challenge,mmlu', help='Tasks to evaluate')
104
+ parser.add_argument('--tokenizer_name', type=str, default="TinyLlama/TinyLlama_v1.1", help="tokenizer name or path")
105
+ parser.add_argument('--batch_size', type=int, default=64)
106
+ parser.add_argument('--num_fewshot', type=int, default=0)
107
+ # do convert or not
108
+ parser.add_argument('--skip_convert', action='store_true', help='Whether to convert to Huggingface model')
109
+
110
+ args = parser.parse_args()
111
+ main(args)
112
+