Some fixes to typing (#1371)

* some fixes to typing

* fix module reference

* comment
This commit is contained in:
Awni Hannun
2024-08-28 11:16:19 -07:00
committed by GitHub
parent bd47e1f066
commit 291cf40aca
15 changed files with 152 additions and 145 deletions

View File

@@ -234,7 +234,7 @@ def glorot_uniform(
def he_normal(
dtype: mx.Dtype = mx.float32,
) -> Callable[[mx.array, str, float], mx.array]:
) -> Callable[[mx.array, Literal["fan_in", "fan_out"], float], mx.array]:
r"""Build a He normal initializer.
This initializer samples from a normal distribution with a standard
@@ -292,7 +292,7 @@ def he_normal(
def he_uniform(
dtype: mx.Dtype = mx.float32,
) -> Callable[[mx.array, str, float], mx.array]:
) -> Callable[[mx.array, Literal["fan_in", "fan_out"], float], mx.array]:
r"""A He uniform (Kaiming uniform) initializer.
This initializer samples from a uniform distribution with a range

View File

@@ -1,5 +1,7 @@
# Copyright © 2023 Apple Inc.
from __future__ import annotations
import textwrap
from typing import Any, Callable, List, Optional, Tuple, Union
@@ -7,42 +9,6 @@ import mlx.core as mx
from mlx.utils import tree_flatten, tree_unflatten
def _unwrap(model, value_key, value, filter_fn, map_fn, is_leaf_fn):
if is_leaf_fn(model, value_key, value):
return map_fn(value)
elif isinstance(value, Module):
return {
k: _unwrap(value, k, v, filter_fn, map_fn, is_leaf_fn)
for k, v in value.items()
if filter_fn(value, k, v)
}
elif isinstance(value, dict):
nd = {}
for k, v in value.items():
tk = f"{value_key}.{k}"
nd[k] = (
_unwrap(model, tk, v, filter_fn, map_fn, is_leaf_fn)
if filter_fn(model, tk, v)
else {}
)
return nd
elif isinstance(value, list):
nl = []
for i, vi in enumerate(value):
tk = f"{value_key}.{i}"
nl.append(
_unwrap(model, tk, vi, filter_fn, map_fn, is_leaf_fn)
if filter_fn(model, tk, vi)
else {}
)
return nl
raise RuntimeError("Unexpected leaf found while traversing the module")
class Module(dict):
"""Base class for building neural networks with MLX.
@@ -151,7 +117,7 @@ class Module(dict):
self,
file_or_weights: Union[str, List[Tuple[str, mx.array]]],
strict: bool = True,
) -> "Module":
) -> Module:
"""
Update the model's weights from a ``.npz``, a ``.safetensors`` file, or a list.
@@ -266,9 +232,9 @@ class Module(dict):
def filter_and_map(
self,
filter_fn: Callable[["mlx.nn.Module", str, Any], bool],
filter_fn: Callable[[Module, str, Any], bool],
map_fn: Optional[Callable] = None,
is_leaf_fn: Optional[Callable[["mlx.nn.Module", str, Any], bool]] = None,
is_leaf_fn: Optional[Callable[[Module, str, Any], bool]] = None,
):
"""Recursively filter the contents of the module using ``filter_fn``,
namely only select keys and values where ``filter_fn`` returns true.
@@ -323,7 +289,7 @@ class Module(dict):
return self.filter_and_map(self.valid_child_filter, is_leaf_fn=_is_leaf_module)
def update(self, parameters: dict) -> "Module":
def update(self, parameters: dict) -> Module:
"""Replace the parameters of this Module with the provided ones in the
dict of dicts and lists.
@@ -371,8 +337,8 @@ class Module(dict):
def apply(
self,
map_fn: Callable[[mx.array], mx.array],
filter_fn: Optional[Callable[["mlx.nn.Module", str, Any], bool]] = None,
) -> "Module":
filter_fn: Optional[Callable[[Module, str, Any], bool]] = None,
) -> Module:
"""Map all the parameters using the provided ``map_fn`` and immediately
update the module with the mapped parameters.
@@ -391,7 +357,7 @@ class Module(dict):
self.update(self.filter_and_map(filter_fn, map_fn))
return self
def update_modules(self, modules: dict) -> "Module":
def update_modules(self, modules: dict) -> Module:
"""Replace the child modules of this :class:`Module` instance with the
provided ones in the dict of dicts and lists.
@@ -432,9 +398,7 @@ class Module(dict):
apply(self, modules)
return self
def apply_to_modules(
self, apply_fn: Callable[[str, "mlx.nn.Module"], Any]
) -> "Module":
def apply_to_modules(self, apply_fn: Callable[[str, Module], Any]) -> Module:
"""Apply a function to all the modules in this instance (including this
instance).
@@ -489,7 +453,7 @@ class Module(dict):
recurse: bool = True,
keys: Optional[Union[str, List[str]]] = None,
strict: bool = False,
) -> "Module":
) -> Module:
"""Freeze the Module's parameters or some of them. Freezing a parameter means not
computing gradients for it.
@@ -544,7 +508,7 @@ class Module(dict):
recurse: bool = True,
keys: Optional[Union[str, List[str]]] = None,
strict: bool = False,
) -> "Module":
) -> Module:
"""Unfreeze the Module's parameters or some of them.
This function is idempotent ie unfreezing a model that is not frozen is
@@ -588,7 +552,7 @@ class Module(dict):
_unfreeze_impl("", self)
return self
def train(self, mode: bool = True) -> "Module":
def train(self, mode: bool = True) -> Module:
"""Set the model in or out of training mode.
Training mode only applies to certain layers. For example
@@ -608,7 +572,7 @@ class Module(dict):
self.apply_to_modules(_set_train)
return self
def eval(self) -> "Module":
def eval(self) -> Module:
"""Set the model to evaluation mode.
See :func:`train`.
@@ -637,3 +601,39 @@ class Module(dict):
return True
self.apply(lambda x: x.astype(dtype) if predicate(x.dtype) else x)
def _unwrap(model, value_key, value, filter_fn, map_fn, is_leaf_fn):
if is_leaf_fn(model, value_key, value):
return map_fn(value)
elif isinstance(value, Module):
return {
k: _unwrap(value, k, v, filter_fn, map_fn, is_leaf_fn)
for k, v in value.items()
if filter_fn(value, k, v)
}
elif isinstance(value, dict):
nd = {}
for k, v in value.items():
tk = f"{value_key}.{k}"
nd[k] = (
_unwrap(model, tk, v, filter_fn, map_fn, is_leaf_fn)
if filter_fn(model, tk, v)
else {}
)
return nd
elif isinstance(value, list):
nl = []
for i, vi in enumerate(value):
tk = f"{value_key}.{i}"
nl.append(
_unwrap(model, tk, vi, filter_fn, map_fn, is_leaf_fn)
if filter_fn(model, tk, vi)
else {}
)
return nl
raise RuntimeError("Unexpected leaf found while traversing the module")

View File

@@ -190,9 +190,9 @@ class MaxPool1d(_Pool1d):
def __init__(
self,
kernel_size: Union[int, Tuple[int, int]],
stride: Optional[Union[int, Tuple[int, int]]] = None,
padding: Optional[Union[int, Tuple[int, int]]] = 0,
kernel_size: Union[int, Tuple[int]],
stride: Optional[Union[int, Tuple[int]]] = None,
padding: Union[int, Tuple[int]] = 0,
):
super().__init__(mx.max, -float("inf"), kernel_size, stride, padding)
@@ -229,9 +229,9 @@ class AvgPool1d(_Pool1d):
def __init__(
self,
kernel_size: Union[int, Tuple[int, int]],
stride: Optional[Union[int, Tuple[int, int]]] = None,
padding: Optional[Union[int, Tuple[int, int]]] = 0,
kernel_size: Union[int, Tuple[int]],
stride: Optional[Union[int, Tuple[int]]] = None,
padding: Union[int, Tuple[int]] = 0,
):
super().__init__(mx.mean, 0, kernel_size, stride, padding)

View File

@@ -12,7 +12,7 @@ def quantize(
model: Module,
group_size: int = 64,
bits: int = 4,
class_predicate: Optional[callable] = None,
class_predicate: Optional[Callable] = None,
):
"""Quantize the sub-modules of a module according to a predicate.

