docs update

This commit is contained in:
Awni Hannun
2024-06-06 20:28:06 -07:00
committed by CircleCI Docs
parent d7a78fbe2b
commit 85f70be0e6
2046 changed files with 323202 additions and 2319 deletions

View File

@@ -43,6 +43,7 @@ are the CPU and GPU.
usage/function_transforms
usage/compile
usage/numpy
usage/distributed
usage/using_streams
.. toctree::
@@ -69,6 +70,7 @@ are the CPU and GPU.
python/metal
python/nn
python/optimizers
python/distributed
python/tree_utils
.. toctree::

View File

@@ -186,8 +186,8 @@ should point to the path to the built metal library.
Binary Size Minimization
~~~~~~~~~~~~~~~~~~~~~~~~
To produce a smaller binary use the CMake flags `CMAKE_BUILD_TYPE=MinSizeRel`
and `BUILD_SHARED_LIBS=ON`.
To produce a smaller binary use the CMake flags ``CMAKE_BUILD_TYPE=MinSizeRel``
and ``BUILD_SHARED_LIBS=ON``.
The MLX CMake build has several additional options to make smaller binaries.
For example, if you don't need the CPU backend or support for safetensors and
@@ -203,7 +203,7 @@ GGUF, you can do:
-DMLX_BUILD_GGUF=OFF \
-DMLX_METAL_JIT=ON
THE `MLX_METAL_JIT` flag minimizes the size of the MLX Metal library which
THE ``MLX_METAL_JIT`` flag minimizes the size of the MLX Metal library which
contains pre-built GPU kernels. This substantially reduces the size of the
Metal library by run-time compiling kernels the first time they are used in MLX
on a given machine. Note run-time compilation incurs a cold-start cost which can

View File

@@ -55,6 +55,7 @@
~array.tolist
~array.transpose
~array.var
~array.view

View File

@@ -1,6 +0,0 @@
mlx.core.block\_sparse\_mm
==========================
.. currentmodule:: mlx.core
.. autofunction:: block_sparse_mm

View File

@@ -0,0 +1,25 @@
mlx.core.distributed.Group
==========================
.. currentmodule:: mlx.core.distributed
.. autoclass:: Group
.. automethod:: __init__
.. rubric:: Methods
.. autosummary::
~Group.__init__
~Group.rank
~Group.size
~Group.split

View File

@@ -0,0 +1,6 @@
mlx.core.distributed.all\_gather
================================
.. currentmodule:: mlx.core.distributed
.. autofunction:: all_gather

View File

@@ -0,0 +1,6 @@
mlx.core.distributed.all\_sum
=============================
.. currentmodule:: mlx.core.distributed
.. autofunction:: all_sum

View File

@@ -0,0 +1,6 @@
mlx.core.distributed.init
=========================
.. currentmodule:: mlx.core.distributed
.. autofunction:: init

View File

@@ -0,0 +1,6 @@
mlx.core.distributed.is\_available
==================================
.. currentmodule:: mlx.core.distributed
.. autofunction:: is_available

View File

@@ -0,0 +1,6 @@
mlx.core.view
=============
.. currentmodule:: mlx.core
.. autofunction:: view

View File

@@ -0,0 +1,19 @@
.. _distributed:
.. currentmodule:: mlx.core.distributed
Distributed Communication
==========================
MLX provides a distributed communication package using MPI. The MPI library is
loaded at runtime; if MPI is available then distributed communication is also
made available.
.. autosummary::
:toctree: _autosummary
Group
is_available
init
all_sum
all_gather

View File

@@ -0,0 +1,16 @@
mlx.nn.GLU
==========
.. currentmodule:: mlx.nn
.. autoclass:: GLU
.. rubric:: Methods
.. autosummary::

View File

@@ -0,0 +1,16 @@
mlx.nn.HardShrink
=================
.. currentmodule:: mlx.nn
.. autoclass:: HardShrink
.. rubric:: Methods
.. autosummary::

View File

@@ -0,0 +1,16 @@
mlx.nn.HardTanh
===============
.. currentmodule:: mlx.nn
.. autoclass:: HardTanh
.. rubric:: Methods
.. autosummary::

View File

@@ -0,0 +1,16 @@
mlx.nn.Hardswish
================
.. currentmodule:: mlx.nn
.. autoclass:: Hardswish
.. rubric:: Methods
.. autosummary::

