post nanobind docs fixes and some updates (#889)

* post nanobind docs fixes and some updates

* one more doc nit

* fix for stubs and latex
This commit is contained in:
Awni Hannun
2024-03-24 15:03:27 -07:00
committed by GitHub
parent be98f4ab6b
commit 1e16331d9c
16 changed files with 185 additions and 118 deletions

View File

@@ -569,11 +569,10 @@ void init_array(nb::module_& m) {
.. note::
Python in place updates for all array frameworks map to
assignment. For instance ``x[idx] += y`` maps to ``x[idx] =
x[idx] + y``. As a result, assigning to the same index ignores
all but one updates. Using ``x.at[idx].add(y)`` will correctly
apply all the updates to all indices.
Regular in-place updates map to assignment. For instance ``x[idx] += y``
maps to ``x[idx] = x[idx] + y``. As a result, assigning to the
same index ignores all but one update. Using ``x.at[idx].add(y)``
will correctly apply all updates to all indices.
.. list-table::
:header-rows: 1
@@ -591,7 +590,18 @@ void init_array(nb::module_& m) {
* - ``x = x.at[idx].maximum(y)``
- ``x[idx] = mx.maximum(x[idx], y)``
* - ``x = x.at[idx].minimum(y)``
- ``x[idx] = mx.minimum(x[idx], y)``
- ``x[idx] = mx.minimum(x[idx], y)``
Example:
>>> a = mx.array([0, 0])
>>> idx = mx.array([0, 1, 0, 1])
>>> a[idx] += 1
>>> a
array([1, 1], dtype=int32)
>>>
>>> a = mx.array([0, 0])
>>> a.at[idx].add(1)
array([2, 2], dtype=int32)
)pbdoc")
.def(
"__len__",

View File

@@ -118,12 +118,13 @@ void init_fast(nb::module_& parent_module) {
A fast implementation of multi-head attention: ``O = softmax(Q @ K.T, dim=-1) @ V``.
Supports:
* [Multi-Head Attention](https://arxiv.org/abs/1706.03762)
* [Grouped Query Attention](https://arxiv.org/abs/2305.13245)
* [Multi-Query Attention](https://arxiv.org/abs/1911.02150).
* `Multi-Head Attention <https://arxiv.org/abs/1706.03762>`_
* `Grouped Query Attention <https://arxiv.org/abs/2305.13245>`_
* `Multi-Query Attention <https://arxiv.org/abs/1911.02150>`_
Note: The softmax operation is performed in ``float32`` regardless of
input precision.
the input precision.
Note: For Grouped Query Attention and Multi-Query Attention, the ``k``
and ``v`` inputs should not be pre-tiled to match ``q``.

View File

@@ -47,7 +47,7 @@ void init_ops(nb::module_& m) {
"shape"_a,
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig("def reshape(a: array, /, shape: List[int], *, stream: "
nb::sig("def reshape(a: array, /, shape: Sequence[int], *, stream: "
"Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Reshape an array while preserving the size.
@@ -115,8 +115,9 @@ void init_ops(nb::module_& m) {
"axis"_a = nb::none(),
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig("def squeeze(a: array, /, axis: Union[None, int, List[int]] = "
"None, *, stream: Union[None, Stream, Device] = None) -> array"),
nb::sig(
"def squeeze(a: array, /, axis: Union[None, int, Sequence[int]] = "
"None, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Remove length one axes from an array.
@@ -143,7 +144,7 @@ void init_ops(nb::module_& m) {
"axis"_a,
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig("def expand_dims(a: array, /, axis: Union[int, List[int]], "
nb::sig("def expand_dims(a: array, /, axis: Union[int, Sequence[int]], "
"*, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Add a size one dimension at the given axis.
@@ -1148,78 +1149,36 @@ void init_ops(nb::module_& m) {
Returns:
array: Bases of ``a`` raised to powers in ``b``.
)pbdoc");
m.def(
"arange",
[](Scalar stop, std::optional<Dtype> dtype_, StreamOrDevice s) {
Dtype dtype =
dtype_.has_value() ? dtype_.value() : scalar_to_dtype(stop);
return arange(0.0, scalar_to_double(stop), 1.0, dtype, s);
},
"stop"_a,
"dtype"_a = nb::none(),
"stream"_a = nb::none());
m.def(
"arange",
[](Scalar start,
Scalar stop,
std::optional<Dtype> dtype_,
StreamOrDevice s) {
Dtype dtype = dtype_.has_value()
? dtype_.value()
: promote_types(scalar_to_dtype(start), scalar_to_dtype(stop));
return arange(
scalar_to_double(start), scalar_to_double(stop), dtype, s);
},
"start"_a,
"stop"_a,
"dtype"_a = nb::none(),
"stream"_a = nb::none());
m.def(
"arange",
[](Scalar stop,
Scalar step,
std::optional<Dtype> dtype_,
StreamOrDevice s) {
Dtype dtype = dtype_.has_value()
? dtype_.value()
: promote_types(scalar_to_dtype(stop), scalar_to_dtype(step));
return arange(
0.0, scalar_to_double(stop), scalar_to_double(step), dtype, s);
},
"stop"_a,
"step"_a,
"dtype"_a = nb::none(),
"stream"_a = nb::none());
m.def(
"arange",
[](Scalar start,
Scalar stop,
Scalar step,
std::optional<Dtype> dtype_,
const std::optional<Scalar>& step,
const std::optional<Dtype>& dtype_,
StreamOrDevice s) {
// Determine the final dtype based on input types
Dtype dtype = dtype_.has_value()
? dtype_.value()
Dtype dtype = dtype_
? *dtype_
: promote_types(
scalar_to_dtype(start),
promote_types(scalar_to_dtype(stop), scalar_to_dtype(step)));
step ? promote_types(
scalar_to_dtype(stop), scalar_to_dtype(*step))
: scalar_to_dtype(stop));
return arange(
scalar_to_double(start),
scalar_to_double(stop),
scalar_to_double(step),
step ? scalar_to_double(*step) : 1.0,
dtype,
s);
},
"start"_a,
"stop"_a,
"step"_a,
"step"_a = nb::none(),
nb::kw_only(),
"dtype"_a = nb::none(),
"stream"_a = nb::none(),
nb::sig(
"def arange(start, stop, step, dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array"),
"def arange(start : Union[int, float], stop : Union[int, float], step : Union[None, int, float], dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Generates ranges of numbers.
@@ -1244,6 +1203,30 @@ void init_ops(nb::module_& m) {
This can lead to unexpected results for example if `start + step`
is a fractional value and the `dtype` is integral.
)pbdoc");
m.def(
"arange",
[](Scalar stop,
const std::optional<Scalar>& step,
const std::optional<Dtype>& dtype_,
StreamOrDevice s) {
Dtype dtype = dtype_ ? *dtype_
: step
? promote_types(scalar_to_dtype(stop), scalar_to_dtype(*step))
: scalar_to_dtype(stop);
return arange(
0.0,
scalar_to_double(stop),
step ? scalar_to_double(*step) : 1.0,
dtype,
s);
},
"stop"_a,
"step"_a = nb::none(),
nb::kw_only(),
"dtype"_a = nb::none(),
"stream"_a = nb::none(),
nb::sig(
"def arange(stop : Union[int, float], step : Union[None, int, float], dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array"));
m.def(
"linspace",
[](Scalar start,
@@ -1367,7 +1350,7 @@ void init_ops(nb::module_& m) {
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def full(shape: Union[int, List[int]], vals: Union[scalar, array], dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array"),
"def full(shape: Union[int, Sequence[int]], vals: Union[scalar, array], dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Construct an array with the given value.
@@ -1400,7 +1383,7 @@ void init_ops(nb::module_& m) {
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def zeros(shape: Union[int, List[int]], dtype: Optional[Dtype] = float32, *, stream: Union[None, Stream, Device] = None) -> array"),
"def zeros(shape: Union[int, Sequence[int]], dtype: Optional[Dtype] = float32, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Construct an array of zeros.
@@ -1446,7 +1429,7 @@ void init_ops(nb::module_& m) {
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def ones(shape: Union[int, List[int]], dtype: Optional[Dtype] = float32, *, stream: Union[None, Stream, Device] = None) -> array"),
"def ones(shape: Union[int, Sequence[int]], dtype: Optional[Dtype] = float32, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Construct an array of ones.
@@ -1686,7 +1669,7 @@ void init_ops(nb::module_& m) {
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def all(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"),
"def all(a: array, /, axis: Union[None, int, Sequence[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
An `and` reduction over the given axes.
@@ -1715,7 +1698,7 @@ void init_ops(nb::module_& m) {
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def any(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"),
"def any(a: array, /, axis: Union[None, int, Sequence[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
An `or` reduction over the given axes.
@@ -1942,7 +1925,7 @@ void init_ops(nb::module_& m) {
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def transpose(a: array, /, axes: Optional[List[int]] = None, *, stream: Union[None, Stream, Device] = None) -> array"),
"def transpose(a: array, /, axes: Optional[Sequence[int]] = None, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Transpose the dimensions of the array.
@@ -1968,7 +1951,7 @@ void init_ops(nb::module_& m) {
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def sum(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"),
"def sum(a: array, /, axis: Union[None, int, Sequence[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Sum reduce the array over the given axes.
@@ -1997,7 +1980,7 @@ void init_ops(nb::module_& m) {
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def prod(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"),
"def prod(a: array, /, axis: Union[None, int, Sequence[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
An product reduction over the given axes.
@@ -2026,7 +2009,7 @@ void init_ops(nb::module_& m) {
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def min(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"),
"def min(a: array, /, axis: Union[None, int, Sequence[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
A `min` reduction over the given axes.
@@ -2055,7 +2038,7 @@ void init_ops(nb::module_& m) {
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def max(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"),
"def max(a: array, /, axis: Union[None, int, Sequence[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
A `max` reduction over the given axes.
@@ -2084,7 +2067,7 @@ void init_ops(nb::module_& m) {
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def logsumexp(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"),
"def logsumexp(a: array, /, axis: Union[None, int, Sequence[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
A `log-sum-exp` reduction over the given axes.
@@ -2119,7 +2102,7 @@ void init_ops(nb::module_& m) {
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def mean(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"),
"def mean(a: array, /, axis: Union[None, int, Sequence[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Compute the mean(s) over the given axes.
@@ -2150,7 +2133,7 @@ void init_ops(nb::module_& m) {
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def var(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, ddof: int = 0, *, stream: Union[None, Stream, Device] = None) -> array"),
"def var(a: array, /, axis: Union[None, int, Sequence[int]] = None, keepdims: bool = False, ddof: int = 0, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Compute the variance(s) over the given axes.
@@ -2186,7 +2169,7 @@ void init_ops(nb::module_& m) {
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def split(a: array, /, indices_or_sections: Union[int, List[int]], axis: int = 0, *, stream: Union[None, Stream, Device] = None) -> array"),
"def split(a: array, /, indices_or_sections: Union[int, Sequence[int]], axis: int = 0, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Split an array along a given axis.
@@ -2432,7 +2415,7 @@ void init_ops(nb::module_& m) {
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def broadcast_to(a: Union[scalar, array], /, shape: List[int], *, stream: Union[None, Stream, Device] = None) -> array"),
"def broadcast_to(a: Union[scalar, array], /, shape: Sequence[int], *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Broadcast an array to the given shape.
@@ -2455,7 +2438,7 @@ void init_ops(nb::module_& m) {
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def softmax(a: array, /, axis: Union[None, int, List[int]] = None, *, stream: Union[None, Stream, Device] = None) -> array"),
"def softmax(a: array, /, axis: Union[None, int, Sequence[int]] = None, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Perform the softmax along the given axis.
@@ -2677,7 +2660,7 @@ void init_ops(nb::module_& m) {
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def as_strided(a: array, /, shape: Optional[List[int]] = None, strides: Optional[List[int]] = None, offset: int = 0, *, stream: Union[None, Stream, Device] = None) -> array"),
"def as_strided(a: array, /, shape: Optional[Sequence[int]] = None, strides: Optional[Sequence[int]] = None, offset: int = 0, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Create a view into the array with the given shape and strides.
@@ -3078,7 +3061,7 @@ void init_ops(nb::module_& m) {
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def conv_general(input: array, weight: array, /, stride: Union[int, List[int]] = 1, padding: Union[int, List[int], Tuple[List[int], List[int]]] = 0, kernel_dilation: Union[int, List[int]] = 1, input_dilation: Union[int, List[int]] = 1, groups: int = 1, flip: bool = false, *, stream: Union[None, Stream, Device] = None) -> array"),
"def conv_general(input: array, weight: array, /, stride: Union[int, Sequence[int]] = 1, padding: Union[int, Sequence[int], Tuple[Sequence[int], Sequence[int]]] = 0, kernel_dilation: Union[int, Sequence[int]] = 1, input_dilation: Union[int, Sequence[int]] = 1, groups: int = 1, flip: bool = false, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
General convolution over an input with several channels
@@ -3471,7 +3454,7 @@ void init_ops(nb::module_& m) {
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def tensordot(a: array, b: array, /, axes: Union[int, List[List[int]]] = 2, *, stream: Union[None, Stream, Device] = None) -> array"),
"def tensordot(a: array, b: array, /, axes: Union[int, List[Sequence[int]]] = 2, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Compute the tensor dot product along the specified axes.
@@ -3539,7 +3522,7 @@ void init_ops(nb::module_& m) {
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def tile(a: array, reps: Union[int, List[int]], /, *, stream: Union[None, Stream, Device] = None) -> array"),
"def tile(a: array, reps: Union[int, Sequence[int]], /, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Construct an array by repeating ``a`` the number of times given by ``reps``.

View File

@@ -92,6 +92,8 @@ void init_random(nb::module_& parent_module) {
"key"_a,
"num"_a = 2,
"stream"_a = nb::none(),
nb::sig(
"def split(key: array, num: int = 2, stream: Union[None, Stream, Device] = None) -> array)"),
R"pbdoc(
Split a PRNG key into sub keys.
@@ -125,6 +127,8 @@ void init_random(nb::module_& parent_module) {
"dtype"_a.none() = float32,
"key"_a = nb::none(),
"stream"_a = nb::none(),
nb::sig(
"def uniform(low: Union[scalar, array] = 0, high: Union[scalar, array] = 1, shape: Sequence[int] = [], dtype: Optional[Dtype] = float32, key: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Generate uniformly distributed random numbers.
@@ -159,6 +163,8 @@ void init_random(nb::module_& parent_module) {
"scale"_a = 1.0,
"key"_a = nb::none(),
"stream"_a = nb::none(),
nb::sig(
"def normal(shape: Sequence[int] = [], dtype: Optional[Dtype] = float32, loc: float = 0.0, scale: float = 1.0, key: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Generate normally distributed random numbers.
@@ -190,6 +196,8 @@ void init_random(nb::module_& parent_module) {
"dtype"_a.none() = int32,
"key"_a = nb::none(),
"stream"_a = nb::none(),
nb::sig(
"def randint(low: Union[scalar, array], high: Union[scalar, array], shape: Sequence[int] = [], dtype: Optional[Dtype] = int32, key: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Generate random integers from the given interval.
@@ -225,6 +233,8 @@ void init_random(nb::module_& parent_module) {
"shape"_a = nb::none(),
"key"_a = nb::none(),
"stream"_a = nb::none(),
nb::sig(
"def bernoulli(p: Union[scalar, array] = 0.5, shape: Optional[Sequence[int]] = None, key: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Generate Bernoulli random values.
@@ -266,6 +276,8 @@ void init_random(nb::module_& parent_module) {
"dtype"_a.none() = float32,
"key"_a = nb::none(),
"stream"_a = nb::none(),
nb::sig(
"def truncated_normal(lower: Union[scalar, array], upper: Union[scalar, array], shape: Optional[Sequence[int]] = None, dtype: float32, key: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Generate values from a truncated normal distribution.
@@ -298,6 +310,8 @@ void init_random(nb::module_& parent_module) {
"dtype"_a.none() = float32,
"stream"_a = nb::none(),
"key"_a = nb::none(),
nb::sig(
"def gumbel(shape: Sequence[int] = [], dtype: Optional[Dtype] = float32, stream: Optional[array] = None, key: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Sample from the standard Gumbel distribution.
@@ -338,6 +352,8 @@ void init_random(nb::module_& parent_module) {
"num_samples"_a = nb::none(),
"key"_a = nb::none(),
"stream"_a = nb::none(),
nb::sig(
"def categorical(logits: array, axis: int = -1, shape: Optional[Sequence[int]] = None, num_samples: Optional[int] = None, key: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Sample from a categorical distribution.