More tmp fft changes

This commit is contained in:
Angelos Katharopoulos 2025-04-30 22:29:22 -07:00
parent 1704809f29
commit be57a16a80
3 changed files with 99 additions and 5 deletions

View File

@ -1,5 +1,7 @@
// Copyright © 2024 Apple Inc.
#include <cassert>
#include "mlx/backend/common/utils.h"
namespace mlx::core {
@ -28,4 +30,28 @@ void transpose(const array& in, array& out, const std::vector<int>& axes) {
out.copy_shared_buffer(in, out_strides, flags, in.data_size());
}
void as_transposed(array& out, const std::vector<int>& axes) {
assert(out.data_size() == out.size() && out.flags().contiguous);
// Calculate the contiguous strides.
Strides strides(out.ndim(), 1);
for (int i = out.ndim() - 2; i >= 0; i--) {
strides[i] = strides[i + 1] * out.shape(i);
}
// Calculate the new strides for transposing.
Strides new_strides;
new_strides.reserve(out.ndim());
for (auto ax : axes) {
new_strides.push_back(strides[ax]);
}
auto [ds, rc, cc] = check_contiguity(out.shape(), new_strides);
auto flags = out.flags();
flags.row_contiguous = rc;
flags.col_contiguous = cc;
out.copy_shared_buffer(out, new_strides, flags, ds);
}
} // namespace mlx::core

View File

@ -7,5 +7,6 @@
namespace mlx::core {
void transpose(const array& in, array& out, const std::vector<int>& axes);
void as_transposed(array& out, const std::vector<int>& axes);
} // namespace mlx::core

View File

@ -771,17 +771,84 @@ void fft_op(
d.add_temporaries(std::move(copies), s.index);
}
inline array prepare_input(
inline array ensure_fastest_moving_axis(
const array& x,
int axis,
metal::Device& d,
const Stream& s) {
// The axis is already with a stride of 1 so check that we have no overlaps
// and broadcasting and avoid the copy.
if (x.strides(axis) == 1) {
// This is a fairly strict test perhaps consider relaxing it in the future.
if (x.flags().row_contiguous || x.flags().col_contiguous) {
return x;
}
}
// To make it the fastest moving axis simply transpose it, then copy it and
// then transpose it back.
// Transpose it
std::vector<int> axes(x.ndim(), 0);
for (int ax = 0; ax < axes.size(); ax++) {
axes[ax] = (ax < axis) ? ax : ax + 1;
}
axes.back() = axis;
Shape xtshape;
xtshape.reserve(axes.size());
for (auto ax : axes) {
xtshape.push_back(x.shape(ax));
}
array xt(xtshape, x.dtype(), nullptr, {});
transpose(x, xt, axes);
// Copy it
array xtc(xt.shape(), x.dtype(), nullptr, {});
copy_gpu(
xt,
xtc,
xt.flags().row_contiguous ? CopyType::Vector : CopyType::General,
s);
d.add_temporary(xtc, s.index);
// Transpose it
for (int ax = 0; ax < axes.size(); ax++) {
axes[ax] = (ax < axis) ? ax : ((ax == axis) ? axes.size() - 1 : ax - 1);
}
array y(x.shape(), x.dtype(), nullptr, {});
transpose(xtc, y, axes);
return y;
}
inline void prepare_output_array(const array& in, array& out, int axis) {
// Prepare the output array such that it matches the input in terms of
// stride ordering. Namely we might have moved `axis` around in the `in`
// array. We must do the same in `out`. The difference is that we don't have
// to copy anything because `out` contains garbage at the moment.
if (in.flags().row_contiguous && out.flags().row_contiguous) {
return;
}
std::vector<int> axes(out.ndim(), 0);
for (int ax = 0; ax < axes.size(); ax++) {
axes[ax] = (ax < axis) ? ax : ax + 1;
}
axes.back() = axis;
as_transposed(out, axes);
}
void fft_stockham_inplace(
const array& in,
const array& in_,
array& out,
size_t axis,
bool inverse,
bool real,
metal::Device& d,
const Stream& s) {
array in = ensure_fastest_moving_axis(in_, axis, d, s);
prepare_output_array(in, out, axis);
}
void fft_op_inplace(
@ -790,7 +857,7 @@ void fft_op_inplace(
size_t axis,
bool inverse,
bool real,
metal::Device &d,
metal::Device& d,
const Stream& s) {
// Get the FFT size and plan it
size_t n = out.dtype() == float32 ? out.shape(axis) : in.shape(axis);
@ -811,7 +878,7 @@ void nd_fft_op_inplace(
const std::vector<size_t>& axes,
bool inverse,
bool real,
metal::Device &d,
metal::Device& d,
const Stream& s) {
// We are going to make and possibly reuse some intermediate arrays that will
// hold the intermediate fft results.