diff --git a/docs/src/python/metal.rst b/docs/src/python/metal.rst index c92b18936..589ec0a82 100644 --- a/docs/src/python/metal.rst +++ b/docs/src/python/metal.rst @@ -12,5 +12,6 @@ Metal get_cache_memory set_memory_limit set_cache_limit + clear_cache start_capture stop_capture diff --git a/docs/src/python/ops.rst b/docs/src/python/ops.rst index 1c855b855..7795512a0 100644 --- a/docs/src/python/ops.rst +++ b/docs/src/python/ops.rst @@ -92,6 +92,7 @@ Operations moveaxis multiply negative + not_equal ones ones_like outer diff --git a/mlx/backend/metal/allocator.cpp b/mlx/backend/metal/allocator.cpp index 0e7502744..cc4945d29 100644 --- a/mlx/backend/metal/allocator.cpp +++ b/mlx/backend/metal/allocator.cpp @@ -209,6 +209,11 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) { return Buffer{static_cast(buf)}; } +void MetalAllocator::clear_cache() { + std::unique_lock lk(mutex_); + buffer_cache_.clear(); +} + void MetalAllocator::free(Buffer buffer) { auto buf = static_cast(buffer.ptr()); std::unique_lock lk(mutex_); @@ -242,6 +247,9 @@ size_t get_peak_memory() { size_t get_cache_memory() { return allocator().get_cache_memory(); } +void clear_cache() { + return allocator().clear_cache(); +} } // namespace metal diff --git a/mlx/backend/metal/allocator.h b/mlx/backend/metal/allocator.h index 9f6c0ec9b..e83008bdc 100644 --- a/mlx/backend/metal/allocator.h +++ b/mlx/backend/metal/allocator.h @@ -26,6 +26,7 @@ class BufferCache { size_t cache_size() { return pool_size_; } + void clear(); private: struct BufferHolder { @@ -37,7 +38,6 @@ class BufferCache { MTL::Buffer* buf; }; - void clear(); void add_at_head(BufferHolder* to_add); void remove_from_list(BufferHolder* to_remove); @@ -67,6 +67,7 @@ class MetalAllocator : public allocator::Allocator { }; size_t set_cache_limit(size_t limit); size_t set_memory_limit(size_t limit, bool relaxed); + void clear_cache(); private: MTL::Device* device_; diff --git a/mlx/backend/metal/metal.h b/mlx/backend/metal/metal.h index 86a47f37d..63e4bff5e 100644 --- a/mlx/backend/metal/metal.h +++ b/mlx/backend/metal/metal.h @@ -54,6 +54,9 @@ size_t set_memory_limit(size_t limit, bool relaxed = true); * */ size_t set_cache_limit(size_t limit); +/* Clear the memory cache. */ +void clear_cache(); + /** Capture a GPU trace, saving it to an absolute file `path` */ void start_capture(std::string path = ""); void stop_capture(); diff --git a/mlx/backend/no_metal/metal.cpp b/mlx/backend/no_metal/metal.cpp index 0a0b635ae..d3c011397 100644 --- a/mlx/backend/no_metal/metal.cpp +++ b/mlx/backend/no_metal/metal.cpp @@ -4,7 +4,6 @@ #include "mlx/backend/metal/metal.h" #include "mlx/backend/metal/metal_impl.h" - namespace mlx::core::metal { bool is_available() { @@ -48,5 +47,6 @@ size_t set_cache_limit(size_t) { } void start_capture(std::string path) {} void stop_capture() {} +void clear_cache() {} } // namespace mlx::core::metal diff --git a/python/src/metal.cpp b/python/src/metal.cpp index 6c7a27655..0c806cd3e 100644 --- a/python/src/metal.cpp +++ b/python/src/metal.cpp @@ -90,6 +90,15 @@ void init_metal(nb::module_& m) { Returns: int: The previous cache limit in bytes. )pbdoc"); + metal.def( + "clear_cache", + &metal::clear_cache, + R"pbdoc( + Clear the memory cache. + + After calling this, :func:`get_cache_memory` should return ``0``. + )pbdoc"); + metal.def( "start_capture", &metal::start_capture, diff --git a/python/tests/test_metal.py b/python/tests/test_metal.py index 2b3b107b1..deef2d985 100644 --- a/python/tests/test_metal.py +++ b/python/tests/test_metal.py @@ -42,6 +42,9 @@ class TestMetal(mlx_tests.MLXTestCase): cache_mem = mx.metal.get_cache_memory() self.assertTrue(cache_mem >= 4096 * 4) + mx.metal.clear_cache() + self.assertEqual(mx.metal.get_cache_memory(), 0) + if __name__ == "__main__": unittest.main() diff --git a/tests/metal_tests.cpp b/tests/metal_tests.cpp index 1ce50dcd2..1185ea04f 100644 --- a/tests/metal_tests.cpp +++ b/tests/metal_tests.cpp @@ -513,4 +513,7 @@ TEST_CASE("test metal memory info") { auto cache_mem = metal::get_cache_memory(); CHECK(cache_mem >= 4096 * 4); } + + metal::clear_cache(); + CHECK_EQ(metal::get_cache_memory(), 0); }