mlx/python/tests/test_fft.py
2024-02-14 14:14:58 -08:00

93 lines
3.2 KiB
Python

# Copyright © 2023 Apple Inc.
import itertools
import unittest
import mlx.core as mx
import mlx_tests
import numpy as np
class TestFFT(mlx_tests.MLXTestCase):
def check_mx_np(self, op, a_np, axes, s):
with self.subTest(op=op, axes=axes, s=s):
op_np = getattr(np.fft, op)
op_mx = getattr(mx.fft, op)
out_np = op_np(a_np, s=s, axes=axes)
a_mx = mx.array(a_np)
out_mx = op_mx(a_mx, s=s, axes=axes)
self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6))
def test_fft(self):
def check_mx_np(op_mx, op_np, a_np, **kwargs):
out_np = op_np(a_np, **kwargs)
a_mx = mx.array(a_np)
out_mx = op_mx(a_mx, **kwargs)
self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6))
with mx.stream(mx.cpu):
r = np.random.rand(100).astype(np.float32)
i = np.random.rand(100).astype(np.float32)
a_np = r + 1j * i
check_mx_np(mx.fft.fft, np.fft.fft, a_np)
# Check with slicing and padding
r = np.random.rand(100).astype(np.float32)
i = np.random.rand(100).astype(np.float32)
a_np = r + 1j * i
check_mx_np(mx.fft.fft, np.fft.fft, a_np, n=80)
check_mx_np(mx.fft.fft, np.fft.fft, a_np, n=120)
# Check different axes
r = np.random.rand(100, 100).astype(np.float32)
i = np.random.rand(100, 100).astype(np.float32)
a_np = r + 1j * i
check_mx_np(mx.fft.fft, np.fft.fft, a_np, axis=0)
check_mx_np(mx.fft.fft, np.fft.fft, a_np, axis=1)
# Check real fft
a_np = np.random.rand(100).astype(np.float32)
check_mx_np(mx.fft.rfft, np.fft.rfft, a_np)
check_mx_np(mx.fft.rfft, np.fft.rfft, a_np, n=80)
check_mx_np(mx.fft.rfft, np.fft.rfft, a_np, n=120)
# Check real inverse
r = np.random.rand(100, 100).astype(np.float32)
i = np.random.rand(100, 100).astype(np.float32)
a_np = r + 1j * i
check_mx_np(mx.fft.ifft, np.fft.ifft, a_np)
check_mx_np(mx.fft.ifft, np.fft.ifft, a_np, n=80)
check_mx_np(mx.fft.ifft, np.fft.ifft, a_np, n=120)
check_mx_np(mx.fft.irfft, np.fft.irfft, a_np)
check_mx_np(mx.fft.irfft, np.fft.irfft, a_np, n=80)
check_mx_np(mx.fft.irfft, np.fft.irfft, a_np, n=120)
def test_fftn(self):
with mx.stream(mx.cpu):
r = np.random.randn(8, 8, 8).astype(np.float32)
i = np.random.randn(8, 8, 8).astype(np.float32)
a = r + 1j * i
axes = [None, (1, 2), (2, 1), (0, 2)]
shapes = [None, (10, 5), (5, 10)]
ops = [
"fft2",
"ifft2",
"rfft2",
"irfft2",
"fftn",
"ifftn",
"rfftn",
"irfftn",
]
for op, ax, s in itertools.product(ops, axes, shapes):
x = a
if op in ["rfft2", "rfftn"]:
x = r
self.check_mx_np(op, x, axes=ax, s=s)
if __name__ == "__main__":
unittest.main()