mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 01:17:26 +08:00
Cross platform libmpi loading (#1975)
This commit is contained in:
parent
4e1994e9d7
commit
d343782c8b
@ -89,6 +89,7 @@ jobs:
|
||||
pip install numpy
|
||||
sudo apt-get update
|
||||
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
||||
sudo apt-get install openmpi-bin openmpi-common libopenmpi-dev
|
||||
- run:
|
||||
name: Install Python package
|
||||
command: |
|
||||
@ -110,6 +111,8 @@ jobs:
|
||||
name: Run Python tests
|
||||
command: |
|
||||
python3 -m unittest discover python/tests -v
|
||||
mpirun --bind-to none -host localhost:8 -np 8 python python/tests/mpi_test_distributed.py
|
||||
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py
|
||||
- run:
|
||||
name: Build CPP only
|
||||
command: |
|
||||
|
@ -212,24 +212,6 @@ else()
|
||||
set(MLX_BUILD_ACCELERATE OFF)
|
||||
endif()
|
||||
|
||||
find_package(MPI)
|
||||
if(MPI_FOUND)
|
||||
execute_process(
|
||||
COMMAND zsh "-c" "mpirun --version"
|
||||
OUTPUT_VARIABLE MPI_VERSION
|
||||
ERROR_QUIET)
|
||||
if(${MPI_VERSION} MATCHES ".*Open MPI.*")
|
||||
target_include_directories(mlx PRIVATE ${MPI_INCLUDE_PATH})
|
||||
elseif(MPI_VERSION STREQUAL "")
|
||||
set(MPI_FOUND FALSE)
|
||||
message(
|
||||
WARNING "MPI found but mpirun is not available. Building without MPI.")
|
||||
else()
|
||||
set(MPI_FOUND FALSE)
|
||||
message(WARNING "MPI which is not OpenMPI found. Building without MPI.")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
message(STATUS "Downloading json")
|
||||
FetchContent_Declare(
|
||||
json
|
||||
|
@ -1,4 +1,4 @@
|
||||
if(MPI_FOUND AND MLX_BUILD_CPU)
|
||||
if(MLX_BUILD_CPU)
|
||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/mpi.cpp)
|
||||
else()
|
||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_mpi.cpp)
|
||||
|
@ -1,12 +1,13 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <dlfcn.h>
|
||||
#include <mpi.h>
|
||||
#include <iostream>
|
||||
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
#include "mlx/distributed/distributed.h"
|
||||
#include "mlx/distributed/distributed_impl.h"
|
||||
#include "mlx/distributed/mpi/mpi.h"
|
||||
#include "mlx/distributed/mpi/mpi_declarations.h"
|
||||
|
||||
#define LOAD_SYMBOL(symbol, variable) \
|
||||
{ \
|
||||
@ -18,6 +19,12 @@
|
||||
} \
|
||||
}
|
||||
|
||||
#ifdef __APPLE__
|
||||
static constexpr const char* libmpi_name = "libmpi.dylib";
|
||||
#else
|
||||
static constexpr const char* libmpi_name = "libmpi.so";
|
||||
#endif
|
||||
|
||||
namespace mlx::core::distributed::mpi {
|
||||
|
||||
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
|
||||
@ -47,11 +54,24 @@ struct MPIWrapper {
|
||||
MPIWrapper() {
|
||||
initialized_ = false;
|
||||
|
||||
libmpi_handle_ = dlopen("libmpi.dylib", RTLD_NOW | RTLD_GLOBAL);
|
||||
libmpi_handle_ = dlopen(libmpi_name, RTLD_NOW | RTLD_GLOBAL);
|
||||
if (libmpi_handle_ == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Check library version and warn if it isn't Open MPI
|
||||
int (*get_version)(char*, int*);
|
||||
LOAD_SYMBOL(MPI_Get_library_version, get_version);
|
||||
char version_ptr[MPI_MAX_LIBRARY_VERSION_STRING];
|
||||
int version_length = 0;
|
||||
get_version(version_ptr, &version_length);
|
||||
std::string_view version(version_ptr, version_length);
|
||||
if (version.find("Open MPI") == std::string::npos) {
|
||||
std::cerr << "[mpi] MPI found but it does not appear to be Open MPI."
|
||||
<< "MLX requires Open MPI but this is " << version << std::endl;
|
||||
return;
|
||||
}
|
||||
|
||||
// API
|
||||
LOAD_SYMBOL(MPI_Init, init);
|
||||
LOAD_SYMBOL(MPI_Finalize, finalize);
|
||||
|
28
mlx/distributed/mpi/mpi_declarations.h
Normal file
28
mlx/distributed/mpi/mpi_declarations.h
Normal file
@ -0,0 +1,28 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
// Constants
|
||||
|
||||
#define MPI_SUCCESS 0
|
||||
#define MPI_ANY_SOURCE -1
|
||||
#define MPI_ANY_TAG -1
|
||||
#define MPI_IN_PLACE ((void*)1)
|
||||
#define MPI_MAX_LIBRARY_VERSION_STRING 256
|
||||
|
||||
// Define all the types that we use so that we don't include <mpi.h> which
|
||||
// causes linker errors on some platforms.
|
||||
//
|
||||
// NOTE: We define everything for openmpi.
|
||||
|
||||
typedef void* MPI_Comm;
|
||||
typedef void* MPI_Datatype;
|
||||
typedef void* MPI_Op;
|
||||
|
||||
typedef void(MPI_User_function)(void*, void*, int*, MPI_Datatype*);
|
||||
|
||||
typedef struct ompi_status_public_t {
|
||||
int MPI_SOURCE;
|
||||
int MPI_TAG;
|
||||
int MPI_ERROR;
|
||||
int _cancelled;
|
||||
size_t _ucount;
|
||||
} MPI_Status;
|
Loading…
Reference in New Issue
Block a user