diff --git a/python/mlx/nn/layers/base.py b/python/mlx/nn/layers/base.py index 7c4ccc85e..c22d0adc2 100644 --- a/python/mlx/nn/layers/base.py +++ b/python/mlx/nn/layers/base.py @@ -444,6 +444,6 @@ class Module(dict): state = self.__dict__.copy() state.pop("_compiled_call_impl", None) return state - + def __setstate__(self, state): - self.__dict__.update(state) \ No newline at end of file + self.__dict__.update(state)