mlx.nn.Module.load_weights#
- Module.load_weights(file_or_weights: Union[str, List[Tuple[str, array]]], strict: bool = True)#
Update the model’s weights from a
.npz
or a list.- Parameters:
file_or_weights (str or list(tuple(str, mx.array))) – The path to the weights
.npz
file or a list of pairs of parameter names and arrays.strict (bool, optional) – If
True
then checks that the provided weights exactly match the parameters of the model. Otherwise, only the weights actually contained in the model are loaded and shapes are not checked. Default:True
.
Example
import mlx.core as mx import mlx.nn as nn model = nn.Linear(10, 10) # Load from file model.load_weights("weights.npz") # Load from list weights = [ ("weight", mx.random.uniform(shape=(10, 10))), ("bias", mx.zeros((10,))), ] model.load_weights(weights) # Missing weight weights = [ ("weight", mx.random.uniform(shape=(10, 10))), ] # Raises a ValueError exception model.load_weights(weights) # Ok, only updates the weight but not the bias model.load_weights(weights, strict=False)