mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
use cuda buffer in small pool
This commit is contained in:
@@ -2,7 +2,6 @@
|
||||
|
||||
#include "mlx/backend/cuda/allocator.h"
|
||||
#include "mlx/backend/cuda/utils.h"
|
||||
#include "mlx/backend/cuda/worker.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
@@ -25,55 +24,58 @@ constexpr int small_block_size = 8;
|
||||
constexpr int small_pool_size = 4 * page_size;
|
||||
|
||||
SmallSizePool::SmallSizePool() {
|
||||
CHECK_CUDA_ERROR(cudaMallocManaged(&buffer_, small_pool_size));
|
||||
end_ = reinterpret_cast<void*>(
|
||||
reinterpret_cast<char*>(buffer_) + small_pool_size);
|
||||
next_free_ = reinterpret_cast<Block*>(buffer_);
|
||||
|
||||
CHECK_CUDA_ERROR(
|
||||
cudaMemAdvise(buffer_, small_pool_size, cudaMemAdviseSetReadMostly, 0));
|
||||
|
||||
auto num_blocks = small_pool_size / small_block_size;
|
||||
buffer_ = new Block[num_blocks];
|
||||
|
||||
next_free_ = buffer_;
|
||||
|
||||
CHECK_CUDA_ERROR(cudaMallocManaged(&data_, small_pool_size));
|
||||
CHECK_CUDA_ERROR(
|
||||
cudaMemAdvise(data_, small_pool_size, cudaMemAdviseSetReadMostly, 0));
|
||||
|
||||
auto curr = next_free_;
|
||||
for (size_t i = 0; i < num_blocks - 1; ++i) {
|
||||
curr->next = reinterpret_cast<Block*>(
|
||||
reinterpret_cast<char*>(buffer_) + (i + 1) * small_block_size);
|
||||
for (size_t i = 1; i < num_blocks; ++i) {
|
||||
curr->next = buffer_ + i;
|
||||
curr = curr->next;
|
||||
}
|
||||
curr->next = nullptr;
|
||||
}
|
||||
|
||||
SmallSizePool::~SmallSizePool() {
|
||||
CHECK_CUDA_ERROR(cudaFree(buffer_));
|
||||
CHECK_CUDA_ERROR(cudaFree(data_));
|
||||
delete[] buffer_;
|
||||
}
|
||||
|
||||
void* SmallSizePool::malloc() {
|
||||
CudaBuffer* SmallSizePool::malloc() {
|
||||
if (next_free_ == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
Block* b = next_free_;
|
||||
uint64_t i = next_free_ - buffer_;
|
||||
next_free_ = next_free_->next;
|
||||
return static_cast<void*>(b);
|
||||
b->buf.data = static_cast<char*>(data_) + i * small_block_size;
|
||||
b->buf.size = small_block_size;
|
||||
return &b->buf;
|
||||
}
|
||||
|
||||
void SmallSizePool::free(void* p) {
|
||||
auto b = static_cast<Block*>(p);
|
||||
void SmallSizePool::free(CudaBuffer* buf) {
|
||||
auto b = reinterpret_cast<Block*>(buf);
|
||||
b->next = next_free_;
|
||||
next_free_ = b;
|
||||
}
|
||||
|
||||
bool SmallSizePool::in_pool(void* p) {
|
||||
return (p >= buffer_) && (p < end_);
|
||||
bool SmallSizePool::in_pool(CudaBuffer* buf) {
|
||||
constexpr int num_blocks = (small_pool_size / small_block_size);
|
||||
auto b = reinterpret_cast<Block*>(buf);
|
||||
int64_t block_num = b - buffer_;
|
||||
return block_num >= 0 && block_num < num_blocks;
|
||||
}
|
||||
|
||||
CudaAllocator::CudaAllocator()
|
||||
: buffer_cache_(
|
||||
page_size,
|
||||
[](CudaBuffer* buf) { return buf->size; },
|
||||
[this](CudaBuffer* buf) {
|
||||
cuda_free(buf->data);
|
||||
delete buf;
|
||||
}) {
|
||||
[this](CudaBuffer* buf) { cuda_free(buf); }) {
|
||||
// TODO: Set memory limit for multi-device.
|
||||
size_t free, total;
|
||||
CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total));
|
||||
@@ -95,24 +97,20 @@ Buffer CudaAllocator::malloc(size_t size) {
|
||||
|
||||
CudaBuffer* buf = buffer_cache_.reuse_from_cache(size);
|
||||
if (!buf) {
|
||||
// If we have a lot of memory pressure or are over the maximum cache size,
|
||||
// try to reclaim memory from the cache.
|
||||
// If we have a lot of memory pressure try to reclaim memory from the cache.
|
||||
int64_t mem_to_free =
|
||||
get_active_memory() + get_cache_memory() + size - memory_limit_;
|
||||
mem_to_free = std::max(
|
||||
static_cast<int64_t>(get_cache_memory() - max_pool_size_), mem_to_free);
|
||||
if (mem_to_free > 0) {
|
||||
buffer_cache_.release_cached_buffers(mem_to_free);
|
||||
}
|
||||
|
||||
buf = new CudaBuffer{nullptr, size};
|
||||
|
||||
// Try the scalar pool first
|
||||
if (size <= small_block_size) {
|
||||
buf->data = scalar_pool_.malloc();
|
||||
buf = scalar_pool_.malloc();
|
||||
}
|
||||
lock.unlock();
|
||||
if (!buf->data) {
|
||||
if (!buf) {
|
||||
buf = new CudaBuffer{nullptr, size};
|
||||
cudaError_t err = cudaMallocManaged(&buf->data, size);
|
||||
if (err != cudaSuccess && err != cudaErrorMemoryAllocation) {
|
||||
throw std::runtime_error(fmt::format(
|
||||
@@ -139,7 +137,11 @@ void CudaAllocator::free(Buffer buffer) {
|
||||
|
||||
std::unique_lock lock(mutex_);
|
||||
active_memory_ -= buf->size;
|
||||
buffer_cache_.recycle_to_cache(buf);
|
||||
if (get_cache_memory() < max_pool_size_) {
|
||||
buffer_cache_.recycle_to_cache(buf);
|
||||
} else {
|
||||
cuda_free(buf);
|
||||
}
|
||||
}
|
||||
|
||||
size_t CudaAllocator::size(Buffer buffer) const {
|
||||
@@ -151,11 +153,12 @@ size_t CudaAllocator::size(Buffer buffer) const {
|
||||
}
|
||||
|
||||
// This must be called with mutex_ aquired
|
||||
void CudaAllocator::cuda_free(void* buf) {
|
||||
void CudaAllocator::cuda_free(CudaBuffer* buf) {
|
||||
if (scalar_pool_.in_pool(buf)) {
|
||||
scalar_pool_.free(buf);
|
||||
} else {
|
||||
cudaFree(buf);
|
||||
cudaFree(buf->data);
|
||||
delete buf;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -21,13 +21,14 @@ struct CudaBuffer {
|
||||
|
||||
class SmallSizePool {
|
||||
private:
|
||||
struct Block {
|
||||
union Block {
|
||||
Block* next;
|
||||
CudaBuffer buf;
|
||||
};
|
||||
|
||||
void* buffer_{nullptr};
|
||||
Block* buffer_{nullptr};
|
||||
void* data_{nullptr};
|
||||
Block* next_free_{nullptr};
|
||||
void* end_{nullptr};
|
||||
|
||||
public:
|
||||
SmallSizePool();
|
||||
@@ -36,9 +37,9 @@ class SmallSizePool {
|
||||
SmallSizePool(const SmallSizePool&) = delete;
|
||||
SmallSizePool& operator=(const SmallSizePool&) = delete;
|
||||
|
||||
void* malloc();
|
||||
void free(void* p);
|
||||
bool in_pool(void* p);
|
||||
CudaBuffer* malloc();
|
||||
void free(CudaBuffer* buf);
|
||||
bool in_pool(CudaBuffer* buf);
|
||||
};
|
||||
|
||||
class CudaAllocator : public allocator::Allocator {
|
||||
@@ -57,7 +58,7 @@ class CudaAllocator : public allocator::Allocator {
|
||||
void clear_cache();
|
||||
|
||||
private:
|
||||
void cuda_free(void* buf);
|
||||
void cuda_free(CudaBuffer* buf);
|
||||
|
||||
CudaAllocator();
|
||||
friend CudaAllocator& allocator();
|
||||
|
||||
@@ -128,8 +128,7 @@ Buffer MetalAllocator::malloc(size_t size) {
|
||||
|
||||
auto pool = metal::new_scoped_memory_pool();
|
||||
|
||||
// If we have a lot of memory pressure or are over the maximum cache size,
|
||||
// try to reclaim memory from the cache
|
||||
// If we have a lot of memory pressure try to reclaim memory from the cache
|
||||
if (mem_required >= gc_limit_ || num_resources_ >= resource_limit_) {
|
||||
num_resources_ -=
|
||||
buffer_cache_.release_cached_buffers(mem_required - gc_limit_);
|
||||
|
||||
Reference in New Issue
Block a user