mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 09:51:17 +08:00
move memory APIs into top level mlx.core (#1982)
This commit is contained in:
parent
65a38c452b
commit
4e1994e9d7
@ -70,6 +70,7 @@ are the CPU and GPU.
|
|||||||
python/fft
|
python/fft
|
||||||
python/linalg
|
python/linalg
|
||||||
python/metal
|
python/metal
|
||||||
|
python/memory
|
||||||
python/nn
|
python/nn
|
||||||
python/optimizers
|
python/optimizers
|
||||||
python/distributed
|
python/distributed
|
||||||
|
16
docs/src/python/memory.rst
Normal file
16
docs/src/python/memory.rst
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
Memory Management
|
||||||
|
=================
|
||||||
|
|
||||||
|
.. currentmodule:: mlx.core
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: _autosummary
|
||||||
|
|
||||||
|
get_active_memory
|
||||||
|
get_peak_memory
|
||||||
|
reset_peak_memory
|
||||||
|
get_cache_memory
|
||||||
|
set_memory_limit
|
||||||
|
set_cache_limit
|
||||||
|
set_wired_limit
|
||||||
|
clear_cache
|
@ -3,6 +3,7 @@
|
|||||||
#include "mlx/backend/metal/metal.h"
|
#include "mlx/backend/metal/metal.h"
|
||||||
#include "mlx/backend/metal/metal_impl.h"
|
#include "mlx/backend/metal/metal_impl.h"
|
||||||
#include "mlx/backend/metal/resident.h"
|
#include "mlx/backend/metal/resident.h"
|
||||||
|
#include "mlx/memory.h"
|
||||||
|
|
||||||
#include <mach/vm_page_size.h>
|
#include <mach/vm_page_size.h>
|
||||||
#include <unistd.h>
|
#include <unistd.h>
|
||||||
@ -323,40 +324,40 @@ MetalAllocator& allocator() {
|
|||||||
return *allocator_;
|
return *allocator_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
} // namespace metal
|
||||||
|
|
||||||
size_t set_cache_limit(size_t limit) {
|
size_t set_cache_limit(size_t limit) {
|
||||||
return allocator().set_cache_limit(limit);
|
return metal::allocator().set_cache_limit(limit);
|
||||||
}
|
}
|
||||||
size_t set_memory_limit(size_t limit) {
|
size_t set_memory_limit(size_t limit) {
|
||||||
return allocator().set_memory_limit(limit);
|
return metal::allocator().set_memory_limit(limit);
|
||||||
}
|
}
|
||||||
size_t get_memory_limit() {
|
size_t get_memory_limit() {
|
||||||
return allocator().get_memory_limit();
|
return metal::allocator().get_memory_limit();
|
||||||
}
|
}
|
||||||
size_t set_wired_limit(size_t limit) {
|
size_t set_wired_limit(size_t limit) {
|
||||||
if (limit >
|
if (limit > std::get<size_t>(metal::device_info().at(
|
||||||
std::get<size_t>(device_info().at("max_recommended_working_set_size"))) {
|
"max_recommended_working_set_size"))) {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"[metal::set_wired_limit] Setting a wired limit larger than "
|
"[metal::set_wired_limit] Setting a wired limit larger than "
|
||||||
"the maximum working set size is not allowed.");
|
"the maximum working set size is not allowed.");
|
||||||
}
|
}
|
||||||
return allocator().set_wired_limit(limit);
|
return metal::allocator().set_wired_limit(limit);
|
||||||
}
|
}
|
||||||
size_t get_active_memory() {
|
size_t get_active_memory() {
|
||||||
return allocator().get_active_memory();
|
return metal::allocator().get_active_memory();
|
||||||
}
|
}
|
||||||
size_t get_peak_memory() {
|
size_t get_peak_memory() {
|
||||||
return allocator().get_peak_memory();
|
return metal::allocator().get_peak_memory();
|
||||||
}
|
}
|
||||||
void reset_peak_memory() {
|
void reset_peak_memory() {
|
||||||
allocator().reset_peak_memory();
|
metal::allocator().reset_peak_memory();
|
||||||
}
|
}
|
||||||
size_t get_cache_memory() {
|
size_t get_cache_memory() {
|
||||||
return allocator().get_cache_memory();
|
return metal::allocator().get_cache_memory();
|
||||||
}
|
}
|
||||||
void clear_cache() {
|
void clear_cache() {
|
||||||
return allocator().clear_cache();
|
return metal::allocator().clear_cache();
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace metal
|
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -12,74 +12,6 @@ namespace mlx::core::metal {
|
|||||||
/* Check if the Metal backend is available. */
|
/* Check if the Metal backend is available. */
|
||||||
bool is_available();
|
bool is_available();
|
||||||
|
|
||||||
/* Get the actively used memory in bytes.
|
|
||||||
*
|
|
||||||
* Note, this will not always match memory use reported by the system because
|
|
||||||
* it does not include cached memory buffers.
|
|
||||||
* */
|
|
||||||
size_t get_active_memory();
|
|
||||||
|
|
||||||
/* Get the peak amount of used memory in bytes.
|
|
||||||
*
|
|
||||||
* The maximum memory used recorded from the beginning of the program
|
|
||||||
* execution or since the last call to reset_peak_memory.
|
|
||||||
* */
|
|
||||||
size_t get_peak_memory();
|
|
||||||
|
|
||||||
/* Reset the peak memory to zero.
|
|
||||||
* */
|
|
||||||
void reset_peak_memory();
|
|
||||||
|
|
||||||
/* Get the cache size in bytes.
|
|
||||||
*
|
|
||||||
* The cache includes memory not currently used that has not been returned
|
|
||||||
* to the system allocator.
|
|
||||||
* */
|
|
||||||
size_t get_cache_memory();
|
|
||||||
|
|
||||||
/* Set the memory limit.
|
|
||||||
* The memory limit is a guideline for the maximum amount of memory to use
|
|
||||||
* during graph evaluation. If the memory limit is exceeded and there is no
|
|
||||||
* more RAM (including swap when available) allocations will result in an
|
|
||||||
* exception.
|
|
||||||
*
|
|
||||||
* When metal is available the memory limit defaults to 1.5 times the maximum
|
|
||||||
* recommended working set size reported by the device.
|
|
||||||
*
|
|
||||||
* Returns the previous memory limit.
|
|
||||||
* */
|
|
||||||
size_t set_memory_limit(size_t limit);
|
|
||||||
|
|
||||||
/* Get the current memory limit. */
|
|
||||||
size_t get_memory_limit();
|
|
||||||
|
|
||||||
/* Set the free cache limit.
|
|
||||||
* If using more than the given limit, free memory will be reclaimed
|
|
||||||
* from the cache on the next allocation. To disable the cache,
|
|
||||||
* set the limit to 0.
|
|
||||||
*
|
|
||||||
* The cache limit defaults to the memory limit.
|
|
||||||
*
|
|
||||||
* Returns the previous cache limit.
|
|
||||||
* */
|
|
||||||
size_t set_cache_limit(size_t limit);
|
|
||||||
|
|
||||||
/* Clear the memory cache. */
|
|
||||||
void clear_cache();
|
|
||||||
|
|
||||||
/* Set the wired size limit.
|
|
||||||
*
|
|
||||||
* Note, this function is only useful for macOS 15.0 or higher.
|
|
||||||
*
|
|
||||||
* The wired limit is the total size in bytes of memory that will be kept
|
|
||||||
* resident. The default value is ``0``.
|
|
||||||
*
|
|
||||||
* Setting a wired limit larger than system wired limit is an error.
|
|
||||||
*
|
|
||||||
* Returns the previous wired limit.
|
|
||||||
* */
|
|
||||||
size_t set_wired_limit(size_t limit);
|
|
||||||
|
|
||||||
/** Capture a GPU trace, saving it to an absolute file `path` */
|
/** Capture a GPU trace, saving it to an absolute file `path` */
|
||||||
void start_capture(std::string path = "");
|
void start_capture(std::string path = "");
|
||||||
void stop_capture();
|
void stop_capture();
|
||||||
|
@ -2,7 +2,9 @@
|
|||||||
|
|
||||||
#include "mlx/allocator.h"
|
#include "mlx/allocator.h"
|
||||||
|
|
||||||
namespace mlx::core::allocator {
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace allocator {
|
||||||
|
|
||||||
Allocator& allocator() {
|
Allocator& allocator() {
|
||||||
static CommonAllocator allocator_;
|
static CommonAllocator allocator_;
|
||||||
@ -15,5 +17,30 @@ void* Buffer::raw_ptr() {
|
|||||||
}
|
}
|
||||||
return static_cast<size_t*>(ptr_) + 1;
|
return static_cast<size_t*>(ptr_) + 1;
|
||||||
}
|
}
|
||||||
|
} // namespace allocator
|
||||||
|
|
||||||
} // namespace mlx::core::allocator
|
size_t get_active_memory() {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
size_t get_peak_memory() {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
void reset_peak_memory() {}
|
||||||
|
size_t get_cache_memory() {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
size_t set_memory_limit(size_t) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
size_t get_memory_limit() {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
size_t set_cache_limit(size_t) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
size_t set_wired_limit(size_t) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
void clear_cache() {}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
|
@ -31,33 +31,8 @@ void synchronize(Stream) {
|
|||||||
"[metal::synchronize] Cannot synchronize GPU without metal backend");
|
"[metal::synchronize] Cannot synchronize GPU without metal backend");
|
||||||
}
|
}
|
||||||
|
|
||||||
// No-ops when Metal is not available.
|
|
||||||
size_t get_active_memory() {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
size_t get_peak_memory() {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
void reset_peak_memory() {}
|
|
||||||
size_t get_cache_memory() {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
size_t set_memory_limit(size_t) {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
size_t get_memory_limit() {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
size_t set_cache_limit(size_t) {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
size_t set_wired_limit(size_t) {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
void start_capture(std::string) {}
|
void start_capture(std::string) {}
|
||||||
void stop_capture() {}
|
void stop_capture() {}
|
||||||
void clear_cache() {}
|
|
||||||
|
|
||||||
const std::unordered_map<std::string, std::variant<std::string, size_t>>&
|
const std::unordered_map<std::string, std::variant<std::string, size_t>>&
|
||||||
device_info() {
|
device_info() {
|
||||||
|
78
mlx/memory.h
Normal file
78
mlx/memory.h
Normal file
@ -0,0 +1,78 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <cstdlib>
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
/* Get the actively used memory in bytes.
|
||||||
|
*
|
||||||
|
* Note, this will not always match memory use reported by the system because
|
||||||
|
* it does not include cached memory buffers.
|
||||||
|
* */
|
||||||
|
size_t get_active_memory();
|
||||||
|
|
||||||
|
/* Get the peak amount of used memory in bytes.
|
||||||
|
*
|
||||||
|
* The maximum memory used recorded from the beginning of the program
|
||||||
|
* execution or since the last call to reset_peak_memory.
|
||||||
|
* */
|
||||||
|
size_t get_peak_memory();
|
||||||
|
|
||||||
|
/* Reset the peak memory to zero.
|
||||||
|
* */
|
||||||
|
void reset_peak_memory();
|
||||||
|
|
||||||
|
/* Get the cache size in bytes.
|
||||||
|
*
|
||||||
|
* The cache includes memory not currently used that has not been returned
|
||||||
|
* to the system allocator.
|
||||||
|
* */
|
||||||
|
size_t get_cache_memory();
|
||||||
|
|
||||||
|
/* Set the memory limit.
|
||||||
|
* The memory limit is a guideline for the maximum amount of memory to use
|
||||||
|
* during graph evaluation. If the memory limit is exceeded and there is no
|
||||||
|
* more RAM (including swap when available) allocations will result in an
|
||||||
|
* exception.
|
||||||
|
*
|
||||||
|
* When Metal is available the memory limit defaults to 1.5 times the maximum
|
||||||
|
* recommended working set size reported by the device.
|
||||||
|
*
|
||||||
|
* Returns the previous memory limit.
|
||||||
|
* */
|
||||||
|
size_t set_memory_limit(size_t limit);
|
||||||
|
|
||||||
|
/* Get the current memory limit. */
|
||||||
|
size_t get_memory_limit();
|
||||||
|
|
||||||
|
/* Set the cache limit.
|
||||||
|
* If using more than the given limit, free memory will be reclaimed
|
||||||
|
* from the cache on the next allocation. To disable the cache,
|
||||||
|
* set the limit to 0.
|
||||||
|
*
|
||||||
|
* The cache limit defaults to the memory limit.
|
||||||
|
*
|
||||||
|
* Returns the previous cache limit.
|
||||||
|
* */
|
||||||
|
size_t set_cache_limit(size_t limit);
|
||||||
|
|
||||||
|
/* Clear the memory cache. */
|
||||||
|
void clear_cache();
|
||||||
|
|
||||||
|
/* Set the wired size limit.
|
||||||
|
*
|
||||||
|
* Note, this function is only useful when using the Metal backend with
|
||||||
|
* macOS 15.0 or higher.
|
||||||
|
*
|
||||||
|
* The wired limit is the total size in bytes of memory that will be kept
|
||||||
|
* resident. The default value is ``0``.
|
||||||
|
*
|
||||||
|
* Setting a wired limit larger than system wired limit is an error.
|
||||||
|
*
|
||||||
|
* Returns the previous wired limit.
|
||||||
|
* */
|
||||||
|
size_t set_wired_limit(size_t limit);
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
@ -14,6 +14,7 @@
|
|||||||
#include "mlx/fft.h"
|
#include "mlx/fft.h"
|
||||||
#include "mlx/io.h"
|
#include "mlx/io.h"
|
||||||
#include "mlx/linalg.h"
|
#include "mlx/linalg.h"
|
||||||
|
#include "mlx/memory.h"
|
||||||
#include "mlx/ops.h"
|
#include "mlx/ops.h"
|
||||||
#include "mlx/random.h"
|
#include "mlx/random.h"
|
||||||
#include "mlx/stream.h"
|
#include "mlx/stream.h"
|
||||||
|
@ -12,6 +12,7 @@
|
|||||||
#include "mlx/backend/cpu/eval.h"
|
#include "mlx/backend/cpu/eval.h"
|
||||||
#include "mlx/backend/metal/metal_impl.h"
|
#include "mlx/backend/metal/metal_impl.h"
|
||||||
#include "mlx/fence.h"
|
#include "mlx/fence.h"
|
||||||
|
#include "mlx/memory.h"
|
||||||
#include "mlx/ops.h"
|
#include "mlx/ops.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
#include "mlx/scheduler.h"
|
#include "mlx/scheduler.h"
|
||||||
@ -219,7 +220,7 @@ array eval_impl(std::vector<array> outputs, bool async) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (scheduler::n_active_tasks() > MAX_ACTIVE_TASKS ||
|
if (scheduler::n_active_tasks() > MAX_ACTIVE_TASKS ||
|
||||||
(metal::get_active_memory() > metal::get_memory_limit() &&
|
(get_active_memory() > get_memory_limit() &&
|
||||||
scheduler::n_active_tasks() > 0)) {
|
scheduler::n_active_tasks() > 0)) {
|
||||||
// Commit any open streams
|
// Commit any open streams
|
||||||
for (auto& [_, e] : events) {
|
for (auto& [_, e] : events) {
|
||||||
@ -228,8 +229,7 @@ array eval_impl(std::vector<array> outputs, bool async) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
scheduler::wait_for_one();
|
scheduler::wait_for_one();
|
||||||
// TODO memory api should be moved out of metal
|
while (get_active_memory() > get_memory_limit() &&
|
||||||
while (metal::get_active_memory() > metal::get_memory_limit() &&
|
|
||||||
scheduler::n_active_tasks() > 0) {
|
scheduler::n_active_tasks() > 0) {
|
||||||
scheduler::wait_for_one();
|
scheduler::wait_for_one();
|
||||||
}
|
}
|
||||||
|
@ -17,6 +17,7 @@ nanobind_add_module(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/memory.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/mlx_func.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/mlx_func.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/stream.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/stream.cpp
|
||||||
|
125
python/src/memory.cpp
Normal file
125
python/src/memory.cpp
Normal file
@ -0,0 +1,125 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/memory.h"
|
||||||
|
#include <nanobind/nanobind.h>
|
||||||
|
|
||||||
|
namespace mx = mlx::core;
|
||||||
|
namespace nb = nanobind;
|
||||||
|
using namespace nb::literals;
|
||||||
|
|
||||||
|
void init_memory(nb::module_& m) {
|
||||||
|
m.def(
|
||||||
|
"get_active_memory",
|
||||||
|
&mx::get_active_memory,
|
||||||
|
R"pbdoc(
|
||||||
|
Get the actively used memory in bytes.
|
||||||
|
|
||||||
|
Note, this will not always match memory use reported by the system because
|
||||||
|
it does not include cached memory buffers.
|
||||||
|
)pbdoc");
|
||||||
|
m.def(
|
||||||
|
"get_peak_memory",
|
||||||
|
&mx::get_peak_memory,
|
||||||
|
R"pbdoc(
|
||||||
|
Get the peak amount of used memory in bytes.
|
||||||
|
|
||||||
|
The maximum memory used recorded from the beginning of the program
|
||||||
|
execution or since the last call to :func:`reset_peak_memory`.
|
||||||
|
)pbdoc");
|
||||||
|
m.def(
|
||||||
|
"reset_peak_memory",
|
||||||
|
&mx::reset_peak_memory,
|
||||||
|
R"pbdoc(
|
||||||
|
Reset the peak memory to zero.
|
||||||
|
)pbdoc");
|
||||||
|
m.def(
|
||||||
|
"get_cache_memory",
|
||||||
|
&mx::get_cache_memory,
|
||||||
|
R"pbdoc(
|
||||||
|
Get the cache size in bytes.
|
||||||
|
|
||||||
|
The cache includes memory not currently used that has not been returned
|
||||||
|
to the system allocator.
|
||||||
|
)pbdoc");
|
||||||
|
m.def(
|
||||||
|
"set_memory_limit",
|
||||||
|
&mx::set_memory_limit,
|
||||||
|
"limit"_a,
|
||||||
|
R"pbdoc(
|
||||||
|
Set the memory limit.
|
||||||
|
|
||||||
|
The memory limit is a guideline for the maximum amount of memory to use
|
||||||
|
during graph evaluation. If the memory limit is exceeded and there is no
|
||||||
|
more RAM (including swap when available) allocations will result in an
|
||||||
|
exception.
|
||||||
|
|
||||||
|
When metal is available the memory limit defaults to 1.5 times the
|
||||||
|
maximum recommended working set size reported by the device.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
limit (int): Memory limit in bytes.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: The previous memory limit in bytes.
|
||||||
|
)pbdoc");
|
||||||
|
m.def(
|
||||||
|
"set_cache_limit",
|
||||||
|
&mx::set_cache_limit,
|
||||||
|
"limit"_a,
|
||||||
|
R"pbdoc(
|
||||||
|
Set the free cache limit.
|
||||||
|
|
||||||
|
If using more than the given limit, free memory will be reclaimed
|
||||||
|
from the cache on the next allocation. To disable the cache, set
|
||||||
|
the limit to ``0``.
|
||||||
|
|
||||||
|
The cache limit defaults to the memory limit. See
|
||||||
|
:func:`set_memory_limit` for more details.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
limit (int): The cache limit in bytes.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: The previous cache limit in bytes.
|
||||||
|
)pbdoc");
|
||||||
|
m.def(
|
||||||
|
"set_wired_limit",
|
||||||
|
&mx::set_wired_limit,
|
||||||
|
"limit"_a,
|
||||||
|
R"pbdoc(
|
||||||
|
Set the wired size limit.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
* This function is only useful on macOS 15.0 or higher.
|
||||||
|
* The wired limit should remain strictly less than the total
|
||||||
|
memory size.
|
||||||
|
|
||||||
|
The wired limit is the total size in bytes of memory that will be kept
|
||||||
|
resident. The default value is ``0``.
|
||||||
|
|
||||||
|
Setting a wired limit larger than system wired limit is an error. You can
|
||||||
|
increase the system wired limit with:
|
||||||
|
|
||||||
|
.. code-block::
|
||||||
|
|
||||||
|
sudo sysctl iogpu.wired_limit_mb=<size_in_megabytes>
|
||||||
|
|
||||||
|
Use :func:`device_info` to query the system wired limit
|
||||||
|
(``"max_recommended_working_set_size"``) and the total memory size
|
||||||
|
(``"memory_size"``).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
limit (int): The wired limit in bytes.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: The previous wired limit in bytes.
|
||||||
|
)pbdoc");
|
||||||
|
m.def(
|
||||||
|
"clear_cache",
|
||||||
|
&mx::clear_cache,
|
||||||
|
R"pbdoc(
|
||||||
|
Clear the memory cache.
|
||||||
|
|
||||||
|
After calling this, :func:`get_cache_memory` should return ``0``.
|
||||||
|
)pbdoc");
|
||||||
|
}
|
@ -1,17 +1,27 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
#include "mlx/backend/metal/metal.h"
|
|
||||||
#include <nanobind/nanobind.h>
|
#include <nanobind/nanobind.h>
|
||||||
#include <nanobind/stl/optional.h>
|
#include <nanobind/stl/optional.h>
|
||||||
#include <nanobind/stl/string.h>
|
#include <nanobind/stl/string.h>
|
||||||
#include <nanobind/stl/unordered_map.h>
|
#include <nanobind/stl/unordered_map.h>
|
||||||
#include <nanobind/stl/variant.h>
|
#include <nanobind/stl/variant.h>
|
||||||
#include <nanobind/stl/vector.h>
|
#include <nanobind/stl/vector.h>
|
||||||
|
#include "mlx/backend/metal/metal.h"
|
||||||
|
#include "mlx/memory.h"
|
||||||
|
|
||||||
namespace mx = mlx::core;
|
namespace mx = mlx::core;
|
||||||
namespace nb = nanobind;
|
namespace nb = nanobind;
|
||||||
using namespace nb::literals;
|
using namespace nb::literals;
|
||||||
|
|
||||||
|
bool DEPRECATE(const std::string& old_fn, const std::string new_fn) {
|
||||||
|
std::cerr << old_fn << " is deprecated and will be removed in a future "
|
||||||
|
<< "version. Use " << new_fn << " instead." << std::endl;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
#define DEPRECATE(oldfn, newfn) static bool dep = DEPRECATE(oldfn, newfn)
|
||||||
|
|
||||||
void init_metal(nb::module_& m) {
|
void init_metal(nb::module_& m) {
|
||||||
nb::module_ metal = m.def_submodule("metal", "mlx.metal");
|
nb::module_ metal = m.def_submodule("metal", "mlx.metal");
|
||||||
metal.def(
|
metal.def(
|
||||||
@ -20,121 +30,47 @@ void init_metal(nb::module_& m) {
|
|||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
Check if the Metal back-end is available.
|
Check if the Metal back-end is available.
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
metal.def(
|
metal.def("get_active_memory", []() {
|
||||||
"get_active_memory",
|
DEPRECATE("mx.metal.get_active_memory", "mx.get_active_memory");
|
||||||
&mx::metal::get_active_memory,
|
return mx::get_active_memory();
|
||||||
R"pbdoc(
|
});
|
||||||
Get the actively used memory in bytes.
|
metal.def("get_peak_memory", []() {
|
||||||
|
DEPRECATE("mx.metal.get_peak_memory", "mx.get_peak_memory");
|
||||||
Note, this will not always match memory use reported by the system because
|
return mx::get_active_memory();
|
||||||
it does not include cached memory buffers.
|
});
|
||||||
)pbdoc");
|
metal.def("reset_peak_memory", []() {
|
||||||
metal.def(
|
DEPRECATE("mx.metal.reset_peak_memory", "mx.reset_peak_memory");
|
||||||
"get_peak_memory",
|
mx::reset_peak_memory();
|
||||||
&mx::metal::get_peak_memory,
|
});
|
||||||
R"pbdoc(
|
metal.def("get_cache_memory", []() {
|
||||||
Get the peak amount of used memory in bytes.
|
DEPRECATE("mx.metal.get_cache_memory", "mx.get_cache_memory");
|
||||||
|
return mx::get_cache_memory();
|
||||||
The maximum memory used recorded from the beginning of the program
|
});
|
||||||
execution or since the last call to :func:`reset_peak_memory`.
|
|
||||||
)pbdoc");
|
|
||||||
metal.def(
|
|
||||||
"reset_peak_memory",
|
|
||||||
&mx::metal::reset_peak_memory,
|
|
||||||
R"pbdoc(
|
|
||||||
Reset the peak memory to zero.
|
|
||||||
)pbdoc");
|
|
||||||
metal.def(
|
|
||||||
"get_cache_memory",
|
|
||||||
&mx::metal::get_cache_memory,
|
|
||||||
R"pbdoc(
|
|
||||||
Get the cache size in bytes.
|
|
||||||
|
|
||||||
The cache includes memory not currently used that has not been returned
|
|
||||||
to the system allocator.
|
|
||||||
)pbdoc");
|
|
||||||
metal.def(
|
metal.def(
|
||||||
"set_memory_limit",
|
"set_memory_limit",
|
||||||
&mx::metal::set_memory_limit,
|
[](size_t limit) {
|
||||||
"limit"_a,
|
DEPRECATE("mx.metal.set_memory_limt", "mx.set_memory_limit");
|
||||||
R"pbdoc(
|
return mx::set_memory_limit(limit);
|
||||||
Set the memory limit.
|
},
|
||||||
|
"limit"_a);
|
||||||
The memory limit is a guideline for the maximum amount of memory to use
|
|
||||||
during graph evaluation. If the memory limit is exceeded and there is no
|
|
||||||
more RAM (including swap when available) allocations will result in an
|
|
||||||
exception.
|
|
||||||
|
|
||||||
When metal is available the memory limit defaults to 1.5 times the
|
|
||||||
maximum recommended working set size reported by the device.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
limit (int): Memory limit in bytes.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
int: The previous memory limit in bytes.
|
|
||||||
)pbdoc");
|
|
||||||
metal.def(
|
metal.def(
|
||||||
"set_cache_limit",
|
"set_cache_limit",
|
||||||
&mx::metal::set_cache_limit,
|
[](size_t limit) {
|
||||||
"limit"_a,
|
DEPRECATE("mx.metal.set_cache_limt", "mx.set_cache_limit");
|
||||||
R"pbdoc(
|
return mx::set_cache_limit(limit);
|
||||||
Set the free cache limit.
|
},
|
||||||
|
"limit"_a);
|
||||||
If using more than the given limit, free memory will be reclaimed
|
|
||||||
from the cache on the next allocation. To disable the cache, set
|
|
||||||
the limit to ``0``.
|
|
||||||
|
|
||||||
The cache limit defaults to the memory limit. See
|
|
||||||
:func:`set_memory_limit` for more details.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
limit (int): The cache limit in bytes.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
int: The previous cache limit in bytes.
|
|
||||||
)pbdoc");
|
|
||||||
metal.def(
|
metal.def(
|
||||||
"set_wired_limit",
|
"set_wired_limit",
|
||||||
&mx::metal::set_wired_limit,
|
[](size_t limit) {
|
||||||
"limit"_a,
|
DEPRECATE("mx.metal.set_wired_limt", "mx.set_wired_limit");
|
||||||
R"pbdoc(
|
return mx::set_wired_limit(limit);
|
||||||
Set the wired size limit.
|
},
|
||||||
|
"limit"_a);
|
||||||
.. note::
|
metal.def("clear_cache", []() {
|
||||||
* This function is only useful on macOS 15.0 or higher.
|
DEPRECATE("mx.metal.clear_cache", "mx.clear_cache");
|
||||||
* The wired limit should remain strictly less than the total
|
mx::clear_cache();
|
||||||
memory size.
|
});
|
||||||
|
|
||||||
The wired limit is the total size in bytes of memory that will be kept
|
|
||||||
resident. The default value is ``0``.
|
|
||||||
|
|
||||||
Setting a wired limit larger than system wired limit is an error. You can
|
|
||||||
increase the system wired limit with:
|
|
||||||
|
|
||||||
.. code-block::
|
|
||||||
|
|
||||||
sudo sysctl iogpu.wired_limit_mb=<size_in_megabytes>
|
|
||||||
|
|
||||||
Use :func:`device_info` to query the system wired limit
|
|
||||||
(``"max_recommended_working_set_size"``) and the total memory size
|
|
||||||
(``"memory_size"``).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
limit (int): The wired limit in bytes.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
int: The previous wired limit in bytes.
|
|
||||||
)pbdoc");
|
|
||||||
metal.def(
|
|
||||||
"clear_cache",
|
|
||||||
&mx::metal::clear_cache,
|
|
||||||
R"pbdoc(
|
|
||||||
Clear the memory cache.
|
|
||||||
|
|
||||||
After calling this, :func:`get_cache_memory` should return ``0``.
|
|
||||||
)pbdoc");
|
|
||||||
|
|
||||||
metal.def(
|
metal.def(
|
||||||
"start_capture",
|
"start_capture",
|
||||||
&mx::metal::start_capture,
|
&mx::metal::start_capture,
|
||||||
|
@ -12,6 +12,7 @@ void init_array(nb::module_&);
|
|||||||
void init_device(nb::module_&);
|
void init_device(nb::module_&);
|
||||||
void init_stream(nb::module_&);
|
void init_stream(nb::module_&);
|
||||||
void init_metal(nb::module_&);
|
void init_metal(nb::module_&);
|
||||||
|
void init_memory(nb::module_&);
|
||||||
void init_ops(nb::module_&);
|
void init_ops(nb::module_&);
|
||||||
void init_transforms(nb::module_&);
|
void init_transforms(nb::module_&);
|
||||||
void init_random(nb::module_&);
|
void init_random(nb::module_&);
|
||||||
@ -34,6 +35,7 @@ NB_MODULE(core, m) {
|
|||||||
init_stream(m);
|
init_stream(m);
|
||||||
init_array(m);
|
init_array(m);
|
||||||
init_metal(m);
|
init_metal(m);
|
||||||
|
init_memory(m);
|
||||||
init_ops(m);
|
init_ops(m);
|
||||||
init_transforms(m);
|
init_transforms(m);
|
||||||
init_random(m);
|
init_random(m);
|
||||||
|
@ -179,16 +179,16 @@ class TestDistributed(mlx_tests.MLXTestCase):
|
|||||||
mx.eval(x)
|
mx.eval(x)
|
||||||
mx.synchronize(mx.default_stream(mx.default_device()))
|
mx.synchronize(mx.default_stream(mx.default_device()))
|
||||||
|
|
||||||
mx.metal.reset_peak_memory()
|
mx.reset_peak_memory()
|
||||||
scale = mx.array(2.0)
|
scale = mx.array(2.0)
|
||||||
y = mx.distributed.all_sum(x)
|
y = mx.distributed.all_sum(x)
|
||||||
mx.eval(y)
|
mx.eval(y)
|
||||||
mx.synchronize(mx.default_stream(mx.default_device()))
|
mx.synchronize(mx.default_stream(mx.default_device()))
|
||||||
all_sum_only = mx.metal.get_peak_memory()
|
all_sum_only = mx.get_peak_memory()
|
||||||
y = mx.distributed.all_sum(x) * scale
|
y = mx.distributed.all_sum(x) * scale
|
||||||
mx.eval(y)
|
mx.eval(y)
|
||||||
mx.synchronize(mx.default_stream(mx.default_device()))
|
mx.synchronize(mx.default_stream(mx.default_device()))
|
||||||
all_sum_with_binary = mx.metal.get_peak_memory()
|
all_sum_with_binary = mx.get_peak_memory()
|
||||||
|
|
||||||
self.assertEqual(all_sum_only, all_sum_with_binary)
|
self.assertEqual(all_sum_only, all_sum_with_binary)
|
||||||
|
|
||||||
|
@ -1813,10 +1813,10 @@ class TestArray(mlx_tests.MLXTestCase):
|
|||||||
|
|
||||||
fun()
|
fun()
|
||||||
mx.synchronize()
|
mx.synchronize()
|
||||||
peak_1 = mx.metal.get_peak_memory()
|
peak_1 = mx.get_peak_memory()
|
||||||
fun()
|
fun()
|
||||||
mx.synchronize()
|
mx.synchronize()
|
||||||
peak_2 = mx.metal.get_peak_memory()
|
peak_2 = mx.get_peak_memory()
|
||||||
self.assertEqual(peak_1, peak_2)
|
self.assertEqual(peak_1, peak_2)
|
||||||
|
|
||||||
def fun():
|
def fun():
|
||||||
@ -1826,10 +1826,10 @@ class TestArray(mlx_tests.MLXTestCase):
|
|||||||
|
|
||||||
fun()
|
fun()
|
||||||
mx.synchronize()
|
mx.synchronize()
|
||||||
peak_1 = mx.metal.get_peak_memory()
|
peak_1 = mx.get_peak_memory()
|
||||||
fun()
|
fun()
|
||||||
mx.synchronize()
|
mx.synchronize()
|
||||||
peak_2 = mx.metal.get_peak_memory()
|
peak_2 = mx.get_peak_memory()
|
||||||
self.assertEqual(peak_1, peak_2)
|
self.assertEqual(peak_1, peak_2)
|
||||||
|
|
||||||
def test_add_numpy(self):
|
def test_add_numpy(self):
|
||||||
|
@ -747,7 +747,7 @@ class TestAutograd(mlx_tests.MLXTestCase):
|
|||||||
]:
|
]:
|
||||||
if mx.metal.is_available():
|
if mx.metal.is_available():
|
||||||
mx.synchronize(mx.default_stream(mx.default_device()))
|
mx.synchronize(mx.default_stream(mx.default_device()))
|
||||||
mem_pre = mx.metal.get_active_memory()
|
mem_pre = mx.get_active_memory()
|
||||||
else:
|
else:
|
||||||
mem_pre = 0
|
mem_pre = 0
|
||||||
|
|
||||||
@ -765,7 +765,7 @@ class TestAutograd(mlx_tests.MLXTestCase):
|
|||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
if mx.metal.is_available():
|
if mx.metal.is_available():
|
||||||
mem_post = mx.metal.get_active_memory()
|
mem_post = mx.get_active_memory()
|
||||||
else:
|
else:
|
||||||
mem_post = 0
|
mem_post = 0
|
||||||
|
|
||||||
|
@ -955,7 +955,7 @@ class TestCompile(mlx_tests.MLXTestCase):
|
|||||||
def test_leaks(self):
|
def test_leaks(self):
|
||||||
gc.collect()
|
gc.collect()
|
||||||
if mx.metal.is_available():
|
if mx.metal.is_available():
|
||||||
mem_pre = mx.metal.get_active_memory()
|
mem_pre = mx.get_active_memory()
|
||||||
else:
|
else:
|
||||||
mem_pre = 0
|
mem_pre = 0
|
||||||
|
|
||||||
@ -973,7 +973,7 @@ class TestCompile(mlx_tests.MLXTestCase):
|
|||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
if mx.metal.is_available():
|
if mx.metal.is_available():
|
||||||
mem_post = mx.metal.get_active_memory()
|
mem_post = mx.get_active_memory()
|
||||||
else:
|
else:
|
||||||
mem_post = 0
|
mem_post = 0
|
||||||
|
|
||||||
|
@ -118,9 +118,9 @@ class TestEval(mlx_tests.MLXTestCase):
|
|||||||
out = mx.vjp(fn, (x,), (y,))
|
out = mx.vjp(fn, (x,), (y,))
|
||||||
out = mx.vjp(fn, (x,), (y,))
|
out = mx.vjp(fn, (x,), (y,))
|
||||||
if mx.metal.is_available():
|
if mx.metal.is_available():
|
||||||
peak_mem = mx.metal.get_peak_memory()
|
peak_mem = mx.get_peak_memory()
|
||||||
out = mx.vjp(fn, (x,), (y,))
|
out = mx.vjp(fn, (x,), (y,))
|
||||||
self.assertEqual(peak_mem, mx.metal.get_peak_memory())
|
self.assertEqual(peak_mem, mx.get_peak_memory())
|
||||||
|
|
||||||
def test_async_eval_with_multiple_streams(self):
|
def test_async_eval_with_multiple_streams(self):
|
||||||
x = mx.array([1.0])
|
x = mx.array([1.0])
|
||||||
@ -151,11 +151,11 @@ class TestEval(mlx_tests.MLXTestCase):
|
|||||||
|
|
||||||
x = mx.zeros((4096, 4096))
|
x = mx.zeros((4096, 4096))
|
||||||
mx.eval(x)
|
mx.eval(x)
|
||||||
pre = mx.metal.get_peak_memory()
|
pre = mx.get_peak_memory()
|
||||||
out = fun(x)
|
out = fun(x)
|
||||||
del x
|
del x
|
||||||
mx.eval(out)
|
mx.eval(out)
|
||||||
post = mx.metal.get_peak_memory()
|
post = mx.get_peak_memory()
|
||||||
self.assertEqual(pre, post)
|
self.assertEqual(pre, post)
|
||||||
|
|
||||||
def fun(x):
|
def fun(x):
|
||||||
@ -167,11 +167,11 @@ class TestEval(mlx_tests.MLXTestCase):
|
|||||||
|
|
||||||
x = mx.zeros((4096 * 4096,))
|
x = mx.zeros((4096 * 4096,))
|
||||||
mx.eval(x)
|
mx.eval(x)
|
||||||
pre = mx.metal.get_peak_memory()
|
pre = mx.get_peak_memory()
|
||||||
out = fun(x)
|
out = fun(x)
|
||||||
del x
|
del x
|
||||||
mx.eval(out)
|
mx.eval(out)
|
||||||
post = mx.metal.get_peak_memory()
|
post = mx.get_peak_memory()
|
||||||
self.assertEqual(pre, post)
|
self.assertEqual(pre, post)
|
||||||
|
|
||||||
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
|
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
|
||||||
@ -187,7 +187,7 @@ class TestEval(mlx_tests.MLXTestCase):
|
|||||||
|
|
||||||
s1 = mx.default_stream(mx.gpu)
|
s1 = mx.default_stream(mx.gpu)
|
||||||
s2 = mx.new_stream(mx.gpu)
|
s2 = mx.new_stream(mx.gpu)
|
||||||
old_limit = mx.metal.set_memory_limit(1000)
|
old_limit = mx.set_memory_limit(1000)
|
||||||
|
|
||||||
x = mx.ones((512, 512), stream=s2)
|
x = mx.ones((512, 512), stream=s2)
|
||||||
for _ in range(80):
|
for _ in range(80):
|
||||||
@ -195,7 +195,7 @@ class TestEval(mlx_tests.MLXTestCase):
|
|||||||
y = mx.abs(x, stream=s2)
|
y = mx.abs(x, stream=s2)
|
||||||
z = mx.abs(y, stream=s2)
|
z = mx.abs(y, stream=s2)
|
||||||
mx.eval(z)
|
mx.eval(z)
|
||||||
mx.metal.set_memory_limit(old_limit)
|
mx.set_memory_limit(old_limit)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -243,7 +243,7 @@ class TestExportImport(mlx_tests.MLXTestCase):
|
|||||||
def test_leaks(self):
|
def test_leaks(self):
|
||||||
path = os.path.join(self.test_dir, "fn.mlxfn")
|
path = os.path.join(self.test_dir, "fn.mlxfn")
|
||||||
if mx.metal.is_available():
|
if mx.metal.is_available():
|
||||||
mem_pre = mx.metal.get_active_memory()
|
mem_pre = mx.get_active_memory()
|
||||||
else:
|
else:
|
||||||
mem_pre = 0
|
mem_pre = 0
|
||||||
|
|
||||||
@ -261,7 +261,7 @@ class TestExportImport(mlx_tests.MLXTestCase):
|
|||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
if mx.metal.is_available():
|
if mx.metal.is_available():
|
||||||
mem_post = mx.metal.get_active_memory()
|
mem_post = mx.get_active_memory()
|
||||||
else:
|
else:
|
||||||
mem_post = 0
|
mem_post = 0
|
||||||
|
|
||||||
|
@ -387,14 +387,14 @@ class TestLoad(mlx_tests.MLXTestCase):
|
|||||||
mx.save(save_file, x)
|
mx.save(save_file, x)
|
||||||
mx.synchronize(mx.default_stream(mx.default_device()))
|
mx.synchronize(mx.default_stream(mx.default_device()))
|
||||||
|
|
||||||
mx.metal.reset_peak_memory()
|
mx.reset_peak_memory()
|
||||||
scale = mx.array(2.0)
|
scale = mx.array(2.0)
|
||||||
y = mx.load(save_file)
|
y = mx.load(save_file)
|
||||||
mx.eval(y)
|
mx.eval(y)
|
||||||
load_only = mx.metal.get_peak_memory()
|
load_only = mx.get_peak_memory()
|
||||||
y = mx.load(save_file) * scale
|
y = mx.load(save_file) * scale
|
||||||
mx.eval(y)
|
mx.eval(y)
|
||||||
load_with_binary = mx.metal.get_peak_memory()
|
load_with_binary = mx.get_peak_memory()
|
||||||
|
|
||||||
self.assertEqual(load_only, load_with_binary)
|
self.assertEqual(load_only, load_with_binary)
|
||||||
|
|
||||||
|
60
python/tests/test_memory.py
Normal file
60
python/tests/test_memory.py
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx_tests
|
||||||
|
|
||||||
|
|
||||||
|
class TestMemory(mlx_tests.MLXTestCase):
|
||||||
|
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
|
||||||
|
def test_memory_info(self):
|
||||||
|
old_limit = mx.set_cache_limit(0)
|
||||||
|
|
||||||
|
a = mx.zeros((4096,))
|
||||||
|
mx.eval(a)
|
||||||
|
del a
|
||||||
|
self.assertEqual(mx.get_cache_memory(), 0)
|
||||||
|
self.assertEqual(mx.set_cache_limit(old_limit), 0)
|
||||||
|
self.assertEqual(mx.set_cache_limit(old_limit), old_limit)
|
||||||
|
|
||||||
|
old_limit = mx.set_memory_limit(10)
|
||||||
|
self.assertTrue(mx.set_memory_limit(old_limit), 10)
|
||||||
|
self.assertTrue(mx.set_memory_limit(old_limit), old_limit)
|
||||||
|
|
||||||
|
# Query active and peak memory
|
||||||
|
a = mx.zeros((4096,))
|
||||||
|
mx.eval(a)
|
||||||
|
mx.synchronize()
|
||||||
|
active_mem = mx.get_active_memory()
|
||||||
|
self.assertTrue(active_mem >= 4096 * 4)
|
||||||
|
|
||||||
|
b = mx.zeros((4096,))
|
||||||
|
mx.eval(b)
|
||||||
|
del b
|
||||||
|
mx.synchronize()
|
||||||
|
|
||||||
|
new_active_mem = mx.get_active_memory()
|
||||||
|
self.assertEqual(new_active_mem, active_mem)
|
||||||
|
peak_mem = mx.get_peak_memory()
|
||||||
|
self.assertTrue(peak_mem >= 4096 * 8)
|
||||||
|
cache_mem = mx.get_cache_memory()
|
||||||
|
self.assertTrue(cache_mem >= 4096 * 4)
|
||||||
|
|
||||||
|
mx.clear_cache()
|
||||||
|
self.assertEqual(mx.get_cache_memory(), 0)
|
||||||
|
|
||||||
|
mx.reset_peak_memory()
|
||||||
|
self.assertEqual(mx.get_peak_memory(), 0)
|
||||||
|
|
||||||
|
old_limit = mx.set_wired_limit(1000)
|
||||||
|
old_limit = mx.set_wired_limit(0)
|
||||||
|
self.assertEqual(old_limit, 1000)
|
||||||
|
|
||||||
|
max_size = mx.metal.device_info()["max_recommended_working_set_size"]
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
mx.set_wired_limit(max_size + 10)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
@ -1,60 +0,0 @@
|
|||||||
# Copyright © 2023-2024 Apple Inc.
|
|
||||||
|
|
||||||
import unittest
|
|
||||||
|
|
||||||
import mlx.core as mx
|
|
||||||
import mlx_tests
|
|
||||||
|
|
||||||
|
|
||||||
class TestMetal(mlx_tests.MLXTestCase):
|
|
||||||
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
|
|
||||||
def test_memory_info(self):
|
|
||||||
old_limit = mx.metal.set_cache_limit(0)
|
|
||||||
|
|
||||||
a = mx.zeros((4096,))
|
|
||||||
mx.eval(a)
|
|
||||||
del a
|
|
||||||
self.assertEqual(mx.metal.get_cache_memory(), 0)
|
|
||||||
self.assertEqual(mx.metal.set_cache_limit(old_limit), 0)
|
|
||||||
self.assertEqual(mx.metal.set_cache_limit(old_limit), old_limit)
|
|
||||||
|
|
||||||
old_limit = mx.metal.set_memory_limit(10)
|
|
||||||
self.assertTrue(mx.metal.set_memory_limit(old_limit), 10)
|
|
||||||
self.assertTrue(mx.metal.set_memory_limit(old_limit), old_limit)
|
|
||||||
|
|
||||||
# Query active and peak memory
|
|
||||||
a = mx.zeros((4096,))
|
|
||||||
mx.eval(a)
|
|
||||||
mx.synchronize()
|
|
||||||
active_mem = mx.metal.get_active_memory()
|
|
||||||
self.assertTrue(active_mem >= 4096 * 4)
|
|
||||||
|
|
||||||
b = mx.zeros((4096,))
|
|
||||||
mx.eval(b)
|
|
||||||
del b
|
|
||||||
mx.synchronize()
|
|
||||||
|
|
||||||
new_active_mem = mx.metal.get_active_memory()
|
|
||||||
self.assertEqual(new_active_mem, active_mem)
|
|
||||||
peak_mem = mx.metal.get_peak_memory()
|
|
||||||
self.assertTrue(peak_mem >= 4096 * 8)
|
|
||||||
cache_mem = mx.metal.get_cache_memory()
|
|
||||||
self.assertTrue(cache_mem >= 4096 * 4)
|
|
||||||
|
|
||||||
mx.metal.clear_cache()
|
|
||||||
self.assertEqual(mx.metal.get_cache_memory(), 0)
|
|
||||||
|
|
||||||
mx.metal.reset_peak_memory()
|
|
||||||
self.assertEqual(mx.metal.get_peak_memory(), 0)
|
|
||||||
|
|
||||||
old_limit = mx.metal.set_wired_limit(1000)
|
|
||||||
old_limit = mx.metal.set_wired_limit(0)
|
|
||||||
self.assertEqual(old_limit, 1000)
|
|
||||||
|
|
||||||
max_size = mx.metal.device_info()["max_recommended_working_set_size"]
|
|
||||||
with self.assertRaises(ValueError):
|
|
||||||
mx.metal.set_wired_limit(max_size + 10)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
unittest.main()
|
|
@ -1904,10 +1904,10 @@ class TestOps(mlx_tests.MLXTestCase):
|
|||||||
mx.synchronize(mx.default_stream(mx.default_device()))
|
mx.synchronize(mx.default_stream(mx.default_device()))
|
||||||
mx.eval(fn(2))
|
mx.eval(fn(2))
|
||||||
mx.synchronize(mx.default_stream(mx.default_device()))
|
mx.synchronize(mx.default_stream(mx.default_device()))
|
||||||
mem2 = mx.metal.get_peak_memory()
|
mem2 = mx.get_peak_memory()
|
||||||
mx.eval(fn(4))
|
mx.eval(fn(4))
|
||||||
mx.synchronize(mx.default_stream(mx.default_device()))
|
mx.synchronize(mx.default_stream(mx.default_device()))
|
||||||
mem4 = mx.metal.get_peak_memory()
|
mem4 = mx.get_peak_memory()
|
||||||
self.assertEqual(mem2, mem4)
|
self.assertEqual(mem2, mem4)
|
||||||
|
|
||||||
def test_squeeze_expand(self):
|
def test_squeeze_expand(self):
|
||||||
|
@ -635,7 +635,7 @@ class TestVmap(mlx_tests.MLXTestCase):
|
|||||||
|
|
||||||
def test_leaks(self):
|
def test_leaks(self):
|
||||||
if mx.metal.is_available():
|
if mx.metal.is_available():
|
||||||
mem_pre = mx.metal.get_active_memory()
|
mem_pre = mx.get_active_memory()
|
||||||
else:
|
else:
|
||||||
mem_pre = 0
|
mem_pre = 0
|
||||||
|
|
||||||
@ -653,7 +653,7 @@ class TestVmap(mlx_tests.MLXTestCase):
|
|||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
if mx.metal.is_available():
|
if mx.metal.is_available():
|
||||||
mem_post = mx.metal.get_active_memory()
|
mem_post = mx.get_active_memory()
|
||||||
else:
|
else:
|
||||||
mem_post = 0
|
mem_post = 0
|
||||||
|
|
||||||
|
@ -473,24 +473,24 @@ TEST_CASE("test metal validation") {
|
|||||||
eval(scatter_max(array(1), {}, array(2), std::vector<int>{}));
|
eval(scatter_max(array(1), {}, array(2), std::vector<int>{}));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_CASE("test metal memory info") {
|
TEST_CASE("test memory info") {
|
||||||
// Test cache limits
|
// Test cache limits
|
||||||
{
|
{
|
||||||
auto old_limit = metal::set_cache_limit(0);
|
auto old_limit = set_cache_limit(0);
|
||||||
{
|
{
|
||||||
auto a = zeros({4096});
|
auto a = zeros({4096});
|
||||||
eval(a);
|
eval(a);
|
||||||
}
|
}
|
||||||
CHECK_EQ(metal::get_cache_memory(), 0);
|
CHECK_EQ(get_cache_memory(), 0);
|
||||||
CHECK_EQ(metal::set_cache_limit(old_limit), 0);
|
CHECK_EQ(set_cache_limit(old_limit), 0);
|
||||||
CHECK_EQ(metal::set_cache_limit(old_limit), old_limit);
|
CHECK_EQ(set_cache_limit(old_limit), old_limit);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test memory limits
|
// Test memory limits
|
||||||
{
|
{
|
||||||
auto old_limit = metal::set_memory_limit(10);
|
auto old_limit = set_memory_limit(10);
|
||||||
CHECK_EQ(metal::set_memory_limit(old_limit), 10);
|
CHECK_EQ(set_memory_limit(old_limit), 10);
|
||||||
CHECK_EQ(metal::set_memory_limit(old_limit), old_limit);
|
CHECK_EQ(set_memory_limit(old_limit), old_limit);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Query active and peak memory
|
// Query active and peak memory
|
||||||
@ -498,22 +498,22 @@ TEST_CASE("test metal memory info") {
|
|||||||
auto a = zeros({4096});
|
auto a = zeros({4096});
|
||||||
eval(a);
|
eval(a);
|
||||||
synchronize();
|
synchronize();
|
||||||
auto active_mem = metal::get_active_memory();
|
auto active_mem = get_active_memory();
|
||||||
CHECK(active_mem >= 4096 * 4);
|
CHECK(active_mem >= 4096 * 4);
|
||||||
{
|
{
|
||||||
auto b = zeros({4096});
|
auto b = zeros({4096});
|
||||||
eval(b);
|
eval(b);
|
||||||
}
|
}
|
||||||
synchronize();
|
synchronize();
|
||||||
auto new_active_mem = metal::get_active_memory();
|
auto new_active_mem = get_active_memory();
|
||||||
CHECK_EQ(new_active_mem, active_mem);
|
CHECK_EQ(new_active_mem, active_mem);
|
||||||
auto peak_mem = metal::get_peak_memory();
|
auto peak_mem = get_peak_memory();
|
||||||
CHECK(peak_mem >= 4096 * 8);
|
CHECK(peak_mem >= 4096 * 8);
|
||||||
|
|
||||||
auto cache_mem = metal::get_cache_memory();
|
auto cache_mem = get_cache_memory();
|
||||||
CHECK(cache_mem >= 4096 * 4);
|
CHECK(cache_mem >= 4096 * 4);
|
||||||
}
|
}
|
||||||
|
|
||||||
metal::clear_cache();
|
clear_cache();
|
||||||
CHECK_EQ(metal::get_cache_memory(), 0);
|
CHECK_EQ(get_cache_memory(), 0);
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user