mlx.nn.Module#
- class mlx.nn.Module#
Base class for building neural networks with MLX.
All the layers provided in
mlx.nn.layers
subclass this class and your models should do the same.A
Module
can contain otherModule
instances ormlx.core.array
instances in arbitrary nesting of python lists or dicts. TheModule
then allows recursively extracting all themlx.core.array
instances usingmlx.nn.Module.parameters()
.In addition, the
Module
has the concept of trainable and non trainable parameters (called “frozen”). When usingmlx.nn.value_and_grad()
the gradients are returned only with respect to the trainable parameters. All arrays in a module are trainable unless they are added in the “frozen” set by callingfreeze()
.import mlx.core as mx import mlx.nn as nn class MyMLP(nn.Module): def __init__(self, in_dims: int, out_dims: int, hidden_dims: int = 16): super().__init__() self.in_proj = nn.Linear(in_dims, hidden_dims) self.out_proj = nn.Linear(hidden_dims, out_dims) def __call__(self, x): x = self.in_proj(x) x = mx.maximum(x, 0) return self.out_proj(x) model = MyMLP(2, 1) # All the model parameters are created but since MLX is lazy by # default, they are not evaluated yet. Calling `mx.eval` actually # allocates memory and initializes the parameters. mx.eval(model.parameters()) # Setting a parameter to a new value is as simply as accessing that # parameter and assigning a new array to it. model.in_proj.weight = model.in_proj.weight * 2 mx.eval(model.parameters())
- __init__()#
Should be called by the subclasses of
Module
.
Methods
__init__
()Should be called by the subclasses of
Module
.apply
(map_fn[, filter_fn])Map all the parameters using the provided
map_fn
and immediately update the module with the mapped parameters.apply_to_modules
(apply_fn)Apply a function to all the modules in this instance (including this instance).
children
()Return the direct descendants of this Module instance.
clear
()copy
()eval
()filter_and_map
(filter_fn[, map_fn, is_leaf_fn])Recursively filter the contents of the module using
filter_fn
, namely only select keys and values wherefilter_fn
returns true.freeze
(*[, recurse, keys, strict])Freeze the Module's parameters or some of them.
fromkeys
([value])Create a new dictionary with keys from iterable and values set to value.
get
(key[, default])Return the value for key if key is in the dictionary, else default.
is_module
(value)items
()keys
()leaf_modules
()Return the submodules that do not contain other modules.
load_weights
(file)Load and update the model's weights from a .npz file.
modules
()Return a list with all the modules in this instance.
named_modules
()Return a list with all the modules in this instance and their name with dot notation.
parameters
()Recursively return all the
mlx.core.array
members of this Module as a dict of dicts and lists.pop
(k[,d])If key is not found, default is returned if given, otherwise KeyError is raised
popitem
()Remove and return a (key, value) pair as a 2-tuple.
save_weights
(file)Save the model's weights to a .npz file.
setdefault
(key[, default])Insert key with a value of default if key is not in the dictionary.
train
([mode])trainable_parameter_filter
(module, key, value)trainable_parameters
()Recursively return all the non frozen
mlx.core.array
members of this Module as a dict of dicts and lists.unfreeze
(*[, recurse, keys, strict])Unfreeze the Module's parameters or some of them.
update
(parameters)Replace the parameters of this Module with the provided ones in the dict of dicts and lists.
valid_child_filter
(module, key, value)valid_parameter_filter
(module, key, value)values
()Attributes
training