Gradient Checkpointing is a technique to trade off speed for reduced VRAM usage during backprop. During backprop, we usually keep the forward activations of all layers preceding the ones we computed the gradient for in VRAM, since we will need them during later steps of backpropagation. We can reduce VRAM usage by discarding these earlier activations and recomputing them later, when we require them. A middle ground between computing everything again and keeping everything in VRAM is keeping only certain checkpoints in VRAM. The linked repo has a great animation showing the whole process. PyTorch has this implemented as activation checkpointing (which is a more reasonable name). In their blog they also mention that they offer an automatic Pareto-optimal tradeoff for a user-specified memory limit! (although the config seems to have a different name in the code than mentioned in the blog)
○