mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-19 02:38:09 +08:00
Some fixes to typing (#1371)
* some fixes to typing * fix module reference * comment
This commit is contained in:
@@ -234,7 +234,7 @@ def glorot_uniform(
|
||||
|
||||
def he_normal(
|
||||
dtype: mx.Dtype = mx.float32,
|
||||
) -> Callable[[mx.array, str, float], mx.array]:
|
||||
) -> Callable[[mx.array, Literal["fan_in", "fan_out"], float], mx.array]:
|
||||
r"""Build a He normal initializer.
|
||||
|
||||
This initializer samples from a normal distribution with a standard
|
||||
@@ -292,7 +292,7 @@ def he_normal(
|
||||
|
||||
def he_uniform(
|
||||
dtype: mx.Dtype = mx.float32,
|
||||
) -> Callable[[mx.array, str, float], mx.array]:
|
||||
) -> Callable[[mx.array, Literal["fan_in", "fan_out"], float], mx.array]:
|
||||
r"""A He uniform (Kaiming uniform) initializer.
|
||||
|
||||
This initializer samples from a uniform distribution with a range
|
||||
|
@@ -1,5 +1,7 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import textwrap
|
||||
from typing import Any, Callable, List, Optional, Tuple, Union
|
||||
|
||||
@@ -7,42 +9,6 @@ import mlx.core as mx
|
||||
from mlx.utils import tree_flatten, tree_unflatten
|
||||
|
||||
|
||||
def _unwrap(model, value_key, value, filter_fn, map_fn, is_leaf_fn):
|
||||
if is_leaf_fn(model, value_key, value):
|
||||
return map_fn(value)
|
||||
|
||||
elif isinstance(value, Module):
|
||||
return {
|
||||
k: _unwrap(value, k, v, filter_fn, map_fn, is_leaf_fn)
|
||||
for k, v in value.items()
|
||||
if filter_fn(value, k, v)
|
||||
}
|
||||
|
||||
elif isinstance(value, dict):
|
||||
nd = {}
|
||||
for k, v in value.items():
|
||||
tk = f"{value_key}.{k}"
|
||||
nd[k] = (
|
||||
_unwrap(model, tk, v, filter_fn, map_fn, is_leaf_fn)
|
||||
if filter_fn(model, tk, v)
|
||||
else {}
|
||||
)
|
||||
return nd
|
||||
|
||||
elif isinstance(value, list):
|
||||
nl = []
|
||||
for i, vi in enumerate(value):
|
||||
tk = f"{value_key}.{i}"
|
||||
nl.append(
|
||||
_unwrap(model, tk, vi, filter_fn, map_fn, is_leaf_fn)
|
||||
if filter_fn(model, tk, vi)
|
||||
else {}
|
||||
)
|
||||
return nl
|
||||
|
||||
raise RuntimeError("Unexpected leaf found while traversing the module")
|
||||
|
||||
|
||||
class Module(dict):
|
||||
"""Base class for building neural networks with MLX.
|
||||
|
||||
@@ -151,7 +117,7 @@ class Module(dict):
|
||||
self,
|
||||
file_or_weights: Union[str, List[Tuple[str, mx.array]]],
|
||||
strict: bool = True,
|
||||
) -> "Module":
|
||||
) -> Module:
|
||||
"""
|
||||
Update the model's weights from a ``.npz``, a ``.safetensors`` file, or a list.
|
||||
|
||||
@@ -266,9 +232,9 @@ class Module(dict):
|
||||
|
||||
def filter_and_map(
|
||||
self,
|
||||
filter_fn: Callable[["mlx.nn.Module", str, Any], bool],
|
||||
filter_fn: Callable[[Module, str, Any], bool],
|
||||
map_fn: Optional[Callable] = None,
|
||||
is_leaf_fn: Optional[Callable[["mlx.nn.Module", str, Any], bool]] = None,
|
||||
is_leaf_fn: Optional[Callable[[Module, str, Any], bool]] = None,
|
||||
):
|
||||
"""Recursively filter the contents of the module using ``filter_fn``,
|
||||
namely only select keys and values where ``filter_fn`` returns true.
|
||||
@@ -323,7 +289,7 @@ class Module(dict):
|
||||
|
||||
return self.filter_and_map(self.valid_child_filter, is_leaf_fn=_is_leaf_module)
|
||||
|
||||
def update(self, parameters: dict) -> "Module":
|
||||
def update(self, parameters: dict) -> Module:
|
||||
"""Replace the parameters of this Module with the provided ones in the
|
||||
dict of dicts and lists.
|
||||
|
||||
@@ -371,8 +337,8 @@ class Module(dict):
|
||||
def apply(
|
||||
self,
|
||||
map_fn: Callable[[mx.array], mx.array],
|
||||
filter_fn: Optional[Callable[["mlx.nn.Module", str, Any], bool]] = None,
|
||||
) -> "Module":
|
||||
filter_fn: Optional[Callable[[Module, str, Any], bool]] = None,
|
||||
) -> Module:
|
||||
"""Map all the parameters using the provided ``map_fn`` and immediately
|
||||
update the module with the mapped parameters.
|
||||
|
||||
@@ -391,7 +357,7 @@ class Module(dict):
|
||||
self.update(self.filter_and_map(filter_fn, map_fn))
|
||||
return self
|
||||
|
||||
def update_modules(self, modules: dict) -> "Module":
|
||||
def update_modules(self, modules: dict) -> Module:
|
||||
"""Replace the child modules of this :class:`Module` instance with the
|
||||
provided ones in the dict of dicts and lists.
|
||||
|
||||
@@ -432,9 +398,7 @@ class Module(dict):
|
||||
apply(self, modules)
|
||||
return self
|
||||
|
||||
def apply_to_modules(
|
||||
self, apply_fn: Callable[[str, "mlx.nn.Module"], Any]
|
||||
) -> "Module":
|
||||
def apply_to_modules(self, apply_fn: Callable[[str, Module], Any]) -> Module:
|
||||
"""Apply a function to all the modules in this instance (including this
|
||||
instance).
|
||||
|
||||
@@ -489,7 +453,7 @@ class Module(dict):
|
||||
recurse: bool = True,
|
||||
keys: Optional[Union[str, List[str]]] = None,
|
||||
strict: bool = False,
|
||||
) -> "Module":
|
||||
) -> Module:
|
||||
"""Freeze the Module's parameters or some of them. Freezing a parameter means not
|
||||
computing gradients for it.
|
||||
|
||||
@@ -544,7 +508,7 @@ class Module(dict):
|
||||
recurse: bool = True,
|
||||
keys: Optional[Union[str, List[str]]] = None,
|
||||
strict: bool = False,
|
||||
) -> "Module":
|
||||
) -> Module:
|
||||
"""Unfreeze the Module's parameters or some of them.
|
||||
|
||||
This function is idempotent ie unfreezing a model that is not frozen is
|
||||
@@ -588,7 +552,7 @@ class Module(dict):
|
||||
_unfreeze_impl("", self)
|
||||
return self
|
||||
|
||||
def train(self, mode: bool = True) -> "Module":
|
||||
def train(self, mode: bool = True) -> Module:
|
||||
"""Set the model in or out of training mode.
|
||||
|
||||
Training mode only applies to certain layers. For example
|
||||
@@ -608,7 +572,7 @@ class Module(dict):
|
||||
self.apply_to_modules(_set_train)
|
||||
return self
|
||||
|
||||
def eval(self) -> "Module":
|
||||
def eval(self) -> Module:
|
||||
"""Set the model to evaluation mode.
|
||||
|
||||
See :func:`train`.
|
||||
@@ -637,3 +601,39 @@ class Module(dict):
|
||||
return True
|
||||
|
||||
self.apply(lambda x: x.astype(dtype) if predicate(x.dtype) else x)
|
||||
|
||||
|
||||
def _unwrap(model, value_key, value, filter_fn, map_fn, is_leaf_fn):
|
||||
if is_leaf_fn(model, value_key, value):
|
||||
return map_fn(value)
|
||||
|
||||
elif isinstance(value, Module):
|
||||
return {
|
||||
k: _unwrap(value, k, v, filter_fn, map_fn, is_leaf_fn)
|
||||
for k, v in value.items()
|
||||
if filter_fn(value, k, v)
|
||||
}
|
||||
|
||||
elif isinstance(value, dict):
|
||||
nd = {}
|
||||
for k, v in value.items():
|
||||
tk = f"{value_key}.{k}"
|
||||
nd[k] = (
|
||||
_unwrap(model, tk, v, filter_fn, map_fn, is_leaf_fn)
|
||||
if filter_fn(model, tk, v)
|
||||
else {}
|
||||
)
|
||||
return nd
|
||||
|
||||
elif isinstance(value, list):
|
||||
nl = []
|
||||
for i, vi in enumerate(value):
|
||||
tk = f"{value_key}.{i}"
|
||||
nl.append(
|
||||
_unwrap(model, tk, vi, filter_fn, map_fn, is_leaf_fn)
|
||||
if filter_fn(model, tk, vi)
|
||||
else {}
|
||||
)
|
||||
return nl
|
||||
|
||||
raise RuntimeError("Unexpected leaf found while traversing the module")
|
||||
|
@@ -190,9 +190,9 @@ class MaxPool1d(_Pool1d):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kernel_size: Union[int, Tuple[int, int]],
|
||||
stride: Optional[Union[int, Tuple[int, int]]] = None,
|
||||
padding: Optional[Union[int, Tuple[int, int]]] = 0,
|
||||
kernel_size: Union[int, Tuple[int]],
|
||||
stride: Optional[Union[int, Tuple[int]]] = None,
|
||||
padding: Union[int, Tuple[int]] = 0,
|
||||
):
|
||||
super().__init__(mx.max, -float("inf"), kernel_size, stride, padding)
|
||||
|
||||
@@ -229,9 +229,9 @@ class AvgPool1d(_Pool1d):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kernel_size: Union[int, Tuple[int, int]],
|
||||
stride: Optional[Union[int, Tuple[int, int]]] = None,
|
||||
padding: Optional[Union[int, Tuple[int, int]]] = 0,
|
||||
kernel_size: Union[int, Tuple[int]],
|
||||
stride: Optional[Union[int, Tuple[int]]] = None,
|
||||
padding: Union[int, Tuple[int]] = 0,
|
||||
):
|
||||
super().__init__(mx.mean, 0, kernel_size, stride, padding)
|
||||
|
||||
|
@@ -12,7 +12,7 @@ def quantize(
|
||||
model: Module,
|
||||
group_size: int = 64,
|
||||
bits: int = 4,
|
||||
class_predicate: Optional[callable] = None,
|
||||
class_predicate: Optional[Callable] = None,
|
||||
):
|
||||
"""Quantize the sub-modules of a module according to a predicate.
|
||||
|
||||
|
@@ -1,7 +1,7 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
import math
|
||||
from typing import Literal
|
||||
from typing import Literal, Optional
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
@@ -22,7 +22,7 @@ def _reduce(loss: mx.array, reduction: Reduction = "none"):
|
||||
def cross_entropy(
|
||||
logits: mx.array,
|
||||
targets: mx.array,
|
||||
weights: mx.array = None,
|
||||
weights: Optional[mx.array] = None,
|
||||
axis: int = -1,
|
||||
label_smoothing: float = 0.0,
|
||||
reduction: Reduction = "none",
|
||||
@@ -117,7 +117,7 @@ def cross_entropy(
|
||||
def binary_cross_entropy(
|
||||
inputs: mx.array,
|
||||
targets: mx.array,
|
||||
weights: mx.array = None,
|
||||
weights: Optional[mx.array] = None,
|
||||
with_logits: bool = True,
|
||||
reduction: Reduction = "mean",
|
||||
) -> mx.array:
|
||||
|
@@ -1,7 +1,7 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
from functools import wraps
|
||||
from typing import Callable
|
||||
from typing import Callable, Optional
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
@@ -37,7 +37,7 @@ def value_and_grad(model: Module, fn: Callable):
|
||||
return wrapped_value_grad_fn
|
||||
|
||||
|
||||
def checkpoint(module: Module, fn: Callable = None):
|
||||
def checkpoint(module: Module, fn: Optional[Callable] = None):
|
||||
"""Transform the passed callable to one that performs gradient
|
||||
checkpointing with respect to the trainable parameters of the module (and
|
||||
the callable's inputs).
|
||||
|
@@ -4,6 +4,7 @@ import math
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx.nn import Module
|
||||
from mlx.utils import tree_map, tree_reduce
|
||||
|
||||
|
||||
@@ -17,7 +18,7 @@ class Optimizer:
|
||||
self._state = {"step": mx.array(0, mx.uint64)}
|
||||
self._schedulers = {k: v for k, v in (schedulers or {}).items()}
|
||||
|
||||
def update(self, model: "mlx.nn.Module", gradients: dict):
|
||||
def update(self, model: Module, gradients: dict):
|
||||
"""Apply the gradients to the parameters of the model and update the
|
||||
model with the new parameters.
|
||||
|
||||
|
@@ -1,10 +1,10 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
from collections import defaultdict
|
||||
from typing import Any, Callable, Tuple
|
||||
from typing import Any, Callable, Optional, Tuple
|
||||
|
||||
|
||||
def tree_map(
|
||||
fn: Callable, tree: Any, *rest: Tuple[Any], is_leaf: Callable = None
|
||||
fn: Callable, tree: Any, *rest: Any, is_leaf: Optional[Callable] = None
|
||||
) -> Any:
|
||||
"""Applies ``fn`` to the leaves of the Python tree ``tree`` and
|
||||
returns a new collection with the results.
|
||||
@@ -59,8 +59,8 @@ def tree_map(
|
||||
def tree_map_with_path(
|
||||
fn: Callable,
|
||||
tree: Any,
|
||||
*rest: Tuple[Any],
|
||||
is_leaf: Callable = None,
|
||||
*rest: Any,
|
||||
is_leaf: Optional[Callable] = None,
|
||||
path: Any = None,
|
||||
) -> Any:
|
||||
"""Applies ``fn`` to the path and leaves of the Python tree ``tree`` and
|
||||
|
Reference in New Issue
Block a user