Andrew DalPino
commited on
Commit
·
a1fecbb
1
Parent(s):
c7028d4
Fix unsqueeze missing dimension argument
Browse files- README.md +1 -1
- instruction-tune.py +3 -3
- model.py +12 -17
README.md
CHANGED
@@ -127,7 +127,7 @@ Then navigate to the dashboard using your favorite web browser.
|
|
127 |
| Argument | Default | Type | Description |
|
128 |
|---|---|---|---|
|
129 |
| --base_model_path | "./checkpoints/checkpoint.pt" | string | The path to the base checkpoint on disk. |
|
130 |
-
| --max_tokens_per_sample |
|
131 |
| --mask_input | False | bool | Should we mask the input part of the training sequences i.e. only train on the supervised output? |
|
132 |
| --batch_size | 1 | int | The number of samples to pass through the network at a time. |
|
133 |
| --gradient_accumulation_steps | 64 | int | The number of batches to pass through the network before updating the weights. |
|
|
|
127 |
| Argument | Default | Type | Description |
|
128 |
|---|---|---|---|
|
129 |
| --base_model_path | "./checkpoints/checkpoint.pt" | string | The path to the base checkpoint on disk. |
|
130 |
+
| --max_tokens_per_sample | 2048 | int | The maximum number of tokens to pack into a single training sequence. |
|
131 |
| --mask_input | False | bool | Should we mask the input part of the training sequences i.e. only train on the supervised output? |
|
132 |
| --batch_size | 1 | int | The number of samples to pass through the network at a time. |
|
133 |
| --gradient_accumulation_steps | 64 | int | The number of batches to pass through the network before updating the weights. |
|
instruction-tune.py
CHANGED
@@ -27,7 +27,7 @@ def main():
|
|
27 |
parser.add_argument(
|
28 |
"--base_model_path", default="./checkpoints/checkpoint.pt", type=str
|
29 |
)
|
30 |
-
parser.add_argument("--max_tokens_per_sample", default=
|
31 |
parser.add_argument("--mask_input", action="store_true")
|
32 |
parser.add_argument("--batch_size", default=1, type=int)
|
33 |
parser.add_argument("--gradient_accumulation_steps", default=64, type=int)
|
@@ -62,7 +62,7 @@ def main():
|
|
62 |
else torch.float32
|
63 |
)
|
64 |
|
65 |
-
|
66 |
|
67 |
if args.seed:
|
68 |
torch.manual_seed(args.seed)
|
@@ -160,7 +160,7 @@ def main():
|
|
160 |
x = x.to(args.device, non_blocking=True)
|
161 |
y = y.to(args.device, non_blocking=True)
|
162 |
|
163 |
-
with
|
164 |
y_pred, loss = model(x, y)
|
165 |
|
166 |
scaled_loss = loss / args.gradient_accumulation_steps
|
|
|
27 |
parser.add_argument(
|
28 |
"--base_model_path", default="./checkpoints/checkpoint.pt", type=str
|
29 |
)
|
30 |
+
parser.add_argument("--max_tokens_per_sample", default=2048, type=int)
|
31 |
parser.add_argument("--mask_input", action="store_true")
|
32 |
parser.add_argument("--batch_size", default=1, type=int)
|
33 |
parser.add_argument("--gradient_accumulation_steps", default=64, type=int)
|
|
|
62 |
else torch.float32
|
63 |
)
|
64 |
|
65 |
+
amp_context = autocast(device_type=args.device, dtype=dtype)
|
66 |
|
67 |
if args.seed:
|
68 |
torch.manual_seed(args.seed)
|
|
|
160 |
x = x.to(args.device, non_blocking=True)
|
161 |
y = y.to(args.device, non_blocking=True)
|
162 |
|
163 |
+
with amp_context:
|
164 |
y_pred, loss = model(x, y)
|
165 |
|
166 |
scaled_loss = loss / args.gradient_accumulation_steps
|
model.py
CHANGED
@@ -92,9 +92,7 @@ class GPT(Module):
|
|
92 |
"""Instead of memorizing the activations of the forward pass, recompute them at various checkpoints."""
|
93 |
self.checkpoint = partial(torch_checkpoint, use_reentrant=False)
|
94 |
|
95 |
-
def forward(
|
96 |
-
self, x: Tensor, y: Tensor | None = None
|
97 |
-
) -> tuple[Tensor, Tensor | None]:
|
98 |
"""A forward pass optimized for batch training."""
|
99 |
|
100 |
z = self.token_embeddings(x)
|
@@ -110,17 +108,15 @@ class GPT(Module):
|
|
110 |
z = self.output_norm(z)
|
111 |
z = self.output_layer(z)
|
112 |
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
labels = y.view(-1)
|
117 |
|
118 |
-
|
119 |
-
else:
|
120 |
-
loss = None
|
121 |
|
122 |
return z, loss
|
123 |
|
|
|
124 |
def predict(self, x: Tensor) -> Tensor:
|
125 |
"""A forward pass optimized for batch next-token prediction."""
|
126 |
|
@@ -136,7 +132,7 @@ class GPT(Module):
|
|
136 |
|
137 |
z = self.output_norm(z)
|
138 |
|
139 |
-
# Pluck only the last token embedding
|
140 |
z = z[:, -1, :]
|
141 |
|
142 |
z = self.output_layer(z)
|
@@ -200,7 +196,7 @@ class GPT(Module):
|
|
200 |
|
201 |
probabilities = softmax(logits, dim=0)
|
202 |
|
203 |
-
offset = torch.multinomial(probabilities, num_samples=1).squeeze(
|
204 |
|
205 |
next_token = indices[offset]
|
206 |
|
@@ -251,7 +247,8 @@ class GPT(Module):
|
|
251 |
reverse=True,
|
252 |
)
|
253 |
|
254 |
-
candidates
|
|
|
255 |
|
256 |
tokens = torch.tensor([], dtype=prompt.dtype).to(prompt.device)
|
257 |
|
@@ -372,9 +369,7 @@ class GPTWithLoRA(Module):
|
|
372 |
for name in lora_params:
|
373 |
remove_parametrizations(module, name, leave_parametrized=True)
|
374 |
|
375 |
-
def forward(
|
376 |
-
self, x: Tensor, y: Tensor | None = None
|
377 |
-
) -> tuple[Tensor, Tensor | None]:
|
378 |
return self.model.forward(x, y)
|
379 |
|
380 |
def predict(self, x: Tensor) -> Tensor:
|
@@ -407,7 +402,7 @@ class GPTWithLoRA(Module):
|
|
407 |
|
408 |
|
409 |
class ONNXModel(Module):
|
410 |
-
"""This wrapper provides a
|
411 |
|
412 |
def __init__(self, model: GPT | GPTWithLoRA):
|
413 |
super().__init__()
|
|
|
92 |
"""Instead of memorizing the activations of the forward pass, recompute them at various checkpoints."""
|
93 |
self.checkpoint = partial(torch_checkpoint, use_reentrant=False)
|
94 |
|
95 |
+
def forward(self, x: Tensor, y: Tensor) -> tuple[Tensor, Tensor]:
|
|
|
|
|
96 |
"""A forward pass optimized for batch training."""
|
97 |
|
98 |
z = self.token_embeddings(x)
|
|
|
108 |
z = self.output_norm(z)
|
109 |
z = self.output_layer(z)
|
110 |
|
111 |
+
# Concatenate the batches along the time dimension.
|
112 |
+
y_pred = z.view(-1, z.size(-1))
|
113 |
+
labels = y.view(-1)
|
|
|
114 |
|
115 |
+
loss = self.loss_function(y_pred, labels)
|
|
|
|
|
116 |
|
117 |
return z, loss
|
118 |
|
119 |
+
@torch.no_grad()
|
120 |
def predict(self, x: Tensor) -> Tensor:
|
121 |
"""A forward pass optimized for batch next-token prediction."""
|
122 |
|
|
|
132 |
|
133 |
z = self.output_norm(z)
|
134 |
|
135 |
+
# Pluck only the last token embedding from each batch.
|
136 |
z = z[:, -1, :]
|
137 |
|
138 |
z = self.output_layer(z)
|
|
|
196 |
|
197 |
probabilities = softmax(logits, dim=0)
|
198 |
|
199 |
+
offset = torch.multinomial(probabilities, num_samples=1).squeeze()
|
200 |
|
201 |
next_token = indices[offset]
|
202 |
|
|
|
247 |
reverse=True,
|
248 |
)
|
249 |
|
250 |
+
candidates: list[Candidate] = []
|
251 |
+
completed: list[Candidate] = []
|
252 |
|
253 |
tokens = torch.tensor([], dtype=prompt.dtype).to(prompt.device)
|
254 |
|
|
|
369 |
for name in lora_params:
|
370 |
remove_parametrizations(module, name, leave_parametrized=True)
|
371 |
|
372 |
+
def forward(self, x: Tensor, y: Tensor) -> tuple[Tensor, Tensor]:
|
|
|
|
|
373 |
return self.model.forward(x, y)
|
374 |
|
375 |
def predict(self, x: Tensor) -> Tensor:
|
|
|
402 |
|
403 |
|
404 |
class ONNXModel(Module):
|
405 |
+
"""This wrapper provides a clean inferencing API for production models."""
|
406 |
|
407 |
def __init__(self, model: GPT | GPTWithLoRA):
|
408 |
super().__init__()
|