MoE backward improvements (#2335)

This commit is contained in:
Angelos Katharopoulos 2025-07-07 17:59:53 -07:00 committed by GitHub
parent a4fcc893cd
commit 4a9b29a875
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
22 changed files with 1130 additions and 60 deletions

View File

@ -6,6 +6,7 @@
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/gemm.h"
#include "mlx/backend/cpu/lapack.h"
#include "mlx/primitives.h"
@ -52,6 +53,58 @@ inline void mask_matrix(
}
}
template <typename T>
inline void segmented_mm(
const T* a,
const T* b,
const uint32_t* segments,
T* out,
bool a_transposed,
bool b_transposed,
size_t lda,
size_t ldb,
const Shape& a_shape,
const Strides& a_strides,
const Shape& b_shape,
const Strides& b_strides,
size_t num_segments,
const Shape& segments_shape,
const Strides& segments_strides) {
int ndim = a_shape.size();
Shape a_copy = a_shape;
Shape b_copy = b_shape;
int32_t M = a_copy[ndim - 2];
int32_t N = b_copy[ndim - 1];
for (int i = 0; i < num_segments; i++) {
uint32_t k_start =
segments[elem_to_loc(2 * i, segments_shape, segments_strides)];
uint32_t k_end =
segments[elem_to_loc(2 * i + 1, segments_shape, segments_strides)];
if (k_end <= k_start) {
std::fill_n(out + i * M * N, M * N, T(0));
continue;
}
a_copy[ndim - 1] = k_end - k_start;
b_copy[ndim - 2] = k_end - k_start;
matmul<T>(
a + k_start * a_strides[ndim - 1],
b + k_start * b_strides[ndim - 2],
out + i * M * N,
a_transposed,
b_transposed,
lda,
ldb,
N,
1.0,
0.0,
1,
a_copy,
a_strides,
b_copy,
b_strides);
}
}
} // namespace
void BlockMaskedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
@ -437,4 +490,121 @@ void GatherMM::eval_cpu(const std::vector<array>& inputs, array& out) {
encoder.add_temporaries(std::move(temps));
}
void SegmentedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
out.set_data(allocator::malloc(out.nbytes()));
auto& s = stream();
auto& encoder = cpu::get_command_encoder(stream());
auto check_transpose = [&s, &encoder](const array& x) {
auto stx = x.strides()[x.ndim() - 2];
auto sty = x.strides()[x.ndim() - 1];
if (stx == x.shape(-1) && sty == 1) {
return std::make_tuple(false, stx, x);
} else if (stx == 1 && sty == x.shape(-2)) {
return std::make_tuple(true, sty, x);
} else {
array xc(x.shape(), x.dtype(), nullptr, {});
copy(x, xc, CopyType::General, s);
encoder.add_temporary(xc);
int64_t stx = x.shape(-1);
return std::make_tuple(false, stx, xc);
}
};
auto [a_transposed, lda, a] = check_transpose(inputs[0]);
auto [b_transposed, ldb, b] = check_transpose(inputs[1]);
auto& segments = inputs[2];
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_input_array(segments);
encoder.set_output_array(out);
encoder.dispatch([a = array::unsafe_weak_copy(a),
b = array::unsafe_weak_copy(b),
segments = array::unsafe_weak_copy(segments),
out_ptr = out.data<void>(),
a_transposed = a_transposed,
b_transposed = b_transposed,
lda = lda,
ldb = ldb]() {
switch (a.dtype()) {
case float64:
segmented_mm<double>(
a.data<double>(),
b.data<double>(),
segments.data<uint32_t>(),
static_cast<double*>(out_ptr),
a_transposed,
b_transposed,
lda,
ldb,
a.shape(),
a.strides(),
b.shape(),
b.strides(),
segments.size() / 2,
segments.shape(),
segments.strides());
break;
case float32:
segmented_mm<float>(
a.data<float>(),
b.data<float>(),
segments.data<uint32_t>(),
static_cast<float*>(out_ptr),
a_transposed,
b_transposed,
lda,
ldb,
a.shape(),
a.strides(),
b.shape(),
b.strides(),
segments.size() / 2,
segments.shape(),
segments.strides());
break;
case float16:
segmented_mm<float16_t>(
a.data<float16_t>(),
b.data<float16_t>(),
segments.data<uint32_t>(),
static_cast<float16_t*>(out_ptr),
a_transposed,
b_transposed,
lda,
ldb,
a.shape(),
a.strides(),
b.shape(),
b.strides(),
segments.size() / 2,
segments.shape(),
segments.strides());
break;
case bfloat16:
segmented_mm<bfloat16_t>(
a.data<bfloat16_t>(),
b.data<bfloat16_t>(),
segments.data<uint32_t>(),
static_cast<bfloat16_t*>(out_ptr),
a_transposed,
b_transposed,
lda,
ldb,
a.shape(),
a.strides(),
b.shape(),
b.strides(),
segments.size() / 2,
segments.shape(),
segments.strides());
break;
default:
throw std::invalid_argument(
"Segmented mm supports only real float types.");
}
});
}
} // namespace mlx::core

View File

