mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Remove "using namespace mlx::core" in python/src (#1689)
This commit is contained in:
@@ -9,24 +9,23 @@
|
||||
#include "mlx/fft.h"
|
||||
#include "mlx/ops.h"
|
||||
|
||||
namespace mx = mlx::core;
|
||||
namespace nb = nanobind;
|
||||
using namespace nb::literals;
|
||||
|
||||
using namespace mlx::core;
|
||||
|
||||
void init_fft(nb::module_& parent_module) {
|
||||
auto m = parent_module.def_submodule(
|
||||
"fft", "mlx.core.fft: Fast Fourier Transforms.");
|
||||
m.def(
|
||||
"fft",
|
||||
[](const array& a,
|
||||
[](const mx::array& a,
|
||||
const std::optional<int>& n,
|
||||
int axis,
|
||||
StreamOrDevice s) {
|
||||
mx::StreamOrDevice s) {
|
||||
if (n.has_value()) {
|
||||
return fft::fft(a, n.value(), axis, s);
|
||||
return mx::fft::fft(a, n.value(), axis, s);
|
||||
} else {
|
||||
return fft::fft(a, axis, s);
|
||||
return mx::fft::fft(a, axis, s);
|
||||
}
|
||||
},
|
||||
"a"_a,
|
||||
@@ -49,14 +48,14 @@ void init_fft(nb::module_& parent_module) {
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"ifft",
|
||||
[](const array& a,
|
||||
[](const mx::array& a,
|
||||
const std::optional<int>& n,
|
||||
int axis,
|
||||
StreamOrDevice s) {
|
||||
mx::StreamOrDevice s) {
|
||||
if (n.has_value()) {
|
||||
return fft::ifft(a, n.value(), axis, s);
|
||||
return mx::fft::ifft(a, n.value(), axis, s);
|
||||
} else {
|
||||
return fft::ifft(a, axis, s);
|
||||
return mx::fft::ifft(a, axis, s);
|
||||
}
|
||||
},
|
||||
"a"_a,
|
||||
@@ -79,19 +78,19 @@ void init_fft(nb::module_& parent_module) {
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"fft2",
|
||||
[](const array& a,
|
||||
[](const mx::array& a,
|
||||
const std::optional<std::vector<int>>& n,
|
||||
const std::optional<std::vector<int>>& axes,
|
||||
StreamOrDevice s) {
|
||||
mx::StreamOrDevice s) {
|
||||
if (axes.has_value() && n.has_value()) {
|
||||
return fft::fftn(a, n.value(), axes.value(), s);
|
||||
return mx::fft::fftn(a, n.value(), axes.value(), s);
|
||||
} else if (axes.has_value()) {
|
||||
return fft::fftn(a, axes.value(), s);
|
||||
return mx::fft::fftn(a, axes.value(), s);
|
||||
} else if (n.has_value()) {
|
||||
throw std::invalid_argument(
|
||||
"[fft2] `axes` should not be `None` if `s` is not `None`.");
|
||||
} else {
|
||||
return fft::fftn(a, s);
|
||||
return mx::fft::fftn(a, s);
|
||||
}
|
||||
},
|
||||
"a"_a,
|
||||
@@ -115,19 +114,19 @@ void init_fft(nb::module_& parent_module) {
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"ifft2",
|
||||
[](const array& a,
|
||||
[](const mx::array& a,
|
||||
const std::optional<std::vector<int>>& n,
|
||||
const std::optional<std::vector<int>>& axes,
|
||||
StreamOrDevice s) {
|
||||
mx::StreamOrDevice s) {
|
||||
if (axes.has_value() && n.has_value()) {
|
||||
return fft::ifftn(a, n.value(), axes.value(), s);
|
||||
return mx::fft::ifftn(a, n.value(), axes.value(), s);
|
||||
} else if (axes.has_value()) {
|
||||
return fft::ifftn(a, axes.value(), s);
|
||||
return mx::fft::ifftn(a, axes.value(), s);
|
||||
} else if (n.has_value()) {
|
||||
throw std::invalid_argument(
|
||||
"[ifft2] `axes` should not be `None` if `s` is not `None`.");
|
||||
} else {
|
||||
return fft::ifftn(a, s);
|
||||
return mx::fft::ifftn(a, s);
|
||||
}
|
||||
},
|
||||
"a"_a,
|
||||
@@ -151,19 +150,19 @@ void init_fft(nb::module_& parent_module) {
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"fftn",
|
||||
[](const array& a,
|
||||
[](const mx::array& a,
|
||||
const std::optional<std::vector<int>>& n,
|
||||
const std::optional<std::vector<int>>& axes,
|
||||
StreamOrDevice s) {
|
||||
mx::StreamOrDevice s) {
|
||||
if (axes.has_value() && n.has_value()) {
|
||||
return fft::fftn(a, n.value(), axes.value(), s);
|
||||
return mx::fft::fftn(a, n.value(), axes.value(), s);
|
||||
} else if (axes.has_value()) {
|
||||
return fft::fftn(a, axes.value(), s);
|
||||
return mx::fft::fftn(a, axes.value(), s);
|
||||
} else if (n.has_value()) {
|
||||
throw std::invalid_argument(
|
||||
"[fftn] `axes` should not be `None` if `s` is not `None`.");
|
||||
} else {
|
||||
return fft::fftn(a, s);
|
||||
return mx::fft::fftn(a, s);
|
||||
}
|
||||
},
|
||||
"a"_a,
|
||||
@@ -188,19 +187,19 @@ void init_fft(nb::module_& parent_module) {
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"ifftn",
|
||||
[](const array& a,
|
||||
[](const mx::array& a,
|
||||
const std::optional<std::vector<int>>& n,
|
||||
const std::optional<std::vector<int>>& axes,
|
||||
StreamOrDevice s) {
|
||||
mx::StreamOrDevice s) {
|
||||
if (axes.has_value() && n.has_value()) {
|
||||
return fft::ifftn(a, n.value(), axes.value(), s);
|
||||
return mx::fft::ifftn(a, n.value(), axes.value(), s);
|
||||
} else if (axes.has_value()) {
|
||||
return fft::ifftn(a, axes.value(), s);
|
||||
return mx::fft::ifftn(a, axes.value(), s);
|
||||
} else if (n.has_value()) {
|
||||
throw std::invalid_argument(
|
||||
"[ifftn] `axes` should not be `None` if `s` is not `None`.");
|
||||
} else {
|
||||
return fft::ifftn(a, s);
|
||||
return mx::fft::ifftn(a, s);
|
||||
}
|
||||
},
|
||||
"a"_a,
|
||||
@@ -225,14 +224,14 @@ void init_fft(nb::module_& parent_module) {
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"rfft",
|
||||
[](const array& a,
|
||||
[](const mx::array& a,
|
||||
const std::optional<int>& n,
|
||||
int axis,
|
||||
StreamOrDevice s) {
|
||||
mx::StreamOrDevice s) {
|
||||
if (n.has_value()) {
|
||||
return fft::rfft(a, n.value(), axis, s);
|
||||
return mx::fft::rfft(a, n.value(), axis, s);
|
||||
} else {
|
||||
return fft::rfft(a, axis, s);
|
||||
return mx::fft::rfft(a, axis, s);
|
||||
}
|
||||
},
|
||||
"a"_a,
|
||||
@@ -260,14 +259,14 @@ void init_fft(nb::module_& parent_module) {
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"irfft",
|
||||
[](const array& a,
|
||||
[](const mx::array& a,
|
||||
const std::optional<int>& n,
|
||||
int axis,
|
||||
StreamOrDevice s) {
|
||||
mx::StreamOrDevice s) {
|
||||
if (n.has_value()) {
|
||||
return fft::irfft(a, n.value(), axis, s);
|
||||
return mx::fft::irfft(a, n.value(), axis, s);
|
||||
} else {
|
||||
return fft::irfft(a, axis, s);
|
||||
return mx::fft::irfft(a, axis, s);
|
||||
}
|
||||
},
|
||||
"a"_a,
|
||||
@@ -294,19 +293,19 @@ void init_fft(nb::module_& parent_module) {
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"rfft2",
|
||||
[](const array& a,
|
||||
[](const mx::array& a,
|
||||
const std::optional<std::vector<int>>& n,
|
||||
const std::optional<std::vector<int>>& axes,
|
||||
StreamOrDevice s) {
|
||||
mx::StreamOrDevice s) {
|
||||
if (axes.has_value() && n.has_value()) {
|
||||
return fft::rfftn(a, n.value(), axes.value(), s);
|
||||
return mx::fft::rfftn(a, n.value(), axes.value(), s);
|
||||
} else if (axes.has_value()) {
|
||||
return fft::rfftn(a, axes.value(), s);
|
||||
return mx::fft::rfftn(a, axes.value(), s);
|
||||
} else if (n.has_value()) {
|
||||
throw std::invalid_argument(
|
||||
"[rfft2] `axes` should not be `None` if `s` is not `None`.");
|
||||
} else {
|
||||
return fft::rfftn(a, s);
|
||||
return mx::fft::rfftn(a, s);
|
||||
}
|
||||
},
|
||||
"a"_a,
|
||||
@@ -336,19 +335,19 @@ void init_fft(nb::module_& parent_module) {
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"irfft2",
|
||||
[](const array& a,
|
||||
[](const mx::array& a,
|
||||
const std::optional<std::vector<int>>& n,
|
||||
const std::optional<std::vector<int>>& axes,
|
||||
StreamOrDevice s) {
|
||||
mx::StreamOrDevice s) {
|
||||
if (axes.has_value() && n.has_value()) {
|
||||
return fft::irfftn(a, n.value(), axes.value(), s);
|
||||
return mx::fft::irfftn(a, n.value(), axes.value(), s);
|
||||
} else if (axes.has_value()) {
|
||||
return fft::irfftn(a, axes.value(), s);
|
||||
return mx::fft::irfftn(a, axes.value(), s);
|
||||
} else if (n.has_value()) {
|
||||
throw std::invalid_argument(
|
||||
"[irfft2] `axes` should not be `None` if `s` is not `None`.");
|
||||
} else {
|
||||
return fft::irfftn(a, s);
|
||||
return mx::fft::irfftn(a, s);
|
||||
}
|
||||
},
|
||||
"a"_a,
|
||||
@@ -378,19 +377,19 @@ void init_fft(nb::module_& parent_module) {
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"rfftn",
|
||||
[](const array& a,
|
||||
[](const mx::array& a,
|
||||
const std::optional<std::vector<int>>& n,
|
||||
const std::optional<std::vector<int>>& axes,
|
||||
StreamOrDevice s) {
|
||||
mx::StreamOrDevice s) {
|
||||
if (axes.has_value() && n.has_value()) {
|
||||
return fft::rfftn(a, n.value(), axes.value(), s);
|
||||
return mx::fft::rfftn(a, n.value(), axes.value(), s);
|
||||
} else if (axes.has_value()) {
|
||||
return fft::rfftn(a, axes.value(), s);
|
||||
return mx::fft::rfftn(a, axes.value(), s);
|
||||
} else if (n.has_value()) {
|
||||
throw std::invalid_argument(
|
||||
"[rfftn] `axes` should not be `None` if `s` is not `None`.");
|
||||
} else {
|
||||
return fft::rfftn(a, s);
|
||||
return mx::fft::rfftn(a, s);
|
||||
}
|
||||
},
|
||||
"a"_a,
|
||||
@@ -420,19 +419,19 @@ void init_fft(nb::module_& parent_module) {
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"irfftn",
|
||||
[](const array& a,
|
||||
[](const mx::array& a,
|
||||
const std::optional<std::vector<int>>& n,
|
||||
const std::optional<std::vector<int>>& axes,
|
||||
StreamOrDevice s) {
|
||||
mx::StreamOrDevice s) {
|
||||
if (axes.has_value() && n.has_value()) {
|
||||
return fft::irfftn(a, n.value(), axes.value(), s);
|
||||
return mx::fft::irfftn(a, n.value(), axes.value(), s);
|
||||
} else if (axes.has_value()) {
|
||||
return fft::irfftn(a, axes.value(), s);
|
||||
return mx::fft::irfftn(a, axes.value(), s);
|
||||
} else if (n.has_value()) {
|
||||
throw std::invalid_argument(
|
||||
"[irfftn] `axes` should not be `None` if `s` is not `None`.");
|
||||
} else {
|
||||
return fft::irfftn(a, s);
|
||||
return mx::fft::irfftn(a, s);
|
||||
}
|
||||
},
|
||||
"a"_a,
|
||||
|
||||
Reference in New Issue
Block a user