Writing an LLM from scratch, part 32b — Interventions: gradient clipping :: Giles’ blog


I’m still working on training the best GPT-2 small sized base model that I can
with a number of FLOPs roughly equal to two days on my own machine — my “extra credit”
exercise after having worked through
Sebastian Raschka‘s book
Build a Large Language Model (from Scratch)“.

In the last post I trained
a baseline model — one with the same architecture and almost the same training code as in
the minimal training run in the book, just modified to run using DDP on an 8x A100 40 GiB/GPU
machine in the cloud.
There are a bunch of “interventions” I want to try to see if they’ll make it better,
as measured by the loss they get on a test set. I’ll do a post for each intervention,
and this is the first: gradient clipping.

Why?

In the training chart for the baseline model, you can see that there are three
places where the loss suddenly spiked up, at around global steps 4,200, 13,000,
and 23,000:

Baseline training run on an 8x A100 with 40 GiB/GPU

There are a number of things that could cause loss spikes like that:

  • A “bad batch” — that is, one batch, or even one sequence in a batch, was
    massively different in structure to the others that the model had seen, so it just
    had much worse loss. That doesn’t seem likely in this case, though: the numbers
    on the chart are averages over 617 global steps each, and it would take a truly pathological
    sequence to move the needle that much.
  • Something weird in the optimiser. That’s not something I understand well, but
    according to the various LLMs I’m working with, it’s a possibility.
  • Exploding gradients. This is my working hypothesis, and so in this post I’ll
    try out gradient clipping, the normal solution to that problem.

What?

Exploding gradients are common in RNNs, and also happen in LLMs like this one. I spent a bit
of time reading around to find out how they happen, and the ah-ha moment came when
I came across this post from Wanshun Wong.
Not only is the post itself a good intro in terms of how it affects RNNs, but in the
“further reading” at the end, there’s some gold:

Chapter 10.11 of [1] has a good overview of how gradient clipping works.

References

  1. I. Goodfellow, Y. Bengio, and A. Courville. Deep Learning (2016), MIT Press.

Now, I bought a copy of “Deep Learning” at the
same time as I bought Raschka’s book, but I’d only glanced through it. Now was the
time to get it down from the shelf — and, indeed, section 10.11.1 is all about clipping
to handle exploding gradients. I’ll put the explanation of how they happen into my
own words, to see if I can clarify things (at least in my mind).

Normally, when we learn about gradient descent, it’s illustrated with nice smooth
loss charts like this imaginary one for a single-parameter model:

A simple loss chart

We’re told that we might start at point A. The gradient is quite high and negative,
so we multiply it by our learning rate and subtract it from our parameter. That
gets us to point B. This time around, the gradient is smaller as the curve is flatter
there, so when we do the same — multiply by LR and subtract — we take a smaller step, and
wind up at C. Rinse and repeat and we’ll wind up near the minimum.

The problem is, what if the loss curve actually looks like this:

A more complex loss chart

…?

We start at A, with a small gradient, move a little to the right, and now we’re at
B halfway down a cliff! The gradient is massive, and when we subtract it, even scaled
by the learning rate, we can zoom off somewhere to the right — maybe not even on the
chart. Indeed, you can imagine a cliff that is so steep that
it would have vertical portions — negative infinite gradients in this case — and no matter what your learning
rate is, you’ll wind up with an infinite parameter update and everything will break.
It’s hard to see how a model can continue training in a case like that.

Now, what can cause steep cliffs like that? The book says “strongly nonlinear functions,
such as those computed by a recurrent neural net over many time steps”.

If you know about RNNs (I wrote about them
if you’d like a summary), you’ll remember that a single RNN might be quite
shallow — maybe three or four layers — but when you’re doing backpropagation,
you run a number of inputs through, one after the other, work out the overall loss, and then “unroll” it to
something similar to a “vanilla” neural net to do the backward pass. To put that in
concrete terms, a 3-layer neural network trained with a 100-element sequence would
unroll to a 300-layer deep network. Every one of those layers has several operations, including
(in the implementation I was looking at in my post above), a tanh. It’s not surprising that there are cliffs in the loss landscape — it’s
more surprising that there are any smooth bits!

