File size: 12,907 Bytes
f8f5cdf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import collections

from .utils import ExplicitEnum, is_torch_available, logging


if is_torch_available():
    import torch


logger = logging.get_logger(__name__)


class DebugUnderflowOverflow:
    """
    This debug class helps detect and understand where the model starts getting very large or very small, and more
    importantly `nan` or `inf` weight and activation elements.

    There are 2 working modes:

    1. Underflow/overflow detection (default)
    2. Specific batch absolute min/max tracing without detection

    Mode 1: Underflow/overflow detection

    To activate the underflow/overflow detection, initialize the object with the model :

    ```python
    debug_overflow = DebugUnderflowOverflow(model)
    ```

    then run the training as normal and if `nan` or `inf` gets detected in at least one of the weight, input or output
    elements this module will throw an exception and will print `max_frames_to_save` frames that lead to this event,
    each frame reporting

    1. the fully qualified module name plus the class name whose `forward` was run
    2. the absolute min and max value of all elements for each module weights, and the inputs and output

    For example, here is the header and the last few frames in detection report for `google/mt5-small` run in fp16
    mixed precision :

    ```
    Detected inf/nan during batch_number=0
    Last 21 forward frames:
    abs min  abs max  metadata
    [...]
                      encoder.block.2.layer.1.DenseReluDense.wi_0 Linear
    2.17e-07 4.50e+00 weight
    1.79e-06 4.65e+00 input[0]
    2.68e-06 3.70e+01 output
                      encoder.block.2.layer.1.DenseReluDense.wi_1 Linear
    8.08e-07 2.66e+01 weight
    1.79e-06 4.65e+00 input[0]
    1.27e-04 2.37e+02 output
                      encoder.block.2.layer.1.DenseReluDense.wo Linear
    1.01e-06 6.44e+00 weight
    0.00e+00 9.74e+03 input[0]
    3.18e-04 6.27e+04 output
                      encoder.block.2.layer.1.DenseReluDense T5DenseGatedGeluDense
    1.79e-06 4.65e+00 input[0]
    3.18e-04 6.27e+04 output
                      encoder.block.2.layer.1.dropout Dropout
    3.18e-04 6.27e+04 input[0]
    0.00e+00      inf output
    ```

    You can see here, that `T5DenseGatedGeluDense.forward` resulted in output activations, whose absolute max value was
    around 62.7K, which is very close to fp16's top limit of 64K. In the next frame we have `Dropout` which
    renormalizes the weights, after it zeroed some of the elements, which pushes the absolute max value to more than
    64K, and we get an overlow.

    As you can see it's the previous frames that we need to look into when the numbers start going into very large for
    fp16 numbers.

    The tracking is done in a forward hook, which gets invoked immediately after `forward` has completed.

    By default the last 21 frames are printed. You can change the default to adjust for your needs. For example :

    ```python
    debug_overflow = DebugUnderflowOverflow(model, max_frames_to_save=100)
    ```

        To validate that you have set up this debugging feature correctly, and you intend to use it in a training that
        may take hours to complete, first run it with normal tracing enabled for one of a few batches as explained in
        the next section.


        Mode 2. Specific batch absolute min/max tracing without detection

        The second work mode is per-batch tracing with the underflow/overflow detection feature turned off.

        Let's say you want to watch the absolute min and max values for all the ingredients of each `forward` call of a
    given batch, and only do that for batches 1 and 3. Then you instantiate this class as :

    ```python
    debug_overflow = DebugUnderflowOverflow(model, trace_batch_nums=[1, 3])
    ```

    And now full batches 1 and 3 will be traced using the same format as explained above. Batches are 0-indexed.

    This is helpful if you know that the program starts misbehaving after a certain batch number, so you can
    fast-forward right to that area.


    Early stopping:

    You can also specify the batch number after which to stop the training, with :

    ```python
    debug_overflow = DebugUnderflowOverflow(model, trace_batch_nums=[1, 3], abort_after_batch_num=3)
    ```

    This feature is mainly useful in the tracing mode, but you can use it for any mode.


    **Performance**:

    As this module measures absolute `min`/``max` of each weight of the model on every forward it'll slow the training
    down. Therefore remember to turn it off once the debugging needs have been met.

    Args:
        model (`nn.Module`):
            The model to debug.
        max_frames_to_save (`int`, *optional*, defaults to 21):
            How many frames back to record
        trace_batch_nums(`List[int]`, *optional*, defaults to `[]`):
            Which batch numbers to trace (turns detection off)
        abort_after_batch_num  (`int``, *optional*):
            Whether to abort after a certain batch number has finished
    """

    def __init__(self, model, max_frames_to_save=21, trace_batch_nums=[], abort_after_batch_num=None):
        self.model = model
        self.trace_batch_nums = trace_batch_nums
        self.abort_after_batch_num = abort_after_batch_num

        # keep a LIFO buffer of frames to dump as soon as inf/nan is encountered to give context to the problem emergence
        self.frames = collections.deque([], max_frames_to_save)
        self.frame = []
        self.batch_number = 0
        self.total_calls = 0
        self.detected_overflow = False
        self.prefix = "                 "

        self.analyse_model()

        self.register_forward_hook()

    def save_frame(self, frame=None):
        if frame is not None:
            self.expand_frame(frame)
        self.frames.append("\n".join(self.frame))
        self.frame = []  # start a new frame

    def expand_frame(self, line):
        self.frame.append(line)

    def trace_frames(self):
        print("\n".join(self.frames))
        self.frames = []

    def reset_saved_frames(self):
        self.frames = []

    def dump_saved_frames(self):
        print(f"\nDetected inf/nan during batch_number={self.batch_number}")
        print(f"Last {len(self.frames)} forward frames:")
        print(f"{'abs min':8} {'abs max':8} metadata")
        print("\n".join(self.frames))
        print("\n\n")
        self.frames = []

    def analyse_model(self):
        # extract the fully qualified module names, to be able to report at run time. e.g.:
        # encoder.block.2.layer.0.SelfAttention.o
        #
        # for shared weights only the first shared module name will be registered
        self.module_names = {m: name for name, m in self.model.named_modules()}
        # self.longest_module_name = max(len(v) for v in self.module_names.values())

    def analyse_variable(self, var, ctx):
        if torch.is_tensor(var):
            self.expand_frame(get_abs_min_max(var, ctx))
            if detect_overflow(var, ctx):
                self.detected_overflow = True
        elif var is None:
            self.expand_frame(f"{'None':>17} {ctx}")
        else:
            self.expand_frame(f"{'not a tensor':>17} {ctx}")

    def batch_start_frame(self):
        self.expand_frame(f"\n\n{self.prefix} *** Starting batch number={self.batch_number} ***")
        self.expand_frame(f"{'abs min':8} {'abs max':8} metadata")

    def batch_end_frame(self):
        self.expand_frame(f"{self.prefix} *** Finished batch number={self.batch_number-1} ***\n\n")

    def create_frame(self, module, input, output):
        self.expand_frame(f"{self.prefix} {self.module_names[module]} {module.__class__.__name__}")

        # params
        for name, p in module.named_parameters(recurse=False):
            self.analyse_variable(p, name)

        # inputs
        if isinstance(input, tuple):
            for i, x in enumerate(input):
                self.analyse_variable(x, f"input[{i}]")
        else:
            self.analyse_variable(input, "input")

        # outputs
        if isinstance(output, tuple):
            for i, x in enumerate(output):
                # possibly a tuple of tuples
                if isinstance(x, tuple):
                    for j, y in enumerate(x):
                        self.analyse_variable(y, f"output[{i}][{j}]")
                else:
                    self.analyse_variable(x, f"output[{i}]")
        else:
            self.analyse_variable(output, "output")

        self.save_frame()

    def register_forward_hook(self):
        self.model.apply(self._register_forward_hook)

    def _register_forward_hook(self, module):
        module.register_forward_hook(self.forward_hook)

    def forward_hook(self, module, input, output):
        # - input is a tuple of packed inputs (could be non-Tensors)
        # - output could be a Tensor or a tuple of Tensors and non-Tensors

        last_frame_of_batch = False

        trace_mode = True if self.batch_number in self.trace_batch_nums else False
        if trace_mode:
            self.reset_saved_frames()

        if self.total_calls == 0:
            self.batch_start_frame()
        self.total_calls += 1

        # count batch numbers - the very first forward hook of the batch will be called when the
        # batch completes - i.e. it gets called very last - we know this batch has finished
        if module == self.model:
            self.batch_number += 1
            last_frame_of_batch = True

        self.create_frame(module, input, output)

        # if last_frame_of_batch:
        #     self.batch_end_frame()

        if trace_mode:
            self.trace_frames()

        if last_frame_of_batch:
            self.batch_start_frame()

        if self.detected_overflow and not trace_mode:
            self.dump_saved_frames()

            # now we can abort, as it's pointless to continue running
            raise ValueError(
                "DebugUnderflowOverflow: inf/nan detected, aborting as there is no point running further. "
                "Please scroll up above this traceback to see the activation values prior to this event."
            )

        # abort after certain batch if requested to do so
        if self.abort_after_batch_num is not None and self.batch_number > self.abort_after_batch_num:
            raise ValueError(
                f"DebugUnderflowOverflow: aborting after {self.batch_number} batches due to"
                f" `abort_after_batch_num={self.abort_after_batch_num}` arg"
            )


