From 29e43170c49f6eb34a0d4b9c67a58bf015e8f8df Mon Sep 17 00:00:00 2001 From: dc-dc-dc Date: Wed, 20 Dec 2023 11:58:37 -0500 Subject: [PATCH] added complex64 and python tests --- mlx/safetensor.cpp | 4 ++++ mlx/safetensor.h | 4 ++++ python/tests/test_load.py | 27 +++++++++++++++++++++++++++ 3 files changed, 35 insertions(+) diff --git a/mlx/safetensor.cpp b/mlx/safetensor.cpp index 37d7ceb22..4647fa37f 100644 --- a/mlx/safetensor.cpp +++ b/mlx/safetensor.cpp @@ -29,6 +29,8 @@ std::string dtype_to_safetensor_str(Dtype t) { return ST_U8; } else if (t == bool_) { return ST_BOOL; + } else if (t == complex64) { + return ST_C64; } else { throw std::runtime_error("[safetensor] unsupported dtype"); } @@ -59,6 +61,8 @@ Dtype dtype_from_safetensor_str(std::string str) { return uint8; } else if (str == ST_BOOL) { return bool_; + } else if (str == ST_C64) { + return complex64; } else { throw std::runtime_error("[safetensor] unsupported dtype " + str); } diff --git a/mlx/safetensor.h b/mlx/safetensor.h index 5a6fa24b8..6bec469d3 100644 --- a/mlx/safetensor.h +++ b/mlx/safetensor.h @@ -25,4 +25,8 @@ namespace mlx::core { #define ST_U16 "U16" #define ST_U32 "U32" #define ST_U64 "U64" + +// Note: Complex numbers aren't in the spec yet so this could change - +// https://github.com/huggingface/safetensors/issues/389 +#define ST_C64 "C64" } // namespace mlx::core diff --git a/python/tests/test_load.py b/python/tests/test_load.py index e63588d03..7c05906cf 100644 --- a/python/tests/test_load.py +++ b/python/tests/test_load.py @@ -64,6 +64,33 @@ class TestLoad(mlx_tests.MLXTestCase): load_arr_mlx_npy = np.load(save_file_mlx) self.assertTrue(np.array_equal(load_arr_mlx_npy, save_arr_npy)) + def test_save_and_load_safetensor(self): + if not os.path.isdir(self.test_dir): + os.mkdir(self.test_dir) + + for dt in self.dtypes + ["bfloat16"]: + with self.subTest(dtype=dt): + for i, shape in enumerate([(1,), (23,), (1024, 1024), (4, 6, 3, 1, 2)]): + with self.subTest(shape=shape): + save_file_mlx = os.path.join( + self.test_dir, f"mlx_{dt}_{i}_fs.safetensors" + ) + save_dict = { + "test": mx.random.normal(shape=shape, dtype=getattr(mx, dt)) + if dt in ["float32", "float16", "bfloat16"] + else mx.ones(shape, dtype=getattr(mx, dt)) + } + + with open(save_file_mlx, "wb") as f: + mx.save_safetensor(f, save_dict) + with open(save_file_mlx, "rb") as f: + load_dict = mx.load_safetensor(f) + + self.assertTrue("test" in load_dict) + self.assertTrue( + mx.array_equal(load_dict["test"], save_dict["test"]) + ) + def test_save_and_load_fs(self): if not os.path.isdir(self.test_dir):