mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Expose function to clear memory cache (#1032)
* expose function to clear memory cache * fix linux build * fix metal tests
This commit is contained in:
parent
20a01bbd9f
commit
771575d27b
@ -12,5 +12,6 @@ Metal
|
||||
get_cache_memory
|
||||
set_memory_limit
|
||||
set_cache_limit
|
||||
clear_cache
|
||||
start_capture
|
||||
stop_capture
|
||||
|
@ -92,6 +92,7 @@ Operations
|
||||
moveaxis
|
||||
multiply
|
||||
negative
|
||||
not_equal
|
||||
ones
|
||||
ones_like
|
||||
outer
|
||||
|
@ -209,6 +209,11 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
|
||||
return Buffer{static_cast<void*>(buf)};
|
||||
}
|
||||
|
||||
void MetalAllocator::clear_cache() {
|
||||
std::unique_lock lk(mutex_);
|
||||
buffer_cache_.clear();
|
||||
}
|
||||
|
||||
void MetalAllocator::free(Buffer buffer) {
|
||||
auto buf = static_cast<MTL::Buffer*>(buffer.ptr());
|
||||
std::unique_lock lk(mutex_);
|
||||
@ -242,6 +247,9 @@ size_t get_peak_memory() {
|
||||
size_t get_cache_memory() {
|
||||
return allocator().get_cache_memory();
|
||||
}
|
||||
void clear_cache() {
|
||||
return allocator().clear_cache();
|
||||
}
|
||||
|
||||
} // namespace metal
|
||||
|
||||
|
@ -26,6 +26,7 @@ class BufferCache {
|
||||
size_t cache_size() {
|
||||
return pool_size_;
|
||||
}
|
||||
void clear();
|
||||
|
||||
private:
|
||||
struct BufferHolder {
|
||||
@ -37,7 +38,6 @@ class BufferCache {
|
||||
MTL::Buffer* buf;
|
||||
};
|
||||
|
||||
void clear();
|
||||
void add_at_head(BufferHolder* to_add);
|
||||
void remove_from_list(BufferHolder* to_remove);
|
||||
|
||||
@ -67,6 +67,7 @@ class MetalAllocator : public allocator::Allocator {
|
||||
};
|
||||
size_t set_cache_limit(size_t limit);
|
||||
size_t set_memory_limit(size_t limit, bool relaxed);
|
||||
void clear_cache();
|
||||
|
||||
private:
|
||||
MTL::Device* device_;
|
||||
|
@ -54,6 +54,9 @@ size_t set_memory_limit(size_t limit, bool relaxed = true);
|
||||
* */
|
||||
size_t set_cache_limit(size_t limit);
|
||||
|
||||
/* Clear the memory cache. */
|
||||
void clear_cache();
|
||||
|
||||
/** Capture a GPU trace, saving it to an absolute file `path` */
|
||||
void start_capture(std::string path = "");
|
||||
void stop_capture();
|
||||
|
@ -4,7 +4,6 @@
|
||||
|
||||
#include "mlx/backend/metal/metal.h"
|
||||
#include "mlx/backend/metal/metal_impl.h"
|
||||
|
||||
namespace mlx::core::metal {
|
||||
|
||||
bool is_available() {
|
||||
@ -48,5 +47,6 @@ size_t set_cache_limit(size_t) {
|
||||
}
|
||||
void start_capture(std::string path) {}
|
||||
void stop_capture() {}
|
||||
void clear_cache() {}
|
||||
|
||||
} // namespace mlx::core::metal
|
||||
|
@ -90,6 +90,15 @@ void init_metal(nb::module_& m) {
|
||||
Returns:
|
||||
int: The previous cache limit in bytes.
|
||||
)pbdoc");
|
||||
metal.def(
|
||||
"clear_cache",
|
||||
&metal::clear_cache,
|
||||
R"pbdoc(
|
||||
Clear the memory cache.
|
||||
|
||||
After calling this, :func:`get_cache_memory` should return ``0``.
|
||||
)pbdoc");
|
||||
|
||||
metal.def(
|
||||
"start_capture",
|
||||
&metal::start_capture,
|
||||
|
@ -42,6 +42,9 @@ class TestMetal(mlx_tests.MLXTestCase):
|
||||
cache_mem = mx.metal.get_cache_memory()
|
||||
self.assertTrue(cache_mem >= 4096 * 4)
|
||||
|
||||
mx.metal.clear_cache()
|
||||
self.assertEqual(mx.metal.get_cache_memory(), 0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
@ -513,4 +513,7 @@ TEST_CASE("test metal memory info") {
|
||||
auto cache_mem = metal::get_cache_memory();
|
||||
CHECK(cache_mem >= 4096 * 4);
|
||||
}
|
||||
|
||||
metal::clear_cache();
|
||||
CHECK_EQ(metal::get_cache_memory(), 0);
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user