Remove "using namespace mlx::core" in python/src (#1689)

This commit is contained in:
Cheng
2024-12-12 08:45:39 +09:00
committed by GitHub
parent f3dfa36a3a
commit 0bf19037ca
22 changed files with 1423 additions and 1302 deletions

View File

@@ -9,24 +9,23 @@
#include "mlx/fft.h"
#include "mlx/ops.h"
namespace mx = mlx::core;
namespace nb = nanobind;
using namespace nb::literals;
using namespace mlx::core;
void init_fft(nb::module_& parent_module) {
auto m = parent_module.def_submodule(
"fft", "mlx.core.fft: Fast Fourier Transforms.");
m.def(
"fft",
[](const array& a,
[](const mx::array& a,
const std::optional<int>& n,
int axis,
StreamOrDevice s) {
mx::StreamOrDevice s) {
if (n.has_value()) {
return fft::fft(a, n.value(), axis, s);
return mx::fft::fft(a, n.value(), axis, s);
} else {
return fft::fft(a, axis, s);
return mx::fft::fft(a, axis, s);
}
},
"a"_a,
@@ -49,14 +48,14 @@ void init_fft(nb::module_& parent_module) {
)pbdoc");
m.def(
"ifft",
[](const array& a,
[](const mx::array& a,
const std::optional<int>& n,
int axis,
StreamOrDevice s) {
mx::StreamOrDevice s) {
if (n.has_value()) {
return fft::ifft(a, n.value(), axis, s);
return mx::fft::ifft(a, n.value(), axis, s);
} else {
return fft::ifft(a, axis, s);
return mx::fft::ifft(a, axis, s);
}
},
"a"_a,
@@ -79,19 +78,19 @@ void init_fft(nb::module_& parent_module) {
)pbdoc");
m.def(
"fft2",
[](const array& a,
[](const mx::array& a,
const std::optional<std::vector<int>>& n,
const std::optional<std::vector<int>>& axes,
StreamOrDevice s) {
mx::StreamOrDevice s) {
if (axes.has_value() && n.has_value()) {
return fft::fftn(a, n.value(), axes.value(), s);
return mx::fft::fftn(a, n.value(), axes.value(), s);
} else if (axes.has_value()) {
return fft::fftn(a, axes.value(), s);
return mx::fft::fftn(a, axes.value(), s);
} else if (n.has_value()) {
throw std::invalid_argument(
"[fft2] `axes` should not be `None` if `s` is not `None`.");
} else {
return fft::fftn(a, s);
return mx::fft::fftn(a, s);
}
},
"a"_a,
@@ -115,19 +114,19 @@ void init_fft(nb::module_& parent_module) {
)pbdoc");
m.def(
"ifft2",
[](const array& a,
[](const mx::array& a,
const std::optional<std::vector<int>>& n,
const std::optional<std::vector<int>>& axes,
StreamOrDevice s) {
mx::StreamOrDevice s) {
if (axes.has_value() && n.has_value()) {
return fft::ifftn(a, n.value(), axes.value(), s);
return mx::fft::ifftn(a, n.value(), axes.value(), s);
} else if (axes.has_value()) {
return fft::ifftn(a, axes.value(), s);
return mx::fft::ifftn(a, axes.value(), s);
} else if (n.has_value()) {
throw std::invalid_argument(
"[ifft2] `axes` should not be `None` if `s` is not `None`.");
} else {
return fft::ifftn(a, s);
return mx::fft::ifftn(a, s);
}
},
"a"_a,
@@ -151,19 +150,19 @@ void init_fft(nb::module_& parent_module) {
)pbdoc");
m.def(
"fftn",
[](const array& a,
[](const mx::array& a,
const std::optional<std::vector<int>>& n,
const std::optional<std::vector<int>>& axes,
StreamOrDevice s) {
mx::StreamOrDevice s) {
if (axes.has_value() && n.has_value()) {
return fft::fftn(a, n.value(), axes.value(), s);
return mx::fft::fftn(a, n.value(), axes.value(), s);
} else if (axes.has_value()) {
return fft::fftn(a, axes.value(), s);
return mx::fft::fftn(a, axes.value(), s);
} else if (n.has_value()) {
throw std::invalid_argument(
"[fftn] `axes` should not be `None` if `s` is not `None`.");
} else {
return fft::fftn(a, s);
return mx::fft::fftn(a, s);
}
},
"a"_a,
@@ -188,19 +187,19 @@ void init_fft(nb::module_& parent_module) {
)pbdoc");
m.def(
"ifftn",
[](const array& a,
[](const mx::array& a,
const std::optional<std::vector<int>>& n,
const std::optional<std::vector<int>>& axes,
StreamOrDevice s) {
mx::StreamOrDevice s) {
if (axes.has_value() && n.has_value()) {
return fft::ifftn(a, n.value(), axes.value(), s);
return mx::fft::ifftn(a, n.value(), axes.value(), s);
} else if (axes.has_value()) {
return fft::ifftn(a, axes.value(), s);
return mx::fft::ifftn(a, axes.value(), s);
} else if (n.has_value()) {
throw std::invalid_argument(
"[ifftn] `axes` should not be `None` if `s` is not `None`.");
} else {
return fft::ifftn(a, s);
return mx::fft::ifftn(a, s);
}
},
"a"_a,
@@ -225,14 +224,14 @@ void init_fft(nb::module_& parent_module) {
)pbdoc");
m.def(
"rfft",
[](const array& a,
[](const mx::array& a,
const std::optional<int>& n,
int axis,
StreamOrDevice s) {
mx::StreamOrDevice s) {
if (n.has_value()) {
return fft::rfft(a, n.value(), axis, s);
return mx::fft::rfft(a, n.value(), axis, s);
} else {
return fft::rfft(a, axis, s);
return mx::fft::rfft(a, axis, s);
}
},
"a"_a,
@@ -260,14 +259,14 @@ void init_fft(nb::module_& parent_module) {
)pbdoc");
m.def(
"irfft",
[](const array& a,
[](const mx::array& a,
const std::optional<int>& n,
int axis,
StreamOrDevice s) {
mx::StreamOrDevice s) {
if (n.has_value()) {
return fft::irfft(a, n.value(), axis, s);
return mx::fft::irfft(a, n.value(), axis, s);
} else {
return fft::irfft(a, axis, s);
return mx::fft::irfft(a, axis, s);
}
},
"a"_a,
@@ -294,19 +293,19 @@ void init_fft(nb::module_& parent_module) {
)pbdoc");
m.def(
"rfft2",
[](const array& a,
[](const mx::array& a,
const std::optional<std::vector<int>>& n,
const std::optional<std::vector<int>>& axes,
StreamOrDevice s) {
mx::StreamOrDevice s) {
if (axes.has_value() && n.has_value()) {
return fft::rfftn(a, n.value(), axes.value(), s);
return mx::fft::rfftn(a, n.value(), axes.value(), s);
} else if (axes.has_value()) {
return fft::rfftn(a, axes.value(), s);
return mx::fft::rfftn(a, axes.value(), s);
} else if (n.has_value()) {
throw std::invalid_argument(
"[rfft2] `axes` should not be `None` if `s` is not `None`.");
} else {
return fft::rfftn(a, s);
return mx::fft::rfftn(a, s);
}
},
"a"_a,
@@ -336,19 +335,19 @@ void init_fft(nb::module_& parent_module) {
)pbdoc");
m.def(
"irfft2",
[](const array& a,
[](const mx::array& a,
const std::optional<std::vector<int>>& n,
const std::optional<std::vector<int>>& axes,
StreamOrDevice s) {
mx::StreamOrDevice s) {
if (axes.has_value() && n.has_value()) {
return fft::irfftn(a, n.value(), axes.value(), s);
return mx::fft::irfftn(a, n.value(), axes.value(), s);
} else if (axes.has_value()) {
return fft::irfftn(a, axes.value(), s);
return mx::fft::irfftn(a, axes.value(), s);
} else if (n.has_value()) {
throw std::invalid_argument(
"[irfft2] `axes` should not be `None` if `s` is not `None`.");
} else {
return fft::irfftn(a, s);
return mx::fft::irfftn(a, s);
}
},
"a"_a,
@@ -378,19 +377,19 @@ void init_fft(nb::module_& parent_module) {
)pbdoc");
m.def(
"rfftn",
[](const array& a,
[](const mx::array& a,
const std::optional<std::vector<int>>& n,
const std::optional<std::vector<int>>& axes,
StreamOrDevice s) {
mx::StreamOrDevice s) {
if (axes.has_value() && n.has_value()) {
return fft::rfftn(a, n.value(), axes.value(), s);
return mx::fft::rfftn(a, n.value(), axes.value(), s);
} else if (axes.has_value()) {
return fft::rfftn(a, axes.value(), s);
return mx::fft::rfftn(a, axes.value(), s);
} else if (n.has_value()) {
throw std::invalid_argument(
"[rfftn] `axes` should not be `None` if `s` is not `None`.");
} else {
return fft::rfftn(a, s);
return mx::fft::rfftn(a, s);
}
},
"a"_a,
@@ -420,19 +419,19 @@ void init_fft(nb::module_& parent_module) {
)pbdoc");
m.def(
"irfftn",
[](const array& a,
[](const mx::array& a,
const std::optional<std::vector<int>>& n,
const std::optional<std::vector<int>>& axes,
StreamOrDevice s) {
mx::StreamOrDevice s) {
if (axes.has_value() && n.has_value()) {
return fft::irfftn(a, n.value(), axes.value(), s);
return mx::fft::irfftn(a, n.value(), axes.value(), s);
} else if (axes.has_value()) {
return fft::irfftn(a, axes.value(), s);
return mx::fft::irfftn(a, axes.value(), s);
} else if (n.has_value()) {
throw std::invalid_argument(
"[irfftn] `axes` should not be `None` if `s` is not `None`.");
} else {
return fft::irfftn(a, s);
return mx::fft::irfftn(a, s);
}
},
"a"_a,