diff --git a/python/mlx/utils.py b/python/mlx/utils.py index 2a3c1e660..b7173deb7 100644 --- a/python/mlx/utils.py +++ b/python/mlx/utils.py @@ -62,7 +62,7 @@ def tree_map_with_path( tree: Any, *rest: Any, is_leaf: Optional[Callable] = None, - path: Any = None, + path: Optional[Any] = None, ) -> Any: """Applies ``fn`` to the path and leaves of the Python tree ``tree`` and returns a new collection with the results. @@ -74,8 +74,9 @@ def tree_map_with_path( fn (callable): The function that processes the leaves of the tree. tree (Any): The main Python tree that will be iterated upon. rest (tuple[Any]): Extra trees to be iterated together with ``tree``. - is_leaf (callable, optional): An optional callable that returns ``True`` + is_leaf (Optional[Callable]): An optional callable that returns ``True`` if the passed object is considered a leaf or ``False`` otherwise. + path (Optional[Any]): Prefix will be added to the result. Returns: A Python tree with the new values returned by ``fn``.