import tensorflow as tf |
import mesh_tensorflow as mtf |
from functools import partial |
def entmax_backward(explicit_inputs, all_inputs, forward_operations, outputs, output_grads, alpha=1.3, dim=None, |
n_iter=50): |
x, = explicit_inputs |
y, = outputs |
dY, = output_grads |
gppr = mtf.where(mtf.greater(y, 0), mtf.pow(y, (2 - alpha)), mtf.zeros_like(y)) |
dX = dY * gppr |
q = mtf.reduce_sum(dX, reduced_dim=dim) / mtf.reduce_sum(gppr, reduced_dim=dim) |
dX = dX - q * gppr |
return dX, |
def entmax_forward(x, alpha=1.3, dim=None, n_iter=50): |
assert alpha > 1 and alpha < 2, 'alpha must be between 1 and 2' |
_gp = lambda x, alpha: x ** (alpha - 1) |
_gp_inv = lambda x, alpha: mtf.pow(x, (1 / (alpha - 1))) |
_p = lambda x, alpha: _gp_inv(mtf.relu(x), alpha) |
dim = x.shape[-1] if dim is None else dim |
d = dim.size |
x = x * (alpha - 1) |
max_val = mtf.reduce_max(x, reduced_dim=dim) |
tau_lo = max_val - _gp(1, alpha) |
tau_hi = max_val - _gp(1 / d, alpha) |
f_lo = mtf.reduce_sum(_p(x - tau_lo, alpha), reduced_dim=dim) - 1 |
dm = tau_hi - tau_lo |
for _ in range(n_iter): |
dm = dm / 2 |
tau_m = tau_lo + dm |
p_m = _p(x - tau_m, alpha) |
f_m = mtf.reduce_sum(p_m, reduced_dim=dim) - 1 |
mask = mtf.greater_equal((f_m * f_lo), 0) |
tau_lo = mtf.where(mask, tau_m, tau_lo) |
p_m = p_m / mtf.reduce_sum(p_m, reduced_dim=dim) |
return p_m |
def entmax(x, alpha=1.3, dim=None, n_iter=50): |
kwargs = dict(alpha=alpha, dim=dim, n_iter=n_iter) |
return mtf.custom_gradient( |
partial(entmax_forward, **kwargs), |
partial(entmax_backward, **kwargs), |
[x] |
) |
def entmax_cross_entropy_with_logits(logits, targets, vocab_dim, z_loss=0.0): |
if targets.dtype.is_integer: |
if (set(targets.shape.dims) != set(logits.shape.dims).difference([vocab_dim])): |
raise ValueError( |
"softmax_cross_entropy_with_logits with hard targets " |
"dims in targets=%s should be dims in logits=%s other than " |
"vocab_dim=%s" % (targets, logits, vocab_dim)) |
targets = mtf.one_hot(targets, vocab_dim, dtype=logits.dtype) |
elif set(targets.shape.dims) != set(logits.shape.dims): |
raise ValueError( |
"softmax_cross_entropy_with_logits with soft targets " |
"dims in targets=%s should be dims in logits=%s" % (targets, logits)) |
if vocab_dim not in logits.shape.dims: |
raise ValueError("vocab_dim must be in logits.shape.dims") |
log_entmax = mtf.log(entmax(logits, dim=vocab_dim)) |
loss = mtf.negative( |
mtf.reduce_sum(log_entmax * targets, reduced_dim=vocab_dim)) |
return loss |
def sample_categorical(x, dim=None): |
dim = x.shape[-1] if dim is None else dim |
cdf = mtf.cumsum(x, dim) |
rand_uniform = mtf.random_uniform(x.mesh, x.shape - dim, minval=0, maxval=1) |
mask = mtf.cast(mtf.greater(cdf, rand_uniform), tf.int32) |
return mtf.argmax(mask, dim) |
def biasmask_attn_weights(mesh, nd, ns, variable_dtype): |
i = mtf.range(mesh, nd, tf.int32) + ns.size - nd.size |
j = mtf.range(mesh, ns, tf.int32) |
i, j = map(lambda t: mtf.broadcast(t, [nd, ns]), (i, j)) |
dtype = variable_dtype.activation_dtype |
return mtf.cast(mtf.less(i, j), dtype) * -1e10 |
def parse_inputs(mtf_features, other_features): |
x = mtf_features["inputs"] |
batch_dim = x.shape[0] |
sequence_dim = x.shape[1] |
embd_dim = other_features["embd_dim"] |
vocab_dim = other_features["vocab_dim"] |
embed_sequence_dim = other_features["embed_sequence_dim"] |
return x, batch_dim, sequence_dim, embd_dim, vocab_dim, embed_sequence_dim |