diff --git a/python/mlx/nn/layers/base.py b/python/mlx/nn/layers/base.py index 646f5f2dc..094c89326 100644 --- a/python/mlx/nn/layers/base.py +++ b/python/mlx/nn/layers/base.py @@ -54,6 +54,8 @@ class Module(dict): mx.eval(model.parameters()) """ + __call__: Callable + def __init__(self): """Should be called by the subclasses of ``Module``.""" self._no_grad = set()