Upload 2 files
Browse files- hf_model.py +297 -0
- lm_harness.py +112 -0
hf_model.py
ADDED
@@ -0,0 +1,297 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
|
3 |
+
from transformers.modeling_utils import PreTrainedModel
|
4 |
+
from .hf_config import HFConfig
|
5 |
+
import torch.nn as nn
|
6 |
+
from lit_gpt.model import Block, MBlock
|
7 |
+
try:
|
8 |
+
from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn
|
9 |
+
except ImportError:
|
10 |
+
RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
|
11 |
+
from typing import Any, Literal, Optional, Type, Union, List, Tuple
|
12 |
+
from lit_gpt.config import Config
|
13 |
+
|
14 |
+
RoPECache = Tuple[torch.Tensor, torch.Tensor]
|
15 |
+
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
16 |
+
|
17 |
+
class HF_GPTPreTrainedModel(PreTrainedModel):
|
18 |
+
config_class = HFConfig
|
19 |
+
supports_gradient_checkpointing = True
|
20 |
+
_no_split_modules = ["Block"]
|
21 |
+
|
22 |
+
def __init__(self, *inputs, **kwargs):
|
23 |
+
super().__init__(*inputs, **kwargs)
|
24 |
+
|
25 |
+
|
26 |
+
class HF_GPTModel(HF_GPTPreTrainedModel):
|
27 |
+
def __init__(self, config):
|
28 |
+
super().__init__(config)
|
29 |
+
self.config = config
|
30 |
+
self.lit_config = Config.from_name(config.name)
|
31 |
+
config = self.lit_config
|
32 |
+
self.embeddings = nn.Embedding(config.vocab_size, config.n_embd)
|
33 |
+
self.h = nn.ModuleList([Block(config, i) for i in range(config.n_layer)])
|
34 |
+
self.ln_f= config.norm_class(config.n_embd, eps=config.norm_eps)
|
35 |
+
|
36 |
+
self.rope_cache: Optional[RoPECache] = None
|
37 |
+
self.mask_cache: Optional[torch.Tensor] = None
|
38 |
+
self.kv_caches: List[KVCache] = []
|
39 |
+
self.max_len = self.lit_config.block_size
|
40 |
+
self.mamba_init = self.lit_config.mamba or self.lit_config.mamba_init
|
41 |
+
|
42 |
+
def forward(
|
43 |
+
self,
|
44 |
+
input_ids: torch.LongTensor = None,
|
45 |
+
attention_mask: Optional[torch.Tensor] = None,
|
46 |
+
position_ids: Optional[torch.LongTensor] = None,
|
47 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
48 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
49 |
+
labels: Optional[torch.LongTensor] = None,
|
50 |
+
use_cache: Optional[bool] = None,
|
51 |
+
output_attentions: Optional[bool] = None,
|
52 |
+
output_hidden_states: Optional[bool] = None,
|
53 |
+
return_dict: Optional[bool] = True,
|
54 |
+
max_seq_length: Optional[int] = None,
|
55 |
+
) -> Union[Tuple, BaseModelOutputWithPast]:
|
56 |
+
idx = input_ids
|
57 |
+
input_pos = position_ids
|
58 |
+
|
59 |
+
assert inputs_embeds is None
|
60 |
+
|
61 |
+
if self.lit_config.mamba:
|
62 |
+
hidden_states = self.embeddings(idx)
|
63 |
+
residual = None
|
64 |
+
for block in self.h:
|
65 |
+
hidden_states, residual = block(
|
66 |
+
hidden_states, residual, inference_params=None
|
67 |
+
)
|
68 |
+
norm_f = self.ln_f
|
69 |
+
if not self.lit_config.fused_add_norm:
|
70 |
+
residual = (hidden_states + residual) if residual is not None else hidden_states
|
71 |
+
hidden_states = norm_f(residual.to(dtype= norm_f.weight.dtype))
|
72 |
+
else:
|
73 |
+
# Set prenorm=False here since we don't need the residual
|
74 |
+
fused_add_norm_fn = rms_norm_fn if isinstance(norm_f, RMSNorm) else layer_norm_fn
|
75 |
+
hidden_states = fused_add_norm_fn(
|
76 |
+
hidden_states,
|
77 |
+
norm_f.weight,
|
78 |
+
norm_f.bias,
|
79 |
+
eps=norm_f.eps,
|
80 |
+
residual=residual,
|
81 |
+
prenorm=False,
|
82 |
+
residual_in_fp32=self.lit_config.residual_in_fp32,
|
83 |
+
)
|
84 |
+
|
85 |
+
else:
|
86 |
+
B, T = idx.size()
|
87 |
+
# use_kv_cache = input_pos is not None
|
88 |
+
use_kv_cache = input_pos is not None
|
89 |
+
block_size = self.lit_config.block_size
|
90 |
+
if max_seq_length is None:
|
91 |
+
max_seq_length = block_size
|
92 |
+
if use_kv_cache: # not relevant otherwise
|
93 |
+
assert (
|
94 |
+
max_seq_length >= T
|
95 |
+
), f"Cannot forward sequence of length {T}, max seq length is only {max_seq_length}"
|
96 |
+
#assert max_seq_length <= block_size, f"Cannot attend to {max_seq_length}, block size is only {block_size}"
|
97 |
+
#assert block_size >= T, f"Cannot forward sequence of length {T}, block size is only {block_size}"
|
98 |
+
if not self.lit_config.nope:
|
99 |
+
if self.rope_cache is None:
|
100 |
+
self.rope_cache = self.build_rope_cache(idx, self.max_len)
|
101 |
+
elif T> self.max_len:
|
102 |
+
self.max_len = T
|
103 |
+
self.rope_cache = self.build_rope_cache(idx, self.max_len)
|
104 |
+
cos, sin = self.rope_cache
|
105 |
+
# passing `attn_mask` to SDPA downgrades it to use the inefficient implementation. since we only need the mask
|
106 |
+
# for the kv-cache support (only during inference), we only create it in that situation
|
107 |
+
# this will be resolved by https://github.com/pytorch/pytorch/issues/96099
|
108 |
+
if use_kv_cache and self.mask_cache is None:
|
109 |
+
self.mask_cache = self.build_mask_cache(idx)
|
110 |
+
if use_kv_cache:
|
111 |
+
if not self.lit_config.nope:
|
112 |
+
cos = cos.index_select(0, input_pos)
|
113 |
+
sin = sin.index_select(0, input_pos)
|
114 |
+
mask = self.mask_cache.index_select(2, input_pos)
|
115 |
+
mask = mask[:, :, :, :max_seq_length]
|
116 |
+
else:
|
117 |
+
if not self.lit_config.nope:
|
118 |
+
cos = cos[:T]
|
119 |
+
sin = sin[:T]
|
120 |
+
mask = None
|
121 |
+
if self.lit_config.nope:
|
122 |
+
rope = None
|
123 |
+
else:
|
124 |
+
rope = (cos, sin)
|
125 |
+
# forward the model itself
|
126 |
+
x = self.embeddings(idx) # token embeddings of shape (b, t, n_embd)
|
127 |
+
if not use_kv_cache:
|
128 |
+
for block in self.h:
|
129 |
+
x, *_ = block(x, rope, max_seq_length)
|
130 |
+
else:
|
131 |
+
if self.lit_config.nope:
|
132 |
+
self.kv_caches = self.kv_caches or self.build_kv_caches(x, max_seq_length, None )
|
133 |
+
else:
|
134 |
+
self.kv_caches = self.kv_caches or self.build_kv_caches(x, max_seq_length, cos.size(-1) * 2)
|
135 |
+
for i, block in enumerate(self.h):
|
136 |
+
x, self.kv_caches[i] = block(x, rope, max_seq_length, mask, input_pos, self.kv_caches[i])
|
137 |
+
|
138 |
+
hidden_states = self.ln_f(x)
|
139 |
+
|
140 |
+
return BaseModelOutputWithPast(
|
141 |
+
last_hidden_state=hidden_states,
|
142 |
+
past_key_values=None,
|
143 |
+
hidden_states=hidden_states,
|
144 |
+
attentions=None
|
145 |
+
)
|
146 |
+
|
147 |
+
|
148 |
+
def build_rope_cache(self, idx: torch.Tensor, seq_len: int) -> RoPECache:
|
149 |
+
return build_rope_cache(
|
150 |
+
seq_len=seq_len,
|
151 |
+
n_elem=int(self.lit_config.rotary_percentage * self.lit_config.head_size),
|
152 |
+
dtype=torch.float32,
|
153 |
+
device=idx.device,
|
154 |
+
condense_ratio=self.lit_config.condense_ratio,
|
155 |
+
)
|
156 |
+
|
157 |
+
def build_mask_cache(self, idx: torch.Tensor) -> torch.Tensor:
|
158 |
+
ones = torch.ones((self.lit_config.block_size, self.lit_config.block_size), device=idx.device, dtype=torch.bool)
|
159 |
+
return torch.tril(ones).unsqueeze(0).unsqueeze(0)
|
160 |
+
|
161 |
+
def build_kv_caches(self, idx: torch.Tensor, max_seq_length: int, rope_cache_length: int) -> List[KVCache]:
|
162 |
+
B = idx.size(0)
|
163 |
+
heads = 1 if self.lit_config.n_query_groups == 1 else self.lit_config.n_query_groups
|
164 |
+
if rope_cache_length is not None:
|
165 |
+
k_cache_shape = (
|
166 |
+
B,
|
167 |
+
max_seq_length,
|
168 |
+
heads,
|
169 |
+
rope_cache_length + self.lit_config.head_size - int(self.lit_config.rotary_percentage * self.lit_config.head_size),
|
170 |
+
)
|
171 |
+
else:
|
172 |
+
k_cache_shape = (
|
173 |
+
B,
|
174 |
+
max_seq_length,
|
175 |
+
heads,
|
176 |
+
self.lit_config.head_size,
|
177 |
+
)
|
178 |
+
v_cache_shape = (B, max_seq_length, heads, self.lit_config.head_size)
|
179 |
+
device = idx.device
|
180 |
+
return [
|
181 |
+
(torch.zeros(k_cache_shape, device=device), torch.zeros(v_cache_shape, device=device))
|
182 |
+
for _ in range(self.lit_config.n_layer)
|
183 |
+
]
|
184 |
+
|
185 |
+
|
186 |
+
|
187 |
+
class HF_GPTForCausalLM(HF_GPTPreTrainedModel):
|
188 |
+
_tied_weights_keys = ["lm_head.weight"]
|
189 |
+
|
190 |
+
def __init__(self, config):
|
191 |
+
super().__init__(config)
|
192 |
+
self.lit_config = Config.from_name(config.name)
|
193 |
+
self.config = config
|
194 |
+
self.transformer = HF_GPTModel(config)
|
195 |
+
config = self.lit_config
|
196 |
+
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
197 |
+
self.n_layer = config.n_layer
|
198 |
+
self.post_init()
|
199 |
+
|
200 |
+
def get_input_embeddings(self):
|
201 |
+
return self.transformer.embeddings
|
202 |
+
|
203 |
+
def set_input_embeddings(self, value):
|
204 |
+
self.transformer.embeddings = value
|
205 |
+
|
206 |
+
def get_output_embeddings(self):
|
207 |
+
return self.lm_head
|
208 |
+
|
209 |
+
def set_output_embeddings(self, new_embeddings):
|
210 |
+
self.lm_head = new_embeddings
|
211 |
+
|
212 |
+
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
|
213 |
+
inputs = {"input_ids": input_ids}
|
214 |
+
if past is not None:
|
215 |
+
inputs["past_key_values"] = past
|
216 |
+
return inputs
|
217 |
+
|
218 |
+
def forward(
|
219 |
+
self,
|
220 |
+
input_ids: torch.LongTensor = None,
|
221 |
+
attention_mask: Optional[torch.Tensor] = None,
|
222 |
+
position_ids: Optional[torch.LongTensor] = None,
|
223 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
224 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
225 |
+
labels: Optional[torch.LongTensor] = None,
|
226 |
+
use_cache: Optional[bool] = None,
|
227 |
+
output_attentions: Optional[bool] = None,
|
228 |
+
output_hidden_states: Optional[bool] = None,
|
229 |
+
return_dict: Optional[bool] = True,
|
230 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
231 |
+
transformer_outputs = self.transformer(
|
232 |
+
input_ids,
|
233 |
+
attention_mask=attention_mask,
|
234 |
+
position_ids=position_ids,
|
235 |
+
past_key_values=past_key_values,
|
236 |
+
inputs_embeds=inputs_embeds,
|
237 |
+
use_cache=use_cache,
|
238 |
+
output_attentions=output_attentions,
|
239 |
+
output_hidden_states=output_hidden_states,
|
240 |
+
return_dict=return_dict,
|
241 |
+
)
|
242 |
+
|
243 |
+
hidden_states = transformer_outputs[0]
|
244 |
+
logits = self.lm_head(hidden_states)
|
245 |
+
|
246 |
+
loss = None
|
247 |
+
if labels is not None:
|
248 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
249 |
+
shift_labels = labels[..., 1:].contiguous()
|
250 |
+
loss_fct = nn.CrossEntropyLoss()
|
251 |
+
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
252 |
+
|
253 |
+
if not return_dict:
|
254 |
+
output = (logits,) + transformer_outputs[1:]
|
255 |
+
return ((loss,) + output) if loss is not None else output
|
256 |
+
|
257 |
+
return CausalLMOutputWithPast(
|
258 |
+
loss=loss,
|
259 |
+
logits=logits,
|
260 |
+
past_key_values=transformer_outputs.past_key_values,
|
261 |
+
hidden_states=transformer_outputs.hidden_states,
|
262 |
+
attentions=transformer_outputs.attentions,
|
263 |
+
)
|
264 |
+
|
265 |
+
def _reorder_cache(self, past, beam_idx):
|
266 |
+
return tuple(
|
267 |
+
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past)
|
268 |
+
for layer_past in past
|
269 |
+
)
|
270 |
+
|
271 |
+
def build_rope_cache(
|
272 |
+
seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000, condense_ratio: int = 1
|
273 |
+
) -> RoPECache:
|
274 |
+
"""Enhanced Transformer with Rotary Position Embedding.
|
275 |
+
|
276 |
+
Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/
|
277 |
+
transformers/rope/__init__.py. MIT License:
|
278 |
+
https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license.
|
279 |
+
"""
|
280 |
+
# $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
|
281 |
+
theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, device=device) / n_elem))
|
282 |
+
|
283 |
+
# Create position indexes `[0, 1, ..., seq_len - 1]`
|
284 |
+
seq_idx = torch.arange(seq_len, device=device) / condense_ratio
|
285 |
+
|
286 |
+
# Calculate the product of position index and $\theta_i$
|
287 |
+
idx_theta = torch.outer(seq_idx, theta)
|
288 |
+
|
289 |
+
cos, sin = torch.cos(idx_theta), torch.sin(idx_theta)
|
290 |
+
|
291 |
+
# added by peiyuan to ensure same data type with q, k, to use fused rotary embedding
|
292 |
+
if dtype == torch.bfloat16:
|
293 |
+
return cos.bfloat16(), sin.bfloat16()
|
294 |
+
# this is to mimic the behaviour of complex32, else we will get different results
|
295 |
+
if dtype in (torch.float16, torch.bfloat16, torch.int8):
|
296 |
+
return cos.half(), sin.half()
|
297 |
+
return cos, sin
|
lm_harness.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
from __future__ import annotations
|
4 |
+
|
5 |
+
import fla # noqa
|
6 |
+
from lm_eval.__main__ import cli_evaluate
|
7 |
+
from lm_eval.api.registry import register_model
|
8 |
+
from lm_eval.models.huggingface import HFLM
|
9 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
|
10 |
+
from hf_gpt.hf_model import HF_GPTForCausalLM
|
11 |
+
from hf_gpt.hf_config import HFConfig
|
12 |
+
import requests
|
13 |
+
import wandb
|
14 |
+
import lm_eval
|
15 |
+
# from lm_eval.loggers import WandbLogger
|
16 |
+
import argparse
|
17 |
+
AutoConfig.register("hf_gpt",HFConfig)
|
18 |
+
AutoModelForCausalLM.register(HFConfig,HF_GPTForCausalLM)
|
19 |
+
import logging
|
20 |
+
logging.basicConfig(level=logging.INFO)
|
21 |
+
import torch
|
22 |
+
import os
|
23 |
+
import pdb
|
24 |
+
|
25 |
+
import os
|
26 |
+
os.environ['HF_HOME'] = '/lustre/fs8/portfolios/nvr/users/ahatamizadeh/hf_cache/'
|
27 |
+
|
28 |
+
|
29 |
+
def is_directory_non_empty(directory):
|
30 |
+
if not os.path.isdir(directory):
|
31 |
+
return "The provided path is not a directory."
|
32 |
+
|
33 |
+
return len(os.listdir(directory)) > 0
|
34 |
+
|
35 |
+
def main(args):
|
36 |
+
### First convert to Huggingface models when neccessary
|
37 |
+
import datasets
|
38 |
+
|
39 |
+
datasets.config.HF_DATASETS_TRUST_REMOTE_CODE = True
|
40 |
+
hf_save_dir = args.hf_save_dir or os.path.dirname(args.ckpt_path)
|
41 |
+
# for example: /lustre/fsw/portfolios/nvr/users/soyang/code/next_gen_llm-1/checkpoint/outputs/tsz512x4k_20B_Samba_421M_tsz512x4k_20B_Samba_421M_sy_stream_v11/iter-009198-ckpt.pth
|
42 |
+
ckpt = torch.load(args.ckpt_path)
|
43 |
+
print("Checkpoint loaded")
|
44 |
+
hf_config = HFConfig(name=args.model_name)
|
45 |
+
hf_model = HF_GPTForCausalLM(hf_config)
|
46 |
+
model_weight = ckpt['model']
|
47 |
+
new_weight = {}
|
48 |
+
for k, v in model_weight.items():
|
49 |
+
if 'wte' in k:
|
50 |
+
new_weight[k.replace("wte", "embeddings")] = v
|
51 |
+
elif 'beta_proj' in k:
|
52 |
+
new_weight[k.replace("beta_proj", "b_proj")] = v
|
53 |
+
elif 'bias_proj' in k:
|
54 |
+
new_weight[k.replace("bias_proj", "b_proj")] = v
|
55 |
+
else:
|
56 |
+
new_weight[k] = v
|
57 |
+
|
58 |
+
hf_model.load_state_dict(new_weight)
|
59 |
+
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, trust_remote_code=True)
|
60 |
+
hf_model.save_pretrained(hf_save_dir)
|
61 |
+
tokenizer.save_pretrained(hf_save_dir)
|
62 |
+
print("Huggingface model saved")
|
63 |
+
|
64 |
+
### Then call lm_eval
|
65 |
+
tasks = args.tasks.split(',')
|
66 |
+
assert hf_save_dir is not None
|
67 |
+
assert args.dtype in ['bfloat16', 'float32']
|
68 |
+
# wandb_logger = wandb.init(project="llm_next_gen", name=args.exp_name, id=args.exp_name, group=args.wandb_group_name) # or empty if wandb.init(...) already called before
|
69 |
+
print("Start lm eval....")
|
70 |
+
|
71 |
+
results = lm_eval.simple_evaluate(
|
72 |
+
model="hf",
|
73 |
+
model_args=f"pretrained={hf_save_dir},trust_remote_code=True,dtype={args.dtype}",
|
74 |
+
tasks=tasks,
|
75 |
+
device="cuda",
|
76 |
+
log_samples=False,
|
77 |
+
batch_size=1,
|
78 |
+
num_fewshot=args.num_fewshot,
|
79 |
+
)['results']
|
80 |
+
|
81 |
+
print('swde: {}'.format(results['swde']['contains,none']))
|
82 |
+
print('squad_completion: {}'.format(results['squad_completion']['contains,none']))
|
83 |
+
# print('mmlu: {}'.format(results['mmlu']['acc,none']))
|
84 |
+
print('piqa: {}'.format(results['piqa']['acc,none']))
|
85 |
+
print('hellaswag: {}'.format(results['hellaswag']['acc_norm,none']))
|
86 |
+
print('winogrande: {}'.format(results['winogrande']['acc,none']))
|
87 |
+
print('arc_easy: {}'.format(results['arc_easy']['acc,none']))
|
88 |
+
print('arc_challenge: {}'.format(results['arc_challenge']['acc_norm,none']))
|
89 |
+
print('wikitext, ppl: {}'.format(results['wikitext']['word_perplexity,none']))
|
90 |
+
print('lambada_openai, acc: {}'.format(results['lambada_openai']['acc,none']))
|
91 |
+
print('lambada_openai, ppl: {}'.format(results['lambada_openai']['perplexity,none']))
|
92 |
+
|
93 |
+
|
94 |
+
if __name__ == "__main__":
|
95 |
+
parser = argparse.ArgumentParser(description='LLM Training')
|
96 |
+
parser.add_argument('--ckpt_path', type=str, default=None, help='Path to the ckpt directory')
|
97 |
+
parser.add_argument('--hf_save_dir', type=str, default=None, help='(Selective) Path to the saved HF model directory')
|
98 |
+
parser.add_argument('--dtype', type=str, default='bfloat16', help='Data type to use for inference')
|
99 |
+
parser.add_argument('--model_name', type=str, default='Samba_421M', help='Model name')
|
100 |
+
parser.add_argument('--exp_name', type=str, default='hf_eval', help='Experiment name')
|
101 |
+
parser.add_argument('--wandb_dir', type=str, default='/lustre/fsw/portfolios/nvr/users/soyang/code/next_gen_llm-1/checkpoint/outputs', help='Wandb directory')
|
102 |
+
parser.add_argument('--wandb_group_name', type=str, default='lm-eval-harness', help='Wandb group name')
|
103 |
+
parser.add_argument('--tasks', type=str, default='wikitext,lambada_openai,piqa,hellaswag,winogrande,arc_easy,arc_challenge,mmlu', help='Tasks to evaluate')
|
104 |
+
parser.add_argument('--tokenizer_name', type=str, default="TinyLlama/TinyLlama_v1.1", help="tokenizer name or path")
|
105 |
+
parser.add_argument('--batch_size', type=int, default=64)
|
106 |
+
parser.add_argument('--num_fewshot', type=int, default=0)
|
107 |
+
# do convert or not
|
108 |
+
parser.add_argument('--skip_convert', action='store_true', help='Whether to convert to Huggingface model')
|
109 |
+
|
110 |
+
args = parser.parse_args()
|
111 |
+
main(args)
|
112 |
+
|