From c67a48be48f42dfba6ae49372bd1af30e3648adf Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 25 Dec 2023 12:52:06 -0800 Subject: [PATCH] add forced swap --- mlx/allocator.cpp | 9 +++++++-- mlx/allocator.h | 4 ++-- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/mlx/allocator.cpp b/mlx/allocator.cpp index b591caeeb..86fa5974f 100644 --- a/mlx/allocator.cpp +++ b/mlx/allocator.cpp @@ -9,7 +9,7 @@ namespace mlx::core::allocator { Buffer malloc(size_t size) { - auto buffer = allocator().malloc(size); + auto buffer = allocator().malloc(size, /* allow_swap */ true); if (size && !buffer.ptr()) { std::ostringstream msg; msg << "[malloc] Unable to allocate " << size << " bytes."; @@ -22,7 +22,7 @@ void free(Buffer buffer) { return allocator().free(buffer); } -Buffer CommonAllocator::malloc(size_t size) { +Buffer CommonAllocator::malloc(size_t size, bool) { return Buffer{std::malloc(size)}; } @@ -38,6 +38,11 @@ Buffer malloc_or_wait(size_t size) { buffer = allocator().malloc(size); } + // Try swapping if needed + if (size && !buffer.ptr()) { + buffer = allocator().malloc(size, /* allow_swap = */ true); + } + if (size && !buffer.ptr()) { std::ostringstream msg; msg << "[malloc_or_wait] Unable to allocate " << size << " bytes."; diff --git a/mlx/allocator.h b/mlx/allocator.h index ce0c1cd00..1061d6cce 100644 --- a/mlx/allocator.h +++ b/mlx/allocator.h @@ -39,7 +39,7 @@ Buffer malloc_or_wait(size_t size); class Allocator { /** Abstract base class for a memory allocator. */ public: - virtual Buffer malloc(size_t size) = 0; + virtual Buffer malloc(size_t size, bool allow_swap = false) = 0; virtual void free(Buffer buffer) = 0; Allocator() = default; @@ -55,7 +55,7 @@ Allocator& allocator(); class CommonAllocator : public Allocator { /** A general CPU allocator. */ public: - virtual Buffer malloc(size_t size) override; + virtual Buffer malloc(size_t size, bool allow_swap = false) override; virtual void free(Buffer buffer) override; private: