mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 18:48:15 +08:00 
			
		
		
		
	TCP socket distributed
This commit is contained in:
		@@ -5,14 +5,77 @@
 | 
				
			|||||||
#include <netdb.h>
 | 
					#include <netdb.h>
 | 
				
			||||||
#include <sys/socket.h>
 | 
					#include <sys/socket.h>
 | 
				
			||||||
#include <unistd.h>
 | 
					#include <unistd.h>
 | 
				
			||||||
 | 
					#include <chrono>
 | 
				
			||||||
#include <cstdlib>
 | 
					#include <cstdlib>
 | 
				
			||||||
#include <fstream>
 | 
					#include <fstream>
 | 
				
			||||||
#include <iostream>
 | 
					#include <iostream>
 | 
				
			||||||
#include <sstream>
 | 
					#include <sstream>
 | 
				
			||||||
 | 
					#include <thread>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#include "mlx/backend/common/copy.h"
 | 
					#include "mlx/backend/common/copy.h"
 | 
				
			||||||
#include "mlx/distributed/distributed.h"
 | 
					#include "mlx/distributed/distributed.h"
 | 
				
			||||||
#include "mlx/distributed/distributed_impl.h"
 | 
					#include "mlx/distributed/distributed_impl.h"
 | 
				
			||||||
 | 
					#include "mlx/io/threadpool.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#define SWITCH_TYPE(x, ...)  \
 | 
				
			||||||
 | 
					  switch ((x).dtype()) {     \
 | 
				
			||||||
 | 
					    case bool_: {            \
 | 
				
			||||||
 | 
					      using T = bool;        \
 | 
				
			||||||
 | 
					      __VA_ARGS__;           \
 | 
				
			||||||
 | 
					    } break;                 \
 | 
				
			||||||
 | 
					    case int8: {             \
 | 
				
			||||||
 | 
					      using T = int8_t;      \
 | 
				
			||||||
 | 
					      __VA_ARGS__;           \
 | 
				
			||||||
 | 
					    } break;                 \
 | 
				
			||||||
 | 
					    case int16: {            \
 | 
				
			||||||
 | 
					      using T = int16_t;     \
 | 
				
			||||||
 | 
					      __VA_ARGS__;           \
 | 
				
			||||||
 | 
					    } break;                 \
 | 
				
			||||||
 | 
					    case int32: {            \
 | 
				
			||||||
 | 
					      using T = int32_t;     \
 | 
				
			||||||
 | 
					      __VA_ARGS__;           \
 | 
				
			||||||
 | 
					    } break;                 \
 | 
				
			||||||
 | 
					    case int64: {            \
 | 
				
			||||||
 | 
					      using T = int64_t;     \
 | 
				
			||||||
 | 
					      __VA_ARGS__;           \
 | 
				
			||||||
 | 
					    } break;                 \
 | 
				
			||||||
 | 
					    case uint8: {            \
 | 
				
			||||||
 | 
					      using T = uint8_t;     \
 | 
				
			||||||
 | 
					      __VA_ARGS__;           \
 | 
				
			||||||
 | 
					    } break;                 \
 | 
				
			||||||
 | 
					    case uint16: {           \
 | 
				
			||||||
 | 
					      using T = uint16_t;    \
 | 
				
			||||||
 | 
					      __VA_ARGS__;           \
 | 
				
			||||||
 | 
					    } break;                 \
 | 
				
			||||||
 | 
					    case uint32: {           \
 | 
				
			||||||
 | 
					      using T = uint32_t;    \
 | 
				
			||||||
 | 
					      __VA_ARGS__;           \
 | 
				
			||||||
 | 
					    } break;                 \
 | 
				
			||||||
 | 
					    case uint64: {           \
 | 
				
			||||||
 | 
					      using T = uint64_t;    \
 | 
				
			||||||
 | 
					      __VA_ARGS__;           \
 | 
				
			||||||
 | 
					    } break;                 \
 | 
				
			||||||
 | 
					    case bfloat16: {         \
 | 
				
			||||||
 | 
					      using T = bfloat16_t;  \
 | 
				
			||||||
 | 
					      __VA_ARGS__;           \
 | 
				
			||||||
 | 
					    } break;                 \
 | 
				
			||||||
 | 
					    case float16: {          \
 | 
				
			||||||
 | 
					      using T = float16_t;   \
 | 
				
			||||||
 | 
					      __VA_ARGS__;           \
 | 
				
			||||||
 | 
					    } break;                 \
 | 
				
			||||||
 | 
					    case float32: {          \
 | 
				
			||||||
 | 
					      using T = float;       \
 | 
				
			||||||
 | 
					      __VA_ARGS__;           \
 | 
				
			||||||
 | 
					    } break;                 \
 | 
				
			||||||
 | 
					    case complex64: {        \
 | 
				
			||||||
 | 
					      using T = complex64_t; \
 | 
				
			||||||
 | 
					      __VA_ARGS__;           \
 | 
				
			||||||
 | 
					    } break;                 \
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					constexpr const size_t PACKET_SIZE = 262144;
 | 
				
			||||||
 | 
					constexpr const int CONN_ATTEMPTS = 5;
 | 
				
			||||||
 | 
					constexpr const int CONN_WAIT = 1000;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
using json = nlohmann::json;
 | 
					using json = nlohmann::json;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -30,46 +93,8 @@ void sum_inplace(const T* input, T* output, size_t N) {
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
void sum_inplace(const array& input, array& output) {
 | 
					void sum_inplace(const array& input, array& output) {
 | 
				
			||||||
  switch (input.dtype()) {
 | 
					  SWITCH_TYPE(
 | 
				
			||||||
    case bool_:
 | 
					      input, sum_inplace(input.data<T>(), output.data<T>(), input.size()));
 | 
				
			||||||
      return sum_inplace(input.data<bool>(), output.data<bool>(), input.size());
 | 
					 | 
				
			||||||
    case int8:
 | 
					 | 
				
			||||||
      return sum_inplace(
 | 
					 | 
				
			||||||
          input.data<int8_t>(), output.data<int8_t>(), input.size());
 | 
					 | 
				
			||||||
    case uint8:
 | 
					 | 
				
			||||||
      return sum_inplace(
 | 
					 | 
				
			||||||
          input.data<uint8_t>(), output.data<uint8_t>(), input.size());
 | 
					 | 
				
			||||||
    case int16:
 | 
					 | 
				
			||||||
      return sum_inplace(
 | 
					 | 
				
			||||||
          input.data<int16_t>(), output.data<int16_t>(), input.size());
 | 
					 | 
				
			||||||
    case uint16:
 | 
					 | 
				
			||||||
      return sum_inplace(
 | 
					 | 
				
			||||||
          input.data<uint16_t>(), output.data<uint16_t>(), input.size());
 | 
					 | 
				
			||||||
    case int32:
 | 
					 | 
				
			||||||
      return sum_inplace(
 | 
					 | 
				
			||||||
          input.data<int32_t>(), output.data<int32_t>(), input.size());
 | 
					 | 
				
			||||||
    case uint32:
 | 
					 | 
				
			||||||
      return sum_inplace(
 | 
					 | 
				
			||||||
          input.data<uint32_t>(), output.data<uint32_t>(), input.size());
 | 
					 | 
				
			||||||
    case int64:
 | 
					 | 
				
			||||||
      return sum_inplace(
 | 
					 | 
				
			||||||
          input.data<int64_t>(), output.data<int64_t>(), input.size());
 | 
					 | 
				
			||||||
    case uint64:
 | 
					 | 
				
			||||||
      return sum_inplace(
 | 
					 | 
				
			||||||
          input.data<uint64_t>(), output.data<uint64_t>(), input.size());
 | 
					 | 
				
			||||||
    case float16:
 | 
					 | 
				
			||||||
      return sum_inplace(
 | 
					 | 
				
			||||||
          input.data<float16_t>(), output.data<float16_t>(), input.size());
 | 
					 | 
				
			||||||
    case bfloat16:
 | 
					 | 
				
			||||||
      return sum_inplace(
 | 
					 | 
				
			||||||
          input.data<bfloat16_t>(), output.data<bfloat16_t>(), input.size());
 | 
					 | 
				
			||||||
    case float32:
 | 
					 | 
				
			||||||
      return sum_inplace(
 | 
					 | 
				
			||||||
          input.data<float>(), output.data<float>(), input.size());
 | 
					 | 
				
			||||||
    case complex64:
 | 
					 | 
				
			||||||
      return sum_inplace(
 | 
					 | 
				
			||||||
          input.data<complex64_t>(), output.data<complex64_t>(), input.size());
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
array ensure_row_contiguous(const array& arr) {
 | 
					array ensure_row_contiguous(const array& arr) {
 | 
				
			||||||
@@ -95,7 +120,7 @@ address_t parse_address(std::string ip, std::string port) {
 | 
				
			|||||||
  struct addrinfo hints, *res;
 | 
					  struct addrinfo hints, *res;
 | 
				
			||||||
  memset(&hints, 0, sizeof(hints));
 | 
					  memset(&hints, 0, sizeof(hints));
 | 
				
			||||||
  hints.ai_family = AF_UNSPEC;
 | 
					  hints.ai_family = AF_UNSPEC;
 | 
				
			||||||
  hints.ai_socktype = SOCK_DGRAM;
 | 
					  hints.ai_socktype = SOCK_STREAM;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  int status = getaddrinfo(ip.c_str(), port.c_str(), &hints, &res);
 | 
					  int status = getaddrinfo(ip.c_str(), port.c_str(), &hints, &res);
 | 
				
			||||||
  if (status != 0) {
 | 
					  if (status != 0) {
 | 
				
			||||||
@@ -134,30 +159,118 @@ std::vector<address_t> load_peers() {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
struct GroupImpl {
 | 
					struct GroupImpl {
 | 
				
			||||||
  GroupImpl(std::vector<address_t> peers, int rank, bool global)
 | 
					  GroupImpl(std::vector<address_t> peers, int rank, bool global)
 | 
				
			||||||
      : peers_(std::move(peers)), rank_(rank), global_(global) {
 | 
					      : rank_(rank), global_(global), pool_(4), sockets_(peers.size(), -1) {
 | 
				
			||||||
    if (rank_ > 0 && rank_ >= peers_.size()) {
 | 
					    if (rank_ > 0 && rank_ >= peers.size()) {
 | 
				
			||||||
      throw std::runtime_error(
 | 
					      throw std::runtime_error(
 | 
				
			||||||
          "Rank cannot be larger than the size of the group");
 | 
					          "Rank cannot be larger than the size of the group");
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
    if (global_ && rank_ < peers_.size()) {
 | 
					
 | 
				
			||||||
      socket_fd_ = socket(AF_INET, SOCK_DGRAM, 0);
 | 
					    int success;
 | 
				
			||||||
      if (socket_fd_ < 0) {
 | 
					
 | 
				
			||||||
 | 
					    // If we are expecting anyone to connect to us
 | 
				
			||||||
 | 
					    if (rank_ < peers.size() - 1) {
 | 
				
			||||||
 | 
					      // Create the socket to wait for connections from the peers
 | 
				
			||||||
 | 
					      int sock = socket(AF_INET, SOCK_STREAM, 0);
 | 
				
			||||||
 | 
					      if (sock < 0) {
 | 
				
			||||||
        std::ostringstream msg;
 | 
					        std::ostringstream msg;
 | 
				
			||||||
        msg << "Couldn't create socket (error: " << errno << ")";
 | 
					        msg << "Couldn't create socket (error: " << errno << ")";
 | 
				
			||||||
        throw std::runtime_error(msg.str());
 | 
					        throw std::runtime_error(msg.str());
 | 
				
			||||||
      }
 | 
					      }
 | 
				
			||||||
      int success =
 | 
					
 | 
				
			||||||
          bind(socket_fd_, peers_[rank_].sockaddr(), peers_[rank_].len);
 | 
					      // Make sure we can launch immediately after shutdown by setting the
 | 
				
			||||||
 | 
					      // reuseaddr option so that we don't get address already in use errors
 | 
				
			||||||
 | 
					      int enable = 1;
 | 
				
			||||||
 | 
					      success =
 | 
				
			||||||
 | 
					          setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &enable, sizeof(int));
 | 
				
			||||||
 | 
					      if (success < 0) {
 | 
				
			||||||
 | 
					        shutdown(sock, 2);
 | 
				
			||||||
 | 
					        close(sock);
 | 
				
			||||||
 | 
					        std::ostringstream msg;
 | 
				
			||||||
 | 
					        msg << "Couldn't enable reuseaddr (rank: " << rank_
 | 
				
			||||||
 | 
					            << " error: " << errno << ")";
 | 
				
			||||||
 | 
					        throw std::runtime_error(msg.str());
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					      success =
 | 
				
			||||||
 | 
					          setsockopt(sock, SOL_SOCKET, SO_REUSEPORT, &enable, sizeof(int));
 | 
				
			||||||
 | 
					      if (success < 0) {
 | 
				
			||||||
 | 
					        shutdown(sock, 2);
 | 
				
			||||||
 | 
					        close(sock);
 | 
				
			||||||
 | 
					        std::ostringstream msg;
 | 
				
			||||||
 | 
					        msg << "Couldn't enable reuseport (rank: " << rank_
 | 
				
			||||||
 | 
					            << " error: " << errno << ")";
 | 
				
			||||||
 | 
					        throw std::runtime_error(msg.str());
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      // Bind it to the port
 | 
				
			||||||
 | 
					      success = bind(sock, peers[rank_].sockaddr(), peers[rank_].len);
 | 
				
			||||||
 | 
					      if (success < 0) {
 | 
				
			||||||
 | 
					        shutdown(sock, 2);
 | 
				
			||||||
 | 
					        close(sock);
 | 
				
			||||||
 | 
					        std::ostringstream msg;
 | 
				
			||||||
 | 
					        msg << "Couldn't bind socket (rank: " << rank_ << " error: " << errno
 | 
				
			||||||
 | 
					            << ")";
 | 
				
			||||||
 | 
					        throw std::runtime_error(msg.str());
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      // Wait for connections
 | 
				
			||||||
 | 
					      success = listen(sock, 0);
 | 
				
			||||||
 | 
					      if (success < 0) {
 | 
				
			||||||
 | 
					        shutdown(sock, 2);
 | 
				
			||||||
 | 
					        close(sock);
 | 
				
			||||||
 | 
					        std::ostringstream msg;
 | 
				
			||||||
 | 
					        msg << "Couldn't listen (error: " << errno << ")";
 | 
				
			||||||
 | 
					        throw std::runtime_error(msg.str());
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					      for (int i = 0; i < peers.size() - rank_ - 1; i++) {
 | 
				
			||||||
 | 
					        int peer_socket = accept(sock, nullptr, nullptr);
 | 
				
			||||||
 | 
					        if (peer_socket < 0) {
 | 
				
			||||||
 | 
					          shutdown(sock, 2);
 | 
				
			||||||
 | 
					          close(sock);
 | 
				
			||||||
 | 
					          std::ostringstream msg;
 | 
				
			||||||
 | 
					          msg << "Accept failed (error: " << errno << ")";
 | 
				
			||||||
 | 
					          throw std::runtime_error(msg.str());
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        sockets_[peers.size() - 1 - i] = peer_socket;
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      // Close the listening socket
 | 
				
			||||||
 | 
					      shutdown(sock, 2);
 | 
				
			||||||
 | 
					      close(sock);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Connect to the peers with smaller rank
 | 
				
			||||||
 | 
					    for (int i = 0; i < rank_; i++) {
 | 
				
			||||||
 | 
					      sockets_[i] = socket(AF_INET, SOCK_STREAM, 0);
 | 
				
			||||||
 | 
					      if (sockets_[i] < 0) {
 | 
				
			||||||
 | 
					        std::ostringstream msg;
 | 
				
			||||||
 | 
					        msg << "Couldn't create socket (error: " << errno << ")";
 | 
				
			||||||
 | 
					        throw std::runtime_error(msg.str());
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					      for (int attempt = 0; attempt < CONN_ATTEMPTS; attempt++) {
 | 
				
			||||||
 | 
					        if (attempt > 0) {
 | 
				
			||||||
 | 
					          int wait = (1 << (attempt - 1)) * CONN_WAIT;
 | 
				
			||||||
 | 
					          std::this_thread::sleep_for(std::chrono::milliseconds(wait));
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        success = connect(sockets_[i], peers[i].sockaddr(), peers[i].len);
 | 
				
			||||||
 | 
					        if (success == 0) {
 | 
				
			||||||
 | 
					          break;
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
      if (success < 0) {
 | 
					      if (success < 0) {
 | 
				
			||||||
        std::ostringstream msg;
 | 
					        std::ostringstream msg;
 | 
				
			||||||
        msg << "Couldn't bind socket (error: " << errno << ")";
 | 
					        msg << "Couldn't connect (rank: " << rank_ << " to: " << i
 | 
				
			||||||
 | 
					            << " error: " << errno << ")";
 | 
				
			||||||
        throw std::runtime_error(msg.str());
 | 
					        throw std::runtime_error(msg.str());
 | 
				
			||||||
      }
 | 
					      }
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  ~GroupImpl() {
 | 
					  ~GroupImpl() {
 | 
				
			||||||
    if (global_) {
 | 
					    if (global_) {
 | 
				
			||||||
      close(socket_fd_);
 | 
					      for (int sock : sockets_) {
 | 
				
			||||||
 | 
					        shutdown(sock, 2);
 | 
				
			||||||
 | 
					        close(sock);
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -166,43 +279,92 @@ struct GroupImpl {
 | 
				
			|||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  int size() {
 | 
					  int size() {
 | 
				
			||||||
    return std::max(peers_.size(), 1ul);
 | 
					    return std::max(sockets_.size(), 1ul);
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  void send(const char* buf, size_t len, int dst) {
 | 
					  void send(const char* buf, size_t len, int dst) {
 | 
				
			||||||
    while (len > 0) {
 | 
					    while (len > 0) {
 | 
				
			||||||
      size_t l = std::min(len, 8192ul);
 | 
					      ssize_t r = ::send(sockets_[dst], buf, len, 0);
 | 
				
			||||||
      ssize_t r = sendto(
 | 
					 | 
				
			||||||
          socket_fd_, buf, l, 0, peers_[dst].sockaddr(), peers_[dst].len);
 | 
					 | 
				
			||||||
      if (r <= 0) {
 | 
					      if (r <= 0) {
 | 
				
			||||||
        std::ostringstream msg;
 | 
					        std::ostringstream msg;
 | 
				
			||||||
        msg << "Send of " << l << " bytes failed (errno: " << errno << ")";
 | 
					        msg << "Send of " << len << " bytes failed (errno: " << errno << ")";
 | 
				
			||||||
        throw std::runtime_error(msg.str());
 | 
					        throw std::runtime_error(msg.str());
 | 
				
			||||||
      }
 | 
					      }
 | 
				
			||||||
      len -= l;
 | 
					 | 
				
			||||||
      buf += l;
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  void recv(char* buf, size_t len, int src) {
 | 
					 | 
				
			||||||
    sockaddr_storage addr;
 | 
					 | 
				
			||||||
    socklen_t addr_len;
 | 
					 | 
				
			||||||
    while (len != 0) {
 | 
					 | 
				
			||||||
      ssize_t r =
 | 
					 | 
				
			||||||
          recvfrom(socket_fd_, buf, len, 0, (struct sockaddr*)&addr, &addr_len);
 | 
					 | 
				
			||||||
      if (r <= 0) {
 | 
					 | 
				
			||||||
        throw std::runtime_error("Recv failed");
 | 
					 | 
				
			||||||
      }
 | 
					 | 
				
			||||||
      buf += r;
 | 
					      buf += r;
 | 
				
			||||||
      len -= r;
 | 
					      len -= r;
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  void recv(char* buf, size_t len, int src) {
 | 
				
			||||||
 | 
					    while (len > 0) {
 | 
				
			||||||
 | 
					      ssize_t r = ::recv(sockets_[src], buf, len, 0);
 | 
				
			||||||
 | 
					      if (r <= 0) {
 | 
				
			||||||
 | 
					        std::ostringstream msg;
 | 
				
			||||||
 | 
					        msg << "Recv of " << len << " bytes failed (errno: " << errno << ")";
 | 
				
			||||||
 | 
					        throw std::runtime_error(msg.str());
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					      buf += r;
 | 
				
			||||||
 | 
					      len -= r;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  template <typename T>
 | 
				
			||||||
 | 
					  void send_recv_sum(char* buf, size_t len, int peer) {
 | 
				
			||||||
 | 
					    char recv_buffer[2 * PACKET_SIZE];
 | 
				
			||||||
 | 
					    char* recv_buffers[2];
 | 
				
			||||||
 | 
					    recv_buffers[0] = recv_buffer;
 | 
				
			||||||
 | 
					    recv_buffers[1] = recv_buffer + PACKET_SIZE;
 | 
				
			||||||
 | 
					    std::future<void> sent, received;
 | 
				
			||||||
 | 
					    size_t n_blocks = (len + PACKET_SIZE - 1) / PACKET_SIZE;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    for (size_t b = 0; b < n_blocks; b++) {
 | 
				
			||||||
 | 
					      if (b > 0) {
 | 
				
			||||||
 | 
					        sent.wait();
 | 
				
			||||||
 | 
					        received.wait();
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					      size_t l = std::min(len - b * PACKET_SIZE, PACKET_SIZE);
 | 
				
			||||||
 | 
					      if (rank_ < peer) {
 | 
				
			||||||
 | 
					        sent = send_async(buf + b * PACKET_SIZE, l, peer);
 | 
				
			||||||
 | 
					        received = recv_async(recv_buffers[b % 2], l, peer);
 | 
				
			||||||
 | 
					      } else {
 | 
				
			||||||
 | 
					        received = recv_async(recv_buffers[b % 2], l, peer);
 | 
				
			||||||
 | 
					        sent = send_async(buf + b * PACKET_SIZE, l, peer);
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					      if (b > 0) {
 | 
				
			||||||
 | 
					        sum_inplace(
 | 
				
			||||||
 | 
					            (const T*)recv_buffers[(b - 1) % 2],
 | 
				
			||||||
 | 
					            (T*)(buf + (b - 1) * PACKET_SIZE),
 | 
				
			||||||
 | 
					            PACKET_SIZE / sizeof(T));
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    sent.wait();
 | 
				
			||||||
 | 
					    received.wait();
 | 
				
			||||||
 | 
					    size_t l = std::min(len - (n_blocks - 1) * PACKET_SIZE, PACKET_SIZE);
 | 
				
			||||||
 | 
					    sum_inplace(
 | 
				
			||||||
 | 
					        (const T*)recv_buffers[(n_blocks - 1) % 2],
 | 
				
			||||||
 | 
					        (T*)(buf + (n_blocks - 1) * PACKET_SIZE),
 | 
				
			||||||
 | 
					        l / sizeof(T));
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  void send_recv_sum(array& out, int peer) {
 | 
				
			||||||
 | 
					    SWITCH_TYPE(out, send_recv_sum<T>(out.data<char>(), out.nbytes(), peer));
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  std::future<void> send_async(const char* buf, size_t len, int dst) {
 | 
				
			||||||
 | 
					    return pool_.enqueue(
 | 
				
			||||||
 | 
					        [this, buf, len, dst]() { this->send(buf, len, dst); });
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  std::future<void> recv_async(char* buf, size_t len, int src) {
 | 
				
			||||||
 | 
					    return pool_.enqueue(
 | 
				
			||||||
 | 
					        [this, buf, len, src]() { this->recv(buf, len, src); });
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 private:
 | 
					 private:
 | 
				
			||||||
  std::vector<address_t> peers_;
 | 
					 | 
				
			||||||
  int rank_;
 | 
					  int rank_;
 | 
				
			||||||
  bool global_;
 | 
					  bool global_;
 | 
				
			||||||
  int socket_fd_;
 | 
					  ThreadPool pool_;
 | 
				
			||||||
 | 
					  std::vector<int> sockets_;
 | 
				
			||||||
};
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
} // namespace
 | 
					} // namespace
 | 
				
			||||||
@@ -251,57 +413,84 @@ Stream communication_stream() {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
void all_sum(Group group_, const array& input_, array& output) {
 | 
					void all_sum(Group group_, const array& input_, array& output) {
 | 
				
			||||||
  auto group = std::static_pointer_cast<GroupImpl>(group_.raw_group());
 | 
					  auto group = std::static_pointer_cast<GroupImpl>(group_.raw_group());
 | 
				
			||||||
  if (group->size() != 2) {
 | 
					 | 
				
			||||||
    throw std::runtime_error("Only pairwise communication supported for now");
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
  array input = ensure_row_contiguous(input_);
 | 
					  array input = ensure_row_contiguous(input_);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // Donation not supported
 | 
					  int size = group->size();
 | 
				
			||||||
  if (input.data<void>() == output.data<void>()) {
 | 
					  int rank = group->rank();
 | 
				
			||||||
    array temp(
 | 
					
 | 
				
			||||||
        allocator::malloc_or_wait(output.nbytes()),
 | 
					  if ((size & (size - 1)) != 0) {
 | 
				
			||||||
        output.shape(),
 | 
					    throw std::runtime_error("Only powers of 2 are currently supported");
 | 
				
			||||||
        output.dtype());
 | 
					  }
 | 
				
			||||||
    if (group->rank() == 0) {
 | 
					
 | 
				
			||||||
      group->send(input.data<char>(), input.nbytes(), 1);
 | 
					  // If not inplace all reduce then copy the input to the output first.
 | 
				
			||||||
      group->recv(temp.data<char>(), output.nbytes(), 1);
 | 
					  if (input.data<void>() != output.data<void>()) {
 | 
				
			||||||
      sum_inplace(temp, output);
 | 
					    std::memcpy(output.data<char>(), input.data<char>(), input.nbytes());
 | 
				
			||||||
    } else {
 | 
					  }
 | 
				
			||||||
      group->recv(temp.data<char>(), output.nbytes(), 0);
 | 
					
 | 
				
			||||||
      group->send(input.data<char>(), input.nbytes(), 0);
 | 
					  // Butterfly all reduce
 | 
				
			||||||
      sum_inplace(temp, output);
 | 
					  for (int distance = 1; distance <= size / 2; distance *= 2) {
 | 
				
			||||||
    }
 | 
					    group->send_recv_sum(output, rank ^ distance);
 | 
				
			||||||
  } else {
 | 
					 | 
				
			||||||
    if (group->rank() == 0) {
 | 
					 | 
				
			||||||
      group->send(input.data<char>(), input.nbytes(), 1);
 | 
					 | 
				
			||||||
      group->recv(output.data<char>(), output.nbytes(), 1);
 | 
					 | 
				
			||||||
      sum_inplace(input, output);
 | 
					 | 
				
			||||||
    } else {
 | 
					 | 
				
			||||||
      group->recv(output.data<char>(), output.nbytes(), 0);
 | 
					 | 
				
			||||||
      group->send(input.data<char>(), input.nbytes(), 0);
 | 
					 | 
				
			||||||
      sum_inplace(input, output);
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
void all_gather(Group group_, const array& input_, array& output) {
 | 
					void all_gather(Group group_, const array& input_, array& output) {
 | 
				
			||||||
  auto group = std::static_pointer_cast<GroupImpl>(group_.raw_group());
 | 
					  auto group = std::static_pointer_cast<GroupImpl>(group_.raw_group());
 | 
				
			||||||
  if (group->size() != 2) {
 | 
					 | 
				
			||||||
    throw std::runtime_error("Only pairwise communication supported for now");
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
  array input = ensure_row_contiguous(input_);
 | 
					  array input = ensure_row_contiguous(input_);
 | 
				
			||||||
  if (group->rank() == 0) {
 | 
					  std::future<void> sent;
 | 
				
			||||||
    group->send(input.data<char>(), input.nbytes(), 1);
 | 
					  std::future<void> received;
 | 
				
			||||||
    group->recv(output.data<char>() + input.nbytes(), input.nbytes(), 1);
 | 
					
 | 
				
			||||||
    std::memcpy(output.data<char>(), input.data<char>(), input.nbytes());
 | 
					  int rank = group->rank();
 | 
				
			||||||
  } else {
 | 
					  int size = group->size();
 | 
				
			||||||
    group->recv(output.data<char>(), input.nbytes(), 0);
 | 
					
 | 
				
			||||||
    group->send(input.data<char>(), input.nbytes(), 0);
 | 
					  if ((size & (size - 1)) != 0) {
 | 
				
			||||||
    std::memcpy(
 | 
					    throw std::runtime_error("Only powers of 2 are currently supported");
 | 
				
			||||||
        output.data<char>() + input.nbytes(),
 | 
					 | 
				
			||||||
        input.data<char>(),
 | 
					 | 
				
			||||||
        input.nbytes());
 | 
					 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Butterfly all gather
 | 
				
			||||||
 | 
					  int peer = rank ^ 1;
 | 
				
			||||||
 | 
					  if (peer < rank) {
 | 
				
			||||||
 | 
					    received = group->recv_async(
 | 
				
			||||||
 | 
					        output.data<char>() + peer * input.nbytes(), input.nbytes(), peer);
 | 
				
			||||||
 | 
					    sent = group->send_async(input.data<char>(), input.nbytes(), peer);
 | 
				
			||||||
 | 
					  } else {
 | 
				
			||||||
 | 
					    sent = group->send_async(input.data<char>(), input.nbytes(), peer);
 | 
				
			||||||
 | 
					    received = group->recv_async(
 | 
				
			||||||
 | 
					        output.data<char>() + peer * input.nbytes(), input.nbytes(), peer);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  std::memcpy(
 | 
				
			||||||
 | 
					      output.data<char>() + rank * input.nbytes(),
 | 
				
			||||||
 | 
					      input.data<char>(),
 | 
				
			||||||
 | 
					      input.nbytes());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  for (int distance = 2; distance <= size / 2; distance *= 2) {
 | 
				
			||||||
 | 
					    sent.wait();
 | 
				
			||||||
 | 
					    received.wait();
 | 
				
			||||||
 | 
					    int peer = rank ^ distance;
 | 
				
			||||||
 | 
					    int their_offset = peer & ~(distance - 1);
 | 
				
			||||||
 | 
					    int our_offset = rank & ~(distance - 1);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if (peer < rank) {
 | 
				
			||||||
 | 
					      received = group->recv_async(
 | 
				
			||||||
 | 
					          output.data<char>() + their_offset * input.nbytes(),
 | 
				
			||||||
 | 
					          distance * input.nbytes(),
 | 
				
			||||||
 | 
					          peer);
 | 
				
			||||||
 | 
					      sent = group->send_async(
 | 
				
			||||||
 | 
					          output.data<char>() + our_offset * input.nbytes(),
 | 
				
			||||||
 | 
					          distance * input.nbytes(),
 | 
				
			||||||
 | 
					          peer);
 | 
				
			||||||
 | 
					    } else {
 | 
				
			||||||
 | 
					      sent = group->send_async(
 | 
				
			||||||
 | 
					          output.data<char>() + our_offset * input.nbytes(),
 | 
				
			||||||
 | 
					          distance * input.nbytes(),
 | 
				
			||||||
 | 
					          peer);
 | 
				
			||||||
 | 
					      received = group->recv_async(
 | 
				
			||||||
 | 
					          output.data<char>() + their_offset * input.nbytes(),
 | 
				
			||||||
 | 
					          distance * input.nbytes(),
 | 
				
			||||||
 | 
					          peer);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  sent.wait();
 | 
				
			||||||
 | 
					  received.wait();
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
void send(Group group_, const array& input_, int dst) {
 | 
					void send(Group group_, const array& input_, int dst) {
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user