Compatibility with pip-installed openmpi (#2741)
Some checks failed
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.10) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_with_tests (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
Nightly Build / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Nightly Build / Linux Fedora CPP Build (x86_64) (push) Has been cancelled

This commit is contained in:
Pedro Cuenca
2025-11-08 01:58:31 +01:00
committed by GitHub
parent be9e2aebd6
commit eba6a9d163
3 changed files with 52 additions and 17 deletions

View File

@@ -7,12 +7,13 @@ Distributed Communication
MLX supports distributed communication operations that allow the computational cost
of training or inference to be shared across many physical machines. At the
moment we support two different communication backends:
moment we support three different communication backends:
* `MPI <https://en.wikipedia.org/wiki/Message_Passing_Interface>`_ a
full-featured and mature distributed communications library
* A **ring** backend of our own that uses native TCP sockets and should be
faster for thunderbolt connections.
* A **ring** backend of our own that uses native TCP sockets. It should be
faster for thunderbolt connections, but it also works over Ethernet.
* `nccl <https://developer.nvidia.com/nccl>`_, for use in CUDA environments.
The list of all currently supported operations and their documentation can be
seen in the :ref:`API docs<distributed>`.
@@ -84,9 +85,8 @@ Selecting Backend
^^^^^^^^^^^^^^^^^
You can select the backend you want to use when calling :func:`init` by passing
one of ``{'any', 'ring', 'mpi'}``. When passing ``any``, MLX will try to
initialize the ``ring`` backend and if it fails the ``mpi`` backend. If they
both fail then a singleton group is created.
one of ``{'any', 'ring', 'mpi', 'nccl'}``. When passing ``any``, MLX will try all
available backends. If they all fail then a singleton group is created.
.. note::
After a distributed backend is successfully initialized :func:`init` will
@@ -220,7 +220,7 @@ print 4 etc.
Installing MPI
^^^^^^^^^^^^^^
MPI can be installed with Homebrew, using the Anaconda package manager or
MPI can be installed with Homebrew, pip, using the Anaconda package manager, or
compiled from source. Most of our testing is done using ``openmpi`` installed
with the Anaconda package manager as follows:
@@ -228,14 +228,16 @@ with the Anaconda package manager as follows:
$ conda install conda-forge::openmpi
Installing with Homebrew may require specifying the location of ``libmpi.dyld``
Installing with Homebrew or pip requires specifying the location of ``libmpi.dyld``
so that MLX can find it and load it at runtime. This can simply be achieved by
passing the ``DYLD_LIBRARY_PATH`` environment variable to ``mpirun`` and it is
done automatically by ``mlx.launch``.
done automatically by ``mlx.launch``. Some environments use a non-standard
library filename that can be specified using the ``MPI_LIBNAME`` environment
variable. This is automatically taken care of by ``mlx.launch`` as well.
.. code:: shell
$ mpirun -np 2 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python test.py
$ mpirun -np 2 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ -x MPI_LIBNAME=libmpi.40.dylib python test.py
$ # or simply
$ mlx.launch -n 2 test.py

View File

@@ -1,6 +1,7 @@
// Copyright © 2024 Apple Inc.
#include <dlfcn.h>
#include <cstdlib>
#include <iostream>
#include "mlx/backend/cpu/encoder.h"
@@ -19,11 +20,17 @@
} \
}
static const char* get_libmpi_name() {
const char* libname = std::getenv("MLX_MPI_LIBNAME");
if (libname != nullptr) {
return libname;
}
#ifdef __APPLE__
static constexpr const char* libmpi_name = "libmpi.dylib";
return "libmpi.dylib";
#else
static constexpr const char* libmpi_name = "libmpi.so";
return "libmpi.so";
#endif
}
namespace mlx::core::distributed::mpi {
@@ -94,7 +101,7 @@ struct MPIWrapper {
MPIWrapper() {
initialized_ = false;
libmpi_handle_ = dlopen(libmpi_name, RTLD_NOW | RTLD_GLOBAL);
libmpi_handle_ = dlopen(get_libmpi_name(), RTLD_NOW | RTLD_GLOBAL);
if (libmpi_handle_ == nullptr) {
return;
}

View File

@@ -5,6 +5,7 @@ import base64
import ipaddress
import json
import os
import platform
import shlex
import shutil
import sys
@@ -386,15 +387,40 @@ def launch_ring(parser, hosts, args, command):
t.join()
def get_mpi_libname():
try:
ompi_info = run(["which", "ompi_info"], check=True, capture_output=True)
ompi_info = ompi_info.stdout.strip().decode()
if platform.system() == "Darwin":
otool_output = run(
["otool", "-L", ompi_info], check=True, capture_output=True
)
else:
otool_output = run(["ldd", ompi_info], check=True, capture_output=True)
otool_output = otool_output.stdout.decode()
# StopIteration if not found
libmpi_line = next(
filter(lambda line: "libmpi" in line, otool_output.splitlines())
)
return libmpi_line.strip().split()[0].removeprefix("@rpath/")
except:
return None
def launch_mpi(parser, hosts, args, command):
mpirun = run(["which", "mpirun"], check=True, capture_output=True)
mpirun = mpirun.stdout.strip().decode()
# Homebrew libmpi doesn't work with anaconda python out of the box.
# TODO: Check if we should do this with every mpirun
if "homebrew" in mpirun:
# Compatibility with homebrew and pip installs
mpi_libname = get_mpi_libname()
if mpi_libname is not None:
dyld = Path(mpirun).parent.parent / "lib"
args.env = [f"DYLD_LIBRARY_PATH={str(dyld)}"] + args.env
args.env = [
f"DYLD_LIBRARY_PATH={str(dyld)}",
f"MLX_MPI_LIBNAME={mpi_libname}",
] + args.env
log(args.verbose, f"Using '{mpirun}'")
with tempfile.NamedTemporaryFile(mode="w") as f: