diff --git a/mlx/backend/cuda/allocator.cpp b/mlx/backend/cuda/allocator.cpp index 6cc7145b5..66e8c5c66 100644 --- a/mlx/backend/cuda/allocator.cpp +++ b/mlx/backend/cuda/allocator.cpp @@ -17,6 +17,52 @@ namespace cu { constexpr int page_size = 16384; +// Any allocations smaller than this will try to use the small pool +constexpr int small_block_size = 8; + +// The small pool size in bytes. This should be a multiple of the host page +// size and small_block_size. +constexpr int small_pool_size = 4 * page_size; + +SmallSizePool::SmallSizePool() { + CHECK_CUDA_ERROR(cudaMallocManaged(&buffer_, small_pool_size)); + end_ = reinterpret_cast( + reinterpret_cast(buffer_) + small_pool_size); + next_free_ = reinterpret_cast(buffer_); + + auto num_blocks = small_pool_size / small_block_size; + auto curr = next_free_; + for (size_t i = 0; i < num_blocks - 1; ++i) { + curr->next = reinterpret_cast( + reinterpret_cast(buffer_) + (i + 1) * small_block_size); + curr = curr->next; + } + curr->next = nullptr; +} + +SmallSizePool::~SmallSizePool() { + CHECK_CUDA_ERROR(cudaFree(buffer_)); +} + +void* SmallSizePool::malloc() { + if (next_free_ == nullptr) { + return nullptr; + } + Block* b = next_free_; + next_free_ = next_free_->next; + return static_cast(b); +} + +void SmallSizePool::free(void* p) { + auto b = static_cast(p); + b->next = next_free_; + next_free_ = b; +} + +bool SmallSizePool::in_pool(void* p) { + return (p >= buffer_) && (p < end_); +} + CudaAllocator::CudaAllocator() : buffer_cache_( page_size, @@ -36,7 +82,9 @@ Buffer CudaAllocator::malloc(size_t size) { // Find available buffer from cache. auto orig_size = size; std::unique_lock lock(mutex_); - if (size < page_size) { + if (size <= small_block_size) { + size = 8; + } else if (size < page_size) { size = next_power_of_2(size); } else { size = page_size * ((size + page_size - 1) / page_size); @@ -53,11 +101,19 @@ Buffer CudaAllocator::malloc(size_t size) { lock.unlock(); buf = new CudaBuffer{nullptr, size}; - cudaError_t err = cudaMallocManaged(&buf->data, size); - if (err != cudaSuccess && err != cudaErrorMemoryAllocation) { - throw std::runtime_error(fmt::format( - "cudaMallocManaged failed: {}.", cudaGetErrorString(err))); + + // Try the scalar pool first + if (size <= small_block_size) { + buf->data = scalar_pool_.malloc(); } + if (!buf->data) { + cudaError_t err = cudaMallocManaged(&buf->data, size); + if (err != cudaSuccess && err != cudaErrorMemoryAllocation) { + throw std::runtime_error(fmt::format( + "cudaMallocManaged failed: {}.", cudaGetErrorString(err))); + } + } + lock.lock(); } active_memory_ += size; @@ -116,7 +172,11 @@ void CudaAllocator::cuda_free(void* buf) { return; } } - cudaFree(buf); + if (scalar_pool_.in_pool(buf)) { + scalar_pool_.free(buf); + } else { + cudaFree(buf); + } } size_t CudaAllocator::get_active_memory() const { diff --git a/mlx/backend/cuda/allocator.h b/mlx/backend/cuda/allocator.h index e268c6334..f7474dda6 100644 --- a/mlx/backend/cuda/allocator.h +++ b/mlx/backend/cuda/allocator.h @@ -22,6 +22,28 @@ struct CudaBuffer { size_t size; }; +class SmallSizePool { + private: + struct Block { + Block* next; + }; + + void* buffer_{nullptr}; + Block* next_free_{nullptr}; + void* end_{nullptr}; + + public: + SmallSizePool(); + ~SmallSizePool(); + + SmallSizePool(const SmallSizePool&) = delete; + SmallSizePool& operator=(const SmallSizePool&) = delete; + + void* malloc(); + void free(void* p); + bool in_pool(void* p); +}; + class CudaAllocator : public allocator::Allocator { public: Buffer malloc(size_t size) override; @@ -60,6 +82,7 @@ class CudaAllocator : public allocator::Allocator { BufferCache buffer_cache_; size_t active_memory_{0}; size_t peak_memory_{0}; + SmallSizePool scalar_pool_; }; CudaAllocator& allocator();