Remove "using namespace mlx::core" in python/src (#1689)

This commit is contained in:
Cheng
2024-12-12 08:45:39 +09:00
committed by GitHub
parent f3dfa36a3a
commit 0bf19037ca
22 changed files with 1423 additions and 1302 deletions

View File

@@ -16,7 +16,7 @@ enum PyScalarT {
namespace nanobind {
template <>
struct ndarray_traits<float16_t> {
struct ndarray_traits<mx::float16_t> {
static constexpr bool is_complex = false;
static constexpr bool is_float = true;
static constexpr bool is_bool = false;
@@ -36,21 +36,21 @@ int check_shape_dim(int64_t dim) {
}
template <typename T>
array nd_array_to_mlx_contiguous(
mx::array nd_array_to_mlx_contiguous(
nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu> nd_array,
const Shape& shape,
Dtype dtype) {
const mx::Shape& shape,
mx::Dtype dtype) {
// Make a copy of the numpy buffer
// Get buffer ptr pass to array constructor
auto data_ptr = nd_array.data();
return array(static_cast<const T*>(data_ptr), shape, dtype);
return mx::array(static_cast<const T*>(data_ptr), shape, dtype);
}
array nd_array_to_mlx(
mx::array nd_array_to_mlx(
nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu> nd_array,
std::optional<Dtype> dtype) {
std::optional<mx::Dtype> dtype) {
// Compute the shape and size
Shape shape;
mx::Shape shape;
for (int i = 0; i < nd_array.ndim(); i++) {
shape.push_back(check_shape_dim(nd_array.shape(i)));
}
@@ -59,49 +59,49 @@ array nd_array_to_mlx(
// Copy data and make array
if (type == nb::dtype<bool>()) {
return nd_array_to_mlx_contiguous<bool>(
nd_array, shape, dtype.value_or(bool_));
nd_array, shape, dtype.value_or(mx::bool_));
} else if (type == nb::dtype<uint8_t>()) {
return nd_array_to_mlx_contiguous<uint8_t>(
nd_array, shape, dtype.value_or(uint8));
nd_array, shape, dtype.value_or(mx::uint8));
} else if (type == nb::dtype<uint16_t>()) {
return nd_array_to_mlx_contiguous<uint16_t>(
nd_array, shape, dtype.value_or(uint16));
nd_array, shape, dtype.value_or(mx::uint16));
} else if (type == nb::dtype<uint32_t>()) {
return nd_array_to_mlx_contiguous<uint32_t>(
nd_array, shape, dtype.value_or(uint32));
nd_array, shape, dtype.value_or(mx::uint32));
} else if (type == nb::dtype<uint64_t>()) {
return nd_array_to_mlx_contiguous<uint64_t>(
nd_array, shape, dtype.value_or(uint64));
nd_array, shape, dtype.value_or(mx::uint64));
} else if (type == nb::dtype<int8_t>()) {
return nd_array_to_mlx_contiguous<int8_t>(
nd_array, shape, dtype.value_or(int8));
nd_array, shape, dtype.value_or(mx::int8));
} else if (type == nb::dtype<int16_t>()) {
return nd_array_to_mlx_contiguous<int16_t>(
nd_array, shape, dtype.value_or(int16));
nd_array, shape, dtype.value_or(mx::int16));
} else if (type == nb::dtype<int32_t>()) {
return nd_array_to_mlx_contiguous<int32_t>(
nd_array, shape, dtype.value_or(int32));
nd_array, shape, dtype.value_or(mx::int32));
} else if (type == nb::dtype<int64_t>()) {
return nd_array_to_mlx_contiguous<int64_t>(
nd_array, shape, dtype.value_or(int64));
} else if (type == nb::dtype<float16_t>()) {
return nd_array_to_mlx_contiguous<float16_t>(
nd_array, shape, dtype.value_or(float16));
nd_array, shape, dtype.value_or(mx::int64));
} else if (type == nb::dtype<mx::float16_t>()) {
return nd_array_to_mlx_contiguous<mx::float16_t>(
nd_array, shape, dtype.value_or(mx::float16));
} else if (type == nb::bfloat16) {
return nd_array_to_mlx_contiguous<bfloat16_t>(
nd_array, shape, dtype.value_or(bfloat16));
return nd_array_to_mlx_contiguous<mx::bfloat16_t>(
nd_array, shape, dtype.value_or(mx::bfloat16));
} else if (type == nb::dtype<float>()) {
return nd_array_to_mlx_contiguous<float>(
nd_array, shape, dtype.value_or(float32));
nd_array, shape, dtype.value_or(mx::float32));
} else if (type == nb::dtype<double>()) {
return nd_array_to_mlx_contiguous<double>(
nd_array, shape, dtype.value_or(float32));
nd_array, shape, dtype.value_or(mx::float32));
} else if (type == nb::dtype<std::complex<float>>()) {
return nd_array_to_mlx_contiguous<complex64_t>(
nd_array, shape, dtype.value_or(complex64));
return nd_array_to_mlx_contiguous<mx::complex64_t>(
nd_array, shape, dtype.value_or(mx::complex64));
} else if (type == nb::dtype<std::complex<double>>()) {
return nd_array_to_mlx_contiguous<complex128_t>(
nd_array, shape, dtype.value_or(complex64));
return nd_array_to_mlx_contiguous<mx::complex128_t>(
nd_array, shape, dtype.value_or(mx::complex64));
} else {
throw std::invalid_argument("Cannot convert numpy array to mlx array.");
}
@@ -109,7 +109,7 @@ array nd_array_to_mlx(
template <typename T, typename... NDParams>
nb::ndarray<NDParams...> mlx_to_nd_array_impl(
array a,
mx::array a,
std::optional<nb::dlpack::dtype> t = {}) {
{
nb::gil_scoped_release nogil;
@@ -126,48 +126,48 @@ nb::ndarray<NDParams...> mlx_to_nd_array_impl(
}
template <typename... NDParams>
nb::ndarray<NDParams...> mlx_to_nd_array(const array& a) {
nb::ndarray<NDParams...> mlx_to_nd_array(const mx::array& a) {
switch (a.dtype()) {
case bool_:
case mx::bool_:
return mlx_to_nd_array_impl<bool, NDParams...>(a);
case uint8:
case mx::uint8:
return mlx_to_nd_array_impl<uint8_t, NDParams...>(a);
case uint16:
case mx::uint16:
return mlx_to_nd_array_impl<uint16_t, NDParams...>(a);
case uint32:
case mx::uint32:
return mlx_to_nd_array_impl<uint32_t, NDParams...>(a);
case uint64:
case mx::uint64:
return mlx_to_nd_array_impl<uint64_t, NDParams...>(a);
case int8:
case mx::int8:
return mlx_to_nd_array_impl<int8_t, NDParams...>(a);
case int16:
case mx::int16:
return mlx_to_nd_array_impl<int16_t, NDParams...>(a);
case int32:
case mx::int32:
return mlx_to_nd_array_impl<int32_t, NDParams...>(a);
case int64:
case mx::int64:
return mlx_to_nd_array_impl<int64_t, NDParams...>(a);
case float16:
return mlx_to_nd_array_impl<float16_t, NDParams...>(a);
case bfloat16:
case mx::float16:
return mlx_to_nd_array_impl<mx::float16_t, NDParams...>(a);
case mx::bfloat16:
throw nb::type_error("bfloat16 arrays cannot be converted to NumPy.");
case float32:
case mx::float32:
return mlx_to_nd_array_impl<float, NDParams...>(a);
case complex64:
case mx::complex64:
return mlx_to_nd_array_impl<std::complex<float>, NDParams...>(a);
default:
throw nb::type_error("type cannot be converted to NumPy.");
}
}
nb::ndarray<nb::numpy> mlx_to_np_array(const array& a) {
nb::ndarray<nb::numpy> mlx_to_np_array(const mx::array& a) {
return mlx_to_nd_array<nb::numpy>(a);
}
nb::ndarray<> mlx_to_dlpack(const array& a) {
nb::ndarray<> mlx_to_dlpack(const mx::array& a) {
return mlx_to_nd_array<>(a);
}
nb::object to_scalar(array& a) {
nb::object to_scalar(mx::array& a) {
if (a.size() != 1) {
throw std::invalid_argument(
"[convert] Only length-1 arrays can be converted to Python scalars.");
@@ -177,31 +177,31 @@ nb::object to_scalar(array& a) {
a.eval();
}
switch (a.dtype()) {
case bool_:
case mx::bool_:
return nb::cast(a.item<bool>());
case uint8:
case mx::uint8:
return nb::cast(a.item<uint8_t>());
case uint16:
case mx::uint16:
return nb::cast(a.item<uint16_t>());
case uint32:
case mx::uint32:
return nb::cast(a.item<uint32_t>());
case uint64:
case mx::uint64:
return nb::cast(a.item<uint64_t>());
case int8:
case mx::int8:
return nb::cast(a.item<int8_t>());
case int16:
case mx::int16:
return nb::cast(a.item<int16_t>());
case int32:
case mx::int32:
return nb::cast(a.item<int32_t>());
case int64:
case mx::int64:
return nb::cast(a.item<int64_t>());
case float16:
return nb::cast(static_cast<float>(a.item<float16_t>()));
case float32:
case mx::float16:
return nb::cast(static_cast<float>(a.item<mx::float16_t>()));
case mx::float32:
return nb::cast(a.item<float>());
case bfloat16:
return nb::cast(static_cast<float>(a.item<bfloat16_t>()));
case complex64:
case mx::bfloat16:
return nb::cast(static_cast<float>(a.item<mx::bfloat16_t>()));
case mx::complex64:
return nb::cast(a.item<std::complex<float>>());
default:
throw nb::type_error("type cannot be converted to Python scalar.");
@@ -209,7 +209,7 @@ nb::object to_scalar(array& a) {
}
template <typename T, typename U = T>
nb::list to_list(array& a, size_t index, int dim) {
nb::list to_list(mx::array& a, size_t index, int dim) {
nb::list pl;
auto stride = a.strides()[dim];
for (int i = 0; i < a.shape(dim); ++i) {
@@ -223,7 +223,7 @@ nb::list to_list(array& a, size_t index, int dim) {
return pl;
}
nb::object tolist(array& a) {
nb::object tolist(mx::array& a) {
if (a.ndim() == 0) {
return to_scalar(a);
}
@@ -232,31 +232,31 @@ nb::object tolist(array& a) {
a.eval();
}
switch (a.dtype()) {
case bool_:
case mx::bool_:
return to_list<bool>(a, 0, 0);
case uint8:
case mx::uint8:
return to_list<uint8_t>(a, 0, 0);
case uint16:
case mx::uint16:
return to_list<uint16_t>(a, 0, 0);
case uint32:
case mx::uint32:
return to_list<uint32_t>(a, 0, 0);
case uint64:
case mx::uint64:
return to_list<uint64_t>(a, 0, 0);
case int8:
case mx::int8:
return to_list<int8_t>(a, 0, 0);
case int16:
case mx::int16:
return to_list<int16_t>(a, 0, 0);
case int32:
case mx::int32:
return to_list<int32_t>(a, 0, 0);
case int64:
case mx::int64:
return to_list<int64_t>(a, 0, 0);
case float16:
return to_list<float16_t, float>(a, 0, 0);
case float32:
case mx::float16:
return to_list<mx::float16_t, float>(a, 0, 0);
case mx::float32:
return to_list<float>(a, 0, 0);
case bfloat16:
return to_list<bfloat16_t, float>(a, 0, 0);
case complex64:
case mx::bfloat16:
return to_list<mx::bfloat16_t, float>(a, 0, 0);
case mx::complex64:
return to_list<std::complex<float>>(a, 0, 0);
default:
throw nb::type_error("data type cannot be converted to Python list.");
@@ -279,7 +279,7 @@ void fill_vector(T list, std::vector<U>& vals) {
template <typename T>
PyScalarT validate_shape(
T list,
const Shape& shape,
const mx::Shape& shape,
int idx,
bool& all_python_primitive_elements) {
if (idx >= shape.size()) {
@@ -307,9 +307,9 @@ PyScalarT validate_shape(
shape,
idx + 1,
all_python_primitive_elements);
} else if (nb::isinstance<array>(l)) {
} else if (nb::isinstance<mx::array>(l)) {
all_python_primitive_elements = false;
auto arr = nb::cast<array>(l);
auto arr = nb::cast<mx::array>(l);
if (arr.ndim() + idx + 1 == shape.size() &&
std::equal(
arr.shape().cbegin(),
@@ -347,7 +347,7 @@ PyScalarT validate_shape(
}
template <typename T>
void get_shape(T list, Shape& shape) {
void get_shape(T list, mx::Shape& shape) {
shape.push_back(check_shape_dim(nb::len(list)));
if (shape.back() > 0) {
auto l = list.begin();
@@ -355,8 +355,8 @@ void get_shape(T list, Shape& shape) {
return get_shape(nb::cast<nb::list>(*l), shape);
} else if (nb::isinstance<nb::tuple>(*l)) {
return get_shape(nb::cast<nb::tuple>(*l), shape);
} else if (nb::isinstance<array>(*l)) {
auto arr = nb::cast<array>(*l);
} else if (nb::isinstance<mx::array>(*l)) {
auto arr = nb::cast<mx::array>(*l);
for (int i = 0; i < arr.ndim(); i++) {
shape.push_back(arr.shape(i));
}
@@ -366,54 +366,55 @@ void get_shape(T list, Shape& shape) {
}
template <typename T>
array array_from_list_impl(
mx::array array_from_list_impl(
T pl,
const PyScalarT& inferred_type,
std::optional<Dtype> specified_type,
const Shape& shape) {
std::optional<mx::Dtype> specified_type,
const mx::Shape& shape) {
// Make the array
switch (inferred_type) {
case pybool: {
std::vector<bool> vals;
fill_vector(pl, vals);
return array(vals.begin(), shape, specified_type.value_or(bool_));
return mx::array(vals.begin(), shape, specified_type.value_or(mx::bool_));
}
case pyint: {
auto dtype = specified_type.value_or(int32);
if (dtype == int64) {
auto dtype = specified_type.value_or(mx::int32);
if (dtype == mx::int64) {
std::vector<int64_t> vals;
fill_vector(pl, vals);
return array(vals.begin(), shape, dtype);
} else if (dtype == uint64) {
return mx::array(vals.begin(), shape, dtype);
} else if (dtype == mx::uint64) {
std::vector<uint64_t> vals;
fill_vector(pl, vals);
return array(vals.begin(), shape, dtype);
} else if (dtype == uint32) {
return mx::array(vals.begin(), shape, dtype);
} else if (dtype == mx::uint32) {
std::vector<uint32_t> vals;
fill_vector(pl, vals);
return array(vals.begin(), shape, dtype);
} else if (issubdtype(dtype, inexact)) {
return mx::array(vals.begin(), shape, dtype);
} else if (mx::issubdtype(dtype, mx::inexact)) {
std::vector<float> vals;
fill_vector(pl, vals);
return array(vals.begin(), shape, dtype);
return mx::array(vals.begin(), shape, dtype);
} else {
std::vector<int> vals;
fill_vector(pl, vals);
return array(vals.begin(), shape, dtype);
return mx::array(vals.begin(), shape, dtype);
}
}
case pyfloat: {
std::vector<float> vals;
fill_vector(pl, vals);
return array(vals.begin(), shape, specified_type.value_or(float32));
return mx::array(
vals.begin(), shape, specified_type.value_or(mx::float32));
}
case pycomplex: {
std::vector<std::complex<float>> vals;
fill_vector(pl, vals);
return array(
reinterpret_cast<complex64_t*>(vals.data()),
return mx::array(
reinterpret_cast<mx::complex64_t*>(vals.data()),
shape,
specified_type.value_or(complex64));
specified_type.value_or(mx::complex64));
}
default: {
std::ostringstream msg;
@@ -425,9 +426,9 @@ array array_from_list_impl(
}
template <typename T>
array array_from_list_impl(T pl, std::optional<Dtype> dtype) {
mx::array array_from_list_impl(T pl, std::optional<mx::Dtype> dtype) {
// Compute the shape
Shape shape;
mx::Shape shape;
get_shape(pl, shape);
// Validate the shape and type
@@ -440,30 +441,31 @@ array array_from_list_impl(T pl, std::optional<Dtype> dtype) {
}
// `pl` contains mlx arrays
std::vector<array> arrays;
std::vector<mx::array> arrays;
for (auto l : pl) {
arrays.push_back(create_array(nb::cast<ArrayInitType>(l), dtype));
}
return stack(arrays);
return mx::stack(arrays);
}
array array_from_list(nb::list pl, std::optional<Dtype> dtype) {
mx::array array_from_list(nb::list pl, std::optional<mx::Dtype> dtype) {
return array_from_list_impl(pl, dtype);
}
array array_from_list(nb::tuple pl, std::optional<Dtype> dtype) {
mx::array array_from_list(nb::tuple pl, std::optional<mx::Dtype> dtype) {
return array_from_list_impl(pl, dtype);
}
array create_array(ArrayInitType v, std::optional<Dtype> t) {
mx::array create_array(ArrayInitType v, std::optional<mx::Dtype> t) {
if (auto pv = std::get_if<nb::bool_>(&v); pv) {
return array(nb::cast<bool>(*pv), t.value_or(bool_));
return mx::array(nb::cast<bool>(*pv), t.value_or(mx::bool_));
} else if (auto pv = std::get_if<nb::int_>(&v); pv) {
return array(nb::cast<int>(*pv), t.value_or(int32));
return mx::array(nb::cast<int>(*pv), t.value_or(mx::int32));
} else if (auto pv = std::get_if<nb::float_>(&v); pv) {
return array(nb::cast<float>(*pv), t.value_or(float32));
return mx::array(nb::cast<float>(*pv), t.value_or(mx::float32));
} else if (auto pv = std::get_if<std::complex<float>>(&v); pv) {
return array(static_cast<complex64_t>(*pv), t.value_or(complex64));
return mx::array(
static_cast<mx::complex64_t>(*pv), t.value_or(mx::complex64));
} else if (auto pv = std::get_if<nb::list>(&v); pv) {
return array_from_list(*pv, t);
} else if (auto pv = std::get_if<nb::tuple>(&v); pv) {
@@ -472,10 +474,10 @@ array create_array(ArrayInitType v, std::optional<Dtype> t) {
nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu>>(&v);
pv) {
return nd_array_to_mlx(*pv, t);
} else if (auto pv = std::get_if<array>(&v); pv) {
return astype(*pv, t.value_or((*pv).dtype()));
} else if (auto pv = std::get_if<mx::array>(&v); pv) {
return mx::astype(*pv, t.value_or((*pv).dtype()));
} else {
auto arr = to_array_with_accessor(std::get<nb::object>(v));
return astype(arr, t.value_or(arr.dtype()));
return mx::astype(arr, t.value_or(arr.dtype()));
}
}