Compare commits

...

14 Commits

Author SHA1 Message Date
Ronan Collobert
87b680766e Gloo backend support 2024-11-13 13:52:37 -08:00
Ronan Collobert
70ffaa50d2 be more relaxed on OpenMPI version 2024-11-13 13:51:37 -08:00
Angelos Katharopoulos
d82699f0f1 Merge branch 'distributed-layers' into socket-distributed-layers 2024-11-05 11:36:16 -08:00
Angelos Katharopoulos
6fc00d2c10 Add rudimentary barrier 2024-11-05 11:34:55 -08:00
Angelos Katharopoulos
44f0de2854 Fix run without distributed 2024-11-05 11:27:41 -08:00
Angelos Katharopoulos
29ec3539ed TCP socket distributed 2024-11-05 11:27:41 -08:00
Angelos Katharopoulos
e94f0028c3 Change the send message size 2024-11-05 11:27:41 -08:00
Angelos Katharopoulos
e5354fcddb Make it work even for donated inputs 2024-11-05 11:27:41 -08:00
Angelos Katharopoulos
34dd079a64 Start a sockets based distributed backend 2024-11-05 11:27:41 -08:00
Angelos Katharopoulos
16975815e9 Fixes in distributed layers 2024-11-05 11:27:26 -08:00
Angelos Katharopoulos
a8b3da7946 Add distributed layers to nn top-level 2024-11-05 11:27:26 -08:00
Angelos Katharopoulos
060e1c9f92 Add quantized distributed layers 2024-11-05 11:27:26 -08:00
Angelos Katharopoulos
0b04742985 Add the distributed linear layers 2024-11-05 11:27:26 -08:00
Angelos Katharopoulos
c3ccd4919f Add MPI barrier 2024-11-05 11:26:53 -08:00
14 changed files with 1230 additions and 9 deletions

View File

@@ -168,11 +168,12 @@ endif()
find_package(MPI)
if(MPI_FOUND)
execute_process(
COMMAND zsh "-c" "mpirun --version"
COMMAND zsh "-c" "${MPIEXEC_EXECUTABLE} --version"
OUTPUT_VARIABLE MPI_VERSION
ERROR_QUIET)
if(${MPI_VERSION} MATCHES ".*Open MPI.*")
if(${MPI_VERSION} MATCHES ".*Open MPI.*" OR ${MPI_VERSION} MATCHES ".*OpenRTE.*")
target_include_directories(mlx PRIVATE ${MPI_INCLUDE_PATH})
target_link_libraries(mlx PRIVATE ${MPI_CXX_LIBRARIES})
elseif(MPI_VERSION STREQUAL "")
set(MPI_FOUND FALSE)
message(

View File

@@ -1,8 +1,20 @@
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp)
if(MPI_FOUND AND MLX_BUILD_CPU)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/mpi)
else()
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_distributed.cpp)
if(MLX_BUILD_CPU)
if(MLX_CUSTOM_DISTRIBUTED)
if(MLX_CUSTOM_DISTRIBUTED STREQUAL "gloo")
message(STATUS "Distributed: using gloo backend")
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/gloo)
else()
message(STATUS "Distributed: using sockets backend")
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/sockets)
endif()
elseif(MPI_FOUND)
message(STATUS "Distributed: using MPI backend")
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/mpi)
else()
message(STATUS "Distributed: no support")
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_distributed.cpp)
endif()
endif()

View File

@@ -32,6 +32,8 @@ struct Group {
*/
Group split(int color, int key = -1);
void barrier();
const std::shared_ptr<void>& raw_group() {
return group_;
}

View File

@@ -0,0 +1,25 @@
find_path(
GLOO_INCLUDE_DIR gloo/allreduce.h
PATHS ${GLOO_INC_DIR}
PATH_SUFFIXES include)
find_library(
GLOO_LIBRARY gloo
PATHS ${GLOO_LIB_DIR}
PATH_SUFFIXES lib
HINTS GLOO)
find_library(
UV_LIBRARY uv
PATHS ${UV_LIB_DIR}
PATH_SUFFIXES lib
HINTS UV)
message(STATUS "GLOO LIB <${GLOO_LIBRARY}>")
message(STATUS "GLOO INC <${GLOO_INCLUDE_DIR}>")
message(STATUS "UV LIB <${UV_LIB_DIR}>")
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gloo.cpp)
target_link_libraries(mlx PUBLIC ${GLOO_LIBRARY})
target_link_libraries(mlx PUBLIC ${UV_LIBRARY})
target_include_directories(mlx PRIVATE ${GLOO_INCLUDE_DIR})

View File

