Metal FFT for powers of 2 up to 2048 (#915)

* add Metal FFT for powers of 2

* skip GPU test on linux

* fix contiguity bug

* address comments

* Update mlx/backend/metal/fft.cpp

* Update mlx/backend/metal/fft.cpp

* fix bug in synch

---------

Co-authored-by: Alex Barron <abarron22@apple.com>
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Alex Barron 2024-04-12 05:40:06 +01:00 committed by GitHub
parent ae18326533
commit 2e7c02d5cd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 431 additions and 31 deletions

View File

@ -0,0 +1,57 @@
# Copyright © 2024 Apple Inc.
import matplotlib
import mlx.core as mx
import numpy as np
from time_utils import measure_runtime
matplotlib.use("Agg")
import matplotlib.pyplot as plt
def bandwidth_gb(runtime_ms, system_size):
bytes_per_fft = np.dtype(np.complex64).itemsize * 2
bytes_per_gb = 1e9
ms_per_s = 1e3
return system_size * bytes_per_fft / runtime_ms * ms_per_s / bytes_per_gb
def run_bench(system_size):
def fft(x):
out = mx.fft.fft(x)
mx.eval(out)
return out
bandwidths = []
for k in range(4, 12):
n = 2**k
x = mx.random.uniform(shape=(system_size // n, n)).astype(mx.float32)
x = x.astype(mx.complex64)
mx.eval(x)
runtime_ms = measure_runtime(fft, x=x)
bandwidths.append(bandwidth_gb(runtime_ms, system_size))
return bandwidths
def time_fft():
with mx.stream(mx.cpu):
cpu_bandwidths = run_bench(system_size=int(2**22))
with mx.stream(mx.gpu):
gpu_bandwidths = run_bench(system_size=int(2**29))
# plot bandwidths
x = [2**k for k in range(4, 12)]
plt.scatter(x, gpu_bandwidths, color="green", label="GPU")
plt.scatter(x, cpu_bandwidths, color="red", label="CPU")
plt.title("MLX FFT Benchmark")
plt.xlabel("N")
plt.ylabel("Bandwidth (GB/s)")
plt.legend()
plt.savefig("fft_plot.png")
if __name__ == "__main__":
time_fft()

View File

@ -1,12 +1,106 @@
// Copyright © 2023 Apple Inc.
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/mlx.h"
#include "mlx/primitives.h"
namespace mlx::core {
void FFT::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& s = out.primitive().stream();
auto& d = metal::device(s.device);
auto& in = inputs[0];
throw std::runtime_error("[FFT] NYI for Metal backend.");
if (axes_.size() == 0 || axes_.size() > 1 || inverse_ ||
in.dtype() != complex64 || out.dtype() != complex64) {
// Could also fallback to CPU implementation here.
throw std::runtime_error(
"GPU FFT is only implemented for 1D, forward, complex FFTs.");
}
size_t n = in.shape(axes_[0]);
if (!is_power_of_2(n) || n > 2048 || n < 4) {
throw std::runtime_error(
"GPU FFT is only implemented for the powers of 2 from 4 -> 2048");
}
// Make sure that the array is contiguous and has stride 1 in the FFT dim
std::vector<array> copies;
auto check_input = [this, &copies, &s](const array& x) {
// TODO: Pass the strides to the kernel so
// we can avoid the copy when x is not contiguous.
bool no_copy = x.strides()[axes_[0]] == 1 && x.flags().row_contiguous ||
x.flags().col_contiguous;
if (no_copy) {
return x;
} else {
array x_copy(x.shape(), x.dtype(), nullptr, {});
std::vector<size_t> strides;
size_t cur_stride = x.shape(axes_[0]);
for (int axis = 0; axis < x.ndim(); axis++) {
if (axis == axes_[0]) {
strides.push_back(1);
} else {
strides.push_back(cur_stride);
cur_stride *= x.shape(axis);
}
}
auto flags = x.flags();
size_t f_stride = 1;
size_t b_stride = 1;
flags.col_contiguous = true;
flags.row_contiguous = true;
for (int i = 0, ri = x.ndim() - 1; i < x.ndim(); ++i, --ri) {
flags.col_contiguous &= (strides[i] == f_stride || x.shape(i) == 1);
f_stride *= x.shape(i);
flags.row_contiguous &= (strides[ri] == b_stride || x.shape(ri) == 1);
b_stride *= x.shape(ri);
}
// This is probably over-conservative
flags.contiguous = false;
x_copy.set_data(
allocator::malloc_or_wait(x.nbytes()), x.data_size(), strides, flags);
copy_gpu_inplace(x, x_copy, CopyType::GeneralGeneral, s);
copies.push_back(x_copy);
return x_copy;
}
};
const array& in_contiguous = check_input(inputs[0]);
// TODO: allow donation here
out.set_data(
allocator::malloc_or_wait(out.nbytes()),
in_contiguous.data_size(),
in_contiguous.strides(),
in_contiguous.flags());
// We use n / 4 threads by default since radix-4
// is the largest single threaded radix butterfly
// we currently implement.
size_t m = n / 4;
size_t batch = in.size() / in.shape(axes_[0]);
auto& compute_encoder = d.get_command_encoder(s.index);
{
std::ostringstream kname;
kname << "fft_" << n;
auto kernel = d.get_kernel(kname.str());
bool donated = in.data_shared_ptr() == nullptr;
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array(in_contiguous, 0);
compute_encoder.set_output_array(out, 1);
auto group_dims = MTL::Size(1, m, 1);
auto grid_dims = MTL::Size(batch, m, 1);
compute_encoder->dispatchThreads(grid_dims, group_dims);
}
d.get_command_buffer(s.index)->addCompletedHandler(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
}
} // namespace mlx::core

View File

@ -21,6 +21,7 @@ set(
"binary_two"
"conv"
"copy"
"fft"
"gemv"
"quantized"
"random"

View File

@ -0,0 +1,195 @@
// Copyright © 2024 Apple Inc.
// Metal FFT using Stockham's algorithm
//
// References:
// - VkFFT (https://github.com/DTolm/VkFFT)
// - Eric Bainville's excellent page (http://www.bealto.com/gpu-fft.html)
#include <metal_math>
#include <metal_common>
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/kernels/utils.h"
using namespace metal;
float2 complex_mul(float2 a, float2 b) {
float2 c;
c.x = a.x * b.x - a.y * b.y;
c.y = a.x * b.y + a.y * b.x;
return c;
}
float2 get_twiddle(int k, int p) {
float theta = -1.0f * k * M_PI_F / (2*p);
float2 twiddle;
twiddle.x = metal::fast::cos(theta);
twiddle.y = metal::fast::sin(theta);
return twiddle;
}
// single threaded radix2 implemetation
void radix2(int i, int p, int m, threadgroup float2* read_buf, threadgroup float2* write_buf) {
float2 x_0 = read_buf[i];
float2 x_1 = read_buf[i + m];
// The index within this sub-DFT
int k = i & (p - 1);
float2 twiddle = get_twiddle(k, p);
float2 z = complex_mul(x_1, twiddle);
float2 y_0 = x_0 + z;
float2 y_1 = x_0 - z;
int j = (i << 1) - k;
write_buf[j] = y_0;
write_buf[j + p] = y_1;
}
// single threaded radix4 implemetation
void radix4(int i, int p, int m, threadgroup float2* read_buf, threadgroup float2* write_buf) {
float2 x_0 = read_buf[i];
float2 x_1 = read_buf[i + m];
float2 x_2 = read_buf[i + 2*m];
float2 x_3 = read_buf[i + 3*m];
// The index within this sub-DFT
int k = i & (p - 1);
float2 twiddle = get_twiddle(k, p);
// e^a * e^b = e^(a + b)
float2 twiddle_2 = complex_mul(twiddle, twiddle);
float2 twiddle_3 = complex_mul(twiddle, twiddle_2);
x_1 = complex_mul(x_1, twiddle);
x_2 = complex_mul(x_2, twiddle_2);
x_3 = complex_mul(x_3, twiddle_3);
float2 minus_i;
minus_i.x = 0;
minus_i.y = -1;
// Hard coded twiddle factors for DFT4
float2 z_0 = x_0 + x_2;
float2 z_1 = x_0 - x_2;
float2 z_2 = x_1 + x_3;
float2 z_3 = complex_mul(x_1 - x_3, minus_i);
float2 y_0 = z_0 + z_2;
float2 y_1 = z_1 + z_3;
float2 y_2 = z_0 - z_2;
float2 y_3 = z_1 - z_3;
int j = ((i - k) << 2) + k;
write_buf[j] = y_0;
write_buf[j + p] = y_1;
write_buf[j + 2*p] = y_2;
write_buf[j + 3*p] = y_3;
}
// Each FFT is computed entirely in shared GPU memory.
//
// N is decomposed into radix-2 and radix-4 DFTs:
// e.g. 128 = 2 * 4 * 4 * 4
//
// At each step we use n / 4 threads, each performing
// a single-threaded radix-4 or radix-2 DFT.
//
// We provide the number of radix-2 and radix-4
// steps at compile time for a ~20% performance boost.
template <size_t n, size_t radix_2_steps, size_t radix_4_steps>
[[kernel]] void fft(
const device float2 *in [[buffer(0)]],
device float2 * out [[buffer(1)]],
uint3 thread_position_in_grid [[thread_position_in_grid]],
uint3 threads_per_grid [[threads_per_grid]]) {
// Index of the DFT in batch
int batch_idx = thread_position_in_grid.x * n;
// The index in the DFT we're working on
int i = thread_position_in_grid.y;
// The number of the threads we're using for each DFT
int m = threads_per_grid.y;
// Allocate 2 shared memory buffers for Stockham.
// We alternate reading from one and writing to the other at each radix step.
threadgroup float2 shared_in[n];
threadgroup float2 shared_out[n];
// Pointers to facilitate Stockham buffer swapping
threadgroup float2* read_buf = shared_in;
threadgroup float2* write_buf = shared_out;
threadgroup float2* tmp;
// Copy input into shared memory
shared_in[i] = in[batch_idx + i];
shared_in[i + m] = in[batch_idx + i + m];
shared_in[i + 2*m] = in[batch_idx + i + 2*m];
shared_in[i + 3*m] = in[batch_idx + i + 3*m];
threadgroup_barrier(mem_flags::mem_threadgroup);
int p = 1;
for (size_t r = 0; r < radix_2_steps; r++) {
radix2(i, p, m*2, read_buf, write_buf);
radix2(i + m, p, m*2, read_buf, write_buf);
p *= 2;
threadgroup_barrier(mem_flags::mem_threadgroup);
// Stockham switch of buffers
tmp = write_buf;
write_buf = read_buf;
read_buf = tmp;
}
for (size_t r = 0; r < radix_4_steps; r++) {
radix4(i, p, m, read_buf, write_buf);
p *= 4;
threadgroup_barrier(mem_flags::mem_threadgroup);
// Stockham switch of buffers
tmp = write_buf;
write_buf = read_buf;
read_buf = tmp;
}
// Copy shared memory to output
out[batch_idx + i] = read_buf[i];
out[batch_idx + i + m] = read_buf[i + m];
out[batch_idx + i + 2*m] = read_buf[i + 2*m];
out[batch_idx + i + 3*m] = read_buf[i + 3*m];
}
#define instantiate_fft(name, n, radix_2_steps, radix_4_steps) \
template [[host_name("fft_" #name)]] \
[[kernel]] void fft<n, radix_2_steps, radix_4_steps>( \
const device float2* in [[buffer(0)]], \
device float2* out [[buffer(1)]], \
uint3 thread_position_in_grid [[thread_position_in_grid]], \
uint3 threads_per_grid [[threads_per_grid]]);
// Explicitly define kernels for each power of 2.
instantiate_fft(4, /* n= */ 4, /* radix_2_steps= */ 0, /* radix_4_steps= */ 1)
instantiate_fft(8, 8, 1, 1)
instantiate_fft(16, 16, 0, 2)
instantiate_fft(32, 32, 1, 2)
instantiate_fft(64, 64, 0, 3)
instantiate_fft(128, 128, 1, 3)
instantiate_fft(256, 256, 0, 4)
instantiate_fft(512, 512, 1, 4)
instantiate_fft(1024, 1024, 0, 5)
// 2048 is the max that will fit into 32KB of threadgroup memory.
// TODO: implement 4 step FFT for larger n.
instantiate_fft(2048, 2048, 1, 5)

View File

@ -130,6 +130,10 @@ inline void debug_set_primitive_buffer_label(
#endif
}
bool is_power_of_2(int n) {
return ((n & (n - 1)) == 0) && n != 0;
}
} // namespace
} // namespace mlx::core

View File

@ -9,58 +9,49 @@ import numpy as np
class TestFFT(mlx_tests.MLXTestCase):
def check_mx_np(self, op, a_np, axes, s):
with self.subTest(op=op, axes=axes, s=s):
op_np = getattr(np.fft, op)
op_mx = getattr(mx.fft, op)
out_np = op_np(a_np, s=s, axes=axes)
a_mx = mx.array(a_np)
out_mx = op_mx(a_mx, s=s, axes=axes)
self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6))
def check_mx_np(self, op_mx, op_np, a_np, atol=1e-5, rtol=1e-6, **kwargs):
out_np = op_np(a_np, **kwargs)
a_mx = mx.array(a_np)
out_mx = op_mx(a_mx, **kwargs)
np.testing.assert_allclose(out_np, out_mx, atol=atol, rtol=rtol)
def test_fft(self):
def check_mx_np(op_mx, op_np, a_np, **kwargs):
out_np = op_np(a_np, **kwargs)
a_mx = mx.array(a_np)
out_mx = op_mx(a_mx, **kwargs)
self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6))
with mx.stream(mx.cpu):
r = np.random.rand(100).astype(np.float32)
i = np.random.rand(100).astype(np.float32)
a_np = r + 1j * i
check_mx_np(mx.fft.fft, np.fft.fft, a_np)
self.check_mx_np(mx.fft.fft, np.fft.fft, a_np)
# Check with slicing and padding
r = np.random.rand(100).astype(np.float32)
i = np.random.rand(100).astype(np.float32)
a_np = r + 1j * i
check_mx_np(mx.fft.fft, np.fft.fft, a_np, n=80)
check_mx_np(mx.fft.fft, np.fft.fft, a_np, n=120)
self.check_mx_np(mx.fft.fft, np.fft.fft, a_np, n=80)
self.check_mx_np(mx.fft.fft, np.fft.fft, a_np, n=120)
# Check different axes
r = np.random.rand(100, 100).astype(np.float32)
i = np.random.rand(100, 100).astype(np.float32)
a_np = r + 1j * i
check_mx_np(mx.fft.fft, np.fft.fft, a_np, axis=0)
check_mx_np(mx.fft.fft, np.fft.fft, a_np, axis=1)
self.check_mx_np(mx.fft.fft, np.fft.fft, a_np, axis=0)
self.check_mx_np(mx.fft.fft, np.fft.fft, a_np, axis=1)
# Check real fft
a_np = np.random.rand(100).astype(np.float32)
check_mx_np(mx.fft.rfft, np.fft.rfft, a_np)
check_mx_np(mx.fft.rfft, np.fft.rfft, a_np, n=80)
check_mx_np(mx.fft.rfft, np.fft.rfft, a_np, n=120)
self.check_mx_np(mx.fft.rfft, np.fft.rfft, a_np)
self.check_mx_np(mx.fft.rfft, np.fft.rfft, a_np, n=80)
self.check_mx_np(mx.fft.rfft, np.fft.rfft, a_np, n=120)
# Check real inverse
r = np.random.rand(100, 100).astype(np.float32)
i = np.random.rand(100, 100).astype(np.float32)
a_np = r + 1j * i
check_mx_np(mx.fft.ifft, np.fft.ifft, a_np)
check_mx_np(mx.fft.ifft, np.fft.ifft, a_np, n=80)
check_mx_np(mx.fft.ifft, np.fft.ifft, a_np, n=120)
check_mx_np(mx.fft.irfft, np.fft.irfft, a_np)
check_mx_np(mx.fft.irfft, np.fft.irfft, a_np, n=80)
check_mx_np(mx.fft.irfft, np.fft.irfft, a_np, n=120)
self.check_mx_np(mx.fft.ifft, np.fft.ifft, a_np)
self.check_mx_np(mx.fft.ifft, np.fft.ifft, a_np, n=80)
self.check_mx_np(mx.fft.ifft, np.fft.ifft, a_np, n=120)
self.check_mx_np(mx.fft.irfft, np.fft.irfft, a_np)
self.check_mx_np(mx.fft.irfft, np.fft.irfft, a_np, n=80)
self.check_mx_np(mx.fft.irfft, np.fft.irfft, a_np, n=120)
def test_fftn(self):
with mx.stream(mx.cpu):
@ -85,7 +76,65 @@ class TestFFT(mlx_tests.MLXTestCase):
x = a
if op in ["rfft2", "rfftn"]:
x = r
self.check_mx_np(op, x, axes=ax, s=s)
mx_op = getattr(mx.fft, op)
np_op = getattr(np.fft, op)
self.check_mx_np(mx_op, np_op, x, axes=ax, s=s)
def test_fft_powers_of_two(self):
shape = (16, 4, 8)
# np.fft.fft always uses double precision complex128
# mx.fft.fft only supports single precision complex64
# hence the fairly tolerant equality checks.
atol = 1e-4
rtol = 1e-4
np.random.seed(7)
for k in range(4, 12):
r = np.random.rand(*shape, 2**k).astype(np.float32)
i = np.random.rand(*shape, 2**k).astype(np.float32)
a_np = r + 1j * i
self.check_mx_np(mx.fft.fft, np.fft.fft, a_np, atol=atol, rtol=rtol)
r = np.random.rand(*shape, 32).astype(np.float32)
i = np.random.rand(*shape, 32).astype(np.float32)
a_np = r + 1j * i
for axis in range(4):
self.check_mx_np(
mx.fft.fft, np.fft.fft, a_np, atol=atol, rtol=rtol, axis=axis
)
r = np.random.rand(4, 8).astype(np.float32)
i = np.random.rand(4, 8).astype(np.float32)
a_np = r + 1j * i
a_mx = mx.array(a_np)
def test_fft_contiguity(self):
r = np.random.rand(4, 8).astype(np.float32)
i = np.random.rand(4, 8).astype(np.float32)
a_np = r + 1j * i
a_mx = mx.array(a_np)
# non-contiguous in the FFT dim
out_mx = mx.fft.fft(a_mx[:, ::2])
out_np = np.fft.fft(a_np[:, ::2])
np.testing.assert_allclose(out_np, out_mx, atol=1e-5, rtol=1e-5)
# non-contiguous not in the FFT dim
out_mx = mx.fft.fft(a_mx[::2])
out_np = np.fft.fft(a_np[::2])
np.testing.assert_allclose(out_np, out_mx, atol=1e-5, rtol=1e-5)
out_mx = mx.broadcast_to(mx.reshape(mx.transpose(a_mx), (4, 8, 1)), (4, 8, 16))
out_np = np.broadcast_to(np.reshape(np.transpose(a_np), (4, 8, 1)), (4, 8, 16))
np.testing.assert_allclose(out_np, out_mx, atol=1e-5, rtol=1e-5)
out2_mx = mx.fft.fft(mx.abs(out_mx) + 4)
out2_np = np.fft.fft(np.abs(out_np) + 4)
np.testing.assert_allclose(out2_mx, out2_np, atol=1e-5, rtol=1e-5)
b_np = np.array([[0, 1, 2, 3]])
out_mx = mx.abs(mx.fft.fft(mx.tile(mx.reshape(mx.array(b_np), (1, 4)), (4, 1))))
out_np = np.abs(np.fft.fft(np.tile(np.reshape(np.array(b_np), (1, 4)), (4, 1))))
np.testing.assert_allclose(out_mx, out_np, atol=1e-5, rtol=1e-5)
if __name__ == "__main__":