mlx.utils.tree_map

mlx.utils.tree_map(fn, tree, *rest)

Applies fn to the leaves of the python tree tree and returns a new collection with the results.

If rest is provided, every item is assumed to be a superset of tree and the corresponding leaves are provided as extra positional arguments to fn. In that respect, tree_map() is closer to itertools.starmap() than to map().

import mlx.nn as nn
from mlx.utils import tree_map

model = nn.Linear(10, 10)
print(model.parameters().keys())
# dict_keys(['weight', 'bias'])

# square the parameters
model.update(tree_map(lambda x: x*x, model.parameters()))
Parameters:
  • fn (Callable) – The function that processes the leaves of the tree

  • tree (Any) – The main python tree that will be iterated upon

  • rest (Tuple[Any]) – Extra trees to be iterated together with tree

Returns:

A python tree with the new values returned by fn.