mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-04 05:18:09 +08:00
throw for invalid case and remove test (#1575)
This commit is contained in:
@@ -88,9 +88,8 @@ void init_fft(nb::module_& parent_module) {
|
||||
} else if (axes.has_value()) {
|
||||
return fft::fftn(a, axes.value(), s);
|
||||
} else if (n.has_value()) {
|
||||
std::vector<int> axes_(n.value().size());
|
||||
std::iota(axes_.begin(), axes_.end(), -n.value().size());
|
||||
return fft::fftn(a, n.value(), axes_, s);
|
||||
throw std::invalid_argument(
|
||||
"[fft2] `axes` should not be `None` if `s` is not `None`.");
|
||||
} else {
|
||||
return fft::fftn(a, s);
|
||||
}
|
||||
@@ -125,9 +124,8 @@ void init_fft(nb::module_& parent_module) {
|
||||
} else if (axes.has_value()) {
|
||||
return fft::ifftn(a, axes.value(), s);
|
||||
} else if (n.has_value()) {
|
||||
std::vector<int> axes_(n.value().size());
|
||||
std::iota(axes_.begin(), axes_.end(), -n.value().size());
|
||||
return fft::ifftn(a, n.value(), axes_, s);
|
||||
throw std::invalid_argument(
|
||||
"[ifft2] `axes` should not be `None` if `s` is not `None`.");
|
||||
} else {
|
||||
return fft::ifftn(a, s);
|
||||
}
|
||||
@@ -162,9 +160,8 @@ void init_fft(nb::module_& parent_module) {
|
||||
} else if (axes.has_value()) {
|
||||
return fft::fftn(a, axes.value(), s);
|
||||
} else if (n.has_value()) {
|
||||
std::vector<int> axes_(n.value().size());
|
||||
std::iota(axes_.begin(), axes_.end(), -n.value().size());
|
||||
return fft::fftn(a, n.value(), axes_, s);
|
||||
throw std::invalid_argument(
|
||||
"[fftn] `axes` should not be `None` if `s` is not `None`.");
|
||||
} else {
|
||||
return fft::fftn(a, s);
|
||||
}
|
||||
@@ -200,9 +197,8 @@ void init_fft(nb::module_& parent_module) {
|
||||
} else if (axes.has_value()) {
|
||||
return fft::ifftn(a, axes.value(), s);
|
||||
} else if (n.has_value()) {
|
||||
std::vector<int> axes_(n.value().size());
|
||||
std::iota(axes_.begin(), axes_.end(), -n.value().size());
|
||||
return fft::ifftn(a, n.value(), axes_, s);
|
||||
throw std::invalid_argument(
|
||||
"[ifftn] `axes` should not be `None` if `s` is not `None`.");
|
||||
} else {
|
||||
return fft::ifftn(a, s);
|
||||
}
|
||||
@@ -307,9 +303,8 @@ void init_fft(nb::module_& parent_module) {
|
||||
} else if (axes.has_value()) {
|
||||
return fft::rfftn(a, axes.value(), s);
|
||||
} else if (n.has_value()) {
|
||||
std::vector<int> axes_(n.value().size());
|
||||
std::iota(axes_.begin(), axes_.end(), -n.value().size());
|
||||
return fft::rfftn(a, n.value(), axes_, s);
|
||||
throw std::invalid_argument(
|
||||
"[rfft2] `axes` should not be `None` if `s` is not `None`.");
|
||||
} else {
|
||||
return fft::rfftn(a, s);
|
||||
}
|
||||
@@ -350,9 +345,8 @@ void init_fft(nb::module_& parent_module) {
|
||||
} else if (axes.has_value()) {
|
||||
return fft::irfftn(a, axes.value(), s);
|
||||
} else if (n.has_value()) {
|
||||
std::vector<int> axes_(n.value().size());
|
||||
std::iota(axes_.begin(), axes_.end(), -n.value().size());
|
||||
return fft::irfftn(a, n.value(), axes_, s);
|
||||
throw std::invalid_argument(
|
||||
"[irfft2] `axes` should not be `None` if `s` is not `None`.");
|
||||
} else {
|
||||
return fft::irfftn(a, s);
|
||||
}
|
||||
@@ -393,9 +387,8 @@ void init_fft(nb::module_& parent_module) {
|
||||
} else if (axes.has_value()) {
|
||||
return fft::rfftn(a, axes.value(), s);
|
||||
} else if (n.has_value()) {
|
||||
std::vector<int> axes_(n.value().size());
|
||||
std::iota(axes_.begin(), axes_.end(), -n.value().size());
|
||||
return fft::rfftn(a, n.value(), axes_, s);
|
||||
throw std::invalid_argument(
|
||||
"[rfftn] `axes` should not be `None` if `s` is not `None`.");
|
||||
} else {
|
||||
return fft::rfftn(a, s);
|
||||
}
|
||||
@@ -436,9 +429,8 @@ void init_fft(nb::module_& parent_module) {
|
||||
} else if (axes.has_value()) {
|
||||
return fft::irfftn(a, axes.value(), s);
|
||||
} else if (n.has_value()) {
|
||||
std::vector<int> axes_(n.value().size());
|
||||
std::iota(axes_.begin(), axes_.end(), -n.value().size());
|
||||
return fft::irfftn(a, n.value(), axes_, s);
|
||||
throw std::invalid_argument(
|
||||
"[irfftn] `axes` should not be `None` if `s` is not `None`.");
|
||||
} else {
|
||||
return fft::irfftn(a, s);
|
||||
}
|
||||
|
@@ -71,6 +71,8 @@ class TestFFT(mlx_tests.MLXTestCase):
|
||||
]
|
||||
|
||||
for op, ax, s in itertools.product(ops, axes, shapes):
|
||||
if ax is None and s is not None:
|
||||
continue
|
||||
x = a
|
||||
if op in ["rfft2", "rfftn"]:
|
||||
x = r
|
||||
|
Reference in New Issue
Block a user