diff --git a/mlx/backend/metal/matmul.cpp b/mlx/backend/metal/matmul.cpp index 3a17686fc..bc4cee56f 100644 --- a/mlx/backend/metal/matmul.cpp +++ b/mlx/backend/metal/matmul.cpp @@ -632,106 +632,6 @@ void steel_matmul_axpby( /* float beta = */ beta); } -// void steel_matmul( -// const Stream& s, -// metal::Device& d, -// const array& a, -// const array& b, -// array& out, -// int M, -// int N, -// int K, -// int batch_size_out, -// int lda, -// int ldb, -// bool transpose_a, -// bool transpose_b, -// std::vector& copies, -// Shape batch_shape /* = {} */, -// Strides A_batch_stride /* = {} */, -// Strides B_batch_stride /* = {} */) { - -// return - -// using namespace mlx::steel; - -// if (batch_shape.empty()) { -// ///////////////////////////////////////////////////////////////////////////// -// // Check and collapse batch dimensions -// auto [batch_shape_, A_bstride_, B_bstride_] = collapse_batches(a, b); - -// batch_shape = batch_shape_; -// A_batch_stride = A_bstride_; -// B_batch_stride = B_bstride_; -// // Collapse batches into M if needed -// if (batch_size_out > 1 && !transpose_a && batch_shape.size() == 1 && -// a.strides()[a.ndim() - 2] == K && A_batch_stride.back() == M * K && -// B_batch_stride.back() == 0) { -// M *= batch_shape.back(); -// batch_size_out = 1; - -// A_batch_stride = {0}; -// B_batch_stride = {0}; -// batch_shape = {1}; -// } -// } - -// size_t matrix_stride_out = size_t(M) * N; - -// ///////////////////////////////////////////////////////////////////////////// -// // Split K specialization - -// int _tm = M / 16; -// int _tn = N / 16; -// int _tk = K / 16; - -// if (batch_size_out == 1 && (_tm * _tn) <= 32 && _tk >= 8) { -// return steel_gemm_splitk( -// /* const Stream& s = */ s, -// /* metal::Device& d = */ d, -// /* const array& a = */ a, -// /* const array& b = */ b, -// /* array& out = */ out, -// /* int M = */ M, -// /* int N = */ N, -// /* int K = */ K, -// /* int batch_size_out = */ batch_size_out, -// /* int lda = */ lda, -// /* int ldb = */ ldb, -// /* bool transpose_a = */ transpose_a, -// /* bool transpose_b = */ transpose_b, -// /* std::vector& copies = */ copies); -// } - -// ///////////////////////////////////////////////////////////////////////////// -// // Regular kernel dispatch -// auto batch_strides = A_batch_stride; -// batch_strides.insert( -// batch_strides.end(), B_batch_stride.begin(), B_batch_stride.end()); - -// return steel_matmul_regular( -// /* const Stream& s = */ s, -// /* metal::Device& d = */ d, -// /* const array& a = */ a, -// /* const array& b = */ b, -// /* array& out = */ out, -// /* int M = */ M, -// /* int N = */ N, -// /* int K = */ K, -// /* int batch_size_out = */ batch_size_out, -// /* int lda = */ lda, -// /* int ldb = */ ldb, -// /* int ldd = */ N, -// /* bool transpose_a = */ transpose_a, -// /* bool transpose_b = */ transpose_b, -// /* std::vector& copies = */ copies, -// /* Shape batch_shape = */ std::move(batch_shape), -// /* Strides batch_strides = */ std::move(batch_strides), -// /* int64_t A_batch_stride = */ A_batch_stride.back(), -// /* int64_t B_batch_stride = */ B_batch_stride.back(), -// /* int64_t matrix_stride_out = */ matrix_stride_out); -// } - template void gemv_axbpy( const Stream& s, @@ -1108,46 +1008,10 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { /* float beta = */ beta_); } - using namespace mlx::steel; - - ///////////////////////////////////////////////////////////////////////////// - // Split K specialization - - int _tm = M / 16; - int _tn = N / 16; - int _tk = K / 16; - - if (batch_size_out == 1 && (_tm * _tn) <= 32 && _tk >= 8) { - return steel_gemm_splitk_axpby( - /* const Stream& s = */ s, - /* metal::Device& d = */ d, - /* const array& a = */ a, - /* const array& b = */ b, - /* const array& c = */ c, - /* array& out = */ out, - /* int M = */ M, - /* int N = */ N, - /* int K = */ K, - /* int batch_size_out = */ batch_size_out, - /* int lda = */ lda, - /* int ldb = */ ldb, - /* bool transpose_a = */ transpose_a, - /* bool transpose_b = */ transpose_b, - /* std::vector& copies = */ copies, - /* float alpha = */ alpha_, - /* float beta = */ beta_); - } - ///////////////////////////////////////////////////////////////////////////// // Regular addmm dispatch - Strides batch_strides = A_batch_stride; - batch_strides.insert( - batch_strides.end(), B_batch_stride.begin(), B_batch_stride.end()); - batch_strides.insert( - batch_strides.end(), C_batch_stride.begin(), C_batch_stride.end()); - - return steel_matmul_regular_axpby( + return steel_matmul_axpby( /* const Stream& s = */ s, /* metal::Device& d = */ d, /* const array& a = */ a, @@ -1160,16 +1024,13 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { /* int batch_size_out = */ batch_size_out, /* int lda = */ lda, /* int ldb = */ ldb, - /* int ldd = */ ldd, /* bool transpose_a = */ transpose_a, /* bool transpose_b = */ transpose_b, /* std::vector& copies = */ copies, - /* Shape batch_shape = */ std::move(batch_shape), - /* Strides batch_strides = */ std::move(batch_strides), - /* int64_t A_batch_stride = */ A_batch_stride.back(), - /* int64_t B_batch_stride = */ B_batch_stride.back(), - /* int64_t matrix_stride_out = */ int64_t(M) * ldd, - /* int64_t C_batch_stride = */ C_batch_stride.back(), + /* Shape batch_shape = */ batch_shape, + /* Strides A_batch_stride = */ A_batch_stride, + /* Strides B_batch_stride = */ B_batch_stride, + /* Strides B_batch_stride = */ C_batch_stride, /* float alpha = */ alpha_, /* float beta = */ beta_); }