@@ -0,0 +1,178 @@
// Copyright © 2024 Apple Inc.
#include <unistd.h>
#include <chrono>
#include <cstdlib>
#include <fstream>
#include <iostream>
#include <sstream>
#include <thread>
#include "mlx/backend/common/copy.h"
#include "mlx/distributed/distributed.h"
#include "mlx/distributed/distributed_impl.h"
#include "mlx/io/threadpool.h"
#include "gloo/allreduce.h"
#include "gloo/math.h"
#include "gloo/mpi/context.h"
#include "gloo/transport/uv/device.h"
#define SWITCH_TYPE(x, ...) \
switch ((x).dtype()) { \
case bool_: { \
using T = bool; \
__VA_ARGS__; \
} break; \
case int8: { \
using T = int8_t; \
__VA_ARGS__; \
} break; \
case int16: { \
using T = int16_t; \
__VA_ARGS__; \
} break; \
case int32: { \
using T = int32_t; \
__VA_ARGS__; \
} break; \
case int64: { \
using T = int64_t; \
__VA_ARGS__; \
} break; \
case uint8: { \
using T = uint8_t; \
__VA_ARGS__; \
} break; \
case uint16: { \
using T = uint16_t; \
__VA_ARGS__; \
} break; \
case uint32: { \
using T = uint32_t; \
__VA_ARGS__; \
} break; \
case uint64: { \
using T = uint64_t; \
__VA_ARGS__; \
} break; \
case bfloat16: { \
using T = bfloat16_t; \
__VA_ARGS__; \
} break; \
case float16: { \
using T = float16_t; \
__VA_ARGS__; \
} break; \
case float32: { \
using T = float; \
__VA_ARGS__; \
} break; \
case complex64: { \
using T = complex64_t; \
__VA_ARGS__; \
} break; \
}
namespace mlx::core::distributed {
namespace {
array ensure_row_contiguous(const array& arr) {
if (arr.flags().row_contiguous) {
return arr;
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::General);
return arr_copy;
}
}
} // namespace
bool is_available() {
return true;
}
int Group::rank() {
return std::static_pointer_cast<gloo::mpi::Context>(group_)->rank;
}
int Group::size() {
return std::static_pointer_cast<gloo::mpi::Context>(group_)->size;
}
Group Group::split(int color, int key) {
throw std::runtime_error("split is NYI");
}
void Group::barrier() {
throw std::runtime_error("barrier is NYI");
}
struct GlooCTX {
std::shared_ptr<gloo::mpi::Context> context;
std::shared_ptr<gloo::transport::Device> dev;
};
Group init(bool strict /* = false */) {
static std::shared_ptr<GlooCTX> gloo_ctx = nullptr;
if (gloo_ctx == nullptr) {
gloo_ctx = std::make_shared<GlooCTX>();
gloo_ctx->context = gloo::mpi::Context::createManaged();
gloo_ctx->dev = gloo::transport::uv::CreateDevice("localhost");
gloo_ctx->context->connectFullMesh(gloo_ctx->dev);
}
return Group(gloo_ctx->context);
}
namespace detail {
Stream communication_stream() {
static Stream comm_stream = new_stream(Device::cpu);
return comm_stream;
}
template <typename T>
void all_reduce_sum(
std::shared_ptr<gloo::mpi::Context> context,
T* output,
T* input,
size_t len) {
gloo::AllreduceOptions opts_(context);
opts_.setInput(input, len);
opts_.setOutput(output, len);
opts_.setAlgorithm(gloo::AllreduceOptions::Algorithm::RING);
opts_.setReduceFunction(
static_cast<void (*)(void*, const void*, const void*, size_t)>(
&gloo::sum<T>));
gloo::allreduce(opts_);
}
void all_sum(Group group_, const array& input_, array& output) {
array input = ensure_row_contiguous(input_);
if (input.data<void>() != output.data<void>()) {
std::memcpy(output.data<char>(), input.data<char>(), input.nbytes());
}
auto context =
std::static_pointer_cast<gloo::mpi::Context>(group_.raw_group());
SWITCH_TYPE(
output,
all_reduce_sum<T>(
context, output.data<T>(), input.data<T>(), input.size()));
}
void all_gather(Group group_, const array& input_, array& output) {
throw std::runtime_error("all_gather NYI");
}
void send(Group group_, const array& input_, int dst) {
throw std::runtime_error("send NYI");
}
void recv(Group group_, array& out, int src) {
throw std::runtime_error("recv NYI");
}
} // namespace detail
} // namespace mlx::core::distributed

View File

