diff --git a/python/mlx/nn/layers/base.py b/python/mlx/nn/layers/base.py index dcf079457..7c4ccc85e 100644 --- a/python/mlx/nn/layers/base.py +++ b/python/mlx/nn/layers/base.py @@ -439,3 +439,11 @@ class Module(dict): def eval(self): self.train(False) + + def __getstate__(self): + state = self.__dict__.copy() + state.pop("_compiled_call_impl", None) + return state + + def __setstate__(self, state): + self.__dict__.update(state) \ No newline at end of file