From 2a980a76ce5363f8b80da376436947d2691e9c8d Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 21 Mar 2025 12:28:36 -0700 Subject: [PATCH] Add stats and limit to common allocator and enable tests (#1988) * add stats to common allocator and enable tests * linux memory and default * fix --- mlx/allocator.cpp | 20 ------ mlx/allocator.h | 12 ---- mlx/backend/no_metal/allocator.cpp | 101 ++++++++++++++++++++++++--- mlx/backend/no_metal/apple_memory.h | 16 +++++ mlx/backend/no_metal/linux_memory.h | 22 ++++++ python/src/metal.cpp | 2 +- python/tests/mpi_test_distributed.py | 6 +- python/tests/test_array.py | 1 - python/tests/test_autograd.py | 14 +--- python/tests/test_eval.py | 8 +-- python/tests/test_load.py | 2 +- python/tests/test_memory.py | 9 ++- python/tests/test_ops.py | 6 +- 13 files changed, 151 insertions(+), 68 deletions(-) create mode 100644 mlx/backend/no_metal/apple_memory.h create mode 100644 mlx/backend/no_metal/linux_memory.h diff --git a/mlx/allocator.cpp b/mlx/allocator.cpp index 2d97a6db3..fbfca7551 100644 --- a/mlx/allocator.cpp +++ b/mlx/allocator.cpp @@ -4,7 +4,6 @@ #include #include "mlx/allocator.h" -#include "mlx/scheduler.h" namespace mlx::core::allocator { @@ -22,23 +21,4 @@ void free(Buffer buffer) { allocator().free(buffer); } -Buffer CommonAllocator::malloc(size_t size) { - void* ptr = std::malloc(size + sizeof(size_t)); - if (ptr != nullptr) { - *static_cast(ptr) = size; - } - return Buffer{ptr}; -} - -void CommonAllocator::free(Buffer buffer) { - std::free(buffer.ptr()); -} - -size_t CommonAllocator::size(Buffer buffer) const { - if (buffer.ptr() == nullptr) { - return 0; - } - return *static_cast(buffer.ptr()); -} - } // namespace mlx::core::allocator diff --git a/mlx/allocator.h b/mlx/allocator.h index d4e3e1d6e..362f4f08a 100644 --- a/mlx/allocator.h +++ b/mlx/allocator.h @@ -49,16 +49,4 @@ class Allocator { Allocator& allocator(); -class CommonAllocator : public Allocator { - /** A general CPU allocator. */ - public: - virtual Buffer malloc(size_t size) override; - virtual void free(Buffer buffer) override; - virtual size_t size(Buffer buffer) const override; - - private: - CommonAllocator() = default; - friend Allocator& allocator(); -}; - } // namespace mlx::core::allocator diff --git a/mlx/backend/no_metal/allocator.cpp b/mlx/backend/no_metal/allocator.cpp index 750bcc539..b73a53484 100644 --- a/mlx/backend/no_metal/allocator.cpp +++ b/mlx/backend/no_metal/allocator.cpp @@ -1,38 +1,123 @@ // Copyright © 2023 Apple Inc. +#include + #include "mlx/allocator.h" +#ifdef __APPLE__ +#include "mlx/backend/no_metal/apple_memory.h" +#elif defined(__linux__) +#include "mlx/backend/no_metal/linux_memory.h" +#else +size_t get_memory_size() { + return 0; +} +#endif + namespace mlx::core { namespace allocator { -Allocator& allocator() { +class CommonAllocator : public Allocator { + /** A general CPU allocator. */ + public: + virtual Buffer malloc(size_t size) override; + virtual void free(Buffer buffer) override; + virtual size_t size(Buffer buffer) const override; + size_t get_active_memory() const { + return active_memory_; + }; + size_t get_peak_memory() const { + return peak_memory_; + }; + void reset_peak_memory() { + std::unique_lock lk(mutex_); + peak_memory_ = 0; + }; + size_t get_memory_limit() { + return memory_limit_; + } + size_t set_memory_limit(size_t limit) { + std::unique_lock lk(mutex_); + std::swap(memory_limit_, limit); + return limit; + } + + private: + size_t memory_limit_; + size_t active_memory_{0}; + size_t peak_memory_{0}; + std::mutex mutex_; + CommonAllocator() : memory_limit_(0.8 * get_memory_size()) { + if (memory_limit_ == 0) { + memory_limit_ = 1UL << 33; + } + }; + + friend CommonAllocator& common_allocator(); +}; + +CommonAllocator& common_allocator() { static CommonAllocator allocator_; return allocator_; } +Allocator& allocator() { + return common_allocator(); +} + void* Buffer::raw_ptr() { if (!ptr_) { return nullptr; } return static_cast(ptr_) + 1; } + +Buffer CommonAllocator::malloc(size_t size) { + void* ptr = std::malloc(size + sizeof(size_t)); + if (ptr != nullptr) { + *static_cast(ptr) = size; + } + std::unique_lock lk(mutex_); + active_memory_ += size; + peak_memory_ = std::max(active_memory_, peak_memory_); + return Buffer{ptr}; +} + +void CommonAllocator::free(Buffer buffer) { + auto sz = size(buffer); + std::free(buffer.ptr()); + std::unique_lock lk(mutex_); + active_memory_ -= sz; +} + +size_t CommonAllocator::size(Buffer buffer) const { + if (buffer.ptr() == nullptr) { + return 0; + } + return *static_cast(buffer.ptr()); +} + } // namespace allocator size_t get_active_memory() { - return 0; + return allocator::common_allocator().get_active_memory(); } size_t get_peak_memory() { - return 0; + return allocator::common_allocator().get_peak_memory(); } -void reset_peak_memory() {} -size_t get_cache_memory() { - return 0; +void reset_peak_memory() { + return allocator::common_allocator().reset_peak_memory(); } -size_t set_memory_limit(size_t) { - return 0; +size_t set_memory_limit(size_t limit) { + return allocator::common_allocator().set_memory_limit(limit); } size_t get_memory_limit() { + return allocator::common_allocator().get_memory_limit(); +} + +// No-ops for common allocator +size_t get_cache_memory() { return 0; } size_t set_cache_limit(size_t) { diff --git a/mlx/backend/no_metal/apple_memory.h b/mlx/backend/no_metal/apple_memory.h new file mode 100644 index 000000000..7fdc53014 --- /dev/null +++ b/mlx/backend/no_metal/apple_memory.h @@ -0,0 +1,16 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include + +namespace { + +size_t get_memory_size() { + size_t memsize = 0; + size_t length = sizeof(memsize); + sysctlbyname("hw.memsize", &memsize, &length, NULL, 0); + return memsize; +} + +} // namespace diff --git a/mlx/backend/no_metal/linux_memory.h b/mlx/backend/no_metal/linux_memory.h new file mode 100644 index 000000000..f909edcd7 --- /dev/null +++ b/mlx/backend/no_metal/linux_memory.h @@ -0,0 +1,22 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include + +namespace { + +size_t get_memory_size() { + struct sysinfo info; + + if (sysinfo(&info) != 0) { + return 0; + } + + size_t total_ram = info.totalram; + total_ram *= info.mem_unit; + + return total_ram; +} + +} // namespace diff --git a/python/src/metal.cpp b/python/src/metal.cpp index 09c69687c..a13dd2a03 100644 --- a/python/src/metal.cpp +++ b/python/src/metal.cpp @@ -36,7 +36,7 @@ void init_metal(nb::module_& m) { }); metal.def("get_peak_memory", []() { DEPRECATE("mx.metal.get_peak_memory", "mx.get_peak_memory"); - return mx::get_active_memory(); + return mx::get_peak_memory(); }); metal.def("reset_peak_memory", []() { DEPRECATE("mx.metal.reset_peak_memory", "mx.reset_peak_memory"); diff --git a/python/tests/mpi_test_distributed.py b/python/tests/mpi_test_distributed.py index f2c1c25b1..a6467568e 100644 --- a/python/tests/mpi_test_distributed.py +++ b/python/tests/mpi_test_distributed.py @@ -177,17 +177,17 @@ class TestDistributed(mlx_tests.MLXTestCase): def test_donation(self): x = mx.random.normal((1024,)) mx.eval(x) - mx.synchronize(mx.default_stream(mx.default_device())) + mx.synchronize() mx.reset_peak_memory() scale = mx.array(2.0) y = mx.distributed.all_sum(x) mx.eval(y) - mx.synchronize(mx.default_stream(mx.default_device())) + mx.synchronize() all_sum_only = mx.get_peak_memory() y = mx.distributed.all_sum(x) * scale mx.eval(y) - mx.synchronize(mx.default_stream(mx.default_device())) + mx.synchronize() all_sum_with_binary = mx.get_peak_memory() self.assertEqual(all_sum_only, all_sum_with_binary) diff --git a/python/tests/test_array.py b/python/tests/test_array.py index c6ecde8cb..f22d7ced0 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -1803,7 +1803,6 @@ class TestArray(mlx_tests.MLXTestCase): b = pickle.loads(pickle.dumps(a)) self.assertTrue(mx.array_equal(mx.array(a), mx.array(b))) - @unittest.skipIf(not mx.metal.is_available(), "Metal is not available") def test_multi_output_leak(self): def fun(): a = mx.zeros((2**20)) diff --git a/python/tests/test_autograd.py b/python/tests/test_autograd.py index 82513a825..ec9d957ea 100644 --- a/python/tests/test_autograd.py +++ b/python/tests/test_autograd.py @@ -745,11 +745,8 @@ class TestAutograd(mlx_tests.MLXTestCase): mx.custom_function, mx.checkpoint, ]: - if mx.metal.is_available(): - mx.synchronize(mx.default_stream(mx.default_device())) - mem_pre = mx.get_active_memory() - else: - mem_pre = 0 + mx.synchronize() + mem_pre = mx.get_active_memory() def outer(): d = {} @@ -763,12 +760,7 @@ class TestAutograd(mlx_tests.MLXTestCase): for _ in range(5): outer() gc.collect() - - if mx.metal.is_available(): - mem_post = mx.get_active_memory() - else: - mem_post = 0 - + mem_post = mx.get_active_memory() self.assertEqual(mem_pre, mem_post) def test_grad_with_copies(self): diff --git a/python/tests/test_eval.py b/python/tests/test_eval.py index 1b0a7a268..fcd424343 100644 --- a/python/tests/test_eval.py +++ b/python/tests/test_eval.py @@ -117,10 +117,9 @@ class TestEval(mlx_tests.MLXTestCase): out = mx.vjp(fn, (x,), (y,)) out = mx.vjp(fn, (x,), (y,)) - if mx.metal.is_available(): - peak_mem = mx.get_peak_memory() - out = mx.vjp(fn, (x,), (y,)) - self.assertEqual(peak_mem, mx.get_peak_memory()) + peak_mem = mx.get_peak_memory() + out = mx.vjp(fn, (x,), (y,)) + self.assertEqual(peak_mem, mx.get_peak_memory()) def test_async_eval_with_multiple_streams(self): x = mx.array([1.0]) @@ -137,7 +136,6 @@ class TestEval(mlx_tests.MLXTestCase): mx.async_eval(x) mx.eval(a + b) - @unittest.skipIf(not mx.metal.is_available(), "Metal is not available") def test_donation_for_noops(self): def fun(x): s = x.shape diff --git a/python/tests/test_load.py b/python/tests/test_load.py index 67c3f4768..341564dae 100644 --- a/python/tests/test_load.py +++ b/python/tests/test_load.py @@ -385,7 +385,7 @@ class TestLoad(mlx_tests.MLXTestCase): mx.eval(x) save_file = os.path.join(self.test_dir, "donation.npy") mx.save(save_file, x) - mx.synchronize(mx.default_stream(mx.default_device())) + mx.synchronize() mx.reset_peak_memory() scale = mx.array(2.0) diff --git a/python/tests/test_memory.py b/python/tests/test_memory.py index cf7e8d1ce..7343bdc91 100644 --- a/python/tests/test_memory.py +++ b/python/tests/test_memory.py @@ -7,7 +7,6 @@ import mlx_tests class TestMemory(mlx_tests.MLXTestCase): - @unittest.skipIf(not mx.metal.is_available(), "Metal is not available") def test_memory_info(self): old_limit = mx.set_cache_limit(0) @@ -38,8 +37,10 @@ class TestMemory(mlx_tests.MLXTestCase): self.assertEqual(new_active_mem, active_mem) peak_mem = mx.get_peak_memory() self.assertTrue(peak_mem >= 4096 * 8) - cache_mem = mx.get_cache_memory() - self.assertTrue(cache_mem >= 4096 * 4) + + if mx.metal.is_available(): + cache_mem = mx.get_cache_memory() + self.assertTrue(cache_mem >= 4096 * 4) mx.clear_cache() self.assertEqual(mx.get_cache_memory(), 0) @@ -47,6 +48,8 @@ class TestMemory(mlx_tests.MLXTestCase): mx.reset_peak_memory() self.assertEqual(mx.get_peak_memory(), 0) + @unittest.skipIf(not mx.metal.is_available(), "Metal is not available") + def test_wired_memory(self): old_limit = mx.set_wired_limit(1000) old_limit = mx.set_wired_limit(0) self.assertEqual(old_limit, 1000) diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 2ba098f7b..302c017a0 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -1901,12 +1901,12 @@ class TestOps(mlx_tests.MLXTestCase): x = mx.cumsum(x) return x - mx.synchronize(mx.default_stream(mx.default_device())) + mx.synchronize() mx.eval(fn(2)) - mx.synchronize(mx.default_stream(mx.default_device())) + mx.synchronize() mem2 = mx.get_peak_memory() mx.eval(fn(4)) - mx.synchronize(mx.default_stream(mx.default_device())) + mx.synchronize() mem4 = mx.get_peak_memory() self.assertEqual(mem2, mem4)