Quick Start Guide#
Basics#
Import mlx.core and make an array:
>> import mlx.core as mx
>> a = mx.array([1, 2, 3, 4])
>> a.shape
[4]
>> a.dtype
int32
>> b = mx.array([1.0, 2.0, 3.0, 4.0])
>> b.dtype
float32
Operations in MLX are lazy. The outputs of MLX operations are not computed
until they are needed. To force an array to be evaluated use
eval().  Arrays will automatically be evaluated in a few cases. For
example, inspecting a scalar with array.item(), printing an array,
or converting an array from array to numpy.ndarray all
automatically evaluate the array.
>> c = a + b    # c not yet evaluated
>> mx.eval(c)  # evaluates c
>> c = a + b
>> print(c)     # Also evaluates c
array([2, 4, 6, 8], dtype=float32)
>> c = a + b
>> import numpy as np
>> np.array(c)   # Also evaluates c
array([2., 4., 6., 8.], dtype=float32)
Function and Graph Transformations#
MLX has standard function transformations like grad() and vmap().
Transformations can be composed arbitrarily. For example
grad(vmap(grad(fn))) (or any other composition) is allowed.
>> x = mx.array(0.0)
>> mx.sin(x)
array(0, dtype=float32)
>> mx.grad(mx.sin)(x)
array(1, dtype=float32)
>> mx.grad(mx.grad(mx.sin))(x)
array(-0, dtype=float32)
Other gradient transformations include vjp() for vector-Jacobian products
and jvp() for Jacobian-vector products.
Use value_and_grad() to efficiently compute both a function’s output and
gradient with respect to the function’s input.