2023-12-01 03:12:53 +08:00
|
|
|
// Copyright © 2023 Apple Inc.
|
|
|
|
|
2023-11-30 02:30:41 +08:00
|
|
|
#include <numeric>
|
|
|
|
#include <ostream>
|
|
|
|
#include <variant>
|
|
|
|
|
|
|
|
#include <pybind11/iostream.h>
|
|
|
|
#include <pybind11/pybind11.h>
|
|
|
|
#include <pybind11/stl.h>
|
|
|
|
|
|
|
|
#include "mlx/ops.h"
|
|
|
|
#include "mlx/utils.h"
|
|
|
|
#include "python/src/load.h"
|
|
|
|
#include "python/src/utils.h"
|
|
|
|
|
|
|
|
namespace py = pybind11;
|
|
|
|
using namespace py::literals;
|
|
|
|
using namespace mlx::core;
|
|
|
|
|
|
|
|
using Scalar = std::variant<int, double>;
|
|
|
|
|
|
|
|
Dtype scalar_to_dtype(Scalar scalar) {
|
|
|
|
if (std::holds_alternative<int>(scalar)) {
|
|
|
|
return int32;
|
|
|
|
} else {
|
|
|
|
return float32;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
double scalar_to_double(Scalar s) {
|
|
|
|
if (std::holds_alternative<double>(s)) {
|
|
|
|
return std::get<double>(s);
|
|
|
|
} else {
|
|
|
|
return static_cast<double>(std::get<int>(s));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
void init_ops(py::module_& m) {
|
2023-12-12 05:42:55 +08:00
|
|
|
py::options options;
|
|
|
|
options.disable_function_signatures();
|
|
|
|
|
2023-11-30 02:30:41 +08:00
|
|
|
m.def(
|
|
|
|
"reshape",
|
|
|
|
&reshape,
|
|
|
|
"a"_a,
|
|
|
|
py::pos_only(),
|
|
|
|
"shape"_a,
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
R"pbdoc(
|
2023-12-12 05:42:55 +08:00
|
|
|
reshape(a: array, /, shape: List[int], *, stream: Union[None, Stream, Device] = None) -> array
|
|
|
|
|
2023-11-30 02:30:41 +08:00
|
|
|
Reshape an array while preserving the size.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
shape (tuple(int)): New shape.
|
|
|
|
stream (Stream, optional): Stream or device. Defaults to ```None```
|
|
|
|
in which case the default stream of the default device is used.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The reshaped array.
|
|
|
|
)pbdoc");
|
2023-12-17 13:54:37 +08:00
|
|
|
m.def(
|
|
|
|
"flatten",
|
|
|
|
[](const array& a,
|
|
|
|
int start_axis,
|
|
|
|
int end_axis,
|
|
|
|
const StreamOrDevice& s) { return flatten(a, start_axis, end_axis); },
|
|
|
|
"a"_a,
|
|
|
|
py::pos_only(),
|
|
|
|
"start_axis"_a = 0,
|
|
|
|
"end_axis"_a = -1,
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
R"pbdoc(
|
|
|
|
flatten(a: array, /, start_axis: int = 0, end_axis: int = -1, *, stream: Union[None, Stream, Device] = None) -> array
|
|
|
|
|
|
|
|
Flatten an array.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
start_axis (int, optional): The first dimension to flatten. Defaults to ``0``.
|
|
|
|
end_axis (int, optional): The last dimension to flatten. Defaults to ``-1``.
|
|
|
|
stream (Stream, optional): Stream or device. Defaults to ``None``
|
|
|
|
in which case the default stream of the default device is used.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The flattened array.
|
|
|
|
)pbdoc");
|
2023-11-30 02:30:41 +08:00
|
|
|
m.def(
|
|
|
|
"squeeze",
|
|
|
|
[](const array& a, const IntOrVec& v, const StreamOrDevice& s) {
|
|
|
|
if (std::holds_alternative<std::monostate>(v)) {
|
|
|
|
return squeeze(a, s);
|
|
|
|
} else if (auto pv = std::get_if<int>(&v); pv) {
|
|
|
|
return squeeze(a, *pv, s);
|
|
|
|
} else {
|
|
|
|
return squeeze(a, std::get<std::vector<int>>(v), s);
|
|
|
|
}
|
|
|
|
},
|
|
|
|
"a"_a,
|
|
|
|
py::pos_only(),
|
|
|
|
"axis"_a = none,
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
R"pbdoc(
|
2023-12-12 05:42:55 +08:00
|
|
|
squeeze(a: array, /, axis: Union[None, int, List[int]] = None, *, stream: Union[None, Stream, Device] = None) -> array
|
|
|
|
|
2023-11-30 02:30:41 +08:00
|
|
|
Remove length one axes from an array.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
axis (int or tuple(int), optional): Axes to remove. Defaults
|
|
|
|
to ```None``` in which case all size one axes are removed.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The output array with size one axes removed.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"expand_dims",
|
|
|
|
[](const array& a,
|
|
|
|
const std::variant<int, std::vector<int>>& v,
|
|
|
|
StreamOrDevice s) {
|
|
|
|
if (auto pv = std::get_if<int>(&v); pv) {
|
|
|
|
return expand_dims(a, *pv, s);
|
|
|
|
} else {
|
|
|
|
return expand_dims(a, std::get<std::vector<int>>(v), s);
|
|
|
|
}
|
|
|
|
},
|
|
|
|
"a"_a,
|
|
|
|
py::pos_only(),
|
|
|
|
"axis"_a,
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
R"pbdoc(
|
2023-12-12 05:42:55 +08:00
|
|
|
expand_dims(a: array, /, axis: Union[int, List[int]], *, stream: Union[None, Stream, Device] = None) -> array
|
|
|
|
|
2023-11-30 02:30:41 +08:00
|
|
|
Add a size one dimension at the given axis.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
axes (int or tuple(int)): The index of the inserted dimensions.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The array with inserted dimensions.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"abs",
|
|
|
|
&mlx::core::abs,
|
|
|
|
"a"_a,
|
|
|
|
py::pos_only(),
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
R"pbdoc(
|
2023-12-12 05:42:55 +08:00
|
|
|
abs(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
|
|
|
|
|
2023-11-30 02:30:41 +08:00
|
|
|
Element-wise absolute value.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The absolute value of ``a``.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"sign",
|
|
|
|
&sign,
|
|
|
|
"a"_a,
|
|
|
|
py::pos_only(),
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
R"pbdoc(
|
2023-12-12 05:42:55 +08:00
|
|
|
sign(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
|
|
|
|
|
2023-11-30 02:30:41 +08:00
|
|
|
Element-wise sign.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The sign of ``a``.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"negative",
|
|
|
|
&negative,
|
|
|
|
"a"_a,
|
|
|
|
py::pos_only(),
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
R"pbdoc(
|
2023-12-12 05:42:55 +08:00
|
|
|
negative(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
|
|
|
|
|
2023-11-30 02:30:41 +08:00
|
|
|
Element-wise negation.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The negative of ``a``.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"add",
|
|
|
|
[](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) {
|
|
|
|
auto [a, b] = to_arrays(a_, b_);
|
|
|
|
return add(a, b, s);
|
|
|
|
},
|
|
|
|
"a"_a,
|
|
|
|
"b"_a,
|
|
|
|
py::pos_only(),
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
R"pbdoc(
|
2023-12-12 05:42:55 +08:00
|
|
|
add(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array
|
|
|
|
|
2023-11-30 02:30:41 +08:00
|
|
|
Element-wise addition.
|
|
|
|
|
|
|
|
Add two arrays with numpy-style broadcasting semantics. Either or both input arrays
|
|
|
|
can also be scalars.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array or scalar.
|
|
|
|
b (array): Input array or scalar.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The sum of ``a`` and ``b``.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"subtract",
|
|
|
|
[](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) {
|
|
|
|
auto [a, b] = to_arrays(a_, b_);
|
|
|
|
return subtract(a, b, s);
|
|
|
|
},
|
|
|
|
"a"_a,
|
|
|
|
"b"_a,
|
|
|
|
py::pos_only(),
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
R"pbdoc(
|
2023-12-12 05:42:55 +08:00
|
|
|
subtract(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array
|
|
|
|
|
2023-11-30 02:30:41 +08:00
|
|
|
Element-wise subtraction.
|
|
|
|
|
|
|
|
Subtract one array from another with numpy-style broadcasting semantics. Either or both
|
|
|
|
input arrays can also be scalars.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array or scalar.
|
|
|
|
b (array): Input array or scalar.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The difference ``a - b``.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"multiply",
|
|
|
|
[](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) {
|
|
|
|
auto [a, b] = to_arrays(a_, b_);
|
|
|
|
return multiply(a, b, s);
|
|
|
|
},
|
|
|
|
"a"_a,
|
|
|
|
"b"_a,
|
|
|
|
py::pos_only(),
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
R"pbdoc(
|
2023-12-12 05:42:55 +08:00
|
|
|
multiply(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array
|
|
|
|
|
2023-11-30 02:30:41 +08:00
|
|
|
Element-wise multiplication.
|
|
|
|
|
|
|
|
Multiply two arrays with numpy-style broadcasting semantics. Either or both
|
|
|
|
input arrays can also be scalars.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array or scalar.
|
|
|
|
b (array): Input array or scalar.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The multiplication ``a * b``.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"divide",
|
|
|
|
[](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) {
|
|
|
|
auto [a, b] = to_arrays(a_, b_);
|
|
|
|
return divide(a, b, s);
|
|
|
|
},
|
|
|
|
"a"_a,
|
|
|
|
"b"_a,
|
|
|
|
py::pos_only(),
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
R"pbdoc(
|
2023-12-12 05:42:55 +08:00
|
|
|
divide(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array
|
|
|
|
|
2023-11-30 02:30:41 +08:00
|
|
|
Element-wise division.
|
|
|
|
|
|
|
|
Divide two arrays with numpy-style broadcasting semantics. Either or both
|
|
|
|
input arrays can also be scalars.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array or scalar.
|
|
|
|
b (array): Input array or scalar.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The quotient ``a / b``.
|
|
|
|
)pbdoc");
|
2023-12-09 07:08:52 +08:00
|
|
|
m.def(
|
|
|
|
"remainder",
|
|
|
|
[](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) {
|
|
|
|
auto [a, b] = to_arrays(a_, b_);
|
|
|
|
return remainder(a, b, s);
|
|
|
|
},
|
|
|
|
"a"_a,
|
|
|
|
"b"_a,
|
|
|
|
py::pos_only(),
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
R"pbdoc(
|
2023-12-12 05:42:55 +08:00
|
|
|
remainder(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array
|
|
|
|
|
2023-12-09 07:08:52 +08:00
|
|
|
Element-wise remainder of division.
|
|
|
|
|
|
|
|
Computes the remainder of dividing a with b with numpy-style
|
|
|
|
broadcasting semantics. Either or both input arrays can also be
|
|
|
|
scalars.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array or scalar.
|
|
|
|
b (array): Input array or scalar.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The remainder of ``a // b``.
|
|
|
|
)pbdoc");
|
2023-11-30 02:30:41 +08:00
|
|
|
m.def(
|
|
|
|
"equal",
|
|
|
|
[](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) {
|
|
|
|
auto [a, b] = to_arrays(a_, b_);
|
|
|
|
return equal(a, b, s);
|
|
|
|
},
|
|
|
|
"a"_a,
|
|
|
|
"b"_a,
|
|
|
|
py::pos_only(),
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
R"pbdoc(
|
2023-12-12 05:42:55 +08:00
|
|
|
equal(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array
|
|
|
|
|
2023-11-30 02:30:41 +08:00
|
|
|
Element-wise equality.
|
|
|
|
|
|
|
|
Equality comparison on two arrays with numpy-style broadcasting semantics.
|
|
|
|
Either or both input arrays can also be scalars.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array or scalar.
|
|
|
|
b (array): Input array or scalar.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The element-wise comparison ``a == b``.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"not_equal",
|
|
|
|
[](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) {
|
|
|
|
auto [a, b] = to_arrays(a_, b_);
|
|
|
|
return not_equal(a, b, s);
|
|
|
|
},
|
|
|
|
"a"_a,
|
|
|
|
"b"_a,
|
|
|
|
py::pos_only(),
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
R"pbdoc(
|
2023-12-12 05:42:55 +08:00
|
|
|
not_equal(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array
|
|
|
|
|
2023-11-30 02:30:41 +08:00
|
|
|
Element-wise not equal.
|
|
|
|
|
|
|
|
Not equal comparison on two arrays with numpy-style broadcasting semantics.
|
|
|
|
Either or both input arrays can also be scalars.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array or scalar.
|
|
|
|
b (array): Input array or scalar.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The element-wise comparison ``a != b``.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"less",
|
|
|
|
[](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) {
|
|
|
|
auto [a, b] = to_arrays(a_, b_);
|
|
|
|
return less(a, b, s);
|
|
|
|
},
|
|
|
|
"a"_a,
|
|
|
|
"b"_a,
|
|
|
|
py::pos_only(),
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
R"pbdoc(
|
2023-12-12 05:42:55 +08:00
|
|
|
less(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array
|
|
|
|
|
2023-11-30 02:30:41 +08:00
|
|
|
Element-wise less than.
|
|
|
|
|
|
|
|
Strict less than on two arrays with numpy-style broadcasting semantics.
|
|
|
|
Either or both input arrays can also be scalars.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array or scalar.
|
|
|
|
b (array): Input array or scalar.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The element-wise comparison ``a < b``.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"less_equal",
|
|
|
|
[](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) {
|
|
|
|
auto [a, b] = to_arrays(a_, b_);
|
|
|
|
return less_equal(a, b, s);
|
|
|
|
},
|
|
|
|
"a"_a,
|
|
|
|
"b"_a,
|
|
|
|
py::pos_only(),
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
R"pbdoc(
|
2023-12-12 05:42:55 +08:00
|
|
|
less_equal(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array
|
|
|
|
|
2023-11-30 02:30:41 +08:00
|
|
|
Element-wise less than or equal.
|
|
|
|
|
|
|
|
Less than or equal on two arrays with numpy-style broadcasting semantics.
|
|
|
|
Either or both input arrays can also be scalars.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array or scalar.
|
|
|
|
b (array): Input array or scalar.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The element-wise comparison ``a <= b``.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"greater",
|
|
|
|
[](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) {
|
|
|
|
auto [a, b] = to_arrays(a_, b_);
|
|
|
|
return greater(a, b, s);
|
|
|
|
},
|
|
|
|
"a"_a,
|
|
|
|
"b"_a,
|
|
|
|
py::pos_only(),
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
R"pbdoc(
|
2023-12-12 05:42:55 +08:00
|
|
|
greater(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array
|
|
|
|
|
2023-11-30 02:30:41 +08:00
|
|
|
Element-wise greater than.
|
|
|
|
|
|
|
|
Strict greater than on two arrays with numpy-style broadcasting semantics.
|
|
|
|
Either or both input arrays can also be scalars.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array or scalar.
|
|
|
|
b (array): Input array or scalar.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The element-wise comparison ``a > b``.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"greater_equal",
|
|
|
|
[](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) {
|
|
|
|
auto [a, b] = to_arrays(a_, b_);
|
|
|
|
return greater_equal(a, b, s);
|
|
|
|
},
|
|
|
|
"a"_a,
|
|
|
|
"b"_a,
|
|
|
|
py::pos_only(),
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
R"pbdoc(
|
2023-12-12 05:42:55 +08:00
|
|
|
greater_equal(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array
|
|
|
|
|
2023-11-30 02:30:41 +08:00
|
|
|
Element-wise greater or equal.
|
|
|
|
|
|
|
|
Greater than or equal on two arrays with numpy-style broadcasting semantics.
|
|
|
|
Either or both input arrays can also be scalars.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array or scalar.
|
|
|
|
b (array): Input array or scalar.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The element-wise comparison ``a >= b``.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"array_equal",
|
|
|
|
[](const ScalarOrArray& a_,
|
|
|
|
const ScalarOrArray& b_,
|
|
|
|
bool equal_nan,
|
|
|
|
StreamOrDevice s) {
|
|
|
|
auto [a, b] = to_arrays(a_, b_);
|
|
|
|
return array_equal(a, b, equal_nan, s);
|
|
|
|
},
|
|
|
|
"a"_a,
|
|
|
|
"b"_a,
|
|
|
|
py::pos_only(),
|
|
|
|
py::kw_only(),
|
|
|
|
"equal_nan"_a = false,
|
|
|
|
"stream"_a = none,
|
|
|
|
R"pbdoc(
|
2023-12-12 05:42:55 +08:00
|
|
|
array_equal(a: Union[scalar, array], b: Union[scalar, array], equal_nan: bool = False, stream: Union[None, Stream, Device] = None) -> array
|
|
|
|
|
2023-11-30 02:30:41 +08:00
|
|
|
Array equality check.
|
|
|
|
|
|
|
|
Compare two arrays for equality. Returns ``True`` if and only if the arrays
|
|
|
|
have the same shape and their values are equal. The arrays need not have
|
|
|
|
the same type to be considered equal.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array or scalar.
|
|
|
|
b (array): Input array or scalar.
|
|
|
|
equal_nan (bool): If ``True``, NaNs are treated as equal.
|
|
|
|
Defaults to ``False``.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: A scalar boolean array.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"matmul",
|
|
|
|
&matmul,
|
|
|
|
"a"_a,
|
|
|
|
"b"_a,
|
|
|
|
py::pos_only(),
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
R"pbdoc(
|
2023-12-12 05:42:55 +08:00
|
|
|
matmul(a: array, b: array, /, *, stream: Union[None, Stream, Device] = None) -> array
|
|
|
|
|
2023-11-30 02:30:41 +08:00
|
|
|
Matrix multiplication.
|
|
|
|
|
|
|
|
Perform the (possibly batched) matrix multiplication of two arrays. This function supports
|
|
|
|
broadcasting for arrays with more than two dimensions.
|
|
|
|
|
|
|
|
- If the first array is 1-D then a 1 is prepended to its shape to make it
|
|
|
|
a matrix. Similarly if the second array is 1-D then a 1 is appended to its
|
|
|
|
shape to make it a matrix. In either case the singleton dimension is removed
|
|
|
|
from the result.
|
|
|
|
- A batched matrix multiplication is performed if the arrays have more than
|
|
|
|
2 dimensions. The matrix dimensions for the matrix product are the last
|
|
|
|
two dimensions of each input.
|
|
|
|
- All but the last two dimensions of each input are broadcast with one another using
|
|
|
|
standard numpy-style broadcasting semantics.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array or scalar.
|
|
|
|
b (array): Input array or scalar.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The matrix product of ``a`` and ``b``.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"square",
|
|
|
|
&square,
|
|
|
|
"a"_a,
|
|
|
|
py::pos_only(),
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
R"pbdoc(
|
2023-12-12 05:42:55 +08:00
|
|
|
square(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
|
|
|
|
|
2023-11-30 02:30:41 +08:00
|
|
|
Element-wise square.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The square of ``a``.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"sqrt",
|
|
|
|
&mlx::core::sqrt,
|
|
|
|
"a"_a,
|
|
|
|
py::pos_only(),
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
R"pbdoc(
|
2023-12-12 05:42:55 +08:00
|
|
|
sqrt(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
|
|
|
|
|
2023-11-30 02:30:41 +08:00
|
|
|
Element-wise square root.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The square root of ``a``.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"rsqrt",
|
|
|
|
&rsqrt,
|
|
|
|
"a"_a,
|
|
|
|
py::pos_only(),
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
R"pbdoc(
|
2023-12-12 05:42:55 +08:00
|
|
|
rsqrt(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
|
|
|
|
|
2023-11-30 02:30:41 +08:00
|
|
|
Element-wise reciprocal and square root.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: One over the square root of ``a``.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"reciprocal",
|
|
|
|
&reciprocal,
|
|
|
|
"a"_a,
|
|
|
|
py::pos_only(),
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
R"pbdoc(
|
2023-12-12 05:42:55 +08:00
|
|
|
reciprocal(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
|
|
|
|
|
2023-11-30 02:30:41 +08:00
|
|
|
Element-wise reciprocal.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The reciprocal of ``a``.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"logical_not",
|
|
|
|
[](const ScalarOrArray& a, StreamOrDevice s) {
|
|
|
|
return logical_not(to_array(a), s);
|
|
|
|
},
|
|
|
|
"a"_a,
|
|
|
|
py::pos_only(),
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
R"pbdoc(
|
2023-12-12 05:42:55 +08:00
|
|
|
logical_not(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
|
|
|
|
|
2023-11-30 02:30:41 +08:00
|
|
|
Element-wise logical not.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array or scalar.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The boolean array containing the logical not of ``a``.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"logaddexp",
|
|
|
|
[](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) {
|
|
|
|
auto [a, b] = to_arrays(a_, b_);
|
|
|
|
return logaddexp(a, b, s);
|
|
|
|
},
|
|
|
|
"a"_a,
|
|
|
|
"b"_a,
|
|
|
|
py::pos_only(),
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
R"pbdoc(
|
2023-12-12 05:42:55 +08:00
|
|
|
logaddexp(a: Union[scalar, array], b: Union[scalar, array], /, *, stream: Union[None, Stream, Device] = None) -> array
|
|
|
|
|
2023-11-30 02:30:41 +08:00
|
|
|
Element-wise log-add-exp.
|
|
|
|
|
|
|
|
This is a numerically stable log-add-exp of two arrays with numpy-style
|
|
|
|
broadcasting semantics. Either or both input arrays can also be scalars.
|
|
|
|
|
|
|
|
The computation is is a numerically stable version of ``log(exp(a) + exp(b))``.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array or scalar.
|
|
|
|
b (array): Input array or scalar.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The log-add-exp of ``a`` and ``b``.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"exp",
|
|
|
|
&mlx::core::exp,
|
|
|
|
"a"_a,
|
|
|
|
py::pos_only(),
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
R"pbdoc(
|
2023-12-12 05:42:55 +08:00
|
|
|
exp(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
|
|
|
|
|
2023-11-30 02:30:41 +08:00
|
|
|
Element-wise exponential.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The exponential of ``a``.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"erf",
|
|
|
|
&mlx::core::erf,
|
|
|
|
"a"_a,
|
|
|
|
py::pos_only(),
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
R"pbdoc(
|
2023-12-12 05:42:55 +08:00
|
|
|
erf(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
|
|
|
|
|
2023-11-30 02:30:41 +08:00
|
|
|
Element-wise error function.
|
|
|
|
|
|
|
|
.. math::
|
|
|
|
\mathrm{erf}(x) = \frac{2}{\sqrt{\pi}} \int_0^t e^{-t^2} \, dx
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The error function of ``a``.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"erfinv",
|
|
|
|
&mlx::core::erfinv,
|
|
|
|
"a"_a,
|
|
|
|
py::pos_only(),
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
R"pbdoc(
|
2023-12-12 05:42:55 +08:00
|
|
|
erfinv(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
|
|
|
|
|
2023-11-30 02:30:41 +08:00
|
|
|
Element-wise inverse of :func:`erf`.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The inverse error function of ``a``.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"sin",
|
|
|
|
&mlx::core::sin,
|
|
|
|
"a"_a,
|
|
|
|
py::pos_only(),
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
R"pbdoc(
|
2023-12-12 05:42:55 +08:00
|
|
|
sin(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
|
|
|
|
|
2023-11-30 02:30:41 +08:00
|
|
|
Element-wise sine.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The sine of ``a``.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"cos",
|
|
|
|
&mlx::core::cos,
|
|
|
|
"a"_a,
|
|
|
|
py::pos_only(),
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
R"pbdoc(
|
2023-12-12 05:42:55 +08:00
|
|
|
cos(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
|
|
|
|
|
2023-11-30 02:30:41 +08:00
|
|
|
Element-wise cosine.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The cosine of ``a``.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"tan",
|
|
|
|
&mlx::core::tan,
|
|
|
|
"a"_a,
|
|
|
|
py::pos_only(),
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
R"pbdoc(
|
2023-12-12 05:42:55 +08:00
|
|
|
tan(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
|
|
|
|
|
2023-11-30 02:30:41 +08:00
|
|
|
Element-wise tangent.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The tangent of ``a``.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"arcsin",
|
|
|
|
&mlx::core::arcsin,
|
|
|
|
"a"_a,
|
|
|
|
py::pos_only(),
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
R"pbdoc(
|
2023-12-12 05:42:55 +08:00
|
|
|
arcsin(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
|
|
|
|
|
2023-11-30 02:30:41 +08:00
|
|
|
Element-wise inverse sine.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The inverse sine of ``a``.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"arccos",
|
|
|
|
&mlx::core::arccos,
|
|
|
|
"a"_a,
|
|
|
|
py::pos_only(),
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
R"pbdoc(
|
2023-12-12 05:42:55 +08:00
|
|
|
arccos(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
|
|
|
|
|
2023-11-30 02:30:41 +08:00
|
|
|
Element-wise inverse cosine.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The inverse cosine of ``a``.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"arctan",
|
|
|
|
&mlx::core::arctan,
|
|
|
|
"a"_a,
|
|
|
|
py::pos_only(),
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
R"pbdoc(
|
2023-12-12 05:42:55 +08:00
|
|
|
arctan(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
|
|
|
|
|
2023-11-30 02:30:41 +08:00
|
|
|
Element-wise inverse tangent.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The inverse tangent of ``a``.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"sinh",
|
|
|
|
&mlx::core::sinh,
|
|
|
|
"a"_a,
|
|
|
|
py::pos_only(),
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
R"pbdoc(
|
2023-12-12 05:42:55 +08:00
|
|
|
sinh(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
|
|
|
|
|
2023-11-30 02:30:41 +08:00
|
|
|
Element-wise hyperbolic sine.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The hyperbolic sine of ``a``.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"cosh",
|
|
|
|
&mlx::core::cosh,
|
|
|
|
"a"_a,
|
|
|
|
py::pos_only(),
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
R"pbdoc(
|
2023-12-12 05:42:55 +08:00
|
|
|
cosh(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
|
|
|
|
|
2023-11-30 02:30:41 +08:00
|
|
|
Element-wise hyperbolic cosine.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The hyperbolic cosine of ``a``.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"tanh",
|
|
|
|
&mlx::core::tanh,
|
|
|
|
"a"_a,
|
|
|
|
py::pos_only(),
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
R"pbdoc(
|
2023-12-12 05:42:55 +08:00
|
|
|
tanh(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
|
|
|
|
|
2023-11-30 02:30:41 +08:00
|
|
|
Element-wise hyperbolic tangent.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The hyperbolic tangent of ``a``.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"arcsinh",
|
|
|
|
&mlx::core::arcsinh,
|
|
|
|
"a"_a,
|
|
|
|
py::pos_only(),
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
R"pbdoc(
|
2023-12-12 05:42:55 +08:00
|
|
|
arcsinh(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
|
|
|
|
|
2023-11-30 02:30:41 +08:00
|
|
|
Element-wise inverse hyperbolic sine.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The inverse hyperbolic sine of ``a``.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"arccosh",
|
|
|
|
&mlx::core::arccosh,
|
|
|
|
"a"_a,
|
|
|
|
py::pos_only(),
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
R"pbdoc(
|
2023-12-12 05:42:55 +08:00
|
|
|
arccosh(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
|
|
|
|
|
2023-11-30 02:30:41 +08:00
|
|
|
Element-wise inverse hyperbolic cosine.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The inverse hyperbolic cosine of ``a``.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"arctanh",
|
|
|
|
&mlx::core::arctanh,
|
|
|
|
"a"_a,
|
|
|
|
py::pos_only(),
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
R"pbdoc(
|
2023-12-12 05:42:55 +08:00
|
|
|
arctanh(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
|
|
|
|
|
2023-11-30 02:30:41 +08:00
|
|
|
Element-wise inverse hyperbolic tangent.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The inverse hyperbolic tangent of ``a``.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"log",
|
|
|
|
&mlx::core::log,
|
|
|
|
"a"_a,
|
|
|
|
py::pos_only(),
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
R"pbdoc(
|
2023-12-12 05:42:55 +08:00
|
|
|
log(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
|
|
|
|
|
2023-11-30 02:30:41 +08:00
|
|
|
Element-wise natural logarithm.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The natural logarithm of ``a``.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"log2",
|
|
|
|
&mlx::core::log2,
|
|
|
|
"a"_a,
|
|
|
|
py::pos_only(),
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
R"pbdoc(
|
2023-12-12 05:42:55 +08:00
|
|
|
log2(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
|
|
|
|
|
2023-11-30 02:30:41 +08:00
|
|
|
Element-wise base-2 logarithm.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The base-2 logarithm of ``a``.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"log10",
|
|
|
|
&mlx::core::log10,
|
|
|
|
"a"_a,
|
|
|
|
py::pos_only(),
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
R"pbdoc(
|
2023-12-12 05:42:55 +08:00
|
|
|
log10(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
|
|
|
|
|
2023-11-30 02:30:41 +08:00
|
|
|
Element-wise base-10 logarithm.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The base-10 logarithm of ``a``.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"log1p",
|
|
|
|
&mlx::core::log1p,
|
|
|
|
"a"_a,
|
|
|
|
py::pos_only(),
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
R"pbdoc(
|
2023-12-12 05:42:55 +08:00
|
|
|
log1p(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
|
|
|
|
|
2023-11-30 02:30:41 +08:00
|
|
|
Element-wise natural log of one plus the array.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The natural logarithm of one plus ``a``.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"stop_gradient",
|
|
|
|
&stop_gradient,
|
|
|
|
"a"_a,
|
|
|
|
py::pos_only(),
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
R"pbdoc(
|
2023-12-12 05:42:55 +08:00
|
|
|
stop_gradient(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
|
|
|
|
|
2023-11-30 02:30:41 +08:00
|
|
|
Stop gradients from being computed.
|
|
|
|
|
|
|
|
The operation is the identity but it prevents gradients from flowing
|
|
|
|
through the array.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The unchanged input ``a`` but without gradient flowing
|
|
|
|
through it.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"sigmoid",
|
|
|
|
&sigmoid,
|
|
|
|
"a"_a,
|
|
|
|
py::pos_only(),
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
R"pbdoc(
|
2023-12-12 05:42:55 +08:00
|
|
|
sigmoid(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
|
|
|
|
|
2023-11-30 02:30:41 +08:00
|
|
|
Element-wise logistic sigmoid.
|
|
|
|
|
|
|
|
The logistic sigmoid function is:
|
|
|
|
|
|
|
|
.. math::
|
|
|
|
\mathrm{sigmoid}(x) = \frac{1}{1 + e^{-x}}
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The logistic sigmoid of ``a``.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"power",
|
|
|
|
[](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) {
|
|
|
|
auto [a, b] = to_arrays(a_, b_);
|
|
|
|
return power(a, b, s);
|
|
|
|
},
|
|
|
|
"a"_a,
|
|
|
|
"b"_a,
|
|
|
|
py::pos_only(),
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
R"pbdoc(
|
2023-12-12 05:42:55 +08:00
|
|
|
power(a: Union[scalar, array], b: Union[scalar, array], /, *, stream: Union[None, Stream, Device] = None) -> array
|
|
|
|
|
2023-11-30 02:30:41 +08:00
|
|
|
Element-wise power operation.
|
|
|
|
|
|
|
|
Raise the elements of a to the powers in elements of b with numpy-style
|
|
|
|
broadcasting semantics. Either or both input arrays can also be scalars.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array or scalar.
|
|
|
|
b (array): Input array or scalar.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: Bases of ``a`` raised to powers in ``b``.
|
|
|
|
)pbdoc");
|
2023-12-12 05:42:55 +08:00
|
|
|
m.def(
|
|
|
|
"arange",
|
|
|
|
[](Scalar stop, std::optional<Dtype> dtype_, StreamOrDevice s) {
|
|
|
|
Dtype dtype =
|
|
|
|
dtype_.has_value() ? dtype_.value() : scalar_to_dtype(stop);
|
|
|
|
|
|
|
|
return arange(0.0, scalar_to_double(stop), 1.0, dtype, s);
|
|
|
|
},
|
|
|
|
"stop"_a,
|
|
|
|
"dtype"_a = none,
|
|
|
|
"stream"_a = none);
|
|
|
|
m.def(
|
|
|
|
"arange",
|
|
|
|
[](Scalar start,
|
|
|
|
Scalar stop,
|
|
|
|
std::optional<Dtype> dtype_,
|
|
|
|
StreamOrDevice s) {
|
|
|
|
Dtype dtype = dtype_.has_value()
|
|
|
|
? dtype_.value()
|
|
|
|
: promote_types(scalar_to_dtype(start), scalar_to_dtype(stop));
|
|
|
|
return arange(
|
|
|
|
scalar_to_double(start), scalar_to_double(stop), dtype, s);
|
|
|
|
},
|
|
|
|
"start"_a,
|
|
|
|
"stop"_a,
|
|
|
|
"dtype"_a = none,
|
|
|
|
"stream"_a = none);
|
|
|
|
m.def(
|
|
|
|
"arange",
|
|
|
|
[](Scalar stop,
|
|
|
|
Scalar step,
|
|
|
|
std::optional<Dtype> dtype_,
|
|
|
|
StreamOrDevice s) {
|
|
|
|
Dtype dtype = dtype_.has_value()
|
|
|
|
? dtype_.value()
|
|
|
|
: promote_types(scalar_to_dtype(stop), scalar_to_dtype(step));
|
|
|
|
|
|
|
|
return arange(
|
|
|
|
0.0, scalar_to_double(stop), scalar_to_double(step), dtype, s);
|
|
|
|
},
|
|
|
|
"stop"_a,
|
|
|
|
"step"_a,
|
|
|
|
"dtype"_a = none,
|
|
|
|
"stream"_a = none);
|
|
|
|
m.def(
|
|
|
|
"arange",
|
|
|
|
[](Scalar start,
|
|
|
|
Scalar stop,
|
|
|
|
Scalar step,
|
|
|
|
std::optional<Dtype> dtype_,
|
|
|
|
StreamOrDevice s) {
|
|
|
|
// Determine the final dtype based on input types
|
|
|
|
Dtype dtype = dtype_.has_value()
|
|
|
|
? dtype_.value()
|
|
|
|
: promote_types(
|
|
|
|
scalar_to_dtype(start),
|
|
|
|
promote_types(scalar_to_dtype(stop), scalar_to_dtype(step)));
|
|
|
|
|
|
|
|
return arange(
|
|
|
|
scalar_to_double(start),
|
|
|
|
scalar_to_double(stop),
|
|
|
|
scalar_to_double(step),
|
|
|
|
dtype,
|
|
|
|
s);
|
|
|
|
},
|
|
|
|
"start"_a,
|
|
|
|
"stop"_a,
|
|
|
|
"step"_a,
|
|
|
|
"dtype"_a = none,
|
|
|
|
"stream"_a = none,
|
|
|
|
R"pbdoc(
|
|
|
|
arange(start, stop, step, dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array
|
|
|
|
|
|
|
|
Generates ranges of numbers.
|
|
|
|
|
|
|
|
Generate numbers in the half-open interval ``[start, stop)`` in
|
|
|
|
increments of ``step``.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
start (float or int, optional): Starting value which defaults to ``0``.
|
|
|
|
stop (float or int): Stopping value.
|
|
|
|
step (float or int, optional): Increment which defaults to ``1``.
|
|
|
|
dtype (Dtype, optional): Specifies the data type of the output.
|
|
|
|
If unspecified will default to ``float32`` if any of ``start``,
|
|
|
|
``stop``, or ``step`` are ``float``. Otherwise will default to
|
|
|
|
``int32``.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The range of values.
|
|
|
|
|
|
|
|
Note:
|
|
|
|
Following the Numpy convention the actual increment used to
|
|
|
|
generate numbers is ``dtype(start + step) - dtype(start)``.
|
|
|
|
This can lead to unexpected results for example if `start + step`
|
|
|
|
is a fractional value and the `dtype` is integral.
|
2023-11-30 02:30:41 +08:00
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"take",
|
|
|
|
[](const array& a,
|
|
|
|
const array& indices,
|
|
|
|
const std::optional<int>& axis,
|
|
|
|
StreamOrDevice s) {
|
|
|
|
if (axis.has_value()) {
|
|
|
|
return take(a, indices, axis.value(), s);
|
|
|
|
} else {
|
|
|
|
return take(a, indices, s);
|
|
|
|
}
|
|
|
|
},
|
|
|
|
"a"_a,
|
|
|
|
py::pos_only(),
|
|
|
|
"indices"_a,
|
|
|
|
"axis"_a = std::nullopt,
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
R"pbdoc(
|
2023-12-12 05:42:55 +08:00
|
|
|
take(a: array, /, indices: array, axis: Optional[int] = None, *, stream: Union[None, Stream, Device] = None) -> array
|
|
|
|
|
2023-11-30 02:30:41 +08:00
|
|
|
Take elements along an axis.
|
|
|
|
|
|
|
|
The elements are taken from ``indices`` along the specified axis.
|
|
|
|
If the axis is not specified the array is treated as a flattened
|
|
|
|
1-D array prior to performing the take.
|
|
|
|
|
|
|
|
As an example, if the ``axis=1`` this is equialent to ``a[:, indices, ...]``.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
indices (array): Input array with integral type.
|
|
|
|
axis (int, optional): Axis along which to perform the take. If unspecified
|
|
|
|
the array is treated as a flattened 1-D vector.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The indexed values of ``a``.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"take_along_axis",
|
|
|
|
[](const array& a,
|
|
|
|
const array& indices,
|
|
|
|
const std::optional<int>& axis,
|
|
|
|
StreamOrDevice s) {
|
|
|
|
if (axis.has_value()) {
|
|
|
|
return take_along_axis(a, indices, axis.value(), s);
|
|
|
|
} else {
|
|
|
|
return take_along_axis(reshape(a, {-1}, s), indices, 0, s);
|
|
|
|
}
|
|
|
|
},
|
|
|
|
"a"_a,
|
|
|
|
py::pos_only(),
|
|
|
|
"indices"_a,
|
|
|
|
"axis"_a,
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
R"pbdoc(
|
2023-12-12 05:42:55 +08:00
|
|
|
take_along_axis(a: array, /, indices: array, axis: Optional[int] = None, *, stream: Union[None, Stream, Device] = None) -> array
|
|
|
|
|
2023-11-30 02:30:41 +08:00
|
|
|
Take values along an axis at the specified indices.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
indices (array): Indices array. These should be broadcastable with
|
|
|
|
the input array excluding the `axis` dimension.
|
|
|
|
axis (int or None): Axis in the input to take the values from. If
|
|
|
|
``axis == None`` the array is flattened to 1D prior to the indexing
|
|
|
|
operation.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The output array with the specified shape and values.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"full",
|
|
|
|
[](const std::variant<int, std::vector<int>>& shape,
|
|
|
|
const ScalarOrArray& vals,
|
|
|
|
std::optional<Dtype> dtype,
|
|
|
|
StreamOrDevice s) {
|
|
|
|
if (auto pv = std::get_if<int>(&shape); pv) {
|
|
|
|
return full({*pv}, to_array(vals, dtype), s);
|
|
|
|
} else {
|
|
|
|
return full(
|
|
|
|
std::get<std::vector<int>>(shape), to_array(vals, dtype), s);
|
|
|
|
}
|
|
|
|
},
|
|
|
|
"shape"_a,
|
|
|
|
"vals"_a,
|
|
|
|
"dtype"_a = std::nullopt,
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
R"pbdoc(
|
2023-12-12 05:42:55 +08:00
|
|
|
full(shape: Union[int, List[int]], vals: Union[scalar, array], dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array
|
|
|
|
|
2023-11-30 02:30:41 +08:00
|
|
|
Construct an array with the given value.
|
|
|
|
|
|
|
|
Constructs an array of size ``shape`` filled with ``vals``. If ``vals``
|
|
|
|
is an :obj:`array` it must be broadcastable to the given ``shape``.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
shape (int or list(int)): The shape of the output array.
|
|
|
|
vals (float or int or array): Values to fill the array with.
|
|
|
|
dtype (Dtype, optional): Data type of the output array. If
|
|
|
|
unspecified the output type is inferred from ``vals``.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The output array with the specified shape and values.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"zeros",
|
|
|
|
[](const std::variant<int, std::vector<int>>& shape,
|
|
|
|
std::optional<Dtype> dtype,
|
|
|
|
StreamOrDevice s) {
|
|
|
|
auto t = dtype.value_or(float32);
|
|
|
|
if (auto pv = std::get_if<int>(&shape); pv) {
|
|
|
|
return zeros({*pv}, t, s);
|
|
|
|
} else {
|
|
|
|
return zeros(std::get<std::vector<int>>(shape), t, s);
|
|
|
|
}
|
|
|
|
},
|
|
|
|
"shape"_a,
|
|
|
|
"dtype"_a = std::nullopt,
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
R"pbdoc(
|
2023-12-12 05:42:55 +08:00
|
|
|
zeros(shape: Union[int, List[int]], dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array
|
|
|
|
|
2023-11-30 02:30:41 +08:00
|
|
|
Construct an array of zeros.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
shape (int or list(int)): The shape of the output array.
|
|
|
|
dtype (Dtype, optional): Data type of the output array. If
|
|
|
|
unspecified the output type defaults to ``float32``.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The array of zeros with the specified shape.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"zeros_like",
|
|
|
|
&zeros_like,
|
|
|
|
"a"_a,
|
|
|
|
py::pos_only(),
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
R"pbdoc(
|
2023-12-12 05:42:55 +08:00
|
|
|
zeros_like(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
|
|
|
|
|
2023-11-30 02:30:41 +08:00
|
|
|
An array of zeros like the input.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): The input to take the shape and type from.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The output array filled with zeros.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"ones",
|
|
|
|
[](const std::variant<int, std::vector<int>>& shape,
|
|
|
|
std::optional<Dtype> dtype,
|
|
|
|
StreamOrDevice s) {
|
|
|
|
auto t = dtype.value_or(float32);
|
|
|
|
if (auto pv = std::get_if<int>(&shape); pv) {
|
|
|
|
return ones({*pv}, t, s);
|
|
|
|
} else {
|
|
|
|
return ones(std::get<std::vector<int>>(shape), t, s);
|
|
|
|
}
|
|
|
|
},
|
|
|
|
"shape"_a,
|
|
|
|
"dtype"_a = std::nullopt,
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
R"pbdoc(
|
2023-12-12 05:42:55 +08:00
|
|
|
ones(shape: Union[int, List[int]], dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array
|
|
|
|
|
2023-11-30 02:30:41 +08:00
|
|
|
Construct an array of ones.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
shape (int or list(int)): The shape of the output array.
|
|
|
|
dtype (Dtype, optional): Data type of the output array. If
|
|
|
|
unspecified the output type defaults to ``float32``.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The array of ones with the specified shape.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"ones_like",
|
|
|
|
&ones_like,
|
|
|
|
"a"_a,
|
|
|
|
py::pos_only(),
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
R"pbdoc(
|
2023-12-12 05:42:55 +08:00
|
|
|
ones_like(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
|
|
|
|
|
2023-11-30 02:30:41 +08:00
|
|
|
An array of ones like the input.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): The input to take the shape and type from.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The output array filled with ones.
|
|
|
|
)pbdoc");
|
2023-12-12 04:38:17 +08:00
|
|
|
m.def(
|
|
|
|
"eye",
|
|
|
|
[](int n,
|
2023-12-12 07:01:41 +08:00
|
|
|
std::optional<int> m,
|
|
|
|
int k,
|
|
|
|
std::optional<Dtype> dtype,
|
2023-12-12 04:38:17 +08:00
|
|
|
StreamOrDevice s) {
|
2023-12-12 07:01:41 +08:00
|
|
|
return eye(n, m.value_or(n), k, dtype.value_or(float32), s);
|
2023-12-12 04:38:17 +08:00
|
|
|
},
|
|
|
|
"n"_a,
|
|
|
|
"m"_a = py::none(),
|
2023-12-12 07:01:41 +08:00
|
|
|
"k"_a = 0,
|
2023-12-12 04:38:17 +08:00
|
|
|
"dtype"_a = std::nullopt,
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
R"pbdoc(
|
2023-12-12 07:01:41 +08:00
|
|
|
eye(n: int, m: Optional[int] = None, k: int = 0, dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array
|
|
|
|
|
2023-12-12 04:38:17 +08:00
|
|
|
Create an identity matrix or a general diagonal matrix.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
n (int): The number of rows in the output.
|
|
|
|
m (int, optional): The number of columns in the output. Defaults to n.
|
|
|
|
k (int, optional): Index of the diagonal. Defaults to 0 (main diagonal).
|
|
|
|
dtype (Dtype, optional): Data type of the output array. Defaults to float32.
|
|
|
|
stream (Stream, optional): Stream or device. Defaults to None.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: An array where all elements are equal to zero, except for the k-th diagonal, whose values are equal to one.
|
2023-12-12 07:01:41 +08:00
|
|
|
)pbdoc");
|
2023-12-12 04:38:17 +08:00
|
|
|
m.def(
|
|
|
|
"identity",
|
2023-12-12 07:01:41 +08:00
|
|
|
[](int n, std::optional<Dtype> dtype, StreamOrDevice s) {
|
|
|
|
return identity(n, dtype.value_or(float32), s);
|
|
|
|
},
|
2023-12-12 04:38:17 +08:00
|
|
|
"n"_a,
|
|
|
|
"dtype"_a = std::nullopt,
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
R"pbdoc(
|
2023-12-12 07:01:41 +08:00
|
|
|
identity(n: int, dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array
|
|
|
|
|
2023-12-12 04:38:17 +08:00
|
|
|
Create a square identity matrix.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
n (int): The number of rows and columns in the output.
|
|
|
|
dtype (Dtype, optional): Data type of the output array. Defaults to float32.
|
|
|
|
stream (Stream, optional): Stream or device. Defaults to None.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: An identity matrix of size n x n.
|
2023-12-12 07:01:41 +08:00
|
|
|
)pbdoc");
|
2023-12-16 09:30:34 +08:00
|
|
|
m.def(
|
|
|
|
"tri",
|
|
|
|
[](int n, std::optional<int> m, int k, Dtype dtype, StreamOrDevice s) {
|
|
|
|
return tri(n, m.value_or(n), k, float32, s);
|
|
|
|
},
|
|
|
|
"n"_a,
|
|
|
|
"m"_a = none,
|
|
|
|
"k"_a = 0,
|
|
|
|
"dtype"_a = float32,
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
R"pbdoc(
|
|
|
|
tri(n: int, m: int, k: int, dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array
|
|
|
|
|
|
|
|
An array with ones at and below the given diagonal and zeros elsewhere.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
n (int): The number of rows in the output.
|
|
|
|
m (int, optional): The number of cols in the output. Defaults to ``None``.
|
|
|
|
k (int, optional): The diagonal of the 2-D array. Defaults to ``0``.
|
|
|
|
dtype (Dtype, optional): Data type of the output array. Defaults to ``float32``.
|
|
|
|
stream (Stream, optional): Stream or device. Defaults to ``None``.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: Array with its lower triangle filled with ones and zeros elsewhere
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"tril",
|
|
|
|
&tril,
|
|
|
|
"x"_a,
|
|
|
|
"k"_a = 0,
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
R"pbdoc(
|
|
|
|
tril(x: array, k: int, *, stream: Union[None, Stream, Device] = None) -> array
|
|
|
|
|
|
|
|
Zeros the array above the given diagonal.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
x (array): input array.
|
|
|
|
k (int, optional): The diagonal of the 2-D array. Defaults to ``0``.
|
|
|
|
stream (Stream, optional): Stream or device. Defaults to ``None``.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: Array zeroed above the given diagonal
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"triu",
|
|
|
|
&triu,
|
|
|
|
"x"_a,
|
|
|
|
"k"_a = 0,
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
R"pbdoc(
|
|
|
|
triu(x: array, k: int, *, stream: Union[None, Stream, Device] = None) -> array
|
|
|
|
|
|
|
|
Zeros the array below the given diagonal.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
x (array): input array.
|
|
|
|
k (int, optional): The diagonal of the 2-D array. Defaults to ``0``.
|
|
|
|
stream (Stream, optional): Stream or device. Defaults to ``None``.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: Array zeroed below the given diagonal
|
|
|
|
)pbdoc");
|
2023-11-30 02:30:41 +08:00
|
|
|
m.def(
|
|
|
|
"allclose",
|
|
|
|
&allclose,
|
|
|
|
"a"_a,
|
|
|
|
"b"_a,
|
|
|
|
py::pos_only(),
|
|
|
|
"rtol"_a = 1e-5,
|
|
|
|
"atol"_a = 1e-8,
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
R"pbdoc(
|
2023-12-12 05:42:55 +08:00
|
|
|
allclose(a: array, b: array, /, rtol: float = 1e-05, atol: float = 1e-08, *, stream: Union[None, Stream, Device] = None) -> array
|
|
|
|
|
2023-11-30 02:30:41 +08:00
|
|
|
Approximate comparison of two arrays.
|
|
|
|
|
|
|
|
The arrays are considered equal if:
|
|
|
|
|
|
|
|
.. code-block::
|
|
|
|
|
|
|
|
all(abs(a - b) <= (atol + rtol * abs(b)))
|
|
|
|
|
|
|
|
Note unlike :func:`array_equal`, this function supports numpy-style
|
|
|
|
broadcasting.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
b (array): Input array.
|
|
|
|
rtol (float): Relative tolerance.
|
|
|
|
atol (float): Absolute tolerance.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The boolean output scalar indicating if the arrays are close.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"all",
|
|
|
|
[](const array& a,
|
|
|
|
const IntOrVec& axis,
|
|
|
|
bool keepdims,
|
|
|
|
StreamOrDevice s) {
|
|
|
|
return all(a, get_reduce_axes(axis, a.ndim()), keepdims, s);
|
|
|
|
},
|
|
|
|
"a"_a,
|
|
|
|
py::pos_only(),
|
|
|
|
"axis"_a = none,
|
|
|
|
"keepdims"_a = false,
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
R"pbdoc(
|
2023-12-12 05:42:55 +08:00
|
|
|
all(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array
|
|
|
|
|
2023-11-30 02:30:41 +08:00
|
|
|
An `and` reduction over the given axes.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
axis (int or list(int), optional): Optional axis or
|
|
|
|
axes to reduce over. If unspecified this defaults
|
|
|
|
to reducing over the entire array.
|
|
|
|
keepdims (bool, optional): Keep reduced axes as
|
|
|
|
singleton dimensions, defaults to `False`.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The output array with the corresponding axes reduced.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"any",
|
|
|
|
[](const array& a,
|
|
|
|
const IntOrVec& axis,
|
|
|
|
bool keepdims,
|
|
|
|
StreamOrDevice s) {
|
|
|
|
return any(a, get_reduce_axes(axis, a.ndim()), keepdims, s);
|
|
|
|
},
|
|
|
|
"a"_a,
|
|
|
|
py::pos_only(),
|
|
|
|
"axis"_a = none,
|
|
|
|
"keepdims"_a = false,
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
R"pbdoc(
|
2023-12-12 05:42:55 +08:00
|
|
|
any(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array
|
|
|
|
|
2023-11-30 02:30:41 +08:00
|
|
|
An `or` reduction over the given axes.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
axis (int or list(int), optional): Optional axis or
|
|
|
|
axes to reduce over. If unspecified this defaults
|
|
|
|
to reducing over the entire array.
|
|
|
|
keepdims (bool, optional): Keep reduced axes as
|
|
|
|
singleton dimensions, defaults to `False`.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The output array with the corresponding axes reduced.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"minimum",
|
|
|
|
[](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) {
|
|
|
|
auto [a, b] = to_arrays(a_, b_);
|
|
|
|
return minimum(a, b, s);
|
|
|
|
},
|
|
|
|
"a"_a,
|
|
|
|
"b"_a,
|
|
|
|
py::pos_only(),
|
2023-12-12 05:42:55 +08:00
|
|
|
py::kw_only(),
|
2023-11-30 02:30:41 +08:00
|
|
|
"stream"_a = none,
|
|
|
|
R"pbdoc(
|
2023-12-12 05:42:55 +08:00
|
|
|
minimum(a: Union[scalar, array], b: Union[scalar, array], /, *, stream: Union[None, Stream, Device] = None) -> array
|
|
|
|
|
2023-11-30 02:30:41 +08:00
|
|
|
Element-wise minimum.
|
|
|
|
|
|
|
|
Take the element-wise min of two arrays with numpy-style broadcasting
|
|
|
|
semantics. Either or both input arrays can also be scalars.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array or scalar.
|
|
|
|
b (array): Input array or scalar.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The min of ``a`` and ``b``.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"maximum",
|
|
|
|
[](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) {
|
|
|
|
auto [a, b] = to_arrays(a_, b_);
|
|
|
|
return maximum(a, b, s);
|
|
|
|
},
|
|
|
|
"a"_a,
|
|
|
|
"b"_a,
|
|
|
|
py::pos_only(),
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
R"pbdoc(
|
2023-12-12 05:42:55 +08:00
|
|
|
maximum(a: Union[scalar, array], b: Union[scalar, array], /, *, stream: Union[None, Stream, Device] = None) -> array
|
|
|
|
|
2023-11-30 02:30:41 +08:00
|
|
|
Element-wise maximum.
|
|
|
|
|
|
|
|
Take the element-wise max of two arrays with numpy-style broadcasting
|
|
|
|
semantics. Either or both input arrays can also be scalars.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array or scalar.
|
|
|
|
b (array): Input array or scalar.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The max of ``a`` and ``b``.
|
|
|
|
)pbdoc");
|
2023-12-15 02:00:23 +08:00
|
|
|
m.def(
|
|
|
|
"floor",
|
|
|
|
&mlx::core::floor,
|
|
|
|
"a"_a,
|
|
|
|
py::pos_only(),
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
R"pbdoc(
|
|
|
|
floor(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
|
|
|
|
|
|
|
|
Element-wise floor.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The floor of ``a``.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"ceil",
|
|
|
|
&mlx::core::ceil,
|
|
|
|
"a"_a,
|
|
|
|
py::pos_only(),
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
R"pbdoc(
|
|
|
|
ceil(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
|
|
|
|
|
|
|
|
Element-wise ceil.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The ceil of ``a``.
|
|
|
|
)pbdoc");
|
2023-12-15 04:59:12 +08:00
|
|
|
m.def(
|
|
|
|
"moveaxis",
|
|
|
|
&moveaxis,
|
|
|
|
"a"_a,
|
|
|
|
py::pos_only(),
|
|
|
|
"source"_a,
|
|
|
|
"destiantion"_a,
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
R"pbdoc(
|
|
|
|
moveaxis(a: array, /, source: int, destination: int, *, stream: Union[None, Stream, Device] = None) -> array
|
|
|
|
|
|
|
|
Move an axis to a new position.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
source (int): Specifies the source axis.
|
|
|
|
destination (int): Specifies the destination axis.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The array with the axis moved.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"swapaxes",
|
|
|
|
&swapaxes,
|
|
|
|
"a"_a,
|
|
|
|
py::pos_only(),
|
|
|
|
"axis1"_a,
|
|
|
|
"axis2"_a,
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
R"pbdoc(
|
|
|
|
swapaxes(a: array, /, axis1 : int, axis2: int, *, stream: Union[None, Stream, Device] = None) -> array
|
|
|
|
|
|
|
|
Swap two axes of an array.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
axis1 (int): Specifies the first axis.
|
|
|
|
axis2 (int): Specifies the second axis.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The array with swapped axes.
|
|
|
|
)pbdoc");
|
2023-11-30 02:30:41 +08:00
|
|
|
m.def(
|
|
|
|
"transpose",
|
|
|
|
[](const array& a,
|
|
|
|
const std::optional<std::vector<int>>& axes,
|
|
|
|
StreamOrDevice s) {
|
|
|
|
if (axes.has_value()) {
|
|
|
|
return transpose(a, get_reduce_axes(axes.value(), a.ndim()), s);
|
|
|
|
} else {
|
|
|
|
return transpose(a, s);
|
|
|
|
}
|
|
|
|
},
|
|
|
|
"a"_a,
|
|
|
|
py::pos_only(),
|
|
|
|
"axes"_a = std::nullopt,
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
R"pbdoc(
|
2023-12-12 05:42:55 +08:00
|
|
|
transpose(a: array, /, axes: Optional[List[int]] = None, *, stream: Union[None, Stream, Device] = None) -> array
|
|
|
|
|
2023-11-30 02:30:41 +08:00
|
|
|
Transpose the dimensions of the array.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
axes (list(int), optional): Specifies the source axis for each axis
|
|
|
|
in the new array. The default is to reverse the axes.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The transposed array.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"sum",
|
|
|
|
[](const array& a,
|
|
|
|
const IntOrVec& axis,
|
|
|
|
bool keepdims,
|
|
|
|
StreamOrDevice s) {
|
|
|
|
return sum(a, get_reduce_axes(axis, a.ndim()), keepdims, s);
|
|
|
|
},
|
|
|
|
"array"_a,
|
|
|
|
py::pos_only(),
|
|
|
|
"axis"_a = none,
|
|
|
|
"keepdims"_a = false,
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
R"pbdoc(
|
2023-12-12 05:42:55 +08:00
|
|
|
sum(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array
|
|
|
|
|
2023-11-30 02:30:41 +08:00
|
|
|
Sum reduce the array over the given axes.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
axis (int or list(int), optional): Optional axis or
|
|
|
|
axes to reduce over. If unspecified this defaults
|
|
|
|
to reducing over the entire array.
|
|
|
|
keepdims (bool, optional): Keep reduced axes as
|
|
|
|
singleton dimensions, defaults to `False`.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The output array with the corresponding axes reduced.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"prod",
|
|
|
|
[](const array& a,
|
|
|
|
const IntOrVec& axis,
|
|
|
|
bool keepdims,
|
|
|
|
StreamOrDevice s) {
|
|
|
|
return prod(a, get_reduce_axes(axis, a.ndim()), keepdims, s);
|
|
|
|
},
|
|
|
|
"a"_a,
|
|
|
|
py::pos_only(),
|
|
|
|
"axis"_a = none,
|
|
|
|
"keepdims"_a = false,
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
R"pbdoc(
|
2023-12-12 05:42:55 +08:00
|
|
|
prod(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array
|
|
|
|
|
2023-11-30 02:30:41 +08:00
|
|
|
An product reduction over the given axes.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
axis (int or list(int), optional): Optional axis or
|
|
|
|
axes to reduce over. If unspecified this defaults
|
|
|
|
to reducing over the entire array.
|
|
|
|
keepdims (bool, optional): Keep reduced axes as
|
|
|
|
singleton dimensions, defaults to `False`.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The output array with the corresponding axes reduced.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"min",
|
|
|
|
[](const array& a,
|
|
|
|
const IntOrVec& axis,
|
|
|
|
bool keepdims,
|
|
|
|
StreamOrDevice s) {
|
|
|
|
return min(a, get_reduce_axes(axis, a.ndim()), keepdims, s);
|
|
|
|
},
|
|
|
|
"a"_a,
|
|
|
|
py::pos_only(),
|
|
|
|
"axis"_a = none,
|
|
|
|
"keepdims"_a = false,
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
R"pbdoc(
|
2023-12-12 05:42:55 +08:00
|
|
|
min(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array
|
|
|
|
|
2023-11-30 02:30:41 +08:00
|
|
|
An `min` reduction over the given axes.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
axis (int or list(int), optional): Optional axis or
|
|
|
|
axes to reduce over. If unspecified this defaults
|
|
|
|
to reducing over the entire array.
|
|
|
|
keepdims (bool, optional): Keep reduced axes as
|
|
|
|
singleton dimensions, defaults to `False`.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The output array with the corresponding axes reduced.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"max",
|
|
|
|
[](const array& a,
|
|
|
|
const IntOrVec& axis,
|
|
|
|
bool keepdims,
|
|
|
|
StreamOrDevice s) {
|
|
|
|
return max(a, get_reduce_axes(axis, a.ndim()), keepdims, s);
|
|
|
|
},
|
|
|
|
"a"_a,
|
|
|
|
py::pos_only(),
|
|
|
|
"axis"_a = none,
|
|
|
|
"keepdims"_a = false,
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
R"pbdoc(
|
2023-12-12 05:42:55 +08:00
|
|
|
max(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array
|
|
|
|
|
2023-11-30 02:30:41 +08:00
|
|
|
An `max` reduction over the given axes.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
axis (int or list(int), optional): Optional axis or
|
|
|
|
axes to reduce over. If unspecified this defaults
|
|
|
|
to reducing over the entire array.
|
|
|
|
keepdims (bool, optional): Keep reduced axes as
|
|
|
|
singleton dimensions, defaults to `False`.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The output array with the corresponding axes reduced.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"logsumexp",
|
|
|
|
[](const array& a,
|
|
|
|
const IntOrVec& axis,
|
|
|
|
bool keepdims,
|
|
|
|
StreamOrDevice s) {
|
|
|
|
return logsumexp(a, get_reduce_axes(axis, a.ndim()), keepdims, s);
|
|
|
|
},
|
|
|
|
"a"_a,
|
|
|
|
py::pos_only(),
|
|
|
|
"axis"_a = none,
|
|
|
|
"keepdims"_a = false,
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
R"pbdoc(
|
2023-12-12 05:42:55 +08:00
|
|
|
logsumexp(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array
|
|
|
|
|
2023-11-30 02:30:41 +08:00
|
|
|
A `log-sum-exp` reduction over the given axes.
|
|
|
|
|
|
|
|
The log-sum-exp reduction is a numerically stable version of:
|
|
|
|
|
|
|
|
.. code-block::
|
|
|
|
|
|
|
|
log(sum(exp(a), axis))
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
axis (int or list(int), optional): Optional axis or
|
|
|
|
axes to reduce over. If unspecified this defaults
|
|
|
|
to reducing over the entire array.
|
|
|
|
keepdims (bool, optional): Keep reduced axes as
|
|
|
|
singleton dimensions, defaults to `False`.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The output array with the corresponding axes reduced.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"mean",
|
|
|
|
[](const array& a,
|
|
|
|
const IntOrVec& axis,
|
|
|
|
bool keepdims,
|
|
|
|
StreamOrDevice s) {
|
|
|
|
return mean(a, get_reduce_axes(axis, a.ndim()), keepdims, s);
|
|
|
|
},
|
|
|
|
"a"_a,
|
|
|
|
py::pos_only(),
|
|
|
|
"axis"_a = none,
|
|
|
|
"keepdims"_a = false,
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
R"pbdoc(
|
2023-12-12 05:42:55 +08:00
|
|
|
mean(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array
|
|
|
|
|
2023-11-30 02:30:41 +08:00
|
|
|
Compute the mean(s) over the given axes.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
axis (int or list(int), optional): Optional axis or
|
|
|
|
axes to reduce over. If unspecified this defaults
|
|
|
|
to reducing over the entire array.
|
|
|
|
keepdims (bool, optional): Keep reduced axes as
|
|
|
|
singleton dimensions, defaults to `False`.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The output array of means.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"var",
|
|
|
|
[](const array& a,
|
|
|
|
const IntOrVec& axis,
|
|
|
|
bool keepdims,
|
|
|
|
int ddof,
|
|
|
|
StreamOrDevice s) {
|
|
|
|
return var(a, get_reduce_axes(axis, a.ndim()), keepdims, ddof, s);
|
|
|
|
},
|
|
|
|
"a"_a,
|
|
|
|
py::pos_only(),
|
|
|
|
"axis"_a = none,
|
|
|
|
"keepdims"_a = false,
|
|
|
|
"ddof"_a = 0,
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
R"pbdoc(
|
2023-12-12 05:42:55 +08:00
|
|
|
var(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, ddof: int = 0, *, stream: Union[None, Stream, Device] = None) -> array
|
|
|
|
|
2023-11-30 02:30:41 +08:00
|
|
|
Compute the variance(s) over the given axes.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
axis (int or list(int), optional): Optional axis or
|
|
|
|
axes to reduce over. If unspecified this defaults
|
|
|
|
to reducing over the entire array.
|
|
|
|
keepdims (bool, optional): Keep reduced axes as
|
|
|
|
singleton dimensions, defaults to `False`.
|
|
|
|
ddof (int, optional): The divisor to compute the variance
|
|
|
|
is ``N - ddof``, defaults to 0.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The output array of variances.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"split",
|
|
|
|
[](const array& a,
|
|
|
|
const std::variant<int, std::vector<int>>& indices_or_sections,
|
|
|
|
int axis,
|
|
|
|
StreamOrDevice s) {
|
|
|
|
if (auto pv = std::get_if<int>(&indices_or_sections); pv) {
|
|
|
|
return split(a, *pv, axis, s);
|
|
|
|
} else {
|
|
|
|
return split(
|
|
|
|
a, std::get<std::vector<int>>(indices_or_sections), axis, s);
|
|
|
|
}
|
|
|
|
},
|
|
|
|
"a"_a,
|
|
|
|
py::pos_only(),
|
|
|
|
"indices_or_sections"_a,
|
|
|
|
"axis"_a = 0,
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
R"pbdoc(
|
2023-12-12 05:42:55 +08:00
|
|
|
split(a: array, /, indices_or_sections: Union[int, List[int]], axis: int = 0, *, stream: Union[None, Stream, Device] = None) -> array
|
|
|
|
|
2023-11-30 02:30:41 +08:00
|
|
|
Split an array along a given axis.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
indices_or_sections (int or list(int)): If ``indices_or_sections``
|
|
|
|
is an integer the array is split into that many sections of equal
|
|
|
|
size. An error is raised if this is not possible. If ``indices_or_sections``
|
|
|
|
is a list, the list contains the indices of the start of each subarray
|
|
|
|
along the given axis.
|
|
|
|
axis (int, optional): Axis to split along, defaults to `0`.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
list(array): A list of split arrays.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"argmin",
|
|
|
|
[](const array& a,
|
|
|
|
std::optional<int> axis,
|
|
|
|
bool keepdims,
|
|
|
|
StreamOrDevice s) {
|
|
|
|
if (axis) {
|
|
|
|
return argmin(a, *axis, keepdims, s);
|
|
|
|
} else {
|
|
|
|
return argmin(a, keepdims, s);
|
|
|
|
}
|
|
|
|
},
|
|
|
|
"a"_a,
|
|
|
|
py::pos_only(),
|
|
|
|
"axis"_a = std::nullopt,
|
|
|
|
"keepdims"_a = false,
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
R"pbdoc(
|
2023-12-12 05:42:55 +08:00
|
|
|
argmin(a: array, /, axis: Union[None, int] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array
|
|
|
|
|
2023-11-30 02:30:41 +08:00
|
|
|
Indices of the minimum values along the axis.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
axis (int, optional): Optional axis to reduce over. If unspecified
|
|
|
|
this defaults to reducing over the entire array.
|
|
|
|
keepdims (bool, optional): Keep reduced axes as
|
|
|
|
singleton dimensions, defaults to `False`.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The output array with the indices of the minimum values.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"argmax",
|
|
|
|
[](const array& a,
|
|
|
|
std::optional<int> axis,
|
|
|
|
bool keepdims,
|
|
|
|
StreamOrDevice s) {
|
|
|
|
if (axis) {
|
|
|
|
return argmax(a, *axis, keepdims, s);
|
|
|
|
} else {
|
|
|
|
return argmax(a, keepdims, s);
|
|
|
|
}
|
|
|
|
},
|
|
|
|
"a"_a,
|
|
|
|
py::pos_only(),
|
|
|
|
"axis"_a = std::nullopt,
|
|
|
|
"keepdims"_a = false,
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
R"pbdoc(
|
2023-12-12 05:42:55 +08:00
|
|
|
argmax(a: array, /, axis: Union[None, int] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array
|
|
|
|
|
2023-11-30 02:30:41 +08:00
|
|
|
Indices of the maximum values along the axis.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
axis (int, optional): Optional axis to reduce over. If unspecified
|
|
|
|
this defaults to reducing over the entire array.
|
|
|
|
keepdims (bool, optional): Keep reduced axes as
|
|
|
|
singleton dimensions, defaults to `False`.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The output array with the indices of the minimum values.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"sort",
|
|
|
|
[](const array& a, std::optional<int> axis, StreamOrDevice s) {
|
|
|
|
if (axis) {
|
|
|
|
return sort(a, *axis, s);
|
|
|
|
} else {
|
|
|
|
return sort(a, s);
|
|
|
|
}
|
|
|
|
},
|
|
|
|
"a"_a,
|
|
|
|
py::pos_only(),
|
|
|
|
"axis"_a = -1,
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
R"pbdoc(
|
2023-12-12 05:42:55 +08:00
|
|
|
sort(a: array, /, axis: Union[None, int] = -1, *, stream: Union[None, Stream, Device] = None) -> array
|
2023-12-12 07:01:41 +08:00
|
|
|
|
2023-11-30 02:30:41 +08:00
|
|
|
Returns a sorted copy of the array.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
2023-12-12 07:01:41 +08:00
|
|
|
axis (int or None, optional): Optional axis to sort over.
|
|
|
|
If ``None``, this sorts over the flattened array.
|
2023-11-30 02:30:41 +08:00
|
|
|
If unspecified, it defaults to -1 (sorting over the last axis).
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The sorted array.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"argsort",
|
|
|
|
[](const array& a, std::optional<int> axis, StreamOrDevice s) {
|
|
|
|
if (axis) {
|
|
|
|
return argsort(a, *axis, s);
|
|
|
|
} else {
|
|
|
|
return argsort(a, s);
|
|
|
|
}
|
|
|
|
},
|
|
|
|
"a"_a,
|
|
|
|
py::pos_only(),
|
|
|
|
"axis"_a = -1,
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
R"pbdoc(
|
2023-12-12 05:42:55 +08:00
|
|
|
argsort(a: array, /, axis: Union[None, int] = -1, *, stream: Union[None, Stream, Device] = None) -> array
|
|
|
|
|
2023-11-30 02:30:41 +08:00
|
|
|
Returns the indices that sort the array.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
2023-12-12 07:01:41 +08:00
|
|
|
axis (int or None, optional): Optional axis to sort over.
|
|
|
|
If ``None``, this sorts over the flattened array.
|
2023-11-30 02:30:41 +08:00
|
|
|
If unspecified, it defaults to -1 (sorting over the last axis).
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The indices that sort the input array.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"partition",
|
|
|
|
[](const array& a, int kth, std::optional<int> axis, StreamOrDevice s) {
|
|
|
|
if (axis) {
|
|
|
|
return partition(a, kth, *axis, s);
|
|
|
|
} else {
|
|
|
|
return partition(a, kth, s);
|
|
|
|
}
|
|
|
|
},
|
|
|
|
"a"_a,
|
|
|
|
py::pos_only(),
|
|
|
|
"kth"_a,
|
|
|
|
"axis"_a = -1,
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
R"pbdoc(
|
2023-12-12 05:42:55 +08:00
|
|
|
partition(a: array, /, kth: int, axis: Union[None, int] = -1, *, stream: Union[None, Stream, Device] = None) -> array
|
|
|
|
|
2023-11-30 02:30:41 +08:00
|
|
|
Returns a partitioned copy of the array such that the smaller ``kth``
|
|
|
|
elements are first.
|
|
|
|
|
|
|
|
The ordering of the elements in partitions is undefined.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
2023-12-12 07:01:41 +08:00
|
|
|
kth (int): Element at the ``kth`` index will be in its sorted
|
|
|
|
position in the output. All elements before the kth index will
|
|
|
|
be less or equal to the ``kth`` element and all elements after
|
2023-11-30 02:30:41 +08:00
|
|
|
will be greater or equal to the ``kth`` element in the output.
|
2023-12-12 07:01:41 +08:00
|
|
|
axis (int or None, optional): Optional axis to partition over.
|
|
|
|
If ``None``, this partitions over the flattened array.
|
2023-11-30 02:30:41 +08:00
|
|
|
If unspecified, it defaults to ``-1``.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The partitioned array.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"argpartition",
|
|
|
|
[](const array& a, int kth, std::optional<int> axis, StreamOrDevice s) {
|
|
|
|
if (axis) {
|
|
|
|
return argpartition(a, kth, *axis, s);
|
|
|
|
} else {
|
|
|
|
return argpartition(a, kth, s);
|
|
|
|
}
|
|
|
|
},
|
|
|
|
"a"_a,
|
|
|
|
py::pos_only(),
|
|
|
|
"kth"_a,
|
|
|
|
"axis"_a = -1,
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
R"pbdoc(
|
2023-12-12 05:42:55 +08:00
|
|
|
argpartition(a: array, /, kth: int, axis: Union[None, int] = -1, *, stream: Union[None, Stream, Device] = None) -> array
|
|
|
|
|
2023-11-30 02:30:41 +08:00
|
|
|
Returns the indices that partition the array.
|
|
|
|
|
|
|
|
The ordering of the elements within a partition in given by the indices
|
|
|
|
is undefined.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
kth (int): Element index at the ``kth`` position in the output will
|
|
|
|
give the sorted position. All indices before the ``kth`` position
|
2023-12-12 07:01:41 +08:00
|
|
|
will be of elements less or equal to the element at the ``kth``
|
2023-11-30 02:30:41 +08:00
|
|
|
index and all indices after will be of elements greater or equal
|
|
|
|
to the element at the ``kth`` index.
|
2023-12-12 07:01:41 +08:00
|
|
|
axis (int or None, optional): Optional axis to partiton over.
|
|
|
|
If ``None``, this partitions over the flattened array.
|
2023-11-30 02:30:41 +08:00
|
|
|
If unspecified, it defaults to ``-1``.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The indices that partition the input array.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"topk",
|
|
|
|
[](const array& a, int k, std::optional<int> axis, StreamOrDevice s) {
|
|
|
|
if (axis) {
|
|
|
|
return topk(a, k, *axis, s);
|
|
|
|
} else {
|
|
|
|
return topk(a, k, s);
|
|
|
|
}
|
|
|
|
},
|
|
|
|
"a"_a,
|
|
|
|
py::pos_only(),
|
|
|
|
"k"_a,
|
|
|
|
"axis"_a = -1,
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
R"pbdoc(
|
2023-12-12 05:42:55 +08:00
|
|
|
topk(a: array, /, k: int, axis: Union[None, int] = -1, *, stream: Union[None, Stream, Device] = None) -> array
|
|
|
|
|
2023-11-30 02:30:41 +08:00
|
|
|
Returns the ``k`` largest elements from the input along a given axis.
|
|
|
|
|
|
|
|
The elements will not necessarily be in sorted order.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
k (int): ``k`` top elements to be returned
|
2023-12-12 07:01:41 +08:00
|
|
|
axis (int or None, optional): Optional axis to select over.
|
|
|
|
If ``None``, this selects the top ``k`` elements over the
|
2023-11-30 02:30:41 +08:00
|
|
|
flattened array. If unspecified, it defaults to ``-1``.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The top ``k`` elements from the input.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"broadcast_to",
|
|
|
|
[](const ScalarOrArray& a,
|
|
|
|
const std::vector<int>& shape,
|
|
|
|
StreamOrDevice s) { return broadcast_to(to_array(a), shape, s); },
|
|
|
|
"a"_a,
|
|
|
|
py::pos_only(),
|
|
|
|
"shape"_a,
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
R"pbdoc(
|
2023-12-12 05:42:55 +08:00
|
|
|
broadcast_to(a: Union[scalar, array], /, shape: List[int], *, stream: Union[None, Stream, Device] = None) -> array
|
|
|
|
|
2023-11-30 02:30:41 +08:00
|
|
|
Broadcast an array to the given shape.
|
|
|
|
|
|
|
|
The broadcasting semantics are the same as Numpy.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
shape (list(int)): The shape to broadcast to.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The output array with the new shape.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"softmax",
|
|
|
|
[](const array& a, const IntOrVec& axis, StreamOrDevice s) {
|
|
|
|
return softmax(a, get_reduce_axes(axis, a.ndim()), s);
|
|
|
|
},
|
|
|
|
"a"_a,
|
|
|
|
py::pos_only(),
|
|
|
|
"axis"_a = none,
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
R"pbdoc(
|
2023-12-12 05:42:55 +08:00
|
|
|
softmax(a: array, /, axis: Union[None, int, List[int]] = None, *, stream: Union[None, Stream, Device] = None) -> array
|
|
|
|
|
2023-11-30 02:30:41 +08:00
|
|
|
Perform the softmax along the given axis.
|
|
|
|
|
|
|
|
This operation is a numerically stable version of:
|
|
|
|
|
|
|
|
.. code-block::
|
|
|
|
|
|
|
|
exp(a) / sum(exp(a), axis, keepdims=True)
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
axis (int or list(int), optional): Optional axis or axes to compute
|
|
|
|
the softmax over. If unspecified this performs the softmax over
|
|
|
|
the full array.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The output of the softmax.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"concatenate",
|
|
|
|
[](const std::vector<array>& arrays,
|
|
|
|
std::optional<int> axis,
|
|
|
|
StreamOrDevice s) {
|
|
|
|
if (axis) {
|
|
|
|
return concatenate(arrays, *axis, s);
|
|
|
|
} else {
|
|
|
|
return concatenate(arrays, s);
|
|
|
|
}
|
|
|
|
},
|
|
|
|
"arrays"_a,
|
|
|
|
py::pos_only(),
|
|
|
|
"axis"_a = 0,
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
R"pbdoc(
|
2023-12-12 05:42:55 +08:00
|
|
|
concatenate(arrays: List[array], axis: Optional[int] = 0, *, stream: Union[None, Stream, Device] = None) -> array
|
|
|
|
|
2023-11-30 02:30:41 +08:00
|
|
|
Concatenate the arrays along the given axis.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
arrays (list(array)): Input :obj:`list` or :obj:`tuple` of arrays.
|
|
|
|
axis (int, optional): Optional axis to concatenate along. If
|
|
|
|
unspecified defaults to ``0``.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The concatenated array.
|
|
|
|
)pbdoc");
|
2023-12-15 05:21:19 +08:00
|
|
|
m.def(
|
|
|
|
"stack",
|
|
|
|
[](const std::vector<array>& arrays,
|
|
|
|
std::optional<int> axis,
|
|
|
|
StreamOrDevice s) {
|
|
|
|
if (axis.has_value()) {
|
|
|
|
return stack(arrays, axis.value(), s);
|
|
|
|
} else {
|
|
|
|
return stack(arrays, s);
|
|
|
|
}
|
|
|
|
},
|
|
|
|
"arrays"_a,
|
|
|
|
py::pos_only(),
|
|
|
|
"axis"_a = 0,
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
R"pbdoc(
|
|
|
|
stack(arrays: List[array], axis: Optional[int] = 0, *, stream: Union[None, Stream, Device] = None) -> array
|
|
|
|
|
|
|
|
Stacks the arrays along a new axis.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
arrays (list(array)): A list of arrays to stack.
|
|
|
|
axis (int, optional): The axis in the result array along which the
|
2023-12-16 09:30:34 +08:00
|
|
|
input arrays are stacked. Defaults to ``0``.
|
2023-12-15 05:21:19 +08:00
|
|
|
stream (Stream, optional): Stream or device. Defaults to ``None``.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The resulting stacked array.
|
|
|
|
)pbdoc");
|
2023-12-18 12:00:29 +08:00
|
|
|
m.def(
|
|
|
|
"clip",
|
|
|
|
[](const array& a,
|
|
|
|
const std::optional<ScalarOrArray>& min,
|
|
|
|
const std::optional<ScalarOrArray>& max,
|
|
|
|
StreamOrDevice s) {
|
|
|
|
std::optional<array> min_ = std::nullopt;
|
|
|
|
std::optional<array> max_ = std::nullopt;
|
|
|
|
if (min) {
|
|
|
|
min_ = to_array(min.value());
|
|
|
|
}
|
|
|
|
if (max) {
|
|
|
|
max_ = to_array(max.value());
|
|
|
|
}
|
|
|
|
return clip(a, min_, max_, s);
|
|
|
|
},
|
|
|
|
"a"_a,
|
|
|
|
py::pos_only(),
|
|
|
|
"a_min"_a,
|
|
|
|
"a_max"_a,
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
R"pbdoc(
|
|
|
|
clip(a: array, /, a_min: Union[scalar, array, None], a_max: Union[scalar, array, None], *, stream: Union[None, Stream, Device] = None) -> array
|
|
|
|
|
|
|
|
Clip the values of the array between the given minimum and maximum.
|
|
|
|
|
|
|
|
If either ``a_min`` or ``a_max`` are ``None``, then corresponding edge
|
|
|
|
is ignored. At least one of ``a_min`` and ``a_max`` cannot be ``None``.
|
|
|
|
The input ``a`` and the limits must broadcast with one another.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
a_min (scalar or array or None): Minimum value to clip to.
|
|
|
|
a_max (scalar or array or None): Maximum value to clip to.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The clipped array.
|
|
|
|
)pbdoc");
|
2023-11-30 02:30:41 +08:00
|
|
|
m.def(
|
|
|
|
"pad",
|
|
|
|
[](const array& a,
|
|
|
|
const std::variant<
|
|
|
|
int,
|
|
|
|
std::tuple<int>,
|
|
|
|
std::pair<int, int>,
|
|
|
|
std::vector<std::pair<int, int>>>& pad_width,
|
|
|
|
const ScalarOrArray& constant_value,
|
|
|
|
StreamOrDevice s) {
|
|
|
|
if (auto pv = std::get_if<int>(&pad_width); pv) {
|
|
|
|
return pad(a, *pv, to_array(constant_value), s);
|
|
|
|
} else if (auto pv = std::get_if<std::tuple<int>>(&pad_width); pv) {
|
|
|
|
return pad(a, std::get<0>(*pv), to_array(constant_value), s);
|
|
|
|
} else if (auto pv = std::get_if<std::pair<int, int>>(&pad_width); pv) {
|
|
|
|
return pad(a, *pv, to_array(constant_value), s);
|
|
|
|
} else {
|
|
|
|
auto v = std::get<std::vector<std::pair<int, int>>>(pad_width);
|
|
|
|
if (v.size() == 1) {
|
|
|
|
return pad(a, v[0], to_array(constant_value), s);
|
|
|
|
} else {
|
|
|
|
return pad(a, v, to_array(constant_value), s);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
},
|
|
|
|
"a"_a,
|
|
|
|
py::pos_only(),
|
|
|
|
"pad_width"_a,
|
|
|
|
"constant_values"_a = 0,
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
R"pbdoc(
|
2023-12-12 05:42:55 +08:00
|
|
|
pad(a: array, pad_with: Union[int, Tuple[int], Tuple[int, int], List[Tuple[int, int]]], constant_values: Union[scalar, array] = 0, *, stream: Union[None, Stream, Device] = None) -> array
|
|
|
|
|
2023-11-30 02:30:41 +08:00
|
|
|
Pad an array with a constant value
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
pad_width (int, tuple(int), tuple(int, int) or list(tuple(int, int))): Number of padded
|
|
|
|
values to add to the edges of each axis:``((before_1, after_1),
|
|
|
|
(before_2, after_2), ..., (before_N, after_N))``. If a single pair
|
|
|
|
of integers is passed then ``(before_i, after_i)`` are all the same.
|
|
|
|
If a single integer or tuple with a single integer is passed then
|
|
|
|
all axes are extended by the same number on each side.
|
|
|
|
constant_value (array or scalar, optional): Optional constant value
|
|
|
|
to pad the edges of the array with.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The padded array.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"as_strided",
|
|
|
|
[](const array& a,
|
|
|
|
std::optional<std::vector<int>> shape,
|
|
|
|
std::optional<std::vector<size_t>> strides,
|
|
|
|
size_t offset,
|
|
|
|
StreamOrDevice s) {
|
|
|
|
std::vector<int> a_shape = (shape) ? *shape : a.shape();
|
|
|
|
std::vector<size_t> a_strides;
|
|
|
|
if (strides) {
|
|
|
|
a_strides = *strides;
|
|
|
|
} else {
|
|
|
|
std::fill_n(std::back_inserter(a_strides), a_shape.size(), 1);
|
|
|
|
for (int i = a_shape.size() - 1; i > 0; i--) {
|
|
|
|
a_strides[i - 1] = a_shape[i] * a_strides[i];
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return as_strided(a, a_shape, a_strides, offset, s);
|
|
|
|
},
|
|
|
|
"a"_a,
|
|
|
|
py::pos_only(),
|
|
|
|
"shape"_a = none,
|
|
|
|
"strides"_a = none,
|
|
|
|
"offset"_a = 0,
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
R"pbdoc(
|
2023-12-12 05:42:55 +08:00
|
|
|
as_strided(a: array, /, shape: Optional[List[int]] = None, strides: Optional[List[int]] = None, offset: int = 0, *, stream: Union[None, Stream, Device] = None) -> array
|
|
|
|
|
2023-11-30 02:30:41 +08:00
|
|
|
Create a view into the array with the given shape and strides.
|
|
|
|
|
|
|
|
The resulting array will always be as if the provided array was row
|
|
|
|
contiguous regardless of the provided arrays storage order and current
|
|
|
|
strides.
|
|
|
|
|
|
|
|
.. note::
|
|
|
|
Note that this function should be used with caution as it changes
|
|
|
|
the shape and strides of the array directly. This can lead to the
|
|
|
|
resulting array pointing to invalid memory locations which can
|
|
|
|
result into crashes.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array
|
|
|
|
shape (list(int), optional): The shape of the resulting array. If
|
|
|
|
None it defaults to ``a.shape()``.
|
|
|
|
strides (list(int), optional): The strides of the resulting array. If
|
|
|
|
None it defaults to the reverse exclusive cumulative product of
|
|
|
|
``a.shape()``.
|
|
|
|
offset (int): Skip that many elements from the beginning of the input
|
|
|
|
array.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The output array which is the strided view of the input.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"cumsum",
|
|
|
|
[](const array& a,
|
|
|
|
std::optional<int> axis,
|
|
|
|
bool reverse,
|
|
|
|
bool inclusive,
|
|
|
|
StreamOrDevice s) {
|
|
|
|
if (axis) {
|
|
|
|
return cumsum(a, *axis, reverse, inclusive, s);
|
|
|
|
} else {
|
|
|
|
return cumsum(reshape(a, {-1}, s), 0, reverse, inclusive, s);
|
|
|
|
}
|
|
|
|
},
|
|
|
|
"a"_a,
|
|
|
|
py::pos_only(),
|
|
|
|
"axis"_a = std::nullopt,
|
|
|
|
py::kw_only(),
|
|
|
|
"reverse"_a = false,
|
|
|
|
"inclusive"_a = true,
|
|
|
|
"stream"_a = none,
|
|
|
|
R"pbdoc(
|
2023-12-12 05:42:55 +08:00
|
|
|
cumsum(a: array, /, axis: Optional[int] = None, *, reverse: bool = False, inclusive: bool = True, stream: Union[None, Stream, Device] = None) -> array
|
|
|
|
|
2023-11-30 02:30:41 +08:00
|
|
|
Return the cumulative sum of the elements along the given axis.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array
|
|
|
|
axis (int, optional): Optional axis to compute the cumulative sum
|
|
|
|
over. If unspecified the cumulative sum of the flattened array is
|
|
|
|
returned.
|
|
|
|
reverse (bool): Perform the cumulative sum in reverse.
|
|
|
|
inclusive (bool): The i-th element of the output includes the i-th
|
|
|
|
element of the input.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"cumprod",
|
|
|
|
[](const array& a,
|
|
|
|
std::optional<int> axis,
|
|
|
|
bool reverse,
|
|
|
|
bool inclusive,
|
|
|
|
StreamOrDevice s) {
|
|
|
|
if (axis) {
|
|
|
|
return cumprod(a, *axis, reverse, inclusive, s);
|
|
|
|
} else {
|
|
|
|
return cumprod(reshape(a, {-1}, s), 0, reverse, inclusive, s);
|
|
|
|
}
|
|
|
|
},
|
|
|
|
"a"_a,
|
|
|
|
py::pos_only(),
|
|
|
|
"axis"_a = std::nullopt,
|
|
|
|
py::kw_only(),
|
|
|
|
"reverse"_a = false,
|
|
|
|
"inclusive"_a = true,
|
|
|
|
"stream"_a = none,
|
|
|
|
R"pbdoc(
|
2023-12-12 05:42:55 +08:00
|
|
|
cumprod(a: array, /, axis: Optional[int] = None, *, reverse: bool = False, inclusive: bool = True, stream: Union[None, Stream, Device] = None) -> array
|
|
|
|
|
2023-11-30 02:30:41 +08:00
|
|
|
Return the cumulative product of the elements along the given axis.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array
|
|
|
|
axis (int, optional): Optional axis to compute the cumulative product
|
|
|
|
over. If unspecified the cumulative product of the flattened array is
|
|
|
|
returned.
|
|
|
|
reverse (bool): Perform the cumulative product in reverse.
|
|
|
|
inclusive (bool): The i-th element of the output includes the i-th
|
|
|
|
element of the input.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"cummax",
|
|
|
|
[](const array& a,
|
|
|
|
std::optional<int> axis,
|
|
|
|
bool reverse,
|
|
|
|
bool inclusive,
|
|
|
|
StreamOrDevice s) {
|
|
|
|
if (axis) {
|
|
|
|
return cummax(a, *axis, reverse, inclusive, s);
|
|
|
|
} else {
|
|
|
|
return cummax(reshape(a, {-1}, s), 0, reverse, inclusive, s);
|
|
|
|
}
|
|
|
|
},
|
|
|
|
"a"_a,
|
|
|
|
py::pos_only(),
|
|
|
|
"axis"_a = std::nullopt,
|
|
|
|
py::kw_only(),
|
|
|
|
"reverse"_a = false,
|
|
|
|
"inclusive"_a = true,
|
|
|
|
"stream"_a = none,
|
|
|
|
R"pbdoc(
|
2023-12-12 05:42:55 +08:00
|
|
|
cummax(a: array, /, axis: Optional[int] = None, *, reverse: bool = False, inclusive: bool = True, stream: Union[None, Stream, Device] = None) -> array
|
|
|
|
|
2023-11-30 02:30:41 +08:00
|
|
|
Return the cumulative maximum of the elements along the given axis.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array
|
|
|
|
axis (int, optional): Optional axis to compute the cumulative maximum
|
|
|
|
over. If unspecified the cumulative maximum of the flattened array is
|
|
|
|
returned.
|
|
|
|
reverse (bool): Perform the cumulative maximum in reverse.
|
|
|
|
inclusive (bool): The i-th element of the output includes the i-th
|
|
|
|
element of the input.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"cummin",
|
|
|
|
[](const array& a,
|
|
|
|
std::optional<int> axis,
|
|
|
|
bool reverse,
|
|
|
|
bool inclusive,
|
|
|
|
StreamOrDevice s) {
|
|
|
|
if (axis) {
|
|
|
|
return cummin(a, *axis, reverse, inclusive, s);
|
|
|
|
} else {
|
|
|
|
return cummin(reshape(a, {-1}, s), 0, reverse, inclusive, s);
|
|
|
|
}
|
|
|
|
},
|
|
|
|
"a"_a,
|
|
|
|
py::pos_only(),
|
|
|
|
"axis"_a = std::nullopt,
|
|
|
|
py::kw_only(),
|
|
|
|
"reverse"_a = false,
|
|
|
|
"inclusive"_a = true,
|
|
|
|
"stream"_a = none,
|
|
|
|
R"pbdoc(
|
2023-12-12 05:42:55 +08:00
|
|
|
cummin(a: array, /, axis: Optional[int] = None, *, reverse: bool = False, inclusive: bool = True, stream: Union[None, Stream, Device] = None) -> array
|
|
|
|
|
2023-11-30 02:30:41 +08:00
|
|
|
Return the cumulative minimum of the elements along the given axis.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array
|
|
|
|
axis (int, optional): Optional axis to compute the cumulative minimum
|
|
|
|
over. If unspecified the cumulative minimum of the flattened array is
|
|
|
|
returned.
|
|
|
|
reverse (bool): Perform the cumulative minimum in reverse.
|
|
|
|
inclusive (bool): The i-th element of the output includes the i-th
|
|
|
|
element of the input.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"convolve",
|
|
|
|
[](const array& a,
|
|
|
|
const array& v,
|
|
|
|
const std::string& mode,
|
|
|
|
StreamOrDevice s) {
|
|
|
|
if (a.ndim() != 1 || v.ndim() != 1) {
|
|
|
|
throw std::invalid_argument("[convolve] Inputs must be 1D.");
|
|
|
|
}
|
|
|
|
|
|
|
|
array in = a.size() < v.size() ? v : a;
|
|
|
|
array wt = a.size() < v.size() ? a : v;
|
|
|
|
wt = slice(wt, {wt.shape(0) - 1}, {-wt.shape(0) - 1}, {-1}, s);
|
|
|
|
|
|
|
|
in = reshape(in, {1, -1, 1}, s);
|
|
|
|
wt = reshape(wt, {1, -1, 1}, s);
|
|
|
|
|
|
|
|
int padding = 0;
|
|
|
|
|
|
|
|
if (mode == "full") {
|
|
|
|
padding = wt.size() - 1;
|
|
|
|
} else if (mode == "valid") {
|
|
|
|
padding = 0;
|
|
|
|
} else if (mode == "same") {
|
|
|
|
// Odd sizes use symmetric padding
|
|
|
|
if (wt.size() % 2) {
|
|
|
|
padding = wt.size() / 2;
|
|
|
|
} else { // Even sizes use asymmetric padding
|
|
|
|
int pad_l = wt.size() / 2;
|
|
|
|
int pad_r = std::max(0, pad_l - 1);
|
|
|
|
in = pad(in, {{0, 0}, {pad_l, pad_r}, {0, 0}}, array(0), s);
|
|
|
|
}
|
|
|
|
|
|
|
|
} else {
|
|
|
|
throw std::invalid_argument("[convolve] Invalid mode.");
|
|
|
|
}
|
|
|
|
|
|
|
|
array out = conv1d(
|
|
|
|
in,
|
|
|
|
wt,
|
|
|
|
/*stride = */ 1,
|
|
|
|
/*padding = */ padding,
|
|
|
|
/*dilation = */ 1,
|
|
|
|
/*groups = */ 1,
|
|
|
|
s);
|
|
|
|
|
|
|
|
return reshape(out, {-1}, s);
|
|
|
|
},
|
|
|
|
"a"_a,
|
|
|
|
"v"_a,
|
|
|
|
py::pos_only(),
|
|
|
|
"mode"_a = "full",
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
R"pbdoc(
|
2023-12-12 05:42:55 +08:00
|
|
|
convolve(a: array, v: array, /, mode: str = "full", *, stream: Union[None, Stream, Device] = None) -> array
|
|
|
|
|
2023-11-30 02:30:41 +08:00
|
|
|
The discrete convolution of 1D arrays.
|
|
|
|
|
|
|
|
If ``v`` is longer than ``a``, then they are swapped.
|
|
|
|
The conv filter is flipped following signal processing convention.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): 1D Input array.
|
|
|
|
v (array): 1D Input array.
|
|
|
|
mode (str, optional): {'full', 'valid', 'same'}
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The convolved array.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"conv1d",
|
|
|
|
&conv1d,
|
|
|
|
"input"_a,
|
|
|
|
"weight"_a,
|
|
|
|
py::pos_only(),
|
|
|
|
"stride"_a = 1,
|
|
|
|
"padding"_a = 0,
|
|
|
|
"dilation"_a = 1,
|
|
|
|
"groups"_a = 1,
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
R"pbdoc(
|
2023-12-12 05:42:55 +08:00
|
|
|
conv1d(input: array, weight: array, /, stride: int = 1, padding: int = 0, dilation: int = 1, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array
|
|
|
|
|
2023-11-30 02:30:41 +08:00
|
|
|
1D convolution over an input with several channels
|
|
|
|
|
|
|
|
Note: Only the default ``groups=1`` is currently supported.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
input (array): input array of shape (``N``, ``H``, ``C_in``)
|
|
|
|
weight (array): weight array of shape (``C_out``, ``H``, ``C_in``)
|
|
|
|
stride (int, optional): kernel stride. Default: ``1``.
|
|
|
|
padding (int, optional): input padding. Default: ``0``.
|
|
|
|
dilation (int, optional): kernel dilation. Default: ``1``.
|
|
|
|
groups (int, optional): input feature groups. Default: ``1``.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The convolved array.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"conv2d",
|
|
|
|
[](const array& input,
|
|
|
|
const array& weight,
|
|
|
|
const std::variant<int, std::pair<int, int>>& stride,
|
|
|
|
const std::variant<int, std::pair<int, int>>& padding,
|
|
|
|
const std::variant<int, std::pair<int, int>>& dilation,
|
|
|
|
int groups,
|
|
|
|
StreamOrDevice s) {
|
|
|
|
std::pair<int, int> stride_pair{1, 1};
|
|
|
|
std::pair<int, int> padding_pair{0, 0};
|
|
|
|
std::pair<int, int> dilation_pair{1, 1};
|
|
|
|
|
|
|
|
if (auto pv = std::get_if<int>(&stride); pv) {
|
|
|
|
stride_pair = std::pair<int, int>{*pv, *pv};
|
|
|
|
} else {
|
|
|
|
stride_pair = std::get<std::pair<int, int>>(stride);
|
|
|
|
}
|
|
|
|
|
|
|
|
if (auto pv = std::get_if<int>(&padding); pv) {
|
|
|
|
padding_pair = std::pair<int, int>{*pv, *pv};
|
|
|
|
} else {
|
|
|
|
padding_pair = std::get<std::pair<int, int>>(padding);
|
|
|
|
}
|
|
|
|
|
|
|
|
if (auto pv = std::get_if<int>(&dilation); pv) {
|
|
|
|
dilation_pair = std::pair<int, int>{*pv, *pv};
|
|
|
|
} else {
|
|
|
|
dilation_pair = std::get<std::pair<int, int>>(dilation);
|
|
|
|
}
|
|
|
|
|
|
|
|
return conv2d(
|
|
|
|
input, weight, stride_pair, padding_pair, dilation_pair, groups, s);
|
|
|
|
},
|
|
|
|
"input"_a,
|
|
|
|
"weight"_a,
|
|
|
|
py::pos_only(),
|
|
|
|
"stride"_a = 1,
|
|
|
|
"padding"_a = 0,
|
|
|
|
"dilation"_a = 1,
|
|
|
|
"groups"_a = 1,
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
R"pbdoc(
|
2023-12-12 05:42:55 +08:00
|
|
|
conv2d(input: array, weight: array, /, stride: Union[int, Tuple[int, int]] = 1, padding: Union[int, Tuple[int, int]] = 0, dilation: Union[int, Tuple[int, int]] = 1, groups: Union[int, Tuple[int, int]] = 1, *, stream: Union[None, Stream, Device] = None) -> array
|
|
|
|
|
2023-11-30 02:30:41 +08:00
|
|
|
2D convolution over an input with several channels
|
|
|
|
|
|
|
|
Note: Only the default ``groups=1`` is currently supported.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
input (array): input array of shape ``(N, H, W, C_in)``
|
|
|
|
weight (array): weight array of shape ``(C_out, H, W, C_in)``
|
2023-12-12 07:01:41 +08:00
|
|
|
stride (int or tuple(int), optional): :obj:`tuple` of size 2 with
|
|
|
|
kernel strides. All spatial dimensions get the same stride if
|
2023-11-30 02:30:41 +08:00
|
|
|
only one number is specified. Default: ``1``.
|
|
|
|
padding (int or tuple(int), optional): :obj:`tuple` of size 2 with
|
2023-12-12 07:01:41 +08:00
|
|
|
symmetric input padding. All spatial dimensions get the same
|
2023-11-30 02:30:41 +08:00
|
|
|
padding if only one number is specified. Default: ``0``.
|
2023-12-12 07:01:41 +08:00
|
|
|
dilation (int or tuple(int), optional): :obj:`tuple` of size 2 with
|
|
|
|
kernel dilation. All spatial dimensions get the same dilation
|
2023-11-30 02:30:41 +08:00
|
|
|
if only one number is specified. Default: ``1``
|
|
|
|
groups (int, optional): input feature groups. Default: ``1``.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The convolved array.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"save",
|
|
|
|
&mlx_save_helper,
|
|
|
|
"file"_a,
|
|
|
|
"arr"_a,
|
|
|
|
py::pos_only(),
|
|
|
|
"retain_graph"_a = true,
|
|
|
|
py::kw_only(),
|
|
|
|
R"pbdoc(
|
2023-12-12 05:42:55 +08:00
|
|
|
save(file: str, arr: array, / , retain_graph: bool = True)
|
|
|
|
|
2023-11-30 02:30:41 +08:00
|
|
|
Save the array to a binary file in ``.npy`` format.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
file (str): File to which the array is saved
|
|
|
|
arr (array): Array to be saved.
|
|
|
|
retain_graph(bool): Optional argument to retain graph
|
|
|
|
during array evaluation before saving. Default: True
|
|
|
|
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"savez",
|
|
|
|
[](py::object file, py::args args, const py::kwargs& kwargs) {
|
|
|
|
mlx_savez_helper(file, args, kwargs, /*compressed=*/false);
|
|
|
|
},
|
|
|
|
"file"_a,
|
|
|
|
py::pos_only(),
|
|
|
|
py::kw_only(),
|
|
|
|
R"pbdoc(
|
2023-12-12 05:42:55 +08:00
|
|
|
savez(file: str, *args, **kwargs)
|
2023-12-12 07:01:41 +08:00
|
|
|
|
2023-11-30 02:30:41 +08:00
|
|
|
Save several arrays to a binary file in uncompressed ``.npz`` format.
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
import mlx.core as mx
|
|
|
|
|
|
|
|
x = mx.ones((10, 10))
|
|
|
|
mx.savez("my_path.npz", x=x)
|
|
|
|
|
|
|
|
import mlx.nn as nn
|
|
|
|
from mlx.utils import tree_flatten
|
|
|
|
|
|
|
|
model = nn.TransformerEncoder(6, 128, 4)
|
|
|
|
flat_params = tree_flatten(model.parameters())
|
|
|
|
mx.savez("model.npz", **dict(flat_params))
|
|
|
|
|
|
|
|
Args:
|
|
|
|
file (file, str): Path to file to which the arrays are saved.
|
|
|
|
args (arrays): Arrays to be saved.
|
|
|
|
kwargs (arrays): Arrays to be saved. Each array will be saved
|
|
|
|
with the associated keyword as the output file name.
|
|
|
|
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"savez_compressed",
|
|
|
|
[](py::object file, py::args args, const py::kwargs& kwargs) {
|
|
|
|
mlx_savez_helper(file, args, kwargs, /*compressed=*/true);
|
|
|
|
},
|
|
|
|
"file"_a,
|
|
|
|
py::pos_only(),
|
|
|
|
py::kw_only(),
|
|
|
|
R"pbdoc(
|
2023-12-12 05:42:55 +08:00
|
|
|
savez_compressed(file: str, *args, **kwargs)
|
|
|
|
|
2023-11-30 02:30:41 +08:00
|
|
|
Save several arrays to a binary file in compressed ``.npz`` format.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
file (file, str): Path to file to which the arrays are saved.
|
|
|
|
args (arrays): Arrays to be saved.
|
|
|
|
kwargs (arrays): Arrays to be saved. Each array will be saved
|
|
|
|
with the associated keyword as the output file name.
|
|
|
|
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"load",
|
|
|
|
&mlx_load_helper,
|
|
|
|
"file"_a,
|
|
|
|
py::pos_only(),
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
R"pbdoc(
|
2023-12-12 05:42:55 +08:00
|
|
|
load(file: str, /, *, stream: Union[None, Stream, Device] = None) -> Union[array, Dict[str, array]]
|
|
|
|
|
2023-11-30 02:30:41 +08:00
|
|
|
Load array(s) from a binary file in ``.npy`` or ``.npz`` format.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
file (file, str): File in which the array is saved
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
result (array, dict): The loaded array if ``.npy`` file or a dict mapping name to array if ``.npz`` file
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"where",
|
|
|
|
[](const ScalarOrArray& condition,
|
|
|
|
const ScalarOrArray& x_,
|
|
|
|
const ScalarOrArray& y_,
|
|
|
|
StreamOrDevice s) {
|
|
|
|
auto [x, y] = to_arrays(x_, y_);
|
|
|
|
return where(to_array(condition), x, y, s);
|
|
|
|
},
|
|
|
|
"condition"_a,
|
|
|
|
"x"_a,
|
|
|
|
"y"_a,
|
|
|
|
py::pos_only(),
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
R"pbdoc(
|
2023-12-12 05:42:55 +08:00
|
|
|
where(condition: Union[scalar, array], x: Union[scalar, array], y: Union[scalar, array], /, *, stream: Union[None, Stream, Device] = None) -> array
|
|
|
|
|
2023-11-30 02:30:41 +08:00
|
|
|
Select from ``x`` or ``y`` according to ``condition``.
|
|
|
|
|
|
|
|
The condition and input arrays must be the same shape or broadcastable
|
|
|
|
with each another.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
condition (array): The condition array.
|
|
|
|
x (array): The input selected from where condition is ``True``.
|
|
|
|
y (array): The input selected from where condition is ``False``.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
result (array): The output containing elements selected from ``x`` and ``y``.
|
|
|
|
)pbdoc");
|
|
|
|
}
|