mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
post nanobind docs fixes and some updates (#889)
* post nanobind docs fixes and some updates * one more doc nit * fix for stubs and latex
This commit is contained in:
parent
be98f4ab6b
commit
1e16331d9c
@ -44,7 +44,7 @@ jobs:
|
||||
name: Generate package stubs
|
||||
command: |
|
||||
echo "stubs"
|
||||
python -m nanobind.stubgen -m mlx.core -r -O python
|
||||
python setup.py generate_stubs
|
||||
- run:
|
||||
name: Run Python tests
|
||||
command: |
|
||||
@ -94,7 +94,7 @@ jobs:
|
||||
name: Generate package stubs
|
||||
command: |
|
||||
source env/bin/activate
|
||||
python -m nanobind.stubgen -m mlx.core -r -O python
|
||||
python setup.py generate_stubs
|
||||
- run:
|
||||
name: Run Python tests
|
||||
command: |
|
||||
@ -159,7 +159,7 @@ jobs:
|
||||
name: Generate package stubs
|
||||
command: |
|
||||
source env/bin/activate
|
||||
python -m nanobind.stubgen -m mlx.core -r -O python
|
||||
python setup.py generate_stubs
|
||||
- run:
|
||||
name: Build Python package
|
||||
command: |
|
||||
@ -216,7 +216,7 @@ jobs:
|
||||
<< parameters.extra_env >> \
|
||||
CMAKE_BUILD_PARALLEL_LEVEL="" \
|
||||
pip install . -v
|
||||
python -m nanobind.stubgen -m mlx.core -r -O python
|
||||
python setup.py generate_stubs
|
||||
<< parameters.extra_env >> \
|
||||
CMAKE_BUILD_PARALLEL_LEVEL="" \
|
||||
python -m build --wheel
|
||||
|
@ -36,9 +36,10 @@ intersphinx_mapping = {
|
||||
templates_path = ["_templates"]
|
||||
html_static_path = ["_static"]
|
||||
source_suffix = ".rst"
|
||||
master_doc = "index"
|
||||
main_doc = "index"
|
||||
highlight_language = "python"
|
||||
pygments_style = "sphinx"
|
||||
add_module_names = False
|
||||
|
||||
# -- Options for HTML output -------------------------------------------------
|
||||
|
||||
@ -62,11 +63,19 @@ htmlhelp_basename = "mlx_doc"
|
||||
|
||||
|
||||
def setup(app):
|
||||
wrapped = app.registry.documenters["function"].can_document_member
|
||||
from sphinx.util import inspect
|
||||
|
||||
def nanobind_function_patch(member: Any, *args, **kwargs) -> bool:
|
||||
return "nanobind.nb_func" in str(type(member)) or wrapped(
|
||||
member, *args, **kwargs
|
||||
)
|
||||
wrapped_isfunc = inspect.isfunction
|
||||
|
||||
app.registry.documenters["function"].can_document_member = nanobind_function_patch
|
||||
def isfunc(obj):
|
||||
type_name = str(type(obj))
|
||||
if "nanobind.nb_method" in type_name or "nanobind.nb_func" in type_name:
|
||||
return True
|
||||
return wrapped_isfunc(obj)
|
||||
|
||||
inspect.isfunction = isfunc
|
||||
|
||||
|
||||
# -- Options for LaTeX output ------------------------------------------------
|
||||
|
||||
latex_documents = [(main_doc, "MLX.tex", "MLX Documentation", author, "manual")]
|
||||
|
@ -62,6 +62,7 @@ are the CPU and GPU.
|
||||
python/ops
|
||||
python/random
|
||||
python/transforms
|
||||
python/fast
|
||||
python/fft
|
||||
python/linalg
|
||||
python/metal
|
||||
|
@ -10,9 +10,12 @@ Array
|
||||
|
||||
array
|
||||
array.astype
|
||||
array.at
|
||||
array.item
|
||||
array.tolist
|
||||
array.dtype
|
||||
array.itemsize
|
||||
array.nbytes
|
||||
array.ndim
|
||||
array.shape
|
||||
array.size
|
||||
@ -23,14 +26,24 @@ Array
|
||||
array.argmax
|
||||
array.argmin
|
||||
array.cos
|
||||
array.cummax
|
||||
array.cummin
|
||||
array.cumprod
|
||||
array.cumsum
|
||||
array.diag
|
||||
array.diagonal
|
||||
array.dtype
|
||||
array.exp
|
||||
array.flatten
|
||||
array.log
|
||||
array.log10
|
||||
array.log1p
|
||||
array.log2
|
||||
array.logsumexp
|
||||
array.max
|
||||
array.mean
|
||||
array.min
|
||||
array.moveaxis
|
||||
array.prod
|
||||
array.reciprocal
|
||||
array.reshape
|
||||
@ -40,6 +53,8 @@ Array
|
||||
array.split
|
||||
array.sqrt
|
||||
array.square
|
||||
array.squeeze
|
||||
array.swapaxes
|
||||
array.sum
|
||||
array.transpose
|
||||
array.T
|
||||
|
@ -44,9 +44,15 @@ The default floating point type is ``float32`` and the default integer type is
|
||||
* - ``int64``
|
||||
- 8
|
||||
- 64-bit signed integer
|
||||
* - ``bfloat16``
|
||||
- 2
|
||||
- 16-bit brain float (e8, m7)
|
||||
* - ``float16``
|
||||
- 2
|
||||
- 16-bit float, only available with `ARM C language extensions <https://developer.arm.com/documentation/101028/0012/3--C-language-extensions?lang=en>`_
|
||||
- 16-bit IEEE float (e5, m10)
|
||||
* - ``float32``
|
||||
- 4
|
||||
- 32-bit float
|
||||
* - ``complex64``
|
||||
- 8
|
||||
- 64-bit complex float
|
||||
|
14
docs/src/python/fast.rst
Normal file
14
docs/src/python/fast.rst
Normal file
@ -0,0 +1,14 @@
|
||||
.. _fast:
|
||||
|
||||
Fast
|
||||
====
|
||||
|
||||
.. currentmodule:: mlx.core.fast
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
rms_norm
|
||||
layer_norm
|
||||
rope
|
||||
scaled_dot_product_attention
|
@ -38,6 +38,10 @@ Operations
|
||||
conv_general
|
||||
cos
|
||||
cosh
|
||||
cummax
|
||||
cummin
|
||||
cumprod
|
||||
cumsum
|
||||
dequantize
|
||||
diag
|
||||
diagonal
|
||||
|
@ -156,7 +156,7 @@ def glorot_normal(
|
||||
(``fan_out``) units according to:
|
||||
|
||||
.. math::
|
||||
\sigma = \gamma \sqrt{\frac{2.0}{\text{fan_in} + \text{fan_out}}}
|
||||
\sigma = \gamma \sqrt{\frac{2.0}{\text{fan\_in} + \text{fan\_out}}}
|
||||
|
||||
For more details see the original reference: `Understanding the difficulty
|
||||
of training deep feedforward neural networks
|
||||
@ -199,7 +199,7 @@ def glorot_uniform(
|
||||
units according to:
|
||||
|
||||
.. math::
|
||||
\sigma = \gamma \sqrt{\frac{6.0}{\text{fan_in} + \text{fan_out}}}
|
||||
\sigma = \gamma \sqrt{\frac{6.0}{\text{fan\_in} + \text{fan\_out}}}
|
||||
|
||||
For more details see the original reference: `Understanding the difficulty
|
||||
of training deep feedforward neural networks
|
||||
|
@ -166,7 +166,7 @@ class MaxPool1d(_Pool1d):
|
||||
\text{input}(N_i, \text{stride} \times t + m, C_j),
|
||||
|
||||
where :math:`L_{out} = \left\lfloor \frac{L + 2 \times \text{padding} -
|
||||
\text{kernel_size}}{\text{stride}}\right\rfloor + 1`.
|
||||
\text{kernel\_size}}{\text{stride}}\right\rfloor + 1`.
|
||||
|
||||
Args:
|
||||
kernel_size (int or tuple(int)): The size of the pooling window kernel.
|
||||
@ -205,7 +205,7 @@ class AvgPool1d(_Pool1d):
|
||||
\text{input}(N_i, \text{stride} \times t + m, C_j),
|
||||
|
||||
where :math:`L_{out} = \left\lfloor \frac{L + 2 \times \text{padding} -
|
||||
\text{kernel_size}}{\text{stride}}\right\rfloor + 1`.
|
||||
\text{kernel\_size}}{\text{stride}}\right\rfloor + 1`.
|
||||
|
||||
Args:
|
||||
kernel_size (int or tuple(int)): The size of the pooling window kernel.
|
||||
@ -246,8 +246,8 @@ class MaxPool2d(_Pool2d):
|
||||
\text{stride[1]} \times w + n, C_j),
|
||||
\end{aligned}
|
||||
|
||||
where :math:`H_{out} = \left\lfloor\frac{H + 2 * \text{padding[0]} - \text{kernel_size[0]}}{\text{stride[0]}}\right\rfloor + 1`,
|
||||
:math:`W_{out} = \left\lfloor\frac{W + 2 * \text{padding[1]} - \text{kernel_size[1]}}{\text{stride[1]}}\right\rfloor + 1`.
|
||||
where :math:`H_{out} = \left\lfloor\frac{H + 2 * \text{padding[0]} - \text{kernel\_size[0]}}{\text{stride[0]}}\right\rfloor + 1`,
|
||||
:math:`W_{out} = \left\lfloor\frac{W + 2 * \text{padding[1]} - \text{kernel\_size[1]}}{\text{stride[1]}}\right\rfloor + 1`.
|
||||
|
||||
The parameters ``kernel_size``, ``stride``, ``padding``, can either be:
|
||||
|
||||
@ -295,8 +295,8 @@ class AvgPool2d(_Pool2d):
|
||||
\text{stride[1]} \times w + n, C_j),
|
||||
\end{aligned}
|
||||
|
||||
where :math:`H_{out} = \left\lfloor\frac{H + 2 * \text{padding[0]} - \text{kernel_size[0]}}{\text{stride[0]}}\right\rfloor + 1`,
|
||||
:math:`W_{out} = \left\lfloor\frac{W + 2 * \text{padding[1]} - \text{kernel_size[1]}}{\text{stride[1]}}\right\rfloor + 1`.
|
||||
where :math:`H_{out} = \left\lfloor\frac{H + 2 * \text{padding[0]} - \text{kernel\_size[0]}}{\text{stride[0]}}\right\rfloor + 1`,
|
||||
:math:`W_{out} = \left\lfloor\frac{W + 2 * \text{padding[1]} - \text{kernel\_size[1]}}{\text{stride[1]}}\right\rfloor + 1`.
|
||||
|
||||
The parameters ``kernel_size``, ``stride``, ``padding``, can either be:
|
||||
|
||||
|
@ -103,12 +103,12 @@ class GRU(Module):
|
||||
|
||||
.. math::
|
||||
|
||||
\begin{align*}
|
||||
\begin{aligned}
|
||||
r_t &= \sigma (W_{xr}x_t + W_{hr}h_t + b_{r}) \\
|
||||
z_t &= \sigma (W_{xz}x_t + W_{hz}h_t + b_{z}) \\
|
||||
n_t &= \text{tanh}(W_{xn}x_t + b_{n} + r_t \odot (W_{hn}h_t + b_{hn})) \\
|
||||
h_{t + 1} &= (1 - z_t) \odot n_t + z_t \odot h_t
|
||||
\end{align*}
|
||||
\end{aligned}
|
||||
|
||||
The hidden state :math:`h` has shape ``NH`` or ``H`` depending on
|
||||
whether the input is batched or not. Returns the hidden state at each
|
||||
@ -206,14 +206,14 @@ class LSTM(Module):
|
||||
Concretely, for each element of the sequence, this layer computes:
|
||||
|
||||
.. math::
|
||||
\begin{align*}
|
||||
\begin{aligned}
|
||||
i_t &= \sigma (W_{xi}x_t + W_{hi}h_t + b_{i}) \\
|
||||
f_t &= \sigma (W_{xf}x_t + W_{hf}h_t + b_{f}) \\
|
||||
g_t &= \text{tanh} (W_{xg}x_t + W_{hg}h_t + b_{g}) \\
|
||||
o_t &= \sigma (W_{xo}x_t + W_{ho}h_t + b_{o}) \\
|
||||
c_{t + 1} &= f_t \odot c_t + i_t \odot g_t \\
|
||||
h_{t + 1} &= o_t \text{tanh}(c_{t + 1})
|
||||
\end{align*}
|
||||
\end{aligned}
|
||||
|
||||
The hidden state :math:`h` and cell state :math:`c` have shape ``NH``
|
||||
or ``H``, depending on whether the input is batched or not.
|
||||
|
@ -343,10 +343,9 @@ def smooth_l1_loss(
|
||||
|
||||
.. math::
|
||||
|
||||
l =
|
||||
\begin{cases}
|
||||
0.5 (x - y)^2, & \text{ if } & (x - y) < \beta \\
|
||||
|x - y| - 0.5 \beta, & & \text{otherwise}
|
||||
l = \begin{cases}
|
||||
0.5 (x - y)^2, & \text{if } (x - y) < \beta \\
|
||||
|x - y| - 0.5 \beta, & \text{otherwise}
|
||||
\end{cases}
|
||||
|
||||
Args:
|
||||
|
@ -569,11 +569,10 @@ void init_array(nb::module_& m) {
|
||||
|
||||
.. note::
|
||||
|
||||
Python in place updates for all array frameworks map to
|
||||
assignment. For instance ``x[idx] += y`` maps to ``x[idx] =
|
||||
x[idx] + y``. As a result, assigning to the same index ignores
|
||||
all but one updates. Using ``x.at[idx].add(y)`` will correctly
|
||||
apply all the updates to all indices.
|
||||
Regular in-place updates map to assignment. For instance ``x[idx] += y``
|
||||
maps to ``x[idx] = x[idx] + y``. As a result, assigning to the
|
||||
same index ignores all but one update. Using ``x.at[idx].add(y)``
|
||||
will correctly apply all updates to all indices.
|
||||
|
||||
.. list-table::
|
||||
:header-rows: 1
|
||||
@ -591,7 +590,18 @@ void init_array(nb::module_& m) {
|
||||
* - ``x = x.at[idx].maximum(y)``
|
||||
- ``x[idx] = mx.maximum(x[idx], y)``
|
||||
* - ``x = x.at[idx].minimum(y)``
|
||||
- ``x[idx] = mx.minimum(x[idx], y)``
|
||||
- ``x[idx] = mx.minimum(x[idx], y)``
|
||||
|
||||
Example:
|
||||
>>> a = mx.array([0, 0])
|
||||
>>> idx = mx.array([0, 1, 0, 1])
|
||||
>>> a[idx] += 1
|
||||
>>> a
|
||||
array([1, 1], dtype=int32)
|
||||
>>>
|
||||
>>> a = mx.array([0, 0])
|
||||
>>> a.at[idx].add(1)
|
||||
array([2, 2], dtype=int32)
|
||||
)pbdoc")
|
||||
.def(
|
||||
"__len__",
|
||||
|
@ -118,12 +118,13 @@ void init_fast(nb::module_& parent_module) {
|
||||
A fast implementation of multi-head attention: ``O = softmax(Q @ K.T, dim=-1) @ V``.
|
||||
|
||||
Supports:
|
||||
* [Multi-Head Attention](https://arxiv.org/abs/1706.03762)
|
||||
* [Grouped Query Attention](https://arxiv.org/abs/2305.13245)
|
||||
* [Multi-Query Attention](https://arxiv.org/abs/1911.02150).
|
||||
|
||||
* `Multi-Head Attention <https://arxiv.org/abs/1706.03762>`_
|
||||
* `Grouped Query Attention <https://arxiv.org/abs/2305.13245>`_
|
||||
* `Multi-Query Attention <https://arxiv.org/abs/1911.02150>`_
|
||||
|
||||
Note: The softmax operation is performed in ``float32`` regardless of
|
||||
input precision.
|
||||
the input precision.
|
||||
|
||||
Note: For Grouped Query Attention and Multi-Query Attention, the ``k``
|
||||
and ``v`` inputs should not be pre-tiled to match ``q``.
|
||||
|
@ -47,7 +47,7 @@ void init_ops(nb::module_& m) {
|
||||
"shape"_a,
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig("def reshape(a: array, /, shape: List[int], *, stream: "
|
||||
nb::sig("def reshape(a: array, /, shape: Sequence[int], *, stream: "
|
||||
"Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Reshape an array while preserving the size.
|
||||
@ -115,8 +115,9 @@ void init_ops(nb::module_& m) {
|
||||
"axis"_a = nb::none(),
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig("def squeeze(a: array, /, axis: Union[None, int, List[int]] = "
|
||||
"None, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
nb::sig(
|
||||
"def squeeze(a: array, /, axis: Union[None, int, Sequence[int]] = "
|
||||
"None, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Remove length one axes from an array.
|
||||
|
||||
@ -143,7 +144,7 @@ void init_ops(nb::module_& m) {
|
||||
"axis"_a,
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig("def expand_dims(a: array, /, axis: Union[int, List[int]], "
|
||||
nb::sig("def expand_dims(a: array, /, axis: Union[int, Sequence[int]], "
|
||||
"*, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Add a size one dimension at the given axis.
|
||||
@ -1148,78 +1149,36 @@ void init_ops(nb::module_& m) {
|
||||
Returns:
|
||||
array: Bases of ``a`` raised to powers in ``b``.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"arange",
|
||||
[](Scalar stop, std::optional<Dtype> dtype_, StreamOrDevice s) {
|
||||
Dtype dtype =
|
||||
dtype_.has_value() ? dtype_.value() : scalar_to_dtype(stop);
|
||||
|
||||
return arange(0.0, scalar_to_double(stop), 1.0, dtype, s);
|
||||
},
|
||||
"stop"_a,
|
||||
"dtype"_a = nb::none(),
|
||||
"stream"_a = nb::none());
|
||||
m.def(
|
||||
"arange",
|
||||
[](Scalar start,
|
||||
Scalar stop,
|
||||
std::optional<Dtype> dtype_,
|
||||
StreamOrDevice s) {
|
||||
Dtype dtype = dtype_.has_value()
|
||||
? dtype_.value()
|
||||
: promote_types(scalar_to_dtype(start), scalar_to_dtype(stop));
|
||||
return arange(
|
||||
scalar_to_double(start), scalar_to_double(stop), dtype, s);
|
||||
},
|
||||
"start"_a,
|
||||
"stop"_a,
|
||||
"dtype"_a = nb::none(),
|
||||
"stream"_a = nb::none());
|
||||
m.def(
|
||||
"arange",
|
||||
[](Scalar stop,
|
||||
Scalar step,
|
||||
std::optional<Dtype> dtype_,
|
||||
StreamOrDevice s) {
|
||||
Dtype dtype = dtype_.has_value()
|
||||
? dtype_.value()
|
||||
: promote_types(scalar_to_dtype(stop), scalar_to_dtype(step));
|
||||
|
||||
return arange(
|
||||
0.0, scalar_to_double(stop), scalar_to_double(step), dtype, s);
|
||||
},
|
||||
"stop"_a,
|
||||
"step"_a,
|
||||
"dtype"_a = nb::none(),
|
||||
"stream"_a = nb::none());
|
||||
m.def(
|
||||
"arange",
|
||||
[](Scalar start,
|
||||
Scalar stop,
|
||||
Scalar step,
|
||||
std::optional<Dtype> dtype_,
|
||||
const std::optional<Scalar>& step,
|
||||
const std::optional<Dtype>& dtype_,
|
||||
StreamOrDevice s) {
|
||||
// Determine the final dtype based on input types
|
||||
Dtype dtype = dtype_.has_value()
|
||||
? dtype_.value()
|
||||
Dtype dtype = dtype_
|
||||
? *dtype_
|
||||
: promote_types(
|
||||
scalar_to_dtype(start),
|
||||
promote_types(scalar_to_dtype(stop), scalar_to_dtype(step)));
|
||||
|
||||
step ? promote_types(
|
||||
scalar_to_dtype(stop), scalar_to_dtype(*step))
|
||||
: scalar_to_dtype(stop));
|
||||
return arange(
|
||||
scalar_to_double(start),
|
||||
scalar_to_double(stop),
|
||||
scalar_to_double(step),
|
||||
step ? scalar_to_double(*step) : 1.0,
|
||||
dtype,
|
||||
s);
|
||||
},
|
||||
"start"_a,
|
||||
"stop"_a,
|
||||
"step"_a,
|
||||
"step"_a = nb::none(),
|
||||
nb::kw_only(),
|
||||
"dtype"_a = nb::none(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def arange(start, stop, step, dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def arange(start : Union[int, float], stop : Union[int, float], step : Union[None, int, float], dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Generates ranges of numbers.
|
||||
|
||||
@ -1244,6 +1203,30 @@ void init_ops(nb::module_& m) {
|
||||
This can lead to unexpected results for example if `start + step`
|
||||
is a fractional value and the `dtype` is integral.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"arange",
|
||||
[](Scalar stop,
|
||||
const std::optional<Scalar>& step,
|
||||
const std::optional<Dtype>& dtype_,
|
||||
StreamOrDevice s) {
|
||||
Dtype dtype = dtype_ ? *dtype_
|
||||
: step
|
||||
? promote_types(scalar_to_dtype(stop), scalar_to_dtype(*step))
|
||||
: scalar_to_dtype(stop);
|
||||
return arange(
|
||||
0.0,
|
||||
scalar_to_double(stop),
|
||||
step ? scalar_to_double(*step) : 1.0,
|
||||
dtype,
|
||||
s);
|
||||
},
|
||||
"stop"_a,
|
||||
"step"_a = nb::none(),
|
||||
nb::kw_only(),
|
||||
"dtype"_a = nb::none(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def arange(stop : Union[int, float], step : Union[None, int, float], dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array"));
|
||||
m.def(
|
||||
"linspace",
|
||||
[](Scalar start,
|
||||
@ -1367,7 +1350,7 @@ void init_ops(nb::module_& m) {
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def full(shape: Union[int, List[int]], vals: Union[scalar, array], dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def full(shape: Union[int, Sequence[int]], vals: Union[scalar, array], dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Construct an array with the given value.
|
||||
|
||||
@ -1400,7 +1383,7 @@ void init_ops(nb::module_& m) {
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def zeros(shape: Union[int, List[int]], dtype: Optional[Dtype] = float32, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def zeros(shape: Union[int, Sequence[int]], dtype: Optional[Dtype] = float32, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Construct an array of zeros.
|
||||
|
||||
@ -1446,7 +1429,7 @@ void init_ops(nb::module_& m) {
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def ones(shape: Union[int, List[int]], dtype: Optional[Dtype] = float32, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def ones(shape: Union[int, Sequence[int]], dtype: Optional[Dtype] = float32, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Construct an array of ones.
|
||||
|
||||
@ -1686,7 +1669,7 @@ void init_ops(nb::module_& m) {
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def all(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def all(a: array, /, axis: Union[None, int, Sequence[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
An `and` reduction over the given axes.
|
||||
|
||||
@ -1715,7 +1698,7 @@ void init_ops(nb::module_& m) {
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def any(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def any(a: array, /, axis: Union[None, int, Sequence[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
An `or` reduction over the given axes.
|
||||
|
||||
@ -1942,7 +1925,7 @@ void init_ops(nb::module_& m) {
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def transpose(a: array, /, axes: Optional[List[int]] = None, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def transpose(a: array, /, axes: Optional[Sequence[int]] = None, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Transpose the dimensions of the array.
|
||||
|
||||
@ -1968,7 +1951,7 @@ void init_ops(nb::module_& m) {
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def sum(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def sum(a: array, /, axis: Union[None, int, Sequence[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Sum reduce the array over the given axes.
|
||||
|
||||
@ -1997,7 +1980,7 @@ void init_ops(nb::module_& m) {
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def prod(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def prod(a: array, /, axis: Union[None, int, Sequence[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
An product reduction over the given axes.
|
||||
|
||||
@ -2026,7 +2009,7 @@ void init_ops(nb::module_& m) {
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def min(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def min(a: array, /, axis: Union[None, int, Sequence[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
A `min` reduction over the given axes.
|
||||
|
||||
@ -2055,7 +2038,7 @@ void init_ops(nb::module_& m) {
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def max(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def max(a: array, /, axis: Union[None, int, Sequence[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
A `max` reduction over the given axes.
|
||||
|
||||
@ -2084,7 +2067,7 @@ void init_ops(nb::module_& m) {
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def logsumexp(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def logsumexp(a: array, /, axis: Union[None, int, Sequence[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
A `log-sum-exp` reduction over the given axes.
|
||||
|
||||
@ -2119,7 +2102,7 @@ void init_ops(nb::module_& m) {
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def mean(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def mean(a: array, /, axis: Union[None, int, Sequence[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Compute the mean(s) over the given axes.
|
||||
|
||||
@ -2150,7 +2133,7 @@ void init_ops(nb::module_& m) {
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def var(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, ddof: int = 0, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def var(a: array, /, axis: Union[None, int, Sequence[int]] = None, keepdims: bool = False, ddof: int = 0, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Compute the variance(s) over the given axes.
|
||||
|
||||
@ -2186,7 +2169,7 @@ void init_ops(nb::module_& m) {
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def split(a: array, /, indices_or_sections: Union[int, List[int]], axis: int = 0, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def split(a: array, /, indices_or_sections: Union[int, Sequence[int]], axis: int = 0, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Split an array along a given axis.
|
||||
|
||||
@ -2432,7 +2415,7 @@ void init_ops(nb::module_& m) {
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def broadcast_to(a: Union[scalar, array], /, shape: List[int], *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def broadcast_to(a: Union[scalar, array], /, shape: Sequence[int], *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Broadcast an array to the given shape.
|
||||
|
||||
@ -2455,7 +2438,7 @@ void init_ops(nb::module_& m) {
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def softmax(a: array, /, axis: Union[None, int, List[int]] = None, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def softmax(a: array, /, axis: Union[None, int, Sequence[int]] = None, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Perform the softmax along the given axis.
|
||||
|
||||
@ -2677,7 +2660,7 @@ void init_ops(nb::module_& m) {
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def as_strided(a: array, /, shape: Optional[List[int]] = None, strides: Optional[List[int]] = None, offset: int = 0, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def as_strided(a: array, /, shape: Optional[Sequence[int]] = None, strides: Optional[Sequence[int]] = None, offset: int = 0, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Create a view into the array with the given shape and strides.
|
||||
|
||||
@ -3078,7 +3061,7 @@ void init_ops(nb::module_& m) {
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def conv_general(input: array, weight: array, /, stride: Union[int, List[int]] = 1, padding: Union[int, List[int], Tuple[List[int], List[int]]] = 0, kernel_dilation: Union[int, List[int]] = 1, input_dilation: Union[int, List[int]] = 1, groups: int = 1, flip: bool = false, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def conv_general(input: array, weight: array, /, stride: Union[int, Sequence[int]] = 1, padding: Union[int, Sequence[int], Tuple[Sequence[int], Sequence[int]]] = 0, kernel_dilation: Union[int, Sequence[int]] = 1, input_dilation: Union[int, Sequence[int]] = 1, groups: int = 1, flip: bool = false, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
General convolution over an input with several channels
|
||||
|
||||
@ -3471,7 +3454,7 @@ void init_ops(nb::module_& m) {
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def tensordot(a: array, b: array, /, axes: Union[int, List[List[int]]] = 2, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def tensordot(a: array, b: array, /, axes: Union[int, List[Sequence[int]]] = 2, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Compute the tensor dot product along the specified axes.
|
||||
|
||||
@ -3539,7 +3522,7 @@ void init_ops(nb::module_& m) {
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def tile(a: array, reps: Union[int, List[int]], /, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def tile(a: array, reps: Union[int, Sequence[int]], /, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Construct an array by repeating ``a`` the number of times given by ``reps``.
|
||||
|
||||
|
@ -92,6 +92,8 @@ void init_random(nb::module_& parent_module) {
|
||||
"key"_a,
|
||||
"num"_a = 2,
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def split(key: array, num: int = 2, stream: Union[None, Stream, Device] = None) -> array)"),
|
||||
R"pbdoc(
|
||||
Split a PRNG key into sub keys.
|
||||
|
||||
@ -125,6 +127,8 @@ void init_random(nb::module_& parent_module) {
|
||||
"dtype"_a.none() = float32,
|
||||
"key"_a = nb::none(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def uniform(low: Union[scalar, array] = 0, high: Union[scalar, array] = 1, shape: Sequence[int] = [], dtype: Optional[Dtype] = float32, key: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Generate uniformly distributed random numbers.
|
||||
|
||||
@ -159,6 +163,8 @@ void init_random(nb::module_& parent_module) {
|
||||
"scale"_a = 1.0,
|
||||
"key"_a = nb::none(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def normal(shape: Sequence[int] = [], dtype: Optional[Dtype] = float32, loc: float = 0.0, scale: float = 1.0, key: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Generate normally distributed random numbers.
|
||||
|
||||
@ -190,6 +196,8 @@ void init_random(nb::module_& parent_module) {
|
||||
"dtype"_a.none() = int32,
|
||||
"key"_a = nb::none(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def randint(low: Union[scalar, array], high: Union[scalar, array], shape: Sequence[int] = [], dtype: Optional[Dtype] = int32, key: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Generate random integers from the given interval.
|
||||
|
||||
@ -225,6 +233,8 @@ void init_random(nb::module_& parent_module) {
|
||||
"shape"_a = nb::none(),
|
||||
"key"_a = nb::none(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def bernoulli(p: Union[scalar, array] = 0.5, shape: Optional[Sequence[int]] = None, key: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Generate Bernoulli random values.
|
||||
|
||||
@ -266,6 +276,8 @@ void init_random(nb::module_& parent_module) {
|
||||
"dtype"_a.none() = float32,
|
||||
"key"_a = nb::none(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def truncated_normal(lower: Union[scalar, array], upper: Union[scalar, array], shape: Optional[Sequence[int]] = None, dtype: float32, key: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Generate values from a truncated normal distribution.
|
||||
|
||||
@ -298,6 +310,8 @@ void init_random(nb::module_& parent_module) {
|
||||
"dtype"_a.none() = float32,
|
||||
"stream"_a = nb::none(),
|
||||
"key"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def gumbel(shape: Sequence[int] = [], dtype: Optional[Dtype] = float32, stream: Optional[array] = None, key: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Sample from the standard Gumbel distribution.
|
||||
|
||||
@ -338,6 +352,8 @@ void init_random(nb::module_& parent_module) {
|
||||
"num_samples"_a = nb::none(),
|
||||
"key"_a = nb::none(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def categorical(logits: array, axis: int = -1, shape: Optional[Sequence[int]] = None, num_samples: Optional[int] = None, key: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Sample from a categorical distribution.
|
||||
|
||||
|
15
setup.py
15
setup.py
@ -134,9 +134,18 @@ class GenerateStubs(Command):
|
||||
pass
|
||||
|
||||
def run(self) -> None:
|
||||
subprocess.run(
|
||||
["python", "-m", "nanobind.stubgen", "-m", "mlx.core", "-r", "-O", "python"]
|
||||
)
|
||||
out_path = "python/mlx/core"
|
||||
stub_cmd = [
|
||||
"python",
|
||||
"-m",
|
||||
"nanobind.stubgen",
|
||||
"-m",
|
||||
"mlx.core",
|
||||
]
|
||||
subprocess.run(stub_cmd + ["-r", "-O", out_path])
|
||||
# Run again without recursive to specify output file name
|
||||
subprocess.run(["rm", f"{out_path}/mlx.pyi"])
|
||||
subprocess.run(stub_cmd + ["-o", f"{out_path}/__init__.pyi"])
|
||||
|
||||
|
||||
# Read the content of README.md
|
||||
|
Loading…
Reference in New Issue
Block a user