mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 00:04:41 +08:00
feat: add support for saving safetensors in the save_weights
(#497)
* feat: add save safetensors support in module save_weights * chore: checking missing changes * Update python/mlx/nn/layers/base.py Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * chore: update docstring for load_weights --------- Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
This commit is contained in:
@@ -54,16 +54,31 @@ class TestBase(mlx_tests.MLXTestCase):
|
||||
|
||||
m.apply_to_modules(assert_training)
|
||||
|
||||
def test_io(self):
|
||||
def test_save_npz_weights(self):
|
||||
def make_model():
|
||||
return nn.Sequential(nn.Linear(2, 2), nn.ReLU(), nn.Linear(2, 2))
|
||||
|
||||
m = make_model()
|
||||
tdir = tempfile.TemporaryDirectory()
|
||||
file = os.path.join(tdir.name, "model.npz")
|
||||
m.save_weights(file)
|
||||
npz_file = os.path.join(tdir.name, "model.npz")
|
||||
m.save_weights(npz_file)
|
||||
m_load = make_model()
|
||||
m_load.load_weights(file)
|
||||
m_load.load_weights(npz_file)
|
||||
tdir.cleanup()
|
||||
|
||||
eq_tree = tree_map(mx.array_equal, m.parameters(), m_load.parameters())
|
||||
self.assertTrue(all(tree_flatten(eq_tree)))
|
||||
|
||||
def test_save_safetensors_weights(self):
|
||||
def make_model():
|
||||
return nn.Sequential(nn.Linear(2, 2), nn.ReLU(), nn.Linear(2, 2))
|
||||
|
||||
m = make_model()
|
||||
tdir = tempfile.TemporaryDirectory()
|
||||
safetensors_file = os.path.join(tdir.name, "model.safetensors")
|
||||
m.save_weights(safetensors_file)
|
||||
m_load = make_model()
|
||||
m_load.load_weights(safetensors_file)
|
||||
tdir.cleanup()
|
||||
|
||||
eq_tree = tree_map(mx.array_equal, m.parameters(), m_load.parameters())
|
||||
|
Reference in New Issue
Block a user