From d5964a271007a53937e63908bfab853e0253178c Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 1 Mar 2024 19:51:58 -0800 Subject: [PATCH] bindings for memory info (#761) * bindings for memory info * update api * keep cache low if requested * fix default * nit in ops error --- docs/src/index.rst | 1 + docs/src/python/metal.rst | 14 ++++++ docs/src/python/ops.rst | 2 + mlx/backend/metal/allocator.cpp | 71 +++++++++++++++++++--------- mlx/backend/metal/allocator.h | 21 ++++++++- mlx/backend/metal/metal.h | 47 +++++++++++++++++- mlx/backend/no_metal/metal.cpp | 19 ++++++-- mlx/ops.cpp | 2 +- python/src/metal.cpp | 84 ++++++++++++++++++++++++++++++--- python/tests/test_metal.py | 45 ++++++++++++++++++ tests/metal_tests.cpp | 63 +++++++++++++------------ 11 files changed, 300 insertions(+), 69 deletions(-) create mode 100644 docs/src/python/metal.rst create mode 100644 python/tests/test_metal.py diff --git a/docs/src/index.rst b/docs/src/index.rst index 50dfe9083..e54a55b7a 100644 --- a/docs/src/index.rst +++ b/docs/src/index.rst @@ -64,6 +64,7 @@ are the CPU and GPU. python/transforms python/fft python/linalg + python/metal python/nn python/optimizers python/tree_utils diff --git a/docs/src/python/metal.rst b/docs/src/python/metal.rst new file mode 100644 index 000000000..c11deb4fa --- /dev/null +++ b/docs/src/python/metal.rst @@ -0,0 +1,14 @@ +Metal +===== + +.. currentmodule:: mlx.core.metal + +.. autosummary:: + :toctree: _autosummary + + is_available + get_active_memory + get_peak_memory + get_cache_memory + set_memory_limit + set_cache_limit diff --git a/docs/src/python/ops.rst b/docs/src/python/ops.rst index 2cc2b6d6b..a7809ead2 100644 --- a/docs/src/python/ops.rst +++ b/docs/src/python/ops.rst @@ -57,6 +57,7 @@ Operations greater_equal identity inner + isclose isnan isposinf isneginf @@ -121,6 +122,7 @@ Operations tan tanh tensordot + tile transpose tri tril diff --git a/mlx/backend/metal/allocator.cpp b/mlx/backend/metal/allocator.cpp index cab27b715..d8e4538ae 100644 --- a/mlx/backend/metal/allocator.cpp +++ b/mlx/backend/metal/allocator.cpp @@ -1,4 +1,4 @@ -// Copyright © 2023 Apple Inc. +// Copyright © 2023-2024 Apple Inc. #include "mlx/backend/metal/allocator.h" #include "mlx/backend/metal/metal.h" @@ -23,16 +23,6 @@ void* Buffer::raw_ptr() { namespace metal { -static bool cache_enabled_ = true; - -bool cache_enabled() { - return cache_enabled_; -} - -void set_cache_enabled(bool enabled) { - cache_enabled_ = enabled; -} - namespace { BufferCache::BufferCache(MTL::Device* device) @@ -158,9 +148,23 @@ void BufferCache::remove_from_list(BufferCache::BufferHolder* to_remove) { MetalAllocator::MetalAllocator() : device_(device(mlx::core::Device::gpu).mtl_device()), buffer_cache_(device_), - peak_allocated_size_(0), block_limit_(1.5 * device_->recommendedMaxWorkingSetSize()), - gc_limit_(0.95 * device_->recommendedMaxWorkingSetSize()) {} + gc_limit_(0.95 * device_->recommendedMaxWorkingSetSize()), + max_pool_size_(block_limit_) {} + +size_t MetalAllocator::set_cache_limit(size_t limit) { + std::swap(limit, max_pool_size_); + return limit; +}; + +size_t MetalAllocator::set_memory_limit(size_t limit, bool relaxed) { + std::swap(limit, block_limit_); + relaxed_ = relaxed; + gc_limit_ = std::min( + block_limit_, + static_cast(0.95 * device_->recommendedMaxWorkingSetSize())); + return limit; +}; Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) { // Metal doesn't like empty buffers @@ -175,10 +179,12 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) { // Try the cache MTL::Buffer* buf = buffer_cache_.reuse_from_cache(size); - + size_t pool_size = get_cache_memory(); if (!buf) { + size_t mem_required = get_active_memory() + pool_size + size; + // If there is too much memory pressure, fail (likely causes a wait). - if (!allow_swap && device_->currentAllocatedSize() + size >= block_limit_) { + if (!(allow_swap && relaxed_) && mem_required >= block_limit_) { return Buffer{nullptr}; } @@ -186,10 +192,8 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) { // If we have a lot of memory pressure, check if we can reclaim some memory // from the cache - if (device_->currentAllocatedSize() + size >= gc_limit_) { - size_t min_bytes_to_free = - size + device_->currentAllocatedSize() - gc_limit_; - buffer_cache_.release_cached_buffers(min_bytes_to_free); + if (mem_required >= gc_limit_) { + buffer_cache_.release_cached_buffers(mem_required - gc_limit_); } // Allocate new buffer if needed @@ -198,15 +202,22 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) { buf = device_->newBuffer(size, res_opt); } - peak_allocated_size_ = - std::max(peak_allocated_size_, device_->currentAllocatedSize()); + // Maintain the cache below the requested limit + if (pool_size >= max_pool_size_) { + auto thread_pool = metal::new_scoped_memory_pool(); + buffer_cache_.release_cached_buffers(pool_size - max_pool_size_); + } + + active_memory_ += buf->length(); + peak_memory_ = std::max(peak_memory_, active_memory_); return Buffer{static_cast(buf)}; } void MetalAllocator::free(Buffer buffer) { auto buf = static_cast(buffer.ptr()); - if (cache_enabled()) { + active_memory_ -= buf->length(); + if (max_pool_size_ > 0) { buffer_cache_.recycle_to_cache(buf); } else { buf->release(); @@ -218,6 +229,22 @@ MetalAllocator& allocator() { return allocator_; } +size_t set_cache_limit(size_t limit) { + return allocator().set_cache_limit(limit); +} +size_t set_memory_limit(size_t limit, bool relaxed /* = true */) { + return allocator().set_memory_limit(limit, relaxed); +} +size_t get_active_memory() { + return allocator().get_active_memory(); +} +size_t get_peak_memory() { + return allocator().get_peak_memory(); +} +size_t get_cache_memory() { + return allocator().get_cache_memory(); +} + } // namespace metal } // namespace mlx::core diff --git a/mlx/backend/metal/allocator.h b/mlx/backend/metal/allocator.h index 45a58bc13..a31cb5fb4 100644 --- a/mlx/backend/metal/allocator.h +++ b/mlx/backend/metal/allocator.h @@ -1,4 +1,4 @@ -// Copyright © 2023 Apple Inc. +// Copyright © 2023-2024 Apple Inc. #pragma once @@ -24,6 +24,9 @@ class BufferCache { MTL::Buffer* reuse_from_cache(size_t size); void recycle_to_cache(MTL::Buffer* buf); void release_cached_buffers(size_t min_bytes_to_free); + size_t pool_size() { + return pool_size_; + } private: struct BufferHolder { @@ -54,6 +57,17 @@ class MetalAllocator : public allocator::Allocator { public: virtual Buffer malloc(size_t size, bool allow_swap = false) override; virtual void free(Buffer buffer) override; + size_t get_active_memory() { + return active_memory_; + }; + size_t get_peak_memory() { + return peak_memory_; + }; + size_t get_cache_memory() { + return buffer_cache_.pool_size(); + }; + size_t set_cache_limit(size_t limit); + size_t set_memory_limit(size_t limit, bool relaxed); private: MTL::Device* device_; @@ -64,9 +78,12 @@ class MetalAllocator : public allocator::Allocator { BufferCache buffer_cache_; // Allocation stats - size_t peak_allocated_size_; size_t block_limit_; size_t gc_limit_; + size_t active_memory_{0}; + size_t peak_memory_{0}; + size_t max_pool_size_; + bool relaxed_{true}; }; MetalAllocator& allocator(); diff --git a/mlx/backend/metal/metal.h b/mlx/backend/metal/metal.h index 11ac510fd..360481f81 100644 --- a/mlx/backend/metal/metal.h +++ b/mlx/backend/metal/metal.h @@ -12,8 +12,51 @@ namespace mlx::core::metal { bool is_available(); -bool cache_enabled(void); -void set_cache_enabled(bool enabled); + +/* Get the actively used memory in bytes. + * + * Note, this will not always match memory use reported by the system because + * it does not include cached memory buffers. + * */ +size_t get_active_memory(); + +/* Get the peak amount of used memory in bytes. + * + * The maximum memory used is recorded from the beginning of the program + * execution. + * */ +size_t get_peak_memory(); + +/* Get the cache size in bytes. + * + * The cache includes memory not currently used that has not been returned + * to the system allocator. + * */ +size_t get_cache_memory(); + +/* Set the memory limit. + * Calls to malloc will wait on scheduled tasks if the limit is exceeded. If + * there are no more scheduled tasks an error will be raised if relaxed + * is false or memory will be allocated (including the potential for + * swap) if relaxed is true. + * + * The memory limit defaults to 1.5 times the maximum recommended working set + * size reported by the device. + * + * Returns the previous memory limit. + * */ +size_t set_memory_limit(size_t limit, bool relaxed = true); + +/* Set the free cache limit. + * If using more than the given limit, free memory will be reclaimed + * from the cache on the next allocation. To disable the cache, + * set the limit to 0. + * + * The cache limit defaults to the memory limit. + * + * Returns the previous cache limit. + * */ +size_t set_cache_limit(size_t limit); void new_stream(Stream stream); std::shared_ptr new_scoped_memory_pool(); diff --git a/mlx/backend/no_metal/metal.cpp b/mlx/backend/no_metal/metal.cpp index e62c05c63..240e00c41 100644 --- a/mlx/backend/no_metal/metal.cpp +++ b/mlx/backend/no_metal/metal.cpp @@ -23,10 +23,21 @@ std::function make_task( "[metal::make_task] Cannot make GPU task without metal backend"); } -// No cache for CPU only -bool cache_enabled(void) { - return false; +// No-ops when Metal is not available. +size_t get_active_memory() { + return 0; +} +size_t get_peak_memory() { + return 0; +} +size_t get_cache_memory() { + return 0; +} +size_t set_memory_limit(size_t, bool) { + return 0; +} +size_t set_cache_limit(size_t) { + return 0; } -void set_cache_enabled(bool) {} } // namespace mlx::core::metal diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 3d7e57946..774bb2285 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -1700,7 +1700,7 @@ array argpartition( int kth_ = kth < 0 ? kth + a.shape(axis) : kth; if (kth_ < 0 || kth_ >= a.shape(axis_)) { std::ostringstream msg; - msg << "[argpartition] Received invalid kth " << kth << "along axis " + msg << "[argpartition] Received invalid kth " << kth << " along axis " << axis << " for array with shape: " << a.shape(); throw std::invalid_argument(msg.str()); } diff --git a/python/src/metal.cpp b/python/src/metal.cpp index 5331c8870..4263f6bff 100644 --- a/python/src/metal.cpp +++ b/python/src/metal.cpp @@ -5,18 +5,88 @@ #include "mlx/backend/metal/metal.h" namespace py = pybind11; +using namespace py::literals; using namespace mlx::core; void init_metal(py::module_& m) { py::module_ metal = m.def_submodule("metal", "mlx.metal"); - metal.def("is_available", &metal::is_available); metal.def( - "cache_enabled", - &metal::cache_enabled, - "check if metal buffer cache is enabled, default is true"); + "is_available", + &metal::is_available, + R"pbdoc( + Check if the Metal back-end is available. + )pbdoc"); metal.def( - "set_cache_enabled", - &metal::set_cache_enabled, - "enable or disable metal buffer cache"); + "get_active_memory", + &metal::get_active_memory, + R"pbdoc( + Get the actively used memory in bytes. + + Note, this will not always match memory use reported by the system because + it does not include cached memory buffers. + )pbdoc"); + metal.def( + "get_peak_memory", + &metal::get_peak_memory, + R"pbdoc( + Get the peak amount of used memory in bytes. + + The maximum memory used is recorded from the beginning of the program + execution. + )pbdoc"); + metal.def( + "get_cache_memory", + &metal::get_cache_memory, + R"pbdoc( + Get the cache size in bytes. + + The cache includes memory not currently used that has not been returned + to the system allocator. + )pbdoc"); + metal.def( + "set_memory_limit", + &metal::set_memory_limit, + "limit"_a, + py::kw_only(), + "relaxed"_a = true, + R"pbdoc( + Set the memory limit. + + Memory allocations will wait on scheduled tasks to complete if the limit + is exceeded. If there are no more scheduled tasks an error will be raised + if ``relaxed`` is ``False``. Otherwise memory will be allocated + (including the potential for swap) if ``relaxed`` is ``True``. + + The memory limit defaults to 1.5 times the maximum recommended working set + size reported by the device. + + Args: + limit (int): Memory limit in bytes. + relaxed (bool, optional): If `False`` an error is raised if the limit + is exceeded. Default: ``True`` + + Returns: + int: The previous memory limit in bytes. + )pbdoc"); + metal.def( + "set_cache_limit", + &metal::set_cache_limit, + "limit"_a, + R"pbdoc( + Set the free cache limit. + + If using more than the given limit, free memory will be reclaimed + from the cache on the next allocation. To disable the cache, set + the limit to ``0``. + + The cache limit defaults to the memory limit. See + :func:`set_memory_limit` for more details. + + Args: + limit (int): The cache limit in bytes. + + Returns: + int: The previous cache limit in bytes. + )pbdoc"); } diff --git a/python/tests/test_metal.py b/python/tests/test_metal.py new file mode 100644 index 000000000..53b269772 --- /dev/null +++ b/python/tests/test_metal.py @@ -0,0 +1,45 @@ +# Copyright © 2023-2024 Apple Inc. + +import unittest + +import mlx.core as mx +import mlx_tests + + +class TestMetal(mlx_tests.MLXTestCase): + + @unittest.skipIf(not mx.metal.is_available(), "Metal is not available") + def test_memory_info(self): + old_limit = mx.metal.set_cache_limit(0) + + a = mx.zeros((4096,)) + mx.eval(a) + del a + self.assertEqual(mx.metal.get_cache_memory(), 0) + self.assertEqual(mx.metal.set_cache_limit(old_limit), 0) + self.assertEqual(mx.metal.set_cache_limit(old_limit), old_limit) + + old_limit = mx.metal.set_memory_limit(10) + self.assertTrue(mx.metal.set_memory_limit(old_limit), 10) + self.assertTrue(mx.metal.set_memory_limit(old_limit), old_limit) + + # Query active and peak memory + a = mx.zeros((4096,)) + mx.eval(a) + active_mem = mx.metal.get_active_memory() + self.assertTrue(active_mem >= 4096 * 4) + + b = mx.zeros((4096,)) + mx.eval(b) + del b + + new_active_mem = mx.metal.get_active_memory() + self.assertEqual(new_active_mem, active_mem) + peak_mem = mx.metal.get_peak_memory() + self.assertTrue(peak_mem >= 4096 * 8) + cache_mem = mx.metal.get_cache_memory() + self.assertTrue(cache_mem >= 4096 * 4) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/metal_tests.cpp b/tests/metal_tests.cpp index ff4e3bb0f..c7a0c8c14 100644 --- a/tests/metal_tests.cpp +++ b/tests/metal_tests.cpp @@ -1,4 +1,4 @@ -// Copyright © 2023 Apple Inc. +// Copyright © 2023-2024 Apple Inc. #include #include "doctest/doctest.h" @@ -473,41 +473,42 @@ TEST_CASE("test metal validation") { eval(scatter_max(array(1), {}, array(2), std::vector{})); } -TEST_CASE("test metal enable/disable cache") { - // Test enable metal cache +TEST_CASE("test metal memory info") { + // Test cache limits { - metal::set_cache_enabled(true); - CHECK(metal::cache_enabled()); - - auto& a = metal::allocator(); - auto size = 100; - auto buf = a.malloc(size, false); - - // Release a - a.free(buf); - - // Check size should equals to size - CHECK_EQ(static_cast(buf.ptr())->length(), size); + auto old_limit = metal::set_cache_limit(0); + { + auto a = zeros({4096}); + eval(a); + } + CHECK_EQ(metal::get_cache_memory(), 0); + CHECK_EQ(metal::set_cache_limit(old_limit), 0); + CHECK_EQ(metal::set_cache_limit(old_limit), old_limit); } - // Test disable metal cache + // Test memory limits { - metal::set_cache_enabled(false); - CHECK(!metal::cache_enabled()); + auto old_limit = metal::set_memory_limit(10); + CHECK_EQ(metal::set_memory_limit(old_limit), 10); + CHECK_EQ(metal::set_memory_limit(old_limit), old_limit); + } - auto& a = metal::allocator(); - auto size = 100; - auto buf = a.malloc(size, false); - auto buf_ptr = static_cast(buf.ptr()); - unsigned char first_byte = *reinterpret_cast(buf_ptr); + // Query active and peak memory + { + auto a = zeros({4096}); + eval(a); + auto active_mem = metal::get_active_memory(); + CHECK(active_mem >= 4096 * 4); + { + auto b = zeros({4096}); + eval(b); + } + auto new_active_mem = metal::get_active_memory(); + CHECK_EQ(new_active_mem, active_mem); + auto peak_mem = metal::get_peak_memory(); + CHECK(peak_mem >= 4096 * 8); - // Release a - a.free(buf); - - // If release successfully, the first byte should be different from the - // first byte before release - unsigned char new_first_byte = *reinterpret_cast(buf_ptr); - - CHECK_NE(new_first_byte, first_byte); + auto cache_mem = metal::get_cache_memory(); + CHECK(cache_mem >= 4096 * 4); } }