This commit is contained in:
Awni Hannun 2023-11-29 12:41:56 -08:00 committed by CircleCI Docs
parent aede70e81d
commit fbd10a48d4
440 changed files with 66856 additions and 0 deletions

4
docs/build/html/.buildinfo vendored Normal file
View File

@ -0,0 +1,4 @@
# Sphinx build info version 1
# This file hashes the configuration used when building these files. When it is not found, a full rebuild will be done.
config: 5b1e2769941ea186df86318a9fdcef20
tags: 645f666f9bcd5a90fca523b33c5a78b7

View File

@ -0,0 +1,6 @@
.. _cpp_ops:
Operations
==========

View File

@ -0,0 +1,948 @@
Developer Documentation
=======================
MLX provides a open and flexible backend to which users may add operations
and specialized implementations without much hassle. While the library supplies
efficient operations that can be used and composed for any number of
applications, there may arise cases where new functionalities or highly
optimized implementations are needed. For such cases, you may design and
implement your own operations that link to and build on top of :mod:`mlx.core`.
We will introduce the inner-workings of MLX and go over a simple example to
learn the steps involved in adding new operations to MLX with your own CPU
and GPU implementations.
Introducing the Example
-----------------------
Let's say that you would like an operation that takes in two arrays,
``x`` and ``y``, scales them both by some coefficents ``alpha`` and ``beta``
respectively, and then adds them together to get the result
``z = alpha * x + beta * y``. Well, you can very easily do that by just
writing out a function as follows:
.. code-block:: python
import mlx.core as mx
def simple_axpby(x: mx.array, y: mx.array, alpha: float, beta: float) -> mx.array:
return alpha * x + beta * y
This function performs that operation while leaving the implementations and
differentiation to MLX.
However, you work with vector math libraries often and realize that the
``axpby`` routine defines the same operation ``Y = (alpha * X) + (beta * Y)``.
You would really like the part of your applications that does this operation
on the CPU to be very fast - so you decide that you want it to rely on the
``axpby`` routine provided by the Accelerate_ framework. Continuing to impose
our assumptions on to you, let's also assume that you want to learn how add
your own implementation for the gradients of your new operation while going
over the ins-and-outs of the MLX framework.
Well, what a coincidence! You are in the right place. Over the course of this
example, we will learn:
* The structure of the MLX library from the frontend API to the backend implementations.
* How to implement your own CPU backend that redirects to Accelerate_ when appropriate (and a fallback if needed).
* How to implement your own GPU implementation using metal.
* How to add your own ``vjp`` and ``jvp``.
* How to build your implementations, link them to MLX, and bind them to python.
Operations and Primitives
-------------------------
In one sentence, operations in MLX build the computation graph, and primitives
provide the rules for evaluation and transformations of said graph. Let's start
by discussing operations in more detail.
Operations
^^^^^^^^^^^
Operations are the frontend functions that operate on arrays. They are defined
in the C++ API (:ref:`cpp_ops`) and then we provide bindings to these
operations in the Python API (:ref:`ops`).
We would like an operation, :meth:`axpby` that takes in two arrays ``x`` and ``y``,
and two scalars, ``alpha`` and ``beta``. This is how we would define it in the
C++ API:
.. code-block:: C++
/**
* Scale and sum two vectors elementwise
* z = alpha * x + beta * y
*
* Follow numpy style broadcasting between x and y
* Inputs are upcasted to floats if needed
**/
array axpby(
const array& x, // Input array x
const array& y, // Input array y
const float alpha, // Scaling factor for x
const float beta, // Scaling factor for y
StreamOrDevice s = {} // Stream on which to schedule the operation
);
This operation itself can call other operations within it if needed. So, the
simplest way to go about implementing this operation would be do so in terms
of existing operations.
.. code-block:: C++
array axpby(
const array& x, // Input array x
const array& y, // Input array y
const float alpha, // Scaling factor for x
const float beta, // Scaling factor for y
StreamOrDevice s /* = {} */ // Stream on which to schedule the operation
) {
// Scale x and y on the provided stream
auto ax = multiply(array(alpha), x, s);
auto by = multiply(array(beta), y, s);
// Add and return
return add(ax, by, s);
}
However, as we discussed earlier, this is not our goal. The operations themselves
do not contain the implementations that act on the data, nor do they contain the
rules of transformations. Rather, they are an easy to use interface that build
on top of the building blocks we call :class:`Primitive`.
Primitives
^^^^^^^^^^^
A :class:`Primitive` is part of the computation graph of an :class:`array`. It
defines how to create an output given a set of input :class:`array` . Further,
a :class:`Primitive` is a class that contains rules on how it is evaluated
on the CPU or GPU, and how it acts under transformations such as ``vjp`` and
``jvp``. These words on their own can be a bit abstract, so lets take a step
back and go to our example to give ourselves a more concrete image.
.. code-block:: C++
class Axpby : public Primitive {
public:
explicit Axpby(Stream stream, float alpha, float beta)
: Primitive(stream), alpha_(alpha), beta_(beta){};
/**
* A primitive must know how to evaluate itself on the CPU/GPU
* for the given inputs and populate the output array.
*
* To avoid unecessary allocations, the evaluation function
* is responsible for allocating space for the array.
*/
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
/** The Jacobian-vector product. */
array jvp(
const std::vector<array>& primals,
const std::vector<array>& tangents,
const std::vector<int>& argnums) override;
/** The vector-Jacobian product. */
std::vector<array> vjp(
const std::vector<array>& primals,
const array& cotan,
const std::vector<int>& argnums) override;
/**
* The primitive must know how to vectorize itself accross
* the given axes. The output is a pair containing the array
* representing the vectorized computation and the axis which
* corresponds to the output vectorized dimension.
*/
std::pair<array, int> vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) override;
/** Print the primitive. */
void print(std::ostream& os) override {
os << "Axpby";
}
/** Equivalence check **/
bool is_equivalent(const Primitive& other) const override;
private:
float alpha_;
float beta_;
/** Fall back implementation for evaluation on CPU */
void eval(const std::vector<array>& inputs, array& out);
};
The :class:`Axpby` class derives from the base :class:`Primitive` class and
follows the above demonstrated interface. :class:`Axpby` treats ``alpha`` and
``beta`` as parameters. It then provides implementations of how the array ``out``
is produced given ``inputs`` through :meth:`Axpby::eval_cpu` and
:meth:`Axpby::eval_gpu`. Further, it provides rules of transformations in
:meth:`Axpby::jvp`, :meth:`Axpby::vjp`, and :meth:`Axpby::vmap`.
Using the Primitives
^^^^^^^^^^^^^^^^^^^^^
Operations can use this :class:`Primitive` to add a new :class:`array` to
the computation graph. An :class:`array` can be constructed by providing its
data type, shape, the :class:`Primitive` that computes it, and the
:class:`array` inputs that are passed to the primitive.
Let's re-implement our operation now in terms of our :class:`Axpby` primitive.
.. code-block:: C++
array axpby(
const array& x, // Input array x
const array& y, // Input array y
const float alpha, // Scaling factor for x
const float beta, // Scaling factor for y
StreamOrDevice s /* = {} */ // Stream on which to schedule the operation
) {
// Promote dtypes between x and y as needed
auto promoted_dtype = promote_types(x.dtype(), y.dtype());
// Upcast to float32 for non-floating point inputs x and y
auto out_dtype = is_floating_point(promoted_dtype)
? promoted_dtype
: promote_types(promoted_dtype, float32);
// Cast x and y up to the determined dtype (on the same stream s)
auto x_casted = astype(x, out_dtype, s);
auto y_casted = astype(y, out_dtype, s);
// Broadcast the shapes of x and y (on the same stream s)
auto broadcasted_inputs = broadcast_arrays({x_casted, y_casted}, s);
auto out_shape = broadcasted_inputs[0].shape();
// Construct the array as the output of the Axpby primitive
// with the broadcasted and upcasted arrays as inputs
return array(
/* const std::vector<int>& shape = */ out_shape,
/* Dtype dtype = */ out_dtype,
/* std::unique_ptr<Primitive> primitive = */
std::make_unique<Axpby>(to_stream(s), alpha, beta),
/* const std::vector<array>& inputs = */ broadcasted_inputs);
}
This operation now handles the following:
#. Upcast inputs and resolve the the output data type.
#. Broadcast the inputs and resolve the output shape.
#. Construct the primitive :class:`Axpby` using the given stream, ``alpha``, and ``beta``.
#. Construct the output :class:`array` using the primitive and the inputs.
Implementing the Primitive
--------------------------
No computation happens when we call the operation alone. In effect, the
operation only builds the computation graph. When we evaluate the output
array, MLX schedules the execution of the computation graph, and calls
:meth:`Axpby::eval_cpu` or :meth:`Axpby::eval_gpu` depending on the
stream/device specified by the user.
.. warning::
When :meth:`Primitive::eval_cpu` or :meth:`Primitive::eval_gpu` are called,
no memory has been allocated for the output array. It falls on the implementation
of these functions to allocate memory as needed
Implementing the CPU Backend
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Let's start by trying to implement a naive and generic version of
:meth:`Axpby::eval_cpu`. We declared this as a private member function of
:class:`Axpby` earlier called :meth:`Axpby::eval`.
Our naive method will go over each element of the output array, find the
corresponding input elements of ``x`` and ``y`` and perform the operation
pointwise. This is captured in the templated function :meth:`axpby_impl`.
.. code-block:: C++
template <typename T>
void axpby_impl(
const array& x,
const array& y,
array& out,
float alpha_,
float beta_) {
// We only allocate memory when we are ready to fill the output
// malloc_or_wait synchronously allocates available memory
// There may be a wait executed here if the allocation is requested
// under memory-pressured conditions
out.set_data(allocator::malloc_or_wait(out.nbytes()));
// Collect input and output data pointers
const T* x_ptr = x.data<T>();
const T* y_ptr = y.data<T>();
T* out_ptr = out.data<T>();
// Cast alpha and beta to the relevant types
T alpha = static_cast<T>(alpha_);
T beta = static_cast<T>(beta_);
// Do the elementwise operation for each output
for (size_t out_idx = 0; out_idx < out.size(); out_idx++) {
// Map linear indices to offsets in x and y
auto x_offset = elem_to_loc(out_idx, x.shape(), x.strides());
auto y_offset = elem_to_loc(out_idx, y.shape(), y.strides());
// We allocate the output to be contiguous and regularly strided
// (defaults to row major) and hence it doesn't need additonal mapping
out_ptr[out_idx] = alpha * x_ptr[x_offset] + beta * y_ptr[y_offset];
}
}
Now, we would like our implementation to be able to do this pointwise operation
for all incoming floating point arrays. Accordingly, we add dispatches for
``float32``, ``float16``, ``bfloat16`` and ``complex64``. We throw an error
if we encounter an unexpected type.
.. code-block:: C++
/** Fall back implementation for evaluation on CPU */
void Axpby::eval(const std::vector<array>& inputs, array& out) {
// Check the inputs (registered in the op while contructing the out array)
assert(inputs.size() == 2);
auto& x = inputs[0];
auto& y = inputs[1];
// Dispatch to the correct dtype
if (out.dtype() == float32) {
return axpby_impl<float>(x, y, out, alpha_, beta_);
} else if (out.dtype() == float16) {
return axpby_impl<float16_t>(x, y, out, alpha_, beta_);
} else if (out.dtype() == bfloat16) {
return axpby_impl<bfloat16_t>(x, y, out, alpha_, beta_);
} else if (out.dtype() == complex64) {
return axpby_impl<complex64_t>(x, y, out, alpha_, beta_);
} else {
throw std::runtime_error(
"Axpby is only supported for floating point types.");
}
}
We have a fallback implementation! Now, to do what we are really here to do.
Remember we wanted to use the ``axpby`` routine provided by the Accelerate_
framework? Well, there are 3 complications to keep in mind:
#. Accelerate does not provide implementations of ``axpby`` for half precision
floats. We can only direct to it for ``float32`` types
#. Accelerate assumes the inputs ``x`` and ``y`` are contiguous and all elements
have fixed strides between them. Possibly due to broadcasts and transposes,
we aren't guaranteed that the inputs fit this requirement. We can
only direct to Accelerate if both ``x`` and ``y`` are row contiguous or
column contiguous.
#. Accelerate performs the routine ``Y = (alpha * X) + (beta * Y)`` inplace.
MLX expects to write out the answer to a new array. We must copy the elements
of ``y`` into the output array and use that as an input to ``axpby``
Let's write out an implementation that uses Accelerate in the right conditions.
It must simply allocate data for the output, copy elements of ``y`` into it,
and then call the :meth:`catlas_saxpby` from accelerate.
.. code-block:: C++
template <typename T>
void axpby_impl_accelerate(
const array& x,
const array& y,
array& out,
float alpha_,
float beta_) {
// Accelerate library provides catlas_saxpby which does
// Y = (alpha * X) + (beta * Y) in place
// To use it, we first copy the data in y over to the output array
// This specialization requires both x and y be contiguous in the same mode
// i.e: corresponding linear indices in both point to corresponding elements
// The data in the output array is allocated to match the strides in y
// such that x, y, and out are contiguous in the same mode and
// no transposition is needed
out.set_data(
allocator::malloc_or_wait(y.data_size() * out.itemsize()),
y.data_size(),
y.strides(),
y.flags());
// We then copy over the elements using the contiguous vector specialization
copy_inplace(y, out, CopyType::Vector);
// Get x and y pointers for catlas_saxpby
const T* x_ptr = x.data<T>();
T* y_ptr = out.data<T>();
T alpha = static_cast<T>(alpha_);
T beta = static_cast<T>(beta_);
// Call the inplace accelerate operator
catlas_saxpby(
/* N = */ out.size(),
/* ALPHA = */ alpha,
/* X = */ x_ptr,
/* INCX = */ 1,
/* BETA = */ beta,
/* Y = */ y_ptr,
/* INCY = */ 1);
}
Great! But what about the inputs that do not fit the criteria for accelerate?
Luckily, we can always just direct back to :meth:`Axpby::eval`.
With this in mind, lets finally implement our :meth:`Axpby::eval_cpu`.
.. code-block:: C++
/** Evaluate primitive on CPU using accelerate specializations */
void Axpby::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& x = inputs[0];
auto& y = inputs[1];
// Accelerate specialization for contiguous single precision float arrays
if (out.dtype() == float32 &&
((x.flags().row_contiguous && y.flags().row_contiguous) ||
(x.flags().col_contiguous && y.flags().col_contiguous))) {
axpby_impl_accelerate<float>(x, y, out, alpha_, beta_);
return;
}
// Fall back to common backend if specializations are not available
eval(inputs, out);
}
We have now hit a milestone! Just this much is enough to run the operation
:meth:`axpby` on a CPU stream!
If you do not plan on running the operation on the GPU or using transforms on
computation graphs that contain :class:`Axpby`, you can stop implementing the
primitive here and enjoy the speed-ups you get from the Accelerate library.
Implementing the GPU Backend
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Apple silicon devices address their GPUs using the Metal_ shading language, and
all GPU kernels in MLX are written using metal.
.. note::
Here are some helpful resources if you are new to metal!
* A walkthrough of the metal compute pipeline: `Metal Example`_
* Documentation for metal shading language: `Metal Specification`_
* Using metal from C++: `Metal-cpp`_
Let's keep the GPU algorithm simple. We will launch exactly as many threads
as there are elements in the output. Each thread will pick the element it needs
from ``x`` and ``y``, do the pointwise operation, and then update its assigned
element in the output.
.. code-block:: C++
template <typename T>
[[kernel]] void axpby_general(
device const T* x [[buffer(0)]],
device const T* y [[buffer(1)]],
device T* out [[buffer(2)]],
constant const float& alpha [[buffer(3)]],
constant const float& beta [[buffer(4)]],
constant const int* shape [[buffer(5)]],
constant const size_t* x_strides [[buffer(6)]],
constant const size_t* y_strides [[buffer(7)]],
constant const int& ndim [[buffer(8)]],
uint index [[thread_position_in_grid]]) {
// Convert linear indices to offsets in array
auto x_offset = elem_to_loc(index, shape, x_strides, ndim);
auto y_offset = elem_to_loc(index, shape, y_strides, ndim);
// Do the operation and update the output
out[index] =
static_cast<T>(alpha) * x[x_offset] + static_cast<T>(beta) * y[y_offset];
}
We then need to instantiate this template for all floating point types and give
each instantiation a unique host name so we can identify the right kernel for
each data type.
.. code-block:: C++
#define instantiate_axpby(type_name, type) \
template [[host_name("axpby_general_" #type_name)]] \
[[kernel]] void axpby_general<type>( \
device const type* x [[buffer(0)]], \
device const type* y [[buffer(1)]], \
device type* out [[buffer(2)]], \
constant const float& alpha [[buffer(3)]], \
constant const float& beta [[buffer(4)]], \
constant const int* shape [[buffer(5)]], \
constant const size_t* x_strides [[buffer(6)]], \
constant const size_t* y_strides [[buffer(7)]], \
constant const int& ndim [[buffer(8)]], \
uint index [[thread_position_in_grid]]);
instantiate_axpby(float32, float);
instantiate_axpby(float16, half);
instantiate_axpby(bflot16, bfloat16_t);
instantiate_axpby(complex64, complex64_t);
This kernel will be compiled into a metal library ``mlx_ext.metallib`` as we
will see later in :ref:`Building with CMake`. In the following example, we
assume that the library ``mlx_ext.metallib`` will always be co-located with
the executable/ shared-library calling the :meth:`register_library` function.
The :meth:`register_library` function takes the library's name and potential
path (or in this case, a function that can produce the path of the metal
library) and tries to load that library if it hasn't already been registered
by the relevant static :class:`mlx::core::metal::Device` object. This is why,
it is important to package your C++ library with the metal library. We will
go over this process in more detail later.
The logic to determine the kernel, set the inputs, resolve the grid dimensions
and dispatch it to the GPU are contained in :meth:`Axpby::eval_gpu` as shown
below.
.. code-block:: C++
/** Evaluate primitive on GPU */
void Axpby::eval_gpu(const std::vector<array>& inputs, array& out) {
// Prepare inputs
assert(inputs.size() == 2);
auto& x = inputs[0];
auto& y = inputs[1];
// Each primitive carries the stream it should execute on
// and each stream carries its device identifiers
auto& s = stream();
// We get the needed metal device using the stream
auto& d = metal::device(s.device);
// Allocate output memory
out.set_data(allocator::malloc_or_wait(out.nbytes()));
// Resolve name of kernel (corresponds to axpby.metal)
std::ostringstream kname;
kname << "axpby_" << "general_" << type_to_name(out);
// Make sure the metal library is available and look for it
// in the same folder as this executable if needed
d.register_library("mlx_ext", metal::get_colocated_mtllib_path);
// Make a kernel from this metal library
auto kernel = d.get_kernel(kname.str(), "mlx_ext");
// Prepare to encode kernel
auto compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
// Kernel parameters are registered with buffer indices corresponding to
// those in the kernel decelaration at axpby.metal
int ndim = out.ndim();
size_t nelem = out.size();
// Encode input arrays to kernel
set_array_buffer(compute_encoder, x, 0);
set_array_buffer(compute_encoder, y, 1);
// Encode output arrays to kernel
set_array_buffer(compute_encoder, out, 2);
// Encode alpha and beta
compute_encoder->setBytes(&alpha_, sizeof(float), 3);
compute_encoder->setBytes(&beta_, sizeof(float), 4);
// Encode shape, strides and ndim
compute_encoder->setBytes(x.shape().data(), ndim * sizeof(int), 5);
compute_encoder->setBytes(x.strides().data(), ndim * sizeof(size_t), 6);
compute_encoder->setBytes(y.strides().data(), ndim * sizeof(size_t), 7);
compute_encoder->setBytes(&ndim, sizeof(int), 8);
// We launch 1 thread for each input and make sure that the number of
// threads in any given threadgroup is not higher than the max allowed
size_t tgp_size = std::min(nelem, kernel->maxTotalThreadsPerThreadgroup());
// Fix the 3D size of each threadgroup (in terms of threads)
MTL::Size group_dims = MTL::Size(tgp_size, 1, 1);
// Fix the 3D size of the launch grid (in terms of threads)
MTL::Size grid_dims = MTL::Size(nelem, 1, 1);
// Launch the grid with the given number of threads divded among
// the given threadgroups
compute_encoder->dispatchThreads(grid_dims, group_dims);
}
We can now call the :meth:`axpby` operation on both the CPU and the GPU!
A few things to note about MLX and metal before moving on. MLX keeps track
of the active ``compute_encoder``. We rely on :meth:`d.get_command_encoder`
to give us the active metal compute command encoder instead of building a
new one and calling :meth:`compute_encoder->end_encoding` at the end.
MLX keeps adding kernels (compute pipelines) to the active command encoder
until some specified limit is hit or the compute encoder needs to be flushed
for synchronization. MLX also handles enqueuing and commiting the associated
command buffers as needed. We suggest taking a deeper dive into
:class:`metal::Device` if you would like to study this routine further.
Primitive Transforms
^^^^^^^^^^^^^^^^^^^^^
Now that we have come this far, let's also learn how to add implementations to
transformations in a :class:`Primitive`. These transformations can be built on
top of our operations, including the one we just defined now. Which then gives
us the following :meth:`Axpby::jvp` and :meth:`Axpby::vjp` implementations.
.. code-block:: C++
/** The Jacobian-vector product. */
array Axpby::jvp(
const std::vector<array>& primals,
const std::vector<array>& tangents,
const std::vector<int>& argnums) {
// Forward mode diff that pushes along the tangents
// The jvp transform on the the primitive can built with ops
// that are scheduled on the same stream as the primtive
// If argnums = {0}, we only push along x in which case the
// jvp is just the tangent scaled by alpha
// Similarly, if argnums = {1}, the jvp is just the tangent
// scaled by beta
if (argnums.size() > 1) {
auto scale = argnums[0] == 0 ? alpha_ : beta_;
auto scale_arr = array(scale, tangents[0].dtype());
return multiply(scale_arr, tangents[0], stream());
}
// If, argnums = {0, 1}, we take contributions from both
// which gives us jvp = tangent_x * alpha + tangent_y * beta
else {
return axpby(tangents[0], tangents[1], alpha_, beta_, stream());
}
}
.. code-block:: C++
/** The vector-Jacobian product. */
std::vector<array> Axpby::vjp(
const std::vector<array>& primals,
const array& cotan,
const std::vector<int>& argnums) {
// Reverse mode diff
std::vector<array> vjps;
for (auto arg : argnums) {
auto scale = arg == 0 ? alpha_ : beta_;
auto scale_arr = array(scale, cotan.dtype());
vjps.push_back(multiply(scale_arr, cotan, stream()));
}
return vjps;
}
Finally, you need not have a transformation fully defined to start using your
own :class:`Primitive`.
.. code-block:: C++
/** Vectorize primitve along given axis */
std::pair<array, int> Axpby::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
throw std::runtime_error("Axpby has no vmap implementation.");
}
Building and Binding
--------------------
Let's look at the overall directory structure first.
| extensions
| ├── axpby
| │ ├── axpby.cpp
| │ ├── axpby.h
| │ └── axpby.metal
| ├── mlx_sample_extensions
| │ └── __init__.py
| ├── bindings.cpp
| ├── CMakeLists.txt
| └── setup.py
* ``extensions/axpby/`` defines the C++ extension library
* ``extensions/mlx_sample_extensions`` sets out the strucutre for the
associated python package
* ``extensions/bindings.cpp`` provides python bindings for our operation
* ``extensions/CMakeLists.txt`` holds CMake rules to build the library and
python bindings
* ``extensions/setup.py`` holds the ``setuptools`` rules to build and install
the python package
Binding to Python
^^^^^^^^^^^^^^^^^^
We use PyBind11_ to build a Python API for the C++ library. Since bindings
for all needed components such as `mlx.core.array`, `mlx.core.stream`, etc.
are already provided, adding our :meth:`axpby` becomes very simple!
.. code-block:: C++
PYBIND11_MODULE(mlx_sample_extensions, m) {
m.doc() = "Sample C++ and metal extensions for MLX";
m.def(
"axpby",
&axpby,
"x"_a,
"y"_a,
py::pos_only(),
"alpha"_a,
"beta"_a,
py::kw_only(),
"stream"_a = py::none(),
R"pbdoc(
Scale and sum two vectors elementwise
``z = alpha * x + beta * y``
Follows numpy style broadcasting between ``x`` and ``y``
Inputs are upcasted to floats if needed
Args:
x (array): Input array.
y (array): Input array.
alpha (float): Scaling factor for ``x``.
beta (float): Scaling factor for ``y``.
Returns:
array: ``alpha * x + beta * y``
)pbdoc");
}
Most of the complexity in the above example comes from additional bells and
whistles such as the literal names and doc-strings.
.. warning::
:mod:`mlx.core` needs to be imported before importing
:mod:`mlx_sample_extensions` as defined by the pybind11 module above to
ensure that the casters for :mod:`mlx.core` components like
:class:`mlx.core.array` are available.
.. _Building with CMake:
Building with CMake
^^^^^^^^^^^^^^^^^^^^
Building the C++ extension library itself is simple, it only requires that you
``find_package(MLX CONFIG)`` and then link it to your library.
.. code-block:: cmake
# Add library
add_library(mlx_ext)
# Add sources
target_sources(
mlx_ext
PUBLIC
${CMAKE_CURRENT_LIST_DIR}/axpby/axpby.cpp
)
# Add include headers
target_include_directories(
mlx_ext PUBLIC ${CMAKE_CURRENT_LIST_DIR}
)
# Link to mlx
target_link_libraries(mlx_ext PUBLIC mlx)
We also need to build the attached metal library. For convenience, we provide a
:meth:`mlx_build_metallib` function that builds a ``.metallib`` target given
sources, headers, destinations, etc. (defined in ``cmake/extension.cmake`` and
automatically imported with MLX package).
Here is what that looks like in practice!
.. code-block:: cmake
# Build metallib
if(MLX_BUILD_METAL)
mlx_build_metallib(
TARGET mlx_ext_metallib
TITLE mlx_ext
SOURCES ${CMAKE_CURRENT_LIST_DIR}/axpby/axpby.metal
INCLUDE_DIRS ${PROJECT_SOURCE_DIR} ${MLX_INCLUDE_DIRS}
OUTPUT_DIRECTORY ${CMAKE_LIBRARY_OUTPUT_DIRECTORY}
)
add_dependencies(
mlx_ext
mlx_ext_metallib
)
endif()
Finally, we build the Pybind11_ bindings
.. code-block:: cmake
pybind11_add_module(
mlx_sample_extensions
${CMAKE_CURRENT_LIST_DIR}/bindings.cpp
)
target_link_libraries(mlx_sample_extensions PRIVATE mlx_ext)
if(BUILD_SHARED_LIBS)
target_link_options(mlx_sample_extensions PRIVATE -Wl,-rpath,@loader_path)
endif()
Building with ``setuptools``
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Once we have set out the CMake build rules as described above, we can use the
build utilities defined in :mod:`mlx.extension` for a simple build process.
.. code-block:: python
from mlx import extension
from setuptools import setup
if __name__ == "__main__":
setup(
name="mlx_sample_extensions",
version="0.0.0",
description="Sample C++ and Metal extensions for MLX primitives.",
ext_modules=[extension.CMakeExtension("mlx_sample_extensions")],
cmdclass={"build_ext": extension.CMakeBuild},
packages = ["mlx_sample_extensions"],
package_dir = {"": "mlx_sample_extensions"},
package_data = {"mlx_sample_extensions" : ["*.so", "*.dylib", "*.metallib"]},
zip_safe=False,
python_requires=">=3.7",
)
.. note::
We treat ``extensions/mlx_sample_extensions`` as the package directory
even though it only contains a ``__init__.py`` to ensure the following:
* :mod:`mlx.core` is always imported before importing :mod:`mlx_sample_extensions`
* The C++ extension library and the metal library are co-located with the python
bindings and copied together if the package is installed
You can build inplace for development using
``python setup.py build_ext -j8 --inplace`` (in ``extensions/``)
This will result in a directory structure as follows:
| extensions
| ├── mlx_sample_extensions
| │ ├── __init__.py
| │ ├── libmlx_ext.dylib # C++ extension library
| │ ├── mlx_ext.metallib # Metal library
| │ └── mlx_sample_extensions.cpython-3x-darwin.so # Python Binding
| ...
When you try to install using the command ``python -m pip install .``
(in ``extensions/``), the package will be installed with the same strucutre as
``extensions/mlx_sample_extensions`` and the C++ and metal library will be
copied along with the python binding since they are specified as ``package_data``.
Usage
-----
After installing the extension as described above, you should be able to simply
import the python package and play with it as you would any other MLX operation!
Let's looks at a simple script and it's results!
.. code-block:: python
import mlx.core as mx
from mlx_sample_extensions import axpby
a = mx.ones((3, 4))
b = mx.ones((3, 4))
c = axpby(a, b, 4.0, 2.0, stream=mx.cpu)
print(f"c shape: {c.shape}")
print(f"c dtype: {c.dtype}")
print(f"c correctness: {mx.all(c == 6.0).item()}")
Output:
.. code-block::
c shape: [3, 4]
c dtype: float32
c correctness: True
Results
^^^^^^^^^^^^^^^^
Let's run a quick benchmark and see how our new ``axpby`` operation compares
with the naive :meth:`simple_axpby` we defined at first on the CPU.
.. code-block:: python
import mlx.core as mx
from mlx_sample_extensions import axpby
import time
mx.set_default_device(mx.cpu)
def simple_axpby(x: mx.array, y: mx.array, alpha: float, beta: float) -> mx.array:
return alpha * x + beta * y
M = 256
N = 512
x = mx.random.normal((M, N))
y = mx.random.normal((M, N))
alpha = 4.0
beta = 2.0
mx.eval((x, y))
def bench(f):
# Warm up
for i in range(100):
z = f(x, y, alpha, beta)
mx.eval(z)
# Timed run
s = time.time()
for i in range(5000):
z = f(x, y, alpha, beta)
mx.eval(z)
e = time.time()
return e - s
simple_time = bench(simple_axpby)
custom_time = bench(axpby)
print(f"Simple axpby: {simple_time:.3f} s | Custom axpby: {custom_time:.3f} s")
Results:
.. code-block::
Simple axpby: 0.114 s | Custom axpby: 0.109 s
We see some modest improvements right away!
This operation is now good to be used to build other operations,
in :class:`mlx.nn.Module` calls, and also as a part of graph
transformations such as :meth:`grad` and :meth:`simplify`!
Scripts
-------
.. admonition:: Download the code
The full example code is available in `mlx-examples <code>`_.
.. code: `TODO_LINK/extensions`_
.. _Accelerate: https://developer.apple.com/documentation/accelerate/blas?language=objc
.. _Metal: https://developer.apple.com/documentation/metal?language=objc
.. _Metal-cpp: https://developer.apple.com/metal/cpp/
.. _`Metal Specification`: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
.. _`Metal Example`: https://developer.apple.com/documentation/metal/performing_calculations_on_a_gpu?language=objc
.. _PyBind11: https://pybind11.readthedocs.io/en/stable/

View File

@ -0,0 +1,77 @@
.. _linear_regression:
Linear Regression
-----------------
Let's implement a basic linear regression model as a starting point to
learn MLX. First import the core package and setup some problem metadata:
.. code-block:: python
import mlx.core as mx
num_features = 100
num_examples = 1_000
num_iters = 10_000 # iterations of SGD
lr = 0.01 # learning rate for SGD
We'll generate a synthetic dataset by:
1. Sampling the design matrix ``X``.
2. Sampling a ground truth parameter vector ``w_star``.
3. Compute the dependent values ``y`` by adding Gaussian noise to ``X @ w_star``.
.. code-block:: python
# True parameters
w_star = mx.random.normal((num_features,))
# Input examples (design matrix)
X = mx.random.normal((num_examples, num_features))
# Noisy labels
eps = 1e-2 * mx.random.normal((num_examples,))
y = X @ w_star + eps
We will use SGD to find the optimal weights. To start, define the squared loss
and get the gradient function of the loss with respect to the parameters.
.. code-block:: python
def loss_fn(w):
return 0.5 * mx.mean(mx.square(X @ w - y))
grad_fn = mx.grad(loss_fn)
Start the optimization by initializing the parameters ``w`` randomly. Then
repeatedly update the parameters for ``num_iters`` iterations.
.. code-block:: python
w = 1e-2 * mx.random.normal((num_features,))
for _ in range(num_iters):
grad = grad_fn(w)
w = w - lr * grad
mx.eval(w)
Finally, compute the loss of the learned parameters and verify that they are
close to the ground truth parameters.
.. code-block:: python
loss = loss_fn(w)
error_norm = mx.sum(mx.square(w - w_star)).item() ** 0.5
print(
f"Loss {loss.item():.5f}, |w-w*| = {error_norm:.5f}, "
)
# Should print something close to: Loss 0.00005, |w-w*| = 0.00364
Complete `linear regression
<https://github.com/ml-explore/mlx/tree/main/examples/python/linear_regression.py>`_
and `logistic regression
<https://github.com/ml-explore/mlx/tree/main/examples/python/logistic_regression.py>`_
examples are available in the MLX GitHub repo.

View File

@ -0,0 +1,382 @@
LLM inference
==============
MLX enables efficient inference of large-ish transformers on Apple silicon
without compromising on ease of use. In this example we will create an
inference script for the Llama family of transformer models in which the model
is defined in less than 200 lines of python.
Implementing the model
----------------------
We will use the neural network building blocks defined in the :mod:`mlx.nn`
module to concisely define the model architecture.
Attention layer
^^^^^^^^^^^^^^^^
We will start with the llama attention layer which notably uses the RoPE
positional encoding. [1]_ In addition, our attention layer will optionally use a
key/value cache that will be concatenated with the provided keys and values to
support efficient inference.
Our implementation uses :class:`mlx.nn.Linear` for all the projections and
:class:`mlx.nn.RoPE` for the positional encoding.
.. code-block:: python
import mlx.core as mx
import mlx.nn as nn
class LlamaAttention(nn.Module):
def __init__(self, dims: int, num_heads: int):
super().__init__()
self.num_heads = num_heads
self.rope = nn.RoPE(dims // num_heads, traditional=True)
self.query_proj = nn.Linear(dims, dims, bias=False)
self.key_proj = nn.Linear(dims, dims, bias=False)
self.value_proj = nn.Linear(dims, dims, bias=False)
self.out_proj = nn.Linear(dims, dims, bias=False)
def __call__(self, queries, keys, values, mask=None, cache=None):
queries = self.query_proj(queries)
keys = self.key_proj(keys)
values = self.value_proj(values)
# Extract some shapes
num_heads = self.num_heads
B, L, D = queries.shape
# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
# Add RoPE to the queries and keys and combine them with the cache
if cache is not None:
key_cache, value_cache = cache
queries = self.rope(queries, offset=key_cache.shape[2])
keys = self.rope(keys, offset=key_cache.shape[2])
keys = mx.concatenate([key_cache, keys], axis=2)
values = mx.concatenate([value_cache, values], axis=2)
else:
queries = self.rope(queries)
keys = self.rope(keys)
# Finally perform the attention computation
scale = math.sqrt(1 / queries.shape[-1])
scores = (queries * scale) @ keys.transpose(0, 1, 3, 2)
if mask is not None:
scores = scores + mask
scores = mx.softmax(scores, axis=-1)
values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
# Note that we return the keys and values to possibly be used as a cache
return self.out_proj(values_hat), (keys, values)
Encoder layer
^^^^^^^^^^^^^
The other component of the Llama model is the encoder layer which uses RMS
normalization [2]_ and SwiGLU. [3]_ For RMS normalization we will use
:class:`mlx.nn.RMSNorm` that is already provided in :mod:`mlx.nn`.
.. code-block:: python
class LlamaEncoderLayer(nn.Module):
def __init__(self, dims: int, mlp_dims: int, num_heads: int):
super().__init__()
self.attention = LlamaAttention(dims, num_heads)
self.norm1 = nn.RMSNorm(dims)
self.norm2 = nn.RMSNorm(dims)
self.linear1 = nn.Linear(dims, mlp_dims, bias=False)
self.linear2 = nn.Linear(dims, mlp_dims, bias=False)
self.linear3 = nn.Linear(mlp_dims, dims, bias=False)
def __call__(self, x, mask=None, cache=None):
y = self.norm1(x)
y, cache = self.attention(y, y, y, mask, cache)
x = x + y
y = self.norm2(x)
a = self.linear1(y)
b = self.linear2(y)
y = a * mx.sigmoid(a) * b
y = self.linear3(y)
x = x + y
return x, cache
Full model
^^^^^^^^^^
To implement any Llama model we simply have to combine ``LlamaEncoderLayer``
instances with an :class:`mlx.nn.Embedding` to embed the input tokens.
.. code-block:: python
class Llama(nn.Module):
def __init__(
self, num_layers: int, vocab_size: int, dims: int, mlp_dims: int, num_heads: int
):
super().__init__()
self.embedding = nn.Embedding(vocab_size, dims)
self.layers = [
LlamaEncoderLayer(dims, mlp_dims, num_heads) for _ in range(num_layers)
]
self.norm = nn.RMSNorm(dims)
self.out_proj = nn.Linear(dims, vocab_size, bias=False)
def __call__(self, x):
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
mask = mask.astype(self.embedding.weight.dtype)
x = self.embedding(x)
for l in self.layers:
x, _ = l(x, mask)
x = self.norm(x)
return self.out_proj(x)
Note that in the implementation above we use a simple list to hold the encoder
layers but using ``model.parameters()`` will still consider these layers.
Generation
^^^^^^^^^^^
Our ``Llama`` module can be used for training but not inference as the
``__call__`` method above processes one input, completely ignores the cache and
performs no sampling whatsoever. In the rest of this subsection, we will
implement the inference function as a python generator that processes the
prompt and then autoregressively yields tokens one at a time.
.. code-block:: python
class Llama(nn.Module):
...
def generate(self, x, temp=1.0):
cache = []
# Make an additive causal mask. We will need that to process the prompt.
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
mask = mask.astype(self.embedding.weight.dtype)
# First we process the prompt x the same way as in __call__ but
# save the caches in cache
x = self.embedding(x)
for l in self.layers:
x, c = l(x, mask=mask)
cache.append(c) # <--- we store the per layer cache in a
# simple python list
x = self.norm(x)
y = self.out_proj(x[:, -1]) # <--- we only care about the last logits
# that generate the next token
y = mx.random.categorical(y * (1/temp))
# y now has size [1]
# Since MLX is lazily evaluated nothing is computed yet.
# Calling y.item() would force the computation to happen at
# this point but we can also choose not to do that and let the
# user choose when to start the computation.
yield y
# Now we parsed the prompt and generated the first token we
# need to feed it back into the model and loop to generate the
# rest.
while True:
# Unsqueezing the last dimension to add a sequence length
# dimension of 1
x = y[:, None]
x = self.embedding(x)
for i in range(len(cache)):
# We are overwriting the arrays in the cache list. When
# the computation will happen, MLX will be discarding the
# old cache the moment it is not needed anymore.
x, cache[i] = self.layers[i](x, mask=None, cache=cache[i])
x = self.norm(x)
y = self.out_proj(x[:, -1])
y = mx.random.categorical(y * (1/temp))
yield y
Putting it all together
^^^^^^^^^^^^^^^^^^^^^^^
We now have everything we need to create a Llama model and sample tokens from
it. In the following code, we randomly initialize a small Llama model, process
6 tokens of prompt and generate 10 tokens.
.. code-block:: python
model = Llama(num_layers=12, vocab_size=8192, dims=512, mlp_dims=1024, num_heads=8)
# Since MLX is lazily evaluated nothing has actually been materialized yet.
# We could have set the `dims` to 20_000 on a machine with 8GB of RAM and the
# code above would still run. Let's actually materialize the model.
mx.eval(model.parameters())
prompt = mx.array([[1, 10, 8, 32, 44, 7]]) # <-- Note the double brackets because we
# have a batch dimension even
# though it is 1 in this case
generated = [t for i, t in zip(range(10), model.generate(prompt, 0.8))]
# Since we haven't evaluated anything, nothing is computed yet. The list
# `generated` contains the arrays that hold the computation graph for the
# full processing of the prompt and the generation of 10 tokens.
#
# We can evaluate them one at a time, or all together. Concatenate them or
# print them. They would all result in very similar runtimes and give exactly
# the same results.
mx.eval(generated)
Converting the weights
----------------------
This section assumes that you have access to the original Llama weights and the
SentencePiece model that comes with them. We will write a small script to
convert the PyTorch weights to MLX compatible ones and write them in a NPZ file
that can be loaded directly by MLX.
.. code-block:: python
import argparse
from itertools import starmap
import numpy as np
import torch
def map_torch_to_mlx(key, value):
if "tok_embedding" in key:
key = "embedding.weight"
elif "norm" in key:
key = key.replace("attention_norm", "norm1").replace("ffn_norm", "norm2")
elif "wq" in key or "wk" in key or "wv" in key or "wo" in key:
key = key.replace("wq", "query_proj")
key = key.replace("wk", "key_proj")
key = key.replace("wv", "value_proj")
key = key.replace("wo", "out_proj")
elif "w1" in key or "w2" in key or "w3" in key:
# The FFN is a separate submodule in PyTorch
key = key.replace("feed_forward.w1", "linear1")
key = key.replace("feed_forward.w3", "linear2")
key = key.replace("feed_forward.w2", "linear3")
elif "output" in key:
key = key.replace("output", "out_proj")
elif "rope" in key:
return None, None
return key, value.numpy()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert Llama weights to MLX")
parser.add_argument("torch_weights")
parser.add_argument("output_file")
args = parser.parse_args()
state = torch.load(args.torch_weights)
np.savez(
args.output_file,
**{k: v for k, v in starmap(map_torch_to_mlx, state.items()) if k is not None}
)
Weight loading and benchmarking
-------------------------------
After converting the weights to be compatible to our implementation, all that is
left is to load them from disk and we can finally use the LLM to generate text.
We can load numpy format files using the :func:`mlx.core.load` operation.
To create a parameter dictionary from the key/value representation of NPZ files
we will use the :func:`mlx.utils.tree_unflatten` helper method as follows:
.. code-block:: python
from mlx.utils import tree_unflatten
model.update(tree_unflatten(list(mx.load(weight_file).items())))
:meth:`mlx.utils.tree_unflatten` will take keys from the NPZ file that look
like ``layers.2.attention.query_proj.weight`` and will transform them to
.. code-block:: python
{"layers": [..., ..., {"attention": {"query_proj": {"weight": ...}}}]}
which can then be used to update the model. Note that the method above incurs
several unnecessary copies from disk to numpy and then from numpy to MLX. It
will be replaced in the future with direct loading to MLX.
You can download the full example code in `mlx-examples <code>`_. Assuming, the
existence of ``weights.pth`` and ``tokenizer.model`` in the current working
directory we can play around with our inference script as follows (the timings
are representative of an M1 Ultra and the 7B parameter Llama model):
.. code-block:: bash
$ python convert.py weights.pth llama-7B.mlx.npz
$ python llama.py llama-7B.mlx.npz tokenizer.model 'Call me Ishmael. Some years ago never mind how long precisely'
[INFO] Loading model from disk: 5.247 s
Press enter to start generation
------
, having little or no money in my purse, and nothing of greater consequence in my mind, I happened to be walking down Gower Street in the afternoon, in the heavy rain, and I saw a few steps off, a man in rags, who sat upon his bundle and looked hard into the wet as if he were going to cry. I watched him attentively for some time, and could not but observe that, though a numerous crowd was hurrying up and down,
------
[INFO] Prompt processing: 0.437 s
[INFO] Full generation: 4.330 s
We observe that 4.3 seconds are required to generate 100 tokens and 0.4 seconds
of those are spent processing the prompt. This amounts to a little over **39 ms
per token**.
By running with a much bigger prompt we can see that the per token generation
time as well as the prompt processing time remains almost constant.
.. code-block:: bash
$ python llama.py llama-7B.mlx.npz tokenizer.model 'Call me Ishmael. Some years ago never mind how long precisely, having little or no money in my purse, and nothing of greater consequence in my mind, I happened to be walking down Gower Street in the afternoon, in the heavy rain, and I saw a few steps off, a man in rags, who sat upon his bundle and looked hard into the wet as if he were going to cry. I watched him attentively for some time, and could not but observe that, though a numerous crowd was hurrying up and down, nobody took the least notice of him. I stopped at last, at a little distance, as if I had been in doubt, and after looking on a few minutes, walked straight up to him. He slowly raised his eyes, and fixed them upon me for a moment, without speaking, and then resumed his place and posture as before. I stood looking at him for a while, feeling very much pain at heart, and then said to him, “What are you doing there?” Something like a smile passed over his face, as he said slowly, “I am waiting for someone; but it has been three quarters of an hour now, and he has not come.” “What is it you are waiting for?” said I. Still he made no immediate reply, but again put his face down upon his hands, and did not'
[INFO] Loading model from disk: 5.247 s
Press enter to start generation
------
take his eyes from the ground. “What is it you are waiting for?” said I. “I am not accustomed to be thus questioned,” said he. “You look like a reasonable man—tell me, then, what are you waiting for?” “You would not understand,” he replied; “and how could you help me, if I were to tell you?” “I should not only understand, but would do all that I could,” said I. He did not
------
[INFO] Prompt processing: 0.579 s
[INFO] Full generation: 4.690 s
$ python llama.py --num-tokens 500 llama-7B.mlx.npz tokenizer.model 'Call me Ishmael. Some years ago never mind how long precisely, having little or no money in my purse, and nothing of greater consequence in my mind, I happened to be walking down Gower Street in the afternoon, in the heavy rain, and I saw a few steps off, a man in rags, who sat upon his bundle and looked hard into the wet as if he were going to cry. I watched him attentively for some time, and could not but observe that, though a numerous crowd was hurrying up and down, nobody took the least notice of him. I stopped at last, at a little distance, as if I had been in doubt, and after looking on a few minutes, walked straight up to him. He slowly raised his eyes, and fixed them upon me for a moment, without speaking, and then resumed his place and posture as before. I stood looking at him for a while, feeling very much pain at heart, and then said to him, “What are you doing there?” Something like a smile passed over his face, as he said slowly, “I am waiting for someone; but it has been three quarters of an hour now, and he has not come.” “What is it you are waiting for?” said I. Still he made no immediate reply, but again put his face down upon his hands, and did not'
[INFO] Loading model from disk: 5.628 s
Press enter to start generation
------
take his eyes from the ground. “What is it you are waiting for?” said I. “I am not accustomed to be thus questioned,” said he. “You look like a reasonable man—tell me, then, what are you waiting for?” “You would not understand,” he replied; “and how could you help me, if I were to tell you?” “I should not only understand, but would do all that I could,” said I. He did not reply, but still went on looking at the ground, and took hold of his bundle with a nervous trembling. I waited some time, and then resumed. “It is of no use to say you would not understand, if I were to tell you,” said he. “I have not told you why I am waiting for him,” said I. “And I am sure I should not understand,” replied he. “I will tell you then,” said I, “and, perhaps, you would not be surprised.” “No matter,” said he, “I shall be surprised anyhow; so tell me why you are waiting for him.” “He is my friend,” said I. “Yes,” said he, with a slight smile, “I know.” “He has been kind to me,” said I, “and I am waiting for him. I want to see him, and could have waited as I am now, for a much longer time.” “He will not soon come,” said he. “Unless he sees you here, he will not know of your having waited, and he will be very unlikely to come.” “No matter,” said I, “I shall wait for him.” “This is a strange thing,” said he, still with the same amused smile. “How did you know,” said I, “that he was coming? How should you be waiting?” “That is my secret,” said he. “And you expect him?” “Yes,” said I. “Are you disappointed then, if he does not come?” “No,” said I, “it is his secret, not mine.” “If he comes,” said he, “do you mean to go straight away?” “Yes,” said I, “I cannot be happy if I do not go straight away after him.” “Did you know this place before?” asked he. “Yes,” said I. “Is there any shop to buy food here?” “
------
[INFO] Prompt processing: 0.633 s
[INFO] Full generation: 21.475 s
Scripts
-------
.. admonition:: Download the code
The full example code is available in `mlx-examples <code>`_.
.. code: `https://github.com/ml-explore/mlx-examples/tree/main/llama`_
.. [1] Su, J., Lu, Y., Pan, S., Murtadha, A., Wen, B. and Liu, Y., 2021.
Roformer: Enhanced transformer with rotary position embedding. arXiv
preprint arXiv:2104.09864.
.. [2] Zhang, B. and Sennrich, R., 2019. Root mean square layer normalization.
Advances in Neural Information Processing Systems, 32.
.. [3] Shazeer, N., 2020. Glu variants improve transformer. arXiv preprint
arXiv:2002.05202.

View File

@ -0,0 +1,131 @@
.. _mlp:
Multi-Layer Perceptron
----------------------
In this example we'll learn to use ``mlx.nn`` by implementing a simple
multi-layer perceptron to classify MNIST.
As a first step import the MLX packages we need:
.. code-block:: python
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
import numpy as np
The model is defined as the ``MLP`` class which inherits from
:class:`mlx.nn.Module`. We follow the standard idiom to make a new module:
1. Define an ``__init__`` where the parameters and/or submodules are setup. See
the :ref:`Module class docs<module_class>` for more information on how
:class:`mlx.nn.Module` registers parameters.
2. Define a ``__call__`` where the computation is implemented.
.. code-block:: python
class MLP(nn.Module):
def __init__(
self, num_layers: int, input_dim: int, hidden_dim: int, output_dim: int
):
super().__init__()
layer_sizes = [input_dim] + [hidden_dim] * num_layers + [output_dim]
self.layers = [
nn.Linear(idim, odim)
for idim, odim in zip(layer_sizes[:-1], layer_sizes[1:])
]
def __call__(self, x):
for l in self.layers[:-1]:
x = mx.maximum(l(x), 0.0)
return self.layers[-1](x)
We define the loss function which takes the mean of the per-example cross
entropy loss. The ``mlx.nn.losses`` sub-package has implementations of some
commonly used loss functions.
.. code-block:: python
def loss_fn(model, X, y):
return mx.mean(nn.losses.cross_entropy(model(X), y))
We also need a function to compute the accuracy of the model on the validation
set:
.. code-block:: python
def eval_fn(model, X, y):
return mx.mean(mx.argmax(model(X), axis=1) == y)
Next, setup the problem parameters and load the data:
.. code-block:: python
num_layers = 2
hidden_dim = 32
num_classes = 10
batch_size = 256
num_epochs = 10
learning_rate = 1e-1
# Load the data
import mnist
train_images, train_labels, test_images, test_labels = map(
mx.array, mnist.mnist()
)
Since we're using SGD, we need an iterator which shuffles and constructs
minibatches of examples in the training set:
.. code-block:: python
def batch_iterate(batch_size, X, y):
perm = mx.array(np.random.permutation(y.size))
for s in range(0, y.size, batch_size):
ids = perm[s : s + batch_size]
yield X[ids], y[ids]
Finally, we put it all together by instantiating the model, the
:class:`mlx.optimizers.SGD` optimizer, and running the training loop:
.. code-block:: python
# Load the model
model = MLP(num_layers, train_images.shape[-1], hidden_dim, num_classes)
mx.eval(model.parameters())
# Get a function which gives the loss and gradient of the
# loss with respect to the model's trainable parameters
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
# Instantiate the optimizer
optimizer = optim.SGD(learning_rate=learning_rate)
for e in range(num_epochs):
for X, y in batch_iterate(batch_size, train_images, train_labels):
loss, grads = loss_and_grad_fn(model, X, y)
# Update the optimizer state and model parameters
# in a single call
optimizer.update(model, grads)
# Force a graph evaluation
mx.eval(model.parameters(), optimizer.state)
accuracy = eval_fn(model, test_images, test_labels)
print(f"Epoch {e}: Test accuracy {accuracy.item():.3f}")
.. note::
The :func:`mlx.nn.value_and_grad` function is a convenience function to get
the gradient of a loss with respect to the trainable parameters of a model.
This should not be confused with :func:`mlx.core.value_and_grad`.
The model should train to a decent accuracy (about 95%) after just a few passes
over the training set. The `full example <https://github.com/ml-explore/mlx-examples/tree/main/mlp>`_
is available in the MLX GitHub repo.

49
docs/build/html/_sources/index.rst.txt vendored Normal file
View File

@ -0,0 +1,49 @@
MLX
===
.. toctree::
:caption: Install
:maxdepth: 1
install
.. toctree::
:caption: Usage
:maxdepth: 1
quick_start
using_streams
.. toctree::
:caption: Examples
:maxdepth: 1
examples/linear_regression
examples/mlp
examples/llama-inference
.. toctree::
:caption: Further Reading
:maxdepth: 1
dev/extensions
.. toctree::
:caption: Python API Reference
:maxdepth: 1
python/array
python/devices_and_streams
python/ops
python/random
python/transforms
python/fft
python/nn
python/optimizers
python/tree_utils
.. toctree::
:caption: C++ API Reference
:maxdepth: 1
cpp/ops

102
docs/build/html/_sources/install.rst.txt vendored Normal file
View File

@ -0,0 +1,102 @@
Build and Install
=================
Install from PyPI
-----------------
MLX is available at Apple's internal PyPI repository. All you have to do to use
MLX with your own Apple silicon computer is
.. code-block:: shell
pip install apple-mlx -i https://pypi.apple.com/simple
Build from source
-----------------
Build Requirements
^^^^^^^^^^^^^^^^^^
- A C++ compiler with C++17 support (e.g. Clang >= 5.0)
- `cmake <https://cmake.org/>`_ -- version 3.24 or later, and ``make``
Python API
^^^^^^^^^^
To build and install the MLX python library from source, first, clone MLX from
`its GitHub repo <https://github.com/ml-explore/mlx>`_:
.. code-block:: shell
git clone git@github.com:ml-explore/mlx.git mlx && cd mlx
Make sure that you have `pybind11 <https://pybind11.readthedocs.io/en/stable/index.html>`_
installed. You can install ``pybind11`` with ``pip``, ``brew`` or ``conda`` as follows:
.. code-block:: shell
pip install "pybind11[global]"
conda install pybind11
brew install pybind11
Then simply build and install it using pip:
.. code-block:: shell
env CMAKE_BUILD_PARALLEL_LEVEL="" pip install .
C++ API
^^^^^^^
Currently, MLX must be built and installed from source.
Similarly to the python library, to build and install the MLX C++ library start
by cloning MLX from `its GitHub repo
<https://github.com/ml-explore/mlx>`_:
.. code-block:: shell
git clone git@github.com:ml-explore/mlx.git mlx && cd mlx
Create a build directory and run CMake and make:
.. code-block:: shell
mkdir -p build && cd build
cmake .. && make -j
Run tests with:
.. code-block:: shell
make test
Install with:
.. code-block:: shell
make install
Note that the built ``mlx.metallib`` file should be either at the same
directory as the executable statically linked to ``libmlx.a`` or the
preprocessor constant ``METAL_PATH`` should be defined at build time and it
should point to the path to the built metal library.
.. list-table:: Build Options
:widths: 25 8
:header-rows: 1
* - Option
- Default
* - MLX_BUILD_TESTS
- ON
* - MLX_BUILD_EXAMPLES
- OFF
* - MLX_BUILD_BENCHMARKS
- OFF
* - MLX_BUILD_METAL
- ON
* - MLX_BUILD_PYTHON_BINDINGS
- OFF

View File

@ -0,0 +1,28 @@
mlx.core.Device
===============
.. currentmodule:: mlx.core
.. autoclass:: Device
.. automethod:: __init__
.. rubric:: Methods
.. autosummary::
~Device.__init__
.. rubric:: Attributes
.. autosummary::
~Device.type

View File

@ -0,0 +1,28 @@
mlx.core.Dtype
==============
.. currentmodule:: mlx.core
.. autoclass:: Dtype
.. automethod:: __init__
.. rubric:: Methods
.. autosummary::
~Dtype.__init__
.. rubric:: Attributes
.. autosummary::
~Dtype.size

View File

@ -0,0 +1,28 @@
mlx.core.Stream
===============
.. currentmodule:: mlx.core
.. autoclass:: Stream
.. automethod:: __init__
.. rubric:: Methods
.. autosummary::
~Stream.__init__
.. rubric:: Attributes
.. autosummary::
~Stream.device

View File

@ -0,0 +1,6 @@
mlx.core.abs
============
.. currentmodule:: mlx.core
.. autofunction:: abs

View File

@ -0,0 +1,6 @@
mlx.core.add
============
.. currentmodule:: mlx.core
.. autofunction:: add

View File

@ -0,0 +1,6 @@
mlx.core.all
============
.. currentmodule:: mlx.core
.. autofunction:: all

View File

@ -0,0 +1,6 @@
mlx.core.allclose
=================
.. currentmodule:: mlx.core
.. autofunction:: allclose

View File

@ -0,0 +1,6 @@
mlx.core.any
============
.. currentmodule:: mlx.core
.. autofunction:: any

View File

@ -0,0 +1,6 @@
mlx.core.arange
===============
.. currentmodule:: mlx.core
.. autofunction:: arange

View File

@ -0,0 +1,6 @@
mlx.core.arccos
===============
.. currentmodule:: mlx.core
.. autofunction:: arccos

View File

@ -0,0 +1,6 @@
mlx.core.arccosh
================
.. currentmodule:: mlx.core
.. autofunction:: arccosh

View File

@ -0,0 +1,6 @@
mlx.core.arcsin
===============
.. currentmodule:: mlx.core
.. autofunction:: arcsin

View File

@ -0,0 +1,6 @@
mlx.core.arcsinh
================
.. currentmodule:: mlx.core
.. autofunction:: arcsinh

View File

@ -0,0 +1,6 @@
mlx.core.arctan
===============
.. currentmodule:: mlx.core
.. autofunction:: arctan

View File

@ -0,0 +1,6 @@
mlx.core.arctanh
================
.. currentmodule:: mlx.core
.. autofunction:: arctanh

View File

@ -0,0 +1,6 @@
mlx.core.argmax
===============
.. currentmodule:: mlx.core
.. autofunction:: argmax

View File

@ -0,0 +1,6 @@
mlx.core.argmin
===============
.. currentmodule:: mlx.core
.. autofunction:: argmin

View File

@ -0,0 +1,6 @@
mlx.core.argpartition
=====================
.. currentmodule:: mlx.core
.. autofunction:: argpartition

View File

@ -0,0 +1,6 @@
mlx.core.argsort
================
.. currentmodule:: mlx.core
.. autofunction:: argsort

View File

@ -0,0 +1,6 @@
mlx.core.array.T
================
.. currentmodule:: mlx.core
.. autoproperty:: array.T

View File

@ -0,0 +1,6 @@
mlx.core.array.abs
==================
.. currentmodule:: mlx.core
.. automethod:: array.abs

View File

@ -0,0 +1,6 @@
mlx.core.array.all
==================
.. currentmodule:: mlx.core
.. automethod:: array.all

View File

@ -0,0 +1,6 @@
mlx.core.array.any
==================
.. currentmodule:: mlx.core
.. automethod:: array.any

View File

@ -0,0 +1,6 @@
mlx.core.array.argmax
=====================
.. currentmodule:: mlx.core
.. automethod:: array.argmax

View File

@ -0,0 +1,6 @@
mlx.core.array.argmin
=====================
.. currentmodule:: mlx.core
.. automethod:: array.argmin

View File

@ -0,0 +1,6 @@
mlx.core.array.astype
=====================
.. currentmodule:: mlx.core
.. automethod:: array.astype

View File

@ -0,0 +1,6 @@
mlx.core.array.cos
==================
.. currentmodule:: mlx.core
.. automethod:: array.cos

View File

@ -0,0 +1,6 @@
mlx.core.array.dtype
====================
.. currentmodule:: mlx.core
.. autoproperty:: array.dtype

View File

@ -0,0 +1,6 @@
mlx.core.array.exp
==================
.. currentmodule:: mlx.core
.. automethod:: array.exp

View File

@ -0,0 +1,6 @@
mlx.core.array.item
===================
.. currentmodule:: mlx.core
.. automethod:: array.item

View File

@ -0,0 +1,6 @@
mlx.core.array.log
==================
.. currentmodule:: mlx.core
.. automethod:: array.log

View File

@ -0,0 +1,6 @@
mlx.core.array.log1p
====================
.. currentmodule:: mlx.core
.. automethod:: array.log1p

View File

@ -0,0 +1,6 @@
mlx.core.array.logsumexp
========================
.. currentmodule:: mlx.core
.. automethod:: array.logsumexp

View File

@ -0,0 +1,6 @@
mlx.core.array.max
==================
.. currentmodule:: mlx.core
.. automethod:: array.max

View File

@ -0,0 +1,6 @@
mlx.core.array.mean
===================
.. currentmodule:: mlx.core
.. automethod:: array.mean

View File

@ -0,0 +1,6 @@
mlx.core.array.min
==================
.. currentmodule:: mlx.core
.. automethod:: array.min

View File

@ -0,0 +1,6 @@
mlx.core.array.ndim
===================
.. currentmodule:: mlx.core
.. autoproperty:: array.ndim

View File

@ -0,0 +1,6 @@
mlx.core.array.prod
===================
.. currentmodule:: mlx.core
.. automethod:: array.prod

View File

@ -0,0 +1,6 @@
mlx.core.array.reciprocal
=========================
.. currentmodule:: mlx.core
.. automethod:: array.reciprocal

View File

@ -0,0 +1,6 @@
mlx.core.array.reshape
======================
.. currentmodule:: mlx.core
.. automethod:: array.reshape

View File

@ -0,0 +1,6 @@
mlx.core.array.rsqrt
====================
.. currentmodule:: mlx.core
.. automethod:: array.rsqrt

View File

@ -0,0 +1,66 @@
mlx.core.array
==============
.. currentmodule:: mlx.core
.. autoclass:: array
.. automethod:: __init__
.. rubric:: Methods
.. autosummary::
~array.__init__
~array.abs
~array.all
~array.any
~array.argmax
~array.argmin
~array.astype
~array.cos
~array.cummax
~array.cummin
~array.cumprod
~array.cumsum
~array.exp
~array.item
~array.log
~array.log10
~array.log1p
~array.log2
~array.logsumexp
~array.max
~array.mean
~array.min
~array.prod
~array.reciprocal
~array.reshape
~array.rsqrt
~array.sin
~array.split
~array.sqrt
~array.square
~array.squeeze
~array.sum
~array.tolist
~array.transpose
~array.var
.. rubric:: Attributes
.. autosummary::
~array.T
~array.dtype
~array.ndim
~array.shape
~array.size

View File

@ -0,0 +1,6 @@
mlx.core.array.shape
====================
.. currentmodule:: mlx.core
.. autoproperty:: array.shape

View File

@ -0,0 +1,6 @@
mlx.core.array.sin
==================
.. currentmodule:: mlx.core
.. automethod:: array.sin

View File

@ -0,0 +1,6 @@
mlx.core.array.size
===================
.. currentmodule:: mlx.core
.. autoproperty:: array.size

View File

@ -0,0 +1,6 @@
mlx.core.array.split
====================
.. currentmodule:: mlx.core
.. automethod:: array.split

View File

@ -0,0 +1,6 @@
mlx.core.array.sqrt
===================
.. currentmodule:: mlx.core
.. automethod:: array.sqrt

View File

@ -0,0 +1,6 @@
mlx.core.array.square
=====================
.. currentmodule:: mlx.core
.. automethod:: array.square

View File

@ -0,0 +1,6 @@
mlx.core.array.sum
==================
.. currentmodule:: mlx.core
.. automethod:: array.sum

View File

@ -0,0 +1,6 @@
mlx.core.array.tolist
=====================
.. currentmodule:: mlx.core
.. automethod:: array.tolist

View File

@ -0,0 +1,6 @@
mlx.core.array.transpose
========================
.. currentmodule:: mlx.core
.. automethod:: array.transpose

View File

@ -0,0 +1,6 @@
mlx.core.array.var
==================
.. currentmodule:: mlx.core
.. automethod:: array.var

View File

@ -0,0 +1,6 @@
mlx.core.array\_equal
=====================
.. currentmodule:: mlx.core
.. autofunction:: array_equal

View File

@ -0,0 +1,6 @@
mlx.core.broadcast\_to
======================
.. currentmodule:: mlx.core
.. autofunction:: broadcast_to

View File

@ -0,0 +1,6 @@
mlx.core.concatenate
====================
.. currentmodule:: mlx.core
.. autofunction:: concatenate

View File

@ -0,0 +1,6 @@
mlx.core.conv1d
===============
.. currentmodule:: mlx.core
.. autofunction:: conv1d

View File

@ -0,0 +1,6 @@
mlx.core.conv2d
===============
.. currentmodule:: mlx.core
.. autofunction:: conv2d

View File

@ -0,0 +1,6 @@
mlx.core.convolve
=================
.. currentmodule:: mlx.core
.. autofunction:: convolve

View File

@ -0,0 +1,6 @@
mlx.core.cos
============
.. currentmodule:: mlx.core
.. autofunction:: cos

View File

@ -0,0 +1,6 @@
mlx.core.cosh
=============
.. currentmodule:: mlx.core
.. autofunction:: cosh

View File

@ -0,0 +1,6 @@
mlx.core.default\_device
========================
.. currentmodule:: mlx.core
.. autofunction:: default_device

View File

@ -0,0 +1,6 @@
mlx.core.default\_stream
========================
.. currentmodule:: mlx.core
.. autofunction:: default_stream

View File

@ -0,0 +1,6 @@
mlx.core.divide
===============
.. currentmodule:: mlx.core
.. autofunction:: divide

View File

@ -0,0 +1,6 @@
mlx.core.equal
==============
.. currentmodule:: mlx.core
.. autofunction:: equal

View File

@ -0,0 +1,6 @@
mlx.core.erf
============
.. currentmodule:: mlx.core
.. autofunction:: erf

View File

@ -0,0 +1,6 @@
mlx.core.erfinv
===============
.. currentmodule:: mlx.core
.. autofunction:: erfinv

View File

@ -0,0 +1,6 @@
mlx.core.eval
=============
.. currentmodule:: mlx.core
.. autofunction:: eval

View File

@ -0,0 +1,6 @@
mlx.core.exp
============
.. currentmodule:: mlx.core
.. autofunction:: exp

View File

@ -0,0 +1,6 @@
mlx.core.expand\_dims
=====================
.. currentmodule:: mlx.core
.. autofunction:: expand_dims

View File

@ -0,0 +1,6 @@
mlx.core.fft.fft
================
.. currentmodule:: mlx.core.fft
.. autofunction:: fft

View File

@ -0,0 +1,6 @@
mlx.core.fft.fft2
=================
.. currentmodule:: mlx.core.fft
.. autofunction:: fft2

View File

@ -0,0 +1,6 @@
mlx.core.fft.fftn
=================
.. currentmodule:: mlx.core.fft
.. autofunction:: fftn

View File

@ -0,0 +1,6 @@
mlx.core.fft.ifft
=================
.. currentmodule:: mlx.core.fft
.. autofunction:: ifft

View File

@ -0,0 +1,6 @@
mlx.core.fft.ifft2
==================
.. currentmodule:: mlx.core.fft
.. autofunction:: ifft2

View File

@ -0,0 +1,6 @@
mlx.core.fft.ifftn
==================
.. currentmodule:: mlx.core.fft
.. autofunction:: ifftn

View File

@ -0,0 +1,6 @@
mlx.core.fft.irfft
==================
.. currentmodule:: mlx.core.fft
.. autofunction:: irfft

View File

@ -0,0 +1,6 @@
mlx.core.fft.irfft2
===================
.. currentmodule:: mlx.core.fft
.. autofunction:: irfft2

View File

@ -0,0 +1,6 @@
mlx.core.fft.irfftn
===================
.. currentmodule:: mlx.core.fft
.. autofunction:: irfftn

View File

@ -0,0 +1,6 @@
mlx.core.fft.rfft
=================
.. currentmodule:: mlx.core.fft
.. autofunction:: rfft

View File

@ -0,0 +1,6 @@
mlx.core.fft.rfft2
==================
.. currentmodule:: mlx.core.fft
.. autofunction:: rfft2

View File

@ -0,0 +1,6 @@
mlx.core.fft.rfftn
==================
.. currentmodule:: mlx.core.fft
.. autofunction:: rfftn

View File

@ -0,0 +1,6 @@
mlx.core.full
=============
.. currentmodule:: mlx.core
.. autofunction:: full

View File

@ -0,0 +1,6 @@
mlx.core.grad
=============
.. currentmodule:: mlx.core
.. autofunction:: grad

View File

@ -0,0 +1,6 @@
mlx.core.greater
================
.. currentmodule:: mlx.core
.. autofunction:: greater

View File

@ -0,0 +1,6 @@
mlx.core.greater\_equal
=======================
.. currentmodule:: mlx.core
.. autofunction:: greater_equal

View File

@ -0,0 +1,6 @@
mlx.core.jvp
============
.. currentmodule:: mlx.core
.. autofunction:: jvp

View File

@ -0,0 +1,6 @@
mlx.core.less
=============
.. currentmodule:: mlx.core
.. autofunction:: less

View File

@ -0,0 +1,6 @@
mlx.core.less\_equal
====================
.. currentmodule:: mlx.core
.. autofunction:: less_equal

View File

@ -0,0 +1,6 @@
mlx.core.load
=============
.. currentmodule:: mlx.core
.. autofunction:: load

View File

@ -0,0 +1,6 @@
mlx.core.log
============
.. currentmodule:: mlx.core
.. autofunction:: log

View File

@ -0,0 +1,6 @@
mlx.core.log10
==============
.. currentmodule:: mlx.core
.. autofunction:: log10

View File

@ -0,0 +1,6 @@
mlx.core.log1p
==============
.. currentmodule:: mlx.core
.. autofunction:: log1p

Some files were not shown because too many files have changed in this diff Show More