diff --git a/mlx/backend/cuda/allocator.cpp b/mlx/backend/cuda/allocator.cpp index e1507c631..97a5ae4d4 100644 --- a/mlx/backend/cuda/allocator.cpp +++ b/mlx/backend/cuda/allocator.cpp @@ -1,6 +1,7 @@ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/allocator.h" +#include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/utils.h" #include "mlx/utils.h" @@ -93,9 +94,17 @@ CudaAllocator::CudaAllocator() CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total)); memory_limit_ = total * 0.95; max_pool_size_ = memory_limit_; +#if CUDART_VERSION >= 13000 + cudaMemLocation loc; + loc.id = 0; + loc.type = cudaMemLocationTypeNone; + cudaMemGetDefaultMemPool(&cuda_pool_, &loc, cudaMemAllocationTypeManaged); + // TODO set that. + // uint64_t threshold = UINT64_MAX; +#endif } -Buffer CudaAllocator::malloc(size_t size) { +Buffer CudaAllocator::malloc_impl(size_t size, cudaStream_t stream) { // Find available buffer from cache. auto orig_size = size; std::unique_lock lock(mutex_); @@ -123,7 +132,12 @@ Buffer CudaAllocator::malloc(size_t size) { lock.unlock(); if (!buf) { buf = new CudaBuffer{nullptr, size}; - cudaError_t err = cudaMallocManaged(&buf->data, size); + cudaError_t err; + if (stream != nullptr && cuda_pool_ != nullptr) { + err = cudaMallocFromPoolAsync(&buf->data, size, cuda_pool_, stream); + } else { + err = cudaMallocManaged(&buf->data, size); + } if (err != cudaSuccess && err != cudaErrorMemoryAllocation) { throw std::runtime_error(fmt::format( "cudaMallocManaged failed: {}.", cudaGetErrorString(err))); @@ -141,6 +155,14 @@ Buffer CudaAllocator::malloc(size_t size) { return Buffer{buf}; } +Buffer CudaAllocator::malloc_async(size_t size, cudaStream_t stream) { + return malloc_impl(size, stream); +} + +Buffer CudaAllocator::malloc(size_t size) { + return malloc_impl(size, nullptr); +} + void CudaAllocator::free(Buffer buffer) { auto* buf = static_cast(buffer.ptr()); if (!buf) { @@ -220,6 +242,16 @@ CudaAllocator& allocator() { return *allocator_; } +Buffer malloc_async(size_t size, cudaStream_t stream) { + auto buffer = allocator().malloc_async(size, stream); + if (size && !buffer.ptr()) { + std::ostringstream msg; + msg << "[malloc_async] Unable to allocate " << size << " bytes."; + throw std::runtime_error(msg.str()); + } + return buffer; +} + } // namespace cu namespace allocator { diff --git a/mlx/backend/cuda/allocator.h b/mlx/backend/cuda/allocator.h index 81b3dde59..eb69db3b4 100644 --- a/mlx/backend/cuda/allocator.h +++ b/mlx/backend/cuda/allocator.h @@ -5,6 +5,7 @@ #include "mlx/allocator.h" #include "mlx/backend/common/buffer_cache.h" +#include #include #include #include @@ -45,6 +46,7 @@ class SmallSizePool { class CudaAllocator : public allocator::Allocator { public: Buffer malloc(size_t size) override; + Buffer malloc_async(size_t size, cudaStream_t stream); void free(Buffer buffer) override; size_t size(Buffer buffer) const override; @@ -58,6 +60,7 @@ class CudaAllocator : public allocator::Allocator { void clear_cache(); private: + Buffer malloc_impl(size_t size, cudaStream_t stream); void cuda_free(CudaBuffer* buf); CudaAllocator(); @@ -70,8 +73,11 @@ class CudaAllocator : public allocator::Allocator { size_t active_memory_{0}; size_t peak_memory_{0}; SmallSizePool scalar_pool_; + cudaMemPool_t cuda_pool_{nullptr}; }; CudaAllocator& allocator(); +Buffer malloc_async(size_t size, cudaStream_t stream); + } // namespace mlx::core::cu diff --git a/mlx/backend/cuda/arange.cu b/mlx/backend/cuda/arange.cu index a28a245db..1c1f96a72 100644 --- a/mlx/backend/cuda/arange.cu +++ b/mlx/backend/cuda/arange.cu @@ -41,9 +41,8 @@ void Arange::eval_gpu(const std::vector& inputs, array& out) { if (out.size() == 0) { return; } - out.set_data(allocator::malloc(out.nbytes())); - auto& encoder = cu::get_command_encoder(stream()); + out.set_data(cu::malloc_async(out.nbytes(), encoder.stream())); encoder.set_output_array(out); dispatch_int_float_types(out.dtype(), "Arange", [&](auto type_tag) { diff --git a/mlx/backend/cuda/arg_reduce.cu b/mlx/backend/cuda/arg_reduce.cu index 21cd677a8..04bfe33ce 100644 --- a/mlx/backend/cuda/arg_reduce.cu +++ b/mlx/backend/cuda/arg_reduce.cu @@ -140,8 +140,10 @@ void ArgReduce::eval_gpu(const std::vector& inputs, array& out) { nvtx3::scoped_range r("ArgReduce::eval_gpu"); assert(inputs.size() == 1); auto& in = inputs[0]; - out.set_data(allocator::malloc(out.nbytes())); + auto& s = stream(); + auto& encoder = cu::get_command_encoder(s); + out.set_data(cu::malloc_async(out.nbytes(), encoder.stream())); // Prepare the shapes, strides and axis arguments. Shape shape = remove_index(in.shape(), axis_); @@ -154,7 +156,6 @@ void ArgReduce::eval_gpu(const std::vector& inputs, array& out) { int32_t ndim = shape.size(); // ArgReduce. - auto& encoder = cu::get_command_encoder(s); encoder.set_input_array(in); encoder.set_output_array(out); dispatch_real_types(in.dtype(), "ArgReduce", [&](auto type_tag) { diff --git a/mlx/backend/cuda/copy.cu b/mlx/backend/cuda/copy.cu index 158d3de6e..17e414f73 100644 --- a/mlx/backend/cuda/copy.cu +++ b/mlx/backend/cuda/copy.cu @@ -87,8 +87,8 @@ void fill_gpu(const array& in, array& out, const Stream& s) { if (out.size() == 0) { return; } - out.set_data(allocator::malloc(out.nbytes())); auto& encoder = cu::get_command_encoder(s); + out.set_data(cu::malloc_async(out.nbytes(), encoder.stream())); encoder.set_input_array(in); encoder.set_output_array(out); copy_contiguous(encoder, CopyType::Scalar, in, out, 0, 0); diff --git a/mlx/backend/cuda/device.h b/mlx/backend/cuda/device.h index d18092328..c04973484 100644 --- a/mlx/backend/cuda/device.h +++ b/mlx/backend/cuda/device.h @@ -3,6 +3,7 @@ #pragma once #include "mlx/array.h" +#include "mlx/backend/cuda/allocator.h" #include "mlx/backend/cuda/lru_cache.h" #include "mlx/backend/cuda/worker.h" #include "mlx/stream.h" diff --git a/mlx/backend/cuda/gemms/cublas_gemm.cpp b/mlx/backend/cuda/gemms/cublas_gemm.cpp index fc2380fc3..041784913 100644 --- a/mlx/backend/cuda/gemms/cublas_gemm.cpp +++ b/mlx/backend/cuda/gemms/cublas_gemm.cpp @@ -370,7 +370,7 @@ void CublasGemm::execute( // Ensure workspace is 256-byte aligned int nbytes = cuda::ceil_div(heuristic_.workspaceSize, 256) * 256; array workspace( - allocator::malloc(nbytes), + cu::malloc_async(nbytes, encoder.stream()), {static_cast(heuristic_.workspaceSize)}, int8); encoder.add_temporary(workspace); diff --git a/mlx/backend/cuda/gemms/cublas_gemm_batched_12_9.cu b/mlx/backend/cuda/gemms/cublas_gemm_batched_12_9.cu index 41ab9c8bd..03aca99d6 100644 --- a/mlx/backend/cuda/gemms/cublas_gemm_batched_12_9.cu +++ b/mlx/backend/cuda/gemms/cublas_gemm_batched_12_9.cu @@ -163,7 +163,7 @@ void CublasGemm::run_batched( // Launch kernel to set device offsets auto pointers = array( - allocator::malloc(batch_count * sizeof(void*) * 3), + cu::malloc_async(batch_count * sizeof(void*) * 3, encoder.stream()), {batch_count * 3}, uint64); @@ -251,7 +251,7 @@ void CublasGemm::run_batched( // Launch kernel to set device offsets auto pointers = array( - allocator::malloc(batch_count * sizeof(uint64_t) * 4), + cu::malloc_async(batch_count * sizeof(uint64_t) * 4, encoder.stream()), {batch_count * 4}, uint64); diff --git a/mlx/backend/cuda/indexing.cpp b/mlx/backend/cuda/indexing.cpp index 1d867e063..762a101c2 100644 --- a/mlx/backend/cuda/indexing.cpp +++ b/mlx/backend/cuda/indexing.cpp @@ -59,7 +59,9 @@ void Gather::eval_gpu(const std::vector& inputs, array& out) { assert(inputs.size() > 0); const auto& src = inputs[0]; - out.set_data(allocator::malloc(out.nbytes())); + auto& s = stream(); + auto& encoder = cu::get_command_encoder(s); + out.set_data(cu::malloc_async(out.nbytes(), encoder.stream())); if (out.size() == 0) { return; } @@ -80,7 +82,6 @@ void Gather::eval_gpu(const std::vector& inputs, array& out) { dtype_to_string(idx_dtype), nidx); - auto& s = stream(); cu::JitModule& mod = cu::get_jit_module(s.device, module_name, [&]() { std::vector kernel_names; for (int ndim = 0; ndim <= MAX_NDIM; ++ndim) { @@ -121,7 +122,6 @@ void Gather::eval_gpu(const std::vector& inputs, array& out) { idx_ndim, large ? "int64_t" : "int32_t"); - auto& encoder = cu::get_command_encoder(s); for (const auto& in : inputs) { encoder.set_input_array(in); } @@ -239,7 +239,9 @@ void GatherAxis::eval_gpu(const std::vector& inputs, array& out) { const auto& src = inputs[0]; const auto& idx = inputs[1]; - out.set_data(allocator::malloc(out.nbytes())); + auto& s = stream(); + auto& encoder = cu::get_command_encoder(s); + out.set_data(cu::malloc_async(out.nbytes(), encoder.stream())); if (out.size() == 0) { return; } @@ -251,7 +253,6 @@ void GatherAxis::eval_gpu(const std::vector& inputs, array& out) { dtype_to_string(out.dtype()), dtype_to_string(idx.dtype())); - auto& s = stream(); cu::JitModule& mod = cu::get_jit_module(s.device, module_name, [&]() { std::vector kernel_names; for (int ndim = 0; ndim <= MAX_NDIM; ++ndim) { @@ -312,7 +313,6 @@ void GatherAxis::eval_gpu(const std::vector& inputs, array& out) { idx.flags().row_contiguous, large ? "int64_t" : "int32_t"); - auto& encoder = cu::get_command_encoder(s); for (const auto& in : inputs) { encoder.set_input_array(in); } diff --git a/mlx/backend/cuda/layer_norm.cu b/mlx/backend/cuda/layer_norm.cu index 369b2547e..c1dc2833a 100644 --- a/mlx/backend/cuda/layer_norm.cu +++ b/mlx/backend/cuda/layer_norm.cu @@ -230,9 +230,10 @@ void LayerNorm::eval_gpu( nvtx3::scoped_range r("LayerNorm::eval_gpu"); auto& s = stream(); auto& out = outputs[0]; + auto& encoder = cu::get_command_encoder(s); // Make sure that the last dimension is contiguous. - auto set_output = [&s, &out](const array& x) { + auto set_output = [&s, &out, &encoder](const array& x) { bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1; if (no_copy && x.ndim() > 1) { auto s = x.strides()[x.ndim() - 2]; @@ -243,7 +244,7 @@ void LayerNorm::eval_gpu( out.copy_shared_buffer(x); } else { out.set_data( - allocator::malloc(x.data_size() * x.itemsize()), + cu::malloc_async(x.data_size() * x.itemsize(), encoder.stream()), x.data_size(), x.strides(), x.flags()); @@ -265,7 +266,6 @@ void LayerNorm::eval_gpu( int64_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0; int64_t b_stride = (b.ndim() == 1) ? b.strides()[0] : 0; - auto& encoder = cu::get_command_encoder(s); encoder.set_input_array(x); encoder.set_input_array(w); encoder.set_input_array(b); @@ -335,7 +335,7 @@ void LayerNormVJP::eval_gpu( gx.copy_shared_buffer(g); g_in_gx = true; } else { - gx.set_data(allocator::malloc(gx.nbytes())); + gx.set_data(cu::malloc_async(gx.nbytes(), encoder.stream())); } if (g_copied && !g_in_gx) { encoder.add_temporary(g); @@ -355,7 +355,7 @@ void LayerNormVJP::eval_gpu( g_in_gw = true; gw_temp.copy_shared_buffer(g); } else { - gw_temp.set_data(allocator::malloc(gw_temp.nbytes())); + gw_temp.set_data(cu::malloc_async(gw_temp.nbytes(), encoder.stream())); encoder.add_temporary(gw_temp); } } diff --git a/mlx/backend/cuda/logsumexp.cu b/mlx/backend/cuda/logsumexp.cu index b90c300d0..e7641c05d 100644 --- a/mlx/backend/cuda/logsumexp.cu +++ b/mlx/backend/cuda/logsumexp.cu @@ -115,7 +115,7 @@ void LogSumExp::eval_gpu(const std::vector& inputs, array& out) { auto in = ensure_contiguous(inputs[0]); if (in.flags().row_contiguous) { - out.set_data(allocator::malloc(out.nbytes())); + out.set_data(cu::malloc_async(out.nbytes(), encoder.stream())); } else { auto n = in.shape(-1); auto flags = in.flags(); @@ -130,7 +130,7 @@ void LogSumExp::eval_gpu(const std::vector& inputs, array& out) { } flags.col_contiguous = col_contig; out.set_data( - allocator::malloc(in.nbytes() / n), + cu::malloc_async(in.nbytes() / n, encoder.stream()), in.data_size() / n, std::move(strides), flags); diff --git a/mlx/backend/cuda/matmul.cpp b/mlx/backend/cuda/matmul.cpp index 8a79b172b..8ccf3c466 100644 --- a/mlx/backend/cuda/matmul.cpp +++ b/mlx/backend/cuda/matmul.cpp @@ -121,7 +121,7 @@ void Matmul::eval_gpu(const std::vector& inputs, array& out) { return; } - out.set_data(allocator::malloc(out.nbytes())); + out.set_data(cu::malloc_async(out.nbytes(), encoder.stream())); int M = a_pre.shape(-2); int N = b_pre.shape(-1); @@ -163,7 +163,7 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { if (beta_ == 1 && a.dtype() != complex64 && c.strides(-1) == 1 && c.data_size() == out.shape(-1)) { - out.set_data(allocator::malloc(out.nbytes())); + out.set_data(cu::malloc_async(out.nbytes(), encoder.stream())); gemm_and_bias( encoder, M, @@ -187,10 +187,10 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { auto sty = c.strides()[c.ndim() - 1]; if (sty == 1 && stx == c.shape(-1)) { ldc = stx; - out.set_data(allocator::malloc(out.nbytes())); + out.set_data(cu::malloc_async(out.nbytes(), encoder.stream())); } else if (sty == 1 && stx == 0) { ldc = 0; - out.set_data(allocator::malloc(out.nbytes())); + out.set_data(cu::malloc_async(out.nbytes(), encoder.stream())); } else { // Copy C into out and set C to out ldc = c.shape(-1); diff --git a/mlx/backend/cuda/rms_norm.cu b/mlx/backend/cuda/rms_norm.cu index bc879c6f8..f0cce445c 100644 --- a/mlx/backend/cuda/rms_norm.cu +++ b/mlx/backend/cuda/rms_norm.cu @@ -176,9 +176,10 @@ void RMSNorm::eval_gpu( nvtx3::scoped_range r("RMSNorm::eval_gpu"); auto& s = stream(); auto& out = outputs[0]; + auto& encoder = cu::get_command_encoder(s); // Make sure that the last dimension is contiguous. - auto set_output = [&s, &out](const array& x) { + auto set_output = [&s, &out, &encoder](const array& x) { bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1; if (no_copy && x.ndim() > 1) { auto s = x.strides()[x.ndim() - 2]; @@ -189,7 +190,7 @@ void RMSNorm::eval_gpu( out.copy_shared_buffer(x); } else { out.set_data( - allocator::malloc(x.data_size() * x.itemsize()), + cu::malloc_async(x.data_size() * x.itemsize(), encoder.stream()), x.data_size(), x.strides(), x.flags()); @@ -209,7 +210,6 @@ void RMSNorm::eval_gpu( int32_t n_rows = x.data_size() / axis_size; int64_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0; - auto& encoder = cu::get_command_encoder(s); encoder.set_input_array(x); encoder.set_input_array(w); encoder.set_output_array(out); @@ -274,7 +274,7 @@ void RMSNormVJP::eval_gpu( gx.copy_shared_buffer(g); g_in_gx = true; } else { - gx.set_data(allocator::malloc(gx.nbytes())); + gx.set_data(cu::malloc_async(gx.nbytes(), encoder.stream())); } if (g_copied && !g_in_gx) { encoder.add_temporary(g); @@ -292,7 +292,7 @@ void RMSNormVJP::eval_gpu( if (!g_in_gx && donate_g) { gw_temp.copy_shared_buffer(g); } else { - gw_temp.set_data(allocator::malloc(gw_temp.nbytes())); + gw_temp.set_data(cu::malloc_async(gw_temp.nbytes(), encoder.stream())); encoder.add_temporary(gw_temp); } } diff --git a/mlx/backend/cuda/rope.cu b/mlx/backend/cuda/rope.cu index bac67cf90..84de58a63 100644 --- a/mlx/backend/cuda/rope.cu +++ b/mlx/backend/cuda/rope.cu @@ -250,6 +250,7 @@ void RoPE::eval_gpu( nvtx3::scoped_range r("RoPE::eval_gpu"); auto& s = stream(); + auto& encoder = cu::get_command_encoder(s); auto& in = inputs[0]; auto& offset = inputs[1]; auto& out = outputs[0]; @@ -291,14 +292,14 @@ void RoPE::eval_gpu( donated = true; out.copy_shared_buffer(in); } else { - out.set_data(allocator::malloc(out.nbytes())); + out.set_data(cu::malloc_async(out.nbytes(), encoder.stream())); } strides[0] = mat_size; strides[1] = in.strides()[ndim - 2]; strides[2] = in.strides()[ndim - 1]; } else if (dispatch_ndim == 3) { // Handle non-contiguous 3D inputs - out.set_data(allocator::malloc(out.nbytes())); + out.set_data(cu::malloc_async(out.nbytes(), encoder.stream())); strides[0] = in.strides()[ndim - 3]; strides[1] = in.strides()[ndim - 2]; strides[2] = in.strides()[ndim - 1]; @@ -319,7 +320,6 @@ void RoPE::eval_gpu( bool single = in.flags().row_contiguous && B == 1 && T == 1; bool with_freqs = inputs.size() == 3; - auto& encoder = cu::get_command_encoder(s); encoder.set_input_array(donated ? out : in); encoder.set_input_array(offset); if (with_freqs) { diff --git a/mlx/backend/cuda/scaled_dot_product_attention.cu b/mlx/backend/cuda/scaled_dot_product_attention.cu index 825da5cd3..be6968702 100644 --- a/mlx/backend/cuda/scaled_dot_product_attention.cu +++ b/mlx/backend/cuda/scaled_dot_product_attention.cu @@ -565,9 +565,10 @@ void sdpa_vector_2pass_fallback( array sums(intermediate_shape, float32, nullptr, {}); array maxs(std::move(intermediate_shape), float32, nullptr, {}); - intermediate.set_data(allocator::malloc(intermediate.nbytes())); - sums.set_data(allocator::malloc(sums.nbytes())); - maxs.set_data(allocator::malloc(maxs.nbytes())); + intermediate.set_data( + cu::malloc_async(intermediate.nbytes(), encoder.stream())); + sums.set_data(cu::malloc_async(sums.nbytes(), encoder.stream())); + maxs.set_data(cu::malloc_async(maxs.nbytes(), encoder.stream())); encoder.add_temporary(intermediate); encoder.add_temporary(sums); @@ -787,7 +788,7 @@ void ScaledDotProductAttention::eval_gpu( }; o.set_data( - allocator::malloc(o.nbytes()), + cu::malloc_async(o.nbytes(), encoder.stream()), o.size(), {str_oB, str_oH, str_oL, str_oD}, flags); diff --git a/mlx/backend/cuda/scan.cu b/mlx/backend/cuda/scan.cu index 56d4ae275..4d07aed5c 100644 --- a/mlx/backend/cuda/scan.cu +++ b/mlx/backend/cuda/scan.cu @@ -367,13 +367,14 @@ void Scan::eval_gpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto in = inputs[0]; auto& s = stream(); + auto& encoder = cu::get_command_encoder(s); if (in.flags().contiguous && in.strides()[axis_] != 0) { if (in.is_donatable() && in.itemsize() == out.itemsize()) { out.copy_shared_buffer(in); } else { out.set_data( - allocator::malloc(in.data_size() * out.itemsize()), + cu::malloc_async(in.data_size() * out.itemsize(), encoder.stream()), in.data_size(), in.strides(), in.flags()); @@ -387,7 +388,6 @@ void Scan::eval_gpu(const std::vector& inputs, array& out) { int32_t axis_size = in.shape(axis_); bool contiguous = in.strides()[axis_] == 1; - auto& encoder = cu::get_command_encoder(s); encoder.set_input_array(in); encoder.set_output_array(out); diff --git a/mlx/backend/cuda/slicing.cpp b/mlx/backend/cuda/slicing.cpp index 93241936b..18cf1e02c 100644 --- a/mlx/backend/cuda/slicing.cpp +++ b/mlx/backend/cuda/slicing.cpp @@ -23,14 +23,15 @@ void concatenate_gpu( } std::partial_sum(sizes.cbegin(), sizes.cend(), sizes.begin()); - out.set_data(allocator::malloc(out.nbytes())); + auto& encoder = cu::get_command_encoder(s); + out.set_data(cu::malloc_async(out.nbytes(), encoder.stream())); auto strides = out.strides(); auto flags = out.flags(); flags.row_contiguous = false; flags.col_contiguous = false; flags.contiguous = false; - auto concurrent = cu::get_command_encoder(s).concurrent_context(); + auto concurrent = encoder.concurrent_context(); for (int i = 0; i < inputs.size(); i++) { array out_slice(inputs[i].shape(), out.dtype(), nullptr, {}); size_t data_offset = strides[axis] * sizes[i]; @@ -80,6 +81,7 @@ array compute_dynamic_offset( return std::make_tuple(false, std::move(source), std::vector{kernel_name}); }); + auto& encoder = cu::get_command_encoder(s); // Prepare output. array offset({1}, int64, nullptr, {}); bool donate = indices.is_donatable() && @@ -87,10 +89,9 @@ array compute_dynamic_offset( if (donate) { offset.copy_shared_buffer(indices); } else { - offset.set_data(allocator::malloc(offset.itemsize())); + offset.set_data(cu::malloc_async(offset.itemsize(), encoder.stream())); } - auto& encoder = cu::get_command_encoder(s); encoder.add_temporary(offset); encoder.set_input_array(indices); encoder.set_output_array(offset); diff --git a/mlx/backend/cuda/softmax.cu b/mlx/backend/cuda/softmax.cu index d808bce38..cb8e7f35a 100644 --- a/mlx/backend/cuda/softmax.cu +++ b/mlx/backend/cuda/softmax.cu @@ -109,15 +109,16 @@ void Softmax::eval_gpu(const std::vector& inputs, array& out) { nvtx3::scoped_range r("Softmax::eval_gpu"); assert(inputs.size() == 1); auto& s = stream(); + auto& encoder = cu::get_command_encoder(s); // Make sure that the last dimension is contiguous. - auto set_output = [&s, &out](const array& x) { + auto set_output = [&s, &out, &encoder](const array& x) { if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) { if (x.is_donatable()) { out.copy_shared_buffer(x); } else { out.set_data( - allocator::malloc(x.data_size() * x.itemsize()), + cu::malloc_async(x.data_size() * x.itemsize(), encoder.stream()), x.data_size(), x.strides(), x.flags()); @@ -136,7 +137,6 @@ void Softmax::eval_gpu(const std::vector& inputs, array& out) { int axis_size = in.shape().back(); int n_rows = in.data_size() / axis_size; - auto& encoder = cu::get_command_encoder(s); encoder.set_input_array(in); encoder.set_output_array(out); dispatch_float_types(out.dtype(), "softmax", [&](auto type_tag) { diff --git a/mlx/backend/cuda/sort.cu b/mlx/backend/cuda/sort.cu index 95a2d263c..b9f52762c 100644 --- a/mlx/backend/cuda/sort.cu +++ b/mlx/backend/cuda/sort.cu @@ -49,11 +49,14 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) { array trans = swapaxes_in_eval(in, axis, last_dim); in = contiguous_copy_gpu(trans, s); encoder.add_temporary(in); - out = array(allocator::malloc(out.nbytes()), in.shape(), out.dtype()); + out = array( + cu::malloc_async(out.nbytes(), encoder.stream()), + in.shape(), + out.dtype()); encoder.add_temporary(out); } else { out.set_data( - allocator::malloc(in.data_size() * out.itemsize()), + cu::malloc_async(in.data_size() * out.itemsize(), encoder.stream()), in.data_size(), in.strides(), in.flags()); @@ -70,12 +73,18 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) { thrust::make_counting_iterator(0), OffsetTransform{nsort}); if (argsort) { // Indices in the sorted dimension. - array indices(allocator::malloc(out.nbytes()), in.shape(), out.dtype()); + array indices( + cu::malloc_async(out.nbytes(), encoder.stream()), + in.shape(), + out.dtype()); encoder.add_temporary(indices); // In argsort though we don't need the result of sorted values, the // API requires us to provide an array to store it. - array discard(allocator::malloc(in.nbytes()), in.shape(), in.dtype()); + array discard( + cu::malloc_async(in.nbytes(), encoder.stream()), + in.shape(), + in.dtype()); encoder.add_temporary(discard); size_t size; @@ -94,7 +103,10 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) { sizeof(Type) * 8, stream)); - array temp(allocator::malloc(size), {static_cast(size)}, uint8); + array temp( + cu::malloc_async(size, encoder.stream()), + {static_cast(size)}, + uint8); encoder.add_temporary(temp); // Start capturing after allocations @@ -135,7 +147,10 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) { sizeof(Type) * 8, stream)); - array temp(allocator::malloc(size), {static_cast(size)}, uint8); + array temp( + cu::malloc_async(size, encoder.stream()), + {static_cast(size)}, + uint8); encoder.add_temporary(temp); // Start capturing after allocations