import torch
= torch.randn(3)
w = torch.randn(6,3)
x
def forward(x,weights):
return x@weights
Gradient accumulation
One of the techniques which can be used to train bigger models with less GPU memory is gradient accumulation. It is a training technique where gradients computed during multiple mini-batches are accumulated before performing a weight update, effectively simulating larger batch sizes without the memory requirements.
Given a model and a loss function, in the forward pass the final loss value is computed as an average of loss values for all samples in a batch. In a backward pass gradients are computed for each input and added together. Below is a diagram showing forward and backward passes for a batch of 6 items:
Computation without gradient accumulation
Let’s create a simple toy example with model performing only a dot product and an identity loss function
= w.clone()
weights True)
weights.requires_grad_(
# forward
= x
batch = forward(batch,weights).mean()
result
#backward
result.backward() weights.grad
tensor([-0.3536, 0.0207, -0.6125])
Gradient accumulation without normalization
If we decide to distribute the computation into 3 batches of 2 items, the gradient would be different.
An example below shows the diagram of such scenario:
The gradient is different here:
= w.clone()
weights True)
weights.requires_grad_(
=2
batch_sizefor batch_n in range(0,len(x),batch_size):
= x[batch_n:batch_n+batch_size]
batch = forward(batch,weights).mean()
result
result.backward()
weights.grad
tensor([-1.0608, 0.0620, -1.8375])
Gradient accumulation with proper normalization
We can see that the reason for differences is the outcome are different normalization factors in the loss function. When we average loss values of each batch in isolation we don’t take into account the total amount of the items we want to accumulate gradients for. To correct the computation we need to add a normalization factor when computing a final loss value. This factor is a ratio between the total number of items we perform the computation for and the batch size. In our example it equals to \(6/2 = 3\)
The proper normalization produces correct gradient vector:
= w.clone()
weights True)
weights.requires_grad_(
=2
batch_sizefor batch_n in range(0,len(x),batch_size):
= x[batch_n:batch_n+batch_size]
batch = forward(batch,weights).mean() / 3
result
result.backward()
weights.grad
tensor([-0.3536, 0.0207, -0.6125])
Practical implementation of this technique is well explained in the fast.ai course lecture.