mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	throw for invalid case and remove test (#1575)
This commit is contained in:
		| @@ -88,9 +88,8 @@ void init_fft(nb::module_& parent_module) { | |||||||
|         } else if (axes.has_value()) { |         } else if (axes.has_value()) { | ||||||
|           return fft::fftn(a, axes.value(), s); |           return fft::fftn(a, axes.value(), s); | ||||||
|         } else if (n.has_value()) { |         } else if (n.has_value()) { | ||||||
|           std::vector<int> axes_(n.value().size()); |           throw std::invalid_argument( | ||||||
|           std::iota(axes_.begin(), axes_.end(), -n.value().size()); |               "[fft2] `axes` should not be `None` if `s` is not `None`."); | ||||||
|           return fft::fftn(a, n.value(), axes_, s); |  | ||||||
|         } else { |         } else { | ||||||
|           return fft::fftn(a, s); |           return fft::fftn(a, s); | ||||||
|         } |         } | ||||||
| @@ -125,9 +124,8 @@ void init_fft(nb::module_& parent_module) { | |||||||
|         } else if (axes.has_value()) { |         } else if (axes.has_value()) { | ||||||
|           return fft::ifftn(a, axes.value(), s); |           return fft::ifftn(a, axes.value(), s); | ||||||
|         } else if (n.has_value()) { |         } else if (n.has_value()) { | ||||||
|           std::vector<int> axes_(n.value().size()); |           throw std::invalid_argument( | ||||||
|           std::iota(axes_.begin(), axes_.end(), -n.value().size()); |               "[ifft2] `axes` should not be `None` if `s` is not `None`."); | ||||||
|           return fft::ifftn(a, n.value(), axes_, s); |  | ||||||
|         } else { |         } else { | ||||||
|           return fft::ifftn(a, s); |           return fft::ifftn(a, s); | ||||||
|         } |         } | ||||||
| @@ -162,9 +160,8 @@ void init_fft(nb::module_& parent_module) { | |||||||
|         } else if (axes.has_value()) { |         } else if (axes.has_value()) { | ||||||
|           return fft::fftn(a, axes.value(), s); |           return fft::fftn(a, axes.value(), s); | ||||||
|         } else if (n.has_value()) { |         } else if (n.has_value()) { | ||||||
|           std::vector<int> axes_(n.value().size()); |           throw std::invalid_argument( | ||||||
|           std::iota(axes_.begin(), axes_.end(), -n.value().size()); |               "[fftn] `axes` should not be `None` if `s` is not `None`."); | ||||||
|           return fft::fftn(a, n.value(), axes_, s); |  | ||||||
|         } else { |         } else { | ||||||
|           return fft::fftn(a, s); |           return fft::fftn(a, s); | ||||||
|         } |         } | ||||||
| @@ -200,9 +197,8 @@ void init_fft(nb::module_& parent_module) { | |||||||
|         } else if (axes.has_value()) { |         } else if (axes.has_value()) { | ||||||
|           return fft::ifftn(a, axes.value(), s); |           return fft::ifftn(a, axes.value(), s); | ||||||
|         } else if (n.has_value()) { |         } else if (n.has_value()) { | ||||||
|           std::vector<int> axes_(n.value().size()); |           throw std::invalid_argument( | ||||||
|           std::iota(axes_.begin(), axes_.end(), -n.value().size()); |               "[ifftn] `axes` should not be `None` if `s` is not `None`."); | ||||||
|           return fft::ifftn(a, n.value(), axes_, s); |  | ||||||
|         } else { |         } else { | ||||||
|           return fft::ifftn(a, s); |           return fft::ifftn(a, s); | ||||||
|         } |         } | ||||||
| @@ -307,9 +303,8 @@ void init_fft(nb::module_& parent_module) { | |||||||
|         } else if (axes.has_value()) { |         } else if (axes.has_value()) { | ||||||
|           return fft::rfftn(a, axes.value(), s); |           return fft::rfftn(a, axes.value(), s); | ||||||
|         } else if (n.has_value()) { |         } else if (n.has_value()) { | ||||||
|           std::vector<int> axes_(n.value().size()); |           throw std::invalid_argument( | ||||||
|           std::iota(axes_.begin(), axes_.end(), -n.value().size()); |               "[rfft2] `axes` should not be `None` if `s` is not `None`."); | ||||||
|           return fft::rfftn(a, n.value(), axes_, s); |  | ||||||
|         } else { |         } else { | ||||||
|           return fft::rfftn(a, s); |           return fft::rfftn(a, s); | ||||||
|         } |         } | ||||||
| @@ -350,9 +345,8 @@ void init_fft(nb::module_& parent_module) { | |||||||
|         } else if (axes.has_value()) { |         } else if (axes.has_value()) { | ||||||
|           return fft::irfftn(a, axes.value(), s); |           return fft::irfftn(a, axes.value(), s); | ||||||
|         } else if (n.has_value()) { |         } else if (n.has_value()) { | ||||||
|           std::vector<int> axes_(n.value().size()); |           throw std::invalid_argument( | ||||||
|           std::iota(axes_.begin(), axes_.end(), -n.value().size()); |               "[irfft2] `axes` should not be `None` if `s` is not `None`."); | ||||||
|           return fft::irfftn(a, n.value(), axes_, s); |  | ||||||
|         } else { |         } else { | ||||||
|           return fft::irfftn(a, s); |           return fft::irfftn(a, s); | ||||||
|         } |         } | ||||||
| @@ -393,9 +387,8 @@ void init_fft(nb::module_& parent_module) { | |||||||
|         } else if (axes.has_value()) { |         } else if (axes.has_value()) { | ||||||
|           return fft::rfftn(a, axes.value(), s); |           return fft::rfftn(a, axes.value(), s); | ||||||
|         } else if (n.has_value()) { |         } else if (n.has_value()) { | ||||||
|           std::vector<int> axes_(n.value().size()); |           throw std::invalid_argument( | ||||||
|           std::iota(axes_.begin(), axes_.end(), -n.value().size()); |               "[rfftn] `axes` should not be `None` if `s` is not `None`."); | ||||||
|           return fft::rfftn(a, n.value(), axes_, s); |  | ||||||
|         } else { |         } else { | ||||||
|           return fft::rfftn(a, s); |           return fft::rfftn(a, s); | ||||||
|         } |         } | ||||||
| @@ -436,9 +429,8 @@ void init_fft(nb::module_& parent_module) { | |||||||
|         } else if (axes.has_value()) { |         } else if (axes.has_value()) { | ||||||
|           return fft::irfftn(a, axes.value(), s); |           return fft::irfftn(a, axes.value(), s); | ||||||
|         } else if (n.has_value()) { |         } else if (n.has_value()) { | ||||||
|           std::vector<int> axes_(n.value().size()); |           throw std::invalid_argument( | ||||||
|           std::iota(axes_.begin(), axes_.end(), -n.value().size()); |               "[irfftn] `axes` should not be `None` if `s` is not `None`."); | ||||||
|           return fft::irfftn(a, n.value(), axes_, s); |  | ||||||
|         } else { |         } else { | ||||||
|           return fft::irfftn(a, s); |           return fft::irfftn(a, s); | ||||||
|         } |         } | ||||||
|   | |||||||
| @@ -71,6 +71,8 @@ class TestFFT(mlx_tests.MLXTestCase): | |||||||
|         ] |         ] | ||||||
|  |  | ||||||
|         for op, ax, s in itertools.product(ops, axes, shapes): |         for op, ax, s in itertools.product(ops, axes, shapes): | ||||||
|  |             if ax is None and s is not None: | ||||||
|  |                 continue | ||||||
|             x = a |             x = a | ||||||
|             if op in ["rfft2", "rfftn"]: |             if op in ["rfft2", "rfftn"]: | ||||||
|                 x = r |                 x = r | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun