diff --git a/mlx/backend/metal/matmul.cpp b/mlx/backend/metal/matmul.cpp index 03d5a89cb..3a17686fc 100644 --- a/mlx/backend/metal/matmul.cpp +++ b/mlx/backend/metal/matmul.cpp @@ -494,11 +494,13 @@ inline void steel_gemm_splitk( /* std::vector& copies = */ copies); } -void steel_matmul( +template +void steel_matmul_axpby( const Stream& s, metal::Device& d, const array& a, const array& b, + const array& c, array& out, int M, int N, @@ -511,32 +513,56 @@ void steel_matmul( std::vector& copies, Shape batch_shape /* = {} */, Strides A_batch_stride /* = {} */, - Strides B_batch_stride /* = {} */) { + Strides B_batch_stride /* = {} */, + Strides C_batch_stride /* = {} */, + float alpha /* = 1.0f */, + float beta /* = 0.0f */) { using namespace mlx::steel; if (batch_shape.empty()) { ///////////////////////////////////////////////////////////////////////////// // Check and collapse batch dimensions - auto [batch_shape_, A_bstride_, B_bstride_] = collapse_batches(a, b); + if constexpr (CHECK_AB) { + auto [batch_shape_, A_bstride_, B_bstride_, C_bstride_] = + collapse_batches(a, b, c); - batch_shape = batch_shape_; - A_batch_stride = A_bstride_; - B_batch_stride = B_bstride_; - // Collapse batches into M if needed - if (batch_size_out > 1 && !transpose_a && batch_shape.size() == 1 && - a.strides()[a.ndim() - 2] == K && A_batch_stride.back() == M * K && - B_batch_stride.back() == 0) { - M *= batch_shape.back(); - batch_size_out = 1; + batch_shape = batch_shape_; + A_batch_stride = A_bstride_; + B_batch_stride = B_bstride_; + C_batch_stride = C_bstride_; + // Collapse batches into M if needed + if (batch_size_out > 1 && !transpose_a && batch_shape.size() == 1 && + a.strides()[a.ndim() - 2] == K && A_batch_stride.back() == M * K && + C_batch_stride.back() == M * c.strides()[c.ndim() - 2] && + B_batch_stride.back() == 0) { + M *= batch_shape.back(); + batch_size_out = 1; - A_batch_stride = {0}; - B_batch_stride = {0}; - batch_shape = {1}; + A_batch_stride = {0}; + B_batch_stride = {0}; + C_batch_stride = {0}; + batch_shape = {1}; + } + } else { + auto [batch_shape_, A_bstride_, B_bstride_] = collapse_batches(a, b); + + batch_shape = batch_shape_; + A_batch_stride = A_bstride_; + B_batch_stride = B_bstride_; + // Collapse batches into M if needed + if (batch_size_out > 1 && !transpose_a && batch_shape.size() == 1 && + a.strides()[a.ndim() - 2] == K && A_batch_stride.back() == M * K && + B_batch_stride.back() == 0) { + M *= batch_shape.back(); + batch_size_out = 1; + + A_batch_stride = {0}; + B_batch_stride = {0}; + batch_shape = {1}; + } } } - size_t matrix_stride_out = size_t(M) * N; - ///////////////////////////////////////////////////////////////////////////// // Split K specialization @@ -545,11 +571,12 @@ void steel_matmul( int _tk = K / 16; if (batch_size_out == 1 && (_tm * _tn) <= 32 && _tk >= 8) { - return steel_gemm_splitk( + return steel_gemm_splitk_axpby( /* const Stream& s = */ s, /* metal::Device& d = */ d, /* const array& a = */ a, /* const array& b = */ b, + /* const array& c = */ c, /* array& out = */ out, /* int M = */ M, /* int N = */ N, @@ -559,7 +586,9 @@ void steel_matmul( /* int ldb = */ ldb, /* bool transpose_a = */ transpose_a, /* bool transpose_b = */ transpose_b, - /* std::vector& copies = */ copies); + /* std::vector& copies = */ copies, + /* float alpha = */ alpha, + /* float beta = */ beta); } ///////////////////////////////////////////////////////////////////////////// @@ -567,12 +596,21 @@ void steel_matmul( auto batch_strides = A_batch_stride; batch_strides.insert( batch_strides.end(), B_batch_stride.begin(), B_batch_stride.end()); + if (CHECK_AB && !C_batch_stride.empty()) { + batch_strides.insert( + batch_strides.end(), C_batch_stride.begin(), C_batch_stride.end()); + } - return steel_matmul_regular( + int64_t A_batch_stride_ = A_batch_stride.empty() ? 0 : A_batch_stride.back(); + int64_t B_batch_stride_ = B_batch_stride.empty() ? 0 : B_batch_stride.back(); + int64_t C_batch_stride_ = C_batch_stride.empty() ? 0 : C_batch_stride.back(); + + return steel_matmul_regular_axpby( /* const Stream& s = */ s, /* metal::Device& d = */ d, /* const array& a = */ a, /* const array& b = */ b, + /* const array& c = */ c, /* array& out = */ out, /* int M = */ M, /* int N = */ N, @@ -586,11 +624,114 @@ void steel_matmul( /* std::vector& copies = */ copies, /* Shape batch_shape = */ std::move(batch_shape), /* Strides batch_strides = */ std::move(batch_strides), - /* int64_t A_batch_stride = */ A_batch_stride.back(), - /* int64_t B_batch_stride = */ B_batch_stride.back(), - /* int64_t matrix_stride_out = */ matrix_stride_out); + /* int64_t A_batch_stride = */ A_batch_stride_, + /* int64_t B_batch_stride = */ B_batch_stride_, + /* int64_t matrix_stride_out = */ int64_t(M) * N, + /* int64_t C_batch_stride = */ C_batch_stride_, + /* float alpha = */ alpha, + /* float beta = */ beta); } +// void steel_matmul( +// const Stream& s, +// metal::Device& d, +// const array& a, +// const array& b, +// array& out, +// int M, +// int N, +// int K, +// int batch_size_out, +// int lda, +// int ldb, +// bool transpose_a, +// bool transpose_b, +// std::vector& copies, +// Shape batch_shape /* = {} */, +// Strides A_batch_stride /* = {} */, +// Strides B_batch_stride /* = {} */) { + +// return + +// using namespace mlx::steel; + +// if (batch_shape.empty()) { +// ///////////////////////////////////////////////////////////////////////////// +// // Check and collapse batch dimensions +// auto [batch_shape_, A_bstride_, B_bstride_] = collapse_batches(a, b); + +// batch_shape = batch_shape_; +// A_batch_stride = A_bstride_; +// B_batch_stride = B_bstride_; +// // Collapse batches into M if needed +// if (batch_size_out > 1 && !transpose_a && batch_shape.size() == 1 && +// a.strides()[a.ndim() - 2] == K && A_batch_stride.back() == M * K && +// B_batch_stride.back() == 0) { +// M *= batch_shape.back(); +// batch_size_out = 1; + +// A_batch_stride = {0}; +// B_batch_stride = {0}; +// batch_shape = {1}; +// } +// } + +// size_t matrix_stride_out = size_t(M) * N; + +// ///////////////////////////////////////////////////////////////////////////// +// // Split K specialization + +// int _tm = M / 16; +// int _tn = N / 16; +// int _tk = K / 16; + +// if (batch_size_out == 1 && (_tm * _tn) <= 32 && _tk >= 8) { +// return steel_gemm_splitk( +// /* const Stream& s = */ s, +// /* metal::Device& d = */ d, +// /* const array& a = */ a, +// /* const array& b = */ b, +// /* array& out = */ out, +// /* int M = */ M, +// /* int N = */ N, +// /* int K = */ K, +// /* int batch_size_out = */ batch_size_out, +// /* int lda = */ lda, +// /* int ldb = */ ldb, +// /* bool transpose_a = */ transpose_a, +// /* bool transpose_b = */ transpose_b, +// /* std::vector& copies = */ copies); +// } + +// ///////////////////////////////////////////////////////////////////////////// +// // Regular kernel dispatch +// auto batch_strides = A_batch_stride; +// batch_strides.insert( +// batch_strides.end(), B_batch_stride.begin(), B_batch_stride.end()); + +// return steel_matmul_regular( +// /* const Stream& s = */ s, +// /* metal::Device& d = */ d, +// /* const array& a = */ a, +// /* const array& b = */ b, +// /* array& out = */ out, +// /* int M = */ M, +// /* int N = */ N, +// /* int K = */ K, +// /* int batch_size_out = */ batch_size_out, +// /* int lda = */ lda, +// /* int ldb = */ ldb, +// /* int ldd = */ N, +// /* bool transpose_a = */ transpose_a, +// /* bool transpose_b = */ transpose_b, +// /* std::vector& copies = */ copies, +// /* Shape batch_shape = */ std::move(batch_shape), +// /* Strides batch_strides = */ std::move(batch_strides), +// /* int64_t A_batch_stride = */ A_batch_stride.back(), +// /* int64_t B_batch_stride = */ B_batch_stride.back(), +// /* int64_t matrix_stride_out = */ matrix_stride_out); +// } + template void gemv_axbpy( const Stream& s, diff --git a/mlx/backend/metal/matmul.h b/mlx/backend/metal/matmul.h index 9c898b282..218664b1f 100644 --- a/mlx/backend/metal/matmul.h +++ b/mlx/backend/metal/matmul.h @@ -78,7 +78,31 @@ inline void steel_matmul_regular( /* int64_t matrix_stride_out = */ matrix_stride_out); } -void steel_matmul( +template +void steel_matmul_axpby( + const Stream& s, + metal::Device& d, + const array& a, + const array& b, + const array& c, + array& out, + int M, + int N, + int K, + int batch_size_out, + int lda, + int ldb, + bool transpose_a, + bool transpose_b, + std::vector& copies, + Shape batch_shape = {}, + Strides A_batch_stride = {}, + Strides B_batch_stride = {}, + Strides C_batch_stride = {}, + float alpha = 1.0f, + float beta = 0.0f); + +inline void steel_matmul( const Stream& s, metal::Device& d, const array& a, @@ -95,6 +119,26 @@ void steel_matmul( std::vector& copies, Shape batch_shape = {}, Strides A_batch_stride = {}, - Strides B_batch_stride = {}); + Strides B_batch_stride = {}) { + return steel_matmul_axpby( + /* const Stream& s = */ s, + /* metal::Device& d = */ d, + /* const array& a = */ a, + /* const array& b = */ b, + /* const array& c = */ b, + /* array& out = */ out, + /* int M = */ M, + /* int N = */ N, + /* int K = */ K, + /* int batch_size_out = */ batch_size_out, + /* int lda = */ lda, + /* int ldb = */ ldb, + /* bool transpose_a = */ transpose_a, + /* bool transpose_b = */ transpose_b, + /* std::vector& copies = */ copies, + /* Shape batch_shape = */ batch_shape, + /* Strides A_batch_stride = */ A_batch_stride, + /* Strides B_batch_stride = */ B_batch_stride); +} } // namespace mlx::core