mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-21 10:18:10 +08:00
post nanobind docs fixes and some updates (#889)
* post nanobind docs fixes and some updates * one more doc nit * fix for stubs and latex
This commit is contained in:
@@ -47,7 +47,7 @@ void init_ops(nb::module_& m) {
|
||||
"shape"_a,
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig("def reshape(a: array, /, shape: List[int], *, stream: "
|
||||
nb::sig("def reshape(a: array, /, shape: Sequence[int], *, stream: "
|
||||
"Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Reshape an array while preserving the size.
|
||||
@@ -115,8 +115,9 @@ void init_ops(nb::module_& m) {
|
||||
"axis"_a = nb::none(),
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig("def squeeze(a: array, /, axis: Union[None, int, List[int]] = "
|
||||
"None, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
nb::sig(
|
||||
"def squeeze(a: array, /, axis: Union[None, int, Sequence[int]] = "
|
||||
"None, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Remove length one axes from an array.
|
||||
|
||||
@@ -143,7 +144,7 @@ void init_ops(nb::module_& m) {
|
||||
"axis"_a,
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig("def expand_dims(a: array, /, axis: Union[int, List[int]], "
|
||||
nb::sig("def expand_dims(a: array, /, axis: Union[int, Sequence[int]], "
|
||||
"*, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Add a size one dimension at the given axis.
|
||||
@@ -1148,78 +1149,36 @@ void init_ops(nb::module_& m) {
|
||||
Returns:
|
||||
array: Bases of ``a`` raised to powers in ``b``.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"arange",
|
||||
[](Scalar stop, std::optional<Dtype> dtype_, StreamOrDevice s) {
|
||||
Dtype dtype =
|
||||
dtype_.has_value() ? dtype_.value() : scalar_to_dtype(stop);
|
||||
|
||||
return arange(0.0, scalar_to_double(stop), 1.0, dtype, s);
|
||||
},
|
||||
"stop"_a,
|
||||
"dtype"_a = nb::none(),
|
||||
"stream"_a = nb::none());
|
||||
m.def(
|
||||
"arange",
|
||||
[](Scalar start,
|
||||
Scalar stop,
|
||||
std::optional<Dtype> dtype_,
|
||||
StreamOrDevice s) {
|
||||
Dtype dtype = dtype_.has_value()
|
||||
? dtype_.value()
|
||||
: promote_types(scalar_to_dtype(start), scalar_to_dtype(stop));
|
||||
return arange(
|
||||
scalar_to_double(start), scalar_to_double(stop), dtype, s);
|
||||
},
|
||||
"start"_a,
|
||||
"stop"_a,
|
||||
"dtype"_a = nb::none(),
|
||||
"stream"_a = nb::none());
|
||||
m.def(
|
||||
"arange",
|
||||
[](Scalar stop,
|
||||
Scalar step,
|
||||
std::optional<Dtype> dtype_,
|
||||
StreamOrDevice s) {
|
||||
Dtype dtype = dtype_.has_value()
|
||||
? dtype_.value()
|
||||
: promote_types(scalar_to_dtype(stop), scalar_to_dtype(step));
|
||||
|
||||
return arange(
|
||||
0.0, scalar_to_double(stop), scalar_to_double(step), dtype, s);
|
||||
},
|
||||
"stop"_a,
|
||||
"step"_a,
|
||||
"dtype"_a = nb::none(),
|
||||
"stream"_a = nb::none());
|
||||
m.def(
|
||||
"arange",
|
||||
[](Scalar start,
|
||||
Scalar stop,
|
||||
Scalar step,
|
||||
std::optional<Dtype> dtype_,
|
||||
const std::optional<Scalar>& step,
|
||||
const std::optional<Dtype>& dtype_,
|
||||
StreamOrDevice s) {
|
||||
// Determine the final dtype based on input types
|
||||
Dtype dtype = dtype_.has_value()
|
||||
? dtype_.value()
|
||||
Dtype dtype = dtype_
|
||||
? *dtype_
|
||||
: promote_types(
|
||||
scalar_to_dtype(start),
|
||||
promote_types(scalar_to_dtype(stop), scalar_to_dtype(step)));
|
||||
|
||||
step ? promote_types(
|
||||
scalar_to_dtype(stop), scalar_to_dtype(*step))
|
||||
: scalar_to_dtype(stop));
|
||||
return arange(
|
||||
scalar_to_double(start),
|
||||
scalar_to_double(stop),
|
||||
scalar_to_double(step),
|
||||
step ? scalar_to_double(*step) : 1.0,
|
||||
dtype,
|
||||
s);
|
||||
},
|
||||
"start"_a,
|
||||
"stop"_a,
|
||||
"step"_a,
|
||||
"step"_a = nb::none(),
|
||||
nb::kw_only(),
|
||||
"dtype"_a = nb::none(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def arange(start, stop, step, dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def arange(start : Union[int, float], stop : Union[int, float], step : Union[None, int, float], dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Generates ranges of numbers.
|
||||
|
||||
@@ -1244,6 +1203,30 @@ void init_ops(nb::module_& m) {
|
||||
This can lead to unexpected results for example if `start + step`
|
||||
is a fractional value and the `dtype` is integral.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"arange",
|
||||
[](Scalar stop,
|
||||
const std::optional<Scalar>& step,
|
||||
const std::optional<Dtype>& dtype_,
|
||||
StreamOrDevice s) {
|
||||
Dtype dtype = dtype_ ? *dtype_
|
||||
: step
|
||||
? promote_types(scalar_to_dtype(stop), scalar_to_dtype(*step))
|
||||
: scalar_to_dtype(stop);
|
||||
return arange(
|
||||
0.0,
|
||||
scalar_to_double(stop),
|
||||
step ? scalar_to_double(*step) : 1.0,
|
||||
dtype,
|
||||
s);
|
||||
},
|
||||
"stop"_a,
|
||||
"step"_a = nb::none(),
|
||||
nb::kw_only(),
|
||||
"dtype"_a = nb::none(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def arange(stop : Union[int, float], step : Union[None, int, float], dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array"));
|
||||
m.def(
|
||||
"linspace",
|
||||
[](Scalar start,
|
||||
@@ -1367,7 +1350,7 @@ void init_ops(nb::module_& m) {
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def full(shape: Union[int, List[int]], vals: Union[scalar, array], dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def full(shape: Union[int, Sequence[int]], vals: Union[scalar, array], dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Construct an array with the given value.
|
||||
|
||||
@@ -1400,7 +1383,7 @@ void init_ops(nb::module_& m) {
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def zeros(shape: Union[int, List[int]], dtype: Optional[Dtype] = float32, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def zeros(shape: Union[int, Sequence[int]], dtype: Optional[Dtype] = float32, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Construct an array of zeros.
|
||||
|
||||
@@ -1446,7 +1429,7 @@ void init_ops(nb::module_& m) {
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def ones(shape: Union[int, List[int]], dtype: Optional[Dtype] = float32, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def ones(shape: Union[int, Sequence[int]], dtype: Optional[Dtype] = float32, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Construct an array of ones.
|
||||
|
||||
@@ -1686,7 +1669,7 @@ void init_ops(nb::module_& m) {
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def all(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def all(a: array, /, axis: Union[None, int, Sequence[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
An `and` reduction over the given axes.
|
||||
|
||||
@@ -1715,7 +1698,7 @@ void init_ops(nb::module_& m) {
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def any(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def any(a: array, /, axis: Union[None, int, Sequence[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
An `or` reduction over the given axes.
|
||||
|
||||
@@ -1942,7 +1925,7 @@ void init_ops(nb::module_& m) {
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def transpose(a: array, /, axes: Optional[List[int]] = None, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def transpose(a: array, /, axes: Optional[Sequence[int]] = None, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Transpose the dimensions of the array.
|
||||
|
||||
@@ -1968,7 +1951,7 @@ void init_ops(nb::module_& m) {
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def sum(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def sum(a: array, /, axis: Union[None, int, Sequence[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Sum reduce the array over the given axes.
|
||||
|
||||
@@ -1997,7 +1980,7 @@ void init_ops(nb::module_& m) {
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def prod(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def prod(a: array, /, axis: Union[None, int, Sequence[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
An product reduction over the given axes.
|
||||
|
||||
@@ -2026,7 +2009,7 @@ void init_ops(nb::module_& m) {
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def min(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def min(a: array, /, axis: Union[None, int, Sequence[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
A `min` reduction over the given axes.
|
||||
|
||||
@@ -2055,7 +2038,7 @@ void init_ops(nb::module_& m) {
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def max(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def max(a: array, /, axis: Union[None, int, Sequence[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
A `max` reduction over the given axes.
|
||||
|
||||
@@ -2084,7 +2067,7 @@ void init_ops(nb::module_& m) {
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def logsumexp(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def logsumexp(a: array, /, axis: Union[None, int, Sequence[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
A `log-sum-exp` reduction over the given axes.
|
||||
|
||||
@@ -2119,7 +2102,7 @@ void init_ops(nb::module_& m) {
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def mean(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def mean(a: array, /, axis: Union[None, int, Sequence[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Compute the mean(s) over the given axes.
|
||||
|
||||
@@ -2150,7 +2133,7 @@ void init_ops(nb::module_& m) {
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def var(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, ddof: int = 0, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def var(a: array, /, axis: Union[None, int, Sequence[int]] = None, keepdims: bool = False, ddof: int = 0, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Compute the variance(s) over the given axes.
|
||||
|
||||
@@ -2186,7 +2169,7 @@ void init_ops(nb::module_& m) {
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def split(a: array, /, indices_or_sections: Union[int, List[int]], axis: int = 0, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def split(a: array, /, indices_or_sections: Union[int, Sequence[int]], axis: int = 0, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Split an array along a given axis.
|
||||
|
||||
@@ -2432,7 +2415,7 @@ void init_ops(nb::module_& m) {
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def broadcast_to(a: Union[scalar, array], /, shape: List[int], *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def broadcast_to(a: Union[scalar, array], /, shape: Sequence[int], *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Broadcast an array to the given shape.
|
||||
|
||||
@@ -2455,7 +2438,7 @@ void init_ops(nb::module_& m) {
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def softmax(a: array, /, axis: Union[None, int, List[int]] = None, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def softmax(a: array, /, axis: Union[None, int, Sequence[int]] = None, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Perform the softmax along the given axis.
|
||||
|
||||
@@ -2677,7 +2660,7 @@ void init_ops(nb::module_& m) {
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def as_strided(a: array, /, shape: Optional[List[int]] = None, strides: Optional[List[int]] = None, offset: int = 0, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def as_strided(a: array, /, shape: Optional[Sequence[int]] = None, strides: Optional[Sequence[int]] = None, offset: int = 0, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Create a view into the array with the given shape and strides.
|
||||
|
||||
@@ -3078,7 +3061,7 @@ void init_ops(nb::module_& m) {
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def conv_general(input: array, weight: array, /, stride: Union[int, List[int]] = 1, padding: Union[int, List[int], Tuple[List[int], List[int]]] = 0, kernel_dilation: Union[int, List[int]] = 1, input_dilation: Union[int, List[int]] = 1, groups: int = 1, flip: bool = false, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def conv_general(input: array, weight: array, /, stride: Union[int, Sequence[int]] = 1, padding: Union[int, Sequence[int], Tuple[Sequence[int], Sequence[int]]] = 0, kernel_dilation: Union[int, Sequence[int]] = 1, input_dilation: Union[int, Sequence[int]] = 1, groups: int = 1, flip: bool = false, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
General convolution over an input with several channels
|
||||
|
||||
@@ -3471,7 +3454,7 @@ void init_ops(nb::module_& m) {
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def tensordot(a: array, b: array, /, axes: Union[int, List[List[int]]] = 2, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def tensordot(a: array, b: array, /, axes: Union[int, List[Sequence[int]]] = 2, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Compute the tensor dot product along the specified axes.
|
||||
|
||||
@@ -3539,7 +3522,7 @@ void init_ops(nb::module_& m) {
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def tile(a: array, reps: Union[int, List[int]], /, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def tile(a: array, reps: Union[int, Sequence[int]], /, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Construct an array by repeating ``a`` the number of times given by ``reps``.
|
||||
|
||||
|
Reference in New Issue
Block a user