diff --git a/ACKNOWLEDGMENTS.md b/ACKNOWLEDGMENTS.md index b4108657f..57e84c736 100644 --- a/ACKNOWLEDGMENTS.md +++ b/ACKNOWLEDGMENTS.md @@ -7,6 +7,7 @@ with a short description of your contribution(s) below. For example: MLX was developed with contributions from the following individuals: +- Manish Aradwad: Added support to pickle `nn.Module`. - Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. - Juarez Bochi: Fixed bug in cross attention. - Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example. diff --git a/python/mlx/nn/layers/base.py b/python/mlx/nn/layers/base.py index 646f5f2dc..e36d6649b 100644 --- a/python/mlx/nn/layers/base.py +++ b/python/mlx/nn/layers/base.py @@ -530,3 +530,9 @@ class Module(dict): See :func:`train`. """ self.train(False) + + def __getstate__(self): + return self.__dict__.copy() + + def __setstate__(self, state): + self.__dict__.update(state) diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index 28e72a7e7..2928b6855 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -854,6 +854,19 @@ class TestLayers(mlx_tests.MLXTestCase): self.assertTrue(y.shape, x.shape) self.assertTrue(y.dtype, mx.float16) + def test_deepcopy(self): + import copy + + layer = nn.Linear(input_dims=4, output_dims=8) + layer_copy = copy.deepcopy(layer) + + # Verify that the copied layer is not the same object as the original layer + self.assertIsNot(layer_copy, layer) + + # Verify that the copied layer has the same attributes as the original layer + self.assertEqual(layer_copy.input_dims, layer.input_dims) + self.assertEqual(layer_copy.output_dims, layer.output_dims) + if __name__ == "__main__": unittest.main()