@ -83,6 +83,7 @@ NO_GPU_MULTI(LUF)
NO_GPU_MULTI(QRF)
NO_GPU(QuantizedMatmul)
NO_GPU(Scan)
NO_GPU(SegmentedMM)
NO_GPU_MULTI(SVD)
NO_GPU(Inverse)
NO_GPU(Cholesky)

View File

@ -63,6 +63,7 @@ if(MLX_METAL_JIT)
make_jit_source(steel/gemm/kernels/steel_gemm_masked kernels/steel/defines.h)
make_jit_source(steel/gemm/kernels/steel_gemm_gather)
make_jit_source(steel/gemm/kernels/steel_gemm_splitk)
make_jit_source(steel/gemm/kernels/steel_gemm_segmented)
make_jit_source(
steel/conv/conv
kernels/steel/utils.h

View File

@ -575,9 +575,17 @@ void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.set_output_array(out, 2);
// Set source info
compute_encoder.set_vector_bytes(remove_index(idx.shape(), axis_), 3);
compute_encoder.set_vector_bytes(remove_index(upd.strides(), axis_), 4);
compute_encoder.set_vector_bytes(remove_index(idx.strides(), axis_), 5);
if (ndim > 1) {
compute_encoder.set_vector_bytes(remove_index(idx.shape(), axis_), 3);
compute_encoder.set_vector_bytes(remove_index(upd.strides(), axis_), 4);
compute_encoder.set_vector_bytes(remove_index(idx.strides(), axis_), 5);
} else {
// The following will be ignored in the kernel but we still have to set
// some value so that metal validation passes.
compute_encoder.set_vector_bytes(idx.shape(), 3);
compute_encoder.set_vector_bytes(upd.strides(), 4);
compute_encoder.set_vector_bytes(idx.strides(), 5);
}
compute_encoder.set_bytes(ndim - 1, 6);
compute_encoder.set_bytes(axis_, 7);
compute_encoder.set_bytes(out.shape(axis_), 8);

View File

@ -34,6 +34,7 @@ const char* steel_gemm_fused();
const char* steel_gemm_masked();
const char* steel_gemm_splitk();
const char* steel_gemm_gather();
const char* steel_gemm_segmented();
const char* conv();
const char* steel_conv();
const char* steel_conv_general();

View File

@ -652,6 +652,43 @@ MTL::ComputePipelineState* get_steel_gemm_gather_kernel(
return d.get_kernel(kernel_name, lib, hash_name, func_consts);
}
MTL::ComputePipelineState* get_steel_gemm_segmented_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& hash_name,
const metal::MTLFCList& func_consts,
const array& out,
bool transpose_a,
bool transpose_b,
int bm,
int bn,
int bk,
int wm,
int wn) {
const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name, [&]() {
std::string kernel_source;
concatenate(
kernel_source,
metal::utils(),
metal::gemm(),
metal::steel_gemm_segmented(),
get_template_definition(
lib_name,
"segmented_mm",
get_type_string(out.dtype()),
bm,
bn,
bk,
wm,
wn,
transpose_a,
transpose_b));
return kernel_source;
});
return d.get_kernel(kernel_name, lib, hash_name, func_consts);
}
MTL::ComputePipelineState* get_gemv_masked_kernel(
metal::Device& d,
const std::string& kernel_name,

View File

@ -175,6 +175,20 @@ MTL::ComputePipelineState* get_steel_gemm_gather_kernel(
int wn,
bool rhs);
MTL::ComputePipelineState* get_steel_gemm_segmented_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& hash_name,
const metal::MTLFCList& func_consts,
const array& out,
bool transpose_a,
bool transpose_b,
int bm,
int bn,
int bk,
int wm,
int wn);
MTL::ComputePipelineState* get_steel_conv_kernel(
metal::Device& d,
const std::string& kernel_name,

View File

@ -71,6 +71,7 @@ set(STEEL_HEADERS
steel/gemm/kernels/steel_gemm_fused.h
steel/gemm/kernels/steel_gemm_gather.h
steel/gemm/kernels/steel_gemm_masked.h
steel/gemm/kernels/steel_gemm_segmented.h
steel/gemm/kernels/steel_gemm_splitk.h
steel/utils/type_traits.h
steel/utils/integral_constant.h)
@ -120,6 +121,7 @@ if(NOT MLX_METAL_JIT)
build_kernel(steel/gemm/kernels/steel_gemm_gather ${STEEL_HEADERS})
build_kernel(steel/gemm/kernels/steel_gemm_masked ${STEEL_HEADERS})
build_kernel(steel/gemm/kernels/steel_gemm_splitk ${STEEL_HEADERS})
build_kernel(steel/gemm/kernels/steel_gemm_segmented ${STEEL_HEADERS})
build_kernel(gemv_masked steel/utils.h)
endif()

View File

@ -0,0 +1,266 @@
// Copyright © 2025 Apple Inc.
using namespace mlx::steel;
constant bool segments_contiguous [[function_constant(199)]];
constant bool align_M [[function_constant(200)]];
constant bool align_N [[function_constant(201)]];
template <
typename T,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
typename AccumType = float>
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void segmented_mm(
const device T* A [[buffer(0)]],
const device T* B [[buffer(1)]],
const device uint32_t* segments [[buffer(2)]],
device T* C [[buffer(3)]],
const constant GEMMParams* params [[buffer(4)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]]) {
using gemm_kernel = GEMMKernel<
T,
T,
BM,
BN,
BK,
WM,
WN,
transpose_a,
transpose_b,
true,
true,
AccumType>;
using loader_a_t = typename gemm_kernel::loader_a_t;
using loader_b_t = typename gemm_kernel::loader_b_t;
using mma_t = typename gemm_kernel::mma_t;
if (params->tiles_n <= static_cast<int>(tid.x) ||
params->tiles_m <= static_cast<int>(tid.y)) {
return;
}
// Prepare threadgroup memory
threadgroup T As[gemm_kernel::tgp_mem_size_a];
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
// Find the block in A, B, C
const int c_row = tid.y * BM;
const int c_col = tid.x * BN;
const size_t c_row_long = size_t(c_row);
const size_t c_col_long = size_t(c_col);
// Prepare threadgroup bounds
const short tgp_bm = align_M ? BM : short(min(BM, params->M - c_row));
const short tgp_bn = align_N ? BN : short(min(BN, params->N - c_col));
// Move the pointers to the output tile
A += transpose_a ? c_row_long : c_row_long * params->lda;
B += transpose_b ? c_col_long * params->ldb : c_col_long;
C += c_row_long * params->ldd + c_col_long;
// Move the pointers to the start of the segment
uint32_t k_start, k_end;
if (segments_contiguous) {
k_start = segments[2 * tid.z];
k_end = segments[2 * tid.z + 1];
} else {
// We accept either contiguous (above) or weird strides where the beginning
// of the next one is the previous one. Basically the last two strides are
// both 1!
k_start = segments[tid.z];
k_end = segments[tid.z + 1];
}
A += transpose_a ? k_start * params->lda : k_start;
B += transpose_b ? k_start : k_start * params->ldb;
C += tid.z * params->batch_stride_d;
// Prepare threadgroup mma operation
thread mma_t mma_op(simd_group_id, simd_lane_id);
// Prepare threadgroup loading operations
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
// Matrix level alignment so only check K
if (align_M && align_N) {
uint32_t k = k_start + BK;
for (; k <= k_end; k += BK) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup
loader_a.load_unsafe();
loader_b.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
short k_remain = BK - short(k - k_end);
const short2 tile_dims_A =
transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
const short2 tile_dims_B =
transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);
if (k_remain > 0) {
threadgroup_barrier(mem_flags::mem_threadgroup);
loader_a.load_safe(tile_dims_A);
loader_b.load_safe(tile_dims_B);
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(As, Bs);
}
mma_op.store_result(C, params->ldd);
} else {
// Tile aligned do the same as above
if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) {
uint32_t k = k_start + BK;
for (; k <= k_end; k += BK) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup
loader_a.load_unsafe();
loader_b.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
short k_remain = BK - short(k - k_end);
const short2 tile_dims_A =
transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
const short2 tile_dims_B =
transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);
if (k_remain > 0) {
threadgroup_barrier(mem_flags::mem_threadgroup);
loader_a.load_safe(tile_dims_A);
loader_b.load_safe(tile_dims_B);
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(As, Bs);
}
mma_op.store_result(C, params->ldd);
}
// Tile partially aligned check rows
else if (align_N || tgp_bn == BN) {
uint32_t k = k_start + BK;
for (; k <= k_end; k += BK) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup
loader_a.load_safe(
transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm));
loader_b.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
short k_remain = BK - short(k - k_end);
const short2 tile_dims_A =
transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
const short2 tile_dims_B =
transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);
if (k_remain > 0) {
threadgroup_barrier(mem_flags::mem_threadgroup);
loader_a.load_safe(tile_dims_A);
loader_b.load_safe(tile_dims_B);
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(As, Bs);
}
mma_op.store_result_safe(C, params->ldd, short2(tgp_bn, tgp_bm));
}
// Tile partially aligned check cols
else if (align_M || tgp_bm == BM) {
uint32_t k = k_start + BK;
for (; k <= k_end; k += BK) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup
loader_a.load_unsafe();
loader_b.load_safe(
transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK));
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
short k_remain = BK - short(k - k_end);
const short2 tile_dims_A =
transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
const short2 tile_dims_B =
transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);
if (k_remain > 0) {
threadgroup_barrier(mem_flags::mem_threadgroup);
loader_a.load_safe(tile_dims_A);
loader_b.load_safe(tile_dims_B);
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(As, Bs);
}
mma_op.store_result_safe(C, params->ldd, short2(tgp_bn, tgp_bm));
}
// Nothing aligned so check both rows and cols
else {
uint32_t k = k_start + BK;
for (; k <= k_end; k += BK) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup
loader_a.load_safe(
transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm));
loader_b.load_safe(
transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK));
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
short k_remain = BK - short(k - k_end);
const short2 tile_dims_A =
transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
const short2 tile_dims_B =
transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);
if (k_remain > 0) {
threadgroup_barrier(mem_flags::mem_threadgroup);
loader_a.load_safe(tile_dims_A);
loader_b.load_safe(tile_dims_B);
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(As, Bs);
}
mma_op.store_result_safe(C, params->ldd, short2(tgp_bn, tgp_bm));
}
}
}

