Skip to content
Amardeep Kumar
Go back

Normalization: BatchNorm, LayerNorm, and Why Transformers Need a Different One

In post 1.5 we set up a good optimizer - AdamW with adaptive learning rates and weight decay. But there’s a separate problem that shows up as networks get deeper: the values flowing through the network drift over time. This post is about why that happens and how normalisation fixes it.


Why it exists

Run a 20-layer network with no normalisation and look at the distribution of values coming out of each layer. Layer 1 might output values between -1 and 1. By layer 10, the values are in the thousands. By layer 20, they’ve either exploded to huge numbers or collapsed to near-zero.

This happens because each layer multiplies its inputs by its weights. If a layer’s weights are slightly greater than 1, the values grow with every pass. Slightly less than 1, they shrink. Over 20 layers, small multiplicative errors compound - the same reason 1.01^20 ≈ 1.22 and 0.99^20 ≈ 0.82.

Exploding or vanishing values cause two problems. First, gradient descent breaks: gradients computed on huge or tiny activations are unstable, and training either diverges or stalls. Second, activations saturate: if your activation function (like sigmoid) is given a value of 1000, it outputs 1.0 for everything, and you’ve lost all information about the input.

Normalisation fixes this by rescaling the activations at each layer so they have a predictable distribution - roughly zero mean and unit variance - before passing to the next layer.

The formula is z-score standardisation, the same thing you’ve seen in statistics or data preprocessing:

x_normalised = (x - mean) / sqrt(variance + ε)

ε is a small number added to prevent dividing by zero.

After normalising, a learned scale (γ) and shift (β) are applied to give the network the freedom to undo the normalisation if needed:

output = γ × x_normalised + β

γ and β are learnable parameters, just like weights. This means the network can learn “actually, this layer works better with a mean of 2 and a variance of 3” - the normalisation step doesn’t force a rigid distribution, it just stabilises the starting point.

The question is: what values do you compute mean and variance over? That’s what separates BatchNorm from LayerNorm.


How it works

BatchNorm

BatchNorm (Batch Normalisation) computes mean and variance across the batch dimension.

Say your batch has 64 samples, and each sample has 512 features. BatchNorm looks at each of the 512 features and asks: across all 64 samples in this batch, what’s the mean and variance of this feature? Then it normalises using those batch-level statistics.

The key word is “across the batch.” BatchNorm’s normalisation stats depend on the other samples sharing the batch with you.

That’s fine for image models. Images are all the same size, batch sizes are large enough to compute stable statistics, and inference can use running averages collected during training.

But it breaks in two situations:

Think of it like a global lock in concurrent programming. BatchNorm requires all samples in the batch to coordinate to compute a shared statistic before any single sample can proceed. That works when all inputs are the same shape - but it’s a bottleneck when inputs differ, and it fails entirely when you have nothing to aggregate across. A grid with rows = samples in the batch, columns = features. BatchNorm highlights a column, showing normalisation happens down the batch. LayerNorm highlights a full row, showing normalisation happens across features for one sample.

LayerNorm

LayerNorm (Layer Normalisation) computes mean and variance across the feature dimension of a single sample.

For one sample with 512 features, LayerNorm looks at all 512 values and normalises them using their own mean and variance - no other samples involved.

Every sample normalises itself. Batch size doesn’t matter. Sequence length doesn’t matter. Whether you’re at training time or inference time, the statistics are computed from the sample you’re currently processing.

That’s why transformers use LayerNorm. Text sequences vary in length. Batching sequences together requires padding shorter ones to match the longest, which already complicates batch statistics. LayerNorm sidesteps all of this - each token normalises its own 512-dimensional vector without consulting anything else.

There’s no global state. No coordination across samples. Like a function that only reads its own local variables - nothing shared, nothing to synchronise.

RMSNorm

RMSNorm (Root Mean Square Normalisation) is a simplified version of LayerNorm used in LLaMA, Mistral, and most modern open-source LLMs.

Standard LayerNorm subtracts the mean before scaling (centring + scaling). RMSNorm skips the centring and just scales by the root mean square:

RMS(x) = sqrt(mean(x²))
x_normalised = x / RMS(x)

No mean subtraction - just divide by the magnitude. The intuition: centering is often redundant because the scale parameter (γ) can implicitly shift the distribution. Removing it shaves compute without hurting quality.

If LayerNorm is “normalise by subtracting mean and dividing by standard deviation,” RMSNorm is “just divide by the typical magnitude.” One fewer thing to compute, similar results in practice.


Code

Step 1 - watch activations drift without normalisation:

A 10-layer network, each layer a Linear + ReLU. No normalisation. We’ll plot the distribution of activations at layer 1, layer 5, and layer 10.

import torch
import torch.nn as nn
import matplotlib.pyplot as plt

torch.manual_seed(42)

