post nanobind docs fixes and some updates (#889)

* post nanobind docs fixes and some updates

* one more doc nit

* fix for stubs and latex
This commit is contained in:
Awni Hannun 2024-03-24 15:03:27 -07:00 committed by GitHub
parent be98f4ab6b
commit 1e16331d9c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 185 additions and 118 deletions

View File

@ -44,7 +44,7 @@ jobs:
name: Generate package stubs name: Generate package stubs
command: | command: |
echo "stubs" echo "stubs"
python -m nanobind.stubgen -m mlx.core -r -O python python setup.py generate_stubs
- run: - run:
name: Run Python tests name: Run Python tests
command: | command: |
@ -94,7 +94,7 @@ jobs:
name: Generate package stubs name: Generate package stubs
command: | command: |
source env/bin/activate source env/bin/activate
python -m nanobind.stubgen -m mlx.core -r -O python python setup.py generate_stubs
- run: - run:
name: Run Python tests name: Run Python tests
command: | command: |
@ -159,7 +159,7 @@ jobs:
name: Generate package stubs name: Generate package stubs
command: | command: |
source env/bin/activate source env/bin/activate
python -m nanobind.stubgen -m mlx.core -r -O python python setup.py generate_stubs
- run: - run:
name: Build Python package name: Build Python package
command: | command: |
@ -216,7 +216,7 @@ jobs:
<< parameters.extra_env >> \ << parameters.extra_env >> \
CMAKE_BUILD_PARALLEL_LEVEL="" \ CMAKE_BUILD_PARALLEL_LEVEL="" \
pip install . -v pip install . -v
python -m nanobind.stubgen -m mlx.core -r -O python python setup.py generate_stubs
<< parameters.extra_env >> \ << parameters.extra_env >> \
CMAKE_BUILD_PARALLEL_LEVEL="" \ CMAKE_BUILD_PARALLEL_LEVEL="" \
python -m build --wheel python -m build --wheel

View File

@ -36,9 +36,10 @@ intersphinx_mapping = {
templates_path = ["_templates"] templates_path = ["_templates"]
html_static_path = ["_static"] html_static_path = ["_static"]
source_suffix = ".rst" source_suffix = ".rst"
master_doc = "index" main_doc = "index"
highlight_language = "python" highlight_language = "python"
pygments_style = "sphinx" pygments_style = "sphinx"
add_module_names = False
# -- Options for HTML output ------------------------------------------------- # -- Options for HTML output -------------------------------------------------
@ -62,11 +63,19 @@ htmlhelp_basename = "mlx_doc"
def setup(app): def setup(app):
wrapped = app.registry.documenters["function"].can_document_member from sphinx.util import inspect
def nanobind_function_patch(member: Any, *args, **kwargs) -> bool: wrapped_isfunc = inspect.isfunction
return "nanobind.nb_func" in str(type(member)) or wrapped(
member, *args, **kwargs
)
app.registry.documenters["function"].can_document_member = nanobind_function_patch def isfunc(obj):
type_name = str(type(obj))
if "nanobind.nb_method" in type_name or "nanobind.nb_func" in type_name:
return True
return wrapped_isfunc(obj)
inspect.isfunction = isfunc
# -- Options for LaTeX output ------------------------------------------------
latex_documents = [(main_doc, "MLX.tex", "MLX Documentation", author, "manual")]

View File

@ -62,6 +62,7 @@ are the CPU and GPU.
python/ops python/ops
python/random python/random
python/transforms python/transforms
python/fast
python/fft python/fft
python/linalg python/linalg
python/metal python/metal

View File

@ -10,9 +10,12 @@ Array
array array
array.astype array.astype
array.at
array.item array.item
array.tolist array.tolist
array.dtype array.dtype
array.itemsize
array.nbytes
array.ndim array.ndim
array.shape array.shape
array.size array.size
@ -23,14 +26,24 @@ Array
array.argmax array.argmax
array.argmin array.argmin
array.cos array.cos
array.cummax
array.cummin
array.cumprod
array.cumsum
array.diag
array.diagonal
array.dtype array.dtype
array.exp array.exp
array.flatten
array.log array.log
array.log10
array.log1p array.log1p
array.log2
array.logsumexp array.logsumexp
array.max array.max
array.mean array.mean
array.min array.min
array.moveaxis
array.prod array.prod
array.reciprocal array.reciprocal
array.reshape array.reshape
@ -40,6 +53,8 @@ Array
array.split array.split
array.sqrt array.sqrt
array.square array.square
array.squeeze
array.swapaxes
array.sum array.sum
array.transpose array.transpose
array.T array.T

View File

@ -44,9 +44,15 @@ The default floating point type is ``float32`` and the default integer type is
* - ``int64`` * - ``int64``
- 8 - 8
- 64-bit signed integer - 64-bit signed integer
* - ``bfloat16``
- 2
- 16-bit brain float (e8, m7)
* - ``float16`` * - ``float16``
- 2 - 2
- 16-bit float, only available with `ARM C language extensions <https://developer.arm.com/documentation/101028/0012/3--C-language-extensions?lang=en>`_ - 16-bit IEEE float (e5, m10)
* - ``float32`` * - ``float32``
- 4 - 4
- 32-bit float - 32-bit float
* - ``complex64``
- 8
- 64-bit complex float

14
docs/src/python/fast.rst Normal file
View File

@ -0,0 +1,14 @@
.. _fast:
Fast
====
.. currentmodule:: mlx.core.fast
.. autosummary::
:toctree: _autosummary
rms_norm
layer_norm
rope
scaled_dot_product_attention

View File

@ -38,6 +38,10 @@ Operations
conv_general conv_general
cos cos
cosh cosh
cummax
cummin
cumprod
cumsum
dequantize dequantize
diag diag
diagonal diagonal

View File

@ -156,7 +156,7 @@ def glorot_normal(
(``fan_out``) units according to: (``fan_out``) units according to:
.. math:: .. math::
\sigma = \gamma \sqrt{\frac{2.0}{\text{fan_in} + \text{fan_out}}} \sigma = \gamma \sqrt{\frac{2.0}{\text{fan\_in} + \text{fan\_out}}}
For more details see the original reference: `Understanding the difficulty For more details see the original reference: `Understanding the difficulty
of training deep feedforward neural networks of training deep feedforward neural networks
@ -199,7 +199,7 @@ def glorot_uniform(
units according to: units according to:
.. math:: .. math::
\sigma = \gamma \sqrt{\frac{6.0}{\text{fan_in} + \text{fan_out}}} \sigma = \gamma \sqrt{\frac{6.0}{\text{fan\_in} + \text{fan\_out}}}
For more details see the original reference: `Understanding the difficulty For more details see the original reference: `Understanding the difficulty
of training deep feedforward neural networks of training deep feedforward neural networks

View File

@ -166,7 +166,7 @@ class MaxPool1d(_Pool1d):
\text{input}(N_i, \text{stride} \times t + m, C_j), \text{input}(N_i, \text{stride} \times t + m, C_j),
where :math:`L_{out} = \left\lfloor \frac{L + 2 \times \text{padding} - where :math:`L_{out} = \left\lfloor \frac{L + 2 \times \text{padding} -
\text{kernel_size}}{\text{stride}}\right\rfloor + 1`. \text{kernel\_size}}{\text{stride}}\right\rfloor + 1`.
Args: Args:
kernel_size (int or tuple(int)): The size of the pooling window kernel. kernel_size (int or tuple(int)): The size of the pooling window kernel.
@ -205,7 +205,7 @@ class AvgPool1d(_Pool1d):
\text{input}(N_i, \text{stride} \times t + m, C_j), \text{input}(N_i, \text{stride} \times t + m, C_j),
where :math:`L_{out} = \left\lfloor \frac{L + 2 \times \text{padding} - where :math:`L_{out} = \left\lfloor \frac{L + 2 \times \text{padding} -
\text{kernel_size}}{\text{stride}}\right\rfloor + 1`. \text{kernel\_size}}{\text{stride}}\right\rfloor + 1`.
Args: Args:
kernel_size (int or tuple(int)): The size of the pooling window kernel. kernel_size (int or tuple(int)): The size of the pooling window kernel.
@ -246,8 +246,8 @@ class MaxPool2d(_Pool2d):
\text{stride[1]} \times w + n, C_j), \text{stride[1]} \times w + n, C_j),
\end{aligned} \end{aligned}
where :math:`H_{out} = \left\lfloor\frac{H + 2 * \text{padding[0]} - \text{kernel_size[0]}}{\text{stride[0]}}\right\rfloor + 1`, where :math:`H_{out} = \left\lfloor\frac{H + 2 * \text{padding[0]} - \text{kernel\_size[0]}}{\text{stride[0]}}\right\rfloor + 1`,
:math:`W_{out} = \left\lfloor\frac{W + 2 * \text{padding[1]} - \text{kernel_size[1]}}{\text{stride[1]}}\right\rfloor + 1`. :math:`W_{out} = \left\lfloor\frac{W + 2 * \text{padding[1]} - \text{kernel\_size[1]}}{\text{stride[1]}}\right\rfloor + 1`.
The parameters ``kernel_size``, ``stride``, ``padding``, can either be: The parameters ``kernel_size``, ``stride``, ``padding``, can either be:
@ -295,8 +295,8 @@ class AvgPool2d(_Pool2d):
\text{stride[1]} \times w + n, C_j), \text{stride[1]} \times w + n, C_j),
\end{aligned} \end{aligned}
where :math:`H_{out} = \left\lfloor\frac{H + 2 * \text{padding[0]} - \text{kernel_size[0]}}{\text{stride[0]}}\right\rfloor + 1`, where :math:`H_{out} = \left\lfloor\frac{H + 2 * \text{padding[0]} - \text{kernel\_size[0]}}{\text{stride[0]}}\right\rfloor + 1`,
:math:`W_{out} = \left\lfloor\frac{W + 2 * \text{padding[1]} - \text{kernel_size[1]}}{\text{stride[1]}}\right\rfloor + 1`. :math:`W_{out} = \left\lfloor\frac{W + 2 * \text{padding[1]} - \text{kernel\_size[1]}}{\text{stride[1]}}\right\rfloor + 1`.
The parameters ``kernel_size``, ``stride``, ``padding``, can either be: The parameters ``kernel_size``, ``stride``, ``padding``, can either be:

