docs for checkpoint + a few more tests

This commit is contained in:
Awni Hannun
2024-03-05 15:34:46 -08:00
parent 1368bce280
commit a5827d0384
6 changed files with 38 additions and 8 deletions

View File

@@ -1487,10 +1487,22 @@ class TestNNUtils(mlx_tests.MLXTestCase):
lin = nn.Linear(2, 2)
x = mx.array([0.1, 0.2])
lin.my_attr = "hello"
expected_y = lin(x)
y = nn.utils.checkpoint(lin)(x)
clin = nn.utils.checkpoint(lin)
y = clin(x)
self.assertTrue(mx.allclose(expected_y, y))
# Check get/set attribute
self.assertEqual(clin.my_attr, "hello")
clin.my_attr = "bye"
self.assertEqual(clin.my_attr, "bye")
self.assertTrue(isinstance(clin, nn.Linear))
self.assertEqual(repr(clin), repr(lin))
if __name__ == "__main__":
unittest.main()