diff --git a/python/mlx/extension.py b/python/mlx/extension.py index 842cf667f..afece6d83 100644 --- a/python/mlx/extension.py +++ b/python/mlx/extension.py @@ -77,7 +77,7 @@ class CMakeBuild(build_ext): ["cmake", "--build", ".", *build_args], cwd=build_temp, check=True ) - def run(self): + def run(self) -> None: super().run() # Based on https://github.com/pypa/setuptools/blob/main/setuptools/command/build_ext.py#L102 diff --git a/python/mlx/utils.py b/python/mlx/utils.py index e7b61373a..14b23a41e 100644 --- a/python/mlx/utils.py +++ b/python/mlx/utils.py @@ -1,8 +1,11 @@ # Copyright © 2023 Apple Inc. from collections import defaultdict +from typing import Any, Callable, Tuple -def tree_map(fn, tree, *rest, is_leaf=None): +def tree_map( + fn: Callable, tree: Any, *rest: Tuple[Any], is_leaf: Callable = None +) -> Any: """Applies ``fn`` to the leaves of the Python tree ``tree`` and returns a new collection with the results. @@ -53,7 +56,13 @@ def tree_map(fn, tree, *rest, is_leaf=None): return fn(tree, *rest) -def tree_map_with_path(fn, tree, *rest, is_leaf=None, path=None): +def tree_map_with_path( + fn: Callable, + tree: Any, + *rest: Tuple[Any], + is_leaf: Callable = None, + path: Any = None, +) -> Any: """Applies ``fn`` to the path and leaves of the Python tree ``tree`` and returns a new collection with the results.