View File

@@ -0,0 +1,16 @@
mlx.nn.LeakyReLU
================
.. currentmodule:: mlx.nn
.. autoclass:: LeakyReLU
.. rubric:: Methods
.. autosummary::

View File

@@ -0,0 +1,16 @@
mlx.nn.ReLU6
============
.. currentmodule:: mlx.nn
.. autoclass:: ReLU6
.. rubric:: Methods
.. autosummary::

View File

@@ -0,0 +1,16 @@
mlx.nn.Softmax
==============
.. currentmodule:: mlx.nn
.. autoclass:: Softmax
.. rubric:: Methods
.. autosummary::

View File

@@ -0,0 +1,16 @@
mlx.nn.Softmin
==============
.. currentmodule:: mlx.nn
.. autoclass:: Softmin
.. rubric:: Methods
.. autosummary::

View File

@@ -0,0 +1,16 @@
mlx.nn.Softplus
===============
.. currentmodule:: mlx.nn
.. autoclass:: Softplus
.. rubric:: Methods
.. autosummary::

View File

@@ -0,0 +1,16 @@
mlx.nn.Softsign
===============
.. currentmodule:: mlx.nn
.. autoclass:: Softsign
.. rubric:: Methods
.. autosummary::

View File

@@ -0,0 +1,16 @@
mlx.nn.Tanh
===========
.. currentmodule:: mlx.nn
.. autoclass:: Tanh
.. rubric:: Methods
.. autosummary::

View File

@@ -0,0 +1,11 @@
mlx.nn.hard\_shrink
===================
.. currentmodule:: mlx.nn
.. autoclass:: hard_shrink

View File

@@ -0,0 +1,11 @@
mlx.nn.hard\_tanh
=================
.. currentmodule:: mlx.nn
.. autoclass:: hard_tanh

View File

@@ -0,0 +1,11 @@
mlx.nn.softmin
==============
.. currentmodule:: mlx.nn
.. autoclass:: softmin

View File

@@ -17,6 +17,8 @@ simple functions.
gelu_approx
gelu_fast_approx
glu
hard_shrink
hard_tanh
hardswish
leaky_relu
log_sigmoid
@@ -29,6 +31,7 @@ simple functions.
sigmoid
silu
softmax
softmin
softplus
softshrink
step

View File

@@ -21,10 +21,15 @@ Layers
Dropout3d
Embedding
GELU
GLU
GroupNorm
GRU
HardShrink
HardTanh
Hardswish
InstanceNorm
LayerNorm
LeakyReLU
Linear
LSTM
MaxPool1d
@@ -36,13 +41,19 @@ Layers
QuantizedLinear
RMSNorm
ReLU
ReLU6
RNN
RoPE
SELU
Sequential
SiLU
SinusoidalPositionalEncoding
Softmin
Softshrink
Softsign
Softmax
Softplus
Step
Tanh
Transformer
Upsample

View File

@@ -156,6 +156,7 @@ Operations
tril
triu
var
view
where
zeros
zeros_like

View File

