mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-19 19:38:16 +08:00
add numeric type hierarchy and issubdtype as well as a set_dtype meth… (#427)
* add numeric type hierarchy and issubdtype as well as a set_dtype method to nn.Module with predicate
numeric type hierarchy and issubtype is compatible to the [numpy hierarchy](220f0ab2c5/numpy/_core/numerictypes.py (L42)
).
Closes #285.
* nits in docs
* unify type category checking
* nits in docs
* nits in docs
* more docs nits
* fix callable type
---------
Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
@@ -578,3 +578,26 @@ class Module(dict):
|
||||
See :func:`train`.
|
||||
"""
|
||||
self.train(False)
|
||||
|
||||
def set_dtype(
|
||||
self,
|
||||
dtype: mx.Dtype,
|
||||
predicate: Optional[Callable[[mx.Dtype], bool]] = lambda x: mx.issubdtype(
|
||||
x, mx.floating
|
||||
),
|
||||
):
|
||||
"""Set the dtype of the module's parameters.
|
||||
|
||||
Args:
|
||||
dtype (Dtype): The new dtype.
|
||||
predicate (typing.Callable, optional): A predicate to select
|
||||
parameters to cast. By default, only parameters of type
|
||||
:attr:`floating` will be updated to avoid casting integer
|
||||
parameters to the new dtype.
|
||||
"""
|
||||
if predicate is None:
|
||||
|
||||
def predicate(_):
|
||||
return True
|
||||
|
||||
self.apply(lambda x: x.astype(dtype) if predicate(x.dtype) else x)
|
||||
|
Reference in New Issue
Block a user