mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-21 04:31:48 +08:00
feat: add support for saving safetensors in the save_weights
(#497)
* feat: add save safetensors support in module save_weights * chore: checking missing changes * Update python/mlx/nn/layers/base.py Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * chore: update docstring for load_weights --------- Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
This commit is contained in:
@@ -96,11 +96,11 @@ class Module(dict):
|
||||
strict: bool = True,
|
||||
):
|
||||
"""
|
||||
Update the model's weights from a ``.npz`` or a list.
|
||||
Update the model's weights from a ``.npz``, a ``.safetensors`` file, or a list.
|
||||
|
||||
Args:
|
||||
file_or_weights (str or list(tuple(str, mx.array))): The path to
|
||||
the weights ``.npz`` file or a list of pairs of parameter names
|
||||
the weights ``.npz`` file (``.npz`` or ``.safetensors``) or a list of pairs of parameter names
|
||||
and arrays.
|
||||
strict (bool, optional): If ``True`` then checks that the provided
|
||||
weights exactly match the parameters of the model. Otherwise,
|
||||
@@ -118,6 +118,9 @@ class Module(dict):
|
||||
# Load from file
|
||||
model.load_weights("weights.npz")
|
||||
|
||||
# Load from .safetensors file
|
||||
model.load_weights("weights.safetensors")
|
||||
|
||||
# Load from list
|
||||
weights = [
|
||||
("weight", mx.random.uniform(shape=(10, 10))),
|
||||
@@ -166,9 +169,20 @@ class Module(dict):
|
||||
|
||||
def save_weights(self, file: str):
|
||||
"""
|
||||
Save the model's weights to a ``.npz`` file.
|
||||
Save the model's weights to a file. The saving method is determined by the file extension:
|
||||
- ``.npz`` will use :func:`mx.savez`
|
||||
- ``.safetensors`` will use :func:`mx.save_safetensors`
|
||||
"""
|
||||
mx.savez(file, **dict(tree_flatten(self.parameters())))
|
||||
params_dict = dict(tree_flatten(self.parameters()))
|
||||
|
||||
if file.endswith(".npz"):
|
||||
mx.savez(file, **params_dict)
|
||||
elif file.endswith(".safetensors"):
|
||||
mx.save_safetensors(file, params_dict)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unsupported file extension. Use '.npz' or '.safetensors'."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def is_module(value):
|
||||
|
Reference in New Issue
Block a user