Introducing PyTorch Lightning Sharded: Train SOTA Models, With Half The Memory

Lightning 1.1 reveals Sharded Training — train deep learning models on multiple GPUs saving over 50% on memory, with no performance loss or code change required!

Image for post
Image for post
Image By Author

Larger Model, Better Accuracy

Image for post
Image for post
Comparing language model research parameter sizes over time. GPT-3 continued to surpass by magnitudes. Image by Microsoft

Sharded Training Powered by Lightning

Image for post
Image for post
Traditional Distributed Training vs Sharded Training. Parameters (P) are split across GPUs to reduce memory overhead per GPU. For Sharded Training, we split the optimizer states and gradients. Image by author
Image for post
Image for post

Enable Sharded Training with no code changes

Image for post
Image for post
Average Peak Memory Training a Transformer LM ((22 layers, hidden size 3072, trained on SST, 2 billion variant with 32 layers), SwAV Wide Resnet (trained on STL-10), DeepSpeech2 (trained on Librispeech100), iGPT (trained on MNIST) using 8 A100s. Uses same hyper-parameters and batch size per model. We increase model capacity to roughly a billion parameters. Lower is better. Image by Author
Image for post
Image for post
Average Epoch time on 8 A100s with the same hyper-parameters and batch sizes. Lower is better. Image by Author

Use Sharded training with Lightning today

Written by

Research Engineer at Grid AI | Pytorch Lightning

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store