View File

@ -103,12 +103,12 @@ class GRU(Module):
.. math:: .. math::
\begin{align*} \begin{aligned}
r_t &= \sigma (W_{xr}x_t + W_{hr}h_t + b_{r}) \\ r_t &= \sigma (W_{xr}x_t + W_{hr}h_t + b_{r}) \\
z_t &= \sigma (W_{xz}x_t + W_{hz}h_t + b_{z}) \\ z_t &= \sigma (W_{xz}x_t + W_{hz}h_t + b_{z}) \\
n_t &= \text{tanh}(W_{xn}x_t + b_{n} + r_t \odot (W_{hn}h_t + b_{hn})) \\ n_t &= \text{tanh}(W_{xn}x_t + b_{n} + r_t \odot (W_{hn}h_t + b_{hn})) \\
h_{t + 1} &= (1 - z_t) \odot n_t + z_t \odot h_t h_{t + 1} &= (1 - z_t) \odot n_t + z_t \odot h_t
\end{align*} \end{aligned}
The hidden state :math:`h` has shape ``NH`` or ``H`` depending on The hidden state :math:`h` has shape ``NH`` or ``H`` depending on
whether the input is batched or not. Returns the hidden state at each whether the input is batched or not. Returns the hidden state at each
@ -206,14 +206,14 @@ class LSTM(Module):
Concretely, for each element of the sequence, this layer computes: Concretely, for each element of the sequence, this layer computes:
.. math:: .. math::
\begin{align*} \begin{aligned}
i_t &= \sigma (W_{xi}x_t + W_{hi}h_t + b_{i}) \\ i_t &= \sigma (W_{xi}x_t + W_{hi}h_t + b_{i}) \\
f_t &= \sigma (W_{xf}x_t + W_{hf}h_t + b_{f}) \\ f_t &= \sigma (W_{xf}x_t + W_{hf}h_t + b_{f}) \\
g_t &= \text{tanh} (W_{xg}x_t + W_{hg}h_t + b_{g}) \\ g_t &= \text{tanh} (W_{xg}x_t + W_{hg}h_t + b_{g}) \\
o_t &= \sigma (W_{xo}x_t + W_{ho}h_t + b_{o}) \\ o_t &= \sigma (W_{xo}x_t + W_{ho}h_t + b_{o}) \\
c_{t + 1} &= f_t \odot c_t + i_t \odot g_t \\ c_{t + 1} &= f_t \odot c_t + i_t \odot g_t \\
h_{t + 1} &= o_t \text{tanh}(c_{t + 1}) h_{t + 1} &= o_t \text{tanh}(c_{t + 1})
\end{align*} \end{aligned}
The hidden state :math:`h` and cell state :math:`c` have shape ``NH`` The hidden state :math:`h` and cell state :math:`c` have shape ``NH``
or ``H``, depending on whether the input is batched or not. or ``H``, depending on whether the input is batched or not.

