// Copyright © 2023-2024 Apple Inc. #include #include "mlx/allocator.h" #include "mlx/backend/common/utils.h" #include "mlx/backend/cpu/copy.h" #include "mlx/backend/cpu/encoder.h" #include "mlx/backend/cpu/simd/simd.h" namespace mlx::core { namespace { template void copy_single(const array& src, array& dst) { auto src_ptr = src.data(); auto dst_ptr = dst.data(); auto size = dst.size(); auto val = static_cast(src_ptr[0]); std::fill_n(dst_ptr, size, val); } template void copy_vector(const array& src, array& dst) { auto src_ptr = src.data(); auto dst_ptr = dst.data(); auto size = src.data_size(); std::copy(src_ptr, src_ptr + size, dst_ptr); } template inline void copy_dims( const SrcT* src, DstT* dst, const Shape& shape, const Strides& i_strides, const Strides& o_strides, int axis) { auto stride_src = i_strides[axis]; auto stride_dst = o_strides[axis]; auto N = shape[axis]; for (int i = 0; i < N; i++) { if constexpr (D > 1) { copy_dims( src, dst, shape, i_strides, o_strides, axis + 1); } else { *dst = static_cast(*src); } src += stride_src; dst += stride_dst; } } template void copy_general_general( const array& src, array& dst, const Shape& data_shape, const Strides& i_strides, const Strides& o_strides, int64_t i_offset, int64_t o_offset, const std::optional& dynamic_i_offset, const std::optional& dynamic_o_offset) { auto src_ptr = src.data() + i_offset; auto dst_ptr = dst.data() + o_offset; auto i_offset_ptr = dynamic_i_offset ? dynamic_i_offset->data() : nullptr; auto o_offset_ptr = dynamic_o_offset ? dynamic_o_offset->data() : nullptr; auto size = src.size(); if (data_shape.empty()) { auto val = static_cast(*src_ptr); *dst_ptr = val; return; } auto [shape, strides] = collapse_contiguous_dims(data_shape, {i_strides, o_strides}); int ndim = shape.size(); if (ndim < 3) { if (i_offset_ptr) { src_ptr += i_offset_ptr[0]; } if (o_offset_ptr) { dst_ptr += o_offset_ptr[0]; } if (ndim == 1) { copy_dims( src_ptr, dst_ptr, shape, strides[0], strides[1], 0); } else if (ndim == 2) { copy_dims( src_ptr, dst_ptr, shape, strides[0], strides[1], 0); } else if (ndim == 3) { copy_dims( src_ptr, dst_ptr, shape, strides[0], strides[1], 0); } return; } if (i_offset_ptr) { src_ptr += i_offset_ptr[0]; } if (o_offset_ptr) { dst_ptr += o_offset_ptr[0]; } ContiguousIterator in(shape, strides[0], ndim - 3); ContiguousIterator out(shape, strides[1], ndim - 3); auto stride = std::accumulate( shape.end() - 3, shape.end(), 1, std::multiplies()); for (int64_t elem = 0; elem < size; elem += stride) { copy_dims( src_ptr + in.loc, dst_ptr + out.loc, shape, strides[0], strides[1], ndim - 3); in.step(); out.step(); } } template inline void copy_general_general(const array& src, array& dst) { copy_general_general( src, dst, src.shape(), src.strides(), dst.strides(), 0, 0, std::nullopt, std::nullopt); } template void copy_general( const array& src, array& dst, const Shape& data_shape, const Strides& i_strides, const Strides&, int64_t i_offset, int64_t o_offset, const std::optional& dynamic_i_offset, const std::optional& dynamic_o_offset) { copy_general_general( src, dst, data_shape, i_strides, make_contiguous_strides(data_shape), i_offset, o_offset, dynamic_i_offset, dynamic_o_offset); } template inline void copy_general(const array& src, array& dst) { copy_general_general( src, dst, src.shape(), src.strides(), make_contiguous_strides(src.shape()), 0, 0, std::nullopt, std::nullopt); } template void copy(const array& src, array& dst, CopyType ctype, Args&&... args) { switch (ctype) { case CopyType::Scalar: copy_single(src, dst); return; case CopyType::Vector: copy_vector(src, dst); return; case CopyType::General: copy_general(src, dst, std::forward(args)...); return; case CopyType::GeneralGeneral: copy_general_general(src, dst, std::forward(args)...); return; } } template void copy(const array& src, array& dst, CopyType ctype, Args&&... args) { switch (dst.dtype()) { case bool_: copy(src, dst, ctype, std::forward(args)...); break; case uint8: copy(src, dst, ctype, std::forward(args)...); break; case uint16: copy(src, dst, ctype, std::forward(args)...); break; case uint32: copy(src, dst, ctype, std::forward(args)...); break; case uint64: copy(src, dst, ctype, std::forward(args)...); break; case int8: copy(src, dst, ctype, std::forward(args)...); break; case int16: copy(src, dst, ctype, std::forward(args)...); break; case int32: copy(src, dst, ctype, std::forward(args)...); break; case int64: copy(src, dst, ctype, std::forward(args)...); break; case float16: copy(src, dst, ctype, std::forward(args)...); break; case float32: copy(src, dst, ctype, std::forward(args)...); break; case float64: copy(src, dst, ctype, std::forward(args)...); break; case bfloat16: copy(src, dst, ctype, std::forward(args)...); break; case complex64: copy(src, dst, ctype, std::forward(args)...); break; } } template inline void copy_inplace_dispatch( const array& src, array& dst, CopyType ctype, Args&&... args) { switch (src.dtype()) { case bool_: copy(src, dst, ctype, std::forward(args)...); break; case uint8: copy(src, dst, ctype, std::forward(args)...); break; case uint16: copy(src, dst, ctype, std::forward(args)...); break; case uint32: copy(src, dst, ctype, std::forward(args)...); break; case uint64: copy(src, dst, ctype, std::forward(args)...); break; case int8: copy(src, dst, ctype, std::forward(args)...); break; case int16: copy(src, dst, ctype, std::forward(args)...); break; case int32: copy(src, dst, ctype, std::forward(args)...); break; case int64: copy(src, dst, ctype, std::forward(args)...); break; case float16: copy(src, dst, ctype, std::forward(args)...); break; case float32: copy(src, dst, ctype, std::forward(args)...); break; case float64: copy(src, dst, ctype, std::forward(args)...); break; case bfloat16: copy(src, dst, ctype, std::forward(args)...); break; case complex64: copy(src, dst, ctype, std::forward(args)...); break; } } } // namespace void copy_inplace(const array& src, array& dst, CopyType ctype, Stream stream) { auto& encoder = cpu::get_command_encoder(stream); encoder.set_input_array(src); encoder.set_output_array(dst); encoder.dispatch( [src = array::unsafe_weak_copy(src), dst = array::unsafe_weak_copy(dst), ctype]() mutable { copy_inplace_dispatch(src, dst, ctype); }); } void copy(const array& src, array& dst, CopyType ctype, Stream stream) { bool donated = set_copy_output_data(src, dst, ctype); if (donated && src.dtype() == dst.dtype()) { // If the output has the same type as the input then there is nothing to // copy, just use the buffer. return; } if (ctype == CopyType::GeneralGeneral) { ctype = CopyType::General; } copy_inplace(src, dst, ctype, stream); } void copy_inplace( const array& src, array& dst, const Shape& data_shape, const Strides& i_strides, const Strides& o_strides, int64_t i_offset, int64_t o_offset, CopyType ctype, Stream stream, const std::optional& dynamic_i_offset, /* = std::nullopt */ const std::optional& dynamic_o_offset /* = std::nullopt */) { auto& encoder = cpu::get_command_encoder(stream); encoder.set_input_array(src); encoder.set_output_array(dst); auto weak_copy_if_set = [](auto x) -> std::optional { if (x) { return array::unsafe_weak_copy(*x); } else { return std::nullopt; } }; encoder.dispatch( [src = array::unsafe_weak_copy(src), dst = array::unsafe_weak_copy(dst), data_shape, i_strides, o_strides, i_offset, o_offset, ctype, dynamic_i_offset = weak_copy_if_set(dynamic_i_offset), dynamic_o_offset = weak_copy_if_set(dynamic_o_offset)]() mutable { switch (ctype) { case CopyType::General: case CopyType::GeneralGeneral: copy_inplace_dispatch( src, dst, ctype, data_shape, i_strides, o_strides, i_offset, o_offset, dynamic_i_offset, dynamic_o_offset); break; case CopyType::Scalar: case CopyType::Vector: copy_inplace_dispatch(src, dst, ctype); } }); } } // namespace mlx::core