Andrew DalPino commited on
Commit
a1fecbb
·
1 Parent(s): c7028d4

Fix unsqueeze missing dimension argument

Browse files
Files changed (3) hide show
  1. README.md +1 -1
  2. instruction-tune.py +3 -3
  3. model.py +12 -17
README.md CHANGED
@@ -127,7 +127,7 @@ Then navigate to the dashboard using your favorite web browser.
127
  | Argument | Default | Type | Description |
128
  |---|---|---|---|
129
  | --base_model_path | "./checkpoints/checkpoint.pt" | string | The path to the base checkpoint on disk. |
130
- | --max_tokens_per_sample | 4096 | int | The maximum number of tokens to pack into a single training sequence. |
131
  | --mask_input | False | bool | Should we mask the input part of the training sequences i.e. only train on the supervised output? |
132
  | --batch_size | 1 | int | The number of samples to pass through the network at a time. |
133
  | --gradient_accumulation_steps | 64 | int | The number of batches to pass through the network before updating the weights. |
 
127
  | Argument | Default | Type | Description |
128
  |---|---|---|---|
129
  | --base_model_path | "./checkpoints/checkpoint.pt" | string | The path to the base checkpoint on disk. |
130
+ | --max_tokens_per_sample | 2048 | int | The maximum number of tokens to pack into a single training sequence. |
131
  | --mask_input | False | bool | Should we mask the input part of the training sequences i.e. only train on the supervised output? |
132
  | --batch_size | 1 | int | The number of samples to pass through the network at a time. |
133
  | --gradient_accumulation_steps | 64 | int | The number of batches to pass through the network before updating the weights. |
instruction-tune.py CHANGED
@@ -27,7 +27,7 @@ def main():
27
  parser.add_argument(
28
  "--base_model_path", default="./checkpoints/checkpoint.pt", type=str
29
  )
30
- parser.add_argument("--max_tokens_per_sample", default=4096, type=int)
31
  parser.add_argument("--mask_input", action="store_true")
32
  parser.add_argument("--batch_size", default=1, type=int)
33
  parser.add_argument("--gradient_accumulation_steps", default=64, type=int)
@@ -62,7 +62,7 @@ def main():
62
  else torch.float32
63
  )
64
 
65
- forward_context = autocast(device_type=args.device, dtype=dtype)
66
 
67
  if args.seed:
68
  torch.manual_seed(args.seed)
@@ -160,7 +160,7 @@ def main():
160
  x = x.to(args.device, non_blocking=True)
161
  y = y.to(args.device, non_blocking=True)
162
 
163
- with forward_context:
164
  y_pred, loss = model(x, y)
165
 
166
  scaled_loss = loss / args.gradient_accumulation_steps
 
27
  parser.add_argument(
28
  "--base_model_path", default="./checkpoints/checkpoint.pt", type=str
29
  )
30
+ parser.add_argument("--max_tokens_per_sample", default=2048, type=int)
31
  parser.add_argument("--mask_input", action="store_true")
32
  parser.add_argument("--batch_size", default=1, type=int)
33
  parser.add_argument("--gradient_accumulation_steps", default=64, type=int)
 
62
  else torch.float32
63
  )
64
 
65
+ amp_context = autocast(device_type=args.device, dtype=dtype)
66
 
67
  if args.seed:
68
  torch.manual_seed(args.seed)
 
160
  x = x.to(args.device, non_blocking=True)
161
  y = y.to(args.device, non_blocking=True)
162
 
163
+ with amp_context:
164
  y_pred, loss = model(x, y)
165
 
166
  scaled_loss = loss / args.gradient_accumulation_steps
model.py CHANGED
@@ -92,9 +92,7 @@ class GPT(Module):
92
  """Instead of memorizing the activations of the forward pass, recompute them at various checkpoints."""
93
  self.checkpoint = partial(torch_checkpoint, use_reentrant=False)
94
 
95
- def forward(
96
- self, x: Tensor, y: Tensor | None = None
97
- ) -> tuple[Tensor, Tensor | None]:
98
  """A forward pass optimized for batch training."""
99
 
100
  z = self.token_embeddings(x)
@@ -110,17 +108,15 @@ class GPT(Module):
110
  z = self.output_norm(z)
