added complex64 and python tests

This commit is contained in:
dc-dc-dc 2023-12-20 11:58:37 -05:00
parent 3f82fe8888
commit 29e43170c4
3 changed files with 35 additions and 0 deletions

View File

@ -29,6 +29,8 @@ std::string dtype_to_safetensor_str(Dtype t) {
return ST_U8;
} else if (t == bool_) {
return ST_BOOL;
} else if (t == complex64) {
return ST_C64;
} else {
throw std::runtime_error("[safetensor] unsupported dtype");
}
@ -59,6 +61,8 @@ Dtype dtype_from_safetensor_str(std::string str) {
return uint8;
} else if (str == ST_BOOL) {
return bool_;
} else if (str == ST_C64) {
return complex64;
} else {
throw std::runtime_error("[safetensor] unsupported dtype " + str);
}

View File

@ -25,4 +25,8 @@ namespace mlx::core {
#define ST_U16 "U16"
#define ST_U32 "U32"
#define ST_U64 "U64"
// Note: Complex numbers aren't in the spec yet so this could change -
// https://github.com/huggingface/safetensors/issues/389
#define ST_C64 "C64"
} // namespace mlx::core

View File

@ -64,6 +64,33 @@ class TestLoad(mlx_tests.MLXTestCase):
load_arr_mlx_npy = np.load(save_file_mlx)
self.assertTrue(np.array_equal(load_arr_mlx_npy, save_arr_npy))
def test_save_and_load_safetensor(self):
if not os.path.isdir(self.test_dir):
os.mkdir(self.test_dir)
for dt in self.dtypes + ["bfloat16"]:
with self.subTest(dtype=dt):
for i, shape in enumerate([(1,), (23,), (1024, 1024), (4, 6, 3, 1, 2)]):
with self.subTest(shape=shape):
save_file_mlx = os.path.join(
self.test_dir, f"mlx_{dt}_{i}_fs.safetensors"
)
save_dict = {
"test": mx.random.normal(shape=shape, dtype=getattr(mx, dt))
if dt in ["float32", "float16", "bfloat16"]
else mx.ones(shape, dtype=getattr(mx, dt))
}
with open(save_file_mlx, "wb") as f:
mx.save_safetensor(f, save_dict)
with open(save_file_mlx, "rb") as f:
load_dict = mx.load_safetensor(f)
self.assertTrue("test" in load_dict)
self.assertTrue(
mx.array_equal(load_dict["test"], save_dict["test"])
)
def test_save_and_load_fs(self):
if not os.path.isdir(self.test_dir):