checkpoint module's __call__

This commit is contained in:
Awni Hannun
2024-03-05 08:39:25 -08:00
parent cbefd9129e
commit 8918a437bb
2 changed files with 26 additions and 12 deletions

View File

@@ -1481,5 +1481,16 @@ class TestLayers(mlx_tests.MLXTestCase):
)
class TestNNUtils(mlx_tests.MLXTestCase):
def test_checkpoint(self):
lin = nn.Linear(2, 2)
x = mx.array([0.1, 0.2])
expected_y = lin(x)
y = nn.utils.checkpoint(lin)(x)
self.assertTrue(mx.allclose(expected_y, y))
if __name__ == "__main__":
unittest.main()