mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 18:11:15 +08:00
Expose function to clear memory cache (#1032)
* expose function to clear memory cache * fix linux build * fix metal tests
This commit is contained in:
parent
20a01bbd9f
commit
771575d27b
@ -12,5 +12,6 @@ Metal
|
|||||||
get_cache_memory
|
get_cache_memory
|
||||||
set_memory_limit
|
set_memory_limit
|
||||||
set_cache_limit
|
set_cache_limit
|
||||||
|
clear_cache
|
||||||
start_capture
|
start_capture
|
||||||
stop_capture
|
stop_capture
|
||||||
|
@ -92,6 +92,7 @@ Operations
|
|||||||
moveaxis
|
moveaxis
|
||||||
multiply
|
multiply
|
||||||
negative
|
negative
|
||||||
|
not_equal
|
||||||
ones
|
ones
|
||||||
ones_like
|
ones_like
|
||||||
outer
|
outer
|
||||||
|
@ -209,6 +209,11 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
|
|||||||
return Buffer{static_cast<void*>(buf)};
|
return Buffer{static_cast<void*>(buf)};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void MetalAllocator::clear_cache() {
|
||||||
|
std::unique_lock lk(mutex_);
|
||||||
|
buffer_cache_.clear();
|
||||||
|
}
|
||||||
|
|
||||||
void MetalAllocator::free(Buffer buffer) {
|
void MetalAllocator::free(Buffer buffer) {
|
||||||
auto buf = static_cast<MTL::Buffer*>(buffer.ptr());
|
auto buf = static_cast<MTL::Buffer*>(buffer.ptr());
|
||||||
std::unique_lock lk(mutex_);
|
std::unique_lock lk(mutex_);
|
||||||
@ -242,6 +247,9 @@ size_t get_peak_memory() {
|
|||||||
size_t get_cache_memory() {
|
size_t get_cache_memory() {
|
||||||
return allocator().get_cache_memory();
|
return allocator().get_cache_memory();
|
||||||
}
|
}
|
||||||
|
void clear_cache() {
|
||||||
|
return allocator().clear_cache();
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace metal
|
} // namespace metal
|
||||||
|
|
||||||
|
@ -26,6 +26,7 @@ class BufferCache {
|
|||||||
size_t cache_size() {
|
size_t cache_size() {
|
||||||
return pool_size_;
|
return pool_size_;
|
||||||
}
|
}
|
||||||
|
void clear();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
struct BufferHolder {
|
struct BufferHolder {
|
||||||
@ -37,7 +38,6 @@ class BufferCache {
|
|||||||
MTL::Buffer* buf;
|
MTL::Buffer* buf;
|
||||||
};
|
};
|
||||||
|
|
||||||
void clear();
|
|
||||||
void add_at_head(BufferHolder* to_add);
|
void add_at_head(BufferHolder* to_add);
|
||||||
void remove_from_list(BufferHolder* to_remove);
|
void remove_from_list(BufferHolder* to_remove);
|
||||||
|
|
||||||
@ -67,6 +67,7 @@ class MetalAllocator : public allocator::Allocator {
|
|||||||
};
|
};
|
||||||
size_t set_cache_limit(size_t limit);
|
size_t set_cache_limit(size_t limit);
|
||||||
size_t set_memory_limit(size_t limit, bool relaxed);
|
size_t set_memory_limit(size_t limit, bool relaxed);
|
||||||
|
void clear_cache();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
MTL::Device* device_;
|
MTL::Device* device_;
|
||||||
|
@ -54,6 +54,9 @@ size_t set_memory_limit(size_t limit, bool relaxed = true);
|
|||||||
* */
|
* */
|
||||||
size_t set_cache_limit(size_t limit);
|
size_t set_cache_limit(size_t limit);
|
||||||
|
|
||||||
|
/* Clear the memory cache. */
|
||||||
|
void clear_cache();
|
||||||
|
|
||||||
/** Capture a GPU trace, saving it to an absolute file `path` */
|
/** Capture a GPU trace, saving it to an absolute file `path` */
|
||||||
void start_capture(std::string path = "");
|
void start_capture(std::string path = "");
|
||||||
void stop_capture();
|
void stop_capture();
|
||||||
|
@ -4,7 +4,6 @@
|
|||||||
|
|
||||||
#include "mlx/backend/metal/metal.h"
|
#include "mlx/backend/metal/metal.h"
|
||||||
#include "mlx/backend/metal/metal_impl.h"
|
#include "mlx/backend/metal/metal_impl.h"
|
||||||
|
|
||||||
namespace mlx::core::metal {
|
namespace mlx::core::metal {
|
||||||
|
|
||||||
bool is_available() {
|
bool is_available() {
|
||||||
@ -48,5 +47,6 @@ size_t set_cache_limit(size_t) {
|
|||||||
}
|
}
|
||||||
void start_capture(std::string path) {}
|
void start_capture(std::string path) {}
|
||||||
void stop_capture() {}
|
void stop_capture() {}
|
||||||
|
void clear_cache() {}
|
||||||
|
|
||||||
} // namespace mlx::core::metal
|
} // namespace mlx::core::metal
|
||||||
|
@ -90,6 +90,15 @@ void init_metal(nb::module_& m) {
|
|||||||
Returns:
|
Returns:
|
||||||
int: The previous cache limit in bytes.
|
int: The previous cache limit in bytes.
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
|
metal.def(
|
||||||
|
"clear_cache",
|
||||||
|
&metal::clear_cache,
|
||||||
|
R"pbdoc(
|
||||||
|
Clear the memory cache.
|
||||||
|
|
||||||
|
After calling this, :func:`get_cache_memory` should return ``0``.
|
||||||
|
)pbdoc");
|
||||||
|
|
||||||
metal.def(
|
metal.def(
|
||||||
"start_capture",
|
"start_capture",
|
||||||
&metal::start_capture,
|
&metal::start_capture,
|
||||||
|
@ -42,6 +42,9 @@ class TestMetal(mlx_tests.MLXTestCase):
|
|||||||
cache_mem = mx.metal.get_cache_memory()
|
cache_mem = mx.metal.get_cache_memory()
|
||||||
self.assertTrue(cache_mem >= 4096 * 4)
|
self.assertTrue(cache_mem >= 4096 * 4)
|
||||||
|
|
||||||
|
mx.metal.clear_cache()
|
||||||
|
self.assertEqual(mx.metal.get_cache_memory(), 0)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
@ -513,4 +513,7 @@ TEST_CASE("test metal memory info") {
|
|||||||
auto cache_mem = metal::get_cache_memory();
|
auto cache_mem = metal::get_cache_memory();
|
||||||
CHECK(cache_mem >= 4096 * 4);
|
CHECK(cache_mem >= 4096 * 4);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
metal::clear_cache();
|
||||||
|
CHECK_EQ(metal::get_cache_memory(), 0);
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user