Some fixes to typing (#1371)

* some fixes to typing

* fix module reference

* comment
This commit is contained in:
Awni Hannun 2024-08-28 11:16:19 -07:00 committed by GitHub
parent bd47e1f066
commit 291cf40aca
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 152 additions and 145 deletions

View File

@ -234,7 +234,7 @@ def glorot_uniform(
def he_normal( def he_normal(
dtype: mx.Dtype = mx.float32, dtype: mx.Dtype = mx.float32,
) -> Callable[[mx.array, str, float], mx.array]: ) -> Callable[[mx.array, Literal["fan_in", "fan_out"], float], mx.array]:
r"""Build a He normal initializer. r"""Build a He normal initializer.
This initializer samples from a normal distribution with a standard This initializer samples from a normal distribution with a standard
@ -292,7 +292,7 @@ def he_normal(
def he_uniform( def he_uniform(
dtype: mx.Dtype = mx.float32, dtype: mx.Dtype = mx.float32,
) -> Callable[[mx.array, str, float], mx.array]: ) -> Callable[[mx.array, Literal["fan_in", "fan_out"], float], mx.array]:
r"""A He uniform (Kaiming uniform) initializer. r"""A He uniform (Kaiming uniform) initializer.
This initializer samples from a uniform distribution with a range This initializer samples from a uniform distribution with a range

View File

@ -1,5 +1,7 @@
# Copyright © 2023 Apple Inc. # Copyright © 2023 Apple Inc.
from __future__ import annotations
import textwrap import textwrap
from typing import Any, Callable, List, Optional, Tuple, Union from typing import Any, Callable, List, Optional, Tuple, Union
@ -7,42 +9,6 @@ import mlx.core as mx
from mlx.utils import tree_flatten, tree_unflatten from mlx.utils import tree_flatten, tree_unflatten
def _unwrap(model, value_key, value, filter_fn, map_fn, is_leaf_fn):
if is_leaf_fn(model, value_key, value):
return map_fn(value)
elif isinstance(value, Module):
return {
k: _unwrap(value, k, v, filter_fn, map_fn, is_leaf_fn)
for k, v in value.items()
if filter_fn(value, k, v)
}
elif isinstance(value, dict):
nd = {}
for k, v in value.items():
tk = f"{value_key}.{k}"
nd[k] = (
_unwrap(model, tk, v, filter_fn, map_fn, is_leaf_fn)
if filter_fn(model, tk, v)
else {}
)
return nd
elif isinstance(value, list):
nl = []
for i, vi in enumerate(value):
tk = f"{value_key}.{i}"
nl.append(
_unwrap(model, tk, vi, filter_fn, map_fn, is_leaf_fn)
if filter_fn(model, tk, vi)
else {}
)
return nl
raise RuntimeError("Unexpected leaf found while traversing the module")
class Module(dict): class Module(dict):
"""Base class for building neural networks with MLX. """Base class for building neural networks with MLX.
@ -151,7 +117,7 @@ class Module(dict):
self, self,
file_or_weights: Union[str, List[Tuple[str, mx.array]]], file_or_weights: Union[str, List[Tuple[str, mx.array]]],
strict: bool = True, strict: bool = True,
) -> "Module": ) -> Module:
""" """
Update the model's weights from a ``.npz``, a ``.safetensors`` file, or a list. Update the model's weights from a ``.npz``, a ``.safetensors`` file, or a list.
@ -266,9 +232,9 @@ class Module(dict):
def filter_and_map( def filter_and_map(
self, self,
filter_fn: Callable[["mlx.nn.Module", str, Any], bool], filter_fn: Callable[[Module, str, Any], bool],
map_fn: Optional[Callable] = None, map_fn: Optional[Callable] = None,
is_leaf_fn: Optional[Callable[["mlx.nn.Module", str, Any], bool]] = None, is_leaf_fn: Optional[Callable[[Module, str, Any], bool]] = None,
): ):
"""Recursively filter the contents of the module using ``filter_fn``, """Recursively filter the contents of the module using ``filter_fn``,
namely only select keys and values where ``filter_fn`` returns true. namely only select keys and values where ``filter_fn`` returns true.
@ -323,7 +289,7 @@ class Module(dict):
return self.filter_and_map(self.valid_child_filter, is_leaf_fn=_is_leaf_module) return self.filter_and_map(self.valid_child_filter, is_leaf_fn=_is_leaf_module)
def update(self, parameters: dict) -> "Module": def update(self, parameters: dict) -> Module:
"""Replace the parameters of this Module with the provided ones in the """Replace the parameters of this Module with the provided ones in the
dict of dicts and lists. dict of dicts and lists.
@ -371,8 +337,8 @@ class Module(dict):
def apply( def apply(
self, self,
map_fn: Callable[[mx.array], mx.array], map_fn: Callable[[mx.array], mx.array],
filter_fn: Optional[Callable[["mlx.nn.Module", str, Any], bool]] = None, filter_fn: Optional[Callable[[Module, str, Any], bool]] = None,
) -> "Module": ) -> Module:
"""Map all the parameters using the provided ``map_fn`` and immediately """Map all the parameters using the provided ``map_fn`` and immediately
update the module with the mapped parameters. update the module with the mapped parameters.
@ -391,7 +357,7 @@ class Module(dict):
self.update(self.filter_and_map(filter_fn, map_fn)) self.update(self.filter_and_map(filter_fn, map_fn))
return self return self
def update_modules(self, modules: dict) -> "Module": def update_modules(self, modules: dict) -> Module:
"""Replace the child modules of this :class:`Module` instance with the """Replace the child modules of this :class:`Module` instance with the
provided ones in the dict of dicts and lists. provided ones in the dict of dicts and lists.
@ -432,9 +398,7 @@ class Module(dict):
apply(self, modules) apply(self, modules)
return self return self
def apply_to_modules( def apply_to_modules(self, apply_fn: Callable[[str, Module], Any]) -> Module:
self, apply_fn: Callable[[str, "mlx.nn.Module"], Any]
) -> "Module":
"""Apply a function to all the modules in this instance (including this """Apply a function to all the modules in this instance (including this
instance). instance).
@ -489,7 +453,7 @@ class Module(dict):
recurse: bool = True, recurse: bool = True,
keys: Optional[Union[str, List[str]]] = None, keys: Optional[Union[str, List[str]]] = None,
strict: bool = False, strict: bool = False,
) -> "Module": ) -> Module:
"""Freeze the Module's parameters or some of them. Freezing a parameter means not """Freeze the Module's parameters or some of them. Freezing a parameter means not
computing gradients for it. computing gradients for it.
@ -544,7 +508,7 @@ class Module(dict):
recurse: bool = True, recurse: bool = True,
keys: Optional[Union[str, List[str]]] = None, keys: Optional[Union[str, List[str]]] = None,
strict: bool = False, strict: bool = False,
) -> "Module": ) -> Module:
"""Unfreeze the Module's parameters or some of them. """Unfreeze the Module's parameters or some of them.
This function is idempotent ie unfreezing a model that is not frozen is This function is idempotent ie unfreezing a model that is not frozen is
@ -588,7 +552,7 @@ class Module(dict):
_unfreeze_impl("", self) _unfreeze_impl("", self)
return self return self
def train(self, mode: bool = True) -> "Module": def train(self, mode: bool = True) -> Module:
"""Set the model in or out of training mode. """Set the model in or out of training mode.
Training mode only applies to certain layers. For example Training mode only applies to certain layers. For example
@ -608,7 +572,7 @@ class Module(dict):
self.apply_to_modules(_set_train) self.apply_to_modules(_set_train)
return self return self
def eval(self) -> "Module": def eval(self) -> Module:
"""Set the model to evaluation mode. """Set the model to evaluation mode.
See :func:`train`. See :func:`train`.
@ -637,3 +601,39 @@ class Module(dict):
return True return True
self.apply(lambda x: x.astype(dtype) if predicate(x.dtype) else x) self.apply(lambda x: x.astype(dtype) if predicate(x.dtype) else x)
def _unwrap(model, value_key, value, filter_fn, map_fn, is_leaf_fn):
if is_leaf_fn(model, value_key, value):
return map_fn(value)
elif isinstance(value, Module):
return {
k: _unwrap(value, k, v, filter_fn, map_fn, is_leaf_fn)
for k, v in value.items()
if filter_fn(value, k, v)
}
elif isinstance(value, dict):
nd = {}
for k, v in value.items():
tk = f"{value_key}.{k}"
nd[k] = (
_unwrap(model, tk, v, filter_fn, map_fn, is_leaf_fn)
if filter_fn(model, tk, v)
else {}
)
return nd
elif isinstance(value, list):
nl = []
for i, vi in enumerate(value):
tk = f"{value_key}.{i}"
nl.append(
_unwrap(model, tk, vi, filter_fn, map_fn, is_leaf_fn)
if filter_fn(model, tk, vi)
else {}
)
return nl
raise RuntimeError("Unexpected leaf found while traversing the module")

