From 4a9b29a8753ad65e2156bfe0d99d305fb48c4fcc Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Mon, 7 Jul 2025 17:59:53 -0700 Subject: [PATCH] MoE backward improvements (#2335) --- mlx/backend/cpu/masked_mm.cpp | 170 +++++++++++ mlx/backend/cuda/primitives.cu | 1 + mlx/backend/metal/CMakeLists.txt | 1 + mlx/backend/metal/indexing.cpp | 14 +- mlx/backend/metal/jit/includes.h | 1 + mlx/backend/metal/jit_kernels.cpp | 37 +++ mlx/backend/metal/kernels.h | 14 + mlx/backend/metal/kernels/CMakeLists.txt | 2 + .../steel/gemm/kernels/steel_gemm_segmented.h | 266 ++++++++++++++++++ .../gemm/kernels/steel_gemm_segmented.metal | 43 +++ mlx/backend/metal/matmul.cpp | 162 +++++++++++ mlx/backend/metal/nojit_kernels.cpp | 16 ++ mlx/backend/no_cpu/primitives.cpp | 1 + mlx/backend/no_gpu/primitives.cpp | 1 + mlx/ops.cpp | 48 ++++ mlx/ops.h | 6 + mlx/primitives.cpp | 235 ++++++++++++---- mlx/primitives.h | 10 + python/src/ops.cpp | 22 ++ python/tests/cuda_skip.py | 4 + python/tests/test_blas.py | 93 ++++++ python/tests/test_quantized.py | 43 +++ 22 files changed, 1130 insertions(+), 60 deletions(-) create mode 100644 mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h create mode 100644 mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.metal diff --git a/mlx/backend/cpu/masked_mm.cpp b/mlx/backend/cpu/masked_mm.cpp index 0be7c79ce..fbee6118f 100644 --- a/mlx/backend/cpu/masked_mm.cpp +++ b/mlx/backend/cpu/masked_mm.cpp @@ -6,6 +6,7 @@ #include "mlx/backend/common/utils.h" #include "mlx/backend/cpu/copy.h" #include "mlx/backend/cpu/encoder.h" +#include "mlx/backend/cpu/gemm.h" #include "mlx/backend/cpu/lapack.h" #include "mlx/primitives.h" @@ -52,6 +53,58 @@ inline void mask_matrix( } } +template +inline void segmented_mm( + const T* a, + const T* b, + const uint32_t* segments, + T* out, + bool a_transposed, + bool b_transposed, + size_t lda, + size_t ldb, + const Shape& a_shape, + const Strides& a_strides, + const Shape& b_shape, + const Strides& b_strides, + size_t num_segments, + const Shape& segments_shape, + const Strides& segments_strides) { + int ndim = a_shape.size(); + Shape a_copy = a_shape; + Shape b_copy = b_shape; + int32_t M = a_copy[ndim - 2]; + int32_t N = b_copy[ndim - 1]; + for (int i = 0; i < num_segments; i++) { + uint32_t k_start = + segments[elem_to_loc(2 * i, segments_shape, segments_strides)]; + uint32_t k_end = + segments[elem_to_loc(2 * i + 1, segments_shape, segments_strides)]; + if (k_end <= k_start) { + std::fill_n(out + i * M * N, M * N, T(0)); + continue; + } + a_copy[ndim - 1] = k_end - k_start; + b_copy[ndim - 2] = k_end - k_start; + matmul( + a + k_start * a_strides[ndim - 1], + b + k_start * b_strides[ndim - 2], + out + i * M * N, + a_transposed, + b_transposed, + lda, + ldb, + N, + 1.0, + 0.0, + 1, + a_copy, + a_strides, + b_copy, + b_strides); + } +} + } // namespace void BlockMaskedMM::eval_cpu(const std::vector& inputs, array& out) { @@ -437,4 +490,121 @@ void GatherMM::eval_cpu(const std::vector& inputs, array& out) { encoder.add_temporaries(std::move(temps)); } +void SegmentedMM::eval_cpu(const std::vector& inputs, array& out) { + out.set_data(allocator::malloc(out.nbytes())); + + auto& s = stream(); + auto& encoder = cpu::get_command_encoder(stream()); + auto check_transpose = [&s, &encoder](const array& x) { + auto stx = x.strides()[x.ndim() - 2]; + auto sty = x.strides()[x.ndim() - 1]; + if (stx == x.shape(-1) && sty == 1) { + return std::make_tuple(false, stx, x); + } else if (stx == 1 && sty == x.shape(-2)) { + return std::make_tuple(true, sty, x); + } else { + array xc(x.shape(), x.dtype(), nullptr, {}); + copy(x, xc, CopyType::General, s); + encoder.add_temporary(xc); + int64_t stx = x.shape(-1); + return std::make_tuple(false, stx, xc); + } + }; + + auto [a_transposed, lda, a] = check_transpose(inputs[0]); + auto [b_transposed, ldb, b] = check_transpose(inputs[1]); + auto& segments = inputs[2]; + + encoder.set_input_array(a); + encoder.set_input_array(b); + encoder.set_input_array(segments); + encoder.set_output_array(out); + encoder.dispatch([a = array::unsafe_weak_copy(a), + b = array::unsafe_weak_copy(b), + segments = array::unsafe_weak_copy(segments), + out_ptr = out.data(), + a_transposed = a_transposed, + b_transposed = b_transposed, + lda = lda, + ldb = ldb]() { + switch (a.dtype()) { + case float64: + segmented_mm( + a.data(), + b.data(), + segments.data(), + static_cast(out_ptr), + a_transposed, + b_transposed, + lda, + ldb, + a.shape(), + a.strides(), + b.shape(), + b.strides(), + segments.size() / 2, + segments.shape(), + segments.strides()); + break; + case float32: + segmented_mm( + a.data(), + b.data(), + segments.data(), + static_cast(out_ptr), + a_transposed, + b_transposed, + lda, + ldb, + a.shape(), + a.strides(), + b.shape(), + b.strides(), + segments.size() / 2, + segments.shape(), + segments.strides()); + break; + case float16: + segmented_mm( + a.data(), + b.data(), + segments.data(), + static_cast(out_ptr), + a_transposed, + b_transposed, + lda, + ldb, + a.shape(), + a.strides(), + b.shape(), + b.strides(), + segments.size() / 2, + segments.shape(), + segments.strides()); + break; + case bfloat16: + segmented_mm( + a.data(), + b.data(), + segments.data(), + static_cast(out_ptr), + a_transposed, + b_transposed, + lda, + ldb, + a.shape(), + a.strides(), + b.shape(), + b.strides(), + segments.size() / 2, + segments.shape(), + segments.strides()); + break; + default: + throw std::invalid_argument( + "Segmented mm supports only real float types."); + } + }); +} + } // namespace mlx::core diff --git a/mlx/backend/cuda/primitives.cu b/mlx/backend/cuda/primitives.cu index 18fa45a33..a8496b958 100644 --- a/mlx/backend/cuda/primitives.cu +++ b/mlx/backend/cuda/primitives.cu @@ -83,6 +83,7 @@ NO_GPU_MULTI(LUF) NO_GPU_MULTI(QRF) NO_GPU(QuantizedMatmul) NO_GPU(Scan) +NO_GPU(SegmentedMM) NO_GPU_MULTI(SVD) NO_GPU(Inverse) NO_GPU(Cholesky) diff --git a/mlx/backend/metal/CMakeLists.txt b/mlx/backend/metal/CMakeLists.txt index d0c872451..ccdd83202 100644 --- a/mlx/backend/metal/CMakeLists.txt +++ b/mlx/backend/metal/CMakeLists.txt @@ -63,6 +63,7 @@ if(MLX_METAL_JIT) make_jit_source(steel/gemm/kernels/steel_gemm_masked kernels/steel/defines.h) make_jit_source(steel/gemm/kernels/steel_gemm_gather) make_jit_source(steel/gemm/kernels/steel_gemm_splitk) + make_jit_source(steel/gemm/kernels/steel_gemm_segmented) make_jit_source( steel/conv/conv kernels/steel/utils.h diff --git a/mlx/backend/metal/indexing.cpp b/mlx/backend/metal/indexing.cpp index d2a601b1e..13ce88a62 100644 --- a/mlx/backend/metal/indexing.cpp +++ b/mlx/backend/metal/indexing.cpp @@ -575,9 +575,17 @@ void ScatterAxis::eval_gpu(const std::vector& inputs, array& out) { compute_encoder.set_output_array(out, 2); // Set source info - compute_encoder.set_vector_bytes(remove_index(idx.shape(), axis_), 3); - compute_encoder.set_vector_bytes(remove_index(upd.strides(), axis_), 4); - compute_encoder.set_vector_bytes(remove_index(idx.strides(), axis_), 5); + if (ndim > 1) { + compute_encoder.set_vector_bytes(remove_index(idx.shape(), axis_), 3); + compute_encoder.set_vector_bytes(remove_index(upd.strides(), axis_), 4); + compute_encoder.set_vector_bytes(remove_index(idx.strides(), axis_), 5); + } else { + // The following will be ignored in the kernel but we still have to set + // some value so that metal validation passes. + compute_encoder.set_vector_bytes(idx.shape(), 3); + compute_encoder.set_vector_bytes(upd.strides(), 4); + compute_encoder.set_vector_bytes(idx.strides(), 5); + } compute_encoder.set_bytes(ndim - 1, 6); compute_encoder.set_bytes(axis_, 7); compute_encoder.set_bytes(out.shape(axis_), 8); diff --git a/mlx/backend/metal/jit/includes.h b/mlx/backend/metal/jit/includes.h index 27ae22d05..b380a8374 100644 --- a/mlx/backend/metal/jit/includes.h +++ b/mlx/backend/metal/jit/includes.h @@ -34,6 +34,7 @@ const char* steel_gemm_fused(); const char* steel_gemm_masked(); const char* steel_gemm_splitk(); const char* steel_gemm_gather(); +const char* steel_gemm_segmented(); const char* conv(); const char* steel_conv(); const char* steel_conv_general(); diff --git a/mlx/backend/metal/jit_kernels.cpp b/mlx/backend/metal/jit_kernels.cpp index 467380c3a..fd0e0db09 100644 --- a/mlx/backend/metal/jit_kernels.cpp +++ b/mlx/backend/metal/jit_kernels.cpp @@ -652,6 +652,43 @@ MTL::ComputePipelineState* get_steel_gemm_gather_kernel( return d.get_kernel(kernel_name, lib, hash_name, func_consts); } +MTL::ComputePipelineState* get_steel_gemm_segmented_kernel( + metal::Device& d, + const std::string& kernel_name, + const std::string& hash_name, + const metal::MTLFCList& func_consts, + const array& out, + bool transpose_a, + bool transpose_b, + int bm, + int bn, + int bk, + int wm, + int wn) { + const auto& lib_name = kernel_name; + auto lib = d.get_library(lib_name, [&]() { + std::string kernel_source; + concatenate( + kernel_source, + metal::utils(), + metal::gemm(), + metal::steel_gemm_segmented(), + get_template_definition( + lib_name, + "segmented_mm", + get_type_string(out.dtype()), + bm, + bn, + bk, + wm, + wn, + transpose_a, + transpose_b)); + return kernel_source; + }); + return d.get_kernel(kernel_name, lib, hash_name, func_consts); +} + MTL::ComputePipelineState* get_gemv_masked_kernel( metal::Device& d, const std::string& kernel_name, diff --git a/mlx/backend/metal/kernels.h b/mlx/backend/metal/kernels.h index 1de5fa47c..794c67bdc 100644 --- a/mlx/backend/metal/kernels.h +++ b/mlx/backend/metal/kernels.h @@ -175,6 +175,20 @@ MTL::ComputePipelineState* get_steel_gemm_gather_kernel( int wn, bool rhs); +MTL::ComputePipelineState* get_steel_gemm_segmented_kernel( + metal::Device& d, + const std::string& kernel_name, + const std::string& hash_name, + const metal::MTLFCList& func_consts, + const array& out, + bool transpose_a, + bool transpose_b, + int bm, + int bn, + int bk, + int wm, + int wn); + MTL::ComputePipelineState* get_steel_conv_kernel( metal::Device& d, const std::string& kernel_name, diff --git a/mlx/backend/metal/kernels/CMakeLists.txt b/mlx/backend/metal/kernels/CMakeLists.txt index 3ee88ca46..4069d8c21 100644 --- a/mlx/backend/metal/kernels/CMakeLists.txt +++ b/mlx/backend/metal/kernels/CMakeLists.txt @@ -71,6 +71,7 @@ set(STEEL_HEADERS steel/gemm/kernels/steel_gemm_fused.h steel/gemm/kernels/steel_gemm_gather.h steel/gemm/kernels/steel_gemm_masked.h + steel/gemm/kernels/steel_gemm_segmented.h steel/gemm/kernels/steel_gemm_splitk.h steel/utils/type_traits.h steel/utils/integral_constant.h) @@ -120,6 +121,7 @@ if(NOT MLX_METAL_JIT) build_kernel(steel/gemm/kernels/steel_gemm_gather ${STEEL_HEADERS}) build_kernel(steel/gemm/kernels/steel_gemm_masked ${STEEL_HEADERS}) build_kernel(steel/gemm/kernels/steel_gemm_splitk ${STEEL_HEADERS}) + build_kernel(steel/gemm/kernels/steel_gemm_segmented ${STEEL_HEADERS}) build_kernel(gemv_masked steel/utils.h) endif() diff --git a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h new file mode 100644 index 000000000..b915eb343 --- /dev/null +++ b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h @@ -0,0 +1,266 @@ +// Copyright © 2025 Apple Inc. + +using namespace mlx::steel; + +constant bool segments_contiguous [[function_constant(199)]]; +constant bool align_M [[function_constant(200)]]; +constant bool align_N [[function_constant(201)]]; + +template < + typename T, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose_a, + bool transpose_b, + typename AccumType = float> +[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void segmented_mm( + const device T* A [[buffer(0)]], + const device T* B [[buffer(1)]], + const device uint32_t* segments [[buffer(2)]], + device T* C [[buffer(3)]], + const constant GEMMParams* params [[buffer(4)]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint3 tid [[threadgroup_position_in_grid]]) { + using gemm_kernel = GEMMKernel< + T, + T, + BM, + BN, + BK, + WM, + WN, + transpose_a, + transpose_b, + true, + true, + AccumType>; + + using loader_a_t = typename gemm_kernel::loader_a_t; + using loader_b_t = typename gemm_kernel::loader_b_t; + using mma_t = typename gemm_kernel::mma_t; + + if (params->tiles_n <= static_cast(tid.x) || + params->tiles_m <= static_cast(tid.y)) { + return; + } + + // Prepare threadgroup memory + threadgroup T As[gemm_kernel::tgp_mem_size_a]; + threadgroup T Bs[gemm_kernel::tgp_mem_size_b]; + + // Find the block in A, B, C + const int c_row = tid.y * BM; + const int c_col = tid.x * BN; + const size_t c_row_long = size_t(c_row); + const size_t c_col_long = size_t(c_col); + + // Prepare threadgroup bounds + const short tgp_bm = align_M ? BM : short(min(BM, params->M - c_row)); + const short tgp_bn = align_N ? BN : short(min(BN, params->N - c_col)); + + // Move the pointers to the output tile + A += transpose_a ? c_row_long : c_row_long * params->lda; + B += transpose_b ? c_col_long * params->ldb : c_col_long; + C += c_row_long * params->ldd + c_col_long; + + // Move the pointers to the start of the segment + uint32_t k_start, k_end; + if (segments_contiguous) { + k_start = segments[2 * tid.z]; + k_end = segments[2 * tid.z + 1]; + } else { + // We accept either contiguous (above) or weird strides where the beginning + // of the next one is the previous one. Basically the last two strides are + // both 1! + k_start = segments[tid.z]; + k_end = segments[tid.z + 1]; + } + A += transpose_a ? k_start * params->lda : k_start; + B += transpose_b ? k_start : k_start * params->ldb; + C += tid.z * params->batch_stride_d; + + // Prepare threadgroup mma operation + thread mma_t mma_op(simd_group_id, simd_lane_id); + + // Prepare threadgroup loading operations + thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id); + thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id); + + // Matrix level alignment so only check K + if (align_M && align_N) { + uint32_t k = k_start + BK; + for (; k <= k_end; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Load elements into threadgroup + loader_a.load_unsafe(); + loader_b.load_unsafe(); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + short k_remain = BK - short(k - k_end); + const short2 tile_dims_A = + transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm); + const short2 tile_dims_B = + transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain); + if (k_remain > 0) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_a.load_safe(tile_dims_A); + loader_b.load_safe(tile_dims_B); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(As, Bs); + } + mma_op.store_result(C, params->ldd); + } else { + // Tile aligned do the same as above + if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) { + uint32_t k = k_start + BK; + for (; k <= k_end; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Load elements into threadgroup + loader_a.load_unsafe(); + loader_b.load_unsafe(); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + short k_remain = BK - short(k - k_end); + const short2 tile_dims_A = + transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm); + const short2 tile_dims_B = + transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain); + if (k_remain > 0) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_a.load_safe(tile_dims_A); + loader_b.load_safe(tile_dims_B); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(As, Bs); + } + mma_op.store_result(C, params->ldd); + } + + // Tile partially aligned check rows + else if (align_N || tgp_bn == BN) { + uint32_t k = k_start + BK; + for (; k <= k_end; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Load elements into threadgroup + loader_a.load_safe( + transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm)); + loader_b.load_unsafe(); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + short k_remain = BK - short(k - k_end); + const short2 tile_dims_A = + transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm); + const short2 tile_dims_B = + transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain); + if (k_remain > 0) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_a.load_safe(tile_dims_A); + loader_b.load_safe(tile_dims_B); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(As, Bs); + } + mma_op.store_result_safe(C, params->ldd, short2(tgp_bn, tgp_bm)); + } + + // Tile partially aligned check cols + else if (align_M || tgp_bm == BM) { + uint32_t k = k_start + BK; + for (; k <= k_end; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Load elements into threadgroup + loader_a.load_unsafe(); + loader_b.load_safe( + transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK)); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + short k_remain = BK - short(k - k_end); + const short2 tile_dims_A = + transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm); + const short2 tile_dims_B = + transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain); + if (k_remain > 0) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_a.load_safe(tile_dims_A); + loader_b.load_safe(tile_dims_B); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(As, Bs); + } + mma_op.store_result_safe(C, params->ldd, short2(tgp_bn, tgp_bm)); + } + + // Nothing aligned so check both rows and cols + else { + uint32_t k = k_start + BK; + for (; k <= k_end; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Load elements into threadgroup + loader_a.load_safe( + transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm)); + loader_b.load_safe( + transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK)); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + short k_remain = BK - short(k - k_end); + const short2 tile_dims_A = + transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm); + const short2 tile_dims_B = + transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain); + if (k_remain > 0) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_a.load_safe(tile_dims_A); + loader_b.load_safe(tile_dims_B); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(As, Bs); + } + mma_op.store_result_safe(C, params->ldd, short2(tgp_bn, tgp_bm)); + } + } +} diff --git a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.metal b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.metal new file mode 100644 index 000000000..a7515c359 --- /dev/null +++ b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.metal @@ -0,0 +1,43 @@ +// Copyright © 2024 Apple Inc. + +// clang-format off +#include "mlx/backend/metal/kernels/utils.h" +#include "mlx/backend/metal/kernels/steel/gemm/gemm.h" +#include "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h" + +#define instantiate_segmented_mm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + instantiate_kernel( \ + "steel_segmented_mm_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn \ + "_bk" #bk "_wm" #wm "_wn" #wn, \ + segmented_mm, \ + itype, \ + bm, \ + bn, \ + bk, \ + wm, \ + wn, \ + trans_a, \ + trans_b, \ + float) + +#define instantiate_segmented_mm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + instantiate_segmented_mm(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + instantiate_segmented_mm(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + instantiate_segmented_mm(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + instantiate_segmented_mm(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn) + +#define instantiate_segmented_mm_shapes_helper(iname, itype, oname, otype) \ + instantiate_segmented_mm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) \ + instantiate_segmented_mm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 1, 2) \ + instantiate_segmented_mm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \ + instantiate_segmented_mm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 1, 2) \ + instantiate_segmented_mm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) +// clang-format on + +instantiate_segmented_mm_shapes_helper(float16, half, float16, half); +instantiate_segmented_mm_shapes_helper( + bfloat16, + bfloat16_t, + bfloat16, + bfloat16_t); +instantiate_segmented_mm_shapes_helper(float32, float, float32, float); diff --git a/mlx/backend/metal/matmul.cpp b/mlx/backend/metal/matmul.cpp index be7f3e2f8..55b8be3a9 100644 --- a/mlx/backend/metal/matmul.cpp +++ b/mlx/backend/metal/matmul.cpp @@ -1864,4 +1864,166 @@ void GatherMM::eval_gpu(const std::vector& inputs, array& out) { gather_mm(a, b, lhs_indices, rhs_indices, out, M, N, K, d, s); } +void segmented_mm( + const array& a_, + const array& b_, + const array& segments_, + array& out, + int M, + int N, + int K, + metal::Device& d, + const Stream& s) { + auto check_segments_layout = [&d, &s](const array& x) { + // Contiguous so return early + if (x.flags().row_contiguous) { + return std::make_tuple(true, x); + } + + bool rc = true; + for (int i = 0; i < x.ndim() - 2; i++) { + rc &= + (x.strides(i + 1) * x.shape(i) == x.strides(i)) || (x.shape(i) == 1); + } + rc &= x.strides(x.ndim() - 1) == 1; + if (x.ndim() > 1) { + rc &= x.strides(x.ndim() - 2) == 1; + } + + if (rc) { + return std::make_tuple(false, x); + } + + array x_copy(x.shape(), x.dtype(), nullptr, {}); + copy_gpu(x, x_copy, CopyType::General, s); + d.add_temporary(x_copy, s.index); + return std::make_tuple(true, x_copy); + }; + + // Copy if needed + std::vector copies; + auto [transpose_a, lda, a] = check_transpose(copies, s, a_, false); + auto [transpose_b, ldb, b] = check_transpose(copies, s, b_, false); + auto [segments_contiguous, segments] = check_segments_layout(segments_); + d.add_temporaries(std::move(copies), s.index); + + // Determine dispatch kernel + int bm = 64, bn = 64, bk = 16; + int wm = 2, wn = 2; + size_t batch_size_out = out.size() / M / N; + + char devc = d.get_architecture().back(); + GEMM_TPARAM_MACRO(devc) + + const bool align_M = (M % bm) == 0; + const bool align_N = (N % bn) == 0; + + // Define the kernel name + std::string base_name; + base_name.reserve(128); + concatenate( + base_name, + "steel_segmented_mm_", + transpose_a ? 't' : 'n', + transpose_b ? 't' : 'n', + "_", + type_to_name(a), + "_", + type_to_name(out), + "_bm", + bm, + "_bn", + bn, + "_bk", + bk, + "_wm", + wm, + "_wn", + wn); + + metal::MTLFCList func_consts = { + {&segments_contiguous, MTL::DataType::DataTypeBool, 199}, + {&align_M, MTL::DataType::DataTypeBool, 200}, + {&align_N, MTL::DataType::DataTypeBool, 201}, + }; + + // And the kernel hash that includes the function constants + std::string hash_name; + hash_name.reserve(128); + concatenate( + hash_name, + base_name, + "_segments_contiguous_", + segments_contiguous ? 't' : 'n', + "_align_M_", + align_M ? 't' : 'n', + "_align_N_", + align_N ? 't' : 'n'); + + // Get and set the kernel + auto& compute_encoder = d.get_command_encoder(s.index); + auto kernel = get_steel_gemm_segmented_kernel( + d, + base_name, + hash_name, + func_consts, + out, + transpose_a, + transpose_b, + bm, + bn, + bk, + wm, + wn); + compute_encoder.set_compute_pipeline_state(kernel); + + // Prepare the matmul params + steel::GEMMParams params{ + /* const int M = */ M, + /* const int N = */ N, + /* const int K = */ K, + /* const int lda = */ static_cast(lda), + /* const int ldb = */ static_cast(ldb), + /* const int ldd = */ N, + /* const int tiles_n = */ (N + bn - 1) / bn, + /* const int tiles_m = */ (M + bm - 1) / bm, + /* const int64_t batch_stride_a = */ 0, + /* const int64_t batch_stride_b = */ 0, + /* const int64_t batch_stride_d = */ M * N, + /* const int swizzle_log = */ 0, + /* const int gemm_k_iterations_aligned = */ 0, + /* const int batch_ndim = */ 0}; + + // Prepare the grid + MTL::Size group_dims = MTL::Size(32, wn, wm); + MTL::Size grid_dims = + MTL::Size(params.tiles_n, params.tiles_m, batch_size_out); + + // Launch kernel + compute_encoder.set_input_array(a, 0); + compute_encoder.set_input_array(b, 1); + compute_encoder.set_input_array(segments, 2); + compute_encoder.set_output_array(out, 3); + compute_encoder.set_bytes(params, 4); + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); +} + +void SegmentedMM::eval_gpu(const std::vector& inputs, array& out) { + auto& s = stream(); + auto& d = metal::device(s.device); + + auto& a = inputs[0]; + auto& b = inputs[1]; + auto& segments = inputs[2]; + + out.set_data(allocator::malloc(out.nbytes())); + + // Extract shapes from inputs. + int M = a.shape(-2); + int N = b.shape(-1); + int K = a.shape(-1); + + segmented_mm(a, b, segments, out, M, N, K, d, s); +} + } // namespace mlx::core diff --git a/mlx/backend/metal/nojit_kernels.cpp b/mlx/backend/metal/nojit_kernels.cpp index b0375e37f..32d3e75f7 100644 --- a/mlx/backend/metal/nojit_kernels.cpp +++ b/mlx/backend/metal/nojit_kernels.cpp @@ -210,6 +210,22 @@ MTL::ComputePipelineState* get_steel_gemm_gather_kernel( return d.get_kernel(kernel_name, hash_name, func_consts); } +MTL::ComputePipelineState* get_steel_gemm_segmented_kernel( + metal::Device& d, + const std::string& kernel_name, + const std::string& hash_name, + const metal::MTLFCList& func_consts, + const array&, + bool, + bool, + int, + int, + int, + int, + int) { + return d.get_kernel(kernel_name, hash_name, func_consts); +} + MTL::ComputePipelineState* get_gemv_masked_kernel( metal::Device& d, const std::string& kernel_name, diff --git a/mlx/backend/no_cpu/primitives.cpp b/mlx/backend/no_cpu/primitives.cpp index 1a180bfe0..09e6c4ef3 100644 --- a/mlx/backend/no_cpu/primitives.cpp +++ b/mlx/backend/no_cpu/primitives.cpp @@ -105,6 +105,7 @@ NO_CPU(Scan) NO_CPU(Scatter) NO_CPU(ScatterAxis) NO_CPU(Select) +NO_CPU(SegmentedMM) NO_CPU(Sigmoid) NO_CPU(Sign) NO_CPU(Sin) diff --git a/mlx/backend/no_gpu/primitives.cpp b/mlx/backend/no_gpu/primitives.cpp index 409aa2c89..dfe5b57f1 100644 --- a/mlx/backend/no_gpu/primitives.cpp +++ b/mlx/backend/no_gpu/primitives.cpp @@ -121,6 +121,7 @@ NO_GPU(Scan) NO_GPU(Scatter) NO_GPU(ScatterAxis) NO_GPU(Select) +NO_GPU(SegmentedMM) NO_GPU(Sigmoid) NO_GPU(Sign) NO_GPU(Sin) diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 2b861428f..7161a39b2 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -4649,6 +4649,54 @@ array gather_mm( return axes.empty() ? out : squeeze(out, axes, s); } +array segmented_mm( + array a, + array b, + array segments, + StreamOrDevice s /* = {} */) { + if (a.ndim() != 2 || b.ndim() != 2) { + throw std::invalid_argument("[segmented_mm] Batched matmul not supported"); + } + + if (segments.ndim() < 1 || segments.shape().back() != 2) { + std::ostringstream msg; + msg << "[segmented_mm] The segments should have shape (..., 2) but " + << segments.shape() << " was provided."; + throw std::invalid_argument(msg.str()); + } + + // Type promotion + auto out_type = result_type(a, b); + if (!issubdtype(out_type, floating)) { + std::ostringstream msg; + msg << "[segmented_mm] Only real floating point types are supported but " + << a.dtype() << " and " << b.dtype() + << " were provided which results in " << out_type + << ", which is not a real floating point type."; + throw std::invalid_argument(msg.str()); + } + + if (!issubdtype(segments.dtype(), integer)) { + throw std::invalid_argument( + "[segmented_mm] Got segments with invalid dtype. Segments must be integral."); + } + + a = astype(a, out_type, s); + b = astype(b, out_type, s); + segments = astype(segments, uint32, s); + + Shape out_shape = segments.shape(); + out_shape.pop_back(); + out_shape.push_back(a.shape(0)); + out_shape.push_back(b.shape(1)); + + return array( + std::move(out_shape), + out_type, + std::make_shared(to_stream(s)), + {std::move(a), std::move(b), std::move(segments)}); +} + array diagonal( const array& a, int offset /* = 0 */, diff --git a/mlx/ops.h b/mlx/ops.h index af3cdb5bd..596d6d287 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -1406,6 +1406,12 @@ array gather_mm( bool sorted_indices = false, StreamOrDevice s = {}); +/** + * Compute a matrix product but segment the inner dimension and write the + * result separately for each segment. + */ +array segmented_mm(array a, array b, array segments, StreamOrDevice s = {}); + /** Extract a diagonal or construct a diagonal array */ array diagonal( const array& a, diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 5f2bfdda4..b2b7306dd 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -109,6 +109,70 @@ std::tuple vmap_ternary_op( return {a, b, c, to_ax}; } +// Calculate the gradient wrt to the weights of the following calculation +// +// y = gather_mm(x, w.T, lhs_indices, rhs_indices, sorted) +// +// Note the transpose above. This function returns the gradient for w.T so if w +// was used instead then one needs to transpose the returned gradient. +// +// We define it as a separate function to reuse it for gather_mm and +// gather_qmm. +array gather_mm_grad( + const array& x, + const array& dy, + const array& lhs_indices, + const array& rhs_indices, + bool sorted, + Shape batch_shape, + const Stream& s) { + int M = x.shape(-2); + int K = x.shape(-1); + int N = dy.shape(-1); + int num_segments = std::accumulate( + batch_shape.begin(), batch_shape.end(), 1, std::multiplies()); + batch_shape.push_back(N); + batch_shape.push_back(K); + + // If the indices are sorted then it means that we can do the whole gradient + // computation via a segmented matmul. We just need to calculate the segments + // using the indices. + if (sorted) { + auto segments = zeros({num_segments}, uint32, s); + segments = scatter_add_axis(segments, rhs_indices, array(M, uint32), 0, s); + segments = cumsum(segments, 0, false, true, s); + segments = concatenate({array({0}, {1}, uint32), segments}, 0, s); + segments = as_strided(segments, {num_segments, 2}, {1, 1}, 0, s); + + return reshape( + segmented_mm( + swapaxes(flatten(dy, 0, -2, s), 0, 1, s), + flatten(x, 0, -2, s), + segments, + s), + std::move(batch_shape), + s); + } + + // Otherwise we need to gather matmul the dy and then scatter add it to the + // correct locations. + else { + // TODO: If the lhs indices wasn't provided, this is always a sorted matmul + // so we should add that check. + auto dw = gather_mm( + swapaxes(dy, -1, -2, s), x, std::nullopt, lhs_indices, false, s); + return reshape( + scatter_add( + zeros({num_segments, N, K}, dw.dtype(), s), + rhs_indices, + expand_dims(dw, -3, s), + 0, + s), + std::move(batch_shape), + s); + } +} + } // namespace std::vector Primitive::jvp( @@ -3169,8 +3233,9 @@ std::vector QuantizedMatmul::vjp( "[QuantizedMatmul::vjp] no gradient wrt the quantized weights."); } else { if (!dsb) { - auto fc = flatten(cotangents[0], 0, -2, stream()); - auto fx = flatten(primals[0], 0, -2, stream()); + int ndim = primals[1].ndim(); + auto fc = flatten(cotangents[0], 0, -ndim, stream()); + auto fx = flatten(primals[0], 0, -ndim, stream()); auto dw = transpose_ ? matmul(swapaxes(fc, -1, -2, stream()), fx, stream()) : matmul(swapaxes(fx, -1, -2, stream()), fc, stream()); @@ -3181,7 +3246,6 @@ std::vector QuantizedMatmul::vjp( vjps.push_back(sum(*dsb, -1, false, stream())); } else { // scales - auto s = stream(); auto wq = dequantize( primals[1], ones_like(primals[2], stream()), @@ -3253,34 +3317,42 @@ std::vector GatherQMM::vjp( auto& lhs_indices = primals[4]; auto& rhs_indices = primals[5]; + int M = cotan.shape(-2); + int N = cotan.shape(-1); + int K = x.shape(-1); + bool sorted = left_sorted_ || right_sorted_; + bool no_broadcast = rhs_indices.size() * M * K == x.size(); + std::optional dsb = std::nullopt; for (auto arg : argnums) { // gradient wrt to x if (arg == 0) { - vjps.push_back(reshape( - scatter_add( - flatten(zeros_like(x, stream()), 0, -3, stream()), - lhs_indices, - expand_dims( - gather_qmm( - cotan, - w, - scales, - biases, - std::nullopt, - rhs_indices, - !transpose_, - group_size_, - bits_, - sorted, - stream()), - -3, - stream()), - 0, - stream()), - x.shape(), - stream())); + auto g = gather_qmm( + cotan, + w, + scales, + biases, + std::nullopt, + rhs_indices, + !transpose_, + group_size_, + bits_, + sorted, + stream()); + if (sorted && no_broadcast) { + vjps.push_back(g); + } else { + vjps.push_back(reshape( + scatter_add( + flatten(zeros_like(x, stream()), 0, -3, stream()), + lhs_indices, + expand_dims(g, -3, stream()), + 0, + stream()), + x.shape(), + stream())); + } } // gradient wrt to the indices is undefined @@ -3290,9 +3362,49 @@ std::vector GatherQMM::vjp( } // gradient wrt to w_q, scales or biases - else { + else if (arg == 1) { throw std::runtime_error( - "GatherQMM::vjp no gradient wrt the quantized matrix yet."); + "GatherQMM::vjp no gradient wrt the quantized weights."); + } else { + if (!dsb) { + auto shape = w.shape(); + shape.pop_back(); + shape.pop_back(); + dsb = unflatten( + gather_mm_grad( + x, + cotan, + lhs_indices, + rhs_indices, + sorted, + std::move(shape), + stream()), + -1, + {-1, group_size_}, + stream()); + } + if (arg == 3) { + vjps.push_back(sum(*dsb, -1, false, stream())); + } else { + vjps.push_back( + sum(multiply( + *dsb, + unflatten( + dequantize( + w, + ones_like(scales, stream()), + zeros_like(biases, stream()), + group_size_, + bits_, + stream()), + -1, + {-1, group_size_}, + stream()), + stream()), + -1, + false, + stream())); + } } } return vjps; @@ -5064,6 +5176,8 @@ std::vector GatherMM::vjp( std::vector vjps; auto& cotan = cotangents[0]; + auto& a = primals[0]; + auto& b = primals[1]; auto& lhs_indices = primals[2]; auto& rhs_indices = primals[3]; @@ -5072,39 +5186,46 @@ std::vector GatherMM::vjp( int K = primals[0].shape(-1); bool sorted = left_sorted_ || right_sorted_; + bool no_broadcast = rhs_indices.size() * M * K == primals[0].size(); for (auto arg : argnums) { if (arg == 0) { - // M X N * (K X N).T -> M X K - auto base = zeros_like(primals[0], stream()); - auto bt = swapaxes(primals[1], -1, -2, stream()); - - auto base_shape = base.shape(); - base = reshape(base, {-1, M, K}, stream()); - - // g : (out_batch_shape) + (M, K) - auto g = - gather_mm(cotan, bt, std::nullopt, rhs_indices, sorted, stream()); - g = expand_dims(g, -3, stream()); - auto gacc = scatter_add(base, lhs_indices, g, 0, stream()); - - vjps.push_back(reshape(gacc, base_shape, stream())); - + auto g = gather_mm( + cotan, + swapaxes(b, -1, -2, stream()), + std::nullopt, + rhs_indices, + sorted, + stream()); + if (sorted && no_broadcast) { + vjps.push_back(g); + } else { + vjps.push_back(reshape( + scatter_add( + flatten(zeros_like(a, stream()), 0, -3, stream()), + lhs_indices, + expand_dims(g, -3, stream()), + 0, + stream()), + a.shape(), + stream())); + } } else if (arg == 1) { - // (M X K).T * M X N -> K X N - auto base = zeros_like(primals[1], stream()); - auto at = swapaxes(primals[0], -1, -2, stream()); - - auto base_shape = base.shape(); - base = reshape(base, {-1, K, N}, stream()); - - // g : (out_batch_shape) + (K, N) - auto g = - gather_mm(at, cotan, lhs_indices, std::nullopt, sorted, stream()); - g = expand_dims(g, -3, stream()); - auto gacc = scatter_add(base, rhs_indices, g, 0, stream()); - - vjps.push_back(reshape(gacc, base_shape, stream())); + auto shape = b.shape(); + shape.pop_back(); + shape.pop_back(); + vjps.push_back(swapaxes( + gather_mm_grad( + a, + cotan, + lhs_indices, + rhs_indices, + sorted, + std::move(shape), + stream()), + -1, + -2, + stream())); } else { throw std::invalid_argument( "[GatherMM] Cannot calculate VJP with respect to indices."); diff --git a/mlx/primitives.h b/mlx/primitives.h index 4b18430ca..f4f157298 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -526,6 +526,16 @@ class GatherMM : public UnaryPrimitive { bool right_sorted_; }; +class SegmentedMM : public UnaryPrimitive { + public: + explicit SegmentedMM(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_PRINT(SegmentedMM) +}; + class BroadcastAxes : public UnaryPrimitive { public: explicit BroadcastAxes(Stream stream, std::vector ignore_axes = {}) diff --git a/python/src/ops.cpp b/python/src/ops.cpp index a1e77d681..d047f64cb 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -4321,6 +4321,28 @@ void init_ops(nb::module_& m) { array: The result of the multiplication of ``x`` with ``w`` after gathering using ``lhs_indices`` and ``rhs_indices``. )pbdoc"); + m.def( + "segmented_mm", + &mx::segmented_mm, + nb::arg(), + nb::arg(), + "segments"_a, + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def segmented_mm(a: array, b: array, /, segments: array, *, stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + Perform a matrix multiplication but segment the inner dimension and + save the result for each segment separately. + + Args: + a (array): Input array of shape ``MxK``. + b (array): Input array of shape ``KxN``. + segments (array): The offsets into the inner dimension for each segment. + + Returns: + array: The result per segment of shape ``MxN``. + )pbdoc"); m.def( "tensordot", [](const mx::array& a, diff --git a/python/tests/cuda_skip.py b/python/tests/cuda_skip.py index fce92bacb..17eb80eee 100644 --- a/python/tests/cuda_skip.py +++ b/python/tests/cuda_skip.py @@ -8,6 +8,9 @@ cuda_skip = { # Gather matmul NYI "TestBlas.test_gather_matmul", "TestBlas.test_gather_matmul_grad", + "TestBlas.test_gather_mm_sorted", + # Segmented matmul NYI + "TestBlas.test_segmented_mm", # Scan NYI "TestArray.test_api", "TestAutograd.test_cumprod_grad", @@ -76,6 +79,7 @@ cuda_skip = { "TestQuantized.test_gather_matmul_grad", "TestQuantized.test_gather_qmm", "TestQuantized.test_gather_qmm_sorted", + "TestQuantized.test_gather_qmm_grad", "TestQuantized.test_non_multiples", "TestQuantized.test_qmm", "TestQuantized.test_qmm_jvp", diff --git a/python/tests/test_blas.py b/python/tests/test_blas.py index eb45df124..5e096d9c5 100644 --- a/python/tests/test_blas.py +++ b/python/tests/test_blas.py @@ -1163,6 +1163,99 @@ class TestBlas(mlx_tests.MLXTestCase): self.assertEqual(r.shape, t.shape) self.assertTrue(mx.allclose(r, t, atol=1e-4).item()) + def test_gather_mm_sorted(self): + def gather_mm_ref(a, b, rhs): + b = b[rhs] + return a @ b + + def gather_mm_test(a, b, rhs): + return mx.gather_mm(a, b, rhs_indices=rhs, sorted_indices=True) + + a = mx.random.normal((100, 1, 100)) + b = mx.random.normal((8, 100, 100)) + rhs = mx.sort(mx.random.randint(0, 8, shape=(100,))) + + c1 = gather_mm_ref(a, b, rhs) + c2 = gather_mm_test(a, b, rhs) + self.assertTrue(mx.allclose(c1, c2, atol=1e-4)) + + cotan = mx.random.normal(c1.shape) + c1, dc1 = mx.vjp( + lambda a, b: gather_mm_ref(a, b, rhs), + [a, b], + [cotan], + ) + c2, dc2 = mx.vjp( + lambda a, b: gather_mm_test(a, b, rhs), + [a, b], + [cotan], + ) + self.assertTrue(mx.allclose(c1[0], c2[0], atol=1e-4)) + self.assertTrue(mx.allclose(dc1[0], dc2[0], atol=1e-4)) + self.assertTrue(mx.allclose(dc1[1], dc2[1], atol=1e-4)) + + def test_segmented_mm(self): + def segmented_mm_ref(a, b, s): + s = s.tolist() + c = [] + for s1, s2 in s: + c.append(a[:, s1:s2] @ b[s1:s2, :]) + return mx.stack(c, axis=0) + + shapes = [ + (10, 10, 10), + (10, 10, 1000), + (1000, 1000, 1000), + ] + all_segments = [[0, 0, 1.0], [0, 0.5, 1.0], [r / 9 for r in range(10)]] + + for M, N, K in shapes: + for s in all_segments: + segments = [] + for i in range(len(s) - 1): + segments.append([s[i], s[i + 1]]) + segments = mx.array(segments) + segments = mx.minimum(K - 1, (K * segments).astype(mx.uint32)) + a = mx.random.normal((M, K)) + b = mx.random.normal((K, N)) + c1 = segmented_mm_ref(a, b, segments) + c2 = mx.segmented_mm(a, b, segments) + self.assertTrue(mx.allclose(c1, c2, atol=1e-4)) + + a = mx.random.normal((K, M)) + b = mx.random.normal((K, N)) + c1 = segmented_mm_ref(a.T, b, segments) + c2 = mx.segmented_mm(a.T, b, segments) + self.assertTrue(mx.allclose(c1, c2, atol=1e-4)) + + a = mx.random.normal((M, K)) + b = mx.random.normal((N, K)) + c1 = segmented_mm_ref(a, b.T, segments) + c2 = mx.segmented_mm(a, b.T, segments) + self.assertTrue(mx.allclose(c1, c2, atol=1e-4)) + + a = mx.random.normal((K, M)) + b = mx.random.normal((N, K)) + c1 = segmented_mm_ref(a.T, b.T, segments) + c2 = mx.segmented_mm(a.T, b.T, segments) + self.assertTrue(mx.allclose(c1, c2, atol=1e-4)) + + with self.assertRaises(ValueError): + a = mx.ones((2, 10, 10)) + s = mx.array([[0, 5], [5, 10]]).astype(mx.uint32) + mx.segmented_mm(a, a, s) + + a = mx.ones((10, 1000)) + s = mx.random.randint(0, 16, shape=(1000,)) + s = mx.zeros(16, dtype=s.dtype).at[s].add(1) + s = mx.sort(s) + s = mx.cumsum(s) + s = mx.concatenate([mx.array([0]), s]) + s = mx.as_strided(s, (16, 2), (1, 1)) + s = mx.reshape(s, (2, 2, 4, 2)) + c = mx.segmented_mm(a, a.T, s) + self.assertEqual(c.shape, (2, 2, 4, 10, 10)) + def test_gemv_gemm_same_precision(self): mx.random.seed(0) N = 256 diff --git a/python/tests/test_quantized.py b/python/tests/test_quantized.py index f402bd444..2c62c6307 100644 --- a/python/tests/test_quantized.py +++ b/python/tests/test_quantized.py @@ -549,6 +549,49 @@ class TestQuantized(mlx_tests.MLXTestCase): self.assertTrue(mx.allclose(y1, y3, atol=1e-5)) self.assertTrue(mx.allclose(y1, y4, atol=1e-5)) + def test_gather_qmm_grad(self): + def gather_qmm_ref(x, w, s, b, lhs, rhs, trans, sort): + if lhs is not None: + x = x[lhs] + if rhs is not None: + w = w[rhs] + s = s[rhs] + b = b[rhs] + return mx.quantized_matmul(x, w, s, b, transpose=trans) + + def gather_qmm(x, w, s, b, lhs, rhs, trans, sort): + return mx.gather_qmm( + x, + w, + s, + b, + transpose=trans, + lhs_indices=lhs, + rhs_indices=rhs, + sorted_indices=sort, + ) + + x = mx.random.normal((16, 1, 256)) + w, s, b = mx.quantize(mx.random.normal((4, 256, 256))) + indices = mx.sort(mx.random.randint(0, 4, shape=(16,))) + cotan = mx.random.normal((16, 1, 256)) + + (o1,), (dx1, ds1, db1) = mx.vjp( + lambda x, s, b: gather_qmm_ref(x, w, s, b, None, indices, True, True), + [x, s, b], + [cotan], + ) + (o2,), (dx2, ds2, db2) = mx.vjp( + lambda x, s, b: gather_qmm(x, w, s, b, None, indices, True, True), + [x, s, b], + [cotan], + ) + + self.assertTrue(mx.allclose(o1, o2, atol=1e-4)) + self.assertTrue(mx.allclose(dx1, dx2, atol=1e-4)) + self.assertTrue(mx.allclose(ds1, ds2, atol=1e-3)) + self.assertTrue(mx.allclose(db1, db2, atol=1e-3)) + def test_vjp_scales_biases(self): mx.random.seed(0) x = mx.random.normal(shape=(2, 2, 512))