Now in LLMs, we don’t have that unrolling through time — but our network is deep enough
as it is. For the GPT-2 small model, disregarding the embeddings and the final output head, we have 12 Transformer layers,
each of which is multiple matrix multiplications for attention, then a softmax, then another layer, and
then a feed-forward… mapping precisely to the equivalent vanilla NN is hard, but I think
you can treat each one as at least four layers, so we’ve got 48. And there are GELUs and logs and
exps dotted around, so again — we should expect cliffs.

So if sometimes we’ll get crazy gradients, what can we do about them? We clip them.

How?

Clipping gradients simply means that if they get larger than a particular number — v,
which we define — we reduce them to that number. In other words, we have a cap on how
big they can get.

“Deep Learning” (“DL” from now on) suggests two ways to do it. Remember that while in the
example above, we only had one parameter — on the X axis — for the GPT-2 small
LLM we’re training, we have 163 million of them. So the gradients, instead of
being one number, will be a 163M-long vector, one per parameter. The two ways to clip are:

  • We clip element-wise. If any one of the gradients in the vector is larger than v,
    we reduce it to v.
  • We clip based on the norm: the length of the gradient vector in — in our
    case — 163M-dimensional space. That sounds harder than it is — it’s really
    just an extension of the Pythagorean equation that a2+b2=c2 to multiple
    dimensions. If you want to work out the length of a vector (a,b) then you
    can use Pythagoras to work out c=a2+b2, and that generalises
    to any number of dimensions. So for our model we’d just square all 163M
    elements of the vector, sum those, and take the square root of the result, and that’s the norm.
    If the norm is greater than v, we just divide every element of the gradient vector by the norm
    and multiply the result by v, to produce
    a new gradient vector whose norm is v.

The second feels more elegant — we’re scaling all of the elements of the gradient
vector by the same amount, so it still points in the same direction. Interestingly, though,
DL says that the two methods “work similarly”, which I’ll read as “are pretty much
the same in practice”.

DL then goes on to say how infinite or not-a-number gradients should be handled.
With the first way, clearly doing it naively would set every element in the gradient
vector to v, which would make the total size (norm) of the update very large. With the
second, it be even worse — we’d still wind up with completely junk gradients, because
the norm would be infinite, and in Python math.inf / math.inf is math.nan, so
we’d be applying gradients with NaNs in them at best. That would be likely to
knock our model into unrecoverable territory, as any parameter that had that applied
to it would be NaN forever.

Their suggested solution is that if you get garbage gradients like that, you can take
a random step — that is, create a new gradient to apply that has the norm v
but just points in a random direction. The idea is that this will move you away from
the cliff-ridden part of the loss landscape where you’ve found yourself (more about that later), and things will
continue nicely.

So, anyway, how to do this in practice?

PyTorch has a function, clip_grad_norm_,
and that’s what’s referenced in almost every bit of writing I’ve found about how
to clip gradients. So I decided to use that, assuming it would do what was described
in DL’s second option and that it would do the random updates they suggest for non-finite
gradients. (I was half-correct — see later.)

As to how to use it — if we had a normal training loop, where we were just using a normal optimiser, we would go from:

    train_loss = calculate_loss(y_logits, target_y_ids)
    train_loss.backward()

    optimizer.step()

…to something like

    train_loss = calculate_loss(y_logits, target_y_ids)
    train_loss.backward()

    torch.nn.utils.clip_grad_norm_(model.parameters(), clipping_max_norm)

    optimizer.step()

…where clipping_max_norm is the max value v from above.

However, for our training code using Automatic Mixed Precision (AMP),
it’s a little more complicated — but luckily, the AMP explainer we’ve been using
has a section explaining what to do.

Right now we have this:

    with torch.amp.autocast(device_type=device.type, dtype=torch.float16):
        logits = model(inputs)

        train_loss = calculate_loss(logits, targets)

    scaler.scale(train_loss).backward()
    scaler.step(optimizer)
    scaler.update()

