Remove "using namespace mlx::core" in python/src (#1689)

This commit is contained in:
Cheng
2024-12-12 08:45:39 +09:00
committed by GitHub
parent f3dfa36a3a
commit 0bf19037ca
22 changed files with 1423 additions and 1302 deletions

View File

@@ -10,15 +10,13 @@
#include "mlx/linalg.h"
namespace mx = mlx::core;
namespace nb = nanobind;
using namespace nb::literals;
using namespace mlx::core;
using namespace mlx::core::linalg;
namespace {
nb::tuple svd_helper(const array& a, StreamOrDevice s /* = {} */) {
const auto result = svd(a, s);
nb::tuple svd_helper(const mx::array& a, mx::StreamOrDevice s /* = {} */) {
const auto result = mx::linalg::svd(a, s);
return nb::make_tuple(result.at(0), result.at(1), result.at(2));
}
} // namespace
@@ -29,11 +27,11 @@ void init_linalg(nb::module_& parent_module) {
m.def(
"norm",
[](const array& a,
[](const mx::array& a,
const std::variant<std::monostate, int, double, std::string>& ord_,
const std::variant<std::monostate, int, std::vector<int>>& axis_,
const bool keepdims,
const StreamOrDevice stream) {
const mx::StreamOrDevice stream) {
std::optional<std::vector<int>> axis = std::nullopt;
if (auto pv = std::get_if<int>(&axis_); pv) {
axis = std::vector<int>{*pv};
@@ -42,10 +40,10 @@ void init_linalg(nb::module_& parent_module) {
}
if (std::holds_alternative<std::monostate>(ord_)) {
return norm(a, axis, keepdims, stream);
return mx::linalg::norm(a, axis, keepdims, stream);
} else {
if (auto pv = std::get_if<std::string>(&ord_); pv) {
return norm(a, *pv, axis, keepdims, stream);
return mx::linalg::norm(a, *pv, axis, keepdims, stream);
}
double ord;
if (auto pv = std::get_if<int>(&ord_); pv) {
@@ -53,7 +51,7 @@ void init_linalg(nb::module_& parent_module) {
} else {
ord = std::get<double>(ord_);
}
return norm(a, ord, axis, keepdims, stream);
return mx::linalg::norm(a, ord, axis, keepdims, stream);
}
},
nb::arg(),
@@ -182,7 +180,7 @@ void init_linalg(nb::module_& parent_module) {
)pbdoc");
m.def(
"qr",
&qr,
&mx::linalg::qr,
"a"_a,
nb::kw_only(),
"stream"_a = nb::none(),
@@ -239,7 +237,7 @@ void init_linalg(nb::module_& parent_module) {
)pbdoc");
m.def(
"inv",
&inv,
&mx::linalg::inv,
"a"_a,
nb::kw_only(),
"stream"_a = nb::none(),
@@ -262,7 +260,7 @@ void init_linalg(nb::module_& parent_module) {
)pbdoc");
m.def(
"tri_inv",
&tri_inv,
&mx::linalg::tri_inv,
"a"_a,
"upper"_a,
nb::kw_only(),
@@ -287,7 +285,7 @@ void init_linalg(nb::module_& parent_module) {
)pbdoc");
m.def(
"cholesky",
&cholesky,
&mx::linalg::cholesky,
"a"_a,
"upper"_a = false,
nb::kw_only(),
@@ -317,7 +315,7 @@ void init_linalg(nb::module_& parent_module) {
)pbdoc");
m.def(
"cholesky_inv",
&cholesky_inv,
&mx::linalg::cholesky_inv,
"a"_a,
"upper"_a = false,
nb::kw_only(),
@@ -355,7 +353,7 @@ void init_linalg(nb::module_& parent_module) {
)pbdoc");
m.def(
"pinv",
&pinv,
&mx::linalg::pinv,
"a"_a,
nb::kw_only(),
"stream"_a = nb::none(),
@@ -379,7 +377,7 @@ void init_linalg(nb::module_& parent_module) {
)pbdoc");
m.def(
"cross",
&cross,
&mx::linalg::cross,
"a"_a,
"b"_a,
"axis"_a = -1,
@@ -407,7 +405,7 @@ void init_linalg(nb::module_& parent_module) {
)pbdoc");
m.def(
"eigvalsh",
&eigvalsh,
&mx::linalg::eigvalsh,
"a"_a,
"UPLO"_a = "L",
nb::kw_only(),
@@ -442,9 +440,9 @@ void init_linalg(nb::module_& parent_module) {
)pbdoc");
m.def(
"eigh",
[](const array& a, const std::string UPLO, StreamOrDevice s) {
[](const mx::array& a, const std::string UPLO, mx::StreamOrDevice s) {
// TODO avoid cast?
auto result = eigh(a, UPLO, s);
auto result = mx::linalg::eigh(a, UPLO, s);
return nb::make_tuple(result.first, result.second);
},
"a"_a,