move memory APIs into top level mlx.core (#1982)

This commit is contained in:
Awni Hannun 2025-03-21 07:25:12 -07:00 committed by GitHub
parent 65a38c452b
commit 4e1994e9d7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
25 changed files with 418 additions and 323 deletions

View File

@ -70,6 +70,7 @@ are the CPU and GPU.
python/fft
python/linalg
python/metal
python/memory
python/nn
python/optimizers
python/distributed

View File

@ -0,0 +1,16 @@
Memory Management
=================
.. currentmodule:: mlx.core
.. autosummary::
:toctree: _autosummary
get_active_memory
get_peak_memory
reset_peak_memory
get_cache_memory
set_memory_limit
set_cache_limit
set_wired_limit
clear_cache

View File

@ -3,6 +3,7 @@
#include "mlx/backend/metal/metal.h"
#include "mlx/backend/metal/metal_impl.h"
#include "mlx/backend/metal/resident.h"
#include "mlx/memory.h"
#include <mach/vm_page_size.h>
#include <unistd.h>
@ -323,40 +324,40 @@ MetalAllocator& allocator() {
return *allocator_;
}
} // namespace metal
size_t set_cache_limit(size_t limit) {
return allocator().set_cache_limit(limit);
return metal::allocator().set_cache_limit(limit);
}
size_t set_memory_limit(size_t limit) {
return allocator().set_memory_limit(limit);
return metal::allocator().set_memory_limit(limit);
}
size_t get_memory_limit() {
return allocator().get_memory_limit();
return metal::allocator().get_memory_limit();
}
size_t set_wired_limit(size_t limit) {
if (limit >
std::get<size_t>(device_info().at("max_recommended_working_set_size"))) {
if (limit > std::get<size_t>(metal::device_info().at(
"max_recommended_working_set_size"))) {
throw std::invalid_argument(
"[metal::set_wired_limit] Setting a wired limit larger than "
"the maximum working set size is not allowed.");
}
return allocator().set_wired_limit(limit);
return metal::allocator().set_wired_limit(limit);
}
size_t get_active_memory() {
return allocator().get_active_memory();
return metal::allocator().get_active_memory();
}
size_t get_peak_memory() {
return allocator().get_peak_memory();
return metal::allocator().get_peak_memory();
}
void reset_peak_memory() {
allocator().reset_peak_memory();
metal::allocator().reset_peak_memory();
}
size_t get_cache_memory() {
return allocator().get_cache_memory();
return metal::allocator().get_cache_memory();
}
void clear_cache() {
return allocator().clear_cache();
return metal::allocator().clear_cache();
}
} // namespace metal
} // namespace mlx::core

View File

