mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
move memory APIs into top level mlx.core (#1982)
This commit is contained in:
parent
65a38c452b
commit
4e1994e9d7
@ -70,6 +70,7 @@ are the CPU and GPU.
|
||||
python/fft
|
||||
python/linalg
|
||||
python/metal
|
||||
python/memory
|
||||
python/nn
|
||||
python/optimizers
|
||||
python/distributed
|
||||
|
16
docs/src/python/memory.rst
Normal file
16
docs/src/python/memory.rst
Normal file
@ -0,0 +1,16 @@
|
||||
Memory Management
|
||||
=================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
get_active_memory
|
||||
get_peak_memory
|
||||
reset_peak_memory
|
||||
get_cache_memory
|
||||
set_memory_limit
|
||||
set_cache_limit
|
||||
set_wired_limit
|
||||
clear_cache
|
@ -3,6 +3,7 @@
|
||||
#include "mlx/backend/metal/metal.h"
|
||||
#include "mlx/backend/metal/metal_impl.h"
|
||||
#include "mlx/backend/metal/resident.h"
|
||||
#include "mlx/memory.h"
|
||||
|
||||
#include <mach/vm_page_size.h>
|
||||
#include <unistd.h>
|
||||
@ -323,40 +324,40 @@ MetalAllocator& allocator() {
|
||||
return *allocator_;
|
||||
}
|
||||
|
||||
} // namespace metal
|
||||
|
||||
size_t set_cache_limit(size_t limit) {
|
||||
return allocator().set_cache_limit(limit);
|
||||
return metal::allocator().set_cache_limit(limit);
|
||||
}
|
||||
size_t set_memory_limit(size_t limit) {
|
||||
return allocator().set_memory_limit(limit);
|
||||
return metal::allocator().set_memory_limit(limit);
|
||||
}
|
||||
size_t get_memory_limit() {
|
||||
return allocator().get_memory_limit();
|
||||
return metal::allocator().get_memory_limit();
|
||||
}
|
||||
size_t set_wired_limit(size_t limit) {
|
||||
if (limit >
|
||||
std::get<size_t>(device_info().at("max_recommended_working_set_size"))) {
|
||||
if (limit > std::get<size_t>(metal::device_info().at(
|
||||
"max_recommended_working_set_size"))) {
|
||||
throw std::invalid_argument(
|
||||
"[metal::set_wired_limit] Setting a wired limit larger than "
|
||||
"the maximum working set size is not allowed.");
|
||||
}
|
||||
return allocator().set_wired_limit(limit);
|
||||
return metal::allocator().set_wired_limit(limit);
|
||||
}
|
||||
size_t get_active_memory() {
|
||||
return allocator().get_active_memory();
|
||||
return metal::allocator().get_active_memory();
|
||||
}
|
||||
size_t get_peak_memory() {
|
||||
return allocator().get_peak_memory();
|
||||
return metal::allocator().get_peak_memory();
|
||||
}
|
||||
void reset_peak_memory() {
|
||||
allocator().reset_peak_memory();
|
||||
metal::allocator().reset_peak_memory();
|
||||
}
|
||||
size_t get_cache_memory() {
|
||||
return allocator().get_cache_memory();
|
||||
return metal::allocator().get_cache_memory();
|
||||
}
|
||||
void clear_cache() {
|
||||
return allocator().clear_cache();
|
||||
return metal::allocator().clear_cache();
|
||||
}
|
||||
|
||||
} // namespace metal
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -12,74 +12,6 @@ namespace mlx::core::metal {
|
||||
/* Check if the Metal backend is available. */
|
||||
bool is_available();
|
||||
|
||||
/* Get the actively used memory in bytes.
|
||||
*
|
||||
* Note, this will not always match memory use reported by the system because
|
||||
* it does not include cached memory buffers.
|
||||
* */
|
||||
size_t get_active_memory();
|
||||
|
||||
/* Get the peak amount of used memory in bytes.
|
||||
*
|
||||
* The maximum memory used recorded from the beginning of the program
|
||||
* execution or since the last call to reset_peak_memory.
|
||||
* */
|
||||
size_t get_peak_memory();
|
||||
|
||||
/* Reset the peak memory to zero.
|
||||
* */
|
||||
void reset_peak_memory();
|
||||
|
||||
/* Get the cache size in bytes.
|
||||
*
|
||||
* The cache includes memory not currently used that has not been returned
|
||||
* to the system allocator.
|
||||
* */
|
||||
size_t get_cache_memory();
|
||||
|
||||
/* Set the memory limit.
|
||||
* The memory limit is a guideline for the maximum amount of memory to use
|
||||
* during graph evaluation. If the memory limit is exceeded and there is no
|
||||
* more RAM (including swap when available) allocations will result in an
|
||||
* exception.
|
||||
*
|
||||
* When metal is available the memory limit defaults to 1.5 times the maximum
|
||||
* recommended working set size reported by the device.
|
||||
*
|
||||
* Returns the previous memory limit.
|
||||
* */
|
||||
size_t set_memory_limit(size_t limit);
|
||||
|
||||
/* Get the current memory limit. */
|
||||
size_t get_memory_limit();
|
||||
|
||||
/* Set the free cache limit.
|
||||
* If using more than the given limit, free memory will be reclaimed
|
||||
* from the cache on the next allocation. To disable the cache,
|
||||
* set the limit to 0.
|
||||
*
|
||||
* The cache limit defaults to the memory limit.
|
||||
*
|
||||
* Returns the previous cache limit.
|
||||
* */
|
||||
size_t set_cache_limit(size_t limit);
|
||||
|
||||
/* Clear the memory cache. */
|
||||
void clear_cache();
|
||||
|
||||
/* Set the wired size limit.
|
||||
*
|
||||
* Note, this function is only useful for macOS 15.0 or higher.
|
||||
*
|
||||
* The wired limit is the total size in bytes of memory that will be kept
|
||||
* resident. The default value is ``0``.
|
||||
*
|
||||
* Setting a wired limit larger than system wired limit is an error.
|
||||
*
|
||||
* Returns the previous wired limit.
|
||||
* */
|
||||
size_t set_wired_limit(size_t limit);
|
||||
|
||||
/** Capture a GPU trace, saving it to an absolute file `path` */
|
||||
void start_capture(std::string path = "");
|
||||
void stop_capture();
|
||||
|
@ -2,7 +2,9 @@
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
|
||||
namespace mlx::core::allocator {
|
||||
namespace mlx::core {
|
||||
|
||||
namespace allocator {
|
||||
|
||||
Allocator& allocator() {
|
||||
static CommonAllocator allocator_;
|
||||
@ -15,5 +17,30 @@ void* Buffer::raw_ptr() {
|
||||
}
|
||||
return static_cast<size_t*>(ptr_) + 1;
|
||||
}
|
||||
} // namespace allocator
|
||||
|
||||
} // namespace mlx::core::allocator
|
||||
size_t get_active_memory() {
|
||||
return 0;
|
||||
}
|
||||
size_t get_peak_memory() {
|
||||
return 0;
|
||||
}
|
||||
void reset_peak_memory() {}
|
||||
size_t get_cache_memory() {
|
||||
return 0;
|
||||
}
|
||||
size_t set_memory_limit(size_t) {
|
||||
return 0;
|
||||
}
|
||||
size_t get_memory_limit() {
|
||||
return 0;
|
||||
}
|
||||
size_t set_cache_limit(size_t) {
|
||||
return 0;
|
||||
}
|
||||
size_t set_wired_limit(size_t) {
|
||||
return 0;
|
||||
}
|
||||
void clear_cache() {}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -31,33 +31,8 @@ void synchronize(Stream) {
|
||||
"[metal::synchronize] Cannot synchronize GPU without metal backend");
|
||||
}
|
||||
|
||||
// No-ops when Metal is not available.
|
||||
size_t get_active_memory() {
|
||||
return 0;
|
||||
}
|
||||
size_t get_peak_memory() {
|
||||
return 0;
|
||||
}
|
||||
void reset_peak_memory() {}
|
||||
size_t get_cache_memory() {
|
||||
return 0;
|
||||
}
|
||||
size_t set_memory_limit(size_t) {
|
||||
return 0;
|
||||
}
|
||||
size_t get_memory_limit() {
|
||||
return 0;
|
||||
}
|
||||
size_t set_cache_limit(size_t) {
|
||||
return 0;
|
||||
}
|
||||
size_t set_wired_limit(size_t) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
void start_capture(std::string) {}
|
||||
void stop_capture() {}
|
||||
void clear_cache() {}
|
||||
|
||||
const std::unordered_map<std::string, std::variant<std::string, size_t>>&
|
||||
device_info() {
|
||||
|
78
mlx/memory.h
Normal file
78
mlx/memory.h
Normal file
@ -0,0 +1,78 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
/* Get the actively used memory in bytes.
|
||||
*
|
||||
* Note, this will not always match memory use reported by the system because
|
||||
* it does not include cached memory buffers.
|
||||
* */
|
||||
size_t get_active_memory();
|
||||
|
||||
/* Get the peak amount of used memory in bytes.
|
||||
*
|
||||
* The maximum memory used recorded from the beginning of the program
|
||||
* execution or since the last call to reset_peak_memory.
|
||||
* */
|
||||
size_t get_peak_memory();
|
||||
|
||||
/* Reset the peak memory to zero.
|
||||
* */
|
||||
void reset_peak_memory();
|
||||
|
||||
/* Get the cache size in bytes.
|
||||
*
|
||||
* The cache includes memory not currently used that has not been returned
|
||||
* to the system allocator.
|
||||
* */
|
||||
size_t get_cache_memory();
|
||||
|
||||
/* Set the memory limit.
|
||||
* The memory limit is a guideline for the maximum amount of memory to use
|
||||
* during graph evaluation. If the memory limit is exceeded and there is no
|
||||
* more RAM (including swap when available) allocations will result in an
|
||||
* exception.
|
||||
*
|
||||
* When Metal is available the memory limit defaults to 1.5 times the maximum
|
||||
* recommended working set size reported by the device.
|
||||
*
|
||||
* Returns the previous memory limit.
|
||||
* */
|
||||
size_t set_memory_limit(size_t limit);
|
||||
|
||||
/* Get the current memory limit. */
|
||||
size_t get_memory_limit();
|
||||
|
||||
/* Set the cache limit.
|
||||
* If using more than the given limit, free memory will be reclaimed
|
||||
* from the cache on the next allocation. To disable the cache,
|
||||
* set the limit to 0.
|
||||
*
|
||||
* The cache limit defaults to the memory limit.
|
||||
*
|
||||
* Returns the previous cache limit.
|
||||
* */
|
||||
size_t set_cache_limit(size_t limit);
|
||||
|
||||
/* Clear the memory cache. */
|
||||
void clear_cache();
|
||||
|
||||
/* Set the wired size limit.
|
||||
*
|
||||
* Note, this function is only useful when using the Metal backend with
|
||||
* macOS 15.0 or higher.
|
||||
*
|
||||
* The wired limit is the total size in bytes of memory that will be kept
|
||||
* resident. The default value is ``0``.
|
||||
*
|
||||
* Setting a wired limit larger than system wired limit is an error.
|
||||
*
|
||||
* Returns the previous wired limit.
|
||||
* */
|
||||
size_t set_wired_limit(size_t limit);
|
||||
|
||||
} // namespace mlx::core
|
@ -14,6 +14,7 @@
|
||||
#include "mlx/fft.h"
|
||||
#include "mlx/io.h"
|
||||
#include "mlx/linalg.h"
|
||||
#include "mlx/memory.h"
|
||||
#include "mlx/ops.h"
|
||||
#include "mlx/random.h"
|
||||
#include "mlx/stream.h"
|
||||
|
@ -12,6 +12,7 @@
|
||||
#include "mlx/backend/cpu/eval.h"
|
||||
#include "mlx/backend/metal/metal_impl.h"
|
||||
#include "mlx/fence.h"
|
||||
#include "mlx/memory.h"
|
||||
#include "mlx/ops.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/scheduler.h"
|
||||
@ -219,7 +220,7 @@ array eval_impl(std::vector<array> outputs, bool async) {
|
||||
}
|
||||
|
||||
if (scheduler::n_active_tasks() > MAX_ACTIVE_TASKS ||
|
||||
(metal::get_active_memory() > metal::get_memory_limit() &&
|
||||
(get_active_memory() > get_memory_limit() &&
|
||||
scheduler::n_active_tasks() > 0)) {
|
||||
// Commit any open streams
|
||||
for (auto& [_, e] : events) {
|
||||
@ -228,8 +229,7 @@ array eval_impl(std::vector<array> outputs, bool async) {
|
||||
}
|
||||
}
|
||||
scheduler::wait_for_one();
|
||||
// TODO memory api should be moved out of metal
|
||||
while (metal::get_active_memory() > metal::get_memory_limit() &&
|
||||
while (get_active_memory() > get_memory_limit() &&
|
||||
scheduler::n_active_tasks() > 0) {
|
||||
scheduler::wait_for_one();
|
||||
}
|
||||
|
@ -17,6 +17,7 @@ nanobind_add_module(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/memory.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/mlx_func.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/stream.cpp
|
||||
|
125
python/src/memory.cpp
Normal file
125
python/src/memory.cpp
Normal file
@ -0,0 +1,125 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/memory.h"
|
||||
#include <nanobind/nanobind.h>
|
||||
|
||||
namespace mx = mlx::core;
|
||||
namespace nb = nanobind;
|
||||
using namespace nb::literals;
|
||||
|
||||
void init_memory(nb::module_& m) {
|
||||
m.def(
|
||||
"get_active_memory",
|
||||
&mx::get_active_memory,
|
||||
R"pbdoc(
|
||||
Get the actively used memory in bytes.
|
||||
|
||||
Note, this will not always match memory use reported by the system because
|
||||
it does not include cached memory buffers.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"get_peak_memory",
|
||||
&mx::get_peak_memory,
|
||||
R"pbdoc(
|
||||
Get the peak amount of used memory in bytes.
|
||||
|
||||
The maximum memory used recorded from the beginning of the program
|
||||
execution or since the last call to :func:`reset_peak_memory`.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"reset_peak_memory",
|
||||
&mx::reset_peak_memory,
|
||||
R"pbdoc(
|
||||
Reset the peak memory to zero.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"get_cache_memory",
|
||||
&mx::get_cache_memory,
|
||||
R"pbdoc(
|
||||
Get the cache size in bytes.
|
||||
|
||||
The cache includes memory not currently used that has not been returned
|
||||
to the system allocator.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"set_memory_limit",
|
||||
&mx::set_memory_limit,
|
||||
"limit"_a,
|
||||
R"pbdoc(
|
||||
Set the memory limit.
|
||||
|
||||
The memory limit is a guideline for the maximum amount of memory to use
|
||||
during graph evaluation. If the memory limit is exceeded and there is no
|
||||
more RAM (including swap when available) allocations will result in an
|
||||
exception.
|
||||
|
||||
When metal is available the memory limit defaults to 1.5 times the
|
||||
maximum recommended working set size reported by the device.
|
||||
|
||||
Args:
|
||||
limit (int): Memory limit in bytes.
|
||||
|
||||
Returns:
|
||||
int: The previous memory limit in bytes.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"set_cache_limit",
|
||||
&mx::set_cache_limit,
|
||||
"limit"_a,
|
||||
R"pbdoc(
|
||||
Set the free cache limit.
|
||||
|
||||
If using more than the given limit, free memory will be reclaimed
|
||||
from the cache on the next allocation. To disable the cache, set
|
||||
the limit to ``0``.
|
||||
|
||||
The cache limit defaults to the memory limit. See
|
||||
:func:`set_memory_limit` for more details.
|
||||
|
||||
Args:
|
||||
limit (int): The cache limit in bytes.
|
||||
|
||||
Returns:
|
||||
int: The previous cache limit in bytes.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"set_wired_limit",
|
||||
&mx::set_wired_limit,
|
||||
"limit"_a,
|
||||
R"pbdoc(
|
||||
Set the wired size limit.
|
||||
|
||||
.. note::
|
||||
* This function is only useful on macOS 15.0 or higher.
|
||||
* The wired limit should remain strictly less than the total
|
||||
memory size.
|
||||
|
||||
The wired limit is the total size in bytes of memory that will be kept
|
||||
resident. The default value is ``0``.
|
||||
|
||||
Setting a wired limit larger than system wired limit is an error. You can
|
||||
increase the system wired limit with:
|
||||
|
||||
.. code-block::
|
||||
|
||||
sudo sysctl iogpu.wired_limit_mb=<size_in_megabytes>
|
||||
|
||||
Use :func:`device_info` to query the system wired limit
|
||||
(``"max_recommended_working_set_size"``) and the total memory size
|
||||
(``"memory_size"``).
|
||||
|
||||
Args:
|
||||
limit (int): The wired limit in bytes.
|
||||
|
||||
Returns:
|
||||
int: The previous wired limit in bytes.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"clear_cache",
|
||||
&mx::clear_cache,
|
||||
R"pbdoc(
|
||||
Clear the memory cache.
|
||||
|
||||
After calling this, :func:`get_cache_memory` should return ``0``.
|
||||
)pbdoc");
|
||||
}
|
@ -1,17 +1,27 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
#include <iostream>
|
||||
|
||||
#include "mlx/backend/metal/metal.h"
|
||||
#include <nanobind/nanobind.h>
|
||||
#include <nanobind/stl/optional.h>
|
||||
#include <nanobind/stl/string.h>
|
||||
#include <nanobind/stl/unordered_map.h>
|
||||
#include <nanobind/stl/variant.h>
|
||||
#include <nanobind/stl/vector.h>
|
||||
#include "mlx/backend/metal/metal.h"
|
||||
#include "mlx/memory.h"
|
||||
|
||||
namespace mx = mlx::core;
|
||||
namespace nb = nanobind;
|
||||
using namespace nb::literals;
|
||||
|
||||
bool DEPRECATE(const std::string& old_fn, const std::string new_fn) {
|
||||
std::cerr << old_fn << " is deprecated and will be removed in a future "
|
||||
<< "version. Use " << new_fn << " instead." << std::endl;
|
||||
return true;
|
||||
}
|
||||
|
||||
#define DEPRECATE(oldfn, newfn) static bool dep = DEPRECATE(oldfn, newfn)
|
||||
|
||||
void init_metal(nb::module_& m) {
|
||||
nb::module_ metal = m.def_submodule("metal", "mlx.metal");
|
||||
metal.def(
|
||||
@ -20,121 +30,47 @@ void init_metal(nb::module_& m) {
|
||||
R"pbdoc(
|
||||
Check if the Metal back-end is available.
|
||||
)pbdoc");
|
||||
metal.def(
|
||||
"get_active_memory",
|
||||
&mx::metal::get_active_memory,
|
||||
R"pbdoc(
|
||||
Get the actively used memory in bytes.
|
||||
|
||||
Note, this will not always match memory use reported by the system because
|
||||
it does not include cached memory buffers.
|
||||
)pbdoc");
|
||||
metal.def(
|
||||
"get_peak_memory",
|
||||
&mx::metal::get_peak_memory,
|
||||
R"pbdoc(
|
||||
Get the peak amount of used memory in bytes.
|
||||
|
||||
The maximum memory used recorded from the beginning of the program
|
||||
execution or since the last call to :func:`reset_peak_memory`.
|
||||
)pbdoc");
|
||||
metal.def(
|
||||
"reset_peak_memory",
|
||||
&mx::metal::reset_peak_memory,
|
||||
R"pbdoc(
|
||||
Reset the peak memory to zero.
|
||||
)pbdoc");
|
||||
metal.def(
|
||||
"get_cache_memory",
|
||||
&mx::metal::get_cache_memory,
|
||||
R"pbdoc(
|
||||
Get the cache size in bytes.
|
||||
|
||||
The cache includes memory not currently used that has not been returned
|
||||
to the system allocator.
|
||||
)pbdoc");
|
||||
metal.def("get_active_memory", []() {
|
||||
DEPRECATE("mx.metal.get_active_memory", "mx.get_active_memory");
|
||||
return mx::get_active_memory();
|
||||
});
|
||||
metal.def("get_peak_memory", []() {
|
||||
DEPRECATE("mx.metal.get_peak_memory", "mx.get_peak_memory");
|
||||
return mx::get_active_memory();
|
||||
});
|
||||
metal.def("reset_peak_memory", []() {
|
||||
DEPRECATE("mx.metal.reset_peak_memory", "mx.reset_peak_memory");
|
||||
mx::reset_peak_memory();
|
||||
});
|
||||
metal.def("get_cache_memory", []() {
|
||||
DEPRECATE("mx.metal.get_cache_memory", "mx.get_cache_memory");
|
||||
return mx::get_cache_memory();
|
||||
});
|
||||
metal.def(
|
||||
"set_memory_limit",
|
||||
&mx::metal::set_memory_limit,
|
||||
"limit"_a,
|
||||
R"pbdoc(
|
||||
Set the memory limit.
|
||||
|
||||
The memory limit is a guideline for the maximum amount of memory to use
|
||||
during graph evaluation. If the memory limit is exceeded and there is no
|
||||
more RAM (including swap when available) allocations will result in an
|
||||
exception.
|
||||
|
||||
When metal is available the memory limit defaults to 1.5 times the
|
||||
maximum recommended working set size reported by the device.
|
||||
|
||||
Args:
|
||||
limit (int): Memory limit in bytes.
|
||||
|
||||
Returns:
|
||||
int: The previous memory limit in bytes.
|
||||
)pbdoc");
|
||||
[](size_t limit) {
|
||||
DEPRECATE("mx.metal.set_memory_limt", "mx.set_memory_limit");
|
||||
return mx::set_memory_limit(limit);
|
||||
},
|
||||
"limit"_a);
|
||||
metal.def(
|
||||
"set_cache_limit",
|
||||
&mx::metal::set_cache_limit,
|
||||
"limit"_a,
|
||||
R"pbdoc(
|
||||
Set the free cache limit.
|
||||
|
||||
If using more than the given limit, free memory will be reclaimed
|
||||
from the cache on the next allocation. To disable the cache, set
|
||||
the limit to ``0``.
|
||||
|
||||
The cache limit defaults to the memory limit. See
|
||||
:func:`set_memory_limit` for more details.
|
||||
|
||||
Args:
|
||||
limit (int): The cache limit in bytes.
|
||||
|
||||
Returns:
|
||||
int: The previous cache limit in bytes.
|
||||
)pbdoc");
|
||||
[](size_t limit) {
|
||||
DEPRECATE("mx.metal.set_cache_limt", "mx.set_cache_limit");
|
||||
return mx::set_cache_limit(limit);
|
||||
},
|
||||
"limit"_a);
|
||||
metal.def(
|
||||
"set_wired_limit",
|
||||
&mx::metal::set_wired_limit,
|
||||
"limit"_a,
|
||||
R"pbdoc(
|
||||
Set the wired size limit.
|
||||
|
||||
.. note::
|
||||
* This function is only useful on macOS 15.0 or higher.
|
||||
* The wired limit should remain strictly less than the total
|
||||
memory size.
|
||||
|
||||
The wired limit is the total size in bytes of memory that will be kept
|
||||
resident. The default value is ``0``.
|
||||
|
||||
Setting a wired limit larger than system wired limit is an error. You can
|
||||
increase the system wired limit with:
|
||||
|
||||
.. code-block::
|
||||
|
||||
sudo sysctl iogpu.wired_limit_mb=<size_in_megabytes>
|
||||
|
||||
Use :func:`device_info` to query the system wired limit
|
||||
(``"max_recommended_working_set_size"``) and the total memory size
|
||||
(``"memory_size"``).
|
||||
|
||||
Args:
|
||||
limit (int): The wired limit in bytes.
|
||||
|
||||
Returns:
|
||||
int: The previous wired limit in bytes.
|
||||
)pbdoc");
|
||||
metal.def(
|
||||
"clear_cache",
|
||||
&mx::metal::clear_cache,
|
||||
R"pbdoc(
|
||||
Clear the memory cache.
|
||||
|
||||
After calling this, :func:`get_cache_memory` should return ``0``.
|
||||
)pbdoc");
|
||||
|
||||
[](size_t limit) {
|
||||
DEPRECATE("mx.metal.set_wired_limt", "mx.set_wired_limit");
|
||||
return mx::set_wired_limit(limit);
|
||||
},
|
||||
"limit"_a);
|
||||
metal.def("clear_cache", []() {
|
||||
DEPRECATE("mx.metal.clear_cache", "mx.clear_cache");
|
||||
mx::clear_cache();
|
||||
});
|
||||
metal.def(
|
||||
"start_capture",
|
||||
&mx::metal::start_capture,
|
||||
|
@ -12,6 +12,7 @@ void init_array(nb::module_&);
|
||||
void init_device(nb::module_&);
|
||||
void init_stream(nb::module_&);
|
||||
void init_metal(nb::module_&);
|
||||
void init_memory(nb::module_&);
|
||||
void init_ops(nb::module_&);
|
||||
void init_transforms(nb::module_&);
|
||||
void init_random(nb::module_&);
|
||||
@ -34,6 +35,7 @@ NB_MODULE(core, m) {
|
||||
init_stream(m);
|
||||
init_array(m);
|
||||
init_metal(m);
|
||||
init_memory(m);
|
||||
init_ops(m);
|
||||
init_transforms(m);
|
||||
init_random(m);
|
||||
|
@ -179,16 +179,16 @@ class TestDistributed(mlx_tests.MLXTestCase):
|
||||
mx.eval(x)
|
||||
mx.synchronize(mx.default_stream(mx.default_device()))
|
||||
|
||||
mx.metal.reset_peak_memory()
|
||||
mx.reset_peak_memory()
|
||||
scale = mx.array(2.0)
|
||||
y = mx.distributed.all_sum(x)
|
||||
mx.eval(y)
|
||||
mx.synchronize(mx.default_stream(mx.default_device()))
|
||||
all_sum_only = mx.metal.get_peak_memory()
|
||||
all_sum_only = mx.get_peak_memory()
|
||||
y = mx.distributed.all_sum(x) * scale
|
||||
mx.eval(y)
|
||||
mx.synchronize(mx.default_stream(mx.default_device()))
|
||||
all_sum_with_binary = mx.metal.get_peak_memory()
|
||||
all_sum_with_binary = mx.get_peak_memory()
|
||||
|
||||
self.assertEqual(all_sum_only, all_sum_with_binary)
|
||||
|
||||
|
@ -1813,10 +1813,10 @@ class TestArray(mlx_tests.MLXTestCase):
|
||||
|
||||
fun()
|
||||
mx.synchronize()
|
||||
peak_1 = mx.metal.get_peak_memory()
|
||||
peak_1 = mx.get_peak_memory()
|
||||
fun()
|
||||
mx.synchronize()
|
||||
peak_2 = mx.metal.get_peak_memory()
|
||||
peak_2 = mx.get_peak_memory()
|
||||
self.assertEqual(peak_1, peak_2)
|
||||
|
||||
def fun():
|
||||
@ -1826,10 +1826,10 @@ class TestArray(mlx_tests.MLXTestCase):
|
||||
|
||||
fun()
|
||||
mx.synchronize()
|
||||
peak_1 = mx.metal.get_peak_memory()
|
||||
peak_1 = mx.get_peak_memory()
|
||||
fun()
|
||||
mx.synchronize()
|
||||
peak_2 = mx.metal.get_peak_memory()
|
||||
peak_2 = mx.get_peak_memory()
|
||||
self.assertEqual(peak_1, peak_2)
|
||||
|
||||
def test_add_numpy(self):
|
||||
|
@ -747,7 +747,7 @@ class TestAutograd(mlx_tests.MLXTestCase):
|
||||
]:
|
||||
if mx.metal.is_available():
|
||||
mx.synchronize(mx.default_stream(mx.default_device()))
|
||||
mem_pre = mx.metal.get_active_memory()
|
||||
mem_pre = mx.get_active_memory()
|
||||
else:
|
||||
mem_pre = 0
|
||||
|
||||
@ -765,7 +765,7 @@ class TestAutograd(mlx_tests.MLXTestCase):
|
||||
gc.collect()
|
||||
|
||||
if mx.metal.is_available():
|
||||
mem_post = mx.metal.get_active_memory()
|
||||
mem_post = mx.get_active_memory()
|
||||
else:
|
||||
mem_post = 0
|
||||
|
||||
|
@ -955,7 +955,7 @@ class TestCompile(mlx_tests.MLXTestCase):
|
||||
def test_leaks(self):
|
||||
gc.collect()
|
||||
if mx.metal.is_available():
|
||||
mem_pre = mx.metal.get_active_memory()
|
||||
mem_pre = mx.get_active_memory()
|
||||
else:
|
||||
mem_pre = 0
|
||||
|
||||
@ -973,7 +973,7 @@ class TestCompile(mlx_tests.MLXTestCase):
|
||||
gc.collect()
|
||||
|
||||
if mx.metal.is_available():
|
||||
mem_post = mx.metal.get_active_memory()
|
||||
mem_post = mx.get_active_memory()
|
||||
else:
|
||||
mem_post = 0
|
||||
|
||||
|
@ -118,9 +118,9 @@ class TestEval(mlx_tests.MLXTestCase):
|
||||
out = mx.vjp(fn, (x,), (y,))
|
||||
out = mx.vjp(fn, (x,), (y,))
|
||||
if mx.metal.is_available():
|
||||
peak_mem = mx.metal.get_peak_memory()
|
||||
peak_mem = mx.get_peak_memory()
|
||||
out = mx.vjp(fn, (x,), (y,))
|
||||
self.assertEqual(peak_mem, mx.metal.get_peak_memory())
|
||||
self.assertEqual(peak_mem, mx.get_peak_memory())
|
||||
|
||||
def test_async_eval_with_multiple_streams(self):
|
||||
x = mx.array([1.0])
|
||||
@ -151,11 +151,11 @@ class TestEval(mlx_tests.MLXTestCase):
|
||||
|
||||
x = mx.zeros((4096, 4096))
|
||||
mx.eval(x)
|
||||
pre = mx.metal.get_peak_memory()
|
||||
pre = mx.get_peak_memory()
|
||||
out = fun(x)
|
||||
del x
|
||||
mx.eval(out)
|
||||
post = mx.metal.get_peak_memory()
|
||||
post = mx.get_peak_memory()
|
||||
self.assertEqual(pre, post)
|
||||
|
||||
def fun(x):
|
||||
@ -167,11 +167,11 @@ class TestEval(mlx_tests.MLXTestCase):
|
||||
|
||||
x = mx.zeros((4096 * 4096,))
|
||||
mx.eval(x)
|
||||
pre = mx.metal.get_peak_memory()
|
||||
pre = mx.get_peak_memory()
|
||||
out = fun(x)
|
||||
del x
|
||||
mx.eval(out)
|
||||
post = mx.metal.get_peak_memory()
|
||||
post = mx.get_peak_memory()
|
||||
self.assertEqual(pre, post)
|
||||
|
||||
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
|
||||
@ -187,7 +187,7 @@ class TestEval(mlx_tests.MLXTestCase):
|
||||
|
||||
s1 = mx.default_stream(mx.gpu)
|
||||
s2 = mx.new_stream(mx.gpu)
|
||||
old_limit = mx.metal.set_memory_limit(1000)
|
||||
old_limit = mx.set_memory_limit(1000)
|
||||
|
||||
x = mx.ones((512, 512), stream=s2)
|
||||
for _ in range(80):
|
||||
@ -195,7 +195,7 @@ class TestEval(mlx_tests.MLXTestCase):
|
||||
y = mx.abs(x, stream=s2)
|
||||
z = mx.abs(y, stream=s2)
|
||||
mx.eval(z)
|
||||
mx.metal.set_memory_limit(old_limit)
|
||||
mx.set_memory_limit(old_limit)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -243,7 +243,7 @@ class TestExportImport(mlx_tests.MLXTestCase):
|
||||
def test_leaks(self):
|
||||
path = os.path.join(self.test_dir, "fn.mlxfn")
|
||||
if mx.metal.is_available():
|
||||
mem_pre = mx.metal.get_active_memory()
|
||||
mem_pre = mx.get_active_memory()
|
||||
else:
|
||||
mem_pre = 0
|
||||
|
||||
@ -261,7 +261,7 @@ class TestExportImport(mlx_tests.MLXTestCase):
|
||||
gc.collect()
|
||||
|
||||
if mx.metal.is_available():
|
||||
mem_post = mx.metal.get_active_memory()
|
||||
mem_post = mx.get_active_memory()
|
||||
else:
|
||||
mem_post = 0
|
||||
|
||||
|
@ -387,14 +387,14 @@ class TestLoad(mlx_tests.MLXTestCase):
|
||||
mx.save(save_file, x)
|
||||
mx.synchronize(mx.default_stream(mx.default_device()))
|
||||
|
||||
mx.metal.reset_peak_memory()
|
||||
mx.reset_peak_memory()
|
||||
scale = mx.array(2.0)
|
||||
y = mx.load(save_file)
|
||||
mx.eval(y)
|
||||
load_only = mx.metal.get_peak_memory()
|
||||
load_only = mx.get_peak_memory()
|
||||
y = mx.load(save_file) * scale
|
||||
mx.eval(y)
|
||||
load_with_binary = mx.metal.get_peak_memory()
|
||||
load_with_binary = mx.get_peak_memory()
|
||||
|
||||
self.assertEqual(load_only, load_with_binary)
|
||||
|
||||
|
60
python/tests/test_memory.py
Normal file
60
python/tests/test_memory.py
Normal file
@ -0,0 +1,60 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import unittest
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx_tests
|
||||
|
||||
|
||||
class TestMemory(mlx_tests.MLXTestCase):
|
||||
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
|
||||
def test_memory_info(self):
|
||||
old_limit = mx.set_cache_limit(0)
|
||||
|
||||
a = mx.zeros((4096,))
|
||||
mx.eval(a)
|
||||
del a
|
||||
self.assertEqual(mx.get_cache_memory(), 0)
|
||||
self.assertEqual(mx.set_cache_limit(old_limit), 0)
|
||||
self.assertEqual(mx.set_cache_limit(old_limit), old_limit)
|
||||
|
||||
old_limit = mx.set_memory_limit(10)
|
||||
self.assertTrue(mx.set_memory_limit(old_limit), 10)
|
||||
self.assertTrue(mx.set_memory_limit(old_limit), old_limit)
|
||||
|
||||
# Query active and peak memory
|
||||
a = mx.zeros((4096,))
|
||||
mx.eval(a)
|
||||
mx.synchronize()
|
||||
active_mem = mx.get_active_memory()
|
||||
self.assertTrue(active_mem >= 4096 * 4)
|
||||
|
||||
b = mx.zeros((4096,))
|
||||
mx.eval(b)
|
||||
del b
|
||||
mx.synchronize()
|
||||
|
||||
new_active_mem = mx.get_active_memory()
|
||||
self.assertEqual(new_active_mem, active_mem)
|
||||
peak_mem = mx.get_peak_memory()
|
||||
self.assertTrue(peak_mem >= 4096 * 8)
|
||||
cache_mem = mx.get_cache_memory()
|
||||
self.assertTrue(cache_mem >= 4096 * 4)
|
||||
|
||||
mx.clear_cache()
|
||||
self.assertEqual(mx.get_cache_memory(), 0)
|
||||
|
||||
mx.reset_peak_memory()
|
||||
self.assertEqual(mx.get_peak_memory(), 0)
|
||||
|
||||
old_limit = mx.set_wired_limit(1000)
|
||||
old_limit = mx.set_wired_limit(0)
|
||||
self.assertEqual(old_limit, 1000)
|
||||
|
||||
max_size = mx.metal.device_info()["max_recommended_working_set_size"]
|
||||
with self.assertRaises(ValueError):
|
||||
mx.set_wired_limit(max_size + 10)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
@ -1,60 +0,0 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import unittest
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx_tests
|
||||
|
||||
|
||||
class TestMetal(mlx_tests.MLXTestCase):
|
||||
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
|
||||
def test_memory_info(self):
|
||||
old_limit = mx.metal.set_cache_limit(0)
|
||||
|
||||
a = mx.zeros((4096,))
|
||||
mx.eval(a)
|
||||
del a
|
||||
self.assertEqual(mx.metal.get_cache_memory(), 0)
|
||||
self.assertEqual(mx.metal.set_cache_limit(old_limit), 0)
|
||||
self.assertEqual(mx.metal.set_cache_limit(old_limit), old_limit)
|
||||
|
||||
old_limit = mx.metal.set_memory_limit(10)
|
||||
self.assertTrue(mx.metal.set_memory_limit(old_limit), 10)
|
||||
self.assertTrue(mx.metal.set_memory_limit(old_limit), old_limit)
|
||||
|
||||
# Query active and peak memory
|
||||
a = mx.zeros((4096,))
|
||||
mx.eval(a)
|
||||
mx.synchronize()
|
||||
active_mem = mx.metal.get_active_memory()
|
||||
self.assertTrue(active_mem >= 4096 * 4)
|
||||
|
||||
b = mx.zeros((4096,))
|
||||
mx.eval(b)
|
||||
del b
|
||||
mx.synchronize()
|
||||
|
||||
new_active_mem = mx.metal.get_active_memory()
|
||||
self.assertEqual(new_active_mem, active_mem)
|
||||
peak_mem = mx.metal.get_peak_memory()
|
||||
self.assertTrue(peak_mem >= 4096 * 8)
|
||||
cache_mem = mx.metal.get_cache_memory()
|
||||
self.assertTrue(cache_mem >= 4096 * 4)
|
||||
|
||||
mx.metal.clear_cache()
|
||||
self.assertEqual(mx.metal.get_cache_memory(), 0)
|
||||
|
||||
mx.metal.reset_peak_memory()
|
||||
self.assertEqual(mx.metal.get_peak_memory(), 0)
|
||||
|
||||
old_limit = mx.metal.set_wired_limit(1000)
|
||||
old_limit = mx.metal.set_wired_limit(0)
|
||||
self.assertEqual(old_limit, 1000)
|
||||
|
||||
max_size = mx.metal.device_info()["max_recommended_working_set_size"]
|
||||
with self.assertRaises(ValueError):
|
||||
mx.metal.set_wired_limit(max_size + 10)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
@ -1904,10 +1904,10 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
mx.synchronize(mx.default_stream(mx.default_device()))
|
||||
mx.eval(fn(2))
|
||||
mx.synchronize(mx.default_stream(mx.default_device()))
|
||||
mem2 = mx.metal.get_peak_memory()
|
||||
mem2 = mx.get_peak_memory()
|
||||
mx.eval(fn(4))
|
||||
mx.synchronize(mx.default_stream(mx.default_device()))
|
||||
mem4 = mx.metal.get_peak_memory()
|
||||
mem4 = mx.get_peak_memory()
|
||||
self.assertEqual(mem2, mem4)
|
||||
|
||||
def test_squeeze_expand(self):
|
||||
|
@ -635,7 +635,7 @@ class TestVmap(mlx_tests.MLXTestCase):
|
||||
|
||||
def test_leaks(self):
|
||||
if mx.metal.is_available():
|
||||
mem_pre = mx.metal.get_active_memory()
|
||||
mem_pre = mx.get_active_memory()
|
||||
else:
|
||||
mem_pre = 0
|
||||
|
||||
@ -653,7 +653,7 @@ class TestVmap(mlx_tests.MLXTestCase):
|
||||
gc.collect()
|
||||
|
||||
if mx.metal.is_available():
|
||||
mem_post = mx.metal.get_active_memory()
|
||||
mem_post = mx.get_active_memory()
|
||||
else:
|
||||
mem_post = 0
|
||||
|
||||
|
@ -473,24 +473,24 @@ TEST_CASE("test metal validation") {
|
||||
eval(scatter_max(array(1), {}, array(2), std::vector<int>{}));
|
||||
}
|
||||
|
||||
TEST_CASE("test metal memory info") {
|
||||
TEST_CASE("test memory info") {
|
||||
// Test cache limits
|
||||
{
|
||||
auto old_limit = metal::set_cache_limit(0);
|
||||
auto old_limit = set_cache_limit(0);
|
||||
{
|
||||
auto a = zeros({4096});
|
||||
eval(a);
|
||||
}
|
||||
CHECK_EQ(metal::get_cache_memory(), 0);
|
||||
CHECK_EQ(metal::set_cache_limit(old_limit), 0);
|
||||
CHECK_EQ(metal::set_cache_limit(old_limit), old_limit);
|
||||
CHECK_EQ(get_cache_memory(), 0);
|
||||
CHECK_EQ(set_cache_limit(old_limit), 0);
|
||||
CHECK_EQ(set_cache_limit(old_limit), old_limit);
|
||||
}
|
||||
|
||||
// Test memory limits
|
||||
{
|
||||
auto old_limit = metal::set_memory_limit(10);
|
||||
CHECK_EQ(metal::set_memory_limit(old_limit), 10);
|
||||
CHECK_EQ(metal::set_memory_limit(old_limit), old_limit);
|
||||
auto old_limit = set_memory_limit(10);
|
||||
CHECK_EQ(set_memory_limit(old_limit), 10);
|
||||
CHECK_EQ(set_memory_limit(old_limit), old_limit);
|
||||
}
|
||||
|
||||
// Query active and peak memory
|
||||
@ -498,22 +498,22 @@ TEST_CASE("test metal memory info") {
|
||||
auto a = zeros({4096});
|
||||
eval(a);
|
||||
synchronize();
|
||||
auto active_mem = metal::get_active_memory();
|
||||
auto active_mem = get_active_memory();
|
||||
CHECK(active_mem >= 4096 * 4);
|
||||
{
|
||||
auto b = zeros({4096});
|
||||
eval(b);
|
||||
}
|
||||
synchronize();
|
||||
auto new_active_mem = metal::get_active_memory();
|
||||
auto new_active_mem = get_active_memory();
|
||||
CHECK_EQ(new_active_mem, active_mem);
|
||||
auto peak_mem = metal::get_peak_memory();
|
||||
auto peak_mem = get_peak_memory();
|
||||
CHECK(peak_mem >= 4096 * 8);
|
||||
|
||||
auto cache_mem = metal::get_cache_memory();
|
||||
auto cache_mem = get_cache_memory();
|
||||
CHECK(cache_mem >= 4096 * 4);
|
||||
}
|
||||
|
||||
metal::clear_cache();
|
||||
CHECK_EQ(metal::get_cache_memory(), 0);
|
||||
clear_cache();
|
||||
CHECK_EQ(get_cache_memory(), 0);
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user