From 4e1994e9d702dc208afda4f608f4d1d3b53ae5f5 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 21 Mar 2025 07:25:12 -0700 Subject: [PATCH] move memory APIs into top level mlx.core (#1982) --- docs/src/index.rst | 1 + docs/src/python/memory.rst | 16 +++ mlx/backend/metal/allocator.cpp | 27 ++--- mlx/backend/metal/metal.h | 68 ------------ mlx/backend/no_metal/allocator.cpp | 31 +++++- mlx/backend/no_metal/metal.cpp | 25 ----- mlx/memory.h | 78 ++++++++++++++ mlx/mlx.h | 1 + mlx/transforms.cpp | 6 +- python/src/CMakeLists.txt | 1 + python/src/memory.cpp | 125 +++++++++++++++++++++ python/src/metal.cpp | 156 ++++++++------------------- python/src/mlx.cpp | 2 + python/tests/mpi_test_distributed.py | 6 +- python/tests/test_array.py | 8 +- python/tests/test_autograd.py | 4 +- python/tests/test_compile.py | 4 +- python/tests/test_eval.py | 16 +-- python/tests/test_export_import.py | 4 +- python/tests/test_load.py | 6 +- python/tests/test_memory.py | 60 +++++++++++ python/tests/test_metal.py | 60 ----------- python/tests/test_ops.py | 4 +- python/tests/test_vmap.py | 4 +- tests/metal_tests.cpp | 28 ++--- 25 files changed, 418 insertions(+), 323 deletions(-) create mode 100644 docs/src/python/memory.rst create mode 100644 mlx/memory.h create mode 100644 python/src/memory.cpp create mode 100644 python/tests/test_memory.py delete mode 100644 python/tests/test_metal.py diff --git a/docs/src/index.rst b/docs/src/index.rst index 075861e88..e216ed5ce 100644 --- a/docs/src/index.rst +++ b/docs/src/index.rst @@ -70,6 +70,7 @@ are the CPU and GPU. python/fft python/linalg python/metal + python/memory python/nn python/optimizers python/distributed diff --git a/docs/src/python/memory.rst b/docs/src/python/memory.rst new file mode 100644 index 000000000..f708efbfd --- /dev/null +++ b/docs/src/python/memory.rst @@ -0,0 +1,16 @@ +Memory Management +================= + +.. currentmodule:: mlx.core + +.. autosummary:: + :toctree: _autosummary + + get_active_memory + get_peak_memory + reset_peak_memory + get_cache_memory + set_memory_limit + set_cache_limit + set_wired_limit + clear_cache diff --git a/mlx/backend/metal/allocator.cpp b/mlx/backend/metal/allocator.cpp index d7b84a165..0eec44bfa 100644 --- a/mlx/backend/metal/allocator.cpp +++ b/mlx/backend/metal/allocator.cpp @@ -3,6 +3,7 @@ #include "mlx/backend/metal/metal.h" #include "mlx/backend/metal/metal_impl.h" #include "mlx/backend/metal/resident.h" +#include "mlx/memory.h" #include #include @@ -323,40 +324,40 @@ MetalAllocator& allocator() { return *allocator_; } +} // namespace metal + size_t set_cache_limit(size_t limit) { - return allocator().set_cache_limit(limit); + return metal::allocator().set_cache_limit(limit); } size_t set_memory_limit(size_t limit) { - return allocator().set_memory_limit(limit); + return metal::allocator().set_memory_limit(limit); } size_t get_memory_limit() { - return allocator().get_memory_limit(); + return metal::allocator().get_memory_limit(); } size_t set_wired_limit(size_t limit) { - if (limit > - std::get(device_info().at("max_recommended_working_set_size"))) { + if (limit > std::get(metal::device_info().at( + "max_recommended_working_set_size"))) { throw std::invalid_argument( "[metal::set_wired_limit] Setting a wired limit larger than " "the maximum working set size is not allowed."); } - return allocator().set_wired_limit(limit); + return metal::allocator().set_wired_limit(limit); } size_t get_active_memory() { - return allocator().get_active_memory(); + return metal::allocator().get_active_memory(); } size_t get_peak_memory() { - return allocator().get_peak_memory(); + return metal::allocator().get_peak_memory(); } void reset_peak_memory() { - allocator().reset_peak_memory(); + metal::allocator().reset_peak_memory(); } size_t get_cache_memory() { - return allocator().get_cache_memory(); + return metal::allocator().get_cache_memory(); } void clear_cache() { - return allocator().clear_cache(); + return metal::allocator().clear_cache(); } -} // namespace metal - } // namespace mlx::core diff --git a/mlx/backend/metal/metal.h b/mlx/backend/metal/metal.h index 82151c538..d162007d1 100644 --- a/mlx/backend/metal/metal.h +++ b/mlx/backend/metal/metal.h @@ -12,74 +12,6 @@ namespace mlx::core::metal { /* Check if the Metal backend is available. */ bool is_available(); -/* Get the actively used memory in bytes. - * - * Note, this will not always match memory use reported by the system because - * it does not include cached memory buffers. - * */ -size_t get_active_memory(); - -/* Get the peak amount of used memory in bytes. - * - * The maximum memory used recorded from the beginning of the program - * execution or since the last call to reset_peak_memory. - * */ -size_t get_peak_memory(); - -/* Reset the peak memory to zero. - * */ -void reset_peak_memory(); - -/* Get the cache size in bytes. - * - * The cache includes memory not currently used that has not been returned - * to the system allocator. - * */ -size_t get_cache_memory(); - -/* Set the memory limit. - * The memory limit is a guideline for the maximum amount of memory to use - * during graph evaluation. If the memory limit is exceeded and there is no - * more RAM (including swap when available) allocations will result in an - * exception. - * - * When metal is available the memory limit defaults to 1.5 times the maximum - * recommended working set size reported by the device. - * - * Returns the previous memory limit. - * */ -size_t set_memory_limit(size_t limit); - -/* Get the current memory limit. */ -size_t get_memory_limit(); - -/* Set the free cache limit. - * If using more than the given limit, free memory will be reclaimed - * from the cache on the next allocation. To disable the cache, - * set the limit to 0. - * - * The cache limit defaults to the memory limit. - * - * Returns the previous cache limit. - * */ -size_t set_cache_limit(size_t limit); - -/* Clear the memory cache. */ -void clear_cache(); - -/* Set the wired size limit. - * - * Note, this function is only useful for macOS 15.0 or higher. - * - * The wired limit is the total size in bytes of memory that will be kept - * resident. The default value is ``0``. - * - * Setting a wired limit larger than system wired limit is an error. - * - * Returns the previous wired limit. - * */ -size_t set_wired_limit(size_t limit); - /** Capture a GPU trace, saving it to an absolute file `path` */ void start_capture(std::string path = ""); void stop_capture(); diff --git a/mlx/backend/no_metal/allocator.cpp b/mlx/backend/no_metal/allocator.cpp index 0429ea53a..750bcc539 100644 --- a/mlx/backend/no_metal/allocator.cpp +++ b/mlx/backend/no_metal/allocator.cpp @@ -2,7 +2,9 @@ #include "mlx/allocator.h" -namespace mlx::core::allocator { +namespace mlx::core { + +namespace allocator { Allocator& allocator() { static CommonAllocator allocator_; @@ -15,5 +17,30 @@ void* Buffer::raw_ptr() { } return static_cast(ptr_) + 1; } +} // namespace allocator -} // namespace mlx::core::allocator +size_t get_active_memory() { + return 0; +} +size_t get_peak_memory() { + return 0; +} +void reset_peak_memory() {} +size_t get_cache_memory() { + return 0; +} +size_t set_memory_limit(size_t) { + return 0; +} +size_t get_memory_limit() { + return 0; +} +size_t set_cache_limit(size_t) { + return 0; +} +size_t set_wired_limit(size_t) { + return 0; +} +void clear_cache() {} + +} // namespace mlx::core diff --git a/mlx/backend/no_metal/metal.cpp b/mlx/backend/no_metal/metal.cpp index 03c68c734..ef9af8800 100644 --- a/mlx/backend/no_metal/metal.cpp +++ b/mlx/backend/no_metal/metal.cpp @@ -31,33 +31,8 @@ void synchronize(Stream) { "[metal::synchronize] Cannot synchronize GPU without metal backend"); } -// No-ops when Metal is not available. -size_t get_active_memory() { - return 0; -} -size_t get_peak_memory() { - return 0; -} -void reset_peak_memory() {} -size_t get_cache_memory() { - return 0; -} -size_t set_memory_limit(size_t) { - return 0; -} -size_t get_memory_limit() { - return 0; -} -size_t set_cache_limit(size_t) { - return 0; -} -size_t set_wired_limit(size_t) { - return 0; -} - void start_capture(std::string) {} void stop_capture() {} -void clear_cache() {} const std::unordered_map>& device_info() { diff --git a/mlx/memory.h b/mlx/memory.h new file mode 100644 index 000000000..8a264734c --- /dev/null +++ b/mlx/memory.h @@ -0,0 +1,78 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include + +namespace mlx::core { + +/* Get the actively used memory in bytes. + * + * Note, this will not always match memory use reported by the system because + * it does not include cached memory buffers. + * */ +size_t get_active_memory(); + +/* Get the peak amount of used memory in bytes. + * + * The maximum memory used recorded from the beginning of the program + * execution or since the last call to reset_peak_memory. + * */ +size_t get_peak_memory(); + +/* Reset the peak memory to zero. + * */ +void reset_peak_memory(); + +/* Get the cache size in bytes. + * + * The cache includes memory not currently used that has not been returned + * to the system allocator. + * */ +size_t get_cache_memory(); + +/* Set the memory limit. + * The memory limit is a guideline for the maximum amount of memory to use + * during graph evaluation. If the memory limit is exceeded and there is no + * more RAM (including swap when available) allocations will result in an + * exception. + * + * When Metal is available the memory limit defaults to 1.5 times the maximum + * recommended working set size reported by the device. + * + * Returns the previous memory limit. + * */ +size_t set_memory_limit(size_t limit); + +/* Get the current memory limit. */ +size_t get_memory_limit(); + +/* Set the cache limit. + * If using more than the given limit, free memory will be reclaimed + * from the cache on the next allocation. To disable the cache, + * set the limit to 0. + * + * The cache limit defaults to the memory limit. + * + * Returns the previous cache limit. + * */ +size_t set_cache_limit(size_t limit); + +/* Clear the memory cache. */ +void clear_cache(); + +/* Set the wired size limit. + * + * Note, this function is only useful when using the Metal backend with + * macOS 15.0 or higher. + * + * The wired limit is the total size in bytes of memory that will be kept + * resident. The default value is ``0``. + * + * Setting a wired limit larger than system wired limit is an error. + * + * Returns the previous wired limit. + * */ +size_t set_wired_limit(size_t limit); + +} // namespace mlx::core diff --git a/mlx/mlx.h b/mlx/mlx.h index 0fc657ca4..cef8d806d 100644 --- a/mlx/mlx.h +++ b/mlx/mlx.h @@ -14,6 +14,7 @@ #include "mlx/fft.h" #include "mlx/io.h" #include "mlx/linalg.h" +#include "mlx/memory.h" #include "mlx/ops.h" #include "mlx/random.h" #include "mlx/stream.h" diff --git a/mlx/transforms.cpp b/mlx/transforms.cpp index 105a0fa28..54f3b302b 100644 --- a/mlx/transforms.cpp +++ b/mlx/transforms.cpp @@ -12,6 +12,7 @@ #include "mlx/backend/cpu/eval.h" #include "mlx/backend/metal/metal_impl.h" #include "mlx/fence.h" +#include "mlx/memory.h" #include "mlx/ops.h" #include "mlx/primitives.h" #include "mlx/scheduler.h" @@ -219,7 +220,7 @@ array eval_impl(std::vector outputs, bool async) { } if (scheduler::n_active_tasks() > MAX_ACTIVE_TASKS || - (metal::get_active_memory() > metal::get_memory_limit() && + (get_active_memory() > get_memory_limit() && scheduler::n_active_tasks() > 0)) { // Commit any open streams for (auto& [_, e] : events) { @@ -228,8 +229,7 @@ array eval_impl(std::vector outputs, bool async) { } } scheduler::wait_for_one(); - // TODO memory api should be moved out of metal - while (metal::get_active_memory() > metal::get_memory_limit() && + while (get_active_memory() > get_memory_limit() && scheduler::n_active_tasks() > 0) { scheduler::wait_for_one(); } diff --git a/python/src/CMakeLists.txt b/python/src/CMakeLists.txt index caaa478a3..7ea302cf9 100644 --- a/python/src/CMakeLists.txt +++ b/python/src/CMakeLists.txt @@ -17,6 +17,7 @@ nanobind_add_module( ${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/load.cpp ${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/memory.cpp ${CMAKE_CURRENT_SOURCE_DIR}/mlx_func.cpp ${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp ${CMAKE_CURRENT_SOURCE_DIR}/stream.cpp diff --git a/python/src/memory.cpp b/python/src/memory.cpp new file mode 100644 index 000000000..5ce9a765b --- /dev/null +++ b/python/src/memory.cpp @@ -0,0 +1,125 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/memory.h" +#include + +namespace mx = mlx::core; +namespace nb = nanobind; +using namespace nb::literals; + +void init_memory(nb::module_& m) { + m.def( + "get_active_memory", + &mx::get_active_memory, + R"pbdoc( + Get the actively used memory in bytes. + + Note, this will not always match memory use reported by the system because + it does not include cached memory buffers. + )pbdoc"); + m.def( + "get_peak_memory", + &mx::get_peak_memory, + R"pbdoc( + Get the peak amount of used memory in bytes. + + The maximum memory used recorded from the beginning of the program + execution or since the last call to :func:`reset_peak_memory`. + )pbdoc"); + m.def( + "reset_peak_memory", + &mx::reset_peak_memory, + R"pbdoc( + Reset the peak memory to zero. + )pbdoc"); + m.def( + "get_cache_memory", + &mx::get_cache_memory, + R"pbdoc( + Get the cache size in bytes. + + The cache includes memory not currently used that has not been returned + to the system allocator. + )pbdoc"); + m.def( + "set_memory_limit", + &mx::set_memory_limit, + "limit"_a, + R"pbdoc( + Set the memory limit. + + The memory limit is a guideline for the maximum amount of memory to use + during graph evaluation. If the memory limit is exceeded and there is no + more RAM (including swap when available) allocations will result in an + exception. + + When metal is available the memory limit defaults to 1.5 times the + maximum recommended working set size reported by the device. + + Args: + limit (int): Memory limit in bytes. + + Returns: + int: The previous memory limit in bytes. + )pbdoc"); + m.def( + "set_cache_limit", + &mx::set_cache_limit, + "limit"_a, + R"pbdoc( + Set the free cache limit. + + If using more than the given limit, free memory will be reclaimed + from the cache on the next allocation. To disable the cache, set + the limit to ``0``. + + The cache limit defaults to the memory limit. See + :func:`set_memory_limit` for more details. + + Args: + limit (int): The cache limit in bytes. + + Returns: + int: The previous cache limit in bytes. + )pbdoc"); + m.def( + "set_wired_limit", + &mx::set_wired_limit, + "limit"_a, + R"pbdoc( + Set the wired size limit. + + .. note:: + * This function is only useful on macOS 15.0 or higher. + * The wired limit should remain strictly less than the total + memory size. + + The wired limit is the total size in bytes of memory that will be kept + resident. The default value is ``0``. + + Setting a wired limit larger than system wired limit is an error. You can + increase the system wired limit with: + + .. code-block:: + + sudo sysctl iogpu.wired_limit_mb= + + Use :func:`device_info` to query the system wired limit + (``"max_recommended_working_set_size"``) and the total memory size + (``"memory_size"``). + + Args: + limit (int): The wired limit in bytes. + + Returns: + int: The previous wired limit in bytes. + )pbdoc"); + m.def( + "clear_cache", + &mx::clear_cache, + R"pbdoc( + Clear the memory cache. + + After calling this, :func:`get_cache_memory` should return ``0``. + )pbdoc"); +} diff --git a/python/src/metal.cpp b/python/src/metal.cpp index fef856dd9..09c69687c 100644 --- a/python/src/metal.cpp +++ b/python/src/metal.cpp @@ -1,17 +1,27 @@ // Copyright © 2023-2024 Apple Inc. +#include -#include "mlx/backend/metal/metal.h" #include #include #include #include #include #include +#include "mlx/backend/metal/metal.h" +#include "mlx/memory.h" namespace mx = mlx::core; namespace nb = nanobind; using namespace nb::literals; +bool DEPRECATE(const std::string& old_fn, const std::string new_fn) { + std::cerr << old_fn << " is deprecated and will be removed in a future " + << "version. Use " << new_fn << " instead." << std::endl; + return true; +} + +#define DEPRECATE(oldfn, newfn) static bool dep = DEPRECATE(oldfn, newfn) + void init_metal(nb::module_& m) { nb::module_ metal = m.def_submodule("metal", "mlx.metal"); metal.def( @@ -20,121 +30,47 @@ void init_metal(nb::module_& m) { R"pbdoc( Check if the Metal back-end is available. )pbdoc"); - metal.def( - "get_active_memory", - &mx::metal::get_active_memory, - R"pbdoc( - Get the actively used memory in bytes. - - Note, this will not always match memory use reported by the system because - it does not include cached memory buffers. - )pbdoc"); - metal.def( - "get_peak_memory", - &mx::metal::get_peak_memory, - R"pbdoc( - Get the peak amount of used memory in bytes. - - The maximum memory used recorded from the beginning of the program - execution or since the last call to :func:`reset_peak_memory`. - )pbdoc"); - metal.def( - "reset_peak_memory", - &mx::metal::reset_peak_memory, - R"pbdoc( - Reset the peak memory to zero. - )pbdoc"); - metal.def( - "get_cache_memory", - &mx::metal::get_cache_memory, - R"pbdoc( - Get the cache size in bytes. - - The cache includes memory not currently used that has not been returned - to the system allocator. - )pbdoc"); + metal.def("get_active_memory", []() { + DEPRECATE("mx.metal.get_active_memory", "mx.get_active_memory"); + return mx::get_active_memory(); + }); + metal.def("get_peak_memory", []() { + DEPRECATE("mx.metal.get_peak_memory", "mx.get_peak_memory"); + return mx::get_active_memory(); + }); + metal.def("reset_peak_memory", []() { + DEPRECATE("mx.metal.reset_peak_memory", "mx.reset_peak_memory"); + mx::reset_peak_memory(); + }); + metal.def("get_cache_memory", []() { + DEPRECATE("mx.metal.get_cache_memory", "mx.get_cache_memory"); + return mx::get_cache_memory(); + }); metal.def( "set_memory_limit", - &mx::metal::set_memory_limit, - "limit"_a, - R"pbdoc( - Set the memory limit. - - The memory limit is a guideline for the maximum amount of memory to use - during graph evaluation. If the memory limit is exceeded and there is no - more RAM (including swap when available) allocations will result in an - exception. - - When metal is available the memory limit defaults to 1.5 times the - maximum recommended working set size reported by the device. - - Args: - limit (int): Memory limit in bytes. - - Returns: - int: The previous memory limit in bytes. - )pbdoc"); + [](size_t limit) { + DEPRECATE("mx.metal.set_memory_limt", "mx.set_memory_limit"); + return mx::set_memory_limit(limit); + }, + "limit"_a); metal.def( "set_cache_limit", - &mx::metal::set_cache_limit, - "limit"_a, - R"pbdoc( - Set the free cache limit. - - If using more than the given limit, free memory will be reclaimed - from the cache on the next allocation. To disable the cache, set - the limit to ``0``. - - The cache limit defaults to the memory limit. See - :func:`set_memory_limit` for more details. - - Args: - limit (int): The cache limit in bytes. - - Returns: - int: The previous cache limit in bytes. - )pbdoc"); + [](size_t limit) { + DEPRECATE("mx.metal.set_cache_limt", "mx.set_cache_limit"); + return mx::set_cache_limit(limit); + }, + "limit"_a); metal.def( "set_wired_limit", - &mx::metal::set_wired_limit, - "limit"_a, - R"pbdoc( - Set the wired size limit. - - .. note:: - * This function is only useful on macOS 15.0 or higher. - * The wired limit should remain strictly less than the total - memory size. - - The wired limit is the total size in bytes of memory that will be kept - resident. The default value is ``0``. - - Setting a wired limit larger than system wired limit is an error. You can - increase the system wired limit with: - - .. code-block:: - - sudo sysctl iogpu.wired_limit_mb= - - Use :func:`device_info` to query the system wired limit - (``"max_recommended_working_set_size"``) and the total memory size - (``"memory_size"``). - - Args: - limit (int): The wired limit in bytes. - - Returns: - int: The previous wired limit in bytes. - )pbdoc"); - metal.def( - "clear_cache", - &mx::metal::clear_cache, - R"pbdoc( - Clear the memory cache. - - After calling this, :func:`get_cache_memory` should return ``0``. - )pbdoc"); - + [](size_t limit) { + DEPRECATE("mx.metal.set_wired_limt", "mx.set_wired_limit"); + return mx::set_wired_limit(limit); + }, + "limit"_a); + metal.def("clear_cache", []() { + DEPRECATE("mx.metal.clear_cache", "mx.clear_cache"); + mx::clear_cache(); + }); metal.def( "start_capture", &mx::metal::start_capture, diff --git a/python/src/mlx.cpp b/python/src/mlx.cpp index ecf9a3a13..eaddecb26 100644 --- a/python/src/mlx.cpp +++ b/python/src/mlx.cpp @@ -12,6 +12,7 @@ void init_array(nb::module_&); void init_device(nb::module_&); void init_stream(nb::module_&); void init_metal(nb::module_&); +void init_memory(nb::module_&); void init_ops(nb::module_&); void init_transforms(nb::module_&); void init_random(nb::module_&); @@ -34,6 +35,7 @@ NB_MODULE(core, m) { init_stream(m); init_array(m); init_metal(m); + init_memory(m); init_ops(m); init_transforms(m); init_random(m); diff --git a/python/tests/mpi_test_distributed.py b/python/tests/mpi_test_distributed.py index 0d172cee4..f2c1c25b1 100644 --- a/python/tests/mpi_test_distributed.py +++ b/python/tests/mpi_test_distributed.py @@ -179,16 +179,16 @@ class TestDistributed(mlx_tests.MLXTestCase): mx.eval(x) mx.synchronize(mx.default_stream(mx.default_device())) - mx.metal.reset_peak_memory() + mx.reset_peak_memory() scale = mx.array(2.0) y = mx.distributed.all_sum(x) mx.eval(y) mx.synchronize(mx.default_stream(mx.default_device())) - all_sum_only = mx.metal.get_peak_memory() + all_sum_only = mx.get_peak_memory() y = mx.distributed.all_sum(x) * scale mx.eval(y) mx.synchronize(mx.default_stream(mx.default_device())) - all_sum_with_binary = mx.metal.get_peak_memory() + all_sum_with_binary = mx.get_peak_memory() self.assertEqual(all_sum_only, all_sum_with_binary) diff --git a/python/tests/test_array.py b/python/tests/test_array.py index b8917b75c..c6ecde8cb 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -1813,10 +1813,10 @@ class TestArray(mlx_tests.MLXTestCase): fun() mx.synchronize() - peak_1 = mx.metal.get_peak_memory() + peak_1 = mx.get_peak_memory() fun() mx.synchronize() - peak_2 = mx.metal.get_peak_memory() + peak_2 = mx.get_peak_memory() self.assertEqual(peak_1, peak_2) def fun(): @@ -1826,10 +1826,10 @@ class TestArray(mlx_tests.MLXTestCase): fun() mx.synchronize() - peak_1 = mx.metal.get_peak_memory() + peak_1 = mx.get_peak_memory() fun() mx.synchronize() - peak_2 = mx.metal.get_peak_memory() + peak_2 = mx.get_peak_memory() self.assertEqual(peak_1, peak_2) def test_add_numpy(self): diff --git a/python/tests/test_autograd.py b/python/tests/test_autograd.py index 350b09837..82513a825 100644 --- a/python/tests/test_autograd.py +++ b/python/tests/test_autograd.py @@ -747,7 +747,7 @@ class TestAutograd(mlx_tests.MLXTestCase): ]: if mx.metal.is_available(): mx.synchronize(mx.default_stream(mx.default_device())) - mem_pre = mx.metal.get_active_memory() + mem_pre = mx.get_active_memory() else: mem_pre = 0 @@ -765,7 +765,7 @@ class TestAutograd(mlx_tests.MLXTestCase): gc.collect() if mx.metal.is_available(): - mem_post = mx.metal.get_active_memory() + mem_post = mx.get_active_memory() else: mem_post = 0 diff --git a/python/tests/test_compile.py b/python/tests/test_compile.py index 8cf3b4e08..f5ce496cd 100644 --- a/python/tests/test_compile.py +++ b/python/tests/test_compile.py @@ -955,7 +955,7 @@ class TestCompile(mlx_tests.MLXTestCase): def test_leaks(self): gc.collect() if mx.metal.is_available(): - mem_pre = mx.metal.get_active_memory() + mem_pre = mx.get_active_memory() else: mem_pre = 0 @@ -973,7 +973,7 @@ class TestCompile(mlx_tests.MLXTestCase): gc.collect() if mx.metal.is_available(): - mem_post = mx.metal.get_active_memory() + mem_post = mx.get_active_memory() else: mem_post = 0 diff --git a/python/tests/test_eval.py b/python/tests/test_eval.py index ebcf64c7a..1b0a7a268 100644 --- a/python/tests/test_eval.py +++ b/python/tests/test_eval.py @@ -118,9 +118,9 @@ class TestEval(mlx_tests.MLXTestCase): out = mx.vjp(fn, (x,), (y,)) out = mx.vjp(fn, (x,), (y,)) if mx.metal.is_available(): - peak_mem = mx.metal.get_peak_memory() + peak_mem = mx.get_peak_memory() out = mx.vjp(fn, (x,), (y,)) - self.assertEqual(peak_mem, mx.metal.get_peak_memory()) + self.assertEqual(peak_mem, mx.get_peak_memory()) def test_async_eval_with_multiple_streams(self): x = mx.array([1.0]) @@ -151,11 +151,11 @@ class TestEval(mlx_tests.MLXTestCase): x = mx.zeros((4096, 4096)) mx.eval(x) - pre = mx.metal.get_peak_memory() + pre = mx.get_peak_memory() out = fun(x) del x mx.eval(out) - post = mx.metal.get_peak_memory() + post = mx.get_peak_memory() self.assertEqual(pre, post) def fun(x): @@ -167,11 +167,11 @@ class TestEval(mlx_tests.MLXTestCase): x = mx.zeros((4096 * 4096,)) mx.eval(x) - pre = mx.metal.get_peak_memory() + pre = mx.get_peak_memory() out = fun(x) del x mx.eval(out) - post = mx.metal.get_peak_memory() + post = mx.get_peak_memory() self.assertEqual(pre, post) @unittest.skipIf(not mx.metal.is_available(), "Metal is not available") @@ -187,7 +187,7 @@ class TestEval(mlx_tests.MLXTestCase): s1 = mx.default_stream(mx.gpu) s2 = mx.new_stream(mx.gpu) - old_limit = mx.metal.set_memory_limit(1000) + old_limit = mx.set_memory_limit(1000) x = mx.ones((512, 512), stream=s2) for _ in range(80): @@ -195,7 +195,7 @@ class TestEval(mlx_tests.MLXTestCase): y = mx.abs(x, stream=s2) z = mx.abs(y, stream=s2) mx.eval(z) - mx.metal.set_memory_limit(old_limit) + mx.set_memory_limit(old_limit) if __name__ == "__main__": diff --git a/python/tests/test_export_import.py b/python/tests/test_export_import.py index fd62a58f6..2b4b425ca 100644 --- a/python/tests/test_export_import.py +++ b/python/tests/test_export_import.py @@ -243,7 +243,7 @@ class TestExportImport(mlx_tests.MLXTestCase): def test_leaks(self): path = os.path.join(self.test_dir, "fn.mlxfn") if mx.metal.is_available(): - mem_pre = mx.metal.get_active_memory() + mem_pre = mx.get_active_memory() else: mem_pre = 0 @@ -261,7 +261,7 @@ class TestExportImport(mlx_tests.MLXTestCase): gc.collect() if mx.metal.is_available(): - mem_post = mx.metal.get_active_memory() + mem_post = mx.get_active_memory() else: mem_post = 0 diff --git a/python/tests/test_load.py b/python/tests/test_load.py index fbc67f3c2..67c3f4768 100644 --- a/python/tests/test_load.py +++ b/python/tests/test_load.py @@ -387,14 +387,14 @@ class TestLoad(mlx_tests.MLXTestCase): mx.save(save_file, x) mx.synchronize(mx.default_stream(mx.default_device())) - mx.metal.reset_peak_memory() + mx.reset_peak_memory() scale = mx.array(2.0) y = mx.load(save_file) mx.eval(y) - load_only = mx.metal.get_peak_memory() + load_only = mx.get_peak_memory() y = mx.load(save_file) * scale mx.eval(y) - load_with_binary = mx.metal.get_peak_memory() + load_with_binary = mx.get_peak_memory() self.assertEqual(load_only, load_with_binary) diff --git a/python/tests/test_memory.py b/python/tests/test_memory.py new file mode 100644 index 000000000..cf7e8d1ce --- /dev/null +++ b/python/tests/test_memory.py @@ -0,0 +1,60 @@ +# Copyright © 2023-2024 Apple Inc. + +import unittest + +import mlx.core as mx +import mlx_tests + + +class TestMemory(mlx_tests.MLXTestCase): + @unittest.skipIf(not mx.metal.is_available(), "Metal is not available") + def test_memory_info(self): + old_limit = mx.set_cache_limit(0) + + a = mx.zeros((4096,)) + mx.eval(a) + del a + self.assertEqual(mx.get_cache_memory(), 0) + self.assertEqual(mx.set_cache_limit(old_limit), 0) + self.assertEqual(mx.set_cache_limit(old_limit), old_limit) + + old_limit = mx.set_memory_limit(10) + self.assertTrue(mx.set_memory_limit(old_limit), 10) + self.assertTrue(mx.set_memory_limit(old_limit), old_limit) + + # Query active and peak memory + a = mx.zeros((4096,)) + mx.eval(a) + mx.synchronize() + active_mem = mx.get_active_memory() + self.assertTrue(active_mem >= 4096 * 4) + + b = mx.zeros((4096,)) + mx.eval(b) + del b + mx.synchronize() + + new_active_mem = mx.get_active_memory() + self.assertEqual(new_active_mem, active_mem) + peak_mem = mx.get_peak_memory() + self.assertTrue(peak_mem >= 4096 * 8) + cache_mem = mx.get_cache_memory() + self.assertTrue(cache_mem >= 4096 * 4) + + mx.clear_cache() + self.assertEqual(mx.get_cache_memory(), 0) + + mx.reset_peak_memory() + self.assertEqual(mx.get_peak_memory(), 0) + + old_limit = mx.set_wired_limit(1000) + old_limit = mx.set_wired_limit(0) + self.assertEqual(old_limit, 1000) + + max_size = mx.metal.device_info()["max_recommended_working_set_size"] + with self.assertRaises(ValueError): + mx.set_wired_limit(max_size + 10) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/tests/test_metal.py b/python/tests/test_metal.py deleted file mode 100644 index 81cefabce..000000000 --- a/python/tests/test_metal.py +++ /dev/null @@ -1,60 +0,0 @@ -# Copyright © 2023-2024 Apple Inc. - -import unittest - -import mlx.core as mx -import mlx_tests - - -class TestMetal(mlx_tests.MLXTestCase): - @unittest.skipIf(not mx.metal.is_available(), "Metal is not available") - def test_memory_info(self): - old_limit = mx.metal.set_cache_limit(0) - - a = mx.zeros((4096,)) - mx.eval(a) - del a - self.assertEqual(mx.metal.get_cache_memory(), 0) - self.assertEqual(mx.metal.set_cache_limit(old_limit), 0) - self.assertEqual(mx.metal.set_cache_limit(old_limit), old_limit) - - old_limit = mx.metal.set_memory_limit(10) - self.assertTrue(mx.metal.set_memory_limit(old_limit), 10) - self.assertTrue(mx.metal.set_memory_limit(old_limit), old_limit) - - # Query active and peak memory - a = mx.zeros((4096,)) - mx.eval(a) - mx.synchronize() - active_mem = mx.metal.get_active_memory() - self.assertTrue(active_mem >= 4096 * 4) - - b = mx.zeros((4096,)) - mx.eval(b) - del b - mx.synchronize() - - new_active_mem = mx.metal.get_active_memory() - self.assertEqual(new_active_mem, active_mem) - peak_mem = mx.metal.get_peak_memory() - self.assertTrue(peak_mem >= 4096 * 8) - cache_mem = mx.metal.get_cache_memory() - self.assertTrue(cache_mem >= 4096 * 4) - - mx.metal.clear_cache() - self.assertEqual(mx.metal.get_cache_memory(), 0) - - mx.metal.reset_peak_memory() - self.assertEqual(mx.metal.get_peak_memory(), 0) - - old_limit = mx.metal.set_wired_limit(1000) - old_limit = mx.metal.set_wired_limit(0) - self.assertEqual(old_limit, 1000) - - max_size = mx.metal.device_info()["max_recommended_working_set_size"] - with self.assertRaises(ValueError): - mx.metal.set_wired_limit(max_size + 10) - - -if __name__ == "__main__": - unittest.main() diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 8e1cd8efd..2ba098f7b 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -1904,10 +1904,10 @@ class TestOps(mlx_tests.MLXTestCase): mx.synchronize(mx.default_stream(mx.default_device())) mx.eval(fn(2)) mx.synchronize(mx.default_stream(mx.default_device())) - mem2 = mx.metal.get_peak_memory() + mem2 = mx.get_peak_memory() mx.eval(fn(4)) mx.synchronize(mx.default_stream(mx.default_device())) - mem4 = mx.metal.get_peak_memory() + mem4 = mx.get_peak_memory() self.assertEqual(mem2, mem4) def test_squeeze_expand(self): diff --git a/python/tests/test_vmap.py b/python/tests/test_vmap.py index d1d4f0bd4..1a1ba23b3 100644 --- a/python/tests/test_vmap.py +++ b/python/tests/test_vmap.py @@ -635,7 +635,7 @@ class TestVmap(mlx_tests.MLXTestCase): def test_leaks(self): if mx.metal.is_available(): - mem_pre = mx.metal.get_active_memory() + mem_pre = mx.get_active_memory() else: mem_pre = 0 @@ -653,7 +653,7 @@ class TestVmap(mlx_tests.MLXTestCase): gc.collect() if mx.metal.is_available(): - mem_post = mx.metal.get_active_memory() + mem_post = mx.get_active_memory() else: mem_post = 0 diff --git a/tests/metal_tests.cpp b/tests/metal_tests.cpp index 1185ea04f..7aabdf36d 100644 --- a/tests/metal_tests.cpp +++ b/tests/metal_tests.cpp @@ -473,24 +473,24 @@ TEST_CASE("test metal validation") { eval(scatter_max(array(1), {}, array(2), std::vector{})); } -TEST_CASE("test metal memory info") { +TEST_CASE("test memory info") { // Test cache limits { - auto old_limit = metal::set_cache_limit(0); + auto old_limit = set_cache_limit(0); { auto a = zeros({4096}); eval(a); } - CHECK_EQ(metal::get_cache_memory(), 0); - CHECK_EQ(metal::set_cache_limit(old_limit), 0); - CHECK_EQ(metal::set_cache_limit(old_limit), old_limit); + CHECK_EQ(get_cache_memory(), 0); + CHECK_EQ(set_cache_limit(old_limit), 0); + CHECK_EQ(set_cache_limit(old_limit), old_limit); } // Test memory limits { - auto old_limit = metal::set_memory_limit(10); - CHECK_EQ(metal::set_memory_limit(old_limit), 10); - CHECK_EQ(metal::set_memory_limit(old_limit), old_limit); + auto old_limit = set_memory_limit(10); + CHECK_EQ(set_memory_limit(old_limit), 10); + CHECK_EQ(set_memory_limit(old_limit), old_limit); } // Query active and peak memory @@ -498,22 +498,22 @@ TEST_CASE("test metal memory info") { auto a = zeros({4096}); eval(a); synchronize(); - auto active_mem = metal::get_active_memory(); + auto active_mem = get_active_memory(); CHECK(active_mem >= 4096 * 4); { auto b = zeros({4096}); eval(b); } synchronize(); - auto new_active_mem = metal::get_active_memory(); + auto new_active_mem = get_active_memory(); CHECK_EQ(new_active_mem, active_mem); - auto peak_mem = metal::get_peak_memory(); + auto peak_mem = get_peak_memory(); CHECK(peak_mem >= 4096 * 8); - auto cache_mem = metal::get_cache_memory(); + auto cache_mem = get_cache_memory(); CHECK(cache_mem >= 4096 * 4); } - metal::clear_cache(); - CHECK_EQ(metal::get_cache_memory(), 0); + clear_cache(); + CHECK_EQ(get_cache_memory(), 0); }