@@ -71,6 +71,7 @@ struct MPIWrapper {
LOAD_SYMBOL(MPI_Allgather, all_gather);
LOAD_SYMBOL(MPI_Send, send);
LOAD_SYMBOL(MPI_Recv, recv);
LOAD_SYMBOL(MPI_Barrier, barrier);
LOAD_SYMBOL(MPI_Type_contiguous, mpi_type_contiguous);
LOAD_SYMBOL(MPI_Type_commit, mpi_type_commit);
LOAD_SYMBOL(MPI_Op_create, mpi_op_create);
@@ -195,6 +196,7 @@ struct MPIWrapper {
int (*comm_free)(MPI_Comm*);
int (*send)(const void*, int, MPI_Datatype, int, int, MPI_Comm);
int (*recv)(void*, int, MPI_Datatype, int, int, MPI_Comm, MPI_Status*);
int (*barrier)(MPI_Comm);
// Objects
MPI_Comm comm_world_;
@@ -263,6 +265,10 @@ struct MPIGroupImpl {
return size_;
}
void barrier() {
mpi().barrier(comm_);
}
private:
MPI_Comm comm_;
bool global_;
@@ -298,6 +304,11 @@ Group Group::split(int color, int key) {
return Group(std::make_shared<MPIGroupImpl>(new_comm, false));
}
void Group::barrier() {
auto mpi_group = std::static_pointer_cast<MPIGroupImpl>(group_);
mpi_group->barrier();
}
bool is_available() {
return mpi().is_available();
}

View File

@@ -17,6 +17,8 @@ Group Group::split(int color, int key) {
throw std::runtime_error("Cannot split the distributed group further");
}
void Group::barrier() {}
bool is_available() {
return false;
}

View File

@@ -0,0 +1,5 @@
target_sources(
mlx
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/sockets.cpp
)

View File

@@ -0,0 +1,522 @@
// Copyright © 2024 Apple Inc.
#include <arpa/inet.h>
#include <json.hpp>
#include <netdb.h>
#include <sys/socket.h>
#include <unistd.h>
#include <chrono>
#include <cstdlib>
#include <fstream>
#include <iostream>
#include <sstream>
#include <thread>
#include "mlx/backend/common/copy.h"
#include "mlx/distributed/distributed.h"
#include "mlx/distributed/distributed_impl.h"
#include "mlx/io/threadpool.h"
#define SWITCH_TYPE(x, ...) \
switch ((x).dtype()) { \
case bool_: { \
using T = bool; \
__VA_ARGS__; \
} break; \
case int8: { \
using T = int8_t; \
__VA_ARGS__; \
} break; \
case int16: { \
using T = int16_t; \
__VA_ARGS__; \
} break; \
case int32: { \
using T = int32_t; \
__VA_ARGS__; \
} break; \
case int64: { \
using T = int64_t; \
__VA_ARGS__; \
} break; \
case uint8: { \
using T = uint8_t; \
__VA_ARGS__; \
} break; \
case uint16: { \
using T = uint16_t; \
__VA_ARGS__; \
} break; \
case uint32: { \
using T = uint32_t; \
__VA_ARGS__; \
} break; \
case uint64: { \
using T = uint64_t; \
__VA_ARGS__; \
} break; \
case bfloat16: { \
using T = bfloat16_t; \
__VA_ARGS__; \
} break; \
case float16: { \
using T = float16_t; \
__VA_ARGS__; \
} break; \
case float32: { \
using T = float; \
__VA_ARGS__; \
} break; \
case complex64: { \
using T = complex64_t; \
__VA_ARGS__; \
} break; \
}
constexpr const size_t PACKET_SIZE = 262144;
constexpr const int CONN_ATTEMPTS = 5;
constexpr const int CONN_WAIT = 1000;
using json = nlohmann::json;
namespace mlx::core::distributed {
namespace {
template <typename T>
void sum_inplace(const T* input, T* output, size_t N) {
while (N-- > 0) {
*output += *input;
input++;
output++;
}
}
void sum_inplace(const array& input, array& output) {
SWITCH_TYPE(
input, sum_inplace(input.data<T>(), output.data<T>(), input.size()));
}
array ensure_row_contiguous(const array& arr) {
if (arr.flags().row_contiguous) {
return arr;
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::General);
return arr_copy;
}
}
struct address_t {
sockaddr_storage addr;
socklen_t len;
const sockaddr* sockaddr() {
return (struct sockaddr*)&addr;
}
};
address_t parse_address(std::string ip, std::string port) {
struct addrinfo hints, *res;
memset(&hints, 0, sizeof(hints));
hints.ai_family = AF_UNSPEC;
hints.ai_socktype = SOCK_STREAM;
int status = getaddrinfo(ip.c_str(), port.c_str(), &hints, &res);
if (status != 0) {
std::ostringstream msg;
msg << "Can't parse peer address " << ip << ":" << port;
throw std::runtime_error(msg.str());
}
address_t result;
memcpy(&result.addr, res->ai_addr, res->ai_addrlen);
result.len = res->ai_addrlen;
freeaddrinfo(res);
return result;
}
std::vector<address_t> load_peers() {
std::vector<address_t> peers;
std::ifstream f;
if (const char* hostfile_buf = std::getenv("MLX_HOSTFILE")) {
f.open(hostfile_buf);
} else {
return peers;
}
json hosts = json::parse(f);
for (auto& h : hosts) {
peers.push_back(std::move(parse_address(
h["ip"].template get<std::string>(),
h["port"].template get<std::string>())));
}
return peers;
}
struct GroupImpl {
GroupImpl(std::vector<address_t> peers, int rank, bool global)
: rank_(rank), global_(global), pool_(4), sockets_(peers.size(), -1) {
if (rank_ > 0 && rank_ >= peers.size()) {
throw std::runtime_error(
"Rank cannot be larger than the size of the group");
}
int success;
// If we are expecting anyone to connect to us
if (rank_ + 1 < peers.size()) {
// Create the socket to wait for connections from the peers
int sock = socket(AF_INET, SOCK_STREAM, 0);
if (sock < 0) {
std::ostringstream msg;
msg << "Couldn't create socket (error: " << errno << ")";
throw std::runtime_error(msg.str());
}
// Make sure we can launch immediately after shutdown by setting the
// reuseaddr option so that we don't get address already in use errors
int enable = 1;
success =
setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &enable, sizeof(int));
if (success < 0) {
shutdown(sock, 2);
close(sock);
std::ostringstream msg;
msg << "Couldn't enable reuseaddr (rank: " << rank_
<< " error: " << errno << ")";
throw std::runtime_error(msg.str());
}
success =
setsockopt(sock, SOL_SOCKET, SO_REUSEPORT, &enable, sizeof(int));
if (success < 0) {
shutdown(sock, 2);
close(sock);
std::ostringstream msg;
msg << "Couldn't enable reuseport (rank: " << rank_
<< " error: " << errno << ")";
throw std::runtime_error(msg.str());
}
// Bind it to the port
success = bind(sock, peers[rank_].sockaddr(), peers[rank_].len);
if (success < 0) {
shutdown(sock, 2);
close(sock);
std::ostringstream msg;
msg << "Couldn't bind socket (rank: " << rank_ << " error: " << errno
<< ")";
throw std::runtime_error(msg.str());
}
// Wait for connections
success = listen(sock, 0);
if (success < 0) {
shutdown(sock, 2);
close(sock);
std::ostringstream msg;
msg << "Couldn't listen (error: " << errno << ")";
throw std::runtime_error(msg.str());
}
for (int i = 0; i < peers.size() - rank_ - 1; i++) {
int peer_socket = accept(sock, nullptr, nullptr);
if (peer_socket < 0) {
shutdown(sock, 2);
close(sock);
std::ostringstream msg;
msg << "Accept failed (error: " << errno << ")";
throw std::runtime_error(msg.str());
}
sockets_[peers.size() - 1 - i] = peer_socket;
}
// Close the listening socket
shutdown(sock, 2);
close(sock);
}
// Connect to the peers with smaller rank
for (int i = 0; i < rank_; i++) {
sockets_[i] = socket(AF_INET, SOCK_STREAM, 0);
if (sockets_[i] < 0) {
std::ostringstream msg;
msg << "Couldn't create socket (error: " << errno << ")";
throw std::runtime_error(msg.str());
}
for (int attempt = 0; attempt < CONN_ATTEMPTS; attempt++) {
if (attempt > 0) {
int wait = (1 << (attempt - 1)) * CONN_WAIT;
std::this_thread::sleep_for(std::chrono::milliseconds(wait));
}
success = connect(sockets_[i], peers[i].sockaddr(), peers[i].len);
if (success == 0) {
break;
}
}
if (success < 0) {
std::ostringstream msg;
msg << "Couldn't connect (rank: " << rank_ << " to: " << i
<< " error: " << errno << ")";
throw std::runtime_error(msg.str());
}
}
}
~GroupImpl() {
if (global_) {
for (int sock : sockets_) {
shutdown(sock, 2);
close(sock);
}
}
}
int rank() {
return rank_;
}
int size() {
return std::max(sockets_.size(), 1ul);
}
void send(const char* buf, size_t len, int dst) {
while (len > 0) {
ssize_t r = ::send(sockets_[dst], buf, len, 0);
if (r <= 0) {
std::ostringstream msg;
msg << "Send of " << len << " bytes failed (errno: " << errno << ")";
throw std::runtime_error(msg.str());
}
buf += r;
len -= r;
}
}
void recv(char* buf, size_t len, int src) {
while (len > 0) {
ssize_t r = ::recv(sockets_[src], buf, len, 0);
if (r <= 0) {
std::ostringstream msg;
msg << "Recv of " << len << " bytes failed (errno: " << errno << ")";
throw std::runtime_error(msg.str());
}
buf += r;
len -= r;
}
}
template <typename T>
void send_recv_sum(char* buf, size_t len, int peer) {
char recv_buffer[2 * PACKET_SIZE];
char* recv_buffers[2];
recv_buffers[0] = recv_buffer;
recv_buffers[1] = recv_buffer + PACKET_SIZE;
std::future<void> sent, received;
size_t n_blocks = (len + PACKET_SIZE - 1) / PACKET_SIZE;
for (size_t b = 0; b < n_blocks; b++) {
if (b > 0) {
sent.wait();
received.wait();
}
size_t l = std::min(len - b * PACKET_SIZE, PACKET_SIZE);
if (rank_ < peer) {
sent = send_async(buf + b * PACKET_SIZE, l, peer);
received = recv_async(recv_buffers[b % 2], l, peer);
} else {
received = recv_async(recv_buffers[b % 2], l, peer);
sent = send_async(buf + b * PACKET_SIZE, l, peer);
}
if (b > 0) {
sum_inplace(
(const T*)recv_buffers[(b - 1) % 2],
(T*)(buf + (b - 1) * PACKET_SIZE),
PACKET_SIZE / sizeof(T));
}
}
sent.wait();
received.wait();
size_t l = std::min(len - (n_blocks - 1) * PACKET_SIZE, PACKET_SIZE);
sum_inplace(
(const T*)recv_buffers[(n_blocks - 1) % 2],
(T*)(buf + (n_blocks - 1) * PACKET_SIZE),
l / sizeof(T));
}
void send_recv_sum(array& out, int peer) {
SWITCH_TYPE(out, send_recv_sum<T>(out.data<char>(), out.nbytes(), peer));
}
std::future<void> send_async(const char* buf, size_t len, int dst) {
return pool_.enqueue(
[this, buf, len, dst]() { this->send(buf, len, dst); });
}
std::future<void> recv_async(char* buf, size_t len, int src) {
return pool_.enqueue(
[this, buf, len, src]() { this->recv(buf, len, src); });
}
private:
int rank_;
bool global_;
ThreadPool pool_;
std::vector<int> sockets_;
};
} // namespace
bool is_available() {
return true;
}
int Group::rank() {
return std::static_pointer_cast<GroupImpl>(group_)->rank();
}
int Group::size() {
return std::static_pointer_cast<GroupImpl>(group_)->size();
}
Group Group::split(int color, int key) {
throw std::runtime_error("Splitting not supported yet");
}
void Group::barrier() {
char buff[128];
std::memset(buff, 1, 128);
auto group = std::static_pointer_cast<GroupImpl>(raw_group());
int size = group->size();
int rank = group->rank();
for (int distance = 1; distance <= size / 2; distance *= 2) {
group->send_recv_sum<char>(buff, 128, rank ^ distance);
}
}
Group init(bool strict /* = false */) {
static std::shared_ptr<GroupImpl> global_group = nullptr;
if (global_group == nullptr) {
auto peers = load_peers();
int rank = 0;
if (const char* rank_buf = std::getenv("MLX_RANK")) {
rank = std::atoi(rank_buf);
}
if (peers.size() == 0) {
if (strict) {
throw std::runtime_error("Can't initialize distributed");
}
}
global_group = std::make_shared<GroupImpl>(std::move(peers), rank, true);
}
return Group(global_group);
}
namespace detail {
Stream communication_stream() {
static Stream comm_stream = new_stream(Device::cpu);
return comm_stream;
}
void all_sum(Group group_, const array& input_, array& output) {
auto group = std::static_pointer_cast<GroupImpl>(group_.raw_group());
array input = ensure_row_contiguous(input_);
int size = group->size();
int rank = group->rank();
if ((size & (size - 1)) != 0) {
throw std::runtime_error("Only powers of 2 are currently supported");
}
// If not inplace all reduce then copy the input to the output first.
if (input.data<void>() != output.data<void>()) {
std::memcpy(output.data<char>(), input.data<char>(), input.nbytes());
}
// Butterfly all reduce
for (int distance = 1; distance <= size / 2; distance *= 2) {
group->send_recv_sum(output, rank ^ distance);
}
}
void all_gather(Group group_, const array& input_, array& output) {
auto group = std::static_pointer_cast<GroupImpl>(group_.raw_group());
array input = ensure_row_contiguous(input_);
std::future<void> sent;
std::future<void> received;
int rank = group->rank();
int size = group->size();
if ((size & (size - 1)) != 0) {
throw std::runtime_error("Only powers of 2 are currently supported");
}
// Butterfly all gather
int peer = rank ^ 1;
if (peer < rank) {
received = group->recv_async(
output.data<char>() + peer * input.nbytes(), input.nbytes(), peer);
sent = group->send_async(input.data<char>(), input.nbytes(), peer);
} else {
sent = group->send_async(input.data<char>(), input.nbytes(), peer);
received = group->recv_async(
output.data<char>() + peer * input.nbytes(), input.nbytes(), peer);
}
std::memcpy(
output.data<char>() + rank * input.nbytes(),
input.data<char>(),
input.nbytes());
for (int distance = 2; distance <= size / 2; distance *= 2) {
sent.wait();
received.wait();
int peer = rank ^ distance;
int their_offset = peer & ~(distance - 1);
int our_offset = rank & ~(distance - 1);
if (peer < rank) {
received = group->recv_async(
output.data<char>() + their_offset * input.nbytes(),
distance * input.nbytes(),
peer);
sent = group->send_async(
output.data<char>() + our_offset * input.nbytes(),
distance * input.nbytes(),
peer);
} else {
sent = group->send_async(
output.data<char>() + our_offset * input.nbytes(),
distance * input.nbytes(),
peer);
received = group->recv_async(
output.data<char>() + their_offset * input.nbytes(),
distance * input.nbytes(),
peer);
}
}
sent.wait();
received.wait();
}
void send(Group group_, const array& input_, int dst) {
array input = ensure_row_contiguous(input_);
auto group = std::static_pointer_cast<GroupImpl>(group_.raw_group());
group->send(input.data<char>(), input.nbytes(), dst);
}
void recv(Group group_, array& out, int src) {
auto group = std::static_pointer_cast<GroupImpl>(group_.raw_group());
group->recv(out.data<char>(), out.nbytes(), src);
}
} // namespace detail
} // namespace mlx::core::distributed

