mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 04:24:36 +08:00
Metal FFT for powers of 2 up to 2048 (#915)
* add Metal FFT for powers of 2 * skip GPU test on linux * fix contiguity bug * address comments * Update mlx/backend/metal/fft.cpp * Update mlx/backend/metal/fft.cpp * fix bug in synch --------- Co-authored-by: Alex Barron <abarron22@apple.com> Co-authored-by: Awni Hannun <awni.hannun@gmail.com> Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
@@ -9,58 +9,49 @@ import numpy as np
|
||||
|
||||
|
||||
class TestFFT(mlx_tests.MLXTestCase):
|
||||
def check_mx_np(self, op, a_np, axes, s):
|
||||
with self.subTest(op=op, axes=axes, s=s):
|
||||
op_np = getattr(np.fft, op)
|
||||
op_mx = getattr(mx.fft, op)
|
||||
out_np = op_np(a_np, s=s, axes=axes)
|
||||
a_mx = mx.array(a_np)
|
||||
out_mx = op_mx(a_mx, s=s, axes=axes)
|
||||
self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6))
|
||||
def check_mx_np(self, op_mx, op_np, a_np, atol=1e-5, rtol=1e-6, **kwargs):
|
||||
out_np = op_np(a_np, **kwargs)
|
||||
a_mx = mx.array(a_np)
|
||||
out_mx = op_mx(a_mx, **kwargs)
|
||||
np.testing.assert_allclose(out_np, out_mx, atol=atol, rtol=rtol)
|
||||
|
||||
def test_fft(self):
|
||||
def check_mx_np(op_mx, op_np, a_np, **kwargs):
|
||||
out_np = op_np(a_np, **kwargs)
|
||||
a_mx = mx.array(a_np)
|
||||
out_mx = op_mx(a_mx, **kwargs)
|
||||
self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6))
|
||||
|
||||
with mx.stream(mx.cpu):
|
||||
r = np.random.rand(100).astype(np.float32)
|
||||
i = np.random.rand(100).astype(np.float32)
|
||||
a_np = r + 1j * i
|
||||
check_mx_np(mx.fft.fft, np.fft.fft, a_np)
|
||||
self.check_mx_np(mx.fft.fft, np.fft.fft, a_np)
|
||||
|
||||
# Check with slicing and padding
|
||||
r = np.random.rand(100).astype(np.float32)
|
||||
i = np.random.rand(100).astype(np.float32)
|
||||
a_np = r + 1j * i
|
||||
check_mx_np(mx.fft.fft, np.fft.fft, a_np, n=80)
|
||||
check_mx_np(mx.fft.fft, np.fft.fft, a_np, n=120)
|
||||
self.check_mx_np(mx.fft.fft, np.fft.fft, a_np, n=80)
|
||||
self.check_mx_np(mx.fft.fft, np.fft.fft, a_np, n=120)
|
||||
|
||||
# Check different axes
|
||||
r = np.random.rand(100, 100).astype(np.float32)
|
||||
i = np.random.rand(100, 100).astype(np.float32)
|
||||
a_np = r + 1j * i
|
||||
check_mx_np(mx.fft.fft, np.fft.fft, a_np, axis=0)
|
||||
check_mx_np(mx.fft.fft, np.fft.fft, a_np, axis=1)
|
||||
self.check_mx_np(mx.fft.fft, np.fft.fft, a_np, axis=0)
|
||||
self.check_mx_np(mx.fft.fft, np.fft.fft, a_np, axis=1)
|
||||
|
||||
# Check real fft
|
||||
a_np = np.random.rand(100).astype(np.float32)
|
||||
check_mx_np(mx.fft.rfft, np.fft.rfft, a_np)
|
||||
check_mx_np(mx.fft.rfft, np.fft.rfft, a_np, n=80)
|
||||
check_mx_np(mx.fft.rfft, np.fft.rfft, a_np, n=120)
|
||||
self.check_mx_np(mx.fft.rfft, np.fft.rfft, a_np)
|
||||
self.check_mx_np(mx.fft.rfft, np.fft.rfft, a_np, n=80)
|
||||
self.check_mx_np(mx.fft.rfft, np.fft.rfft, a_np, n=120)
|
||||
|
||||
# Check real inverse
|
||||
r = np.random.rand(100, 100).astype(np.float32)
|
||||
i = np.random.rand(100, 100).astype(np.float32)
|
||||
a_np = r + 1j * i
|
||||
check_mx_np(mx.fft.ifft, np.fft.ifft, a_np)
|
||||
check_mx_np(mx.fft.ifft, np.fft.ifft, a_np, n=80)
|
||||
check_mx_np(mx.fft.ifft, np.fft.ifft, a_np, n=120)
|
||||
check_mx_np(mx.fft.irfft, np.fft.irfft, a_np)
|
||||
check_mx_np(mx.fft.irfft, np.fft.irfft, a_np, n=80)
|
||||
check_mx_np(mx.fft.irfft, np.fft.irfft, a_np, n=120)
|
||||
self.check_mx_np(mx.fft.ifft, np.fft.ifft, a_np)
|
||||
self.check_mx_np(mx.fft.ifft, np.fft.ifft, a_np, n=80)
|
||||
self.check_mx_np(mx.fft.ifft, np.fft.ifft, a_np, n=120)
|
||||
self.check_mx_np(mx.fft.irfft, np.fft.irfft, a_np)
|
||||
self.check_mx_np(mx.fft.irfft, np.fft.irfft, a_np, n=80)
|
||||
self.check_mx_np(mx.fft.irfft, np.fft.irfft, a_np, n=120)
|
||||
|
||||
def test_fftn(self):
|
||||
with mx.stream(mx.cpu):
|
||||
@@ -85,7 +76,65 @@ class TestFFT(mlx_tests.MLXTestCase):
|
||||
x = a
|
||||
if op in ["rfft2", "rfftn"]:
|
||||
x = r
|
||||
self.check_mx_np(op, x, axes=ax, s=s)
|
||||
mx_op = getattr(mx.fft, op)
|
||||
np_op = getattr(np.fft, op)
|
||||
self.check_mx_np(mx_op, np_op, x, axes=ax, s=s)
|
||||
|
||||
def test_fft_powers_of_two(self):
|
||||
shape = (16, 4, 8)
|
||||
# np.fft.fft always uses double precision complex128
|
||||
# mx.fft.fft only supports single precision complex64
|
||||
# hence the fairly tolerant equality checks.
|
||||
atol = 1e-4
|
||||
rtol = 1e-4
|
||||
np.random.seed(7)
|
||||
for k in range(4, 12):
|
||||
r = np.random.rand(*shape, 2**k).astype(np.float32)
|
||||
i = np.random.rand(*shape, 2**k).astype(np.float32)
|
||||
a_np = r + 1j * i
|
||||
self.check_mx_np(mx.fft.fft, np.fft.fft, a_np, atol=atol, rtol=rtol)
|
||||
|
||||
r = np.random.rand(*shape, 32).astype(np.float32)
|
||||
i = np.random.rand(*shape, 32).astype(np.float32)
|
||||
a_np = r + 1j * i
|
||||
for axis in range(4):
|
||||
self.check_mx_np(
|
||||
mx.fft.fft, np.fft.fft, a_np, atol=atol, rtol=rtol, axis=axis
|
||||
)
|
||||
|
||||
r = np.random.rand(4, 8).astype(np.float32)
|
||||
i = np.random.rand(4, 8).astype(np.float32)
|
||||
a_np = r + 1j * i
|
||||
a_mx = mx.array(a_np)
|
||||
|
||||
def test_fft_contiguity(self):
|
||||
r = np.random.rand(4, 8).astype(np.float32)
|
||||
i = np.random.rand(4, 8).astype(np.float32)
|
||||
a_np = r + 1j * i
|
||||
a_mx = mx.array(a_np)
|
||||
|
||||
# non-contiguous in the FFT dim
|
||||
out_mx = mx.fft.fft(a_mx[:, ::2])
|
||||
out_np = np.fft.fft(a_np[:, ::2])
|
||||
np.testing.assert_allclose(out_np, out_mx, atol=1e-5, rtol=1e-5)
|
||||
|
||||
# non-contiguous not in the FFT dim
|
||||
out_mx = mx.fft.fft(a_mx[::2])
|
||||
out_np = np.fft.fft(a_np[::2])
|
||||
np.testing.assert_allclose(out_np, out_mx, atol=1e-5, rtol=1e-5)
|
||||
|
||||
out_mx = mx.broadcast_to(mx.reshape(mx.transpose(a_mx), (4, 8, 1)), (4, 8, 16))
|
||||
out_np = np.broadcast_to(np.reshape(np.transpose(a_np), (4, 8, 1)), (4, 8, 16))
|
||||
np.testing.assert_allclose(out_np, out_mx, atol=1e-5, rtol=1e-5)
|
||||
|
||||
out2_mx = mx.fft.fft(mx.abs(out_mx) + 4)
|
||||
out2_np = np.fft.fft(np.abs(out_np) + 4)
|
||||
np.testing.assert_allclose(out2_mx, out2_np, atol=1e-5, rtol=1e-5)
|
||||
|
||||
b_np = np.array([[0, 1, 2, 3]])
|
||||
out_mx = mx.abs(mx.fft.fft(mx.tile(mx.reshape(mx.array(b_np), (1, 4)), (4, 1))))
|
||||
out_np = np.abs(np.fft.fft(np.tile(np.reshape(np.array(b_np), (1, 4)), (4, 1))))
|
||||
np.testing.assert_allclose(out_mx, out_np, atol=1e-5, rtol=1e-5)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Reference in New Issue
Block a user