mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-02 13:54:44 +08:00
add numeric type hierarchy and issubdtype as well as a set_dtype meth… (#427)
* add numeric type hierarchy and issubdtype as well as a set_dtype method to nn.Module with predicate
numeric type hierarchy and issubtype is compatible to the [numpy hierarchy](220f0ab2c5/numpy/_core/numerictypes.py (L42)
).
Closes #285.
* nits in docs
* unify type category checking
* nits in docs
* nits in docs
* more docs nits
* fix callable type
---------
Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
@@ -1492,6 +1492,29 @@ class TestLayers(mlx_tests.MLXTestCase):
|
||||
"AvgPool2d(kernel_size=(1, 2), stride=(2, 2), padding=(1, 2))",
|
||||
)
|
||||
|
||||
def test_set_dtype(self):
|
||||
def assert_dtype(layer, dtype):
|
||||
for k, v in tree_flatten(layer.parameters()):
|
||||
self.assertEqual(v.dtype, dtype, f"dtype mismatch for {k}")
|
||||
|
||||
layer = nn.Linear(input_dims=4, output_dims=8, bias=True)
|
||||
assert_dtype(layer, mx.float32)
|
||||
|
||||
layer.set_dtype(mx.bfloat16)
|
||||
assert_dtype(layer, mx.bfloat16)
|
||||
|
||||
layer.set_dtype(mx.float32, lambda x: False)
|
||||
assert_dtype(layer, mx.bfloat16)
|
||||
|
||||
layer.set_dtype(mx.int32, lambda x: True)
|
||||
assert_dtype(layer, mx.int32)
|
||||
|
||||
layer.set_dtype(mx.int64, predicate=None)
|
||||
assert_dtype(layer, mx.int64)
|
||||
|
||||
layer.set_dtype(mx.int16, lambda x: mx.issubdtype(x, mx.integer))
|
||||
assert_dtype(layer, mx.int16)
|
||||
|
||||
def test_rnn(self):
|
||||
layer = nn.RNN(input_size=5, hidden_size=12, bias=True)
|
||||
inp = mx.random.normal((2, 25, 5))
|
||||
|
Reference in New Issue
Block a user