File size: 14,740 Bytes
c6e7238 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 |
import mesh_tensorflow as mtf
import tensorflow.compat.v1 as tf
from tensorflow.python.tpu import tpu_estimator
import mesh_tensorflow.transformer as mtf_transformer
from optimizers import get_optimizer
from utils import (create_host_call, get_graph_info, remove_batch_from_layout, simd_mesh_setup, add_mode_to_params,
get_batch_size, auto_layout, auto_layout_and_mesh_shape)
from models.utils import biasmask_attn_weights
from tensorflow.python.ops import resources
from sample import sample_autoregressive
from models.gpt2 import gpt2
import math
def model_fn(features, labels, mode, params):
# Get global step
global_step = tf.train.get_global_step()
# Construct mtf graph + mesh from params
graph = mtf.Graph()
mesh_shape = mtf.convert_to_shape(params["mesh_shape"])
layout_rules = mtf.convert_to_layout_rules(params["layout"])
# Mesh setup
if params["use_tpu"]:
var_placer, mesh_impl = simd_mesh_setup(params, mesh_shape, layout_rules)
else:
var_placer = None
gpu_ids = params["gpu_ids"]
mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
mesh_shape, layout_rules, gpu_ids)
# Trainable variable precision
# Store to checkpoints in master type, train in slice type, compute in activation type
if params["precision"] == "bfloat16":
variable_dtype = mtf.VariableDType(master_dtype=tf.bfloat16, slice_dtype=tf.float32,
activation_dtype=tf.bfloat16)
else:
variable_dtype = mtf.VariableDType(master_dtype=tf.float32, slice_dtype=tf.float32, activation_dtype=tf.float32)
# Build mtf mesh object
mesh = mtf.Mesh(graph, "my_mesh", var_placer)
# Build mtf_features & seq length dict for getting number of microbatches
# We need to pack inputs into a dict to pass into serialize_training_step
features_dict = {"inputs": features, "labels": labels}
sequence_length_dict = {"inputs": params["n_ctx"], "labels": params["n_ctx"]}
params = add_mode_to_params(params, mode)
batch_size = get_batch_size(params)
batch_dim = mtf.Dimension("batch", batch_size)
batch_dims = [batch_dim]
feature_length = sequence_length_dict["inputs"]
length_dim = mtf.Dimension("sequence", feature_length)
mtf_features = {}
for key, x in features_dict.items():
if x is not None:
feature_shape = mtf.Shape(batch_dims + [length_dim])
if type(features_dict[key]) == dict:
features_dict[key] = features_dict[key]["feature"]
x = tf.cast(features_dict[key], tf.int32)
x = tf.reshape(x, feature_shape.to_integer_list)
mtf_features[key] = mtf.import_fully_replicated(
mesh, x, feature_shape, name=key)
# Instantiate dict for dimensions, bias, etc that can be calculated here once then passed into model
other_features = {}
memory_length_dim = mtf.Dimension("memory_length", length_dim.size)
attn_bias = biasmask_attn_weights(mesh, length_dim, memory_length_dim, variable_dtype) if params["causal"] else None
# Add attn_bias into mtf_features
other_features["attn_bias"] = attn_bias
# Define other Dimensions that we'll need inside the model
embd_dim = mtf.Dimension("embd", params["n_embd"])
vocab_dim = mtf.Dimension("vocab", params["n_vocab"])
# We need this because gathering when both the args have the same dimension in them breaks things
# This dim is specifically for the weights
# This prevents the "Einsum has lhs dimension without corresponding rhs or output dimension." error
embed_sequence_dim = mtf.Dimension("embed_sequence", params["n_ctx"])
other_features["embd_dim"] = embd_dim
other_features["vocab_dim"] = vocab_dim
other_features["embed_sequence_dim"] = embed_sequence_dim
other_features["memory_length_dim"] = memory_length_dim
if mode == tf.estimator.ModeKeys.PREDICT:
# Set up the model for prediction
inputs = mtf_features["inputs"]
if params["remove_partial_sequences"] is None:
params["remove_partial_sequences"] = False
export = params.get("export", False)
if not export:
mtf_samples = sample_autoregressive(
inputs, other_features=other_features, params=params, variable_dtype=variable_dtype,
remove_partial_sequences=params["remove_partial_sequences"], stop_at_token=params["eos_id"],
sampling_use_entmax=params['sampling_use_entmax'], max_steps=params["predict_max_steps"])
else:
with mtf.utils.outside_all_rewrites():
with tf.variable_scope('gpt2'):
mtf_samples, loss, loss_batch = gpt2.model(mtf_features, other_features, params, mesh,
variable_dtype=variable_dtype, context=None)
mtf_samples = mtf.anonymize(mtf_samples)
inputs = mtf.anonymize(inputs)
lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=True)
inputs = lowering.export_to_tf_tensor(inputs)
outputs = lowering.export_to_tf_tensor(mtf_samples)
predictions = {
"inputs": inputs,
"outputs": outputs}
def scaffold_fn():
return tf.train.Scaffold(
local_init_op=tf.group(
tf.train.Scaffold.default_local_init_op(),
lowering.copy_masters_to_slices(),
name="mtf_local_init_op"),
ready_op=tf.concat(
[tf.report_uninitialized_variables(),
resources.report_uninitialized_resources()],
axis=0,
name="mtf_ready_op"))
return tpu_estimator.TPUEstimatorSpec(
mode=tf.estimator.ModeKeys.PREDICT,
predictions=predictions,
scaffold_fn=scaffold_fn,
prediction_hooks=[mtf.MtfRestoreHook(lowering)])
# We're not predicting, so we better be training or evaluating
assert (mode == tf.estimator.ModeKeys.TRAIN or mode == tf.estimator.ModeKeys.EVAL)
if mode == tf.estimator.ModeKeys.TRAIN:
# Gets number of microbatches per batch for serialized training
# if param tokens_per_mb_per_replica = None, this defaults to 1 and no microbatching is performed
num_microbatches = int(mtf_transformer.utils.serialize_num_microbatches(batch_dim=batch_dim,
sequence_length=sequence_length_dict,
mesh_shape=mesh_shape,
layout_rules=layout_rules,
tokens_per_microbatch_per_replica=
params["tokens_per_mb_per_replica"]))
else:
num_microbatches = 1
params["num_microbatches"] = num_microbatches # Add num microbatches to params
if num_microbatches > 1:
# For serialize_training_step we need to modify the model to output results in a dict
def serialized_fn(mtf_features):
if params["model"] == "GPT":
with tf.variable_scope('gpt2'):
logits, loss, loss_batch = gpt2.model(mtf_features, other_features, params, mesh,
variable_dtype=variable_dtype)
return {"logits": logits, "loss": loss, "loss_batch": loss_batch}
else:
raise Exception(f"'{params['model']}' is not a valid model - please select from [GPT]")
# Serialize the training step - Gradients are accumulated locally and reduced once.
var_grads, output_dict = mtf.serialize_training_step(mtf_features, serialized_fn, batch_dim, num_microbatches)
loss = output_dict["loss"]
loss_batch = output_dict["loss_batch"]
logits = output_dict["logits"]
else:
# If we're not splitting into microbatches, return logits & loss as is
if params["model"] == "GPT":
with mtf.utils.outside_all_rewrites():
with tf.variable_scope('gpt2'):
logits, loss, loss_batch = gpt2.model(mtf_features, other_features, params, mesh,
variable_dtype=variable_dtype, context=None)
else:
raise Exception(f"'{params['model']}' is not a valid model - please select from [GPT]")
# Auto layout generation
if params["auto_layout"]:
auto_layout(graph, mesh_shape, logits, loss)
if params["auto_layout_and_mesh_shape"]:
auto_layout_and_mesh_shape(graph, params["num_cores"], logits, loss)
if mode == tf.estimator.ModeKeys.TRAIN:
# In TRAIN mode, get optimizer
if params["num_microbatches"] > 1:
# If we are splitting the batch into microbatches, var grads are created in the serialize_training_step fn
# So we pass them in here
_, update_ops, var_grads = get_optimizer(mesh, loss, params, variable_dtype=variable_dtype,
inp_var_grads=var_grads)
else:
# Otherwise, they are created in the get_optimizer fn, so we leave inp_var_grads blank
_, update_ops, var_grads = get_optimizer(mesh, loss, params, variable_dtype=variable_dtype)
# Log summaries to tensorboard
mtf.scalar_summary("loss", loss)
# Log gradients if in params
if params["log_grads"] not in [None, False]:
for g in var_grads:
grad_norm = mtf.sqrt(mtf.reduce_sum(mtf.square(g)))
mtf.scalar_summary("grads/norm" + g.name[:-2], grad_norm)
else:
# For now, we can only export fully-replicated tensors.
# This has to be done before lowering or they will not be included in the graph
mean_logits = mtf.reduce_mean(logits, reduced_dim=vocab_dim)
max_logits = mtf.argmax(logits, vocab_dim)
del logits
fully_replicated_mean_logits = mtf.anonymize(mean_logits)
fully_replicated_max_logits = mtf.anonymize(max_logits)
fully_replicated_loss_batch = mtf.anonymize(loss_batch)
# Gets & prints info about no. trainable vars in the model & dimension names
get_graph_info(graph)
# 'lowers' mtf tensors into a tf graph - this enables us to export results as tf tensors
lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=True)
tf_loss = lowering.export_to_tf_tensor(loss)
tf_loss = tf.cast(tf_loss, tf.float32)
if mode == tf.estimator.ModeKeys.TRAIN:
# Use our patched version until mtf updates theirs
host_call = create_host_call(params['model_path'])
mtf.utils.remove_summaries()
# Creates train_op
tf_update_ops = [lowering.lowered_operation(op) for op in update_ops]
tf_update_ops.append(tf.assign_add(global_step, 1)) # Need to manually increment global_step
tf.logging.info(f"tf_update_ops: {tf_update_ops}")
train_op = tf.group(tf_update_ops)
else:
tf_mean_logits = lowering.export_to_tf_tensor(fully_replicated_mean_logits)
tf_max_logits = lowering.export_to_tf_tensor(fully_replicated_max_logits)
tf_loss_batch = tf.to_float(lowering.export_to_tf_tensor(fully_replicated_loss_batch))
with mtf.utils.outside_all_rewrites():
# Copy master variables to slices. Must be called first.
restore_hook = mtf.MtfRestoreHook(lowering)
if mode == tf.estimator.ModeKeys.TRAIN:
# Set up the checkpoint server and return the TPUEstimatorSpec
saver = tf.train.Saver(
tf.global_variables(),
sharded=True,
max_to_keep=10,
keep_checkpoint_every_n_hours=2,
defer_build=False,
save_relative_paths=True)
tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
saver_listener = mtf.MtfCheckpointSaverListener(lowering)
saver_hook = tf.train.CheckpointSaverHook(
params["model_path"],
save_steps=params["steps_per_checkpoint"],
saver=saver,
listeners=[saver_listener])
return tpu_estimator.TPUEstimatorSpec(
tf.estimator.ModeKeys.TRAIN,
loss=tf_loss,
host_call=host_call,
train_op=train_op,
training_hooks=[restore_hook, saver_hook])
elif mode == tf.estimator.ModeKeys.EVAL:
# Evaluation metrics
def _perplexity(loss):
perplexity = tf.exp(loss)
return tf.metrics.mean(perplexity)
def _bits_per_byte(loss):
bpb = loss * (0.29335 / math.log(2))
return tf.metrics.mean(bpb)
def _metric_fn(tf_mean_logits, tf_loss_batch):
mean_logits = tf.metrics.mean(tf_mean_logits)
loss = tf.reduce_mean(tf_loss_batch)
perp = _perplexity(loss)
bpb = _bits_per_byte(loss)
return {"mean_logits": mean_logits, "perplexity": perp, "bits per byte": bpb}
def _lambada_metric_fn(labels, tf_max_logits, tf_loss_batch):
eos_token = params["eos_id"]
answer_positions = tf.where(tf.math.not_equal(labels, eos_token))
correct_answers = tf.gather_nd(tf.math.equal(tf_max_logits, labels), answer_positions)
accuracy = tf.metrics.mean(tf.cast(correct_answers, tf.float32))
# I guess tf_loss_batch has z_loss and maybe other stuff added to it
# so maybe this should be calculated separately in the future
answer_loss = tf.gather_nd(tf_loss_batch, answer_positions)
log_perplexity = tf.metrics.mean(answer_loss)
return {"lambada_acc": accuracy, "lambada_log_ppl": log_perplexity}
eval_task = params["eval_task"]
if eval_task == "lambada":
eval_metrics = (_lambada_metric_fn, [labels, tf_max_logits, tf_loss_batch])
else:
eval_metrics = (_metric_fn, [tf_mean_logits, tf_loss_batch])
return tpu_estimator.TPUEstimatorSpec(
tf.estimator.ModeKeys.EVAL,
evaluation_hooks=[restore_hook],
loss=tf_loss,
eval_metrics=eval_metrics)
|