mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-22 04:56:41 +08:00
added complex64 and python tests
This commit is contained in:
parent
3f82fe8888
commit
29e43170c4
@ -29,6 +29,8 @@ std::string dtype_to_safetensor_str(Dtype t) {
|
|||||||
return ST_U8;
|
return ST_U8;
|
||||||
} else if (t == bool_) {
|
} else if (t == bool_) {
|
||||||
return ST_BOOL;
|
return ST_BOOL;
|
||||||
|
} else if (t == complex64) {
|
||||||
|
return ST_C64;
|
||||||
} else {
|
} else {
|
||||||
throw std::runtime_error("[safetensor] unsupported dtype");
|
throw std::runtime_error("[safetensor] unsupported dtype");
|
||||||
}
|
}
|
||||||
@ -59,6 +61,8 @@ Dtype dtype_from_safetensor_str(std::string str) {
|
|||||||
return uint8;
|
return uint8;
|
||||||
} else if (str == ST_BOOL) {
|
} else if (str == ST_BOOL) {
|
||||||
return bool_;
|
return bool_;
|
||||||
|
} else if (str == ST_C64) {
|
||||||
|
return complex64;
|
||||||
} else {
|
} else {
|
||||||
throw std::runtime_error("[safetensor] unsupported dtype " + str);
|
throw std::runtime_error("[safetensor] unsupported dtype " + str);
|
||||||
}
|
}
|
||||||
|
@ -25,4 +25,8 @@ namespace mlx::core {
|
|||||||
#define ST_U16 "U16"
|
#define ST_U16 "U16"
|
||||||
#define ST_U32 "U32"
|
#define ST_U32 "U32"
|
||||||
#define ST_U64 "U64"
|
#define ST_U64 "U64"
|
||||||
|
|
||||||
|
// Note: Complex numbers aren't in the spec yet so this could change -
|
||||||
|
// https://github.com/huggingface/safetensors/issues/389
|
||||||
|
#define ST_C64 "C64"
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -64,6 +64,33 @@ class TestLoad(mlx_tests.MLXTestCase):
|
|||||||
load_arr_mlx_npy = np.load(save_file_mlx)
|
load_arr_mlx_npy = np.load(save_file_mlx)
|
||||||
self.assertTrue(np.array_equal(load_arr_mlx_npy, save_arr_npy))
|
self.assertTrue(np.array_equal(load_arr_mlx_npy, save_arr_npy))
|
||||||
|
|
||||||
|
def test_save_and_load_safetensor(self):
|
||||||
|
if not os.path.isdir(self.test_dir):
|
||||||
|
os.mkdir(self.test_dir)
|
||||||
|
|
||||||
|
for dt in self.dtypes + ["bfloat16"]:
|
||||||
|
with self.subTest(dtype=dt):
|
||||||
|
for i, shape in enumerate([(1,), (23,), (1024, 1024), (4, 6, 3, 1, 2)]):
|
||||||
|
with self.subTest(shape=shape):
|
||||||
|
save_file_mlx = os.path.join(
|
||||||
|
self.test_dir, f"mlx_{dt}_{i}_fs.safetensors"
|
||||||
|
)
|
||||||
|
save_dict = {
|
||||||
|
"test": mx.random.normal(shape=shape, dtype=getattr(mx, dt))
|
||||||
|
if dt in ["float32", "float16", "bfloat16"]
|
||||||
|
else mx.ones(shape, dtype=getattr(mx, dt))
|
||||||
|
}
|
||||||
|
|
||||||
|
with open(save_file_mlx, "wb") as f:
|
||||||
|
mx.save_safetensor(f, save_dict)
|
||||||
|
with open(save_file_mlx, "rb") as f:
|
||||||
|
load_dict = mx.load_safetensor(f)
|
||||||
|
|
||||||
|
self.assertTrue("test" in load_dict)
|
||||||
|
self.assertTrue(
|
||||||
|
mx.array_equal(load_dict["test"], save_dict["test"])
|
||||||
|
)
|
||||||
|
|
||||||
def test_save_and_load_fs(self):
|
def test_save_and_load_fs(self):
|
||||||
|
|
||||||
if not os.path.isdir(self.test_dir):
|
if not os.path.isdir(self.test_dir):
|
||||||
|
Loading…
Reference in New Issue
Block a user