From 2fdf9eb5355892283489b62819512957633d8ab2 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 26 Aug 2024 11:22:27 -0700 Subject: [PATCH] Fix ternary for large arrays (#1359) * fix ternary for large arrays * fix --- mlx/backend/common/ternary.h | 28 ++++++++++++++++++++++++++++ mlx/backend/metal/ternary.cpp | 14 +++++++++----- 2 files changed, 37 insertions(+), 5 deletions(-) diff --git a/mlx/backend/common/ternary.h b/mlx/backend/common/ternary.h index 52d202df72..48f7d16a63 100644 --- a/mlx/backend/common/ternary.h +++ b/mlx/backend/common/ternary.h @@ -12,6 +12,7 @@ namespace { // TODO: Add support for more combinations of input types. enum class TernaryOpType { ScalarScalarScalar, + VectorVectorVector, General, }; @@ -20,6 +21,12 @@ get_ternary_op_type(const array& a, const array& b, const array& c) { TernaryOpType topt; if (a.data_size() == 1 && b.data_size() == 1 && c.data_size() == 1) { topt = TernaryOpType::ScalarScalarScalar; + } else if ( + (a.flags().row_contiguous && b.flags().row_contiguous && + c.flags().row_contiguous) || + (a.flags().col_contiguous && b.flags().col_contiguous && + c.flags().col_contiguous)) { + topt = TernaryOpType::VectorVectorVector; } else { topt = TernaryOpType::General; } @@ -33,11 +40,32 @@ void set_ternary_op_output_data( array& out, TernaryOpType topt, bool donate_with_move = false) { + auto maybe_donate = [&out, donate_with_move](const array& x) { + if (x.is_donatable() && x.itemsize() == out.itemsize()) { + if (donate_with_move) { + out.move_shared_buffer(x); + } else { + out.copy_shared_buffer(x); + } + return true; + } + return false; + }; + switch (topt) { case TernaryOpType::ScalarScalarScalar: out.set_data( allocator::malloc_or_wait(out.itemsize()), 1, b.strides(), b.flags()); break; + case TernaryOpType::VectorVectorVector: + if (!(maybe_donate(a) || maybe_donate(b) || maybe_donate(c))) { + out.set_data( + allocator::malloc_or_wait(out.itemsize() * b.data_size()), + b.data_size(), + b.strides(), + b.flags()); + } + break; case TernaryOpType::General: out.set_data(allocator::malloc_or_wait(out.nbytes())); break; diff --git a/mlx/backend/metal/ternary.cpp b/mlx/backend/metal/ternary.cpp index 8f4371939a..c214db267d 100644 --- a/mlx/backend/metal/ternary.cpp +++ b/mlx/backend/metal/ternary.cpp @@ -56,9 +56,12 @@ void ternary_op_gpu_inplace( auto& compute_encoder = d.get_command_encoder(s.index); compute_encoder->setComputePipelineState(kernel); - compute_encoder.set_input_array(a, 0); - compute_encoder.set_input_array(b, 1); - compute_encoder.set_input_array(c, 2); + bool donate_a = a.data_shared_ptr() == nullptr; + bool donate_b = b.data_shared_ptr() == nullptr; + bool donate_c = c.data_shared_ptr() == nullptr; + compute_encoder.set_input_array(donate_a ? out : a, 0); + compute_encoder.set_input_array(donate_b ? out : b, 1); + compute_encoder.set_input_array(donate_c ? out : c, 2); compute_encoder.set_output_array(out, 3); if (topt == TernaryOpType::General) { @@ -91,9 +94,10 @@ void ternary_op_gpu_inplace( MTL::Size grid_dims = MTL::Size(dim0, dim1, rest); compute_encoder.dispatchThreads(grid_dims, group_dims); } else { - // Launch a 1D grid of threads + // Launch a 1D or 2D grid of threads size_t nthreads = out.data_size(); - MTL::Size grid_dims = MTL::Size(nthreads, 1, 1); + MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides()) + : MTL::Size(nthreads, 1, 1); NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); if (thread_group_size > nthreads) { thread_group_size = nthreads;