From 99eefd2ec00f0222c1bcded623c52e2cf1fc0b6a Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Mon, 14 Apr 2025 16:37:36 -0700 Subject: [PATCH] Gather mm new kernel and small refactoring (#2040) --- benchmarks/python/gather_mm_bench.py | 74 ++ mlx/backend/common/CMakeLists.txt | 3 +- mlx/backend/common/broadcasting.cpp | 24 + mlx/backend/common/broadcasting.h | 11 + mlx/backend/common/common.cpp | 18 +- mlx/backend/metal/CMakeLists.txt | 1 + mlx/backend/metal/jit/includes.h | 1 + mlx/backend/metal/jit_kernels.cpp | 38 + mlx/backend/metal/kernels.h | 15 + mlx/backend/metal/kernels/CMakeLists.txt | 2 + .../steel/gemm/kernels/steel_gemm_fused.h | 96 +-- .../steel/gemm/kernels/steel_gemm_gather.h | 459 ++++++++++++ .../gemm/kernels/steel_gemm_gather.metal | 59 ++ mlx/backend/metal/kernels/steel/gemm/mma.h | 81 +++ mlx/backend/metal/matmul.cpp | 665 +++++++++++------- mlx/backend/metal/nojit_kernels.cpp | 17 + mlx/backend/metal/utils.h | 19 +- mlx/ops.cpp | 13 +- mlx/ops.h | 1 + mlx/primitives.cpp | 14 +- mlx/primitives.h | 17 +- python/src/ops.cpp | 8 +- python/tests/test_blas.py | 2 +- 23 files changed, 1260 insertions(+), 378 deletions(-) create mode 100644 benchmarks/python/gather_mm_bench.py create mode 100644 mlx/backend/common/broadcasting.cpp create mode 100644 mlx/backend/common/broadcasting.h create mode 100644 mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h create mode 100644 mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.metal diff --git a/benchmarks/python/gather_mm_bench.py b/benchmarks/python/gather_mm_bench.py new file mode 100644 index 000000000..85ddb08a6 --- /dev/null +++ b/benchmarks/python/gather_mm_bench.py @@ -0,0 +1,74 @@ +# Copyright © 2023-2024 Apple Inc. + +import mlx.core as mx +from time_utils import time_fn + +N = 1024 +D = 1024 +M = 1024 +E = 32 +I = 4 + + +def gather_sort(x, indices): + N, M = indices.shape + indices = indices.flatten() + order = mx.argsort(indices) + inv_order = mx.argsort(order) + return x.flatten(0, -3)[order // M], indices[order], inv_order + + +def scatter_unsort(x, inv_order, shape=None): + x = x[inv_order] + if shape is not None: + x = mx.unflatten(x, 0, shape) + return x + + +def gather_mm_simulate(x, w, indices): + x, idx, inv_order = gather_sort(x, indices) + for i in range(2): + y = mx.concatenate([x[i] @ w[j].T for i, j in enumerate(idx.tolist())], axis=0) + x = y[:, None] + x = scatter_unsort(x, inv_order, indices.shape) + return x + + +def time_gather_mm(): + x = mx.random.normal((N, 1, 1, D)) / 1024**0.5 + w1 = mx.random.normal((E, M, D)) / 1024**0.5 + w2 = mx.random.normal((E, D, M)) / 1024**0.5 + indices = (mx.random.uniform(shape=(N, I)) * E).astype(mx.uint32) + sorted_indices = mx.sort(indices.flatten()).reshape(N, I) + mx.eval(x, w1, w2, indices, sorted_indices) + + def gather_mm(x, w1, w2, indices, sort): + idx = indices + inv_order = None + if sort: + x, idx, inv_order = gather_sort(x, indices) + x = mx.gather_mm(x, w1.swapaxes(-1, -2), rhs_indices=idx, sorted_indices=sort) + x = mx.gather_mm(x, w2.swapaxes(-1, -2), rhs_indices=idx, sorted_indices=sort) + if sort: + x = scatter_unsort(x, inv_order, indices.shape) + return x + + time_fn(gather_mm, x, w1, w2, indices, False) + time_fn(gather_mm, x, w1, w2, sorted_indices, False) + time_fn(gather_mm, x, w1, w2, indices, True) + + x = mx.random.normal((N * I, D)) / 1024**0.5 + w1 = mx.random.normal((M, D)) / 1024**0.5 + w2 = mx.random.normal((D, M)) / 1024**0.5 + mx.eval(x, w1, w2) + + def equivalent_matmul(x, w1, w2): + x = x @ w1.T + x = x @ w2.T + return x + + time_fn(equivalent_matmul, x, w1, w2) + + +if __name__ == "__main__": + time_gather_mm() diff --git a/mlx/backend/common/CMakeLists.txt b/mlx/backend/common/CMakeLists.txt index 82e6eef84..6c4e25067 100644 --- a/mlx/backend/common/CMakeLists.txt +++ b/mlx/backend/common/CMakeLists.txt @@ -1,6 +1,7 @@ target_sources( mlx - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/broadcasting.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp ${CMAKE_CURRENT_SOURCE_DIR}/common.cpp ${CMAKE_CURRENT_SOURCE_DIR}/load.cpp ${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp diff --git a/mlx/backend/common/broadcasting.cpp b/mlx/backend/common/broadcasting.cpp new file mode 100644 index 000000000..49bc75b8f --- /dev/null +++ b/mlx/backend/common/broadcasting.cpp @@ -0,0 +1,24 @@ +// Copyright © 2024 Apple Inc. + +#include "mlx/backend/common/utils.h" + +namespace mlx::core { + +void broadcast(const array& in, array& out) { + if (out.size() == 0) { + out.set_data(nullptr); + return; + } + Strides strides(out.ndim(), 0); + int diff = out.ndim() - in.ndim(); + for (int i = in.ndim() - 1; i >= 0; --i) { + strides[i + diff] = (in.shape()[i] == 1) ? 0 : in.strides()[i]; + } + auto flags = in.flags(); + if (out.size() > in.size()) { + flags.row_contiguous = flags.col_contiguous = false; + } + out.copy_shared_buffer(in, strides, flags, in.data_size()); +} + +} // namespace mlx::core diff --git a/mlx/backend/common/broadcasting.h b/mlx/backend/common/broadcasting.h new file mode 100644 index 000000000..29651e909 --- /dev/null +++ b/mlx/backend/common/broadcasting.h @@ -0,0 +1,11 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include "mlx/array.h" + +namespace mlx::core { + +void broadcast(const array& in, array& out); + +} // namespace mlx::core diff --git a/mlx/backend/common/common.cpp b/mlx/backend/common/common.cpp index 57813e062..2cda88a31 100644 --- a/mlx/backend/common/common.cpp +++ b/mlx/backend/common/common.cpp @@ -1,6 +1,7 @@ // Copyright © 2024 Apple Inc. #include +#include "mlx/backend/common/broadcasting.h" #include "mlx/backend/common/utils.h" #include "mlx/primitives.h" @@ -42,23 +43,6 @@ void AsStrided::eval(const std::vector& inputs, array& out) { return out.copy_shared_buffer(in, strides_, flags, data_size, offset_); } -void broadcast(const array& in, array& out) { - if (out.size() == 0) { - out.set_data(nullptr); - return; - } - Strides strides(out.ndim(), 0); - int diff = out.ndim() - in.ndim(); - for (int i = in.ndim() - 1; i >= 0; --i) { - strides[i + diff] = (in.shape()[i] == 1) ? 0 : in.strides()[i]; - } - auto flags = in.flags(); - if (out.size() > in.size()) { - flags.row_contiguous = flags.col_contiguous = false; - } - out.copy_shared_buffer(in, strides, flags, in.data_size()); -} - void Broadcast::eval(const std::vector& inputs, array& out) { broadcast(inputs[0], out); } diff --git a/mlx/backend/metal/CMakeLists.txt b/mlx/backend/metal/CMakeLists.txt index 7985396c4..332c560f8 100644 --- a/mlx/backend/metal/CMakeLists.txt +++ b/mlx/backend/metal/CMakeLists.txt @@ -61,6 +61,7 @@ if(MLX_METAL_JIT) kernels/steel/gemm/transforms.h) make_jit_source(steel/gemm/kernels/steel_gemm_fused) make_jit_source(steel/gemm/kernels/steel_gemm_masked kernels/steel/defines.h) + make_jit_source(steel/gemm/kernels/steel_gemm_gather) make_jit_source(steel/gemm/kernels/steel_gemm_splitk) make_jit_source( steel/conv/conv diff --git a/mlx/backend/metal/jit/includes.h b/mlx/backend/metal/jit/includes.h index 921ce50ce..27ae22d05 100644 --- a/mlx/backend/metal/jit/includes.h +++ b/mlx/backend/metal/jit/includes.h @@ -33,6 +33,7 @@ const char* gemm(); const char* steel_gemm_fused(); const char* steel_gemm_masked(); const char* steel_gemm_splitk(); +const char* steel_gemm_gather(); const char* conv(); const char* steel_conv(); const char* steel_conv_general(); diff --git a/mlx/backend/metal/jit_kernels.cpp b/mlx/backend/metal/jit_kernels.cpp index 204bb14e7..c0a698a86 100644 --- a/mlx/backend/metal/jit_kernels.cpp +++ b/mlx/backend/metal/jit_kernels.cpp @@ -584,6 +584,44 @@ MTL::ComputePipelineState* get_steel_gemm_masked_kernel( return d.get_kernel(kernel_name, lib); } +MTL::ComputePipelineState* get_steel_gemm_gather_kernel( + metal::Device& d, + const std::string& kernel_name, + const std::string& hash_name, + const metal::MTLFCList& func_consts, + const array& out, + bool transpose_a, + bool transpose_b, + int bm, + int bn, + int bk, + int wm, + int wn, + bool rhs) { + const auto& lib_name = kernel_name; + auto lib = d.get_library(lib_name, [&]() { + std::string kernel_source; + concatenate( + kernel_source, + metal::utils(), + metal::gemm(), + metal::steel_gemm_gather(), + get_template_definition( + lib_name, + rhs ? "gather_mm_rhs" : "gather_mm", + get_type_string(out.dtype()), + bm, + bn, + bk, + wm, + wn, + transpose_a, + transpose_b)); + return kernel_source; + }); + return d.get_kernel(kernel_name, lib, hash_name, func_consts); +} + MTL::ComputePipelineState* get_gemv_masked_kernel( metal::Device& d, const std::string& kernel_name, diff --git a/mlx/backend/metal/kernels.h b/mlx/backend/metal/kernels.h index 1638a4496..ba5914140 100644 --- a/mlx/backend/metal/kernels.h +++ b/mlx/backend/metal/kernels.h @@ -160,6 +160,21 @@ MTL::ComputePipelineState* get_steel_gemm_masked_kernel( bool mn_aligned, bool k_aligned); +MTL::ComputePipelineState* get_steel_gemm_gather_kernel( + metal::Device& d, + const std::string& kernel_name, + const std::string& hash_name, + const metal::MTLFCList& func_consts, + const array& out, + bool transpose_a, + bool transpose_b, + int bm, + int bn, + int bk, + int wm, + int wn, + bool rhs); + MTL::ComputePipelineState* get_steel_conv_kernel( metal::Device& d, const std::string& kernel_name, diff --git a/mlx/backend/metal/kernels/CMakeLists.txt b/mlx/backend/metal/kernels/CMakeLists.txt index 309a840f8..3ee88ca46 100644 --- a/mlx/backend/metal/kernels/CMakeLists.txt +++ b/mlx/backend/metal/kernels/CMakeLists.txt @@ -69,6 +69,7 @@ set(STEEL_HEADERS steel/gemm/loader.h steel/gemm/transforms.h steel/gemm/kernels/steel_gemm_fused.h + steel/gemm/kernels/steel_gemm_gather.h steel/gemm/kernels/steel_gemm_masked.h steel/gemm/kernels/steel_gemm_splitk.h steel/utils/type_traits.h @@ -116,6 +117,7 @@ if(NOT MLX_METAL_JIT) build_kernel(steel/conv/kernels/steel_conv ${STEEL_HEADERS}) build_kernel(steel/conv/kernels/steel_conv_general ${STEEL_HEADERS}) build_kernel(steel/gemm/kernels/steel_gemm_fused ${STEEL_HEADERS}) + build_kernel(steel/gemm/kernels/steel_gemm_gather ${STEEL_HEADERS}) build_kernel(steel/gemm/kernels/steel_gemm_masked ${STEEL_HEADERS}) build_kernel(steel/gemm/kernels/steel_gemm_splitk ${STEEL_HEADERS}) build_kernel(gemv_masked steel/utils.h) diff --git a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h index bcc585bbe..add495d93 100644 --- a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +++ b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h @@ -15,10 +15,6 @@ constant bool align_M [[function_constant(200)]]; constant bool align_N [[function_constant(201)]]; constant bool align_K [[function_constant(202)]]; -constant bool do_gather [[function_constant(300)]]; - -constant bool gather_bias = do_gather && use_out_source; - // clang-format off template < typename T, @@ -39,12 +35,6 @@ template < const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]], const constant int* batch_shape [[buffer(6)]], const constant int64_t* batch_strides [[buffer(7)]], - const constant uint32_t* lhs_indices [[buffer(10), function_constant(do_gather)]], - const constant uint32_t* rhs_indices [[buffer(11), function_constant(do_gather)]], - const constant uint32_t* C_indices [[buffer(12), function_constant(gather_bias)]], - const constant int* operand_shape [[buffer(13), function_constant(do_gather)]], - const constant int64_t* operand_strides [[buffer(14), function_constant(do_gather)]], - const constant packed_int3& operand_batch_ndim [[buffer(15), function_constant(do_gather)]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]], uint3 tid [[threadgroup_position_in_grid]], @@ -81,84 +71,26 @@ template < } // Adjust for batch + if (has_batch) { + const constant auto* A_bstrides = batch_strides; + const constant auto* B_bstrides = batch_strides + params->batch_ndim; - // Handle gather - if (do_gather) { - // Read indices - uint32_t indx_A, indx_B, indx_C; + ulong2 batch_offsets = elem_to_loc_broadcast( + tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim); - if (has_batch) { - const constant auto* indx_A_bstrides = batch_strides; - const constant auto* indx_B_bstrides = batch_strides + params->batch_ndim; - - ulong2 indx_offsets = elem_to_loc_broadcast( - tid.z, - batch_shape, - indx_A_bstrides, - indx_B_bstrides, - params->batch_ndim); - indx_A = lhs_indices[indx_offsets.x]; - indx_B = rhs_indices[indx_offsets.y]; - - if (use_out_source) { - const constant auto* indx_C_bstrides = - indx_B_bstrides + params->batch_ndim; - auto indx_offset_C = elem_to_loc( - tid.z, batch_shape, indx_C_bstrides, params->batch_ndim); - indx_C = C_indices[indx_offset_C]; - } - } else { - indx_A = lhs_indices[params->batch_stride_a * tid.z]; - indx_B = rhs_indices[params->batch_stride_b * tid.z]; - - if (use_out_source) { - indx_C = C_indices[addmm_params->batch_stride_c * tid.z]; - } - } - - // Translate indices to offsets - int batch_ndim_A = operand_batch_ndim.x; - const constant int* batch_shape_A = operand_shape; - const constant auto* batch_strides_A = operand_strides; - A += elem_to_loc(indx_A, batch_shape_A, batch_strides_A, batch_ndim_A); - - int batch_ndim_B = operand_batch_ndim.y; - const constant int* batch_shape_B = batch_shape_A + batch_ndim_A; - const constant auto* batch_strides_B = batch_strides_A + batch_ndim_A; - B += elem_to_loc(indx_B, batch_shape_B, batch_strides_B, batch_ndim_B); + A += batch_offsets.x; + B += batch_offsets.y; if (use_out_source) { - int batch_ndim_C = operand_batch_ndim.z; - const constant int* batch_shape_C = batch_shape_B + batch_ndim_B; - const constant auto* batch_strides_C = batch_strides_B + batch_ndim_B; - C += elem_to_loc(indx_C, batch_shape_C, batch_strides_C, batch_ndim_C); + const constant auto* C_bstrides = B_bstrides + params->batch_ndim; + C += elem_to_loc(tid.z, batch_shape, C_bstrides, params->batch_ndim); } + } else { + A += params->batch_stride_a * tid.z; + B += params->batch_stride_b * tid.z; - } - - // Handle regular batch - else { - if (has_batch) { - const constant auto* A_bstrides = batch_strides; - const constant auto* B_bstrides = batch_strides + params->batch_ndim; - - ulong2 batch_offsets = elem_to_loc_broadcast( - tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim); - - A += batch_offsets.x; - B += batch_offsets.y; - - if (use_out_source) { - const constant auto* C_bstrides = B_bstrides + params->batch_ndim; - C += elem_to_loc(tid.z, batch_shape, C_bstrides, params->batch_ndim); - } - } else { - A += params->batch_stride_a * tid.z; - B += params->batch_stride_b * tid.z; - - if (use_out_source) { - C += addmm_params->batch_stride_c * tid.z; - } + if (use_out_source) { + C += addmm_params->batch_stride_c * tid.z; } } diff --git a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h new file mode 100644 index 000000000..4493375c1 --- /dev/null +++ b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h @@ -0,0 +1,459 @@ +// Copyright © 2024 Apple Inc. + +using namespace mlx::steel; + +constant bool has_batch [[function_constant(10)]]; +constant bool align_M [[function_constant(200)]]; +constant bool align_N [[function_constant(201)]]; +constant bool align_K [[function_constant(202)]]; + +template < + typename T, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose_a, + bool transpose_b, + typename AccumType = float> +[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void gather_mm_rhs( + const device T* A [[buffer(0)]], + const device T* B [[buffer(1)]], + const device uint32_t* rhs_indices [[buffer(2)]], + device T* C [[buffer(3)]], + const constant GEMMParams* params [[buffer(4)]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint3 tid [[threadgroup_position_in_grid]]) { + using gemm_kernel = GEMMKernel< + T, + T, + BM, + BN, + BK, + WM, + WN, + transpose_a, + transpose_b, + true, + true, + AccumType>; + + using loader_a_t = typename gemm_kernel::loader_a_t; + using loader_b_t = typename gemm_kernel::loader_b_t; + using mma_t = typename gemm_kernel::mma_t; + + if (params->tiles_n <= static_cast(tid.x) || + params->tiles_m <= static_cast(tid.y)) { + return; + } + + // Prepare threadgroup memory + threadgroup T As[gemm_kernel::tgp_mem_size_a]; + threadgroup T Bs[gemm_kernel::tgp_mem_size_b]; + + // Find the block in A, B, C + const int c_row = tid.y * BM; + const int c_col = tid.x * BN; + const size_t c_row_long = size_t(c_row); + const size_t c_col_long = size_t(c_col); + + // Prepare threadgroup bounds + const short tgp_bm = align_M ? BM : short(min(BM, params->M - c_row)); + const short tgp_bn = align_N ? BN : short(min(BN, params->N - c_col)); + + A += transpose_a ? c_row_long : c_row_long * params->lda; + B += transpose_b ? c_col_long * params->ldb : c_col_long; + C += c_row_long * params->ldd + c_col_long; + + // Do as many matmuls as necessary + uint32_t index; + short offset; + uint32_t index_next = rhs_indices[c_row]; + short offset_next = 0; + int n = 0; + while (n < tgp_bm) { + n++; + offset = offset_next; + index = index_next; + offset_next = tgp_bm; + for (; n < tgp_bm; n++) { + if (rhs_indices[c_row + n] != index) { + offset_next = n; + index_next = rhs_indices[c_row + n]; + break; + } + } + threadgroup_barrier(mem_flags::mem_none); + + // Prepare threadgroup mma operation + thread mma_t mma_op(simd_group_id, simd_lane_id); + + // Prepare threadgroup loading operations + thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id); + thread loader_b_t loader_b( + B + index * params->batch_stride_b, + params->ldb, + Bs, + simd_group_id, + simd_lane_id); + + // Prepare iterations + const int gemm_k_iterations = params->gemm_k_iterations_aligned; + + // Do unaligned K iterations first + if (!align_K) { + const int k_last = params->gemm_k_iterations_aligned * BK; + const int k_remain = params->K - k_last; + const size_t k_jump_a = + transpose_a ? params->lda * size_t(k_last) : size_t(k_last); + const size_t k_jump_b = + transpose_b ? size_t(k_last) : params->ldb * size_t(k_last); + + // Move loader source ahead to end + loader_a.src += k_jump_a; + loader_b.src += k_jump_b; + + // Load tile + const short2 tile_dims_A = + transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm); + const short2 tile_dims_B = + transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain); + + loader_a.load_safe(tile_dims_A); + loader_b.load_safe(tile_dims_B); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Do matmul + mma_op.mma(As, Bs); + + // Reset source back to start + loader_a.src -= k_jump_a; + loader_b.src -= k_jump_b; + } + + // Matrix level aligned never check + if (align_M && align_N) { + for (int k = 0; k < gemm_k_iterations; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Load elements into threadgroup + loader_a.load_unsafe(); + loader_b.load_unsafe(); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + + // Store results to device memory + if (offset_next - offset == BM) { + mma_op.store_result(C, params->ldd); + } else { + mma_op.store_result_slice( + C, params->ldd, short2(0, offset), short2(BN, offset_next)); + } + } else { + const short lbk = 0; + + // Tile aligned don't check + if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) { + gemm_kernel::gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + lbk, + LoopAlignment{}); + if (offset_next - offset == BM) { + mma_op.store_result(C, params->ldd); + } else { + mma_op.store_result_slice( + C, params->ldd, short2(0, offset), short2(BN, offset_next)); + } + } + + // Tile partially aligned check rows + else if (align_N || tgp_bn == BN) { + gemm_kernel::gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + lbk, + LoopAlignment{}); + mma_op.store_result_slice( + C, params->ldd, short2(0, offset), short2(BN, offset_next)); + } + + // Tile partially aligned check cols + else if (align_M || tgp_bm == BM) { + gemm_kernel::gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + lbk, + LoopAlignment{}); + mma_op.store_result_slice( + C, params->ldd, short2(0, offset), short2(tgp_bn, offset_next)); + } + + // Nothing aligned so check both rows and cols + else { + gemm_kernel::gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + lbk, + LoopAlignment{}); + mma_op.store_result_slice( + C, params->ldd, short2(0, offset), short2(tgp_bn, offset_next)); + } + } + } +} + +template < + typename T, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose_a, + bool transpose_b, + typename AccumType = float> +[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void gather_mm( + const device T* A [[buffer(0)]], + const device T* B [[buffer(1)]], + const device uint32_t* lhs_indices [[buffer(2)]], + const device uint32_t* rhs_indices [[buffer(3)]], + device T* C [[buffer(4)]], + const constant GEMMParams* params [[buffer(5)]], + const constant int* indices_shape [[buffer(6)]], + const constant int64_t* lhs_strides [[buffer(7)]], + const constant int64_t* rhs_strides [[buffer(8)]], + const constant int& batch_ndim_a [[buffer(9)]], + const constant int* batch_shape_a [[buffer(10)]], + const constant int64_t* batch_strides_a [[buffer(11)]], + const constant int& batch_ndim_b [[buffer(12)]], + const constant int* batch_shape_b [[buffer(13)]], + const constant int64_t* batch_strides_b [[buffer(14)]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint3 tid [[threadgroup_position_in_grid]]) { + using gemm_kernel = GEMMKernel< + T, + T, + BM, + BN, + BK, + WM, + WN, + transpose_a, + transpose_b, + true, + true, + AccumType>; + + using loader_a_t = typename gemm_kernel::loader_a_t; + using loader_b_t = typename gemm_kernel::loader_b_t; + using mma_t = typename gemm_kernel::mma_t; + + if (params->tiles_n <= static_cast(tid.x) || + params->tiles_m <= static_cast(tid.y)) { + return; + } + + // Move A and B to the locations pointed by lhs_indices and rhs_indices. + uint32_t indx_A, indx_B; + if (has_batch) { + ulong2 indices_offsets = elem_to_loc_broadcast( + tid.z, indices_shape, lhs_strides, rhs_strides, params->batch_ndim); + indx_A = lhs_indices[indices_offsets.x]; + indx_B = rhs_indices[indices_offsets.y]; + } else { + indx_A = lhs_indices[params->batch_stride_a * tid.z]; + indx_B = rhs_indices[params->batch_stride_b * tid.z]; + } + A += elem_to_loc(indx_A, batch_shape_a, batch_strides_a, batch_ndim_a); + B += elem_to_loc(indx_B, batch_shape_b, batch_strides_b, batch_ndim_b); + C += params->batch_stride_d * tid.z; + + // Prepare threadgroup memory + threadgroup T As[gemm_kernel::tgp_mem_size_a]; + threadgroup T Bs[gemm_kernel::tgp_mem_size_b]; + + // Just make sure everybody's finished with the indexing math above. + threadgroup_barrier(mem_flags::mem_none); + + // Find block in A, B, C + const int c_row = tid.y * BM; + const int c_col = tid.x * BN; + const size_t c_row_long = size_t(c_row); + const size_t c_col_long = size_t(c_col); + + A += transpose_a ? c_row_long : c_row_long * params->lda; + B += transpose_b ? c_col_long * params->ldb : c_col_long; + C += c_row_long * params->ldd + c_col_long; + + // Prepare threadgroup mma operation + thread mma_t mma_op(simd_group_id, simd_lane_id); + + // Prepare threadgroup loading operations + thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id); + thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id); + + // Prepare threadgroup bounds + const short tgp_bm = align_M ? BM : short(min(BM, params->M - c_row)); + const short tgp_bn = align_N ? BN : short(min(BN, params->N - c_col)); + + // Prepare iterations + int gemm_k_iterations = params->gemm_k_iterations_aligned; + + // Do unaligned K iterations first + if (!align_K) { + const int k_last = params->gemm_k_iterations_aligned * BK; + const int k_remain = params->K - k_last; + const size_t k_jump_a = + transpose_a ? params->lda * size_t(k_last) : size_t(k_last); + const size_t k_jump_b = + transpose_b ? size_t(k_last) : params->ldb * size_t(k_last); + + // Move loader source ahead to end + loader_a.src += k_jump_a; + loader_b.src += k_jump_b; + + // Load tile + const short2 tile_dims_A = + transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm); + const short2 tile_dims_B = + transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain); + + loader_a.load_safe(tile_dims_A); + loader_b.load_safe(tile_dims_B); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Do matmul + mma_op.mma(As, Bs); + + // Reset source back to start + loader_a.src -= k_jump_a; + loader_b.src -= k_jump_b; + } + + // Matrix level aligned never check + if (align_M && align_N) { + for (int k = 0; k < gemm_k_iterations; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Load elements into threadgroup + loader_a.load_unsafe(); + loader_b.load_unsafe(); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + + // Store results to device memory + mma_op.store_result(C, params->ldd); + } else { + const short lbk = 0; + + // Tile aligned don't check + if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) { + gemm_kernel::gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + lbk, + LoopAlignment{}); + mma_op.store_result(C, params->ldd); + } + + // Tile partially aligned check rows + else if (align_N || tgp_bn == BN) { + gemm_kernel::gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + lbk, + LoopAlignment{}); + mma_op.store_result_safe(C, params->ldd, short2(tgp_bn, tgp_bm)); + } + + // Tile partially aligned check cols + else if (align_M || tgp_bm == BM) { + gemm_kernel::gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + lbk, + LoopAlignment{}); + mma_op.store_result_safe(C, params->ldd, short2(tgp_bn, tgp_bm)); + } + + // Nothing aligned so check both rows and cols + else { + gemm_kernel::gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + lbk, + LoopAlignment{}); + mma_op.store_result_safe(C, params->ldd, short2(tgp_bn, tgp_bm)); + } + } +} diff --git a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.metal b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.metal new file mode 100644 index 000000000..f8e5a2a37 --- /dev/null +++ b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.metal @@ -0,0 +1,59 @@ +// Copyright © 2024 Apple Inc. + +// clang-format off +#include "mlx/backend/metal/kernels/utils.h" +#include "mlx/backend/metal/kernels/steel/gemm/gemm.h" +#include "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h" + +#define instantiate_gather_mm_rhs(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + instantiate_kernel( \ + "steel_gather_mm_rhs_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn \ + "_bk" #bk "_wm" #wm "_wn" #wn, \ + gather_mm_rhs, \ + itype, \ + bm, \ + bn, \ + bk, \ + wm, \ + wn, \ + trans_a, \ + trans_b, \ + float) + +#define instantiate_gather_mm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + instantiate_kernel( \ + "steel_gather_mm_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn \ + "_bk" #bk "_wm" #wm "_wn" #wn, \ + gather_mm, \ + itype, \ + bm, \ + bn, \ + bk, \ + wm, \ + wn, \ + trans_a, \ + trans_b, \ + float) + +#define instantiate_gather_mm_rhs_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + instantiate_gather_mm_rhs(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + instantiate_gather_mm_rhs(nt, false, true, iname, itype, oname, otype, bm, bn, bk, wm, wn) + +#define instantiate_gather_mm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + instantiate_gather_mm(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + instantiate_gather_mm(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + instantiate_gather_mm(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + instantiate_gather_mm(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn) + +#define instantiate_gather_mm_shapes_helper(iname, itype, oname, otype) \ + instantiate_gather_mm_rhs_transpose_helper(iname, itype, oname, otype, 16, 64, 16, 1, 2) \ + instantiate_gather_mm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) \ + instantiate_gather_mm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 1, 2) \ + instantiate_gather_mm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \ + instantiate_gather_mm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 1, 2) \ + instantiate_gather_mm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) +// clang-format on + +instantiate_gather_mm_shapes_helper(float16, half, float16, half); +instantiate_gather_mm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t); +instantiate_gather_mm_shapes_helper(float32, float, float32, float); diff --git a/mlx/backend/metal/kernels/steel/gemm/mma.h b/mlx/backend/metal/kernels/steel/gemm/mma.h index aea235abb..64b87655e 100644 --- a/mlx/backend/metal/kernels/steel/gemm/mma.h +++ b/mlx/backend/metal/kernels/steel/gemm/mma.h @@ -142,6 +142,42 @@ struct BaseMMAFrag { } } + template < + typename DstPtrType, + typename StrX, + typename StrY, + typename StartX, + typename StopX, + typename StartY, + typename StopY, + typename OffX, + typename OffY> + METAL_FUNC static constexpr void store_slice( + const thread frag_type& src, + DstPtrType dst, + StrX str_x, + StrY str_y, + StartX start_x, + StopX stop_x, + StartY start_y, + StopY stop_y, + OffX off_x = Int<0>{}, + OffY off_y = Int<0>{}) { + using U = pointer_element_t; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + if ((off_x + i) < stop_x && (off_x + i) >= start_x && + (off_y + j) < stop_y && (off_y + j) >= start_y) { + dst[(off_x + i) * str_x + (off_y + j) * str_y] = + static_cast(src[i * kElemCols + j]); + } + } + } + } + METAL_FUNC static constexpr void mma( thread frag_type& D, thread frag_type& A, @@ -335,6 +371,31 @@ struct MMATile { } } } + + template + METAL_FUNC void store_slice( + device U* dst, + const int ld, + const short2 start, + const short2 stop) const { + STEEL_PRAGMA_UNROLL + for (int i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < kTileCols; ++j) { + MMAFrag_t::store_slice( + frag_at(i, j), + dst, + ld, + Int<1>{}, + start.y, + stop.y, + start.x, + stop.x, + (i * kFragRows) * w_x, + (j * kFragCols) * w_y); + } + } + } }; template @@ -474,6 +535,26 @@ struct BlockMMA { Ctile.template store(D, ldd); } + METAL_FUNC void + store_result_slice(device U* D, const int ldd, short2 start, short2 stop) { + // Apply epilogue + STEEL_PRAGMA_UNROLL + for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) { + Ctile.elems()[i] = Epilogue::apply(Ctile.elems()[i]); + } + + D += sm * ldd + sn; + start -= short2(sn, sm); + stop -= short2(sn, sm); + + // TODO: Check the start as well + if (stop.y <= 0 || stop.x <= 0) { + return; + } + + Ctile.template store_slice(D, ldd, start, stop); + } + METAL_FUNC void store_result_safe(device U* D, const int ldd, short2 dst_tile_dims) { // Apply epilogue diff --git a/mlx/backend/metal/matmul.cpp b/mlx/backend/metal/matmul.cpp index 3f736505f..27369ad07 100644 --- a/mlx/backend/metal/matmul.cpp +++ b/mlx/backend/metal/matmul.cpp @@ -5,6 +5,7 @@ #include #include +#include "mlx/backend/common/broadcasting.h" #include "mlx/backend/common/utils.h" #include "mlx/backend/metal/copy.h" #include "mlx/backend/metal/device.h" @@ -102,6 +103,47 @@ std::tuple check_transpose( } }; +inline array +ensure_row_contiguous(const array& x, metal::Device& d, const Stream& s) { + if (!x.flags().row_contiguous) { + array x_copy(x.shape(), x.dtype(), nullptr, {}); + copy_gpu(x, x_copy, CopyType::General, s); + d.add_temporary(x_copy, s.index); + return x_copy; + } else { + return x; + } +} + +inline std::tuple +ensure_batch_contiguous(const array& x, metal::Device& d, const Stream& s) { + if (x.flags().row_contiguous) { + return std::make_tuple(false, x.strides()[x.ndim() - 2], x); + } + + bool rc = true; + for (int i = 0; i < x.ndim() - 3; i++) { + rc &= x.strides()[i + 1] * x.shape(i) == x.strides()[i]; + } + if (rc) { + auto stx = x.strides()[x.ndim() - 2]; + auto sty = x.strides()[x.ndim() - 1]; + auto K = x.shape(-2); + auto N = x.shape(-1); + if (sty == 1 && (N != 1 || stx == N)) { + return std::make_tuple(false, stx, x); + } + if (stx == 1 && (N != 1 || sty == K)) { + return std::make_tuple(true, sty, x); + } + } + + array x_copy(x.shape(), x.dtype(), nullptr, {}); + copy_gpu(x, x_copy, CopyType::General, s); + d.add_temporary(x_copy, s.index); + return std::make_tuple(false, x_copy.strides()[x_copy.ndim() - 2], x_copy); +} + } // namespace /////////////////////////////////////////////////////////////////////////////// @@ -230,7 +272,6 @@ void steel_matmul_regular( const bool align_M = (M % bm) == 0; const bool align_N = (N % bn) == 0; const bool align_K = (K % bk) == 0; - const bool do_gather = false; metal::MTLFCList func_consts = { {&has_batch, MTL::DataType::DataTypeBool, 10}, @@ -239,7 +280,6 @@ void steel_matmul_regular( {&align_M, MTL::DataType::DataTypeBool, 200}, {&align_N, MTL::DataType::DataTypeBool, 201}, {&align_K, MTL::DataType::DataTypeBool, 202}, - {&do_gather, MTL::DataType::DataTypeBool, 300}, }; // clang-format off @@ -248,8 +288,7 @@ void steel_matmul_regular( << "_do_axpby_" << (do_axpby ? 't' : 'n') << "_align_M_" << (align_M ? 't' : 'n') << "_align_N_" << (align_N ? 't' : 'n') - << "_align_K_" << (align_K ? 't' : 'n') - << "_do_gather_" << (do_gather ? 't' : 'n'); // clang-format on + << "_align_K_" << (align_K ? 't' : 'n'); // clang-format on std::string hash_name = kname.str(); @@ -975,7 +1014,6 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { const bool align_M = (M % bm) == 0; const bool align_N = (N % bn) == 0; const bool align_K = (K % bk) == 0; - const bool do_gather = false; metal::MTLFCList func_consts = { {&has_batch, MTL::DataType::DataTypeBool, 10}, @@ -984,7 +1022,6 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { {&align_M, MTL::DataType::DataTypeBool, 200}, {&align_N, MTL::DataType::DataTypeBool, 201}, {&align_K, MTL::DataType::DataTypeBool, 202}, - {&do_gather, MTL::DataType::DataTypeBool, 300}, }; // clang-format off @@ -993,8 +1030,7 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { << "_do_axpby_" << (do_axpby ? 't' : 'n') << "_align_M_" << (align_M ? 't' : 'n') << "_align_N_" << (align_N ? 't' : 'n') - << "_align_K_" << (align_K ? 't' : 'n') - << "_do_gather_" << (do_gather ? 't' : 'n'); // clang-format on + << "_align_K_" << (align_K ? 't' : 'n'); // clang-format on std::string hash_name = kname.str(); @@ -1464,267 +1500,337 @@ void BlockMaskedMM::eval_gpu(const std::vector& inputs, array& out) { d.add_temporaries(std::move(copies), s.index); } -void GatherMM::eval_gpu(const std::vector& inputs, array& out) { - using namespace mlx::steel; - // assert(inputs.size() == 2); - if (!issubdtype(out.dtype(), floating)) { - throw std::runtime_error( - "[GatherMM] Does not yet support non-floating point types."); - } - auto& s = stream(); - auto& d = metal::device(s.device); +void gather_mm_rhs( + const array& a_, + const array& b_, + const array& indices_, + array& out, + metal::Device& d, + const Stream& s) { + array indices = ensure_row_contiguous(indices_, d, s); + auto [transpose_b, ldb, b] = ensure_batch_contiguous(b_, d, s); - auto& a_pre = inputs[0]; - auto& b_pre = inputs[1]; - // Return 0s if either input is empty - if (a_pre.size() == 0 || b_pre.size() == 0) { - array zero = array(0, a_pre.dtype()); - fill_gpu(zero, out, s); - d.add_temporary(std::move(zero), s.index); - return; - } + // Broadcast a with indices. If we are here that means lhs_indices were not + // provided so the lhs_indices are implied to be the shape of a broadcasted + // with rhs_indices. We need only broadcast a and copy it as if applying the + // lhs_indices. + auto broadcast_with_indices = [&d, &s, &indices](const array& x) { + if (x.size() / x.shape(-2) / x.shape(-1) == indices.size()) { + return ensure_row_contiguous(x, d, s); + } - out.set_data(allocator::malloc(out.nbytes())); + auto x_shape = indices.shape(); + x_shape.push_back(x.shape(-2)); + x_shape.push_back(x.shape(-1)); + array new_x(std::move(x_shape), x.dtype(), nullptr, {}); + broadcast(x, new_x); + return ensure_row_contiguous(new_x, d, s); + }; + array a = broadcast_with_indices(a_); - ///////////////////////////////////////////////////////////////////////////// - // Init checks and prep + // Extract the matmul shapes + int K = a.shape(-1); + int M = a.size() / K; + int N = b.shape(-1); + int lda = a.strides()[a.ndim() - 2]; // should be K - int M = a_pre.shape(-2); - int N = b_pre.shape(-1); - int K = a_pre.shape(-1); + // Define the dispatch blocks + int bm = 16, bn = 64, bk = 16; + int wm = 1, wn = 2; - // Keep a vector with copies to be cleared in the completed buffer to release - // the arrays - std::vector copies; - auto [transpose_a, a_cols, a] = check_transpose(copies, s, a_pre, M == 1); - auto [transpose_b, b_cols, b] = check_transpose(copies, s, b_pre, N == 1); + const bool align_M = (M % bm) == 0; + const bool align_N = (N % bn) == 0; + const bool align_K = (K % bk) == 0; - int lda = a_cols; - int ldb = b_cols; + // Define the kernel name + std::string base_name; + base_name.reserve(64); + concatenate( + base_name, + "steel_gather_mm_rhs_n", + transpose_b ? 't' : 'n', + '_', + type_to_name(a), + '_', + type_to_name(out), + "_bm", + bm, + "_bn", + bn, + "_bk", + bk, + "_wm", + wm, + "_wn", + wn); - ///////////////////////////////////////////////////////////////////////////// - // Check and collapse batch dimensions - - auto get_batch_dims = [](const auto& v) { - return decltype(v){v.begin(), v.end() - 2}; + metal::MTLFCList func_consts = { + {&align_M, MTL::DataType::DataTypeBool, 200}, + {&align_N, MTL::DataType::DataTypeBool, 201}, + {&align_K, MTL::DataType::DataTypeBool, 202}, }; - auto& lhs_indices = inputs[2]; - auto& rhs_indices = inputs[3]; + // And the kernel hash that includes the function constants + std::string hash_name; + hash_name.reserve(128); + concatenate( + hash_name, + base_name, + "_align_M_", + align_M ? 't' : 'n', + "_align_N_", + align_N ? 't' : 'n', + "_align_K_", + align_K ? 't' : 'n'); - Shape batch_shape = get_batch_dims(out.shape()); - Strides batch_strides; + // Get and set the kernel + auto& compute_encoder = d.get_command_encoder(s.index); + auto kernel = get_steel_gemm_gather_kernel( + d, + base_name, + hash_name, + func_consts, + out, + false, + transpose_b, + bm, + bn, + bk, + wm, + wn, + true); + compute_encoder.set_compute_pipeline_state(kernel); - batch_strides.insert( - batch_strides.end(), - lhs_indices.strides().begin(), - lhs_indices.strides().end()); - auto lhs_indices_str = batch_strides.empty() ? 0 : batch_strides.back(); + // Prepare the matmul params + auto batch_stride_b = b.ndim() > 2 ? b.strides()[b.ndim() - 3] : b.size(); + steel::GEMMParams params{ + /* const int M = */ M, + /* const int N = */ N, + /* const int K = */ K, + /* const int lda = */ lda, + /* const int ldb = */ static_cast(ldb), + /* const int ldd = */ N, + /* const int tiles_n = */ (N + bn - 1) / bn, + /* const int tiles_m = */ (M + bm - 1) / bm, + /* const int64_t batch_stride_a = */ 0, + /* const int64_t batch_stride_b = */ static_cast(batch_stride_b), + /* const int64_t batch_stride_d = */ 0, + /* const int swizzle_log = */ 0, + /* const int gemm_k_iterations_aligned = */ (K / bk), + /* const int batch_ndim = */ 0}; - batch_strides.insert( - batch_strides.end(), - rhs_indices.strides().begin(), - rhs_indices.strides().end()); - auto rhs_indices_str = batch_strides.empty() ? 0 : batch_strides.back(); + // Prepare the grid + MTL::Size group_dims = MTL::Size(32, wn, wm); + MTL::Size grid_dims = MTL::Size(params.tiles_n, params.tiles_m, 1); - int batch_ndim = batch_shape.size(); + // Launch kernel + compute_encoder.set_input_array(a, 0); + compute_encoder.set_input_array(b, 1); + compute_encoder.set_input_array(indices, 2); + compute_encoder.set_output_array(out, 3); + compute_encoder.set_bytes(params, 4); - if (batch_ndim == 0) { - batch_shape = {1}; - batch_strides = {0}; - } + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); +} - int batch_ndim_A = a.ndim() - 2; - int batch_ndim_B = b.ndim() - 2; - std::vector operand_batch_ndim = {batch_ndim_A, batch_ndim_B}; +void gather_mv( + const array& mat_, + const array& vec_, + const array& mat_indices_, + const array& vec_indices_, + array& out, + int N, + int K, + bool is_mv, + metal::Device& d, + const Stream& s) { + // Copy if needed + std::vector copies; + auto [transpose_mat, mat_cols, mat] = + check_transpose(copies, s, mat_, N == 1); + auto [transpose_vec, vec_cols, vec] = check_transpose(copies, s, vec_, true); + d.add_temporaries(std::move(copies), s.index); - Shape batch_shape_A = get_batch_dims(a.shape()); - Strides batch_strides_A = get_batch_dims(a.strides()); - Shape batch_shape_B = get_batch_dims(b.shape()); - Strides batch_strides_B = get_batch_dims(b.strides()); + // If we are doing vector matrix instead of matrix vector we need to flip the + // matrix transposition. Basically m @ v = v @ m.T assuming that v is treated + // as a one dimensional array. + transpose_mat = (!is_mv) ^ transpose_mat; - if (batch_ndim_A == 0) { - batch_shape_A = {1}; - batch_strides_A = {0}; - } + // Define some shapes + int in_vector_len = K; + int out_vector_len = N; + int mat_ld = mat_cols; - if (batch_ndim_B == 0) { - batch_shape_B = {1}; - batch_strides_B = {0}; - } + int batch_size_out = out.size() / N; + int batch_ndim = out.ndim() - 2; + int batch_ndim_mat = mat.ndim() - 2; + int batch_ndim_vec = vec.ndim() - 2; + Strides index_strides = vec_indices_.strides(); + index_strides.insert( + index_strides.end(), + mat_indices_.strides().begin(), + mat_indices_.strides().end()); - auto matrix_stride_out = static_cast(M) * N; - auto batch_size_out = out.size() / matrix_stride_out; - - ///////////////////////////////////////////////////////////////////////////// - // Gemv specialization - - // Route to gemv if needed - if (std::min(M, N) == 1) { - // Collect problem info - bool is_b_matrix = N != 1; - - auto& mat = is_b_matrix ? b : a; - auto& vec = is_b_matrix ? a : b; - bool transpose_mat = is_b_matrix ? !transpose_b : transpose_a; - int in_vector_len = K; - int out_vector_len = is_b_matrix ? N : M; - - int mat_cols = transpose_mat ? out_vector_len : in_vector_len; - int mat_rows = transpose_mat ? in_vector_len : out_vector_len; - int mat_ld = is_b_matrix ? b_cols : a_cols; - - auto batch_strides_mat = is_b_matrix ? batch_strides_B : batch_strides_A; - auto batch_strides_vec = is_b_matrix ? batch_strides_A : batch_strides_B; - - auto batch_shape_mat = is_b_matrix ? batch_shape_B : batch_shape_A; - auto batch_shape_vec = is_b_matrix ? batch_shape_A : batch_shape_B; - - if (!is_b_matrix) { - batch_strides = rhs_indices.strides(); - batch_strides.insert( - batch_strides.end(), - lhs_indices.strides().begin(), - lhs_indices.strides().end()); - } - - int batch_ndim = batch_shape.size(); - - // Determine dispatch kernel - int tm = 4, tn = 4; - int sm = 1, sn = 32; - int bm = 1, bn = 1; - int n_out_per_tgp; - std::ostringstream kname; - - if (transpose_mat) { - if (in_vector_len >= 8192 && out_vector_len >= 2048) { - sm = 4; - sn = 8; - } else { - sm = 8; - sn = 4; - } - - if (out_vector_len >= 2048) { - bn = 16; - } else if (out_vector_len >= 512) { - bn = 4; - } else { - bn = 2; - } - - // Specialized kernel for very small outputs - tn = out_vector_len < tn ? 1 : tn; - - n_out_per_tgp = bn * sn * tn; - kname << "gemv_t_gather_" << type_to_name(out); + // Determine dispatch kernel + int tm = 4, tn = 4; + int sm = 1, sn = 32; + int bm = 1, bn = 1; + int n_out_per_tgp; + std::ostringstream kname; + if (transpose_mat) { + if (in_vector_len >= 8192 && out_vector_len >= 2048) { + sm = 4; + sn = 8; } else { - bm = out_vector_len >= 4096 ? 8 : 4; - sn = 32; - - // Specialized kernel for very small outputs - tm = out_vector_len < tm ? 1 : tm; - - n_out_per_tgp = bm * sm * tm; - kname << "gemv_gather_" << type_to_name(out); + sm = 8; + sn = 4; } - kname << "_bm" << bm << "_bn" << bn << "_sm" << sm << "_sn" << sn << "_tm" - << tm << "_tn" << tn; + if (out_vector_len >= 2048) { + bn = 16; + } else if (out_vector_len >= 512) { + bn = 4; + } else { + bn = 2; + } - // Encode and dispatch kernel - auto& compute_encoder = d.get_command_encoder(s.index); - auto kernel = d.get_kernel(kname.str()); - compute_encoder.set_compute_pipeline_state(kernel); + // Specialized kernel for very small outputs + tn = out_vector_len < tn ? 1 : tn; - int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp; - MTL::Size group_dims = MTL::Size(32, bn, bm); - MTL::Size grid_dims = MTL::Size(n_tgp, 1, batch_size_out); + n_out_per_tgp = bn * sn * tn; + kname << "gemv_t_gather_" << type_to_name(out); - compute_encoder.set_input_array(mat, 0); - compute_encoder.set_input_array(vec, 1); - compute_encoder.set_output_array(out, 3); + } else { + bm = out_vector_len >= 4096 ? 8 : 4; + sn = 32; - compute_encoder.set_bytes(in_vector_len, 4); - compute_encoder.set_bytes(out_vector_len, 5); - compute_encoder.set_bytes(mat_ld, 6); + // Specialized kernel for very small outputs + tm = out_vector_len < tm ? 1 : tm; - compute_encoder.set_bytes(batch_ndim, 9); - compute_encoder.set_vector_bytes(batch_shape, 10); - compute_encoder.set_vector_bytes(batch_strides, 11); - - int batch_ndim_vec = batch_shape_vec.size(); - compute_encoder.set_bytes(batch_ndim_vec, 12); - compute_encoder.set_vector_bytes(batch_shape_vec, 13); - compute_encoder.set_vector_bytes(batch_strides_vec, 14); - - int batch_ndim_mat = batch_shape_mat.size(); - compute_encoder.set_bytes(batch_ndim_mat, 15); - compute_encoder.set_vector_bytes(batch_shape_mat, 16); - compute_encoder.set_vector_bytes(batch_strides_mat, 17); - - compute_encoder.set_input_array(lhs_indices, 18 + int(!is_b_matrix)); - compute_encoder.set_input_array(rhs_indices, 18 + int(is_b_matrix)); - - compute_encoder.dispatch_threadgroups(grid_dims, group_dims); - - d.add_temporaries(std::move(copies), s.index); - return; + n_out_per_tgp = bm * sm * tm; + kname << "gemv_gather_" << type_to_name(out); } - ///////////////////////////////////////////////////////////////////////////// - // Regular kernel dispatch + kname << "_bm" << bm << "_bn" << bn << "_sm" << sm << "_sn" << sn << "_tm" + << tm << "_tn" << tn; + + // Encode and dispatch kernel + auto& compute_encoder = d.get_command_encoder(s.index); + auto kernel = d.get_kernel(kname.str()); + compute_encoder.set_compute_pipeline_state(kernel); + + int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp; + MTL::Size group_dims = MTL::Size(32, bn, bm); + MTL::Size grid_dims = MTL::Size(n_tgp, 1, batch_size_out); + + compute_encoder.set_input_array(mat, 0); + compute_encoder.set_input_array(vec, 1); + compute_encoder.set_output_array(out, 3); + + compute_encoder.set_bytes(in_vector_len, 4); + compute_encoder.set_bytes(out_vector_len, 5); + compute_encoder.set_bytes(mat_ld, 6); + + compute_encoder.set_bytes(batch_ndim, 9); + compute_encoder.set_vector_bytes(out.shape(), 10); + compute_encoder.set_vector_bytes(index_strides, 11); + + compute_encoder.set_bytes(batch_ndim_vec, 12); + compute_encoder.set_vector_bytes(vec.shape(), 13); + compute_encoder.set_vector_bytes(vec.strides(), 14); + + compute_encoder.set_bytes(batch_ndim_mat, 15); + compute_encoder.set_vector_bytes(mat.shape(), 16); + compute_encoder.set_vector_bytes(mat.strides(), 17); + + compute_encoder.set_input_array(vec_indices_, 18); + compute_encoder.set_input_array(mat_indices_, 19); + + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); +} + +void gather_mm( + const array& a_, + const array& b_, + const array& lhs_indices, + const array& rhs_indices, + array& out, + int M, + int N, + int K, + metal::Device& d, + const Stream& s) { + // Copy if needed + std::vector copies; + auto [transpose_a, lda, a] = check_transpose(copies, s, a_, false); + auto [transpose_b, ldb, b] = check_transpose(copies, s, b_, false); + d.add_temporaries(std::move(copies), s.index); // Determine dispatch kernel int bm = 64, bn = 64, bk = 16; int wm = 2, wn = 2; + size_t batch_size_out = out.size() / M / N; + int batch_ndim = out.ndim() - 2; + int batch_ndim_a = a.ndim() - 2; + int batch_ndim_b = b.ndim() - 2; char devc = d.get_architecture().back(); GEMM_TPARAM_MACRO(devc) - // Prepare kernel name - std::ostringstream kname; - kname << "steel_gemm_fused_" << (transpose_a ? 't' : 'n') - << (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_" - << type_to_name(out) << "_bm" << bm << "_bn" << bn << "_bk" << bk - << "_wm" << wm << "_wn" << wn; - - std::string base_name = kname.str(); - const bool has_batch = batch_ndim > 1; - const bool use_out_source = false; - const bool do_axpby = false; const bool align_M = (M % bm) == 0; const bool align_N = (N % bn) == 0; const bool align_K = (K % bk) == 0; - const bool do_gather = true; + + // Define the kernel name + std::string base_name; + base_name.reserve(128); + concatenate( + base_name, + "steel_gather_mm_", + transpose_a ? 't' : 'n', + transpose_b ? 't' : 'n', + "_", + type_to_name(a), + "_", + type_to_name(out), + "_bm", + bm, + "_bn", + bn, + "_bk", + bk, + "_wm", + wm, + "_wn", + wn); metal::MTLFCList func_consts = { {&has_batch, MTL::DataType::DataTypeBool, 10}, - {&use_out_source, MTL::DataType::DataTypeBool, 100}, - {&do_axpby, MTL::DataType::DataTypeBool, 110}, {&align_M, MTL::DataType::DataTypeBool, 200}, {&align_N, MTL::DataType::DataTypeBool, 201}, {&align_K, MTL::DataType::DataTypeBool, 202}, - {&do_gather, MTL::DataType::DataTypeBool, 300}, }; - // clang-format off - kname << "_has_batch_" << (has_batch ? 't' : 'n') - << "_use_out_source_" << (use_out_source ? 't' : 'n') - << "_do_axpby_" << (do_axpby ? 't' : 'n') - << "_align_M_" << (align_M ? 't' : 'n') - << "_align_N_" << (align_N ? 't' : 'n') - << "_align_K_" << (align_K ? 't' : 'n') - << "_do_gather_" << (do_gather ? 't' : 'n'); // clang-format on + // And the kernel hash that includes the function constants + std::string hash_name; + hash_name.reserve(128); + concatenate( + hash_name, + base_name, + "_has_batch_", + has_batch ? 't' : 'n', + "_align_M_", + align_M ? 't' : 'n', + "_align_N_", + align_N ? 't' : 'n', + "_align_K_", + align_K ? 't' : 'n'); - std::string hash_name = kname.str(); - - // Encode and dispatch kernel + // Get and set the kernel auto& compute_encoder = d.get_command_encoder(s.index); - auto kernel = get_steel_gemm_fused_kernel( + auto kernel = get_steel_gemm_gather_kernel( d, base_name, hash_name, @@ -1736,72 +1842,97 @@ void GatherMM::eval_gpu(const std::vector& inputs, array& out) { bn, bk, wm, - wn); - + wn, + false); compute_encoder.set_compute_pipeline_state(kernel); - // Use problem size to determine threadblock swizzle - int tn = (N + bn - 1) / bn; - int tm = (M + bm - 1) / bm; - - // TODO: Explore device-based tuning for swizzle - int swizzle_log = 0; // tm >= 6 ? 3 : (tm <= 3 ? 0 : 2); - - // Prepare steel matmul params - GEMMParams params{ + // Prepare the matmul params + steel::GEMMParams params{ /* const int M = */ M, /* const int N = */ N, /* const int K = */ K, - /* const int lda = */ lda, - /* const int ldb = */ ldb, + /* const int lda = */ static_cast(lda), + /* const int ldb = */ static_cast(ldb), /* const int ldd = */ N, - /* const int tiles_n = */ tn, - /* const int tiles_m = */ tm, - /* const int64_t batch_stride_a = */ lhs_indices_str, - /* const int64_t batch_stride_b = */ rhs_indices_str, - /* const int64_t batch_stride_d = */ matrix_stride_out, - /* const int swizzle_log = */ swizzle_log, + /* const int tiles_n = */ (N + bn - 1) / bn, + /* const int tiles_m = */ (M + bm - 1) / bm, + /* const int64_t batch_stride_a = */ + (batch_ndim > 0) ? lhs_indices.strides()[0] : 0, + /* const int64_t batch_stride_b = */ + (batch_ndim > 0) ? rhs_indices.strides()[0] : 0, + /* const int64_t batch_stride_d = */ M * N, + /* const int swizzle_log = */ 0, /* const int gemm_k_iterations_aligned = */ (K / bk), /* const int batch_ndim = */ batch_ndim}; - // Prepare launch grid params - int tile = 1 << swizzle_log; - tm = (tm + tile - 1) / tile; - tn = tn * tile; - + // Prepare the grid MTL::Size group_dims = MTL::Size(32, wn, wm); - MTL::Size grid_dims = MTL::Size(tn, tm, batch_size_out); + MTL::Size grid_dims = + MTL::Size(params.tiles_n, params.tiles_m, batch_size_out); // Launch kernel compute_encoder.set_input_array(a, 0); compute_encoder.set_input_array(b, 1); - compute_encoder.set_output_array(out, 3); - - compute_encoder.set_bytes(params, 4); - - compute_encoder.set_vector_bytes(batch_shape, 6); - compute_encoder.set_vector_bytes(batch_strides, 7); - - compute_encoder.set_input_array(lhs_indices, 10); - compute_encoder.set_input_array(rhs_indices, 11); - - std::vector operand_shape = batch_shape_A; - operand_shape.insert( - operand_shape.end(), batch_shape_B.begin(), batch_shape_B.end()); - - std::vector operand_strides = batch_strides_A; - operand_strides.insert( - operand_strides.end(), batch_strides_B.begin(), batch_strides_B.end()); - - operand_batch_ndim.push_back(0); - - compute_encoder.set_vector_bytes(operand_shape, 13); - compute_encoder.set_vector_bytes(operand_strides, 14); - compute_encoder.set_vector_bytes(operand_batch_ndim, 15); - + compute_encoder.set_input_array(lhs_indices, 2); + compute_encoder.set_input_array(rhs_indices, 3); + compute_encoder.set_output_array(out, 4); + compute_encoder.set_bytes(params, 5); + compute_encoder.set_vector_bytes(lhs_indices.shape(), 6); + compute_encoder.set_vector_bytes(lhs_indices.strides(), 7); + compute_encoder.set_vector_bytes(rhs_indices.strides(), 8); + compute_encoder.set_bytes(batch_ndim_a, 9); + compute_encoder.set_vector_bytes(a.shape(), 10); + compute_encoder.set_vector_bytes(a.strides(), 11); + compute_encoder.set_bytes(batch_ndim_b, 12); + compute_encoder.set_vector_bytes(b.shape(), 13); + compute_encoder.set_vector_bytes(b.strides(), 14); compute_encoder.dispatch_threadgroups(grid_dims, group_dims); +} - d.add_temporaries(std::move(copies), s.index); +void GatherMM::eval_gpu(const std::vector& inputs, array& out) { + auto& s = stream(); + auto& d = metal::device(s.device); + + auto& a = inputs[0]; + auto& b = inputs[1]; + auto& lhs_indices = inputs[2]; + auto& rhs_indices = inputs[3]; + + // Return 0s if either input is empty + if (a.size() == 0 || b.size() == 0) { + array zero = array(0, a.dtype()); + fill_gpu(zero, out, s); + d.add_temporary(std::move(zero), s.index); + return; + } + + out.set_data(allocator::malloc(out.nbytes())); + + // Extract shapes strides from inputs and copy in case of non-contiguous + // vectors. + int M = a.shape(-2); + int N = b.shape(-1); + int K = a.shape(-1); + + // We are walking a in order and b is also in order so we can batch up the + // matmuls and reuse reading a and b. + if (M == 1 && right_sorted_ == true) { + gather_mm_rhs(a, b, rhs_indices, out, d, s); + return; + } + + // Route to gather gemv if any of a or b are vectors + if (M == 1) { + gather_mv(b, a, rhs_indices, lhs_indices, out, N, K, false, d, s); + return; + } + if (N == 1) { + gather_mv(a, b, lhs_indices, rhs_indices, out, M, K, true, d, s); + return; + } + + // Route to non specialized gather mm + gather_mm(a, b, lhs_indices, rhs_indices, out, M, N, K, d, s); } } // namespace mlx::core diff --git a/mlx/backend/metal/nojit_kernels.cpp b/mlx/backend/metal/nojit_kernels.cpp index 2d6077ed1..292af6919 100644 --- a/mlx/backend/metal/nojit_kernels.cpp +++ b/mlx/backend/metal/nojit_kernels.cpp @@ -193,6 +193,23 @@ MTL::ComputePipelineState* get_steel_gemm_masked_kernel( return d.get_kernel(kernel_name); } +MTL::ComputePipelineState* get_steel_gemm_gather_kernel( + metal::Device& d, + const std::string& kernel_name, + const std::string& hash_name, + const metal::MTLFCList& func_consts, + const array&, + bool, + bool, + int, + int, + int, + int, + int, + bool) { + return d.get_kernel(kernel_name, "mlx", hash_name, func_consts); +} + MTL::ComputePipelineState* get_gemv_masked_kernel( metal::Device& d, const std::string& kernel_name, diff --git a/mlx/backend/metal/utils.h b/mlx/backend/metal/utils.h index cc56bab32..079d15f17 100644 --- a/mlx/backend/metal/utils.h +++ b/mlx/backend/metal/utils.h @@ -2,6 +2,8 @@ #pragma once +#include + #include "mlx/array.h" #include "mlx/backend/metal/device.h" #include "mlx/primitives.h" @@ -58,14 +60,27 @@ inline void debug_set_primitive_buffer_label( std::string get_primitive_string(Primitive* primitive); +template +constexpr bool is_numeric_except_char = std::is_arithmetic_v && + !std::is_same_v && !std::is_same_v && + !std::is_same_v && !std::is_same_v; + template void concatenate(std::string& acc, T first) { - acc += first; + if constexpr (is_numeric_except_char) { + acc += std::to_string(first); + } else { + acc += first; + } } template void concatenate(std::string& acc, T first, Args... args) { - acc += first; + if constexpr (is_numeric_except_char) { + acc += std::to_string(first); + } else { + acc += first; + } concatenate(acc, args...); } diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 6d1116905..1946a43fa 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -4499,6 +4499,7 @@ array gather_mm( array b, std::optional lhs_indices_ /* = std::nullopt */, std::optional rhs_indices_ /* = std::nullopt */, + bool sorted_indices /* = false */, StreamOrDevice s /* = {} */) { // If no indices, fall back to full matmul if (!lhs_indices_ && !rhs_indices_) { @@ -4574,12 +4575,18 @@ array gather_mm( out_shape.push_back(M); out_shape.push_back(N); - // Caculate array + // Make the output array auto out = array( std::move(out_shape), out_type, - std::make_shared(to_stream(s)), - {a, b, lhs_indices, rhs_indices}); + std::make_shared( + to_stream(s), + sorted_indices && !rhs_indices_, + sorted_indices && !lhs_indices_), + {std::move(a), + std::move(b), + std::move(lhs_indices), + std::move(rhs_indices)}); // Remove the possibly inserted singleton dimensions std::vector axes; diff --git a/mlx/ops.h b/mlx/ops.h index ce3d4ff44..f6fd958b3 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -1399,6 +1399,7 @@ array gather_mm( array b, std::optional lhs_indices = std::nullopt, std::optional rhs_indices = std::nullopt, + bool sorted_indices = false, StreamOrDevice s = {}); /** Extract a diagonal or construct a diagonal array */ diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 8328d96da..9b34fe657 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -4895,6 +4895,8 @@ std::vector GatherMM::vjp( int N = cotan.shape(-1); int K = primals[0].shape(-1); + bool sorted = left_sorted_ || right_sorted_; + for (auto arg : argnums) { if (arg == 0) { // M X N * (K X N).T -> M X K @@ -4905,7 +4907,8 @@ std::vector GatherMM::vjp( base = reshape(base, {-1, M, K}, stream()); // g : (out_batch_shape) + (M, K) - auto g = gather_mm(cotan, bt, std::nullopt, rhs_indices, stream()); + auto g = + gather_mm(cotan, bt, std::nullopt, rhs_indices, sorted, stream()); g = expand_dims(g, -3, stream()); auto gacc = scatter_add(base, lhs_indices, g, 0, stream()); @@ -4920,7 +4923,8 @@ std::vector GatherMM::vjp( base = reshape(base, {-1, K, N}, stream()); // g : (out_batch_shape) + (K, N) - auto g = gather_mm(at, cotan, lhs_indices, std::nullopt, stream()); + auto g = + gather_mm(at, cotan, lhs_indices, std::nullopt, sorted, stream()); g = expand_dims(g, -3, stream()); auto gacc = scatter_add(base, rhs_indices, g, 0, stream()); @@ -4933,6 +4937,12 @@ std::vector GatherMM::vjp( return vjps; } +bool GatherMM::is_equivalent(const Primitive& other) const { + const GatherMM& g_other = static_cast(other); + return left_sorted_ == g_other.left_sorted_ && + right_sorted_ == g_other.right_sorted_; +} + bool BlockMaskedMM::is_equivalent(const Primitive& other) const { const BlockMaskedMM& a_other = static_cast(other); return (block_size_ == a_other.block_size_); diff --git a/mlx/primitives.h b/mlx/primitives.h index 7738b273b..1902a562d 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -498,7 +498,13 @@ class BlockMaskedMM : public UnaryPrimitive { class GatherMM : public UnaryPrimitive { public: - explicit GatherMM(Stream stream) : UnaryPrimitive(stream) {} + explicit GatherMM( + Stream stream, + bool left_sorted = false, + bool right_sorted = false) + : UnaryPrimitive(stream), + left_sorted_(left_sorted), + right_sorted_(right_sorted) {} void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; @@ -510,7 +516,14 @@ class GatherMM : public UnaryPrimitive { const std::vector& outputs) override; DEFINE_PRINT(GatherMM) - DEFINE_DEFAULT_IS_EQUIVALENT() + bool is_equivalent(const Primitive& other) const override; + auto state() const { + return std::make_pair(left_sorted_, right_sorted_); + } + + private: + bool left_sorted_; + bool right_sorted_; }; class BroadcastAxes : public UnaryPrimitive { diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 7f06a4ddf..8798ba482 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -4464,9 +4464,10 @@ void init_ops(nb::module_& m) { "lhs_indices"_a = nb::none(), "rhs_indices"_a = nb::none(), nb::kw_only(), + "sorted_indices"_a = false, "stream"_a = nb::none(), nb::sig( - "def gather_mm(a: array, b: array, /, lhs_indices: array, rhs_indices: array, *, stream: Union[None, Stream, Device] = None) -> array"), + "def gather_mm(a: array, b: array, /, lhs_indices: array, rhs_indices: array, *, sorted_indices: bool = False, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( Matrix multiplication with matrix-level gather. @@ -4485,11 +4486,16 @@ void init_ops(nb::module_& m) { For ``b`` with shape ``(B1, B2, ..., BS, M, K)``, ``rhs_indices`` contains indices from the range ``[0, B1 * B2 * ... * BS)`` + If only one index is passed and it is sorted, the ``sorted_indices`` + flag can be passed for a possible faster implementation. + Args: a (array): Input array. b (array): Input array. lhs_indices (array, optional): Integer indices for ``a``. Default: ``None`` rhs_indices (array, optional): Integer indices for ``b``. Default: ``None`` + sorted_indices (bool, optional): May allow a faster implementation + if the passed indices are sorted. Default: ``False``. Returns: array: The output array. diff --git a/python/tests/test_blas.py b/python/tests/test_blas.py index 8b7fb462d..6fca4885b 100644 --- a/python/tests/test_blas.py +++ b/python/tests/test_blas.py @@ -1108,7 +1108,7 @@ class TestBlas(mlx_tests.MLXTestCase): lhs_indices_ = mx.broadcast_to(lhs_indices, (3, 2)) rhs_indices_ = mx.broadcast_to(rhs_indices, (3, 2)) M = a.shape[-2] - N = b.shape[-2] + N = b.shape[-1] K = a.shape[-1] a = a.reshape((-1, M, K))