mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
More tmp fft changes
This commit is contained in:
parent
1704809f29
commit
be57a16a80
@ -1,5 +1,7 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#include "mlx/backend/common/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
@ -28,4 +30,28 @@ void transpose(const array& in, array& out, const std::vector<int>& axes) {
|
||||
out.copy_shared_buffer(in, out_strides, flags, in.data_size());
|
||||
}
|
||||
|
||||
void as_transposed(array& out, const std::vector<int>& axes) {
|
||||
assert(out.data_size() == out.size() && out.flags().contiguous);
|
||||
|
||||
// Calculate the contiguous strides.
|
||||
Strides strides(out.ndim(), 1);
|
||||
for (int i = out.ndim() - 2; i >= 0; i--) {
|
||||
strides[i] = strides[i + 1] * out.shape(i);
|
||||
}
|
||||
|
||||
// Calculate the new strides for transposing.
|
||||
Strides new_strides;
|
||||
new_strides.reserve(out.ndim());
|
||||
for (auto ax : axes) {
|
||||
new_strides.push_back(strides[ax]);
|
||||
}
|
||||
|
||||
auto [ds, rc, cc] = check_contiguity(out.shape(), new_strides);
|
||||
auto flags = out.flags();
|
||||
flags.row_contiguous = rc;
|
||||
flags.col_contiguous = cc;
|
||||
|
||||
out.copy_shared_buffer(out, new_strides, flags, ds);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -7,5 +7,6 @@
|
||||
namespace mlx::core {
|
||||
|
||||
void transpose(const array& in, array& out, const std::vector<int>& axes);
|
||||
void as_transposed(array& out, const std::vector<int>& axes);
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -771,17 +771,84 @@ void fft_op(
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
}
|
||||
|
||||
inline array prepare_input(
|
||||
inline array ensure_fastest_moving_axis(
|
||||
const array& x,
|
||||
int axis,
|
||||
metal::Device& d,
|
||||
const Stream& s) {
|
||||
// The axis is already with a stride of 1 so check that we have no overlaps
|
||||
// and broadcasting and avoid the copy.
|
||||
if (x.strides(axis) == 1) {
|
||||
// This is a fairly strict test perhaps consider relaxing it in the future.
|
||||
if (x.flags().row_contiguous || x.flags().col_contiguous) {
|
||||
return x;
|
||||
}
|
||||
}
|
||||
|
||||
// To make it the fastest moving axis simply transpose it, then copy it and
|
||||
// then transpose it back.
|
||||
|
||||
// Transpose it
|
||||
std::vector<int> axes(x.ndim(), 0);
|
||||
for (int ax = 0; ax < axes.size(); ax++) {
|
||||
axes[ax] = (ax < axis) ? ax : ax + 1;
|
||||
}
|
||||
axes.back() = axis;
|
||||
Shape xtshape;
|
||||
xtshape.reserve(axes.size());
|
||||
for (auto ax : axes) {
|
||||
xtshape.push_back(x.shape(ax));
|
||||
}
|
||||
array xt(xtshape, x.dtype(), nullptr, {});
|
||||
transpose(x, xt, axes);
|
||||
|
||||
// Copy it
|
||||
array xtc(xt.shape(), x.dtype(), nullptr, {});
|
||||
copy_gpu(
|
||||
xt,
|
||||
xtc,
|
||||
xt.flags().row_contiguous ? CopyType::Vector : CopyType::General,
|
||||
s);
|
||||
d.add_temporary(xtc, s.index);
|
||||
|
||||
// Transpose it
|
||||
for (int ax = 0; ax < axes.size(); ax++) {
|
||||
axes[ax] = (ax < axis) ? ax : ((ax == axis) ? axes.size() - 1 : ax - 1);
|
||||
}
|
||||
array y(x.shape(), x.dtype(), nullptr, {});
|
||||
transpose(xtc, y, axes);
|
||||
|
||||
return y;
|
||||
}
|
||||
|
||||
inline void prepare_output_array(const array& in, array& out, int axis) {
|
||||
// Prepare the output array such that it matches the input in terms of
|
||||
// stride ordering. Namely we might have moved `axis` around in the `in`
|
||||
// array. We must do the same in `out`. The difference is that we don't have
|
||||
// to copy anything because `out` contains garbage at the moment.
|
||||
|
||||
if (in.flags().row_contiguous && out.flags().row_contiguous) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<int> axes(out.ndim(), 0);
|
||||
for (int ax = 0; ax < axes.size(); ax++) {
|
||||
axes[ax] = (ax < axis) ? ax : ax + 1;
|
||||
}
|
||||
axes.back() = axis;
|
||||
as_transposed(out, axes);
|
||||
}
|
||||
|
||||
void fft_stockham_inplace(
|
||||
const array& in,
|
||||
const array& in_,
|
||||
array& out,
|
||||
size_t axis,
|
||||
bool inverse,
|
||||
bool real,
|
||||
metal::Device& d,
|
||||
const Stream& s) {
|
||||
|
||||
array in = ensure_fastest_moving_axis(in_, axis, d, s);
|
||||
prepare_output_array(in, out, axis);
|
||||
}
|
||||
|
||||
void fft_op_inplace(
|
||||
@ -790,7 +857,7 @@ void fft_op_inplace(
|
||||
size_t axis,
|
||||
bool inverse,
|
||||
bool real,
|
||||
metal::Device &d,
|
||||
metal::Device& d,
|
||||
const Stream& s) {
|
||||
// Get the FFT size and plan it
|
||||
size_t n = out.dtype() == float32 ? out.shape(axis) : in.shape(axis);
|
||||
@ -811,7 +878,7 @@ void nd_fft_op_inplace(
|
||||
const std::vector<size_t>& axes,
|
||||
bool inverse,
|
||||
bool real,
|
||||
metal::Device &d,
|
||||
metal::Device& d,
|
||||
const Stream& s) {
|
||||
// We are going to make and possibly reuse some intermediate arrays that will
|
||||
// hold the intermediate fft results.
|
||||
|
Loading…
Reference in New Issue
Block a user