From f6feb61f92c52452bbdd2cb6dbd34653f2f8b0d3 Mon Sep 17 00:00:00 2001
From: Anchen
Date: Fri, 19 Jan 2024 06:19:33 -0800
Subject: [PATCH] feat: add support for saving safetensors in the
`save_weights` (#497)
* feat: add save safetensors support in module save_weights
* chore: checking missing changes
* Update python/mlx/nn/layers/base.py
Co-authored-by: Awni Hannun
* chore: update docstring for load_weights
---------
Co-authored-by: Awni Hannun
---
python/mlx/nn/layers/base.py | 22 ++++++++++++++++++----
python/tests/test_nn.py | 23 +++++++++++++++++++----
2 files changed, 37 insertions(+), 8 deletions(-)
diff --git a/python/mlx/nn/layers/base.py b/python/mlx/nn/layers/base.py
index 094c89326..3da1993ec 100644
--- a/python/mlx/nn/layers/base.py
+++ b/python/mlx/nn/layers/base.py
@@ -96,11 +96,11 @@ class Module(dict):
strict: bool = True,
):
"""
- Update the model's weights from a ``.npz`` or a list.
+ Update the model's weights from a ``.npz``, a ``.safetensors`` file, or a list.
Args:
file_or_weights (str or list(tuple(str, mx.array))): The path to
- the weights ``.npz`` file or a list of pairs of parameter names
+ the weights ``.npz`` file (``.npz`` or ``.safetensors``) or a list of pairs of parameter names
and arrays.
strict (bool, optional): If ``True`` then checks that the provided
weights exactly match the parameters of the model. Otherwise,
@@ -118,6 +118,9 @@ class Module(dict):
# Load from file
model.load_weights("weights.npz")
+ # Load from .safetensors file
+ model.load_weights("weights.safetensors")
+
# Load from list
weights = [
("weight", mx.random.uniform(shape=(10, 10))),
@@ -166,9 +169,20 @@ class Module(dict):
def save_weights(self, file: str):
"""
- Save the model's weights to a ``.npz`` file.
+ Save the model's weights to a file. The saving method is determined by the file extension:
+ - ``.npz`` will use :func:`mx.savez`
+ - ``.safetensors`` will use :func:`mx.save_safetensors`
"""
- mx.savez(file, **dict(tree_flatten(self.parameters())))
+ params_dict = dict(tree_flatten(self.parameters()))
+
+ if file.endswith(".npz"):
+ mx.savez(file, **params_dict)
+ elif file.endswith(".safetensors"):
+ mx.save_safetensors(file, params_dict)
+ else:
+ raise ValueError(
+ "Unsupported file extension. Use '.npz' or '.safetensors'."
+ )
@staticmethod
def is_module(value):
diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py
index 9ae8a2cd1..2893af8e7 100644
--- a/python/tests/test_nn.py
+++ b/python/tests/test_nn.py
@@ -54,16 +54,31 @@ class TestBase(mlx_tests.MLXTestCase):
m.apply_to_modules(assert_training)
- def test_io(self):
+ def test_save_npz_weights(self):
def make_model():
return nn.Sequential(nn.Linear(2, 2), nn.ReLU(), nn.Linear(2, 2))
m = make_model()
tdir = tempfile.TemporaryDirectory()
- file = os.path.join(tdir.name, "model.npz")
- m.save_weights(file)
+ npz_file = os.path.join(tdir.name, "model.npz")
+ m.save_weights(npz_file)
m_load = make_model()
- m_load.load_weights(file)
+ m_load.load_weights(npz_file)
+ tdir.cleanup()
+
+ eq_tree = tree_map(mx.array_equal, m.parameters(), m_load.parameters())
+ self.assertTrue(all(tree_flatten(eq_tree)))
+
+ def test_save_safetensors_weights(self):
+ def make_model():
+ return nn.Sequential(nn.Linear(2, 2), nn.ReLU(), nn.Linear(2, 2))
+
+ m = make_model()
+ tdir = tempfile.TemporaryDirectory()
+ safetensors_file = os.path.join(tdir.name, "model.safetensors")
+ m.save_weights(safetensors_file)
+ m_load = make_model()
+ m_load.load_weights(safetensors_file)
tdir.cleanup()
eq_tree = tree_map(mx.array_equal, m.parameters(), m_load.parameters())