mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-27 20:07:59 +08:00
[CUDA] speedup handling scalars (#2389)
* speedup scalars in cuda * comment
This commit is contained in:
parent
63f663d9c6
commit
93d70419e7
@ -17,6 +17,52 @@ namespace cu {
|
|||||||
|
|
||||||
constexpr int page_size = 16384;
|
constexpr int page_size = 16384;
|
||||||
|
|
||||||
|
// Any allocations smaller than this will try to use the small pool
|
||||||
|
constexpr int small_block_size = 8;
|
||||||
|
|
||||||
|
// The small pool size in bytes. This should be a multiple of the host page
|
||||||
|
// size and small_block_size.
|
||||||
|
constexpr int small_pool_size = 4 * page_size;
|
||||||
|
|
||||||
|
SmallSizePool::SmallSizePool() {
|
||||||
|
CHECK_CUDA_ERROR(cudaMallocManaged(&buffer_, small_pool_size));
|
||||||
|
end_ = reinterpret_cast<void*>(
|
||||||
|
reinterpret_cast<char*>(buffer_) + small_pool_size);
|
||||||
|
next_free_ = reinterpret_cast<Block*>(buffer_);
|
||||||
|
|
||||||
|
auto num_blocks = small_pool_size / small_block_size;
|
||||||
|
auto curr = next_free_;
|
||||||
|
for (size_t i = 0; i < num_blocks - 1; ++i) {
|
||||||
|
curr->next = reinterpret_cast<Block*>(
|
||||||
|
reinterpret_cast<char*>(buffer_) + (i + 1) * small_block_size);
|
||||||
|
curr = curr->next;
|
||||||
|
}
|
||||||
|
curr->next = nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallSizePool::~SmallSizePool() {
|
||||||
|
CHECK_CUDA_ERROR(cudaFree(buffer_));
|
||||||
|
}
|
||||||
|
|
||||||
|
void* SmallSizePool::malloc() {
|
||||||
|
if (next_free_ == nullptr) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
Block* b = next_free_;
|
||||||
|
next_free_ = next_free_->next;
|
||||||
|
return static_cast<void*>(b);
|
||||||
|
}
|
||||||
|
|
||||||
|
void SmallSizePool::free(void* p) {
|
||||||
|
auto b = static_cast<Block*>(p);
|
||||||
|
b->next = next_free_;
|
||||||
|
next_free_ = b;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool SmallSizePool::in_pool(void* p) {
|
||||||
|
return (p >= buffer_) && (p < end_);
|
||||||
|
}
|
||||||
|
|
||||||
CudaAllocator::CudaAllocator()
|
CudaAllocator::CudaAllocator()
|
||||||
: buffer_cache_(
|
: buffer_cache_(
|
||||||
page_size,
|
page_size,
|
||||||
@ -36,7 +82,9 @@ Buffer CudaAllocator::malloc(size_t size) {
|
|||||||
// Find available buffer from cache.
|
// Find available buffer from cache.
|
||||||
auto orig_size = size;
|
auto orig_size = size;
|
||||||
std::unique_lock lock(mutex_);
|
std::unique_lock lock(mutex_);
|
||||||
if (size < page_size) {
|
if (size <= small_block_size) {
|
||||||
|
size = 8;
|
||||||
|
} else if (size < page_size) {
|
||||||
size = next_power_of_2(size);
|
size = next_power_of_2(size);
|
||||||
} else {
|
} else {
|
||||||
size = page_size * ((size + page_size - 1) / page_size);
|
size = page_size * ((size + page_size - 1) / page_size);
|
||||||
@ -53,11 +101,19 @@ Buffer CudaAllocator::malloc(size_t size) {
|
|||||||
|
|
||||||
lock.unlock();
|
lock.unlock();
|
||||||
buf = new CudaBuffer{nullptr, size};
|
buf = new CudaBuffer{nullptr, size};
|
||||||
cudaError_t err = cudaMallocManaged(&buf->data, size);
|
|
||||||
if (err != cudaSuccess && err != cudaErrorMemoryAllocation) {
|
// Try the scalar pool first
|
||||||
throw std::runtime_error(fmt::format(
|
if (size <= small_block_size) {
|
||||||
"cudaMallocManaged failed: {}.", cudaGetErrorString(err)));
|
buf->data = scalar_pool_.malloc();
|
||||||
}
|
}
|
||||||
|
if (!buf->data) {
|
||||||
|
cudaError_t err = cudaMallocManaged(&buf->data, size);
|
||||||
|
if (err != cudaSuccess && err != cudaErrorMemoryAllocation) {
|
||||||
|
throw std::runtime_error(fmt::format(
|
||||||
|
"cudaMallocManaged failed: {}.", cudaGetErrorString(err)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
lock.lock();
|
lock.lock();
|
||||||
}
|
}
|
||||||
active_memory_ += size;
|
active_memory_ += size;
|
||||||
@ -116,7 +172,11 @@ void CudaAllocator::cuda_free(void* buf) {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
cudaFree(buf);
|
if (scalar_pool_.in_pool(buf)) {
|
||||||
|
scalar_pool_.free(buf);
|
||||||
|
} else {
|
||||||
|
cudaFree(buf);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t CudaAllocator::get_active_memory() const {
|
size_t CudaAllocator::get_active_memory() const {
|
||||||
|
@ -22,6 +22,28 @@ struct CudaBuffer {
|
|||||||
size_t size;
|
size_t size;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class SmallSizePool {
|
||||||
|
private:
|
||||||
|
struct Block {
|
||||||
|
Block* next;
|
||||||
|
};
|
||||||
|
|
||||||
|
void* buffer_{nullptr};
|
||||||
|
Block* next_free_{nullptr};
|
||||||
|
void* end_{nullptr};
|
||||||
|
|
||||||
|
public:
|
||||||
|
SmallSizePool();
|
||||||
|
~SmallSizePool();
|
||||||
|
|
||||||
|
SmallSizePool(const SmallSizePool&) = delete;
|
||||||
|
SmallSizePool& operator=(const SmallSizePool&) = delete;
|
||||||
|
|
||||||
|
void* malloc();
|
||||||
|
void free(void* p);
|
||||||
|
bool in_pool(void* p);
|
||||||
|
};
|
||||||
|
|
||||||
class CudaAllocator : public allocator::Allocator {
|
class CudaAllocator : public allocator::Allocator {
|
||||||
public:
|
public:
|
||||||
Buffer malloc(size_t size) override;
|
Buffer malloc(size_t size) override;
|
||||||
@ -60,6 +82,7 @@ class CudaAllocator : public allocator::Allocator {
|
|||||||
BufferCache<CudaBuffer> buffer_cache_;
|
BufferCache<CudaBuffer> buffer_cache_;
|
||||||
size_t active_memory_{0};
|
size_t active_memory_{0};
|
||||||
size_t peak_memory_{0};
|
size_t peak_memory_{0};
|
||||||
|
SmallSizePool scalar_pool_;
|
||||||
};
|
};
|
||||||
|
|
||||||
CudaAllocator& allocator();
|
CudaAllocator& allocator();
|
||||||
|
Loading…
Reference in New Issue
Block a user