Batch Normalization

Definition

Normalizes pre-activations within a mini-batch using batch statistics, then rescales with learned parameters . Subsumes input normalization as a special case applied at the data layer.

Intuition

Input normalization: equalizing feature scales removes elongated, axis-misaligned loss contours that force small learning rates and cause slow, zigzagging descent. The same principle applied deep inside a network is batch normalization. As activations shift during training (internal covariate shift), each layer must constantly readapt to a moving distribution. Batch norm pins each layer’s input distribution near zero mean and unit variance, smoothing the optimization landscape and permitting higher learning rates.

Formal Description

Input normalization (applied to raw features before training):

where and are the mean and standard deviation of feature over the training set. Ensures all input features have comparable scale, improving gradient descent convergence.

Batch normalization (applied to pre-activations at layer ):

Forward pass (training):

and are learned per-feature parameters; prevents division by zero.

Test time: replace with exponential moving averages accumulated during training:

This makes inference deterministic and independent of batch composition.

Placement: typically inserted before the activation function (e.g., Linear → BN → ReLU), though post-activation is also used.

Applications

  • Virtually all modern deep networks (ResNets, Transformers use it or a variant)
  • Enables significantly deeper networks by stabilizing gradient flow
  • Acts as a mild regularizer, sometimes reducing the need for dropout

Trade-offs

  • Noise as regularization: batch statistics introduce per-batch noise, which has a regularizing effect but also means training and test behavior differ
  • Train vs. test mismatch: moving average estimates must be tracked carefully; bugs here are a common source of subtle errors
  • Small batch sizes: batch statistics become unreliable; prefer Layer Normalization (normalizes over features instead of batch) for small batches, RNNs, and Transformers
  • Overhead: adds parameters and a forward-pass computation per layer, though cost is generally negligible
  • weight_initialization becomes less critical when batch norm is used, since activations are re-centered at each layer