mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
use cuda buffer in small pool
This commit is contained in:
@@ -2,7 +2,6 @@
|
|||||||
|
|
||||||
#include "mlx/backend/cuda/allocator.h"
|
#include "mlx/backend/cuda/allocator.h"
|
||||||
#include "mlx/backend/cuda/utils.h"
|
#include "mlx/backend/cuda/utils.h"
|
||||||
#include "mlx/backend/cuda/worker.h"
|
|
||||||
#include "mlx/utils.h"
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
#include <cuda_runtime.h>
|
#include <cuda_runtime.h>
|
||||||
@@ -25,55 +24,58 @@ constexpr int small_block_size = 8;
|
|||||||
constexpr int small_pool_size = 4 * page_size;
|
constexpr int small_pool_size = 4 * page_size;
|
||||||
|
|
||||||
SmallSizePool::SmallSizePool() {
|
SmallSizePool::SmallSizePool() {
|
||||||
CHECK_CUDA_ERROR(cudaMallocManaged(&buffer_, small_pool_size));
|
|
||||||
end_ = reinterpret_cast<void*>(
|
|
||||||
reinterpret_cast<char*>(buffer_) + small_pool_size);
|
|
||||||
next_free_ = reinterpret_cast<Block*>(buffer_);
|
|
||||||
|
|
||||||
CHECK_CUDA_ERROR(
|
|
||||||
cudaMemAdvise(buffer_, small_pool_size, cudaMemAdviseSetReadMostly, 0));
|
|
||||||
|
|
||||||
auto num_blocks = small_pool_size / small_block_size;
|
auto num_blocks = small_pool_size / small_block_size;
|
||||||
|
buffer_ = new Block[num_blocks];
|
||||||
|
|
||||||
|
next_free_ = buffer_;
|
||||||
|
|
||||||
|
CHECK_CUDA_ERROR(cudaMallocManaged(&data_, small_pool_size));
|
||||||
|
CHECK_CUDA_ERROR(
|
||||||
|
cudaMemAdvise(data_, small_pool_size, cudaMemAdviseSetReadMostly, 0));
|
||||||
|
|
||||||
auto curr = next_free_;
|
auto curr = next_free_;
|
||||||
for (size_t i = 0; i < num_blocks - 1; ++i) {
|
for (size_t i = 1; i < num_blocks; ++i) {
|
||||||
curr->next = reinterpret_cast<Block*>(
|
curr->next = buffer_ + i;
|
||||||
reinterpret_cast<char*>(buffer_) + (i + 1) * small_block_size);
|
|
||||||
curr = curr->next;
|
curr = curr->next;
|
||||||
}
|
}
|
||||||
curr->next = nullptr;
|
curr->next = nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
SmallSizePool::~SmallSizePool() {
|
SmallSizePool::~SmallSizePool() {
|
||||||
CHECK_CUDA_ERROR(cudaFree(buffer_));
|
CHECK_CUDA_ERROR(cudaFree(data_));
|
||||||
|
delete[] buffer_;
|
||||||
}
|
}
|
||||||
|
|
||||||
void* SmallSizePool::malloc() {
|
CudaBuffer* SmallSizePool::malloc() {
|
||||||
if (next_free_ == nullptr) {
|
if (next_free_ == nullptr) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
Block* b = next_free_;
|
Block* b = next_free_;
|
||||||
|
uint64_t i = next_free_ - buffer_;
|
||||||
next_free_ = next_free_->next;
|
next_free_ = next_free_->next;
|
||||||
return static_cast<void*>(b);
|
b->buf.data = static_cast<char*>(data_) + i * small_block_size;
|
||||||
|
b->buf.size = small_block_size;
|
||||||
|
return &b->buf;
|
||||||
}
|
}
|
||||||
|
|
||||||
void SmallSizePool::free(void* p) {
|
void SmallSizePool::free(CudaBuffer* buf) {
|
||||||
auto b = static_cast<Block*>(p);
|
auto b = reinterpret_cast<Block*>(buf);
|
||||||
b->next = next_free_;
|
b->next = next_free_;
|
||||||
next_free_ = b;
|
next_free_ = b;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool SmallSizePool::in_pool(void* p) {
|
bool SmallSizePool::in_pool(CudaBuffer* buf) {
|
||||||
return (p >= buffer_) && (p < end_);
|
constexpr int num_blocks = (small_pool_size / small_block_size);
|
||||||
|
auto b = reinterpret_cast<Block*>(buf);
|
||||||
|
int64_t block_num = b - buffer_;
|
||||||
|
return block_num >= 0 && block_num < num_blocks;
|
||||||
}
|
}
|
||||||
|
|
||||||
CudaAllocator::CudaAllocator()
|
CudaAllocator::CudaAllocator()
|
||||||
: buffer_cache_(
|
: buffer_cache_(
|
||||||
page_size,
|
page_size,
|
||||||
[](CudaBuffer* buf) { return buf->size; },
|
[](CudaBuffer* buf) { return buf->size; },
|
||||||
[this](CudaBuffer* buf) {
|
[this](CudaBuffer* buf) { cuda_free(buf); }) {
|
||||||
cuda_free(buf->data);
|
|
||||||
delete buf;
|
|
||||||
}) {
|
|
||||||
// TODO: Set memory limit for multi-device.
|
// TODO: Set memory limit for multi-device.
|
||||||
size_t free, total;
|
size_t free, total;
|
||||||
CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total));
|
CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total));
|
||||||
@@ -95,24 +97,20 @@ Buffer CudaAllocator::malloc(size_t size) {
|
|||||||
|
|
||||||
CudaBuffer* buf = buffer_cache_.reuse_from_cache(size);
|
CudaBuffer* buf = buffer_cache_.reuse_from_cache(size);
|
||||||
if (!buf) {
|
if (!buf) {
|
||||||
// If we have a lot of memory pressure or are over the maximum cache size,
|
// If we have a lot of memory pressure try to reclaim memory from the cache.
|
||||||
// try to reclaim memory from the cache.
|
|
||||||
int64_t mem_to_free =
|
int64_t mem_to_free =
|
||||||
get_active_memory() + get_cache_memory() + size - memory_limit_;
|
get_active_memory() + get_cache_memory() + size - memory_limit_;
|
||||||
mem_to_free = std::max(
|
|
||||||
static_cast<int64_t>(get_cache_memory() - max_pool_size_), mem_to_free);
|
|
||||||
if (mem_to_free > 0) {
|
if (mem_to_free > 0) {
|
||||||
buffer_cache_.release_cached_buffers(mem_to_free);
|
buffer_cache_.release_cached_buffers(mem_to_free);
|
||||||
}
|
}
|
||||||
|
|
||||||
buf = new CudaBuffer{nullptr, size};
|
|
||||||
|
|
||||||
// Try the scalar pool first
|
// Try the scalar pool first
|
||||||
if (size <= small_block_size) {
|
if (size <= small_block_size) {
|
||||||
buf->data = scalar_pool_.malloc();
|
buf = scalar_pool_.malloc();
|
||||||
}
|
}
|
||||||
lock.unlock();
|
lock.unlock();
|
||||||
if (!buf->data) {
|
if (!buf) {
|
||||||
|
buf = new CudaBuffer{nullptr, size};
|
||||||
cudaError_t err = cudaMallocManaged(&buf->data, size);
|
cudaError_t err = cudaMallocManaged(&buf->data, size);
|
||||||
if (err != cudaSuccess && err != cudaErrorMemoryAllocation) {
|
if (err != cudaSuccess && err != cudaErrorMemoryAllocation) {
|
||||||
throw std::runtime_error(fmt::format(
|
throw std::runtime_error(fmt::format(
|
||||||
@@ -139,7 +137,11 @@ void CudaAllocator::free(Buffer buffer) {
|
|||||||
|
|
||||||
std::unique_lock lock(mutex_);
|
std::unique_lock lock(mutex_);
|
||||||
active_memory_ -= buf->size;
|
active_memory_ -= buf->size;
|
||||||
buffer_cache_.recycle_to_cache(buf);
|
if (get_cache_memory() < max_pool_size_) {
|
||||||
|
buffer_cache_.recycle_to_cache(buf);
|
||||||
|
} else {
|
||||||
|
cuda_free(buf);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t CudaAllocator::size(Buffer buffer) const {
|
size_t CudaAllocator::size(Buffer buffer) const {
|
||||||
@@ -151,11 +153,12 @@ size_t CudaAllocator::size(Buffer buffer) const {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// This must be called with mutex_ aquired
|
// This must be called with mutex_ aquired
|
||||||
void CudaAllocator::cuda_free(void* buf) {
|
void CudaAllocator::cuda_free(CudaBuffer* buf) {
|
||||||
if (scalar_pool_.in_pool(buf)) {
|
if (scalar_pool_.in_pool(buf)) {
|
||||||
scalar_pool_.free(buf);
|
scalar_pool_.free(buf);
|
||||||
} else {
|
} else {
|
||||||
cudaFree(buf);
|
cudaFree(buf->data);
|
||||||
|
delete buf;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -21,13 +21,14 @@ struct CudaBuffer {
|
|||||||
|
|
||||||
class SmallSizePool {
|
class SmallSizePool {
|
||||||
private:
|
private:
|
||||||
struct Block {
|
union Block {
|
||||||
Block* next;
|
Block* next;
|
||||||
|
CudaBuffer buf;
|
||||||
};
|
};
|
||||||
|
|
||||||
void* buffer_{nullptr};
|
Block* buffer_{nullptr};
|
||||||
|
void* data_{nullptr};
|
||||||
Block* next_free_{nullptr};
|
Block* next_free_{nullptr};
|
||||||
void* end_{nullptr};
|
|
||||||
|
|
||||||
public:
|
public:
|
||||||
SmallSizePool();
|
SmallSizePool();
|
||||||
@@ -36,9 +37,9 @@ class SmallSizePool {
|
|||||||
SmallSizePool(const SmallSizePool&) = delete;
|
SmallSizePool(const SmallSizePool&) = delete;
|
||||||
SmallSizePool& operator=(const SmallSizePool&) = delete;
|
SmallSizePool& operator=(const SmallSizePool&) = delete;
|
||||||
|
|
||||||
void* malloc();
|
CudaBuffer* malloc();
|
||||||
void free(void* p);
|
void free(CudaBuffer* buf);
|
||||||
bool in_pool(void* p);
|
bool in_pool(CudaBuffer* buf);
|
||||||
};
|
};
|
||||||
|
|
||||||
class CudaAllocator : public allocator::Allocator {
|
class CudaAllocator : public allocator::Allocator {
|
||||||
@@ -57,7 +58,7 @@ class CudaAllocator : public allocator::Allocator {
|
|||||||
void clear_cache();
|
void clear_cache();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void cuda_free(void* buf);
|
void cuda_free(CudaBuffer* buf);
|
||||||
|
|
||||||
CudaAllocator();
|
CudaAllocator();
|
||||||
friend CudaAllocator& allocator();
|
friend CudaAllocator& allocator();
|
||||||
|
|||||||
@@ -128,8 +128,7 @@ Buffer MetalAllocator::malloc(size_t size) {
|
|||||||
|
|
||||||
auto pool = metal::new_scoped_memory_pool();
|
auto pool = metal::new_scoped_memory_pool();
|
||||||
|
|
||||||
// If we have a lot of memory pressure or are over the maximum cache size,
|
// If we have a lot of memory pressure try to reclaim memory from the cache
|
||||||
// try to reclaim memory from the cache
|
|
||||||
if (mem_required >= gc_limit_ || num_resources_ >= resource_limit_) {
|
if (mem_required >= gc_limit_ || num_resources_ >= resource_limit_) {
|
||||||
num_resources_ -=
|
num_resources_ -=
|
||||||
buffer_cache_.release_cached_buffers(mem_required - gc_limit_);
|
buffer_cache_.release_cached_buffers(mem_required - gc_limit_);
|
||||||
|
|||||||
Reference in New Issue
Block a user