bindings for memory info (#761)

* bindings for memory info

* update api

* keep cache low if requested

* fix default

* nit in ops error
This commit is contained in:
Awni Hannun 2024-03-01 19:51:58 -08:00 committed by GitHub
parent cf3eb87e52
commit d5964a2710
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 300 additions and 69 deletions

View File

@ -64,6 +64,7 @@ are the CPU and GPU.
python/transforms
python/fft
python/linalg
python/metal
python/nn
python/optimizers
python/tree_utils

14
docs/src/python/metal.rst Normal file
View File

@ -0,0 +1,14 @@
Metal
=====
.. currentmodule:: mlx.core.metal
.. autosummary::
:toctree: _autosummary
is_available
get_active_memory
get_peak_memory
get_cache_memory
set_memory_limit
set_cache_limit

View File

@ -57,6 +57,7 @@ Operations
greater_equal
identity
inner
isclose
isnan
isposinf
isneginf
@ -121,6 +122,7 @@ Operations
tan
tanh
tensordot
tile
transpose
tri
tril

View File

@ -1,4 +1,4 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-2024 Apple Inc.
#include "mlx/backend/metal/allocator.h"
#include "mlx/backend/metal/metal.h"
@ -23,16 +23,6 @@ void* Buffer::raw_ptr() {
namespace metal {
static bool cache_enabled_ = true;
bool cache_enabled() {
return cache_enabled_;
}
void set_cache_enabled(bool enabled) {
cache_enabled_ = enabled;
}
namespace {
BufferCache::BufferCache(MTL::Device* device)
@ -158,9 +148,23 @@ void BufferCache::remove_from_list(BufferCache::BufferHolder* to_remove) {
MetalAllocator::MetalAllocator()
: device_(device(mlx::core::Device::gpu).mtl_device()),
buffer_cache_(device_),
peak_allocated_size_(0),
block_limit_(1.5 * device_->recommendedMaxWorkingSetSize()),
gc_limit_(0.95 * device_->recommendedMaxWorkingSetSize()) {}
gc_limit_(0.95 * device_->recommendedMaxWorkingSetSize()),
max_pool_size_(block_limit_) {}
size_t MetalAllocator::set_cache_limit(size_t limit) {
std::swap(limit, max_pool_size_);
return limit;
};
size_t MetalAllocator::set_memory_limit(size_t limit, bool relaxed) {
std::swap(limit, block_limit_);
relaxed_ = relaxed;
gc_limit_ = std::min(
block_limit_,
static_cast<size_t>(0.95 * device_->recommendedMaxWorkingSetSize()));
return limit;
};
Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
// Metal doesn't like empty buffers
@ -175,10 +179,12 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
// Try the cache
MTL::Buffer* buf = buffer_cache_.reuse_from_cache(size);
size_t pool_size = get_cache_memory();
if (!buf) {
size_t mem_required = get_active_memory() + pool_size + size;
// If there is too much memory pressure, fail (likely causes a wait).
if (!allow_swap && device_->currentAllocatedSize() + size >= block_limit_) {
if (!(allow_swap && relaxed_) && mem_required >= block_limit_) {
return Buffer{nullptr};
}
@ -186,10 +192,8 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
// If we have a lot of memory pressure, check if we can reclaim some memory
// from the cache
if (device_->currentAllocatedSize() + size >= gc_limit_) {
size_t min_bytes_to_free =
size + device_->currentAllocatedSize() - gc_limit_;
buffer_cache_.release_cached_buffers(min_bytes_to_free);
if (mem_required >= gc_limit_) {
buffer_cache_.release_cached_buffers(mem_required - gc_limit_);
}
// Allocate new buffer if needed
@ -198,15 +202,22 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
buf = device_->newBuffer(size, res_opt);
}
peak_allocated_size_ =
std::max(peak_allocated_size_, device_->currentAllocatedSize());
// Maintain the cache below the requested limit
if (pool_size >= max_pool_size_) {
auto thread_pool = metal::new_scoped_memory_pool();
buffer_cache_.release_cached_buffers(pool_size - max_pool_size_);
}
active_memory_ += buf->length();
peak_memory_ = std::max(peak_memory_, active_memory_);
return Buffer{static_cast<void*>(buf)};
}
void MetalAllocator::free(Buffer buffer) {
auto buf = static_cast<MTL::Buffer*>(buffer.ptr());
if (cache_enabled()) {
active_memory_ -= buf->length();
if (max_pool_size_ > 0) {
buffer_cache_.recycle_to_cache(buf);
} else {
buf->release();
@ -218,6 +229,22 @@ MetalAllocator& allocator() {
return allocator_;
}
size_t set_cache_limit(size_t limit) {
return allocator().set_cache_limit(limit);
}
size_t set_memory_limit(size_t limit, bool relaxed /* = true */) {
return allocator().set_memory_limit(limit, relaxed);
}
size_t get_active_memory() {
return allocator().get_active_memory();
}
size_t get_peak_memory() {
return allocator().get_peak_memory();
}
size_t get_cache_memory() {
return allocator().get_cache_memory();
}
} // namespace metal
} // namespace mlx::core

View File

@ -1,4 +1,4 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-2024 Apple Inc.
#pragma once
@ -24,6 +24,9 @@ class BufferCache {
MTL::Buffer* reuse_from_cache(size_t size);
void recycle_to_cache(MTL::Buffer* buf);
void release_cached_buffers(size_t min_bytes_to_free);
size_t pool_size() {
return pool_size_;
}
private:
struct BufferHolder {
@ -54,6 +57,17 @@ class MetalAllocator : public allocator::Allocator {
public:
virtual Buffer malloc(size_t size, bool allow_swap = false) override;
virtual void free(Buffer buffer) override;
size_t get_active_memory() {
return active_memory_;
};
size_t get_peak_memory() {
return peak_memory_;
};
size_t get_cache_memory() {
return buffer_cache_.pool_size();
};
size_t set_cache_limit(size_t limit);
size_t set_memory_limit(size_t limit, bool relaxed);
private:
MTL::Device* device_;
@ -64,9 +78,12 @@ class MetalAllocator : public allocator::Allocator {
BufferCache buffer_cache_;
// Allocation stats
size_t peak_allocated_size_;
size_t block_limit_;
size_t gc_limit_;
size_t active_memory_{0};
size_t peak_memory_{0};
size_t max_pool_size_;
bool relaxed_{true};
};
MetalAllocator& allocator();

View File

@ -12,8 +12,51 @@
namespace mlx::core::metal {
bool is_available();
bool cache_enabled(void);
void set_cache_enabled(bool enabled);
/* Get the actively used memory in bytes.
*
* Note, this will not always match memory use reported by the system because
* it does not include cached memory buffers.
* */
size_t get_active_memory();
/* Get the peak amount of used memory in bytes.
*
* The maximum memory used is recorded from the beginning of the program
* execution.
* */
size_t get_peak_memory();
/* Get the cache size in bytes.
*
* The cache includes memory not currently used that has not been returned
* to the system allocator.
* */
size_t get_cache_memory();
/* Set the memory limit.
* Calls to malloc will wait on scheduled tasks if the limit is exceeded. If
* there are no more scheduled tasks an error will be raised if relaxed
* is false or memory will be allocated (including the potential for
* swap) if relaxed is true.
*
* The memory limit defaults to 1.5 times the maximum recommended working set
* size reported by the device.
*
* Returns the previous memory limit.
* */
size_t set_memory_limit(size_t limit, bool relaxed = true);
/* Set the free cache limit.
* If using more than the given limit, free memory will be reclaimed
* from the cache on the next allocation. To disable the cache,
* set the limit to 0.
*
* The cache limit defaults to the memory limit.
*
* Returns the previous cache limit.
* */
size_t set_cache_limit(size_t limit);
void new_stream(Stream stream);
std::shared_ptr<void> new_scoped_memory_pool();

View File

@ -23,10 +23,21 @@ std::function<void()> make_task(
"[metal::make_task] Cannot make GPU task without metal backend");
}
// No cache for CPU only
bool cache_enabled(void) {
return false;
// No-ops when Metal is not available.
size_t get_active_memory() {
return 0;
}
size_t get_peak_memory() {
return 0;
}
size_t get_cache_memory() {
return 0;
}
size_t set_memory_limit(size_t, bool) {
return 0;
}
size_t set_cache_limit(size_t) {
return 0;
}
void set_cache_enabled(bool) {}
} // namespace mlx::core::metal

View File

@ -5,18 +5,88 @@
#include "mlx/backend/metal/metal.h"
namespace py = pybind11;
using namespace py::literals;
using namespace mlx::core;
void init_metal(py::module_& m) {
py::module_ metal = m.def_submodule("metal", "mlx.metal");
metal.def("is_available", &metal::is_available);
metal.def(
"cache_enabled",
&metal::cache_enabled,
"check if metal buffer cache is enabled, default is true");
"is_available",
&metal::is_available,
R"pbdoc(
Check if the Metal back-end is available.
)pbdoc");
metal.def(
"set_cache_enabled",
&metal::set_cache_enabled,
"enable or disable metal buffer cache");
"get_active_memory",
&metal::get_active_memory,
R"pbdoc(
Get the actively used memory in bytes.
Note, this will not always match memory use reported by the system because
it does not include cached memory buffers.
)pbdoc");
metal.def(
"get_peak_memory",
&metal::get_peak_memory,
R"pbdoc(
Get the peak amount of used memory in bytes.
The maximum memory used is recorded from the beginning of the program
execution.
)pbdoc");
metal.def(
"get_cache_memory",
&metal::get_cache_memory,
R"pbdoc(
Get the cache size in bytes.
The cache includes memory not currently used that has not been returned
to the system allocator.
)pbdoc");
metal.def(
"set_memory_limit",
&metal::set_memory_limit,
"limit"_a,
py::kw_only(),
"relaxed"_a = true,
R"pbdoc(
Set the memory limit.
Memory allocations will wait on scheduled tasks to complete if the limit
is exceeded. If there are no more scheduled tasks an error will be raised
if ``relaxed`` is ``False``. Otherwise memory will be allocated
(including the potential for swap) if ``relaxed`` is ``True``.
The memory limit defaults to 1.5 times the maximum recommended working set
size reported by the device.
Args:
limit (int): Memory limit in bytes.
relaxed (bool, optional): If `False`` an error is raised if the limit
is exceeded. Default: ``True``
Returns:
int: The previous memory limit in bytes.
)pbdoc");
metal.def(
"set_cache_limit",
&metal::set_cache_limit,
"limit"_a,
R"pbdoc(
Set the free cache limit.
If using more than the given limit, free memory will be reclaimed
from the cache on the next allocation. To disable the cache, set
the limit to ``0``.
The cache limit defaults to the memory limit. See
:func:`set_memory_limit` for more details.
Args:
limit (int): The cache limit in bytes.
Returns:
int: The previous cache limit in bytes.
)pbdoc");
}

View File

@ -0,0 +1,45 @@
# Copyright © 2023-2024 Apple Inc.
import unittest
import mlx.core as mx
import mlx_tests
class TestMetal(mlx_tests.MLXTestCase):
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
def test_memory_info(self):
old_limit = mx.metal.set_cache_limit(0)
a = mx.zeros((4096,))
mx.eval(a)
del a
self.assertEqual(mx.metal.get_cache_memory(), 0)
self.assertEqual(mx.metal.set_cache_limit(old_limit), 0)
self.assertEqual(mx.metal.set_cache_limit(old_limit), old_limit)
old_limit = mx.metal.set_memory_limit(10)
self.assertTrue(mx.metal.set_memory_limit(old_limit), 10)
self.assertTrue(mx.metal.set_memory_limit(old_limit), old_limit)
# Query active and peak memory
a = mx.zeros((4096,))
mx.eval(a)
active_mem = mx.metal.get_active_memory()
self.assertTrue(active_mem >= 4096 * 4)
b = mx.zeros((4096,))
mx.eval(b)
del b
new_active_mem = mx.metal.get_active_memory()
self.assertEqual(new_active_mem, active_mem)
peak_mem = mx.metal.get_peak_memory()
self.assertTrue(peak_mem >= 4096 * 8)
cache_mem = mx.metal.get_cache_memory()
self.assertTrue(cache_mem >= 4096 * 4)
if __name__ == "__main__":
unittest.main()

View File

@ -1,4 +1,4 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-2024 Apple Inc.
#include <array>
#include "doctest/doctest.h"
@ -473,41 +473,42 @@ TEST_CASE("test metal validation") {
eval(scatter_max(array(1), {}, array(2), std::vector<int>{}));
}
TEST_CASE("test metal enable/disable cache") {
// Test enable metal cache
TEST_CASE("test metal memory info") {
// Test cache limits
{
metal::set_cache_enabled(true);
CHECK(metal::cache_enabled());
auto& a = metal::allocator();
auto size = 100;
auto buf = a.malloc(size, false);
// Release a
a.free(buf);
// Check size should equals to size
CHECK_EQ(static_cast<MTL::Buffer*>(buf.ptr())->length(), size);
}
// Test disable metal cache
auto old_limit = metal::set_cache_limit(0);
{
metal::set_cache_enabled(false);
CHECK(!metal::cache_enabled());
auto a = zeros({4096});
eval(a);
}
CHECK_EQ(metal::get_cache_memory(), 0);
CHECK_EQ(metal::set_cache_limit(old_limit), 0);
CHECK_EQ(metal::set_cache_limit(old_limit), old_limit);
}
auto& a = metal::allocator();
auto size = 100;
auto buf = a.malloc(size, false);
auto buf_ptr = static_cast<MTL::Buffer*>(buf.ptr());
unsigned char first_byte = *reinterpret_cast<unsigned char*>(buf_ptr);
// Test memory limits
{
auto old_limit = metal::set_memory_limit(10);
CHECK_EQ(metal::set_memory_limit(old_limit), 10);
CHECK_EQ(metal::set_memory_limit(old_limit), old_limit);
}
// Release a
a.free(buf);
// Query active and peak memory
{
auto a = zeros({4096});
eval(a);
auto active_mem = metal::get_active_memory();
CHECK(active_mem >= 4096 * 4);
{
auto b = zeros({4096});
eval(b);
}
auto new_active_mem = metal::get_active_memory();
CHECK_EQ(new_active_mem, active_mem);
auto peak_mem = metal::get_peak_memory();
CHECK(peak_mem >= 4096 * 8);
// If release successfully, the first byte should be different from the
// first byte before release
unsigned char new_first_byte = *reinterpret_cast<unsigned char*>(buf_ptr);
CHECK_NE(new_first_byte, first_byte);
auto cache_mem = metal::get_cache_memory();
CHECK(cache_mem >= 4096 * 4);
}
}