use cuda buffer in small pool

This commit is contained in:
Awni Hannun
2025-07-20 07:14:57 -07:00
parent 60e20bedb6
commit 4fd39d662d
3 changed files with 46 additions and 43 deletions

View File

@@ -2,7 +2,6 @@
#include "mlx/backend/cuda/allocator.h" #include "mlx/backend/cuda/allocator.h"
#include "mlx/backend/cuda/utils.h" #include "mlx/backend/cuda/utils.h"
#include "mlx/backend/cuda/worker.h"
#include "mlx/utils.h" #include "mlx/utils.h"
#include <cuda_runtime.h> #include <cuda_runtime.h>
@@ -25,55 +24,58 @@ constexpr int small_block_size = 8;
constexpr int small_pool_size = 4 * page_size; constexpr int small_pool_size = 4 * page_size;
SmallSizePool::SmallSizePool() { SmallSizePool::SmallSizePool() {
CHECK_CUDA_ERROR(cudaMallocManaged(&buffer_, small_pool_size));
end_ = reinterpret_cast<void*>(
reinterpret_cast<char*>(buffer_) + small_pool_size);
next_free_ = reinterpret_cast<Block*>(buffer_);
CHECK_CUDA_ERROR(
cudaMemAdvise(buffer_, small_pool_size, cudaMemAdviseSetReadMostly, 0));
auto num_blocks = small_pool_size / small_block_size; auto num_blocks = small_pool_size / small_block_size;
buffer_ = new Block[num_blocks];
next_free_ = buffer_;
CHECK_CUDA_ERROR(cudaMallocManaged(&data_, small_pool_size));
CHECK_CUDA_ERROR(
cudaMemAdvise(data_, small_pool_size, cudaMemAdviseSetReadMostly, 0));
auto curr = next_free_; auto curr = next_free_;
for (size_t i = 0; i < num_blocks - 1; ++i) { for (size_t i = 1; i < num_blocks; ++i) {
curr->next = reinterpret_cast<Block*>( curr->next = buffer_ + i;
reinterpret_cast<char*>(buffer_) + (i + 1) * small_block_size);
curr = curr->next; curr = curr->next;
} }
curr->next = nullptr; curr->next = nullptr;
} }
SmallSizePool::~SmallSizePool() { SmallSizePool::~SmallSizePool() {
CHECK_CUDA_ERROR(cudaFree(buffer_)); CHECK_CUDA_ERROR(cudaFree(data_));
delete[] buffer_;
} }
void* SmallSizePool::malloc() { CudaBuffer* SmallSizePool::malloc() {
if (next_free_ == nullptr) { if (next_free_ == nullptr) {
return nullptr; return nullptr;
} }
Block* b = next_free_; Block* b = next_free_;
uint64_t i = next_free_ - buffer_;
next_free_ = next_free_->next; next_free_ = next_free_->next;
return static_cast<void*>(b); b->buf.data = static_cast<char*>(data_) + i * small_block_size;
b->buf.size = small_block_size;
return &b->buf;
} }
void SmallSizePool::free(void* p) { void SmallSizePool::free(CudaBuffer* buf) {
auto b = static_cast<Block*>(p); auto b = reinterpret_cast<Block*>(buf);
b->next = next_free_; b->next = next_free_;
next_free_ = b; next_free_ = b;
} }
bool SmallSizePool::in_pool(void* p) { bool SmallSizePool::in_pool(CudaBuffer* buf) {
return (p >= buffer_) && (p < end_); constexpr int num_blocks = (small_pool_size / small_block_size);
auto b = reinterpret_cast<Block*>(buf);
int64_t block_num = b - buffer_;
return block_num >= 0 && block_num < num_blocks;
} }
CudaAllocator::CudaAllocator() CudaAllocator::CudaAllocator()
: buffer_cache_( : buffer_cache_(
page_size, page_size,
[](CudaBuffer* buf) { return buf->size; }, [](CudaBuffer* buf) { return buf->size; },
[this](CudaBuffer* buf) { [this](CudaBuffer* buf) { cuda_free(buf); }) {
cuda_free(buf->data);
delete buf;
}) {
// TODO: Set memory limit for multi-device. // TODO: Set memory limit for multi-device.
size_t free, total; size_t free, total;
CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total)); CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total));
@@ -95,24 +97,20 @@ Buffer CudaAllocator::malloc(size_t size) {
CudaBuffer* buf = buffer_cache_.reuse_from_cache(size); CudaBuffer* buf = buffer_cache_.reuse_from_cache(size);
if (!buf) { if (!buf) {
// If we have a lot of memory pressure or are over the maximum cache size, // If we have a lot of memory pressure try to reclaim memory from the cache.
// try to reclaim memory from the cache.
int64_t mem_to_free = int64_t mem_to_free =
get_active_memory() + get_cache_memory() + size - memory_limit_; get_active_memory() + get_cache_memory() + size - memory_limit_;
mem_to_free = std::max(
static_cast<int64_t>(get_cache_memory() - max_pool_size_), mem_to_free);
if (mem_to_free > 0) { if (mem_to_free > 0) {
buffer_cache_.release_cached_buffers(mem_to_free); buffer_cache_.release_cached_buffers(mem_to_free);
} }
buf = new CudaBuffer{nullptr, size};
// Try the scalar pool first // Try the scalar pool first
if (size <= small_block_size) { if (size <= small_block_size) {
buf->data = scalar_pool_.malloc(); buf = scalar_pool_.malloc();
} }
lock.unlock(); lock.unlock();
if (!buf->data) { if (!buf) {
buf = new CudaBuffer{nullptr, size};
cudaError_t err = cudaMallocManaged(&buf->data, size); cudaError_t err = cudaMallocManaged(&buf->data, size);
if (err != cudaSuccess && err != cudaErrorMemoryAllocation) { if (err != cudaSuccess && err != cudaErrorMemoryAllocation) {
throw std::runtime_error(fmt::format( throw std::runtime_error(fmt::format(
@@ -139,7 +137,11 @@ void CudaAllocator::free(Buffer buffer) {
std::unique_lock lock(mutex_); std::unique_lock lock(mutex_);
active_memory_ -= buf->size; active_memory_ -= buf->size;
buffer_cache_.recycle_to_cache(buf); if (get_cache_memory() < max_pool_size_) {
buffer_cache_.recycle_to_cache(buf);
} else {
cuda_free(buf);
}
} }
size_t CudaAllocator::size(Buffer buffer) const { size_t CudaAllocator::size(Buffer buffer) const {
@@ -151,11 +153,12 @@ size_t CudaAllocator::size(Buffer buffer) const {
} }
// This must be called with mutex_ aquired // This must be called with mutex_ aquired
void CudaAllocator::cuda_free(void* buf) { void CudaAllocator::cuda_free(CudaBuffer* buf) {
if (scalar_pool_.in_pool(buf)) { if (scalar_pool_.in_pool(buf)) {
scalar_pool_.free(buf); scalar_pool_.free(buf);
} else { } else {
cudaFree(buf); cudaFree(buf->data);
delete buf;
} }
} }

View File

@@ -21,13 +21,14 @@ struct CudaBuffer {
class SmallSizePool { class SmallSizePool {
private: private:
struct Block { union Block {
Block* next; Block* next;
CudaBuffer buf;
}; };
void* buffer_{nullptr}; Block* buffer_{nullptr};
void* data_{nullptr};
Block* next_free_{nullptr}; Block* next_free_{nullptr};
void* end_{nullptr};
public: public:
SmallSizePool(); SmallSizePool();
@@ -36,9 +37,9 @@ class SmallSizePool {
SmallSizePool(const SmallSizePool&) = delete; SmallSizePool(const SmallSizePool&) = delete;
SmallSizePool& operator=(const SmallSizePool&) = delete; SmallSizePool& operator=(const SmallSizePool&) = delete;
void* malloc(); CudaBuffer* malloc();
void free(void* p); void free(CudaBuffer* buf);
bool in_pool(void* p); bool in_pool(CudaBuffer* buf);
}; };
class CudaAllocator : public allocator::Allocator { class CudaAllocator : public allocator::Allocator {
@@ -57,7 +58,7 @@ class CudaAllocator : public allocator::Allocator {
void clear_cache(); void clear_cache();
private: private:
void cuda_free(void* buf); void cuda_free(CudaBuffer* buf);
CudaAllocator(); CudaAllocator();
friend CudaAllocator& allocator(); friend CudaAllocator& allocator();

View File

@@ -128,8 +128,7 @@ Buffer MetalAllocator::malloc(size_t size) {
auto pool = metal::new_scoped_memory_pool(); auto pool = metal::new_scoped_memory_pool();
// If we have a lot of memory pressure or are over the maximum cache size, // If we have a lot of memory pressure try to reclaim memory from the cache
// try to reclaim memory from the cache
if (mem_required >= gc_limit_ || num_resources_ >= resource_limit_) { if (mem_required >= gc_limit_ || num_resources_ >= resource_limit_) {
num_resources_ -= num_resources_ -=
buffer_cache_.release_cached_buffers(mem_required - gc_limit_); buffer_cache_.release_cached_buffers(mem_required - gc_limit_);