mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-05 07:34:42 +08:00
add numeric type hierarchy and issubdtype as well as a set_dtype meth… (#427)
* add numeric type hierarchy and issubdtype as well as a set_dtype method to nn.Module with predicate
numeric type hierarchy and issubtype is compatible to the [numpy hierarchy](220f0ab2c5/numpy/_core/numerictypes.py (L42)
).
Closes #285.
* nits in docs
* unify type category checking
* nits in docs
* nits in docs
* more docs nits
* fix callable type
---------
Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
@@ -2026,6 +2026,40 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(mx_res.ndim, np_res.ndim)
|
||||
self.assertTrue(mx.all(mx.equal(mx_res, atleast_arrays[i])))
|
||||
|
||||
def test_issubdtype(self):
|
||||
self.assertTrue(mx.issubdtype(mx.bfloat16, mx.inexact))
|
||||
|
||||
cats = [
|
||||
"complexfloating",
|
||||
"floating",
|
||||
"inexact",
|
||||
"signedinteger",
|
||||
"unsignedinteger",
|
||||
"integer",
|
||||
"number",
|
||||
"generic",
|
||||
"bool_",
|
||||
"uint8",
|
||||
"uint16",
|
||||
"uint32",
|
||||
"uint64",
|
||||
"int8",
|
||||
"int16",
|
||||
"int32",
|
||||
"int64",
|
||||
"float16",
|
||||
"float32",
|
||||
"complex64",
|
||||
]
|
||||
|
||||
for a in cats:
|
||||
for b in cats:
|
||||
self.assertEqual(
|
||||
mx.issubdtype(getattr(mx, a), getattr(mx, b)),
|
||||
np.issubdtype(getattr(np, a), getattr(np, b)),
|
||||
f"mx and np don't aggree on {a}, {b}",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user