From 2e7c02d5cdb173c777e42128c1590e7d86dc9a55 Mon Sep 17 00:00:00 2001 From: Alex Barron Date: Fri, 12 Apr 2024 05:40:06 +0100 Subject: [PATCH] Metal FFT for powers of 2 up to 2048 (#915) * add Metal FFT for powers of 2 * skip GPU test on linux * fix contiguity bug * address comments * Update mlx/backend/metal/fft.cpp * Update mlx/backend/metal/fft.cpp * fix bug in synch --------- Co-authored-by: Alex Barron Co-authored-by: Awni Hannun Co-authored-by: Awni Hannun --- benchmarks/python/fft_bench.py | 57 +++++++ mlx/backend/metal/fft.cpp | 98 +++++++++++- mlx/backend/metal/kernels/CMakeLists.txt | 1 + mlx/backend/metal/kernels/fft.metal | 195 +++++++++++++++++++++++ mlx/backend/metal/utils.h | 4 + python/tests/test_fft.py | 107 +++++++++---- 6 files changed, 431 insertions(+), 31 deletions(-) create mode 100644 benchmarks/python/fft_bench.py create mode 100644 mlx/backend/metal/kernels/fft.metal diff --git a/benchmarks/python/fft_bench.py b/benchmarks/python/fft_bench.py new file mode 100644 index 000000000..391a28aec --- /dev/null +++ b/benchmarks/python/fft_bench.py @@ -0,0 +1,57 @@ +# Copyright © 2024 Apple Inc. + +import matplotlib +import mlx.core as mx +import numpy as np +from time_utils import measure_runtime + +matplotlib.use("Agg") +import matplotlib.pyplot as plt + + +def bandwidth_gb(runtime_ms, system_size): + bytes_per_fft = np.dtype(np.complex64).itemsize * 2 + bytes_per_gb = 1e9 + ms_per_s = 1e3 + return system_size * bytes_per_fft / runtime_ms * ms_per_s / bytes_per_gb + + +def run_bench(system_size): + def fft(x): + out = mx.fft.fft(x) + mx.eval(out) + return out + + bandwidths = [] + for k in range(4, 12): + n = 2**k + x = mx.random.uniform(shape=(system_size // n, n)).astype(mx.float32) + x = x.astype(mx.complex64) + mx.eval(x) + runtime_ms = measure_runtime(fft, x=x) + bandwidths.append(bandwidth_gb(runtime_ms, system_size)) + + return bandwidths + + +def time_fft(): + + with mx.stream(mx.cpu): + cpu_bandwidths = run_bench(system_size=int(2**22)) + + with mx.stream(mx.gpu): + gpu_bandwidths = run_bench(system_size=int(2**29)) + + # plot bandwidths + x = [2**k for k in range(4, 12)] + plt.scatter(x, gpu_bandwidths, color="green", label="GPU") + plt.scatter(x, cpu_bandwidths, color="red", label="CPU") + plt.title("MLX FFT Benchmark") + plt.xlabel("N") + plt.ylabel("Bandwidth (GB/s)") + plt.legend() + plt.savefig("fft_plot.png") + + +if __name__ == "__main__": + time_fft() diff --git a/mlx/backend/metal/fft.cpp b/mlx/backend/metal/fft.cpp index c0844d374..3c851d23f 100644 --- a/mlx/backend/metal/fft.cpp +++ b/mlx/backend/metal/fft.cpp @@ -1,12 +1,106 @@ // Copyright © 2023 Apple Inc. - +#include "mlx/backend/metal/copy.h" +#include "mlx/backend/metal/utils.h" +#include "mlx/mlx.h" #include "mlx/primitives.h" namespace mlx::core { void FFT::eval_gpu(const std::vector& inputs, array& out) { + auto& s = out.primitive().stream(); + auto& d = metal::device(s.device); + auto& in = inputs[0]; - throw std::runtime_error("[FFT] NYI for Metal backend."); + + if (axes_.size() == 0 || axes_.size() > 1 || inverse_ || + in.dtype() != complex64 || out.dtype() != complex64) { + // Could also fallback to CPU implementation here. + throw std::runtime_error( + "GPU FFT is only implemented for 1D, forward, complex FFTs."); + } + + size_t n = in.shape(axes_[0]); + + if (!is_power_of_2(n) || n > 2048 || n < 4) { + throw std::runtime_error( + "GPU FFT is only implemented for the powers of 2 from 4 -> 2048"); + } + + // Make sure that the array is contiguous and has stride 1 in the FFT dim + std::vector copies; + auto check_input = [this, &copies, &s](const array& x) { + // TODO: Pass the strides to the kernel so + // we can avoid the copy when x is not contiguous. + bool no_copy = x.strides()[axes_[0]] == 1 && x.flags().row_contiguous || + x.flags().col_contiguous; + if (no_copy) { + return x; + } else { + array x_copy(x.shape(), x.dtype(), nullptr, {}); + std::vector strides; + size_t cur_stride = x.shape(axes_[0]); + for (int axis = 0; axis < x.ndim(); axis++) { + if (axis == axes_[0]) { + strides.push_back(1); + } else { + strides.push_back(cur_stride); + cur_stride *= x.shape(axis); + } + } + + auto flags = x.flags(); + size_t f_stride = 1; + size_t b_stride = 1; + flags.col_contiguous = true; + flags.row_contiguous = true; + for (int i = 0, ri = x.ndim() - 1; i < x.ndim(); ++i, --ri) { + flags.col_contiguous &= (strides[i] == f_stride || x.shape(i) == 1); + f_stride *= x.shape(i); + flags.row_contiguous &= (strides[ri] == b_stride || x.shape(ri) == 1); + b_stride *= x.shape(ri); + } + // This is probably over-conservative + flags.contiguous = false; + + x_copy.set_data( + allocator::malloc_or_wait(x.nbytes()), x.data_size(), strides, flags); + copy_gpu_inplace(x, x_copy, CopyType::GeneralGeneral, s); + copies.push_back(x_copy); + return x_copy; + } + }; + const array& in_contiguous = check_input(inputs[0]); + + // TODO: allow donation here + out.set_data( + allocator::malloc_or_wait(out.nbytes()), + in_contiguous.data_size(), + in_contiguous.strides(), + in_contiguous.flags()); + + // We use n / 4 threads by default since radix-4 + // is the largest single threaded radix butterfly + // we currently implement. + size_t m = n / 4; + size_t batch = in.size() / in.shape(axes_[0]); + + auto& compute_encoder = d.get_command_encoder(s.index); + { + std::ostringstream kname; + kname << "fft_" << n; + auto kernel = d.get_kernel(kname.str()); + + bool donated = in.data_shared_ptr() == nullptr; + compute_encoder->setComputePipelineState(kernel); + compute_encoder.set_input_array(in_contiguous, 0); + compute_encoder.set_output_array(out, 1); + + auto group_dims = MTL::Size(1, m, 1); + auto grid_dims = MTL::Size(batch, m, 1); + compute_encoder->dispatchThreads(grid_dims, group_dims); + } + d.get_command_buffer(s.index)->addCompletedHandler( + [copies](MTL::CommandBuffer*) mutable { copies.clear(); }); } } // namespace mlx::core diff --git a/mlx/backend/metal/kernels/CMakeLists.txt b/mlx/backend/metal/kernels/CMakeLists.txt index e8ca1356c..2010cb85a 100644 --- a/mlx/backend/metal/kernels/CMakeLists.txt +++ b/mlx/backend/metal/kernels/CMakeLists.txt @@ -21,6 +21,7 @@ set( "binary_two" "conv" "copy" + "fft" "gemv" "quantized" "random" diff --git a/mlx/backend/metal/kernels/fft.metal b/mlx/backend/metal/kernels/fft.metal new file mode 100644 index 000000000..25ceaab18 --- /dev/null +++ b/mlx/backend/metal/kernels/fft.metal @@ -0,0 +1,195 @@ +// Copyright © 2024 Apple Inc. + +// Metal FFT using Stockham's algorithm +// +// References: +// - VkFFT (https://github.com/DTolm/VkFFT) +// - Eric Bainville's excellent page (http://www.bealto.com/gpu-fft.html) + +#include +#include + + +#include "mlx/backend/metal/kernels/defines.h" +#include "mlx/backend/metal/kernels/utils.h" + +using namespace metal; + +float2 complex_mul(float2 a, float2 b) { + float2 c; + c.x = a.x * b.x - a.y * b.y; + c.y = a.x * b.y + a.y * b.x; + return c; +} + +float2 get_twiddle(int k, int p) { + float theta = -1.0f * k * M_PI_F / (2*p); + + float2 twiddle; + twiddle.x = metal::fast::cos(theta); + twiddle.y = metal::fast::sin(theta); + return twiddle; +} + +// single threaded radix2 implemetation +void radix2(int i, int p, int m, threadgroup float2* read_buf, threadgroup float2* write_buf) { + float2 x_0 = read_buf[i]; + float2 x_1 = read_buf[i + m]; + + // The index within this sub-DFT + int k = i & (p - 1); + + float2 twiddle = get_twiddle(k, p); + + float2 z = complex_mul(x_1, twiddle); + + float2 y_0 = x_0 + z; + float2 y_1 = x_0 - z; + + int j = (i << 1) - k; + + write_buf[j] = y_0; + write_buf[j + p] = y_1; +} + +// single threaded radix4 implemetation +void radix4(int i, int p, int m, threadgroup float2* read_buf, threadgroup float2* write_buf) { + float2 x_0 = read_buf[i]; + float2 x_1 = read_buf[i + m]; + float2 x_2 = read_buf[i + 2*m]; + float2 x_3 = read_buf[i + 3*m]; + + // The index within this sub-DFT + int k = i & (p - 1); + + float2 twiddle = get_twiddle(k, p); + // e^a * e^b = e^(a + b) + float2 twiddle_2 = complex_mul(twiddle, twiddle); + float2 twiddle_3 = complex_mul(twiddle, twiddle_2); + + x_1 = complex_mul(x_1, twiddle); + x_2 = complex_mul(x_2, twiddle_2); + x_3 = complex_mul(x_3, twiddle_3); + + float2 minus_i; + minus_i.x = 0; + minus_i.y = -1; + + // Hard coded twiddle factors for DFT4 + float2 z_0 = x_0 + x_2; + float2 z_1 = x_0 - x_2; + float2 z_2 = x_1 + x_3; + float2 z_3 = complex_mul(x_1 - x_3, minus_i); + + float2 y_0 = z_0 + z_2; + float2 y_1 = z_1 + z_3; + float2 y_2 = z_0 - z_2; + float2 y_3 = z_1 - z_3; + + int j = ((i - k) << 2) + k; + + write_buf[j] = y_0; + write_buf[j + p] = y_1; + write_buf[j + 2*p] = y_2; + write_buf[j + 3*p] = y_3; +} + + +// Each FFT is computed entirely in shared GPU memory. +// +// N is decomposed into radix-2 and radix-4 DFTs: +// e.g. 128 = 2 * 4 * 4 * 4 +// +// At each step we use n / 4 threads, each performing +// a single-threaded radix-4 or radix-2 DFT. +// +// We provide the number of radix-2 and radix-4 +// steps at compile time for a ~20% performance boost. +template +[[kernel]] void fft( + const device float2 *in [[buffer(0)]], + device float2 * out [[buffer(1)]], + uint3 thread_position_in_grid [[thread_position_in_grid]], + uint3 threads_per_grid [[threads_per_grid]]) { + + // Index of the DFT in batch + int batch_idx = thread_position_in_grid.x * n; + // The index in the DFT we're working on + int i = thread_position_in_grid.y; + // The number of the threads we're using for each DFT + int m = threads_per_grid.y; + + // Allocate 2 shared memory buffers for Stockham. + // We alternate reading from one and writing to the other at each radix step. + threadgroup float2 shared_in[n]; + threadgroup float2 shared_out[n]; + + // Pointers to facilitate Stockham buffer swapping + threadgroup float2* read_buf = shared_in; + threadgroup float2* write_buf = shared_out; + threadgroup float2* tmp; + + // Copy input into shared memory + shared_in[i] = in[batch_idx + i]; + shared_in[i + m] = in[batch_idx + i + m]; + shared_in[i + 2*m] = in[batch_idx + i + 2*m]; + shared_in[i + 3*m] = in[batch_idx + i + 3*m]; + + threadgroup_barrier(mem_flags::mem_threadgroup); + + int p = 1; + + for (size_t r = 0; r < radix_2_steps; r++) { + radix2(i, p, m*2, read_buf, write_buf); + radix2(i + m, p, m*2, read_buf, write_buf); + p *= 2; + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Stockham switch of buffers + tmp = write_buf; + write_buf = read_buf; + read_buf = tmp; + } + + for (size_t r = 0; r < radix_4_steps; r++) { + radix4(i, p, m, read_buf, write_buf); + p *= 4; + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Stockham switch of buffers + tmp = write_buf; + write_buf = read_buf; + read_buf = tmp; + } + + // Copy shared memory to output + out[batch_idx + i] = read_buf[i]; + out[batch_idx + i + m] = read_buf[i + m]; + out[batch_idx + i + 2*m] = read_buf[i + 2*m]; + out[batch_idx + i + 3*m] = read_buf[i + 3*m]; +} + +#define instantiate_fft(name, n, radix_2_steps, radix_4_steps) \ + template [[host_name("fft_" #name)]] \ + [[kernel]] void fft( \ + const device float2* in [[buffer(0)]], \ + device float2* out [[buffer(1)]], \ + uint3 thread_position_in_grid [[thread_position_in_grid]], \ + uint3 threads_per_grid [[threads_per_grid]]); + + +// Explicitly define kernels for each power of 2. +instantiate_fft(4, /* n= */ 4, /* radix_2_steps= */ 0, /* radix_4_steps= */ 1) +instantiate_fft(8, 8, 1, 1) +instantiate_fft(16, 16, 0, 2) +instantiate_fft(32, 32, 1, 2) +instantiate_fft(64, 64, 0, 3) +instantiate_fft(128, 128, 1, 3) +instantiate_fft(256, 256, 0, 4) +instantiate_fft(512, 512, 1, 4) +instantiate_fft(1024, 1024, 0, 5) +// 2048 is the max that will fit into 32KB of threadgroup memory. +// TODO: implement 4 step FFT for larger n. +instantiate_fft(2048, 2048, 1, 5) diff --git a/mlx/backend/metal/utils.h b/mlx/backend/metal/utils.h index 0ec315dd5..9c4f3e921 100644 --- a/mlx/backend/metal/utils.h +++ b/mlx/backend/metal/utils.h @@ -130,6 +130,10 @@ inline void debug_set_primitive_buffer_label( #endif } +bool is_power_of_2(int n) { + return ((n & (n - 1)) == 0) && n != 0; +} + } // namespace } // namespace mlx::core diff --git a/python/tests/test_fft.py b/python/tests/test_fft.py index 14473afa1..6cb4b1618 100644 --- a/python/tests/test_fft.py +++ b/python/tests/test_fft.py @@ -9,58 +9,49 @@ import numpy as np class TestFFT(mlx_tests.MLXTestCase): - def check_mx_np(self, op, a_np, axes, s): - with self.subTest(op=op, axes=axes, s=s): - op_np = getattr(np.fft, op) - op_mx = getattr(mx.fft, op) - out_np = op_np(a_np, s=s, axes=axes) - a_mx = mx.array(a_np) - out_mx = op_mx(a_mx, s=s, axes=axes) - self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6)) + def check_mx_np(self, op_mx, op_np, a_np, atol=1e-5, rtol=1e-6, **kwargs): + out_np = op_np(a_np, **kwargs) + a_mx = mx.array(a_np) + out_mx = op_mx(a_mx, **kwargs) + np.testing.assert_allclose(out_np, out_mx, atol=atol, rtol=rtol) def test_fft(self): - def check_mx_np(op_mx, op_np, a_np, **kwargs): - out_np = op_np(a_np, **kwargs) - a_mx = mx.array(a_np) - out_mx = op_mx(a_mx, **kwargs) - self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6)) - with mx.stream(mx.cpu): r = np.random.rand(100).astype(np.float32) i = np.random.rand(100).astype(np.float32) a_np = r + 1j * i - check_mx_np(mx.fft.fft, np.fft.fft, a_np) + self.check_mx_np(mx.fft.fft, np.fft.fft, a_np) # Check with slicing and padding r = np.random.rand(100).astype(np.float32) i = np.random.rand(100).astype(np.float32) a_np = r + 1j * i - check_mx_np(mx.fft.fft, np.fft.fft, a_np, n=80) - check_mx_np(mx.fft.fft, np.fft.fft, a_np, n=120) + self.check_mx_np(mx.fft.fft, np.fft.fft, a_np, n=80) + self.check_mx_np(mx.fft.fft, np.fft.fft, a_np, n=120) # Check different axes r = np.random.rand(100, 100).astype(np.float32) i = np.random.rand(100, 100).astype(np.float32) a_np = r + 1j * i - check_mx_np(mx.fft.fft, np.fft.fft, a_np, axis=0) - check_mx_np(mx.fft.fft, np.fft.fft, a_np, axis=1) + self.check_mx_np(mx.fft.fft, np.fft.fft, a_np, axis=0) + self.check_mx_np(mx.fft.fft, np.fft.fft, a_np, axis=1) # Check real fft a_np = np.random.rand(100).astype(np.float32) - check_mx_np(mx.fft.rfft, np.fft.rfft, a_np) - check_mx_np(mx.fft.rfft, np.fft.rfft, a_np, n=80) - check_mx_np(mx.fft.rfft, np.fft.rfft, a_np, n=120) + self.check_mx_np(mx.fft.rfft, np.fft.rfft, a_np) + self.check_mx_np(mx.fft.rfft, np.fft.rfft, a_np, n=80) + self.check_mx_np(mx.fft.rfft, np.fft.rfft, a_np, n=120) # Check real inverse r = np.random.rand(100, 100).astype(np.float32) i = np.random.rand(100, 100).astype(np.float32) a_np = r + 1j * i - check_mx_np(mx.fft.ifft, np.fft.ifft, a_np) - check_mx_np(mx.fft.ifft, np.fft.ifft, a_np, n=80) - check_mx_np(mx.fft.ifft, np.fft.ifft, a_np, n=120) - check_mx_np(mx.fft.irfft, np.fft.irfft, a_np) - check_mx_np(mx.fft.irfft, np.fft.irfft, a_np, n=80) - check_mx_np(mx.fft.irfft, np.fft.irfft, a_np, n=120) + self.check_mx_np(mx.fft.ifft, np.fft.ifft, a_np) + self.check_mx_np(mx.fft.ifft, np.fft.ifft, a_np, n=80) + self.check_mx_np(mx.fft.ifft, np.fft.ifft, a_np, n=120) + self.check_mx_np(mx.fft.irfft, np.fft.irfft, a_np) + self.check_mx_np(mx.fft.irfft, np.fft.irfft, a_np, n=80) + self.check_mx_np(mx.fft.irfft, np.fft.irfft, a_np, n=120) def test_fftn(self): with mx.stream(mx.cpu): @@ -85,7 +76,65 @@ class TestFFT(mlx_tests.MLXTestCase): x = a if op in ["rfft2", "rfftn"]: x = r - self.check_mx_np(op, x, axes=ax, s=s) + mx_op = getattr(mx.fft, op) + np_op = getattr(np.fft, op) + self.check_mx_np(mx_op, np_op, x, axes=ax, s=s) + + def test_fft_powers_of_two(self): + shape = (16, 4, 8) + # np.fft.fft always uses double precision complex128 + # mx.fft.fft only supports single precision complex64 + # hence the fairly tolerant equality checks. + atol = 1e-4 + rtol = 1e-4 + np.random.seed(7) + for k in range(4, 12): + r = np.random.rand(*shape, 2**k).astype(np.float32) + i = np.random.rand(*shape, 2**k).astype(np.float32) + a_np = r + 1j * i + self.check_mx_np(mx.fft.fft, np.fft.fft, a_np, atol=atol, rtol=rtol) + + r = np.random.rand(*shape, 32).astype(np.float32) + i = np.random.rand(*shape, 32).astype(np.float32) + a_np = r + 1j * i + for axis in range(4): + self.check_mx_np( + mx.fft.fft, np.fft.fft, a_np, atol=atol, rtol=rtol, axis=axis + ) + + r = np.random.rand(4, 8).astype(np.float32) + i = np.random.rand(4, 8).astype(np.float32) + a_np = r + 1j * i + a_mx = mx.array(a_np) + + def test_fft_contiguity(self): + r = np.random.rand(4, 8).astype(np.float32) + i = np.random.rand(4, 8).astype(np.float32) + a_np = r + 1j * i + a_mx = mx.array(a_np) + + # non-contiguous in the FFT dim + out_mx = mx.fft.fft(a_mx[:, ::2]) + out_np = np.fft.fft(a_np[:, ::2]) + np.testing.assert_allclose(out_np, out_mx, atol=1e-5, rtol=1e-5) + + # non-contiguous not in the FFT dim + out_mx = mx.fft.fft(a_mx[::2]) + out_np = np.fft.fft(a_np[::2]) + np.testing.assert_allclose(out_np, out_mx, atol=1e-5, rtol=1e-5) + + out_mx = mx.broadcast_to(mx.reshape(mx.transpose(a_mx), (4, 8, 1)), (4, 8, 16)) + out_np = np.broadcast_to(np.reshape(np.transpose(a_np), (4, 8, 1)), (4, 8, 16)) + np.testing.assert_allclose(out_np, out_mx, atol=1e-5, rtol=1e-5) + + out2_mx = mx.fft.fft(mx.abs(out_mx) + 4) + out2_np = np.fft.fft(np.abs(out_np) + 4) + np.testing.assert_allclose(out2_mx, out2_np, atol=1e-5, rtol=1e-5) + + b_np = np.array([[0, 1, 2, 3]]) + out_mx = mx.abs(mx.fft.fft(mx.tile(mx.reshape(mx.array(b_np), (1, 4)), (4, 1)))) + out_np = np.abs(np.fft.fft(np.tile(np.reshape(np.array(b_np), (1, 4)), (4, 1)))) + np.testing.assert_allclose(out_mx, out_np, atol=1e-5, rtol=1e-5) if __name__ == "__main__":