From 8993382aaa3ea9501c3135ec7099cc3f2ce760eb Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 26 Jan 2024 16:30:33 -0800 Subject: [PATCH] Buffer Donation (#519) * buffer donation * fix to move shared pointer * format * gpu in place for copy and binary * revert ops test * cpu in place * a little cleanup * remove useless bench --- mlx/array.cpp | 8 ++ mlx/array.h | 13 +++ mlx/backend/accelerate/primitives.cpp | 152 ++++++-------------------- mlx/backend/common/binary.h | 79 ++++++++++--- mlx/backend/common/copy.cpp | 15 ++- mlx/backend/common/primitives.cpp | 6 - mlx/backend/common/unary.h | 19 +++- mlx/backend/metal/copy.cpp | 17 ++- mlx/backend/metal/metal.cpp | 22 ++-- mlx/backend/metal/primitives.cpp | 39 ++++--- mlx/transforms.cpp | 1 + python/tests/mlx_tests.py | 6 +- 12 files changed, 199 insertions(+), 178 deletions(-) diff --git a/mlx/array.cpp b/mlx/array.cpp index e19ed704d..2362b586f 100644 --- a/mlx/array.cpp +++ b/mlx/array.cpp @@ -155,6 +155,14 @@ void array::copy_shared_buffer(const array& other) { copy_shared_buffer(other, other.strides(), other.flags(), other.data_size()); } +void array::move_shared_buffer(array other) { + array_desc_->data = std::move(other.array_desc_->data); + array_desc_->strides = other.strides(); + array_desc_->flags = other.flags(); + array_desc_->data_size = other.data_size(); + array_desc_->data_ptr = other.array_desc_->data_ptr; +} + array::ArrayDesc::ArrayDesc(const std::vector& shape, Dtype dtype) : shape(shape), dtype(dtype) { std::tie(size, strides) = cum_prod(shape); diff --git a/mlx/array.h b/mlx/array.h index 8f345bd3f..3b70153e3 100644 --- a/mlx/array.h +++ b/mlx/array.h @@ -240,6 +240,11 @@ class array { return array_desc_->inputs; } + /** True indicates the arrays buffer is safe to reuse */ + bool is_donatable() const { + return array_desc_.use_count() == 1 && (array_desc_->data.use_count() == 1); + } + /** The array's siblings. */ const std::vector& siblings() const { return array_desc_->siblings; @@ -282,6 +287,12 @@ class array { return array_desc_->data->buffer; }; + // Return a copy of the shared pointer + // to the array::Data struct + std::shared_ptr data_shared_ptr() const { + return array_desc_->data; + } + // Return a raw pointer to the arrays data template T* data() { return static_cast(array_desc_->data_ptr); @@ -322,6 +333,8 @@ class array { void copy_shared_buffer(const array& other); + void move_shared_buffer(array other); + void overwrite_descriptor(const array& other) { array_desc_ = other.array_desc_; } diff --git a/mlx/backend/accelerate/primitives.cpp b/mlx/backend/accelerate/primitives.cpp index 777bb76f4..be8741a54 100644 --- a/mlx/backend/accelerate/primitives.cpp +++ b/mlx/backend/accelerate/primitives.cpp @@ -71,21 +71,11 @@ void Abs::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto& in = inputs[0]; if (in.dtype() == float32 && in.flags().contiguous) { - auto size = in.data_size(); - out.set_data( - allocator::malloc_or_wait(size * out.itemsize()), - size, - in.strides(), - in.flags()); - vDSP_vabs(in.data(), 1, out.data(), 1, size); + set_unary_output_data(in, out); + vDSP_vabs(in.data(), 1, out.data(), 1, in.data_size()); } else if (in.dtype() == int32 && in.flags().contiguous) { - auto size = in.data_size(); - out.set_data( - allocator::malloc_or_wait(size * out.itemsize()), - size, - in.strides(), - in.flags()); - vDSP_vabsi(in.data(), 1, out.data(), 1, size); + set_unary_output_data(in, out); + vDSP_vabsi(in.data(), 1, out.data(), 1, in.data_size()); } else if (is_unsigned(in.dtype())) { // No-op for unsigned types out.copy_shared_buffer(in); @@ -138,12 +128,8 @@ void ArcCos::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; if (out.dtype() == float32 && in.flags().contiguous) { + set_unary_output_data(in, out); int size = in.data_size(); - out.set_data( - allocator::malloc_or_wait(size * out.itemsize()), - size, - in.strides(), - in.flags()); vvacosf(out.data(), in.data(), &size); } else { eval(inputs, out); @@ -154,12 +140,8 @@ void ArcCosh::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; if (out.dtype() == float32 && in.flags().contiguous) { + set_unary_output_data(in, out); int size = in.data_size(); - out.set_data( - allocator::malloc_or_wait(size * out.itemsize()), - size, - in.strides(), - in.flags()); vvacoshf(out.data(), in.data(), &size); } else { eval(inputs, out); @@ -170,12 +152,8 @@ void ArcSin::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; if (out.dtype() == float32 && in.flags().contiguous) { + set_unary_output_data(in, out); int size = in.data_size(); - out.set_data( - allocator::malloc_or_wait(size * out.itemsize()), - size, - in.strides(), - in.flags()); vvasinf(out.data(), in.data(), &size); } else { eval(inputs, out); @@ -186,12 +164,8 @@ void ArcSinh::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; if (out.dtype() == float32 && in.flags().contiguous) { + set_unary_output_data(in, out); int size = in.data_size(); - out.set_data( - allocator::malloc_or_wait(size * out.itemsize()), - size, - in.strides(), - in.flags()); vvasinhf(out.data(), in.data(), &size); } else { eval(inputs, out); @@ -202,12 +176,8 @@ void ArcTan::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; if (out.dtype() == float32 && in.flags().contiguous) { + set_unary_output_data(in, out); int size = in.data_size(); - out.set_data( - allocator::malloc_or_wait(size * out.itemsize()), - size, - in.strides(), - in.flags()); vvatanf(out.data(), in.data(), &size); } else { eval(inputs, out); @@ -218,12 +188,8 @@ void ArcTanh::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; if (out.dtype() == float32 && in.flags().contiguous) { + set_unary_output_data(in, out); int size = in.data_size(); - out.set_data( - allocator::malloc_or_wait(size * out.itemsize()), - size, - in.strides(), - in.flags()); vvatanhf(out.data(), in.data(), &size); } else { eval(inputs, out); @@ -235,30 +201,23 @@ void AsType::eval_cpu(const std::vector& inputs, array& out) { auto& in = inputs[0]; if (in.flags().contiguous) { - auto allocfn = [&in, &out]() { - out.set_data( - allocator::malloc_or_wait(in.data_size() * out.itemsize()), - in.data_size(), - in.strides(), - in.flags()); - }; // Use accelerate functions if possible if (in.dtype() == float32 && out.dtype() == uint32) { - allocfn(); + set_unary_output_data(in, out); vDSP_vfixu32( in.data(), 1, out.data(), 1, in.data_size()); return; } else if (in.dtype() == float32 && out.dtype() == int32) { - allocfn(); + set_unary_output_data(in, out); vDSP_vfix32(in.data(), 1, out.data(), 1, in.data_size()); return; } else if (in.dtype() == uint32 && out.dtype() == float32) { - allocfn(); + set_unary_output_data(in, out); vDSP_vfltu32( in.data(), 1, out.data(), 1, in.data_size()); return; } else if (in.dtype() == int32 && out.dtype() == float32) { - allocfn(); + set_unary_output_data(in, out); vDSP_vflt32(in.data(), 1, out.data(), 1, in.data_size()); return; } @@ -270,12 +229,8 @@ void Cos::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; if (out.dtype() == float32 && in.flags().contiguous) { + set_unary_output_data(in, out); int size = in.data_size(); - out.set_data( - allocator::malloc_or_wait(size * out.itemsize()), - size, - in.strides(), - in.flags()); vvcosf(out.data(), in.data(), &size); } else { eval(inputs, out); @@ -286,12 +241,8 @@ void Cosh::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; if (out.dtype() == float32 && in.flags().contiguous) { + set_unary_output_data(in, out); int size = in.data_size(); - out.set_data( - allocator::malloc_or_wait(size * out.itemsize()), - size, - in.strides(), - in.flags()); vvcoshf(out.data(), in.data(), &size); } else { eval(inputs, out); @@ -379,12 +330,8 @@ void Exp::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; if (out.dtype() == float32 && in.flags().contiguous) { + set_unary_output_data(in, out); auto size = in.data_size(); - out.set_data( - allocator::malloc_or_wait(size * out.itemsize()), - size, - in.strides(), - in.flags()); vvexpf(out.data(), in.data(), reinterpret_cast(&size)); } else if (is_floating_point(out.dtype())) { unary_fp(in, out, [](auto x) { return std::exp(x); }); @@ -411,12 +358,8 @@ void Log::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; if (out.dtype() == float32 && in.flags().contiguous) { + set_unary_output_data(in, out); auto size = in.data_size(); - out.set_data( - allocator::malloc_or_wait(size * out.itemsize()), - size, - in.strides(), - in.flags()); switch (base_) { case Base::e: vvlogf( @@ -440,12 +383,8 @@ void Log1p::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; if (out.dtype() == float32 && in.flags().contiguous) { + set_unary_output_data(in, out); auto size = in.data_size(); - out.set_data( - allocator::malloc_or_wait(size * out.itemsize()), - size, - in.strides(), - in.flags()); vvlog1pf( out.data(), in.data(), reinterpret_cast(&size)); } else if (is_floating_point(out.dtype())) { @@ -527,13 +466,8 @@ void Negative::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto& in = inputs[0]; if (in.dtype() == float32 && in.flags().contiguous) { - auto size = in.data_size(); - out.set_data( - allocator::malloc_or_wait(size * out.itemsize()), - size, - in.strides(), - in.flags()); - vDSP_vneg(in.data(), 1, out.data(), 1, size); + set_unary_output_data(in, out); + vDSP_vneg(in.data(), 1, out.data(), 1, in.data_size()); } else { unary(in, out, [](auto x) { return -x; }); } @@ -546,7 +480,13 @@ void Power::eval_cpu(const std::vector& inputs, array& out) { if (out.dtype() == float32 && a.flags().row_contiguous && b.flags().row_contiguous) { int size = a.size(); - out.set_data(allocator::malloc_or_wait(out.nbytes())); + if (a.is_donatable() && a.itemsize() == out.itemsize()) { + out.copy_shared_buffer(a); + } else if (b.is_donatable() && b.itemsize() == out.itemsize()) { + out.copy_shared_buffer(b); + } else { + out.set_data(allocator::malloc_or_wait(out.nbytes())); + } vvpowf(out.data(), b.data(), a.data(), &size); } else { eval(inputs, out); @@ -588,12 +528,8 @@ void Sin::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; if (out.dtype() == float32 && in.flags().contiguous) { + set_unary_output_data(in, out); int size = in.data_size(); - out.set_data( - allocator::malloc_or_wait(size * out.itemsize()), - size, - in.strides(), - in.flags()); vvsinf(out.data(), in.data(), &size); } else { eval(inputs, out); @@ -604,12 +540,8 @@ void Sinh::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; if (out.dtype() == float32 && in.flags().contiguous) { + set_unary_output_data(in, out); int size = in.data_size(); - out.set_data( - allocator::malloc_or_wait(size * out.itemsize()), - size, - in.strides(), - in.flags()); vvsinhf(out.data(), in.data(), &size); } else { eval(inputs, out); @@ -620,12 +552,8 @@ void Square::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto& in = inputs[0]; if (in.dtype() == float32 && in.flags().contiguous) { + set_unary_output_data(in, out); auto size = in.data_size(); - out.set_data( - allocator::malloc_or_wait(size * out.itemsize()), - size, - in.strides(), - in.flags()); vDSP_vsq(in.data(), 1, out.data(), 1, size); } else { unary(in, out, [](auto x) { return x * x; }); @@ -636,12 +564,8 @@ void Sqrt::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto& in = inputs[0]; if (in.dtype() == float32 && in.flags().contiguous) { + set_unary_output_data(in, out); int size = in.data_size(); - out.set_data( - allocator::malloc_or_wait(size * out.itemsize()), - size, - in.strides(), - in.flags()); if (recip_) { vvrsqrtf(out.data(), in.data(), &size); } else { @@ -696,12 +620,8 @@ void Tan::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; if (out.dtype() == float32 && in.flags().contiguous) { + set_unary_output_data(in, out); int size = in.data_size(); - out.set_data( - allocator::malloc_or_wait(size * out.itemsize()), - size, - in.strides(), - in.flags()); vvtanf(out.data(), in.data(), &size); } else { eval(inputs, out); @@ -712,12 +632,8 @@ void Tanh::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; if (out.dtype() == float32 && in.flags().contiguous) { + set_unary_output_data(in, out); int size = in.data_size(); - out.set_data( - allocator::malloc_or_wait(size * out.itemsize()), - size, - in.strides(), - in.flags()); vvtanhf(out.data(), in.data(), &size); } else { eval(inputs, out); diff --git a/mlx/backend/common/binary.h b/mlx/backend/common/binary.h index c37488559..fb397b669 100644 --- a/mlx/backend/common/binary.h +++ b/mlx/backend/common/binary.h @@ -1,7 +1,6 @@ // Copyright © 2023 Apple Inc. #pragma once - #include "mlx/allocator.h" #include "mlx/array.h" #include "mlx/backend/common/utils.h" @@ -40,29 +39,83 @@ void set_binary_op_output_data( const array& a, const array& b, array& out, - BinaryOpType bopt) { + BinaryOpType bopt, + bool donate_with_move = false) { switch (bopt) { case ScalarScalar: out.set_data( allocator::malloc_or_wait(out.itemsize()), 1, a.strides(), a.flags()); break; case ScalarVector: - out.set_data( - allocator::malloc_or_wait(b.data_size() * out.itemsize()), - b.data_size(), - b.strides(), - b.flags()); + if (b.is_donatable() && b.itemsize() == out.itemsize()) { + if (donate_with_move) { + out.move_shared_buffer(b); + } else { + out.copy_shared_buffer(b); + } + } else { + out.set_data( + allocator::malloc_or_wait(b.data_size() * out.itemsize()), + b.data_size(), + b.strides(), + b.flags()); + } break; case VectorScalar: + if (a.is_donatable() && a.itemsize() == out.itemsize()) { + if (donate_with_move) { + out.move_shared_buffer(a); + } else { + out.copy_shared_buffer(a); + } + } else { + out.set_data( + allocator::malloc_or_wait(a.data_size() * out.itemsize()), + a.data_size(), + a.strides(), + a.flags()); + } + break; case VectorVector: - out.set_data( - allocator::malloc_or_wait(a.data_size() * out.itemsize()), - a.data_size(), - a.strides(), - a.flags()); + if (a.is_donatable() && a.itemsize() == out.itemsize()) { + if (donate_with_move) { + out.move_shared_buffer(a); + } else { + out.copy_shared_buffer(a); + } + } else if (b.is_donatable() && b.itemsize() == out.itemsize()) { + if (donate_with_move) { + out.move_shared_buffer(b); + } else { + out.copy_shared_buffer(b); + } + } else { + out.set_data( + allocator::malloc_or_wait(a.data_size() * out.itemsize()), + a.data_size(), + a.strides(), + a.flags()); + } break; case General: - out.set_data(allocator::malloc_or_wait(out.nbytes())); + if (a.is_donatable() && a.flags().row_contiguous && + a.itemsize() == out.itemsize() && a.size() == out.size()) { + if (donate_with_move) { + out.move_shared_buffer(a); + } else { + out.copy_shared_buffer(a); + } + } else if ( + b.is_donatable() && b.flags().row_contiguous && + b.itemsize() == out.itemsize() && b.size() == out.size()) { + if (donate_with_move) { + out.move_shared_buffer(b); + } else { + out.copy_shared_buffer(b); + } + } else { + out.set_data(allocator::malloc_or_wait(out.nbytes())); + } break; } } diff --git a/mlx/backend/common/copy.cpp b/mlx/backend/common/copy.cpp index c8af4b548..cc37e767a 100644 --- a/mlx/backend/common/copy.cpp +++ b/mlx/backend/common/copy.cpp @@ -289,11 +289,16 @@ void copy(const array& src, array& dst, CopyType ctype) { // Allocate the output switch (ctype) { case CopyType::Vector: - dst.set_data( - allocator::malloc_or_wait(src.data_size() * dst.itemsize()), - src.data_size(), - src.strides(), - src.flags()); + if (src.is_donatable() && src.itemsize() == dst.itemsize()) { + dst.copy_shared_buffer(src); + } else { + auto size = src.data_size(); + dst.set_data( + allocator::malloc_or_wait(size * dst.itemsize()), + size, + src.strides(), + src.flags()); + } break; case CopyType::Scalar: case CopyType::General: diff --git a/mlx/backend/common/primitives.cpp b/mlx/backend/common/primitives.cpp index 6aaf59b45..ce36a8d5d 100644 --- a/mlx/backend/common/primitives.cpp +++ b/mlx/backend/common/primitives.cpp @@ -237,17 +237,14 @@ void Erf::eval(const std::vector& inputs, array& out) { const auto& in = inputs[0]; switch (out.dtype()) { case float32: - out.set_data(allocator::malloc_or_wait(out.nbytes())); unary_op(in, out, [](auto x) { return std::erf(x); }); break; case float16: - out.set_data(allocator::malloc_or_wait(out.nbytes())); unary_op(in, out, [](auto x) { return static_cast(std::erf(static_cast(x))); }); break; case bfloat16: - out.set_data(allocator::malloc_or_wait(out.nbytes())); unary_op(in, out, [](auto x) { return static_cast(std::erf(static_cast(x))); }); @@ -264,17 +261,14 @@ void ErfInv::eval(const std::vector& inputs, array& out) { const auto& in = inputs[0]; switch (out.dtype()) { case float32: - out.set_data(allocator::malloc_or_wait(out.nbytes())); unary_op(in, out, [](auto x) { return erfinv(x); }); break; case float16: - out.set_data(allocator::malloc_or_wait(out.nbytes())); unary_op(in, out, [](auto x) { return static_cast(erfinv(static_cast(x))); }); break; case bfloat16: - out.set_data(allocator::malloc_or_wait(out.nbytes())); unary_op(in, out, [](auto x) { return static_cast(erfinv(static_cast(x))); }); diff --git a/mlx/backend/common/unary.h b/mlx/backend/common/unary.h index 918bae998..7fdcbeb77 100644 --- a/mlx/backend/common/unary.h +++ b/mlx/backend/common/unary.h @@ -64,15 +64,24 @@ struct RoundOp { } }; +void set_unary_output_data(const array& in, array& out) { + if (in.is_donatable() && in.itemsize() == out.itemsize()) { + out.copy_shared_buffer(in); + } else { + auto size = in.data_size(); + out.set_data( + allocator::malloc_or_wait(size * out.itemsize()), + size, + in.strides(), + in.flags()); + } +} + template void unary_op(const array& a, array& out, Op op) { const T* a_ptr = a.data(); if (a.flags().contiguous) { - out.set_data( - allocator::malloc_or_wait(a.data_size() * out.itemsize()), - a.data_size(), - a.strides(), - a.flags()); + set_unary_output_data(a, out); T* dst = out.data(); for (size_t i = 0; i < a.data_size(); ++i) { dst[i] = op(a_ptr[i]); diff --git a/mlx/backend/metal/copy.cpp b/mlx/backend/metal/copy.cpp index 67e839481..e023165c0 100644 --- a/mlx/backend/metal/copy.cpp +++ b/mlx/backend/metal/copy.cpp @@ -12,11 +12,15 @@ namespace mlx::core { void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) { if (ctype == CopyType::Vector) { - out.set_data( - allocator::malloc_or_wait(in.data_size() * out.itemsize()), - in.data_size(), - in.strides(), - in.flags()); + if (in.is_donatable() && in.itemsize() == out.itemsize()) { + out.move_shared_buffer(in); + } else { + out.set_data( + allocator::malloc_or_wait(in.data_size() * out.itemsize()), + in.data_size(), + in.strides(), + in.flags()); + } } else { out.set_data(allocator::malloc_or_wait(out.nbytes())); } @@ -67,7 +71,8 @@ void copy_gpu_inplace( auto kernel = d.get_kernel(kname.str()); auto compute_encoder = d.get_command_encoder(s.index); compute_encoder->setComputePipelineState(kernel); - set_array_buffer(compute_encoder, in, 0); + bool donate_in = in.data_shared_ptr() == nullptr; + set_array_buffer(compute_encoder, donate_in ? out : in, 0); set_array_buffer(compute_encoder, out, 1); if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) { diff --git a/mlx/backend/metal/metal.cpp b/mlx/backend/metal/metal.cpp index c120744b1..372b3d231 100644 --- a/mlx/backend/metal/metal.cpp +++ b/mlx/backend/metal/metal.cpp @@ -64,14 +64,23 @@ std::function make_task( auto command_buffer = increment_command_buffer(s); auto outputs = arr.outputs(); arr.primitive().eval_gpu(arr.inputs(), outputs); + std::vector> buffers; + for (auto& in : arr.inputs()) { + buffers.push_back(in.data_shared_ptr()); + } + for (auto& s : arr.siblings()) { + buffers.push_back(s.data_shared_ptr()); + } + if (!arr.is_tracer()) { + arr.detach(); + } + if (p) { metal::device(s.device).end_encoding(s.index); scheduler::notify_new_task(s); command_buffer->addCompletedHandler( - [s, arr, p = std::move(p)](MTL::CommandBuffer* cbuf) mutable { - if (!arr.is_tracer()) { - arr.detach(); - } + [s, buffers = std::move(buffers), p = std::move(p)]( + MTL::CommandBuffer* cbuf) { p->set_value(); scheduler::notify_task_completion(s); check_error(cbuf); @@ -79,10 +88,7 @@ std::function make_task( metal::device(s.device).commit_command_buffer(s.index); } else { command_buffer->addCompletedHandler( - [s, arr](MTL::CommandBuffer* cbuf) mutable { - if (!arr.is_tracer()) { - arr.detach(); - } + [s, buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) { check_error(cbuf); }); } diff --git a/mlx/backend/metal/primitives.cpp b/mlx/backend/metal/primitives.cpp index e41368f6e..a956a6a03 100644 --- a/mlx/backend/metal/primitives.cpp +++ b/mlx/backend/metal/primitives.cpp @@ -27,8 +27,8 @@ void binary_op( auto& a = inputs[0]; auto& b = inputs[1]; auto bopt = get_binary_op_type(a, b); - set_binary_op_output_data(a, b, outputs[0], bopt); - set_binary_op_output_data(a, b, outputs[1], bopt); + set_binary_op_output_data(a, b, outputs[0], bopt, true); + set_binary_op_output_data(a, b, outputs[1], bopt, true); auto& out = outputs[0]; if (out.size() == 0) { @@ -69,8 +69,14 @@ void binary_op( auto kernel = d.get_kernel(kname.str()); auto compute_encoder = d.get_command_encoder(s.index); compute_encoder->setComputePipelineState(kernel); - set_array_buffer(compute_encoder, a, 0); - set_array_buffer(compute_encoder, b, 1); + // - If a is donated it goes to the first output + // - If b is donated it goes to the first output if a was not donated + // otherwise it goes to the second output + bool donate_a = a.data_shared_ptr() == nullptr; + bool donate_b = b.data_shared_ptr() == nullptr; + set_array_buffer(compute_encoder, donate_a ? outputs[0] : a, 0); + set_array_buffer( + compute_encoder, donate_b ? (donate_a ? outputs[1] : outputs[0]) : b, 1); set_array_buffer(compute_encoder, outputs[0], 2); set_array_buffer(compute_encoder, outputs[1], 3); @@ -122,7 +128,7 @@ void binary_op( auto& a = inputs[0]; auto& b = inputs[1]; auto bopt = get_binary_op_type(a, b); - set_binary_op_output_data(a, b, out, bopt); + set_binary_op_output_data(a, b, out, bopt, true); if (out.size() == 0) { return; } @@ -161,8 +167,10 @@ void binary_op( auto kernel = d.get_kernel(kname.str()); auto compute_encoder = d.get_command_encoder(s.index); compute_encoder->setComputePipelineState(kernel); - set_array_buffer(compute_encoder, a, 0); - set_array_buffer(compute_encoder, b, 1); + bool donate_a = a.data_shared_ptr() == nullptr; + bool donate_b = b.data_shared_ptr() == nullptr; + set_array_buffer(compute_encoder, donate_a ? out : a, 0); + set_array_buffer(compute_encoder, donate_b ? out : b, 1); set_array_buffer(compute_encoder, out, 2); if (bopt == General) { @@ -212,11 +220,15 @@ void unary_op( auto& in = inputs[0]; bool contig = in.flags().contiguous; if (contig) { - out.set_data( - allocator::malloc_or_wait(in.data_size() * out.itemsize()), - in.data_size(), - in.strides(), - in.flags()); + if (in.is_donatable() && in.itemsize() == out.itemsize()) { + out.move_shared_buffer(in); + } else { + out.set_data( + allocator::malloc_or_wait(in.data_size() * out.itemsize()), + in.data_size(), + in.strides(), + in.flags()); + } } else { out.set_data(allocator::malloc_or_wait(out.nbytes())); } @@ -240,7 +252,8 @@ void unary_op( auto compute_encoder = d.get_command_encoder(s.index); compute_encoder->setComputePipelineState(kernel); - set_array_buffer(compute_encoder, in, 0); + set_array_buffer( + compute_encoder, in.data_shared_ptr() == nullptr ? out : in, 0); set_array_buffer(compute_encoder, out, 1); if (!contig) { compute_encoder->setBytes(in.shape().data(), in.ndim() * sizeof(int), 2); diff --git a/mlx/transforms.cpp b/mlx/transforms.cpp index 86cf69989..88d5038f2 100644 --- a/mlx/transforms.cpp +++ b/mlx/transforms.cpp @@ -98,6 +98,7 @@ void eval(const std::vector& outputs) { auto stream = arr.primitive().stream(); std::vector> arr_deps; for (auto& in : arr.inputs()) { + // TODO that's a bug if (auto it = deps.find(in.primitive_id()); it != deps.end()) { arr_deps.push_back(it->second); } diff --git a/python/tests/mlx_tests.py b/python/tests/mlx_tests.py index 3d1e36269..477e81c8c 100644 --- a/python/tests/mlx_tests.py +++ b/python/tests/mlx_tests.py @@ -65,11 +65,9 @@ class MLXTestCase(unittest.TestCase): ) if not isinstance(mx_res, mx.array) and not isinstance(expected, mx.array): np.testing.assert_allclose(mx_res, expected, rtol=rtol, atol=atol) + return elif not isinstance(mx_res, mx.array): mx_res = mx.array(mx_res) - self.assertTrue(mx.allclose(mx_res, expected, rtol=rtol, atol=atol)) elif not isinstance(expected, mx.array): expected = mx.array(expected) - self.assertTrue(mx.allclose(mx_res, expected, rtol=rtol, atol=atol)) - else: - self.assertTrue(mx.allclose(mx_res, expected, rtol=rtol, atol=atol)) + self.assertTrue(mx.allclose(mx_res, expected, rtol=rtol, atol=atol))