|
import tensorflow as tf |
|
import mesh_tensorflow as mtf |
|
from functools import partial |
|
|
|
|
|
def entmax_backward(explicit_inputs, all_inputs, forward_operations, outputs, output_grads, alpha=1.3, dim=None, |
|
n_iter=50): |
|
x, = explicit_inputs |
|
y, = outputs |
|
dY, = output_grads |
|
|
|
gppr = mtf.where(mtf.greater(y, 0), mtf.pow(y, (2 - alpha)), mtf.zeros_like(y)) |
|
dX = dY * gppr |
|
|
|
q = mtf.reduce_sum(dX, reduced_dim=dim) / mtf.reduce_sum(gppr, reduced_dim=dim) |
|
dX = dX - q * gppr |
|
|
|
return dX, |
|
|
|
|
|
def entmax_forward(x, alpha=1.3, dim=None, n_iter=50): |
|
assert alpha > 1 and alpha < 2, 'alpha must be between 1 and 2' |
|
|
|
_gp = lambda x, alpha: x ** (alpha - 1) |
|
_gp_inv = lambda x, alpha: mtf.pow(x, (1 / (alpha - 1))) |
|
_p = lambda x, alpha: _gp_inv(mtf.relu(x), alpha) |
|
|
|
dim = x.shape[-1] if dim is None else dim |
|
d = dim.size |
|
|
|
x = x * (alpha - 1) |
|
|
|
max_val = mtf.reduce_max(x, reduced_dim=dim) |
|
|
|
tau_lo = max_val - _gp(1, alpha) |
|
tau_hi = max_val - _gp(1 / d, alpha) |
|
|
|
f_lo = mtf.reduce_sum(_p(x - tau_lo, alpha), reduced_dim=dim) - 1 |
|
|
|
dm = tau_hi - tau_lo |
|
|
|
for _ in range(n_iter): |
|
dm = dm / 2 |
|
tau_m = tau_lo + dm |
|
p_m = _p(x - tau_m, alpha) |
|
f_m = mtf.reduce_sum(p_m, reduced_dim=dim) - 1 |
|
|
|
mask = mtf.greater_equal((f_m * f_lo), 0) |
|
tau_lo = mtf.where(mask, tau_m, tau_lo) |
|
|
|
p_m = p_m / mtf.reduce_sum(p_m, reduced_dim=dim) |
|
return p_m |
|
|
|
|
|
def entmax(x, alpha=1.3, dim=None, n_iter=50): |
|
kwargs = dict(alpha=alpha, dim=dim, n_iter=n_iter) |
|
|
|
return mtf.custom_gradient( |
|
partial(entmax_forward, **kwargs), |
|
partial(entmax_backward, **kwargs), |
|
[x] |
|
) |
|
|
|
|
|
def entmax_cross_entropy_with_logits(logits, targets, vocab_dim, z_loss=0.0): |
|
if targets.dtype.is_integer: |
|
|
|
if (set(targets.shape.dims) != set(logits.shape.dims).difference([vocab_dim])): |
|
raise ValueError( |
|
"softmax_cross_entropy_with_logits with hard targets " |
|
"dims in targets=%s should be dims in logits=%s other than " |
|
"vocab_dim=%s" % (targets, logits, vocab_dim)) |
|
targets = mtf.one_hot(targets, vocab_dim, dtype=logits.dtype) |
|
elif set(targets.shape.dims) != set(logits.shape.dims): |
|
raise ValueError( |
|
"softmax_cross_entropy_with_logits with soft targets " |
|
"dims in targets=%s should be dims in logits=%s" % (targets, logits)) |
|
|
|
if vocab_dim not in logits.shape.dims: |
|
raise ValueError("vocab_dim must be in logits.shape.dims") |
|
|
|
log_entmax = mtf.log(entmax(logits, dim=vocab_dim)) |
|
|
|
loss = mtf.negative( |
|
mtf.reduce_sum(log_entmax * targets, reduced_dim=vocab_dim)) |
|
|
|
return loss |
|
|
|
|
|
def sample_categorical(x, dim=None): |
|
dim = x.shape[-1] if dim is None else dim |
|
|
|
cdf = mtf.cumsum(x, dim) |
|
rand_uniform = mtf.random_uniform(x.mesh, x.shape - dim, minval=0, maxval=1) |
|
mask = mtf.cast(mtf.greater(cdf, rand_uniform), tf.int32) |
|
return mtf.argmax(mask, dim) |
|
|
|
|
|
def biasmask_attn_weights(mesh, nd, ns, variable_dtype): |
|
|
|
|
|
|
|
|
|
|
|
|
|
i = mtf.range(mesh, nd, tf.int32) + ns.size - nd.size |
|
j = mtf.range(mesh, ns, tf.int32) |
|
i, j = map(lambda t: mtf.broadcast(t, [nd, ns]), (i, j)) |
|
dtype = variable_dtype.activation_dtype |
|
return mtf.cast(mtf.less(i, j), dtype) * -1e10 |
|
|
|
|
|
def parse_inputs(mtf_features, other_features): |
|
|
|
|
|
x = mtf_features["inputs"] |
|
|
|
batch_dim = x.shape[0] |
|
sequence_dim = x.shape[1] |
|
embd_dim = other_features["embd_dim"] |
|
vocab_dim = other_features["vocab_dim"] |
|
embed_sequence_dim = other_features["embed_sequence_dim"] |
|
|
|
return x, batch_dim, sequence_dim, embd_dim, vocab_dim, embed_sequence_dim |
|
|