diff --git a/docs/src/python/metal.rst b/docs/src/python/metal.rst index cb2cdb38e..4d6fb91d9 100644 --- a/docs/src/python/metal.rst +++ b/docs/src/python/metal.rst @@ -14,6 +14,7 @@ Metal get_cache_memory set_memory_limit set_cache_limit + set_wired_limit clear_cache start_capture stop_capture diff --git a/mlx/backend/metal/CMakeLists.txt b/mlx/backend/metal/CMakeLists.txt index 7b2949ade..3e88e18d1 100644 --- a/mlx/backend/metal/CMakeLists.txt +++ b/mlx/backend/metal/CMakeLists.txt @@ -99,6 +99,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp ${CMAKE_CURRENT_SOURCE_DIR}/ternary.cpp ${CMAKE_CURRENT_SOURCE_DIR}/unary.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/resident.cpp ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp) if(NOT MLX_METAL_PATH) diff --git a/mlx/backend/metal/allocator.cpp b/mlx/backend/metal/allocator.cpp index 8c1f80291..0453aaf01 100644 --- a/mlx/backend/metal/allocator.cpp +++ b/mlx/backend/metal/allocator.cpp @@ -2,6 +2,7 @@ #include "mlx/backend/metal/allocator.h" #include "mlx/backend/metal/metal.h" #include "mlx/backend/metal/metal_impl.h" +#include "mlx/backend/metal/resident.h" #include #include @@ -140,6 +141,7 @@ void BufferCache::remove_from_list(BufferCache::BufferHolder* to_remove) { MetalAllocator::MetalAllocator() : device_(device(mlx::core::Device::gpu).mtl_device()), + residency_set_(device_), buffer_cache_(device_) { auto memsize = std::get(device_info()["memory_size"]); block_limit_ = @@ -148,6 +150,8 @@ MetalAllocator::MetalAllocator() static_cast(0.95 * device_->recommendedMaxWorkingSetSize()), block_limit_); max_pool_size_ = block_limit_; + device(mlx::core::Device::gpu) + .set_residency_set(residency_set_.mtl_residency_set()); } size_t MetalAllocator::set_cache_limit(size_t limit) { @@ -164,6 +168,12 @@ size_t MetalAllocator::set_memory_limit(size_t limit, bool relaxed) { return limit; }; +size_t MetalAllocator::set_wired_limit(size_t limit) { + std::swap(limit, wired_limit_); + residency_set_.resize(wired_limit_); + return limit; +}; + Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) { // Metal doesn't like empty buffers if (size == 0) { @@ -220,6 +230,8 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) { buffer_cache_.release_cached_buffers(get_cache_memory() - max_pool_size_); } + residency_set_.insert(buf); + return Buffer{static_cast(buf)}; } @@ -231,6 +243,7 @@ void MetalAllocator::clear_cache() { void MetalAllocator::free(Buffer buffer) { auto buf = static_cast(buffer.ptr()); std::unique_lock lk(mutex_); + residency_set_.erase(buf); active_memory_ -= buf->length(); if (get_cache_memory() < max_pool_size_) { buffer_cache_.recycle_to_cache(buf); @@ -246,15 +259,9 @@ size_t MetalAllocator::size(Buffer buffer) const { } MetalAllocator& allocator() { - // By creating the |allocator_| on heap, the destructor of MetalAllocator will - // not be called on exit and all the buffers will be leaked. This is necessary - // because releasing buffers can take more than 30sec when the program holds a - // lot of RAM (for example inferencing a LLM), and it would feel frozen to - // users when exiting. - // TODO(zcbenz): Consider using the `base::NoDestructor` class from Chromium - // when applying this pattern to more places, or when introducing sanitizers - // to MLX. - // https://source.chromium.org/chromium/chromium/src/+/main:base/no_destructor.h + // By creating the |allocator_| on heap, the destructor of MetalAllocator + // will not be called on exit and buffers in the cache will be leaked. This + // can save some time at program exit. static MetalAllocator* allocator_ = new MetalAllocator; return *allocator_; } @@ -265,6 +272,15 @@ size_t set_cache_limit(size_t limit) { size_t set_memory_limit(size_t limit, bool relaxed /* = true */) { return allocator().set_memory_limit(limit, relaxed); } +size_t set_wired_limit(size_t limit) { + if (limit > + std::get(device_info()["max_recommended_working_set_size"])) { + throw std::invalid_argument( + "[metal::set_wired_limit] Setting a wired limit larger than " + "the maximum working set size is not allowed."); + } + return allocator().set_wired_limit(limit); +} size_t get_active_memory() { return allocator().get_active_memory(); } diff --git a/mlx/backend/metal/allocator.h b/mlx/backend/metal/allocator.h index f2cf8feb8..997638db5 100644 --- a/mlx/backend/metal/allocator.h +++ b/mlx/backend/metal/allocator.h @@ -8,6 +8,7 @@ #include "mlx/allocator.h" #include "mlx/backend/metal/device.h" +#include "mlx/backend/metal/resident.h" namespace mlx::core::metal { @@ -72,6 +73,7 @@ class MetalAllocator : public allocator::Allocator { }; size_t set_cache_limit(size_t limit); size_t set_memory_limit(size_t limit, bool relaxed); + size_t set_wired_limit(size_t limit); void clear_cache(); private: @@ -82,12 +84,15 @@ class MetalAllocator : public allocator::Allocator { // Caching allocator BufferCache buffer_cache_; + ResidencySet residency_set_; + // Allocation stats size_t block_limit_; size_t gc_limit_; size_t active_memory_{0}; size_t peak_memory_{0}; size_t max_pool_size_; + size_t wired_limit_{0}; bool relaxed_{true}; std::mutex mutex_; diff --git a/mlx/backend/metal/device.cpp b/mlx/backend/metal/device.cpp index 37254a8a0..5106dfa21 100644 --- a/mlx/backend/metal/device.cpp +++ b/mlx/backend/metal/device.cpp @@ -206,6 +206,9 @@ void Device::new_queue(int index) { "[metal::Device] Failed to make new command queue."); } stream_map_.emplace(index, q); + if (residency_set_ != nullptr) { + q->addResidencySet(residency_set_); + } } int Device::get_command_buffer_ops(int index) { @@ -351,7 +354,7 @@ MTL::Library* Device::build_library_(const std::string& source_string) { // Throw error if unable to compile library if (!mtl_lib) { std::ostringstream msg; - msg << "[metal::Device] Unable to build metal library from source" << "\n"; + msg << "[metal::Device] Unable to build metal library from source\n"; if (error) { msg << error->localizedDescription()->utf8String() << "\n"; } @@ -593,6 +596,21 @@ MTL::ComputePipelineState* Device::get_kernel( return get_kernel_(base_name, mtl_lib, kname, func_consts, linked_functions); } +void Device::set_residency_set(const MTL::ResidencySet* residency_set) { + if (residency_set_ != nullptr) { + throw std::runtime_error( + "[Device::set_residency_set] Can only be set once."); + } + if (residency_set == nullptr) { + return; + } + residency_set_ = residency_set; + // Attach residency set to existing command queues + for (auto& [_, stream] : stream_map_) { + stream.queue->addResidencySet(residency_set_); + } +} + Device& device(mlx::core::Device) { static Device metal_device; return metal_device; diff --git a/mlx/backend/metal/device.h b/mlx/backend/metal/device.h index a3b613d68..8737b479a 100644 --- a/mlx/backend/metal/device.h +++ b/mlx/backend/metal/device.h @@ -181,6 +181,8 @@ class Device { void add_temporary(array arr, int index); void add_temporaries(std::vector arrays, int index); + void set_residency_set(const MTL::ResidencySet* residency_set); + private: DeviceStream& get_stream_(int index) { return stream_map_.find(index)->second; @@ -225,6 +227,7 @@ class Device { std::shared_mutex library_mtx_; std::unordered_map library_map_; + const MTL::ResidencySet* residency_set_{nullptr}; }; Device& device(mlx::core::Device); diff --git a/mlx/backend/metal/matmul.h b/mlx/backend/metal/matmul.h index c771bb8b4..b6160d6f1 100644 --- a/mlx/backend/metal/matmul.h +++ b/mlx/backend/metal/matmul.h @@ -1,5 +1,7 @@ // Copyright © 2023 Apple Inc. +#pragma once + #include "mlx/backend/metal/device.h" namespace mlx::core { diff --git a/mlx/backend/metal/metal.h b/mlx/backend/metal/metal.h index c63ddda28..830df4f49 100644 --- a/mlx/backend/metal/metal.h +++ b/mlx/backend/metal/metal.h @@ -63,6 +63,19 @@ size_t set_cache_limit(size_t limit); /* Clear the memory cache. */ void clear_cache(); +/* Set the wired size limit. + * + * Note, this function is only useful for macOS 15.0 or higher. + * + * The wired limit is the total size in bytes of memory that will be kept + * resident. The default value is ``0``. + * + * Setting a wired limit larger than system wired limit is an error. + * + * Returns the previous wired limit. + * */ +size_t set_wired_limit(size_t limit); + /** Capture a GPU trace, saving it to an absolute file `path` */ void start_capture(std::string path = ""); void stop_capture(); diff --git a/mlx/backend/metal/resident.cpp b/mlx/backend/metal/resident.cpp new file mode 100644 index 000000000..403857c6d --- /dev/null +++ b/mlx/backend/metal/resident.cpp @@ -0,0 +1,96 @@ +// Copyright © 2024 Apple Inc. + +#include "mlx/backend/metal/resident.h" +#include "mlx/backend/metal/metal_impl.h" + +namespace mlx::core::metal { + +// TODO maybe worth including tvos / visionos +#define supported __builtin_available(macOS 15, iOS 18, *) + +ResidencySet::ResidencySet(MTL::Device* d) { + if (supported) { + auto pool = new_scoped_memory_pool(); + auto desc = MTL::ResidencySetDescriptor::alloc()->init(); + NS::Error* error; + wired_set_ = d->newResidencySet(desc, &error); + desc->release(); + if (!wired_set_) { + std::ostringstream msg; + msg << "[metal::Device] Unable to construct residency set.\n"; + if (error) { + msg << error->localizedDescription()->utf8String() << "\n"; + } + throw std::runtime_error(msg.str()); + } + } +} + +void ResidencySet::insert(MTL::Allocation* buf) { + if (supported) { + if (wired_set_->allocatedSize() + buf->allocatedSize() <= capacity_) { + wired_set_->addAllocation(buf); + wired_set_->commit(); + wired_set_->requestResidency(); + } else { + unwired_set_.insert(buf); + } + } +} + +void ResidencySet::erase(MTL::Allocation* buf) { + if (supported) { + if (auto it = unwired_set_.find(buf); it != unwired_set_.end()) { + unwired_set_.erase(it); + } else { + wired_set_->removeAllocation(buf); + wired_set_->commit(); + } + } +} + +void ResidencySet::resize(size_t size) { + if (supported) { + if (capacity_ == size) { + return; + } + capacity_ = size; + + size_t current_size = wired_set_->allocatedSize(); + + if (current_size < size) { + // Add unwired allocations to the set + for (auto it = unwired_set_.begin(); it != unwired_set_.end();) { + auto buf_size = (*it)->allocatedSize(); + if (current_size + buf_size > size) { + it++; + } else { + current_size += buf_size; + wired_set_->addAllocation(*it); + unwired_set_.erase(it++); + } + } + wired_set_->commit(); + wired_set_->requestResidency(); + } else if (current_size > size) { + // Remove wired allocations until under capacity + auto allocations = wired_set_->allAllocations(); + auto num_allocations = wired_set_->allocationCount(); + for (int i = 0; i < num_allocations && current_size > size; ++i) { + auto buf = static_cast(allocations->object(i)); + wired_set_->removeAllocation(buf); + current_size -= buf->allocatedSize(); + unwired_set_.insert(buf); + } + wired_set_->commit(); + } + } +} + +ResidencySet::~ResidencySet() { + if (supported) { + wired_set_->release(); + } +} + +} // namespace mlx::core::metal diff --git a/mlx/backend/metal/resident.h b/mlx/backend/metal/resident.h new file mode 100644 index 000000000..5db558286 --- /dev/null +++ b/mlx/backend/metal/resident.h @@ -0,0 +1,32 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include "mlx/backend/metal/device.h" + +namespace mlx::core::metal { + +class ResidencySet { + public: + ResidencySet(MTL::Device* d); + ~ResidencySet(); + + ResidencySet(const ResidencySet&) = delete; + ResidencySet& operator=(const ResidencySet&) = delete; + + const MTL::ResidencySet* mtl_residency_set() { + return wired_set_; + } + + void insert(MTL::Allocation* buf); + void erase(MTL::Allocation* buf); + + void resize(size_t size); + + private: + MTL::ResidencySet* wired_set_{nullptr}; + std::unordered_set unwired_set_; + size_t capacity_{0}; +}; + +} // namespace mlx::core::metal diff --git a/mlx/backend/no_metal/metal.cpp b/mlx/backend/no_metal/metal.cpp index 4cf5b00db..d23f8d33a 100644 --- a/mlx/backend/no_metal/metal.cpp +++ b/mlx/backend/no_metal/metal.cpp @@ -16,14 +16,14 @@ std::unique_ptr> new_scoped_memory_pool() { return nullptr; } -std::function make_task(array arr, bool signal) { +std::function make_task(array, bool) { throw std::runtime_error( "[metal::make_task] Cannot make GPU task without metal backend"); } std::function make_synchronize_task( - Stream s, - std::shared_ptr> p) { + Stream, + std::shared_ptr>) { throw std::runtime_error( "[metal::make_synchronize_task] Cannot synchronize GPU" " without metal backend"); @@ -46,7 +46,11 @@ size_t set_memory_limit(size_t, bool) { size_t set_cache_limit(size_t) { return 0; } -void start_capture(std::string path) {} +size_t set_wired_limit(size_t) { + return 0; +} + +void start_capture(std::string) {} void stop_capture() {} void clear_cache() {} diff --git a/python/src/metal.cpp b/python/src/metal.cpp index 4306b3915..c08bd6c50 100644 --- a/python/src/metal.cpp +++ b/python/src/metal.cpp @@ -6,6 +6,7 @@ #include #include #include +#include namespace nb = nanobind; using namespace nb::literals; @@ -98,6 +99,38 @@ void init_metal(nb::module_& m) { Returns: int: The previous cache limit in bytes. )pbdoc"); + metal.def( + "set_wired_limit", + &metal::set_wired_limit, + "limit"_a, + R"pbdoc( + Set the wired size limit. + + .. note:: + * This function is only useful on macOS 15.0 or higher. + * The wired limit should remain strictly less than the total + memory size. + + The wired limit is the total size in bytes of memory that will be kept + resident. The default value is ``0``. + + Setting a wired limit larger than system wired limit is an error. You can + increase the system wired limit with: + + .. code-block:: + + sudo sysctl iogpu.wired_limit_mb= + + Use :func:`device_info` to query the system wired limit + (``"max_recommended_working_set_size"``) and the total memory size + (``"memory_size"``). + + Args: + limit (int): The wired limit in bytes. + + Returns: + int: The previous wired limit in bytes. + )pbdoc"); metal.def( "clear_cache", &metal::clear_cache, diff --git a/python/tests/test_metal.py b/python/tests/test_metal.py index be63b118e..81cefabce 100644 --- a/python/tests/test_metal.py +++ b/python/tests/test_metal.py @@ -47,6 +47,14 @@ class TestMetal(mlx_tests.MLXTestCase): mx.metal.reset_peak_memory() self.assertEqual(mx.metal.get_peak_memory(), 0) + old_limit = mx.metal.set_wired_limit(1000) + old_limit = mx.metal.set_wired_limit(0) + self.assertEqual(old_limit, 1000) + + max_size = mx.metal.device_info()["max_recommended_working_set_size"] + with self.assertRaises(ValueError): + mx.metal.set_wired_limit(max_size + 10) + if __name__ == "__main__": unittest.main()