mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 01:17:26 +08:00
Comms (#1097)
* Start the communications branch using MPI * Add ops and primitives * Add python bindings for distributed
This commit is contained in:
parent
0189ab6ab6
commit
50dfb664db
@ -71,6 +71,7 @@ jobs:
|
||||
name: Install dependencies
|
||||
command: |
|
||||
brew install python@3.8
|
||||
brew install openmpi
|
||||
python3.8 -m venv env
|
||||
source env/bin/activate
|
||||
pip install --upgrade pip
|
||||
@ -96,6 +97,7 @@ jobs:
|
||||
source env/bin/activate
|
||||
LOW_MEMORY=1 DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
|
||||
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu
|
||||
mpirun -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py
|
||||
- run:
|
||||
name: Build example extension
|
||||
command: |
|
||||
|
@ -167,6 +167,11 @@ else()
|
||||
set(MLX_BUILD_ACCELERATE OFF)
|
||||
endif()
|
||||
|
||||
find_package(MPI)
|
||||
if (MPI_FOUND)
|
||||
target_include_directories(mlx PRIVATE ${MPI_INCLUDE_PATH})
|
||||
endif()
|
||||
|
||||
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/mlx)
|
||||
|
||||
target_include_directories(
|
||||
|
@ -9,3 +9,4 @@ build_example(tutorial.cpp)
|
||||
build_example(linear_regression.cpp)
|
||||
build_example(logistic_regression.cpp)
|
||||
build_example(metal_capture.cpp)
|
||||
build_example(distributed.cpp)
|
||||
|
22
examples/cpp/distributed.cpp
Normal file
22
examples/cpp/distributed.cpp
Normal file
@ -0,0 +1,22 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "mlx/mlx.h"
|
||||
|
||||
using namespace mlx::core;
|
||||
|
||||
int main() {
|
||||
if (!distributed::is_available()) {
|
||||
std::cout << "No communication backend found" << std::endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
auto global_group = distributed::init();
|
||||
std::cout << global_group.rank() << " / " << global_group.size() << std::endl;
|
||||
|
||||
array x = ones({10});
|
||||
array out = distributed::all_reduce_sum(x, global_group);
|
||||
|
||||
std::cout << out << std::endl;
|
||||
}
|
@ -25,6 +25,7 @@ else()
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_cpu)
|
||||
endif()
|
||||
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/distributed)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/io)
|
||||
if (MLX_BUILD_ACCELERATE)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/accelerate)
|
||||
|
16
mlx/distributed/CMakeLists.txt
Normal file
16
mlx/distributed/CMakeLists.txt
Normal file
@ -0,0 +1,16 @@
|
||||
target_sources(
|
||||
mlx
|
||||
PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp
|
||||
)
|
||||
|
||||
if (MPI_FOUND AND MLX_BUILD_CPU)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/mpi)
|
||||
else()
|
||||
target_sources(
|
||||
mlx
|
||||
PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/no_distributed.cpp
|
||||
)
|
||||
endif()
|
62
mlx/distributed/distributed.h
Normal file
62
mlx/distributed/distributed.h
Normal file
@ -0,0 +1,62 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "mlx/array.h"
|
||||
|
||||
namespace mlx::core::distributed {
|
||||
|
||||
/* Check if a communication backend is available */
|
||||
bool is_available();
|
||||
|
||||
/**
|
||||
* A distributed::Group represents a group of independent mlx processes that
|
||||
* can communicate. We must also be able to create sub-groups from a group in
|
||||
* order to define more granular communication.
|
||||
*/
|
||||
struct Group {
|
||||
Group(std::shared_ptr<void> group) : group_(group) {}
|
||||
|
||||
int rank();
|
||||
int size();
|
||||
|
||||
/**
|
||||
* Split the group according to the provided color. Namely processes that use
|
||||
* the same color will go to the same group.
|
||||
*
|
||||
* The key defines the rank of the processes in the new group. The smaller
|
||||
* the key the smaller the rank. If the provided key is negative, then the
|
||||
* rank in the current group is used.
|
||||
*/
|
||||
Group split(int color, int key = -1);
|
||||
|
||||
const std::shared_ptr<void>& raw_group() {
|
||||
return group_;
|
||||
}
|
||||
|
||||
private:
|
||||
std::shared_ptr<void> group_{nullptr};
|
||||
};
|
||||
|
||||
/**
|
||||
* Initialize the distributed backend and return the group containing all
|
||||
* discoverable processes.
|
||||
*/
|
||||
Group init();
|
||||
|
||||
namespace detail {
|
||||
|
||||
/* Return the communication stream. */
|
||||
Stream communication_stream();
|
||||
|
||||
/* Perform an all reduce sum operation */
|
||||
void all_reduce_sum(Group group, const array& input, array& output);
|
||||
|
||||
/* Perform an all reduce sum operation */
|
||||
void all_gather(Group group, const array& input, array& output);
|
||||
|
||||
} // namespace detail
|
||||
|
||||
} // namespace mlx::core::distributed
|
5
mlx/distributed/mpi/CMakeLists.txt
Normal file
5
mlx/distributed/mpi/CMakeLists.txt
Normal file
@ -0,0 +1,5 @@
|
||||
target_sources(
|
||||
mlx
|
||||
PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/mpi.cpp
|
||||
)
|
283
mlx/distributed/mpi/mpi.cpp
Normal file
283
mlx/distributed/mpi/mpi.cpp
Normal file
@ -0,0 +1,283 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <dlfcn.h>
|
||||
#include <mpi.h>
|
||||
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/distributed/distributed.h"
|
||||
#include "mlx/scheduler.h"
|
||||
|
||||
#define LOAD_SYMBOL(symbol, variable) \
|
||||
{ \
|
||||
variable = (decltype(variable))dlsym(libmpi_handle_, #symbol); \
|
||||
char* error = dlerror(); \
|
||||
if (error != nullptr) { \
|
||||
libmpi_handle_ = nullptr; \
|
||||
return; \
|
||||
} \
|
||||
}
|
||||
|
||||
namespace mlx::core::distributed {
|
||||
|
||||
namespace {
|
||||
|
||||
array ensure_row_contiguous(const array& arr) {
|
||||
if (arr.flags().row_contiguous) {
|
||||
return arr;
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy(arr, arr_copy, CopyType::General);
|
||||
return arr_copy;
|
||||
}
|
||||
}
|
||||
|
||||
struct MPIWrapper {
|
||||
MPIWrapper() {
|
||||
libmpi_handle_ = dlopen("libmpi.dylib", RTLD_NOW | RTLD_GLOBAL);
|
||||
if (libmpi_handle_ == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
// API
|
||||
LOAD_SYMBOL(MPI_Init, init);
|
||||
LOAD_SYMBOL(MPI_Finalize, finalize);
|
||||
LOAD_SYMBOL(MPI_Comm_rank, rank);
|
||||
LOAD_SYMBOL(MPI_Comm_size, size);
|
||||
LOAD_SYMBOL(MPI_Comm_split, comm_split);
|
||||
LOAD_SYMBOL(MPI_Comm_free, comm_free);
|
||||
LOAD_SYMBOL(MPI_Allreduce, all_reduce);
|
||||
LOAD_SYMBOL(MPI_Allgather, all_gather);
|
||||
|
||||
// Objects
|
||||
LOAD_SYMBOL(ompi_mpi_comm_world, comm_world_);
|
||||
|
||||
// Ops
|
||||
LOAD_SYMBOL(ompi_mpi_op_sum, op_sum_);
|
||||
|
||||
// Datatypes
|
||||
LOAD_SYMBOL(ompi_mpi_c_bool, mpi_bool_);
|
||||
LOAD_SYMBOL(ompi_mpi_int8_t, mpi_int8_);
|
||||
LOAD_SYMBOL(ompi_mpi_uint8_t, mpi_uint8_);
|
||||
LOAD_SYMBOL(ompi_mpi_int16_t, mpi_int16_);
|
||||
LOAD_SYMBOL(ompi_mpi_uint16_t, mpi_uint16_);
|
||||
LOAD_SYMBOL(ompi_mpi_int32_t, mpi_int32_);
|
||||
LOAD_SYMBOL(ompi_mpi_uint32_t, mpi_uint32_);
|
||||
LOAD_SYMBOL(ompi_mpi_int64_t, mpi_int64_);
|
||||
LOAD_SYMBOL(ompi_mpi_uint64_t, mpi_uint64_);
|
||||
LOAD_SYMBOL(ompi_mpi_float, mpi_float_);
|
||||
LOAD_SYMBOL(ompi_mpi_c_complex, mpi_complex_);
|
||||
}
|
||||
|
||||
bool is_available() {
|
||||
return libmpi_handle_ != nullptr;
|
||||
}
|
||||
|
||||
bool init_safe() {
|
||||
if (!is_available()) {
|
||||
return false;
|
||||
}
|
||||
return init(nullptr, nullptr) == MPI_SUCCESS;
|
||||
}
|
||||
|
||||
void finalize_safe() {
|
||||
if (is_available()) {
|
||||
finalize();
|
||||
}
|
||||
}
|
||||
|
||||
MPI_Comm world() {
|
||||
return comm_world_;
|
||||
}
|
||||
|
||||
MPI_Datatype datatype(const array& arr) {
|
||||
switch (arr.dtype()) {
|
||||
case bool_:
|
||||
return mpi_bool_;
|
||||
case int8:
|
||||
return mpi_int8_;
|
||||
case uint8:
|
||||
return mpi_uint8_;
|
||||
case int16:
|
||||
return mpi_int16_;
|
||||
case uint16:
|
||||
return mpi_uint16_;
|
||||
case int32:
|
||||
return mpi_int32_;
|
||||
case uint32:
|
||||
return mpi_uint32_;
|
||||
case int64:
|
||||
return mpi_int64_;
|
||||
case uint64:
|
||||
return mpi_uint64_;
|
||||
case float32:
|
||||
return mpi_float_;
|
||||
case complex64:
|
||||
return mpi_complex_;
|
||||
case float16:
|
||||
case bfloat16:
|
||||
throw std::runtime_error("MPI doesn't support 16-bit floats");
|
||||
}
|
||||
}
|
||||
|
||||
MPI_Op op_sum() {
|
||||
return op_sum_;
|
||||
}
|
||||
|
||||
void* libmpi_handle_;
|
||||
|
||||
// API
|
||||
int (*init)(int*, char***);
|
||||
int (*finalize)();
|
||||
int (*rank)(MPI_Comm, int*);
|
||||
int (*size)(MPI_Comm, int*);
|
||||
int (*all_reduce)(const void*, void*, int, MPI_Datatype, MPI_Op, MPI_Comm);
|
||||
int (*all_gather)(
|
||||
const void*,
|
||||
int,
|
||||
MPI_Datatype,
|
||||
void*,
|
||||
int,
|
||||
MPI_Datatype,
|
||||
MPI_Comm);
|
||||
int (*comm_split)(MPI_Comm, int, int, MPI_Comm*);
|
||||
int (*comm_free)(MPI_Comm*);
|
||||
|
||||
// Objects
|
||||
MPI_Comm comm_world_;
|
||||
|
||||
// Ops
|
||||
MPI_Op op_sum_;
|
||||
|
||||
// Datatypes
|
||||
MPI_Datatype mpi_bool_;
|
||||
MPI_Datatype mpi_int8_;
|
||||
MPI_Datatype mpi_uint8_;
|
||||
MPI_Datatype mpi_int16_;
|
||||
MPI_Datatype mpi_uint16_;
|
||||
MPI_Datatype mpi_int32_;
|
||||
MPI_Datatype mpi_uint32_;
|
||||
MPI_Datatype mpi_int64_;
|
||||
MPI_Datatype mpi_uint64_;
|
||||
MPI_Datatype mpi_float_;
|
||||
MPI_Datatype mpi_complex_;
|
||||
};
|
||||
|
||||
MPIWrapper& mpi() {
|
||||
static MPIWrapper wrapper;
|
||||
return wrapper;
|
||||
}
|
||||
|
||||
struct MPIGroupImpl {
|
||||
MPIGroupImpl(MPI_Comm comm, bool global)
|
||||
: comm_(comm), global_(global), rank_(-1), size_(-1) {}
|
||||
~MPIGroupImpl() {
|
||||
if (global_) {
|
||||
mpi().finalize_safe();
|
||||
} else {
|
||||
mpi().comm_free(&comm_);
|
||||
}
|
||||
}
|
||||
|
||||
MPI_Comm comm() {
|
||||
return comm_;
|
||||
}
|
||||
|
||||
int rank() {
|
||||
if (rank_ < 0) {
|
||||
mpi().rank(comm_, &rank_);
|
||||
}
|
||||
return rank_;
|
||||
}
|
||||
|
||||
int size() {
|
||||
if (size_ < 0) {
|
||||
mpi().size(comm_, &size_);
|
||||
}
|
||||
return size_;
|
||||
}
|
||||
|
||||
private:
|
||||
MPI_Comm comm_;
|
||||
bool global_;
|
||||
int rank_;
|
||||
int size_;
|
||||
};
|
||||
|
||||
MPI_Comm to_comm(Group& group) {
|
||||
return std::static_pointer_cast<MPIGroupImpl>(group.raw_group())->comm();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
int Group::rank() {
|
||||
return std::static_pointer_cast<MPIGroupImpl>(group_)->rank();
|
||||
}
|
||||
|
||||
int Group::size() {
|
||||
return std::static_pointer_cast<MPIGroupImpl>(group_)->size();
|
||||
}
|
||||
|
||||
Group Group::split(int color, int key) {
|
||||
auto mpi_group = std::static_pointer_cast<MPIGroupImpl>(group_);
|
||||
|
||||
key = (key < 0) ? rank() : key;
|
||||
|
||||
MPI_Comm new_comm;
|
||||
int result = mpi().comm_split(mpi_group->comm(), color, key, &new_comm);
|
||||
if (result != MPI_SUCCESS) {
|
||||
throw std::runtime_error("MPI could not split this group");
|
||||
}
|
||||
|
||||
return Group(std::make_shared<MPIGroupImpl>(new_comm, false));
|
||||
}
|
||||
|
||||
bool is_available() {
|
||||
return mpi().is_available();
|
||||
}
|
||||
|
||||
Group init() {
|
||||
static std::shared_ptr<MPIGroupImpl> global_group = nullptr;
|
||||
|
||||
if (global_group == nullptr) {
|
||||
if (!mpi().init_safe()) {
|
||||
throw std::runtime_error("Cannot initialize MPI");
|
||||
}
|
||||
global_group = std::make_shared<MPIGroupImpl>(mpi().world(), true);
|
||||
}
|
||||
|
||||
return Group(global_group);
|
||||
}
|
||||
|
||||
namespace detail {
|
||||
|
||||
Stream communication_stream() {
|
||||
static Stream comm_stream = new_stream(Device::cpu);
|
||||
return comm_stream;
|
||||
}
|
||||
|
||||
void all_reduce_sum(Group group, const array& input_, array& output) {
|
||||
array input = ensure_row_contiguous(input_);
|
||||
mpi().all_reduce(
|
||||
input.data<void>(),
|
||||
output.data<void>(),
|
||||
input.size(),
|
||||
mpi().datatype(input),
|
||||
mpi().op_sum(),
|
||||
to_comm(group));
|
||||
}
|
||||
|
||||
void all_gather(Group group, const array& input_, array& output) {
|
||||
array input = ensure_row_contiguous(input_);
|
||||
mpi().all_gather(
|
||||
input.data<void>(),
|
||||
input.size(),
|
||||
mpi().datatype(input),
|
||||
output.data<void>(),
|
||||
input.size(),
|
||||
mpi().datatype(output),
|
||||
to_comm(group));
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
} // namespace mlx::core::distributed
|
39
mlx/distributed/no_distributed.cpp
Normal file
39
mlx/distributed/no_distributed.cpp
Normal file
@ -0,0 +1,39 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "mlx/distributed/distributed.h"
|
||||
|
||||
namespace mlx::core::distributed {
|
||||
|
||||
int Group::rank() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
int Group::size() {
|
||||
return 1;
|
||||
}
|
||||
|
||||
Group Group::split(int color, int key) {
|
||||
throw std::runtime_error("Cannot split the distributed group further");
|
||||
}
|
||||
|
||||
bool is_available() {
|
||||
return false;
|
||||
}
|
||||
|
||||
Group init() {
|
||||
return Group(nullptr);
|
||||
}
|
||||
|
||||
namespace detail {
|
||||
|
||||
Stream communication_stream() {
|
||||
static Stream comm_stream = new_stream(Device::cpu);
|
||||
return comm_stream;
|
||||
}
|
||||
|
||||
void all_reduce_sum(Group group, const array& input, array& output) {}
|
||||
void all_gather(Group group, const array& input, array& output) {}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
} // namespace mlx::core::distributed
|
54
mlx/distributed/ops.cpp
Normal file
54
mlx/distributed/ops.cpp
Normal file
@ -0,0 +1,54 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "mlx/distributed/ops.h"
|
||||
#include "mlx/distributed/primitives.h"
|
||||
|
||||
namespace mlx::core::distributed {
|
||||
|
||||
namespace {
|
||||
|
||||
Group to_group(std::optional<Group> group) {
|
||||
if (group.has_value()) {
|
||||
return group.value();
|
||||
} else {
|
||||
return distributed::init();
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
array all_reduce_sum(const array& x, std::optional<Group> group_) {
|
||||
auto group = to_group(group_);
|
||||
|
||||
if (group.size() == 1) {
|
||||
return x;
|
||||
}
|
||||
|
||||
return array(
|
||||
x.shape(),
|
||||
x.dtype(),
|
||||
std::make_shared<AllReduce>(group, AllReduce::Sum),
|
||||
{x});
|
||||
}
|
||||
|
||||
array all_gather(const array& x, std::optional<Group> group_) {
|
||||
auto group = to_group(group_);
|
||||
|
||||
if (group.size() == 1) {
|
||||
return x;
|
||||
}
|
||||
|
||||
auto result_shape = x.shape();
|
||||
if (result_shape.size() == 0) {
|
||||
result_shape.push_back(group.size());
|
||||
} else {
|
||||
result_shape[0] *= group.size();
|
||||
}
|
||||
return array(
|
||||
std::move(result_shape),
|
||||
x.dtype(),
|
||||
std::make_shared<AllGather>(group),
|
||||
{x});
|
||||
}
|
||||
|
||||
} // namespace mlx::core::distributed
|
14
mlx/distributed/ops.h
Normal file
14
mlx/distributed/ops.h
Normal file
@ -0,0 +1,14 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <optional>
|
||||
|
||||
#include "mlx/distributed/distributed.h"
|
||||
|
||||
namespace mlx::core::distributed {
|
||||
|
||||
array all_reduce_sum(const array& x, std::optional<Group> group = std::nullopt);
|
||||
array all_gather(const array& x, std::optional<Group> group = std::nullopt);
|
||||
|
||||
} // namespace mlx::core::distributed
|
98
mlx/distributed/primitives.cpp
Normal file
98
mlx/distributed/primitives.cpp
Normal file
@ -0,0 +1,98 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/distributed/ops.h"
|
||||
#include "mlx/distributed/primitives.h"
|
||||
#include "mlx/ops.h"
|
||||
|
||||
namespace mlx::core::distributed {
|
||||
|
||||
void AllReduce::eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
assert(inputs.size() == 1);
|
||||
assert(outputs.size() == 1);
|
||||
|
||||
outputs[0].set_data(allocator::malloc_or_wait(outputs[0].nbytes()));
|
||||
|
||||
switch (reduce_type_) {
|
||||
case Sum:
|
||||
distributed::detail::all_reduce_sum(group(), inputs[0], outputs[0]);
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error("Only all reduce sum is supported for now");
|
||||
}
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> AllReduce::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
switch (reduce_type_) {
|
||||
case Sum:
|
||||
return {{all_reduce_sum(inputs[0], group())}, axes};
|
||||
default:
|
||||
throw std::runtime_error("Only all reduce sum is supported for now");
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<array> AllReduce::jvp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>& argnums) {
|
||||
switch (reduce_type_) {
|
||||
case Sum:
|
||||
return {all_reduce_sum(tangents[0], group())};
|
||||
default:
|
||||
throw std::runtime_error("Only all reduce sum is supported for now");
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<array> AllReduce::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>& outputs) {
|
||||
return cotangents;
|
||||
}
|
||||
|
||||
void AllGather::eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
assert(inputs.size() == 1);
|
||||
assert(outputs.size() == 1);
|
||||
|
||||
outputs[0].set_data(allocator::malloc_or_wait(outputs[0].nbytes()));
|
||||
|
||||
distributed::detail::all_gather(group(), inputs[0], outputs[0]);
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> AllGather::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
return {{all_gather(inputs[0], group())}, axes};
|
||||
}
|
||||
|
||||
std::vector<array> AllGather::jvp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>& argnums) {
|
||||
return {all_gather(tangents[0], group())};
|
||||
}
|
||||
|
||||
std::vector<array> AllGather::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>& outputs) {
|
||||
auto g = group();
|
||||
std::vector<int> starts(primals[0].ndim(), 0);
|
||||
auto stops = primals[0].shape();
|
||||
starts[0] = g.rank() * stops[0];
|
||||
stops[0] += starts[0];
|
||||
return {slice(cotangents[0], starts, stops)};
|
||||
}
|
||||
|
||||
} // namespace mlx::core::distributed
|
100
mlx/distributed/primitives.h
Normal file
100
mlx/distributed/primitives.h
Normal file
@ -0,0 +1,100 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/distributed/distributed.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core::distributed {
|
||||
|
||||
class DistPrimitive : public Primitive {
|
||||
public:
|
||||
DistPrimitive(Group group)
|
||||
: Primitive(detail::communication_stream()), group_(group) {}
|
||||
|
||||
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override {
|
||||
throw std::runtime_error(
|
||||
"Communication primitives cannot be run on the GPU");
|
||||
}
|
||||
|
||||
const Group& group() const {
|
||||
return group_;
|
||||
}
|
||||
|
||||
private:
|
||||
Group group_;
|
||||
};
|
||||
|
||||
class AllReduce : public DistPrimitive {
|
||||
public:
|
||||
enum ReduceType { And, Or, Sum, Prod, Min, Max };
|
||||
|
||||
AllReduce(Group group, ReduceType reduce_type)
|
||||
: DistPrimitive(group), reduce_type_(reduce_type) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
std::pair<std::vector<array>, std::vector<int>> vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) override;
|
||||
std::vector<array> jvp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>& argnums) override;
|
||||
std::vector<array> vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>& outputs) override;
|
||||
|
||||
void print(std::ostream& os) override {
|
||||
switch (reduce_type_) {
|
||||
case And:
|
||||
os << "And";
|
||||
case Or:
|
||||
os << "And";
|
||||
break;
|
||||
case Sum:
|
||||
os << "Sum";
|
||||
break;
|
||||
case Prod:
|
||||
os << "Prod";
|
||||
break;
|
||||
case Min:
|
||||
os << "Min";
|
||||
break;
|
||||
case Max:
|
||||
os << "Max";
|
||||
break;
|
||||
}
|
||||
os << " AllReduce";
|
||||
}
|
||||
|
||||
private:
|
||||
ReduceType reduce_type_;
|
||||
};
|
||||
|
||||
class AllGather : public DistPrimitive {
|
||||
public:
|
||||
AllGather(Group group) : DistPrimitive(group) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
std::pair<std::vector<array>, std::vector<int>> vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) override;
|
||||
std::vector<array> jvp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>& argnums) override;
|
||||
std::vector<array> vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>& outputs) override;
|
||||
|
||||
DEFINE_PRINT(AllGather);
|
||||
};
|
||||
|
||||
} // namespace mlx::core::distributed
|
@ -6,6 +6,8 @@
|
||||
#include "mlx/backend/metal/metal.h"
|
||||
#include "mlx/compile.h"
|
||||
#include "mlx/device.h"
|
||||
#include "mlx/distributed/distributed.h"
|
||||
#include "mlx/distributed/ops.h"
|
||||
#include "mlx/fast.h"
|
||||
#include "mlx/fft.h"
|
||||
#include "mlx/io.h"
|
||||
|
@ -6,6 +6,7 @@ nanobind_add_module(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/array.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/convert.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
||||
|
107
python/src/distributed.cpp
Normal file
107
python/src/distributed.cpp
Normal file
@ -0,0 +1,107 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <nanobind/nanobind.h>
|
||||
#include <nanobind/stl/optional.h>
|
||||
#include <nanobind/stl/shared_ptr.h>
|
||||
|
||||
#include "mlx/distributed/distributed.h"
|
||||
#include "mlx/distributed/ops.h"
|
||||
|
||||
namespace nb = nanobind;
|
||||
using namespace nb::literals;
|
||||
using namespace mlx::core;
|
||||
|
||||
void init_distributed(nb::module_& parent_module) {
|
||||
auto m = parent_module.def_submodule(
|
||||
"distributed", "mlx.core.distributed: Communication operations");
|
||||
|
||||
nb::class_<distributed::Group>(
|
||||
m,
|
||||
"Group",
|
||||
R"pbcopy(
|
||||
An :class:`mlx.core.distributed.Group` represents a group of independent mlx
|
||||
processes that can communicate.
|
||||
)pbcopy")
|
||||
.def("rank", &distributed::Group::rank, "Get the rank of this process")
|
||||
.def("size", &distributed::Group::size, "Get the size of the group")
|
||||
.def(
|
||||
"split",
|
||||
&distributed::Group::split,
|
||||
"color"_a,
|
||||
"key"_a = -1,
|
||||
nb::sig("def split(self, color: int, key: int = -1) -> Group"),
|
||||
R"pbdoc(
|
||||
Split the group to subgroups based on the provided color.
|
||||
|
||||
Processes that use the same color go to the same group. The ``key``
|
||||
argument defines the rank in the new group. The smaller the key the
|
||||
smaller the rank. If the key is negative then the rank in the
|
||||
current group is used.
|
||||
|
||||
Args:
|
||||
color (int): A value to group processes into subgroups.
|
||||
key (int, optional): A key to optionally change the rank ordering
|
||||
of the processes.
|
||||
)pbdoc");
|
||||
|
||||
m.def(
|
||||
"is_available",
|
||||
&distributed::is_available,
|
||||
R"pbdoc(
|
||||
Check if a communication backend is available.
|
||||
)pbdoc");
|
||||
|
||||
m.def(
|
||||
"init",
|
||||
&distributed::init,
|
||||
R"pbdoc(
|
||||
Initialize the communication backend and create the global communication group.
|
||||
)pbdoc");
|
||||
|
||||
m.def(
|
||||
"all_reduce_sum",
|
||||
&distributed::all_reduce_sum,
|
||||
"x"_a,
|
||||
nb::kw_only(),
|
||||
"group"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def all_reduce_sum(x: array, *, group: Optional[Group] = None) -> array"),
|
||||
R"pbdoc(
|
||||
All reduce sum.
|
||||
|
||||
Sum the ``x`` arrays from all processes in the group.
|
||||
|
||||
Args:
|
||||
x (array): Input array.
|
||||
group (Group): The group of processes that will participate in the
|
||||
reduction. If set to ``None`` the global group is used. Default:
|
||||
``None``.
|
||||
|
||||
Returns:
|
||||
array: The sum of all ``x`` arrays.
|
||||
)pbdoc");
|
||||
|
||||
m.def(
|
||||
"all_gather",
|
||||
&distributed::all_gather,
|
||||
"x"_a,
|
||||
nb::kw_only(),
|
||||
"group"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def all_gather(x: array, *, group: Optional[Group] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Gather arrays from all processes.
|
||||
|
||||
Gather the ``x`` arrays from all processes in the group and concatenate
|
||||
them along the first axis. The arrays should all have the same shape.
|
||||
|
||||
Args:
|
||||
x (array): Input array.
|
||||
group (Group): The group of processes that will participate in the
|
||||
gather. If set to ``None`` the global group is used. Default:
|
||||
``None``.
|
||||
|
||||
Returns:
|
||||
array: The concatenation of all ``x`` arrays.
|
||||
)pbdoc");
|
||||
}
|
@ -1,4 +1,4 @@
|
||||
// Conbright © 2023-2024 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <nanobind/nanobind.h>
|
||||
|
||||
@ -18,6 +18,7 @@ void init_fft(nb::module_&);
|
||||
void init_linalg(nb::module_&);
|
||||
void init_constants(nb::module_&);
|
||||
void init_fast(nb::module_&);
|
||||
void init_distributed(nb::module_&);
|
||||
|
||||
NB_MODULE(core, m) {
|
||||
m.doc() = "mlx: A framework for machine learning on Apple silicon.";
|
||||
@ -37,6 +38,7 @@ NB_MODULE(core, m) {
|
||||
init_linalg(m);
|
||||
init_constants(m);
|
||||
init_fast(m);
|
||||
init_distributed(m);
|
||||
|
||||
m.attr("__version__") = TOSTRING(_VERSION_);
|
||||
}
|
||||
|
98
python/tests/mpi_test_distributed.py
Normal file
98
python/tests/mpi_test_distributed.py
Normal file
@ -0,0 +1,98 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import unittest
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx_tests
|
||||
|
||||
|
||||
class TestDistributed(mlx_tests.MLXTestCase):
|
||||
def test_groups(self):
|
||||
world = mx.distributed.init()
|
||||
self.assertEqual(world.size(), 8)
|
||||
self.assertTrue(0 <= world.rank() < 8)
|
||||
|
||||
world2 = mx.distributed.init()
|
||||
self.assertEqual(world.size(), world2.size())
|
||||
self.assertEqual(world.rank(), world2.rank())
|
||||
|
||||
sub = world.split(world.rank() % 2)
|
||||
self.assertEqual(sub.size(), 4)
|
||||
self.assertEqual(sub.rank(), world.rank() // 2)
|
||||
|
||||
sub = world.split(world.rank() // 2)
|
||||
self.assertEqual(sub.size(), 2)
|
||||
|
||||
def test_all_reduce(self):
|
||||
world = mx.distributed.init()
|
||||
dtypes = [
|
||||
mx.int8,
|
||||
mx.uint8,
|
||||
mx.int16,
|
||||
mx.uint16,
|
||||
mx.int32,
|
||||
mx.uint32,
|
||||
mx.float32,
|
||||
mx.complex64,
|
||||
]
|
||||
for dt in dtypes:
|
||||
x = mx.ones((2, 2, 4), dtype=dt)
|
||||
y = mx.distributed.all_reduce_sum(x)
|
||||
self.assertTrue(mx.all(y == world.size()))
|
||||
|
||||
sub = world.split(world.rank() % 2)
|
||||
for dt in dtypes:
|
||||
x = mx.ones((2, 2, 4), dtype=dt)
|
||||
y = mx.distributed.all_reduce_sum(x, group=sub)
|
||||
self.assertTrue(mx.all(y == sub.size()))
|
||||
|
||||
def test_all_gather(self):
|
||||
world = mx.distributed.init()
|
||||
dtypes = [
|
||||
mx.int8,
|
||||
mx.uint8,
|
||||
mx.int16,
|
||||
mx.uint16,
|
||||
mx.int32,
|
||||
mx.uint32,
|
||||
mx.float32,
|
||||
mx.complex64,
|
||||
]
|
||||
for dt in dtypes:
|
||||
x = mx.ones((2, 2, 4), dtype=dt)
|
||||
y = mx.distributed.all_gather(x)
|
||||
self.assertEqual(y.shape, (world.size() * 2, 2, 4))
|
||||
self.assertTrue(mx.all(y == 1))
|
||||
|
||||
sub = world.split(world.rank() % 2)
|
||||
for dt in dtypes:
|
||||
x = mx.ones((2, 2, 4), dtype=dt)
|
||||
y = mx.distributed.all_gather(x, group=sub)
|
||||
self.assertEqual(y.shape, (sub.size() * 2, 2, 4))
|
||||
self.assertTrue(mx.all(y == 1))
|
||||
|
||||
def test_mixed(self):
|
||||
# Make the following groups:
|
||||
# - world: 0 1 2 3 4 5 6 7
|
||||
# - sub_1: 0 1 0 1 0 1 0 1
|
||||
# - sub_2: 0 0 1 1 2 2 3 3
|
||||
#
|
||||
# The corresponding colors to make them are
|
||||
# - world: N/A
|
||||
# - sub_1: 0 0 1 1 2 2 3 3
|
||||
# - sub_2: 0 1 0 1 0 1 0 1
|
||||
|
||||
world = mx.distributed.init()
|
||||
sub_1 = world.split(world.rank() // 2)
|
||||
sub_2 = world.split(world.rank() % 2)
|
||||
|
||||
x = mx.ones((1, 8)) * world.rank()
|
||||
y = mx.distributed.all_reduce_sum(x, group=sub_1)
|
||||
z = mx.distributed.all_gather(y, group=sub_2)
|
||||
z_target = mx.arange(8).reshape(4, 2).sum(-1, keepdims=True)
|
||||
|
||||
self.assertTrue(mx.all(z == z_target))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in New Issue
Block a user