diff --git a/mlx/backend/cuda/allocator.cpp b/mlx/backend/cuda/allocator.cpp index 203534e21..86af3a774 100644 --- a/mlx/backend/cuda/allocator.cpp +++ b/mlx/backend/cuda/allocator.cpp @@ -6,6 +6,7 @@ #include #include +#include #include @@ -13,24 +14,47 @@ namespace mlx::core { namespace cu { -CudaAllocator::CudaAllocator() { +CudaAllocator::CudaAllocator() + : buffer_cache_( + getpagesize(), + [](CudaBuffer* buf) { return buf->size; }, + [this](CudaBuffer* buf) { cuda_free(buf); }) { // TODO: Set memory limit for multi-device. size_t free, total; CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total)); memory_limit_ = total * 0.8; + max_pool_size_ = memory_limit_; } Buffer CudaAllocator::malloc(size_t size) { - // TODO: Check memory limit. - auto* buf = new CudaBuffer{nullptr, size}; - cudaError_t err = cudaMallocManaged(&buf->data, size); - if (err != cudaSuccess && err != cudaErrorMemoryAllocation) { - throw std::runtime_error( - fmt::format("cudaMallocManaged failed: {}.", cudaGetErrorString(err))); + // Find available buffer from cache. + std::unique_lock lock(mutex_); + CudaBuffer* buf = buffer_cache_.reuse_from_cache(size); + if (!buf) { + // If we have a lot of memory pressure or are over the maximum cache size, + // try to reclaim memory from the cache. + size_t mem_required = get_active_memory() + get_cache_memory() + size; + if (mem_required >= memory_limit_) { + buffer_cache_.release_cached_buffers(mem_required - memory_limit_); + } + + lock.unlock(); + buf = new CudaBuffer{nullptr, size}; + cudaError_t err = cudaMallocManaged(&buf->data, size); + if (err != cudaSuccess && err != cudaErrorMemoryAllocation) { + throw std::runtime_error(fmt::format( + "cudaMallocManaged failed: {}.", cudaGetErrorString(err))); + } + lock.lock(); } - std::lock_guard lock(mutex_); active_memory_ += size; peak_memory_ = std::max(active_memory_, peak_memory_); + + // Maintain the cache below the requested limit. + if (get_cache_memory() > max_pool_size_) { + buffer_cache_.release_cached_buffers(get_cache_memory() - max_pool_size_); + } + return Buffer{buf}; } @@ -40,26 +64,14 @@ void CudaAllocator::free(Buffer buffer) { return; } - // If free() is called from a unregistered thread, reschedule the call to - // worker. - { - std::lock_guard lock(worker_mutex_); - if (allowed_threads_.count(std::this_thread::get_id()) == 0) { - if (!worker_) { - worker_.reset(new Worker); - } - worker_->add_task([buffer]() { allocator().free(buffer); }); - worker_->end_batch(); - worker_->commit(); - return; - } + std::unique_lock lock(mutex_); + active_memory_ -= buf->size; + if (get_cache_memory() < max_pool_size_) { + buffer_cache_.recycle_to_cache(buf); + } else { + lock.unlock(); + cuda_free(buf); } - - size_t size = buf->size; - cudaFree(buf->data); - delete buf; - std::lock_guard lock(mutex_); - active_memory_ -= size; } size_t CudaAllocator::size(Buffer buffer) const { @@ -98,6 +110,41 @@ size_t CudaAllocator::set_memory_limit(size_t limit) { return limit; } +size_t CudaAllocator::get_cache_memory() const { + return buffer_cache_.cache_size(); +} + +size_t CudaAllocator::set_cache_limit(size_t limit) { + std::lock_guard lk(mutex_); + std::swap(limit, max_pool_size_); + return limit; +} + +void CudaAllocator::clear_cache() { + std::lock_guard lk(mutex_); + buffer_cache_.clear(); +} + +void CudaAllocator::cuda_free(CudaBuffer* buf) { + // If cuda_free() is called from a unregistered thread, reschedule the call to + // worker. + { + std::lock_guard lock(worker_mutex_); + if (allowed_threads_.count(std::this_thread::get_id()) == 0) { + if (!worker_) { + worker_.reset(new Worker); + } + worker_->add_task([this, buf]() { this->cuda_free(buf); }); + worker_->end_batch(); + worker_->commit(); + return; + } + } + + cudaFree(buf->data); + delete buf; +} + CudaAllocator& allocator() { // By creating the |allocator_| on heap, the destructor of CudaAllocator // will not be called on exit and buffers in the cache will be leaked. This @@ -138,17 +185,19 @@ size_t set_memory_limit(size_t limit) { size_t get_memory_limit() { return cu::allocator().get_memory_limit(); } - -// TODO: Implement buffer cache. size_t get_cache_memory() { - return 0; + return cu::allocator().get_cache_memory(); } -size_t set_cache_limit(size_t) { - return 0; +size_t set_cache_limit(size_t limit) { + return cu::allocator().set_cache_limit(limit); } +void clear_cache() { + cu::allocator().clear_cache(); +} + +// Not supported in CUDA. size_t set_wired_limit(size_t) { return 0; } -void clear_cache() {} } // namespace mlx::core diff --git a/mlx/backend/cuda/allocator.h b/mlx/backend/cuda/allocator.h index 6c418ee7e..fe3755121 100644 --- a/mlx/backend/cuda/allocator.h +++ b/mlx/backend/cuda/allocator.h @@ -3,6 +3,7 @@ #pragma once #include "mlx/allocator.h" +#include "mlx/backend/common/buffer_cache.h" #include #include @@ -38,17 +39,24 @@ class CudaAllocator : public allocator::Allocator { void reset_peak_memory(); size_t get_memory_limit(); size_t set_memory_limit(size_t limit); + size_t get_cache_memory() const; + size_t set_cache_limit(size_t limit); + void clear_cache(); private: CudaAllocator(); friend CudaAllocator& allocator(); + void cuda_free(CudaBuffer* buf); + std::mutex worker_mutex_; std::unique_ptr worker_; std::set allowed_threads_; std::mutex mutex_; size_t memory_limit_; + size_t max_pool_size_; + BufferCache buffer_cache_; size_t active_memory_{0}; size_t peak_memory_{0}; };