View File

@ -0,0 +1,43 @@
// Copyright © 2024 Apple Inc.
// clang-format off
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
#include "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h"
#define instantiate_segmented_mm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_kernel( \
"steel_segmented_mm_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn \
"_bk" #bk "_wm" #wm "_wn" #wn, \
segmented_mm, \
itype, \
bm, \
bn, \
bk, \
wm, \
wn, \
trans_a, \
trans_b, \
float)
#define instantiate_segmented_mm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_segmented_mm(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_segmented_mm(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_segmented_mm(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_segmented_mm(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn)
#define instantiate_segmented_mm_shapes_helper(iname, itype, oname, otype) \
instantiate_segmented_mm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) \
instantiate_segmented_mm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 1, 2) \
instantiate_segmented_mm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \
instantiate_segmented_mm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 1, 2) \
instantiate_segmented_mm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2)
// clang-format on
instantiate_segmented_mm_shapes_helper(float16, half, float16, half);
instantiate_segmented_mm_shapes_helper(
bfloat16,
bfloat16_t,
bfloat16,
bfloat16_t);
instantiate_segmented_mm_shapes_helper(float32, float, float32, float);

View File

@ -1864,4 +1864,166 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
gather_mm(a, b, lhs_indices, rhs_indices, out, M, N, K, d, s);
}
void segmented_mm(
const array& a_,
const array& b_,
const array& segments_,
array& out,
int M,
int N,
int K,
metal::Device& d,
const Stream& s) {
auto check_segments_layout = [&d, &s](const array& x) {
// Contiguous so return early
if (x.flags().row_contiguous) {
return std::make_tuple(true, x);
}
bool rc = true;
for (int i = 0; i < x.ndim() - 2; i++) {
rc &=
(x.strides(i + 1) * x.shape(i) == x.strides(i)) || (x.shape(i) == 1);
}
rc &= x.strides(x.ndim() - 1) == 1;
if (x.ndim() > 1) {
rc &= x.strides(x.ndim() - 2) == 1;
}
if (rc) {
return std::make_tuple(false, x);
}
array x_copy(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
d.add_temporary(x_copy, s.index);
return std::make_tuple(true, x_copy);
};
// Copy if needed
std::vector<array> copies;
auto [transpose_a, lda, a] = check_transpose(copies, s, a_, false);
auto [transpose_b, ldb, b] = check_transpose(copies, s, b_, false);
auto [segments_contiguous, segments] = check_segments_layout(segments_);
d.add_temporaries(std::move(copies), s.index);
// Determine dispatch kernel
int bm = 64, bn = 64, bk = 16;
int wm = 2, wn = 2;
size_t batch_size_out = out.size() / M / N;
char devc = d.get_architecture().back();
GEMM_TPARAM_MACRO(devc)
const bool align_M = (M % bm) == 0;
const bool align_N = (N % bn) == 0;
// Define the kernel name
std::string base_name;
base_name.reserve(128);
concatenate(
base_name,
"steel_segmented_mm_",
transpose_a ? 't' : 'n',
transpose_b ? 't' : 'n',
"_",
type_to_name(a),
"_",
type_to_name(out),
"_bm",
bm,
"_bn",
bn,
"_bk",
bk,
"_wm",
wm,
"_wn",
wn);
metal::MTLFCList func_consts = {
{&segments_contiguous, MTL::DataType::DataTypeBool, 199},
{&align_M, MTL::DataType::DataTypeBool, 200},
{&align_N, MTL::DataType::DataTypeBool, 201},
};
// And the kernel hash that includes the function constants
std::string hash_name;
hash_name.reserve(128);
concatenate(
hash_name,
base_name,
"_segments_contiguous_",
segments_contiguous ? 't' : 'n',
"_align_M_",
align_M ? 't' : 'n',
"_align_N_",
align_N ? 't' : 'n');
// Get and set the kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = get_steel_gemm_segmented_kernel(
d,
base_name,
hash_name,
func_consts,
out,
transpose_a,
transpose_b,
bm,
bn,
bk,
wm,
wn);
compute_encoder.set_compute_pipeline_state(kernel);
// Prepare the matmul params
steel::GEMMParams params{
/* const int M = */ M,
/* const int N = */ N,
/* const int K = */ K,
/* const int lda = */ static_cast<int>(lda),
/* const int ldb = */ static_cast<int>(ldb),
/* const int ldd = */ N,
/* const int tiles_n = */ (N + bn - 1) / bn,
/* const int tiles_m = */ (M + bm - 1) / bm,
/* const int64_t batch_stride_a = */ 0,
/* const int64_t batch_stride_b = */ 0,
/* const int64_t batch_stride_d = */ M * N,
/* const int swizzle_log = */ 0,
/* const int gemm_k_iterations_aligned = */ 0,
/* const int batch_ndim = */ 0};
// Prepare the grid
MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims =
MTL::Size(params.tiles_n, params.tiles_m, batch_size_out);
// Launch kernel
compute_encoder.set_input_array(a, 0);
compute_encoder.set_input_array(b, 1);
compute_encoder.set_input_array(segments, 2);
compute_encoder.set_output_array(out, 3);
compute_encoder.set_bytes(params, 4);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}
void SegmentedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& s = stream();
auto& d = metal::device(s.device);
auto& a = inputs[0];
auto& b = inputs[1];
auto& segments = inputs[2];
out.set_data(allocator::malloc(out.nbytes()));
// Extract shapes from inputs.
int M = a.shape(-2);
int N = b.shape(-1);
int K = a.shape(-1);
segmented_mm(a, b, segments, out, M, N, K, d, s);
}
} // namespace mlx::core

