Compare commits

...

16 Commits

Author SHA1 Message Date
Ronan Collobert
87b680766e Gloo backend support 2024-11-13 13:52:37 -08:00
Ronan Collobert
70ffaa50d2 be more relaxed on OpenMPI version 2024-11-13 13:51:37 -08:00
Angelos Katharopoulos
d82699f0f1 Merge branch 'distributed-layers' into socket-distributed-layers 2024-11-05 11:36:16 -08:00
Angelos Katharopoulos
6fc00d2c10 Add rudimentary barrier 2024-11-05 11:34:55 -08:00
Angelos Katharopoulos
44f0de2854 Fix run without distributed 2024-11-05 11:27:41 -08:00
Angelos Katharopoulos
29ec3539ed TCP socket distributed 2024-11-05 11:27:41 -08:00
Angelos Katharopoulos
e94f0028c3 Change the send message size 2024-11-05 11:27:41 -08:00
Angelos Katharopoulos
e5354fcddb Make it work even for donated inputs 2024-11-05 11:27:41 -08:00
Angelos Katharopoulos
34dd079a64 Start a sockets based distributed backend 2024-11-05 11:27:41 -08:00
Angelos Katharopoulos
16975815e9 Fixes in distributed layers 2024-11-05 11:27:26 -08:00
Angelos Katharopoulos
a8b3da7946 Add distributed layers to nn top-level 2024-11-05 11:27:26 -08:00
Angelos Katharopoulos
060e1c9f92 Add quantized distributed layers 2024-11-05 11:27:26 -08:00
Angelos Katharopoulos
0b04742985 Add the distributed linear layers 2024-11-05 11:27:26 -08:00
Angelos Katharopoulos
c3ccd4919f Add MPI barrier 2024-11-05 11:26:53 -08:00
Alex Barron
26be608470 Add split_k qvm for long context (#1564)
* Add splitk qvm

* configurable splitk

* tuning

* remove extra instantiation

* remove refactor

* separate test

* cpu tolerance
2024-11-05 11:25:19 -08:00
Angelos Katharopoulos
248431eb3c Reductions update (#1351) 2024-11-04 22:25:16 -08:00
26 changed files with 1912 additions and 219 deletions

View File

@@ -168,11 +168,12 @@ endif()
find_package(MPI)
if(MPI_FOUND)
execute_process(
COMMAND zsh "-c" "mpirun --version"
COMMAND zsh "-c" "${MPIEXEC_EXECUTABLE} --version"
OUTPUT_VARIABLE MPI_VERSION
ERROR_QUIET)
if(${MPI_VERSION} MATCHES ".*Open MPI.*")
if(${MPI_VERSION} MATCHES ".*Open MPI.*" OR ${MPI_VERSION} MATCHES ".*OpenRTE.*")
target_include_directories(mlx PRIVATE ${MPI_INCLUDE_PATH})
target_link_libraries(mlx PRIVATE ${MPI_CXX_LIBRARIES})
elseif(MPI_VERSION STREQUAL "")
set(MPI_FOUND FALSE)
message(

View File

@@ -144,6 +144,13 @@ def reduction(op, axis, x):
mx.eval(ys)
def sum_and_add(axis, x, y):
z = x.sum(axis=axis, keepdims=True)
for i in range(50):
z = (z + y).sum(axis=axis, keepdims=True)
mx.eval(z)
def softmax(axis, x):
ys = []
for i in range(100):
@@ -505,5 +512,8 @@ if __name__ == "__main__":
elif args.benchmark == "selu":
print(bench(selu, x))
elif args.benchmark == "sum_and_add":
print(bench(sum_and_add, axis, *xs))
else:
raise ValueError("Unknown benchmark")

View File

@@ -319,16 +319,18 @@ MTL::ComputePipelineState* get_mb_sort_kernel(
MTL::ComputePipelineState* get_reduce_init_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& func_name,
const std::string& op_name,
const array& out) {
auto lib = d.get_library(kernel_name, [&]() {
std::ostringstream kernel_source;
std::string op_type = op_name(out);
op_type[0] = std::toupper(op_name(out)[0]);
std::string op_type = op_name;
op_type[0] = std::toupper(op_name[0]);
auto out_type = get_type_string(out.dtype());
std::string op = op_type + "<" + out_type + ">";
kernel_source << metal::utils() << metal::reduce_utils() << metal::reduce();
kernel_source << get_template_definition(
kernel_name, "init_reduce", out_type, op);
kernel_name, func_name, out_type, op);
return kernel_source.str();
});
return d.get_kernel(kernel_name, lib);

View File

@@ -79,6 +79,8 @@ MTL::ComputePipelineState* get_mb_sort_kernel(
MTL::ComputePipelineState* get_reduce_init_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& func_name,
const std::string& op_name,
const array& out);
MTL::ComputePipelineState* get_reduce_kernel(

View File

@@ -650,8 +650,8 @@ METAL_FUNC void qvm_impl(
const device T* biases,
const device T* x,
device T* y,
const constant int& in_vec_size,
const constant int& out_vec_size,
const int in_vec_size,
const int out_vec_size,
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
@@ -1298,6 +1298,61 @@ template <typename T, const int group_size, const int bits, bool batched>
simd_lid);
}
template <typename T, const int group_size, const int bits, int split_k = 32>
[[kernel]] void qvm_split_k(
const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]],
const device T* x [[buffer(3)]],
device T* y [[buffer(4)]],
const constant int& in_vec_size [[buffer(5)]],
const constant int& out_vec_size [[buffer(6)]],
const constant int& x_batch_ndims [[buffer(7)]],
const constant int* x_shape [[buffer(8)]],
const constant size_t* x_strides [[buffer(9)]],
const constant int& w_batch_ndims [[buffer(10)]],
const constant int* w_shape [[buffer(11)]],
const constant size_t* w_strides [[buffer(12)]],
const constant size_t* s_strides [[buffer(13)]],
const constant size_t* b_strides [[buffer(14)]],
const constant int& final_block_size [[buffer(15)]],
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
adjust_matrix_offsets<T>(
x,
w,
scales,
biases,
y,
out_vec_size,
x_batch_ndims,
x_shape,
x_strides,
w_batch_ndims,
w_shape,
w_strides,
s_strides,
b_strides,
tid);
// When (in_vec_size % split_k != 0) the final block needs to be smaller
int in_vec_size_adj =
tid.z % split_k == split_k - 1 ? final_block_size : in_vec_size;
qvm_impl<T, group_size, bits>(
w,
scales,
biases,
x,
y,
in_vec_size_adj,
out_vec_size,
tid,
simd_gid,
simd_lid);
}
template <
typename T,
const int group_size,

View File

@@ -51,6 +51,15 @@
D, \
batched)
#define instantiate_quantized_split_k(name, type, group_size, bits, split_k) \
instantiate_kernel( \
#name "_" #type "_gs_" #group_size "_b_" #bits "_spk_" #split_k, \
name, \
type, \
group_size, \
bits, \
split_k)
#define instantiate_quantized_batched_wrap(name, type, group_size, bits) \
instantiate_quantized_batched(name, type, group_size, bits, 1) \
instantiate_quantized_batched(name, type, group_size, bits, 0)
@@ -84,11 +93,16 @@
instantiate_quantized_quad(qmv_quad, type, group_size, bits, 128, 1) \
instantiate_quantized_quad(qmv_quad, type, group_size, bits, 128, 0)
#define instantiate_quantized_all_splitk(type, group_size, bits) \
instantiate_quantized_split_k(qvm_split_k, type, group_size, bits, 8) \
instantiate_quantized_split_k(qvm_split_k, type, group_size, bits, 32)
#define instantiate_quantized_funcs(type, group_size, bits) \
instantiate_quantized_all_single(type, group_size, bits) \
instantiate_quantized_all_batched(type, group_size, bits) \
instantiate_quantized_all_aligned(type, group_size, bits) \
instantiate_quantized_all_quad(type, group_size, bits)
instantiate_quantized_all_quad(type, group_size, bits) \
instantiate_quantized_all_splitk(type, group_size, bits)
#define instantiate_quantized_types(group_size, bits) \
instantiate_quantized_funcs(float, group_size, bits) \

View File

@@ -113,9 +113,12 @@ instantiate_reduce_from_types(instantiate_all_reduce, or, bool, Or<bool>)
// special case bool with larger output type
instantiate_all_reduce(sumbool_, bool, uint32_t, Sum<uint32_t>)
#define instantiate_col_reduce_small(name, itype, otype, op, dim) \
instantiate_kernel("col_reduce_small_" #dim "_reduce_" #name, \
col_reduce_small, \
#define instantiate_col_reduce_small(name, itype, otype, op, dim) \
instantiate_kernel("col_reduce_small_" #dim "_reduce_" #name, \
col_reduce_small, \
itype, otype, op, dim) \
instantiate_kernel("col_reduce_longcolumn_" #dim "_reduce_" #name, \
col_reduce_longcolumn, \
itype, otype, op, dim)
#define instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, bm, bn) \
@@ -123,9 +126,14 @@ instantiate_all_reduce(sumbool_, bool, uint32_t, Sum<uint32_t>)
col_reduce_looped, \
itype, otype, op, dim, bm, bn)
#define instantiate_col_reduce_2pass_tile(name, itype, otype, op, dim, bm, bn) \
instantiate_kernel("col_reduce_2pass_" #dim "_" #bm "_" #bn "_reduce_" #name, \
col_reduce_2pass, \
itype, otype, op, dim, bm, bn)
#define instantiate_col_reduce_looped(name, itype, otype, op, dim) \
instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, 8, 128) \
instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, 32, 32)
instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, 32, 32) \
instantiate_col_reduce_2pass_tile(name, itype, otype, op, dim, 32, 32)
#define instantiate_col_reduce_general(name, itype, otype, op) \
instantiate_col_reduce_small(name, itype, otype, op, 0) \

View File

