remove use of cuda pool, use cuda free async

This commit is contained in:
Awni Hannun
2025-11-03 09:14:17 -08:00
parent c27a0647a3
commit 742033fefe
6 changed files with 132 additions and 111 deletions

View File

@@ -349,7 +349,10 @@ class array {
return array_desc_->data; return array_desc_->data;
} }
// Return a raw pointer to the arrays data // Return a raw pointer to the arrays data. This function may do a copy if
// the underlying buffer is not accessible on the CPU. When accessing the
// data for GPU kernels, be sure to use the correct method / function for the
// given backend to access the GPU pointer.
template <typename T> template <typename T>
T* data() { T* data() {
return reinterpret_cast<T*>( return reinterpret_cast<T*>(

View File

@@ -68,7 +68,7 @@ CudaBuffer* SmallSizePool::malloc() {
next_free_ = next_free_->next; next_free_ = next_free_->next;
b->buf.data = static_cast<char*>(data_) + i * small_block_size; b->buf.data = static_cast<char*>(data_) + i * small_block_size;
b->buf.size = small_block_size; b->buf.size = small_block_size;
b->buf.managed = true; b->buf.device = -1;
return &b->buf; return &b->buf;
} }
@@ -94,10 +94,15 @@ CudaAllocator::CudaAllocator()
CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total)); CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total));
memory_limit_ = total * 0.95; memory_limit_ = total * 0.95;
max_pool_size_ = memory_limit_; max_pool_size_ = memory_limit_;
int loc = 0;
CHECK_CUDA_ERROR(cudaDeviceGetDefaultMemPool(&cuda_pool_, loc)); int device_count = 0;
CHECK_CUDA_ERROR(cudaMemPoolSetAttribute( CHECK_CUDA_ERROR(cudaGetDeviceCount(&device_count));
cuda_pool_, cudaMemPoolAttrReleaseThreshold, &memory_limit_)); int curr_device = 0;
CHECK_CUDA_ERROR(cudaGetDevice(&curr_device));
for (int i = 0; i < device_count; ++i) {
free_streams_.emplace_back(
cu::device(mlx::core::Device{mlx::core::Device::gpu, i}));
}
} }
Buffer CudaAllocator::malloc_impl(size_t size, cudaStream_t stream) { Buffer CudaAllocator::malloc_impl(size_t size, cudaStream_t stream) {
@@ -127,13 +132,16 @@ Buffer CudaAllocator::malloc_impl(size_t size, cudaStream_t stream) {
} }
lock.unlock(); lock.unlock();
if (!buf) { if (!buf) {
bool managed = stream == nullptr; int device = -1;
buf = new CudaBuffer{nullptr, size, managed}; if (stream != nullptr) {
cudaStreamGetDevice(stream, &device);
}
buf = new CudaBuffer{nullptr, size, device};
cudaError_t err; cudaError_t err;
if (managed) { if (device == -1) {
err = cudaMallocManaged(&buf->data, size); err = cudaMallocManaged(&buf->data, size);
} else { } else {
err = cudaMallocFromPoolAsync(&buf->data, size, cuda_pool_, stream); err = cudaMallocAsync(&buf->data, size, stream);
} }
if (err != cudaSuccess && err != cudaErrorMemoryAllocation) { if (err != cudaSuccess && err != cudaErrorMemoryAllocation) {
throw std::runtime_error(fmt::format( throw std::runtime_error(fmt::format(
@@ -188,7 +196,11 @@ void CudaAllocator::cuda_free(CudaBuffer* buf) {
if (scalar_pool_.in_pool(buf)) { if (scalar_pool_.in_pool(buf)) {
scalar_pool_.free(buf); scalar_pool_.free(buf);
} else { } else {
cudaFree(buf->data); if (buf->device >= 0) {
cudaFreeAsync(buf->data, free_streams_[buf->device]);
} else {
cudaFree(buf->data);
}
delete buf; delete buf;
} }
} }
@@ -213,9 +225,6 @@ size_t CudaAllocator::get_memory_limit() {
size_t CudaAllocator::set_memory_limit(size_t limit) { size_t CudaAllocator::set_memory_limit(size_t limit) {
std::lock_guard lock(mutex_); std::lock_guard lock(mutex_);
std::swap(limit, memory_limit_); std::swap(limit, memory_limit_);
CHECK_CUDA_ERROR(cudaMemPoolTrimTo(cuda_pool_, memory_limit_));
CHECK_CUDA_ERROR(cudaMemPoolSetAttribute(
cuda_pool_, cudaMemPoolAttrReleaseThreshold, &memory_limit_));
return limit; return limit;
} }
@@ -265,12 +274,12 @@ void* Buffer::raw_ptr() {
return nullptr; return nullptr;
} }
auto& cbuf = *static_cast<cu::CudaBuffer*>(ptr_); auto& cbuf = *static_cast<cu::CudaBuffer*>(ptr_);
if (!cbuf.managed) { if (cbuf.device != -1) {
// TODO maybe make this async on a i/o stream to avoid synchronizing the // TODO maybe make this async on a i/o stream to avoid synchronizing the
// device on malloc/and free // device on malloc/and free
void* new_data; void* new_data;
CHECK_CUDA_ERROR(cudaMallocManaged(&new_data, cbuf.size)); CHECK_CUDA_ERROR(cudaMallocManaged(&new_data, cbuf.size));
cbuf.managed = true; cbuf.device = -1;
CHECK_CUDA_ERROR( CHECK_CUDA_ERROR(
cudaMemcpy(new_data, cbuf.data, cbuf.size, cudaMemcpyDefault)); cudaMemcpy(new_data, cbuf.data, cbuf.size, cudaMemcpyDefault));
CHECK_CUDA_ERROR(cudaFree(cbuf.data)); CHECK_CUDA_ERROR(cudaFree(cbuf.data));

View File

@@ -4,6 +4,7 @@
#include "mlx/allocator.h" #include "mlx/allocator.h"
#include "mlx/backend/common/buffer_cache.h" #include "mlx/backend/common/buffer_cache.h"
#include "mlx/backend/cuda/cuda_utils.h"
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <mutex> #include <mutex>
@@ -18,11 +19,11 @@ using allocator::Buffer;
struct CudaBuffer { struct CudaBuffer {
void* data; void* data;
size_t size; size_t size;
bool managed; int device; // -1 for managed
}; };
template <typename T> template <typename T>
T* gpu_ptr(Buffer buf) { inline T* gpu_ptr(Buffer buf) {
return static_cast<T*>(static_cast<cu::CudaBuffer*>(buf.ptr())->data); return static_cast<T*>(static_cast<cu::CudaBuffer*>(buf.ptr())->data);
} }
@@ -78,8 +79,8 @@ class CudaAllocator : public allocator::Allocator {
BufferCache<CudaBuffer> buffer_cache_; BufferCache<CudaBuffer> buffer_cache_;
size_t active_memory_{0}; size_t active_memory_{0};
size_t peak_memory_{0}; size_t peak_memory_{0};
std::vector<CudaStream> free_streams_;
SmallSizePool scalar_pool_; SmallSizePool scalar_pool_;
cudaMemPool_t cuda_pool_{nullptr};
}; };
CudaAllocator& allocator(); CudaAllocator& allocator();

View File

@@ -0,0 +1,82 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include <cublasLt.h>
#include <cuda.h>
#include <cuda_runtime.h>
namespace mlx::core {
// Throw exception if the cuda API does not succeed.
void check_cublas_error(const char* name, cublasStatus_t err);
void check_cuda_error(const char* name, cudaError_t err);
void check_cuda_error(const char* name, CUresult err);
// The macro version that prints the command that failed.
#define CHECK_CUBLAS_ERROR(cmd) check_cublas_error(#cmd, (cmd))
#define CHECK_CUDA_ERROR(cmd) check_cuda_error(#cmd, (cmd))
// Base class for RAII managed CUDA resources.
template <typename Handle, cudaError_t (*Destroy)(Handle)>
class CudaHandle {
public:
CudaHandle(Handle handle = nullptr) : handle_(handle) {}
CudaHandle(CudaHandle&& other) : handle_(other.handle_) {
assert(this != &other);
other.handle_ = nullptr;
}
~CudaHandle() {
reset();
}
CudaHandle(const CudaHandle&) = delete;
CudaHandle& operator=(const CudaHandle&) = delete;
CudaHandle& operator=(CudaHandle&& other) {
assert(this != &other);
reset();
std::swap(handle_, other.handle_);
return *this;
}
void reset() {
if (handle_ != nullptr) {
CHECK_CUDA_ERROR(Destroy(handle_));
handle_ = nullptr;
}
}
operator Handle() const {
return handle_;
}
protected:
Handle handle_;
};
namespace cu {
class Device;
}; // namespace cu
// Wrappers of CUDA resources.
class CudaGraph : public CudaHandle<cudaGraph_t, cudaGraphDestroy> {
public:
using CudaHandle::CudaHandle;
explicit CudaGraph(cu::Device& device);
void end_capture(cudaStream_t stream);
};
class CudaGraphExec : public CudaHandle<cudaGraphExec_t, cudaGraphExecDestroy> {
public:
void instantiate(cudaGraph_t graph);
};
class CudaStream : public CudaHandle<cudaStream_t, cudaStreamDestroy> {
public:
explicit CudaStream(cu::Device& device);
};
} // namespace mlx::core

View File

@@ -4,101 +4,12 @@
#pragma once #pragma once
#include <cublasLt.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include "mlx/array.h" #include "mlx/array.h"
#include "mlx/backend/cuda/allocator.h" #include "mlx/backend/cuda/allocator.h"
#include "mlx/backend/cuda/cuda_utils.h"
namespace mlx::core { namespace mlx::core {
namespace cu {
class Device;
}
template <typename T>
T* gpu_ptr(array& arr) {
return cu::gpu_ptr<T>(arr.buffer());
}
template <typename T>
const T* gpu_ptr(const array& arr) {
return cu::gpu_ptr<T>(arr.buffer());
}
struct Dtype;
// Throw exception if the cuda API does not succeed.
void check_cublas_error(const char* name, cublasStatus_t err);
void check_cuda_error(const char* name, cudaError_t err);
void check_cuda_error(const char* name, CUresult err);
// The macro version that prints the command that failed.
#define CHECK_CUBLAS_ERROR(cmd) check_cublas_error(#cmd, (cmd))
#define CHECK_CUDA_ERROR(cmd) check_cuda_error(#cmd, (cmd))
// Convert Dtype to CUDA C++ types.
const char* dtype_to_cuda_type(const Dtype& dtype);
// Base class for RAII managed CUDA resources.
template <typename Handle, cudaError_t (*Destroy)(Handle)>
class CudaHandle {
public:
CudaHandle(Handle handle = nullptr) : handle_(handle) {}
CudaHandle(CudaHandle&& other) : handle_(other.handle_) {
assert(this != &other);
other.handle_ = nullptr;
}
~CudaHandle() {
reset();
}
CudaHandle(const CudaHandle&) = delete;
CudaHandle& operator=(const CudaHandle&) = delete;
CudaHandle& operator=(CudaHandle&& other) {
assert(this != &other);
reset();
std::swap(handle_, other.handle_);
return *this;
}
void reset() {
if (handle_ != nullptr) {
CHECK_CUDA_ERROR(Destroy(handle_));
handle_ = nullptr;
}
}
operator Handle() const {
return handle_;
}
protected:
Handle handle_;
};
// Wrappers of CUDA resources.
class CudaGraph : public CudaHandle<cudaGraph_t, cudaGraphDestroy> {
public:
using CudaHandle::CudaHandle;
explicit CudaGraph(cu::Device& device);
void end_capture(cudaStream_t stream);
};
class CudaGraphExec : public CudaHandle<cudaGraphExec_t, cudaGraphExecDestroy> {
public:
void instantiate(cudaGraph_t graph);
};
class CudaStream : public CudaHandle<cudaStream_t, cudaStreamDestroy> {
public:
explicit CudaStream(cu::Device& device);
};
template <typename T> template <typename T>
inline uint max_occupancy_block_dim(T kernel) { inline uint max_occupancy_block_dim(T kernel) {
int _, block_dim; int _, block_dim;
@@ -112,4 +23,19 @@ inline uint max_occupancy_block_dim(T kernel) {
return block_dim; return block_dim;
} }
template <typename T>
inline T* gpu_ptr(array& arr) {
return cu::gpu_ptr<T>(arr.buffer());
}
template <typename T>
inline const T* gpu_ptr(const array& arr) {
return cu::gpu_ptr<T>(arr.buffer());
}
struct Dtype;
// Convert Dtype to CUDA C++ types.
const char* dtype_to_cuda_type(const Dtype& dtype);
} // namespace mlx::core } // namespace mlx::core

View File

@@ -326,8 +326,8 @@ class NCCLGroup : public GroupImpl {
auto& encoder = cu::get_command_encoder(stream); auto& encoder = cu::get_command_encoder(stream);
CHECK_NCCL(ncclAllReduce( CHECK_NCCL(ncclAllReduce(
input.data<T>(), gpu_ptr<T>(input),
output.data<T>(), gpu_ptr<T>(output),
input.size(), input.size(),
dt, dt,
op, op,