Some fixes to typing (#1371)

* some fixes to typing

* fix module reference

* comment
This commit is contained in:
Awni Hannun
2024-08-28 11:16:19 -07:00
committed by GitHub
parent bd47e1f066
commit 291cf40aca
15 changed files with 152 additions and 145 deletions

View File

@@ -1,5 +1,7 @@
# Copyright © 2023 Apple Inc.
from __future__ import annotations
import textwrap
from typing import Any, Callable, List, Optional, Tuple, Union
@@ -7,42 +9,6 @@ import mlx.core as mx
from mlx.utils import tree_flatten, tree_unflatten
def _unwrap(model, value_key, value, filter_fn, map_fn, is_leaf_fn):
if is_leaf_fn(model, value_key, value):
return map_fn(value)
elif isinstance(value, Module):
return {
k: _unwrap(value, k, v, filter_fn, map_fn, is_leaf_fn)
for k, v in value.items()
if filter_fn(value, k, v)
}
elif isinstance(value, dict):
nd = {}
for k, v in value.items():
tk = f"{value_key}.{k}"
nd[k] = (
_unwrap(model, tk, v, filter_fn, map_fn, is_leaf_fn)
if filter_fn(model, tk, v)
else {}
)
return nd
elif isinstance(value, list):
nl = []
for i, vi in enumerate(value):
tk = f"{value_key}.{i}"
nl.append(
_unwrap(model, tk, vi, filter_fn, map_fn, is_leaf_fn)
if filter_fn(model, tk, vi)
else {}
)
return nl
raise RuntimeError("Unexpected leaf found while traversing the module")
class Module(dict):
"""Base class for building neural networks with MLX.
@@ -151,7 +117,7 @@ class Module(dict):
self,
file_or_weights: Union[str, List[Tuple[str, mx.array]]],
strict: bool = True,
) -> "Module":
) -> Module:
"""
Update the model's weights from a ``.npz``, a ``.safetensors`` file, or a list.
@@ -266,9 +232,9 @@ class Module(dict):
def filter_and_map(
self,
filter_fn: Callable[["mlx.nn.Module", str, Any], bool],
filter_fn: Callable[[Module, str, Any], bool],
map_fn: Optional[Callable] = None,
is_leaf_fn: Optional[Callable[["mlx.nn.Module", str, Any], bool]] = None,
is_leaf_fn: Optional[Callable[[Module, str, Any], bool]] = None,
):
"""Recursively filter the contents of the module using ``filter_fn``,
namely only select keys and values where ``filter_fn`` returns true.
@@ -323,7 +289,7 @@ class Module(dict):
return self.filter_and_map(self.valid_child_filter, is_leaf_fn=_is_leaf_module)
def update(self, parameters: dict) -> "Module":
def update(self, parameters: dict) -> Module:
"""Replace the parameters of this Module with the provided ones in the
dict of dicts and lists.
@@ -371,8 +337,8 @@ class Module(dict):
def apply(
self,
map_fn: Callable[[mx.array], mx.array],
filter_fn: Optional[Callable[["mlx.nn.Module", str, Any], bool]] = None,
) -> "Module":
filter_fn: Optional[Callable[[Module, str, Any], bool]] = None,
) -> Module:
"""Map all the parameters using the provided ``map_fn`` and immediately
update the module with the mapped parameters.
@@ -391,7 +357,7 @@ class Module(dict):
self.update(self.filter_and_map(filter_fn, map_fn))
return self
def update_modules(self, modules: dict) -> "Module":
def update_modules(self, modules: dict) -> Module:
"""Replace the child modules of this :class:`Module` instance with the
provided ones in the dict of dicts and lists.
@@ -432,9 +398,7 @@ class Module(dict):
apply(self, modules)
return self
def apply_to_modules(
self, apply_fn: Callable[[str, "mlx.nn.Module"], Any]
) -> "Module":
def apply_to_modules(self, apply_fn: Callable[[str, Module], Any]) -> Module:
"""Apply a function to all the modules in this instance (including this
instance).
@@ -489,7 +453,7 @@ class Module(dict):
recurse: bool = True,
keys: Optional[Union[str, List[str]]] = None,
strict: bool = False,
) -> "Module":
) -> Module:
"""Freeze the Module's parameters or some of them. Freezing a parameter means not
computing gradients for it.
@@ -544,7 +508,7 @@ class Module(dict):
recurse: bool = True,
keys: Optional[Union[str, List[str]]] = None,
strict: bool = False,
) -> "Module":
) -> Module:
"""Unfreeze the Module's parameters or some of them.
This function is idempotent ie unfreezing a model that is not frozen is
@@ -588,7 +552,7 @@ class Module(dict):
_unfreeze_impl("", self)
return self
def train(self, mode: bool = True) -> "Module":
def train(self, mode: bool = True) -> Module:
"""Set the model in or out of training mode.
Training mode only applies to certain layers. For example
@@ -608,7 +572,7 @@ class Module(dict):
self.apply_to_modules(_set_train)
return self
def eval(self) -> "Module":
def eval(self) -> Module:
"""Set the model to evaluation mode.
See :func:`train`.
@@ -637,3 +601,39 @@ class Module(dict):
return True
self.apply(lambda x: x.astype(dtype) if predicate(x.dtype) else x)
def _unwrap(model, value_key, value, filter_fn, map_fn, is_leaf_fn):
if is_leaf_fn(model, value_key, value):
return map_fn(value)
elif isinstance(value, Module):
return {
k: _unwrap(value, k, v, filter_fn, map_fn, is_leaf_fn)
for k, v in value.items()
if filter_fn(value, k, v)
}
elif isinstance(value, dict):
nd = {}
for k, v in value.items():
tk = f"{value_key}.{k}"
nd[k] = (
_unwrap(model, tk, v, filter_fn, map_fn, is_leaf_fn)
if filter_fn(model, tk, v)
else {}
)
return nd
elif isinstance(value, list):
nl = []
for i, vi in enumerate(value):
tk = f"{value_key}.{i}"
nl.append(
_unwrap(model, tk, vi, filter_fn, map_fn, is_leaf_fn)
if filter_fn(model, tk, vi)
else {}
)
return nl
raise RuntimeError("Unexpected leaf found while traversing the module")

View File

@@ -190,9 +190,9 @@ class MaxPool1d(_Pool1d):
def __init__(
self,
kernel_size: Union[int, Tuple[int, int]],
stride: Optional[Union[int, Tuple[int, int]]] = None,
padding: Optional[Union[int, Tuple[int, int]]] = 0,
kernel_size: Union[int, Tuple[int]],
stride: Optional[Union[int, Tuple[int]]] = None,
padding: Union[int, Tuple[int]] = 0,
):
super().__init__(mx.max, -float("inf"), kernel_size, stride, padding)
@@ -229,9 +229,9 @@ class AvgPool1d(_Pool1d):
def __init__(
self,
kernel_size: Union[int, Tuple[int, int]],
stride: Optional[Union[int, Tuple[int, int]]] = None,
padding: Optional[Union[int, Tuple[int, int]]] = 0,
kernel_size: Union[int, Tuple[int]],
stride: Optional[Union[int, Tuple[int]]] = None,
padding: Union[int, Tuple[int]] = 0,
):
super().__init__(mx.mean, 0, kernel_size, stride, padding)

View File

@@ -12,7 +12,7 @@ def quantize(
model: Module,
group_size: int = 64,
bits: int = 4,
class_predicate: Optional[callable] = None,
class_predicate: Optional[Callable] = None,
):
"""Quantize the sub-modules of a module according to a predicate.