fukugawa commited on
Commit
265aded
1 Parent(s): 4abe164

Upload FlaxTransformerLMForCausalLM

Browse files
Files changed (1) hide show
  1. modeling_transformerlm_flax.py +33 -29
modeling_transformerlm_flax.py CHANGED
@@ -278,6 +278,32 @@ class FlaxTransformerLMPreTrainedModel(FlaxPreTrainedModel):
278
  **kwargs,
279
  ):
280
  module = self.module_class(config=config, dtype=dtype, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
281
  super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
282
 
283
  def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
@@ -366,37 +392,17 @@ class FlaxTransformerLMPreTrainedModel(FlaxPreTrainedModel):
366
  if input_ids.shape[1] > 1:
367
  input_ids = jnp.insert(input_ids, 0, 0, axis=1) # Insert 0 at the beginning of prompt
368
 
369
- # Progressive cache loop
370
  if self.module.use_cache:
371
- def loop_body_fn(i, state):
372
- logits, cache = state
373
- input_id = lax.dynamic_slice(input_ids, (0, i), (input_ids.shape[0], 1))
374
- output = self.module.apply(
375
- {
376
- "params": inputs["params"],
377
- "cache": cache
378
- },
379
- jnp.array(input_id, dtype="i4"),
380
- jnp.array(attention_mask, dtype="i4"),
381
- jnp.array(position_ids, dtype="i4"),
382
- not train,
383
- False,
384
- output_attentions,
385
- output_hidden_states,
386
- return_dict,
387
- rngs=rngs,
388
- mutable=mutable,
389
- )
390
- lm_output, new_vars = output
391
- logits = lm_output.logits
392
- cache = new_vars["cache"]
393
- return logits, unfreeze(cache)
394
-
395
  seq_length = input_ids.shape[1]
396
- logits = jnp.zeros((1, 1, self.module.config.vocab_size), dtype=self.dtype)
 
397
  cache = inputs["cache"]
398
  initial_state = (logits, cache)
399
- last_logits, last_cache = lax.fori_loop(0, seq_length, loop_body_fn, initial_state)
 
 
 
400
 
401
  if not return_dict:
402
  outputs = (last_logits,) + (last_cache,)
@@ -454,7 +460,6 @@ class FlaxTransformerLMModule(nn.Module):
454
  for i in range(config.num_layers)]
455
  self.ln_f = nn.LayerNorm(dtype=config.dtype, name='encoderdecoder_norm')
456
 
457
- @nn.compact
458
  def __call__(
459
  self,
460
  input_ids,
@@ -543,7 +548,6 @@ class FlaxTransformerLMForCausalLMModule(nn.Module):
543
  name='logitdense',
544
  )
545
 
546
- @nn.compact
547
  def __call__(
548
  self,
549
  input_ids,
 
278
  **kwargs,
279
  ):
280
  module = self.module_class(config=config, dtype=dtype, **kwargs)
281
+
282
+ def token_id_to_logits(state, token_id):
283
+ logits, cache = state
284
+ output = self.module.apply(
285
+ {
286
+ "params": self.params,
287
+ "cache": cache
288
+ },
289
+ token_id,
290
+ None,
291
+ None,
292
+ True,
293
+ False,
294
+ False,
295
+ False,
296
+ True,
297
+ rngs={},
298
+ mutable=["cache"],
299
+ )
300
+ lm_output, new_vars = output
301
+ logits = lm_output.logits
302
+ cache = unfreeze(new_vars["cache"])
303
+ return (logits, cache), logits
304
+
305
+ self.scan_body_fn = token_id_to_logits
306
+
307
  super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
308
 
309
  def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
 
392
  if input_ids.shape[1] > 1:
393
  input_ids = jnp.insert(input_ids, 0, 0, axis=1) # Insert 0 at the beginning of prompt
394
 
 
395
  if self.module.use_cache:
396
+ # Progressive cache loop
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
397
  seq_length = input_ids.shape[1]
398
+ vcab_size = self.module.config.vocab_size
399
+ logits = jnp.zeros((1, 1, vcab_size), dtype=self.dtype)
400
  cache = inputs["cache"]
401
  initial_state = (logits, cache)
402
+ input_tokens = jnp.reshape(input_ids, (seq_length, 1, 1))
403
+ last, all_logits = lax.scan(self.scan_body_fn, initial_state, input_tokens)
404
+ last_logits, last_cache = last
405
+ # lm_logits = jnp.reshape(all_logits, (1, seq_length, vcab_size))
406
 
407
  if not return_dict:
408
  outputs = (last_logits,) + (last_cache,)
 
460
  for i in range(config.num_layers)]
461
  self.ln_f = nn.LayerNorm(dtype=config.dtype, name='encoderdecoder_norm')
462
 
 
463
  def __call__(
464
  self,
465
  input_ids,
 
548
  name='logitdense',
549
  )
550
 
 
551
  def __call__(
552
  self,
553
  input_ids,