diff --git a/mlx/backend/metal/copy.cpp b/mlx/backend/metal/copy.cpp index b57d254e8..2ba38c10d 100644 --- a/mlx/backend/metal/copy.cpp +++ b/mlx/backend/metal/copy.cpp @@ -156,6 +156,7 @@ void copy_gpu_inplace( array& out, CopyType ctype, const Stream& s) { + assert(in.shape() == out.shape()); return copy_gpu_inplace( in, out, in.shape(), in.strides(), out.strides(), 0, 0, ctype, s); } @@ -167,6 +168,7 @@ void copy_gpu_inplace( int64_t ioffset, CopyType ctype, const Stream& s) { + assert(in.shape() == out.shape()); std::vector ostrides{out.strides().begin(), out.strides().end()}; return copy_gpu_inplace( in, out, in.shape(), istride, ostrides, ioffset, 0, ctype, s); diff --git a/mlx/backend/metal/sort.cpp b/mlx/backend/metal/sort.cpp index 824492789..eecf1cef5 100644 --- a/mlx/backend/metal/sort.cpp +++ b/mlx/backend/metal/sort.cpp @@ -236,35 +236,21 @@ void multi_block_sort( } // Copy outputs with appropriate strides - array strided_out_arr = argsort ? dev_idxs_out : dev_vals_out; - - if (axis == in.ndim() - 1) { - copy_gpu_inplace(strided_out_arr, out, CopyType::Vector, s); - } else { - std::vector strided_out_shape = in.shape(); - int out_axis_shape = strided_out_shape[axis]; - - strided_out_shape.erase(strided_out_shape.begin() + axis); - strided_out_shape.push_back(out_axis_shape); - - std::vector strided_out_str(in.ndim(), 1); - for (int i = in.ndim() - 2; i >= 0; --i) { - strided_out_str[i] = strided_out_str[i + 1] * strided_out_shape[i + 1]; - } - - strided_out_str.erase(strided_out_str.end() - 1); - strided_out_str.insert(strided_out_str.begin() + axis, 1); - - array strided_out_slice(in.shape(), out.dtype(), nullptr, {}); - strided_out_slice.copy_shared_buffer( - strided_out_arr, - strided_out_str, - strided_out_arr.flags(), - strided_out_arr.size(), - 0); - - copy_gpu_inplace(strided_out_slice, out, CopyType::General, s); + auto strides = out.strides(); + for (int ax = axis + 1; ax < strides.size(); ax++) { + strides[ax] *= out.shape(axis); } + strides[axis] = 1; + copy_gpu_inplace( + (argsort) ? dev_idxs_out : dev_vals_out, + out, + out.shape(), + strides, + out.strides(), + 0, + 0, + (axis == in.ndim() - 1) ? CopyType::Vector : CopyType::General, + s); // Clear copies d.get_command_buffer(s.index)->addCompletedHandler(