mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 10:38:10 +08:00 
			
		
		
		
	Some fixes to typing (#1371)
* some fixes to typing * fix module reference * comment
This commit is contained in:
		@@ -234,7 +234,7 @@ def glorot_uniform(
 | 
			
		||||
 | 
			
		||||
def he_normal(
 | 
			
		||||
    dtype: mx.Dtype = mx.float32,
 | 
			
		||||
) -> Callable[[mx.array, str, float], mx.array]:
 | 
			
		||||
) -> Callable[[mx.array, Literal["fan_in", "fan_out"], float], mx.array]:
 | 
			
		||||
    r"""Build a He normal initializer.
 | 
			
		||||
 | 
			
		||||
    This initializer samples from a normal distribution with a standard
 | 
			
		||||
@@ -292,7 +292,7 @@ def he_normal(
 | 
			
		||||
 | 
			
		||||
def he_uniform(
 | 
			
		||||
    dtype: mx.Dtype = mx.float32,
 | 
			
		||||
) -> Callable[[mx.array, str, float], mx.array]:
 | 
			
		||||
) -> Callable[[mx.array, Literal["fan_in", "fan_out"], float], mx.array]:
 | 
			
		||||
    r"""A He uniform (Kaiming uniform) initializer.
 | 
			
		||||
 | 
			
		||||
    This initializer samples from a uniform distribution with a range
 | 
			
		||||
 
 | 
			
		||||
@@ -1,5 +1,7 @@
 | 
			
		||||
# Copyright © 2023 Apple Inc.
 | 
			
		||||
 | 
			
		||||
from __future__ import annotations
 | 
			
		||||
 | 
			
		||||
import textwrap
 | 
			
		||||
from typing import Any, Callable, List, Optional, Tuple, Union
 | 
			
		||||
 | 
			
		||||
@@ -7,42 +9,6 @@ import mlx.core as mx
 | 
			
		||||
from mlx.utils import tree_flatten, tree_unflatten
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _unwrap(model, value_key, value, filter_fn, map_fn, is_leaf_fn):
 | 
			
		||||
    if is_leaf_fn(model, value_key, value):
 | 
			
		||||
        return map_fn(value)
 | 
			
		||||
 | 
			
		||||
    elif isinstance(value, Module):
 | 
			
		||||
        return {
 | 
			
		||||
            k: _unwrap(value, k, v, filter_fn, map_fn, is_leaf_fn)
 | 
			
		||||
            for k, v in value.items()
 | 
			
		||||
            if filter_fn(value, k, v)
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
    elif isinstance(value, dict):
 | 
			
		||||
        nd = {}
 | 
			
		||||
        for k, v in value.items():
 | 
			
		||||
            tk = f"{value_key}.{k}"
 | 
			
		||||
            nd[k] = (
 | 
			
		||||
                _unwrap(model, tk, v, filter_fn, map_fn, is_leaf_fn)
 | 
			
		||||
                if filter_fn(model, tk, v)
 | 
			
		||||
                else {}
 | 
			
		||||
            )
 | 
			
		||||
        return nd
 | 
			
		||||
 | 
			
		||||
    elif isinstance(value, list):
 | 
			
		||||
        nl = []
 | 
			
		||||
        for i, vi in enumerate(value):
 | 
			
		||||
            tk = f"{value_key}.{i}"
 | 
			
		||||
            nl.append(
 | 
			
		||||
                _unwrap(model, tk, vi, filter_fn, map_fn, is_leaf_fn)
 | 
			
		||||
                if filter_fn(model, tk, vi)
 | 
			
		||||
                else {}
 | 
			
		||||
            )
 | 
			
		||||
        return nl
 | 
			
		||||
 | 
			
		||||
    raise RuntimeError("Unexpected leaf found while traversing the module")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Module(dict):
 | 
			
		||||
    """Base class for building neural networks with MLX.
 | 
			
		||||
 | 
			
		||||
@@ -151,7 +117,7 @@ class Module(dict):
 | 
			
		||||
        self,
 | 
			
		||||
        file_or_weights: Union[str, List[Tuple[str, mx.array]]],
 | 
			
		||||
        strict: bool = True,
 | 
			
		||||
    ) -> "Module":
 | 
			
		||||
    ) -> Module:
 | 
			
		||||
        """
 | 
			
		||||
        Update the model's weights from a ``.npz``, a ``.safetensors`` file, or a list.
 | 
			
		||||
 | 
			
		||||
@@ -266,9 +232,9 @@ class Module(dict):
 | 
			
		||||
 | 
			
		||||
    def filter_and_map(
 | 
			
		||||
        self,
 | 
			
		||||
        filter_fn: Callable[["mlx.nn.Module", str, Any], bool],
 | 
			
		||||
        filter_fn: Callable[[Module, str, Any], bool],
 | 
			
		||||
        map_fn: Optional[Callable] = None,
 | 
			
		||||
        is_leaf_fn: Optional[Callable[["mlx.nn.Module", str, Any], bool]] = None,
 | 
			
		||||
        is_leaf_fn: Optional[Callable[[Module, str, Any], bool]] = None,
 | 
			
		||||
    ):
 | 
			
		||||
        """Recursively filter the contents of the module using ``filter_fn``,
 | 
			
		||||
        namely only select keys and values where ``filter_fn`` returns true.
 | 
			
		||||
@@ -323,7 +289,7 @@ class Module(dict):
 | 
			
		||||
 | 
			
		||||
        return self.filter_and_map(self.valid_child_filter, is_leaf_fn=_is_leaf_module)
 | 
			
		||||
 | 
			
		||||
    def update(self, parameters: dict) -> "Module":
 | 
			
		||||
    def update(self, parameters: dict) -> Module:
 | 
			
		||||
        """Replace the parameters of this Module with the provided ones in the
 | 
			
		||||
        dict of dicts and lists.
 | 
			
		||||
 | 
			
		||||
@@ -371,8 +337,8 @@ class Module(dict):
 | 
			
		||||
    def apply(
 | 
			
		||||
        self,
 | 
			
		||||
        map_fn: Callable[[mx.array], mx.array],
 | 
			
		||||
        filter_fn: Optional[Callable[["mlx.nn.Module", str, Any], bool]] = None,
 | 
			
		||||
    ) -> "Module":
 | 
			
		||||
        filter_fn: Optional[Callable[[Module, str, Any], bool]] = None,
 | 
			
		||||
    ) -> Module:
 | 
			
		||||
        """Map all the parameters using the provided ``map_fn`` and immediately
 | 
			
		||||
        update the module with the mapped parameters.
 | 
			
		||||
 | 
			
		||||
@@ -391,7 +357,7 @@ class Module(dict):
 | 
			
		||||
        self.update(self.filter_and_map(filter_fn, map_fn))
 | 
			
		||||
        return self
 | 
			
		||||
 | 
			
		||||
    def update_modules(self, modules: dict) -> "Module":
 | 
			
		||||
    def update_modules(self, modules: dict) -> Module:
 | 
			
		||||
        """Replace the child modules of this :class:`Module` instance with the
 | 
			
		||||
        provided ones in the dict of dicts and lists.
 | 
			
		||||
 | 
			
		||||
@@ -432,9 +398,7 @@ class Module(dict):
 | 
			
		||||
        apply(self, modules)
 | 
			
		||||
        return self
 | 
			
		||||
 | 
			
		||||
    def apply_to_modules(
 | 
			
		||||
        self, apply_fn: Callable[[str, "mlx.nn.Module"], Any]
 | 
			
		||||
    ) -> "Module":
 | 
			
		||||
    def apply_to_modules(self, apply_fn: Callable[[str, Module], Any]) -> Module:
 | 
			
		||||
        """Apply a function to all the modules in this instance (including this
 | 
			
		||||
        instance).
 | 
			
		||||
 | 
			
		||||
@@ -489,7 +453,7 @@ class Module(dict):
 | 
			
		||||
        recurse: bool = True,
 | 
			
		||||
        keys: Optional[Union[str, List[str]]] = None,
 | 
			
		||||
        strict: bool = False,
 | 
			
		||||
    ) -> "Module":
 | 
			
		||||
    ) -> Module:
 | 
			
		||||
        """Freeze the Module's parameters or some of them. Freezing a parameter means not
 | 
			
		||||
        computing gradients for it.
 | 
			
		||||
 | 
			
		||||
@@ -544,7 +508,7 @@ class Module(dict):
 | 
			
		||||
        recurse: bool = True,
 | 
			
		||||
        keys: Optional[Union[str, List[str]]] = None,
 | 
			
		||||
        strict: bool = False,
 | 
			
		||||
    ) -> "Module":
 | 
			
		||||
    ) -> Module:
 | 
			
		||||
        """Unfreeze the Module's parameters or some of them.
 | 
			
		||||
 | 
			
		||||
        This function is idempotent ie unfreezing a model that is not frozen is
 | 
			
		||||
@@ -588,7 +552,7 @@ class Module(dict):
 | 
			
		||||
            _unfreeze_impl("", self)
 | 
			
		||||
        return self
 | 
			
		||||
 | 
			
		||||
    def train(self, mode: bool = True) -> "Module":
 | 
			
		||||
    def train(self, mode: bool = True) -> Module:
 | 
			
		||||
        """Set the model in or out of training mode.
 | 
			
		||||
 | 
			
		||||
        Training mode only applies to certain layers. For example
 | 
			
		||||
