File size: 16,394 Bytes
e431f0f ad56477 e431f0f ad56477 fc4824e ad56477 2947852 ad56477 2947852 e431f0f ad56477 e431f0f fc4824e ad56477 fc4824e ad56477 624da87 ad56477 fc4824e ad56477 fc4824e c7028d4 fc4824e a0ce4f0 c7028d4 ad56477 e431f0f ad56477 e431f0f ad56477 a0ce4f0 98bb233 ad56477 a0ce4f0 ad56477 a1fecbb c7028d4 98bb233 c7028d4 ad56477 c7028d4 e431f0f ad56477 e431f0f ad56477 e431f0f ad56477 c7028d4 ad56477 a1fecbb ad56477 e431f0f ad56477 e431f0f ad56477 a1fecbb ad56477 e431f0f ad56477 c7028d4 ad56477 530fce2 ad56477 e431f0f ad56477 111c2a3 ad56477 111c2a3 ad56477 98bb233 ad56477 c7028d4 ad56477 c7028d4 ad56477 c7028d4 ad56477 c7028d4 ad56477 c7028d4 ad56477 e431f0f 2947852 e431f0f 111c2a3 e431f0f c7028d4 e431f0f ad56477 624da87 ad56477 624da87 ad56477 624da87 ad56477 e431f0f ad56477 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 |
from math import sqrt
from dataclasses import dataclass
from functools import partial, cached_property
from typing import Iterator, Self
import torch
from torch import Tensor
from torch.nn import (
Module,
ModuleList,
Sequential,
Embedding,
MultiheadAttention,
Linear,
SiLU,
RMSNorm,
Dropout1d,
CrossEntropyLoss,
Parameter,
)
from torch.nn.functional import softmax, log_softmax
from torch.nn.utils.parametrize import register_parametrization, remove_parametrizations
from torch.utils.checkpoint import checkpoint as torch_checkpoint
from huggingface_hub import PyTorchModelHubMixin
class LightGPT(Module, PyTorchModelHubMixin):
"""A generative pretrained transformer."""
def __init__(
self,
vocabulary_size: int,
embedding_dimensions: int,
num_heads: int,
num_layers: int,
feed_forward_ratio: int,
dropout: float,
padding_index: int,
eos_index: int,
):
super().__init__()
if num_layers <= 0:
raise ValueError(f"Num layers must be greater than 0, {num_layers} given.")
if feed_forward_ratio not in (1, 2, 4):
raise ValueError("Feed-forward ratio must be either 1, 2, or 4.")
if vocabulary_size <= 0:
raise ValueError(
f"Vocabulary size must be greater than 0, {vocabulary_size} given."
)
token_embeddings = Embedding(
vocabulary_size, embedding_dimensions, padding_idx=padding_index
)
output_layer = Linear(embedding_dimensions, vocabulary_size, bias=False)
token_embeddings.weight = output_layer.weight # Tie weights
self.token_embeddings = token_embeddings
self.body = ModuleList(
[
CausalSelfAttentionBlock(
embedding_dimensions,
num_heads,
feed_forward_ratio,
dropout,
)
for _ in range(num_layers)
]
)
self.checkpoint = lambda layer, x, attention_mask: layer(x, attention_mask)
self.output_norm = RMSNorm(embedding_dimensions)
self.output_layer = output_layer
self.loss_function = CrossEntropyLoss(ignore_index=padding_index)
self.vocabulary_size = vocabulary_size
self.eos_index = eos_index
@cached_property
def num_trainable_params(self) -> int:
return sum(param.numel() for param in self.parameters() if param.requires_grad)
def enable_activation_checkpointing(self) -> None:
"""Instead of memorizing the activations of the forward pass, recompute them at various checkpoints."""
self.checkpoint = partial(torch_checkpoint, use_reentrant=False)
def forward(
self, x: Tensor, y: Tensor | None = None
) -> tuple[Tensor, Tensor | None]:
"""A forward pass optimized for batch training."""
z = self.token_embeddings(x)
b, t, d = z.size()
causal_mask = torch.full((t, t), float("-inf"), dtype=z.dtype, device=z.device)
causal_mask = torch.triu(causal_mask, diagonal=1)
for layer in self.body:
z = self.checkpoint(layer, z, causal_mask)
z = self.output_norm(z)
z = self.output_layer(z)
if y is not None:
y_pred = z.view(-1, z.size(-1))
labels = y.view(-1) # Flatten the batch dimension.
loss = self.loss_function(y_pred, labels)
else:
loss = None
return z, loss
@torch.no_grad()
def predict(self, x: Tensor) -> Tensor:
"""A forward pass optimized for batch next-token prediction."""
z = self.token_embeddings(x)
b, t, d = z.size()
causal_mask = torch.full((t, t), float("-inf"), dtype=z.dtype, device=z.device)
causal_mask = torch.triu(causal_mask, diagonal=1)
for layer in self.body:
z = layer(z, causal_mask)
z = self.output_norm(z)
z = z[:, -1, :] # Pluck only the last token embedding from each batch.
z = self.output_layer(z)
return z
@torch.no_grad()
def generate(
self,
prompt: Tensor,
max_tokens: int = 2000,
context_length: int = 1024,
temperature: float = 1.0,
top_k: int = 500,
top_p: float = 0.9,
) -> Iterator:
"""
Given a prompt, sample the next {max_tokens} tokens from the model weighted
by their predicted probabilities and filtered by the {top_k} and {top_p}.
"""
if max_tokens <= 0:
raise ValueError(f"Max tokens must be greater than 0, {max_tokens} given.")
if temperature <= 0:
raise ValueError(
f"Temperature must be greater than 0, {temperature} given."
)
if top_k <= 0 or top_k > self.vocabulary_size:
raise ValueError(
f"Top k must be between 1 and {self.vocabulary_size}, {top_k} given."
)
if top_p <= 0.0 or top_p > 1.0:
raise ValueError(f"Top p must be between 0 and 1, {top_p} given.")
context_window = prompt
for _ in range(max_tokens):
context_window = context_window[-context_length:]
logits = self.predict(context_window.unsqueeze(0)).squeeze()
logits, indices = torch.topk(logits, top_k, sorted=True)
probabilities = softmax(logits, dim=0)
cumulative_probability_mass = torch.cumsum(probabilities, dim=0)
min_probability_mass = cumulative_probability_mass[0]
threshold_p = max(top_p, min_probability_mass.item())
selected_indices = cumulative_probability_mass <= threshold_p
logits = logits[selected_indices]
indices = indices[selected_indices]
logits /= temperature
probabilities = softmax(logits, dim=0)
offset = torch.multinomial(probabilities, num_samples=1).squeeze()
next_token = indices[offset]
if next_token == self.eos_index:
break
yield next_token
context_window = torch.cat((context_window, next_token.unsqueeze(0)))
@torch.no_grad()
def beam_search(
self,
prompt: Tensor,
max_tokens: int = 100,
context_length: int = 1024,
num_candidates: int = 3,
beam_width: int = 16,
) -> list:
"""
Given a prompt, return the {num_candidates} highest probability sequences. Note that
this method is often best for generating shorter sequences and is typically less
natural sounding than sequences that are more random in nature.
"""
if max_tokens <= 0:
raise ValueError(f"Max tokens must be greater than 0, {max_tokens} given.")
if num_candidates <= 0:
raise ValueError(
f"Num candidates must be greater than 0, {num_candidates} given."
)
if beam_width <= 0:
raise ValueError(f"Beam width must be greater than 0, {beam_width} given.")
@dataclass
class Candidate:
log_probability: float
tokens: Tensor
def priority(self) -> float:
return self.log_probability
sort_candidates = partial(
sorted,
key=lambda candidate: candidate.priority(),
reverse=True,
)
candidates: list[Candidate] = []
completed: list[Candidate] = []
tokens = torch.tensor([], dtype=prompt.dtype).to(prompt.device)
candidates.append(Candidate(0.0, tokens))
while len(candidates) > 0:
candidate = candidates.pop()
if len(completed) >= num_candidates:
completed = sort_candidates(completed)
completed = completed[:num_candidates]
worst_candidate = completed[-1]
if candidate.log_probability < worst_candidate.log_probability:
break
if len(candidate.tokens) > 0 and candidate.tokens[-1] == self.eos_index:
candidate.tokens = candidate.tokens[:-1]
completed.append(candidate)
continue
if len(candidate.tokens) >= max_tokens:
completed.append(candidate)
continue
context_window = torch.cat((prompt, candidate.tokens))
context_window = context_window[-context_length:]
logits = self.predict(context_window.unsqueeze(0)).squeeze()
logits, indices = torch.topk(logits, beam_width, sorted=False)
log_probabilities = log_softmax(logits, dim=0)
for log_probability, index in zip(log_probabilities, indices):
log_probability = candidate.log_probability + log_probability
tokens = torch.cat((candidate.tokens, index.unsqueeze(0)))
candidates.append(Candidate(log_probability, tokens))
candidates = sort_candidates(candidates)
candidates = candidates[:beam_width]
return completed
class LightGPTInstruct(Module, PyTorchModelHubMixin):
"""
A wrapper for pretrained GPT models that applies a LoRA reparameterization
to the intermediate layers of the network.
"""
def __init__(self, model: LightGPT, rank: int, alpha: float, dropout: float):
super().__init__()
if rank <= 0:
raise ValueError(f"Rank must be greater than 0, {rank} given.")
if alpha <= 0.0:
raise ValueError(f"Alpha must be greater than 0, {alpha} given.")
for param in model.parameters():
param.requires_grad = False
model.token_embeddings.weight = Parameter(
torch.cat(
(
model.token_embeddings.weight.data,
torch.randn(2, model.token_embeddings.weight.size(dim=1)),
)
)
)
for module in model.body:
out_features, in_features = module.attention.in_proj_weight.shape
register_parametrization(
module.attention,
"in_proj_weight",
LoRA(in_features, out_features, rank, alpha, dropout),
)
out_features, in_features = module.attention.out_proj.weight.shape
register_parametrization(
module.attention.out_proj,
"weight",
LoRA(in_features, out_features, rank, alpha, dropout),
)
for layer in module.mlp.layers:
if isinstance(layer, Linear):
register_parametrization(
layer,
"weight",
LoRA.from_linear(layer, rank, alpha, dropout),
)
self.model = model
@property
def num_trainable_params(self) -> int:
return self.model.num_trainable_params
def state_dict(self):
return {
name: module
for name, module in super().state_dict().items()
if "lora" in name
}
def merge_lora_parameters(self):
"""Merge the LoRA parameters with the original parameters."""
for module in self.model.modules():
if hasattr(module, "parametrizations"):
lora_params = [name for name in module.parametrizations.keys()]
for name in lora_params:
remove_parametrizations(module, name, leave_parametrized=True)
def forward(
self, x: Tensor, y: Tensor | None = None
) -> tuple[Tensor, Tensor | None]:
return self.model.forward(x, y)
def predict(self, x: Tensor) -> Tensor:
return self.model.predict(x)
def generate(
self,
prompt: Tensor,
max_tokens: int = 2000,
context_length: int = 1024,
temperature: float = 1.0,
top_k: int = 500,
top_p: float = 0.9,
) -> Iterator:
return self.model.generate(
prompt, max_tokens, context_length, temperature, top_k, top_p
)
def beam_search(
self,
prompt: Tensor,
max_tokens: int = 100,
context_length: int = 1024,
num_candidates: int = 3,
beam_width: int = 16,
) -> list:
return self.model.beam_search(
prompt, max_tokens, context_length, num_candidates, beam_width
)
class ONNXModel(Module):
"""This wrapper provides a clean inferencing API for ONNX production models."""
def __init__(self, model: LightGPT | LightGPTInstruct):
super().__init__()
self.model = model
def forward(self, x: Tensor) -> Tensor:
return self.model.predict(x)
class CausalSelfAttentionBlock(Module):
"""Causal self-attention block with residual connections."""
def __init__(
self,
embedding_dimensions: int,
num_heads: int,
feed_forward_ratio: int,
dropout: float,
):
super().__init__()
if embedding_dimensions <= 0:
raise ValueError(
f"Embedding dimensions must be greater than 0, {embedding_dimensions} given."
)
if num_heads <= 0:
raise ValueError(f"Num heads must be greater than 0, {num_heads} given.")
if dropout < 0 or dropout > 1:
raise ValueError(f"Dropout must be between 0 and 1, {dropout} given")
self.norm1 = RMSNorm(embedding_dimensions)
self.attention = MultiheadAttention(
embedding_dimensions,
num_heads,
batch_first=True,
dropout=dropout,
bias=False,
)
hidden_dimensions = feed_forward_ratio * embedding_dimensions
self.norm2 = RMSNorm(embedding_dimensions)
self.mlp = MLP(embedding_dimensions, hidden_dimensions, dropout)
def forward(self, x: Tensor, attention_mask: Tensor) -> Tensor:
z = self.norm1(x)
z, _ = self.attention(z, z, z, attn_mask=attention_mask, is_causal=True)
z = x + z # Residual connection
x = z
z = self.norm2(x)
z = self.mlp(z)
z = x + z # Residual connection
return z
class MLP(Module):
"""A two-layer fully-connected network with dropout."""
def __init__(
self, embedding_dimensions: int, hidden_dimensions: int, dropout: float
):
super().__init__()
if embedding_dimensions <= 0:
raise ValueError(
f"Embedding dimensions must be greater than 0, {embedding_dimensions} given."
)
if hidden_dimensions <= 0:
raise ValueError(
f"Hidden dimensions must be greater than 0, {hidden_dimensions} given."
)
self.layers = Sequential(
Linear(embedding_dimensions, hidden_dimensions, bias=False),
SiLU(),
Linear(hidden_dimensions, embedding_dimensions, bias=False),
)
self.dropout = Dropout1d(p=dropout)
def forward(self, x: Tensor) -> Tensor:
return self.dropout(self.layers(x))
class LoRA(Module):
"""Rank decomposition transformation."""
@classmethod
def from_linear(
cls, linear: Linear, rank: int, alpha: float, dropout: float
) -> Self:
out_features, in_features = linear.weight.shape
return cls(in_features, out_features, rank, alpha, dropout)
def __init__(
self,
in_features: int,
out_features: int,
rank: int,
alpha: float,
dropout: float,
):
super().__init__()
if rank <= 0:
raise ValueError(f"Rank must be greater than 0, {rank} given.")
if alpha <= 0.0:
raise ValueError(f"Alpha must be greater than 0, {alpha} given.")
std_dev = 1.0 / sqrt(rank)
self.lora_a = Parameter(torch.randn(rank, in_features) * std_dev)
self.lora_b = Parameter(torch.zeros(out_features, rank))
self.dropout = Dropout1d(p=dropout)
self.alpha = alpha
def forward(self, x: Tensor) -> Tensor:
z = self.lora_b @ self.dropout(self.lora_a)
z *= self.alpha
return x + z
|