throw for invalid case and remove test (#1575)

This commit is contained in:
Awni Hannun 2024-11-08 12:04:03 -08:00 committed by GitHub
parent 91c0277356
commit 8c34c9dac4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 18 additions and 24 deletions

View File

@ -88,9 +88,8 @@ void init_fft(nb::module_& parent_module) {
} else if (axes.has_value()) {
return fft::fftn(a, axes.value(), s);
} else if (n.has_value()) {
std::vector<int> axes_(n.value().size());
std::iota(axes_.begin(), axes_.end(), -n.value().size());
return fft::fftn(a, n.value(), axes_, s);
throw std::invalid_argument(
"[fft2] `axes` should not be `None` if `s` is not `None`.");
} else {
return fft::fftn(a, s);
}
@ -125,9 +124,8 @@ void init_fft(nb::module_& parent_module) {
} else if (axes.has_value()) {
return fft::ifftn(a, axes.value(), s);
} else if (n.has_value()) {
std::vector<int> axes_(n.value().size());
std::iota(axes_.begin(), axes_.end(), -n.value().size());
return fft::ifftn(a, n.value(), axes_, s);
throw std::invalid_argument(
"[ifft2] `axes` should not be `None` if `s` is not `None`.");
} else {
return fft::ifftn(a, s);
}
@ -162,9 +160,8 @@ void init_fft(nb::module_& parent_module) {
} else if (axes.has_value()) {
return fft::fftn(a, axes.value(), s);
} else if (n.has_value()) {
std::vector<int> axes_(n.value().size());
std::iota(axes_.begin(), axes_.end(), -n.value().size());
return fft::fftn(a, n.value(), axes_, s);
throw std::invalid_argument(
"[fftn] `axes` should not be `None` if `s` is not `None`.");
} else {
return fft::fftn(a, s);
}
@ -200,9 +197,8 @@ void init_fft(nb::module_& parent_module) {
} else if (axes.has_value()) {
return fft::ifftn(a, axes.value(), s);
} else if (n.has_value()) {
std::vector<int> axes_(n.value().size());
std::iota(axes_.begin(), axes_.end(), -n.value().size());
return fft::ifftn(a, n.value(), axes_, s);
throw std::invalid_argument(
"[ifftn] `axes` should not be `None` if `s` is not `None`.");
} else {
return fft::ifftn(a, s);
}
@ -307,9 +303,8 @@ void init_fft(nb::module_& parent_module) {
} else if (axes.has_value()) {
return fft::rfftn(a, axes.value(), s);
} else if (n.has_value()) {
std::vector<int> axes_(n.value().size());
std::iota(axes_.begin(), axes_.end(), -n.value().size());
return fft::rfftn(a, n.value(), axes_, s);
throw std::invalid_argument(
"[rfft2] `axes` should not be `None` if `s` is not `None`.");
} else {
return fft::rfftn(a, s);
}
@ -350,9 +345,8 @@ void init_fft(nb::module_& parent_module) {
} else if (axes.has_value()) {
return fft::irfftn(a, axes.value(), s);
} else if (n.has_value()) {
std::vector<int> axes_(n.value().size());
std::iota(axes_.begin(), axes_.end(), -n.value().size());
return fft::irfftn(a, n.value(), axes_, s);
throw std::invalid_argument(
"[irfft2] `axes` should not be `None` if `s` is not `None`.");
} else {
return fft::irfftn(a, s);
}
@ -393,9 +387,8 @@ void init_fft(nb::module_& parent_module) {
} else if (axes.has_value()) {
return fft::rfftn(a, axes.value(), s);
} else if (n.has_value()) {
std::vector<int> axes_(n.value().size());
std::iota(axes_.begin(), axes_.end(), -n.value().size());
return fft::rfftn(a, n.value(), axes_, s);
throw std::invalid_argument(
"[rfftn] `axes` should not be `None` if `s` is not `None`.");
} else {
return fft::rfftn(a, s);
}
@ -436,9 +429,8 @@ void init_fft(nb::module_& parent_module) {
} else if (axes.has_value()) {
return fft::irfftn(a, axes.value(), s);
} else if (n.has_value()) {
std::vector<int> axes_(n.value().size());
std::iota(axes_.begin(), axes_.end(), -n.value().size());
return fft::irfftn(a, n.value(), axes_, s);
throw std::invalid_argument(
"[irfftn] `axes` should not be `None` if `s` is not `None`.");
} else {
return fft::irfftn(a, s);
}

View File

@ -71,6 +71,8 @@ class TestFFT(mlx_tests.MLXTestCase):
]
for op, ax, s in itertools.product(ops, axes, shapes):
if ax is None and s is not None:
continue
x = a
if op in ["rfft2", "rfftn"]:
x = r