Per that explainer, we need to move to this:

    with torch.amp.autocast(device_type=device.type, dtype=torch.float16):
        logits = model(inputs)

        train_loss = calculate_loss(logits, targets)

    scaler.scale(train_loss).backward()

    scaler.unscale_(optimizer)
    torch.nn.utils.clip_grad_norm_(model.parameters(), clipping_max_norm)

    scaler.step(optimizer)
    scaler.update()

That looks a bit weird; we’re “unscaling” the gradients,
then clipping them, then using the scaler to step the
optimiser. You’d think that you’d need to “re-scale” the scaler after clipping the gradients —
to get back to where you started from before the optimiser step.
From the help page I gather it keeps track of whether or not the gradients it has right now are
currently scaled and handles them appropriately based on that state in scaler.step.

Anyway, given that we know what the code looks like now, we need to implement it
in a way that can be easily switched on for this experiment (and potentially in
the future), but which also allows us to not use it if we don’t want to.

The best way with our setup is to make it a training option, so we can do it this way:

    with torch.amp.autocast(device_type=device.type, dtype=torch.float16):
        logits = model(inputs)

        train_loss = calculate_loss(logits, targets)

    scaler.scale(train_loss).backward()

    if clipping_max_norm is not None:
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), clipping_max_norm)

    scaler.step(optimizer)
    scaler.update()

…with clipping_max_norm extracted from the train.json file where we call it in
load_datasets_and_train:

    train(
        run_dir,
        ddp_model, optimizer, scaler,
        train_conf["clipping_max_norm"],
        train_ds,
        global_step, best_loss,
        checkpoint_interval=train_conf["checkpoint_interval"],
        do_checkpoints=True,
    )

…and we can just pass in None for it in our check_batch_size_works function that
we use to find the maximum micro-batch size for our current hardware, as all we’re
testing for there is memory usage — we don’t care if we’re doing good updates.

Here’s the code delta for that,
plus a bugfix to allow
for train.json files without a clipping_max_norm in them.

But it would also be useful to be able to track when it “fired” — that is, when we
had to clip our gradients. Then we can see two things:

  1. Whether we actually did wind up clipping them and fixing those loss spikes
  2. Whether we were clipping at other times — we don’t want to be doing it unnecessarily.

Now, the docs for clip_grad_norm_
say that it returns the “[t]otal norm of the parameter gradients (viewed as a single vector)”.
It doesn’t say whether that’s before or after the clipping, but given that the return value would
always be clipping_max_norm if it was after, I’m going to guess that it returns
the pre-clipping norm (ChatGPT agrees).

So we can chart that; changes in these diffs: 1,
2,
3,
4.

How much?

So we now have code to clip gradients to a given norm size and to chart the gradient
norms so that we know what they were before clipping. The question is, what
should that clipping norm be? Some googling around suggested that there was no standard way
of saying “for such-and-such a kind of model, gradients should be clipped at around
x“. For example, on this Reddit thread,
GLVic says “Common values are 1, 3, 5, 8, 10”, and likewise sample code in
this tutorial.
has 1, as does this one.

So my initial thought was, let’s just use 1. But then I wondered, what actually are
the gradient norms that we’re getting in normal training? I decided to run a local short
train on 3m tokens (a thousandth of the full training set, taking just less than four minutes) with very frequent checkpointing, and
gradient clipping set to 1, and
see what happened.

Small local train, gradient clipping at 1

You can see that the “grad max” line is almost always above the “grad clip” — we’re
almost always clipping. This doesn’t sound right. It looked like the range of the grad max
was generally beween 1.1 and a little above 3, so I set the clipping_max_norm to 3.5 and
did another train:

Small local train, gradient clipping at 3.5

Our loss is about the same, but we’re no longer clipping — and that’s what we want;
there was no evidence of exploding gradients for that short run — just big updates
near the start, as you’d expect.

I then ran the same with no gradient clipping at all, and got exactly the same shape
for the loss chart as I did with gradient clipping at 3.5, and the same final loss — that’s a good signal that clipping is
not affecting the train when we stay inside the limit, which is exactly what we want.

So, it was time to train our model!

Running the train

I kicked off the train, and after a little while, I looked at the training chart,
which is updated dynamically as the model trains:

First run of cloud train, with missing max gradients

