mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Feature complete Metal FFT (#1102)
* feature complete metal fft * fix contiguity bug * jit fft * simplify rader/bluestein constant computation * remove kernel/utils.h dep * remove bf16.h dep * format --------- Co-authored-by: Alex Barron <abarron22@apple.com>
This commit is contained in:
parent
0e585b4409
commit
27d70c7d9d
@ -3,6 +3,8 @@
|
||||
import matplotlib
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
import sympy
|
||||
import torch
|
||||
from time_utils import measure_runtime
|
||||
|
||||
matplotlib.use("Agg")
|
||||
@ -16,40 +18,100 @@ def bandwidth_gb(runtime_ms, system_size):
|
||||
return system_size * bytes_per_fft / runtime_ms * ms_per_s / bytes_per_gb
|
||||
|
||||
|
||||
def run_bench(system_size):
|
||||
def fft(x):
|
||||
out = mx.fft.fft(x)
|
||||
def run_bench(system_size, fft_sizes, backend="mlx", dim=1):
|
||||
def fft_mlx(x):
|
||||
if dim == 1:
|
||||
out = mx.fft.fft(x)
|
||||
elif dim == 2:
|
||||
out = mx.fft.fft2(x)
|
||||
mx.eval(out)
|
||||
return out
|
||||
|
||||
bandwidths = []
|
||||
for k in range(4, 12):
|
||||
n = 2**k
|
||||
x = mx.random.uniform(shape=(system_size // n, n)).astype(mx.float32)
|
||||
x = x.astype(mx.complex64)
|
||||
mx.eval(x)
|
||||
runtime_ms = measure_runtime(fft, x=x)
|
||||
bandwidths.append(bandwidth_gb(runtime_ms, system_size))
|
||||
def fft_mps(x):
|
||||
if dim == 1:
|
||||
out = torch.fft.fft(x)
|
||||
elif dim == 2:
|
||||
out = torch.fft.fft2(x)
|
||||
torch.mps.synchronize()
|
||||
return out
|
||||
|
||||
return bandwidths
|
||||
bandwidths = []
|
||||
for n in fft_sizes:
|
||||
batch_size = system_size // n**dim
|
||||
shape = [batch_size] + [n for _ in range(dim)]
|
||||
if backend == "mlx":
|
||||
x_np = np.random.uniform(size=(system_size // n, n)).astype(np.complex64)
|
||||
x = mx.array(x_np)
|
||||
mx.eval(x)
|
||||
fft = fft_mlx
|
||||
elif backend == "mps":
|
||||
x_np = np.random.uniform(size=(system_size // n, n)).astype(np.complex64)
|
||||
x = torch.tensor(x_np, device="mps")
|
||||
torch.mps.synchronize()
|
||||
fft = fft_mps
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
runtime_ms = measure_runtime(fft, x=x)
|
||||
bandwidth = bandwidth_gb(runtime_ms, np.prod(shape))
|
||||
print(n, bandwidth)
|
||||
bandwidths.append(bandwidth)
|
||||
|
||||
return np.array(bandwidths)
|
||||
|
||||
|
||||
def time_fft():
|
||||
with mx.stream(mx.cpu):
|
||||
cpu_bandwidths = run_bench(system_size=int(2**22))
|
||||
x = np.array(range(2, 512))
|
||||
system_size = int(2**26)
|
||||
|
||||
print("MLX GPU")
|
||||
with mx.stream(mx.gpu):
|
||||
gpu_bandwidths = run_bench(system_size=int(2**29))
|
||||
gpu_bandwidths = run_bench(system_size=system_size, fft_sizes=x)
|
||||
|
||||
# plot bandwidths
|
||||
x = [2**k for k in range(4, 12)]
|
||||
plt.scatter(x, gpu_bandwidths, color="green", label="GPU")
|
||||
plt.scatter(x, cpu_bandwidths, color="red", label="CPU")
|
||||
plt.title("MLX FFT Benchmark")
|
||||
plt.xlabel("N")
|
||||
plt.ylabel("Bandwidth (GB/s)")
|
||||
plt.legend()
|
||||
plt.savefig("fft_plot.png")
|
||||
print("MPS GPU")
|
||||
mps_bandwidths = run_bench(system_size=system_size, fft_sizes=x, backend="mps")
|
||||
|
||||
print("CPU")
|
||||
system_size = int(2**20)
|
||||
with mx.stream(mx.cpu):
|
||||
cpu_bandwidths = run_bench(system_size=system_size, fft_sizes=x)
|
||||
|
||||
x = np.array(x)
|
||||
|
||||
all_indices = x - x[0]
|
||||
radix_2to13 = (
|
||||
np.array([i for i in x if all(p <= 13 for p in sympy.primefactors(i))]) - x[0]
|
||||
)
|
||||
bluesteins = (
|
||||
np.array([i for i in x if any(p > 13 for p in sympy.primefactors(i))]) - x[0]
|
||||
)
|
||||
|
||||
for indices, name in [
|
||||
(all_indices, "All"),
|
||||
(radix_2to13, "Radix 2-13"),
|
||||
(bluesteins, "Bluestein's"),
|
||||
]:
|
||||
# plot bandwidths
|
||||
print(name)
|
||||
plt.scatter(x[indices], gpu_bandwidths[indices], color="green", label="GPU")
|
||||
plt.scatter(x[indices], mps_bandwidths[indices], color="blue", label="MPS")
|
||||
plt.scatter(x[indices], cpu_bandwidths[indices], color="red", label="CPU")
|
||||
plt.title(f"MLX FFT Benchmark -- {name}")
|
||||
plt.xlabel("N")
|
||||
plt.ylabel("Bandwidth (GB/s)")
|
||||
plt.legend()
|
||||
plt.savefig(f"{name}.png")
|
||||
plt.clf()
|
||||
|
||||
av_gpu_bandwidth = np.mean(gpu_bandwidths)
|
||||
av_mps_bandwidth = np.mean(mps_bandwidths)
|
||||
av_cpu_bandwidth = np.mean(cpu_bandwidths)
|
||||
print("Average bandwidths:")
|
||||
print("GPU:", av_gpu_bandwidth)
|
||||
print("MPS:", av_mps_bandwidth)
|
||||
print("CPU:", av_cpu_bandwidth)
|
||||
|
||||
portion_faster = len(np.where(gpu_bandwidths > mps_bandwidths)[0]) / len(x)
|
||||
print("Percent MLX faster than MPS: ", portion_faster * 100)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -64,6 +64,11 @@ if (MLX_METAL_JIT)
|
||||
make_jit_source(unary)
|
||||
make_jit_source(binary)
|
||||
make_jit_source(binary_two)
|
||||
make_jit_source(
|
||||
fft
|
||||
kernels/fft/radix.h
|
||||
kernels/fft/readwrite.h
|
||||
)
|
||||
make_jit_source(ternary)
|
||||
make_jit_source(softmax)
|
||||
make_jit_source(scan)
|
||||
|
@ -1,106 +1,794 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
#include <cassert>
|
||||
#include <complex>
|
||||
#include <map>
|
||||
#include <numeric>
|
||||
#include <set>
|
||||
|
||||
#include "mlx/3rdparty/pocketfft.h"
|
||||
#include "mlx/backend/metal/binary.h"
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/metal/kernels.h"
|
||||
#include "mlx/backend/metal/slicing.h"
|
||||
#include "mlx/backend/metal/unary.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/mlx.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void FFT::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& s = out.primitive().stream();
|
||||
auto& d = metal::device(s.device);
|
||||
using MTLFC = std::tuple<const void*, MTL::DataType, NS::UInteger>;
|
||||
|
||||
auto& in = inputs[0];
|
||||
#define MAX_STOCKHAM_FFT_SIZE 4096
|
||||
#define MAX_RADER_FFT_SIZE 2048
|
||||
#define MAX_BLUESTEIN_FFT_SIZE 2048
|
||||
// Threadgroup memory batching improves throughput for small n
|
||||
#define MIN_THREADGROUP_MEM_SIZE 256
|
||||
// For strided reads/writes, coalesce at least this many complex64s
|
||||
#define MIN_COALESCE_WIDTH 4
|
||||
|
||||
if (axes_.size() == 0 || axes_.size() > 1 || inverse_ ||
|
||||
in.dtype() != complex64 || out.dtype() != complex64) {
|
||||
// Could also fallback to CPU implementation here.
|
||||
throw std::runtime_error(
|
||||
"GPU FFT is only implemented for 1D, forward, complex FFTs.");
|
||||
inline const std::vector<int> supported_radices() {
|
||||
// Ordered by preference in decomposition.
|
||||
return {13, 11, 8, 7, 6, 5, 4, 3, 2};
|
||||
}
|
||||
|
||||
std::vector<int> prime_factors(int n) {
|
||||
int z = 2;
|
||||
std::vector<int> factors;
|
||||
while (z * z <= n) {
|
||||
if (n % z == 0) {
|
||||
factors.push_back(z);
|
||||
n /= z;
|
||||
} else {
|
||||
z++;
|
||||
}
|
||||
}
|
||||
if (n > 1) {
|
||||
factors.push_back(n);
|
||||
}
|
||||
return factors;
|
||||
}
|
||||
|
||||
struct FourStepParams {
|
||||
bool required = false;
|
||||
bool first_step = true;
|
||||
int n1 = 0;
|
||||
int n2 = 0;
|
||||
};
|
||||
|
||||
// Forward Declaration
|
||||
void fft_op(
|
||||
const array& in,
|
||||
array& out,
|
||||
size_t axis,
|
||||
bool inverse,
|
||||
bool real,
|
||||
const FourStepParams four_step_params,
|
||||
bool inplace,
|
||||
const Stream& s);
|
||||
|
||||
struct FFTPlan {
|
||||
int n = 0;
|
||||
// Number of steps for each radix in the Stockham decomposition
|
||||
std::vector<int> stockham;
|
||||
// Number of steps for each radix in the Rader decomposition
|
||||
std::vector<int> rader;
|
||||
// Rader factor, 1 if no rader factors
|
||||
int rader_n = 1;
|
||||
int bluestein_n = -1;
|
||||
// Four step FFT
|
||||
bool four_step = false;
|
||||
int n1 = 0;
|
||||
int n2 = 0;
|
||||
};
|
||||
|
||||
int next_fast_n(int n) {
|
||||
return next_power_of_2(n);
|
||||
}
|
||||
|
||||
std::vector<int> plan_stockham_fft(int n) {
|
||||
auto radices = supported_radices();
|
||||
std::vector<int> plan(radices.size(), 0);
|
||||
int orig_n = n;
|
||||
if (n == 1) {
|
||||
return plan;
|
||||
}
|
||||
for (int i = 0; i < radices.size(); i++) {
|
||||
int radix = radices[i];
|
||||
// Manually tuned radices for powers of 2
|
||||
if (is_power_of_2(orig_n) && orig_n < 512 && radix > 4) {
|
||||
continue;
|
||||
}
|
||||
while (n % radix == 0) {
|
||||
plan[i] += 1;
|
||||
n /= radix;
|
||||
if (n == 1) {
|
||||
return plan;
|
||||
}
|
||||
}
|
||||
}
|
||||
throw std::runtime_error("Unplannable");
|
||||
}
|
||||
|
||||
FFTPlan plan_fft(int n) {
|
||||
auto radices = supported_radices();
|
||||
std::set<int> radices_set(radices.begin(), radices.end());
|
||||
|
||||
FFTPlan plan;
|
||||
plan.n = n;
|
||||
plan.rader = std::vector<int>(radices.size(), 0);
|
||||
auto factors = prime_factors(n);
|
||||
int remaining_n = n;
|
||||
|
||||
// Four Step FFT when N is too large for shared mem.
|
||||
if (n > MAX_STOCKHAM_FFT_SIZE && is_power_of_2(n)) {
|
||||
// For power's of two we have a fast, no transpose four step implementation.
|
||||
plan.four_step = true;
|
||||
// Rough heuristic for choosing faster powers of two when we can
|
||||
plan.n2 = n > 65536 ? 1024 : 64;
|
||||
plan.n1 = n / plan.n2;
|
||||
return plan;
|
||||
} else if (n > MAX_STOCKHAM_FFT_SIZE) {
|
||||
// Otherwise we use a multi-upload Bluestein's
|
||||
plan.four_step = true;
|
||||
plan.bluestein_n = next_fast_n(2 * n - 1);
|
||||
return plan;
|
||||
}
|
||||
|
||||
size_t n = in.shape(axes_[0]);
|
||||
for (int factor : factors) {
|
||||
// Make sure the factor is a supported radix
|
||||
if (radices_set.find(factor) == radices_set.end()) {
|
||||
// We only support a single Rader factor currently
|
||||
// TODO(alexbarron) investigate weirdness with large
|
||||
// Rader sizes -- possibly a compiler issue?
|
||||
if (plan.rader_n > 1 || n > MAX_RADER_FFT_SIZE) {
|
||||
plan.four_step = n > MAX_BLUESTEIN_FFT_SIZE;
|
||||
plan.bluestein_n = next_fast_n(2 * n - 1);
|
||||
plan.stockham = plan_stockham_fft(plan.bluestein_n);
|
||||
plan.rader = std::vector<int>(radices.size(), 0);
|
||||
return plan;
|
||||
}
|
||||
// See if we can use Rader's algorithm to Stockham decompose n - 1
|
||||
auto rader_factors = prime_factors(factor - 1);
|
||||
int last_factor = -1;
|
||||
for (int rf : rader_factors) {
|
||||
// We don't nest Rader's algorithm so if `factor - 1`
|
||||
// isn't Stockham decomposable we give up and do Bluestein's.
|
||||
if (radices_set.find(rf) == radices_set.end()) {
|
||||
plan.four_step = n > MAX_BLUESTEIN_FFT_SIZE;
|
||||
plan.bluestein_n = next_fast_n(2 * n - 1);
|
||||
plan.stockham = plan_stockham_fft(plan.bluestein_n);
|
||||
plan.rader = std::vector<int>(radices.size(), 0);
|
||||
return plan;
|
||||
}
|
||||
}
|
||||
plan.rader = plan_stockham_fft(factor - 1);
|
||||
plan.rader_n = factor;
|
||||
remaining_n /= factor;
|
||||
}
|
||||
}
|
||||
|
||||
if (!is_power_of_2(n) || n > 2048 || n < 4) {
|
||||
throw std::runtime_error(
|
||||
"GPU FFT is only implemented for the powers of 2 from 4 -> 2048");
|
||||
plan.stockham = plan_stockham_fft(remaining_n);
|
||||
return plan;
|
||||
}
|
||||
|
||||
int compute_elems_per_thread(FFTPlan plan) {
|
||||
// Heuristics for selecting an efficient number
|
||||
// of threads to use for a particular mixed-radix FFT.
|
||||
auto n = plan.n;
|
||||
|
||||
std::vector<int> steps;
|
||||
auto radices = supported_radices();
|
||||
steps.insert(steps.end(), plan.stockham.begin(), plan.stockham.end());
|
||||
steps.insert(steps.end(), plan.rader.begin(), plan.rader.end());
|
||||
std::set<int> used_radices;
|
||||
for (int i = 0; i < steps.size(); i++) {
|
||||
int radix = radices[i % radices.size()];
|
||||
if (steps[i] > 0) {
|
||||
used_radices.insert(radix);
|
||||
}
|
||||
}
|
||||
|
||||
// Manual tuning for 7/11/13
|
||||
if (used_radices.find(7) != used_radices.end() &&
|
||||
(used_radices.find(11) != used_radices.end() ||
|
||||
used_radices.find(13) != used_radices.end())) {
|
||||
return 7;
|
||||
} else if (
|
||||
used_radices.find(11) != used_radices.end() &&
|
||||
used_radices.find(13) != used_radices.end()) {
|
||||
return 11;
|
||||
}
|
||||
|
||||
// TODO(alexbarron) Some really weird stuff is going on
|
||||
// for certain `elems_per_thread` on large composite n.
|
||||
// Possibly a compiler issue?
|
||||
if (n == 3159)
|
||||
return 13;
|
||||
if (n == 3645)
|
||||
return 5;
|
||||
if (n == 3969)
|
||||
return 7;
|
||||
if (n == 1982)
|
||||
return 5;
|
||||
|
||||
if (used_radices.size() == 1) {
|
||||
return *(used_radices.begin());
|
||||
}
|
||||
if (used_radices.size() == 2) {
|
||||
if (used_radices.find(11) != used_radices.end() ||
|
||||
used_radices.find(13) != used_radices.end()) {
|
||||
return std::accumulate(used_radices.begin(), used_radices.end(), 0) / 2;
|
||||
}
|
||||
std::vector<int> radix_vec(used_radices.begin(), used_radices.end());
|
||||
return radix_vec[1];
|
||||
}
|
||||
// In all other cases use the second smallest radix.
|
||||
std::vector<int> radix_vec(used_radices.begin(), used_radices.end());
|
||||
return radix_vec[1];
|
||||
}
|
||||
|
||||
// Rader
|
||||
int mod_exp(int x, int y, int n) {
|
||||
int out = 1;
|
||||
while (y) {
|
||||
if (y & 1) {
|
||||
out = out * x % n;
|
||||
}
|
||||
y >>= 1;
|
||||
x = x * x % n;
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
int primitive_root(int n) {
|
||||
auto factors = prime_factors(n - 1);
|
||||
|
||||
for (int r = 2; r < n - 1; r++) {
|
||||
bool found = true;
|
||||
for (int factor : factors) {
|
||||
if (mod_exp(r, (n - 1) / factor, n) == 1) {
|
||||
found = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (found) {
|
||||
return r;
|
||||
}
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
std::tuple<array, array, array> compute_raders_constants(
|
||||
int rader_n,
|
||||
const Stream& s) {
|
||||
int proot = primitive_root(rader_n);
|
||||
// Fermat's little theorem
|
||||
int inv = mod_exp(proot, rader_n - 2, rader_n);
|
||||
std::vector<short> g_q(rader_n - 1);
|
||||
std::vector<short> g_minus_q(rader_n - 1);
|
||||
for (int i = 0; i < rader_n - 1; i++) {
|
||||
g_q[i] = mod_exp(proot, i, rader_n);
|
||||
g_minus_q[i] = mod_exp(inv, i, rader_n);
|
||||
}
|
||||
array g_q_arr(g_q.begin(), {rader_n - 1});
|
||||
array g_minus_q_arr(g_minus_q.begin(), {rader_n - 1});
|
||||
|
||||
std::vector<std::complex<float>> b_q(rader_n - 1);
|
||||
for (int i = 0; i < rader_n - 1; i++) {
|
||||
float pi_i = (float)g_minus_q[i] * -2.0 * M_PI / rader_n;
|
||||
b_q[i] = std::exp(std::complex<float>(0, pi_i));
|
||||
}
|
||||
|
||||
array b_q_fft({rader_n - 1}, complex64, nullptr, {});
|
||||
b_q_fft.set_data(allocator::malloc_or_wait(b_q_fft.nbytes()));
|
||||
auto b_q_fft_ptr =
|
||||
reinterpret_cast<std::complex<float>*>(b_q_fft.data<complex64_t>());
|
||||
std::ptrdiff_t item_size = b_q_fft.itemsize();
|
||||
size_t fft_size = rader_n - 1;
|
||||
// This FFT is always small (<4096, batch 1) so save some overhead
|
||||
// and do it on the CPU
|
||||
pocketfft::c2c(
|
||||
/* shape= */ {fft_size},
|
||||
/* stride_in= */ {item_size},
|
||||
/* stride_out= */ {item_size},
|
||||
/* axes= */ {0},
|
||||
/* forward= */ true,
|
||||
/* data_in= */ b_q.data(),
|
||||
/* data_out= */ b_q_fft_ptr,
|
||||
/* scale= */ 1.0f);
|
||||
return std::make_tuple(b_q_fft, g_q_arr, g_minus_q_arr);
|
||||
}
|
||||
|
||||
// Bluestein
|
||||
std::pair<array, array> compute_bluestein_constants(int n, int bluestein_n) {
|
||||
// We need to calculate the Bluestein twiddle factors
|
||||
// in double precision for the overall numerical stability
|
||||
// of Bluestein's FFT algorithm to be acceptable.
|
||||
//
|
||||
// Metal doesn't support float64, so instead we
|
||||
// manually implement the required operations on cpu.
|
||||
//
|
||||
// In numpy:
|
||||
// w_k = np.exp(-1j * np.pi / N * (np.arange(-N + 1, N) ** 2))
|
||||
// w_q = np.fft.fft(1/w_k)
|
||||
// return w_k, w_q
|
||||
int length = 2 * n - 1;
|
||||
|
||||
std::vector<std::complex<float>> w_k_vec(n);
|
||||
std::vector<std::complex<float>> w_q_vec(bluestein_n, 0);
|
||||
|
||||
for (int i = -n + 1; i < n; i++) {
|
||||
double theta = pow(i, 2) * M_PI / (double)n;
|
||||
w_q_vec[i + n - 1] = std::exp(std::complex<double>(0, theta));
|
||||
if (i >= 0) {
|
||||
w_k_vec[i] = std::exp(std::complex<double>(0, -theta));
|
||||
}
|
||||
}
|
||||
|
||||
array w_k({n}, complex64, nullptr, {});
|
||||
w_k.set_data(allocator::malloc_or_wait(w_k.nbytes()));
|
||||
std::copy(w_k_vec.begin(), w_k_vec.end(), w_k.data<complex64_t>());
|
||||
|
||||
array w_q({bluestein_n}, complex64, nullptr, {});
|
||||
w_q.set_data(allocator::malloc_or_wait(w_q.nbytes()));
|
||||
auto w_q_ptr =
|
||||
reinterpret_cast<std::complex<float>*>(w_q.data<complex64_t>());
|
||||
|
||||
std::ptrdiff_t item_size = w_q.itemsize();
|
||||
size_t fft_size = bluestein_n;
|
||||
pocketfft::c2c(
|
||||
/* shape= */ {fft_size},
|
||||
/* stride_in= */ {item_size},
|
||||
/* stride_out= */ {item_size},
|
||||
/* axes= */ {0},
|
||||
/* forward= */ true,
|
||||
/* data_in= */ w_q_vec.data(),
|
||||
/* data_out= */ w_q_ptr,
|
||||
/* scale= */ 1.0f);
|
||||
return std::make_tuple(w_k, w_q);
|
||||
}
|
||||
|
||||
void multi_upload_bluestein_fft(
|
||||
const array& in,
|
||||
array& out,
|
||||
size_t axis,
|
||||
bool inverse,
|
||||
bool real,
|
||||
FFTPlan& plan,
|
||||
std::vector<array> copies,
|
||||
const Stream& s) {
|
||||
// TODO(alexbarron) Implement fused kernels for mutli upload bluestein's
|
||||
// algorithm
|
||||
int n = inverse ? out.shape(axis) : in.shape(axis);
|
||||
auto [w_k, w_q] = compute_bluestein_constants(n, plan.bluestein_n);
|
||||
|
||||
// Broadcast w_q and w_k to the batch size
|
||||
std::vector<size_t> b_strides(in.ndim(), 0);
|
||||
b_strides[axis] = 1;
|
||||
array w_k_broadcast({}, complex64, nullptr, {});
|
||||
array w_q_broadcast({}, complex64, nullptr, {});
|
||||
w_k_broadcast.copy_shared_buffer(w_k, b_strides, {}, w_k.data_size());
|
||||
w_q_broadcast.copy_shared_buffer(w_q, b_strides, {}, w_q.data_size());
|
||||
|
||||
auto temp_shape = inverse ? out.shape() : in.shape();
|
||||
array temp(temp_shape, complex64, nullptr, {});
|
||||
array temp1(temp_shape, complex64, nullptr, {});
|
||||
|
||||
if (real && !inverse) {
|
||||
// Convert float32->complex64
|
||||
copy_gpu(in, temp, CopyType::General, s);
|
||||
} else if (real && inverse) {
|
||||
int back_offset = n % 2 == 0 ? 2 : 1;
|
||||
auto slice_shape = in.shape();
|
||||
slice_shape[axis] -= back_offset;
|
||||
array slice_temp(slice_shape, complex64, nullptr, {});
|
||||
array conj_temp(in.shape(), complex64, nullptr, {});
|
||||
copies.push_back(slice_temp);
|
||||
copies.push_back(conj_temp);
|
||||
|
||||
std::vector<int> rstarts(in.ndim(), 0);
|
||||
std::vector<int> rstrides(in.ndim(), 1);
|
||||
rstarts[axis] = in.shape(axis) - back_offset;
|
||||
rstrides[axis] = -1;
|
||||
unary_op_gpu({in}, conj_temp, "conj", s);
|
||||
slice_gpu(in, slice_temp, rstarts, rstrides, s);
|
||||
concatenate_gpu({conj_temp, slice_temp}, temp, (int)axis, s);
|
||||
} else if (inverse) {
|
||||
unary_op_gpu({in}, temp, "conj", s);
|
||||
} else {
|
||||
temp.copy_shared_buffer(in);
|
||||
}
|
||||
|
||||
binary_op_gpu({temp, w_k_broadcast}, temp1, "mul", s);
|
||||
|
||||
std::vector<std::pair<int, int>> pads;
|
||||
auto padded_shape = out.shape();
|
||||
padded_shape[axis] = plan.bluestein_n;
|
||||
array pad_temp(padded_shape, complex64, nullptr, {});
|
||||
pad_gpu(temp1, array(complex64_t{0.0f, 0.0f}), pad_temp, {(int)axis}, {0}, s);
|
||||
|
||||
array pad_temp1(padded_shape, complex64, nullptr, {});
|
||||
fft_op(
|
||||
pad_temp,
|
||||
pad_temp1,
|
||||
axis,
|
||||
/*inverse=*/false,
|
||||
/*real=*/false,
|
||||
FourStepParams(),
|
||||
/*inplace=*/false,
|
||||
s);
|
||||
|
||||
binary_op_gpu_inplace({pad_temp1, w_q_broadcast}, pad_temp, "mul", s);
|
||||
|
||||
fft_op(
|
||||
pad_temp,
|
||||
pad_temp1,
|
||||
axis,
|
||||
/* inverse= */ true,
|
||||
/* real= */ false,
|
||||
FourStepParams(),
|
||||
/*inplace=*/true,
|
||||
s);
|
||||
|
||||
int offset = plan.bluestein_n - (2 * n - 1);
|
||||
std::vector<int> starts(in.ndim(), 0);
|
||||
std::vector<int> strides(in.ndim(), 1);
|
||||
starts[axis] = plan.bluestein_n - offset - n;
|
||||
slice_gpu(pad_temp1, temp, starts, strides, s);
|
||||
|
||||
binary_op_gpu_inplace({temp, w_k_broadcast}, temp1, "mul", s);
|
||||
|
||||
if (real && !inverse) {
|
||||
std::vector<int> rstarts(in.ndim(), 0);
|
||||
std::vector<int> rstrides(in.ndim(), 1);
|
||||
slice_gpu(temp1, out, rstarts, strides, s);
|
||||
} else if (real && inverse) {
|
||||
std::vector<size_t> b_strides(in.ndim(), 0);
|
||||
auto inv_n = array({1.0f / n}, {1}, float32);
|
||||
array temp_float(out.shape(), out.dtype(), nullptr, {});
|
||||
copies.push_back(temp_float);
|
||||
copies.push_back(inv_n);
|
||||
|
||||
copy_gpu(temp1, temp_float, CopyType::General, s);
|
||||
binary_op_gpu({temp_float, inv_n}, out, "mul", s);
|
||||
} else if (inverse) {
|
||||
auto inv_n = array({1.0f / n}, {1}, complex64);
|
||||
unary_op_gpu({temp1}, temp, "conj", s);
|
||||
binary_op_gpu({temp, inv_n}, out, "mul", s);
|
||||
copies.push_back(inv_n);
|
||||
} else {
|
||||
out.copy_shared_buffer(temp1);
|
||||
}
|
||||
|
||||
copies.push_back(w_k);
|
||||
copies.push_back(w_q);
|
||||
copies.push_back(w_k_broadcast);
|
||||
copies.push_back(w_q_broadcast);
|
||||
copies.push_back(temp);
|
||||
copies.push_back(temp1);
|
||||
copies.push_back(pad_temp);
|
||||
copies.push_back(pad_temp1);
|
||||
}
|
||||
|
||||
void four_step_fft(
|
||||
const array& in,
|
||||
array& out,
|
||||
size_t axis,
|
||||
bool inverse,
|
||||
bool real,
|
||||
FFTPlan& plan,
|
||||
std::vector<array> copies,
|
||||
const Stream& s) {
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
if (plan.bluestein_n == -1) {
|
||||
// Fast no transpose implementation for powers of 2.
|
||||
FourStepParams four_step_params = {
|
||||
/* required= */ true, /* first_step= */ true, plan.n1, plan.n2};
|
||||
auto temp_shape = (real && inverse) ? out.shape() : in.shape();
|
||||
array temp(temp_shape, complex64, nullptr, {});
|
||||
fft_op(
|
||||
in, temp, axis, inverse, real, four_step_params, /*inplace=*/false, s);
|
||||
four_step_params.first_step = false;
|
||||
fft_op(
|
||||
temp, out, axis, inverse, real, four_step_params, /*inplace=*/false, s);
|
||||
copies.push_back(temp);
|
||||
} else {
|
||||
multi_upload_bluestein_fft(in, out, axis, inverse, real, plan, copies, s);
|
||||
}
|
||||
}
|
||||
|
||||
void fft_op(
|
||||
const array& in,
|
||||
array& out,
|
||||
size_t axis,
|
||||
bool inverse,
|
||||
bool real,
|
||||
const FourStepParams four_step_params,
|
||||
bool inplace,
|
||||
const Stream& s) {
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
size_t n = out.dtype() == float32 ? out.shape(axis) : in.shape(axis);
|
||||
if (n == 1) {
|
||||
out.copy_shared_buffer(in);
|
||||
return;
|
||||
}
|
||||
|
||||
if (four_step_params.required) {
|
||||
// Four Step FFT decomposes into two FFTs: n1 on columns, n2 on rows
|
||||
n = four_step_params.first_step ? four_step_params.n1 : four_step_params.n2;
|
||||
}
|
||||
|
||||
// Make sure that the array is contiguous and has stride 1 in the FFT dim
|
||||
std::vector<array> copies;
|
||||
auto check_input = [this, &copies, &s](const array& x) {
|
||||
auto check_input = [&axis, &copies, &s](const array& x) {
|
||||
// TODO: Pass the strides to the kernel so
|
||||
// we can avoid the copy when x is not contiguous.
|
||||
bool no_copy = x.strides()[axes_[0]] == 1 && x.flags().row_contiguous ||
|
||||
x.flags().col_contiguous;
|
||||
bool no_copy = x.strides()[axis] == 1 &&
|
||||
(x.flags().row_contiguous || x.flags().col_contiguous);
|
||||
if (no_copy) {
|
||||
return x;
|
||||
} else {
|
||||
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
||||
std::vector<size_t> strides;
|
||||
size_t cur_stride = x.shape(axes_[0]);
|
||||
for (int axis = 0; axis < x.ndim(); axis++) {
|
||||
if (axis == axes_[0]) {
|
||||
size_t cur_stride = x.shape(axis);
|
||||
for (int a = 0; a < x.ndim(); a++) {
|
||||
if (a == axis) {
|
||||
strides.push_back(1);
|
||||
} else {
|
||||
strides.push_back(cur_stride);
|
||||
cur_stride *= x.shape(axis);
|
||||
cur_stride *= x.shape(a);
|
||||
}
|
||||
}
|
||||
|
||||
auto flags = x.flags();
|
||||
size_t f_stride = 1;
|
||||
size_t b_stride = 1;
|
||||
flags.col_contiguous = true;
|
||||
flags.row_contiguous = true;
|
||||
for (int i = 0, ri = x.ndim() - 1; i < x.ndim(); ++i, --ri) {
|
||||
flags.col_contiguous &= (strides[i] == f_stride || x.shape(i) == 1);
|
||||
f_stride *= x.shape(i);
|
||||
flags.row_contiguous &= (strides[ri] == b_stride || x.shape(ri) == 1);
|
||||
b_stride *= x.shape(ri);
|
||||
}
|
||||
// This is probably over-conservative
|
||||
flags.contiguous = false;
|
||||
auto [data_size, is_row_contiguous, is_col_contiguous] =
|
||||
check_contiguity(x.shape(), strides);
|
||||
|
||||
flags.col_contiguous = is_row_contiguous;
|
||||
flags.row_contiguous = is_col_contiguous;
|
||||
flags.contiguous = data_size == x_copy.size();
|
||||
|
||||
x_copy.set_data(
|
||||
allocator::malloc_or_wait(x.nbytes()), x.data_size(), strides, flags);
|
||||
allocator::malloc_or_wait(x.nbytes()), data_size, strides, flags);
|
||||
copy_gpu_inplace(x, x_copy, CopyType::GeneralGeneral, s);
|
||||
copies.push_back(x_copy);
|
||||
return x_copy;
|
||||
}
|
||||
};
|
||||
const array& in_contiguous = check_input(inputs[0]);
|
||||
const array& in_contiguous = check_input(in);
|
||||
|
||||
// real to complex: n -> (n/2)+1
|
||||
// complex to real: (n/2)+1 -> n
|
||||
auto out_strides = in_contiguous.strides();
|
||||
size_t out_data_size = in_contiguous.data_size();
|
||||
if (in.shape(axis) != out.shape(axis)) {
|
||||
for (int i = 0; i < out_strides.size(); i++) {
|
||||
if (out_strides[i] != 1) {
|
||||
out_strides[i] = out_strides[i] / in.shape(axis) * out.shape(axis);
|
||||
}
|
||||
}
|
||||
out_data_size = out_data_size / in.shape(axis) * out.shape(axis);
|
||||
}
|
||||
|
||||
auto plan = plan_fft(n);
|
||||
if (plan.four_step) {
|
||||
four_step_fft(in, out, axis, inverse, real, plan, copies, s);
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||
return;
|
||||
}
|
||||
|
||||
// TODO: allow donation here
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(out.nbytes()),
|
||||
in_contiguous.data_size(),
|
||||
in_contiguous.strides(),
|
||||
in_contiguous.flags());
|
||||
if (!inplace) {
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(out.nbytes()),
|
||||
out_data_size,
|
||||
out_strides,
|
||||
in_contiguous.flags());
|
||||
}
|
||||
|
||||
// We use n / 4 threads by default since radix-4
|
||||
// is the largest single threaded radix butterfly
|
||||
// we currently implement.
|
||||
size_t m = n / 4;
|
||||
size_t batch = in.size() / in.shape(axes_[0]);
|
||||
auto radices = supported_radices();
|
||||
int fft_size = plan.bluestein_n > 0 ? plan.bluestein_n : n;
|
||||
|
||||
// Setup function constants
|
||||
bool power_of_2 = is_power_of_2(fft_size);
|
||||
|
||||
auto make_int = [](int* a, int i) {
|
||||
return std::make_tuple(a, MTL::DataType::DataTypeInt, i);
|
||||
};
|
||||
auto make_bool = [](bool* a, int i) {
|
||||
return std::make_tuple(a, MTL::DataType::DataTypeBool, i);
|
||||
};
|
||||
|
||||
std::vector<MTLFC> func_consts = {
|
||||
make_bool(&inverse, 0), make_bool(&power_of_2, 1)};
|
||||
|
||||
// Start of radix/rader step constants
|
||||
int index = 4;
|
||||
for (int i = 0; i < plan.stockham.size(); i++) {
|
||||
func_consts.push_back(make_int(&plan.stockham[i], index));
|
||||
index += 1;
|
||||
}
|
||||
for (int i = 0; i < plan.rader.size(); i++) {
|
||||
func_consts.push_back(make_int(&plan.rader[i], index));
|
||||
index += 1;
|
||||
}
|
||||
int elems_per_thread = compute_elems_per_thread(plan);
|
||||
func_consts.push_back(make_int(&elems_per_thread, 2));
|
||||
|
||||
int rader_m = n / plan.rader_n;
|
||||
func_consts.push_back(make_int(&rader_m, 3));
|
||||
|
||||
// The overall number of FFTs we're going to compute for this input
|
||||
int size = out.dtype() == float32 ? out.size() : in.size();
|
||||
if (real && inverse && four_step_params.required) {
|
||||
size = out.size();
|
||||
}
|
||||
int total_batch_size = size / n;
|
||||
int threads_per_fft = (fft_size + elems_per_thread - 1) / elems_per_thread;
|
||||
|
||||
// We batch among threadgroups for improved efficiency when n is small
|
||||
int threadgroup_batch_size = std::max(MIN_THREADGROUP_MEM_SIZE / fft_size, 1);
|
||||
if (four_step_params.required) {
|
||||
// Require a threadgroup batch size of at least 4 for four step FFT
|
||||
// so we can coalesce the memory accesses.
|
||||
threadgroup_batch_size =
|
||||
std::max(threadgroup_batch_size, MIN_COALESCE_WIDTH);
|
||||
}
|
||||
int threadgroup_mem_size = next_power_of_2(threadgroup_batch_size * fft_size);
|
||||
// FFTs up to 2^20 are currently supported
|
||||
assert(threadgroup_mem_size <= MAX_STOCKHAM_FFT_SIZE);
|
||||
|
||||
// ceil divide
|
||||
int batch_size =
|
||||
(total_batch_size + threadgroup_batch_size - 1) / threadgroup_batch_size;
|
||||
|
||||
if (real && !four_step_params.required) {
|
||||
// We can perform 2 RFFTs at once so the batch size is halved.
|
||||
batch_size = (batch_size + 2 - 1) / 2;
|
||||
}
|
||||
int out_buffer_size = out.size();
|
||||
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto in_type_str = in.dtype() == float32 ? "float" : "float2";
|
||||
auto out_type_str = out.dtype() == float32 ? "float" : "float2";
|
||||
// Only required by four step
|
||||
int step = -1;
|
||||
{
|
||||
std::ostringstream kname;
|
||||
kname << "fft_" << n;
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
std::string inv_string = inverse ? "true" : "false";
|
||||
std::string real_string = real ? "true" : "false";
|
||||
if (plan.bluestein_n > 0) {
|
||||
kname << "bluestein_fft_mem_" << threadgroup_mem_size << "_"
|
||||
<< in_type_str << "_" << out_type_str;
|
||||
} else if (plan.rader_n > 1) {
|
||||
kname << "rader_fft_mem_" << threadgroup_mem_size << "_" << in_type_str
|
||||
<< "_" << out_type_str;
|
||||
} else if (four_step_params.required) {
|
||||
step = four_step_params.first_step ? 0 : 1;
|
||||
kname << "four_step_mem_" << threadgroup_mem_size << "_" << in_type_str
|
||||
<< "_" << out_type_str << "_" << step << "_" << real_string;
|
||||
} else {
|
||||
kname << "fft_mem_" << threadgroup_mem_size << "_" << in_type_str << "_"
|
||||
<< out_type_str;
|
||||
}
|
||||
std::string base_name = kname.str();
|
||||
// We use a specialized kernel for each FFT size
|
||||
kname << "_n" << fft_size << "_inv_" << inverse;
|
||||
std::string hash_name = kname.str();
|
||||
auto kernel = get_fft_kernel(
|
||||
d,
|
||||
base_name,
|
||||
hash_name,
|
||||
threadgroup_mem_size,
|
||||
in_type_str,
|
||||
out_type_str,
|
||||
step,
|
||||
real,
|
||||
func_consts);
|
||||
|
||||
bool donated = in.data_shared_ptr() == nullptr;
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_input_array(in_contiguous, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
|
||||
auto group_dims = MTL::Size(1, m, 1);
|
||||
auto grid_dims = MTL::Size(batch, m, 1);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
if (plan.bluestein_n > 0) {
|
||||
// Precomputed twiddle factors for Bluestein's
|
||||
auto [w_k, w_q] = compute_bluestein_constants(n, plan.bluestein_n);
|
||||
copies.push_back(w_q);
|
||||
copies.push_back(w_k);
|
||||
|
||||
compute_encoder.set_input_array(w_q, 2); // w_q
|
||||
compute_encoder.set_input_array(w_k, 3); // w_k
|
||||
compute_encoder->setBytes(&n, sizeof(int), 4);
|
||||
compute_encoder->setBytes(&plan.bluestein_n, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&total_batch_size, sizeof(int), 6);
|
||||
} else if (plan.rader_n > 1) {
|
||||
auto [b_q, g_q, g_minus_q] = compute_raders_constants(plan.rader_n, s);
|
||||
copies.push_back(b_q);
|
||||
copies.push_back(g_q);
|
||||
copies.push_back(g_minus_q);
|
||||
|
||||
compute_encoder.set_input_array(b_q, 2);
|
||||
compute_encoder.set_input_array(g_q, 3);
|
||||
compute_encoder.set_input_array(g_minus_q, 4);
|
||||
compute_encoder->setBytes(&n, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&total_batch_size, sizeof(int), 6);
|
||||
compute_encoder->setBytes(&plan.rader_n, sizeof(int), 7);
|
||||
} else if (four_step_params.required) {
|
||||
compute_encoder->setBytes(&four_step_params.n1, sizeof(int), 2);
|
||||
compute_encoder->setBytes(&four_step_params.n2, sizeof(int), 3);
|
||||
compute_encoder->setBytes(&total_batch_size, sizeof(int), 4);
|
||||
} else {
|
||||
compute_encoder->setBytes(&n, sizeof(int), 2);
|
||||
compute_encoder->setBytes(&total_batch_size, sizeof(int), 3);
|
||||
}
|
||||
|
||||
auto group_dims = MTL::Size(1, threadgroup_batch_size, threads_per_fft);
|
||||
auto grid_dims =
|
||||
MTL::Size(batch_size, threadgroup_batch_size, threads_per_fft);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||
}
|
||||
|
||||
void fft_op(
|
||||
const array& in,
|
||||
array& out,
|
||||
size_t axis,
|
||||
bool inverse,
|
||||
bool real,
|
||||
bool inplace,
|
||||
const Stream& s) {
|
||||
fft_op(in, out, axis, inverse, real, FourStepParams(), inplace, s);
|
||||
}
|
||||
|
||||
void nd_fft_op(
|
||||
const array& in,
|
||||
array& out,
|
||||
const std::vector<size_t>& axes,
|
||||
bool inverse,
|
||||
bool real,
|
||||
const Stream& s) {
|
||||
// Perform ND FFT on GPU as a series of 1D FFTs
|
||||
auto temp_shape = inverse ? in.shape() : out.shape();
|
||||
array temp1(temp_shape, complex64, nullptr, {});
|
||||
array temp2(temp_shape, complex64, nullptr, {});
|
||||
std::vector<array> temp_arrs = {temp1, temp2};
|
||||
for (int i = axes.size() - 1; i >= 0; i--) {
|
||||
int reverse_index = axes.size() - i - 1;
|
||||
// For 5D and above, we don't want to reallocate our two temporary arrays
|
||||
bool inplace = reverse_index >= 3 && i != 0;
|
||||
// Opposite order for fft vs ifft
|
||||
int index = inverse ? reverse_index : i;
|
||||
size_t axis = axes[index];
|
||||
// Mirror np.fft.(i)rfftn and perform a real transform
|
||||
// only on the final axis.
|
||||
bool step_real = (real && index == axes.size() - 1);
|
||||
int step_shape = inverse ? out.shape(axis) : in.shape(axis);
|
||||
const array& in_arr = i == axes.size() - 1 ? in : temp_arrs[1 - i % 2];
|
||||
array& out_arr = i == 0 ? out : temp_arrs[i % 2];
|
||||
fft_op(in_arr, out_arr, axis, inverse, step_real, inplace, s);
|
||||
}
|
||||
|
||||
std::vector<array> copies = {temp1, temp2};
|
||||
auto& d = metal::device(s.device);
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||
}
|
||||
|
||||
void FFT::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& s = stream();
|
||||
auto& in = inputs[0];
|
||||
|
||||
if (axes_.size() > 1) {
|
||||
nd_fft_op(in, out, axes_, inverse_, real_, s);
|
||||
} else {
|
||||
fft_op(in, out, axes_[0], inverse_, real_, /*inplace=*/false, s);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
53
mlx/backend/metal/jit/fft.h
Normal file
53
mlx/backend/metal/jit/fft.h
Normal file
@ -0,0 +1,53 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
constexpr std::string_view fft_kernel = R"(
|
||||
template [[host_name("{name}")]] [[kernel]] void
|
||||
fft<{tg_mem_size}, {in_T}, {out_T}>(
|
||||
const device {in_T}* in [[buffer(0)]],
|
||||
device {out_T}* out [[buffer(1)]],
|
||||
constant const int& n,
|
||||
constant const int& batch_size,
|
||||
uint3 elem [[thread_position_in_grid]],
|
||||
uint3 grid [[threads_per_grid]]);
|
||||
)";
|
||||
|
||||
constexpr std::string_view rader_fft_kernel = R"(
|
||||
template [[host_name("{name}")]] [[kernel]] void
|
||||
rader_fft<{tg_mem_size}, {in_T}, {out_T}>(
|
||||
const device {in_T}* in [[buffer(0)]],
|
||||
device {out_T}* out [[buffer(1)]],
|
||||
const device float2* raders_b_q [[buffer(2)]],
|
||||
const device short* raders_g_q [[buffer(3)]],
|
||||
const device short* raders_g_minus_q [[buffer(4)]],
|
||||
constant const int& n,
|
||||
constant const int& batch_size,
|
||||
constant const int& rader_n,
|
||||
uint3 elem [[thread_position_in_grid]],
|
||||
uint3 grid [[threads_per_grid]]);
|
||||
)";
|
||||
|
||||
constexpr std::string_view bluestein_fft_kernel = R"(
|
||||
template [[host_name("{name}")]] [[kernel]] void
|
||||
bluestein_fft<{tg_mem_size}, {in_T}, {out_T}>(
|
||||
const device {in_T}* in [[buffer(0)]],
|
||||
device {out_T}* out [[buffer(1)]],
|
||||
const device float2* w_q [[buffer(2)]],
|
||||
const device float2* w_k [[buffer(3)]],
|
||||
constant const int& length,
|
||||
constant const int& n,
|
||||
constant const int& batch_size,
|
||||
uint3 elem [[thread_position_in_grid]],
|
||||
uint3 grid [[threads_per_grid]]);
|
||||
)";
|
||||
|
||||
constexpr std::string_view four_step_fft_kernel = R"(
|
||||
template [[host_name("{name}")]] [[kernel]] void
|
||||
four_step_fft<{tg_mem_size}, {in_T}, {out_T}, {step}, {real}>(
|
||||
const device {in_T}* in [[buffer(0)]],
|
||||
device {out_T}* out [[buffer(1)]],
|
||||
constant const int& n1,
|
||||
constant const int& n2,
|
||||
constant const int& batch_size,
|
||||
uint3 elem [[thread_position_in_grid]],
|
||||
uint3 grid [[threads_per_grid]]);
|
||||
)";
|
@ -17,6 +17,7 @@ const char* unary();
|
||||
const char* binary();
|
||||
const char* binary_two();
|
||||
const char* copy();
|
||||
const char* fft();
|
||||
const char* ternary();
|
||||
const char* scan();
|
||||
const char* softmax();
|
||||
|
@ -6,6 +6,7 @@
|
||||
#include "mlx/backend/metal/jit/binary.h"
|
||||
#include "mlx/backend/metal/jit/binary_two.h"
|
||||
#include "mlx/backend/metal/jit/copy.h"
|
||||
#include "mlx/backend/metal/jit/fft.h"
|
||||
#include "mlx/backend/metal/jit/includes.h"
|
||||
#include "mlx/backend/metal/jit/reduce.h"
|
||||
#include "mlx/backend/metal/jit/scan.h"
|
||||
@ -489,4 +490,51 @@ MTL::ComputePipelineState* get_steel_conv_general_kernel(
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_fft_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const std::string& hash_name,
|
||||
const int tg_mem_size,
|
||||
const std::string& in_type,
|
||||
const std::string& out_type,
|
||||
int step,
|
||||
bool real,
|
||||
const metal::MTLFCList& func_consts) {
|
||||
const auto& lib_name = kernel_name;
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel_source;
|
||||
std::string kernel_string;
|
||||
if (lib_name.find("bluestein") != std::string::npos) {
|
||||
kernel_string = bluestein_fft_kernel;
|
||||
} else if (lib_name.find("rader") != std::string::npos) {
|
||||
kernel_string = rader_fft_kernel;
|
||||
} else if (lib_name.find("four_step") != std::string::npos) {
|
||||
kernel_string = four_step_fft_kernel;
|
||||
} else {
|
||||
kernel_string = fft_kernel;
|
||||
}
|
||||
kernel_source << metal::fft();
|
||||
if (lib_name.find("four_step") != std::string::npos) {
|
||||
kernel_source << fmt::format(
|
||||
kernel_string,
|
||||
"name"_a = lib_name,
|
||||
"tg_mem_size"_a = tg_mem_size,
|
||||
"in_T"_a = in_type,
|
||||
"out_T"_a = out_type,
|
||||
"step"_a = step,
|
||||
"real"_a = real);
|
||||
} else {
|
||||
kernel_source << fmt::format(
|
||||
kernel_string,
|
||||
"name"_a = lib_name,
|
||||
"tg_mem_size"_a = tg_mem_size,
|
||||
"in_T"_a = in_type,
|
||||
"out_T"_a = out_type);
|
||||
}
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
return d.get_kernel(kernel_name, lib, hash_name, func_consts);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -155,4 +155,15 @@ MTL::ComputePipelineState* get_steel_conv_general_kernel(
|
||||
int wm,
|
||||
int wn);
|
||||
|
||||
MTL::ComputePipelineState* get_fft_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const std::string& hash_name,
|
||||
const int tg_mem_size,
|
||||
const std::string& in_type,
|
||||
const std::string& out_type,
|
||||
int step,
|
||||
bool real,
|
||||
const metal::MTLFCList& func_consts);
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -48,6 +48,9 @@ set(
|
||||
binary.h
|
||||
ternary.h
|
||||
copy.h
|
||||
fft.h
|
||||
fft/radix.h
|
||||
fft/readwrite.h
|
||||
softmax.h
|
||||
sort.h
|
||||
scan.h
|
||||
|
486
mlx/backend/metal/kernels/fft.h
Normal file
486
mlx/backend/metal/kernels/fft.h
Normal file
@ -0,0 +1,486 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
// Metal FFT using Stockham's algorithm
|
||||
//
|
||||
// References:
|
||||
// - VkFFT (https://github.com/DTolm/VkFFT)
|
||||
// - Eric Bainville's excellent page (http://www.bealto.com/gpu-fft.html)
|
||||
|
||||
#include <metal_common>
|
||||
|
||||
#include "mlx/backend/metal/kernels/fft/radix.h"
|
||||
#include "mlx/backend/metal/kernels/fft/readwrite.h"
|
||||
#include "mlx/backend/metal/kernels/steel/defines.h"
|
||||
|
||||
using namespace metal;
|
||||
|
||||
#define MAX_RADIX 13
|
||||
// Reached when elems_per_thread_ = 6, max_radix = 13
|
||||
// and some threads have to do 3 radix 6s requiring 18 float2s.
|
||||
#define MAX_OUTPUT_SIZE 18
|
||||
|
||||
// Specialize for a particular value of N at runtime
|
||||
STEEL_CONST bool inv_ [[function_constant(0)]];
|
||||
STEEL_CONST bool is_power_of_2_ [[function_constant(1)]];
|
||||
STEEL_CONST int elems_per_thread_ [[function_constant(2)]];
|
||||
// rader_m = n / rader_n
|
||||
STEEL_CONST int rader_m_ [[function_constant(3)]];
|
||||
// Stockham steps
|
||||
STEEL_CONST int radix_13_steps_ [[function_constant(4)]];
|
||||
STEEL_CONST int radix_11_steps_ [[function_constant(5)]];
|
||||
STEEL_CONST int radix_8_steps_ [[function_constant(6)]];
|
||||
STEEL_CONST int radix_7_steps_ [[function_constant(7)]];
|
||||
STEEL_CONST int radix_6_steps_ [[function_constant(8)]];
|
||||
STEEL_CONST int radix_5_steps_ [[function_constant(9)]];
|
||||
STEEL_CONST int radix_4_steps_ [[function_constant(10)]];
|
||||
STEEL_CONST int radix_3_steps_ [[function_constant(11)]];
|
||||
STEEL_CONST int radix_2_steps_ [[function_constant(12)]];
|
||||
// Rader steps
|
||||
STEEL_CONST int rader_13_steps_ [[function_constant(13)]];
|
||||
STEEL_CONST int rader_11_steps_ [[function_constant(14)]];
|
||||
STEEL_CONST int rader_8_steps_ [[function_constant(15)]];
|
||||
STEEL_CONST int rader_7_steps_ [[function_constant(16)]];
|
||||
STEEL_CONST int rader_6_steps_ [[function_constant(17)]];
|
||||
STEEL_CONST int rader_5_steps_ [[function_constant(18)]];
|
||||
STEEL_CONST int rader_4_steps_ [[function_constant(19)]];
|
||||
STEEL_CONST int rader_3_steps_ [[function_constant(20)]];
|
||||
STEEL_CONST int rader_2_steps_ [[function_constant(21)]];
|
||||
|
||||
// See "radix.h" for radix codelets
|
||||
typedef void (*RadixFunc)(thread float2*, thread float2*);
|
||||
|
||||
// Perform a single radix n butterfly with appropriate twiddles
|
||||
template <int radix, RadixFunc radix_func>
|
||||
METAL_FUNC void radix_butterfly(
|
||||
int i,
|
||||
int p,
|
||||
thread float2* x,
|
||||
thread short* indices,
|
||||
thread float2* y) {
|
||||
// i: the index in the overall DFT that we're processing.
|
||||
// p: the size of the DFTs we're merging at this step.
|
||||
// m: how many threads are working on this DFT.
|
||||
int k, j;
|
||||
|
||||
// Use faster bitwise operations when working with powers of two
|
||||
constexpr bool radix_p_2 = (radix & (radix - 1)) == 0;
|
||||
if (radix_p_2 && is_power_of_2_) {
|
||||
constexpr short power = __builtin_ctz(radix);
|
||||
k = i & (p - 1);
|
||||
j = ((i - k) << power) + k;
|
||||
} else {
|
||||
k = i % p;
|
||||
j = (i / p) * radix * p + k;
|
||||
}
|
||||
|
||||
// Apply twiddles
|
||||
if (p > 1) {
|
||||
float2 twiddle_1 = get_twiddle(k, radix * p);
|
||||
float2 twiddle = twiddle_1;
|
||||
x[1] = complex_mul(x[1], twiddle);
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (int t = 2; t < radix; t++) {
|
||||
twiddle = complex_mul(twiddle, twiddle_1);
|
||||
x[t] = complex_mul(x[t], twiddle);
|
||||
}
|
||||
}
|
||||
|
||||
radix_func(x, y);
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (int t = 0; t < radix; t++) {
|
||||
indices[t] = j + t * p;
|
||||
}
|
||||
}
|
||||
|
||||
// Perform all the radix steps required for a
|
||||
// particular radix size n.
|
||||
template <int radix, RadixFunc radix_func>
|
||||
METAL_FUNC void radix_n_steps(
|
||||
int i,
|
||||
thread int* p,
|
||||
int m,
|
||||
int n,
|
||||
int num_steps,
|
||||
thread float2* inputs,
|
||||
thread short* indices,
|
||||
thread float2* values,
|
||||
threadgroup float2* buf) {
|
||||
int m_r = n / radix;
|
||||
// When combining different sized radices, we have to do
|
||||
// multiple butterflies in a single thread.
|
||||
// E.g. n = 28 = 4 * 7
|
||||
// 4 threads, 7 elems_per_thread
|
||||
// All threads do 1 radix7 butterfly.
|
||||
// 3 threads do 2 radix4 butterflies.
|
||||
// 1 thread does 1 radix4 butterfly.
|
||||
int max_radices_per_thread = (elems_per_thread_ + radix - 1) / radix;
|
||||
|
||||
int index = 0;
|
||||
int r_index = 0;
|
||||
for (int s = 0; s < num_steps; s++) {
|
||||
for (int t = 0; t < max_radices_per_thread; t++) {
|
||||
index = i + t * m;
|
||||
if (index < m_r) {
|
||||
for (int r = 0; r < radix; r++) {
|
||||
inputs[r] = buf[index + r * m_r];
|
||||
}
|
||||
radix_butterfly<radix, radix_func>(
|
||||
index, *p, inputs, indices + t * radix, values + t * radix);
|
||||
}
|
||||
}
|
||||
|
||||
// Wait until all threads have read their inputs into thread local mem
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
for (int t = 0; t < max_radices_per_thread; t++) {
|
||||
index = i + t * m;
|
||||
if (index < m_r) {
|
||||
for (int r = 0; r < radix; r++) {
|
||||
r_index = t * radix + r;
|
||||
buf[indices[r_index]] = values[r_index];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Wait until all threads have written back to threadgroup mem
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
*p *= radix;
|
||||
}
|
||||
}
|
||||
|
||||
#define RADIX_STEP(radix, radix_func, num_steps) \
|
||||
radix_n_steps<radix, radix_func>( \
|
||||
fft_idx, p, m, n, num_steps, inputs, indices, values, buf);
|
||||
|
||||
template <bool rader = false>
|
||||
METAL_FUNC void
|
||||
perform_fft(int fft_idx, thread int* p, int m, int n, threadgroup float2* buf) {
|
||||
float2 inputs[MAX_RADIX];
|
||||
short indices[MAX_OUTPUT_SIZE];
|
||||
float2 values[MAX_OUTPUT_SIZE];
|
||||
|
||||
RADIX_STEP(2, radix2, rader ? rader_2_steps_ : radix_2_steps_);
|
||||
RADIX_STEP(3, radix3, rader ? rader_3_steps_ : radix_3_steps_);
|
||||
RADIX_STEP(4, radix4, rader ? rader_4_steps_ : radix_4_steps_);
|
||||
RADIX_STEP(5, radix5, rader ? rader_5_steps_ : radix_5_steps_);
|
||||
RADIX_STEP(6, radix6, rader ? rader_6_steps_ : radix_6_steps_);
|
||||
RADIX_STEP(7, radix7, rader ? rader_7_steps_ : radix_7_steps_);
|
||||
RADIX_STEP(8, radix8, rader ? rader_8_steps_ : radix_8_steps_);
|
||||
RADIX_STEP(11, radix11, rader ? rader_11_steps_ : radix_11_steps_);
|
||||
RADIX_STEP(13, radix13, rader ? rader_13_steps_ : radix_13_steps_);
|
||||
}
|
||||
|
||||
// Each FFT is computed entirely in shared GPU memory.
|
||||
//
|
||||
// N is decomposed into radix-n DFTs:
|
||||
// e.g. 128 = 2 * 4 * 4 * 4
|
||||
template <int tg_mem_size, typename in_T, typename out_T>
|
||||
[[kernel]] void fft(
|
||||
const device in_T* in [[buffer(0)]],
|
||||
device out_T* out [[buffer(1)]],
|
||||
constant const int& n,
|
||||
constant const int& batch_size,
|
||||
uint3 elem [[thread_position_in_grid]],
|
||||
uint3 grid [[threads_per_grid]]) {
|
||||
threadgroup float2 shared_in[tg_mem_size];
|
||||
|
||||
thread ReadWriter<in_T, out_T> read_writer = ReadWriter<in_T, out_T>(
|
||||
in,
|
||||
&shared_in[0],
|
||||
out,
|
||||
n,
|
||||
batch_size,
|
||||
elems_per_thread_,
|
||||
elem,
|
||||
grid,
|
||||
inv_);
|
||||
|
||||
if (read_writer.out_of_bounds()) {
|
||||
return;
|
||||
};
|
||||
read_writer.load();
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
int p = 1;
|
||||
int fft_idx = elem.z; // Thread index in DFT
|
||||
int m = grid.z; // Threads per DFT
|
||||
int tg_idx = elem.y * n; // Index of this DFT in threadgroup
|
||||
threadgroup float2* buf = &shared_in[tg_idx];
|
||||
|
||||
perform_fft(fft_idx, &p, m, n, buf);
|
||||
|
||||
read_writer.write();
|
||||
}
|
||||
|
||||
template <int tg_mem_size, typename in_T, typename out_T>
|
||||
[[kernel]] void rader_fft(
|
||||
const device in_T* in [[buffer(0)]],
|
||||
device out_T* out [[buffer(1)]],
|
||||
const device float2* raders_b_q [[buffer(2)]],
|
||||
const device short* raders_g_q [[buffer(3)]],
|
||||
const device short* raders_g_minus_q [[buffer(4)]],
|
||||
constant const int& n,
|
||||
constant const int& batch_size,
|
||||
constant const int& rader_n,
|
||||
uint3 elem [[thread_position_in_grid]],
|
||||
uint3 grid [[threads_per_grid]]) {
|
||||
// Use Rader's algorithm to compute fast FFTs
|
||||
// when a prime factor `p` of `n` is greater than 13 but
|
||||
// has `p - 1` Stockham decomposable into to prime factors <= 13.
|
||||
//
|
||||
// E.g. n = 102
|
||||
// = 2 * 3 * 17
|
||||
// . = 2 * 3 * RADER(16)
|
||||
// . = 2 * 3 * RADER(4 * 4)
|
||||
//
|
||||
// In numpy:
|
||||
// x_perm = x[g_q]
|
||||
// y = np.fft.fft(x_perm) * b_q
|
||||
// z = np.fft.ifft(y) + x[0]
|
||||
// out = z[g_minus_q]
|
||||
// out[0] = x[1:].sum()
|
||||
//
|
||||
// Where the g_q and g_minus_q are permutations formed
|
||||
// by the group under multiplicative modulo N using the
|
||||
// primitive root of N and b_q is a constant.
|
||||
// See https://en.wikipedia.org/wiki/Rader%27s_FFT_algorithm
|
||||
//
|
||||
// Rader's uses fewer operations than Bluestein's and so
|
||||
// is more accurate. It's also faster in most cases.
|
||||
threadgroup float2 shared_in[tg_mem_size];
|
||||
|
||||
thread ReadWriter<in_T, out_T> read_writer = ReadWriter<in_T, out_T>(
|
||||
in,
|
||||
&shared_in[0],
|
||||
out,
|
||||
n,
|
||||
batch_size,
|
||||
elems_per_thread_,
|
||||
elem,
|
||||
grid,
|
||||
inv_);
|
||||
|
||||
if (read_writer.out_of_bounds()) {
|
||||
return;
|
||||
};
|
||||
read_writer.load();
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// The number of the threads we're using for each DFT
|
||||
int m = grid.z;
|
||||
|
||||
int fft_idx = elem.z;
|
||||
int tg_idx = elem.y * n;
|
||||
threadgroup float2* buf = &shared_in[tg_idx];
|
||||
|
||||
// rader_m = n / rader_n;
|
||||
int rader_m = rader_m_;
|
||||
|
||||
// We have to load two x_0s for each thread since sometimes
|
||||
// elems_per_thread_ crosses a boundary.
|
||||
// E.g. with n = 34, rader_n = 17, elems_per_thread_ = 4
|
||||
// 0 0 0 0 1 1 1 1 2 2 2 2 3 3 3 3 4 4 4 4 5 5 5 5 6 6 6 6 7 7 7 7 8 8
|
||||
// 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
|
||||
short x_0_index =
|
||||
metal::min(fft_idx * elems_per_thread_ / (rader_n - 1), rader_m - 1);
|
||||
float2 x_0[2] = {buf[x_0_index], buf[x_0_index + 1]};
|
||||
|
||||
// Do the Rader permutation in shared memory
|
||||
float2 temp[MAX_RADIX];
|
||||
int max_index = n - rader_m - 1;
|
||||
for (int e = 0; e < elems_per_thread_; e++) {
|
||||
short index = metal::min(fft_idx * elems_per_thread_ + e, max_index);
|
||||
short g_q = raders_g_q[index / rader_m];
|
||||
temp[e] = buf[rader_m + (g_q - 1) * rader_m + index % rader_m];
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
for (int e = 0; e < elems_per_thread_; e++) {
|
||||
short index = metal::min(fft_idx * elems_per_thread_ + e, max_index);
|
||||
buf[index + rader_m] = temp[e];
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Rader FFT on x[rader_m:]
|
||||
int p = 1;
|
||||
perform_fft</*rader=*/true>(fft_idx, &p, m, n - rader_m, buf + rader_m);
|
||||
|
||||
// x_1 + ... + x_n is computed for us in the first FFT step so
|
||||
// we save it in the first rader_m indices of the array for later.
|
||||
int x_sum_index = metal::min(fft_idx, rader_m - 1);
|
||||
buf[x_sum_index] = buf[rader_m + x_sum_index * (rader_n - 1)];
|
||||
|
||||
float2 inv = {1.0f, -1.0f};
|
||||
for (int e = 0; e < elems_per_thread_; e++) {
|
||||
short index = metal::min(fft_idx * elems_per_thread_ + e, max_index);
|
||||
short interleaved_index =
|
||||
index / rader_m + (index % rader_m) * (rader_n - 1);
|
||||
temp[e] = complex_mul(
|
||||
buf[rader_m + interleaved_index],
|
||||
raders_b_q[interleaved_index % (rader_n - 1)]);
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
for (int e = 0; e < elems_per_thread_; e++) {
|
||||
short index = metal::min(fft_idx * elems_per_thread_ + e, max_index);
|
||||
buf[rader_m + index] = temp[e] * inv;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Rader IFFT on x[rader_m:]
|
||||
p = 1;
|
||||
perform_fft</*rader=*/true>(fft_idx, &p, m, n - rader_m, buf + rader_m);
|
||||
|
||||
float2 rader_inv_factor = {1.0f / (rader_n - 1), -1.0f / (rader_n - 1)};
|
||||
|
||||
for (int e = 0; e < elems_per_thread_; e++) {
|
||||
short index = metal::min(fft_idx * elems_per_thread_ + e, n - rader_m - 1);
|
||||
short diff_index = index / (rader_n - 1) - x_0_index;
|
||||
temp[e] = buf[rader_m + index] * rader_inv_factor + x_0[diff_index];
|
||||
}
|
||||
|
||||
// Use the sum of elements that was computed in the first FFT
|
||||
float2 x_sum = buf[x_0_index] + x_0[0];
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
for (int e = 0; e < elems_per_thread_; e++) {
|
||||
short index = metal::min(fft_idx * elems_per_thread_ + e, max_index);
|
||||
short g_q_index = index % (rader_n - 1);
|
||||
short g_q = raders_g_minus_q[g_q_index];
|
||||
short out_index = index - g_q_index + g_q + (index / (rader_n - 1));
|
||||
buf[out_index] = temp[e];
|
||||
}
|
||||
|
||||
buf[x_0_index * rader_n] = x_sum;
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
p = rader_n;
|
||||
perform_fft(fft_idx, &p, m, n, buf);
|
||||
|
||||
read_writer.write();
|
||||
}
|
||||
|
||||
template <int tg_mem_size, typename in_T, typename out_T>
|
||||
[[kernel]] void bluestein_fft(
|
||||
const device in_T* in [[buffer(0)]],
|
||||
device out_T* out [[buffer(1)]],
|
||||
const device float2* w_q [[buffer(2)]],
|
||||
const device float2* w_k [[buffer(3)]],
|
||||
constant const int& length,
|
||||
constant const int& n,
|
||||
constant const int& batch_size,
|
||||
uint3 elem [[thread_position_in_grid]],
|
||||
uint3 grid [[threads_per_grid]]) {
|
||||
// Computes arbitrary length FFTs with Bluestein's algorithm
|
||||
//
|
||||
// In numpy:
|
||||
// bluestein_n = next_power_of_2(2*n - 1)
|
||||
// out = w_k * np.fft.ifft(np.fft.fft(w_k * in, bluestein_n) * w_q)
|
||||
//
|
||||
// Where w_k and w_q are precomputed on CPU in high precision as:
|
||||
// w_k = np.exp(-1j * np.pi / n * (np.arange(-n + 1, n) ** 2))
|
||||
// w_q = np.fft.fft(1/w_k[-n:])
|
||||
threadgroup float2 shared_in[tg_mem_size];
|
||||
|
||||
thread ReadWriter<in_T, out_T> read_writer = ReadWriter<in_T, out_T>(
|
||||
in,
|
||||
&shared_in[0],
|
||||
out,
|
||||
n,
|
||||
batch_size,
|
||||
elems_per_thread_,
|
||||
elem,
|
||||
grid,
|
||||
inv_);
|
||||
|
||||
if (read_writer.out_of_bounds()) {
|
||||
return;
|
||||
};
|
||||
read_writer.load_padded(length, w_k);
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
int p = 1;
|
||||
int fft_idx = elem.z; // Thread index in DFT
|
||||
int m = grid.z; // Threads per DFT
|
||||
int tg_idx = elem.y * n; // Index of this DFT in threadgroup
|
||||
threadgroup float2* buf = &shared_in[tg_idx];
|
||||
|
||||
// fft
|
||||
perform_fft(fft_idx, &p, m, n, buf);
|
||||
|
||||
float2 inv = float2(1.0f, -1.0f);
|
||||
for (int t = 0; t < elems_per_thread_; t++) {
|
||||
int index = fft_idx + t * m;
|
||||
buf[index] = complex_mul(buf[index], w_q[index]) * inv;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// ifft
|
||||
p = 1;
|
||||
perform_fft(fft_idx, &p, m, n, buf);
|
||||
|
||||
read_writer.write_padded(length, w_k);
|
||||
}
|
||||
|
||||
template <
|
||||
int tg_mem_size,
|
||||
typename in_T,
|
||||
typename out_T,
|
||||
int step,
|
||||
bool real = false>
|
||||
[[kernel]] void four_step_fft(
|
||||
const device in_T* in [[buffer(0)]],
|
||||
device out_T* out [[buffer(1)]],
|
||||
constant const int& n1,
|
||||
constant const int& n2,
|
||||
constant const int& batch_size,
|
||||
uint3 elem [[thread_position_in_grid]],
|
||||
uint3 grid [[threads_per_grid]]) {
|
||||
// Fast four step FFT implementation for powers of 2.
|
||||
int overall_n = n1 * n2;
|
||||
int n = step == 0 ? n1 : n2;
|
||||
int stride = step == 0 ? n2 : n1;
|
||||
|
||||
// The number of the threads we're using for each DFT
|
||||
int m = grid.z;
|
||||
int fft_idx = elem.z;
|
||||
|
||||
threadgroup float2 shared_in[tg_mem_size];
|
||||
threadgroup float2* buf = &shared_in[elem.y * n];
|
||||
|
||||
using read_writer_t = ReadWriter<in_T, out_T, step, real>;
|
||||
read_writer_t read_writer = read_writer_t(
|
||||
in,
|
||||
&shared_in[0],
|
||||
out,
|
||||
n,
|
||||
batch_size,
|
||||
elems_per_thread_,
|
||||
elem,
|
||||
grid,
|
||||
inv_);
|
||||
|
||||
if (read_writer.out_of_bounds()) {
|
||||
return;
|
||||
};
|
||||
read_writer.load_strided(stride, overall_n);
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
int p = 1;
|
||||
perform_fft(fft_idx, &p, m, n, buf);
|
||||
|
||||
read_writer.write_strided(stride, overall_n);
|
||||
}
|
@ -1,199 +1,84 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
// Metal FFT using Stockham's algorithm
|
||||
//
|
||||
// References:
|
||||
// - VkFFT (https://github.com/DTolm/VkFFT)
|
||||
// - Eric Bainville's excellent page (http://www.bealto.com/gpu-fft.html)
|
||||
#include "mlx/backend/metal/kernels/fft.h"
|
||||
|
||||
#include <metal_common>
|
||||
#include <metal_math>
|
||||
#define instantiate_fft(tg_mem_size, in_T, out_T) \
|
||||
template [[host_name("fft_mem_" #tg_mem_size "_" #in_T \
|
||||
"_" #out_T)]] [[kernel]] void \
|
||||
fft<tg_mem_size, in_T, out_T>( \
|
||||
const device in_T* in [[buffer(0)]], \
|
||||
device out_T* out [[buffer(1)]], \
|
||||
constant const int& n, \
|
||||
constant const int& batch_size, \
|
||||
uint3 elem [[thread_position_in_grid]], \
|
||||
uint3 grid [[threads_per_grid]]);
|
||||
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
#define instantiate_rader(tg_mem_size, in_T, out_T) \
|
||||
template [[host_name("rader_fft_mem_" #tg_mem_size "_" #in_T \
|
||||
"_" #out_T)]] [[kernel]] void \
|
||||
rader_fft<tg_mem_size, in_T, out_T>( \
|
||||
const device in_T* in [[buffer(0)]], \
|
||||
device out_T* out [[buffer(1)]], \
|
||||
const device float2* raders_b_q [[buffer(2)]], \
|
||||
const device short* raders_g_q [[buffer(3)]], \
|
||||
const device short* raders_g_minus_q [[buffer(4)]], \
|
||||
constant const int& n, \
|
||||
constant const int& batch_size, \
|
||||
constant const int& rader_n, \
|
||||
uint3 elem [[thread_position_in_grid]], \
|
||||
uint3 grid [[threads_per_grid]]);
|
||||
|
||||
using namespace metal;
|
||||
#define instantiate_bluestein(tg_mem_size, in_T, out_T) \
|
||||
template [[host_name("bluestein_fft_mem_" #tg_mem_size "_" #in_T \
|
||||
"_" #out_T)]] [[kernel]] void \
|
||||
bluestein_fft<tg_mem_size, in_T, out_T>( \
|
||||
const device in_T* in [[buffer(0)]], \
|
||||
device out_T* out [[buffer(1)]], \
|
||||
const device float2* w_q [[buffer(2)]], \
|
||||
const device float2* w_k [[buffer(3)]], \
|
||||
constant const int& length, \
|
||||
constant const int& n, \
|
||||
constant const int& batch_size, \
|
||||
uint3 elem [[thread_position_in_grid]], \
|
||||
uint3 grid [[threads_per_grid]]);
|
||||
|
||||
float2 complex_mul(float2 a, float2 b) {
|
||||
float2 c;
|
||||
c.x = a.x * b.x - a.y * b.y;
|
||||
c.y = a.x * b.y + a.y * b.x;
|
||||
return c;
|
||||
}
|
||||
#define instantiate_four_step(tg_mem_size, in_T, out_T, step, real) \
|
||||
template [[host_name("four_step_mem_" #tg_mem_size "_" #in_T "_" #out_T \
|
||||
"_" #step "_" #real)]] [[kernel]] void \
|
||||
four_step_fft<tg_mem_size, in_T, out_T, step, real>( \
|
||||
const device in_T* in [[buffer(0)]], \
|
||||
device out_T* out [[buffer(1)]], \
|
||||
constant const int& n1, \
|
||||
constant const int& n2, \
|
||||
constant const int& batch_size, \
|
||||
uint3 elem [[thread_position_in_grid]], \
|
||||
uint3 grid [[threads_per_grid]]);
|
||||
|
||||
float2 get_twiddle(int k, int p) {
|
||||
float theta = -1.0f * k * M_PI_F / (2 * p);
|
||||
|
||||
float2 twiddle;
|
||||
twiddle.x = metal::fast::cos(theta);
|
||||
twiddle.y = metal::fast::sin(theta);
|
||||
return twiddle;
|
||||
}
|
||||
|
||||
// single threaded radix2 implemetation
|
||||
void radix2(
|
||||
int i,
|
||||
int p,
|
||||
int m,
|
||||
threadgroup float2* read_buf,
|
||||
threadgroup float2* write_buf) {
|
||||
float2 x_0 = read_buf[i];
|
||||
float2 x_1 = read_buf[i + m];
|
||||
|
||||
// The index within this sub-DFT
|
||||
int k = i & (p - 1);
|
||||
|
||||
float2 twiddle = get_twiddle(k, p);
|
||||
|
||||
float2 z = complex_mul(x_1, twiddle);
|
||||
|
||||
float2 y_0 = x_0 + z;
|
||||
float2 y_1 = x_0 - z;
|
||||
|
||||
int j = (i << 1) - k;
|
||||
|
||||
write_buf[j] = y_0;
|
||||
write_buf[j + p] = y_1;
|
||||
}
|
||||
|
||||
// single threaded radix4 implemetation
|
||||
void radix4(
|
||||
int i,
|
||||
int p,
|
||||
int m,
|
||||
threadgroup float2* read_buf,
|
||||
threadgroup float2* write_buf) {
|
||||
float2 x_0 = read_buf[i];
|
||||
float2 x_1 = read_buf[i + m];
|
||||
float2 x_2 = read_buf[i + 2 * m];
|
||||
float2 x_3 = read_buf[i + 3 * m];
|
||||
|
||||
// The index within this sub-DFT
|
||||
int k = i & (p - 1);
|
||||
|
||||
float2 twiddle = get_twiddle(k, p);
|
||||
// e^a * e^b = e^(a + b)
|
||||
float2 twiddle_2 = complex_mul(twiddle, twiddle);
|
||||
float2 twiddle_3 = complex_mul(twiddle, twiddle_2);
|
||||
|
||||
x_1 = complex_mul(x_1, twiddle);
|
||||
x_2 = complex_mul(x_2, twiddle_2);
|
||||
x_3 = complex_mul(x_3, twiddle_3);
|
||||
|
||||
float2 minus_i;
|
||||
minus_i.x = 0;
|
||||
minus_i.y = -1;
|
||||
|
||||
// Hard coded twiddle factors for DFT4
|
||||
float2 z_0 = x_0 + x_2;
|
||||
float2 z_1 = x_0 - x_2;
|
||||
float2 z_2 = x_1 + x_3;
|
||||
float2 z_3 = complex_mul(x_1 - x_3, minus_i);
|
||||
|
||||
float2 y_0 = z_0 + z_2;
|
||||
float2 y_1 = z_1 + z_3;
|
||||
float2 y_2 = z_0 - z_2;
|
||||
float2 y_3 = z_1 - z_3;
|
||||
|
||||
int j = ((i - k) << 2) + k;
|
||||
|
||||
write_buf[j] = y_0;
|
||||
write_buf[j + p] = y_1;
|
||||
write_buf[j + 2 * p] = y_2;
|
||||
write_buf[j + 3 * p] = y_3;
|
||||
}
|
||||
|
||||
// Each FFT is computed entirely in shared GPU memory.
|
||||
//
|
||||
// N is decomposed into radix-2 and radix-4 DFTs:
|
||||
// e.g. 128 = 2 * 4 * 4 * 4
|
||||
//
|
||||
// At each step we use n / 4 threads, each performing
|
||||
// a single-threaded radix-4 or radix-2 DFT.
|
||||
//
|
||||
// We provide the number of radix-2 and radix-4
|
||||
// steps at compile time for a ~20% performance boost.
|
||||
template <size_t n, size_t radix_2_steps, size_t radix_4_steps>
|
||||
[[kernel]] void fft(
|
||||
const device float2* in [[buffer(0)]],
|
||||
device float2* out [[buffer(1)]],
|
||||
uint3 thread_position_in_grid [[thread_position_in_grid]],
|
||||
uint3 threads_per_grid [[threads_per_grid]]) {
|
||||
// Index of the DFT in batch
|
||||
int batch_idx = thread_position_in_grid.x * n;
|
||||
// The index in the DFT we're working on
|
||||
int i = thread_position_in_grid.y;
|
||||
// The number of the threads we're using for each DFT
|
||||
int m = threads_per_grid.y;
|
||||
|
||||
// Allocate 2 shared memory buffers for Stockham.
|
||||
// We alternate reading from one and writing to the other at each radix step.
|
||||
threadgroup float2 shared_in[n];
|
||||
threadgroup float2 shared_out[n];
|
||||
|
||||
// Pointers to facilitate Stockham buffer swapping
|
||||
threadgroup float2* read_buf = shared_in;
|
||||
threadgroup float2* write_buf = shared_out;
|
||||
threadgroup float2* tmp;
|
||||
|
||||
// Copy input into shared memory
|
||||
shared_in[i] = in[batch_idx + i];
|
||||
shared_in[i + m] = in[batch_idx + i + m];
|
||||
shared_in[i + 2 * m] = in[batch_idx + i + 2 * m];
|
||||
shared_in[i + 3 * m] = in[batch_idx + i + 3 * m];
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
int p = 1;
|
||||
|
||||
for (size_t r = 0; r < radix_2_steps; r++) {
|
||||
radix2(i, p, m * 2, read_buf, write_buf);
|
||||
radix2(i + m, p, m * 2, read_buf, write_buf);
|
||||
p *= 2;
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Stockham switch of buffers
|
||||
tmp = write_buf;
|
||||
write_buf = read_buf;
|
||||
read_buf = tmp;
|
||||
}
|
||||
|
||||
for (size_t r = 0; r < radix_4_steps; r++) {
|
||||
radix4(i, p, m, read_buf, write_buf);
|
||||
p *= 4;
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Stockham switch of buffers
|
||||
tmp = write_buf;
|
||||
write_buf = read_buf;
|
||||
read_buf = tmp;
|
||||
}
|
||||
|
||||
// Copy shared memory to output
|
||||
out[batch_idx + i] = read_buf[i];
|
||||
out[batch_idx + i + m] = read_buf[i + m];
|
||||
out[batch_idx + i + 2 * m] = read_buf[i + 2 * m];
|
||||
out[batch_idx + i + 3 * m] = read_buf[i + 3 * m];
|
||||
}
|
||||
|
||||
#define instantiate_fft(name, n, radix_2_steps, radix_4_steps) \
|
||||
template [[host_name("fft_" #name)]] [[kernel]] void \
|
||||
fft<n, radix_2_steps, radix_4_steps>( \
|
||||
const device float2* in [[buffer(0)]], \
|
||||
device float2* out [[buffer(1)]], \
|
||||
uint3 thread_position_in_grid [[thread_position_in_grid]], \
|
||||
uint3 threads_per_grid [[threads_per_grid]]);
|
||||
|
||||
// Explicitly define kernels for each power of 2.
|
||||
// clang-format off
|
||||
instantiate_fft(4, /* n= */ 4, /* radix_2_steps= */ 0, /* radix_4_steps= */ 1)
|
||||
instantiate_fft(8, 8, 1, 1) instantiate_fft(16, 16, 0, 2)
|
||||
instantiate_fft(32, 32, 1, 2) instantiate_fft(64, 64, 0, 3)
|
||||
instantiate_fft(128, 128, 1, 3) instantiate_fft(256, 256, 0, 4)
|
||||
instantiate_fft(512, 512, 1, 4)
|
||||
instantiate_fft(1024, 1024, 0, 5)
|
||||
// 2048 is the max that will fit into 32KB of threadgroup memory.
|
||||
// TODO: implement 4 step FFT for larger n.
|
||||
instantiate_fft(2048, 2048, 1, 5) // clang-format on
|
||||
#define instantiate_ffts(tg_mem_size) \
|
||||
instantiate_fft(tg_mem_size, float2, float2) \
|
||||
instantiate_fft(tg_mem_size, float, float2) \
|
||||
instantiate_fft(tg_mem_size, float2, float) \
|
||||
instantiate_rader(tg_mem_size, float2, float2) \
|
||||
instantiate_rader(tg_mem_size, float, float2) \
|
||||
instantiate_rader(tg_mem_size, float2, float) \
|
||||
instantiate_bluestein(tg_mem_size, float2, float2) \
|
||||
instantiate_bluestein(tg_mem_size, float, float2) \
|
||||
instantiate_bluestein(tg_mem_size, float2, float) \
|
||||
instantiate_four_step(tg_mem_size, float2, float2, 0, /*real=*/false) \
|
||||
instantiate_four_step(tg_mem_size, float2, float2, 1, /*real=*/false) \
|
||||
instantiate_four_step(tg_mem_size, float, float2, 0, /*real=*/true) \
|
||||
instantiate_four_step(tg_mem_size, float2, float2, 1, /*real=*/true) \
|
||||
instantiate_four_step(tg_mem_size, float2, float2, 0, /*real=*/true) \
|
||||
instantiate_four_step(tg_mem_size, float2, float, 1, /*real=*/true)
|
||||
|
||||
// It's substantially faster to statically define the
|
||||
// threadgroup memory size rather than using
|
||||
// `setThreadgroupMemoryLength` on the compute encoder.
|
||||
// For non-power of 2 sizes we round up the shared memory.
|
||||
instantiate_ffts(256)
|
||||
instantiate_ffts(512)
|
||||
instantiate_ffts(1024)
|
||||
instantiate_ffts(2048)
|
||||
// 4096 is the max that will fit into 32KB of threadgroup memory.
|
||||
instantiate_ffts(4096) // clang-format on
|
||||
|
328
mlx/backend/metal/kernels/fft/radix.h
Normal file
328
mlx/backend/metal/kernels/fft/radix.h
Normal file
@ -0,0 +1,328 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
/* Radix kernels
|
||||
|
||||
We provide optimized, single threaded Radix codelets
|
||||
for n=2,3,4,5,6,7,8,10,11,12,13.
|
||||
|
||||
For n=2,3,4,5,6 we hand write the codelets.
|
||||
For n=8,10,12 we combine smaller codelets.
|
||||
For n=7,11,13 we use Rader's algorithm which decomposes
|
||||
them into (n-1)=6,10,12 codelets. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <metal_common>
|
||||
#include <metal_math>
|
||||
#include <metal_stdlib>
|
||||
|
||||
METAL_FUNC float2 complex_mul(float2 a, float2 b) {
|
||||
return float2(a.x * b.x - a.y * b.y, a.x * b.y + a.y * b.x);
|
||||
}
|
||||
|
||||
// Complex mul followed by conjugate
|
||||
METAL_FUNC float2 complex_mul_conj(float2 a, float2 b) {
|
||||
return float2(a.x * b.x - a.y * b.y, -a.x * b.y - a.y * b.x);
|
||||
}
|
||||
|
||||
// Compute an FFT twiddle factor
|
||||
METAL_FUNC float2 get_twiddle(int k, int p) {
|
||||
float theta = -2.0f * k * M_PI_F / p;
|
||||
|
||||
float2 twiddle = {metal::fast::cos(theta), metal::fast::sin(theta)};
|
||||
return twiddle;
|
||||
}
|
||||
|
||||
METAL_FUNC void radix2(thread float2* x, thread float2* y) {
|
||||
y[0] = x[0] + x[1];
|
||||
y[1] = x[0] - x[1];
|
||||
}
|
||||
|
||||
METAL_FUNC void radix3(thread float2* x, thread float2* y) {
|
||||
float pi_2_3 = -0.8660254037844387;
|
||||
|
||||
float2 a_1 = x[1] + x[2];
|
||||
float2 a_2 = x[1] - x[2];
|
||||
|
||||
y[0] = x[0] + a_1;
|
||||
float2 b_1 = x[0] - 0.5 * a_1;
|
||||
float2 b_2 = pi_2_3 * a_2;
|
||||
|
||||
float2 b_2_j = {-b_2.y, b_2.x};
|
||||
y[1] = b_1 + b_2_j;
|
||||
y[2] = b_1 - b_2_j;
|
||||
}
|
||||
|
||||
METAL_FUNC void radix4(thread float2* x, thread float2* y) {
|
||||
float2 z_0 = x[0] + x[2];
|
||||
float2 z_1 = x[0] - x[2];
|
||||
float2 z_2 = x[1] + x[3];
|
||||
float2 z_3 = x[1] - x[3];
|
||||
float2 z_3_i = {z_3.y, -z_3.x};
|
||||
|
||||
y[0] = z_0 + z_2;
|
||||
y[1] = z_1 + z_3_i;
|
||||
y[2] = z_0 - z_2;
|
||||
y[3] = z_1 - z_3_i;
|
||||
}
|
||||
|
||||
METAL_FUNC void radix5(thread float2* x, thread float2* y) {
|
||||
float2 root_5_4 = 0.5590169943749475;
|
||||
float2 sin_2pi_5 = 0.9510565162951535;
|
||||
float2 sin_1pi_5 = 0.5877852522924731;
|
||||
|
||||
float2 a_1 = x[1] + x[4];
|
||||
float2 a_2 = x[2] + x[3];
|
||||
float2 a_3 = x[1] - x[4];
|
||||
float2 a_4 = x[2] - x[3];
|
||||
|
||||
float2 a_5 = a_1 + a_2;
|
||||
float2 a_6 = root_5_4 * (a_1 - a_2);
|
||||
float2 a_7 = x[0] - a_5 / 4;
|
||||
float2 a_8 = a_7 + a_6;
|
||||
float2 a_9 = a_7 - a_6;
|
||||
float2 a_10 = sin_2pi_5 * a_3 + sin_1pi_5 * a_4;
|
||||
float2 a_11 = sin_1pi_5 * a_3 - sin_2pi_5 * a_4;
|
||||
float2 a_10_j = {a_10.y, -a_10.x};
|
||||
float2 a_11_j = {a_11.y, -a_11.x};
|
||||
|
||||
y[0] = x[0] + a_5;
|
||||
y[1] = a_8 + a_10_j;
|
||||
y[2] = a_9 + a_11_j;
|
||||
y[3] = a_9 - a_11_j;
|
||||
y[4] = a_8 - a_10_j;
|
||||
}
|
||||
|
||||
METAL_FUNC void radix6(thread float2* x, thread float2* y) {
|
||||
float sin_pi_3 = 0.8660254037844387;
|
||||
float2 a_1 = x[2] + x[4];
|
||||
float2 a_2 = x[0] - a_1 / 2;
|
||||
float2 a_3 = sin_pi_3 * (x[2] - x[4]);
|
||||
float2 a_4 = x[5] + x[1];
|
||||
float2 a_5 = x[3] - a_4 / 2;
|
||||
float2 a_6 = sin_pi_3 * (x[5] - x[1]);
|
||||
float2 a_7 = x[0] + a_1;
|
||||
|
||||
float2 a_3_i = {a_3.y, -a_3.x};
|
||||
float2 a_6_i = {a_6.y, -a_6.x};
|
||||
float2 a_8 = a_2 + a_3_i;
|
||||
float2 a_9 = a_2 - a_3_i;
|
||||
float2 a_10 = x[3] + a_4;
|
||||
float2 a_11 = a_5 + a_6_i;
|
||||
float2 a_12 = a_5 - a_6_i;
|
||||
|
||||
y[0] = a_7 + a_10;
|
||||
y[1] = a_8 - a_11;
|
||||
y[2] = a_9 + a_12;
|
||||
y[3] = a_7 - a_10;
|
||||
y[4] = a_8 + a_11;
|
||||
y[5] = a_9 - a_12;
|
||||
}
|
||||
|
||||
METAL_FUNC void radix7(thread float2* x, thread float2* y) {
|
||||
// Rader's algorithm
|
||||
float2 inv = {1 / 6.0, -1 / 6.0};
|
||||
|
||||
// fft
|
||||
float2 in1[6] = {x[1], x[3], x[2], x[6], x[4], x[5]};
|
||||
radix6(in1, y + 1);
|
||||
|
||||
y[0] = y[1] + x[0];
|
||||
|
||||
// b_q
|
||||
y[1] = complex_mul_conj(y[1], float2(-1, 0));
|
||||
y[2] = complex_mul_conj(y[2], float2(2.44013336, -1.02261879));
|
||||
y[3] = complex_mul_conj(y[3], float2(2.37046941, -1.17510629));
|
||||
y[4] = complex_mul_conj(y[4], float2(0, -2.64575131));
|
||||
y[5] = complex_mul_conj(y[5], float2(2.37046941, 1.17510629));
|
||||
y[6] = complex_mul_conj(y[6], float2(-2.44013336, -1.02261879));
|
||||
|
||||
// ifft
|
||||
radix6(y + 1, x + 1);
|
||||
|
||||
y[1] = x[1] * inv + x[0];
|
||||
y[5] = x[2] * inv + x[0];
|
||||
y[4] = x[3] * inv + x[0];
|
||||
y[6] = x[4] * inv + x[0];
|
||||
y[2] = x[5] * inv + x[0];
|
||||
y[3] = x[6] * inv + x[0];
|
||||
}
|
||||
|
||||
METAL_FUNC void radix8(thread float2* x, thread float2* y) {
|
||||
float cos_pi_4 = 0.7071067811865476;
|
||||
float2 w_0 = {cos_pi_4, -cos_pi_4};
|
||||
float2 w_1 = {-cos_pi_4, -cos_pi_4};
|
||||
float2 temp[8] = {x[0], x[2], x[4], x[6], x[1], x[3], x[5], x[7]};
|
||||
radix4(temp, x);
|
||||
radix4(temp + 4, x + 4);
|
||||
|
||||
y[0] = x[0] + x[4];
|
||||
y[4] = x[0] - x[4];
|
||||
float2 x_5 = complex_mul(x[5], w_0);
|
||||
y[1] = x[1] + x_5;
|
||||
y[5] = x[1] - x_5;
|
||||
float2 x_6 = {x[6].y, -x[6].x};
|
||||
y[2] = x[2] + x_6;
|
||||
y[6] = x[2] - x_6;
|
||||
float2 x_7 = complex_mul(x[7], w_1);
|
||||
y[3] = x[3] + x_7;
|
||||
y[7] = x[3] - x_7;
|
||||
}
|
||||
|
||||
template <bool raders_perm>
|
||||
METAL_FUNC void radix10(thread float2* x, thread float2* y) {
|
||||
float2 w[4];
|
||||
w[0] = {0.8090169943749475, -0.5877852522924731};
|
||||
w[1] = {0.30901699437494745, -0.9510565162951535};
|
||||
w[2] = {-w[1].x, w[1].y};
|
||||
w[3] = {-w[0].x, w[0].y};
|
||||
|
||||
if (raders_perm) {
|
||||
float2 temp[10] = {
|
||||
x[0], x[3], x[4], x[8], x[2], x[1], x[7], x[9], x[6], x[5]};
|
||||
radix5(temp, x);
|
||||
radix5(temp + 5, x + 5);
|
||||
} else {
|
||||
float2 temp[10] = {
|
||||
x[0], x[2], x[4], x[6], x[8], x[1], x[3], x[5], x[7], x[9]};
|
||||
radix5(temp, x);
|
||||
radix5(temp + 5, x + 5);
|
||||
}
|
||||
|
||||
y[0] = x[0] + x[5];
|
||||
y[5] = x[0] - x[5];
|
||||
for (int t = 1; t < 5; t++) {
|
||||
float2 a = complex_mul(x[t + 5], w[t - 1]);
|
||||
y[t] = x[t] + a;
|
||||
y[t + 5] = x[t] - a;
|
||||
}
|
||||
}
|
||||
|
||||
METAL_FUNC void radix11(thread float2* x, thread float2* y) {
|
||||
// Raders Algorithm
|
||||
float2 inv = {1 / 10.0, -1 / 10.0};
|
||||
|
||||
// fft
|
||||
radix10<true>(x + 1, y + 1);
|
||||
|
||||
y[0] = y[1] + x[0];
|
||||
|
||||
// b_q
|
||||
y[1] = complex_mul_conj(y[1], float2(-1, 0));
|
||||
y[2] = complex_mul_conj(y[2], float2(0.955301878, -3.17606649));
|
||||
y[3] = complex_mul_conj(y[3], float2(2.63610556, 2.01269656));
|
||||
y[4] = complex_mul_conj(y[4], float2(2.54127802, 2.13117479));
|
||||
y[5] = complex_mul_conj(y[5], float2(2.07016210, 2.59122150));
|
||||
y[6] = complex_mul_conj(y[6], float2(0, -3.31662479));
|
||||
y[7] = complex_mul_conj(y[7], float2(2.07016210, -2.59122150));
|
||||
y[8] = complex_mul_conj(y[8], float2(-2.54127802, 2.13117479));
|
||||
y[9] = complex_mul_conj(y[9], float2(2.63610556, -2.01269656));
|
||||
y[10] = complex_mul_conj(y[10], float2(-0.955301878, -3.17606649));
|
||||
|
||||
// ifft
|
||||
radix10<false>(y + 1, x + 1);
|
||||
|
||||
y[1] = x[1] * inv + x[0];
|
||||
y[6] = x[2] * inv + x[0];
|
||||
y[3] = x[3] * inv + x[0];
|
||||
y[7] = x[4] * inv + x[0];
|
||||
y[9] = x[5] * inv + x[0];
|
||||
y[10] = x[6] * inv + x[0];
|
||||
y[5] = x[7] * inv + x[0];
|
||||
y[8] = x[8] * inv + x[0];
|
||||
y[4] = x[9] * inv + x[0];
|
||||
y[2] = x[10] * inv + x[0];
|
||||
}
|
||||
|
||||
template <bool raders_perm>
|
||||
METAL_FUNC void radix12(thread float2* x, thread float2* y) {
|
||||
float2 w[6];
|
||||
float sin_pi_3 = 0.8660254037844387;
|
||||
w[0] = {sin_pi_3, -0.5};
|
||||
w[1] = {0.5, -sin_pi_3};
|
||||
w[2] = {0, -1};
|
||||
w[3] = {-0.5, -sin_pi_3};
|
||||
w[4] = {-sin_pi_3, -0.5};
|
||||
|
||||
if (raders_perm) {
|
||||
float2 temp[12] = {
|
||||
x[0],
|
||||
x[3],
|
||||
x[2],
|
||||
x[11],
|
||||
x[8],
|
||||
x[9],
|
||||
x[1],
|
||||
x[7],
|
||||
x[5],
|
||||
x[10],
|
||||
x[4],
|
||||
x[6]};
|
||||
radix6(temp, x);
|
||||
radix6(temp + 6, x + 6);
|
||||
} else {
|
||||
float2 temp[12] = {
|
||||
x[0],
|
||||
x[2],
|
||||
x[4],
|
||||
x[6],
|
||||
x[8],
|
||||
x[10],
|
||||
x[1],
|
||||
x[3],
|
||||
x[5],
|
||||
x[7],
|
||||
x[9],
|
||||
x[11]};
|
||||
radix6(temp, x);
|
||||
radix6(temp + 6, x + 6);
|
||||
}
|
||||
|
||||
y[0] = x[0] + x[6];
|
||||
y[6] = x[0] - x[6];
|
||||
for (int t = 1; t < 6; t++) {
|
||||
float2 a = complex_mul(x[t + 6], w[t - 1]);
|
||||
y[t] = x[t] + a;
|
||||
y[t + 6] = x[t] - a;
|
||||
}
|
||||
}
|
||||
|
||||
METAL_FUNC void radix13(thread float2* x, thread float2* y) {
|
||||
// Raders Algorithm
|
||||
float2 inv = {1 / 12.0, -1 / 12.0};
|
||||
|
||||
// fft
|
||||
radix12<true>(x + 1, y + 1);
|
||||
|
||||
y[0] = y[1] + x[0];
|
||||
|
||||
// b_q
|
||||
y[1] = complex_mul_conj(y[1], float2(-1, 0));
|
||||
y[2] = complex_mul_conj(y[2], float2(3.07497206, -1.88269669));
|
||||
y[3] = complex_mul_conj(y[3], float2(3.09912468, 1.84266823));
|
||||
y[4] = complex_mul_conj(y[4], float2(3.45084438, -1.04483161));
|
||||
y[5] = complex_mul_conj(y[5], float2(0.91083583, 3.48860690));
|
||||
y[6] = complex_mul_conj(y[6], float2(-3.60286363, 0.139189267));
|
||||
y[7] = complex_mul_conj(y[7], float2(3.60555128, 0));
|
||||
y[8] = complex_mul_conj(y[8], float2(3.60286363, 0.139189267));
|
||||
y[9] = complex_mul_conj(y[9], float2(0.91083583, -3.48860690));
|
||||
y[10] = complex_mul_conj(y[10], float2(-3.45084438, -1.04483161));
|
||||
y[11] = complex_mul_conj(y[11], float2(3.09912468, -1.84266823));
|
||||
y[12] = complex_mul_conj(y[12], float2(-3.07497206, -1.88269669));
|
||||
|
||||
// ifft
|
||||
radix12<false>(y + 1, x + 1);
|
||||
|
||||
y[1] = x[1] * inv + x[0];
|
||||
y[7] = x[2] * inv + x[0];
|
||||
y[10] = x[3] * inv + x[0];
|
||||
y[5] = x[4] * inv + x[0];
|
||||
y[9] = x[5] * inv + x[0];
|
||||
y[11] = x[6] * inv + x[0];
|
||||
y[12] = x[7] * inv + x[0];
|
||||
y[6] = x[8] * inv + x[0];
|
||||
y[3] = x[9] * inv + x[0];
|
||||
y[8] = x[10] * inv + x[0];
|
||||
y[4] = x[11] * inv + x[0];
|
||||
y[2] = x[12] * inv + x[0];
|
||||
}
|
622
mlx/backend/metal/kernels/fft/readwrite.h
Normal file
622
mlx/backend/metal/kernels/fft/readwrite.h
Normal file
@ -0,0 +1,622 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <metal_common>
|
||||
|
||||
#include "mlx/backend/metal/kernels/fft/radix.h"
|
||||
|
||||
/* FFT helpers for reading and writing from/to device memory.
|
||||
|
||||
For many sizes, GPU FFTs are memory bandwidth bound so
|
||||
read/write performance is important.
|
||||
|
||||
Where possible, we read 128 bits sequentially in each thread,
|
||||
coalesced with accesses from adajcent threads for optimal performance.
|
||||
|
||||
We implement specialized reading/writing for:
|
||||
- FFT
|
||||
- RFFT
|
||||
- IRFFT
|
||||
|
||||
Each with support for:
|
||||
- Contiguous reads
|
||||
- Padded reads
|
||||
- Strided reads
|
||||
*/
|
||||
|
||||
#define MAX_RADIX 13
|
||||
|
||||
using namespace metal;
|
||||
|
||||
template <
|
||||
typename in_T,
|
||||
typename out_T,
|
||||
int step = 0,
|
||||
bool four_step_real = false>
|
||||
struct ReadWriter {
|
||||
const device in_T* in;
|
||||
threadgroup float2* buf;
|
||||
device out_T* out;
|
||||
int n;
|
||||
int batch_size;
|
||||
int elems_per_thread;
|
||||
uint3 elem;
|
||||
uint3 grid;
|
||||
int threads_per_tg;
|
||||
bool inv;
|
||||
|
||||
// Used for strided access
|
||||
int strided_device_idx = 0;
|
||||
int strided_shared_idx = 0;
|
||||
|
||||
METAL_FUNC ReadWriter(
|
||||
const device in_T* in_,
|
||||
threadgroup float2* buf_,
|
||||
device out_T* out_,
|
||||
const short n_,
|
||||
const int batch_size_,
|
||||
const short elems_per_thread_,
|
||||
const uint3 elem_,
|
||||
const uint3 grid_,
|
||||
const bool inv_)
|
||||
: in(in_),
|
||||
buf(buf_),
|
||||
out(out_),
|
||||
n(n_),
|
||||
batch_size(batch_size_),
|
||||
elems_per_thread(elems_per_thread_),
|
||||
elem(elem_),
|
||||
grid(grid_),
|
||||
inv(inv_) {
|
||||
// Account for padding on last threadgroup
|
||||
threads_per_tg = elem.x == grid.x - 1
|
||||
? (batch_size - (grid.x - 1) * grid.y) * grid.z
|
||||
: grid.y * grid.z;
|
||||
}
|
||||
|
||||
// ifft(x) = 1/n * conj(fft(conj(x)))
|
||||
METAL_FUNC float2 post_in(float2 elem) const {
|
||||
return inv ? float2(elem.x, -elem.y) : elem;
|
||||
}
|
||||
|
||||
// Handle float case for generic RFFT alg
|
||||
METAL_FUNC float2 post_in(float elem) const {
|
||||
return float2(elem, 0);
|
||||
}
|
||||
|
||||
METAL_FUNC float2 pre_out(float2 elem) const {
|
||||
return inv ? float2(elem.x / n, -elem.y / n) : elem;
|
||||
}
|
||||
|
||||
METAL_FUNC float2 pre_out(float2 elem, int length) const {
|
||||
return inv ? float2(elem.x / length, -elem.y / length) : elem;
|
||||
}
|
||||
|
||||
METAL_FUNC bool out_of_bounds() const {
|
||||
// Account for possible extra threadgroups
|
||||
int grid_index = elem.x * grid.y + elem.y;
|
||||
return grid_index >= batch_size;
|
||||
}
|
||||
|
||||
METAL_FUNC void load() const {
|
||||
int batch_idx = elem.x * grid.y * n;
|
||||
short tg_idx = elem.y * grid.z + elem.z;
|
||||
short max_index = grid.y * n - 2;
|
||||
|
||||
// 2 complex64s = 128 bits
|
||||
constexpr int read_width = 2;
|
||||
for (short e = 0; e < (elems_per_thread / read_width); e++) {
|
||||
short index = read_width * tg_idx + read_width * threads_per_tg * e;
|
||||
index = metal::min(index, max_index);
|
||||
// vectorized reads
|
||||
buf[index] = post_in(in[batch_idx + index]);
|
||||
buf[index + 1] = post_in(in[batch_idx + index + 1]);
|
||||
}
|
||||
max_index += 1;
|
||||
if (elems_per_thread % 2 != 0) {
|
||||
short index = tg_idx +
|
||||
read_width * threads_per_tg * (elems_per_thread / read_width);
|
||||
index = metal::min(index, max_index);
|
||||
buf[index] = post_in(in[batch_idx + index]);
|
||||
}
|
||||
}
|
||||
|
||||
METAL_FUNC void write() const {
|
||||
int batch_idx = elem.x * grid.y * n;
|
||||
short tg_idx = elem.y * grid.z + elem.z;
|
||||
short max_index = grid.y * n - 2;
|
||||
|
||||
constexpr int read_width = 2;
|
||||
for (short e = 0; e < (elems_per_thread / read_width); e++) {
|
||||
short index = read_width * tg_idx + read_width * threads_per_tg * e;
|
||||
index = metal::min(index, max_index);
|
||||
// vectorized reads
|
||||
out[batch_idx + index] = pre_out(buf[index]);
|
||||
out[batch_idx + index + 1] = pre_out(buf[index + 1]);
|
||||
}
|
||||
max_index += 1;
|
||||
if (elems_per_thread % 2 != 0) {
|
||||
short index = tg_idx +
|
||||
read_width * threads_per_tg * (elems_per_thread / read_width);
|
||||
index = metal::min(index, max_index);
|
||||
out[batch_idx + index] = pre_out(buf[index]);
|
||||
}
|
||||
}
|
||||
|
||||
// Padded IO for Bluestein's algorithm
|
||||
METAL_FUNC void load_padded(int length, const device float2* w_k) const {
|
||||
int batch_idx = elem.x * grid.y * length + elem.y * length;
|
||||
int fft_idx = elem.z;
|
||||
int m = grid.z;
|
||||
|
||||
threadgroup float2* seq_buf = buf + elem.y * n;
|
||||
for (int e = 0; e < elems_per_thread; e++) {
|
||||
int index = metal::min(fft_idx + e * m, n - 1);
|
||||
if (index < length) {
|
||||
float2 elem = post_in(in[batch_idx + index]);
|
||||
seq_buf[index] = complex_mul(elem, w_k[index]);
|
||||
} else {
|
||||
seq_buf[index] = 0.0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
METAL_FUNC void write_padded(int length, const device float2* w_k) const {
|
||||
int batch_idx = elem.x * grid.y * length + elem.y * length;
|
||||
int fft_idx = elem.z;
|
||||
int m = grid.z;
|
||||
float2 inv_factor = {1.0f / n, -1.0f / n};
|
||||
|
||||
threadgroup float2* seq_buf = buf + elem.y * n;
|
||||
for (int e = 0; e < elems_per_thread; e++) {
|
||||
int index = metal::min(fft_idx + e * m, n - 1);
|
||||
if (index < length) {
|
||||
float2 elem = seq_buf[index + length - 1] * inv_factor;
|
||||
out[batch_idx + index] = pre_out(complex_mul(elem, w_k[index]), length);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Strided IO for four step FFT
|
||||
METAL_FUNC void compute_strided_indices(int stride, int overall_n) {
|
||||
// Use the batch threadgroup dimension to coalesce memory accesses:
|
||||
// e.g. stride = 12
|
||||
// device | shared mem
|
||||
// 0 1 2 3 | 0 12 - -
|
||||
// - - - - | 1 13 - -
|
||||
// - - - - | 2 14 - -
|
||||
// 12 13 14 15 | 3 15 - -
|
||||
int coalesce_width = grid.y;
|
||||
int tg_idx = elem.y * grid.z + elem.z;
|
||||
int outer_batch_size = stride / coalesce_width;
|
||||
|
||||
int strided_batch_idx = (elem.x % outer_batch_size) * coalesce_width +
|
||||
overall_n * (elem.x / outer_batch_size);
|
||||
strided_device_idx = strided_batch_idx +
|
||||
tg_idx / coalesce_width * elems_per_thread * stride +
|
||||
tg_idx % coalesce_width;
|
||||
strided_shared_idx = (tg_idx % coalesce_width) * n +
|
||||
tg_idx / coalesce_width * elems_per_thread;
|
||||
}
|
||||
|
||||
// Four Step FFT First Step
|
||||
METAL_FUNC void load_strided(int stride, int overall_n) {
|
||||
compute_strided_indices(stride, overall_n);
|
||||
for (int e = 0; e < elems_per_thread; e++) {
|
||||
buf[strided_shared_idx + e] =
|
||||
post_in(in[strided_device_idx + e * stride]);
|
||||
}
|
||||
}
|
||||
|
||||
METAL_FUNC void write_strided(int stride, int overall_n) {
|
||||
for (int e = 0; e < elems_per_thread; e++) {
|
||||
float2 output = buf[strided_shared_idx + e];
|
||||
int combined_idx = (strided_device_idx + e * stride) % overall_n;
|
||||
int ij = (combined_idx / stride) * (combined_idx % stride);
|
||||
// Apply four step twiddles at end of first step
|
||||
float2 twiddle = get_twiddle(ij, overall_n);
|
||||
out[strided_device_idx + e * stride] = complex_mul(output, twiddle);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Four Step FFT Second Step
|
||||
template <>
|
||||
METAL_FUNC void ReadWriter<float2, float2, /*step=*/1>::load_strided(
|
||||
int stride,
|
||||
int overall_n) {
|
||||
// Silence compiler warnings
|
||||
(void)stride;
|
||||
(void)overall_n;
|
||||
// Don't invert between steps
|
||||
bool default_inv = inv;
|
||||
inv = false;
|
||||
load();
|
||||
inv = default_inv;
|
||||
}
|
||||
|
||||
template <>
|
||||
METAL_FUNC void ReadWriter<float2, float2, /*step=*/1>::write_strided(
|
||||
int stride,
|
||||
int overall_n) {
|
||||
compute_strided_indices(stride, overall_n);
|
||||
for (int e = 0; e < elems_per_thread; e++) {
|
||||
float2 output = buf[strided_shared_idx + e];
|
||||
out[strided_device_idx + e * stride] = pre_out(output, overall_n);
|
||||
}
|
||||
}
|
||||
|
||||
// For RFFT, we interleave batches of two real sequences into one complex one:
|
||||
//
|
||||
// z_k = x_k + j.y_k
|
||||
// X_k = (Z_k + Z_(N-k)*) / 2
|
||||
// Y_k = -j * ((Z_k - Z_(N-k)*) / 2)
|
||||
//
|
||||
// This roughly doubles the throughput over the regular FFT.
|
||||
template <>
|
||||
METAL_FUNC bool ReadWriter<float, float2>::out_of_bounds() const {
|
||||
int grid_index = elem.x * grid.y + elem.y;
|
||||
// We pack two sequences into one for RFFTs
|
||||
return grid_index * 2 >= batch_size;
|
||||
}
|
||||
|
||||
template <>
|
||||
METAL_FUNC void ReadWriter<float, float2>::load() const {
|
||||
int batch_idx = elem.x * grid.y * n * 2 + elem.y * n * 2;
|
||||
threadgroup float2* seq_buf = buf + elem.y * n;
|
||||
|
||||
// No out of bounds accesses on odd batch sizes
|
||||
int grid_index = elem.x * grid.y + elem.y;
|
||||
short next_in =
|
||||
batch_size % 2 == 1 && grid_index * 2 == batch_size - 1 ? 0 : n;
|
||||
|
||||
short m = grid.z;
|
||||
short fft_idx = elem.z;
|
||||
|
||||
for (int e = 0; e < elems_per_thread; e++) {
|
||||
int index = metal::min(fft_idx + e * m, n - 1);
|
||||
seq_buf[index].x = in[batch_idx + index];
|
||||
seq_buf[index].y = in[batch_idx + index + next_in];
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
METAL_FUNC void ReadWriter<float, float2>::write() const {
|
||||
short n_over_2 = (n / 2) + 1;
|
||||
|
||||
int batch_idx = elem.x * grid.y * n_over_2 * 2 + elem.y * n_over_2 * 2;
|
||||
threadgroup float2* seq_buf = buf + elem.y * n;
|
||||
|
||||
int grid_index = elem.x * grid.y + elem.y;
|
||||
short next_out =
|
||||
batch_size % 2 == 1 && grid_index * 2 == batch_size - 1 ? 0 : n_over_2;
|
||||
|
||||
float2 conj = {1, -1};
|
||||
float2 minus_j = {0, -1};
|
||||
|
||||
short m = grid.z;
|
||||
short fft_idx = elem.z;
|
||||
|
||||
for (int e = 0; e < elems_per_thread / 2 + 1; e++) {
|
||||
int index = metal::min(fft_idx + e * m, n_over_2 - 1);
|
||||
// x_0 = z_0.real
|
||||
// y_0 = z_0.imag
|
||||
if (index == 0) {
|
||||
out[batch_idx + index] = {seq_buf[index].x, 0};
|
||||
out[batch_idx + index + next_out] = {seq_buf[index].y, 0};
|
||||
} else {
|
||||
float2 x_k = seq_buf[index];
|
||||
float2 x_n_minus_k = seq_buf[n - index] * conj;
|
||||
out[batch_idx + index] = (x_k + x_n_minus_k) / 2;
|
||||
out[batch_idx + index + next_out] =
|
||||
complex_mul(((x_k - x_n_minus_k) / 2), minus_j);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
METAL_FUNC void ReadWriter<float, float2>::load_padded(
|
||||
int length,
|
||||
const device float2* w_k) const {
|
||||
int batch_idx = elem.x * grid.y * length * 2 + elem.y * length * 2;
|
||||
threadgroup float2* seq_buf = buf + elem.y * n;
|
||||
|
||||
// No out of bounds accesses on odd batch sizes
|
||||
int grid_index = elem.x * grid.y + elem.y;
|
||||
short next_in =
|
||||
batch_size % 2 == 1 && grid_index * 2 == batch_size - 1 ? 0 : length;
|
||||
|
||||
short m = grid.z;
|
||||
short fft_idx = elem.z;
|
||||
|
||||
for (int e = 0; e < elems_per_thread; e++) {
|
||||
int index = metal::min(fft_idx + e * m, n - 1);
|
||||
if (index < length) {
|
||||
float2 elem =
|
||||
float2(in[batch_idx + index], in[batch_idx + index + next_in]);
|
||||
seq_buf[index] = complex_mul(elem, w_k[index]);
|
||||
} else {
|
||||
seq_buf[index] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
METAL_FUNC void ReadWriter<float, float2>::write_padded(
|
||||
int length,
|
||||
const device float2* w_k) const {
|
||||
int length_over_2 = (length / 2) + 1;
|
||||
int batch_idx =
|
||||
elem.x * grid.y * length_over_2 * 2 + elem.y * length_over_2 * 2;
|
||||
threadgroup float2* seq_buf = buf + elem.y * n + length - 1;
|
||||
|
||||
int grid_index = elem.x * grid.y + elem.y;
|
||||
short next_out = batch_size % 2 == 1 && grid_index * 2 == batch_size - 1
|
||||
? 0
|
||||
: length_over_2;
|
||||
|
||||
float2 conj = {1, -1};
|
||||
float2 inv_factor = {1.0f / n, -1.0f / n};
|
||||
float2 minus_j = {0, -1};
|
||||
|
||||
short m = grid.z;
|
||||
short fft_idx = elem.z;
|
||||
|
||||
for (int e = 0; e < elems_per_thread / 2 + 1; e++) {
|
||||
int index = metal::min(fft_idx + e * m, length_over_2 - 1);
|
||||
// x_0 = z_0.real
|
||||
// y_0 = z_0.imag
|
||||
if (index == 0) {
|
||||
float2 elem = complex_mul(w_k[index], seq_buf[index] * inv_factor);
|
||||
out[batch_idx + index] = float2(elem.x, 0);
|
||||
out[batch_idx + index + next_out] = float2(elem.y, 0);
|
||||
} else {
|
||||
float2 x_k = complex_mul(w_k[index], seq_buf[index] * inv_factor);
|
||||
float2 x_n_minus_k = complex_mul(
|
||||
w_k[length - index], seq_buf[length - index] * inv_factor);
|
||||
x_n_minus_k *= conj;
|
||||
// w_k should happen before this extraction
|
||||
out[batch_idx + index] = (x_k + x_n_minus_k) / 2;
|
||||
out[batch_idx + index + next_out] =
|
||||
complex_mul(((x_k - x_n_minus_k) / 2), minus_j);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// For IRFFT, we do the opposite
|
||||
//
|
||||
// Z_k = X_k + j.Y_k
|
||||
// x_k = Re(Z_k)
|
||||
// Y_k = Imag(Z_k)
|
||||
template <>
|
||||
METAL_FUNC bool ReadWriter<float2, float>::out_of_bounds() const {
|
||||
int grid_index = elem.x * grid.y + elem.y;
|
||||
// We pack two sequences into one for IRFFTs
|
||||
return grid_index * 2 >= batch_size;
|
||||
}
|
||||
|
||||
template <>
|
||||
METAL_FUNC void ReadWriter<float2, float>::load() const {
|
||||
short n_over_2 = (n / 2) + 1;
|
||||
int batch_idx = elem.x * grid.y * n_over_2 * 2 + elem.y * n_over_2 * 2;
|
||||
threadgroup float2* seq_buf = buf + elem.y * n;
|
||||
|
||||
// No out of bounds accesses on odd batch sizes
|
||||
int grid_index = elem.x * grid.y + elem.y;
|
||||
short next_in =
|
||||
batch_size % 2 == 1 && grid_index * 2 == batch_size - 1 ? 0 : n_over_2;
|
||||
|
||||
short m = grid.z;
|
||||
short fft_idx = elem.z;
|
||||
|
||||
float2 conj = {1, -1};
|
||||
float2 plus_j = {0, 1};
|
||||
|
||||
for (int t = 0; t < elems_per_thread / 2 + 1; t++) {
|
||||
int index = metal::min(fft_idx + t * m, n_over_2 - 1);
|
||||
float2 x = in[batch_idx + index];
|
||||
float2 y = in[batch_idx + index + next_in];
|
||||
// NumPy forces first input to be real
|
||||
bool first_val = index == 0;
|
||||
// NumPy forces last input on even irffts to be real
|
||||
bool last_val = n % 2 == 0 && index == n_over_2 - 1;
|
||||
if (first_val || last_val) {
|
||||
x = float2(x.x, 0);
|
||||
y = float2(y.x, 0);
|
||||
}
|
||||
seq_buf[index] = x + complex_mul(y, plus_j);
|
||||
seq_buf[index].y = -seq_buf[index].y;
|
||||
if (index > 0 && !last_val) {
|
||||
seq_buf[n - index] = (x * conj) + complex_mul(y * conj, plus_j);
|
||||
seq_buf[n - index].y = -seq_buf[n - index].y;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
METAL_FUNC void ReadWriter<float2, float>::write() const {
|
||||
int batch_idx = elem.x * grid.y * n * 2 + elem.y * n * 2;
|
||||
threadgroup float2* seq_buf = buf + elem.y * n;
|
||||
|
||||
int grid_index = elem.x * grid.y + elem.y;
|
||||
short next_out =
|
||||
batch_size % 2 == 1 && grid_index * 2 == batch_size - 1 ? 0 : n;
|
||||
|
||||
short m = grid.z;
|
||||
short fft_idx = elem.z;
|
||||
|
||||
for (int e = 0; e < elems_per_thread; e++) {
|
||||
int index = metal::min(fft_idx + e * m, n - 1);
|
||||
out[batch_idx + index] = seq_buf[index].x / n;
|
||||
out[batch_idx + index + next_out] = seq_buf[index].y / -n;
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
METAL_FUNC void ReadWriter<float2, float>::load_padded(
|
||||
int length,
|
||||
const device float2* w_k) const {
|
||||
int n_over_2 = (n / 2) + 1;
|
||||
int length_over_2 = (length / 2) + 1;
|
||||
|
||||
int batch_idx =
|
||||
elem.x * grid.y * length_over_2 * 2 + elem.y * length_over_2 * 2;
|
||||
threadgroup float2* seq_buf = buf + elem.y * n;
|
||||
|
||||
// No out of bounds accesses on odd batch sizes
|
||||
int grid_index = elem.x * grid.y + elem.y;
|
||||
short next_in = batch_size % 2 == 1 && grid_index * 2 == batch_size - 1
|
||||
? 0
|
||||
: length_over_2;
|
||||
|
||||
short m = grid.z;
|
||||
short fft_idx = elem.z;
|
||||
|
||||
float2 conj = {1, -1};
|
||||
float2 plus_j = {0, 1};
|
||||
|
||||
for (int t = 0; t < elems_per_thread / 2 + 1; t++) {
|
||||
int index = metal::min(fft_idx + t * m, n_over_2 - 1);
|
||||
float2 x = in[batch_idx + index];
|
||||
float2 y = in[batch_idx + index + next_in];
|
||||
if (index < length_over_2) {
|
||||
bool last_val = length % 2 == 0 && index == length_over_2 - 1;
|
||||
if (last_val) {
|
||||
x = float2(x.x, 0);
|
||||
y = float2(y.x, 0);
|
||||
}
|
||||
float2 elem1 = x + complex_mul(y, plus_j);
|
||||
seq_buf[index] = complex_mul(elem1 * conj, w_k[index]);
|
||||
if (index > 0 && !last_val) {
|
||||
float2 elem2 = (x * conj) + complex_mul(y * conj, plus_j);
|
||||
seq_buf[length - index] =
|
||||
complex_mul(elem2 * conj, w_k[length - index]);
|
||||
}
|
||||
} else {
|
||||
short pad_index = metal::min(length + (index - length_over_2) * 2, n - 2);
|
||||
seq_buf[pad_index] = 0;
|
||||
seq_buf[pad_index + 1] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
METAL_FUNC void ReadWriter<float2, float>::write_padded(
|
||||
int length,
|
||||
const device float2* w_k) const {
|
||||
int batch_idx = elem.x * grid.y * length * 2 + elem.y * length * 2;
|
||||
threadgroup float2* seq_buf = buf + elem.y * n + length - 1;
|
||||
|
||||
int grid_index = elem.x * grid.y + elem.y;
|
||||
short next_out =
|
||||
batch_size % 2 == 1 && grid_index * 2 == batch_size - 1 ? 0 : length;
|
||||
|
||||
short m = grid.z;
|
||||
short fft_idx = elem.z;
|
||||
|
||||
float2 inv_factor = {1.0f / n, -1.0f / n};
|
||||
for (int e = 0; e < elems_per_thread; e++) {
|
||||
int index = fft_idx + e * m;
|
||||
if (index < length) {
|
||||
float2 output = complex_mul(seq_buf[index] * inv_factor, w_k[index]);
|
||||
out[batch_idx + index] = output.x / length;
|
||||
out[batch_idx + index + next_out] = output.y / -length;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Four Step RFFT
|
||||
template <>
|
||||
METAL_FUNC void
|
||||
ReadWriter<float2, float2, /*step=*/1, /*real=*/true>::load_strided(
|
||||
int stride,
|
||||
int overall_n) {
|
||||
// Silence compiler warnings
|
||||
(void)stride;
|
||||
(void)overall_n;
|
||||
// Don't invert between steps
|
||||
bool default_inv = inv;
|
||||
inv = false;
|
||||
load();
|
||||
inv = default_inv;
|
||||
}
|
||||
|
||||
template <>
|
||||
METAL_FUNC void
|
||||
ReadWriter<float2, float2, /*step=*/1, /*real=*/true>::write_strided(
|
||||
int stride,
|
||||
int overall_n) {
|
||||
int overall_n_over_2 = overall_n / 2 + 1;
|
||||
int coalesce_width = grid.y;
|
||||
int tg_idx = elem.y * grid.z + elem.z;
|
||||
int outer_batch_size = stride / coalesce_width;
|
||||
|
||||
int strided_batch_idx = (elem.x % outer_batch_size) * coalesce_width +
|
||||
overall_n_over_2 * (elem.x / outer_batch_size);
|
||||
strided_device_idx = strided_batch_idx +
|
||||
tg_idx / coalesce_width * elems_per_thread / 2 * stride +
|
||||
tg_idx % coalesce_width;
|
||||
strided_shared_idx = (tg_idx % coalesce_width) * n +
|
||||
tg_idx / coalesce_width * elems_per_thread / 2;
|
||||
for (int e = 0; e < elems_per_thread / 2; e++) {
|
||||
float2 output = buf[strided_shared_idx + e];
|
||||
out[strided_device_idx + e * stride] = output;
|
||||
}
|
||||
|
||||
// Add on n/2 + 1 element
|
||||
if (tg_idx == 0 && elem.x % outer_batch_size == 0) {
|
||||
out[strided_batch_idx + overall_n / 2] = buf[n / 2];
|
||||
}
|
||||
}
|
||||
|
||||
// Four Step IRFFT
|
||||
template <>
|
||||
METAL_FUNC void
|
||||
ReadWriter<float2, float2, /*step=*/0, /*real=*/true>::load_strided(
|
||||
int stride,
|
||||
int overall_n) {
|
||||
int overall_n_over_2 = overall_n / 2 + 1;
|
||||
auto conj = float2(1, -1);
|
||||
|
||||
compute_strided_indices(stride, overall_n);
|
||||
// Translate indices in terms of N - k
|
||||
for (int e = 0; e < elems_per_thread; e++) {
|
||||
int device_idx = strided_device_idx + e * stride;
|
||||
int overall_batch = device_idx / overall_n;
|
||||
int overall_index = device_idx % overall_n;
|
||||
if (overall_index < overall_n_over_2) {
|
||||
device_idx -= overall_batch * (overall_n - overall_n_over_2);
|
||||
buf[strided_shared_idx + e] = in[device_idx] * conj;
|
||||
} else {
|
||||
int conj_idx = overall_n - overall_index;
|
||||
device_idx = overall_batch * overall_n_over_2 + conj_idx;
|
||||
buf[strided_shared_idx + e] = in[device_idx];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
METAL_FUNC void
|
||||
ReadWriter<float2, float, /*step=*/1, /*real=*/true>::load_strided(
|
||||
int stride,
|
||||
int overall_n) {
|
||||
// Silence compiler warnings
|
||||
(void)stride;
|
||||
(void)overall_n;
|
||||
bool default_inv = inv;
|
||||
inv = false;
|
||||
load();
|
||||
inv = default_inv;
|
||||
}
|
||||
|
||||
template <>
|
||||
METAL_FUNC void
|
||||
ReadWriter<float2, float, /*step=*/1, /*real=*/true>::write_strided(
|
||||
int stride,
|
||||
int overall_n) {
|
||||
compute_strided_indices(stride, overall_n);
|
||||
|
||||
for (int e = 0; e < elems_per_thread; e++) {
|
||||
out[strided_device_idx + e * stride] =
|
||||
pre_out(buf[strided_shared_idx + e], overall_n).x;
|
||||
}
|
||||
}
|
@ -191,4 +191,17 @@ MTL::ComputePipelineState* get_steel_conv_general_kernel(
|
||||
return d.get_kernel(kernel_name);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_fft_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const std::string& hash_name,
|
||||
const int tg_mem_size,
|
||||
const std::string& in_type,
|
||||
const std::string& out_type,
|
||||
int step,
|
||||
bool real,
|
||||
const metal::MTLFCList& func_consts) {
|
||||
return d.get_kernel(kernel_name, "mlx", hash_name, func_consts);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -134,6 +134,13 @@ bool is_power_of_2(int n) {
|
||||
return ((n & (n - 1)) == 0) && n != 0;
|
||||
}
|
||||
|
||||
int next_power_of_2(int n) {
|
||||
if (is_power_of_2(n)) {
|
||||
return n;
|
||||
}
|
||||
return pow(2, std::ceil(std::log2(n)));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -81,7 +81,8 @@ array fft_impl(
|
||||
if (any_greater) {
|
||||
// Pad with zeros
|
||||
auto tmp = zeros(in_shape, a.dtype(), s);
|
||||
in = scatter(tmp, std::vector<array>{}, in, std::vector<int>{}, s);
|
||||
std::vector<int> starts(in.ndim(), 0);
|
||||
in = slice_update(tmp, in, starts, in.shape());
|
||||
}
|
||||
|
||||
auto out_shape = in_shape;
|
||||
|
@ -16,96 +16,140 @@ class TestFFT(mlx_tests.MLXTestCase):
|
||||
np.testing.assert_allclose(out_np, out_mx, atol=atol, rtol=rtol)
|
||||
|
||||
def test_fft(self):
|
||||
with mx.stream(mx.cpu):
|
||||
r = np.random.rand(100).astype(np.float32)
|
||||
i = np.random.rand(100).astype(np.float32)
|
||||
a_np = r + 1j * i
|
||||
self.check_mx_np(mx.fft.fft, np.fft.fft, a_np)
|
||||
r = np.random.rand(100).astype(np.float32)
|
||||
i = np.random.rand(100).astype(np.float32)
|
||||
a_np = r + 1j * i
|
||||
self.check_mx_np(mx.fft.fft, np.fft.fft, a_np)
|
||||
|
||||
# Check with slicing and padding
|
||||
r = np.random.rand(100).astype(np.float32)
|
||||
i = np.random.rand(100).astype(np.float32)
|
||||
a_np = r + 1j * i
|
||||
self.check_mx_np(mx.fft.fft, np.fft.fft, a_np, n=80)
|
||||
self.check_mx_np(mx.fft.fft, np.fft.fft, a_np, n=120)
|
||||
# Check with slicing and padding
|
||||
r = np.random.rand(100).astype(np.float32)
|
||||
i = np.random.rand(100).astype(np.float32)
|
||||
a_np = r + 1j * i
|
||||
self.check_mx_np(mx.fft.fft, np.fft.fft, a_np, n=80)
|
||||
self.check_mx_np(mx.fft.fft, np.fft.fft, a_np, n=120)
|
||||
|
||||
# Check different axes
|
||||
r = np.random.rand(100, 100).astype(np.float32)
|
||||
i = np.random.rand(100, 100).astype(np.float32)
|
||||
a_np = r + 1j * i
|
||||
self.check_mx_np(mx.fft.fft, np.fft.fft, a_np, axis=0)
|
||||
self.check_mx_np(mx.fft.fft, np.fft.fft, a_np, axis=1)
|
||||
# Check different axes
|
||||
r = np.random.rand(100, 100).astype(np.float32)
|
||||
i = np.random.rand(100, 100).astype(np.float32)
|
||||
a_np = r + 1j * i
|
||||
self.check_mx_np(mx.fft.fft, np.fft.fft, a_np, axis=0)
|
||||
self.check_mx_np(mx.fft.fft, np.fft.fft, a_np, axis=1)
|
||||
|
||||
# Check real fft
|
||||
a_np = np.random.rand(100).astype(np.float32)
|
||||
self.check_mx_np(mx.fft.rfft, np.fft.rfft, a_np)
|
||||
self.check_mx_np(mx.fft.rfft, np.fft.rfft, a_np, n=80)
|
||||
self.check_mx_np(mx.fft.rfft, np.fft.rfft, a_np, n=120)
|
||||
# Check real fft
|
||||
a_np = np.random.rand(100).astype(np.float32)
|
||||
self.check_mx_np(mx.fft.rfft, np.fft.rfft, a_np)
|
||||
self.check_mx_np(mx.fft.rfft, np.fft.rfft, a_np, n=80)
|
||||
self.check_mx_np(mx.fft.rfft, np.fft.rfft, a_np, n=120)
|
||||
|
||||
# Check real inverse
|
||||
r = np.random.rand(100, 100).astype(np.float32)
|
||||
i = np.random.rand(100, 100).astype(np.float32)
|
||||
a_np = r + 1j * i
|
||||
self.check_mx_np(mx.fft.ifft, np.fft.ifft, a_np)
|
||||
self.check_mx_np(mx.fft.ifft, np.fft.ifft, a_np, n=80)
|
||||
self.check_mx_np(mx.fft.ifft, np.fft.ifft, a_np, n=120)
|
||||
self.check_mx_np(mx.fft.irfft, np.fft.irfft, a_np)
|
||||
self.check_mx_np(mx.fft.irfft, np.fft.irfft, a_np, n=80)
|
||||
self.check_mx_np(mx.fft.irfft, np.fft.irfft, a_np, n=120)
|
||||
# Check real inverse
|
||||
r = np.random.rand(100, 100).astype(np.float32)
|
||||
i = np.random.rand(100, 100).astype(np.float32)
|
||||
a_np = r + 1j * i
|
||||
self.check_mx_np(mx.fft.ifft, np.fft.ifft, a_np)
|
||||
self.check_mx_np(mx.fft.ifft, np.fft.ifft, a_np, n=80)
|
||||
self.check_mx_np(mx.fft.ifft, np.fft.ifft, a_np, n=120)
|
||||
|
||||
x = np.fft.rfft(a_np)
|
||||
self.check_mx_np(mx.fft.irfft, np.fft.irfft, x)
|
||||
|
||||
def test_fftn(self):
|
||||
with mx.stream(mx.cpu):
|
||||
r = np.random.randn(8, 8, 8).astype(np.float32)
|
||||
i = np.random.randn(8, 8, 8).astype(np.float32)
|
||||
a = r + 1j * i
|
||||
r = np.random.randn(8, 8, 8).astype(np.float32)
|
||||
i = np.random.randn(8, 8, 8).astype(np.float32)
|
||||
a = r + 1j * i
|
||||
|
||||
axes = [None, (1, 2), (2, 1), (0, 2)]
|
||||
shapes = [None, (10, 5), (5, 10)]
|
||||
ops = [
|
||||
"fft2",
|
||||
"ifft2",
|
||||
"rfft2",
|
||||
"irfft2",
|
||||
"fftn",
|
||||
"ifftn",
|
||||
"rfftn",
|
||||
"irfftn",
|
||||
axes = [None, (1, 2), (2, 1), (0, 2)]
|
||||
shapes = [None, (10, 5), (5, 10)]
|
||||
ops = [
|
||||
"fft2",
|
||||
"ifft2",
|
||||
"rfft2",
|
||||
"irfft2",
|
||||
"fftn",
|
||||
"ifftn",
|
||||
"rfftn",
|
||||
"irfftn",
|
||||
]
|
||||
|
||||
for op, ax, s in itertools.product(ops, axes, shapes):
|
||||
x = a
|
||||
if op in ["rfft2", "rfftn"]:
|
||||
x = r
|
||||
elif op == "irfft2":
|
||||
x = np.ascontiguousarray(np.fft.rfft2(x, axes=ax, s=s))
|
||||
elif op == "irfftn":
|
||||
x = np.ascontiguousarray(np.fft.rfftn(x, axes=ax, s=s))
|
||||
mx_op = getattr(mx.fft, op)
|
||||
np_op = getattr(np.fft, op)
|
||||
self.check_mx_np(mx_op, np_op, x, axes=ax, s=s)
|
||||
|
||||
def _run_ffts(self, shape, atol=1e-4, rtol=1e-4):
|
||||
np.random.seed(9)
|
||||
|
||||
r = np.random.rand(*shape).astype(np.float32)
|
||||
i = np.random.rand(*shape).astype(np.float32)
|
||||
a_np = r + 1j * i
|
||||
self.check_mx_np(mx.fft.fft, np.fft.fft, a_np, atol=atol, rtol=rtol)
|
||||
self.check_mx_np(mx.fft.ifft, np.fft.ifft, a_np, atol=atol, rtol=rtol)
|
||||
|
||||
self.check_mx_np(mx.fft.rfft, np.fft.rfft, r, atol=atol, rtol=rtol)
|
||||
|
||||
ia_np = np.fft.rfft(a_np)
|
||||
self.check_mx_np(
|
||||
mx.fft.irfft, np.fft.irfft, ia_np, atol=atol, rtol=rtol, n=shape[-1]
|
||||
)
|
||||
self.check_mx_np(mx.fft.irfft, np.fft.irfft, ia_np, atol=atol, rtol=rtol)
|
||||
|
||||
def test_fft_shared_mem(self):
|
||||
nums = np.concatenate(
|
||||
[
|
||||
# small radix
|
||||
np.arange(2, 14),
|
||||
# powers of 2
|
||||
[2**k for k in range(4, 13)],
|
||||
# stockham
|
||||
[3 * 3 * 3, 3 * 11, 11 * 13 * 2, 7 * 4 * 13 * 11, 13 * 13 * 11],
|
||||
# rader
|
||||
[17, 23, 29, 17 * 8 * 3, 23 * 2, 1153, 1982],
|
||||
# bluestein
|
||||
[47, 83, 17 * 17],
|
||||
# large stockham
|
||||
[3159, 3645, 3969, 4004],
|
||||
]
|
||||
)
|
||||
for batch_size in (1, 3, 32):
|
||||
for num in nums:
|
||||
atol = 1e-4 if num < 1025 else 1e-3
|
||||
self._run_ffts((batch_size, num), atol=atol)
|
||||
|
||||
for op, ax, s in itertools.product(ops, axes, shapes):
|
||||
x = a
|
||||
if op in ["rfft2", "rfftn"]:
|
||||
x = r
|
||||
mx_op = getattr(mx.fft, op)
|
||||
np_op = getattr(np.fft, op)
|
||||
self.check_mx_np(mx_op, np_op, x, axes=ax, s=s)
|
||||
@unittest.skip("Too slow for CI but useful for local testing.")
|
||||
def test_fft_exhaustive(self):
|
||||
nums = range(2, 4097)
|
||||
for batch_size in (1, 3, 32):
|
||||
for num in nums:
|
||||
print(num)
|
||||
atol = 1e-4 if num < 1025 else 1e-3
|
||||
self._run_ffts((batch_size, num), atol=atol)
|
||||
|
||||
def test_fft_powers_of_two(self):
|
||||
shape = (16, 4, 8)
|
||||
# np.fft.fft always uses double precision complex128
|
||||
# mx.fft.fft only supports single precision complex64
|
||||
# hence the fairly tolerant equality checks.
|
||||
atol = 1e-4
|
||||
rtol = 1e-4
|
||||
np.random.seed(7)
|
||||
for k in range(4, 12):
|
||||
r = np.random.rand(*shape, 2**k).astype(np.float32)
|
||||
i = np.random.rand(*shape, 2**k).astype(np.float32)
|
||||
a_np = r + 1j * i
|
||||
self.check_mx_np(mx.fft.fft, np.fft.fft, a_np, atol=atol, rtol=rtol)
|
||||
def test_fft_big_powers_of_two(self):
|
||||
# TODO: improve precision on big powers of two on GPU
|
||||
for k in range(12, 17):
|
||||
self._run_ffts((3, 2**k), atol=1e-3)
|
||||
|
||||
r = np.random.rand(*shape, 32).astype(np.float32)
|
||||
i = np.random.rand(*shape, 32).astype(np.float32)
|
||||
a_np = r + 1j * i
|
||||
for axis in range(4):
|
||||
self.check_mx_np(
|
||||
mx.fft.fft, np.fft.fft, a_np, atol=atol, rtol=rtol, axis=axis
|
||||
)
|
||||
for k in range(17, 20):
|
||||
self._run_ffts((3, 2**k), atol=1e-2)
|
||||
|
||||
r = np.random.rand(4, 8).astype(np.float32)
|
||||
i = np.random.rand(4, 8).astype(np.float32)
|
||||
a_np = r + 1j * i
|
||||
a_mx = mx.array(a_np)
|
||||
def test_fft_large_numbers(self):
|
||||
numbers = [
|
||||
1037, # prime > 2048
|
||||
18247, # medium size prime factors
|
||||
1259 * 11, # large prime factors
|
||||
7883, # large prime
|
||||
3**8, # large stockham decomposable
|
||||
3109, # bluestein
|
||||
4006, # large rader
|
||||
]
|
||||
for large_num in numbers:
|
||||
self._run_ffts((1, large_num), atol=1e-3)
|
||||
|
||||
def test_fft_contiguity(self):
|
||||
r = np.random.rand(4, 8).astype(np.float32)
|
||||
|
@ -7,8 +7,6 @@
|
||||
using namespace mlx::core;
|
||||
|
||||
TEST_CASE("test fft basics") {
|
||||
auto device = default_device();
|
||||
set_default_device(Device::cpu);
|
||||
array x(1.0);
|
||||
CHECK_THROWS(fft::fft(x));
|
||||
CHECK_THROWS(fft::ifft(x));
|
||||
@ -94,13 +92,9 @@ TEST_CASE("test fft basics") {
|
||||
CHECK(array_equal(y, array(expected_1, {2, 2})).item<bool>());
|
||||
CHECK(array_equal(fft::ifft(y, 1), x).item<bool>());
|
||||
}
|
||||
set_default_device(device);
|
||||
}
|
||||
|
||||
TEST_CASE("test real ffts") {
|
||||
auto device = default_device();
|
||||
set_default_device(Device::cpu);
|
||||
|
||||
auto x = array({1.0});
|
||||
auto y = fft::rfft(x);
|
||||
CHECK_EQ(y.dtype(), complex64);
|
||||
@ -124,14 +118,9 @@ TEST_CASE("test real ffts") {
|
||||
CHECK_EQ(y.size(), 2);
|
||||
CHECK_EQ(y.dtype(), float32);
|
||||
CHECK(array_equal(y, array({0.5f, -0.5f})).item<bool>());
|
||||
|
||||
set_default_device(device);
|
||||
}
|
||||
|
||||
TEST_CASE("test fftn") {
|
||||
auto device = default_device();
|
||||
set_default_device(Device::cpu);
|
||||
|
||||
auto x = zeros({5, 5, 5});
|
||||
CHECK_THROWS_AS(fft::fftn(x, {}, {0, 3}), std::invalid_argument);
|
||||
CHECK_THROWS_AS(fft::fftn(x, {}, {0, -4}), std::invalid_argument);
|
||||
@ -204,8 +193,6 @@ TEST_CASE("test fftn") {
|
||||
CHECK_EQ(y.shape(), std::vector<int>{5, 8});
|
||||
CHECK_EQ(y.dtype(), float32);
|
||||
}
|
||||
|
||||
set_default_device(device);
|
||||
}
|
||||
|
||||
TEST_CASE("test fft with provided shape") {
|
||||
@ -234,9 +221,6 @@ TEST_CASE("test fft with provided shape") {
|
||||
}
|
||||
|
||||
TEST_CASE("test fft vmap") {
|
||||
auto device = default_device();
|
||||
set_default_device(Device::cpu);
|
||||
|
||||
auto fft_fn = [](array x) { return fft::fft(x); };
|
||||
auto x = reshape(arange(8), {2, 4});
|
||||
auto y = vmap(fft_fn)(x);
|
||||
@ -252,14 +236,9 @@ TEST_CASE("test fft vmap") {
|
||||
|
||||
y = vmap(rfft_fn, 1, 1)(x);
|
||||
CHECK(array_equal(y, fft::rfft(x, 0)).item<bool>());
|
||||
|
||||
set_default_device(device);
|
||||
}
|
||||
|
||||
TEST_CASE("test fft grads") {
|
||||
auto device = default_device();
|
||||
set_default_device(Device::cpu);
|
||||
|
||||
// Regular
|
||||
auto fft_fn = [](array x) { return fft::fft(x); };
|
||||
auto cotangent = astype(arange(10), complex64);
|
||||
@ -328,6 +307,4 @@ TEST_CASE("test fft grads") {
|
||||
zeros({5, 8}))
|
||||
.second;
|
||||
CHECK_EQ(vjp_out.shape(), std::vector<int>{5, 5});
|
||||
|
||||
set_default_device(device);
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user