File size: 9,324 Bytes
c6e7238 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 |
import mesh_tensorflow as mtf
import tensorflow.compat.v1 as tf
import mesh_tensorflow.transformer as mtf_transformer
from models.utils import entmax, sample_categorical
from models.gpt2 import gpt2
def sample_autoregressive(partial_sequences,
other_features,
params,
stop_at_token=50256,
max_steps=None,
temperature=0.9,
variable_dtype=mtf.VariableDType(tf.float32),
encoder_output=None,
encoder_sequence_id=None,
encoder_inputs=None,
shared_params=None,
has_partial_sequences=True,
encoder_layer_outputs=None,
never_end=False,
remove_partial_sequences=False,
sampling_keep_top_k=-1,
sampling_use_entmax = False,
bos_id=50256,
):
"""Sample randomly one token at a time.
The partial_sequences represent partial sequences to be continued. The
first tokens of each sequence are nonzero representing the given partial
sequences and the last tokens of each sequence are zeros, representing what
needs to be filled in.
If there are no partial sequences (you want to sample from the beginning),
then pass partial_sequences=mtf.zeros(mesh, shape, dtype=tf.int32) and
has_partial_sequences=False (so we can skip computation).
Args:
partial_sequences: an int32 Tensor with shape [<batch_dims>, length_dim]
stop_at_token: an optional integer eos id. Stop when we produce it.
max_steps: an optional integer, the max number of steps to decode.
temperature: an optional floating point value between 0.0 and 1.0 0.0
means argmax, 1.0 means sample according to predicted distribution.
variable_dtype: a mtf.VariableDType
encoder_output: an optional Tensor
encoder_sequence_id: an optional Tensor
encoder_inputs: an optional Tensor
shared_params: an optional dictionary
has_partial_sequences: a boolean
encoder_layer_outputs: optional - readonly list of tensor activations when
decoding, one per each input layer + the embedding layer
never_end: a boolean - if set, then avoid generating stop_at_token
remove_partial_sequences: a boolean - whether to remove the partial
sequences from the output
sampling_keep_top_k: an integer - if not -1, only sample from the top k
logits.
bos_id: beginning of sequence id
Returns:
a Tensor with shape [<batch_dims>, length_dim]
"""
inputs = partial_sequences # Partial sequences to fill in
batch_dims = inputs.shape.dims[:-1]
length_dim = inputs.shape.dims[-1]
padding_id = params.get("padding_id", 0)
slow_sampling = params.get("slow_sampling", False)
initial_position = mtf.reduce_sum(
mtf.to_int32(mtf.not_equal(inputs, padding_id)), reduced_dim=length_dim) # Gets position where zero padding starts
length_range = mtf.range(inputs.mesh, length_dim, tf.int32)
input_full_attention = True # for now hardcode this to true bc lazy
if input_full_attention:
# Vanilla autoregressive model - each position can see previous positions.
# Think this feeds in to the loop fn and tells each position where it can attend to?
read_priority = write_priority = length_range * mtf.to_int32(
mtf.greater(length_range, initial_position))
else:
read_priority = write_priority = length_range
# Builds context to pass around internally
# The 'first part' context records initial states of k / v / x
if not slow_sampling:
context_first_part = mtf_transformer.transformer.Context(
model=None,
mesh=inputs.mesh,
batch_dims=batch_dims,
length_dim=length_dim,
variable_dtype=variable_dtype,
mode="first_part",
position=length_range,
position_is_default=True,
new_states=[],
initial_position=initial_position,
sequence_id=None,
encoder_output=encoder_output,
encoder_sequence_id=encoder_sequence_id,
constant_states=[],
shared_params=shared_params,
encoder_layer_outputs=encoder_layer_outputs,
write_priority=write_priority,
read_priority=read_priority,
inputs=inputs,
encoder_inputs=encoder_inputs)
with tf.variable_scope("gpt2"):
logits, _, _ = gpt2.model({"inputs": inputs}, other_features, params, inputs.mesh, variable_dtype=variable_dtype, context=context_first_part)
if not has_partial_sequences:
initial_states = [mtf.zeros_like(t) for t in context_first_part.new_states]
else:
initial_states = context_first_part.new_states
else:
initial_states = []
if not has_partial_sequences:
partial_sequences_eos_count = 0
if stop_at_token is not None:
partial_sequences_eos_count = mtf.reduce_sum(
mtf.to_int32(mtf.equal(partial_sequences, stop_at_token)),
reduced_dim=length_dim)
def cond_fn(position, ids, *unused_states):
"""Should we run another loop iteration?"""
past_end = mtf.greater_equal(position, length_dim.size)
if max_steps:
past_end = mtf.logical_or(
past_end, mtf.greater_equal(position - initial_position, max_steps))
is_done = past_end
if stop_at_token is not None:
eos_count = mtf.reduce_sum(
mtf.to_int32(mtf.equal(ids, stop_at_token)),
reduced_dim=length_dim)
has_additional_eos = mtf.greater(eos_count, partial_sequences_eos_count)
is_done = mtf.logical_or(is_done, has_additional_eos)
all_done = mtf.reduce_all(is_done)
return mtf.logical_not(all_done)
def body_fn(position, ids, *states):
"""One step in the decode loop."""
nonlocal sampling_keep_top_k
context = mtf_transformer.transformer.Context(
model=None,
mesh=inputs.mesh,
batch_dims=batch_dims,
length_dim=length_dim,
variable_dtype=variable_dtype,
mode="incremental",
position=position,
position_is_default=True,
states=states,
new_states=[],
initial_position=position,
sequence_id=None,
encoder_output=encoder_output,
encoder_sequence_id=encoder_sequence_id,
shared_params=shared_params,
encoder_layer_outputs=encoder_layer_outputs,
write_priority=write_priority,
read_priority=read_priority,
inputs=ids,
encoder_inputs=encoder_inputs) if not slow_sampling else None
with tf.variable_scope("gpt2", reuse=tf.AUTO_REUSE):
logits, _, _ = gpt2.model({"inputs": ids}, other_features, params, inputs.mesh, variable_dtype=variable_dtype, context = context)
if not sampling_use_entmax:
# By default, do top_k sampling of 0.9
if sampling_keep_top_k == -2:
sampling_keep_top_k = int(logits.shape[-1].size * 0.1)
if sampling_keep_top_k != -1:
if sampling_keep_top_k <= 0:
raise ValueError("sampling_keep_top_k must either be -1 or positive.")
k_largest = mtf.nth_largest_element(
logits, n=sampling_keep_top_k,
reduced_dim=other_features["vocab_dim"])
logits = mtf.where(mtf.less_equal(logits, k_largest),
mtf.ones_like(logits) * -1e6, logits)
ids_this_step = mtf.sample_with_temperature(
logits, other_features["vocab_dim"], temperature)
else:
ids_this_step = sample_categorical(entmax(logits))
if slow_sampling:
ids_this_step = mtf.shift(ids_this_step, offset=1, dim=length_dim, wrap=False)
else:
ids_this_step = mtf.reshape(ids_this_step, (batch_dims))
one_hot = mtf.one_hot(position, length_dim, dtype=tf.int32)
one_new_id = ids_this_step * one_hot
new_ids = (1 - one_hot) * ids + one_new_id
new_position = position + 1
ret = [new_position, new_ids]
if context is not None:
ret += context.new_states
return ret
while_loop_inputs = [initial_position, inputs] + initial_states
final_position, outputs = mtf.while_loop(
cond_fn, body_fn, while_loop_inputs)[:2]
del final_position
if has_partial_sequences and remove_partial_sequences:
# Remove partial sequences from outputs
partial_length = mtf.reduce_sum(
mtf.to_int32(mtf.not_equal(partial_sequences, padding_id)),
reduced_dim=length_dim)
outputs = mtf.dynamic_shift(
outputs, -partial_length, length_dim, wrap=False)
return outputs
|