fukugawa commited on
Commit
cc04067
1 Parent(s): a1c0ad6

Upload FlaxTransformerLMForCausalLM

Browse files
Files changed (1) hide show
  1. modeling_transformerlm_flax.py +3 -3
modeling_transformerlm_flax.py CHANGED
@@ -402,12 +402,12 @@ class FlaxTransformerLMPreTrainedModel(FlaxPreTrainedModel):
402
  input_tokens = jnp.reshape(input_ids, (seq_length, 1, 1))
403
  last, all_logits = lax.scan(self.scan_body_fn, initial_state, input_tokens)
404
  last_logits, last_cache = last
405
- # lm_logits = jnp.reshape(all_logits, (1, seq_length, vcab_size))
406
 
407
  if not return_dict:
408
- outputs = (last_logits,) + (last_cache,)
409
  else:
410
- outputs = (FlaxCausalLMOutput(logits=last_logits, hidden_states=None, attentions=None), {"cache": last_cache})
411
  else:
412
  output = self.module.apply(
413
  inputs,
 
402
  input_tokens = jnp.reshape(input_ids, (seq_length, 1, 1))
403
  last, all_logits = lax.scan(self.scan_body_fn, initial_state, input_tokens)
404
  last_logits, last_cache = last
405
+ lm_logits = jnp.reshape(all_logits, (1, seq_length, vcab_size))
406
 
407
  if not return_dict:
408
+ outputs = (lm_logits,) + (last_cache,)
409
  else:
410
+ outputs = (FlaxCausalLMOutput(logits=lm_logits, hidden_states=None, attentions=None), {"cache": last_cache})
411
  else:
412
  output = self.module.apply(
413
  inputs,