diff --git a/mlx/allocator.cpp b/mlx/allocator.cpp index 86fa5974f2..a111b5c22d 100644 --- a/mlx/allocator.cpp +++ b/mlx/allocator.cpp @@ -23,11 +23,22 @@ void free(Buffer buffer) { } Buffer CommonAllocator::malloc(size_t size, bool) { - return Buffer{std::malloc(size)}; + void* ptr = std::malloc(size + sizeof(size_t)); + if (ptr != nullptr) { + *static_cast(ptr) = size; + } + return Buffer{ptr}; } void CommonAllocator::free(Buffer buffer) { - std::free(buffer.raw_ptr()); + std::free(buffer.ptr()); +} + +size_t CommonAllocator::size(Buffer buffer) const { + if (buffer.ptr() == nullptr) { + return 0; + } + return *static_cast(buffer.ptr()); } Buffer malloc_or_wait(size_t size) { diff --git a/mlx/allocator.h b/mlx/allocator.h index 42dd7e180b..2809c7f7fb 100644 --- a/mlx/allocator.h +++ b/mlx/allocator.h @@ -41,6 +41,7 @@ class Allocator { public: virtual Buffer malloc(size_t size, bool allow_swap = false) = 0; virtual void free(Buffer buffer) = 0; + virtual size_t size(Buffer buffer) const = 0; Allocator() = default; Allocator(const Allocator& other) = delete; @@ -57,6 +58,7 @@ class CommonAllocator : public Allocator { public: virtual Buffer malloc(size_t size, bool allow_swap = false) override; virtual void free(Buffer buffer) override; + virtual size_t size(Buffer buffer) const override; private: CommonAllocator() = default; diff --git a/mlx/array.h b/mlx/array.h index cce64e79ce..d6d6c921d5 100644 --- a/mlx/array.h +++ b/mlx/array.h @@ -324,6 +324,10 @@ class array { return array_desc_->data->buffer; } + size_t buffer_size() const { + return allocator::allocator().size(buffer()); + } + // Return a copy of the shared pointer // to the array::Data struct std::shared_ptr data_shared_ptr() const { diff --git a/mlx/backend/common/binary.h b/mlx/backend/common/binary.h index 19f32bd6dd..3777f4bdd7 100644 --- a/mlx/backend/common/binary.h +++ b/mlx/backend/common/binary.h @@ -43,13 +43,15 @@ void set_binary_op_output_data( array& out, BinaryOpType bopt, bool donate_with_move = false) { + bool b_donatable = is_donatable(b, out); + bool a_donatable = is_donatable(a, out); switch (bopt) { case BinaryOpType::ScalarScalar: out.set_data( allocator::malloc_or_wait(out.itemsize()), 1, a.strides(), a.flags()); break; case BinaryOpType::ScalarVector: - if (b.is_donatable() && b.itemsize() == out.itemsize()) { + if (b_donatable) { if (donate_with_move) { out.move_shared_buffer(b); } else { @@ -64,7 +66,7 @@ void set_binary_op_output_data( } break; case BinaryOpType::VectorScalar: - if (a.is_donatable() && a.itemsize() == out.itemsize()) { + if (a_donatable) { if (donate_with_move) { out.move_shared_buffer(a); } else { @@ -79,13 +81,13 @@ void set_binary_op_output_data( } break; case BinaryOpType::VectorVector: - if (a.is_donatable() && a.itemsize() == out.itemsize()) { + if (a_donatable) { if (donate_with_move) { out.move_shared_buffer(a); } else { out.copy_shared_buffer(a); } - } else if (b.is_donatable() && b.itemsize() == out.itemsize()) { + } else if (b_donatable) { if (donate_with_move) { out.move_shared_buffer(b); } else { @@ -100,16 +102,14 @@ void set_binary_op_output_data( } break; case BinaryOpType::General: - if (a.is_donatable() && a.flags().row_contiguous && - a.itemsize() == out.itemsize() && a.size() == out.size()) { + if (a_donatable && a.flags().row_contiguous && a.size() == out.size()) { if (donate_with_move) { out.move_shared_buffer(a); } else { out.copy_shared_buffer(a); } } else if ( - b.is_donatable() && b.flags().row_contiguous && - b.itemsize() == out.itemsize() && b.size() == out.size()) { + b_donatable && b.flags().row_contiguous && b.size() == out.size()) { if (donate_with_move) { out.move_shared_buffer(b); } else { diff --git a/mlx/backend/common/ternary.h b/mlx/backend/common/ternary.h index 48f7d16a63..f3696c7a52 100644 --- a/mlx/backend/common/ternary.h +++ b/mlx/backend/common/ternary.h @@ -41,7 +41,7 @@ void set_ternary_op_output_data( TernaryOpType topt, bool donate_with_move = false) { auto maybe_donate = [&out, donate_with_move](const array& x) { - if (x.is_donatable() && x.itemsize() == out.itemsize()) { + if (is_donatable(x, out)) { if (donate_with_move) { out.move_shared_buffer(x); } else { diff --git a/mlx/backend/common/unary.h b/mlx/backend/common/unary.h index 28c4f0f4a5..0dbe374fea 100644 --- a/mlx/backend/common/unary.h +++ b/mlx/backend/common/unary.h @@ -12,7 +12,7 @@ namespace mlx::core { namespace { void set_unary_output_data(const array& in, array& out) { - if (in.is_donatable() && in.itemsize() == out.itemsize()) { + if (is_donatable(in, out)) { out.copy_shared_buffer(in); } else { auto size = in.data_size(); diff --git a/mlx/backend/common/utils.h b/mlx/backend/common/utils.h index 0eedfceece..3c4a3d0a3f 100644 --- a/mlx/backend/common/utils.h +++ b/mlx/backend/common/utils.h @@ -155,4 +155,11 @@ inline auto check_contiguity( no_broadcast_data_size, is_row_contiguous, is_col_contiguous); } +inline bool is_donatable(const array& in, const array& out) { + constexpr size_t donation_extra = 16384; + + return in.is_donatable() && in.itemsize() == out.itemsize() && + in.buffer_size() <= out.nbytes() + donation_extra; +} + } // namespace mlx::core diff --git a/mlx/backend/metal/allocator.cpp b/mlx/backend/metal/allocator.cpp index 2d2825d4ac..3c9f04dfe5 100644 --- a/mlx/backend/metal/allocator.cpp +++ b/mlx/backend/metal/allocator.cpp @@ -241,6 +241,10 @@ void MetalAllocator::free(Buffer buffer) { } } +size_t MetalAllocator::size(Buffer buffer) const { + return static_cast(buffer.ptr())->length(); +} + MetalAllocator& allocator() { // By creating the |allocator_| on heap, the destructor of MetalAllocator will // not be called on exit and all the buffers will be leaked. This is necessary diff --git a/mlx/backend/metal/allocator.h b/mlx/backend/metal/allocator.h index 8e34f48d27..f2cf8feb89 100644 --- a/mlx/backend/metal/allocator.h +++ b/mlx/backend/metal/allocator.h @@ -56,6 +56,7 @@ class MetalAllocator : public allocator::Allocator { public: virtual Buffer malloc(size_t size, bool allow_swap = false) override; virtual void free(Buffer buffer) override; + virtual size_t size(Buffer buffer) const override; size_t get_active_memory() { return active_memory_; }; diff --git a/mlx/backend/no_metal/allocator.cpp b/mlx/backend/no_metal/allocator.cpp index e0d777f4ae..27e2ea06f8 100644 --- a/mlx/backend/no_metal/allocator.cpp +++ b/mlx/backend/no_metal/allocator.cpp @@ -10,7 +10,7 @@ Allocator& allocator() { } void* Buffer::raw_ptr() { - return ptr_; + return static_cast(ptr_) + 1; } } // namespace mlx::core::allocator