mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-20 03:48:15 +08:00
Comms (#1097)
* Start the communications branch using MPI * Add ops and primitives * Add python bindings for distributed
This commit is contained in:

committed by
GitHub

parent
0189ab6ab6
commit
50dfb664db
@@ -25,6 +25,7 @@ else()
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_cpu)
|
||||
endif()
|
||||
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/distributed)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/io)
|
||||
if (MLX_BUILD_ACCELERATE)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/accelerate)
|
||||
|
16
mlx/distributed/CMakeLists.txt
Normal file
16
mlx/distributed/CMakeLists.txt
Normal file
@@ -0,0 +1,16 @@
|
||||
target_sources(
|
||||
mlx
|
||||
PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp
|
||||
)
|
||||
|
||||
if (MPI_FOUND AND MLX_BUILD_CPU)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/mpi)
|
||||
else()
|
||||
target_sources(
|
||||
mlx
|
||||
PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/no_distributed.cpp
|
||||
)
|
||||
endif()
|
62
mlx/distributed/distributed.h
Normal file
62
mlx/distributed/distributed.h
Normal file
@@ -0,0 +1,62 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "mlx/array.h"
|
||||
|
||||
namespace mlx::core::distributed {
|
||||
|
||||
/* Check if a communication backend is available */
|
||||
bool is_available();
|
||||
|
||||
/**
|
||||
* A distributed::Group represents a group of independent mlx processes that
|
||||
* can communicate. We must also be able to create sub-groups from a group in
|
||||
* order to define more granular communication.
|
||||
*/
|
||||
struct Group {
|
||||
Group(std::shared_ptr<void> group) : group_(group) {}
|
||||
|
||||
int rank();
|
||||
int size();
|
||||
|
||||
/**
|
||||
* Split the group according to the provided color. Namely processes that use
|
||||
* the same color will go to the same group.
|
||||
*
|
||||
* The key defines the rank of the processes in the new group. The smaller
|
||||
* the key the smaller the rank. If the provided key is negative, then the
|
||||
* rank in the current group is used.
|
||||
*/
|
||||
Group split(int color, int key = -1);
|
||||
|
||||
const std::shared_ptr<void>& raw_group() {
|
||||
return group_;
|
||||
}
|
||||
|
||||
private:
|
||||
std::shared_ptr<void> group_{nullptr};
|
||||
};
|
||||
|
||||
/**
|
||||
* Initialize the distributed backend and return the group containing all
|
||||
* discoverable processes.
|
||||
*/
|
||||
Group init();
|
||||
|
||||
namespace detail {
|
||||
|
||||
/* Return the communication stream. */
|
||||
Stream communication_stream();
|
||||
|
||||
/* Perform an all reduce sum operation */
|
||||
void all_reduce_sum(Group group, const array& input, array& output);
|
||||
|
||||
/* Perform an all reduce sum operation */
|
||||
void all_gather(Group group, const array& input, array& output);
|
||||
|
||||
} // namespace detail
|
||||
|
||||
} // namespace mlx::core::distributed
|
5
mlx/distributed/mpi/CMakeLists.txt
Normal file
5
mlx/distributed/mpi/CMakeLists.txt
Normal file
@@ -0,0 +1,5 @@
|
||||
target_sources(
|
||||
mlx
|
||||
PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/mpi.cpp
|
||||
)
|
283
mlx/distributed/mpi/mpi.cpp
Normal file
283
mlx/distributed/mpi/mpi.cpp
Normal file
@@ -0,0 +1,283 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <dlfcn.h>
|
||||
#include <mpi.h>
|
||||
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/distributed/distributed.h"
|
||||
#include "mlx/scheduler.h"
|
||||
|
||||
#define LOAD_SYMBOL(symbol, variable) \
|
||||
{ \
|
||||
variable = (decltype(variable))dlsym(libmpi_handle_, #symbol); \
|
||||
char* error = dlerror(); \
|
||||
if (error != nullptr) { \
|
||||
libmpi_handle_ = nullptr; \
|
||||
return; \
|
||||
} \
|
||||
}
|
||||
|
||||
namespace mlx::core::distributed {
|
||||
|
||||
namespace {
|
||||
|
||||
array ensure_row_contiguous(const array& arr) {
|
||||
if (arr.flags().row_contiguous) {
|
||||
return arr;
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy(arr, arr_copy, CopyType::General);
|
||||
return arr_copy;
|
||||
}
|
||||
}
|
||||
|
||||
struct MPIWrapper {
|
||||
MPIWrapper() {
|
||||
libmpi_handle_ = dlopen("libmpi.dylib", RTLD_NOW | RTLD_GLOBAL);
|
||||
if (libmpi_handle_ == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
// API
|
||||
LOAD_SYMBOL(MPI_Init, init);
|
||||
LOAD_SYMBOL(MPI_Finalize, finalize);
|
||||
LOAD_SYMBOL(MPI_Comm_rank, rank);
|
||||
LOAD_SYMBOL(MPI_Comm_size, size);
|
||||
LOAD_SYMBOL(MPI_Comm_split, comm_split);
|
||||
LOAD_SYMBOL(MPI_Comm_free, comm_free);
|
||||
LOAD_SYMBOL(MPI_Allreduce, all_reduce);
|
||||
LOAD_SYMBOL(MPI_Allgather, all_gather);
|
||||
|
||||
// Objects
|
||||
LOAD_SYMBOL(ompi_mpi_comm_world, comm_world_);
|
||||
|
||||
// Ops
|
||||
LOAD_SYMBOL(ompi_mpi_op_sum, op_sum_);
|
||||
|
||||
// Datatypes
|
||||
LOAD_SYMBOL(ompi_mpi_c_bool, mpi_bool_);
|
||||
LOAD_SYMBOL(ompi_mpi_int8_t, mpi_int8_);
|
||||
LOAD_SYMBOL(ompi_mpi_uint8_t, mpi_uint8_);
|
||||
LOAD_SYMBOL(ompi_mpi_int16_t, mpi_int16_);
|
||||
LOAD_SYMBOL(ompi_mpi_uint16_t, mpi_uint16_);
|
||||
LOAD_SYMBOL(ompi_mpi_int32_t, mpi_int32_);
|
||||
LOAD_SYMBOL(ompi_mpi_uint32_t, mpi_uint32_);
|
||||
LOAD_SYMBOL(ompi_mpi_int64_t, mpi_int64_);
|
||||
LOAD_SYMBOL(ompi_mpi_uint64_t, mpi_uint64_);
|
||||
LOAD_SYMBOL(ompi_mpi_float, mpi_float_);
|
||||
LOAD_SYMBOL(ompi_mpi_c_complex, mpi_complex_);
|
||||
}
|
||||
|
||||
bool is_available() {
|
||||
return libmpi_handle_ != nullptr;
|
||||
}
|
||||
|
||||
bool init_safe() {
|
||||
if (!is_available()) {
|
||||
return false;
|
||||
}
|
||||
return init(nullptr, nullptr) == MPI_SUCCESS;
|
||||
}
|
||||
|
||||
void finalize_safe() {
|
||||
if (is_available()) {
|
||||
finalize();
|
||||
}
|
||||
}
|
||||
|
||||
MPI_Comm world() {
|
||||
return comm_world_;
|
||||
}
|
||||
|
||||
MPI_Datatype datatype(const array& arr) {
|
||||
switch (arr.dtype()) {
|
||||
case bool_:
|
||||
return mpi_bool_;
|
||||
case int8:
|
||||
return mpi_int8_;
|
||||
case uint8:
|
||||
return mpi_uint8_;
|
||||
case int16:
|
||||
return mpi_int16_;
|
||||
case uint16:
|
||||
return mpi_uint16_;
|
||||
case int32:
|
||||
return mpi_int32_;
|
||||
case uint32:
|
||||
return mpi_uint32_;
|
||||
case int64:
|
||||
return mpi_int64_;
|
||||
case uint64:
|
||||
return mpi_uint64_;
|
||||
case float32:
|
||||
return mpi_float_;
|
||||
case complex64:
|
||||
return mpi_complex_;
|
||||
case float16:
|
||||
case bfloat16:
|
||||
throw std::runtime_error("MPI doesn't support 16-bit floats");
|
||||
}
|
||||
}
|
||||
|
||||
MPI_Op op_sum() {
|
||||
return op_sum_;
|
||||
}
|
||||
|
||||
void* libmpi_handle_;
|
||||
|
||||
// API
|
||||
int (*init)(int*, char***);
|
||||
int (*finalize)();
|
||||
int (*rank)(MPI_Comm, int*);
|
||||
int (*size)(MPI_Comm, int*);
|
||||
int (*all_reduce)(const void*, void*, int, MPI_Datatype, MPI_Op, MPI_Comm);
|
||||
int (*all_gather)(
|
||||
const void*,
|
||||
int,
|
||||
MPI_Datatype,
|
||||
void*,
|
||||
int,
|
||||
MPI_Datatype,
|
||||
MPI_Comm);
|
||||
int (*comm_split)(MPI_Comm, int, int, MPI_Comm*);
|
||||
int (*comm_free)(MPI_Comm*);
|
||||
|
||||
// Objects
|
||||
MPI_Comm comm_world_;
|
||||
|
||||
// Ops
|
||||
MPI_Op op_sum_;
|
||||
|
||||
// Datatypes
|
||||
MPI_Datatype mpi_bool_;
|
||||
MPI_Datatype mpi_int8_;
|
||||
MPI_Datatype mpi_uint8_;
|
||||
MPI_Datatype mpi_int16_;
|
||||
MPI_Datatype mpi_uint16_;
|
||||
MPI_Datatype mpi_int32_;
|
||||
MPI_Datatype mpi_uint32_;
|
||||
MPI_Datatype mpi_int64_;
|
||||
MPI_Datatype mpi_uint64_;
|
||||
MPI_Datatype mpi_float_;
|
||||
MPI_Datatype mpi_complex_;
|
||||
};
|
||||
|
||||
MPIWrapper& mpi() {
|
||||
static MPIWrapper wrapper;
|
||||
return wrapper;
|
||||
}
|
||||
|
||||
struct MPIGroupImpl {
|
||||
MPIGroupImpl(MPI_Comm comm, bool global)
|
||||
: comm_(comm), global_(global), rank_(-1), size_(-1) {}
|
||||
~MPIGroupImpl() {
|
||||
if (global_) {
|
||||
mpi().finalize_safe();
|
||||
} else {
|
||||
mpi().comm_free(&comm_);
|
||||
}
|
||||
}
|
||||
|
||||
MPI_Comm comm() {
|
||||
return comm_;
|
||||
}
|
||||
|
||||
int rank() {
|
||||
if (rank_ < 0) {
|
||||
mpi().rank(comm_, &rank_);
|
||||
}
|
||||
return rank_;
|
||||
}
|
||||
|
||||
int size() {
|
||||
if (size_ < 0) {
|
||||
mpi().size(comm_, &size_);
|
||||
}
|
||||
return size_;
|
||||
}
|
||||
|
||||
private:
|
||||
MPI_Comm comm_;
|
||||
bool global_;
|
||||
int rank_;
|
||||
int size_;
|
||||
};
|
||||
|
||||
MPI_Comm to_comm(Group& group) {
|
||||
return std::static_pointer_cast<MPIGroupImpl>(group.raw_group())->comm();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
int Group::rank() {
|
||||
return std::static_pointer_cast<MPIGroupImpl>(group_)->rank();
|
||||
}
|
||||
|
||||
int Group::size() {
|
||||
return std::static_pointer_cast<MPIGroupImpl>(group_)->size();
|
||||
}
|
||||
|
||||
Group Group::split(int color, int key) {
|
||||
auto mpi_group = std::static_pointer_cast<MPIGroupImpl>(group_);
|
||||
|
||||
key = (key < 0) ? rank() : key;
|
||||
|
||||
MPI_Comm new_comm;
|
||||
int result = mpi().comm_split(mpi_group->comm(), color, key, &new_comm);
|
||||
if (result != MPI_SUCCESS) {
|
||||
throw std::runtime_error("MPI could not split this group");
|
||||
}
|
||||
|
||||
return Group(std::make_shared<MPIGroupImpl>(new_comm, false));
|
||||
}
|
||||
|
||||
bool is_available() {
|
||||
return mpi().is_available();
|
||||
}
|
||||
|
||||
Group init() {
|
||||
static std::shared_ptr<MPIGroupImpl> global_group = nullptr;
|
||||
|
||||
if (global_group == nullptr) {
|
||||
if (!mpi().init_safe()) {
|
||||
throw std::runtime_error("Cannot initialize MPI");
|
||||
}
|
||||
global_group = std::make_shared<MPIGroupImpl>(mpi().world(), true);
|
||||
}
|
||||
|
||||
return Group(global_group);
|
||||
}
|
||||
|
||||
namespace detail {
|
||||
|
||||
Stream communication_stream() {
|
||||
static Stream comm_stream = new_stream(Device::cpu);
|
||||
return comm_stream;
|
||||
}
|
||||
|
||||
void all_reduce_sum(Group group, const array& input_, array& output) {
|
||||
array input = ensure_row_contiguous(input_);
|
||||
mpi().all_reduce(
|
||||
input.data<void>(),
|
||||
output.data<void>(),
|
||||
input.size(),
|
||||
mpi().datatype(input),
|
||||
mpi().op_sum(),
|
||||
to_comm(group));
|
||||
}
|
||||
|
||||
void all_gather(Group group, const array& input_, array& output) {
|
||||
array input = ensure_row_contiguous(input_);
|
||||
mpi().all_gather(
|
||||
input.data<void>(),
|
||||
input.size(),
|
||||
mpi().datatype(input),
|
||||
output.data<void>(),
|
||||
input.size(),
|
||||
mpi().datatype(output),
|
||||
to_comm(group));
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
} // namespace mlx::core::distributed
|
39
mlx/distributed/no_distributed.cpp
Normal file
39
mlx/distributed/no_distributed.cpp
Normal file
@@ -0,0 +1,39 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "mlx/distributed/distributed.h"
|
||||
|
||||
namespace mlx::core::distributed {
|
||||
|
||||
int Group::rank() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
int Group::size() {
|
||||
return 1;
|
||||
}
|
||||
|
||||
Group Group::split(int color, int key) {
|
||||
throw std::runtime_error("Cannot split the distributed group further");
|
||||
}
|
||||
|
||||
bool is_available() {
|
||||
return false;
|
||||
}
|
||||
|
||||
Group init() {
|
||||
return Group(nullptr);
|
||||
}
|
||||
|
||||
namespace detail {
|
||||
|
||||
Stream communication_stream() {
|
||||
static Stream comm_stream = new_stream(Device::cpu);
|
||||
return comm_stream;
|
||||
}
|
||||
|
||||
void all_reduce_sum(Group group, const array& input, array& output) {}
|
||||
void all_gather(Group group, const array& input, array& output) {}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
} // namespace mlx::core::distributed
|
54
mlx/distributed/ops.cpp
Normal file
54
mlx/distributed/ops.cpp
Normal file
@@ -0,0 +1,54 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "mlx/distributed/ops.h"
|
||||
#include "mlx/distributed/primitives.h"
|
||||
|
||||
namespace mlx::core::distributed {
|
||||
|
||||
namespace {
|
||||
|
||||
Group to_group(std::optional<Group> group) {
|
||||
if (group.has_value()) {
|
||||
return group.value();
|
||||
} else {
|
||||
return distributed::init();
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
array all_reduce_sum(const array& x, std::optional<Group> group_) {
|
||||
auto group = to_group(group_);
|
||||
|
||||
if (group.size() == 1) {
|
||||
return x;
|
||||
}
|
||||
|
||||
return array(
|
||||
x.shape(),
|
||||
x.dtype(),
|
||||
std::make_shared<AllReduce>(group, AllReduce::Sum),
|
||||
{x});
|
||||
}
|
||||
|
||||
array all_gather(const array& x, std::optional<Group> group_) {
|
||||
auto group = to_group(group_);
|
||||
|
||||
if (group.size() == 1) {
|
||||
return x;
|
||||
}
|
||||
|
||||
auto result_shape = x.shape();
|
||||
if (result_shape.size() == 0) {
|
||||
result_shape.push_back(group.size());
|
||||
} else {
|
||||
result_shape[0] *= group.size();
|
||||
}
|
||||
return array(
|
||||
std::move(result_shape),
|
||||
x.dtype(),
|
||||
std::make_shared<AllGather>(group),
|
||||
{x});
|
||||
}
|
||||
|
||||
} // namespace mlx::core::distributed
|
14
mlx/distributed/ops.h
Normal file
14
mlx/distributed/ops.h
Normal file
@@ -0,0 +1,14 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <optional>
|
||||
|
||||
#include "mlx/distributed/distributed.h"
|
||||
|
||||
namespace mlx::core::distributed {
|
||||
|
||||
array all_reduce_sum(const array& x, std::optional<Group> group = std::nullopt);
|
||||
array all_gather(const array& x, std::optional<Group> group = std::nullopt);
|
||||
|
||||
} // namespace mlx::core::distributed
|
98
mlx/distributed/primitives.cpp
Normal file
98
mlx/distributed/primitives.cpp
Normal file
@@ -0,0 +1,98 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/distributed/ops.h"
|
||||
#include "mlx/distributed/primitives.h"
|
||||
#include "mlx/ops.h"
|
||||
|
||||
namespace mlx::core::distributed {
|
||||
|
||||
void AllReduce::eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
assert(inputs.size() == 1);
|
||||
assert(outputs.size() == 1);
|
||||
|
||||
outputs[0].set_data(allocator::malloc_or_wait(outputs[0].nbytes()));
|
||||
|
||||
switch (reduce_type_) {
|
||||
case Sum:
|
||||
distributed::detail::all_reduce_sum(group(), inputs[0], outputs[0]);
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error("Only all reduce sum is supported for now");
|
||||
}
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> AllReduce::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
switch (reduce_type_) {
|
||||
case Sum:
|
||||
return {{all_reduce_sum(inputs[0], group())}, axes};
|
||||
default:
|
||||
throw std::runtime_error("Only all reduce sum is supported for now");
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<array> AllReduce::jvp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>& argnums) {
|
||||
switch (reduce_type_) {
|
||||
case Sum:
|
||||
return {all_reduce_sum(tangents[0], group())};
|
||||
default:
|
||||
throw std::runtime_error("Only all reduce sum is supported for now");
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<array> AllReduce::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>& outputs) {
|
||||
return cotangents;
|
||||
}
|
||||
|
||||
void AllGather::eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
assert(inputs.size() == 1);
|
||||
assert(outputs.size() == 1);
|
||||
|
||||
outputs[0].set_data(allocator::malloc_or_wait(outputs[0].nbytes()));
|
||||
|
||||
distributed::detail::all_gather(group(), inputs[0], outputs[0]);
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> AllGather::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
return {{all_gather(inputs[0], group())}, axes};
|
||||
}
|
||||
|
||||
std::vector<array> AllGather::jvp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>& argnums) {
|
||||
return {all_gather(tangents[0], group())};
|
||||
}
|
||||
|
||||
std::vector<array> AllGather::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>& outputs) {
|
||||
auto g = group();
|
||||
std::vector<int> starts(primals[0].ndim(), 0);
|
||||
auto stops = primals[0].shape();
|
||||
starts[0] = g.rank() * stops[0];
|
||||
stops[0] += starts[0];
|
||||
return {slice(cotangents[0], starts, stops)};
|
||||
}
|
||||
|
||||
} // namespace mlx::core::distributed
|
100
mlx/distributed/primitives.h
Normal file
100
mlx/distributed/primitives.h
Normal file
@@ -0,0 +1,100 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/distributed/distributed.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core::distributed {
|
||||
|
||||
class DistPrimitive : public Primitive {
|
||||
public:
|
||||
DistPrimitive(Group group)
|
||||
: Primitive(detail::communication_stream()), group_(group) {}
|
||||
|
||||
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override {
|
||||
throw std::runtime_error(
|
||||
"Communication primitives cannot be run on the GPU");
|
||||
}
|
||||
|
||||
const Group& group() const {
|
||||
return group_;
|
||||
}
|
||||
|
||||
private:
|
||||
Group group_;
|
||||
};
|
||||
|
||||
class AllReduce : public DistPrimitive {
|
||||
public:
|
||||
enum ReduceType { And, Or, Sum, Prod, Min, Max };
|
||||
|
||||
AllReduce(Group group, ReduceType reduce_type)
|
||||
: DistPrimitive(group), reduce_type_(reduce_type) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
std::pair<std::vector<array>, std::vector<int>> vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) override;
|
||||
std::vector<array> jvp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>& argnums) override;
|
||||
std::vector<array> vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>& outputs) override;
|
||||
|
||||
void print(std::ostream& os) override {
|
||||
switch (reduce_type_) {
|
||||
case And:
|
||||
os << "And";
|
||||
case Or:
|
||||
os << "And";
|
||||
break;
|
||||
case Sum:
|
||||
os << "Sum";
|
||||
break;
|
||||
case Prod:
|
||||
os << "Prod";
|
||||
break;
|
||||
case Min:
|
||||
os << "Min";
|
||||
break;
|
||||
case Max:
|
||||
os << "Max";
|
||||
break;
|
||||
}
|
||||
os << " AllReduce";
|
||||
}
|
||||
|
||||
private:
|
||||
ReduceType reduce_type_;
|
||||
};
|
||||
|
||||
class AllGather : public DistPrimitive {
|
||||
public:
|
||||
AllGather(Group group) : DistPrimitive(group) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
std::pair<std::vector<array>, std::vector<int>> vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) override;
|
||||
std::vector<array> jvp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>& argnums) override;
|
||||
std::vector<array> vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>& outputs) override;
|
||||
|
||||
DEFINE_PRINT(AllGather);
|
||||
};
|
||||
|
||||
} // namespace mlx::core::distributed
|
Reference in New Issue
Block a user