You can see the dotted green lines, both the light one and the dark one — that is,
the “grad max” and the “grad avg” — disappear starting just before global step
4,000, only coming back at about 5,500 — that is, these were not plotted for
global steps 4,319 and 4,936, even though the loss was. What was going on?

I took a look at the checkpoint meta file for the first of those to see what the actual numbers
were, and saw this:

{
    "min_train_loss": 3.7176883220672607,
    "max_train_loss": 5.877607822418213,
    "avg_train_loss": 4.3170230991450085,
    "max_grad_norms": Infinity,
    "avg_grad_norms": Infinity,
    "frac_clipped": 0.0016207455429497568,
    "global_step": 4319,
    "is_best": true
}

Aha! The PyPlot code I was using could not handle infinite values, which is entirely
reasonable. That was easy enough to fix,
though — I just replaced positive infinity by 1,000,000 and negative infinity by -1,000,000,
and then (in the interest of getting a proper from-scratch run) kicked everything
off from the beginning.

That training run completed with this chart:

Second cloud run, showing clipping periods

That’s a little hard to read, but if you look closely at the green lines, you
can see that there are seven periods where gradients were either very large or
infinite. Weirdly, though, out of the seven, two of them were two checkpoint periods long
(that is, two periods of 617 global steps). That felt weird, though of course
we’re looking at the maximum gradient norm and the average gradient norm — so
two single infinite/high-gradient steps in successive 617-step periods would lead to that effect.

What was even stranger, though,
was that if you look at the training chart for the run with no gradient clipping,
we have only three loss spikes rather than seven:

Baseline training run on an 8x A100 with 40 GiB/GPU

…though it’s also very noticeable that the gradient-clipped run had only two small loss
spikes, unlike the three larger ones in the unclipped run.

The training loss the gradient-clipped run reported at the end was better, too:

Training complete in 12,343.442 seconds
Tokens seen: 3,260,252,160
Throughput: 264,128 tokens/second
Final train loss: 3.728

…versus 3.743 at the end of the baseline train. Training took a very small amount
longer compared to the baseline’s 12,243.523 seconds — 100 seconds, which may well
be just noise.

So it was time to download it, upload it to Hugging Face Hub,
and run the sequence-completion smoke test:

giles@perry:~/Dev/ddp-base-model-from-scratch (main)$ uv run test_smoke.py runs/8xa100m40-gradient-clipping/model.json runs/8xa100m40-gradient-clipping/checkpoints/best/model.safetensors
Every effort moves you further afield the most to get more time and space out of your pocket. With more in abundance

Coherent enough!

Next, we evaluate it against our held-back test set:

giles@perry:~/Dev/ddp-base-model-from-scratch (main)$ uv run test_loss.py datasets/ runs/8xa100m40-gradient-clipping/model.json runs/8xa100m40-gradient-clipping/checkpoints/best/model.safetensors
Fetching 4 files: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:0000:00, 1471.81it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3200/3200 [04:5800:00, 10.72it/s]
Loss against our test dataset: 3.678

So, the loss had gone down — but only from 3.692 to 3.678, a reduction of 0.014,
or about 0.3%.

That’s not actually all that bad, though!
After all, in my initial experiments on my local machine, training for a Chinchilla-optimal
number of tokens from FineWeb-Edu (rather than the regular FineWeb I’m using now)
got a loss of 4.167 on the same dataset (weirdly worse with the more-curated training set),
and training for a further Chinchilla-optimal number of tokens only brought that down to
4.135, for a difference of 0.032, or 0.7%. That was the total effect of doubling
the training time.

It’s not really comparable due to the different training sets, but speaking very
loosely, we could say that adding gradient clipping for this train had almost half as much effect as doubling the training time for
the other one, with a negligible increase (about a minute or two) on training time. That’s pretty nifty.

But the question remained: why those long periods of high gradients, even with gradient
clipping? And why were there still loss spikes — in particular the one just before
global step 12,000, which lasted for two checkpoint periods?

Chasing infinity

Remember that when I started the first run of this train, and got the chart with
the missing bits, it was because the logged max_grad_norms and avg_grad_norms
were infinite.

