From b0cc71ae710e81c56091d9c90fc62e79268846f1 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 2 Oct 2025 12:21:27 -0700 Subject: [PATCH] Faster triu, tril, where with scalar (#2644) --- mlx/backend/common/ternary.h | 12 ++++++++ mlx/backend/cuda/ternary.cu | 37 +++++++++++++------------ mlx/backend/metal/jit_kernels.cpp | 27 ++++++++++++++---- mlx/backend/metal/kernels/ternary.h | 30 ++++++++++++++++---- mlx/backend/metal/kernels/ternary.metal | 12 ++++++-- mlx/backend/metal/ternary.cpp | 17 ++++++++---- mlx/ops.cpp | 8 +++--- 7 files changed, 101 insertions(+), 42 deletions(-) diff --git a/mlx/backend/common/ternary.h b/mlx/backend/common/ternary.h index d98dd8d68..233708ec3 100644 --- a/mlx/backend/common/ternary.h +++ b/mlx/backend/common/ternary.h @@ -11,6 +11,8 @@ namespace mlx::core { enum class TernaryOpType { ScalarScalarScalar, VectorVectorVector, + VectorVectorScalar, + VectorScalarVector, General, }; @@ -25,6 +27,14 @@ get_ternary_op_type(const array& a, const array& b, const array& c) { (a.flags().col_contiguous && b.flags().col_contiguous && c.flags().col_contiguous)) { topt = TernaryOpType::VectorVectorVector; + } else if ( + b.data_size() == 1 && a.flags().row_contiguous && + c.flags().row_contiguous) { + topt = TernaryOpType::VectorScalarVector; + } else if ( + c.data_size() == 1 && a.flags().row_contiguous && + b.flags().row_contiguous) { + topt = TernaryOpType::VectorVectorScalar; } else { topt = TernaryOpType::General; } @@ -59,6 +69,8 @@ inline void set_ternary_op_output_data( b.flags()); } break; + case TernaryOpType::VectorVectorScalar: + case TernaryOpType::VectorScalarVector: case TernaryOpType::General: // Try to donate an input which is row_contiguous if (!((a.flags().row_contiguous && maybe_donate(a)) || diff --git a/mlx/backend/cuda/ternary.cu b/mlx/backend/cuda/ternary.cu index 67937fc8e..84ae996aa 100644 --- a/mlx/backend/cuda/ternary.cu +++ b/mlx/backend/cuda/ternary.cu @@ -156,7 +156,25 @@ void ternary_op_gpu_inplace( using DType = cuda_type_t; auto topt = get_ternary_op_type(a, b, c); - if (topt == TernaryOpType::General) { + if (topt == TernaryOpType::VectorVectorVector || + topt == TernaryOpType::ScalarScalarScalar) { + dispatch_bool(out.data_size() > UINT32_MAX, [&](auto large) { + using IdxT = std::conditional_t; + constexpr int N_READS = 16 / sizeof(DType); + auto [num_blocks, block_dims] = get_launch_args( + out.data_size(), out.shape(), out.strides(), large(), N_READS); + encoder.add_kernel_node( + cu::ternary_v, + num_blocks, + block_dims, + 0, + a.data(), + b.data(), + c.data(), + out.data(), + out.data_size()); + }); + } else { dispatch_bool( a.data_size() > INT32_MAX || b.data_size() > INT32_MAX || c.data_size() > INT32_MAX || out.data_size() > INT32_MAX, @@ -225,23 +243,6 @@ void ternary_op_gpu_inplace( ndim); } }); - } else { - dispatch_bool(out.data_size() > UINT32_MAX, [&](auto large) { - using IdxT = std::conditional_t; - constexpr int N_READS = 16 / sizeof(DType); - auto [num_blocks, block_dims] = get_launch_args( - out.data_size(), out.shape(), out.strides(), large(), N_READS); - encoder.add_kernel_node( - cu::ternary_v, - num_blocks, - block_dims, - 0, - a.data(), - b.data(), - c.data(), - out.data(), - out.data_size()); - }); } }); } diff --git a/mlx/backend/metal/jit_kernels.cpp b/mlx/backend/metal/jit_kernels.cpp index cc8addf9f..e70420cf8 100644 --- a/mlx/backend/metal/jit_kernels.cpp +++ b/mlx/backend/metal/jit_kernels.cpp @@ -144,8 +144,7 @@ MTL::ComputePipelineState* get_ternary_kernel( auto t_str = get_type_string(type); std::string kernel_source = metal::utils(); concatenate(kernel_source, metal::ternary_ops(), metal::ternary()); - const std::array, 4> kernel_types = {{ - {"v2", "ternary_v2"}, + const std::array, 3> kernel_types = {{ {"g1large", "ternary_g_nd1"}, {"g2large", "ternary_g_nd2"}, {"g3large", "ternary_g_nd3"}, @@ -154,13 +153,29 @@ MTL::ComputePipelineState* get_ternary_kernel( kernel_source += get_template_definition(name + "_" + lib_name, func, t_str, op); } + + kernel_source += get_template_definition( + "v2_" + lib_name, "ternary_v2", t_str, op, false, false); + kernel_source += get_template_definition( + "sv2_" + lib_name, "ternary_v2", t_str, op, true, false); + kernel_source += get_template_definition( + "vs2_" + lib_name, "ternary_v2", t_str, op, false, true); + if (get_work_per_thread(type) > 1) { - kernel_source += - get_template_definition("vn_" + lib_name, "ternary_v", t_str, op); + kernel_source += get_template_definition( + "vn_" + lib_name, "ternary_v", t_str, op, false, false); + kernel_source += get_template_definition( + "svn_" + lib_name, "ternary_v", t_str, op, true, false); + kernel_source += get_template_definition( + "vsn_" + lib_name, "ternary_v", t_str, op, false, true); } - kernel_source += - get_template_definition("v_" + lib_name, "ternary_v", t_str, op, 1); + kernel_source += get_template_definition( + "v_" + lib_name, "ternary_v", t_str, op, false, false, 1); + kernel_source += get_template_definition( + "sv_" + lib_name, "ternary_v", t_str, op, true, false, 1); + kernel_source += get_template_definition( + "vs_" + lib_name, "ternary_v", t_str, op, false, true, 1); kernel_source += get_template_definition( "g1_" + lib_name, "ternary_g_nd1", t_str, op, "int"); kernel_source += get_template_definition( diff --git a/mlx/backend/metal/kernels/ternary.h b/mlx/backend/metal/kernels/ternary.h index 570f5e4d6..705b73e25 100644 --- a/mlx/backend/metal/kernels/ternary.h +++ b/mlx/backend/metal/kernels/ternary.h @@ -1,6 +1,11 @@ // Copyright © 2024 Apple Inc. -template ::n> +template < + typename T, + typename Op, + bool BSCALAR, + bool CSCALAR, + int N = WorkPerThread::n> [[kernel]] void ternary_v( device const bool* a, device const T* b, @@ -11,16 +16,25 @@ template ::n> index *= N; if (N > 1 && index + N > size) { for (int i = 0; index + i < size; ++i) { - d[index + i] = Op()(a[index + i], b[index + i], c[index + i]); + auto bidx = BSCALAR ? 0 : index + i; + auto cidx = CSCALAR ? 0 : index + i; + d[index + i] = Op()(a[index + i], b[bidx], c[cidx]); } } else { for (int i = 0; i < N; ++i) { - d[index + i] = Op()(a[index + i], b[index + i], c[index + i]); + auto bidx = BSCALAR ? 0 : index + i; + auto cidx = CSCALAR ? 0 : index + i; + d[index + i] = Op()(a[index + i], b[bidx], c[cidx]); } } } -template ::n> +template < + typename T, + typename Op, + bool BSCALAR, + bool CSCALAR, + int N = WorkPerThread::n> [[kernel]] void ternary_v2( device const bool* a, device const T* b, @@ -32,11 +46,15 @@ template ::n> int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); if (N > 1 && offset + N > size) { for (int i = 0; offset + i < size; ++i) { - d[offset + i] = Op()(a[offset + i], b[offset + i], c[offset + i]); + auto bidx = BSCALAR ? 0 : offset + i; + auto cidx = CSCALAR ? 0 : offset + i; + d[offset + i] = Op()(a[offset + i], b[bidx], c[cidx]); } } else { for (int i = 0; i < N; ++i) { - d[offset + i] = Op()(a[offset + i], b[offset + i], c[offset + i]); + auto bidx = BSCALAR ? 0 : offset + i; + auto cidx = CSCALAR ? 0 : offset + i; + d[offset + i] = Op()(a[offset + i], b[bidx], c[cidx]); } } } diff --git a/mlx/backend/metal/kernels/ternary.metal b/mlx/backend/metal/kernels/ternary.metal index 6da258b6f..ac1cd4160 100644 --- a/mlx/backend/metal/kernels/ternary.metal +++ b/mlx/backend/metal/kernels/ternary.metal @@ -9,8 +9,12 @@ #include "mlx/backend/metal/kernels/ternary.h" #define instantiate_ternary_base(op, tname, type) \ - instantiate_kernel("v_" #op #tname, ternary_v, type, op, 1) \ - instantiate_kernel("v2_" #op #tname, ternary_v2, type, op) \ + instantiate_kernel("v_" #op #tname, ternary_v, type, op, false, false, 1) \ + instantiate_kernel("v2_" #op #tname, ternary_v2, type, op, false, false) \ + instantiate_kernel("vs_" #op #tname, ternary_v, type, op, false, true, 1) \ + instantiate_kernel("vs2_" #op #tname, ternary_v2, type, op, false, true) \ + instantiate_kernel("sv_" #op #tname, ternary_v, type, op, true, false, 1) \ + instantiate_kernel("sv2_" #op #tname, ternary_v2, type, op, true, false) \ instantiate_kernel("gn2_" #op #tname, ternary_g, type, op, 2, int) \ instantiate_kernel("g1_" #op #tname, ternary_g_nd1, type, op, int) \ instantiate_kernel("g2_" #op #tname, ternary_g_nd2, type, op, int) \ @@ -21,7 +25,9 @@ instantiate_kernel("gn4large_" #op #tname, ternary_g, type, op, 4) \ #define instantiate_ternary_all(op, tname, type) \ - instantiate_kernel("vn_" #op #tname, ternary_v, type, op) \ + instantiate_kernel("vn_" #op #tname, ternary_v, type, op, false, false) \ + instantiate_kernel("vsn_" #op #tname, ternary_v, type, op, false, true) \ + instantiate_kernel("svn_" #op #tname, ternary_v, type, op, true, false) \ instantiate_ternary_base(op, tname, type) #define instantiate_ternary_types(op) \ diff --git a/mlx/backend/metal/ternary.cpp b/mlx/backend/metal/ternary.cpp index b2b9e3337..252815aae 100644 --- a/mlx/backend/metal/ternary.cpp +++ b/mlx/backend/metal/ternary.cpp @@ -58,12 +58,19 @@ void ternary_op_gpu_inplace( if (large) { kernel_name += "large"; } - } else if (large) { - kernel_name = "v2"; - } else if (work_per_thread > 1) { - kernel_name = "vn"; } else { - kernel_name = "v"; + if (topt == TernaryOpType::VectorScalarVector) { + kernel_name = "sv"; + } else if (topt == TernaryOpType::VectorVectorScalar) { + kernel_name = "vs"; + } else { + kernel_name = "v"; + } + if (large) { + kernel_name += "2"; + } else if (work_per_thread > 1) { + kernel_name += "n"; + } } concatenate(kernel_name, "_", op, type_to_name(b)); diff --git a/mlx/ops.cpp b/mlx/ops.cpp index c1d16ba1f..73d3a1f23 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -348,16 +348,16 @@ array tril(array x, int k /* = 0 */, StreamOrDevice s /* = {} */) { if (x.ndim() < 2) { throw std::invalid_argument("[tril] array must be at least 2-D"); } - auto mask = tri(x.shape(-2), x.shape(-1), k, x.dtype(), s); - return where(mask, x, zeros_like(x, s), s); + auto mask = tri(x.shape(-2), x.shape(-1), k, bool_, s); + return where(mask, x, array(0, x.dtype()), s); } array triu(array x, int k /* = 0 */, StreamOrDevice s /* = {} */) { if (x.ndim() < 2) { throw std::invalid_argument("[triu] array must be at least 2-D"); } - auto mask = tri(x.shape(-2), x.shape(-1), k - 1, x.dtype(), s); - return where(mask, zeros_like(x, s), x, s); + auto mask = tri(x.shape(-2), x.shape(-1), k - 1, bool_, s); + return where(mask, array(0, x.dtype()), x, s); } array reshape(const array& a, Shape shape, StreamOrDevice s /* = {} */) {