class DeepNetNoNorm(nn.Module):
    def __init__(self, depth=10, width=256):
        super().__init__()
        self.layers = nn.ModuleList([
            nn.Sequential(nn.Linear(width, width), nn.ReLU())
            for _ in range(depth)
        ])

    def forward(self, x):
        snapshots = {}
        for i, layer in enumerate(self.layers):
            x = layer(x)
            if i in (0, 4, 9):
                snapshots[f"layer {i+1}"] = x.detach().flatten().numpy()
        return x, snapshots

model = DeepNetNoNorm()
x = torch.randn(64, 256)  # batch of 64 samples, 256 features
_, snapshots = model(x)

fig, axes = plt.subplots(1, 3, figsize=(12, 3))
for ax, (name, values) in zip(axes, snapshots.items()):
    ax.hist(values, bins=60, color="#3b82f6", alpha=0.8)
    ax.set_title(f"{name}  (std={values.std():.1f})")
    ax.set_xlabel("activation value")
    ax.spines["top"].set_visible(False)
    ax.spines["right"].set_visible(False)

plt.suptitle("Activations without normalisation — distribution shifts across layers")
plt.tight_layout()
plt.savefig("/images/posts/batchnorm-vs-layernorm/activations-no-norm.png", dpi=150)
plt.show()

Three histograms on a shared x-axis. Layer 1 is spread from 0 to 1.2 with std=0.34. Layers 5 and 10 are a tight spike near zero with std=0.02 — the activations have collapsed across layers.

With the same x-axis across all three, the collapse is clear. Layer 1 has a spread of values reaching up to 1.2. By layer 5, almost everything is crushed into a spike near zero and stays there. The standard deviation drops from 0.34 to 0.02 in five layers.

Step 2 - add LayerNorm and watch the distributions stabilise:

class DeepNetWithLayerNorm(nn.Module):
    def __init__(self, depth=10, width=256):
        super().__init__()
        self.layers = nn.ModuleList([
            nn.Sequential(nn.Linear(width, width), nn.LayerNorm(width), nn.ReLU())
            for _ in range(depth)
        ])

    def forward(self, x):
        snapshots = {}
        for i, layer in enumerate(self.layers):
            x = layer(x)
            if i in (0, 4, 9):
                snapshots[f"layer {i+1}"] = x.detach().flatten().numpy()
        return x, snapshots

model_norm = DeepNetWithLayerNorm()
torch.manual_seed(42)
x = torch.randn(64, 256)
_, snapshots_norm = model_norm(x)

fig, axes = plt.subplots(1, 3, figsize=(12, 3))
for ax, (name, values) in zip(axes, snapshots_norm.items()):
    ax.hist(values, bins=60, color="#10b981", alpha=0.8)
    ax.set_title(f"{name}  (std={values.std():.1f})")
    ax.set_xlabel("activation value")
    ax.spines["top"].set_visible(False)
    ax.spines["right"].set_visible(False)

plt.suptitle("Activations with LayerNorm — distribution stays consistent")
plt.tight_layout()
plt.savefig("/images/posts/batchnorm-vs-layernorm/activations-with-layernorm.png", dpi=150)
plt.show()

Three histograms for layers 1, 5, and 10 with LayerNorm. All three have nearly identical shape and scale — the distribution is stable across all layers.

The distributions now look nearly identical across layers 1, 5, and 10. Gradients flowing back through a normalised network are much more stable.

Step 3 - using nn.LayerNorm in practice:

import torch
import torch.nn as nn

# normalise a batch of sequences: (batch=4, seq_len=10, features=64)
x = torch.randn(4, 10, 64)
norm = nn.LayerNorm(64)   # normalise the last dimension (features)

out = norm(x)
print(f"input  shape: {x.shape}")
print(f"output shape: {out.shape}")
print(f"input  mean={x.mean():.3f}  std={x.std():.3f}")
print(f"output mean={out.mean():.3f}  std={out.std():.3f}")
input  shape: torch.Size([4, 10, 64])
output shape: torch.Size([4, 10, 64])
input  mean=0.006  std=0.998
output mean=0.000  std=0.570

Shape is unchanged - LayerNorm is applied in-place. The output mean is near zero and variance is tightened. You’d see a larger effect with deep, uninitialized networks where values have actually drifted.

nn.LayerNorm(64) tells PyTorch: “normalise the last 64 dimensions.” In a transformer this is called at every attention block and every feed-forward block. It’s one line, but it’s what keeps a 96-layer model trainable.


Key takeaways


Training is now stable thanks to normalisation. But there’s one more thing networks do to avoid overfitting: deliberately break during training so they don’t memorise the data.

Post 1.7 - Dropout and Overfitting


Share this post on:

Previous Post
Transformer Architecture & Key Design Decisions
Next Post
Optimizers: SGD, Momentum, Adam, and AdamW