diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index 44cbdaf1e5..445ef3df49 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -95,6 +95,9 @@ class TestBase(mlx_tests.MLXTestCase): m.save_weights(npz_file) m_load = make_model() m_load.load_weights(npz_file) + + # Eval before cleanup so model file is unlocked. + mx.eval(m_load.state) tdir.cleanup() eq_tree = tree_map(mx.array_equal, m.parameters(), m_load.parameters()) @@ -110,6 +113,9 @@ class TestBase(mlx_tests.MLXTestCase): m.save_weights(safetensors_file) m_load = make_model() m_load.load_weights(safetensors_file) + + # Eval before cleanup so model file is unlocked. + mx.eval(m_load.state) tdir.cleanup() eq_tree = tree_map(mx.array_equal, m.parameters(), m_load.parameters())