mlx.utils.tree_map
- mlx.utils.tree_map(fn, tree, *rest)
Applies
fn
to the leaves of the python treetree
and returns a new collection with the results.If
rest
is provided, every item is assumed to be a superset oftree
and the corresponding leaves are provided as extra positional arguments tofn
. In that respect,tree_map()
is closer toitertools.starmap()
than tomap()
.import mlx.nn as nn from mlx.utils import tree_map model = nn.Linear(10, 10) print(model.parameters().keys()) # dict_keys(['weight', 'bias']) # square the parameters model.update(tree_map(lambda x: x*x, model.parameters()))
- Parameters:
fn (Callable) – The function that processes the leaves of the tree
tree (Any) – The main python tree that will be iterated upon
rest (Tuple[Any]) – Extra trees to be iterated together with tree
- Returns:
A python tree with the new values returned by
fn
.