Appendice B. Precisione mista e quantizzazione
Per impostazione predefinita su, PyTorch utilizza float a 32 bit per rappresentare i parametri del modello: si tratta di 4 byte per parametro. Se il tuo modello ha 1 miliardo di parametri, hai bisogno di almeno 4 GB di RAM solo per contenere il modello. Al momento dell'inferenza hai bisogno di una quantità di RAM sufficiente per memorizzare le attivazioni, mentre al momento dell'addestramento hai bisogno di una quantità di RAM sufficiente per memorizzare anche tutte le attivazioni intermedie (per il backward pass) e per memorizzare i parametri dell'ottimizzatore (ad esempio, Adam ha bisogno di due parametri aggiuntivi per ogni parametro del modello: si tratta di 8 GB in più). Si tratta di molta RAM e di molto tempo speso per trasferire i dati tra la CPU e la GPU, per non parlare dello spazio di archiviazione, del tempo di download e del consumo energetico.
Come possiamo quindi ridurre le dimensioni del modello? Un'opzione semplice è quella di utilizzare una rappresentazione dei dati a precisione ridotta, tipicamente a 16 bit anziché a 32 bit. Se addestri un modello a 32 bit e poi lo riduci a 16 bit dopo l'addestramento, le sue dimensioni si dimezzeranno, con un impatto minimo sulla sua qualità. Ottimo!
Tuttavia, se provi ad addestrare il modello utilizzando i float a 16 bit, potresti incorrere in problemi di convergenza, come vedremo. Una strategia comune è quindi l'addestramento a precisione mista (MPT), in cui durante l'addestramento ...