// Copyright © 2023-2024 Apple Inc. #include #include "mlx/allocator.h" #include "mlx/backend/common/copy.h" namespace mlx::core { namespace { template void copy_single(const array& src, array& dst) { auto val = static_cast(src.data()[0]); auto dst_ptr = dst.data(); for (int i = 0; i < dst.size(); ++i) { dst_ptr[i] = val; } } template void copy_vector(const array& src, array& dst) { auto src_ptr = src.data(); auto dst_ptr = dst.data(); std::copy(src_ptr, src_ptr + src.data_size(), dst_ptr); } template void copy_general_dim1( const array& src, array& dst, const std::vector& data_shape, const std::vector& i_strides, int64_t i_offset) { const SrcT* src_ptr = src.data(); DstT* dst_ptr = dst.data(); stride_t src_idx = i_offset; stride_t dst_idx = 0; for (int i = 0; i < data_shape[0]; ++i) { dst_ptr[dst_idx++] = static_cast(src_ptr[src_idx]); src_idx += i_strides[0]; } } template inline void copy_general_dim1(const array& src, array& dst) { return copy_general_dim1( src, dst, src.shape(), src.strides(), 0); } template void copy_general_dim2( const array& src, array& dst, const std::vector& data_shape, const std::vector& i_strides, int64_t i_offset) { const SrcT* src_ptr = src.data(); DstT* dst_ptr = dst.data(); stride_t src_idx = i_offset; stride_t dst_idx = 0; for (int i = 0; i < data_shape[0]; ++i) { for (int j = 0; j < data_shape[1]; ++j) { dst_ptr[dst_idx++] = static_cast(src_ptr[src_idx]); src_idx += i_strides[1]; } src_idx += i_strides[0] - i_strides[1] * data_shape[1]; } } template inline void copy_general_dim2(const array& src, array& dst) { return copy_general_dim2( src, dst, src.shape(), src.strides(), 0); } template void copy_general_dim3( const array& src, array& dst, const std::vector& data_shape, const std::vector& i_strides, int64_t i_offset) { const SrcT* src_ptr = src.data(); DstT* dst_ptr = dst.data(); stride_t src_idx = i_offset; stride_t dst_idx = 0; for (int i = 0; i < data_shape[0]; ++i) { for (int j = 0; j < data_shape[1]; ++j) { for (int k = 0; k < data_shape[2]; ++k) { dst_ptr[dst_idx++] = static_cast(src_ptr[src_idx]); src_idx += i_strides[2]; } src_idx += i_strides[1] - i_strides[2] * data_shape[2]; } src_idx += i_strides[0] - i_strides[1] * data_shape[1]; } } template inline void copy_general_dim3(const array& src, array& dst) { return copy_general_dim3( src, dst, src.shape(), src.strides(), 0); } template void copy_general_dim4( const array& src, array& dst, const std::vector& data_shape, const std::vector& i_strides, int64_t i_offset) { const SrcT* src_ptr = src.data(); DstT* dst_ptr = dst.data(); stride_t src_idx = i_offset; stride_t dst_idx = 0; for (int i = 0; i < data_shape[0]; ++i) { for (int j = 0; j < data_shape[1]; ++j) { for (int k = 0; k < data_shape[2]; ++k) { for (int ii = 0; ii < data_shape[3]; ++ii) { dst_ptr[dst_idx++] = static_cast(src_ptr[src_idx]); src_idx += i_strides[3]; } src_idx += i_strides[2] - i_strides[3] * data_shape[3]; } src_idx += i_strides[1] - i_strides[2] * data_shape[2]; } src_idx += i_strides[0] - i_strides[1] * data_shape[1]; } } template inline void copy_general_dim4(const array& src, array& dst) { return copy_general_dim4( src, dst, src.shape(), src.strides(), 0); } template void copy_general( const array& src, array& dst, const std::vector& data_shape, const std::vector& i_strides, int64_t i_offset) { switch (src.ndim()) { case 1: copy_general_dim1( src, dst, data_shape, i_strides, i_offset); return; case 2: copy_general_dim2( src, dst, data_shape, i_strides, i_offset); return; case 3: copy_general_dim3( src, dst, data_shape, i_strides, i_offset); return; case 4: copy_general_dim4( src, dst, data_shape, i_strides, i_offset); return; } auto src_ptr = src.data() + i_offset; auto dst_ptr = dst.data(); for (size_t i = 0; i < dst.size(); ++i) { stride_t src_elem = elem_to_loc(i, data_shape, i_strides); dst_ptr[i] = static_cast(src_ptr[src_elem]); } } template inline void copy_general(const array& src, array& dst) { return copy_general( src, dst, src.shape(), src.strides(), 0); } template inline void copy_general( const array& src, array& dst, const std::vector& data_shape, const std::vector& i_strides, const std::vector& o_strides, int64_t i_offset, int64_t o_offset) { return copy_general( src, dst, data_shape, i_strides, i_offset); } template inline void copy_general_general_dims( const array& src, array& dst, const std::vector& data_shape, const std::vector& i_strides, const std::vector& o_strides, stride_t i_offset, stride_t o_offset) { if constexpr (D > 1) { int axis = src.ndim() - D; auto stride_src = i_strides[axis]; auto stride_dst = o_strides[axis]; auto N = data_shape[axis]; for (int i = 0; i < N; i++) { copy_general_general_dims( src, dst, data_shape, i_strides, o_strides, i_offset, o_offset); i_offset += stride_src; o_offset += stride_dst; } } else { int axis = src.ndim() - 1; auto stride_src = i_strides[axis]; auto stride_dst = o_strides[axis]; auto N = data_shape[axis]; const SrcT* src_ptr = src.data() + i_offset; DstT* dst_ptr = dst.data() + o_offset; for (int i = 0; i < N; i++) { *dst_ptr = static_cast(*src_ptr); src_ptr += stride_src; dst_ptr += stride_dst; } } } template void copy_general_general( const array& src, array& dst, const std::vector& data_shape, const std::vector& i_strides, const std::vector& o_strides, stride_t i_offset, stride_t o_offset) { switch (src.ndim()) { case 1: copy_general_general_dims( src, dst, data_shape, i_strides, o_strides, i_offset, o_offset); return; case 2: copy_general_general_dims( src, dst, data_shape, i_strides, o_strides, i_offset, o_offset); return; case 3: copy_general_general_dims( src, dst, data_shape, i_strides, o_strides, i_offset, o_offset); return; case 4: copy_general_general_dims( src, dst, data_shape, i_strides, o_strides, i_offset, o_offset); return; case 5: copy_general_general_dims( src, dst, data_shape, i_strides, o_strides, i_offset, o_offset); return; } int size = std::accumulate( data_shape.begin() - 5, data_shape.end(), 1, std::multiplies()); for (int i = 0; i < src.size(); i += size) { stride_t src_offset = i_offset + elem_to_loc(i, data_shape, i_strides); stride_t dst_offset = o_offset + elem_to_loc(i, dst.shape(), o_strides); copy_general_general_dims( src, dst, data_shape, i_strides, o_strides, src_offset, dst_offset); } } template inline void copy_general_general(const array& src, array& dst) { return copy_general_general( src, dst, src.shape(), src.strides(), dst.strides(), 0, 0); } template void copy(const array& src, array& dst, CopyType ctype, Args&&... args) { switch (ctype) { case CopyType::Scalar: copy_single(src, dst); return; case CopyType::Vector: copy_vector(src, dst); return; case CopyType::General: copy_general(src, dst, std::forward(args)...); return; case CopyType::GeneralGeneral: copy_general_general(src, dst, std::forward(args)...); } } template void copy(const array& src, array& dst, CopyType ctype, Args&&... args) { switch (dst.dtype()) { case bool_: copy(src, dst, ctype, std::forward(args)...); break; case uint8: copy(src, dst, ctype, std::forward(args)...); break; case uint16: copy(src, dst, ctype, std::forward(args)...); break; case uint32: copy(src, dst, ctype, std::forward(args)...); break; case uint64: copy(src, dst, ctype, std::forward(args)...); break; case int8: copy(src, dst, ctype, std::forward(args)...); break; case int16: copy(src, dst, ctype, std::forward(args)...); break; case int32: copy(src, dst, ctype, std::forward(args)...); break; case int64: copy(src, dst, ctype, std::forward(args)...); break; case float16: copy(src, dst, ctype, std::forward(args)...); break; case float32: copy(src, dst, ctype, std::forward(args)...); break; case bfloat16: copy(src, dst, ctype, std::forward(args)...); break; case complex64: copy(src, dst, ctype, std::forward(args)...); break; } } template inline void copy_inplace_dispatch( const array& src, array& dst, CopyType ctype, Args&&... args) { switch (src.dtype()) { case bool_: copy(src, dst, ctype, std::forward(args)...); break; case uint8: copy(src, dst, ctype, std::forward(args)...); break; case uint16: copy(src, dst, ctype, std::forward(args)...); break; case uint32: copy(src, dst, ctype, std::forward(args)...); break; case uint64: copy(src, dst, ctype, std::forward(args)...); break; case int8: copy(src, dst, ctype, std::forward(args)...); break; case int16: copy(src, dst, ctype, std::forward(args)...); break; case int32: copy(src, dst, ctype, std::forward(args)...); break; case int64: copy(src, dst, ctype, std::forward(args)...); break; case float16: copy(src, dst, ctype, std::forward(args)...); break; case float32: copy(src, dst, ctype, std::forward(args)...); break; case bfloat16: copy(src, dst, ctype, std::forward(args)...); break; case complex64: copy(src, dst, ctype, std::forward(args)...); break; } } } // namespace void copy_inplace(const array& src, array& dst, CopyType ctype) { return copy_inplace_dispatch(src, dst, ctype); } void copy(const array& src, array& dst, CopyType ctype) { // Allocate the output switch (ctype) { case CopyType::Vector: if (src.is_donatable() && src.itemsize() == dst.itemsize()) { dst.copy_shared_buffer(src); } else { auto size = src.data_size(); dst.set_data( allocator::malloc_or_wait(size * dst.itemsize()), size, src.strides(), src.flags()); } break; case CopyType::Scalar: case CopyType::General: case CopyType::GeneralGeneral: dst.set_data(allocator::malloc_or_wait(dst.nbytes())); break; } if (ctype == CopyType::GeneralGeneral) { ctype = CopyType::General; } copy_inplace(src, dst, ctype); } template void copy_inplace( const array& src, array& dst, const std::vector& data_shape, const std::vector& i_strides, const std::vector& o_strides, int64_t i_offset, int64_t o_offset, CopyType ctype) { switch (ctype) { case CopyType::General: case CopyType::GeneralGeneral: return copy_inplace_dispatch( src, dst, ctype, data_shape, i_strides, o_strides, i_offset, o_offset); case CopyType::Scalar: case CopyType::Vector: return copy_inplace_dispatch(src, dst, ctype); } } template <> void copy_inplace( const array& src, array& dst, const std::vector& data_shape, const std::vector& i_strides, const std::vector& o_strides, int64_t i_offset, int64_t o_offset, CopyType ctype) { switch (ctype) { case CopyType::General: case CopyType::GeneralGeneral: return copy_inplace_dispatch( src, dst, ctype, data_shape, i_strides, o_strides, i_offset, o_offset); case CopyType::Scalar: case CopyType::Vector: return copy_inplace_dispatch(src, dst, ctype); } } } // namespace mlx::core