View File

@ -190,9 +190,9 @@ class MaxPool1d(_Pool1d):
def __init__( def __init__(
self, self,
kernel_size: Union[int, Tuple[int, int]], kernel_size: Union[int, Tuple[int]],
stride: Optional[Union[int, Tuple[int, int]]] = None, stride: Optional[Union[int, Tuple[int]]] = None,
padding: Optional[Union[int, Tuple[int, int]]] = 0, padding: Union[int, Tuple[int]] = 0,
): ):
super().__init__(mx.max, -float("inf"), kernel_size, stride, padding) super().__init__(mx.max, -float("inf"), kernel_size, stride, padding)
@ -229,9 +229,9 @@ class AvgPool1d(_Pool1d):
def __init__( def __init__(
self, self,
kernel_size: Union[int, Tuple[int, int]], kernel_size: Union[int, Tuple[int]],
stride: Optional[Union[int, Tuple[int, int]]] = None, stride: Optional[Union[int, Tuple[int]]] = None,
padding: Optional[Union[int, Tuple[int, int]]] = 0, padding: Union[int, Tuple[int]] = 0,
): ):
super().__init__(mx.mean, 0, kernel_size, stride, padding) super().__init__(mx.mean, 0, kernel_size, stride, padding)

View File

@ -12,7 +12,7 @@ def quantize(
model: Module, model: Module,
group_size: int = 64, group_size: int = 64,
bits: int = 4, bits: int = 4,
class_predicate: Optional[callable] = None, class_predicate: Optional[Callable] = None,
): ):
"""Quantize the sub-modules of a module according to a predicate. """Quantize the sub-modules of a module according to a predicate.

View File