111
  z = self.output_layer(z)
112
 
113
- if y is not None:
114
- # Flatten the batch dimension before calculating loss.
115
- y_pred = z.view(-1, z.size(-1))
116
- labels = y.view(-1)
117
 
118
- loss = self.loss_function(y_pred, labels)
119
- else:
120
- loss = None
121
 
122
  return z, loss
123
 
 
124
  def predict(self, x: Tensor) -> Tensor:
125
  """A forward pass optimized for batch next-token prediction."""
126
 
@@ -136,7 +132,7 @@ class GPT(Module):
136
 
137
  z = self.output_norm(z)
138
 
139
- # Pluck only the last token embedding in the time dimension.
140
  z = z[:, -1, :]
141
 
142
  z = self.output_layer(z)
@@ -200,7 +196,7 @@ class GPT(Module):
200
 
201
  probabilities = softmax(logits, dim=0)
202
 
203
- offset = torch.multinomial(probabilities, num_samples=1).squeeze(0)
204
 
205
  next_token = indices[offset]
206
 
@@ -251,7 +247,8 @@ class GPT(Module):
251
  reverse=True,
252
  )
253
 
254
- candidates, completed = [], []
 
255
 
256
  tokens = torch.tensor([], dtype=prompt.dtype).to(prompt.device)
257
 
@@ -372,9 +369,7 @@ class GPTWithLoRA(Module):
372
  for name in lora_params:
373
  remove_parametrizations(module, name, leave_parametrized=True)
374
 
375
- def forward(
376
- self, x: Tensor, y: Tensor | None = None
377
- ) -> tuple[Tensor, Tensor | None]:
378
  return self.model.forward(x, y)
379
 
380
  def predict(self, x: Tensor) -> Tensor:
@@ -407,7 +402,7 @@ class GPTWithLoRA(Module):
407
 
408
 
409
  class ONNXModel(Module):
410
- """This wrapper provides a cleaner inferencing API for production models."""
411
 
412
  def __init__(self, model: GPT | GPTWithLoRA):
413
  super().__init__()
 
92
  """Instead of memorizing the activations of the forward pass, recompute them at various checkpoints."""
93
  self.checkpoint = partial(torch_checkpoint, use_reentrant=False)
94
 
95
+ def forward(self, x: Tensor, y: Tensor) -> tuple[Tensor, Tensor]:
 
 
96
  """A forward pass optimized for batch training."""
97
 
98
  z = self.token_embeddings(x)
 
108
  z = self.output_norm(z)
109
  z = self.output_layer(z)
110
 
111
+ # Concatenate the batches along the time dimension.
112
+ y_pred = z.view(-1, z.size(-1))
113
+ labels = y.view(-1)
 
114
 
115
+ loss = self.loss_function(y_pred, labels)
 
 
116
 
117
  return z, loss
118
 
119
+ @torch.no_grad()
120
  def predict(self, x: Tensor) -> Tensor:
121
  """A forward pass optimized for batch next-token prediction."""
122
 
 
132
 
133
  z = self.output_norm(z)
134
 
135
+ # Pluck only the last token embedding from each batch.
136
  z = z[:, -1, :]
137
 
138
  z = self.output_layer(z)
 
196
 
197
  probabilities = softmax(logits, dim=0)
198
 
199
+ offset = torch.multinomial(probabilities, num_samples=1).squeeze()
200
 
201
  next_token = indices[offset]
202
 
 
247
  reverse=True,
248
  )
249
 
250
+ candidates: list[Candidate] = []
251
+ completed: list[Candidate] = []
252
 
253
  tokens = torch.tensor([], dtype=prompt.dtype).to(prompt.device)
254
 
 
369
  for name in lora_params:
370
  remove_parametrizations(module, name, leave_parametrized=True)
371
 
372
+ def forward(self, x: Tensor, y: Tensor) -> tuple[Tensor, Tensor]:
 
 
373
  return self.model.forward(x, y)
374
 
375
  def predict(self, x: Tensor) -> Tensor:
 
402
 
403
 
404
  class ONNXModel(Module):
405
+ """This wrapper provides a clean inferencing API for production models."""
406
 
407
  def __init__(self, model: GPT | GPTWithLoRA):
408
  super().__init__()