17/06/2024
Quantizing a model is not enough. Quantizing won't help you much in saving memory to train a model!
In the typical backpropagation algorithm, the model weights and the input tensors are stored in memory in Float16 or BFloat16. During the forward pass, we only need those. During the backward pass, we create the gradients, again in Float16 or BFloat16.
Once we have the gradients, we can update the model parameters. During the optimization steps, all the operations are done in Float32! If we consider the Adam Optimizer, for example, we need to convert the gradient from Float16 to Float32. With the gradients, we compute the momentum and the variance, which need to be stored in memory in Float 32. From the momentum and variance, we can compute the updated weight values in Float32 as well. We then convert back the model weights from Float32 to Float16 for the next backpropagation iteration.
So, in memory, during the optimization step, we need:
- The model parameters in Float16
- The gradients in Float16
- The gradients in Float32
- The momentum in Float32
- The momentum in Float32
- The model parameters in Float32
Because the Float32 takes twice as much memory as the Float16, the optimizer state requires 8X more memory than the model parameters themselves.
Even when we quantize the model parameters, the memory requirements are the same. Let's say we quantize to 4-bits floating numbers. During the forward pass, the input tensors still come in BFloat16 precision, so we need to dequantize the model parameters to perform the different computations. The same problem occurs during the backward pass, and we need to dequantize the model parameters. And the optimizer computations still happen in Float32 precision by converting the dequantized weights and gradients to Float32.
QLoRA is a good solution in that situation because the gradient updates only happen on the LoRA adapters, which minimizes the optimizer memory spike, and the optimizer updates are paged by buffering the optimizer state to the CPU RAM if need be!