@@ -608,7 +572,7 @@ class Module(dict):
 | 
			
		||||
        self.apply_to_modules(_set_train)
 | 
			
		||||
        return self
 | 
			
		||||
 | 
			
		||||
    def eval(self) -> "Module":
 | 
			
		||||
    def eval(self) -> Module:
 | 
			
		||||
        """Set the model to evaluation mode.
 | 
			
		||||
 | 
			
		||||
        See :func:`train`.
 | 
			
		||||
@@ -637,3 +601,39 @@ class Module(dict):
 | 
			
		||||
                return True
 | 
			
		||||
 | 
			
		||||
        self.apply(lambda x: x.astype(dtype) if predicate(x.dtype) else x)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _unwrap(model, value_key, value, filter_fn, map_fn, is_leaf_fn):
 | 
			
		||||
    if is_leaf_fn(model, value_key, value):
 | 
			
		||||
        return map_fn(value)
 | 
			
		||||
 | 
			
		||||
    elif isinstance(value, Module):
 | 
			
		||||
        return {
 | 
			
		||||
            k: _unwrap(value, k, v, filter_fn, map_fn, is_leaf_fn)
 | 
			
		||||
            for k, v in value.items()
 | 
			
		||||
            if filter_fn(value, k, v)
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
    elif isinstance(value, dict):
 | 
			
		||||
        nd = {}
 | 
			
		||||
        for k, v in value.items():
 | 
			
		||||
            tk = f"{value_key}.{k}"
 | 
			
		||||
            nd[k] = (
 | 
			
		||||
                _unwrap(model, tk, v, filter_fn, map_fn, is_leaf_fn)
 | 
			
		||||
                if filter_fn(model, tk, v)
 | 
			
		||||
                else {}
 | 
			
		||||
            )
 | 
			
		||||
        return nd
 | 
			
		||||
 | 
			
		||||
    elif isinstance(value, list):
 | 
			
		||||
        nl = []
 | 
			
		||||
        for i, vi in enumerate(value):
 | 
			
		||||
            tk = f"{value_key}.{i}"
 | 
			
		||||
            nl.append(
 | 
			
		||||
                _unwrap(model, tk, vi, filter_fn, map_fn, is_leaf_fn)
 | 
			
		||||
                if filter_fn(model, tk, vi)
 | 
			
		||||
                else {}
 | 
			
		||||
            )
 | 
			
		||||
        return nl
 | 
			
		||||
 | 
			
		||||
    raise RuntimeError("Unexpected leaf found while traversing the module")
 | 
			
		||||
 
 | 
			
		||||
@@ -190,9 +190,9 @@ class MaxPool1d(_Pool1d):
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        kernel_size: Union[int, Tuple[int, int]],
 | 
			
		||||
        stride: Optional[Union[int, Tuple[int, int]]] = None,
 | 
			
		||||
        padding: Optional[Union[int, Tuple[int, int]]] = 0,
 | 
			
		||||
        kernel_size: Union[int, Tuple[int]],
 | 
			
		||||
        stride: Optional[Union[int, Tuple[int]]] = None,
 | 
			
		||||
        padding: Union[int, Tuple[int]] = 0,
 | 
			
		||||
    ):
 | 
			
		||||
        super().__init__(mx.max, -float("inf"), kernel_size, stride, padding)
 | 
			
		||||
 | 
			
		||||
@@ -229,9 +229,9 @@ class AvgPool1d(_Pool1d):
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        kernel_size: Union[int, Tuple[int, int]],
 | 
			
		||||
        stride: Optional[Union[int, Tuple[int, int]]] = None,
 | 
			
		||||
        padding: Optional[Union[int, Tuple[int, int]]] = 0,
 | 
			
		||||
        kernel_size: Union[int, Tuple[int]],
 | 
			
		||||
        stride: Optional[Union[int, Tuple[int]]] = None,
 | 
			
		||||
        padding: Union[int, Tuple[int]] = 0,
 | 
			
		||||
    ):
 | 
			
		||||
        super().__init__(mx.mean, 0, kernel_size, stride, padding)
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -12,7 +12,7 @@ def quantize(
 | 
			
		||||
    model: Module,
 | 
			
		||||
    group_size: int = 64,
 | 
			
		||||
    bits: int = 4,
 | 
			
		||||
    class_predicate: Optional[callable] = None,
 | 
			
		||||
    class_predicate: Optional[Callable] = None,
 | 
			
		||||
):
 | 
			
		||||
    """Quantize the sub-modules of a module according to a predicate.
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -1,7 +1,7 @@
 | 
			
		||||
# Copyright © 2023 Apple Inc.
 | 
			
		||||
 | 
			
		||||
import math
 | 
			
		||||
from typing import Literal
 | 
			
		||||
from typing import Literal, Optional
 | 
			
		||||
 | 
			
		||||
