mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-29 04:35:36 +08:00
add load with path tests
This commit is contained in:
parent
3dcb286baf
commit
c2511fd83a
@ -3,6 +3,7 @@
|
|||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx_tests
|
import mlx_tests
|
||||||
@ -62,17 +63,28 @@ class TestLoad(mlx_tests.MLXTestCase):
|
|||||||
load_arr_mlx_npy = np.load(save_file_mlx)
|
load_arr_mlx_npy = np.load(save_file_mlx)
|
||||||
self.assertTrue(np.array_equal(load_arr_mlx_npy, save_arr_npy))
|
self.assertTrue(np.array_equal(load_arr_mlx_npy, save_arr_npy))
|
||||||
|
|
||||||
|
save_file = os.path.join(self.test_dir, f"mlx_path.npy")
|
||||||
|
save_arr = mx.ones((32,))
|
||||||
|
mx.save(Path(save_file), save_arr)
|
||||||
|
|
||||||
|
# Load array saved by mlx as mlx array
|
||||||
|
load_arr = mx.load(Path(save_file))
|
||||||
|
self.assertTrue(mx.array_equal(load_arr, save_arr))
|
||||||
|
|
||||||
def test_save_and_load_safetensors(self):
|
def test_save_and_load_safetensors(self):
|
||||||
test_file = os.path.join(self.test_dir, "test.safetensors")
|
test_file = os.path.join(self.test_dir, "test.safetensors")
|
||||||
with self.assertRaises(Exception):
|
with self.assertRaises(Exception):
|
||||||
mx.save_safetensors(test_file, {"a": mx.ones((4, 4))}, {"testing": 0})
|
mx.save_safetensors(test_file, {"a": mx.ones((4, 4))}, {"testing": 0})
|
||||||
|
|
||||||
mx.save_safetensors(
|
for obj in [str, Path]:
|
||||||
test_file, {"test": mx.ones((2, 2))}, {"testing": "test", "format": "mlx"}
|
mx.save_safetensors(
|
||||||
)
|
obj(test_file),
|
||||||
res = mx.load(test_file, return_metadata=True)
|
{"test": mx.ones((2, 2))},
|
||||||
self.assertEqual(len(res), 2)
|
{"testing": "test", "format": "mlx"},
|
||||||
self.assertEqual(res[1], {"testing": "test", "format": "mlx"})
|
)
|
||||||
|
res = mx.load(obj(test_file), return_metadata=True)
|
||||||
|
self.assertEqual(len(res), 2)
|
||||||
|
self.assertEqual(res[1], {"testing": "test", "format": "mlx"})
|
||||||
|
|
||||||
for dt in self.dtypes + ["bfloat16"]:
|
for dt in self.dtypes + ["bfloat16"]:
|
||||||
with self.subTest(dtype=dt):
|
with self.subTest(dtype=dt):
|
||||||
@ -128,6 +140,13 @@ class TestLoad(mlx_tests.MLXTestCase):
|
|||||||
mx.array_equal(load_dict["test"], save_dict["test"])
|
mx.array_equal(load_dict["test"], save_dict["test"])
|
||||||
)
|
)
|
||||||
|
|
||||||
|
save_file_mlx = os.path.join(self.test_dir, f"mlx_path_test_fs.gguf")
|
||||||
|
save_dict = {"test": mx.ones(shape)}
|
||||||
|
mx.save_gguf(Path(save_file_mlx), save_dict)
|
||||||
|
load_dict = mx.load(Path(save_file_mlx))
|
||||||
|
self.assertTrue("test" in load_dict)
|
||||||
|
self.assertTrue(mx.array_equal(load_dict["test"], save_dict["test"]))
|
||||||
|
|
||||||
def test_load_f8_e4m3(self):
|
def test_load_f8_e4m3(self):
|
||||||
if not os.path.isdir(self.test_dir):
|
if not os.path.isdir(self.test_dir):
|
||||||
os.mkdir(self.test_dir)
|
os.mkdir(self.test_dir)
|
||||||
|
Loading…
Reference in New Issue
Block a user