add numeric type hierarchy and issubdtype as well as a set_dtype meth… (#427)

* add numeric type hierarchy and issubdtype as well as a set_dtype method to nn.Module with predicate

numeric type hierarchy and issubtype is compatible to the [numpy hierarchy](220f0ab2c5/numpy/_core/numerictypes.py (L42)).

Closes #285.

* nits in docs

* unify type category checking

* nits in docs

* nits in docs

* more docs nits

* fix callable type

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Daniel Strobusch
2024-03-25 20:32:59 +01:00
committed by GitHub
parent bfb5bad4f0
commit 479051ce1c
26 changed files with 538 additions and 97 deletions

View File

@@ -1492,6 +1492,29 @@ class TestLayers(mlx_tests.MLXTestCase):
"AvgPool2d(kernel_size=(1, 2), stride=(2, 2), padding=(1, 2))",
)
def test_set_dtype(self):
def assert_dtype(layer, dtype):
for k, v in tree_flatten(layer.parameters()):
self.assertEqual(v.dtype, dtype, f"dtype mismatch for {k}")
layer = nn.Linear(input_dims=4, output_dims=8, bias=True)
assert_dtype(layer, mx.float32)
layer.set_dtype(mx.bfloat16)
assert_dtype(layer, mx.bfloat16)
layer.set_dtype(mx.float32, lambda x: False)
assert_dtype(layer, mx.bfloat16)
layer.set_dtype(mx.int32, lambda x: True)
assert_dtype(layer, mx.int32)
layer.set_dtype(mx.int64, predicate=None)
assert_dtype(layer, mx.int64)
layer.set_dtype(mx.int16, lambda x: mx.issubdtype(x, mx.integer))
assert_dtype(layer, mx.int16)
def test_rnn(self):
layer = nn.RNN(input_size=5, hidden_size=12, bias=True)
inp = mx.random.normal((2, 25, 5))