Introducing PyTorch Lightning Sharded: Train SOTA Models, With Half The Memory

Lightning 1.1 reveals Sharded Training — train deep learning models on multiple GPUs saving over 50% on memory, with no performance loss or code change required!

Image for post
Image for post
Image By Author

Larger Model, Better Accuracy

Recent advancements in language modelling are trending towards larger pre-trained models performing better on downstream tasks. This is famously shown with the release of GPT-3 by OpenAI, their largest model at 175 billion parameters, requiring massive amounts of compute and optimization tricks to train.

Image for post
Image for post
Comparing language model research parameter sizes over time. GPT-3 continued to surpass by magnitudes. Image by Microsoft

Sharded Training Powered by Lightning

Image for post
Image for post
Traditional Distributed Training vs Sharded Training. Parameters (P) are split across GPUs to reduce memory overhead per GPU. For Sharded Training, we split the optimizer states and gradients. Image by author
Image for post
Image for post

Enable Sharded Training with no code changes

To demonstrate how easy it is to use Sharded Training in Lightning we use NeMo, a popular library from NVIDIA to train conversational AI models backed by Lightning. We’ll be using a vanilla Transformer LM model provided in NeMo, and be using a 1.2 billion parameter model which has a high memory requirement to train. When training large language models, memory is a valuable resource to boost the model size or to improve saturation on GPUs. To train the model we’ll be using the WikiText dataset.

Image for post
Image for post
Average Peak Memory Training a Transformer LM ((22 layers, hidden size 3072, trained on SST, 2 billion variant with 32 layers), SwAV Wide Resnet (trained on STL-10), DeepSpeech2 (trained on Librispeech100), iGPT (trained on MNIST) using 8 A100s. Uses same hyper-parameters and batch size per model. We increase model capacity to roughly a billion parameters. Lower is better. Image by Author
Image for post
Image for post
Average Epoch time on 8 A100s with the same hyper-parameters and batch sizes. Lower is better. Image by Author

Use Sharded training with Lightning today

In this article we described how you can use Sharded Training with PyTorch Lightning to reduce memory requirements for your research with no code changes required. We also show that a large range of models from a variety of domains can be trained with large memory savings by simply adding a single flag to your lightning trainer, with no performance loss.

Research Engineer at Grid AI | Pytorch Lightning

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store