Feature complete Metal FFT (#1102)

* feature complete metal fft

* fix contiguity bug

* jit fft

* simplify rader/bluestein constant computation

* remove kernel/utils.h dep

* remove bf16.h dep

* format

---------

Co-authored-by: Alex Barron <abarron22@apple.com>
This commit is contained in:
Alex Barron
2024-06-06 12:57:25 -07:00
committed by GitHub
parent 0e585b4409
commit 27d70c7d9d
17 changed files with 2601 additions and 367 deletions

View File

@@ -16,96 +16,140 @@ class TestFFT(mlx_tests.MLXTestCase):
np.testing.assert_allclose(out_np, out_mx, atol=atol, rtol=rtol)
def test_fft(self):
with mx.stream(mx.cpu):
r = np.random.rand(100).astype(np.float32)
i = np.random.rand(100).astype(np.float32)
a_np = r + 1j * i
self.check_mx_np(mx.fft.fft, np.fft.fft, a_np)
r = np.random.rand(100).astype(np.float32)
i = np.random.rand(100).astype(np.float32)
a_np = r + 1j * i
self.check_mx_np(mx.fft.fft, np.fft.fft, a_np)
# Check with slicing and padding
r = np.random.rand(100).astype(np.float32)
i = np.random.rand(100).astype(np.float32)
a_np = r + 1j * i
self.check_mx_np(mx.fft.fft, np.fft.fft, a_np, n=80)
self.check_mx_np(mx.fft.fft, np.fft.fft, a_np, n=120)
# Check with slicing and padding
r = np.random.rand(100).astype(np.float32)
i = np.random.rand(100).astype(np.float32)
a_np = r + 1j * i
self.check_mx_np(mx.fft.fft, np.fft.fft, a_np, n=80)
self.check_mx_np(mx.fft.fft, np.fft.fft, a_np, n=120)
# Check different axes
r = np.random.rand(100, 100).astype(np.float32)
i = np.random.rand(100, 100).astype(np.float32)
a_np = r + 1j * i
self.check_mx_np(mx.fft.fft, np.fft.fft, a_np, axis=0)
self.check_mx_np(mx.fft.fft, np.fft.fft, a_np, axis=1)
# Check different axes
r = np.random.rand(100, 100).astype(np.float32)
i = np.random.rand(100, 100).astype(np.float32)
a_np = r + 1j * i
self.check_mx_np(mx.fft.fft, np.fft.fft, a_np, axis=0)
self.check_mx_np(mx.fft.fft, np.fft.fft, a_np, axis=1)
# Check real fft
a_np = np.random.rand(100).astype(np.float32)
self.check_mx_np(mx.fft.rfft, np.fft.rfft, a_np)
self.check_mx_np(mx.fft.rfft, np.fft.rfft, a_np, n=80)
self.check_mx_np(mx.fft.rfft, np.fft.rfft, a_np, n=120)
# Check real fft
a_np = np.random.rand(100).astype(np.float32)
self.check_mx_np(mx.fft.rfft, np.fft.rfft, a_np)
self.check_mx_np(mx.fft.rfft, np.fft.rfft, a_np, n=80)
self.check_mx_np(mx.fft.rfft, np.fft.rfft, a_np, n=120)
# Check real inverse
r = np.random.rand(100, 100).astype(np.float32)
i = np.random.rand(100, 100).astype(np.float32)
a_np = r + 1j * i
self.check_mx_np(mx.fft.ifft, np.fft.ifft, a_np)
self.check_mx_np(mx.fft.ifft, np.fft.ifft, a_np, n=80)
self.check_mx_np(mx.fft.ifft, np.fft.ifft, a_np, n=120)
self.check_mx_np(mx.fft.irfft, np.fft.irfft, a_np)
self.check_mx_np(mx.fft.irfft, np.fft.irfft, a_np, n=80)
self.check_mx_np(mx.fft.irfft, np.fft.irfft, a_np, n=120)
# Check real inverse
r = np.random.rand(100, 100).astype(np.float32)
i = np.random.rand(100, 100).astype(np.float32)
a_np = r + 1j * i
self.check_mx_np(mx.fft.ifft, np.fft.ifft, a_np)
self.check_mx_np(mx.fft.ifft, np.fft.ifft, a_np, n=80)
self.check_mx_np(mx.fft.ifft, np.fft.ifft, a_np, n=120)
x = np.fft.rfft(a_np)
self.check_mx_np(mx.fft.irfft, np.fft.irfft, x)
def test_fftn(self):
with mx.stream(mx.cpu):
r = np.random.randn(8, 8, 8).astype(np.float32)
i = np.random.randn(8, 8, 8).astype(np.float32)
a = r + 1j * i
r = np.random.randn(8, 8, 8).astype(np.float32)
i = np.random.randn(8, 8, 8).astype(np.float32)
a = r + 1j * i
axes = [None, (1, 2), (2, 1), (0, 2)]
shapes = [None, (10, 5), (5, 10)]
ops = [
"fft2",
"ifft2",
"rfft2",
"irfft2",
"fftn",
"ifftn",
"rfftn",
"irfftn",
axes = [None, (1, 2), (2, 1), (0, 2)]
shapes = [None, (10, 5), (5, 10)]
ops = [
"fft2",
"ifft2",
"rfft2",
"irfft2",
"fftn",
"ifftn",
"rfftn",
"irfftn",
]
for op, ax, s in itertools.product(ops, axes, shapes):
x = a
if op in ["rfft2", "rfftn"]:
x = r
elif op == "irfft2":
x = np.ascontiguousarray(np.fft.rfft2(x, axes=ax, s=s))
elif op == "irfftn":
x = np.ascontiguousarray(np.fft.rfftn(x, axes=ax, s=s))
mx_op = getattr(mx.fft, op)
np_op = getattr(np.fft, op)
self.check_mx_np(mx_op, np_op, x, axes=ax, s=s)
def _run_ffts(self, shape, atol=1e-4, rtol=1e-4):
np.random.seed(9)
r = np.random.rand(*shape).astype(np.float32)
i = np.random.rand(*shape).astype(np.float32)
a_np = r + 1j * i
self.check_mx_np(mx.fft.fft, np.fft.fft, a_np, atol=atol, rtol=rtol)
self.check_mx_np(mx.fft.ifft, np.fft.ifft, a_np, atol=atol, rtol=rtol)
self.check_mx_np(mx.fft.rfft, np.fft.rfft, r, atol=atol, rtol=rtol)
ia_np = np.fft.rfft(a_np)
self.check_mx_np(
mx.fft.irfft, np.fft.irfft, ia_np, atol=atol, rtol=rtol, n=shape[-1]
)
self.check_mx_np(mx.fft.irfft, np.fft.irfft, ia_np, atol=atol, rtol=rtol)
def test_fft_shared_mem(self):
nums = np.concatenate(
[
# small radix
np.arange(2, 14),
# powers of 2
[2**k for k in range(4, 13)],
# stockham
[3 * 3 * 3, 3 * 11, 11 * 13 * 2, 7 * 4 * 13 * 11, 13 * 13 * 11],
# rader
[17, 23, 29, 17 * 8 * 3, 23 * 2, 1153, 1982],
# bluestein
[47, 83, 17 * 17],
# large stockham
[3159, 3645, 3969, 4004],
]
)
for batch_size in (1, 3, 32):
for num in nums:
atol = 1e-4 if num < 1025 else 1e-3
self._run_ffts((batch_size, num), atol=atol)
for op, ax, s in itertools.product(ops, axes, shapes):
x = a
if op in ["rfft2", "rfftn"]:
x = r
mx_op = getattr(mx.fft, op)
np_op = getattr(np.fft, op)
self.check_mx_np(mx_op, np_op, x, axes=ax, s=s)
@unittest.skip("Too slow for CI but useful for local testing.")
def test_fft_exhaustive(self):
nums = range(2, 4097)
for batch_size in (1, 3, 32):
for num in nums:
print(num)
atol = 1e-4 if num < 1025 else 1e-3
self._run_ffts((batch_size, num), atol=atol)
def test_fft_powers_of_two(self):
shape = (16, 4, 8)
# np.fft.fft always uses double precision complex128
# mx.fft.fft only supports single precision complex64
# hence the fairly tolerant equality checks.
atol = 1e-4
rtol = 1e-4
np.random.seed(7)
for k in range(4, 12):
r = np.random.rand(*shape, 2**k).astype(np.float32)
i = np.random.rand(*shape, 2**k).astype(np.float32)
a_np = r + 1j * i
self.check_mx_np(mx.fft.fft, np.fft.fft, a_np, atol=atol, rtol=rtol)
def test_fft_big_powers_of_two(self):
# TODO: improve precision on big powers of two on GPU
for k in range(12, 17):
self._run_ffts((3, 2**k), atol=1e-3)
r = np.random.rand(*shape, 32).astype(np.float32)
i = np.random.rand(*shape, 32).astype(np.float32)
a_np = r + 1j * i
for axis in range(4):
self.check_mx_np(
mx.fft.fft, np.fft.fft, a_np, atol=atol, rtol=rtol, axis=axis
)
for k in range(17, 20):
self._run_ffts((3, 2**k), atol=1e-2)
r = np.random.rand(4, 8).astype(np.float32)
i = np.random.rand(4, 8).astype(np.float32)
a_np = r + 1j * i
a_mx = mx.array(a_np)
def test_fft_large_numbers(self):
numbers = [
1037, # prime > 2048
18247, # medium size prime factors
1259 * 11, # large prime factors
7883, # large prime
3**8, # large stockham decomposable
3109, # bluestein
4006, # large rader
]
for large_num in numbers:
self._run_ffts((1, large_num), atol=1e-3)
def test_fft_contiguity(self):
r = np.random.rand(4, 8).astype(np.float32)