diff --git a/mlx/backend/cpu/cholesky.cpp b/mlx/backend/cpu/cholesky.cpp index f6b28a85a..3c5bbbc93 100644 --- a/mlx/backend/cpu/cholesky.cpp +++ b/mlx/backend/cpu/cholesky.cpp @@ -20,7 +20,7 @@ void cholesky_impl(const array& a, array& factor, bool upper, Stream stream) { // The decomposition is computed in place, so just copy the input to the // output. - copy( + copy_cpu( a, factor, a.flags().row_contiguous ? CopyType::Vector : CopyType::General, diff --git a/mlx/backend/cpu/conv.cpp b/mlx/backend/cpu/conv.cpp index e5636b3b8..5684f9709 100644 --- a/mlx/backend/cpu/conv.cpp +++ b/mlx/backend/cpu/conv.cpp @@ -883,7 +883,7 @@ void explicit_gemm_conv_1D_cpu( // Fill with zeros std::vector temps; temps.push_back(array(0, conv_dtype)); - copy(temps.back(), in_padded, CopyType::Scalar, stream); + copy_cpu(temps.back(), in_padded, CopyType::Scalar, stream); // Pick input slice from padded size_t data_offset = padding_lo[0] * in_padded.strides()[1]; @@ -895,7 +895,7 @@ void explicit_gemm_conv_1D_cpu( in_padded_slice.size(), data_offset); // Copy input values into the slice - copy_inplace(in, in_padded_slice, CopyType::GeneralGeneral, stream); + copy_cpu_inplace(in, in_padded_slice, CopyType::GeneralGeneral, stream); temps.push_back(in_padded_slice); // Make strided view @@ -920,7 +920,7 @@ void explicit_gemm_conv_1D_cpu( // Materialize strided view Shape strided_reshape = {N * oH, wH * C}; array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {}); - copy(in_strided_view, in_strided, CopyType::General, stream); + copy_cpu(in_strided_view, in_strided, CopyType::General, stream); temps.push_back(in_strided); // Check wt dtype and prepare @@ -938,13 +938,13 @@ void explicit_gemm_conv_1D_cpu( wt.size(), 0); gemm_wt = array(wt_transpose.shape(), float32, nullptr, {}); - copy(wt_transpose, gemm_wt, CopyType::General, stream); + copy_cpu(wt_transpose, gemm_wt, CopyType::General, stream); temps.push_back(gemm_wt); } else if (wt.dtype() != float32 || !wt.flags().row_contiguous) { auto ctype = wt.flags().row_contiguous ? CopyType::Vector : CopyType::General; gemm_wt = array(wt.shape(), float32, nullptr, {}); - copy(wt, gemm_wt, ctype, stream); + copy_cpu(wt, gemm_wt, ctype, stream); temps.push_back(gemm_wt); } @@ -991,7 +991,7 @@ void explicit_gemm_conv_1D_cpu( // Copy results if needed if (out.dtype() != float32) { - copy_inplace(gemm_out, out, CopyType::Vector, stream); + copy_cpu_inplace(gemm_out, out, CopyType::Vector, stream); } encoder.add_temporaries(std::move(temps)); } @@ -1029,7 +1029,7 @@ void explicit_gemm_conv_2D_cpu( // Fill with zeros std::vector temps; temps.push_back(array(0, conv_dtype)); - copy(temps.back(), in_padded, CopyType::Scalar, stream); + copy_cpu(temps.back(), in_padded, CopyType::Scalar, stream); // Pick input slice from padded size_t data_offset = padding_lo[0] * in_padded.strides()[1] + @@ -1044,7 +1044,7 @@ void explicit_gemm_conv_2D_cpu( temps.push_back(in_padded_slice); // Copy input values into the slice - copy_inplace(in, in_padded_slice, CopyType::GeneralGeneral, stream); + copy_cpu_inplace(in, in_padded_slice, CopyType::GeneralGeneral, stream); // Make strided view Shape strided_shape = {N, oH, oW, wH, wW, C}; @@ -1065,7 +1065,7 @@ void explicit_gemm_conv_2D_cpu( // Materialize strided view Shape strided_reshape = {N * oH * oW, wH * wW * C}; array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {}); - copy(in_strided_view, in_strided, CopyType::General, stream); + copy_cpu(in_strided_view, in_strided, CopyType::General, stream); temps.push_back(in_strided); // Check wt dtype and prepare @@ -1076,7 +1076,7 @@ void explicit_gemm_conv_2D_cpu( auto ctype = wt.flags().row_contiguous ? CopyType::Vector : CopyType::General; gemm_wt = array(wt.shape(), float32, nullptr, {}); - copy(wt, gemm_wt, ctype, stream); + copy_cpu(wt, gemm_wt, ctype, stream); temps.push_back(gemm_wt); } @@ -1116,7 +1116,7 @@ void explicit_gemm_conv_2D_cpu( // Copy results if needed if (out.dtype() != float32) { - copy_inplace(gemm_out, out, CopyType::Vector, stream); + copy_cpu_inplace(gemm_out, out, CopyType::Vector, stream); } encoder.add_temporaries(std::move(temps)); } @@ -1156,7 +1156,7 @@ void explicit_gemm_conv_ND_cpu( // Fill with zeros std::vector temps = {array(0, conv_dtype)}; - copy(temps.back(), in_padded, CopyType::Scalar, stream); + copy_cpu(temps.back(), in_padded, CopyType::Scalar, stream); // Pick input slice from padded size_t data_offset = 0; @@ -1173,7 +1173,7 @@ void explicit_gemm_conv_ND_cpu( data_offset); // Copy input values into the slice - copy_inplace(in, in_padded_slice, CopyType::GeneralGeneral, stream); + copy_cpu_inplace(in, in_padded_slice, CopyType::GeneralGeneral, stream); temps.push_back(in_padded_slice); // Make strided view @@ -1212,7 +1212,7 @@ void explicit_gemm_conv_ND_cpu( } array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {}); - copy(in_strided_view, in_strided, CopyType::General, stream); + copy_cpu(in_strided_view, in_strided, CopyType::General, stream); temps.push_back(in_strided); // Check wt dtype and prepare @@ -1223,13 +1223,13 @@ void explicit_gemm_conv_ND_cpu( auto ctype = wt.flags().row_contiguous ? CopyType::Vector : CopyType::General; gemm_wt = array(wt.shape(), float32, nullptr, {}); - copy(wt, gemm_wt, ctype, stream); + copy_cpu(wt, gemm_wt, ctype, stream); temps.push_back(gemm_wt); } if (flip) { auto gemm_wt_ = array(gemm_wt.shape(), float32, nullptr, {}); - copy(gemm_wt, gemm_wt_, CopyType::Vector, stream); + copy_cpu(gemm_wt, gemm_wt_, CopyType::Vector, stream); temps.push_back(gemm_wt_); // Calculate the total size of the spatial dimensions @@ -1284,7 +1284,7 @@ void explicit_gemm_conv_ND_cpu( // Copy results if needed if (out.dtype() != float32) { - copy_inplace(gemm_out, out, CopyType::Vector, stream); + copy_cpu_inplace(gemm_out, out, CopyType::Vector, stream); } encoder.add_temporaries(std::move(temps)); } diff --git a/mlx/backend/cpu/copy.cpp b/mlx/backend/cpu/copy.cpp index f9b8595dd..47ae3ef4b 100644 --- a/mlx/backend/cpu/copy.cpp +++ b/mlx/backend/cpu/copy.cpp @@ -295,7 +295,11 @@ inline void copy_inplace_dispatch( } // namespace -void copy_inplace(const array& src, array& dst, CopyType ctype, Stream stream) { +void copy_cpu_inplace( + const array& src, + array& dst, + CopyType ctype, + Stream stream) { auto& encoder = cpu::get_command_encoder(stream); encoder.set_input_array(src); encoder.set_output_array(dst); @@ -305,7 +309,7 @@ void copy_inplace(const array& src, array& dst, CopyType ctype, Stream stream) { ctype]() mutable { copy_inplace_dispatch(src, dst, ctype); }); } -void copy(const array& src, array& dst, CopyType ctype, Stream stream) { +void copy_cpu(const array& src, array& dst, CopyType ctype, Stream stream) { bool donated = set_copy_output_data(src, dst, ctype); if (donated && src.dtype() == dst.dtype()) { // If the output has the same type as the input then there is nothing to @@ -315,10 +319,10 @@ void copy(const array& src, array& dst, CopyType ctype, Stream stream) { if (ctype == CopyType::GeneralGeneral) { ctype = CopyType::General; } - copy_inplace(src, dst, ctype, stream); + copy_cpu_inplace(src, dst, ctype, stream); } -void copy_inplace( +void copy_cpu_inplace( const array& src, array& dst, const Shape& data_shape, diff --git a/mlx/backend/cpu/copy.h b/mlx/backend/cpu/copy.h index e930da465..ee303c3d3 100644 --- a/mlx/backend/cpu/copy.h +++ b/mlx/backend/cpu/copy.h @@ -10,10 +10,14 @@ namespace mlx::core { -void copy(const array& src, array& dst, CopyType ctype, Stream stream); -void copy_inplace(const array& src, array& dst, CopyType ctype, Stream stream); +void copy_cpu(const array& src, array& dst, CopyType ctype, Stream stream); +void copy_cpu_inplace( + const array& src, + array& dst, + CopyType ctype, + Stream stream); -void copy_inplace( +void copy_cpu_inplace( const array& src, array& dst, const Shape& data_shape, diff --git a/mlx/backend/cpu/distributed.cpp b/mlx/backend/cpu/distributed.cpp index a3edf8f49..ac6201552 100644 --- a/mlx/backend/cpu/distributed.cpp +++ b/mlx/backend/cpu/distributed.cpp @@ -14,7 +14,7 @@ std::pair ensure_row_contiguous(const array& arr, Stream stream) { return {arr, false}; } else { array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); - copy(arr, arr_copy, CopyType::General, stream); + copy_cpu(arr, arr_copy, CopyType::General, stream); return {arr_copy, true}; } }; @@ -35,7 +35,7 @@ void AllReduce::eval_cpu( return in; } else { array arr_copy(in.shape(), in.dtype(), nullptr, {}); - copy(in, arr_copy, CopyType::General, s); + copy_cpu(in, arr_copy, CopyType::General, s); out.copy_shared_buffer(arr_copy); return arr_copy; } diff --git a/mlx/backend/cpu/eig.cpp b/mlx/backend/cpu/eig.cpp index c89003fc0..a01295145 100644 --- a/mlx/backend/cpu/eig.cpp +++ b/mlx/backend/cpu/eig.cpp @@ -135,7 +135,7 @@ void Eig::eval_cpu( : array(a.shape(), complex64, nullptr, {}); auto a_copy = array(a.shape(), a.dtype(), nullptr, {}); - copy( + copy_cpu( a, a_copy, a.flags().row_contiguous ? CopyType::Vector : CopyType::General, diff --git a/mlx/backend/cpu/eigh.cpp b/mlx/backend/cpu/eigh.cpp index 58d3634e8..d457c1fd9 100644 --- a/mlx/backend/cpu/eigh.cpp +++ b/mlx/backend/cpu/eigh.cpp @@ -196,7 +196,7 @@ void Eigh::eval_cpu( values.set_data(allocator::malloc(values.nbytes())); - copy( + copy_cpu( a, vectors, a.flags().row_contiguous ? CopyType::Vector : CopyType::General, diff --git a/mlx/backend/cpu/hadamard.cpp b/mlx/backend/cpu/hadamard.cpp index b0734ad79..bf7e1dc26 100644 --- a/mlx/backend/cpu/hadamard.cpp +++ b/mlx/backend/cpu/hadamard.cpp @@ -96,7 +96,7 @@ void Hadamard::eval_cpu(const std::vector& inputs, array& out) { if (in.flags().row_contiguous && in.is_donatable()) { out.copy_shared_buffer(in); } else { - copy( + copy_cpu( in, out, in.flags().row_contiguous ? CopyType::Vector : CopyType::General, diff --git a/mlx/backend/cpu/indexing.cpp b/mlx/backend/cpu/indexing.cpp index 5f99093e5..6daced6fa 100644 --- a/mlx/backend/cpu/indexing.cpp +++ b/mlx/backend/cpu/indexing.cpp @@ -517,7 +517,7 @@ void Scatter::eval_cpu(const std::vector& inputs, array& out) { // Copy src into out (copy allocates memory for out) auto ctype = src.flags().row_contiguous ? CopyType::Vector : CopyType::General; - copy(src, out, ctype, stream()); + copy_cpu(src, out, ctype, stream()); auto& encoder = cpu::get_command_encoder(stream()); std::vector inds; @@ -686,7 +686,7 @@ void ScatterAxis::eval_cpu(const std::vector& inputs, array& out) { // Copy src into out (copy allocates memory for out) auto ctype = src.flags().row_contiguous ? CopyType::Vector : CopyType::General; - copy(src, out, ctype, stream()); + copy_cpu(src, out, ctype, stream()); auto& encoder = cpu::get_command_encoder(stream()); encoder.set_input_array(idx); diff --git a/mlx/backend/cpu/inverse.cpp b/mlx/backend/cpu/inverse.cpp index 2e79addcb..ddc979daa 100644 --- a/mlx/backend/cpu/inverse.cpp +++ b/mlx/backend/cpu/inverse.cpp @@ -115,7 +115,7 @@ void inverse_impl( // (A⁻¹)ᵀ = (Aᵀ)⁻¹ // The inverse is computed in place, so just copy the input to the output. - copy( + copy_cpu( a, inv, a.flags().row_contiguous ? CopyType::Vector : CopyType::General, diff --git a/mlx/backend/cpu/logsumexp.cpp b/mlx/backend/cpu/logsumexp.cpp index 56f0dab9f..3ae9a3cce 100644 --- a/mlx/backend/cpu/logsumexp.cpp +++ b/mlx/backend/cpu/logsumexp.cpp @@ -88,7 +88,7 @@ void LogSumExp::eval_cpu(const std::vector& inputs, array& out) { return x; } else { auto x_copy = array(x.shape(), x.dtype(), nullptr, {}); - copy(x, x_copy, CopyType::General, s); + copy_cpu(x, x_copy, CopyType::General, s); encoder.add_temporary(x_copy); return x_copy; } diff --git a/mlx/backend/cpu/luf.cpp b/mlx/backend/cpu/luf.cpp index 9ac9361a4..5f1507e18 100644 --- a/mlx/backend/cpu/luf.cpp +++ b/mlx/backend/cpu/luf.cpp @@ -31,7 +31,7 @@ void luf_impl( strides[ndim - 1] = M; strides[ndim - 2] = 1; lu.set_data(allocator::malloc(lu.nbytes()), lu.nbytes(), strides, flags); - copy_inplace( + copy_cpu_inplace( a, lu, a.shape(), diff --git a/mlx/backend/cpu/masked_mm.cpp b/mlx/backend/cpu/masked_mm.cpp index fbee6118f..7f2a7cf4a 100644 --- a/mlx/backend/cpu/masked_mm.cpp +++ b/mlx/backend/cpu/masked_mm.cpp @@ -124,20 +124,20 @@ void BlockMaskedMM::eval_cpu(const std::vector& inputs, array& out) { if (!expand_all && stx == arr.shape(-1) && sty == 1) { if (do_copy) { array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); - copy(arr, arr_copy, CopyType::Vector, s); + copy_cpu(arr, arr_copy, CopyType::Vector, s); return std::make_tuple(false, stx, arr_copy, true); } return std::make_tuple(false, stx, arr, false); } else if (!expand_all && stx == 1 && sty == arr.shape(-2)) { if (do_copy) { array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); - copy(arr, arr_copy, CopyType::Vector, s); + copy_cpu(arr, arr_copy, CopyType::Vector, s); return std::make_tuple(true, sty, arr_copy, true); } return std::make_tuple(true, sty, arr, false); } else { array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); - copy(arr, arr_copy, CopyType::General, s); + copy_cpu(arr, arr_copy, CopyType::General, s); int64_t stx = arr.shape(-1); return std::make_tuple(false, stx, arr_copy, true); } @@ -386,7 +386,7 @@ void GatherMM::eval_cpu(const std::vector& inputs, array& out) { return std::make_tuple(true, sty, arr); } else { temps.push_back(array(arr.shape(), arr.dtype(), nullptr, {})); - copy(arr, temps.back(), CopyType::General, s); + copy_cpu(arr, temps.back(), CopyType::General, s); int64_t stx = arr.shape(-1); return std::make_tuple(false, stx, temps.back()); } @@ -504,7 +504,7 @@ void SegmentedMM::eval_cpu(const std::vector& inputs, array& out) { return std::make_tuple(true, sty, x); } else { array xc(x.shape(), x.dtype(), nullptr, {}); - copy(x, xc, CopyType::General, s); + copy_cpu(x, xc, CopyType::General, s); encoder.add_temporary(xc); int64_t stx = x.shape(-1); return std::make_tuple(false, stx, xc); diff --git a/mlx/backend/cpu/matmul.cpp b/mlx/backend/cpu/matmul.cpp index b944aacc0..7997c75ed 100644 --- a/mlx/backend/cpu/matmul.cpp +++ b/mlx/backend/cpu/matmul.cpp @@ -81,7 +81,7 @@ void matmul_general( return std::make_tuple(true, sty, arr); } else { temps.push_back(array(arr.shape(), arr.dtype(), nullptr, {})); - copy(arr, temps.back(), CopyType::General, stream); + copy_cpu(arr, temps.back(), CopyType::General, stream); stx = arr.shape(-1); return std::make_tuple(false, stx, temps.back()); } @@ -142,7 +142,7 @@ void AddMM::eval_cpu(const std::vector& inputs, array& out) { CopyType ctype = c.data_size() == 1 ? CopyType::Scalar : (c.flags().row_contiguous ? CopyType::Vector : CopyType::General); - copy(c, out, ctype, stream()); + copy_cpu(c, out, ctype, stream()); if (inputs[0].shape(-1) == 0) { return; } diff --git a/mlx/backend/cpu/primitives.cpp b/mlx/backend/cpu/primitives.cpp index 2a612a2d9..f2cb12fdd 100644 --- a/mlx/backend/cpu/primitives.cpp +++ b/mlx/backend/cpu/primitives.cpp @@ -22,7 +22,7 @@ void reshape(const array& in, array& out) { auto [copy_necessary, out_strides] = prepare_reshape(in, out); if (copy_necessary) { out.set_data(allocator::malloc(out.nbytes())); - copy_inplace(in, out, CopyType::General, out.primitive().stream()); + copy_cpu_inplace(in, out, CopyType::General, out.primitive().stream()); } else { shared_buffer_reshape(in, out_strides, out); } @@ -175,7 +175,7 @@ void AsType::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto& in = inputs[0]; CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General; - copy(in, out, ctype, stream()); + copy_cpu(in, out, ctype, stream()); } void Concatenate::eval_cpu(const std::vector& inputs, array& out) { @@ -198,7 +198,7 @@ void Concatenate::eval_cpu(const std::vector& inputs, array& out) { size_t data_offset = strides[axis_] * sizes[i]; out_slice.copy_shared_buffer( out, strides, flags, out_slice.size(), data_offset); - copy_inplace(inputs[i], out_slice, CopyType::GeneralGeneral, stream()); + copy_cpu_inplace(inputs[i], out_slice, CopyType::GeneralGeneral, stream()); } } @@ -211,7 +211,7 @@ void Contiguous::eval_cpu(const std::vector& inputs, array& out) { (allow_col_major_ && in.flags().col_contiguous))) { out.copy_shared_buffer(in); } else { - copy(in, out, CopyType::General, stream()); + copy_cpu(in, out, CopyType::General, stream()); } } @@ -235,7 +235,7 @@ void Full::eval_cpu(const std::vector& inputs, array& out) { } else { ctype = CopyType::General; } - copy(in, out, ctype, stream()); + copy_cpu(in, out, ctype, stream()); } void Pad::eval_cpu(const std::vector& inputs, array& out) { @@ -251,7 +251,7 @@ void Pad::eval_cpu(const std::vector& inputs, array& out) { assert(val.dtype() == in.dtype() && in.dtype() == out.dtype()); // Fill output with val - copy(val, out, CopyType::Scalar, stream()); + copy_cpu(val, out, CopyType::Scalar, stream()); // Find offset for start of input values size_t data_offset = 0; @@ -266,7 +266,7 @@ void Pad::eval_cpu(const std::vector& inputs, array& out) { out, out.strides(), out.flags(), out_slice.size(), data_offset); // Copy input values into the slice - copy_inplace(in, out_slice, CopyType::GeneralGeneral, stream()); + copy_cpu_inplace(in, out_slice, CopyType::GeneralGeneral, stream()); } void RandomBits::eval_cpu(const std::vector& inputs, array& out) { @@ -340,7 +340,7 @@ void DynamicSlice::eval_cpu(const std::vector& inputs, array& out) { out.set_data(allocator::malloc(out.nbytes())); auto [in_offset, donated] = compute_dynamic_offset(inputs[1], in.strides(), axes_, stream()); - copy_inplace( + copy_cpu_inplace( /* const array& src = */ in, /* array& dst = */ out, /* const Shape& data_shape = */ out.shape(), @@ -372,11 +372,11 @@ void DynamicSliceUpdate::eval_cpu( auto ctype = in.flags().contiguous && in.size() == in.data_size() ? CopyType::Vector : CopyType::General; - copy(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, stream()); + copy_cpu(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, stream()); auto [out_offset, donated] = compute_dynamic_offset(inputs[2], out.strides(), axes_, stream()); - copy_inplace( + copy_cpu_inplace( /* const array& src = */ upd, /* array& dst = */ out, /* const std::vector& data_shape = */ upd.shape(), @@ -412,14 +412,14 @@ void SliceUpdate::eval_cpu(const std::vector& inputs, array& out) { auto ctype = in.flags().contiguous && in.size() == in.data_size() ? CopyType::Vector : CopyType::General; - copy(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, stream()); + copy_cpu(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, stream()); // Calculate out strides, initial offset and if copy needs to be made auto [data_offset, out_strides] = prepare_slice(out, start_indices_, strides_); // Do copy - copy_inplace( + copy_cpu_inplace( /* const array& src = */ upd, /* array& dst = */ out, /* const std::vector& data_shape = */ upd.shape(), @@ -456,9 +456,9 @@ void View::eval_cpu(const std::vector& inputs, array& out) { if (in.dtype() == bool_) { auto in_tmp = array(in.shape(), uint8, nullptr, {}); in_tmp.copy_shared_buffer(in); - copy_inplace(in_tmp, tmp, CopyType::General, stream()); + copy_cpu_inplace(in_tmp, tmp, CopyType::General, stream()); } else { - copy_inplace(in, tmp, CopyType::General, stream()); + copy_cpu_inplace(in, tmp, CopyType::General, stream()); } auto flags = out.flags(); diff --git a/mlx/backend/cpu/qrf.cpp b/mlx/backend/cpu/qrf.cpp index 9e01d188b..13c7e1132 100644 --- a/mlx/backend/cpu/qrf.cpp +++ b/mlx/backend/cpu/qrf.cpp @@ -26,7 +26,7 @@ void qrf_impl(const array& a, array& q, array& r, Stream stream) { strides[in.ndim() - 2] = 1; strides[in.ndim() - 1] = M; in.set_data(allocator::malloc(in.nbytes()), in.nbytes(), strides, flags); - copy_inplace(a, in, CopyType::GeneralGeneral, stream); + copy_cpu_inplace(a, in, CopyType::GeneralGeneral, stream); auto& encoder = cpu::get_command_encoder(stream); q.set_data(allocator::malloc(q.nbytes())); r.set_data(allocator::malloc(r.nbytes())); diff --git a/mlx/backend/cpu/quantized.cpp b/mlx/backend/cpu/quantized.cpp index ee8e56cc0..ee61221da 100644 --- a/mlx/backend/cpu/quantized.cpp +++ b/mlx/backend/cpu/quantized.cpp @@ -529,7 +529,7 @@ void QuantizedMatmul::eval_cpu(const std::vector& inputs, array& out) { return arr; } else { temps.push_back(array(arr.shape(), arr.dtype(), nullptr, {})); - copy(arr, temps.back(), CopyType::General, s); + copy_cpu(arr, temps.back(), CopyType::General, s); return temps.back(); } }; @@ -579,7 +579,7 @@ void GatherQMM::eval_cpu(const std::vector& inputs, array& out) { return arr; } else { temps.push_back(array(arr.shape(), arr.dtype(), nullptr, {})); - copy(arr, temps.back(), CopyType::General, s); + copy_cpu(arr, temps.back(), CopyType::General, s); return temps.back(); } }; @@ -713,7 +713,7 @@ void fast::AffineQuantize::eval_cpu( return std::make_pair(arr, false); } else { array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); - copy(arr, arr_copy, CopyType::General, s); + copy_cpu(arr, arr_copy, CopyType::General, s); return std::make_pair(arr_copy, true); } }; diff --git a/mlx/backend/cpu/scan.cpp b/mlx/backend/cpu/scan.cpp index 33addd161..a62763fa0 100644 --- a/mlx/backend/cpu/scan.cpp +++ b/mlx/backend/cpu/scan.cpp @@ -251,7 +251,7 @@ void Scan::eval_cpu(const std::vector& inputs, array& out) { auto in = inputs[0]; if (!in.flags().row_contiguous) { array arr_copy(in.shape(), in.dtype(), nullptr, {}); - copy(in, arr_copy, CopyType::General, stream()); + copy_cpu(in, arr_copy, CopyType::General, stream()); in = arr_copy; encoder.add_temporary(arr_copy); } diff --git a/mlx/backend/cpu/softmax.cpp b/mlx/backend/cpu/softmax.cpp index 41d14f556..8823c7906 100644 --- a/mlx/backend/cpu/softmax.cpp +++ b/mlx/backend/cpu/softmax.cpp @@ -132,7 +132,7 @@ void Softmax::eval_cpu(const std::vector& inputs, array& out) { return x; } else { array x_copy(x.shape(), x.dtype(), nullptr, {}); - copy(x, x_copy, CopyType::General, s); + copy_cpu(x, x_copy, CopyType::General, s); out.copy_shared_buffer(x_copy); return x_copy; } diff --git a/mlx/backend/cpu/sort.cpp b/mlx/backend/cpu/sort.cpp index b00e301b8..f2243f60f 100644 --- a/mlx/backend/cpu/sort.cpp +++ b/mlx/backend/cpu/sort.cpp @@ -335,7 +335,7 @@ void Sort::eval_cpu(const std::vector& inputs, array& out) { // Copy input to output CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General; - copy(in, out, ctype, stream()); + copy_cpu(in, out, ctype, stream()); auto& encoder = cpu::get_command_encoder(stream()); encoder.set_output_array(out); @@ -427,7 +427,7 @@ void Partition::eval_cpu(const std::vector& inputs, array& out) { // Copy input to output CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General; - copy(in, out, ctype, stream()); + copy_cpu(in, out, ctype, stream()); auto& encoder = cpu::get_command_encoder(stream()); encoder.set_output_array(out); diff --git a/mlx/backend/cpu/svd.cpp b/mlx/backend/cpu/svd.cpp index 24d93f8e5..08ad444e1 100644 --- a/mlx/backend/cpu/svd.cpp +++ b/mlx/backend/cpu/svd.cpp @@ -31,7 +31,7 @@ void svd_impl( // lapack clobbers the input, so we have to make a copy. array in(a.shape(), a.dtype(), nullptr, {}); - copy( + copy_cpu( a, in, a.flags().row_contiguous ? CopyType::Vector : CopyType::General,