mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
[CUDA] Reduce use of managed memory (#2725)
Some checks failed
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.10) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_with_tests (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
Nightly Build / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Nightly Build / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
Some checks failed
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.10) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_with_tests (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
Nightly Build / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Nightly Build / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
* Use async cuda malloc managed with cuda 13 * add pool threshold * refactor for regular cuda malloc * load eval gpu for cuda * remove use of cuda pool, use cuda free async * fix * fix * fix * fix * fix + comment
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/allocator.h"
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/utils.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
@@ -67,6 +68,7 @@ CudaBuffer* SmallSizePool::malloc() {
|
||||
next_free_ = next_free_->next;
|
||||
b->buf.data = static_cast<char*>(data_) + i * small_block_size;
|
||||
b->buf.size = small_block_size;
|
||||
b->buf.device = -1;
|
||||
return &b->buf;
|
||||
}
|
||||
|
||||
@@ -88,14 +90,40 @@ CudaAllocator::CudaAllocator()
|
||||
page_size,
|
||||
[](CudaBuffer* buf) { return buf->size; },
|
||||
[this](CudaBuffer* buf) { cuda_free(buf); }) {
|
||||
// TODO: Set memory limit for multi-device.
|
||||
size_t free, total;
|
||||
CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total));
|
||||
memory_limit_ = total * 0.95;
|
||||
max_pool_size_ = memory_limit_;
|
||||
|
||||
int device_count = 0;
|
||||
CHECK_CUDA_ERROR(cudaGetDeviceCount(&device_count));
|
||||
int curr;
|
||||
CHECK_CUDA_ERROR(cudaGetDevice(&curr));
|
||||
for (int i = 0; i < device_count; ++i) {
|
||||
CHECK_CUDA_ERROR(cudaSetDevice(i));
|
||||
cudaStream_t s;
|
||||
CHECK_CUDA_ERROR(cudaStreamCreateWithFlags(&s, cudaStreamNonBlocking));
|
||||
free_streams_.push_back(s);
|
||||
}
|
||||
CHECK_CUDA_ERROR(cudaSetDevice(curr));
|
||||
}
|
||||
|
||||
Buffer CudaAllocator::malloc(size_t size) {
|
||||
void copy_to_managed(CudaBuffer& buf) {
|
||||
// TODO maybe make this async on a i/o stream to avoid synchronizing the
|
||||
// device on malloc/and free
|
||||
void* new_data;
|
||||
CHECK_CUDA_ERROR(cudaMallocManaged(&new_data, buf.size));
|
||||
buf.device = -1;
|
||||
CHECK_CUDA_ERROR(cudaMemcpy(new_data, buf.data, buf.size, cudaMemcpyDefault));
|
||||
CHECK_CUDA_ERROR(cudaFree(buf.data));
|
||||
buf.data = new_data;
|
||||
}
|
||||
|
||||
Buffer CudaAllocator::malloc_impl(size_t size, cudaStream_t stream) {
|
||||
if (size == 0) {
|
||||
return Buffer{new CudaBuffer{nullptr, 0, -1}};
|
||||
}
|
||||
|
||||
// Find available buffer from cache.
|
||||
std::unique_lock lock(mutex_);
|
||||
if (size <= small_block_size) {
|
||||
@@ -106,6 +134,11 @@ Buffer CudaAllocator::malloc(size_t size) {
|
||||
size = page_size * ((size + page_size - 1) / page_size);
|
||||
}
|
||||
|
||||
int device = -1;
|
||||
if (size > small_block_size && stream != nullptr) {
|
||||
CHECK_CUDA_ERROR(cudaStreamGetDevice(stream, &device));
|
||||
}
|
||||
|
||||
CudaBuffer* buf = buffer_cache_.reuse_from_cache(size);
|
||||
if (!buf) {
|
||||
// If we have a lot of memory pressure try to reclaim memory from the cache.
|
||||
@@ -121,8 +154,13 @@ Buffer CudaAllocator::malloc(size_t size) {
|
||||
}
|
||||
lock.unlock();
|
||||
if (!buf) {
|
||||
buf = new CudaBuffer{nullptr, size};
|
||||
cudaError_t err = cudaMallocManaged(&buf->data, size);
|
||||
buf = new CudaBuffer{nullptr, size, device};
|
||||
cudaError_t err;
|
||||
if (device == -1) {
|
||||
err = cudaMallocManaged(&buf->data, size);
|
||||
} else {
|
||||
err = cudaMallocAsync(&buf->data, size, stream);
|
||||
}
|
||||
if (err != cudaSuccess && err != cudaErrorMemoryAllocation) {
|
||||
throw std::runtime_error(fmt::format(
|
||||
"cudaMallocManaged failed: {}.", cudaGetErrorString(err)));
|
||||
@@ -137,14 +175,30 @@ Buffer CudaAllocator::malloc(size_t size) {
|
||||
if (get_cache_memory() > max_pool_size_) {
|
||||
buffer_cache_.release_cached_buffers(get_cache_memory() - max_pool_size_);
|
||||
}
|
||||
// Copy to managed here if the buffer is not on the right device
|
||||
if (buf->device != device) {
|
||||
copy_to_managed(*buf);
|
||||
}
|
||||
return Buffer{buf};
|
||||
}
|
||||
|
||||
Buffer CudaAllocator::malloc_async(size_t size, cudaStream_t stream) {
|
||||
return malloc_impl(size, stream);
|
||||
}
|
||||
|
||||
Buffer CudaAllocator::malloc(size_t size) {
|
||||
return malloc_impl(size, nullptr);
|
||||
}
|
||||
|
||||
void CudaAllocator::free(Buffer buffer) {
|
||||
auto* buf = static_cast<CudaBuffer*>(buffer.ptr());
|
||||
if (!buf) {
|
||||
return;
|
||||
}
|
||||
if (buf->size == 0) {
|
||||
delete buf;
|
||||
return;
|
||||
}
|
||||
|
||||
std::unique_lock lock(mutex_);
|
||||
active_memory_ -= buf->size;
|
||||
@@ -168,7 +222,11 @@ void CudaAllocator::cuda_free(CudaBuffer* buf) {
|
||||
if (scalar_pool_.in_pool(buf)) {
|
||||
scalar_pool_.free(buf);
|
||||
} else {
|
||||
cudaFree(buf->data);
|
||||
if (buf->device >= 0) {
|
||||
cudaFreeAsync(buf->data, free_streams_[buf->device]);
|
||||
} else {
|
||||
cudaFree(buf->data);
|
||||
}
|
||||
delete buf;
|
||||
}
|
||||
}
|
||||
@@ -219,6 +277,16 @@ CudaAllocator& allocator() {
|
||||
return *allocator_;
|
||||
}
|
||||
|
||||
Buffer malloc_async(size_t size, cudaStream_t stream) {
|
||||
auto buffer = allocator().malloc_async(size, stream);
|
||||
if (size && !buffer.ptr()) {
|
||||
std::ostringstream msg;
|
||||
msg << "[malloc_async] Unable to allocate " << size << " bytes.";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
return buffer;
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
|
||||
namespace allocator {
|
||||
@@ -231,7 +299,11 @@ void* Buffer::raw_ptr() {
|
||||
if (!ptr_) {
|
||||
return nullptr;
|
||||
}
|
||||
return static_cast<cu::CudaBuffer*>(ptr_)->data;
|
||||
auto& cbuf = *static_cast<cu::CudaBuffer*>(ptr_);
|
||||
if (cbuf.device != -1) {
|
||||
copy_to_managed(cbuf);
|
||||
}
|
||||
return cbuf.data;
|
||||
}
|
||||
|
||||
} // namespace allocator
|
||||
|
||||
Reference in New Issue
Block a user