diff --git a/.circleci/config.yml b/.circleci/config.yml index 6f8e613f8..b2e7794c1 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -44,7 +44,7 @@ jobs: name: Generate package stubs command: | echo "stubs" - python -m nanobind.stubgen -m mlx.core -r -O python + python setup.py generate_stubs - run: name: Run Python tests command: | @@ -94,7 +94,7 @@ jobs: name: Generate package stubs command: | source env/bin/activate - python -m nanobind.stubgen -m mlx.core -r -O python + python setup.py generate_stubs - run: name: Run Python tests command: | @@ -159,7 +159,7 @@ jobs: name: Generate package stubs command: | source env/bin/activate - python -m nanobind.stubgen -m mlx.core -r -O python + python setup.py generate_stubs - run: name: Build Python package command: | @@ -216,7 +216,7 @@ jobs: << parameters.extra_env >> \ CMAKE_BUILD_PARALLEL_LEVEL="" \ pip install . -v - python -m nanobind.stubgen -m mlx.core -r -O python + python setup.py generate_stubs << parameters.extra_env >> \ CMAKE_BUILD_PARALLEL_LEVEL="" \ python -m build --wheel diff --git a/docs/src/conf.py b/docs/src/conf.py index c85ee9510..3348c2f46 100644 --- a/docs/src/conf.py +++ b/docs/src/conf.py @@ -36,9 +36,10 @@ intersphinx_mapping = { templates_path = ["_templates"] html_static_path = ["_static"] source_suffix = ".rst" -master_doc = "index" +main_doc = "index" highlight_language = "python" pygments_style = "sphinx" +add_module_names = False # -- Options for HTML output ------------------------------------------------- @@ -62,11 +63,19 @@ htmlhelp_basename = "mlx_doc" def setup(app): - wrapped = app.registry.documenters["function"].can_document_member + from sphinx.util import inspect - def nanobind_function_patch(member: Any, *args, **kwargs) -> bool: - return "nanobind.nb_func" in str(type(member)) or wrapped( - member, *args, **kwargs - ) + wrapped_isfunc = inspect.isfunction - app.registry.documenters["function"].can_document_member = nanobind_function_patch + def isfunc(obj): + type_name = str(type(obj)) + if "nanobind.nb_method" in type_name or "nanobind.nb_func" in type_name: + return True + return wrapped_isfunc(obj) + + inspect.isfunction = isfunc + + +# -- Options for LaTeX output ------------------------------------------------ + +latex_documents = [(main_doc, "MLX.tex", "MLX Documentation", author, "manual")] diff --git a/docs/src/index.rst b/docs/src/index.rst index e54a55b7a..aec2ea0b8 100644 --- a/docs/src/index.rst +++ b/docs/src/index.rst @@ -62,6 +62,7 @@ are the CPU and GPU. python/ops python/random python/transforms + python/fast python/fft python/linalg python/metal diff --git a/docs/src/python/array.rst b/docs/src/python/array.rst index e96e7234d..00f97c68f 100644 --- a/docs/src/python/array.rst +++ b/docs/src/python/array.rst @@ -10,9 +10,12 @@ Array array array.astype + array.at array.item array.tolist array.dtype + array.itemsize + array.nbytes array.ndim array.shape array.size @@ -23,14 +26,24 @@ Array array.argmax array.argmin array.cos + array.cummax + array.cummin + array.cumprod + array.cumsum + array.diag + array.diagonal array.dtype array.exp + array.flatten array.log + array.log10 array.log1p + array.log2 array.logsumexp array.max array.mean array.min + array.moveaxis array.prod array.reciprocal array.reshape @@ -40,6 +53,8 @@ Array array.split array.sqrt array.square + array.squeeze + array.swapaxes array.sum array.transpose array.T diff --git a/docs/src/python/data_types.rst b/docs/src/python/data_types.rst index c1b240d86..83991261e 100644 --- a/docs/src/python/data_types.rst +++ b/docs/src/python/data_types.rst @@ -44,9 +44,15 @@ The default floating point type is ``float32`` and the default integer type is * - ``int64`` - 8 - 64-bit signed integer + * - ``bfloat16`` + - 2 + - 16-bit brain float (e8, m7) * - ``float16`` - 2 - - 16-bit float, only available with `ARM C language extensions `_ + - 16-bit IEEE float (e5, m10) * - ``float32`` - 4 - 32-bit float + * - ``complex64`` + - 8 + - 64-bit complex float diff --git a/docs/src/python/fast.rst b/docs/src/python/fast.rst new file mode 100644 index 000000000..26bd62a26 --- /dev/null +++ b/docs/src/python/fast.rst @@ -0,0 +1,14 @@ +.. _fast: + +Fast +==== + +.. currentmodule:: mlx.core.fast + +.. autosummary:: + :toctree: _autosummary + + rms_norm + layer_norm + rope + scaled_dot_product_attention diff --git a/docs/src/python/ops.rst b/docs/src/python/ops.rst index 6396bb3c6..462e92a59 100644 --- a/docs/src/python/ops.rst +++ b/docs/src/python/ops.rst @@ -38,6 +38,10 @@ Operations conv_general cos cosh + cummax + cummin + cumprod + cumsum dequantize diag diagonal diff --git a/python/mlx/nn/init.py b/python/mlx/nn/init.py index 6596ba741..4020ff2aa 100644 --- a/python/mlx/nn/init.py +++ b/python/mlx/nn/init.py @@ -156,7 +156,7 @@ def glorot_normal( (``fan_out``) units according to: .. math:: - \sigma = \gamma \sqrt{\frac{2.0}{\text{fan_in} + \text{fan_out}}} + \sigma = \gamma \sqrt{\frac{2.0}{\text{fan\_in} + \text{fan\_out}}} For more details see the original reference: `Understanding the difficulty of training deep feedforward neural networks @@ -199,7 +199,7 @@ def glorot_uniform( units according to: .. math:: - \sigma = \gamma \sqrt{\frac{6.0}{\text{fan_in} + \text{fan_out}}} + \sigma = \gamma \sqrt{\frac{6.0}{\text{fan\_in} + \text{fan\_out}}} For more details see the original reference: `Understanding the difficulty of training deep feedforward neural networks diff --git a/python/mlx/nn/layers/pooling.py b/python/mlx/nn/layers/pooling.py index 9d6267bf9..1bcd2a2dc 100644 --- a/python/mlx/nn/layers/pooling.py +++ b/python/mlx/nn/layers/pooling.py @@ -166,7 +166,7 @@ class MaxPool1d(_Pool1d): \text{input}(N_i, \text{stride} \times t + m, C_j), where :math:`L_{out} = \left\lfloor \frac{L + 2 \times \text{padding} - - \text{kernel_size}}{\text{stride}}\right\rfloor + 1`. + \text{kernel\_size}}{\text{stride}}\right\rfloor + 1`. Args: kernel_size (int or tuple(int)): The size of the pooling window kernel. @@ -205,7 +205,7 @@ class AvgPool1d(_Pool1d): \text{input}(N_i, \text{stride} \times t + m, C_j), where :math:`L_{out} = \left\lfloor \frac{L + 2 \times \text{padding} - - \text{kernel_size}}{\text{stride}}\right\rfloor + 1`. + \text{kernel\_size}}{\text{stride}}\right\rfloor + 1`. Args: kernel_size (int or tuple(int)): The size of the pooling window kernel. @@ -246,8 +246,8 @@ class MaxPool2d(_Pool2d): \text{stride[1]} \times w + n, C_j), \end{aligned} - where :math:`H_{out} = \left\lfloor\frac{H + 2 * \text{padding[0]} - \text{kernel_size[0]}}{\text{stride[0]}}\right\rfloor + 1`, - :math:`W_{out} = \left\lfloor\frac{W + 2 * \text{padding[1]} - \text{kernel_size[1]}}{\text{stride[1]}}\right\rfloor + 1`. + where :math:`H_{out} = \left\lfloor\frac{H + 2 * \text{padding[0]} - \text{kernel\_size[0]}}{\text{stride[0]}}\right\rfloor + 1`, + :math:`W_{out} = \left\lfloor\frac{W + 2 * \text{padding[1]} - \text{kernel\_size[1]}}{\text{stride[1]}}\right\rfloor + 1`. The parameters ``kernel_size``, ``stride``, ``padding``, can either be: @@ -295,8 +295,8 @@ class AvgPool2d(_Pool2d): \text{stride[1]} \times w + n, C_j), \end{aligned} - where :math:`H_{out} = \left\lfloor\frac{H + 2 * \text{padding[0]} - \text{kernel_size[0]}}{\text{stride[0]}}\right\rfloor + 1`, - :math:`W_{out} = \left\lfloor\frac{W + 2 * \text{padding[1]} - \text{kernel_size[1]}}{\text{stride[1]}}\right\rfloor + 1`. + where :math:`H_{out} = \left\lfloor\frac{H + 2 * \text{padding[0]} - \text{kernel\_size[0]}}{\text{stride[0]}}\right\rfloor + 1`, + :math:`W_{out} = \left\lfloor\frac{W + 2 * \text{padding[1]} - \text{kernel\_size[1]}}{\text{stride[1]}}\right\rfloor + 1`. The parameters ``kernel_size``, ``stride``, ``padding``, can either be: diff --git a/python/mlx/nn/layers/recurrent.py b/python/mlx/nn/layers/recurrent.py index 6f8a590fa..d578c521c 100644 --- a/python/mlx/nn/layers/recurrent.py +++ b/python/mlx/nn/layers/recurrent.py @@ -103,12 +103,12 @@ class GRU(Module): .. math:: - \begin{align*} + \begin{aligned} r_t &= \sigma (W_{xr}x_t + W_{hr}h_t + b_{r}) \\ z_t &= \sigma (W_{xz}x_t + W_{hz}h_t + b_{z}) \\ n_t &= \text{tanh}(W_{xn}x_t + b_{n} + r_t \odot (W_{hn}h_t + b_{hn})) \\ h_{t + 1} &= (1 - z_t) \odot n_t + z_t \odot h_t - \end{align*} + \end{aligned} The hidden state :math:`h` has shape ``NH`` or ``H`` depending on whether the input is batched or not. Returns the hidden state at each @@ -206,14 +206,14 @@ class LSTM(Module): Concretely, for each element of the sequence, this layer computes: .. math:: - \begin{align*} + \begin{aligned} i_t &= \sigma (W_{xi}x_t + W_{hi}h_t + b_{i}) \\ f_t &= \sigma (W_{xf}x_t + W_{hf}h_t + b_{f}) \\ g_t &= \text{tanh} (W_{xg}x_t + W_{hg}h_t + b_{g}) \\ o_t &= \sigma (W_{xo}x_t + W_{ho}h_t + b_{o}) \\ c_{t + 1} &= f_t \odot c_t + i_t \odot g_t \\ h_{t + 1} &= o_t \text{tanh}(c_{t + 1}) - \end{align*} + \end{aligned} The hidden state :math:`h` and cell state :math:`c` have shape ``NH`` or ``H``, depending on whether the input is batched or not. diff --git a/python/mlx/nn/losses.py b/python/mlx/nn/losses.py index ee33fde3e..5bbeb1f06 100644 --- a/python/mlx/nn/losses.py +++ b/python/mlx/nn/losses.py @@ -343,10 +343,9 @@ def smooth_l1_loss( .. math:: - l = - \begin{cases} - 0.5 (x - y)^2, & \text{ if } & (x - y) < \beta \\ - |x - y| - 0.5 \beta, & & \text{otherwise} + l = \begin{cases} + 0.5 (x - y)^2, & \text{if } (x - y) < \beta \\ + |x - y| - 0.5 \beta, & \text{otherwise} \end{cases} Args: diff --git a/python/src/array.cpp b/python/src/array.cpp index 196561c16..e16db9568 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -569,11 +569,10 @@ void init_array(nb::module_& m) { .. note:: - Python in place updates for all array frameworks map to - assignment. For instance ``x[idx] += y`` maps to ``x[idx] = - x[idx] + y``. As a result, assigning to the same index ignores - all but one updates. Using ``x.at[idx].add(y)`` will correctly - apply all the updates to all indices. + Regular in-place updates map to assignment. For instance ``x[idx] += y`` + maps to ``x[idx] = x[idx] + y``. As a result, assigning to the + same index ignores all but one update. Using ``x.at[idx].add(y)`` + will correctly apply all updates to all indices. .. list-table:: :header-rows: 1 @@ -591,7 +590,18 @@ void init_array(nb::module_& m) { * - ``x = x.at[idx].maximum(y)`` - ``x[idx] = mx.maximum(x[idx], y)`` * - ``x = x.at[idx].minimum(y)`` - - ``x[idx] = mx.minimum(x[idx], y)`` + - ``x[idx] = mx.minimum(x[idx], y)`` + + Example: + >>> a = mx.array([0, 0]) + >>> idx = mx.array([0, 1, 0, 1]) + >>> a[idx] += 1 + >>> a + array([1, 1], dtype=int32) + >>> + >>> a = mx.array([0, 0]) + >>> a.at[idx].add(1) + array([2, 2], dtype=int32) )pbdoc") .def( "__len__", diff --git a/python/src/fast.cpp b/python/src/fast.cpp index 306ef1bb4..d74eca6e8 100644 --- a/python/src/fast.cpp +++ b/python/src/fast.cpp @@ -118,12 +118,13 @@ void init_fast(nb::module_& parent_module) { A fast implementation of multi-head attention: ``O = softmax(Q @ K.T, dim=-1) @ V``. Supports: - * [Multi-Head Attention](https://arxiv.org/abs/1706.03762) - * [Grouped Query Attention](https://arxiv.org/abs/2305.13245) - * [Multi-Query Attention](https://arxiv.org/abs/1911.02150). + + * `Multi-Head Attention `_ + * `Grouped Query Attention `_ + * `Multi-Query Attention `_ Note: The softmax operation is performed in ``float32`` regardless of - input precision. + the input precision. Note: For Grouped Query Attention and Multi-Query Attention, the ``k`` and ``v`` inputs should not be pre-tiled to match ``q``. diff --git a/python/src/ops.cpp b/python/src/ops.cpp index e40ec5ca3..a68eeb9a7 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -47,7 +47,7 @@ void init_ops(nb::module_& m) { "shape"_a, nb::kw_only(), "stream"_a = nb::none(), - nb::sig("def reshape(a: array, /, shape: List[int], *, stream: " + nb::sig("def reshape(a: array, /, shape: Sequence[int], *, stream: " "Union[None, Stream, Device] = None) -> array"), R"pbdoc( Reshape an array while preserving the size. @@ -115,8 +115,9 @@ void init_ops(nb::module_& m) { "axis"_a = nb::none(), nb::kw_only(), "stream"_a = nb::none(), - nb::sig("def squeeze(a: array, /, axis: Union[None, int, List[int]] = " - "None, *, stream: Union[None, Stream, Device] = None) -> array"), + nb::sig( + "def squeeze(a: array, /, axis: Union[None, int, Sequence[int]] = " + "None, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( Remove length one axes from an array. @@ -143,7 +144,7 @@ void init_ops(nb::module_& m) { "axis"_a, nb::kw_only(), "stream"_a = nb::none(), - nb::sig("def expand_dims(a: array, /, axis: Union[int, List[int]], " + nb::sig("def expand_dims(a: array, /, axis: Union[int, Sequence[int]], " "*, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( Add a size one dimension at the given axis. @@ -1148,78 +1149,36 @@ void init_ops(nb::module_& m) { Returns: array: Bases of ``a`` raised to powers in ``b``. )pbdoc"); - m.def( - "arange", - [](Scalar stop, std::optional dtype_, StreamOrDevice s) { - Dtype dtype = - dtype_.has_value() ? dtype_.value() : scalar_to_dtype(stop); - - return arange(0.0, scalar_to_double(stop), 1.0, dtype, s); - }, - "stop"_a, - "dtype"_a = nb::none(), - "stream"_a = nb::none()); m.def( "arange", [](Scalar start, Scalar stop, - std::optional dtype_, - StreamOrDevice s) { - Dtype dtype = dtype_.has_value() - ? dtype_.value() - : promote_types(scalar_to_dtype(start), scalar_to_dtype(stop)); - return arange( - scalar_to_double(start), scalar_to_double(stop), dtype, s); - }, - "start"_a, - "stop"_a, - "dtype"_a = nb::none(), - "stream"_a = nb::none()); - m.def( - "arange", - [](Scalar stop, - Scalar step, - std::optional dtype_, - StreamOrDevice s) { - Dtype dtype = dtype_.has_value() - ? dtype_.value() - : promote_types(scalar_to_dtype(stop), scalar_to_dtype(step)); - - return arange( - 0.0, scalar_to_double(stop), scalar_to_double(step), dtype, s); - }, - "stop"_a, - "step"_a, - "dtype"_a = nb::none(), - "stream"_a = nb::none()); - m.def( - "arange", - [](Scalar start, - Scalar stop, - Scalar step, - std::optional dtype_, + const std::optional& step, + const std::optional& dtype_, StreamOrDevice s) { // Determine the final dtype based on input types - Dtype dtype = dtype_.has_value() - ? dtype_.value() + Dtype dtype = dtype_ + ? *dtype_ : promote_types( scalar_to_dtype(start), - promote_types(scalar_to_dtype(stop), scalar_to_dtype(step))); - + step ? promote_types( + scalar_to_dtype(stop), scalar_to_dtype(*step)) + : scalar_to_dtype(stop)); return arange( scalar_to_double(start), scalar_to_double(stop), - scalar_to_double(step), + step ? scalar_to_double(*step) : 1.0, dtype, s); }, "start"_a, "stop"_a, - "step"_a, + "step"_a = nb::none(), + nb::kw_only(), "dtype"_a = nb::none(), "stream"_a = nb::none(), nb::sig( - "def arange(start, stop, step, dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array"), + "def arange(start : Union[int, float], stop : Union[int, float], step : Union[None, int, float], dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( Generates ranges of numbers. @@ -1244,6 +1203,30 @@ void init_ops(nb::module_& m) { This can lead to unexpected results for example if `start + step` is a fractional value and the `dtype` is integral. )pbdoc"); + m.def( + "arange", + [](Scalar stop, + const std::optional& step, + const std::optional& dtype_, + StreamOrDevice s) { + Dtype dtype = dtype_ ? *dtype_ + : step + ? promote_types(scalar_to_dtype(stop), scalar_to_dtype(*step)) + : scalar_to_dtype(stop); + return arange( + 0.0, + scalar_to_double(stop), + step ? scalar_to_double(*step) : 1.0, + dtype, + s); + }, + "stop"_a, + "step"_a = nb::none(), + nb::kw_only(), + "dtype"_a = nb::none(), + "stream"_a = nb::none(), + nb::sig( + "def arange(stop : Union[int, float], step : Union[None, int, float], dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array")); m.def( "linspace", [](Scalar start, @@ -1367,7 +1350,7 @@ void init_ops(nb::module_& m) { nb::kw_only(), "stream"_a = nb::none(), nb::sig( - "def full(shape: Union[int, List[int]], vals: Union[scalar, array], dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array"), + "def full(shape: Union[int, Sequence[int]], vals: Union[scalar, array], dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( Construct an array with the given value. @@ -1400,7 +1383,7 @@ void init_ops(nb::module_& m) { nb::kw_only(), "stream"_a = nb::none(), nb::sig( - "def zeros(shape: Union[int, List[int]], dtype: Optional[Dtype] = float32, *, stream: Union[None, Stream, Device] = None) -> array"), + "def zeros(shape: Union[int, Sequence[int]], dtype: Optional[Dtype] = float32, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( Construct an array of zeros. @@ -1446,7 +1429,7 @@ void init_ops(nb::module_& m) { nb::kw_only(), "stream"_a = nb::none(), nb::sig( - "def ones(shape: Union[int, List[int]], dtype: Optional[Dtype] = float32, *, stream: Union[None, Stream, Device] = None) -> array"), + "def ones(shape: Union[int, Sequence[int]], dtype: Optional[Dtype] = float32, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( Construct an array of ones. @@ -1686,7 +1669,7 @@ void init_ops(nb::module_& m) { nb::kw_only(), "stream"_a = nb::none(), nb::sig( - "def all(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"), + "def all(a: array, /, axis: Union[None, int, Sequence[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( An `and` reduction over the given axes. @@ -1715,7 +1698,7 @@ void init_ops(nb::module_& m) { nb::kw_only(), "stream"_a = nb::none(), nb::sig( - "def any(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"), + "def any(a: array, /, axis: Union[None, int, Sequence[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( An `or` reduction over the given axes. @@ -1942,7 +1925,7 @@ void init_ops(nb::module_& m) { nb::kw_only(), "stream"_a = nb::none(), nb::sig( - "def transpose(a: array, /, axes: Optional[List[int]] = None, *, stream: Union[None, Stream, Device] = None) -> array"), + "def transpose(a: array, /, axes: Optional[Sequence[int]] = None, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( Transpose the dimensions of the array. @@ -1968,7 +1951,7 @@ void init_ops(nb::module_& m) { nb::kw_only(), "stream"_a = nb::none(), nb::sig( - "def sum(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"), + "def sum(a: array, /, axis: Union[None, int, Sequence[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( Sum reduce the array over the given axes. @@ -1997,7 +1980,7 @@ void init_ops(nb::module_& m) { nb::kw_only(), "stream"_a = nb::none(), nb::sig( - "def prod(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"), + "def prod(a: array, /, axis: Union[None, int, Sequence[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( An product reduction over the given axes. @@ -2026,7 +2009,7 @@ void init_ops(nb::module_& m) { nb::kw_only(), "stream"_a = nb::none(), nb::sig( - "def min(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"), + "def min(a: array, /, axis: Union[None, int, Sequence[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( A `min` reduction over the given axes. @@ -2055,7 +2038,7 @@ void init_ops(nb::module_& m) { nb::kw_only(), "stream"_a = nb::none(), nb::sig( - "def max(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"), + "def max(a: array, /, axis: Union[None, int, Sequence[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( A `max` reduction over the given axes. @@ -2084,7 +2067,7 @@ void init_ops(nb::module_& m) { nb::kw_only(), "stream"_a = nb::none(), nb::sig( - "def logsumexp(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"), + "def logsumexp(a: array, /, axis: Union[None, int, Sequence[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( A `log-sum-exp` reduction over the given axes. @@ -2119,7 +2102,7 @@ void init_ops(nb::module_& m) { nb::kw_only(), "stream"_a = nb::none(), nb::sig( - "def mean(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"), + "def mean(a: array, /, axis: Union[None, int, Sequence[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( Compute the mean(s) over the given axes. @@ -2150,7 +2133,7 @@ void init_ops(nb::module_& m) { nb::kw_only(), "stream"_a = nb::none(), nb::sig( - "def var(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, ddof: int = 0, *, stream: Union[None, Stream, Device] = None) -> array"), + "def var(a: array, /, axis: Union[None, int, Sequence[int]] = None, keepdims: bool = False, ddof: int = 0, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( Compute the variance(s) over the given axes. @@ -2186,7 +2169,7 @@ void init_ops(nb::module_& m) { nb::kw_only(), "stream"_a = nb::none(), nb::sig( - "def split(a: array, /, indices_or_sections: Union[int, List[int]], axis: int = 0, *, stream: Union[None, Stream, Device] = None) -> array"), + "def split(a: array, /, indices_or_sections: Union[int, Sequence[int]], axis: int = 0, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( Split an array along a given axis. @@ -2432,7 +2415,7 @@ void init_ops(nb::module_& m) { nb::kw_only(), "stream"_a = nb::none(), nb::sig( - "def broadcast_to(a: Union[scalar, array], /, shape: List[int], *, stream: Union[None, Stream, Device] = None) -> array"), + "def broadcast_to(a: Union[scalar, array], /, shape: Sequence[int], *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( Broadcast an array to the given shape. @@ -2455,7 +2438,7 @@ void init_ops(nb::module_& m) { nb::kw_only(), "stream"_a = nb::none(), nb::sig( - "def softmax(a: array, /, axis: Union[None, int, List[int]] = None, *, stream: Union[None, Stream, Device] = None) -> array"), + "def softmax(a: array, /, axis: Union[None, int, Sequence[int]] = None, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( Perform the softmax along the given axis. @@ -2677,7 +2660,7 @@ void init_ops(nb::module_& m) { nb::kw_only(), "stream"_a = nb::none(), nb::sig( - "def as_strided(a: array, /, shape: Optional[List[int]] = None, strides: Optional[List[int]] = None, offset: int = 0, *, stream: Union[None, Stream, Device] = None) -> array"), + "def as_strided(a: array, /, shape: Optional[Sequence[int]] = None, strides: Optional[Sequence[int]] = None, offset: int = 0, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( Create a view into the array with the given shape and strides. @@ -3078,7 +3061,7 @@ void init_ops(nb::module_& m) { nb::kw_only(), "stream"_a = nb::none(), nb::sig( - "def conv_general(input: array, weight: array, /, stride: Union[int, List[int]] = 1, padding: Union[int, List[int], Tuple[List[int], List[int]]] = 0, kernel_dilation: Union[int, List[int]] = 1, input_dilation: Union[int, List[int]] = 1, groups: int = 1, flip: bool = false, *, stream: Union[None, Stream, Device] = None) -> array"), + "def conv_general(input: array, weight: array, /, stride: Union[int, Sequence[int]] = 1, padding: Union[int, Sequence[int], Tuple[Sequence[int], Sequence[int]]] = 0, kernel_dilation: Union[int, Sequence[int]] = 1, input_dilation: Union[int, Sequence[int]] = 1, groups: int = 1, flip: bool = false, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( General convolution over an input with several channels @@ -3471,7 +3454,7 @@ void init_ops(nb::module_& m) { nb::kw_only(), "stream"_a = nb::none(), nb::sig( - "def tensordot(a: array, b: array, /, axes: Union[int, List[List[int]]] = 2, *, stream: Union[None, Stream, Device] = None) -> array"), + "def tensordot(a: array, b: array, /, axes: Union[int, List[Sequence[int]]] = 2, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( Compute the tensor dot product along the specified axes. @@ -3539,7 +3522,7 @@ void init_ops(nb::module_& m) { nb::kw_only(), "stream"_a = nb::none(), nb::sig( - "def tile(a: array, reps: Union[int, List[int]], /, *, stream: Union[None, Stream, Device] = None) -> array"), + "def tile(a: array, reps: Union[int, Sequence[int]], /, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( Construct an array by repeating ``a`` the number of times given by ``reps``. diff --git a/python/src/random.cpp b/python/src/random.cpp index b3cb3aa2f..dde8469d4 100644 --- a/python/src/random.cpp +++ b/python/src/random.cpp @@ -92,6 +92,8 @@ void init_random(nb::module_& parent_module) { "key"_a, "num"_a = 2, "stream"_a = nb::none(), + nb::sig( + "def split(key: array, num: int = 2, stream: Union[None, Stream, Device] = None) -> array)"), R"pbdoc( Split a PRNG key into sub keys. @@ -125,6 +127,8 @@ void init_random(nb::module_& parent_module) { "dtype"_a.none() = float32, "key"_a = nb::none(), "stream"_a = nb::none(), + nb::sig( + "def uniform(low: Union[scalar, array] = 0, high: Union[scalar, array] = 1, shape: Sequence[int] = [], dtype: Optional[Dtype] = float32, key: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( Generate uniformly distributed random numbers. @@ -159,6 +163,8 @@ void init_random(nb::module_& parent_module) { "scale"_a = 1.0, "key"_a = nb::none(), "stream"_a = nb::none(), + nb::sig( + "def normal(shape: Sequence[int] = [], dtype: Optional[Dtype] = float32, loc: float = 0.0, scale: float = 1.0, key: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( Generate normally distributed random numbers. @@ -190,6 +196,8 @@ void init_random(nb::module_& parent_module) { "dtype"_a.none() = int32, "key"_a = nb::none(), "stream"_a = nb::none(), + nb::sig( + "def randint(low: Union[scalar, array], high: Union[scalar, array], shape: Sequence[int] = [], dtype: Optional[Dtype] = int32, key: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( Generate random integers from the given interval. @@ -225,6 +233,8 @@ void init_random(nb::module_& parent_module) { "shape"_a = nb::none(), "key"_a = nb::none(), "stream"_a = nb::none(), + nb::sig( + "def bernoulli(p: Union[scalar, array] = 0.5, shape: Optional[Sequence[int]] = None, key: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( Generate Bernoulli random values. @@ -266,6 +276,8 @@ void init_random(nb::module_& parent_module) { "dtype"_a.none() = float32, "key"_a = nb::none(), "stream"_a = nb::none(), + nb::sig( + "def truncated_normal(lower: Union[scalar, array], upper: Union[scalar, array], shape: Optional[Sequence[int]] = None, dtype: float32, key: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( Generate values from a truncated normal distribution. @@ -298,6 +310,8 @@ void init_random(nb::module_& parent_module) { "dtype"_a.none() = float32, "stream"_a = nb::none(), "key"_a = nb::none(), + nb::sig( + "def gumbel(shape: Sequence[int] = [], dtype: Optional[Dtype] = float32, stream: Optional[array] = None, key: Union[None, Stream, Device] = None) -> array"), R"pbdoc( Sample from the standard Gumbel distribution. @@ -338,6 +352,8 @@ void init_random(nb::module_& parent_module) { "num_samples"_a = nb::none(), "key"_a = nb::none(), "stream"_a = nb::none(), + nb::sig( + "def categorical(logits: array, axis: int = -1, shape: Optional[Sequence[int]] = None, num_samples: Optional[int] = None, key: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( Sample from a categorical distribution. diff --git a/setup.py b/setup.py index 8bbb2a2f7..136834ff8 100644 --- a/setup.py +++ b/setup.py @@ -134,9 +134,18 @@ class GenerateStubs(Command): pass def run(self) -> None: - subprocess.run( - ["python", "-m", "nanobind.stubgen", "-m", "mlx.core", "-r", "-O", "python"] - ) + out_path = "python/mlx/core" + stub_cmd = [ + "python", + "-m", + "nanobind.stubgen", + "-m", + "mlx.core", + ] + subprocess.run(stub_cmd + ["-r", "-O", out_path]) + # Run again without recursive to specify output file name + subprocess.run(["rm", f"{out_path}/mlx.pyi"]) + subprocess.run(stub_cmd + ["-o", f"{out_path}/__init__.pyi"]) # Read the content of README.md