mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Cross platform libmpi loading (#1975)
This commit is contained in:
parent
4e1994e9d7
commit
d343782c8b
@ -89,6 +89,7 @@ jobs:
|
|||||||
pip install numpy
|
pip install numpy
|
||||||
sudo apt-get update
|
sudo apt-get update
|
||||||
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
||||||
|
sudo apt-get install openmpi-bin openmpi-common libopenmpi-dev
|
||||||
- run:
|
- run:
|
||||||
name: Install Python package
|
name: Install Python package
|
||||||
command: |
|
command: |
|
||||||
@ -110,6 +111,8 @@ jobs:
|
|||||||
name: Run Python tests
|
name: Run Python tests
|
||||||
command: |
|
command: |
|
||||||
python3 -m unittest discover python/tests -v
|
python3 -m unittest discover python/tests -v
|
||||||
|
mpirun --bind-to none -host localhost:8 -np 8 python python/tests/mpi_test_distributed.py
|
||||||
|
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py
|
||||||
- run:
|
- run:
|
||||||
name: Build CPP only
|
name: Build CPP only
|
||||||
command: |
|
command: |
|
||||||
|
@ -212,24 +212,6 @@ else()
|
|||||||
set(MLX_BUILD_ACCELERATE OFF)
|
set(MLX_BUILD_ACCELERATE OFF)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
find_package(MPI)
|
|
||||||
if(MPI_FOUND)
|
|
||||||
execute_process(
|
|
||||||
COMMAND zsh "-c" "mpirun --version"
|
|
||||||
OUTPUT_VARIABLE MPI_VERSION
|
|
||||||
ERROR_QUIET)
|
|
||||||
if(${MPI_VERSION} MATCHES ".*Open MPI.*")
|
|
||||||
target_include_directories(mlx PRIVATE ${MPI_INCLUDE_PATH})
|
|
||||||
elseif(MPI_VERSION STREQUAL "")
|
|
||||||
set(MPI_FOUND FALSE)
|
|
||||||
message(
|
|
||||||
WARNING "MPI found but mpirun is not available. Building without MPI.")
|
|
||||||
else()
|
|
||||||
set(MPI_FOUND FALSE)
|
|
||||||
message(WARNING "MPI which is not OpenMPI found. Building without MPI.")
|
|
||||||
endif()
|
|
||||||
endif()
|
|
||||||
|
|
||||||
message(STATUS "Downloading json")
|
message(STATUS "Downloading json")
|
||||||
FetchContent_Declare(
|
FetchContent_Declare(
|
||||||
json
|
json
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
if(MPI_FOUND AND MLX_BUILD_CPU)
|
if(MLX_BUILD_CPU)
|
||||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/mpi.cpp)
|
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/mpi.cpp)
|
||||||
else()
|
else()
|
||||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_mpi.cpp)
|
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_mpi.cpp)
|
||||||
|
@ -1,12 +1,13 @@
|
|||||||
// Copyright © 2024 Apple Inc.
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
#include <dlfcn.h>
|
#include <dlfcn.h>
|
||||||
#include <mpi.h>
|
#include <iostream>
|
||||||
|
|
||||||
#include "mlx/backend/cpu/encoder.h"
|
#include "mlx/backend/cpu/encoder.h"
|
||||||
#include "mlx/distributed/distributed.h"
|
#include "mlx/distributed/distributed.h"
|
||||||
#include "mlx/distributed/distributed_impl.h"
|
#include "mlx/distributed/distributed_impl.h"
|
||||||
#include "mlx/distributed/mpi/mpi.h"
|
#include "mlx/distributed/mpi/mpi.h"
|
||||||
|
#include "mlx/distributed/mpi/mpi_declarations.h"
|
||||||
|
|
||||||
#define LOAD_SYMBOL(symbol, variable) \
|
#define LOAD_SYMBOL(symbol, variable) \
|
||||||
{ \
|
{ \
|
||||||
@ -18,6 +19,12 @@
|
|||||||
} \
|
} \
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#ifdef __APPLE__
|
||||||
|
static constexpr const char* libmpi_name = "libmpi.dylib";
|
||||||
|
#else
|
||||||
|
static constexpr const char* libmpi_name = "libmpi.so";
|
||||||
|
#endif
|
||||||
|
|
||||||
namespace mlx::core::distributed::mpi {
|
namespace mlx::core::distributed::mpi {
|
||||||
|
|
||||||
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
|
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
|
||||||
@ -47,11 +54,24 @@ struct MPIWrapper {
|
|||||||
MPIWrapper() {
|
MPIWrapper() {
|
||||||
initialized_ = false;
|
initialized_ = false;
|
||||||
|
|
||||||
libmpi_handle_ = dlopen("libmpi.dylib", RTLD_NOW | RTLD_GLOBAL);
|
libmpi_handle_ = dlopen(libmpi_name, RTLD_NOW | RTLD_GLOBAL);
|
||||||
if (libmpi_handle_ == nullptr) {
|
if (libmpi_handle_ == nullptr) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check library version and warn if it isn't Open MPI
|
||||||
|
int (*get_version)(char*, int*);
|
||||||
|
LOAD_SYMBOL(MPI_Get_library_version, get_version);
|
||||||
|
char version_ptr[MPI_MAX_LIBRARY_VERSION_STRING];
|
||||||
|
int version_length = 0;
|
||||||
|
get_version(version_ptr, &version_length);
|
||||||
|
std::string_view version(version_ptr, version_length);
|
||||||
|
if (version.find("Open MPI") == std::string::npos) {
|
||||||
|
std::cerr << "[mpi] MPI found but it does not appear to be Open MPI."
|
||||||
|
<< "MLX requires Open MPI but this is " << version << std::endl;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
// API
|
// API
|
||||||
LOAD_SYMBOL(MPI_Init, init);
|
LOAD_SYMBOL(MPI_Init, init);
|
||||||
LOAD_SYMBOL(MPI_Finalize, finalize);
|
LOAD_SYMBOL(MPI_Finalize, finalize);
|
||||||
|
28
mlx/distributed/mpi/mpi_declarations.h
Normal file
28
mlx/distributed/mpi/mpi_declarations.h
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
// Constants
|
||||||
|
|
||||||
|
#define MPI_SUCCESS 0
|
||||||
|
#define MPI_ANY_SOURCE -1
|
||||||
|
#define MPI_ANY_TAG -1
|
||||||
|
#define MPI_IN_PLACE ((void*)1)
|
||||||
|
#define MPI_MAX_LIBRARY_VERSION_STRING 256
|
||||||
|
|
||||||
|
// Define all the types that we use so that we don't include <mpi.h> which
|
||||||
|
// causes linker errors on some platforms.
|
||||||
|
//
|
||||||
|
// NOTE: We define everything for openmpi.
|
||||||
|
|
||||||
|
typedef void* MPI_Comm;
|
||||||
|
typedef void* MPI_Datatype;
|
||||||
|
typedef void* MPI_Op;
|
||||||
|
|
||||||
|
typedef void(MPI_User_function)(void*, void*, int*, MPI_Datatype*);
|
||||||
|
|
||||||
|
typedef struct ompi_status_public_t {
|
||||||
|
int MPI_SOURCE;
|
||||||
|
int MPI_TAG;
|
||||||
|
int MPI_ERROR;
|
||||||
|
int _cancelled;
|
||||||
|
size_t _ucount;
|
||||||
|
} MPI_Status;
|
Loading…
Reference in New Issue
Block a user