def get_abs_min_max(var, ctx):
    abs_var = var.abs()
    return f"{abs_var.min():8.2e} {abs_var.max():8.2e} {ctx}"


def detect_overflow(var, ctx):
    """
    Report whether the tensor contains any `nan` or `inf` entries.

    This is useful for detecting overflows/underflows and best to call right after the function that did some math that
    modified the tensor in question.

    This function contains a few other helper features that you can enable and tweak directly if you want to track
    various other things.

    Args:
        var: the tensor variable to check
        ctx: the message to print as a context

    Return:
        `True` if `inf` or `nan` was detected, `False` otherwise
    """
    detected = False
    if torch.isnan(var).any().item():
        detected = True
        print(f"{ctx} has nans")
    if torch.isinf(var).any().item():
        detected = True
        print(f"{ctx} has infs")

    # if needed to monitor large elements can enable the following
    if 0:  # and detected:
        n100 = var[torch.ge(var.abs(), 100)]
        if n100.numel() > 0:
            print(f"{ctx}:  n100={n100.numel()}")
        n1000 = var[torch.ge(var.abs(), 1000)]
        if n1000.numel() > 0:
            print(f"{ctx}: n1000={n1000.numel()}")
        n10000 = var[torch.ge(var.abs(), 10000)]
        if n10000.numel() > 0:
            print(f"{ctx}: n10000={n10000.numel()}")

    if 0:
        print(f"min={var.min():9.2e} max={var.max():9.2e}")

    if 0:
        print(f"min={var.min():9.2e} max={var.max():9.2e} var={var.var():9.2e} mean={var.mean():9.2e} ({ctx})")

    return detected


class DebugOption(ExplicitEnum):
    UNDERFLOW_OVERFLOW = "underflow_overflow"
    TPU_METRICS_DEBUG = "tpu_metrics_debug"