mlx.utils.tree_map#
- tree_map(fn, tree, *rest, is_leaf=None)#
Applies
fn
to the leaves of the python treetree
and returns a new collection with the results.If
rest
is provided, every item is assumed to be a superset oftree
and the corresponding leaves are provided as extra positional arguments tofn
. In that respect,tree_map()
is closer toitertools.starmap()
than tomap()
.The keyword argument
is_leaf
decides what constitutes a leaf fromtree
similar totree_flatten()
.import mlx.nn as nn from mlx.utils import tree_map model = nn.Linear(10, 10) print(model.parameters().keys()) # dict_keys(['weight', 'bias']) # square the parameters model.update(tree_map(lambda x: x*x, model.parameters()))
- Parameters:
fn (Callable) – The function that processes the leaves of the tree
tree (Any) – The main python tree that will be iterated upon
rest (Tuple[Any]) – Extra trees to be iterated together with tree
is_leaf (Optional[Callable]) – An optional callable that returns True if the passed object is considered a leaf or False otherwise.
- Returns:
A python tree with the new values returned by
fn
.