mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 07:58:14 +08:00 
			
		
		
		
	throw for invalid case and remove test (#1575)
This commit is contained in:
		| @@ -88,9 +88,8 @@ void init_fft(nb::module_& parent_module) { | ||||
|         } else if (axes.has_value()) { | ||||
|           return fft::fftn(a, axes.value(), s); | ||||
|         } else if (n.has_value()) { | ||||
|           std::vector<int> axes_(n.value().size()); | ||||
|           std::iota(axes_.begin(), axes_.end(), -n.value().size()); | ||||
|           return fft::fftn(a, n.value(), axes_, s); | ||||
|           throw std::invalid_argument( | ||||
|               "[fft2] `axes` should not be `None` if `s` is not `None`."); | ||||
|         } else { | ||||
|           return fft::fftn(a, s); | ||||
|         } | ||||
| @@ -125,9 +124,8 @@ void init_fft(nb::module_& parent_module) { | ||||
|         } else if (axes.has_value()) { | ||||
|           return fft::ifftn(a, axes.value(), s); | ||||
|         } else if (n.has_value()) { | ||||
|           std::vector<int> axes_(n.value().size()); | ||||
|           std::iota(axes_.begin(), axes_.end(), -n.value().size()); | ||||
|           return fft::ifftn(a, n.value(), axes_, s); | ||||
|           throw std::invalid_argument( | ||||
|               "[ifft2] `axes` should not be `None` if `s` is not `None`."); | ||||
|         } else { | ||||
|           return fft::ifftn(a, s); | ||||
|         } | ||||
| @@ -162,9 +160,8 @@ void init_fft(nb::module_& parent_module) { | ||||
|         } else if (axes.has_value()) { | ||||
|           return fft::fftn(a, axes.value(), s); | ||||
|         } else if (n.has_value()) { | ||||
|           std::vector<int> axes_(n.value().size()); | ||||
|           std::iota(axes_.begin(), axes_.end(), -n.value().size()); | ||||
|           return fft::fftn(a, n.value(), axes_, s); | ||||
|           throw std::invalid_argument( | ||||
|               "[fftn] `axes` should not be `None` if `s` is not `None`."); | ||||
|         } else { | ||||
|           return fft::fftn(a, s); | ||||
|         } | ||||
| @@ -200,9 +197,8 @@ void init_fft(nb::module_& parent_module) { | ||||
|         } else if (axes.has_value()) { | ||||
|           return fft::ifftn(a, axes.value(), s); | ||||
|         } else if (n.has_value()) { | ||||
|           std::vector<int> axes_(n.value().size()); | ||||
|           std::iota(axes_.begin(), axes_.end(), -n.value().size()); | ||||
|           return fft::ifftn(a, n.value(), axes_, s); | ||||
|           throw std::invalid_argument( | ||||
|               "[ifftn] `axes` should not be `None` if `s` is not `None`."); | ||||
|         } else { | ||||
|           return fft::ifftn(a, s); | ||||
|         } | ||||
| @@ -307,9 +303,8 @@ void init_fft(nb::module_& parent_module) { | ||||
|         } else if (axes.has_value()) { | ||||
|           return fft::rfftn(a, axes.value(), s); | ||||
|         } else if (n.has_value()) { | ||||
|           std::vector<int> axes_(n.value().size()); | ||||
|           std::iota(axes_.begin(), axes_.end(), -n.value().size()); | ||||
|           return fft::rfftn(a, n.value(), axes_, s); | ||||
|           throw std::invalid_argument( | ||||
|               "[rfft2] `axes` should not be `None` if `s` is not `None`."); | ||||
|         } else { | ||||
|           return fft::rfftn(a, s); | ||||
|         } | ||||
| @@ -350,9 +345,8 @@ void init_fft(nb::module_& parent_module) { | ||||
|         } else if (axes.has_value()) { | ||||
|           return fft::irfftn(a, axes.value(), s); | ||||
|         } else if (n.has_value()) { | ||||
|           std::vector<int> axes_(n.value().size()); | ||||
|           std::iota(axes_.begin(), axes_.end(), -n.value().size()); | ||||
|           return fft::irfftn(a, n.value(), axes_, s); | ||||
|           throw std::invalid_argument( | ||||
|               "[irfft2] `axes` should not be `None` if `s` is not `None`."); | ||||
|         } else { | ||||
|           return fft::irfftn(a, s); | ||||
|         } | ||||
| @@ -393,9 +387,8 @@ void init_fft(nb::module_& parent_module) { | ||||
|         } else if (axes.has_value()) { | ||||
|           return fft::rfftn(a, axes.value(), s); | ||||
|         } else if (n.has_value()) { | ||||
|           std::vector<int> axes_(n.value().size()); | ||||
|           std::iota(axes_.begin(), axes_.end(), -n.value().size()); | ||||
|           return fft::rfftn(a, n.value(), axes_, s); | ||||
|           throw std::invalid_argument( | ||||
|               "[rfftn] `axes` should not be `None` if `s` is not `None`."); | ||||
|         } else { | ||||
|           return fft::rfftn(a, s); | ||||
|         } | ||||
| @@ -436,9 +429,8 @@ void init_fft(nb::module_& parent_module) { | ||||
|         } else if (axes.has_value()) { | ||||
|           return fft::irfftn(a, axes.value(), s); | ||||
|         } else if (n.has_value()) { | ||||
|           std::vector<int> axes_(n.value().size()); | ||||
|           std::iota(axes_.begin(), axes_.end(), -n.value().size()); | ||||
|           return fft::irfftn(a, n.value(), axes_, s); | ||||
|           throw std::invalid_argument( | ||||
|               "[irfftn] `axes` should not be `None` if `s` is not `None`."); | ||||
|         } else { | ||||
|           return fft::irfftn(a, s); | ||||
|         } | ||||
|   | ||||
| @@ -71,6 +71,8 @@ class TestFFT(mlx_tests.MLXTestCase): | ||||
|         ] | ||||
|  | ||||
|         for op, ax, s in itertools.product(ops, axes, shapes): | ||||
|             if ax is None and s is not None: | ||||
|                 continue | ||||
|             x = a | ||||
|             if op in ["rfft2", "rfftn"]: | ||||
|                 x = r | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun