mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-05 16:51:13 +08:00

* cuda graph prototype fix signal bug + start to add dependencies capture more capture more ops remaining ops fix reduce and rope deps add concurrent context try update, but not working cosistent topology order use node api use node api directly to reduce overhead fix bug use kernels in unary cache graph format fix synchronization format * comment
406 lines
16 KiB
Python
406 lines
16 KiB
Python
# Copyright © 2023 Apple Inc.
|
|
|
|
import os
|
|
import tempfile
|
|
import unittest
|
|
|
|
import mlx.core as mx
|
|
import mlx_tests
|
|
import numpy as np
|
|
|
|
|
|
class TestLoad(mlx_tests.MLXTestCase):
|
|
dtypes = [
|
|
"uint8",
|
|
"uint16",
|
|
"uint32",
|
|
"uint64",
|
|
"int8",
|
|
"int16",
|
|
"int32",
|
|
"int64",
|
|
"float32",
|
|
"float16",
|
|
"complex64",
|
|
]
|
|
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
cls.test_dir_fid = tempfile.TemporaryDirectory()
|
|
cls.test_dir = cls.test_dir_fid.name
|
|
if not os.path.isdir(cls.test_dir):
|
|
os.mkdir(cls.test_dir)
|
|
|
|
@classmethod
|
|
def tearDownClass(cls):
|
|
cls.test_dir_fid.cleanup()
|
|
|
|
def test_save_and_load(self):
|
|
for dt in self.dtypes:
|
|
with self.subTest(dtype=dt):
|
|
for i, shape in enumerate([(1,), (23,), (1024, 1024), (4, 6, 3, 1, 2)]):
|
|
with self.subTest(shape=shape):
|
|
save_file_mlx = os.path.join(self.test_dir, f"mlx_{dt}_{i}.npy")
|
|
save_file_npy = os.path.join(self.test_dir, f"npy_{dt}_{i}.npy")
|
|
|
|
save_arr = np.random.uniform(0.0, 32.0, size=shape)
|
|
save_arr_npy = save_arr.astype(getattr(np, dt))
|
|
save_arr_mlx = mx.array(save_arr_npy)
|
|
|
|
mx.save(save_file_mlx, save_arr_mlx)
|
|
np.save(save_file_npy, save_arr_npy)
|
|
|
|
# Load array saved by mlx as mlx array
|
|
load_arr_mlx_mlx = mx.load(save_file_mlx)
|
|
self.assertTrue(mx.array_equal(load_arr_mlx_mlx, save_arr_mlx))
|
|
|
|
# Load array saved by numpy as mlx array
|
|
load_arr_npy_mlx = mx.load(save_file_npy)
|
|
self.assertTrue(mx.array_equal(load_arr_npy_mlx, save_arr_mlx))
|
|
|
|
# Load array saved by mlx as numpy array
|
|
load_arr_mlx_npy = np.load(save_file_mlx)
|
|
self.assertTrue(np.array_equal(load_arr_mlx_npy, save_arr_npy))
|
|
|
|
def test_save_and_load_safetensors(self):
|
|
test_file = os.path.join(self.test_dir, "test.safetensors")
|
|
with self.assertRaises(Exception):
|
|
mx.save_safetensors(test_file, {"a": mx.ones((4, 4))}, {"testing": 0})
|
|
|
|
mx.save_safetensors(
|
|
test_file, {"test": mx.ones((2, 2))}, {"testing": "test", "format": "mlx"}
|
|
)
|
|
res = mx.load(test_file, return_metadata=True)
|
|
self.assertEqual(len(res), 2)
|
|
self.assertEqual(res[1], {"testing": "test", "format": "mlx"})
|
|
|
|
for dt in self.dtypes + ["bfloat16"]:
|
|
with self.subTest(dtype=dt):
|
|
for i, shape in enumerate([(1,), (23,), (1024, 1024), (4, 6, 3, 1, 2)]):
|
|
with self.subTest(shape=shape):
|
|
save_file_mlx = os.path.join(
|
|
self.test_dir, f"mlx_{dt}_{i}_fs.safetensors"
|
|
)
|
|
save_dict = {
|
|
"test": (
|
|
mx.random.normal(shape=shape, dtype=getattr(mx, dt))
|
|
if dt in ["float32", "float16", "bfloat16"]
|
|
else mx.ones(shape, dtype=getattr(mx, dt))
|
|
)
|
|
}
|
|
|
|
with open(save_file_mlx, "wb") as f:
|
|
mx.save_safetensors(f, save_dict)
|
|
with open(save_file_mlx, "rb") as f:
|
|
load_dict = mx.load(f)
|
|
|
|
self.assertTrue("test" in load_dict)
|
|
self.assertTrue(
|
|
mx.array_equal(load_dict["test"], save_dict["test"])
|
|
)
|
|
|
|
def test_save_and_load_gguf(self):
|
|
if not os.path.isdir(self.test_dir):
|
|
os.mkdir(self.test_dir)
|
|
|
|
# TODO: Add support for other dtypes (self.dtypes + ["bfloat16"])
|
|
supported_dtypes = ["float16", "float32", "int8", "int16", "int32"]
|
|
for dt in supported_dtypes:
|
|
with self.subTest(dtype=dt):
|
|
for i, shape in enumerate([(1,), (23,), (1024, 1024), (4, 6, 3, 1, 2)]):
|
|
with self.subTest(shape=shape):
|
|
save_file_mlx = os.path.join(
|
|
self.test_dir, f"mlx_{dt}_{i}_fs.gguf"
|
|
)
|
|
save_dict = {
|
|
"test": (
|
|
mx.random.normal(shape=shape, dtype=getattr(mx, dt))
|
|
if dt in ["float32", "float16", "bfloat16"]
|
|
else mx.ones(shape, dtype=getattr(mx, dt))
|
|
)
|
|
}
|
|
|
|
mx.save_gguf(save_file_mlx, save_dict)
|
|
load_dict = mx.load(save_file_mlx)
|
|
|
|
self.assertTrue("test" in load_dict)
|
|
self.assertTrue(
|
|
mx.array_equal(load_dict["test"], save_dict["test"])
|
|
)
|
|
|
|
def test_load_f8_e4m3(self):
|
|
if not os.path.isdir(self.test_dir):
|
|
os.mkdir(self.test_dir)
|
|
|
|
expected = [
|
|
0,
|
|
mx.nan,
|
|
mx.nan,
|
|
-0.875,
|
|
0.4375,
|
|
-0.005859,
|
|
-1.25,
|
|
-1.25,
|
|
-1.5,
|
|
-0.0039,
|
|
]
|
|
expected = mx.array(expected, dtype=mx.bfloat16)
|
|
contents = b'H\x00\x00\x00\x00\x00\x00\x00{"tensor":{"dtype":"F8_E4M3","shape":[10],"data_offsets":[0,10]}} \x00\x7f\xff\xb6.\x83\xba\xba\xbc\x82'
|
|
with tempfile.NamedTemporaryFile(suffix=".safetensors") as f:
|
|
f.write(contents)
|
|
f.seek(0)
|
|
out = mx.load(f)["tensor"]
|
|
self.assertTrue(mx.allclose(out[0], expected[0], equal_nan=True))
|
|
|
|
def test_save_and_load_gguf_metadata_basic(self):
|
|
if not os.path.isdir(self.test_dir):
|
|
os.mkdir(self.test_dir)
|
|
|
|
save_file_mlx = os.path.join(self.test_dir, f"mlx_gguf_with_metadata.gguf")
|
|
save_dict = {"test": mx.ones((4, 4), dtype=mx.int32)}
|
|
metadata = {}
|
|
|
|
# Empty works
|
|
mx.save_gguf(save_file_mlx, save_dict, metadata)
|
|
|
|
# Loads without the metadata
|
|
load_dict = mx.load(save_file_mlx)
|
|
self.assertTrue("test" in load_dict)
|
|
self.assertTrue(mx.array_equal(load_dict["test"], save_dict["test"]))
|
|
|
|
# Loads empty metadata
|
|
load_dict, meta_load_dict = mx.load(save_file_mlx, return_metadata=True)
|
|
self.assertTrue("test" in load_dict)
|
|
self.assertTrue(mx.array_equal(load_dict["test"], save_dict["test"]))
|
|
self.assertEqual(len(meta_load_dict), 0)
|
|
|
|
# Loads string metadata
|
|
metadata = {"meta": "data"}
|
|
mx.save_gguf(save_file_mlx, save_dict, metadata)
|
|
load_dict, meta_load_dict = mx.load(save_file_mlx, return_metadata=True)
|
|
self.assertTrue("test" in load_dict)
|
|
self.assertTrue(mx.array_equal(load_dict["test"], save_dict["test"]))
|
|
self.assertEqual(len(meta_load_dict), 1)
|
|
self.assertTrue("meta" in meta_load_dict)
|
|
self.assertEqual(meta_load_dict["meta"], "data")
|
|
|
|
def test_save_and_load_gguf_metadata_arrays(self):
|
|
if not os.path.isdir(self.test_dir):
|
|
os.mkdir(self.test_dir)
|
|
|
|
save_file_mlx = os.path.join(self.test_dir, f"mlx_gguf_with_metadata.gguf")
|
|
save_dict = {"test": mx.ones((4, 4), dtype=mx.int32)}
|
|
|
|
# Test scalars and one dimensional arrays
|
|
for t in [
|
|
mx.uint8,
|
|
mx.int8,
|
|
mx.uint16,
|
|
mx.int16,
|
|
mx.uint32,
|
|
mx.int32,
|
|
mx.uint64,
|
|
mx.int64,
|
|
mx.float32,
|
|
]:
|
|
for shape in [(), (2,)]:
|
|
arr = mx.random.uniform(shape=shape).astype(t)
|
|
metadata = {"meta": arr}
|
|
mx.save_gguf(save_file_mlx, save_dict, metadata)
|
|
_, meta_load_dict = mx.load(save_file_mlx, return_metadata=True)
|
|
self.assertEqual(len(meta_load_dict), 1)
|
|
self.assertTrue("meta" in meta_load_dict)
|
|
self.assertTrue(mx.array_equal(meta_load_dict["meta"], arr))
|
|
self.assertEqual(meta_load_dict["meta"].dtype, arr.dtype)
|
|
|
|
for t in [mx.float16, mx.bfloat16, mx.complex64]:
|
|
with self.assertRaises(ValueError):
|
|
arr = mx.array(1, t)
|
|
metadata = {"meta": arr}
|
|
mx.save_gguf(save_file_mlx, save_dict, metadata)
|
|
|
|
def test_save_and_load_gguf_metadata_mixed(self):
|
|
if not os.path.isdir(self.test_dir):
|
|
os.mkdir(self.test_dir)
|
|
|
|
save_file_mlx = os.path.join(self.test_dir, f"mlx_gguf_with_metadata.gguf")
|
|
save_dict = {"test": mx.ones((4, 4), dtype=mx.int32)}
|
|
|
|
# Test string and array
|
|
arr = mx.array(1.5)
|
|
metadata = {"meta1": arr, "meta2": "data"}
|
|
mx.save_gguf(save_file_mlx, save_dict, metadata)
|
|
_, meta_load_dict = mx.load(save_file_mlx, return_metadata=True)
|
|
self.assertEqual(len(meta_load_dict), 2)
|
|
self.assertTrue("meta1" in meta_load_dict)
|
|
self.assertTrue(mx.array_equal(meta_load_dict["meta1"], arr))
|
|
self.assertEqual(meta_load_dict["meta1"].dtype, arr.dtype)
|
|
self.assertTrue("meta2" in meta_load_dict)
|
|
self.assertEqual(meta_load_dict["meta2"], "data")
|
|
|
|
# Test list of strings
|
|
metadata = {"meta": ["data1", "data2", "data345"]}
|
|
mx.save_gguf(save_file_mlx, save_dict, metadata)
|
|
_, meta_load_dict = mx.load(save_file_mlx, return_metadata=True)
|
|
self.assertEqual(len(meta_load_dict), 1)
|
|
self.assertEqual(meta_load_dict["meta"], metadata["meta"])
|
|
|
|
# Test a combination of stuff
|
|
metadata = {
|
|
"meta1": ["data1", "data2", "data345"],
|
|
"meta2": mx.array([1, 2, 3, 4]),
|
|
"meta3": "data",
|
|
"meta4": mx.array(1.5),
|
|
}
|
|
mx.save_gguf(save_file_mlx, save_dict, metadata)
|
|
_, meta_load_dict = mx.load(save_file_mlx, return_metadata=True)
|
|
self.assertEqual(len(meta_load_dict), 4)
|
|
for k, v in metadata.items():
|
|
if isinstance(v, mx.array):
|
|
self.assertTrue(mx.array_equal(meta_load_dict[k], v))
|
|
else:
|
|
self.assertEqual(meta_load_dict[k], v)
|
|
|
|
def test_save_and_load_fs(self):
|
|
if not os.path.isdir(self.test_dir):
|
|
os.mkdir(self.test_dir)
|
|
|
|
for dt in self.dtypes:
|
|
with self.subTest(dtype=dt):
|
|
for i, shape in enumerate([(1,), (23,), (1024, 1024), (4, 6, 3, 1, 2)]):
|
|
with self.subTest(shape=shape):
|
|
save_file_mlx = os.path.join(
|
|
self.test_dir, f"mlx_{dt}_{i}_fs.npy"
|
|
)
|
|
save_file_npy = os.path.join(
|
|
self.test_dir, f"npy_{dt}_{i}_fs.npy"
|
|
)
|
|
|
|
save_arr = np.random.uniform(0.0, 32.0, size=shape)
|
|
save_arr_npy = save_arr.astype(getattr(np, dt))
|
|
save_arr_mlx = mx.array(save_arr_npy)
|
|
|
|
with open(save_file_mlx, "wb") as f:
|
|
mx.save(f, save_arr_mlx)
|
|
|
|
np.save(save_file_npy, save_arr_npy)
|
|
|
|
# Load array saved by mlx as mlx array
|
|
with open(save_file_mlx, "rb") as f:
|
|
load_arr_mlx_mlx = mx.load(f)
|
|
self.assertTrue(mx.array_equal(load_arr_mlx_mlx, save_arr_mlx))
|
|
|
|
# Load array saved by numpy as mlx array
|
|
with open(save_file_npy, "rb") as f:
|
|
load_arr_npy_mlx = mx.load(f)
|
|
self.assertTrue(mx.array_equal(load_arr_npy_mlx, save_arr_mlx))
|
|
|
|
# Load array saved by mlx as numpy array
|
|
load_arr_mlx_npy = np.load(save_file_mlx)
|
|
self.assertTrue(np.array_equal(load_arr_mlx_npy, save_arr_npy))
|
|
|
|
def test_savez_and_loadz(self):
|
|
if not os.path.isdir(self.test_dir):
|
|
os.mkdir(self.test_dir)
|
|
|
|
for dt in self.dtypes:
|
|
with self.subTest(dtype=dt):
|
|
shapes = [(6,), (6, 6), (4, 1, 3, 1, 2)]
|
|
save_file_mlx_uncomp = os.path.join(
|
|
self.test_dir, f"mlx_{dt}_uncomp.npz"
|
|
)
|
|
save_file_npy_uncomp = os.path.join(
|
|
self.test_dir, f"npy_{dt}_uncomp.npz"
|
|
)
|
|
save_file_mlx_comp = os.path.join(self.test_dir, f"mlx_{dt}_comp.npz")
|
|
save_file_npy_comp = os.path.join(self.test_dir, f"npy_{dt}_comp.npz")
|
|
|
|
# Make dictionary of multiple
|
|
save_arrs_npy = {
|
|
f"save_arr_{i}": np.random.uniform(
|
|
0.0, 32.0, size=shapes[i]
|
|
).astype(getattr(np, dt))
|
|
for i in range(len(shapes))
|
|
}
|
|
save_arrs_mlx = {k: mx.array(v) for k, v in save_arrs_npy.items()}
|
|
|
|
# Save as npz files
|
|
np.savez(save_file_npy_uncomp, **save_arrs_npy)
|
|
mx.savez(save_file_mlx_uncomp, **save_arrs_mlx)
|
|
np.savez_compressed(save_file_npy_comp, **save_arrs_npy)
|
|
mx.savez_compressed(save_file_mlx_comp, **save_arrs_mlx)
|
|
|
|
for save_file_npy, save_file_mlx in (
|
|
(save_file_npy_uncomp, save_file_mlx_uncomp),
|
|
(save_file_npy_comp, save_file_mlx_comp),
|
|
):
|
|
# Load array saved by mlx as mlx array
|
|
load_arr_mlx_mlx = mx.load(save_file_mlx)
|
|
for k, v in load_arr_mlx_mlx.items():
|
|
self.assertTrue(mx.array_equal(save_arrs_mlx[k], v))
|
|
|
|
# Load arrays saved by numpy as mlx arrays
|
|
load_arr_npy_mlx = mx.load(save_file_npy)
|
|
for k, v in load_arr_npy_mlx.items():
|
|
self.assertTrue(mx.array_equal(save_arrs_mlx[k], v))
|
|
|
|
# Load array saved by mlx as numpy array
|
|
load_arr_mlx_npy = np.load(save_file_mlx)
|
|
for k, v in load_arr_mlx_npy.items():
|
|
self.assertTrue(np.array_equal(save_arrs_npy[k], v))
|
|
|
|
def test_non_contiguous(self):
|
|
a = mx.broadcast_to(mx.array([1, 2]), [4, 2])
|
|
|
|
save_file = os.path.join(self.test_dir, "a.npy")
|
|
mx.save(save_file, a)
|
|
aload = mx.load(save_file)
|
|
self.assertTrue(mx.array_equal(a, aload))
|
|
|
|
save_file = os.path.join(self.test_dir, "a.safetensors")
|
|
mx.save_safetensors(save_file, {"a": a})
|
|
aload = mx.load(save_file)["a"]
|
|
self.assertTrue(mx.array_equal(a, aload))
|
|
|
|
save_file = os.path.join(self.test_dir, "a.gguf")
|
|
mx.save_gguf(save_file, {"a": a})
|
|
aload = mx.load(save_file)["a"]
|
|
self.assertTrue(mx.array_equal(a, aload))
|
|
|
|
# safetensors and gguf only work with row contiguous
|
|
# make sure col contiguous is handled properly
|
|
save_file = os.path.join(self.test_dir, "a.safetensors")
|
|
a = mx.arange(4).reshape(2, 2).T
|
|
mx.save_safetensors(save_file, {"a": a})
|
|
aload = mx.load(save_file)["a"]
|
|
self.assertTrue(mx.array_equal(a, aload))
|
|
|
|
save_file = os.path.join(self.test_dir, "a.gguf")
|
|
mx.save_gguf(save_file, {"a": a})
|
|
aload = mx.load(save_file)["a"]
|
|
self.assertTrue(mx.array_equal(a, aload))
|
|
|
|
def test_load_donation(self):
|
|
x = mx.random.normal((1024,))
|
|
mx.eval(x)
|
|
save_file = os.path.join(self.test_dir, "donation.npy")
|
|
mx.save(save_file, x)
|
|
mx.synchronize()
|
|
|
|
mx.reset_peak_memory()
|
|
scale = mx.array(2.0)
|
|
y = mx.load(save_file)
|
|
mx.eval(y)
|
|
mx.synchronize()
|
|
load_only = mx.get_peak_memory()
|
|
y = mx.load(save_file) * scale
|
|
mx.eval(y)
|
|
mx.synchronize()
|
|
load_with_binary = mx.get_peak_memory()
|
|
|
|
self.assertEqual(load_only, load_with_binary)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
mlx_tests.MLXTestRunner()
|