What is Autograd?

Autograd is an automatic differentiation system that differentiates Native Python and Numpy code. Autograd is a robust system that can handle many Python features, including recursion, closures, derivatives, backpropagation, and many others. Moreover, the reverse mode differentiation enables the system to work on gradients of functions along with forwarding mode differentiation. Therefore, Autograd provides gradient-based optimization.

What is XLA?

Accelerated Linear Algebra (XLA) is a powerful compiler that helps accelerate models in TensorFlow without having to make any changes in the source code. XLA provides an efficient system for running the models to save time and memory. The system that uses XLA for running its models achieves significant performance improvement.

What is Google JAX?

Google JAX (Just After Execution) is one of the machine learning frameworks which help in working with numerical functions. JAX is mainly becoming famous for its use in research.  An in-depth understanding of JAX explains that it combines Autograd and XLA. As mentioned above, Autograd and XLA are potent systems that enable developers to work with highly complex networks. Google JAX utilizes the updated version of Autograd, which helps work with various features, including loops, if conditions, forward and reverse propagation, and jit compilation.

Similarly, XLA helps compile and run the code on accelerators such as Tensor Processing Units (TPUs) and Graphics Processing Units (GPUs). Moreover, the developers can use a one-function API to compile their Python functions into XLA-optimized kernels. Therefore, the system helps write and compile crucial algorithms to get optimal solutions and high performance of complex programs.

Features of Google JAX

JAX provides the developers with a JIT (Just In Time) compiler that works on the floating point operations per second (FLOPs) approach. These floating point operations determine the time and effort required to generate an optimized algorithm using Python. Following are some key features of Google JAX:

  • JAX provides a Just In Time (JIT) compiler, which helps execute an essential piece of code easily compiled at runtime.
  • The developers can run the NumPy code on the CPU, GPU, and TPU.
  • JAX provides an advanced approach for pseudo-random number generation.
  • JAX helps reduce the runtime overhead while working with differentially private stochastic gradient descent and provides automatic vectorization, just-in-time compilation, and static graph optimization.
  • This system offers enhancements for transforming numerical programs for research purposes.

Working with JAX

Google JAX is a library that can transform the arrays and manipulate them in different programs to build a solution to a complex program efficiently. JAX works like NumPy most of the time, but it has a slightly different syntax than NumPy. Following are some basic commands to use JAX in any program:

Importing JAX

import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

The above code imports some crucial libraries and functions to use further according to the requirements of any application. JAX enables running the code accelerators to make it time and cost-efficient. Following is another simple example of importing JAX and creating a vector using it:

Using Arrays

import jax
import jax.numpy as jnp
x = jnp.arange(10)

The above code shows that working with JAX is like working with NumPy. Therefore, the developers do not need to learn to work with a new API to use JAX. The users can replace “np” with “jnp” in the NumPy code to work with JAX. The difference is the type of the array. JAX uses DeviceArray to represent arrays. The most helpful feature here is that users can run the same code on different backgrounds without a problem.

Matrix Multiplication

Another difference between NumPy and JAX is the technique of generating random numbers. Following is a simple example of using JAX to generate random numbers:

import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
key = random.PRNGKey(0)
x = random.normal(key, (10,))

The above code explains generating random numbers using JAX instead of NumPy. Here, JAX uses explicit PRNG in which the precise passing and iteration of the PRNG state handle the entropy production and consumption. JAX also uses Threefry counter-based PRNG technique, where the users can fork the PRNG state into new PRNGs to work with parallel stochastic generation. Similarly, JAX can help multiply two large matrices using the following commands:

size = 3000
x = random.normal(key, (size, size), dtype=jnp.float32)
%timeit jnp.dot(x, x.T).block_until_ready()

JAX Program Transformations

JAX provides various program transformations for working with numerical computations. These transformations help the code to run faster and in a more efficient way. These transformations help in increasing the performance of code compared to the code written with NumPy. Following are some of the main program transformations that the JAX uses:

  • JAX uses jit() to run the code faster than the usual NumPy code.
  • JAX uses grad() for taking derivatives in numerical computations.
  • JAX uses vmap() to utilize automatic vectorization and batching.

Following is a brief introduction to using these interesting program transformations.

Using jit() Program Transformation

As mentioned previously, JAX can run the code on GPU and TPU. The developers can utilize this ability to dispatch multiple operations on the GPU. JAX uses a jit decorator for the compilation of these operations using XLA. Following is a simple example to achieve this approach:

def selu(x, alpha=1.67, lmbda=1.05):
 return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)
x = random.normal(key, (1000000,))
%timeit selu(x).block_until_ready()
import jax
import jax.numpy as jnp

The developers can speed up the computations by using the jit decorator here. This decorator will compile the function on its first call and put it in the cache for the next time. Following is an example of using the jit decorator:

selu_jit = jit(selu)
%timeit selu_jit(x).block_until_ready()

Using grad() for Derivates

As mentioned above, the JAX library helps evaluate complex numerical functions and provides efficient computation methods. The developers can also use JAX to transform these numerical functions. One such transformation is automatic differentiation. For this purpose, JAX uses Autograd. Autograd provides the functionality to quickly compute gradients using the grad() function. Following is a simple example of computing gradient using the grad() function:

def sum_logistic(x):
 return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))
x_small = jnp.arange(3.)
derivative_fn = grad(sum_logistic)
The next step is to verify the output using finite differences. The following piece of code explains utilising the function of the finite difference to confirm the outcome:
def first_finite_differences(f, x):
 eps = 1e-3
 return jnp.array([(f(x + eps * v) - f(x - eps * v)) / (2 * eps)
                  for v in jnp.eye(len(x))])

Similarly, the developers can use grad() function for taking more derivatives. Moreover, the developers can use the jax.vjp() function to achieve reverse-mode vector jacobian products and forward-mode vector products for complex operations. Following is a simple example of using JAX to compute these vectors computations and make the function more efficient:

from jax import jacfwd, jacrev
def hessian(fun):
 return jit(jacfwd(jacrev(fun)))

Using vmap() for Auto vectorization

The JAX API uses vmap() transformation for vectorizing the maps. This transformation is beneficial as it uses the usual semantics for mapping the functions along the array axes using the loop in the function’s primitive. This approach helps increase the performance of the process. Moreover, when combined with jit(), this helps in adding the batch dimensions faster. Following is a simple example of utilizing vmap() to promote matrix-vector products into matrix-matrix products:

mat = random.normal(key, (150, 100))
batched_x = random.normal(key, (10, 100))

def apply_matrix(v):
 return jnp.dot(mat, v)

All the above examples show the power of JAX. The library helps achieve efficiency in time and cost management, along with optimized output. Moreover, the performance of a function increases highly using the JAX transformations. There are lots of other examples which highlight the importance of using JAX for complex numerical computations.