mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-24 02:41:19 +08:00
Wired (#1510)
* expose residency sets as wire/unwire * returns wired size * fix * runtime support check * fix os check * fix test * fix no metal build * docs * nit * nits in docs * nits
This commit is contained in:
parent
f70764a162
commit
0eb56d5be0
@ -14,6 +14,7 @@ Metal
|
||||
get_cache_memory
|
||||
set_memory_limit
|
||||
set_cache_limit
|
||||
set_wired_limit
|
||||
clear_cache
|
||||
start_capture
|
||||
stop_capture
|
||||
|
@ -99,6 +99,7 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/ternary.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/unary.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/resident.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp)
|
||||
|
||||
if(NOT MLX_METAL_PATH)
|
||||
|
@ -2,6 +2,7 @@
|
||||
#include "mlx/backend/metal/allocator.h"
|
||||
#include "mlx/backend/metal/metal.h"
|
||||
#include "mlx/backend/metal/metal_impl.h"
|
||||
#include "mlx/backend/metal/resident.h"
|
||||
|
||||
#include <mach/vm_page_size.h>
|
||||
#include <unistd.h>
|
||||
@ -140,6 +141,7 @@ void BufferCache::remove_from_list(BufferCache::BufferHolder* to_remove) {
|
||||
|
||||
MetalAllocator::MetalAllocator()
|
||||
: device_(device(mlx::core::Device::gpu).mtl_device()),
|
||||
residency_set_(device_),
|
||||
buffer_cache_(device_) {
|
||||
auto memsize = std::get<size_t>(device_info()["memory_size"]);
|
||||
block_limit_ =
|
||||
@ -148,6 +150,8 @@ MetalAllocator::MetalAllocator()
|
||||
static_cast<size_t>(0.95 * device_->recommendedMaxWorkingSetSize()),
|
||||
block_limit_);
|
||||
max_pool_size_ = block_limit_;
|
||||
device(mlx::core::Device::gpu)
|
||||
.set_residency_set(residency_set_.mtl_residency_set());
|
||||
}
|
||||
|
||||
size_t MetalAllocator::set_cache_limit(size_t limit) {
|
||||
@ -164,6 +168,12 @@ size_t MetalAllocator::set_memory_limit(size_t limit, bool relaxed) {
|
||||
return limit;
|
||||
};
|
||||
|
||||
size_t MetalAllocator::set_wired_limit(size_t limit) {
|
||||
std::swap(limit, wired_limit_);
|
||||
residency_set_.resize(wired_limit_);
|
||||
return limit;
|
||||
};
|
||||
|
||||
Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
|
||||
// Metal doesn't like empty buffers
|
||||
if (size == 0) {
|
||||
@ -220,6 +230,8 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
|
||||
buffer_cache_.release_cached_buffers(get_cache_memory() - max_pool_size_);
|
||||
}
|
||||
|
||||
residency_set_.insert(buf);
|
||||
|
||||
return Buffer{static_cast<void*>(buf)};
|
||||
}
|
||||
|
||||
@ -231,6 +243,7 @@ void MetalAllocator::clear_cache() {
|
||||
void MetalAllocator::free(Buffer buffer) {
|
||||
auto buf = static_cast<MTL::Buffer*>(buffer.ptr());
|
||||
std::unique_lock lk(mutex_);
|
||||
residency_set_.erase(buf);
|
||||
active_memory_ -= buf->length();
|
||||
if (get_cache_memory() < max_pool_size_) {
|
||||
buffer_cache_.recycle_to_cache(buf);
|
||||
@ -246,15 +259,9 @@ size_t MetalAllocator::size(Buffer buffer) const {
|
||||
}
|
||||
|
||||
MetalAllocator& allocator() {
|
||||
// By creating the |allocator_| on heap, the destructor of MetalAllocator will
|
||||
// not be called on exit and all the buffers will be leaked. This is necessary
|
||||
// because releasing buffers can take more than 30sec when the program holds a
|
||||
// lot of RAM (for example inferencing a LLM), and it would feel frozen to
|
||||
// users when exiting.
|
||||
// TODO(zcbenz): Consider using the `base::NoDestructor` class from Chromium
|
||||
// when applying this pattern to more places, or when introducing sanitizers
|
||||
// to MLX.
|
||||
// https://source.chromium.org/chromium/chromium/src/+/main:base/no_destructor.h
|
||||
// By creating the |allocator_| on heap, the destructor of MetalAllocator
|
||||
// will not be called on exit and buffers in the cache will be leaked. This
|
||||
// can save some time at program exit.
|
||||
static MetalAllocator* allocator_ = new MetalAllocator;
|
||||
return *allocator_;
|
||||
}
|
||||
@ -265,6 +272,15 @@ size_t set_cache_limit(size_t limit) {
|
||||
size_t set_memory_limit(size_t limit, bool relaxed /* = true */) {
|
||||
return allocator().set_memory_limit(limit, relaxed);
|
||||
}
|
||||
size_t set_wired_limit(size_t limit) {
|
||||
if (limit >
|
||||
std::get<size_t>(device_info()["max_recommended_working_set_size"])) {
|
||||
throw std::invalid_argument(
|
||||
"[metal::set_wired_limit] Setting a wired limit larger than "
|
||||
"the maximum working set size is not allowed.");
|
||||
}
|
||||
return allocator().set_wired_limit(limit);
|
||||
}
|
||||
size_t get_active_memory() {
|
||||
return allocator().get_active_memory();
|
||||
}
|
||||
|
@ -8,6 +8,7 @@
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/resident.h"
|
||||
|
||||
namespace mlx::core::metal {
|
||||
|
||||
@ -72,6 +73,7 @@ class MetalAllocator : public allocator::Allocator {
|
||||
};
|
||||
size_t set_cache_limit(size_t limit);
|
||||
size_t set_memory_limit(size_t limit, bool relaxed);
|
||||
size_t set_wired_limit(size_t limit);
|
||||
void clear_cache();
|
||||
|
||||
private:
|
||||
@ -82,12 +84,15 @@ class MetalAllocator : public allocator::Allocator {
|
||||
// Caching allocator
|
||||
BufferCache buffer_cache_;
|
||||
|
||||
ResidencySet residency_set_;
|
||||
|
||||
// Allocation stats
|
||||
size_t block_limit_;
|
||||
size_t gc_limit_;
|
||||
size_t active_memory_{0};
|
||||
size_t peak_memory_{0};
|
||||
size_t max_pool_size_;
|
||||
size_t wired_limit_{0};
|
||||
bool relaxed_{true};
|
||||
|
||||
std::mutex mutex_;
|
||||
|
@ -206,6 +206,9 @@ void Device::new_queue(int index) {
|
||||
"[metal::Device] Failed to make new command queue.");
|
||||
}
|
||||
stream_map_.emplace(index, q);
|
||||
if (residency_set_ != nullptr) {
|
||||
q->addResidencySet(residency_set_);
|
||||
}
|
||||
}
|
||||
|
||||
int Device::get_command_buffer_ops(int index) {
|
||||
@ -351,7 +354,7 @@ MTL::Library* Device::build_library_(const std::string& source_string) {
|
||||
// Throw error if unable to compile library
|
||||
if (!mtl_lib) {
|
||||
std::ostringstream msg;
|
||||
msg << "[metal::Device] Unable to build metal library from source" << "\n";
|
||||
msg << "[metal::Device] Unable to build metal library from source\n";
|
||||
if (error) {
|
||||
msg << error->localizedDescription()->utf8String() << "\n";
|
||||
}
|
||||
@ -593,6 +596,21 @@ MTL::ComputePipelineState* Device::get_kernel(
|
||||
return get_kernel_(base_name, mtl_lib, kname, func_consts, linked_functions);
|
||||
}
|
||||
|
||||
void Device::set_residency_set(const MTL::ResidencySet* residency_set) {
|
||||
if (residency_set_ != nullptr) {
|
||||
throw std::runtime_error(
|
||||
"[Device::set_residency_set] Can only be set once.");
|
||||
}
|
||||
if (residency_set == nullptr) {
|
||||
return;
|
||||
}
|
||||
residency_set_ = residency_set;
|
||||
// Attach residency set to existing command queues
|
||||
for (auto& [_, stream] : stream_map_) {
|
||||
stream.queue->addResidencySet(residency_set_);
|
||||
}
|
||||
}
|
||||
|
||||
Device& device(mlx::core::Device) {
|
||||
static Device metal_device;
|
||||
return metal_device;
|
||||
|
@ -181,6 +181,8 @@ class Device {
|
||||
void add_temporary(array arr, int index);
|
||||
void add_temporaries(std::vector<array> arrays, int index);
|
||||
|
||||
void set_residency_set(const MTL::ResidencySet* residency_set);
|
||||
|
||||
private:
|
||||
DeviceStream& get_stream_(int index) {
|
||||
return stream_map_.find(index)->second;
|
||||
@ -225,6 +227,7 @@ class Device {
|
||||
|
||||
std::shared_mutex library_mtx_;
|
||||
std::unordered_map<std::string, MTL::Library*> library_map_;
|
||||
const MTL::ResidencySet* residency_set_{nullptr};
|
||||
};
|
||||
|
||||
Device& device(mlx::core::Device);
|
||||
|
@ -1,5 +1,7 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/metal/device.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
@ -63,6 +63,19 @@ size_t set_cache_limit(size_t limit);
|
||||
/* Clear the memory cache. */
|
||||
void clear_cache();
|
||||
|
||||
/* Set the wired size limit.
|
||||
*
|
||||
* Note, this function is only useful for macOS 15.0 or higher.
|
||||
*
|
||||
* The wired limit is the total size in bytes of memory that will be kept
|
||||
* resident. The default value is ``0``.
|
||||
*
|
||||
* Setting a wired limit larger than system wired limit is an error.
|
||||
*
|
||||
* Returns the previous wired limit.
|
||||
* */
|
||||
size_t set_wired_limit(size_t limit);
|
||||
|
||||
/** Capture a GPU trace, saving it to an absolute file `path` */
|
||||
void start_capture(std::string path = "");
|
||||
void stop_capture();
|
||||
|
96
mlx/backend/metal/resident.cpp
Normal file
96
mlx/backend/metal/resident.cpp
Normal file
@ -0,0 +1,96 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/metal/resident.h"
|
||||
#include "mlx/backend/metal/metal_impl.h"
|
||||
|
||||
namespace mlx::core::metal {
|
||||
|
||||
// TODO maybe worth including tvos / visionos
|
||||
#define supported __builtin_available(macOS 15, iOS 18, *)
|
||||
|
||||
ResidencySet::ResidencySet(MTL::Device* d) {
|
||||
if (supported) {
|
||||
auto pool = new_scoped_memory_pool();
|
||||
auto desc = MTL::ResidencySetDescriptor::alloc()->init();
|
||||
NS::Error* error;
|
||||
wired_set_ = d->newResidencySet(desc, &error);
|
||||
desc->release();
|
||||
if (!wired_set_) {
|
||||
std::ostringstream msg;
|
||||
msg << "[metal::Device] Unable to construct residency set.\n";
|
||||
if (error) {
|
||||
msg << error->localizedDescription()->utf8String() << "\n";
|
||||
}
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ResidencySet::insert(MTL::Allocation* buf) {
|
||||
if (supported) {
|
||||
if (wired_set_->allocatedSize() + buf->allocatedSize() <= capacity_) {
|
||||
wired_set_->addAllocation(buf);
|
||||
wired_set_->commit();
|
||||
wired_set_->requestResidency();
|
||||
} else {
|
||||
unwired_set_.insert(buf);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ResidencySet::erase(MTL::Allocation* buf) {
|
||||
if (supported) {
|
||||
if (auto it = unwired_set_.find(buf); it != unwired_set_.end()) {
|
||||
unwired_set_.erase(it);
|
||||
} else {
|
||||
wired_set_->removeAllocation(buf);
|
||||
wired_set_->commit();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ResidencySet::resize(size_t size) {
|
||||
if (supported) {
|
||||
if (capacity_ == size) {
|
||||
return;
|
||||
}
|
||||
capacity_ = size;
|
||||
|
||||
size_t current_size = wired_set_->allocatedSize();
|
||||
|
||||
if (current_size < size) {
|
||||
// Add unwired allocations to the set
|
||||
for (auto it = unwired_set_.begin(); it != unwired_set_.end();) {
|
||||
auto buf_size = (*it)->allocatedSize();
|
||||
if (current_size + buf_size > size) {
|
||||
it++;
|
||||
} else {
|
||||
current_size += buf_size;
|
||||
wired_set_->addAllocation(*it);
|
||||
unwired_set_.erase(it++);
|
||||
}
|
||||
}
|
||||
wired_set_->commit();
|
||||
wired_set_->requestResidency();
|
||||
} else if (current_size > size) {
|
||||
// Remove wired allocations until under capacity
|
||||
auto allocations = wired_set_->allAllocations();
|
||||
auto num_allocations = wired_set_->allocationCount();
|
||||
for (int i = 0; i < num_allocations && current_size > size; ++i) {
|
||||
auto buf = static_cast<const MTL::Allocation*>(allocations->object(i));
|
||||
wired_set_->removeAllocation(buf);
|
||||
current_size -= buf->allocatedSize();
|
||||
unwired_set_.insert(buf);
|
||||
}
|
||||
wired_set_->commit();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ResidencySet::~ResidencySet() {
|
||||
if (supported) {
|
||||
wired_set_->release();
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core::metal
|
32
mlx/backend/metal/resident.h
Normal file
32
mlx/backend/metal/resident.h
Normal file
@ -0,0 +1,32 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/metal/device.h"
|
||||
|
||||
namespace mlx::core::metal {
|
||||
|
||||
class ResidencySet {
|
||||
public:
|
||||
ResidencySet(MTL::Device* d);
|
||||
~ResidencySet();
|
||||
|
||||
ResidencySet(const ResidencySet&) = delete;
|
||||
ResidencySet& operator=(const ResidencySet&) = delete;
|
||||
|
||||
const MTL::ResidencySet* mtl_residency_set() {
|
||||
return wired_set_;
|
||||
}
|
||||
|
||||
void insert(MTL::Allocation* buf);
|
||||
void erase(MTL::Allocation* buf);
|
||||
|
||||
void resize(size_t size);
|
||||
|
||||
private:
|
||||
MTL::ResidencySet* wired_set_{nullptr};
|
||||
std::unordered_set<const MTL::Allocation*> unwired_set_;
|
||||
size_t capacity_{0};
|
||||
};
|
||||
|
||||
} // namespace mlx::core::metal
|
@ -16,14 +16,14 @@ std::unique_ptr<void, std::function<void(void*)>> new_scoped_memory_pool() {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::function<void()> make_task(array arr, bool signal) {
|
||||
std::function<void()> make_task(array, bool) {
|
||||
throw std::runtime_error(
|
||||
"[metal::make_task] Cannot make GPU task without metal backend");
|
||||
}
|
||||
|
||||
std::function<void()> make_synchronize_task(
|
||||
Stream s,
|
||||
std::shared_ptr<std::promise<void>> p) {
|
||||
Stream,
|
||||
std::shared_ptr<std::promise<void>>) {
|
||||
throw std::runtime_error(
|
||||
"[metal::make_synchronize_task] Cannot synchronize GPU"
|
||||
" without metal backend");
|
||||
@ -46,7 +46,11 @@ size_t set_memory_limit(size_t, bool) {
|
||||
size_t set_cache_limit(size_t) {
|
||||
return 0;
|
||||
}
|
||||
void start_capture(std::string path) {}
|
||||
size_t set_wired_limit(size_t) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
void start_capture(std::string) {}
|
||||
void stop_capture() {}
|
||||
void clear_cache() {}
|
||||
|
||||
|
@ -6,6 +6,7 @@
|
||||
#include <nanobind/stl/string.h>
|
||||
#include <nanobind/stl/unordered_map.h>
|
||||
#include <nanobind/stl/variant.h>
|
||||
#include <nanobind/stl/vector.h>
|
||||
|
||||
namespace nb = nanobind;
|
||||
using namespace nb::literals;
|
||||
@ -98,6 +99,38 @@ void init_metal(nb::module_& m) {
|
||||
Returns:
|
||||
int: The previous cache limit in bytes.
|
||||
)pbdoc");
|
||||
metal.def(
|
||||
"set_wired_limit",
|
||||
&metal::set_wired_limit,
|
||||
"limit"_a,
|
||||
R"pbdoc(
|
||||
Set the wired size limit.
|
||||
|
||||
.. note::
|
||||
* This function is only useful on macOS 15.0 or higher.
|
||||
* The wired limit should remain strictly less than the total
|
||||
memory size.
|
||||
|
||||
The wired limit is the total size in bytes of memory that will be kept
|
||||
resident. The default value is ``0``.
|
||||
|
||||
Setting a wired limit larger than system wired limit is an error. You can
|
||||
increase the system wired limit with:
|
||||
|
||||
.. code-block::
|
||||
|
||||
sudo sysctl iogpu.wired_limit_mb=<size_in_megabytes>
|
||||
|
||||
Use :func:`device_info` to query the system wired limit
|
||||
(``"max_recommended_working_set_size"``) and the total memory size
|
||||
(``"memory_size"``).
|
||||
|
||||
Args:
|
||||
limit (int): The wired limit in bytes.
|
||||
|
||||
Returns:
|
||||
int: The previous wired limit in bytes.
|
||||
)pbdoc");
|
||||
metal.def(
|
||||
"clear_cache",
|
||||
&metal::clear_cache,
|
||||
|
@ -47,6 +47,14 @@ class TestMetal(mlx_tests.MLXTestCase):
|
||||
mx.metal.reset_peak_memory()
|
||||
self.assertEqual(mx.metal.get_peak_memory(), 0)
|
||||
|
||||
old_limit = mx.metal.set_wired_limit(1000)
|
||||
old_limit = mx.metal.set_wired_limit(0)
|
||||
self.assertEqual(old_limit, 1000)
|
||||
|
||||
max_size = mx.metal.device_info()["max_recommended_working_set_size"]
|
||||
with self.assertRaises(ValueError):
|
||||
mx.metal.set_wired_limit(max_size + 10)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user