File size: 30,358 Bytes
578867d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 |
import math
import torch
from tqdm import tqdm
from dataclasses import dataclass
from contextlib import nullcontext
from typing import Mapping, Optional, Tuple
from accelerate import Accelerator
from collections import defaultdict
from transformers.modeling_outputs import BaseModelOutputWithPast
def optional_grad_ctx(with_grad=False):
if with_grad:
return nullcontext()
else:
return torch.no_grad()
def move_to_device(data, device):
"""
Prepares one `data` before feeding it to the model, be it a tensor or a nested list/dictionary of tensors.
"""
if isinstance(data, Mapping):
return type(data)({k: move_to_device(v, device) for k, v in data.items()})
elif isinstance(data, (tuple, list)):
return type(data)(move_to_device(v, device) for v in data)
elif isinstance(data, torch.Tensor):
kwargs = {"device": device}
return data.to(**kwargs)
else:
return data
def get_shifted_labels(input_ids):
if isinstance(input_ids, torch.Tensor):
labels = input_ids.clone()
labels = torch.cat([labels[:, 1:], labels.new_zeros((input_ids.shape[0], 1)) - 100], dim=-1)
elif isinstance(input_ids, list) and isinstance(input_ids[0], int):
labels = input_ids.copy()
labels = labels[1:] + [-100]
elif isinstance(input_ids, list) and isinstance(input_ids[0], list):
labels = input_ids.copy()
for i, label in enumerate(labels):
labels[i] = labels[i][1:] + [-100]
else:
raise NotImplementedError
return labels
def compute_loss(logits, labels, shift=False):
"""
Returns:
token_loss: batch_size, seq_length
"""
if shift:
labels = get_shifted_labels(labels)
labels = labels.to(logits.device)
batch_size = logits.shape[0]
# NOTE: the loss on -100 labels is 0 by default
token_loss = torch.nn.functional.cross_entropy(
logits.flatten(0, 1),
labels.reshape(-1),
reduction="none"
).reshape(batch_size, -1) # batch_size, seq_len
# print(token_loss)
valid_token_num = (labels != -100).sum(-1) # batch_size
all_valid_token_num = valid_token_num.sum()
if all_valid_token_num > 0:
loss = token_loss.sum() / valid_token_num.sum()
else:
loss = token_loss.sum()
batch_loss = token_loss.sum(-1) / valid_token_num
# prevent nan
if (valid_token_num == 0).any():
batch_loss = batch_loss.masked_fill(valid_token_num == 0, 0.)
return loss, batch_loss, token_loss
@torch.no_grad()
def evaluate_perplexity(model, dataloader, accelerator:Optional[Accelerator]=None):
if accelerator is not None and type(dataloader) == torch.utils.data.DataLoader:
# if the dataloader has been prepared, we shall not prepare it twice, especially in case of deepspeed
dataloader = accelerator.prepare(dataloader)
# if accelerator.process_index == 0:
# for name, x in model.named_parameters():
# print(f"{name: ^80} {x.dtype}")
all_loss = defaultdict(list)
for i, x in enumerate(tqdm(dataloader, desc="Computing Perplexity")):
# NOTE: important to reset memory for every batch
if hasattr(model, "memory"):
model.memory.reset()
# the seq id
index = x.pop("index")
# length is used to group training data, no use here
length = x.pop("length", None)
output = model(**x)
valid_token_num = (x["labels"] != -100).sum(-1)
# NOTE: we need the loss for each element in the batch for accurate computation, because the number of valid tokens may differ among elements
if hasattr(output, "batch_loss"):
# output from our model has batch_loss by default
batch_loss = output.batch_loss
else:
# output from other models does not
loss, batch_loss, token_loss = compute_loss(output.logits, x["labels"], shift=True)
index = index.tolist()
batch_loss = batch_loss.tolist()
valid_token_num = valid_token_num.tolist()
if accelerator is not None and accelerator.num_processes > 1:
# num_device * batch_size
index = accelerator.gather_for_metrics(index)
batch_loss = accelerator.gather_for_metrics(batch_loss)
valid_token_num = accelerator.gather_for_metrics(valid_token_num)
for _id, _loss, _num in zip(index, batch_loss, valid_token_num):
# loss times num is the total loss of all valid tokens
all_loss[_id].append((_loss * _num, _num))
all_loss = dict(all_loss)
for _id, loss_and_num in all_loss.items():
# sum up the loss for all valid tokens in the entire sequence, and divide the number of valid tokens
all_loss[_id] = sum([x[0] for x in loss_and_num]) / sum(x[1] for x in loss_and_num)
# average across then take exp
perplexity = math.exp(sum(all_loss.values()) / len(all_loss))
return perplexity
@torch.no_grad()
def evaluate_generation(model, dataloader, accelerator:Optional[Accelerator]=None, tokenizer=None, return_new_tokens_only=True, **generation_config):
if accelerator is not None and type(dataloader) == torch.utils.data.DataLoader:
# if the dataloader has been prepared, we shall not prepare it twice, especially in case of deepspeed
dataloader = accelerator.prepare(dataloader)
all_indices = []
all_outputs = []
index = 0
for i, x in enumerate(tqdm(dataloader, desc="Computing Generation")):
# if i > 3:
# break
# NOTE: important to reset memory for every batch
if hasattr(model, "memory"):
model.memory.reset()
# length is used to group training data, no use here
length = x.pop("length", None)
# if indices are None, we use batch size
indices = x.pop("index", None)
if indices is None:
indices = list(range(index, index + x['input_ids'].shape[0]))
index += x['input_ids'].shape[0]
else:
indices = indices.tolist()
outputs = model.generate(**x, **generation_config)
if return_new_tokens_only:
start_idx = x["input_ids"].shape[1]
outputs = outputs[:, start_idx:]
outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
if accelerator is not None and accelerator.num_processes > 1:
outputs = accelerator.gather_for_metrics(outputs)
indices = accelerator.gather_for_metrics(indices)
outputs = outputs
indices = indices
all_indices.extend(indices)
all_outputs.extend(outputs)
return all_indices, all_outputs
@torch.no_grad()
def evaluate_nll(model, dataloader, accelerator:Optional[Accelerator]=None):
if accelerator is not None and type(dataloader) == torch.utils.data.DataLoader:
# if the dataloader has been prepared, we shall not prepare it twice, especially in case of deepspeed
dataloader = accelerator.prepare(dataloader)
# if accelerator.process_index == 0:
# for name, x in model.named_parameters():
# print(f"{name: ^80} {x.dtype}")
all_loss = defaultdict(list)
for i, x in enumerate(tqdm(dataloader, desc="Computing Perplexity")):
# NOTE: important to reset memory for every batch
if hasattr(model, "memory"):
model.memory.reset()
# the seq id
index = x.pop("index")
# length is used to group training data, no use here
length = x.pop("length", None)
output = model(**x)
valid_token_num = (x["labels"] != -100).sum()
# NOTE: we need the loss for each element in the batch for accurate computation, because the number of valid tokens may differ among elements
if hasattr(output, "batch_loss"):
# output from our model has batch_loss by default
batch_loss = output.batch_loss
else:
# output from other models does not
loss, batch_loss, token_loss = compute_loss(output.logits, x["labels"], shift=True)
if accelerator is not None and accelerator.num_processes > 1:
# num_device * batch_size
index = accelerator.gather_for_metrics(index)
batch_loss = accelerator.gather_for_metrics(batch_loss)
valid_token_num = accelerator.gather_for_metrics(valid_token_num)
for _id, _loss in zip(index.tolist(), batch_loss.tolist()):
# loss times num is the total loss of all valid tokens
all_loss[_id].append(_loss)
return all_loss
@dataclass
class ModelOutput(BaseModelOutputWithPast):
loss: Optional[torch.FloatTensor] = None
batch_loss: Optional[torch.FloatTensor] = None
token_loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
########## Various RoPE Scaling Methods Below (wrap the encoding process within the module for convenience) ##########
def get_rope(head_dim, base, max_position_embeddings, rope_scaling=None):
"""
Get rope module. {native, linear scaling, dynamic ntk scaling, yarn scaling, llama3 scaling}
"""
if rope_scaling is None:
rope = RotaryEmbedding(
dim=head_dim,
base=base,
max_position_embeddings=max_position_embeddings,
)
else:
scaling_type = rope_scaling["type"]
scaling_factor = rope_scaling["factor"]
if scaling_type == "linear":
rope = LinearScalingRotaryEmbedding(
dim=head_dim,
base=base,
max_position_embeddings=max_position_embeddings,
scaling_factor=scaling_factor,
)
elif scaling_type == "dynamic":
rope = DynamicNTKScalingRotaryEmbedding(
dim=head_dim,
base=base,
max_position_embeddings=max_position_embeddings,
scaling_factor=scaling_factor,
)
elif scaling_type == "yarn":
rope = YarnRotaryEmbedding(
dim=head_dim,
base=base,
max_position_embeddings=max_position_embeddings,
scaling_factor=scaling_factor,
)
elif scaling_type == "yarn-t":
rope = YarnDynamicTemperatureRotaryEmbedding(
dim=head_dim,
base=base,
max_position_embeddings=max_position_embeddings,
scaling_factor=scaling_factor,
)
elif scaling_type == "yarn-t-logn":
rope = YarnDynamicTemperatureLogNRotaryEmbedding(
dim=head_dim,
base=base,
max_position_embeddings=max_position_embeddings,
scaling_factor=scaling_factor,
)
elif scaling_type == "llama3":
rope = Llama3RotaryEmbedding(
dim=head_dim,
base=base,
max_position_embeddings=max_position_embeddings,
scaling_factor=scaling_factor,
original_max_position_embeddings=rope_scaling.get("original_max_position_embeddings", 8192),
low_freq_factor=rope_scaling.get("low_freq_factor", 1),
high_freq_factor=rope_scaling.get("high_freq_factor", 4),
)
else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
return rope
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
class RotaryEmbedding(torch.nn.Module):
def __init__(self, dim, max_position_embeddings=32768, base=10000, device=None):
super().__init__()
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.float32).to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
# Build here to make `torch.jit.trace` work.
self._set_cos_sin_cache(
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
)
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.float32)
freqs = torch.outer(t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos(), persistent=False)
self.register_buffer("sin_cached", emb.sin(), persistent=False)
def forward(self, q, k, position_ids):
seq_len = max(position_ids.max().item() + 1, k.shape[2])
# x: [bs, num_attention_heads, seq_len, head_size]
if seq_len > self.max_seq_len_cached:
self._set_cos_sin_cache(seq_len=seq_len, device=k.device, dtype=k.dtype)
# batch_size, 1, key_len, head_dim
k_cos = self.cos_cached[position_ids].to(dtype=k.dtype).unsqueeze(1)
k_sin = self.sin_cached[position_ids].to(dtype=k.dtype).unsqueeze(1)
q_cos = k_cos[..., -q.shape[2]:, :]
q_sin = k_sin[..., -q.shape[2]:, :]
q_embed = (q * q_cos) + (rotate_half(q) * q_sin)
k_embed = (k * k_cos) + (rotate_half(k) * k_sin)
return q_embed, k_embed
class LinearScalingRotaryEmbedding(RotaryEmbedding):
"""RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
def __init__(self, dim, max_position_embeddings=32768, base=10000, device=None, scaling_factor=1.0):
self.scaling_factor = scaling_factor
super().__init__(dim, max_position_embeddings, base, device)
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.float32)
t = t / self.scaling_factor
freqs = torch.outer(t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding):
"""RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
def __init__(self, dim, max_position_embeddings=32768, base=10000, device=None, scaling_factor=1.0):
self.scaling_factor = scaling_factor
super().__init__(dim, max_position_embeddings, base, device)
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
if seq_len > self.max_position_embeddings:
base = self.base * (
(self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
) ** (self.dim / (self.dim - 2))
inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2, dtype=torch.float32).to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
freqs = torch.outer(t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
class YarnRotaryEmbedding(torch.nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0, beta_slow=2, beta_fast=128):
super().__init__()
self.base = base
self.dim = dim
self.scaling_factor = scaling_factor
self.beta_slow = beta_slow
self.beta_fast = beta_fast
self.max_position_embeddings = max_position_embeddings
self._set_cos_sin_cache(
seq_len=math.ceil(max_position_embeddings * scaling_factor), device=device, dtype=torch.get_default_dtype()
)
def _get_factor(self):
# the dimension whose index is smaller than fast_dim rotates more than beta_fast
fast_dim = self.dim / 2 * (math.log(self.max_position_embeddings / (2 * math.pi * self.beta_fast)) / math.log(self.base))
fast_dim = max(math.floor(fast_dim), 0)
# the dimension whose index is bigger than slow_dim rotates less than beta_slow
slow_dim = self.dim / 2 * (math.log(self.max_position_embeddings / (2 * math.pi * self.beta_slow)) / math.log(self.base))
slow_dim = min(math.ceil(slow_dim), self.dim - 1)
if fast_dim == slow_dim:
slow_dim += 0.001
# NOTE: very important to use full precision here so that the factor is correct
dim_arange = torch.arange(0, self.dim // 2, dtype=torch.float32)
dim_factor = (dim_arange - fast_dim) / (slow_dim - fast_dim)
dim_factor = torch.clamp(dim_factor, 0, 1)
# align with the paper notation
return (1 - dim_factor)
def _get_temperature(self):
if self.scaling_factor <= 1:
return 1.0
return 0.07 * math.log(self.scaling_factor) + 1.0
def _set_cos_sin_cache(self, seq_len, device, dtype):
dim_arange = torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim
# dim / 2
freq = self.base ** dim_arange
theta = 1 / freq
interleave_theta = theta / self.scaling_factor
factor = self._get_factor().to(device)
yarn_theta = factor * theta + (1 - factor) * interleave_theta
self.register_buffer("inv_freq", yarn_theta, persistent=False)
t = torch.arange(seq_len, device=device, dtype=torch.float32)
freqs = torch.outer(t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
# get attention temperature
temperature = self._get_temperature()
self.register_buffer("cos_cached", emb.cos() * temperature, persistent=False)
self.register_buffer("sin_cached", emb.sin() * temperature, persistent=False)
self.max_seq_len_cached = seq_len
def forward(self, q, k, position_ids):
seq_len = max(position_ids.max().item() + 1, k.shape[2])
# x: [bs, num_attention_heads, seq_len, head_size]
if seq_len > self.max_seq_len_cached:
self.scaling_factor = seq_len / self.max_position_embeddings
self._set_cos_sin_cache(seq_len=seq_len, device=k.device, dtype=k.dtype)
k_cos = self.cos_cached[position_ids].to(dtype=k.dtype).unsqueeze(1)
k_sin = self.sin_cached[position_ids].to(dtype=k.dtype).unsqueeze(1)
q_cos = k_cos[..., -q.shape[2]:, :]
q_sin = k_sin[..., -q.shape[2]:, :]
q_embed = (q * q_cos) + (rotate_half(q) * q_sin)
k_embed = (k * k_cos) + (rotate_half(k) * k_sin)
return q_embed, k_embed
class YarnDynamicTemperatureRotaryEmbedding(torch.nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0, beta_slow=2, beta_fast=128):
super().__init__()
self.base = base
self.dim = dim
self.scaling_factor = scaling_factor
self.beta_slow = beta_slow
self.beta_fast = beta_fast
self.max_position_embeddings = max_position_embeddings
self._set_cos_sin_cache(
seq_len=math.ceil(max_position_embeddings * scaling_factor), device=device, dtype=torch.get_default_dtype()
)
def _get_factor(self):
# the dimension whose index is smaller than fast_dim rotates more than beta_fast
fast_dim = self.dim / 2 * (math.log(self.max_position_embeddings / (2 * math.pi * self.beta_fast)) / math.log(self.base))
fast_dim = max(math.floor(fast_dim), 0)
# the dimension whose index is bigger than slow_dim rotates less than beta_slow
slow_dim = self.dim / 2 * (math.log(self.max_position_embeddings / (2 * math.pi * self.beta_slow)) / math.log(self.base))
slow_dim = min(math.ceil(slow_dim), self.dim - 1)
if fast_dim == slow_dim:
slow_dim += 0.001
# NOTE: very important to use full precision here so that the factor is correct
dim_arange = torch.arange(0, self.dim // 2, dtype=torch.float32)
dim_factor = (dim_arange - fast_dim) / (slow_dim - fast_dim)
dim_factor = torch.clamp(dim_factor, 0, 1)
# align with the paper notation
return (1 - dim_factor)
def _set_cos_sin_cache(self, seq_len, device, dtype):
dim_arange = torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim
# dim / 2
freq = self.base ** dim_arange
theta = 1 / freq
interleave_theta = theta / self.scaling_factor
factor = self._get_factor().to(device)
yarn_theta = factor * theta + (1 - factor) * interleave_theta
self.register_buffer("inv_freq", yarn_theta, persistent=False)
positions = torch.arange(seq_len, device=device, dtype=torch.float32)
freqs = torch.outer(positions, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
# NOTE: get attention temperature that will be applied on the query vector
# temperature = torch.log(positions + 1) / math.log(self.max_position_embeddings)
temperature = (0.07 * torch.log((positions + 1) / self.max_position_embeddings) + 1) ** 2
temperature[:self.max_position_embeddings] = 1
self.register_buffer("temperature", temperature.unsqueeze(1), persistent=False)
self.register_buffer("cos_cached", emb.cos(), persistent=False)
self.register_buffer("sin_cached", emb.sin(), persistent=False)
self.max_seq_len_cached = seq_len
def forward(self, q, k, position_ids):
seq_len = max(position_ids.max().item() + 1, k.shape[2])
# x: [bs, num_attention_heads, seq_len, head_size]
if seq_len > self.max_seq_len_cached:
self.scaling_factor = seq_len / self.max_position_embeddings
self._set_cos_sin_cache(seq_len=seq_len, device=k.device, dtype=k.dtype)
# batch_size, 1, key_len, head_dim
k_cos = self.cos_cached[position_ids].to(dtype=k.dtype).unsqueeze(1)
k_sin = self.sin_cached[position_ids].to(dtype=k.dtype).unsqueeze(1)
q_cos = k_cos[..., -q.shape[2]:, :]
q_sin = k_sin[..., -q.shape[2]:, :]
q_position_ids = position_ids[:, -q.shape[2]:]
temperature = self.temperature[q_position_ids].to(dtype=k.dtype).unsqueeze(1)
q_cos = q_cos * temperature
q_sin = q_sin * temperature
q_embed = (q * q_cos) + (rotate_half(q) * q_sin)
k_embed = (k * k_cos) + (rotate_half(k) * k_sin)
return q_embed, k_embed
class YarnDynamicTemperatureLogNRotaryEmbedding(torch.nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0, beta_slow=2, beta_fast=128):
super().__init__()
self.base = base
self.dim = dim
self.scaling_factor = scaling_factor
self.beta_slow = beta_slow
self.beta_fast = beta_fast
self.max_position_embeddings = max_position_embeddings
self._set_cos_sin_cache(
seq_len=math.ceil(max_position_embeddings * scaling_factor), device=device, dtype=torch.get_default_dtype()
)
def _get_factor(self):
# the dimension whose index is smaller than fast_dim rotates more than beta_fast
fast_dim = self.dim / 2 * (math.log(self.max_position_embeddings / (2 * math.pi * self.beta_fast)) / math.log(self.base))
fast_dim = max(math.floor(fast_dim), 0)
# the dimension whose index is bigger than slow_dim rotates less than beta_slow
slow_dim = self.dim / 2 * (math.log(self.max_position_embeddings / (2 * math.pi * self.beta_slow)) / math.log(self.base))
slow_dim = min(math.ceil(slow_dim), self.dim - 1)
if fast_dim == slow_dim:
slow_dim += 0.001
# NOTE: very important to use full precision here so that the factor is correct
dim_arange = torch.arange(0, self.dim // 2, dtype=torch.float32)
dim_factor = (dim_arange - fast_dim) / (slow_dim - fast_dim)
dim_factor = torch.clamp(dim_factor, 0, 1)
# align with the paper notation
return (1 - dim_factor)
def _set_cos_sin_cache(self, seq_len, device, dtype):
dim_arange = torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim
# dim / 2
freq = self.base ** dim_arange
theta = 1 / freq
interleave_theta = theta / self.scaling_factor
factor = self._get_factor().to(device)
yarn_theta = factor * theta + (1 - factor) * interleave_theta
self.register_buffer("inv_freq", yarn_theta, persistent=False)
positions = torch.arange(seq_len, device=device, dtype=torch.float32)
freqs = torch.outer(positions, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
# NOTE: get attention temperature that will be applied on the query vector
temperature = torch.log(positions + 1) / math.log(self.max_position_embeddings)
# temperature = (0.07 * torch.log((positions + 1) / self.max_position_embeddings) + 1) ** 2
temperature[:self.max_position_embeddings] = 1
self.register_buffer("temperature", temperature.unsqueeze(1), persistent=False)
self.register_buffer("cos_cached", emb.cos(), persistent=False)
self.register_buffer("sin_cached", emb.sin(), persistent=False)
self.max_seq_len_cached = seq_len
def forward(self, q, k, position_ids):
seq_len = max(position_ids.max().item() + 1, k.shape[2])
# x: [bs, num_attention_heads, seq_len, head_size]
if seq_len > self.max_seq_len_cached:
self.scaling_factor = seq_len / self.max_position_embeddings
self._set_cos_sin_cache(seq_len=seq_len, device=k.device, dtype=k.dtype)
# batch_size, 1, key_len, head_dim
k_cos = self.cos_cached[position_ids].to(dtype=k.dtype).unsqueeze(1)
k_sin = self.sin_cached[position_ids].to(dtype=k.dtype).unsqueeze(1)
q_cos = k_cos[..., -q.shape[2]:, :]
q_sin = k_sin[..., -q.shape[2]:, :]
q_position_ids = position_ids[:, -q.shape[2]:]
temperature = self.temperature[q_position_ids].to(dtype=k.dtype).unsqueeze(1)
q_cos = q_cos * temperature
q_sin = q_sin * temperature
q_embed = (q * q_cos) + (rotate_half(q) * q_sin)
k_embed = (k * k_cos) + (rotate_half(k) * k_sin)
return q_embed, k_embed
class Llama3RotaryEmbedding(torch.nn.Module):
def __init__(self, dim, max_position_embeddings=8192, base=10000, device=None, scaling_factor=1.0, original_max_position_embeddings=8192, low_freq_factor=1, high_freq_factor=4):
super().__init__()
self.base = base
self.dim = dim
self.scaling_factor = scaling_factor
self.original_max_position_embeddings = original_max_position_embeddings
self.max_position_embeddings = max(max_position_embeddings, int(original_max_position_embeddings * scaling_factor))
self.low_freq_factor = low_freq_factor
self.high_freq_factor = high_freq_factor
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.float32).to(device) / self.dim))
low_freq_wavelen = self.original_max_position_embeddings / low_freq_factor
high_freq_wavelen = self.original_max_position_embeddings / high_freq_factor
new_freqs = []
for freq in inv_freq:
wavelen = 2 * math.pi / freq
if wavelen < high_freq_wavelen:
new_freqs.append(freq)
elif wavelen > low_freq_wavelen:
new_freqs.append(freq / scaling_factor)
else:
assert low_freq_wavelen != high_freq_wavelen
smooth = (self.original_max_position_embeddings / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
new_freqs.append((1 - smooth) * freq / scaling_factor + smooth * freq)
inv_freq = torch.tensor(new_freqs, dtype=inv_freq.dtype, device=inv_freq.device)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self._set_cos_sin_cache(seq_len=self.max_position_embeddings, device=device)
def _set_cos_sin_cache(self, seq_len, device):
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.float32)
freqs = torch.outer(t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos(), persistent=False)
self.register_buffer("sin_cached", emb.sin(), persistent=False)
def forward(self, q, k, position_ids):
seq_len = max(position_ids.max().item() + 1, k.shape[2])
# x: [bs, num_attention_heads, seq_len, head_size]
if seq_len > self.max_seq_len_cached:
self._set_cos_sin_cache(seq_len=seq_len, device=k.device)
k_cos = self.cos_cached[position_ids].to(dtype=k.dtype).unsqueeze(1)
k_sin = self.sin_cached[position_ids].to(dtype=k.dtype).unsqueeze(1)
q_cos = k_cos[..., -q.shape[2]:, :]
q_sin = k_sin[..., -q.shape[2]:, :]
q_embed = (q * q_cos) + (rotate_half(q) * q_sin)
k_embed = (k * k_cos) + (rotate_half(k) * k_sin)
return q_embed, k_embed |