diff --git a/mlx/backend/common/primitives.cpp b/mlx/backend/common/primitives.cpp index 0bdc09ce9..ef00721aa 100644 --- a/mlx/backend/common/primitives.cpp +++ b/mlx/backend/common/primitives.cpp @@ -527,8 +527,8 @@ void RandomBits::eval(const std::vector& inputs, array& out) { std::pair> Reshape::prepare_reshape( const array& in, const array& out) { - // Special case for empty arrays - if (in.size() == 0) { + // Special case for empty arrays or row contiguous arrays + if (in.size() == 0 || in.flags().row_contiguous) { return {false, out.strides()}; } @@ -570,18 +570,13 @@ void Reshape::shared_buffer_reshape( const std::vector& out_strides, array& out) { auto flags = in.flags(); - if (flags.contiguous && in.data_size() == in.size()) { - size_t f_stride = 1; - size_t b_stride = 1; - flags.col_contiguous = true; - flags.row_contiguous = true; - for (int i = 0, ri = out.ndim() - 1; i < out.ndim(); ++i, --ri) { - flags.col_contiguous &= (out_strides[i] == f_stride || out.shape(i) == 1); - f_stride *= out.shape(i); - flags.row_contiguous &= - (out_strides[ri] == b_stride || out.shape(ri) == 1); - b_stride *= out.shape(ri); - } + if (flags.row_contiguous) { + // For row contiguous reshapes: + // - Shallow copy the buffer + // - If reshaping into a vector (all singleton dimensions except one) it + // becomes col contiguous again. + auto max_dim = std::max_element(out.shape().begin(), out.shape().end()); + flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim; } out.copy_shared_buffer(in, out_strides, flags, in.data_size()); }