What happens when clip_grad_norm_ gets an infinite gradient — either one that has
an infinity as one of its components, or one that (due to numerical overflow) winds up
with a norm of infinity anyway? I’d been kind of assuming that it did what the authors
described in “Deep Learning” — a random update of norm v — given that the book
stated pretty confidently that you “can” do it but then appeared to consider the topic closed.

But it doesn’t! If you check that link to the docs, you’ll see that it has a parameter
error_if_nonfinite, which is False by default. If it’s set to True, that will
raise an exception if the norm is positive or negative infinity, or if it’s not a number
— which catches both the infinite component and the norm overflow cases above. But if
it’s not set — and we weren’t setting it — and the norm or the gradients are non-finite, then clip_grad_norm_ will essentially
return garbage gradients. Depending on the exact cause, elements will either be infinities
of one sign or another, or NaNs. And if these are added to parameters, then those
parameters will become garbage too.

Now that leads to the question, given that we know that somewhere in the period
between the checkpoint at global step 4,319 and the previous one at 3,702 there was
an infinite norm at some point, how on earth did the model manage to continue training
after that? Loss went up at around the same time, but it wasn’t completely broken as
it would have been with NaNs or infinities in its parameters.

Obscurely enough, the answer turned out to be in the AMP explainer,
in a comment in one of the bits of example code. Regarding the GradScaler class we’re using:

        # ``scaler.step()`` first unscales the gradients of the optimizer's assigned parameters.
        # If these gradients do not contain ``inf``s or ``NaN``s, optimizer.step() is then called,
        # otherwise, optimizer.step() is skipped.

So what was happening was that the scaler — something we introduced into our code
to get a speedup by using 16-bit floats instead of 32-bit whenever PyTorch thought
it would make sense — was protecting us against infinite and NaN gradients as a
side-effect. It was skipping updates that would have polluted our weights with
bad values from non-finite gradients.

Hmph.

Grumble

If the above comes across as a little frustrated, then it’s because I am a bit!
From a software engineering viewpoint, this situation really does feel a bit like a
rather messy part of the API.

There are three things that it’s reasonable for a library to do with infinite/NaN
gradients:

  1. Blindly apply them and expect the developer to sanitise their inputs.
  2. Raise an error.
  3. Take some kind of default sane action, like skipping the update.

Now, if we look at that error_if_nonfinite, we can see that the first two of those
cases are handled there; and the developer can choose which option to follow.
It’s not where I’d personally put it (the step function on the optimiser seems more
natural) and I think I’d probably set the default to True too, but I can also imagine
good reasons for it being the way it is — backward compatibility for one.

But the “skip non-finite gradients” being a (not even optional!) behaviour that is
on a class designed for handling mixed-precision training just seems outright bonkers.
I would be surprised if there weren’t people out there who’ve spent days trying
to work out why their training runs failed catastrophically when they decided to
switch from mixed-precision to “full fat” 32-bit floats, not realising that a
hardly-even-documented feature of the scaler had been saving them from gradient issues
previously.

Anyway, rant over. What does this all mean?

So…?

There are three ways a gradient can explode:

  1. It can get very large, still be finite, and have a finite norm.
  2. It can get very large, still be finite, but have an infinite norm (eg. due to numerical overflow)
  3. It can become infinite — that is, at least one of the parameters’ gradients is infinite (which
    of course means an infinite norm regardless of any numerical stuff).

With both the baseline code and our new code, the GradScaler was saving us from
the last two of those, by skipping the optimiser steps with non-finite gradients.

However, the baseline run was not protected against the first kind — large but finite
gradients with a finite norm — while this run was protected.

What I’m almost certain is happening here is that in all of my training runs so
far, there have been all three kinds of issues with exploding gradients. The
GradScaler, which again, we introduced for faster training, happened to be saving
us from the infinite gradients/norms. But we were still being bitten by the finite
but excessively large ones.

And that, I think, is why this training run had a positive — not huge, but certainly worthwhile
— effect on the test set loss.

If I had more time, I think I’d do another run, logging all three of those categories
of error to see how frequent they are, and charting the result. That might go some way to
explaining the final question I had here: why is it that the renowned “Deep Learning”
suggests a random update to get away from the cliff where you’ve found yourself,
while we seem to be getting away with just skipping the update, which is much simpler?
Well, the book was written in 2016, and I guess rather a lot has changed in the last 10 years 🙂

