Apéndice B. Precisión mixta y cuantización
Por defecto en, PyTorch utiliza floats de 32 bits para representar los parámetros del modelo: eso son 4 bytes por parámetro. Si tu modelo tiene 1.000 millones de parámetros, necesitarás al menos 4 GB de RAM sólo para contener el modelo. En el momento de la inferencia también necesitas suficiente RAM para almacenar las activaciones, y en el momento del entrenamiento necesitas suficiente RAM para almacenar también todas las activaciones intermedias (para el paso hacia atrás), y para almacenar los parámetros del optimizador (por ejemplo, Adam necesita dos parámetros adicionales por cada parámetro del modelo, lo que supone 8 GB extra). Esto es mucha RAM, y también mucho tiempo de transferencia de datos entre la CPU y la GPU, por no hablar del espacio de almacenamiento, el tiempo de descarga y el consumo de energía.
Entonces, ¿cómo podemos reducir el tamaño del modelo? Una opción sencilla es utilizar una representación flotante de precisión reducida, normalmente flotantes de 16 bits en lugar de flotantes de 32 bits. Si entrenas un modelo de 32 bits y luego lo reduces a 16 bits después del entrenamiento, su tamaño se reducirá a la mitad, con escaso impacto en su calidad. ¡Estupendo!
Sin embargo, si intentas entrenar el modelo utilizando valores flotantes de 16 bits, puedes encontrarte con problemas de convergencia, como veremos. Así que una estrategia habitual es el entrenamiento de precisión mixta (MPT), en el que mantenemos los pesos y las ...