diff --git a/.github/actions/build-cuda/action.yml b/.github/actions/build-cuda/action.yml index 08257cc64..d6068247a 100644 --- a/.github/actions/build-cuda/action.yml +++ b/.github/actions/build-cuda/action.yml @@ -26,10 +26,6 @@ runs: CMAKE_ARGS: -DMLX_BUILD_CUDA=ON -DCMAKE_COMPILE_WARNING_AS_ERROR=ON -DCMAKE_CUDA_COMPILER=${{ inputs.nvcc-location }} run: pip install -e ".[dev]" -v - - name: Check if build actually worked - shell: bash - run: python -c "import mlx.core" - - name: Run Python tests - CPU if: inputs.run-tests == 'true' shell: bash diff --git a/mlx/allocator.h b/mlx/allocator.h index 362f4f08a..67a9245c4 100644 --- a/mlx/allocator.h +++ b/mlx/allocator.h @@ -14,7 +14,7 @@ class Buffer { void* ptr_; public: - Buffer(void* ptr) : ptr_(ptr) {}; + explicit Buffer(void* ptr) : ptr_(ptr) {}; // Get the raw data pointer from the buffer void* raw_ptr(); diff --git a/mlx/array.cpp b/mlx/array.cpp index a05e8dfa7..a8c77d150 100644 --- a/mlx/array.cpp +++ b/mlx/array.cpp @@ -64,7 +64,7 @@ array array::unsafe_weak_copy(const array& other) { other.strides(), other.flags(), [](auto) {}); - cpy.array_desc_->data_ptr = other.array_desc_->data_ptr; + cpy.array_desc_->offset = other.array_desc_->offset; return cpy; } @@ -141,7 +141,7 @@ bool array::is_tracer() const { void array::set_data(allocator::Buffer buffer, Deleter d) { array_desc_->data = std::make_shared(buffer, d); - array_desc_->data_ptr = buffer.raw_ptr(); + array_desc_->offset = 0; array_desc_->data_size = size(); array_desc_->flags.contiguous = true; array_desc_->flags.row_contiguous = true; @@ -156,7 +156,7 @@ void array::set_data( Flags flags, Deleter d) { array_desc_->data = std::make_shared(buffer, d); - array_desc_->data_ptr = buffer.raw_ptr(); + array_desc_->offset = 0; array_desc_->data_size = data_size; array_desc_->strides = std::move(strides); array_desc_->flags = flags; @@ -172,9 +172,8 @@ void array::copy_shared_buffer( array_desc_->strides = strides; array_desc_->flags = flags; array_desc_->data_size = data_size; - auto char_offset = sizeof(char) * itemsize() * offset; - array_desc_->data_ptr = static_cast( - static_cast(other.array_desc_->data_ptr) + char_offset); + array_desc_->offset = + sizeof(char) * itemsize() * offset + other.array_desc_->offset; } void array::copy_shared_buffer(const array& other) { diff --git a/mlx/array.h b/mlx/array.h index 279d70a5e..c8a529d7d 100644 --- a/mlx/array.h +++ b/mlx/array.h @@ -354,15 +354,23 @@ class array { return array_desc_->data; } - // Return a raw pointer to the arrays data + // Return a raw pointer to the arrays data. This function may do a copy if + // the underlying buffer is not accessible on the CPU. When accessing the + // data for GPU kernels, be sure to use the correct method / function for the + // given backend to access the GPU pointer. template T* data() { - return static_cast(array_desc_->data_ptr); + return reinterpret_cast( + (static_cast(buffer().raw_ptr()) + array_desc_->offset)); } template const T* data() const { - return static_cast(array_desc_->data_ptr); + return const_cast(*this).data(); + } + + int64_t offset() const { + return array_desc_->offset; } enum Status { @@ -466,8 +474,8 @@ class array { // can share the underlying data buffer. std::shared_ptr data; - // Properly offset data pointer - void* data_ptr{nullptr}; + // Offset from beginning of data pointer + int64_t offset{0}; // The size in elements of the data buffer the array accesses size_t data_size; diff --git a/mlx/backend/common/binary.h b/mlx/backend/common/binary.h index ac6e1891d..78607ef07 100644 --- a/mlx/backend/common/binary.h +++ b/mlx/backend/common/binary.h @@ -38,20 +38,20 @@ inline void set_binary_op_output_data( const array& a, const array& b, array& out, - BinaryOpType bopt) { + BinaryOpType bopt, + std::function mallocfn = allocator::malloc) { bool b_donatable = is_donatable(b, out); bool a_donatable = is_donatable(a, out); switch (bopt) { case BinaryOpType::ScalarScalar: - out.set_data( - allocator::malloc(out.itemsize()), 1, a.strides(), a.flags()); + out.set_data(mallocfn(out.itemsize()), 1, a.strides(), a.flags()); break; case BinaryOpType::ScalarVector: if (b_donatable) { out.copy_shared_buffer(b); } else { out.set_data( - allocator::malloc(b.data_size() * out.itemsize()), + mallocfn(b.data_size() * out.itemsize()), b.data_size(), b.strides(), b.flags()); @@ -62,7 +62,7 @@ inline void set_binary_op_output_data( out.copy_shared_buffer(a); } else { out.set_data( - allocator::malloc(a.data_size() * out.itemsize()), + mallocfn(a.data_size() * out.itemsize()), a.data_size(), a.strides(), a.flags()); @@ -75,7 +75,7 @@ inline void set_binary_op_output_data( out.copy_shared_buffer(b); } else { out.set_data( - allocator::malloc(a.data_size() * out.itemsize()), + mallocfn(a.data_size() * out.itemsize()), a.data_size(), a.strides(), a.flags()); @@ -88,7 +88,7 @@ inline void set_binary_op_output_data( b_donatable && b.flags().row_contiguous && b.size() == out.size()) { out.copy_shared_buffer(b); } else { - out.set_data(allocator::malloc(out.nbytes())); + out.set_data(mallocfn(out.nbytes())); } break; } diff --git a/mlx/backend/common/broadcasting.cpp b/mlx/backend/common/broadcasting.cpp index 49bc75b8f..0bb52e096 100644 --- a/mlx/backend/common/broadcasting.cpp +++ b/mlx/backend/common/broadcasting.cpp @@ -6,7 +6,7 @@ namespace mlx::core { void broadcast(const array& in, array& out) { if (out.size() == 0) { - out.set_data(nullptr); + out.set_data(allocator::malloc(0)); return; } Strides strides(out.ndim(), 0); diff --git a/mlx/backend/common/compiled.cpp b/mlx/backend/common/compiled.cpp index 44e2a432b..e5e1d4350 100644 --- a/mlx/backend/common/compiled.cpp +++ b/mlx/backend/common/compiled.cpp @@ -114,7 +114,9 @@ void compiled_allocate_outputs( const std::vector& inputs, std::vector& outputs, const std::function& is_constant, - bool contiguous) { + bool contiguous, + const std::function& + mallocfn /* = allocator::malloc */) { if (contiguous) { int o = 0; Strides strides; @@ -140,7 +142,7 @@ void compiled_allocate_outputs( } for (; o < outputs.size(); ++o) { outputs[o].set_data( - allocator::malloc(data_size * outputs[o].itemsize()), + mallocfn(data_size * outputs[o].itemsize()), data_size, strides, flags); @@ -163,7 +165,7 @@ void compiled_allocate_outputs( } } for (; o < outputs.size(); ++o) { - outputs[o].set_data(allocator::malloc(outputs[o].nbytes())); + outputs[o].set_data(mallocfn(outputs[o].nbytes())); } } } diff --git a/mlx/backend/common/compiled.h b/mlx/backend/common/compiled.h index e92a6d0ad..3be371333 100644 --- a/mlx/backend/common/compiled.h +++ b/mlx/backend/common/compiled.h @@ -58,7 +58,9 @@ void compiled_allocate_outputs( const std::vector& inputs, std::vector& outputs, const std::function& is_constant, - bool contiguous); + bool contiguous, + const std::function& mallocfn = + allocator::malloc); // Collapse contiguous dims ignoring scalars and constants. std::tuple> compiled_collapse_contiguous_dims( diff --git a/mlx/backend/common/copy.h b/mlx/backend/common/copy.h index c23d2e79a..859ce0410 100644 --- a/mlx/backend/common/copy.h +++ b/mlx/backend/common/copy.h @@ -22,7 +22,11 @@ enum class CopyType { GeneralGeneral }; -inline bool set_copy_output_data(const array& in, array& out, CopyType ctype) { +inline bool set_copy_output_data( + const array& in, + array& out, + CopyType ctype, + std::function mallocfn = allocator::malloc) { if (ctype == CopyType::Vector) { // If the input is donateable, we are doing a vector copy and the types // have the same size, then the input buffer can hold the output. @@ -31,14 +35,14 @@ inline bool set_copy_output_data(const array& in, array& out, CopyType ctype) { return true; } else { out.set_data( - allocator::malloc(in.data_size() * out.itemsize()), + mallocfn(in.data_size() * out.itemsize()), in.data_size(), in.strides(), in.flags()); return false; } } else { - out.set_data(allocator::malloc(out.nbytes())); + out.set_data(mallocfn(out.nbytes())); return false; } } diff --git a/mlx/backend/common/slicing.cpp b/mlx/backend/common/slicing.cpp index 6f5736d63..38f3c1ba0 100644 --- a/mlx/backend/common/slicing.cpp +++ b/mlx/backend/common/slicing.cpp @@ -45,7 +45,7 @@ void slice( const Shape& start_indices, const Shape& strides) { if (out.size() == 0) { - out.set_data(nullptr); + out.set_data(allocator::malloc(0)); return; } diff --git a/mlx/backend/common/ternary.h b/mlx/backend/common/ternary.h index 233708ec3..c63a57261 100644 --- a/mlx/backend/common/ternary.h +++ b/mlx/backend/common/ternary.h @@ -46,7 +46,8 @@ inline void set_ternary_op_output_data( const array& b, const array& c, array& out, - TernaryOpType topt) { + TernaryOpType topt, + std::function mallocfn = allocator::malloc) { auto maybe_donate = [&out](const array& x) { if (is_donatable(x, out)) { out.copy_shared_buffer(x); @@ -57,13 +58,12 @@ inline void set_ternary_op_output_data( switch (topt) { case TernaryOpType::ScalarScalarScalar: - out.set_data( - allocator::malloc(out.itemsize()), 1, b.strides(), b.flags()); + out.set_data(mallocfn(out.itemsize()), 1, b.strides(), b.flags()); break; case TernaryOpType::VectorVectorVector: if (!(maybe_donate(a) || maybe_donate(b) || maybe_donate(c))) { out.set_data( - allocator::malloc(out.itemsize() * b.data_size()), + mallocfn(out.itemsize() * b.data_size()), b.data_size(), b.strides(), b.flags()); @@ -76,7 +76,7 @@ inline void set_ternary_op_output_data( if (!((a.flags().row_contiguous && maybe_donate(a)) || (b.flags().row_contiguous && maybe_donate(b)) || (c.flags().row_contiguous && maybe_donate(c)))) { - out.set_data(allocator::malloc(out.nbytes())); + out.set_data(mallocfn(out.nbytes())); } break; } diff --git a/mlx/backend/common/unary.h b/mlx/backend/common/unary.h index a27a1f45c..b19fc98ed 100644 --- a/mlx/backend/common/unary.h +++ b/mlx/backend/common/unary.h @@ -7,19 +7,22 @@ namespace mlx::core { -inline void set_unary_output_data(const array& in, array& out) { +inline void set_unary_output_data( + const array& in, + array& out, + std::function mallocfn = allocator::malloc) { if (in.flags().contiguous) { if (is_donatable(in, out)) { out.copy_shared_buffer(in); } else { out.set_data( - allocator::malloc(in.data_size() * out.itemsize()), + mallocfn(in.data_size() * out.itemsize()), in.data_size(), in.strides(), in.flags()); } } else { - out.set_data(allocator::malloc(out.nbytes())); + out.set_data(mallocfn(out.nbytes())); } } diff --git a/mlx/backend/cpu/primitives.cpp b/mlx/backend/cpu/primitives.cpp index f2cb12fdd..d5b917b84 100644 --- a/mlx/backend/cpu/primitives.cpp +++ b/mlx/backend/cpu/primitives.cpp @@ -333,7 +333,7 @@ void Reshape::eval_cpu(const std::vector& inputs, array& out) { void DynamicSlice::eval_cpu(const std::vector& inputs, array& out) { if (out.size() == 0) { - out.set_data(nullptr); + out.set_data(allocator::malloc(0)); return; } auto& in = inputs[0]; @@ -361,7 +361,7 @@ void DynamicSliceUpdate::eval_cpu( const std::vector& inputs, array& out) { if (out.size() == 0) { - out.set_data(nullptr); + out.set_data(allocator::malloc(0)); return; } @@ -396,7 +396,7 @@ void DynamicSliceUpdate::eval_cpu( void SliceUpdate::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 2); if (out.size() == 0) { - out.set_data(nullptr); + out.set_data(allocator::malloc(0)); return; } diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index 7f8f1aade..543e9fd58 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -32,6 +32,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.cu ${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/load.cpp ${CMAKE_CURRENT_SOURCE_DIR}/layer_norm.cu ${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.cu ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp diff --git a/mlx/backend/cuda/allocator.cpp b/mlx/backend/cuda/allocator.cpp index 75bd60a68..4b244d91e 100644 --- a/mlx/backend/cuda/allocator.cpp +++ b/mlx/backend/cuda/allocator.cpp @@ -1,6 +1,7 @@ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/allocator.h" +#include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/utils.h" #include "mlx/utils.h" @@ -67,6 +68,7 @@ CudaBuffer* SmallSizePool::malloc() { next_free_ = next_free_->next; b->buf.data = static_cast(data_) + i * small_block_size; b->buf.size = small_block_size; + b->buf.device = -1; return &b->buf; } @@ -88,14 +90,40 @@ CudaAllocator::CudaAllocator() page_size, [](CudaBuffer* buf) { return buf->size; }, [this](CudaBuffer* buf) { cuda_free(buf); }) { - // TODO: Set memory limit for multi-device. size_t free, total; CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total)); memory_limit_ = total * 0.95; max_pool_size_ = memory_limit_; + + int device_count = 0; + CHECK_CUDA_ERROR(cudaGetDeviceCount(&device_count)); + int curr; + CHECK_CUDA_ERROR(cudaGetDevice(&curr)); + for (int i = 0; i < device_count; ++i) { + CHECK_CUDA_ERROR(cudaSetDevice(i)); + cudaStream_t s; + CHECK_CUDA_ERROR(cudaStreamCreateWithFlags(&s, cudaStreamNonBlocking)); + free_streams_.push_back(s); + } + CHECK_CUDA_ERROR(cudaSetDevice(curr)); } -Buffer CudaAllocator::malloc(size_t size) { +void copy_to_managed(CudaBuffer& buf) { + // TODO maybe make this async on a i/o stream to avoid synchronizing the + // device on malloc/and free + void* new_data; + CHECK_CUDA_ERROR(cudaMallocManaged(&new_data, buf.size)); + buf.device = -1; + CHECK_CUDA_ERROR(cudaMemcpy(new_data, buf.data, buf.size, cudaMemcpyDefault)); + CHECK_CUDA_ERROR(cudaFree(buf.data)); + buf.data = new_data; +} + +Buffer CudaAllocator::malloc_impl(size_t size, cudaStream_t stream) { + if (size == 0) { + return Buffer{new CudaBuffer{nullptr, 0, -1}}; + } + // Find available buffer from cache. std::unique_lock lock(mutex_); if (size <= small_block_size) { @@ -106,6 +134,11 @@ Buffer CudaAllocator::malloc(size_t size) { size = page_size * ((size + page_size - 1) / page_size); } + int device = -1; + if (size > small_block_size && stream != nullptr) { + CHECK_CUDA_ERROR(cudaStreamGetDevice(stream, &device)); + } + CudaBuffer* buf = buffer_cache_.reuse_from_cache(size); if (!buf) { // If we have a lot of memory pressure try to reclaim memory from the cache. @@ -121,8 +154,13 @@ Buffer CudaAllocator::malloc(size_t size) { } lock.unlock(); if (!buf) { - buf = new CudaBuffer{nullptr, size}; - cudaError_t err = cudaMallocManaged(&buf->data, size); + buf = new CudaBuffer{nullptr, size, device}; + cudaError_t err; + if (device == -1) { + err = cudaMallocManaged(&buf->data, size); + } else { + err = cudaMallocAsync(&buf->data, size, stream); + } if (err != cudaSuccess && err != cudaErrorMemoryAllocation) { throw std::runtime_error(fmt::format( "cudaMallocManaged failed: {}.", cudaGetErrorString(err))); @@ -137,14 +175,30 @@ Buffer CudaAllocator::malloc(size_t size) { if (get_cache_memory() > max_pool_size_) { buffer_cache_.release_cached_buffers(get_cache_memory() - max_pool_size_); } + // Copy to managed here if the buffer is not on the right device + if (buf->device != device) { + copy_to_managed(*buf); + } return Buffer{buf}; } +Buffer CudaAllocator::malloc_async(size_t size, cudaStream_t stream) { + return malloc_impl(size, stream); +} + +Buffer CudaAllocator::malloc(size_t size) { + return malloc_impl(size, nullptr); +} + void CudaAllocator::free(Buffer buffer) { auto* buf = static_cast(buffer.ptr()); if (!buf) { return; } + if (buf->size == 0) { + delete buf; + return; + } std::unique_lock lock(mutex_); active_memory_ -= buf->size; @@ -168,7 +222,11 @@ void CudaAllocator::cuda_free(CudaBuffer* buf) { if (scalar_pool_.in_pool(buf)) { scalar_pool_.free(buf); } else { - cudaFree(buf->data); + if (buf->device >= 0) { + cudaFreeAsync(buf->data, free_streams_[buf->device]); + } else { + cudaFree(buf->data); + } delete buf; } } @@ -219,6 +277,16 @@ CudaAllocator& allocator() { return *allocator_; } +Buffer malloc_async(size_t size, cudaStream_t stream) { + auto buffer = allocator().malloc_async(size, stream); + if (size && !buffer.ptr()) { + std::ostringstream msg; + msg << "[malloc_async] Unable to allocate " << size << " bytes."; + throw std::runtime_error(msg.str()); + } + return buffer; +} + } // namespace cu namespace allocator { @@ -231,7 +299,11 @@ void* Buffer::raw_ptr() { if (!ptr_) { return nullptr; } - return static_cast(ptr_)->data; + auto& cbuf = *static_cast(ptr_); + if (cbuf.device != -1) { + copy_to_managed(cbuf); + } + return cbuf.data; } } // namespace allocator diff --git a/mlx/backend/cuda/allocator.h b/mlx/backend/cuda/allocator.h index 81b3dde59..0b62d8529 100644 --- a/mlx/backend/cuda/allocator.h +++ b/mlx/backend/cuda/allocator.h @@ -4,7 +4,9 @@ #include "mlx/allocator.h" #include "mlx/backend/common/buffer_cache.h" +#include "mlx/backend/cuda/cuda_utils.h" +#include #include #include #include @@ -17,6 +19,7 @@ using allocator::Buffer; struct CudaBuffer { void* data; size_t size; + int device; // -1 for managed }; class SmallSizePool { @@ -45,6 +48,7 @@ class SmallSizePool { class CudaAllocator : public allocator::Allocator { public: Buffer malloc(size_t size) override; + Buffer malloc_async(size_t size, cudaStream_t stream); void free(Buffer buffer) override; size_t size(Buffer buffer) const override; @@ -58,6 +62,7 @@ class CudaAllocator : public allocator::Allocator { void clear_cache(); private: + Buffer malloc_impl(size_t size, cudaStream_t stream); void cuda_free(CudaBuffer* buf); CudaAllocator(); @@ -69,9 +74,12 @@ class CudaAllocator : public allocator::Allocator { BufferCache buffer_cache_; size_t active_memory_{0}; size_t peak_memory_{0}; + std::vector free_streams_; SmallSizePool scalar_pool_; }; CudaAllocator& allocator(); +Buffer malloc_async(size_t size, cudaStream_t stream); + } // namespace mlx::core::cu diff --git a/mlx/backend/cuda/arange.cu b/mlx/backend/cuda/arange.cu index a28a245db..a8e406ba1 100644 --- a/mlx/backend/cuda/arange.cu +++ b/mlx/backend/cuda/arange.cu @@ -41,9 +41,8 @@ void Arange::eval_gpu(const std::vector& inputs, array& out) { if (out.size() == 0) { return; } - out.set_data(allocator::malloc(out.nbytes())); - auto& encoder = cu::get_command_encoder(stream()); + out.set_data(cu::malloc_async(out.nbytes(), encoder.stream())); encoder.set_output_array(out); dispatch_int_float_types(out.dtype(), "Arange", [&](auto type_tag) { @@ -58,7 +57,7 @@ void Arange::eval_gpu(const std::vector& inputs, array& out) { num_blocks, block_dims, 0, - out.data(), + gpu_ptr(out), out.data_size(), static_cast(start_), static_cast(start_ + step_) - static_cast(start_)); diff --git a/mlx/backend/cuda/arg_reduce.cu b/mlx/backend/cuda/arg_reduce.cu index 21cd677a8..8dd77cc73 100644 --- a/mlx/backend/cuda/arg_reduce.cu +++ b/mlx/backend/cuda/arg_reduce.cu @@ -140,8 +140,10 @@ void ArgReduce::eval_gpu(const std::vector& inputs, array& out) { nvtx3::scoped_range r("ArgReduce::eval_gpu"); assert(inputs.size() == 1); auto& in = inputs[0]; - out.set_data(allocator::malloc(out.nbytes())); + auto& s = stream(); + auto& encoder = cu::get_command_encoder(s); + out.set_data(cu::malloc_async(out.nbytes(), encoder.stream())); // Prepare the shapes, strides and axis arguments. Shape shape = remove_index(in.shape(), axis_); @@ -154,7 +156,6 @@ void ArgReduce::eval_gpu(const std::vector& inputs, array& out) { int32_t ndim = shape.size(); // ArgReduce. - auto& encoder = cu::get_command_encoder(s); encoder.set_input_array(in); encoder.set_output_array(out); dispatch_real_types(in.dtype(), "ArgReduce", [&](auto type_tag) { @@ -172,8 +173,8 @@ void ArgReduce::eval_gpu(const std::vector& inputs, array& out) { num_blocks, block_dim(), 0, - in.data(), - out.data(), + gpu_ptr(in), + gpu_ptr(out), out.size(), const_param(shape), const_param(in_strides), diff --git a/mlx/backend/cuda/binary/binary.cuh b/mlx/backend/cuda/binary/binary.cuh index 20bb199ec..47847e4d9 100644 --- a/mlx/backend/cuda/binary/binary.cuh +++ b/mlx/backend/cuda/binary/binary.cuh @@ -292,9 +292,9 @@ void binary_op_gpu_inplace( {num_blocks_x, num_blocks_y}, block_dims, 0, - a.data(), - b.data(), - out.data(), + gpu_ptr(a), + gpu_ptr(b), + gpu_ptr(out), rest, const_param(shape), const_param(a_strides), @@ -310,9 +310,9 @@ void binary_op_gpu_inplace( {num_blocks_x, num_blocks_y}, block_dims, 0, - a.data(), - b.data(), - out.data(), + gpu_ptr(a), + gpu_ptr(b), + gpu_ptr(out), rest, const_param(shape), const_param(a_strides), @@ -339,9 +339,9 @@ void binary_op_gpu_inplace( num_blocks, block_dims, 0, - a.data(), - b.data(), - out.data(), + gpu_ptr(a), + gpu_ptr(b), + gpu_ptr(out), out.data_size()); }); } @@ -365,7 +365,11 @@ void binary_op_gpu( auto& a = inputs[0]; auto& b = inputs[1]; auto bopt = get_binary_op_type(a, b); - set_binary_op_output_data(a, b, out, bopt); + auto& encoder = cu::get_command_encoder(s); + + set_binary_op_output_data(a, b, out, bopt, [&](auto n) { + return cu::malloc_async(n, encoder.stream()); + }); binary_op_gpu_inplace(inputs, out, op, s); } diff --git a/mlx/backend/cuda/binary_two.cu b/mlx/backend/cuda/binary_two.cu index cd0fe2c46..edec2678f 100644 --- a/mlx/backend/cuda/binary_two.cu +++ b/mlx/backend/cuda/binary_two.cu @@ -245,14 +245,18 @@ void binary_two_op_gpu_inplace( auto& out_a = outputs[0]; auto& out_b = outputs[1]; auto bopt = get_binary_op_type(a, b); - set_binary_op_output_data(a, b, out_a, bopt); - set_binary_op_output_data(a, b, out_b, bopt); + auto& encoder = cu::get_command_encoder(s); + set_binary_op_output_data(a, b, out_a, bopt, [&](auto n) { + return cu::malloc_async(n, encoder.stream()); + }); + set_binary_op_output_data(a, b, out_b, bopt, [&](auto n) { + return cu::malloc_async(n, encoder.stream()); + }); if (out_a.size() == 0) { return; } - auto& encoder = cu::get_command_encoder(s); encoder.set_input_array(a); encoder.set_input_array(b); encoder.set_output_array(out_a); @@ -313,10 +317,10 @@ void binary_two_op_gpu_inplace( {num_blocks_x, num_blocks_y}, block_dims, 0, - a.data(), - b.data(), - out_a.data(), - out_b.data(), + gpu_ptr(a), + gpu_ptr(b), + gpu_ptr(out_a), + gpu_ptr(out_b), rest, const_param(shape), const_param(a_strides), @@ -332,10 +336,10 @@ void binary_two_op_gpu_inplace( {num_blocks_x, num_blocks_y}, block_dims, 0, - a.data(), - b.data(), - out_a.data(), - out_b.data(), + gpu_ptr(a), + gpu_ptr(b), + gpu_ptr(out_a), + gpu_ptr(out_b), rest, const_param(shape), const_param(a_strides), @@ -366,10 +370,10 @@ void binary_two_op_gpu_inplace( num_blocks, block_dims, 0, - a.data(), - b.data(), - out_a.data(), - out_b.data(), + gpu_ptr(a), + gpu_ptr(b), + gpu_ptr(out_a), + gpu_ptr(out_b), out_a.data_size()); }); } diff --git a/mlx/backend/cuda/compiled.cpp b/mlx/backend/cuda/compiled.cpp index 302ca2f99..3c5681019 100644 --- a/mlx/backend/cuda/compiled.cpp +++ b/mlx/backend/cuda/compiled.cpp @@ -293,8 +293,13 @@ void Compiled::eval_gpu( } } + auto& encoder = cu::get_command_encoder(s); + // Put outputs. - compiled_allocate_outputs(inputs, outputs, is_constant_, contiguous); + compiled_allocate_outputs( + inputs, outputs, is_constant_, contiguous, [&](auto n) { + return cu::malloc_async(n, encoder.stream()); + }); for (auto& x : outputs) { args.append(x); } @@ -324,7 +329,6 @@ void Compiled::eval_gpu( kernel_name += fmt::format( "_strided<{}, {}, {}>", shape.size(), index_type, work_per_thread); } - auto& encoder = cu::get_command_encoder(s); for (const auto& in : inputs) { encoder.set_input_array(in); } diff --git a/mlx/backend/cuda/conv.cpp b/mlx/backend/cuda/conv.cpp index e65de63e0..4f7a971ca 100644 --- a/mlx/backend/cuda/conv.cpp +++ b/mlx/backend/cuda/conv.cpp @@ -270,17 +270,16 @@ void Convolution::eval_gpu(const std::vector& inputs, array& out_) { if (out_.size() == 0) { return; } + auto& s = stream(); + auto& encoder = cu::get_command_encoder(s); assert(inputs.size() == 2); array in = inputs[0]; array wt = inputs[1]; array out = out_; - out.set_data(allocator::malloc(out.nbytes())); + out.set_data(cu::malloc_async(out.nbytes(), encoder.stream())); Dtype dtype = out.dtype(); - auto& s = stream(); - auto& encoder = cu::get_command_encoder(s); - // Search cache. ConvCacheKey cache_key{ encoder.device().cuda_device(), diff --git a/mlx/backend/cuda/conv/gemm_conv.cu b/mlx/backend/cuda/conv/gemm_conv.cu index 11a78a7ab..b2c151706 100644 --- a/mlx/backend/cuda/conv/gemm_conv.cu +++ b/mlx/backend/cuda/conv/gemm_conv.cu @@ -86,7 +86,7 @@ array unfold_inputs_nd( int mat_N, ConvParams& params) { array unfolded({mat_M, mat_K}, in.dtype(), nullptr, {}); - unfolded.set_data(allocator::malloc(unfolded.nbytes())); + unfolded.set_data(cu::malloc_async(unfolded.nbytes(), encoder.stream())); encoder.add_temporary(unfolded); int filter_size = params.C; @@ -118,8 +118,8 @@ array unfold_inputs_nd( num_blocks, block_dims, 0, - in.data(), - unfolded.data(), + gpu_ptr(in), + gpu_ptr(unfolded), filter_size, out_pixels, params); diff --git a/mlx/backend/cuda/conv/gemm_grouped_conv.cu b/mlx/backend/cuda/conv/gemm_grouped_conv.cu index 7ceb58166..788c640ff 100644 --- a/mlx/backend/cuda/conv/gemm_grouped_conv.cu +++ b/mlx/backend/cuda/conv/gemm_grouped_conv.cu @@ -89,7 +89,7 @@ array grouped_unfold_transpose_inputs_nd( int mat_N, ConvParams& params) { array unfolded({mat_M, mat_K * params.groups}, in.dtype(), nullptr, {}); - unfolded.set_data(allocator::malloc(unfolded.nbytes())); + unfolded.set_data(cu::malloc_async(unfolded.nbytes(), encoder.stream())); encoder.add_temporary(unfolded); int filter_size = params.C; @@ -121,8 +121,8 @@ array grouped_unfold_transpose_inputs_nd( num_blocks, block_dims, 0, - in.data(), - unfolded.data(), + gpu_ptr(in), + gpu_ptr(unfolded), filter_size, out_pixels, params); diff --git a/mlx/backend/cuda/copy.cu b/mlx/backend/cuda/copy.cu index 158d3de6e..f559076d2 100644 --- a/mlx/backend/cuda/copy.cu +++ b/mlx/backend/cuda/copy.cu @@ -5,6 +5,22 @@ namespace mlx::core { +void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) { + auto& encoder = cu::get_command_encoder(s); + bool donated = set_copy_output_data(in, out, ctype, [&](auto n) { + return cu::malloc_async(n, encoder.stream()); + }); + if (donated && in.dtype() == out.dtype()) { + // If the output has the same type as the input then there is nothing to + // copy, just use the buffer. + return; + } + if (ctype == CopyType::GeneralGeneral) { + ctype = CopyType::General; + } + copy_gpu_inplace(in, out, ctype, s); +} + void copy_gpu_inplace( const array& in, array& out, @@ -87,11 +103,31 @@ void fill_gpu(const array& in, array& out, const Stream& s) { if (out.size() == 0) { return; } - out.set_data(allocator::malloc(out.nbytes())); auto& encoder = cu::get_command_encoder(s); + out.set_data(cu::malloc_async(out.nbytes(), encoder.stream())); encoder.set_input_array(in); encoder.set_output_array(out); copy_contiguous(encoder, CopyType::Scalar, in, out, 0, 0); } +void reshape_gpu(const array& in, array& out, Stream s) { + auto [copy_necessary, out_strides] = prepare_reshape(in, out); + if (copy_necessary) { + auto& encoder = cu::get_command_encoder(s); + out.set_data(cu::malloc_async(out.nbytes(), encoder.stream())); + copy_gpu_inplace( + in, + out, + in.shape(), + in.strides(), + make_contiguous_strides(in.shape()), + 0, + 0, + CopyType::General, + s); + } else { + shared_buffer_reshape(in, out_strides, out); + } +} + } // namespace mlx::core diff --git a/mlx/backend/cuda/copy/copy_contiguous.cu b/mlx/backend/cuda/copy/copy_contiguous.cu index 9c2aa9838..6f51734b1 100644 --- a/mlx/backend/cuda/copy/copy_contiguous.cu +++ b/mlx/backend/cuda/copy/copy_contiguous.cu @@ -77,8 +77,8 @@ void copy_contiguous( num_blocks, block_dims, 0, - in.data() + in_offset, - out.data() + out_offset, + gpu_ptr(in) + in_offset, + gpu_ptr(out) + out_offset, out.data_size()); }); }); diff --git a/mlx/backend/cuda/copy/copy_general.cu b/mlx/backend/cuda/copy/copy_general.cu index 6ac42751a..fec110a86 100644 --- a/mlx/backend/cuda/copy/copy_general.cu +++ b/mlx/backend/cuda/copy/copy_general.cu @@ -106,8 +106,8 @@ void copy_general( using InType = cuda_type_t; using OutType = cuda_type_t; using IdxT = std::conditional_t; - const InType* in_ptr = in.data() + offset_in; - OutType* out_ptr = out.data() + offset_out; + const InType* in_ptr = gpu_ptr(in) + offset_in; + OutType* out_ptr = gpu_ptr(out) + offset_out; int ndim = shape.size(); size_t data_size = 1; for (auto& s : shape) diff --git a/mlx/backend/cuda/copy/copy_general_dynamic.cu b/mlx/backend/cuda/copy/copy_general_dynamic.cu index 7a7f0dca5..96cf6fb5c 100644 --- a/mlx/backend/cuda/copy/copy_general_dynamic.cu +++ b/mlx/backend/cuda/copy/copy_general_dynamic.cu @@ -69,8 +69,8 @@ void copy_general_dynamic( using InType = cuda_type_t; using OutType = cuda_type_t; using IdxT = std::conditional_t; - const InType* in_ptr = in.data() + offset_in; - OutType* out_ptr = out.data() + offset_out; + const InType* in_ptr = gpu_ptr(in) + offset_in; + OutType* out_ptr = gpu_ptr(out) + offset_out; int ndim = shape.size(); if (ndim <= 3) { dispatch_1_2_3(ndim, [&](auto dims_constant) { @@ -90,8 +90,8 @@ void copy_general_dynamic( const_param(shape), const_param(strides_in), const_param(strides_out), - dynamic_offset_in.data(), - dynamic_offset_out.data()); + gpu_ptr(dynamic_offset_in), + gpu_ptr(dynamic_offset_out)); }); } else { // ndim >= 4 auto [num_blocks, block_dims] = get_launch_args(out, large()); @@ -107,8 +107,8 @@ void copy_general_dynamic( const_param(strides_in), const_param(strides_out), ndim, - dynamic_offset_in.data(), - dynamic_offset_out.data()); + gpu_ptr(dynamic_offset_in), + gpu_ptr(dynamic_offset_out)); } }); }); diff --git a/mlx/backend/cuda/copy/copy_general_input.cu b/mlx/backend/cuda/copy/copy_general_input.cu index ce8bb1b78..42a027ec5 100644 --- a/mlx/backend/cuda/copy/copy_general_input.cu +++ b/mlx/backend/cuda/copy/copy_general_input.cu @@ -92,8 +92,8 @@ void copy_general_input( using InType = cuda_type_t; using OutType = cuda_type_t; using IdxT = std::conditional_t; - const InType* in_ptr = in.data() + offset_in; - OutType* out_ptr = out.data() + offset_out; + const InType* in_ptr = gpu_ptr(in) + offset_in; + OutType* out_ptr = gpu_ptr(out) + offset_out; int ndim = shape.size(); int work_per_thread = 1; auto dim0 = ndim > 0 ? shape.back() : 1; diff --git a/mlx/backend/cuda/cuda_utils.h b/mlx/backend/cuda/cuda_utils.h new file mode 100644 index 000000000..538b4ca9c --- /dev/null +++ b/mlx/backend/cuda/cuda_utils.h @@ -0,0 +1,82 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include +#include +#include + +namespace mlx::core { + +// Throw exception if the cuda API does not succeed. +void check_cublas_error(const char* name, cublasStatus_t err); +void check_cuda_error(const char* name, cudaError_t err); +void check_cuda_error(const char* name, CUresult err); + +// The macro version that prints the command that failed. +#define CHECK_CUBLAS_ERROR(cmd) check_cublas_error(#cmd, (cmd)) +#define CHECK_CUDA_ERROR(cmd) check_cuda_error(#cmd, (cmd)) + +// Base class for RAII managed CUDA resources. +template +class CudaHandle { + public: + CudaHandle(Handle handle = nullptr) : handle_(handle) {} + + CudaHandle(CudaHandle&& other) : handle_(other.handle_) { + assert(this != &other); + other.handle_ = nullptr; + } + + ~CudaHandle() { + reset(); + } + + CudaHandle(const CudaHandle&) = delete; + CudaHandle& operator=(const CudaHandle&) = delete; + + CudaHandle& operator=(CudaHandle&& other) { + assert(this != &other); + reset(); + std::swap(handle_, other.handle_); + return *this; + } + + void reset() { + if (handle_ != nullptr) { + CHECK_CUDA_ERROR(Destroy(handle_)); + handle_ = nullptr; + } + } + + operator Handle() const { + return handle_; + } + + protected: + Handle handle_; +}; + +namespace cu { +class Device; +}; // namespace cu + +// Wrappers of CUDA resources. +class CudaGraph : public CudaHandle { + public: + using CudaHandle::CudaHandle; + explicit CudaGraph(cu::Device& device); + void end_capture(cudaStream_t stream); +}; + +class CudaGraphExec : public CudaHandle { + public: + void instantiate(cudaGraph_t graph); +}; + +class CudaStream : public CudaHandle { + public: + explicit CudaStream(cu::Device& device); +}; + +} // namespace mlx::core diff --git a/mlx/backend/cuda/cudnn_utils.cpp b/mlx/backend/cuda/cudnn_utils.cpp index 20280f2be..2728414be 100644 --- a/mlx/backend/cuda/cudnn_utils.cpp +++ b/mlx/backend/cuda/cudnn_utils.cpp @@ -132,14 +132,18 @@ bool prepare_cudnn_plan( void** data_ptrs, F&& execute) { int workspace_size = plan.getWorkspaceSize(); - array workspace( - workspace_size > 0 ? allocator::malloc(workspace_size) - : allocator::Buffer(nullptr), - {workspace_size}, - uint8); + void* workspace_ptr = nullptr; + if (workspace_size > 0) { + array workspace( + cu::malloc_async(workspace_size, encoder.stream()), + {workspace_size}, + uint8); + encoder.add_temporary(workspace); + workspace_ptr = gpu_ptr(workspace); + } auto args = cudnn_frontend::VariantPackBuilder() - .setWorkspacePointer(workspace.data()) + .setWorkspacePointer(workspace_ptr) .setDataPointers(num_args, data_ptrs) .setUids(num_args, uids) .build(); @@ -151,7 +155,6 @@ bool prepare_cudnn_plan( return false; } - encoder.add_temporary(workspace); return true; } diff --git a/mlx/backend/cuda/cudnn_utils.h b/mlx/backend/cuda/cudnn_utils.h index c35c5cac9..b28249678 100644 --- a/mlx/backend/cuda/cudnn_utils.h +++ b/mlx/backend/cuda/cudnn_utils.h @@ -3,6 +3,7 @@ #pragma once #include "mlx/array.h" +#include "mlx/backend/cuda/allocator.h" #include "mlx/backend/cuda/device/config.h" #include "mlx/backend/cuda/utils.h" #include "mlx/dtype_utils.h" @@ -23,7 +24,7 @@ class CommandEncoder; // Return pointer alignment of |x|'s data. inline uint8_t get_alignment(const array& x) { uint8_t alignment = 1; - uintptr_t address = reinterpret_cast(x.data()); + uintptr_t address = reinterpret_cast(gpu_ptr(x)); for (; alignment < 32; alignment *= 2) { if (address % (alignment * 2)) { return alignment; @@ -56,7 +57,7 @@ inline std::array vector_key(const Vec& vec) { // Helpers used by get_data_ptrs to get pointers. inline void* get_data_ptr(const array& arr) { - return const_cast(arr.data()); + return const_cast(gpu_ptr(arr)); } template >> diff --git a/mlx/backend/cuda/custom_kernel.cpp b/mlx/backend/cuda/custom_kernel.cpp index ee1778fd8..958516746 100644 --- a/mlx/backend/cuda/custom_kernel.cpp +++ b/mlx/backend/cuda/custom_kernel.cpp @@ -279,6 +279,7 @@ void CustomKernel::eval_gpu( std::vector& outputs) { nvtx3::scoped_range r("CustomKernel::eval_gpu"); auto& s = stream(); + auto& encoder = cu::get_command_encoder(s); std::vector copies; @@ -288,7 +289,7 @@ void CustomKernel::eval_gpu( copies.emplace_back(init_value_.value(), out.dtype()); fill_gpu(copies.back(), out, s); } else { - out.set_data(allocator::malloc(out.nbytes())); + out.set_data(cu::malloc_async(out.nbytes(), encoder.stream())); } } @@ -356,7 +357,6 @@ void CustomKernel::eval_gpu( dim3 grid((gx + tx - 1) / tx, (gy + ty - 1) / ty, (gz + tz - 1) / tz); // Call the kernel - auto& encoder = cu::get_command_encoder(s); for (const auto& in : checked_inputs) { encoder.set_input_array(in); } diff --git a/mlx/backend/cuda/device.h b/mlx/backend/cuda/device.h index d18092328..c04973484 100644 --- a/mlx/backend/cuda/device.h +++ b/mlx/backend/cuda/device.h @@ -3,6 +3,7 @@ #pragma once #include "mlx/array.h" +#include "mlx/backend/cuda/allocator.h" #include "mlx/backend/cuda/lru_cache.h" #include "mlx/backend/cuda/worker.h" #include "mlx/stream.h" diff --git a/mlx/backend/cuda/distributed.cu b/mlx/backend/cuda/distributed.cu index 07a5ed10f..33363866f 100644 --- a/mlx/backend/cuda/distributed.cu +++ b/mlx/backend/cuda/distributed.cu @@ -15,8 +15,10 @@ void AllReduce::eval_gpu( assert(inputs.size() == 1); assert(outputs.size() == 1); - auto set_input_output = - [s = stream()](const array& in, array& out) -> std::pair { + auto& s = stream(); + auto& encoder = cu::get_command_encoder(s); + auto set_input_output = [&](const array& in, + array& out) -> std::pair { if (!in.flags().row_contiguous) { copy_gpu(in, out, CopyType::General, s); return {out, out}; @@ -24,19 +26,17 @@ void AllReduce::eval_gpu( out.copy_shared_buffer(in); return {in, out}; } else { - out.set_data(allocator::malloc(out.nbytes())); + out.set_data(cu::malloc_async(out.nbytes(), encoder.stream())); return {in, out}; } }; auto [input, output] = set_input_output(inputs[0], outputs[0]); - auto& encoder = cu::get_command_encoder(stream()); encoder.set_input_array(input); encoder.set_output_array(output); auto capture = encoder.capture_context(); - auto& s = stream(); switch (reduce_type_) { case Sum: @@ -74,7 +74,7 @@ void AllGather::eval_gpu( }; auto input = ensure_contiguous(inputs[0]); - outputs[0].set_data(allocator::malloc(outputs[0].nbytes())); + outputs[0].set_data(cu::malloc_async(outputs[0].nbytes(), encoder.stream())); encoder.set_input_array(input); encoder.set_output_array(outputs[0]); @@ -103,7 +103,7 @@ void ReduceScatter::eval_gpu( }; auto input = ensure_contiguous(inputs[0]); - outputs[0].set_data(allocator::malloc(outputs[0].nbytes())); + outputs[0].set_data(cu::malloc_async(outputs[0].nbytes(), encoder.stream())); encoder.set_input_array(input); encoder.set_output_array(outputs[0]); diff --git a/mlx/backend/cuda/fence.cpp b/mlx/backend/cuda/fence.cpp index 9121c7f4e..a50c9111f 100644 --- a/mlx/backend/cuda/fence.cpp +++ b/mlx/backend/cuda/fence.cpp @@ -1,6 +1,8 @@ // Copyright © 2025 Apple Inc. #include "mlx/fence.h" +#include "mlx/backend/cuda/allocator.h" +#include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/event.h" namespace mlx::core { @@ -20,8 +22,24 @@ void Fence::wait(Stream s, const array&) { fence->event.wait(fence->count); } -void Fence::update(Stream s, const array&) { +void Fence::update(Stream s, const array& a, bool cross_device) { auto* fence = static_cast(fence_.get()); + if (cross_device) { + // Move to managed memory if there is a device switch + auto& cbuf = + *static_cast(const_cast(a).buffer().ptr()); + if (cbuf.device != -1) { + void* new_data; + CHECK_CUDA_ERROR(cudaMallocManaged(&new_data, cbuf.size)); + cbuf.device = -1; + auto& encoder = cu::device(s.device).get_command_encoder(s); + encoder.commit(); + CHECK_CUDA_ERROR(cudaMemcpyAsync( + new_data, cbuf.data, cbuf.size, cudaMemcpyDefault, encoder.stream())); + CHECK_CUDA_ERROR(cudaFreeAsync(cbuf.data, encoder.stream())); + cbuf.data = new_data; + } + } fence->count++; fence->event.signal(s, fence->count); } diff --git a/mlx/backend/cuda/gemms/cublas_gemm.cpp b/mlx/backend/cuda/gemms/cublas_gemm.cpp index fc2380fc3..60ca2ccae 100644 --- a/mlx/backend/cuda/gemms/cublas_gemm.cpp +++ b/mlx/backend/cuda/gemms/cublas_gemm.cpp @@ -241,7 +241,7 @@ void CublasGemm::set_bias(cu::CommandEncoder& encoder, const array& bias) { CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue))); - auto* bias_ptr = bias.data(); + auto* bias_ptr = gpu_ptr(bias); CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( matmul_desc_, CUBLASLT_MATMUL_DESC_BIAS_POINTER, @@ -278,9 +278,9 @@ void CublasGemm::run( execute( encoder, - out.data(), - a.data(), - b.data(), + gpu_ptr(out), + gpu_ptr(a), + gpu_ptr(b), nullptr, alpha); } @@ -321,10 +321,10 @@ void CublasGemm::run( execute( encoder, - out.data(), - a.data(), - b.data(), - c.data(), + gpu_ptr(out), + gpu_ptr(a), + gpu_ptr(b), + gpu_ptr(c), alpha, beta); } @@ -370,11 +370,11 @@ void CublasGemm::execute( // Ensure workspace is 256-byte aligned int nbytes = cuda::ceil_div(heuristic_.workspaceSize, 256) * 256; array workspace( - allocator::malloc(nbytes), + cu::malloc_async(nbytes, encoder.stream()), {static_cast(heuristic_.workspaceSize)}, int8); encoder.add_temporary(workspace); - workspace_ptr = workspace.data(); + workspace_ptr = gpu_ptr(workspace); } auto capture = encoder.capture_context(); diff --git a/mlx/backend/cuda/gemms/cublas_gemm_batched_12_0.cpp b/mlx/backend/cuda/gemms/cublas_gemm_batched_12_0.cpp index 70df21fda..67b8c2754 100644 --- a/mlx/backend/cuda/gemms/cublas_gemm_batched_12_0.cpp +++ b/mlx/backend/cuda/gemms/cublas_gemm_batched_12_0.cpp @@ -25,9 +25,10 @@ void CublasGemm::run_batched( for (size_t i = 0; i < nbatch; ++i) { execute( encoder, - out.data() + out.itemsize() * i * batch_shape.back() * M_ * N_, - a.data() + a.itemsize() * a_it.loc, - b.data() + b.itemsize() * b_it.loc, + gpu_ptr(out) + + out.itemsize() * i * batch_shape.back() * M_ * N_, + gpu_ptr(a) + a.itemsize() * a_it.loc, + gpu_ptr(b) + b.itemsize() * b_it.loc, nullptr, alpha); a_it.step(); @@ -60,10 +61,11 @@ void CublasGemm::run_batched( for (size_t i = 0; i < nbatch; ++i) { execute( encoder, - out.data() + out.itemsize() * i * batch_shape.back() * M_ * N_, - a.data() + a.itemsize() * a_it.loc, - b.data() + b.itemsize() * b_it.loc, - c.data() + c.itemsize() * c_it.loc, + gpu_ptr(out) + + out.itemsize() * i * batch_shape.back() * M_ * N_, + gpu_ptr(a) + a.itemsize() * a_it.loc, + gpu_ptr(b) + b.itemsize() * b_it.loc, + gpu_ptr(c) + c.itemsize() * c_it.loc, alpha, beta); a_it.step(); diff --git a/mlx/backend/cuda/gemms/cublas_gemm_batched_12_9.cu b/mlx/backend/cuda/gemms/cublas_gemm_batched_12_9.cu index 41ab9c8bd..c3b6fd379 100644 --- a/mlx/backend/cuda/gemms/cublas_gemm_batched_12_9.cu +++ b/mlx/backend/cuda/gemms/cublas_gemm_batched_12_9.cu @@ -163,7 +163,7 @@ void CublasGemm::run_batched( // Launch kernel to set device offsets auto pointers = array( - allocator::malloc(batch_count * sizeof(void*) * 3), + cu::malloc_async(batch_count * sizeof(void*) * 3, encoder.stream()), {batch_count * 3}, uint64); @@ -183,10 +183,10 @@ void CublasGemm::run_batched( num_blocks, block_dims, 0, - pointers.data(), - a.data(), - b.data(), - out.data(), + gpu_ptr(pointers), + gpu_ptr(a), + gpu_ptr(b), + gpu_ptr(out), item_size, const_param(batch_shape), const_param(a_batch_strides), @@ -200,10 +200,10 @@ void CublasGemm::run_batched( num_blocks, block_dims, 0, - pointers.data(), - a.data(), - b.data(), - out.data(), + gpu_ptr(pointers), + gpu_ptr(a), + gpu_ptr(b), + gpu_ptr(out), item_size, const_param(batch_shape), const_param(a_batch_strides), @@ -219,7 +219,7 @@ void CublasGemm::run_batched( encoder.set_input_array(b); encoder.set_output_array(out); - auto a_pointers = pointers.data(); + auto a_pointers = gpu_ptr(pointers); auto b_pointers = a_pointers + batch_count; auto out_pointers = b_pointers + batch_count; execute( @@ -251,7 +251,7 @@ void CublasGemm::run_batched( // Launch kernel to set device offsets auto pointers = array( - allocator::malloc(batch_count * sizeof(uint64_t) * 4), + cu::malloc_async(batch_count * sizeof(uint64_t) * 4, encoder.stream()), {batch_count * 4}, uint64); @@ -271,11 +271,11 @@ void CublasGemm::run_batched( num_blocks, block_dims, 0, - pointers.data(), - a.data(), - b.data(), - c.data(), - out.data(), + gpu_ptr(pointers), + gpu_ptr(a), + gpu_ptr(b), + gpu_ptr(c), + gpu_ptr(out), item_size, const_param(batch_shape), const_param(a_batch_strides), @@ -290,11 +290,11 @@ void CublasGemm::run_batched( num_blocks, block_dims, 0, - pointers.data(), - a.data(), - b.data(), - c.data(), - out.data(), + gpu_ptr(pointers), + gpu_ptr(a), + gpu_ptr(b), + gpu_ptr(c), + gpu_ptr(out), item_size, const_param(batch_shape), const_param(a_batch_strides), @@ -312,7 +312,7 @@ void CublasGemm::run_batched( encoder.set_input_array(c); encoder.set_output_array(out); - auto a_pointers = pointers.data(); + auto a_pointers = gpu_ptr(pointers); auto b_pointers = a_pointers + batch_count; auto c_pointers = b_pointers + batch_count; auto out_pointers = c_pointers + batch_count; diff --git a/mlx/backend/cuda/gemms/gemv.cu b/mlx/backend/cuda/gemms/gemv.cu index e1c755039..6966a583b 100644 --- a/mlx/backend/cuda/gemms/gemv.cu +++ b/mlx/backend/cuda/gemms/gemv.cu @@ -149,13 +149,13 @@ void gemv( auto vec_strides = const_param(b_batch_strides); if (M == 1) { - mat = b.data(); - vec = a.data(); + mat = gpu_ptr(b); + vec = gpu_ptr(a); rows = N; std::swap(mat_strides, vec_strides); } else { - mat = a.data(); - vec = b.data(); + mat = gpu_ptr(a); + vec = gpu_ptr(b); rows = M; } uint32_t num_blocks_x = (rows + rows_per_block - 1) / rows_per_block; @@ -177,7 +177,7 @@ void gemv( 0, mat, vec, - out.data(), + gpu_ptr(out), rows, cols); } else { @@ -189,7 +189,7 @@ void gemv( 0, mat, vec, - out.data(), + gpu_ptr(out), rows, cols, const_param(batch_shape), diff --git a/mlx/backend/cuda/indexing.cpp b/mlx/backend/cuda/indexing.cpp index 1d867e063..5c468d2ba 100644 --- a/mlx/backend/cuda/indexing.cpp +++ b/mlx/backend/cuda/indexing.cpp @@ -31,7 +31,7 @@ void append_indices_arg( int idx_ndim) { SmallVector indices(nidx); for (int i = 0; i < nidx; ++i) { - indices[i] = inputs[i + 1].data(); + indices[i] = gpu_ptr(inputs[i + 1]); } args.append(std::move(indices)); SmallVector indices_shape(nidx * idx_ndim); @@ -59,7 +59,9 @@ void Gather::eval_gpu(const std::vector& inputs, array& out) { assert(inputs.size() > 0); const auto& src = inputs[0]; - out.set_data(allocator::malloc(out.nbytes())); + auto& s = stream(); + auto& encoder = cu::get_command_encoder(s); + out.set_data(cu::malloc_async(out.nbytes(), encoder.stream())); if (out.size() == 0) { return; } @@ -80,7 +82,6 @@ void Gather::eval_gpu(const std::vector& inputs, array& out) { dtype_to_string(idx_dtype), nidx); - auto& s = stream(); cu::JitModule& mod = cu::get_jit_module(s.device, module_name, [&]() { std::vector kernel_names; for (int ndim = 0; ndim <= MAX_NDIM; ++ndim) { @@ -121,7 +122,6 @@ void Gather::eval_gpu(const std::vector& inputs, array& out) { idx_ndim, large ? "int64_t" : "int32_t"); - auto& encoder = cu::get_command_encoder(s); for (const auto& in : inputs) { encoder.set_input_array(in); } @@ -239,7 +239,9 @@ void GatherAxis::eval_gpu(const std::vector& inputs, array& out) { const auto& src = inputs[0]; const auto& idx = inputs[1]; - out.set_data(allocator::malloc(out.nbytes())); + auto& s = stream(); + auto& encoder = cu::get_command_encoder(s); + out.set_data(cu::malloc_async(out.nbytes(), encoder.stream())); if (out.size() == 0) { return; } @@ -251,7 +253,6 @@ void GatherAxis::eval_gpu(const std::vector& inputs, array& out) { dtype_to_string(out.dtype()), dtype_to_string(idx.dtype())); - auto& s = stream(); cu::JitModule& mod = cu::get_jit_module(s.device, module_name, [&]() { std::vector kernel_names; for (int ndim = 0; ndim <= MAX_NDIM; ++ndim) { @@ -312,7 +313,6 @@ void GatherAxis::eval_gpu(const std::vector& inputs, array& out) { idx.flags().row_contiguous, large ? "int64_t" : "int32_t"); - auto& encoder = cu::get_command_encoder(s); for (const auto& in : inputs) { encoder.set_input_array(in); } diff --git a/mlx/backend/cuda/jit_module.h b/mlx/backend/cuda/jit_module.h index e2fd0c8b8..ac2886306 100644 --- a/mlx/backend/cuda/jit_module.h +++ b/mlx/backend/cuda/jit_module.h @@ -31,7 +31,7 @@ struct KernelArgs { } void append(const array& a) { - append(reinterpret_cast(a.data())); + append(reinterpret_cast(gpu_ptr(a))); } template diff --git a/mlx/backend/cuda/kernel_utils.cuh b/mlx/backend/cuda/kernel_utils.cuh index d889cd590..bd5f7ef80 100644 --- a/mlx/backend/cuda/kernel_utils.cuh +++ b/mlx/backend/cuda/kernel_utils.cuh @@ -9,6 +9,7 @@ #include #include "mlx/array.h" +#include "mlx/backend/cuda/allocator.h" #include "mlx/backend/cuda/device/utils.cuh" #include diff --git a/mlx/backend/cuda/layer_norm.cu b/mlx/backend/cuda/layer_norm.cu index 369b2547e..30cc6b837 100644 --- a/mlx/backend/cuda/layer_norm.cu +++ b/mlx/backend/cuda/layer_norm.cu @@ -230,9 +230,10 @@ void LayerNorm::eval_gpu( nvtx3::scoped_range r("LayerNorm::eval_gpu"); auto& s = stream(); auto& out = outputs[0]; + auto& encoder = cu::get_command_encoder(s); // Make sure that the last dimension is contiguous. - auto set_output = [&s, &out](const array& x) { + auto set_output = [&s, &out, &encoder](const array& x) { bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1; if (no_copy && x.ndim() > 1) { auto s = x.strides()[x.ndim() - 2]; @@ -243,7 +244,7 @@ void LayerNorm::eval_gpu( out.copy_shared_buffer(x); } else { out.set_data( - allocator::malloc(x.data_size() * x.itemsize()), + cu::malloc_async(x.data_size() * x.itemsize(), encoder.stream()), x.data_size(), x.strides(), x.flags()); @@ -265,7 +266,6 @@ void LayerNorm::eval_gpu( int64_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0; int64_t b_stride = (b.ndim() == 1) ? b.strides()[0] : 0; - auto& encoder = cu::get_command_encoder(s); encoder.set_input_array(x); encoder.set_input_array(w); encoder.set_input_array(b); @@ -280,10 +280,10 @@ void LayerNorm::eval_gpu( n_rows, block_dim(), 0, - x.data(), - w.data(), - b.data(), - out.data(), + gpu_ptr(x), + gpu_ptr(w), + gpu_ptr(b), + gpu_ptr(out), eps_, axis_size, w_stride, @@ -335,7 +335,7 @@ void LayerNormVJP::eval_gpu( gx.copy_shared_buffer(g); g_in_gx = true; } else { - gx.set_data(allocator::malloc(gx.nbytes())); + gx.set_data(cu::malloc_async(gx.nbytes(), encoder.stream())); } if (g_copied && !g_in_gx) { encoder.add_temporary(g); @@ -355,7 +355,7 @@ void LayerNormVJP::eval_gpu( g_in_gw = true; gw_temp.copy_shared_buffer(g); } else { - gw_temp.set_data(allocator::malloc(gw_temp.nbytes())); + gw_temp.set_data(cu::malloc_async(gw_temp.nbytes(), encoder.stream())); encoder.add_temporary(gw_temp); } } @@ -393,11 +393,11 @@ void LayerNormVJP::eval_gpu( n_rows, block_dim(), 0, - x.data(), - w.data(), - g.data(), - gx.data(), - gw_temp.data(), + gpu_ptr(x), + gpu_ptr(w), + gpu_ptr(g), + gpu_ptr(gx), + gpu_ptr(gw_temp), eps_, axis_size, w_stride); diff --git a/mlx/backend/cuda/load.cpp b/mlx/backend/cuda/load.cpp new file mode 100644 index 000000000..a5687addb --- /dev/null +++ b/mlx/backend/cuda/load.cpp @@ -0,0 +1,60 @@ +// Copyright © 2023 Apple Inc. + +#include +#include + +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/utils.h" +#include "mlx/primitives.h" + +namespace { + +template +void swap_endianness(uint8_t* data_bytes, size_t N) { + struct Elem { + uint8_t bytes[scalar_size]; + }; + + Elem* data = reinterpret_cast(data_bytes); + + for (size_t i = 0; i < N; i++) { + for (size_t j = 0; j < (scalar_size / 2); j++) { + std::swap(data[i].bytes[j], data[i].bytes[scalar_size - j - 1]); + } + } +} + +} // namespace + +namespace mlx::core { + +void Load::eval_gpu(const std::vector& inputs, array& out) { + auto& encoder = cu::get_command_encoder(stream()); + auto size = out.size(); + auto nbytes = size * out.itemsize(); + out.set_data(cu::malloc_async(nbytes, encoder.stream())); + auto out_ptr = malloc(nbytes); + reader_->read(static_cast(out_ptr), nbytes, offset_); + if (swap_endianness_) { + switch (out.itemsize()) { + case 2: + swap_endianness<2>(reinterpret_cast(out_ptr), size); + break; + case 4: + swap_endianness<4>(reinterpret_cast(out_ptr), size); + break; + case 8: + swap_endianness<8>(reinterpret_cast(out_ptr), size); + break; + } + } + CHECK_CUDA_ERROR(cudaMemcpyAsync( + gpu_ptr(out), + out_ptr, + nbytes, + cudaMemcpyDefault, + encoder.stream())); + CHECK_CUDA_ERROR(cudaLaunchHostFunc(encoder.stream(), free, out_ptr)); +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/logsumexp.cu b/mlx/backend/cuda/logsumexp.cu index b90c300d0..a3841a957 100644 --- a/mlx/backend/cuda/logsumexp.cu +++ b/mlx/backend/cuda/logsumexp.cu @@ -115,7 +115,7 @@ void LogSumExp::eval_gpu(const std::vector& inputs, array& out) { auto in = ensure_contiguous(inputs[0]); if (in.flags().row_contiguous) { - out.set_data(allocator::malloc(out.nbytes())); + out.set_data(cu::malloc_async(out.nbytes(), encoder.stream())); } else { auto n = in.shape(-1); auto flags = in.flags(); @@ -130,7 +130,7 @@ void LogSumExp::eval_gpu(const std::vector& inputs, array& out) { } flags.col_contiguous = col_contig; out.set_data( - allocator::malloc(in.nbytes() / n), + cu::malloc_async(in.nbytes() / n, encoder.stream()), in.data_size() / n, std::move(strides), flags); @@ -151,8 +151,8 @@ void LogSumExp::eval_gpu(const std::vector& inputs, array& out) { n_rows, block_dim(), 0, - in.data(), - out.data(), + gpu_ptr(in), + gpu_ptr(out), axis_size); }); }); diff --git a/mlx/backend/cuda/matmul.cpp b/mlx/backend/cuda/matmul.cpp index 8a79b172b..8ccf3c466 100644 --- a/mlx/backend/cuda/matmul.cpp +++ b/mlx/backend/cuda/matmul.cpp @@ -121,7 +121,7 @@ void Matmul::eval_gpu(const std::vector& inputs, array& out) { return; } - out.set_data(allocator::malloc(out.nbytes())); + out.set_data(cu::malloc_async(out.nbytes(), encoder.stream())); int M = a_pre.shape(-2); int N = b_pre.shape(-1); @@ -163,7 +163,7 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { if (beta_ == 1 && a.dtype() != complex64 && c.strides(-1) == 1 && c.data_size() == out.shape(-1)) { - out.set_data(allocator::malloc(out.nbytes())); + out.set_data(cu::malloc_async(out.nbytes(), encoder.stream())); gemm_and_bias( encoder, M, @@ -187,10 +187,10 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { auto sty = c.strides()[c.ndim() - 1]; if (sty == 1 && stx == c.shape(-1)) { ldc = stx; - out.set_data(allocator::malloc(out.nbytes())); + out.set_data(cu::malloc_async(out.nbytes(), encoder.stream())); } else if (sty == 1 && stx == 0) { ldc = 0; - out.set_data(allocator::malloc(out.nbytes())); + out.set_data(cu::malloc_async(out.nbytes(), encoder.stream())); } else { // Copy C into out and set C to out ldc = c.shape(-1); diff --git a/mlx/backend/cuda/primitives.cpp b/mlx/backend/cuda/primitives.cpp index 43a60eedb..48995b097 100644 --- a/mlx/backend/cuda/primitives.cpp +++ b/mlx/backend/cuda/primitives.cpp @@ -28,7 +28,6 @@ NO_GPU(FFT) NO_GPU(GatherMM) NO_GPU(GatherQMM) NO_GPU(Hadamard) -NO_GPU(Load) NO_GPU_MULTI(LUF) NO_GPU_MULTI(QRF) NO_GPU(QuantizedMatmul) diff --git a/mlx/backend/cuda/quantized/affine_quantize.cu b/mlx/backend/cuda/quantized/affine_quantize.cu index a64597a88..da4b34135 100644 --- a/mlx/backend/cuda/quantized/affine_quantize.cu +++ b/mlx/backend/cuda/quantized/affine_quantize.cu @@ -262,10 +262,10 @@ void affine_quantize( num_blocks, block_dims, 0, - w.data(), - wq.data(), - scales.data(), - biases.data(), + gpu_ptr(w), + gpu_ptr(wq), + gpu_ptr(scales), + gpu_ptr(biases), w.size()); }); }); @@ -318,10 +318,10 @@ void affine_dequantize( num_blocks, block_dims, 0, - wq.data(), - scales.data(), - biases.data(), - w.data(), + gpu_ptr(wq), + gpu_ptr(scales), + gpu_ptr(biases), + gpu_ptr(w), w.size()); }); }); diff --git a/mlx/backend/cuda/quantized/fp_quantize.cu b/mlx/backend/cuda/quantized/fp_quantize.cu index 0f979dfb0..45c61baf5 100644 --- a/mlx/backend/cuda/quantized/fp_quantize.cu +++ b/mlx/backend/cuda/quantized/fp_quantize.cu @@ -156,9 +156,9 @@ void fp_quantize( num_blocks, block_dims, 0, - w.data(), - wq.data(), - scales.data(), + gpu_ptr(w), + gpu_ptr(wq), + gpu_ptr(scales), w.size()); } else { throw std::runtime_error( @@ -202,9 +202,9 @@ void fp_dequantize( num_blocks, block_dims, 0, - wq.data(), - scales.data(), - w.data(), + gpu_ptr(wq), + gpu_ptr(scales), + gpu_ptr(w), w.size()); } else { throw std::runtime_error( diff --git a/mlx/backend/cuda/quantized/quantized.cpp b/mlx/backend/cuda/quantized/quantized.cpp index 58710834f..f75064d4e 100644 --- a/mlx/backend/cuda/quantized/quantized.cpp +++ b/mlx/backend/cuda/quantized/quantized.cpp @@ -59,7 +59,7 @@ void fast::Quantize::eval_gpu( auto scales = ensure_row_contiguous(inputs[1], enc, s); auto& w = outputs[0]; - w.set_data(allocator::malloc(w.nbytes())); + w.set_data(cu::malloc_async(w.nbytes(), enc.stream())); if (mode_ == QuantizationMode::Affine) { auto biases = ensure_row_contiguous(inputs[2], enc, s); @@ -72,11 +72,11 @@ void fast::Quantize::eval_gpu( auto& wq = outputs[0]; auto& scales = outputs[1]; - wq.set_data(allocator::malloc(wq.nbytes())); - scales.set_data(allocator::malloc(scales.nbytes())); + wq.set_data(cu::malloc_async(wq.nbytes(), enc.stream())); + scales.set_data(cu::malloc_async(scales.nbytes(), enc.stream())); if (mode_ == QuantizationMode::Affine) { auto& biases = outputs[2]; - biases.set_data(allocator::malloc(biases.nbytes())); + biases.set_data(cu::malloc_async(biases.nbytes(), enc.stream())); affine_quantize(w, wq, scales, biases, group_size_, bits_, enc, s); } else { fp_quantize(w, wq, scales, group_size_, bits_, enc, s); diff --git a/mlx/backend/cuda/random.cu b/mlx/backend/cuda/random.cu index 26a3eb8b7..68dcacdc7 100644 --- a/mlx/backend/cuda/random.cu +++ b/mlx/backend/cuda/random.cu @@ -143,7 +143,9 @@ void RandomBits::eval_gpu(const std::vector& inputs, array& out) { uint32_t elems_per_key = out.size() / num_keys; uint32_t bytes_per_key = out.itemsize() * elems_per_key; - out.set_data(allocator::malloc(out.nbytes())); + auto& s = stream(); + auto& encoder = cu::get_command_encoder(s); + out.set_data(cu::malloc_async(out.nbytes(), encoder.stream())); if (out.size() == 0) { return; } @@ -152,8 +154,6 @@ void RandomBits::eval_gpu(const std::vector& inputs, array& out) { uint32_t half_size = out_per_key / 2; bool odd = out_per_key % 2; - auto& s = stream(); - auto& encoder = cu::get_command_encoder(s); encoder.set_input_array(keys); encoder.set_output_array(out); dim3 grid_dims{num_keys, half_size + odd}; @@ -171,8 +171,8 @@ void RandomBits::eval_gpu(const std::vector& inputs, array& out) { grid, block, 0, - keys.data(), - out.data(), + gpu_ptr(keys), + gpu_ptr(out), grid_dims, odd, bytes_per_key); @@ -182,8 +182,8 @@ void RandomBits::eval_gpu(const std::vector& inputs, array& out) { grid, block, 0, - keys.data(), - out.data(), + gpu_ptr(keys), + gpu_ptr(out), grid_dims, odd, bytes_per_key, diff --git a/mlx/backend/cuda/reduce/all_reduce.cu b/mlx/backend/cuda/reduce/all_reduce.cu index b815597bd..685bbefa8 100644 --- a/mlx/backend/cuda/reduce/all_reduce.cu +++ b/mlx/backend/cuda/reduce/all_reduce.cu @@ -66,7 +66,7 @@ void all_reduce( Reduce::ReduceType reduce_type) { constexpr int N_READS = 8; - out.set_data(allocator::malloc(out.nbytes())); + out.set_data(cu::malloc_async(out.nbytes(), encoder.stream())); auto get_args = [](size_t size, int N) { int threads = std::min(512UL, (size + N - 1) / N); @@ -100,14 +100,15 @@ void all_reduce( Dtype dt = in.dtype(); // Cub doesn't like const pointers for load (sigh). - void* indata = const_cast(in.data()); + void* indata = const_cast(gpu_ptr(in)); // Large array so allocate an intermediate and accumulate there std::tie(blocks, threads, block_step) = get_args(insize, N_READS); encoder.set_input_array(in); if (blocks > 1) { array intermediate({blocks}, out.dtype(), nullptr, {}); - intermediate.set_data(allocator::malloc(intermediate.nbytes())); + intermediate.set_data( + cu::malloc_async(intermediate.nbytes(), encoder.stream())); encoder.add_temporary(intermediate); encoder.set_output_array(intermediate); dispatch_all_types(dt, [&](auto type_tag) { @@ -122,14 +123,14 @@ void all_reduce( threads, 0, static_cast(indata), - intermediate.data(), + gpu_ptr(intermediate), block_step, insize); }); }); // Set the input for the next step and recalculate the blocks - indata = intermediate.data(); + indata = gpu_ptr(intermediate); dt = intermediate.dtype(); insize = intermediate.size(); std::tie(blocks, threads, block_step) = get_args(insize, N_READS); @@ -149,7 +150,7 @@ void all_reduce( threads, 0, static_cast(indata), - out.data(), + gpu_ptr(out), block_step, insize); }); diff --git a/mlx/backend/cuda/reduce/col_reduce.cu b/mlx/backend/cuda/reduce/col_reduce.cu index feafc2fb0..fb81079a5 100644 --- a/mlx/backend/cuda/reduce/col_reduce.cu +++ b/mlx/backend/cuda/reduce/col_reduce.cu @@ -250,7 +250,7 @@ void col_reduce_looped( const cu::ColReduceArgs& args) { // Allocate data for the output using in's layout to access them as // contiguously as possible. - allocate_same_layout(out, in, axes); + allocate_same_layout(out, in, axes, encoder); encoder.set_input_array(in); encoder.set_output_array(out); @@ -261,7 +261,7 @@ void col_reduce_looped( using T = cuda_type_t; using U = typename cu::ReduceResult::type; // Cub doesn't like const pointers for vectorized loads. (sigh) - T* indata = const_cast(in.data()); + T* indata = const_cast(gpu_ptr(in)); constexpr int N_READS = 4; constexpr int BM = 32; @@ -276,7 +276,7 @@ void col_reduce_looped( blocks, 0, indata, - out.data(), + gpu_ptr(out), static_cast(args)); }); }); @@ -293,7 +293,7 @@ void col_reduce_small( const cu::ColReduceArgs& args) { // Allocate data for the output using in's layout to access them as // contiguously as possible. - allocate_same_layout(out, in, axes); + allocate_same_layout(out, in, axes, encoder); encoder.set_input_array(in); encoder.set_output_array(out); @@ -312,8 +312,8 @@ void col_reduce_small( grid, block, 0, - in.data(), - out.data(), + gpu_ptr(in), + gpu_ptr(out), static_cast(args), out.size()); }); diff --git a/mlx/backend/cuda/reduce/init_reduce.cu b/mlx/backend/cuda/reduce/init_reduce.cu index 8c0d380f5..40e3d6859 100644 --- a/mlx/backend/cuda/reduce/init_reduce.cu +++ b/mlx/backend/cuda/reduce/init_reduce.cu @@ -28,7 +28,7 @@ void init_reduce( Reduce::ReduceType reduce_type) { // Allocate if needed if (out.data_shared_ptr() == nullptr) { - out.set_data(allocator::malloc(out.nbytes())); + out.set_data(cu::malloc_async(out.nbytes(), encoder.stream())); } encoder.set_output_array(out); @@ -42,7 +42,7 @@ void init_reduce( dim3 block(grid.x < 1024 ? grid.x : 1024, 1, 1); grid.x = (grid.x + 1023) / 1024; encoder.add_kernel_node( - kernel, grid, block, 0, out.data(), out.size()); + kernel, grid, block, 0, gpu_ptr(out), out.size()); }); }); } diff --git a/mlx/backend/cuda/reduce/reduce_utils.cuh b/mlx/backend/cuda/reduce/reduce_utils.cuh index d993bacbb..0323c3184 100644 --- a/mlx/backend/cuda/reduce/reduce_utils.cuh +++ b/mlx/backend/cuda/reduce/reduce_utils.cuh @@ -5,6 +5,7 @@ #include #include "mlx/backend/common/utils.h" +#include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device/utils.cuh" #include @@ -92,9 +93,10 @@ block_reduce(Block block, Warp warp, T (&vals)[N], T* smem, Op op, T init) { inline void allocate_same_layout( array& out, const array& in, - const std::vector& axes) { + const std::vector& axes, + cu::CommandEncoder& encoder) { if (in.flags().row_contiguous) { - out.set_data(allocator::malloc(out.nbytes())); + out.set_data(cu::malloc_async(out.nbytes(), encoder.stream())); return; } @@ -133,7 +135,7 @@ inline void allocate_same_layout( fl.col_contiguous = cc; fl.contiguous = true; out.set_data( - allocator::malloc(out.nbytes()), + cu::malloc_async(out.nbytes(), encoder.stream()), data_size, final_strides, fl, diff --git a/mlx/backend/cuda/reduce/row_reduce.cu b/mlx/backend/cuda/reduce/row_reduce.cu index 1ae46d0a3..ea99e1132 100644 --- a/mlx/backend/cuda/reduce/row_reduce.cu +++ b/mlx/backend/cuda/reduce/row_reduce.cu @@ -238,7 +238,7 @@ void row_reduce_simple( const ReductionPlan& plan) { // Allocate data for the output using in's layout to avoid elem_to_loc in the // kernel. - allocate_same_layout(out, in, axes); + allocate_same_layout(out, in, axes, encoder); // TODO: If out.size() < 1024 which will be a common case then write this in // 2 passes. Something like 32 * out.size() and then do a warp reduce. @@ -268,10 +268,10 @@ void row_reduce_simple( kernel = cu::row_reduce_simple; } - T* indata = const_cast(in.data()); + T* indata = const_cast(gpu_ptr(in)); int size = plan.shape.back(); encoder.add_kernel_node( - kernel, grid, block, 0, indata, out.data(), out.size(), size); + kernel, grid, block, 0, indata, gpu_ptr(out), out.size(), size); }); }); } @@ -286,7 +286,7 @@ void row_reduce_looped( cu::RowReduceArgs args) { // Allocate data for the output using in's layout to access them as // contiguously as possible. - allocate_same_layout(out, in, axes); + allocate_same_layout(out, in, axes, encoder); encoder.set_input_array(in); encoder.set_output_array(out); @@ -315,7 +315,7 @@ void row_reduce_looped( }); encoder.add_kernel_node( - kernel, grid, block, 0, in.data(), out.data(), args); + kernel, grid, block, 0, gpu_ptr(in), gpu_ptr(out), args); }); }); } diff --git a/mlx/backend/cuda/rms_norm.cu b/mlx/backend/cuda/rms_norm.cu index bc879c6f8..8d5d02238 100644 --- a/mlx/backend/cuda/rms_norm.cu +++ b/mlx/backend/cuda/rms_norm.cu @@ -176,9 +176,10 @@ void RMSNorm::eval_gpu( nvtx3::scoped_range r("RMSNorm::eval_gpu"); auto& s = stream(); auto& out = outputs[0]; + auto& encoder = cu::get_command_encoder(s); // Make sure that the last dimension is contiguous. - auto set_output = [&s, &out](const array& x) { + auto set_output = [&s, &out, &encoder](const array& x) { bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1; if (no_copy && x.ndim() > 1) { auto s = x.strides()[x.ndim() - 2]; @@ -189,7 +190,7 @@ void RMSNorm::eval_gpu( out.copy_shared_buffer(x); } else { out.set_data( - allocator::malloc(x.data_size() * x.itemsize()), + cu::malloc_async(x.data_size() * x.itemsize(), encoder.stream()), x.data_size(), x.strides(), x.flags()); @@ -209,7 +210,6 @@ void RMSNorm::eval_gpu( int32_t n_rows = x.data_size() / axis_size; int64_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0; - auto& encoder = cu::get_command_encoder(s); encoder.set_input_array(x); encoder.set_input_array(w); encoder.set_output_array(out); @@ -223,9 +223,9 @@ void RMSNorm::eval_gpu( n_rows, block_dim(), 0, - x.data(), - w.data(), - out.data(), + gpu_ptr(x), + gpu_ptr(w), + gpu_ptr(out), eps_, axis_size, w_stride); @@ -274,7 +274,7 @@ void RMSNormVJP::eval_gpu( gx.copy_shared_buffer(g); g_in_gx = true; } else { - gx.set_data(allocator::malloc(gx.nbytes())); + gx.set_data(cu::malloc_async(gx.nbytes(), encoder.stream())); } if (g_copied && !g_in_gx) { encoder.add_temporary(g); @@ -292,7 +292,7 @@ void RMSNormVJP::eval_gpu( if (!g_in_gx && donate_g) { gw_temp.copy_shared_buffer(g); } else { - gw_temp.set_data(allocator::malloc(gw_temp.nbytes())); + gw_temp.set_data(cu::malloc_async(gw_temp.nbytes(), encoder.stream())); encoder.add_temporary(gw_temp); } } @@ -318,11 +318,11 @@ void RMSNormVJP::eval_gpu( n_rows, block_dim(), 0, - x.data(), - w.data(), - g.data(), - gx.data(), - gw_temp.data(), + gpu_ptr(x), + gpu_ptr(w), + gpu_ptr(g), + gpu_ptr(gx), + gpu_ptr(gw_temp), eps_, axis_size, w_stride); diff --git a/mlx/backend/cuda/rope.cu b/mlx/backend/cuda/rope.cu index bac67cf90..67285e884 100644 --- a/mlx/backend/cuda/rope.cu +++ b/mlx/backend/cuda/rope.cu @@ -250,6 +250,7 @@ void RoPE::eval_gpu( nvtx3::scoped_range r("RoPE::eval_gpu"); auto& s = stream(); + auto& encoder = cu::get_command_encoder(s); auto& in = inputs[0]; auto& offset = inputs[1]; auto& out = outputs[0]; @@ -291,14 +292,14 @@ void RoPE::eval_gpu( donated = true; out.copy_shared_buffer(in); } else { - out.set_data(allocator::malloc(out.nbytes())); + out.set_data(cu::malloc_async(out.nbytes(), encoder.stream())); } strides[0] = mat_size; strides[1] = in.strides()[ndim - 2]; strides[2] = in.strides()[ndim - 1]; } else if (dispatch_ndim == 3) { // Handle non-contiguous 3D inputs - out.set_data(allocator::malloc(out.nbytes())); + out.set_data(cu::malloc_async(out.nbytes(), encoder.stream())); strides[0] = in.strides()[ndim - 3]; strides[1] = in.strides()[ndim - 2]; strides[2] = in.strides()[ndim - 1]; @@ -319,7 +320,6 @@ void RoPE::eval_gpu( bool single = in.flags().row_contiguous && B == 1 && T == 1; bool with_freqs = inputs.size() == 3; - auto& encoder = cu::get_command_encoder(s); encoder.set_input_array(donated ? out : in); encoder.set_input_array(offset); if (with_freqs) { @@ -340,9 +340,9 @@ void RoPE::eval_gpu( grid, block, 0, - (donated ? out : in).data(), - out.data(), - offset.data(), + gpu_ptr(donated ? out : in), + gpu_ptr(out), + gpu_ptr(offset), scale_, std::log2(base_), mat_size, @@ -357,10 +357,10 @@ void RoPE::eval_gpu( grid, block, 0, - (donated ? out : in).data(), - out.data(), - offset.data(), - inputs[2].data(), + gpu_ptr(donated ? out : in), + gpu_ptr(out), + gpu_ptr(offset), + gpu_ptr(inputs[2]), scale_, mat_size, dims, @@ -381,10 +381,10 @@ void RoPE::eval_gpu( grid, block, 0, - (donated ? out : in).data(), - out.data(), - offset.data(), - inputs[2].data(), + gpu_ptr(donated ? out : in), + gpu_ptr(out), + gpu_ptr(offset), + gpu_ptr(inputs[2]), scale_, std::log2(base_), strides, @@ -408,9 +408,9 @@ void RoPE::eval_gpu( grid, block, 0, - (donated ? out : in).data(), - out.data(), - offset.data(), + gpu_ptr(donated ? out : in), + gpu_ptr(out), + gpu_ptr(offset), scale_, std::log2(base_), strides, diff --git a/mlx/backend/cuda/scaled_dot_product_attention.cu b/mlx/backend/cuda/scaled_dot_product_attention.cu index 825da5cd3..151fca041 100644 --- a/mlx/backend/cuda/scaled_dot_product_attention.cu +++ b/mlx/backend/cuda/scaled_dot_product_attention.cu @@ -513,11 +513,11 @@ void sdpa_vector_1pass_fallback( grid_dim, block_dim, 0, - q.data(), - k.data(), - v.data(), - o.data(), - sinks ? (*sinks).data() : nullptr, + gpu_ptr(q), + gpu_ptr(k), + gpu_ptr(v), + gpu_ptr(o), + sinks ? gpu_ptr(*sinks) : nullptr, params); }); }); @@ -565,9 +565,10 @@ void sdpa_vector_2pass_fallback( array sums(intermediate_shape, float32, nullptr, {}); array maxs(std::move(intermediate_shape), float32, nullptr, {}); - intermediate.set_data(allocator::malloc(intermediate.nbytes())); - sums.set_data(allocator::malloc(sums.nbytes())); - maxs.set_data(allocator::malloc(maxs.nbytes())); + intermediate.set_data( + cu::malloc_async(intermediate.nbytes(), encoder.stream())); + sums.set_data(cu::malloc_async(sums.nbytes(), encoder.stream())); + maxs.set_data(cu::malloc_async(maxs.nbytes(), encoder.stream())); encoder.add_temporary(intermediate); encoder.add_temporary(sums); @@ -601,13 +602,13 @@ void sdpa_vector_2pass_fallback( grid_dim, block_dim, 0, - q.data(), - k.data(), - v.data(), - sinks ? (*sinks).data() : nullptr, - intermediate.data(), - sums.data(), - maxs.data(), + gpu_ptr(q), + gpu_ptr(k), + gpu_ptr(v), + sinks ? gpu_ptr(*sinks) : nullptr, + gpu_ptr(intermediate), + gpu_ptr(sums), + gpu_ptr(maxs), params); } @@ -628,10 +629,10 @@ void sdpa_vector_2pass_fallback( grid_dim, block_dim, 0, - intermediate.data(), - sums.data(), - maxs.data(), - o.data(), + gpu_ptr(intermediate), + gpu_ptr(sums), + gpu_ptr(maxs), + gpu_ptr(o), params); } }); @@ -787,7 +788,7 @@ void ScaledDotProductAttention::eval_gpu( }; o.set_data( - allocator::malloc(o.nbytes()), + cu::malloc_async(o.nbytes(), encoder.stream()), o.size(), {str_oB, str_oH, str_oL, str_oD}, flags); diff --git a/mlx/backend/cuda/scan.cu b/mlx/backend/cuda/scan.cu index 56d4ae275..d05cdb94a 100644 --- a/mlx/backend/cuda/scan.cu +++ b/mlx/backend/cuda/scan.cu @@ -367,13 +367,14 @@ void Scan::eval_gpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto in = inputs[0]; auto& s = stream(); + auto& encoder = cu::get_command_encoder(s); if (in.flags().contiguous && in.strides()[axis_] != 0) { if (in.is_donatable() && in.itemsize() == out.itemsize()) { out.copy_shared_buffer(in); } else { out.set_data( - allocator::malloc(in.data_size() * out.itemsize()), + cu::malloc_async(in.data_size() * out.itemsize(), encoder.stream()), in.data_size(), in.strides(), in.flags()); @@ -387,7 +388,6 @@ void Scan::eval_gpu(const std::vector& inputs, array& out) { int32_t axis_size = in.shape(axis_); bool contiguous = in.strides()[axis_] == 1; - auto& encoder = cu::get_command_encoder(s); encoder.set_input_array(in); encoder.set_output_array(out); @@ -415,8 +415,8 @@ void Scan::eval_gpu(const std::vector& inputs, array& out) { in.data_size() / axis_size, block_dim, 0, - in.data(), - out.data(), + gpu_ptr(in), + gpu_ptr(out), axis_size); } else { constexpr int BM = WARP_SIZE; @@ -445,8 +445,8 @@ void Scan::eval_gpu(const std::vector& inputs, array& out) { num_blocks, block_dim, 0, - in.data(), - out.data(), + gpu_ptr(in), + gpu_ptr(out), axis_size, stride, stride_blocks); diff --git a/mlx/backend/cuda/slicing.cpp b/mlx/backend/cuda/slicing.cpp index 93241936b..18cf1e02c 100644 --- a/mlx/backend/cuda/slicing.cpp +++ b/mlx/backend/cuda/slicing.cpp @@ -23,14 +23,15 @@ void concatenate_gpu( } std::partial_sum(sizes.cbegin(), sizes.cend(), sizes.begin()); - out.set_data(allocator::malloc(out.nbytes())); + auto& encoder = cu::get_command_encoder(s); + out.set_data(cu::malloc_async(out.nbytes(), encoder.stream())); auto strides = out.strides(); auto flags = out.flags(); flags.row_contiguous = false; flags.col_contiguous = false; flags.contiguous = false; - auto concurrent = cu::get_command_encoder(s).concurrent_context(); + auto concurrent = encoder.concurrent_context(); for (int i = 0; i < inputs.size(); i++) { array out_slice(inputs[i].shape(), out.dtype(), nullptr, {}); size_t data_offset = strides[axis] * sizes[i]; @@ -80,6 +81,7 @@ array compute_dynamic_offset( return std::make_tuple(false, std::move(source), std::vector{kernel_name}); }); + auto& encoder = cu::get_command_encoder(s); // Prepare output. array offset({1}, int64, nullptr, {}); bool donate = indices.is_donatable() && @@ -87,10 +89,9 @@ array compute_dynamic_offset( if (donate) { offset.copy_shared_buffer(indices); } else { - offset.set_data(allocator::malloc(offset.itemsize())); + offset.set_data(cu::malloc_async(offset.itemsize(), encoder.stream())); } - auto& encoder = cu::get_command_encoder(s); encoder.add_temporary(offset); encoder.set_input_array(indices); encoder.set_output_array(offset); diff --git a/mlx/backend/cuda/softmax.cu b/mlx/backend/cuda/softmax.cu index d808bce38..e065d0b89 100644 --- a/mlx/backend/cuda/softmax.cu +++ b/mlx/backend/cuda/softmax.cu @@ -109,15 +109,16 @@ void Softmax::eval_gpu(const std::vector& inputs, array& out) { nvtx3::scoped_range r("Softmax::eval_gpu"); assert(inputs.size() == 1); auto& s = stream(); + auto& encoder = cu::get_command_encoder(s); // Make sure that the last dimension is contiguous. - auto set_output = [&s, &out](const array& x) { + auto set_output = [&s, &out, &encoder](const array& x) { if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) { if (x.is_donatable()) { out.copy_shared_buffer(x); } else { out.set_data( - allocator::malloc(x.data_size() * x.itemsize()), + cu::malloc_async(x.data_size() * x.itemsize(), encoder.stream()), x.data_size(), x.strides(), x.flags()); @@ -136,7 +137,6 @@ void Softmax::eval_gpu(const std::vector& inputs, array& out) { int axis_size = in.shape().back(); int n_rows = in.data_size() / axis_size; - auto& encoder = cu::get_command_encoder(s); encoder.set_input_array(in); encoder.set_output_array(out); dispatch_float_types(out.dtype(), "softmax", [&](auto type_tag) { @@ -152,8 +152,8 @@ void Softmax::eval_gpu(const std::vector& inputs, array& out) { n_rows, block_dim(), 0, - in.data(), - out.data(), + gpu_ptr(in), + gpu_ptr(out), axis_size); }); }); diff --git a/mlx/backend/cuda/sort.cu b/mlx/backend/cuda/sort.cu index 95a2d263c..5becbd496 100644 --- a/mlx/backend/cuda/sort.cu +++ b/mlx/backend/cuda/sort.cu @@ -49,11 +49,14 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) { array trans = swapaxes_in_eval(in, axis, last_dim); in = contiguous_copy_gpu(trans, s); encoder.add_temporary(in); - out = array(allocator::malloc(out.nbytes()), in.shape(), out.dtype()); + out = array( + cu::malloc_async(out.nbytes(), encoder.stream()), + in.shape(), + out.dtype()); encoder.add_temporary(out); } else { out.set_data( - allocator::malloc(in.data_size() * out.itemsize()), + cu::malloc_async(in.data_size() * out.itemsize(), encoder.stream()), in.data_size(), in.strides(), in.flags()); @@ -70,22 +73,28 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) { thrust::make_counting_iterator(0), OffsetTransform{nsort}); if (argsort) { // Indices in the sorted dimension. - array indices(allocator::malloc(out.nbytes()), in.shape(), out.dtype()); + array indices( + cu::malloc_async(out.nbytes(), encoder.stream()), + in.shape(), + out.dtype()); encoder.add_temporary(indices); // In argsort though we don't need the result of sorted values, the // API requires us to provide an array to store it. - array discard(allocator::malloc(in.nbytes()), in.shape(), in.dtype()); + array discard( + cu::malloc_async(in.nbytes(), encoder.stream()), + in.shape(), + in.dtype()); encoder.add_temporary(discard); size_t size; CHECK_CUDA_ERROR(cub::DeviceSegmentedRadixSort::SortPairs( nullptr, size, - in.data(), - discard.data(), - indices.data(), - out.data(), + gpu_ptr(in), + gpu_ptr(discard), + gpu_ptr(indices), + gpu_ptr(out), in.data_size(), in.data_size() / nsort, offsets, @@ -94,7 +103,10 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) { sizeof(Type) * 8, stream)); - array temp(allocator::malloc(size), {static_cast(size)}, uint8); + array temp( + cu::malloc_async(size, encoder.stream()), + {static_cast(size)}, + uint8); encoder.add_temporary(temp); // Start capturing after allocations @@ -103,16 +115,16 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) { cu::thrust_policy(stream), thrust::counting_iterator(0), thrust::counting_iterator(indices.data_size()), - thrust::device_pointer_cast(indices.data()), + thrust::device_pointer_cast(gpu_ptr(indices)), ModOp{static_cast(nsort)}); CHECK_CUDA_ERROR(cub::DeviceSegmentedRadixSort::SortPairs( - temp.data(), + gpu_ptr(temp), size, - in.data(), - discard.data(), - indices.data(), - out.data(), + gpu_ptr(in), + gpu_ptr(discard), + gpu_ptr(indices), + gpu_ptr(out), in.data_size(), in.data_size() / nsort, offsets, @@ -125,8 +137,8 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) { CHECK_CUDA_ERROR(cub::DeviceSegmentedRadixSort::SortKeys( nullptr, size, - in.data(), - out.data(), + gpu_ptr(in), + gpu_ptr(out), in.data_size(), in.data_size() / nsort, offsets, @@ -135,16 +147,19 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) { sizeof(Type) * 8, stream)); - array temp(allocator::malloc(size), {static_cast(size)}, uint8); + array temp( + cu::malloc_async(size, encoder.stream()), + {static_cast(size)}, + uint8); encoder.add_temporary(temp); // Start capturing after allocations auto capture = encoder.capture_context(); CHECK_CUDA_ERROR(cub::DeviceSegmentedRadixSort::SortKeys( - temp.data(), + gpu_ptr(temp), size, - in.data(), - out.data(), + gpu_ptr(in), + gpu_ptr(out), in.data_size(), in.data_size() / nsort, offsets, diff --git a/mlx/backend/cuda/ternary.cu b/mlx/backend/cuda/ternary.cu index 84ae996aa..b1cd99f7d 100644 --- a/mlx/backend/cuda/ternary.cu +++ b/mlx/backend/cuda/ternary.cu @@ -168,10 +168,10 @@ void ternary_op_gpu_inplace( num_blocks, block_dims, 0, - a.data(), - b.data(), - c.data(), - out.data(), + gpu_ptr(a), + gpu_ptr(b), + gpu_ptr(c), + gpu_ptr(out), out.data_size()); }); } else { @@ -211,10 +211,10 @@ void ternary_op_gpu_inplace( {num_blocks_x, num_blocks_y}, block_dims, 0, - a.data(), - b.data(), - c.data(), - out.data(), + gpu_ptr(a), + gpu_ptr(b), + gpu_ptr(c), + gpu_ptr(out), rest, const_param(shape), const_param(a_strides), @@ -231,10 +231,10 @@ void ternary_op_gpu_inplace( {num_blocks_x, num_blocks_y}, block_dims, 0, - a.data(), - b.data(), - c.data(), - out.data(), + gpu_ptr(a), + gpu_ptr(b), + gpu_ptr(c), + gpu_ptr(out), rest, const_param(shape), const_param(a_strides), @@ -256,7 +256,10 @@ void ternary_op_gpu( auto& b = inputs[1]; auto& c = inputs[2]; auto topt = get_ternary_op_type(a, b, c); - set_ternary_op_output_data(a, b, c, out, topt); + auto& encoder = cu::get_command_encoder(s); + set_ternary_op_output_data(a, b, c, out, topt, [&](auto n) { + return cu::malloc_async(n, encoder.stream()); + }); ternary_op_gpu_inplace(inputs, out, s); } diff --git a/mlx/backend/cuda/unary/unary.cuh b/mlx/backend/cuda/unary/unary.cuh index 8f4a02d50..2ffbf1eea 100644 --- a/mlx/backend/cuda/unary/unary.cuh +++ b/mlx/backend/cuda/unary/unary.cuh @@ -158,8 +158,8 @@ void unary_op_gpu_inplace( num_blocks, block_dims, 0, - in.data(), - out.data(), + gpu_ptr(in), + gpu_ptr(out), out.data_size()); } else { using IdxT = std::conditional_t; @@ -182,8 +182,8 @@ void unary_op_gpu_inplace( {num_blocks_x, num_blocks_y}, block_dims, 0, - in.data(), - out.data(), + gpu_ptr(in), + gpu_ptr(out), rest, const_param(shape), const_param(strides), @@ -207,7 +207,10 @@ void unary_op_gpu( array& out, const char* op, const Stream& s) { - set_unary_output_data(inputs[0], out); + auto& encoder = cu::get_command_encoder(s); + set_unary_output_data(inputs[0], out, [&](auto n) { + return cu::malloc_async(n, encoder.stream()); + }); unary_op_gpu_inplace(inputs, out, op, s); } diff --git a/mlx/backend/cuda/utils.h b/mlx/backend/cuda/utils.h index 9e95a84ef..417f7c8aa 100644 --- a/mlx/backend/cuda/utils.h +++ b/mlx/backend/cuda/utils.h @@ -4,89 +4,12 @@ #pragma once -#include -#include -#include +#include "mlx/array.h" +#include "mlx/backend/cuda/allocator.h" +#include "mlx/backend/cuda/cuda_utils.h" namespace mlx::core { -namespace cu { -class Device; - -} - -struct Dtype; - -// Throw exception if the cuda API does not succeed. -void check_cublas_error(const char* name, cublasStatus_t err); -void check_cuda_error(const char* name, cudaError_t err); -void check_cuda_error(const char* name, CUresult err); - -// The macro version that prints the command that failed. -#define CHECK_CUBLAS_ERROR(cmd) check_cublas_error(#cmd, (cmd)) -#define CHECK_CUDA_ERROR(cmd) check_cuda_error(#cmd, (cmd)) - -// Convert Dtype to CUDA C++ types. -const char* dtype_to_cuda_type(const Dtype& dtype); - -// Base class for RAII managed CUDA resources. -template -class CudaHandle { - public: - CudaHandle(Handle handle = nullptr) : handle_(handle) {} - - CudaHandle(CudaHandle&& other) : handle_(other.handle_) { - assert(this != &other); - other.handle_ = nullptr; - } - - ~CudaHandle() { - reset(); - } - - CudaHandle(const CudaHandle&) = delete; - CudaHandle& operator=(const CudaHandle&) = delete; - - CudaHandle& operator=(CudaHandle&& other) { - assert(this != &other); - reset(); - std::swap(handle_, other.handle_); - return *this; - } - - void reset() { - if (handle_ != nullptr) { - CHECK_CUDA_ERROR(Destroy(handle_)); - handle_ = nullptr; - } - } - - operator Handle() const { - return handle_; - } - - protected: - Handle handle_; -}; - -// Wrappers of CUDA resources. -class CudaGraph : public CudaHandle { - public: - using CudaHandle::CudaHandle; - explicit CudaGraph(cu::Device& device); - void end_capture(cudaStream_t stream); -}; - -class CudaGraphExec : public CudaHandle { - public: - void instantiate(cudaGraph_t graph); -}; - -class CudaStream : public CudaHandle { - public: - explicit CudaStream(cu::Device& device); -}; - template inline uint max_occupancy_block_dim(T kernel) { int _, block_dim; @@ -100,4 +23,22 @@ inline uint max_occupancy_block_dim(T kernel) { return block_dim; } +template +inline T* gpu_ptr(array& arr) { + return reinterpret_cast( + static_cast( + static_cast(arr.buffer().ptr())->data) + + arr.offset()); +} + +template +inline const T* gpu_ptr(const array& arr) { + return gpu_ptr(const_cast(arr)); +} + +struct Dtype; + +// Convert Dtype to CUDA C++ types. +const char* dtype_to_cuda_type(const Dtype& dtype); + } // namespace mlx::core diff --git a/mlx/backend/gpu/copy.cpp b/mlx/backend/gpu/copy.cpp index 472ee486b..1ed6e2345 100644 --- a/mlx/backend/gpu/copy.cpp +++ b/mlx/backend/gpu/copy.cpp @@ -7,18 +7,7 @@ namespace mlx::core { -void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) { - bool donated = set_copy_output_data(in, out, ctype); - if (donated && in.dtype() == out.dtype()) { - // If the output has the same type as the input then there is nothing to - // copy, just use the buffer. - return; - } - if (ctype == CopyType::GeneralGeneral) { - ctype = CopyType::General; - } - copy_gpu_inplace(in, out, ctype, s); -} +void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s); void copy_gpu(const array& in, array& out, CopyType ctype) { copy_gpu(in, out, ctype, out.primitive().stream()); @@ -52,25 +41,6 @@ array contiguous_copy_gpu(const array& arr, const Stream& s) { return arr_copy; } -void reshape_gpu(const array& in, array& out, Stream s) { - auto [copy_necessary, out_strides] = prepare_reshape(in, out); - if (copy_necessary) { - out.set_data(allocator::malloc(out.nbytes())); - copy_gpu_inplace( - in, - out, - in.shape(), - in.strides(), - make_contiguous_strides(in.shape()), - 0, - 0, - CopyType::General, - s); - } else { - shared_buffer_reshape(in, out_strides, out); - } -} - array flatten_in_eval(const array& x, int start_axis, int end_axis, Stream s) { int ndim = x.ndim(); if (start_axis < 0) { diff --git a/mlx/backend/gpu/primitives.cpp b/mlx/backend/gpu/primitives.cpp index ee40799df..0138928c0 100644 --- a/mlx/backend/gpu/primitives.cpp +++ b/mlx/backend/gpu/primitives.cpp @@ -83,7 +83,7 @@ void Depends::eval_gpu( void DynamicSlice::eval_gpu(const std::vector& inputs, array& out) { MLX_PROFILER_RANGE("DynamicSlice::eval_gpu"); if (out.size() == 0) { - out.set_data(nullptr); + out.set_data(allocator::malloc(0)); return; } @@ -112,7 +112,7 @@ void DynamicSliceUpdate::eval_gpu( array& out) { MLX_PROFILER_RANGE("DynamicSliceUpdate::eval_gpu"); if (out.size() == 0) { - out.set_data(nullptr); + out.set_data(allocator::malloc(0)); return; } @@ -209,7 +209,7 @@ void Slice::eval_gpu(const std::vector& inputs, array& out) { MLX_PROFILER_RANGE("Slice::eval_gpu"); assert(inputs.size() == 1); if (out.size() == 0) { - out.set_data(nullptr); + out.set_data(allocator::malloc(0)); return; } @@ -220,7 +220,7 @@ void Slice::eval_gpu(const std::vector& inputs, array& out) { void SliceUpdate::eval_gpu(const std::vector& inputs, array& out) { assert(inputs.size() == 2); if (out.size() == 0) { - out.set_data(nullptr); + out.set_data(allocator::malloc(0)); return; } diff --git a/mlx/backend/metal/copy.cpp b/mlx/backend/metal/copy.cpp index d9c568e52..6b791289c 100644 --- a/mlx/backend/metal/copy.cpp +++ b/mlx/backend/metal/copy.cpp @@ -10,6 +10,19 @@ namespace mlx::core { constexpr int MAX_COPY_SPECIALIZED_DIMS = 3; +void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) { + bool donated = set_copy_output_data(in, out, ctype); + if (donated && in.dtype() == out.dtype()) { + // If the output has the same type as the input then there is nothing to + // copy, just use the buffer. + return; + } + if (ctype == CopyType::GeneralGeneral) { + ctype = CopyType::General; + } + copy_gpu_inplace(in, out, ctype, s); +} + void copy_gpu_inplace( const array& in, array& out, @@ -201,4 +214,23 @@ void fill_gpu(const array& val, array& out, const Stream& s) { compute_encoder.dispatch_threads(grid_dims, group_dims); } +void reshape_gpu(const array& in, array& out, Stream s) { + auto [copy_necessary, out_strides] = prepare_reshape(in, out); + if (copy_necessary) { + out.set_data(allocator::malloc(out.nbytes())); + copy_gpu_inplace( + in, + out, + in.shape(), + in.strides(), + make_contiguous_strides(in.shape()), + 0, + 0, + CopyType::General, + s); + } else { + shared_buffer_reshape(in, out_strides, out); + } +} + } // namespace mlx::core diff --git a/mlx/backend/metal/device.cpp b/mlx/backend/metal/device.cpp index 6d4d2841d..f9c8b8052 100644 --- a/mlx/backend/metal/device.cpp +++ b/mlx/backend/metal/device.cpp @@ -261,10 +261,7 @@ void CommandEncoder::set_input_array( needs_barrier_ = needs_barrier_ | (prev_outputs_.find(r_buf) != prev_outputs_.end()); auto a_buf = static_cast(a.buffer().ptr()); - auto base_offset = a.data() - - static_cast(const_cast(a_buf)->contents()); - base_offset += offset; - enc_->setBuffer(a_buf, base_offset, idx); + enc_->setBuffer(a_buf, a.offset() + offset, idx); } void CommandEncoder::set_output_array( @@ -448,10 +445,8 @@ void Device::end_encoding(int index) { auto& enc = *stream.encoder; // Remove temporaries from inputs and outputs for (auto& t : stream.temporaries) { - if (t.data() != nullptr) { - enc.outputs().erase(t.buffer().ptr()); - enc.inputs().erase(t.buffer().ptr()); - } + enc.outputs().erase(t.buffer().ptr()); + enc.inputs().erase(t.buffer().ptr()); } // Keep references to the fences we waited on and put them diff --git a/mlx/backend/metal/fence.cpp b/mlx/backend/metal/fence.cpp index 5abdf7309..da73d5d91 100644 --- a/mlx/backend/metal/fence.cpp +++ b/mlx/backend/metal/fence.cpp @@ -31,7 +31,7 @@ struct FenceImpl { auto p = metal::new_scoped_memory_pool(); static_cast(fence)->release(); } else { - allocator::free(static_cast(fence)); + allocator::free(allocator::Buffer{static_cast(fence)}); } } bool use_fast{false}; @@ -99,7 +99,7 @@ void Fence::wait(Stream stream, const array& x) { [fence_ = fence_](MTL::CommandBuffer* cbuf) {}); } -void Fence::update(Stream stream, const array& x) { +void Fence::update(Stream stream, const array& x, bool cross_device) { auto& f = *static_cast(fence_.get()); f.count++; @@ -130,21 +130,23 @@ void Fence::update(Stream stream, const array& x) { // Launch input visibility kernels auto& compute_encoder = d.get_command_encoder(idx); - auto kernel = d.get_kernel("input_coherent"); - uint32_t nthreads = - (x.data_size() * x.itemsize() + sizeof(uint32_t) - 1) / sizeof(uint32_t); - MTL::Size group_dims = MTL::Size(1024, 1, 1); - MTL::Size grid_dims = MTL::Size((nthreads + 1024 - 1) / 1024, 1, 1); - compute_encoder.set_compute_pipeline_state(kernel); - compute_encoder.set_input_array(x, 0); - compute_encoder.set_bytes(nthreads, 1); - compute_encoder.dispatch_threadgroups(grid_dims, group_dims); + if (cross_device) { + auto kernel = d.get_kernel("input_coherent"); + uint32_t nthreads = (x.data_size() * x.itemsize() + sizeof(uint32_t) - 1) / + sizeof(uint32_t); + MTL::Size group_dims = MTL::Size(1024, 1, 1); + MTL::Size grid_dims = MTL::Size((nthreads + 1024 - 1) / 1024, 1, 1); + compute_encoder.set_compute_pipeline_state(kernel); + compute_encoder.set_input_array(x, 0); + compute_encoder.set_bytes(nthreads, 1); + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); + } // Barrier on previous kernels compute_encoder.barrier(); // Launch value update kernel - kernel = d.get_kernel("fence_update"); + auto kernel = d.get_kernel("fence_update"); MTL::Size kernel_dims = MTL::Size(1, 1, 1); compute_encoder.set_compute_pipeline_state(kernel); diff --git a/mlx/backend/metal/fft.cpp b/mlx/backend/metal/fft.cpp index d329a4685..d99e1badb 100644 --- a/mlx/backend/metal/fft.cpp +++ b/mlx/backend/metal/fft.cpp @@ -768,9 +768,12 @@ void nd_fft_op( const Stream& s) { // Perform ND FFT on GPU as a series of 1D FFTs auto temp_shape = inverse ? in.shape() : out.shape(); - array temp1(temp_shape, complex64, nullptr, {}); - array temp2(temp_shape, complex64, nullptr, {}); - std::vector temp_arrs = {temp1, temp2}; + std::vector temp_arrs; + temp_arrs.emplace_back(temp_shape, complex64, nullptr, std::vector{}); + if (axes.size() > 2) { + temp_arrs.emplace_back( + temp_shape, complex64, nullptr, std::vector{}); + } for (int i = axes.size() - 1; i >= 0; i--) { int reverse_index = axes.size() - i - 1; // For 5D and above, we don't want to reallocate our two temporary arrays @@ -781,8 +784,8 @@ void nd_fft_op( // Mirror np.fft.(i)rfftn and perform a real transform // only on the final axis. bool step_real = (real && index == axes.size() - 1); - const array& in_arr = i == axes.size() - 1 ? in : temp_arrs[1 - i % 2]; - array& out_arr = i == 0 ? out : temp_arrs[i % 2]; + const array& in_arr = i == axes.size() - 1 ? in : temp_arrs[i % 2]; + array& out_arr = i == 0 ? out : temp_arrs[1 - i % 2]; fft_op(in_arr, out_arr, axis, inverse, step_real, inplace, s); } diff --git a/mlx/backend/no_gpu/fence.cpp b/mlx/backend/no_gpu/fence.cpp index af22108a8..cd66d23cf 100644 --- a/mlx/backend/no_gpu/fence.cpp +++ b/mlx/backend/no_gpu/fence.cpp @@ -36,7 +36,7 @@ void Fence::wait(Stream stream, const array&) { } } -void Fence::update(Stream stream, const array&) { +void Fence::update(Stream stream, const array&, bool) { auto& f = *static_cast(fence_.get()); f.count++; if (stream.device == Device::cpu) { diff --git a/mlx/distributed/nccl/nccl.cpp b/mlx/distributed/nccl/nccl.cpp index 71fc8b3bd..bb9340ba5 100644 --- a/mlx/distributed/nccl/nccl.cpp +++ b/mlx/distributed/nccl/nccl.cpp @@ -300,8 +300,8 @@ class NCCLGroup : public GroupImpl { using T = typename decltype(type_tag)::type; auto& encoder = cu::get_command_encoder(stream); CHECK_NCCL(ncclAllGather( - input.data(), - output.data(), + gpu_ptr(input), + gpu_ptr(output), input.size(), dt, comm_, @@ -348,8 +348,8 @@ class NCCLGroup : public GroupImpl { auto& encoder = cu::get_command_encoder(stream); CHECK_NCCL(ncclAllReduce( - input.data(), - output.data(), + gpu_ptr(input), + gpu_ptr(output), input.size(), dt, op, @@ -367,8 +367,8 @@ class NCCLGroup : public GroupImpl { auto& encoder = cu::get_command_encoder(stream); CHECK_NCCL(ncclReduceScatter( - input.data(), - output.data(), + gpu_ptr(input), + gpu_ptr(output), output.size(), dt, op, diff --git a/mlx/fence.h b/mlx/fence.h index 66c2b5f49..0ececdb6d 100644 --- a/mlx/fence.h +++ b/mlx/fence.h @@ -15,7 +15,7 @@ namespace mlx::core { * `wait` returns. The array passed to `wait` will not be read until all * previous calls to `update` have completed. * - * Note, calls to `update` should always from the same thread or explicitly + * Note, calls to `update` should always be from the same thread or explicitly * synchronized so that they occur in sequence. Calls to `wait` can be on any * thread. * @@ -29,7 +29,7 @@ class Fence { Fence() {}; explicit Fence(Stream stream); - void update(Stream stream, const array& x); + void update(Stream stream, const array& x, bool cross_device); void wait(Stream stream, const array& x); private: diff --git a/mlx/io/load.cpp b/mlx/io/load.cpp index f3d71e860..92366be78 100644 --- a/mlx/io/load.cpp +++ b/mlx/io/load.cpp @@ -13,6 +13,7 @@ #include #endif // _WIN32 +#include "mlx/backend/cuda/cuda.h" #include "mlx/io/load.h" #include "mlx/ops.h" #include "mlx/primitives.h" @@ -226,10 +227,7 @@ array load(std::shared_ptr in_stream, StreamOrDevice s) { throw std::runtime_error("[load] Failed to open " + in_stream->label()); } - auto stream = to_stream(s, Device::cpu); - if (stream.device != Device::cpu) { - throw std::runtime_error("[load] Must run on a CPU stream."); - } + auto stream = cu::is_available() ? to_stream(s) : to_stream(s, Device::cpu); //////////////////////////////////////////////////////// // Read header and prepare array details diff --git a/mlx/io/safetensors.cpp b/mlx/io/safetensors.cpp index d9b9e9e40..738370f35 100644 --- a/mlx/io/safetensors.cpp +++ b/mlx/io/safetensors.cpp @@ -4,6 +4,7 @@ #include #include +#include "mlx/backend/cuda/cuda.h" #include "mlx/io.h" #include "mlx/io/load.h" #include "mlx/ops.h" @@ -113,10 +114,7 @@ SafetensorsLoad load_safetensors( "[load_safetensors] Failed to open " + in_stream->label()); } - auto stream = to_stream(s, Device::cpu); - if (stream.device != Device::cpu) { - throw std::runtime_error("[load_safetensors] Must run on a CPU stream."); - } + auto stream = cu::is_available() ? to_stream(s) : to_stream(s, Device::cpu); uint64_t jsonHeaderLength = 0; // This is the same limit as in the original Rust Safetensors code. diff --git a/mlx/transforms.cpp b/mlx/transforms.cpp index 3ec64feea..4967c50a8 100644 --- a/mlx/transforms.cpp +++ b/mlx/transforms.cpp @@ -62,7 +62,7 @@ array eval_impl(std::vector outputs, bool async) { } // Map of array id that needs fence and stream it's computed on - std::unordered_map needs_fence; + std::unordered_map> needs_fence; auto synchronizer = array( {}, bool_, std::make_shared(stream), std::move(outputs)); @@ -114,7 +114,14 @@ array eval_impl(std::vector outputs, bool async) { "https://github.com/ml-explore/mlx/issues."); } if (a.primitive().stream() != in.primitive().stream()) { - needs_fence.emplace(in.id(), in.primitive().stream().index); + bool device_switch = + a.primitive().stream().device != in.primitive().stream().device; + auto [it, inserted] = needs_fence.emplace( + in.id(), + std::make_pair(in.primitive().stream().index, device_switch)); + if (!inserted) { + it->second.second |= device_switch; + } } } @@ -190,7 +197,6 @@ array eval_impl(std::vector outputs, bool async) { } std::unordered_set open_streams; - while (!tape.empty()) { auto arr = std::move(tape.back()); tape.pop_back(); @@ -216,7 +222,7 @@ array eval_impl(std::vector outputs, bool async) { // Use fence to wait within a single eval // Get the input array's stream fence and wait on the // output arrays stream - fences[it->second].wait(stream, in); + fences[it->second.first].wait(stream, in); } else if (in.event().valid()) { if (in.event().is_signaled()) { in.detach_event(); @@ -251,12 +257,12 @@ array eval_impl(std::vector outputs, bool async) { } auto maybe_update_fence = [&fences, &needs_fence, stream](const array& a) { - if (needs_fence.find(a.id()) != needs_fence.end()) { + if (auto nf = needs_fence.find(a.id()); nf != needs_fence.end()) { auto it = fences.find(stream.index); if (it == fences.end()) { it = fences.emplace(stream.index, Fence{stream}).first; } - it->second.update(stream, a); + it->second.update(stream, a, nf->second.second); } };