View File

@ -210,6 +210,22 @@ MTL::ComputePipelineState* get_steel_gemm_gather_kernel(
return d.get_kernel(kernel_name, hash_name, func_consts);
}
MTL::ComputePipelineState* get_steel_gemm_segmented_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& hash_name,
const metal::MTLFCList& func_consts,
const array&,
bool,
bool,
int,
int,
int,
int,
int) {
return d.get_kernel(kernel_name, hash_name, func_consts);
}
MTL::ComputePipelineState* get_gemv_masked_kernel(
metal::Device& d,
const std::string& kernel_name,

View File

@ -105,6 +105,7 @@ NO_CPU(Scan)
NO_CPU(Scatter)
NO_CPU(ScatterAxis)
NO_CPU(Select)
NO_CPU(SegmentedMM)
NO_CPU(Sigmoid)
NO_CPU(Sign)
NO_CPU(Sin)

View File

@ -121,6 +121,7 @@ NO_GPU(Scan)
NO_GPU(Scatter)
NO_GPU(ScatterAxis)
NO_GPU(Select)
NO_GPU(SegmentedMM)
NO_GPU(Sigmoid)
NO_GPU(Sign)
NO_GPU(Sin)

View File

@ -4649,6 +4649,54 @@ array gather_mm(
return axes.empty() ? out : squeeze(out, axes, s);
}
array segmented_mm(
array a,
array b,
array segments,
StreamOrDevice s /* = {} */) {
if (a.ndim() != 2 || b.ndim() != 2) {
throw std::invalid_argument("[segmented_mm] Batched matmul not supported");
}
if (segments.ndim() < 1 || segments.shape().back() != 2) {
std::ostringstream msg;
msg << "[segmented_mm] The segments should have shape (..., 2) but "
<< segments.shape() << " was provided.";
throw std::invalid_argument(msg.str());
}
// Type promotion
auto out_type = result_type(a, b);
if (!issubdtype(out_type, floating)) {
std::ostringstream msg;
msg << "[segmented_mm] Only real floating point types are supported but "
<< a.dtype() << " and " << b.dtype()
<< " were provided which results in " << out_type
<< ", which is not a real floating point type.";
throw std::invalid_argument(msg.str());
}
if (!issubdtype(segments.dtype(), integer)) {
throw std::invalid_argument(
"[segmented_mm] Got segments with invalid dtype. Segments must be integral.");
}
a = astype(a, out_type, s);
b = astype(b, out_type, s);
segments = astype(segments, uint32, s);
Shape out_shape = segments.shape();
out_shape.pop_back();
out_shape.push_back(a.shape(0));
out_shape.push_back(b.shape(1));
return array(
std::move(out_shape),
out_type,
std::make_shared<SegmentedMM>(to_stream(s)),
{std::move(a), std::move(b), std::move(segments)});
}
array diagonal(
const array& a,
int offset /* = 0 */,

View File

@ -1406,6 +1406,12 @@ array gather_mm(
bool sorted_indices = false,
StreamOrDevice s = {});
/**
* Compute a matrix product but segment the inner dimension and write the
* result separately for each segment.
*/
array segmented_mm(array a, array b, array segments, StreamOrDevice s = {});
/** Extract a diagonal or construct a diagonal array */
array diagonal(
const array& a,

View File

@ -109,6 +109,70 @@ std::tuple<array, array, array, int> vmap_ternary_op(
return {a, b, c, to_ax};
}
// Calculate the gradient wrt to the weights of the following calculation
//
// y = gather_mm(x, w.T, lhs_indices, rhs_indices, sorted)
//
// Note the transpose above. This function returns the gradient for w.T so if w
// was used instead then one needs to transpose the returned gradient.
//
// We define it as a separate function to reuse it for gather_mm and
// gather_qmm.
array gather_mm_grad(
const array& x,
const array& dy,
const array& lhs_indices,
const array& rhs_indices,
bool sorted,
Shape batch_shape,
const Stream& s) {
int M = x.shape(-2);
int K = x.shape(-1);
int N = dy.shape(-1);
int num_segments = std::accumulate(
batch_shape.begin(), batch_shape.end(), 1, std::multiplies<int>());
batch_shape.push_back(N);
batch_shape.push_back(K);
// If the indices are sorted then it means that we can do the whole gradient
// computation via a segmented matmul. We just need to calculate the segments
// using the indices.
if (sorted) {
auto segments = zeros({num_segments}, uint32, s);
segments = scatter_add_axis(segments, rhs_indices, array(M, uint32), 0, s);
segments = cumsum(segments, 0, false, true, s);
segments = concatenate({array({0}, {1}, uint32), segments}, 0, s);
segments = as_strided(segments, {num_segments, 2}, {1, 1}, 0, s);
return reshape(
segmented_mm(
swapaxes(flatten(dy, 0, -2, s), 0, 1, s),
flatten(x, 0, -2, s),
segments,
s),
std::move(batch_shape),
s);
}
// Otherwise we need to gather matmul the dy and then scatter add it to the
// correct locations.
else {
// TODO: If the lhs indices wasn't provided, this is always a sorted matmul
// so we should add that check.
auto dw = gather_mm(
swapaxes(dy, -1, -2, s), x, std::nullopt, lhs_indices, false, s);
return reshape(
scatter_add(
zeros({num_segments, N, K}, dw.dtype(), s),
rhs_indices,
expand_dims(dw, -3, s),
0,
s),
std::move(batch_shape),
s);
}
}
} // namespace
std::vector<array> Primitive::jvp(
@ -3169,8 +3233,9 @@ std::vector<array> QuantizedMatmul::vjp(
"[QuantizedMatmul::vjp] no gradient wrt the quantized weights.");
} else {
if (!dsb) {
auto fc = flatten(cotangents[0], 0, -2, stream());
auto fx = flatten(primals[0], 0, -2, stream());
int ndim = primals[1].ndim();
auto fc = flatten(cotangents[0], 0, -ndim, stream());
auto fx = flatten(primals[0], 0, -ndim, stream());
auto dw = transpose_
? matmul(swapaxes(fc, -1, -2, stream()), fx, stream())
: matmul(swapaxes(fx, -1, -2, stream()), fc, stream());
@ -3181,7 +3246,6 @@ std::vector<array> QuantizedMatmul::vjp(
vjps.push_back(sum(*dsb, -1, false, stream()));
} else {
// scales
auto s = stream();
auto wq = dequantize(
primals[1],
ones_like(primals[2], stream()),
@ -3253,34 +3317,42 @@ std::vector<array> GatherQMM::vjp(
auto& lhs_indices = primals[4];
auto& rhs_indices = primals[5];
int M = cotan.shape(-2);
int N = cotan.shape(-1);
int K = x.shape(-1);
bool sorted = left_sorted_ || right_sorted_;
bool no_broadcast = rhs_indices.size() * M * K == x.size();
std::optional<array> dsb = std::nullopt;
for (auto arg : argnums) {
// gradient wrt to x
if (arg == 0) {
vjps.push_back(reshape(
scatter_add(
flatten(zeros_like(x, stream()), 0, -3, stream()),
lhs_indices,
expand_dims(
gather_qmm(
cotan,
w,
scales,
biases,
std::nullopt,
rhs_indices,
!transpose_,
group_size_,
bits_,
sorted,
stream()),
-3,
stream()),
0,
stream()),
x.shape(),
stream()));
auto g = gather_qmm(
cotan,
w,
scales,
biases,
std::nullopt,
rhs_indices,
!transpose_,
group_size_,
bits_,
sorted,
stream());
if (sorted && no_broadcast) {
vjps.push_back(g);
} else {
vjps.push_back(reshape(
scatter_add(
flatten(zeros_like(x, stream()), 0, -3, stream()),
lhs_indices,
expand_dims(g, -3, stream()),
0,
stream()),
x.shape(),
stream()));
}
}
// gradient wrt to the indices is undefined
@ -3290,9 +3362,49 @@ std::vector<array> GatherQMM::vjp(
}
// gradient wrt to w_q, scales or biases
else {
else if (arg == 1) {
throw std::runtime_error(
"GatherQMM::vjp no gradient wrt the quantized matrix yet.");
"GatherQMM::vjp no gradient wrt the quantized weights.");
} else {
if (!dsb) {
auto shape = w.shape();
shape.pop_back();
shape.pop_back();
dsb = unflatten(
gather_mm_grad(
x,
cotan,
lhs_indices,
rhs_indices,
sorted,
std::move(shape),
stream()),
-1,
{-1, group_size_},
stream());
}
if (arg == 3) {
vjps.push_back(sum(*dsb, -1, false, stream()));
} else {
vjps.push_back(
sum(multiply(
*dsb,
unflatten(
dequantize(
w,
ones_like(scales, stream()),
zeros_like(biases, stream()),
group_size_,
bits_,
stream()),
-1,
{-1, group_size_},
stream()),
stream()),
-1,
false,
stream()));
}
}
}
return vjps;
@ -5064,6 +5176,8 @@ std::vector<array> GatherMM::vjp(
std::vector<array> vjps;
auto& cotan = cotangents[0];
auto& a = primals[0];
auto& b = primals[1];
auto& lhs_indices = primals[2];
auto& rhs_indices = primals[3];
@ -5072,39 +5186,46 @@ std::vector<array> GatherMM::vjp(
int K = primals[0].shape(-1);
bool sorted = left_sorted_ || right_sorted_;
bool no_broadcast = rhs_indices.size() * M * K == primals[0].size();
for (auto arg : argnums) {
if (arg == 0) {
// M X N * (K X N).T -> M X K
auto base = zeros_like(primals[0], stream());
auto bt = swapaxes(primals[1], -1, -2, stream());
auto base_shape = base.shape();
base = reshape(base, {-1, M, K}, stream());
// g : (out_batch_shape) + (M, K)
auto g =
gather_mm(cotan, bt, std::nullopt, rhs_indices, sorted, stream());
g = expand_dims(g, -3, stream());
auto gacc = scatter_add(base, lhs_indices, g, 0, stream());
vjps.push_back(reshape(gacc, base_shape, stream()));
auto g = gather_mm(
cotan,
swapaxes(b, -1, -2, stream()),
std::nullopt,
rhs_indices,
sorted,
stream());
if (sorted && no_broadcast) {
vjps.push_back(g);
} else {
vjps.push_back(reshape(
scatter_add(
flatten(zeros_like(a, stream()), 0, -3, stream()),
lhs_indices,
expand_dims(g, -3, stream()),
0,
stream()),
a.shape(),
stream()));
}
} else if (arg == 1) {
// (M X K).T * M X N -> K X N
auto base = zeros_like(primals[1], stream());
auto at = swapaxes(primals[0], -1, -2, stream());
auto base_shape = base.shape();
base = reshape(base, {-1, K, N}, stream());
// g : (out_batch_shape) + (K, N)
auto g =
gather_mm(at, cotan, lhs_indices, std::nullopt, sorted, stream());
g = expand_dims(g, -3, stream());
auto gacc = scatter_add(base, rhs_indices, g, 0, stream());
vjps.push_back(reshape(gacc, base_shape, stream()));
auto shape = b.shape();
shape.pop_back();
shape.pop_back();
vjps.push_back(swapaxes(
gather_mm_grad(
a,
cotan,
lhs_indices,
rhs_indices,
sorted,
std::move(shape),
stream()),
-1,
-2,
stream()));
} else {
throw std::invalid_argument(
"[GatherMM] Cannot calculate VJP with respect to indices.");

View File

@ -526,6 +526,16 @@ class GatherMM : public UnaryPrimitive {
bool right_sorted_;
};
class SegmentedMM : public UnaryPrimitive {
public:
explicit SegmentedMM(Stream stream) : UnaryPrimitive(stream) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_PRINT(SegmentedMM)
};
class BroadcastAxes : public UnaryPrimitive {
public:
explicit BroadcastAxes(Stream stream, std::vector<int> ignore_axes = {})

View File

@ -4321,6 +4321,28 @@ void init_ops(nb::module_& m) {
array: The result of the multiplication of ``x`` with ``w``
after gathering using ``lhs_indices`` and ``rhs_indices``.
)pbdoc");
m.def(
"segmented_mm",
&mx::segmented_mm,
nb::arg(),
nb::arg(),
"segments"_a,
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def segmented_mm(a: array, b: array, /, segments: array, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Perform a matrix multiplication but segment the inner dimension and
save the result for each segment separately.
Args:
a (array): Input array of shape ``MxK``.
b (array): Input array of shape ``KxN``.
segments (array): The offsets into the inner dimension for each segment.
Returns:
array: The result per segment of shape ``MxN``.
)pbdoc");
m.def(
"tensordot",
[](const mx::array& a,

View File

@ -8,6 +8,9 @@ cuda_skip = {
# Gather matmul NYI
"TestBlas.test_gather_matmul",
"TestBlas.test_gather_matmul_grad",
"TestBlas.test_gather_mm_sorted",
# Segmented matmul NYI
"TestBlas.test_segmented_mm",
# Scan NYI
"TestArray.test_api",
"TestAutograd.test_cumprod_grad",
@ -76,6 +79,7 @@ cuda_skip = {
"TestQuantized.test_gather_matmul_grad",
"TestQuantized.test_gather_qmm",
"TestQuantized.test_gather_qmm_sorted",
"TestQuantized.test_gather_qmm_grad",
"TestQuantized.test_non_multiples",
"TestQuantized.test_qmm",
"TestQuantized.test_qmm_jvp",

View File

@ -1163,6 +1163,99 @@ class TestBlas(mlx_tests.MLXTestCase):
self.assertEqual(r.shape, t.shape)
self.assertTrue(mx.allclose(r, t, atol=1e-4).item())
def test_gather_mm_sorted(self):
def gather_mm_ref(a, b, rhs):
b = b[rhs]
return a @ b
def gather_mm_test(a, b, rhs):
return mx.gather_mm(a, b, rhs_indices=rhs, sorted_indices=True)
a = mx.random.normal((100, 1, 100))
b = mx.random.normal((8, 100, 100))
rhs = mx.sort(mx.random.randint(0, 8, shape=(100,)))
c1 = gather_mm_ref(a, b, rhs)
c2 = gather_mm_test(a, b, rhs)
self.assertTrue(mx.allclose(c1, c2, atol=1e-4))
cotan = mx.random.normal(c1.shape)
c1, dc1 = mx.vjp(
lambda a, b: gather_mm_ref(a, b, rhs),
[a, b],
[cotan],
)
c2, dc2 = mx.vjp(
lambda a, b: gather_mm_test(a, b, rhs),
[a, b],
[cotan],
)
self.assertTrue(mx.allclose(c1[0], c2[0], atol=1e-4))
self.assertTrue(mx.allclose(dc1[0], dc2[0], atol=1e-4))
self.assertTrue(mx.allclose(dc1[1], dc2[1], atol=1e-4))
def test_segmented_mm(self):
def segmented_mm_ref(a, b, s):
s = s.tolist()
c = []
for s1, s2 in s:
c.append(a[:, s1:s2] @ b[s1:s2, :])
return mx.stack(c, axis=0)
shapes = [
(10, 10, 10),
(10, 10, 1000),
(1000, 1000, 1000),
]
all_segments = [[0, 0, 1.0], [0, 0.5, 1.0], [r / 9 for r in range(10)]]
for M, N, K in shapes:
for s in all_segments:
segments = []
for i in range(len(s) - 1):
segments.append([s[i], s[i + 1]])
segments = mx.array(segments)
segments = mx.minimum(K - 1, (K * segments).astype(mx.uint32))
a = mx.random.normal((M, K))
b = mx.random.normal((K, N))
c1 = segmented_mm_ref(a, b, segments)
c2 = mx.segmented_mm(a, b, segments)
self.assertTrue(mx.allclose(c1, c2, atol=1e-4))
a = mx.random.normal((K, M))
b = mx.random.normal((K, N))
c1 = segmented_mm_ref(a.T, b, segments)
c2 = mx.segmented_mm(a.T, b, segments)
self.assertTrue(mx.allclose(c1, c2, atol=1e-4))
a = mx.random.normal((M, K))
b = mx.random.normal((N, K))
c1 = segmented_mm_ref(a, b.T, segments)
c2 = mx.segmented_mm(a, b.T, segments)
self.assertTrue(mx.allclose(c1, c2, atol=1e-4))
a = mx.random.normal((K, M))
b = mx.random.normal((N, K))
c1 = segmented_mm_ref(a.T, b.T, segments)
c2 = mx.segmented_mm(a.T, b.T, segments)
self.assertTrue(mx.allclose(c1, c2, atol=1e-4))
with self.assertRaises(ValueError):
a = mx.ones((2, 10, 10))
s = mx.array([[0, 5], [5, 10]]).astype(mx.uint32)
mx.segmented_mm(a, a, s)
a = mx.ones((10, 1000))
s = mx.random.randint(0, 16, shape=(1000,))
s = mx.zeros(16, dtype=s.dtype).at[s].add(1)
s = mx.sort(s)
s = mx.cumsum(s)
s = mx.concatenate([mx.array([0]), s])
s = mx.as_strided(s, (16, 2), (1, 1))
s = mx.reshape(s, (2, 2, 4, 2))
c = mx.segmented_mm(a, a.T, s)
self.assertEqual(c.shape, (2, 2, 4, 10, 10))
def test_gemv_gemm_same_precision(self):
mx.random.seed(0)
N = 256

View File

@ -549,6 +549,49 @@ class TestQuantized(mlx_tests.MLXTestCase):
self.assertTrue(mx.allclose(y1, y3, atol=1e-5))
self.assertTrue(mx.allclose(y1, y4, atol=1e-5))
def test_gather_qmm_grad(self):
def gather_qmm_ref(x, w, s, b, lhs, rhs, trans, sort):
if lhs is not None:
x = x[lhs]
if rhs is not None:
w = w[rhs]
s = s[rhs]
b = b[rhs]
return mx.quantized_matmul(x, w, s, b, transpose=trans)
def gather_qmm(x, w, s, b, lhs, rhs, trans, sort):
return mx.gather_qmm(
x,
w,
s,
b,
transpose=trans,
lhs_indices=lhs,
rhs_indices=rhs,
sorted_indices=sort,
)
x = mx.random.normal((16, 1, 256))
w, s, b = mx.quantize(mx.random.normal((4, 256, 256)))
indices = mx.sort(mx.random.randint(0, 4, shape=(16,)))
cotan = mx.random.normal((16, 1, 256))
(o1,), (dx1, ds1, db1) = mx.vjp(
lambda x, s, b: gather_qmm_ref(x, w, s, b, None, indices, True, True),
[x, s, b],
[cotan],
)
(o2,), (dx2, ds2, db2) = mx.vjp(
lambda x, s, b: gather_qmm(x, w, s, b, None, indices, True, True),
[x, s, b],
[cotan],
)
self.assertTrue(mx.allclose(o1, o2, atol=1e-4))
self.assertTrue(mx.allclose(dx1, dx2, atol=1e-4))
self.assertTrue(mx.allclose(ds1, ds2, atol=1e-3))
self.assertTrue(mx.allclose(db1, db2, atol=1e-3))
def test_vjp_scales_biases(self):
mx.random.seed(0)
x = mx.random.normal(shape=(2, 2, 512))