diff --git a/mlx/backend/common/transpose.cpp b/mlx/backend/common/transpose.cpp index d0c79011e..f6ddd7289 100644 --- a/mlx/backend/common/transpose.cpp +++ b/mlx/backend/common/transpose.cpp @@ -1,5 +1,7 @@ // Copyright © 2024 Apple Inc. +#include + #include "mlx/backend/common/utils.h" namespace mlx::core { @@ -28,4 +30,28 @@ void transpose(const array& in, array& out, const std::vector& axes) { out.copy_shared_buffer(in, out_strides, flags, in.data_size()); } +void as_transposed(array& out, const std::vector& axes) { + assert(out.data_size() == out.size() && out.flags().contiguous); + + // Calculate the contiguous strides. + Strides strides(out.ndim(), 1); + for (int i = out.ndim() - 2; i >= 0; i--) { + strides[i] = strides[i + 1] * out.shape(i); + } + + // Calculate the new strides for transposing. + Strides new_strides; + new_strides.reserve(out.ndim()); + for (auto ax : axes) { + new_strides.push_back(strides[ax]); + } + + auto [ds, rc, cc] = check_contiguity(out.shape(), new_strides); + auto flags = out.flags(); + flags.row_contiguous = rc; + flags.col_contiguous = cc; + + out.copy_shared_buffer(out, new_strides, flags, ds); +} + } // namespace mlx::core diff --git a/mlx/backend/common/transpose.h b/mlx/backend/common/transpose.h index 378e80a2a..3ee9758b5 100644 --- a/mlx/backend/common/transpose.h +++ b/mlx/backend/common/transpose.h @@ -7,5 +7,6 @@ namespace mlx::core { void transpose(const array& in, array& out, const std::vector& axes); +void as_transposed(array& out, const std::vector& axes); } // namespace mlx::core diff --git a/mlx/backend/metal/fft.cpp b/mlx/backend/metal/fft.cpp index 60e256174..a34da87ff 100644 --- a/mlx/backend/metal/fft.cpp +++ b/mlx/backend/metal/fft.cpp @@ -771,17 +771,84 @@ void fft_op( d.add_temporaries(std::move(copies), s.index); } -inline array prepare_input( +inline array ensure_fastest_moving_axis( + const array& x, + int axis, + metal::Device& d, + const Stream& s) { + // The axis is already with a stride of 1 so check that we have no overlaps + // and broadcasting and avoid the copy. + if (x.strides(axis) == 1) { + // This is a fairly strict test perhaps consider relaxing it in the future. + if (x.flags().row_contiguous || x.flags().col_contiguous) { + return x; + } + } + + // To make it the fastest moving axis simply transpose it, then copy it and + // then transpose it back. + + // Transpose it + std::vector axes(x.ndim(), 0); + for (int ax = 0; ax < axes.size(); ax++) { + axes[ax] = (ax < axis) ? ax : ax + 1; + } + axes.back() = axis; + Shape xtshape; + xtshape.reserve(axes.size()); + for (auto ax : axes) { + xtshape.push_back(x.shape(ax)); + } + array xt(xtshape, x.dtype(), nullptr, {}); + transpose(x, xt, axes); + + // Copy it + array xtc(xt.shape(), x.dtype(), nullptr, {}); + copy_gpu( + xt, + xtc, + xt.flags().row_contiguous ? CopyType::Vector : CopyType::General, + s); + d.add_temporary(xtc, s.index); + + // Transpose it + for (int ax = 0; ax < axes.size(); ax++) { + axes[ax] = (ax < axis) ? ax : ((ax == axis) ? axes.size() - 1 : ax - 1); + } + array y(x.shape(), x.dtype(), nullptr, {}); + transpose(xtc, y, axes); + + return y; +} + +inline void prepare_output_array(const array& in, array& out, int axis) { + // Prepare the output array such that it matches the input in terms of + // stride ordering. Namely we might have moved `axis` around in the `in` + // array. We must do the same in `out`. The difference is that we don't have + // to copy anything because `out` contains garbage at the moment. + + if (in.flags().row_contiguous && out.flags().row_contiguous) { + return; + } + + std::vector axes(out.ndim(), 0); + for (int ax = 0; ax < axes.size(); ax++) { + axes[ax] = (ax < axis) ? ax : ax + 1; + } + axes.back() = axis; + as_transposed(out, axes); +} void fft_stockham_inplace( - const array& in, + const array& in_, array& out, size_t axis, bool inverse, bool real, metal::Device& d, const Stream& s) { - + array in = ensure_fastest_moving_axis(in_, axis, d, s); + prepare_output_array(in, out, axis); } void fft_op_inplace( @@ -790,7 +857,7 @@ void fft_op_inplace( size_t axis, bool inverse, bool real, - metal::Device &d, + metal::Device& d, const Stream& s) { // Get the FFT size and plan it size_t n = out.dtype() == float32 ? out.shape(axis) : in.shape(axis); @@ -811,7 +878,7 @@ void nd_fft_op_inplace( const std::vector& axes, bool inverse, bool real, - metal::Device &d, + metal::Device& d, const Stream& s) { // We are going to make and possibly reuse some intermediate arrays that will // hold the intermediate fft results.