From 56cc858af965ae29ba040d53c5d6d0a321c58485 Mon Sep 17 00:00:00 2001 From: Cheng Date: Mon, 21 Jul 2025 23:30:35 +0900 Subject: [PATCH] Add contiguous_copy_cpu util for copying array (#2397) --- mlx/backend/cpu/copy.cpp | 6 ++++++ mlx/backend/cpu/copy.h | 3 +++ mlx/backend/cpu/distributed.cpp | 7 ++----- mlx/backend/cpu/logsumexp.cpp | 3 +-- mlx/backend/cpu/masked_mm.cpp | 3 +-- mlx/backend/cpu/quantized.cpp | 4 +--- mlx/backend/cpu/scan.cpp | 6 ++---- mlx/backend/cpu/softmax.cpp | 3 +-- 8 files changed, 17 insertions(+), 18 deletions(-) diff --git a/mlx/backend/cpu/copy.cpp b/mlx/backend/cpu/copy.cpp index 47ae3ef4b..f9ff22677 100644 --- a/mlx/backend/cpu/copy.cpp +++ b/mlx/backend/cpu/copy.cpp @@ -377,4 +377,10 @@ void copy_cpu_inplace( }); } +array contiguous_copy_cpu(const array& arr, Stream stream) { + array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); + copy_cpu(arr, arr_copy, CopyType::General, stream); + return arr_copy; +} + } // namespace mlx::core diff --git a/mlx/backend/cpu/copy.h b/mlx/backend/cpu/copy.h index ee303c3d3..007291369 100644 --- a/mlx/backend/cpu/copy.h +++ b/mlx/backend/cpu/copy.h @@ -30,4 +30,7 @@ void copy_cpu_inplace( const std::optional& dynamic_i_offset = std::nullopt, const std::optional& dynamic_o_offset = std::nullopt); +// Return a contiguous array with same shape that copies the data of |arr|. +array contiguous_copy_cpu(const array& arr, Stream stream); + } // namespace mlx::core diff --git a/mlx/backend/cpu/distributed.cpp b/mlx/backend/cpu/distributed.cpp index ac6201552..d641d581b 100644 --- a/mlx/backend/cpu/distributed.cpp +++ b/mlx/backend/cpu/distributed.cpp @@ -13,9 +13,7 @@ std::pair ensure_row_contiguous(const array& arr, Stream stream) { if (arr.flags().row_contiguous) { return {arr, false}; } else { - array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); - copy_cpu(arr, arr_copy, CopyType::General, stream); - return {arr_copy, true}; + return {contiguous_copy_cpu(arr, stream), true}; } }; @@ -34,8 +32,7 @@ void AllReduce::eval_cpu( } return in; } else { - array arr_copy(in.shape(), in.dtype(), nullptr, {}); - copy_cpu(in, arr_copy, CopyType::General, s); + array arr_copy = contiguous_copy_cpu(in, s); out.copy_shared_buffer(arr_copy); return arr_copy; } diff --git a/mlx/backend/cpu/logsumexp.cpp b/mlx/backend/cpu/logsumexp.cpp index 3ae9a3cce..d061907f9 100644 --- a/mlx/backend/cpu/logsumexp.cpp +++ b/mlx/backend/cpu/logsumexp.cpp @@ -87,8 +87,7 @@ void LogSumExp::eval_cpu(const std::vector& inputs, array& out) { if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) { return x; } else { - auto x_copy = array(x.shape(), x.dtype(), nullptr, {}); - copy_cpu(x, x_copy, CopyType::General, s); + array x_copy = contiguous_copy_cpu(x, s); encoder.add_temporary(x_copy); return x_copy; } diff --git a/mlx/backend/cpu/masked_mm.cpp b/mlx/backend/cpu/masked_mm.cpp index 7f2a7cf4a..c3efb79cd 100644 --- a/mlx/backend/cpu/masked_mm.cpp +++ b/mlx/backend/cpu/masked_mm.cpp @@ -136,9 +136,8 @@ void BlockMaskedMM::eval_cpu(const std::vector& inputs, array& out) { } return std::make_tuple(true, sty, arr, false); } else { - array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); - copy_cpu(arr, arr_copy, CopyType::General, s); int64_t stx = arr.shape(-1); + array arr_copy = contiguous_copy_cpu(arr, s); return std::make_tuple(false, stx, arr_copy, true); } }; diff --git a/mlx/backend/cpu/quantized.cpp b/mlx/backend/cpu/quantized.cpp index ee61221da..1c02c4e61 100644 --- a/mlx/backend/cpu/quantized.cpp +++ b/mlx/backend/cpu/quantized.cpp @@ -712,9 +712,7 @@ void fast::AffineQuantize::eval_cpu( if (arr.flags().row_contiguous) { return std::make_pair(arr, false); } else { - array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); - copy_cpu(arr, arr_copy, CopyType::General, s); - return std::make_pair(arr_copy, true); + return std::make_pair(contiguous_copy_cpu(arr, s), true); } }; diff --git a/mlx/backend/cpu/scan.cpp b/mlx/backend/cpu/scan.cpp index a62763fa0..4dc2f50ea 100644 --- a/mlx/backend/cpu/scan.cpp +++ b/mlx/backend/cpu/scan.cpp @@ -250,10 +250,8 @@ void Scan::eval_cpu(const std::vector& inputs, array& out) { // Ensure contiguity auto in = inputs[0]; if (!in.flags().row_contiguous) { - array arr_copy(in.shape(), in.dtype(), nullptr, {}); - copy_cpu(in, arr_copy, CopyType::General, stream()); - in = arr_copy; - encoder.add_temporary(arr_copy); + in = contiguous_copy_cpu(in, stream()); + encoder.add_temporary(in); } out.set_data(allocator::malloc(out.nbytes())); diff --git a/mlx/backend/cpu/softmax.cpp b/mlx/backend/cpu/softmax.cpp index 8823c7906..4c2941e96 100644 --- a/mlx/backend/cpu/softmax.cpp +++ b/mlx/backend/cpu/softmax.cpp @@ -131,8 +131,7 @@ void Softmax::eval_cpu(const std::vector& inputs, array& out) { } return x; } else { - array x_copy(x.shape(), x.dtype(), nullptr, {}); - copy_cpu(x, x_copy, CopyType::General, s); + array x_copy = contiguous_copy_cpu(x, s); out.copy_shared_buffer(x_copy); return x_copy; }