View File

@@ -1,5 +1,5 @@
// Copyright © 2023 Apple Inc.
//
#include <json.hpp>
#include <stack>

View File

@@ -60,6 +60,12 @@ from mlx.nn.layers.convolution_transpose import (
ConvTranspose2d,
ConvTranspose3d,
)
from mlx.nn.layers.distributed import (
AllToShardedLinear,
QuantizedAllToShardedLinear,
QuantizedShardedToAllLinear,
ShardedToAllLinear,
)
from mlx.nn.layers.dropout import Dropout, Dropout2d, Dropout3d
from mlx.nn.layers.embedding import Embedding
from mlx.nn.layers.linear import Bilinear, Identity, Linear

View File

@@ -0,0 +1,456 @@
# Copyright © 2024 Apple Inc.
import math
from functools import lru_cache
from typing import Optional
import mlx.core as mx
from mlx.nn.layers.base import Module
@lru_cache
def sum_gradients(group):
if group.size() == 1:
return lambda x: x
@mx.custom_function
def f(x):
return x
@f.vjp
def f(x, dx, _):
return mx.distributed.all_sum(dx, group=group)
return f
class AllToShardedLinear(Module):
"""Each member of the group applies part of the affine transformation such
that the result is sharded across the group.
The gradients are automatically aggregated from each member of the group.
Args:
input_dims (int): The dimensionality of the input features
output_dims (int): The dimensionality of the output features
bias (bool, optional): If set to ``False`` the the layer will not use a
bias. Default is ``True``.
group (mx.distributed.Group, optional): The sharding will happen across
this group. If not set then the global group is used. Default is
``None``.
"""
def __init__(
self,
input_dims: int,
output_dims: int,
bias: bool = True,
group: Optional[mx.distributed.Group] = None,
):
super().__init__()
# Initialize the parameters
scale = math.sqrt(1.0 / input_dims)
self.group = group or mx.distributed.init()
N = self.group.size()
if (output_dims % N) != 0:
raise ValueError(
f"Cannot shard the output of size {output_dims} across {N} devices."
)
self.weight = mx.random.uniform(
low=-scale,
high=scale,
shape=(output_dims // N, input_dims),
)
if bias:
self.bias = mx.random.uniform(
low=-scale,
high=scale,
shape=(output_dims // N,),
)
def _extra_repr(self) -> str:
out_dims, in_dims = self.weight.shape
N = self.group.size()
out_dims *= N
return f"input_dims={in_dims}, output_dims={out_dims}, bias={'bias' in self}"
def __call__(self, x: mx.array) -> mx.array:
# Aggregate the gradients coming from each shard
if self.group.size() > 1:
x = sum_gradients(self.group)(x)
# Compute the affine projection
if "bias" in self:
x = mx.addmm(self["bias"], x, self["weight"].T)
else:
x = x @ self["weight"].T
return x
@classmethod
def from_linear(
cls, linear_layer: Module, group: Optional[mx.distributed.Group] = None
):
group = group or mx.distributed.init()
N = group.size()
r = group.rank()
output_dims, input_dims = linear_layer.weight.shape
step = output_dims // N
sl = cls(input_dims, output_dims, False, group)
# The multiplication with 1.0 forces a copy, perhaps change to
# something better when available.
sl.weight = linear_layer.weight[r * step : (r + 1) * step] * 1
if "bias" in linear_layer:
sl.bias = linear_layer.bias[r * step : (r + 1) * step] * 1
return sl
class ShardedToAllLinear(Module):
"""Each member of the group applies part of the affine transformation and
then aggregates the results.
All nodes will have the same exact result after this layer.
:class:`ShardedToAllLinear` provides a classmethod :meth:`from_linear` to
convert linear layers to sharded :obj:`ShardedToAllLinear` layers.
Args:
input_dims (int): The dimensionality of the input features
output_dims (int): The dimensionality of the output features
bias (bool, optional): If set to ``False`` the the layer will not use a
bias. Default is ``True``.
group (mx.distributed.Group, optional): The sharding will happen across
this group. If not set then the global group is used. Default is
``None``.
"""
def __init__(
self,
input_dims: int,
output_dims: int,
bias: bool = True,
group: Optional[mx.distributed.Group] = None,
):
super().__init__()
# Initialize the parameters
scale = math.sqrt(1.0 / input_dims)
self.group = group or mx.distributed.init()
N = self.group.size()
if (input_dims % N) != 0:
raise ValueError(
f"The input of size {input_dims} cannot be sharded across {N} devices."
)
self.weight = mx.random.uniform(
low=-scale,
high=scale,
shape=(output_dims, input_dims // N),
)
if bias:
self.bias = mx.random.uniform(
low=-scale,
high=scale,
shape=(output_dims,),
)
def _extra_repr(self) -> str:
N = self.group.size()
out_dims, in_dims = self.weight.shape
in_dims *= N
return f"input_dims={in_dims}, output_dims={out_dims}, bias={'bias' in self}"
def __call__(self, x: mx.array) -> mx.array:
if self.group.size() > 1:
# Perform the local projection and aggregate the results
x = x @ self["weight"].T
x = mx.distributed.all_sum(x, group=self.group)
# Add the bias if we have one
if "bias" in self:
x = x + self["bias"]
else:
# Normal linear layer as we are not in a distributed setting.
if "bias" in self:
x = mx.addmm(self["bias"], x, self["weight"].T)
else:
x = x @ self["weight"].T
return x
@classmethod
def from_linear(
cls, linear_layer: Module, group: Optional[mx.distributed.Group] = None
):
group = group or mx.distributed.init()
N = group.size()
r = group.rank()
output_dims, input_dims = linear_layer.weight.shape
step = input_dims // N
sl = cls(input_dims, output_dims, False, group)
# The multiplication with 1.0 forces a copy, perhaps change to
# something better when available.
sl.weight = linear_layer.weight[:, r * step : (r + 1) * step] * 1
if "bias" in linear_layer:
sl.bias = linear_layer.bias
return sl
class QuantizedAllToShardedLinear(Module):
"""Each member of the group applies part of the affine transformation with
a quantized matrix such that the result is sharded across the group.
It is the quantized equivalent of :class:`mlx.nn.AllToShardedLinear`.
Similar to :class:`mlx.nn.QuantizedLinear` its parameters are frozen and
will not be included in any gradient computation.
Args:
input_dims (int): The dimensionality of the input features.
output_dims (int): The dimensionality of the output features.
bias (bool, optional): If set to ``False`` then the layer will not use
a bias. Default: ``True``.
group_size (int, optional): The group size to use for the quantized
weight. See :func:`~mlx.core.quantize`. Default: ``64``.
bits (int, optional): The bit width to use for the quantized weight.
See :func:`~mlx.core.quantize`. Default: ``4``.
group (mx.distributed.Group, optional): The sharding will happen across
this group. If not set then the global group is used. Default is
``None``.
"""
def __init__(
self,
input_dims: int,
output_dims: int,
bias: bool = True,
group_size: int = 64,
bits: int = 4,
group: Optional[mx.distributed.Group] = None,
):
super().__init__()
# Quantization config
self.group_size = group_size
self.bits = bits
# Initialize the quantized weight
scale = math.sqrt(1.0 / input_dims)
self.group = group or mx.distributed.init()
N = self.group.size()
if (output_dims % N) != 0:
raise ValueError(
f"Cannot shard the output of size {output_dims} across {N} devices."
)
weight = mx.random.uniform(
low=-scale,
high=scale,
shape=(output_dims // N, input_dims),
)
self.weight, self.scales, self.biases = mx.quantize(weight, group_size, bits)
# And bias if needed
if bias:
self.bias = mx.zeros((output_dims // N,))
# Freeze this model's parameters
self.freeze()
def unfreeze(self, *args, **kwargs):
"""Wrap unfreeze so that we unfreeze any layers we might contain but
our parameters will remain frozen."""
super().unfreeze(*args, **kwargs)
self.freeze(recurse=False)
def _extra_repr(self) -> str:
out_dims, in_dims = self.weight.shape
in_dims *= 32 // self.bits
out_dims *= self.group.size()
return (
f"input_dims={in_dims}, output_dims={out_dims}, bias={'bias' in self}, "
f"group_size={self.group_size}, bits={self.bits}"
)
def __call__(self, x: mx.array) -> mx.array:
# Aggregate the gradients coming from each shard
if self.group.size() > 1:
x = sum_gradients(self.group)(x)
x = mx.quantized_matmul(
x,
self["weight"],
scales=self["scales"],
biases=self["biases"],
transpose=True,
group_size=self.group_size,
bits=self.bits,
)
if "bias" in self:
x = x + self["bias"]
return x
@classmethod
def from_quantized_linear(
cls,
quantized_linear_layer: Module,
group: Optional[mx.distributed.Group] = None,
):
group = group or mx.distributed.init()
N = group.size()
r = group.rank()
output_dims, input_dims = quantized_linear_layer.weight.shape
input_dims *= 32 // quantized_linear_layer.bits
step = output_dims // N
sl = cls(
input_dims,
output_dims,
False,
group_size=quantized_linear_layer.group_size,
bits=quantized_linear_layer.bits,
group=group,
)
sl.weight = quantized_linear_layer.weight[r * step : (r + 1) * step] * 1
sl.scales = quantized_linear_layer.scales[r * step : (r + 1) * step] * 1
sl.biases = quantized_linear_layer.biases[r * step : (r + 1) * step] * 1
if "bias" in quantized_linear_layer:
sl.bias = quantized_linear_layer.bias[r * step : (r + 1) * step] * 1
return sl
class QuantizedShardedToAllLinear(Module):
"""Each member of the group applies part of the affine transformation using
the quantized matrix and then aggregates the results.
All nodes will have the same exact result after this layer.
It is the quantized equivalent of :class:`mlx.nn.ShardedToAllLinear`.
Similar to :class:`mlx.nn.QuantizedLinear` its parameters are frozen and
will not be included in any gradient computation.
Args:
input_dims (int): The dimensionality of the input features.
output_dims (int): The dimensionality of the output features.
bias (bool, optional): If set to ``False`` then the layer will not use
a bias. Default: ``True``.
group_size (int, optional): The group size to use for the quantized
weight. See :func:`~mlx.core.quantize`. Default: ``64``.
bits (int, optional): The bit width to use for the quantized weight.
See :func:`~mlx.core.quantize`. Default: ``4``.
group (mx.distributed.Group, optional): The sharding will happen across
this group. If not set then the global group is used. Default is
``None``.
"""
def __init__(
self,
input_dims: int,
output_dims: int,
bias: bool = True,
group_size: int = 64,
bits: int = 4,
group: Optional[mx.distributed.Group] = None,
):
super().__init__()
# Quantization config
self.group_size = group_size
self.bits = bits
# Initialize the quantized weight
scale = math.sqrt(1.0 / input_dims)
self.group = group or mx.distributed.init()
N = self.group.size()
if (input_dims % N) != 0:
raise ValueError(
f"The input of size {input_dims} cannot be sharded across {N} devices."
)
weight = mx.random.uniform(
low=-scale,
high=scale,
shape=(output_dims, input_dims // N),
)
self.weight, self.scales, self.biases = mx.quantize(weight, group_size, bits)
# And bias if needed
if bias:
self.bias = mx.zeros((output_dims,))
# Freeze this model's parameters
self.freeze()
def unfreeze(self, *args, **kwargs):
"""Wrap unfreeze so that we unfreeze any layers we might contain but
our parameters will remain frozen."""
super().unfreeze(*args, **kwargs)
self.freeze(recurse=False)
def _extra_repr(self) -> str:
out_dims, in_dims = self.weight.shape
in_dims *= (32 // self.bits) * self.group.size()
return (
f"input_dims={in_dims}, output_dims={out_dims}, bias={'bias' in self}, "
f"group_size={self.group_size}, bits={self.bits}"
)
def __call__(self, x: mx.array) -> mx.array:
x = mx.quantized_matmul(
x,
self["weight"],
scales=self["scales"],
biases=self["biases"],
transpose=True,
group_size=self.group_size,
bits=self.bits,
)
if self.group.size() > 1:
x = mx.distributed.all_sum(x, group=self.group)
if "bias" in self:
x = x + self["bias"]
return x
@classmethod
def from_quantized_linear(
cls,
quantized_linear_layer: Module,
group: Optional[mx.distributed.Group] = None,
):
group = group or mx.distributed.init()
N = group.size()
r = group.rank()
output_dims, input_dims = quantized_linear_layer.weight.shape
step = input_dims // N
step_grouped = quantized_linear_layer.scales.shape[1] // N
input_dims *= (32 // quantized_linear_layer.bits) * N
sl = cls(
input_dims,
output_dims,
False,
group_size=quantized_linear_layer.group_size,
bits=quantized_linear_layer.bits,
group=group,
)
sl.weight = quantized_linear_layer.weight[:, r * step : (r + 1) * step] * 1
sl.scales = (
quantized_linear_layer.scales[:, r * step_grouped : (r + 1) * step_grouped]
* 1
)
sl.biases = (
quantized_linear_layer.biases[:, r * step_grouped : (r + 1) * step_grouped]
* 1
)
if "bias" in quantized_linear_layer:
sl.bias = quantized_linear_layer.bias
return sl

View File

@@ -197,7 +197,7 @@ class QuantizedLinear(Module):
out_dims, in_dims = self.weight.shape
in_dims *= 32 // self.bits
return (
f"input_dims={in_dims}, output_dims={out_dims}, bias={'bias' in self},"
f"input_dims={in_dims}, output_dims={out_dims}, bias={'bias' in self}, "
f"group_size={self.group_size}, bits={self.bits}"
)

View File

@@ -44,7 +44,8 @@ void init_distributed(nb::module_& parent_module) {
color (int): A value to group processes into subgroups.
key (int, optional): A key to optionally change the rank ordering
of the processes.
)pbdoc");
)pbdoc")
.def("barrier", &distributed::Group::barrier, "Make a synhronization point for all nodes in the group");
m.def(
"is_available",