diff --git a/python/tests/test_load.py b/python/tests/test_load.py index 7c05906cf..c38fff661 100644 --- a/python/tests/test_load.py +++ b/python/tests/test_load.py @@ -84,7 +84,7 @@ class TestLoad(mlx_tests.MLXTestCase): with open(save_file_mlx, "wb") as f: mx.save_safetensor(f, save_dict) with open(save_file_mlx, "rb") as f: - load_dict = mx.load_safetensor(f) + load_dict = mx.load(f) self.assertTrue("test" in load_dict) self.assertTrue(