mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Remove "using namespace mlx::core" in python/src (#1689)
This commit is contained in:
@@ -10,15 +10,13 @@
|
||||
|
||||
#include "mlx/linalg.h"
|
||||
|
||||
namespace mx = mlx::core;
|
||||
namespace nb = nanobind;
|
||||
using namespace nb::literals;
|
||||
|
||||
using namespace mlx::core;
|
||||
using namespace mlx::core::linalg;
|
||||
|
||||
namespace {
|
||||
nb::tuple svd_helper(const array& a, StreamOrDevice s /* = {} */) {
|
||||
const auto result = svd(a, s);
|
||||
nb::tuple svd_helper(const mx::array& a, mx::StreamOrDevice s /* = {} */) {
|
||||
const auto result = mx::linalg::svd(a, s);
|
||||
return nb::make_tuple(result.at(0), result.at(1), result.at(2));
|
||||
}
|
||||
} // namespace
|
||||
@@ -29,11 +27,11 @@ void init_linalg(nb::module_& parent_module) {
|
||||
|
||||
m.def(
|
||||
"norm",
|
||||
[](const array& a,
|
||||
[](const mx::array& a,
|
||||
const std::variant<std::monostate, int, double, std::string>& ord_,
|
||||
const std::variant<std::monostate, int, std::vector<int>>& axis_,
|
||||
const bool keepdims,
|
||||
const StreamOrDevice stream) {
|
||||
const mx::StreamOrDevice stream) {
|
||||
std::optional<std::vector<int>> axis = std::nullopt;
|
||||
if (auto pv = std::get_if<int>(&axis_); pv) {
|
||||
axis = std::vector<int>{*pv};
|
||||
@@ -42,10 +40,10 @@ void init_linalg(nb::module_& parent_module) {
|
||||
}
|
||||
|
||||
if (std::holds_alternative<std::monostate>(ord_)) {
|
||||
return norm(a, axis, keepdims, stream);
|
||||
return mx::linalg::norm(a, axis, keepdims, stream);
|
||||
} else {
|
||||
if (auto pv = std::get_if<std::string>(&ord_); pv) {
|
||||
return norm(a, *pv, axis, keepdims, stream);
|
||||
return mx::linalg::norm(a, *pv, axis, keepdims, stream);
|
||||
}
|
||||
double ord;
|
||||
if (auto pv = std::get_if<int>(&ord_); pv) {
|
||||
@@ -53,7 +51,7 @@ void init_linalg(nb::module_& parent_module) {
|
||||
} else {
|
||||
ord = std::get<double>(ord_);
|
||||
}
|
||||
return norm(a, ord, axis, keepdims, stream);
|
||||
return mx::linalg::norm(a, ord, axis, keepdims, stream);
|
||||
}
|
||||
},
|
||||
nb::arg(),
|
||||
@@ -182,7 +180,7 @@ void init_linalg(nb::module_& parent_module) {
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"qr",
|
||||
&qr,
|
||||
&mx::linalg::qr,
|
||||
"a"_a,
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
@@ -239,7 +237,7 @@ void init_linalg(nb::module_& parent_module) {
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"inv",
|
||||
&inv,
|
||||
&mx::linalg::inv,
|
||||
"a"_a,
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
@@ -262,7 +260,7 @@ void init_linalg(nb::module_& parent_module) {
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"tri_inv",
|
||||
&tri_inv,
|
||||
&mx::linalg::tri_inv,
|
||||
"a"_a,
|
||||
"upper"_a,
|
||||
nb::kw_only(),
|
||||
@@ -287,7 +285,7 @@ void init_linalg(nb::module_& parent_module) {
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"cholesky",
|
||||
&cholesky,
|
||||
&mx::linalg::cholesky,
|
||||
"a"_a,
|
||||
"upper"_a = false,
|
||||
nb::kw_only(),
|
||||
@@ -317,7 +315,7 @@ void init_linalg(nb::module_& parent_module) {
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"cholesky_inv",
|
||||
&cholesky_inv,
|
||||
&mx::linalg::cholesky_inv,
|
||||
"a"_a,
|
||||
"upper"_a = false,
|
||||
nb::kw_only(),
|
||||
@@ -355,7 +353,7 @@ void init_linalg(nb::module_& parent_module) {
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"pinv",
|
||||
&pinv,
|
||||
&mx::linalg::pinv,
|
||||
"a"_a,
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
@@ -379,7 +377,7 @@ void init_linalg(nb::module_& parent_module) {
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"cross",
|
||||
&cross,
|
||||
&mx::linalg::cross,
|
||||
"a"_a,
|
||||
"b"_a,
|
||||
"axis"_a = -1,
|
||||
@@ -407,7 +405,7 @@ void init_linalg(nb::module_& parent_module) {
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"eigvalsh",
|
||||
&eigvalsh,
|
||||
&mx::linalg::eigvalsh,
|
||||
"a"_a,
|
||||
"UPLO"_a = "L",
|
||||
nb::kw_only(),
|
||||
@@ -442,9 +440,9 @@ void init_linalg(nb::module_& parent_module) {
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"eigh",
|
||||
[](const array& a, const std::string UPLO, StreamOrDevice s) {
|
||||
[](const mx::array& a, const std::string UPLO, mx::StreamOrDevice s) {
|
||||
// TODO avoid cast?
|
||||
auto result = eigh(a, UPLO, s);
|
||||
auto result = mx::linalg::eigh(a, UPLO, s);
|
||||
return nb::make_tuple(result.first, result.second);
|
||||
},
|
||||
"a"_a,
|
||||
|
||||
Reference in New Issue
Block a user