mlx.nn.Module.load_weights#
- Module.load_weights(file_or_weights: Union[str, List[Tuple[str, array]]], strict: bool = True)#
Update the model’s weights from a
.npzor a list.- Parameters:
file_or_weights (str or list(tuple(str, mx.array))) – The path to the weights
.npzfile or a list of pairs of parameter names and arrays.strict (bool, optional) – If
Truethen checks that the provided weights exactly match the parameters of the model. Otherwise, only the weights actually contained in the model are loaded and shapes are not checked. Default:True.
Example
import mlx.core as mx import mlx.nn as nn model = nn.Linear(10, 10) # Load from file model.load_weights("weights.npz") # Load from list weights = [ ("weight", mx.random.uniform(shape=(10, 10))), ("bias", mx.zeros((10,))), ] model.load_weights(weights) # Missing weight weights = [ ("weight", mx.random.uniform(shape=(10, 10))), ] # Raises a ValueError exception model.load_weights(weights) # Ok, only updates the weight but not the bias model.load_weights(weights, strict=False)