Building a neural network using JAX.

JAX is a Python library developed by Google Research for high-performance numerical computing. I got interested in it because it combines Numpy’s syntax with the benefits of automatic differentiation and hardware acceleration (GPU/TPU). I’ve been wanting to learn more about it, so I finally set aside a few hours to build neural nets from scratch using Numpy, then JAX.

Layout for this post

  • What is compilation and how does JAX work
  • Building a network from scratch using Numpy
  • Building one from scratch using JAX
  • Comparing the two

For all the code (with more comments), check out the notebook here https://github.com/HarrisonSantiago/WebsiteNotebooks

How does JAX work?

At it’s core, JAX speeds up numerical computation through Just-In-Time (JIT) compilation. The following section gives an overview of how compiling works in different languages, as a justification for bringing JIT into Python. From this, we will also address how introducing JIT necessitates a functional paradigm when coding. However, this is not going to get into nitty details of JAX that are best answered by the official documentation.


When a developer writes source code, the path to execution depends on the language’s execution model. Broadly speaking, there are three execution models: pure compilation (c, c++, rust, go), pure interpretation (original python, ruby, shell scripts), and JIT compilation (Java, c#, javascript). Let’s discuss the pros and cons of each.  

Pure compiled

Compiled languages, such as C and C++, go through a compilation process that directly translates the source code into machine code specific to the target architecture. This machine code can be executed directly by the CPU without any intermediary steps. The compilation process itself may take some time, but once compiled, the resulting executable can run very efficiently.

The compilation process begins with the source code, typically written in .c or .cpp files. This human-readable code is then fed into a compiler, which performs several steps such as lexical analysis (breaking the code into tokens), followed by syntax analysis (ensuring the code adheres to language rules). The compiler then generates an Abstract Syntax Tree (AST) and performs various optimizations before finally producing object code specific to the target architecture. Object code, usually in .o or .obj files, contains machine code instructions but is not yet executable. The linker then takes one or more object files, along with any specified libraries, and combines them into a single executable. It resolves external references, assigns final memory addresses to code and data sections, and may perform additional optimizations like dead code elimination.

Now what is the big take away for generalizing to other compiled languages? Modern compilers analyze the entire code base at once perform extensive optimizations to generate highly efficient machine code tailored to specific CPU architectures. This means the generated machine code can then directly executed by the CPU, with very little overhead.

Interpreted Languages

In interpreted languages, the computer is translating and executing the source code line-by-line at runtime. An interpreter reads the source code, parses it, and immediately executes the corresponding operations without producing an intermediate executable file. This adds a lot of overhead when executing code, but it makes it much easier to have dynamic typing in your language (optimizing dynamic code can be an absolute nightmare).

In these languages, the instructions are not directly executed by the target machine, but instead read and executed by some other program (which normally is written in the language of the native machine). For example, the “+” operation would be recognized by the interpreter at run time, which would then call its own “add(a,b)” function with the appropriate arguments, which would then execute the machine code “ADD” instruction.

JIT compilation

JIT compilation is compilation of code during execution of a program (at run time) rather than at some time before. For this model, we will focus on Java. As background, the Java compiler turns source code into an intermediate step called “bytecode”. This bytecode is platform agnostic (no optimizations are done for the hardware architecture like in c++). From there, Java employs the Java Virtual Machine (JVM) at run time to convert the bytecode into source code.

When a Java program is executed, the JVM comes into play. Initially, the JVM interprets the bytecode, translating it into machine code on-the-fly. However, they also employ Just-In-Time (JIT) compilation to enhance performance. The JIT compiler analyzes the bytecode as it’s being interpreted and identifies “hot spots” – frequently executed portions of code. These hot spots are then compiled into native machine code, and cached for latter use, where they can be executed directly by the hardware, bypassing the interpretation step for subsequent calls. This eliminates the need for repeated compilation and improves overall performance.

In short, this idea of identifying hot spots, and caching a compiled version for repeated use is what defines JIT compilation.

How Python works

Python’s execution model is fundamentally based on interpretation, but modern versions have moved away from “pure interpretation”. When a Python script is run, the interpreter first parses the source code into bytecode. Similarly to Java, this bytecode is a low-level, platform-independent representation of the source code, optimized for interpretation by the Python Virtual Machine (PVM).

The bytecode compilation process in Python is pretty transparent. When a .py file is executed, Python checks for a corresponding .pyc file (compiled bytecode) in a pycache directory. If the .pyc file exists and is up-to-date, Python loads the bytecode directly. If not, it compiles the source code to bytecode, saves it as a .pyc file for future use, and then proceeds with execution.

Once the bytecode is available, the PVM, begins executing it instruction by instruction. Each bytecode instruction corresponds to a specific operation, such as loading a value, performing an arithmetic operation, or calling a function.

So why Jax?

So why Jax? Based off what we learned, it seems like modern Python is more-or-less JIT anyways. The difference is that in true JIT language like Java, the bytecode in compiled directly to machine code. On the other hand, the PVM interprets the bytecode. Going back to simple addition, in python, the statement y = x+1 is executed as a sequence of operations like “load constant 1”, “load x” “add the two values” “store the result in y”. Each of these operations is implemented by an independent function call. Java would already knows that x is an integer, and can execute the line using a single CPU instruction.

Numerical computations like that one are exactly where JAX comes into play. JAX’s JIT compiler transforms mathematical operations into optimized machine code that runs directly on the hardware, specifically targeting numerical computations and array operations. This compilation happens at a lower level than Python’s bytecode compilation and includes hardware-specific optimizations, vectorization, and GPU/TPU-specific code generation when available.

The catch with this is somewhat intuitive: if you are trying to optimize a math function, you need to know exactly what that function is. This means you can only use JIT compilation on functions who’s operations and output depend solely on the input. As such, when using JAX you end up adopting a very functional coding mindset.

Building a Neural Net from Scratch (Numpy)

When creating this network, I wanted to follow the same OOP designs encouraged by Pytorch. So to start, we can create layers

class Layer(ABC):
  """Abstract base class for neural network layers."""
  def __init__(self) -> None:
    self.input = None
    self.output = None
  @abstractmethod
  def forward(self, input_data: np.ndarray) -> np.ndarray:
    """Forward pass computation.
    Args:
        input_data: Input tensor of shape (batch_size, input_features)
    Returns:
        Output tensor of shape (batch_size, output_features)
    """
    raise NotImplementedError

  @abstractmethod
  def backward(self, output_error: np.ndarray, learning_rate: float) -> np.ndarray:
    """Backward pass computation.
    Args:
        output_error: Gradient of the loss with respect to layer output
        learning_rate: Learning rate for parameter updates
    Returns:
        Gradient of the loss with respect to layer input
    """
    raise NotImplementedError

class DenseLayer(Layer):
  """Fully connected neural network layer."""
  def __init__(self, input_size: int, output_size: int) -> None:
    super().__init__()

    limit = np.sqrt(6 / (input_size + output_size))
    self.weights = np.random.uniform(-limit, limit, (input_size, output_size))
    self.bias = np.zeros((1, output_size))

    self.weights_grad: Optional[np.ndarray] = None
    self.bias_grad: Optional[np.ndarray] = None

  def forward(self, input_data: np.ndarray) -> np.ndarray:
    self.input = input_data
    self.output = np.dot(input_data, self.weights) + self.bias
    return self.output

  def backward(self, output_error: np.ndarray, learning_rate: float) -> np.ndarray:
    # Compute gradients
    input_error = np.dot(output_error, self.weights.T)
    self.weights_grad = np.dot(self.input.T, output_error) / self.input.shape[0]
    self.bias_grad = np.sum(output_error, axis=0, keepdims=True) / self.input.shape[0] 

    self.weights -= learning_rate * self.weights_grad
    self.bias -= learning_rate * self.bias_grad

    return input_error

class ActivationLayer(Layer):
  """Neural network activation layer."""

  def __init__(self,
                activation_fn: Callable[[np.ndarray], np.ndarray],
                activation_prime: Callable[[np.ndarray], np.ndarray]) -> None:
    super().__init__()
    self.activation_fn = activation_fn
    self.activation_prime = activation_prime

  def forward(self, input_data: np.ndarray) -> np.ndarray:
    self.input = input_data
    self.output = self.activation_fn(self.input)
    return self.output

  def backward(self, output_error: np.ndarray, learning_rate: float) -> np.ndarray:
    return self.activation_prime(self.input) * output_error
    

and then create our loss and activation functions

class LossFunctions:
  """Collection of loss functions and their derivatives."""

  @staticmethod
  def cross_entropy(y_true: np.ndarray,
                    y_pred: np.ndarray,
                    epsilon: float = 1e-15) -> float:
    y_pred = np.clip(y_pred, epsilon, 1 - epsilon)
    return -np.mean(np.sum(y_true * np.log(y_pred), axis=1))

  @staticmethod
  def cross_entropy_prime(y_true: np.ndarray,
                          y_pred: np.ndarray,
                          epsilon: float = 1e-15) -> np.ndarray:
    y_pred = np.clip(y_pred, epsilon, 1 - epsilon)
    return y_pred - y_true


class Activations:
  """Collection of activation functions and their derivatives."""

  @staticmethod
  def leaky_relu(x: np.ndarray, alpha: float = 0.01) -> np.ndarray:
    return np.where(x > 0, x, alpha * x)
  
  @staticmethod
  def leaky_relu_prime(x: np.ndarray, alpha: float = 0.01) -> np.ndarray:
    return np.where(x > 0, 1.0, alpha)
    

and from here, create our wrapper for the network and training

@dataclass
class NetworkConfig:
    """Neural network configuration parameters."""
    learning_rate: float = 0.01 
    epochs: int = 1000
    batch_size: int = 32
    clip_value: float = 5.0  # Add gradient clipping
    epsilon: float = 1e-15   # Small constant to prevent division by zero


class NeuralNetwork:
    """Simple neural network implementation."""

    def __init__(self, config: NetworkConfig = NetworkConfig()) -> None:
        self.layers: list[Layer] = []
        self.loss = LossFunctions.cross_entropy
        self.loss_prime = LossFunctions.cross_entropy_prime
        self.config = config

    def add(self, layer: Layer) -> None:
        self.layers.append(layer)

    def predict(self, input_data: np.ndarray) -> np.ndarray:
        """Generate predictions with NaN checking."""
        output = input_data
       
        for i, layer in enumerate(self.layers):
            output = layer.forward(output)
            
        return output

    def _clip_gradients(self, grad: np.ndarray) -> np.ndarray:
        """Clip gradients to prevent explosion."""
        return np.clip(grad, -self.config.clip_value, self.config.clip_value)

    def fit(self, x_train: np.ndarray, y_train: np.ndarray) -> list[float]:
        """Train the neural network with NaN prevention."""
        samples = len(x_train)

        for epoch in range(self.config.epochs):
            epoch_loss = 0
            batch_count = 0

            # Mini-batch gradient descent
            for i in range(0, samples, self.config.batch_size):
                batch_x = x_train[i:i + self.config.batch_size]
                batch_y = y_train[i:i + self.config.batch_size]
                actual_batch_size = len(batch_x)

                # Forward propagation
                output = self.predict(batch_x)
                # Prevent division by zero in log operations
                output = np.clip(output, self.config.epsilon, 1 - self.config.epsilon)
                    
                # Backward propagation
                error = self.loss_prime(batch_y, output)
                error = self._clip_gradients(error)  # Clip initial gradients
                
                for j, layer in enumerate(reversed(self.layers)):
                    error = layer.backward(error, self.config.learning_rate)
                    error = self._clip_gradients(error)  # Clip initial gradients
    

All of this should look familiar, it’s direct analogues to torch.nn.Linear and the like.

Creating a Neural Net from Scratch (JAX)

We established earlier that there are limitations on the functions you can JIT using JAX. This poses issues for classes, as you can’t jit a function that uses self as a parameter. Similarly, conditional logic poses prevents JIT compilation. This means we use code like “if this layer type, execute this activation function”.

So how can we create a neural network with JAX? The answer is we need to implicitly define out network. Rather than taking a state-based approach, we generative the matrices that represent our layers and then define functions that directly modify them.

So first, let us create our “layers”

def get_layer_params(input_size: int,
                     output_size: int,
                     key: random.PRNGKey) -> Dict:   
  """
  Return the weights and biases for a dense layer

  Args: 
    - input_size: size of this layer
    - output_size: size of the next layer
    - key: random key for this layer

  Returns: 
    - params: a dictionary of weights and biases for this layer
  """
  limit = jnp.sqrt(6 / (input_size + output_size))
  W_key, b_key = random.split(key)
  return {
    'weights': random.uniform(W_key, (output_size, input_size),
                               minval=-limit, maxval=limit),
    'bias': random.uniform(b_key, (output_size))
    }
    

And then let’s implicitly define our network. Notice how for each learning step we calculate our gradient, and then return the parameters. Predicting then is just running an input through all our weights and biases.

@jit
def step(params: List[Dict],
         x: jnp.ndarray,
         y: jnp.ndarray,
         lr: float = 0.05):
  
  """Optimized training step with static learning rate"""

  grads = grad(loss)(params, x, y)
  updated_params = []

  for param, grad_param in zip(params, grads):
    updated_param = {
      'weights': param['weights'] - lr * grad_param['weights'],
      'bias': param['bias'] - lr * grad_param['bias']
    }
    updated_params.append(updated_param)
  return updated_params

@jit
def loss(params: List[Dict],
         x: jnp.ndarray,
         targets: jnp.ndarray) -> float:
  """Compute cross entropy loss"""

  predictions = batched_predict(params, x)
  log_softmax = predictions - \
              jax.scipy.special.logsumexp(predictions, axis=1, keepdims=True)
  return -jnp.mean(jnp.sum(targets * log_softmax, axis=1))

@jit
def predict(params: List[Dict], x: jnp.ndarray) -> jnp.ndarray:
  """Implicitly defines densely connected network"""
  alpha = 1e-15
  for p in params[:-1]:
    x = jnp.dot(p['weights'], x) + p['bias']
    x = jnp.where(x > 0, x, alpha * x)

  final_weight = params[-1]['weights']
  final_bias = params[-1]['bias']
  return jnp.dot(final_weight, x) + final_bias

batched_predict = vmap(predict, in_axes=(None, 0))
    

and then our training loop with batches is similarly defined.

for epoch in range(num_epochs):
  for i in range(0, len(X_train), batch_size):
    batch_x = X_train[i:i + batch_size]
    batch_y = y_train[i:i + batch_size]
    params = step(params, batch_x, batch_y)
   
predictions = batched_predict(params, X_test)
    

Comparing the two

For this post I won’t be doing an exhaustive comparison, my only goal is to show how jax improves the speed. To this end, we will be doing classification on the following test data

where we have three classes, each defining a ring of various radius. We’ll train for each on 200 epochs, verify that they both have found good solutions, and then compare the time it took for each to train. It is important to note that this is not a rigorous test by any means. The classes in the Numpy version will introduce overhead, and the default datatype in JAX will speed up it’s version. However, I feel like forcing a functional paradigm is also a way that JAX encourages performant code so I feel comfortable enough making the comparison between our methods. After all, it’s a blog post and not a research paper.

If we time, the Numpy network we see the following training times

#Get data from utils
X_train, X_test, y_train, y_test = generate_circle_data_np(n_points= 2000)

for _ in range(5):
  # Define our network for each trial
  network = NeuralNetwork(NetworkConfig(
      learning_rate=0.05,
      epochs=50,
      batch_size=64,
  ))

  network.add(DenseLayer(2, 128))
  network.add(ActivationLayer(Activations.leaky_relu, Activations.leaky_relu_prime))
  network.add(DenseLayer(128, 256))
  network.add(ActivationLayer(Activations.leaky_relu, Activations.leaky_relu_prime))
  network.add(DenseLayer(256, 128)) 
  network.add(ActivationLayer(Activations.leaky_relu, Activations.leaky_relu_prime))
  network.add(DenseLayer(128, 3))  

  # Time how long training takes
  start_time = time.time()
  network.fit(X_train, y_train)
  end_time = time.time()
  print(f"Training time: {end_time - start_time:.2f} seconds")

> Training time: 7.55 seconds
> Training time: 7.34 seconds
> Training time: 7.40 seconds
> Training time: 7.83 seconds
> Training time: 7.82 seconds
    

with the resulting predictions on our test set

And with our JAX version we get

PRN = random.key(0)
X_train, X_test, y_train, y_test = generate_circle_data_jax(PRN, n_points=2000)

input_dim = X_train.shape[1]
num_classes = y_train.shape[1]

layer_sizes = [input_dim, 128, 256, 128 ,num_classes]
keys = random.split(PRN, len(layer_sizes))
num_epochs = 50
batch_size = 64

for _ in range(5):

  #restart out network every trial
  params = [get_layer_params(input_size, output_size, key) \
        for input_size, output_size, key \
        in zip(layer_sizes[:-1], layer_sizes[1:], keys)]
  
  #Train the network
  start_time = time.time()
  for epoch in range(num_epochs):
    epoch_loss = 0
    for i in range(0, len(X_train), batch_size):
      batch_x = X_train[i:i + batch_size]
      batch_y = y_train[i:i + batch_size]
      params = step(params, batch_x, batch_y)
      

  end_time = time.time()
  print(f"Training time: {end_time - start_time:.2f} seconds")

> Training time: 2.04 seconds
> Training time: 1.63 seconds
> Training time: 1.58 seconds
> Training time: 1.61 seconds
> Training time: 1.60 seconds
    

and with a speed up of ~4x, no decrease in accuracy!