diff --git a/mlx/backend/cpu/sort.cpp b/mlx/backend/cpu/sort.cpp index f2243f60f..089f7c425 100644 --- a/mlx/backend/cpu/sort.cpp +++ b/mlx/backend/cpu/sort.cpp @@ -334,7 +334,9 @@ void Sort::eval_cpu(const std::vector& inputs, array& out) { auto& in = inputs[0]; // Copy input to output - CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General; + CopyType ctype = (in.flags().contiguous && in.strides()[axis_] != 0) + ? CopyType::Vector + : CopyType::General; copy_cpu(in, out, ctype, stream()); auto& encoder = cpu::get_command_encoder(stream()); @@ -426,7 +428,9 @@ void Partition::eval_cpu(const std::vector& inputs, array& out) { auto& in = inputs[0]; // Copy input to output - CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General; + CopyType ctype = (in.flags().contiguous && in.strides()[axis_] != 0) + ? CopyType::Vector + : CopyType::General; copy_cpu(in, out, ctype, stream()); auto& encoder = cpu::get_command_encoder(stream());