mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 12:49:44 +08:00
Feature complete Metal FFT (#1102)
* feature complete metal fft * fix contiguity bug * jit fft * simplify rader/bluestein constant computation * remove kernel/utils.h dep * remove bf16.h dep * format --------- Co-authored-by: Alex Barron <abarron22@apple.com>
This commit is contained in:
@@ -16,96 +16,140 @@ class TestFFT(mlx_tests.MLXTestCase):
|
||||
np.testing.assert_allclose(out_np, out_mx, atol=atol, rtol=rtol)
|
||||
|
||||
def test_fft(self):
|
||||
with mx.stream(mx.cpu):
|
||||
r = np.random.rand(100).astype(np.float32)
|
||||
i = np.random.rand(100).astype(np.float32)
|
||||
a_np = r + 1j * i
|
||||
self.check_mx_np(mx.fft.fft, np.fft.fft, a_np)
|
||||
r = np.random.rand(100).astype(np.float32)
|
||||
i = np.random.rand(100).astype(np.float32)
|
||||
a_np = r + 1j * i
|
||||
self.check_mx_np(mx.fft.fft, np.fft.fft, a_np)
|
||||
|
||||
# Check with slicing and padding
|
||||
r = np.random.rand(100).astype(np.float32)
|
||||
i = np.random.rand(100).astype(np.float32)
|
||||
a_np = r + 1j * i
|
||||
self.check_mx_np(mx.fft.fft, np.fft.fft, a_np, n=80)
|
||||
self.check_mx_np(mx.fft.fft, np.fft.fft, a_np, n=120)
|
||||
# Check with slicing and padding
|
||||
r = np.random.rand(100).astype(np.float32)
|
||||
i = np.random.rand(100).astype(np.float32)
|
||||
a_np = r + 1j * i
|
||||
self.check_mx_np(mx.fft.fft, np.fft.fft, a_np, n=80)
|
||||
self.check_mx_np(mx.fft.fft, np.fft.fft, a_np, n=120)
|
||||
|
||||
# Check different axes
|
||||
r = np.random.rand(100, 100).astype(np.float32)
|
||||
i = np.random.rand(100, 100).astype(np.float32)
|
||||
a_np = r + 1j * i
|
||||
self.check_mx_np(mx.fft.fft, np.fft.fft, a_np, axis=0)
|
||||
self.check_mx_np(mx.fft.fft, np.fft.fft, a_np, axis=1)
|
||||
# Check different axes
|
||||
r = np.random.rand(100, 100).astype(np.float32)
|
||||
i = np.random.rand(100, 100).astype(np.float32)
|
||||
a_np = r + 1j * i
|
||||
self.check_mx_np(mx.fft.fft, np.fft.fft, a_np, axis=0)
|
||||
self.check_mx_np(mx.fft.fft, np.fft.fft, a_np, axis=1)
|
||||
|
||||
# Check real fft
|
||||
a_np = np.random.rand(100).astype(np.float32)
|
||||
self.check_mx_np(mx.fft.rfft, np.fft.rfft, a_np)
|
||||
self.check_mx_np(mx.fft.rfft, np.fft.rfft, a_np, n=80)
|
||||
self.check_mx_np(mx.fft.rfft, np.fft.rfft, a_np, n=120)
|
||||
# Check real fft
|
||||
a_np = np.random.rand(100).astype(np.float32)
|
||||
self.check_mx_np(mx.fft.rfft, np.fft.rfft, a_np)
|
||||
self.check_mx_np(mx.fft.rfft, np.fft.rfft, a_np, n=80)
|
||||
self.check_mx_np(mx.fft.rfft, np.fft.rfft, a_np, n=120)
|
||||
|
||||
# Check real inverse
|
||||
r = np.random.rand(100, 100).astype(np.float32)
|
||||
i = np.random.rand(100, 100).astype(np.float32)
|
||||
a_np = r + 1j * i
|
||||
self.check_mx_np(mx.fft.ifft, np.fft.ifft, a_np)
|
||||
self.check_mx_np(mx.fft.ifft, np.fft.ifft, a_np, n=80)
|
||||
self.check_mx_np(mx.fft.ifft, np.fft.ifft, a_np, n=120)
|
||||
self.check_mx_np(mx.fft.irfft, np.fft.irfft, a_np)
|
||||
self.check_mx_np(mx.fft.irfft, np.fft.irfft, a_np, n=80)
|
||||
self.check_mx_np(mx.fft.irfft, np.fft.irfft, a_np, n=120)
|
||||
# Check real inverse
|
||||
r = np.random.rand(100, 100).astype(np.float32)
|
||||
i = np.random.rand(100, 100).astype(np.float32)
|
||||
a_np = r + 1j * i
|
||||
self.check_mx_np(mx.fft.ifft, np.fft.ifft, a_np)
|
||||
self.check_mx_np(mx.fft.ifft, np.fft.ifft, a_np, n=80)
|
||||
self.check_mx_np(mx.fft.ifft, np.fft.ifft, a_np, n=120)
|
||||
|
||||
x = np.fft.rfft(a_np)
|
||||
self.check_mx_np(mx.fft.irfft, np.fft.irfft, x)
|
||||
|
||||
def test_fftn(self):
|
||||
with mx.stream(mx.cpu):
|
||||
r = np.random.randn(8, 8, 8).astype(np.float32)
|
||||
i = np.random.randn(8, 8, 8).astype(np.float32)
|
||||
a = r + 1j * i
|
||||
r = np.random.randn(8, 8, 8).astype(np.float32)
|
||||
i = np.random.randn(8, 8, 8).astype(np.float32)
|
||||
a = r + 1j * i
|
||||
|
||||
axes = [None, (1, 2), (2, 1), (0, 2)]
|
||||
shapes = [None, (10, 5), (5, 10)]
|
||||
ops = [
|
||||
"fft2",
|
||||
"ifft2",
|
||||
"rfft2",
|
||||
"irfft2",
|
||||
"fftn",
|
||||
"ifftn",
|
||||
"rfftn",
|
||||
"irfftn",
|
||||
axes = [None, (1, 2), (2, 1), (0, 2)]
|
||||
shapes = [None, (10, 5), (5, 10)]
|
||||
ops = [
|
||||
"fft2",
|
||||
"ifft2",
|
||||
"rfft2",
|
||||
"irfft2",
|
||||
"fftn",
|
||||
"ifftn",
|
||||
"rfftn",
|
||||
"irfftn",
|
||||
]
|
||||
|
||||
for op, ax, s in itertools.product(ops, axes, shapes):
|
||||
x = a
|
||||
if op in ["rfft2", "rfftn"]:
|
||||
x = r
|
||||
elif op == "irfft2":
|
||||
x = np.ascontiguousarray(np.fft.rfft2(x, axes=ax, s=s))
|
||||
elif op == "irfftn":
|
||||
x = np.ascontiguousarray(np.fft.rfftn(x, axes=ax, s=s))
|
||||
mx_op = getattr(mx.fft, op)
|
||||
np_op = getattr(np.fft, op)
|
||||
self.check_mx_np(mx_op, np_op, x, axes=ax, s=s)
|
||||
|
||||
def _run_ffts(self, shape, atol=1e-4, rtol=1e-4):
|
||||
np.random.seed(9)
|
||||
|
||||
r = np.random.rand(*shape).astype(np.float32)
|
||||
i = np.random.rand(*shape).astype(np.float32)
|
||||
a_np = r + 1j * i
|
||||
self.check_mx_np(mx.fft.fft, np.fft.fft, a_np, atol=atol, rtol=rtol)
|
||||
self.check_mx_np(mx.fft.ifft, np.fft.ifft, a_np, atol=atol, rtol=rtol)
|
||||
|
||||
self.check_mx_np(mx.fft.rfft, np.fft.rfft, r, atol=atol, rtol=rtol)
|
||||
|
||||
ia_np = np.fft.rfft(a_np)
|
||||
self.check_mx_np(
|
||||
mx.fft.irfft, np.fft.irfft, ia_np, atol=atol, rtol=rtol, n=shape[-1]
|
||||
)
|
||||
self.check_mx_np(mx.fft.irfft, np.fft.irfft, ia_np, atol=atol, rtol=rtol)
|
||||
|
||||
def test_fft_shared_mem(self):
|
||||
nums = np.concatenate(
|
||||
[
|
||||
# small radix
|
||||
np.arange(2, 14),
|
||||
# powers of 2
|
||||
[2**k for k in range(4, 13)],
|
||||
# stockham
|
||||
[3 * 3 * 3, 3 * 11, 11 * 13 * 2, 7 * 4 * 13 * 11, 13 * 13 * 11],
|
||||
# rader
|
||||
[17, 23, 29, 17 * 8 * 3, 23 * 2, 1153, 1982],
|
||||
# bluestein
|
||||
[47, 83, 17 * 17],
|
||||
# large stockham
|
||||
[3159, 3645, 3969, 4004],
|
||||
]
|
||||
)
|
||||
for batch_size in (1, 3, 32):
|
||||
for num in nums:
|
||||
atol = 1e-4 if num < 1025 else 1e-3
|
||||
self._run_ffts((batch_size, num), atol=atol)
|
||||
|
||||
for op, ax, s in itertools.product(ops, axes, shapes):
|
||||
x = a
|
||||
if op in ["rfft2", "rfftn"]:
|
||||
x = r
|
||||
mx_op = getattr(mx.fft, op)
|
||||
np_op = getattr(np.fft, op)
|
||||
self.check_mx_np(mx_op, np_op, x, axes=ax, s=s)
|
||||
@unittest.skip("Too slow for CI but useful for local testing.")
|
||||
def test_fft_exhaustive(self):
|
||||
nums = range(2, 4097)
|
||||
for batch_size in (1, 3, 32):
|
||||
for num in nums:
|
||||
print(num)
|
||||
atol = 1e-4 if num < 1025 else 1e-3
|
||||
self._run_ffts((batch_size, num), atol=atol)
|
||||
|
||||
def test_fft_powers_of_two(self):
|
||||
shape = (16, 4, 8)
|
||||
# np.fft.fft always uses double precision complex128
|
||||
# mx.fft.fft only supports single precision complex64
|
||||
# hence the fairly tolerant equality checks.
|
||||
atol = 1e-4
|
||||
rtol = 1e-4
|
||||
np.random.seed(7)
|
||||
for k in range(4, 12):
|
||||
r = np.random.rand(*shape, 2**k).astype(np.float32)
|
||||
i = np.random.rand(*shape, 2**k).astype(np.float32)
|
||||
a_np = r + 1j * i
|
||||
self.check_mx_np(mx.fft.fft, np.fft.fft, a_np, atol=atol, rtol=rtol)
|
||||
def test_fft_big_powers_of_two(self):
|
||||
# TODO: improve precision on big powers of two on GPU
|
||||
for k in range(12, 17):
|
||||
self._run_ffts((3, 2**k), atol=1e-3)
|
||||
|
||||
r = np.random.rand(*shape, 32).astype(np.float32)
|
||||
i = np.random.rand(*shape, 32).astype(np.float32)
|
||||
a_np = r + 1j * i
|
||||
for axis in range(4):
|
||||
self.check_mx_np(
|
||||
mx.fft.fft, np.fft.fft, a_np, atol=atol, rtol=rtol, axis=axis
|
||||
)
|
||||
for k in range(17, 20):
|
||||
self._run_ffts((3, 2**k), atol=1e-2)
|
||||
|
||||
r = np.random.rand(4, 8).astype(np.float32)
|
||||
i = np.random.rand(4, 8).astype(np.float32)
|
||||
a_np = r + 1j * i
|
||||
a_mx = mx.array(a_np)
|
||||
def test_fft_large_numbers(self):
|
||||
numbers = [
|
||||
1037, # prime > 2048
|
||||
18247, # medium size prime factors
|
||||
1259 * 11, # large prime factors
|
||||
7883, # large prime
|
||||
3**8, # large stockham decomposable
|
||||
3109, # bluestein
|
||||
4006, # large rader
|
||||
]
|
||||
for large_num in numbers:
|
||||
self._run_ffts((1, large_num), atol=1e-3)
|
||||
|
||||
def test_fft_contiguity(self):
|
||||
r = np.random.rand(4, 8).astype(np.float32)
|
||||
|
Reference in New Issue
Block a user