2024-03-19 11:12:25 +08:00
|
|
|
// Copyright © 2023-2024 Apple Inc.
|
2023-12-01 03:12:53 +08:00
|
|
|
|
2023-11-30 02:30:41 +08:00
|
|
|
#include <numeric>
|
|
|
|
#include <ostream>
|
|
|
|
#include <variant>
|
|
|
|
|
2024-03-19 11:12:25 +08:00
|
|
|
#include <nanobind/nanobind.h>
|
|
|
|
#include <nanobind/stl/optional.h>
|
|
|
|
#include <nanobind/stl/pair.h>
|
|
|
|
#include <nanobind/stl/string.h>
|
|
|
|
#include <nanobind/stl/tuple.h>
|
|
|
|
#include <nanobind/stl/variant.h>
|
|
|
|
#include <nanobind/stl/vector.h>
|
2023-11-30 02:30:41 +08:00
|
|
|
|
2024-07-26 00:36:44 +08:00
|
|
|
#include "mlx/einsum.h"
|
2023-11-30 02:30:41 +08:00
|
|
|
#include "mlx/ops.h"
|
|
|
|
#include "mlx/utils.h"
|
|
|
|
#include "python/src/load.h"
|
|
|
|
#include "python/src/utils.h"
|
|
|
|
|
2024-12-12 07:45:39 +08:00
|
|
|
namespace mx = mlx::core;
|
2024-03-19 11:12:25 +08:00
|
|
|
namespace nb = nanobind;
|
|
|
|
using namespace nb::literals;
|
2023-11-30 02:30:41 +08:00
|
|
|
|
|
|
|
using Scalar = std::variant<int, double>;
|
|
|
|
|
2024-12-12 07:45:39 +08:00
|
|
|
mx::Dtype scalar_to_dtype(Scalar scalar) {
|
2023-11-30 02:30:41 +08:00
|
|
|
if (std::holds_alternative<int>(scalar)) {
|
2024-12-12 07:45:39 +08:00
|
|
|
return mx::int32;
|
2023-11-30 02:30:41 +08:00
|
|
|
} else {
|
2024-12-12 07:45:39 +08:00
|
|
|
return mx::float32;
|
2023-11-30 02:30:41 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
double scalar_to_double(Scalar s) {
|
|
|
|
if (std::holds_alternative<double>(s)) {
|
|
|
|
return std::get<double>(s);
|
|
|
|
} else {
|
|
|
|
return static_cast<double>(std::get<int>(s));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2024-03-19 11:12:25 +08:00
|
|
|
void init_ops(nb::module_& m) {
|
2023-11-30 02:30:41 +08:00
|
|
|
m.def(
|
|
|
|
"reshape",
|
2024-12-12 07:45:39 +08:00
|
|
|
&mx::reshape,
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::arg(),
|
2023-11-30 02:30:41 +08:00
|
|
|
"shape"_a,
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
2024-03-25 06:03:27 +08:00
|
|
|
nb::sig("def reshape(a: array, /, shape: Sequence[int], *, stream: "
|
2024-03-19 11:12:25 +08:00
|
|
|
"Union[None, Stream, Device] = None) -> array"),
|
2023-11-30 02:30:41 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Reshape an array while preserving the size.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
shape (tuple(int)): New shape.
|
2024-01-27 01:27:31 +08:00
|
|
|
stream (Stream, optional): Stream or device. Defaults to ``None``
|
2023-11-30 02:30:41 +08:00
|
|
|
in which case the default stream of the default device is used.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The reshaped array.
|
|
|
|
)pbdoc");
|
2023-12-17 13:54:37 +08:00
|
|
|
m.def(
|
|
|
|
"flatten",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const mx::array& a,
|
2023-12-17 13:54:37 +08:00
|
|
|
int start_axis,
|
|
|
|
int end_axis,
|
2024-12-12 07:45:39 +08:00
|
|
|
const mx::StreamOrDevice& s) {
|
|
|
|
return mx::flatten(a, start_axis, end_axis);
|
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::arg(),
|
2023-12-17 13:54:37 +08:00
|
|
|
"start_axis"_a = 0,
|
|
|
|
"end_axis"_a = -1,
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig("def flatten(a: array, /, start_axis: int = 0, end_axis: int = "
|
|
|
|
"-1, *, stream: Union[None, Stream, Device] = None) -> array"),
|
2023-12-17 13:54:37 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Flatten an array.
|
|
|
|
|
2024-01-27 07:16:46 +08:00
|
|
|
The axes flattened will be between ``start_axis`` and ``end_axis``,
|
|
|
|
inclusive. Negative axes are supported. After converting negative axis to
|
|
|
|
positive, axes outside the valid range will be clamped to a valid value,
|
|
|
|
``start_axis`` to ``0`` and ``end_axis`` to ``ndim - 1``.
|
|
|
|
|
2023-12-17 13:54:37 +08:00
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
start_axis (int, optional): The first dimension to flatten. Defaults to ``0``.
|
|
|
|
end_axis (int, optional): The last dimension to flatten. Defaults to ``-1``.
|
|
|
|
stream (Stream, optional): Stream or device. Defaults to ``None``
|
|
|
|
in which case the default stream of the default device is used.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The flattened array.
|
2024-01-27 07:16:46 +08:00
|
|
|
|
|
|
|
Example:
|
|
|
|
>>> a = mx.array([[1, 2], [3, 4]])
|
|
|
|
>>> mx.flatten(a)
|
|
|
|
array([1, 2, 3, 4], dtype=int32)
|
|
|
|
>>>
|
|
|
|
>>> mx.flatten(a, start_axis=0, end_axis=-1)
|
|
|
|
array([1, 2, 3, 4], dtype=int32)
|
2023-12-17 13:54:37 +08:00
|
|
|
)pbdoc");
|
2024-12-12 13:51:37 +08:00
|
|
|
m.def(
|
|
|
|
"unflatten",
|
|
|
|
&unflatten,
|
|
|
|
nb::arg(),
|
|
|
|
"axis"_a,
|
|
|
|
"shape"_a,
|
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def unflatten(a: array, /, axis: int, shape: Sequence[int], *, stream: Union[None, Stream, Device] = None) -> array"),
|
|
|
|
R"pbdoc(
|
|
|
|
Unflatten an axis of an array to a shape.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
axis (int): The axis to unflatten.
|
|
|
|
shape (tuple(int)): The shape to unflatten to. At most one
|
|
|
|
entry can be ``-1`` in which case the corresponding size will be
|
|
|
|
inferred.
|
|
|
|
stream (Stream, optional): Stream or device. Defaults to ``None``
|
|
|
|
in which case the default stream of the default device is used.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The unflattened array.
|
|
|
|
|
|
|
|
Example:
|
|
|
|
>>> a = mx.array([1, 2, 3, 4])
|
|
|
|
>>> mx.unflatten(a, 0, (2, -1))
|
|
|
|
array([[1, 2], [3, 4]], dtype=int32)
|
|
|
|
)pbdoc");
|
2023-11-30 02:30:41 +08:00
|
|
|
m.def(
|
|
|
|
"squeeze",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const mx::array& a, const IntOrVec& v, const mx::StreamOrDevice& s) {
|
2023-11-30 02:30:41 +08:00
|
|
|
if (std::holds_alternative<std::monostate>(v)) {
|
2024-12-12 07:45:39 +08:00
|
|
|
return mx::squeeze(a, s);
|
2023-11-30 02:30:41 +08:00
|
|
|
} else if (auto pv = std::get_if<int>(&v); pv) {
|
2024-12-12 07:45:39 +08:00
|
|
|
return mx::squeeze(a, *pv, s);
|
2023-11-30 02:30:41 +08:00
|
|
|
} else {
|
2024-12-12 07:45:39 +08:00
|
|
|
return mx::squeeze(a, std::get<std::vector<int>>(v), s);
|
2023-11-30 02:30:41 +08:00
|
|
|
}
|
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::arg(),
|
|
|
|
"axis"_a = nb::none(),
|
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
2024-03-25 06:03:27 +08:00
|
|
|
nb::sig(
|
|
|
|
"def squeeze(a: array, /, axis: Union[None, int, Sequence[int]] = "
|
|
|
|
"None, *, stream: Union[None, Stream, Device] = None) -> array"),
|
2023-11-30 02:30:41 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Remove length one axes from an array.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
axis (int or tuple(int), optional): Axes to remove. Defaults
|
2024-03-13 02:46:44 +08:00
|
|
|
to ``None`` in which case all size one axes are removed.
|
2023-11-30 02:30:41 +08:00
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The output array with size one axes removed.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"expand_dims",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const mx::array& a,
|
2023-11-30 02:30:41 +08:00
|
|
|
const std::variant<int, std::vector<int>>& v,
|
2024-12-12 07:45:39 +08:00
|
|
|
mx::StreamOrDevice s) {
|
2023-11-30 02:30:41 +08:00
|
|
|
if (auto pv = std::get_if<int>(&v); pv) {
|
2024-12-12 07:45:39 +08:00
|
|
|
return mx::expand_dims(a, *pv, s);
|
2023-11-30 02:30:41 +08:00
|
|
|
} else {
|
2024-12-12 07:45:39 +08:00
|
|
|
return mx::expand_dims(a, std::get<std::vector<int>>(v), s);
|
2023-11-30 02:30:41 +08:00
|
|
|
}
|
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::arg(),
|
2023-11-30 02:30:41 +08:00
|
|
|
"axis"_a,
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
2024-03-25 06:03:27 +08:00
|
|
|
nb::sig("def expand_dims(a: array, /, axis: Union[int, Sequence[int]], "
|
2024-03-19 11:12:25 +08:00
|
|
|
"*, stream: Union[None, Stream, Device] = None) -> array"),
|
2023-11-30 02:30:41 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Add a size one dimension at the given axis.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
axes (int or tuple(int)): The index of the inserted dimensions.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The array with inserted dimensions.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"abs",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const ScalarOrArray& a, mx::StreamOrDevice s) {
|
|
|
|
return mx::abs(to_array(a), s);
|
2024-05-10 00:36:02 +08:00
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::arg(),
|
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def abs(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"),
|
2023-11-30 02:30:41 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Element-wise absolute value.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The absolute value of ``a``.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"sign",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const ScalarOrArray& a, mx::StreamOrDevice s) {
|
|
|
|
return mx::sign(to_array(a), s);
|
2024-05-10 00:36:02 +08:00
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::arg(),
|
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def sign(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"),
|
2023-11-30 02:30:41 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Element-wise sign.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The sign of ``a``.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"negative",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const ScalarOrArray& a, mx::StreamOrDevice s) {
|
|
|
|
return mx::negative(to_array(a), s);
|
2024-05-10 00:36:02 +08:00
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::arg(),
|
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def negative(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"),
|
2023-11-30 02:30:41 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Element-wise negation.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The negative of ``a``.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"add",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const ScalarOrArray& a_,
|
|
|
|
const ScalarOrArray& b_,
|
|
|
|
mx::StreamOrDevice s) {
|
2023-11-30 02:30:41 +08:00
|
|
|
auto [a, b] = to_arrays(a_, b_);
|
2024-12-12 07:45:39 +08:00
|
|
|
return mx::add(a, b, s);
|
2023-11-30 02:30:41 +08:00
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::arg(),
|
|
|
|
nb::arg(),
|
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def add(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array"),
|
2023-11-30 02:30:41 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Element-wise addition.
|
|
|
|
|
|
|
|
Add two arrays with numpy-style broadcasting semantics. Either or both input arrays
|
|
|
|
can also be scalars.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array or scalar.
|
|
|
|
b (array): Input array or scalar.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The sum of ``a`` and ``b``.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"subtract",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const ScalarOrArray& a_,
|
|
|
|
const ScalarOrArray& b_,
|
|
|
|
mx::StreamOrDevice s) {
|
2023-11-30 02:30:41 +08:00
|
|
|
auto [a, b] = to_arrays(a_, b_);
|
2024-12-12 07:45:39 +08:00
|
|
|
return mx::subtract(a, b, s);
|
2023-11-30 02:30:41 +08:00
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::arg(),
|
|
|
|
nb::arg(),
|
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def subtract(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array"),
|
2023-11-30 02:30:41 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Element-wise subtraction.
|
|
|
|
|
|
|
|
Subtract one array from another with numpy-style broadcasting semantics. Either or both
|
|
|
|
input arrays can also be scalars.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array or scalar.
|
|
|
|
b (array): Input array or scalar.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The difference ``a - b``.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"multiply",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const ScalarOrArray& a_,
|
|
|
|
const ScalarOrArray& b_,
|
|
|
|
mx::StreamOrDevice s) {
|
2023-11-30 02:30:41 +08:00
|
|
|
auto [a, b] = to_arrays(a_, b_);
|
2024-12-12 07:45:39 +08:00
|
|
|
return mx::multiply(a, b, s);
|
2023-11-30 02:30:41 +08:00
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::arg(),
|
|
|
|
nb::arg(),
|
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def multiply(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array"),
|
2023-11-30 02:30:41 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Element-wise multiplication.
|
|
|
|
|
|
|
|
Multiply two arrays with numpy-style broadcasting semantics. Either or both
|
|
|
|
input arrays can also be scalars.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array or scalar.
|
|
|
|
b (array): Input array or scalar.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The multiplication ``a * b``.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"divide",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const ScalarOrArray& a_,
|
|
|
|
const ScalarOrArray& b_,
|
|
|
|
mx::StreamOrDevice s) {
|
2023-11-30 02:30:41 +08:00
|
|
|
auto [a, b] = to_arrays(a_, b_);
|
2024-12-12 07:45:39 +08:00
|
|
|
return mx::divide(a, b, s);
|
2023-11-30 02:30:41 +08:00
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::arg(),
|
|
|
|
nb::arg(),
|
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def divide(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array"),
|
2023-11-30 02:30:41 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Element-wise division.
|
|
|
|
|
|
|
|
Divide two arrays with numpy-style broadcasting semantics. Either or both
|
|
|
|
input arrays can also be scalars.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array or scalar.
|
|
|
|
b (array): Input array or scalar.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The quotient ``a / b``.
|
|
|
|
)pbdoc");
|
2024-01-09 08:39:08 +08:00
|
|
|
m.def(
|
|
|
|
"divmod",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const ScalarOrArray& a_,
|
|
|
|
const ScalarOrArray& b_,
|
|
|
|
mx::StreamOrDevice s) {
|
2024-01-09 08:39:08 +08:00
|
|
|
auto [a, b] = to_arrays(a_, b_);
|
2024-12-12 07:45:39 +08:00
|
|
|
return mx::divmod(a, b, s);
|
2024-01-09 08:39:08 +08:00
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::arg(),
|
|
|
|
nb::arg(),
|
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def divmod(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array"),
|
2024-01-09 08:39:08 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Element-wise quotient and remainder.
|
|
|
|
|
|
|
|
The fuction ``divmod(a, b)`` is equivalent to but faster than
|
|
|
|
``(a // b, a % b)``. The function uses numpy-style broadcasting
|
|
|
|
semantics. Either or both input arrays can also be scalars.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array or scalar.
|
|
|
|
b (array): Input array or scalar.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
tuple(array, array): The quotient ``a // b`` and remainder ``a % b``.
|
|
|
|
)pbdoc");
|
2023-12-20 12:12:19 +08:00
|
|
|
m.def(
|
|
|
|
"floor_divide",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const ScalarOrArray& a_,
|
|
|
|
const ScalarOrArray& b_,
|
|
|
|
mx::StreamOrDevice s) {
|
2023-12-20 12:12:19 +08:00
|
|
|
auto [a, b] = to_arrays(a_, b_);
|
2024-12-12 07:45:39 +08:00
|
|
|
return mx::floor_divide(a, b, s);
|
2023-12-20 12:12:19 +08:00
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::arg(),
|
|
|
|
nb::arg(),
|
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def floor_divide(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array"),
|
2023-12-20 12:12:19 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Element-wise integer division.
|
|
|
|
|
|
|
|
If either array is a floating point type then it is equivalent to
|
|
|
|
calling :func:`floor` after :func:`divide`.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array or scalar.
|
|
|
|
b (array): Input array or scalar.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The quotient ``a // b``.
|
|
|
|
)pbdoc");
|
2023-12-09 07:08:52 +08:00
|
|
|
m.def(
|
|
|
|
"remainder",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const ScalarOrArray& a_,
|
|
|
|
const ScalarOrArray& b_,
|
|
|
|
mx::StreamOrDevice s) {
|
2023-12-09 07:08:52 +08:00
|
|
|
auto [a, b] = to_arrays(a_, b_);
|
2024-12-12 07:45:39 +08:00
|
|
|
return mx::remainder(a, b, s);
|
2023-12-09 07:08:52 +08:00
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::arg(),
|
|
|
|
nb::arg(),
|
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def remainder(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array"),
|
2023-12-09 07:08:52 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Element-wise remainder of division.
|
|
|
|
|
|
|
|
Computes the remainder of dividing a with b with numpy-style
|
|
|
|
broadcasting semantics. Either or both input arrays can also be
|
|
|
|
scalars.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array or scalar.
|
|
|
|
b (array): Input array or scalar.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The remainder of ``a // b``.
|
|
|
|
)pbdoc");
|
2023-11-30 02:30:41 +08:00
|
|
|
m.def(
|
|
|
|
"equal",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const ScalarOrArray& a_,
|
|
|
|
const ScalarOrArray& b_,
|
|
|
|
mx::StreamOrDevice s) {
|
2023-11-30 02:30:41 +08:00
|
|
|
auto [a, b] = to_arrays(a_, b_);
|
2024-12-12 07:45:39 +08:00
|
|
|
return mx::equal(a, b, s);
|
2023-11-30 02:30:41 +08:00
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::arg(),
|
|
|
|
nb::arg(),
|
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def equal(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array"),
|
2023-11-30 02:30:41 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Element-wise equality.
|
|
|
|
|
|
|
|
Equality comparison on two arrays with numpy-style broadcasting semantics.
|
|
|
|
Either or both input arrays can also be scalars.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array or scalar.
|
|
|
|
b (array): Input array or scalar.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The element-wise comparison ``a == b``.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"not_equal",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const ScalarOrArray& a_,
|
|
|
|
const ScalarOrArray& b_,
|
|
|
|
mx::StreamOrDevice s) {
|
2023-11-30 02:30:41 +08:00
|
|
|
auto [a, b] = to_arrays(a_, b_);
|
2024-12-12 07:45:39 +08:00
|
|
|
return mx::not_equal(a, b, s);
|
2023-11-30 02:30:41 +08:00
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::arg(),
|
|
|
|
nb::arg(),
|
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def not_equal(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array"),
|
2023-11-30 02:30:41 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Element-wise not equal.
|
|
|
|
|
|
|
|
Not equal comparison on two arrays with numpy-style broadcasting semantics.
|
|
|
|
Either or both input arrays can also be scalars.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array or scalar.
|
|
|
|
b (array): Input array or scalar.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The element-wise comparison ``a != b``.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"less",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const ScalarOrArray& a_,
|
|
|
|
const ScalarOrArray& b_,
|
|
|
|
mx::StreamOrDevice s) {
|
2023-11-30 02:30:41 +08:00
|
|
|
auto [a, b] = to_arrays(a_, b_);
|
2024-12-12 07:45:39 +08:00
|
|
|
return mx::less(a, b, s);
|
2023-11-30 02:30:41 +08:00
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::arg(),
|
|
|
|
nb::arg(),
|
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def less(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array"),
|
2023-11-30 02:30:41 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Element-wise less than.
|
|
|
|
|
|
|
|
Strict less than on two arrays with numpy-style broadcasting semantics.
|
|
|
|
Either or both input arrays can also be scalars.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array or scalar.
|
|
|
|
b (array): Input array or scalar.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The element-wise comparison ``a < b``.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"less_equal",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const ScalarOrArray& a_,
|
|
|
|
const ScalarOrArray& b_,
|
|
|
|
mx::StreamOrDevice s) {
|
2023-11-30 02:30:41 +08:00
|
|
|
auto [a, b] = to_arrays(a_, b_);
|
2024-12-12 07:45:39 +08:00
|
|
|
return mx::less_equal(a, b, s);
|
2023-11-30 02:30:41 +08:00
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::arg(),
|
|
|
|
nb::arg(),
|
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def less_equal(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array"),
|
2023-11-30 02:30:41 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Element-wise less than or equal.
|
|
|
|
|
|
|
|
Less than or equal on two arrays with numpy-style broadcasting semantics.
|
|
|
|
Either or both input arrays can also be scalars.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array or scalar.
|
|
|
|
b (array): Input array or scalar.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The element-wise comparison ``a <= b``.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"greater",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const ScalarOrArray& a_,
|
|
|
|
const ScalarOrArray& b_,
|
|
|
|
mx::StreamOrDevice s) {
|
2023-11-30 02:30:41 +08:00
|
|
|
auto [a, b] = to_arrays(a_, b_);
|
2024-12-12 07:45:39 +08:00
|
|
|
return mx::greater(a, b, s);
|
2023-11-30 02:30:41 +08:00
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::arg(),
|
|
|
|
nb::arg(),
|
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def greater(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array"),
|
2023-11-30 02:30:41 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Element-wise greater than.
|
|
|
|
|
|
|
|
Strict greater than on two arrays with numpy-style broadcasting semantics.
|
|
|
|
Either or both input arrays can also be scalars.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array or scalar.
|
|
|
|
b (array): Input array or scalar.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The element-wise comparison ``a > b``.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"greater_equal",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const ScalarOrArray& a_,
|
|
|
|
const ScalarOrArray& b_,
|
|
|
|
mx::StreamOrDevice s) {
|
2023-11-30 02:30:41 +08:00
|
|
|
auto [a, b] = to_arrays(a_, b_);
|
2024-12-12 07:45:39 +08:00
|
|
|
return mx::greater_equal(a, b, s);
|
2023-11-30 02:30:41 +08:00
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::arg(),
|
|
|
|
nb::arg(),
|
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def greater_equal(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array"),
|
2023-11-30 02:30:41 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Element-wise greater or equal.
|
|
|
|
|
|
|
|
Greater than or equal on two arrays with numpy-style broadcasting semantics.
|
|
|
|
Either or both input arrays can also be scalars.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array or scalar.
|
|
|
|
b (array): Input array or scalar.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The element-wise comparison ``a >= b``.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"array_equal",
|
|
|
|
[](const ScalarOrArray& a_,
|
|
|
|
const ScalarOrArray& b_,
|
|
|
|
bool equal_nan,
|
2024-12-12 07:45:39 +08:00
|
|
|
mx::StreamOrDevice s) {
|
2023-11-30 02:30:41 +08:00
|
|
|
auto [a, b] = to_arrays(a_, b_);
|
2024-12-12 07:45:39 +08:00
|
|
|
return mx::array_equal(a, b, equal_nan, s);
|
2023-11-30 02:30:41 +08:00
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::arg(),
|
|
|
|
nb::arg(),
|
|
|
|
nb::kw_only(),
|
2023-11-30 02:30:41 +08:00
|
|
|
"equal_nan"_a = false,
|
2024-03-19 11:12:25 +08:00
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def array_equal(a: Union[scalar, array], b: Union[scalar, array], equal_nan: bool = False, stream: Union[None, Stream, Device] = None) -> array"),
|
2023-11-30 02:30:41 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Array equality check.
|
|
|
|
|
|
|
|
Compare two arrays for equality. Returns ``True`` if and only if the arrays
|
|
|
|
have the same shape and their values are equal. The arrays need not have
|
|
|
|
the same type to be considered equal.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array or scalar.
|
|
|
|
b (array): Input array or scalar.
|
2024-01-26 12:47:06 +08:00
|
|
|
equal_nan (bool): If ``True``, NaNs are considered equal.
|
2023-11-30 02:30:41 +08:00
|
|
|
Defaults to ``False``.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: A scalar boolean array.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"matmul",
|
2024-12-12 07:45:39 +08:00
|
|
|
&mx::matmul,
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::arg(),
|
|
|
|
nb::arg(),
|
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def matmul(a: array, b: array, /, *, stream: Union[None, Stream, Device] = None) -> array"),
|
2023-11-30 02:30:41 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Matrix multiplication.
|
|
|
|
|
|
|
|
Perform the (possibly batched) matrix multiplication of two arrays. This function supports
|
|
|
|
broadcasting for arrays with more than two dimensions.
|
|
|
|
|
|
|
|
- If the first array is 1-D then a 1 is prepended to its shape to make it
|
|
|
|
a matrix. Similarly if the second array is 1-D then a 1 is appended to its
|
|
|
|
shape to make it a matrix. In either case the singleton dimension is removed
|
|
|
|
from the result.
|
|
|
|
- A batched matrix multiplication is performed if the arrays have more than
|
|
|
|
2 dimensions. The matrix dimensions for the matrix product are the last
|
|
|
|
two dimensions of each input.
|
|
|
|
- All but the last two dimensions of each input are broadcast with one another using
|
|
|
|
standard numpy-style broadcasting semantics.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array or scalar.
|
|
|
|
b (array): Input array or scalar.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The matrix product of ``a`` and ``b``.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"square",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const ScalarOrArray& a, mx::StreamOrDevice s) {
|
|
|
|
return mx::square(to_array(a), s);
|
2024-05-10 00:36:02 +08:00
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::arg(),
|
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def square(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"),
|
2023-11-30 02:30:41 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Element-wise square.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The square of ``a``.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"sqrt",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const ScalarOrArray& a, mx::StreamOrDevice s) {
|
|
|
|
return mx::sqrt(to_array(a), s);
|
2024-05-10 00:36:02 +08:00
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::arg(),
|
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def sqrt(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"),
|
2023-11-30 02:30:41 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Element-wise square root.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The square root of ``a``.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"rsqrt",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const ScalarOrArray& a, mx::StreamOrDevice s) {
|
|
|
|
return mx::rsqrt(to_array(a), s);
|
2024-05-10 00:36:02 +08:00
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::arg(),
|
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def rsqrt(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"),
|
2023-11-30 02:30:41 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Element-wise reciprocal and square root.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: One over the square root of ``a``.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"reciprocal",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const ScalarOrArray& a, mx::StreamOrDevice s) {
|
|
|
|
return mx::reciprocal(to_array(a), s);
|
2024-05-10 00:36:02 +08:00
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::arg(),
|
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def reciprocal(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"),
|
2023-11-30 02:30:41 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Element-wise reciprocal.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The reciprocal of ``a``.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"logical_not",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const ScalarOrArray& a, mx::StreamOrDevice s) {
|
|
|
|
return mx::logical_not(to_array(a), s);
|
2023-11-30 02:30:41 +08:00
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::arg(),
|
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def logical_not(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"),
|
2023-11-30 02:30:41 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Element-wise logical not.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array or scalar.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The boolean array containing the logical not of ``a``.
|
|
|
|
)pbdoc");
|
2024-01-08 23:00:05 +08:00
|
|
|
m.def(
|
|
|
|
"logical_and",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const ScalarOrArray& a, const ScalarOrArray& b, mx::StreamOrDevice s) {
|
|
|
|
return mx::logical_and(to_array(a), to_array(b), s);
|
2024-01-08 23:00:05 +08:00
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::arg(),
|
|
|
|
nb::arg(),
|
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def logical_and(a: array, b: array, /, *, stream: Union[None, Stream, Device] = None) -> array"),
|
2024-01-08 23:00:05 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Element-wise logical and.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): First input array or scalar.
|
|
|
|
b (array): Second input array or scalar.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The boolean array containing the logical and of ``a`` and ``b``.
|
|
|
|
)pbdoc");
|
|
|
|
|
|
|
|
m.def(
|
|
|
|
"logical_or",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const ScalarOrArray& a, const ScalarOrArray& b, mx::StreamOrDevice s) {
|
|
|
|
return mx::logical_or(to_array(a), to_array(b), s);
|
2024-01-08 23:00:05 +08:00
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::arg(),
|
|
|
|
nb::arg(),
|
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def logical_or(a: array, b: array, /, *, stream: Union[None, Stream, Device] = None) -> array"),
|
2024-01-08 23:00:05 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Element-wise logical or.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): First input array or scalar.
|
|
|
|
b (array): Second input array or scalar.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The boolean array containing the logical or of ``a`` and ``b``.
|
|
|
|
)pbdoc");
|
2023-11-30 02:30:41 +08:00
|
|
|
m.def(
|
|
|
|
"logaddexp",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const ScalarOrArray& a_,
|
|
|
|
const ScalarOrArray& b_,
|
|
|
|
mx::StreamOrDevice s) {
|
2023-11-30 02:30:41 +08:00
|
|
|
auto [a, b] = to_arrays(a_, b_);
|
2024-12-12 07:45:39 +08:00
|
|
|
return mx::logaddexp(a, b, s);
|
2023-11-30 02:30:41 +08:00
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::arg(),
|
|
|
|
nb::arg(),
|
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def logaddexp(a: Union[scalar, array], b: Union[scalar, array], /, *, stream: Union[None, Stream, Device] = None) -> array"),
|
2023-11-30 02:30:41 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Element-wise log-add-exp.
|
|
|
|
|
|
|
|
This is a numerically stable log-add-exp of two arrays with numpy-style
|
|
|
|
broadcasting semantics. Either or both input arrays can also be scalars.
|
|
|
|
|
|
|
|
The computation is is a numerically stable version of ``log(exp(a) + exp(b))``.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array or scalar.
|
|
|
|
b (array): Input array or scalar.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The log-add-exp of ``a`` and ``b``.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"exp",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const ScalarOrArray& a, mx::StreamOrDevice s) {
|
|
|
|
return mx::exp(to_array(a), s);
|
2024-05-10 00:36:02 +08:00
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::arg(),
|
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def exp(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"),
|
2023-11-30 02:30:41 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Element-wise exponential.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The exponential of ``a``.
|
|
|
|
)pbdoc");
|
2024-04-09 05:26:01 +08:00
|
|
|
m.def(
|
|
|
|
"expm1",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const ScalarOrArray& a, mx::StreamOrDevice s) {
|
|
|
|
return mx::expm1(to_array(a), s);
|
2024-05-10 00:36:02 +08:00
|
|
|
},
|
2024-04-09 05:26:01 +08:00
|
|
|
nb::arg(),
|
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def expm1(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"),
|
|
|
|
R"pbdoc(
|
|
|
|
Element-wise exponential minus 1.
|
|
|
|
|
|
|
|
Computes ``exp(x) - 1`` with greater precision for small ``x``.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The expm1 of ``a``.
|
|
|
|
)pbdoc");
|
2023-11-30 02:30:41 +08:00
|
|
|
m.def(
|
|
|
|
"erf",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const ScalarOrArray& a, mx::StreamOrDevice s) {
|
|
|
|
return mx::erf(to_array(a), s);
|
2024-05-10 00:36:02 +08:00
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::arg(),
|
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def erf(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"),
|
2023-11-30 02:30:41 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Element-wise error function.
|
|
|
|
|
|
|
|
.. math::
|
2024-01-27 07:16:46 +08:00
|
|
|
\mathrm{erf}(x) = \frac{2}{\sqrt{\pi}} \int_0^x e^{-t^2} \, dt
|
2023-11-30 02:30:41 +08:00
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The error function of ``a``.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"erfinv",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const ScalarOrArray& a, mx::StreamOrDevice s) {
|
|
|
|
return mx::erfinv(to_array(a), s);
|
2024-05-10 00:36:02 +08:00
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::arg(),
|
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def erfinv(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"),
|
2023-11-30 02:30:41 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Element-wise inverse of :func:`erf`.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The inverse error function of ``a``.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"sin",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const ScalarOrArray& a, mx::StreamOrDevice s) {
|
|
|
|
return mx::sin(to_array(a), s);
|
2024-05-10 00:36:02 +08:00
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::arg(),
|
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def sin(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"),
|
2023-11-30 02:30:41 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Element-wise sine.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The sine of ``a``.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"cos",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const ScalarOrArray& a, mx::StreamOrDevice s) {
|
|
|
|
return mx::cos(to_array(a), s);
|
2024-05-10 00:36:02 +08:00
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::arg(),
|
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def cos(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"),
|
2023-11-30 02:30:41 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Element-wise cosine.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The cosine of ``a``.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"tan",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const ScalarOrArray& a, mx::StreamOrDevice s) {
|
|
|
|
return mx::tan(to_array(a), s);
|
2024-05-10 00:36:02 +08:00
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::arg(),
|
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def tan(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"),
|
2023-11-30 02:30:41 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Element-wise tangent.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The tangent of ``a``.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"arcsin",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const ScalarOrArray& a, mx::StreamOrDevice s) {
|
|
|
|
return mx::arcsin(to_array(a), s);
|
2024-05-10 00:36:02 +08:00
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::arg(),
|
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def arcsin(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"),
|
2023-11-30 02:30:41 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Element-wise inverse sine.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The inverse sine of ``a``.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"arccos",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const ScalarOrArray& a, mx::StreamOrDevice s) {
|
|
|
|
return mx::arccos(to_array(a), s);
|
2024-05-10 00:36:02 +08:00
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::arg(),
|
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def arccos(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"),
|
2023-11-30 02:30:41 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Element-wise inverse cosine.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The inverse cosine of ``a``.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"arctan",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const ScalarOrArray& a, mx::StreamOrDevice s) {
|
|
|
|
return mx::arctan(to_array(a), s);
|
2024-05-10 00:36:02 +08:00
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::arg(),
|
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def arctan(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"),
|
2023-11-30 02:30:41 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Element-wise inverse tangent.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The inverse tangent of ``a``.
|
|
|
|
)pbdoc");
|
2024-05-08 23:35:15 +08:00
|
|
|
m.def(
|
|
|
|
"arctan2",
|
2024-12-12 07:45:39 +08:00
|
|
|
&mx::arctan2,
|
2024-05-08 23:35:15 +08:00
|
|
|
nb::arg(),
|
|
|
|
nb::arg(),
|
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def arctan2(a: array, b: array, /, *, stream: Union[None, Stream, Device] = None) -> array"),
|
|
|
|
R"pbdoc(
|
|
|
|
Element-wise inverse tangent of the ratio of two arrays.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
b (array): Input array.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The inverse tangent of the ratio of ``a`` and ``b``.
|
|
|
|
)pbdoc");
|
2023-11-30 02:30:41 +08:00
|
|
|
m.def(
|
|
|
|
"sinh",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const ScalarOrArray& a, mx::StreamOrDevice s) {
|
|
|
|
return mx::sinh(to_array(a), s);
|
2024-05-10 00:36:02 +08:00
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::arg(),
|
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def sinh(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"),
|
2023-11-30 02:30:41 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Element-wise hyperbolic sine.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The hyperbolic sine of ``a``.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"cosh",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const ScalarOrArray& a, mx::StreamOrDevice s) {
|
|
|
|
return mx::cosh(to_array(a), s);
|
2024-05-10 00:36:02 +08:00
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::arg(),
|
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def cosh(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"),
|
2023-11-30 02:30:41 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Element-wise hyperbolic cosine.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The hyperbolic cosine of ``a``.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"tanh",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const ScalarOrArray& a, mx::StreamOrDevice s) {
|
|
|
|
return mx::tanh(to_array(a), s);
|
2024-05-10 00:36:02 +08:00
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::arg(),
|
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def tanh(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"),
|
2023-11-30 02:30:41 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Element-wise hyperbolic tangent.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The hyperbolic tangent of ``a``.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"arcsinh",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const ScalarOrArray& a, mx::StreamOrDevice s) {
|
|
|
|
return mx::arcsinh(to_array(a), s);
|
2024-05-10 00:36:02 +08:00
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::arg(),
|
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def arcsinh(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"),
|
2023-11-30 02:30:41 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Element-wise inverse hyperbolic sine.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The inverse hyperbolic sine of ``a``.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"arccosh",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const ScalarOrArray& a, mx::StreamOrDevice s) {
|
|
|
|
return mx::arccosh(to_array(a), s);
|
2024-05-10 00:36:02 +08:00
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::arg(),
|
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def arccosh(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"),
|
2023-11-30 02:30:41 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Element-wise inverse hyperbolic cosine.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The inverse hyperbolic cosine of ``a``.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"arctanh",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const ScalarOrArray& a, mx::StreamOrDevice s) {
|
|
|
|
return mx::arctanh(to_array(a), s);
|
2024-05-10 00:36:02 +08:00
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::arg(),
|
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def arctanh(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"),
|
2023-11-30 02:30:41 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Element-wise inverse hyperbolic tangent.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The inverse hyperbolic tangent of ``a``.
|
|
|
|
)pbdoc");
|
2024-04-23 02:17:49 +08:00
|
|
|
m.def(
|
|
|
|
"degrees",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const ScalarOrArray& a, mx::StreamOrDevice s) {
|
|
|
|
return mx::degrees(to_array(a), s);
|
2024-05-10 00:36:02 +08:00
|
|
|
},
|
2024-04-23 02:17:49 +08:00
|
|
|
nb::arg(),
|
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def degrees(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"),
|
|
|
|
R"pbdoc(
|
|
|
|
Convert angles from radians to degrees.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The angles in degrees.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"radians",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const ScalarOrArray& a, mx::StreamOrDevice s) {
|
|
|
|
return mx::radians(to_array(a), s);
|
2024-05-10 00:36:02 +08:00
|
|
|
},
|
2024-04-23 02:17:49 +08:00
|
|
|
nb::arg(),
|
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def radians(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"),
|
|
|
|
R"pbdoc(
|
|
|
|
Convert angles from degrees to radians.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The angles in radians.
|
|
|
|
)pbdoc");
|
2023-11-30 02:30:41 +08:00
|
|
|
m.def(
|
|
|
|
"log",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const ScalarOrArray& a, mx::StreamOrDevice s) {
|
|
|
|
return mx::log(to_array(a), s);
|
2024-05-10 00:36:02 +08:00
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::arg(),
|
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def log(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"),
|
2023-11-30 02:30:41 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Element-wise natural logarithm.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The natural logarithm of ``a``.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"log2",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const ScalarOrArray& a, mx::StreamOrDevice s) {
|
|
|
|
return mx::log2(to_array(a), s);
|
2024-05-10 00:36:02 +08:00
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::arg(),
|
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def log2(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"),
|
2023-11-30 02:30:41 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Element-wise base-2 logarithm.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The base-2 logarithm of ``a``.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"log10",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const ScalarOrArray& a, mx::StreamOrDevice s) {
|
|
|
|
return mx::log10(to_array(a), s);
|
2024-05-10 00:36:02 +08:00
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::arg(),
|
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def log10(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"),
|
2023-11-30 02:30:41 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Element-wise base-10 logarithm.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The base-10 logarithm of ``a``.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"log1p",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const ScalarOrArray& a, mx::StreamOrDevice s) {
|
|
|
|
return mx::log1p(to_array(a), s);
|
2024-05-10 00:36:02 +08:00
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::arg(),
|
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def log1p(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"),
|
2023-11-30 02:30:41 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Element-wise natural log of one plus the array.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The natural logarithm of one plus ``a``.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"stop_gradient",
|
2024-12-12 07:45:39 +08:00
|
|
|
&mx::stop_gradient,
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::arg(),
|
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def stop_gradient(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"),
|
2023-11-30 02:30:41 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Stop gradients from being computed.
|
|
|
|
|
|
|
|
The operation is the identity but it prevents gradients from flowing
|
|
|
|
through the array.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
|
|
|
|
Returns:
|
2024-07-26 00:36:44 +08:00
|
|
|
array:
|
|
|
|
The unchanged input ``a`` but without gradient flowing
|
2023-11-30 02:30:41 +08:00
|
|
|
through it.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"sigmoid",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const ScalarOrArray& a, mx::StreamOrDevice s) {
|
|
|
|
return mx::sigmoid(to_array(a), s);
|
2024-05-10 00:36:02 +08:00
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::arg(),
|
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def sigmoid(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"),
|
2023-11-30 02:30:41 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Element-wise logistic sigmoid.
|
|
|
|
|
|
|
|
The logistic sigmoid function is:
|
|
|
|
|
|
|
|
.. math::
|
|
|
|
\mathrm{sigmoid}(x) = \frac{1}{1 + e^{-x}}
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The logistic sigmoid of ``a``.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"power",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const ScalarOrArray& a_,
|
|
|
|
const ScalarOrArray& b_,
|
|
|
|
mx::StreamOrDevice s) {
|
2023-11-30 02:30:41 +08:00
|
|
|
auto [a, b] = to_arrays(a_, b_);
|
2024-12-12 07:45:39 +08:00
|
|
|
return mx::power(a, b, s);
|
2023-11-30 02:30:41 +08:00
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::arg(),
|
|
|
|
nb::arg(),
|
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def power(a: Union[scalar, array], b: Union[scalar, array], /, *, stream: Union[None, Stream, Device] = None) -> array"),
|
2023-11-30 02:30:41 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Element-wise power operation.
|
|
|
|
|
|
|
|
Raise the elements of a to the powers in elements of b with numpy-style
|
|
|
|
broadcasting semantics. Either or both input arrays can also be scalars.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array or scalar.
|
|
|
|
b (array): Input array or scalar.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: Bases of ``a`` raised to powers in ``b``.
|
|
|
|
)pbdoc");
|
2023-12-12 05:42:55 +08:00
|
|
|
m.def(
|
|
|
|
"arange",
|
|
|
|
[](Scalar start,
|
|
|
|
Scalar stop,
|
2024-03-25 06:03:27 +08:00
|
|
|
const std::optional<Scalar>& step,
|
2024-12-12 07:45:39 +08:00
|
|
|
const std::optional<mx::Dtype>& dtype_,
|
|
|
|
mx::StreamOrDevice s) {
|
2023-12-12 05:42:55 +08:00
|
|
|
// Determine the final dtype based on input types
|
2024-12-12 07:45:39 +08:00
|
|
|
mx::Dtype dtype = dtype_
|
2024-03-25 06:03:27 +08:00
|
|
|
? *dtype_
|
2024-12-12 07:45:39 +08:00
|
|
|
: mx::promote_types(
|
2023-12-12 05:42:55 +08:00
|
|
|
scalar_to_dtype(start),
|
2024-12-12 07:45:39 +08:00
|
|
|
step ? mx::promote_types(
|
2024-03-25 06:03:27 +08:00
|
|
|
scalar_to_dtype(stop), scalar_to_dtype(*step))
|
|
|
|
: scalar_to_dtype(stop));
|
2024-12-12 07:45:39 +08:00
|
|
|
return mx::arange(
|
2023-12-12 05:42:55 +08:00
|
|
|
scalar_to_double(start),
|
|
|
|
scalar_to_double(stop),
|
2024-03-25 06:03:27 +08:00
|
|
|
step ? scalar_to_double(*step) : 1.0,
|
2023-12-12 05:42:55 +08:00
|
|
|
dtype,
|
|
|
|
s);
|
|
|
|
},
|
|
|
|
"start"_a,
|
|
|
|
"stop"_a,
|
2024-03-25 06:03:27 +08:00
|
|
|
"step"_a = nb::none(),
|
2024-03-19 11:12:25 +08:00
|
|
|
"dtype"_a = nb::none(),
|
2024-10-15 03:45:49 +08:00
|
|
|
nb::kw_only(),
|
2024-03-19 11:12:25 +08:00
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
2024-03-25 06:03:27 +08:00
|
|
|
"def arange(start : Union[int, float], stop : Union[int, float], step : Union[None, int, float], dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array"),
|
2023-12-12 05:42:55 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Generates ranges of numbers.
|
|
|
|
|
|
|
|
Generate numbers in the half-open interval ``[start, stop)`` in
|
|
|
|
increments of ``step``.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
start (float or int, optional): Starting value which defaults to ``0``.
|
|
|
|
stop (float or int): Stopping value.
|
|
|
|
step (float or int, optional): Increment which defaults to ``1``.
|
2024-10-15 07:17:03 +08:00
|
|
|
dtype (Dtype, optional): Specifies the data type of the output. If unspecified will default to ``float32`` if any of ``start``, ``stop``, or ``step`` are ``float``. Otherwise will default to ``int32``.
|
2023-12-12 05:42:55 +08:00
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The range of values.
|
|
|
|
|
|
|
|
Note:
|
|
|
|
Following the Numpy convention the actual increment used to
|
|
|
|
generate numbers is ``dtype(start + step) - dtype(start)``.
|
|
|
|
This can lead to unexpected results for example if `start + step`
|
|
|
|
is a fractional value and the `dtype` is integral.
|
2023-11-30 02:30:41 +08:00
|
|
|
)pbdoc");
|
2024-03-25 06:03:27 +08:00
|
|
|
m.def(
|
|
|
|
"arange",
|
|
|
|
[](Scalar stop,
|
|
|
|
const std::optional<Scalar>& step,
|
2024-12-12 07:45:39 +08:00
|
|
|
const std::optional<mx::Dtype>& dtype_,
|
|
|
|
mx::StreamOrDevice s) {
|
|
|
|
mx::Dtype dtype = dtype_ ? *dtype_
|
2024-03-25 06:03:27 +08:00
|
|
|
: step
|
2024-12-12 07:45:39 +08:00
|
|
|
? mx::promote_types(scalar_to_dtype(stop), scalar_to_dtype(*step))
|
2024-03-25 06:03:27 +08:00
|
|
|
: scalar_to_dtype(stop);
|
2024-12-12 07:45:39 +08:00
|
|
|
return mx::arange(
|
2024-03-25 06:03:27 +08:00
|
|
|
0.0,
|
|
|
|
scalar_to_double(stop),
|
|
|
|
step ? scalar_to_double(*step) : 1.0,
|
|
|
|
dtype,
|
|
|
|
s);
|
|
|
|
},
|
|
|
|
"stop"_a,
|
|
|
|
"step"_a = nb::none(),
|
|
|
|
"dtype"_a = nb::none(),
|
2024-10-15 03:45:49 +08:00
|
|
|
nb::kw_only(),
|
2024-03-25 06:03:27 +08:00
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
2024-08-29 02:16:19 +08:00
|
|
|
"def arange(stop : Union[int, float], step : Union[None, int, float] = None, dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array"));
|
2023-12-19 11:57:55 +08:00
|
|
|
m.def(
|
|
|
|
"linspace",
|
2024-01-06 01:37:46 +08:00
|
|
|
[](Scalar start,
|
|
|
|
Scalar stop,
|
|
|
|
int num,
|
2024-12-12 07:45:39 +08:00
|
|
|
std::optional<mx::Dtype> dtype,
|
|
|
|
mx::StreamOrDevice s) {
|
|
|
|
return mx::linspace(
|
2024-01-06 01:37:46 +08:00
|
|
|
scalar_to_double(start),
|
|
|
|
scalar_to_double(stop),
|
|
|
|
num,
|
2024-12-12 07:45:39 +08:00
|
|
|
dtype.value_or(mx::float32),
|
2024-01-06 01:37:46 +08:00
|
|
|
s);
|
2023-12-19 11:57:55 +08:00
|
|
|
},
|
|
|
|
"start"_a,
|
|
|
|
"stop"_a,
|
|
|
|
"num"_a = 50,
|
2024-12-12 07:45:39 +08:00
|
|
|
"dtype"_a.none() = mx::float32,
|
2024-03-19 11:12:25 +08:00
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def linspace(start, stop, num: Optional[int] = 50, dtype: Optional[Dtype] = float32, stream: Union[None, Stream, Device] = None) -> array"),
|
2023-12-19 11:57:55 +08:00
|
|
|
R"pbdoc(
|
2024-03-19 11:12:25 +08:00
|
|
|
Generate ``num`` evenly spaced numbers over interval ``[start, stop]``.
|
2023-12-19 11:57:55 +08:00
|
|
|
|
2024-03-19 11:12:25 +08:00
|
|
|
Args:
|
|
|
|
start (scalar): Starting value.
|
|
|
|
stop (scalar): Stopping value.
|
|
|
|
num (int, optional): Number of samples, defaults to ``50``.
|
|
|
|
dtype (Dtype, optional): Specifies the data type of the output,
|
|
|
|
default to ``float32``.
|
2023-12-19 11:57:55 +08:00
|
|
|
|
2024-03-19 11:12:25 +08:00
|
|
|
Returns:
|
|
|
|
array: The range of values.
|
2023-12-19 11:57:55 +08:00
|
|
|
)pbdoc");
|
2023-11-30 02:30:41 +08:00
|
|
|
m.def(
|
|
|
|
"take",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const mx::array& a,
|
|
|
|
const std::variant<nb::int_, mx::array>& indices,
|
2023-11-30 02:30:41 +08:00
|
|
|
const std::optional<int>& axis,
|
2024-12-12 07:45:39 +08:00
|
|
|
mx::StreamOrDevice s) {
|
2024-12-10 10:57:38 +08:00
|
|
|
if (auto pv = std::get_if<nb::int_>(&indices); pv) {
|
|
|
|
auto idx = nb::cast<int>(*pv);
|
2024-12-12 07:45:39 +08:00
|
|
|
return axis ? mx::take(a, idx, axis.value(), s) : mx::take(a, idx, s);
|
2023-11-30 02:30:41 +08:00
|
|
|
} else {
|
2024-12-12 07:45:39 +08:00
|
|
|
auto indices_ = std::get<mx::array>(indices);
|
|
|
|
return axis ? mx::take(a, indices_, axis.value(), s)
|
|
|
|
: mx::take(a, indices_, s);
|
2023-11-30 02:30:41 +08:00
|
|
|
}
|
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::arg(),
|
2023-11-30 02:30:41 +08:00
|
|
|
"indices"_a,
|
2024-03-19 11:12:25 +08:00
|
|
|
"axis"_a = nb::none(),
|
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
2024-09-27 06:58:03 +08:00
|
|
|
"def take(a: array, /, indices: Union[int, array], axis: Optional[int] = None, *, stream: Union[None, Stream, Device] = None) -> array"),
|
2023-11-30 02:30:41 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Take elements along an axis.
|
|
|
|
|
|
|
|
The elements are taken from ``indices`` along the specified axis.
|
|
|
|
If the axis is not specified the array is treated as a flattened
|
|
|
|
1-D array prior to performing the take.
|
|
|
|
|
2024-01-02 13:08:17 +08:00
|
|
|
As an example, if the ``axis=1`` this is equivalent to ``a[:, indices, ...]``.
|
2023-11-30 02:30:41 +08:00
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
2024-09-27 06:58:03 +08:00
|
|
|
indices (int or array): Integer index or input array with integral type.
|
2023-11-30 02:30:41 +08:00
|
|
|
axis (int, optional): Axis along which to perform the take. If unspecified
|
|
|
|
the array is treated as a flattened 1-D vector.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The indexed values of ``a``.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"take_along_axis",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const mx::array& a,
|
|
|
|
const mx::array& indices,
|
2023-11-30 02:30:41 +08:00
|
|
|
const std::optional<int>& axis,
|
2024-12-12 07:45:39 +08:00
|
|
|
mx::StreamOrDevice s) {
|
2023-11-30 02:30:41 +08:00
|
|
|
if (axis.has_value()) {
|
2024-12-12 07:45:39 +08:00
|
|
|
return mx::take_along_axis(a, indices, axis.value(), s);
|
2023-11-30 02:30:41 +08:00
|
|
|
} else {
|
2024-12-12 07:45:39 +08:00
|
|
|
return mx::take_along_axis(mx::reshape(a, {-1}, s), indices, 0, s);
|
2023-11-30 02:30:41 +08:00
|
|
|
}
|
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::arg(),
|
2023-11-30 02:30:41 +08:00
|
|
|
"indices"_a,
|
2024-03-19 11:12:25 +08:00
|
|
|
"axis"_a.none(),
|
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def take_along_axis(a: array, /, indices: array, axis: Optional[int] = None, *, stream: Union[None, Stream, Device] = None) -> array"),
|
2023-11-30 02:30:41 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Take values along an axis at the specified indices.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
indices (array): Indices array. These should be broadcastable with
|
|
|
|
the input array excluding the `axis` dimension.
|
|
|
|
axis (int or None): Axis in the input to take the values from. If
|
|
|
|
``axis == None`` the array is flattened to 1D prior to the indexing
|
|
|
|
operation.
|
|
|
|
|
|
|
|
Returns:
|
2024-09-24 01:03:38 +08:00
|
|
|
array: The output array.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"put_along_axis",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const mx::array& a,
|
|
|
|
const mx::array& indices,
|
|
|
|
const mx::array& values,
|
2024-09-24 01:03:38 +08:00
|
|
|
const std::optional<int>& axis,
|
2024-12-12 07:45:39 +08:00
|
|
|
mx::StreamOrDevice s) {
|
2024-09-24 01:03:38 +08:00
|
|
|
if (axis.has_value()) {
|
2024-12-12 07:45:39 +08:00
|
|
|
return mx::put_along_axis(a, indices, values, axis.value(), s);
|
2024-09-24 01:03:38 +08:00
|
|
|
} else {
|
2024-12-12 07:45:39 +08:00
|
|
|
return mx::reshape(
|
|
|
|
mx::put_along_axis(
|
|
|
|
mx::reshape(a, {-1}, s), indices, values, 0, s),
|
2024-09-24 01:03:38 +08:00
|
|
|
a.shape(),
|
|
|
|
s);
|
|
|
|
}
|
|
|
|
},
|
|
|
|
nb::arg(),
|
|
|
|
"indices"_a,
|
|
|
|
"values"_a,
|
|
|
|
"axis"_a.none(),
|
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def put_along_axis(a: array, /, indices: array, values: array, axis: Optional[int] = None, *, stream: Union[None, Stream, Device] = None) -> array"),
|
|
|
|
R"pbdoc(
|
|
|
|
Put values along an axis at the specified indices.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Destination array.
|
|
|
|
indices (array): Indices array. These should be broadcastable with
|
|
|
|
the input array excluding the `axis` dimension.
|
|
|
|
values (array): Values array. These should be broadcastable with
|
|
|
|
the indices.
|
|
|
|
|
|
|
|
axis (int or None): Axis in the destination to put the values to. If
|
|
|
|
``axis == None`` the destination is flattened prior to the put
|
|
|
|
operation.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The output array.
|
2023-11-30 02:30:41 +08:00
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"full",
|
|
|
|
[](const std::variant<int, std::vector<int>>& shape,
|
|
|
|
const ScalarOrArray& vals,
|
2024-12-12 07:45:39 +08:00
|
|
|
std::optional<mx::Dtype> dtype,
|
|
|
|
mx::StreamOrDevice s) {
|
2023-11-30 02:30:41 +08:00
|
|
|
if (auto pv = std::get_if<int>(&shape); pv) {
|
2024-12-12 07:45:39 +08:00
|
|
|
return mx::full({*pv}, to_array(vals, dtype), s);
|
2023-11-30 02:30:41 +08:00
|
|
|
} else {
|
2024-12-12 07:45:39 +08:00
|
|
|
return mx::full(
|
2023-11-30 02:30:41 +08:00
|
|
|
std::get<std::vector<int>>(shape), to_array(vals, dtype), s);
|
|
|
|
}
|
|
|
|
},
|
|
|
|
"shape"_a,
|
|
|
|
"vals"_a,
|
2024-03-19 11:12:25 +08:00
|
|
|
"dtype"_a = nb::none(),
|
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
2024-03-25 06:03:27 +08:00
|
|
|
"def full(shape: Union[int, Sequence[int]], vals: Union[scalar, array], dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array"),
|
2023-11-30 02:30:41 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Construct an array with the given value.
|
|
|
|
|
|
|
|
Constructs an array of size ``shape`` filled with ``vals``. If ``vals``
|
|
|
|
is an :obj:`array` it must be broadcastable to the given ``shape``.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
shape (int or list(int)): The shape of the output array.
|
|
|
|
vals (float or int or array): Values to fill the array with.
|
|
|
|
dtype (Dtype, optional): Data type of the output array. If
|
|
|
|
unspecified the output type is inferred from ``vals``.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The output array with the specified shape and values.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"zeros",
|
|
|
|
[](const std::variant<int, std::vector<int>>& shape,
|
2024-12-12 07:45:39 +08:00
|
|
|
std::optional<mx::Dtype> dtype,
|
|
|
|
mx::StreamOrDevice s) {
|
|
|
|
auto t = dtype.value_or(mx::float32);
|
2023-11-30 02:30:41 +08:00
|
|
|
if (auto pv = std::get_if<int>(&shape); pv) {
|
2024-12-12 07:45:39 +08:00
|
|
|
return mx::zeros({*pv}, t, s);
|
2023-11-30 02:30:41 +08:00
|
|
|
} else {
|
2024-12-12 07:45:39 +08:00
|
|
|
return mx::zeros(std::get<std::vector<int>>(shape), t, s);
|
2023-11-30 02:30:41 +08:00
|
|
|
}
|
|
|
|
},
|
|
|
|
"shape"_a,
|
2024-12-12 07:45:39 +08:00
|
|
|
"dtype"_a.none() = mx::float32,
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
2024-03-25 06:03:27 +08:00
|
|
|
"def zeros(shape: Union[int, Sequence[int]], dtype: Optional[Dtype] = float32, *, stream: Union[None, Stream, Device] = None) -> array"),
|
2023-11-30 02:30:41 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Construct an array of zeros.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
shape (int or list(int)): The shape of the output array.
|
|
|
|
dtype (Dtype, optional): Data type of the output array. If
|
|
|
|
unspecified the output type defaults to ``float32``.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The array of zeros with the specified shape.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"zeros_like",
|
2024-12-12 07:45:39 +08:00
|
|
|
&mx::zeros_like,
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::arg(),
|
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def zeros_like(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"),
|
2023-11-30 02:30:41 +08:00
|
|
|
R"pbdoc(
|
|
|
|
An array of zeros like the input.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): The input to take the shape and type from.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The output array filled with zeros.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"ones",
|
|
|
|
[](const std::variant<int, std::vector<int>>& shape,
|
2024-12-12 07:45:39 +08:00
|
|
|
std::optional<mx::Dtype> dtype,
|
|
|
|
mx::StreamOrDevice s) {
|
|
|
|
auto t = dtype.value_or(mx::float32);
|
2023-11-30 02:30:41 +08:00
|
|
|
if (auto pv = std::get_if<int>(&shape); pv) {
|
2024-12-12 07:45:39 +08:00
|
|
|
return mx::ones({*pv}, t, s);
|
2023-11-30 02:30:41 +08:00
|
|
|
} else {
|
2024-12-12 07:45:39 +08:00
|
|
|
return mx::ones(std::get<std::vector<int>>(shape), t, s);
|
2023-11-30 02:30:41 +08:00
|
|
|
}
|
|
|
|
},
|
|
|
|
"shape"_a,
|
2024-12-12 07:45:39 +08:00
|
|
|
"dtype"_a.none() = mx::float32,
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
2024-03-25 06:03:27 +08:00
|
|
|
"def ones(shape: Union[int, Sequence[int]], dtype: Optional[Dtype] = float32, *, stream: Union[None, Stream, Device] = None) -> array"),
|
2023-11-30 02:30:41 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Construct an array of ones.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
shape (int or list(int)): The shape of the output array.
|
|
|
|
dtype (Dtype, optional): Data type of the output array. If
|
|
|
|
unspecified the output type defaults to ``float32``.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The array of ones with the specified shape.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"ones_like",
|
2024-12-12 07:45:39 +08:00
|
|
|
&mx::ones_like,
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::arg(),
|
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def ones_like(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"),
|
2023-11-30 02:30:41 +08:00
|
|
|
R"pbdoc(
|
|
|
|
An array of ones like the input.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): The input to take the shape and type from.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The output array filled with ones.
|
|
|
|
)pbdoc");
|
2023-12-12 04:38:17 +08:00
|
|
|
m.def(
|
|
|
|
"eye",
|
|
|
|
[](int n,
|
2023-12-12 07:01:41 +08:00
|
|
|
std::optional<int> m,
|
|
|
|
int k,
|
2024-12-12 07:45:39 +08:00
|
|
|
std::optional<mx::Dtype> dtype,
|
|
|
|
mx::StreamOrDevice s) {
|
|
|
|
return mx::eye(n, m.value_or(n), k, dtype.value_or(mx::float32), s);
|
2023-12-12 04:38:17 +08:00
|
|
|
},
|
|
|
|
"n"_a,
|
2024-03-19 11:12:25 +08:00
|
|
|
"m"_a = nb::none(),
|
2023-12-12 07:01:41 +08:00
|
|
|
"k"_a = 0,
|
2024-12-12 07:45:39 +08:00
|
|
|
"dtype"_a.none() = mx::float32,
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def eye(n: int, m: Optional[int] = None, k: int = 0, dtype: Optional[Dtype] = float32, *, stream: Union[None, Stream, Device] = None) -> array"),
|
2023-12-12 04:38:17 +08:00
|
|
|
R"pbdoc(
|
2024-03-19 11:12:25 +08:00
|
|
|
Create an identity matrix or a general diagonal matrix.
|
2023-12-12 04:38:17 +08:00
|
|
|
|
2024-03-19 11:12:25 +08:00
|
|
|
Args:
|
|
|
|
n (int): The number of rows in the output.
|
|
|
|
m (int, optional): The number of columns in the output. Defaults to n.
|
|
|
|
k (int, optional): Index of the diagonal. Defaults to 0 (main diagonal).
|
|
|
|
dtype (Dtype, optional): Data type of the output array. Defaults to float32.
|
|
|
|
stream (Stream, optional): Stream or device. Defaults to None.
|
2023-12-12 04:38:17 +08:00
|
|
|
|
2024-03-19 11:12:25 +08:00
|
|
|
Returns:
|
|
|
|
array: An array where all elements are equal to zero, except for the k-th diagonal, whose values are equal to one.
|
2023-12-12 07:01:41 +08:00
|
|
|
)pbdoc");
|
2023-12-12 04:38:17 +08:00
|
|
|
m.def(
|
|
|
|
"identity",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](int n, std::optional<mx::Dtype> dtype, mx::StreamOrDevice s) {
|
|
|
|
return mx::identity(n, dtype.value_or(mx::float32), s);
|
2023-12-12 07:01:41 +08:00
|
|
|
},
|
2023-12-12 04:38:17 +08:00
|
|
|
"n"_a,
|
2024-12-12 07:45:39 +08:00
|
|
|
"dtype"_a.none() = mx::float32,
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def identity(n: int, dtype: Optional[Dtype] = float32, *, stream: Union[None, Stream, Device] = None) -> array"),
|
2023-12-12 04:38:17 +08:00
|
|
|
R"pbdoc(
|
2024-03-19 11:12:25 +08:00
|
|
|
Create a square identity matrix.
|
2023-12-12 04:38:17 +08:00
|
|
|
|
2024-03-19 11:12:25 +08:00
|
|
|
Args:
|
|
|
|
n (int): The number of rows and columns in the output.
|
|
|
|
dtype (Dtype, optional): Data type of the output array. Defaults to float32.
|
|
|
|
stream (Stream, optional): Stream or device. Defaults to None.
|
2023-12-12 04:38:17 +08:00
|
|
|
|
2024-03-19 11:12:25 +08:00
|
|
|
Returns:
|
|
|
|
array: An identity matrix of size n x n.
|
2023-12-12 07:01:41 +08:00
|
|
|
)pbdoc");
|
2023-12-16 09:30:34 +08:00
|
|
|
m.def(
|
|
|
|
"tri",
|
2024-01-06 01:37:46 +08:00
|
|
|
[](int n,
|
|
|
|
std::optional<int> m,
|
|
|
|
int k,
|
2024-12-12 07:45:39 +08:00
|
|
|
std::optional<mx::Dtype> type,
|
|
|
|
mx::StreamOrDevice s) {
|
|
|
|
return mx::tri(n, m.value_or(n), k, type.value_or(mx::float32), s);
|
2023-12-16 09:30:34 +08:00
|
|
|
},
|
|
|
|
"n"_a,
|
2024-03-19 11:12:25 +08:00
|
|
|
"m"_a = nb::none(),
|
2023-12-16 09:30:34 +08:00
|
|
|
"k"_a = 0,
|
2024-12-12 07:45:39 +08:00
|
|
|
"dtype"_a.none() = mx::float32,
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def tri(n: int, m: int, k: int, dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array"),
|
2023-12-16 09:30:34 +08:00
|
|
|
R"pbdoc(
|
|
|
|
An array with ones at and below the given diagonal and zeros elsewhere.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
n (int): The number of rows in the output.
|
|
|
|
m (int, optional): The number of cols in the output. Defaults to ``None``.
|
|
|
|
k (int, optional): The diagonal of the 2-D array. Defaults to ``0``.
|
|
|
|
dtype (Dtype, optional): Data type of the output array. Defaults to ``float32``.
|
|
|
|
stream (Stream, optional): Stream or device. Defaults to ``None``.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: Array with its lower triangle filled with ones and zeros elsewhere
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"tril",
|
2024-12-12 07:45:39 +08:00
|
|
|
&mx::tril,
|
2023-12-16 09:30:34 +08:00
|
|
|
"x"_a,
|
|
|
|
"k"_a = 0,
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def tril(x: array, k: int, *, stream: Union[None, Stream, Device] = None) -> array"),
|
2023-12-16 09:30:34 +08:00
|
|
|
R"pbdoc(
|
2024-03-19 11:12:25 +08:00
|
|
|
Zeros the array above the given diagonal.
|
2023-12-16 09:30:34 +08:00
|
|
|
|
2024-03-19 11:12:25 +08:00
|
|
|
Args:
|
|
|
|
x (array): input array.
|
|
|
|
k (int, optional): The diagonal of the 2-D array. Defaults to ``0``.
|
|
|
|
stream (Stream, optional): Stream or device. Defaults to ``None``.
|
2023-12-16 09:30:34 +08:00
|
|
|
|
2024-03-19 11:12:25 +08:00
|
|
|
Returns:
|
|
|
|
array: Array zeroed above the given diagonal
|
|
|
|
)pbdoc");
|
2023-12-16 09:30:34 +08:00
|
|
|
m.def(
|
|
|
|
"triu",
|
2024-12-12 07:45:39 +08:00
|
|
|
&mx::triu,
|
2023-12-16 09:30:34 +08:00
|
|
|
"x"_a,
|
|
|
|
"k"_a = 0,
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def triu(x: array, k: int, *, stream: Union[None, Stream, Device] = None) -> array"),
|
2023-12-16 09:30:34 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Zeros the array below the given diagonal.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
x (array): input array.
|
|
|
|
k (int, optional): The diagonal of the 2-D array. Defaults to ``0``.
|
|
|
|
stream (Stream, optional): Stream or device. Defaults to ``None``.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: Array zeroed below the given diagonal
|
|
|
|
)pbdoc");
|
2023-11-30 02:30:41 +08:00
|
|
|
m.def(
|
|
|
|
"allclose",
|
2024-12-12 07:45:39 +08:00
|
|
|
&mx::allclose,
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::arg(),
|
|
|
|
nb::arg(),
|
2023-11-30 02:30:41 +08:00
|
|
|
"rtol"_a = 1e-5,
|
|
|
|
"atol"_a = 1e-8,
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::kw_only(),
|
2024-01-26 12:47:06 +08:00
|
|
|
"equal_nan"_a = false,
|
2024-03-19 11:12:25 +08:00
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def allclose(a: array, b: array, /, rtol: float = 1e-05, atol: float = 1e-08, *, equal_nan: bool = False, stream: Union[None, Stream, Device] = None) -> array"),
|
2023-11-30 02:30:41 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Approximate comparison of two arrays.
|
|
|
|
|
2024-01-26 12:47:06 +08:00
|
|
|
Infinite values are considered equal if they have the same sign, NaN values are not equal unless ``equal_nan`` is ``True``.
|
|
|
|
|
2023-11-30 02:30:41 +08:00
|
|
|
The arrays are considered equal if:
|
|
|
|
|
|
|
|
.. code-block::
|
|
|
|
|
|
|
|
all(abs(a - b) <= (atol + rtol * abs(b)))
|
|
|
|
|
|
|
|
Note unlike :func:`array_equal`, this function supports numpy-style
|
|
|
|
broadcasting.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
b (array): Input array.
|
|
|
|
rtol (float): Relative tolerance.
|
|
|
|
atol (float): Absolute tolerance.
|
2024-01-26 12:47:06 +08:00
|
|
|
equal_nan (bool): If ``True``, NaNs are considered equal.
|
|
|
|
Defaults to ``False``.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The boolean output scalar indicating if the arrays are close.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"isclose",
|
2024-12-12 07:45:39 +08:00
|
|
|
&mx::isclose,
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::arg(),
|
|
|
|
nb::arg(),
|
2024-01-26 12:47:06 +08:00
|
|
|
"rtol"_a = 1e-5,
|
|
|
|
"atol"_a = 1e-8,
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::kw_only(),
|
2024-01-26 12:47:06 +08:00
|
|
|
"equal_nan"_a = false,
|
2024-03-19 11:12:25 +08:00
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def isclose(a: array, b: array, /, rtol: float = 1e-05, atol: float = 1e-08, *, equal_nan: bool = False, stream: Union[None, Stream, Device] = None) -> array"),
|
2024-01-26 12:47:06 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Returns a boolean array where two arrays are element-wise equal within a tolerance.
|
|
|
|
|
|
|
|
Infinite values are considered equal if they have the same sign, NaN values are
|
|
|
|
not equal unless ``equal_nan`` is ``True``.
|
|
|
|
|
|
|
|
Two values are considered equal if:
|
|
|
|
|
|
|
|
.. code-block::
|
|
|
|
|
|
|
|
abs(a - b) <= (atol + rtol * abs(b))
|
|
|
|
|
|
|
|
Note unlike :func:`array_equal`, this function supports numpy-style
|
|
|
|
broadcasting.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
b (array): Input array.
|
|
|
|
rtol (float): Relative tolerance.
|
|
|
|
atol (float): Absolute tolerance.
|
|
|
|
equal_nan (bool): If ``True``, NaNs are considered equal.
|
|
|
|
Defaults to ``False``.
|
2023-11-30 02:30:41 +08:00
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The boolean output scalar indicating if the arrays are close.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"all",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const mx::array& a,
|
2023-11-30 02:30:41 +08:00
|
|
|
const IntOrVec& axis,
|
|
|
|
bool keepdims,
|
2024-12-12 07:45:39 +08:00
|
|
|
mx::StreamOrDevice s) {
|
|
|
|
return mx::all(a, get_reduce_axes(axis, a.ndim()), keepdims, s);
|
2023-11-30 02:30:41 +08:00
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::arg(),
|
|
|
|
"axis"_a = nb::none(),
|
2023-11-30 02:30:41 +08:00
|
|
|
"keepdims"_a = false,
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
2024-03-25 06:03:27 +08:00
|
|
|
"def all(a: array, /, axis: Union[None, int, Sequence[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"),
|
2023-11-30 02:30:41 +08:00
|
|
|
R"pbdoc(
|
|
|
|
An `and` reduction over the given axes.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
axis (int or list(int), optional): Optional axis or
|
|
|
|
axes to reduce over. If unspecified this defaults
|
|
|
|
to reducing over the entire array.
|
|
|
|
keepdims (bool, optional): Keep reduced axes as
|
|
|
|
singleton dimensions, defaults to `False`.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The output array with the corresponding axes reduced.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"any",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const mx::array& a,
|
2023-11-30 02:30:41 +08:00
|
|
|
const IntOrVec& axis,
|
|
|
|
bool keepdims,
|
2024-12-12 07:45:39 +08:00
|
|
|
mx::StreamOrDevice s) {
|
|
|
|
return mx::any(a, get_reduce_axes(axis, a.ndim()), keepdims, s);
|
2023-11-30 02:30:41 +08:00
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::arg(),
|
|
|
|
"axis"_a = nb::none(),
|
2023-11-30 02:30:41 +08:00
|
|
|
"keepdims"_a = false,
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
2024-03-25 06:03:27 +08:00
|
|
|
"def any(a: array, /, axis: Union[None, int, Sequence[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"),
|
2023-11-30 02:30:41 +08:00
|
|
|
R"pbdoc(
|
|
|
|
An `or` reduction over the given axes.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
axis (int or list(int), optional): Optional axis or
|
|
|
|
axes to reduce over. If unspecified this defaults
|
|
|
|
to reducing over the entire array.
|
|
|
|
keepdims (bool, optional): Keep reduced axes as
|
|
|
|
singleton dimensions, defaults to `False`.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The output array with the corresponding axes reduced.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"minimum",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const ScalarOrArray& a_,
|
|
|
|
const ScalarOrArray& b_,
|
|
|
|
mx::StreamOrDevice s) {
|
2023-11-30 02:30:41 +08:00
|
|
|
auto [a, b] = to_arrays(a_, b_);
|
2024-12-12 07:45:39 +08:00
|
|
|
return mx::minimum(a, b, s);
|
2023-11-30 02:30:41 +08:00
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::arg(),
|
|
|
|
nb::arg(),
|
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def minimum(a: Union[scalar, array], b: Union[scalar, array], /, *, stream: Union[None, Stream, Device] = None) -> array"),
|
2023-11-30 02:30:41 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Element-wise minimum.
|
|
|
|
|
|
|
|
Take the element-wise min of two arrays with numpy-style broadcasting
|
|
|
|
semantics. Either or both input arrays can also be scalars.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array or scalar.
|
|
|
|
b (array): Input array or scalar.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The min of ``a`` and ``b``.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"maximum",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const ScalarOrArray& a_,
|
|
|
|
const ScalarOrArray& b_,
|
|
|
|
mx::StreamOrDevice s) {
|
2023-11-30 02:30:41 +08:00
|
|
|
auto [a, b] = to_arrays(a_, b_);
|
2024-12-12 07:45:39 +08:00
|
|
|
return mx::maximum(a, b, s);
|
2023-11-30 02:30:41 +08:00
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::arg(),
|
|
|
|
nb::arg(),
|
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def maximum(a: Union[scalar, array], b: Union[scalar, array], /, *, stream: Union[None, Stream, Device] = None) -> array"),
|
2023-11-30 02:30:41 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Element-wise maximum.
|
|
|
|
|
|
|
|
Take the element-wise max of two arrays with numpy-style broadcasting
|
|
|
|
semantics. Either or both input arrays can also be scalars.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array or scalar.
|
|
|
|
b (array): Input array or scalar.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The max of ``a`` and ``b``.
|
|
|
|
)pbdoc");
|
2023-12-15 02:00:23 +08:00
|
|
|
m.def(
|
|
|
|
"floor",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const ScalarOrArray& a, mx::StreamOrDevice s) {
|
|
|
|
return mx::floor(to_array(a), s);
|
2024-05-10 00:36:02 +08:00
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::arg(),
|
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def floor(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"),
|
2023-12-15 02:00:23 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Element-wise floor.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The floor of ``a``.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"ceil",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const ScalarOrArray& a, mx::StreamOrDevice s) {
|
|
|
|
return mx::ceil(to_array(a), s);
|
2024-05-10 00:36:02 +08:00
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::arg(),
|
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def ceil(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"),
|
2023-12-15 02:00:23 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Element-wise ceil.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The ceil of ``a``.
|
|
|
|
)pbdoc");
|
2024-01-13 03:16:48 +08:00
|
|
|
m.def(
|
|
|
|
"isnan",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const ScalarOrArray& a, mx::StreamOrDevice s) {
|
|
|
|
return mx::isnan(to_array(a), s);
|
2024-05-10 00:36:02 +08:00
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::arg(),
|
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def isnan(a: array, stream: Union[None, Stream, Device] = None) -> array"),
|
2024-01-13 03:16:48 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Return a boolean array indicating which elements are NaN.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
|
|
|
|
Returns:
|
2024-01-16 11:50:44 +08:00
|
|
|
array: The boolean array indicating which elements are NaN.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"isinf",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const ScalarOrArray& a, mx::StreamOrDevice s) {
|
|
|
|
return mx::isinf(to_array(a), s);
|
2024-05-10 00:36:02 +08:00
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::arg(),
|
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def isinf(a: array, stream: Union[None, Stream, Device] = None) -> array"),
|
2024-01-16 11:50:44 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Return a boolean array indicating which elements are +/- inifnity.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The boolean array indicating which elements are +/- infinity.
|
2024-01-13 03:16:48 +08:00
|
|
|
)pbdoc");
|
2024-08-14 05:49:28 +08:00
|
|
|
m.def(
|
|
|
|
"isfinite",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const ScalarOrArray& a, mx::StreamOrDevice s) {
|
|
|
|
return mx::isfinite(to_array(a), s);
|
2024-08-14 05:49:28 +08:00
|
|
|
},
|
|
|
|
nb::arg(),
|
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def isfinite(a: array, stream: Union[None, Stream, Device] = None) -> array"),
|
|
|
|
R"pbdoc(
|
|
|
|
Return a boolean array indicating which elements are finite.
|
|
|
|
|
|
|
|
An element is finite if it is not infinite or NaN.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The boolean array indicating which elements are finite.
|
|
|
|
)pbdoc");
|
2024-01-16 22:48:07 +08:00
|
|
|
m.def(
|
|
|
|
"isposinf",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const ScalarOrArray& a, mx::StreamOrDevice s) {
|
|
|
|
return mx::isposinf(to_array(a), s);
|
2024-05-10 00:36:02 +08:00
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::arg(),
|
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def isposinf(a: array, stream: Union[None, Stream, Device] = None) -> array"),
|
2024-01-16 22:48:07 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Return a boolean array indicating which elements are positive infinity.
|
2024-01-20 06:06:05 +08:00
|
|
|
|
2024-01-16 22:48:07 +08:00
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
stream (Union[None, Stream, Device]): Optional stream or device.
|
2024-01-20 06:06:05 +08:00
|
|
|
|
2024-01-16 22:48:07 +08:00
|
|
|
Returns:
|
|
|
|
array: The boolean array indicating which elements are positive infinity.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"isneginf",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const ScalarOrArray& a, mx::StreamOrDevice s) {
|
|
|
|
return mx::isneginf(to_array(a), s);
|
2024-05-10 00:36:02 +08:00
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::arg(),
|
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def isneginf(a: array, stream: Union[None, Stream, Device] = None) -> array"),
|
2024-01-16 22:48:07 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Return a boolean array indicating which elements are negative infinity.
|
2024-01-20 06:06:05 +08:00
|
|
|
|
2024-01-16 22:48:07 +08:00
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
stream (Union[None, Stream, Device]): Optional stream or device.
|
2024-01-20 06:06:05 +08:00
|
|
|
|
2024-01-16 22:48:07 +08:00
|
|
|
Returns:
|
|
|
|
array: The boolean array indicating which elements are negative infinity.
|
|
|
|
)pbdoc");
|
2023-12-15 04:59:12 +08:00
|
|
|
m.def(
|
|
|
|
"moveaxis",
|
2024-12-12 07:45:39 +08:00
|
|
|
&mx::moveaxis,
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::arg(),
|
2023-12-15 04:59:12 +08:00
|
|
|
"source"_a,
|
2024-01-02 13:08:17 +08:00
|
|
|
"destination"_a,
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def moveaxis(a: array, /, source: int, destination: int, *, stream: Union[None, Stream, Device] = None) -> array"),
|
2023-12-15 04:59:12 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Move an axis to a new position.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
source (int): Specifies the source axis.
|
|
|
|
destination (int): Specifies the destination axis.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The array with the axis moved.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"swapaxes",
|
2024-12-12 07:45:39 +08:00
|
|
|
&mx::swapaxes,
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::arg(),
|
2023-12-15 04:59:12 +08:00
|
|
|
"axis1"_a,
|
|
|
|
"axis2"_a,
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def swapaxes(a: array, /, axis1 : int, axis2: int, *, stream: Union[None, Stream, Device] = None) -> array"),
|
2023-12-15 04:59:12 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Swap two axes of an array.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
axis1 (int): Specifies the first axis.
|
|
|
|
axis2 (int): Specifies the second axis.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The array with swapped axes.
|
|
|
|
)pbdoc");
|
2023-11-30 02:30:41 +08:00
|
|
|
m.def(
|
|
|
|
"transpose",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const mx::array& a,
|
2023-11-30 02:30:41 +08:00
|
|
|
const std::optional<std::vector<int>>& axes,
|
2024-12-12 07:45:39 +08:00
|
|
|
mx::StreamOrDevice s) {
|
2023-11-30 02:30:41 +08:00
|
|
|
if (axes.has_value()) {
|
2024-12-12 07:45:39 +08:00
|
|
|
return mx::transpose(a, *axes, s);
|
2023-11-30 02:30:41 +08:00
|
|
|
} else {
|
2024-12-12 07:45:39 +08:00
|
|
|
return mx::transpose(a, s);
|
2023-11-30 02:30:41 +08:00
|
|
|
}
|
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::arg(),
|
|
|
|
"axes"_a = nb::none(),
|
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
2024-03-25 06:03:27 +08:00
|
|
|
"def transpose(a: array, /, axes: Optional[Sequence[int]] = None, *, stream: Union[None, Stream, Device] = None) -> array"),
|
2023-11-30 02:30:41 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Transpose the dimensions of the array.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
axes (list(int), optional): Specifies the source axis for each axis
|
|
|
|
in the new array. The default is to reverse the axes.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The transposed array.
|
|
|
|
)pbdoc");
|
2024-07-27 01:40:49 +08:00
|
|
|
m.def(
|
|
|
|
"permute_dims",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const mx::array& a,
|
2024-07-27 01:40:49 +08:00
|
|
|
const std::optional<std::vector<int>>& axes,
|
2024-12-12 07:45:39 +08:00
|
|
|
mx::StreamOrDevice s) {
|
2024-07-27 01:40:49 +08:00
|
|
|
if (axes.has_value()) {
|
2024-12-12 07:45:39 +08:00
|
|
|
return mx::transpose(a, *axes, s);
|
2024-07-27 01:40:49 +08:00
|
|
|
} else {
|
2024-12-12 07:45:39 +08:00
|
|
|
return mx::transpose(a, s);
|
2024-07-27 01:40:49 +08:00
|
|
|
}
|
|
|
|
},
|
|
|
|
nb::arg(),
|
|
|
|
"axes"_a = nb::none(),
|
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def permute_dims(a: array, /, axes: Optional[Sequence[int]] = None, *, stream: Union[None, Stream, Device] = None) -> array"),
|
|
|
|
R"pbdoc(
|
|
|
|
See :func:`transpose`.
|
|
|
|
)pbdoc");
|
2023-11-30 02:30:41 +08:00
|
|
|
m.def(
|
|
|
|
"sum",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const mx::array& a,
|
2023-11-30 02:30:41 +08:00
|
|
|
const IntOrVec& axis,
|
|
|
|
bool keepdims,
|
2024-12-12 07:45:39 +08:00
|
|
|
mx::StreamOrDevice s) {
|
|
|
|
return mx::sum(a, get_reduce_axes(axis, a.ndim()), keepdims, s);
|
2023-11-30 02:30:41 +08:00
|
|
|
},
|
|
|
|
"array"_a,
|
2024-03-19 11:12:25 +08:00
|
|
|
"axis"_a = nb::none(),
|
2023-11-30 02:30:41 +08:00
|
|
|
"keepdims"_a = false,
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
2024-03-25 06:03:27 +08:00
|
|
|
"def sum(a: array, /, axis: Union[None, int, Sequence[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"),
|
2023-11-30 02:30:41 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Sum reduce the array over the given axes.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
axis (int or list(int), optional): Optional axis or
|
|
|
|
axes to reduce over. If unspecified this defaults
|
|
|
|
to reducing over the entire array.
|
|
|
|
keepdims (bool, optional): Keep reduced axes as
|
|
|
|
singleton dimensions, defaults to `False`.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The output array with the corresponding axes reduced.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"prod",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const mx::array& a,
|
2023-11-30 02:30:41 +08:00
|
|
|
const IntOrVec& axis,
|
|
|
|
bool keepdims,
|
2024-12-12 07:45:39 +08:00
|
|
|
mx::StreamOrDevice s) {
|
|
|
|
return mx::prod(a, get_reduce_axes(axis, a.ndim()), keepdims, s);
|
2023-11-30 02:30:41 +08:00
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::arg(),
|
|
|
|
"axis"_a = nb::none(),
|
2023-11-30 02:30:41 +08:00
|
|
|
"keepdims"_a = false,
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
2024-03-25 06:03:27 +08:00
|
|
|
"def prod(a: array, /, axis: Union[None, int, Sequence[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"),
|
2023-11-30 02:30:41 +08:00
|
|
|
R"pbdoc(
|
|
|
|
An product reduction over the given axes.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
axis (int or list(int), optional): Optional axis or
|
|
|
|
axes to reduce over. If unspecified this defaults
|
|
|
|
to reducing over the entire array.
|
|
|
|
keepdims (bool, optional): Keep reduced axes as
|
|
|
|
singleton dimensions, defaults to `False`.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The output array with the corresponding axes reduced.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"min",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const mx::array& a,
|
2023-11-30 02:30:41 +08:00
|
|
|
const IntOrVec& axis,
|
|
|
|
bool keepdims,
|
2024-12-12 07:45:39 +08:00
|
|
|
mx::StreamOrDevice s) {
|
|
|
|
return mx::min(a, get_reduce_axes(axis, a.ndim()), keepdims, s);
|
2023-11-30 02:30:41 +08:00
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::arg(),
|
|
|
|
"axis"_a = nb::none(),
|
2023-11-30 02:30:41 +08:00
|
|
|
"keepdims"_a = false,
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
2024-03-25 06:03:27 +08:00
|
|
|
"def min(a: array, /, axis: Union[None, int, Sequence[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"),
|
2023-11-30 02:30:41 +08:00
|
|
|
R"pbdoc(
|
2024-03-19 11:12:25 +08:00
|
|
|
A `min` reduction over the given axes.
|
2023-11-30 02:30:41 +08:00
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
axis (int or list(int), optional): Optional axis or
|
|
|
|
axes to reduce over. If unspecified this defaults
|
|
|
|
to reducing over the entire array.
|
|
|
|
keepdims (bool, optional): Keep reduced axes as
|
|
|
|
singleton dimensions, defaults to `False`.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The output array with the corresponding axes reduced.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"max",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const mx::array& a,
|
2023-11-30 02:30:41 +08:00
|
|
|
const IntOrVec& axis,
|
|
|
|
bool keepdims,
|
2024-12-12 07:45:39 +08:00
|
|
|
mx::StreamOrDevice s) {
|
|
|
|
return mx::max(a, get_reduce_axes(axis, a.ndim()), keepdims, s);
|
2023-11-30 02:30:41 +08:00
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::arg(),
|
|
|
|
"axis"_a = nb::none(),
|
2023-11-30 02:30:41 +08:00
|
|
|
"keepdims"_a = false,
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
2024-03-25 06:03:27 +08:00
|
|
|
"def max(a: array, /, axis: Union[None, int, Sequence[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"),
|
2023-11-30 02:30:41 +08:00
|
|
|
R"pbdoc(
|
2024-03-19 11:12:25 +08:00
|
|
|
A `max` reduction over the given axes.
|
2023-11-30 02:30:41 +08:00
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
axis (int or list(int), optional): Optional axis or
|
|
|
|
axes to reduce over. If unspecified this defaults
|
|
|
|
to reducing over the entire array.
|
|
|
|
keepdims (bool, optional): Keep reduced axes as
|
|
|
|
singleton dimensions, defaults to `False`.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The output array with the corresponding axes reduced.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"logsumexp",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const mx::array& a,
|
2023-11-30 02:30:41 +08:00
|
|
|
const IntOrVec& axis,
|
|
|
|
bool keepdims,
|
2024-12-12 07:45:39 +08:00
|
|
|
mx::StreamOrDevice s) {
|
|
|
|
return mx::logsumexp(a, get_reduce_axes(axis, a.ndim()), keepdims, s);
|
2023-11-30 02:30:41 +08:00
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::arg(),
|
|
|
|
"axis"_a = nb::none(),
|
2023-11-30 02:30:41 +08:00
|
|
|
"keepdims"_a = false,
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
2024-03-25 06:03:27 +08:00
|
|
|
"def logsumexp(a: array, /, axis: Union[None, int, Sequence[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"),
|
2023-11-30 02:30:41 +08:00
|
|
|
R"pbdoc(
|
|
|
|
A `log-sum-exp` reduction over the given axes.
|
|
|
|
|
|
|
|
The log-sum-exp reduction is a numerically stable version of:
|
|
|
|
|
|
|
|
.. code-block::
|
|
|
|
|
|
|
|
log(sum(exp(a), axis))
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
axis (int or list(int), optional): Optional axis or
|
|
|
|
axes to reduce over. If unspecified this defaults
|
|
|
|
to reducing over the entire array.
|
|
|
|
keepdims (bool, optional): Keep reduced axes as
|
|
|
|
singleton dimensions, defaults to `False`.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The output array with the corresponding axes reduced.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"mean",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const mx::array& a,
|
2023-11-30 02:30:41 +08:00
|
|
|
const IntOrVec& axis,
|
|
|
|
bool keepdims,
|
2024-12-12 07:45:39 +08:00
|
|
|
mx::StreamOrDevice s) {
|
|
|
|
return mx::mean(a, get_reduce_axes(axis, a.ndim()), keepdims, s);
|
2023-11-30 02:30:41 +08:00
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::arg(),
|
|
|
|
"axis"_a = nb::none(),
|
2023-11-30 02:30:41 +08:00
|
|
|
"keepdims"_a = false,
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
2024-03-25 06:03:27 +08:00
|
|
|
"def mean(a: array, /, axis: Union[None, int, Sequence[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"),
|
2023-11-30 02:30:41 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Compute the mean(s) over the given axes.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
axis (int or list(int), optional): Optional axis or
|
|
|
|
axes to reduce over. If unspecified this defaults
|
|
|
|
to reducing over the entire array.
|
|
|
|
keepdims (bool, optional): Keep reduced axes as
|
|
|
|
singleton dimensions, defaults to `False`.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The output array of means.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"var",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const mx::array& a,
|
2023-11-30 02:30:41 +08:00
|
|
|
const IntOrVec& axis,
|
|
|
|
bool keepdims,
|
|
|
|
int ddof,
|
2024-12-12 07:45:39 +08:00
|
|
|
mx::StreamOrDevice s) {
|
|
|
|
return mx::var(a, get_reduce_axes(axis, a.ndim()), keepdims, ddof, s);
|
2023-11-30 02:30:41 +08:00
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::arg(),
|
|
|
|
"axis"_a = nb::none(),
|
2023-11-30 02:30:41 +08:00
|
|
|
"keepdims"_a = false,
|
|
|
|
"ddof"_a = 0,
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
2024-03-25 06:03:27 +08:00
|
|
|
"def var(a: array, /, axis: Union[None, int, Sequence[int]] = None, keepdims: bool = False, ddof: int = 0, *, stream: Union[None, Stream, Device] = None) -> array"),
|
2023-11-30 02:30:41 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Compute the variance(s) over the given axes.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
axis (int or list(int), optional): Optional axis or
|
|
|
|
axes to reduce over. If unspecified this defaults
|
|
|
|
to reducing over the entire array.
|
|
|
|
keepdims (bool, optional): Keep reduced axes as
|
|
|
|
singleton dimensions, defaults to `False`.
|
|
|
|
ddof (int, optional): The divisor to compute the variance
|
|
|
|
is ``N - ddof``, defaults to 0.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The output array of variances.
|
|
|
|
)pbdoc");
|
2024-04-09 05:26:01 +08:00
|
|
|
m.def(
|
|
|
|
"std",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const mx::array& a,
|
2024-04-09 05:26:01 +08:00
|
|
|
const IntOrVec& axis,
|
|
|
|
bool keepdims,
|
|
|
|
int ddof,
|
2024-12-12 07:45:39 +08:00
|
|
|
mx::StreamOrDevice s) {
|
|
|
|
return mx::std(a, get_reduce_axes(axis, a.ndim()), keepdims, ddof, s);
|
2024-04-09 05:26:01 +08:00
|
|
|
},
|
|
|
|
nb::arg(),
|
|
|
|
"axis"_a = nb::none(),
|
|
|
|
"keepdims"_a = false,
|
|
|
|
"ddof"_a = 0,
|
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def std(a: array, /, axis: Union[None, int, Sequence[int]] = None, keepdims: bool = False, ddof: int = 0, *, stream: Union[None, Stream, Device] = None) -> array"),
|
|
|
|
R"pbdoc(
|
|
|
|
Compute the standard deviation(s) over the given axes.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
axis (int or list(int), optional): Optional axis or
|
|
|
|
axes to reduce over. If unspecified this defaults
|
|
|
|
to reducing over the entire array.
|
|
|
|
keepdims (bool, optional): Keep reduced axes as
|
|
|
|
singleton dimensions, defaults to `False`.
|
|
|
|
ddof (int, optional): The divisor to compute the variance
|
|
|
|
is ``N - ddof``, defaults to 0.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The output array of standard deviations.
|
|
|
|
)pbdoc");
|
2023-11-30 02:30:41 +08:00
|
|
|
m.def(
|
|
|
|
"split",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const mx::array& a,
|
2023-11-30 02:30:41 +08:00
|
|
|
const std::variant<int, std::vector<int>>& indices_or_sections,
|
|
|
|
int axis,
|
2024-12-12 07:45:39 +08:00
|
|
|
mx::StreamOrDevice s) {
|
2023-11-30 02:30:41 +08:00
|
|
|
if (auto pv = std::get_if<int>(&indices_or_sections); pv) {
|
2024-12-12 07:45:39 +08:00
|
|
|
return mx::split(a, *pv, axis, s);
|
2023-11-30 02:30:41 +08:00
|
|
|
} else {
|
2024-12-12 07:45:39 +08:00
|
|
|
return mx::split(
|
2023-11-30 02:30:41 +08:00
|
|
|
a, std::get<std::vector<int>>(indices_or_sections), axis, s);
|
|
|
|
}
|
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::arg(),
|
2023-11-30 02:30:41 +08:00
|
|
|
"indices_or_sections"_a,
|
|
|
|
"axis"_a = 0,
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
2024-03-25 06:03:27 +08:00
|
|
|
"def split(a: array, /, indices_or_sections: Union[int, Sequence[int]], axis: int = 0, *, stream: Union[None, Stream, Device] = None) -> array"),
|
2023-11-30 02:30:41 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Split an array along a given axis.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
indices_or_sections (int or list(int)): If ``indices_or_sections``
|
|
|
|
is an integer the array is split into that many sections of equal
|
|
|
|
size. An error is raised if this is not possible. If ``indices_or_sections``
|
|
|
|
is a list, the list contains the indices of the start of each subarray
|
|
|
|
along the given axis.
|
|
|
|
axis (int, optional): Axis to split along, defaults to `0`.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
list(array): A list of split arrays.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"argmin",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const mx::array& a,
|
2023-11-30 02:30:41 +08:00
|
|
|
std::optional<int> axis,
|
|
|
|
bool keepdims,
|
2024-12-12 07:45:39 +08:00
|
|
|
mx::StreamOrDevice s) {
|
2023-11-30 02:30:41 +08:00
|
|
|
if (axis) {
|
2024-12-12 07:45:39 +08:00
|
|
|
return mx::argmin(a, *axis, keepdims, s);
|
2023-11-30 02:30:41 +08:00
|
|
|
} else {
|
2024-12-12 07:45:39 +08:00
|
|
|
return mx::argmin(a, keepdims, s);
|
2023-11-30 02:30:41 +08:00
|
|
|
}
|
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::arg(),
|
|
|
|
"axis"_a = nb::none(),
|
2023-11-30 02:30:41 +08:00
|
|
|
"keepdims"_a = false,
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def argmin(a: array, /, axis: Union[None, int] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"),
|
2023-11-30 02:30:41 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Indices of the minimum values along the axis.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
axis (int, optional): Optional axis to reduce over. If unspecified
|
|
|
|
this defaults to reducing over the entire array.
|
|
|
|
keepdims (bool, optional): Keep reduced axes as
|
|
|
|
singleton dimensions, defaults to `False`.
|
|
|
|
|
|
|
|
Returns:
|
2024-01-24 14:24:41 +08:00
|
|
|
array: The ``uint32`` array with the indices of the minimum values.
|
2023-11-30 02:30:41 +08:00
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"argmax",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const mx::array& a,
|
2023-11-30 02:30:41 +08:00
|
|
|
std::optional<int> axis,
|
|
|
|
bool keepdims,
|
2024-12-12 07:45:39 +08:00
|
|
|
mx::StreamOrDevice s) {
|
2023-11-30 02:30:41 +08:00
|
|
|
if (axis) {
|
2024-12-12 07:45:39 +08:00
|
|
|
return mx::argmax(a, *axis, keepdims, s);
|
2023-11-30 02:30:41 +08:00
|
|
|
} else {
|
2024-12-12 07:45:39 +08:00
|
|
|
return mx::argmax(a, keepdims, s);
|
2023-11-30 02:30:41 +08:00
|
|
|
}
|
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::arg(),
|
|
|
|
"axis"_a = nb::none(),
|
2023-11-30 02:30:41 +08:00
|
|
|
"keepdims"_a = false,
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def argmax(a: array, /, axis: Union[None, int] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"),
|
2023-11-30 02:30:41 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Indices of the maximum values along the axis.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
axis (int, optional): Optional axis to reduce over. If unspecified
|
|
|
|
this defaults to reducing over the entire array.
|
|
|
|
keepdims (bool, optional): Keep reduced axes as
|
|
|
|
singleton dimensions, defaults to `False`.
|
|
|
|
|
|
|
|
Returns:
|
2024-01-24 14:24:41 +08:00
|
|
|
array: The ``uint32`` array with the indices of the maximum values.
|
2023-11-30 02:30:41 +08:00
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"sort",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const mx::array& a, std::optional<int> axis, mx::StreamOrDevice s) {
|
2023-11-30 02:30:41 +08:00
|
|
|
if (axis) {
|
2024-12-12 07:45:39 +08:00
|
|
|
return mx::sort(a, *axis, s);
|
2023-11-30 02:30:41 +08:00
|
|
|
} else {
|
2024-12-12 07:45:39 +08:00
|
|
|
return mx::sort(a, s);
|
2023-11-30 02:30:41 +08:00
|
|
|
}
|
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::arg(),
|
|
|
|
"axis"_a.none() = -1,
|
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def sort(a: array, /, axis: Union[None, int] = -1, *, stream: Union[None, Stream, Device] = None) -> array"),
|
2023-11-30 02:30:41 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Returns a sorted copy of the array.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
2023-12-12 07:01:41 +08:00
|
|
|
axis (int or None, optional): Optional axis to sort over.
|
|
|
|
If ``None``, this sorts over the flattened array.
|
2023-11-30 02:30:41 +08:00
|
|
|
If unspecified, it defaults to -1 (sorting over the last axis).
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The sorted array.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"argsort",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const mx::array& a, std::optional<int> axis, mx::StreamOrDevice s) {
|
2023-11-30 02:30:41 +08:00
|
|
|
if (axis) {
|
2024-12-12 07:45:39 +08:00
|
|
|
return mx::argsort(a, *axis, s);
|
2023-11-30 02:30:41 +08:00
|
|
|
} else {
|
2024-12-12 07:45:39 +08:00
|
|
|
return mx::argsort(a, s);
|
2023-11-30 02:30:41 +08:00
|
|
|
}
|
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::arg(),
|
|
|
|
"axis"_a.none() = -1,
|
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def argsort(a: array, /, axis: Union[None, int] = -1, *, stream: Union[None, Stream, Device] = None) -> array"),
|
2023-11-30 02:30:41 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Returns the indices that sort the array.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
2023-12-12 07:01:41 +08:00
|
|
|
axis (int or None, optional): Optional axis to sort over.
|
|
|
|
If ``None``, this sorts over the flattened array.
|
2023-11-30 02:30:41 +08:00
|
|
|
If unspecified, it defaults to -1 (sorting over the last axis).
|
|
|
|
|
|
|
|
Returns:
|
2024-01-24 14:24:41 +08:00
|
|
|
array: The ``uint32`` array containing indices that sort the input.
|
2023-11-30 02:30:41 +08:00
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"partition",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const mx::array& a,
|
|
|
|
int kth,
|
|
|
|
std::optional<int> axis,
|
|
|
|
mx::StreamOrDevice s) {
|
2023-11-30 02:30:41 +08:00
|
|
|
if (axis) {
|
2024-12-12 07:45:39 +08:00
|
|
|
return mx::partition(a, kth, *axis, s);
|
2023-11-30 02:30:41 +08:00
|
|
|
} else {
|
2024-12-12 07:45:39 +08:00
|
|
|
return mx::partition(a, kth, s);
|
2023-11-30 02:30:41 +08:00
|
|
|
}
|
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::arg(),
|
2023-11-30 02:30:41 +08:00
|
|
|
"kth"_a,
|
2024-03-19 11:12:25 +08:00
|
|
|
"axis"_a.none() = -1,
|
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def partition(a: array, /, kth: int, axis: Union[None, int] = -1, *, stream: Union[None, Stream, Device] = None) -> array"),
|
2023-11-30 02:30:41 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Returns a partitioned copy of the array such that the smaller ``kth``
|
|
|
|
elements are first.
|
|
|
|
|
|
|
|
The ordering of the elements in partitions is undefined.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
2023-12-12 07:01:41 +08:00
|
|
|
kth (int): Element at the ``kth`` index will be in its sorted
|
|
|
|
position in the output. All elements before the kth index will
|
|
|
|
be less or equal to the ``kth`` element and all elements after
|
2023-11-30 02:30:41 +08:00
|
|
|
will be greater or equal to the ``kth`` element in the output.
|
2023-12-12 07:01:41 +08:00
|
|
|
axis (int or None, optional): Optional axis to partition over.
|
|
|
|
If ``None``, this partitions over the flattened array.
|
2023-11-30 02:30:41 +08:00
|
|
|
If unspecified, it defaults to ``-1``.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The partitioned array.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"argpartition",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const mx::array& a,
|
|
|
|
int kth,
|
|
|
|
std::optional<int> axis,
|
|
|
|
mx::StreamOrDevice s) {
|
2023-11-30 02:30:41 +08:00
|
|
|
if (axis) {
|
2024-12-12 07:45:39 +08:00
|
|
|
return mx::argpartition(a, kth, *axis, s);
|
2023-11-30 02:30:41 +08:00
|
|
|
} else {
|
2024-12-12 07:45:39 +08:00
|
|
|
return mx::argpartition(a, kth, s);
|
2023-11-30 02:30:41 +08:00
|
|
|
}
|
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::arg(),
|
2023-11-30 02:30:41 +08:00
|
|
|
"kth"_a,
|
2024-05-16 06:25:25 +08:00
|
|
|
"axis"_a.none() = -1,
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def argpartition(a: array, /, kth: int, axis: Union[None, int] = -1, *, stream: Union[None, Stream, Device] = None) -> array"),
|
2023-11-30 02:30:41 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Returns the indices that partition the array.
|
|
|
|
|
|
|
|
The ordering of the elements within a partition in given by the indices
|
|
|
|
is undefined.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
kth (int): Element index at the ``kth`` position in the output will
|
|
|
|
give the sorted position. All indices before the ``kth`` position
|
2023-12-12 07:01:41 +08:00
|
|
|
will be of elements less or equal to the element at the ``kth``
|
2023-11-30 02:30:41 +08:00
|
|
|
index and all indices after will be of elements greater or equal
|
|
|
|
to the element at the ``kth`` index.
|
2024-01-02 13:08:17 +08:00
|
|
|
axis (int or None, optional): Optional axis to partition over.
|
2023-12-12 07:01:41 +08:00
|
|
|
If ``None``, this partitions over the flattened array.
|
2023-11-30 02:30:41 +08:00
|
|
|
If unspecified, it defaults to ``-1``.
|
|
|
|
|
|
|
|
Returns:
|
2024-07-12 06:59:07 +08:00
|
|
|
array: The ``uint32`` array containing indices that partition the input.
|
2023-11-30 02:30:41 +08:00
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"topk",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const mx::array& a,
|
|
|
|
int k,
|
|
|
|
std::optional<int> axis,
|
|
|
|
mx::StreamOrDevice s) {
|
2023-11-30 02:30:41 +08:00
|
|
|
if (axis) {
|
2024-12-12 07:45:39 +08:00
|
|
|
return mx::topk(a, k, *axis, s);
|
2023-11-30 02:30:41 +08:00
|
|
|
} else {
|
2024-12-12 07:45:39 +08:00
|
|
|
return mx::topk(a, k, s);
|
2023-11-30 02:30:41 +08:00
|
|
|
}
|
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::arg(),
|
2023-11-30 02:30:41 +08:00
|
|
|
"k"_a,
|
2024-03-19 11:12:25 +08:00
|
|
|
"axis"_a.none() = -1,
|
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def topk(a: array, /, k: int, axis: Union[None, int] = -1, *, stream: Union[None, Stream, Device] = None) -> array"),
|
2023-11-30 02:30:41 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Returns the ``k`` largest elements from the input along a given axis.
|
|
|
|
|
|
|
|
The elements will not necessarily be in sorted order.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
k (int): ``k`` top elements to be returned
|
2023-12-12 07:01:41 +08:00
|
|
|
axis (int or None, optional): Optional axis to select over.
|
|
|
|
If ``None``, this selects the top ``k`` elements over the
|
2023-11-30 02:30:41 +08:00
|
|
|
flattened array. If unspecified, it defaults to ``-1``.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The top ``k`` elements from the input.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"broadcast_to",
|
|
|
|
[](const ScalarOrArray& a,
|
|
|
|
const std::vector<int>& shape,
|
2024-12-12 07:45:39 +08:00
|
|
|
mx::StreamOrDevice s) {
|
|
|
|
return mx::broadcast_to(to_array(a), shape, s);
|
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::arg(),
|
2023-11-30 02:30:41 +08:00
|
|
|
"shape"_a,
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
2024-03-25 06:03:27 +08:00
|
|
|
"def broadcast_to(a: Union[scalar, array], /, shape: Sequence[int], *, stream: Union[None, Stream, Device] = None) -> array"),
|
2023-11-30 02:30:41 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Broadcast an array to the given shape.
|
|
|
|
|
|
|
|
The broadcasting semantics are the same as Numpy.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
shape (list(int)): The shape to broadcast to.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The output array with the new shape.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"softmax",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const mx::array& a,
|
|
|
|
const IntOrVec& axis,
|
|
|
|
bool precise,
|
|
|
|
mx::StreamOrDevice s) {
|
|
|
|
return mx::softmax(a, get_reduce_axes(axis, a.ndim()), precise, s);
|
2023-11-30 02:30:41 +08:00
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::arg(),
|
|
|
|
"axis"_a = nb::none(),
|
|
|
|
nb::kw_only(),
|
2024-04-04 23:32:35 +08:00
|
|
|
"precise"_a = false,
|
2024-03-19 11:12:25 +08:00
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
2024-03-25 06:03:27 +08:00
|
|
|
"def softmax(a: array, /, axis: Union[None, int, Sequence[int]] = None, *, stream: Union[None, Stream, Device] = None) -> array"),
|
2023-11-30 02:30:41 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Perform the softmax along the given axis.
|
|
|
|
|
|
|
|
This operation is a numerically stable version of:
|
|
|
|
|
|
|
|
.. code-block::
|
|
|
|
|
|
|
|
exp(a) / sum(exp(a), axis, keepdims=True)
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
axis (int or list(int), optional): Optional axis or axes to compute
|
|
|
|
the softmax over. If unspecified this performs the softmax over
|
|
|
|
the full array.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The output of the softmax.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"concatenate",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const std::vector<mx::array>& arrays,
|
2023-11-30 02:30:41 +08:00
|
|
|
std::optional<int> axis,
|
2024-12-12 07:45:39 +08:00
|
|
|
mx::StreamOrDevice s) {
|
2023-11-30 02:30:41 +08:00
|
|
|
if (axis) {
|
2024-12-12 07:45:39 +08:00
|
|
|
return mx::concatenate(arrays, *axis, s);
|
2023-11-30 02:30:41 +08:00
|
|
|
} else {
|
2024-12-12 07:45:39 +08:00
|
|
|
return mx::concatenate(arrays, s);
|
2023-11-30 02:30:41 +08:00
|
|
|
}
|
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::arg(),
|
|
|
|
"axis"_a.none() = 0,
|
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
2024-08-29 02:16:19 +08:00
|
|
|
"def concatenate(arrays: list[array], axis: Optional[int] = 0, *, stream: Union[None, Stream, Device] = None) -> array"),
|
2023-11-30 02:30:41 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Concatenate the arrays along the given axis.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
arrays (list(array)): Input :obj:`list` or :obj:`tuple` of arrays.
|
|
|
|
axis (int, optional): Optional axis to concatenate along. If
|
|
|
|
unspecified defaults to ``0``.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The concatenated array.
|
|
|
|
)pbdoc");
|
2024-07-27 01:40:49 +08:00
|
|
|
m.def(
|
|
|
|
"concat",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const std::vector<mx::array>& arrays,
|
2024-07-27 01:40:49 +08:00
|
|
|
std::optional<int> axis,
|
2024-12-12 07:45:39 +08:00
|
|
|
mx::StreamOrDevice s) {
|
2024-07-27 01:40:49 +08:00
|
|
|
if (axis) {
|
2024-12-12 07:45:39 +08:00
|
|
|
return mx::concatenate(arrays, *axis, s);
|
2024-07-27 01:40:49 +08:00
|
|
|
} else {
|
2024-12-12 07:45:39 +08:00
|
|
|
return mx::concatenate(arrays, s);
|
2024-07-27 01:40:49 +08:00
|
|
|
}
|
|
|
|
},
|
|
|
|
nb::arg(),
|
|
|
|
"axis"_a.none() = 0,
|
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
2024-08-29 02:16:19 +08:00
|
|
|
"def concat(arrays: list[array], axis: Optional[int] = 0, *, stream: Union[None, Stream, Device] = None) -> array"),
|
2024-07-27 01:40:49 +08:00
|
|
|
R"pbdoc(
|
|
|
|
See :func:`concatenate`.
|
|
|
|
)pbdoc");
|
2023-12-15 05:21:19 +08:00
|
|
|
m.def(
|
|
|
|
"stack",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const std::vector<mx::array>& arrays,
|
2023-12-15 05:21:19 +08:00
|
|
|
std::optional<int> axis,
|
2024-12-12 07:45:39 +08:00
|
|
|
mx::StreamOrDevice s) {
|
2023-12-15 05:21:19 +08:00
|
|
|
if (axis.has_value()) {
|
2024-12-12 07:45:39 +08:00
|
|
|
return mx::stack(arrays, axis.value(), s);
|
2023-12-15 05:21:19 +08:00
|
|
|
} else {
|
2024-12-12 07:45:39 +08:00
|
|
|
return mx::stack(arrays, s);
|
2023-12-15 05:21:19 +08:00
|
|
|
}
|
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::arg(),
|
2023-12-15 05:21:19 +08:00
|
|
|
"axis"_a = 0,
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
2024-08-29 02:16:19 +08:00
|
|
|
"def stack(arrays: list[array], axis: Optional[int] = 0, *, stream: Union[None, Stream, Device] = None) -> array"),
|
2023-12-15 05:21:19 +08:00
|
|
|
R"pbdoc(
|
2024-03-19 11:12:25 +08:00
|
|
|
Stacks the arrays along a new axis.
|
2023-12-15 05:21:19 +08:00
|
|
|
|
2024-03-19 11:12:25 +08:00
|
|
|
Args:
|
|
|
|
arrays (list(array)): A list of arrays to stack.
|
|
|
|
axis (int, optional): The axis in the result array along which the
|
|
|
|
input arrays are stacked. Defaults to ``0``.
|
|
|
|
stream (Stream, optional): Stream or device. Defaults to ``None``.
|
2023-12-15 05:21:19 +08:00
|
|
|
|
2024-03-19 11:12:25 +08:00
|
|
|
Returns:
|
|
|
|
array: The resulting stacked array.
|
|
|
|
)pbdoc");
|
2024-04-10 02:43:08 +08:00
|
|
|
m.def(
|
|
|
|
"meshgrid",
|
|
|
|
[](nb::args arrays_,
|
|
|
|
bool sparse,
|
|
|
|
std::string indexing,
|
2024-12-12 07:45:39 +08:00
|
|
|
mx::StreamOrDevice s) {
|
|
|
|
std::vector<mx::array> arrays =
|
|
|
|
nb::cast<std::vector<mx::array>>(arrays_);
|
|
|
|
return mx::meshgrid(arrays, sparse, indexing, s);
|
2024-04-10 02:43:08 +08:00
|
|
|
},
|
|
|
|
"arrays"_a,
|
|
|
|
"sparse"_a = false,
|
|
|
|
"indexing"_a = "xy",
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
2024-08-29 02:16:19 +08:00
|
|
|
"def meshgrid(*arrays: array, sparse: Optional[bool] = False, indexing: Optional[str] = 'xy', stream: Union[None, Stream, Device] = None) -> array"),
|
2024-04-10 02:43:08 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Generate multidimensional coordinate grids from 1-D coordinate arrays
|
|
|
|
|
|
|
|
Args:
|
|
|
|
arrays (array): Input arrays.
|
|
|
|
sparse (bool, optional): If ``True``, a sparse grid is returned in which each output
|
|
|
|
array has a single non-zero element. If ``False``, a dense grid is returned.
|
|
|
|
Defaults to ``False``.
|
|
|
|
indexing (str, optional): Cartesian ('xy') or matrix ('ij') indexing of the output arrays.
|
|
|
|
Defaults to ``'xy'``.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
list(array): The output arrays.
|
|
|
|
)pbdoc");
|
2023-12-28 05:11:38 +08:00
|
|
|
m.def(
|
|
|
|
"repeat",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const mx::array& array,
|
2023-12-28 05:11:38 +08:00
|
|
|
int repeats,
|
|
|
|
std::optional<int> axis,
|
2024-12-12 07:45:39 +08:00
|
|
|
mx::StreamOrDevice s) {
|
2023-12-28 05:11:38 +08:00
|
|
|
if (axis.has_value()) {
|
2024-12-12 07:45:39 +08:00
|
|
|
return mx::repeat(array, repeats, axis.value(), s);
|
2023-12-28 05:11:38 +08:00
|
|
|
} else {
|
2024-12-12 07:45:39 +08:00
|
|
|
return mx::repeat(array, repeats, s);
|
2023-12-28 05:11:38 +08:00
|
|
|
}
|
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::arg(),
|
2023-12-28 05:11:38 +08:00
|
|
|
"repeats"_a,
|
2024-03-19 11:12:25 +08:00
|
|
|
"axis"_a = nb::none(),
|
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def repeat(array: array, repeats: int, axis: Optional[int] = None, *, stream: Union[None, Stream, Device] = None) -> array"),
|
2023-12-28 05:11:38 +08:00
|
|
|
R"pbdoc(
|
2024-03-19 11:12:25 +08:00
|
|
|
Repeat an array along a specified axis.
|
2023-12-28 05:11:38 +08:00
|
|
|
|
2024-03-19 11:12:25 +08:00
|
|
|
Args:
|
|
|
|
array (array): Input array.
|
|
|
|
repeats (int): The number of repetitions for each element.
|
|
|
|
axis (int, optional): The axis in which to repeat the array along. If
|
|
|
|
unspecified it uses the flattened array of the input and repeats
|
|
|
|
along axis 0.
|
|
|
|
stream (Stream, optional): Stream or device. Defaults to ``None``.
|
2023-12-28 05:11:38 +08:00
|
|
|
|
2024-03-19 11:12:25 +08:00
|
|
|
Returns:
|
|
|
|
array: The resulting repeated array.
|
|
|
|
)pbdoc");
|
2023-12-18 12:00:29 +08:00
|
|
|
m.def(
|
|
|
|
"clip",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const mx::array& a,
|
2023-12-18 12:00:29 +08:00
|
|
|
const std::optional<ScalarOrArray>& min,
|
|
|
|
const std::optional<ScalarOrArray>& max,
|
2024-12-12 07:45:39 +08:00
|
|
|
mx::StreamOrDevice s) {
|
|
|
|
std::optional<mx::array> min_ = std::nullopt;
|
|
|
|
std::optional<mx::array> max_ = std::nullopt;
|
2023-12-18 12:00:29 +08:00
|
|
|
if (min) {
|
2024-09-15 07:09:09 +08:00
|
|
|
min_ = to_arrays(a, min.value()).second;
|
2023-12-18 12:00:29 +08:00
|
|
|
}
|
|
|
|
if (max) {
|
2024-09-15 07:09:09 +08:00
|
|
|
max_ = to_arrays(a, max.value()).second;
|
2023-12-18 12:00:29 +08:00
|
|
|
}
|
2024-12-12 07:45:39 +08:00
|
|
|
return mx::clip(a, min_, max_, s);
|
2023-12-18 12:00:29 +08:00
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::arg(),
|
|
|
|
"a_min"_a.none(),
|
|
|
|
"a_max"_a.none(),
|
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def clip(a: array, /, a_min: Union[scalar, array, None], a_max: Union[scalar, array, None], *, stream: Union[None, Stream, Device] = None) -> array"),
|
2023-12-18 12:00:29 +08:00
|
|
|
R"pbdoc(
|
2024-03-19 11:12:25 +08:00
|
|
|
Clip the values of the array between the given minimum and maximum.
|
2023-12-18 12:00:29 +08:00
|
|
|
|
2024-03-19 11:12:25 +08:00
|
|
|
If either ``a_min`` or ``a_max`` are ``None``, then corresponding edge
|
|
|
|
is ignored. At least one of ``a_min`` and ``a_max`` cannot be ``None``.
|
|
|
|
The input ``a`` and the limits must broadcast with one another.
|
2023-12-18 12:00:29 +08:00
|
|
|
|
2024-03-19 11:12:25 +08:00
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
a_min (scalar or array or None): Minimum value to clip to.
|
|
|
|
a_max (scalar or array or None): Maximum value to clip to.
|
2023-12-18 12:00:29 +08:00
|
|
|
|
2024-03-19 11:12:25 +08:00
|
|
|
Returns:
|
|
|
|
array: The clipped array.
|
|
|
|
)pbdoc");
|
2023-11-30 02:30:41 +08:00
|
|
|
m.def(
|
|
|
|
"pad",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const mx::array& a,
|
2023-11-30 02:30:41 +08:00
|
|
|
const std::variant<
|
|
|
|
int,
|
|
|
|
std::tuple<int>,
|
|
|
|
std::pair<int, int>,
|
|
|
|
std::vector<std::pair<int, int>>>& pad_width,
|
2024-08-07 02:23:10 +08:00
|
|
|
const std::string mode,
|
2023-11-30 02:30:41 +08:00
|
|
|
const ScalarOrArray& constant_value,
|
2024-12-12 07:45:39 +08:00
|
|
|
mx::StreamOrDevice s) {
|
2023-11-30 02:30:41 +08:00
|
|
|
if (auto pv = std::get_if<int>(&pad_width); pv) {
|
2024-12-12 07:45:39 +08:00
|
|
|
return mx::pad(a, *pv, to_array(constant_value), mode, s);
|
2023-11-30 02:30:41 +08:00
|
|
|
} else if (auto pv = std::get_if<std::tuple<int>>(&pad_width); pv) {
|
2024-12-12 07:45:39 +08:00
|
|
|
return mx::pad(
|
|
|
|
a, std::get<0>(*pv), to_array(constant_value), mode, s);
|
2023-11-30 02:30:41 +08:00
|
|
|
} else if (auto pv = std::get_if<std::pair<int, int>>(&pad_width); pv) {
|
2024-12-12 07:45:39 +08:00
|
|
|
return mx::pad(a, *pv, to_array(constant_value), mode, s);
|
2023-11-30 02:30:41 +08:00
|
|
|
} else {
|
|
|
|
auto v = std::get<std::vector<std::pair<int, int>>>(pad_width);
|
|
|
|
if (v.size() == 1) {
|
2024-12-12 07:45:39 +08:00
|
|
|
return mx::pad(a, v[0], to_array(constant_value), mode, s);
|
2023-11-30 02:30:41 +08:00
|
|
|
} else {
|
2024-12-12 07:45:39 +08:00
|
|
|
return mx::pad(a, v, to_array(constant_value), mode, s);
|
2023-11-30 02:30:41 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::arg(),
|
2023-11-30 02:30:41 +08:00
|
|
|
"pad_width"_a,
|
2024-08-07 02:23:10 +08:00
|
|
|
"mode"_a = "constant",
|
2023-11-30 02:30:41 +08:00
|
|
|
"constant_values"_a = 0,
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
2024-08-29 02:16:19 +08:00
|
|
|
"def pad(a: array, pad_width: Union[int, tuple[int], tuple[int, int], list[tuple[int, int]]], mode: Literal['constant', 'edge'] = 'constant', constant_values: Union[scalar, array] = 0, *, stream: Union[None, Stream, Device] = None) -> array"),
|
2023-11-30 02:30:41 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Pad an array with a constant value
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
pad_width (int, tuple(int), tuple(int, int) or list(tuple(int, int))): Number of padded
|
|
|
|
values to add to the edges of each axis:``((before_1, after_1),
|
|
|
|
(before_2, after_2), ..., (before_N, after_N))``. If a single pair
|
|
|
|
of integers is passed then ``(before_i, after_i)`` are all the same.
|
|
|
|
If a single integer or tuple with a single integer is passed then
|
|
|
|
all axes are extended by the same number on each side.
|
2024-08-07 02:23:10 +08:00
|
|
|
mode: Padding mode. One of the following strings:
|
|
|
|
"constant" (default): Pads with a constant value.
|
|
|
|
"edge": Pads with the edge values of array.
|
2023-11-30 02:30:41 +08:00
|
|
|
constant_value (array or scalar, optional): Optional constant value
|
|
|
|
to pad the edges of the array with.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The padded array.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"as_strided",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const mx::array& a,
|
|
|
|
std::optional<mx::Shape> shape,
|
|
|
|
std::optional<mx::Strides> strides,
|
2023-11-30 02:30:41 +08:00
|
|
|
size_t offset,
|
2024-12-12 07:45:39 +08:00
|
|
|
mx::StreamOrDevice s) {
|
2024-12-10 03:09:02 +08:00
|
|
|
auto a_shape = (shape) ? *shape : a.shape();
|
2024-12-12 07:45:39 +08:00
|
|
|
mx::Strides a_strides;
|
2023-11-30 02:30:41 +08:00
|
|
|
if (strides) {
|
|
|
|
a_strides = *strides;
|
|
|
|
} else {
|
2024-12-12 07:45:39 +08:00
|
|
|
a_strides = mx::Strides(a_shape.size(), 1);
|
2023-11-30 02:30:41 +08:00
|
|
|
for (int i = a_shape.size() - 1; i > 0; i--) {
|
|
|
|
a_strides[i - 1] = a_shape[i] * a_strides[i];
|
|
|
|
}
|
|
|
|
}
|
2024-12-12 07:45:39 +08:00
|
|
|
return mx::as_strided(a, a_shape, a_strides, offset, s);
|
2023-11-30 02:30:41 +08:00
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::arg(),
|
|
|
|
"shape"_a = nb::none(),
|
|
|
|
"strides"_a = nb::none(),
|
2023-11-30 02:30:41 +08:00
|
|
|
"offset"_a = 0,
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
2024-03-25 06:03:27 +08:00
|
|
|
"def as_strided(a: array, /, shape: Optional[Sequence[int]] = None, strides: Optional[Sequence[int]] = None, offset: int = 0, *, stream: Union[None, Stream, Device] = None) -> array"),
|
2023-11-30 02:30:41 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Create a view into the array with the given shape and strides.
|
|
|
|
|
|
|
|
The resulting array will always be as if the provided array was row
|
|
|
|
contiguous regardless of the provided arrays storage order and current
|
|
|
|
strides.
|
|
|
|
|
|
|
|
.. note::
|
|
|
|
Note that this function should be used with caution as it changes
|
|
|
|
the shape and strides of the array directly. This can lead to the
|
|
|
|
resulting array pointing to invalid memory locations which can
|
|
|
|
result into crashes.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array
|
|
|
|
shape (list(int), optional): The shape of the resulting array. If
|
|
|
|
None it defaults to ``a.shape()``.
|
|
|
|
strides (list(int), optional): The strides of the resulting array. If
|
|
|
|
None it defaults to the reverse exclusive cumulative product of
|
|
|
|
``a.shape()``.
|
|
|
|
offset (int): Skip that many elements from the beginning of the input
|
|
|
|
array.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The output array which is the strided view of the input.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"cumsum",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const mx::array& a,
|
2023-11-30 02:30:41 +08:00
|
|
|
std::optional<int> axis,
|
|
|
|
bool reverse,
|
|
|
|
bool inclusive,
|
2024-12-12 07:45:39 +08:00
|
|
|
mx::StreamOrDevice s) {
|
2023-11-30 02:30:41 +08:00
|
|
|
if (axis) {
|
2024-12-12 07:45:39 +08:00
|
|
|
return mx::cumsum(a, *axis, reverse, inclusive, s);
|
2023-11-30 02:30:41 +08:00
|
|
|
} else {
|
2024-12-12 07:45:39 +08:00
|
|
|
return mx::cumsum(mx::reshape(a, {-1}, s), 0, reverse, inclusive, s);
|
2023-11-30 02:30:41 +08:00
|
|
|
}
|
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::arg(),
|
|
|
|
"axis"_a = nb::none(),
|
|
|
|
nb::kw_only(),
|
2023-11-30 02:30:41 +08:00
|
|
|
"reverse"_a = false,
|
|
|
|
"inclusive"_a = true,
|
2024-03-19 11:12:25 +08:00
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def cumsum(a: array, /, axis: Optional[int] = None, *, reverse: bool = False, inclusive: bool = True, stream: Union[None, Stream, Device] = None) -> array"),
|
2023-11-30 02:30:41 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Return the cumulative sum of the elements along the given axis.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array
|
|
|
|
axis (int, optional): Optional axis to compute the cumulative sum
|
|
|
|
over. If unspecified the cumulative sum of the flattened array is
|
|
|
|
returned.
|
|
|
|
reverse (bool): Perform the cumulative sum in reverse.
|
|
|
|
inclusive (bool): The i-th element of the output includes the i-th
|
|
|
|
element of the input.
|
2024-07-26 00:36:44 +08:00
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The output array.
|
2023-11-30 02:30:41 +08:00
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"cumprod",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const mx::array& a,
|
2023-11-30 02:30:41 +08:00
|
|
|
std::optional<int> axis,
|
|
|
|
bool reverse,
|
|
|
|
bool inclusive,
|
2024-12-12 07:45:39 +08:00
|
|
|
mx::StreamOrDevice s) {
|
2023-11-30 02:30:41 +08:00
|
|
|
if (axis) {
|
2024-12-12 07:45:39 +08:00
|
|
|
return mx::cumprod(a, *axis, reverse, inclusive, s);
|
2023-11-30 02:30:41 +08:00
|
|
|
} else {
|
2024-12-12 07:45:39 +08:00
|
|
|
return mx::cumprod(mx::reshape(a, {-1}, s), 0, reverse, inclusive, s);
|
2023-11-30 02:30:41 +08:00
|
|
|
}
|
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::arg(),
|
|
|
|
"axis"_a = nb::none(),
|
|
|
|
nb::kw_only(),
|
2023-11-30 02:30:41 +08:00
|
|
|
"reverse"_a = false,
|
|
|
|
"inclusive"_a = true,
|
2024-03-19 11:12:25 +08:00
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def cumprod(a: array, /, axis: Optional[int] = None, *, reverse: bool = False, inclusive: bool = True, stream: Union[None, Stream, Device] = None) -> array"),
|
2023-11-30 02:30:41 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Return the cumulative product of the elements along the given axis.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array
|
|
|
|
axis (int, optional): Optional axis to compute the cumulative product
|
|
|
|
over. If unspecified the cumulative product of the flattened array is
|
|
|
|
returned.
|
|
|
|
reverse (bool): Perform the cumulative product in reverse.
|
|
|
|
inclusive (bool): The i-th element of the output includes the i-th
|
|
|
|
element of the input.
|
2024-07-26 00:36:44 +08:00
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The output array.
|
2023-11-30 02:30:41 +08:00
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"cummax",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const mx::array& a,
|
2023-11-30 02:30:41 +08:00
|
|
|
std::optional<int> axis,
|
|
|
|
bool reverse,
|
|
|
|
bool inclusive,
|
2024-12-12 07:45:39 +08:00
|
|
|
mx::StreamOrDevice s) {
|
2023-11-30 02:30:41 +08:00
|
|
|
if (axis) {
|
2024-12-12 07:45:39 +08:00
|
|
|
return mx::cummax(a, *axis, reverse, inclusive, s);
|
2023-11-30 02:30:41 +08:00
|
|
|
} else {
|
2024-12-12 07:45:39 +08:00
|
|
|
return mx::cummax(mx::reshape(a, {-1}, s), 0, reverse, inclusive, s);
|
2023-11-30 02:30:41 +08:00
|
|
|
}
|
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::arg(),
|
|
|
|
"axis"_a = nb::none(),
|
|
|
|
nb::kw_only(),
|
2023-11-30 02:30:41 +08:00
|
|
|
"reverse"_a = false,
|
|
|
|
"inclusive"_a = true,
|
2024-03-19 11:12:25 +08:00
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def cummax(a: array, /, axis: Optional[int] = None, *, reverse: bool = False, inclusive: bool = True, stream: Union[None, Stream, Device] = None) -> array"),
|
2023-11-30 02:30:41 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Return the cumulative maximum of the elements along the given axis.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array
|
|
|
|
axis (int, optional): Optional axis to compute the cumulative maximum
|
|
|
|
over. If unspecified the cumulative maximum of the flattened array is
|
|
|
|
returned.
|
|
|
|
reverse (bool): Perform the cumulative maximum in reverse.
|
|
|
|
inclusive (bool): The i-th element of the output includes the i-th
|
|
|
|
element of the input.
|
2024-07-26 00:36:44 +08:00
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The output array.
|
2023-11-30 02:30:41 +08:00
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"cummin",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const mx::array& a,
|
2023-11-30 02:30:41 +08:00
|
|
|
std::optional<int> axis,
|
|
|
|
bool reverse,
|
|
|
|
bool inclusive,
|
2024-12-12 07:45:39 +08:00
|
|
|
mx::StreamOrDevice s) {
|
2023-11-30 02:30:41 +08:00
|
|
|
if (axis) {
|
2024-12-12 07:45:39 +08:00
|
|
|
return mx::cummin(a, *axis, reverse, inclusive, s);
|
2023-11-30 02:30:41 +08:00
|
|
|
} else {
|
2024-12-12 07:45:39 +08:00
|
|
|
return mx::cummin(mx::reshape(a, {-1}, s), 0, reverse, inclusive, s);
|
2023-11-30 02:30:41 +08:00
|
|
|
}
|
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::arg(),
|
|
|
|
"axis"_a = nb::none(),
|
|
|
|
nb::kw_only(),
|
2023-11-30 02:30:41 +08:00
|
|
|
"reverse"_a = false,
|
|
|
|
"inclusive"_a = true,
|
2024-03-19 11:12:25 +08:00
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def cummin(a: array, /, axis: Optional[int] = None, *, reverse: bool = False, inclusive: bool = True, stream: Union[None, Stream, Device] = None) -> array"),
|
2023-11-30 02:30:41 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Return the cumulative minimum of the elements along the given axis.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array
|
|
|
|
axis (int, optional): Optional axis to compute the cumulative minimum
|
|
|
|
over. If unspecified the cumulative minimum of the flattened array is
|
|
|
|
returned.
|
|
|
|
reverse (bool): Perform the cumulative minimum in reverse.
|
|
|
|
inclusive (bool): The i-th element of the output includes the i-th
|
|
|
|
element of the input.
|
2024-07-26 00:36:44 +08:00
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The output array.
|
2023-11-30 02:30:41 +08:00
|
|
|
)pbdoc");
|
2024-05-10 22:22:20 +08:00
|
|
|
m.def(
|
|
|
|
"conj",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const ScalarOrArray& a, mx::StreamOrDevice s) {
|
|
|
|
return mx::conjugate(to_array(a), s);
|
2024-05-10 22:22:20 +08:00
|
|
|
},
|
|
|
|
nb::arg(),
|
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def conj(a: array, *, stream: Union[None, Stream, Device] = None) -> array"),
|
|
|
|
R"pbdoc(
|
|
|
|
Return the elementwise complex conjugate of the input.
|
|
|
|
Alias for `mx.conjugate`.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array
|
2024-07-26 00:36:44 +08:00
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The output array.
|
2024-05-10 22:22:20 +08:00
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"conjugate",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const ScalarOrArray& a, mx::StreamOrDevice s) {
|
|
|
|
return mx::conjugate(to_array(a), s);
|
2024-05-10 22:22:20 +08:00
|
|
|
},
|
|
|
|
nb::arg(),
|
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def conjugate(a: array, *, stream: Union[None, Stream, Device] = None) -> array"),
|
|
|
|
R"pbdoc(
|
|
|
|
Return the elementwise complex conjugate of the input.
|
|
|
|
Alias for `mx.conj`.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array
|
2024-07-26 00:36:44 +08:00
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The output array.
|
2024-05-10 22:22:20 +08:00
|
|
|
)pbdoc");
|
2023-11-30 02:30:41 +08:00
|
|
|
m.def(
|
|
|
|
"convolve",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const mx::array& a,
|
|
|
|
const mx::array& v,
|
2023-11-30 02:30:41 +08:00
|
|
|
const std::string& mode,
|
2024-12-12 07:45:39 +08:00
|
|
|
mx::StreamOrDevice s) {
|
2023-11-30 02:30:41 +08:00
|
|
|
if (a.ndim() != 1 || v.ndim() != 1) {
|
|
|
|
throw std::invalid_argument("[convolve] Inputs must be 1D.");
|
|
|
|
}
|
|
|
|
|
2024-01-27 07:13:26 +08:00
|
|
|
if (a.size() == 0 || v.size() == 0) {
|
|
|
|
throw std::invalid_argument("[convolve] Inputs cannot be empty.");
|
|
|
|
}
|
|
|
|
|
2024-12-12 07:45:39 +08:00
|
|
|
mx::array in = a.size() < v.size() ? v : a;
|
|
|
|
mx::array wt = a.size() < v.size() ? a : v;
|
|
|
|
wt = mx::slice(wt, {wt.shape(0) - 1}, {-wt.shape(0) - 1}, {-1}, s);
|
2023-11-30 02:30:41 +08:00
|
|
|
|
2024-12-12 07:45:39 +08:00
|
|
|
in = mx::reshape(in, {1, -1, 1}, s);
|
|
|
|
wt = mx::reshape(wt, {1, -1, 1}, s);
|
2023-11-30 02:30:41 +08:00
|
|
|
|
|
|
|
int padding = 0;
|
|
|
|
|
|
|
|
if (mode == "full") {
|
|
|
|
padding = wt.size() - 1;
|
|
|
|
} else if (mode == "valid") {
|
|
|
|
padding = 0;
|
|
|
|
} else if (mode == "same") {
|
|
|
|
// Odd sizes use symmetric padding
|
|
|
|
if (wt.size() % 2) {
|
|
|
|
padding = wt.size() / 2;
|
|
|
|
} else { // Even sizes use asymmetric padding
|
|
|
|
int pad_l = wt.size() / 2;
|
|
|
|
int pad_r = std::max(0, pad_l - 1);
|
2024-12-12 07:45:39 +08:00
|
|
|
in = mx::pad(
|
|
|
|
in,
|
|
|
|
{{0, 0}, {pad_l, pad_r}, {0, 0}},
|
|
|
|
mx::array(0),
|
|
|
|
"constant",
|
|
|
|
s);
|
2023-11-30 02:30:41 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
} else {
|
|
|
|
throw std::invalid_argument("[convolve] Invalid mode.");
|
|
|
|
}
|
|
|
|
|
2024-12-12 07:45:39 +08:00
|
|
|
mx::array out = mx::conv1d(
|
2023-11-30 02:30:41 +08:00
|
|
|
in,
|
|
|
|
wt,
|
|
|
|
/*stride = */ 1,
|
|
|
|
/*padding = */ padding,
|
|
|
|
/*dilation = */ 1,
|
|
|
|
/*groups = */ 1,
|
|
|
|
s);
|
|
|
|
|
2024-12-12 07:45:39 +08:00
|
|
|
return mx::reshape(out, {-1}, s);
|
2023-11-30 02:30:41 +08:00
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::arg(),
|
|
|
|
nb::arg(),
|
2023-11-30 02:30:41 +08:00
|
|
|
"mode"_a = "full",
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
R"(def convolve(a: array, v: array, /, mode: str = "full", *, stream: Union[None, Stream, Device] = None) -> array)"),
|
2023-11-30 02:30:41 +08:00
|
|
|
R"pbdoc(
|
|
|
|
The discrete convolution of 1D arrays.
|
|
|
|
|
|
|
|
If ``v`` is longer than ``a``, then they are swapped.
|
|
|
|
The conv filter is flipped following signal processing convention.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): 1D Input array.
|
|
|
|
v (array): 1D Input array.
|
|
|
|
mode (str, optional): {'full', 'valid', 'same'}
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The convolved array.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"conv1d",
|
2024-12-12 07:45:39 +08:00
|
|
|
&mx::conv1d,
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::arg(),
|
|
|
|
nb::arg(),
|
2023-11-30 02:30:41 +08:00
|
|
|
"stride"_a = 1,
|
|
|
|
"padding"_a = 0,
|
|
|
|
"dilation"_a = 1,
|
|
|
|
"groups"_a = 1,
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def conv1d(input: array, weight: array, /, stride: int = 1, padding: int = 0, dilation: int = 1, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array"),
|
2023-11-30 02:30:41 +08:00
|
|
|
R"pbdoc(
|
|
|
|
1D convolution over an input with several channels
|
|
|
|
|
|
|
|
Args:
|
2024-09-07 10:52:38 +08:00
|
|
|
input (array): Input array of shape ``(N, H, C_in)``.
|
|
|
|
weight (array): Weight array of shape ``(C_out, H, C_in)``.
|
|
|
|
stride (int, optional): Kernel stride. Default: ``1``.
|
|
|
|
padding (int, optional): Input padding. Default: ``0``.
|
|
|
|
dilation (int, optional): Kernel dilation. Default: ``1``.
|
|
|
|
groups (int, optional): Input feature groups. Default: ``1``.
|
2023-11-30 02:30:41 +08:00
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The convolved array.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"conv2d",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const mx::array& input,
|
|
|
|
const mx::array& weight,
|
2023-11-30 02:30:41 +08:00
|
|
|
const std::variant<int, std::pair<int, int>>& stride,
|
|
|
|
const std::variant<int, std::pair<int, int>>& padding,
|
|
|
|
const std::variant<int, std::pair<int, int>>& dilation,
|
|
|
|
int groups,
|
2024-12-12 07:45:39 +08:00
|
|
|
mx::StreamOrDevice s) {
|
2023-11-30 02:30:41 +08:00
|
|
|
std::pair<int, int> stride_pair{1, 1};
|
|
|
|
std::pair<int, int> padding_pair{0, 0};
|
|
|
|
std::pair<int, int> dilation_pair{1, 1};
|
|
|
|
|
|
|
|
if (auto pv = std::get_if<int>(&stride); pv) {
|
|
|
|
stride_pair = std::pair<int, int>{*pv, *pv};
|
|
|
|
} else {
|
|
|
|
stride_pair = std::get<std::pair<int, int>>(stride);
|
|
|
|
}
|
|
|
|
|
|
|
|
if (auto pv = std::get_if<int>(&padding); pv) {
|
|
|
|
padding_pair = std::pair<int, int>{*pv, *pv};
|
|
|
|
} else {
|
|
|
|
padding_pair = std::get<std::pair<int, int>>(padding);
|
|
|
|
}
|
|
|
|
|
|
|
|
if (auto pv = std::get_if<int>(&dilation); pv) {
|
|
|
|
dilation_pair = std::pair<int, int>{*pv, *pv};
|
|
|
|
} else {
|
|
|
|
dilation_pair = std::get<std::pair<int, int>>(dilation);
|
|
|
|
}
|
|
|
|
|
2024-12-12 07:45:39 +08:00
|
|
|
return mx::conv2d(
|
2023-11-30 02:30:41 +08:00
|
|
|
input, weight, stride_pair, padding_pair, dilation_pair, groups, s);
|
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::arg(),
|
|
|
|
nb::arg(),
|
2023-11-30 02:30:41 +08:00
|
|
|
"stride"_a = 1,
|
|
|
|
"padding"_a = 0,
|
|
|
|
"dilation"_a = 1,
|
|
|
|
"groups"_a = 1,
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
2024-08-29 02:16:19 +08:00
|
|
|
"def conv2d(input: array, weight: array, /, stride: Union[int, tuple[int, int]] = 1, padding: Union[int, tuple[int, int]] = 0, dilation: Union[int, tuple[int, int]] = 1, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array"),
|
2023-11-30 02:30:41 +08:00
|
|
|
R"pbdoc(
|
|
|
|
2D convolution over an input with several channels
|
|
|
|
|
|
|
|
Args:
|
2024-09-07 10:52:38 +08:00
|
|
|
input (array): Input array of shape ``(N, H, W, C_in)``.
|
|
|
|
weight (array): Weight array of shape ``(C_out, H, W, C_in)``.
|
2023-12-12 07:01:41 +08:00
|
|
|
stride (int or tuple(int), optional): :obj:`tuple` of size 2 with
|
|
|
|
kernel strides. All spatial dimensions get the same stride if
|
2023-11-30 02:30:41 +08:00
|
|
|
only one number is specified. Default: ``1``.
|
|
|
|
padding (int or tuple(int), optional): :obj:`tuple` of size 2 with
|
2023-12-12 07:01:41 +08:00
|
|
|
symmetric input padding. All spatial dimensions get the same
|
2023-11-30 02:30:41 +08:00
|
|
|
padding if only one number is specified. Default: ``0``.
|
2023-12-12 07:01:41 +08:00
|
|
|
dilation (int or tuple(int), optional): :obj:`tuple` of size 2 with
|
|
|
|
kernel dilation. All spatial dimensions get the same dilation
|
2023-11-30 02:30:41 +08:00
|
|
|
if only one number is specified. Default: ``1``
|
|
|
|
groups (int, optional): input feature groups. Default: ``1``.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The convolved array.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
2024-05-11 21:15:02 +08:00
|
|
|
"conv3d",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const mx::array& input,
|
|
|
|
const mx::array& weight,
|
2024-05-11 21:15:02 +08:00
|
|
|
const std::variant<int, std::tuple<int, int, int>>& stride,
|
|
|
|
const std::variant<int, std::tuple<int, int, int>>& padding,
|
|
|
|
const std::variant<int, std::tuple<int, int, int>>& dilation,
|
|
|
|
int groups,
|
2024-12-12 07:45:39 +08:00
|
|
|
mx::StreamOrDevice s) {
|
2024-05-11 21:15:02 +08:00
|
|
|
std::tuple<int, int, int> stride_tuple{1, 1, 1};
|
|
|
|
std::tuple<int, int, int> padding_tuple{0, 0, 0};
|
|
|
|
std::tuple<int, int, int> dilation_tuple{1, 1, 1};
|
|
|
|
|
|
|
|
if (auto pv = std::get_if<int>(&stride); pv) {
|
|
|
|
stride_tuple = std::tuple<int, int, int>{*pv, *pv, *pv};
|
|
|
|
} else {
|
|
|
|
stride_tuple = std::get<std::tuple<int, int, int>>(stride);
|
|
|
|
}
|
|
|
|
|
|
|
|
if (auto pv = std::get_if<int>(&padding); pv) {
|
|
|
|
padding_tuple = std::tuple<int, int, int>{*pv, *pv, *pv};
|
|
|
|
} else {
|
|
|
|
padding_tuple = std::get<std::tuple<int, int, int>>(padding);
|
|
|
|
}
|
|
|
|
|
|
|
|
if (auto pv = std::get_if<int>(&dilation); pv) {
|
|
|
|
dilation_tuple = std::tuple<int, int, int>{*pv, *pv, *pv};
|
|
|
|
} else {
|
|
|
|
dilation_tuple = std::get<std::tuple<int, int, int>>(dilation);
|
|
|
|
}
|
|
|
|
|
2024-12-12 07:45:39 +08:00
|
|
|
return mx::conv3d(
|
2024-05-11 21:15:02 +08:00
|
|
|
input,
|
|
|
|
weight,
|
|
|
|
stride_tuple,
|
|
|
|
padding_tuple,
|
|
|
|
dilation_tuple,
|
|
|
|
groups,
|
|
|
|
s);
|
|
|
|
},
|
|
|
|
nb::arg(),
|
|
|
|
nb::arg(),
|
|
|
|
"stride"_a = 1,
|
|
|
|
"padding"_a = 0,
|
|
|
|
"dilation"_a = 1,
|
|
|
|
"groups"_a = 1,
|
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
2024-08-29 02:16:19 +08:00
|
|
|
"def conv3d(input: array, weight: array, /, stride: Union[int, tuple[int, int, int]] = 1, padding: Union[int, tuple[int, int, int]] = 0, dilation: Union[int, tuple[int, int, int]] = 1, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array"),
|
2024-05-11 21:15:02 +08:00
|
|
|
R"pbdoc(
|
|
|
|
3D convolution over an input with several channels
|
|
|
|
|
|
|
|
Note: Only the default ``groups=1`` is currently supported.
|
|
|
|
|
|
|
|
Args:
|
2024-09-07 10:52:38 +08:00
|
|
|
input (array): Input array of shape ``(N, D, H, W, C_in)``.
|
|
|
|
weight (array): Weight array of shape ``(C_out, D, H, W, C_in)``.
|
|
|
|
stride (int or tuple(int), optional): :obj:`tuple` of size 3 with
|
|
|
|
kernel strides. All spatial dimensions get the same stride if
|
|
|
|
only one number is specified. Default: ``1``.
|
|
|
|
padding (int or tuple(int), optional): :obj:`tuple` of size 3 with
|
|
|
|
symmetric input padding. All spatial dimensions get the same
|
|
|
|
padding if only one number is specified. Default: ``0``.
|
|
|
|
dilation (int or tuple(int), optional): :obj:`tuple` of size 3 with
|
|
|
|
kernel dilation. All spatial dimensions get the same dilation
|
|
|
|
if only one number is specified. Default: ``1``
|
|
|
|
groups (int, optional): input feature groups. Default: ``1``.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The convolved array.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"conv_transpose1d",
|
2024-12-12 07:45:39 +08:00
|
|
|
&mx::conv_transpose1d,
|
2024-09-07 10:52:38 +08:00
|
|
|
nb::arg(),
|
|
|
|
nb::arg(),
|
|
|
|
"stride"_a = 1,
|
|
|
|
"padding"_a = 0,
|
|
|
|
"dilation"_a = 1,
|
|
|
|
"groups"_a = 1,
|
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def conv_transpose1d(input: array, weight: array, /, stride: int = 1, padding: int = 0, dilation: int = 1, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array"),
|
|
|
|
R"pbdoc(
|
|
|
|
1D transposed convolution over an input with several channels
|
|
|
|
|
|
|
|
Args:
|
|
|
|
input (array): Input array of shape ``(N, H, C_in)``.
|
|
|
|
weight (array): Weight array of shape ``(C_out, H, C_in)``.
|
|
|
|
stride (int, optional): Kernel stride. Default: ``1``.
|
|
|
|
padding (int, optional): Input padding. Default: ``0``.
|
|
|
|
dilation (int, optional): Kernel dilation. Default: ``1``.
|
|
|
|
groups (int, optional): Input feature groups. Default: ``1``.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The convolved array.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"conv_transpose2d",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const mx::array& input,
|
|
|
|
const mx::array& weight,
|
2024-09-07 10:52:38 +08:00
|
|
|
const std::variant<int, std::pair<int, int>>& stride,
|
|
|
|
const std::variant<int, std::pair<int, int>>& padding,
|
|
|
|
const std::variant<int, std::pair<int, int>>& dilation,
|
|
|
|
int groups,
|
2024-12-12 07:45:39 +08:00
|
|
|
mx::StreamOrDevice s) {
|
2024-09-07 10:52:38 +08:00
|
|
|
std::pair<int, int> stride_pair{1, 1};
|
|
|
|
std::pair<int, int> padding_pair{0, 0};
|
|
|
|
std::pair<int, int> dilation_pair{1, 1};
|
|
|
|
|
|
|
|
if (auto pv = std::get_if<int>(&stride); pv) {
|
|
|
|
stride_pair = std::pair<int, int>{*pv, *pv};
|
|
|
|
} else {
|
|
|
|
stride_pair = std::get<std::pair<int, int>>(stride);
|
|
|
|
}
|
|
|
|
|
|
|
|
if (auto pv = std::get_if<int>(&padding); pv) {
|
|
|
|
padding_pair = std::pair<int, int>{*pv, *pv};
|
|
|
|
} else {
|
|
|
|
padding_pair = std::get<std::pair<int, int>>(padding);
|
|
|
|
}
|
|
|
|
|
|
|
|
if (auto pv = std::get_if<int>(&dilation); pv) {
|
|
|
|
dilation_pair = std::pair<int, int>{*pv, *pv};
|
|
|
|
} else {
|
|
|
|
dilation_pair = std::get<std::pair<int, int>>(dilation);
|
|
|
|
}
|
|
|
|
|
2024-12-12 07:45:39 +08:00
|
|
|
return mx::conv_transpose2d(
|
2024-09-07 10:52:38 +08:00
|
|
|
input, weight, stride_pair, padding_pair, dilation_pair, groups, s);
|
|
|
|
},
|
|
|
|
nb::arg(),
|
|
|
|
nb::arg(),
|
|
|
|
"stride"_a = 1,
|
|
|
|
"padding"_a = 0,
|
|
|
|
"dilation"_a = 1,
|
|
|
|
"groups"_a = 1,
|
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def conv_transpose2d(input: array, weight: array, /, stride: Union[int, Tuple[int, int]] = 1, padding: Union[int, Tuple[int, int]] = 0, dilation: Union[int, Tuple[int, int]] = 1, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array"),
|
|
|
|
R"pbdoc(
|
|
|
|
2D transposed convolution over an input with several channels
|
|
|
|
|
|
|
|
Note: Only the default ``groups=1`` is currently supported.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
input (array): Input array of shape ``(N, H, W, C_in)``.
|
|
|
|
weight (array): Weight array of shape ``(C_out, H, W, C_in)``.
|
|
|
|
stride (int or tuple(int), optional): :obj:`tuple` of size 2 with
|
|
|
|
kernel strides. All spatial dimensions get the same stride if
|
|
|
|
only one number is specified. Default: ``1``.
|
|
|
|
padding (int or tuple(int), optional): :obj:`tuple` of size 2 with
|
|
|
|
symmetric input padding. All spatial dimensions get the same
|
|
|
|
padding if only one number is specified. Default: ``0``.
|
|
|
|
dilation (int or tuple(int), optional): :obj:`tuple` of size 2 with
|
|
|
|
kernel dilation. All spatial dimensions get the same dilation
|
|
|
|
if only one number is specified. Default: ``1``
|
|
|
|
groups (int, optional): input feature groups. Default: ``1``.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The convolved array.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"conv_transpose3d",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const mx::array& input,
|
|
|
|
const mx::array& weight,
|
2024-09-07 10:52:38 +08:00
|
|
|
const std::variant<int, std::tuple<int, int, int>>& stride,
|
|
|
|
const std::variant<int, std::tuple<int, int, int>>& padding,
|
|
|
|
const std::variant<int, std::tuple<int, int, int>>& dilation,
|
|
|
|
int groups,
|
2024-12-12 07:45:39 +08:00
|
|
|
mx::StreamOrDevice s) {
|
2024-09-07 10:52:38 +08:00
|
|
|
std::tuple<int, int, int> stride_tuple{1, 1, 1};
|
|
|
|
std::tuple<int, int, int> padding_tuple{0, 0, 0};
|
|
|
|
std::tuple<int, int, int> dilation_tuple{1, 1, 1};
|
|
|
|
|
|
|
|
if (auto pv = std::get_if<int>(&stride); pv) {
|
|
|
|
stride_tuple = std::tuple<int, int, int>{*pv, *pv, *pv};
|
|
|
|
} else {
|
|
|
|
stride_tuple = std::get<std::tuple<int, int, int>>(stride);
|
|
|
|
}
|
|
|
|
|
|
|
|
if (auto pv = std::get_if<int>(&padding); pv) {
|
|
|
|
padding_tuple = std::tuple<int, int, int>{*pv, *pv, *pv};
|
|
|
|
} else {
|
|
|
|
padding_tuple = std::get<std::tuple<int, int, int>>(padding);
|
|
|
|
}
|
|
|
|
|
|
|
|
if (auto pv = std::get_if<int>(&dilation); pv) {
|
|
|
|
dilation_tuple = std::tuple<int, int, int>{*pv, *pv, *pv};
|
|
|
|
} else {
|
|
|
|
dilation_tuple = std::get<std::tuple<int, int, int>>(dilation);
|
|
|
|
}
|
|
|
|
|
2024-12-12 07:45:39 +08:00
|
|
|
return mx::conv_transpose3d(
|
2024-09-07 10:52:38 +08:00
|
|
|
input,
|
|
|
|
weight,
|
|
|
|
stride_tuple,
|
|
|
|
padding_tuple,
|
|
|
|
dilation_tuple,
|
|
|
|
groups,
|
|
|
|
s);
|
|
|
|
},
|
|
|
|
nb::arg(),
|
|
|
|
nb::arg(),
|
|
|
|
"stride"_a = 1,
|
|
|
|
"padding"_a = 0,
|
|
|
|
"dilation"_a = 1,
|
|
|
|
"groups"_a = 1,
|
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def conv_transpose3d(input: array, weight: array, /, stride: Union[int, Tuple[int, int, int]] = 1, padding: Union[int, Tuple[int, int, int]] = 0, dilation: Union[int, Tuple[int, int, int]] = 1, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array"),
|
|
|
|
R"pbdoc(
|
|
|
|
3D transposed convolution over an input with several channels
|
|
|
|
|
|
|
|
Note: Only the default ``groups=1`` is currently supported.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
input (array): Input array of shape ``(N, D, H, W, C_in)``.
|
|
|
|
weight (array): Weight array of shape ``(C_out, D, H, W, C_in)``.
|
2024-05-11 21:15:02 +08:00
|
|
|
stride (int or tuple(int), optional): :obj:`tuple` of size 3 with
|
|
|
|
kernel strides. All spatial dimensions get the same stride if
|
|
|
|
only one number is specified. Default: ``1``.
|
|
|
|
padding (int or tuple(int), optional): :obj:`tuple` of size 3 with
|
|
|
|
symmetric input padding. All spatial dimensions get the same
|
|
|
|
padding if only one number is specified. Default: ``0``.
|
|
|
|
dilation (int or tuple(int), optional): :obj:`tuple` of size 3 with
|
|
|
|
kernel dilation. All spatial dimensions get the same dilation
|
|
|
|
if only one number is specified. Default: ``1``
|
|
|
|
groups (int, optional): input feature groups. Default: ``1``.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The convolved array.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
2024-02-29 12:11:16 +08:00
|
|
|
"conv_general",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const mx::array& input,
|
|
|
|
const mx::array& weight,
|
2024-02-29 12:11:16 +08:00
|
|
|
const std::variant<int, std::vector<int>>& stride,
|
|
|
|
const std::variant<
|
|
|
|
int,
|
|
|
|
std::vector<int>,
|
|
|
|
std::pair<std::vector<int>, std::vector<int>>>& padding,
|
|
|
|
const std::variant<int, std::vector<int>>& kernel_dilation,
|
|
|
|
const std::variant<int, std::vector<int>>& input_dilation,
|
|
|
|
int groups,
|
|
|
|
bool flip,
|
2024-12-12 07:45:39 +08:00
|
|
|
mx::StreamOrDevice s) {
|
2024-02-29 12:11:16 +08:00
|
|
|
std::vector<int> stride_vec;
|
|
|
|
std::vector<int> padding_lo_vec;
|
|
|
|
std::vector<int> padding_hi_vec;
|
|
|
|
std::vector<int> kernel_dilation_vec;
|
|
|
|
std::vector<int> input_dilation_vec;
|
|
|
|
|
|
|
|
if (auto pv = std::get_if<int>(&stride); pv) {
|
|
|
|
stride_vec.push_back(*pv);
|
|
|
|
} else {
|
|
|
|
stride_vec = std::get<std::vector<int>>(stride);
|
|
|
|
}
|
|
|
|
|
|
|
|
if (auto pv = std::get_if<int>(&padding); pv) {
|
|
|
|
padding_lo_vec.push_back(*pv);
|
|
|
|
padding_hi_vec.push_back(*pv);
|
|
|
|
} else if (auto pv = std::get_if<std::vector<int>>(&padding); pv) {
|
|
|
|
padding_lo_vec = *pv;
|
|
|
|
padding_hi_vec = *pv;
|
|
|
|
} else {
|
|
|
|
auto [pl, ph] =
|
|
|
|
std::get<std::pair<std::vector<int>, std::vector<int>>>(padding);
|
|
|
|
padding_lo_vec = pl;
|
|
|
|
padding_hi_vec = ph;
|
|
|
|
}
|
|
|
|
|
|
|
|
if (auto pv = std::get_if<int>(&kernel_dilation); pv) {
|
|
|
|
kernel_dilation_vec.push_back(*pv);
|
|
|
|
} else {
|
|
|
|
kernel_dilation_vec = std::get<std::vector<int>>(kernel_dilation);
|
|
|
|
}
|
|
|
|
|
|
|
|
if (auto pv = std::get_if<int>(&input_dilation); pv) {
|
|
|
|
input_dilation_vec.push_back(*pv);
|
|
|
|
} else {
|
|
|
|
input_dilation_vec = std::get<std::vector<int>>(input_dilation);
|
|
|
|
}
|
|
|
|
|
2024-12-12 07:45:39 +08:00
|
|
|
return mx::conv_general(
|
2024-05-14 21:50:09 +08:00
|
|
|
/* array input = */ std::move(input),
|
|
|
|
/* array weight = */ std::move(weight),
|
|
|
|
/* std::vector<int> stride = */ std::move(stride_vec),
|
|
|
|
/* std::vector<int> padding_lo = */ std::move(padding_lo_vec),
|
|
|
|
/* std::vector<int> padding_hi = */ std::move(padding_hi_vec),
|
|
|
|
/* std::vector<int> kernel_dilation = */
|
|
|
|
std::move(kernel_dilation_vec),
|
|
|
|
/* std::vector<int> input_dilation = */
|
|
|
|
std::move(input_dilation_vec),
|
2024-02-29 12:11:16 +08:00
|
|
|
/* int groups = */ groups,
|
|
|
|
/* bool flip = */ flip,
|
|
|
|
s);
|
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::arg(),
|
|
|
|
nb::arg(),
|
2024-02-29 12:11:16 +08:00
|
|
|
"stride"_a = 1,
|
|
|
|
"padding"_a = 0,
|
|
|
|
"kernel_dilation"_a = 1,
|
|
|
|
"input_dilation"_a = 1,
|
|
|
|
"groups"_a = 1,
|
|
|
|
"flip"_a = false,
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
2024-08-29 02:16:19 +08:00
|
|
|
"def conv_general(input: array, weight: array, /, stride: Union[int, Sequence[int]] = 1, padding: Union[int, Sequence[int], tuple[Sequence[int], Sequence[int]]] = 0, kernel_dilation: Union[int, Sequence[int]] = 1, input_dilation: Union[int, Sequence[int]] = 1, groups: int = 1, flip: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"),
|
2024-02-29 12:11:16 +08:00
|
|
|
R"pbdoc(
|
|
|
|
General convolution over an input with several channels
|
|
|
|
|
|
|
|
Args:
|
2024-09-07 10:52:38 +08:00
|
|
|
input (array): Input array of shape ``(N, ..., C_in)``.
|
|
|
|
weight (array): Weight array of shape ``(C_out, ..., C_in)``.
|
2024-02-29 12:11:16 +08:00
|
|
|
stride (int or list(int), optional): :obj:`list` with kernel strides.
|
|
|
|
All spatial dimensions get the same stride if
|
|
|
|
only one number is specified. Default: ``1``.
|
|
|
|
padding (int, list(int), or tuple(list(int), list(int)), optional):
|
|
|
|
:obj:`list` with input padding. All spatial dimensions get the same
|
|
|
|
padding if only one number is specified. Default: ``0``.
|
|
|
|
kernel_dilation (int or list(int), optional): :obj:`list` with
|
|
|
|
kernel dilation. All spatial dimensions get the same dilation
|
|
|
|
if only one number is specified. Default: ``1``
|
|
|
|
input_dilation (int or list(int), optional): :obj:`list` with
|
|
|
|
input dilation. All spatial dimensions get the same dilation
|
|
|
|
if only one number is specified. Default: ``1``
|
|
|
|
groups (int, optional): Input feature groups. Default: ``1``.
|
|
|
|
flip (bool, optional): Flip the order in which the spatial dimensions of
|
|
|
|
the weights are processed. Performs the cross-correlation operator when
|
|
|
|
``flip`` is ``False`` and the convolution operator otherwise.
|
|
|
|
Default: ``False``.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The convolved array.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
2023-11-30 02:30:41 +08:00
|
|
|
"save",
|
|
|
|
&mlx_save_helper,
|
|
|
|
"file"_a,
|
|
|
|
"arr"_a,
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::sig("def save(file: str, arr: array) -> None"),
|
2023-11-30 02:30:41 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Save the array to a binary file in ``.npy`` format.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
file (str): File to which the array is saved
|
|
|
|
arr (array): Array to be saved.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"savez",
|
2024-03-19 11:12:25 +08:00
|
|
|
[](nb::object file, nb::args args, const nb::kwargs& kwargs) {
|
|
|
|
mlx_savez_helper(file, args, kwargs, /* compressed= */ false);
|
2023-11-30 02:30:41 +08:00
|
|
|
},
|
|
|
|
"file"_a,
|
2024-03-19 11:12:25 +08:00
|
|
|
"args"_a,
|
|
|
|
"kwargs"_a,
|
2023-11-30 02:30:41 +08:00
|
|
|
R"pbdoc(
|
2024-03-19 11:12:25 +08:00
|
|
|
Save several arrays to a binary file in uncompressed ``.npz``
|
|
|
|
format.
|
2023-11-30 02:30:41 +08:00
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
import mlx.core as mx
|
|
|
|
|
|
|
|
x = mx.ones((10, 10))
|
|
|
|
mx.savez("my_path.npz", x=x)
|
|
|
|
|
|
|
|
import mlx.nn as nn
|
|
|
|
from mlx.utils import tree_flatten
|
|
|
|
|
|
|
|
model = nn.TransformerEncoder(6, 128, 4)
|
|
|
|
flat_params = tree_flatten(model.parameters())
|
|
|
|
mx.savez("model.npz", **dict(flat_params))
|
|
|
|
|
|
|
|
Args:
|
|
|
|
file (file, str): Path to file to which the arrays are saved.
|
|
|
|
args (arrays): Arrays to be saved.
|
|
|
|
kwargs (arrays): Arrays to be saved. Each array will be saved
|
|
|
|
with the associated keyword as the output file name.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"savez_compressed",
|
2024-03-19 11:12:25 +08:00
|
|
|
[](nb::object file, nb::args args, const nb::kwargs& kwargs) {
|
2023-11-30 02:30:41 +08:00
|
|
|
mlx_savez_helper(file, args, kwargs, /*compressed=*/true);
|
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::arg(),
|
|
|
|
"args"_a,
|
|
|
|
"kwargs"_a,
|
|
|
|
nb::sig("def savez_compressed(file: str, *args, **kwargs)"),
|
2023-11-30 02:30:41 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Save several arrays to a binary file in compressed ``.npz`` format.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
file (file, str): Path to file to which the arrays are saved.
|
|
|
|
args (arrays): Arrays to be saved.
|
|
|
|
kwargs (arrays): Arrays to be saved. Each array will be saved
|
|
|
|
with the associated keyword as the output file name.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"load",
|
|
|
|
&mlx_load_helper,
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::arg(),
|
|
|
|
"format"_a = nb::none(),
|
2024-01-20 06:06:05 +08:00
|
|
|
"return_metadata"_a = false,
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
2024-08-29 02:16:19 +08:00
|
|
|
"def load(file: str, /, format: Optional[str] = None, return_metadata: bool = False, *, stream: Union[None, Stream, Device] = None) -> Union[array, dict[str, array]]"),
|
2023-11-30 02:30:41 +08:00
|
|
|
R"pbdoc(
|
2024-01-11 05:22:48 +08:00
|
|
|
Load array(s) from a binary file.
|
|
|
|
|
2024-03-19 11:12:25 +08:00
|
|
|
The supported formats are ``.npy``, ``.npz``, ``.safetensors``, and
|
|
|
|
``.gguf``.
|
2023-11-30 02:30:41 +08:00
|
|
|
|
|
|
|
Args:
|
2023-12-27 18:06:55 +08:00
|
|
|
file (file, str): File in which the array is saved.
|
2024-03-19 11:12:25 +08:00
|
|
|
format (str, optional): Format of the file. If ``None``, the
|
2024-07-26 00:36:44 +08:00
|
|
|
format is inferred from the file extension. Supported formats:
|
|
|
|
``npy``, ``npz``, and ``safetensors``. Default: ``None``.
|
2024-03-19 11:12:25 +08:00
|
|
|
return_metadata (bool, optional): Load the metadata for formats
|
2024-07-26 00:36:44 +08:00
|
|
|
which support matadata. The metadata will be returned as an
|
|
|
|
additional dictionary. Default: ``False``.
|
2023-11-30 02:30:41 +08:00
|
|
|
Returns:
|
2024-05-21 02:51:47 +08:00
|
|
|
array or dict:
|
2024-03-19 11:12:25 +08:00
|
|
|
A single array if loading from a ``.npy`` file or a dict
|
|
|
|
mapping names to arrays if loading from a ``.npz`` or
|
|
|
|
``.safetensors`` file. If ``return_metadata` is ``True`` an
|
|
|
|
additional dictionary of metadata will be returned.
|
2024-01-11 05:22:48 +08:00
|
|
|
|
|
|
|
Warning:
|
|
|
|
|
2024-03-19 11:12:25 +08:00
|
|
|
When loading unsupported quantization formats from GGUF, tensors
|
|
|
|
will automatically cast to ``mx.float16``
|
2023-12-27 18:06:55 +08:00
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"save_safetensors",
|
|
|
|
&mlx_save_safetensor_helper,
|
|
|
|
"file"_a,
|
|
|
|
"arrays"_a,
|
2024-03-19 11:12:25 +08:00
|
|
|
"metadata"_a = nb::none(),
|
|
|
|
nb::sig(
|
2024-08-29 02:16:19 +08:00
|
|
|
"def save_safetensors(file: str, arrays: dict[str, array], metadata: Optional[dict[str, str]] = None)"),
|
2023-12-27 18:06:55 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Save array(s) to a binary file in ``.safetensors`` format.
|
|
|
|
|
2024-03-19 11:12:25 +08:00
|
|
|
See the `Safetensors documentation
|
|
|
|
<https://huggingface.co/docs/safetensors/index>`_ for more
|
|
|
|
information on the format.
|
2023-12-27 18:06:55 +08:00
|
|
|
|
|
|
|
Args:
|
2024-01-11 05:22:48 +08:00
|
|
|
file (file, str): File in which the array is saved.
|
2024-03-19 11:12:25 +08:00
|
|
|
arrays (dict(str, array)): The dictionary of names to arrays to
|
|
|
|
be saved. metadata (dict(str, str), optional): The dictionary of
|
|
|
|
metadata to be saved.
|
2024-01-11 05:22:48 +08:00
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"save_gguf",
|
|
|
|
&mlx_save_gguf_helper,
|
|
|
|
"file"_a,
|
|
|
|
"arrays"_a,
|
2024-03-19 11:12:25 +08:00
|
|
|
"metadata"_a = nb::none(),
|
|
|
|
nb::sig(
|
2024-08-29 02:16:19 +08:00
|
|
|
"def save_gguf(file: str, arrays: dict[str, array], metadata: dict[str, Union[array, str, list[str]]])"),
|
2024-01-11 05:22:48 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Save array(s) to a binary file in ``.gguf`` format.
|
|
|
|
|
2024-03-19 11:12:25 +08:00
|
|
|
See the `GGUF documentation
|
|
|
|
<https://github.com/ggerganov/ggml/blob/master/docs/gguf.md>`_ for
|
2024-01-11 05:22:48 +08:00
|
|
|
more information on the format.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
file (file, str): File in which the array is saved.
|
2024-03-19 11:12:25 +08:00
|
|
|
arrays (dict(str, array)): The dictionary of names to arrays to
|
2024-07-26 00:36:44 +08:00
|
|
|
be saved.
|
|
|
|
metadata (dict(str, Union[array, str, list(str)])): The dictionary
|
|
|
|
of metadata to be saved. The values can be a scalar or 1D
|
2024-03-19 11:12:25 +08:00
|
|
|
obj:`array`, a :obj:`str`, or a :obj:`list` of :obj:`str`.
|
2023-11-30 02:30:41 +08:00
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"where",
|
|
|
|
[](const ScalarOrArray& condition,
|
|
|
|
const ScalarOrArray& x_,
|
|
|
|
const ScalarOrArray& y_,
|
2024-12-12 07:45:39 +08:00
|
|
|
mx::StreamOrDevice s) {
|
2023-11-30 02:30:41 +08:00
|
|
|
auto [x, y] = to_arrays(x_, y_);
|
2024-12-12 07:45:39 +08:00
|
|
|
return mx::where(to_array(condition), x, y, s);
|
2023-11-30 02:30:41 +08:00
|
|
|
},
|
|
|
|
"condition"_a,
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::arg(),
|
|
|
|
nb::arg(),
|
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def where(condition: Union[scalar, array], x: Union[scalar, array], y: Union[scalar, array], /, *, stream: Union[None, Stream, Device] = None) -> array"),
|
2023-11-30 02:30:41 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Select from ``x`` or ``y`` according to ``condition``.
|
|
|
|
|
2024-03-19 11:12:25 +08:00
|
|
|
The condition and input arrays must be the same shape or
|
|
|
|
broadcastable with each another.
|
2023-11-30 02:30:41 +08:00
|
|
|
|
|
|
|
Args:
|
|
|
|
condition (array): The condition array.
|
|
|
|
x (array): The input selected from where condition is ``True``.
|
|
|
|
y (array): The input selected from where condition is ``False``.
|
|
|
|
|
|
|
|
Returns:
|
2024-05-21 02:51:47 +08:00
|
|
|
array: The output containing elements selected from
|
2024-03-19 11:12:25 +08:00
|
|
|
``x`` and ``y``.
|
2023-11-30 02:30:41 +08:00
|
|
|
)pbdoc");
|
2024-07-26 00:57:37 +08:00
|
|
|
m.def(
|
|
|
|
"nan_to_num",
|
|
|
|
[](const ScalarOrArray& a,
|
|
|
|
float nan,
|
|
|
|
std::optional<float>& posinf,
|
|
|
|
std::optional<float>& neginf,
|
2024-12-12 07:45:39 +08:00
|
|
|
mx::StreamOrDevice s) {
|
|
|
|
return mx::nan_to_num(to_array(a), nan, posinf, neginf, s);
|
2024-07-26 00:57:37 +08:00
|
|
|
},
|
|
|
|
nb::arg(),
|
|
|
|
"nan"_a = 0.0f,
|
|
|
|
"posinf"_a = nb::none(),
|
|
|
|
"neginf"_a = nb::none(),
|
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def nan_to_num(a: Union[scalar, array], nan: float = 0, posinf: Optional[float] = None, neginf: Optional[float] = None, *, stream: Union[None, Stream, Device] = None) -> array"),
|
|
|
|
R"pbdoc(
|
|
|
|
Replace NaN and Inf values with finite numbers.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array
|
|
|
|
nan (float, optional): Value to replace NaN with. Default: ``0``.
|
|
|
|
posinf (float, optional): Value to replace positive infinities
|
|
|
|
with. If ``None``, defaults to largest finite value for the
|
|
|
|
given data type. Default: ``None``.
|
|
|
|
neginf (float, optional): Value to replace negative infinities
|
|
|
|
with. If ``None``, defaults to the negative of the largest
|
|
|
|
finite value for the given data type. Default: ``None``.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: Output array with NaN and Inf replaced.
|
|
|
|
)pbdoc");
|
2023-12-19 03:32:48 +08:00
|
|
|
m.def(
|
|
|
|
"round",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const ScalarOrArray& a, int decimals, mx::StreamOrDevice s) {
|
|
|
|
return mx::round(to_array(a), decimals, s);
|
2023-12-19 03:32:48 +08:00
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::arg(),
|
2023-12-19 03:32:48 +08:00
|
|
|
"decimals"_a = 0,
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def round(a: array, /, decimals: int = 0, stream: Union[None, Stream, Device] = None) -> array"),
|
2023-12-19 03:32:48 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Round to the given number of decimals.
|
|
|
|
|
2024-01-02 13:08:17 +08:00
|
|
|
Basically performs:
|
2023-12-19 03:32:48 +08:00
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
s = 10**decimals
|
|
|
|
x = round(x * s) / s
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array
|
|
|
|
decimals (int): Number of decimal places to round to. (default: 0)
|
|
|
|
|
|
|
|
Returns:
|
2024-05-21 02:51:47 +08:00
|
|
|
array: An array of the same type as ``a`` rounded to the
|
2024-03-19 11:12:25 +08:00
|
|
|
given number of decimals.
|
2023-12-19 03:32:48 +08:00
|
|
|
)pbdoc");
|
2023-12-19 15:18:57 +08:00
|
|
|
m.def(
|
|
|
|
"quantized_matmul",
|
2024-12-12 07:45:39 +08:00
|
|
|
&mx::quantized_matmul,
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::arg(),
|
|
|
|
nb::arg(),
|
2023-12-19 15:18:57 +08:00
|
|
|
"scales"_a,
|
|
|
|
"biases"_a,
|
2024-01-04 06:22:36 +08:00
|
|
|
"transpose"_a = true,
|
2023-12-21 08:53:53 +08:00
|
|
|
"group_size"_a = 64,
|
|
|
|
"bits"_a = 4,
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def quantized_matmul(x: array, w: array, /, scales: array, biases: array, transpose: bool = True, group_size: int = 64, bits: int = 4, *, stream: Union[None, Stream, Device] = None) -> array"),
|
2023-12-19 15:18:57 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Perform the matrix multiplication with the quantized matrix ``w``. The
|
2023-12-21 08:53:53 +08:00
|
|
|
quantization uses one floating point scale and bias per ``group_size`` of
|
|
|
|
elements. Each element in ``w`` takes ``bits`` bits and is packed in an
|
2023-12-19 15:18:57 +08:00
|
|
|
unsigned 32 bit integer.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
x (array): Input array
|
|
|
|
w (array): Quantized matrix packed in unsigned integers
|
2023-12-21 08:53:53 +08:00
|
|
|
scales (array): The scales to use per ``group_size`` elements of ``w``
|
|
|
|
biases (array): The biases to use per ``group_size`` elements of ``w``
|
2024-01-04 06:22:36 +08:00
|
|
|
transpose (bool, optional): Defines whether to multiply with the
|
|
|
|
transposed ``w`` or not, namely whether we are performing
|
2024-07-26 00:36:44 +08:00
|
|
|
``x @ w.T`` or ``x @ w``. Default: ``True``.
|
2023-12-21 08:53:53 +08:00
|
|
|
group_size (int, optional): The size of the group in ``w`` that
|
2024-07-26 00:36:44 +08:00
|
|
|
shares a scale and bias. Default: ``64``.
|
2023-12-21 08:53:53 +08:00
|
|
|
bits (int, optional): The number of bits occupied by each element in
|
2024-07-26 00:36:44 +08:00
|
|
|
``w``. Default: ``4``.
|
2023-12-19 15:18:57 +08:00
|
|
|
|
|
|
|
Returns:
|
2024-05-21 02:51:47 +08:00
|
|
|
array: The result of the multiplication of ``x`` with ``w``.
|
2023-12-19 15:18:57 +08:00
|
|
|
)pbdoc");
|
2023-12-21 06:17:38 +08:00
|
|
|
m.def(
|
|
|
|
"quantize",
|
2024-12-12 07:45:39 +08:00
|
|
|
&mx::quantize,
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::arg(),
|
2023-12-21 08:53:53 +08:00
|
|
|
"group_size"_a = 64,
|
|
|
|
"bits"_a = 4,
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
2024-08-29 02:16:19 +08:00
|
|
|
"def quantize(w: array, /, group_size: int = 64, bits : int = 4, *, stream: Union[None, Stream, Device] = None) -> tuple[array, array, array]"),
|
2023-12-21 06:17:38 +08:00
|
|
|
R"pbdoc(
|
2023-12-21 08:53:53 +08:00
|
|
|
Quantize the matrix ``w`` using ``bits`` bits per element.
|
2023-12-21 06:17:38 +08:00
|
|
|
|
2023-12-21 08:53:53 +08:00
|
|
|
Note, every ``group_size`` elements in a row of ``w`` are quantized
|
2023-12-21 06:17:38 +08:00
|
|
|
together. Hence, number of columns of ``w`` should be divisible by
|
2023-12-21 08:53:53 +08:00
|
|
|
``group_size``. In particular, the rows of ``w`` are divided into groups of
|
|
|
|
size ``group_size`` which are quantized together.
|
2023-12-21 06:17:38 +08:00
|
|
|
|
|
|
|
.. warning::
|
|
|
|
|
|
|
|
``quantize`` currently only supports 2D inputs with dimensions which are multiples of 32
|
|
|
|
|
|
|
|
Formally, for a group of :math:`g` consecutive elements :math:`w_1` to
|
|
|
|
:math:`w_g` in a row of ``w`` we compute the quantized representation
|
|
|
|
of each element :math:`\hat{w_i}` as follows
|
|
|
|
|
|
|
|
.. math::
|
|
|
|
|
|
|
|
\begin{aligned}
|
|
|
|
\alpha &= \max_i w_i \\
|
|
|
|
\beta &= \min_i w_i \\
|
|
|
|
s &= \frac{\alpha - \beta}{2^b - 1} \\
|
|
|
|
\hat{w_i} &= \textrm{round}\left( \frac{w_i - \beta}{s}\right).
|
|
|
|
\end{aligned}
|
|
|
|
|
|
|
|
After the above computation, :math:`\hat{w_i}` fits in :math:`b` bits
|
|
|
|
and is packed in an unsigned 32-bit integer from the lower to upper
|
|
|
|
bits. For instance, for 4-bit quantization we fit 8 elements in an
|
|
|
|
unsigned 32 bit integer where the 1st element occupies the 4 least
|
|
|
|
significant bits, the 2nd bits 4-7 etc.
|
|
|
|
|
|
|
|
In order to be able to dequantize the elements of ``w`` we also need to
|
|
|
|
save :math:`s` and :math:`\beta` which are the returned ``scales`` and
|
|
|
|
``biases`` respectively.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
w (array): Matrix to be quantized
|
2023-12-21 08:53:53 +08:00
|
|
|
group_size (int, optional): The size of the group in ``w`` that shares a
|
2024-07-26 00:36:44 +08:00
|
|
|
scale and bias. Default: ``64``.
|
2023-12-21 08:53:53 +08:00
|
|
|
bits (int, optional): The number of bits occupied by each element of
|
2024-07-26 00:36:44 +08:00
|
|
|
``w`` in the returned quantized matrix. Default: ``4``.
|
2023-12-21 06:17:38 +08:00
|
|
|
|
|
|
|
Returns:
|
2024-05-21 02:51:47 +08:00
|
|
|
tuple: A tuple containing
|
2023-12-21 06:17:38 +08:00
|
|
|
|
2024-05-21 02:51:47 +08:00
|
|
|
* w_q (array): The quantized version of ``w``
|
|
|
|
* scales (array): The scale to multiply each element with, namely :math:`s`
|
|
|
|
* biases (array): The biases to add to each element, namely :math:`\beta`
|
2023-12-21 06:17:38 +08:00
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"dequantize",
|
2024-12-12 07:45:39 +08:00
|
|
|
&mx::dequantize,
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::arg(),
|
2023-12-21 06:17:38 +08:00
|
|
|
"scales"_a,
|
|
|
|
"biases"_a,
|
2023-12-21 08:53:53 +08:00
|
|
|
"group_size"_a = 64,
|
|
|
|
"bits"_a = 4,
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def dequantize(w: array, /, scales: array, biases: array, group_size: int = 64, bits: int = 4, *, stream: Union[None, Stream, Device] = None) -> array"),
|
2023-12-21 06:17:38 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Dequantize the matrix ``w`` using the provided ``scales`` and
|
2023-12-21 08:53:53 +08:00
|
|
|
``biases`` and the ``group_size`` and ``bits`` configuration.
|
2023-12-21 06:17:38 +08:00
|
|
|
|
|
|
|
Formally, given the notation in :func:`quantize`, we compute
|
|
|
|
:math:`w_i` from :math:`\hat{w_i}` and corresponding :math:`s` and
|
|
|
|
:math:`\beta` as follows
|
|
|
|
|
|
|
|
.. math::
|
|
|
|
|
|
|
|
w_i = s \hat{w_i} - \beta
|
|
|
|
|
|
|
|
Args:
|
|
|
|
w (array): Matrix to be quantized
|
2023-12-21 08:53:53 +08:00
|
|
|
scales (array): The scales to use per ``group_size`` elements of ``w``
|
|
|
|
biases (array): The biases to use per ``group_size`` elements of ``w``
|
|
|
|
group_size (int, optional): The size of the group in ``w`` that shares a
|
2024-07-26 00:36:44 +08:00
|
|
|
scale and bias. Default: ``64``.
|
2023-12-21 08:53:53 +08:00
|
|
|
bits (int, optional): The number of bits occupied by each element in
|
2024-07-26 00:36:44 +08:00
|
|
|
``w``. Default: ``4``.
|
2023-12-21 06:17:38 +08:00
|
|
|
|
|
|
|
Returns:
|
2024-05-21 02:51:47 +08:00
|
|
|
array: The dequantized version of ``w``
|
2023-12-21 06:17:38 +08:00
|
|
|
)pbdoc");
|
2024-05-17 06:24:14 +08:00
|
|
|
m.def(
|
2024-06-25 04:36:35 +08:00
|
|
|
"gather_qmm",
|
2024-12-12 07:45:39 +08:00
|
|
|
&mx::gather_qmm,
|
2024-05-17 06:24:14 +08:00
|
|
|
nb::arg(),
|
|
|
|
nb::arg(),
|
|
|
|
"scales"_a,
|
|
|
|
"biases"_a,
|
|
|
|
"lhs_indices"_a = nb::none(),
|
|
|
|
"rhs_indices"_a = nb::none(),
|
|
|
|
"transpose"_a = true,
|
|
|
|
"group_size"_a = 64,
|
|
|
|
"bits"_a = 4,
|
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
2024-05-22 22:48:34 +08:00
|
|
|
"def gather_qmm(x: array, w: array, /, scales: array, biases: array, lhs_indices: Optional[array] = None, rhs_indices: Optional[array] = None, transpose: bool = True, group_size: int = 64, bits: int = 4, *, stream: Union[None, Stream, Device] = None) -> array"),
|
2024-05-17 06:24:14 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Perform quantized matrix multiplication with matrix-level gather.
|
|
|
|
|
2024-05-22 22:48:34 +08:00
|
|
|
This operation is the quantized equivalent to :func:`gather_mm`.
|
|
|
|
Similar to :func:`gather_mm`, the indices ``lhs_indices`` and
|
2024-05-17 06:24:14 +08:00
|
|
|
``rhs_indices`` contain flat indices along the batch dimensions (i.e.
|
|
|
|
all but the last two dimensions) of ``x`` and ``w`` respectively.
|
|
|
|
|
|
|
|
Note that ``scales`` and ``biases`` must have the same batch dimensions
|
|
|
|
as ``w`` since they represent the same quantized matrix.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
x (array): Input array
|
|
|
|
w (array): Quantized matrix packed in unsigned integers
|
|
|
|
scales (array): The scales to use per ``group_size`` elements of ``w``
|
|
|
|
biases (array): The biases to use per ``group_size`` elements of ``w``
|
2024-07-26 00:36:44 +08:00
|
|
|
lhs_indices (array, optional): Integer indices for ``x``. Default: ``None``.
|
|
|
|
rhs_indices (array, optional): Integer indices for ``w``. Default: ``None``.
|
2024-05-17 06:24:14 +08:00
|
|
|
transpose (bool, optional): Defines whether to multiply with the
|
|
|
|
transposed ``w`` or not, namely whether we are performing
|
2024-07-26 00:36:44 +08:00
|
|
|
``x @ w.T`` or ``x @ w``. Default: ``True``.
|
2024-05-17 06:24:14 +08:00
|
|
|
group_size (int, optional): The size of the group in ``w`` that
|
2024-07-26 00:36:44 +08:00
|
|
|
shares a scale and bias. Default: ``64``.
|
2024-05-17 06:24:14 +08:00
|
|
|
bits (int, optional): The number of bits occupied by each element in
|
2024-07-26 00:36:44 +08:00
|
|
|
``w``. Default: ``4``.
|
2024-05-17 06:24:14 +08:00
|
|
|
|
|
|
|
Returns:
|
2024-05-21 02:51:47 +08:00
|
|
|
array: The result of the multiplication of ``x`` with ``w``
|
2024-05-17 06:24:14 +08:00
|
|
|
after gathering using ``lhs_indices`` and ``rhs_indices``.
|
|
|
|
)pbdoc");
|
2024-01-03 09:15:00 +08:00
|
|
|
m.def(
|
|
|
|
"tensordot",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const mx::array& a,
|
|
|
|
const mx::array& b,
|
2024-03-05 01:51:02 +08:00
|
|
|
const std::variant<int, std::vector<std::vector<int>>>& axes,
|
2024-12-12 07:45:39 +08:00
|
|
|
mx::StreamOrDevice s) {
|
2024-03-05 01:51:02 +08:00
|
|
|
if (auto pv = std::get_if<int>(&axes); pv) {
|
2024-12-12 07:45:39 +08:00
|
|
|
return mx::tensordot(a, b, *pv, s);
|
2024-01-03 09:15:00 +08:00
|
|
|
} else {
|
2024-03-05 01:51:02 +08:00
|
|
|
auto& x = std::get<std::vector<std::vector<int>>>(axes);
|
2024-01-03 09:15:00 +08:00
|
|
|
if (x.size() != 2) {
|
|
|
|
throw std::invalid_argument(
|
2024-03-05 01:51:02 +08:00
|
|
|
"[tensordot] axes must be a list of two lists.");
|
2024-01-03 09:15:00 +08:00
|
|
|
}
|
2024-12-12 07:45:39 +08:00
|
|
|
return mx::tensordot(a, b, x[0], x[1], s);
|
2024-01-03 09:15:00 +08:00
|
|
|
}
|
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::arg(),
|
|
|
|
nb::arg(),
|
2024-01-23 13:17:00 +08:00
|
|
|
"axes"_a = 2,
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
2024-08-29 02:16:19 +08:00
|
|
|
"def tensordot(a: array, b: array, /, axes: Union[int, list[Sequence[int]]] = 2, *, stream: Union[None, Stream, Device] = None) -> array"),
|
2024-01-03 09:15:00 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Compute the tensor dot product along the specified axes.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array
|
|
|
|
b (array): Input array
|
2024-01-23 13:17:00 +08:00
|
|
|
axes (int or list(list(int)), optional): The number of dimensions to
|
2024-01-03 09:15:00 +08:00
|
|
|
sum over. If an integer is provided, then sum over the last
|
2024-01-23 13:17:00 +08:00
|
|
|
``axes`` dimensions of ``a`` and the first ``axes`` dimensions of
|
2024-01-03 09:15:00 +08:00
|
|
|
``b``. If a list of lists is provided, then sum over the
|
2024-07-26 00:36:44 +08:00
|
|
|
corresponding dimensions of ``a`` and ``b``. Default: 2.
|
2024-01-11 05:22:48 +08:00
|
|
|
|
2024-01-03 09:15:00 +08:00
|
|
|
Returns:
|
2024-05-21 02:51:47 +08:00
|
|
|
array: The tensor dot product.
|
2024-01-03 09:15:00 +08:00
|
|
|
)pbdoc");
|
2024-01-08 01:01:09 +08:00
|
|
|
m.def(
|
|
|
|
"inner",
|
2024-12-12 07:45:39 +08:00
|
|
|
&mx::inner,
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::arg(),
|
|
|
|
nb::arg(),
|
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def inner(a: array, b: array, /, *, stream: Union[None, Stream, Device] = None) -> array"),
|
2024-01-08 01:01:09 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Ordinary inner product of vectors for 1-D arrays, in higher dimensions a sum product over the last axes.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array
|
|
|
|
b (array): Input array
|
|
|
|
|
|
|
|
Returns:
|
2024-05-21 02:51:47 +08:00
|
|
|
array: The inner product.
|
2024-01-08 01:01:09 +08:00
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"outer",
|
2024-12-12 07:45:39 +08:00
|
|
|
&mx::outer,
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::arg(),
|
|
|
|
nb::arg(),
|
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def outer(a: array, b: array, /, *, stream: Union[None, Stream, Device] = None) -> array"),
|
2024-01-08 01:01:09 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Compute the outer product of two 1-D arrays, if the array's passed are not 1-D a flatten op will be run beforehand.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array
|
|
|
|
b (array): Input array
|
|
|
|
|
|
|
|
Returns:
|
2024-05-21 02:51:47 +08:00
|
|
|
array: The outer product.
|
2024-01-08 01:01:09 +08:00
|
|
|
)pbdoc");
|
2024-01-13 15:03:16 +08:00
|
|
|
m.def(
|
|
|
|
"tile",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const mx::array& a,
|
2024-05-14 21:25:25 +08:00
|
|
|
const std::variant<int, std::vector<int>>& reps,
|
2024-12-12 07:45:39 +08:00
|
|
|
mx::StreamOrDevice s) {
|
2024-01-13 15:03:16 +08:00
|
|
|
if (auto pv = std::get_if<int>(&reps); pv) {
|
2024-12-12 07:45:39 +08:00
|
|
|
return mx::tile(a, {*pv}, s);
|
2024-01-13 15:03:16 +08:00
|
|
|
} else {
|
2024-12-12 07:45:39 +08:00
|
|
|
return mx::tile(a, std::get<std::vector<int>>(reps), s);
|
2024-01-13 15:03:16 +08:00
|
|
|
}
|
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::arg(),
|
|
|
|
nb::arg(),
|
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
2024-03-25 06:03:27 +08:00
|
|
|
"def tile(a: array, reps: Union[int, Sequence[int]], /, *, stream: Union[None, Stream, Device] = None) -> array"),
|
2024-01-13 15:03:16 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Construct an array by repeating ``a`` the number of times given by ``reps``.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array
|
|
|
|
reps (int or list(int)): The number of times to repeat ``a`` along each axis.
|
|
|
|
|
|
|
|
Returns:
|
2024-05-21 02:51:47 +08:00
|
|
|
array: The tiled array.
|
2024-01-13 15:03:16 +08:00
|
|
|
)pbdoc");
|
2024-01-18 04:42:39 +08:00
|
|
|
m.def(
|
|
|
|
"addmm",
|
2024-12-12 07:45:39 +08:00
|
|
|
&mx::addmm,
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::arg(),
|
|
|
|
nb::arg(),
|
|
|
|
nb::arg(),
|
2024-01-18 04:42:39 +08:00
|
|
|
"alpha"_a = 1.0f,
|
|
|
|
"beta"_a = 1.0f,
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def addmm(c: array, a: array, b: array, /, alpha: float = 1.0, beta: float = 1.0, *, stream: Union[None, Stream, Device] = None) -> array"),
|
2024-01-18 04:42:39 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Matrix multiplication with addition and optional scaling.
|
|
|
|
|
|
|
|
Perform the (possibly batched) matrix multiplication of two arrays and add to the result
|
|
|
|
with optional scaling factors.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
c (array): Input array or scalar.
|
|
|
|
a (array): Input array or scalar.
|
|
|
|
b (array): Input array or scalar.
|
2024-01-20 06:06:05 +08:00
|
|
|
alpha (float, optional): Scaling factor for the
|
2024-01-18 04:42:39 +08:00
|
|
|
matrix product of ``a`` and ``b`` (default: ``1``)
|
|
|
|
beta (float, optional): Scaling factor for ``c`` (default: ``1``)
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: ``alpha * (a @ b) + beta * c``
|
|
|
|
)pbdoc");
|
2024-04-17 05:45:39 +08:00
|
|
|
m.def(
|
|
|
|
"block_masked_mm",
|
2024-12-12 07:45:39 +08:00
|
|
|
&mx::block_masked_mm,
|
2024-04-17 05:45:39 +08:00
|
|
|
nb::arg(),
|
|
|
|
nb::arg(),
|
|
|
|
"block_size"_a = 64,
|
|
|
|
"mask_out"_a = nb::none(),
|
|
|
|
"mask_lhs"_a = nb::none(),
|
|
|
|
"mask_rhs"_a = nb::none(),
|
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
2024-08-29 02:16:19 +08:00
|
|
|
"def block_masked_mm(a: array, b: array, /, block_size: int = 64, mask_out: Optional[array] = None, mask_lhs: Optional[array] = None, mask_rhs: Optional[array] = None, *, stream: Union[None, Stream, Device] = None) -> array"),
|
2024-04-17 05:45:39 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Matrix multiplication with block masking.
|
|
|
|
|
|
|
|
Perform the (possibly batched) matrix multiplication of two arrays and with blocks
|
|
|
|
of size ``block_size x block_size`` optionally masked out.
|
|
|
|
|
|
|
|
Assuming ``a`` with shape (..., `M`, `K`) and b with shape (..., `K`, `N`)
|
|
|
|
|
|
|
|
* ``lhs_mask`` must have shape (..., :math:`\lceil` `M` / ``block_size`` :math:`\rceil`, :math:`\lceil` `K` / ``block_size`` :math:`\rceil`)
|
|
|
|
|
2024-04-27 13:03:42 +08:00
|
|
|
* ``rhs_mask`` must have shape (..., :math:`\lceil` `K` / ``block_size`` :math:`\rceil`, :math:`\lceil` `N` / ``block_size`` :math:`\rceil`)
|
|
|
|
|
2024-04-17 05:45:39 +08:00
|
|
|
* ``out_mask`` must have shape (..., :math:`\lceil` `M` / ``block_size`` :math:`\rceil`, :math:`\lceil` `N` / ``block_size`` :math:`\rceil`)
|
|
|
|
|
|
|
|
Note: Only ``block_size=64`` and ``block_size=32`` are currently supported
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array or scalar.
|
|
|
|
b (array): Input array or scalar.
|
2024-07-26 00:36:44 +08:00
|
|
|
block_size (int): Size of blocks to be masked. Must be ``32`` or ``64``. Default: ``64``.
|
|
|
|
mask_out (array, optional): Mask for output. Default: ``None``.
|
|
|
|
mask_lhs (array, optional): Mask for ``a``. Default: ``None``.
|
|
|
|
mask_rhs (array, optional): Mask for ``b``. Default: ``None``.
|
2024-04-17 05:45:39 +08:00
|
|
|
|
2024-07-26 00:36:44 +08:00
|
|
|
Returns:
|
|
|
|
array: The output array.
|
2024-05-03 05:03:58 +08:00
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
2024-05-22 22:48:34 +08:00
|
|
|
"gather_mm",
|
2024-12-12 07:45:39 +08:00
|
|
|
&mx::gather_mm,
|
2024-05-03 05:03:58 +08:00
|
|
|
nb::arg(),
|
|
|
|
nb::arg(),
|
|
|
|
"lhs_indices"_a = nb::none(),
|
|
|
|
"rhs_indices"_a = nb::none(),
|
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
2024-05-22 22:48:34 +08:00
|
|
|
"def gather_mm(a: array, b: array, /, lhs_indices: array, rhs_indices: array, *, stream: Union[None, Stream, Device] = None) -> array"),
|
2024-05-03 05:03:58 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Matrix multiplication with matrix-level gather.
|
|
|
|
|
2024-05-22 22:48:34 +08:00
|
|
|
Performs a gather of the operands with the given indices followed by a
|
|
|
|
(possibly batched) matrix multiplication of two arrays. This operation
|
|
|
|
is more efficient than explicitly applying a :func:`take` followed by a
|
|
|
|
:func:`matmul`.
|
2024-05-10 00:36:02 +08:00
|
|
|
|
2024-05-22 22:48:34 +08:00
|
|
|
The indices ``lhs_indices`` and ``rhs_indices`` contain flat indices
|
|
|
|
along the batch dimensions (i.e. all but the last two dimensions) of
|
|
|
|
``a`` and ``b`` respectively.
|
2024-05-03 05:03:58 +08:00
|
|
|
|
2024-05-22 22:48:34 +08:00
|
|
|
For ``a`` with shape ``(A1, A2, ..., AS, M, K)``, ``lhs_indices``
|
|
|
|
contains indices from the range ``[0, A1 * A2 * ... * AS)``
|
2024-05-03 05:03:58 +08:00
|
|
|
|
2024-05-22 22:48:34 +08:00
|
|
|
For ``b`` with shape ``(B1, B2, ..., BS, M, K)``, ``rhs_indices``
|
|
|
|
contains indices from the range ``[0, B1 * B2 * ... * BS)``
|
2024-05-03 05:03:58 +08:00
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
b (array): Input array.
|
2024-07-26 00:36:44 +08:00
|
|
|
lhs_indices (array, optional): Integer indices for ``a``. Default: ``None``
|
|
|
|
rhs_indices (array, optional): Integer indices for ``b``. Default: ``None``
|
2024-05-03 05:03:58 +08:00
|
|
|
|
2024-07-26 00:36:44 +08:00
|
|
|
Returns:
|
|
|
|
array: The output array.
|
2024-04-17 05:45:39 +08:00
|
|
|
)pbdoc");
|
2024-01-31 01:45:48 +08:00
|
|
|
m.def(
|
|
|
|
"diagonal",
|
2024-12-12 07:45:39 +08:00
|
|
|
&mx::diagonal,
|
2024-01-31 01:45:48 +08:00
|
|
|
"a"_a,
|
|
|
|
"offset"_a = 0,
|
|
|
|
"axis1"_a = 0,
|
|
|
|
"axis2"_a = 1,
|
2024-03-19 11:12:25 +08:00
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def diagonal(a: array, offset: int = 0, axis1: int = 0, axis2: int = 1, stream: Union[None, Stream, Device] = None) -> array"),
|
2024-01-31 01:45:48 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Return specified diagonals.
|
|
|
|
|
|
|
|
If ``a`` is 2-D, then a 1-D array containing the diagonal at the given
|
|
|
|
``offset`` is returned.
|
|
|
|
|
|
|
|
If ``a`` has more than two dimensions, then ``axis1`` and ``axis2``
|
|
|
|
determine the 2D subarrays from which diagonals are extracted. The new
|
|
|
|
shape is the original shape with ``axis1`` and ``axis2`` removed and a
|
|
|
|
new dimension inserted at the end corresponding to the diagonal.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array
|
|
|
|
offset (int, optional): Offset of the diagonal from the main diagonal.
|
|
|
|
Can be positive or negative. Default: ``0``.
|
|
|
|
axis1 (int, optional): The first axis of the 2-D sub-arrays from which
|
|
|
|
the diagonals should be taken. Default: ``0``.
|
|
|
|
axis2 (int, optional): The second axis of the 2-D sub-arrays from which
|
|
|
|
the diagonals should be taken. Default: ``1``.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The diagonals of the array.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"diag",
|
2024-12-12 07:45:39 +08:00
|
|
|
&mx::diag,
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::arg(),
|
2024-01-31 01:45:48 +08:00
|
|
|
"k"_a = 0,
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def diag(a: array, /, k: int = 0, *, stream: Union[None, Stream, Device] = None) -> array"),
|
2024-01-31 01:45:48 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Extract a diagonal or construct a diagonal matrix.
|
|
|
|
If ``a`` is 1-D then a diagonal matrix is constructed with ``a`` on the
|
|
|
|
:math:`k`-th diagonal. If ``a`` is 2-D then the :math:`k`-th diagonal is
|
|
|
|
returned.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): 1-D or 2-D input array.
|
|
|
|
k (int, optional): The diagonal to extract or construct.
|
|
|
|
Default: ``0``.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The extracted diagonal or the constructed diagonal matrix.
|
|
|
|
)pbdoc");
|
2024-05-23 06:50:27 +08:00
|
|
|
m.def(
|
|
|
|
"trace",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const mx::array& a,
|
2024-05-23 06:50:27 +08:00
|
|
|
int offset,
|
|
|
|
int axis1,
|
|
|
|
int axis2,
|
2024-12-12 07:45:39 +08:00
|
|
|
std::optional<mx::Dtype> dtype,
|
|
|
|
mx::StreamOrDevice s) {
|
2024-05-23 06:50:27 +08:00
|
|
|
if (!dtype.has_value()) {
|
2024-12-12 07:45:39 +08:00
|
|
|
return mx::trace(a, offset, axis1, axis2, s);
|
2024-05-23 06:50:27 +08:00
|
|
|
}
|
2024-12-12 07:45:39 +08:00
|
|
|
return mx::trace(a, offset, axis1, axis2, dtype.value(), s);
|
2024-05-23 06:50:27 +08:00
|
|
|
},
|
|
|
|
nb::arg(),
|
|
|
|
"offset"_a = 0,
|
|
|
|
"axis1"_a = 0,
|
|
|
|
"axis2"_a = 1,
|
|
|
|
"dtype"_a = nb::none(),
|
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
2024-08-29 02:16:19 +08:00
|
|
|
"def trace(a: array, /, offset: int = 0, axis1: int = 0, axis2: int = 1, dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array"),
|
2024-05-23 06:50:27 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Return the sum along a specified diagonal in the given array.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array
|
|
|
|
offset (int, optional): Offset of the diagonal from the main diagonal.
|
|
|
|
Can be positive or negative. Default: ``0``.
|
|
|
|
axis1 (int, optional): The first axis of the 2-D sub-arrays from which
|
|
|
|
the diagonals should be taken. Default: ``0``.
|
|
|
|
axis2 (int, optional): The second axis of the 2-D sub-arrays from which
|
|
|
|
the diagonals should be taken. Default: ``1``.
|
|
|
|
dtype (Dtype, optional): Data type of the output array. If
|
|
|
|
unspecified the output type is inferred from the input array.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: Sum of specified diagonal.
|
|
|
|
)pbdoc");
|
2024-02-20 01:40:52 +08:00
|
|
|
m.def(
|
|
|
|
"atleast_1d",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const nb::args& arys, mx::StreamOrDevice s) -> nb::object {
|
2024-02-27 03:17:59 +08:00
|
|
|
if (arys.size() == 1) {
|
2024-12-12 07:45:39 +08:00
|
|
|
return nb::cast(mx::atleast_1d(nb::cast<mx::array>(arys[0]), s));
|
2024-02-27 03:17:59 +08:00
|
|
|
}
|
2024-12-12 07:45:39 +08:00
|
|
|
return nb::cast(
|
|
|
|
mx::atleast_1d(nb::cast<std::vector<mx::array>>(arys), s));
|
2024-02-27 03:17:59 +08:00
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
"arys"_a,
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
2024-08-29 02:16:19 +08:00
|
|
|
"def atleast_1d(*arys: array, stream: Union[None, Stream, Device] = None) -> Union[array, list[array]]"),
|
2024-02-20 01:40:52 +08:00
|
|
|
R"pbdoc(
|
2024-02-27 03:17:59 +08:00
|
|
|
Convert all arrays to have at least one dimension.
|
2024-02-20 01:40:52 +08:00
|
|
|
|
2024-02-27 03:17:59 +08:00
|
|
|
Args:
|
|
|
|
*arys: Input arrays.
|
2024-02-20 01:40:52 +08:00
|
|
|
stream (Union[None, Stream, Device], optional): The stream to execute the operation on.
|
|
|
|
|
|
|
|
Returns:
|
2024-02-27 03:17:59 +08:00
|
|
|
array or list(array): An array or list of arrays with at least one dimension.
|
2024-02-20 01:40:52 +08:00
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"atleast_2d",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const nb::args& arys, mx::StreamOrDevice s) -> nb::object {
|
2024-02-27 03:17:59 +08:00
|
|
|
if (arys.size() == 1) {
|
2024-12-12 07:45:39 +08:00
|
|
|
return nb::cast(mx::atleast_2d(nb::cast<mx::array>(arys[0]), s));
|
2024-02-27 03:17:59 +08:00
|
|
|
}
|
2024-12-12 07:45:39 +08:00
|
|
|
return nb::cast(
|
|
|
|
mx::atleast_2d(nb::cast<std::vector<mx::array>>(arys), s));
|
2024-02-27 03:17:59 +08:00
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
"arys"_a,
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
2024-08-29 02:16:19 +08:00
|
|
|
"def atleast_2d(*arys: array, stream: Union[None, Stream, Device] = None) -> Union[array, list[array]]"),
|
2024-02-20 01:40:52 +08:00
|
|
|
R"pbdoc(
|
2024-02-27 03:17:59 +08:00
|
|
|
Convert all arrays to have at least two dimensions.
|
2024-02-20 01:40:52 +08:00
|
|
|
|
2024-02-27 03:17:59 +08:00
|
|
|
Args:
|
|
|
|
*arys: Input arrays.
|
2024-02-20 01:40:52 +08:00
|
|
|
stream (Union[None, Stream, Device], optional): The stream to execute the operation on.
|
|
|
|
|
|
|
|
Returns:
|
2024-02-27 03:17:59 +08:00
|
|
|
array or list(array): An array or list of arrays with at least two dimensions.
|
2024-02-20 01:40:52 +08:00
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"atleast_3d",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const nb::args& arys, mx::StreamOrDevice s) -> nb::object {
|
2024-02-27 03:17:59 +08:00
|
|
|
if (arys.size() == 1) {
|
2024-12-12 07:45:39 +08:00
|
|
|
return nb::cast(mx::atleast_3d(nb::cast<mx::array>(arys[0]), s));
|
2024-02-27 03:17:59 +08:00
|
|
|
}
|
2024-12-12 07:45:39 +08:00
|
|
|
return nb::cast(
|
|
|
|
mx::atleast_3d(nb::cast<std::vector<mx::array>>(arys), s));
|
2024-02-27 03:17:59 +08:00
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
"arys"_a,
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
2024-08-29 02:16:19 +08:00
|
|
|
"def atleast_3d(*arys: array, stream: Union[None, Stream, Device] = None) -> Union[array, list[array]]"),
|
2024-02-20 01:40:52 +08:00
|
|
|
R"pbdoc(
|
2024-02-27 03:17:59 +08:00
|
|
|
Convert all arrays to have at least three dimensions.
|
2024-02-20 01:40:52 +08:00
|
|
|
|
2024-02-27 03:17:59 +08:00
|
|
|
Args:
|
|
|
|
*arys: Input arrays.
|
2024-02-20 01:40:52 +08:00
|
|
|
stream (Union[None, Stream, Device], optional): The stream to execute the operation on.
|
|
|
|
|
|
|
|
Returns:
|
2024-02-27 03:17:59 +08:00
|
|
|
array or list(array): An array or list of arrays with at least three dimensions.
|
2024-02-20 01:40:52 +08:00
|
|
|
)pbdoc");
|
2024-03-26 03:32:59 +08:00
|
|
|
m.def(
|
|
|
|
"issubdtype",
|
2024-08-14 05:49:28 +08:00
|
|
|
[](const nb::object& d1, const nb::object& d2) {
|
|
|
|
auto dispatch_second = [](const auto& t1, const auto& d2) {
|
2024-12-12 07:45:39 +08:00
|
|
|
if (nb::isinstance<mx::Dtype>(d2)) {
|
|
|
|
return mx::issubdtype(t1, nb::cast<mx::Dtype>(d2));
|
|
|
|
} else if (nb::isinstance<mx::Dtype::Category>(d2)) {
|
|
|
|
return mx::issubdtype(t1, nb::cast<mx::Dtype::Category>(d2));
|
2024-08-14 05:49:28 +08:00
|
|
|
} else {
|
|
|
|
throw std::invalid_argument(
|
|
|
|
"[issubdtype] Received invalid type for second input.");
|
|
|
|
}
|
|
|
|
};
|
2024-12-12 07:45:39 +08:00
|
|
|
if (nb::isinstance<mx::Dtype>(d1)) {
|
|
|
|
return dispatch_second(nb::cast<mx::Dtype>(d1), d2);
|
|
|
|
} else if (nb::isinstance<mx::Dtype::Category>(d1)) {
|
|
|
|
return dispatch_second(nb::cast<mx::Dtype::Category>(d1), d2);
|
2024-08-14 05:49:28 +08:00
|
|
|
} else {
|
|
|
|
throw std::invalid_argument(
|
|
|
|
"[issubdtype] Received invalid type for first input.");
|
|
|
|
}
|
|
|
|
},
|
2024-03-26 03:32:59 +08:00
|
|
|
""_a,
|
|
|
|
""_a,
|
2024-08-14 05:49:28 +08:00
|
|
|
nb::sig(
|
|
|
|
"def issubdtype(arg1: Union[Dtype, DtypeCategory], arg2: Union[Dtype, DtypeCategory]) -> bool"),
|
2024-03-26 03:32:59 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Check if a :obj:`Dtype` or :obj:`DtypeCategory` is a subtype
|
|
|
|
of another.
|
|
|
|
|
2024-08-14 05:49:28 +08:00
|
|
|
Args:
|
|
|
|
arg1 (Union[Dtype, DtypeCategory]: First dtype or category.
|
|
|
|
arg2 (Union[Dtype, DtypeCategory]: Second dtype or category.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
bool:
|
|
|
|
A boolean indicating if the first input is a subtype of the
|
|
|
|
second input.
|
|
|
|
|
|
|
|
Example:
|
|
|
|
|
|
|
|
>>> ints = mx.array([1, 2, 3], dtype=mx.int32)
|
|
|
|
>>> mx.issubdtype(ints.dtype, mx.integer)
|
|
|
|
True
|
|
|
|
>>> mx.issubdtype(ints.dtype, mx.floating)
|
|
|
|
False
|
2024-03-26 03:32:59 +08:00
|
|
|
|
2024-08-14 05:49:28 +08:00
|
|
|
>>> floats = mx.array([1, 2, 3], dtype=mx.float32)
|
|
|
|
>>> mx.issubdtype(floats.dtype, mx.integer)
|
|
|
|
False
|
|
|
|
>>> mx.issubdtype(floats.dtype, mx.floating)
|
|
|
|
True
|
2024-03-26 03:32:59 +08:00
|
|
|
|
2024-08-14 05:49:28 +08:00
|
|
|
Similar types of different sizes are not subdtypes of each other:
|
2024-03-26 03:32:59 +08:00
|
|
|
|
2024-08-14 05:49:28 +08:00
|
|
|
>>> mx.issubdtype(mx.float64, mx.float32)
|
|
|
|
False
|
|
|
|
>>> mx.issubdtype(mx.float32, mx.float64)
|
|
|
|
False
|
2024-03-26 03:32:59 +08:00
|
|
|
|
2024-08-14 05:49:28 +08:00
|
|
|
but both are subtypes of `floating`:
|
2024-03-26 03:32:59 +08:00
|
|
|
|
2024-08-14 05:49:28 +08:00
|
|
|
>>> mx.issubdtype(mx.float64, mx.floating)
|
|
|
|
True
|
|
|
|
>>> mx.issubdtype(mx.float32, mx.floating)
|
|
|
|
True
|
2024-03-26 03:32:59 +08:00
|
|
|
|
2024-08-14 05:49:28 +08:00
|
|
|
For convenience, dtype-like objects are allowed too:
|
2024-03-26 03:32:59 +08:00
|
|
|
|
2024-08-14 05:49:28 +08:00
|
|
|
>>> mx.issubdtype(mx.float32, mx.inexact)
|
|
|
|
True
|
|
|
|
>>> mx.issubdtype(mx.signedinteger, mx.floating)
|
|
|
|
False
|
2024-03-26 03:32:59 +08:00
|
|
|
)pbdoc");
|
2024-04-27 13:03:42 +08:00
|
|
|
m.def(
|
|
|
|
"bitwise_and",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const ScalarOrArray& a_,
|
|
|
|
const ScalarOrArray& b_,
|
|
|
|
mx::StreamOrDevice s) {
|
2024-04-27 13:03:42 +08:00
|
|
|
auto [a, b] = to_arrays(a_, b_);
|
2024-12-12 07:45:39 +08:00
|
|
|
return mx::bitwise_and(a, b, s);
|
2024-04-27 13:03:42 +08:00
|
|
|
},
|
|
|
|
nb::arg(),
|
|
|
|
nb::arg(),
|
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def bitwise_and(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array"),
|
|
|
|
R"pbdoc(
|
|
|
|
Element-wise bitwise and.
|
|
|
|
|
|
|
|
Take the bitwise and of two arrays with numpy-style broadcasting
|
|
|
|
semantics. Either or both input arrays can also be scalars.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array or scalar.
|
|
|
|
b (array): Input array or scalar.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The bitwise and ``a & b``.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"bitwise_or",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const ScalarOrArray& a_,
|
|
|
|
const ScalarOrArray& b_,
|
|
|
|
mx::StreamOrDevice s) {
|
2024-04-27 13:03:42 +08:00
|
|
|
auto [a, b] = to_arrays(a_, b_);
|
2024-12-12 07:45:39 +08:00
|
|
|
return mx::bitwise_or(a, b, s);
|
2024-04-27 13:03:42 +08:00
|
|
|
},
|
|
|
|
nb::arg(),
|
|
|
|
nb::arg(),
|
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def bitwise_or(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array"),
|
|
|
|
R"pbdoc(
|
|
|
|
Element-wise bitwise or.
|
|
|
|
|
|
|
|
Take the bitwise or of two arrays with numpy-style broadcasting
|
|
|
|
semantics. Either or both input arrays can also be scalars.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array or scalar.
|
|
|
|
b (array): Input array or scalar.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The bitwise or``a | b``.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"bitwise_xor",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const ScalarOrArray& a_,
|
|
|
|
const ScalarOrArray& b_,
|
|
|
|
mx::StreamOrDevice s) {
|
2024-04-27 13:03:42 +08:00
|
|
|
auto [a, b] = to_arrays(a_, b_);
|
2024-12-12 07:45:39 +08:00
|
|
|
return mx::bitwise_xor(a, b, s);
|
2024-04-27 13:03:42 +08:00
|
|
|
},
|
|
|
|
nb::arg(),
|
|
|
|
nb::arg(),
|
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def bitwise_xor(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array"),
|
|
|
|
R"pbdoc(
|
|
|
|
Element-wise bitwise xor.
|
|
|
|
|
|
|
|
Take the bitwise exclusive or of two arrays with numpy-style
|
|
|
|
broadcasting semantics. Either or both input arrays can also be
|
|
|
|
scalars.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array or scalar.
|
|
|
|
b (array): Input array or scalar.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The bitwise xor ``a ^ b``.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"left_shift",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const ScalarOrArray& a_,
|
|
|
|
const ScalarOrArray& b_,
|
|
|
|
mx::StreamOrDevice s) {
|
2024-04-27 13:03:42 +08:00
|
|
|
auto [a, b] = to_arrays(a_, b_);
|
2024-12-12 07:45:39 +08:00
|
|
|
return mx::left_shift(a, b, s);
|
2024-04-27 13:03:42 +08:00
|
|
|
},
|
|
|
|
nb::arg(),
|
|
|
|
nb::arg(),
|
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def left_shift(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array"),
|
|
|
|
R"pbdoc(
|
|
|
|
Element-wise left shift.
|
|
|
|
|
|
|
|
Shift the bits of the first input to the left by the second using
|
|
|
|
numpy-style broadcasting semantics. Either or both input arrays can
|
|
|
|
also be scalars.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array or scalar.
|
|
|
|
b (array): Input array or scalar.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The bitwise left shift ``a << b``.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"right_shift",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const ScalarOrArray& a_,
|
|
|
|
const ScalarOrArray& b_,
|
|
|
|
mx::StreamOrDevice s) {
|
2024-04-27 13:03:42 +08:00
|
|
|
auto [a, b] = to_arrays(a_, b_);
|
2024-12-12 07:45:39 +08:00
|
|
|
return mx::right_shift(a, b, s);
|
2024-04-27 13:03:42 +08:00
|
|
|
},
|
|
|
|
nb::arg(),
|
|
|
|
nb::arg(),
|
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def right_shift(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array"),
|
|
|
|
R"pbdoc(
|
|
|
|
Element-wise right shift.
|
|
|
|
|
|
|
|
Shift the bits of the first input to the right by the second using
|
|
|
|
numpy-style broadcasting semantics. Either or both input arrays can
|
|
|
|
also be scalars.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array or scalar.
|
|
|
|
b (array): Input array or scalar.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The bitwise right shift ``a >> b``.
|
|
|
|
)pbdoc");
|
2024-06-04 23:05:27 +08:00
|
|
|
m.def(
|
|
|
|
"view",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const ScalarOrArray& a, const mx::Dtype& dtype, mx::StreamOrDevice s) {
|
|
|
|
return mx::view(to_array(a), dtype, s);
|
2024-06-04 23:05:27 +08:00
|
|
|
},
|
|
|
|
nb::arg(),
|
|
|
|
"dtype"_a,
|
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def view(a: Union[scalar, array], dtype: Dtype, stream: Union[None, Stream, Device] = None) -> array"),
|
|
|
|
R"pbdoc(
|
|
|
|
View the array as a different type.
|
|
|
|
|
|
|
|
The output shape changes along the last axis if the input array's
|
|
|
|
type and the input ``dtype`` do not have the same size.
|
|
|
|
|
|
|
|
Note: the view op does not imply that the input and output arrays share
|
|
|
|
their underlying data. The view only gaurantees that the binary
|
|
|
|
representation of each element (or group of elements) is the same.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array or scalar.
|
|
|
|
dtype (Dtype): The data type to change to.
|
|
|
|
|
2024-07-10 11:39:01 +08:00
|
|
|
Returns:
|
|
|
|
array: The array with the new type.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"hadamard_transform",
|
2024-12-12 07:45:39 +08:00
|
|
|
&mx::hadamard_transform,
|
2024-07-10 11:39:01 +08:00
|
|
|
nb::arg(),
|
2024-07-24 05:54:43 +08:00
|
|
|
"scale"_a = nb::none(),
|
2024-07-10 11:39:01 +08:00
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
2024-08-29 02:16:19 +08:00
|
|
|
"def hadamard_transform(a: array, scale: Optional[float] = None, stream: Union[None, Stream, Device] = None) -> array"),
|
2024-07-10 11:39:01 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Perform the Walsh-Hadamard transform along the final axis.
|
|
|
|
|
|
|
|
Equivalent to:
|
|
|
|
|
2024-07-12 06:59:07 +08:00
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
from scipy.linalg import hadamard
|
2024-07-10 11:39:01 +08:00
|
|
|
|
2024-07-24 05:54:43 +08:00
|
|
|
y = (hadamard(len(x)) @ x) * scale
|
2024-07-12 06:59:07 +08:00
|
|
|
|
|
|
|
Supports sizes ``n = m*2^k`` for ``m`` in ``(1, 12, 20, 28)`` and ``2^k
|
|
|
|
<= 8192`` for float32 and ``2^k <= 16384`` for float16/bfloat16.
|
2024-07-10 11:39:01 +08:00
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array or scalar.
|
|
|
|
scale (float): Scale the output by this factor.
|
2024-07-24 05:54:43 +08:00
|
|
|
Defaults to ``1/sqrt(a.shape[-1])`` so that the Hadamard matrix is orthonormal.
|
2024-07-10 11:39:01 +08:00
|
|
|
|
2024-06-04 23:05:27 +08:00
|
|
|
Returns:
|
2024-07-12 06:59:07 +08:00
|
|
|
array: The transformed array.
|
2024-06-04 23:05:27 +08:00
|
|
|
)pbdoc");
|
2024-07-26 00:36:44 +08:00
|
|
|
m.def(
|
|
|
|
"einsum_path",
|
|
|
|
[](const std::string& equation, const nb::args& operands) {
|
2024-12-12 07:45:39 +08:00
|
|
|
auto arrays_list = nb::cast<std::vector<mx::array>>(operands);
|
|
|
|
auto [path, str] = mx::einsum_path(equation, arrays_list);
|
2024-07-26 00:36:44 +08:00
|
|
|
// Convert to list of tuples
|
|
|
|
std::vector<nb::tuple> tuple_path;
|
|
|
|
for (auto& p : path) {
|
|
|
|
tuple_path.push_back(nb::tuple(nb::cast(p)));
|
|
|
|
}
|
|
|
|
return std::make_pair(tuple_path, str);
|
|
|
|
},
|
|
|
|
"subscripts"_a,
|
|
|
|
"operands"_a,
|
|
|
|
nb::sig("def einsum_path(subscripts: str, *operands)"),
|
|
|
|
R"pbdoc(
|
|
|
|
|
|
|
|
Compute the contraction order for the given Einstein summation.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
subscripts (str): The Einstein summation convention equation.
|
|
|
|
*operands (array): The input arrays.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
tuple(list(tuple(int, int)), str):
|
|
|
|
The einsum path and a string containing information about the
|
|
|
|
chosen path.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"einsum",
|
|
|
|
[](const std::string& subscripts,
|
|
|
|
const nb::args& operands,
|
2024-12-12 07:45:39 +08:00
|
|
|
mx::StreamOrDevice s) {
|
|
|
|
auto arrays_list = nb::cast<std::vector<mx::array>>(operands);
|
|
|
|
return mx::einsum(subscripts, arrays_list, s);
|
2024-07-26 00:36:44 +08:00
|
|
|
},
|
|
|
|
"subscripts"_a,
|
|
|
|
"operands"_a,
|
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
2024-08-29 02:16:19 +08:00
|
|
|
"def einsum(subscripts: str, *operands, stream: Union[None, Stream, Device] = None) -> array"),
|
2024-07-26 00:36:44 +08:00
|
|
|
R"pbdoc(
|
|
|
|
|
|
|
|
Perform the Einstein summation convention on the operands.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
subscripts (str): The Einstein summation convention equation.
|
|
|
|
*operands (array): The input arrays.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The output array.
|
|
|
|
)pbdoc");
|
2024-10-08 08:21:42 +08:00
|
|
|
m.def(
|
|
|
|
"roll",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const mx::array& a,
|
2024-10-08 08:21:42 +08:00
|
|
|
const IntOrVec& shift,
|
|
|
|
const IntOrVec& axis,
|
2024-12-12 07:45:39 +08:00
|
|
|
mx::StreamOrDevice s) {
|
2024-10-08 08:21:42 +08:00
|
|
|
return std::visit(
|
2024-12-12 07:45:39 +08:00
|
|
|
[&](auto sh, auto ax) -> mx::array {
|
2024-10-08 08:21:42 +08:00
|
|
|
using T = decltype(ax);
|
|
|
|
using V = decltype(sh);
|
|
|
|
|
|
|
|
if constexpr (std::is_same_v<V, std::monostate>) {
|
|
|
|
throw std::invalid_argument(
|
|
|
|
"[roll] Expected two arguments but only one was given.");
|
|
|
|
} else {
|
|
|
|
if constexpr (std::is_same_v<T, std::monostate>) {
|
2024-12-12 07:45:39 +08:00
|
|
|
return mx::roll(a, sh, s);
|
2024-10-08 08:21:42 +08:00
|
|
|
} else {
|
2024-12-12 07:45:39 +08:00
|
|
|
return mx::roll(a, sh, ax, s);
|
2024-10-08 08:21:42 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
},
|
|
|
|
shift,
|
|
|
|
axis);
|
|
|
|
},
|
|
|
|
nb::arg(),
|
|
|
|
"shift"_a,
|
|
|
|
"axis"_a = nb::none(),
|
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def roll(a: array, shift: Union[int, Tuple[int]], axis: Union[None, int, Tuple[int]] = None, /, *, stream: Union[None, Stream, Device] = None) -> array"),
|
|
|
|
R"pbdoc(
|
|
|
|
Roll array elements along a given axis.
|
|
|
|
|
|
|
|
Elements that are rolled beyond the end of the array are introduced at
|
|
|
|
the beggining and vice-versa.
|
|
|
|
|
|
|
|
If the axis is not provided the array is flattened, rolled and then the
|
|
|
|
shape is restored.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array
|
|
|
|
shift (int or tuple(int)): The number of places by which elements
|
|
|
|
are shifted. If positive the array is rolled to the right, if
|
|
|
|
negative it is rolled to the left. If an int is provided but the
|
|
|
|
axis is a tuple then the same value is used for all axes.
|
|
|
|
axis (int or tuple(int), optional): The axis or axes along which to
|
|
|
|
roll the elements.
|
|
|
|
)pbdoc");
|
2024-10-16 07:23:15 +08:00
|
|
|
m.def(
|
|
|
|
"real",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const ScalarOrArray& a, mx::StreamOrDevice s) {
|
|
|
|
return mx::real(to_array(a), s);
|
2024-10-16 07:23:15 +08:00
|
|
|
},
|
|
|
|
nb::arg(),
|
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def real(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"),
|
|
|
|
R"pbdoc(
|
|
|
|
Returns the real part of a complex array.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The real part of ``a``.
|
|
|
|
)pbdoc");
|
|
|
|
m.def(
|
|
|
|
"imag",
|
2024-12-12 07:45:39 +08:00
|
|
|
[](const ScalarOrArray& a, mx::StreamOrDevice s) {
|
|
|
|
return mx::imag(to_array(a), s);
|
2024-10-16 07:23:15 +08:00
|
|
|
},
|
|
|
|
nb::arg(),
|
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def imag(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"),
|
|
|
|
R"pbdoc(
|
|
|
|
Returns the imaginary part of a complex array.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a (array): Input array.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The imaginary part of ``a``.
|
|
|
|
)pbdoc");
|
2023-11-30 02:30:41 +08:00
|
|
|
}
|