@ -1,7 +1,7 @@
# Copyright © 2023 Apple Inc. # Copyright © 2023 Apple Inc.
import math import math
from typing import Literal from typing import Literal, Optional
import mlx.core as mx import mlx.core as mx
@ -22,7 +22,7 @@ def _reduce(loss: mx.array, reduction: Reduction = "none"):
def cross_entropy( def cross_entropy(
logits: mx.array, logits: mx.array,
targets: mx.array, targets: mx.array,
weights: mx.array = None, weights: Optional[mx.array] = None,
axis: int = -1, axis: int = -1,
label_smoothing: float = 0.0, label_smoothing: float = 0.0,
reduction: Reduction = "none", reduction: Reduction = "none",
@ -117,7 +117,7 @@ def cross_entropy(
def binary_cross_entropy( def binary_cross_entropy(
inputs: mx.array, inputs: mx.array,
targets: mx.array, targets: mx.array,
weights: mx.array = None, weights: Optional[mx.array] = None,
with_logits: bool = True, with_logits: bool = True,
reduction: Reduction = "mean", reduction: Reduction = "mean",
) -> mx.array: ) -> mx.array:

View File

@ -1,7 +1,7 @@
# Copyright © 2023-2024 Apple Inc. # Copyright © 2023-2024 Apple Inc.
from functools import wraps from functools import wraps
from typing import Callable from typing import Callable, Optional
import mlx.core as mx import mlx.core as mx
@ -37,7 +37,7 @@ def value_and_grad(model: Module, fn: Callable):
return wrapped_value_grad_fn return wrapped_value_grad_fn
def checkpoint(module: Module, fn: Callable = None): def checkpoint(module: Module, fn: Optional[Callable] = None):
"""Transform the passed callable to one that performs gradient """Transform the passed callable to one that performs gradient
checkpointing with respect to the trainable parameters of the module (and checkpointing with respect to the trainable parameters of the module (and
the callable's inputs). the callable's inputs).

View File

@ -4,6 +4,7 @@ import math
from typing import Callable, List, Optional, Tuple, Union from typing import Callable, List, Optional, Tuple, Union
import mlx.core as mx import mlx.core as mx
from mlx.nn import Module
from mlx.utils import tree_map, tree_reduce from mlx.utils import tree_map, tree_reduce
@ -17,7 +18,7 @@ class Optimizer:
self._state = {"step": mx.array(0, mx.uint64)} self._state = {"step": mx.array(0, mx.uint64)}
self._schedulers = {k: v for k, v in (schedulers or {}).items()} self._schedulers = {k: v for k, v in (schedulers or {}).items()}
def update(self, model: "mlx.nn.Module", gradients: dict): def update(self, model: Module, gradients: dict):
"""Apply the gradients to the parameters of the model and update the """Apply the gradients to the parameters of the model and update the
model with the new parameters. model with the new parameters.

View File

@ -1,10 +1,10 @@
# Copyright © 2023 Apple Inc. # Copyright © 2023 Apple Inc.
from collections import defaultdict from collections import defaultdict
from typing import Any, Callable, Tuple from typing import Any, Callable, Optional, Tuple
def tree_map( def tree_map(
fn: Callable, tree: Any, *rest: Tuple[Any], is_leaf: Callable = None fn: Callable, tree: Any, *rest: Any, is_leaf: Optional[Callable] = None
) -> Any: ) -> Any:
"""Applies ``fn`` to the leaves of the Python tree ``tree`` and """Applies ``fn`` to the leaves of the Python tree ``tree`` and
returns a new collection with the results. returns a new collection with the results.
@ -59,8 +59,8 @@ def tree_map(
def tree_map_with_path( def tree_map_with_path(
fn: Callable, fn: Callable,
tree: Any, tree: Any,
*rest: Tuple[Any], *rest: Any,
is_leaf: Callable = None, is_leaf: Optional[Callable] = None,
path: Any = None, path: Any = None,
) -> Any: ) -> Any:
"""Applies ``fn`` to the path and leaves of the Python tree ``tree`` and """Applies ``fn`` to the path and leaves of the Python tree ``tree`` and

View File

@ -9,6 +9,7 @@
#include <nanobind/stl/string.h> #include <nanobind/stl/string.h>
#include <nanobind/stl/variant.h> #include <nanobind/stl/variant.h>
#include <nanobind/stl/vector.h> #include <nanobind/stl/vector.h>
#include <nanobind/typing.h>
#include "mlx/backend/metal/metal.h" #include "mlx/backend/metal/metal.h"
#include "python/src/buffer.h" #include "python/src/buffer.h"
@ -113,6 +114,7 @@ void init_array(nb::module_& m) {
.def("__hash__", [](const Dtype& t) { .def("__hash__", [](const Dtype& t) {
return static_cast<int64_t>(t.val); return static_cast<int64_t>(t.val);
}); });
m.attr("bool_") = nb::cast(bool_); m.attr("bool_") = nb::cast(bool_);
m.attr("uint8") = nb::cast(uint8); m.attr("uint8") = nb::cast(uint8);
m.attr("uint16") = nb::cast(uint16); m.attr("uint16") = nb::cast(uint16);
@ -177,7 +179,7 @@ void init_array(nb::module_& m) {
.export_values(); .export_values();
nb::class_<ArrayAt>( nb::class_<ArrayAt>(
m, m,
"_ArrayAt", "ArrayAt",
R"pbdoc( R"pbdoc(
A helper object to apply updates at specific indices. A helper object to apply updates at specific indices.
)pbdoc") )pbdoc")
@ -195,7 +197,7 @@ void init_array(nb::module_& m) {
nb::class_<ArrayPythonIterator>( nb::class_<ArrayPythonIterator>(
m, m,
"_ArrayIterator", "ArrayIterator",
R"pbdoc( R"pbdoc(
A helper object to iterate over the 1st dimension of an array. A helper object to iterate over the 1st dimension of an array.
)pbdoc") )pbdoc")

View File

@ -229,14 +229,16 @@ void init_fast(nb::module_& parent_module) {
Returns: Returns:
Callable ``metal_kernel``. Callable ``metal_kernel``.
Example:
.. code-block:: python .. code-block:: python
def exp_elementwise(a: mx.array): def exp_elementwise(a: mx.array):
source = """ source = '''
uint elem = thread_position_in_grid.x; uint elem = thread_position_in_grid.x;
T tmp = inp[elem]; T tmp = inp[elem];
out[elem] = metal::exp(tmp); out[elem] = metal::exp(tmp);
""" '''
kernel = mx.fast.metal_kernel( kernel = mx.fast.metal_kernel(
name="myexp", name="myexp",
@ -256,7 +258,6 @@ void init_fast(nb::module_& parent_module) {
a = mx.random.normal(shape=(4, 16)).astype(mx.float16) a = mx.random.normal(shape=(4, 16)).astype(mx.float16)
b = exp_elementwise(a) b = exp_elementwise(a)
assert mx.allclose(b, mx.exp(a)) assert mx.allclose(b, mx.exp(a))
)pbdoc") )pbdoc")
.def( .def(
"__call__", "__call__",

View File

@ -63,7 +63,7 @@ void init_linalg(nb::module_& parent_module) {
nb::kw_only(), nb::kw_only(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
nb::sig( nb::sig(
"def norm(a: array, /, ord: Union[None, scalar, str] = None, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"), "def norm(a: array, /, ord: Union[None, int, float, str] = None, axis: Union[None, int, list[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc( R"pbdoc(
Matrix or vector norm. Matrix or vector norm.
@ -74,7 +74,7 @@ void init_linalg(nb::module_& parent_module) {
a (array): Input array. If ``axis`` is ``None``, ``a`` must be 1-D or 2-D, a (array): Input array. If ``axis`` is ``None``, ``a`` must be 1-D or 2-D,
unless ``ord`` is ``None``. If both ``axis`` and ``ord`` are ``None``, the unless ``ord`` is ``None``. If both ``axis`` and ``ord`` are ``None``, the
2-norm of ``a.flatten`` will be returned. 2-norm of ``a.flatten`` will be returned.
ord (scalar or str, optional): Order of the norm (see table under ``Notes``). ord (int, float or str, optional): Order of the norm (see table under ``Notes``).
If ``None``, the 2-norm (or Frobenius norm for matrices) will be computed If ``None``, the 2-norm (or Frobenius norm for matrices) will be computed
along the given ``axis``. Default: ``None``. along the given ``axis``. Default: ``None``.
axis (int or list(int), optional): If ``axis`` is an integer, it specifies the axis (int or list(int), optional): If ``axis`` is an integer, it specifies the
@ -187,7 +187,7 @@ void init_linalg(nb::module_& parent_module) {
nb::kw_only(), nb::kw_only(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
nb::sig( nb::sig(
"def qr(a: array, *, stream: Union[None, Stream, Device] = None) -> (array, array)"), "def qr(a: array, *, stream: Union[None, Stream, Device] = None) -> tuple(array, array)"),
R"pbdoc( R"pbdoc(
The QR factorization of the input matrix. The QR factorization of the input matrix.
@ -220,7 +220,7 @@ void init_linalg(nb::module_& parent_module) {
nb::kw_only(), nb::kw_only(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
nb::sig( nb::sig(
"def svd(a: array, *, stream: Union[None, Stream, Device] = None) -> (array, array, array)"), "def svd(a: array, *, stream: Union[None, Stream, Device] = None) -> tuple(array, array, array)"),
R"pbdoc( R"pbdoc(
The Singular Value Decomposition (SVD) of the input matrix. The Singular Value Decomposition (SVD) of the input matrix.

View File

@ -1360,7 +1360,7 @@ void init_ops(nb::module_& m) {
"dtype"_a = nb::none(), "dtype"_a = nb::none(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
nb::sig( nb::sig(
"def arange(stop : Union[int, float], step : Union[None, int, float], dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array")); "def arange(stop : Union[int, float], step : Union[None, int, float] = None, dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array"));
m.def( m.def(
"linspace", "linspace",
[](Scalar start, [](Scalar start,
@ -2695,7 +2695,7 @@ void init_ops(nb::module_& m) {
nb::kw_only(), nb::kw_only(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
nb::sig( nb::sig(
"def concatenate(arrays: List[array], axis: Optional[int] = 0, *, stream: Union[None, Stream, Device] = None) -> array"), "def concatenate(arrays: list[array], axis: Optional[int] = 0, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc( R"pbdoc(
Concatenate the arrays along the given axis. Concatenate the arrays along the given axis.
@ -2723,7 +2723,7 @@ void init_ops(nb::module_& m) {
nb::kw_only(), nb::kw_only(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
nb::sig( nb::sig(
"def concat(arrays: List[array], axis: Optional[int] = 0, *, stream: Union[None, Stream, Device] = None) -> array"), "def concat(arrays: list[array], axis: Optional[int] = 0, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc( R"pbdoc(
See :func:`concatenate`. See :func:`concatenate`.
)pbdoc"); )pbdoc");
@ -2743,7 +2743,7 @@ void init_ops(nb::module_& m) {
nb::kw_only(), nb::kw_only(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
nb::sig( nb::sig(
"def stack(arrays: List[array], axis: Optional[int] = 0, *, stream: Union[None, Stream, Device] = None) -> array"), "def stack(arrays: list[array], axis: Optional[int] = 0, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc( R"pbdoc(
Stacks the arrays along a new axis. Stacks the arrays along a new axis.
@ -2770,7 +2770,7 @@ void init_ops(nb::module_& m) {
"indexing"_a = "xy", "indexing"_a = "xy",
"stream"_a = nb::none(), "stream"_a = nb::none(),
nb::sig( nb::sig(
"def meshgrid(*arrays: array, sparse: Optional[bool] = false, indexing: Optional[str] = 'xy', stream: Union[None, Stream, Device] = None) -> array"), "def meshgrid(*arrays: array, sparse: Optional[bool] = False, indexing: Optional[str] = 'xy', stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc( R"pbdoc(
Generate multidimensional coordinate grids from 1-D coordinate arrays Generate multidimensional coordinate grids from 1-D coordinate arrays
@ -2889,7 +2889,7 @@ void init_ops(nb::module_& m) {
nb::kw_only(), nb::kw_only(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
nb::sig( nb::sig(
"def pad(a: array, pad_width: Union[int, Tuple[int], Tuple[int, int], List[Tuple[int, int]]], mode: Literal['constant', 'edge'] = 'constant', constant_values: Union[scalar, array] = 0, *, stream: Union[None, Stream, Device] = None) -> array"), "def pad(a: array, pad_width: Union[int, tuple[int], tuple[int, int], list[tuple[int, int]]], mode: Literal['constant', 'edge'] = 'constant', constant_values: Union[scalar, array] = 0, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc( R"pbdoc(
Pad an array with a constant value Pad an array with a constant value
@ -3291,7 +3291,7 @@ void init_ops(nb::module_& m) {
nb::kw_only(), nb::kw_only(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
nb::sig( nb::sig(
"def conv2d(input: array, weight: array, /, stride: Union[int, Tuple[int, int]] = 1, padding: Union[int, Tuple[int, int]] = 0, dilation: Union[int, Tuple[int, int]] = 1, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array"), "def conv2d(input: array, weight: array, /, stride: Union[int, tuple[int, int]] = 1, padding: Union[int, tuple[int, int]] = 0, dilation: Union[int, tuple[int, int]] = 1, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc( R"pbdoc(
2D convolution over an input with several channels 2D convolution over an input with several channels
@ -3361,7 +3361,7 @@ void init_ops(nb::module_& m) {
nb::kw_only(), nb::kw_only(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
nb::sig( nb::sig(
"def conv3d(input: array, weight: array, /, stride: Union[int, Tuple[int, int, int]] = 1, padding: Union[int, Tuple[int, int, int]] = 0, dilation: Union[int, Tuple[int, int, int]] = 1, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array"), "def conv3d(input: array, weight: array, /, stride: Union[int, tuple[int, int, int]] = 1, padding: Union[int, tuple[int, int, int]] = 0, dilation: Union[int, tuple[int, int, int]] = 1, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc( R"pbdoc(
3D convolution over an input with several channels 3D convolution over an input with several channels
@ -3460,7 +3460,7 @@ void init_ops(nb::module_& m) {
nb::kw_only(), nb::kw_only(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
nb::sig( nb::sig(
"def conv_general(input: array, weight: array, /, stride: Union[int, Sequence[int]] = 1, padding: Union[int, Sequence[int], Tuple[Sequence[int], Sequence[int]]] = 0, kernel_dilation: Union[int, Sequence[int]] = 1, input_dilation: Union[int, Sequence[int]] = 1, groups: int = 1, flip: bool = false, *, stream: Union[None, Stream, Device] = None) -> array"), "def conv_general(input: array, weight: array, /, stride: Union[int, Sequence[int]] = 1, padding: Union[int, Sequence[int], tuple[Sequence[int], Sequence[int]]] = 0, kernel_dilation: Union[int, Sequence[int]] = 1, input_dilation: Union[int, Sequence[int]] = 1, groups: int = 1, flip: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc( R"pbdoc(
General convolution over an input with several channels General convolution over an input with several channels
@ -3560,7 +3560,7 @@ void init_ops(nb::module_& m) {
nb::kw_only(), nb::kw_only(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
nb::sig( nb::sig(
"def load(file: str, /, format: Optional[str] = None, return_metadata: bool = False, *, stream: Union[None, Stream, Device] = None) -> Union[array, Dict[str, array]]"), "def load(file: str, /, format: Optional[str] = None, return_metadata: bool = False, *, stream: Union[None, Stream, Device] = None) -> Union[array, dict[str, array]]"),
R"pbdoc( R"pbdoc(
Load array(s) from a binary file. Load array(s) from a binary file.
@ -3594,7 +3594,7 @@ void init_ops(nb::module_& m) {
"arrays"_a, "arrays"_a,
"metadata"_a = nb::none(), "metadata"_a = nb::none(),
nb::sig( nb::sig(
"def save_safetensors(file: str, arrays: Dict[str, array], metadata: Optional[Dict[str, str]] = None)"), "def save_safetensors(file: str, arrays: dict[str, array], metadata: Optional[dict[str, str]] = None)"),
R"pbdoc( R"pbdoc(
Save array(s) to a binary file in ``.safetensors`` format. Save array(s) to a binary file in ``.safetensors`` format.
@ -3615,7 +3615,7 @@ void init_ops(nb::module_& m) {
"arrays"_a, "arrays"_a,
"metadata"_a = nb::none(), "metadata"_a = nb::none(),
nb::sig( nb::sig(
"def save_gguf(file: str, arrays: Dict[str, array], metadata: Dict[str, Union[array, str, List[str]]])"), "def save_gguf(file: str, arrays: dict[str, array], metadata: dict[str, Union[array, str, list[str]]])"),
R"pbdoc( R"pbdoc(
Save array(s) to a binary file in ``.gguf`` format. Save array(s) to a binary file in ``.gguf`` format.
@ -3769,7 +3769,7 @@ void init_ops(nb::module_& m) {
nb::kw_only(), nb::kw_only(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
nb::sig( nb::sig(
"def quantize(w: array, /, group_size: int = 64, bits : int = 4, *, stream: Union[None, Stream, Device] = None) -> Tuple[array, array, array]"), "def quantize(w: array, /, group_size: int = 64, bits : int = 4, *, stream: Union[None, Stream, Device] = None) -> tuple[array, array, array]"),
R"pbdoc( R"pbdoc(
Quantize the matrix ``w`` using ``bits`` bits per element. Quantize the matrix ``w`` using ``bits`` bits per element.
@ -3924,7 +3924,7 @@ void init_ops(nb::module_& m) {
nb::kw_only(), nb::kw_only(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
nb::sig( nb::sig(
"def tensordot(a: array, b: array, /, axes: Union[int, List[Sequence[int]]] = 2, *, stream: Union[None, Stream, Device] = None) -> array"), "def tensordot(a: array, b: array, /, axes: Union[int, list[Sequence[int]]] = 2, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc( R"pbdoc(
Compute the tensor dot product along the specified axes. Compute the tensor dot product along the specified axes.
@ -4046,7 +4046,7 @@ void init_ops(nb::module_& m) {
nb::kw_only(), nb::kw_only(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
nb::sig( nb::sig(
"def block_masked_mm(a: array, b: array, /, block_size: int = 64, mask_out: array, mask_lhs: array, mask_rhs: array, *, stream: Union[None, Stream, Device] = None) -> array"), "def block_masked_mm(a: array, b: array, /, block_size: int = 64, mask_out: Optional[array] = None, mask_lhs: Optional[array] = None, mask_rhs: Optional[array] = None, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc( R"pbdoc(
Matrix multiplication with block masking. Matrix multiplication with block masking.
@ -4189,7 +4189,7 @@ void init_ops(nb::module_& m) {
nb::kw_only(), nb::kw_only(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
nb::sig( nb::sig(
"def trace(a: array, /, offset: int = 0, axis1: int = 0, axis2: int = 1, dtype = Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array"), "def trace(a: array, /, offset: int = 0, axis1: int = 0, axis2: int = 1, dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc( R"pbdoc(
Return the sum along a specified diagonal in the given array. Return the sum along a specified diagonal in the given array.
@ -4218,7 +4218,7 @@ void init_ops(nb::module_& m) {
"arys"_a, "arys"_a,
"stream"_a = nb::none(), "stream"_a = nb::none(),
nb::sig( nb::sig(
"def atleast_1d(*arys: array, stream: Union[None, Stream, Device] = None) -> Union[array, List[array]]"), "def atleast_1d(*arys: array, stream: Union[None, Stream, Device] = None) -> Union[array, list[array]]"),
R"pbdoc( R"pbdoc(
Convert all arrays to have at least one dimension. Convert all arrays to have at least one dimension.
@ -4240,7 +4240,7 @@ void init_ops(nb::module_& m) {
"arys"_a, "arys"_a,
"stream"_a = nb::none(), "stream"_a = nb::none(),
nb::sig( nb::sig(
"def atleast_2d(*arys: array, stream: Union[None, Stream, Device] = None) -> Union[array, List[array]]"), "def atleast_2d(*arys: array, stream: Union[None, Stream, Device] = None) -> Union[array, list[array]]"),
R"pbdoc( R"pbdoc(
Convert all arrays to have at least two dimensions. Convert all arrays to have at least two dimensions.
@ -4262,7 +4262,7 @@ void init_ops(nb::module_& m) {
"arys"_a, "arys"_a,
"stream"_a = nb::none(), "stream"_a = nb::none(),
nb::sig( nb::sig(
"def atleast_3d(*arys: array, stream: Union[None, Stream, Device] = None) -> Union[array, List[array]]"), "def atleast_3d(*arys: array, stream: Union[None, Stream, Device] = None) -> Union[array, list[array]]"),
R"pbdoc( R"pbdoc(
Convert all arrays to have at least three dimensions. Convert all arrays to have at least three dimensions.
@ -4511,7 +4511,7 @@ void init_ops(nb::module_& m) {
nb::kw_only(), nb::kw_only(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
nb::sig( nb::sig(
"def hadamard_transform(a: array, Optional[float] scale = None, stream: Union[None, Stream, Device] = None) -> array"), "def hadamard_transform(a: array, scale: Optional[float] = None, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc( R"pbdoc(
Perform the Walsh-Hadamard transform along the final axis. Perform the Walsh-Hadamard transform along the final axis.
@ -4575,7 +4575,7 @@ void init_ops(nb::module_& m) {
nb::kw_only(), nb::kw_only(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
nb::sig( nb::sig(
"def einsum(subscripts: str, *operands, *, stream: Union[None, Stream, Device] = None) -> array"), "def einsum(subscripts: str, *operands, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc( R"pbdoc(
Perform the Einstein summation convention on the operands. Perform the Einstein summation convention on the operands.

View File

@ -93,7 +93,7 @@ void init_random(nb::module_& parent_module) {
"num"_a = 2, "num"_a = 2,
"stream"_a = nb::none(), "stream"_a = nb::none(),
nb::sig( nb::sig(
"def split(key: array, num: int = 2, stream: Union[None, Stream, Device] = None) -> array)"), "def split(key: array, num: int = 2, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc( R"pbdoc(
Split a PRNG key into sub keys. Split a PRNG key into sub keys.
@ -321,7 +321,7 @@ void init_random(nb::module_& parent_module) {
"key"_a = nb::none(), "key"_a = nb::none(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
nb::sig( nb::sig(
"def truncated_normal(lower: Union[scalar, array], upper: Union[scalar, array], shape: Optional[Sequence[int]] = None, dtype: float32, key: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"), "def truncated_normal(lower: Union[scalar, array], upper: Union[scalar, array], shape: Optional[Sequence[int]] = None, dtype: Optional[Dtype] = float32, key: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc( R"pbdoc(
Generate values from a truncated normal distribution. Generate values from a truncated normal distribution.

View File

@ -4,6 +4,7 @@
#include <nanobind/nanobind.h> #include <nanobind/nanobind.h>
#include <nanobind/stl/optional.h> #include <nanobind/stl/optional.h>
#include <nanobind/stl/string.h>
#include <nanobind/stl/variant.h> #include <nanobind/stl/variant.h>
#include "mlx/stream.h" #include "mlx/stream.h"
@ -56,8 +57,8 @@ void init_stream(nb::module_& m) {
os << s; os << s;
return os.str(); return os.str();
}) })
.def("__eq__", [](const Stream& s1, const Stream& s2) { .def("__eq__", [](const Stream& s, const nb::object& other) {
return s1 == s2; return nb::isinstance<Stream>(other) && s == nb::cast<Stream>(other);
}); });
nb::implicitly_convertible<Device::DeviceType, Device>(); nb::implicitly_convertible<Device::DeviceType, Device>();

View File

@ -178,7 +178,7 @@ auto py_value_and_grad(
msg << error_msg_tag << " The return value of the function " msg << error_msg_tag << " The return value of the function "
<< "whose gradient we want to compute should be either a " << "whose gradient we want to compute should be either a "
<< "scalar array or a tuple with the first value being a " << "scalar array or a tuple with the first value being a "
<< "scalar array (Union[array, Tuple[array, Any, ...]]); but " << "scalar array (Union[array, tuple[array, Any, ...]]); but "
<< type_name_str(py_value_out) << " was returned."; << type_name_str(py_value_out) << " was returned.";
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
@ -197,7 +197,7 @@ auto py_value_and_grad(
msg << error_msg_tag << " The return value of the function " msg << error_msg_tag << " The return value of the function "
<< "whose gradient we want to compute should be either a " << "whose gradient we want to compute should be either a "
<< "scalar array or a tuple with the first value being a " << "scalar array or a tuple with the first value being a "
<< "scalar array (Union[array, Tuple[array, Any, ...]]); but it " << "scalar array (Union[array, tuple[array, Any, ...]]); but it "
<< "was a tuple with the first value being of type " << "was a tuple with the first value being of type "
<< type_name_str(ret[0]) << " ."; << type_name_str(ret[0]) << " .";
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
@ -973,13 +973,13 @@ void init_transforms(nb::module_& m) {
.def( .def(
nb::init<nb::callable>(), nb::init<nb::callable>(),
"f"_a, "f"_a,
nb::sig("def __init__(self, f: callable)")) nb::sig("def __init__(self, f: Callable)"))
.def("__call__", &PyCustomFunction::call_impl) .def("__call__", &PyCustomFunction::call_impl)
.def( .def(
"vjp", "vjp",
&PyCustomFunction::set_vjp, &PyCustomFunction::set_vjp,
"f"_a, "f"_a,
nb::sig("def vjp(self, f_vjp: callable)"), nb::sig("def vjp(self, f: Callable)"),
R"pbdoc( R"pbdoc(
Define a custom vjp for the wrapped function. Define a custom vjp for the wrapped function.
@ -1001,7 +1001,7 @@ void init_transforms(nb::module_& m) {
"jvp", "jvp",
&PyCustomFunction::set_jvp, &PyCustomFunction::set_jvp,
"f"_a, "f"_a,
nb::sig("def jvp(self, f_jvp: callable)"), nb::sig("def jvp(self, f: Callable)"),
R"pbdoc( R"pbdoc(
Define a custom jvp for the wrapped function. Define a custom jvp for the wrapped function.
@ -1021,7 +1021,7 @@ void init_transforms(nb::module_& m) {
"vmap", "vmap",
&PyCustomFunction::set_vmap, &PyCustomFunction::set_vmap,
"f"_a, "f"_a,
nb::sig("def vmap(self, f_vmap: callable)"), nb::sig("def vmap(self, f: Callable)"),
R"pbdoc( R"pbdoc(
Define a custom vectorization transformation for the wrapped function. Define a custom vectorization transformation for the wrapped function.
@ -1116,7 +1116,7 @@ void init_transforms(nb::module_& m) {
"primals"_a, "primals"_a,
"tangents"_a, "tangents"_a,
nb::sig( nb::sig(
"def jvp(fun: callable, primals: List[array], tangents: List[array]) -> Tuple[List[array], List[array]]"), "def jvp(fun: Callable, primals: list[array], tangents: list[array]) -> tuple[list[array], list[array]]"),
R"pbdoc( R"pbdoc(
Compute the Jacobian-vector product. Compute the Jacobian-vector product.
@ -1124,7 +1124,7 @@ void init_transforms(nb::module_& m) {
at ``primals`` with the ``tangents``. at ``primals`` with the ``tangents``.
Args: Args:
fun (callable): A function which takes a variable number of :class:`array` fun (Callable): A function which takes a variable number of :class:`array`
and returns a single :class:`array` or list of :class:`array`. and returns a single :class:`array` or list of :class:`array`.
primals (list(array)): A list of :class:`array` at which to primals (list(array)): A list of :class:`array` at which to
evaluate the Jacobian. evaluate the Jacobian.
@ -1155,7 +1155,7 @@ void init_transforms(nb::module_& m) {
"primals"_a, "primals"_a,
"cotangents"_a, "cotangents"_a,
nb::sig( nb::sig(
"def vjp(fun: callable, primals: List[array], cotangents: List[array]) -> Tuple[List[array], List[array]]"), "def vjp(fun: Callable, primals: list[array], cotangents: list[array]) -> tuple[list[array], list[array]]"),
R"pbdoc( R"pbdoc(
Compute the vector-Jacobian product. Compute the vector-Jacobian product.
@ -1163,7 +1163,7 @@ void init_transforms(nb::module_& m) {
function ``fun`` evaluated at ``primals``. function ``fun`` evaluated at ``primals``.
Args: Args:
fun (callable): A function which takes a variable number of :class:`array` fun (Callable): A function which takes a variable number of :class:`array`
and returns a single :class:`array` or list of :class:`array`. and returns a single :class:`array` or list of :class:`array`.
primals (list(array)): A list of :class:`array` at which to primals (list(array)): A list of :class:`array` at which to
evaluate the Jacobian. evaluate the Jacobian.
@ -1189,7 +1189,7 @@ void init_transforms(nb::module_& m) {
"argnums"_a = nb::none(), "argnums"_a = nb::none(),
"argnames"_a = std::vector<std::string>{}, "argnames"_a = std::vector<std::string>{},
nb::sig( nb::sig(
"def value_and_grad(fun: callable, argnums: Optional[Union[int, List[int]]] = None, argnames: Union[str, List[str]] = []) -> callable"), "def value_and_grad(fun: Callable, argnums: Optional[Union[int, list[int]]] = None, argnames: Union[str, list[str]] = []) -> Callable"),
R"pbdoc( R"pbdoc(
Returns a function which computes the value and gradient of ``fun``. Returns a function which computes the value and gradient of ``fun``.
@ -1221,7 +1221,7 @@ void init_transforms(nb::module_& m) {
(loss, mse, l1), grads = mx.value_and_grad(lasso)(params, inputs, targets) (loss, mse, l1), grads = mx.value_and_grad(lasso)(params, inputs, targets)
Args: Args:
fun (callable): A function which takes a variable number of fun (Callable): A function which takes a variable number of
:class:`array` or trees of :class:`array` and returns :class:`array` or trees of :class:`array` and returns
a scalar output :class:`array` or a tuple the first element a scalar output :class:`array` or a tuple the first element
of which should be a scalar :class:`array`. of which should be a scalar :class:`array`.
@ -1235,7 +1235,7 @@ void init_transforms(nb::module_& m) {
no gradients for keyword arguments by default. no gradients for keyword arguments by default.
Returns: Returns:
callable: A function which returns a tuple where the first element Callable: A function which returns a tuple where the first element
is the output of `fun` and the second element is the gradients w.r.t. is the output of `fun` and the second element is the gradients w.r.t.
the loss. the loss.
)pbdoc"); )pbdoc");
@ -1257,12 +1257,12 @@ void init_transforms(nb::module_& m) {
"argnums"_a = nb::none(), "argnums"_a = nb::none(),
"argnames"_a = std::vector<std::string>{}, "argnames"_a = std::vector<std::string>{},
nb::sig( nb::sig(
"def grad(fun: callable, argnums: Optional[Union[int, List[int]]] = None, argnames: Union[str, List[str]] = []) -> callable"), "def grad(fun: Callable, argnums: Optional[Union[int, list[int]]] = None, argnames: Union[str, list[str]] = []) -> Callable"),
R"pbdoc( R"pbdoc(
Returns a function which computes the gradient of ``fun``. Returns a function which computes the gradient of ``fun``.
Args: Args:
fun (callable): A function which takes a variable number of fun (Callable): A function which takes a variable number of
:class:`array` or trees of :class:`array` and returns :class:`array` or trees of :class:`array` and returns
a scalar output :class:`array`. a scalar output :class:`array`.
argnums (int or list(int), optional): Specify the index (or indices) argnums (int or list(int), optional): Specify the index (or indices)
@ -1275,7 +1275,7 @@ void init_transforms(nb::module_& m) {
no gradients for keyword arguments by default. no gradients for keyword arguments by default.
Returns: Returns:
callable: A function which has the same input arguments as ``fun`` and Callable: A function which has the same input arguments as ``fun`` and
returns the gradient(s). returns the gradient(s).
)pbdoc"); )pbdoc");
m.def( m.def(
@ -1289,12 +1289,12 @@ void init_transforms(nb::module_& m) {
"in_axes"_a = 0, "in_axes"_a = 0,
"out_axes"_a = 0, "out_axes"_a = 0,
nb::sig( nb::sig(
"def vmap(fun: callable, in_axes: object = 0, out_axes: object = 0) -> callable"), "def vmap(fun: Callable, in_axes: object = 0, out_axes: object = 0) -> Callable"),
R"pbdoc( R"pbdoc(
Returns a vectorized version of ``fun``. Returns a vectorized version of ``fun``.
Args: Args:
fun (callable): A function which takes a variable number of fun (Callable): A function which takes a variable number of
:class:`array` or a tree of :class:`array` and returns :class:`array` or a tree of :class:`array` and returns
a variable number of :class:`array` or a tree of :class:`array`. a variable number of :class:`array` or a tree of :class:`array`.
in_axes (int, optional): An integer or a valid prefix tree of the in_axes (int, optional): An integer or a valid prefix tree of the
@ -1307,7 +1307,7 @@ void init_transforms(nb::module_& m) {
Defaults to ``0``. Defaults to ``0``.
Returns: Returns:
callable: The vectorized function. Callable: The vectorized function.
)pbdoc"); )pbdoc");
m.def( m.def(
"export_to_dot", "export_to_dot",
@ -1367,11 +1367,13 @@ void init_transforms(nb::module_& m) {
"inputs"_a = nb::none(), "inputs"_a = nb::none(),
"outputs"_a = nb::none(), "outputs"_a = nb::none(),
"shapeless"_a = false, "shapeless"_a = false,
nb::sig(
"def compile(fun: Callable, inputs: Optional[object] = None, outputs: Optional[object] = None, shapeless: bool = False) -> Callable"),
R"pbdoc( R"pbdoc(
Returns a compiled function which produces the same output as ``fun``. Returns a compiled function which produces the same output as ``fun``.
Args: Args:
fun (callable): A function which takes a variable number of fun (Callable): A function which takes a variable number of
:class:`array` or trees of :class:`array` and returns :class:`array` or trees of :class:`array` and returns
a variable number of :class:`array` or trees of :class:`array`. a variable number of :class:`array` or trees of :class:`array`.
inputs (list or dict, optional): These inputs will be captured during inputs (list or dict, optional): These inputs will be captured during
@ -1392,7 +1394,7 @@ void init_transforms(nb::module_& m) {
``shapeless`` set to ``True``. Default: ``False`` ``shapeless`` set to ``True``. Default: ``False``
Returns: Returns:
callable: A compiled function which has the same input arguments Callable: A compiled function which has the same input arguments
as ``fun`` and returns the the same output(s). as ``fun`` and returns the the same output(s).
)pbdoc"); )pbdoc");
m.def( m.def(