View File

@@ -1,7 +1,7 @@
# Copyright © 2023 Apple Inc.
import math
from typing import Literal
from typing import Literal, Optional
import mlx.core as mx
@@ -22,7 +22,7 @@ def _reduce(loss: mx.array, reduction: Reduction = "none"):
def cross_entropy(
logits: mx.array,
targets: mx.array,
weights: mx.array = None,
weights: Optional[mx.array] = None,
axis: int = -1,
label_smoothing: float = 0.0,
reduction: Reduction = "none",
@@ -117,7 +117,7 @@ def cross_entropy(
def binary_cross_entropy(
inputs: mx.array,
targets: mx.array,
weights: mx.array = None,
weights: Optional[mx.array] = None,
with_logits: bool = True,
reduction: Reduction = "mean",
) -> mx.array:

View File

@@ -1,7 +1,7 @@
# Copyright © 2023-2024 Apple Inc.
from functools import wraps
from typing import Callable
from typing import Callable, Optional
import mlx.core as mx
@@ -37,7 +37,7 @@ def value_and_grad(model: Module, fn: Callable):
return wrapped_value_grad_fn
def checkpoint(module: Module, fn: Callable = None):
def checkpoint(module: Module, fn: Optional[Callable] = None):
"""Transform the passed callable to one that performs gradient
checkpointing with respect to the trainable parameters of the module (and
the callable's inputs).

View File

@@ -4,6 +4,7 @@ import math
from typing import Callable, List, Optional, Tuple, Union
import mlx.core as mx
from mlx.nn import Module
from mlx.utils import tree_map, tree_reduce
@@ -17,7 +18,7 @@ class Optimizer:
self._state = {"step": mx.array(0, mx.uint64)}
self._schedulers = {k: v for k, v in (schedulers or {}).items()}
def update(self, model: "mlx.nn.Module", gradients: dict):
def update(self, model: Module, gradients: dict):
"""Apply the gradients to the parameters of the model and update the
model with the new parameters.

View File

@@ -1,10 +1,10 @@
# Copyright © 2023 Apple Inc.
from collections import defaultdict
from typing import Any, Callable, Tuple
from typing import Any, Callable, Optional, Tuple
def tree_map(
fn: Callable, tree: Any, *rest: Tuple[Any], is_leaf: Callable = None
fn: Callable, tree: Any, *rest: Any, is_leaf: Optional[Callable] = None
) -> Any:
"""Applies ``fn`` to the leaves of the Python tree ``tree`` and
returns a new collection with the results.
@@ -59,8 +59,8 @@ def tree_map(
def tree_map_with_path(
fn: Callable,
tree: Any,
*rest: Tuple[Any],
is_leaf: Callable = None,
*rest: Any,
is_leaf: Optional[Callable] = None,
path: Any = None,
) -> Any:
"""Applies ``fn`` to the path and leaves of the Python tree ``tree`` and