Neural Networks#

Writing arbitrarily complex neural networks in MLX can be done using only mlx.core.array and mlx.core.value_and_grad(). However, this requires the user to write again and again the same simple neural network operations as well as handle all the parameter state and initialization manually and explicitly.

The module mlx.nn solves this problem by providing an intuitive way of composing neural network layers, initializing their parameters, freezing them for finetuning and more.

Quick Start with Neural Networks#

import mlx.core as mx
import mlx.nn as nn

class MLP(nn.Module):
    def __init__(self, in_dims: int, out_dims: int):
        super().__init__()

        self.layers = [
            nn.Linear(in_dims, 128),
            nn.Linear(128, 128),
            nn.Linear(128, out_dims),
        ]

    def __call__(self, x):
        for i, l in enumerate(self.layers):
            x = mx.maximum(x, 0) if i > 0 else x
            x = l(x)
        return x

# The model is created with all its parameters but nothing is initialized
# yet because MLX is lazily evaluated
mlp = MLP(2, 10)

# We can access its parameters by calling mlp.parameters()
params = mlp.parameters()
print(params["layers"][0]["weight"].shape)

# Printing a parameter will cause it to be evaluated and thus initialized
print(params["layers"][0])

# We can also force evaluate all parameters to initialize the model
mx.eval(mlp.parameters())

# A simple loss function.
# NOTE: It doesn't matter how it uses the mlp model. It currently captures
#       it from the local scope. It could be a positional argument or a
#       keyword argument.
def l2_loss(x, y):
    y_hat = mlp(x)
    return (y_hat - y).square().mean()

# Calling `nn.value_and_grad` instead of `mx.value_and_grad` returns the
# gradient with respect to `mlp.trainable_parameters()`
loss_and_grad = nn.value_and_grad(mlp, l2_loss)

The Module Class#

The workhorse of any neural network library is the Module class. In MLX the Module class is a container of mlx.core.array or Module instances. Its main function is to provide a way to recursively access and update its parameters and those of its submodules.

Parameters#

A parameter of a module is any public member of type mlx.core.array (its name should not start with _). It can be arbitrarily nested in other Module instances or lists and dictionaries.

Module.parameters() can be used to extract a nested dictionary with all the parameters of a module and its submodules.

A Module can also keep track of “frozen” parameters. Module.trainable_parameters() returns only the subset of Module.parameters() that is not frozen. When using mlx.nn.value_and_grad() the gradients returned will be with respect to these trainable parameters.

Updating the parameters#

MLX modules allow accessing and updating individual parameters. However, most times we need to update large subsets of a module’s parameters. This action is performed by Module.update().

Value and grad#

Using a Module does not preclude using MLX’s high order function transformations (mlx.core.value_and_grad(), mlx.core.grad(), etc.). However, these function transformations assume pure functions, namely the parameters should be passed as an argument to the function being transformed.

There is an easy pattern to achieve that with MLX modules

model = ...

def f(params, other_inputs):
    model.update(params)  # <---- Necessary to make the model use the passed parameters
    return model(other_inputs)

f(model.trainable_parameters(), mx.zeros((10,)))

However, mlx.nn.value_and_grad() provides precisely this pattern and only computes the gradients with respect to the trainable parameters of the model.

In detail:

  • it wraps the passed function with a function that calls Module.update() to make sure the model is using the provided parameters.

  • it calls mlx.core.value_and_grad() to transform the function into a function that also computes the gradients with respect to the passed parameters.

  • it wraps the returned function with a function that passes the trainable parameters as the first argument to the function returned by mlx.core.value_and_grad()

value_and_grad(model, fn)

Transform the passed function fn to a function that computes the gradients of fn wrt the model's trainable parameters and also its value.

Neural Network Layers#

Embedding(num_embeddings, dims)

Implements a simple lookup table that maps each input integer to a high-dimensional vector.

ReLU()

Applies the Rectified Linear Unit.

GELU([approx])

Applies the Gaussian Error Linear Units.

SiLU()

Applies the Sigmoid Linear Unit.

Linear(input_dims, output_dims[, bias])

Applies an affine transformation to the input.

Conv1d(in_channels, out_channels, kernel_size)

Applies a 1-dimensional convolution over the multi-channel input sequence.

Conv2d(in_channels, out_channels, kernel_size)

Applies a 2-dimensional convolution over the multi-channel input image.

LayerNorm(dims[, eps, affine])

Applies layer normalization [1] on the inputs.

RMSNorm(dims[, eps])

Applies Root Mean Square normalization [1] to the inputs.

GroupNorm(num_groups, dims[, eps, affine, ...])

Applies Group Normalization [1] to the inputs.

RoPE(dims[, traditional])

Implements the rotary positional encoding [1].

MultiHeadAttention(dims, num_heads[, ...])

Implements the scaled dot product attention with multiple heads.

Sequential(*modules)

A layer that calls the passed callables in order.

Layers without parameters (e.g. activation functions) are also provided as simple functions.

gelu(x)

Applies the Gaussian Error Linear Units function.

gelu_approx(x)

An approximation to Gaussian Error Linear Unit.

gelu_fast_approx(x)

A fast approximation to Gaussian Error Linear Unit.

relu(x)

Applies the Rectified Linear Unit.

silu(x)

Applies the Sigmoid Linear Unit.