March 2026
Intermediate
534 pages
12h 51m
English
In this chapter, we will focus on JAX, a Python library designed for scientific computing, to solve optimization problems on the GPU. We'll use JAX to develop both a linear regression model and a neural network from scratch. Additionally, we'll construct a physics-informed neural network by incorporating physical laws into the training process. All these tasks are made possible by JAX's automatic differentiation capabilities. Furthermore, we'll showcase the benefits of utilizing JAX's automatic vectorization and Just-In-Time (JIT) compilation features. By the end of this chapter, you will have gained the knowledge and skills to build gradient-based solutions that can be applied to various machine ...
Read now
Unlock full access