From 8c34c9dac4ce1377dfd1275a9bcaffefe014e468 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 8 Nov 2024 12:04:03 -0800 Subject: [PATCH] throw for invalid case and remove test (#1575) --- python/src/fft.cpp | 40 ++++++++++++++++------------------------ python/tests/test_fft.py | 2 ++ 2 files changed, 18 insertions(+), 24 deletions(-) diff --git a/python/src/fft.cpp b/python/src/fft.cpp index 3b1007fe2..44e914d05 100644 --- a/python/src/fft.cpp +++ b/python/src/fft.cpp @@ -88,9 +88,8 @@ void init_fft(nb::module_& parent_module) { } else if (axes.has_value()) { return fft::fftn(a, axes.value(), s); } else if (n.has_value()) { - std::vector axes_(n.value().size()); - std::iota(axes_.begin(), axes_.end(), -n.value().size()); - return fft::fftn(a, n.value(), axes_, s); + throw std::invalid_argument( + "[fft2] `axes` should not be `None` if `s` is not `None`."); } else { return fft::fftn(a, s); } @@ -125,9 +124,8 @@ void init_fft(nb::module_& parent_module) { } else if (axes.has_value()) { return fft::ifftn(a, axes.value(), s); } else if (n.has_value()) { - std::vector axes_(n.value().size()); - std::iota(axes_.begin(), axes_.end(), -n.value().size()); - return fft::ifftn(a, n.value(), axes_, s); + throw std::invalid_argument( + "[ifft2] `axes` should not be `None` if `s` is not `None`."); } else { return fft::ifftn(a, s); } @@ -162,9 +160,8 @@ void init_fft(nb::module_& parent_module) { } else if (axes.has_value()) { return fft::fftn(a, axes.value(), s); } else if (n.has_value()) { - std::vector axes_(n.value().size()); - std::iota(axes_.begin(), axes_.end(), -n.value().size()); - return fft::fftn(a, n.value(), axes_, s); + throw std::invalid_argument( + "[fftn] `axes` should not be `None` if `s` is not `None`."); } else { return fft::fftn(a, s); } @@ -200,9 +197,8 @@ void init_fft(nb::module_& parent_module) { } else if (axes.has_value()) { return fft::ifftn(a, axes.value(), s); } else if (n.has_value()) { - std::vector axes_(n.value().size()); - std::iota(axes_.begin(), axes_.end(), -n.value().size()); - return fft::ifftn(a, n.value(), axes_, s); + throw std::invalid_argument( + "[ifftn] `axes` should not be `None` if `s` is not `None`."); } else { return fft::ifftn(a, s); } @@ -307,9 +303,8 @@ void init_fft(nb::module_& parent_module) { } else if (axes.has_value()) { return fft::rfftn(a, axes.value(), s); } else if (n.has_value()) { - std::vector axes_(n.value().size()); - std::iota(axes_.begin(), axes_.end(), -n.value().size()); - return fft::rfftn(a, n.value(), axes_, s); + throw std::invalid_argument( + "[rfft2] `axes` should not be `None` if `s` is not `None`."); } else { return fft::rfftn(a, s); } @@ -350,9 +345,8 @@ void init_fft(nb::module_& parent_module) { } else if (axes.has_value()) { return fft::irfftn(a, axes.value(), s); } else if (n.has_value()) { - std::vector axes_(n.value().size()); - std::iota(axes_.begin(), axes_.end(), -n.value().size()); - return fft::irfftn(a, n.value(), axes_, s); + throw std::invalid_argument( + "[irfft2] `axes` should not be `None` if `s` is not `None`."); } else { return fft::irfftn(a, s); } @@ -393,9 +387,8 @@ void init_fft(nb::module_& parent_module) { } else if (axes.has_value()) { return fft::rfftn(a, axes.value(), s); } else if (n.has_value()) { - std::vector axes_(n.value().size()); - std::iota(axes_.begin(), axes_.end(), -n.value().size()); - return fft::rfftn(a, n.value(), axes_, s); + throw std::invalid_argument( + "[rfftn] `axes` should not be `None` if `s` is not `None`."); } else { return fft::rfftn(a, s); } @@ -436,9 +429,8 @@ void init_fft(nb::module_& parent_module) { } else if (axes.has_value()) { return fft::irfftn(a, axes.value(), s); } else if (n.has_value()) { - std::vector axes_(n.value().size()); - std::iota(axes_.begin(), axes_.end(), -n.value().size()); - return fft::irfftn(a, n.value(), axes_, s); + throw std::invalid_argument( + "[irfftn] `axes` should not be `None` if `s` is not `None`."); } else { return fft::irfftn(a, s); } diff --git a/python/tests/test_fft.py b/python/tests/test_fft.py index c996b8d47..95f9f7a54 100644 --- a/python/tests/test_fft.py +++ b/python/tests/test_fft.py @@ -71,6 +71,8 @@ class TestFFT(mlx_tests.MLXTestCase): ] for op, ax, s in itertools.product(ops, axes, shapes): + if ax is None and s is not None: + continue x = a if op in ["rfft2", "rfftn"]: x = r