Feature complete Metal FFT (#1102)

* feature complete metal fft

* fix contiguity bug

* jit fft

* simplify rader/bluestein constant computation

* remove kernel/utils.h dep

* remove bf16.h dep

* format

---------

Co-authored-by: Alex Barron <abarron22@apple.com>
This commit is contained in:
Alex Barron 2024-06-06 12:57:25 -07:00 committed by GitHub
parent 0e585b4409
commit 27d70c7d9d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 2601 additions and 367 deletions

View File

@ -3,6 +3,8 @@
import matplotlib import matplotlib
import mlx.core as mx import mlx.core as mx
import numpy as np import numpy as np
import sympy
import torch
from time_utils import measure_runtime from time_utils import measure_runtime
matplotlib.use("Agg") matplotlib.use("Agg")
@ -16,40 +18,100 @@ def bandwidth_gb(runtime_ms, system_size):
return system_size * bytes_per_fft / runtime_ms * ms_per_s / bytes_per_gb return system_size * bytes_per_fft / runtime_ms * ms_per_s / bytes_per_gb
def run_bench(system_size): def run_bench(system_size, fft_sizes, backend="mlx", dim=1):
def fft(x): def fft_mlx(x):
out = mx.fft.fft(x) if dim == 1:
out = mx.fft.fft(x)
elif dim == 2:
out = mx.fft.fft2(x)
mx.eval(out) mx.eval(out)
return out return out
bandwidths = [] def fft_mps(x):
for k in range(4, 12): if dim == 1:
n = 2**k out = torch.fft.fft(x)
x = mx.random.uniform(shape=(system_size // n, n)).astype(mx.float32) elif dim == 2:
x = x.astype(mx.complex64) out = torch.fft.fft2(x)
mx.eval(x) torch.mps.synchronize()
runtime_ms = measure_runtime(fft, x=x) return out
bandwidths.append(bandwidth_gb(runtime_ms, system_size))
return bandwidths bandwidths = []
for n in fft_sizes:
batch_size = system_size // n**dim
shape = [batch_size] + [n for _ in range(dim)]
if backend == "mlx":
x_np = np.random.uniform(size=(system_size // n, n)).astype(np.complex64)
x = mx.array(x_np)
mx.eval(x)
fft = fft_mlx
elif backend == "mps":
x_np = np.random.uniform(size=(system_size // n, n)).astype(np.complex64)
x = torch.tensor(x_np, device="mps")
torch.mps.synchronize()
fft = fft_mps
else:
raise NotImplementedError()
runtime_ms = measure_runtime(fft, x=x)
bandwidth = bandwidth_gb(runtime_ms, np.prod(shape))
print(n, bandwidth)
bandwidths.append(bandwidth)
return np.array(bandwidths)
def time_fft(): def time_fft():
with mx.stream(mx.cpu): x = np.array(range(2, 512))
cpu_bandwidths = run_bench(system_size=int(2**22)) system_size = int(2**26)
print("MLX GPU")
with mx.stream(mx.gpu): with mx.stream(mx.gpu):
gpu_bandwidths = run_bench(system_size=int(2**29)) gpu_bandwidths = run_bench(system_size=system_size, fft_sizes=x)
# plot bandwidths print("MPS GPU")
x = [2**k for k in range(4, 12)] mps_bandwidths = run_bench(system_size=system_size, fft_sizes=x, backend="mps")
plt.scatter(x, gpu_bandwidths, color="green", label="GPU")
plt.scatter(x, cpu_bandwidths, color="red", label="CPU") print("CPU")
plt.title("MLX FFT Benchmark") system_size = int(2**20)
plt.xlabel("N") with mx.stream(mx.cpu):
plt.ylabel("Bandwidth (GB/s)") cpu_bandwidths = run_bench(system_size=system_size, fft_sizes=x)
plt.legend()
plt.savefig("fft_plot.png") x = np.array(x)
all_indices = x - x[0]
radix_2to13 = (
np.array([i for i in x if all(p <= 13 for p in sympy.primefactors(i))]) - x[0]
)
bluesteins = (
np.array([i for i in x if any(p > 13 for p in sympy.primefactors(i))]) - x[0]
)
for indices, name in [
(all_indices, "All"),
(radix_2to13, "Radix 2-13"),
(bluesteins, "Bluestein's"),
]:
# plot bandwidths
print(name)
plt.scatter(x[indices], gpu_bandwidths[indices], color="green", label="GPU")
plt.scatter(x[indices], mps_bandwidths[indices], color="blue", label="MPS")
plt.scatter(x[indices], cpu_bandwidths[indices], color="red", label="CPU")
plt.title(f"MLX FFT Benchmark -- {name}")
plt.xlabel("N")
plt.ylabel("Bandwidth (GB/s)")
plt.legend()
plt.savefig(f"{name}.png")
plt.clf()
av_gpu_bandwidth = np.mean(gpu_bandwidths)
av_mps_bandwidth = np.mean(mps_bandwidths)
av_cpu_bandwidth = np.mean(cpu_bandwidths)
print("Average bandwidths:")
print("GPU:", av_gpu_bandwidth)
print("MPS:", av_mps_bandwidth)
print("CPU:", av_cpu_bandwidth)
portion_faster = len(np.where(gpu_bandwidths > mps_bandwidths)[0]) / len(x)
print("Percent MLX faster than MPS: ", portion_faster * 100)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -64,6 +64,11 @@ if (MLX_METAL_JIT)
make_jit_source(unary) make_jit_source(unary)
make_jit_source(binary) make_jit_source(binary)
make_jit_source(binary_two) make_jit_source(binary_two)
make_jit_source(
fft
kernels/fft/radix.h
kernels/fft/readwrite.h
)
make_jit_source(ternary) make_jit_source(ternary)
make_jit_source(softmax) make_jit_source(softmax)
make_jit_source(scan) make_jit_source(scan)

View File

@ -1,106 +1,794 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023 Apple Inc.
#include <cassert>
#include <complex>
#include <map>
#include <numeric>
#include <set>
#include "mlx/3rdparty/pocketfft.h"
#include "mlx/backend/metal/binary.h"
#include "mlx/backend/metal/copy.h" #include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/slicing.h"
#include "mlx/backend/metal/unary.h"
#include "mlx/backend/metal/utils.h" #include "mlx/backend/metal/utils.h"
#include "mlx/mlx.h" #include "mlx/mlx.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
namespace mlx::core { namespace mlx::core {
void FFT::eval_gpu(const std::vector<array>& inputs, array& out) { using MTLFC = std::tuple<const void*, MTL::DataType, NS::UInteger>;
auto& s = out.primitive().stream();
auto& d = metal::device(s.device);
auto& in = inputs[0]; #define MAX_STOCKHAM_FFT_SIZE 4096
#define MAX_RADER_FFT_SIZE 2048
#define MAX_BLUESTEIN_FFT_SIZE 2048
// Threadgroup memory batching improves throughput for small n
#define MIN_THREADGROUP_MEM_SIZE 256
// For strided reads/writes, coalesce at least this many complex64s
#define MIN_COALESCE_WIDTH 4
if (axes_.size() == 0 || axes_.size() > 1 || inverse_ || inline const std::vector<int> supported_radices() {
in.dtype() != complex64 || out.dtype() != complex64) { // Ordered by preference in decomposition.
// Could also fallback to CPU implementation here. return {13, 11, 8, 7, 6, 5, 4, 3, 2};
throw std::runtime_error( }
"GPU FFT is only implemented for 1D, forward, complex FFTs.");
std::vector<int> prime_factors(int n) {
int z = 2;
std::vector<int> factors;
while (z * z <= n) {
if (n % z == 0) {
factors.push_back(z);
n /= z;
} else {
z++;
}
}
if (n > 1) {
factors.push_back(n);
}
return factors;
}
struct FourStepParams {
bool required = false;
bool first_step = true;
int n1 = 0;
int n2 = 0;
};
// Forward Declaration
void fft_op(
const array& in,
array& out,
size_t axis,
bool inverse,
bool real,
const FourStepParams four_step_params,
bool inplace,
const Stream& s);
struct FFTPlan {
int n = 0;
// Number of steps for each radix in the Stockham decomposition
std::vector<int> stockham;
// Number of steps for each radix in the Rader decomposition
std::vector<int> rader;
// Rader factor, 1 if no rader factors
int rader_n = 1;
int bluestein_n = -1;
// Four step FFT
bool four_step = false;
int n1 = 0;
int n2 = 0;
};
int next_fast_n(int n) {
return next_power_of_2(n);
}
std::vector<int> plan_stockham_fft(int n) {
auto radices = supported_radices();
std::vector<int> plan(radices.size(), 0);
int orig_n = n;
if (n == 1) {
return plan;
}
for (int i = 0; i < radices.size(); i++) {
int radix = radices[i];
// Manually tuned radices for powers of 2
if (is_power_of_2(orig_n) && orig_n < 512 && radix > 4) {
continue;
}
while (n % radix == 0) {
plan[i] += 1;
n /= radix;
if (n == 1) {
return plan;
}
}
}
throw std::runtime_error("Unplannable");
}
FFTPlan plan_fft(int n) {
auto radices = supported_radices();
std::set<int> radices_set(radices.begin(), radices.end());
FFTPlan plan;
plan.n = n;
plan.rader = std::vector<int>(radices.size(), 0);
auto factors = prime_factors(n);
int remaining_n = n;
// Four Step FFT when N is too large for shared mem.
if (n > MAX_STOCKHAM_FFT_SIZE && is_power_of_2(n)) {
// For power's of two we have a fast, no transpose four step implementation.
plan.four_step = true;
// Rough heuristic for choosing faster powers of two when we can
plan.n2 = n > 65536 ? 1024 : 64;
plan.n1 = n / plan.n2;
return plan;
} else if (n > MAX_STOCKHAM_FFT_SIZE) {
// Otherwise we use a multi-upload Bluestein's
plan.four_step = true;
plan.bluestein_n = next_fast_n(2 * n - 1);
return plan;
} }
size_t n = in.shape(axes_[0]); for (int factor : factors) {
// Make sure the factor is a supported radix
if (radices_set.find(factor) == radices_set.end()) {
// We only support a single Rader factor currently
// TODO(alexbarron) investigate weirdness with large
// Rader sizes -- possibly a compiler issue?
if (plan.rader_n > 1 || n > MAX_RADER_FFT_SIZE) {
plan.four_step = n > MAX_BLUESTEIN_FFT_SIZE;
plan.bluestein_n = next_fast_n(2 * n - 1);
plan.stockham = plan_stockham_fft(plan.bluestein_n);
plan.rader = std::vector<int>(radices.size(), 0);
return plan;
}
// See if we can use Rader's algorithm to Stockham decompose n - 1
auto rader_factors = prime_factors(factor - 1);
int last_factor = -1;
for (int rf : rader_factors) {
// We don't nest Rader's algorithm so if `factor - 1`
// isn't Stockham decomposable we give up and do Bluestein's.
if (radices_set.find(rf) == radices_set.end()) {
plan.four_step = n > MAX_BLUESTEIN_FFT_SIZE;
plan.bluestein_n = next_fast_n(2 * n - 1);
plan.stockham = plan_stockham_fft(plan.bluestein_n);
plan.rader = std::vector<int>(radices.size(), 0);
return plan;
}
}
plan.rader = plan_stockham_fft(factor - 1);
plan.rader_n = factor;
remaining_n /= factor;
}
}
if (!is_power_of_2(n) || n > 2048 || n < 4) { plan.stockham = plan_stockham_fft(remaining_n);
throw std::runtime_error( return plan;
"GPU FFT is only implemented for the powers of 2 from 4 -> 2048"); }
int compute_elems_per_thread(FFTPlan plan) {
// Heuristics for selecting an efficient number
// of threads to use for a particular mixed-radix FFT.
auto n = plan.n;
std::vector<int> steps;
auto radices = supported_radices();
steps.insert(steps.end(), plan.stockham.begin(), plan.stockham.end());
steps.insert(steps.end(), plan.rader.begin(), plan.rader.end());
std::set<int> used_radices;
for (int i = 0; i < steps.size(); i++) {
int radix = radices[i % radices.size()];
if (steps[i] > 0) {
used_radices.insert(radix);
}
}
// Manual tuning for 7/11/13
if (used_radices.find(7) != used_radices.end() &&
(used_radices.find(11) != used_radices.end() ||
used_radices.find(13) != used_radices.end())) {
return 7;
} else if (
used_radices.find(11) != used_radices.end() &&
used_radices.find(13) != used_radices.end()) {
return 11;
}
// TODO(alexbarron) Some really weird stuff is going on
// for certain `elems_per_thread` on large composite n.
// Possibly a compiler issue?
if (n == 3159)
return 13;
if (n == 3645)
return 5;
if (n == 3969)
return 7;
if (n == 1982)
return 5;
if (used_radices.size() == 1) {
return *(used_radices.begin());
}
if (used_radices.size() == 2) {
if (used_radices.find(11) != used_radices.end() ||
used_radices.find(13) != used_radices.end()) {
return std::accumulate(used_radices.begin(), used_radices.end(), 0) / 2;
}
std::vector<int> radix_vec(used_radices.begin(), used_radices.end());
return radix_vec[1];
}
// In all other cases use the second smallest radix.
std::vector<int> radix_vec(used_radices.begin(), used_radices.end());
return radix_vec[1];
}
// Rader
int mod_exp(int x, int y, int n) {
int out = 1;
while (y) {
if (y & 1) {
out = out * x % n;
}
y >>= 1;
x = x * x % n;
}
return out;
}
int primitive_root(int n) {
auto factors = prime_factors(n - 1);
for (int r = 2; r < n - 1; r++) {
bool found = true;
for (int factor : factors) {
if (mod_exp(r, (n - 1) / factor, n) == 1) {
found = false;
break;
}
}
if (found) {
return r;
}
}
return -1;
}
std::tuple<array, array, array> compute_raders_constants(
int rader_n,
const Stream& s) {
int proot = primitive_root(rader_n);
// Fermat's little theorem
int inv = mod_exp(proot, rader_n - 2, rader_n);
std::vector<short> g_q(rader_n - 1);
std::vector<short> g_minus_q(rader_n - 1);
for (int i = 0; i < rader_n - 1; i++) {
g_q[i] = mod_exp(proot, i, rader_n);
g_minus_q[i] = mod_exp(inv, i, rader_n);
}
array g_q_arr(g_q.begin(), {rader_n - 1});
array g_minus_q_arr(g_minus_q.begin(), {rader_n - 1});
std::vector<std::complex<float>> b_q(rader_n - 1);
for (int i = 0; i < rader_n - 1; i++) {
float pi_i = (float)g_minus_q[i] * -2.0 * M_PI / rader_n;
b_q[i] = std::exp(std::complex<float>(0, pi_i));
}
array b_q_fft({rader_n - 1}, complex64, nullptr, {});
b_q_fft.set_data(allocator::malloc_or_wait(b_q_fft.nbytes()));
auto b_q_fft_ptr =
reinterpret_cast<std::complex<float>*>(b_q_fft.data<complex64_t>());
std::ptrdiff_t item_size = b_q_fft.itemsize();
size_t fft_size = rader_n - 1;
// This FFT is always small (<4096, batch 1) so save some overhead
// and do it on the CPU
pocketfft::c2c(
/* shape= */ {fft_size},
/* stride_in= */ {item_size},
/* stride_out= */ {item_size},
/* axes= */ {0},
/* forward= */ true,
/* data_in= */ b_q.data(),
/* data_out= */ b_q_fft_ptr,
/* scale= */ 1.0f);
return std::make_tuple(b_q_fft, g_q_arr, g_minus_q_arr);
}
// Bluestein
std::pair<array, array> compute_bluestein_constants(int n, int bluestein_n) {
// We need to calculate the Bluestein twiddle factors
// in double precision for the overall numerical stability
// of Bluestein's FFT algorithm to be acceptable.
//
// Metal doesn't support float64, so instead we
// manually implement the required operations on cpu.
//
// In numpy:
// w_k = np.exp(-1j * np.pi / N * (np.arange(-N + 1, N) ** 2))
// w_q = np.fft.fft(1/w_k)
// return w_k, w_q
int length = 2 * n - 1;
std::vector<std::complex<float>> w_k_vec(n);
std::vector<std::complex<float>> w_q_vec(bluestein_n, 0);
for (int i = -n + 1; i < n; i++) {
double theta = pow(i, 2) * M_PI / (double)n;
w_q_vec[i + n - 1] = std::exp(std::complex<double>(0, theta));
if (i >= 0) {
w_k_vec[i] = std::exp(std::complex<double>(0, -theta));
}
}
array w_k({n}, complex64, nullptr, {});
w_k.set_data(allocator::malloc_or_wait(w_k.nbytes()));
std::copy(w_k_vec.begin(), w_k_vec.end(), w_k.data<complex64_t>());
array w_q({bluestein_n}, complex64, nullptr, {});
w_q.set_data(allocator::malloc_or_wait(w_q.nbytes()));
auto w_q_ptr =
reinterpret_cast<std::complex<float>*>(w_q.data<complex64_t>());
std::ptrdiff_t item_size = w_q.itemsize();
size_t fft_size = bluestein_n;
pocketfft::c2c(
/* shape= */ {fft_size},
/* stride_in= */ {item_size},
/* stride_out= */ {item_size},
/* axes= */ {0},
/* forward= */ true,
/* data_in= */ w_q_vec.data(),
/* data_out= */ w_q_ptr,
/* scale= */ 1.0f);
return std::make_tuple(w_k, w_q);
}
void multi_upload_bluestein_fft(
const array& in,
array& out,
size_t axis,
bool inverse,
bool real,
FFTPlan& plan,
std::vector<array> copies,
const Stream& s) {
// TODO(alexbarron) Implement fused kernels for mutli upload bluestein's
// algorithm
int n = inverse ? out.shape(axis) : in.shape(axis);
auto [w_k, w_q] = compute_bluestein_constants(n, plan.bluestein_n);
// Broadcast w_q and w_k to the batch size
std::vector<size_t> b_strides(in.ndim(), 0);
b_strides[axis] = 1;
array w_k_broadcast({}, complex64, nullptr, {});
array w_q_broadcast({}, complex64, nullptr, {});
w_k_broadcast.copy_shared_buffer(w_k, b_strides, {}, w_k.data_size());
w_q_broadcast.copy_shared_buffer(w_q, b_strides, {}, w_q.data_size());
auto temp_shape = inverse ? out.shape() : in.shape();
array temp(temp_shape, complex64, nullptr, {});
array temp1(temp_shape, complex64, nullptr, {});
if (real && !inverse) {
// Convert float32->complex64
copy_gpu(in, temp, CopyType::General, s);
} else if (real && inverse) {
int back_offset = n % 2 == 0 ? 2 : 1;
auto slice_shape = in.shape();
slice_shape[axis] -= back_offset;
array slice_temp(slice_shape, complex64, nullptr, {});
array conj_temp(in.shape(), complex64, nullptr, {});
copies.push_back(slice_temp);
copies.push_back(conj_temp);
std::vector<int> rstarts(in.ndim(), 0);
std::vector<int> rstrides(in.ndim(), 1);
rstarts[axis] = in.shape(axis) - back_offset;
rstrides[axis] = -1;
unary_op_gpu({in}, conj_temp, "conj", s);
slice_gpu(in, slice_temp, rstarts, rstrides, s);
concatenate_gpu({conj_temp, slice_temp}, temp, (int)axis, s);
} else if (inverse) {
unary_op_gpu({in}, temp, "conj", s);
} else {
temp.copy_shared_buffer(in);
}
binary_op_gpu({temp, w_k_broadcast}, temp1, "mul", s);
std::vector<std::pair<int, int>> pads;
auto padded_shape = out.shape();
padded_shape[axis] = plan.bluestein_n;
array pad_temp(padded_shape, complex64, nullptr, {});
pad_gpu(temp1, array(complex64_t{0.0f, 0.0f}), pad_temp, {(int)axis}, {0}, s);
array pad_temp1(padded_shape, complex64, nullptr, {});
fft_op(
pad_temp,
pad_temp1,
axis,
/*inverse=*/false,
/*real=*/false,
FourStepParams(),
/*inplace=*/false,
s);
binary_op_gpu_inplace({pad_temp1, w_q_broadcast}, pad_temp, "mul", s);
fft_op(
pad_temp,
pad_temp1,
axis,
/* inverse= */ true,
/* real= */ false,
FourStepParams(),
/*inplace=*/true,
s);
int offset = plan.bluestein_n - (2 * n - 1);
std::vector<int> starts(in.ndim(), 0);
std::vector<int> strides(in.ndim(), 1);
starts[axis] = plan.bluestein_n - offset - n;
slice_gpu(pad_temp1, temp, starts, strides, s);
binary_op_gpu_inplace({temp, w_k_broadcast}, temp1, "mul", s);
if (real && !inverse) {
std::vector<int> rstarts(in.ndim(), 0);
std::vector<int> rstrides(in.ndim(), 1);
slice_gpu(temp1, out, rstarts, strides, s);
} else if (real && inverse) {
std::vector<size_t> b_strides(in.ndim(), 0);
auto inv_n = array({1.0f / n}, {1}, float32);
array temp_float(out.shape(), out.dtype(), nullptr, {});
copies.push_back(temp_float);
copies.push_back(inv_n);
copy_gpu(temp1, temp_float, CopyType::General, s);
binary_op_gpu({temp_float, inv_n}, out, "mul", s);
} else if (inverse) {
auto inv_n = array({1.0f / n}, {1}, complex64);
unary_op_gpu({temp1}, temp, "conj", s);
binary_op_gpu({temp, inv_n}, out, "mul", s);
copies.push_back(inv_n);
} else {
out.copy_shared_buffer(temp1);
}
copies.push_back(w_k);
copies.push_back(w_q);
copies.push_back(w_k_broadcast);
copies.push_back(w_q_broadcast);
copies.push_back(temp);
copies.push_back(temp1);
copies.push_back(pad_temp);
copies.push_back(pad_temp1);
}
void four_step_fft(
const array& in,
array& out,
size_t axis,
bool inverse,
bool real,
FFTPlan& plan,
std::vector<array> copies,
const Stream& s) {
auto& d = metal::device(s.device);
if (plan.bluestein_n == -1) {
// Fast no transpose implementation for powers of 2.
FourStepParams four_step_params = {
/* required= */ true, /* first_step= */ true, plan.n1, plan.n2};
auto temp_shape = (real && inverse) ? out.shape() : in.shape();
array temp(temp_shape, complex64, nullptr, {});
fft_op(
in, temp, axis, inverse, real, four_step_params, /*inplace=*/false, s);
four_step_params.first_step = false;
fft_op(
temp, out, axis, inverse, real, four_step_params, /*inplace=*/false, s);
copies.push_back(temp);
} else {
multi_upload_bluestein_fft(in, out, axis, inverse, real, plan, copies, s);
}
}
void fft_op(
const array& in,
array& out,
size_t axis,
bool inverse,
bool real,
const FourStepParams four_step_params,
bool inplace,
const Stream& s) {
auto& d = metal::device(s.device);
size_t n = out.dtype() == float32 ? out.shape(axis) : in.shape(axis);
if (n == 1) {
out.copy_shared_buffer(in);
return;
}
if (four_step_params.required) {
// Four Step FFT decomposes into two FFTs: n1 on columns, n2 on rows
n = four_step_params.first_step ? four_step_params.n1 : four_step_params.n2;
} }
// Make sure that the array is contiguous and has stride 1 in the FFT dim // Make sure that the array is contiguous and has stride 1 in the FFT dim
std::vector<array> copies; std::vector<array> copies;
auto check_input = [this, &copies, &s](const array& x) { auto check_input = [&axis, &copies, &s](const array& x) {
// TODO: Pass the strides to the kernel so // TODO: Pass the strides to the kernel so
// we can avoid the copy when x is not contiguous. // we can avoid the copy when x is not contiguous.
bool no_copy = x.strides()[axes_[0]] == 1 && x.flags().row_contiguous || bool no_copy = x.strides()[axis] == 1 &&
x.flags().col_contiguous; (x.flags().row_contiguous || x.flags().col_contiguous);
if (no_copy) { if (no_copy) {
return x; return x;
} else { } else {
array x_copy(x.shape(), x.dtype(), nullptr, {}); array x_copy(x.shape(), x.dtype(), nullptr, {});
std::vector<size_t> strides; std::vector<size_t> strides;
size_t cur_stride = x.shape(axes_[0]); size_t cur_stride = x.shape(axis);
for (int axis = 0; axis < x.ndim(); axis++) { for (int a = 0; a < x.ndim(); a++) {
if (axis == axes_[0]) { if (a == axis) {
strides.push_back(1); strides.push_back(1);
} else { } else {
strides.push_back(cur_stride); strides.push_back(cur_stride);
cur_stride *= x.shape(axis); cur_stride *= x.shape(a);
} }
} }
auto flags = x.flags(); auto flags = x.flags();
size_t f_stride = 1; auto [data_size, is_row_contiguous, is_col_contiguous] =
size_t b_stride = 1; check_contiguity(x.shape(), strides);
flags.col_contiguous = true;
flags.row_contiguous = true; flags.col_contiguous = is_row_contiguous;
for (int i = 0, ri = x.ndim() - 1; i < x.ndim(); ++i, --ri) { flags.row_contiguous = is_col_contiguous;
flags.col_contiguous &= (strides[i] == f_stride || x.shape(i) == 1); flags.contiguous = data_size == x_copy.size();
f_stride *= x.shape(i);
flags.row_contiguous &= (strides[ri] == b_stride || x.shape(ri) == 1);
b_stride *= x.shape(ri);
}
// This is probably over-conservative
flags.contiguous = false;
x_copy.set_data( x_copy.set_data(
allocator::malloc_or_wait(x.nbytes()), x.data_size(), strides, flags); allocator::malloc_or_wait(x.nbytes()), data_size, strides, flags);
copy_gpu_inplace(x, x_copy, CopyType::GeneralGeneral, s); copy_gpu_inplace(x, x_copy, CopyType::GeneralGeneral, s);
copies.push_back(x_copy); copies.push_back(x_copy);
return x_copy; return x_copy;
} }
}; };
const array& in_contiguous = check_input(inputs[0]); const array& in_contiguous = check_input(in);
// real to complex: n -> (n/2)+1
// complex to real: (n/2)+1 -> n
auto out_strides = in_contiguous.strides();
size_t out_data_size = in_contiguous.data_size();
if (in.shape(axis) != out.shape(axis)) {
for (int i = 0; i < out_strides.size(); i++) {
if (out_strides[i] != 1) {
out_strides[i] = out_strides[i] / in.shape(axis) * out.shape(axis);
}
}
out_data_size = out_data_size / in.shape(axis) * out.shape(axis);
}
auto plan = plan_fft(n);
if (plan.four_step) {
four_step_fft(in, out, axis, inverse, real, plan, copies, s);
d.get_command_buffer(s.index)->addCompletedHandler(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
return;
}
// TODO: allow donation here // TODO: allow donation here
out.set_data( if (!inplace) {
allocator::malloc_or_wait(out.nbytes()), out.set_data(
in_contiguous.data_size(), allocator::malloc_or_wait(out.nbytes()),
in_contiguous.strides(), out_data_size,
in_contiguous.flags()); out_strides,
in_contiguous.flags());
}
// We use n / 4 threads by default since radix-4 auto radices = supported_radices();
// is the largest single threaded radix butterfly int fft_size = plan.bluestein_n > 0 ? plan.bluestein_n : n;
// we currently implement.
size_t m = n / 4; // Setup function constants
size_t batch = in.size() / in.shape(axes_[0]); bool power_of_2 = is_power_of_2(fft_size);
auto make_int = [](int* a, int i) {
return std::make_tuple(a, MTL::DataType::DataTypeInt, i);
};
auto make_bool = [](bool* a, int i) {
return std::make_tuple(a, MTL::DataType::DataTypeBool, i);
};
std::vector<MTLFC> func_consts = {
make_bool(&inverse, 0), make_bool(&power_of_2, 1)};
// Start of radix/rader step constants
int index = 4;
for (int i = 0; i < plan.stockham.size(); i++) {
func_consts.push_back(make_int(&plan.stockham[i], index));
index += 1;
}
for (int i = 0; i < plan.rader.size(); i++) {
func_consts.push_back(make_int(&plan.rader[i], index));
index += 1;
}
int elems_per_thread = compute_elems_per_thread(plan);
func_consts.push_back(make_int(&elems_per_thread, 2));
int rader_m = n / plan.rader_n;
func_consts.push_back(make_int(&rader_m, 3));
// The overall number of FFTs we're going to compute for this input
int size = out.dtype() == float32 ? out.size() : in.size();
if (real && inverse && four_step_params.required) {
size = out.size();
}
int total_batch_size = size / n;
int threads_per_fft = (fft_size + elems_per_thread - 1) / elems_per_thread;
// We batch among threadgroups for improved efficiency when n is small
int threadgroup_batch_size = std::max(MIN_THREADGROUP_MEM_SIZE / fft_size, 1);
if (four_step_params.required) {
// Require a threadgroup batch size of at least 4 for four step FFT
// so we can coalesce the memory accesses.
threadgroup_batch_size =
std::max(threadgroup_batch_size, MIN_COALESCE_WIDTH);
}
int threadgroup_mem_size = next_power_of_2(threadgroup_batch_size * fft_size);
// FFTs up to 2^20 are currently supported
assert(threadgroup_mem_size <= MAX_STOCKHAM_FFT_SIZE);
// ceil divide
int batch_size =
(total_batch_size + threadgroup_batch_size - 1) / threadgroup_batch_size;
if (real && !four_step_params.required) {
// We can perform 2 RFFTs at once so the batch size is halved.
batch_size = (batch_size + 2 - 1) / 2;
}
int out_buffer_size = out.size();
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
auto in_type_str = in.dtype() == float32 ? "float" : "float2";
auto out_type_str = out.dtype() == float32 ? "float" : "float2";
// Only required by four step
int step = -1;
{ {
std::ostringstream kname; std::ostringstream kname;
kname << "fft_" << n; std::string inv_string = inverse ? "true" : "false";
auto kernel = d.get_kernel(kname.str()); std::string real_string = real ? "true" : "false";
if (plan.bluestein_n > 0) {
kname << "bluestein_fft_mem_" << threadgroup_mem_size << "_"
<< in_type_str << "_" << out_type_str;
} else if (plan.rader_n > 1) {
kname << "rader_fft_mem_" << threadgroup_mem_size << "_" << in_type_str
<< "_" << out_type_str;
} else if (four_step_params.required) {
step = four_step_params.first_step ? 0 : 1;
kname << "four_step_mem_" << threadgroup_mem_size << "_" << in_type_str
<< "_" << out_type_str << "_" << step << "_" << real_string;
} else {
kname << "fft_mem_" << threadgroup_mem_size << "_" << in_type_str << "_"
<< out_type_str;
}
std::string base_name = kname.str();
// We use a specialized kernel for each FFT size
kname << "_n" << fft_size << "_inv_" << inverse;
std::string hash_name = kname.str();
auto kernel = get_fft_kernel(
d,
base_name,
hash_name,
threadgroup_mem_size,
in_type_str,
out_type_str,
step,
real,
func_consts);
bool donated = in.data_shared_ptr() == nullptr;
compute_encoder->setComputePipelineState(kernel); compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array(in_contiguous, 0); compute_encoder.set_input_array(in_contiguous, 0);
compute_encoder.set_output_array(out, 1); compute_encoder.set_output_array(out, 1);
auto group_dims = MTL::Size(1, m, 1); if (plan.bluestein_n > 0) {
auto grid_dims = MTL::Size(batch, m, 1); // Precomputed twiddle factors for Bluestein's
compute_encoder.dispatchThreads(grid_dims, group_dims); auto [w_k, w_q] = compute_bluestein_constants(n, plan.bluestein_n);
copies.push_back(w_q);
copies.push_back(w_k);
compute_encoder.set_input_array(w_q, 2); // w_q
compute_encoder.set_input_array(w_k, 3); // w_k
compute_encoder->setBytes(&n, sizeof(int), 4);
compute_encoder->setBytes(&plan.bluestein_n, sizeof(int), 5);
compute_encoder->setBytes(&total_batch_size, sizeof(int), 6);
} else if (plan.rader_n > 1) {
auto [b_q, g_q, g_minus_q] = compute_raders_constants(plan.rader_n, s);
copies.push_back(b_q);
copies.push_back(g_q);
copies.push_back(g_minus_q);
compute_encoder.set_input_array(b_q, 2);
compute_encoder.set_input_array(g_q, 3);
compute_encoder.set_input_array(g_minus_q, 4);
compute_encoder->setBytes(&n, sizeof(int), 5);
compute_encoder->setBytes(&total_batch_size, sizeof(int), 6);
compute_encoder->setBytes(&plan.rader_n, sizeof(int), 7);
} else if (four_step_params.required) {
compute_encoder->setBytes(&four_step_params.n1, sizeof(int), 2);
compute_encoder->setBytes(&four_step_params.n2, sizeof(int), 3);
compute_encoder->setBytes(&total_batch_size, sizeof(int), 4);
} else {
compute_encoder->setBytes(&n, sizeof(int), 2);
compute_encoder->setBytes(&total_batch_size, sizeof(int), 3);
}
auto group_dims = MTL::Size(1, threadgroup_batch_size, threads_per_fft);
auto grid_dims =
MTL::Size(batch_size, threadgroup_batch_size, threads_per_fft);
compute_encoder->dispatchThreads(grid_dims, group_dims);
} }
d.get_command_buffer(s.index)->addCompletedHandler( d.get_command_buffer(s.index)->addCompletedHandler(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); }); [copies](MTL::CommandBuffer*) mutable { copies.clear(); });
} }
void fft_op(
const array& in,
array& out,
size_t axis,
bool inverse,
bool real,
bool inplace,
const Stream& s) {
fft_op(in, out, axis, inverse, real, FourStepParams(), inplace, s);
}
void nd_fft_op(
const array& in,
array& out,
const std::vector<size_t>& axes,
bool inverse,
bool real,
const Stream& s) {
// Perform ND FFT on GPU as a series of 1D FFTs
auto temp_shape = inverse ? in.shape() : out.shape();
array temp1(temp_shape, complex64, nullptr, {});
array temp2(temp_shape, complex64, nullptr, {});
std::vector<array> temp_arrs = {temp1, temp2};
for (int i = axes.size() - 1; i >= 0; i--) {
int reverse_index = axes.size() - i - 1;
// For 5D and above, we don't want to reallocate our two temporary arrays
bool inplace = reverse_index >= 3 && i != 0;
// Opposite order for fft vs ifft
int index = inverse ? reverse_index : i;
size_t axis = axes[index];
// Mirror np.fft.(i)rfftn and perform a real transform
// only on the final axis.
bool step_real = (real && index == axes.size() - 1);
int step_shape = inverse ? out.shape(axis) : in.shape(axis);
const array& in_arr = i == axes.size() - 1 ? in : temp_arrs[1 - i % 2];
array& out_arr = i == 0 ? out : temp_arrs[i % 2];
fft_op(in_arr, out_arr, axis, inverse, step_real, inplace, s);
}
std::vector<array> copies = {temp1, temp2};
auto& d = metal::device(s.device);
d.get_command_buffer(s.index)->addCompletedHandler(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
}
void FFT::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& s = stream();
auto& in = inputs[0];
if (axes_.size() > 1) {
nd_fft_op(in, out, axes_, inverse_, real_, s);
} else {
fft_op(in, out, axes_[0], inverse_, real_, /*inplace=*/false, s);
}
}
} // namespace mlx::core } // namespace mlx::core

View File

@ -0,0 +1,53 @@
// Copyright © 2024 Apple Inc.
constexpr std::string_view fft_kernel = R"(
template [[host_name("{name}")]] [[kernel]] void
fft<{tg_mem_size}, {in_T}, {out_T}>(
const device {in_T}* in [[buffer(0)]],
device {out_T}* out [[buffer(1)]],
constant const int& n,
constant const int& batch_size,
uint3 elem [[thread_position_in_grid]],
uint3 grid [[threads_per_grid]]);
)";
constexpr std::string_view rader_fft_kernel = R"(
template [[host_name("{name}")]] [[kernel]] void
rader_fft<{tg_mem_size}, {in_T}, {out_T}>(
const device {in_T}* in [[buffer(0)]],
device {out_T}* out [[buffer(1)]],
const device float2* raders_b_q [[buffer(2)]],
const device short* raders_g_q [[buffer(3)]],
const device short* raders_g_minus_q [[buffer(4)]],
constant const int& n,
constant const int& batch_size,
constant const int& rader_n,
uint3 elem [[thread_position_in_grid]],
uint3 grid [[threads_per_grid]]);
)";
constexpr std::string_view bluestein_fft_kernel = R"(
template [[host_name("{name}")]] [[kernel]] void
bluestein_fft<{tg_mem_size}, {in_T}, {out_T}>(
const device {in_T}* in [[buffer(0)]],
device {out_T}* out [[buffer(1)]],
const device float2* w_q [[buffer(2)]],
const device float2* w_k [[buffer(3)]],
constant const int& length,
constant const int& n,
constant const int& batch_size,
uint3 elem [[thread_position_in_grid]],
uint3 grid [[threads_per_grid]]);
)";
constexpr std::string_view four_step_fft_kernel = R"(
template [[host_name("{name}")]] [[kernel]] void
four_step_fft<{tg_mem_size}, {in_T}, {out_T}, {step}, {real}>(
const device {in_T}* in [[buffer(0)]],
device {out_T}* out [[buffer(1)]],
constant const int& n1,
constant const int& n2,
constant const int& batch_size,
uint3 elem [[thread_position_in_grid]],
uint3 grid [[threads_per_grid]]);
)";

View File

@ -17,6 +17,7 @@ const char* unary();
const char* binary(); const char* binary();
const char* binary_two(); const char* binary_two();
const char* copy(); const char* copy();
const char* fft();
const char* ternary(); const char* ternary();
const char* scan(); const char* scan();
const char* softmax(); const char* softmax();

View File

@ -6,6 +6,7 @@
#include "mlx/backend/metal/jit/binary.h" #include "mlx/backend/metal/jit/binary.h"
#include "mlx/backend/metal/jit/binary_two.h" #include "mlx/backend/metal/jit/binary_two.h"
#include "mlx/backend/metal/jit/copy.h" #include "mlx/backend/metal/jit/copy.h"
#include "mlx/backend/metal/jit/fft.h"
#include "mlx/backend/metal/jit/includes.h" #include "mlx/backend/metal/jit/includes.h"
#include "mlx/backend/metal/jit/reduce.h" #include "mlx/backend/metal/jit/reduce.h"
#include "mlx/backend/metal/jit/scan.h" #include "mlx/backend/metal/jit/scan.h"
@ -489,4 +490,51 @@ MTL::ComputePipelineState* get_steel_conv_general_kernel(
return d.get_kernel(kernel_name, lib); return d.get_kernel(kernel_name, lib);
} }
MTL::ComputePipelineState* get_fft_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& hash_name,
const int tg_mem_size,
const std::string& in_type,
const std::string& out_type,
int step,
bool real,
const metal::MTLFCList& func_consts) {
const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
std::string kernel_string;
if (lib_name.find("bluestein") != std::string::npos) {
kernel_string = bluestein_fft_kernel;
} else if (lib_name.find("rader") != std::string::npos) {
kernel_string = rader_fft_kernel;
} else if (lib_name.find("four_step") != std::string::npos) {
kernel_string = four_step_fft_kernel;
} else {
kernel_string = fft_kernel;
}
kernel_source << metal::fft();
if (lib_name.find("four_step") != std::string::npos) {
kernel_source << fmt::format(
kernel_string,
"name"_a = lib_name,
"tg_mem_size"_a = tg_mem_size,
"in_T"_a = in_type,
"out_T"_a = out_type,
"step"_a = step,
"real"_a = real);
} else {
kernel_source << fmt::format(
kernel_string,
"name"_a = lib_name,
"tg_mem_size"_a = tg_mem_size,
"in_T"_a = in_type,
"out_T"_a = out_type);
}
lib = d.get_library(lib_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib, hash_name, func_consts);
}
} // namespace mlx::core } // namespace mlx::core