View File

@ -343,10 +343,9 @@ def smooth_l1_loss(
.. math:: .. math::
l = l = \begin{cases}
\begin{cases} 0.5 (x - y)^2, & \text{if } (x - y) < \beta \\
0.5 (x - y)^2, & \text{ if } & (x - y) < \beta \\ |x - y| - 0.5 \beta, & \text{otherwise}
|x - y| - 0.5 \beta, & & \text{otherwise}
\end{cases} \end{cases}
Args: Args:

View File

@ -569,11 +569,10 @@ void init_array(nb::module_& m) {
.. note:: .. note::
Python in place updates for all array frameworks map to Regular in-place updates map to assignment. For instance ``x[idx] += y``
assignment. For instance ``x[idx] += y`` maps to ``x[idx] = maps to ``x[idx] = x[idx] + y``. As a result, assigning to the
x[idx] + y``. As a result, assigning to the same index ignores same index ignores all but one update. Using ``x.at[idx].add(y)``
all but one updates. Using ``x.at[idx].add(y)`` will correctly will correctly apply all updates to all indices.
apply all the updates to all indices.
.. list-table:: .. list-table::
:header-rows: 1 :header-rows: 1
@ -592,6 +591,17 @@ void init_array(nb::module_& m) {
- ``x[idx] = mx.maximum(x[idx], y)`` - ``x[idx] = mx.maximum(x[idx], y)``
* - ``x = x.at[idx].minimum(y)`` * - ``x = x.at[idx].minimum(y)``
- ``x[idx] = mx.minimum(x[idx], y)`` - ``x[idx] = mx.minimum(x[idx], y)``
Example:
>>> a = mx.array([0, 0])
>>> idx = mx.array([0, 1, 0, 1])
>>> a[idx] += 1
>>> a
array([1, 1], dtype=int32)
>>>
>>> a = mx.array([0, 0])
>>> a.at[idx].add(1)
array([2, 2], dtype=int32)
)pbdoc") )pbdoc")
.def( .def(
"__len__", "__len__",

View File

@ -118,12 +118,13 @@ void init_fast(nb::module_& parent_module) {
A fast implementation of multi-head attention: ``O = softmax(Q @ K.T, dim=-1) @ V``. A fast implementation of multi-head attention: ``O = softmax(Q @ K.T, dim=-1) @ V``.
Supports: Supports:
* [Multi-Head Attention](https://arxiv.org/abs/1706.03762)
* [Grouped Query Attention](https://arxiv.org/abs/2305.13245) * `Multi-Head Attention <https://arxiv.org/abs/1706.03762>`_
* [Multi-Query Attention](https://arxiv.org/abs/1911.02150). * `Grouped Query Attention <https://arxiv.org/abs/2305.13245>`_
* `Multi-Query Attention <https://arxiv.org/abs/1911.02150>`_
Note: The softmax operation is performed in ``float32`` regardless of Note: The softmax operation is performed in ``float32`` regardless of
input precision. the input precision.
Note: For Grouped Query Attention and Multi-Query Attention, the ``k`` Note: For Grouped Query Attention and Multi-Query Attention, the ``k``
and ``v`` inputs should not be pre-tiled to match ``q``. and ``v`` inputs should not be pre-tiled to match ``q``.

View File

@ -47,7 +47,7 @@ void init_ops(nb::module_& m) {
"shape"_a, "shape"_a,
nb::kw_only(), nb::kw_only(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
nb::sig("def reshape(a: array, /, shape: List[int], *, stream: " nb::sig("def reshape(a: array, /, shape: Sequence[int], *, stream: "
"Union[None, Stream, Device] = None) -> array"), "Union[None, Stream, Device] = None) -> array"),
R"pbdoc( R"pbdoc(
Reshape an array while preserving the size. Reshape an array while preserving the size.
@ -115,7 +115,8 @@ void init_ops(nb::module_& m) {
"axis"_a = nb::none(), "axis"_a = nb::none(),
nb::kw_only(), nb::kw_only(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
nb::sig("def squeeze(a: array, /, axis: Union[None, int, List[int]] = " nb::sig(
"def squeeze(a: array, /, axis: Union[None, int, Sequence[int]] = "
"None, *, stream: Union[None, Stream, Device] = None) -> array"), "None, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc( R"pbdoc(
Remove length one axes from an array. Remove length one axes from an array.
@ -143,7 +144,7 @@ void init_ops(nb::module_& m) {
"axis"_a, "axis"_a,
nb::kw_only(), nb::kw_only(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
nb::sig("def expand_dims(a: array, /, axis: Union[int, List[int]], " nb::sig("def expand_dims(a: array, /, axis: Union[int, Sequence[int]], "
"*, stream: Union[None, Stream, Device] = None) -> array"), "*, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc( R"pbdoc(
Add a size one dimension at the given axis. Add a size one dimension at the given axis.
@ -1148,78 +1149,36 @@ void init_ops(nb::module_& m) {
Returns: Returns:
array: Bases of ``a`` raised to powers in ``b``. array: Bases of ``a`` raised to powers in ``b``.
)pbdoc"); )pbdoc");
m.def(
"arange",
[](Scalar stop, std::optional<Dtype> dtype_, StreamOrDevice s) {
Dtype dtype =
dtype_.has_value() ? dtype_.value() : scalar_to_dtype(stop);
return arange(0.0, scalar_to_double(stop), 1.0, dtype, s);
},
"stop"_a,
"dtype"_a = nb::none(),
"stream"_a = nb::none());
m.def( m.def(
"arange", "arange",
[](Scalar start, [](Scalar start,
Scalar stop, Scalar stop,
std::optional<Dtype> dtype_, const std::optional<Scalar>& step,
StreamOrDevice s) { const std::optional<Dtype>& dtype_,
Dtype dtype = dtype_.has_value()
? dtype_.value()
: promote_types(scalar_to_dtype(start), scalar_to_dtype(stop));
return arange(
scalar_to_double(start), scalar_to_double(stop), dtype, s);
},
"start"_a,
"stop"_a,
"dtype"_a = nb::none(),
"stream"_a = nb::none());
m.def(
"arange",
[](Scalar stop,
Scalar step,
std::optional<Dtype> dtype_,
StreamOrDevice s) {
Dtype dtype = dtype_.has_value()
? dtype_.value()
: promote_types(scalar_to_dtype(stop), scalar_to_dtype(step));
return arange(
0.0, scalar_to_double(stop), scalar_to_double(step), dtype, s);
},
"stop"_a,
"step"_a,
"dtype"_a = nb::none(),
"stream"_a = nb::none());
m.def(
"arange",
[](Scalar start,
Scalar stop,
Scalar step,
std::optional<Dtype> dtype_,
StreamOrDevice s) { StreamOrDevice s) {
// Determine the final dtype based on input types // Determine the final dtype based on input types
Dtype dtype = dtype_.has_value() Dtype dtype = dtype_
? dtype_.value() ? *dtype_
: promote_types( : promote_types(
scalar_to_dtype(start), scalar_to_dtype(start),
promote_types(scalar_to_dtype(stop), scalar_to_dtype(step))); step ? promote_types(
scalar_to_dtype(stop), scalar_to_dtype(*step))
: scalar_to_dtype(stop));
return arange( return arange(
scalar_to_double(start), scalar_to_double(start),
scalar_to_double(stop), scalar_to_double(stop),
scalar_to_double(step), step ? scalar_to_double(*step) : 1.0,
dtype, dtype,
s); s);
}, },
"start"_a, "start"_a,
"stop"_a, "stop"_a,
"step"_a, "step"_a = nb::none(),
nb::kw_only(),
"dtype"_a = nb::none(), "dtype"_a = nb::none(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
nb::sig( nb::sig(
"def arange(start, stop, step, dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array"), "def arange(start : Union[int, float], stop : Union[int, float], step : Union[None, int, float], dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc( R"pbdoc(
Generates ranges of numbers. Generates ranges of numbers.
@ -1244,6 +1203,30 @@ void init_ops(nb::module_& m) {
This can lead to unexpected results for example if `start + step` This can lead to unexpected results for example if `start + step`
is a fractional value and the `dtype` is integral. is a fractional value and the `dtype` is integral.
)pbdoc"); )pbdoc");
m.def(
"arange",
[](Scalar stop,
const std::optional<Scalar>& step,
const std::optional<Dtype>& dtype_,
StreamOrDevice s) {
Dtype dtype = dtype_ ? *dtype_
: step
? promote_types(scalar_to_dtype(stop), scalar_to_dtype(*step))
: scalar_to_dtype(stop);
return arange(
0.0,
scalar_to_double(stop),
step ? scalar_to_double(*step) : 1.0,
dtype,
s);
},
"stop"_a,
"step"_a = nb::none(),
nb::kw_only(),
"dtype"_a = nb::none(),
"stream"_a = nb::none(),
nb::sig(
"def arange(stop : Union[int, float], step : Union[None, int, float], dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array"));
m.def( m.def(
"linspace", "linspace",
[](Scalar start, [](Scalar start,
@ -1367,7 +1350,7 @@ void init_ops(nb::module_& m) {
nb::kw_only(), nb::kw_only(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
nb::sig( nb::sig(
"def full(shape: Union[int, List[int]], vals: Union[scalar, array], dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array"), "def full(shape: Union[int, Sequence[int]], vals: Union[scalar, array], dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc( R"pbdoc(
Construct an array with the given value. Construct an array with the given value.
@ -1400,7 +1383,7 @@ void init_ops(nb::module_& m) {
nb::kw_only(), nb::kw_only(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
nb::sig( nb::sig(
"def zeros(shape: Union[int, List[int]], dtype: Optional[Dtype] = float32, *, stream: Union[None, Stream, Device] = None) -> array"), "def zeros(shape: Union[int, Sequence[int]], dtype: Optional[Dtype] = float32, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc( R"pbdoc(
Construct an array of zeros. Construct an array of zeros.
@ -1446,7 +1429,7 @@ void init_ops(nb::module_& m) {
nb::kw_only(), nb::kw_only(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
nb::sig( nb::sig(
"def ones(shape: Union[int, List[int]], dtype: Optional[Dtype] = float32, *, stream: Union[None, Stream, Device] = None) -> array"), "def ones(shape: Union[int, Sequence[int]], dtype: Optional[Dtype] = float32, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc( R"pbdoc(
Construct an array of ones. Construct an array of ones.
@ -1686,7 +1669,7 @@ void init_ops(nb::module_& m) {
nb::kw_only(), nb::kw_only(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
nb::sig( nb::sig(
"def all(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"), "def all(a: array, /, axis: Union[None, int, Sequence[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc( R"pbdoc(
An `and` reduction over the given axes. An `and` reduction over the given axes.
@ -1715,7 +1698,7 @@ void init_ops(nb::module_& m) {
nb::kw_only(), nb::kw_only(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
nb::sig( nb::sig(
"def any(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"), "def any(a: array, /, axis: Union[None, int, Sequence[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc( R"pbdoc(
An `or` reduction over the given axes. An `or` reduction over the given axes.
@ -1942,7 +1925,7 @@ void init_ops(nb::module_& m) {
nb::kw_only(), nb::kw_only(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
nb::sig( nb::sig(
"def transpose(a: array, /, axes: Optional[List[int]] = None, *, stream: Union[None, Stream, Device] = None) -> array"), "def transpose(a: array, /, axes: Optional[Sequence[int]] = None, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc( R"pbdoc(
Transpose the dimensions of the array. Transpose the dimensions of the array.
@ -1968,7 +1951,7 @@ void init_ops(nb::module_& m) {
nb::kw_only(), nb::kw_only(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
nb::sig( nb::sig(
"def sum(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"), "def sum(a: array, /, axis: Union[None, int, Sequence[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc( R"pbdoc(
Sum reduce the array over the given axes. Sum reduce the array over the given axes.
@ -1997,7 +1980,7 @@ void init_ops(nb::module_& m) {
nb::kw_only(), nb::kw_only(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
nb::sig( nb::sig(
"def prod(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"), "def prod(a: array, /, axis: Union[None, int, Sequence[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc( R"pbdoc(
An product reduction over the given axes. An product reduction over the given axes.
@ -2026,7 +2009,7 @@ void init_ops(nb::module_& m) {
nb::kw_only(), nb::kw_only(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
nb::sig( nb::sig(
"def min(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"), "def min(a: array, /, axis: Union[None, int, Sequence[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc( R"pbdoc(
A `min` reduction over the given axes. A `min` reduction over the given axes.
@ -2055,7 +2038,7 @@ void init_ops(nb::module_& m) {
nb::kw_only(), nb::kw_only(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
nb::sig( nb::sig(
"def max(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"), "def max(a: array, /, axis: Union[None, int, Sequence[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc( R"pbdoc(
A `max` reduction over the given axes. A `max` reduction over the given axes.
@ -2084,7 +2067,7 @@ void init_ops(nb::module_& m) {
nb::kw_only(), nb::kw_only(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
nb::sig( nb::sig(
"def logsumexp(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"), "def logsumexp(a: array, /, axis: Union[None, int, Sequence[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc( R"pbdoc(
A `log-sum-exp` reduction over the given axes. A `log-sum-exp` reduction over the given axes.
@ -2119,7 +2102,7 @@ void init_ops(nb::module_& m) {
nb::kw_only(), nb::kw_only(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
nb::sig( nb::sig(
"def mean(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"), "def mean(a: array, /, axis: Union[None, int, Sequence[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc( R"pbdoc(
Compute the mean(s) over the given axes. Compute the mean(s) over the given axes.
@ -2150,7 +2133,7 @@ void init_ops(nb::module_& m) {
nb::kw_only(), nb::kw_only(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
nb::sig( nb::sig(
"def var(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, ddof: int = 0, *, stream: Union[None, Stream, Device] = None) -> array"), "def var(a: array, /, axis: Union[None, int, Sequence[int]] = None, keepdims: bool = False, ddof: int = 0, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc( R"pbdoc(
Compute the variance(s) over the given axes. Compute the variance(s) over the given axes.
@ -2186,7 +2169,7 @@ void init_ops(nb::module_& m) {
nb::kw_only(), nb::kw_only(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
nb::sig( nb::sig(
"def split(a: array, /, indices_or_sections: Union[int, List[int]], axis: int = 0, *, stream: Union[None, Stream, Device] = None) -> array"), "def split(a: array, /, indices_or_sections: Union[int, Sequence[int]], axis: int = 0, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc( R"pbdoc(
Split an array along a given axis. Split an array along a given axis.
@ -2432,7 +2415,7 @@ void init_ops(nb::module_& m) {
nb::kw_only(), nb::kw_only(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
nb::sig( nb::sig(
"def broadcast_to(a: Union[scalar, array], /, shape: List[int], *, stream: Union[None, Stream, Device] = None) -> array"), "def broadcast_to(a: Union[scalar, array], /, shape: Sequence[int], *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc( R"pbdoc(
Broadcast an array to the given shape. Broadcast an array to the given shape.
@ -2455,7 +2438,7 @@ void init_ops(nb::module_& m) {
nb::kw_only(), nb::kw_only(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
nb::sig( nb::sig(
"def softmax(a: array, /, axis: Union[None, int, List[int]] = None, *, stream: Union[None, Stream, Device] = None) -> array"), "def softmax(a: array, /, axis: Union[None, int, Sequence[int]] = None, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc( R"pbdoc(
Perform the softmax along the given axis. Perform the softmax along the given axis.
@ -2677,7 +2660,7 @@ void init_ops(nb::module_& m) {
nb::kw_only(), nb::kw_only(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
nb::sig( nb::sig(
"def as_strided(a: array, /, shape: Optional[List[int]] = None, strides: Optional[List[int]] = None, offset: int = 0, *, stream: Union[None, Stream, Device] = None) -> array"), "def as_strided(a: array, /, shape: Optional[Sequence[int]] = None, strides: Optional[Sequence[int]] = None, offset: int = 0, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc( R"pbdoc(
Create a view into the array with the given shape and strides. Create a view into the array with the given shape and strides.
@ -3078,7 +3061,7 @@ void init_ops(nb::module_& m) {
nb::kw_only(), nb::kw_only(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
nb::sig( nb::sig(
"def conv_general(input: array, weight: array, /, stride: Union[int, List[int]] = 1, padding: Union[int, List[int], Tuple[List[int], List[int]]] = 0, kernel_dilation: Union[int, List[int]] = 1, input_dilation: Union[int, List[int]] = 1, groups: int = 1, flip: bool = false, *, stream: Union[None, Stream, Device] = None) -> array"), "def conv_general(input: array, weight: array, /, stride: Union[int, Sequence[int]] = 1, padding: Union[int, Sequence[int], Tuple[Sequence[int], Sequence[int]]] = 0, kernel_dilation: Union[int, Sequence[int]] = 1, input_dilation: Union[int, Sequence[int]] = 1, groups: int = 1, flip: bool = false, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc( R"pbdoc(
General convolution over an input with several channels General convolution over an input with several channels
@ -3471,7 +3454,7 @@ void init_ops(nb::module_& m) {
nb::kw_only(), nb::kw_only(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
nb::sig( nb::sig(
"def tensordot(a: array, b: array, /, axes: Union[int, List[List[int]]] = 2, *, stream: Union[None, Stream, Device] = None) -> array"), "def tensordot(a: array, b: array, /, axes: Union[int, List[Sequence[int]]] = 2, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc( R"pbdoc(
Compute the tensor dot product along the specified axes. Compute the tensor dot product along the specified axes.
@ -3539,7 +3522,7 @@ void init_ops(nb::module_& m) {
nb::kw_only(), nb::kw_only(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
nb::sig( nb::sig(
"def tile(a: array, reps: Union[int, List[int]], /, *, stream: Union[None, Stream, Device] = None) -> array"), "def tile(a: array, reps: Union[int, Sequence[int]], /, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc( R"pbdoc(
Construct an array by repeating ``a`` the number of times given by ``reps``. Construct an array by repeating ``a`` the number of times given by ``reps``.

View File

@ -92,6 +92,8 @@ void init_random(nb::module_& parent_module) {
"key"_a, "key"_a,
"num"_a = 2, "num"_a = 2,
"stream"_a = nb::none(), "stream"_a = nb::none(),
nb::sig(
"def split(key: array, num: int = 2, stream: Union[None, Stream, Device] = None) -> array)"),
R"pbdoc( R"pbdoc(
Split a PRNG key into sub keys. Split a PRNG key into sub keys.
@ -125,6 +127,8 @@ void init_random(nb::module_& parent_module) {
"dtype"_a.none() = float32, "dtype"_a.none() = float32,
"key"_a = nb::none(), "key"_a = nb::none(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
nb::sig(
"def uniform(low: Union[scalar, array] = 0, high: Union[scalar, array] = 1, shape: Sequence[int] = [], dtype: Optional[Dtype] = float32, key: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc( R"pbdoc(
Generate uniformly distributed random numbers. Generate uniformly distributed random numbers.
@ -159,6 +163,8 @@ void init_random(nb::module_& parent_module) {
"scale"_a = 1.0, "scale"_a = 1.0,
"key"_a = nb::none(), "key"_a = nb::none(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
nb::sig(
"def normal(shape: Sequence[int] = [], dtype: Optional[Dtype] = float32, loc: float = 0.0, scale: float = 1.0, key: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc( R"pbdoc(
Generate normally distributed random numbers. Generate normally distributed random numbers.
@ -190,6 +196,8 @@ void init_random(nb::module_& parent_module) {
"dtype"_a.none() = int32, "dtype"_a.none() = int32,
"key"_a = nb::none(), "key"_a = nb::none(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
nb::sig(
"def randint(low: Union[scalar, array], high: Union[scalar, array], shape: Sequence[int] = [], dtype: Optional[Dtype] = int32, key: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc( R"pbdoc(
Generate random integers from the given interval. Generate random integers from the given interval.
@ -225,6 +233,8 @@ void init_random(nb::module_& parent_module) {
"shape"_a = nb::none(), "shape"_a = nb::none(),
"key"_a = nb::none(), "key"_a = nb::none(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
nb::sig(
"def bernoulli(p: Union[scalar, array] = 0.5, shape: Optional[Sequence[int]] = None, key: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc( R"pbdoc(
Generate Bernoulli random values. Generate Bernoulli random values.
@ -266,6 +276,8 @@ void init_random(nb::module_& parent_module) {
"dtype"_a.none() = float32, "dtype"_a.none() = float32,
"key"_a = nb::none(), "key"_a = nb::none(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
nb::sig(
"def truncated_normal(lower: Union[scalar, array], upper: Union[scalar, array], shape: Optional[Sequence[int]] = None, dtype: float32, key: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc( R"pbdoc(
Generate values from a truncated normal distribution. Generate values from a truncated normal distribution.
@ -298,6 +310,8 @@ void init_random(nb::module_& parent_module) {
"dtype"_a.none() = float32, "dtype"_a.none() = float32,
"stream"_a = nb::none(), "stream"_a = nb::none(),
"key"_a = nb::none(), "key"_a = nb::none(),
nb::sig(
"def gumbel(shape: Sequence[int] = [], dtype: Optional[Dtype] = float32, stream: Optional[array] = None, key: Union[None, Stream, Device] = None) -> array"),
R"pbdoc( R"pbdoc(
Sample from the standard Gumbel distribution. Sample from the standard Gumbel distribution.
@ -338,6 +352,8 @@ void init_random(nb::module_& parent_module) {
"num_samples"_a = nb::none(), "num_samples"_a = nb::none(),
"key"_a = nb::none(), "key"_a = nb::none(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
nb::sig(
"def categorical(logits: array, axis: int = -1, shape: Optional[Sequence[int]] = None, num_samples: Optional[int] = None, key: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc( R"pbdoc(
Sample from a categorical distribution. Sample from a categorical distribution.

View File

@ -134,9 +134,18 @@ class GenerateStubs(Command):
pass pass
def run(self) -> None: def run(self) -> None:
subprocess.run( out_path = "python/mlx/core"
["python", "-m", "nanobind.stubgen", "-m", "mlx.core", "-r", "-O", "python"] stub_cmd = [
) "python",
"-m",
"nanobind.stubgen",
"-m",
"mlx.core",
]
subprocess.run(stub_cmd + ["-r", "-O", out_path])
# Run again without recursive to specify output file name
subprocess.run(["rm", f"{out_path}/mlx.pyi"])
subprocess.run(stub_cmd + ["-o", f"{out_path}/__init__.pyi"])
# Read the content of README.md # Read the content of README.md