mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-18 18:28:12 +08:00
Cross platform libmpi loading (#1975)
This commit is contained in:

committed by
GitHub

parent
4e1994e9d7
commit
d343782c8b
@@ -1,4 +1,4 @@
|
||||
if(MPI_FOUND AND MLX_BUILD_CPU)
|
||||
if(MLX_BUILD_CPU)
|
||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/mpi.cpp)
|
||||
else()
|
||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_mpi.cpp)
|
||||
|
@@ -1,12 +1,13 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <dlfcn.h>
|
||||
#include <mpi.h>
|
||||
#include <iostream>
|
||||
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
#include "mlx/distributed/distributed.h"
|
||||
#include "mlx/distributed/distributed_impl.h"
|
||||
#include "mlx/distributed/mpi/mpi.h"
|
||||
#include "mlx/distributed/mpi/mpi_declarations.h"
|
||||
|
||||
#define LOAD_SYMBOL(symbol, variable) \
|
||||
{ \
|
||||
@@ -18,6 +19,12 @@
|
||||
} \
|
||||
}
|
||||
|
||||
#ifdef __APPLE__
|
||||
static constexpr const char* libmpi_name = "libmpi.dylib";
|
||||
#else
|
||||
static constexpr const char* libmpi_name = "libmpi.so";
|
||||
#endif
|
||||
|
||||
namespace mlx::core::distributed::mpi {
|
||||
|
||||
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
|
||||
@@ -47,11 +54,24 @@ struct MPIWrapper {
|
||||
MPIWrapper() {
|
||||
initialized_ = false;
|
||||
|
||||
libmpi_handle_ = dlopen("libmpi.dylib", RTLD_NOW | RTLD_GLOBAL);
|
||||
libmpi_handle_ = dlopen(libmpi_name, RTLD_NOW | RTLD_GLOBAL);
|
||||
if (libmpi_handle_ == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Check library version and warn if it isn't Open MPI
|
||||
int (*get_version)(char*, int*);
|
||||
LOAD_SYMBOL(MPI_Get_library_version, get_version);
|
||||
char version_ptr[MPI_MAX_LIBRARY_VERSION_STRING];
|
||||
int version_length = 0;
|
||||
get_version(version_ptr, &version_length);
|
||||
std::string_view version(version_ptr, version_length);
|
||||
if (version.find("Open MPI") == std::string::npos) {
|
||||
std::cerr << "[mpi] MPI found but it does not appear to be Open MPI."
|
||||
<< "MLX requires Open MPI but this is " << version << std::endl;
|
||||
return;
|
||||
}
|
||||
|
||||
// API
|
||||
LOAD_SYMBOL(MPI_Init, init);
|
||||
LOAD_SYMBOL(MPI_Finalize, finalize);
|
||||
|
28
mlx/distributed/mpi/mpi_declarations.h
Normal file
28
mlx/distributed/mpi/mpi_declarations.h
Normal file
@@ -0,0 +1,28 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
// Constants
|
||||
|
||||
#define MPI_SUCCESS 0
|
||||
#define MPI_ANY_SOURCE -1
|
||||
#define MPI_ANY_TAG -1
|
||||
#define MPI_IN_PLACE ((void*)1)
|
||||
#define MPI_MAX_LIBRARY_VERSION_STRING 256
|
||||
|
||||
// Define all the types that we use so that we don't include <mpi.h> which
|
||||
// causes linker errors on some platforms.
|
||||
//
|
||||
// NOTE: We define everything for openmpi.
|
||||
|
||||
typedef void* MPI_Comm;
|
||||
typedef void* MPI_Datatype;
|
||||
typedef void* MPI_Op;
|
||||
|
||||
typedef void(MPI_User_function)(void*, void*, int*, MPI_Datatype*);
|
||||
|
||||
typedef struct ompi_status_public_t {
|
||||
int MPI_SOURCE;
|
||||
int MPI_TAG;
|
||||
int MPI_ERROR;
|
||||
int _cancelled;
|
||||
size_t _ucount;
|
||||
} MPI_Status;
|
Reference in New Issue
Block a user