feat: add support for saving safetensors in the save_weights (#497)

* feat: add save safetensors support in module save_weights

* chore: checking missing changes

* Update python/mlx/nn/layers/base.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* chore: update docstring for load_weights

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
This commit is contained in:
Anchen
2024-01-19 06:19:33 -08:00
committed by GitHub
parent c4ec836523
commit f6feb61f92
2 changed files with 37 additions and 8 deletions

View File

@@ -54,16 +54,31 @@ class TestBase(mlx_tests.MLXTestCase):
m.apply_to_modules(assert_training)
def test_io(self):
def test_save_npz_weights(self):
def make_model():
return nn.Sequential(nn.Linear(2, 2), nn.ReLU(), nn.Linear(2, 2))
m = make_model()
tdir = tempfile.TemporaryDirectory()
file = os.path.join(tdir.name, "model.npz")
m.save_weights(file)
npz_file = os.path.join(tdir.name, "model.npz")
m.save_weights(npz_file)
m_load = make_model()
m_load.load_weights(file)
m_load.load_weights(npz_file)
tdir.cleanup()
eq_tree = tree_map(mx.array_equal, m.parameters(), m_load.parameters())
self.assertTrue(all(tree_flatten(eq_tree)))
def test_save_safetensors_weights(self):
def make_model():
return nn.Sequential(nn.Linear(2, 2), nn.ReLU(), nn.Linear(2, 2))
m = make_model()
tdir = tempfile.TemporaryDirectory()
safetensors_file = os.path.join(tdir.name, "model.safetensors")
m.save_weights(safetensors_file)
m_load = make_model()
m_load.load_weights(safetensors_file)
tdir.cleanup()
eq_tree = tree_map(mx.array_equal, m.parameters(), m_load.parameters())