From d378567cc66b7055f6625f1ea70d77c024fa1311 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 31 Oct 2025 14:12:15 -0700 Subject: [PATCH] refactor for regular cuda malloc --- mlx/array.cpp | 11 +++-- mlx/array.h | 9 +++-- mlx/backend/common/binary.h | 14 +++---- mlx/backend/common/compiled.cpp | 8 ++-- mlx/backend/common/compiled.h | 4 +- mlx/backend/common/copy.h | 10 +++-- mlx/backend/common/ternary.h | 10 ++--- mlx/backend/common/unary.h | 9 +++-- mlx/backend/cuda/allocator.cpp | 31 +++++++++----- mlx/backend/cuda/allocator.h | 6 +++ mlx/backend/cuda/arange.cu | 2 +- mlx/backend/cuda/arg_reduce.cu | 4 +- mlx/backend/cuda/binary/binary.cuh | 24 ++++++----- mlx/backend/cuda/binary_two.cu | 34 +++++++++------- mlx/backend/cuda/conv.cpp | 7 ++-- mlx/backend/cuda/conv/gemm_conv.cu | 6 +-- mlx/backend/cuda/conv/gemm_grouped_conv.cu | 6 +-- mlx/backend/cuda/copy.cu | 16 ++++++++ mlx/backend/cuda/copy/copy_contiguous.cu | 4 +- mlx/backend/cuda/copy/copy_general.cu | 4 +- mlx/backend/cuda/copy/copy_general_dynamic.cu | 12 +++--- mlx/backend/cuda/copy/copy_general_input.cu | 4 +- mlx/backend/cuda/cudnn_utils.cpp | 4 +- mlx/backend/cuda/cudnn_utils.h | 5 ++- mlx/backend/cuda/custom_kernel.cpp | 4 +- mlx/backend/cuda/distributed.cu | 10 ++--- mlx/backend/cuda/gemms/cublas_gemm.cpp | 18 ++++----- .../cuda/gemms/cublas_gemm_batched_12_0.cpp | 16 ++++---- .../cuda/gemms/cublas_gemm_batched_12_9.cu | 40 +++++++++---------- mlx/backend/cuda/gemms/gemv.cu | 12 +++--- mlx/backend/cuda/indexing.cpp | 2 +- mlx/backend/cuda/jit_module.h | 2 +- mlx/backend/cuda/kernel_utils.cuh | 1 + mlx/backend/cuda/layer_norm.cu | 18 ++++----- mlx/backend/cuda/logsumexp.cu | 4 +- mlx/backend/cuda/quantized/affine_quantize.cu | 16 ++++---- mlx/backend/cuda/quantized/fp_quantize.cu | 12 +++--- mlx/backend/cuda/quantized/quantized.cpp | 6 +-- mlx/backend/cuda/random.cu | 14 +++---- mlx/backend/cuda/reduce/all_reduce.cu | 13 +++--- mlx/backend/cuda/reduce/col_reduce.cu | 12 +++--- mlx/backend/cuda/reduce/init_reduce.cu | 4 +- mlx/backend/cuda/reduce/reduce_utils.cuh | 8 ++-- mlx/backend/cuda/reduce/row_reduce.cu | 10 ++--- mlx/backend/cuda/rms_norm.cu | 16 ++++---- mlx/backend/cuda/rope.cu | 28 ++++++------- .../cuda/scaled_dot_product_attention.cu | 32 +++++++-------- mlx/backend/cuda/scan.cu | 8 ++-- mlx/backend/cuda/softmax.cu | 4 +- mlx/backend/cuda/sort.cu | 30 +++++++------- mlx/backend/cuda/ternary.cu | 29 ++++++++------ mlx/backend/cuda/unary/unary.cuh | 13 +++--- mlx/backend/cuda/utils.h | 12 ++++++ mlx/backend/gpu/copy.cpp | 13 +----- mlx/backend/metal/copy.cpp | 13 ++++++ 55 files changed, 370 insertions(+), 294 deletions(-) diff --git a/mlx/array.cpp b/mlx/array.cpp index a05e8dfa7..a8c77d150 100644 --- a/mlx/array.cpp +++ b/mlx/array.cpp @@ -64,7 +64,7 @@ array array::unsafe_weak_copy(const array& other) { other.strides(), other.flags(), [](auto) {}); - cpy.array_desc_->data_ptr = other.array_desc_->data_ptr; + cpy.array_desc_->offset = other.array_desc_->offset; return cpy; } @@ -141,7 +141,7 @@ bool array::is_tracer() const { void array::set_data(allocator::Buffer buffer, Deleter d) { array_desc_->data = std::make_shared(buffer, d); - array_desc_->data_ptr = buffer.raw_ptr(); + array_desc_->offset = 0; array_desc_->data_size = size(); array_desc_->flags.contiguous = true; array_desc_->flags.row_contiguous = true; @@ -156,7 +156,7 @@ void array::set_data( Flags flags, Deleter d) { array_desc_->data = std::make_shared(buffer, d); - array_desc_->data_ptr = buffer.raw_ptr(); + array_desc_->offset = 0; array_desc_->data_size = data_size; array_desc_->strides = std::move(strides); array_desc_->flags = flags; @@ -172,9 +172,8 @@ void array::copy_shared_buffer( array_desc_->strides = strides; array_desc_->flags = flags; array_desc_->data_size = data_size; - auto char_offset = sizeof(char) * itemsize() * offset; - array_desc_->data_ptr = static_cast( - static_cast(other.array_desc_->data_ptr) + char_offset); + array_desc_->offset = + sizeof(char) * itemsize() * offset + other.array_desc_->offset; } void array::copy_shared_buffer(const array& other) { diff --git a/mlx/array.h b/mlx/array.h index 4e9a5ae63..b7ae8996c 100644 --- a/mlx/array.h +++ b/mlx/array.h @@ -352,12 +352,13 @@ class array { // Return a raw pointer to the arrays data template T* data() { - return static_cast(array_desc_->data_ptr); + return reinterpret_cast( + (static_cast(buffer().raw_ptr()) + array_desc_->offset)); } template const T* data() const { - return static_cast(array_desc_->data_ptr); + return const_cast(*this).data(); } enum Status { @@ -461,8 +462,8 @@ class array { // can share the underlying data buffer. std::shared_ptr data; - // Properly offset data pointer - void* data_ptr{nullptr}; + // Offset from beginning of data pointer + int64_t offset{0}; // The size in elements of the data buffer the array accesses size_t data_size; diff --git a/mlx/backend/common/binary.h b/mlx/backend/common/binary.h index ac6e1891d..78607ef07 100644 --- a/mlx/backend/common/binary.h +++ b/mlx/backend/common/binary.h @@ -38,20 +38,20 @@ inline void set_binary_op_output_data( const array& a, const array& b, array& out, - BinaryOpType bopt) { + BinaryOpType bopt, + std::function mallocfn = allocator::malloc) { bool b_donatable = is_donatable(b, out); bool a_donatable = is_donatable(a, out); switch (bopt) { case BinaryOpType::ScalarScalar: - out.set_data( - allocator::malloc(out.itemsize()), 1, a.strides(), a.flags()); + out.set_data(mallocfn(out.itemsize()), 1, a.strides(), a.flags()); break; case BinaryOpType::ScalarVector: if (b_donatable) { out.copy_shared_buffer(b); } else { out.set_data( - allocator::malloc(b.data_size() * out.itemsize()), + mallocfn(b.data_size() * out.itemsize()), b.data_size(), b.strides(), b.flags()); @@ -62,7 +62,7 @@ inline void set_binary_op_output_data( out.copy_shared_buffer(a); } else { out.set_data( - allocator::malloc(a.data_size() * out.itemsize()), + mallocfn(a.data_size() * out.itemsize()), a.data_size(), a.strides(), a.flags()); @@ -75,7 +75,7 @@ inline void set_binary_op_output_data( out.copy_shared_buffer(b); } else { out.set_data( - allocator::malloc(a.data_size() * out.itemsize()), + mallocfn(a.data_size() * out.itemsize()), a.data_size(), a.strides(), a.flags()); @@ -88,7 +88,7 @@ inline void set_binary_op_output_data( b_donatable && b.flags().row_contiguous && b.size() == out.size()) { out.copy_shared_buffer(b); } else { - out.set_data(allocator::malloc(out.nbytes())); + out.set_data(mallocfn(out.nbytes())); } break; } diff --git a/mlx/backend/common/compiled.cpp b/mlx/backend/common/compiled.cpp index 44e2a432b..e5e1d4350 100644 --- a/mlx/backend/common/compiled.cpp +++ b/mlx/backend/common/compiled.cpp @@ -114,7 +114,9 @@ void compiled_allocate_outputs( const std::vector& inputs, std::vector& outputs, const std::function& is_constant, - bool contiguous) { + bool contiguous, + const std::function& + mallocfn /* = allocator::malloc */) { if (contiguous) { int o = 0; Strides strides; @@ -140,7 +142,7 @@ void compiled_allocate_outputs( } for (; o < outputs.size(); ++o) { outputs[o].set_data( - allocator::malloc(data_size * outputs[o].itemsize()), + mallocfn(data_size * outputs[o].itemsize()), data_size, strides, flags); @@ -163,7 +165,7 @@ void compiled_allocate_outputs( } } for (; o < outputs.size(); ++o) { - outputs[o].set_data(allocator::malloc(outputs[o].nbytes())); + outputs[o].set_data(mallocfn(outputs[o].nbytes())); } } } diff --git a/mlx/backend/common/compiled.h b/mlx/backend/common/compiled.h index e92a6d0ad..3be371333 100644 --- a/mlx/backend/common/compiled.h +++ b/mlx/backend/common/compiled.h @@ -58,7 +58,9 @@ void compiled_allocate_outputs( const std::vector& inputs, std::vector& outputs, const std::function& is_constant, - bool contiguous); + bool contiguous, + const std::function& mallocfn = + allocator::malloc); // Collapse contiguous dims ignoring scalars and constants. std::tuple> compiled_collapse_contiguous_dims( diff --git a/mlx/backend/common/copy.h b/mlx/backend/common/copy.h index c23d2e79a..859ce0410 100644 --- a/mlx/backend/common/copy.h +++ b/mlx/backend/common/copy.h @@ -22,7 +22,11 @@ enum class CopyType { GeneralGeneral }; -inline bool set_copy_output_data(const array& in, array& out, CopyType ctype) { +inline bool set_copy_output_data( + const array& in, + array& out, + CopyType ctype, + std::function mallocfn = allocator::malloc) { if (ctype == CopyType::Vector) { // If the input is donateable, we are doing a vector copy and the types // have the same size, then the input buffer can hold the output. @@ -31,14 +35,14 @@ inline bool set_copy_output_data(const array& in, array& out, CopyType ctype) { return true; } else { out.set_data( - allocator::malloc(in.data_size() * out.itemsize()), + mallocfn(in.data_size() * out.itemsize()), in.data_size(), in.strides(), in.flags()); return false; } } else { - out.set_data(allocator::malloc(out.nbytes())); + out.set_data(mallocfn(out.nbytes())); return false; } } diff --git a/mlx/backend/common/ternary.h b/mlx/backend/common/ternary.h index 233708ec3..c63a57261 100644 --- a/mlx/backend/common/ternary.h +++ b/mlx/backend/common/ternary.h @@ -46,7 +46,8 @@ inline void set_ternary_op_output_data( const array& b, const array& c, array& out, - TernaryOpType topt) { + TernaryOpType topt, + std::function mallocfn = allocator::malloc) { auto maybe_donate = [&out](const array& x) { if (is_donatable(x, out)) { out.copy_shared_buffer(x); @@ -57,13 +58,12 @@ inline void set_ternary_op_output_data( switch (topt) { case TernaryOpType::ScalarScalarScalar: - out.set_data( - allocator::malloc(out.itemsize()), 1, b.strides(), b.flags()); + out.set_data(mallocfn(out.itemsize()), 1, b.strides(), b.flags()); break; case TernaryOpType::VectorVectorVector: if (!(maybe_donate(a) || maybe_donate(b) || maybe_donate(c))) { out.set_data( - allocator::malloc(out.itemsize() * b.data_size()), + mallocfn(out.itemsize() * b.data_size()), b.data_size(), b.strides(), b.flags()); @@ -76,7 +76,7 @@ inline void set_ternary_op_output_data( if (!((a.flags().row_contiguous && maybe_donate(a)) || (b.flags().row_contiguous && maybe_donate(b)) || (c.flags().row_contiguous && maybe_donate(c)))) { - out.set_data(allocator::malloc(out.nbytes())); + out.set_data(mallocfn(out.nbytes())); } break; } diff --git a/mlx/backend/common/unary.h b/mlx/backend/common/unary.h index a27a1f45c..b19fc98ed 100644 --- a/mlx/backend/common/unary.h +++ b/mlx/backend/common/unary.h @@ -7,19 +7,22 @@ namespace mlx::core { -inline void set_unary_output_data(const array& in, array& out) { +inline void set_unary_output_data( + const array& in, + array& out, + std::function mallocfn = allocator::malloc) { if (in.flags().contiguous) { if (is_donatable(in, out)) { out.copy_shared_buffer(in); } else { out.set_data( - allocator::malloc(in.data_size() * out.itemsize()), + mallocfn(in.data_size() * out.itemsize()), in.data_size(), in.strides(), in.flags()); } } else { - out.set_data(allocator::malloc(out.nbytes())); + out.set_data(mallocfn(out.nbytes())); } } diff --git a/mlx/backend/cuda/allocator.cpp b/mlx/backend/cuda/allocator.cpp index 6ef2da907..dcc82a7f2 100644 --- a/mlx/backend/cuda/allocator.cpp +++ b/mlx/backend/cuda/allocator.cpp @@ -68,6 +68,7 @@ CudaBuffer* SmallSizePool::malloc() { next_free_ = next_free_->next; b->buf.data = static_cast(data_) + i * small_block_size; b->buf.size = small_block_size; + b->buf.managed = true; return &b->buf; } @@ -94,16 +95,13 @@ CudaAllocator::CudaAllocator() CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total)); memory_limit_ = total * 0.95; max_pool_size_ = memory_limit_; -#if CUDART_VERSION >= 13000 - cudaMemLocation loc; - loc.id = 0; - loc.type = cudaMemLocationTypeNone; - cudaMemGetDefaultMemPool(&cuda_pool_, &loc, cudaMemAllocationTypeManaged); + int loc = 0; + cudaDeviceGetDefaultMemPool(&cuda_pool_, loc); + // TODO need a strategy for that uint64_t threshold = UINT64_MAX; cudaMemPoolSetAttribute( cuda_pool_, cudaMemPoolAttrReleaseThreshold, &threshold); -#endif } Buffer CudaAllocator::malloc_impl(size_t size, cudaStream_t stream) { @@ -133,12 +131,13 @@ Buffer CudaAllocator::malloc_impl(size_t size, cudaStream_t stream) { } lock.unlock(); if (!buf) { - buf = new CudaBuffer{nullptr, size}; + bool managed = stream == nullptr; + buf = new CudaBuffer{nullptr, size, managed}; cudaError_t err; - if (stream != nullptr && cuda_pool_ != nullptr) { - err = cudaMallocFromPoolAsync(&buf->data, size, cuda_pool_, stream); - } else { + if (managed) { err = cudaMallocManaged(&buf->data, size); + } else { + err = cudaMallocFromPoolAsync(&buf->data, size, cuda_pool_, stream); } if (err != cudaSuccess && err != cudaErrorMemoryAllocation) { throw std::runtime_error(fmt::format( @@ -266,7 +265,17 @@ void* Buffer::raw_ptr() { if (!ptr_) { return nullptr; } - return static_cast(ptr_)->data; + auto& cbuf = *static_cast(ptr_); + if (!cbuf.managed) { + void* new_data; + CHECK_CUDA_ERROR(cudaMallocManaged(&new_data, cbuf.size)); + cbuf.managed = true; + CHECK_CUDA_ERROR( + cudaMemcpy(new_data, cbuf.data, cbuf.size, cudaMemcpyDefault)); + CHECK_CUDA_ERROR(cudaFree(cbuf.data)); + cbuf.data = new_data; + } + return cbuf.data; } } // namespace allocator diff --git a/mlx/backend/cuda/allocator.h b/mlx/backend/cuda/allocator.h index eb69db3b4..a0f7d38db 100644 --- a/mlx/backend/cuda/allocator.h +++ b/mlx/backend/cuda/allocator.h @@ -18,8 +18,14 @@ using allocator::Buffer; struct CudaBuffer { void* data; size_t size; + bool managed; }; +template +T* gpu_ptr(Buffer buf) { + return static_cast(static_cast(buf.ptr())->data); +} + class SmallSizePool { private: union Block { diff --git a/mlx/backend/cuda/arange.cu b/mlx/backend/cuda/arange.cu index 1c1f96a72..a8e406ba1 100644 --- a/mlx/backend/cuda/arange.cu +++ b/mlx/backend/cuda/arange.cu @@ -57,7 +57,7 @@ void Arange::eval_gpu(const std::vector& inputs, array& out) { num_blocks, block_dims, 0, - out.data(), + gpu_ptr(out), out.data_size(), static_cast(start_), static_cast(start_ + step_) - static_cast(start_)); diff --git a/mlx/backend/cuda/arg_reduce.cu b/mlx/backend/cuda/arg_reduce.cu index 04bfe33ce..8dd77cc73 100644 --- a/mlx/backend/cuda/arg_reduce.cu +++ b/mlx/backend/cuda/arg_reduce.cu @@ -173,8 +173,8 @@ void ArgReduce::eval_gpu(const std::vector& inputs, array& out) { num_blocks, block_dim(), 0, - in.data(), - out.data(), + gpu_ptr(in), + gpu_ptr(out), out.size(), const_param(shape), const_param(in_strides), diff --git a/mlx/backend/cuda/binary/binary.cuh b/mlx/backend/cuda/binary/binary.cuh index 20bb199ec..47847e4d9 100644 --- a/mlx/backend/cuda/binary/binary.cuh +++ b/mlx/backend/cuda/binary/binary.cuh @@ -292,9 +292,9 @@ void binary_op_gpu_inplace( {num_blocks_x, num_blocks_y}, block_dims, 0, - a.data(), - b.data(), - out.data(), + gpu_ptr(a), + gpu_ptr(b), + gpu_ptr(out), rest, const_param(shape), const_param(a_strides), @@ -310,9 +310,9 @@ void binary_op_gpu_inplace( {num_blocks_x, num_blocks_y}, block_dims, 0, - a.data(), - b.data(), - out.data(), + gpu_ptr(a), + gpu_ptr(b), + gpu_ptr(out), rest, const_param(shape), const_param(a_strides), @@ -339,9 +339,9 @@ void binary_op_gpu_inplace( num_blocks, block_dims, 0, - a.data(), - b.data(), - out.data(), + gpu_ptr(a), + gpu_ptr(b), + gpu_ptr(out), out.data_size()); }); } @@ -365,7 +365,11 @@ void binary_op_gpu( auto& a = inputs[0]; auto& b = inputs[1]; auto bopt = get_binary_op_type(a, b); - set_binary_op_output_data(a, b, out, bopt); + auto& encoder = cu::get_command_encoder(s); + + set_binary_op_output_data(a, b, out, bopt, [&](auto n) { + return cu::malloc_async(n, encoder.stream()); + }); binary_op_gpu_inplace(inputs, out, op, s); } diff --git a/mlx/backend/cuda/binary_two.cu b/mlx/backend/cuda/binary_two.cu index cd0fe2c46..edec2678f 100644 --- a/mlx/backend/cuda/binary_two.cu +++ b/mlx/backend/cuda/binary_two.cu @@ -245,14 +245,18 @@ void binary_two_op_gpu_inplace( auto& out_a = outputs[0]; auto& out_b = outputs[1]; auto bopt = get_binary_op_type(a, b); - set_binary_op_output_data(a, b, out_a, bopt); - set_binary_op_output_data(a, b, out_b, bopt); + auto& encoder = cu::get_command_encoder(s); + set_binary_op_output_data(a, b, out_a, bopt, [&](auto n) { + return cu::malloc_async(n, encoder.stream()); + }); + set_binary_op_output_data(a, b, out_b, bopt, [&](auto n) { + return cu::malloc_async(n, encoder.stream()); + }); if (out_a.size() == 0) { return; } - auto& encoder = cu::get_command_encoder(s); encoder.set_input_array(a); encoder.set_input_array(b); encoder.set_output_array(out_a); @@ -313,10 +317,10 @@ void binary_two_op_gpu_inplace( {num_blocks_x, num_blocks_y}, block_dims, 0, - a.data(), - b.data(), - out_a.data(), - out_b.data(), + gpu_ptr(a), + gpu_ptr(b), + gpu_ptr(out_a), + gpu_ptr(out_b), rest, const_param(shape), const_param(a_strides), @@ -332,10 +336,10 @@ void binary_two_op_gpu_inplace( {num_blocks_x, num_blocks_y}, block_dims, 0, - a.data(), - b.data(), - out_a.data(), - out_b.data(), + gpu_ptr(a), + gpu_ptr(b), + gpu_ptr(out_a), + gpu_ptr(out_b), rest, const_param(shape), const_param(a_strides), @@ -366,10 +370,10 @@ void binary_two_op_gpu_inplace( num_blocks, block_dims, 0, - a.data(), - b.data(), - out_a.data(), - out_b.data(), + gpu_ptr(a), + gpu_ptr(b), + gpu_ptr(out_a), + gpu_ptr(out_b), out_a.data_size()); }); } diff --git a/mlx/backend/cuda/conv.cpp b/mlx/backend/cuda/conv.cpp index e65de63e0..4f7a971ca 100644 --- a/mlx/backend/cuda/conv.cpp +++ b/mlx/backend/cuda/conv.cpp @@ -270,17 +270,16 @@ void Convolution::eval_gpu(const std::vector& inputs, array& out_) { if (out_.size() == 0) { return; } + auto& s = stream(); + auto& encoder = cu::get_command_encoder(s); assert(inputs.size() == 2); array in = inputs[0]; array wt = inputs[1]; array out = out_; - out.set_data(allocator::malloc(out.nbytes())); + out.set_data(cu::malloc_async(out.nbytes(), encoder.stream())); Dtype dtype = out.dtype(); - auto& s = stream(); - auto& encoder = cu::get_command_encoder(s); - // Search cache. ConvCacheKey cache_key{ encoder.device().cuda_device(), diff --git a/mlx/backend/cuda/conv/gemm_conv.cu b/mlx/backend/cuda/conv/gemm_conv.cu index 11a78a7ab..b2c151706 100644 --- a/mlx/backend/cuda/conv/gemm_conv.cu +++ b/mlx/backend/cuda/conv/gemm_conv.cu @@ -86,7 +86,7 @@ array unfold_inputs_nd( int mat_N, ConvParams& params) { array unfolded({mat_M, mat_K}, in.dtype(), nullptr, {}); - unfolded.set_data(allocator::malloc(unfolded.nbytes())); + unfolded.set_data(cu::malloc_async(unfolded.nbytes(), encoder.stream())); encoder.add_temporary(unfolded); int filter_size = params.C; @@ -118,8 +118,8 @@ array unfold_inputs_nd( num_blocks, block_dims, 0, - in.data(), - unfolded.data(), + gpu_ptr(in), + gpu_ptr(unfolded), filter_size, out_pixels, params); diff --git a/mlx/backend/cuda/conv/gemm_grouped_conv.cu b/mlx/backend/cuda/conv/gemm_grouped_conv.cu index 7ceb58166..788c640ff 100644 --- a/mlx/backend/cuda/conv/gemm_grouped_conv.cu +++ b/mlx/backend/cuda/conv/gemm_grouped_conv.cu @@ -89,7 +89,7 @@ array grouped_unfold_transpose_inputs_nd( int mat_N, ConvParams& params) { array unfolded({mat_M, mat_K * params.groups}, in.dtype(), nullptr, {}); - unfolded.set_data(allocator::malloc(unfolded.nbytes())); + unfolded.set_data(cu::malloc_async(unfolded.nbytes(), encoder.stream())); encoder.add_temporary(unfolded); int filter_size = params.C; @@ -121,8 +121,8 @@ array grouped_unfold_transpose_inputs_nd( num_blocks, block_dims, 0, - in.data(), - unfolded.data(), + gpu_ptr(in), + gpu_ptr(unfolded), filter_size, out_pixels, params); diff --git a/mlx/backend/cuda/copy.cu b/mlx/backend/cuda/copy.cu index 17e414f73..e36bcbdf6 100644 --- a/mlx/backend/cuda/copy.cu +++ b/mlx/backend/cuda/copy.cu @@ -5,6 +5,22 @@ namespace mlx::core { +void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) { + auto& encoder = cu::get_command_encoder(s); + bool donated = set_copy_output_data(in, out, ctype, [&](auto n) { + return cu::malloc_async(n, encoder.stream()); + }); + if (donated && in.dtype() == out.dtype()) { + // If the output has the same type as the input then there is nothing to + // copy, just use the buffer. + return; + } + if (ctype == CopyType::GeneralGeneral) { + ctype = CopyType::General; + } + copy_gpu_inplace(in, out, ctype, s); +} + void copy_gpu_inplace( const array& in, array& out, diff --git a/mlx/backend/cuda/copy/copy_contiguous.cu b/mlx/backend/cuda/copy/copy_contiguous.cu index 9c2aa9838..6f51734b1 100644 --- a/mlx/backend/cuda/copy/copy_contiguous.cu +++ b/mlx/backend/cuda/copy/copy_contiguous.cu @@ -77,8 +77,8 @@ void copy_contiguous( num_blocks, block_dims, 0, - in.data() + in_offset, - out.data() + out_offset, + gpu_ptr(in) + in_offset, + gpu_ptr(out) + out_offset, out.data_size()); }); }); diff --git a/mlx/backend/cuda/copy/copy_general.cu b/mlx/backend/cuda/copy/copy_general.cu index 6ac42751a..fec110a86 100644 --- a/mlx/backend/cuda/copy/copy_general.cu +++ b/mlx/backend/cuda/copy/copy_general.cu @@ -106,8 +106,8 @@ void copy_general( using InType = cuda_type_t; using OutType = cuda_type_t; using IdxT = std::conditional_t; - const InType* in_ptr = in.data() + offset_in; - OutType* out_ptr = out.data() + offset_out; + const InType* in_ptr = gpu_ptr(in) + offset_in; + OutType* out_ptr = gpu_ptr(out) + offset_out; int ndim = shape.size(); size_t data_size = 1; for (auto& s : shape) diff --git a/mlx/backend/cuda/copy/copy_general_dynamic.cu b/mlx/backend/cuda/copy/copy_general_dynamic.cu index 7a7f0dca5..96cf6fb5c 100644 --- a/mlx/backend/cuda/copy/copy_general_dynamic.cu +++ b/mlx/backend/cuda/copy/copy_general_dynamic.cu @@ -69,8 +69,8 @@ void copy_general_dynamic( using InType = cuda_type_t; using OutType = cuda_type_t; using IdxT = std::conditional_t; - const InType* in_ptr = in.data() + offset_in; - OutType* out_ptr = out.data() + offset_out; + const InType* in_ptr = gpu_ptr(in) + offset_in; + OutType* out_ptr = gpu_ptr(out) + offset_out; int ndim = shape.size(); if (ndim <= 3) { dispatch_1_2_3(ndim, [&](auto dims_constant) { @@ -90,8 +90,8 @@ void copy_general_dynamic( const_param(shape), const_param(strides_in), const_param(strides_out), - dynamic_offset_in.data(), - dynamic_offset_out.data()); + gpu_ptr(dynamic_offset_in), + gpu_ptr(dynamic_offset_out)); }); } else { // ndim >= 4 auto [num_blocks, block_dims] = get_launch_args(out, large()); @@ -107,8 +107,8 @@ void copy_general_dynamic( const_param(strides_in), const_param(strides_out), ndim, - dynamic_offset_in.data(), - dynamic_offset_out.data()); + gpu_ptr(dynamic_offset_in), + gpu_ptr(dynamic_offset_out)); } }); }); diff --git a/mlx/backend/cuda/copy/copy_general_input.cu b/mlx/backend/cuda/copy/copy_general_input.cu index ce8bb1b78..42a027ec5 100644 --- a/mlx/backend/cuda/copy/copy_general_input.cu +++ b/mlx/backend/cuda/copy/copy_general_input.cu @@ -92,8 +92,8 @@ void copy_general_input( using InType = cuda_type_t; using OutType = cuda_type_t; using IdxT = std::conditional_t; - const InType* in_ptr = in.data() + offset_in; - OutType* out_ptr = out.data() + offset_out; + const InType* in_ptr = gpu_ptr(in) + offset_in; + OutType* out_ptr = gpu_ptr(out) + offset_out; int ndim = shape.size(); int work_per_thread = 1; auto dim0 = ndim > 0 ? shape.back() : 1; diff --git a/mlx/backend/cuda/cudnn_utils.cpp b/mlx/backend/cuda/cudnn_utils.cpp index 20280f2be..e93778056 100644 --- a/mlx/backend/cuda/cudnn_utils.cpp +++ b/mlx/backend/cuda/cudnn_utils.cpp @@ -133,13 +133,13 @@ bool prepare_cudnn_plan( F&& execute) { int workspace_size = plan.getWorkspaceSize(); array workspace( - workspace_size > 0 ? allocator::malloc(workspace_size) + workspace_size > 0 ? cu::malloc_async(workspace_size, encoder.stream()) : allocator::Buffer(nullptr), {workspace_size}, uint8); auto args = cudnn_frontend::VariantPackBuilder() - .setWorkspacePointer(workspace.data()) + .setWorkspacePointer(gpu_ptr(workspace)) .setDataPointers(num_args, data_ptrs) .setUids(num_args, uids) .build(); diff --git a/mlx/backend/cuda/cudnn_utils.h b/mlx/backend/cuda/cudnn_utils.h index c35c5cac9..b28249678 100644 --- a/mlx/backend/cuda/cudnn_utils.h +++ b/mlx/backend/cuda/cudnn_utils.h @@ -3,6 +3,7 @@ #pragma once #include "mlx/array.h" +#include "mlx/backend/cuda/allocator.h" #include "mlx/backend/cuda/device/config.h" #include "mlx/backend/cuda/utils.h" #include "mlx/dtype_utils.h" @@ -23,7 +24,7 @@ class CommandEncoder; // Return pointer alignment of |x|'s data. inline uint8_t get_alignment(const array& x) { uint8_t alignment = 1; - uintptr_t address = reinterpret_cast(x.data()); + uintptr_t address = reinterpret_cast(gpu_ptr(x)); for (; alignment < 32; alignment *= 2) { if (address % (alignment * 2)) { return alignment; @@ -56,7 +57,7 @@ inline std::array vector_key(const Vec& vec) { // Helpers used by get_data_ptrs to get pointers. inline void* get_data_ptr(const array& arr) { - return const_cast(arr.data()); + return const_cast(gpu_ptr(arr)); } template >> diff --git a/mlx/backend/cuda/custom_kernel.cpp b/mlx/backend/cuda/custom_kernel.cpp index ee1778fd8..958516746 100644 --- a/mlx/backend/cuda/custom_kernel.cpp +++ b/mlx/backend/cuda/custom_kernel.cpp @@ -279,6 +279,7 @@ void CustomKernel::eval_gpu( std::vector& outputs) { nvtx3::scoped_range r("CustomKernel::eval_gpu"); auto& s = stream(); + auto& encoder = cu::get_command_encoder(s); std::vector copies; @@ -288,7 +289,7 @@ void CustomKernel::eval_gpu( copies.emplace_back(init_value_.value(), out.dtype()); fill_gpu(copies.back(), out, s); } else { - out.set_data(allocator::malloc(out.nbytes())); + out.set_data(cu::malloc_async(out.nbytes(), encoder.stream())); } } @@ -356,7 +357,6 @@ void CustomKernel::eval_gpu( dim3 grid((gx + tx - 1) / tx, (gy + ty - 1) / ty, (gz + tz - 1) / tz); // Call the kernel - auto& encoder = cu::get_command_encoder(s); for (const auto& in : checked_inputs) { encoder.set_input_array(in); } diff --git a/mlx/backend/cuda/distributed.cu b/mlx/backend/cuda/distributed.cu index 4d2658534..309ed423c 100644 --- a/mlx/backend/cuda/distributed.cu +++ b/mlx/backend/cuda/distributed.cu @@ -15,8 +15,10 @@ void AllReduce::eval_gpu( assert(inputs.size() == 1); assert(outputs.size() == 1); - auto set_input_output = - [s = stream()](const array& in, array& out) -> std::pair { + auto& s = stream(); + auto& encoder = cu::get_command_encoder(s); + auto set_input_output = [&](const array& in, + array& out) -> std::pair { if (!in.flags().row_contiguous) { copy_gpu(in, out, CopyType::General, s); return {out, out}; @@ -24,19 +26,17 @@ void AllReduce::eval_gpu( out.copy_shared_buffer(in); return {in, out}; } else { - out.set_data(allocator::malloc(out.nbytes())); + out.set_data(cu::malloc_async(out.nbytes(), encoder.stream())); return {in, out}; } }; auto [input, output] = set_input_output(inputs[0], outputs[0]); - auto& encoder = cu::get_command_encoder(stream()); encoder.set_input_array(input); encoder.set_output_array(output); auto capture = encoder.capture_context(); - auto& s = stream(); switch (reduce_type_) { case Sum: diff --git a/mlx/backend/cuda/gemms/cublas_gemm.cpp b/mlx/backend/cuda/gemms/cublas_gemm.cpp index 041784913..60ca2ccae 100644 --- a/mlx/backend/cuda/gemms/cublas_gemm.cpp +++ b/mlx/backend/cuda/gemms/cublas_gemm.cpp @@ -241,7 +241,7 @@ void CublasGemm::set_bias(cu::CommandEncoder& encoder, const array& bias) { CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue))); - auto* bias_ptr = bias.data(); + auto* bias_ptr = gpu_ptr(bias); CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( matmul_desc_, CUBLASLT_MATMUL_DESC_BIAS_POINTER, @@ -278,9 +278,9 @@ void CublasGemm::run( execute( encoder, - out.data(), - a.data(), - b.data(), + gpu_ptr(out), + gpu_ptr(a), + gpu_ptr(b), nullptr, alpha); } @@ -321,10 +321,10 @@ void CublasGemm::run( execute( encoder, - out.data(), - a.data(), - b.data(), - c.data(), + gpu_ptr(out), + gpu_ptr(a), + gpu_ptr(b), + gpu_ptr(c), alpha, beta); } @@ -374,7 +374,7 @@ void CublasGemm::execute( {static_cast(heuristic_.workspaceSize)}, int8); encoder.add_temporary(workspace); - workspace_ptr = workspace.data(); + workspace_ptr = gpu_ptr(workspace); } auto capture = encoder.capture_context(); diff --git a/mlx/backend/cuda/gemms/cublas_gemm_batched_12_0.cpp b/mlx/backend/cuda/gemms/cublas_gemm_batched_12_0.cpp index 70df21fda..67b8c2754 100644 --- a/mlx/backend/cuda/gemms/cublas_gemm_batched_12_0.cpp +++ b/mlx/backend/cuda/gemms/cublas_gemm_batched_12_0.cpp @@ -25,9 +25,10 @@ void CublasGemm::run_batched( for (size_t i = 0; i < nbatch; ++i) { execute( encoder, - out.data() + out.itemsize() * i * batch_shape.back() * M_ * N_, - a.data() + a.itemsize() * a_it.loc, - b.data() + b.itemsize() * b_it.loc, + gpu_ptr(out) + + out.itemsize() * i * batch_shape.back() * M_ * N_, + gpu_ptr(a) + a.itemsize() * a_it.loc, + gpu_ptr(b) + b.itemsize() * b_it.loc, nullptr, alpha); a_it.step(); @@ -60,10 +61,11 @@ void CublasGemm::run_batched( for (size_t i = 0; i < nbatch; ++i) { execute( encoder, - out.data() + out.itemsize() * i * batch_shape.back() * M_ * N_, - a.data() + a.itemsize() * a_it.loc, - b.data() + b.itemsize() * b_it.loc, - c.data() + c.itemsize() * c_it.loc, + gpu_ptr(out) + + out.itemsize() * i * batch_shape.back() * M_ * N_, + gpu_ptr(a) + a.itemsize() * a_it.loc, + gpu_ptr(b) + b.itemsize() * b_it.loc, + gpu_ptr(c) + c.itemsize() * c_it.loc, alpha, beta); a_it.step(); diff --git a/mlx/backend/cuda/gemms/cublas_gemm_batched_12_9.cu b/mlx/backend/cuda/gemms/cublas_gemm_batched_12_9.cu index 03aca99d6..c3b6fd379 100644 --- a/mlx/backend/cuda/gemms/cublas_gemm_batched_12_9.cu +++ b/mlx/backend/cuda/gemms/cublas_gemm_batched_12_9.cu @@ -183,10 +183,10 @@ void CublasGemm::run_batched( num_blocks, block_dims, 0, - pointers.data(), - a.data(), - b.data(), - out.data(), + gpu_ptr(pointers), + gpu_ptr(a), + gpu_ptr(b), + gpu_ptr(out), item_size, const_param(batch_shape), const_param(a_batch_strides), @@ -200,10 +200,10 @@ void CublasGemm::run_batched( num_blocks, block_dims, 0, - pointers.data(), - a.data(), - b.data(), - out.data(), + gpu_ptr(pointers), + gpu_ptr(a), + gpu_ptr(b), + gpu_ptr(out), item_size, const_param(batch_shape), const_param(a_batch_strides), @@ -219,7 +219,7 @@ void CublasGemm::run_batched( encoder.set_input_array(b); encoder.set_output_array(out); - auto a_pointers = pointers.data(); + auto a_pointers = gpu_ptr(pointers); auto b_pointers = a_pointers + batch_count; auto out_pointers = b_pointers + batch_count; execute( @@ -271,11 +271,11 @@ void CublasGemm::run_batched( num_blocks, block_dims, 0, - pointers.data(), - a.data(), - b.data(), - c.data(), - out.data(), + gpu_ptr(pointers), + gpu_ptr(a), + gpu_ptr(b), + gpu_ptr(c), + gpu_ptr(out), item_size, const_param(batch_shape), const_param(a_batch_strides), @@ -290,11 +290,11 @@ void CublasGemm::run_batched( num_blocks, block_dims, 0, - pointers.data(), - a.data(), - b.data(), - c.data(), - out.data(), + gpu_ptr(pointers), + gpu_ptr(a), + gpu_ptr(b), + gpu_ptr(c), + gpu_ptr(out), item_size, const_param(batch_shape), const_param(a_batch_strides), @@ -312,7 +312,7 @@ void CublasGemm::run_batched( encoder.set_input_array(c); encoder.set_output_array(out); - auto a_pointers = pointers.data(); + auto a_pointers = gpu_ptr(pointers); auto b_pointers = a_pointers + batch_count; auto c_pointers = b_pointers + batch_count; auto out_pointers = c_pointers + batch_count; diff --git a/mlx/backend/cuda/gemms/gemv.cu b/mlx/backend/cuda/gemms/gemv.cu index e1c755039..6966a583b 100644 --- a/mlx/backend/cuda/gemms/gemv.cu +++ b/mlx/backend/cuda/gemms/gemv.cu @@ -149,13 +149,13 @@ void gemv( auto vec_strides = const_param(b_batch_strides); if (M == 1) { - mat = b.data(); - vec = a.data(); + mat = gpu_ptr(b); + vec = gpu_ptr(a); rows = N; std::swap(mat_strides, vec_strides); } else { - mat = a.data(); - vec = b.data(); + mat = gpu_ptr(a); + vec = gpu_ptr(b); rows = M; } uint32_t num_blocks_x = (rows + rows_per_block - 1) / rows_per_block; @@ -177,7 +177,7 @@ void gemv( 0, mat, vec, - out.data(), + gpu_ptr(out), rows, cols); } else { @@ -189,7 +189,7 @@ void gemv( 0, mat, vec, - out.data(), + gpu_ptr(out), rows, cols, const_param(batch_shape), diff --git a/mlx/backend/cuda/indexing.cpp b/mlx/backend/cuda/indexing.cpp index 762a101c2..5c468d2ba 100644 --- a/mlx/backend/cuda/indexing.cpp +++ b/mlx/backend/cuda/indexing.cpp @@ -31,7 +31,7 @@ void append_indices_arg( int idx_ndim) { SmallVector indices(nidx); for (int i = 0; i < nidx; ++i) { - indices[i] = inputs[i + 1].data(); + indices[i] = gpu_ptr(inputs[i + 1]); } args.append(std::move(indices)); SmallVector indices_shape(nidx * idx_ndim); diff --git a/mlx/backend/cuda/jit_module.h b/mlx/backend/cuda/jit_module.h index e2fd0c8b8..801a699d5 100644 --- a/mlx/backend/cuda/jit_module.h +++ b/mlx/backend/cuda/jit_module.h @@ -31,7 +31,7 @@ struct KernelArgs { } void append(const array& a) { - append(reinterpret_cast(a.data())); + append(reinterpret_cast(gpu_ptr(a.buffer()))); } template diff --git a/mlx/backend/cuda/kernel_utils.cuh b/mlx/backend/cuda/kernel_utils.cuh index d889cd590..bd5f7ef80 100644 --- a/mlx/backend/cuda/kernel_utils.cuh +++ b/mlx/backend/cuda/kernel_utils.cuh @@ -9,6 +9,7 @@ #include #include "mlx/array.h" +#include "mlx/backend/cuda/allocator.h" #include "mlx/backend/cuda/device/utils.cuh" #include diff --git a/mlx/backend/cuda/layer_norm.cu b/mlx/backend/cuda/layer_norm.cu index c1dc2833a..30cc6b837 100644 --- a/mlx/backend/cuda/layer_norm.cu +++ b/mlx/backend/cuda/layer_norm.cu @@ -280,10 +280,10 @@ void LayerNorm::eval_gpu( n_rows, block_dim(), 0, - x.data(), - w.data(), - b.data(), - out.data(), + gpu_ptr(x), + gpu_ptr(w), + gpu_ptr(b), + gpu_ptr(out), eps_, axis_size, w_stride, @@ -393,11 +393,11 @@ void LayerNormVJP::eval_gpu( n_rows, block_dim(), 0, - x.data(), - w.data(), - g.data(), - gx.data(), - gw_temp.data(), + gpu_ptr(x), + gpu_ptr(w), + gpu_ptr(g), + gpu_ptr(gx), + gpu_ptr(gw_temp), eps_, axis_size, w_stride); diff --git a/mlx/backend/cuda/logsumexp.cu b/mlx/backend/cuda/logsumexp.cu index e7641c05d..a3841a957 100644 --- a/mlx/backend/cuda/logsumexp.cu +++ b/mlx/backend/cuda/logsumexp.cu @@ -151,8 +151,8 @@ void LogSumExp::eval_gpu(const std::vector& inputs, array& out) { n_rows, block_dim(), 0, - in.data(), - out.data(), + gpu_ptr(in), + gpu_ptr(out), axis_size); }); }); diff --git a/mlx/backend/cuda/quantized/affine_quantize.cu b/mlx/backend/cuda/quantized/affine_quantize.cu index a64597a88..da4b34135 100644 --- a/mlx/backend/cuda/quantized/affine_quantize.cu +++ b/mlx/backend/cuda/quantized/affine_quantize.cu @@ -262,10 +262,10 @@ void affine_quantize( num_blocks, block_dims, 0, - w.data(), - wq.data(), - scales.data(), - biases.data(), + gpu_ptr(w), + gpu_ptr(wq), + gpu_ptr(scales), + gpu_ptr(biases), w.size()); }); }); @@ -318,10 +318,10 @@ void affine_dequantize( num_blocks, block_dims, 0, - wq.data(), - scales.data(), - biases.data(), - w.data(), + gpu_ptr(wq), + gpu_ptr(scales), + gpu_ptr(biases), + gpu_ptr(w), w.size()); }); }); diff --git a/mlx/backend/cuda/quantized/fp_quantize.cu b/mlx/backend/cuda/quantized/fp_quantize.cu index 0f979dfb0..45c61baf5 100644 --- a/mlx/backend/cuda/quantized/fp_quantize.cu +++ b/mlx/backend/cuda/quantized/fp_quantize.cu @@ -156,9 +156,9 @@ void fp_quantize( num_blocks, block_dims, 0, - w.data(), - wq.data(), - scales.data(), + gpu_ptr(w), + gpu_ptr(wq), + gpu_ptr(scales), w.size()); } else { throw std::runtime_error( @@ -202,9 +202,9 @@ void fp_dequantize( num_blocks, block_dims, 0, - wq.data(), - scales.data(), - w.data(), + gpu_ptr(wq), + gpu_ptr(scales), + gpu_ptr(w), w.size()); } else { throw std::runtime_error( diff --git a/mlx/backend/cuda/quantized/quantized.cpp b/mlx/backend/cuda/quantized/quantized.cpp index 58710834f..f221abc50 100644 --- a/mlx/backend/cuda/quantized/quantized.cpp +++ b/mlx/backend/cuda/quantized/quantized.cpp @@ -59,7 +59,7 @@ void fast::Quantize::eval_gpu( auto scales = ensure_row_contiguous(inputs[1], enc, s); auto& w = outputs[0]; - w.set_data(allocator::malloc(w.nbytes())); + w.set_data(cu::malloc_async(w.nbytes(), enc.stream())); if (mode_ == QuantizationMode::Affine) { auto biases = ensure_row_contiguous(inputs[2], enc, s); @@ -72,8 +72,8 @@ void fast::Quantize::eval_gpu( auto& wq = outputs[0]; auto& scales = outputs[1]; - wq.set_data(allocator::malloc(wq.nbytes())); - scales.set_data(allocator::malloc(scales.nbytes())); + wq.set_data(cu::malloc_async(wq.nbytes(), enc.stream())); + scales.set_data(cu::malloc_async(scales.nbytes(), enc.stream())); if (mode_ == QuantizationMode::Affine) { auto& biases = outputs[2]; biases.set_data(allocator::malloc(biases.nbytes())); diff --git a/mlx/backend/cuda/random.cu b/mlx/backend/cuda/random.cu index 26a3eb8b7..68dcacdc7 100644 --- a/mlx/backend/cuda/random.cu +++ b/mlx/backend/cuda/random.cu @@ -143,7 +143,9 @@ void RandomBits::eval_gpu(const std::vector& inputs, array& out) { uint32_t elems_per_key = out.size() / num_keys; uint32_t bytes_per_key = out.itemsize() * elems_per_key; - out.set_data(allocator::malloc(out.nbytes())); + auto& s = stream(); + auto& encoder = cu::get_command_encoder(s); + out.set_data(cu::malloc_async(out.nbytes(), encoder.stream())); if (out.size() == 0) { return; } @@ -152,8 +154,6 @@ void RandomBits::eval_gpu(const std::vector& inputs, array& out) { uint32_t half_size = out_per_key / 2; bool odd = out_per_key % 2; - auto& s = stream(); - auto& encoder = cu::get_command_encoder(s); encoder.set_input_array(keys); encoder.set_output_array(out); dim3 grid_dims{num_keys, half_size + odd}; @@ -171,8 +171,8 @@ void RandomBits::eval_gpu(const std::vector& inputs, array& out) { grid, block, 0, - keys.data(), - out.data(), + gpu_ptr(keys), + gpu_ptr(out), grid_dims, odd, bytes_per_key); @@ -182,8 +182,8 @@ void RandomBits::eval_gpu(const std::vector& inputs, array& out) { grid, block, 0, - keys.data(), - out.data(), + gpu_ptr(keys), + gpu_ptr(out), grid_dims, odd, bytes_per_key, diff --git a/mlx/backend/cuda/reduce/all_reduce.cu b/mlx/backend/cuda/reduce/all_reduce.cu index b815597bd..685bbefa8 100644 --- a/mlx/backend/cuda/reduce/all_reduce.cu +++ b/mlx/backend/cuda/reduce/all_reduce.cu @@ -66,7 +66,7 @@ void all_reduce( Reduce::ReduceType reduce_type) { constexpr int N_READS = 8; - out.set_data(allocator::malloc(out.nbytes())); + out.set_data(cu::malloc_async(out.nbytes(), encoder.stream())); auto get_args = [](size_t size, int N) { int threads = std::min(512UL, (size + N - 1) / N); @@ -100,14 +100,15 @@ void all_reduce( Dtype dt = in.dtype(); // Cub doesn't like const pointers for load (sigh). - void* indata = const_cast(in.data()); + void* indata = const_cast(gpu_ptr(in)); // Large array so allocate an intermediate and accumulate there std::tie(blocks, threads, block_step) = get_args(insize, N_READS); encoder.set_input_array(in); if (blocks > 1) { array intermediate({blocks}, out.dtype(), nullptr, {}); - intermediate.set_data(allocator::malloc(intermediate.nbytes())); + intermediate.set_data( + cu::malloc_async(intermediate.nbytes(), encoder.stream())); encoder.add_temporary(intermediate); encoder.set_output_array(intermediate); dispatch_all_types(dt, [&](auto type_tag) { @@ -122,14 +123,14 @@ void all_reduce( threads, 0, static_cast(indata), - intermediate.data(), + gpu_ptr(intermediate), block_step, insize); }); }); // Set the input for the next step and recalculate the blocks - indata = intermediate.data(); + indata = gpu_ptr(intermediate); dt = intermediate.dtype(); insize = intermediate.size(); std::tie(blocks, threads, block_step) = get_args(insize, N_READS); @@ -149,7 +150,7 @@ void all_reduce( threads, 0, static_cast(indata), - out.data(), + gpu_ptr(out), block_step, insize); }); diff --git a/mlx/backend/cuda/reduce/col_reduce.cu b/mlx/backend/cuda/reduce/col_reduce.cu index feafc2fb0..fb81079a5 100644 --- a/mlx/backend/cuda/reduce/col_reduce.cu +++ b/mlx/backend/cuda/reduce/col_reduce.cu @@ -250,7 +250,7 @@ void col_reduce_looped( const cu::ColReduceArgs& args) { // Allocate data for the output using in's layout to access them as // contiguously as possible. - allocate_same_layout(out, in, axes); + allocate_same_layout(out, in, axes, encoder); encoder.set_input_array(in); encoder.set_output_array(out); @@ -261,7 +261,7 @@ void col_reduce_looped( using T = cuda_type_t; using U = typename cu::ReduceResult::type; // Cub doesn't like const pointers for vectorized loads. (sigh) - T* indata = const_cast(in.data()); + T* indata = const_cast(gpu_ptr(in)); constexpr int N_READS = 4; constexpr int BM = 32; @@ -276,7 +276,7 @@ void col_reduce_looped( blocks, 0, indata, - out.data(), + gpu_ptr(out), static_cast(args)); }); }); @@ -293,7 +293,7 @@ void col_reduce_small( const cu::ColReduceArgs& args) { // Allocate data for the output using in's layout to access them as // contiguously as possible. - allocate_same_layout(out, in, axes); + allocate_same_layout(out, in, axes, encoder); encoder.set_input_array(in); encoder.set_output_array(out); @@ -312,8 +312,8 @@ void col_reduce_small( grid, block, 0, - in.data(), - out.data(), + gpu_ptr(in), + gpu_ptr(out), static_cast(args), out.size()); }); diff --git a/mlx/backend/cuda/reduce/init_reduce.cu b/mlx/backend/cuda/reduce/init_reduce.cu index 8c0d380f5..40e3d6859 100644 --- a/mlx/backend/cuda/reduce/init_reduce.cu +++ b/mlx/backend/cuda/reduce/init_reduce.cu @@ -28,7 +28,7 @@ void init_reduce( Reduce::ReduceType reduce_type) { // Allocate if needed if (out.data_shared_ptr() == nullptr) { - out.set_data(allocator::malloc(out.nbytes())); + out.set_data(cu::malloc_async(out.nbytes(), encoder.stream())); } encoder.set_output_array(out); @@ -42,7 +42,7 @@ void init_reduce( dim3 block(grid.x < 1024 ? grid.x : 1024, 1, 1); grid.x = (grid.x + 1023) / 1024; encoder.add_kernel_node( - kernel, grid, block, 0, out.data(), out.size()); + kernel, grid, block, 0, gpu_ptr(out), out.size()); }); }); } diff --git a/mlx/backend/cuda/reduce/reduce_utils.cuh b/mlx/backend/cuda/reduce/reduce_utils.cuh index d993bacbb..0323c3184 100644 --- a/mlx/backend/cuda/reduce/reduce_utils.cuh +++ b/mlx/backend/cuda/reduce/reduce_utils.cuh @@ -5,6 +5,7 @@ #include #include "mlx/backend/common/utils.h" +#include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device/utils.cuh" #include @@ -92,9 +93,10 @@ block_reduce(Block block, Warp warp, T (&vals)[N], T* smem, Op op, T init) { inline void allocate_same_layout( array& out, const array& in, - const std::vector& axes) { + const std::vector& axes, + cu::CommandEncoder& encoder) { if (in.flags().row_contiguous) { - out.set_data(allocator::malloc(out.nbytes())); + out.set_data(cu::malloc_async(out.nbytes(), encoder.stream())); return; } @@ -133,7 +135,7 @@ inline void allocate_same_layout( fl.col_contiguous = cc; fl.contiguous = true; out.set_data( - allocator::malloc(out.nbytes()), + cu::malloc_async(out.nbytes(), encoder.stream()), data_size, final_strides, fl, diff --git a/mlx/backend/cuda/reduce/row_reduce.cu b/mlx/backend/cuda/reduce/row_reduce.cu index 1ae46d0a3..ea99e1132 100644 --- a/mlx/backend/cuda/reduce/row_reduce.cu +++ b/mlx/backend/cuda/reduce/row_reduce.cu @@ -238,7 +238,7 @@ void row_reduce_simple( const ReductionPlan& plan) { // Allocate data for the output using in's layout to avoid elem_to_loc in the // kernel. - allocate_same_layout(out, in, axes); + allocate_same_layout(out, in, axes, encoder); // TODO: If out.size() < 1024 which will be a common case then write this in // 2 passes. Something like 32 * out.size() and then do a warp reduce. @@ -268,10 +268,10 @@ void row_reduce_simple( kernel = cu::row_reduce_simple; } - T* indata = const_cast(in.data()); + T* indata = const_cast(gpu_ptr(in)); int size = plan.shape.back(); encoder.add_kernel_node( - kernel, grid, block, 0, indata, out.data(), out.size(), size); + kernel, grid, block, 0, indata, gpu_ptr(out), out.size(), size); }); }); } @@ -286,7 +286,7 @@ void row_reduce_looped( cu::RowReduceArgs args) { // Allocate data for the output using in's layout to access them as // contiguously as possible. - allocate_same_layout(out, in, axes); + allocate_same_layout(out, in, axes, encoder); encoder.set_input_array(in); encoder.set_output_array(out); @@ -315,7 +315,7 @@ void row_reduce_looped( }); encoder.add_kernel_node( - kernel, grid, block, 0, in.data(), out.data(), args); + kernel, grid, block, 0, gpu_ptr(in), gpu_ptr(out), args); }); }); } diff --git a/mlx/backend/cuda/rms_norm.cu b/mlx/backend/cuda/rms_norm.cu index f0cce445c..8d5d02238 100644 --- a/mlx/backend/cuda/rms_norm.cu +++ b/mlx/backend/cuda/rms_norm.cu @@ -223,9 +223,9 @@ void RMSNorm::eval_gpu( n_rows, block_dim(), 0, - x.data(), - w.data(), - out.data(), + gpu_ptr(x), + gpu_ptr(w), + gpu_ptr(out), eps_, axis_size, w_stride); @@ -318,11 +318,11 @@ void RMSNormVJP::eval_gpu( n_rows, block_dim(), 0, - x.data(), - w.data(), - g.data(), - gx.data(), - gw_temp.data(), + gpu_ptr(x), + gpu_ptr(w), + gpu_ptr(g), + gpu_ptr(gx), + gpu_ptr(gw_temp), eps_, axis_size, w_stride); diff --git a/mlx/backend/cuda/rope.cu b/mlx/backend/cuda/rope.cu index 84de58a63..67285e884 100644 --- a/mlx/backend/cuda/rope.cu +++ b/mlx/backend/cuda/rope.cu @@ -340,9 +340,9 @@ void RoPE::eval_gpu( grid, block, 0, - (donated ? out : in).data(), - out.data(), - offset.data(), + gpu_ptr(donated ? out : in), + gpu_ptr(out), + gpu_ptr(offset), scale_, std::log2(base_), mat_size, @@ -357,10 +357,10 @@ void RoPE::eval_gpu( grid, block, 0, - (donated ? out : in).data(), - out.data(), - offset.data(), - inputs[2].data(), + gpu_ptr(donated ? out : in), + gpu_ptr(out), + gpu_ptr(offset), + gpu_ptr(inputs[2]), scale_, mat_size, dims, @@ -381,10 +381,10 @@ void RoPE::eval_gpu( grid, block, 0, - (donated ? out : in).data(), - out.data(), - offset.data(), - inputs[2].data(), + gpu_ptr(donated ? out : in), + gpu_ptr(out), + gpu_ptr(offset), + gpu_ptr(inputs[2]), scale_, std::log2(base_), strides, @@ -408,9 +408,9 @@ void RoPE::eval_gpu( grid, block, 0, - (donated ? out : in).data(), - out.data(), - offset.data(), + gpu_ptr(donated ? out : in), + gpu_ptr(out), + gpu_ptr(offset), scale_, std::log2(base_), strides, diff --git a/mlx/backend/cuda/scaled_dot_product_attention.cu b/mlx/backend/cuda/scaled_dot_product_attention.cu index be6968702..151fca041 100644 --- a/mlx/backend/cuda/scaled_dot_product_attention.cu +++ b/mlx/backend/cuda/scaled_dot_product_attention.cu @@ -513,11 +513,11 @@ void sdpa_vector_1pass_fallback( grid_dim, block_dim, 0, - q.data(), - k.data(), - v.data(), - o.data(), - sinks ? (*sinks).data() : nullptr, + gpu_ptr(q), + gpu_ptr(k), + gpu_ptr(v), + gpu_ptr(o), + sinks ? gpu_ptr(*sinks) : nullptr, params); }); }); @@ -602,13 +602,13 @@ void sdpa_vector_2pass_fallback( grid_dim, block_dim, 0, - q.data(), - k.data(), - v.data(), - sinks ? (*sinks).data() : nullptr, - intermediate.data(), - sums.data(), - maxs.data(), + gpu_ptr(q), + gpu_ptr(k), + gpu_ptr(v), + sinks ? gpu_ptr(*sinks) : nullptr, + gpu_ptr(intermediate), + gpu_ptr(sums), + gpu_ptr(maxs), params); } @@ -629,10 +629,10 @@ void sdpa_vector_2pass_fallback( grid_dim, block_dim, 0, - intermediate.data(), - sums.data(), - maxs.data(), - o.data(), + gpu_ptr(intermediate), + gpu_ptr(sums), + gpu_ptr(maxs), + gpu_ptr(o), params); } }); diff --git a/mlx/backend/cuda/scan.cu b/mlx/backend/cuda/scan.cu index 4d07aed5c..d05cdb94a 100644 --- a/mlx/backend/cuda/scan.cu +++ b/mlx/backend/cuda/scan.cu @@ -415,8 +415,8 @@ void Scan::eval_gpu(const std::vector& inputs, array& out) { in.data_size() / axis_size, block_dim, 0, - in.data(), - out.data(), + gpu_ptr(in), + gpu_ptr(out), axis_size); } else { constexpr int BM = WARP_SIZE; @@ -445,8 +445,8 @@ void Scan::eval_gpu(const std::vector& inputs, array& out) { num_blocks, block_dim, 0, - in.data(), - out.data(), + gpu_ptr(in), + gpu_ptr(out), axis_size, stride, stride_blocks); diff --git a/mlx/backend/cuda/softmax.cu b/mlx/backend/cuda/softmax.cu index cb8e7f35a..e065d0b89 100644 --- a/mlx/backend/cuda/softmax.cu +++ b/mlx/backend/cuda/softmax.cu @@ -152,8 +152,8 @@ void Softmax::eval_gpu(const std::vector& inputs, array& out) { n_rows, block_dim(), 0, - in.data(), - out.data(), + gpu_ptr(in), + gpu_ptr(out), axis_size); }); }); diff --git a/mlx/backend/cuda/sort.cu b/mlx/backend/cuda/sort.cu index b9f52762c..5becbd496 100644 --- a/mlx/backend/cuda/sort.cu +++ b/mlx/backend/cuda/sort.cu @@ -91,10 +91,10 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) { CHECK_CUDA_ERROR(cub::DeviceSegmentedRadixSort::SortPairs( nullptr, size, - in.data(), - discard.data(), - indices.data(), - out.data(), + gpu_ptr(in), + gpu_ptr(discard), + gpu_ptr(indices), + gpu_ptr(out), in.data_size(), in.data_size() / nsort, offsets, @@ -115,16 +115,16 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) { cu::thrust_policy(stream), thrust::counting_iterator(0), thrust::counting_iterator(indices.data_size()), - thrust::device_pointer_cast(indices.data()), + thrust::device_pointer_cast(gpu_ptr(indices)), ModOp{static_cast(nsort)}); CHECK_CUDA_ERROR(cub::DeviceSegmentedRadixSort::SortPairs( - temp.data(), + gpu_ptr(temp), size, - in.data(), - discard.data(), - indices.data(), - out.data(), + gpu_ptr(in), + gpu_ptr(discard), + gpu_ptr(indices), + gpu_ptr(out), in.data_size(), in.data_size() / nsort, offsets, @@ -137,8 +137,8 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) { CHECK_CUDA_ERROR(cub::DeviceSegmentedRadixSort::SortKeys( nullptr, size, - in.data(), - out.data(), + gpu_ptr(in), + gpu_ptr(out), in.data_size(), in.data_size() / nsort, offsets, @@ -156,10 +156,10 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) { // Start capturing after allocations auto capture = encoder.capture_context(); CHECK_CUDA_ERROR(cub::DeviceSegmentedRadixSort::SortKeys( - temp.data(), + gpu_ptr(temp), size, - in.data(), - out.data(), + gpu_ptr(in), + gpu_ptr(out), in.data_size(), in.data_size() / nsort, offsets, diff --git a/mlx/backend/cuda/ternary.cu b/mlx/backend/cuda/ternary.cu index 84ae996aa..b1cd99f7d 100644 --- a/mlx/backend/cuda/ternary.cu +++ b/mlx/backend/cuda/ternary.cu @@ -168,10 +168,10 @@ void ternary_op_gpu_inplace( num_blocks, block_dims, 0, - a.data(), - b.data(), - c.data(), - out.data(), + gpu_ptr(a), + gpu_ptr(b), + gpu_ptr(c), + gpu_ptr(out), out.data_size()); }); } else { @@ -211,10 +211,10 @@ void ternary_op_gpu_inplace( {num_blocks_x, num_blocks_y}, block_dims, 0, - a.data(), - b.data(), - c.data(), - out.data(), + gpu_ptr(a), + gpu_ptr(b), + gpu_ptr(c), + gpu_ptr(out), rest, const_param(shape), const_param(a_strides), @@ -231,10 +231,10 @@ void ternary_op_gpu_inplace( {num_blocks_x, num_blocks_y}, block_dims, 0, - a.data(), - b.data(), - c.data(), - out.data(), + gpu_ptr(a), + gpu_ptr(b), + gpu_ptr(c), + gpu_ptr(out), rest, const_param(shape), const_param(a_strides), @@ -256,7 +256,10 @@ void ternary_op_gpu( auto& b = inputs[1]; auto& c = inputs[2]; auto topt = get_ternary_op_type(a, b, c); - set_ternary_op_output_data(a, b, c, out, topt); + auto& encoder = cu::get_command_encoder(s); + set_ternary_op_output_data(a, b, c, out, topt, [&](auto n) { + return cu::malloc_async(n, encoder.stream()); + }); ternary_op_gpu_inplace(inputs, out, s); } diff --git a/mlx/backend/cuda/unary/unary.cuh b/mlx/backend/cuda/unary/unary.cuh index 8f4a02d50..2ffbf1eea 100644 --- a/mlx/backend/cuda/unary/unary.cuh +++ b/mlx/backend/cuda/unary/unary.cuh @@ -158,8 +158,8 @@ void unary_op_gpu_inplace( num_blocks, block_dims, 0, - in.data(), - out.data(), + gpu_ptr(in), + gpu_ptr(out), out.data_size()); } else { using IdxT = std::conditional_t; @@ -182,8 +182,8 @@ void unary_op_gpu_inplace( {num_blocks_x, num_blocks_y}, block_dims, 0, - in.data(), - out.data(), + gpu_ptr(in), + gpu_ptr(out), rest, const_param(shape), const_param(strides), @@ -207,7 +207,10 @@ void unary_op_gpu( array& out, const char* op, const Stream& s) { - set_unary_output_data(inputs[0], out); + auto& encoder = cu::get_command_encoder(s); + set_unary_output_data(inputs[0], out, [&](auto n) { + return cu::malloc_async(n, encoder.stream()); + }); unary_op_gpu_inplace(inputs, out, op, s); } diff --git a/mlx/backend/cuda/utils.h b/mlx/backend/cuda/utils.h index 9e95a84ef..2d5ae59ef 100644 --- a/mlx/backend/cuda/utils.h +++ b/mlx/backend/cuda/utils.h @@ -7,6 +7,8 @@ #include #include #include +#include "mlx/array.h" +#include "mlx/backend/cuda/allocator.h" namespace mlx::core { @@ -15,6 +17,16 @@ class Device; } +template +T* gpu_ptr(array& arr) { + return cu::gpu_ptr(arr.buffer()); +} + +template +const T* gpu_ptr(const array& arr) { + return cu::gpu_ptr(arr.buffer()); +} + struct Dtype; // Throw exception if the cuda API does not succeed. diff --git a/mlx/backend/gpu/copy.cpp b/mlx/backend/gpu/copy.cpp index 472ee486b..f8caf09cd 100644 --- a/mlx/backend/gpu/copy.cpp +++ b/mlx/backend/gpu/copy.cpp @@ -7,18 +7,7 @@ namespace mlx::core { -void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) { - bool donated = set_copy_output_data(in, out, ctype); - if (donated && in.dtype() == out.dtype()) { - // If the output has the same type as the input then there is nothing to - // copy, just use the buffer. - return; - } - if (ctype == CopyType::GeneralGeneral) { - ctype = CopyType::General; - } - copy_gpu_inplace(in, out, ctype, s); -} +void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s); void copy_gpu(const array& in, array& out, CopyType ctype) { copy_gpu(in, out, ctype, out.primitive().stream()); diff --git a/mlx/backend/metal/copy.cpp b/mlx/backend/metal/copy.cpp index d9c568e52..c58e7d3c2 100644 --- a/mlx/backend/metal/copy.cpp +++ b/mlx/backend/metal/copy.cpp @@ -10,6 +10,19 @@ namespace mlx::core { constexpr int MAX_COPY_SPECIALIZED_DIMS = 3; +void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) { + bool donated = set_copy_output_data(in, out, ctype); + if (donated && in.dtype() == out.dtype()) { + // If the output has the same type as the input then there is nothing to + // copy, just use the buffer. + return; + } + if (ctype == CopyType::GeneralGeneral) { + ctype = CopyType::General; + } + copy_gpu_inplace(in, out, ctype, s); +} + void copy_gpu_inplace( const array& in, array& out,