diff --git a/python/mlx/nn/layers/base.py b/python/mlx/nn/layers/base.py index 094c89326..3da1993ec 100644 --- a/python/mlx/nn/layers/base.py +++ b/python/mlx/nn/layers/base.py @@ -96,11 +96,11 @@ class Module(dict): strict: bool = True, ): """ - Update the model's weights from a ``.npz`` or a list. + Update the model's weights from a ``.npz``, a ``.safetensors`` file, or a list. Args: file_or_weights (str or list(tuple(str, mx.array))): The path to - the weights ``.npz`` file or a list of pairs of parameter names + the weights ``.npz`` file (``.npz`` or ``.safetensors``) or a list of pairs of parameter names and arrays. strict (bool, optional): If ``True`` then checks that the provided weights exactly match the parameters of the model. Otherwise, @@ -118,6 +118,9 @@ class Module(dict): # Load from file model.load_weights("weights.npz") + # Load from .safetensors file + model.load_weights("weights.safetensors") + # Load from list weights = [ ("weight", mx.random.uniform(shape=(10, 10))), @@ -166,9 +169,20 @@ class Module(dict): def save_weights(self, file: str): """ - Save the model's weights to a ``.npz`` file. + Save the model's weights to a file. The saving method is determined by the file extension: + - ``.npz`` will use :func:`mx.savez` + - ``.safetensors`` will use :func:`mx.save_safetensors` """ - mx.savez(file, **dict(tree_flatten(self.parameters()))) + params_dict = dict(tree_flatten(self.parameters())) + + if file.endswith(".npz"): + mx.savez(file, **params_dict) + elif file.endswith(".safetensors"): + mx.save_safetensors(file, params_dict) + else: + raise ValueError( + "Unsupported file extension. Use '.npz' or '.safetensors'." + ) @staticmethod def is_module(value): diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index 9ae8a2cd1..2893af8e7 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -54,16 +54,31 @@ class TestBase(mlx_tests.MLXTestCase): m.apply_to_modules(assert_training) - def test_io(self): + def test_save_npz_weights(self): def make_model(): return nn.Sequential(nn.Linear(2, 2), nn.ReLU(), nn.Linear(2, 2)) m = make_model() tdir = tempfile.TemporaryDirectory() - file = os.path.join(tdir.name, "model.npz") - m.save_weights(file) + npz_file = os.path.join(tdir.name, "model.npz") + m.save_weights(npz_file) m_load = make_model() - m_load.load_weights(file) + m_load.load_weights(npz_file) + tdir.cleanup() + + eq_tree = tree_map(mx.array_equal, m.parameters(), m_load.parameters()) + self.assertTrue(all(tree_flatten(eq_tree))) + + def test_save_safetensors_weights(self): + def make_model(): + return nn.Sequential(nn.Linear(2, 2), nn.ReLU(), nn.Linear(2, 2)) + + m = make_model() + tdir = tempfile.TemporaryDirectory() + safetensors_file = os.path.join(tdir.name, "model.safetensors") + m.save_weights(safetensors_file) + m_load = make_model() + m_load.load_weights(safetensors_file) tdir.cleanup() eq_tree = tree_map(mx.array_equal, m.parameters(), m_load.parameters())