diff --git a/mlx/array.h b/mlx/array.h index b7ae8996c..40a7aa87c 100644 --- a/mlx/array.h +++ b/mlx/array.h @@ -349,7 +349,10 @@ class array { return array_desc_->data; } - // Return a raw pointer to the arrays data + // Return a raw pointer to the arrays data. This function may do a copy if + // the underlying buffer is not accessible on the CPU. When accessing the + // data for GPU kernels, be sure to use the correct method / function for the + // given backend to access the GPU pointer. template T* data() { return reinterpret_cast( diff --git a/mlx/backend/cuda/allocator.cpp b/mlx/backend/cuda/allocator.cpp index 918d85741..67ff40e1e 100644 --- a/mlx/backend/cuda/allocator.cpp +++ b/mlx/backend/cuda/allocator.cpp @@ -68,7 +68,7 @@ CudaBuffer* SmallSizePool::malloc() { next_free_ = next_free_->next; b->buf.data = static_cast(data_) + i * small_block_size; b->buf.size = small_block_size; - b->buf.managed = true; + b->buf.device = -1; return &b->buf; } @@ -94,10 +94,15 @@ CudaAllocator::CudaAllocator() CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total)); memory_limit_ = total * 0.95; max_pool_size_ = memory_limit_; - int loc = 0; - CHECK_CUDA_ERROR(cudaDeviceGetDefaultMemPool(&cuda_pool_, loc)); - CHECK_CUDA_ERROR(cudaMemPoolSetAttribute( - cuda_pool_, cudaMemPoolAttrReleaseThreshold, &memory_limit_)); + + int device_count = 0; + CHECK_CUDA_ERROR(cudaGetDeviceCount(&device_count)); + int curr_device = 0; + CHECK_CUDA_ERROR(cudaGetDevice(&curr_device)); + for (int i = 0; i < device_count; ++i) { + free_streams_.emplace_back( + cu::device(mlx::core::Device{mlx::core::Device::gpu, i})); + } } Buffer CudaAllocator::malloc_impl(size_t size, cudaStream_t stream) { @@ -127,13 +132,16 @@ Buffer CudaAllocator::malloc_impl(size_t size, cudaStream_t stream) { } lock.unlock(); if (!buf) { - bool managed = stream == nullptr; - buf = new CudaBuffer{nullptr, size, managed}; + int device = -1; + if (stream != nullptr) { + cudaStreamGetDevice(stream, &device); + } + buf = new CudaBuffer{nullptr, size, device}; cudaError_t err; - if (managed) { + if (device == -1) { err = cudaMallocManaged(&buf->data, size); } else { - err = cudaMallocFromPoolAsync(&buf->data, size, cuda_pool_, stream); + err = cudaMallocAsync(&buf->data, size, stream); } if (err != cudaSuccess && err != cudaErrorMemoryAllocation) { throw std::runtime_error(fmt::format( @@ -188,7 +196,11 @@ void CudaAllocator::cuda_free(CudaBuffer* buf) { if (scalar_pool_.in_pool(buf)) { scalar_pool_.free(buf); } else { - cudaFree(buf->data); + if (buf->device >= 0) { + cudaFreeAsync(buf->data, free_streams_[buf->device]); + } else { + cudaFree(buf->data); + } delete buf; } } @@ -213,9 +225,6 @@ size_t CudaAllocator::get_memory_limit() { size_t CudaAllocator::set_memory_limit(size_t limit) { std::lock_guard lock(mutex_); std::swap(limit, memory_limit_); - CHECK_CUDA_ERROR(cudaMemPoolTrimTo(cuda_pool_, memory_limit_)); - CHECK_CUDA_ERROR(cudaMemPoolSetAttribute( - cuda_pool_, cudaMemPoolAttrReleaseThreshold, &memory_limit_)); return limit; } @@ -265,12 +274,12 @@ void* Buffer::raw_ptr() { return nullptr; } auto& cbuf = *static_cast(ptr_); - if (!cbuf.managed) { + if (cbuf.device != -1) { // TODO maybe make this async on a i/o stream to avoid synchronizing the // device on malloc/and free void* new_data; CHECK_CUDA_ERROR(cudaMallocManaged(&new_data, cbuf.size)); - cbuf.managed = true; + cbuf.device = -1; CHECK_CUDA_ERROR( cudaMemcpy(new_data, cbuf.data, cbuf.size, cudaMemcpyDefault)); CHECK_CUDA_ERROR(cudaFree(cbuf.data)); diff --git a/mlx/backend/cuda/allocator.h b/mlx/backend/cuda/allocator.h index a0f7d38db..df89d0982 100644 --- a/mlx/backend/cuda/allocator.h +++ b/mlx/backend/cuda/allocator.h @@ -4,6 +4,7 @@ #include "mlx/allocator.h" #include "mlx/backend/common/buffer_cache.h" +#include "mlx/backend/cuda/cuda_utils.h" #include #include @@ -18,11 +19,11 @@ using allocator::Buffer; struct CudaBuffer { void* data; size_t size; - bool managed; + int device; // -1 for managed }; template -T* gpu_ptr(Buffer buf) { +inline T* gpu_ptr(Buffer buf) { return static_cast(static_cast(buf.ptr())->data); } @@ -78,8 +79,8 @@ class CudaAllocator : public allocator::Allocator { BufferCache buffer_cache_; size_t active_memory_{0}; size_t peak_memory_{0}; + std::vector free_streams_; SmallSizePool scalar_pool_; - cudaMemPool_t cuda_pool_{nullptr}; }; CudaAllocator& allocator(); diff --git a/mlx/backend/cuda/cuda_utils.h b/mlx/backend/cuda/cuda_utils.h new file mode 100644 index 000000000..538b4ca9c --- /dev/null +++ b/mlx/backend/cuda/cuda_utils.h @@ -0,0 +1,82 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include +#include +#include + +namespace mlx::core { + +// Throw exception if the cuda API does not succeed. +void check_cublas_error(const char* name, cublasStatus_t err); +void check_cuda_error(const char* name, cudaError_t err); +void check_cuda_error(const char* name, CUresult err); + +// The macro version that prints the command that failed. +#define CHECK_CUBLAS_ERROR(cmd) check_cublas_error(#cmd, (cmd)) +#define CHECK_CUDA_ERROR(cmd) check_cuda_error(#cmd, (cmd)) + +// Base class for RAII managed CUDA resources. +template +class CudaHandle { + public: + CudaHandle(Handle handle = nullptr) : handle_(handle) {} + + CudaHandle(CudaHandle&& other) : handle_(other.handle_) { + assert(this != &other); + other.handle_ = nullptr; + } + + ~CudaHandle() { + reset(); + } + + CudaHandle(const CudaHandle&) = delete; + CudaHandle& operator=(const CudaHandle&) = delete; + + CudaHandle& operator=(CudaHandle&& other) { + assert(this != &other); + reset(); + std::swap(handle_, other.handle_); + return *this; + } + + void reset() { + if (handle_ != nullptr) { + CHECK_CUDA_ERROR(Destroy(handle_)); + handle_ = nullptr; + } + } + + operator Handle() const { + return handle_; + } + + protected: + Handle handle_; +}; + +namespace cu { +class Device; +}; // namespace cu + +// Wrappers of CUDA resources. +class CudaGraph : public CudaHandle { + public: + using CudaHandle::CudaHandle; + explicit CudaGraph(cu::Device& device); + void end_capture(cudaStream_t stream); +}; + +class CudaGraphExec : public CudaHandle { + public: + void instantiate(cudaGraph_t graph); +}; + +class CudaStream : public CudaHandle { + public: + explicit CudaStream(cu::Device& device); +}; + +} // namespace mlx::core diff --git a/mlx/backend/cuda/utils.h b/mlx/backend/cuda/utils.h index 2d5ae59ef..d52299fcf 100644 --- a/mlx/backend/cuda/utils.h +++ b/mlx/backend/cuda/utils.h @@ -4,101 +4,12 @@ #pragma once -#include -#include -#include #include "mlx/array.h" #include "mlx/backend/cuda/allocator.h" +#include "mlx/backend/cuda/cuda_utils.h" namespace mlx::core { -namespace cu { -class Device; - -} - -template -T* gpu_ptr(array& arr) { - return cu::gpu_ptr(arr.buffer()); -} - -template -const T* gpu_ptr(const array& arr) { - return cu::gpu_ptr(arr.buffer()); -} - -struct Dtype; - -// Throw exception if the cuda API does not succeed. -void check_cublas_error(const char* name, cublasStatus_t err); -void check_cuda_error(const char* name, cudaError_t err); -void check_cuda_error(const char* name, CUresult err); - -// The macro version that prints the command that failed. -#define CHECK_CUBLAS_ERROR(cmd) check_cublas_error(#cmd, (cmd)) -#define CHECK_CUDA_ERROR(cmd) check_cuda_error(#cmd, (cmd)) - -// Convert Dtype to CUDA C++ types. -const char* dtype_to_cuda_type(const Dtype& dtype); - -// Base class for RAII managed CUDA resources. -template -class CudaHandle { - public: - CudaHandle(Handle handle = nullptr) : handle_(handle) {} - - CudaHandle(CudaHandle&& other) : handle_(other.handle_) { - assert(this != &other); - other.handle_ = nullptr; - } - - ~CudaHandle() { - reset(); - } - - CudaHandle(const CudaHandle&) = delete; - CudaHandle& operator=(const CudaHandle&) = delete; - - CudaHandle& operator=(CudaHandle&& other) { - assert(this != &other); - reset(); - std::swap(handle_, other.handle_); - return *this; - } - - void reset() { - if (handle_ != nullptr) { - CHECK_CUDA_ERROR(Destroy(handle_)); - handle_ = nullptr; - } - } - - operator Handle() const { - return handle_; - } - - protected: - Handle handle_; -}; - -// Wrappers of CUDA resources. -class CudaGraph : public CudaHandle { - public: - using CudaHandle::CudaHandle; - explicit CudaGraph(cu::Device& device); - void end_capture(cudaStream_t stream); -}; - -class CudaGraphExec : public CudaHandle { - public: - void instantiate(cudaGraph_t graph); -}; - -class CudaStream : public CudaHandle { - public: - explicit CudaStream(cu::Device& device); -}; - template inline uint max_occupancy_block_dim(T kernel) { int _, block_dim; @@ -112,4 +23,19 @@ inline uint max_occupancy_block_dim(T kernel) { return block_dim; } +template +inline T* gpu_ptr(array& arr) { + return cu::gpu_ptr(arr.buffer()); +} + +template +inline const T* gpu_ptr(const array& arr) { + return cu::gpu_ptr(arr.buffer()); +} + +struct Dtype; + +// Convert Dtype to CUDA C++ types. +const char* dtype_to_cuda_type(const Dtype& dtype); + } // namespace mlx::core diff --git a/mlx/distributed/nccl/nccl.cpp b/mlx/distributed/nccl/nccl.cpp index 8a5376242..2dbb9a2a2 100644 --- a/mlx/distributed/nccl/nccl.cpp +++ b/mlx/distributed/nccl/nccl.cpp @@ -326,8 +326,8 @@ class NCCLGroup : public GroupImpl { auto& encoder = cu::get_command_encoder(stream); CHECK_NCCL(ncclAllReduce( - input.data(), - output.data(), + gpu_ptr(input), + gpu_ptr(output), input.size(), dt, op,