// Copyright © 2023 Apple Inc. #include #include "mlx/allocator.h" #include "mlx/backend/common/copy.h" namespace mlx::core { namespace { template void copy_single(const array& src, array& dst) { auto val = static_cast(src.data()[0]); auto dst_ptr = dst.data(); for (int i = 0; i < dst.size(); ++i) { dst_ptr[i] = val; } } template void copy_vector(const array& src, array& dst) { auto src_ptr = src.data(); auto dst_ptr = dst.data(); std::copy(src_ptr, src_ptr + src.data_size(), dst_ptr); } template void copy_general_dim1(const array& src, array& dst) { const SrcT* src_ptr = src.data(); DstT* dst_ptr = dst.data(); size_t src_idx = 0; size_t dst_idx = 0; for (size_t i = 0; i < src.shape()[0]; ++i) { dst_ptr[dst_idx++] = static_cast(src_ptr[src_idx]); src_idx += src.strides()[0]; } } template void copy_general_dim2(const array& src, array& dst) { const SrcT* src_ptr = src.data(); DstT* dst_ptr = dst.data(); size_t src_idx = 0; size_t dst_idx = 0; for (size_t i = 0; i < src.shape()[0]; ++i) { for (size_t j = 0; j < src.shape()[1]; ++j) { dst_ptr[dst_idx++] = static_cast(src_ptr[src_idx]); src_idx += src.strides()[1]; } src_idx += src.strides()[0] - src.strides()[1] * src.shape()[1]; } } template void copy_general_dim3(const array& src, array& dst) { const SrcT* src_ptr = src.data(); DstT* dst_ptr = dst.data(); size_t src_idx = 0; size_t dst_idx = 0; for (size_t i = 0; i < src.shape()[0]; ++i) { for (size_t j = 0; j < src.shape()[1]; ++j) { for (size_t k = 0; k < src.shape()[2]; ++k) { dst_ptr[dst_idx++] = static_cast(src_ptr[src_idx]); src_idx += src.strides()[2]; } src_idx += src.strides()[1] - src.strides()[2] * src.shape()[2]; } src_idx += src.strides()[0] - src.strides()[1] * src.shape()[1]; } } template void copy_general_dim4(const array& src, array& dst) { const SrcT* src_ptr = src.data(); DstT* dst_ptr = dst.data(); size_t src_idx = 0; size_t dst_idx = 0; for (size_t i = 0; i < src.shape()[0]; ++i) { for (size_t j = 0; j < src.shape()[1]; ++j) { for (size_t k = 0; k < src.shape()[2]; ++k) { for (size_t ii = 0; ii < src.shape()[3]; ++ii) { dst_ptr[dst_idx++] = static_cast(src_ptr[src_idx]); src_idx += src.strides()[3]; } src_idx += src.strides()[2] - src.strides()[3] * src.shape()[3]; } src_idx += src.strides()[1] - src.strides()[2] * src.shape()[2]; } src_idx += src.strides()[0] - src.strides()[1] * src.shape()[1]; } } template void copy_general(const array& src, array& dst) { switch (src.ndim()) { case 1: copy_general_dim1(src, dst); return; case 2: copy_general_dim2(src, dst); return; case 3: copy_general_dim3(src, dst); return; case 4: copy_general_dim4(src, dst); return; } auto src_ptr = src.data(); auto dst_ptr = dst.data(); for (size_t i = 0; i < dst.size(); ++i) { size_t src_elem = elem_to_loc(i, src.shape(), src.strides()); dst_ptr[i] = static_cast(src_ptr[src_elem]); } } template inline void copy_general_general_dims( const array& src, array& dst, size_t offset_src, size_t offset_dst) { if constexpr (D > 1) { int axis = src.ndim() - D; auto stride_src = src.strides()[axis]; auto stride_dst = dst.strides()[axis]; auto N = src.shape(axis); for (int i = 0; i < N; i++) { copy_general_general_dims( src, dst, offset_src, offset_dst); offset_src += stride_src; offset_dst += stride_dst; } } else { int axis = src.ndim() - 1; auto stride_src = src.strides()[axis]; auto stride_dst = dst.strides()[axis]; auto N = src.shape(axis); const SrcT* src_ptr = src.data() + offset_src; DstT* dst_ptr = dst.data() + offset_dst; for (int i = 0; i < N; i++) { *dst_ptr = static_cast(*src_ptr); src_ptr += stride_src; dst_ptr += stride_dst; } } } template void copy_general_general(const array& src, array& dst) { switch (src.ndim()) { case 1: copy_general_general_dims(src, dst, 0, 0); return; case 2: copy_general_general_dims(src, dst, 0, 0); return; case 3: copy_general_general_dims(src, dst, 0, 0); return; case 4: copy_general_general_dims(src, dst, 0, 0); return; case 5: copy_general_general_dims(src, dst, 0, 0); return; } int size = std::accumulate( src.shape().begin() - 5, src.shape().end(), 1, std::multiplies()); for (int i = 0; i < src.size(); i += size) { size_t offset_src = elem_to_loc(i, src.shape(), src.strides()); size_t offset_dst = elem_to_loc(i, dst.shape(), dst.strides()); copy_general_general_dims(src, dst, offset_src, offset_dst); } } template void copy(const array& src, array& dst, CopyType ctype) { switch (ctype) { case CopyType::Scalar: copy_single(src, dst); return; case CopyType::Vector: copy_vector(src, dst); return; case CopyType::General: copy_general(src, dst); return; case CopyType::GeneralGeneral: copy_general_general(src, dst); } } template void copy(const array& src, array& dst, CopyType ctype) { switch (dst.dtype()) { case bool_: copy(src, dst, ctype); break; case uint8: copy(src, dst, ctype); break; case uint16: copy(src, dst, ctype); break; case uint32: copy(src, dst, ctype); break; case uint64: copy(src, dst, ctype); break; case int8: copy(src, dst, ctype); break; case int16: copy(src, dst, ctype); break; case int32: copy(src, dst, ctype); break; case int64: copy(src, dst, ctype); break; case float16: copy(src, dst, ctype); break; case float32: copy(src, dst, ctype); break; case bfloat16: copy(src, dst, ctype); break; case complex64: copy(src, dst, ctype); break; } } } // namespace void copy_inplace(const array& src, array& dst, CopyType ctype) { switch (src.dtype()) { case bool_: copy(src, dst, ctype); break; case uint8: copy(src, dst, ctype); break; case uint16: copy(src, dst, ctype); break; case uint32: copy(src, dst, ctype); break; case uint64: copy(src, dst, ctype); break; case int8: copy(src, dst, ctype); break; case int16: copy(src, dst, ctype); break; case int32: copy(src, dst, ctype); break; case int64: copy(src, dst, ctype); break; case float16: copy(src, dst, ctype); break; case float32: copy(src, dst, ctype); break; case bfloat16: copy(src, dst, ctype); break; case complex64: copy(src, dst, ctype); break; } } void copy(const array& src, array& dst, CopyType ctype) { // Allocate the output switch (ctype) { case CopyType::Vector: dst.set_data( allocator::malloc_or_wait(src.data_size() * dst.itemsize()), src.data_size(), src.strides(), src.flags()); break; case CopyType::Scalar: case CopyType::General: case CopyType::GeneralGeneral: dst.set_data(allocator::malloc_or_wait(dst.nbytes())); break; } if (ctype == CopyType::GeneralGeneral) { ctype = CopyType::General; } copy_inplace(src, dst, ctype); } } // namespace mlx::core