Added support for python copy (#335)

* Added support for python copy

* precommit changes

* removed `_compiled_call_impl` line

* added tests and suggested changes

* ACK changes
This commit is contained in:
toji
2024-01-04 10:29:40 +05:30
committed by GitHub
parent 0d31128a44
commit d2467c320d
3 changed files with 20 additions and 0 deletions

View File

@@ -854,6 +854,19 @@ class TestLayers(mlx_tests.MLXTestCase):
self.assertTrue(y.shape, x.shape)
self.assertTrue(y.dtype, mx.float16)
def test_deepcopy(self):
import copy
layer = nn.Linear(input_dims=4, output_dims=8)
layer_copy = copy.deepcopy(layer)
# Verify that the copied layer is not the same object as the original layer
self.assertIsNot(layer_copy, layer)
# Verify that the copied layer has the same attributes as the original layer
self.assertEqual(layer_copy.input_dims, layer.input_dims)
self.assertEqual(layer_copy.output_dims, layer.output_dims)
if __name__ == "__main__":
unittest.main()