new

Get trending papers in your email inbox!

Subscribe

Daily Papers

by AK and the research community

Learning Rates as a Function of Batch Size: A Random Matrix Theory Approach to Neural Network Training

We study the effect of mini-batching on the loss landscape of deep neural networks using spiked, field-dependent random matrix theory. We demonstrate that the magnitude of the extremal values of the batch Hessian are larger than those of the empirical Hessian. We also derive similar results for the Generalised Gauss-Newton matrix approximation of the Hessian. As a consequence of our theorems we derive an analytical expressions for the maximal learning rates as a function of batch size, informing practical training regimens for both stochastic gradient descent (linear scaling) and adaptive algorithms, such as Adam (square root scaling), for smooth, non-convex deep neural networks. Whilst the linear scaling for stochastic gradient descent has been derived under more restrictive conditions, which we generalise, the square root scaling rule for adaptive optimisers is, to our knowledge, completely novel. %For stochastic second-order methods and adaptive methods, we derive that the minimal damping coefficient is proportional to the ratio of the learning rate to batch size. We validate our claims on the VGG/WideResNet architectures on the CIFAR-100 and ImageNet datasets. Based on our investigations of the sub-sampled Hessian we develop a stochastic Lanczos quadrature based on the fly learning rate and momentum learner, which avoids the need for expensive multiple evaluations for these key hyper-parameters and shows good preliminary results on the Pre-Residual Architecure for CIFAR-100.

EcoTTA: Memory-Efficient Continual Test-time Adaptation via Self-distilled Regularization

This paper presents a simple yet effective approach that improves continual test-time adaptation (TTA) in a memory-efficient manner. TTA may primarily be conducted on edge devices with limited memory, so reducing memory is crucial but has been overlooked in previous TTA studies. In addition, long-term adaptation often leads to catastrophic forgetting and error accumulation, which hinders applying TTA in real-world deployments. Our approach consists of two components to address these issues. First, we present lightweight meta networks that can adapt the frozen original networks to the target domain. This novel architecture minimizes memory consumption by decreasing the size of intermediate activations required for backpropagation. Second, our novel self-distilled regularization controls the output of the meta networks not to deviate significantly from the output of the frozen original networks, thereby preserving well-trained knowledge from the source domain. Without additional memory, this regularization prevents error accumulation and catastrophic forgetting, resulting in stable performance even in long-term test-time adaptation. We demonstrate that our simple yet effective strategy outperforms other state-of-the-art methods on various benchmarks for image classification and semantic segmentation tasks. Notably, our proposed method with ResNet-50 and WideResNet-40 takes 86% and 80% less memory than the recent state-of-the-art method, CoTTA.

Improving Generalization of Adversarial Training via Robust Critical Fine-Tuning

Deep neural networks are susceptible to adversarial examples, posing a significant security risk in critical applications. Adversarial Training (AT) is a well-established technique to enhance adversarial robustness, but it often comes at the cost of decreased generalization ability. This paper proposes Robustness Critical Fine-Tuning (RiFT), a novel approach to enhance generalization without compromising adversarial robustness. The core idea of RiFT is to exploit the redundant capacity for robustness by fine-tuning the adversarially trained model on its non-robust-critical module. To do so, we introduce module robust criticality (MRC), a measure that evaluates the significance of a given module to model robustness under worst-case weight perturbations. Using this measure, we identify the module with the lowest MRC value as the non-robust-critical module and fine-tune its weights to obtain fine-tuned weights. Subsequently, we linearly interpolate between the adversarially trained weights and fine-tuned weights to derive the optimal fine-tuned model weights. We demonstrate the efficacy of RiFT on ResNet18, ResNet34, and WideResNet34-10 models trained on CIFAR10, CIFAR100, and Tiny-ImageNet datasets. Our experiments show that \method can significantly improve both generalization and out-of-distribution robustness by around 1.5% while maintaining or even slightly enhancing adversarial robustness. Code is available at https://github.com/microsoft/robustlearn.