From b1d95a3880223dce788bae9941afbb5f3a626ae5 Mon Sep 17 00:00:00 2001 From: Jagrit Digani Date: Wed, 11 Jun 2025 09:43:34 -0700 Subject: [PATCH] Some cleanup --- .../steel/gemm/kernels/steel_gemm_fused.h | 4 +- mlx/backend/metal/matmul.cpp | 69 ++++++++++--------- 2 files changed, 37 insertions(+), 36 deletions(-) diff --git a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h index add495d93..85830872d 100644 --- a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +++ b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h @@ -33,8 +33,8 @@ template < device T* D [[buffer(3)]], const constant GEMMParams* params [[buffer(4)]], const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]], - const constant int* batch_shape [[buffer(6)]], - const constant int64_t* batch_strides [[buffer(7)]], + const constant int* batch_shape [[buffer(6), function_constant(has_batch)]], + const constant int64_t* batch_strides [[buffer(7), function_constant(has_batch)]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]], uint3 tid [[threadgroup_position_in_grid]], diff --git a/mlx/backend/metal/matmul.cpp b/mlx/backend/metal/matmul.cpp index c8c933223..2a7a8dd94 100644 --- a/mlx/backend/metal/matmul.cpp +++ b/mlx/backend/metal/matmul.cpp @@ -164,6 +164,10 @@ ensure_batch_contiguous(const array& x, metal::Device& d, const Stream& s) { wn = 2; \ } +/////////////////////////////////////////////////////////////////////////////// +// Regular steel matmul dispatch +/////////////////////////////////////////////////////////////////////////////// + template void steel_matmul_regular_axpby( const Stream& s, @@ -296,8 +300,10 @@ void steel_matmul_regular_axpby( compute_encoder.set_bytes(params, 4); - compute_encoder.set_vector_bytes(batch_shape, 6); - compute_encoder.set_vector_bytes(batch_strides, 7); + if (has_batch) { + compute_encoder.set_vector_bytes(batch_shape, 6); + compute_encoder.set_vector_bytes(batch_strides, 7); + } if (use_out_source) { int ldc = c.strides()[c.ndim() - 2]; @@ -320,6 +326,10 @@ void steel_matmul_regular_axpby( d.add_temporaries(std::move(copies), s.index); } +/////////////////////////////////////////////////////////////////////////////// +// Split k steel matmul +/////////////////////////////////////////////////////////////////////////////// + template void steel_gemm_splitk_axpby( const Stream& s, @@ -466,38 +476,9 @@ void steel_gemm_splitk_axpby( d.add_temporaries(std::move(copies), s.index); } -inline void steel_gemm_splitk( - const Stream& s, - metal::Device& d, - const array& a, - const array& b, - array& out, - int M, - int N, - int K, - int batch_size_out, - int lda, - int ldb, - bool transpose_a, - bool transpose_b, - std::vector& copies) { - return steel_gemm_splitk_axpby( - /* const Stream& s = */ s, - /* metal::Device& d = */ d, - /* const array& a = */ a, - /* const array& b = */ b, - /* const array& c = */ b, - /* array& out = */ out, - /* int M = */ M, - /* int N = */ N, - /* int K = */ K, - /* int batch_size_out = */ batch_size_out, - /* int lda = */ lda, - /* int ldb = */ ldb, - /* bool transpose_a = */ transpose_a, - /* bool transpose_b = */ transpose_b, - /* std::vector& copies = */ copies); -} +/////////////////////////////////////////////////////////////////////////////// +// Split matmul routing +/////////////////////////////////////////////////////////////////////////////// template void steel_matmul_axpby( @@ -637,6 +618,10 @@ void steel_matmul_axpby( /* float beta = */ beta); } +/////////////////////////////////////////////////////////////////////////////// +// GEMV dispatch +/////////////////////////////////////////////////////////////////////////////// + template void gemv_axbpy( const Stream& s, @@ -812,6 +797,10 @@ inline void gemv( /* Strides B_batch_stride = */ B_batch_stride); } +/////////////////////////////////////////////////////////////////////////////// +// Matmul implementation +/////////////////////////////////////////////////////////////////////////////// + void Matmul::eval_gpu(const std::vector& inputs, array& out) { assert(inputs.size() == 2); if (!issubdtype(out.dtype(), floating)) { @@ -913,6 +902,10 @@ void Matmul::eval_gpu(const std::vector& inputs, array& out) { /* Strides B_batch_stride = */ std::move(B_batch_stride)); } +/////////////////////////////////////////////////////////////////////////////// +// AddMM implementation +/////////////////////////////////////////////////////////////////////////////// + void AddMM::eval_gpu(const std::vector& inputs, array& out) { assert(inputs.size() == 3); if (!issubdtype(out.dtype(), floating)) { @@ -1043,6 +1036,10 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { /* float beta = */ beta_); } +/////////////////////////////////////////////////////////////////////////////// +// BlockMaskedMM implementation +/////////////////////////////////////////////////////////////////////////////// + void BlockMaskedMM::eval_gpu(const std::vector& inputs, array& out) { using namespace mlx::steel; // assert(inputs.size() == 2); @@ -1431,6 +1428,10 @@ void BlockMaskedMM::eval_gpu(const std::vector& inputs, array& out) { d.add_temporaries(std::move(copies), s.index); } +/////////////////////////////////////////////////////////////////////////////// +// GatherMM implementation +/////////////////////////////////////////////////////////////////////////////// + void gather_mm_rhs( const array& a_, const array& b_,