add numeric type hierarchy and issubdtype as well as a set_dtype meth… (#427)

* add numeric type hierarchy and issubdtype as well as a set_dtype method to nn.Module with predicate

numeric type hierarchy and issubtype is compatible to the [numpy hierarchy](220f0ab2c5/numpy/_core/numerictypes.py (L42)).

Closes #285.

* nits in docs

* unify type category checking

* nits in docs

* nits in docs

* more docs nits

* fix callable type

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Daniel Strobusch
2024-03-25 20:32:59 +01:00
committed by GitHub
parent bfb5bad4f0
commit 479051ce1c
26 changed files with 538 additions and 97 deletions

View File

@@ -2026,6 +2026,40 @@ class TestOps(mlx_tests.MLXTestCase):
self.assertEqual(mx_res.ndim, np_res.ndim)
self.assertTrue(mx.all(mx.equal(mx_res, atleast_arrays[i])))
def test_issubdtype(self):
self.assertTrue(mx.issubdtype(mx.bfloat16, mx.inexact))
cats = [
"complexfloating",
"floating",
"inexact",
"signedinteger",
"unsignedinteger",
"integer",
"number",
"generic",
"bool_",
"uint8",
"uint16",
"uint32",
"uint64",
"int8",
"int16",
"int32",
"int64",
"float16",
"float32",
"complex64",
]
for a in cats:
for b in cats:
self.assertEqual(
mx.issubdtype(getattr(mx, a), getattr(mx, b)),
np.issubdtype(getattr(np, a), getattr(np, b)),
f"mx and np don't aggree on {a}, {b}",
)
if __name__ == "__main__":
unittest.main()