@@ -0,0 +1,166 @@
.. _usage_distributed:
Distributed Communication
=========================
.. currentmodule:: mlx.core.distributed
MLX utilizes `MPI <https://en.wikipedia.org/wiki/Message_Passing_Interface>`_ to
provide distributed communication operations that allow the computational cost
of training or inference to be shared across many physical machines. You can
see a list of the supported operations in the :ref:`API docs<distributed>`.
.. note::
A lot of operations may not be supported or not as fast as they should be.
We are adding more and tuning the ones we have as we are figuring out the
best way to do distributed computing on Macs using MLX.
Getting Started
---------------
MLX already comes with the ability to "talk" to MPI if it is installed on the
machine. The minimal distributed program in MLX is as simple as:
.. code:: python
import mlx.core as mx
world = mx.distributed.init()
x = mx.distributed.all_sum(mx.ones(10))
print(world.rank(), x)
The program above sums the array ``mx.ones(10)`` across all
distributed processes. If simply run with ``python``, however, only one
process is launched and no distributed communication takes place.
To launch the program in distributed mode we need to use ``mpirun`` or
``mpiexec`` depending on the MPI installation. The simplest possible way is the
following:
.. code:: shell
$ mpirun -np 2 python test.py
1 array([2, 2, 2, ..., 2, 2, 2], dtype=float32)
0 array([2, 2, 2, ..., 2, 2, 2], dtype=float32)
The above launches two processes on the same (local) machine and we can see
both standard output streams. The processes send the array of 1s to each other
and compute the sum which is printed. Launching with ``mpirun -np 4 ...`` would
print 4 etc.
Installing MPI
---------------
MPI can be installed with Homebrew, using the Anaconda package manager or
compiled from source. Most of our testing is done using ``openmpi`` installed
with the Anaconda package manager as follows:
.. code:: shell
$ conda install openmpi
Installing with Homebrew may require specifying the location of ``libmpi.dyld``
so that MLX can find it and load it at runtime. This can simply be achieved by
passing the ``DYLD_LIBRARY_PATH`` environment variable to ``mpirun``.
.. code:: shell
$ mpirun -np 2 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python test.py
Setting up Remote Hosts
-----------------------
MPI can automatically connect to remote hosts and set up the communication over
the network if the remote hosts can be accessed via ssh. A good checklist to
debug connectivity issues is the following:
* ``ssh hostname`` works from all machines to all machines without asking for
password or host confirmation
* ``mpirun`` is accessible on all machines. You can call ``mpirun`` using its
full path to force all machines to use a specific path.
* Ensure that the ``hostname`` used by MPI is the one that you have configured
in the ``.ssh/config`` files on all machines.
.. note::
For an example hostname ``foo.bar.com`` MPI can use only ``foo`` as
the hostname passed to ssh if the current hostname matches ``*.bar.com``.
An easy way to pass the host names to MPI is using a host file. A host file
looks like the following, where ``host1`` and ``host2`` should be the fully
qualified domain names or IPs for these hosts.
.. code::
host1 slots=1
host2 slots=1
When using MLX, it is very likely that you want to use 1 slot per host, ie one
process per host. The hostfile also needs to contain the current
host if you want to run on the local host. Passing the host file to
``mpirun`` is simply done using the ``--hostfile`` command line argument.
Training Example
----------------
In this section we will adapt an MLX training loop to support data parallel
distributed training. Namely, we will average the gradients across a set of
hosts before applying them to the model.
Our training loop looks like the following code snippet if we omit the model,
dataset and optimizer initialization.
.. code:: python
model = ...
optimizer = ...
dataset = ...
def step(model, x, y):
loss, grads = loss_grad_fn(model, x, y)
optimizer.update(model, grads)
return loss
for x, y in dataset:
loss = step(model, x, y)
mx.eval(loss, model.parameters())
All we have to do to average the gradients across machines is perform an
:func:`all_sum` and divide by the size of the :class:`Group`. Namely we
have to :func:`mlx.utils.tree_map` the gradients with following function.
.. code:: python
def all_avg(x):
return mx.distributed.all_sum(x) / mx.distributed.init().size()
Putting everything together our training loop step looks as follows with
everything else remaining the same.
.. code:: python
from mlx.utils import tree_map
def all_reduce_grads(grads):
N = mx.distributed.init()
if N == 1:
return grads
return tree_map(
lambda x: mx.distributed.all_sum(x) / N,
grads)
def step(model, x, y):
loss, grads = loss_grad_fn(model, x, y)
grads = all_reduce_grads(grads) # <--- This line was added
optimizer.update(model, grads)
return loss
Tuning All Reduce
-----------------
We are working on improving the performance of all reduce on MLX but for now
the two main things one can do to extract the most out of distributed training with MLX are:
1. Perform a few large reductions instead of many small ones to improve
bandwidth and latency
2. Pass ``--mca btl_tcp_links 4`` to ``mpirun`` to configure it to use 4 tcp
connections between each host to improve bandwidth

View File

@@ -3,7 +3,11 @@
Conversion to NumPy and Other Frameworks
========================================
MLX array implements the `Python Buffer Protocol <https://docs.python.org/3/c-api/buffer.html>`_.
MLX array supports conversion between other frameworks with either:
* The `Python Buffer Protocol <https://docs.python.org/3/c-api/buffer.html>`_.
* `DLPack <https://dmlc.github.io/dlpack/latest/>`_.
Let's convert an array to NumPy and back.
.. code-block:: python