mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 12:49:44 +08:00
Added support for python copy (#335)
* Added support for python copy * precommit changes * removed `_compiled_call_impl` line * added tests and suggested changes * ACK changes
This commit is contained in:
@@ -854,6 +854,19 @@ class TestLayers(mlx_tests.MLXTestCase):
|
||||
self.assertTrue(y.shape, x.shape)
|
||||
self.assertTrue(y.dtype, mx.float16)
|
||||
|
||||
def test_deepcopy(self):
|
||||
import copy
|
||||
|
||||
layer = nn.Linear(input_dims=4, output_dims=8)
|
||||
layer_copy = copy.deepcopy(layer)
|
||||
|
||||
# Verify that the copied layer is not the same object as the original layer
|
||||
self.assertIsNot(layer_copy, layer)
|
||||
|
||||
# Verify that the copied layer has the same attributes as the original layer
|
||||
self.assertEqual(layer_copy.input_dims, layer.input_dims)
|
||||
self.assertEqual(layer_copy.output_dims, layer.output_dims)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user