diff --git a/mlx/allocator.h b/mlx/allocator.h index 37835a3c7..cd6a78e74 100644 --- a/mlx/allocator.h +++ b/mlx/allocator.h @@ -34,6 +34,10 @@ class Allocator { virtual Buffer malloc(size_t size) = 0; virtual void free(Buffer buffer) = 0; virtual size_t size(Buffer buffer) const = 0; + virtual Buffer make_buffer(void* ptr, size_t size) { + return Buffer{nullptr}; + }; + virtual void release(Buffer buffer) {} Allocator() = default; Allocator(const Allocator& other) = delete; @@ -53,4 +57,17 @@ inline void free(Buffer buffer) { allocator().free(buffer); } +// Make a Buffer from a raw pointer of the given size without a copy. If a +// no-copy conversion is not possible then the returned buffer.ptr() will be +// nullptr. Any buffer created with this function must be released with +// release(buffer) +inline Buffer make_buffer(void* ptr, size_t size) { + return allocator().make_buffer(ptr, size); +}; + +// Release a buffer from the allocator made with make_buffer +inline void release(Buffer buffer) { + allocator().release(buffer); +} + } // namespace mlx::core::allocator diff --git a/mlx/array.cpp b/mlx/array.cpp index c43d6a104..7c8e8191e 100644 --- a/mlx/array.cpp +++ b/mlx/array.cpp @@ -82,6 +82,28 @@ array::array(std::initializer_list data, Dtype dtype) init(data.begin()); } +array::array( + void* data, + Shape shape, + Dtype dtype, + const std::function& deleter) + : array_desc_(std::make_shared(std::move(shape), dtype)) { + auto buffer = allocator::make_buffer(data, nbytes()); + if (buffer.ptr() == nullptr) { + set_data(allocator::malloc(nbytes())); + auto ptr = static_cast(data); + std::copy(ptr, ptr + nbytes(), this->data()); + deleter(data); + } else { + auto wrapped_deleter = [deleter](allocator::Buffer buffer) { + auto ptr = buffer.ptr(); + allocator::release(buffer); + return deleter(ptr); + }; + set_data(buffer, std::move(wrapped_deleter)); + } +} + /* Build an array from a shared buffer */ array::array(allocator::Buffer data, Shape shape, Dtype dtype, Deleter deleter) : array_desc_(std::make_shared(std::move(shape), dtype)) { diff --git a/mlx/array.h b/mlx/array.h index 25b1f5766..645fa68b5 100644 --- a/mlx/array.h +++ b/mlx/array.h @@ -57,6 +57,16 @@ class array { Shape shape, Dtype dtype = TypeToDtype()); + /* Build an array from a raw pointer. The constructor will attempt to use the + * input data without a copy. The deleter will be called when the array no + * longer needs the underlying memory - after the array is destroyed in the + * no-copy case and after the copy otherwise. */ + explicit array( + void* data, + Shape shape, + Dtype dtype, + const std::function& deleter); + /* Build an array from a buffer */ explicit array( allocator::Buffer data, diff --git a/mlx/backend/metal/allocator.cpp b/mlx/backend/metal/allocator.cpp index 2cecce358..2e7fea2af 100644 --- a/mlx/backend/metal/allocator.cpp +++ b/mlx/backend/metal/allocator.cpp @@ -203,6 +203,32 @@ size_t MetalAllocator::size(Buffer buffer) const { return static_cast(buffer.ptr())->length(); } +Buffer MetalAllocator::make_buffer(void* ptr, size_t size) { + auto buf = device_->newBuffer(ptr, size, resource_options, nullptr); + if (!buf) { + return Buffer{nullptr}; + } + std::unique_lock lk(mutex_); + residency_set_.insert(buf); + active_memory_ += buf->length(); + peak_memory_ = std::max(peak_memory_, active_memory_); + num_resources_++; + return Buffer{static_cast(buf)}; +} + +void MetalAllocator::release(Buffer buffer) { + auto buf = static_cast(buffer.ptr()); + if (buf == nullptr) { + return; + } + std::unique_lock lk(mutex_); + active_memory_ -= buf->length(); + num_resources_--; + lk.unlock(); + auto pool = metal::new_scoped_memory_pool(); + buf->release(); +} + MetalAllocator& allocator() { // By creating the |allocator_| on heap, the destructor of MetalAllocator // will not be called on exit and buffers in the cache will be leaked. This diff --git a/mlx/backend/metal/allocator.h b/mlx/backend/metal/allocator.h index 216735ad3..5e177b3d3 100644 --- a/mlx/backend/metal/allocator.h +++ b/mlx/backend/metal/allocator.h @@ -21,6 +21,9 @@ class MetalAllocator : public allocator::Allocator { virtual Buffer malloc(size_t size) override; virtual void free(Buffer buffer) override; virtual size_t size(Buffer buffer) const override; + virtual Buffer make_buffer(void* ptr, size_t size) override; + virtual void release(Buffer buffer) override; + size_t get_active_memory() { return active_memory_; }; diff --git a/mlx/backend/no_gpu/allocator.cpp b/mlx/backend/no_gpu/allocator.cpp index 320d1a267..76d008e61 100644 --- a/mlx/backend/no_gpu/allocator.cpp +++ b/mlx/backend/no_gpu/allocator.cpp @@ -25,6 +25,7 @@ class CommonAllocator : public Allocator { virtual Buffer malloc(size_t size) override; virtual void free(Buffer buffer) override; virtual size_t size(Buffer buffer) const override; + size_t get_active_memory() const { return active_memory_; }; diff --git a/tests/array_tests.cpp b/tests/array_tests.cpp index b31da9899..b624dff79 100644 --- a/tests/array_tests.cpp +++ b/tests/array_tests.cpp @@ -1,5 +1,4 @@ // Copyright © 2023 Apple Inc. - #include #include "doctest/doctest.h" @@ -608,3 +607,24 @@ TEST_CASE("test make empty array") { CHECK_EQ(a.size(), 0); CHECK_EQ(a.dtype(), bool_); } + +TEST_CASE("test make array from user buffer") { + int size = 4096; + std::vector buffer(size, 0); + + int count = 0; + auto deleter = [&count](void*) { count++; }; + + { + auto a = array(buffer.data(), Shape{size}, int32, deleter); + if (metal::is_available()) { + CHECK_EQ(buffer.data(), a.data()); + } + auto b = a + array(1); + eval(b); + auto expected = ones({4096}); + CHECK(array_equal(b, expected).item()); + } + // deleter should always get called + CHECK_EQ(count, 1); +}