My guess is that their solution might have been
a solid default in the age of RNNs, but might not make so much sense with the kind of models
we’re training these days.

I think I can see a way in which that makes sense. Think of the illustration of a loss “cliff”
in a one-parameter world that we had at the start of this post:

A more complex loss chart

If you happen to wind up on that cliff, you’re in trouble.

But imagine a two-parameter model — the line of the loss function becomes a surface.
Just as in the real world you might be able to walk along the edge at the top of a cliff and
find a nice easy slope down next to it, you can imagine that the cliff in the two-parameter
case might be less of a problem because you don’t need to be lucky enough to jump down it —
you can walk around it.

Extrapolating examples like this to higher dimensions is
risky, but I think it should hold that the more dimensions you’re working with,
the less likely it is that a cliff is an issue — you’re more likely to be able to find
a way around it. I’ve heard a very similar argument made for why local minima are
less of an issue with lots of parameters. It’s certainly worth saying that this is
far from a mathematical proof, but I think it’s a decent grounding for intuition.

Now think about an RNN. Although you’re doing back-propagation through time over
what amounts to a very deep network, there aren’t actually all that many parameters,
certainly compared to an LLM like this. Each parameter is involved in the back-propagation
multiple times.

So, thinking of it that way, the gradient vector for the RNNs they were dealing with
was of much lower dimensionality than the ones we’re dealing with, even for this
tiny model.

They say that the random step “will typically move away from the numerically unstable
configuration”. I’m probably playing fast and loose here, but I’ll take that as something
like: if you wound up on a cliff, you were likely in a very “cliffy”
area of the loss landscape. “Teleporting” randomly to somewhere some distance away
was a sensible way to handle that.

In our situation, even if the area is “cliffy” in the direction that one particular
batch might push us, we have so many extra dimensions that it may well be that it
won’t be so bad with the next one. So just skipping the problematic update — under
all of those assumptions — seems a perfectly reasonable way to handle it.

Validation

All of this, BTW, made me think back to validation loss. In our previous training runs,
where we were measuring it just before each checkpoint, its spikes were in general correlated
with but not identical to spikes in training loss:

Loss in a run with validation

Now, of course, exploding gradients don’t have to be related to high training loss —
there’s enough non-linearity in there that we can treat them as being completely uncorrelated,
I think. But you definitely would expect them to have an effect on validation
loss if applied. Disregarding the infinite ones (which were being filtered out anyway),
the very high ones that we are now clipping would, in the unclipped baseline
train, seem very likely to have caused validation loss spikes.

So: if I hadn’t stripped that out, we would likely have been able to see a clear
difference in the validation loss line between clipped and unclipped. That would have
been useful!

I’m not going to re-introduce it, though. Best to keep the number of code changes
to a minimum if I’m trying to compare like with like over the course of these intervention
tests.

Anyway.

I think that’s enough for gradient clipping. I may come back and do the experiment
another time to see what the relative ratios of the different kinds of problematic
gradients are. Are there parts of the train where we get lots of them as a percentage (ie.
we’re somewhere “cliffy” in the loss landscape)? How many infinite gradient vs infinite norm
vs big-but-not-infinite instances do we have relative to each other, and to normal
gradient updates? What do we see if we have validation loss? And so on.

But for now: gradient clipping definitely helps, and goes on the positive interventions list!

I’m thinking I’ll see what happens with switching off dropout next. That should at
least be a bit easier…

Stay tuned!

Here’s a link to the next post in this series.

[Update: an earlier version of this post used the wrong number for the baseline
model’s loss on the test set, which made it look like there was a 0.065 improvement
in test loss from adding gradient clipping rather than 0.014. The above has been updated
to reflect the real numbers, and the conclusion, at least remains the same: gradient
clipping is good but not amazing.]

{
const url = new URL(document.querySelector(“link[rel=\”canonical\”]”).href);
url.host = “www.gilesthomas.com”;
return url.toString();
})()
}”
>



Source link

Leave a Reply

Your email address will not be published. Required fields are marked *