@@ -1,11 +1,6 @@
// Copyright © 2023-2024 Apple Inc.
template <
typename T,
typename U,
typename Op,
int NDIMS,
int N_READS = REDUCE_N_READS>
template <typename T, typename U, typename Op, int NDIMS>
[[kernel]] void col_reduce_small(
const device T* in [[buffer(0)]],
device U* out [[buffer(1)]],
@@ -20,170 +15,128 @@ template <
const constant size_t& non_col_reductions [[buffer(10)]],
uint3 gid [[threadgroup_position_in_grid]],
uint3 gsize [[threadgroups_per_grid]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[thread_position_in_grid]],
uint3 tsize [[threads_per_grid]]) {
uint3 lid [[thread_position_in_threadgroup]],
uint3 lsize [[threads_per_threadgroup]]) {
constexpr int n_reads = 4;
Op op;
looped_elem_to_loc<NDIMS> loop;
const device T* row;
// Case 1: Small row small column
if (reduction_size * non_col_reductions < 64 && reduction_stride < 32) {
U totals[31];
for (int i = 0; i < 31; i++) {
totals[i] = Op::init;
U totals[n_reads];
for (int i = 0; i < n_reads; i++) {
totals[i] = Op::init;
}
size_t column = size_t(gid.x) * lsize.x * n_reads + lid.x * n_reads;
if (column >= reduction_stride) {
return;
}
bool safe = column + n_reads <= reduction_stride;
size_t out_idx = gid.y + gsize.y * size_t(gid.z);
size_t in_idx = elem_to_loc(out_idx, shape, strides, ndim);
in += in_idx + column;
size_t total_rows = non_col_reductions * reduction_size;
loop.next(lid.y, reduce_shape, reduce_strides);
for (size_t r = lid.y; r < total_rows; r += lsize.y) {
row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim);
if (safe) {
for (int i = 0; i < n_reads; i++) {
totals[i] = op(static_cast<U>(row[i]), totals[i]);
}
} else {
U vals[n_reads];
for (int i = 0; i < n_reads; i++) {
vals[i] =
(column + i < reduction_stride) ? static_cast<U>(row[i]) : op.init;
}
for (int i = 0; i < n_reads; i++) {
totals[i] = op(vals[i], totals[i]);
}
}
loop.next(lsize.y, reduce_shape, reduce_strides);
}
short stride = reduction_stride;
short size = reduction_size;
short blocks = stride / N_READS;
short extra = stride - blocks * N_READS;
size_t out_idx = tid.x + tsize.y * size_t(tid.y);
in += elem_to_loc(out_idx, shape, strides, ndim);
for (uint r = 0; r < non_col_reductions; r++) {
row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim);
for (short i = 0; i < size; i++) {
for (short j = 0; j < blocks; j++) {
for (short k = 0; k < N_READS; k++) {
totals[j * N_READS + k] =
op(totals[j * N_READS + k],
static_cast<U>(row[i * stride + j * N_READS + k]));
}
}
for (short k = 0; k < extra; k++) {
totals[blocks * N_READS + k] =
op(totals[blocks * N_READS + k],
static_cast<U>(row[i * stride + blocks * N_READS + k]));
if (lsize.y > 1) {
// lsize.y should be <= 8
threadgroup U shared_vals[32 * 8 * n_reads];
for (int i = 0; i < n_reads; i++) {
shared_vals[lid.y * lsize.x * n_reads + lid.x * n_reads + i] = totals[i];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (lid.y == 0) {
for (int i = 0; i < n_reads; i++) {
totals[i] = shared_vals[lid.x * n_reads + i];
}
for (uint j = 1; j < lsize.y; j++) {
for (int i = 0; i < n_reads; i++) {
totals[i] =
op(shared_vals[j * lsize.x * n_reads + lid.x * n_reads + i],
totals[i]);
}
}
loop.next(reduce_shape, reduce_strides);
}
out += out_idx * reduction_stride;
for (short j = 0; j < stride; j++) {
out[j] = totals[j];
}
}
// Case 2: Long row small column
else if (reduction_size * non_col_reductions < 32) {
U totals[N_READS];
for (int i = 0; i < N_READS; i++) {
totals[i] = Op::init;
}
short size = reduction_size;
size_t offset = size_t(tid.x) * N_READS;
bool safe = offset + N_READS <= reduction_stride;
short extra = reduction_stride - offset;
size_t out_idx = tid.y + tsize.z * size_t(tid.z);
in += elem_to_loc(out_idx, shape, strides, ndim) + offset;
for (uint r = 0; r < non_col_reductions; r++) {
row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim);
if (safe) {
for (short i = 0; i < size; i++) {
for (short j = 0; j < N_READS; j++) {
totals[j] =
op(static_cast<U>(row[i * reduction_stride + j]), totals[j]);
}
}
} else {
for (short i = 0; i < size; i++) {
for (short j = 0; j < extra; j++) {
totals[j] =
op(static_cast<U>(row[i * reduction_stride + j]), totals[j]);
}
}
}
loop.next(reduce_shape, reduce_strides);
}
out += out_idx * reduction_stride + offset;
if (lid.y == 0) {
out += out_idx * reduction_stride + column;
if (safe) {
for (short i = 0; i < N_READS; i++) {
for (int i = 0; i < n_reads; i++) {
out[i] = totals[i];
}
} else {
for (short i = 0; i < extra; i++) {
for (int i = 0; column + i < reduction_stride; i++) {
out[i] = totals[i];
}
}
}
}
// Case 3: Long row medium column
else {
threadgroup U shared_vals[1024];
U totals[N_READS];
for (int i = 0; i < N_READS; i++) {
totals[i] = Op::init;
}
short stride = reduction_stride;
short lid = simd_group_id * simd_size + simd_lane_id;
short2 tile((stride + N_READS - 1) / N_READS, 32);
short2 offset((lid % tile.x) * N_READS, lid / tile.x);
short sm_stride = tile.x * N_READS;
bool safe = offset.x + N_READS <= stride;
size_t out_idx = gid.y + gsize.y * size_t(gid.z);
in += elem_to_loc(out_idx, shape, strides, ndim) + offset.x;
// Read cooperatively and contiguously and aggregate the partial results.
size_t total = non_col_reductions * reduction_size;
loop.next(offset.y, reduce_shape, reduce_strides);
for (size_t r = offset.y; r < total; r += simd_size) {
row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim);
if (safe) {
for (int i = 0; i < N_READS; i++) {
totals[i] = op(static_cast<U>(row[i]), totals[i]);
}
} else {
U vals[N_READS];
for (int i = 0; i < N_READS; i++) {
vals[i] = (offset.x + i < stride) ? static_cast<U>(row[i]) : op.init;
}
for (int i = 0; i < N_READS; i++) {
totals[i] = op(vals[i], totals[i]);
}
}
loop.next(simd_size, reduce_shape, reduce_strides);
}
// Each thread holds N_READS partial results but the simdgroups are not
// aligned to do the reduction across the simdgroup so we write our results
// in the shared memory and read them back according to the simdgroup.
for (int i = 0; i < N_READS; i++) {
shared_vals[offset.y * sm_stride + offset.x + i] = totals[i];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
for (int i = 0; i < N_READS; i++) {
totals[i] = op.simd_reduce(
shared_vals[simd_lane_id * sm_stride + simd_group_id * N_READS + i]);
}
// Write the output.
if (simd_lane_id == 0) {
short column = simd_group_id * N_READS;
out += out_idx * reduction_stride + column;
if (column + N_READS <= stride) {
for (int i = 0; i < N_READS; i++) {
out[i] = totals[i];
}
} else {
for (int i = 0; column + i < stride; i++) {
out[i] = totals[i];
}
}
template <typename T, typename U, typename Op, int NDIMS>
[[kernel]] void col_reduce_longcolumn(
const device T* in [[buffer(0)]],
device U* out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& reduction_stride [[buffer(3)]],
const constant int* shape [[buffer(4)]],
const constant size_t* strides [[buffer(5)]],
const constant int& ndim [[buffer(6)]],
const constant int* reduce_shape [[buffer(7)]],
const constant size_t* reduce_strides [[buffer(8)]],
const constant int& reduce_ndim [[buffer(9)]],
const constant size_t& non_col_reductions [[buffer(10)]],
const constant size_t& out_size [[buffer(11)]],
uint3 gid [[threadgroup_position_in_grid]],
uint3 gsize [[threadgroups_per_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint3 lsize [[threads_per_threadgroup]]) {
Op op;
looped_elem_to_loc<NDIMS> loop;
const device T* row;
size_t out_idx = gid.x + gsize.x * size_t(gid.y);
size_t in_idx = elem_to_loc(out_idx, shape, strides, ndim);
in += in_idx + lid.x;
U total = Op::init;
size_t total_rows = non_col_reductions * reduction_size;
loop.next(gid.z * lsize.y + lid.y, reduce_shape, reduce_strides);
for (size_t r = gid.z * lsize.y + lid.y; r < total_rows;
r += lsize.y * gsize.z) {
row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim);
total = op(static_cast<U>(*row), total);
loop.next(lsize.y * gsize.z, reduce_shape, reduce_strides);
}
threadgroup U shared_vals[32 * 32];
shared_vals[lid.y * lsize.x + lid.x] = total;
threadgroup_barrier(mem_flags::mem_threadgroup);
if (lid.y == 0) {
for (uint i = 1; i < lsize.y; i++) {
total = op(total, shared_vals[i * lsize.x + lid.x]);
}
out[gid.z * out_size + out_idx * reduction_stride + lid.x] = total;
}
}
@@ -216,7 +169,7 @@ template <typename T, typename U, typename Op, int NDIMS, int BM, int BN>
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
Op op;
constexpr int n_simdgroups = 4;
constexpr int n_simdgroups = 8;
constexpr short tgp_size = n_simdgroups * simd_size;
constexpr short n_reads = (BM * BN) / tgp_size;
constexpr short n_read_blocks = BN / n_reads;
@@ -329,3 +282,103 @@ template <typename T, typename U, typename Op, int NDIMS, int BM, int BN>
}
}
}
template <typename T, typename U, typename Op, int NDIMS, int BM, int BN>
[[kernel]] void col_reduce_2pass(
const device T* in [[buffer(0)]],
device U* out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& reduction_stride [[buffer(3)]],
const constant int* shape [[buffer(4)]],
const constant size_t* strides [[buffer(5)]],
const constant int& ndim [[buffer(6)]],
const constant int* reduce_shape [[buffer(7)]],
const constant size_t* reduce_strides [[buffer(8)]],
const constant int& reduce_ndim [[buffer(9)]],
const constant size_t& non_col_reductions [[buffer(10)]],
const constant size_t& out_size [[buffer(11)]],
uint3 gid [[threadgroup_position_in_grid]],
uint3 gsize [[threadgroups_per_grid]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
Op op;
constexpr int n_simdgroups = 8;
constexpr short tgp_size = n_simdgroups * simd_size;
constexpr short n_reads = (BM * BN) / tgp_size;
constexpr short n_read_blocks = BN / n_reads;
constexpr int n_outputs = BN / n_simdgroups;
constexpr short outer_blocks = 32;
static_assert(BM == 32, "BM should be equal to 32");
threadgroup U shared_vals[BN * BM];
U totals[n_reads];
looped_elem_to_loc<NDIMS> loop;
const device T* row;
for (int i = 0; i < n_reads; i++) {
totals[i] = Op::init;
}
short lid = simd_group_id * simd_size + simd_lane_id;
short2 offset((lid % n_read_blocks) * n_reads, lid / n_read_blocks);
size_t column = BN * gid.x + offset.x;
bool safe = column + n_reads <= reduction_stride;
size_t full_idx = gid.y + gsize.y * size_t(gid.z);
size_t block_idx = full_idx / out_size;
size_t out_idx = full_idx % out_size;
size_t in_idx = elem_to_loc(out_idx, shape, strides, ndim);
in += in_idx + column;
size_t total = non_col_reductions * reduction_size;
loop.next(offset.y + block_idx * BM, reduce_shape, reduce_strides);
for (size_t r = offset.y + block_idx * BM; r < total;
r += outer_blocks * BM) {
row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim);
if (safe) {
for (int i = 0; i < n_reads; i++) {
totals[i] = op(static_cast<U>(row[i]), totals[i]);
}
} else {
U vals[n_reads];
for (int i = 0; i < n_reads; i++) {
vals[i] =
(column + i < reduction_stride) ? static_cast<U>(row[i]) : op.init;
}
for (int i = 0; i < n_reads; i++) {
totals[i] = op(vals[i], totals[i]);
}
}
loop.next(outer_blocks * BM, reduce_shape, reduce_strides);
}
// We can use a simd reduction to accumulate across BM so each thread writes
// the partial output to SM and then each simdgroup does BN / n_simdgroups
// accumulations.
for (int i = 0; i < n_reads; i++) {
shared_vals[offset.y * BN + offset.x + i] = totals[i];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
short2 out_offset(simd_group_id * n_outputs, simd_lane_id);
for (int i = 0; i < n_outputs; i++) {
totals[i] =
op.simd_reduce(shared_vals[out_offset.y * BN + out_offset.x + i]);
}
// Write the output.
if (simd_lane_id == 0) {
size_t out_column = BN * gid.x + out_offset.x;
out += full_idx * reduction_stride + out_column;
if (out_column + n_outputs <= reduction_stride) {
for (int i = 0; i < n_outputs; i++) {
out[i] = totals[i];
}
} else {
for (int i = 0; out_column + i < reduction_stride; i++) {
out[i] = totals[i];
}
}
}
}

View File

@@ -97,6 +97,8 @@ MTL::ComputePipelineState* get_mb_sort_kernel(
MTL::ComputePipelineState* get_reduce_init_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string&,
const std::string&,
const array&) {
return d.get_kernel(kernel_name);
}

View File

@@ -6,6 +6,7 @@
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/reduce.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/fast_primitives.h"
#include "mlx/primitives.h"
@@ -148,6 +149,125 @@ void launch_qmm(
d.add_temporaries(std::move(copies), s.index);
}
void qvm_split_k(
const std::vector<array>& inputs,
array& out,
int group_size,
int bits,
int D,
int O,
int B,
int N,
const Stream& s) {
int split_k = D > 8192 ? 32 : 8;
int split_D = (D + split_k - 1) / split_k;
N *= split_k;
int bo = 64;
int bd = 32;
MTL::Size group_dims = MTL::Size(bd, 2, 1);
MTL::Size grid_dims = MTL::Size(O / bo, B, N);
auto& x_pre = inputs[0];
auto& w_pre = inputs[1];
auto& scales_pre = inputs[2];
auto& biases_pre = inputs[3];
// Ensure that the last two dims are row contiguous.
// TODO: Check if we really need this for x as well...
std::vector<array> copies;
auto ensure_row_contiguous_last_dims = [&copies, &s](const array& arr) {
auto stride_0 = arr.strides()[arr.ndim() - 2];
auto stride_1 = arr.strides()[arr.ndim() - 1];
if (stride_0 == arr.shape(-1) && stride_1 == 1) {
return arr;
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy_gpu(arr, arr_copy, CopyType::General, s);
copies.push_back(arr_copy);
return arr_copy;
}
};
auto x = ensure_row_contiguous_last_dims(x_pre);
auto w = ensure_row_contiguous_last_dims(w_pre);
auto scales = ensure_row_contiguous_last_dims(scales_pre);
auto biases = ensure_row_contiguous_last_dims(biases_pre);
int x_batch_ndims = x.ndim() - 2;
auto x_shape = x.shape();
auto x_strides = x.strides();
int w_batch_ndims = w.ndim() - 2;
auto w_shape = w.shape();
auto w_strides = w.strides();
auto s_strides = scales.strides();
auto b_strides = biases.strides();
// Add split_k dim with reshapes
x_shape.insert(x_shape.end() - 2, split_k);
x_shape.back() /= split_k;
x_strides.insert(x_strides.end() - 2, split_D);
x_strides[x.ndim() - 1] = split_D;
x_batch_ndims += 1;
w_shape.insert(w_shape.end() - 2, split_k);
w_shape[w.ndim() - 1] /= split_k;
w_strides.insert(w_strides.end() - 2, split_D * w.shape(-1));
w_batch_ndims += 1;
s_strides.insert(s_strides.end() - 2, split_D * scales.shape(-1));
b_strides.insert(b_strides.end() - 2, split_D * biases.shape(-1));
int final_block_size = D - (split_k - 1) * split_D;
auto& d = metal::device(s.device);
auto temp_shape = out.shape();
temp_shape.insert(temp_shape.end() - 2, split_k);
array intermediate(temp_shape, x.dtype(), nullptr, {});
intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes()));
d.add_temporary(intermediate, s.index);
std::ostringstream kname;
auto type_string = get_type_string(x.dtype());
kname << "qvm_split_k" << "_" << type_string << "_gs_" << group_size << "_b_"
<< bits << "_spk_" << split_k;
auto template_def = get_template_definition(
kname.str(), "qvm_split_k", type_string, group_size, bits, split_k);
// Encode and dispatch kernel
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array(w, 0);
compute_encoder.set_input_array(scales, 1);
compute_encoder.set_input_array(biases, 2);
compute_encoder.set_input_array(x, 3);
compute_encoder.set_output_array(intermediate, 4);
compute_encoder->setBytes(&split_D, sizeof(int), 5);
compute_encoder->setBytes(&O, sizeof(int), 6);
compute_encoder->setBytes(&x_batch_ndims, sizeof(int), 7);
set_vector_bytes(compute_encoder, x_shape, 8);
set_vector_bytes(compute_encoder, x_strides, 9);
compute_encoder->setBytes(&w_batch_ndims, sizeof(int), 10);
set_vector_bytes(compute_encoder, w_shape, 11);
set_vector_bytes(compute_encoder, w_strides, 12);
set_vector_bytes(compute_encoder, s_strides, 13);
set_vector_bytes(compute_encoder, b_strides, 14);
compute_encoder->setBytes(&final_block_size, sizeof(int), 15);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
d.add_temporaries(std::move(copies), s.index);
int axis = intermediate.ndim() - 3;
ReductionPlan plan(
ReductionOpType::ContiguousStridedReduce,
{intermediate.shape(axis)},
{intermediate.strides(axis)});
strided_reduce_general_dispatch(
intermediate, out, "sum", plan, {axis}, compute_encoder, d, s);
}
void qmm_op(
const std::vector<array>& inputs,
array& out,
@@ -211,7 +331,9 @@ void qmm_op(
aligned = true;
}
} else {
if (B < 4) {
if (B < 4 && D >= 1024 && !gather) {
return qvm_split_k(inputs, out, group_size, bits, D, O, B, N, s);
} else if (B < 4) {
name += "qvm";
int bo = 64;
int bd = 32;

View File

@@ -141,6 +141,20 @@ struct ColReduceArgs {
ndim = shape.size();
}
/**
* Create the col reduce arguments for reducing the 1st axis of the row
* contiguous intermediate array.
*/
ColReduceArgs(const array& intermediate) {
assert(intermediate.flags().row_contiguous);
reduction_size = intermediate.shape(0);
reduction_stride = intermediate.size() / reduction_size;
non_col_reductions = 1;
reduce_ndim = 0;
ndim = 0;
}
void encode(CommandEncoder& compute_encoder) {
// Push 0s to avoid encoding empty vectors.
if (reduce_ndim == 0) {
@@ -231,8 +245,10 @@ void init_reduce(
CommandEncoder& compute_encoder,
metal::Device& d,
const Stream& s) {
auto kernel = get_reduce_init_kernel(
d, "init_reduce_" + op_name + type_to_name(out), out);
std::ostringstream kname;
const std::string func_name = "init_reduce";
kname << func_name << "_" << op_name << type_to_name(out);
auto kernel = get_reduce_init_kernel(d, kname.str(), func_name, op_name, out);
size_t nthreads = out.size();
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
@@ -251,8 +267,7 @@ void all_reduce_dispatch(
const std::string& op_name,
CommandEncoder& compute_encoder,
metal::Device& d,
const Stream& s,
std::vector<array>& copies) {
const Stream& s) {
// Set the kernel
std::ostringstream kname;
const std::string func_name = "all_reduce";
@@ -293,7 +308,7 @@ void all_reduce_dispatch(
// Allocate an intermediate tensor to hold results if needed
array intermediate({n_rows}, out.dtype(), nullptr, {});
intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes()));
copies.push_back(intermediate);
d.add_temporary(intermediate, s.index);
// 1st pass
size_t row_size = (in_size + n_rows - 1) / n_rows;
@@ -469,39 +484,11 @@ void strided_reduce_small(
// Figure out the grid dims
MTL::Size grid_dims, group_dims;
// Case 1: Small row small column
if (args.reduction_size * args.non_col_reductions < 64 &&
args.reduction_stride < 32) {
grid_dims = output_grid_for_col_reduce(out, args);
int threadgroup_size = (grid_dims.width > 128) ? 128 : grid_dims.width;
group_dims = MTL::Size(threadgroup_size, 1, 1);
}
// Prepare the arguments for the kernel
args.reduce_shape.push_back(args.reduction_size);
args.reduce_strides.push_back(args.reduction_stride);
args.reduce_ndim++;
// Case 2: Long row small column
else if (args.reduction_size * args.non_col_reductions < 32) {
auto out_grid_dims = output_grid_for_col_reduce(out, args);
int threads_x =
(args.reduction_stride + REDUCE_N_READS - 1) / REDUCE_N_READS;
int threadgroup_x = std::min(threads_x, 128);
grid_dims = MTL::Size(threads_x, out_grid_dims.width, out_grid_dims.height);
group_dims = MTL::Size(threadgroup_x, 1, 1);
}
// Case 3: Long row medium column
else {
args.reduce_shape.push_back(args.reduction_size);
args.reduce_strides.push_back(args.reduction_stride);
args.reduce_ndim++;
int simdgroups =
(args.reduction_stride + REDUCE_N_READS - 1) / REDUCE_N_READS;
int threadgroup_size = simdgroups * 32;
auto out_grid_dims = output_grid_for_col_reduce(out, args);
grid_dims =
MTL::Size(threadgroup_size, out_grid_dims.width, out_grid_dims.height);
group_dims = MTL::Size(threadgroup_size, 1, 1);
}
// Set the kernel
int n = (args.reduce_ndim < 5) ? std::max(1, args.reduce_ndim) : 0;
std::ostringstream kname;
const std::string func_name = "col_reduce_small";
@@ -510,10 +497,113 @@ void strided_reduce_small(
get_reduce_kernel(d, kname.str(), func_name, op_name, in, out, n);
compute_encoder->setComputePipelineState(kernel);
const int n_reads = 4;
size_t reduction_stride_blocks =
(args.reduction_stride + n_reads - 1) / n_reads;
size_t total = args.reduction_size * args.non_col_reductions;
size_t threadgroup_x = std::min(reduction_stride_blocks, 32ul);
size_t threadgroup_y = std::min(
8ul,
std::min(kernel->maxTotalThreadsPerThreadgroup() / threadgroup_x, total));
group_dims = MTL::Size(threadgroup_x, threadgroup_y, 1);
grid_dims = output_grid_for_col_reduce(out, args);
grid_dims = MTL::Size(
(reduction_stride_blocks + threadgroup_x - 1) / threadgroup_x,
grid_dims.width,
grid_dims.height);
// Launch
compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(out, 1);
args.encode(compute_encoder);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
}
void strided_reduce_longcolumn(
const array& in,
array& out,
const std::string& op_name,
ColReduceArgs& args,
CommandEncoder& compute_encoder,
metal::Device& d,
const Stream& s) {
size_t total_reduction_size = args.reduction_size * args.non_col_reductions;
size_t outer_blocks = 32;
if (total_reduction_size >= 32768) {
outer_blocks = 128;
}
// Prepare the temporary accumulator
std::vector<int> intermediate_shape;
intermediate_shape.reserve(out.ndim() + 1);
intermediate_shape.push_back(outer_blocks);
intermediate_shape.insert(
intermediate_shape.end(), out.shape().begin(), out.shape().end());
array intermediate(std::move(intermediate_shape), out.dtype(), nullptr, {});
intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes()));
d.add_temporary(intermediate, s.index);
// Prepare the arguments for the kernel
args.reduce_shape.push_back(args.reduction_size);
args.reduce_strides.push_back(args.reduction_stride);
args.reduce_ndim++;
// Figure out the grid dims
size_t out_size = out.size();
size_t threadgroup_x = args.reduction_stride;
size_t threadgroup_y =
(args.non_col_reductions * args.reduction_size + outer_blocks - 1) /
outer_blocks;
threadgroup_y = std::min(32ul, threadgroup_y);
auto out_grid_size = output_grid_for_col_reduce(out, args);
MTL::Size grid_dims(out_grid_size.width, out_grid_size.height, outer_blocks);
MTL::Size group_dims(threadgroup_x, threadgroup_y, 1);
// Set the kernel
int n = (args.reduce_ndim < 5) ? std::max(1, args.reduce_ndim) : 0;
std::ostringstream kname;
const std::string func_name = "col_reduce_longcolumn";
kname << func_name << "_" << n << "_reduce_" << op_name << type_to_name(in);
auto kernel =
get_reduce_kernel(d, kname.str(), func_name, op_name, in, out, n);
compute_encoder->setComputePipelineState(kernel);
// Launch
compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(intermediate, 1);
args.encode(compute_encoder);
compute_encoder->setBytes(&out_size, sizeof(size_t), 11);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
// Make the 2nd pass arguments and grid_dims
ColReduceArgs second_args(intermediate);
second_args.reduce_shape.push_back(outer_blocks);
second_args.reduce_strides.push_back(out.size());
second_args.reduce_ndim++;
int BN = 32;
grid_dims = MTL::Size(256 * ((out.size() + BN - 1) / BN), 1, 1);
group_dims = MTL::Size(256, 1, 1);
// Set the 2nd kernel
const std::string second_kernel = "col_reduce_looped_1_32_32_reduce_" +
op_name + type_to_name(intermediate);
kernel = get_reduce_kernel(
d,
second_kernel,
"col_reduce_looped",
op_name,
intermediate,
out,
1,
32,
32);
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array(intermediate, 0);
compute_encoder.set_output_array(out, 1);
second_args.encode(compute_encoder);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
@@ -532,9 +622,9 @@ void strided_reduce_looped(
// Figure out the grid dims
auto out_grid_size = output_grid_for_col_reduce(out, args);
int BN = (args.reduction_stride <= 1024) ? 32 : 128;
int BN = 32;
int BM = 1024 / BN;
int threadgroup_size = 4 * 32;
int threadgroup_size = 8 * 32;
MTL::Size grid_dims(
threadgroup_size * ((args.reduction_stride + BN - 1) / BN),
out_grid_size.width,
@@ -558,6 +648,87 @@ void strided_reduce_looped(
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
void strided_reduce_2pass(
const array& in,
array& out,
const std::string& op_name,
ColReduceArgs& args,
CommandEncoder& compute_encoder,
metal::Device& d,
const Stream& s) {
// Prepare the temporary accumulator
std::vector<int> intermediate_shape;
intermediate_shape.reserve(out.ndim() + 1);
intermediate_shape.push_back(32);
intermediate_shape.insert(
intermediate_shape.end(), out.shape().begin(), out.shape().end());
array intermediate(std::move(intermediate_shape), out.dtype(), nullptr, {});
intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes()));
d.add_temporary(intermediate, s.index);
// Prepare the arguments for the kernel
args.reduce_shape.push_back(args.reduction_size);
args.reduce_strides.push_back(args.reduction_stride);
args.reduce_ndim++;
// Figure out the grid dims
size_t out_size = out.size() / args.reduction_stride;
auto out_grid_size = output_grid_for_col_reduce(out, args);
int outer_blocks = 32;
int BN = 32;
int BM = 1024 / BN;
int threadgroup_size = 8 * 32;
MTL::Size grid_dims(
threadgroup_size * ((args.reduction_stride + BN - 1) / BN),
out_grid_size.width * outer_blocks,
out_grid_size.height);
MTL::Size group_dims(threadgroup_size, 1, 1);
// Set the kernel
int n = (args.reduce_ndim < 5) ? std::max(1, args.reduce_ndim) : 0;
std::ostringstream kname;
const std::string func_name = "col_reduce_2pass";
kname << func_name << "_" << n << "_" << BM << "_" << BN << "_reduce_"
<< op_name << type_to_name(in);
auto kernel =
get_reduce_kernel(d, kname.str(), func_name, op_name, in, out, n, BM, BN);
compute_encoder->setComputePipelineState(kernel);
// Launch
compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(intermediate, 1);
args.encode(compute_encoder);
compute_encoder->setBytes(&out_size, sizeof(size_t), 11);
compute_encoder.dispatchThreads(grid_dims, group_dims);
// Make the 2nd pass arguments and grid_dims
ColReduceArgs second_args(intermediate);
second_args.reduce_shape.push_back(outer_blocks);
second_args.reduce_strides.push_back(out.size());
second_args.reduce_ndim++;
grid_dims = MTL::Size(threadgroup_size * ((out.size() + BN - 1) / BN), 1, 1);
// Set the 2nd kernel
const std::string second_kernel = "col_reduce_looped_1_32_32_reduce_" +
op_name + type_to_name(intermediate);
kernel = get_reduce_kernel(
d,
second_kernel,
"col_reduce_looped",
op_name,
intermediate,
out,
1,
32,
32);
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array(intermediate, 0);
compute_encoder.set_output_array(out, 1);
second_args.encode(compute_encoder);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
void strided_reduce_general_dispatch(
const array& in,
array& out,
@@ -570,11 +741,23 @@ void strided_reduce_general_dispatch(
// Prepare the arguments for the kernel
ColReduceArgs args(in, plan, axes);
if (args.reduction_stride < 32 ||
args.reduction_size * args.non_col_reductions < 32) {
// Small column
if (args.reduction_size * args.non_col_reductions < 32) {
return strided_reduce_small(in, out, op_name, args, compute_encoder, d, s);
}
// Long column but small row
if (args.reduction_stride < 32 &&
args.reduction_size * args.non_col_reductions >= 1024) {
return strided_reduce_longcolumn(
in, out, op_name, args, compute_encoder, d, s);
}
if (args.reduction_size * args.non_col_reductions > 256 &&
out.size() / 32 < 1024) {
return strided_reduce_2pass(in, out, op_name, args, compute_encoder, d, s);
}
return strided_reduce_looped(in, out, op_name, args, compute_encoder, d, s);
}
@@ -620,7 +803,6 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
// Reduce
if (in.size() > 0) {
std::vector<array> copies;
ReductionPlan plan = get_reduction_plan(in, axes_);
// If it is a general reduce then copy the input to a contiguous array and
@@ -632,7 +814,7 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
if (plan.type == GeneralReduce) {
array in_copy(in.shape(), in.dtype(), nullptr, {});
copy_gpu(in, in_copy, CopyType::General, s);
copies.push_back(in_copy);
d.add_temporary(in_copy, s.index);
in = in_copy;
plan = get_reduction_plan(in, axes_);
}
@@ -640,7 +822,7 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
// Reducing over everything and the data is all there no broadcasting or
// slicing etc.
if (plan.type == ContiguousAllReduce) {
all_reduce_dispatch(in, out, op_name, compute_encoder, d, s, copies);
all_reduce_dispatch(in, out, op_name, compute_encoder, d, s);
}
// At least the last dimension is row contiguous and we are reducing over
@@ -659,8 +841,6 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
strided_reduce_general_dispatch(
in, out, op_name, plan, axes_, compute_encoder, d, s);
}
d.add_temporaries(std::move(copies), s.index);
}
// Nothing to reduce just initialize the output

View File

@@ -16,8 +16,7 @@ void all_reduce_dispatch(
const std::string& op_name,
CommandEncoder& compute_encoder,
metal::Device& d,
const Stream& s,
std::vector<array>& copies);
const Stream& s);
void row_reduce_general_dispatch(
const array& in,

View File

@@ -1,8 +1,20 @@
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp)
if(MPI_FOUND AND MLX_BUILD_CPU)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/mpi)
else()
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_distributed.cpp)
if(MLX_BUILD_CPU)
if(MLX_CUSTOM_DISTRIBUTED)
if(MLX_CUSTOM_DISTRIBUTED STREQUAL "gloo")
message(STATUS "Distributed: using gloo backend")
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/gloo)
else()
message(STATUS "Distributed: using sockets backend")
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/sockets)
endif()
elseif(MPI_FOUND)
message(STATUS "Distributed: using MPI backend")
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/mpi)
else()
message(STATUS "Distributed: no support")
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_distributed.cpp)
endif()
endif()

View File

@@ -32,6 +32,8 @@ struct Group {
*/
Group split(int color, int key = -1);
void barrier();
const std::shared_ptr<void>& raw_group() {
return group_;
}

View File

@@ -0,0 +1,25 @@
find_path(
GLOO_INCLUDE_DIR gloo/allreduce.h
PATHS ${GLOO_INC_DIR}
PATH_SUFFIXES include)
find_library(
GLOO_LIBRARY gloo
PATHS ${GLOO_LIB_DIR}
PATH_SUFFIXES lib
HINTS GLOO)
find_library(
UV_LIBRARY uv
PATHS ${UV_LIB_DIR}
PATH_SUFFIXES lib
HINTS UV)
message(STATUS "GLOO LIB <${GLOO_LIBRARY}>")
message(STATUS "GLOO INC <${GLOO_INCLUDE_DIR}>")
message(STATUS "UV LIB <${UV_LIB_DIR}>")
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gloo.cpp)
target_link_libraries(mlx PUBLIC ${GLOO_LIBRARY})
target_link_libraries(mlx PUBLIC ${UV_LIBRARY})
target_include_directories(mlx PRIVATE ${GLOO_INCLUDE_DIR})

View File

@@ -0,0 +1,178 @@
// Copyright © 2024 Apple Inc.
#include <unistd.h>
#include <chrono>
#include <cstdlib>
#include <fstream>
#include <iostream>
#include <sstream>
#include <thread>
#include "mlx/backend/common/copy.h"
#include "mlx/distributed/distributed.h"
#include "mlx/distributed/distributed_impl.h"
#include "mlx/io/threadpool.h"
#include "gloo/allreduce.h"
#include "gloo/math.h"
#include "gloo/mpi/context.h"
#include "gloo/transport/uv/device.h"
#define SWITCH_TYPE(x, ...) \
switch ((x).dtype()) { \
case bool_: { \
using T = bool; \
__VA_ARGS__; \
} break; \
case int8: { \
using T = int8_t; \
__VA_ARGS__; \
} break; \
case int16: { \
using T = int16_t; \
__VA_ARGS__; \
} break; \
case int32: { \
using T = int32_t; \
__VA_ARGS__; \
} break; \
case int64: { \
using T = int64_t; \
__VA_ARGS__; \
} break; \
case uint8: { \
using T = uint8_t; \
__VA_ARGS__; \
} break; \
case uint16: { \
using T = uint16_t; \
__VA_ARGS__; \
} break; \
case uint32: { \
using T = uint32_t; \
__VA_ARGS__; \
} break; \
case uint64: { \
using T = uint64_t; \
__VA_ARGS__; \
} break; \
case bfloat16: { \
using T = bfloat16_t; \
__VA_ARGS__; \
} break; \
case float16: { \
using T = float16_t; \
__VA_ARGS__; \
} break; \
case float32: { \
using T = float; \
__VA_ARGS__; \
} break; \
case complex64: { \
using T = complex64_t; \
__VA_ARGS__; \
} break; \
}
namespace mlx::core::distributed {
namespace {
array ensure_row_contiguous(const array& arr) {
if (arr.flags().row_contiguous) {
return arr;
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::General);
return arr_copy;
}
}
} // namespace
bool is_available() {
return true;
}
int Group::rank() {
return std::static_pointer_cast<gloo::mpi::Context>(group_)->rank;
}
int Group::size() {
return std::static_pointer_cast<gloo::mpi::Context>(group_)->size;
}
Group Group::split(int color, int key) {
throw std::runtime_error("split is NYI");
}
void Group::barrier() {
throw std::runtime_error("barrier is NYI");
}
struct GlooCTX {
std::shared_ptr<gloo::mpi::Context> context;
std::shared_ptr<gloo::transport::Device> dev;
};
Group init(bool strict /* = false */) {
static std::shared_ptr<GlooCTX> gloo_ctx = nullptr;
if (gloo_ctx == nullptr) {
gloo_ctx = std::make_shared<GlooCTX>();
gloo_ctx->context = gloo::mpi::Context::createManaged();
gloo_ctx->dev = gloo::transport::uv::CreateDevice("localhost");
gloo_ctx->context->connectFullMesh(gloo_ctx->dev);
}
return Group(gloo_ctx->context);
}
namespace detail {
Stream communication_stream() {
static Stream comm_stream = new_stream(Device::cpu);
return comm_stream;
}
template <typename T>
void all_reduce_sum(
std::shared_ptr<gloo::mpi::Context> context,
T* output,
T* input,
size_t len) {
gloo::AllreduceOptions opts_(context);
opts_.setInput(input, len);
opts_.setOutput(output, len);
opts_.setAlgorithm(gloo::AllreduceOptions::Algorithm::RING);
opts_.setReduceFunction(
static_cast<void (*)(void*, const void*, const void*, size_t)>(
&gloo::sum<T>));
gloo::allreduce(opts_);
}
void all_sum(Group group_, const array& input_, array& output) {
array input = ensure_row_contiguous(input_);
if (input.data<void>() != output.data<void>()) {
std::memcpy(output.data<char>(), input.data<char>(), input.nbytes());
}
auto context =
std::static_pointer_cast<gloo::mpi::Context>(group_.raw_group());
SWITCH_TYPE(
output,
all_reduce_sum<T>(
context, output.data<T>(), input.data<T>(), input.size()));
}
void all_gather(Group group_, const array& input_, array& output) {
throw std::runtime_error("all_gather NYI");
}
void send(Group group_, const array& input_, int dst) {
throw std::runtime_error("send NYI");
}
void recv(Group group_, array& out, int src) {
throw std::runtime_error("recv NYI");
}
} // namespace detail
} // namespace mlx::core::distributed

View File

@@ -71,6 +71,7 @@ struct MPIWrapper {
LOAD_SYMBOL(MPI_Allgather, all_gather);
LOAD_SYMBOL(MPI_Send, send);
LOAD_SYMBOL(MPI_Recv, recv);
LOAD_SYMBOL(MPI_Barrier, barrier);
LOAD_SYMBOL(MPI_Type_contiguous, mpi_type_contiguous);
LOAD_SYMBOL(MPI_Type_commit, mpi_type_commit);
LOAD_SYMBOL(MPI_Op_create, mpi_op_create);
@@ -195,6 +196,7 @@ struct MPIWrapper {
int (*comm_free)(MPI_Comm*);
int (*send)(const void*, int, MPI_Datatype, int, int, MPI_Comm);
int (*recv)(void*, int, MPI_Datatype, int, int, MPI_Comm, MPI_Status*);
int (*barrier)(MPI_Comm);
// Objects
MPI_Comm comm_world_;
@@ -263,6 +265,10 @@ struct MPIGroupImpl {
return size_;
}
void barrier() {
mpi().barrier(comm_);
}
private:
MPI_Comm comm_;
bool global_;
@@ -298,6 +304,11 @@ Group Group::split(int color, int key) {
return Group(std::make_shared<MPIGroupImpl>(new_comm, false));
}
void Group::barrier() {
auto mpi_group = std::static_pointer_cast<MPIGroupImpl>(group_);
mpi_group->barrier();
}
bool is_available() {
return mpi().is_available();
}

View File

@@ -17,6 +17,8 @@ Group Group::split(int color, int key) {
throw std::runtime_error("Cannot split the distributed group further");
}
void Group::barrier() {}
bool is_available() {
return false;
}

View File

@@ -0,0 +1,5 @@
target_sources(
mlx
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/sockets.cpp
)

View File

@@ -0,0 +1,522 @@
// Copyright © 2024 Apple Inc.
#include <arpa/inet.h>
#include <json.hpp>
#include <netdb.h>
#include <sys/socket.h>
#include <unistd.h>
#include <chrono>
#include <cstdlib>
#include <fstream>
#include <iostream>
#include <sstream>
#include <thread>
#include "mlx/backend/common/copy.h"
#include "mlx/distributed/distributed.h"
#include "mlx/distributed/distributed_impl.h"
#include "mlx/io/threadpool.h"
#define SWITCH_TYPE(x, ...) \
switch ((x).dtype()) { \
case bool_: { \
using T = bool; \
__VA_ARGS__; \
} break; \
case int8: { \
using T = int8_t; \
__VA_ARGS__; \
} break; \
case int16: { \
using T = int16_t; \
__VA_ARGS__; \
} break; \
case int32: { \
using T = int32_t; \
__VA_ARGS__; \
} break; \
case int64: { \
using T = int64_t; \
__VA_ARGS__; \
} break; \
case uint8: { \
using T = uint8_t; \
__VA_ARGS__; \
} break; \
case uint16: { \
using T = uint16_t; \
__VA_ARGS__; \
} break; \
case uint32: { \
using T = uint32_t; \
__VA_ARGS__; \
} break; \
case uint64: { \
using T = uint64_t; \
__VA_ARGS__; \
} break; \
case bfloat16: { \
using T = bfloat16_t; \
__VA_ARGS__; \
} break; \
case float16: { \
using T = float16_t; \
__VA_ARGS__; \
} break; \
case float32: { \
using T = float; \
__VA_ARGS__; \
} break; \
case complex64: { \
using T = complex64_t; \
__VA_ARGS__; \
} break; \
}
constexpr const size_t PACKET_SIZE = 262144;
constexpr const int CONN_ATTEMPTS = 5;
constexpr const int CONN_WAIT = 1000;
using json = nlohmann::json;
namespace mlx::core::distributed {
namespace {
template <typename T>
void sum_inplace(const T* input, T* output, size_t N) {
while (N-- > 0) {
*output += *input;
input++;
output++;
}
}
void sum_inplace(const array& input, array& output) {
SWITCH_TYPE(
input, sum_inplace(input.data<T>(), output.data<T>(), input.size()));
}
array ensure_row_contiguous(const array& arr) {
if (arr.flags().row_contiguous) {
return arr;
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::General);
return arr_copy;
}
}
struct address_t {
sockaddr_storage addr;
socklen_t len;
const sockaddr* sockaddr() {
return (struct sockaddr*)&addr;
}
};
address_t parse_address(std::string ip, std::string port) {
struct addrinfo hints, *res;
memset(&hints, 0, sizeof(hints));
hints.ai_family = AF_UNSPEC;
hints.ai_socktype = SOCK_STREAM;
int status = getaddrinfo(ip.c_str(), port.c_str(), &hints, &res);
if (status != 0) {
std::ostringstream msg;
msg << "Can't parse peer address " << ip << ":" << port;
throw std::runtime_error(msg.str());
}
address_t result;
memcpy(&result.addr, res->ai_addr, res->ai_addrlen);
result.len = res->ai_addrlen;
freeaddrinfo(res);
return result;
}
std::vector<address_t> load_peers() {
std::vector<address_t> peers;
std::ifstream f;
if (const char* hostfile_buf = std::getenv("MLX_HOSTFILE")) {
f.open(hostfile_buf);
} else {
return peers;
}
json hosts = json::parse(f);
for (auto& h : hosts) {
peers.push_back(std::move(parse_address(
h["ip"].template get<std::string>(),
h["port"].template get<std::string>())));
}
return peers;
}
struct GroupImpl {
GroupImpl(std::vector<address_t> peers, int rank, bool global)
: rank_(rank), global_(global), pool_(4), sockets_(peers.size(), -1) {
if (rank_ > 0 && rank_ >= peers.size()) {
throw std::runtime_error(
"Rank cannot be larger than the size of the group");
}
int success;
// If we are expecting anyone to connect to us
if (rank_ + 1 < peers.size()) {
// Create the socket to wait for connections from the peers
int sock = socket(AF_INET, SOCK_STREAM, 0);
if (sock < 0) {
std::ostringstream msg;
msg << "Couldn't create socket (error: " << errno << ")";
throw std::runtime_error(msg.str());
}
// Make sure we can launch immediately after shutdown by setting the
// reuseaddr option so that we don't get address already in use errors
int enable = 1;
success =
setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &enable, sizeof(int));
if (success < 0) {
shutdown(sock, 2);
close(sock);
std::ostringstream msg;
msg << "Couldn't enable reuseaddr (rank: " << rank_
<< " error: " << errno << ")";
throw std::runtime_error(msg.str());
}
success =
setsockopt(sock, SOL_SOCKET, SO_REUSEPORT, &enable, sizeof(int));
if (success < 0) {
shutdown(sock, 2);
close(sock);
std::ostringstream msg;
msg << "Couldn't enable reuseport (rank: " << rank_
<< " error: " << errno << ")";
throw std::runtime_error(msg.str());
}
// Bind it to the port
success = bind(sock, peers[rank_].sockaddr(), peers[rank_].len);
if (success < 0) {
shutdown(sock, 2);
close(sock);
std::ostringstream msg;
msg << "Couldn't bind socket (rank: " << rank_ << " error: " << errno
<< ")";
throw std::runtime_error(msg.str());
}
// Wait for connections
success = listen(sock, 0);
if (success < 0) {
shutdown(sock, 2);
close(sock);
std::ostringstream msg;
msg << "Couldn't listen (error: " << errno << ")";
throw std::runtime_error(msg.str());
}
for (int i = 0; i < peers.size() - rank_ - 1; i++) {
int peer_socket = accept(sock, nullptr, nullptr);
if (peer_socket < 0) {
shutdown(sock, 2);
close(sock);
std::ostringstream msg;
msg << "Accept failed (error: " << errno << ")";
throw std::runtime_error(msg.str());
}
sockets_[peers.size() - 1 - i] = peer_socket;
}
// Close the listening socket
shutdown(sock, 2);
close(sock);
}
// Connect to the peers with smaller rank
for (int i = 0; i < rank_; i++) {
sockets_[i] = socket(AF_INET, SOCK_STREAM, 0);
if (sockets_[i] < 0) {
std::ostringstream msg;
msg << "Couldn't create socket (error: " << errno << ")";
throw std::runtime_error(msg.str());
}
for (int attempt = 0; attempt < CONN_ATTEMPTS; attempt++) {
if (attempt > 0) {
int wait = (1 << (attempt - 1)) * CONN_WAIT;
std::this_thread::sleep_for(std::chrono::milliseconds(wait));
}
success = connect(sockets_[i], peers[i].sockaddr(), peers[i].len);
if (success == 0) {
break;
}
}
if (success < 0) {
std::ostringstream msg;
msg << "Couldn't connect (rank: " << rank_ << " to: " << i
<< " error: " << errno << ")";
throw std::runtime_error(msg.str());
}
}
}
~GroupImpl() {
if (global_) {
for (int sock : sockets_) {
shutdown(sock, 2);
close(sock);
}
}
}
int rank() {
return rank_;
}
int size() {
return std::max(sockets_.size(), 1ul);
}
void send(const char* buf, size_t len, int dst) {
while (len > 0) {
ssize_t r = ::send(sockets_[dst], buf, len, 0);
if (r <= 0) {
std::ostringstream msg;
msg << "Send of " << len << " bytes failed (errno: " << errno << ")";
throw std::runtime_error(msg.str());
}
buf += r;
len -= r;
}
}
void recv(char* buf, size_t len, int src) {
while (len > 0) {
ssize_t r = ::recv(sockets_[src], buf, len, 0);
if (r <= 0) {
std::ostringstream msg;
msg << "Recv of " << len << " bytes failed (errno: " << errno << ")";
throw std::runtime_error(msg.str());
}
buf += r;
len -= r;
}
}
template <typename T>
void send_recv_sum(char* buf, size_t len, int peer) {
char recv_buffer[2 * PACKET_SIZE];
char* recv_buffers[2];
recv_buffers[0] = recv_buffer;
recv_buffers[1] = recv_buffer + PACKET_SIZE;
std::future<void> sent, received;
size_t n_blocks = (len + PACKET_SIZE - 1) / PACKET_SIZE;
for (size_t b = 0; b < n_blocks; b++) {
if (b > 0) {
sent.wait();
received.wait();
}
size_t l = std::min(len - b * PACKET_SIZE, PACKET_SIZE);
if (rank_ < peer) {
sent = send_async(buf + b * PACKET_SIZE, l, peer);
received = recv_async(recv_buffers[b % 2], l, peer);
} else {
received = recv_async(recv_buffers[b % 2], l, peer);
sent = send_async(buf + b * PACKET_SIZE, l, peer);
}
if (b > 0) {
sum_inplace(
(const T*)recv_buffers[(b - 1) % 2],
(T*)(buf + (b - 1) * PACKET_SIZE),
PACKET_SIZE / sizeof(T));
}
}
sent.wait();
received.wait();
size_t l = std::min(len - (n_blocks - 1) * PACKET_SIZE, PACKET_SIZE);
sum_inplace(
(const T*)recv_buffers[(n_blocks - 1) % 2],
(T*)(buf + (n_blocks - 1) * PACKET_SIZE),
l / sizeof(T));
}
void send_recv_sum(array& out, int peer) {
SWITCH_TYPE(out, send_recv_sum<T>(out.data<char>(), out.nbytes(), peer));
}
std::future<void> send_async(const char* buf, size_t len, int dst) {
return pool_.enqueue(
[this, buf, len, dst]() { this->send(buf, len, dst); });
}
std::future<void> recv_async(char* buf, size_t len, int src) {
return pool_.enqueue(
[this, buf, len, src]() { this->recv(buf, len, src); });
}
private:
int rank_;
bool global_;
ThreadPool pool_;
std::vector<int> sockets_;
};
} // namespace
bool is_available() {
return true;
}
int Group::rank() {
return std::static_pointer_cast<GroupImpl>(group_)->rank();
}
int Group::size() {
return std::static_pointer_cast<GroupImpl>(group_)->size();
}
Group Group::split(int color, int key) {
throw std::runtime_error("Splitting not supported yet");
}
void Group::barrier() {
char buff[128];
std::memset(buff, 1, 128);
auto group = std::static_pointer_cast<GroupImpl>(raw_group());
int size = group->size();
int rank = group->rank();
for (int distance = 1; distance <= size / 2; distance *= 2) {
group->send_recv_sum<char>(buff, 128, rank ^ distance);
}
}
Group init(bool strict /* = false */) {
static std::shared_ptr<GroupImpl> global_group = nullptr;
if (global_group == nullptr) {
auto peers = load_peers();
int rank = 0;
if (const char* rank_buf = std::getenv("MLX_RANK")) {
rank = std::atoi(rank_buf);
}
if (peers.size() == 0) {
if (strict) {
throw std::runtime_error("Can't initialize distributed");
}
}
global_group = std::make_shared<GroupImpl>(std::move(peers), rank, true);
}
return Group(global_group);
}
namespace detail {
Stream communication_stream() {
static Stream comm_stream = new_stream(Device::cpu);
return comm_stream;
}
void all_sum(Group group_, const array& input_, array& output) {
auto group = std::static_pointer_cast<GroupImpl>(group_.raw_group());
array input = ensure_row_contiguous(input_);
int size = group->size();
int rank = group->rank();
if ((size & (size - 1)) != 0) {
throw std::runtime_error("Only powers of 2 are currently supported");
}
// If not inplace all reduce then copy the input to the output first.
if (input.data<void>() != output.data<void>()) {
std::memcpy(output.data<char>(), input.data<char>(), input.nbytes());
}
// Butterfly all reduce
for (int distance = 1; distance <= size / 2; distance *= 2) {
group->send_recv_sum(output, rank ^ distance);
}
}
void all_gather(Group group_, const array& input_, array& output) {
auto group = std::static_pointer_cast<GroupImpl>(group_.raw_group());
array input = ensure_row_contiguous(input_);
std::future<void> sent;
std::future<void> received;
int rank = group->rank();
int size = group->size();
if ((size & (size - 1)) != 0) {
throw std::runtime_error("Only powers of 2 are currently supported");
}
// Butterfly all gather
int peer = rank ^ 1;
if (peer < rank) {
received = group->recv_async(
output.data<char>() + peer * input.nbytes(), input.nbytes(), peer);
sent = group->send_async(input.data<char>(), input.nbytes(), peer);
} else {
sent = group->send_async(input.data<char>(), input.nbytes(), peer);
received = group->recv_async(
output.data<char>() + peer * input.nbytes(), input.nbytes(), peer);
}
std::memcpy(
output.data<char>() + rank * input.nbytes(),
input.data<char>(),
input.nbytes());
for (int distance = 2; distance <= size / 2; distance *= 2) {
sent.wait();
received.wait();
int peer = rank ^ distance;
int their_offset = peer & ~(distance - 1);
int our_offset = rank & ~(distance - 1);
if (peer < rank) {
received = group->recv_async(
output.data<char>() + their_offset * input.nbytes(),
distance * input.nbytes(),
peer);
sent = group->send_async(
output.data<char>() + our_offset * input.nbytes(),
distance * input.nbytes(),
peer);
} else {
sent = group->send_async(
output.data<char>() + our_offset * input.nbytes(),
distance * input.nbytes(),
peer);
received = group->recv_async(
output.data<char>() + their_offset * input.nbytes(),
distance * input.nbytes(),
peer);
}
}
sent.wait();
received.wait();
}
void send(Group group_, const array& input_, int dst) {
array input = ensure_row_contiguous(input_);
auto group = std::static_pointer_cast<GroupImpl>(group_.raw_group());
group->send(input.data<char>(), input.nbytes(), dst);
}
void recv(Group group_, array& out, int src) {
auto group = std::static_pointer_cast<GroupImpl>(group_.raw_group());
group->recv(out.data<char>(), out.nbytes(), src);
}
} // namespace detail
} // namespace mlx::core::distributed

View File

@@ -1,5 +1,5 @@
// Copyright © 2023 Apple Inc.
//
#include <json.hpp>
#include <stack>

View File

@@ -60,6 +60,12 @@ from mlx.nn.layers.convolution_transpose import (
ConvTranspose2d,
ConvTranspose3d,
)
from mlx.nn.layers.distributed import (
AllToShardedLinear,
QuantizedAllToShardedLinear,
QuantizedShardedToAllLinear,
ShardedToAllLinear,
)
from mlx.nn.layers.dropout import Dropout, Dropout2d, Dropout3d
from mlx.nn.layers.embedding import Embedding
from mlx.nn.layers.linear import Bilinear, Identity, Linear

View File

@@ -0,0 +1,456 @@
# Copyright © 2024 Apple Inc.
import math
from functools import lru_cache
from typing import Optional
import mlx.core as mx
from mlx.nn.layers.base import Module
@lru_cache
def sum_gradients(group):
if group.size() == 1:
return lambda x: x
@mx.custom_function
def f(x):
return x
@f.vjp
def f(x, dx, _):
return mx.distributed.all_sum(dx, group=group)
return f
class AllToShardedLinear(Module):
"""Each member of the group applies part of the affine transformation such
that the result is sharded across the group.
The gradients are automatically aggregated from each member of the group.
Args:
input_dims (int): The dimensionality of the input features
output_dims (int): The dimensionality of the output features
bias (bool, optional): If set to ``False`` the the layer will not use a
bias. Default is ``True``.
group (mx.distributed.Group, optional): The sharding will happen across
this group. If not set then the global group is used. Default is
``None``.
"""
def __init__(
self,
input_dims: int,
output_dims: int,
bias: bool = True,
group: Optional[mx.distributed.Group] = None,
):
super().__init__()
# Initialize the parameters
scale = math.sqrt(1.0 / input_dims)
self.group = group or mx.distributed.init()
N = self.group.size()
if (output_dims % N) != 0:
raise ValueError(
f"Cannot shard the output of size {output_dims} across {N} devices."
)
self.weight = mx.random.uniform(
low=-scale,
high=scale,
shape=(output_dims // N, input_dims),
)
if bias:
self.bias = mx.random.uniform(
low=-scale,
high=scale,
shape=(output_dims // N,),
)
def _extra_repr(self) -> str:
out_dims, in_dims = self.weight.shape
N = self.group.size()
out_dims *= N
return f"input_dims={in_dims}, output_dims={out_dims}, bias={'bias' in self}"
def __call__(self, x: mx.array) -> mx.array:
# Aggregate the gradients coming from each shard
if self.group.size() > 1:
x = sum_gradients(self.group)(x)
# Compute the affine projection
if "bias" in self:
x = mx.addmm(self["bias"], x, self["weight"].T)
else:
x = x @ self["weight"].T
return x
@classmethod
def from_linear(
cls, linear_layer: Module, group: Optional[mx.distributed.Group] = None
):
group = group or mx.distributed.init()
N = group.size()
r = group.rank()
output_dims, input_dims = linear_layer.weight.shape
step = output_dims // N
sl = cls(input_dims, output_dims, False, group)
# The multiplication with 1.0 forces a copy, perhaps change to
# something better when available.
sl.weight = linear_layer.weight[r * step : (r + 1) * step] * 1
if "bias" in linear_layer:
sl.bias = linear_layer.bias[r * step : (r + 1) * step] * 1
return sl
class ShardedToAllLinear(Module):
"""Each member of the group applies part of the affine transformation and
then aggregates the results.
All nodes will have the same exact result after this layer.
:class:`ShardedToAllLinear` provides a classmethod :meth:`from_linear` to
convert linear layers to sharded :obj:`ShardedToAllLinear` layers.
Args:
input_dims (int): The dimensionality of the input features
output_dims (int): The dimensionality of the output features
bias (bool, optional): If set to ``False`` the the layer will not use a
bias. Default is ``True``.
group (mx.distributed.Group, optional): The sharding will happen across
this group. If not set then the global group is used. Default is
``None``.
"""
def __init__(
self,
input_dims: int,
output_dims: int,
bias: bool = True,
group: Optional[mx.distributed.Group] = None,
):
super().__init__()
# Initialize the parameters
scale = math.sqrt(1.0 / input_dims)
self.group = group or mx.distributed.init()
N = self.group.size()
if (input_dims % N) != 0:
raise ValueError(
f"The input of size {input_dims} cannot be sharded across {N} devices."
)
self.weight = mx.random.uniform(
low=-scale,
high=scale,
shape=(output_dims, input_dims // N),
)
if bias:
self.bias = mx.random.uniform(
low=-scale,
high=scale,
shape=(output_dims,),
)
def _extra_repr(self) -> str:
N = self.group.size()
out_dims, in_dims = self.weight.shape
in_dims *= N
return f"input_dims={in_dims}, output_dims={out_dims}, bias={'bias' in self}"
def __call__(self, x: mx.array) -> mx.array:
if self.group.size() > 1:
# Perform the local projection and aggregate the results
x = x @ self["weight"].T
x = mx.distributed.all_sum(x, group=self.group)
# Add the bias if we have one
if "bias" in self:
x = x + self["bias"]
else:
# Normal linear layer as we are not in a distributed setting.
if "bias" in self:
x = mx.addmm(self["bias"], x, self["weight"].T)
else:
x = x @ self["weight"].T
return x
@classmethod
def from_linear(
cls, linear_layer: Module, group: Optional[mx.distributed.Group] = None
):
group = group or mx.distributed.init()
N = group.size()
r = group.rank()
output_dims, input_dims = linear_layer.weight.shape
step = input_dims // N
sl = cls(input_dims, output_dims, False, group)
# The multiplication with 1.0 forces a copy, perhaps change to
# something better when available.
sl.weight = linear_layer.weight[:, r * step : (r + 1) * step] * 1
if "bias" in linear_layer:
sl.bias = linear_layer.bias
return sl
class QuantizedAllToShardedLinear(Module):
"""Each member of the group applies part of the affine transformation with
a quantized matrix such that the result is sharded across the group.
It is the quantized equivalent of :class:`mlx.nn.AllToShardedLinear`.
Similar to :class:`mlx.nn.QuantizedLinear` its parameters are frozen and
will not be included in any gradient computation.
Args:
input_dims (int): The dimensionality of the input features.
output_dims (int): The dimensionality of the output features.
bias (bool, optional): If set to ``False`` then the layer will not use
a bias. Default: ``True``.
group_size (int, optional): The group size to use for the quantized
weight. See :func:`~mlx.core.quantize`. Default: ``64``.
bits (int, optional): The bit width to use for the quantized weight.
See :func:`~mlx.core.quantize`. Default: ``4``.
group (mx.distributed.Group, optional): The sharding will happen across
this group. If not set then the global group is used. Default is
``None``.
"""
def __init__(
self,
input_dims: int,
output_dims: int,
bias: bool = True,
group_size: int = 64,
bits: int = 4,
group: Optional[mx.distributed.Group] = None,
):
super().__init__()
# Quantization config
self.group_size = group_size
self.bits = bits
# Initialize the quantized weight
scale = math.sqrt(1.0 / input_dims)
self.group = group or mx.distributed.init()
N = self.group.size()
if (output_dims % N) != 0:
raise ValueError(
f"Cannot shard the output of size {output_dims} across {N} devices."
)
weight = mx.random.uniform(
low=-scale,
high=scale,
shape=(output_dims // N, input_dims),
)
self.weight, self.scales, self.biases = mx.quantize(weight, group_size, bits)
# And bias if needed
if bias:
self.bias = mx.zeros((output_dims // N,))
# Freeze this model's parameters
self.freeze()
def unfreeze(self, *args, **kwargs):
"""Wrap unfreeze so that we unfreeze any layers we might contain but
our parameters will remain frozen."""
super().unfreeze(*args, **kwargs)
self.freeze(recurse=False)
def _extra_repr(self) -> str:
out_dims, in_dims = self.weight.shape
in_dims *= 32 // self.bits
out_dims *= self.group.size()
return (
f"input_dims={in_dims}, output_dims={out_dims}, bias={'bias' in self}, "
f"group_size={self.group_size}, bits={self.bits}"
)
def __call__(self, x: mx.array) -> mx.array:
# Aggregate the gradients coming from each shard
if self.group.size() > 1:
x = sum_gradients(self.group)(x)
x = mx.quantized_matmul(
x,
self["weight"],
scales=self["scales"],
biases=self["biases"],
transpose=True,
group_size=self.group_size,
bits=self.bits,
)
if "bias" in self:
x = x + self["bias"]
return x
@classmethod
def from_quantized_linear(
cls,
quantized_linear_layer: Module,
group: Optional[mx.distributed.Group] = None,
):
group = group or mx.distributed.init()
N = group.size()
r = group.rank()
output_dims, input_dims = quantized_linear_layer.weight.shape
input_dims *= 32 // quantized_linear_layer.bits
step = output_dims // N
sl = cls(
input_dims,
output_dims,
False,
group_size=quantized_linear_layer.group_size,
bits=quantized_linear_layer.bits,
group=group,
)
sl.weight = quantized_linear_layer.weight[r * step : (r + 1) * step] * 1
sl.scales = quantized_linear_layer.scales[r * step : (r + 1) * step] * 1
sl.biases = quantized_linear_layer.biases[r * step : (r + 1) * step] * 1
if "bias" in quantized_linear_layer:
sl.bias = quantized_linear_layer.bias[r * step : (r + 1) * step] * 1
return sl
class QuantizedShardedToAllLinear(Module):
"""Each member of the group applies part of the affine transformation using
the quantized matrix and then aggregates the results.
All nodes will have the same exact result after this layer.
It is the quantized equivalent of :class:`mlx.nn.ShardedToAllLinear`.
Similar to :class:`mlx.nn.QuantizedLinear` its parameters are frozen and
will not be included in any gradient computation.
Args:
input_dims (int): The dimensionality of the input features.
output_dims (int): The dimensionality of the output features.
bias (bool, optional): If set to ``False`` then the layer will not use
a bias. Default: ``True``.
group_size (int, optional): The group size to use for the quantized
weight. See :func:`~mlx.core.quantize`. Default: ``64``.
bits (int, optional): The bit width to use for the quantized weight.
See :func:`~mlx.core.quantize`. Default: ``4``.
group (mx.distributed.Group, optional): The sharding will happen across
this group. If not set then the global group is used. Default is
``None``.
"""
def __init__(
self,
input_dims: int,
output_dims: int,
bias: bool = True,
group_size: int = 64,
bits: int = 4,
group: Optional[mx.distributed.Group] = None,
):
super().__init__()
# Quantization config
self.group_size = group_size
self.bits = bits
# Initialize the quantized weight
scale = math.sqrt(1.0 / input_dims)
self.group = group or mx.distributed.init()
N = self.group.size()
if (input_dims % N) != 0:
raise ValueError(
f"The input of size {input_dims} cannot be sharded across {N} devices."
)
weight = mx.random.uniform(
low=-scale,
high=scale,
shape=(output_dims, input_dims // N),
)
self.weight, self.scales, self.biases = mx.quantize(weight, group_size, bits)
# And bias if needed
if bias:
self.bias = mx.zeros((output_dims,))
# Freeze this model's parameters
self.freeze()
def unfreeze(self, *args, **kwargs):
"""Wrap unfreeze so that we unfreeze any layers we might contain but
our parameters will remain frozen."""
super().unfreeze(*args, **kwargs)
self.freeze(recurse=False)
def _extra_repr(self) -> str:
out_dims, in_dims = self.weight.shape
in_dims *= (32 // self.bits) * self.group.size()
return (
f"input_dims={in_dims}, output_dims={out_dims}, bias={'bias' in self}, "
f"group_size={self.group_size}, bits={self.bits}"
)
def __call__(self, x: mx.array) -> mx.array:
x = mx.quantized_matmul(
x,
self["weight"],
scales=self["scales"],
biases=self["biases"],
transpose=True,
group_size=self.group_size,
bits=self.bits,
)
if self.group.size() > 1:
x = mx.distributed.all_sum(x, group=self.group)
if "bias" in self:
x = x + self["bias"]
return x
@classmethod
def from_quantized_linear(
cls,
quantized_linear_layer: Module,
group: Optional[mx.distributed.Group] = None,
):
group = group or mx.distributed.init()
N = group.size()
r = group.rank()
output_dims, input_dims = quantized_linear_layer.weight.shape
step = input_dims // N
step_grouped = quantized_linear_layer.scales.shape[1] // N
input_dims *= (32 // quantized_linear_layer.bits) * N
sl = cls(
input_dims,
output_dims,
False,
group_size=quantized_linear_layer.group_size,
bits=quantized_linear_layer.bits,
group=group,
)
sl.weight = quantized_linear_layer.weight[:, r * step : (r + 1) * step] * 1
sl.scales = (
quantized_linear_layer.scales[:, r * step_grouped : (r + 1) * step_grouped]
* 1
)
sl.biases = (
quantized_linear_layer.biases[:, r * step_grouped : (r + 1) * step_grouped]
* 1
)
if "bias" in quantized_linear_layer:
sl.bias = quantized_linear_layer.bias
return sl

View File

@@ -197,7 +197,7 @@ class QuantizedLinear(Module):
out_dims, in_dims = self.weight.shape
in_dims *= 32 // self.bits
return (
f"input_dims={in_dims}, output_dims={out_dims}, bias={'bias' in self},"
f"input_dims={in_dims}, output_dims={out_dims}, bias={'bias' in self}, "
f"group_size={self.group_size}, bits={self.bits}"
)

View File

@@ -44,7 +44,8 @@ void init_distributed(nb::module_& parent_module) {
color (int): A value to group processes into subgroups.
key (int, optional): A key to optionally change the rank ordering
of the processes.
)pbdoc");
)pbdoc")
.def("barrier", &distributed::Group::barrier, "Make a synhronization point for all nodes in the group");
m.def(
"is_available",

View File

@@ -163,6 +163,31 @@ class TestQuantized(mlx_tests.MLXTestCase):
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
def test_qvm_splitk(self):
key = mx.random.key(0)
k1, k2 = mx.random.split(key)
tests = product(
[128, 64, 32], # group_size
[2, 4, 8], # bits
[128], # M
[16384], # N
[1, 3], # B
)
for group_size, bits, M, N, B in tests:
with self.subTest(shape=(B, M, N), group_size=group_size, bits=bits):
x_shape = (1, N) if B == 0 else (B, 1, N)
w_shape = (N, M) if B == 0 else (B, N, M)
x = mx.random.normal(shape=x_shape, key=k1)
w = mx.random.normal(shape=w_shape, key=k2)
w_q, scales, biases = mx.quantize(w, group_size, bits)
w_hat = mx.dequantize(w_q, scales, biases, group_size, bits)
y_q = mx.quantized_matmul(
x, w_q, scales, biases, False, group_size, bits
)
y_hat = x @ w_hat
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 2e-3)
def test_throw(self):
x = mx.random.normal(shape=(10, 512))
w = mx.random.normal(shape=(32, 512))