mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-05 16:13:52 +08:00
Compare commits
14 Commits
split_logs
...
socket-dis
Author | SHA1 | Date | |
---|---|---|---|
![]() |
87b680766e | ||
![]() |
70ffaa50d2 | ||
![]() |
d82699f0f1 | ||
![]() |
6fc00d2c10 | ||
![]() |
44f0de2854 | ||
![]() |
29ec3539ed | ||
![]() |
e94f0028c3 | ||
![]() |
e5354fcddb | ||
![]() |
34dd079a64 | ||
![]() |
16975815e9 | ||
![]() |
a8b3da7946 | ||
![]() |
060e1c9f92 | ||
![]() |
0b04742985 | ||
![]() |
c3ccd4919f |
@@ -168,11 +168,12 @@ endif()
|
||||
find_package(MPI)
|
||||
if(MPI_FOUND)
|
||||
execute_process(
|
||||
COMMAND zsh "-c" "mpirun --version"
|
||||
COMMAND zsh "-c" "${MPIEXEC_EXECUTABLE} --version"
|
||||
OUTPUT_VARIABLE MPI_VERSION
|
||||
ERROR_QUIET)
|
||||
if(${MPI_VERSION} MATCHES ".*Open MPI.*")
|
||||
if(${MPI_VERSION} MATCHES ".*Open MPI.*" OR ${MPI_VERSION} MATCHES ".*OpenRTE.*")
|
||||
target_include_directories(mlx PRIVATE ${MPI_INCLUDE_PATH})
|
||||
target_link_libraries(mlx PRIVATE ${MPI_CXX_LIBRARIES})
|
||||
elseif(MPI_VERSION STREQUAL "")
|
||||
set(MPI_FOUND FALSE)
|
||||
message(
|
||||
|
@@ -1,8 +1,20 @@
|
||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp)
|
||||
|
||||
if(MPI_FOUND AND MLX_BUILD_CPU)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/mpi)
|
||||
else()
|
||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_distributed.cpp)
|
||||
if(MLX_BUILD_CPU)
|
||||
if(MLX_CUSTOM_DISTRIBUTED)
|
||||
if(MLX_CUSTOM_DISTRIBUTED STREQUAL "gloo")
|
||||
message(STATUS "Distributed: using gloo backend")
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/gloo)
|
||||
else()
|
||||
message(STATUS "Distributed: using sockets backend")
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/sockets)
|
||||
endif()
|
||||
elseif(MPI_FOUND)
|
||||
message(STATUS "Distributed: using MPI backend")
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/mpi)
|
||||
else()
|
||||
message(STATUS "Distributed: no support")
|
||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_distributed.cpp)
|
||||
endif()
|
||||
endif()
|
||||
|
@@ -32,6 +32,8 @@ struct Group {
|
||||
*/
|
||||
Group split(int color, int key = -1);
|
||||
|
||||
void barrier();
|
||||
|
||||
const std::shared_ptr<void>& raw_group() {
|
||||
return group_;
|
||||
}
|
||||
|
25
mlx/distributed/gloo/CMakeLists.txt
Normal file
25
mlx/distributed/gloo/CMakeLists.txt
Normal file
@@ -0,0 +1,25 @@
|
||||
find_path(
|
||||
GLOO_INCLUDE_DIR gloo/allreduce.h
|
||||
PATHS ${GLOO_INC_DIR}
|
||||
PATH_SUFFIXES include)
|
||||
|
||||
find_library(
|
||||
GLOO_LIBRARY gloo
|
||||
PATHS ${GLOO_LIB_DIR}
|
||||
PATH_SUFFIXES lib
|
||||
HINTS GLOO)
|
||||
|
||||
find_library(
|
||||
UV_LIBRARY uv
|
||||
PATHS ${UV_LIB_DIR}
|
||||
PATH_SUFFIXES lib
|
||||
HINTS UV)
|
||||
|
||||
message(STATUS "GLOO LIB <${GLOO_LIBRARY}>")
|
||||
message(STATUS "GLOO INC <${GLOO_INCLUDE_DIR}>")
|
||||
message(STATUS "UV LIB <${UV_LIB_DIR}>")
|
||||
|
||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gloo.cpp)
|
||||
target_link_libraries(mlx PUBLIC ${GLOO_LIBRARY})
|
||||
target_link_libraries(mlx PUBLIC ${UV_LIBRARY})
|
||||
target_include_directories(mlx PRIVATE ${GLOO_INCLUDE_DIR})
|
178
mlx/distributed/gloo/gloo.cpp
Normal file
178
mlx/distributed/gloo/gloo.cpp
Normal file
@@ -0,0 +1,178 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <unistd.h>
|
||||
#include <chrono>
|
||||
#include <cstdlib>
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <thread>
|
||||
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/distributed/distributed.h"
|
||||
#include "mlx/distributed/distributed_impl.h"
|
||||
#include "mlx/io/threadpool.h"
|
||||
|
||||
#include "gloo/allreduce.h"
|
||||
#include "gloo/math.h"
|
||||
#include "gloo/mpi/context.h"
|
||||
#include "gloo/transport/uv/device.h"
|
||||
|
||||
#define SWITCH_TYPE(x, ...) \
|
||||
switch ((x).dtype()) { \
|
||||
case bool_: { \
|
||||
using T = bool; \
|
||||
__VA_ARGS__; \
|
||||
} break; \
|
||||
case int8: { \
|
||||
using T = int8_t; \
|
||||
__VA_ARGS__; \
|
||||
} break; \
|
||||
case int16: { \
|
||||
using T = int16_t; \
|
||||
__VA_ARGS__; \
|
||||
} break; \
|
||||
case int32: { \
|
||||
using T = int32_t; \
|
||||
__VA_ARGS__; \
|
||||
} break; \
|
||||
case int64: { \
|
||||
using T = int64_t; \
|
||||
__VA_ARGS__; \
|
||||
} break; \
|
||||
case uint8: { \
|
||||
using T = uint8_t; \
|
||||
__VA_ARGS__; \
|
||||
} break; \
|
||||
case uint16: { \
|
||||
using T = uint16_t; \
|
||||
__VA_ARGS__; \
|
||||
} break; \
|
||||
case uint32: { \
|
||||
using T = uint32_t; \
|
||||
__VA_ARGS__; \
|
||||
} break; \
|
||||
case uint64: { \
|
||||
using T = uint64_t; \
|
||||
__VA_ARGS__; \
|
||||
} break; \
|
||||
case bfloat16: { \
|
||||
using T = bfloat16_t; \
|
||||
__VA_ARGS__; \
|
||||
} break; \
|
||||
case float16: { \
|
||||
using T = float16_t; \
|
||||
__VA_ARGS__; \
|
||||
} break; \
|
||||
case float32: { \
|
||||
using T = float; \
|
||||
__VA_ARGS__; \
|
||||
} break; \
|
||||
case complex64: { \
|
||||
using T = complex64_t; \
|
||||
__VA_ARGS__; \
|
||||
} break; \
|
||||
}
|
||||
|
||||
namespace mlx::core::distributed {
|
||||
|
||||
namespace {
|
||||
array ensure_row_contiguous(const array& arr) {
|
||||
if (arr.flags().row_contiguous) {
|
||||
return arr;
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy(arr, arr_copy, CopyType::General);
|
||||
return arr_copy;
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
bool is_available() {
|
||||
return true;
|
||||
}
|
||||
|
||||
int Group::rank() {
|
||||
return std::static_pointer_cast<gloo::mpi::Context>(group_)->rank;
|
||||
}
|
||||
|
||||
int Group::size() {
|
||||
return std::static_pointer_cast<gloo::mpi::Context>(group_)->size;
|
||||
}
|
||||
|
||||
Group Group::split(int color, int key) {
|
||||
throw std::runtime_error("split is NYI");
|
||||
}
|
||||
|
||||
void Group::barrier() {
|
||||
throw std::runtime_error("barrier is NYI");
|
||||
}
|
||||
|
||||
struct GlooCTX {
|
||||
std::shared_ptr<gloo::mpi::Context> context;
|
||||
std::shared_ptr<gloo::transport::Device> dev;
|
||||
};
|
||||
|
||||
Group init(bool strict /* = false */) {
|
||||
static std::shared_ptr<GlooCTX> gloo_ctx = nullptr;
|
||||
|
||||
if (gloo_ctx == nullptr) {
|
||||
gloo_ctx = std::make_shared<GlooCTX>();
|
||||
gloo_ctx->context = gloo::mpi::Context::createManaged();
|
||||
gloo_ctx->dev = gloo::transport::uv::CreateDevice("localhost");
|
||||
gloo_ctx->context->connectFullMesh(gloo_ctx->dev);
|
||||
}
|
||||
return Group(gloo_ctx->context);
|
||||
}
|
||||
|
||||
namespace detail {
|
||||
|
||||
Stream communication_stream() {
|
||||
static Stream comm_stream = new_stream(Device::cpu);
|
||||
return comm_stream;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void all_reduce_sum(
|
||||
std::shared_ptr<gloo::mpi::Context> context,
|
||||
T* output,
|
||||
T* input,
|
||||
size_t len) {
|
||||
gloo::AllreduceOptions opts_(context);
|
||||
opts_.setInput(input, len);
|
||||
opts_.setOutput(output, len);
|
||||
opts_.setAlgorithm(gloo::AllreduceOptions::Algorithm::RING);
|
||||
opts_.setReduceFunction(
|
||||
static_cast<void (*)(void*, const void*, const void*, size_t)>(
|
||||
&gloo::sum<T>));
|
||||
gloo::allreduce(opts_);
|
||||
}
|
||||
|
||||
void all_sum(Group group_, const array& input_, array& output) {
|
||||
array input = ensure_row_contiguous(input_);
|
||||
if (input.data<void>() != output.data<void>()) {
|
||||
std::memcpy(output.data<char>(), input.data<char>(), input.nbytes());
|
||||
}
|
||||
auto context =
|
||||
std::static_pointer_cast<gloo::mpi::Context>(group_.raw_group());
|
||||
SWITCH_TYPE(
|
||||
output,
|
||||
all_reduce_sum<T>(
|
||||
context, output.data<T>(), input.data<T>(), input.size()));
|
||||
}
|
||||
|
||||
void all_gather(Group group_, const array& input_, array& output) {
|
||||
throw std::runtime_error("all_gather NYI");
|
||||
}
|
||||
|
||||
void send(Group group_, const array& input_, int dst) {
|
||||
throw std::runtime_error("send NYI");
|
||||
}
|
||||
|
||||
void recv(Group group_, array& out, int src) {
|
||||
throw std::runtime_error("recv NYI");
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
} // namespace mlx::core::distributed
|
@@ -71,6 +71,7 @@ struct MPIWrapper {
|
||||
LOAD_SYMBOL(MPI_Allgather, all_gather);
|
||||
LOAD_SYMBOL(MPI_Send, send);
|
||||
LOAD_SYMBOL(MPI_Recv, recv);
|
||||
LOAD_SYMBOL(MPI_Barrier, barrier);
|
||||
LOAD_SYMBOL(MPI_Type_contiguous, mpi_type_contiguous);
|
||||
LOAD_SYMBOL(MPI_Type_commit, mpi_type_commit);
|
||||
LOAD_SYMBOL(MPI_Op_create, mpi_op_create);
|
||||
@@ -195,6 +196,7 @@ struct MPIWrapper {
|
||||
int (*comm_free)(MPI_Comm*);
|
||||
int (*send)(const void*, int, MPI_Datatype, int, int, MPI_Comm);
|
||||
int (*recv)(void*, int, MPI_Datatype, int, int, MPI_Comm, MPI_Status*);
|
||||
int (*barrier)(MPI_Comm);
|
||||
|
||||
// Objects
|
||||
MPI_Comm comm_world_;
|
||||
@@ -263,6 +265,10 @@ struct MPIGroupImpl {
|
||||
return size_;
|
||||
}
|
||||
|
||||
void barrier() {
|
||||
mpi().barrier(comm_);
|
||||
}
|
||||
|
||||
private:
|
||||
MPI_Comm comm_;
|
||||
bool global_;
|
||||
@@ -298,6 +304,11 @@ Group Group::split(int color, int key) {
|
||||
return Group(std::make_shared<MPIGroupImpl>(new_comm, false));
|
||||
}
|
||||
|
||||
void Group::barrier() {
|
||||
auto mpi_group = std::static_pointer_cast<MPIGroupImpl>(group_);
|
||||
mpi_group->barrier();
|
||||
}
|
||||
|
||||
bool is_available() {
|
||||
return mpi().is_available();
|
||||
}
|
||||
|
@@ -17,6 +17,8 @@ Group Group::split(int color, int key) {
|
||||
throw std::runtime_error("Cannot split the distributed group further");
|
||||
}
|
||||
|
||||
void Group::barrier() {}
|
||||
|
||||
bool is_available() {
|
||||
return false;
|
||||
}
|
||||
|
5
mlx/distributed/sockets/CMakeLists.txt
Normal file
5
mlx/distributed/sockets/CMakeLists.txt
Normal file
@@ -0,0 +1,5 @@
|
||||
target_sources(
|
||||
mlx
|
||||
PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/sockets.cpp
|
||||
)
|
522
mlx/distributed/sockets/sockets.cpp
Normal file
522
mlx/distributed/sockets/sockets.cpp
Normal file
@@ -0,0 +1,522 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <arpa/inet.h>
|
||||
#include <json.hpp>
|
||||
#include <netdb.h>
|
||||
#include <sys/socket.h>
|
||||
#include <unistd.h>
|
||||
#include <chrono>
|
||||
#include <cstdlib>
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <thread>
|
||||
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/distributed/distributed.h"
|
||||
#include "mlx/distributed/distributed_impl.h"
|
||||
#include "mlx/io/threadpool.h"
|
||||
|
||||
#define SWITCH_TYPE(x, ...) \
|
||||
switch ((x).dtype()) { \
|
||||
case bool_: { \
|
||||
using T = bool; \
|
||||
__VA_ARGS__; \
|
||||
} break; \
|
||||
case int8: { \
|
||||
using T = int8_t; \
|
||||
__VA_ARGS__; \
|
||||
} break; \
|
||||
case int16: { \
|
||||
using T = int16_t; \
|
||||
__VA_ARGS__; \
|
||||
} break; \
|
||||
case int32: { \
|
||||
using T = int32_t; \
|
||||
__VA_ARGS__; \
|
||||
} break; \
|
||||
case int64: { \
|
||||
using T = int64_t; \
|
||||
__VA_ARGS__; \
|
||||
} break; \
|
||||
case uint8: { \
|
||||
using T = uint8_t; \
|
||||
__VA_ARGS__; \
|
||||
} break; \
|
||||
case uint16: { \
|
||||
using T = uint16_t; \
|
||||
__VA_ARGS__; \
|
||||
} break; \
|
||||
case uint32: { \
|
||||
using T = uint32_t; \
|
||||
__VA_ARGS__; \
|
||||
} break; \
|
||||
case uint64: { \
|
||||
using T = uint64_t; \
|
||||
__VA_ARGS__; \
|
||||
} break; \
|
||||
case bfloat16: { \
|
||||
using T = bfloat16_t; \
|
||||
__VA_ARGS__; \
|
||||
} break; \
|
||||
case float16: { \
|
||||
using T = float16_t; \
|
||||
__VA_ARGS__; \
|
||||
} break; \
|
||||
case float32: { \
|
||||
using T = float; \
|
||||
__VA_ARGS__; \
|
||||
} break; \
|
||||
case complex64: { \
|
||||
using T = complex64_t; \
|
||||
__VA_ARGS__; \
|
||||
} break; \
|
||||
}
|
||||
|
||||
constexpr const size_t PACKET_SIZE = 262144;
|
||||
constexpr const int CONN_ATTEMPTS = 5;
|
||||
constexpr const int CONN_WAIT = 1000;
|
||||
|
||||
using json = nlohmann::json;
|
||||
|
||||
namespace mlx::core::distributed {
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename T>
|
||||
void sum_inplace(const T* input, T* output, size_t N) {
|
||||
while (N-- > 0) {
|
||||
*output += *input;
|
||||
input++;
|
||||
output++;
|
||||
}
|
||||
}
|
||||
|
||||
void sum_inplace(const array& input, array& output) {
|
||||
SWITCH_TYPE(
|
||||
input, sum_inplace(input.data<T>(), output.data<T>(), input.size()));
|
||||
}
|
||||
|
||||
array ensure_row_contiguous(const array& arr) {
|
||||
if (arr.flags().row_contiguous) {
|
||||
return arr;
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy(arr, arr_copy, CopyType::General);
|
||||
return arr_copy;
|
||||
}
|
||||
}
|
||||
|
||||
struct address_t {
|
||||
sockaddr_storage addr;
|
||||
socklen_t len;
|
||||
|
||||
const sockaddr* sockaddr() {
|
||||
return (struct sockaddr*)&addr;
|
||||
}
|
||||
};
|
||||
|
||||
address_t parse_address(std::string ip, std::string port) {
|
||||
struct addrinfo hints, *res;
|
||||
memset(&hints, 0, sizeof(hints));
|
||||
hints.ai_family = AF_UNSPEC;
|
||||
hints.ai_socktype = SOCK_STREAM;
|
||||
|
||||
int status = getaddrinfo(ip.c_str(), port.c_str(), &hints, &res);
|
||||
if (status != 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "Can't parse peer address " << ip << ":" << port;
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
address_t result;
|
||||
memcpy(&result.addr, res->ai_addr, res->ai_addrlen);
|
||||
result.len = res->ai_addrlen;
|
||||
freeaddrinfo(res);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
std::vector<address_t> load_peers() {
|
||||
std::vector<address_t> peers;
|
||||
std::ifstream f;
|
||||
|
||||
if (const char* hostfile_buf = std::getenv("MLX_HOSTFILE")) {
|
||||
f.open(hostfile_buf);
|
||||
} else {
|
||||
return peers;
|
||||
}
|
||||
|
||||
json hosts = json::parse(f);
|
||||
for (auto& h : hosts) {
|
||||
peers.push_back(std::move(parse_address(
|
||||
h["ip"].template get<std::string>(),
|
||||
h["port"].template get<std::string>())));
|
||||
}
|
||||
|
||||
return peers;
|
||||
}
|
||||
|
||||
struct GroupImpl {
|
||||
GroupImpl(std::vector<address_t> peers, int rank, bool global)
|
||||
: rank_(rank), global_(global), pool_(4), sockets_(peers.size(), -1) {
|
||||
if (rank_ > 0 && rank_ >= peers.size()) {
|
||||
throw std::runtime_error(
|
||||
"Rank cannot be larger than the size of the group");
|
||||
}
|
||||
|
||||
int success;
|
||||
|
||||
// If we are expecting anyone to connect to us
|
||||
if (rank_ + 1 < peers.size()) {
|
||||
// Create the socket to wait for connections from the peers
|
||||
int sock = socket(AF_INET, SOCK_STREAM, 0);
|
||||
if (sock < 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "Couldn't create socket (error: " << errno << ")";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
// Make sure we can launch immediately after shutdown by setting the
|
||||
// reuseaddr option so that we don't get address already in use errors
|
||||
int enable = 1;
|
||||
success =
|
||||
setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &enable, sizeof(int));
|
||||
if (success < 0) {
|
||||
shutdown(sock, 2);
|
||||
close(sock);
|
||||
std::ostringstream msg;
|
||||
msg << "Couldn't enable reuseaddr (rank: " << rank_
|
||||
<< " error: " << errno << ")";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
success =
|
||||
setsockopt(sock, SOL_SOCKET, SO_REUSEPORT, &enable, sizeof(int));
|
||||
if (success < 0) {
|
||||
shutdown(sock, 2);
|
||||
close(sock);
|
||||
std::ostringstream msg;
|
||||
msg << "Couldn't enable reuseport (rank: " << rank_
|
||||
<< " error: " << errno << ")";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
// Bind it to the port
|
||||
success = bind(sock, peers[rank_].sockaddr(), peers[rank_].len);
|
||||
if (success < 0) {
|
||||
shutdown(sock, 2);
|
||||
close(sock);
|
||||
std::ostringstream msg;
|
||||
msg << "Couldn't bind socket (rank: " << rank_ << " error: " << errno
|
||||
<< ")";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
// Wait for connections
|
||||
success = listen(sock, 0);
|
||||
if (success < 0) {
|
||||
shutdown(sock, 2);
|
||||
close(sock);
|
||||
std::ostringstream msg;
|
||||
msg << "Couldn't listen (error: " << errno << ")";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
for (int i = 0; i < peers.size() - rank_ - 1; i++) {
|
||||
int peer_socket = accept(sock, nullptr, nullptr);
|
||||
if (peer_socket < 0) {
|
||||
shutdown(sock, 2);
|
||||
close(sock);
|
||||
std::ostringstream msg;
|
||||
msg << "Accept failed (error: " << errno << ")";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
sockets_[peers.size() - 1 - i] = peer_socket;
|
||||
}
|
||||
|
||||
// Close the listening socket
|
||||
shutdown(sock, 2);
|
||||
close(sock);
|
||||
}
|
||||
|
||||
// Connect to the peers with smaller rank
|
||||
for (int i = 0; i < rank_; i++) {
|
||||
sockets_[i] = socket(AF_INET, SOCK_STREAM, 0);
|
||||
if (sockets_[i] < 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "Couldn't create socket (error: " << errno << ")";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
for (int attempt = 0; attempt < CONN_ATTEMPTS; attempt++) {
|
||||
if (attempt > 0) {
|
||||
int wait = (1 << (attempt - 1)) * CONN_WAIT;
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(wait));
|
||||
}
|
||||
success = connect(sockets_[i], peers[i].sockaddr(), peers[i].len);
|
||||
if (success == 0) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (success < 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "Couldn't connect (rank: " << rank_ << " to: " << i
|
||||
<< " error: " << errno << ")";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
~GroupImpl() {
|
||||
if (global_) {
|
||||
for (int sock : sockets_) {
|
||||
shutdown(sock, 2);
|
||||
close(sock);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int rank() {
|
||||
return rank_;
|
||||
}
|
||||
|
||||
int size() {
|
||||
return std::max(sockets_.size(), 1ul);
|
||||
}
|
||||
|
||||
void send(const char* buf, size_t len, int dst) {
|
||||
while (len > 0) {
|
||||
ssize_t r = ::send(sockets_[dst], buf, len, 0);
|
||||
if (r <= 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "Send of " << len << " bytes failed (errno: " << errno << ")";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
buf += r;
|
||||
len -= r;
|
||||
}
|
||||
}
|
||||
|
||||
void recv(char* buf, size_t len, int src) {
|
||||
while (len > 0) {
|
||||
ssize_t r = ::recv(sockets_[src], buf, len, 0);
|
||||
if (r <= 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "Recv of " << len << " bytes failed (errno: " << errno << ")";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
buf += r;
|
||||
len -= r;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void send_recv_sum(char* buf, size_t len, int peer) {
|
||||
char recv_buffer[2 * PACKET_SIZE];
|
||||
char* recv_buffers[2];
|
||||
recv_buffers[0] = recv_buffer;
|
||||
recv_buffers[1] = recv_buffer + PACKET_SIZE;
|
||||
std::future<void> sent, received;
|
||||
size_t n_blocks = (len + PACKET_SIZE - 1) / PACKET_SIZE;
|
||||
|
||||
for (size_t b = 0; b < n_blocks; b++) {
|
||||
if (b > 0) {
|
||||
sent.wait();
|
||||
received.wait();
|
||||
}
|
||||
size_t l = std::min(len - b * PACKET_SIZE, PACKET_SIZE);
|
||||
if (rank_ < peer) {
|
||||
sent = send_async(buf + b * PACKET_SIZE, l, peer);
|
||||
received = recv_async(recv_buffers[b % 2], l, peer);
|
||||
} else {
|
||||
received = recv_async(recv_buffers[b % 2], l, peer);
|
||||
sent = send_async(buf + b * PACKET_SIZE, l, peer);
|
||||
}
|
||||
if (b > 0) {
|
||||
sum_inplace(
|
||||
(const T*)recv_buffers[(b - 1) % 2],
|
||||
(T*)(buf + (b - 1) * PACKET_SIZE),
|
||||
PACKET_SIZE / sizeof(T));
|
||||
}
|
||||
}
|
||||
sent.wait();
|
||||
received.wait();
|
||||
size_t l = std::min(len - (n_blocks - 1) * PACKET_SIZE, PACKET_SIZE);
|
||||
sum_inplace(
|
||||
(const T*)recv_buffers[(n_blocks - 1) % 2],
|
||||
(T*)(buf + (n_blocks - 1) * PACKET_SIZE),
|
||||
l / sizeof(T));
|
||||
}
|
||||
|
||||
void send_recv_sum(array& out, int peer) {
|
||||
SWITCH_TYPE(out, send_recv_sum<T>(out.data<char>(), out.nbytes(), peer));
|
||||
}
|
||||
|
||||
std::future<void> send_async(const char* buf, size_t len, int dst) {
|
||||
return pool_.enqueue(
|
||||
[this, buf, len, dst]() { this->send(buf, len, dst); });
|
||||
}
|
||||
|
||||
std::future<void> recv_async(char* buf, size_t len, int src) {
|
||||
return pool_.enqueue(
|
||||
[this, buf, len, src]() { this->recv(buf, len, src); });
|
||||
}
|
||||
|
||||
private:
|
||||
int rank_;
|
||||
bool global_;
|
||||
ThreadPool pool_;
|
||||
std::vector<int> sockets_;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
bool is_available() {
|
||||
return true;
|
||||
}
|
||||
|
||||
int Group::rank() {
|
||||
return std::static_pointer_cast<GroupImpl>(group_)->rank();
|
||||
}
|
||||
|
||||
int Group::size() {
|
||||
return std::static_pointer_cast<GroupImpl>(group_)->size();
|
||||
}
|
||||
|
||||
Group Group::split(int color, int key) {
|
||||
throw std::runtime_error("Splitting not supported yet");
|
||||
}
|
||||
|
||||
void Group::barrier() {
|
||||
char buff[128];
|
||||
std::memset(buff, 1, 128);
|
||||
|
||||
auto group = std::static_pointer_cast<GroupImpl>(raw_group());
|
||||
int size = group->size();
|
||||
int rank = group->rank();
|
||||
|
||||
for (int distance = 1; distance <= size / 2; distance *= 2) {
|
||||
group->send_recv_sum<char>(buff, 128, rank ^ distance);
|
||||
}
|
||||
}
|
||||
|
||||
Group init(bool strict /* = false */) {
|
||||
static std::shared_ptr<GroupImpl> global_group = nullptr;
|
||||
|
||||
if (global_group == nullptr) {
|
||||
auto peers = load_peers();
|
||||
int rank = 0;
|
||||
if (const char* rank_buf = std::getenv("MLX_RANK")) {
|
||||
rank = std::atoi(rank_buf);
|
||||
}
|
||||
if (peers.size() == 0) {
|
||||
if (strict) {
|
||||
throw std::runtime_error("Can't initialize distributed");
|
||||
}
|
||||
}
|
||||
global_group = std::make_shared<GroupImpl>(std::move(peers), rank, true);
|
||||
}
|
||||
return Group(global_group);
|
||||
}
|
||||
|
||||
namespace detail {
|
||||
|
||||
Stream communication_stream() {
|
||||
static Stream comm_stream = new_stream(Device::cpu);
|
||||
return comm_stream;
|
||||
}
|
||||
|
||||
void all_sum(Group group_, const array& input_, array& output) {
|
||||
auto group = std::static_pointer_cast<GroupImpl>(group_.raw_group());
|
||||
array input = ensure_row_contiguous(input_);
|
||||
|
||||
int size = group->size();
|
||||
int rank = group->rank();
|
||||
|
||||
if ((size & (size - 1)) != 0) {
|
||||
throw std::runtime_error("Only powers of 2 are currently supported");
|
||||
}
|
||||
|
||||
// If not inplace all reduce then copy the input to the output first.
|
||||
if (input.data<void>() != output.data<void>()) {
|
||||
std::memcpy(output.data<char>(), input.data<char>(), input.nbytes());
|
||||
}
|
||||
|
||||
// Butterfly all reduce
|
||||
for (int distance = 1; distance <= size / 2; distance *= 2) {
|
||||
group->send_recv_sum(output, rank ^ distance);
|
||||
}
|
||||
}
|
||||
|
||||
void all_gather(Group group_, const array& input_, array& output) {
|
||||
auto group = std::static_pointer_cast<GroupImpl>(group_.raw_group());
|
||||
array input = ensure_row_contiguous(input_);
|
||||
std::future<void> sent;
|
||||
std::future<void> received;
|
||||
|
||||
int rank = group->rank();
|
||||
int size = group->size();
|
||||
|
||||
if ((size & (size - 1)) != 0) {
|
||||
throw std::runtime_error("Only powers of 2 are currently supported");
|
||||
}
|
||||
|
||||
// Butterfly all gather
|
||||
int peer = rank ^ 1;
|
||||
if (peer < rank) {
|
||||
received = group->recv_async(
|
||||
output.data<char>() + peer * input.nbytes(), input.nbytes(), peer);
|
||||
sent = group->send_async(input.data<char>(), input.nbytes(), peer);
|
||||
} else {
|
||||
sent = group->send_async(input.data<char>(), input.nbytes(), peer);
|
||||
received = group->recv_async(
|
||||
output.data<char>() + peer * input.nbytes(), input.nbytes(), peer);
|
||||
}
|
||||
std::memcpy(
|
||||
output.data<char>() + rank * input.nbytes(),
|
||||
input.data<char>(),
|
||||
input.nbytes());
|
||||
|
||||
for (int distance = 2; distance <= size / 2; distance *= 2) {
|
||||
sent.wait();
|
||||
received.wait();
|
||||
int peer = rank ^ distance;
|
||||
int their_offset = peer & ~(distance - 1);
|
||||
int our_offset = rank & ~(distance - 1);
|
||||
|
||||
if (peer < rank) {
|
||||
received = group->recv_async(
|
||||
output.data<char>() + their_offset * input.nbytes(),
|
||||
distance * input.nbytes(),
|
||||
peer);
|
||||
sent = group->send_async(
|
||||
output.data<char>() + our_offset * input.nbytes(),
|
||||
distance * input.nbytes(),
|
||||
peer);
|
||||
} else {
|
||||
sent = group->send_async(
|
||||
output.data<char>() + our_offset * input.nbytes(),
|
||||
distance * input.nbytes(),
|
||||
peer);
|
||||
received = group->recv_async(
|
||||
output.data<char>() + their_offset * input.nbytes(),
|
||||
distance * input.nbytes(),
|
||||
peer);
|
||||
}
|
||||
}
|
||||
sent.wait();
|
||||
received.wait();
|
||||
}
|
||||
|
||||
void send(Group group_, const array& input_, int dst) {
|
||||
array input = ensure_row_contiguous(input_);
|
||||
auto group = std::static_pointer_cast<GroupImpl>(group_.raw_group());
|
||||
group->send(input.data<char>(), input.nbytes(), dst);
|
||||
}
|
||||
|
||||
void recv(Group group_, array& out, int src) {
|
||||
auto group = std::static_pointer_cast<GroupImpl>(group_.raw_group());
|
||||
group->recv(out.data<char>(), out.nbytes(), src);
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
} // namespace mlx::core::distributed
|
@@ -1,5 +1,5 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
//
|
||||
|
||||
#include <json.hpp>
|
||||
#include <stack>
|
||||
|
||||
|
@@ -60,6 +60,12 @@ from mlx.nn.layers.convolution_transpose import (
|
||||
ConvTranspose2d,
|
||||
ConvTranspose3d,
|
||||
)
|
||||
from mlx.nn.layers.distributed import (
|
||||
AllToShardedLinear,
|
||||
QuantizedAllToShardedLinear,
|
||||
QuantizedShardedToAllLinear,
|
||||
ShardedToAllLinear,
|
||||
)
|
||||
from mlx.nn.layers.dropout import Dropout, Dropout2d, Dropout3d
|
||||
from mlx.nn.layers.embedding import Embedding
|
||||
from mlx.nn.layers.linear import Bilinear, Identity, Linear
|
||||
|
456
python/mlx/nn/layers/distributed.py
Normal file
456
python/mlx/nn/layers/distributed.py
Normal file
@@ -0,0 +1,456 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import math
|
||||
from functools import lru_cache
|
||||
from typing import Optional
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx.nn.layers.base import Module
|
||||
|
||||
|
||||
@lru_cache
|
||||
def sum_gradients(group):
|
||||
if group.size() == 1:
|
||||
return lambda x: x
|
||||
|
||||
@mx.custom_function
|
||||
def f(x):
|
||||
return x
|
||||
|
||||
@f.vjp
|
||||
def f(x, dx, _):
|
||||
return mx.distributed.all_sum(dx, group=group)
|
||||
|
||||
return f
|
||||
|
||||
|
||||
class AllToShardedLinear(Module):
|
||||
"""Each member of the group applies part of the affine transformation such
|
||||
that the result is sharded across the group.
|
||||
|
||||
The gradients are automatically aggregated from each member of the group.
|
||||
|
||||
Args:
|
||||
input_dims (int): The dimensionality of the input features
|
||||
output_dims (int): The dimensionality of the output features
|
||||
bias (bool, optional): If set to ``False`` the the layer will not use a
|
||||
bias. Default is ``True``.
|
||||
group (mx.distributed.Group, optional): The sharding will happen across
|
||||
this group. If not set then the global group is used. Default is
|
||||
``None``.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_dims: int,
|
||||
output_dims: int,
|
||||
bias: bool = True,
|
||||
group: Optional[mx.distributed.Group] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Initialize the parameters
|
||||
scale = math.sqrt(1.0 / input_dims)
|
||||
self.group = group or mx.distributed.init()
|
||||
N = self.group.size()
|
||||
|
||||
if (output_dims % N) != 0:
|
||||
raise ValueError(
|
||||
f"Cannot shard the output of size {output_dims} across {N} devices."
|
||||
)
|
||||
|
||||
self.weight = mx.random.uniform(
|
||||
low=-scale,
|
||||
high=scale,
|
||||
shape=(output_dims // N, input_dims),
|
||||
)
|
||||
if bias:
|
||||
self.bias = mx.random.uniform(
|
||||
low=-scale,
|
||||
high=scale,
|
||||
shape=(output_dims // N,),
|
||||
)
|
||||
|
||||
def _extra_repr(self) -> str:
|
||||
out_dims, in_dims = self.weight.shape
|
||||
N = self.group.size()
|
||||
out_dims *= N
|
||||
return f"input_dims={in_dims}, output_dims={out_dims}, bias={'bias' in self}"
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
# Aggregate the gradients coming from each shard
|
||||
if self.group.size() > 1:
|
||||
x = sum_gradients(self.group)(x)
|
||||
|
||||
# Compute the affine projection
|
||||
if "bias" in self:
|
||||
x = mx.addmm(self["bias"], x, self["weight"].T)
|
||||
else:
|
||||
x = x @ self["weight"].T
|
||||
return x
|
||||
|
||||
@classmethod
|
||||
def from_linear(
|
||||
cls, linear_layer: Module, group: Optional[mx.distributed.Group] = None
|
||||
):
|
||||
group = group or mx.distributed.init()
|
||||
N = group.size()
|
||||
r = group.rank()
|
||||
output_dims, input_dims = linear_layer.weight.shape
|
||||
step = output_dims // N
|
||||
|
||||
sl = cls(input_dims, output_dims, False, group)
|
||||
# The multiplication with 1.0 forces a copy, perhaps change to
|
||||
# something better when available.
|
||||
sl.weight = linear_layer.weight[r * step : (r + 1) * step] * 1
|
||||
if "bias" in linear_layer:
|
||||
sl.bias = linear_layer.bias[r * step : (r + 1) * step] * 1
|
||||
|
||||
return sl
|
||||
|
||||
|
||||
class ShardedToAllLinear(Module):
|
||||
"""Each member of the group applies part of the affine transformation and
|
||||
then aggregates the results.
|
||||
|
||||
All nodes will have the same exact result after this layer.
|
||||
|
||||
:class:`ShardedToAllLinear` provides a classmethod :meth:`from_linear` to
|
||||
convert linear layers to sharded :obj:`ShardedToAllLinear` layers.
|
||||
|
||||
Args:
|
||||
input_dims (int): The dimensionality of the input features
|
||||
output_dims (int): The dimensionality of the output features
|
||||
bias (bool, optional): If set to ``False`` the the layer will not use a
|
||||
bias. Default is ``True``.
|
||||
group (mx.distributed.Group, optional): The sharding will happen across
|
||||
this group. If not set then the global group is used. Default is
|
||||
``None``.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_dims: int,
|
||||
output_dims: int,
|
||||
bias: bool = True,
|
||||
group: Optional[mx.distributed.Group] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Initialize the parameters
|
||||
scale = math.sqrt(1.0 / input_dims)
|
||||
self.group = group or mx.distributed.init()
|
||||
N = self.group.size()
|
||||
|
||||
if (input_dims % N) != 0:
|
||||
raise ValueError(
|
||||
f"The input of size {input_dims} cannot be sharded across {N} devices."
|
||||
)
|
||||
|
||||
self.weight = mx.random.uniform(
|
||||
low=-scale,
|
||||
high=scale,
|
||||
shape=(output_dims, input_dims // N),
|
||||
)
|
||||
if bias:
|
||||
self.bias = mx.random.uniform(
|
||||
low=-scale,
|
||||
high=scale,
|
||||
shape=(output_dims,),
|
||||
)
|
||||
|
||||
def _extra_repr(self) -> str:
|
||||
N = self.group.size()
|
||||
out_dims, in_dims = self.weight.shape
|
||||
in_dims *= N
|
||||
return f"input_dims={in_dims}, output_dims={out_dims}, bias={'bias' in self}"
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
if self.group.size() > 1:
|
||||
# Perform the local projection and aggregate the results
|
||||
x = x @ self["weight"].T
|
||||
x = mx.distributed.all_sum(x, group=self.group)
|
||||
|
||||
# Add the bias if we have one
|
||||
if "bias" in self:
|
||||
x = x + self["bias"]
|
||||
else:
|
||||
# Normal linear layer as we are not in a distributed setting.
|
||||
if "bias" in self:
|
||||
x = mx.addmm(self["bias"], x, self["weight"].T)
|
||||
else:
|
||||
x = x @ self["weight"].T
|
||||
return x
|
||||
|
||||
@classmethod
|
||||
def from_linear(
|
||||
cls, linear_layer: Module, group: Optional[mx.distributed.Group] = None
|
||||
):
|
||||
group = group or mx.distributed.init()
|
||||
N = group.size()
|
||||
r = group.rank()
|
||||
output_dims, input_dims = linear_layer.weight.shape
|
||||
step = input_dims // N
|
||||
|
||||
sl = cls(input_dims, output_dims, False, group)
|
||||
# The multiplication with 1.0 forces a copy, perhaps change to
|
||||
# something better when available.
|
||||
sl.weight = linear_layer.weight[:, r * step : (r + 1) * step] * 1
|
||||
if "bias" in linear_layer:
|
||||
sl.bias = linear_layer.bias
|
||||
|
||||
return sl
|
||||
|
||||
|
||||
class QuantizedAllToShardedLinear(Module):
|
||||
"""Each member of the group applies part of the affine transformation with
|
||||
a quantized matrix such that the result is sharded across the group.
|
||||
|
||||
It is the quantized equivalent of :class:`mlx.nn.AllToShardedLinear`.
|
||||
Similar to :class:`mlx.nn.QuantizedLinear` its parameters are frozen and
|
||||
will not be included in any gradient computation.
|
||||
|
||||
Args:
|
||||
input_dims (int): The dimensionality of the input features.
|
||||
output_dims (int): The dimensionality of the output features.
|
||||
bias (bool, optional): If set to ``False`` then the layer will not use
|
||||
a bias. Default: ``True``.
|
||||
group_size (int, optional): The group size to use for the quantized
|
||||
weight. See :func:`~mlx.core.quantize`. Default: ``64``.
|
||||
bits (int, optional): The bit width to use for the quantized weight.
|
||||
See :func:`~mlx.core.quantize`. Default: ``4``.
|
||||
group (mx.distributed.Group, optional): The sharding will happen across
|
||||
this group. If not set then the global group is used. Default is
|
||||
``None``.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_dims: int,
|
||||
output_dims: int,
|
||||
bias: bool = True,
|
||||
group_size: int = 64,
|
||||
bits: int = 4,
|
||||
group: Optional[mx.distributed.Group] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Quantization config
|
||||
self.group_size = group_size
|
||||
self.bits = bits
|
||||
|
||||
# Initialize the quantized weight
|
||||
scale = math.sqrt(1.0 / input_dims)
|
||||
self.group = group or mx.distributed.init()
|
||||
N = self.group.size()
|
||||
|
||||
if (output_dims % N) != 0:
|
||||
raise ValueError(
|
||||
f"Cannot shard the output of size {output_dims} across {N} devices."
|
||||
)
|
||||
|
||||
weight = mx.random.uniform(
|
||||
low=-scale,
|
||||
high=scale,
|
||||
shape=(output_dims // N, input_dims),
|
||||
)
|
||||
self.weight, self.scales, self.biases = mx.quantize(weight, group_size, bits)
|
||||
|
||||
# And bias if needed
|
||||
if bias:
|
||||
self.bias = mx.zeros((output_dims // N,))
|
||||
|
||||
# Freeze this model's parameters
|
||||
self.freeze()
|
||||
|
||||
def unfreeze(self, *args, **kwargs):
|
||||
"""Wrap unfreeze so that we unfreeze any layers we might contain but
|
||||
our parameters will remain frozen."""
|
||||
super().unfreeze(*args, **kwargs)
|
||||
self.freeze(recurse=False)
|
||||
|
||||
def _extra_repr(self) -> str:
|
||||
out_dims, in_dims = self.weight.shape
|
||||
in_dims *= 32 // self.bits
|
||||
out_dims *= self.group.size()
|
||||
return (
|
||||
f"input_dims={in_dims}, output_dims={out_dims}, bias={'bias' in self}, "
|
||||
f"group_size={self.group_size}, bits={self.bits}"
|
||||
)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
# Aggregate the gradients coming from each shard
|
||||
if self.group.size() > 1:
|
||||
x = sum_gradients(self.group)(x)
|
||||
|
||||
x = mx.quantized_matmul(
|
||||
x,
|
||||
self["weight"],
|
||||
scales=self["scales"],
|
||||
biases=self["biases"],
|
||||
transpose=True,
|
||||
group_size=self.group_size,
|
||||
bits=self.bits,
|
||||
)
|
||||
if "bias" in self:
|
||||
x = x + self["bias"]
|
||||
return x
|
||||
|
||||
@classmethod
|
||||
def from_quantized_linear(
|
||||
cls,
|
||||
quantized_linear_layer: Module,
|
||||
group: Optional[mx.distributed.Group] = None,
|
||||
):
|
||||
group = group or mx.distributed.init()
|
||||
N = group.size()
|
||||
r = group.rank()
|
||||
output_dims, input_dims = quantized_linear_layer.weight.shape
|
||||
input_dims *= 32 // quantized_linear_layer.bits
|
||||
step = output_dims // N
|
||||
|
||||
sl = cls(
|
||||
input_dims,
|
||||
output_dims,
|
||||
False,
|
||||
group_size=quantized_linear_layer.group_size,
|
||||
bits=quantized_linear_layer.bits,
|
||||
group=group,
|
||||
)
|
||||
sl.weight = quantized_linear_layer.weight[r * step : (r + 1) * step] * 1
|
||||
sl.scales = quantized_linear_layer.scales[r * step : (r + 1) * step] * 1
|
||||
sl.biases = quantized_linear_layer.biases[r * step : (r + 1) * step] * 1
|
||||
if "bias" in quantized_linear_layer:
|
||||
sl.bias = quantized_linear_layer.bias[r * step : (r + 1) * step] * 1
|
||||
|
||||
return sl
|
||||
|
||||
|
||||
class QuantizedShardedToAllLinear(Module):
|
||||
"""Each member of the group applies part of the affine transformation using
|
||||
the quantized matrix and then aggregates the results.
|
||||
|
||||
All nodes will have the same exact result after this layer.
|
||||
|
||||
It is the quantized equivalent of :class:`mlx.nn.ShardedToAllLinear`.
|
||||
Similar to :class:`mlx.nn.QuantizedLinear` its parameters are frozen and
|
||||
will not be included in any gradient computation.
|
||||
|
||||
Args:
|
||||
input_dims (int): The dimensionality of the input features.
|
||||
output_dims (int): The dimensionality of the output features.
|
||||
bias (bool, optional): If set to ``False`` then the layer will not use
|
||||
a bias. Default: ``True``.
|
||||
group_size (int, optional): The group size to use for the quantized
|
||||
weight. See :func:`~mlx.core.quantize`. Default: ``64``.
|
||||
bits (int, optional): The bit width to use for the quantized weight.
|
||||
See :func:`~mlx.core.quantize`. Default: ``4``.
|
||||
group (mx.distributed.Group, optional): The sharding will happen across
|
||||
this group. If not set then the global group is used. Default is
|
||||
``None``.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_dims: int,
|
||||
output_dims: int,
|
||||
bias: bool = True,
|
||||
group_size: int = 64,
|
||||
bits: int = 4,
|
||||
group: Optional[mx.distributed.Group] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Quantization config
|
||||
self.group_size = group_size
|
||||
self.bits = bits
|
||||
|
||||
# Initialize the quantized weight
|
||||
scale = math.sqrt(1.0 / input_dims)
|
||||
self.group = group or mx.distributed.init()
|
||||
N = self.group.size()
|
||||
|
||||
if (input_dims % N) != 0:
|
||||
raise ValueError(
|
||||
f"The input of size {input_dims} cannot be sharded across {N} devices."
|
||||
)
|
||||
|
||||
weight = mx.random.uniform(
|
||||
low=-scale,
|
||||
high=scale,
|
||||
shape=(output_dims, input_dims // N),
|
||||
)
|
||||
self.weight, self.scales, self.biases = mx.quantize(weight, group_size, bits)
|
||||
|
||||
# And bias if needed
|
||||
if bias:
|
||||
self.bias = mx.zeros((output_dims,))
|
||||
|
||||
# Freeze this model's parameters
|
||||
self.freeze()
|
||||
|
||||
def unfreeze(self, *args, **kwargs):
|
||||
"""Wrap unfreeze so that we unfreeze any layers we might contain but
|
||||
our parameters will remain frozen."""
|
||||
super().unfreeze(*args, **kwargs)
|
||||
self.freeze(recurse=False)
|
||||
|
||||
def _extra_repr(self) -> str:
|
||||
out_dims, in_dims = self.weight.shape
|
||||
in_dims *= (32 // self.bits) * self.group.size()
|
||||
return (
|
||||
f"input_dims={in_dims}, output_dims={out_dims}, bias={'bias' in self}, "
|
||||
f"group_size={self.group_size}, bits={self.bits}"
|
||||
)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
x = mx.quantized_matmul(
|
||||
x,
|
||||
self["weight"],
|
||||
scales=self["scales"],
|
||||
biases=self["biases"],
|
||||
transpose=True,
|
||||
group_size=self.group_size,
|
||||
bits=self.bits,
|
||||
)
|
||||
if self.group.size() > 1:
|
||||
x = mx.distributed.all_sum(x, group=self.group)
|
||||
if "bias" in self:
|
||||
x = x + self["bias"]
|
||||
return x
|
||||
|
||||
@classmethod
|
||||
def from_quantized_linear(
|
||||
cls,
|
||||
quantized_linear_layer: Module,
|
||||
group: Optional[mx.distributed.Group] = None,
|
||||
):
|
||||
group = group or mx.distributed.init()
|
||||
N = group.size()
|
||||
r = group.rank()
|
||||
output_dims, input_dims = quantized_linear_layer.weight.shape
|
||||
step = input_dims // N
|
||||
step_grouped = quantized_linear_layer.scales.shape[1] // N
|
||||
input_dims *= (32 // quantized_linear_layer.bits) * N
|
||||
|
||||
sl = cls(
|
||||
input_dims,
|
||||
output_dims,
|
||||
False,
|
||||
group_size=quantized_linear_layer.group_size,
|
||||
bits=quantized_linear_layer.bits,
|
||||
group=group,
|
||||
)
|
||||
sl.weight = quantized_linear_layer.weight[:, r * step : (r + 1) * step] * 1
|
||||
sl.scales = (
|
||||
quantized_linear_layer.scales[:, r * step_grouped : (r + 1) * step_grouped]
|
||||
* 1
|
||||
)
|
||||
sl.biases = (
|
||||
quantized_linear_layer.biases[:, r * step_grouped : (r + 1) * step_grouped]
|
||||
* 1
|
||||
)
|
||||
if "bias" in quantized_linear_layer:
|
||||
sl.bias = quantized_linear_layer.bias
|
||||
|
||||
return sl
|
@@ -197,7 +197,7 @@ class QuantizedLinear(Module):
|
||||
out_dims, in_dims = self.weight.shape
|
||||
in_dims *= 32 // self.bits
|
||||
return (
|
||||
f"input_dims={in_dims}, output_dims={out_dims}, bias={'bias' in self},"
|
||||
f"input_dims={in_dims}, output_dims={out_dims}, bias={'bias' in self}, "
|
||||
f"group_size={self.group_size}, bits={self.bits}"
|
||||
)
|
||||
|
||||
|
@@ -44,7 +44,8 @@ void init_distributed(nb::module_& parent_module) {
|
||||
color (int): A value to group processes into subgroups.
|
||||
key (int, optional): A key to optionally change the rank ordering
|
||||
of the processes.
|
||||
)pbdoc");
|
||||
)pbdoc")
|
||||
.def("barrier", &distributed::Group::barrier, "Make a synhronization point for all nodes in the group");
|
||||
|
||||
m.def(
|
||||
"is_available",
|
||||
|
Reference in New Issue
Block a user