{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "6bYaCABobL5q" }, "source": [ "##### Copyright 2021 The TensorFlow Authors." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "cellView": "form", "execution": { "iopub.execute_input": "2022-12-14T03:38:18.997731Z", "iopub.status.busy": "2022-12-14T03:38:18.997516Z", "iopub.status.idle": "2022-12-14T03:38:19.001439Z", "shell.execute_reply": "2022-12-14T03:38:19.000889Z" }, "id": "FlUw7tSKbtg4" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "MfBg1C5NB3X0" }, "source": [ "\n", " \n", " \n", " \n", " \n", "
\n", " View on TensorFlow.org\n", " \n", " Run in Google Colab\n", " \n", " View on GitHub\n", " \n", " Download notebook\n", "
" ] }, { "cell_type": "markdown", "metadata": { "id": "yAMJsAn7NDbc" }, "source": [ "# Validating correctness & numerical equivalence" ] }, { "cell_type": "markdown", "metadata": { "id": "vyddl2kckpdN" }, "source": [ "When migrating your TensorFlow code from TF1.x to TF2, it is a good practice to ensure that your migrated code behaves the same way in TF2 as it did in TF1.x. \n", "\n", "This guide covers migration code examples with the `tf.compat.v1.keras.utils.track_tf1_style_variables` modeling shim applied to `tf.keras.layers.Layer` methods. Read the [model mapping guide](./model_mapping.ipynb) to find out more about the TF2 modeling shims.\n", "\n", "This guide details approaches you can use to: \n", "* Validate the correctness of the results obtained from training models using the migrated code \n", "* Validate the numerical equivalence of your code across TensorFlow versions" ] }, { "cell_type": "markdown", "metadata": { "id": "TaYgaekzOAHf" }, "source": [ "## Setup" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T03:38:19.005343Z", "iopub.status.busy": "2022-12-14T03:38:19.004839Z", "iopub.status.idle": "2022-12-14T03:38:21.539848Z", "shell.execute_reply": "2022-12-14T03:38:21.538733Z" }, "id": "FkHX044DzVsd" }, "outputs": [], "source": [ "!pip uninstall -y -q tensorflow" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T03:38:21.544536Z", "iopub.status.busy": "2022-12-14T03:38:21.543790Z", "iopub.status.idle": "2022-12-14T03:38:44.840458Z", "shell.execute_reply": "2022-12-14T03:38:44.839299Z" }, "id": "M1ZgieHtyzKI" }, "outputs": [], "source": [ "# Install tf-nightly as the DeterministicRandomTestTool is available only in\n", "# Tensorflow 2.8\n", "!pip install -q tf-nightly" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T03:38:44.845262Z", "iopub.status.busy": "2022-12-14T03:38:44.844580Z", "iopub.status.idle": "2022-12-14T03:38:46.873821Z", "shell.execute_reply": "2022-12-14T03:38:46.872724Z" }, "id": "ohYETq4NCX4J" }, "outputs": [], "source": [ "!pip install -q tf_slim" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T03:38:46.878200Z", "iopub.status.busy": "2022-12-14T03:38:46.877941Z", "iopub.status.idle": "2022-12-14T03:38:49.308586Z", "shell.execute_reply": "2022-12-14T03:38:49.307853Z" }, "id": "MFey2HxcktP6" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2022-12-14 03:38:47.140140: E tensorflow/tsl/lib/monitoring/collection_registry.cc:81] Cannot register 2 metrics with the same name: /tensorflow/core/bfc_allocator_delay\n" ] } ], "source": [ "import tensorflow as tf\n", "import tensorflow.compat.v1 as v1\n", "\n", "import numpy as np\n", "import tf_slim as slim\n", "import sys\n", "\n", "\n", "from contextlib import contextmanager" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T03:38:49.312559Z", "iopub.status.busy": "2022-12-14T03:38:49.312144Z", "iopub.status.idle": "2022-12-14T03:38:53.362680Z", "shell.execute_reply": "2022-12-14T03:38:53.361644Z" }, "id": "OriidSSAmRtW" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Cloning into 'models'...\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "remote: Enumerating objects: 3590, done.\u001b[K\r\n", "remote: Counting objects: 0% (1/3590)\u001b[K\r", "remote: Counting objects: 1% (36/3590)\u001b[K\r", "remote: Counting objects: 2% (72/3590)\u001b[K\r", "remote: Counting objects: 3% (108/3590)\u001b[K\r", "remote: Counting objects: 4% (144/3590)\u001b[K\r", "remote: Counting objects: 5% (180/3590)\u001b[K\r", "remote: Counting objects: 6% (216/3590)\u001b[K\r", "remote: Counting objects: 7% (252/3590)\u001b[K\r", "remote: Counting objects: 8% (288/3590)\u001b[K\r", "remote: Counting objects: 9% (324/3590)\u001b[K\r", "remote: Counting objects: 10% (359/3590)\u001b[K\r", "remote: Counting objects: 11% (395/3590)\u001b[K\r", "remote: Counting objects: 12% (431/3590)\u001b[K\r", "remote: Counting objects: 13% (467/3590)\u001b[K\r", "remote: Counting objects: 14% (503/3590)\u001b[K\r", "remote: Counting objects: 15% (539/3590)\u001b[K\r", "remote: Counting objects: 16% (575/3590)\u001b[K\r", "remote: Counting objects: 17% (611/3590)\u001b[K\r", "remote: Counting objects: 18% (647/3590)\u001b[K\r", "remote: Counting objects: 19% (683/3590)\u001b[K\r", "remote: Counting objects: 20% (718/3590)\u001b[K\r", "remote: Counting objects: 21% (754/3590)\u001b[K\r", "remote: Counting objects: 22% (790/3590)\u001b[K\r", "remote: Counting objects: 23% (826/3590)\u001b[K\r", "remote: Counting objects: 24% (862/3590)\u001b[K\r", "remote: Counting objects: 25% (898/3590)\u001b[K\r", "remote: Counting objects: 26% (934/3590)\u001b[K\r", "remote: Counting objects: 27% (970/3590)\u001b[K\r", "remote: Counting objects: 28% (1006/3590)\u001b[K\r", "remote: Counting objects: 29% (1042/3590)\u001b[K\r", "remote: Counting objects: 30% (1077/3590)\u001b[K\r", "remote: Counting objects: 31% (1113/3590)\u001b[K\r", "remote: Counting objects: 32% (1149/3590)\u001b[K\r", "remote: Counting objects: 33% (1185/3590)\u001b[K\r", "remote: Counting objects: 34% (1221/3590)\u001b[K\r", "remote: Counting objects: 35% (1257/3590)\u001b[K\r", "remote: Counting objects: 36% (1293/3590)\u001b[K\r", "remote: Counting objects: 37% (1329/3590)\u001b[K\r", "remote: Counting objects: 38% (1365/3590)\u001b[K\r", "remote: Counting objects: 39% (1401/3590)\u001b[K\r", "remote: Counting objects: 40% (1436/3590)\u001b[K\r", "remote: Counting objects: 41% (1472/3590)\u001b[K\r", "remote: Counting objects: 42% (1508/3590)\u001b[K\r", "remote: Counting objects: 43% (1544/3590)\u001b[K\r", "remote: Counting objects: 44% (1580/3590)\u001b[K\r", "remote: Counting objects: 45% (1616/3590)\u001b[K\r", "remote: Counting objects: 46% (1652/3590)\u001b[K\r", "remote: Counting objects: 47% (1688/3590)\u001b[K\r", "remote: Counting objects: 48% (1724/3590)\u001b[K\r", "remote: Counting objects: 49% (1760/3590)\u001b[K\r", "remote: Counting objects: 50% (1795/3590)\u001b[K\r", "remote: Counting objects: 51% (1831/3590)\u001b[K\r", "remote: Counting objects: 52% (1867/3590)\u001b[K\r", "remote: Counting objects: 53% (1903/3590)\u001b[K\r", "remote: Counting objects: 54% (1939/3590)\u001b[K\r", "remote: Counting objects: 55% (1975/3590)\u001b[K\r", "remote: Counting objects: 56% (2011/3590)\u001b[K\r", "remote: Counting objects: 57% (2047/3590)\u001b[K\r", "remote: Counting objects: 58% (2083/3590)\u001b[K\r", "remote: Counting objects: 59% (2119/3590)\u001b[K\r", "remote: Counting objects: 60% (2154/3590)\u001b[K\r", "remote: Counting objects: 61% (2190/3590)\u001b[K\r", "remote: Counting objects: 62% (2226/3590)\u001b[K\r", "remote: Counting objects: 63% (2262/3590)\u001b[K\r", "remote: Counting objects: 64% (2298/3590)\u001b[K\r", "remote: Counting objects: 65% (2334/3590)\u001b[K\r", "remote: Counting objects: 66% (2370/3590)\u001b[K\r", "remote: Counting objects: 67% (2406/3590)\u001b[K\r", "remote: Counting objects: 68% (2442/3590)\u001b[K\r", "remote: Counting objects: 69% (2478/3590)\u001b[K\r", "remote: Counting objects: 70% (2513/3590)\u001b[K\r", "remote: Counting objects: 71% (2549/3590)\u001b[K\r", "remote: Counting objects: 72% (2585/3590)\u001b[K\r", "remote: Counting objects: 73% (2621/3590)\u001b[K\r", "remote: Counting objects: 74% (2657/3590)\u001b[K\r", "remote: Counting objects: 75% (2693/3590)\u001b[K\r", "remote: Counting objects: 76% (2729/3590)\u001b[K\r", "remote: Counting objects: 77% (2765/3590)\u001b[K\r", "remote: Counting objects: 78% (2801/3590)\u001b[K\r", "remote: Counting objects: 79% (2837/3590)\u001b[K\r", "remote: Counting objects: 80% (2872/3590)\u001b[K\r", "remote: Counting objects: 81% (2908/3590)\u001b[K\r", "remote: Counting objects: 82% (2944/3590)\u001b[K\r", "remote: Counting objects: 83% (2980/3590)\u001b[K\r", "remote: Counting objects: 84% (3016/3590)\u001b[K\r", "remote: Counting objects: 85% (3052/3590)\u001b[K\r", "remote: Counting objects: 86% (3088/3590)\u001b[K\r", "remote: Counting objects: 87% (3124/3590)\u001b[K\r", "remote: Counting objects: 88% (3160/3590)\u001b[K\r", "remote: Counting objects: 89% (3196/3590)\u001b[K\r", "remote: Counting objects: 90% (3231/3590)\u001b[K\r", "remote: Counting objects: 91% (3267/3590)\u001b[K\r", "remote: Counting objects: 92% (3303/3590)\u001b[K\r", "remote: Counting objects: 93% (3339/3590)\u001b[K\r", "remote: Counting objects: 94% (3375/3590)\u001b[K\r", "remote: Counting objects: 95% (3411/3590)\u001b[K\r", "remote: Counting objects: 96% (3447/3590)\u001b[K\r", "remote: Counting objects: 97% (3483/3590)\u001b[K\r", "remote: Counting objects: 98% (3519/3590)\u001b[K\r", "remote: Counting objects: 99% (3555/3590)\u001b[K\r", "remote: Counting objects: 100% (3590/3590)\u001b[K\r", "remote: Counting objects: 100% (3590/3590), done.\u001b[K\r\n", "remote: Compressing objects: 0% (1/3005)\u001b[K\r", "remote: Compressing objects: 1% (31/3005)\u001b[K\r", "remote: Compressing objects: 2% (61/3005)\u001b[K\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "remote: Compressing objects: 3% (91/3005)\u001b[K\r", "remote: Compressing objects: 4% (121/3005)\u001b[K\r", "remote: Compressing objects: 5% (151/3005)\u001b[K\r", "remote: Compressing objects: 6% (181/3005)\u001b[K\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "remote: Compressing objects: 7% (211/3005)\u001b[K\r", "remote: Compressing objects: 8% (241/3005)\u001b[K\r", "remote: Compressing objects: 9% (271/3005)\u001b[K\r", "remote: Compressing objects: 10% (301/3005)\u001b[K\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "remote: Compressing objects: 11% (331/3005)\u001b[K\r", "remote: Compressing objects: 12% (361/3005)\u001b[K\r", "remote: Compressing objects: 13% (391/3005)\u001b[K\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "remote: Compressing objects: 14% (421/3005)\u001b[K\r", "remote: Compressing objects: 15% (451/3005)\u001b[K\r", "remote: Compressing objects: 16% (481/3005)\u001b[K\r", "remote: Compressing objects: 17% (511/3005)\u001b[K\r", "remote: Compressing objects: 18% (541/3005)\u001b[K\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "remote: Compressing objects: 19% (571/3005)\u001b[K\r", "remote: Compressing objects: 20% (601/3005)\u001b[K\r", "remote: Compressing objects: 21% (632/3005)\u001b[K\r", "remote: Compressing objects: 22% (662/3005)\u001b[K\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "remote: Compressing objects: 23% (692/3005)\u001b[K\r", "remote: Compressing objects: 24% (722/3005)\u001b[K\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "remote: Compressing objects: 25% (752/3005)\u001b[K\r", "remote: Compressing objects: 26% (782/3005)\u001b[K\r", "remote: Compressing objects: 27% (812/3005)\u001b[K\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "remote: Compressing objects: 28% (842/3005)\u001b[K\r", "remote: Compressing objects: 29% (872/3005)\u001b[K\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "remote: Compressing objects: 30% (902/3005)\u001b[K\r", "remote: Compressing objects: 31% (932/3005)\u001b[K\r", "remote: Compressing objects: 32% (962/3005)\u001b[K\r", "remote: Compressing objects: 33% (992/3005)\u001b[K\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "remote: Compressing objects: 34% (1022/3005)\u001b[K\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "remote: Compressing objects: 35% (1052/3005)\u001b[K\r", "remote: Compressing objects: 36% (1082/3005)\u001b[K\r", "remote: Compressing objects: 37% (1112/3005)\u001b[K\r", "remote: Compressing objects: 38% (1142/3005)\u001b[K\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "remote: Compressing objects: 39% (1172/3005)\u001b[K\r", "remote: Compressing objects: 40% (1202/3005)\u001b[K\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "remote: Compressing objects: 41% (1233/3005)\u001b[K\r", "remote: Compressing objects: 42% (1263/3005)\u001b[K\r", "remote: Compressing objects: 43% (1293/3005)\u001b[K\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "remote: Compressing objects: 44% (1323/3005)\u001b[K\r", "remote: Compressing objects: 45% (1353/3005)\u001b[K\r", "remote: Compressing objects: 46% (1383/3005)\u001b[K\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "remote: Compressing objects: 47% (1413/3005)\u001b[K\r", "remote: Compressing objects: 48% (1443/3005)\u001b[K\r", "remote: Compressing objects: 49% (1473/3005)\u001b[K\r", "remote: Compressing objects: 50% (1503/3005)\u001b[K\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "remote: Compressing objects: 51% (1533/3005)\u001b[K\r", "remote: Compressing objects: 52% (1563/3005)\u001b[K\r", "remote: Compressing objects: 53% (1593/3005)\u001b[K\r", "remote: Compressing objects: 54% (1623/3005)\u001b[K\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "remote: Compressing objects: 55% (1653/3005)\u001b[K\r", "remote: Compressing objects: 56% (1683/3005)\u001b[K\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "remote: Compressing objects: 57% (1713/3005)\u001b[K\r", "remote: Compressing objects: 58% (1743/3005)\u001b[K\r", "remote: Compressing objects: 59% (1773/3005)\u001b[K\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "remote: Compressing objects: 60% (1803/3005)\u001b[K\r", "remote: Compressing objects: 61% (1834/3005)\u001b[K\r", "remote: Compressing objects: 62% (1864/3005)\u001b[K\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "remote: Compressing objects: 63% (1894/3005)\u001b[K\r", "remote: Compressing objects: 64% (1924/3005)\u001b[K\r", "remote: Compressing objects: 65% (1954/3005)\u001b[K\r", "remote: Compressing objects: 66% (1984/3005)\u001b[K\r", "remote: Compressing objects: 67% (2014/3005)\u001b[K\r", "remote: Compressing objects: 68% (2044/3005)\u001b[K\r", "remote: Compressing objects: 69% (2074/3005)\u001b[K\r", "remote: Compressing objects: 70% (2104/3005)\u001b[K\r", "remote: Compressing objects: 71% (2134/3005)\u001b[K\r", "remote: Compressing objects: 72% (2164/3005)\u001b[K\r", "remote: Compressing objects: 73% (2194/3005)\u001b[K\r", "remote: Compressing objects: 73% (2207/3005)\u001b[K\r", "remote: Compressing objects: 74% (2224/3005)\u001b[K\r", "remote: Compressing objects: 75% (2254/3005)\u001b[K\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "remote: Compressing objects: 76% (2284/3005)\u001b[K\r", "remote: Compressing objects: 77% (2314/3005)\u001b[K\r", "remote: Compressing objects: 78% (2344/3005)\u001b[K\r", "remote: Compressing objects: 79% (2374/3005)\u001b[K\r", "remote: Compressing objects: 80% (2404/3005)\u001b[K\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "remote: Compressing objects: 81% (2435/3005)\u001b[K\r", "remote: Compressing objects: 82% (2465/3005)\u001b[K\r", "remote: Compressing objects: 83% (2495/3005)\u001b[K\r", "remote: Compressing objects: 84% (2525/3005)\u001b[K\r", "remote: Compressing objects: 85% (2555/3005)\u001b[K\r", "remote: Compressing objects: 86% (2585/3005)\u001b[K\r", "remote: Compressing objects: 87% (2615/3005)\u001b[K\r", "remote: Compressing objects: 88% (2645/3005)\u001b[K\r", "remote: Compressing objects: 89% (2675/3005)\u001b[K\r", "remote: Compressing objects: 90% (2705/3005)\u001b[K\r", "remote: Compressing objects: 91% (2735/3005)\u001b[K\r", "remote: Compressing objects: 92% (2765/3005)\u001b[K\r", "remote: Compressing objects: 93% (2795/3005)\u001b[K\r", "remote: Compressing objects: 94% (2825/3005)\u001b[K\r", "remote: Compressing objects: 95% (2855/3005)\u001b[K\r", "remote: Compressing objects: 96% (2885/3005)\u001b[K\r", "remote: Compressing objects: 97% (2915/3005)\u001b[K\r", "remote: Compressing objects: 98% (2945/3005)\u001b[K\r", "remote: Compressing objects: 99% (2975/3005)\u001b[K\r", "remote: Compressing objects: 100% (3005/3005)\u001b[K\r", "remote: Compressing objects: 100% (3005/3005), done.\u001b[K\r\n", "Receiving objects: 0% (1/3590)\r", "Receiving objects: 1% (36/3590)\r", "Receiving objects: 2% (72/3590)\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Receiving objects: 3% (108/3590)\r", "Receiving objects: 4% (144/3590)\r", "Receiving objects: 5% (180/3590)\r", "Receiving objects: 6% (216/3590)\r", "Receiving objects: 7% (252/3590)\r", "Receiving objects: 8% (288/3590)\r", "Receiving objects: 9% (324/3590)\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Receiving objects: 10% (359/3590)\r", "Receiving objects: 11% (395/3590)\r", "Receiving objects: 12% (431/3590)\r", "Receiving objects: 13% (467/3590)\r", "Receiving objects: 14% (503/3590)\r", "Receiving objects: 15% (539/3590)\r", "Receiving objects: 16% (575/3590)\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Receiving objects: 17% (611/3590)\r", "Receiving objects: 18% (647/3590)\r", "Receiving objects: 19% (683/3590)\r", "Receiving objects: 20% (718/3590)\r", "Receiving objects: 21% (754/3590)\r", "Receiving objects: 22% (790/3590)\r", "Receiving objects: 23% (826/3590)\r", "Receiving objects: 24% (862/3590)\r", "Receiving objects: 25% (898/3590)\r", "Receiving objects: 26% (934/3590)\r", "Receiving objects: 27% (970/3590)\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Receiving objects: 28% (1006/3590)\r", "Receiving objects: 29% (1042/3590)\r", "Receiving objects: 30% (1077/3590)\r", "Receiving objects: 31% (1113/3590)\r", "Receiving objects: 32% (1149/3590)\r", "Receiving objects: 33% (1185/3590)\r", "Receiving objects: 34% (1221/3590)\r", "Receiving objects: 35% (1257/3590)\r", "Receiving objects: 36% (1293/3590)\r", "Receiving objects: 37% (1329/3590)\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Receiving objects: 38% (1365/3590)\r", "Receiving objects: 39% (1401/3590)\r", "Receiving objects: 40% (1436/3590)\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Receiving objects: 41% (1472/3590), 4.73 MiB | 9.45 MiB/s\r", "Receiving objects: 42% (1508/3590), 4.73 MiB | 9.45 MiB/s\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Receiving objects: 43% (1544/3590), 4.73 MiB | 9.45 MiB/s\r", "Receiving objects: 44% (1580/3590), 4.73 MiB | 9.45 MiB/s\r", "Receiving objects: 45% (1616/3590), 4.73 MiB | 9.45 MiB/s\r", "Receiving objects: 46% (1652/3590), 4.73 MiB | 9.45 MiB/s\r", "Receiving objects: 47% (1688/3590), 4.73 MiB | 9.45 MiB/s\r", "Receiving objects: 48% (1724/3590), 4.73 MiB | 9.45 MiB/s\r", "Receiving objects: 49% (1760/3590), 4.73 MiB | 9.45 MiB/s\r", "Receiving objects: 50% (1795/3590), 4.73 MiB | 9.45 MiB/s\r", "Receiving objects: 51% (1831/3590), 4.73 MiB | 9.45 MiB/s\r", "Receiving objects: 52% (1867/3590), 4.73 MiB | 9.45 MiB/s\r", "Receiving objects: 53% (1903/3590), 4.73 MiB | 9.45 MiB/s\r", "Receiving objects: 54% (1939/3590), 4.73 MiB | 9.45 MiB/s\r", "Receiving objects: 55% (1975/3590), 4.73 MiB | 9.45 MiB/s\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Receiving objects: 56% (2011/3590), 4.73 MiB | 9.45 MiB/s\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Receiving objects: 57% (2047/3590), 4.73 MiB | 9.45 MiB/s\r", "Receiving objects: 58% (2083/3590), 4.73 MiB | 9.45 MiB/s\r", "Receiving objects: 59% (2119/3590), 4.73 MiB | 9.45 MiB/s\r", "Receiving objects: 60% (2154/3590), 4.73 MiB | 9.45 MiB/s\r", "Receiving objects: 61% (2190/3590), 4.73 MiB | 9.45 MiB/s\r", "Receiving objects: 62% (2226/3590), 4.73 MiB | 9.45 MiB/s\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Receiving objects: 63% (2262/3590), 4.73 MiB | 9.45 MiB/s\r", "Receiving objects: 64% (2298/3590), 4.73 MiB | 9.45 MiB/s\r", "Receiving objects: 65% (2334/3590), 4.73 MiB | 9.45 MiB/s\r", "Receiving objects: 66% (2370/3590), 4.73 MiB | 9.45 MiB/s\r", "Receiving objects: 67% (2406/3590), 4.73 MiB | 9.45 MiB/s\r", "Receiving objects: 68% (2442/3590), 4.73 MiB | 9.45 MiB/s\r", "Receiving objects: 69% (2478/3590), 4.73 MiB | 9.45 MiB/s\r", "Receiving objects: 70% (2513/3590), 4.73 MiB | 9.45 MiB/s\r", "Receiving objects: 71% (2549/3590), 4.73 MiB | 9.45 MiB/s\r", "Receiving objects: 72% (2585/3590), 4.73 MiB | 9.45 MiB/s\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Receiving objects: 73% (2621/3590), 4.73 MiB | 9.45 MiB/s\r", "Receiving objects: 74% (2657/3590), 4.73 MiB | 9.45 MiB/s\r", "Receiving objects: 75% (2693/3590), 4.73 MiB | 9.45 MiB/s\r", "Receiving objects: 76% (2729/3590), 4.73 MiB | 9.45 MiB/s\r", "Receiving objects: 76% (2763/3590), 27.43 MiB | 27.42 MiB/s\r", "Receiving objects: 77% (2765/3590), 27.43 MiB | 27.42 MiB/s\r", "Receiving objects: 78% (2801/3590), 27.43 MiB | 27.42 MiB/s\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Receiving objects: 79% (2837/3590), 27.43 MiB | 27.42 MiB/s\r", "Receiving objects: 80% (2872/3590), 27.43 MiB | 27.42 MiB/s\r", "Receiving objects: 81% (2908/3590), 27.43 MiB | 27.42 MiB/s\r", "Receiving objects: 82% (2944/3590), 27.43 MiB | 27.42 MiB/s\r", "Receiving objects: 83% (2980/3590), 27.43 MiB | 27.42 MiB/s\r", "Receiving objects: 84% (3016/3590), 27.43 MiB | 27.42 MiB/s\r", "Receiving objects: 85% (3052/3590), 27.43 MiB | 27.42 MiB/s\r", "Receiving objects: 86% (3088/3590), 27.43 MiB | 27.42 MiB/s\r", "Receiving objects: 87% (3124/3590), 27.43 MiB | 27.42 MiB/s\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Receiving objects: 88% (3160/3590), 27.43 MiB | 27.42 MiB/s\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Receiving objects: 89% (3196/3590), 27.43 MiB | 27.42 MiB/s\r", "Receiving objects: 90% (3231/3590), 27.43 MiB | 27.42 MiB/s\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Receiving objects: 91% (3267/3590), 27.43 MiB | 27.42 MiB/s\r", "Receiving objects: 92% (3303/3590), 27.43 MiB | 27.42 MiB/s\r", "Receiving objects: 93% (3339/3590), 27.43 MiB | 27.42 MiB/s\r", "Receiving objects: 94% (3375/3590), 27.43 MiB | 27.42 MiB/s\r", "Receiving objects: 95% (3411/3590), 27.43 MiB | 27.42 MiB/s\r", "Receiving objects: 96% (3447/3590), 27.43 MiB | 27.42 MiB/s\r", "Receiving objects: 97% (3483/3590), 27.43 MiB | 27.42 MiB/s\r", "Receiving objects: 98% (3519/3590), 27.43 MiB | 27.42 MiB/s\r", "remote: Total 3590 (delta 943), reused 1501 (delta 531), pack-reused 0\u001b[K\r\n", "Receiving objects: 99% (3555/3590), 27.43 MiB | 27.42 MiB/s\r", "Receiving objects: 100% (3590/3590), 27.43 MiB | 27.42 MiB/s\r", "Receiving objects: 100% (3590/3590), 47.08 MiB | 35.58 MiB/s, done.\r\n", "Resolving deltas: 0% (0/943)\r", "Resolving deltas: 1% (16/943)\r", "Resolving deltas: 3% (31/943)\r", "Resolving deltas: 5% (49/943)\r", "Resolving deltas: 6% (57/943)\r", "Resolving deltas: 7% (68/943)\r", "Resolving deltas: 8% (76/943)\r", "Resolving deltas: 9% (86/943)\r", "Resolving deltas: 10% (95/943)\r", "Resolving deltas: 11% (104/943)\r", "Resolving deltas: 12% (116/943)\r", "Resolving deltas: 13% (127/943)\r", "Resolving deltas: 14% (134/943)\r", "Resolving deltas: 15% (145/943)\r", "Resolving deltas: 16% (152/943)\r", "Resolving deltas: 17% (166/943)\r", "Resolving deltas: 18% (176/943)\r", "Resolving deltas: 19% (183/943)\r", "Resolving deltas: 20% (193/943)\r", "Resolving deltas: 21% (200/943)\r", "Resolving deltas: 22% (208/943)\r", "Resolving deltas: 23% (218/943)\r", "Resolving deltas: 24% (227/943)\r", "Resolving deltas: 25% (237/943)\r", "Resolving deltas: 26% (250/943)\r", "Resolving deltas: 27% (257/943)\r", "Resolving deltas: 29% (281/943)\r", "Resolving deltas: 30% (284/943)\r", "Resolving deltas: 31% (295/943)\r", "Resolving deltas: 32% (303/943)\r", "Resolving deltas: 33% (313/943)\r", "Resolving deltas: 34% (321/943)\r", "Resolving deltas: 35% (335/943)\r", "Resolving deltas: 36% (342/943)\r", "Resolving deltas: 37% (353/943)\r", "Resolving deltas: 38% (359/943)\r", "Resolving deltas: 39% (368/943)\r", "Resolving deltas: 40% (378/943)\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Resolving deltas: 41% (387/943)\r", "Resolving deltas: 42% (399/943)\r", "Resolving deltas: 43% (406/943)\r", "Resolving deltas: 45% (426/943)\r", "Resolving deltas: 46% (440/943)\r", "Resolving deltas: 47% (447/943)\r", "Resolving deltas: 48% (454/943)\r", "Resolving deltas: 49% (463/943)\r", "Resolving deltas: 50% (472/943)\r", "Resolving deltas: 51% (481/943)\r", "Resolving deltas: 52% (491/943)\r", "Resolving deltas: 53% (500/943)\r", "Resolving deltas: 54% (510/943)\r", "Resolving deltas: 55% (519/943)\r", "Resolving deltas: 56% (533/943)\r", "Resolving deltas: 57% (538/943)\r", "Resolving deltas: 58% (550/943)\r", "Resolving deltas: 59% (557/943)\r", "Resolving deltas: 60% (568/943)\r", "Resolving deltas: 61% (579/943)\r", "Resolving deltas: 62% (585/943)\r", "Resolving deltas: 63% (595/943)\r", "Resolving deltas: 64% (604/943)\r", "Resolving deltas: 65% (613/943)\r", "Resolving deltas: 66% (624/943)\r", "Resolving deltas: 67% (632/943)\r", "Resolving deltas: 68% (644/943)\r", "Resolving deltas: 69% (651/943)\r", "Resolving deltas: 70% (668/943)\r", "Resolving deltas: 72% (680/943)\r", "Resolving deltas: 73% (694/943)\r", "Resolving deltas: 74% (699/943)\r", "Resolving deltas: 75% (708/943)\r", "Resolving deltas: 76% (724/943)\r", "Resolving deltas: 77% (733/943)\r", "Resolving deltas: 78% (741/943)\r", "Resolving deltas: 79% (747/943)\r", "Resolving deltas: 80% (756/943)\r", "Resolving deltas: 81% (764/943)\r", "Resolving deltas: 82% (774/943)\r", "Resolving deltas: 83% (783/943)\r", "Resolving deltas: 84% (793/943)\r", "Resolving deltas: 85% (802/943)\r", "Resolving deltas: 86% (811/943)\r", "Resolving deltas: 87% (826/943)\r", "Resolving deltas: 88% (831/943)\r", "Resolving deltas: 89% (842/943)\r", "Resolving deltas: 90% (852/943)\r", "Resolving deltas: 91% (862/943)\r", "Resolving deltas: 92% (868/943)\r", "Resolving deltas: 93% (877/943)\r", "Resolving deltas: 94% (888/943)\r", "Resolving deltas: 95% (897/943)\r", "Resolving deltas: 96% (908/943)\r", "Resolving deltas: 97% (915/943)\r", "Resolving deltas: 98% (925/943)\r", "Resolving deltas: 99% (936/943)\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Resolving deltas: 100% (943/943)\r", "Resolving deltas: 100% (943/943), done.\r\n" ] } ], "source": [ "!git clone --depth=1 https://github.com/tensorflow/models.git\n", "import models.research.slim.nets.inception_resnet_v2 as inception" ] }, { "cell_type": "markdown", "metadata": { "id": "TRacYNxnN-nk" }, "source": [ "If you're putting a nontrivial chunk of forward pass code into the shim, you want to know that it is behaving the same way as it did in TF1.x. For example, consider trying to put an entire TF-Slim Inception-Resnet-v2 model into the shim as such:" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T03:38:53.366962Z", "iopub.status.busy": "2022-12-14T03:38:53.366681Z", "iopub.status.idle": "2022-12-14T03:38:53.370866Z", "shell.execute_reply": "2022-12-14T03:38:53.370295Z" }, "id": "IijQZtxeaErg" }, "outputs": [], "source": [ "# TF1 Inception resnet v2 forward pass based on slim layers\n", "def inception_resnet_v2(inputs, num_classes, is_training):\n", " with slim.arg_scope(\n", " inception.inception_resnet_v2_arg_scope(batch_norm_scale=True)):\n", " return inception.inception_resnet_v2(inputs, num_classes, is_training=is_training)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T03:38:53.374046Z", "iopub.status.busy": "2022-12-14T03:38:53.373505Z", "iopub.status.idle": "2022-12-14T03:38:53.697683Z", "shell.execute_reply": "2022-12-14T03:38:53.697035Z" }, "id": "Z_-Oxg9OlSd4" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/tmp/ipykernel_118303/2131234657.py:8: The name tf.keras.utils.track_tf1_style_variables is deprecated. Please use tf.compat.v1.keras.utils.track_tf1_style_variables instead.\n", "\n" ] } ], "source": [ "class InceptionResnetV2(tf.keras.layers.Layer):\n", " \"\"\"Slim InceptionResnetV2 forward pass as a Keras layer\"\"\"\n", "\n", " def __init__(self, num_classes, **kwargs):\n", " super().__init__(**kwargs)\n", " self.num_classes = num_classes\n", "\n", " @tf.compat.v1.keras.utils.track_tf1_style_variables\n", " def call(self, inputs, training=None):\n", " is_training = training or False \n", " \n", " # Slim does not accept `None` as a value for is_training,\n", " # Keras will still pass `None` to layers to construct functional models\n", " # without forcing the layer to always be in training or in inference.\n", " # However, `None` is generally considered to run layers in inference.\n", " \n", " with slim.arg_scope(\n", " inception.inception_resnet_v2_arg_scope(batch_norm_scale=True)):\n", " return inception.inception_resnet_v2(\n", " inputs, self.num_classes, is_training=is_training)\n" ] }, { "cell_type": "markdown", "metadata": { "id": "EqFmpktjlvh9" }, "source": [ "As it so happens, this layer actually works perfectly fine out of the box (complete with accurate regularization loss tracking). \n", "\n", "However, this is not something you want to take for granted. Follow the below steps to verify that it is actually behaving as it did in TF1.x, down to observing perfect numerical equivalence. These steps can also help you triangulate what part of the forward pass is causing a divergence from TF1.x (identify if the divergence arises in the model forward pass as opposed to a different part of the model)." ] }, { "cell_type": "markdown", "metadata": { "id": "mmgubd9vkevp" }, "source": [ "## Step 1: Verify variables are only created once\n", "\n", "The very first thing you should verify is that you have correctly built the model in a way that reuses variables in each call rather than accidentally creating and using new variables each time. For example, if your model creates a new Keras layer or calls `tf.Variable` in each forward pass call then it is most likely failing to capture variables and creating new ones each time.\n", "\n", "Below are two context manager scopes you can use to detect when your model is creating new variables and debug which part of the model is doing it." ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T03:38:53.701383Z", "iopub.status.busy": "2022-12-14T03:38:53.700896Z", "iopub.status.idle": "2022-12-14T03:38:53.707635Z", "shell.execute_reply": "2022-12-14T03:38:53.707033Z" }, "id": "VMTfTXC0zW97" }, "outputs": [], "source": [ "@contextmanager\n", "def assert_no_variable_creations():\n", " \"\"\"Assert no variables are created in this context manager scope.\"\"\"\n", " def invalid_variable_creator(next_creator, **kwargs):\n", " raise ValueError(\"Attempted to create a new variable instead of reusing an existing one. Args: {}\".format(kwargs))\n", "\n", " with tf.variable_creator_scope(invalid_variable_creator):\n", " yield\n", "\n", "@contextmanager\n", "def catch_and_raise_created_variables():\n", " \"\"\"Raise all variables created within this context manager scope (if any).\"\"\"\n", " created_vars = []\n", " def variable_catcher(next_creator, **kwargs):\n", " var = next_creator(**kwargs)\n", " created_vars.append(var)\n", " return var\n", "\n", " with tf.variable_creator_scope(variable_catcher):\n", " yield\n", " if created_vars:\n", " raise ValueError(\"Created vars:\", created_vars)" ] }, { "cell_type": "markdown", "metadata": { "id": "WOKUtciktQqv" }, "source": [ "The first scope (`assert_no_variable_creations()`) will raise an error immediately once you try creating a variable within the scope. This allows you to inspect the stacktrace (and use interactive debugging) to figure out exactly what lines of code created a variable instead of reusing an existing one.\n", "\n", "The second scope (`catch_and_raise_created_variables()`) will raise an exception at the end of the scope if any variables ended up being created. This exception will include the list of all variables created in the scope. This is useful for figuring out what the set of all weights your model is creating is in case you can spot general patterns. However, it is less useful for identifying the exact lines of code where those variables got created.\n", "\n", "Use both scopes below to verify that the shim-based InceptionResnetV2 layer does not create any new variables after the first call (presumably reusing them)." ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T03:38:53.711113Z", "iopub.status.busy": "2022-12-14T03:38:53.710699Z", "iopub.status.idle": "2022-12-14T03:39:03.248779Z", "shell.execute_reply": "2022-12-14T03:39:03.248055Z" }, "id": "O9FAGotiuLbK" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/keras/engine/base_layer.py:2212: UserWarning: `layer.apply` is deprecated and will be removed in a future version. Please use `layer.__call__` method instead.\n", " warnings.warn('`layer.apply` is deprecated and '\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/keras/legacy_tf_layers/core.py:332: UserWarning: `tf.layers.flatten` is deprecated and will be removed in a future version. Please use `tf.keras.layers.Flatten` instead.\n", " warnings.warn('`tf.layers.flatten` is deprecated and '\n" ] } ], "source": [ "model = InceptionResnetV2(1000)\n", "height, width = 299, 299\n", "num_classes = 1000\n", "\n", "inputs = tf.ones( (1, height, width, 3))\n", "# Create all weights on the first call\n", "model(inputs)\n", "\n", "# Verify that no new weights are created in followup calls\n", "with assert_no_variable_creations():\n", " model(inputs)\n", "with catch_and_raise_created_variables():\n", " model(inputs)" ] }, { "cell_type": "markdown", "metadata": { "id": "9ylT-EIhu1lK" }, "source": [ "In the example below, observe how these decorators work on a layer that incorrectly creates new weights each time instead of reusing existing ones." ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T03:39:03.253141Z", "iopub.status.busy": "2022-12-14T03:39:03.252469Z", "iopub.status.idle": "2022-12-14T03:39:03.256619Z", "shell.execute_reply": "2022-12-14T03:39:03.256047Z" }, "id": "gXqhPQWWtMAw" }, "outputs": [], "source": [ "class BrokenScalingLayer(tf.keras.layers.Layer):\n", " \"\"\"Scaling layer that incorrectly creates new weights each time:\"\"\"\n", "\n", " @tf.compat.v1.keras.utils.track_tf1_style_variables\n", " def call(self, inputs):\n", " var = tf.Variable(initial_value=2.0)\n", " bias = tf.Variable(initial_value=2.0, name='bias')\n", " return inputs * var + bias" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T03:39:03.260192Z", "iopub.status.busy": "2022-12-14T03:39:03.259627Z", "iopub.status.idle": "2022-12-14T03:39:03.277902Z", "shell.execute_reply": "2022-12-14T03:39:03.277304Z" }, "id": "ztUKlMdGvHSq" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Traceback (most recent call last):\n", " File \"/tmpfs/tmp/ipykernel_118303/1128777590.py\", line 7, in \n", " model(inputs)\n", " File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/utils/traceback_utils.py\", line 70, in error_handler\n", " raise e.with_traceback(filtered_tb) from None\n", " File \"/tmpfs/tmp/ipykernel_118303/3224979076.py\", line 6, in call\n", " var = tf.Variable(initial_value=2.0)\n", " File \"/tmpfs/tmp/ipykernel_118303/1829430118.py\", line 5, in invalid_variable_creator\n", " raise ValueError(\"Attempted to create a new variable instead of reusing an existing one. Args: {}\".format(kwargs))\n", "ValueError: Exception encountered when calling layer 'broken_scaling_layer' (type BrokenScalingLayer).\n", "\n", "Attempted to create a new variable instead of reusing an existing one. Args: {'initial_value': 2.0, 'trainable': None, 'validate_shape': True, 'caching_device': None, 'name': None, 'variable_def': None, 'dtype': None, 'import_scope': None, 'constraint': None, 'synchronization': , 'aggregation': , 'shape': None, 'experimental_enable_variable_lifting': None}\n", "\n", "Call arguments received by layer 'broken_scaling_layer' (type BrokenScalingLayer):\n", " • inputs=tf.Tensor(shape=(1, 299, 299, 3), dtype=float32)\n" ] } ], "source": [ "model = BrokenScalingLayer()\n", "inputs = tf.ones( (1, height, width, 3))\n", "model(inputs)\n", "\n", "try:\n", " with assert_no_variable_creations():\n", " model(inputs)\n", "except ValueError as err:\n", " import traceback\n", " traceback.print_exc()\n" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T03:39:03.281111Z", "iopub.status.busy": "2022-12-14T03:39:03.280539Z", "iopub.status.idle": "2022-12-14T03:39:03.289140Z", "shell.execute_reply": "2022-12-14T03:39:03.288582Z" }, "id": "6VyfMJ50vZqZ" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "('Created vars:', [, ])\n" ] } ], "source": [ "model = BrokenScalingLayer()\n", "inputs = tf.ones( (1, height, width, 3))\n", "model(inputs)\n", "\n", "try:\n", " with catch_and_raise_created_variables():\n", " model(inputs)\n", "except ValueError as err:\n", " print(err)" ] }, { "cell_type": "markdown", "metadata": { "id": "JDaiTArcv49M" }, "source": [ "You can fix the layer by making sure it only creates the weights once and then reuses them each time." ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T03:39:03.292412Z", "iopub.status.busy": "2022-12-14T03:39:03.291983Z", "iopub.status.idle": "2022-12-14T03:39:03.301456Z", "shell.execute_reply": "2022-12-14T03:39:03.300908Z" }, "id": "FN1Oa10iviv8" }, "outputs": [], "source": [ "class FixedScalingLayer(tf.keras.layers.Layer):\n", " \"\"\"Scaling layer that incorrectly creates new weights each time:\"\"\"\n", " def __init__(self):\n", " super().__init__()\n", " self.var = None\n", " self.bias = None\n", "\n", " @tf.compat.v1.keras.utils.track_tf1_style_variables\n", " def call(self, inputs):\n", " if self.var is None:\n", " self.var = tf.Variable(initial_value=2.0)\n", " self.bias = tf.Variable(initial_value=2.0, name='bias')\n", " return inputs * self.var + self.bias\n", "\n", "model = FixedScalingLayer()\n", "inputs = tf.ones( (1, height, width, 3))\n", "model(inputs)\n", "\n", "with assert_no_variable_creations():\n", " model(inputs)\n", "with catch_and_raise_created_variables():\n", " model(inputs)" ] }, { "cell_type": "markdown", "metadata": { "id": "MuiZZ7ktwCcn" }, "source": [ "### Troubleshooting\n", "\n", "Here are some common reasons why your model might accidentally be creating new weights instead of reusing existing ones:\n", "\n", "1. It uses an explicit `tf.Variable` call without reusing already-created `tf.Variables`. Fix this by first checking if it has not been created then reusing the existing ones.\n", "2. It creates a Keras layer or model directly in the forward pass each time (as opposed to `tf.compat.v1.layers`). Fix this by first checking if it has not been created then reusing the existing ones.\n", "3. It is built on top of `tf.compat.v1.layers` but fails to assign all `compat.v1.layers` an explicit name or to wrap your `compat.v1.layer` usage inside of a named `variable_scope`, causing the autogenerated layer names to increment in each model call. Fix this by putting a named `tf.compat.v1.variable_scope` inside your shim-decorated method that wraps all of your `tf.compat.v1.layers` usage." ] }, { "cell_type": "markdown", "metadata": { "id": "V4iZLV9BnwKM" }, "source": [ "## Step 2: Check that variable counts, names, and shapes match\n", "\n", "The second step is to make sure your layer running in TF2 creates the same number of weights, with the same shapes, as the corresponding code does in TF1.x.\n", "\n", "You can do a mix of manually checking them to see that they match, and doing the checks programmatically in a unit test as shown below." ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T03:39:03.305013Z", "iopub.status.busy": "2022-12-14T03:39:03.304461Z", "iopub.status.idle": "2022-12-14T03:39:21.004861Z", "shell.execute_reply": "2022-12-14T03:39:21.004090Z" }, "id": "m_aqag5fpun5" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/keras/engine/base_layer_v1.py:1694: UserWarning: `layer.apply` is deprecated and will be removed in a future version. Please use `layer.__call__` method instead.\n", " warnings.warn('`layer.apply` is deprecated and '\n" ] } ], "source": [ "# Build the forward pass inside a TF1.x graph, and \n", "# get the counts, shapes, and names of the variables\n", "graph = tf.Graph()\n", "with graph.as_default(), tf.compat.v1.Session(graph=graph) as sess:\n", " height, width = 299, 299\n", " num_classes = 1000\n", " inputs = tf.ones( (1, height, width, 3))\n", "\n", " out, endpoints = inception_resnet_v2(inputs, num_classes, is_training=False)\n", "\n", " tf1_variable_names_and_shapes = {\n", " var.name: (var.trainable, var.shape) for var in tf.compat.v1.global_variables()}\n", " num_tf1_variables = len(tf.compat.v1.global_variables())" ] }, { "cell_type": "markdown", "metadata": { "id": "WT1-cm99vfNU" }, "source": [ "Next, do the same for the shim-wrapped layer in TF2.\n", "Notice that the model is also called multiple times before grabbing the weights. This is done to effectively test for variable reuse." ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T03:39:21.008893Z", "iopub.status.busy": "2022-12-14T03:39:21.008637Z", "iopub.status.idle": "2022-12-14T03:39:24.226380Z", "shell.execute_reply": "2022-12-14T03:39:24.225669Z" }, "id": "S7ND-lBSqmnE" }, "outputs": [], "source": [ "height, width = 299, 299\n", "num_classes = 1000\n", "\n", "model = InceptionResnetV2(num_classes)\n", "# The weights will not be created until you call the model\n", "\n", "inputs = tf.ones( (1, height, width, 3))\n", "# Call the model multiple times before checking the weights, to verify variables\n", "# get reused rather than accidentally creating additional variables\n", "out, endpoints = model(inputs, training=False)\n", "out, endpoints = model(inputs, training=False)\n", "\n", "# Grab the name: shape mapping and the total number of variables separately,\n", "# because in TF2 variables can be created with the same name\n", "num_tf2_variables = len(model.variables)\n", "tf2_variable_names_and_shapes = {\n", " var.name: (var.trainable, var.shape) for var in model.variables}" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T03:39:24.230091Z", "iopub.status.busy": "2022-12-14T03:39:24.229816Z", "iopub.status.idle": "2022-12-14T03:39:24.233961Z", "shell.execute_reply": "2022-12-14T03:39:24.233379Z" }, "id": "pY2P_4wqsOYw" }, "outputs": [], "source": [ "# Verify that the variable counts, names, and shapes all match:\n", "assert num_tf1_variables == num_tf2_variables\n", "assert tf1_variable_names_and_shapes == tf2_variable_names_and_shapes" ] }, { "cell_type": "markdown", "metadata": { "id": "N4YKJzSVwWkc" }, "source": [ "The shim-based InceptionResnetV2 layer passes this test. However, in the case where they don't match, you can run it through a diff (text or other) to see where the differences are.\n", "\n", "This can provide a clue as to what part of the model isn't behaving as expected. With eager execution you can use pdb, interactive debugging, and breakpoints to dig into the parts of the model that seem suspicious, and debug what is going wrong in more depth." ] }, { "cell_type": "markdown", "metadata": { "id": "2gYrt-_0xpRM" }, "source": [ "### Troubleshooting\n", "\n", "* Pay close attention to the names of any variables created directly by explicit `tf.Variable` calls and Keras layers/models as their variable name generation semantics may differ slightly between TF1.x graphs and TF2 functionality such as eager execution and `tf.function` even if everything else is working properly. If this is the case for you, adjust your test to account for any slightly different naming semantics.\n", "\n", "* You may sometimes find that the `tf.Variable`s, `tf.keras.layers.Layer`s, or `tf.keras.Model`s created in your training loop's forward pass are missing from your TF2 variables list even if they were captured by the variables collection in TF1.x. Fix this by assigning the variables/layers/models that your forward pass creates to instance attributes in your model. See [here](https://www.tensorflow.org/guide/keras/custom_layers_and_models) for more info." ] }, { "cell_type": "markdown", "metadata": { "id": "fOQJ_hUGnzkq" }, "source": [ "## Step 3: Reset all variables, check numerical equivalence with all randomness disabled\n", "\n", "The next step is to verify numerical equivalence for both the actual outputs and the regularization loss tracking when you fix the model such that there is no random number generation involved (such as during inference).\n", "\n", "The exact way to do this may depend on your specific model, but in most models (such as this one), you can do this by:\n", "1. Initializing the weights to the same value with no randomness. This can be done by resetting them to a fixed value after they have been created.\n", "2. Running the model in inference mode to avoid triggering any dropout layers which can be sources of randomness.\n", "\n", "The following code demonstrates how you can compare the TF1.x and TF2 results this way." ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T03:39:24.237784Z", "iopub.status.busy": "2022-12-14T03:39:24.237280Z", "iopub.status.idle": "2022-12-14T03:39:49.563479Z", "shell.execute_reply": "2022-12-14T03:39:49.562787Z" }, "id": "kL4PzD2Cxzmp" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Regularization loss: 0.001182976\n" ] }, { "data": { "text/plain": [ "array([0.00299837, 0.00299837, 0.00299837, 0.00299837, 0.00299837],\n", " dtype=float32)" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "graph = tf.Graph()\n", "with graph.as_default(), tf.compat.v1.Session(graph=graph) as sess:\n", " height, width = 299, 299\n", " num_classes = 1000\n", " inputs = tf.ones( (1, height, width, 3))\n", "\n", " out, endpoints = inception_resnet_v2(inputs, num_classes, is_training=False)\n", "\n", " # Rather than running the global variable initializers,\n", " # reset all variables to a constant value\n", " var_reset = tf.group([var.assign(tf.ones_like(var) * 0.001) for var in tf.compat.v1.global_variables()])\n", " sess.run(var_reset)\n", "\n", " # Grab the outputs & regularization loss\n", " reg_losses = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.REGULARIZATION_LOSSES)\n", " tf1_regularization_loss = sess.run(tf.math.add_n(reg_losses))\n", " tf1_output = sess.run(out)\n", "\n", "print(\"Regularization loss:\", tf1_regularization_loss)\n", "tf1_output[0][:5]" ] }, { "cell_type": "markdown", "metadata": { "id": "IKkoM_x72rUa" }, "source": [ "Get the TF2 results." ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T03:39:49.567413Z", "iopub.status.busy": "2022-12-14T03:39:49.567128Z", "iopub.status.idle": "2022-12-14T03:39:53.933510Z", "shell.execute_reply": "2022-12-14T03:39:53.932839Z" }, "id": "kb086gJwzsNo" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Regularization loss: tf.Tensor(0.0011829757, shape=(), dtype=float32)\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "height, width = 299, 299\n", "num_classes = 1000\n", "\n", "model = InceptionResnetV2(num_classes)\n", "\n", "inputs = tf.ones((1, height, width, 3))\n", "# Call the model once to create the weights\n", "out, endpoints = model(inputs, training=False)\n", "\n", "# Reset all variables to the same fixed value as above, with no randomness\n", "for var in model.variables:\n", " var.assign(tf.ones_like(var) * 0.001)\n", "tf2_output, endpoints = model(inputs, training=False)\n", "\n", "# Get the regularization loss\n", "tf2_regularization_loss = tf.math.add_n(model.losses)\n", "\n", "print(\"Regularization loss:\", tf2_regularization_loss)\n", "tf2_output[0][:5]" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T03:39:53.936827Z", "iopub.status.busy": "2022-12-14T03:39:53.936548Z", "iopub.status.idle": "2022-12-14T03:39:53.939905Z", "shell.execute_reply": "2022-12-14T03:39:53.939251Z" }, "id": "CUfWqlgIK6ej" }, "outputs": [], "source": [ "# Create a dict of tolerance values\n", "tol_dict={'rtol':1e-06, 'atol':1e-05}" ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T03:39:53.943322Z", "iopub.status.busy": "2022-12-14T03:39:53.942814Z", "iopub.status.idle": "2022-12-14T03:39:53.947107Z", "shell.execute_reply": "2022-12-14T03:39:53.946575Z" }, "id": "R-C07eTo0WTr" }, "outputs": [], "source": [ "# Verify that the regularization loss and output both match\n", "# when we fix the weights and avoid randomness by running inference:\n", "np.testing.assert_allclose(tf1_regularization_loss, tf2_regularization_loss.numpy(), **tol_dict)\n", "np.testing.assert_allclose(tf1_output, tf2_output.numpy(), **tol_dict)" ] }, { "cell_type": "markdown", "metadata": { "id": "5UUq_Fuc2zDO" }, "source": [ "The numbers match between TF1.x and TF2 when you remove sources of randomness, and the TF2-compatible `InceptionResnetV2` layer passes the test.\n", "\n", "If you are observing the results diverging for your own models, you can use printing or pdb and interactive debugging to identify where and why the results start to diverge. Eager execution can make this significantly easier. You can also use an ablation approach to run only small portions of the model on fixed intermediate inputs and isolate where the divergence happens.\n", "\n", "Conveniently, many slim nets (and other models) also expose intermediate endpoints that you can probe." ] }, { "cell_type": "markdown", "metadata": { "id": "btRbak-0ou15" }, "source": [ "## Step 4: Align random number generation, check numerical equivalence in both training and inference\n", "\n", "The final step is to verify that the TF2 model numerically matches the TF1.x model, even when accounting for random number generation in variable initialization and in the forward pass itself (such as dropout layers during the forward pass).\n", "\n", "You can do this by using the testing tool below to make random number generation semantics match between TF1.x graphs/sessions and eager execution." ] }, { "cell_type": "markdown", "metadata": { "id": "jYq-JHiC39QC" }, "source": [ "TF1 legacy graphs/sessions and TF2 eager execution use different stateful random number generation semantics.\n", "\n", "In `tf.compat.v1.Session`s, if no seeds are specified, the random number generation depends on how many operations are in the graph at the time when the random operation is added, and how many times the graph is run. In eager execution, stateful random number generation depends on the global seed, the operation random seed, and how many times the operation with the operation with the given random seed is run. See \n", "`tf.random.set_seed` for more info." ] }, { "cell_type": "markdown", "metadata": { "id": "BQbb8Hyk5YVi" }, "source": [ "The following [`v1.keras.utils.DeterministicRandomTestTool`](https://www.tensorflow.org/api_docs/python/tf/compat/v1/keras/utils/DeterministicRandomTestTool) class provides a context manager `scope()` that can make stateful random operations use the same seed across both TF1 graphs/sessions and eager execution.\n", "\n", "The tool provides two testing modes: \n", "1. `constant` which uses the same seed for every single operation no matter how many times it has been called and,\n", "2. `num_random_ops` which uses the number of previously-observed stateful random operations as the operation seed.\n", "\n", "This applies both to the stateful random operations used for creating and initializing variables, and to the stateful random operations used in computation (such as for dropout layers)." ] }, { "cell_type": "markdown", "metadata": { "id": "MoyZenhGHDA-" }, "source": [ "Generate three random tensors to show how to use this tool to make stateful random number generation match between sessions and eager execution." ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T03:39:53.950786Z", "iopub.status.busy": "2022-12-14T03:39:53.950284Z", "iopub.status.idle": "2022-12-14T03:39:53.989064Z", "shell.execute_reply": "2022-12-14T03:39:53.988502Z" }, "id": "DDFfjrbXEWED" }, "outputs": [ { "data": { "text/plain": [ "(array([[2.5063772],\n", " [2.7488918],\n", " [1.4839486]], dtype=float32),\n", " array([[2.5063772, 2.7488918, 1.4839486],\n", " [1.5633398, 2.1358476, 1.3693532],\n", " [0.3598416, 1.8287641, 2.5314465]], dtype=float32),\n", " array([[2.5063772, 2.7488918, 1.4839486],\n", " [1.5633398, 2.1358476, 1.3693532],\n", " [0.3598416, 1.8287641, 2.5314465]], dtype=float32))" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "random_tool = v1.keras.utils.DeterministicRandomTestTool()\n", "with random_tool.scope():\n", " graph = tf.Graph()\n", " with graph.as_default(), tf.compat.v1.Session(graph=graph) as sess:\n", " a = tf.random.uniform(shape=(3,1))\n", " a = a * 3\n", " b = tf.random.uniform(shape=(3,3))\n", " b = b * 3\n", " c = tf.random.uniform(shape=(3,3))\n", " c = c * 3\n", " graph_a, graph_b, graph_c = sess.run([a, b, c])\n", "\n", "graph_a, graph_b, graph_c" ] }, { "cell_type": "code", "execution_count": 23, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T03:39:53.992308Z", "iopub.status.busy": "2022-12-14T03:39:53.991806Z", "iopub.status.idle": "2022-12-14T03:39:54.011113Z", "shell.execute_reply": "2022-12-14T03:39:54.010470Z" }, "id": "o9bkdPuTFpYr" }, "outputs": [ { "data": { "text/plain": [ "(,\n", " ,\n", " )" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "random_tool = v1.keras.utils.DeterministicRandomTestTool()\n", "with random_tool.scope():\n", " a = tf.random.uniform(shape=(3,1))\n", " a = a * 3\n", " b = tf.random.uniform(shape=(3,3))\n", " b = b * 3\n", " c = tf.random.uniform(shape=(3,3))\n", " c = c * 3\n", "\n", "a, b, c" ] }, { "cell_type": "code", "execution_count": 24, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T03:39:54.014331Z", "iopub.status.busy": "2022-12-14T03:39:54.013797Z", "iopub.status.idle": "2022-12-14T03:39:54.017855Z", "shell.execute_reply": "2022-12-14T03:39:54.017276Z" }, "id": "qRJYFydsGIbF" }, "outputs": [], "source": [ "# Demonstrate that the generated random numbers match\n", "np.testing.assert_allclose(graph_a, a.numpy(), **tol_dict)\n", "np.testing.assert_allclose(graph_b, b.numpy(), **tol_dict)\n", "np.testing.assert_allclose(graph_c, c.numpy(), **tol_dict)" ] }, { "cell_type": "markdown", "metadata": { "id": "J8IWCnS-WFrB" }, "source": [ "However, notice that in `constant` mode, because `b` and `c` were generated with the same seed and have the same shape, they will have exactly the same values." ] }, { "cell_type": "code", "execution_count": 25, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T03:39:54.021063Z", "iopub.status.busy": "2022-12-14T03:39:54.020543Z", "iopub.status.idle": "2022-12-14T03:39:54.024066Z", "shell.execute_reply": "2022-12-14T03:39:54.023479Z" }, "id": "IdxV89q2WPid" }, "outputs": [], "source": [ "np.testing.assert_allclose(b.numpy(), c.numpy(), **tol_dict)" ] }, { "cell_type": "markdown", "metadata": { "id": "vQTm7joHHh57" }, "source": [ "### Trace order\n", "If you are worried about some random numbers matching in `constant` mode reducing your confidence in your numerical equivalence test (for example if several weights take on the same initializations), you can use the `num_random_ops` mode to avoid this. In the `num_random_ops` mode, the generated random numbers will depend on the ordering of random ops in the program." ] }, { "cell_type": "code", "execution_count": 26, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T03:39:54.027409Z", "iopub.status.busy": "2022-12-14T03:39:54.026970Z", "iopub.status.idle": "2022-12-14T03:39:54.054591Z", "shell.execute_reply": "2022-12-14T03:39:54.054025Z" }, "id": "L-AeD148VygJ" }, "outputs": [ { "data": { "text/plain": [ "(array([[2.5063772],\n", " [2.7488918],\n", " [1.4839486]], dtype=float32),\n", " array([[0.45038545, 1.9197761 , 2.4536333 ],\n", " [1.0371652 , 2.9898582 , 1.924583 ],\n", " [0.25679827, 1.6579313 , 2.8418403 ]], dtype=float32),\n", " array([[2.9634383 , 1.0862181 , 2.6042497 ],\n", " [0.70099247, 2.3920312 , 1.0470468 ],\n", " [0.18173039, 0.8359269 , 1.0508587 ]], dtype=float32))" ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ "random_tool = v1.keras.utils.DeterministicRandomTestTool(mode='num_random_ops')\n", "with random_tool.scope():\n", " graph = tf.Graph()\n", " with graph.as_default(), tf.compat.v1.Session(graph=graph) as sess:\n", " a = tf.random.uniform(shape=(3,1))\n", " a = a * 3\n", " b = tf.random.uniform(shape=(3,3))\n", " b = b * 3\n", " c = tf.random.uniform(shape=(3,3))\n", " c = c * 3\n", " graph_a, graph_b, graph_c = sess.run([a, b, c])\n", "\n", "graph_a, graph_b, graph_c" ] }, { "cell_type": "code", "execution_count": 27, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T03:39:54.057715Z", "iopub.status.busy": "2022-12-14T03:39:54.057212Z", "iopub.status.idle": "2022-12-14T03:39:54.070419Z", "shell.execute_reply": "2022-12-14T03:39:54.069817Z" }, "id": "CedD41NuVygK" }, "outputs": [ { "data": { "text/plain": [ "(,\n", " ,\n", " )" ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ "random_tool = v1.keras.utils.DeterministicRandomTestTool(mode='num_random_ops')\n", "with random_tool.scope():\n", " a = tf.random.uniform(shape=(3,1))\n", " a = a * 3\n", " b = tf.random.uniform(shape=(3,3))\n", " b = b * 3\n", " c = tf.random.uniform(shape=(3,3))\n", " c = c * 3\n", "\n", "a, b, c" ] }, { "cell_type": "code", "execution_count": 28, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T03:39:54.073744Z", "iopub.status.busy": "2022-12-14T03:39:54.073122Z", "iopub.status.idle": "2022-12-14T03:39:54.077214Z", "shell.execute_reply": "2022-12-14T03:39:54.076668Z" }, "id": "5We2xSnLVygL" }, "outputs": [], "source": [ "# Demonstrate that the generated random numbers match\n", "np.testing.assert_allclose(graph_a, a.numpy(), **tol_dict)\n", "np.testing.assert_allclose(graph_b, b.numpy(), **tol_dict )\n", "np.testing.assert_allclose(graph_c, c.numpy(), **tol_dict)" ] }, { "cell_type": "code", "execution_count": 29, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T03:39:54.080256Z", "iopub.status.busy": "2022-12-14T03:39:54.079793Z", "iopub.status.idle": "2022-12-14T03:39:54.082990Z", "shell.execute_reply": "2022-12-14T03:39:54.082356Z" }, "id": "BBFG1xehWneM" }, "outputs": [], "source": [ "# Demonstrate that with the 'num_random_ops' mode,\n", "# b & c took on different values even though\n", "# their generated shape was the same\n", "assert not np.allclose(b.numpy(), c.numpy(), **tol_dict)" ] }, { "cell_type": "markdown", "metadata": { "id": "OfX_VexcVqSA" }, "source": [ "However, notice that in this mode random generation is sensitive to program order, and so the following generated random numbers do not match." ] }, { "cell_type": "code", "execution_count": 30, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T03:39:54.086042Z", "iopub.status.busy": "2022-12-14T03:39:54.085613Z", "iopub.status.idle": "2022-12-14T03:39:54.102152Z", "shell.execute_reply": "2022-12-14T03:39:54.101559Z" }, "id": "cZt__ElEIDl_" }, "outputs": [], "source": [ "random_tool = v1.keras.utils.DeterministicRandomTestTool(mode='num_random_ops')\n", "with random_tool.scope():\n", " a = tf.random.uniform(shape=(3,1))\n", " a = a * 3\n", " b = tf.random.uniform(shape=(3,3))\n", " b = b * 3\n", "\n", "random_tool = v1.keras.utils.DeterministicRandomTestTool(mode='num_random_ops')\n", "with random_tool.scope():\n", " b_prime = tf.random.uniform(shape=(3,3))\n", " b_prime = b_prime * 3\n", " a_prime = tf.random.uniform(shape=(3,1))\n", " a_prime = a_prime * 3\n", "\n", "assert not np.allclose(a.numpy(), a_prime.numpy())\n", "assert not np.allclose(b.numpy(), b_prime.numpy())" ] }, { "cell_type": "markdown", "metadata": { "id": "nHhOLHyQIkAe" }, "source": [ "To allow for debugging variations due to tracing order, `DeterministicRandomTestTool` in `num_random_ops` mode allows you to see how many random operations have been traced with the `operation_seed` property." ] }, { "cell_type": "code", "execution_count": 31, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T03:39:54.105583Z", "iopub.status.busy": "2022-12-14T03:39:54.105101Z", "iopub.status.idle": "2022-12-14T03:39:54.114616Z", "shell.execute_reply": "2022-12-14T03:39:54.114047Z" }, "id": "33RCSICuJEyV" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0\n", "1\n", "2\n" ] } ], "source": [ "random_tool = v1.keras.utils.DeterministicRandomTestTool(mode='num_random_ops')\n", "with random_tool.scope():\n", " print(random_tool.operation_seed)\n", " a = tf.random.uniform(shape=(3,1))\n", " a = a * 3\n", " print(random_tool.operation_seed)\n", " b = tf.random.uniform(shape=(3,3))\n", " b = b * 3\n", " print(random_tool.operation_seed)" ] }, { "cell_type": "markdown", "metadata": { "id": "bkQD3NpOMxIv" }, "source": [ "If you need to account for varying trace order in your tests, you can even set the auto-incrementing `operation_seed` explicitly. For example, you can use this to make random number generation match across two different program orders." ] }, { "cell_type": "code", "execution_count": 32, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T03:39:54.117828Z", "iopub.status.busy": "2022-12-14T03:39:54.117272Z", "iopub.status.idle": "2022-12-14T03:39:54.134599Z", "shell.execute_reply": "2022-12-14T03:39:54.134050Z" }, "id": "6W4sS_wOM8CH" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0\n", "1\n" ] } ], "source": [ "random_tool = v1.keras.utils.DeterministicRandomTestTool(mode='num_random_ops')\n", "with random_tool.scope():\n", " print(random_tool.operation_seed)\n", " a = tf.random.uniform(shape=(3,1))\n", " a = a * 3\n", " print(random_tool.operation_seed)\n", " b = tf.random.uniform(shape=(3,3))\n", " b = b * 3\n", "\n", "random_tool = v1.keras.utils.DeterministicRandomTestTool(mode='num_random_ops')\n", "with random_tool.scope():\n", " random_tool.operation_seed = 1\n", " b_prime = tf.random.uniform(shape=(3,3))\n", " b_prime = b_prime * 3\n", " random_tool.operation_seed = 0\n", " a_prime = tf.random.uniform(shape=(3,1))\n", " a_prime = a_prime * 3\n", "\n", "np.testing.assert_allclose(a.numpy(), a_prime.numpy(), **tol_dict)\n", "np.testing.assert_allclose(b.numpy(), b_prime.numpy(), **tol_dict)\n" ] }, { "cell_type": "markdown", "metadata": { "id": "bP5Kx1OcNbvM" }, "source": [ "However, `DeterministicRandomTestTool` disallows reusing already-used operation seeds, so make sure the auto-incremented sequences cannot overlap. This is because eager execution generates different numbers for follow-on usages of the same operation seed while TF1 graphs and sessions do not, so raising an error helps keep session and eager stateful random number generation in line." ] }, { "cell_type": "code", "execution_count": 33, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T03:39:54.138009Z", "iopub.status.busy": "2022-12-14T03:39:54.137470Z", "iopub.status.idle": "2022-12-14T03:39:54.147707Z", "shell.execute_reply": "2022-12-14T03:39:54.147126Z" }, "id": "GmBgg5hzNa5H" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "This `DeterministicRandomTestTool` object is trying to re-use the already-used operation seed 1. It cannot guarantee random numbers will match between eager and sessions when an operation seed is reused. You most likely set `operation_seed` explicitly but used a value that caused the naturally-incrementing operation seed sequences to overlap with an already-used seed.\n" ] } ], "source": [ "random_tool = v1.keras.utils.DeterministicRandomTestTool(mode='num_random_ops')\n", "with random_tool.scope():\n", " random_tool.operation_seed = 1\n", " b_prime = tf.random.uniform(shape=(3,3))\n", " b_prime = b_prime * 3\n", " random_tool.operation_seed = 0\n", " a_prime = tf.random.uniform(shape=(3,1))\n", " a_prime = a_prime * 3\n", " try:\n", " c = tf.random.uniform(shape=(3,1))\n", " raise RuntimeError(\"An exception should have been raised before this, \" +\n", " \"because the auto-incremented operation seed will \" +\n", " \"overlap an already-used value\")\n", " except ValueError as err:\n", " print(err)\n" ] }, { "cell_type": "markdown", "metadata": { "id": "U-bLOeCmOn-4" }, "source": [ "### Verifying Inference\n", "\n", "You can now use the `DeterministicRandomTestTool` to make sure the `InceptionResnetV2` model matches in inference, even when using the random weight initialization. For a stronger test condition due to matching program order, use the `num_random_ops` mode." ] }, { "cell_type": "code", "execution_count": 34, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T03:39:54.151016Z", "iopub.status.busy": "2022-12-14T03:39:54.150495Z", "iopub.status.idle": "2022-12-14T03:40:15.741319Z", "shell.execute_reply": "2022-12-14T03:40:15.740627Z" }, "id": "8TWOrflkPa7T" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Regularization loss: 1.2254326\n" ] } ], "source": [ "random_tool = v1.keras.utils.DeterministicRandomTestTool(mode='num_random_ops')\n", "with random_tool.scope():\n", " graph = tf.Graph()\n", " with graph.as_default(), tf.compat.v1.Session(graph=graph) as sess:\n", " height, width = 299, 299\n", " num_classes = 1000\n", " inputs = tf.ones( (1, height, width, 3))\n", "\n", " out, endpoints = inception_resnet_v2(inputs, num_classes, is_training=False)\n", "\n", " # Initialize the variables\n", " sess.run(tf.compat.v1.global_variables_initializer())\n", "\n", " # Grab the outputs & regularization loss\n", " reg_losses = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.REGULARIZATION_LOSSES)\n", " tf1_regularization_loss = sess.run(tf.math.add_n(reg_losses))\n", " tf1_output = sess.run(out)\n", "\n", " print(\"Regularization loss:\", tf1_regularization_loss)" ] }, { "cell_type": "code", "execution_count": 35, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T03:40:15.744904Z", "iopub.status.busy": "2022-12-14T03:40:15.744601Z", "iopub.status.idle": "2022-12-14T03:40:18.420342Z", "shell.execute_reply": "2022-12-14T03:40:18.419659Z" }, "id": "Qcx6ur4KPMI1" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Regularization loss: tf.Tensor(1.2254325, shape=(), dtype=float32)\n" ] } ], "source": [ "height, width = 299, 299\n", "num_classes = 1000\n", "\n", "random_tool = v1.keras.utils.DeterministicRandomTestTool(mode='num_random_ops')\n", "with random_tool.scope():\n", " model = InceptionResnetV2(num_classes)\n", "\n", " inputs = tf.ones((1, height, width, 3))\n", " tf2_output, endpoints = model(inputs, training=False)\n", "\n", " # Grab the regularization loss as well\n", " tf2_regularization_loss = tf.math.add_n(model.losses)\n", "\n", "print(\"Regularization loss:\", tf2_regularization_loss)" ] }, { "cell_type": "code", "execution_count": 36, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T03:40:18.423973Z", "iopub.status.busy": "2022-12-14T03:40:18.423477Z", "iopub.status.idle": "2022-12-14T03:40:18.427976Z", "shell.execute_reply": "2022-12-14T03:40:18.427410Z" }, "id": "m_SS2b6qPFl1" }, "outputs": [], "source": [ "# Verify that the regularization loss and output both match\n", "# when using the DeterministicRandomTestTool:\n", "np.testing.assert_allclose(tf1_regularization_loss, tf2_regularization_loss.numpy(), **tol_dict)\n", "np.testing.assert_allclose(tf1_output, tf2_output.numpy(), **tol_dict)" ] }, { "cell_type": "markdown", "metadata": { "id": "TKSktIRaP-5b" }, "source": [ "### Verifying Training\n", "\n", "Because `DeterministicRandomTestTool` works for *all* stateful random operations (including both weight initialization and computation such as dropout layers), you can use it to verify the models match in training mode as well. You can again use the `num_random_ops` mode because the program order of the stateful random ops matches." ] }, { "cell_type": "code", "execution_count": 37, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T03:40:18.431282Z", "iopub.status.busy": "2022-12-14T03:40:18.430734Z", "iopub.status.idle": "2022-12-14T03:40:42.196309Z", "shell.execute_reply": "2022-12-14T03:40:42.195578Z" }, "id": "nMBFVa1kQTJH" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/layers/normalization/batch_normalization.py:581: _colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Colocations handled automatically by placer.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Regularization loss: 1.22548\n" ] } ], "source": [ "random_tool = v1.keras.utils.DeterministicRandomTestTool(mode='num_random_ops')\n", "with random_tool.scope():\n", " graph = tf.Graph()\n", " with graph.as_default(), tf.compat.v1.Session(graph=graph) as sess:\n", " height, width = 299, 299\n", " num_classes = 1000\n", " inputs = tf.ones( (1, height, width, 3))\n", "\n", " out, endpoints = inception_resnet_v2(inputs, num_classes, is_training=True)\n", "\n", " # Initialize the variables\n", " sess.run(tf.compat.v1.global_variables_initializer())\n", "\n", " # Grab the outputs & regularization loss\n", " reg_losses = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.REGULARIZATION_LOSSES)\n", " tf1_regularization_loss = sess.run(tf.math.add_n(reg_losses))\n", " tf1_output = sess.run(out)\n", "\n", " print(\"Regularization loss:\", tf1_regularization_loss)" ] }, { "cell_type": "code", "execution_count": 38, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T03:40:42.199907Z", "iopub.status.busy": "2022-12-14T03:40:42.199657Z", "iopub.status.idle": "2022-12-14T03:40:45.015643Z", "shell.execute_reply": "2022-12-14T03:40:45.014921Z" }, "id": "-jlBkwI5QTJI" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Regularization loss: tf.Tensor(1.2254798, shape=(), dtype=float32)\n" ] } ], "source": [ "height, width = 299, 299\n", "num_classes = 1000\n", "\n", "random_tool = v1.keras.utils.DeterministicRandomTestTool(mode='num_random_ops')\n", "with random_tool.scope():\n", " model = InceptionResnetV2(num_classes)\n", "\n", " inputs = tf.ones((1, height, width, 3))\n", " tf2_output, endpoints = model(inputs, training=True)\n", "\n", " # Grab the regularization loss as well\n", " tf2_regularization_loss = tf.math.add_n(model.losses)\n", "\n", "print(\"Regularization loss:\", tf2_regularization_loss)" ] }, { "cell_type": "code", "execution_count": 39, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T03:40:45.019242Z", "iopub.status.busy": "2022-12-14T03:40:45.018962Z", "iopub.status.idle": "2022-12-14T03:40:45.023395Z", "shell.execute_reply": "2022-12-14T03:40:45.022782Z" }, "id": "IL9mjTLnQTJJ" }, "outputs": [], "source": [ "# Verify that the regularization loss and output both match\n", "# when using the DeterministicRandomTestTool\n", "np.testing.assert_allclose(tf1_regularization_loss, tf2_regularization_loss.numpy(), **tol_dict)\n", "np.testing.assert_allclose(tf1_output, tf2_output.numpy(), **tol_dict)" ] }, { "cell_type": "markdown", "metadata": { "id": "uJTZvmfnQqZH" }, "source": [ "You have now verified that the `InceptionResnetV2` model running eagerly with decorators around `tf.keras.layers.Layer` numerically matches the slim network running in TF1 graphs and sessions." ] }, { "cell_type": "markdown", "metadata": { "id": "xpOAei5vRAPa" }, "source": [ "Note: When using the `DeterministicRandomTestTool` in `num_random_ops` mode, it is suggested you directly use and call the `tf.keras.layers.Layer` method decorator when testing for numerical equivalence. Embedding it within a Keras functional model or other Keras models can produce differences in stateful random operation tracing order that can be tricky to reason about or match exactly when comparing TF1.x graphs/sessions and eager execution. \n", "\n", "For example, calling the `InceptionResnetV2` layer directly with `training=True` interleaves variable initialization with the dropout order according to the network creation order.\n", "\n", "On the other hand, first putting the `tf.keras.layers.Layer` decorator in a Keras functional model and only then calling the model with `training=True` is equivalent to initializing all variables then using the dropout layer. This produces a different tracing order and a different set of random numbers.\n", "\n", "However, the default `mode='constant'` is not sensitive to these differences in tracing order and will pass without extra work even when embedding the layer in a Keras functional model." ] }, { "cell_type": "code", "execution_count": 40, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T03:40:45.026795Z", "iopub.status.busy": "2022-12-14T03:40:45.026567Z", "iopub.status.idle": "2022-12-14T03:41:08.547454Z", "shell.execute_reply": "2022-12-14T03:41:08.546677Z" }, "id": "0dSR4ZNvYNYm" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Regularization loss: 1.2239965\n" ] } ], "source": [ "random_tool = v1.keras.utils.DeterministicRandomTestTool()\n", "with random_tool.scope():\n", " graph = tf.Graph()\n", " with graph.as_default(), tf.compat.v1.Session(graph=graph) as sess:\n", " height, width = 299, 299\n", " num_classes = 1000\n", " inputs = tf.ones( (1, height, width, 3))\n", "\n", " out, endpoints = inception_resnet_v2(inputs, num_classes, is_training=True)\n", "\n", " # Initialize the variables\n", " sess.run(tf.compat.v1.global_variables_initializer())\n", "\n", " # Get the outputs & regularization losses\n", " reg_losses = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.REGULARIZATION_LOSSES)\n", " tf1_regularization_loss = sess.run(tf.math.add_n(reg_losses))\n", " tf1_output = sess.run(out)\n", "\n", " print(\"Regularization loss:\", tf1_regularization_loss)" ] }, { "cell_type": "code", "execution_count": 41, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T03:41:08.551835Z", "iopub.status.busy": "2022-12-14T03:41:08.551188Z", "iopub.status.idle": "2022-12-14T03:41:16.167387Z", "shell.execute_reply": "2022-12-14T03:41:16.166695Z" }, "id": "iMPMMnPtYUY7" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/keras/engine/base_layer.py:1345: UserWarning: `layer.updates` will be removed in a future version. This property should not be used in TensorFlow 2.0, as `updates` are applied automatically.\n", " warnings.warn('`layer.updates` will be removed in a future version. '\n", "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/legacy_tf_layers/base.py:627: UserWarning: `layer.updates` will be removed in a future version. This property should not be used in TensorFlow 2.0, as `updates` are applied automatically.\n", " self.updates, tf.compat.v1.GraphKeys.UPDATE_OPS\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Regularization loss: tf.Tensor(1.2239964, shape=(), dtype=float32)\n" ] } ], "source": [ "height, width = 299, 299\n", "num_classes = 1000\n", "\n", "random_tool = v1.keras.utils.DeterministicRandomTestTool()\n", "with random_tool.scope():\n", " keras_input = tf.keras.Input(shape=(height, width, 3))\n", " layer = InceptionResnetV2(num_classes)\n", " model = tf.keras.Model(inputs=keras_input, outputs=layer(keras_input))\n", "\n", " inputs = tf.ones((1, height, width, 3))\n", " tf2_output, endpoints = model(inputs, training=True)\n", "\n", " # Get the regularization loss\n", " tf2_regularization_loss = tf.math.add_n(model.losses)\n", "\n", "print(\"Regularization loss:\", tf2_regularization_loss)" ] }, { "cell_type": "code", "execution_count": 42, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T03:41:16.171428Z", "iopub.status.busy": "2022-12-14T03:41:16.170699Z", "iopub.status.idle": "2022-12-14T03:41:16.175329Z", "shell.execute_reply": "2022-12-14T03:41:16.174712Z" }, "id": "jf46KUVyYUY8" }, "outputs": [], "source": [ "# Verify that the regularization loss and output both match\n", "# when using the DeterministicRandomTestTool\n", "np.testing.assert_allclose(tf1_regularization_loss, tf2_regularization_loss.numpy(), **tol_dict)\n", "np.testing.assert_allclose(tf1_output, tf2_output.numpy(), **tol_dict)" ] }, { "cell_type": "markdown", "metadata": { "id": "hWXHjtkiZ09V" }, "source": [ "## Step 3b or 4b (optional): Testing with pre-existing checkpoints\n", "\n", "After step 3 or step 4 above, it can be useful to run your numerical equivalence tests when starting from pre-existing name-based checkpoints if you have some. This can test both that your legacy checkpoint loading is working correctly and that the model itself is working right. The [Reusing TF1.x checkpoints guide](./reuse_checkpoints.ipynb) covers how to reuse your pre-existing TF1.x checkpoints and transfer them over to TF2 checkpoints.\n" ] }, { "cell_type": "markdown", "metadata": { "id": "v6i3MFmGcxYx" }, "source": [ "## Additional Testing & Troubleshooting\n", "\n", "As you add more numerical equivalence tests, you may also choose to add a test that verifies your gradient computation (or even your optimizer updates) match.\n", "\n", "Backpropagation and gradient computation are more prone to floating point numerical instabilities than model forward passes. This means that as your equivalence tests cover more non-isolated parts of your training, you may begin to see non-trivial numerics differences between running fully eagerly and your TF1 graphs. This may be caused by TensorFlow's graph optimizations that do things such as replace subexpressions in a graph with fewer mathematical operations.\n", "\n", "To isolate whether this is likely to be the case, you can compare your TF1 code to TF2 computation happening inside of a `tf.function` (which applies graph optimization passes like your TF1 graph) rather than to a purely eager computation. Alternatively, you can try using `tf.config.optimizer.set_experimental_options` to disable optimization passes such as `\"arithmetic_optimization\"` before your TF1 computation to see if the result ends up numerically closer to your TF2 computation results. In your actual training runs it is recommended you use `tf.function` with optimization passes enabled for performance reasons, but you may find it useful to disable them in your numerical equivalence unit tests.\n", "\n", "Similarly, you may also find that `tf.compat.v1.train` optimizers and TF2 optimizers have slightly different floating point numerics properties than TF2 optimizers, even if the mathematical formulas they are representing are the same. This is less likely to be an issue in your training runs, but it may require a higher numerical tolerance in equivalence unit tests." ] } ], "metadata": { "colab": { "collapsed_sections": [], "name": "validate_correctness.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.16" } }, "nbformat": 4, "nbformat_minor": 0 }