add numeric type hierarchy and issubdtype as well as a set_dtype meth… (#427)

* add numeric type hierarchy and issubdtype as well as a set_dtype method to nn.Module with predicate

numeric type hierarchy and issubtype is compatible to the [numpy hierarchy](220f0ab2c5/numpy/_core/numerictypes.py (L42)).

Closes #285.

* nits in docs

* unify type category checking

* nits in docs

* nits in docs

* more docs nits

* fix callable type

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Daniel Strobusch
2024-03-25 20:32:59 +01:00
committed by GitHub
parent bfb5bad4f0
commit 479051ce1c
26 changed files with 538 additions and 97 deletions

View File

@@ -3684,4 +3684,62 @@ void init_ops(nb::module_& m) {
Returns:
array or list(array): An array or list of arrays with at least three dimensions.
)pbdoc");
m.def(
"issubdtype",
nb::overload_cast<const Dtype&, const Dtype&>(&issubdtype),
""_a,
""_a,
R"pbdoc(
Check if a :obj:`Dtype` or :obj:`DtypeCategory` is a subtype
of another.
>>> ints = mx.array([1, 2, 3], dtype=mx.int32)
>>> mx.issubdtype(ints.dtype, mx.integer)
True
>>> mx.issubdtype(ints.dtype, mx.floating)
False
>>> floats = mx.array([1, 2, 3], dtype=mx.float32)
>>> mx.issubdtype(floats.dtype, mx.integer)
False
>>> mx.issubdtype(floats.dtype, mx.floating)
True
Similar types of different sizes are not subdtypes of each other:
>>> mx.issubdtype(mx.float64, mx.float32)
False
>>> mx.issubdtype(mx.float32, mx.float64)
False
but both are subtypes of `floating`:
>>> mx.issubdtype(mx.float64, mx.floating)
True
>>> mx.issubdtype(mx.float32, mx.floating)
True
For convenience, dtype-like objects are allowed too:
>>> mx.issubdtype(mx.float32, mx.inexact)
True
>>> mx.issubdtype(mx.signedinteger, mx.floating)
False
)pbdoc");
m.def(
"issubdtype",
nb::overload_cast<const Dtype&, const Dtype::Category&>(&issubdtype),
""_a,
""_a);
m.def(
"issubdtype",
nb::overload_cast<const Dtype::Category&, const Dtype&>(&issubdtype),
""_a,
""_a);
m.def(
"issubdtype",
nb::overload_cast<const Dtype::Category&, const Dtype::Category&>(
&issubdtype),
""_a,
""_a);
}