From d343782c8b9d5670b1db10c7297ef2b1d126e5f6 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Fri, 21 Mar 2025 11:23:10 -0700 Subject: [PATCH] Cross platform libmpi loading (#1975) --- .circleci/config.yml | 3 +++ CMakeLists.txt | 18 ----------------- mlx/distributed/mpi/CMakeLists.txt | 2 +- mlx/distributed/mpi/mpi.cpp | 24 ++++++++++++++++++++-- mlx/distributed/mpi/mpi_declarations.h | 28 ++++++++++++++++++++++++++ 5 files changed, 54 insertions(+), 21 deletions(-) create mode 100644 mlx/distributed/mpi/mpi_declarations.h diff --git a/.circleci/config.yml b/.circleci/config.yml index 9c8cb31a3..64871aaa0 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -89,6 +89,7 @@ jobs: pip install numpy sudo apt-get update sudo apt-get install libblas-dev liblapack-dev liblapacke-dev + sudo apt-get install openmpi-bin openmpi-common libopenmpi-dev - run: name: Install Python package command: | @@ -110,6 +111,8 @@ jobs: name: Run Python tests command: | python3 -m unittest discover python/tests -v + mpirun --bind-to none -host localhost:8 -np 8 python python/tests/mpi_test_distributed.py + mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py - run: name: Build CPP only command: | diff --git a/CMakeLists.txt b/CMakeLists.txt index 672b9810c..e2002fc94 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -212,24 +212,6 @@ else() set(MLX_BUILD_ACCELERATE OFF) endif() -find_package(MPI) -if(MPI_FOUND) - execute_process( - COMMAND zsh "-c" "mpirun --version" - OUTPUT_VARIABLE MPI_VERSION - ERROR_QUIET) - if(${MPI_VERSION} MATCHES ".*Open MPI.*") - target_include_directories(mlx PRIVATE ${MPI_INCLUDE_PATH}) - elseif(MPI_VERSION STREQUAL "") - set(MPI_FOUND FALSE) - message( - WARNING "MPI found but mpirun is not available. Building without MPI.") - else() - set(MPI_FOUND FALSE) - message(WARNING "MPI which is not OpenMPI found. Building without MPI.") - endif() -endif() - message(STATUS "Downloading json") FetchContent_Declare( json diff --git a/mlx/distributed/mpi/CMakeLists.txt b/mlx/distributed/mpi/CMakeLists.txt index 7063a101f..842f70b55 100644 --- a/mlx/distributed/mpi/CMakeLists.txt +++ b/mlx/distributed/mpi/CMakeLists.txt @@ -1,4 +1,4 @@ -if(MPI_FOUND AND MLX_BUILD_CPU) +if(MLX_BUILD_CPU) target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/mpi.cpp) else() target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_mpi.cpp) diff --git a/mlx/distributed/mpi/mpi.cpp b/mlx/distributed/mpi/mpi.cpp index b9136f701..77b346037 100644 --- a/mlx/distributed/mpi/mpi.cpp +++ b/mlx/distributed/mpi/mpi.cpp @@ -1,12 +1,13 @@ // Copyright © 2024 Apple Inc. #include -#include +#include #include "mlx/backend/cpu/encoder.h" #include "mlx/distributed/distributed.h" #include "mlx/distributed/distributed_impl.h" #include "mlx/distributed/mpi/mpi.h" +#include "mlx/distributed/mpi/mpi_declarations.h" #define LOAD_SYMBOL(symbol, variable) \ { \ @@ -18,6 +19,12 @@ } \ } +#ifdef __APPLE__ +static constexpr const char* libmpi_name = "libmpi.dylib"; +#else +static constexpr const char* libmpi_name = "libmpi.so"; +#endif + namespace mlx::core::distributed::mpi { using GroupImpl = mlx::core::distributed::detail::GroupImpl; @@ -47,11 +54,24 @@ struct MPIWrapper { MPIWrapper() { initialized_ = false; - libmpi_handle_ = dlopen("libmpi.dylib", RTLD_NOW | RTLD_GLOBAL); + libmpi_handle_ = dlopen(libmpi_name, RTLD_NOW | RTLD_GLOBAL); if (libmpi_handle_ == nullptr) { return; } + // Check library version and warn if it isn't Open MPI + int (*get_version)(char*, int*); + LOAD_SYMBOL(MPI_Get_library_version, get_version); + char version_ptr[MPI_MAX_LIBRARY_VERSION_STRING]; + int version_length = 0; + get_version(version_ptr, &version_length); + std::string_view version(version_ptr, version_length); + if (version.find("Open MPI") == std::string::npos) { + std::cerr << "[mpi] MPI found but it does not appear to be Open MPI." + << "MLX requires Open MPI but this is " << version << std::endl; + return; + } + // API LOAD_SYMBOL(MPI_Init, init); LOAD_SYMBOL(MPI_Finalize, finalize); diff --git a/mlx/distributed/mpi/mpi_declarations.h b/mlx/distributed/mpi/mpi_declarations.h new file mode 100644 index 000000000..99c1a9cbb --- /dev/null +++ b/mlx/distributed/mpi/mpi_declarations.h @@ -0,0 +1,28 @@ +// Copyright © 2024 Apple Inc. + +// Constants + +#define MPI_SUCCESS 0 +#define MPI_ANY_SOURCE -1 +#define MPI_ANY_TAG -1 +#define MPI_IN_PLACE ((void*)1) +#define MPI_MAX_LIBRARY_VERSION_STRING 256 + +// Define all the types that we use so that we don't include which +// causes linker errors on some platforms. +// +// NOTE: We define everything for openmpi. + +typedef void* MPI_Comm; +typedef void* MPI_Datatype; +typedef void* MPI_Op; + +typedef void(MPI_User_function)(void*, void*, int*, MPI_Datatype*); + +typedef struct ompi_status_public_t { + int MPI_SOURCE; + int MPI_TAG; + int MPI_ERROR; + int _cancelled; + size_t _ucount; +} MPI_Status;