mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 01:17:26 +08:00
throw for invalid case and remove test (#1575)
This commit is contained in:
parent
91c0277356
commit
8c34c9dac4
@ -88,9 +88,8 @@ void init_fft(nb::module_& parent_module) {
|
||||
} else if (axes.has_value()) {
|
||||
return fft::fftn(a, axes.value(), s);
|
||||
} else if (n.has_value()) {
|
||||
std::vector<int> axes_(n.value().size());
|
||||
std::iota(axes_.begin(), axes_.end(), -n.value().size());
|
||||
return fft::fftn(a, n.value(), axes_, s);
|
||||
throw std::invalid_argument(
|
||||
"[fft2] `axes` should not be `None` if `s` is not `None`.");
|
||||
} else {
|
||||
return fft::fftn(a, s);
|
||||
}
|
||||
@ -125,9 +124,8 @@ void init_fft(nb::module_& parent_module) {
|
||||
} else if (axes.has_value()) {
|
||||
return fft::ifftn(a, axes.value(), s);
|
||||
} else if (n.has_value()) {
|
||||
std::vector<int> axes_(n.value().size());
|
||||
std::iota(axes_.begin(), axes_.end(), -n.value().size());
|
||||
return fft::ifftn(a, n.value(), axes_, s);
|
||||
throw std::invalid_argument(
|
||||
"[ifft2] `axes` should not be `None` if `s` is not `None`.");
|
||||
} else {
|
||||
return fft::ifftn(a, s);
|
||||
}
|
||||
@ -162,9 +160,8 @@ void init_fft(nb::module_& parent_module) {
|
||||
} else if (axes.has_value()) {
|
||||
return fft::fftn(a, axes.value(), s);
|
||||
} else if (n.has_value()) {
|
||||
std::vector<int> axes_(n.value().size());
|
||||
std::iota(axes_.begin(), axes_.end(), -n.value().size());
|
||||
return fft::fftn(a, n.value(), axes_, s);
|
||||
throw std::invalid_argument(
|
||||
"[fftn] `axes` should not be `None` if `s` is not `None`.");
|
||||
} else {
|
||||
return fft::fftn(a, s);
|
||||
}
|
||||
@ -200,9 +197,8 @@ void init_fft(nb::module_& parent_module) {
|
||||
} else if (axes.has_value()) {
|
||||
return fft::ifftn(a, axes.value(), s);
|
||||
} else if (n.has_value()) {
|
||||
std::vector<int> axes_(n.value().size());
|
||||
std::iota(axes_.begin(), axes_.end(), -n.value().size());
|
||||
return fft::ifftn(a, n.value(), axes_, s);
|
||||
throw std::invalid_argument(
|
||||
"[ifftn] `axes` should not be `None` if `s` is not `None`.");
|
||||
} else {
|
||||
return fft::ifftn(a, s);
|
||||
}
|
||||
@ -307,9 +303,8 @@ void init_fft(nb::module_& parent_module) {
|
||||
} else if (axes.has_value()) {
|
||||
return fft::rfftn(a, axes.value(), s);
|
||||
} else if (n.has_value()) {
|
||||
std::vector<int> axes_(n.value().size());
|
||||
std::iota(axes_.begin(), axes_.end(), -n.value().size());
|
||||
return fft::rfftn(a, n.value(), axes_, s);
|
||||
throw std::invalid_argument(
|
||||
"[rfft2] `axes` should not be `None` if `s` is not `None`.");
|
||||
} else {
|
||||
return fft::rfftn(a, s);
|
||||
}
|
||||
@ -350,9 +345,8 @@ void init_fft(nb::module_& parent_module) {
|
||||
} else if (axes.has_value()) {
|
||||
return fft::irfftn(a, axes.value(), s);
|
||||
} else if (n.has_value()) {
|
||||
std::vector<int> axes_(n.value().size());
|
||||
std::iota(axes_.begin(), axes_.end(), -n.value().size());
|
||||
return fft::irfftn(a, n.value(), axes_, s);
|
||||
throw std::invalid_argument(
|
||||
"[irfft2] `axes` should not be `None` if `s` is not `None`.");
|
||||
} else {
|
||||
return fft::irfftn(a, s);
|
||||
}
|
||||
@ -393,9 +387,8 @@ void init_fft(nb::module_& parent_module) {
|
||||
} else if (axes.has_value()) {
|
||||
return fft::rfftn(a, axes.value(), s);
|
||||
} else if (n.has_value()) {
|
||||
std::vector<int> axes_(n.value().size());
|
||||
std::iota(axes_.begin(), axes_.end(), -n.value().size());
|
||||
return fft::rfftn(a, n.value(), axes_, s);
|
||||
throw std::invalid_argument(
|
||||
"[rfftn] `axes` should not be `None` if `s` is not `None`.");
|
||||
} else {
|
||||
return fft::rfftn(a, s);
|
||||
}
|
||||
@ -436,9 +429,8 @@ void init_fft(nb::module_& parent_module) {
|
||||
} else if (axes.has_value()) {
|
||||
return fft::irfftn(a, axes.value(), s);
|
||||
} else if (n.has_value()) {
|
||||
std::vector<int> axes_(n.value().size());
|
||||
std::iota(axes_.begin(), axes_.end(), -n.value().size());
|
||||
return fft::irfftn(a, n.value(), axes_, s);
|
||||
throw std::invalid_argument(
|
||||
"[irfftn] `axes` should not be `None` if `s` is not `None`.");
|
||||
} else {
|
||||
return fft::irfftn(a, s);
|
||||
}
|
||||
|
@ -71,6 +71,8 @@ class TestFFT(mlx_tests.MLXTestCase):
|
||||
]
|
||||
|
||||
for op, ax, s in itertools.product(ops, axes, shapes):
|
||||
if ax is None and s is not None:
|
||||
continue
|
||||
x = a
|
||||
if op in ["rfft2", "rfftn"]:
|
||||
x = r
|
||||
|
Loading…
Reference in New Issue
Block a user