mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-18 10:26:56 +08:00
add numeric type hierarchy and issubdtype as well as a set_dtype meth… (#427)
* add numeric type hierarchy and issubdtype as well as a set_dtype method to nn.Module with predicate
numeric type hierarchy and issubtype is compatible to the [numpy hierarchy](220f0ab2c5/numpy/_core/numerictypes.py (L42)
).
Closes #285.
* nits in docs
* unify type category checking
* nits in docs
* nits in docs
* more docs nits
* fix callable type
---------
Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
@@ -254,7 +254,7 @@ array array_from_list(
|
||||
std::vector<uint32_t> vals;
|
||||
fill_vector(pl, vals);
|
||||
return array(vals.begin(), shape, dtype);
|
||||
} else if (is_floating_point(dtype)) {
|
||||
} else if (issubdtype(dtype, inexact)) {
|
||||
std::vector<float> vals;
|
||||
fill_vector(pl, vals);
|
||||
return array(vals.begin(), shape, dtype);
|
||||
@@ -439,6 +439,54 @@ void init_array(nb::module_& m) {
|
||||
m.attr("float32") = nb::cast(float32);
|
||||
m.attr("bfloat16") = nb::cast(bfloat16);
|
||||
m.attr("complex64") = nb::cast(complex64);
|
||||
nb::class_<Dtype::Category>(
|
||||
m,
|
||||
"DtypeCategory",
|
||||
R"pbdoc(
|
||||
Type to hold categories of :class:`dtypes <Dtype>`.
|
||||
|
||||
* :attr:`~mlx.core.generic`
|
||||
|
||||
* :ref:`bool_ <data_types>`
|
||||
* :attr:`~mlx.core.number`
|
||||
|
||||
* :attr:`~mlx.core.integer`
|
||||
|
||||
* :attr:`~mlx.core.unsignedinteger`
|
||||
|
||||
* :ref:`uint8 <data_types>`
|
||||
* :ref:`uint16 <data_types>`
|
||||
* :ref:`uint32 <data_types>`
|
||||
* :ref:`uint64 <data_types>`
|
||||
|
||||
* :attr:`~mlx.core.signedinteger`
|
||||
|
||||
* :ref:`int8 <data_types>`
|
||||
* :ref:`int32 <data_types>`
|
||||
* :ref:`int64 <data_types>`
|
||||
|
||||
* :attr:`~mlx.core.inexact`
|
||||
|
||||
* :attr:`~mlx.core.floating`
|
||||
|
||||
* :ref:`float16 <data_types>`
|
||||
* :ref:`bfloat16 <data_types>`
|
||||
* :ref:`float32 <data_types>`
|
||||
|
||||
* :attr:`~mlx.core.complexfloating`
|
||||
|
||||
* :ref:`complex128 <data_types>`
|
||||
|
||||
See also :func:`~mlx.core.issubdtype`.
|
||||
)pbdoc");
|
||||
m.attr("complexfloating") = nb::cast(complexfloating);
|
||||
m.attr("floating") = nb::cast(floating);
|
||||
m.attr("inexact") = nb::cast(inexact);
|
||||
m.attr("signedinteger") = nb::cast(signedinteger);
|
||||
m.attr("unsignedinteger") = nb::cast(unsignedinteger);
|
||||
m.attr("integer") = nb::cast(integer);
|
||||
m.attr("number") = nb::cast(number);
|
||||
m.attr("generic") = nb::cast(generic);
|
||||
|
||||
nb::class_<ArrayAt>(
|
||||
m,
|
||||
@@ -700,7 +748,7 @@ void init_array(nb::module_& m) {
|
||||
.def(
|
||||
"__itruediv__",
|
||||
[](array& a, const ScalarOrArray v) -> array& {
|
||||
if (!is_floating_point(a.dtype())) {
|
||||
if (!issubdtype(a.dtype(), inexact)) {
|
||||
throw std::invalid_argument(
|
||||
"In place division cannot cast to non-floating point type.");
|
||||
}
|
||||
@@ -852,7 +900,7 @@ void init_array(nb::module_& m) {
|
||||
.def(
|
||||
"__invert__",
|
||||
[](const array& a) {
|
||||
if (is_floating_point(a.dtype())) {
|
||||
if (issubdtype(a.dtype(), inexact)) {
|
||||
throw std::invalid_argument(
|
||||
"Floating point types not allowed with or bitwise inversion.");
|
||||
}
|
||||
@@ -866,7 +914,8 @@ void init_array(nb::module_& m) {
|
||||
"__and__",
|
||||
[](const array& a, const ScalarOrArray v) {
|
||||
auto b = to_array(v, a.dtype());
|
||||
if (is_floating_point(a.dtype()) || is_floating_point(b.dtype())) {
|
||||
if (issubdtype(a.dtype(), inexact) ||
|
||||
issubdtype(b.dtype(), inexact)) {
|
||||
throw std::invalid_argument(
|
||||
"Floating point types not allowed with bitwise and.");
|
||||
}
|
||||
@@ -881,7 +930,8 @@ void init_array(nb::module_& m) {
|
||||
"__iand__",
|
||||
[](array& a, const ScalarOrArray v) -> array& {
|
||||
auto b = to_array(v, a.dtype());
|
||||
if (is_floating_point(a.dtype()) || is_floating_point(b.dtype())) {
|
||||
if (issubdtype(a.dtype(), inexact) ||
|
||||
issubdtype(b.dtype(), inexact)) {
|
||||
throw std::invalid_argument(
|
||||
"Floating point types not allowed with bitwise and.");
|
||||
}
|
||||
@@ -898,7 +948,8 @@ void init_array(nb::module_& m) {
|
||||
"__or__",
|
||||
[](const array& a, const ScalarOrArray v) {
|
||||
auto b = to_array(v, a.dtype());
|
||||
if (is_floating_point(a.dtype()) || is_floating_point(b.dtype())) {
|
||||
if (issubdtype(a.dtype(), inexact) ||
|
||||
issubdtype(b.dtype(), inexact)) {
|
||||
throw std::invalid_argument(
|
||||
"Floating point types not allowed with or bitwise or.");
|
||||
}
|
||||
@@ -913,7 +964,8 @@ void init_array(nb::module_& m) {
|
||||
"__ior__",
|
||||
[](array& a, const ScalarOrArray v) -> array& {
|
||||
auto b = to_array(v, a.dtype());
|
||||
if (is_floating_point(a.dtype()) || is_floating_point(b.dtype())) {
|
||||
if (issubdtype(a.dtype(), inexact) ||
|
||||
issubdtype(b.dtype(), inexact)) {
|
||||
throw std::invalid_argument(
|
||||
"Floating point types not allowed with or bitwise or.");
|
||||
}
|
||||
|
@@ -3684,4 +3684,62 @@ void init_ops(nb::module_& m) {
|
||||
Returns:
|
||||
array or list(array): An array or list of arrays with at least three dimensions.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"issubdtype",
|
||||
nb::overload_cast<const Dtype&, const Dtype&>(&issubdtype),
|
||||
""_a,
|
||||
""_a,
|
||||
R"pbdoc(
|
||||
Check if a :obj:`Dtype` or :obj:`DtypeCategory` is a subtype
|
||||
of another.
|
||||
|
||||
>>> ints = mx.array([1, 2, 3], dtype=mx.int32)
|
||||
>>> mx.issubdtype(ints.dtype, mx.integer)
|
||||
True
|
||||
>>> mx.issubdtype(ints.dtype, mx.floating)
|
||||
False
|
||||
|
||||
>>> floats = mx.array([1, 2, 3], dtype=mx.float32)
|
||||
>>> mx.issubdtype(floats.dtype, mx.integer)
|
||||
False
|
||||
>>> mx.issubdtype(floats.dtype, mx.floating)
|
||||
True
|
||||
|
||||
Similar types of different sizes are not subdtypes of each other:
|
||||
|
||||
>>> mx.issubdtype(mx.float64, mx.float32)
|
||||
False
|
||||
>>> mx.issubdtype(mx.float32, mx.float64)
|
||||
False
|
||||
|
||||
but both are subtypes of `floating`:
|
||||
|
||||
>>> mx.issubdtype(mx.float64, mx.floating)
|
||||
True
|
||||
>>> mx.issubdtype(mx.float32, mx.floating)
|
||||
True
|
||||
|
||||
For convenience, dtype-like objects are allowed too:
|
||||
|
||||
>>> mx.issubdtype(mx.float32, mx.inexact)
|
||||
True
|
||||
>>> mx.issubdtype(mx.signedinteger, mx.floating)
|
||||
False
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"issubdtype",
|
||||
nb::overload_cast<const Dtype&, const Dtype::Category&>(&issubdtype),
|
||||
""_a,
|
||||
""_a);
|
||||
m.def(
|
||||
"issubdtype",
|
||||
nb::overload_cast<const Dtype::Category&, const Dtype&>(&issubdtype),
|
||||
""_a,
|
||||
""_a);
|
||||
m.def(
|
||||
"issubdtype",
|
||||
nb::overload_cast<const Dtype::Category&, const Dtype::Category&>(
|
||||
&issubdtype),
|
||||
""_a,
|
||||
""_a);
|
||||
}
|
||||
|
@@ -56,7 +56,7 @@ inline array to_array(
|
||||
} else if (auto pv = std::get_if<nb::float_>(&v); pv) {
|
||||
auto out_t = dtype.value_or(float32);
|
||||
return array(
|
||||
nb::cast<float>(*pv), is_floating_point(out_t) ? out_t : float32);
|
||||
nb::cast<float>(*pv), issubdtype(out_t, floating) ? out_t : float32);
|
||||
} else if (auto pv = std::get_if<std::complex<float>>(&v); pv) {
|
||||
return array(static_cast<complex64_t>(*pv), complex64);
|
||||
} else {
|
||||
|
Reference in New Issue
Block a user