File size: 4,373 Bytes
c6e7238
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import tensorflow as tf
import mesh_tensorflow as mtf
from functools import partial


def entmax_backward(explicit_inputs, all_inputs, forward_operations, outputs, output_grads, alpha=1.3, dim=None,
                    n_iter=50):
    x, = explicit_inputs
    y, = outputs
    dY, = output_grads

    gppr = mtf.where(mtf.greater(y, 0), mtf.pow(y, (2 - alpha)), mtf.zeros_like(y))
    dX = dY * gppr

    q = mtf.reduce_sum(dX, reduced_dim=dim) / mtf.reduce_sum(gppr, reduced_dim=dim)
    dX = dX - q * gppr

    return dX,


def entmax_forward(x, alpha=1.3, dim=None, n_iter=50):
    assert alpha > 1 and alpha < 2, 'alpha must be between 1 and 2'

    _gp = lambda x, alpha: x ** (alpha - 1)
    _gp_inv = lambda x, alpha: mtf.pow(x, (1 / (alpha - 1)))
    _p = lambda x, alpha: _gp_inv(mtf.relu(x), alpha)

    dim = x.shape[-1] if dim is None else dim
    d = dim.size

    x = x * (alpha - 1)

    max_val = mtf.reduce_max(x, reduced_dim=dim)

    tau_lo = max_val - _gp(1, alpha)
    tau_hi = max_val - _gp(1 / d, alpha)

    f_lo = mtf.reduce_sum(_p(x - tau_lo, alpha), reduced_dim=dim) - 1

    dm = tau_hi - tau_lo

    for _ in range(n_iter):
        dm = dm / 2
        tau_m = tau_lo + dm
        p_m = _p(x - tau_m, alpha)
        f_m = mtf.reduce_sum(p_m, reduced_dim=dim) - 1

        mask = mtf.greater_equal((f_m * f_lo), 0)
        tau_lo = mtf.where(mask, tau_m, tau_lo)

    p_m = p_m / mtf.reduce_sum(p_m, reduced_dim=dim)
    return p_m


def entmax(x, alpha=1.3, dim=None, n_iter=50):
    kwargs = dict(alpha=alpha, dim=dim, n_iter=n_iter)

    return mtf.custom_gradient(
        partial(entmax_forward, **kwargs),
        partial(entmax_backward, **kwargs),
        [x]
    )


def entmax_cross_entropy_with_logits(logits, targets, vocab_dim, z_loss=0.0):
    if targets.dtype.is_integer:
        # hard targets
        if (set(targets.shape.dims) != set(logits.shape.dims).difference([vocab_dim])):
            raise ValueError(
                "softmax_cross_entropy_with_logits with hard targets "
                "dims in targets=%s should be dims in logits=%s other than "
                "vocab_dim=%s" % (targets, logits, vocab_dim))
        targets = mtf.one_hot(targets, vocab_dim, dtype=logits.dtype)
    elif set(targets.shape.dims) != set(logits.shape.dims):
        raise ValueError(
            "softmax_cross_entropy_with_logits with soft targets "
            "dims in targets=%s should be dims in logits=%s" % (targets, logits))

    if vocab_dim not in logits.shape.dims:
        raise ValueError("vocab_dim must be in logits.shape.dims")

    log_entmax = mtf.log(entmax(logits, dim=vocab_dim))

    loss = mtf.negative(
        mtf.reduce_sum(log_entmax * targets, reduced_dim=vocab_dim))

    return loss


def sample_categorical(x, dim=None):
    dim = x.shape[-1] if dim is None else dim

    cdf = mtf.cumsum(x, dim)
    rand_uniform = mtf.random_uniform(x.mesh, x.shape - dim, minval=0, maxval=1)
    mask = mtf.cast(mtf.greater(cdf, rand_uniform), tf.int32)
    return mtf.argmax(mask, dim)


def biasmask_attn_weights(mesh, nd, ns, variable_dtype):
    # The old mask_attn_weights applied directly to the QK;
    # this returns a bias that the attention code from mtf adds to the attention matrix.
    # w has shape [batch, heads, dst_sequence, src_sequence], where information flows from src to dst.
    # n_src and n_dest are both the same, i.e equal to sequence length
    # We rename ns because we want bias to have shape [batch, heads, memory_length, sequence] to match up with QK^T
    # Information flows from k and v (memory_length) to q (sequence)
    i = mtf.range(mesh, nd, tf.int32) + ns.size - nd.size
    j = mtf.range(mesh, ns, tf.int32)
    i, j = map(lambda t: mtf.broadcast(t, [nd, ns]), (i, j))
    dtype = variable_dtype.activation_dtype
    return mtf.cast(mtf.less(i, j), dtype) * -1e10


def parse_inputs(mtf_features, other_features):
    # Parse inputs and labels from the mtf_features / other_features input dicts
    # All dimensions are defined inside model_fn for efficiency
    x = mtf_features["inputs"]

    batch_dim = x.shape[0]
    sequence_dim = x.shape[1]
    embd_dim = other_features["embd_dim"]
    vocab_dim = other_features["vocab_dim"]
    embed_sequence_dim = other_features["embed_sequence_dim"]

    return x, batch_dim, sequence_dim, embd_dim, vocab_dim, embed_sequence_dim