mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
throw for invalid case and remove test (#1575)
This commit is contained in:
parent
91c0277356
commit
8c34c9dac4
@ -88,9 +88,8 @@ void init_fft(nb::module_& parent_module) {
|
|||||||
} else if (axes.has_value()) {
|
} else if (axes.has_value()) {
|
||||||
return fft::fftn(a, axes.value(), s);
|
return fft::fftn(a, axes.value(), s);
|
||||||
} else if (n.has_value()) {
|
} else if (n.has_value()) {
|
||||||
std::vector<int> axes_(n.value().size());
|
throw std::invalid_argument(
|
||||||
std::iota(axes_.begin(), axes_.end(), -n.value().size());
|
"[fft2] `axes` should not be `None` if `s` is not `None`.");
|
||||||
return fft::fftn(a, n.value(), axes_, s);
|
|
||||||
} else {
|
} else {
|
||||||
return fft::fftn(a, s);
|
return fft::fftn(a, s);
|
||||||
}
|
}
|
||||||
@ -125,9 +124,8 @@ void init_fft(nb::module_& parent_module) {
|
|||||||
} else if (axes.has_value()) {
|
} else if (axes.has_value()) {
|
||||||
return fft::ifftn(a, axes.value(), s);
|
return fft::ifftn(a, axes.value(), s);
|
||||||
} else if (n.has_value()) {
|
} else if (n.has_value()) {
|
||||||
std::vector<int> axes_(n.value().size());
|
throw std::invalid_argument(
|
||||||
std::iota(axes_.begin(), axes_.end(), -n.value().size());
|
"[ifft2] `axes` should not be `None` if `s` is not `None`.");
|
||||||
return fft::ifftn(a, n.value(), axes_, s);
|
|
||||||
} else {
|
} else {
|
||||||
return fft::ifftn(a, s);
|
return fft::ifftn(a, s);
|
||||||
}
|
}
|
||||||
@ -162,9 +160,8 @@ void init_fft(nb::module_& parent_module) {
|
|||||||
} else if (axes.has_value()) {
|
} else if (axes.has_value()) {
|
||||||
return fft::fftn(a, axes.value(), s);
|
return fft::fftn(a, axes.value(), s);
|
||||||
} else if (n.has_value()) {
|
} else if (n.has_value()) {
|
||||||
std::vector<int> axes_(n.value().size());
|
throw std::invalid_argument(
|
||||||
std::iota(axes_.begin(), axes_.end(), -n.value().size());
|
"[fftn] `axes` should not be `None` if `s` is not `None`.");
|
||||||
return fft::fftn(a, n.value(), axes_, s);
|
|
||||||
} else {
|
} else {
|
||||||
return fft::fftn(a, s);
|
return fft::fftn(a, s);
|
||||||
}
|
}
|
||||||
@ -200,9 +197,8 @@ void init_fft(nb::module_& parent_module) {
|
|||||||
} else if (axes.has_value()) {
|
} else if (axes.has_value()) {
|
||||||
return fft::ifftn(a, axes.value(), s);
|
return fft::ifftn(a, axes.value(), s);
|
||||||
} else if (n.has_value()) {
|
} else if (n.has_value()) {
|
||||||
std::vector<int> axes_(n.value().size());
|
throw std::invalid_argument(
|
||||||
std::iota(axes_.begin(), axes_.end(), -n.value().size());
|
"[ifftn] `axes` should not be `None` if `s` is not `None`.");
|
||||||
return fft::ifftn(a, n.value(), axes_, s);
|
|
||||||
} else {
|
} else {
|
||||||
return fft::ifftn(a, s);
|
return fft::ifftn(a, s);
|
||||||
}
|
}
|
||||||
@ -307,9 +303,8 @@ void init_fft(nb::module_& parent_module) {
|
|||||||
} else if (axes.has_value()) {
|
} else if (axes.has_value()) {
|
||||||
return fft::rfftn(a, axes.value(), s);
|
return fft::rfftn(a, axes.value(), s);
|
||||||
} else if (n.has_value()) {
|
} else if (n.has_value()) {
|
||||||
std::vector<int> axes_(n.value().size());
|
throw std::invalid_argument(
|
||||||
std::iota(axes_.begin(), axes_.end(), -n.value().size());
|
"[rfft2] `axes` should not be `None` if `s` is not `None`.");
|
||||||
return fft::rfftn(a, n.value(), axes_, s);
|
|
||||||
} else {
|
} else {
|
||||||
return fft::rfftn(a, s);
|
return fft::rfftn(a, s);
|
||||||
}
|
}
|
||||||
@ -350,9 +345,8 @@ void init_fft(nb::module_& parent_module) {
|
|||||||
} else if (axes.has_value()) {
|
} else if (axes.has_value()) {
|
||||||
return fft::irfftn(a, axes.value(), s);
|
return fft::irfftn(a, axes.value(), s);
|
||||||
} else if (n.has_value()) {
|
} else if (n.has_value()) {
|
||||||
std::vector<int> axes_(n.value().size());
|
throw std::invalid_argument(
|
||||||
std::iota(axes_.begin(), axes_.end(), -n.value().size());
|
"[irfft2] `axes` should not be `None` if `s` is not `None`.");
|
||||||
return fft::irfftn(a, n.value(), axes_, s);
|
|
||||||
} else {
|
} else {
|
||||||
return fft::irfftn(a, s);
|
return fft::irfftn(a, s);
|
||||||
}
|
}
|
||||||
@ -393,9 +387,8 @@ void init_fft(nb::module_& parent_module) {
|
|||||||
} else if (axes.has_value()) {
|
} else if (axes.has_value()) {
|
||||||
return fft::rfftn(a, axes.value(), s);
|
return fft::rfftn(a, axes.value(), s);
|
||||||
} else if (n.has_value()) {
|
} else if (n.has_value()) {
|
||||||
std::vector<int> axes_(n.value().size());
|
throw std::invalid_argument(
|
||||||
std::iota(axes_.begin(), axes_.end(), -n.value().size());
|
"[rfftn] `axes` should not be `None` if `s` is not `None`.");
|
||||||
return fft::rfftn(a, n.value(), axes_, s);
|
|
||||||
} else {
|
} else {
|
||||||
return fft::rfftn(a, s);
|
return fft::rfftn(a, s);
|
||||||
}
|
}
|
||||||
@ -436,9 +429,8 @@ void init_fft(nb::module_& parent_module) {
|
|||||||
} else if (axes.has_value()) {
|
} else if (axes.has_value()) {
|
||||||
return fft::irfftn(a, axes.value(), s);
|
return fft::irfftn(a, axes.value(), s);
|
||||||
} else if (n.has_value()) {
|
} else if (n.has_value()) {
|
||||||
std::vector<int> axes_(n.value().size());
|
throw std::invalid_argument(
|
||||||
std::iota(axes_.begin(), axes_.end(), -n.value().size());
|
"[irfftn] `axes` should not be `None` if `s` is not `None`.");
|
||||||
return fft::irfftn(a, n.value(), axes_, s);
|
|
||||||
} else {
|
} else {
|
||||||
return fft::irfftn(a, s);
|
return fft::irfftn(a, s);
|
||||||
}
|
}
|
||||||
|
@ -71,6 +71,8 @@ class TestFFT(mlx_tests.MLXTestCase):
|
|||||||
]
|
]
|
||||||
|
|
||||||
for op, ax, s in itertools.product(ops, axes, shapes):
|
for op, ax, s in itertools.product(ops, axes, shapes):
|
||||||
|
if ax is None and s is not None:
|
||||||
|
continue
|
||||||
x = a
|
x = a
|
||||||
if op in ["rfft2", "rfftn"]:
|
if op in ["rfft2", "rfftn"]:
|
||||||
x = r
|
x = r
|
||||||
|
Loading…
Reference in New Issue
Block a user