View File

@ -155,4 +155,15 @@ MTL::ComputePipelineState* get_steel_conv_general_kernel(
int wm, int wm,
int wn); int wn);
MTL::ComputePipelineState* get_fft_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& hash_name,
const int tg_mem_size,
const std::string& in_type,
const std::string& out_type,
int step,
bool real,
const metal::MTLFCList& func_consts);
} // namespace mlx::core } // namespace mlx::core

View File

@ -48,6 +48,9 @@ set(
binary.h binary.h
ternary.h ternary.h
copy.h copy.h
fft.h
fft/radix.h
fft/readwrite.h
softmax.h softmax.h
sort.h sort.h
scan.h scan.h

View File

@ -0,0 +1,486 @@
// Copyright © 2024 Apple Inc.
// Metal FFT using Stockham's algorithm
//
// References:
// - VkFFT (https://github.com/DTolm/VkFFT)
// - Eric Bainville's excellent page (http://www.bealto.com/gpu-fft.html)
#include <metal_common>
#include "mlx/backend/metal/kernels/fft/radix.h"
#include "mlx/backend/metal/kernels/fft/readwrite.h"
#include "mlx/backend/metal/kernels/steel/defines.h"
using namespace metal;
#define MAX_RADIX 13
// Reached when elems_per_thread_ = 6, max_radix = 13
// and some threads have to do 3 radix 6s requiring 18 float2s.
#define MAX_OUTPUT_SIZE 18
// Specialize for a particular value of N at runtime
STEEL_CONST bool inv_ [[function_constant(0)]];
STEEL_CONST bool is_power_of_2_ [[function_constant(1)]];
STEEL_CONST int elems_per_thread_ [[function_constant(2)]];
// rader_m = n / rader_n
STEEL_CONST int rader_m_ [[function_constant(3)]];
// Stockham steps
STEEL_CONST int radix_13_steps_ [[function_constant(4)]];
STEEL_CONST int radix_11_steps_ [[function_constant(5)]];
STEEL_CONST int radix_8_steps_ [[function_constant(6)]];
STEEL_CONST int radix_7_steps_ [[function_constant(7)]];
STEEL_CONST int radix_6_steps_ [[function_constant(8)]];
STEEL_CONST int radix_5_steps_ [[function_constant(9)]];
STEEL_CONST int radix_4_steps_ [[function_constant(10)]];
STEEL_CONST int radix_3_steps_ [[function_constant(11)]];
STEEL_CONST int radix_2_steps_ [[function_constant(12)]];
// Rader steps
STEEL_CONST int rader_13_steps_ [[function_constant(13)]];
STEEL_CONST int rader_11_steps_ [[function_constant(14)]];
STEEL_CONST int rader_8_steps_ [[function_constant(15)]];
STEEL_CONST int rader_7_steps_ [[function_constant(16)]];
STEEL_CONST int rader_6_steps_ [[function_constant(17)]];
STEEL_CONST int rader_5_steps_ [[function_constant(18)]];
STEEL_CONST int rader_4_steps_ [[function_constant(19)]];
STEEL_CONST int rader_3_steps_ [[function_constant(20)]];
STEEL_CONST int rader_2_steps_ [[function_constant(21)]];
// See "radix.h" for radix codelets
typedef void (*RadixFunc)(thread float2*, thread float2*);
// Perform a single radix n butterfly with appropriate twiddles
template <int radix, RadixFunc radix_func>
METAL_FUNC void radix_butterfly(
int i,
int p,
thread float2* x,
thread short* indices,
thread float2* y) {
// i: the index in the overall DFT that we're processing.
// p: the size of the DFTs we're merging at this step.
// m: how many threads are working on this DFT.
int k, j;
// Use faster bitwise operations when working with powers of two
constexpr bool radix_p_2 = (radix & (radix - 1)) == 0;
if (radix_p_2 && is_power_of_2_) {
constexpr short power = __builtin_ctz(radix);
k = i & (p - 1);
j = ((i - k) << power) + k;
} else {
k = i % p;
j = (i / p) * radix * p + k;
}
// Apply twiddles
if (p > 1) {
float2 twiddle_1 = get_twiddle(k, radix * p);
float2 twiddle = twiddle_1;
x[1] = complex_mul(x[1], twiddle);
STEEL_PRAGMA_UNROLL
for (int t = 2; t < radix; t++) {
twiddle = complex_mul(twiddle, twiddle_1);
x[t] = complex_mul(x[t], twiddle);
}
}
radix_func(x, y);
STEEL_PRAGMA_UNROLL
for (int t = 0; t < radix; t++) {
indices[t] = j + t * p;
}
}
// Perform all the radix steps required for a
// particular radix size n.
template <int radix, RadixFunc radix_func>
METAL_FUNC void radix_n_steps(
int i,
thread int* p,
int m,
int n,
int num_steps,
thread float2* inputs,
thread short* indices,
thread float2* values,
threadgroup float2* buf) {
int m_r = n / radix;
// When combining different sized radices, we have to do
// multiple butterflies in a single thread.
// E.g. n = 28 = 4 * 7
// 4 threads, 7 elems_per_thread
// All threads do 1 radix7 butterfly.
// 3 threads do 2 radix4 butterflies.
// 1 thread does 1 radix4 butterfly.
int max_radices_per_thread = (elems_per_thread_ + radix - 1) / radix;
int index = 0;
int r_index = 0;
for (int s = 0; s < num_steps; s++) {
for (int t = 0; t < max_radices_per_thread; t++) {
index = i + t * m;
if (index < m_r) {
for (int r = 0; r < radix; r++) {
inputs[r] = buf[index + r * m_r];
}
radix_butterfly<radix, radix_func>(
index, *p, inputs, indices + t * radix, values + t * radix);
}
}
// Wait until all threads have read their inputs into thread local mem
threadgroup_barrier(mem_flags::mem_threadgroup);
for (int t = 0; t < max_radices_per_thread; t++) {
index = i + t * m;
if (index < m_r) {
for (int r = 0; r < radix; r++) {
r_index = t * radix + r;
buf[indices[r_index]] = values[r_index];
}
}
}
// Wait until all threads have written back to threadgroup mem
threadgroup_barrier(mem_flags::mem_threadgroup);
*p *= radix;
}
}
#define RADIX_STEP(radix, radix_func, num_steps) \
radix_n_steps<radix, radix_func>( \
fft_idx, p, m, n, num_steps, inputs, indices, values, buf);
template <bool rader = false>
METAL_FUNC void
perform_fft(int fft_idx, thread int* p, int m, int n, threadgroup float2* buf) {
float2 inputs[MAX_RADIX];
short indices[MAX_OUTPUT_SIZE];
float2 values[MAX_OUTPUT_SIZE];
RADIX_STEP(2, radix2, rader ? rader_2_steps_ : radix_2_steps_);
RADIX_STEP(3, radix3, rader ? rader_3_steps_ : radix_3_steps_);
RADIX_STEP(4, radix4, rader ? rader_4_steps_ : radix_4_steps_);
RADIX_STEP(5, radix5, rader ? rader_5_steps_ : radix_5_steps_);
RADIX_STEP(6, radix6, rader ? rader_6_steps_ : radix_6_steps_);
RADIX_STEP(7, radix7, rader ? rader_7_steps_ : radix_7_steps_);
RADIX_STEP(8, radix8, rader ? rader_8_steps_ : radix_8_steps_);
RADIX_STEP(11, radix11, rader ? rader_11_steps_ : radix_11_steps_);
RADIX_STEP(13, radix13, rader ? rader_13_steps_ : radix_13_steps_);
}
// Each FFT is computed entirely in shared GPU memory.
//
// N is decomposed into radix-n DFTs:
// e.g. 128 = 2 * 4 * 4 * 4
template <int tg_mem_size, typename in_T, typename out_T>
[[kernel]] void fft(
const device in_T* in [[buffer(0)]],
device out_T* out [[buffer(1)]],
constant const int& n,
constant const int& batch_size,
uint3 elem [[thread_position_in_grid]],
uint3 grid [[threads_per_grid]]) {
threadgroup float2 shared_in[tg_mem_size];
thread ReadWriter<in_T, out_T> read_writer = ReadWriter<in_T, out_T>(
in,
&shared_in[0],
out,
n,
batch_size,
elems_per_thread_,
elem,
grid,
inv_);
if (read_writer.out_of_bounds()) {
return;
};
read_writer.load();
threadgroup_barrier(mem_flags::mem_threadgroup);
int p = 1;
int fft_idx = elem.z; // Thread index in DFT
int m = grid.z; // Threads per DFT
int tg_idx = elem.y * n; // Index of this DFT in threadgroup
threadgroup float2* buf = &shared_in[tg_idx];
perform_fft(fft_idx, &p, m, n, buf);
read_writer.write();
}
template <int tg_mem_size, typename in_T, typename out_T>
[[kernel]] void rader_fft(
const device in_T* in [[buffer(0)]],
device out_T* out [[buffer(1)]],
const device float2* raders_b_q [[buffer(2)]],
const device short* raders_g_q [[buffer(3)]],
const device short* raders_g_minus_q [[buffer(4)]],
constant const int& n,
constant const int& batch_size,
constant const int& rader_n,
uint3 elem [[thread_position_in_grid]],
uint3 grid [[threads_per_grid]]) {
// Use Rader's algorithm to compute fast FFTs
// when a prime factor `p` of `n` is greater than 13 but
// has `p - 1` Stockham decomposable into to prime factors <= 13.
//
// E.g. n = 102
// = 2 * 3 * 17
// . = 2 * 3 * RADER(16)
// . = 2 * 3 * RADER(4 * 4)
//
// In numpy:
// x_perm = x[g_q]
// y = np.fft.fft(x_perm) * b_q
// z = np.fft.ifft(y) + x[0]
// out = z[g_minus_q]
// out[0] = x[1:].sum()
//
// Where the g_q and g_minus_q are permutations formed
// by the group under multiplicative modulo N using the
// primitive root of N and b_q is a constant.
// See https://en.wikipedia.org/wiki/Rader%27s_FFT_algorithm
//
// Rader's uses fewer operations than Bluestein's and so
// is more accurate. It's also faster in most cases.
threadgroup float2 shared_in[tg_mem_size];
thread ReadWriter<in_T, out_T> read_writer = ReadWriter<in_T, out_T>(
in,
&shared_in[0],
out,
n,
batch_size,
elems_per_thread_,
elem,
grid,
inv_);
if (read_writer.out_of_bounds()) {
return;
};
read_writer.load();
threadgroup_barrier(mem_flags::mem_threadgroup);
// The number of the threads we're using for each DFT
int m = grid.z;
int fft_idx = elem.z;
int tg_idx = elem.y * n;
threadgroup float2* buf = &shared_in[tg_idx];
// rader_m = n / rader_n;
int rader_m = rader_m_;
// We have to load two x_0s for each thread since sometimes
// elems_per_thread_ crosses a boundary.
// E.g. with n = 34, rader_n = 17, elems_per_thread_ = 4
// 0 0 0 0 1 1 1 1 2 2 2 2 3 3 3 3 4 4 4 4 5 5 5 5 6 6 6 6 7 7 7 7 8 8
// 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
short x_0_index =
metal::min(fft_idx * elems_per_thread_ / (rader_n - 1), rader_m - 1);
float2 x_0[2] = {buf[x_0_index], buf[x_0_index + 1]};
// Do the Rader permutation in shared memory
float2 temp[MAX_RADIX];
int max_index = n - rader_m - 1;
for (int e = 0; e < elems_per_thread_; e++) {
short index = metal::min(fft_idx * elems_per_thread_ + e, max_index);
short g_q = raders_g_q[index / rader_m];
temp[e] = buf[rader_m + (g_q - 1) * rader_m + index % rader_m];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
for (int e = 0; e < elems_per_thread_; e++) {
short index = metal::min(fft_idx * elems_per_thread_ + e, max_index);
buf[index + rader_m] = temp[e];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Rader FFT on x[rader_m:]
int p = 1;
perform_fft</*rader=*/true>(fft_idx, &p, m, n - rader_m, buf + rader_m);
// x_1 + ... + x_n is computed for us in the first FFT step so
// we save it in the first rader_m indices of the array for later.
int x_sum_index = metal::min(fft_idx, rader_m - 1);
buf[x_sum_index] = buf[rader_m + x_sum_index * (rader_n - 1)];
float2 inv = {1.0f, -1.0f};
for (int e = 0; e < elems_per_thread_; e++) {
short index = metal::min(fft_idx * elems_per_thread_ + e, max_index);
short interleaved_index =
index / rader_m + (index % rader_m) * (rader_n - 1);
temp[e] = complex_mul(
buf[rader_m + interleaved_index],
raders_b_q[interleaved_index % (rader_n - 1)]);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
for (int e = 0; e < elems_per_thread_; e++) {
short index = metal::min(fft_idx * elems_per_thread_ + e, max_index);
buf[rader_m + index] = temp[e] * inv;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Rader IFFT on x[rader_m:]
p = 1;
perform_fft</*rader=*/true>(fft_idx, &p, m, n - rader_m, buf + rader_m);
float2 rader_inv_factor = {1.0f / (rader_n - 1), -1.0f / (rader_n - 1)};
for (int e = 0; e < elems_per_thread_; e++) {
short index = metal::min(fft_idx * elems_per_thread_ + e, n - rader_m - 1);
short diff_index = index / (rader_n - 1) - x_0_index;
temp[e] = buf[rader_m + index] * rader_inv_factor + x_0[diff_index];
}
// Use the sum of elements that was computed in the first FFT
float2 x_sum = buf[x_0_index] + x_0[0];
threadgroup_barrier(mem_flags::mem_threadgroup);
for (int e = 0; e < elems_per_thread_; e++) {
short index = metal::min(fft_idx * elems_per_thread_ + e, max_index);
short g_q_index = index % (rader_n - 1);
short g_q = raders_g_minus_q[g_q_index];
short out_index = index - g_q_index + g_q + (index / (rader_n - 1));
buf[out_index] = temp[e];
}
buf[x_0_index * rader_n] = x_sum;
threadgroup_barrier(mem_flags::mem_threadgroup);
p = rader_n;
perform_fft(fft_idx, &p, m, n, buf);
read_writer.write();
}
template <int tg_mem_size, typename in_T, typename out_T>
[[kernel]] void bluestein_fft(
const device in_T* in [[buffer(0)]],
device out_T* out [[buffer(1)]],
const device float2* w_q [[buffer(2)]],
const device float2* w_k [[buffer(3)]],
constant const int& length,
constant const int& n,
constant const int& batch_size,
uint3 elem [[thread_position_in_grid]],
uint3 grid [[threads_per_grid]]) {
// Computes arbitrary length FFTs with Bluestein's algorithm
//
// In numpy:
// bluestein_n = next_power_of_2(2*n - 1)
// out = w_k * np.fft.ifft(np.fft.fft(w_k * in, bluestein_n) * w_q)
//
// Where w_k and w_q are precomputed on CPU in high precision as:
// w_k = np.exp(-1j * np.pi / n * (np.arange(-n + 1, n) ** 2))
// w_q = np.fft.fft(1/w_k[-n:])
threadgroup float2 shared_in[tg_mem_size];
thread ReadWriter<in_T, out_T> read_writer = ReadWriter<in_T, out_T>(
in,
&shared_in[0],
out,
n,
batch_size,
elems_per_thread_,
elem,
grid,
inv_);
if (read_writer.out_of_bounds()) {
return;
};
read_writer.load_padded(length, w_k);
threadgroup_barrier(mem_flags::mem_threadgroup);
int p = 1;
int fft_idx = elem.z; // Thread index in DFT
int m = grid.z; // Threads per DFT
int tg_idx = elem.y * n; // Index of this DFT in threadgroup
threadgroup float2* buf = &shared_in[tg_idx];
// fft
perform_fft(fft_idx, &p, m, n, buf);
float2 inv = float2(1.0f, -1.0f);
for (int t = 0; t < elems_per_thread_; t++) {
int index = fft_idx + t * m;
buf[index] = complex_mul(buf[index], w_q[index]) * inv;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// ifft
p = 1;
perform_fft(fft_idx, &p, m, n, buf);
read_writer.write_padded(length, w_k);
}
template <
int tg_mem_size,
typename in_T,
typename out_T,
int step,
bool real = false>
[[kernel]] void four_step_fft(
const device in_T* in [[buffer(0)]],
device out_T* out [[buffer(1)]],
constant const int& n1,
constant const int& n2,
constant const int& batch_size,
uint3 elem [[thread_position_in_grid]],
uint3 grid [[threads_per_grid]]) {
// Fast four step FFT implementation for powers of 2.
int overall_n = n1 * n2;
int n = step == 0 ? n1 : n2;
int stride = step == 0 ? n2 : n1;
// The number of the threads we're using for each DFT
int m = grid.z;
int fft_idx = elem.z;
threadgroup float2 shared_in[tg_mem_size];
threadgroup float2* buf = &shared_in[elem.y * n];
using read_writer_t = ReadWriter<in_T, out_T, step, real>;
read_writer_t read_writer = read_writer_t(
in,
&shared_in[0],
out,
n,
batch_size,
elems_per_thread_,
elem,
grid,
inv_);
if (read_writer.out_of_bounds()) {
return;
};
read_writer.load_strided(stride, overall_n);
threadgroup_barrier(mem_flags::mem_threadgroup);
int p = 1;
perform_fft(fft_idx, &p, m, n, buf);
read_writer.write_strided(stride, overall_n);
}

View File

@ -1,199 +1,84 @@
// Copyright © 2024 Apple Inc. // Copyright © 2024 Apple Inc.
// Metal FFT using Stockham's algorithm #include "mlx/backend/metal/kernels/fft.h"
//
// References:
// - VkFFT (https://github.com/DTolm/VkFFT)
// - Eric Bainville's excellent page (http://www.bealto.com/gpu-fft.html)
#include <metal_common> #define instantiate_fft(tg_mem_size, in_T, out_T) \
#include <metal_math> template [[host_name("fft_mem_" #tg_mem_size "_" #in_T \
"_" #out_T)]] [[kernel]] void \
fft<tg_mem_size, in_T, out_T>( \
const device in_T* in [[buffer(0)]], \
device out_T* out [[buffer(1)]], \
constant const int& n, \
constant const int& batch_size, \
uint3 elem [[thread_position_in_grid]], \
uint3 grid [[threads_per_grid]]);
#include "mlx/backend/metal/kernels/defines.h" #define instantiate_rader(tg_mem_size, in_T, out_T) \
#include "mlx/backend/metal/kernels/utils.h" template [[host_name("rader_fft_mem_" #tg_mem_size "_" #in_T \
"_" #out_T)]] [[kernel]] void \
rader_fft<tg_mem_size, in_T, out_T>( \
const device in_T* in [[buffer(0)]], \
device out_T* out [[buffer(1)]], \
const device float2* raders_b_q [[buffer(2)]], \
const device short* raders_g_q [[buffer(3)]], \
const device short* raders_g_minus_q [[buffer(4)]], \
constant const int& n, \
constant const int& batch_size, \
constant const int& rader_n, \
uint3 elem [[thread_position_in_grid]], \
uint3 grid [[threads_per_grid]]);
using namespace metal; #define instantiate_bluestein(tg_mem_size, in_T, out_T) \
template [[host_name("bluestein_fft_mem_" #tg_mem_size "_" #in_T \
"_" #out_T)]] [[kernel]] void \
bluestein_fft<tg_mem_size, in_T, out_T>( \
const device in_T* in [[buffer(0)]], \
device out_T* out [[buffer(1)]], \
const device float2* w_q [[buffer(2)]], \
const device float2* w_k [[buffer(3)]], \
constant const int& length, \
constant const int& n, \
constant const int& batch_size, \
uint3 elem [[thread_position_in_grid]], \
uint3 grid [[threads_per_grid]]);
float2 complex_mul(float2 a, float2 b) { #define instantiate_four_step(tg_mem_size, in_T, out_T, step, real) \
float2 c; template [[host_name("four_step_mem_" #tg_mem_size "_" #in_T "_" #out_T \
c.x = a.x * b.x - a.y * b.y; "_" #step "_" #real)]] [[kernel]] void \
c.y = a.x * b.y + a.y * b.x; four_step_fft<tg_mem_size, in_T, out_T, step, real>( \
return c; const device in_T* in [[buffer(0)]], \
} device out_T* out [[buffer(1)]], \
constant const int& n1, \
constant const int& n2, \
constant const int& batch_size, \
uint3 elem [[thread_position_in_grid]], \
uint3 grid [[threads_per_grid]]);
float2 get_twiddle(int k, int p) {
float theta = -1.0f * k * M_PI_F / (2 * p);
float2 twiddle;
twiddle.x = metal::fast::cos(theta);
twiddle.y = metal::fast::sin(theta);
return twiddle;
}
// single threaded radix2 implemetation
void radix2(
int i,
int p,
int m,
threadgroup float2* read_buf,
threadgroup float2* write_buf) {
float2 x_0 = read_buf[i];
float2 x_1 = read_buf[i + m];
// The index within this sub-DFT
int k = i & (p - 1);
float2 twiddle = get_twiddle(k, p);
float2 z = complex_mul(x_1, twiddle);
float2 y_0 = x_0 + z;
float2 y_1 = x_0 - z;
int j = (i << 1) - k;
write_buf[j] = y_0;
write_buf[j + p] = y_1;
}
// single threaded radix4 implemetation
void radix4(
int i,
int p,
int m,
threadgroup float2* read_buf,
threadgroup float2* write_buf) {
float2 x_0 = read_buf[i];
float2 x_1 = read_buf[i + m];
float2 x_2 = read_buf[i + 2 * m];
float2 x_3 = read_buf[i + 3 * m];
// The index within this sub-DFT
int k = i & (p - 1);
float2 twiddle = get_twiddle(k, p);
// e^a * e^b = e^(a + b)
float2 twiddle_2 = complex_mul(twiddle, twiddle);
float2 twiddle_3 = complex_mul(twiddle, twiddle_2);
x_1 = complex_mul(x_1, twiddle);
x_2 = complex_mul(x_2, twiddle_2);
x_3 = complex_mul(x_3, twiddle_3);
float2 minus_i;
minus_i.x = 0;
minus_i.y = -1;
// Hard coded twiddle factors for DFT4
float2 z_0 = x_0 + x_2;
float2 z_1 = x_0 - x_2;
float2 z_2 = x_1 + x_3;
float2 z_3 = complex_mul(x_1 - x_3, minus_i);
float2 y_0 = z_0 + z_2;
float2 y_1 = z_1 + z_3;
float2 y_2 = z_0 - z_2;
float2 y_3 = z_1 - z_3;
int j = ((i - k) << 2) + k;
write_buf[j] = y_0;
write_buf[j + p] = y_1;
write_buf[j + 2 * p] = y_2;
write_buf[j + 3 * p] = y_3;
}
// Each FFT is computed entirely in shared GPU memory.
//
// N is decomposed into radix-2 and radix-4 DFTs:
// e.g. 128 = 2 * 4 * 4 * 4
//
// At each step we use n / 4 threads, each performing
// a single-threaded radix-4 or radix-2 DFT.
//
// We provide the number of radix-2 and radix-4
// steps at compile time for a ~20% performance boost.
template <size_t n, size_t radix_2_steps, size_t radix_4_steps>
[[kernel]] void fft(
const device float2* in [[buffer(0)]],
device float2* out [[buffer(1)]],
uint3 thread_position_in_grid [[thread_position_in_grid]],
uint3 threads_per_grid [[threads_per_grid]]) {
// Index of the DFT in batch
int batch_idx = thread_position_in_grid.x * n;
// The index in the DFT we're working on
int i = thread_position_in_grid.y;
// The number of the threads we're using for each DFT
int m = threads_per_grid.y;
// Allocate 2 shared memory buffers for Stockham.
// We alternate reading from one and writing to the other at each radix step.
threadgroup float2 shared_in[n];
threadgroup float2 shared_out[n];
// Pointers to facilitate Stockham buffer swapping
threadgroup float2* read_buf = shared_in;
threadgroup float2* write_buf = shared_out;
threadgroup float2* tmp;
// Copy input into shared memory
shared_in[i] = in[batch_idx + i];
shared_in[i + m] = in[batch_idx + i + m];
shared_in[i + 2 * m] = in[batch_idx + i + 2 * m];
shared_in[i + 3 * m] = in[batch_idx + i + 3 * m];
threadgroup_barrier(mem_flags::mem_threadgroup);
int p = 1;
for (size_t r = 0; r < radix_2_steps; r++) {
radix2(i, p, m * 2, read_buf, write_buf);
radix2(i + m, p, m * 2, read_buf, write_buf);
p *= 2;
threadgroup_barrier(mem_flags::mem_threadgroup);
// Stockham switch of buffers
tmp = write_buf;
write_buf = read_buf;
read_buf = tmp;
}
for (size_t r = 0; r < radix_4_steps; r++) {
radix4(i, p, m, read_buf, write_buf);
p *= 4;
threadgroup_barrier(mem_flags::mem_threadgroup);
// Stockham switch of buffers
tmp = write_buf;
write_buf = read_buf;
read_buf = tmp;
}
// Copy shared memory to output
out[batch_idx + i] = read_buf[i];
out[batch_idx + i + m] = read_buf[i + m];
out[batch_idx + i + 2 * m] = read_buf[i + 2 * m];
out[batch_idx + i + 3 * m] = read_buf[i + 3 * m];
}
#define instantiate_fft(name, n, radix_2_steps, radix_4_steps) \
template [[host_name("fft_" #name)]] [[kernel]] void \
fft<n, radix_2_steps, radix_4_steps>( \
const device float2* in [[buffer(0)]], \
device float2* out [[buffer(1)]], \
uint3 thread_position_in_grid [[thread_position_in_grid]], \
uint3 threads_per_grid [[threads_per_grid]]);
// Explicitly define kernels for each power of 2.
// clang-format off // clang-format off
instantiate_fft(4, /* n= */ 4, /* radix_2_steps= */ 0, /* radix_4_steps= */ 1) #define instantiate_ffts(tg_mem_size) \
instantiate_fft(8, 8, 1, 1) instantiate_fft(16, 16, 0, 2) instantiate_fft(tg_mem_size, float2, float2) \
instantiate_fft(32, 32, 1, 2) instantiate_fft(64, 64, 0, 3) instantiate_fft(tg_mem_size, float, float2) \
instantiate_fft(128, 128, 1, 3) instantiate_fft(256, 256, 0, 4) instantiate_fft(tg_mem_size, float2, float) \
instantiate_fft(512, 512, 1, 4) instantiate_rader(tg_mem_size, float2, float2) \
instantiate_fft(1024, 1024, 0, 5) instantiate_rader(tg_mem_size, float, float2) \
// 2048 is the max that will fit into 32KB of threadgroup memory. instantiate_rader(tg_mem_size, float2, float) \
// TODO: implement 4 step FFT for larger n. instantiate_bluestein(tg_mem_size, float2, float2) \
instantiate_fft(2048, 2048, 1, 5) // clang-format on instantiate_bluestein(tg_mem_size, float, float2) \
instantiate_bluestein(tg_mem_size, float2, float) \
instantiate_four_step(tg_mem_size, float2, float2, 0, /*real=*/false) \
instantiate_four_step(tg_mem_size, float2, float2, 1, /*real=*/false) \
instantiate_four_step(tg_mem_size, float, float2, 0, /*real=*/true) \
instantiate_four_step(tg_mem_size, float2, float2, 1, /*real=*/true) \
instantiate_four_step(tg_mem_size, float2, float2, 0, /*real=*/true) \
instantiate_four_step(tg_mem_size, float2, float, 1, /*real=*/true)
// It's substantially faster to statically define the
// threadgroup memory size rather than using
// `setThreadgroupMemoryLength` on the compute encoder.
// For non-power of 2 sizes we round up the shared memory.
instantiate_ffts(256)
instantiate_ffts(512)
instantiate_ffts(1024)
instantiate_ffts(2048)
// 4096 is the max that will fit into 32KB of threadgroup memory.
instantiate_ffts(4096) // clang-format on

View File

@ -0,0 +1,328 @@
// Copyright © 2024 Apple Inc.
/* Radix kernels
We provide optimized, single threaded Radix codelets
for n=2,3,4,5,6,7,8,10,11,12,13.
For n=2,3,4,5,6 we hand write the codelets.
For n=8,10,12 we combine smaller codelets.
For n=7,11,13 we use Rader's algorithm which decomposes
them into (n-1)=6,10,12 codelets. */
#pragma once
#include <metal_common>
#include <metal_math>
#include <metal_stdlib>
METAL_FUNC float2 complex_mul(float2 a, float2 b) {
return float2(a.x * b.x - a.y * b.y, a.x * b.y + a.y * b.x);
}
// Complex mul followed by conjugate
METAL_FUNC float2 complex_mul_conj(float2 a, float2 b) {
return float2(a.x * b.x - a.y * b.y, -a.x * b.y - a.y * b.x);
}
// Compute an FFT twiddle factor
METAL_FUNC float2 get_twiddle(int k, int p) {
float theta = -2.0f * k * M_PI_F / p;
float2 twiddle = {metal::fast::cos(theta), metal::fast::sin(theta)};
return twiddle;
}
METAL_FUNC void radix2(thread float2* x, thread float2* y) {
y[0] = x[0] + x[1];
y[1] = x[0] - x[1];
}
METAL_FUNC void radix3(thread float2* x, thread float2* y) {
float pi_2_3 = -0.8660254037844387;
float2 a_1 = x[1] + x[2];
float2 a_2 = x[1] - x[2];
y[0] = x[0] + a_1;
float2 b_1 = x[0] - 0.5 * a_1;
float2 b_2 = pi_2_3 * a_2;
float2 b_2_j = {-b_2.y, b_2.x};
y[1] = b_1 + b_2_j;
y[2] = b_1 - b_2_j;
}
METAL_FUNC void radix4(thread float2* x, thread float2* y) {
float2 z_0 = x[0] + x[2];
float2 z_1 = x[0] - x[2];
float2 z_2 = x[1] + x[3];
float2 z_3 = x[1] - x[3];
float2 z_3_i = {z_3.y, -z_3.x};
y[0] = z_0 + z_2;
y[1] = z_1 + z_3_i;
y[2] = z_0 - z_2;
y[3] = z_1 - z_3_i;
}
METAL_FUNC void radix5(thread float2* x, thread float2* y) {
float2 root_5_4 = 0.5590169943749475;
float2 sin_2pi_5 = 0.9510565162951535;
float2 sin_1pi_5 = 0.5877852522924731;
float2 a_1 = x[1] + x[4];
float2 a_2 = x[2] + x[3];
float2 a_3 = x[1] - x[4];
float2 a_4 = x[2] - x[3];
float2 a_5 = a_1 + a_2;
float2 a_6 = root_5_4 * (a_1 - a_2);
float2 a_7 = x[0] - a_5 / 4;
float2 a_8 = a_7 + a_6;
float2 a_9 = a_7 - a_6;
float2 a_10 = sin_2pi_5 * a_3 + sin_1pi_5 * a_4;
float2 a_11 = sin_1pi_5 * a_3 - sin_2pi_5 * a_4;
float2 a_10_j = {a_10.y, -a_10.x};
float2 a_11_j = {a_11.y, -a_11.x};
y[0] = x[0] + a_5;
y[1] = a_8 + a_10_j;
y[2] = a_9 + a_11_j;
y[3] = a_9 - a_11_j;
y[4] = a_8 - a_10_j;
}
METAL_FUNC void radix6(thread float2* x, thread float2* y) {
float sin_pi_3 = 0.8660254037844387;
float2 a_1 = x[2] + x[4];
float2 a_2 = x[0] - a_1 / 2;
float2 a_3 = sin_pi_3 * (x[2] - x[4]);
float2 a_4 = x[5] + x[1];
float2 a_5 = x[3] - a_4 / 2;
float2 a_6 = sin_pi_3 * (x[5] - x[1]);
float2 a_7 = x[0] + a_1;
float2 a_3_i = {a_3.y, -a_3.x};
float2 a_6_i = {a_6.y, -a_6.x};
float2 a_8 = a_2 + a_3_i;
float2 a_9 = a_2 - a_3_i;
float2 a_10 = x[3] + a_4;
float2 a_11 = a_5 + a_6_i;
float2 a_12 = a_5 - a_6_i;
y[0] = a_7 + a_10;
y[1] = a_8 - a_11;
y[2] = a_9 + a_12;
y[3] = a_7 - a_10;
y[4] = a_8 + a_11;
y[5] = a_9 - a_12;
}
METAL_FUNC void radix7(thread float2* x, thread float2* y) {
// Rader's algorithm
float2 inv = {1 / 6.0, -1 / 6.0};
// fft
float2 in1[6] = {x[1], x[3], x[2], x[6], x[4], x[5]};
radix6(in1, y + 1);
y[0] = y[1] + x[0];
// b_q
y[1] = complex_mul_conj(y[1], float2(-1, 0));
y[2] = complex_mul_conj(y[2], float2(2.44013336, -1.02261879));
y[3] = complex_mul_conj(y[3], float2(2.37046941, -1.17510629));
y[4] = complex_mul_conj(y[4], float2(0, -2.64575131));
y[5] = complex_mul_conj(y[5], float2(2.37046941, 1.17510629));
y[6] = complex_mul_conj(y[6], float2(-2.44013336, -1.02261879));
// ifft
radix6(y + 1, x + 1);
y[1] = x[1] * inv + x[0];
y[5] = x[2] * inv + x[0];
y[4] = x[3] * inv + x[0];
y[6] = x[4] * inv + x[0];
y[2] = x[5] * inv + x[0];
y[3] = x[6] * inv + x[0];
}
METAL_FUNC void radix8(thread float2* x, thread float2* y) {
float cos_pi_4 = 0.7071067811865476;
float2 w_0 = {cos_pi_4, -cos_pi_4};
float2 w_1 = {-cos_pi_4, -cos_pi_4};
float2 temp[8] = {x[0], x[2], x[4], x[6], x[1], x[3], x[5], x[7]};
radix4(temp, x);
radix4(temp + 4, x + 4);
y[0] = x[0] + x[4];
y[4] = x[0] - x[4];
float2 x_5 = complex_mul(x[5], w_0);
y[1] = x[1] + x_5;
y[5] = x[1] - x_5;
float2 x_6 = {x[6].y, -x[6].x};
y[2] = x[2] + x_6;
y[6] = x[2] - x_6;
float2 x_7 = complex_mul(x[7], w_1);
y[3] = x[3] + x_7;
y[7] = x[3] - x_7;
}
template <bool raders_perm>
METAL_FUNC void radix10(thread float2* x, thread float2* y) {
float2 w[4];
w[0] = {0.8090169943749475, -0.5877852522924731};
w[1] = {0.30901699437494745, -0.9510565162951535};
w[2] = {-w[1].x, w[1].y};
w[3] = {-w[0].x, w[0].y};
if (raders_perm) {
float2 temp[10] = {
x[0], x[3], x[4], x[8], x[2], x[1], x[7], x[9], x[6], x[5]};
radix5(temp, x);
radix5(temp + 5, x + 5);
} else {
float2 temp[10] = {
x[0], x[2], x[4], x[6], x[8], x[1], x[3], x[5], x[7], x[9]};
radix5(temp, x);
radix5(temp + 5, x + 5);
}
y[0] = x[0] + x[5];
y[5] = x[0] - x[5];
for (int t = 1; t < 5; t++) {
float2 a = complex_mul(x[t + 5], w[t - 1]);
y[t] = x[t] + a;
y[t + 5] = x[t] - a;
}
}
METAL_FUNC void radix11(thread float2* x, thread float2* y) {
// Raders Algorithm
float2 inv = {1 / 10.0, -1 / 10.0};
// fft
radix10<true>(x + 1, y + 1);
y[0] = y[1] + x[0];
// b_q
y[1] = complex_mul_conj(y[1], float2(-1, 0));
y[2] = complex_mul_conj(y[2], float2(0.955301878, -3.17606649));
y[3] = complex_mul_conj(y[3], float2(2.63610556, 2.01269656));
y[4] = complex_mul_conj(y[4], float2(2.54127802, 2.13117479));
y[5] = complex_mul_conj(y[5], float2(2.07016210, 2.59122150));
y[6] = complex_mul_conj(y[6], float2(0, -3.31662479));
y[7] = complex_mul_conj(y[7], float2(2.07016210, -2.59122150));
y[8] = complex_mul_conj(y[8], float2(-2.54127802, 2.13117479));
y[9] = complex_mul_conj(y[9], float2(2.63610556, -2.01269656));
y[10] = complex_mul_conj(y[10], float2(-0.955301878, -3.17606649));
// ifft
radix10<false>(y + 1, x + 1);
y[1] = x[1] * inv + x[0];
y[6] = x[2] * inv + x[0];
y[3] = x[3] * inv + x[0];
y[7] = x[4] * inv + x[0];
y[9] = x[5] * inv + x[0];
y[10] = x[6] * inv + x[0];
y[5] = x[7] * inv + x[0];
y[8] = x[8] * inv + x[0];
y[4] = x[9] * inv + x[0];
y[2] = x[10] * inv + x[0];
}
template <bool raders_perm>
METAL_FUNC void radix12(thread float2* x, thread float2* y) {
float2 w[6];
float sin_pi_3 = 0.8660254037844387;
w[0] = {sin_pi_3, -0.5};
w[1] = {0.5, -sin_pi_3};
w[2] = {0, -1};
w[3] = {-0.5, -sin_pi_3};
w[4] = {-sin_pi_3, -0.5};
if (raders_perm) {
float2 temp[12] = {
x[0],
x[3],
x[2],
x[11],
x[8],
x[9],
x[1],
x[7],
x[5],
x[10],
x[4],
x[6]};
radix6(temp, x);
radix6(temp + 6, x + 6);
} else {
float2 temp[12] = {
x[0],
x[2],
x[4],
x[6],
x[8],
x[10],
x[1],
x[3],
x[5],
x[7],
x[9],
x[11]};
radix6(temp, x);
radix6(temp + 6, x + 6);
}
y[0] = x[0] + x[6];
y[6] = x[0] - x[6];
for (int t = 1; t < 6; t++) {
float2 a = complex_mul(x[t + 6], w[t - 1]);
y[t] = x[t] + a;
y[t + 6] = x[t] - a;
}
}
METAL_FUNC void radix13(thread float2* x, thread float2* y) {
// Raders Algorithm
float2 inv = {1 / 12.0, -1 / 12.0};
// fft
radix12<true>(x + 1, y + 1);
y[0] = y[1] + x[0];
// b_q
y[1] = complex_mul_conj(y[1], float2(-1, 0));
y[2] = complex_mul_conj(y[2], float2(3.07497206, -1.88269669));
y[3] = complex_mul_conj(y[3], float2(3.09912468, 1.84266823));
y[4] = complex_mul_conj(y[4], float2(3.45084438, -1.04483161));
y[5] = complex_mul_conj(y[5], float2(0.91083583, 3.48860690));
y[6] = complex_mul_conj(y[6], float2(-3.60286363, 0.139189267));
y[7] = complex_mul_conj(y[7], float2(3.60555128, 0));
y[8] = complex_mul_conj(y[8], float2(3.60286363, 0.139189267));
y[9] = complex_mul_conj(y[9], float2(0.91083583, -3.48860690));
y[10] = complex_mul_conj(y[10], float2(-3.45084438, -1.04483161));
y[11] = complex_mul_conj(y[11], float2(3.09912468, -1.84266823));
y[12] = complex_mul_conj(y[12], float2(-3.07497206, -1.88269669));
// ifft
radix12<false>(y + 1, x + 1);
y[1] = x[1] * inv + x[0];
y[7] = x[2] * inv + x[0];
y[10] = x[3] * inv + x[0];
y[5] = x[4] * inv + x[0];
y[9] = x[5] * inv + x[0];
y[11] = x[6] * inv + x[0];
y[12] = x[7] * inv + x[0];
y[6] = x[8] * inv + x[0];
y[3] = x[9] * inv + x[0];
y[8] = x[10] * inv + x[0];
y[4] = x[11] * inv + x[0];
y[2] = x[12] * inv + x[0];
}

View File

@ -0,0 +1,622 @@
// Copyright © 2024 Apple Inc.
#include <metal_common>
#include "mlx/backend/metal/kernels/fft/radix.h"
/* FFT helpers for reading and writing from/to device memory.
For many sizes, GPU FFTs are memory bandwidth bound so
read/write performance is important.
Where possible, we read 128 bits sequentially in each thread,
coalesced with accesses from adajcent threads for optimal performance.
We implement specialized reading/writing for:
- FFT
- RFFT
- IRFFT
Each with support for:
- Contiguous reads
- Padded reads
- Strided reads
*/
#define MAX_RADIX 13
using namespace metal;
template <
typename in_T,
typename out_T,
int step = 0,
bool four_step_real = false>
struct ReadWriter {
const device in_T* in;
threadgroup float2* buf;
device out_T* out;
int n;
int batch_size;
int elems_per_thread;
uint3 elem;
uint3 grid;
int threads_per_tg;
bool inv;
// Used for strided access
int strided_device_idx = 0;
int strided_shared_idx = 0;
METAL_FUNC ReadWriter(
const device in_T* in_,
threadgroup float2* buf_,
device out_T* out_,
const short n_,
const int batch_size_,
const short elems_per_thread_,
const uint3 elem_,
const uint3 grid_,
const bool inv_)
: in(in_),
buf(buf_),
out(out_),
n(n_),
batch_size(batch_size_),
elems_per_thread(elems_per_thread_),
elem(elem_),
grid(grid_),
inv(inv_) {
// Account for padding on last threadgroup
threads_per_tg = elem.x == grid.x - 1
? (batch_size - (grid.x - 1) * grid.y) * grid.z
: grid.y * grid.z;
}
// ifft(x) = 1/n * conj(fft(conj(x)))
METAL_FUNC float2 post_in(float2 elem) const {
return inv ? float2(elem.x, -elem.y) : elem;
}
// Handle float case for generic RFFT alg
METAL_FUNC float2 post_in(float elem) const {
return float2(elem, 0);
}
METAL_FUNC float2 pre_out(float2 elem) const {
return inv ? float2(elem.x / n, -elem.y / n) : elem;
}
METAL_FUNC float2 pre_out(float2 elem, int length) const {
return inv ? float2(elem.x / length, -elem.y / length) : elem;
}
METAL_FUNC bool out_of_bounds() const {
// Account for possible extra threadgroups
int grid_index = elem.x * grid.y + elem.y;
return grid_index >= batch_size;
}
METAL_FUNC void load() const {
int batch_idx = elem.x * grid.y * n;
short tg_idx = elem.y * grid.z + elem.z;
short max_index = grid.y * n - 2;
// 2 complex64s = 128 bits
constexpr int read_width = 2;
for (short e = 0; e < (elems_per_thread / read_width); e++) {
short index = read_width * tg_idx + read_width * threads_per_tg * e;
index = metal::min(index, max_index);
// vectorized reads
buf[index] = post_in(in[batch_idx + index]);
buf[index + 1] = post_in(in[batch_idx + index + 1]);
}
max_index += 1;
if (elems_per_thread % 2 != 0) {
short index = tg_idx +
read_width * threads_per_tg * (elems_per_thread / read_width);
index = metal::min(index, max_index);
buf[index] = post_in(in[batch_idx + index]);
}
}
METAL_FUNC void write() const {
int batch_idx = elem.x * grid.y * n;
short tg_idx = elem.y * grid.z + elem.z;
short max_index = grid.y * n - 2;
constexpr int read_width = 2;
for (short e = 0; e < (elems_per_thread / read_width); e++) {
short index = read_width * tg_idx + read_width * threads_per_tg * e;
index = metal::min(index, max_index);
// vectorized reads
out[batch_idx + index] = pre_out(buf[index]);
out[batch_idx + index + 1] = pre_out(buf[index + 1]);
}
max_index += 1;
if (elems_per_thread % 2 != 0) {
short index = tg_idx +
read_width * threads_per_tg * (elems_per_thread / read_width);
index = metal::min(index, max_index);
out[batch_idx + index] = pre_out(buf[index]);
}
}
// Padded IO for Bluestein's algorithm
METAL_FUNC void load_padded(int length, const device float2* w_k) const {
int batch_idx = elem.x * grid.y * length + elem.y * length;
int fft_idx = elem.z;
int m = grid.z;
threadgroup float2* seq_buf = buf + elem.y * n;
for (int e = 0; e < elems_per_thread; e++) {
int index = metal::min(fft_idx + e * m, n - 1);
if (index < length) {
float2 elem = post_in(in[batch_idx + index]);
seq_buf[index] = complex_mul(elem, w_k[index]);
} else {
seq_buf[index] = 0.0;
}
}
}
METAL_FUNC void write_padded(int length, const device float2* w_k) const {
int batch_idx = elem.x * grid.y * length + elem.y * length;
int fft_idx = elem.z;
int m = grid.z;
float2 inv_factor = {1.0f / n, -1.0f / n};
threadgroup float2* seq_buf = buf + elem.y * n;
for (int e = 0; e < elems_per_thread; e++) {
int index = metal::min(fft_idx + e * m, n - 1);
if (index < length) {
float2 elem = seq_buf[index + length - 1] * inv_factor;
out[batch_idx + index] = pre_out(complex_mul(elem, w_k[index]), length);
}
}
}
// Strided IO for four step FFT
METAL_FUNC void compute_strided_indices(int stride, int overall_n) {
// Use the batch threadgroup dimension to coalesce memory accesses:
// e.g. stride = 12
// device | shared mem
// 0 1 2 3 | 0 12 - -
// - - - - | 1 13 - -
// - - - - | 2 14 - -
// 12 13 14 15 | 3 15 - -
int coalesce_width = grid.y;
int tg_idx = elem.y * grid.z + elem.z;
int outer_batch_size = stride / coalesce_width;
int strided_batch_idx = (elem.x % outer_batch_size) * coalesce_width +
overall_n * (elem.x / outer_batch_size);
strided_device_idx = strided_batch_idx +
tg_idx / coalesce_width * elems_per_thread * stride +
tg_idx % coalesce_width;
strided_shared_idx = (tg_idx % coalesce_width) * n +
tg_idx / coalesce_width * elems_per_thread;
}
// Four Step FFT First Step
METAL_FUNC void load_strided(int stride, int overall_n) {
compute_strided_indices(stride, overall_n);
for (int e = 0; e < elems_per_thread; e++) {
buf[strided_shared_idx + e] =
post_in(in[strided_device_idx + e * stride]);
}
}
METAL_FUNC void write_strided(int stride, int overall_n) {
for (int e = 0; e < elems_per_thread; e++) {
float2 output = buf[strided_shared_idx + e];
int combined_idx = (strided_device_idx + e * stride) % overall_n;
int ij = (combined_idx / stride) * (combined_idx % stride);
// Apply four step twiddles at end of first step
float2 twiddle = get_twiddle(ij, overall_n);
out[strided_device_idx + e * stride] = complex_mul(output, twiddle);
}
}
};
// Four Step FFT Second Step
template <>
METAL_FUNC void ReadWriter<float2, float2, /*step=*/1>::load_strided(
int stride,
int overall_n) {
// Silence compiler warnings
(void)stride;
(void)overall_n;
// Don't invert between steps
bool default_inv = inv;
inv = false;
load();
inv = default_inv;
}
template <>
METAL_FUNC void ReadWriter<float2, float2, /*step=*/1>::write_strided(
int stride,
int overall_n) {
compute_strided_indices(stride, overall_n);
for (int e = 0; e < elems_per_thread; e++) {
float2 output = buf[strided_shared_idx + e];
out[strided_device_idx + e * stride] = pre_out(output, overall_n);
}
}
// For RFFT, we interleave batches of two real sequences into one complex one:
//
// z_k = x_k + j.y_k
// X_k = (Z_k + Z_(N-k)*) / 2
// Y_k = -j * ((Z_k - Z_(N-k)*) / 2)
//
// This roughly doubles the throughput over the regular FFT.
template <>
METAL_FUNC bool ReadWriter<float, float2>::out_of_bounds() const {
int grid_index = elem.x * grid.y + elem.y;
// We pack two sequences into one for RFFTs
return grid_index * 2 >= batch_size;
}
template <>
METAL_FUNC void ReadWriter<float, float2>::load() const {
int batch_idx = elem.x * grid.y * n * 2 + elem.y * n * 2;
threadgroup float2* seq_buf = buf + elem.y * n;
// No out of bounds accesses on odd batch sizes
int grid_index = elem.x * grid.y + elem.y;
short next_in =
batch_size % 2 == 1 && grid_index * 2 == batch_size - 1 ? 0 : n;
short m = grid.z;
short fft_idx = elem.z;
for (int e = 0; e < elems_per_thread; e++) {
int index = metal::min(fft_idx + e * m, n - 1);
seq_buf[index].x = in[batch_idx + index];
seq_buf[index].y = in[batch_idx + index + next_in];
}
}
template <>
METAL_FUNC void ReadWriter<float, float2>::write() const {
short n_over_2 = (n / 2) + 1;
int batch_idx = elem.x * grid.y * n_over_2 * 2 + elem.y * n_over_2 * 2;
threadgroup float2* seq_buf = buf + elem.y * n;
int grid_index = elem.x * grid.y + elem.y;
short next_out =
batch_size % 2 == 1 && grid_index * 2 == batch_size - 1 ? 0 : n_over_2;
float2 conj = {1, -1};
float2 minus_j = {0, -1};
short m = grid.z;
short fft_idx = elem.z;
for (int e = 0; e < elems_per_thread / 2 + 1; e++) {
int index = metal::min(fft_idx + e * m, n_over_2 - 1);
// x_0 = z_0.real
// y_0 = z_0.imag
if (index == 0) {
out[batch_idx + index] = {seq_buf[index].x, 0};
out[batch_idx + index + next_out] = {seq_buf[index].y, 0};
} else {
float2 x_k = seq_buf[index];
float2 x_n_minus_k = seq_buf[n - index] * conj;
out[batch_idx + index] = (x_k + x_n_minus_k) / 2;
out[batch_idx + index + next_out] =
complex_mul(((x_k - x_n_minus_k) / 2), minus_j);
}
}
}
template <>
METAL_FUNC void ReadWriter<float, float2>::load_padded(
int length,
const device float2* w_k) const {
int batch_idx = elem.x * grid.y * length * 2 + elem.y * length * 2;
threadgroup float2* seq_buf = buf + elem.y * n;
// No out of bounds accesses on odd batch sizes
int grid_index = elem.x * grid.y + elem.y;
short next_in =
batch_size % 2 == 1 && grid_index * 2 == batch_size - 1 ? 0 : length;
short m = grid.z;
short fft_idx = elem.z;
for (int e = 0; e < elems_per_thread; e++) {
int index = metal::min(fft_idx + e * m, n - 1);
if (index < length) {
float2 elem =
float2(in[batch_idx + index], in[batch_idx + index + next_in]);
seq_buf[index] = complex_mul(elem, w_k[index]);
} else {
seq_buf[index] = 0;
}
}
}
template <>
METAL_FUNC void ReadWriter<float, float2>::write_padded(
int length,
const device float2* w_k) const {
int length_over_2 = (length / 2) + 1;
int batch_idx =
elem.x * grid.y * length_over_2 * 2 + elem.y * length_over_2 * 2;
threadgroup float2* seq_buf = buf + elem.y * n + length - 1;
int grid_index = elem.x * grid.y + elem.y;
short next_out = batch_size % 2 == 1 && grid_index * 2 == batch_size - 1
? 0
: length_over_2;
float2 conj = {1, -1};
float2 inv_factor = {1.0f / n, -1.0f / n};
float2 minus_j = {0, -1};
short m = grid.z;
short fft_idx = elem.z;
for (int e = 0; e < elems_per_thread / 2 + 1; e++) {
int index = metal::min(fft_idx + e * m, length_over_2 - 1);
// x_0 = z_0.real
// y_0 = z_0.imag
if (index == 0) {
float2 elem = complex_mul(w_k[index], seq_buf[index] * inv_factor);
out[batch_idx + index] = float2(elem.x, 0);
out[batch_idx + index + next_out] = float2(elem.y, 0);
} else {
float2 x_k = complex_mul(w_k[index], seq_buf[index] * inv_factor);
float2 x_n_minus_k = complex_mul(
w_k[length - index], seq_buf[length - index] * inv_factor);
x_n_minus_k *= conj;
// w_k should happen before this extraction
out[batch_idx + index] = (x_k + x_n_minus_k) / 2;
out[batch_idx + index + next_out] =
complex_mul(((x_k - x_n_minus_k) / 2), minus_j);
}
}
}
// For IRFFT, we do the opposite
//
// Z_k = X_k + j.Y_k
// x_k = Re(Z_k)
// Y_k = Imag(Z_k)
template <>
METAL_FUNC bool ReadWriter<float2, float>::out_of_bounds() const {
int grid_index = elem.x * grid.y + elem.y;
// We pack two sequences into one for IRFFTs
return grid_index * 2 >= batch_size;
}
template <>
METAL_FUNC void ReadWriter<float2, float>::load() const {
short n_over_2 = (n / 2) + 1;
int batch_idx = elem.x * grid.y * n_over_2 * 2 + elem.y * n_over_2 * 2;
threadgroup float2* seq_buf = buf + elem.y * n;
// No out of bounds accesses on odd batch sizes
int grid_index = elem.x * grid.y + elem.y;
short next_in =
batch_size % 2 == 1 && grid_index * 2 == batch_size - 1 ? 0 : n_over_2;
short m = grid.z;
short fft_idx = elem.z;
float2 conj = {1, -1};
float2 plus_j = {0, 1};
for (int t = 0; t < elems_per_thread / 2 + 1; t++) {
int index = metal::min(fft_idx + t * m, n_over_2 - 1);
float2 x = in[batch_idx + index];
float2 y = in[batch_idx + index + next_in];
// NumPy forces first input to be real
bool first_val = index == 0;
// NumPy forces last input on even irffts to be real
bool last_val = n % 2 == 0 && index == n_over_2 - 1;
if (first_val || last_val) {
x = float2(x.x, 0);
y = float2(y.x, 0);
}
seq_buf[index] = x + complex_mul(y, plus_j);
seq_buf[index].y = -seq_buf[index].y;
if (index > 0 && !last_val) {
seq_buf[n - index] = (x * conj) + complex_mul(y * conj, plus_j);
seq_buf[n - index].y = -seq_buf[n - index].y;
}
}
}
template <>
METAL_FUNC void ReadWriter<float2, float>::write() const {
int batch_idx = elem.x * grid.y * n * 2 + elem.y * n * 2;
threadgroup float2* seq_buf = buf + elem.y * n;
int grid_index = elem.x * grid.y + elem.y;
short next_out =
batch_size % 2 == 1 && grid_index * 2 == batch_size - 1 ? 0 : n;
short m = grid.z;
short fft_idx = elem.z;
for (int e = 0; e < elems_per_thread; e++) {
int index = metal::min(fft_idx + e * m, n - 1);
out[batch_idx + index] = seq_buf[index].x / n;
out[batch_idx + index + next_out] = seq_buf[index].y / -n;
}
}
template <>
METAL_FUNC void ReadWriter<float2, float>::load_padded(
int length,
const device float2* w_k) const {
int n_over_2 = (n / 2) + 1;
int length_over_2 = (length / 2) + 1;
int batch_idx =
elem.x * grid.y * length_over_2 * 2 + elem.y * length_over_2 * 2;
threadgroup float2* seq_buf = buf + elem.y * n;
// No out of bounds accesses on odd batch sizes
int grid_index = elem.x * grid.y + elem.y;
short next_in = batch_size % 2 == 1 && grid_index * 2 == batch_size - 1
? 0
: length_over_2;
short m = grid.z;
short fft_idx = elem.z;
float2 conj = {1, -1};
float2 plus_j = {0, 1};
for (int t = 0; t < elems_per_thread / 2 + 1; t++) {
int index = metal::min(fft_idx + t * m, n_over_2 - 1);
float2 x = in[batch_idx + index];
float2 y = in[batch_idx + index + next_in];
if (index < length_over_2) {
bool last_val = length % 2 == 0 && index == length_over_2 - 1;
if (last_val) {
x = float2(x.x, 0);
y = float2(y.x, 0);
}
float2 elem1 = x + complex_mul(y, plus_j);
seq_buf[index] = complex_mul(elem1 * conj, w_k[index]);
if (index > 0 && !last_val) {
float2 elem2 = (x * conj) + complex_mul(y * conj, plus_j);
seq_buf[length - index] =
complex_mul(elem2 * conj, w_k[length - index]);
}
} else {
short pad_index = metal::min(length + (index - length_over_2) * 2, n - 2);
seq_buf[pad_index] = 0;
seq_buf[pad_index + 1] = 0;
}
}
}
template <>
METAL_FUNC void ReadWriter<float2, float>::write_padded(
int length,
const device float2* w_k) const {
int batch_idx = elem.x * grid.y * length * 2 + elem.y * length * 2;
threadgroup float2* seq_buf = buf + elem.y * n + length - 1;
int grid_index = elem.x * grid.y + elem.y;
short next_out =
batch_size % 2 == 1 && grid_index * 2 == batch_size - 1 ? 0 : length;
short m = grid.z;
short fft_idx = elem.z;
float2 inv_factor = {1.0f / n, -1.0f / n};
for (int e = 0; e < elems_per_thread; e++) {
int index = fft_idx + e * m;
if (index < length) {
float2 output = complex_mul(seq_buf[index] * inv_factor, w_k[index]);
out[batch_idx + index] = output.x / length;
out[batch_idx + index + next_out] = output.y / -length;
}
}
}
// Four Step RFFT
template <>
METAL_FUNC void
ReadWriter<float2, float2, /*step=*/1, /*real=*/true>::load_strided(
int stride,
int overall_n) {
// Silence compiler warnings
(void)stride;
(void)overall_n;
// Don't invert between steps
bool default_inv = inv;
inv = false;
load();
inv = default_inv;
}
template <>
METAL_FUNC void
ReadWriter<float2, float2, /*step=*/1, /*real=*/true>::write_strided(
int stride,
int overall_n) {
int overall_n_over_2 = overall_n / 2 + 1;
int coalesce_width = grid.y;
int tg_idx = elem.y * grid.z + elem.z;
int outer_batch_size = stride / coalesce_width;
int strided_batch_idx = (elem.x % outer_batch_size) * coalesce_width +
overall_n_over_2 * (elem.x / outer_batch_size);
strided_device_idx = strided_batch_idx +
tg_idx / coalesce_width * elems_per_thread / 2 * stride +
tg_idx % coalesce_width;
strided_shared_idx = (tg_idx % coalesce_width) * n +
tg_idx / coalesce_width * elems_per_thread / 2;
for (int e = 0; e < elems_per_thread / 2; e++) {
float2 output = buf[strided_shared_idx + e];
out[strided_device_idx + e * stride] = output;
}
// Add on n/2 + 1 element
if (tg_idx == 0 && elem.x % outer_batch_size == 0) {
out[strided_batch_idx + overall_n / 2] = buf[n / 2];
}
}
// Four Step IRFFT
template <>
METAL_FUNC void
ReadWriter<float2, float2, /*step=*/0, /*real=*/true>::load_strided(
int stride,
int overall_n) {
int overall_n_over_2 = overall_n / 2 + 1;
auto conj = float2(1, -1);
compute_strided_indices(stride, overall_n);
// Translate indices in terms of N - k
for (int e = 0; e < elems_per_thread; e++) {
int device_idx = strided_device_idx + e * stride;
int overall_batch = device_idx / overall_n;
int overall_index = device_idx % overall_n;
if (overall_index < overall_n_over_2) {
device_idx -= overall_batch * (overall_n - overall_n_over_2);
buf[strided_shared_idx + e] = in[device_idx] * conj;
} else {
int conj_idx = overall_n - overall_index;
device_idx = overall_batch * overall_n_over_2 + conj_idx;
buf[strided_shared_idx + e] = in[device_idx];
}
}
}
template <>
METAL_FUNC void
ReadWriter<float2, float, /*step=*/1, /*real=*/true>::load_strided(
int stride,
int overall_n) {
// Silence compiler warnings
(void)stride;
(void)overall_n;
bool default_inv = inv;
inv = false;
load();
inv = default_inv;
}
template <>
METAL_FUNC void
ReadWriter<float2, float, /*step=*/1, /*real=*/true>::write_strided(
int stride,
int overall_n) {
compute_strided_indices(stride, overall_n);
for (int e = 0; e < elems_per_thread; e++) {
out[strided_device_idx + e * stride] =
pre_out(buf[strided_shared_idx + e], overall_n).x;
}
}

View File

@ -191,4 +191,17 @@ MTL::ComputePipelineState* get_steel_conv_general_kernel(
return d.get_kernel(kernel_name); return d.get_kernel(kernel_name);
} }
MTL::ComputePipelineState* get_fft_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& hash_name,
const int tg_mem_size,
const std::string& in_type,
const std::string& out_type,
int step,
bool real,
const metal::MTLFCList& func_consts) {
return d.get_kernel(kernel_name, "mlx", hash_name, func_consts);
}
} // namespace mlx::core } // namespace mlx::core

View File

@ -134,6 +134,13 @@ bool is_power_of_2(int n) {
return ((n & (n - 1)) == 0) && n != 0; return ((n & (n - 1)) == 0) && n != 0;
} }
int next_power_of_2(int n) {
if (is_power_of_2(n)) {
return n;
}
return pow(2, std::ceil(std::log2(n)));
}
} // namespace } // namespace
} // namespace mlx::core } // namespace mlx::core

View File

@ -81,7 +81,8 @@ array fft_impl(
if (any_greater) { if (any_greater) {
// Pad with zeros // Pad with zeros
auto tmp = zeros(in_shape, a.dtype(), s); auto tmp = zeros(in_shape, a.dtype(), s);
in = scatter(tmp, std::vector<array>{}, in, std::vector<int>{}, s); std::vector<int> starts(in.ndim(), 0);
in = slice_update(tmp, in, starts, in.shape());
} }
auto out_shape = in_shape; auto out_shape = in_shape;

View File

@ -16,96 +16,140 @@ class TestFFT(mlx_tests.MLXTestCase):
np.testing.assert_allclose(out_np, out_mx, atol=atol, rtol=rtol) np.testing.assert_allclose(out_np, out_mx, atol=atol, rtol=rtol)
def test_fft(self): def test_fft(self):
with mx.stream(mx.cpu): r = np.random.rand(100).astype(np.float32)
r = np.random.rand(100).astype(np.float32) i = np.random.rand(100).astype(np.float32)
i = np.random.rand(100).astype(np.float32) a_np = r + 1j * i
a_np = r + 1j * i self.check_mx_np(mx.fft.fft, np.fft.fft, a_np)
self.check_mx_np(mx.fft.fft, np.fft.fft, a_np)
# Check with slicing and padding # Check with slicing and padding
r = np.random.rand(100).astype(np.float32) r = np.random.rand(100).astype(np.float32)
i = np.random.rand(100).astype(np.float32) i = np.random.rand(100).astype(np.float32)
a_np = r + 1j * i a_np = r + 1j * i
self.check_mx_np(mx.fft.fft, np.fft.fft, a_np, n=80) self.check_mx_np(mx.fft.fft, np.fft.fft, a_np, n=80)
self.check_mx_np(mx.fft.fft, np.fft.fft, a_np, n=120) self.check_mx_np(mx.fft.fft, np.fft.fft, a_np, n=120)
# Check different axes # Check different axes
r = np.random.rand(100, 100).astype(np.float32) r = np.random.rand(100, 100).astype(np.float32)
i = np.random.rand(100, 100).astype(np.float32) i = np.random.rand(100, 100).astype(np.float32)
a_np = r + 1j * i a_np = r + 1j * i
self.check_mx_np(mx.fft.fft, np.fft.fft, a_np, axis=0) self.check_mx_np(mx.fft.fft, np.fft.fft, a_np, axis=0)
self.check_mx_np(mx.fft.fft, np.fft.fft, a_np, axis=1) self.check_mx_np(mx.fft.fft, np.fft.fft, a_np, axis=1)
# Check real fft # Check real fft
a_np = np.random.rand(100).astype(np.float32) a_np = np.random.rand(100).astype(np.float32)
self.check_mx_np(mx.fft.rfft, np.fft.rfft, a_np) self.check_mx_np(mx.fft.rfft, np.fft.rfft, a_np)
self.check_mx_np(mx.fft.rfft, np.fft.rfft, a_np, n=80) self.check_mx_np(mx.fft.rfft, np.fft.rfft, a_np, n=80)
self.check_mx_np(mx.fft.rfft, np.fft.rfft, a_np, n=120) self.check_mx_np(mx.fft.rfft, np.fft.rfft, a_np, n=120)
# Check real inverse # Check real inverse
r = np.random.rand(100, 100).astype(np.float32) r = np.random.rand(100, 100).astype(np.float32)
i = np.random.rand(100, 100).astype(np.float32) i = np.random.rand(100, 100).astype(np.float32)
a_np = r + 1j * i a_np = r + 1j * i
self.check_mx_np(mx.fft.ifft, np.fft.ifft, a_np) self.check_mx_np(mx.fft.ifft, np.fft.ifft, a_np)
self.check_mx_np(mx.fft.ifft, np.fft.ifft, a_np, n=80) self.check_mx_np(mx.fft.ifft, np.fft.ifft, a_np, n=80)
self.check_mx_np(mx.fft.ifft, np.fft.ifft, a_np, n=120) self.check_mx_np(mx.fft.ifft, np.fft.ifft, a_np, n=120)
self.check_mx_np(mx.fft.irfft, np.fft.irfft, a_np)
self.check_mx_np(mx.fft.irfft, np.fft.irfft, a_np, n=80) x = np.fft.rfft(a_np)
self.check_mx_np(mx.fft.irfft, np.fft.irfft, a_np, n=120) self.check_mx_np(mx.fft.irfft, np.fft.irfft, x)
def test_fftn(self): def test_fftn(self):
with mx.stream(mx.cpu): r = np.random.randn(8, 8, 8).astype(np.float32)
r = np.random.randn(8, 8, 8).astype(np.float32) i = np.random.randn(8, 8, 8).astype(np.float32)
i = np.random.randn(8, 8, 8).astype(np.float32) a = r + 1j * i
a = r + 1j * i
axes = [None, (1, 2), (2, 1), (0, 2)] axes = [None, (1, 2), (2, 1), (0, 2)]
shapes = [None, (10, 5), (5, 10)] shapes = [None, (10, 5), (5, 10)]
ops = [ ops = [
"fft2", "fft2",
"ifft2", "ifft2",
"rfft2", "rfft2",
"irfft2", "irfft2",
"fftn", "fftn",
"ifftn", "ifftn",
"rfftn", "rfftn",
"irfftn", "irfftn",
]
for op, ax, s in itertools.product(ops, axes, shapes):
x = a
if op in ["rfft2", "rfftn"]:
x = r
elif op == "irfft2":
x = np.ascontiguousarray(np.fft.rfft2(x, axes=ax, s=s))
elif op == "irfftn":
x = np.ascontiguousarray(np.fft.rfftn(x, axes=ax, s=s))
mx_op = getattr(mx.fft, op)
np_op = getattr(np.fft, op)
self.check_mx_np(mx_op, np_op, x, axes=ax, s=s)
def _run_ffts(self, shape, atol=1e-4, rtol=1e-4):
np.random.seed(9)
r = np.random.rand(*shape).astype(np.float32)
i = np.random.rand(*shape).astype(np.float32)
a_np = r + 1j * i
self.check_mx_np(mx.fft.fft, np.fft.fft, a_np, atol=atol, rtol=rtol)
self.check_mx_np(mx.fft.ifft, np.fft.ifft, a_np, atol=atol, rtol=rtol)
self.check_mx_np(mx.fft.rfft, np.fft.rfft, r, atol=atol, rtol=rtol)
ia_np = np.fft.rfft(a_np)
self.check_mx_np(
mx.fft.irfft, np.fft.irfft, ia_np, atol=atol, rtol=rtol, n=shape[-1]
)
self.check_mx_np(mx.fft.irfft, np.fft.irfft, ia_np, atol=atol, rtol=rtol)
def test_fft_shared_mem(self):
nums = np.concatenate(
[
# small radix
np.arange(2, 14),
# powers of 2
[2**k for k in range(4, 13)],
# stockham
[3 * 3 * 3, 3 * 11, 11 * 13 * 2, 7 * 4 * 13 * 11, 13 * 13 * 11],
# rader
[17, 23, 29, 17 * 8 * 3, 23 * 2, 1153, 1982],
# bluestein
[47, 83, 17 * 17],
# large stockham
[3159, 3645, 3969, 4004],
] ]
)
for batch_size in (1, 3, 32):
for num in nums:
atol = 1e-4 if num < 1025 else 1e-3
self._run_ffts((batch_size, num), atol=atol)
for op, ax, s in itertools.product(ops, axes, shapes): @unittest.skip("Too slow for CI but useful for local testing.")
x = a def test_fft_exhaustive(self):
if op in ["rfft2", "rfftn"]: nums = range(2, 4097)
x = r for batch_size in (1, 3, 32):
mx_op = getattr(mx.fft, op) for num in nums:
np_op = getattr(np.fft, op) print(num)
self.check_mx_np(mx_op, np_op, x, axes=ax, s=s) atol = 1e-4 if num < 1025 else 1e-3
self._run_ffts((batch_size, num), atol=atol)
def test_fft_powers_of_two(self): def test_fft_big_powers_of_two(self):
shape = (16, 4, 8) # TODO: improve precision on big powers of two on GPU
# np.fft.fft always uses double precision complex128 for k in range(12, 17):
# mx.fft.fft only supports single precision complex64 self._run_ffts((3, 2**k), atol=1e-3)
# hence the fairly tolerant equality checks.
atol = 1e-4
rtol = 1e-4
np.random.seed(7)
for k in range(4, 12):
r = np.random.rand(*shape, 2**k).astype(np.float32)
i = np.random.rand(*shape, 2**k).astype(np.float32)
a_np = r + 1j * i
self.check_mx_np(mx.fft.fft, np.fft.fft, a_np, atol=atol, rtol=rtol)
r = np.random.rand(*shape, 32).astype(np.float32) for k in range(17, 20):
i = np.random.rand(*shape, 32).astype(np.float32) self._run_ffts((3, 2**k), atol=1e-2)
a_np = r + 1j * i
for axis in range(4):
self.check_mx_np(
mx.fft.fft, np.fft.fft, a_np, atol=atol, rtol=rtol, axis=axis
)
r = np.random.rand(4, 8).astype(np.float32) def test_fft_large_numbers(self):
i = np.random.rand(4, 8).astype(np.float32) numbers = [
a_np = r + 1j * i 1037, # prime > 2048
a_mx = mx.array(a_np) 18247, # medium size prime factors
1259 * 11, # large prime factors
7883, # large prime
3**8, # large stockham decomposable
3109, # bluestein
4006, # large rader
]
for large_num in numbers:
self._run_ffts((1, large_num), atol=1e-3)
def test_fft_contiguity(self): def test_fft_contiguity(self):
r = np.random.rand(4, 8).astype(np.float32) r = np.random.rand(4, 8).astype(np.float32)

View File

@ -7,8 +7,6 @@
using namespace mlx::core; using namespace mlx::core;
TEST_CASE("test fft basics") { TEST_CASE("test fft basics") {
auto device = default_device();
set_default_device(Device::cpu);
array x(1.0); array x(1.0);
CHECK_THROWS(fft::fft(x)); CHECK_THROWS(fft::fft(x));
CHECK_THROWS(fft::ifft(x)); CHECK_THROWS(fft::ifft(x));
@ -94,13 +92,9 @@ TEST_CASE("test fft basics") {
CHECK(array_equal(y, array(expected_1, {2, 2})).item<bool>()); CHECK(array_equal(y, array(expected_1, {2, 2})).item<bool>());
CHECK(array_equal(fft::ifft(y, 1), x).item<bool>()); CHECK(array_equal(fft::ifft(y, 1), x).item<bool>());
} }
set_default_device(device);
} }
TEST_CASE("test real ffts") { TEST_CASE("test real ffts") {
auto device = default_device();
set_default_device(Device::cpu);
auto x = array({1.0}); auto x = array({1.0});
auto y = fft::rfft(x); auto y = fft::rfft(x);
CHECK_EQ(y.dtype(), complex64); CHECK_EQ(y.dtype(), complex64);
@ -124,14 +118,9 @@ TEST_CASE("test real ffts") {
CHECK_EQ(y.size(), 2); CHECK_EQ(y.size(), 2);
CHECK_EQ(y.dtype(), float32); CHECK_EQ(y.dtype(), float32);
CHECK(array_equal(y, array({0.5f, -0.5f})).item<bool>()); CHECK(array_equal(y, array({0.5f, -0.5f})).item<bool>());
set_default_device(device);
} }
TEST_CASE("test fftn") { TEST_CASE("test fftn") {
auto device = default_device();
set_default_device(Device::cpu);
auto x = zeros({5, 5, 5}); auto x = zeros({5, 5, 5});
CHECK_THROWS_AS(fft::fftn(x, {}, {0, 3}), std::invalid_argument); CHECK_THROWS_AS(fft::fftn(x, {}, {0, 3}), std::invalid_argument);
CHECK_THROWS_AS(fft::fftn(x, {}, {0, -4}), std::invalid_argument); CHECK_THROWS_AS(fft::fftn(x, {}, {0, -4}), std::invalid_argument);
@ -204,8 +193,6 @@ TEST_CASE("test fftn") {
CHECK_EQ(y.shape(), std::vector<int>{5, 8}); CHECK_EQ(y.shape(), std::vector<int>{5, 8});
CHECK_EQ(y.dtype(), float32); CHECK_EQ(y.dtype(), float32);
} }
set_default_device(device);
} }
TEST_CASE("test fft with provided shape") { TEST_CASE("test fft with provided shape") {
@ -234,9 +221,6 @@ TEST_CASE("test fft with provided shape") {
} }
TEST_CASE("test fft vmap") { TEST_CASE("test fft vmap") {
auto device = default_device();
set_default_device(Device::cpu);
auto fft_fn = [](array x) { return fft::fft(x); }; auto fft_fn = [](array x) { return fft::fft(x); };
auto x = reshape(arange(8), {2, 4}); auto x = reshape(arange(8), {2, 4});
auto y = vmap(fft_fn)(x); auto y = vmap(fft_fn)(x);
@ -252,14 +236,9 @@ TEST_CASE("test fft vmap") {
y = vmap(rfft_fn, 1, 1)(x); y = vmap(rfft_fn, 1, 1)(x);
CHECK(array_equal(y, fft::rfft(x, 0)).item<bool>()); CHECK(array_equal(y, fft::rfft(x, 0)).item<bool>());
set_default_device(device);
} }
TEST_CASE("test fft grads") { TEST_CASE("test fft grads") {
auto device = default_device();
set_default_device(Device::cpu);
// Regular // Regular
auto fft_fn = [](array x) { return fft::fft(x); }; auto fft_fn = [](array x) { return fft::fft(x); };
auto cotangent = astype(arange(10), complex64); auto cotangent = astype(arange(10), complex64);
@ -328,6 +307,4 @@ TEST_CASE("test fft grads") {
zeros({5, 8})) zeros({5, 8}))
.second; .second;
CHECK_EQ(vjp_out.shape(), std::vector<int>{5, 5}); CHECK_EQ(vjp_out.shape(), std::vector<int>{5, 5});
set_default_device(device);
} }