diff --git a/mlx/backend/cuda/layer_norm.cu b/mlx/backend/cuda/layer_norm.cu index 5fbf949d7..83a9c2a67 100644 --- a/mlx/backend/cuda/layer_norm.cu +++ b/mlx/backend/cuda/layer_norm.cu @@ -237,8 +237,7 @@ void LayerNorm::eval_gpu( } return x; } else { - auto x_copy = array(x.shape(), x.dtype(), nullptr, {}); - copy_gpu(x, x_copy, CopyType::General, s); + array x_copy = contiguous_copy_gpu(x, s); out.copy_shared_buffer(x_copy); return x_copy; } @@ -295,9 +294,7 @@ void LayerNormVJP::eval_gpu( return x; } copied = true; - array x_copy(x.shape(), x.dtype(), nullptr, {}); - copy_gpu(x, x_copy, CopyType::General, s); - return x_copy; + return contiguous_copy_gpu(x, s); }; bool donate_x = inputs[0].is_donatable(); bool donate_g = inputs[3].is_donatable(); diff --git a/mlx/backend/cuda/logsumexp.cu b/mlx/backend/cuda/logsumexp.cu index afc52826f..ba5836a33 100644 --- a/mlx/backend/cuda/logsumexp.cu +++ b/mlx/backend/cuda/logsumexp.cu @@ -108,8 +108,7 @@ void LogSumExp::eval_gpu(const std::vector& inputs, array& out) { if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) { return x; } else { - auto x_copy = array(x.shape(), x.dtype(), nullptr, {}); - copy_gpu(x, x_copy, CopyType::General, s); + array x_copy = contiguous_copy_gpu(x, s); encoder.add_temporary(x_copy); return x_copy; } diff --git a/mlx/backend/cuda/matmul.cpp b/mlx/backend/cuda/matmul.cpp index b70f61e3d..4110e7eff 100644 --- a/mlx/backend/cuda/matmul.cpp +++ b/mlx/backend/cuda/matmul.cpp @@ -297,8 +297,7 @@ check_transpose(cu::CommandEncoder& enc, const Stream& s, const array& arr) { } else if (stx == 1 && sty == arr.shape(-2)) { return std::make_tuple(true, sty, arr); } else { - array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); - copy_gpu(arr, arr_copy, CopyType::General, s); + array arr_copy = contiguous_copy_gpu(arr, s); enc.add_temporary(arr_copy); return std::make_tuple(false, arr.shape(-1), arr_copy); } diff --git a/mlx/backend/cuda/quantized.cu b/mlx/backend/cuda/quantized.cu index 4424000d8..204dbd547 100644 --- a/mlx/backend/cuda/quantized.cu +++ b/mlx/backend/cuda/quantized.cu @@ -247,8 +247,7 @@ inline array ensure_row_contiguous( cu::CommandEncoder& enc, const Stream& s) { if (!x.flags().row_contiguous) { - array x_copy(x.shape(), x.dtype(), nullptr, {}); - copy_gpu(x, x_copy, CopyType::General, s); + array x_copy = contiguous_copy_gpu(x, s); enc.add_temporary(x_copy); return x_copy; } else { diff --git a/mlx/backend/cuda/reduce.cu b/mlx/backend/cuda/reduce.cu index 8350eebb7..87cb3aedc 100644 --- a/mlx/backend/cuda/reduce.cu +++ b/mlx/backend/cuda/reduce.cu @@ -47,8 +47,7 @@ void Reduce::eval_gpu(const std::vector& inputs, array& out) { } } if (plan.type == GeneralReduce || broadcasted || !in.flags().contiguous) { - array in_copy(in.shape(), in.dtype(), nullptr, {}); - copy_gpu(in, in_copy, CopyType::General, s); + array in_copy = contiguous_copy_gpu(in, s); encoder.add_temporary(in_copy); in = in_copy; plan = get_reduction_plan(in, axes_); diff --git a/mlx/backend/cuda/rms_norm.cu b/mlx/backend/cuda/rms_norm.cu index 964bd7d98..66b759b5e 100644 --- a/mlx/backend/cuda/rms_norm.cu +++ b/mlx/backend/cuda/rms_norm.cu @@ -206,8 +206,7 @@ void RMSNorm::eval_gpu( } return x; } else { - auto x_copy = array(x.shape(), x.dtype(), nullptr, {}); - copy_gpu(x, x_copy, CopyType::General, s); + array x_copy = contiguous_copy_gpu(x, s); out.copy_shared_buffer(x_copy); return x_copy; } @@ -259,9 +258,7 @@ void RMSNormVJP::eval_gpu( return x; } copied = true; - array x_copy(x.shape(), x.dtype(), nullptr, {}); - copy_gpu(x, x_copy, CopyType::General, s); - return x_copy; + return contiguous_copy_gpu(x, s); }; bool donate_x = inputs[0].is_donatable(); bool donate_g = inputs[2].is_donatable(); diff --git a/mlx/backend/cuda/scan.cu b/mlx/backend/cuda/scan.cu index 7a26ee161..969264e34 100644 --- a/mlx/backend/cuda/scan.cu +++ b/mlx/backend/cuda/scan.cu @@ -379,9 +379,7 @@ void Scan::eval_gpu(const std::vector& inputs, array& out) { in.flags()); } } else { - array arr_copy(in.shape(), in.dtype(), nullptr, {}); - copy_gpu(in, arr_copy, CopyType::General, s); - in = std::move(arr_copy); + in = contiguous_copy_gpu(in, s); out.copy_shared_buffer(in); } diff --git a/mlx/backend/cuda/softmax.cu b/mlx/backend/cuda/softmax.cu index 56f67d7f3..53615ae4d 100644 --- a/mlx/backend/cuda/softmax.cu +++ b/mlx/backend/cuda/softmax.cu @@ -125,8 +125,7 @@ void Softmax::eval_gpu(const std::vector& inputs, array& out) { } return x; } else { - auto x_copy = array(x.shape(), x.dtype(), nullptr, {}); - copy_gpu(x, x_copy, CopyType::General, s); + array x_copy = contiguous_copy_gpu(x, s); out.copy_shared_buffer(x_copy); return x_copy; } diff --git a/mlx/backend/cuda/sort.cu b/mlx/backend/cuda/sort.cu index 379c55706..e8e0f389c 100644 --- a/mlx/backend/cuda/sort.cu +++ b/mlx/backend/cuda/sort.cu @@ -72,8 +72,7 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) { bool is_segmented_sort = in.flags().contiguous && in.strides()[axis] == 1; if (!is_segmented_sort) { array trans = swapaxes_in_eval(in, axis, last_dim); - in = array(trans.shape(), trans.dtype(), nullptr, {}); - copy_gpu(trans, in, CopyType::General, s); + in = contiguous_copy_gpu(trans, s); encoder.add_temporary(in); out = array(allocator::malloc(out.nbytes()), in.shape(), out.dtype()); encoder.add_temporary(out); diff --git a/mlx/backend/gpu/copy.cpp b/mlx/backend/gpu/copy.cpp index 6127ac921..4556f7d98 100644 --- a/mlx/backend/gpu/copy.cpp +++ b/mlx/backend/gpu/copy.cpp @@ -46,4 +46,10 @@ void copy_gpu_inplace( in, out, in.shape(), i_strides, out.strides(), i_offset, 0, ctype, s); } +array contiguous_copy_gpu(const array& arr, const Stream& s) { + array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); + copy_gpu(arr, arr_copy, CopyType::General, s); + return arr_copy; +} + } // namespace mlx::core diff --git a/mlx/backend/gpu/copy.h b/mlx/backend/gpu/copy.h index 020f579e4..f01fe9fda 100644 --- a/mlx/backend/gpu/copy.h +++ b/mlx/backend/gpu/copy.h @@ -43,4 +43,7 @@ void copy_gpu_inplace( // Fill the output with the scalar val void fill_gpu(const array& val, array& out, const Stream& s); +// Return a contiguous array with same shape that copies the data of |arr|. +array contiguous_copy_gpu(const array& arr, const Stream& s); + } // namespace mlx::core diff --git a/mlx/backend/metal/conv.cpp b/mlx/backend/metal/conv.cpp index 9eb6a6385..06d058eae 100644 --- a/mlx/backend/metal/conv.cpp +++ b/mlx/backend/metal/conv.cpp @@ -149,8 +149,7 @@ void explicit_gemm_conv_group_ND_gpu( wt, {wt.strides(0), 1, C_per_group}, wt.flags(), wt.size()); // Materialize - auto wt_transpose = array(wt_view.shape(), wt_view.dtype(), nullptr, {}); - copy_gpu(wt_view, wt_transpose, CopyType::General, s); + array wt_transpose = contiguous_copy_gpu(wt_view, s); // Perform gemm std::vector copies = {in_unfolded, wt_transpose}; @@ -961,16 +960,12 @@ void Convolution::eval_gpu(const std::vector& inputs, array& out) { auto in = inputs[0]; auto wt = inputs[1]; if (!in.flags().row_contiguous) { - array arr_copy(in.shape(), in.dtype(), nullptr, {}); - copy_gpu(in, arr_copy, CopyType::General, s); - copies.push_back(arr_copy); - in = arr_copy; + in = contiguous_copy_gpu(in, s); + copies.push_back(in); } if (!wt.flags().row_contiguous) { - array arr_copy(wt.shape(), wt.dtype(), nullptr, {}); - copy_gpu(wt, arr_copy, CopyType::General, s); - copies.push_back(arr_copy); - wt = arr_copy; + wt = contiguous_copy_gpu(wt, s); + copies.push_back(wt); } // 3D conv diff --git a/mlx/backend/metal/logsumexp.cpp b/mlx/backend/metal/logsumexp.cpp index e53bc58d9..2cfdcdc8a 100644 --- a/mlx/backend/metal/logsumexp.cpp +++ b/mlx/backend/metal/logsumexp.cpp @@ -25,8 +25,7 @@ void LogSumExp::eval_gpu(const std::vector& inputs, array& out) { if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) { return x; } else { - auto x_copy = array(x.shape(), x.dtype(), nullptr, {}); - copy_gpu(x, x_copy, CopyType::General, s); + array x_copy = contiguous_copy_gpu(x, s); d.add_temporary(x_copy, s.index); return x_copy; } diff --git a/mlx/backend/metal/matmul.cpp b/mlx/backend/metal/matmul.cpp index 55b8be3a9..5a185f416 100644 --- a/mlx/backend/metal/matmul.cpp +++ b/mlx/backend/metal/matmul.cpp @@ -33,8 +33,7 @@ std::tuple check_transpose( } else if (stx == 1 && (!is_vector || sty == arr.shape(-2))) { return std::make_tuple(true, sty, arr); } else { - array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); - copy_gpu(arr, arr_copy, CopyType::General, s); + array arr_copy = contiguous_copy_gpu(arr, s); copies.push_back(arr_copy); return std::make_tuple(false, arr.shape(-1), arr_copy); } @@ -43,8 +42,7 @@ std::tuple check_transpose( inline array ensure_row_contiguous(const array& x, metal::Device& d, const Stream& s) { if (!x.flags().row_contiguous) { - array x_copy(x.shape(), x.dtype(), nullptr, {}); - copy_gpu(x, x_copy, CopyType::General, s); + array x_copy = contiguous_copy_gpu(x, s); d.add_temporary(x_copy, s.index); return x_copy; } else { @@ -75,8 +73,7 @@ ensure_batch_contiguous(const array& x, metal::Device& d, const Stream& s) { } } - array x_copy(x.shape(), x.dtype(), nullptr, {}); - copy_gpu(x, x_copy, CopyType::General, s); + array x_copy = contiguous_copy_gpu(x, s); d.add_temporary(x_copy, s.index); return std::make_tuple(false, x_copy.strides()[x_copy.ndim() - 2], x_copy); } @@ -1894,8 +1891,7 @@ void segmented_mm( return std::make_tuple(false, x); } - array x_copy(x.shape(), x.dtype(), nullptr, {}); - copy_gpu(x, x_copy, CopyType::General, s); + array x_copy = contiguous_copy_gpu(x, s); d.add_temporary(x_copy, s.index); return std::make_tuple(true, x_copy); }; diff --git a/mlx/backend/metal/normalization.cpp b/mlx/backend/metal/normalization.cpp index 8674eff72..40e7b5bc8 100644 --- a/mlx/backend/metal/normalization.cpp +++ b/mlx/backend/metal/normalization.cpp @@ -40,8 +40,7 @@ void RMSNorm::eval_gpu( } return x; } else { - auto x_copy = array(x.shape(), x.dtype(), nullptr, {}); - copy_gpu(x, x_copy, CopyType::General, s); + array x_copy = contiguous_copy_gpu(x, s); out.copy_shared_buffer(x_copy); return x_copy; } @@ -107,9 +106,7 @@ void RMSNormVJP::eval_gpu( if (x.flags().row_contiguous) { return {x, false}; } - - array x_copy(x.shape(), x.dtype(), nullptr, {}); - copy_gpu(x, x_copy, CopyType::General, s); + array x_copy = contiguous_copy_gpu(x, s); return {x_copy, true}; }; bool donate_x = inputs[0].is_donatable(); @@ -241,8 +238,7 @@ void LayerNorm::eval_gpu( } return x; } else { - auto x_copy = array(x.shape(), x.dtype(), nullptr, {}); - copy_gpu(x, x_copy, CopyType::General, s); + array x_copy = contiguous_copy_gpu(x, s); out.copy_shared_buffer(x_copy); return x_copy; } @@ -319,8 +315,7 @@ void LayerNormVJP::eval_gpu( if (x.flags().row_contiguous) { return {x, false}; } - array x_copy(x.shape(), x.dtype(), nullptr, {}); - copy_gpu(x, x_copy, CopyType::General, s); + array x_copy = contiguous_copy_gpu(x, s); return {x_copy, true}; }; bool donate_x = inputs[0].is_donatable(); diff --git a/mlx/backend/metal/quantized.cpp b/mlx/backend/metal/quantized.cpp index b6dc8db30..53f1c96f3 100644 --- a/mlx/backend/metal/quantized.cpp +++ b/mlx/backend/metal/quantized.cpp @@ -20,8 +20,7 @@ namespace { inline array ensure_row_contiguous(const array& x, metal::Device& d, const Stream& s) { if (!x.flags().row_contiguous) { - array x_copy(x.shape(), x.dtype(), nullptr, {}); - copy_gpu(x, x_copy, CopyType::General, s); + array x_copy = contiguous_copy_gpu(x, s); d.add_temporary(x_copy, s.index); return x_copy; } else { @@ -38,8 +37,7 @@ inline array ensure_row_contiguous_matrix( if (stride_0 == x.shape(-1) && stride_1 == 1) { return x; } else { - array x_copy(x.shape(), x.dtype(), nullptr, {}); - copy_gpu(x, x_copy, CopyType::General, s); + array x_copy = contiguous_copy_gpu(x, s); d.add_temporary(x_copy, s.index); return x_copy; } diff --git a/mlx/backend/metal/reduce.cpp b/mlx/backend/metal/reduce.cpp index 8cb55ba58..3ae766ba9 100644 --- a/mlx/backend/metal/reduce.cpp +++ b/mlx/backend/metal/reduce.cpp @@ -989,8 +989,7 @@ void Reduce::eval_gpu(const std::vector& inputs, array& out) { // input for the axes with stride smaller than the minimum reduction // stride. if (plan.type == GeneralReduce) { - array in_copy(in.shape(), in.dtype(), nullptr, {}); - copy_gpu(in, in_copy, CopyType::General, s); + array in_copy = contiguous_copy_gpu(in, s); d.add_temporary(in_copy, s.index); in = in_copy; plan = get_reduction_plan(in, axes_); diff --git a/mlx/backend/metal/scaled_dot_product_attention.cpp b/mlx/backend/metal/scaled_dot_product_attention.cpp index eef279d1d..9647d0884 100644 --- a/mlx/backend/metal/scaled_dot_product_attention.cpp +++ b/mlx/backend/metal/scaled_dot_product_attention.cpp @@ -398,8 +398,7 @@ void ScaledDotProductAttention::eval_gpu( auto copy_unless = [&copies, &s]( auto predicate, const array& arr) -> const array& { if (!predicate(arr)) { - array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); - copy_gpu(arr, arr_copy, CopyType::General, s); + array arr_copy = contiguous_copy_gpu(arr, s); copies.push_back(std::move(arr_copy)); return copies.back(); } else { diff --git a/mlx/backend/metal/scan.cpp b/mlx/backend/metal/scan.cpp index 3c4051105..fd1899108 100644 --- a/mlx/backend/metal/scan.cpp +++ b/mlx/backend/metal/scan.cpp @@ -30,9 +30,7 @@ void Scan::eval_gpu(const std::vector& inputs, array& out) { in.flags()); } } else { - array arr_copy(in.shape(), in.dtype(), nullptr, {}); - copy_gpu(in, arr_copy, CopyType::General, s); - in = std::move(arr_copy); + in = contiguous_copy_gpu(in, s); out.copy_shared_buffer(in); } diff --git a/mlx/backend/metal/softmax.cpp b/mlx/backend/metal/softmax.cpp index 59662b05d..0b1a1848d 100644 --- a/mlx/backend/metal/softmax.cpp +++ b/mlx/backend/metal/softmax.cpp @@ -35,8 +35,7 @@ void Softmax::eval_gpu(const std::vector& inputs, array& out) { } return x; } else { - auto x_copy = array(x.shape(), x.dtype(), nullptr, {}); - copy_gpu(x, x_copy, CopyType::General, s); + array x_copy = contiguous_copy_gpu(x, s); out.copy_shared_buffer(x_copy); return x_copy; }