diff --git a/docs/src/usage/distributed.rst b/docs/src/usage/distributed.rst index f1a1e4e94..0b83709e2 100644 --- a/docs/src/usage/distributed.rst +++ b/docs/src/usage/distributed.rst @@ -7,12 +7,13 @@ Distributed Communication MLX supports distributed communication operations that allow the computational cost of training or inference to be shared across many physical machines. At the -moment we support two different communication backends: +moment we support three different communication backends: * `MPI `_ a full-featured and mature distributed communications library -* A **ring** backend of our own that uses native TCP sockets and should be - faster for thunderbolt connections. +* A **ring** backend of our own that uses native TCP sockets. It should be + faster for thunderbolt connections, but it also works over Ethernet. +* `nccl `_, for use in CUDA environments. The list of all currently supported operations and their documentation can be seen in the :ref:`API docs`. @@ -84,9 +85,8 @@ Selecting Backend ^^^^^^^^^^^^^^^^^ You can select the backend you want to use when calling :func:`init` by passing -one of ``{'any', 'ring', 'mpi'}``. When passing ``any``, MLX will try to -initialize the ``ring`` backend and if it fails the ``mpi`` backend. If they -both fail then a singleton group is created. +one of ``{'any', 'ring', 'mpi', 'nccl'}``. When passing ``any``, MLX will try all +available backends. If they all fail then a singleton group is created. .. note:: After a distributed backend is successfully initialized :func:`init` will @@ -220,7 +220,7 @@ print 4 etc. Installing MPI ^^^^^^^^^^^^^^ -MPI can be installed with Homebrew, using the Anaconda package manager or +MPI can be installed with Homebrew, pip, using the Anaconda package manager, or compiled from source. Most of our testing is done using ``openmpi`` installed with the Anaconda package manager as follows: @@ -228,14 +228,16 @@ with the Anaconda package manager as follows: $ conda install conda-forge::openmpi -Installing with Homebrew may require specifying the location of ``libmpi.dyld`` +Installing with Homebrew or pip requires specifying the location of ``libmpi.dyld`` so that MLX can find it and load it at runtime. This can simply be achieved by passing the ``DYLD_LIBRARY_PATH`` environment variable to ``mpirun`` and it is -done automatically by ``mlx.launch``. +done automatically by ``mlx.launch``. Some environments use a non-standard +library filename that can be specified using the ``MPI_LIBNAME`` environment +variable. This is automatically taken care of by ``mlx.launch`` as well. .. code:: shell - $ mpirun -np 2 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python test.py + $ mpirun -np 2 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ -x MPI_LIBNAME=libmpi.40.dylib python test.py $ # or simply $ mlx.launch -n 2 test.py diff --git a/mlx/distributed/mpi/mpi.cpp b/mlx/distributed/mpi/mpi.cpp index bf87425e4..3b176e6e6 100644 --- a/mlx/distributed/mpi/mpi.cpp +++ b/mlx/distributed/mpi/mpi.cpp @@ -1,6 +1,7 @@ // Copyright © 2024 Apple Inc. #include +#include #include #include "mlx/backend/cpu/encoder.h" @@ -19,11 +20,17 @@ } \ } +static const char* get_libmpi_name() { + const char* libname = std::getenv("MLX_MPI_LIBNAME"); + if (libname != nullptr) { + return libname; + } #ifdef __APPLE__ -static constexpr const char* libmpi_name = "libmpi.dylib"; + return "libmpi.dylib"; #else -static constexpr const char* libmpi_name = "libmpi.so"; + return "libmpi.so"; #endif +} namespace mlx::core::distributed::mpi { @@ -94,7 +101,7 @@ struct MPIWrapper { MPIWrapper() { initialized_ = false; - libmpi_handle_ = dlopen(libmpi_name, RTLD_NOW | RTLD_GLOBAL); + libmpi_handle_ = dlopen(get_libmpi_name(), RTLD_NOW | RTLD_GLOBAL); if (libmpi_handle_ == nullptr) { return; } diff --git a/python/mlx/distributed_run.py b/python/mlx/distributed_run.py index 448e3f954..e4b50a5ce 100644 --- a/python/mlx/distributed_run.py +++ b/python/mlx/distributed_run.py @@ -5,6 +5,7 @@ import base64 import ipaddress import json import os +import platform import shlex import shutil import sys @@ -386,15 +387,40 @@ def launch_ring(parser, hosts, args, command): t.join() +def get_mpi_libname(): + try: + ompi_info = run(["which", "ompi_info"], check=True, capture_output=True) + ompi_info = ompi_info.stdout.strip().decode() + + if platform.system() == "Darwin": + otool_output = run( + ["otool", "-L", ompi_info], check=True, capture_output=True + ) + else: + otool_output = run(["ldd", ompi_info], check=True, capture_output=True) + otool_output = otool_output.stdout.decode() + + # StopIteration if not found + libmpi_line = next( + filter(lambda line: "libmpi" in line, otool_output.splitlines()) + ) + return libmpi_line.strip().split()[0].removeprefix("@rpath/") + except: + return None + + def launch_mpi(parser, hosts, args, command): mpirun = run(["which", "mpirun"], check=True, capture_output=True) mpirun = mpirun.stdout.strip().decode() - # Homebrew libmpi doesn't work with anaconda python out of the box. - # TODO: Check if we should do this with every mpirun - if "homebrew" in mpirun: + # Compatibility with homebrew and pip installs + mpi_libname = get_mpi_libname() + if mpi_libname is not None: dyld = Path(mpirun).parent.parent / "lib" - args.env = [f"DYLD_LIBRARY_PATH={str(dyld)}"] + args.env + args.env = [ + f"DYLD_LIBRARY_PATH={str(dyld)}", + f"MLX_MPI_LIBNAME={mpi_libname}", + ] + args.env log(args.verbose, f"Using '{mpirun}'") with tempfile.NamedTemporaryFile(mode="w") as f: