bindings for memory info (#761)

* bindings for memory info

* update api

* keep cache low if requested

* fix default

* nit in ops error
This commit is contained in:
Awni Hannun 2024-03-01 19:51:58 -08:00 committed by GitHub
parent cf3eb87e52
commit d5964a2710
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 300 additions and 69 deletions

View File

@ -64,6 +64,7 @@ are the CPU and GPU.
python/transforms python/transforms
python/fft python/fft
python/linalg python/linalg
python/metal
python/nn python/nn
python/optimizers python/optimizers
python/tree_utils python/tree_utils

14
docs/src/python/metal.rst Normal file
View File

@ -0,0 +1,14 @@
Metal
=====
.. currentmodule:: mlx.core.metal
.. autosummary::
:toctree: _autosummary
is_available
get_active_memory
get_peak_memory
get_cache_memory
set_memory_limit
set_cache_limit

View File

@ -57,6 +57,7 @@ Operations
greater_equal greater_equal
identity identity
inner inner
isclose
isnan isnan
isposinf isposinf
isneginf isneginf
@ -121,6 +122,7 @@ Operations
tan tan
tanh tanh
tensordot tensordot
tile
transpose transpose
tri tri
tril tril

View File

@ -1,4 +1,4 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#include "mlx/backend/metal/allocator.h" #include "mlx/backend/metal/allocator.h"
#include "mlx/backend/metal/metal.h" #include "mlx/backend/metal/metal.h"
@ -23,16 +23,6 @@ void* Buffer::raw_ptr() {
namespace metal { namespace metal {
static bool cache_enabled_ = true;
bool cache_enabled() {
return cache_enabled_;
}
void set_cache_enabled(bool enabled) {
cache_enabled_ = enabled;
}
namespace { namespace {
BufferCache::BufferCache(MTL::Device* device) BufferCache::BufferCache(MTL::Device* device)
@ -158,9 +148,23 @@ void BufferCache::remove_from_list(BufferCache::BufferHolder* to_remove) {
MetalAllocator::MetalAllocator() MetalAllocator::MetalAllocator()
: device_(device(mlx::core::Device::gpu).mtl_device()), : device_(device(mlx::core::Device::gpu).mtl_device()),
buffer_cache_(device_), buffer_cache_(device_),
peak_allocated_size_(0),
block_limit_(1.5 * device_->recommendedMaxWorkingSetSize()), block_limit_(1.5 * device_->recommendedMaxWorkingSetSize()),
gc_limit_(0.95 * device_->recommendedMaxWorkingSetSize()) {} gc_limit_(0.95 * device_->recommendedMaxWorkingSetSize()),
max_pool_size_(block_limit_) {}
size_t MetalAllocator::set_cache_limit(size_t limit) {
std::swap(limit, max_pool_size_);
return limit;
};
size_t MetalAllocator::set_memory_limit(size_t limit, bool relaxed) {
std::swap(limit, block_limit_);
relaxed_ = relaxed;
gc_limit_ = std::min(
block_limit_,
static_cast<size_t>(0.95 * device_->recommendedMaxWorkingSetSize()));
return limit;
};
Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) { Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
// Metal doesn't like empty buffers // Metal doesn't like empty buffers
@ -175,10 +179,12 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
// Try the cache // Try the cache
MTL::Buffer* buf = buffer_cache_.reuse_from_cache(size); MTL::Buffer* buf = buffer_cache_.reuse_from_cache(size);
size_t pool_size = get_cache_memory();
if (!buf) { if (!buf) {
size_t mem_required = get_active_memory() + pool_size + size;
// If there is too much memory pressure, fail (likely causes a wait). // If there is too much memory pressure, fail (likely causes a wait).
if (!allow_swap && device_->currentAllocatedSize() + size >= block_limit_) { if (!(allow_swap && relaxed_) && mem_required >= block_limit_) {
return Buffer{nullptr}; return Buffer{nullptr};
} }
@ -186,10 +192,8 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
// If we have a lot of memory pressure, check if we can reclaim some memory // If we have a lot of memory pressure, check if we can reclaim some memory
// from the cache // from the cache
if (device_->currentAllocatedSize() + size >= gc_limit_) { if (mem_required >= gc_limit_) {
size_t min_bytes_to_free = buffer_cache_.release_cached_buffers(mem_required - gc_limit_);
size + device_->currentAllocatedSize() - gc_limit_;
buffer_cache_.release_cached_buffers(min_bytes_to_free);
} }
// Allocate new buffer if needed // Allocate new buffer if needed
@ -198,15 +202,22 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
buf = device_->newBuffer(size, res_opt); buf = device_->newBuffer(size, res_opt);
} }
peak_allocated_size_ = // Maintain the cache below the requested limit
std::max(peak_allocated_size_, device_->currentAllocatedSize()); if (pool_size >= max_pool_size_) {
auto thread_pool = metal::new_scoped_memory_pool();
buffer_cache_.release_cached_buffers(pool_size - max_pool_size_);
}
active_memory_ += buf->length();
peak_memory_ = std::max(peak_memory_, active_memory_);
return Buffer{static_cast<void*>(buf)}; return Buffer{static_cast<void*>(buf)};
} }
void MetalAllocator::free(Buffer buffer) { void MetalAllocator::free(Buffer buffer) {
auto buf = static_cast<MTL::Buffer*>(buffer.ptr()); auto buf = static_cast<MTL::Buffer*>(buffer.ptr());
if (cache_enabled()) { active_memory_ -= buf->length();
if (max_pool_size_ > 0) {
buffer_cache_.recycle_to_cache(buf); buffer_cache_.recycle_to_cache(buf);
} else { } else {
buf->release(); buf->release();
@ -218,6 +229,22 @@ MetalAllocator& allocator() {
return allocator_; return allocator_;
} }
size_t set_cache_limit(size_t limit) {
return allocator().set_cache_limit(limit);
}
size_t set_memory_limit(size_t limit, bool relaxed /* = true */) {
return allocator().set_memory_limit(limit, relaxed);
}
size_t get_active_memory() {
return allocator().get_active_memory();
}
size_t get_peak_memory() {
return allocator().get_peak_memory();
}
size_t get_cache_memory() {
return allocator().get_cache_memory();
}
} // namespace metal } // namespace metal
} // namespace mlx::core } // namespace mlx::core

View File

@ -1,4 +1,4 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#pragma once #pragma once
@ -24,6 +24,9 @@ class BufferCache {
MTL::Buffer* reuse_from_cache(size_t size); MTL::Buffer* reuse_from_cache(size_t size);
void recycle_to_cache(MTL::Buffer* buf); void recycle_to_cache(MTL::Buffer* buf);
void release_cached_buffers(size_t min_bytes_to_free); void release_cached_buffers(size_t min_bytes_to_free);
size_t pool_size() {
return pool_size_;
}
private: private:
struct BufferHolder { struct BufferHolder {
@ -54,6 +57,17 @@ class MetalAllocator : public allocator::Allocator {
public: public:
virtual Buffer malloc(size_t size, bool allow_swap = false) override; virtual Buffer malloc(size_t size, bool allow_swap = false) override;
virtual void free(Buffer buffer) override; virtual void free(Buffer buffer) override;
size_t get_active_memory() {
return active_memory_;
};
size_t get_peak_memory() {
return peak_memory_;
};
size_t get_cache_memory() {
return buffer_cache_.pool_size();
};
size_t set_cache_limit(size_t limit);
size_t set_memory_limit(size_t limit, bool relaxed);
private: private:
MTL::Device* device_; MTL::Device* device_;
@ -64,9 +78,12 @@ class MetalAllocator : public allocator::Allocator {
BufferCache buffer_cache_; BufferCache buffer_cache_;
// Allocation stats // Allocation stats
size_t peak_allocated_size_;
size_t block_limit_; size_t block_limit_;
size_t gc_limit_; size_t gc_limit_;
size_t active_memory_{0};
size_t peak_memory_{0};
size_t max_pool_size_;
bool relaxed_{true};
}; };
MetalAllocator& allocator(); MetalAllocator& allocator();

View File

@ -12,8 +12,51 @@
namespace mlx::core::metal { namespace mlx::core::metal {
bool is_available(); bool is_available();
bool cache_enabled(void);
void set_cache_enabled(bool enabled); /* Get the actively used memory in bytes.
*
* Note, this will not always match memory use reported by the system because
* it does not include cached memory buffers.
* */
size_t get_active_memory();
/* Get the peak amount of used memory in bytes.
*
* The maximum memory used is recorded from the beginning of the program
* execution.
* */
size_t get_peak_memory();
/* Get the cache size in bytes.
*
* The cache includes memory not currently used that has not been returned
* to the system allocator.
* */
size_t get_cache_memory();
/* Set the memory limit.
* Calls to malloc will wait on scheduled tasks if the limit is exceeded. If
* there are no more scheduled tasks an error will be raised if relaxed
* is false or memory will be allocated (including the potential for
* swap) if relaxed is true.
*
* The memory limit defaults to 1.5 times the maximum recommended working set
* size reported by the device.
*
* Returns the previous memory limit.
* */
size_t set_memory_limit(size_t limit, bool relaxed = true);
/* Set the free cache limit.
* If using more than the given limit, free memory will be reclaimed
* from the cache on the next allocation. To disable the cache,
* set the limit to 0.
*
* The cache limit defaults to the memory limit.
*
* Returns the previous cache limit.
* */
size_t set_cache_limit(size_t limit);
void new_stream(Stream stream); void new_stream(Stream stream);
std::shared_ptr<void> new_scoped_memory_pool(); std::shared_ptr<void> new_scoped_memory_pool();

View File

@ -23,10 +23,21 @@ std::function<void()> make_task(
"[metal::make_task] Cannot make GPU task without metal backend"); "[metal::make_task] Cannot make GPU task without metal backend");
} }
// No cache for CPU only // No-ops when Metal is not available.
bool cache_enabled(void) { size_t get_active_memory() {
return false; return 0;
}
size_t get_peak_memory() {
return 0;
}
size_t get_cache_memory() {
return 0;
}
size_t set_memory_limit(size_t, bool) {
return 0;
}
size_t set_cache_limit(size_t) {
return 0;
} }
void set_cache_enabled(bool) {}
} // namespace mlx::core::metal } // namespace mlx::core::metal

View File

@ -1700,7 +1700,7 @@ array argpartition(
int kth_ = kth < 0 ? kth + a.shape(axis) : kth; int kth_ = kth < 0 ? kth + a.shape(axis) : kth;
if (kth_ < 0 || kth_ >= a.shape(axis_)) { if (kth_ < 0 || kth_ >= a.shape(axis_)) {
std::ostringstream msg; std::ostringstream msg;
msg << "[argpartition] Received invalid kth " << kth << "along axis " msg << "[argpartition] Received invalid kth " << kth << " along axis "
<< axis << " for array with shape: " << a.shape(); << axis << " for array with shape: " << a.shape();
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }

View File

@ -5,18 +5,88 @@
#include "mlx/backend/metal/metal.h" #include "mlx/backend/metal/metal.h"
namespace py = pybind11; namespace py = pybind11;
using namespace py::literals;
using namespace mlx::core; using namespace mlx::core;
void init_metal(py::module_& m) { void init_metal(py::module_& m) {
py::module_ metal = m.def_submodule("metal", "mlx.metal"); py::module_ metal = m.def_submodule("metal", "mlx.metal");
metal.def("is_available", &metal::is_available);
metal.def( metal.def(
"cache_enabled", "is_available",
&metal::cache_enabled, &metal::is_available,
"check if metal buffer cache is enabled, default is true"); R"pbdoc(
Check if the Metal back-end is available.
)pbdoc");
metal.def( metal.def(
"set_cache_enabled", "get_active_memory",
&metal::set_cache_enabled, &metal::get_active_memory,
"enable or disable metal buffer cache"); R"pbdoc(
Get the actively used memory in bytes.
Note, this will not always match memory use reported by the system because
it does not include cached memory buffers.
)pbdoc");
metal.def(
"get_peak_memory",
&metal::get_peak_memory,
R"pbdoc(
Get the peak amount of used memory in bytes.
The maximum memory used is recorded from the beginning of the program
execution.
)pbdoc");
metal.def(
"get_cache_memory",
&metal::get_cache_memory,
R"pbdoc(
Get the cache size in bytes.
The cache includes memory not currently used that has not been returned
to the system allocator.
)pbdoc");
metal.def(
"set_memory_limit",
&metal::set_memory_limit,
"limit"_a,
py::kw_only(),
"relaxed"_a = true,
R"pbdoc(
Set the memory limit.
Memory allocations will wait on scheduled tasks to complete if the limit
is exceeded. If there are no more scheduled tasks an error will be raised
if ``relaxed`` is ``False``. Otherwise memory will be allocated
(including the potential for swap) if ``relaxed`` is ``True``.
The memory limit defaults to 1.5 times the maximum recommended working set
size reported by the device.
Args:
limit (int): Memory limit in bytes.
relaxed (bool, optional): If `False`` an error is raised if the limit
is exceeded. Default: ``True``
Returns:
int: The previous memory limit in bytes.
)pbdoc");
metal.def(
"set_cache_limit",
&metal::set_cache_limit,
"limit"_a,
R"pbdoc(
Set the free cache limit.
If using more than the given limit, free memory will be reclaimed
from the cache on the next allocation. To disable the cache, set
the limit to ``0``.
The cache limit defaults to the memory limit. See
:func:`set_memory_limit` for more details.
Args:
limit (int): The cache limit in bytes.
Returns:
int: The previous cache limit in bytes.
)pbdoc");
} }

View File

@ -0,0 +1,45 @@
# Copyright © 2023-2024 Apple Inc.
import unittest
import mlx.core as mx
import mlx_tests
class TestMetal(mlx_tests.MLXTestCase):
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
def test_memory_info(self):
old_limit = mx.metal.set_cache_limit(0)
a = mx.zeros((4096,))
mx.eval(a)
del a
self.assertEqual(mx.metal.get_cache_memory(), 0)
self.assertEqual(mx.metal.set_cache_limit(old_limit), 0)
self.assertEqual(mx.metal.set_cache_limit(old_limit), old_limit)
old_limit = mx.metal.set_memory_limit(10)
self.assertTrue(mx.metal.set_memory_limit(old_limit), 10)
self.assertTrue(mx.metal.set_memory_limit(old_limit), old_limit)
# Query active and peak memory
a = mx.zeros((4096,))
mx.eval(a)
active_mem = mx.metal.get_active_memory()
self.assertTrue(active_mem >= 4096 * 4)
b = mx.zeros((4096,))
mx.eval(b)
del b
new_active_mem = mx.metal.get_active_memory()
self.assertEqual(new_active_mem, active_mem)
peak_mem = mx.metal.get_peak_memory()
self.assertTrue(peak_mem >= 4096 * 8)
cache_mem = mx.metal.get_cache_memory()
self.assertTrue(cache_mem >= 4096 * 4)
if __name__ == "__main__":
unittest.main()

View File

@ -1,4 +1,4 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#include <array> #include <array>
#include "doctest/doctest.h" #include "doctest/doctest.h"
@ -473,41 +473,42 @@ TEST_CASE("test metal validation") {
eval(scatter_max(array(1), {}, array(2), std::vector<int>{})); eval(scatter_max(array(1), {}, array(2), std::vector<int>{}));
} }
TEST_CASE("test metal enable/disable cache") { TEST_CASE("test metal memory info") {
// Test enable metal cache // Test cache limits
{ {
metal::set_cache_enabled(true); auto old_limit = metal::set_cache_limit(0);
CHECK(metal::cache_enabled()); {
auto a = zeros({4096});
auto& a = metal::allocator(); eval(a);
auto size = 100; }
auto buf = a.malloc(size, false); CHECK_EQ(metal::get_cache_memory(), 0);
CHECK_EQ(metal::set_cache_limit(old_limit), 0);
// Release a CHECK_EQ(metal::set_cache_limit(old_limit), old_limit);
a.free(buf);
// Check size should equals to size
CHECK_EQ(static_cast<MTL::Buffer*>(buf.ptr())->length(), size);
} }
// Test disable metal cache // Test memory limits
{ {
metal::set_cache_enabled(false); auto old_limit = metal::set_memory_limit(10);
CHECK(!metal::cache_enabled()); CHECK_EQ(metal::set_memory_limit(old_limit), 10);
CHECK_EQ(metal::set_memory_limit(old_limit), old_limit);
}
auto& a = metal::allocator(); // Query active and peak memory
auto size = 100; {
auto buf = a.malloc(size, false); auto a = zeros({4096});
auto buf_ptr = static_cast<MTL::Buffer*>(buf.ptr()); eval(a);
unsigned char first_byte = *reinterpret_cast<unsigned char*>(buf_ptr); auto active_mem = metal::get_active_memory();
CHECK(active_mem >= 4096 * 4);
{
auto b = zeros({4096});
eval(b);
}
auto new_active_mem = metal::get_active_memory();
CHECK_EQ(new_active_mem, active_mem);
auto peak_mem = metal::get_peak_memory();
CHECK(peak_mem >= 4096 * 8);
// Release a auto cache_mem = metal::get_cache_memory();
a.free(buf); CHECK(cache_mem >= 4096 * 4);
// If release successfully, the first byte should be different from the
// first byte before release
unsigned char new_first_byte = *reinterpret_cast<unsigned char*>(buf_ptr);
CHECK_NE(new_first_byte, first_byte);
} }
} }