diff --git a/python/mlx/nn/init.py b/python/mlx/nn/init.py index 4020ff2aa..e33a24d91 100644 --- a/python/mlx/nn/init.py +++ b/python/mlx/nn/init.py @@ -234,7 +234,7 @@ def glorot_uniform( def he_normal( dtype: mx.Dtype = mx.float32, -) -> Callable[[mx.array, str, float], mx.array]: +) -> Callable[[mx.array, Literal["fan_in", "fan_out"], float], mx.array]: r"""Build a He normal initializer. This initializer samples from a normal distribution with a standard @@ -292,7 +292,7 @@ def he_normal( def he_uniform( dtype: mx.Dtype = mx.float32, -) -> Callable[[mx.array, str, float], mx.array]: +) -> Callable[[mx.array, Literal["fan_in", "fan_out"], float], mx.array]: r"""A He uniform (Kaiming uniform) initializer. This initializer samples from a uniform distribution with a range diff --git a/python/mlx/nn/layers/base.py b/python/mlx/nn/layers/base.py index a756ace5c..80764cf68 100644 --- a/python/mlx/nn/layers/base.py +++ b/python/mlx/nn/layers/base.py @@ -1,5 +1,7 @@ # Copyright © 2023 Apple Inc. +from __future__ import annotations + import textwrap from typing import Any, Callable, List, Optional, Tuple, Union @@ -7,42 +9,6 @@ import mlx.core as mx from mlx.utils import tree_flatten, tree_unflatten -def _unwrap(model, value_key, value, filter_fn, map_fn, is_leaf_fn): - if is_leaf_fn(model, value_key, value): - return map_fn(value) - - elif isinstance(value, Module): - return { - k: _unwrap(value, k, v, filter_fn, map_fn, is_leaf_fn) - for k, v in value.items() - if filter_fn(value, k, v) - } - - elif isinstance(value, dict): - nd = {} - for k, v in value.items(): - tk = f"{value_key}.{k}" - nd[k] = ( - _unwrap(model, tk, v, filter_fn, map_fn, is_leaf_fn) - if filter_fn(model, tk, v) - else {} - ) - return nd - - elif isinstance(value, list): - nl = [] - for i, vi in enumerate(value): - tk = f"{value_key}.{i}" - nl.append( - _unwrap(model, tk, vi, filter_fn, map_fn, is_leaf_fn) - if filter_fn(model, tk, vi) - else {} - ) - return nl - - raise RuntimeError("Unexpected leaf found while traversing the module") - - class Module(dict): """Base class for building neural networks with MLX. @@ -151,7 +117,7 @@ class Module(dict): self, file_or_weights: Union[str, List[Tuple[str, mx.array]]], strict: bool = True, - ) -> "Module": + ) -> Module: """ Update the model's weights from a ``.npz``, a ``.safetensors`` file, or a list. @@ -266,9 +232,9 @@ class Module(dict): def filter_and_map( self, - filter_fn: Callable[["mlx.nn.Module", str, Any], bool], + filter_fn: Callable[[Module, str, Any], bool], map_fn: Optional[Callable] = None, - is_leaf_fn: Optional[Callable[["mlx.nn.Module", str, Any], bool]] = None, + is_leaf_fn: Optional[Callable[[Module, str, Any], bool]] = None, ): """Recursively filter the contents of the module using ``filter_fn``, namely only select keys and values where ``filter_fn`` returns true. @@ -323,7 +289,7 @@ class Module(dict): return self.filter_and_map(self.valid_child_filter, is_leaf_fn=_is_leaf_module) - def update(self, parameters: dict) -> "Module": + def update(self, parameters: dict) -> Module: """Replace the parameters of this Module with the provided ones in the dict of dicts and lists. @@ -371,8 +337,8 @@ class Module(dict): def apply( self, map_fn: Callable[[mx.array], mx.array], - filter_fn: Optional[Callable[["mlx.nn.Module", str, Any], bool]] = None, - ) -> "Module": + filter_fn: Optional[Callable[[Module, str, Any], bool]] = None, + ) -> Module: """Map all the parameters using the provided ``map_fn`` and immediately update the module with the mapped parameters. @@ -391,7 +357,7 @@ class Module(dict): self.update(self.filter_and_map(filter_fn, map_fn)) return self - def update_modules(self, modules: dict) -> "Module": + def update_modules(self, modules: dict) -> Module: """Replace the child modules of this :class:`Module` instance with the provided ones in the dict of dicts and lists. @@ -432,9 +398,7 @@ class Module(dict): apply(self, modules) return self - def apply_to_modules( - self, apply_fn: Callable[[str, "mlx.nn.Module"], Any] - ) -> "Module": + def apply_to_modules(self, apply_fn: Callable[[str, Module], Any]) -> Module: """Apply a function to all the modules in this instance (including this instance). @@ -489,7 +453,7 @@ class Module(dict): recurse: bool = True, keys: Optional[Union[str, List[str]]] = None, strict: bool = False, - ) -> "Module": + ) -> Module: """Freeze the Module's parameters or some of them. Freezing a parameter means not computing gradients for it. @@ -544,7 +508,7 @@ class Module(dict): recurse: bool = True, keys: Optional[Union[str, List[str]]] = None, strict: bool = False, - ) -> "Module": + ) -> Module: """Unfreeze the Module's parameters or some of them. This function is idempotent ie unfreezing a model that is not frozen is @@ -588,7 +552,7 @@ class Module(dict): _unfreeze_impl("", self) return self - def train(self, mode: bool = True) -> "Module": + def train(self, mode: bool = True) -> Module: """Set the model in or out of training mode. Training mode only applies to certain layers. For example @@ -608,7 +572,7 @@ class Module(dict): self.apply_to_modules(_set_train) return self - def eval(self) -> "Module": + def eval(self) -> Module: """Set the model to evaluation mode. See :func:`train`. @@ -637,3 +601,39 @@ class Module(dict): return True self.apply(lambda x: x.astype(dtype) if predicate(x.dtype) else x) + + +def _unwrap(model, value_key, value, filter_fn, map_fn, is_leaf_fn): + if is_leaf_fn(model, value_key, value): + return map_fn(value) + + elif isinstance(value, Module): + return { + k: _unwrap(value, k, v, filter_fn, map_fn, is_leaf_fn) + for k, v in value.items() + if filter_fn(value, k, v) + } + + elif isinstance(value, dict): + nd = {} + for k, v in value.items(): + tk = f"{value_key}.{k}" + nd[k] = ( + _unwrap(model, tk, v, filter_fn, map_fn, is_leaf_fn) + if filter_fn(model, tk, v) + else {} + ) + return nd + + elif isinstance(value, list): + nl = [] + for i, vi in enumerate(value): + tk = f"{value_key}.{i}" + nl.append( + _unwrap(model, tk, vi, filter_fn, map_fn, is_leaf_fn) + if filter_fn(model, tk, vi) + else {} + ) + return nl + + raise RuntimeError("Unexpected leaf found while traversing the module") diff --git a/python/mlx/nn/layers/pooling.py b/python/mlx/nn/layers/pooling.py index 93ae4d8c2..d51feced7 100644 --- a/python/mlx/nn/layers/pooling.py +++ b/python/mlx/nn/layers/pooling.py @@ -190,9 +190,9 @@ class MaxPool1d(_Pool1d): def __init__( self, - kernel_size: Union[int, Tuple[int, int]], - stride: Optional[Union[int, Tuple[int, int]]] = None, - padding: Optional[Union[int, Tuple[int, int]]] = 0, + kernel_size: Union[int, Tuple[int]], + stride: Optional[Union[int, Tuple[int]]] = None, + padding: Union[int, Tuple[int]] = 0, ): super().__init__(mx.max, -float("inf"), kernel_size, stride, padding) @@ -229,9 +229,9 @@ class AvgPool1d(_Pool1d): def __init__( self, - kernel_size: Union[int, Tuple[int, int]], - stride: Optional[Union[int, Tuple[int, int]]] = None, - padding: Optional[Union[int, Tuple[int, int]]] = 0, + kernel_size: Union[int, Tuple[int]], + stride: Optional[Union[int, Tuple[int]]] = None, + padding: Union[int, Tuple[int]] = 0, ): super().__init__(mx.mean, 0, kernel_size, stride, padding) diff --git a/python/mlx/nn/layers/quantized.py b/python/mlx/nn/layers/quantized.py index b8d727d88..48c2ce13a 100644 --- a/python/mlx/nn/layers/quantized.py +++ b/python/mlx/nn/layers/quantized.py @@ -12,7 +12,7 @@ def quantize( model: Module, group_size: int = 64, bits: int = 4, - class_predicate: Optional[callable] = None, + class_predicate: Optional[Callable] = None, ): """Quantize the sub-modules of a module according to a predicate. diff --git a/python/mlx/nn/losses.py b/python/mlx/nn/losses.py index 55b5a68cc..ebf05d8ff 100644 --- a/python/mlx/nn/losses.py +++ b/python/mlx/nn/losses.py @@ -1,7 +1,7 @@ # Copyright © 2023 Apple Inc. import math -from typing import Literal +from typing import Literal, Optional import mlx.core as mx @@ -22,7 +22,7 @@ def _reduce(loss: mx.array, reduction: Reduction = "none"): def cross_entropy( logits: mx.array, targets: mx.array, - weights: mx.array = None, + weights: Optional[mx.array] = None, axis: int = -1, label_smoothing: float = 0.0, reduction: Reduction = "none", @@ -117,7 +117,7 @@ def cross_entropy( def binary_cross_entropy( inputs: mx.array, targets: mx.array, - weights: mx.array = None, + weights: Optional[mx.array] = None, with_logits: bool = True, reduction: Reduction = "mean", ) -> mx.array: diff --git a/python/mlx/nn/utils.py b/python/mlx/nn/utils.py index f651ce92e..8c5e4d462 100644 --- a/python/mlx/nn/utils.py +++ b/python/mlx/nn/utils.py @@ -1,7 +1,7 @@ # Copyright © 2023-2024 Apple Inc. from functools import wraps -from typing import Callable +from typing import Callable, Optional import mlx.core as mx @@ -37,7 +37,7 @@ def value_and_grad(model: Module, fn: Callable): return wrapped_value_grad_fn -def checkpoint(module: Module, fn: Callable = None): +def checkpoint(module: Module, fn: Optional[Callable] = None): """Transform the passed callable to one that performs gradient checkpointing with respect to the trainable parameters of the module (and the callable's inputs). diff --git a/python/mlx/optimizers/optimizers.py b/python/mlx/optimizers/optimizers.py index a997a031d..1b37bcc26 100644 --- a/python/mlx/optimizers/optimizers.py +++ b/python/mlx/optimizers/optimizers.py @@ -4,6 +4,7 @@ import math from typing import Callable, List, Optional, Tuple, Union import mlx.core as mx +from mlx.nn import Module from mlx.utils import tree_map, tree_reduce @@ -17,7 +18,7 @@ class Optimizer: self._state = {"step": mx.array(0, mx.uint64)} self._schedulers = {k: v for k, v in (schedulers or {}).items()} - def update(self, model: "mlx.nn.Module", gradients: dict): + def update(self, model: Module, gradients: dict): """Apply the gradients to the parameters of the model and update the model with the new parameters. diff --git a/python/mlx/utils.py b/python/mlx/utils.py index 14b23a41e..39b9ed21a 100644 --- a/python/mlx/utils.py +++ b/python/mlx/utils.py @@ -1,10 +1,10 @@ # Copyright © 2023 Apple Inc. from collections import defaultdict -from typing import Any, Callable, Tuple +from typing import Any, Callable, Optional, Tuple def tree_map( - fn: Callable, tree: Any, *rest: Tuple[Any], is_leaf: Callable = None + fn: Callable, tree: Any, *rest: Any, is_leaf: Optional[Callable] = None ) -> Any: """Applies ``fn`` to the leaves of the Python tree ``tree`` and returns a new collection with the results. @@ -59,8 +59,8 @@ def tree_map( def tree_map_with_path( fn: Callable, tree: Any, - *rest: Tuple[Any], - is_leaf: Callable = None, + *rest: Any, + is_leaf: Optional[Callable] = None, path: Any = None, ) -> Any: """Applies ``fn`` to the path and leaves of the Python tree ``tree`` and diff --git a/python/src/array.cpp b/python/src/array.cpp index 9213a9988..caf7ce6df 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -9,6 +9,7 @@ #include #include #include +#include #include "mlx/backend/metal/metal.h" #include "python/src/buffer.h" @@ -113,6 +114,7 @@ void init_array(nb::module_& m) { .def("__hash__", [](const Dtype& t) { return static_cast(t.val); }); + m.attr("bool_") = nb::cast(bool_); m.attr("uint8") = nb::cast(uint8); m.attr("uint16") = nb::cast(uint16); @@ -177,7 +179,7 @@ void init_array(nb::module_& m) { .export_values(); nb::class_( m, - "_ArrayAt", + "ArrayAt", R"pbdoc( A helper object to apply updates at specific indices. )pbdoc") @@ -195,7 +197,7 @@ void init_array(nb::module_& m) { nb::class_( m, - "_ArrayIterator", + "ArrayIterator", R"pbdoc( A helper object to iterate over the 1st dimension of an array. )pbdoc") diff --git a/python/src/fast.cpp b/python/src/fast.cpp index f389cc2c5..b7b891933 100644 --- a/python/src/fast.cpp +++ b/python/src/fast.cpp @@ -229,34 +229,35 @@ void init_fast(nb::module_& parent_module) { Returns: Callable ``metal_kernel``. - .. code-block:: python + Example: - def exp_elementwise(a: mx.array): - source = """ - uint elem = thread_position_in_grid.x; - T tmp = inp[elem]; - out[elem] = metal::exp(tmp); - """ + .. code-block:: python - kernel = mx.fast.metal_kernel( - name="myexp", - source=source - ) - outputs = kernel( - inputs={"inp": a}, - template={"T": mx.float32}, - grid=(a.size, 1, 1), - threadgroup=(256, 1, 1), - output_shapes={"out": a.shape}, - output_dtypes={"out": a.dtype}, - verbose=True, - ) - return outputs["out"] + def exp_elementwise(a: mx.array): + source = ''' + uint elem = thread_position_in_grid.x; + T tmp = inp[elem]; + out[elem] = metal::exp(tmp); + ''' - a = mx.random.normal(shape=(4, 16)).astype(mx.float16) - b = exp_elementwise(a) - assert mx.allclose(b, mx.exp(a)) + kernel = mx.fast.metal_kernel( + name="myexp", + source=source + ) + outputs = kernel( + inputs={"inp": a}, + template={"T": mx.float32}, + grid=(a.size, 1, 1), + threadgroup=(256, 1, 1), + output_shapes={"out": a.shape}, + output_dtypes={"out": a.dtype}, + verbose=True, + ) + return outputs["out"] + a = mx.random.normal(shape=(4, 16)).astype(mx.float16) + b = exp_elementwise(a) + assert mx.allclose(b, mx.exp(a)) )pbdoc") .def( "__call__", diff --git a/python/src/linalg.cpp b/python/src/linalg.cpp index 235e4f828..c175ebbfa 100644 --- a/python/src/linalg.cpp +++ b/python/src/linalg.cpp @@ -63,7 +63,7 @@ void init_linalg(nb::module_& parent_module) { nb::kw_only(), "stream"_a = nb::none(), nb::sig( - "def norm(a: array, /, ord: Union[None, scalar, str] = None, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"), + "def norm(a: array, /, ord: Union[None, int, float, str] = None, axis: Union[None, int, list[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( Matrix or vector norm. @@ -74,7 +74,7 @@ void init_linalg(nb::module_& parent_module) { a (array): Input array. If ``axis`` is ``None``, ``a`` must be 1-D or 2-D, unless ``ord`` is ``None``. If both ``axis`` and ``ord`` are ``None``, the 2-norm of ``a.flatten`` will be returned. - ord (scalar or str, optional): Order of the norm (see table under ``Notes``). + ord (int, float or str, optional): Order of the norm (see table under ``Notes``). If ``None``, the 2-norm (or Frobenius norm for matrices) will be computed along the given ``axis``. Default: ``None``. axis (int or list(int), optional): If ``axis`` is an integer, it specifies the @@ -187,7 +187,7 @@ void init_linalg(nb::module_& parent_module) { nb::kw_only(), "stream"_a = nb::none(), nb::sig( - "def qr(a: array, *, stream: Union[None, Stream, Device] = None) -> (array, array)"), + "def qr(a: array, *, stream: Union[None, Stream, Device] = None) -> tuple(array, array)"), R"pbdoc( The QR factorization of the input matrix. @@ -220,7 +220,7 @@ void init_linalg(nb::module_& parent_module) { nb::kw_only(), "stream"_a = nb::none(), nb::sig( - "def svd(a: array, *, stream: Union[None, Stream, Device] = None) -> (array, array, array)"), + "def svd(a: array, *, stream: Union[None, Stream, Device] = None) -> tuple(array, array, array)"), R"pbdoc( The Singular Value Decomposition (SVD) of the input matrix. diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 3e4aa1093..7b3dace34 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -1360,7 +1360,7 @@ void init_ops(nb::module_& m) { "dtype"_a = nb::none(), "stream"_a = nb::none(), nb::sig( - "def arange(stop : Union[int, float], step : Union[None, int, float], dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array")); + "def arange(stop : Union[int, float], step : Union[None, int, float] = None, dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array")); m.def( "linspace", [](Scalar start, @@ -2695,7 +2695,7 @@ void init_ops(nb::module_& m) { nb::kw_only(), "stream"_a = nb::none(), nb::sig( - "def concatenate(arrays: List[array], axis: Optional[int] = 0, *, stream: Union[None, Stream, Device] = None) -> array"), + "def concatenate(arrays: list[array], axis: Optional[int] = 0, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( Concatenate the arrays along the given axis. @@ -2723,7 +2723,7 @@ void init_ops(nb::module_& m) { nb::kw_only(), "stream"_a = nb::none(), nb::sig( - "def concat(arrays: List[array], axis: Optional[int] = 0, *, stream: Union[None, Stream, Device] = None) -> array"), + "def concat(arrays: list[array], axis: Optional[int] = 0, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( See :func:`concatenate`. )pbdoc"); @@ -2743,7 +2743,7 @@ void init_ops(nb::module_& m) { nb::kw_only(), "stream"_a = nb::none(), nb::sig( - "def stack(arrays: List[array], axis: Optional[int] = 0, *, stream: Union[None, Stream, Device] = None) -> array"), + "def stack(arrays: list[array], axis: Optional[int] = 0, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( Stacks the arrays along a new axis. @@ -2770,7 +2770,7 @@ void init_ops(nb::module_& m) { "indexing"_a = "xy", "stream"_a = nb::none(), nb::sig( - "def meshgrid(*arrays: array, sparse: Optional[bool] = false, indexing: Optional[str] = 'xy', stream: Union[None, Stream, Device] = None) -> array"), + "def meshgrid(*arrays: array, sparse: Optional[bool] = False, indexing: Optional[str] = 'xy', stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( Generate multidimensional coordinate grids from 1-D coordinate arrays @@ -2889,7 +2889,7 @@ void init_ops(nb::module_& m) { nb::kw_only(), "stream"_a = nb::none(), nb::sig( - "def pad(a: array, pad_width: Union[int, Tuple[int], Tuple[int, int], List[Tuple[int, int]]], mode: Literal['constant', 'edge'] = 'constant', constant_values: Union[scalar, array] = 0, *, stream: Union[None, Stream, Device] = None) -> array"), + "def pad(a: array, pad_width: Union[int, tuple[int], tuple[int, int], list[tuple[int, int]]], mode: Literal['constant', 'edge'] = 'constant', constant_values: Union[scalar, array] = 0, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( Pad an array with a constant value @@ -3291,7 +3291,7 @@ void init_ops(nb::module_& m) { nb::kw_only(), "stream"_a = nb::none(), nb::sig( - "def conv2d(input: array, weight: array, /, stride: Union[int, Tuple[int, int]] = 1, padding: Union[int, Tuple[int, int]] = 0, dilation: Union[int, Tuple[int, int]] = 1, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array"), + "def conv2d(input: array, weight: array, /, stride: Union[int, tuple[int, int]] = 1, padding: Union[int, tuple[int, int]] = 0, dilation: Union[int, tuple[int, int]] = 1, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( 2D convolution over an input with several channels @@ -3361,7 +3361,7 @@ void init_ops(nb::module_& m) { nb::kw_only(), "stream"_a = nb::none(), nb::sig( - "def conv3d(input: array, weight: array, /, stride: Union[int, Tuple[int, int, int]] = 1, padding: Union[int, Tuple[int, int, int]] = 0, dilation: Union[int, Tuple[int, int, int]] = 1, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array"), + "def conv3d(input: array, weight: array, /, stride: Union[int, tuple[int, int, int]] = 1, padding: Union[int, tuple[int, int, int]] = 0, dilation: Union[int, tuple[int, int, int]] = 1, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( 3D convolution over an input with several channels @@ -3460,7 +3460,7 @@ void init_ops(nb::module_& m) { nb::kw_only(), "stream"_a = nb::none(), nb::sig( - "def conv_general(input: array, weight: array, /, stride: Union[int, Sequence[int]] = 1, padding: Union[int, Sequence[int], Tuple[Sequence[int], Sequence[int]]] = 0, kernel_dilation: Union[int, Sequence[int]] = 1, input_dilation: Union[int, Sequence[int]] = 1, groups: int = 1, flip: bool = false, *, stream: Union[None, Stream, Device] = None) -> array"), + "def conv_general(input: array, weight: array, /, stride: Union[int, Sequence[int]] = 1, padding: Union[int, Sequence[int], tuple[Sequence[int], Sequence[int]]] = 0, kernel_dilation: Union[int, Sequence[int]] = 1, input_dilation: Union[int, Sequence[int]] = 1, groups: int = 1, flip: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( General convolution over an input with several channels @@ -3560,7 +3560,7 @@ void init_ops(nb::module_& m) { nb::kw_only(), "stream"_a = nb::none(), nb::sig( - "def load(file: str, /, format: Optional[str] = None, return_metadata: bool = False, *, stream: Union[None, Stream, Device] = None) -> Union[array, Dict[str, array]]"), + "def load(file: str, /, format: Optional[str] = None, return_metadata: bool = False, *, stream: Union[None, Stream, Device] = None) -> Union[array, dict[str, array]]"), R"pbdoc( Load array(s) from a binary file. @@ -3594,7 +3594,7 @@ void init_ops(nb::module_& m) { "arrays"_a, "metadata"_a = nb::none(), nb::sig( - "def save_safetensors(file: str, arrays: Dict[str, array], metadata: Optional[Dict[str, str]] = None)"), + "def save_safetensors(file: str, arrays: dict[str, array], metadata: Optional[dict[str, str]] = None)"), R"pbdoc( Save array(s) to a binary file in ``.safetensors`` format. @@ -3615,7 +3615,7 @@ void init_ops(nb::module_& m) { "arrays"_a, "metadata"_a = nb::none(), nb::sig( - "def save_gguf(file: str, arrays: Dict[str, array], metadata: Dict[str, Union[array, str, List[str]]])"), + "def save_gguf(file: str, arrays: dict[str, array], metadata: dict[str, Union[array, str, list[str]]])"), R"pbdoc( Save array(s) to a binary file in ``.gguf`` format. @@ -3769,7 +3769,7 @@ void init_ops(nb::module_& m) { nb::kw_only(), "stream"_a = nb::none(), nb::sig( - "def quantize(w: array, /, group_size: int = 64, bits : int = 4, *, stream: Union[None, Stream, Device] = None) -> Tuple[array, array, array]"), + "def quantize(w: array, /, group_size: int = 64, bits : int = 4, *, stream: Union[None, Stream, Device] = None) -> tuple[array, array, array]"), R"pbdoc( Quantize the matrix ``w`` using ``bits`` bits per element. @@ -3924,7 +3924,7 @@ void init_ops(nb::module_& m) { nb::kw_only(), "stream"_a = nb::none(), nb::sig( - "def tensordot(a: array, b: array, /, axes: Union[int, List[Sequence[int]]] = 2, *, stream: Union[None, Stream, Device] = None) -> array"), + "def tensordot(a: array, b: array, /, axes: Union[int, list[Sequence[int]]] = 2, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( Compute the tensor dot product along the specified axes. @@ -4046,7 +4046,7 @@ void init_ops(nb::module_& m) { nb::kw_only(), "stream"_a = nb::none(), nb::sig( - "def block_masked_mm(a: array, b: array, /, block_size: int = 64, mask_out: array, mask_lhs: array, mask_rhs: array, *, stream: Union[None, Stream, Device] = None) -> array"), + "def block_masked_mm(a: array, b: array, /, block_size: int = 64, mask_out: Optional[array] = None, mask_lhs: Optional[array] = None, mask_rhs: Optional[array] = None, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( Matrix multiplication with block masking. @@ -4189,7 +4189,7 @@ void init_ops(nb::module_& m) { nb::kw_only(), "stream"_a = nb::none(), nb::sig( - "def trace(a: array, /, offset: int = 0, axis1: int = 0, axis2: int = 1, dtype = Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array"), + "def trace(a: array, /, offset: int = 0, axis1: int = 0, axis2: int = 1, dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( Return the sum along a specified diagonal in the given array. @@ -4218,7 +4218,7 @@ void init_ops(nb::module_& m) { "arys"_a, "stream"_a = nb::none(), nb::sig( - "def atleast_1d(*arys: array, stream: Union[None, Stream, Device] = None) -> Union[array, List[array]]"), + "def atleast_1d(*arys: array, stream: Union[None, Stream, Device] = None) -> Union[array, list[array]]"), R"pbdoc( Convert all arrays to have at least one dimension. @@ -4240,7 +4240,7 @@ void init_ops(nb::module_& m) { "arys"_a, "stream"_a = nb::none(), nb::sig( - "def atleast_2d(*arys: array, stream: Union[None, Stream, Device] = None) -> Union[array, List[array]]"), + "def atleast_2d(*arys: array, stream: Union[None, Stream, Device] = None) -> Union[array, list[array]]"), R"pbdoc( Convert all arrays to have at least two dimensions. @@ -4262,7 +4262,7 @@ void init_ops(nb::module_& m) { "arys"_a, "stream"_a = nb::none(), nb::sig( - "def atleast_3d(*arys: array, stream: Union[None, Stream, Device] = None) -> Union[array, List[array]]"), + "def atleast_3d(*arys: array, stream: Union[None, Stream, Device] = None) -> Union[array, list[array]]"), R"pbdoc( Convert all arrays to have at least three dimensions. @@ -4511,7 +4511,7 @@ void init_ops(nb::module_& m) { nb::kw_only(), "stream"_a = nb::none(), nb::sig( - "def hadamard_transform(a: array, Optional[float] scale = None, stream: Union[None, Stream, Device] = None) -> array"), + "def hadamard_transform(a: array, scale: Optional[float] = None, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( Perform the Walsh-Hadamard transform along the final axis. @@ -4575,7 +4575,7 @@ void init_ops(nb::module_& m) { nb::kw_only(), "stream"_a = nb::none(), nb::sig( - "def einsum(subscripts: str, *operands, *, stream: Union[None, Stream, Device] = None) -> array"), + "def einsum(subscripts: str, *operands, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( Perform the Einstein summation convention on the operands. diff --git a/python/src/random.cpp b/python/src/random.cpp index 21e242524..13055b1fc 100644 --- a/python/src/random.cpp +++ b/python/src/random.cpp @@ -93,7 +93,7 @@ void init_random(nb::module_& parent_module) { "num"_a = 2, "stream"_a = nb::none(), nb::sig( - "def split(key: array, num: int = 2, stream: Union[None, Stream, Device] = None) -> array)"), + "def split(key: array, num: int = 2, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( Split a PRNG key into sub keys. @@ -321,7 +321,7 @@ void init_random(nb::module_& parent_module) { "key"_a = nb::none(), "stream"_a = nb::none(), nb::sig( - "def truncated_normal(lower: Union[scalar, array], upper: Union[scalar, array], shape: Optional[Sequence[int]] = None, dtype: float32, key: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"), + "def truncated_normal(lower: Union[scalar, array], upper: Union[scalar, array], shape: Optional[Sequence[int]] = None, dtype: Optional[Dtype] = float32, key: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( Generate values from a truncated normal distribution. diff --git a/python/src/stream.cpp b/python/src/stream.cpp index 95ddd20ed..81260abb7 100644 --- a/python/src/stream.cpp +++ b/python/src/stream.cpp @@ -4,6 +4,7 @@ #include #include +#include #include #include "mlx/stream.h" @@ -56,8 +57,8 @@ void init_stream(nb::module_& m) { os << s; return os.str(); }) - .def("__eq__", [](const Stream& s1, const Stream& s2) { - return s1 == s2; + .def("__eq__", [](const Stream& s, const nb::object& other) { + return nb::isinstance(other) && s == nb::cast(other); }); nb::implicitly_convertible(); diff --git a/python/src/transforms.cpp b/python/src/transforms.cpp index e793ed598..32c5b94b8 100644 --- a/python/src/transforms.cpp +++ b/python/src/transforms.cpp @@ -178,7 +178,7 @@ auto py_value_and_grad( msg << error_msg_tag << " The return value of the function " << "whose gradient we want to compute should be either a " << "scalar array or a tuple with the first value being a " - << "scalar array (Union[array, Tuple[array, Any, ...]]); but " + << "scalar array (Union[array, tuple[array, Any, ...]]); but " << type_name_str(py_value_out) << " was returned."; throw std::invalid_argument(msg.str()); } @@ -197,7 +197,7 @@ auto py_value_and_grad( msg << error_msg_tag << " The return value of the function " << "whose gradient we want to compute should be either a " << "scalar array or a tuple with the first value being a " - << "scalar array (Union[array, Tuple[array, Any, ...]]); but it " + << "scalar array (Union[array, tuple[array, Any, ...]]); but it " << "was a tuple with the first value being of type " << type_name_str(ret[0]) << " ."; throw std::invalid_argument(msg.str()); @@ -973,13 +973,13 @@ void init_transforms(nb::module_& m) { .def( nb::init(), "f"_a, - nb::sig("def __init__(self, f: callable)")) + nb::sig("def __init__(self, f: Callable)")) .def("__call__", &PyCustomFunction::call_impl) .def( "vjp", &PyCustomFunction::set_vjp, "f"_a, - nb::sig("def vjp(self, f_vjp: callable)"), + nb::sig("def vjp(self, f: Callable)"), R"pbdoc( Define a custom vjp for the wrapped function. @@ -1001,7 +1001,7 @@ void init_transforms(nb::module_& m) { "jvp", &PyCustomFunction::set_jvp, "f"_a, - nb::sig("def jvp(self, f_jvp: callable)"), + nb::sig("def jvp(self, f: Callable)"), R"pbdoc( Define a custom jvp for the wrapped function. @@ -1021,7 +1021,7 @@ void init_transforms(nb::module_& m) { "vmap", &PyCustomFunction::set_vmap, "f"_a, - nb::sig("def vmap(self, f_vmap: callable)"), + nb::sig("def vmap(self, f: Callable)"), R"pbdoc( Define a custom vectorization transformation for the wrapped function. @@ -1116,7 +1116,7 @@ void init_transforms(nb::module_& m) { "primals"_a, "tangents"_a, nb::sig( - "def jvp(fun: callable, primals: List[array], tangents: List[array]) -> Tuple[List[array], List[array]]"), + "def jvp(fun: Callable, primals: list[array], tangents: list[array]) -> tuple[list[array], list[array]]"), R"pbdoc( Compute the Jacobian-vector product. @@ -1124,7 +1124,7 @@ void init_transforms(nb::module_& m) { at ``primals`` with the ``tangents``. Args: - fun (callable): A function which takes a variable number of :class:`array` + fun (Callable): A function which takes a variable number of :class:`array` and returns a single :class:`array` or list of :class:`array`. primals (list(array)): A list of :class:`array` at which to evaluate the Jacobian. @@ -1155,7 +1155,7 @@ void init_transforms(nb::module_& m) { "primals"_a, "cotangents"_a, nb::sig( - "def vjp(fun: callable, primals: List[array], cotangents: List[array]) -> Tuple[List[array], List[array]]"), + "def vjp(fun: Callable, primals: list[array], cotangents: list[array]) -> tuple[list[array], list[array]]"), R"pbdoc( Compute the vector-Jacobian product. @@ -1163,7 +1163,7 @@ void init_transforms(nb::module_& m) { function ``fun`` evaluated at ``primals``. Args: - fun (callable): A function which takes a variable number of :class:`array` + fun (Callable): A function which takes a variable number of :class:`array` and returns a single :class:`array` or list of :class:`array`. primals (list(array)): A list of :class:`array` at which to evaluate the Jacobian. @@ -1189,7 +1189,7 @@ void init_transforms(nb::module_& m) { "argnums"_a = nb::none(), "argnames"_a = std::vector{}, nb::sig( - "def value_and_grad(fun: callable, argnums: Optional[Union[int, List[int]]] = None, argnames: Union[str, List[str]] = []) -> callable"), + "def value_and_grad(fun: Callable, argnums: Optional[Union[int, list[int]]] = None, argnames: Union[str, list[str]] = []) -> Callable"), R"pbdoc( Returns a function which computes the value and gradient of ``fun``. @@ -1221,7 +1221,7 @@ void init_transforms(nb::module_& m) { (loss, mse, l1), grads = mx.value_and_grad(lasso)(params, inputs, targets) Args: - fun (callable): A function which takes a variable number of + fun (Callable): A function which takes a variable number of :class:`array` or trees of :class:`array` and returns a scalar output :class:`array` or a tuple the first element of which should be a scalar :class:`array`. @@ -1235,7 +1235,7 @@ void init_transforms(nb::module_& m) { no gradients for keyword arguments by default. Returns: - callable: A function which returns a tuple where the first element + Callable: A function which returns a tuple where the first element is the output of `fun` and the second element is the gradients w.r.t. the loss. )pbdoc"); @@ -1257,12 +1257,12 @@ void init_transforms(nb::module_& m) { "argnums"_a = nb::none(), "argnames"_a = std::vector{}, nb::sig( - "def grad(fun: callable, argnums: Optional[Union[int, List[int]]] = None, argnames: Union[str, List[str]] = []) -> callable"), + "def grad(fun: Callable, argnums: Optional[Union[int, list[int]]] = None, argnames: Union[str, list[str]] = []) -> Callable"), R"pbdoc( Returns a function which computes the gradient of ``fun``. Args: - fun (callable): A function which takes a variable number of + fun (Callable): A function which takes a variable number of :class:`array` or trees of :class:`array` and returns a scalar output :class:`array`. argnums (int or list(int), optional): Specify the index (or indices) @@ -1275,7 +1275,7 @@ void init_transforms(nb::module_& m) { no gradients for keyword arguments by default. Returns: - callable: A function which has the same input arguments as ``fun`` and + Callable: A function which has the same input arguments as ``fun`` and returns the gradient(s). )pbdoc"); m.def( @@ -1289,12 +1289,12 @@ void init_transforms(nb::module_& m) { "in_axes"_a = 0, "out_axes"_a = 0, nb::sig( - "def vmap(fun: callable, in_axes: object = 0, out_axes: object = 0) -> callable"), + "def vmap(fun: Callable, in_axes: object = 0, out_axes: object = 0) -> Callable"), R"pbdoc( Returns a vectorized version of ``fun``. Args: - fun (callable): A function which takes a variable number of + fun (Callable): A function which takes a variable number of :class:`array` or a tree of :class:`array` and returns a variable number of :class:`array` or a tree of :class:`array`. in_axes (int, optional): An integer or a valid prefix tree of the @@ -1307,7 +1307,7 @@ void init_transforms(nb::module_& m) { Defaults to ``0``. Returns: - callable: The vectorized function. + Callable: The vectorized function. )pbdoc"); m.def( "export_to_dot", @@ -1367,11 +1367,13 @@ void init_transforms(nb::module_& m) { "inputs"_a = nb::none(), "outputs"_a = nb::none(), "shapeless"_a = false, + nb::sig( + "def compile(fun: Callable, inputs: Optional[object] = None, outputs: Optional[object] = None, shapeless: bool = False) -> Callable"), R"pbdoc( Returns a compiled function which produces the same output as ``fun``. Args: - fun (callable): A function which takes a variable number of + fun (Callable): A function which takes a variable number of :class:`array` or trees of :class:`array` and returns a variable number of :class:`array` or trees of :class:`array`. inputs (list or dict, optional): These inputs will be captured during @@ -1392,7 +1394,7 @@ void init_transforms(nb::module_& m) { ``shapeless`` set to ``True``. Default: ``False`` Returns: - callable: A compiled function which has the same input arguments + Callable: A compiled function which has the same input arguments as ``fun`` and returns the the same output(s). )pbdoc"); m.def(