mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-20 01:18:12 +08:00
bindings for memory info (#761)
* bindings for memory info * update api * keep cache low if requested * fix default * nit in ops error
This commit is contained in:
@@ -5,18 +5,88 @@
|
||||
#include "mlx/backend/metal/metal.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
using namespace py::literals;
|
||||
|
||||
using namespace mlx::core;
|
||||
|
||||
void init_metal(py::module_& m) {
|
||||
py::module_ metal = m.def_submodule("metal", "mlx.metal");
|
||||
metal.def("is_available", &metal::is_available);
|
||||
metal.def(
|
||||
"cache_enabled",
|
||||
&metal::cache_enabled,
|
||||
"check if metal buffer cache is enabled, default is true");
|
||||
"is_available",
|
||||
&metal::is_available,
|
||||
R"pbdoc(
|
||||
Check if the Metal back-end is available.
|
||||
)pbdoc");
|
||||
metal.def(
|
||||
"set_cache_enabled",
|
||||
&metal::set_cache_enabled,
|
||||
"enable or disable metal buffer cache");
|
||||
"get_active_memory",
|
||||
&metal::get_active_memory,
|
||||
R"pbdoc(
|
||||
Get the actively used memory in bytes.
|
||||
|
||||
Note, this will not always match memory use reported by the system because
|
||||
it does not include cached memory buffers.
|
||||
)pbdoc");
|
||||
metal.def(
|
||||
"get_peak_memory",
|
||||
&metal::get_peak_memory,
|
||||
R"pbdoc(
|
||||
Get the peak amount of used memory in bytes.
|
||||
|
||||
The maximum memory used is recorded from the beginning of the program
|
||||
execution.
|
||||
)pbdoc");
|
||||
metal.def(
|
||||
"get_cache_memory",
|
||||
&metal::get_cache_memory,
|
||||
R"pbdoc(
|
||||
Get the cache size in bytes.
|
||||
|
||||
The cache includes memory not currently used that has not been returned
|
||||
to the system allocator.
|
||||
)pbdoc");
|
||||
metal.def(
|
||||
"set_memory_limit",
|
||||
&metal::set_memory_limit,
|
||||
"limit"_a,
|
||||
py::kw_only(),
|
||||
"relaxed"_a = true,
|
||||
R"pbdoc(
|
||||
Set the memory limit.
|
||||
|
||||
Memory allocations will wait on scheduled tasks to complete if the limit
|
||||
is exceeded. If there are no more scheduled tasks an error will be raised
|
||||
if ``relaxed`` is ``False``. Otherwise memory will be allocated
|
||||
(including the potential for swap) if ``relaxed`` is ``True``.
|
||||
|
||||
The memory limit defaults to 1.5 times the maximum recommended working set
|
||||
size reported by the device.
|
||||
|
||||
Args:
|
||||
limit (int): Memory limit in bytes.
|
||||
relaxed (bool, optional): If `False`` an error is raised if the limit
|
||||
is exceeded. Default: ``True``
|
||||
|
||||
Returns:
|
||||
int: The previous memory limit in bytes.
|
||||
)pbdoc");
|
||||
metal.def(
|
||||
"set_cache_limit",
|
||||
&metal::set_cache_limit,
|
||||
"limit"_a,
|
||||
R"pbdoc(
|
||||
Set the free cache limit.
|
||||
|
||||
If using more than the given limit, free memory will be reclaimed
|
||||
from the cache on the next allocation. To disable the cache, set
|
||||
the limit to ``0``.
|
||||
|
||||
The cache limit defaults to the memory limit. See
|
||||
:func:`set_memory_limit` for more details.
|
||||
|
||||
Args:
|
||||
limit (int): The cache limit in bytes.
|
||||
|
||||
Returns:
|
||||
int: The previous cache limit in bytes.
|
||||
)pbdoc");
|
||||
}
|
||||
|
45
python/tests/test_metal.py
Normal file
45
python/tests/test_metal.py
Normal file
@@ -0,0 +1,45 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import unittest
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx_tests
|
||||
|
||||
|
||||
class TestMetal(mlx_tests.MLXTestCase):
|
||||
|
||||
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
|
||||
def test_memory_info(self):
|
||||
old_limit = mx.metal.set_cache_limit(0)
|
||||
|
||||
a = mx.zeros((4096,))
|
||||
mx.eval(a)
|
||||
del a
|
||||
self.assertEqual(mx.metal.get_cache_memory(), 0)
|
||||
self.assertEqual(mx.metal.set_cache_limit(old_limit), 0)
|
||||
self.assertEqual(mx.metal.set_cache_limit(old_limit), old_limit)
|
||||
|
||||
old_limit = mx.metal.set_memory_limit(10)
|
||||
self.assertTrue(mx.metal.set_memory_limit(old_limit), 10)
|
||||
self.assertTrue(mx.metal.set_memory_limit(old_limit), old_limit)
|
||||
|
||||
# Query active and peak memory
|
||||
a = mx.zeros((4096,))
|
||||
mx.eval(a)
|
||||
active_mem = mx.metal.get_active_memory()
|
||||
self.assertTrue(active_mem >= 4096 * 4)
|
||||
|
||||
b = mx.zeros((4096,))
|
||||
mx.eval(b)
|
||||
del b
|
||||
|
||||
new_active_mem = mx.metal.get_active_memory()
|
||||
self.assertEqual(new_active_mem, active_mem)
|
||||
peak_mem = mx.metal.get_peak_memory()
|
||||
self.assertTrue(peak_mem >= 4096 * 8)
|
||||
cache_mem = mx.metal.get_cache_memory()
|
||||
self.assertTrue(cache_mem >= 4096 * 4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Reference in New Issue
Block a user