> #training

2 posts

Gradient Checkpointing is a technique to trade off speed for reduced VRAM usage during backprop. During backprop, we usually keep the forward activations of all layers preceding the ones we computed the gradient for in VRAM, since we will need them during later steps of backpropagation. We can reduce VRAM usage by discarding these earlier activations and recomputing them later, when we require them. A middle ground between computing everything again and keeping everything in VRAM is keeping only certain checkpoints in VRAM. The linked repo has a great animation showing the whole process. PyTorch has this implemented as activation checkpointing (which is a more reasonable name). In their blog they also mention that they offer an automatic Pareto-optimal tradeoff for a user-specified memory limit! (although the config seems to have a different name in the code than mentioned in the blog)

With Fast Forward Computer Vision (ffcv) you can train a classifier on CIFAR-10 on an H100 in ~14 seconds. They report in their CIFAR-10 example:

92.6% accuracy in 36 seconds on a single NVIDIA A100 GPU.

ffcv achieves that by speeding up the data loading with various techniques, so you can re-use most of your training code and just replace the loading, as this example from the quickstart shows:

from ffcv.loader import Loader, OrderOption
from ffcv.transforms import ToTensor, ToDevice, ToTorchImage, Cutout
from ffcv.fields.decoders import IntDecoder, RandomResizedCropRGBImageDecoder

# Random resized crop
decoder = RandomResizedCropRGBImageDecoder((224, 224))

# Data decoding and augmentation
image_pipeline = [decoder, Cutout(), ToTensor(), ToTorchImage(), ToDevice(0)]
label_pipeline = [IntDecoder(), ToTensor(), ToDevice(0)]

# Pipeline for each data field
pipelines = {
    'image': image_pipeline,
    'label': label_pipeline
}

# Replaces PyTorch data loader (`torch.utils.data.Dataloader`)
loader = Loader(write_path, batch_size=bs, num_workers=num_workers,
                order=OrderOption.RANDOM, pipelines=pipelines)

# rest of training / validation proceeds identically
for epoch in range(epochs):
    ...