diff --git a/mlx/backend/common/copy.cpp b/mlx/backend/common/copy.cpp index 4f9c4ea7b..34b84258b 100644 --- a/mlx/backend/common/copy.cpp +++ b/mlx/backend/common/copy.cpp @@ -4,6 +4,7 @@ #include "mlx/allocator.h" #include "mlx/backend/common/copy.h" +#include "mlx/backend/common/utils.h" namespace mlx::core { @@ -142,29 +143,31 @@ void copy_general( const std::vector& data_shape, const std::vector& i_strides, int64_t i_offset) { - switch (src.ndim()) { + auto [new_shape, new_strides] = collapse_contiguous_dims( + data_shape, std::vector>{i_strides}); + switch (new_shape.size()) { case 1: copy_general_dim1( - src, dst, data_shape, i_strides, i_offset); + src, dst, new_shape, new_strides[0], i_offset); return; case 2: copy_general_dim2( - src, dst, data_shape, i_strides, i_offset); + src, dst, new_shape, new_strides[0], i_offset); return; case 3: copy_general_dim3( - src, dst, data_shape, i_strides, i_offset); + src, dst, new_shape, new_strides[0], i_offset); return; case 4: copy_general_dim4( - src, dst, data_shape, i_strides, i_offset); + src, dst, new_shape, new_strides[0], i_offset); return; } auto src_ptr = src.data() + i_offset; auto dst_ptr = dst.data(); for (size_t i = 0; i < dst.size(); ++i) { - stride_t src_elem = elem_to_loc(i, data_shape, i_strides); + stride_t src_elem = elem_to_loc(i, new_shape, new_strides[0]); dst_ptr[i] = static_cast(src_ptr[src_elem]); } } @@ -195,10 +198,10 @@ inline void copy_general_general_dims( const std::vector& data_shape, const std::vector& i_strides, const std::vector& o_strides, - stride_t i_offset, - stride_t o_offset) { + int64_t i_offset, + int64_t o_offset) { if constexpr (D > 1) { - int axis = src.ndim() - D; + int axis = data_shape.size() - D; auto stride_src = i_strides[axis]; auto stride_dst = o_strides[axis]; auto N = data_shape[axis]; @@ -209,7 +212,7 @@ inline void copy_general_general_dims( o_offset += stride_dst; } } else { - int axis = src.ndim() - 1; + int axis = data_shape.size() - 1; auto stride_src = i_strides[axis]; auto stride_dst = o_strides[axis]; auto N = data_shape[axis]; @@ -230,38 +233,76 @@ void copy_general_general( const std::vector& data_shape, const std::vector& i_strides, const std::vector& o_strides, - stride_t i_offset, - stride_t o_offset) { - switch (src.ndim()) { + int64_t i_offset, + int64_t o_offset) { + auto [new_shape, new_strides] = collapse_contiguous_dims( + data_shape, std::vector>{i_strides, o_strides}); + switch (new_shape.size()) { case 1: copy_general_general_dims( - src, dst, data_shape, i_strides, o_strides, i_offset, o_offset); + src, + dst, + new_shape, + new_strides[0], + new_strides[1], + i_offset, + o_offset); return; case 2: copy_general_general_dims( - src, dst, data_shape, i_strides, o_strides, i_offset, o_offset); + src, + dst, + new_shape, + new_strides[0], + new_strides[1], + i_offset, + o_offset); return; case 3: copy_general_general_dims( - src, dst, data_shape, i_strides, o_strides, i_offset, o_offset); + src, + dst, + new_shape, + new_strides[0], + new_strides[1], + i_offset, + o_offset); return; case 4: copy_general_general_dims( - src, dst, data_shape, i_strides, o_strides, i_offset, o_offset); + src, + dst, + new_shape, + new_strides[0], + new_strides[1], + i_offset, + o_offset); return; case 5: copy_general_general_dims( - src, dst, data_shape, i_strides, o_strides, i_offset, o_offset); + src, + dst, + new_shape, + new_strides[0], + new_strides[1], + i_offset, + o_offset); return; } int size = std::accumulate( - data_shape.end() - 5, data_shape.end(), 1, std::multiplies()); + new_shape.end() - 5, new_shape.end(), 1, std::multiplies()); for (int i = 0; i < src.size(); i += size) { - stride_t src_offset = i_offset + elem_to_loc(i, data_shape, i_strides); - stride_t dst_offset = o_offset + elem_to_loc(i, dst.shape(), o_strides); + stride_t src_offset = i_offset + elem_to_loc(i, new_shape, new_strides[0]); + stride_t dst_offset = o_offset + elem_to_loc(i, new_shape, new_strides[1]); copy_general_general_dims( - src, dst, data_shape, i_strides, o_strides, src_offset, dst_offset); + src, + dst, + new_shape, + new_strides[0], + new_strides[1], + src_offset, + dst_offset); } } @@ -444,8 +485,17 @@ void copy_inplace( } } -template <> -void copy_inplace( +template void copy_inplace( + const array& src, + array& dst, + const std::vector& data_shape, + const std::vector& i_strides, + const std::vector& o_strides, + int64_t i_offset, + int64_t o_offset, + CopyType ctype); + +template void copy_inplace( const array& src, array& dst, const std::vector& data_shape, @@ -453,24 +503,6 @@ void copy_inplace( const std::vector& o_strides, int64_t i_offset, int64_t o_offset, - CopyType ctype) { - switch (ctype) { - case CopyType::General: - case CopyType::GeneralGeneral: - return copy_inplace_dispatch( - src, - dst, - ctype, - data_shape, - i_strides, - o_strides, - i_offset, - o_offset); - - case CopyType::Scalar: - case CopyType::Vector: - return copy_inplace_dispatch(src, dst, ctype); - } -} + CopyType ctype); } // namespace mlx::core diff --git a/mlx/backend/common/primitives.cpp b/mlx/backend/common/primitives.cpp index a01717e73..99adfaed4 100644 --- a/mlx/backend/common/primitives.cpp +++ b/mlx/backend/common/primitives.cpp @@ -405,7 +405,17 @@ void Reshape::eval(const std::vector& inputs, array& out) { auto [copy_necessary, out_strides] = prepare_reshape(in, out); if (copy_necessary) { - copy(in, out, in.data_size() == 1 ? CopyType::Scalar : CopyType::General); + out.set_data(allocator::malloc_or_wait(out.nbytes())); + auto out_strides = make_contiguous_strides(in.shape()); + copy_inplace( + in, + out, + in.shape(), + in.strides(), + out_strides, + 0, + 0, + CopyType::General); } else { shared_buffer_reshape(in, out_strides, out); } diff --git a/mlx/backend/common/utils.h b/mlx/backend/common/utils.h index a1a1207c7..f3082ef74 100644 --- a/mlx/backend/common/utils.h +++ b/mlx/backend/common/utils.h @@ -29,6 +29,15 @@ inline size_t elem_to_loc(int elem, const array& a) { return elem_to_loc(elem, a.shape(), a.strides()); } +template +std::vector make_contiguous_strides(const std::vector& shape) { + std::vector strides(shape.size(), 1); + for (int i = shape.size() - 1; i > 0; i--) { + strides[i - 1] = strides[i] * shape[i]; + } + return strides; +} + // Collapse dims that are contiguous to possibly route to a better kernel // e.g. for x = transpose(array({0, 1, 2, 3, 4, 5, 6, 7}, {2, 2, 2}), {2, 0, 1}) // should return {{2, 4}, {{1, 2}}}. diff --git a/mlx/backend/metal/primitives.cpp b/mlx/backend/metal/primitives.cpp index 772b8bc17..64896c7e1 100644 --- a/mlx/backend/metal/primitives.cpp +++ b/mlx/backend/metal/primitives.cpp @@ -273,7 +273,18 @@ void Reshape::eval_gpu(const std::vector& inputs, array& out) { auto [copy_necessary, out_strides] = prepare_reshape(in, out); if (copy_necessary) { - copy_gpu(in, out, CopyType::General); + out.set_data(allocator::malloc_or_wait(out.nbytes())); + auto out_strides = make_contiguous_strides(in.shape()); + copy_gpu_inplace( + in, + out, + in.shape(), + in.strides(), + out_strides, + 0, + 0, + CopyType::General, + stream()); } else { shared_buffer_reshape(in, out_strides, out); }