* expose residency sets as wire/unwire

* returns wired size

* fix

* runtime support check

* fix os check

* fix test

* fix no metal build

* docs

* nit

* nits in docs

* nits
This commit is contained in:
Awni Hannun 2024-10-25 09:35:33 -07:00 committed by GitHub
parent f70764a162
commit 0eb56d5be0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 246 additions and 14 deletions

View File

@ -14,6 +14,7 @@ Metal
get_cache_memory
set_memory_limit
set_cache_limit
set_wired_limit
clear_cache
start_capture
stop_capture

View File

@ -99,6 +99,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ternary.cpp
${CMAKE_CURRENT_SOURCE_DIR}/unary.cpp
${CMAKE_CURRENT_SOURCE_DIR}/resident.cpp
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp)
if(NOT MLX_METAL_PATH)

View File

@ -2,6 +2,7 @@
#include "mlx/backend/metal/allocator.h"
#include "mlx/backend/metal/metal.h"
#include "mlx/backend/metal/metal_impl.h"
#include "mlx/backend/metal/resident.h"
#include <mach/vm_page_size.h>
#include <unistd.h>
@ -140,6 +141,7 @@ void BufferCache::remove_from_list(BufferCache::BufferHolder* to_remove) {
MetalAllocator::MetalAllocator()
: device_(device(mlx::core::Device::gpu).mtl_device()),
residency_set_(device_),
buffer_cache_(device_) {
auto memsize = std::get<size_t>(device_info()["memory_size"]);
block_limit_ =
@ -148,6 +150,8 @@ MetalAllocator::MetalAllocator()
static_cast<size_t>(0.95 * device_->recommendedMaxWorkingSetSize()),
block_limit_);
max_pool_size_ = block_limit_;
device(mlx::core::Device::gpu)
.set_residency_set(residency_set_.mtl_residency_set());
}
size_t MetalAllocator::set_cache_limit(size_t limit) {
@ -164,6 +168,12 @@ size_t MetalAllocator::set_memory_limit(size_t limit, bool relaxed) {
return limit;
};
size_t MetalAllocator::set_wired_limit(size_t limit) {
std::swap(limit, wired_limit_);
residency_set_.resize(wired_limit_);
return limit;
};
Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
// Metal doesn't like empty buffers
if (size == 0) {
@ -220,6 +230,8 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
buffer_cache_.release_cached_buffers(get_cache_memory() - max_pool_size_);
}
residency_set_.insert(buf);
return Buffer{static_cast<void*>(buf)};
}
@ -231,6 +243,7 @@ void MetalAllocator::clear_cache() {
void MetalAllocator::free(Buffer buffer) {
auto buf = static_cast<MTL::Buffer*>(buffer.ptr());
std::unique_lock lk(mutex_);
residency_set_.erase(buf);
active_memory_ -= buf->length();
if (get_cache_memory() < max_pool_size_) {
buffer_cache_.recycle_to_cache(buf);
@ -246,15 +259,9 @@ size_t MetalAllocator::size(Buffer buffer) const {
}
MetalAllocator& allocator() {
// By creating the |allocator_| on heap, the destructor of MetalAllocator will
// not be called on exit and all the buffers will be leaked. This is necessary
// because releasing buffers can take more than 30sec when the program holds a
// lot of RAM (for example inferencing a LLM), and it would feel frozen to
// users when exiting.
// TODO(zcbenz): Consider using the `base::NoDestructor` class from Chromium
// when applying this pattern to more places, or when introducing sanitizers
// to MLX.
// https://source.chromium.org/chromium/chromium/src/+/main:base/no_destructor.h
// By creating the |allocator_| on heap, the destructor of MetalAllocator
// will not be called on exit and buffers in the cache will be leaked. This
// can save some time at program exit.
static MetalAllocator* allocator_ = new MetalAllocator;
return *allocator_;
}
@ -265,6 +272,15 @@ size_t set_cache_limit(size_t limit) {
size_t set_memory_limit(size_t limit, bool relaxed /* = true */) {
return allocator().set_memory_limit(limit, relaxed);
}
size_t set_wired_limit(size_t limit) {
if (limit >
std::get<size_t>(device_info()["max_recommended_working_set_size"])) {
throw std::invalid_argument(
"[metal::set_wired_limit] Setting a wired limit larger than "
"the maximum working set size is not allowed.");
}
return allocator().set_wired_limit(limit);
}
size_t get_active_memory() {
return allocator().get_active_memory();
}

View File

@ -8,6 +8,7 @@
#include "mlx/allocator.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/resident.h"
namespace mlx::core::metal {
@ -72,6 +73,7 @@ class MetalAllocator : public allocator::Allocator {
};
size_t set_cache_limit(size_t limit);
size_t set_memory_limit(size_t limit, bool relaxed);
size_t set_wired_limit(size_t limit);
void clear_cache();
private:
@ -82,12 +84,15 @@ class MetalAllocator : public allocator::Allocator {
// Caching allocator
BufferCache buffer_cache_;
ResidencySet residency_set_;
// Allocation stats
size_t block_limit_;
size_t gc_limit_;
size_t active_memory_{0};
size_t peak_memory_{0};
size_t max_pool_size_;
size_t wired_limit_{0};
bool relaxed_{true};
std::mutex mutex_;

View File

@ -206,6 +206,9 @@ void Device::new_queue(int index) {
"[metal::Device] Failed to make new command queue.");
}
stream_map_.emplace(index, q);
if (residency_set_ != nullptr) {
q->addResidencySet(residency_set_);
}
}
int Device::get_command_buffer_ops(int index) {
@ -351,7 +354,7 @@ MTL::Library* Device::build_library_(const std::string& source_string) {
// Throw error if unable to compile library
if (!mtl_lib) {
std::ostringstream msg;
msg << "[metal::Device] Unable to build metal library from source" << "\n";
msg << "[metal::Device] Unable to build metal library from source\n";
if (error) {
msg << error->localizedDescription()->utf8String() << "\n";
}
@ -593,6 +596,21 @@ MTL::ComputePipelineState* Device::get_kernel(
return get_kernel_(base_name, mtl_lib, kname, func_consts, linked_functions);
}
void Device::set_residency_set(const MTL::ResidencySet* residency_set) {
if (residency_set_ != nullptr) {
throw std::runtime_error(
"[Device::set_residency_set] Can only be set once.");
}
if (residency_set == nullptr) {
return;
}
residency_set_ = residency_set;
// Attach residency set to existing command queues
for (auto& [_, stream] : stream_map_) {
stream.queue->addResidencySet(residency_set_);
}
}
Device& device(mlx::core::Device) {
static Device metal_device;
return metal_device;

View File

@ -181,6 +181,8 @@ class Device {
void add_temporary(array arr, int index);
void add_temporaries(std::vector<array> arrays, int index);
void set_residency_set(const MTL::ResidencySet* residency_set);
private:
DeviceStream& get_stream_(int index) {
return stream_map_.find(index)->second;
@ -225,6 +227,7 @@ class Device {
std::shared_mutex library_mtx_;
std::unordered_map<std::string, MTL::Library*> library_map_;
const MTL::ResidencySet* residency_set_{nullptr};
};
Device& device(mlx::core::Device);

View File

@ -1,5 +1,7 @@
// Copyright © 2023 Apple Inc.
#pragma once
#include "mlx/backend/metal/device.h"
namespace mlx::core {

View File

@ -63,6 +63,19 @@ size_t set_cache_limit(size_t limit);
/* Clear the memory cache. */
void clear_cache();
/* Set the wired size limit.
*
* Note, this function is only useful for macOS 15.0 or higher.
*
* The wired limit is the total size in bytes of memory that will be kept
* resident. The default value is ``0``.
*
* Setting a wired limit larger than system wired limit is an error.
*
* Returns the previous wired limit.
* */
size_t set_wired_limit(size_t limit);
/** Capture a GPU trace, saving it to an absolute file `path` */
void start_capture(std::string path = "");
void stop_capture();

View File

@ -0,0 +1,96 @@
// Copyright © 2024 Apple Inc.
#include "mlx/backend/metal/resident.h"
#include "mlx/backend/metal/metal_impl.h"
namespace mlx::core::metal {
// TODO maybe worth including tvos / visionos
#define supported __builtin_available(macOS 15, iOS 18, *)
ResidencySet::ResidencySet(MTL::Device* d) {
if (supported) {
auto pool = new_scoped_memory_pool();
auto desc = MTL::ResidencySetDescriptor::alloc()->init();
NS::Error* error;
wired_set_ = d->newResidencySet(desc, &error);
desc->release();
if (!wired_set_) {
std::ostringstream msg;
msg << "[metal::Device] Unable to construct residency set.\n";
if (error) {
msg << error->localizedDescription()->utf8String() << "\n";
}
throw std::runtime_error(msg.str());
}
}
}
void ResidencySet::insert(MTL::Allocation* buf) {
if (supported) {
if (wired_set_->allocatedSize() + buf->allocatedSize() <= capacity_) {
wired_set_->addAllocation(buf);
wired_set_->commit();
wired_set_->requestResidency();
} else {
unwired_set_.insert(buf);
}
}
}
void ResidencySet::erase(MTL::Allocation* buf) {
if (supported) {
if (auto it = unwired_set_.find(buf); it != unwired_set_.end()) {
unwired_set_.erase(it);
} else {
wired_set_->removeAllocation(buf);
wired_set_->commit();
}
}
}
void ResidencySet::resize(size_t size) {
if (supported) {
if (capacity_ == size) {
return;
}
capacity_ = size;
size_t current_size = wired_set_->allocatedSize();
if (current_size < size) {
// Add unwired allocations to the set
for (auto it = unwired_set_.begin(); it != unwired_set_.end();) {
auto buf_size = (*it)->allocatedSize();
if (current_size + buf_size > size) {
it++;
} else {
current_size += buf_size;
wired_set_->addAllocation(*it);
unwired_set_.erase(it++);
}
}
wired_set_->commit();
wired_set_->requestResidency();
} else if (current_size > size) {
// Remove wired allocations until under capacity
auto allocations = wired_set_->allAllocations();
auto num_allocations = wired_set_->allocationCount();
for (int i = 0; i < num_allocations && current_size > size; ++i) {
auto buf = static_cast<const MTL::Allocation*>(allocations->object(i));
wired_set_->removeAllocation(buf);
current_size -= buf->allocatedSize();
unwired_set_.insert(buf);
}
wired_set_->commit();
}
}
}
ResidencySet::~ResidencySet() {
if (supported) {
wired_set_->release();
}
}
} // namespace mlx::core::metal

View File

@ -0,0 +1,32 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include "mlx/backend/metal/device.h"
namespace mlx::core::metal {
class ResidencySet {
public:
ResidencySet(MTL::Device* d);
~ResidencySet();
ResidencySet(const ResidencySet&) = delete;
ResidencySet& operator=(const ResidencySet&) = delete;
const MTL::ResidencySet* mtl_residency_set() {
return wired_set_;
}
void insert(MTL::Allocation* buf);
void erase(MTL::Allocation* buf);
void resize(size_t size);
private:
MTL::ResidencySet* wired_set_{nullptr};
std::unordered_set<const MTL::Allocation*> unwired_set_;
size_t capacity_{0};
};
} // namespace mlx::core::metal

View File

@ -16,14 +16,14 @@ std::unique_ptr<void, std::function<void(void*)>> new_scoped_memory_pool() {
return nullptr;
}
std::function<void()> make_task(array arr, bool signal) {
std::function<void()> make_task(array, bool) {
throw std::runtime_error(
"[metal::make_task] Cannot make GPU task without metal backend");
}
std::function<void()> make_synchronize_task(
Stream s,
std::shared_ptr<std::promise<void>> p) {
Stream,
std::shared_ptr<std::promise<void>>) {
throw std::runtime_error(
"[metal::make_synchronize_task] Cannot synchronize GPU"
" without metal backend");
@ -46,7 +46,11 @@ size_t set_memory_limit(size_t, bool) {
size_t set_cache_limit(size_t) {
return 0;
}
void start_capture(std::string path) {}
size_t set_wired_limit(size_t) {
return 0;
}
void start_capture(std::string) {}
void stop_capture() {}
void clear_cache() {}

View File

@ -6,6 +6,7 @@
#include <nanobind/stl/string.h>
#include <nanobind/stl/unordered_map.h>
#include <nanobind/stl/variant.h>
#include <nanobind/stl/vector.h>
namespace nb = nanobind;
using namespace nb::literals;
@ -98,6 +99,38 @@ void init_metal(nb::module_& m) {
Returns:
int: The previous cache limit in bytes.
)pbdoc");
metal.def(
"set_wired_limit",
&metal::set_wired_limit,
"limit"_a,
R"pbdoc(
Set the wired size limit.
.. note::
* This function is only useful on macOS 15.0 or higher.
* The wired limit should remain strictly less than the total
memory size.
The wired limit is the total size in bytes of memory that will be kept
resident. The default value is ``0``.
Setting a wired limit larger than system wired limit is an error. You can
increase the system wired limit with:
.. code-block::
sudo sysctl iogpu.wired_limit_mb=<size_in_megabytes>
Use :func:`device_info` to query the system wired limit
(``"max_recommended_working_set_size"``) and the total memory size
(``"memory_size"``).
Args:
limit (int): The wired limit in bytes.
Returns:
int: The previous wired limit in bytes.
)pbdoc");
metal.def(
"clear_cache",
&metal::clear_cache,

View File

@ -47,6 +47,14 @@ class TestMetal(mlx_tests.MLXTestCase):
mx.metal.reset_peak_memory()
self.assertEqual(mx.metal.get_peak_memory(), 0)
old_limit = mx.metal.set_wired_limit(1000)
old_limit = mx.metal.set_wired_limit(0)
self.assertEqual(old_limit, 1000)
max_size = mx.metal.device_info()["max_recommended_working_set_size"]
with self.assertRaises(ValueError):
mx.metal.set_wired_limit(max_size + 10)
if __name__ == "__main__":
unittest.main()