mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 08:10:15 +08:00
Export / import functions to / from a file (#1642)
* export and import functions * refactor + works for few primitives * nit * allow primitives with state * nit * nit * simplify serialize / deserialize * fix for constants * python bindings * maybe fix serialize failure case * add example * more primitives, training kind of works * same result for python and c++ * some fixes * fix export * template it up * some simplificatoin * rebase * allow kwargs and multiple functions * exporter * more primitives for exporting * deal with endianness * handle invalid stream * add docstring
This commit is contained in:
@@ -28,15 +28,14 @@ class TestLoad(mlx_tests.MLXTestCase):
|
||||
def setUpClass(cls):
|
||||
cls.test_dir_fid = tempfile.TemporaryDirectory()
|
||||
cls.test_dir = cls.test_dir_fid.name
|
||||
if not os.path.isdir(cls.test_dir):
|
||||
os.mkdir(cls.test_dir)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
cls.test_dir_fid.cleanup()
|
||||
|
||||
def test_save_and_load(self):
|
||||
if not os.path.isdir(self.test_dir):
|
||||
os.mkdir(self.test_dir)
|
||||
|
||||
for dt in self.dtypes:
|
||||
with self.subTest(dtype=dt):
|
||||
for i, shape in enumerate([(1,), (23,), (1024, 1024), (4, 6, 3, 1, 2)]):
|
||||
@@ -64,9 +63,6 @@ class TestLoad(mlx_tests.MLXTestCase):
|
||||
self.assertTrue(np.array_equal(load_arr_mlx_npy, save_arr_npy))
|
||||
|
||||
def test_save_and_load_safetensors(self):
|
||||
if not os.path.isdir(self.test_dir):
|
||||
os.mkdir(self.test_dir)
|
||||
|
||||
test_file = os.path.join(self.test_dir, "test.safetensors")
|
||||
with self.assertRaises(Exception):
|
||||
mx.save_safetensors(test_file, {"a": mx.ones((4, 4))}, {"testing": 0})
|
||||
@@ -330,9 +326,6 @@ class TestLoad(mlx_tests.MLXTestCase):
|
||||
self.assertTrue(np.array_equal(save_arrs_npy[k], v))
|
||||
|
||||
def test_non_contiguous(self):
|
||||
if not os.path.isdir(self.test_dir):
|
||||
os.mkdir(self.test_dir)
|
||||
|
||||
a = mx.broadcast_to(mx.array([1, 2]), [4, 2])
|
||||
|
||||
save_file = os.path.join(self.test_dir, "a.npy")
|
||||
|
Reference in New Issue
Block a user