@ -12,74 +12,6 @@ namespace mlx::core::metal {
/* Check if the Metal backend is available. */
bool is_available();
/* Get the actively used memory in bytes.
*
* Note, this will not always match memory use reported by the system because
* it does not include cached memory buffers.
* */
size_t get_active_memory();
/* Get the peak amount of used memory in bytes.
*
* The maximum memory used recorded from the beginning of the program
* execution or since the last call to reset_peak_memory.
* */
size_t get_peak_memory();
/* Reset the peak memory to zero.
* */
void reset_peak_memory();
/* Get the cache size in bytes.
*
* The cache includes memory not currently used that has not been returned
* to the system allocator.
* */
size_t get_cache_memory();
/* Set the memory limit.
* The memory limit is a guideline for the maximum amount of memory to use
* during graph evaluation. If the memory limit is exceeded and there is no
* more RAM (including swap when available) allocations will result in an
* exception.
*
* When metal is available the memory limit defaults to 1.5 times the maximum
* recommended working set size reported by the device.
*
* Returns the previous memory limit.
* */
size_t set_memory_limit(size_t limit);
/* Get the current memory limit. */
size_t get_memory_limit();
/* Set the free cache limit.
* If using more than the given limit, free memory will be reclaimed
* from the cache on the next allocation. To disable the cache,
* set the limit to 0.
*
* The cache limit defaults to the memory limit.
*
* Returns the previous cache limit.
* */
size_t set_cache_limit(size_t limit);
/* Clear the memory cache. */
void clear_cache();
/* Set the wired size limit.
*
* Note, this function is only useful for macOS 15.0 or higher.
*
* The wired limit is the total size in bytes of memory that will be kept
* resident. The default value is ``0``.
*
* Setting a wired limit larger than system wired limit is an error.
*
* Returns the previous wired limit.
* */
size_t set_wired_limit(size_t limit);
/** Capture a GPU trace, saving it to an absolute file `path` */
void start_capture(std::string path = "");
void stop_capture();

View File

@ -2,7 +2,9 @@
#include "mlx/allocator.h"
namespace mlx::core::allocator {
namespace mlx::core {
namespace allocator {
Allocator& allocator() {
static CommonAllocator allocator_;
@ -15,5 +17,30 @@ void* Buffer::raw_ptr() {
}
return static_cast<size_t*>(ptr_) + 1;
}
} // namespace allocator
} // namespace mlx::core::allocator
size_t get_active_memory() {
return 0;
}
size_t get_peak_memory() {
return 0;
}
void reset_peak_memory() {}
size_t get_cache_memory() {
return 0;
}
size_t set_memory_limit(size_t) {
return 0;
}
size_t get_memory_limit() {
return 0;
}
size_t set_cache_limit(size_t) {
return 0;
}
size_t set_wired_limit(size_t) {
return 0;
}
void clear_cache() {}
} // namespace mlx::core

View File

@ -31,33 +31,8 @@ void synchronize(Stream) {
"[metal::synchronize] Cannot synchronize GPU without metal backend");
}
// No-ops when Metal is not available.
size_t get_active_memory() {
return 0;
}
size_t get_peak_memory() {
return 0;
}
void reset_peak_memory() {}
size_t get_cache_memory() {
return 0;
}
size_t set_memory_limit(size_t) {
return 0;
}
size_t get_memory_limit() {
return 0;
}
size_t set_cache_limit(size_t) {
return 0;
}
size_t set_wired_limit(size_t) {
return 0;
}
void start_capture(std::string) {}
void stop_capture() {}
void clear_cache() {}
const std::unordered_map<std::string, std::variant<std::string, size_t>>&
device_info() {

78
mlx/memory.h Normal file
View File

@ -0,0 +1,78 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include <cstdlib>
namespace mlx::core {
/* Get the actively used memory in bytes.
*
* Note, this will not always match memory use reported by the system because
* it does not include cached memory buffers.
* */
size_t get_active_memory();
/* Get the peak amount of used memory in bytes.
*
* The maximum memory used recorded from the beginning of the program
* execution or since the last call to reset_peak_memory.
* */
size_t get_peak_memory();
/* Reset the peak memory to zero.
* */
void reset_peak_memory();
/* Get the cache size in bytes.
*
* The cache includes memory not currently used that has not been returned
* to the system allocator.
* */
size_t get_cache_memory();
/* Set the memory limit.
* The memory limit is a guideline for the maximum amount of memory to use
* during graph evaluation. If the memory limit is exceeded and there is no
* more RAM (including swap when available) allocations will result in an
* exception.
*
* When Metal is available the memory limit defaults to 1.5 times the maximum
* recommended working set size reported by the device.
*
* Returns the previous memory limit.
* */
size_t set_memory_limit(size_t limit);
/* Get the current memory limit. */
size_t get_memory_limit();
/* Set the cache limit.
* If using more than the given limit, free memory will be reclaimed
* from the cache on the next allocation. To disable the cache,
* set the limit to 0.
*
* The cache limit defaults to the memory limit.
*
* Returns the previous cache limit.
* */
size_t set_cache_limit(size_t limit);
/* Clear the memory cache. */
void clear_cache();
/* Set the wired size limit.
*
* Note, this function is only useful when using the Metal backend with
* macOS 15.0 or higher.
*
* The wired limit is the total size in bytes of memory that will be kept
* resident. The default value is ``0``.
*
* Setting a wired limit larger than system wired limit is an error.
*
* Returns the previous wired limit.
* */
size_t set_wired_limit(size_t limit);
} // namespace mlx::core

View File

@ -14,6 +14,7 @@
#include "mlx/fft.h"
#include "mlx/io.h"
#include "mlx/linalg.h"
#include "mlx/memory.h"
#include "mlx/ops.h"
#include "mlx/random.h"
#include "mlx/stream.h"

View File

@ -12,6 +12,7 @@
#include "mlx/backend/cpu/eval.h"
#include "mlx/backend/metal/metal_impl.h"
#include "mlx/fence.h"
#include "mlx/memory.h"
#include "mlx/ops.h"
#include "mlx/primitives.h"
#include "mlx/scheduler.h"
@ -219,7 +220,7 @@ array eval_impl(std::vector<array> outputs, bool async) {
}
if (scheduler::n_active_tasks() > MAX_ACTIVE_TASKS ||
(metal::get_active_memory() > metal::get_memory_limit() &&
(get_active_memory() > get_memory_limit() &&
scheduler::n_active_tasks() > 0)) {
// Commit any open streams
for (auto& [_, e] : events) {
@ -228,8 +229,7 @@ array eval_impl(std::vector<array> outputs, bool async) {
}
}
scheduler::wait_for_one();
// TODO memory api should be moved out of metal
while (metal::get_active_memory() > metal::get_memory_limit() &&
while (get_active_memory() > get_memory_limit() &&
scheduler::n_active_tasks() > 0) {
scheduler::wait_for_one();
}

View File

@ -17,6 +17,7 @@ nanobind_add_module(
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp
${CMAKE_CURRENT_SOURCE_DIR}/memory.cpp
${CMAKE_CURRENT_SOURCE_DIR}/mlx_func.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp
${CMAKE_CURRENT_SOURCE_DIR}/stream.cpp

125
python/src/memory.cpp Normal file
View File

@ -0,0 +1,125 @@
// Copyright © 2025 Apple Inc.
#include "mlx/memory.h"
#include <nanobind/nanobind.h>
namespace mx = mlx::core;
namespace nb = nanobind;
using namespace nb::literals;
void init_memory(nb::module_& m) {
m.def(
"get_active_memory",
&mx::get_active_memory,
R"pbdoc(
Get the actively used memory in bytes.
Note, this will not always match memory use reported by the system because
it does not include cached memory buffers.
)pbdoc");
m.def(
"get_peak_memory",
&mx::get_peak_memory,
R"pbdoc(
Get the peak amount of used memory in bytes.
The maximum memory used recorded from the beginning of the program
execution or since the last call to :func:`reset_peak_memory`.
)pbdoc");
m.def(
"reset_peak_memory",
&mx::reset_peak_memory,
R"pbdoc(
Reset the peak memory to zero.
)pbdoc");
m.def(
"get_cache_memory",
&mx::get_cache_memory,
R"pbdoc(
Get the cache size in bytes.
The cache includes memory not currently used that has not been returned
to the system allocator.
)pbdoc");
m.def(
"set_memory_limit",
&mx::set_memory_limit,
"limit"_a,
R"pbdoc(
Set the memory limit.
The memory limit is a guideline for the maximum amount of memory to use
during graph evaluation. If the memory limit is exceeded and there is no
more RAM (including swap when available) allocations will result in an
exception.
When metal is available the memory limit defaults to 1.5 times the
maximum recommended working set size reported by the device.
Args:
limit (int): Memory limit in bytes.
Returns:
int: The previous memory limit in bytes.
)pbdoc");
m.def(
"set_cache_limit",
&mx::set_cache_limit,
"limit"_a,
R"pbdoc(
Set the free cache limit.
If using more than the given limit, free memory will be reclaimed
from the cache on the next allocation. To disable the cache, set
the limit to ``0``.
The cache limit defaults to the memory limit. See
:func:`set_memory_limit` for more details.
Args:
limit (int): The cache limit in bytes.
Returns:
int: The previous cache limit in bytes.
)pbdoc");
m.def(
"set_wired_limit",
&mx::set_wired_limit,
"limit"_a,
R"pbdoc(
Set the wired size limit.
.. note::
* This function is only useful on macOS 15.0 or higher.
* The wired limit should remain strictly less than the total
memory size.
The wired limit is the total size in bytes of memory that will be kept
resident. The default value is ``0``.
Setting a wired limit larger than system wired limit is an error. You can
increase the system wired limit with:
.. code-block::
sudo sysctl iogpu.wired_limit_mb=<size_in_megabytes>
Use :func:`device_info` to query the system wired limit
(``"max_recommended_working_set_size"``) and the total memory size
(``"memory_size"``).
Args:
limit (int): The wired limit in bytes.
Returns:
int: The previous wired limit in bytes.
)pbdoc");
m.def(
"clear_cache",
&mx::clear_cache,
R"pbdoc(
Clear the memory cache.
After calling this, :func:`get_cache_memory` should return ``0``.
)pbdoc");
}

View File

@ -1,17 +1,27 @@
// Copyright © 2023-2024 Apple Inc.
#include <iostream>
#include "mlx/backend/metal/metal.h"
#include <nanobind/nanobind.h>
#include <nanobind/stl/optional.h>
#include <nanobind/stl/string.h>
#include <nanobind/stl/unordered_map.h>
#include <nanobind/stl/variant.h>
#include <nanobind/stl/vector.h>
#include "mlx/backend/metal/metal.h"
#include "mlx/memory.h"
namespace mx = mlx::core;
namespace nb = nanobind;
using namespace nb::literals;
bool DEPRECATE(const std::string& old_fn, const std::string new_fn) {
std::cerr << old_fn << " is deprecated and will be removed in a future "
<< "version. Use " << new_fn << " instead." << std::endl;
return true;
}
#define DEPRECATE(oldfn, newfn) static bool dep = DEPRECATE(oldfn, newfn)
void init_metal(nb::module_& m) {
nb::module_ metal = m.def_submodule("metal", "mlx.metal");
metal.def(
@ -20,121 +30,47 @@ void init_metal(nb::module_& m) {
R"pbdoc(
Check if the Metal back-end is available.
)pbdoc");
metal.def(
"get_active_memory",
&mx::metal::get_active_memory,
R"pbdoc(
Get the actively used memory in bytes.
Note, this will not always match memory use reported by the system because
it does not include cached memory buffers.
)pbdoc");
metal.def(
"get_peak_memory",
&mx::metal::get_peak_memory,
R"pbdoc(
Get the peak amount of used memory in bytes.
The maximum memory used recorded from the beginning of the program
execution or since the last call to :func:`reset_peak_memory`.
)pbdoc");
metal.def(
"reset_peak_memory",
&mx::metal::reset_peak_memory,
R"pbdoc(
Reset the peak memory to zero.
)pbdoc");
metal.def(
"get_cache_memory",
&mx::metal::get_cache_memory,
R"pbdoc(
Get the cache size in bytes.
The cache includes memory not currently used that has not been returned
to the system allocator.
)pbdoc");
metal.def("get_active_memory", []() {
DEPRECATE("mx.metal.get_active_memory", "mx.get_active_memory");
return mx::get_active_memory();
});
metal.def("get_peak_memory", []() {
DEPRECATE("mx.metal.get_peak_memory", "mx.get_peak_memory");
return mx::get_active_memory();
});
metal.def("reset_peak_memory", []() {
DEPRECATE("mx.metal.reset_peak_memory", "mx.reset_peak_memory");
mx::reset_peak_memory();
});
metal.def("get_cache_memory", []() {
DEPRECATE("mx.metal.get_cache_memory", "mx.get_cache_memory");
return mx::get_cache_memory();
});
metal.def(
"set_memory_limit",
&mx::metal::set_memory_limit,
"limit"_a,
R"pbdoc(
Set the memory limit.
The memory limit is a guideline for the maximum amount of memory to use
during graph evaluation. If the memory limit is exceeded and there is no
more RAM (including swap when available) allocations will result in an
exception.
When metal is available the memory limit defaults to 1.5 times the
maximum recommended working set size reported by the device.
Args:
limit (int): Memory limit in bytes.
Returns:
int: The previous memory limit in bytes.
)pbdoc");
[](size_t limit) {
DEPRECATE("mx.metal.set_memory_limt", "mx.set_memory_limit");
return mx::set_memory_limit(limit);
},
"limit"_a);
metal.def(
"set_cache_limit",
&mx::metal::set_cache_limit,
"limit"_a,
R"pbdoc(
Set the free cache limit.
If using more than the given limit, free memory will be reclaimed
from the cache on the next allocation. To disable the cache, set
the limit to ``0``.
The cache limit defaults to the memory limit. See
:func:`set_memory_limit` for more details.
Args:
limit (int): The cache limit in bytes.
Returns:
int: The previous cache limit in bytes.
)pbdoc");
[](size_t limit) {
DEPRECATE("mx.metal.set_cache_limt", "mx.set_cache_limit");
return mx::set_cache_limit(limit);
},
"limit"_a);
metal.def(
"set_wired_limit",
&mx::metal::set_wired_limit,
"limit"_a,
R"pbdoc(
Set the wired size limit.
.. note::
* This function is only useful on macOS 15.0 or higher.
* The wired limit should remain strictly less than the total
memory size.
The wired limit is the total size in bytes of memory that will be kept
resident. The default value is ``0``.
Setting a wired limit larger than system wired limit is an error. You can
increase the system wired limit with:
.. code-block::
sudo sysctl iogpu.wired_limit_mb=<size_in_megabytes>
Use :func:`device_info` to query the system wired limit
(``"max_recommended_working_set_size"``) and the total memory size
(``"memory_size"``).
Args:
limit (int): The wired limit in bytes.
Returns:
int: The previous wired limit in bytes.
)pbdoc");
metal.def(
"clear_cache",
&mx::metal::clear_cache,
R"pbdoc(
Clear the memory cache.
After calling this, :func:`get_cache_memory` should return ``0``.
)pbdoc");
[](size_t limit) {
DEPRECATE("mx.metal.set_wired_limt", "mx.set_wired_limit");
return mx::set_wired_limit(limit);
},
"limit"_a);
metal.def("clear_cache", []() {
DEPRECATE("mx.metal.clear_cache", "mx.clear_cache");
mx::clear_cache();
});
metal.def(
"start_capture",
&mx::metal::start_capture,

View File

@ -12,6 +12,7 @@ void init_array(nb::module_&);
void init_device(nb::module_&);
void init_stream(nb::module_&);
void init_metal(nb::module_&);
void init_memory(nb::module_&);
void init_ops(nb::module_&);
void init_transforms(nb::module_&);
void init_random(nb::module_&);
@ -34,6 +35,7 @@ NB_MODULE(core, m) {
init_stream(m);
init_array(m);
init_metal(m);
init_memory(m);
init_ops(m);
init_transforms(m);
init_random(m);

View File

@ -179,16 +179,16 @@ class TestDistributed(mlx_tests.MLXTestCase):
mx.eval(x)
mx.synchronize(mx.default_stream(mx.default_device()))
mx.metal.reset_peak_memory()
mx.reset_peak_memory()
scale = mx.array(2.0)
y = mx.distributed.all_sum(x)
mx.eval(y)
mx.synchronize(mx.default_stream(mx.default_device()))
all_sum_only = mx.metal.get_peak_memory()
all_sum_only = mx.get_peak_memory()
y = mx.distributed.all_sum(x) * scale
mx.eval(y)
mx.synchronize(mx.default_stream(mx.default_device()))
all_sum_with_binary = mx.metal.get_peak_memory()
all_sum_with_binary = mx.get_peak_memory()
self.assertEqual(all_sum_only, all_sum_with_binary)

View File

@ -1813,10 +1813,10 @@ class TestArray(mlx_tests.MLXTestCase):
fun()
mx.synchronize()
peak_1 = mx.metal.get_peak_memory()
peak_1 = mx.get_peak_memory()
fun()
mx.synchronize()
peak_2 = mx.metal.get_peak_memory()
peak_2 = mx.get_peak_memory()
self.assertEqual(peak_1, peak_2)
def fun():
@ -1826,10 +1826,10 @@ class TestArray(mlx_tests.MLXTestCase):
fun()
mx.synchronize()
peak_1 = mx.metal.get_peak_memory()
peak_1 = mx.get_peak_memory()
fun()
mx.synchronize()
peak_2 = mx.metal.get_peak_memory()
peak_2 = mx.get_peak_memory()
self.assertEqual(peak_1, peak_2)
def test_add_numpy(self):

View File

@ -747,7 +747,7 @@ class TestAutograd(mlx_tests.MLXTestCase):
]:
if mx.metal.is_available():
mx.synchronize(mx.default_stream(mx.default_device()))
mem_pre = mx.metal.get_active_memory()
mem_pre = mx.get_active_memory()
else:
mem_pre = 0
@ -765,7 +765,7 @@ class TestAutograd(mlx_tests.MLXTestCase):
gc.collect()
if mx.metal.is_available():
mem_post = mx.metal.get_active_memory()
mem_post = mx.get_active_memory()
else:
mem_post = 0

View File

@ -955,7 +955,7 @@ class TestCompile(mlx_tests.MLXTestCase):
def test_leaks(self):
gc.collect()
if mx.metal.is_available():
mem_pre = mx.metal.get_active_memory()
mem_pre = mx.get_active_memory()
else:
mem_pre = 0
@ -973,7 +973,7 @@ class TestCompile(mlx_tests.MLXTestCase):
gc.collect()
if mx.metal.is_available():
mem_post = mx.metal.get_active_memory()
mem_post = mx.get_active_memory()
else:
mem_post = 0

View File

@ -118,9 +118,9 @@ class TestEval(mlx_tests.MLXTestCase):
out = mx.vjp(fn, (x,), (y,))
out = mx.vjp(fn, (x,), (y,))
if mx.metal.is_available():
peak_mem = mx.metal.get_peak_memory()
peak_mem = mx.get_peak_memory()
out = mx.vjp(fn, (x,), (y,))
self.assertEqual(peak_mem, mx.metal.get_peak_memory())
self.assertEqual(peak_mem, mx.get_peak_memory())
def test_async_eval_with_multiple_streams(self):
x = mx.array([1.0])
@ -151,11 +151,11 @@ class TestEval(mlx_tests.MLXTestCase):
x = mx.zeros((4096, 4096))
mx.eval(x)
pre = mx.metal.get_peak_memory()
pre = mx.get_peak_memory()
out = fun(x)
del x
mx.eval(out)
post = mx.metal.get_peak_memory()
post = mx.get_peak_memory()
self.assertEqual(pre, post)
def fun(x):
@ -167,11 +167,11 @@ class TestEval(mlx_tests.MLXTestCase):
x = mx.zeros((4096 * 4096,))
mx.eval(x)
pre = mx.metal.get_peak_memory()
pre = mx.get_peak_memory()
out = fun(x)
del x
mx.eval(out)
post = mx.metal.get_peak_memory()
post = mx.get_peak_memory()
self.assertEqual(pre, post)
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
@ -187,7 +187,7 @@ class TestEval(mlx_tests.MLXTestCase):
s1 = mx.default_stream(mx.gpu)
s2 = mx.new_stream(mx.gpu)
old_limit = mx.metal.set_memory_limit(1000)
old_limit = mx.set_memory_limit(1000)
x = mx.ones((512, 512), stream=s2)
for _ in range(80):
@ -195,7 +195,7 @@ class TestEval(mlx_tests.MLXTestCase):
y = mx.abs(x, stream=s2)
z = mx.abs(y, stream=s2)
mx.eval(z)
mx.metal.set_memory_limit(old_limit)
mx.set_memory_limit(old_limit)
if __name__ == "__main__":

View File

@ -243,7 +243,7 @@ class TestExportImport(mlx_tests.MLXTestCase):
def test_leaks(self):
path = os.path.join(self.test_dir, "fn.mlxfn")
if mx.metal.is_available():
mem_pre = mx.metal.get_active_memory()
mem_pre = mx.get_active_memory()
else:
mem_pre = 0
@ -261,7 +261,7 @@ class TestExportImport(mlx_tests.MLXTestCase):
gc.collect()
if mx.metal.is_available():
mem_post = mx.metal.get_active_memory()
mem_post = mx.get_active_memory()
else:
mem_post = 0

View File

@ -387,14 +387,14 @@ class TestLoad(mlx_tests.MLXTestCase):
mx.save(save_file, x)
mx.synchronize(mx.default_stream(mx.default_device()))
mx.metal.reset_peak_memory()
mx.reset_peak_memory()
scale = mx.array(2.0)
y = mx.load(save_file)
mx.eval(y)
load_only = mx.metal.get_peak_memory()
load_only = mx.get_peak_memory()
y = mx.load(save_file) * scale
mx.eval(y)
load_with_binary = mx.metal.get_peak_memory()
load_with_binary = mx.get_peak_memory()
self.assertEqual(load_only, load_with_binary)

View File

@ -0,0 +1,60 @@
# Copyright © 2023-2024 Apple Inc.
import unittest
import mlx.core as mx
import mlx_tests
class TestMemory(mlx_tests.MLXTestCase):
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
def test_memory_info(self):
old_limit = mx.set_cache_limit(0)
a = mx.zeros((4096,))
mx.eval(a)
del a
self.assertEqual(mx.get_cache_memory(), 0)
self.assertEqual(mx.set_cache_limit(old_limit), 0)
self.assertEqual(mx.set_cache_limit(old_limit), old_limit)
old_limit = mx.set_memory_limit(10)
self.assertTrue(mx.set_memory_limit(old_limit), 10)
self.assertTrue(mx.set_memory_limit(old_limit), old_limit)
# Query active and peak memory
a = mx.zeros((4096,))
mx.eval(a)
mx.synchronize()
active_mem = mx.get_active_memory()
self.assertTrue(active_mem >= 4096 * 4)
b = mx.zeros((4096,))
mx.eval(b)
del b
mx.synchronize()
new_active_mem = mx.get_active_memory()
self.assertEqual(new_active_mem, active_mem)
peak_mem = mx.get_peak_memory()
self.assertTrue(peak_mem >= 4096 * 8)
cache_mem = mx.get_cache_memory()
self.assertTrue(cache_mem >= 4096 * 4)
mx.clear_cache()
self.assertEqual(mx.get_cache_memory(), 0)
mx.reset_peak_memory()
self.assertEqual(mx.get_peak_memory(), 0)
old_limit = mx.set_wired_limit(1000)
old_limit = mx.set_wired_limit(0)
self.assertEqual(old_limit, 1000)
max_size = mx.metal.device_info()["max_recommended_working_set_size"]
with self.assertRaises(ValueError):
mx.set_wired_limit(max_size + 10)
if __name__ == "__main__":
unittest.main()

View File

@ -1,60 +0,0 @@
# Copyright © 2023-2024 Apple Inc.
import unittest
import mlx.core as mx
import mlx_tests
class TestMetal(mlx_tests.MLXTestCase):
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
def test_memory_info(self):
old_limit = mx.metal.set_cache_limit(0)
a = mx.zeros((4096,))
mx.eval(a)
del a
self.assertEqual(mx.metal.get_cache_memory(), 0)
self.assertEqual(mx.metal.set_cache_limit(old_limit), 0)
self.assertEqual(mx.metal.set_cache_limit(old_limit), old_limit)
old_limit = mx.metal.set_memory_limit(10)
self.assertTrue(mx.metal.set_memory_limit(old_limit), 10)
self.assertTrue(mx.metal.set_memory_limit(old_limit), old_limit)
# Query active and peak memory
a = mx.zeros((4096,))
mx.eval(a)
mx.synchronize()
active_mem = mx.metal.get_active_memory()
self.assertTrue(active_mem >= 4096 * 4)
b = mx.zeros((4096,))
mx.eval(b)
del b
mx.synchronize()
new_active_mem = mx.metal.get_active_memory()
self.assertEqual(new_active_mem, active_mem)
peak_mem = mx.metal.get_peak_memory()
self.assertTrue(peak_mem >= 4096 * 8)
cache_mem = mx.metal.get_cache_memory()
self.assertTrue(cache_mem >= 4096 * 4)
mx.metal.clear_cache()
self.assertEqual(mx.metal.get_cache_memory(), 0)
mx.metal.reset_peak_memory()
self.assertEqual(mx.metal.get_peak_memory(), 0)
old_limit = mx.metal.set_wired_limit(1000)
old_limit = mx.metal.set_wired_limit(0)
self.assertEqual(old_limit, 1000)
max_size = mx.metal.device_info()["max_recommended_working_set_size"]
with self.assertRaises(ValueError):
mx.metal.set_wired_limit(max_size + 10)
if __name__ == "__main__":
unittest.main()

View File

@ -1904,10 +1904,10 @@ class TestOps(mlx_tests.MLXTestCase):
mx.synchronize(mx.default_stream(mx.default_device()))
mx.eval(fn(2))
mx.synchronize(mx.default_stream(mx.default_device()))
mem2 = mx.metal.get_peak_memory()
mem2 = mx.get_peak_memory()
mx.eval(fn(4))
mx.synchronize(mx.default_stream(mx.default_device()))
mem4 = mx.metal.get_peak_memory()
mem4 = mx.get_peak_memory()
self.assertEqual(mem2, mem4)
def test_squeeze_expand(self):

View File

@ -635,7 +635,7 @@ class TestVmap(mlx_tests.MLXTestCase):
def test_leaks(self):
if mx.metal.is_available():
mem_pre = mx.metal.get_active_memory()
mem_pre = mx.get_active_memory()
else:
mem_pre = 0
@ -653,7 +653,7 @@ class TestVmap(mlx_tests.MLXTestCase):
gc.collect()
if mx.metal.is_available():
mem_post = mx.metal.get_active_memory()
mem_post = mx.get_active_memory()
else:
mem_post = 0

View File

@ -473,24 +473,24 @@ TEST_CASE("test metal validation") {
eval(scatter_max(array(1), {}, array(2), std::vector<int>{}));
}
TEST_CASE("test metal memory info") {
TEST_CASE("test memory info") {
// Test cache limits
{
auto old_limit = metal::set_cache_limit(0);
auto old_limit = set_cache_limit(0);
{
auto a = zeros({4096});
eval(a);
}
CHECK_EQ(metal::get_cache_memory(), 0);
CHECK_EQ(metal::set_cache_limit(old_limit), 0);
CHECK_EQ(metal::set_cache_limit(old_limit), old_limit);
CHECK_EQ(get_cache_memory(), 0);
CHECK_EQ(set_cache_limit(old_limit), 0);
CHECK_EQ(set_cache_limit(old_limit), old_limit);
}
// Test memory limits
{
auto old_limit = metal::set_memory_limit(10);
CHECK_EQ(metal::set_memory_limit(old_limit), 10);
CHECK_EQ(metal::set_memory_limit(old_limit), old_limit);
auto old_limit = set_memory_limit(10);
CHECK_EQ(set_memory_limit(old_limit), 10);
CHECK_EQ(set_memory_limit(old_limit), old_limit);
}
// Query active and peak memory
@ -498,22 +498,22 @@ TEST_CASE("test metal memory info") {
auto a = zeros({4096});
eval(a);
synchronize();
auto active_mem = metal::get_active_memory();
auto active_mem = get_active_memory();
CHECK(active_mem >= 4096 * 4);
{
auto b = zeros({4096});
eval(b);
}
synchronize();
auto new_active_mem = metal::get_active_memory();
auto new_active_mem = get_active_memory();
CHECK_EQ(new_active_mem, active_mem);
auto peak_mem = metal::get_peak_memory();
auto peak_mem = get_peak_memory();
CHECK(peak_mem >= 4096 * 8);
auto cache_mem = metal::get_cache_memory();
auto cache_mem = get_cache_memory();
CHECK(cache_mem >= 4096 * 4);
}
metal::clear_cache();
CHECK_EQ(metal::get_cache_memory(), 0);
clear_cache();
CHECK_EQ(get_cache_memory(), 0);
}