From a749a91c755d07c6eb174a382a7fa0aa20354b9a Mon Sep 17 00:00:00 2001 From: Ethan Date: Fri, 19 Jan 2024 00:33:34 +0800 Subject: [PATCH] Support disable metal buffer cache to prevent performance degradation caused by large memory caching (#390) * support disable metal buffer cache, due to large unused memory buffered when llm generated long context tokens * Run format and add "cache_enabled" feature tests --- mlx/backend/metal/allocator.cpp | 16 ++++++++++++- mlx/backend/metal/metal.h | 3 +++ python/src/metal.cpp | 8 +++++++ tests/metal_tests.cpp | 41 +++++++++++++++++++++++++++++++++ 4 files changed, 67 insertions(+), 1 deletion(-) diff --git a/mlx/backend/metal/allocator.cpp b/mlx/backend/metal/allocator.cpp index 3e856ac15..cab27b715 100644 --- a/mlx/backend/metal/allocator.cpp +++ b/mlx/backend/metal/allocator.cpp @@ -23,6 +23,16 @@ void* Buffer::raw_ptr() { namespace metal { +static bool cache_enabled_ = true; + +bool cache_enabled() { + return cache_enabled_; +} + +void set_cache_enabled(bool enabled) { + cache_enabled_ = enabled; +} + namespace { BufferCache::BufferCache(MTL::Device* device) @@ -196,7 +206,11 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) { void MetalAllocator::free(Buffer buffer) { auto buf = static_cast(buffer.ptr()); - buffer_cache_.recycle_to_cache(buf); + if (cache_enabled()) { + buffer_cache_.recycle_to_cache(buf); + } else { + buf->release(); + } } MetalAllocator& allocator() { diff --git a/mlx/backend/metal/metal.h b/mlx/backend/metal/metal.h index ad12ef467..d249daac0 100644 --- a/mlx/backend/metal/metal.h +++ b/mlx/backend/metal/metal.h @@ -19,6 +19,9 @@ constexpr bool is_available() { #endif } +bool cache_enabled(void); +void set_cache_enabled(bool enabled); + void new_stream(Stream stream); std::shared_ptr new_scoped_memory_pool(); diff --git a/python/src/metal.cpp b/python/src/metal.cpp index 41b90d974..5331c8870 100644 --- a/python/src/metal.cpp +++ b/python/src/metal.cpp @@ -11,4 +11,12 @@ using namespace mlx::core; void init_metal(py::module_& m) { py::module_ metal = m.def_submodule("metal", "mlx.metal"); metal.def("is_available", &metal::is_available); + metal.def( + "cache_enabled", + &metal::cache_enabled, + "check if metal buffer cache is enabled, default is true"); + metal.def( + "set_cache_enabled", + &metal::set_cache_enabled, + "enable or disable metal buffer cache"); } diff --git a/tests/metal_tests.cpp b/tests/metal_tests.cpp index 38cf47a45..9e6264d7c 100644 --- a/tests/metal_tests.cpp +++ b/tests/metal_tests.cpp @@ -5,6 +5,7 @@ #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/metal.h" +#include "mlx/backend/metal/allocator.h" #include "mlx/mlx.h" using namespace mlx::core; @@ -471,3 +472,43 @@ TEST_CASE("test metal validation") { eval(scatter_max(array(1), {}, array(2), std::vector{})); } + +TEST_CASE("test metal enable/disable cache") { + // Test enable metal cache + { + metal::set_cache_enabled(true); + CHECK(metal::cache_enabled()); + + auto &a = metal::allocator(); + auto size = 100; + auto buf = a.malloc(size, false); + + // Release a + a.free(buf); + + // Check size should equals to size + CHECK_EQ(static_cast(buf.ptr())->length(), size); + } + + // Test disable metal cache + { + metal::set_cache_enabled(false); + CHECK(!metal::cache_enabled()); + + auto &a = metal::allocator(); + auto size = 100; + auto buf = a.malloc(size, false); + auto buf_ptr = static_cast(buf.ptr()); + unsigned char first_byte = *reinterpret_cast(buf_ptr); + printf("first byte: %d\n", first_byte); + + // Release a + a.free(buf); + + // If release successfully, the first byte should be different from the first byte before release + unsigned char new_first_byte = *reinterpret_cast(buf_ptr); + printf("new first byte: %d\n", new_first_byte); + + CHECK_NE(new_first_byte, first_byte); + } +}