mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Some fixes to typing (#1371)
* some fixes to typing * fix module reference * comment
This commit is contained in:
@@ -1,5 +1,7 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import textwrap
|
||||
from typing import Any, Callable, List, Optional, Tuple, Union
|
||||
|
||||
@@ -7,42 +9,6 @@ import mlx.core as mx
|
||||
from mlx.utils import tree_flatten, tree_unflatten
|
||||
|
||||
|
||||
def _unwrap(model, value_key, value, filter_fn, map_fn, is_leaf_fn):
|
||||
if is_leaf_fn(model, value_key, value):
|
||||
return map_fn(value)
|
||||
|
||||
elif isinstance(value, Module):
|
||||
return {
|
||||
k: _unwrap(value, k, v, filter_fn, map_fn, is_leaf_fn)
|
||||
for k, v in value.items()
|
||||
if filter_fn(value, k, v)
|
||||
}
|
||||
|
||||
elif isinstance(value, dict):
|
||||
nd = {}
|
||||
for k, v in value.items():
|
||||
tk = f"{value_key}.{k}"
|
||||
nd[k] = (
|
||||
_unwrap(model, tk, v, filter_fn, map_fn, is_leaf_fn)
|
||||
if filter_fn(model, tk, v)
|
||||
else {}
|
||||
)
|
||||
return nd
|
||||
|
||||
elif isinstance(value, list):
|
||||
nl = []
|
||||
for i, vi in enumerate(value):
|
||||
tk = f"{value_key}.{i}"
|
||||
nl.append(
|
||||
_unwrap(model, tk, vi, filter_fn, map_fn, is_leaf_fn)
|
||||
if filter_fn(model, tk, vi)
|
||||
else {}
|
||||
)
|
||||
return nl
|
||||
|
||||
raise RuntimeError("Unexpected leaf found while traversing the module")
|
||||
|
||||
|
||||
class Module(dict):
|
||||
"""Base class for building neural networks with MLX.
|
||||
|
||||
@@ -151,7 +117,7 @@ class Module(dict):
|
||||
self,
|
||||
file_or_weights: Union[str, List[Tuple[str, mx.array]]],
|
||||
strict: bool = True,
|
||||
) -> "Module":
|
||||
) -> Module:
|
||||
"""
|
||||
Update the model's weights from a ``.npz``, a ``.safetensors`` file, or a list.
|
||||
|
||||
@@ -266,9 +232,9 @@ class Module(dict):
|
||||
|
||||
def filter_and_map(
|
||||
self,
|
||||
filter_fn: Callable[["mlx.nn.Module", str, Any], bool],
|
||||
filter_fn: Callable[[Module, str, Any], bool],
|
||||
map_fn: Optional[Callable] = None,
|
||||
is_leaf_fn: Optional[Callable[["mlx.nn.Module", str, Any], bool]] = None,
|
||||
is_leaf_fn: Optional[Callable[[Module, str, Any], bool]] = None,
|
||||
):
|
||||
"""Recursively filter the contents of the module using ``filter_fn``,
|
||||
namely only select keys and values where ``filter_fn`` returns true.
|
||||
@@ -323,7 +289,7 @@ class Module(dict):
|
||||
|
||||
return self.filter_and_map(self.valid_child_filter, is_leaf_fn=_is_leaf_module)
|
||||
|
||||
def update(self, parameters: dict) -> "Module":
|
||||
def update(self, parameters: dict) -> Module:
|
||||
"""Replace the parameters of this Module with the provided ones in the
|
||||
dict of dicts and lists.
|
||||
|
||||
@@ -371,8 +337,8 @@ class Module(dict):
|
||||
def apply(
|
||||
self,
|
||||
map_fn: Callable[[mx.array], mx.array],
|
||||
filter_fn: Optional[Callable[["mlx.nn.Module", str, Any], bool]] = None,
|
||||
) -> "Module":
|
||||
filter_fn: Optional[Callable[[Module, str, Any], bool]] = None,
|
||||
) -> Module:
|
||||
"""Map all the parameters using the provided ``map_fn`` and immediately
|
||||
update the module with the mapped parameters.
|
||||
|
||||
@@ -391,7 +357,7 @@ class Module(dict):
|
||||
self.update(self.filter_and_map(filter_fn, map_fn))
|
||||
return self
|
||||
|
||||
def update_modules(self, modules: dict) -> "Module":
|
||||
def update_modules(self, modules: dict) -> Module:
|
||||
"""Replace the child modules of this :class:`Module` instance with the
|
||||
provided ones in the dict of dicts and lists.
|
||||
|
||||
@@ -432,9 +398,7 @@ class Module(dict):
|
||||
apply(self, modules)
|
||||
return self
|
||||
|
||||
def apply_to_modules(
|
||||
self, apply_fn: Callable[[str, "mlx.nn.Module"], Any]
|
||||
) -> "Module":
|
||||
def apply_to_modules(self, apply_fn: Callable[[str, Module], Any]) -> Module:
|
||||
"""Apply a function to all the modules in this instance (including this
|
||||
instance).
|
||||
|
||||
@@ -489,7 +453,7 @@ class Module(dict):
|
||||
recurse: bool = True,
|
||||
keys: Optional[Union[str, List[str]]] = None,
|
||||
strict: bool = False,
|
||||
) -> "Module":
|
||||
) -> Module:
|
||||
"""Freeze the Module's parameters or some of them. Freezing a parameter means not
|
||||
computing gradients for it.
|
||||
|
||||
@@ -544,7 +508,7 @@ class Module(dict):
|
||||
recurse: bool = True,
|
||||
keys: Optional[Union[str, List[str]]] = None,
|
||||
strict: bool = False,
|
||||
) -> "Module":
|
||||
) -> Module:
|
||||
"""Unfreeze the Module's parameters or some of them.
|
||||
|
||||
This function is idempotent ie unfreezing a model that is not frozen is
|
||||
@@ -588,7 +552,7 @@ class Module(dict):
|
||||
_unfreeze_impl("", self)
|
||||
return self
|
||||
|
||||
def train(self, mode: bool = True) -> "Module":
|
||||
def train(self, mode: bool = True) -> Module:
|
||||
"""Set the model in or out of training mode.
|
||||
|
||||
Training mode only applies to certain layers. For example
|
||||
@@ -608,7 +572,7 @@ class Module(dict):
|
||||
self.apply_to_modules(_set_train)
|
||||
return self
|
||||
|
||||
def eval(self) -> "Module":
|
||||
def eval(self) -> Module:
|
||||
"""Set the model to evaluation mode.
|
||||
|
||||
See :func:`train`.
|
||||
@@ -637,3 +601,39 @@ class Module(dict):
|
||||
return True
|
||||
|
||||
self.apply(lambda x: x.astype(dtype) if predicate(x.dtype) else x)
|
||||
|
||||
|
||||
def _unwrap(model, value_key, value, filter_fn, map_fn, is_leaf_fn):
|
||||
if is_leaf_fn(model, value_key, value):
|
||||
return map_fn(value)
|
||||
|
||||
elif isinstance(value, Module):
|
||||
return {
|
||||
k: _unwrap(value, k, v, filter_fn, map_fn, is_leaf_fn)
|
||||
for k, v in value.items()
|
||||
if filter_fn(value, k, v)
|
||||
}
|
||||
|
||||
elif isinstance(value, dict):
|
||||
nd = {}
|
||||
for k, v in value.items():
|
||||
tk = f"{value_key}.{k}"
|
||||
nd[k] = (
|
||||
_unwrap(model, tk, v, filter_fn, map_fn, is_leaf_fn)
|
||||
if filter_fn(model, tk, v)
|
||||
else {}
|
||||
)
|
||||
return nd
|
||||
|
||||
elif isinstance(value, list):
|
||||
nl = []
|
||||
for i, vi in enumerate(value):
|
||||
tk = f"{value_key}.{i}"
|
||||
nl.append(
|
||||
_unwrap(model, tk, vi, filter_fn, map_fn, is_leaf_fn)
|
||||
if filter_fn(model, tk, vi)
|
||||
else {}
|
||||
)
|
||||
return nl
|
||||
|
||||
raise RuntimeError("Unexpected leaf found while traversing the module")
|
||||
|
||||
@@ -190,9 +190,9 @@ class MaxPool1d(_Pool1d):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kernel_size: Union[int, Tuple[int, int]],
|
||||
stride: Optional[Union[int, Tuple[int, int]]] = None,
|
||||
padding: Optional[Union[int, Tuple[int, int]]] = 0,
|
||||
kernel_size: Union[int, Tuple[int]],
|
||||
stride: Optional[Union[int, Tuple[int]]] = None,
|
||||
padding: Union[int, Tuple[int]] = 0,
|
||||
):
|
||||
super().__init__(mx.max, -float("inf"), kernel_size, stride, padding)
|
||||
|
||||
@@ -229,9 +229,9 @@ class AvgPool1d(_Pool1d):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kernel_size: Union[int, Tuple[int, int]],
|
||||
stride: Optional[Union[int, Tuple[int, int]]] = None,
|
||||
padding: Optional[Union[int, Tuple[int, int]]] = 0,
|
||||
kernel_size: Union[int, Tuple[int]],
|
||||
stride: Optional[Union[int, Tuple[int]]] = None,
|
||||
padding: Union[int, Tuple[int]] = 0,
|
||||
):
|
||||
super().__init__(mx.mean, 0, kernel_size, stride, padding)
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ def quantize(
|
||||
model: Module,
|
||||
group_size: int = 64,
|
||||
bits: int = 4,
|
||||
class_predicate: Optional[callable] = None,
|
||||
class_predicate: Optional[Callable] = None,
|
||||
):
|
||||
"""Quantize the sub-modules of a module according to a predicate.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user