Upload FlaxTransformerLMForCausalLM
Browse files
modeling_transformerlm_flax.py
CHANGED
@@ -402,12 +402,12 @@ class FlaxTransformerLMPreTrainedModel(FlaxPreTrainedModel):
|
|
402 |
input_tokens = jnp.reshape(input_ids, (seq_length, 1, 1))
|
403 |
last, all_logits = lax.scan(self.scan_body_fn, initial_state, input_tokens)
|
404 |
last_logits, last_cache = last
|
405 |
-
|
406 |
|
407 |
if not return_dict:
|
408 |
-
outputs = (
|
409 |
else:
|
410 |
-
outputs = (FlaxCausalLMOutput(logits=
|
411 |
else:
|
412 |
output = self.module.apply(
|
413 |
inputs,
|
|
|
402 |
input_tokens = jnp.reshape(input_ids, (seq_length, 1, 1))
|
403 |
last, all_logits = lax.scan(self.scan_body_fn, initial_state, input_tokens)
|
404 |
last_logits, last_cache = last
|
405 |
+
lm_logits = jnp.reshape(all_logits, (1, seq_length, vcab_size))
|
406 |
|
407 |
if not return_dict:
|
408 |
+
outputs = (lm_logits,) + (last_cache,)
|
409 |
else:
|
410 |
+
outputs = (FlaxCausalLMOutput(logits=lm_logits, hidden_states=None, attentions=None), {"cache": last_cache})
|
411 |
else:
|
412 |
output = self.module.apply(
|
413 |
inputs,
|