mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-21 10:18:10 +08:00
add numeric type hierarchy and issubdtype as well as a set_dtype meth… (#427)
* add numeric type hierarchy and issubdtype as well as a set_dtype method to nn.Module with predicate
numeric type hierarchy and issubtype is compatible to the [numpy hierarchy](220f0ab2c5/numpy/_core/numerictypes.py (L42)
).
Closes #285.
* nits in docs
* unify type category checking
* nits in docs
* nits in docs
* more docs nits
* fix callable type
---------
Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
@@ -3684,4 +3684,62 @@ void init_ops(nb::module_& m) {
|
||||
Returns:
|
||||
array or list(array): An array or list of arrays with at least three dimensions.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"issubdtype",
|
||||
nb::overload_cast<const Dtype&, const Dtype&>(&issubdtype),
|
||||
""_a,
|
||||
""_a,
|
||||
R"pbdoc(
|
||||
Check if a :obj:`Dtype` or :obj:`DtypeCategory` is a subtype
|
||||
of another.
|
||||
|
||||
>>> ints = mx.array([1, 2, 3], dtype=mx.int32)
|
||||
>>> mx.issubdtype(ints.dtype, mx.integer)
|
||||
True
|
||||
>>> mx.issubdtype(ints.dtype, mx.floating)
|
||||
False
|
||||
|
||||
>>> floats = mx.array([1, 2, 3], dtype=mx.float32)
|
||||
>>> mx.issubdtype(floats.dtype, mx.integer)
|
||||
False
|
||||
>>> mx.issubdtype(floats.dtype, mx.floating)
|
||||
True
|
||||
|
||||
Similar types of different sizes are not subdtypes of each other:
|
||||
|
||||
>>> mx.issubdtype(mx.float64, mx.float32)
|
||||
False
|
||||
>>> mx.issubdtype(mx.float32, mx.float64)
|
||||
False
|
||||
|
||||
but both are subtypes of `floating`:
|
||||
|
||||
>>> mx.issubdtype(mx.float64, mx.floating)
|
||||
True
|
||||
>>> mx.issubdtype(mx.float32, mx.floating)
|
||||
True
|
||||
|
||||
For convenience, dtype-like objects are allowed too:
|
||||
|
||||
>>> mx.issubdtype(mx.float32, mx.inexact)
|
||||
True
|
||||
>>> mx.issubdtype(mx.signedinteger, mx.floating)
|
||||
False
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"issubdtype",
|
||||
nb::overload_cast<const Dtype&, const Dtype::Category&>(&issubdtype),
|
||||
""_a,
|
||||
""_a);
|
||||
m.def(
|
||||
"issubdtype",
|
||||
nb::overload_cast<const Dtype::Category&, const Dtype&>(&issubdtype),
|
||||
""_a,
|
||||
""_a);
|
||||
m.def(
|
||||
"issubdtype",
|
||||
nb::overload_cast<const Dtype::Category&, const Dtype::Category&>(
|
||||
&issubdtype),
|
||||
""_a,
|
||||
""_a);
|
||||
}
|
||||
|
Reference in New Issue
Block a user