Upload FlaxTransformerLMForCausalLM
Browse files- modeling_transformerlm_flax.py +33 -29
modeling_transformerlm_flax.py
CHANGED
@@ -278,6 +278,32 @@ class FlaxTransformerLMPreTrainedModel(FlaxPreTrainedModel):
|
|
278 |
**kwargs,
|
279 |
):
|
280 |
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
281 |
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
282 |
|
283 |
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
|
@@ -366,37 +392,17 @@ class FlaxTransformerLMPreTrainedModel(FlaxPreTrainedModel):
|
|
366 |
if input_ids.shape[1] > 1:
|
367 |
input_ids = jnp.insert(input_ids, 0, 0, axis=1) # Insert 0 at the beginning of prompt
|
368 |
|
369 |
-
# Progressive cache loop
|
370 |
if self.module.use_cache:
|
371 |
-
|
372 |
-
logits, cache = state
|
373 |
-
input_id = lax.dynamic_slice(input_ids, (0, i), (input_ids.shape[0], 1))
|
374 |
-
output = self.module.apply(
|
375 |
-
{
|
376 |
-
"params": inputs["params"],
|
377 |
-
"cache": cache
|
378 |
-
},
|
379 |
-
jnp.array(input_id, dtype="i4"),
|
380 |
-
jnp.array(attention_mask, dtype="i4"),
|
381 |
-
jnp.array(position_ids, dtype="i4"),
|
382 |
-
not train,
|
383 |
-
False,
|
384 |
-
output_attentions,
|
385 |
-
output_hidden_states,
|
386 |
-
return_dict,
|
387 |
-
rngs=rngs,
|
388 |
-
mutable=mutable,
|
389 |
-
)
|
390 |
-
lm_output, new_vars = output
|
391 |
-
logits = lm_output.logits
|
392 |
-
cache = new_vars["cache"]
|
393 |
-
return logits, unfreeze(cache)
|
394 |
-
|
395 |
seq_length = input_ids.shape[1]
|
396 |
-
|
|
|
397 |
cache = inputs["cache"]
|
398 |
initial_state = (logits, cache)
|
399 |
-
|
|
|
|
|
|
|
400 |
|
401 |
if not return_dict:
|
402 |
outputs = (last_logits,) + (last_cache,)
|
@@ -454,7 +460,6 @@ class FlaxTransformerLMModule(nn.Module):
|
|
454 |
for i in range(config.num_layers)]
|
455 |
self.ln_f = nn.LayerNorm(dtype=config.dtype, name='encoderdecoder_norm')
|
456 |
|
457 |
-
@nn.compact
|
458 |
def __call__(
|
459 |
self,
|
460 |
input_ids,
|
@@ -543,7 +548,6 @@ class FlaxTransformerLMForCausalLMModule(nn.Module):
|
|
543 |
name='logitdense',
|
544 |
)
|
545 |
|
546 |
-
@nn.compact
|
547 |
def __call__(
|
548 |
self,
|
549 |
input_ids,
|
|
|
278 |
**kwargs,
|
279 |
):
|
280 |
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
281 |
+
|
282 |
+
def token_id_to_logits(state, token_id):
|
283 |
+
logits, cache = state
|
284 |
+
output = self.module.apply(
|
285 |
+
{
|
286 |
+
"params": self.params,
|
287 |
+
"cache": cache
|
288 |
+
},
|
289 |
+
token_id,
|
290 |
+
None,
|
291 |
+
None,
|
292 |
+
True,
|
293 |
+
False,
|
294 |
+
False,
|
295 |
+
False,
|
296 |
+
True,
|
297 |
+
rngs={},
|
298 |
+
mutable=["cache"],
|
299 |
+
)
|
300 |
+
lm_output, new_vars = output
|
301 |
+
logits = lm_output.logits
|
302 |
+
cache = unfreeze(new_vars["cache"])
|
303 |
+
return (logits, cache), logits
|
304 |
+
|
305 |
+
self.scan_body_fn = token_id_to_logits
|
306 |
+
|
307 |
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
308 |
|
309 |
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
|
|
|
392 |
if input_ids.shape[1] > 1:
|
393 |
input_ids = jnp.insert(input_ids, 0, 0, axis=1) # Insert 0 at the beginning of prompt
|
394 |
|
|
|
395 |
if self.module.use_cache:
|
396 |
+
# Progressive cache loop
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
397 |
seq_length = input_ids.shape[1]
|
398 |
+
vcab_size = self.module.config.vocab_size
|
399 |
+
logits = jnp.zeros((1, 1, vcab_size), dtype=self.dtype)
|
400 |
cache = inputs["cache"]
|
401 |
initial_state = (logits, cache)
|
402 |
+
input_tokens = jnp.reshape(input_ids, (seq_length, 1, 1))
|
403 |
+
last, all_logits = lax.scan(self.scan_body_fn, initial_state, input_tokens)
|
404 |
+
last_logits, last_cache = last
|
405 |
+
# lm_logits = jnp.reshape(all_logits, (1, seq_length, vcab_size))
|
406 |
|
407 |
if not return_dict:
|
408 |
outputs = (last_logits,) + (last_cache,)
|
|
|
460 |
for i in range(config.num_layers)]
|
461 |
self.ln_f = nn.LayerNorm(dtype=config.dtype, name='encoderdecoder_norm')
|
462 |
|
|
|
463 |
def __call__(
|
464 |
self,
|
465 |
input_ids,
|
|
|
548 |
name='logitdense',
|
549 |
)
|
550 |
|
|
|
551 |
def __call__(
|
552 |
self,
|
553 |
input_ids,
|