From 21623156a32910b9db7c913f91612dcde0282caf Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 3 May 2024 17:12:51 -0700 Subject: [PATCH] Reset peak memory (#1074) * reset peak memory * fix linux * nits in docs --- docs/src/python/metal.rst | 1 + mlx/backend/metal/allocator.cpp | 3 +++ mlx/backend/metal/allocator.h | 4 ++++ mlx/backend/metal/metal.h | 8 ++++++-- mlx/backend/no_metal/metal.cpp | 1 + python/src/metal.cpp | 10 ++++++++-- python/tests/test_metal.py | 3 +++ 7 files changed, 26 insertions(+), 4 deletions(-) diff --git a/docs/src/python/metal.rst b/docs/src/python/metal.rst index d333e09ca..cb2cdb38e 100644 --- a/docs/src/python/metal.rst +++ b/docs/src/python/metal.rst @@ -10,6 +10,7 @@ Metal device_info get_active_memory get_peak_memory + reset_peak_memory get_cache_memory set_memory_limit set_cache_limit diff --git a/mlx/backend/metal/allocator.cpp b/mlx/backend/metal/allocator.cpp index aad031cfa..857afbb83 100644 --- a/mlx/backend/metal/allocator.cpp +++ b/mlx/backend/metal/allocator.cpp @@ -258,6 +258,9 @@ size_t get_active_memory() { size_t get_peak_memory() { return allocator().get_peak_memory(); } +void reset_peak_memory() { + allocator().reset_peak_memory(); +} size_t get_cache_memory() { return allocator().get_cache_memory(); } diff --git a/mlx/backend/metal/allocator.h b/mlx/backend/metal/allocator.h index e83008bdc..8e34f48d2 100644 --- a/mlx/backend/metal/allocator.h +++ b/mlx/backend/metal/allocator.h @@ -62,6 +62,10 @@ class MetalAllocator : public allocator::Allocator { size_t get_peak_memory() { return peak_memory_; }; + void reset_peak_memory() { + std::unique_lock lk(mutex_); + peak_memory_ = 0; + }; size_t get_cache_memory() { return buffer_cache_.cache_size(); }; diff --git a/mlx/backend/metal/metal.h b/mlx/backend/metal/metal.h index c4d51dd17..c63ddda28 100644 --- a/mlx/backend/metal/metal.h +++ b/mlx/backend/metal/metal.h @@ -20,11 +20,15 @@ size_t get_active_memory(); /* Get the peak amount of used memory in bytes. * - * The maximum memory used is recorded from the beginning of the program - * execution. + * The maximum memory used recorded from the beginning of the program + * execution or since the last call to reset_peak_memory. * */ size_t get_peak_memory(); +/* Reset the peak memory to zero. + * */ +void reset_peak_memory(); + /* Get the cache size in bytes. * * The cache includes memory not currently used that has not been returned diff --git a/mlx/backend/no_metal/metal.cpp b/mlx/backend/no_metal/metal.cpp index 0aeda5c9d..4cf5b00db 100644 --- a/mlx/backend/no_metal/metal.cpp +++ b/mlx/backend/no_metal/metal.cpp @@ -36,6 +36,7 @@ size_t get_active_memory() { size_t get_peak_memory() { return 0; } +void reset_peak_memory() {} size_t get_cache_memory() { return 0; } diff --git a/python/src/metal.cpp b/python/src/metal.cpp index fef2cc69a..4306b3915 100644 --- a/python/src/metal.cpp +++ b/python/src/metal.cpp @@ -35,8 +35,14 @@ void init_metal(nb::module_& m) { R"pbdoc( Get the peak amount of used memory in bytes. - The maximum memory used is recorded from the beginning of the program - execution. + The maximum memory used recorded from the beginning of the program + execution or since the last call to :func:`reset_peak_memory`. + )pbdoc"); + metal.def( + "reset_peak_memory", + &metal::reset_peak_memory, + R"pbdoc( + Reset the peak memory to zero. )pbdoc"); metal.def( "get_cache_memory", diff --git a/python/tests/test_metal.py b/python/tests/test_metal.py index deef2d985..bd091f8e2 100644 --- a/python/tests/test_metal.py +++ b/python/tests/test_metal.py @@ -45,6 +45,9 @@ class TestMetal(mlx_tests.MLXTestCase): mx.metal.clear_cache() self.assertEqual(mx.metal.get_cache_memory(), 0) + mx.metal.reset_peak_memory() + self.assertEqual(mx.metal.get_peak_memory(), 0) + if __name__ == "__main__": unittest.main()