From f3909576859855367a45d6557bbbbbe304c8590d Mon Sep 17 00:00:00 2001 From: Jagrit Digani Date: Thu, 2 May 2024 14:03:58 -0700 Subject: [PATCH] Block sparse mm (#1058) --- docs/src/python/ops.rst | 1 + mlx/backend/accelerate/primitives.cpp | 1 + mlx/backend/common/default_primitives.cpp | 1 + mlx/backend/common/masked_mm.cpp | 87 +++++ mlx/backend/metal/kernels/gemv.metal | 345 ++++++++++++++--- .../gemm/kernels/steel_gemm_gather.metal | 168 +++++++++ mlx/backend/metal/kernels/steel/gemm/params.h | 8 +- mlx/backend/metal/matmul.cpp | 357 ++++++++++++++++-- mlx/backend/no_metal/primitives.cpp | 3 +- mlx/ops.cpp | 118 ++++++ mlx/ops.h | 8 + mlx/primitives.cpp | 56 ++- mlx/primitives.h | 20 + python/src/ops.cpp | 32 ++ python/tests/test_blas.py | 193 ++++++++++ 15 files changed, 1323 insertions(+), 75 deletions(-) create mode 100644 mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.metal diff --git a/docs/src/python/ops.rst b/docs/src/python/ops.rst index 7bae81069..abd5d1997 100644 --- a/docs/src/python/ops.rst +++ b/docs/src/python/ops.rst @@ -32,6 +32,7 @@ Operations bitwise_or bitwise_xor block_masked_mm + block_sparse_mm broadcast_to ceil clip diff --git a/mlx/backend/accelerate/primitives.cpp b/mlx/backend/accelerate/primitives.cpp index 7a3610717..7b48e62f7 100644 --- a/mlx/backend/accelerate/primitives.cpp +++ b/mlx/backend/accelerate/primitives.cpp @@ -32,6 +32,7 @@ DEFAULT(ArgReduce) DEFAULT(ArgSort) DEFAULT(AsStrided) DEFAULT(BlockMaskedMM) +DEFAULT(BlockSparseMM) DEFAULT(Broadcast) DEFAULT(Ceil) DEFAULT(Concatenate) diff --git a/mlx/backend/common/default_primitives.cpp b/mlx/backend/common/default_primitives.cpp index a0cc42984..d8ec303f1 100644 --- a/mlx/backend/common/default_primitives.cpp +++ b/mlx/backend/common/default_primitives.cpp @@ -42,6 +42,7 @@ DEFAULT(AsType) DEFAULT(AsStrided) DEFAULT(Broadcast) DEFAULT(BlockMaskedMM) +DEFAULT(BlockSparseMM) DEFAULT_MULTI(DivMod) DEFAULT(Ceil) DEFAULT(Concatenate) diff --git a/mlx/backend/common/masked_mm.cpp b/mlx/backend/common/masked_mm.cpp index 52711d512..fa2dd32af 100644 --- a/mlx/backend/common/masked_mm.cpp +++ b/mlx/backend/common/masked_mm.cpp @@ -190,4 +190,91 @@ void BlockMaskedMM::eval(const std::vector& inputs, array& out) { } } +void BlockSparseMM::eval(const std::vector& inputs, array& out) { + if (out.dtype() != float32) { + throw std::runtime_error( + "[BlockSparseMM::eval] Currently only supports float32."); + } + out.set_data(allocator::malloc_or_wait(out.nbytes())); + + auto& a_pre = inputs[0]; + auto& b_pre = inputs[1]; + + auto check_transpose = [](const array& arr) { + auto stx = arr.strides()[arr.ndim() - 2]; + auto sty = arr.strides()[arr.ndim() - 1]; + if (stx == arr.shape(-1) && sty == 1) { + return std::make_tuple(false, stx, arr); + } else if (stx == 1 && sty == arr.shape(-2)) { + return std::make_tuple(true, sty, arr); + } else { + array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); + copy(arr, arr_copy, CopyType::General); + size_t stx = arr.shape(-1); + return std::make_tuple(false, stx, arr_copy); + } + }; + + auto [a_transposed, lda, a] = check_transpose(a_pre); + auto [b_transposed, ldb, b] = check_transpose(b_pre); + + size_t M = a.shape(-2); + size_t N = b.shape(-1); + size_t K = a.shape(-1); + + if (M == 0 || N == 0) { + return; + } + + if (K == 0) { + std::memset(static_cast(out.data()), 0, out.nbytes()); + return; + } + + // Get batch dims + auto batch_size_out = out.size() / (M * N); + size_t matrix_stride_out = M * N; + + auto get_batch_dims = [](const auto& v) { + return decltype(v){v.begin(), v.end() - 2}; + }; + + auto& lhs_indices = inputs[2]; + auto& rhs_indices = inputs[3]; + + std::vector batch_shape = get_batch_dims(out.shape()); + int batch_ndim = batch_shape.size(); + + std::vector batch_shape_A = get_batch_dims(a.shape()); + std::vector batch_strides_A = get_batch_dims(a.strides()); + std::vector batch_shape_B = get_batch_dims(b.shape()); + std::vector batch_strides_B = get_batch_dims(b.strides()); + + const uint32_t* lhs_indices_ptr = lhs_indices.data(); + const uint32_t* rhs_indices_ptr = rhs_indices.data(); + + for (int i = 0; i < batch_size_out; i++) { + // Get index + uint32_t indx_A = lhs_indices_ptr[elem_to_loc(i, lhs_indices)]; + uint32_t indx_B = rhs_indices_ptr[elem_to_loc(i, rhs_indices)]; + + cblas_sgemm( + CblasRowMajor, + a_transposed ? CblasTrans : CblasNoTrans, // transA + b_transposed ? CblasTrans : CblasNoTrans, // transB + M, + N, + K, + 1.0f, // alpha + a.data() + elem_to_loc(indx_A, batch_shape_A, batch_strides_A), + lda, + b.data() + elem_to_loc(indx_B, batch_shape_B, batch_strides_B), + ldb, + 0.0f, // beta + out.data() + matrix_stride_out * i, + out.shape(-1) // ldc + ); + } +} + } // namespace mlx::core \ No newline at end of file diff --git a/mlx/backend/metal/kernels/gemv.metal b/mlx/backend/metal/kernels/gemv.metal index 3f9025358..0398c2cd1 100644 --- a/mlx/backend/metal/kernels/gemv.metal +++ b/mlx/backend/metal/kernels/gemv.metal @@ -7,6 +7,8 @@ #include "mlx/backend/metal/kernels/defines.h" #include "mlx/backend/metal/kernels/utils.h" +#include "mlx/backend/metal/kernels/steel/utils.h" + using namespace metal; /////////////////////////////////////////////////////////////////////////////// @@ -33,22 +35,19 @@ struct GEMVKernel { // - We assume each thead group is launched with (BN, BM, 1) threads // // 1. A thread loads TN elements each from mat along TM contiguous rows - // and the corresponding scalar from the vector + // and the corresponding scalar from the vector // 2. The thread then multiplies and adds to accumulate its local result for - // the block + // the block // 3. At the end, each thread has accumulated results over all blocks across - // the rows - // These are then summed up across the threadgroup + // the rows. These are then summed up across the threadgroup // 4. Each threadgroup writes its accumulated BN * TN outputs // // Edge case handling: - // - The threadgroup with the largest tid will have blocks that exceed the - // matrix + // - The threadgroup with the largest tid has blocks that exceed the matrix // * The blocks that start outside the matrix are never read (thread results - // remain zero) + // remain zero) // * The last thread that partially overlaps with the matrix is shifted - // inwards - // such that the thread block fits exactly in the matrix + // inwards such that the thread block fits exactly in the matrix MLX_MTL_CONST short tgp_mem_size = BN * TN * 2; @@ -100,14 +99,14 @@ struct GEMVKernel { if (simd_gid == 0) { // Main load loop if (bn + TN <= in_vec_size) { -#pragma clang loop unroll(full) + MLX_MTL_PRAGMA_UNROLL for (int tn = 0; tn < TN; tn++) { in_vec_block[tn] = in_vec[bn + tn]; } } else { // Edgecase -#pragma clang loop unroll(full) + MLX_MTL_PRAGMA_UNROLL for (int tn = 0; tn < TN; tn++) { in_vec_block[tn] = bn + tn < in_vec_size ? in_vec[bn + tn] : 0; } @@ -116,24 +115,24 @@ struct GEMVKernel { threadgroup_barrier(mem_flags::mem_threadgroup); -// Load for all rows -#pragma clang loop unroll(full) + // Load for all rows + MLX_MTL_PRAGMA_UNROLL for (int tn = 0; tn < TN; tn++) { v_coeff[tn] = in_vec_block[tn]; } -// Per thread work loop -#pragma clang loop unroll(full) + // Per thread work loop + MLX_MTL_PRAGMA_UNROLL for (int tm = 0; tm < TM; tm++) { // Load for the row if (bn + TN <= in_vec_size) { -#pragma clang loop unroll(full) + MLX_MTL_PRAGMA_UNROLL for (int tn = 0; tn < TN; tn++) { inter[tn] = mat[tm * marix_ld + bn + tn]; } } else { // Edgecase -#pragma clang loop unroll(full) + MLX_MTL_PRAGMA_UNROLL for (int tn = 0; tn < TN; tn++) { int col_idx = (bn + tn) < in_vec_size ? (bn + tn) : (in_vec_size - 1); @@ -142,21 +141,22 @@ struct GEMVKernel { } // Accumulate results + MLX_MTL_PRAGMA_UNROLL for (int tn = 0; tn < TN; tn++) { result[tm] += inter[tn] * v_coeff[tn]; } } } -// Simdgroup accumulations -#pragma clang loop unroll(full) + // Simdgroup accumulations + MLX_MTL_PRAGMA_UNROLL for (int tm = 0; tm < TM; tm++) { result[tm] = simd_sum(result[tm]); } // Write outputs if (simd_lid == 0) { -#pragma clang loop unroll(full) + MLX_MTL_PRAGMA_UNROLL for (int tm = 0; tm < TM; tm++) { if (kDoAxpby) { out_vec[out_row + tm] = static_cast(alpha) * result[tm] + @@ -187,22 +187,18 @@ struct GEMVTKernel { // - We assume each thead group is launched with (BN, BM, 1) threads // // 1. A thread loads TN elements each from mat along TM contiguous rows - // and the corresponding scalar from the vector - // 2. The thread then multiplies and adds to accumulate its local result for - // the block + // and the corresponding scalar from the vector + // 2. The thread then accumulates its local result for the block // 3. At the end, each thread has accumulated results over all blocks across - // the rows - // These are then summed up across the threadgroup + // the rows. These are then summed up across the threadgroup // 4. Each threadgroup writes its accumulated BN * TN outputs // // Edge case handling: - // - The threadgroup with the largest tid will have blocks that exceed the - // matrix + // - The threadgroup with the largest tid has blocks that exceed the matrix // * The blocks that start outside the matrix are never read (thread results - // remain zero) + // remain zero) // * The last thread that partially overlaps with the matrix is shifted - // inwards - // such that the thread block fits exactly in the matrix + // inwards such that the thread block fits exactly in the matrix MLX_MTL_CONST short tgp_mem_size = BN * BM * TN; @@ -249,12 +245,12 @@ struct GEMVTKernel { threadgroup_barrier(mem_flags::mem_none); if (bm + TM <= in_vec_size) { -#pragma clang loop unroll(full) + MLX_MTL_PRAGMA_UNROLL for (int tm = 0; tm < TM; tm++) { v_coeff[tm] = in_vec[bm + tm]; } -#pragma clang loop unroll(full) + MLX_MTL_PRAGMA_UNROLL for (int tm = 0; tm < TM; tm++) { for (int tn = 0; tn < TN; tn++) { inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn]; @@ -281,7 +277,7 @@ struct GEMVTKernel { // Threadgroup collection -#pragma clang loop unroll(full) + MLX_MTL_PRAGMA_UNROLL for (int i = 0; i < TN; i++) { tgp_results[lid.y * TN + i] = result[i]; } @@ -290,15 +286,15 @@ struct GEMVTKernel { // Threadgroup accumulation and writing out results if (lid.y == 0 && out_col < out_vec_size) { -#pragma clang loop unroll(full) + MLX_MTL_PRAGMA_UNROLL for (int i = 1; i < BM; i++) { -#pragma clang loop unroll(full) + MLX_MTL_PRAGMA_UNROLL for (int j = 0; j < TN; j++) { result[j] += tgp_results[i * TN + j]; } } -#pragma clang loop unroll(full) + MLX_MTL_PRAGMA_UNROLL for (int j = 0; j < TN; j++) { if (kDoAxpby) { out_vec[out_col + j] = static_cast(alpha) * result[j] + @@ -408,20 +404,149 @@ template < uint simd_gid [[simdgroup_index_in_threadgroup]], \ uint simd_lid [[thread_index_in_simdgroup]]); -#define instantiate_gemv(name, itype, bm, bn, tm, tn) \ - instantiate_gemv_helper(name, itype, bm, bn, tm, tn, 0, 0) \ - instantiate_gemv_helper(name, itype, bm, bn, tm, tn, 0, 1) \ - instantiate_gemv_helper(name, itype, bm, bn, tm, tn, 1, 0) \ - instantiate_gemv_helper(name, itype, bm, bn, tm, tn, 1, 1) +// clang-format off +#define instantiate_gemv(name, itype, bm, bn, tm, tn) \ + instantiate_gemv_helper(name, itype, bm, bn, tm, tn, 0, 0) \ + instantiate_gemv_helper(name, itype, bm, bn, tm, tn, 0, 1) \ + instantiate_gemv_helper(name, itype, bm, bn, tm, tn, 1, 0) \ + instantiate_gemv_helper(name, itype, bm, bn, tm, tn, 1, 1) // clang-format on -#define instantiate_gemv_blocks(name, itype) \ - instantiate_gemv(name, itype, 4, 32, 1, 4) instantiate_gemv( \ - name, itype, 4, 32, 4, 4) instantiate_gemv(name, itype, 8, 32, 4, 4) +// clang-format off +#define instantiate_gemv_blocks(name, itype) \ + instantiate_gemv(name, itype, 4, 32, 1, 4) \ + instantiate_gemv(name, itype, 4, 32, 4, 4) \ + instantiate_gemv(name, itype, 8, 32, 4, 4) // clang-format on instantiate_gemv_blocks(float32, float); instantiate_gemv_blocks(float16, half); instantiate_gemv_blocks(bfloat16, bfloat16_t); +template < + typename T, + const int BM, /* Threadgroup rows (in threads) */ + const int BN, /* Threadgroup cols (in threads) */ + const int TM, /* Thread rows (in elements) */ + const int TN> /* Thread cols (in elements) */ +[[kernel, max_total_threads_per_threadgroup(BM* BN)]] void gemv_bs( + const device T* mat [[buffer(0)]], + const device T* in_vec [[buffer(1)]], + const device T* bias [[buffer(2)]], + device T* out_vec [[buffer(3)]], + const constant int& in_vec_size [[buffer(4)]], + const constant int& out_vec_size [[buffer(5)]], + const constant int& marix_ld [[buffer(6)]], + const constant float& alpha [[buffer(7)]], + const constant float& beta [[buffer(8)]], + const constant int& batch_ndim [[buffer(9)]], + const constant int* batch_shape [[buffer(10)]], + const constant size_t* index_batch_strides [[buffer(11)]], + const constant int& vector_batch_ndim [[buffer(12)]], + const constant int* vector_batch_shape [[buffer(13)]], + const constant size_t* vector_batch_stride [[buffer(14)]], + const constant int& matrix_batch_ndim [[buffer(15)]], + const constant int* matrix_batch_shape [[buffer(16)]], + const constant size_t* matrix_batch_stride [[buffer(17)]], + const constant uint32_t* vec_indices [[buffer(18)]], + const constant uint32_t* mat_indices [[buffer(19)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + using gemv_kernel = GEMVKernel; + threadgroup T tgp_memory[gemv_kernel::tgp_mem_size]; + + uint32_t indx_vec; + uint32_t indx_mat; + + // Update batch offsets + if (batch_ndim > 1) { + const constant size_t* veci_bstrides = index_batch_strides; + const constant size_t* mati_bstrides = index_batch_strides + batch_ndim; + + ulong2 batch_offsets = elem_to_loc_broadcast( + tid.z, batch_shape, veci_bstrides, mati_bstrides, batch_ndim); + + indx_vec = vec_indices[batch_offsets.x]; + indx_mat = mat_indices[batch_offsets.y]; + + } else { + indx_vec = vec_indices[index_batch_strides[0] * tid.z]; + indx_mat = mat_indices[index_batch_strides[batch_ndim] * tid.z]; + } + + if (vector_batch_ndim > 1) { + in_vec += elem_to_loc( + indx_vec, vector_batch_shape, vector_batch_stride, vector_batch_ndim); + } else { + in_vec += indx_vec * vector_batch_stride[0]; + } + + if (matrix_batch_ndim > 1) { + mat += elem_to_loc( + indx_mat, matrix_batch_shape, matrix_batch_stride, matrix_batch_ndim); + } else { + mat += indx_mat * matrix_batch_stride[0]; + } + + out_vec += tid.z * out_vec_size; + + gemv_kernel::run( + mat, + in_vec, + bias, + out_vec, + in_vec_size, + out_vec_size, + marix_ld, + alpha, + beta, + batch_ndim, // Not used + tgp_memory, + tid, + lid, + simd_gid, + simd_lid); +} + +#define instantiate_gemv_bs_helper(nm, itype, bm, bn, tm, tn) \ + template [[host_name("gemv_bs_" #nm "_bm" #bm "_bn" #bn "_tm" #tm \ + "_tn" #tn)]] [[kernel]] void \ + gemv_bs( \ + const device itype* mat [[buffer(0)]], \ + const device itype* in_vec [[buffer(1)]], \ + const device itype* bias [[buffer(2)]], \ + device itype* out_vec [[buffer(3)]], \ + const constant int& in_vec_size [[buffer(4)]], \ + const constant int& out_vec_size [[buffer(5)]], \ + const constant int& marix_ld [[buffer(6)]], \ + const constant float& alpha [[buffer(7)]], \ + const constant float& beta [[buffer(8)]], \ + const constant int& batch_ndim [[buffer(9)]], \ + const constant int* batch_shape [[buffer(10)]], \ + const constant size_t* index_batch_strides [[buffer(11)]], \ + const constant int& vector_batch_ndim [[buffer(12)]], \ + const constant int* vector_batch_shape [[buffer(13)]], \ + const constant size_t* vector_batch_stride [[buffer(14)]], \ + const constant int& matrix_batch_ndim [[buffer(15)]], \ + const constant int* matrix_batch_shape [[buffer(16)]], \ + const constant size_t* matrix_batch_stride [[buffer(17)]], \ + const constant uint32_t* vec_indices [[buffer(18)]], \ + const constant uint32_t* mat_indices [[buffer(19)]], \ + uint3 tid [[threadgroup_position_in_grid]], \ + uint3 lid [[thread_position_in_threadgroup]], \ + uint simd_gid [[simdgroup_index_in_threadgroup]], \ + uint simd_lid [[thread_index_in_simdgroup]]); + +// clang-format off +#define instantiate_gemv_bs_blocks(name, itype) \ + instantiate_gemv_bs_helper(name, itype, 4, 32, 1, 4) \ + instantiate_gemv_bs_helper(name, itype, 4, 32, 4, 4) \ + instantiate_gemv_bs_helper(name, itype, 8, 32, 4, 4) // clang-format on + +instantiate_gemv_bs_blocks(float32, float); +instantiate_gemv_bs_blocks(float16, half); +instantiate_gemv_bs_blocks(bfloat16, bfloat16_t); + /////////////////////////////////////////////////////////////////////////////// /// Vector matrix multiplication /////////////////////////////////////////////////////////////////////////////// @@ -538,4 +663,134 @@ template < // clang-format off instantiate_gemv_t_blocks(float32, float); instantiate_gemv_t_blocks(float16, half); -instantiate_gemv_t_blocks(bfloat16, bfloat16_t); // clang-format on \ No newline at end of file +instantiate_gemv_t_blocks(bfloat16, bfloat16_t); // clang-format on + +template < + typename T, + const int BM, /* Threadgroup rows (in threads) */ + const int BN, /* Threadgroup cols (in threads) */ + const int TM, /* Thread rows (in elements) */ + const int TN> /* Thread cols (in elements) */ +[[kernel, max_total_threads_per_threadgroup(BM* BN)]] void gemv_t_bs( + const device T* mat [[buffer(0)]], + const device T* in_vec [[buffer(1)]], + const device T* bias [[buffer(2)]], + device T* out_vec [[buffer(3)]], + const constant int& in_vec_size [[buffer(4)]], + const constant int& out_vec_size [[buffer(5)]], + const constant int& marix_ld [[buffer(6)]], + const constant float& alpha [[buffer(7)]], + const constant float& beta [[buffer(8)]], + const constant int& batch_ndim [[buffer(9)]], + const constant int* batch_shape [[buffer(10)]], + const constant size_t* index_batch_strides [[buffer(11)]], + const constant int& vector_batch_ndim [[buffer(12)]], + const constant int* vector_batch_shape [[buffer(13)]], + const constant size_t* vector_batch_stride [[buffer(14)]], + const constant int& matrix_batch_ndim [[buffer(15)]], + const constant int* matrix_batch_shape [[buffer(16)]], + const constant size_t* matrix_batch_stride [[buffer(17)]], + const constant uint32_t* vec_indices [[buffer(18)]], + const constant uint32_t* mat_indices [[buffer(19)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + using gemv_kernel = GEMVTKernel; + threadgroup T tgp_memory[gemv_kernel::tgp_mem_size]; + + uint32_t indx_vec; + uint32_t indx_mat; + + // Update batch offsets + if (batch_ndim > 1) { + const constant size_t* veci_bstrides = index_batch_strides; + const constant size_t* mati_bstrides = index_batch_strides + batch_ndim; + + ulong2 batch_offsets = elem_to_loc_broadcast( + tid.z, batch_shape, veci_bstrides, mati_bstrides, batch_ndim); + + indx_vec = vec_indices[batch_offsets.x]; + indx_mat = mat_indices[batch_offsets.y]; + + } else { + indx_vec = vec_indices[index_batch_strides[0] * tid.z]; + indx_mat = mat_indices[index_batch_strides[batch_ndim] * tid.z]; + } + + if (vector_batch_ndim > 1) { + in_vec += elem_to_loc( + indx_vec, vector_batch_shape, vector_batch_stride, vector_batch_ndim); + } else { + in_vec += indx_vec * vector_batch_stride[0]; + } + + if (matrix_batch_ndim > 1) { + mat += elem_to_loc( + indx_mat, matrix_batch_shape, matrix_batch_stride, matrix_batch_ndim); + } else { + mat += indx_mat * matrix_batch_stride[0]; + } + + out_vec += tid.z * out_vec_size; + + gemv_kernel::run( + mat, + in_vec, + bias, + out_vec, + in_vec_size, + out_vec_size, + marix_ld, + alpha, + beta, + batch_ndim, // Not used, + tgp_memory, + tid, + lid, + simd_gid, + simd_lid); +} + +#define instantiate_gemv_t_bs_helper(nm, itype, bm, bn, tm, tn) \ + template [[host_name("gemv_t_bs_" #nm "_bm" #bm "_bn" #bn "_tm" #tm \ + "_tn" #tn)]] [[kernel]] void \ + gemv_t_bs( \ + const device itype* mat [[buffer(0)]], \ + const device itype* in_vec [[buffer(1)]], \ + const device itype* bias [[buffer(2)]], \ + device itype* out_vec [[buffer(3)]], \ + const constant int& in_vec_size [[buffer(4)]], \ + const constant int& out_vec_size [[buffer(5)]], \ + const constant int& marix_ld [[buffer(6)]], \ + const constant float& alpha [[buffer(7)]], \ + const constant float& beta [[buffer(8)]], \ + const constant int& batch_ndim [[buffer(9)]], \ + const constant int* batch_shape [[buffer(10)]], \ + const constant size_t* index_batch_strides [[buffer(11)]], \ + const constant int& vector_batch_ndim [[buffer(12)]], \ + const constant int* vector_batch_shape [[buffer(13)]], \ + const constant size_t* vector_batch_stride [[buffer(14)]], \ + const constant int& matrix_batch_ndim [[buffer(15)]], \ + const constant int* matrix_batch_shape [[buffer(16)]], \ + const constant size_t* matrix_batch_stride [[buffer(17)]], \ + const constant uint32_t* vec_indices [[buffer(18)]], \ + const constant uint32_t* mat_indices [[buffer(19)]], \ + uint3 tid [[threadgroup_position_in_grid]], \ + uint3 lid [[thread_position_in_threadgroup]], \ + uint simd_gid [[simdgroup_index_in_threadgroup]], \ + uint simd_lid [[thread_index_in_simdgroup]]); + +// clang-format off +#define instantiate_gemv_t_bs_blocks(name, itype) \ + instantiate_gemv_t_bs_helper(name, itype, 8, 8, 4, 1) \ + instantiate_gemv_t_bs_helper(name, itype, 8, 8, 4, 4) \ + instantiate_gemv_t_bs_helper(name, itype, 8, 16, 4, 4) \ + instantiate_gemv_t_bs_helper(name, itype, 8, 32, 4, 4) \ + instantiate_gemv_t_bs_helper(name, itype, 8, 64, 4, 4) \ + instantiate_gemv_t_bs_helper(name, itype, 8, 128, 4, 4) // clang-format on + +// clang-format off +instantiate_gemv_t_bs_blocks(float32, float); +instantiate_gemv_t_bs_blocks(float16, half); +instantiate_gemv_t_bs_blocks(bfloat16, bfloat16_t); // clang-format on \ No newline at end of file diff --git a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.metal b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.metal new file mode 100644 index 000000000..d93417e81 --- /dev/null +++ b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.metal @@ -0,0 +1,168 @@ +// Copyright © 2024 Apple Inc. + +#include "mlx/backend/metal/kernels/bf16.h" +#include "mlx/backend/metal/kernels/steel/gemm/gemm.h" +#include "mlx/backend/metal/kernels/utils.h" + +using namespace metal; +using namespace mlx::steel; + +/////////////////////////////////////////////////////////////////////////////// +// GEMM kernels +/////////////////////////////////////////////////////////////////////////////// + +template < + typename T, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose_a, + bool transpose_b, + bool MN_aligned, + bool K_aligned> +[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void bs_gemm( + const device T* A [[buffer(0)]], + const device T* B [[buffer(1)]], + device T* D [[buffer(3)]], + const constant GEMMParams* params [[buffer(4)]], + const constant int* batch_shape [[buffer(6)]], + const constant size_t* batch_strides [[buffer(7)]], + const constant uint32_t* lhs_indices [[buffer(10)]], + const constant uint32_t* rhs_indices [[buffer(11)]], + const constant int* batch_shape_A [[buffer(12)]], + const constant size_t* batch_strides_A [[buffer(13)]], + const constant int* batch_shape_B [[buffer(14)]], + const constant size_t* batch_strides_B [[buffer(15)]], + const constant int2& operand_batch_ndim [[buffer(16)]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]) { + using gemm_kernel = GEMMKernel< + T, + T, + BM, + BN, + BK, + WM, + WN, + transpose_a, + transpose_b, + MN_aligned, + K_aligned>; + + threadgroup T As[gemm_kernel::tgp_mem_size_a]; + threadgroup T Bs[gemm_kernel::tgp_mem_size_b]; + + uint32_t indx_A; + uint32_t indx_B; + + // Adjust for batch + if (params->batch_ndim > 1) { + const constant size_t* A_bstrides = batch_strides; + const constant size_t* B_bstrides = batch_strides + params->batch_ndim; + + ulong2 batch_offsets = elem_to_loc_broadcast( + tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim); + + indx_A = lhs_indices[batch_offsets.x]; + indx_B = rhs_indices[batch_offsets.y]; + + } else { + indx_A = lhs_indices[params->batch_stride_a * tid.z]; + indx_B = rhs_indices[params->batch_stride_b * tid.z]; + } + + int batch_ndim_A = operand_batch_ndim.x; + int batch_ndim_B = operand_batch_ndim.y; + + if (batch_ndim_A > 1) { + A += elem_to_loc(indx_A, batch_shape_A, batch_strides_A, batch_ndim_A); + } else { + A += indx_A * batch_strides_A[0]; + } + + if (batch_ndim_B > 1) { + B += elem_to_loc(indx_B, batch_shape_B, batch_strides_B, batch_ndim_B); + } else { + B += indx_B * batch_strides_B[0]; + } + + D += params->batch_stride_d * tid.z; + + gemm_kernel::run( + A, B, D, params, As, Bs, simd_lane_id, simd_group_id, tid, lid); +} + +/////////////////////////////////////////////////////////////////////////////// +// GEMM kernel initializations +/////////////////////////////////////////////////////////////////////////////// + +#define instantiate_gemm( \ + tname, \ + trans_a, \ + trans_b, \ + iname, \ + itype, \ + oname, \ + otype, \ + bm, \ + bn, \ + bk, \ + wm, \ + wn, \ + aname, \ + mn_aligned, \ + kname, \ + k_aligned) \ + template [[host_name("steel_block_sparse_gemm_" #tname "_" #iname "_" #oname \ + "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn \ + "_MN_" #aname "_K_" #kname)]] [[kernel]] void \ + bs_gemm( \ + const device itype* A [[buffer(0)]], \ + const device itype* B [[buffer(1)]], \ + device itype* D [[buffer(3)]], \ + const constant GEMMParams* params [[buffer(4)]], \ + const constant int* batch_shape [[buffer(6)]], \ + const constant size_t* batch_strides [[buffer(7)]], \ + const constant uint32_t* lhs_indices [[buffer(10)]], \ + const constant uint32_t* rhs_indices [[buffer(11)]], \ + const constant int* batch_shape_A [[buffer(12)]], \ + const constant size_t* batch_strides_A [[buffer(13)]], \ + const constant int* batch_shape_B [[buffer(14)]], \ + const constant size_t* batch_strides_B [[buffer(15)]], \ + const constant int2& operand_batch_ndim [[buffer(16)]], \ + uint simd_lane_id [[thread_index_in_simdgroup]], \ + uint simd_group_id [[simdgroup_index_in_threadgroup]], \ + uint3 tid [[threadgroup_position_in_grid]], \ + uint3 lid [[thread_position_in_threadgroup]]); + +// clang-format off +#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \ + instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \ + instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \ + instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false) // clang-format on + +// clang-format off +#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn) // clang-format on + +// clang-format off +#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \ + instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \ + instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) \ + instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \ + instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 16, 2, 2) \ + instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 2, 2) // clang-format on + +// clang-format off +instantiate_gemm_shapes_helper(float16, half, float16, half); +instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t); + +instantiate_gemm_shapes_helper(float32, float, float32, float); // clang-format on \ No newline at end of file diff --git a/mlx/backend/metal/kernels/steel/gemm/params.h b/mlx/backend/metal/kernels/steel/gemm/params.h index 642d7f31f..e8bcb2217 100644 --- a/mlx/backend/metal/kernels/steel/gemm/params.h +++ b/mlx/backend/metal/kernels/steel/gemm/params.h @@ -21,9 +21,9 @@ struct GEMMParams { const int tiles_n; const int tiles_m; - const int batch_stride_a; - const int batch_stride_b; - const int batch_stride_d; + const size_t batch_stride_a; + const size_t batch_stride_b; + const size_t batch_stride_d; const int swizzle_log; const int gemm_k_iterations_aligned; @@ -54,7 +54,7 @@ struct GEMMAddMMParams { const int ldc; const int fdc; - const int batch_stride_c; + const size_t batch_stride_c; const float alpha; const float beta; diff --git a/mlx/backend/metal/matmul.cpp b/mlx/backend/metal/matmul.cpp index 3e550cd3d..dc99ead61 100644 --- a/mlx/backend/metal/matmul.cpp +++ b/mlx/backend/metal/matmul.cpp @@ -327,9 +327,9 @@ void steel_matmul_conv_groups( /* const int ldd = */ ldd, /* const int tiles_n = */ tn, /* const int tiles_m = */ tm, - /* const int batch_stride_a = */ K, - /* const int batch_stride_b = */ N * K, - /* const int batch_stride_d = */ N, + /* const size_t batch_stride_a = */ size_t(K), + /* const size_t batch_stride_b = */ size_t(N) * K, + /* const size_t batch_stride_d = */ size_t(N), /* const int swizzle_log = */ swizzle_log, /* const int gemm_k_iterations_aligned = */ (K / bk), /* const int batch_ndim = */ 1}; @@ -405,7 +405,7 @@ void steel_matmul( } } - int matrix_stride_out = M * N; + size_t matrix_stride_out = size_t(M) * N; ///////////////////////////////////////////////////////////////////////////// // Split K specialization @@ -550,9 +550,9 @@ void steel_matmul( /* const int ldd = */ N, /* const int tiles_n = */ tn, /* const int tiles_m = */ tm, - /* const int batch_stride_a = */ int(A_batch_stride.back()), - /* const int batch_stride_b = */ int(B_batch_stride.back()), - /* const int batch_stride_d = */ matrix_stride_out, + /* const size_t batch_stride_a = */ A_batch_stride.back(), + /* const size_t batch_stride_b = */ B_batch_stride.back(), + /* const size_t batch_stride_d = */ matrix_stride_out, /* const int swizzle_log = */ swizzle_log, /* const int gemm_k_iterations_aligned = */ (K / bk), /* const int batch_ndim = */ int(batch_shape.size())}; @@ -645,7 +645,7 @@ void Matmul::eval_gpu(const std::vector& inputs, array& out) { auto [batch_shape, A_batch_stride, B_batch_stride] = collapse_batches(a, b); - auto batch_size_out = out.size() / (M * N); + auto batch_size_out = out.size() / (size_t(M) * size_t(N)); // Collapse batches into M if needed if (batch_size_out > 1 && !a_transposed && batch_shape.size() == 1 && @@ -853,7 +853,8 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { auto [batch_shape, A_batch_stride, B_batch_stride, C_batch_stride] = collapse_batches(a, b, c); - auto batch_size_out = out.size() / (M * N); + size_t matrix_stride_out = size_t(M) * size_t(N); + auto batch_size_out = out.size() / (matrix_stride_out); // Collapse batches into M if needed if (batch_size_out > 1 && !transpose_a && batch_shape.size() == 1 && @@ -869,8 +870,6 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { batch_shape = {1}; } - int matrix_stride_out = M * N; - ///////////////////////////////////////////////////////////////////////////// // Gemv specialization @@ -1120,9 +1119,9 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { /* const int ldd = */ N, /* const int tiles_n = */ tn, /* const int tiles_m = */ tm, - /* const int batch_stride_a = */ int(A_batch_stride.back()), - /* const int batch_stride_b = */ int(B_batch_stride.back()), - /* const int batch_stride_d = */ matrix_stride_out, + /* const size_t batch_stride_a = */ A_batch_stride.back(), + /* const size_t batch_stride_b = */ B_batch_stride.back(), + /* const size_t batch_stride_d = */ matrix_stride_out, /* const int swizzle_log = */ swizzle_log, /* const int gemm_k_iterations_aligned = */ (K / bk), /* const int batch_ndim = */ int(batch_shape.size())}; @@ -1130,7 +1129,7 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { GEMMAddMMParams params{ /* const int ldc = */ ldc, /* const int fdc = */ fdc, - /* const int batch_stride_c = */ int(C_batch_stride.back()), + /* const size_t batch_stride_c = */ C_batch_stride.back(), /* const float alpha = */ alpha_, /* const float beta = */ beta_}; @@ -1230,8 +1229,8 @@ void BlockMaskedMM::eval_gpu(const std::vector& inputs, array& out) { auto& out_mask = inputs[2]; std::vector batch_shape{1}; - int A_batch_str = 0; - int B_batch_str = 0; + size_t A_batch_str = 0; + size_t B_batch_str = 0; std::vector batch_strides; @@ -1249,8 +1248,8 @@ void BlockMaskedMM::eval_gpu(const std::vector& inputs, array& out) { auto [bshape_c, bstrides_c] = collapse_contiguous_dims(bshape, bstrides); batch_shape = bshape_c; - A_batch_str = int(bstrides_c[0].back()); - B_batch_str = int(bstrides_c[1].back()); + A_batch_str = bstrides_c[0].back(); + B_batch_str = bstrides_c[1].back(); for (auto& bstr : bstrides_c) { batch_strides.insert(batch_strides.end(), bstr.begin(), bstr.end()); @@ -1259,8 +1258,8 @@ void BlockMaskedMM::eval_gpu(const std::vector& inputs, array& out) { batch_strides = std::vector(inputs.size(), 0); } - auto batch_size_out = out.size() / (M * N); - int matrix_stride_out = M * N; + size_t matrix_stride_out = size_t(M) * N; + size_t batch_size_out = out.size() / (matrix_stride_out); ///////////////////////////////////////////////////////////////////////////// // Regular kernel dispatch @@ -1301,9 +1300,9 @@ void BlockMaskedMM::eval_gpu(const std::vector& inputs, array& out) { /* const int ldd = */ N, /* const int tiles_n = */ tn, /* const int tiles_m = */ tm, - /* const int batch_stride_a = */ A_batch_str, - /* const int batch_stride_b = */ B_batch_str, - /* const int batch_stride_d = */ matrix_stride_out, + /* const size_t batch_stride_a = */ A_batch_str, + /* const size_t batch_stride_b = */ B_batch_str, + /* const size_t batch_stride_d = */ matrix_stride_out, /* const int swizzle_log = */ swizzle_log, /* const int gemm_k_iterations_aligned = */ (K / bk), /* const int batch_ndim = */ int(batch_shape.size())}; @@ -1355,4 +1354,314 @@ void BlockMaskedMM::eval_gpu(const std::vector& inputs, array& out) { return; } +void BlockSparseMM::eval_gpu(const std::vector& inputs, array& out) { + using namespace mlx::steel; + // assert(inputs.size() == 2); + if (!issubdtype(out.dtype(), floating)) { + throw std::runtime_error( + "[matmul] Does not yet support non-floating point types."); + } + auto& s = stream(); + auto& d = metal::device(s.device); + + auto& a_pre = inputs[0]; + auto& b_pre = inputs[1]; + // Return 0s if either input is empty + if (a_pre.size() == 0 || b_pre.size() == 0) { + array zero = array(0, a_pre.dtype()); + copy_gpu(zero, out, CopyType::Scalar, s); + auto command_buffer = d.get_command_buffer(s.index); + command_buffer->addCompletedHandler([zero](MTL::CommandBuffer*) {}); + return; + } + + out.set_data(allocator::malloc_or_wait(out.nbytes())); + + ///////////////////////////////////////////////////////////////////////////// + // Init checks and prep + + // Keep a vector with copies to be cleared in the completed buffer to release + // the arrays + std::vector copies; + auto check_transpose = [&copies, &s](const array& arr) { + auto stx = arr.strides()[arr.ndim() - 2]; + auto sty = arr.strides()[arr.ndim() - 1]; + if (sty == 1) { + return std::make_tuple(false, stx, arr); + } else if (stx == 1) { + return std::make_tuple(true, sty, arr); + } else { + array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); + copy_gpu(arr, arr_copy, CopyType::General, s); + copies.push_back(arr_copy); + size_t stx = arr.shape(-1); + return std::make_tuple(false, stx, arr_copy); + } + }; + + auto [transpose_a, a_cols, a] = check_transpose(a_pre); + auto [transpose_b, b_cols, b] = check_transpose(b_pre); + + int lda = a_cols; + int ldb = b_cols; + + int M = a.shape(-2); + int N = b.shape(-1); + int K = a.shape(-1); + + ///////////////////////////////////////////////////////////////////////////// + // Check and collapse batch dimensions + + auto get_batch_dims = [](const auto& v) { + return decltype(v){v.begin(), v.end() - 2}; + }; + + auto& lhs_indices = inputs[2]; + auto& rhs_indices = inputs[3]; + + std::vector batch_shape = get_batch_dims(out.shape()); + std::vector batch_strides; + + batch_strides.insert( + batch_strides.end(), + lhs_indices.strides().begin(), + lhs_indices.strides().end()); + size_t lhs_indices_str = batch_strides.empty() ? 0 : batch_strides.back(); + + batch_strides.insert( + batch_strides.end(), + rhs_indices.strides().begin(), + rhs_indices.strides().end()); + size_t rhs_indices_str = batch_strides.empty() ? 0 : batch_strides.back(); + + int batch_ndim = batch_shape.size(); + + if (batch_ndim == 0) { + batch_shape = {1}; + batch_strides = {0}; + } + + int batch_ndim_A = a.ndim() - 2; + int batch_ndim_B = b.ndim() - 2; + std::vector operand_batch_ndim = {batch_ndim_A, batch_ndim_B}; + + std::vector batch_shape_A = get_batch_dims(a.shape()); + std::vector batch_strides_A = get_batch_dims(a.strides()); + std::vector batch_shape_B = get_batch_dims(b.shape()); + std::vector batch_strides_B = get_batch_dims(b.strides()); + + if (batch_ndim_A == 0) { + batch_shape_A = {1}; + batch_strides_A = {0}; + } + + if (batch_ndim_B == 0) { + batch_shape_B = {1}; + batch_strides_B = {0}; + } + + size_t matrix_stride_out = size_t(M) * N; + auto batch_size_out = out.size() / matrix_stride_out; + + ///////////////////////////////////////////////////////////////////////////// + // Gemv specialization + + // Route to gemv if needed + if (std::min(M, N) == 1) { + // Collect problem info + bool is_b_matrix = N != 1; + + auto& mat = is_b_matrix ? b : a; + auto& vec = is_b_matrix ? a : b; + bool transpose_mat = is_b_matrix ? !transpose_b : transpose_a; + int in_vector_len = K; + int out_vector_len = is_b_matrix ? N : M; + + int mat_cols = transpose_mat ? out_vector_len : in_vector_len; + int mat_rows = transpose_mat ? in_vector_len : out_vector_len; + int mat_ld = is_b_matrix ? b_cols : a_cols; + + auto batch_strides_mat = is_b_matrix ? batch_strides_B : batch_strides_A; + auto batch_strides_vec = is_b_matrix ? batch_strides_A : batch_strides_B; + + auto batch_shape_mat = is_b_matrix ? batch_shape_B : batch_shape_A; + auto batch_shape_vec = is_b_matrix ? batch_shape_A : batch_shape_B; + + if (!is_b_matrix) { + batch_strides = rhs_indices.strides(); + batch_strides.insert( + batch_strides.end(), + lhs_indices.strides().begin(), + lhs_indices.strides().end()); + } + + int batch_ndim = batch_shape.size(); + + // Determine dispatch kernel + int tm = 4, tn = 4; + int bm, bn, n_out_per_tgp; + std::ostringstream kname; + + if (transpose_mat) { + bm = 8; + bn = 8; + if (out_vector_len >= 24576) { + bn = 128; + } else if (out_vector_len >= 16384) { + bn = 64; + } else if (out_vector_len >= 8192) { + bn = 16; + } + + // Specialized kernel for very small outputs + tn = out_vector_len < tn ? 1 : tn; + + n_out_per_tgp = bn * tn; + kname << "gemv_t_bs_" << type_to_name(out); + + } else { + bm = out_vector_len >= 4096 ? 8 : 4; + bn = 32; + + // Specialized kernel for very small outputs + tm = out_vector_len < tm ? 1 : tm; + + n_out_per_tgp = bm * tm; + kname << "gemv_bs_" << type_to_name(out); + } + + kname << "_bm" << bm << "_bn" << bn << "_tm" << tm << "_tn" << tn; + + // Encode and dispatch kernel + auto& compute_encoder = d.get_command_encoder(s.index); + auto kernel = d.get_kernel(kname.str()); + compute_encoder->setComputePipelineState(kernel); + + int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp; + MTL::Size group_dims = MTL::Size(bn, bm, 1); + MTL::Size grid_dims = MTL::Size(n_tgp, 1, batch_size_out); + + compute_encoder.set_input_array(mat, 0); + compute_encoder.set_input_array(vec, 1); + compute_encoder.set_output_array(out, 3); + + compute_encoder->setBytes(&in_vector_len, sizeof(int), 4); + compute_encoder->setBytes(&out_vector_len, sizeof(int), 5); + compute_encoder->setBytes(&mat_ld, sizeof(int), 6); + + compute_encoder->setBytes(&batch_ndim, sizeof(int), 9); + set_vector_bytes(compute_encoder, batch_shape, 10); + set_vector_bytes(compute_encoder, batch_strides, 11); + + int batch_ndim_vec = batch_shape_vec.size(); + compute_encoder->setBytes(&batch_ndim_vec, sizeof(int), 12); + set_vector_bytes(compute_encoder, batch_shape_vec, 13); + set_vector_bytes(compute_encoder, batch_strides_vec, 14); + + int batch_ndim_mat = batch_shape_mat.size(); + compute_encoder->setBytes(&batch_ndim_mat, sizeof(int), 15); + set_vector_bytes(compute_encoder, batch_shape_mat, 16); + set_vector_bytes(compute_encoder, batch_strides_mat, 17); + + compute_encoder.set_input_array(lhs_indices, 18 + int(!is_b_matrix)); + compute_encoder.set_input_array(rhs_indices, 18 + int(is_b_matrix)); + + compute_encoder->dispatchThreadgroups(grid_dims, group_dims); + + d.get_command_buffer(s.index)->addCompletedHandler( + [copies](MTL::CommandBuffer*) mutable { copies.clear(); }); + return; + } + + ///////////////////////////////////////////////////////////////////////////// + // Regular kernel dispatch + + // Determine dispatch kernel + int bm = 32, bn = 32, bk = 16; + int wm = 2, wn = 2; + + if ((size_t)batch_size_out * M * N >= 1ul << 20) { + if (!transpose_a && transpose_b) { + bm = 64; + bn = (out.dtype() == float32) ? 64 : 32; + bk = (out.dtype() == float32) ? 16 : 32; + } else { + bm = 64; + bn = 64; + } + } + + // Prepare kernel name + std::ostringstream kname; + kname << "steel_block_sparse_gemm_" << (transpose_a ? 't' : 'n') + << (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_" + << type_to_name(out) << "_bm" << bm << "_bn" << bn << "_bk" << bk + << "_wm" << wm << "_wn" << wn << "_MN_" + << ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned" << "_K_" + << ((K % bk == 0) ? "t" : "n") << "aligned"; + + // Encode and dispatch kernel + auto& compute_encoder = d.get_command_encoder(s.index); + auto kernel = d.get_kernel(kname.str()); + compute_encoder->setComputePipelineState(kernel); + + // Use problem size to determine threadblock swizzle + int tn = (N + bn - 1) / bn; + int tm = (M + bm - 1) / bm; + + // TODO: Explore device-based tuning for swizzle + int swizzle_log = 0; // tm >= 6 ? 3 : (tm <= 3 ? 0 : 2); + + // Prepare steel matmul params + GEMMParams params{ + /* const int M = */ M, + /* const int N = */ N, + /* const int K = */ K, + /* const int lda = */ lda, + /* const int ldb = */ ldb, + /* const int ldd = */ N, + /* const int tiles_n = */ tn, + /* const int tiles_m = */ tm, + /* const size_t batch_stride_a = */ lhs_indices_str, + /* const size_t batch_stride_b = */ rhs_indices_str, + /* const size_t batch_stride_d = */ matrix_stride_out, + /* const int swizzle_log = */ swizzle_log, + /* const int gemm_k_iterations_aligned = */ (K / bk), + /* const int batch_ndim = */ batch_ndim}; + + // Prepare launch grid params + int tile = 1 << swizzle_log; + tm = (tm + tile - 1) / tile; + tn = tn * tile; + + MTL::Size group_dims = MTL::Size(32, wn, wm); + MTL::Size grid_dims = MTL::Size(tn, tm, batch_size_out); + + // Launch kernel + compute_encoder.set_input_array(a, 0); + compute_encoder.set_input_array(b, 1); + compute_encoder.set_output_array(out, 3); + + compute_encoder->setBytes(¶ms, sizeof(GEMMParams), 4); + + set_vector_bytes(compute_encoder, batch_shape, 6); + set_vector_bytes(compute_encoder, batch_strides, 7); + + compute_encoder.set_input_array(lhs_indices, 10); + compute_encoder.set_input_array(rhs_indices, 11); + + set_vector_bytes(compute_encoder, batch_shape_A, 12); + set_vector_bytes(compute_encoder, batch_strides_A, 13); + set_vector_bytes(compute_encoder, batch_shape_B, 14); + set_vector_bytes(compute_encoder, batch_strides_B, 15); + set_vector_bytes(compute_encoder, operand_batch_ndim, 16); + + compute_encoder->dispatchThreadgroups(grid_dims, group_dims); + + // Clear copies + d.get_command_buffer(s.index)->addCompletedHandler( + [copies](MTL::CommandBuffer*) mutable { copies.clear(); }); + return; +} + } // namespace mlx::core diff --git a/mlx/backend/no_metal/primitives.cpp b/mlx/backend/no_metal/primitives.cpp index f9248cf12..2b10e416a 100644 --- a/mlx/backend/no_metal/primitives.cpp +++ b/mlx/backend/no_metal/primitives.cpp @@ -32,6 +32,8 @@ NO_GPU(ArgSort) NO_GPU(AsType) NO_GPU(AsStrided) NO_GPU(BitwiseBinary) +NO_GPU(BlockMaskedMM) +NO_GPU(BlockSparseMM) NO_GPU(Broadcast) NO_GPU(Ceil) NO_GPU_MULTI(Compiled) @@ -100,7 +102,6 @@ NO_GPU(Subtract) NO_GPU_MULTI(SVD) NO_GPU(Tan) NO_GPU(Tanh) -NO_GPU(BlockMaskedMM) NO_GPU(Transpose) NO_GPU(Inverse) diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 3ad012932..92c6137ef 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -3785,6 +3785,124 @@ array block_masked_mm( return out; } +/** Compute matrix product with matrix-level gather */ +array block_sparse_mm( + array a, + array b, + std::optional lhs_indices_ /* = std::nullopt */, + std::optional rhs_indices_ /* = std::nullopt */, + StreamOrDevice s /* = {} */) { + // If no indices, fall back to full matmul + if (!lhs_indices_ && !rhs_indices_) { + return matmul(a, b, s); + } + + // Do shape checks for operands + int in_a_ndim = a.ndim(); + int in_b_ndim = b.ndim(); + + if (a.ndim() == 0 || b.ndim() == 0) { + throw std::invalid_argument( + "[block_sparse_mm] Got 0 dimension input. Inputs must " + "have at least one dimension."); + } + + if (a.ndim() == 1) { + // Insert a singleton dim in the beginning + a = reshape(a, {1, -1}, s); + } + if (b.ndim() == 1) { + // Insert a singleton dim at the end + b = reshape(b, {-1, 1}, s); + } + + if (a.shape(-1) != b.shape(-2)) { + std::ostringstream msg; + msg << "[block_sparse_mm] Last dimension of first input with shape " + << a.shape() << " must match second to last dimension of" + << " second input with shape " << b.shape() << "."; + throw std::invalid_argument(msg.str()); + } + + // Type promotion + auto out_type = result_type(a, b); + if (!issubdtype(out_type, floating)) { + std::ostringstream msg; + msg << "[block_sparse_mm] Only real floating point types are supported but " + << a.dtype() << " and " << b.dtype() + << " were provided which results in " << out_type + << ", which is not a real floating point type."; + throw std::invalid_argument(msg.str()); + } + + a = astype(a, out_type, s); + b = astype(b, out_type, s); + + // Handle broadcasting + std::vector bsx_a(a.shape().begin(), a.shape().end() - 2); + std::vector bsx_b(b.shape().begin(), b.shape().end() - 2); + + auto indices_or_default = [&](const std::optional& indices, + const std::vector& bsx_shape) { + if (indices.has_value()) { + return indices.value(); + } else { + int n_batch = 1; + for (auto& i : bsx_shape) + n_batch *= i; + return reshape(arange(n_batch, uint32, s), bsx_shape, s); + } + }; + + // Pull and broadcast indices + array lhs_indices = indices_or_default(lhs_indices_, bsx_a); + array rhs_indices = indices_or_default(rhs_indices_, bsx_b); + + if (!issubdtype(lhs_indices.dtype(), integer)) { + throw std::invalid_argument( + "[block_sparse_mm] Got lhs_indices with invalid dtype. Indices must be integral."); + } + + if (!issubdtype(rhs_indices.dtype(), integer)) { + throw std::invalid_argument( + "[block_sparse_mm] Got rhs_indices with invalid dtype. Indices must be integral."); + } + + lhs_indices = astype(lhs_indices, uint32, s); + rhs_indices = astype(rhs_indices, uint32, s); + + int M = a.shape(-2); + int N = b.shape(-1); + int K = a.shape(-1); + + auto out_bsx_shape = + broadcast_shapes(lhs_indices.shape(), rhs_indices.shape()); + + lhs_indices = broadcast_to(lhs_indices, out_bsx_shape, s); + rhs_indices = broadcast_to(rhs_indices, out_bsx_shape, s); + + auto out_shape = out_bsx_shape; + out_shape.push_back(M); + out_shape.push_back(N); + + // Caculate array + auto out = array( + out_shape, + out_type, + std::make_shared(to_stream(s)), + {a, b, lhs_indices, rhs_indices}); + + // Remove the possibly inserted singleton dimensions + if (in_a_ndim == 1 || in_b_ndim == 1) { + out_shape.erase( + out_shape.end() - ((in_a_ndim == 1) ? 2 : 1), + out_shape.end() - ((in_b_ndim == 1) ? 0 : 1)); + out = reshape(out, out_shape, s); + } + + return out; +} + array diagonal( const array& a, int offset /* = 0 */, diff --git a/mlx/ops.h b/mlx/ops.h index a909726f7..c75dc1846 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -1183,6 +1183,14 @@ array block_masked_mm( std::optional mask_rhs = std::nullopt, StreamOrDevice s = {}); +/** Compute matrix product with matrix-level gather */ +array block_sparse_mm( + array a, + array b, + std::optional lhs_indices = std::nullopt, + std::optional rhs_indices = std::nullopt, + StreamOrDevice s = {}); + /** Extract a diagonal or construct a diagonal array */ array diagonal( const array& a, diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 85e56ac54..1f50d1e9c 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -3357,7 +3357,61 @@ std::vector BlockMaskedMM::vjp( vjps.push_back(grad); } else { - vjps.push_back(zeros_like(primals[arg], stream())); + throw std::invalid_argument( + "[BlockMaskedMM] Cannot calculate VJP with respect to masks."); + } + } + return vjps; +} + +std::vector BlockSparseMM::vjp( + const std::vector& primals, + const std::vector& cotangents, + const std::vector& argnums, + const std::vector&) { + std::vector vjps; + auto& cotan = cotangents[0]; + + auto& lhs_indices = primals[2]; + auto& rhs_indices = primals[3]; + + int M = cotan.shape(-2); + int N = cotan.shape(-1); + int K = primals[0].shape(-1); + + for (auto arg : argnums) { + if (arg == 0) { + // M X N * (K X N).T -> M X K + auto base = zeros_like(primals[0], stream()); + auto bt = swapaxes(primals[1], -1, -2, stream()); + + auto base_shape = base.shape(); + base = reshape(base, {-1, M, K}, stream()); + + // g : (out_batch_shape) + (M, K) + auto g = block_sparse_mm(cotan, bt, std::nullopt, rhs_indices, stream()); + g = expand_dims(g, -3, stream()); + auto gacc = scatter_add(base, lhs_indices, g, 0, stream()); + + vjps.push_back(reshape(gacc, base_shape, stream())); + + } else if (arg == 1) { + // (M X K).T * M X N -> K X N + auto base = zeros_like(primals[1], stream()); + auto at = swapaxes(primals[0], -1, -2, stream()); + + auto base_shape = base.shape(); + base = reshape(base, {-1, K, N}, stream()); + + // g : (out_batch_shape) + (K, N) + auto g = block_sparse_mm(at, cotan, lhs_indices, std::nullopt, stream()); + g = expand_dims(g, -3, stream()); + auto gacc = scatter_add(base, rhs_indices, g, 0, stream()); + + vjps.push_back(reshape(gacc, base_shape, stream())); + } else { + throw std::invalid_argument( + "[BlockSparseMM] Cannot calculate VJP with respect to indices."); } } return vjps; diff --git a/mlx/primitives.h b/mlx/primitives.h index 1e61bde5b..7d0aca52b 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -485,6 +485,26 @@ class BlockMaskedMM : public UnaryPrimitive { void eval(const std::vector& inputs, array& out); }; +class BlockSparseMM : public UnaryPrimitive { + public: + explicit BlockSparseMM(Stream stream) : UnaryPrimitive(stream) {}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::vector vjp( + const std::vector& primals, + const std::vector& cotangents, + const std::vector& argnums, + const std::vector& outputs) override; + + DEFINE_PRINT(BlockSparseMM) + DEFINE_DEFAULT_IS_EQUIVALENT() + + private: + void eval(const std::vector& inputs, array& out); +}; + class Broadcast : public UnaryPrimitive { public: explicit Broadcast(Stream stream, const std::vector& shape) diff --git a/python/src/ops.cpp b/python/src/ops.cpp index fb2bebdb6..a33ed822d 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -3716,6 +3716,38 @@ void init_ops(nb::module_& m) { mask_lhs (array, optional): Boolean mask for a (default: ``None``) mask_rhs (array, optional): Boolean mask for b (default: ``None``) + )pbdoc"); + m.def( + "block_sparse_mm", + &block_sparse_mm, + nb::arg(), + nb::arg(), + "lhs_indices"_a = nb::none(), + "rhs_indices"_a = nb::none(), + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def block_sparse_mm(a: array, b: array, /, lhs_indices: array, rhs_indices: array, *, stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + Matrix multiplication with matrix-level gather. + + Performs a gather of the operands with the given indices followed by a (possibly batched) matrix multiplication of two arrays. + This operation is more efficient than explicitly applying a :func:``take`` followed by a :func:``matmul``. + + The indices ``lhs_indices`` and ``rhs_indices`` contain flat indices along the batch dimensions (i.e. all but the last two dimensions) of ``a`` and ``b`` respectively. + + For ``a`` with shape ``(A1, A2, ..., AS, M, K)``, + ``lhs_indices`` contains indices from the range ``[0, A1 * A2 * ... * AS)`` + + For ``b`` with shape ``(B1, B2, ..., BS, M, K)``, + ``rhs_indices`` contains indices from the range ``[0, B1 * B2 * ... * BS)`` + + Args: + a (array): Input array. + b (array): Input array. + lhs_indices (array, optional): Integer indices for ``a`` (default: ``None``) + rhs_indices (array, optional): Integer indices for ``b`` (default: ``None``) + )pbdoc"); m.def( "diagonal", diff --git a/python/tests/test_blas.py b/python/tests/test_blas.py index 9f24d294d..884e8d73b 100644 --- a/python/tests/test_blas.py +++ b/python/tests/test_blas.py @@ -810,6 +810,199 @@ class TestBlas(mlx_tests.MLXTestCase): self.assertTrue(np.allclose(c_mx, c_np, atol=1e-5)) + def test_block_sparse_matmul(self): + def np_block_sparse_mm(a, b, lhs_indices=None, rhs_indices=None): + a = a.reshape((-1, a.shape[-2], a.shape[-1])) + b = b.reshape((-1, b.shape[-2], b.shape[-1])) + lhs_indices = lhs_indices or np.arange(a.shape[0]) + rhs_indices = rhs_indices or np.arange(b.shape[0]) + a = a[lhs_indices, :, :] + b = b[rhs_indices, :, :] + out = a @ b + return out + + def test_shape( + M, + N, + K, + np_dtype=np.float32, + batch_A=(), + batch_B=(), + lhs_indices=None, + rhs_indices=None, + ): + with self.subTest( + M=M, + N=N, + K=K, + np_dtype=np_dtype, + batch_A=batch_A, + batch_B=batch_B, + lhs_indices=lhs_indices, + rhs_indices=rhs_indices, + ): + + a_np = np.random.normal(size=batch_A + (M, K)).astype(np_dtype) + b_np = np.random.normal(size=batch_B + (K, N)).astype(np_dtype) + + a_mx = mx.array(a_np) + b_mx = mx.array(b_np) + + out_np = np_block_sparse_mm(a_np, b_np, lhs_indices, rhs_indices) + + lhs_indices_mx = None if lhs_indices is None else mx.array(lhs_indices) + rhs_indices_mx = None if rhs_indices is None else mx.array(rhs_indices) + + out_mx = mx.block_sparse_mm(a_mx, b_mx, lhs_indices_mx, rhs_indices_mx) + + self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5)) + + inputs = ( + { + "batch_A": (1,), + "lhs_indices": (0,), + "batch_B": (3,), + "rhs_indices": (2, 1), + }, + { + "batch_A": (1,), + "lhs_indices": None, + "batch_B": (3,), + "rhs_indices": (2, 1), + }, + { + "batch_A": (2,), + "lhs_indices": None, + "batch_B": (3,), + "rhs_indices": (2, 1), + }, + { + "batch_A": (3,), + "lhs_indices": (0, 2), + "batch_B": (1,), + "rhs_indices": (0,), + }, + { + "batch_A": (5,), + "lhs_indices": (0, 2), + "batch_B": (3,), + "rhs_indices": (2, 1), + }, + { + "batch_A": (4, 2), + "lhs_indices": ( + (7, 6), + (5, 4), + (1, 2), + ), + "batch_B": (4, 1), + "rhs_indices": ((2,), (0,), (1,)), + }, + ) + + for kwargs in inputs: + test_shape(32, 32, 32, **kwargs) + test_shape(16, 1, 16, **kwargs) + + # Add tests for broadcasting + a_np = np.random.normal(size=(5, 32, 32)).astype(np.float32) + b_np = np.random.normal(size=(3, 32, 32)).astype(np.float32) + a_mx = mx.array(a_np) + b_mx = mx.array(b_np) + + # Numpy + a_np = a_np.reshape((5, 1, 32, 32)) + b_np = b_np.reshape((1, 3, 32, 32)) + + a_np = np.broadcast_to(a_np, (5, 4, 32, 32)) + b_np = np.broadcast_to(b_np, (2, 3, 32, 32)).swapaxes(1, 0) + + lhs_indices = [0, 13, 12] + rhs_indices = [0, 3, 5] + + out_np = np_block_sparse_mm(a_np, b_np, lhs_indices, rhs_indices) + + # MLX + a_mx = a_mx.reshape((5, 1, 32, 32)) + b_mx = b_mx.reshape((1, 3, 32, 32)) + + a_mx = mx.broadcast_to(a_mx, (5, 4, 32, 32)) + b_mx = mx.broadcast_to(b_mx, (2, 3, 32, 32)).swapaxes(1, 0) + + lhs_indices_mx = mx.array(lhs_indices) + rhs_indices_mx = mx.array(rhs_indices) + + out_mx = mx.block_sparse_mm(a_mx, b_mx, lhs_indices_mx, rhs_indices_mx) + + self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5)) + + # Gemv test + a_np = np.random.normal(size=(5, 1, 32)).astype(np.float32) + b_np = np.random.normal(size=(3, 16, 32)).astype(np.float32) + a_mx = mx.array(a_np) + b_mx = mx.array(b_np) + + lhs_indices = [3, 1] + rhs_indices = [0, 2] + + b_np_t = np.swapaxes(b_np, -1, -2) + out_np = np_block_sparse_mm(a_np, b_np_t, lhs_indices, rhs_indices) + + lhs_indices_mx = mx.array(lhs_indices) + rhs_indices_mx = mx.array(rhs_indices) + + b_mx_t = mx.swapaxes(b_mx, -1, -2) + out_mx = mx.block_sparse_mm(a_mx, b_mx_t, lhs_indices_mx, rhs_indices_mx) + + self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5)) + + def test_block_sparse_matmul_grad(self): + + lhs_indices = mx.array([[7, 6], [4, 1], [0, 2]], dtype=mx.uint32) + rhs_indices = mx.array([[2], [0], [1]], dtype=mx.uint32) + + def f_ref(a, b): + lhs_indices_ = mx.broadcast_to(lhs_indices, (3, 2)) + rhs_indices_ = mx.broadcast_to(rhs_indices, (3, 2)) + M = a.shape[-2] + N = b.shape[-2] + K = a.shape[-1] + + a = a.reshape((-1, M, K)) + b = b.reshape((-1, K, N)) + + a = mx.take(a, lhs_indices_, 0) + b = mx.take(b, rhs_indices_, 0) + + return a @ b + + def f_test(a, b): + return mx.block_sparse_mm(a, b, lhs_indices, rhs_indices) + + a_mx = mx.random.normal((4, 2, 32, 32)) + b_mx = mx.random.normal((4, 1, 32, 32)) + + out_test = f_test(a_mx, b_mx) + out_ref = f_ref(a_mx, b_mx) + + self.assertTrue(mx.allclose(out_test, out_ref, atol=1e-5)) + + cotan = mx.ones_like(out_test) + out_ref, dout_ref = mx.vjp( + f_ref, + [a_mx, b_mx], + [cotan], + ) + out_test, dout_test = mx.vjp( + f_test, + [a_mx, b_mx], + [cotan], + ) + + for r, t in zip(dout_ref, dout_test): + self.assertEqual(r.shape, t.shape) + self.assertTrue(mx.allclose(r, t, atol=1e-4).item()) + if __name__ == "__main__": unittest.main()