Cross platform libmpi loading (#1975)

This commit is contained in:
Angelos Katharopoulos
2025-03-21 11:23:10 -07:00
committed by GitHub
parent 4e1994e9d7
commit d343782c8b
5 changed files with 54 additions and 21 deletions

View File

@@ -1,4 +1,4 @@
if(MPI_FOUND AND MLX_BUILD_CPU)
if(MLX_BUILD_CPU)
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/mpi.cpp)
else()
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_mpi.cpp)

View File

@@ -1,12 +1,13 @@
// Copyright © 2024 Apple Inc.
#include <dlfcn.h>
#include <mpi.h>
#include <iostream>
#include "mlx/backend/cpu/encoder.h"
#include "mlx/distributed/distributed.h"
#include "mlx/distributed/distributed_impl.h"
#include "mlx/distributed/mpi/mpi.h"
#include "mlx/distributed/mpi/mpi_declarations.h"
#define LOAD_SYMBOL(symbol, variable) \
{ \
@@ -18,6 +19,12 @@
} \
}
#ifdef __APPLE__
static constexpr const char* libmpi_name = "libmpi.dylib";
#else
static constexpr const char* libmpi_name = "libmpi.so";
#endif
namespace mlx::core::distributed::mpi {
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
@@ -47,11 +54,24 @@ struct MPIWrapper {
MPIWrapper() {
initialized_ = false;
libmpi_handle_ = dlopen("libmpi.dylib", RTLD_NOW | RTLD_GLOBAL);
libmpi_handle_ = dlopen(libmpi_name, RTLD_NOW | RTLD_GLOBAL);
if (libmpi_handle_ == nullptr) {
return;
}
// Check library version and warn if it isn't Open MPI
int (*get_version)(char*, int*);
LOAD_SYMBOL(MPI_Get_library_version, get_version);
char version_ptr[MPI_MAX_LIBRARY_VERSION_STRING];
int version_length = 0;
get_version(version_ptr, &version_length);
std::string_view version(version_ptr, version_length);
if (version.find("Open MPI") == std::string::npos) {
std::cerr << "[mpi] MPI found but it does not appear to be Open MPI."
<< "MLX requires Open MPI but this is " << version << std::endl;
return;
}
// API
LOAD_SYMBOL(MPI_Init, init);
LOAD_SYMBOL(MPI_Finalize, finalize);

View File

@@ -0,0 +1,28 @@
// Copyright © 2024 Apple Inc.
// Constants
#define MPI_SUCCESS 0
#define MPI_ANY_SOURCE -1
#define MPI_ANY_TAG -1
#define MPI_IN_PLACE ((void*)1)
#define MPI_MAX_LIBRARY_VERSION_STRING 256
// Define all the types that we use so that we don't include <mpi.h> which
// causes linker errors on some platforms.
//
// NOTE: We define everything for openmpi.
typedef void* MPI_Comm;
typedef void* MPI_Datatype;
typedef void* MPI_Op;
typedef void(MPI_User_function)(void*, void*, int*, MPI_Datatype*);
typedef struct ompi_status_public_t {
int MPI_SOURCE;
int MPI_TAG;
int MPI_ERROR;
int _cancelled;
size_t _ucount;
} MPI_Status;