add numeric type hierarchy and issubdtype as well as a set_dtype meth… (#427)

* add numeric type hierarchy and issubdtype as well as a set_dtype method to nn.Module with predicate

numeric type hierarchy and issubtype is compatible to the [numpy hierarchy](220f0ab2c5/numpy/_core/numerictypes.py (L42)).

Closes #285.

* nits in docs

* unify type category checking

* nits in docs

* nits in docs

* more docs nits

* fix callable type

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Daniel Strobusch
2024-03-25 20:32:59 +01:00
committed by GitHub
parent bfb5bad4f0
commit 479051ce1c
26 changed files with 538 additions and 97 deletions

View File

@@ -578,3 +578,26 @@ class Module(dict):
See :func:`train`.
"""
self.train(False)
def set_dtype(
self,
dtype: mx.Dtype,
predicate: Optional[Callable[[mx.Dtype], bool]] = lambda x: mx.issubdtype(
x, mx.floating
),
):
"""Set the dtype of the module's parameters.
Args:
dtype (Dtype): The new dtype.
predicate (typing.Callable, optional): A predicate to select
parameters to cast. By default, only parameters of type
:attr:`floating` will be updated to avoid casting integer
parameters to the new dtype.
"""
if predicate is None:
def predicate(_):
return True
self.apply(lambda x: x.astype(dtype) if predicate(x.dtype) else x)