add forced swap

This commit is contained in:
Awni Hannun 2023-12-25 12:52:06 -08:00
parent ad5036072c
commit c67a48be48
2 changed files with 9 additions and 4 deletions

View File

@ -9,7 +9,7 @@
namespace mlx::core::allocator { namespace mlx::core::allocator {
Buffer malloc(size_t size) { Buffer malloc(size_t size) {
auto buffer = allocator().malloc(size); auto buffer = allocator().malloc(size, /* allow_swap */ true);
if (size && !buffer.ptr()) { if (size && !buffer.ptr()) {
std::ostringstream msg; std::ostringstream msg;
msg << "[malloc] Unable to allocate " << size << " bytes."; msg << "[malloc] Unable to allocate " << size << " bytes.";
@ -22,7 +22,7 @@ void free(Buffer buffer) {
return allocator().free(buffer); return allocator().free(buffer);
} }
Buffer CommonAllocator::malloc(size_t size) { Buffer CommonAllocator::malloc(size_t size, bool) {
return Buffer{std::malloc(size)}; return Buffer{std::malloc(size)};
} }
@ -38,6 +38,11 @@ Buffer malloc_or_wait(size_t size) {
buffer = allocator().malloc(size); buffer = allocator().malloc(size);
} }
// Try swapping if needed
if (size && !buffer.ptr()) {
buffer = allocator().malloc(size, /* allow_swap = */ true);
}
if (size && !buffer.ptr()) { if (size && !buffer.ptr()) {
std::ostringstream msg; std::ostringstream msg;
msg << "[malloc_or_wait] Unable to allocate " << size << " bytes."; msg << "[malloc_or_wait] Unable to allocate " << size << " bytes.";

View File

@ -39,7 +39,7 @@ Buffer malloc_or_wait(size_t size);
class Allocator { class Allocator {
/** Abstract base class for a memory allocator. */ /** Abstract base class for a memory allocator. */
public: public:
virtual Buffer malloc(size_t size) = 0; virtual Buffer malloc(size_t size, bool allow_swap = false) = 0;
virtual void free(Buffer buffer) = 0; virtual void free(Buffer buffer) = 0;
Allocator() = default; Allocator() = default;
@ -55,7 +55,7 @@ Allocator& allocator();
class CommonAllocator : public Allocator { class CommonAllocator : public Allocator {
/** A general CPU allocator. */ /** A general CPU allocator. */
public: public:
virtual Buffer malloc(size_t size) override; virtual Buffer malloc(size_t size, bool allow_swap = false) override;
virtual void free(Buffer buffer) override; virtual void free(Buffer buffer) override;
private: private: