feat: add support for saving safetensors in the save_weights (#497)

* feat: add save safetensors support in module save_weights

* chore: checking missing changes

* Update python/mlx/nn/layers/base.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* chore: update docstring for load_weights

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
This commit is contained in:
Anchen 2024-01-19 06:19:33 -08:00 committed by GitHub
parent c4ec836523
commit f6feb61f92
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 37 additions and 8 deletions

View File

@ -96,11 +96,11 @@ class Module(dict):
strict: bool = True,
):
"""
Update the model's weights from a ``.npz`` or a list.
Update the model's weights from a ``.npz``, a ``.safetensors`` file, or a list.
Args:
file_or_weights (str or list(tuple(str, mx.array))): The path to
the weights ``.npz`` file or a list of pairs of parameter names
the weights ``.npz`` file (``.npz`` or ``.safetensors``) or a list of pairs of parameter names
and arrays.
strict (bool, optional): If ``True`` then checks that the provided
weights exactly match the parameters of the model. Otherwise,
@ -118,6 +118,9 @@ class Module(dict):
# Load from file
model.load_weights("weights.npz")
# Load from .safetensors file
model.load_weights("weights.safetensors")
# Load from list
weights = [
("weight", mx.random.uniform(shape=(10, 10))),
@ -166,9 +169,20 @@ class Module(dict):
def save_weights(self, file: str):
"""
Save the model's weights to a ``.npz`` file.
Save the model's weights to a file. The saving method is determined by the file extension:
- ``.npz`` will use :func:`mx.savez`
- ``.safetensors`` will use :func:`mx.save_safetensors`
"""
mx.savez(file, **dict(tree_flatten(self.parameters())))
params_dict = dict(tree_flatten(self.parameters()))
if file.endswith(".npz"):
mx.savez(file, **params_dict)
elif file.endswith(".safetensors"):
mx.save_safetensors(file, params_dict)
else:
raise ValueError(
"Unsupported file extension. Use '.npz' or '.safetensors'."
)
@staticmethod
def is_module(value):

View File

@ -54,16 +54,31 @@ class TestBase(mlx_tests.MLXTestCase):
m.apply_to_modules(assert_training)
def test_io(self):
def test_save_npz_weights(self):
def make_model():
return nn.Sequential(nn.Linear(2, 2), nn.ReLU(), nn.Linear(2, 2))
m = make_model()
tdir = tempfile.TemporaryDirectory()
file = os.path.join(tdir.name, "model.npz")
m.save_weights(file)
npz_file = os.path.join(tdir.name, "model.npz")
m.save_weights(npz_file)
m_load = make_model()
m_load.load_weights(file)
m_load.load_weights(npz_file)
tdir.cleanup()
eq_tree = tree_map(mx.array_equal, m.parameters(), m_load.parameters())
self.assertTrue(all(tree_flatten(eq_tree)))
def test_save_safetensors_weights(self):
def make_model():
return nn.Sequential(nn.Linear(2, 2), nn.ReLU(), nn.Linear(2, 2))
m = make_model()
tdir = tempfile.TemporaryDirectory()
safetensors_file = os.path.join(tdir.name, "model.safetensors")
m.save_weights(safetensors_file)
m_load = make_model()
m_load.load_weights(safetensors_file)
tdir.cleanup()
eq_tree = tree_map(mx.array_equal, m.parameters(), m_load.parameters())