|
|
|
@@ -1,11 +1,13 @@
|
|
|
|
|
// Copyright © 2024 Apple Inc.
|
|
|
|
|
#include <cassert>
|
|
|
|
|
#include <complex>
|
|
|
|
|
#include <iostream>
|
|
|
|
|
#include <map>
|
|
|
|
|
#include <numeric>
|
|
|
|
|
#include <set>
|
|
|
|
|
|
|
|
|
|
#include "mlx/3rdparty/pocketfft.h"
|
|
|
|
|
#include "mlx/backend/common/transpose.h"
|
|
|
|
|
#include "mlx/backend/common/utils.h"
|
|
|
|
|
#include "mlx/backend/gpu/copy.h"
|
|
|
|
|
#include "mlx/backend/gpu/slicing.h"
|
|
|
|
@@ -27,7 +29,7 @@ using MTLFC = std::tuple<const void*, MTL::DataType, NS::UInteger>;
|
|
|
|
|
// For strided reads/writes, coalesce at least this many complex64s
|
|
|
|
|
#define MIN_COALESCE_WIDTH 4
|
|
|
|
|
|
|
|
|
|
inline const std::vector<int> supported_radices() {
|
|
|
|
|
inline constexpr std::array<int, 9> supported_radices() {
|
|
|
|
|
// Ordered by preference in decomposition.
|
|
|
|
|
return {13, 11, 8, 7, 6, 5, 4, 3, 2};
|
|
|
|
|
}
|
|
|
|
@@ -49,6 +51,35 @@ std::vector<int> prime_factors(int n) {
|
|
|
|
|
return factors;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int next_fast_n(int n) {
|
|
|
|
|
return next_power_of_2(n);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<int> stockham_decompose(int n) {
|
|
|
|
|
auto radices = supported_radices();
|
|
|
|
|
std::vector<int> steps(radices.size(), 0);
|
|
|
|
|
int orig_n = n;
|
|
|
|
|
|
|
|
|
|
for (int i = 0; i < radices.size(); i++) {
|
|
|
|
|
int radix = radices[i];
|
|
|
|
|
|
|
|
|
|
// Manually tuned radices for powers of 2
|
|
|
|
|
if (is_power_of_2(orig_n) && orig_n < 512 && radix > 4) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
while (n % radix == 0) {
|
|
|
|
|
steps[i] += 1;
|
|
|
|
|
n /= radix;
|
|
|
|
|
if (n == 1) {
|
|
|
|
|
return steps;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return {};
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
struct FourStepParams {
|
|
|
|
|
bool required = false;
|
|
|
|
|
bool first_step = true;
|
|
|
|
@@ -65,9 +96,10 @@ void fft_op(
|
|
|
|
|
bool real,
|
|
|
|
|
const FourStepParams four_step_params,
|
|
|
|
|
bool inplace,
|
|
|
|
|
metal::Device& d,
|
|
|
|
|
const Stream& s);
|
|
|
|
|
|
|
|
|
|
struct FFTPlan {
|
|
|
|
|
struct OldFFTPlan {
|
|
|
|
|
int n = 0;
|
|
|
|
|
// Number of steps for each radix in the Stockham decomposition
|
|
|
|
|
std::vector<int> stockham;
|
|
|
|
@@ -82,9 +114,104 @@ struct FFTPlan {
|
|
|
|
|
int n2 = 0;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
int next_fast_n(int n) {
|
|
|
|
|
return next_power_of_2(n);
|
|
|
|
|
}
|
|
|
|
|
class FFTPlan {
|
|
|
|
|
public:
|
|
|
|
|
enum FFTType {
|
|
|
|
|
UNSUPPORTED,
|
|
|
|
|
NOOP,
|
|
|
|
|
STOCKHAM,
|
|
|
|
|
RADER,
|
|
|
|
|
BLUESTEIN,
|
|
|
|
|
MULTIUPLOAD_BLUESTEIN,
|
|
|
|
|
SMALL_FOUR_STEP,
|
|
|
|
|
LARGE_FOUR_STEP
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
FFTPlan(int n) : n_(n) {
|
|
|
|
|
// NOOP
|
|
|
|
|
if (n == 1) {
|
|
|
|
|
type_ = NOOP;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Too large for Stockham so do four step fft for powers of 2
|
|
|
|
|
else if (n > MAX_STOCKHAM_FFT_SIZE && is_power_of_2(n)) {
|
|
|
|
|
if (n <= 1 << 20) {
|
|
|
|
|
type_ = SMALL_FOUR_STEP;
|
|
|
|
|
n2_ = n > 65536 ? 1024 : 64;
|
|
|
|
|
n1_ = n / n2_;
|
|
|
|
|
steps1_ = stockham_decompose(n1_);
|
|
|
|
|
steps2_ = stockham_decompose(n2_);
|
|
|
|
|
} else {
|
|
|
|
|
type_ = LARGE_FOUR_STEP;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Too large and not power of 2 so do multi-upload Bluestein fft
|
|
|
|
|
else if (n > MAX_STOCKHAM_FFT_SIZE) {
|
|
|
|
|
type_ = MULTIUPLOAD_BLUESTEIN;
|
|
|
|
|
bluestein_n_ = next_fast_n(2 * n - 1);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Stockham fft
|
|
|
|
|
else if (auto steps = stockham_decompose(n); steps.size() > 0) {
|
|
|
|
|
type_ = STOCKHAM;
|
|
|
|
|
steps_ = steps;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Add rader but for now simply fall back to bluestein when stockham not
|
|
|
|
|
// posssible
|
|
|
|
|
else if (n > MAX_BLUESTEIN_FFT_SIZE) {
|
|
|
|
|
type_ = MULTIUPLOAD_BLUESTEIN;
|
|
|
|
|
bluestein_n_ = next_fast_n(2 * n - 1);
|
|
|
|
|
} else {
|
|
|
|
|
type_ = BLUESTEIN;
|
|
|
|
|
bluestein_n_ = next_fast_n(2 * n - 1);
|
|
|
|
|
steps_ = stockham_decompose(bluestein_n_);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
FFTType type() const {
|
|
|
|
|
return type_;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int size() const {
|
|
|
|
|
return n_;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const std::vector<int>& steps() const {
|
|
|
|
|
return steps_;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int first_size() const {
|
|
|
|
|
return n1_;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const std::vector<int>& first_steps() const {
|
|
|
|
|
return steps1_;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int second_size() const {
|
|
|
|
|
return n2_;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const std::vector<int>& second_steps() const {
|
|
|
|
|
return steps2_;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int bluestein_size() const {
|
|
|
|
|
return bluestein_n_;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
int n_;
|
|
|
|
|
FFTType type_;
|
|
|
|
|
std::vector<int> steps_;
|
|
|
|
|
int n1_;
|
|
|
|
|
std::vector<int> steps1_;
|
|
|
|
|
int n2_;
|
|
|
|
|
std::vector<int> steps2_;
|
|
|
|
|
int bluestein_n_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
std::vector<int> plan_stockham_fft(int n) {
|
|
|
|
|
auto radices = supported_radices();
|
|
|
|
@@ -110,15 +237,12 @@ std::vector<int> plan_stockham_fft(int n) {
|
|
|
|
|
throw std::runtime_error("Unplannable");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
FFTPlan plan_fft(int n) {
|
|
|
|
|
OldFFTPlan plan_fft(int n) {
|
|
|
|
|
auto radices = supported_radices();
|
|
|
|
|
std::set<int> radices_set(radices.begin(), radices.end());
|
|
|
|
|
|
|
|
|
|
FFTPlan plan;
|
|
|
|
|
OldFFTPlan plan;
|
|
|
|
|
plan.n = n;
|
|
|
|
|
plan.rader = std::vector<int>(radices.size(), 0);
|
|
|
|
|
auto factors = prime_factors(n);
|
|
|
|
|
int remaining_n = n;
|
|
|
|
|
|
|
|
|
|
// Four Step FFT when N is too large for shared mem.
|
|
|
|
|
if (n > MAX_STOCKHAM_FFT_SIZE && is_power_of_2(n)) {
|
|
|
|
@@ -128,16 +252,20 @@ FFTPlan plan_fft(int n) {
|
|
|
|
|
plan.n2 = n > 65536 ? 1024 : 64;
|
|
|
|
|
plan.n1 = n / plan.n2;
|
|
|
|
|
return plan;
|
|
|
|
|
} else if (n > MAX_STOCKHAM_FFT_SIZE) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (n > MAX_STOCKHAM_FFT_SIZE) {
|
|
|
|
|
// Otherwise we use a multi-upload Bluestein's
|
|
|
|
|
plan.four_step = true;
|
|
|
|
|
plan.bluestein_n = next_fast_n(2 * n - 1);
|
|
|
|
|
return plan;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int remaining_n = n;
|
|
|
|
|
auto factors = prime_factors(n);
|
|
|
|
|
for (int factor : factors) {
|
|
|
|
|
// Make sure the factor is a supported radix
|
|
|
|
|
if (radices_set.find(factor) == radices_set.end()) {
|
|
|
|
|
if (std::find(radices.begin(), radices.end(), factor) == radices.end()) {
|
|
|
|
|
// We only support a single Rader factor currently
|
|
|
|
|
// TODO(alexbarron) investigate weirdness with large
|
|
|
|
|
// Rader sizes -- possibly a compiler issue?
|
|
|
|
@@ -154,7 +282,7 @@ FFTPlan plan_fft(int n) {
|
|
|
|
|
for (int rf : rader_factors) {
|
|
|
|
|
// We don't nest Rader's algorithm so if `factor - 1`
|
|
|
|
|
// isn't Stockham decomposable we give up and do Bluestein's.
|
|
|
|
|
if (radices_set.find(rf) == radices_set.end()) {
|
|
|
|
|
if (std::find(radices.begin(), radices.end(), rf) == radices.end()) {
|
|
|
|
|
plan.four_step = n > MAX_BLUESTEIN_FFT_SIZE;
|
|
|
|
|
plan.bluestein_n = next_fast_n(2 * n - 1);
|
|
|
|
|
plan.stockham = plan_stockham_fft(plan.bluestein_n);
|
|
|
|
@@ -172,7 +300,7 @@ FFTPlan plan_fft(int n) {
|
|
|
|
|
return plan;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int compute_elems_per_thread(FFTPlan plan) {
|
|
|
|
|
int compute_elems_per_thread(OldFFTPlan plan) {
|
|
|
|
|
// Heuristics for selecting an efficient number
|
|
|
|
|
// of threads to use for a particular mixed-radix FFT.
|
|
|
|
|
auto n = plan.n;
|
|
|
|
@@ -355,9 +483,11 @@ void multi_upload_bluestein_fft(
|
|
|
|
|
size_t axis,
|
|
|
|
|
bool inverse,
|
|
|
|
|
bool real,
|
|
|
|
|
FFTPlan& plan,
|
|
|
|
|
OldFFTPlan& plan,
|
|
|
|
|
std::vector<array>& copies,
|
|
|
|
|
const Stream& s) {
|
|
|
|
|
auto& d = metal::device(s.device);
|
|
|
|
|
|
|
|
|
|
// TODO(alexbarron) Implement fused kernels for mutli upload bluestein's
|
|
|
|
|
// algorithm
|
|
|
|
|
int n = inverse ? out.shape(axis) : in.shape(axis);
|
|
|
|
@@ -420,6 +550,7 @@ void multi_upload_bluestein_fft(
|
|
|
|
|
/*real=*/false,
|
|
|
|
|
FourStepParams(),
|
|
|
|
|
/*inplace=*/false,
|
|
|
|
|
d,
|
|
|
|
|
s);
|
|
|
|
|
copies.push_back(pad_temp1);
|
|
|
|
|
|
|
|
|
@@ -435,6 +566,7 @@ void multi_upload_bluestein_fft(
|
|
|
|
|
/* real= */ false,
|
|
|
|
|
FourStepParams(),
|
|
|
|
|
/*inplace=*/true,
|
|
|
|
|
d,
|
|
|
|
|
s);
|
|
|
|
|
|
|
|
|
|
int offset = plan.bluestein_n - (2 * n - 1);
|
|
|
|
@@ -480,7 +612,7 @@ void four_step_fft(
|
|
|
|
|
size_t axis,
|
|
|
|
|
bool inverse,
|
|
|
|
|
bool real,
|
|
|
|
|
FFTPlan& plan,
|
|
|
|
|
OldFFTPlan& plan,
|
|
|
|
|
std::vector<array>& copies,
|
|
|
|
|
const Stream& s,
|
|
|
|
|
bool in_place) {
|
|
|
|
@@ -493,7 +625,15 @@ void four_step_fft(
|
|
|
|
|
auto temp_shape = (real && inverse) ? out.shape() : in.shape();
|
|
|
|
|
array temp(temp_shape, complex64, nullptr, {});
|
|
|
|
|
fft_op(
|
|
|
|
|
in, temp, axis, inverse, real, four_step_params, /*inplace=*/false, s);
|
|
|
|
|
in,
|
|
|
|
|
temp,
|
|
|
|
|
axis,
|
|
|
|
|
inverse,
|
|
|
|
|
real,
|
|
|
|
|
four_step_params,
|
|
|
|
|
/*inplace=*/false,
|
|
|
|
|
d,
|
|
|
|
|
s);
|
|
|
|
|
four_step_params.first_step = false;
|
|
|
|
|
fft_op(
|
|
|
|
|
temp,
|
|
|
|
@@ -503,6 +643,7 @@ void four_step_fft(
|
|
|
|
|
real,
|
|
|
|
|
four_step_params,
|
|
|
|
|
/*inplace=*/in_place,
|
|
|
|
|
d,
|
|
|
|
|
s);
|
|
|
|
|
copies.push_back(temp);
|
|
|
|
|
} else {
|
|
|
|
@@ -518,9 +659,8 @@ void fft_op(
|
|
|
|
|
bool real,
|
|
|
|
|
const FourStepParams four_step_params,
|
|
|
|
|
bool inplace,
|
|
|
|
|
metal::Device& d,
|
|
|
|
|
const Stream& s) {
|
|
|
|
|
auto& d = metal::device(s.device);
|
|
|
|
|
|
|
|
|
|
size_t n = out.dtype() == float32 ? out.shape(axis) : in.shape(axis);
|
|
|
|
|
if (n == 1) {
|
|
|
|
|
out.copy_shared_buffer(in);
|
|
|
|
@@ -755,57 +895,517 @@ void fft_op(
|
|
|
|
|
d.add_temporaries(std::move(copies), s.index);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void fft_op(
|
|
|
|
|
inline int compute_elems_per_thread(int n, const std::vector<int>& steps) {
|
|
|
|
|
auto radices = supported_radices();
|
|
|
|
|
std::set<int> used_radices;
|
|
|
|
|
for (int i = 0; i < steps.size(); i++) {
|
|
|
|
|
if (steps[i] > 0) {
|
|
|
|
|
used_radices.insert(radices[i % radices.size()]);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Manual tuning for 7/11/13
|
|
|
|
|
if (used_radices.find(7) != used_radices.end() &&
|
|
|
|
|
(used_radices.find(11) != used_radices.end() ||
|
|
|
|
|
used_radices.find(13) != used_radices.end())) {
|
|
|
|
|
return 7;
|
|
|
|
|
} else if (
|
|
|
|
|
used_radices.find(11) != used_radices.end() &&
|
|
|
|
|
used_radices.find(13) != used_radices.end()) {
|
|
|
|
|
return 11;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// TODO(alexbarron) Some really weird stuff is going on
|
|
|
|
|
// for certain `elems_per_thread` on large composite n.
|
|
|
|
|
// Possibly a compiler issue?
|
|
|
|
|
if (n == 3159)
|
|
|
|
|
return 13;
|
|
|
|
|
if (n == 3645)
|
|
|
|
|
return 5;
|
|
|
|
|
if (n == 3969)
|
|
|
|
|
return 7;
|
|
|
|
|
if (n == 1982)
|
|
|
|
|
return 5;
|
|
|
|
|
|
|
|
|
|
if (used_radices.size() == 1) {
|
|
|
|
|
return *(used_radices.begin());
|
|
|
|
|
}
|
|
|
|
|
if (used_radices.size() == 2 &&
|
|
|
|
|
(used_radices.find(11) != used_radices.end() ||
|
|
|
|
|
used_radices.find(13) != used_radices.end())) {
|
|
|
|
|
return std::accumulate(used_radices.begin(), used_radices.end(), 0) / 2;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// In all other cases use the second smallest radix.
|
|
|
|
|
return *(++used_radices.begin());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
inline array ensure_fastest_moving_axis(
|
|
|
|
|
const array& x,
|
|
|
|
|
int axis,
|
|
|
|
|
metal::Device& d,
|
|
|
|
|
const Stream& s) {
|
|
|
|
|
// The axis is already with a stride of 1 so check that we have no overlaps
|
|
|
|
|
// and broadcasting and avoid the copy.
|
|
|
|
|
if (x.strides(axis) == 1) {
|
|
|
|
|
// This is a fairly strict test perhaps consider relaxing it in the future.
|
|
|
|
|
if (x.flags().row_contiguous || x.flags().col_contiguous) {
|
|
|
|
|
return x;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// To make it the fastest moving axis simply transpose it, then copy it and
|
|
|
|
|
// then transpose it back.
|
|
|
|
|
|
|
|
|
|
// Transpose it
|
|
|
|
|
std::vector<int> axes(x.ndim(), 0);
|
|
|
|
|
for (int ax = 0; ax < axes.size(); ax++) {
|
|
|
|
|
axes[ax] = (ax < axis) ? ax : ax + 1;
|
|
|
|
|
}
|
|
|
|
|
axes.back() = axis;
|
|
|
|
|
Shape xtshape;
|
|
|
|
|
xtshape.reserve(axes.size());
|
|
|
|
|
for (auto ax : axes) {
|
|
|
|
|
xtshape.push_back(x.shape(ax));
|
|
|
|
|
}
|
|
|
|
|
array xt(xtshape, x.dtype(), nullptr, {});
|
|
|
|
|
transpose(x, xt, axes);
|
|
|
|
|
|
|
|
|
|
// Copy it
|
|
|
|
|
array xtc(xt.shape(), x.dtype(), nullptr, {});
|
|
|
|
|
copy_gpu(
|
|
|
|
|
xt,
|
|
|
|
|
xtc,
|
|
|
|
|
xt.flags().row_contiguous ? CopyType::Vector : CopyType::General,
|
|
|
|
|
s);
|
|
|
|
|
d.add_temporary(xtc, s.index);
|
|
|
|
|
|
|
|
|
|
// Transpose it
|
|
|
|
|
for (int ax = 0; ax < axes.size(); ax++) {
|
|
|
|
|
axes[ax] = (ax < axis) ? ax : ((ax == axis) ? axes.size() - 1 : ax - 1);
|
|
|
|
|
}
|
|
|
|
|
array y(x.shape(), x.dtype(), nullptr, {});
|
|
|
|
|
transpose(xtc, y, axes);
|
|
|
|
|
|
|
|
|
|
return y;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
inline void prepare_output_array(const array& in, array& out, int axis) {
|
|
|
|
|
// Prepare the output array such that it matches the input in terms of
|
|
|
|
|
// stride ordering. Namely we might have moved `axis` around in the `in`
|
|
|
|
|
// array. We must do the same in `out`. The difference is that we don't have
|
|
|
|
|
// to copy anything because `out` contains garbage at the moment.
|
|
|
|
|
|
|
|
|
|
if (in.flags().row_contiguous && out.flags().row_contiguous) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<int> axes(out.ndim(), 0);
|
|
|
|
|
for (int ax = 0; ax < axes.size(); ax++) {
|
|
|
|
|
axes[ax] = (ax < axis) ? ax : ax + 1;
|
|
|
|
|
}
|
|
|
|
|
axes.back() = axis;
|
|
|
|
|
as_transposed(out, axes);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void fft_stockham_inplace(
|
|
|
|
|
const FFTPlan& plan,
|
|
|
|
|
const array& in_,
|
|
|
|
|
array& out,
|
|
|
|
|
size_t axis,
|
|
|
|
|
bool inverse,
|
|
|
|
|
bool real,
|
|
|
|
|
metal::Device& d,
|
|
|
|
|
const Stream& s) {
|
|
|
|
|
// Prepare the input and output arrays such that `axis` has stride 1.
|
|
|
|
|
// Possibly copy the input but never the output as it doesn't have anything
|
|
|
|
|
// useful in it yet.
|
|
|
|
|
array in = ensure_fastest_moving_axis(in_, axis, d, s);
|
|
|
|
|
prepare_output_array(in, out, axis);
|
|
|
|
|
|
|
|
|
|
// Prepare the arguments for stockham fft
|
|
|
|
|
int n = plan.size();
|
|
|
|
|
bool power_of_2 = is_power_of_2(n);
|
|
|
|
|
int total_batch_size =
|
|
|
|
|
out.dtype() == float32 ? out.size() / n : in.size() / n;
|
|
|
|
|
auto& steps = plan.steps();
|
|
|
|
|
int elems_per_thread = compute_elems_per_thread(n, steps);
|
|
|
|
|
int threads_per_fft = ceildiv(n, elems_per_thread);
|
|
|
|
|
int tg_batch_size = std::max(MIN_THREADGROUP_MEM_SIZE / n, 1);
|
|
|
|
|
int tg_mem_size = next_power_of_2(tg_batch_size * n);
|
|
|
|
|
int batch_size = ceildiv(total_batch_size, tg_batch_size);
|
|
|
|
|
batch_size = real ? ceildiv(batch_size, 2) : batch_size; // 2 RFFTs at once
|
|
|
|
|
std::vector<MTLFC> func_consts = {
|
|
|
|
|
{&inverse, MTL::DataType::DataTypeBool, 0},
|
|
|
|
|
{&power_of_2, MTL::DataType::DataTypeBool, 1},
|
|
|
|
|
{&elems_per_thread, MTL::DataType::DataTypeInt, 2}};
|
|
|
|
|
for (int i = 0; i < steps.size(); i++) {
|
|
|
|
|
func_consts.emplace_back(&steps[i], MTL::DataType::DataTypeInt, 4 + i);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Get the kernel
|
|
|
|
|
auto in_type = in.dtype() == float32 ? "float" : "float2";
|
|
|
|
|
auto out_type = out.dtype() == float32 ? "float" : "float2";
|
|
|
|
|
std::string hash_name;
|
|
|
|
|
std::string kname;
|
|
|
|
|
kname.reserve(64);
|
|
|
|
|
hash_name.reserve(64);
|
|
|
|
|
concatenate(kname, "fft_mem_", tg_mem_size, "_", in_type, "_", out_type);
|
|
|
|
|
concatenate(hash_name, kname, "_n", n, "_inv_", inverse);
|
|
|
|
|
auto template_def =
|
|
|
|
|
get_template_definition(kname, "fft", tg_mem_size, in_type, out_type);
|
|
|
|
|
auto kernel = get_fft_kernel(d, kname, hash_name, func_consts, template_def);
|
|
|
|
|
|
|
|
|
|
// Launch it
|
|
|
|
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
|
|
|
|
compute_encoder.set_compute_pipeline_state(kernel);
|
|
|
|
|
compute_encoder.set_input_array(in, 0);
|
|
|
|
|
compute_encoder.set_output_array(out, 1);
|
|
|
|
|
compute_encoder.set_bytes(n, 2);
|
|
|
|
|
compute_encoder.set_bytes(total_batch_size, 3);
|
|
|
|
|
|
|
|
|
|
MTL::Size group_dims(1, tg_batch_size, threads_per_fft);
|
|
|
|
|
MTL::Size grid_dims(batch_size, tg_batch_size, threads_per_fft);
|
|
|
|
|
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void fft_four_step_inplace(
|
|
|
|
|
const FFTPlan& plan,
|
|
|
|
|
const array& in_,
|
|
|
|
|
array& out,
|
|
|
|
|
size_t axis,
|
|
|
|
|
bool inverse,
|
|
|
|
|
bool real,
|
|
|
|
|
metal::Device& d,
|
|
|
|
|
const Stream& s) {
|
|
|
|
|
// Prepare the input and output arrays such that `axis` has stride 1.
|
|
|
|
|
// Possibly copy the input but never the output as it doesn't have anything
|
|
|
|
|
// useful in it yet.
|
|
|
|
|
array in = ensure_fastest_moving_axis(in_, axis, d, s);
|
|
|
|
|
prepare_output_array(in, out, axis);
|
|
|
|
|
|
|
|
|
|
// Also prepare the intermediate array for the four-step fft which is
|
|
|
|
|
// implemented with 2 kernel calls.
|
|
|
|
|
array intermediate(
|
|
|
|
|
(real && inverse) ? out.shape() : in.shape(), complex64, nullptr, {});
|
|
|
|
|
intermediate.set_data(allocator::malloc(intermediate.nbytes()));
|
|
|
|
|
prepare_output_array(in, intermediate, axis);
|
|
|
|
|
d.add_temporary(intermediate, s.index);
|
|
|
|
|
|
|
|
|
|
// Make the two calls
|
|
|
|
|
for (int step = 0; step < 2; step++) {
|
|
|
|
|
// Create the parameters
|
|
|
|
|
int n1 = plan.first_size();
|
|
|
|
|
int n2 = plan.second_size();
|
|
|
|
|
int n = (step == 0) ? n1 : n2;
|
|
|
|
|
bool power_of_2 = true;
|
|
|
|
|
int total_batch_size =
|
|
|
|
|
out.dtype() == float32 ? out.size() / n : in.size() / n;
|
|
|
|
|
auto& steps = (step == 0) ? plan.first_steps() : plan.second_steps();
|
|
|
|
|
int elems_per_thread = compute_elems_per_thread(n, steps);
|
|
|
|
|
int threads_per_fft = ceildiv(n, elems_per_thread);
|
|
|
|
|
int tg_batch_size =
|
|
|
|
|
std::max(MIN_THREADGROUP_MEM_SIZE / n, MIN_COALESCE_WIDTH);
|
|
|
|
|
int tg_mem_size = next_power_of_2(tg_batch_size * n);
|
|
|
|
|
int batch_size = ceildiv(total_batch_size, tg_batch_size);
|
|
|
|
|
std::vector<MTLFC> func_consts = {
|
|
|
|
|
{&inverse, MTL::DataType::DataTypeBool, 0},
|
|
|
|
|
{&power_of_2, MTL::DataType::DataTypeBool, 1},
|
|
|
|
|
{&elems_per_thread, MTL::DataType::DataTypeInt, 2}};
|
|
|
|
|
for (int i = 0; i < steps.size(); i++) {
|
|
|
|
|
func_consts.emplace_back(&steps[i], MTL::DataType::DataTypeInt, 4 + i);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Get the kernel
|
|
|
|
|
auto to_type = [](const array& x) {
|
|
|
|
|
return x.dtype() == float32 ? "float" : "float2";
|
|
|
|
|
};
|
|
|
|
|
auto in_type = step == 0 ? to_type(in) : to_type(intermediate);
|
|
|
|
|
auto out_type = step == 0 ? to_type(intermediate) : to_type(out);
|
|
|
|
|
std::string hash_name;
|
|
|
|
|
std::string kname;
|
|
|
|
|
kname.reserve(64);
|
|
|
|
|
hash_name.reserve(64);
|
|
|
|
|
concatenate(
|
|
|
|
|
kname,
|
|
|
|
|
"four_step_mem_",
|
|
|
|
|
tg_mem_size,
|
|
|
|
|
"_",
|
|
|
|
|
in_type,
|
|
|
|
|
"_",
|
|
|
|
|
out_type,
|
|
|
|
|
"_",
|
|
|
|
|
step,
|
|
|
|
|
(real ? "_true" : "_false"));
|
|
|
|
|
concatenate(hash_name, kname, "_n", n, "_inv_", inverse);
|
|
|
|
|
auto template_def = get_template_definition(
|
|
|
|
|
kname, "four_step_fft", tg_mem_size, in_type, out_type, step, real);
|
|
|
|
|
auto kernel =
|
|
|
|
|
get_fft_kernel(d, kname, hash_name, func_consts, template_def);
|
|
|
|
|
|
|
|
|
|
// Launch it
|
|
|
|
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
|
|
|
|
compute_encoder.set_compute_pipeline_state(kernel);
|
|
|
|
|
compute_encoder.set_input_array((step == 0) ? in : intermediate, 0);
|
|
|
|
|
compute_encoder.set_output_array((step == 0) ? intermediate : out, 1);
|
|
|
|
|
compute_encoder.set_bytes(n1, 2);
|
|
|
|
|
compute_encoder.set_bytes(n2, 3);
|
|
|
|
|
compute_encoder.set_bytes(total_batch_size, 4);
|
|
|
|
|
|
|
|
|
|
MTL::Size group_dims(1, tg_batch_size, threads_per_fft);
|
|
|
|
|
MTL::Size grid_dims(batch_size, tg_batch_size, threads_per_fft);
|
|
|
|
|
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void fft_bluestein(
|
|
|
|
|
const FFTPlan& plan,
|
|
|
|
|
const array& in_,
|
|
|
|
|
array& out,
|
|
|
|
|
size_t axis,
|
|
|
|
|
bool inverse,
|
|
|
|
|
bool real,
|
|
|
|
|
metal::Device& d,
|
|
|
|
|
const Stream& s) {
|
|
|
|
|
// Prepare the input and output arrays such that `axis` has stride 1.
|
|
|
|
|
// Possibly copy the input but never the output as it doesn't have anything
|
|
|
|
|
// useful in it yet.
|
|
|
|
|
array in = ensure_fastest_moving_axis(in_, axis, d, s);
|
|
|
|
|
prepare_output_array(in, out, axis);
|
|
|
|
|
|
|
|
|
|
// Prepare the arguments for bluestein fft
|
|
|
|
|
int n = plan.bluestein_size();
|
|
|
|
|
bool power_of_2 = true;
|
|
|
|
|
int total_batch_size = out.dtype() == float32 ? out.size() / plan.size()
|
|
|
|
|
: in.size() / plan.size();
|
|
|
|
|
auto& steps = plan.steps();
|
|
|
|
|
int elems_per_thread = compute_elems_per_thread(n, steps);
|
|
|
|
|
int threads_per_fft = ceildiv(n, elems_per_thread);
|
|
|
|
|
int tg_batch_size = std::max(MIN_THREADGROUP_MEM_SIZE / n, 1);
|
|
|
|
|
int tg_mem_size = next_power_of_2(tg_batch_size * n);
|
|
|
|
|
int batch_size = ceildiv(total_batch_size, tg_batch_size);
|
|
|
|
|
batch_size = real ? ceildiv(batch_size, 2) : batch_size; // 2 RFFTs at once
|
|
|
|
|
std::vector<MTLFC> func_consts = {
|
|
|
|
|
{&inverse, MTL::DataType::DataTypeBool, 0},
|
|
|
|
|
{&power_of_2, MTL::DataType::DataTypeBool, 1},
|
|
|
|
|
{&elems_per_thread, MTL::DataType::DataTypeInt, 2}};
|
|
|
|
|
for (int i = 0; i < steps.size(); i++) {
|
|
|
|
|
func_consts.emplace_back(&steps[i], MTL::DataType::DataTypeInt, 4 + i);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Get the kernel
|
|
|
|
|
auto in_type = in.dtype() == float32 ? "float" : "float2";
|
|
|
|
|
auto out_type = out.dtype() == float32 ? "float" : "float2";
|
|
|
|
|
std::string hash_name;
|
|
|
|
|
std::string kname;
|
|
|
|
|
kname.reserve(64);
|
|
|
|
|
hash_name.reserve(64);
|
|
|
|
|
concatenate(
|
|
|
|
|
kname, "bluestein_fft_mem_", tg_mem_size, "_", in_type, "_", out_type);
|
|
|
|
|
concatenate(hash_name, kname, "_n", n, "_inv_", inverse);
|
|
|
|
|
auto template_def = get_template_definition(
|
|
|
|
|
kname, "bluestein_fft", tg_mem_size, in_type, out_type);
|
|
|
|
|
auto kernel = get_fft_kernel(d, kname, hash_name, func_consts, template_def);
|
|
|
|
|
|
|
|
|
|
// Get the bluestein constants
|
|
|
|
|
auto [w_k, w_q] =
|
|
|
|
|
compute_bluestein_constants(plan.size(), plan.bluestein_size());
|
|
|
|
|
d.add_temporary(w_k, s.index);
|
|
|
|
|
d.add_temporary(w_q, s.index);
|
|
|
|
|
|
|
|
|
|
// Launch it
|
|
|
|
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
|
|
|
|
compute_encoder.set_compute_pipeline_state(kernel);
|
|
|
|
|
compute_encoder.set_input_array(in, 0);
|
|
|
|
|
compute_encoder.set_output_array(out, 1);
|
|
|
|
|
compute_encoder.set_input_array(w_q, 2);
|
|
|
|
|
compute_encoder.set_input_array(w_k, 3);
|
|
|
|
|
compute_encoder.set_bytes(plan.size(), 4);
|
|
|
|
|
compute_encoder.set_bytes(n, 5);
|
|
|
|
|
compute_encoder.set_bytes(total_batch_size, 6);
|
|
|
|
|
|
|
|
|
|
MTL::Size group_dims(1, tg_batch_size, threads_per_fft);
|
|
|
|
|
MTL::Size grid_dims(batch_size, tg_batch_size, threads_per_fft);
|
|
|
|
|
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void fft_multi_upload_bluestein(
|
|
|
|
|
const FFTPlan& plan,
|
|
|
|
|
const array& in_,
|
|
|
|
|
array& out,
|
|
|
|
|
size_t axis,
|
|
|
|
|
bool inverse,
|
|
|
|
|
bool real,
|
|
|
|
|
metal::Device& d,
|
|
|
|
|
const Stream& s) {
|
|
|
|
|
// Get Bluestein's constants using the CPU (this is done in the submission
|
|
|
|
|
// thread which is pretty bad).
|
|
|
|
|
auto [w_k, w_q] =
|
|
|
|
|
compute_bluestein_constants(plan.size(), plan.bluestein_size());
|
|
|
|
|
d.add_temporary(w_k, s.index);
|
|
|
|
|
d.add_temporary(w_q, s.index);
|
|
|
|
|
|
|
|
|
|
// Prepare the input
|
|
|
|
|
auto in_shape = inverse ? out.shape() : in_.shape();
|
|
|
|
|
array in(std::move(in_shape), complex64, nullptr, {});
|
|
|
|
|
if (real && !inverse) {
|
|
|
|
|
copy_gpu(
|
|
|
|
|
in_,
|
|
|
|
|
in,
|
|
|
|
|
in_.flags().row_contiguous ? CopyType::Vector : CopyType::General,
|
|
|
|
|
s);
|
|
|
|
|
d.add_temporary(in, s.index);
|
|
|
|
|
} else if (real && inverse) {
|
|
|
|
|
int back_offset = plan.size() % 2 == 0 ? 2 : 1;
|
|
|
|
|
auto slice_shape = in.shape();
|
|
|
|
|
slice_shape[axis] -= back_offset;
|
|
|
|
|
array slice_temp(slice_shape, complex64, nullptr, {});
|
|
|
|
|
array conj_temp(in.shape(), complex64, nullptr, {});
|
|
|
|
|
Shape rstarts(in.ndim(), 0);
|
|
|
|
|
Shape rstrides(in.ndim(), 1);
|
|
|
|
|
rstarts[axis] = in.shape(axis) - back_offset;
|
|
|
|
|
rstrides[axis] = -1;
|
|
|
|
|
unary_op_gpu({in_}, conj_temp, "Conjugate", s);
|
|
|
|
|
slice_gpu(in_, slice_temp, rstarts, rstrides, s);
|
|
|
|
|
concatenate_gpu({conj_temp, slice_temp}, in, (int)axis, s);
|
|
|
|
|
d.add_temporary(conj_temp, s.index);
|
|
|
|
|
} else if (inverse) {
|
|
|
|
|
unary_op_gpu({in_}, in, "Conjugate", s);
|
|
|
|
|
d.add_temporary(in, s.index);
|
|
|
|
|
} else {
|
|
|
|
|
in.copy_shared_buffer(in_);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Multiply with
|
|
|
|
|
Strides b_strides(in.ndim(), 0);
|
|
|
|
|
b_strides[axis] = 1;
|
|
|
|
|
array w_k_broadcast(in.shape(), complex64, nullptr, {});
|
|
|
|
|
w_k_broadcast.copy_shared_buffer(w_k, b_strides, {}, w_k.data_size());
|
|
|
|
|
array x(in.shape(), complex64, nullptr, {});
|
|
|
|
|
binary_op_gpu({in, w_k_broadcast}, x, "Multiply", s);
|
|
|
|
|
d.add_temporary(x, s.index);
|
|
|
|
|
|
|
|
|
|
// Pad
|
|
|
|
|
auto padded_shape = out.shape();
|
|
|
|
|
padded_shape[axis] = plan.bluestein_size();
|
|
|
|
|
array padded_x(padded_shape, complex64, nullptr, {});
|
|
|
|
|
auto zero = array(complex64_t{0.0f, 0.0f});
|
|
|
|
|
pad_gpu(x, zero, padded_x, {(int)axis}, {0}, s);
|
|
|
|
|
d.add_temporary(zero, s.index);
|
|
|
|
|
d.add_temporary(padded_x, s.index);
|
|
|
|
|
|
|
|
|
|
// First fft
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void fft_op_inplace(
|
|
|
|
|
const array& in,
|
|
|
|
|
array& out,
|
|
|
|
|
size_t axis,
|
|
|
|
|
bool inverse,
|
|
|
|
|
bool real,
|
|
|
|
|
bool inplace,
|
|
|
|
|
metal::Device& d,
|
|
|
|
|
const Stream& s) {
|
|
|
|
|
fft_op(in, out, axis, inverse, real, FourStepParams(), inplace, s);
|
|
|
|
|
// Get the FFT size and plan it
|
|
|
|
|
auto plan =
|
|
|
|
|
FFTPlan(out.dtype() == float32 ? out.shape(axis) : in.shape(axis));
|
|
|
|
|
|
|
|
|
|
switch (plan.type()) {
|
|
|
|
|
case FFTPlan::NOOP:
|
|
|
|
|
std::cout << "--------------> 1-size FFT <-----------------" << std::endl;
|
|
|
|
|
break;
|
|
|
|
|
case FFTPlan::STOCKHAM:
|
|
|
|
|
return fft_stockham_inplace(plan, in, out, axis, inverse, real, d, s);
|
|
|
|
|
case FFTPlan::SMALL_FOUR_STEP:
|
|
|
|
|
return fft_four_step_inplace(plan, in, out, axis, inverse, real, d, s);
|
|
|
|
|
case FFTPlan::BLUESTEIN:
|
|
|
|
|
return fft_bluestein(plan, in, out, axis, inverse, real, d, s);
|
|
|
|
|
case FFTPlan::UNSUPPORTED: {
|
|
|
|
|
std::string msg;
|
|
|
|
|
concatenate(msg, "FFT of size ", plan.size(), " not supported");
|
|
|
|
|
throw std::runtime_error(msg);
|
|
|
|
|
}
|
|
|
|
|
default:
|
|
|
|
|
std::cout << "----- NYI ----" << std::endl;
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void nd_fft_op(
|
|
|
|
|
void nd_fft_op_inplace(
|
|
|
|
|
const array& in,
|
|
|
|
|
array& out,
|
|
|
|
|
const std::vector<size_t>& axes,
|
|
|
|
|
bool inverse,
|
|
|
|
|
bool real,
|
|
|
|
|
metal::Device& d,
|
|
|
|
|
const Stream& s) {
|
|
|
|
|
// Perform ND FFT on GPU as a series of 1D FFTs
|
|
|
|
|
auto temp_shape = inverse ? in.shape() : out.shape();
|
|
|
|
|
array temp1(temp_shape, complex64, nullptr, {});
|
|
|
|
|
array temp2(temp_shape, complex64, nullptr, {});
|
|
|
|
|
std::vector<array> temp_arrs = {temp1, temp2};
|
|
|
|
|
for (int i = axes.size() - 1; i >= 0; i--) {
|
|
|
|
|
int reverse_index = axes.size() - i - 1;
|
|
|
|
|
// For 5D and above, we don't want to reallocate our two temporary arrays
|
|
|
|
|
bool inplace = reverse_index >= 3 && i != 0;
|
|
|
|
|
// Opposite order for fft vs ifft
|
|
|
|
|
int index = inverse ? reverse_index : i;
|
|
|
|
|
size_t axis = axes[index];
|
|
|
|
|
// Mirror np.fft.(i)rfftn and perform a real transform
|
|
|
|
|
// only on the final axis.
|
|
|
|
|
bool step_real = (real && index == axes.size() - 1);
|
|
|
|
|
auto step_shape = inverse ? out.shape(axis) : in.shape(axis);
|
|
|
|
|
const array& in_arr = i == axes.size() - 1 ? in : temp_arrs[1 - i % 2];
|
|
|
|
|
array& out_arr = i == 0 ? out : temp_arrs[i % 2];
|
|
|
|
|
fft_op(in_arr, out_arr, axis, inverse, step_real, inplace, s);
|
|
|
|
|
}
|
|
|
|
|
// We are going to make and possibly reuse some intermediate arrays that will
|
|
|
|
|
// hold the intermediate fft results.
|
|
|
|
|
auto shape = inverse ? in.shape() : out.shape();
|
|
|
|
|
std::vector<array> intermediates;
|
|
|
|
|
intermediates.reserve(2);
|
|
|
|
|
|
|
|
|
|
auto& d = metal::device(s.device);
|
|
|
|
|
d.add_temporaries(std::move(temp_arrs), s.index);
|
|
|
|
|
// Utility to return either in or one of the intermediates.
|
|
|
|
|
auto get_input_array = [&](int step) -> const array& {
|
|
|
|
|
// The first step so use the input array
|
|
|
|
|
if (step == 0) {
|
|
|
|
|
return in;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return intermediates[(step - 1) % 2];
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// Utility to return either out or one of the intermediates. It also informs
|
|
|
|
|
// us if we should allocate memory for that output or there is already some
|
|
|
|
|
// allocated.
|
|
|
|
|
auto get_output_array = [&](int step) -> array& {
|
|
|
|
|
// It is the final step so return the output array
|
|
|
|
|
if (step == axes.size() - 1) {
|
|
|
|
|
return out;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// We already have made an array that we can use so go ahead and use it and
|
|
|
|
|
// don't reallocate the memory.
|
|
|
|
|
if (step % 2 < intermediates.size()) {
|
|
|
|
|
return intermediates[step % 2];
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
array x(shape, complex64, nullptr, {});
|
|
|
|
|
x.set_data(allocator::malloc(x.nbytes()));
|
|
|
|
|
intermediates.emplace_back(std::move(x));
|
|
|
|
|
d.add_temporary(intermediates.back(), s.index);
|
|
|
|
|
|
|
|
|
|
return intermediates.back();
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// Perform ND FFT on GPU as a series of 1D FFTs
|
|
|
|
|
for (int step = 0; step < axes.size(); step++) {
|
|
|
|
|
auto x = get_input_array(step);
|
|
|
|
|
auto y = get_output_array(step);
|
|
|
|
|
auto step_axis = axes[inverse ? step : axes.size() - step - 1];
|
|
|
|
|
auto step_real = real && (inverse ? step == axes.size() - 1 : step == 0);
|
|
|
|
|
fft_op_inplace(x, y, step_axis, inverse, step_real, d, s);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void FFT::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|
|
|
|
auto& s = stream();
|
|
|
|
|
auto& d = metal::device(s.device);
|
|
|
|
|
auto& in = inputs[0];
|
|
|
|
|
|
|
|
|
|
// The FFT ops above have the *_inplace suffix. This means that the memory
|
|
|
|
|
// needs to be already allocated in the output array. Similar to
|
|
|
|
|
// copy_gpu_inplace and so on.
|
|
|
|
|
//
|
|
|
|
|
// Even though we allocate the memory, we do not necessarily want the
|
|
|
|
|
// contiguous strides so the *_inplace ops may change the strides and flags
|
|
|
|
|
// of the array but will not reallocate the memory.
|
|
|
|
|
|
|
|
|
|
out.set_data(allocator::malloc(out.nbytes()));
|
|
|
|
|
|
|
|
|
|
if (axes_.size() > 1) {
|
|
|
|
|
nd_fft_op(in, out, axes_, inverse_, real_, s);
|
|
|
|
|
nd_fft_op_inplace(in, out, axes_, inverse_, real_, d, s);
|
|
|
|
|
} else {
|
|
|
|
|
fft_op(in, out, axes_[0], inverse_, real_, /*inplace=*/false, s);
|
|
|
|
|
fft_op_inplace(in, out, axes_[0], inverse_, real_, d, s);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|