Cross platform libmpi loading (#1975)

This commit is contained in:
Angelos Katharopoulos 2025-03-21 11:23:10 -07:00 committed by GitHub
parent 4e1994e9d7
commit d343782c8b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 54 additions and 21 deletions

View File

@ -89,6 +89,7 @@ jobs:
pip install numpy pip install numpy
sudo apt-get update sudo apt-get update
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
sudo apt-get install openmpi-bin openmpi-common libopenmpi-dev
- run: - run:
name: Install Python package name: Install Python package
command: | command: |
@ -110,6 +111,8 @@ jobs:
name: Run Python tests name: Run Python tests
command: | command: |
python3 -m unittest discover python/tests -v python3 -m unittest discover python/tests -v
mpirun --bind-to none -host localhost:8 -np 8 python python/tests/mpi_test_distributed.py
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py
- run: - run:
name: Build CPP only name: Build CPP only
command: | command: |

View File

@ -212,24 +212,6 @@ else()
set(MLX_BUILD_ACCELERATE OFF) set(MLX_BUILD_ACCELERATE OFF)
endif() endif()
find_package(MPI)
if(MPI_FOUND)
execute_process(
COMMAND zsh "-c" "mpirun --version"
OUTPUT_VARIABLE MPI_VERSION
ERROR_QUIET)
if(${MPI_VERSION} MATCHES ".*Open MPI.*")
target_include_directories(mlx PRIVATE ${MPI_INCLUDE_PATH})
elseif(MPI_VERSION STREQUAL "")
set(MPI_FOUND FALSE)
message(
WARNING "MPI found but mpirun is not available. Building without MPI.")
else()
set(MPI_FOUND FALSE)
message(WARNING "MPI which is not OpenMPI found. Building without MPI.")
endif()
endif()
message(STATUS "Downloading json") message(STATUS "Downloading json")
FetchContent_Declare( FetchContent_Declare(
json json

View File

@ -1,4 +1,4 @@
if(MPI_FOUND AND MLX_BUILD_CPU) if(MLX_BUILD_CPU)
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/mpi.cpp) target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/mpi.cpp)
else() else()
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_mpi.cpp) target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_mpi.cpp)

View File

@ -1,12 +1,13 @@
// Copyright © 2024 Apple Inc. // Copyright © 2024 Apple Inc.
#include <dlfcn.h> #include <dlfcn.h>
#include <mpi.h> #include <iostream>
#include "mlx/backend/cpu/encoder.h" #include "mlx/backend/cpu/encoder.h"
#include "mlx/distributed/distributed.h" #include "mlx/distributed/distributed.h"
#include "mlx/distributed/distributed_impl.h" #include "mlx/distributed/distributed_impl.h"
#include "mlx/distributed/mpi/mpi.h" #include "mlx/distributed/mpi/mpi.h"
#include "mlx/distributed/mpi/mpi_declarations.h"
#define LOAD_SYMBOL(symbol, variable) \ #define LOAD_SYMBOL(symbol, variable) \
{ \ { \
@ -18,6 +19,12 @@
} \ } \
} }
#ifdef __APPLE__
static constexpr const char* libmpi_name = "libmpi.dylib";
#else
static constexpr const char* libmpi_name = "libmpi.so";
#endif
namespace mlx::core::distributed::mpi { namespace mlx::core::distributed::mpi {
using GroupImpl = mlx::core::distributed::detail::GroupImpl; using GroupImpl = mlx::core::distributed::detail::GroupImpl;
@ -47,11 +54,24 @@ struct MPIWrapper {
MPIWrapper() { MPIWrapper() {
initialized_ = false; initialized_ = false;
libmpi_handle_ = dlopen("libmpi.dylib", RTLD_NOW | RTLD_GLOBAL); libmpi_handle_ = dlopen(libmpi_name, RTLD_NOW | RTLD_GLOBAL);
if (libmpi_handle_ == nullptr) { if (libmpi_handle_ == nullptr) {
return; return;
} }
// Check library version and warn if it isn't Open MPI
int (*get_version)(char*, int*);
LOAD_SYMBOL(MPI_Get_library_version, get_version);
char version_ptr[MPI_MAX_LIBRARY_VERSION_STRING];
int version_length = 0;
get_version(version_ptr, &version_length);
std::string_view version(version_ptr, version_length);
if (version.find("Open MPI") == std::string::npos) {
std::cerr << "[mpi] MPI found but it does not appear to be Open MPI."
<< "MLX requires Open MPI but this is " << version << std::endl;
return;
}
// API // API
LOAD_SYMBOL(MPI_Init, init); LOAD_SYMBOL(MPI_Init, init);
LOAD_SYMBOL(MPI_Finalize, finalize); LOAD_SYMBOL(MPI_Finalize, finalize);

View File

@ -0,0 +1,28 @@
// Copyright © 2024 Apple Inc.
// Constants
#define MPI_SUCCESS 0
#define MPI_ANY_SOURCE -1
#define MPI_ANY_TAG -1
#define MPI_IN_PLACE ((void*)1)
#define MPI_MAX_LIBRARY_VERSION_STRING 256
// Define all the types that we use so that we don't include <mpi.h> which
// causes linker errors on some platforms.
//
// NOTE: We define everything for openmpi.
typedef void* MPI_Comm;
typedef void* MPI_Datatype;
typedef void* MPI_Op;
typedef void(MPI_User_function)(void*, void*, int*, MPI_Datatype*);
typedef struct ompi_status_public_t {
int MPI_SOURCE;
int MPI_TAG;
int MPI_ERROR;
int _cancelled;
size_t _ucount;
} MPI_Status;