added tests and suggested changes

This commit is contained in:
Manish Aradwad 2024-01-02 17:43:17 +00:00
parent 74328a0938
commit 19c0bed4e2
3 changed files with 15 additions and 2 deletions

View File

@ -7,6 +7,7 @@ with a short description of your contribution(s) below. For example:
MLX was developed with contributions from the following individuals: MLX was developed with contributions from the following individuals:
- Manish Aradwad: Added support for deepcopy.
- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. - Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops.
- Juarez Bochi: Fixed bug in cross attention. - Juarez Bochi: Fixed bug in cross attention.
- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example. - Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example.

View File

@ -441,8 +441,7 @@ class Module(dict):
self.train(False) self.train(False)
def __getstate__(self): def __getstate__(self):
state = self.__dict__.copy() return self.__dict__.copy()
return state
def __setstate__(self, state): def __setstate__(self, state):
self.__dict__.update(state) self.__dict__.update(state)

View File

@ -840,6 +840,19 @@ class TestNN(mlx_tests.MLXTestCase):
self.assertTrue(y.shape, x.shape) self.assertTrue(y.shape, x.shape)
self.assertTrue(y.dtype, mx.float16) self.assertTrue(y.dtype, mx.float16)
def test_deepcopy(self):
import copy
layer = nn.Linear(input_dims=4, output_dims=8)
layer_copy = copy.deepcopy(layer)
# Verify that the copied layer is not the same object as the original layer
self.assertIsNot(layer_copy, layer)
# Verify that the copied layer has the same attributes as the original layer
self.assertEqual(layer_copy.input_dims, layer.input_dims)
self.assertEqual(layer_copy.output_dims, layer.output_dims)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()