Expose function to clear memory cache (#1032)

* expose function to clear memory cache

* fix linux build

* fix metal tests
This commit is contained in:
Awni Hannun 2024-04-24 16:48:51 -07:00 committed by GitHub
parent 20a01bbd9f
commit 771575d27b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 31 additions and 2 deletions

View File

@ -12,5 +12,6 @@ Metal
get_cache_memory
set_memory_limit
set_cache_limit
clear_cache
start_capture
stop_capture

View File

@ -92,6 +92,7 @@ Operations
moveaxis
multiply
negative
not_equal
ones
ones_like
outer

View File

@ -209,6 +209,11 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
return Buffer{static_cast<void*>(buf)};
}
void MetalAllocator::clear_cache() {
std::unique_lock lk(mutex_);
buffer_cache_.clear();
}
void MetalAllocator::free(Buffer buffer) {
auto buf = static_cast<MTL::Buffer*>(buffer.ptr());
std::unique_lock lk(mutex_);
@ -242,6 +247,9 @@ size_t get_peak_memory() {
size_t get_cache_memory() {
return allocator().get_cache_memory();
}
void clear_cache() {
return allocator().clear_cache();
}
} // namespace metal

View File

@ -26,6 +26,7 @@ class BufferCache {
size_t cache_size() {
return pool_size_;
}
void clear();
private:
struct BufferHolder {
@ -37,7 +38,6 @@ class BufferCache {
MTL::Buffer* buf;
};
void clear();
void add_at_head(BufferHolder* to_add);
void remove_from_list(BufferHolder* to_remove);
@ -67,6 +67,7 @@ class MetalAllocator : public allocator::Allocator {
};
size_t set_cache_limit(size_t limit);
size_t set_memory_limit(size_t limit, bool relaxed);
void clear_cache();
private:
MTL::Device* device_;

View File

@ -54,6 +54,9 @@ size_t set_memory_limit(size_t limit, bool relaxed = true);
* */
size_t set_cache_limit(size_t limit);
/* Clear the memory cache. */
void clear_cache();
/** Capture a GPU trace, saving it to an absolute file `path` */
void start_capture(std::string path = "");
void stop_capture();

View File

@ -4,7 +4,6 @@
#include "mlx/backend/metal/metal.h"
#include "mlx/backend/metal/metal_impl.h"
namespace mlx::core::metal {
bool is_available() {
@ -48,5 +47,6 @@ size_t set_cache_limit(size_t) {
}
void start_capture(std::string path) {}
void stop_capture() {}
void clear_cache() {}
} // namespace mlx::core::metal

View File

@ -90,6 +90,15 @@ void init_metal(nb::module_& m) {
Returns:
int: The previous cache limit in bytes.
)pbdoc");
metal.def(
"clear_cache",
&metal::clear_cache,
R"pbdoc(
Clear the memory cache.
After calling this, :func:`get_cache_memory` should return ``0``.
)pbdoc");
metal.def(
"start_capture",
&metal::start_capture,

View File

@ -42,6 +42,9 @@ class TestMetal(mlx_tests.MLXTestCase):
cache_mem = mx.metal.get_cache_memory()
self.assertTrue(cache_mem >= 4096 * 4)
mx.metal.clear_cache()
self.assertEqual(mx.metal.get_cache_memory(), 0)
if __name__ == "__main__":
unittest.main()

View File

@ -513,4 +513,7 @@ TEST_CASE("test metal memory info") {
auto cache_mem = metal::get_cache_memory();
CHECK(cache_mem >= 4096 * 4);
}
metal::clear_cache();
CHECK_EQ(metal::get_cache_memory(), 0);
}