diff --git a/mlx/backend/cuda/allocator.cpp b/mlx/backend/cuda/allocator.cpp index 6ca225a5f..93bf48542 100644 --- a/mlx/backend/cuda/allocator.cpp +++ b/mlx/backend/cuda/allocator.cpp @@ -2,7 +2,6 @@ #include "mlx/backend/cuda/allocator.h" #include "mlx/backend/cuda/utils.h" -#include "mlx/backend/cuda/worker.h" #include "mlx/utils.h" #include @@ -25,55 +24,58 @@ constexpr int small_block_size = 8; constexpr int small_pool_size = 4 * page_size; SmallSizePool::SmallSizePool() { - CHECK_CUDA_ERROR(cudaMallocManaged(&buffer_, small_pool_size)); - end_ = reinterpret_cast( - reinterpret_cast(buffer_) + small_pool_size); - next_free_ = reinterpret_cast(buffer_); - - CHECK_CUDA_ERROR( - cudaMemAdvise(buffer_, small_pool_size, cudaMemAdviseSetReadMostly, 0)); - auto num_blocks = small_pool_size / small_block_size; + buffer_ = new Block[num_blocks]; + + next_free_ = buffer_; + + CHECK_CUDA_ERROR(cudaMallocManaged(&data_, small_pool_size)); + CHECK_CUDA_ERROR( + cudaMemAdvise(data_, small_pool_size, cudaMemAdviseSetReadMostly, 0)); + auto curr = next_free_; - for (size_t i = 0; i < num_blocks - 1; ++i) { - curr->next = reinterpret_cast( - reinterpret_cast(buffer_) + (i + 1) * small_block_size); + for (size_t i = 1; i < num_blocks; ++i) { + curr->next = buffer_ + i; curr = curr->next; } curr->next = nullptr; } SmallSizePool::~SmallSizePool() { - CHECK_CUDA_ERROR(cudaFree(buffer_)); + CHECK_CUDA_ERROR(cudaFree(data_)); + delete[] buffer_; } -void* SmallSizePool::malloc() { +CudaBuffer* SmallSizePool::malloc() { if (next_free_ == nullptr) { return nullptr; } Block* b = next_free_; + uint64_t i = next_free_ - buffer_; next_free_ = next_free_->next; - return static_cast(b); + b->buf.data = static_cast(data_) + i * small_block_size; + b->buf.size = small_block_size; + return &b->buf; } -void SmallSizePool::free(void* p) { - auto b = static_cast(p); +void SmallSizePool::free(CudaBuffer* buf) { + auto b = reinterpret_cast(buf); b->next = next_free_; next_free_ = b; } -bool SmallSizePool::in_pool(void* p) { - return (p >= buffer_) && (p < end_); +bool SmallSizePool::in_pool(CudaBuffer* buf) { + constexpr int num_blocks = (small_pool_size / small_block_size); + auto b = reinterpret_cast(buf); + int64_t block_num = b - buffer_; + return block_num >= 0 && block_num < num_blocks; } CudaAllocator::CudaAllocator() : buffer_cache_( page_size, [](CudaBuffer* buf) { return buf->size; }, - [this](CudaBuffer* buf) { - cuda_free(buf->data); - delete buf; - }) { + [this](CudaBuffer* buf) { cuda_free(buf); }) { // TODO: Set memory limit for multi-device. size_t free, total; CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total)); @@ -95,24 +97,20 @@ Buffer CudaAllocator::malloc(size_t size) { CudaBuffer* buf = buffer_cache_.reuse_from_cache(size); if (!buf) { - // If we have a lot of memory pressure or are over the maximum cache size, - // try to reclaim memory from the cache. + // If we have a lot of memory pressure try to reclaim memory from the cache. int64_t mem_to_free = get_active_memory() + get_cache_memory() + size - memory_limit_; - mem_to_free = std::max( - static_cast(get_cache_memory() - max_pool_size_), mem_to_free); if (mem_to_free > 0) { buffer_cache_.release_cached_buffers(mem_to_free); } - buf = new CudaBuffer{nullptr, size}; - // Try the scalar pool first if (size <= small_block_size) { - buf->data = scalar_pool_.malloc(); + buf = scalar_pool_.malloc(); } lock.unlock(); - if (!buf->data) { + if (!buf) { + buf = new CudaBuffer{nullptr, size}; cudaError_t err = cudaMallocManaged(&buf->data, size); if (err != cudaSuccess && err != cudaErrorMemoryAllocation) { throw std::runtime_error(fmt::format( @@ -139,7 +137,11 @@ void CudaAllocator::free(Buffer buffer) { std::unique_lock lock(mutex_); active_memory_ -= buf->size; - buffer_cache_.recycle_to_cache(buf); + if (get_cache_memory() < max_pool_size_) { + buffer_cache_.recycle_to_cache(buf); + } else { + cuda_free(buf); + } } size_t CudaAllocator::size(Buffer buffer) const { @@ -151,11 +153,12 @@ size_t CudaAllocator::size(Buffer buffer) const { } // This must be called with mutex_ aquired -void CudaAllocator::cuda_free(void* buf) { +void CudaAllocator::cuda_free(CudaBuffer* buf) { if (scalar_pool_.in_pool(buf)) { scalar_pool_.free(buf); } else { - cudaFree(buf); + cudaFree(buf->data); + delete buf; } } diff --git a/mlx/backend/cuda/allocator.h b/mlx/backend/cuda/allocator.h index 1fb9bfe95..81b3dde59 100644 --- a/mlx/backend/cuda/allocator.h +++ b/mlx/backend/cuda/allocator.h @@ -21,13 +21,14 @@ struct CudaBuffer { class SmallSizePool { private: - struct Block { + union Block { Block* next; + CudaBuffer buf; }; - void* buffer_{nullptr}; + Block* buffer_{nullptr}; + void* data_{nullptr}; Block* next_free_{nullptr}; - void* end_{nullptr}; public: SmallSizePool(); @@ -36,9 +37,9 @@ class SmallSizePool { SmallSizePool(const SmallSizePool&) = delete; SmallSizePool& operator=(const SmallSizePool&) = delete; - void* malloc(); - void free(void* p); - bool in_pool(void* p); + CudaBuffer* malloc(); + void free(CudaBuffer* buf); + bool in_pool(CudaBuffer* buf); }; class CudaAllocator : public allocator::Allocator { @@ -57,7 +58,7 @@ class CudaAllocator : public allocator::Allocator { void clear_cache(); private: - void cuda_free(void* buf); + void cuda_free(CudaBuffer* buf); CudaAllocator(); friend CudaAllocator& allocator(); diff --git a/mlx/backend/metal/allocator.cpp b/mlx/backend/metal/allocator.cpp index dd6189732..8eb70bcbc 100644 --- a/mlx/backend/metal/allocator.cpp +++ b/mlx/backend/metal/allocator.cpp @@ -128,8 +128,7 @@ Buffer MetalAllocator::malloc(size_t size) { auto pool = metal::new_scoped_memory_pool(); - // If we have a lot of memory pressure or are over the maximum cache size, - // try to reclaim memory from the cache + // If we have a lot of memory pressure try to reclaim memory from the cache if (mem_required >= gc_limit_ || num_resources_ >= resource_limit_) { num_resources_ -= buffer_cache_.release_cached_buffers(mem_required - gc_limit_);