mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 08:29:35 +08:00
add numeric type hierarchy and issubdtype as well as a set_dtype meth… (#427)
* add numeric type hierarchy and issubdtype as well as a set_dtype method to nn.Module with predicate
numeric type hierarchy and issubtype is compatible to the [numpy hierarchy](220f0ab2c5/numpy/_core/numerictypes.py (L42)
).
Closes #285.
* nits in docs
* unify type category checking
* nits in docs
* nits in docs
* more docs nits
* fix callable type
---------
Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
@@ -1492,6 +1492,29 @@ class TestLayers(mlx_tests.MLXTestCase):
|
||||
"AvgPool2d(kernel_size=(1, 2), stride=(2, 2), padding=(1, 2))",
|
||||
)
|
||||
|
||||
def test_set_dtype(self):
|
||||
def assert_dtype(layer, dtype):
|
||||
for k, v in tree_flatten(layer.parameters()):
|
||||
self.assertEqual(v.dtype, dtype, f"dtype mismatch for {k}")
|
||||
|
||||
layer = nn.Linear(input_dims=4, output_dims=8, bias=True)
|
||||
assert_dtype(layer, mx.float32)
|
||||
|
||||
layer.set_dtype(mx.bfloat16)
|
||||
assert_dtype(layer, mx.bfloat16)
|
||||
|
||||
layer.set_dtype(mx.float32, lambda x: False)
|
||||
assert_dtype(layer, mx.bfloat16)
|
||||
|
||||
layer.set_dtype(mx.int32, lambda x: True)
|
||||
assert_dtype(layer, mx.int32)
|
||||
|
||||
layer.set_dtype(mx.int64, predicate=None)
|
||||
assert_dtype(layer, mx.int64)
|
||||
|
||||
layer.set_dtype(mx.int16, lambda x: mx.issubdtype(x, mx.integer))
|
||||
assert_dtype(layer, mx.int16)
|
||||
|
||||
def test_rnn(self):
|
||||
layer = nn.RNN(input_size=5, hidden_size=12, bias=True)
|
||||
inp = mx.random.normal((2, 25, 5))
|
||||
|
@@ -2026,6 +2026,40 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(mx_res.ndim, np_res.ndim)
|
||||
self.assertTrue(mx.all(mx.equal(mx_res, atleast_arrays[i])))
|
||||
|
||||
def test_issubdtype(self):
|
||||
self.assertTrue(mx.issubdtype(mx.bfloat16, mx.inexact))
|
||||
|
||||
cats = [
|
||||
"complexfloating",
|
||||
"floating",
|
||||
"inexact",
|
||||
"signedinteger",
|
||||
"unsignedinteger",
|
||||
"integer",
|
||||
"number",
|
||||
"generic",
|
||||
"bool_",
|
||||
"uint8",
|
||||
"uint16",
|
||||
"uint32",
|
||||
"uint64",
|
||||
"int8",
|
||||
"int16",
|
||||
"int32",
|
||||
"int64",
|
||||
"float16",
|
||||
"float32",
|
||||
"complex64",
|
||||
]
|
||||
|
||||
for a in cats:
|
||||
for b in cats:
|
||||
self.assertEqual(
|
||||
mx.issubdtype(getattr(mx, a), getattr(mx, b)),
|
||||
np.issubdtype(getattr(np, a), getattr(np, b)),
|
||||
f"mx and np don't aggree on {a}, {b}",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user