import mlx.core as mx
 | 
			
		||||
 | 
			
		||||
@@ -22,7 +22,7 @@ def _reduce(loss: mx.array, reduction: Reduction = "none"):
 | 
			
		||||
def cross_entropy(
 | 
			
		||||
    logits: mx.array,
 | 
			
		||||
    targets: mx.array,
 | 
			
		||||
    weights: mx.array = None,
 | 
			
		||||
    weights: Optional[mx.array] = None,
 | 
			
		||||
    axis: int = -1,
 | 
			
		||||
    label_smoothing: float = 0.0,
 | 
			
		||||
    reduction: Reduction = "none",
 | 
			
		||||
@@ -117,7 +117,7 @@ def cross_entropy(
 | 
			
		||||
def binary_cross_entropy(
 | 
			
		||||
    inputs: mx.array,
 | 
			
		||||
    targets: mx.array,
 | 
			
		||||
    weights: mx.array = None,
 | 
			
		||||
    weights: Optional[mx.array] = None,
 | 
			
		||||
    with_logits: bool = True,
 | 
			
		||||
    reduction: Reduction = "mean",
 | 
			
		||||
) -> mx.array:
 | 
			
		||||
 
 | 
			
		||||
@@ -1,7 +1,7 @@
 | 
			
		||||
# Copyright © 2023-2024 Apple Inc.
 | 
			
		||||
 | 
			
		||||
from functools import wraps
 | 
			
		||||
from typing import Callable
 | 
			
		||||
from typing import Callable, Optional
 | 
			
		||||
 | 
			
		||||
import mlx.core as mx
 | 
			
		||||
 | 
			
		||||
@@ -37,7 +37,7 @@ def value_and_grad(model: Module, fn: Callable):
 | 
			
		||||
    return wrapped_value_grad_fn
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def checkpoint(module: Module, fn: Callable = None):
 | 
			
		||||
def checkpoint(module: Module, fn: Optional[Callable] = None):
 | 
			
		||||
    """Transform the passed callable to one that performs gradient
 | 
			
		||||
    checkpointing with respect to the trainable parameters of the module (and
 | 
			
		||||
    the callable's inputs).
 | 
			
		||||
 
 | 
			
		||||
@@ -4,6 +4,7 @@ import math
 | 
			
		||||
from typing import Callable, List, Optional, Tuple, Union
 | 
			
		||||
 | 
			
		||||
import mlx.core as mx
 | 
			
		||||
from mlx.nn import Module
 | 
			
		||||
from mlx.utils import tree_map, tree_reduce
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@@ -17,7 +18,7 @@ class Optimizer:
 | 
			
		||||
        self._state = {"step": mx.array(0, mx.uint64)}
 | 
			
		||||
        self._schedulers = {k: v for k, v in (schedulers or {}).items()}
 | 
			
		||||
 | 
			
		||||
    def update(self, model: "mlx.nn.Module", gradients: dict):
 | 
			
		||||
    def update(self, model: Module, gradients: dict):
 | 
			
		||||
        """Apply the gradients to the parameters of the model and update the
 | 
			
		||||
        model with the new parameters.
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -1,10 +1,10 @@
 | 
			
		||||
# Copyright © 2023 Apple Inc.
 | 
			
		||||
from collections import defaultdict
 | 
			
		||||
from typing import Any, Callable, Tuple
 | 
			
		||||
from typing import Any, Callable, Optional, Tuple
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def tree_map(
 | 
			
		||||
    fn: Callable, tree: Any, *rest: Tuple[Any], is_leaf: Callable = None
 | 
			
		||||
    fn: Callable, tree: Any, *rest: Any, is_leaf: Optional[Callable] = None
 | 
			
		||||
) -> Any:
 | 
			
		||||
    """Applies ``fn`` to the leaves of the Python tree ``tree`` and
 | 
			
		||||
    returns a new collection with the results.
 | 
			
		||||
@@ -59,8 +59,8 @@ def tree_map(
 | 
			
		||||
def tree_map_with_path(
 | 
			
		||||
    fn: Callable,
 | 
			
		||||
    tree: Any,
 | 
			
		||||
    *rest: Tuple[Any],
 | 
			
		||||
    is_leaf: Callable = None,
 | 
			
		||||
    *rest: Any,
 | 
			
		||||
    is_leaf: Optional[Callable] = None,
 | 
			
		||||
    path: Any = None,
 | 
			
		||||
) -> Any:
 | 
			
		||||
    """Applies ``fn`` to the path and leaves of the Python tree ``tree`` and
 | 
			
		||||
 
 | 
			
		||||
@@ -9,6 +9,7 @@
 | 
			
		||||
#include <nanobind/stl/string.h>
 | 
			
		||||
#include <nanobind/stl/variant.h>
 | 
			
		||||
#include <nanobind/stl/vector.h>
 | 
			
		||||
#include <nanobind/typing.h>
 | 
			
		||||
 | 
			
		||||
#include "mlx/backend/metal/metal.h"
 | 
			
		||||
#include "python/src/buffer.h"
 | 
			
		||||
@@ -113,6 +114,7 @@ void init_array(nb::module_& m) {
 | 
			
		||||
      .def("__hash__", [](const Dtype& t) {
 | 
			
		||||
        return static_cast<int64_t>(t.val);
 | 
			
		||||
      });
 | 
			
		||||
 | 
			
		||||
  m.attr("bool_") = nb::cast(bool_);
 | 
			
		||||
  m.attr("uint8") = nb::cast(uint8);
 | 
			
		||||
  m.attr("uint16") = nb::cast(uint16);
 | 
			
		||||
@@ -177,7 +179,7 @@ void init_array(nb::module_& m) {
 | 
			
		||||
      .export_values();
 | 
			
		||||
  nb::class_<ArrayAt>(
 | 
			
		||||
      m,
 | 
			
		||||
      "_ArrayAt",
 | 
			
		||||
      "ArrayAt",
 | 
			
		||||
      R"pbdoc(
 | 
			
		||||
      A helper object to apply updates at specific indices.
 | 
			
		||||
      )pbdoc")
 | 
			
		||||
@@ -195,7 +197,7 @@ void init_array(nb::module_& m) {
 | 
			
		||||
 | 
			
		||||
  nb::class_<ArrayPythonIterator>(
 | 
			
		||||
      m,
 | 
			
		||||
      "_ArrayIterator",
 | 
			
		||||
      "ArrayIterator",
 | 
			
		||||
      R"pbdoc(
 | 
			
		||||
      A helper object to iterate over the 1st dimension of an array.
 | 
			
		||||
      )pbdoc")
 | 
			
		||||
 
 | 
			
		||||
@@ -229,14 +229,16 @@ void init_fast(nb::module_& parent_module) {
 | 
			
		||||
      Returns:
 | 
			
		||||
        Callable ``metal_kernel``.
 | 
			
		||||
 | 
			
		||||
      Example:
 | 
			
		||||
 | 
			
		||||
        .. code-block:: python
 | 
			
		||||
 | 
			
		||||
          def exp_elementwise(a: mx.array):
 | 
			
		||||
            source = """
 | 
			
		||||
              source = '''
 | 
			
		||||
                  uint elem = thread_position_in_grid.x;
 | 
			
		||||
                  T tmp = inp[elem];
 | 
			
		||||
                  out[elem] = metal::exp(tmp);
 | 
			
		||||
            """
 | 
			
		||||
              '''
 | 
			
		||||
 | 
			
		||||
              kernel = mx.fast.metal_kernel(
 | 
			
		||||
                  name="myexp",
 | 
			
		||||
@@ -256,7 +258,6 @@ void init_fast(nb::module_& parent_module) {
 | 
			
		||||
          a = mx.random.normal(shape=(4, 16)).astype(mx.float16)
 | 
			
		||||
          b = exp_elementwise(a)
 | 
			
		||||
          assert mx.allclose(b, mx.exp(a))
 | 
			
		||||
 | 
			
		||||
      )pbdoc")
 | 
			
		||||
      .def(
 | 
			
		||||
          "__call__",
 | 
			
		||||
 
 | 
			
		||||
@@ -63,7 +63,7 @@ void init_linalg(nb::module_& parent_module) {
 | 
			
		||||
      nb::kw_only(),
 | 
			
		||||
      "stream"_a = nb::none(),
 | 
			
		||||
      nb::sig(
 | 
			
		||||
          "def norm(a: array, /, ord: Union[None, scalar, str] = None, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"),
 | 
			
		||||
          "def norm(a: array, /, ord: Union[None, int, float, str] = None, axis: Union[None, int, list[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"),
 | 
			
		||||
      R"pbdoc(
 | 
			
		||||
        Matrix or vector norm.
 | 
			
		||||
 | 
			
		||||
@@ -74,7 +74,7 @@ void init_linalg(nb::module_& parent_module) {
 | 
			
		||||
          a (array): Input array.  If ``axis`` is ``None``, ``a`` must be 1-D or 2-D,
 | 
			
		||||
            unless ``ord`` is ``None``. If both ``axis`` and ``ord`` are ``None``, the
 | 
			
		||||
            2-norm of ``a.flatten`` will be returned.
 | 
			
		||||
          ord (scalar or str, optional): Order of the norm (see table under ``Notes``).
 | 
			
		||||
          ord (int, float or str, optional): Order of the norm (see table under ``Notes``).
 | 
			
		||||
            If ``None``, the 2-norm (or Frobenius norm for matrices) will be computed
 | 
			
		||||
            along the given ``axis``.  Default: ``None``.
 | 
			
		||||
          axis (int or list(int), optional): If ``axis`` is an integer, it specifies the
 | 
			
		||||
@@ -187,7 +187,7 @@ void init_linalg(nb::module_& parent_module) {
 | 
			
		||||
      nb::kw_only(),
 | 
			
		||||
      "stream"_a = nb::none(),
 | 
			
		||||
      nb::sig(
 | 
			
		||||
          "def qr(a: array, *, stream: Union[None, Stream, Device] = None) -> (array, array)"),
 | 
			
		||||
          "def qr(a: array, *, stream: Union[None, Stream, Device] = None) -> tuple(array, array)"),
 | 
			
		||||
      R"pbdoc(
 | 
			
		||||
        The QR factorization of the input matrix.
 | 
			
		||||
 | 
			
		||||
@@ -220,7 +220,7 @@ void init_linalg(nb::module_& parent_module) {
 | 
			
		||||
      nb::kw_only(),
 | 
			
		||||
      "stream"_a = nb::none(),
 | 
			
		||||
      nb::sig(
 | 
			
		||||
          "def svd(a: array, *, stream: Union[None, Stream, Device] = None) -> (array, array, array)"),
 | 
			
		||||
          "def svd(a: array, *, stream: Union[None, Stream, Device] = None) -> tuple(array, array, array)"),
 | 
			
		||||
      R"pbdoc(
 | 
			
		||||
        The Singular Value Decomposition (SVD) of the input matrix.
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -1360,7 +1360,7 @@ void init_ops(nb::module_& m) {
 | 
			
		||||
      "dtype"_a = nb::none(),
 | 
			
		||||
      "stream"_a = nb::none(),
 | 
			
		||||
      nb::sig(
 | 
			
		||||
          "def arange(stop : Union[int, float], step : Union[None, int, float], dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array"));
 | 
			
		||||
          "def arange(stop : Union[int, float], step : Union[None, int, float] = None, dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array"));
 | 
			
		||||
  m.def(
 | 
			
		||||
      "linspace",
 | 
			
		||||
      [](Scalar start,
 | 
			
		||||
@@ -2695,7 +2695,7 @@ void init_ops(nb::module_& m) {
 | 
			
		||||
      nb::kw_only(),
 | 
			
		||||
      "stream"_a = nb::none(),
 | 
			
		||||
      nb::sig(
 | 
			
		||||
          "def concatenate(arrays: List[array], axis: Optional[int] = 0, *, stream: Union[None, Stream, Device] = None) -> array"),
 | 
			
		||||
          "def concatenate(arrays: list[array], axis: Optional[int] = 0, *, stream: Union[None, Stream, Device] = None) -> array"),
 | 
			
		||||
      R"pbdoc(
 | 
			
		||||
        Concatenate the arrays along the given axis.
 | 
			
		||||
 | 
			
		||||
@@ -2723,7 +2723,7 @@ void init_ops(nb::module_& m) {
 | 
			
		||||
      nb::kw_only(),
 | 
			
		||||
      "stream"_a = nb::none(),
 | 
			
		||||
      nb::sig(
 | 
			
		||||
          "def concat(arrays: List[array], axis: Optional[int] = 0, *, stream: Union[None, Stream, Device] = None) -> array"),
 | 
			
		||||
          "def concat(arrays: list[array], axis: Optional[int] = 0, *, stream: Union[None, Stream, Device] = None) -> array"),
 | 
			
		||||
      R"pbdoc(
 | 
			
		||||
        See :func:`concatenate`.
 | 
			
		||||
      )pbdoc");
 | 
			
		||||
@@ -2743,7 +2743,7 @@ void init_ops(nb::module_& m) {
 | 
			
		||||
      nb::kw_only(),
 | 
			
		||||
      "stream"_a = nb::none(),
 | 
			
		||||
      nb::sig(
 | 
			
		||||
          "def stack(arrays: List[array], axis: Optional[int] = 0, *, stream: Union[None, Stream, Device] = None) -> array"),
 | 
			
		||||
          "def stack(arrays: list[array], axis: Optional[int] = 0, *, stream: Union[None, Stream, Device] = None) -> array"),
 | 
			
		||||
      R"pbdoc(
 | 
			
		||||
        Stacks the arrays along a new axis.
 | 
			
		||||
 | 
			
		||||
@@ -2770,7 +2770,7 @@ void init_ops(nb::module_& m) {
 | 
			
		||||
      "indexing"_a = "xy",
 | 
			
		||||
      "stream"_a = nb::none(),
 | 
			
		||||
      nb::sig(
 | 
			
		||||
          "def meshgrid(*arrays: array, sparse: Optional[bool] = false, indexing: Optional[str] = 'xy', stream: Union[None, Stream, Device] = None) -> array"),
 | 
			
		||||
          "def meshgrid(*arrays: array, sparse: Optional[bool] = False, indexing: Optional[str] = 'xy', stream: Union[None, Stream, Device] = None) -> array"),
 | 
			
		||||
      R"pbdoc(
 | 
			
		||||
        Generate multidimensional coordinate grids from 1-D coordinate arrays
 | 
			
		||||
 | 
			
		||||
@@ -2889,7 +2889,7 @@ void init_ops(nb::module_& m) {
 | 
			
		||||
      nb::kw_only(),
 | 
			
		||||
      "stream"_a = nb::none(),
 | 
			
		||||
      nb::sig(
 | 
			
		||||
          "def pad(a: array, pad_width: Union[int, Tuple[int], Tuple[int, int], List[Tuple[int, int]]], mode: Literal['constant', 'edge'] = 'constant', constant_values: Union[scalar, array] = 0, *, stream: Union[None, Stream, Device] = None) -> array"),
 | 
			
		||||
          "def pad(a: array, pad_width: Union[int, tuple[int], tuple[int, int], list[tuple[int, int]]], mode: Literal['constant', 'edge'] = 'constant', constant_values: Union[scalar, array] = 0, *, stream: Union[None, Stream, Device] = None) -> array"),
 | 
			
		||||
      R"pbdoc(
 | 
			
		||||
        Pad an array with a constant value
 | 
			
		||||
 | 
			
		||||
@@ -3291,7 +3291,7 @@ void init_ops(nb::module_& m) {
 | 
			
		||||
      nb::kw_only(),
 | 
			
		||||
      "stream"_a = nb::none(),
 | 
			
		||||
      nb::sig(
 | 
			
		||||
          "def conv2d(input: array, weight: array, /, stride: Union[int, Tuple[int, int]] = 1, padding: Union[int, Tuple[int, int]] = 0, dilation: Union[int, Tuple[int, int]] = 1, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array"),
 | 
			
		||||
          "def conv2d(input: array, weight: array, /, stride: Union[int, tuple[int, int]] = 1, padding: Union[int, tuple[int, int]] = 0, dilation: Union[int, tuple[int, int]] = 1, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array"),
 | 
			
		||||
      R"pbdoc(
 | 
			
		||||
        2D convolution over an input with several channels
 | 
			
		||||
 | 
			
		||||
@@ -3361,7 +3361,7 @@ void init_ops(nb::module_& m) {
 | 
			
		||||
      nb::kw_only(),
 | 
			
		||||
      "stream"_a = nb::none(),
 | 
			
		||||
      nb::sig(
 | 
			
		||||
          "def conv3d(input: array, weight: array, /, stride: Union[int, Tuple[int, int, int]] = 1, padding: Union[int, Tuple[int, int, int]] = 0, dilation: Union[int, Tuple[int, int, int]] = 1, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array"),
 | 
			
		||||
          "def conv3d(input: array, weight: array, /, stride: Union[int, tuple[int, int, int]] = 1, padding: Union[int, tuple[int, int, int]] = 0, dilation: Union[int, tuple[int, int, int]] = 1, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array"),
 | 
			
		||||
      R"pbdoc(
 | 
			
		||||
        3D convolution over an input with several channels
 | 
			
		||||
 | 
			
		||||
@@ -3460,7 +3460,7 @@ void init_ops(nb::module_& m) {
 | 
			
		||||
      nb::kw_only(),
 | 
			
		||||
      "stream"_a = nb::none(),
 | 
			
		||||
      nb::sig(
 | 
			
		||||
          "def conv_general(input: array, weight: array, /, stride: Union[int, Sequence[int]] = 1, padding: Union[int, Sequence[int], Tuple[Sequence[int], Sequence[int]]] = 0, kernel_dilation: Union[int, Sequence[int]] = 1, input_dilation: Union[int, Sequence[int]] = 1, groups: int = 1, flip: bool = false, *, stream: Union[None, Stream, Device] = None) -> array"),
 | 
			
		||||
          "def conv_general(input: array, weight: array, /, stride: Union[int, Sequence[int]] = 1, padding: Union[int, Sequence[int], tuple[Sequence[int], Sequence[int]]] = 0, kernel_dilation: Union[int, Sequence[int]] = 1, input_dilation: Union[int, Sequence[int]] = 1, groups: int = 1, flip: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"),
 | 
			
		||||
      R"pbdoc(
 | 
			
		||||
        General convolution over an input with several channels
 | 
			
		||||
 | 
			
		||||
@@ -3560,7 +3560,7 @@ void init_ops(nb::module_& m) {
 | 
			
		||||
      nb::kw_only(),
 | 
			
		||||
      "stream"_a = nb::none(),
 | 
			
		||||
      nb::sig(
 | 
			
		||||
          "def load(file: str, /, format: Optional[str] = None, return_metadata: bool = False, *, stream: Union[None, Stream, Device] = None) -> Union[array, Dict[str, array]]"),
 | 
			
		||||
          "def load(file: str, /, format: Optional[str] = None, return_metadata: bool = False, *, stream: Union[None, Stream, Device] = None) -> Union[array, dict[str, array]]"),
 | 
			
		||||
      R"pbdoc(
 | 
			
		||||
        Load array(s) from a binary file.
 | 
			
		||||
 | 
			
		||||
@@ -3594,7 +3594,7 @@ void init_ops(nb::module_& m) {
 | 
			
		||||
      "arrays"_a,
 | 
			
		||||
      "metadata"_a = nb::none(),
 | 
			
		||||
      nb::sig(
 | 
			
		||||
          "def save_safetensors(file: str, arrays: Dict[str, array], metadata: Optional[Dict[str, str]] = None)"),
 | 
			
		||||
          "def save_safetensors(file: str, arrays: dict[str, array], metadata: Optional[dict[str, str]] = None)"),
 | 
			
		||||
      R"pbdoc(
 | 
			
		||||
        Save array(s) to a binary file in ``.safetensors`` format.
 | 
			
		||||
 | 
			
		||||
@@ -3615,7 +3615,7 @@ void init_ops(nb::module_& m) {
 | 
			
		||||
      "arrays"_a,
 | 
			
		||||
      "metadata"_a = nb::none(),
 | 
			
		||||
      nb::sig(
 | 
			
		||||
          "def save_gguf(file: str, arrays: Dict[str, array], metadata: Dict[str, Union[array, str, List[str]]])"),
 | 
			
		||||
          "def save_gguf(file: str, arrays: dict[str, array], metadata: dict[str, Union[array, str, list[str]]])"),
 | 
			
		||||
      R"pbdoc(
 | 
			
		||||
        Save array(s) to a binary file in ``.gguf`` format.
 | 
			
		||||
 | 
			
		||||
@@ -3769,7 +3769,7 @@ void init_ops(nb::module_& m) {
 | 
			
		||||
      nb::kw_only(),
 | 
			
		||||
      "stream"_a = nb::none(),
 | 
			
		||||
      nb::sig(
 | 
			
		||||
          "def quantize(w: array, /, group_size: int = 64, bits : int = 4, *, stream: Union[None, Stream, Device] = None) -> Tuple[array, array, array]"),
 | 
			
		||||
          "def quantize(w: array, /, group_size: int = 64, bits : int = 4, *, stream: Union[None, Stream, Device] = None) -> tuple[array, array, array]"),
 | 
			
		||||
      R"pbdoc(
 | 
			
		||||
        Quantize the matrix ``w`` using ``bits`` bits per element.
 | 
			
		||||
 | 
			
		||||
@@ -3924,7 +3924,7 @@ void init_ops(nb::module_& m) {
 | 
			
		||||
      nb::kw_only(),
 | 
			
		||||
      "stream"_a = nb::none(),
 | 
			
		||||
      nb::sig(
 | 
			
		||||
          "def tensordot(a: array, b: array, /, axes: Union[int, List[Sequence[int]]] = 2, *, stream: Union[None, Stream, Device] = None) -> array"),
 | 
			
		||||
          "def tensordot(a: array, b: array, /, axes: Union[int, list[Sequence[int]]] = 2, *, stream: Union[None, Stream, Device] = None) -> array"),
 | 
			
		||||
      R"pbdoc(
 | 
			
		||||
        Compute the tensor dot product along the specified axes.
 | 
			
		||||
 | 
			
		||||
@@ -4046,7 +4046,7 @@ void init_ops(nb::module_& m) {
 | 
			
		||||
      nb::kw_only(),
 | 
			
		||||
      "stream"_a = nb::none(),
 | 
			
		||||
      nb::sig(
 | 
			
		||||
          "def block_masked_mm(a: array, b: array, /, block_size: int = 64, mask_out: array, mask_lhs: array, mask_rhs: array, *, stream: Union[None, Stream, Device] = None) -> array"),
 | 
			
		||||
          "def block_masked_mm(a: array, b: array, /, block_size: int = 64, mask_out: Optional[array] = None, mask_lhs: Optional[array] = None, mask_rhs: Optional[array] = None, *, stream: Union[None, Stream, Device] = None) -> array"),
 | 
			
		||||
      R"pbdoc(
 | 
			
		||||
        Matrix multiplication with block masking.
 | 
			
		||||
 | 
			
		||||
@@ -4189,7 +4189,7 @@ void init_ops(nb::module_& m) {
 | 
			
		||||
      nb::kw_only(),
 | 
			
		||||
      "stream"_a = nb::none(),
 | 
			
		||||
      nb::sig(
 | 
			
		||||
          "def trace(a: array, /, offset: int = 0, axis1: int = 0, axis2: int = 1, dtype = Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array"),
 | 
			
		||||
          "def trace(a: array, /, offset: int = 0, axis1: int = 0, axis2: int = 1, dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array"),
 | 
			
		||||
      R"pbdoc(
 | 
			
		||||
        Return the sum along a specified diagonal in the given array.
 | 
			
		||||
 | 
			
		||||
@@ -4218,7 +4218,7 @@ void init_ops(nb::module_& m) {
 | 
			
		||||
      "arys"_a,
 | 
			
		||||
      "stream"_a = nb::none(),
 | 
			
		||||
      nb::sig(
 | 
			
		||||
          "def atleast_1d(*arys: array, stream: Union[None, Stream, Device] = None) -> Union[array, List[array]]"),
 | 
			
		||||
          "def atleast_1d(*arys: array, stream: Union[None, Stream, Device] = None) -> Union[array, list[array]]"),
 | 
			
		||||
      R"pbdoc(
 | 
			
		||||
        Convert all arrays to have at least one dimension.
 | 
			
		||||
 | 
			
		||||
@@ -4240,7 +4240,7 @@ void init_ops(nb::module_& m) {
 | 
			
		||||
      "arys"_a,
 | 
			
		||||
      "stream"_a = nb::none(),
 | 
			
		||||
      nb::sig(
 | 
			
		||||
          "def atleast_2d(*arys: array, stream: Union[None, Stream, Device] = None) -> Union[array, List[array]]"),
 | 
			
		||||
          "def atleast_2d(*arys: array, stream: Union[None, Stream, Device] = None) -> Union[array, list[array]]"),
 | 
			
		||||
      R"pbdoc(
 | 
			
		||||
        Convert all arrays to have at least two dimensions.
 | 
			
		||||
 | 
			
		||||
@@ -4262,7 +4262,7 @@ void init_ops(nb::module_& m) {
 | 
			
		||||
      "arys"_a,
 | 
			
		||||
      "stream"_a = nb::none(),
 | 
			
		||||
      nb::sig(
 | 
			
		||||
          "def atleast_3d(*arys: array, stream: Union[None, Stream, Device] = None) -> Union[array, List[array]]"),
 | 
			
		||||
          "def atleast_3d(*arys: array, stream: Union[None, Stream, Device] = None) -> Union[array, list[array]]"),
 | 
			
		||||
      R"pbdoc(
 | 
			
		||||
        Convert all arrays to have at least three dimensions.
 | 
			
		||||
 | 
			
		||||
@@ -4511,7 +4511,7 @@ void init_ops(nb::module_& m) {
 | 
			
		||||
      nb::kw_only(),
 | 
			
		||||
      "stream"_a = nb::none(),
 | 
			
		||||
      nb::sig(
 | 
			
		||||
          "def hadamard_transform(a: array, Optional[float] scale = None, stream: Union[None, Stream, Device] = None) -> array"),
 | 
			
		||||
          "def hadamard_transform(a: array, scale: Optional[float] = None, stream: Union[None, Stream, Device] = None) -> array"),
 | 
			
		||||
      R"pbdoc(
 | 
			
		||||
        Perform the Walsh-Hadamard transform along the final axis.
 | 
			
		||||
 | 
			
		||||
@@ -4575,7 +4575,7 @@ void init_ops(nb::module_& m) {
 | 
			
		||||
      nb::kw_only(),
 | 
			
		||||
      "stream"_a = nb::none(),
 | 
			
		||||
      nb::sig(
 | 
			
		||||
          "def einsum(subscripts: str, *operands, *, stream: Union[None, Stream, Device] = None) -> array"),
 | 
			
		||||
          "def einsum(subscripts: str, *operands, stream: Union[None, Stream, Device] = None) -> array"),
 | 
			
		||||
      R"pbdoc(
 | 
			
		||||
 | 
			
		||||
      Perform the Einstein summation convention on the operands.
 | 
			
		||||
 
 | 
			
		||||
@@ -93,7 +93,7 @@ void init_random(nb::module_& parent_module) {
 | 
			
		||||
      "num"_a = 2,
 | 
			
		||||
      "stream"_a = nb::none(),
 | 
			
		||||
      nb::sig(
 | 
			
		||||
          "def split(key: array, num: int = 2, stream: Union[None, Stream, Device] = None) -> array)"),
 | 
			
		||||
          "def split(key: array, num: int = 2, stream: Union[None, Stream, Device] = None) -> array"),
 | 
			
		||||
      R"pbdoc(
 | 
			
		||||
        Split a PRNG key into sub keys.
 | 
			
		||||
 | 
			
		||||
@@ -321,7 +321,7 @@ void init_random(nb::module_& parent_module) {
 | 
			
		||||
      "key"_a = nb::none(),
 | 
			
		||||
      "stream"_a = nb::none(),
 | 
			
		||||
      nb::sig(
 | 
			
		||||
          "def truncated_normal(lower: Union[scalar, array], upper: Union[scalar, array], shape: Optional[Sequence[int]] = None, dtype: float32, key: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"),
 | 
			
		||||
          "def truncated_normal(lower: Union[scalar, array], upper: Union[scalar, array], shape: Optional[Sequence[int]] = None, dtype: Optional[Dtype] = float32, key: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"),
 | 
			
		||||
      R"pbdoc(
 | 
			
		||||
        Generate values from a truncated normal distribution.
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -4,6 +4,7 @@
 | 
			
		||||
 | 
			
		||||
#include <nanobind/nanobind.h>
 | 
			
		||||
#include <nanobind/stl/optional.h>
 | 
			
		||||
#include <nanobind/stl/string.h>
 | 
			
		||||
#include <nanobind/stl/variant.h>
 | 
			
		||||
 | 
			
		||||
#include "mlx/stream.h"
 | 
			
		||||
@@ -56,8 +57,8 @@ void init_stream(nb::module_& m) {
 | 
			
		||||
            os << s;
 | 
			
		||||
            return os.str();
 | 
			
		||||
          })
 | 
			
		||||
      .def("__eq__", [](const Stream& s1, const Stream& s2) {
 | 
			
		||||
        return s1 == s2;
 | 
			
		||||
      .def("__eq__", [](const Stream& s, const nb::object& other) {
 | 
			
		||||
        return nb::isinstance<Stream>(other) && s == nb::cast<Stream>(other);
 | 
			
		||||
      });
 | 
			
		||||
 | 
			
		||||
  nb::implicitly_convertible<Device::DeviceType, Device>();
 | 
			
		||||
 
 | 
			
		||||
@@ -178,7 +178,7 @@ auto py_value_and_grad(
 | 
			
		||||
              msg << error_msg_tag << " The return value of the function "
 | 
			
		||||
                  << "whose gradient we want to compute should be either a "
 | 
			
		||||
                  << "scalar array or a tuple with the first value being a "
 | 
			
		||||
                  << "scalar array (Union[array, Tuple[array, Any, ...]]); but "
 | 
			
		||||
                  << "scalar array (Union[array, tuple[array, Any, ...]]); but "
 | 
			
		||||
                  << type_name_str(py_value_out) << " was returned.";
 | 
			
		||||
              throw std::invalid_argument(msg.str());
 | 
			
		||||
            }
 | 
			
		||||
@@ -197,7 +197,7 @@ auto py_value_and_grad(
 | 
			
		||||
              msg << error_msg_tag << " The return value of the function "
 | 
			
		||||
                  << "whose gradient we want to compute should be either a "
 | 
			
		||||
                  << "scalar array or a tuple with the first value being a "
 | 
			
		||||
                  << "scalar array (Union[array, Tuple[array, Any, ...]]); but it "
 | 
			
		||||
                  << "scalar array (Union[array, tuple[array, Any, ...]]); but it "
 | 
			
		||||
                  << "was a tuple with the first value being of type "
 | 
			
		||||
                  << type_name_str(ret[0]) << " .";
 | 
			
		||||
              throw std::invalid_argument(msg.str());
 | 
			
		||||
@@ -973,13 +973,13 @@ void init_transforms(nb::module_& m) {
 | 
			
		||||
      .def(
 | 
			
		||||
          nb::init<nb::callable>(),
 | 
			
		||||
          "f"_a,
 | 
			
		||||
          nb::sig("def __init__(self, f: callable)"))
 | 
			
		||||
          nb::sig("def __init__(self, f: Callable)"))
 | 
			
		||||
      .def("__call__", &PyCustomFunction::call_impl)
 | 
			
		||||
      .def(
 | 
			
		||||
          "vjp",
 | 
			
		||||
          &PyCustomFunction::set_vjp,
 | 
			
		||||
          "f"_a,
 | 
			
		||||
          nb::sig("def vjp(self, f_vjp: callable)"),
 | 
			
		||||
          nb::sig("def vjp(self, f: Callable)"),
 | 
			
		||||
          R"pbdoc(
 | 
			
		||||
            Define a custom vjp for the wrapped function.
 | 
			
		||||
 | 
			
		||||
@@ -1001,7 +1001,7 @@ void init_transforms(nb::module_& m) {
 | 
			
		||||
          "jvp",
 | 
			
		||||
          &PyCustomFunction::set_jvp,
 | 
			
		||||
          "f"_a,
 | 
			
		||||
          nb::sig("def jvp(self, f_jvp: callable)"),
 | 
			
		||||
          nb::sig("def jvp(self, f: Callable)"),
 | 
			
		||||
          R"pbdoc(
 | 
			
		||||
            Define a custom jvp for the wrapped function.
 | 
			
		||||
 | 
			
		||||
@@ -1021,7 +1021,7 @@ void init_transforms(nb::module_& m) {
 | 
			
		||||
          "vmap",
 | 
			
		||||
          &PyCustomFunction::set_vmap,
 | 
			
		||||
          "f"_a,
 | 
			
		||||
          nb::sig("def vmap(self, f_vmap: callable)"),
 | 
			
		||||
          nb::sig("def vmap(self, f: Callable)"),
 | 
			
		||||
          R"pbdoc(
 | 
			
		||||
            Define a custom vectorization transformation for the wrapped function.
 | 
			
		||||
 | 
			
		||||
@@ -1116,7 +1116,7 @@ void init_transforms(nb::module_& m) {
 | 
			
		||||
      "primals"_a,
 | 
			
		||||
      "tangents"_a,
 | 
			
		||||
      nb::sig(
 | 
			
		||||
          "def jvp(fun: callable, primals: List[array], tangents: List[array]) -> Tuple[List[array], List[array]]"),
 | 
			
		||||
          "def jvp(fun: Callable, primals: list[array], tangents: list[array]) -> tuple[list[array], list[array]]"),
 | 
			
		||||
      R"pbdoc(
 | 
			
		||||
        Compute the Jacobian-vector product.
 | 
			
		||||
 | 
			
		||||
@@ -1124,7 +1124,7 @@ void init_transforms(nb::module_& m) {
 | 
			
		||||
        at ``primals`` with the ``tangents``.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            fun (callable): A function which takes a variable number of :class:`array`
 | 
			
		||||
            fun (Callable): A function which takes a variable number of :class:`array`
 | 
			
		||||
              and returns a single :class:`array` or list of :class:`array`.
 | 
			
		||||
            primals (list(array)): A list of :class:`array` at which to
 | 
			
		||||
              evaluate the Jacobian.
 | 
			
		||||
@@ -1155,7 +1155,7 @@ void init_transforms(nb::module_& m) {
 | 
			
		||||
      "primals"_a,
 | 
			
		||||
      "cotangents"_a,
 | 
			
		||||
      nb::sig(
 | 
			
		||||
          "def vjp(fun: callable, primals: List[array], cotangents: List[array]) -> Tuple[List[array], List[array]]"),
 | 
			
		||||
          "def vjp(fun: Callable, primals: list[array], cotangents: list[array]) -> tuple[list[array], list[array]]"),
 | 
			
		||||
      R"pbdoc(
 | 
			
		||||
        Compute the vector-Jacobian product.
 | 
			
		||||
 | 
			
		||||
@@ -1163,7 +1163,7 @@ void init_transforms(nb::module_& m) {
 | 
			
		||||
        function ``fun`` evaluated at ``primals``.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
          fun (callable): A function which takes a variable number of :class:`array`
 | 
			
		||||
          fun (Callable): A function which takes a variable number of :class:`array`
 | 
			
		||||
            and returns a single :class:`array` or list of :class:`array`.
 | 
			
		||||
          primals (list(array)): A list of :class:`array` at which to
 | 
			
		||||
            evaluate the Jacobian.
 | 
			
		||||
@@ -1189,7 +1189,7 @@ void init_transforms(nb::module_& m) {
 | 
			
		||||
      "argnums"_a = nb::none(),
 | 
			
		||||
      "argnames"_a = std::vector<std::string>{},
 | 
			
		||||
      nb::sig(
 | 
			
		||||
          "def value_and_grad(fun: callable, argnums: Optional[Union[int, List[int]]] = None, argnames: Union[str, List[str]] = []) -> callable"),
 | 
			
		||||
          "def value_and_grad(fun: Callable, argnums: Optional[Union[int, list[int]]] = None, argnames: Union[str, list[str]] = []) -> Callable"),
 | 
			
		||||
      R"pbdoc(
 | 
			
		||||
        Returns a function which computes the value and gradient of ``fun``.
 | 
			
		||||
 | 
			
		||||
@@ -1221,7 +1221,7 @@ void init_transforms(nb::module_& m) {
 | 
			
		||||
            (loss, mse, l1), grads = mx.value_and_grad(lasso)(params, inputs, targets)
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            fun (callable): A function which takes a variable number of
 | 
			
		||||
            fun (Callable): A function which takes a variable number of
 | 
			
		||||
              :class:`array` or trees of :class:`array` and returns
 | 
			
		||||
              a scalar output :class:`array` or a tuple the first element
 | 
			
		||||
              of which should be a scalar :class:`array`.
 | 
			
		||||
@@ -1235,7 +1235,7 @@ void init_transforms(nb::module_& m) {
 | 
			
		||||
              no gradients for keyword arguments by default.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            callable: A function which returns a tuple where the first element
 | 
			
		||||
            Callable: A function which returns a tuple where the first element
 | 
			
		||||
            is the output of `fun` and the second element is the gradients w.r.t.
 | 
			
		||||
            the loss.
 | 
			
		||||
      )pbdoc");
 | 
			
		||||
@@ -1257,12 +1257,12 @@ void init_transforms(nb::module_& m) {
 | 
			
		||||
      "argnums"_a = nb::none(),
 | 
			
		||||
      "argnames"_a = std::vector<std::string>{},
 | 
			
		||||
      nb::sig(
 | 
			
		||||
          "def grad(fun: callable, argnums: Optional[Union[int, List[int]]] = None, argnames: Union[str, List[str]] = []) -> callable"),
 | 
			
		||||
          "def grad(fun: Callable, argnums: Optional[Union[int, list[int]]] = None, argnames: Union[str, list[str]] = []) -> Callable"),
 | 
			
		||||
      R"pbdoc(
 | 
			
		||||
        Returns a function which computes the gradient of ``fun``.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            fun (callable): A function which takes a variable number of
 | 
			
		||||
            fun (Callable): A function which takes a variable number of
 | 
			
		||||
              :class:`array` or trees of :class:`array` and returns
 | 
			
		||||
              a scalar output :class:`array`.
 | 
			
		||||
            argnums (int or list(int), optional): Specify the index (or indices)
 | 
			
		||||
@@ -1275,7 +1275,7 @@ void init_transforms(nb::module_& m) {
 | 
			
		||||
              no gradients for keyword arguments by default.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            callable: A function which has the same input arguments as ``fun`` and
 | 
			
		||||
            Callable: A function which has the same input arguments as ``fun`` and
 | 
			
		||||
            returns the gradient(s).
 | 
			
		||||
      )pbdoc");
 | 
			
		||||
  m.def(
 | 
			
		||||
@@ -1289,12 +1289,12 @@ void init_transforms(nb::module_& m) {
 | 
			
		||||
      "in_axes"_a = 0,
 | 
			
		||||
      "out_axes"_a = 0,
 | 
			
		||||
      nb::sig(
 | 
			
		||||
          "def vmap(fun: callable, in_axes: object = 0, out_axes: object = 0) -> callable"),
 | 
			
		||||
          "def vmap(fun: Callable, in_axes: object = 0, out_axes: object = 0) -> Callable"),
 | 
			
		||||
      R"pbdoc(
 | 
			
		||||
        Returns a vectorized version of ``fun``.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            fun (callable): A function which takes a variable number of
 | 
			
		||||
            fun (Callable): A function which takes a variable number of
 | 
			
		||||
              :class:`array` or a tree of :class:`array` and returns
 | 
			
		||||
              a variable number of :class:`array` or a tree of :class:`array`.
 | 
			
		||||
            in_axes (int, optional): An integer or a valid prefix tree of the
 | 
			
		||||
@@ -1307,7 +1307,7 @@ void init_transforms(nb::module_& m) {
 | 
			
		||||
              Defaults to ``0``.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            callable: The vectorized function.
 | 
			
		||||
            Callable: The vectorized function.
 | 
			
		||||
      )pbdoc");
 | 
			
		||||
  m.def(
 | 
			
		||||
      "export_to_dot",
 | 
			
		||||
@@ -1367,11 +1367,13 @@ void init_transforms(nb::module_& m) {
 | 
			
		||||
      "inputs"_a = nb::none(),
 | 
			
		||||
      "outputs"_a = nb::none(),
 | 
			
		||||
      "shapeless"_a = false,
 | 
			
		||||
      nb::sig(
 | 
			
		||||
          "def compile(fun: Callable, inputs: Optional[object] = None, outputs: Optional[object] = None, shapeless: bool = False) -> Callable"),
 | 
			
		||||
      R"pbdoc(
 | 
			
		||||
        Returns a compiled function which produces the same output as ``fun``.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            fun (callable): A function which takes a variable number of
 | 
			
		||||
            fun (Callable): A function which takes a variable number of
 | 
			
		||||
              :class:`array` or trees of :class:`array` and returns
 | 
			
		||||
              a variable number of :class:`array` or trees of :class:`array`.
 | 
			
		||||
            inputs (list or dict, optional): These inputs will be captured during
 | 
			
		||||
@@ -1392,7 +1394,7 @@ void init_transforms(nb::module_& m) {
 | 
			
		||||
              ``shapeless`` set to ``True``. Default: ``False``
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            callable: A compiled function which has the same input arguments
 | 
			
		||||
            Callable: A compiled function which has the same input arguments
 | 
			
		||||
            as ``fun`` and returns the the same output(s).
 | 
			
		||||
      )pbdoc");
 | 
			
		||||
  m.def(
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user