mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 01:17:26 +08:00
feat: add support for saving safetensors in the save_weights
(#497)
* feat: add save safetensors support in module save_weights * chore: checking missing changes * Update python/mlx/nn/layers/base.py Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * chore: update docstring for load_weights --------- Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
This commit is contained in:
parent
c4ec836523
commit
f6feb61f92
@ -96,11 +96,11 @@ class Module(dict):
|
||||
strict: bool = True,
|
||||
):
|
||||
"""
|
||||
Update the model's weights from a ``.npz`` or a list.
|
||||
Update the model's weights from a ``.npz``, a ``.safetensors`` file, or a list.
|
||||
|
||||
Args:
|
||||
file_or_weights (str or list(tuple(str, mx.array))): The path to
|
||||
the weights ``.npz`` file or a list of pairs of parameter names
|
||||
the weights ``.npz`` file (``.npz`` or ``.safetensors``) or a list of pairs of parameter names
|
||||
and arrays.
|
||||
strict (bool, optional): If ``True`` then checks that the provided
|
||||
weights exactly match the parameters of the model. Otherwise,
|
||||
@ -118,6 +118,9 @@ class Module(dict):
|
||||
# Load from file
|
||||
model.load_weights("weights.npz")
|
||||
|
||||
# Load from .safetensors file
|
||||
model.load_weights("weights.safetensors")
|
||||
|
||||
# Load from list
|
||||
weights = [
|
||||
("weight", mx.random.uniform(shape=(10, 10))),
|
||||
@ -166,9 +169,20 @@ class Module(dict):
|
||||
|
||||
def save_weights(self, file: str):
|
||||
"""
|
||||
Save the model's weights to a ``.npz`` file.
|
||||
Save the model's weights to a file. The saving method is determined by the file extension:
|
||||
- ``.npz`` will use :func:`mx.savez`
|
||||
- ``.safetensors`` will use :func:`mx.save_safetensors`
|
||||
"""
|
||||
mx.savez(file, **dict(tree_flatten(self.parameters())))
|
||||
params_dict = dict(tree_flatten(self.parameters()))
|
||||
|
||||
if file.endswith(".npz"):
|
||||
mx.savez(file, **params_dict)
|
||||
elif file.endswith(".safetensors"):
|
||||
mx.save_safetensors(file, params_dict)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unsupported file extension. Use '.npz' or '.safetensors'."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def is_module(value):
|
||||
|
@ -54,16 +54,31 @@ class TestBase(mlx_tests.MLXTestCase):
|
||||
|
||||
m.apply_to_modules(assert_training)
|
||||
|
||||
def test_io(self):
|
||||
def test_save_npz_weights(self):
|
||||
def make_model():
|
||||
return nn.Sequential(nn.Linear(2, 2), nn.ReLU(), nn.Linear(2, 2))
|
||||
|
||||
m = make_model()
|
||||
tdir = tempfile.TemporaryDirectory()
|
||||
file = os.path.join(tdir.name, "model.npz")
|
||||
m.save_weights(file)
|
||||
npz_file = os.path.join(tdir.name, "model.npz")
|
||||
m.save_weights(npz_file)
|
||||
m_load = make_model()
|
||||
m_load.load_weights(file)
|
||||
m_load.load_weights(npz_file)
|
||||
tdir.cleanup()
|
||||
|
||||
eq_tree = tree_map(mx.array_equal, m.parameters(), m_load.parameters())
|
||||
self.assertTrue(all(tree_flatten(eq_tree)))
|
||||
|
||||
def test_save_safetensors_weights(self):
|
||||
def make_model():
|
||||
return nn.Sequential(nn.Linear(2, 2), nn.ReLU(), nn.Linear(2, 2))
|
||||
|
||||
m = make_model()
|
||||
tdir = tempfile.TemporaryDirectory()
|
||||
safetensors_file = os.path.join(tdir.name, "model.safetensors")
|
||||
m.save_weights(safetensors_file)
|
||||
m_load = make_model()
|
||||
m_load.load_weights(safetensors_file)
|
||||
tdir.cleanup()
|
||||
|
||||
eq_tree = tree_map(mx.array_equal, m.parameters(), m_load.parameters())
|
||||
|
Loading…
Reference in New Issue
Block a user