Compare commits

..

2 Commits

Author SHA1 Message Date
Alex Barron
726dbd9267 v0.20.0 (#1565) 2024-11-05 12:37:57 -08:00
Awni Hannun
54f05e7195 Fix gather vmap (#1563)
* fix gather

* fix
2024-11-05 11:29:20 -08:00
12 changed files with 91 additions and 588 deletions

View File

@@ -24,7 +24,7 @@ option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF)
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
if(NOT MLX_VERSION)
set(MLX_VERSION 0.19.3)
set(MLX_VERSION 0.20.0)
endif()
# --------------------- Processor tests -------------------------

View File

@@ -1,16 +1,8 @@
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp)
if (MLX_BUILD_CPU)
if (MLX_CUSTOM_DISTRIBUTED)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/sockets)
elseif (MPI_FOUND)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/mpi)
else()
target_sources(
mlx
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/no_distributed.cpp
)
endif()
if(MPI_FOUND AND MLX_BUILD_CPU)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/mpi)
else()
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_distributed.cpp)
endif()

View File

@@ -32,8 +32,6 @@ struct Group {
*/
Group split(int color, int key = -1);
void barrier();
const std::shared_ptr<void>& raw_group() {
return group_;
}

View File

@@ -71,7 +71,6 @@ struct MPIWrapper {
LOAD_SYMBOL(MPI_Allgather, all_gather);
LOAD_SYMBOL(MPI_Send, send);
LOAD_SYMBOL(MPI_Recv, recv);
LOAD_SYMBOL(MPI_Barrier, barrier);
LOAD_SYMBOL(MPI_Type_contiguous, mpi_type_contiguous);
LOAD_SYMBOL(MPI_Type_commit, mpi_type_commit);
LOAD_SYMBOL(MPI_Op_create, mpi_op_create);
@@ -196,7 +195,6 @@ struct MPIWrapper {
int (*comm_free)(MPI_Comm*);
int (*send)(const void*, int, MPI_Datatype, int, int, MPI_Comm);
int (*recv)(void*, int, MPI_Datatype, int, int, MPI_Comm, MPI_Status*);
int (*barrier)(MPI_Comm);
// Objects
MPI_Comm comm_world_;
@@ -265,10 +263,6 @@ struct MPIGroupImpl {
return size_;
}
void barrier() {
mpi().barrier(comm_);
}
private:
MPI_Comm comm_;
bool global_;
@@ -304,11 +298,6 @@ Group Group::split(int color, int key) {
return Group(std::make_shared<MPIGroupImpl>(new_comm, false));
}
void Group::barrier() {
auto mpi_group = std::static_pointer_cast<MPIGroupImpl>(group_);
mpi_group->barrier();
}
bool is_available() {
return mpi().is_available();
}

View File

@@ -17,8 +17,6 @@ Group Group::split(int color, int key) {
throw std::runtime_error("Cannot split the distributed group further");
}
void Group::barrier() {}
bool is_available() {
return false;
}

View File

@@ -1,5 +0,0 @@
target_sources(
mlx
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/sockets.cpp
)

View File

@@ -1,522 +0,0 @@
// Copyright © 2024 Apple Inc.
#include <arpa/inet.h>
#include <json.hpp>
#include <netdb.h>
#include <sys/socket.h>
#include <unistd.h>
#include <chrono>
#include <cstdlib>
#include <fstream>
#include <iostream>
#include <sstream>
#include <thread>
#include "mlx/backend/common/copy.h"
#include "mlx/distributed/distributed.h"
#include "mlx/distributed/distributed_impl.h"
#include "mlx/io/threadpool.h"
#define SWITCH_TYPE(x, ...) \
switch ((x).dtype()) { \
case bool_: { \
using T = bool; \
__VA_ARGS__; \
} break; \
case int8: { \
using T = int8_t; \
__VA_ARGS__; \
} break; \
case int16: { \
using T = int16_t; \
__VA_ARGS__; \
} break; \
case int32: { \
using T = int32_t; \
__VA_ARGS__; \
} break; \
case int64: { \
using T = int64_t; \
__VA_ARGS__; \
} break; \
case uint8: { \
using T = uint8_t; \
__VA_ARGS__; \
} break; \
case uint16: { \
using T = uint16_t; \
__VA_ARGS__; \
} break; \
case uint32: { \
using T = uint32_t; \
__VA_ARGS__; \
} break; \
case uint64: { \
using T = uint64_t; \
__VA_ARGS__; \
} break; \
case bfloat16: { \
using T = bfloat16_t; \
__VA_ARGS__; \
} break; \
case float16: { \
using T = float16_t; \
__VA_ARGS__; \
} break; \
case float32: { \
using T = float; \
__VA_ARGS__; \
} break; \
case complex64: { \
using T = complex64_t; \
__VA_ARGS__; \
} break; \
}
constexpr const size_t PACKET_SIZE = 262144;
constexpr const int CONN_ATTEMPTS = 5;
constexpr const int CONN_WAIT = 1000;
using json = nlohmann::json;
namespace mlx::core::distributed {
namespace {
template <typename T>
void sum_inplace(const T* input, T* output, size_t N) {
while (N-- > 0) {
*output += *input;
input++;
output++;
}
}
void sum_inplace(const array& input, array& output) {
SWITCH_TYPE(
input, sum_inplace(input.data<T>(), output.data<T>(), input.size()));
}
array ensure_row_contiguous(const array& arr) {
if (arr.flags().row_contiguous) {
return arr;
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::General);
return arr_copy;
}
}
struct address_t {
sockaddr_storage addr;
socklen_t len;
const sockaddr* sockaddr() {
return (struct sockaddr*)&addr;
}
};
address_t parse_address(std::string ip, std::string port) {
struct addrinfo hints, *res;
memset(&hints, 0, sizeof(hints));
hints.ai_family = AF_UNSPEC;
hints.ai_socktype = SOCK_STREAM;
int status = getaddrinfo(ip.c_str(), port.c_str(), &hints, &res);
if (status != 0) {
std::ostringstream msg;
msg << "Can't parse peer address " << ip << ":" << port;
throw std::runtime_error(msg.str());
}
address_t result;
memcpy(&result.addr, res->ai_addr, res->ai_addrlen);
result.len = res->ai_addrlen;
freeaddrinfo(res);
return result;
}
std::vector<address_t> load_peers() {
std::vector<address_t> peers;
std::ifstream f;
if (const char* hostfile_buf = std::getenv("MLX_HOSTFILE")) {
f.open(hostfile_buf);
} else {
return peers;
}
json hosts = json::parse(f);
for (auto& h : hosts) {
peers.push_back(std::move(parse_address(
h["ip"].template get<std::string>(),
h["port"].template get<std::string>())));
}
return peers;
}
struct GroupImpl {
GroupImpl(std::vector<address_t> peers, int rank, bool global)
: rank_(rank), global_(global), pool_(4), sockets_(peers.size(), -1) {
if (rank_ > 0 && rank_ >= peers.size()) {
throw std::runtime_error(
"Rank cannot be larger than the size of the group");
}
int success;
// If we are expecting anyone to connect to us
if (rank_ + 1 < peers.size()) {
// Create the socket to wait for connections from the peers
int sock = socket(AF_INET, SOCK_STREAM, 0);
if (sock < 0) {
std::ostringstream msg;
msg << "Couldn't create socket (error: " << errno << ")";
throw std::runtime_error(msg.str());
}
// Make sure we can launch immediately after shutdown by setting the
// reuseaddr option so that we don't get address already in use errors
int enable = 1;
success =
setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &enable, sizeof(int));
if (success < 0) {
shutdown(sock, 2);
close(sock);
std::ostringstream msg;
msg << "Couldn't enable reuseaddr (rank: " << rank_
<< " error: " << errno << ")";
throw std::runtime_error(msg.str());
}
success =
setsockopt(sock, SOL_SOCKET, SO_REUSEPORT, &enable, sizeof(int));
if (success < 0) {
shutdown(sock, 2);
close(sock);
std::ostringstream msg;
msg << "Couldn't enable reuseport (rank: " << rank_
<< " error: " << errno << ")";
throw std::runtime_error(msg.str());
}
// Bind it to the port
success = bind(sock, peers[rank_].sockaddr(), peers[rank_].len);
if (success < 0) {
shutdown(sock, 2);
close(sock);
std::ostringstream msg;
msg << "Couldn't bind socket (rank: " << rank_ << " error: " << errno
<< ")";
throw std::runtime_error(msg.str());
}
// Wait for connections
success = listen(sock, 0);
if (success < 0) {
shutdown(sock, 2);
close(sock);
std::ostringstream msg;
msg << "Couldn't listen (error: " << errno << ")";
throw std::runtime_error(msg.str());
}
for (int i = 0; i < peers.size() - rank_ - 1; i++) {
int peer_socket = accept(sock, nullptr, nullptr);
if (peer_socket < 0) {
shutdown(sock, 2);
close(sock);
std::ostringstream msg;
msg << "Accept failed (error: " << errno << ")";
throw std::runtime_error(msg.str());
}
sockets_[peers.size() - 1 - i] = peer_socket;
}
// Close the listening socket
shutdown(sock, 2);
close(sock);
}
// Connect to the peers with smaller rank
for (int i = 0; i < rank_; i++) {
sockets_[i] = socket(AF_INET, SOCK_STREAM, 0);
if (sockets_[i] < 0) {
std::ostringstream msg;
msg << "Couldn't create socket (error: " << errno << ")";
throw std::runtime_error(msg.str());
}
for (int attempt = 0; attempt < CONN_ATTEMPTS; attempt++) {
if (attempt > 0) {
int wait = (1 << (attempt - 1)) * CONN_WAIT;
std::this_thread::sleep_for(std::chrono::milliseconds(wait));
}
success = connect(sockets_[i], peers[i].sockaddr(), peers[i].len);
if (success == 0) {
break;
}
}
if (success < 0) {
std::ostringstream msg;
msg << "Couldn't connect (rank: " << rank_ << " to: " << i
<< " error: " << errno << ")";
throw std::runtime_error(msg.str());
}
}
}
~GroupImpl() {
if (global_) {
for (int sock : sockets_) {
shutdown(sock, 2);
close(sock);
}
}
}
int rank() {
return rank_;
}
int size() {
return std::max(sockets_.size(), 1ul);
}
void send(const char* buf, size_t len, int dst) {
while (len > 0) {
ssize_t r = ::send(sockets_[dst], buf, len, 0);
if (r <= 0) {
std::ostringstream msg;
msg << "Send of " << len << " bytes failed (errno: " << errno << ")";
throw std::runtime_error(msg.str());
}
buf += r;
len -= r;
}
}
void recv(char* buf, size_t len, int src) {
while (len > 0) {
ssize_t r = ::recv(sockets_[src], buf, len, 0);
if (r <= 0) {
std::ostringstream msg;
msg << "Recv of " << len << " bytes failed (errno: " << errno << ")";
throw std::runtime_error(msg.str());
}
buf += r;
len -= r;
}
}
template <typename T>
void send_recv_sum(char* buf, size_t len, int peer) {
char recv_buffer[2 * PACKET_SIZE];
char* recv_buffers[2];
recv_buffers[0] = recv_buffer;
recv_buffers[1] = recv_buffer + PACKET_SIZE;
std::future<void> sent, received;
size_t n_blocks = (len + PACKET_SIZE - 1) / PACKET_SIZE;
for (size_t b = 0; b < n_blocks; b++) {
if (b > 0) {
sent.wait();
received.wait();
}
size_t l = std::min(len - b * PACKET_SIZE, PACKET_SIZE);
if (rank_ < peer) {
sent = send_async(buf + b * PACKET_SIZE, l, peer);
received = recv_async(recv_buffers[b % 2], l, peer);
} else {
received = recv_async(recv_buffers[b % 2], l, peer);
sent = send_async(buf + b * PACKET_SIZE, l, peer);
}
if (b > 0) {
sum_inplace(
(const T*)recv_buffers[(b - 1) % 2],
(T*)(buf + (b - 1) * PACKET_SIZE),
PACKET_SIZE / sizeof(T));
}
}
sent.wait();
received.wait();
size_t l = std::min(len - (n_blocks - 1) * PACKET_SIZE, PACKET_SIZE);
sum_inplace(
(const T*)recv_buffers[(n_blocks - 1) % 2],
(T*)(buf + (n_blocks - 1) * PACKET_SIZE),
l / sizeof(T));
}
void send_recv_sum(array& out, int peer) {
SWITCH_TYPE(out, send_recv_sum<T>(out.data<char>(), out.nbytes(), peer));
}
std::future<void> send_async(const char* buf, size_t len, int dst) {
return pool_.enqueue(
[this, buf, len, dst]() { this->send(buf, len, dst); });
}
std::future<void> recv_async(char* buf, size_t len, int src) {
return pool_.enqueue(
[this, buf, len, src]() { this->recv(buf, len, src); });
}
private:
int rank_;
bool global_;
ThreadPool pool_;
std::vector<int> sockets_;
};
} // namespace
bool is_available() {
return true;
}
int Group::rank() {
return std::static_pointer_cast<GroupImpl>(group_)->rank();
}
int Group::size() {
return std::static_pointer_cast<GroupImpl>(group_)->size();
}
Group Group::split(int color, int key) {
throw std::runtime_error("Splitting not supported yet");
}
void Group::barrier() {
char buff[128];
std::memset(buff, 1, 128);
auto group = std::static_pointer_cast<GroupImpl>(raw_group());
int size = group->size();
int rank = group->rank();
for (int distance = 1; distance <= size / 2; distance *= 2) {
group->send_recv_sum<char>(buff, 128, rank ^ distance);
}
}
Group init(bool strict /* = false */) {
static std::shared_ptr<GroupImpl> global_group = nullptr;
if (global_group == nullptr) {
auto peers = load_peers();
int rank = 0;
if (const char* rank_buf = std::getenv("MLX_RANK")) {
rank = std::atoi(rank_buf);
}
if (peers.size() == 0) {
if (strict) {
throw std::runtime_error("Can't initialize distributed");
}
}
global_group = std::make_shared<GroupImpl>(std::move(peers), rank, true);
}
return Group(global_group);
}
namespace detail {
Stream communication_stream() {
static Stream comm_stream = new_stream(Device::cpu);
return comm_stream;
}
void all_sum(Group group_, const array& input_, array& output) {
auto group = std::static_pointer_cast<GroupImpl>(group_.raw_group());
array input = ensure_row_contiguous(input_);
int size = group->size();
int rank = group->rank();
if ((size & (size - 1)) != 0) {
throw std::runtime_error("Only powers of 2 are currently supported");
}
// If not inplace all reduce then copy the input to the output first.
if (input.data<void>() != output.data<void>()) {
std::memcpy(output.data<char>(), input.data<char>(), input.nbytes());
}
// Butterfly all reduce
for (int distance = 1; distance <= size / 2; distance *= 2) {
group->send_recv_sum(output, rank ^ distance);
}
}
void all_gather(Group group_, const array& input_, array& output) {
auto group = std::static_pointer_cast<GroupImpl>(group_.raw_group());
array input = ensure_row_contiguous(input_);
std::future<void> sent;
std::future<void> received;
int rank = group->rank();
int size = group->size();
if ((size & (size - 1)) != 0) {
throw std::runtime_error("Only powers of 2 are currently supported");
}
// Butterfly all gather
int peer = rank ^ 1;
if (peer < rank) {
received = group->recv_async(
output.data<char>() + peer * input.nbytes(), input.nbytes(), peer);
sent = group->send_async(input.data<char>(), input.nbytes(), peer);
} else {
sent = group->send_async(input.data<char>(), input.nbytes(), peer);
received = group->recv_async(
output.data<char>() + peer * input.nbytes(), input.nbytes(), peer);
}
std::memcpy(
output.data<char>() + rank * input.nbytes(),
input.data<char>(),
input.nbytes());
for (int distance = 2; distance <= size / 2; distance *= 2) {
sent.wait();
received.wait();
int peer = rank ^ distance;
int their_offset = peer & ~(distance - 1);
int our_offset = rank & ~(distance - 1);
if (peer < rank) {
received = group->recv_async(
output.data<char>() + their_offset * input.nbytes(),
distance * input.nbytes(),
peer);
sent = group->send_async(
output.data<char>() + our_offset * input.nbytes(),
distance * input.nbytes(),
peer);
} else {
sent = group->send_async(
output.data<char>() + our_offset * input.nbytes(),
distance * input.nbytes(),
peer);
received = group->recv_async(
output.data<char>() + their_offset * input.nbytes(),
distance * input.nbytes(),
peer);
}
}
sent.wait();
received.wait();
}
void send(Group group_, const array& input_, int dst) {
array input = ensure_row_contiguous(input_);
auto group = std::static_pointer_cast<GroupImpl>(group_.raw_group());
group->send(input.data<char>(), input.nbytes(), dst);
}
void recv(Group group_, array& out, int src) {
auto group = std::static_pointer_cast<GroupImpl>(group_.raw_group());
group->recv(out.data<char>(), out.nbytes(), src);
}
} // namespace detail
} // namespace mlx::core::distributed

View File

@@ -1,5 +1,5 @@
// Copyright © 2023 Apple Inc.
//
#include <json.hpp>
#include <stack>

View File

@@ -1,5 +1,4 @@
// Copyright © 2023-2024 Apple Inc.
#include <algorithm>
#include <cassert>
#include <cmath>
@@ -1683,48 +1682,58 @@ std::pair<std::vector<array>, std::vector<int>> Gather::vmap(
auto gather_axes = axes_;
auto slice_sizes = slice_sizes_;
auto src_vmapped = axes[0] >= 0;
auto indices_vmapped =
std::any_of(axes.begin() + 1, axes.end(), [](int a) { return a >= 0; });
auto out_ax =
*std::find_if(axes.begin(), axes.end(), [](int a) { return a >= 0; });
auto ind_vmap_ax_ptr =
std::find_if(axes.begin() + 1, axes.end(), [](int a) { return a >= 0; });
int out_ax = -1;
bool indices_vmapped = (ind_vmap_ax_ptr != axes.end());
if (indices_vmapped) {
out_ax = *ind_vmap_ax_ptr;
} else if (src_vmapped) {
out_ax = axes[0];
}
// Reorder all the index arrays so the vmap axis is in the same spot.
for (int i = 1; i < axes.size(); ++i) {
if (out_ax != axes[i] && axes[i] >= 0) {
indices[i - 1] = moveaxis(indices[i - 1], axes[i], out_ax, stream());
if (indices_vmapped) {
for (int i = 1; i < axes.size(); ++i) {
if (out_ax != axes[i] && axes[i] >= 0) {
indices[i - 1] = moveaxis(indices[i - 1], axes[i], out_ax, stream());
} else if (axes[i] < 0) {
indices[i - 1] = expand_dims(indices[i - 1], out_ax, stream());
}
}
}
int idx_dims = indices.empty() ? 0 : indices[0].ndim();
if (src_vmapped) {
int max_dims = 0;
for (auto& idx : indices) {
max_dims = std::max(static_cast<int>(idx.ndim()), max_dims);
}
auto new_ax_loc =
std::find_if(gather_axes.begin(), gather_axes.end(), [&out_ax](int a) {
return a >= out_ax;
});
for (; new_ax_loc < gather_axes.end(); new_ax_loc++) {
(*new_ax_loc)++;
for (auto& ax : gather_axes) {
if (ax >= axes[0]) {
ax++;
}
}
if (indices_vmapped) {
// Make a new index array for the vmapped dimension
auto vmap_inds = arange(0, src.shape(axes[0]), stream());
// Reshape it so it broadcasts with other index arrays
{
auto shape = std::vector<int>(idx_dims, 1);
shape[out_ax] = vmap_inds.size();
vmap_inds = reshape(vmap_inds, std::move(shape), stream());
}
// Update gather axes and slice sizes accordingly
auto shape = std::vector<int>(max_dims - out_ax, 1);
auto vmap_inds = arange(0, src.shape(out_ax), stream());
shape[0] = vmap_inds.shape(0);
vmap_inds = reshape(vmap_inds, shape, stream());
slice_sizes.insert(slice_sizes.begin() + out_ax, 1);
auto new_ax_idx = new_ax_loc - gather_axes.begin();
gather_axes.insert(new_ax_loc, out_ax);
indices.insert(indices.begin() + new_ax_idx, vmap_inds);
slice_sizes.insert(slice_sizes.begin() + axes[0], 1);
gather_axes.push_back(axes[0]);
indices.push_back(vmap_inds);
} else {
slice_sizes.insert(slice_sizes.begin() + axes[0], src.shape(axes[0]));
out_ax = max_dims + axes[0];
slice_sizes.insert(slice_sizes.begin() + out_ax, src.shape(out_ax));
out_ax += idx_dims;
}
}
return {{gather(src, indices, gather_axes, slice_sizes, stream())}, {out_ax}};
auto out = gather(src, indices, gather_axes, slice_sizes, stream());
if (src_vmapped && indices_vmapped) {
out = squeeze(out, idx_dims + axes[0], stream());
}
return {{out}, {out_ax}};
}
std::vector<array> Gather::vjp(

View File

@@ -44,8 +44,7 @@ void init_distributed(nb::module_& parent_module) {
color (int): A value to group processes into subgroups.
key (int, optional): A key to optionally change the rank ordering
of the processes.
)pbdoc")
.def("barrier", &distributed::Group::barrier, "Make a synhronization point for all nodes in the group");
)pbdoc");
m.def(
"is_available",

View File

@@ -370,6 +370,51 @@ class TestVmap(mlx_tests.MLXTestCase):
mx.allclose(a[:, i, :] @ invs[i], mx.eye(a.shape[0]), rtol=0, atol=1e-5)
)
def test_vmap_gather(self):
def gather(a, idx):
return a[idx]
a = mx.array([[1, 2], [3, 4]])
idx = mx.array(0)
out = mx.vmap(gather, (0, None))(a, idx)
self.assertTrue(mx.array_equal(out, mx.array([1, 3])))
out = mx.vmap(gather, (1, None))(a, idx)
self.assertTrue(mx.array_equal(out, mx.array([1, 2])))
idx = mx.array([0, 1])
out = mx.vmap(gather, (0, 0))(a, idx)
self.assertTrue(mx.array_equal(out, mx.array([1, 4])))
a = mx.ones((2, 3, 4))
idx = mx.zeros(4, mx.int32)
out = mx.vmap(gather, (2, 0))(a, idx)
self.assertEqual(out.shape, (4, 3))
f = mx.vmap(gather, (0, None))
f = mx.vmap(gather, (0, 0))
out = f(mx.ones((2, 3, 4)), mx.zeros(2, dtype=mx.int32))
self.assertEqual(out.shape, (2, 4))
def gather(a, idxa, idxb):
return a[idxa, idxb]
a = mx.ones((2, 3, 4))
idxa = mx.zeros((2, 3), mx.int32)
idxb = mx.zeros(3, mx.int32)
out = mx.vmap(gather, (0, 0, None))(a, idxa, idxb)
self.assertEqual(out.shape, (2, 3))
idxa = mx.zeros((3, 1, 2), mx.int32)
idxb = mx.zeros((2, 3, 1, 2), mx.int32)
out = mx.vmap(gather, (0, None, 0))(a, idxa, idxb)
self.assertEqual(out.shape, (2, 3, 1, 2))
idxa = mx.zeros((3, 1, 2), mx.int32)
idxb = mx.zeros((3, 1, 2, 2), mx.int32)
out = mx.vmap(gather, (0, None, 3))(a, idxa, idxb)
self.assertEqual(out.shape, (2, 3, 1, 2))
def test_vmap_scatter(self):
def scatter(a):
a[mx.array(0)] = mx.array(0.0)

View File

@@ -165,7 +165,7 @@ if __name__ == "__main__":
setup(
name="mlx",
version=get_version("0.19.3"),
version=get_version("0.20.0"),
author="MLX Contributors",
author_email="mlx@group.apple.com",
description="A framework for machine learning on Apple silicon.",