Appendix B. Mixed Precision and Quantization
By default, PyTorch uses 32-bit floats to represent model parameters: that’s 4 bytes per parameter. If your model has 1 billion parameters, then you need at least 4 GB of RAM just to hold the model. At inference time you also need enough RAM to store the activations, and at training time you need enough RAM to store all the intermediate activations as well (for the backward pass), and to store the optimizer parameters (e.g., Adam needs two additional parameters for each model parameter—that’s an extra 8 GB). This is a lot of RAM, and it’s also plenty of time spent transferring data between the CPU and the GPU, not to mention storage space, download time, and energy consumption.
So how can we reduce the model’s size? A simple option is to use a reduced precision float representation—typically 16-bit floats instead of 32-bit floats. If you train a 32-bit model then shrink it to 16-bits after training, its size will be halved, with little impact on its quality. Great!
However, if you try to train the model using 16-bit floats, you may run into convergence issues, as we will see. So a common strategy is mixed-precision training (MPT), where we keep the weights and weight updates at 32-bit precision during training, but the rest of the computations use 16-bit precision. After training, we shrink the weights down to 16-bits.
Finally, to shrink the model even further, you can use quantization: the parameters are discretized and represented ...
Become an O’Reilly member and get unlimited access to this title plus top books and audiobooks from O’Reilly and nearly 200 top publishers, thousands of courses curated by job role, 150+ live events each month,
and much more.
Read now
Unlock full access