From 27d70c7d9d72323e0de4d920af06ebc443986730 Mon Sep 17 00:00:00 2001 From: Alex Barron Date: Thu, 6 Jun 2024 12:57:25 -0700 Subject: [PATCH] Feature complete Metal FFT (#1102) * feature complete metal fft * fix contiguity bug * jit fft * simplify rader/bluestein constant computation * remove kernel/utils.h dep * remove bf16.h dep * format --------- Co-authored-by: Alex Barron --- benchmarks/python/fft_bench.py | 110 ++- mlx/backend/metal/CMakeLists.txt | 5 + mlx/backend/metal/fft.cpp | 788 ++++++++++++++++++++-- mlx/backend/metal/jit/fft.h | 53 ++ mlx/backend/metal/jit/includes.h | 1 + mlx/backend/metal/jit_kernels.cpp | 48 ++ mlx/backend/metal/kernels.h | 11 + mlx/backend/metal/kernels/CMakeLists.txt | 3 + mlx/backend/metal/kernels/fft.h | 486 +++++++++++++ mlx/backend/metal/kernels/fft.metal | 267 +++----- mlx/backend/metal/kernels/fft/radix.h | 328 +++++++++ mlx/backend/metal/kernels/fft/readwrite.h | 622 +++++++++++++++++ mlx/backend/metal/nojit_kernels.cpp | 13 + mlx/backend/metal/utils.h | 7 + mlx/fft.cpp | 3 +- python/tests/test_fft.py | 200 +++--- tests/fft_tests.cpp | 23 - 17 files changed, 2601 insertions(+), 367 deletions(-) create mode 100644 mlx/backend/metal/jit/fft.h create mode 100644 mlx/backend/metal/kernels/fft.h create mode 100644 mlx/backend/metal/kernels/fft/radix.h create mode 100644 mlx/backend/metal/kernels/fft/readwrite.h diff --git a/benchmarks/python/fft_bench.py b/benchmarks/python/fft_bench.py index 865d6408f..8f3603f47 100644 --- a/benchmarks/python/fft_bench.py +++ b/benchmarks/python/fft_bench.py @@ -3,6 +3,8 @@ import matplotlib import mlx.core as mx import numpy as np +import sympy +import torch from time_utils import measure_runtime matplotlib.use("Agg") @@ -16,40 +18,100 @@ def bandwidth_gb(runtime_ms, system_size): return system_size * bytes_per_fft / runtime_ms * ms_per_s / bytes_per_gb -def run_bench(system_size): - def fft(x): - out = mx.fft.fft(x) +def run_bench(system_size, fft_sizes, backend="mlx", dim=1): + def fft_mlx(x): + if dim == 1: + out = mx.fft.fft(x) + elif dim == 2: + out = mx.fft.fft2(x) mx.eval(out) return out - bandwidths = [] - for k in range(4, 12): - n = 2**k - x = mx.random.uniform(shape=(system_size // n, n)).astype(mx.float32) - x = x.astype(mx.complex64) - mx.eval(x) - runtime_ms = measure_runtime(fft, x=x) - bandwidths.append(bandwidth_gb(runtime_ms, system_size)) + def fft_mps(x): + if dim == 1: + out = torch.fft.fft(x) + elif dim == 2: + out = torch.fft.fft2(x) + torch.mps.synchronize() + return out - return bandwidths + bandwidths = [] + for n in fft_sizes: + batch_size = system_size // n**dim + shape = [batch_size] + [n for _ in range(dim)] + if backend == "mlx": + x_np = np.random.uniform(size=(system_size // n, n)).astype(np.complex64) + x = mx.array(x_np) + mx.eval(x) + fft = fft_mlx + elif backend == "mps": + x_np = np.random.uniform(size=(system_size // n, n)).astype(np.complex64) + x = torch.tensor(x_np, device="mps") + torch.mps.synchronize() + fft = fft_mps + else: + raise NotImplementedError() + runtime_ms = measure_runtime(fft, x=x) + bandwidth = bandwidth_gb(runtime_ms, np.prod(shape)) + print(n, bandwidth) + bandwidths.append(bandwidth) + + return np.array(bandwidths) def time_fft(): - with mx.stream(mx.cpu): - cpu_bandwidths = run_bench(system_size=int(2**22)) + x = np.array(range(2, 512)) + system_size = int(2**26) + print("MLX GPU") with mx.stream(mx.gpu): - gpu_bandwidths = run_bench(system_size=int(2**29)) + gpu_bandwidths = run_bench(system_size=system_size, fft_sizes=x) - # plot bandwidths - x = [2**k for k in range(4, 12)] - plt.scatter(x, gpu_bandwidths, color="green", label="GPU") - plt.scatter(x, cpu_bandwidths, color="red", label="CPU") - plt.title("MLX FFT Benchmark") - plt.xlabel("N") - plt.ylabel("Bandwidth (GB/s)") - plt.legend() - plt.savefig("fft_plot.png") + print("MPS GPU") + mps_bandwidths = run_bench(system_size=system_size, fft_sizes=x, backend="mps") + + print("CPU") + system_size = int(2**20) + with mx.stream(mx.cpu): + cpu_bandwidths = run_bench(system_size=system_size, fft_sizes=x) + + x = np.array(x) + + all_indices = x - x[0] + radix_2to13 = ( + np.array([i for i in x if all(p <= 13 for p in sympy.primefactors(i))]) - x[0] + ) + bluesteins = ( + np.array([i for i in x if any(p > 13 for p in sympy.primefactors(i))]) - x[0] + ) + + for indices, name in [ + (all_indices, "All"), + (radix_2to13, "Radix 2-13"), + (bluesteins, "Bluestein's"), + ]: + # plot bandwidths + print(name) + plt.scatter(x[indices], gpu_bandwidths[indices], color="green", label="GPU") + plt.scatter(x[indices], mps_bandwidths[indices], color="blue", label="MPS") + plt.scatter(x[indices], cpu_bandwidths[indices], color="red", label="CPU") + plt.title(f"MLX FFT Benchmark -- {name}") + plt.xlabel("N") + plt.ylabel("Bandwidth (GB/s)") + plt.legend() + plt.savefig(f"{name}.png") + plt.clf() + + av_gpu_bandwidth = np.mean(gpu_bandwidths) + av_mps_bandwidth = np.mean(mps_bandwidths) + av_cpu_bandwidth = np.mean(cpu_bandwidths) + print("Average bandwidths:") + print("GPU:", av_gpu_bandwidth) + print("MPS:", av_mps_bandwidth) + print("CPU:", av_cpu_bandwidth) + + portion_faster = len(np.where(gpu_bandwidths > mps_bandwidths)[0]) / len(x) + print("Percent MLX faster than MPS: ", portion_faster * 100) if __name__ == "__main__": diff --git a/mlx/backend/metal/CMakeLists.txt b/mlx/backend/metal/CMakeLists.txt index c978fe5e5..cbc18bb3a 100644 --- a/mlx/backend/metal/CMakeLists.txt +++ b/mlx/backend/metal/CMakeLists.txt @@ -64,6 +64,11 @@ if (MLX_METAL_JIT) make_jit_source(unary) make_jit_source(binary) make_jit_source(binary_two) + make_jit_source( + fft + kernels/fft/radix.h + kernels/fft/readwrite.h + ) make_jit_source(ternary) make_jit_source(softmax) make_jit_source(scan) diff --git a/mlx/backend/metal/fft.cpp b/mlx/backend/metal/fft.cpp index 9f64cefd6..394f3c272 100644 --- a/mlx/backend/metal/fft.cpp +++ b/mlx/backend/metal/fft.cpp @@ -1,106 +1,794 @@ // Copyright © 2023 Apple Inc. +#include +#include +#include +#include +#include + +#include "mlx/3rdparty/pocketfft.h" +#include "mlx/backend/metal/binary.h" #include "mlx/backend/metal/copy.h" +#include "mlx/backend/metal/kernels.h" +#include "mlx/backend/metal/slicing.h" +#include "mlx/backend/metal/unary.h" #include "mlx/backend/metal/utils.h" #include "mlx/mlx.h" #include "mlx/primitives.h" namespace mlx::core { -void FFT::eval_gpu(const std::vector& inputs, array& out) { - auto& s = out.primitive().stream(); - auto& d = metal::device(s.device); +using MTLFC = std::tuple; - auto& in = inputs[0]; +#define MAX_STOCKHAM_FFT_SIZE 4096 +#define MAX_RADER_FFT_SIZE 2048 +#define MAX_BLUESTEIN_FFT_SIZE 2048 +// Threadgroup memory batching improves throughput for small n +#define MIN_THREADGROUP_MEM_SIZE 256 +// For strided reads/writes, coalesce at least this many complex64s +#define MIN_COALESCE_WIDTH 4 - if (axes_.size() == 0 || axes_.size() > 1 || inverse_ || - in.dtype() != complex64 || out.dtype() != complex64) { - // Could also fallback to CPU implementation here. - throw std::runtime_error( - "GPU FFT is only implemented for 1D, forward, complex FFTs."); +inline const std::vector supported_radices() { + // Ordered by preference in decomposition. + return {13, 11, 8, 7, 6, 5, 4, 3, 2}; +} + +std::vector prime_factors(int n) { + int z = 2; + std::vector factors; + while (z * z <= n) { + if (n % z == 0) { + factors.push_back(z); + n /= z; + } else { + z++; + } + } + if (n > 1) { + factors.push_back(n); + } + return factors; +} + +struct FourStepParams { + bool required = false; + bool first_step = true; + int n1 = 0; + int n2 = 0; +}; + +// Forward Declaration +void fft_op( + const array& in, + array& out, + size_t axis, + bool inverse, + bool real, + const FourStepParams four_step_params, + bool inplace, + const Stream& s); + +struct FFTPlan { + int n = 0; + // Number of steps for each radix in the Stockham decomposition + std::vector stockham; + // Number of steps for each radix in the Rader decomposition + std::vector rader; + // Rader factor, 1 if no rader factors + int rader_n = 1; + int bluestein_n = -1; + // Four step FFT + bool four_step = false; + int n1 = 0; + int n2 = 0; +}; + +int next_fast_n(int n) { + return next_power_of_2(n); +} + +std::vector plan_stockham_fft(int n) { + auto radices = supported_radices(); + std::vector plan(radices.size(), 0); + int orig_n = n; + if (n == 1) { + return plan; + } + for (int i = 0; i < radices.size(); i++) { + int radix = radices[i]; + // Manually tuned radices for powers of 2 + if (is_power_of_2(orig_n) && orig_n < 512 && radix > 4) { + continue; + } + while (n % radix == 0) { + plan[i] += 1; + n /= radix; + if (n == 1) { + return plan; + } + } + } + throw std::runtime_error("Unplannable"); +} + +FFTPlan plan_fft(int n) { + auto radices = supported_radices(); + std::set radices_set(radices.begin(), radices.end()); + + FFTPlan plan; + plan.n = n; + plan.rader = std::vector(radices.size(), 0); + auto factors = prime_factors(n); + int remaining_n = n; + + // Four Step FFT when N is too large for shared mem. + if (n > MAX_STOCKHAM_FFT_SIZE && is_power_of_2(n)) { + // For power's of two we have a fast, no transpose four step implementation. + plan.four_step = true; + // Rough heuristic for choosing faster powers of two when we can + plan.n2 = n > 65536 ? 1024 : 64; + plan.n1 = n / plan.n2; + return plan; + } else if (n > MAX_STOCKHAM_FFT_SIZE) { + // Otherwise we use a multi-upload Bluestein's + plan.four_step = true; + plan.bluestein_n = next_fast_n(2 * n - 1); + return plan; } - size_t n = in.shape(axes_[0]); + for (int factor : factors) { + // Make sure the factor is a supported radix + if (radices_set.find(factor) == radices_set.end()) { + // We only support a single Rader factor currently + // TODO(alexbarron) investigate weirdness with large + // Rader sizes -- possibly a compiler issue? + if (plan.rader_n > 1 || n > MAX_RADER_FFT_SIZE) { + plan.four_step = n > MAX_BLUESTEIN_FFT_SIZE; + plan.bluestein_n = next_fast_n(2 * n - 1); + plan.stockham = plan_stockham_fft(plan.bluestein_n); + plan.rader = std::vector(radices.size(), 0); + return plan; + } + // See if we can use Rader's algorithm to Stockham decompose n - 1 + auto rader_factors = prime_factors(factor - 1); + int last_factor = -1; + for (int rf : rader_factors) { + // We don't nest Rader's algorithm so if `factor - 1` + // isn't Stockham decomposable we give up and do Bluestein's. + if (radices_set.find(rf) == radices_set.end()) { + plan.four_step = n > MAX_BLUESTEIN_FFT_SIZE; + plan.bluestein_n = next_fast_n(2 * n - 1); + plan.stockham = plan_stockham_fft(plan.bluestein_n); + plan.rader = std::vector(radices.size(), 0); + return plan; + } + } + plan.rader = plan_stockham_fft(factor - 1); + plan.rader_n = factor; + remaining_n /= factor; + } + } - if (!is_power_of_2(n) || n > 2048 || n < 4) { - throw std::runtime_error( - "GPU FFT is only implemented for the powers of 2 from 4 -> 2048"); + plan.stockham = plan_stockham_fft(remaining_n); + return plan; +} + +int compute_elems_per_thread(FFTPlan plan) { + // Heuristics for selecting an efficient number + // of threads to use for a particular mixed-radix FFT. + auto n = plan.n; + + std::vector steps; + auto radices = supported_radices(); + steps.insert(steps.end(), plan.stockham.begin(), plan.stockham.end()); + steps.insert(steps.end(), plan.rader.begin(), plan.rader.end()); + std::set used_radices; + for (int i = 0; i < steps.size(); i++) { + int radix = radices[i % radices.size()]; + if (steps[i] > 0) { + used_radices.insert(radix); + } + } + + // Manual tuning for 7/11/13 + if (used_radices.find(7) != used_radices.end() && + (used_radices.find(11) != used_radices.end() || + used_radices.find(13) != used_radices.end())) { + return 7; + } else if ( + used_radices.find(11) != used_radices.end() && + used_radices.find(13) != used_radices.end()) { + return 11; + } + + // TODO(alexbarron) Some really weird stuff is going on + // for certain `elems_per_thread` on large composite n. + // Possibly a compiler issue? + if (n == 3159) + return 13; + if (n == 3645) + return 5; + if (n == 3969) + return 7; + if (n == 1982) + return 5; + + if (used_radices.size() == 1) { + return *(used_radices.begin()); + } + if (used_radices.size() == 2) { + if (used_radices.find(11) != used_radices.end() || + used_radices.find(13) != used_radices.end()) { + return std::accumulate(used_radices.begin(), used_radices.end(), 0) / 2; + } + std::vector radix_vec(used_radices.begin(), used_radices.end()); + return radix_vec[1]; + } + // In all other cases use the second smallest radix. + std::vector radix_vec(used_radices.begin(), used_radices.end()); + return radix_vec[1]; +} + +// Rader +int mod_exp(int x, int y, int n) { + int out = 1; + while (y) { + if (y & 1) { + out = out * x % n; + } + y >>= 1; + x = x * x % n; + } + return out; +} + +int primitive_root(int n) { + auto factors = prime_factors(n - 1); + + for (int r = 2; r < n - 1; r++) { + bool found = true; + for (int factor : factors) { + if (mod_exp(r, (n - 1) / factor, n) == 1) { + found = false; + break; + } + } + if (found) { + return r; + } + } + return -1; +} + +std::tuple compute_raders_constants( + int rader_n, + const Stream& s) { + int proot = primitive_root(rader_n); + // Fermat's little theorem + int inv = mod_exp(proot, rader_n - 2, rader_n); + std::vector g_q(rader_n - 1); + std::vector g_minus_q(rader_n - 1); + for (int i = 0; i < rader_n - 1; i++) { + g_q[i] = mod_exp(proot, i, rader_n); + g_minus_q[i] = mod_exp(inv, i, rader_n); + } + array g_q_arr(g_q.begin(), {rader_n - 1}); + array g_minus_q_arr(g_minus_q.begin(), {rader_n - 1}); + + std::vector> b_q(rader_n - 1); + for (int i = 0; i < rader_n - 1; i++) { + float pi_i = (float)g_minus_q[i] * -2.0 * M_PI / rader_n; + b_q[i] = std::exp(std::complex(0, pi_i)); + } + + array b_q_fft({rader_n - 1}, complex64, nullptr, {}); + b_q_fft.set_data(allocator::malloc_or_wait(b_q_fft.nbytes())); + auto b_q_fft_ptr = + reinterpret_cast*>(b_q_fft.data()); + std::ptrdiff_t item_size = b_q_fft.itemsize(); + size_t fft_size = rader_n - 1; + // This FFT is always small (<4096, batch 1) so save some overhead + // and do it on the CPU + pocketfft::c2c( + /* shape= */ {fft_size}, + /* stride_in= */ {item_size}, + /* stride_out= */ {item_size}, + /* axes= */ {0}, + /* forward= */ true, + /* data_in= */ b_q.data(), + /* data_out= */ b_q_fft_ptr, + /* scale= */ 1.0f); + return std::make_tuple(b_q_fft, g_q_arr, g_minus_q_arr); +} + +// Bluestein +std::pair compute_bluestein_constants(int n, int bluestein_n) { + // We need to calculate the Bluestein twiddle factors + // in double precision for the overall numerical stability + // of Bluestein's FFT algorithm to be acceptable. + // + // Metal doesn't support float64, so instead we + // manually implement the required operations on cpu. + // + // In numpy: + // w_k = np.exp(-1j * np.pi / N * (np.arange(-N + 1, N) ** 2)) + // w_q = np.fft.fft(1/w_k) + // return w_k, w_q + int length = 2 * n - 1; + + std::vector> w_k_vec(n); + std::vector> w_q_vec(bluestein_n, 0); + + for (int i = -n + 1; i < n; i++) { + double theta = pow(i, 2) * M_PI / (double)n; + w_q_vec[i + n - 1] = std::exp(std::complex(0, theta)); + if (i >= 0) { + w_k_vec[i] = std::exp(std::complex(0, -theta)); + } + } + + array w_k({n}, complex64, nullptr, {}); + w_k.set_data(allocator::malloc_or_wait(w_k.nbytes())); + std::copy(w_k_vec.begin(), w_k_vec.end(), w_k.data()); + + array w_q({bluestein_n}, complex64, nullptr, {}); + w_q.set_data(allocator::malloc_or_wait(w_q.nbytes())); + auto w_q_ptr = + reinterpret_cast*>(w_q.data()); + + std::ptrdiff_t item_size = w_q.itemsize(); + size_t fft_size = bluestein_n; + pocketfft::c2c( + /* shape= */ {fft_size}, + /* stride_in= */ {item_size}, + /* stride_out= */ {item_size}, + /* axes= */ {0}, + /* forward= */ true, + /* data_in= */ w_q_vec.data(), + /* data_out= */ w_q_ptr, + /* scale= */ 1.0f); + return std::make_tuple(w_k, w_q); +} + +void multi_upload_bluestein_fft( + const array& in, + array& out, + size_t axis, + bool inverse, + bool real, + FFTPlan& plan, + std::vector copies, + const Stream& s) { + // TODO(alexbarron) Implement fused kernels for mutli upload bluestein's + // algorithm + int n = inverse ? out.shape(axis) : in.shape(axis); + auto [w_k, w_q] = compute_bluestein_constants(n, plan.bluestein_n); + + // Broadcast w_q and w_k to the batch size + std::vector b_strides(in.ndim(), 0); + b_strides[axis] = 1; + array w_k_broadcast({}, complex64, nullptr, {}); + array w_q_broadcast({}, complex64, nullptr, {}); + w_k_broadcast.copy_shared_buffer(w_k, b_strides, {}, w_k.data_size()); + w_q_broadcast.copy_shared_buffer(w_q, b_strides, {}, w_q.data_size()); + + auto temp_shape = inverse ? out.shape() : in.shape(); + array temp(temp_shape, complex64, nullptr, {}); + array temp1(temp_shape, complex64, nullptr, {}); + + if (real && !inverse) { + // Convert float32->complex64 + copy_gpu(in, temp, CopyType::General, s); + } else if (real && inverse) { + int back_offset = n % 2 == 0 ? 2 : 1; + auto slice_shape = in.shape(); + slice_shape[axis] -= back_offset; + array slice_temp(slice_shape, complex64, nullptr, {}); + array conj_temp(in.shape(), complex64, nullptr, {}); + copies.push_back(slice_temp); + copies.push_back(conj_temp); + + std::vector rstarts(in.ndim(), 0); + std::vector rstrides(in.ndim(), 1); + rstarts[axis] = in.shape(axis) - back_offset; + rstrides[axis] = -1; + unary_op_gpu({in}, conj_temp, "conj", s); + slice_gpu(in, slice_temp, rstarts, rstrides, s); + concatenate_gpu({conj_temp, slice_temp}, temp, (int)axis, s); + } else if (inverse) { + unary_op_gpu({in}, temp, "conj", s); + } else { + temp.copy_shared_buffer(in); + } + + binary_op_gpu({temp, w_k_broadcast}, temp1, "mul", s); + + std::vector> pads; + auto padded_shape = out.shape(); + padded_shape[axis] = plan.bluestein_n; + array pad_temp(padded_shape, complex64, nullptr, {}); + pad_gpu(temp1, array(complex64_t{0.0f, 0.0f}), pad_temp, {(int)axis}, {0}, s); + + array pad_temp1(padded_shape, complex64, nullptr, {}); + fft_op( + pad_temp, + pad_temp1, + axis, + /*inverse=*/false, + /*real=*/false, + FourStepParams(), + /*inplace=*/false, + s); + + binary_op_gpu_inplace({pad_temp1, w_q_broadcast}, pad_temp, "mul", s); + + fft_op( + pad_temp, + pad_temp1, + axis, + /* inverse= */ true, + /* real= */ false, + FourStepParams(), + /*inplace=*/true, + s); + + int offset = plan.bluestein_n - (2 * n - 1); + std::vector starts(in.ndim(), 0); + std::vector strides(in.ndim(), 1); + starts[axis] = plan.bluestein_n - offset - n; + slice_gpu(pad_temp1, temp, starts, strides, s); + + binary_op_gpu_inplace({temp, w_k_broadcast}, temp1, "mul", s); + + if (real && !inverse) { + std::vector rstarts(in.ndim(), 0); + std::vector rstrides(in.ndim(), 1); + slice_gpu(temp1, out, rstarts, strides, s); + } else if (real && inverse) { + std::vector b_strides(in.ndim(), 0); + auto inv_n = array({1.0f / n}, {1}, float32); + array temp_float(out.shape(), out.dtype(), nullptr, {}); + copies.push_back(temp_float); + copies.push_back(inv_n); + + copy_gpu(temp1, temp_float, CopyType::General, s); + binary_op_gpu({temp_float, inv_n}, out, "mul", s); + } else if (inverse) { + auto inv_n = array({1.0f / n}, {1}, complex64); + unary_op_gpu({temp1}, temp, "conj", s); + binary_op_gpu({temp, inv_n}, out, "mul", s); + copies.push_back(inv_n); + } else { + out.copy_shared_buffer(temp1); + } + + copies.push_back(w_k); + copies.push_back(w_q); + copies.push_back(w_k_broadcast); + copies.push_back(w_q_broadcast); + copies.push_back(temp); + copies.push_back(temp1); + copies.push_back(pad_temp); + copies.push_back(pad_temp1); +} + +void four_step_fft( + const array& in, + array& out, + size_t axis, + bool inverse, + bool real, + FFTPlan& plan, + std::vector copies, + const Stream& s) { + auto& d = metal::device(s.device); + + if (plan.bluestein_n == -1) { + // Fast no transpose implementation for powers of 2. + FourStepParams four_step_params = { + /* required= */ true, /* first_step= */ true, plan.n1, plan.n2}; + auto temp_shape = (real && inverse) ? out.shape() : in.shape(); + array temp(temp_shape, complex64, nullptr, {}); + fft_op( + in, temp, axis, inverse, real, four_step_params, /*inplace=*/false, s); + four_step_params.first_step = false; + fft_op( + temp, out, axis, inverse, real, four_step_params, /*inplace=*/false, s); + copies.push_back(temp); + } else { + multi_upload_bluestein_fft(in, out, axis, inverse, real, plan, copies, s); + } +} + +void fft_op( + const array& in, + array& out, + size_t axis, + bool inverse, + bool real, + const FourStepParams four_step_params, + bool inplace, + const Stream& s) { + auto& d = metal::device(s.device); + + size_t n = out.dtype() == float32 ? out.shape(axis) : in.shape(axis); + if (n == 1) { + out.copy_shared_buffer(in); + return; + } + + if (four_step_params.required) { + // Four Step FFT decomposes into two FFTs: n1 on columns, n2 on rows + n = four_step_params.first_step ? four_step_params.n1 : four_step_params.n2; } // Make sure that the array is contiguous and has stride 1 in the FFT dim std::vector copies; - auto check_input = [this, &copies, &s](const array& x) { + auto check_input = [&axis, &copies, &s](const array& x) { // TODO: Pass the strides to the kernel so // we can avoid the copy when x is not contiguous. - bool no_copy = x.strides()[axes_[0]] == 1 && x.flags().row_contiguous || - x.flags().col_contiguous; + bool no_copy = x.strides()[axis] == 1 && + (x.flags().row_contiguous || x.flags().col_contiguous); if (no_copy) { return x; } else { array x_copy(x.shape(), x.dtype(), nullptr, {}); std::vector strides; - size_t cur_stride = x.shape(axes_[0]); - for (int axis = 0; axis < x.ndim(); axis++) { - if (axis == axes_[0]) { + size_t cur_stride = x.shape(axis); + for (int a = 0; a < x.ndim(); a++) { + if (a == axis) { strides.push_back(1); } else { strides.push_back(cur_stride); - cur_stride *= x.shape(axis); + cur_stride *= x.shape(a); } } auto flags = x.flags(); - size_t f_stride = 1; - size_t b_stride = 1; - flags.col_contiguous = true; - flags.row_contiguous = true; - for (int i = 0, ri = x.ndim() - 1; i < x.ndim(); ++i, --ri) { - flags.col_contiguous &= (strides[i] == f_stride || x.shape(i) == 1); - f_stride *= x.shape(i); - flags.row_contiguous &= (strides[ri] == b_stride || x.shape(ri) == 1); - b_stride *= x.shape(ri); - } - // This is probably over-conservative - flags.contiguous = false; + auto [data_size, is_row_contiguous, is_col_contiguous] = + check_contiguity(x.shape(), strides); + + flags.col_contiguous = is_row_contiguous; + flags.row_contiguous = is_col_contiguous; + flags.contiguous = data_size == x_copy.size(); x_copy.set_data( - allocator::malloc_or_wait(x.nbytes()), x.data_size(), strides, flags); + allocator::malloc_or_wait(x.nbytes()), data_size, strides, flags); copy_gpu_inplace(x, x_copy, CopyType::GeneralGeneral, s); copies.push_back(x_copy); return x_copy; } }; - const array& in_contiguous = check_input(inputs[0]); + const array& in_contiguous = check_input(in); + + // real to complex: n -> (n/2)+1 + // complex to real: (n/2)+1 -> n + auto out_strides = in_contiguous.strides(); + size_t out_data_size = in_contiguous.data_size(); + if (in.shape(axis) != out.shape(axis)) { + for (int i = 0; i < out_strides.size(); i++) { + if (out_strides[i] != 1) { + out_strides[i] = out_strides[i] / in.shape(axis) * out.shape(axis); + } + } + out_data_size = out_data_size / in.shape(axis) * out.shape(axis); + } + + auto plan = plan_fft(n); + if (plan.four_step) { + four_step_fft(in, out, axis, inverse, real, plan, copies, s); + d.get_command_buffer(s.index)->addCompletedHandler( + [copies](MTL::CommandBuffer*) mutable { copies.clear(); }); + return; + } // TODO: allow donation here - out.set_data( - allocator::malloc_or_wait(out.nbytes()), - in_contiguous.data_size(), - in_contiguous.strides(), - in_contiguous.flags()); + if (!inplace) { + out.set_data( + allocator::malloc_or_wait(out.nbytes()), + out_data_size, + out_strides, + in_contiguous.flags()); + } - // We use n / 4 threads by default since radix-4 - // is the largest single threaded radix butterfly - // we currently implement. - size_t m = n / 4; - size_t batch = in.size() / in.shape(axes_[0]); + auto radices = supported_radices(); + int fft_size = plan.bluestein_n > 0 ? plan.bluestein_n : n; + + // Setup function constants + bool power_of_2 = is_power_of_2(fft_size); + + auto make_int = [](int* a, int i) { + return std::make_tuple(a, MTL::DataType::DataTypeInt, i); + }; + auto make_bool = [](bool* a, int i) { + return std::make_tuple(a, MTL::DataType::DataTypeBool, i); + }; + + std::vector func_consts = { + make_bool(&inverse, 0), make_bool(&power_of_2, 1)}; + + // Start of radix/rader step constants + int index = 4; + for (int i = 0; i < plan.stockham.size(); i++) { + func_consts.push_back(make_int(&plan.stockham[i], index)); + index += 1; + } + for (int i = 0; i < plan.rader.size(); i++) { + func_consts.push_back(make_int(&plan.rader[i], index)); + index += 1; + } + int elems_per_thread = compute_elems_per_thread(plan); + func_consts.push_back(make_int(&elems_per_thread, 2)); + + int rader_m = n / plan.rader_n; + func_consts.push_back(make_int(&rader_m, 3)); + + // The overall number of FFTs we're going to compute for this input + int size = out.dtype() == float32 ? out.size() : in.size(); + if (real && inverse && four_step_params.required) { + size = out.size(); + } + int total_batch_size = size / n; + int threads_per_fft = (fft_size + elems_per_thread - 1) / elems_per_thread; + + // We batch among threadgroups for improved efficiency when n is small + int threadgroup_batch_size = std::max(MIN_THREADGROUP_MEM_SIZE / fft_size, 1); + if (four_step_params.required) { + // Require a threadgroup batch size of at least 4 for four step FFT + // so we can coalesce the memory accesses. + threadgroup_batch_size = + std::max(threadgroup_batch_size, MIN_COALESCE_WIDTH); + } + int threadgroup_mem_size = next_power_of_2(threadgroup_batch_size * fft_size); + // FFTs up to 2^20 are currently supported + assert(threadgroup_mem_size <= MAX_STOCKHAM_FFT_SIZE); + + // ceil divide + int batch_size = + (total_batch_size + threadgroup_batch_size - 1) / threadgroup_batch_size; + + if (real && !four_step_params.required) { + // We can perform 2 RFFTs at once so the batch size is halved. + batch_size = (batch_size + 2 - 1) / 2; + } + int out_buffer_size = out.size(); auto& compute_encoder = d.get_command_encoder(s.index); + auto in_type_str = in.dtype() == float32 ? "float" : "float2"; + auto out_type_str = out.dtype() == float32 ? "float" : "float2"; + // Only required by four step + int step = -1; { std::ostringstream kname; - kname << "fft_" << n; - auto kernel = d.get_kernel(kname.str()); + std::string inv_string = inverse ? "true" : "false"; + std::string real_string = real ? "true" : "false"; + if (plan.bluestein_n > 0) { + kname << "bluestein_fft_mem_" << threadgroup_mem_size << "_" + << in_type_str << "_" << out_type_str; + } else if (plan.rader_n > 1) { + kname << "rader_fft_mem_" << threadgroup_mem_size << "_" << in_type_str + << "_" << out_type_str; + } else if (four_step_params.required) { + step = four_step_params.first_step ? 0 : 1; + kname << "four_step_mem_" << threadgroup_mem_size << "_" << in_type_str + << "_" << out_type_str << "_" << step << "_" << real_string; + } else { + kname << "fft_mem_" << threadgroup_mem_size << "_" << in_type_str << "_" + << out_type_str; + } + std::string base_name = kname.str(); + // We use a specialized kernel for each FFT size + kname << "_n" << fft_size << "_inv_" << inverse; + std::string hash_name = kname.str(); + auto kernel = get_fft_kernel( + d, + base_name, + hash_name, + threadgroup_mem_size, + in_type_str, + out_type_str, + step, + real, + func_consts); - bool donated = in.data_shared_ptr() == nullptr; compute_encoder->setComputePipelineState(kernel); compute_encoder.set_input_array(in_contiguous, 0); compute_encoder.set_output_array(out, 1); - auto group_dims = MTL::Size(1, m, 1); - auto grid_dims = MTL::Size(batch, m, 1); - compute_encoder.dispatchThreads(grid_dims, group_dims); + if (plan.bluestein_n > 0) { + // Precomputed twiddle factors for Bluestein's + auto [w_k, w_q] = compute_bluestein_constants(n, plan.bluestein_n); + copies.push_back(w_q); + copies.push_back(w_k); + + compute_encoder.set_input_array(w_q, 2); // w_q + compute_encoder.set_input_array(w_k, 3); // w_k + compute_encoder->setBytes(&n, sizeof(int), 4); + compute_encoder->setBytes(&plan.bluestein_n, sizeof(int), 5); + compute_encoder->setBytes(&total_batch_size, sizeof(int), 6); + } else if (plan.rader_n > 1) { + auto [b_q, g_q, g_minus_q] = compute_raders_constants(plan.rader_n, s); + copies.push_back(b_q); + copies.push_back(g_q); + copies.push_back(g_minus_q); + + compute_encoder.set_input_array(b_q, 2); + compute_encoder.set_input_array(g_q, 3); + compute_encoder.set_input_array(g_minus_q, 4); + compute_encoder->setBytes(&n, sizeof(int), 5); + compute_encoder->setBytes(&total_batch_size, sizeof(int), 6); + compute_encoder->setBytes(&plan.rader_n, sizeof(int), 7); + } else if (four_step_params.required) { + compute_encoder->setBytes(&four_step_params.n1, sizeof(int), 2); + compute_encoder->setBytes(&four_step_params.n2, sizeof(int), 3); + compute_encoder->setBytes(&total_batch_size, sizeof(int), 4); + } else { + compute_encoder->setBytes(&n, sizeof(int), 2); + compute_encoder->setBytes(&total_batch_size, sizeof(int), 3); + } + + auto group_dims = MTL::Size(1, threadgroup_batch_size, threads_per_fft); + auto grid_dims = + MTL::Size(batch_size, threadgroup_batch_size, threads_per_fft); + compute_encoder->dispatchThreads(grid_dims, group_dims); } d.get_command_buffer(s.index)->addCompletedHandler( [copies](MTL::CommandBuffer*) mutable { copies.clear(); }); } +void fft_op( + const array& in, + array& out, + size_t axis, + bool inverse, + bool real, + bool inplace, + const Stream& s) { + fft_op(in, out, axis, inverse, real, FourStepParams(), inplace, s); +} + +void nd_fft_op( + const array& in, + array& out, + const std::vector& axes, + bool inverse, + bool real, + const Stream& s) { + // Perform ND FFT on GPU as a series of 1D FFTs + auto temp_shape = inverse ? in.shape() : out.shape(); + array temp1(temp_shape, complex64, nullptr, {}); + array temp2(temp_shape, complex64, nullptr, {}); + std::vector temp_arrs = {temp1, temp2}; + for (int i = axes.size() - 1; i >= 0; i--) { + int reverse_index = axes.size() - i - 1; + // For 5D and above, we don't want to reallocate our two temporary arrays + bool inplace = reverse_index >= 3 && i != 0; + // Opposite order for fft vs ifft + int index = inverse ? reverse_index : i; + size_t axis = axes[index]; + // Mirror np.fft.(i)rfftn and perform a real transform + // only on the final axis. + bool step_real = (real && index == axes.size() - 1); + int step_shape = inverse ? out.shape(axis) : in.shape(axis); + const array& in_arr = i == axes.size() - 1 ? in : temp_arrs[1 - i % 2]; + array& out_arr = i == 0 ? out : temp_arrs[i % 2]; + fft_op(in_arr, out_arr, axis, inverse, step_real, inplace, s); + } + + std::vector copies = {temp1, temp2}; + auto& d = metal::device(s.device); + d.get_command_buffer(s.index)->addCompletedHandler( + [copies](MTL::CommandBuffer*) mutable { copies.clear(); }); +} + +void FFT::eval_gpu(const std::vector& inputs, array& out) { + auto& s = stream(); + auto& in = inputs[0]; + + if (axes_.size() > 1) { + nd_fft_op(in, out, axes_, inverse_, real_, s); + } else { + fft_op(in, out, axes_[0], inverse_, real_, /*inplace=*/false, s); + } +} + } // namespace mlx::core diff --git a/mlx/backend/metal/jit/fft.h b/mlx/backend/metal/jit/fft.h new file mode 100644 index 000000000..24f908db9 --- /dev/null +++ b/mlx/backend/metal/jit/fft.h @@ -0,0 +1,53 @@ +// Copyright © 2024 Apple Inc. + +constexpr std::string_view fft_kernel = R"( +template [[host_name("{name}")]] [[kernel]] void +fft<{tg_mem_size}, {in_T}, {out_T}>( + const device {in_T}* in [[buffer(0)]], + device {out_T}* out [[buffer(1)]], + constant const int& n, + constant const int& batch_size, + uint3 elem [[thread_position_in_grid]], + uint3 grid [[threads_per_grid]]); +)"; + +constexpr std::string_view rader_fft_kernel = R"( +template [[host_name("{name}")]] [[kernel]] void +rader_fft<{tg_mem_size}, {in_T}, {out_T}>( + const device {in_T}* in [[buffer(0)]], + device {out_T}* out [[buffer(1)]], + const device float2* raders_b_q [[buffer(2)]], + const device short* raders_g_q [[buffer(3)]], + const device short* raders_g_minus_q [[buffer(4)]], + constant const int& n, + constant const int& batch_size, + constant const int& rader_n, + uint3 elem [[thread_position_in_grid]], + uint3 grid [[threads_per_grid]]); +)"; + +constexpr std::string_view bluestein_fft_kernel = R"( +template [[host_name("{name}")]] [[kernel]] void +bluestein_fft<{tg_mem_size}, {in_T}, {out_T}>( + const device {in_T}* in [[buffer(0)]], + device {out_T}* out [[buffer(1)]], + const device float2* w_q [[buffer(2)]], + const device float2* w_k [[buffer(3)]], + constant const int& length, + constant const int& n, + constant const int& batch_size, + uint3 elem [[thread_position_in_grid]], + uint3 grid [[threads_per_grid]]); +)"; + +constexpr std::string_view four_step_fft_kernel = R"( +template [[host_name("{name}")]] [[kernel]] void +four_step_fft<{tg_mem_size}, {in_T}, {out_T}, {step}, {real}>( + const device {in_T}* in [[buffer(0)]], + device {out_T}* out [[buffer(1)]], + constant const int& n1, + constant const int& n2, + constant const int& batch_size, + uint3 elem [[thread_position_in_grid]], + uint3 grid [[threads_per_grid]]); +)"; \ No newline at end of file diff --git a/mlx/backend/metal/jit/includes.h b/mlx/backend/metal/jit/includes.h index 4bb4ac38d..f6b668512 100644 --- a/mlx/backend/metal/jit/includes.h +++ b/mlx/backend/metal/jit/includes.h @@ -17,6 +17,7 @@ const char* unary(); const char* binary(); const char* binary_two(); const char* copy(); +const char* fft(); const char* ternary(); const char* scan(); const char* softmax(); diff --git a/mlx/backend/metal/jit_kernels.cpp b/mlx/backend/metal/jit_kernels.cpp index 5c84105c9..1175ddbee 100644 --- a/mlx/backend/metal/jit_kernels.cpp +++ b/mlx/backend/metal/jit_kernels.cpp @@ -6,6 +6,7 @@ #include "mlx/backend/metal/jit/binary.h" #include "mlx/backend/metal/jit/binary_two.h" #include "mlx/backend/metal/jit/copy.h" +#include "mlx/backend/metal/jit/fft.h" #include "mlx/backend/metal/jit/includes.h" #include "mlx/backend/metal/jit/reduce.h" #include "mlx/backend/metal/jit/scan.h" @@ -489,4 +490,51 @@ MTL::ComputePipelineState* get_steel_conv_general_kernel( return d.get_kernel(kernel_name, lib); } +MTL::ComputePipelineState* get_fft_kernel( + metal::Device& d, + const std::string& kernel_name, + const std::string& hash_name, + const int tg_mem_size, + const std::string& in_type, + const std::string& out_type, + int step, + bool real, + const metal::MTLFCList& func_consts) { + const auto& lib_name = kernel_name; + auto lib = d.get_library(lib_name); + if (lib == nullptr) { + std::ostringstream kernel_source; + std::string kernel_string; + if (lib_name.find("bluestein") != std::string::npos) { + kernel_string = bluestein_fft_kernel; + } else if (lib_name.find("rader") != std::string::npos) { + kernel_string = rader_fft_kernel; + } else if (lib_name.find("four_step") != std::string::npos) { + kernel_string = four_step_fft_kernel; + } else { + kernel_string = fft_kernel; + } + kernel_source << metal::fft(); + if (lib_name.find("four_step") != std::string::npos) { + kernel_source << fmt::format( + kernel_string, + "name"_a = lib_name, + "tg_mem_size"_a = tg_mem_size, + "in_T"_a = in_type, + "out_T"_a = out_type, + "step"_a = step, + "real"_a = real); + } else { + kernel_source << fmt::format( + kernel_string, + "name"_a = lib_name, + "tg_mem_size"_a = tg_mem_size, + "in_T"_a = in_type, + "out_T"_a = out_type); + } + lib = d.get_library(lib_name, kernel_source.str()); + } + return d.get_kernel(kernel_name, lib, hash_name, func_consts); +} + } // namespace mlx::core diff --git a/mlx/backend/metal/kernels.h b/mlx/backend/metal/kernels.h index 94bf609f3..ce99464ef 100644 --- a/mlx/backend/metal/kernels.h +++ b/mlx/backend/metal/kernels.h @@ -155,4 +155,15 @@ MTL::ComputePipelineState* get_steel_conv_general_kernel( int wm, int wn); +MTL::ComputePipelineState* get_fft_kernel( + metal::Device& d, + const std::string& kernel_name, + const std::string& hash_name, + const int tg_mem_size, + const std::string& in_type, + const std::string& out_type, + int step, + bool real, + const metal::MTLFCList& func_consts); + } // namespace mlx::core diff --git a/mlx/backend/metal/kernels/CMakeLists.txt b/mlx/backend/metal/kernels/CMakeLists.txt index 9d99d39a5..f98430eb8 100644 --- a/mlx/backend/metal/kernels/CMakeLists.txt +++ b/mlx/backend/metal/kernels/CMakeLists.txt @@ -48,6 +48,9 @@ set( binary.h ternary.h copy.h + fft.h + fft/radix.h + fft/readwrite.h softmax.h sort.h scan.h diff --git a/mlx/backend/metal/kernels/fft.h b/mlx/backend/metal/kernels/fft.h new file mode 100644 index 000000000..a4869a2ac --- /dev/null +++ b/mlx/backend/metal/kernels/fft.h @@ -0,0 +1,486 @@ +// Copyright © 2024 Apple Inc. + +// Metal FFT using Stockham's algorithm +// +// References: +// - VkFFT (https://github.com/DTolm/VkFFT) +// - Eric Bainville's excellent page (http://www.bealto.com/gpu-fft.html) + +#include + +#include "mlx/backend/metal/kernels/fft/radix.h" +#include "mlx/backend/metal/kernels/fft/readwrite.h" +#include "mlx/backend/metal/kernels/steel/defines.h" + +using namespace metal; + +#define MAX_RADIX 13 +// Reached when elems_per_thread_ = 6, max_radix = 13 +// and some threads have to do 3 radix 6s requiring 18 float2s. +#define MAX_OUTPUT_SIZE 18 + +// Specialize for a particular value of N at runtime +STEEL_CONST bool inv_ [[function_constant(0)]]; +STEEL_CONST bool is_power_of_2_ [[function_constant(1)]]; +STEEL_CONST int elems_per_thread_ [[function_constant(2)]]; +// rader_m = n / rader_n +STEEL_CONST int rader_m_ [[function_constant(3)]]; +// Stockham steps +STEEL_CONST int radix_13_steps_ [[function_constant(4)]]; +STEEL_CONST int radix_11_steps_ [[function_constant(5)]]; +STEEL_CONST int radix_8_steps_ [[function_constant(6)]]; +STEEL_CONST int radix_7_steps_ [[function_constant(7)]]; +STEEL_CONST int radix_6_steps_ [[function_constant(8)]]; +STEEL_CONST int radix_5_steps_ [[function_constant(9)]]; +STEEL_CONST int radix_4_steps_ [[function_constant(10)]]; +STEEL_CONST int radix_3_steps_ [[function_constant(11)]]; +STEEL_CONST int radix_2_steps_ [[function_constant(12)]]; +// Rader steps +STEEL_CONST int rader_13_steps_ [[function_constant(13)]]; +STEEL_CONST int rader_11_steps_ [[function_constant(14)]]; +STEEL_CONST int rader_8_steps_ [[function_constant(15)]]; +STEEL_CONST int rader_7_steps_ [[function_constant(16)]]; +STEEL_CONST int rader_6_steps_ [[function_constant(17)]]; +STEEL_CONST int rader_5_steps_ [[function_constant(18)]]; +STEEL_CONST int rader_4_steps_ [[function_constant(19)]]; +STEEL_CONST int rader_3_steps_ [[function_constant(20)]]; +STEEL_CONST int rader_2_steps_ [[function_constant(21)]]; + +// See "radix.h" for radix codelets +typedef void (*RadixFunc)(thread float2*, thread float2*); + +// Perform a single radix n butterfly with appropriate twiddles +template +METAL_FUNC void radix_butterfly( + int i, + int p, + thread float2* x, + thread short* indices, + thread float2* y) { + // i: the index in the overall DFT that we're processing. + // p: the size of the DFTs we're merging at this step. + // m: how many threads are working on this DFT. + int k, j; + + // Use faster bitwise operations when working with powers of two + constexpr bool radix_p_2 = (radix & (radix - 1)) == 0; + if (radix_p_2 && is_power_of_2_) { + constexpr short power = __builtin_ctz(radix); + k = i & (p - 1); + j = ((i - k) << power) + k; + } else { + k = i % p; + j = (i / p) * radix * p + k; + } + + // Apply twiddles + if (p > 1) { + float2 twiddle_1 = get_twiddle(k, radix * p); + float2 twiddle = twiddle_1; + x[1] = complex_mul(x[1], twiddle); + + STEEL_PRAGMA_UNROLL + for (int t = 2; t < radix; t++) { + twiddle = complex_mul(twiddle, twiddle_1); + x[t] = complex_mul(x[t], twiddle); + } + } + + radix_func(x, y); + + STEEL_PRAGMA_UNROLL + for (int t = 0; t < radix; t++) { + indices[t] = j + t * p; + } +} + +// Perform all the radix steps required for a +// particular radix size n. +template +METAL_FUNC void radix_n_steps( + int i, + thread int* p, + int m, + int n, + int num_steps, + thread float2* inputs, + thread short* indices, + thread float2* values, + threadgroup float2* buf) { + int m_r = n / radix; + // When combining different sized radices, we have to do + // multiple butterflies in a single thread. + // E.g. n = 28 = 4 * 7 + // 4 threads, 7 elems_per_thread + // All threads do 1 radix7 butterfly. + // 3 threads do 2 radix4 butterflies. + // 1 thread does 1 radix4 butterfly. + int max_radices_per_thread = (elems_per_thread_ + radix - 1) / radix; + + int index = 0; + int r_index = 0; + for (int s = 0; s < num_steps; s++) { + for (int t = 0; t < max_radices_per_thread; t++) { + index = i + t * m; + if (index < m_r) { + for (int r = 0; r < radix; r++) { + inputs[r] = buf[index + r * m_r]; + } + radix_butterfly( + index, *p, inputs, indices + t * radix, values + t * radix); + } + } + + // Wait until all threads have read their inputs into thread local mem + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (int t = 0; t < max_radices_per_thread; t++) { + index = i + t * m; + if (index < m_r) { + for (int r = 0; r < radix; r++) { + r_index = t * radix + r; + buf[indices[r_index]] = values[r_index]; + } + } + } + + // Wait until all threads have written back to threadgroup mem + threadgroup_barrier(mem_flags::mem_threadgroup); + *p *= radix; + } +} + +#define RADIX_STEP(radix, radix_func, num_steps) \ + radix_n_steps( \ + fft_idx, p, m, n, num_steps, inputs, indices, values, buf); + +template +METAL_FUNC void +perform_fft(int fft_idx, thread int* p, int m, int n, threadgroup float2* buf) { + float2 inputs[MAX_RADIX]; + short indices[MAX_OUTPUT_SIZE]; + float2 values[MAX_OUTPUT_SIZE]; + + RADIX_STEP(2, radix2, rader ? rader_2_steps_ : radix_2_steps_); + RADIX_STEP(3, radix3, rader ? rader_3_steps_ : radix_3_steps_); + RADIX_STEP(4, radix4, rader ? rader_4_steps_ : radix_4_steps_); + RADIX_STEP(5, radix5, rader ? rader_5_steps_ : radix_5_steps_); + RADIX_STEP(6, radix6, rader ? rader_6_steps_ : radix_6_steps_); + RADIX_STEP(7, radix7, rader ? rader_7_steps_ : radix_7_steps_); + RADIX_STEP(8, radix8, rader ? rader_8_steps_ : radix_8_steps_); + RADIX_STEP(11, radix11, rader ? rader_11_steps_ : radix_11_steps_); + RADIX_STEP(13, radix13, rader ? rader_13_steps_ : radix_13_steps_); +} + +// Each FFT is computed entirely in shared GPU memory. +// +// N is decomposed into radix-n DFTs: +// e.g. 128 = 2 * 4 * 4 * 4 +template +[[kernel]] void fft( + const device in_T* in [[buffer(0)]], + device out_T* out [[buffer(1)]], + constant const int& n, + constant const int& batch_size, + uint3 elem [[thread_position_in_grid]], + uint3 grid [[threads_per_grid]]) { + threadgroup float2 shared_in[tg_mem_size]; + + thread ReadWriter read_writer = ReadWriter( + in, + &shared_in[0], + out, + n, + batch_size, + elems_per_thread_, + elem, + grid, + inv_); + + if (read_writer.out_of_bounds()) { + return; + }; + read_writer.load(); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + int p = 1; + int fft_idx = elem.z; // Thread index in DFT + int m = grid.z; // Threads per DFT + int tg_idx = elem.y * n; // Index of this DFT in threadgroup + threadgroup float2* buf = &shared_in[tg_idx]; + + perform_fft(fft_idx, &p, m, n, buf); + + read_writer.write(); +} + +template +[[kernel]] void rader_fft( + const device in_T* in [[buffer(0)]], + device out_T* out [[buffer(1)]], + const device float2* raders_b_q [[buffer(2)]], + const device short* raders_g_q [[buffer(3)]], + const device short* raders_g_minus_q [[buffer(4)]], + constant const int& n, + constant const int& batch_size, + constant const int& rader_n, + uint3 elem [[thread_position_in_grid]], + uint3 grid [[threads_per_grid]]) { + // Use Rader's algorithm to compute fast FFTs + // when a prime factor `p` of `n` is greater than 13 but + // has `p - 1` Stockham decomposable into to prime factors <= 13. + // + // E.g. n = 102 + // = 2 * 3 * 17 + // . = 2 * 3 * RADER(16) + // . = 2 * 3 * RADER(4 * 4) + // + // In numpy: + // x_perm = x[g_q] + // y = np.fft.fft(x_perm) * b_q + // z = np.fft.ifft(y) + x[0] + // out = z[g_minus_q] + // out[0] = x[1:].sum() + // + // Where the g_q and g_minus_q are permutations formed + // by the group under multiplicative modulo N using the + // primitive root of N and b_q is a constant. + // See https://en.wikipedia.org/wiki/Rader%27s_FFT_algorithm + // + // Rader's uses fewer operations than Bluestein's and so + // is more accurate. It's also faster in most cases. + threadgroup float2 shared_in[tg_mem_size]; + + thread ReadWriter read_writer = ReadWriter( + in, + &shared_in[0], + out, + n, + batch_size, + elems_per_thread_, + elem, + grid, + inv_); + + if (read_writer.out_of_bounds()) { + return; + }; + read_writer.load(); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // The number of the threads we're using for each DFT + int m = grid.z; + + int fft_idx = elem.z; + int tg_idx = elem.y * n; + threadgroup float2* buf = &shared_in[tg_idx]; + + // rader_m = n / rader_n; + int rader_m = rader_m_; + + // We have to load two x_0s for each thread since sometimes + // elems_per_thread_ crosses a boundary. + // E.g. with n = 34, rader_n = 17, elems_per_thread_ = 4 + // 0 0 0 0 1 1 1 1 2 2 2 2 3 3 3 3 4 4 4 4 5 5 5 5 6 6 6 6 7 7 7 7 8 8 + // 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 + short x_0_index = + metal::min(fft_idx * elems_per_thread_ / (rader_n - 1), rader_m - 1); + float2 x_0[2] = {buf[x_0_index], buf[x_0_index + 1]}; + + // Do the Rader permutation in shared memory + float2 temp[MAX_RADIX]; + int max_index = n - rader_m - 1; + for (int e = 0; e < elems_per_thread_; e++) { + short index = metal::min(fft_idx * elems_per_thread_ + e, max_index); + short g_q = raders_g_q[index / rader_m]; + temp[e] = buf[rader_m + (g_q - 1) * rader_m + index % rader_m]; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (int e = 0; e < elems_per_thread_; e++) { + short index = metal::min(fft_idx * elems_per_thread_ + e, max_index); + buf[index + rader_m] = temp[e]; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Rader FFT on x[rader_m:] + int p = 1; + perform_fft(fft_idx, &p, m, n - rader_m, buf + rader_m); + + // x_1 + ... + x_n is computed for us in the first FFT step so + // we save it in the first rader_m indices of the array for later. + int x_sum_index = metal::min(fft_idx, rader_m - 1); + buf[x_sum_index] = buf[rader_m + x_sum_index * (rader_n - 1)]; + + float2 inv = {1.0f, -1.0f}; + for (int e = 0; e < elems_per_thread_; e++) { + short index = metal::min(fft_idx * elems_per_thread_ + e, max_index); + short interleaved_index = + index / rader_m + (index % rader_m) * (rader_n - 1); + temp[e] = complex_mul( + buf[rader_m + interleaved_index], + raders_b_q[interleaved_index % (rader_n - 1)]); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (int e = 0; e < elems_per_thread_; e++) { + short index = metal::min(fft_idx * elems_per_thread_ + e, max_index); + buf[rader_m + index] = temp[e] * inv; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Rader IFFT on x[rader_m:] + p = 1; + perform_fft(fft_idx, &p, m, n - rader_m, buf + rader_m); + + float2 rader_inv_factor = {1.0f / (rader_n - 1), -1.0f / (rader_n - 1)}; + + for (int e = 0; e < elems_per_thread_; e++) { + short index = metal::min(fft_idx * elems_per_thread_ + e, n - rader_m - 1); + short diff_index = index / (rader_n - 1) - x_0_index; + temp[e] = buf[rader_m + index] * rader_inv_factor + x_0[diff_index]; + } + + // Use the sum of elements that was computed in the first FFT + float2 x_sum = buf[x_0_index] + x_0[0]; + + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (int e = 0; e < elems_per_thread_; e++) { + short index = metal::min(fft_idx * elems_per_thread_ + e, max_index); + short g_q_index = index % (rader_n - 1); + short g_q = raders_g_minus_q[g_q_index]; + short out_index = index - g_q_index + g_q + (index / (rader_n - 1)); + buf[out_index] = temp[e]; + } + + buf[x_0_index * rader_n] = x_sum; + + threadgroup_barrier(mem_flags::mem_threadgroup); + + p = rader_n; + perform_fft(fft_idx, &p, m, n, buf); + + read_writer.write(); +} + +template +[[kernel]] void bluestein_fft( + const device in_T* in [[buffer(0)]], + device out_T* out [[buffer(1)]], + const device float2* w_q [[buffer(2)]], + const device float2* w_k [[buffer(3)]], + constant const int& length, + constant const int& n, + constant const int& batch_size, + uint3 elem [[thread_position_in_grid]], + uint3 grid [[threads_per_grid]]) { + // Computes arbitrary length FFTs with Bluestein's algorithm + // + // In numpy: + // bluestein_n = next_power_of_2(2*n - 1) + // out = w_k * np.fft.ifft(np.fft.fft(w_k * in, bluestein_n) * w_q) + // + // Where w_k and w_q are precomputed on CPU in high precision as: + // w_k = np.exp(-1j * np.pi / n * (np.arange(-n + 1, n) ** 2)) + // w_q = np.fft.fft(1/w_k[-n:]) + threadgroup float2 shared_in[tg_mem_size]; + + thread ReadWriter read_writer = ReadWriter( + in, + &shared_in[0], + out, + n, + batch_size, + elems_per_thread_, + elem, + grid, + inv_); + + if (read_writer.out_of_bounds()) { + return; + }; + read_writer.load_padded(length, w_k); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + int p = 1; + int fft_idx = elem.z; // Thread index in DFT + int m = grid.z; // Threads per DFT + int tg_idx = elem.y * n; // Index of this DFT in threadgroup + threadgroup float2* buf = &shared_in[tg_idx]; + + // fft + perform_fft(fft_idx, &p, m, n, buf); + + float2 inv = float2(1.0f, -1.0f); + for (int t = 0; t < elems_per_thread_; t++) { + int index = fft_idx + t * m; + buf[index] = complex_mul(buf[index], w_q[index]) * inv; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // ifft + p = 1; + perform_fft(fft_idx, &p, m, n, buf); + + read_writer.write_padded(length, w_k); +} + +template < + int tg_mem_size, + typename in_T, + typename out_T, + int step, + bool real = false> +[[kernel]] void four_step_fft( + const device in_T* in [[buffer(0)]], + device out_T* out [[buffer(1)]], + constant const int& n1, + constant const int& n2, + constant const int& batch_size, + uint3 elem [[thread_position_in_grid]], + uint3 grid [[threads_per_grid]]) { + // Fast four step FFT implementation for powers of 2. + int overall_n = n1 * n2; + int n = step == 0 ? n1 : n2; + int stride = step == 0 ? n2 : n1; + + // The number of the threads we're using for each DFT + int m = grid.z; + int fft_idx = elem.z; + + threadgroup float2 shared_in[tg_mem_size]; + threadgroup float2* buf = &shared_in[elem.y * n]; + + using read_writer_t = ReadWriter; + read_writer_t read_writer = read_writer_t( + in, + &shared_in[0], + out, + n, + batch_size, + elems_per_thread_, + elem, + grid, + inv_); + + if (read_writer.out_of_bounds()) { + return; + }; + read_writer.load_strided(stride, overall_n); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + int p = 1; + perform_fft(fft_idx, &p, m, n, buf); + + read_writer.write_strided(stride, overall_n); +} \ No newline at end of file diff --git a/mlx/backend/metal/kernels/fft.metal b/mlx/backend/metal/kernels/fft.metal index 66dc0d22b..05828f34c 100644 --- a/mlx/backend/metal/kernels/fft.metal +++ b/mlx/backend/metal/kernels/fft.metal @@ -1,199 +1,84 @@ // Copyright © 2024 Apple Inc. -// Metal FFT using Stockham's algorithm -// -// References: -// - VkFFT (https://github.com/DTolm/VkFFT) -// - Eric Bainville's excellent page (http://www.bealto.com/gpu-fft.html) +#include "mlx/backend/metal/kernels/fft.h" -#include -#include +#define instantiate_fft(tg_mem_size, in_T, out_T) \ + template [[host_name("fft_mem_" #tg_mem_size "_" #in_T \ + "_" #out_T)]] [[kernel]] void \ + fft( \ + const device in_T* in [[buffer(0)]], \ + device out_T* out [[buffer(1)]], \ + constant const int& n, \ + constant const int& batch_size, \ + uint3 elem [[thread_position_in_grid]], \ + uint3 grid [[threads_per_grid]]); -#include "mlx/backend/metal/kernels/defines.h" -#include "mlx/backend/metal/kernels/utils.h" +#define instantiate_rader(tg_mem_size, in_T, out_T) \ + template [[host_name("rader_fft_mem_" #tg_mem_size "_" #in_T \ + "_" #out_T)]] [[kernel]] void \ + rader_fft( \ + const device in_T* in [[buffer(0)]], \ + device out_T* out [[buffer(1)]], \ + const device float2* raders_b_q [[buffer(2)]], \ + const device short* raders_g_q [[buffer(3)]], \ + const device short* raders_g_minus_q [[buffer(4)]], \ + constant const int& n, \ + constant const int& batch_size, \ + constant const int& rader_n, \ + uint3 elem [[thread_position_in_grid]], \ + uint3 grid [[threads_per_grid]]); -using namespace metal; +#define instantiate_bluestein(tg_mem_size, in_T, out_T) \ + template [[host_name("bluestein_fft_mem_" #tg_mem_size "_" #in_T \ + "_" #out_T)]] [[kernel]] void \ + bluestein_fft( \ + const device in_T* in [[buffer(0)]], \ + device out_T* out [[buffer(1)]], \ + const device float2* w_q [[buffer(2)]], \ + const device float2* w_k [[buffer(3)]], \ + constant const int& length, \ + constant const int& n, \ + constant const int& batch_size, \ + uint3 elem [[thread_position_in_grid]], \ + uint3 grid [[threads_per_grid]]); -float2 complex_mul(float2 a, float2 b) { - float2 c; - c.x = a.x * b.x - a.y * b.y; - c.y = a.x * b.y + a.y * b.x; - return c; -} +#define instantiate_four_step(tg_mem_size, in_T, out_T, step, real) \ + template [[host_name("four_step_mem_" #tg_mem_size "_" #in_T "_" #out_T \ + "_" #step "_" #real)]] [[kernel]] void \ + four_step_fft( \ + const device in_T* in [[buffer(0)]], \ + device out_T* out [[buffer(1)]], \ + constant const int& n1, \ + constant const int& n2, \ + constant const int& batch_size, \ + uint3 elem [[thread_position_in_grid]], \ + uint3 grid [[threads_per_grid]]); -float2 get_twiddle(int k, int p) { - float theta = -1.0f * k * M_PI_F / (2 * p); - - float2 twiddle; - twiddle.x = metal::fast::cos(theta); - twiddle.y = metal::fast::sin(theta); - return twiddle; -} - -// single threaded radix2 implemetation -void radix2( - int i, - int p, - int m, - threadgroup float2* read_buf, - threadgroup float2* write_buf) { - float2 x_0 = read_buf[i]; - float2 x_1 = read_buf[i + m]; - - // The index within this sub-DFT - int k = i & (p - 1); - - float2 twiddle = get_twiddle(k, p); - - float2 z = complex_mul(x_1, twiddle); - - float2 y_0 = x_0 + z; - float2 y_1 = x_0 - z; - - int j = (i << 1) - k; - - write_buf[j] = y_0; - write_buf[j + p] = y_1; -} - -// single threaded radix4 implemetation -void radix4( - int i, - int p, - int m, - threadgroup float2* read_buf, - threadgroup float2* write_buf) { - float2 x_0 = read_buf[i]; - float2 x_1 = read_buf[i + m]; - float2 x_2 = read_buf[i + 2 * m]; - float2 x_3 = read_buf[i + 3 * m]; - - // The index within this sub-DFT - int k = i & (p - 1); - - float2 twiddle = get_twiddle(k, p); - // e^a * e^b = e^(a + b) - float2 twiddle_2 = complex_mul(twiddle, twiddle); - float2 twiddle_3 = complex_mul(twiddle, twiddle_2); - - x_1 = complex_mul(x_1, twiddle); - x_2 = complex_mul(x_2, twiddle_2); - x_3 = complex_mul(x_3, twiddle_3); - - float2 minus_i; - minus_i.x = 0; - minus_i.y = -1; - - // Hard coded twiddle factors for DFT4 - float2 z_0 = x_0 + x_2; - float2 z_1 = x_0 - x_2; - float2 z_2 = x_1 + x_3; - float2 z_3 = complex_mul(x_1 - x_3, minus_i); - - float2 y_0 = z_0 + z_2; - float2 y_1 = z_1 + z_3; - float2 y_2 = z_0 - z_2; - float2 y_3 = z_1 - z_3; - - int j = ((i - k) << 2) + k; - - write_buf[j] = y_0; - write_buf[j + p] = y_1; - write_buf[j + 2 * p] = y_2; - write_buf[j + 3 * p] = y_3; -} - -// Each FFT is computed entirely in shared GPU memory. -// -// N is decomposed into radix-2 and radix-4 DFTs: -// e.g. 128 = 2 * 4 * 4 * 4 -// -// At each step we use n / 4 threads, each performing -// a single-threaded radix-4 or radix-2 DFT. -// -// We provide the number of radix-2 and radix-4 -// steps at compile time for a ~20% performance boost. -template -[[kernel]] void fft( - const device float2* in [[buffer(0)]], - device float2* out [[buffer(1)]], - uint3 thread_position_in_grid [[thread_position_in_grid]], - uint3 threads_per_grid [[threads_per_grid]]) { - // Index of the DFT in batch - int batch_idx = thread_position_in_grid.x * n; - // The index in the DFT we're working on - int i = thread_position_in_grid.y; - // The number of the threads we're using for each DFT - int m = threads_per_grid.y; - - // Allocate 2 shared memory buffers for Stockham. - // We alternate reading from one and writing to the other at each radix step. - threadgroup float2 shared_in[n]; - threadgroup float2 shared_out[n]; - - // Pointers to facilitate Stockham buffer swapping - threadgroup float2* read_buf = shared_in; - threadgroup float2* write_buf = shared_out; - threadgroup float2* tmp; - - // Copy input into shared memory - shared_in[i] = in[batch_idx + i]; - shared_in[i + m] = in[batch_idx + i + m]; - shared_in[i + 2 * m] = in[batch_idx + i + 2 * m]; - shared_in[i + 3 * m] = in[batch_idx + i + 3 * m]; - - threadgroup_barrier(mem_flags::mem_threadgroup); - - int p = 1; - - for (size_t r = 0; r < radix_2_steps; r++) { - radix2(i, p, m * 2, read_buf, write_buf); - radix2(i + m, p, m * 2, read_buf, write_buf); - p *= 2; - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Stockham switch of buffers - tmp = write_buf; - write_buf = read_buf; - read_buf = tmp; - } - - for (size_t r = 0; r < radix_4_steps; r++) { - radix4(i, p, m, read_buf, write_buf); - p *= 4; - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Stockham switch of buffers - tmp = write_buf; - write_buf = read_buf; - read_buf = tmp; - } - - // Copy shared memory to output - out[batch_idx + i] = read_buf[i]; - out[batch_idx + i + m] = read_buf[i + m]; - out[batch_idx + i + 2 * m] = read_buf[i + 2 * m]; - out[batch_idx + i + 3 * m] = read_buf[i + 3 * m]; -} - -#define instantiate_fft(name, n, radix_2_steps, radix_4_steps) \ - template [[host_name("fft_" #name)]] [[kernel]] void \ - fft( \ - const device float2* in [[buffer(0)]], \ - device float2* out [[buffer(1)]], \ - uint3 thread_position_in_grid [[thread_position_in_grid]], \ - uint3 threads_per_grid [[threads_per_grid]]); - -// Explicitly define kernels for each power of 2. // clang-format off -instantiate_fft(4, /* n= */ 4, /* radix_2_steps= */ 0, /* radix_4_steps= */ 1) -instantiate_fft(8, 8, 1, 1) instantiate_fft(16, 16, 0, 2) -instantiate_fft(32, 32, 1, 2) instantiate_fft(64, 64, 0, 3) -instantiate_fft(128, 128, 1, 3) instantiate_fft(256, 256, 0, 4) -instantiate_fft(512, 512, 1, 4) -instantiate_fft(1024, 1024, 0, 5) -// 2048 is the max that will fit into 32KB of threadgroup memory. -// TODO: implement 4 step FFT for larger n. -instantiate_fft(2048, 2048, 1, 5) // clang-format on +#define instantiate_ffts(tg_mem_size) \ + instantiate_fft(tg_mem_size, float2, float2) \ + instantiate_fft(tg_mem_size, float, float2) \ + instantiate_fft(tg_mem_size, float2, float) \ + instantiate_rader(tg_mem_size, float2, float2) \ + instantiate_rader(tg_mem_size, float, float2) \ + instantiate_rader(tg_mem_size, float2, float) \ + instantiate_bluestein(tg_mem_size, float2, float2) \ + instantiate_bluestein(tg_mem_size, float, float2) \ + instantiate_bluestein(tg_mem_size, float2, float) \ + instantiate_four_step(tg_mem_size, float2, float2, 0, /*real=*/false) \ + instantiate_four_step(tg_mem_size, float2, float2, 1, /*real=*/false) \ + instantiate_four_step(tg_mem_size, float, float2, 0, /*real=*/true) \ + instantiate_four_step(tg_mem_size, float2, float2, 1, /*real=*/true) \ + instantiate_four_step(tg_mem_size, float2, float2, 0, /*real=*/true) \ + instantiate_four_step(tg_mem_size, float2, float, 1, /*real=*/true) + +// It's substantially faster to statically define the +// threadgroup memory size rather than using +// `setThreadgroupMemoryLength` on the compute encoder. +// For non-power of 2 sizes we round up the shared memory. +instantiate_ffts(256) +instantiate_ffts(512) +instantiate_ffts(1024) +instantiate_ffts(2048) +// 4096 is the max that will fit into 32KB of threadgroup memory. +instantiate_ffts(4096) // clang-format on diff --git a/mlx/backend/metal/kernels/fft/radix.h b/mlx/backend/metal/kernels/fft/radix.h new file mode 100644 index 000000000..bd61eef6d --- /dev/null +++ b/mlx/backend/metal/kernels/fft/radix.h @@ -0,0 +1,328 @@ +// Copyright © 2024 Apple Inc. + +/* Radix kernels + +We provide optimized, single threaded Radix codelets +for n=2,3,4,5,6,7,8,10,11,12,13. + +For n=2,3,4,5,6 we hand write the codelets. +For n=8,10,12 we combine smaller codelets. +For n=7,11,13 we use Rader's algorithm which decomposes +them into (n-1)=6,10,12 codelets. */ + +#pragma once + +#include +#include +#include + +METAL_FUNC float2 complex_mul(float2 a, float2 b) { + return float2(a.x * b.x - a.y * b.y, a.x * b.y + a.y * b.x); +} + +// Complex mul followed by conjugate +METAL_FUNC float2 complex_mul_conj(float2 a, float2 b) { + return float2(a.x * b.x - a.y * b.y, -a.x * b.y - a.y * b.x); +} + +// Compute an FFT twiddle factor +METAL_FUNC float2 get_twiddle(int k, int p) { + float theta = -2.0f * k * M_PI_F / p; + + float2 twiddle = {metal::fast::cos(theta), metal::fast::sin(theta)}; + return twiddle; +} + +METAL_FUNC void radix2(thread float2* x, thread float2* y) { + y[0] = x[0] + x[1]; + y[1] = x[0] - x[1]; +} + +METAL_FUNC void radix3(thread float2* x, thread float2* y) { + float pi_2_3 = -0.8660254037844387; + + float2 a_1 = x[1] + x[2]; + float2 a_2 = x[1] - x[2]; + + y[0] = x[0] + a_1; + float2 b_1 = x[0] - 0.5 * a_1; + float2 b_2 = pi_2_3 * a_2; + + float2 b_2_j = {-b_2.y, b_2.x}; + y[1] = b_1 + b_2_j; + y[2] = b_1 - b_2_j; +} + +METAL_FUNC void radix4(thread float2* x, thread float2* y) { + float2 z_0 = x[0] + x[2]; + float2 z_1 = x[0] - x[2]; + float2 z_2 = x[1] + x[3]; + float2 z_3 = x[1] - x[3]; + float2 z_3_i = {z_3.y, -z_3.x}; + + y[0] = z_0 + z_2; + y[1] = z_1 + z_3_i; + y[2] = z_0 - z_2; + y[3] = z_1 - z_3_i; +} + +METAL_FUNC void radix5(thread float2* x, thread float2* y) { + float2 root_5_4 = 0.5590169943749475; + float2 sin_2pi_5 = 0.9510565162951535; + float2 sin_1pi_5 = 0.5877852522924731; + + float2 a_1 = x[1] + x[4]; + float2 a_2 = x[2] + x[3]; + float2 a_3 = x[1] - x[4]; + float2 a_4 = x[2] - x[3]; + + float2 a_5 = a_1 + a_2; + float2 a_6 = root_5_4 * (a_1 - a_2); + float2 a_7 = x[0] - a_5 / 4; + float2 a_8 = a_7 + a_6; + float2 a_9 = a_7 - a_6; + float2 a_10 = sin_2pi_5 * a_3 + sin_1pi_5 * a_4; + float2 a_11 = sin_1pi_5 * a_3 - sin_2pi_5 * a_4; + float2 a_10_j = {a_10.y, -a_10.x}; + float2 a_11_j = {a_11.y, -a_11.x}; + + y[0] = x[0] + a_5; + y[1] = a_8 + a_10_j; + y[2] = a_9 + a_11_j; + y[3] = a_9 - a_11_j; + y[4] = a_8 - a_10_j; +} + +METAL_FUNC void radix6(thread float2* x, thread float2* y) { + float sin_pi_3 = 0.8660254037844387; + float2 a_1 = x[2] + x[4]; + float2 a_2 = x[0] - a_1 / 2; + float2 a_3 = sin_pi_3 * (x[2] - x[4]); + float2 a_4 = x[5] + x[1]; + float2 a_5 = x[3] - a_4 / 2; + float2 a_6 = sin_pi_3 * (x[5] - x[1]); + float2 a_7 = x[0] + a_1; + + float2 a_3_i = {a_3.y, -a_3.x}; + float2 a_6_i = {a_6.y, -a_6.x}; + float2 a_8 = a_2 + a_3_i; + float2 a_9 = a_2 - a_3_i; + float2 a_10 = x[3] + a_4; + float2 a_11 = a_5 + a_6_i; + float2 a_12 = a_5 - a_6_i; + + y[0] = a_7 + a_10; + y[1] = a_8 - a_11; + y[2] = a_9 + a_12; + y[3] = a_7 - a_10; + y[4] = a_8 + a_11; + y[5] = a_9 - a_12; +} + +METAL_FUNC void radix7(thread float2* x, thread float2* y) { + // Rader's algorithm + float2 inv = {1 / 6.0, -1 / 6.0}; + + // fft + float2 in1[6] = {x[1], x[3], x[2], x[6], x[4], x[5]}; + radix6(in1, y + 1); + + y[0] = y[1] + x[0]; + + // b_q + y[1] = complex_mul_conj(y[1], float2(-1, 0)); + y[2] = complex_mul_conj(y[2], float2(2.44013336, -1.02261879)); + y[3] = complex_mul_conj(y[3], float2(2.37046941, -1.17510629)); + y[4] = complex_mul_conj(y[4], float2(0, -2.64575131)); + y[5] = complex_mul_conj(y[5], float2(2.37046941, 1.17510629)); + y[6] = complex_mul_conj(y[6], float2(-2.44013336, -1.02261879)); + + // ifft + radix6(y + 1, x + 1); + + y[1] = x[1] * inv + x[0]; + y[5] = x[2] * inv + x[0]; + y[4] = x[3] * inv + x[0]; + y[6] = x[4] * inv + x[0]; + y[2] = x[5] * inv + x[0]; + y[3] = x[6] * inv + x[0]; +} + +METAL_FUNC void radix8(thread float2* x, thread float2* y) { + float cos_pi_4 = 0.7071067811865476; + float2 w_0 = {cos_pi_4, -cos_pi_4}; + float2 w_1 = {-cos_pi_4, -cos_pi_4}; + float2 temp[8] = {x[0], x[2], x[4], x[6], x[1], x[3], x[5], x[7]}; + radix4(temp, x); + radix4(temp + 4, x + 4); + + y[0] = x[0] + x[4]; + y[4] = x[0] - x[4]; + float2 x_5 = complex_mul(x[5], w_0); + y[1] = x[1] + x_5; + y[5] = x[1] - x_5; + float2 x_6 = {x[6].y, -x[6].x}; + y[2] = x[2] + x_6; + y[6] = x[2] - x_6; + float2 x_7 = complex_mul(x[7], w_1); + y[3] = x[3] + x_7; + y[7] = x[3] - x_7; +} + +template +METAL_FUNC void radix10(thread float2* x, thread float2* y) { + float2 w[4]; + w[0] = {0.8090169943749475, -0.5877852522924731}; + w[1] = {0.30901699437494745, -0.9510565162951535}; + w[2] = {-w[1].x, w[1].y}; + w[3] = {-w[0].x, w[0].y}; + + if (raders_perm) { + float2 temp[10] = { + x[0], x[3], x[4], x[8], x[2], x[1], x[7], x[9], x[6], x[5]}; + radix5(temp, x); + radix5(temp + 5, x + 5); + } else { + float2 temp[10] = { + x[0], x[2], x[4], x[6], x[8], x[1], x[3], x[5], x[7], x[9]}; + radix5(temp, x); + radix5(temp + 5, x + 5); + } + + y[0] = x[0] + x[5]; + y[5] = x[0] - x[5]; + for (int t = 1; t < 5; t++) { + float2 a = complex_mul(x[t + 5], w[t - 1]); + y[t] = x[t] + a; + y[t + 5] = x[t] - a; + } +} + +METAL_FUNC void radix11(thread float2* x, thread float2* y) { + // Raders Algorithm + float2 inv = {1 / 10.0, -1 / 10.0}; + + // fft + radix10(x + 1, y + 1); + + y[0] = y[1] + x[0]; + + // b_q + y[1] = complex_mul_conj(y[1], float2(-1, 0)); + y[2] = complex_mul_conj(y[2], float2(0.955301878, -3.17606649)); + y[3] = complex_mul_conj(y[3], float2(2.63610556, 2.01269656)); + y[4] = complex_mul_conj(y[4], float2(2.54127802, 2.13117479)); + y[5] = complex_mul_conj(y[5], float2(2.07016210, 2.59122150)); + y[6] = complex_mul_conj(y[6], float2(0, -3.31662479)); + y[7] = complex_mul_conj(y[7], float2(2.07016210, -2.59122150)); + y[8] = complex_mul_conj(y[8], float2(-2.54127802, 2.13117479)); + y[9] = complex_mul_conj(y[9], float2(2.63610556, -2.01269656)); + y[10] = complex_mul_conj(y[10], float2(-0.955301878, -3.17606649)); + + // ifft + radix10(y + 1, x + 1); + + y[1] = x[1] * inv + x[0]; + y[6] = x[2] * inv + x[0]; + y[3] = x[3] * inv + x[0]; + y[7] = x[4] * inv + x[0]; + y[9] = x[5] * inv + x[0]; + y[10] = x[6] * inv + x[0]; + y[5] = x[7] * inv + x[0]; + y[8] = x[8] * inv + x[0]; + y[4] = x[9] * inv + x[0]; + y[2] = x[10] * inv + x[0]; +} + +template +METAL_FUNC void radix12(thread float2* x, thread float2* y) { + float2 w[6]; + float sin_pi_3 = 0.8660254037844387; + w[0] = {sin_pi_3, -0.5}; + w[1] = {0.5, -sin_pi_3}; + w[2] = {0, -1}; + w[3] = {-0.5, -sin_pi_3}; + w[4] = {-sin_pi_3, -0.5}; + + if (raders_perm) { + float2 temp[12] = { + x[0], + x[3], + x[2], + x[11], + x[8], + x[9], + x[1], + x[7], + x[5], + x[10], + x[4], + x[6]}; + radix6(temp, x); + radix6(temp + 6, x + 6); + } else { + float2 temp[12] = { + x[0], + x[2], + x[4], + x[6], + x[8], + x[10], + x[1], + x[3], + x[5], + x[7], + x[9], + x[11]}; + radix6(temp, x); + radix6(temp + 6, x + 6); + } + + y[0] = x[0] + x[6]; + y[6] = x[0] - x[6]; + for (int t = 1; t < 6; t++) { + float2 a = complex_mul(x[t + 6], w[t - 1]); + y[t] = x[t] + a; + y[t + 6] = x[t] - a; + } +} + +METAL_FUNC void radix13(thread float2* x, thread float2* y) { + // Raders Algorithm + float2 inv = {1 / 12.0, -1 / 12.0}; + + // fft + radix12(x + 1, y + 1); + + y[0] = y[1] + x[0]; + + // b_q + y[1] = complex_mul_conj(y[1], float2(-1, 0)); + y[2] = complex_mul_conj(y[2], float2(3.07497206, -1.88269669)); + y[3] = complex_mul_conj(y[3], float2(3.09912468, 1.84266823)); + y[4] = complex_mul_conj(y[4], float2(3.45084438, -1.04483161)); + y[5] = complex_mul_conj(y[5], float2(0.91083583, 3.48860690)); + y[6] = complex_mul_conj(y[6], float2(-3.60286363, 0.139189267)); + y[7] = complex_mul_conj(y[7], float2(3.60555128, 0)); + y[8] = complex_mul_conj(y[8], float2(3.60286363, 0.139189267)); + y[9] = complex_mul_conj(y[9], float2(0.91083583, -3.48860690)); + y[10] = complex_mul_conj(y[10], float2(-3.45084438, -1.04483161)); + y[11] = complex_mul_conj(y[11], float2(3.09912468, -1.84266823)); + y[12] = complex_mul_conj(y[12], float2(-3.07497206, -1.88269669)); + + // ifft + radix12(y + 1, x + 1); + + y[1] = x[1] * inv + x[0]; + y[7] = x[2] * inv + x[0]; + y[10] = x[3] * inv + x[0]; + y[5] = x[4] * inv + x[0]; + y[9] = x[5] * inv + x[0]; + y[11] = x[6] * inv + x[0]; + y[12] = x[7] * inv + x[0]; + y[6] = x[8] * inv + x[0]; + y[3] = x[9] * inv + x[0]; + y[8] = x[10] * inv + x[0]; + y[4] = x[11] * inv + x[0]; + y[2] = x[12] * inv + x[0]; +} \ No newline at end of file diff --git a/mlx/backend/metal/kernels/fft/readwrite.h b/mlx/backend/metal/kernels/fft/readwrite.h new file mode 100644 index 000000000..ab699e136 --- /dev/null +++ b/mlx/backend/metal/kernels/fft/readwrite.h @@ -0,0 +1,622 @@ +// Copyright © 2024 Apple Inc. + +#include + +#include "mlx/backend/metal/kernels/fft/radix.h" + +/* FFT helpers for reading and writing from/to device memory. + +For many sizes, GPU FFTs are memory bandwidth bound so +read/write performance is important. + +Where possible, we read 128 bits sequentially in each thread, +coalesced with accesses from adajcent threads for optimal performance. + +We implement specialized reading/writing for: + - FFT + - RFFT + - IRFFT + +Each with support for: + - Contiguous reads + - Padded reads + - Strided reads +*/ + +#define MAX_RADIX 13 + +using namespace metal; + +template < + typename in_T, + typename out_T, + int step = 0, + bool four_step_real = false> +struct ReadWriter { + const device in_T* in; + threadgroup float2* buf; + device out_T* out; + int n; + int batch_size; + int elems_per_thread; + uint3 elem; + uint3 grid; + int threads_per_tg; + bool inv; + + // Used for strided access + int strided_device_idx = 0; + int strided_shared_idx = 0; + + METAL_FUNC ReadWriter( + const device in_T* in_, + threadgroup float2* buf_, + device out_T* out_, + const short n_, + const int batch_size_, + const short elems_per_thread_, + const uint3 elem_, + const uint3 grid_, + const bool inv_) + : in(in_), + buf(buf_), + out(out_), + n(n_), + batch_size(batch_size_), + elems_per_thread(elems_per_thread_), + elem(elem_), + grid(grid_), + inv(inv_) { + // Account for padding on last threadgroup + threads_per_tg = elem.x == grid.x - 1 + ? (batch_size - (grid.x - 1) * grid.y) * grid.z + : grid.y * grid.z; + } + + // ifft(x) = 1/n * conj(fft(conj(x))) + METAL_FUNC float2 post_in(float2 elem) const { + return inv ? float2(elem.x, -elem.y) : elem; + } + + // Handle float case for generic RFFT alg + METAL_FUNC float2 post_in(float elem) const { + return float2(elem, 0); + } + + METAL_FUNC float2 pre_out(float2 elem) const { + return inv ? float2(elem.x / n, -elem.y / n) : elem; + } + + METAL_FUNC float2 pre_out(float2 elem, int length) const { + return inv ? float2(elem.x / length, -elem.y / length) : elem; + } + + METAL_FUNC bool out_of_bounds() const { + // Account for possible extra threadgroups + int grid_index = elem.x * grid.y + elem.y; + return grid_index >= batch_size; + } + + METAL_FUNC void load() const { + int batch_idx = elem.x * grid.y * n; + short tg_idx = elem.y * grid.z + elem.z; + short max_index = grid.y * n - 2; + + // 2 complex64s = 128 bits + constexpr int read_width = 2; + for (short e = 0; e < (elems_per_thread / read_width); e++) { + short index = read_width * tg_idx + read_width * threads_per_tg * e; + index = metal::min(index, max_index); + // vectorized reads + buf[index] = post_in(in[batch_idx + index]); + buf[index + 1] = post_in(in[batch_idx + index + 1]); + } + max_index += 1; + if (elems_per_thread % 2 != 0) { + short index = tg_idx + + read_width * threads_per_tg * (elems_per_thread / read_width); + index = metal::min(index, max_index); + buf[index] = post_in(in[batch_idx + index]); + } + } + + METAL_FUNC void write() const { + int batch_idx = elem.x * grid.y * n; + short tg_idx = elem.y * grid.z + elem.z; + short max_index = grid.y * n - 2; + + constexpr int read_width = 2; + for (short e = 0; e < (elems_per_thread / read_width); e++) { + short index = read_width * tg_idx + read_width * threads_per_tg * e; + index = metal::min(index, max_index); + // vectorized reads + out[batch_idx + index] = pre_out(buf[index]); + out[batch_idx + index + 1] = pre_out(buf[index + 1]); + } + max_index += 1; + if (elems_per_thread % 2 != 0) { + short index = tg_idx + + read_width * threads_per_tg * (elems_per_thread / read_width); + index = metal::min(index, max_index); + out[batch_idx + index] = pre_out(buf[index]); + } + } + + // Padded IO for Bluestein's algorithm + METAL_FUNC void load_padded(int length, const device float2* w_k) const { + int batch_idx = elem.x * grid.y * length + elem.y * length; + int fft_idx = elem.z; + int m = grid.z; + + threadgroup float2* seq_buf = buf + elem.y * n; + for (int e = 0; e < elems_per_thread; e++) { + int index = metal::min(fft_idx + e * m, n - 1); + if (index < length) { + float2 elem = post_in(in[batch_idx + index]); + seq_buf[index] = complex_mul(elem, w_k[index]); + } else { + seq_buf[index] = 0.0; + } + } + } + + METAL_FUNC void write_padded(int length, const device float2* w_k) const { + int batch_idx = elem.x * grid.y * length + elem.y * length; + int fft_idx = elem.z; + int m = grid.z; + float2 inv_factor = {1.0f / n, -1.0f / n}; + + threadgroup float2* seq_buf = buf + elem.y * n; + for (int e = 0; e < elems_per_thread; e++) { + int index = metal::min(fft_idx + e * m, n - 1); + if (index < length) { + float2 elem = seq_buf[index + length - 1] * inv_factor; + out[batch_idx + index] = pre_out(complex_mul(elem, w_k[index]), length); + } + } + } + + // Strided IO for four step FFT + METAL_FUNC void compute_strided_indices(int stride, int overall_n) { + // Use the batch threadgroup dimension to coalesce memory accesses: + // e.g. stride = 12 + // device | shared mem + // 0 1 2 3 | 0 12 - - + // - - - - | 1 13 - - + // - - - - | 2 14 - - + // 12 13 14 15 | 3 15 - - + int coalesce_width = grid.y; + int tg_idx = elem.y * grid.z + elem.z; + int outer_batch_size = stride / coalesce_width; + + int strided_batch_idx = (elem.x % outer_batch_size) * coalesce_width + + overall_n * (elem.x / outer_batch_size); + strided_device_idx = strided_batch_idx + + tg_idx / coalesce_width * elems_per_thread * stride + + tg_idx % coalesce_width; + strided_shared_idx = (tg_idx % coalesce_width) * n + + tg_idx / coalesce_width * elems_per_thread; + } + + // Four Step FFT First Step + METAL_FUNC void load_strided(int stride, int overall_n) { + compute_strided_indices(stride, overall_n); + for (int e = 0; e < elems_per_thread; e++) { + buf[strided_shared_idx + e] = + post_in(in[strided_device_idx + e * stride]); + } + } + + METAL_FUNC void write_strided(int stride, int overall_n) { + for (int e = 0; e < elems_per_thread; e++) { + float2 output = buf[strided_shared_idx + e]; + int combined_idx = (strided_device_idx + e * stride) % overall_n; + int ij = (combined_idx / stride) * (combined_idx % stride); + // Apply four step twiddles at end of first step + float2 twiddle = get_twiddle(ij, overall_n); + out[strided_device_idx + e * stride] = complex_mul(output, twiddle); + } + } +}; + +// Four Step FFT Second Step +template <> +METAL_FUNC void ReadWriter::load_strided( + int stride, + int overall_n) { + // Silence compiler warnings + (void)stride; + (void)overall_n; + // Don't invert between steps + bool default_inv = inv; + inv = false; + load(); + inv = default_inv; +} + +template <> +METAL_FUNC void ReadWriter::write_strided( + int stride, + int overall_n) { + compute_strided_indices(stride, overall_n); + for (int e = 0; e < elems_per_thread; e++) { + float2 output = buf[strided_shared_idx + e]; + out[strided_device_idx + e * stride] = pre_out(output, overall_n); + } +} + +// For RFFT, we interleave batches of two real sequences into one complex one: +// +// z_k = x_k + j.y_k +// X_k = (Z_k + Z_(N-k)*) / 2 +// Y_k = -j * ((Z_k - Z_(N-k)*) / 2) +// +// This roughly doubles the throughput over the regular FFT. +template <> +METAL_FUNC bool ReadWriter::out_of_bounds() const { + int grid_index = elem.x * grid.y + elem.y; + // We pack two sequences into one for RFFTs + return grid_index * 2 >= batch_size; +} + +template <> +METAL_FUNC void ReadWriter::load() const { + int batch_idx = elem.x * grid.y * n * 2 + elem.y * n * 2; + threadgroup float2* seq_buf = buf + elem.y * n; + + // No out of bounds accesses on odd batch sizes + int grid_index = elem.x * grid.y + elem.y; + short next_in = + batch_size % 2 == 1 && grid_index * 2 == batch_size - 1 ? 0 : n; + + short m = grid.z; + short fft_idx = elem.z; + + for (int e = 0; e < elems_per_thread; e++) { + int index = metal::min(fft_idx + e * m, n - 1); + seq_buf[index].x = in[batch_idx + index]; + seq_buf[index].y = in[batch_idx + index + next_in]; + } +} + +template <> +METAL_FUNC void ReadWriter::write() const { + short n_over_2 = (n / 2) + 1; + + int batch_idx = elem.x * grid.y * n_over_2 * 2 + elem.y * n_over_2 * 2; + threadgroup float2* seq_buf = buf + elem.y * n; + + int grid_index = elem.x * grid.y + elem.y; + short next_out = + batch_size % 2 == 1 && grid_index * 2 == batch_size - 1 ? 0 : n_over_2; + + float2 conj = {1, -1}; + float2 minus_j = {0, -1}; + + short m = grid.z; + short fft_idx = elem.z; + + for (int e = 0; e < elems_per_thread / 2 + 1; e++) { + int index = metal::min(fft_idx + e * m, n_over_2 - 1); + // x_0 = z_0.real + // y_0 = z_0.imag + if (index == 0) { + out[batch_idx + index] = {seq_buf[index].x, 0}; + out[batch_idx + index + next_out] = {seq_buf[index].y, 0}; + } else { + float2 x_k = seq_buf[index]; + float2 x_n_minus_k = seq_buf[n - index] * conj; + out[batch_idx + index] = (x_k + x_n_minus_k) / 2; + out[batch_idx + index + next_out] = + complex_mul(((x_k - x_n_minus_k) / 2), minus_j); + } + } +} + +template <> +METAL_FUNC void ReadWriter::load_padded( + int length, + const device float2* w_k) const { + int batch_idx = elem.x * grid.y * length * 2 + elem.y * length * 2; + threadgroup float2* seq_buf = buf + elem.y * n; + + // No out of bounds accesses on odd batch sizes + int grid_index = elem.x * grid.y + elem.y; + short next_in = + batch_size % 2 == 1 && grid_index * 2 == batch_size - 1 ? 0 : length; + + short m = grid.z; + short fft_idx = elem.z; + + for (int e = 0; e < elems_per_thread; e++) { + int index = metal::min(fft_idx + e * m, n - 1); + if (index < length) { + float2 elem = + float2(in[batch_idx + index], in[batch_idx + index + next_in]); + seq_buf[index] = complex_mul(elem, w_k[index]); + } else { + seq_buf[index] = 0; + } + } +} + +template <> +METAL_FUNC void ReadWriter::write_padded( + int length, + const device float2* w_k) const { + int length_over_2 = (length / 2) + 1; + int batch_idx = + elem.x * grid.y * length_over_2 * 2 + elem.y * length_over_2 * 2; + threadgroup float2* seq_buf = buf + elem.y * n + length - 1; + + int grid_index = elem.x * grid.y + elem.y; + short next_out = batch_size % 2 == 1 && grid_index * 2 == batch_size - 1 + ? 0 + : length_over_2; + + float2 conj = {1, -1}; + float2 inv_factor = {1.0f / n, -1.0f / n}; + float2 minus_j = {0, -1}; + + short m = grid.z; + short fft_idx = elem.z; + + for (int e = 0; e < elems_per_thread / 2 + 1; e++) { + int index = metal::min(fft_idx + e * m, length_over_2 - 1); + // x_0 = z_0.real + // y_0 = z_0.imag + if (index == 0) { + float2 elem = complex_mul(w_k[index], seq_buf[index] * inv_factor); + out[batch_idx + index] = float2(elem.x, 0); + out[batch_idx + index + next_out] = float2(elem.y, 0); + } else { + float2 x_k = complex_mul(w_k[index], seq_buf[index] * inv_factor); + float2 x_n_minus_k = complex_mul( + w_k[length - index], seq_buf[length - index] * inv_factor); + x_n_minus_k *= conj; + // w_k should happen before this extraction + out[batch_idx + index] = (x_k + x_n_minus_k) / 2; + out[batch_idx + index + next_out] = + complex_mul(((x_k - x_n_minus_k) / 2), minus_j); + } + } +} + +// For IRFFT, we do the opposite +// +// Z_k = X_k + j.Y_k +// x_k = Re(Z_k) +// Y_k = Imag(Z_k) +template <> +METAL_FUNC bool ReadWriter::out_of_bounds() const { + int grid_index = elem.x * grid.y + elem.y; + // We pack two sequences into one for IRFFTs + return grid_index * 2 >= batch_size; +} + +template <> +METAL_FUNC void ReadWriter::load() const { + short n_over_2 = (n / 2) + 1; + int batch_idx = elem.x * grid.y * n_over_2 * 2 + elem.y * n_over_2 * 2; + threadgroup float2* seq_buf = buf + elem.y * n; + + // No out of bounds accesses on odd batch sizes + int grid_index = elem.x * grid.y + elem.y; + short next_in = + batch_size % 2 == 1 && grid_index * 2 == batch_size - 1 ? 0 : n_over_2; + + short m = grid.z; + short fft_idx = elem.z; + + float2 conj = {1, -1}; + float2 plus_j = {0, 1}; + + for (int t = 0; t < elems_per_thread / 2 + 1; t++) { + int index = metal::min(fft_idx + t * m, n_over_2 - 1); + float2 x = in[batch_idx + index]; + float2 y = in[batch_idx + index + next_in]; + // NumPy forces first input to be real + bool first_val = index == 0; + // NumPy forces last input on even irffts to be real + bool last_val = n % 2 == 0 && index == n_over_2 - 1; + if (first_val || last_val) { + x = float2(x.x, 0); + y = float2(y.x, 0); + } + seq_buf[index] = x + complex_mul(y, plus_j); + seq_buf[index].y = -seq_buf[index].y; + if (index > 0 && !last_val) { + seq_buf[n - index] = (x * conj) + complex_mul(y * conj, plus_j); + seq_buf[n - index].y = -seq_buf[n - index].y; + } + } +} + +template <> +METAL_FUNC void ReadWriter::write() const { + int batch_idx = elem.x * grid.y * n * 2 + elem.y * n * 2; + threadgroup float2* seq_buf = buf + elem.y * n; + + int grid_index = elem.x * grid.y + elem.y; + short next_out = + batch_size % 2 == 1 && grid_index * 2 == batch_size - 1 ? 0 : n; + + short m = grid.z; + short fft_idx = elem.z; + + for (int e = 0; e < elems_per_thread; e++) { + int index = metal::min(fft_idx + e * m, n - 1); + out[batch_idx + index] = seq_buf[index].x / n; + out[batch_idx + index + next_out] = seq_buf[index].y / -n; + } +} + +template <> +METAL_FUNC void ReadWriter::load_padded( + int length, + const device float2* w_k) const { + int n_over_2 = (n / 2) + 1; + int length_over_2 = (length / 2) + 1; + + int batch_idx = + elem.x * grid.y * length_over_2 * 2 + elem.y * length_over_2 * 2; + threadgroup float2* seq_buf = buf + elem.y * n; + + // No out of bounds accesses on odd batch sizes + int grid_index = elem.x * grid.y + elem.y; + short next_in = batch_size % 2 == 1 && grid_index * 2 == batch_size - 1 + ? 0 + : length_over_2; + + short m = grid.z; + short fft_idx = elem.z; + + float2 conj = {1, -1}; + float2 plus_j = {0, 1}; + + for (int t = 0; t < elems_per_thread / 2 + 1; t++) { + int index = metal::min(fft_idx + t * m, n_over_2 - 1); + float2 x = in[batch_idx + index]; + float2 y = in[batch_idx + index + next_in]; + if (index < length_over_2) { + bool last_val = length % 2 == 0 && index == length_over_2 - 1; + if (last_val) { + x = float2(x.x, 0); + y = float2(y.x, 0); + } + float2 elem1 = x + complex_mul(y, plus_j); + seq_buf[index] = complex_mul(elem1 * conj, w_k[index]); + if (index > 0 && !last_val) { + float2 elem2 = (x * conj) + complex_mul(y * conj, plus_j); + seq_buf[length - index] = + complex_mul(elem2 * conj, w_k[length - index]); + } + } else { + short pad_index = metal::min(length + (index - length_over_2) * 2, n - 2); + seq_buf[pad_index] = 0; + seq_buf[pad_index + 1] = 0; + } + } +} + +template <> +METAL_FUNC void ReadWriter::write_padded( + int length, + const device float2* w_k) const { + int batch_idx = elem.x * grid.y * length * 2 + elem.y * length * 2; + threadgroup float2* seq_buf = buf + elem.y * n + length - 1; + + int grid_index = elem.x * grid.y + elem.y; + short next_out = + batch_size % 2 == 1 && grid_index * 2 == batch_size - 1 ? 0 : length; + + short m = grid.z; + short fft_idx = elem.z; + + float2 inv_factor = {1.0f / n, -1.0f / n}; + for (int e = 0; e < elems_per_thread; e++) { + int index = fft_idx + e * m; + if (index < length) { + float2 output = complex_mul(seq_buf[index] * inv_factor, w_k[index]); + out[batch_idx + index] = output.x / length; + out[batch_idx + index + next_out] = output.y / -length; + } + } +} + +// Four Step RFFT +template <> +METAL_FUNC void +ReadWriter::load_strided( + int stride, + int overall_n) { + // Silence compiler warnings + (void)stride; + (void)overall_n; + // Don't invert between steps + bool default_inv = inv; + inv = false; + load(); + inv = default_inv; +} + +template <> +METAL_FUNC void +ReadWriter::write_strided( + int stride, + int overall_n) { + int overall_n_over_2 = overall_n / 2 + 1; + int coalesce_width = grid.y; + int tg_idx = elem.y * grid.z + elem.z; + int outer_batch_size = stride / coalesce_width; + + int strided_batch_idx = (elem.x % outer_batch_size) * coalesce_width + + overall_n_over_2 * (elem.x / outer_batch_size); + strided_device_idx = strided_batch_idx + + tg_idx / coalesce_width * elems_per_thread / 2 * stride + + tg_idx % coalesce_width; + strided_shared_idx = (tg_idx % coalesce_width) * n + + tg_idx / coalesce_width * elems_per_thread / 2; + for (int e = 0; e < elems_per_thread / 2; e++) { + float2 output = buf[strided_shared_idx + e]; + out[strided_device_idx + e * stride] = output; + } + + // Add on n/2 + 1 element + if (tg_idx == 0 && elem.x % outer_batch_size == 0) { + out[strided_batch_idx + overall_n / 2] = buf[n / 2]; + } +} + +// Four Step IRFFT +template <> +METAL_FUNC void +ReadWriter::load_strided( + int stride, + int overall_n) { + int overall_n_over_2 = overall_n / 2 + 1; + auto conj = float2(1, -1); + + compute_strided_indices(stride, overall_n); + // Translate indices in terms of N - k + for (int e = 0; e < elems_per_thread; e++) { + int device_idx = strided_device_idx + e * stride; + int overall_batch = device_idx / overall_n; + int overall_index = device_idx % overall_n; + if (overall_index < overall_n_over_2) { + device_idx -= overall_batch * (overall_n - overall_n_over_2); + buf[strided_shared_idx + e] = in[device_idx] * conj; + } else { + int conj_idx = overall_n - overall_index; + device_idx = overall_batch * overall_n_over_2 + conj_idx; + buf[strided_shared_idx + e] = in[device_idx]; + } + } +} + +template <> +METAL_FUNC void +ReadWriter::load_strided( + int stride, + int overall_n) { + // Silence compiler warnings + (void)stride; + (void)overall_n; + bool default_inv = inv; + inv = false; + load(); + inv = default_inv; +} + +template <> +METAL_FUNC void +ReadWriter::write_strided( + int stride, + int overall_n) { + compute_strided_indices(stride, overall_n); + + for (int e = 0; e < elems_per_thread; e++) { + out[strided_device_idx + e * stride] = + pre_out(buf[strided_shared_idx + e], overall_n).x; + } +} diff --git a/mlx/backend/metal/nojit_kernels.cpp b/mlx/backend/metal/nojit_kernels.cpp index 4db547e36..d789bf2e9 100644 --- a/mlx/backend/metal/nojit_kernels.cpp +++ b/mlx/backend/metal/nojit_kernels.cpp @@ -191,4 +191,17 @@ MTL::ComputePipelineState* get_steel_conv_general_kernel( return d.get_kernel(kernel_name); } +MTL::ComputePipelineState* get_fft_kernel( + metal::Device& d, + const std::string& kernel_name, + const std::string& hash_name, + const int tg_mem_size, + const std::string& in_type, + const std::string& out_type, + int step, + bool real, + const metal::MTLFCList& func_consts) { + return d.get_kernel(kernel_name, "mlx", hash_name, func_consts); +} + } // namespace mlx::core diff --git a/mlx/backend/metal/utils.h b/mlx/backend/metal/utils.h index 9c4f3e921..ee2b4a07b 100644 --- a/mlx/backend/metal/utils.h +++ b/mlx/backend/metal/utils.h @@ -134,6 +134,13 @@ bool is_power_of_2(int n) { return ((n & (n - 1)) == 0) && n != 0; } +int next_power_of_2(int n) { + if (is_power_of_2(n)) { + return n; + } + return pow(2, std::ceil(std::log2(n))); +} + } // namespace } // namespace mlx::core diff --git a/mlx/fft.cpp b/mlx/fft.cpp index 95791a57c..67ba37c13 100644 --- a/mlx/fft.cpp +++ b/mlx/fft.cpp @@ -81,7 +81,8 @@ array fft_impl( if (any_greater) { // Pad with zeros auto tmp = zeros(in_shape, a.dtype(), s); - in = scatter(tmp, std::vector{}, in, std::vector{}, s); + std::vector starts(in.ndim(), 0); + in = slice_update(tmp, in, starts, in.shape()); } auto out_shape = in_shape; diff --git a/python/tests/test_fft.py b/python/tests/test_fft.py index 6cb4b1618..d15253a25 100644 --- a/python/tests/test_fft.py +++ b/python/tests/test_fft.py @@ -16,96 +16,140 @@ class TestFFT(mlx_tests.MLXTestCase): np.testing.assert_allclose(out_np, out_mx, atol=atol, rtol=rtol) def test_fft(self): - with mx.stream(mx.cpu): - r = np.random.rand(100).astype(np.float32) - i = np.random.rand(100).astype(np.float32) - a_np = r + 1j * i - self.check_mx_np(mx.fft.fft, np.fft.fft, a_np) + r = np.random.rand(100).astype(np.float32) + i = np.random.rand(100).astype(np.float32) + a_np = r + 1j * i + self.check_mx_np(mx.fft.fft, np.fft.fft, a_np) - # Check with slicing and padding - r = np.random.rand(100).astype(np.float32) - i = np.random.rand(100).astype(np.float32) - a_np = r + 1j * i - self.check_mx_np(mx.fft.fft, np.fft.fft, a_np, n=80) - self.check_mx_np(mx.fft.fft, np.fft.fft, a_np, n=120) + # Check with slicing and padding + r = np.random.rand(100).astype(np.float32) + i = np.random.rand(100).astype(np.float32) + a_np = r + 1j * i + self.check_mx_np(mx.fft.fft, np.fft.fft, a_np, n=80) + self.check_mx_np(mx.fft.fft, np.fft.fft, a_np, n=120) - # Check different axes - r = np.random.rand(100, 100).astype(np.float32) - i = np.random.rand(100, 100).astype(np.float32) - a_np = r + 1j * i - self.check_mx_np(mx.fft.fft, np.fft.fft, a_np, axis=0) - self.check_mx_np(mx.fft.fft, np.fft.fft, a_np, axis=1) + # Check different axes + r = np.random.rand(100, 100).astype(np.float32) + i = np.random.rand(100, 100).astype(np.float32) + a_np = r + 1j * i + self.check_mx_np(mx.fft.fft, np.fft.fft, a_np, axis=0) + self.check_mx_np(mx.fft.fft, np.fft.fft, a_np, axis=1) - # Check real fft - a_np = np.random.rand(100).astype(np.float32) - self.check_mx_np(mx.fft.rfft, np.fft.rfft, a_np) - self.check_mx_np(mx.fft.rfft, np.fft.rfft, a_np, n=80) - self.check_mx_np(mx.fft.rfft, np.fft.rfft, a_np, n=120) + # Check real fft + a_np = np.random.rand(100).astype(np.float32) + self.check_mx_np(mx.fft.rfft, np.fft.rfft, a_np) + self.check_mx_np(mx.fft.rfft, np.fft.rfft, a_np, n=80) + self.check_mx_np(mx.fft.rfft, np.fft.rfft, a_np, n=120) - # Check real inverse - r = np.random.rand(100, 100).astype(np.float32) - i = np.random.rand(100, 100).astype(np.float32) - a_np = r + 1j * i - self.check_mx_np(mx.fft.ifft, np.fft.ifft, a_np) - self.check_mx_np(mx.fft.ifft, np.fft.ifft, a_np, n=80) - self.check_mx_np(mx.fft.ifft, np.fft.ifft, a_np, n=120) - self.check_mx_np(mx.fft.irfft, np.fft.irfft, a_np) - self.check_mx_np(mx.fft.irfft, np.fft.irfft, a_np, n=80) - self.check_mx_np(mx.fft.irfft, np.fft.irfft, a_np, n=120) + # Check real inverse + r = np.random.rand(100, 100).astype(np.float32) + i = np.random.rand(100, 100).astype(np.float32) + a_np = r + 1j * i + self.check_mx_np(mx.fft.ifft, np.fft.ifft, a_np) + self.check_mx_np(mx.fft.ifft, np.fft.ifft, a_np, n=80) + self.check_mx_np(mx.fft.ifft, np.fft.ifft, a_np, n=120) + + x = np.fft.rfft(a_np) + self.check_mx_np(mx.fft.irfft, np.fft.irfft, x) def test_fftn(self): - with mx.stream(mx.cpu): - r = np.random.randn(8, 8, 8).astype(np.float32) - i = np.random.randn(8, 8, 8).astype(np.float32) - a = r + 1j * i + r = np.random.randn(8, 8, 8).astype(np.float32) + i = np.random.randn(8, 8, 8).astype(np.float32) + a = r + 1j * i - axes = [None, (1, 2), (2, 1), (0, 2)] - shapes = [None, (10, 5), (5, 10)] - ops = [ - "fft2", - "ifft2", - "rfft2", - "irfft2", - "fftn", - "ifftn", - "rfftn", - "irfftn", + axes = [None, (1, 2), (2, 1), (0, 2)] + shapes = [None, (10, 5), (5, 10)] + ops = [ + "fft2", + "ifft2", + "rfft2", + "irfft2", + "fftn", + "ifftn", + "rfftn", + "irfftn", + ] + + for op, ax, s in itertools.product(ops, axes, shapes): + x = a + if op in ["rfft2", "rfftn"]: + x = r + elif op == "irfft2": + x = np.ascontiguousarray(np.fft.rfft2(x, axes=ax, s=s)) + elif op == "irfftn": + x = np.ascontiguousarray(np.fft.rfftn(x, axes=ax, s=s)) + mx_op = getattr(mx.fft, op) + np_op = getattr(np.fft, op) + self.check_mx_np(mx_op, np_op, x, axes=ax, s=s) + + def _run_ffts(self, shape, atol=1e-4, rtol=1e-4): + np.random.seed(9) + + r = np.random.rand(*shape).astype(np.float32) + i = np.random.rand(*shape).astype(np.float32) + a_np = r + 1j * i + self.check_mx_np(mx.fft.fft, np.fft.fft, a_np, atol=atol, rtol=rtol) + self.check_mx_np(mx.fft.ifft, np.fft.ifft, a_np, atol=atol, rtol=rtol) + + self.check_mx_np(mx.fft.rfft, np.fft.rfft, r, atol=atol, rtol=rtol) + + ia_np = np.fft.rfft(a_np) + self.check_mx_np( + mx.fft.irfft, np.fft.irfft, ia_np, atol=atol, rtol=rtol, n=shape[-1] + ) + self.check_mx_np(mx.fft.irfft, np.fft.irfft, ia_np, atol=atol, rtol=rtol) + + def test_fft_shared_mem(self): + nums = np.concatenate( + [ + # small radix + np.arange(2, 14), + # powers of 2 + [2**k for k in range(4, 13)], + # stockham + [3 * 3 * 3, 3 * 11, 11 * 13 * 2, 7 * 4 * 13 * 11, 13 * 13 * 11], + # rader + [17, 23, 29, 17 * 8 * 3, 23 * 2, 1153, 1982], + # bluestein + [47, 83, 17 * 17], + # large stockham + [3159, 3645, 3969, 4004], ] + ) + for batch_size in (1, 3, 32): + for num in nums: + atol = 1e-4 if num < 1025 else 1e-3 + self._run_ffts((batch_size, num), atol=atol) - for op, ax, s in itertools.product(ops, axes, shapes): - x = a - if op in ["rfft2", "rfftn"]: - x = r - mx_op = getattr(mx.fft, op) - np_op = getattr(np.fft, op) - self.check_mx_np(mx_op, np_op, x, axes=ax, s=s) + @unittest.skip("Too slow for CI but useful for local testing.") + def test_fft_exhaustive(self): + nums = range(2, 4097) + for batch_size in (1, 3, 32): + for num in nums: + print(num) + atol = 1e-4 if num < 1025 else 1e-3 + self._run_ffts((batch_size, num), atol=atol) - def test_fft_powers_of_two(self): - shape = (16, 4, 8) - # np.fft.fft always uses double precision complex128 - # mx.fft.fft only supports single precision complex64 - # hence the fairly tolerant equality checks. - atol = 1e-4 - rtol = 1e-4 - np.random.seed(7) - for k in range(4, 12): - r = np.random.rand(*shape, 2**k).astype(np.float32) - i = np.random.rand(*shape, 2**k).astype(np.float32) - a_np = r + 1j * i - self.check_mx_np(mx.fft.fft, np.fft.fft, a_np, atol=atol, rtol=rtol) + def test_fft_big_powers_of_two(self): + # TODO: improve precision on big powers of two on GPU + for k in range(12, 17): + self._run_ffts((3, 2**k), atol=1e-3) - r = np.random.rand(*shape, 32).astype(np.float32) - i = np.random.rand(*shape, 32).astype(np.float32) - a_np = r + 1j * i - for axis in range(4): - self.check_mx_np( - mx.fft.fft, np.fft.fft, a_np, atol=atol, rtol=rtol, axis=axis - ) + for k in range(17, 20): + self._run_ffts((3, 2**k), atol=1e-2) - r = np.random.rand(4, 8).astype(np.float32) - i = np.random.rand(4, 8).astype(np.float32) - a_np = r + 1j * i - a_mx = mx.array(a_np) + def test_fft_large_numbers(self): + numbers = [ + 1037, # prime > 2048 + 18247, # medium size prime factors + 1259 * 11, # large prime factors + 7883, # large prime + 3**8, # large stockham decomposable + 3109, # bluestein + 4006, # large rader + ] + for large_num in numbers: + self._run_ffts((1, large_num), atol=1e-3) def test_fft_contiguity(self): r = np.random.rand(4, 8).astype(np.float32) diff --git a/tests/fft_tests.cpp b/tests/fft_tests.cpp index 5d960ef10..84d2f20c4 100644 --- a/tests/fft_tests.cpp +++ b/tests/fft_tests.cpp @@ -7,8 +7,6 @@ using namespace mlx::core; TEST_CASE("test fft basics") { - auto device = default_device(); - set_default_device(Device::cpu); array x(1.0); CHECK_THROWS(fft::fft(x)); CHECK_THROWS(fft::ifft(x)); @@ -94,13 +92,9 @@ TEST_CASE("test fft basics") { CHECK(array_equal(y, array(expected_1, {2, 2})).item()); CHECK(array_equal(fft::ifft(y, 1), x).item()); } - set_default_device(device); } TEST_CASE("test real ffts") { - auto device = default_device(); - set_default_device(Device::cpu); - auto x = array({1.0}); auto y = fft::rfft(x); CHECK_EQ(y.dtype(), complex64); @@ -124,14 +118,9 @@ TEST_CASE("test real ffts") { CHECK_EQ(y.size(), 2); CHECK_EQ(y.dtype(), float32); CHECK(array_equal(y, array({0.5f, -0.5f})).item()); - - set_default_device(device); } TEST_CASE("test fftn") { - auto device = default_device(); - set_default_device(Device::cpu); - auto x = zeros({5, 5, 5}); CHECK_THROWS_AS(fft::fftn(x, {}, {0, 3}), std::invalid_argument); CHECK_THROWS_AS(fft::fftn(x, {}, {0, -4}), std::invalid_argument); @@ -204,8 +193,6 @@ TEST_CASE("test fftn") { CHECK_EQ(y.shape(), std::vector{5, 8}); CHECK_EQ(y.dtype(), float32); } - - set_default_device(device); } TEST_CASE("test fft with provided shape") { @@ -234,9 +221,6 @@ TEST_CASE("test fft with provided shape") { } TEST_CASE("test fft vmap") { - auto device = default_device(); - set_default_device(Device::cpu); - auto fft_fn = [](array x) { return fft::fft(x); }; auto x = reshape(arange(8), {2, 4}); auto y = vmap(fft_fn)(x); @@ -252,14 +236,9 @@ TEST_CASE("test fft vmap") { y = vmap(rfft_fn, 1, 1)(x); CHECK(array_equal(y, fft::rfft(x, 0)).item()); - - set_default_device(device); } TEST_CASE("test fft grads") { - auto device = default_device(); - set_default_device(Device::cpu); - // Regular auto fft_fn = [](array x) { return fft::fft(x); }; auto cotangent = astype(arange(10), complex64); @@ -328,6 +307,4 @@ TEST_CASE("test fft grads") { zeros({5, 8})) .second; CHECK_EQ(vjp_out.shape(), std::vector{5, 5}); - - set_default_device(device); }