diff --git a/python/tests/test_load.py b/python/tests/test_load.py index 840d3b471..eaf7bcade 100644 --- a/python/tests/test_load.py +++ b/python/tests/test_load.py @@ -3,6 +3,7 @@ import os import tempfile import unittest +from pathlib import Path import mlx.core as mx import mlx_tests @@ -62,17 +63,28 @@ class TestLoad(mlx_tests.MLXTestCase): load_arr_mlx_npy = np.load(save_file_mlx) self.assertTrue(np.array_equal(load_arr_mlx_npy, save_arr_npy)) + save_file = os.path.join(self.test_dir, f"mlx_path.npy") + save_arr = mx.ones((32,)) + mx.save(Path(save_file), save_arr) + + # Load array saved by mlx as mlx array + load_arr = mx.load(Path(save_file)) + self.assertTrue(mx.array_equal(load_arr, save_arr)) + def test_save_and_load_safetensors(self): test_file = os.path.join(self.test_dir, "test.safetensors") with self.assertRaises(Exception): mx.save_safetensors(test_file, {"a": mx.ones((4, 4))}, {"testing": 0}) - mx.save_safetensors( - test_file, {"test": mx.ones((2, 2))}, {"testing": "test", "format": "mlx"} - ) - res = mx.load(test_file, return_metadata=True) - self.assertEqual(len(res), 2) - self.assertEqual(res[1], {"testing": "test", "format": "mlx"}) + for obj in [str, Path]: + mx.save_safetensors( + obj(test_file), + {"test": mx.ones((2, 2))}, + {"testing": "test", "format": "mlx"}, + ) + res = mx.load(obj(test_file), return_metadata=True) + self.assertEqual(len(res), 2) + self.assertEqual(res[1], {"testing": "test", "format": "mlx"}) for dt in self.dtypes + ["bfloat16"]: with self.subTest(dtype=dt): @@ -128,6 +140,13 @@ class TestLoad(mlx_tests.MLXTestCase): mx.array_equal(load_dict["test"], save_dict["test"]) ) + save_file_mlx = os.path.join(self.test_dir, f"mlx_path_test_fs.gguf") + save_dict = {"test": mx.ones(shape)} + mx.save_gguf(Path(save_file_mlx), save_dict) + load_dict = mx.load(Path(save_file_mlx)) + self.assertTrue("test" in load_dict) + self.assertTrue(mx.array_equal(load_dict["test"], save_dict["test"])) + def test_load_f8_e4m3(self): if not os.path.isdir(self.test_dir): os.mkdir(self.test_dir)