Remove "using namespace mlx::core" in python/src (#1689)

This commit is contained in:
Cheng
2024-12-12 08:45:39 +09:00
committed by GitHub
parent f3dfa36a3a
commit 0bf19037ca
22 changed files with 1423 additions and 1302 deletions

View File

@@ -4,22 +4,24 @@
#include "mlx/ops.h"
#include "python/src/convert.h"
array to_array(
mx::array to_array(
const ScalarOrArray& v,
std::optional<Dtype> dtype /* = std::nullopt */) {
std::optional<mx::Dtype> dtype /* = std::nullopt */) {
if (auto pv = std::get_if<nb::bool_>(&v); pv) {
return array(nb::cast<bool>(*pv), dtype.value_or(bool_));
return mx::array(nb::cast<bool>(*pv), dtype.value_or(mx::bool_));
} else if (auto pv = std::get_if<nb::int_>(&v); pv) {
auto out_t = dtype.value_or(int32);
auto out_t = dtype.value_or(mx::int32);
// bool_ is an exception and is always promoted
return array(nb::cast<int>(*pv), (out_t == bool_) ? int32 : out_t);
return mx::array(
nb::cast<int>(*pv), (out_t == mx::bool_) ? mx::int32 : out_t);
} else if (auto pv = std::get_if<nb::float_>(&v); pv) {
auto out_t = dtype.value_or(float32);
return array(
nb::cast<float>(*pv), issubdtype(out_t, floating) ? out_t : float32);
auto out_t = dtype.value_or(mx::float32);
return mx::array(
nb::cast<float>(*pv),
mx::issubdtype(out_t, mx::floating) ? out_t : mx::float32);
} else if (auto pv = std::get_if<std::complex<float>>(&v); pv) {
return array(static_cast<complex64_t>(*pv), complex64);
} else if (auto pv = std::get_if<array>(&v); pv) {
return mx::array(static_cast<mx::complex64_t>(*pv), mx::complex64);
} else if (auto pv = std::get_if<mx::array>(&v); pv) {
return *pv;
} else if (auto pv = std::get_if<
nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu>>(&v);
@@ -30,7 +32,7 @@ array to_array(
}
}
std::pair<array, array> to_arrays(
std::pair<mx::array, mx::array> to_arrays(
const ScalarOrArray& a,
const ScalarOrArray& b) {
// Four cases:
@@ -39,15 +41,15 @@ std::pair<array, array> to_arrays(
// - If b is an array but a is not, treat a as a weak python type
// - If neither is an array convert to arrays but leave their types alone
auto is_mlx_array = [](const ScalarOrArray& x) {
return std::holds_alternative<array>(x) ||
return std::holds_alternative<mx::array>(x) ||
std::holds_alternative<nb::object>(x) &&
nb::hasattr(std::get<nb::object>(x), "__mlx_array__");
};
auto get_mlx_array = [](const ScalarOrArray& x) {
if (auto px = std::get_if<array>(&x); px) {
if (auto px = std::get_if<mx::array>(&x); px) {
return *px;
} else {
return nb::cast<array>(std::get<nb::object>(x).attr("__mlx_array__"));
return nb::cast<mx::array>(std::get<nb::object>(x).attr("__mlx_array__"));
}
};
@@ -66,11 +68,11 @@ std::pair<array, array> to_arrays(
}
}
array to_array_with_accessor(nb::object obj) {
if (nb::isinstance<array>(obj)) {
return nb::cast<array>(obj);
mx::array to_array_with_accessor(nb::object obj) {
if (nb::isinstance<mx::array>(obj)) {
return nb::cast<mx::array>(obj);
} else if (nb::hasattr(obj, "__mlx_array__")) {
return nb::cast<array>(obj.attr("__mlx_array__")());
return nb::cast<mx::array>(obj.attr("__mlx_array__")());
} else {
std::ostringstream msg;
msg << "Invalid type " << nb::type_name(obj.type()).c_str()