mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-21 16:51:15 +08:00
Mlx array accessor (#128)
* Add an accessor to interoperate with custom types * Change the docs to custom signatures
This commit is contained in:
parent
072044e28f
commit
3214629601
@ -458,38 +458,54 @@ void init_array(py::module_& m) {
|
|||||||
m.attr("bfloat16") = py::cast(bfloat16);
|
m.attr("bfloat16") = py::cast(bfloat16);
|
||||||
m.attr("complex64") = py::cast(complex64);
|
m.attr("complex64") = py::cast(complex64);
|
||||||
|
|
||||||
py::class_<array>(m, "array", R"pbdoc(An N-dimensional array object.)pbdoc")
|
auto array_class = py::class_<array>(
|
||||||
.def(
|
m, "array", R"pbdoc(An N-dimensional array object.)pbdoc");
|
||||||
py::init([](ScalarOrArray v, std::optional<Dtype> t) {
|
|
||||||
auto arr = to_array(v, t);
|
{
|
||||||
return astype(arr, t.value_or(arr.dtype()));
|
py::options options;
|
||||||
}),
|
options.disable_function_signatures();
|
||||||
"val"_a,
|
|
||||||
"dtype"_a = std::nullopt)
|
array_class.def(
|
||||||
.def(
|
py::init([](std::variant<
|
||||||
py::init([](std::variant<py::list, py::tuple> pl,
|
py::bool_,
|
||||||
std::optional<Dtype> dtype) {
|
py::int_,
|
||||||
if (auto pv = std::get_if<py::list>(&pl); pv) {
|
py::float_,
|
||||||
return array_from_list(*pv, dtype);
|
std::complex<float>,
|
||||||
|
py::list,
|
||||||
|
py::tuple,
|
||||||
|
py::array,
|
||||||
|
py::buffer,
|
||||||
|
py::object> v,
|
||||||
|
std::optional<Dtype> t) {
|
||||||
|
if (auto pv = std::get_if<py::bool_>(&v); pv) {
|
||||||
|
return array(py::cast<bool>(*pv), t.value_or(bool_));
|
||||||
|
} else if (auto pv = std::get_if<py::int_>(&v); pv) {
|
||||||
|
return array(py::cast<int>(*pv), t.value_or(int32));
|
||||||
|
} else if (auto pv = std::get_if<py::float_>(&v); pv) {
|
||||||
|
return array(py::cast<float>(*pv), t.value_or(float32));
|
||||||
|
} else if (auto pv = std::get_if<std::complex<float>>(&v); pv) {
|
||||||
|
return array(static_cast<complex64_t>(*pv), t.value_or(complex64));
|
||||||
|
} else if (auto pv = std::get_if<py::list>(&v); pv) {
|
||||||
|
return array_from_list(*pv, t);
|
||||||
|
} else if (auto pv = std::get_if<py::tuple>(&v); pv) {
|
||||||
|
return array_from_list(*pv, t);
|
||||||
|
} else if (auto pv = std::get_if<py::array>(&v); pv) {
|
||||||
|
return np_array_to_mlx(*pv, t);
|
||||||
|
} else if (auto pv = std::get_if<py::buffer>(&v); pv) {
|
||||||
|
return np_array_to_mlx(*pv, t);
|
||||||
} else {
|
} else {
|
||||||
auto v = std::get<py::tuple>(pl);
|
auto arr = to_array_with_accessor(std::get<py::object>(v));
|
||||||
return array_from_list(v, dtype);
|
return astype(arr, t.value_or(arr.dtype()));
|
||||||
}
|
}
|
||||||
}),
|
}),
|
||||||
"vals"_a,
|
"val"_a,
|
||||||
"dtype"_a = std::nullopt)
|
"dtype"_a = std::nullopt,
|
||||||
.def(
|
R"pbdoc(
|
||||||
py::init([](py::array np_array, std::optional<Dtype> dtype) {
|
__init__(self: array, val: Union[scalar, list, tuple, numpy.ndarray, array], dtype: Optional[Dtype] = None)
|
||||||
return np_array_to_mlx(np_array, dtype);
|
)pbdoc");
|
||||||
}),
|
}
|
||||||
"vals"_a,
|
|
||||||
"dtype"_a = std::nullopt)
|
array_class
|
||||||
.def(
|
|
||||||
py::init([](py::buffer np_buffer, std::optional<Dtype> dtype) {
|
|
||||||
return np_array_to_mlx(np_buffer, dtype);
|
|
||||||
}),
|
|
||||||
"vals"_a,
|
|
||||||
"dtype"_a = std::nullopt)
|
|
||||||
.def_property_readonly(
|
.def_property_readonly(
|
||||||
"size", &array::size, R"pbdoc(Number of elments in the array.)pbdoc")
|
"size", &array::size, R"pbdoc(Number of elments in the array.)pbdoc")
|
||||||
.def_property_readonly(
|
.def_property_readonly(
|
||||||
|
@ -36,6 +36,9 @@ double scalar_to_double(Scalar s) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void init_ops(py::module_& m) {
|
void init_ops(py::module_& m) {
|
||||||
|
py::options options;
|
||||||
|
options.disable_function_signatures();
|
||||||
|
|
||||||
m.def(
|
m.def(
|
||||||
"reshape",
|
"reshape",
|
||||||
&reshape,
|
&reshape,
|
||||||
@ -45,6 +48,8 @@ void init_ops(py::module_& m) {
|
|||||||
py::kw_only(),
|
py::kw_only(),
|
||||||
"stream"_a = none,
|
"stream"_a = none,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
|
reshape(a: array, /, shape: List[int], *, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
Reshape an array while preserving the size.
|
Reshape an array while preserving the size.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -73,6 +78,8 @@ void init_ops(py::module_& m) {
|
|||||||
py::kw_only(),
|
py::kw_only(),
|
||||||
"stream"_a = none,
|
"stream"_a = none,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
|
squeeze(a: array, /, axis: Union[None, int, List[int]] = None, *, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
Remove length one axes from an array.
|
Remove length one axes from an array.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -100,6 +107,8 @@ void init_ops(py::module_& m) {
|
|||||||
py::kw_only(),
|
py::kw_only(),
|
||||||
"stream"_a = none,
|
"stream"_a = none,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
|
expand_dims(a: array, /, axis: Union[int, List[int]], *, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
Add a size one dimension at the given axis.
|
Add a size one dimension at the given axis.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -117,6 +126,8 @@ void init_ops(py::module_& m) {
|
|||||||
py::kw_only(),
|
py::kw_only(),
|
||||||
"stream"_a = none,
|
"stream"_a = none,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
|
abs(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
Element-wise absolute value.
|
Element-wise absolute value.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -133,6 +144,8 @@ void init_ops(py::module_& m) {
|
|||||||
py::kw_only(),
|
py::kw_only(),
|
||||||
"stream"_a = none,
|
"stream"_a = none,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
|
sign(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
Element-wise sign.
|
Element-wise sign.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -149,6 +162,8 @@ void init_ops(py::module_& m) {
|
|||||||
py::kw_only(),
|
py::kw_only(),
|
||||||
"stream"_a = none,
|
"stream"_a = none,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
|
negative(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
Element-wise negation.
|
Element-wise negation.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -169,6 +184,8 @@ void init_ops(py::module_& m) {
|
|||||||
py::kw_only(),
|
py::kw_only(),
|
||||||
"stream"_a = none,
|
"stream"_a = none,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
|
add(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
Element-wise addition.
|
Element-wise addition.
|
||||||
|
|
||||||
Add two arrays with numpy-style broadcasting semantics. Either or both input arrays
|
Add two arrays with numpy-style broadcasting semantics. Either or both input arrays
|
||||||
@ -193,6 +210,8 @@ void init_ops(py::module_& m) {
|
|||||||
py::kw_only(),
|
py::kw_only(),
|
||||||
"stream"_a = none,
|
"stream"_a = none,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
|
subtract(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
Element-wise subtraction.
|
Element-wise subtraction.
|
||||||
|
|
||||||
Subtract one array from another with numpy-style broadcasting semantics. Either or both
|
Subtract one array from another with numpy-style broadcasting semantics. Either or both
|
||||||
@ -217,6 +236,8 @@ void init_ops(py::module_& m) {
|
|||||||
py::kw_only(),
|
py::kw_only(),
|
||||||
"stream"_a = none,
|
"stream"_a = none,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
|
multiply(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
Element-wise multiplication.
|
Element-wise multiplication.
|
||||||
|
|
||||||
Multiply two arrays with numpy-style broadcasting semantics. Either or both
|
Multiply two arrays with numpy-style broadcasting semantics. Either or both
|
||||||
@ -241,6 +262,8 @@ void init_ops(py::module_& m) {
|
|||||||
py::kw_only(),
|
py::kw_only(),
|
||||||
"stream"_a = none,
|
"stream"_a = none,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
|
divide(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
Element-wise division.
|
Element-wise division.
|
||||||
|
|
||||||
Divide two arrays with numpy-style broadcasting semantics. Either or both
|
Divide two arrays with numpy-style broadcasting semantics. Either or both
|
||||||
@ -265,6 +288,8 @@ void init_ops(py::module_& m) {
|
|||||||
py::kw_only(),
|
py::kw_only(),
|
||||||
"stream"_a = none,
|
"stream"_a = none,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
|
remainder(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
Element-wise remainder of division.
|
Element-wise remainder of division.
|
||||||
|
|
||||||
Computes the remainder of dividing a with b with numpy-style
|
Computes the remainder of dividing a with b with numpy-style
|
||||||
@ -290,6 +315,8 @@ void init_ops(py::module_& m) {
|
|||||||
py::kw_only(),
|
py::kw_only(),
|
||||||
"stream"_a = none,
|
"stream"_a = none,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
|
equal(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
Element-wise equality.
|
Element-wise equality.
|
||||||
|
|
||||||
Equality comparison on two arrays with numpy-style broadcasting semantics.
|
Equality comparison on two arrays with numpy-style broadcasting semantics.
|
||||||
@ -314,6 +341,8 @@ void init_ops(py::module_& m) {
|
|||||||
py::kw_only(),
|
py::kw_only(),
|
||||||
"stream"_a = none,
|
"stream"_a = none,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
|
not_equal(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
Element-wise not equal.
|
Element-wise not equal.
|
||||||
|
|
||||||
Not equal comparison on two arrays with numpy-style broadcasting semantics.
|
Not equal comparison on two arrays with numpy-style broadcasting semantics.
|
||||||
@ -338,6 +367,8 @@ void init_ops(py::module_& m) {
|
|||||||
py::kw_only(),
|
py::kw_only(),
|
||||||
"stream"_a = none,
|
"stream"_a = none,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
|
less(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
Element-wise less than.
|
Element-wise less than.
|
||||||
|
|
||||||
Strict less than on two arrays with numpy-style broadcasting semantics.
|
Strict less than on two arrays with numpy-style broadcasting semantics.
|
||||||
@ -362,6 +393,8 @@ void init_ops(py::module_& m) {
|
|||||||
py::kw_only(),
|
py::kw_only(),
|
||||||
"stream"_a = none,
|
"stream"_a = none,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
|
less_equal(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
Element-wise less than or equal.
|
Element-wise less than or equal.
|
||||||
|
|
||||||
Less than or equal on two arrays with numpy-style broadcasting semantics.
|
Less than or equal on two arrays with numpy-style broadcasting semantics.
|
||||||
@ -386,6 +419,8 @@ void init_ops(py::module_& m) {
|
|||||||
py::kw_only(),
|
py::kw_only(),
|
||||||
"stream"_a = none,
|
"stream"_a = none,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
|
greater(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
Element-wise greater than.
|
Element-wise greater than.
|
||||||
|
|
||||||
Strict greater than on two arrays with numpy-style broadcasting semantics.
|
Strict greater than on two arrays with numpy-style broadcasting semantics.
|
||||||
@ -410,6 +445,8 @@ void init_ops(py::module_& m) {
|
|||||||
py::kw_only(),
|
py::kw_only(),
|
||||||
"stream"_a = none,
|
"stream"_a = none,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
|
greater_equal(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
Element-wise greater or equal.
|
Element-wise greater or equal.
|
||||||
|
|
||||||
Greater than or equal on two arrays with numpy-style broadcasting semantics.
|
Greater than or equal on two arrays with numpy-style broadcasting semantics.
|
||||||
@ -438,6 +475,8 @@ void init_ops(py::module_& m) {
|
|||||||
"equal_nan"_a = false,
|
"equal_nan"_a = false,
|
||||||
"stream"_a = none,
|
"stream"_a = none,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
|
array_equal(a: Union[scalar, array], b: Union[scalar, array], equal_nan: bool = False, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
Array equality check.
|
Array equality check.
|
||||||
|
|
||||||
Compare two arrays for equality. Returns ``True`` if and only if the arrays
|
Compare two arrays for equality. Returns ``True`` if and only if the arrays
|
||||||
@ -462,6 +501,8 @@ void init_ops(py::module_& m) {
|
|||||||
py::kw_only(),
|
py::kw_only(),
|
||||||
"stream"_a = none,
|
"stream"_a = none,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
|
matmul(a: array, b: array, /, *, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
Matrix multiplication.
|
Matrix multiplication.
|
||||||
|
|
||||||
Perform the (possibly batched) matrix multiplication of two arrays. This function supports
|
Perform the (possibly batched) matrix multiplication of two arrays. This function supports
|
||||||
@ -492,6 +533,8 @@ void init_ops(py::module_& m) {
|
|||||||
py::kw_only(),
|
py::kw_only(),
|
||||||
"stream"_a = none,
|
"stream"_a = none,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
|
square(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
Element-wise square.
|
Element-wise square.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -508,6 +551,8 @@ void init_ops(py::module_& m) {
|
|||||||
py::kw_only(),
|
py::kw_only(),
|
||||||
"stream"_a = none,
|
"stream"_a = none,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
|
sqrt(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
Element-wise square root.
|
Element-wise square root.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -524,6 +569,8 @@ void init_ops(py::module_& m) {
|
|||||||
py::kw_only(),
|
py::kw_only(),
|
||||||
"stream"_a = none,
|
"stream"_a = none,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
|
rsqrt(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
Element-wise reciprocal and square root.
|
Element-wise reciprocal and square root.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -540,6 +587,8 @@ void init_ops(py::module_& m) {
|
|||||||
py::kw_only(),
|
py::kw_only(),
|
||||||
"stream"_a = none,
|
"stream"_a = none,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
|
reciprocal(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
Element-wise reciprocal.
|
Element-wise reciprocal.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -558,6 +607,8 @@ void init_ops(py::module_& m) {
|
|||||||
py::kw_only(),
|
py::kw_only(),
|
||||||
"stream"_a = none,
|
"stream"_a = none,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
|
logical_not(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
Element-wise logical not.
|
Element-wise logical not.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -578,6 +629,8 @@ void init_ops(py::module_& m) {
|
|||||||
py::kw_only(),
|
py::kw_only(),
|
||||||
"stream"_a = none,
|
"stream"_a = none,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
|
logaddexp(a: Union[scalar, array], b: Union[scalar, array], /, *, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
Element-wise log-add-exp.
|
Element-wise log-add-exp.
|
||||||
|
|
||||||
This is a numerically stable log-add-exp of two arrays with numpy-style
|
This is a numerically stable log-add-exp of two arrays with numpy-style
|
||||||
@ -600,6 +653,8 @@ void init_ops(py::module_& m) {
|
|||||||
py::kw_only(),
|
py::kw_only(),
|
||||||
"stream"_a = none,
|
"stream"_a = none,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
|
exp(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
Element-wise exponential.
|
Element-wise exponential.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -616,6 +671,8 @@ void init_ops(py::module_& m) {
|
|||||||
py::kw_only(),
|
py::kw_only(),
|
||||||
"stream"_a = none,
|
"stream"_a = none,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
|
erf(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
Element-wise error function.
|
Element-wise error function.
|
||||||
|
|
||||||
.. math::
|
.. math::
|
||||||
@ -635,6 +692,8 @@ void init_ops(py::module_& m) {
|
|||||||
py::kw_only(),
|
py::kw_only(),
|
||||||
"stream"_a = none,
|
"stream"_a = none,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
|
erfinv(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
Element-wise inverse of :func:`erf`.
|
Element-wise inverse of :func:`erf`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -651,6 +710,8 @@ void init_ops(py::module_& m) {
|
|||||||
py::kw_only(),
|
py::kw_only(),
|
||||||
"stream"_a = none,
|
"stream"_a = none,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
|
sin(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
Element-wise sine.
|
Element-wise sine.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -667,6 +728,8 @@ void init_ops(py::module_& m) {
|
|||||||
py::kw_only(),
|
py::kw_only(),
|
||||||
"stream"_a = none,
|
"stream"_a = none,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
|
cos(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
Element-wise cosine.
|
Element-wise cosine.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -683,6 +746,8 @@ void init_ops(py::module_& m) {
|
|||||||
py::kw_only(),
|
py::kw_only(),
|
||||||
"stream"_a = none,
|
"stream"_a = none,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
|
tan(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
Element-wise tangent.
|
Element-wise tangent.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -699,6 +764,8 @@ void init_ops(py::module_& m) {
|
|||||||
py::kw_only(),
|
py::kw_only(),
|
||||||
"stream"_a = none,
|
"stream"_a = none,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
|
arcsin(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
Element-wise inverse sine.
|
Element-wise inverse sine.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -715,6 +782,8 @@ void init_ops(py::module_& m) {
|
|||||||
py::kw_only(),
|
py::kw_only(),
|
||||||
"stream"_a = none,
|
"stream"_a = none,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
|
arccos(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
Element-wise inverse cosine.
|
Element-wise inverse cosine.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -731,6 +800,8 @@ void init_ops(py::module_& m) {
|
|||||||
py::kw_only(),
|
py::kw_only(),
|
||||||
"stream"_a = none,
|
"stream"_a = none,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
|
arctan(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
Element-wise inverse tangent.
|
Element-wise inverse tangent.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -747,6 +818,8 @@ void init_ops(py::module_& m) {
|
|||||||
py::kw_only(),
|
py::kw_only(),
|
||||||
"stream"_a = none,
|
"stream"_a = none,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
|
sinh(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
Element-wise hyperbolic sine.
|
Element-wise hyperbolic sine.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -763,6 +836,8 @@ void init_ops(py::module_& m) {
|
|||||||
py::kw_only(),
|
py::kw_only(),
|
||||||
"stream"_a = none,
|
"stream"_a = none,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
|
cosh(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
Element-wise hyperbolic cosine.
|
Element-wise hyperbolic cosine.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -779,6 +854,8 @@ void init_ops(py::module_& m) {
|
|||||||
py::kw_only(),
|
py::kw_only(),
|
||||||
"stream"_a = none,
|
"stream"_a = none,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
|
tanh(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
Element-wise hyperbolic tangent.
|
Element-wise hyperbolic tangent.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -795,6 +872,8 @@ void init_ops(py::module_& m) {
|
|||||||
py::kw_only(),
|
py::kw_only(),
|
||||||
"stream"_a = none,
|
"stream"_a = none,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
|
arcsinh(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
Element-wise inverse hyperbolic sine.
|
Element-wise inverse hyperbolic sine.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -811,6 +890,8 @@ void init_ops(py::module_& m) {
|
|||||||
py::kw_only(),
|
py::kw_only(),
|
||||||
"stream"_a = none,
|
"stream"_a = none,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
|
arccosh(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
Element-wise inverse hyperbolic cosine.
|
Element-wise inverse hyperbolic cosine.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -827,6 +908,8 @@ void init_ops(py::module_& m) {
|
|||||||
py::kw_only(),
|
py::kw_only(),
|
||||||
"stream"_a = none,
|
"stream"_a = none,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
|
arctanh(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
Element-wise inverse hyperbolic tangent.
|
Element-wise inverse hyperbolic tangent.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -843,6 +926,8 @@ void init_ops(py::module_& m) {
|
|||||||
py::kw_only(),
|
py::kw_only(),
|
||||||
"stream"_a = none,
|
"stream"_a = none,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
|
log(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
Element-wise natural logarithm.
|
Element-wise natural logarithm.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -859,6 +944,8 @@ void init_ops(py::module_& m) {
|
|||||||
py::kw_only(),
|
py::kw_only(),
|
||||||
"stream"_a = none,
|
"stream"_a = none,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
|
log2(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
Element-wise base-2 logarithm.
|
Element-wise base-2 logarithm.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -875,6 +962,8 @@ void init_ops(py::module_& m) {
|
|||||||
py::kw_only(),
|
py::kw_only(),
|
||||||
"stream"_a = none,
|
"stream"_a = none,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
|
log10(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
Element-wise base-10 logarithm.
|
Element-wise base-10 logarithm.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -891,6 +980,8 @@ void init_ops(py::module_& m) {
|
|||||||
py::kw_only(),
|
py::kw_only(),
|
||||||
"stream"_a = none,
|
"stream"_a = none,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
|
log1p(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
Element-wise natural log of one plus the array.
|
Element-wise natural log of one plus the array.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -907,6 +998,8 @@ void init_ops(py::module_& m) {
|
|||||||
py::kw_only(),
|
py::kw_only(),
|
||||||
"stream"_a = none,
|
"stream"_a = none,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
|
stop_gradient(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
Stop gradients from being computed.
|
Stop gradients from being computed.
|
||||||
|
|
||||||
The operation is the identity but it prevents gradients from flowing
|
The operation is the identity but it prevents gradients from flowing
|
||||||
@ -927,6 +1020,8 @@ void init_ops(py::module_& m) {
|
|||||||
py::kw_only(),
|
py::kw_only(),
|
||||||
"stream"_a = none,
|
"stream"_a = none,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
|
sigmoid(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
Element-wise logistic sigmoid.
|
Element-wise logistic sigmoid.
|
||||||
|
|
||||||
The logistic sigmoid function is:
|
The logistic sigmoid function is:
|
||||||
@ -952,6 +1047,8 @@ void init_ops(py::module_& m) {
|
|||||||
py::kw_only(),
|
py::kw_only(),
|
||||||
"stream"_a = none,
|
"stream"_a = none,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
|
power(a: Union[scalar, array], b: Union[scalar, array], /, *, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
Element-wise power operation.
|
Element-wise power operation.
|
||||||
|
|
||||||
Raise the elements of a to the powers in elements of b with numpy-style
|
Raise the elements of a to the powers in elements of b with numpy-style
|
||||||
@ -964,10 +1061,6 @@ void init_ops(py::module_& m) {
|
|||||||
Returns:
|
Returns:
|
||||||
array: Bases of ``a`` raised to powers in ``b``.
|
array: Bases of ``a`` raised to powers in ``b``.
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
{
|
|
||||||
// Disable function signature just for arange which we write manually
|
|
||||||
py::options options;
|
|
||||||
options.disable_function_signatures();
|
|
||||||
m.def(
|
m.def(
|
||||||
"arange",
|
"arange",
|
||||||
[](Scalar stop, std::optional<Dtype> dtype_, StreamOrDevice s) {
|
[](Scalar stop, std::optional<Dtype> dtype_, StreamOrDevice s) {
|
||||||
@ -1024,8 +1117,7 @@ void init_ops(py::module_& m) {
|
|||||||
? dtype_.value()
|
? dtype_.value()
|
||||||
: promote_types(
|
: promote_types(
|
||||||
scalar_to_dtype(start),
|
scalar_to_dtype(start),
|
||||||
promote_types(
|
promote_types(scalar_to_dtype(stop), scalar_to_dtype(step)));
|
||||||
scalar_to_dtype(stop), scalar_to_dtype(step)));
|
|
||||||
|
|
||||||
return arange(
|
return arange(
|
||||||
scalar_to_double(start),
|
scalar_to_double(start),
|
||||||
@ -1065,7 +1157,6 @@ void init_ops(py::module_& m) {
|
|||||||
This can lead to unexpected results for example if `start + step`
|
This can lead to unexpected results for example if `start + step`
|
||||||
is a fractional value and the `dtype` is integral.
|
is a fractional value and the `dtype` is integral.
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
}
|
|
||||||
m.def(
|
m.def(
|
||||||
"take",
|
"take",
|
||||||
[](const array& a,
|
[](const array& a,
|
||||||
@ -1085,6 +1176,8 @@ void init_ops(py::module_& m) {
|
|||||||
py::kw_only(),
|
py::kw_only(),
|
||||||
"stream"_a = none,
|
"stream"_a = none,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
|
take(a: array, /, indices: array, axis: Optional[int] = None, *, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
Take elements along an axis.
|
Take elements along an axis.
|
||||||
|
|
||||||
The elements are taken from ``indices`` along the specified axis.
|
The elements are taken from ``indices`` along the specified axis.
|
||||||
@ -1121,6 +1214,8 @@ void init_ops(py::module_& m) {
|
|||||||
py::kw_only(),
|
py::kw_only(),
|
||||||
"stream"_a = none,
|
"stream"_a = none,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
|
take_along_axis(a: array, /, indices: array, axis: Optional[int] = None, *, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
Take values along an axis at the specified indices.
|
Take values along an axis at the specified indices.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -1153,6 +1248,8 @@ void init_ops(py::module_& m) {
|
|||||||
py::kw_only(),
|
py::kw_only(),
|
||||||
"stream"_a = none,
|
"stream"_a = none,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
|
full(shape: Union[int, List[int]], vals: Union[scalar, array], dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
Construct an array with the given value.
|
Construct an array with the given value.
|
||||||
|
|
||||||
Constructs an array of size ``shape`` filled with ``vals``. If ``vals``
|
Constructs an array of size ``shape`` filled with ``vals``. If ``vals``
|
||||||
@ -1184,6 +1281,8 @@ void init_ops(py::module_& m) {
|
|||||||
py::kw_only(),
|
py::kw_only(),
|
||||||
"stream"_a = none,
|
"stream"_a = none,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
|
zeros(shape: Union[int, List[int]], dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
Construct an array of zeros.
|
Construct an array of zeros.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -1202,6 +1301,8 @@ void init_ops(py::module_& m) {
|
|||||||
py::kw_only(),
|
py::kw_only(),
|
||||||
"stream"_a = none,
|
"stream"_a = none,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
|
zeros_like(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
An array of zeros like the input.
|
An array of zeros like the input.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -1227,6 +1328,8 @@ void init_ops(py::module_& m) {
|
|||||||
py::kw_only(),
|
py::kw_only(),
|
||||||
"stream"_a = none,
|
"stream"_a = none,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
|
ones(shape: Union[int, List[int]], dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
Construct an array of ones.
|
Construct an array of ones.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -1245,6 +1348,8 @@ void init_ops(py::module_& m) {
|
|||||||
py::kw_only(),
|
py::kw_only(),
|
||||||
"stream"_a = none,
|
"stream"_a = none,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
|
ones_like(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
An array of ones like the input.
|
An array of ones like the input.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -1312,6 +1417,8 @@ void init_ops(py::module_& m) {
|
|||||||
py::kw_only(),
|
py::kw_only(),
|
||||||
"stream"_a = none,
|
"stream"_a = none,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
|
allclose(a: array, b: array, /, rtol: float = 1e-05, atol: float = 1e-08, *, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
Approximate comparison of two arrays.
|
Approximate comparison of two arrays.
|
||||||
|
|
||||||
The arrays are considered equal if:
|
The arrays are considered equal if:
|
||||||
@ -1347,6 +1454,8 @@ void init_ops(py::module_& m) {
|
|||||||
py::kw_only(),
|
py::kw_only(),
|
||||||
"stream"_a = none,
|
"stream"_a = none,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
|
all(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
An `and` reduction over the given axes.
|
An `and` reduction over the given axes.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -1375,6 +1484,8 @@ void init_ops(py::module_& m) {
|
|||||||
py::kw_only(),
|
py::kw_only(),
|
||||||
"stream"_a = none,
|
"stream"_a = none,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
|
any(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
An `or` reduction over the given axes.
|
An `or` reduction over the given axes.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -1397,8 +1508,11 @@ void init_ops(py::module_& m) {
|
|||||||
"a"_a,
|
"a"_a,
|
||||||
"b"_a,
|
"b"_a,
|
||||||
py::pos_only(),
|
py::pos_only(),
|
||||||
|
py::kw_only(),
|
||||||
"stream"_a = none,
|
"stream"_a = none,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
|
minimum(a: Union[scalar, array], b: Union[scalar, array], /, *, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
Element-wise minimum.
|
Element-wise minimum.
|
||||||
|
|
||||||
Take the element-wise min of two arrays with numpy-style broadcasting
|
Take the element-wise min of two arrays with numpy-style broadcasting
|
||||||
@ -1423,6 +1537,8 @@ void init_ops(py::module_& m) {
|
|||||||
py::kw_only(),
|
py::kw_only(),
|
||||||
"stream"_a = none,
|
"stream"_a = none,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
|
maximum(a: Union[scalar, array], b: Union[scalar, array], /, *, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
Element-wise maximum.
|
Element-wise maximum.
|
||||||
|
|
||||||
Take the element-wise max of two arrays with numpy-style broadcasting
|
Take the element-wise max of two arrays with numpy-style broadcasting
|
||||||
@ -1452,6 +1568,8 @@ void init_ops(py::module_& m) {
|
|||||||
py::kw_only(),
|
py::kw_only(),
|
||||||
"stream"_a = none,
|
"stream"_a = none,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
|
transpose(a: array, /, axes: Optional[List[int]] = None, *, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
Transpose the dimensions of the array.
|
Transpose the dimensions of the array.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -1477,6 +1595,8 @@ void init_ops(py::module_& m) {
|
|||||||
py::kw_only(),
|
py::kw_only(),
|
||||||
"stream"_a = none,
|
"stream"_a = none,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
|
sum(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
Sum reduce the array over the given axes.
|
Sum reduce the array over the given axes.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -1505,6 +1625,8 @@ void init_ops(py::module_& m) {
|
|||||||
py::kw_only(),
|
py::kw_only(),
|
||||||
"stream"_a = none,
|
"stream"_a = none,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
|
prod(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
An product reduction over the given axes.
|
An product reduction over the given axes.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -1533,6 +1655,8 @@ void init_ops(py::module_& m) {
|
|||||||
py::kw_only(),
|
py::kw_only(),
|
||||||
"stream"_a = none,
|
"stream"_a = none,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
|
min(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
An `min` reduction over the given axes.
|
An `min` reduction over the given axes.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -1561,6 +1685,8 @@ void init_ops(py::module_& m) {
|
|||||||
py::kw_only(),
|
py::kw_only(),
|
||||||
"stream"_a = none,
|
"stream"_a = none,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
|
max(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
An `max` reduction over the given axes.
|
An `max` reduction over the given axes.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -1589,6 +1715,8 @@ void init_ops(py::module_& m) {
|
|||||||
py::kw_only(),
|
py::kw_only(),
|
||||||
"stream"_a = none,
|
"stream"_a = none,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
|
logsumexp(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
A `log-sum-exp` reduction over the given axes.
|
A `log-sum-exp` reduction over the given axes.
|
||||||
|
|
||||||
The log-sum-exp reduction is a numerically stable version of:
|
The log-sum-exp reduction is a numerically stable version of:
|
||||||
@ -1623,6 +1751,8 @@ void init_ops(py::module_& m) {
|
|||||||
py::kw_only(),
|
py::kw_only(),
|
||||||
"stream"_a = none,
|
"stream"_a = none,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
|
mean(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
Compute the mean(s) over the given axes.
|
Compute the mean(s) over the given axes.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -1653,6 +1783,8 @@ void init_ops(py::module_& m) {
|
|||||||
py::kw_only(),
|
py::kw_only(),
|
||||||
"stream"_a = none,
|
"stream"_a = none,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
|
var(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, ddof: int = 0, *, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
Compute the variance(s) over the given axes.
|
Compute the variance(s) over the given axes.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -1688,6 +1820,8 @@ void init_ops(py::module_& m) {
|
|||||||
py::kw_only(),
|
py::kw_only(),
|
||||||
"stream"_a = none,
|
"stream"_a = none,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
|
split(a: array, /, indices_or_sections: Union[int, List[int]], axis: int = 0, *, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
Split an array along a given axis.
|
Split an array along a given axis.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -1721,6 +1855,8 @@ void init_ops(py::module_& m) {
|
|||||||
py::kw_only(),
|
py::kw_only(),
|
||||||
"stream"_a = none,
|
"stream"_a = none,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
|
argmin(a: array, /, axis: Union[None, int] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
Indices of the minimum values along the axis.
|
Indices of the minimum values along the axis.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -1752,6 +1888,8 @@ void init_ops(py::module_& m) {
|
|||||||
py::kw_only(),
|
py::kw_only(),
|
||||||
"stream"_a = none,
|
"stream"_a = none,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
|
argmax(a: array, /, axis: Union[None, int] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
Indices of the maximum values along the axis.
|
Indices of the maximum values along the axis.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -1779,6 +1917,8 @@ void init_ops(py::module_& m) {
|
|||||||
py::kw_only(),
|
py::kw_only(),
|
||||||
"stream"_a = none,
|
"stream"_a = none,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
|
sort(a: array, /, axis: Union[None, int] = -1, *, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
Returns a sorted copy of the array.
|
Returns a sorted copy of the array.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -1805,6 +1945,8 @@ void init_ops(py::module_& m) {
|
|||||||
py::kw_only(),
|
py::kw_only(),
|
||||||
"stream"_a = none,
|
"stream"_a = none,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
|
argsort(a: array, /, axis: Union[None, int] = -1, *, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
Returns the indices that sort the array.
|
Returns the indices that sort the array.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -1832,6 +1974,8 @@ void init_ops(py::module_& m) {
|
|||||||
py::kw_only(),
|
py::kw_only(),
|
||||||
"stream"_a = none,
|
"stream"_a = none,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
|
partition(a: array, /, kth: int, axis: Union[None, int] = -1, *, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
Returns a partitioned copy of the array such that the smaller ``kth``
|
Returns a partitioned copy of the array such that the smaller ``kth``
|
||||||
elements are first.
|
elements are first.
|
||||||
|
|
||||||
@ -1866,6 +2010,8 @@ void init_ops(py::module_& m) {
|
|||||||
py::kw_only(),
|
py::kw_only(),
|
||||||
"stream"_a = none,
|
"stream"_a = none,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
|
argpartition(a: array, /, kth: int, axis: Union[None, int] = -1, *, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
Returns the indices that partition the array.
|
Returns the indices that partition the array.
|
||||||
|
|
||||||
The ordering of the elements within a partition in given by the indices
|
The ordering of the elements within a partition in given by the indices
|
||||||
@ -1901,6 +2047,8 @@ void init_ops(py::module_& m) {
|
|||||||
py::kw_only(),
|
py::kw_only(),
|
||||||
"stream"_a = none,
|
"stream"_a = none,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
|
topk(a: array, /, k: int, axis: Union[None, int] = -1, *, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
Returns the ``k`` largest elements from the input along a given axis.
|
Returns the ``k`` largest elements from the input along a given axis.
|
||||||
|
|
||||||
The elements will not necessarily be in sorted order.
|
The elements will not necessarily be in sorted order.
|
||||||
@ -1926,6 +2074,8 @@ void init_ops(py::module_& m) {
|
|||||||
py::kw_only(),
|
py::kw_only(),
|
||||||
"stream"_a = none,
|
"stream"_a = none,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
|
broadcast_to(a: Union[scalar, array], /, shape: List[int], *, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
Broadcast an array to the given shape.
|
Broadcast an array to the given shape.
|
||||||
|
|
||||||
The broadcasting semantics are the same as Numpy.
|
The broadcasting semantics are the same as Numpy.
|
||||||
@ -1948,6 +2098,8 @@ void init_ops(py::module_& m) {
|
|||||||
py::kw_only(),
|
py::kw_only(),
|
||||||
"stream"_a = none,
|
"stream"_a = none,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
|
softmax(a: array, /, axis: Union[None, int, List[int]] = None, *, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
Perform the softmax along the given axis.
|
Perform the softmax along the given axis.
|
||||||
|
|
||||||
This operation is a numerically stable version of:
|
This operation is a numerically stable version of:
|
||||||
@ -1982,6 +2134,8 @@ void init_ops(py::module_& m) {
|
|||||||
py::kw_only(),
|
py::kw_only(),
|
||||||
"stream"_a = none,
|
"stream"_a = none,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
|
concatenate(arrays: List[array], axis: Optional[int] = 0, *, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
Concatenate the arrays along the given axis.
|
Concatenate the arrays along the given axis.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -2024,6 +2178,8 @@ void init_ops(py::module_& m) {
|
|||||||
py::kw_only(),
|
py::kw_only(),
|
||||||
"stream"_a = none,
|
"stream"_a = none,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
|
pad(a: array, pad_with: Union[int, Tuple[int], Tuple[int, int], List[Tuple[int, int]]], constant_values: Union[scalar, array] = 0, *, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
Pad an array with a constant value
|
Pad an array with a constant value
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -2067,6 +2223,8 @@ void init_ops(py::module_& m) {
|
|||||||
py::kw_only(),
|
py::kw_only(),
|
||||||
"stream"_a = none,
|
"stream"_a = none,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
|
as_strided(a: array, /, shape: Optional[List[int]] = None, strides: Optional[List[int]] = None, offset: int = 0, *, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
Create a view into the array with the given shape and strides.
|
Create a view into the array with the given shape and strides.
|
||||||
|
|
||||||
The resulting array will always be as if the provided array was row
|
The resulting array will always be as if the provided array was row
|
||||||
@ -2113,6 +2271,8 @@ void init_ops(py::module_& m) {
|
|||||||
"inclusive"_a = true,
|
"inclusive"_a = true,
|
||||||
"stream"_a = none,
|
"stream"_a = none,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
|
cumsum(a: array, /, axis: Optional[int] = None, *, reverse: bool = False, inclusive: bool = True, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
Return the cumulative sum of the elements along the given axis.
|
Return the cumulative sum of the elements along the given axis.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -2145,6 +2305,8 @@ void init_ops(py::module_& m) {
|
|||||||
"inclusive"_a = true,
|
"inclusive"_a = true,
|
||||||
"stream"_a = none,
|
"stream"_a = none,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
|
cumprod(a: array, /, axis: Optional[int] = None, *, reverse: bool = False, inclusive: bool = True, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
Return the cumulative product of the elements along the given axis.
|
Return the cumulative product of the elements along the given axis.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -2177,6 +2339,8 @@ void init_ops(py::module_& m) {
|
|||||||
"inclusive"_a = true,
|
"inclusive"_a = true,
|
||||||
"stream"_a = none,
|
"stream"_a = none,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
|
cummax(a: array, /, axis: Optional[int] = None, *, reverse: bool = False, inclusive: bool = True, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
Return the cumulative maximum of the elements along the given axis.
|
Return the cumulative maximum of the elements along the given axis.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -2209,6 +2373,8 @@ void init_ops(py::module_& m) {
|
|||||||
"inclusive"_a = true,
|
"inclusive"_a = true,
|
||||||
"stream"_a = none,
|
"stream"_a = none,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
|
cummin(a: array, /, axis: Optional[int] = None, *, reverse: bool = False, inclusive: bool = True, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
Return the cumulative minimum of the elements along the given axis.
|
Return the cumulative minimum of the elements along the given axis.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -2275,6 +2441,8 @@ void init_ops(py::module_& m) {
|
|||||||
py::kw_only(),
|
py::kw_only(),
|
||||||
"stream"_a = none,
|
"stream"_a = none,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
|
convolve(a: array, v: array, /, mode: str = "full", *, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
The discrete convolution of 1D arrays.
|
The discrete convolution of 1D arrays.
|
||||||
|
|
||||||
If ``v`` is longer than ``a``, then they are swapped.
|
If ``v`` is longer than ``a``, then they are swapped.
|
||||||
@ -2301,6 +2469,8 @@ void init_ops(py::module_& m) {
|
|||||||
py::kw_only(),
|
py::kw_only(),
|
||||||
"stream"_a = none,
|
"stream"_a = none,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
|
conv1d(input: array, weight: array, /, stride: int = 1, padding: int = 0, dilation: int = 1, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
1D convolution over an input with several channels
|
1D convolution over an input with several channels
|
||||||
|
|
||||||
Note: Only the default ``groups=1`` is currently supported.
|
Note: Only the default ``groups=1`` is currently supported.
|
||||||
@ -2360,6 +2530,8 @@ void init_ops(py::module_& m) {
|
|||||||
py::kw_only(),
|
py::kw_only(),
|
||||||
"stream"_a = none,
|
"stream"_a = none,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
|
conv2d(input: array, weight: array, /, stride: Union[int, Tuple[int, int]] = 1, padding: Union[int, Tuple[int, int]] = 0, dilation: Union[int, Tuple[int, int]] = 1, groups: Union[int, Tuple[int, int]] = 1, *, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
2D convolution over an input with several channels
|
2D convolution over an input with several channels
|
||||||
|
|
||||||
Note: Only the default ``groups=1`` is currently supported.
|
Note: Only the default ``groups=1`` is currently supported.
|
||||||
@ -2390,6 +2562,8 @@ void init_ops(py::module_& m) {
|
|||||||
"retain_graph"_a = true,
|
"retain_graph"_a = true,
|
||||||
py::kw_only(),
|
py::kw_only(),
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
|
save(file: str, arr: array, / , retain_graph: bool = True)
|
||||||
|
|
||||||
Save the array to a binary file in ``.npy`` format.
|
Save the array to a binary file in ``.npy`` format.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -2408,6 +2582,8 @@ void init_ops(py::module_& m) {
|
|||||||
py::pos_only(),
|
py::pos_only(),
|
||||||
py::kw_only(),
|
py::kw_only(),
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
|
savez(file: str, *args, **kwargs)
|
||||||
|
|
||||||
Save several arrays to a binary file in uncompressed ``.npz`` format.
|
Save several arrays to a binary file in uncompressed ``.npz`` format.
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
@ -2440,6 +2616,8 @@ void init_ops(py::module_& m) {
|
|||||||
py::pos_only(),
|
py::pos_only(),
|
||||||
py::kw_only(),
|
py::kw_only(),
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
|
savez_compressed(file: str, *args, **kwargs)
|
||||||
|
|
||||||
Save several arrays to a binary file in compressed ``.npz`` format.
|
Save several arrays to a binary file in compressed ``.npz`` format.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -2457,6 +2635,8 @@ void init_ops(py::module_& m) {
|
|||||||
py::kw_only(),
|
py::kw_only(),
|
||||||
"stream"_a = none,
|
"stream"_a = none,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
|
load(file: str, /, *, stream: Union[None, Stream, Device] = None) -> Union[array, Dict[str, array]]
|
||||||
|
|
||||||
Load array(s) from a binary file in ``.npy`` or ``.npz`` format.
|
Load array(s) from a binary file in ``.npy`` or ``.npz`` format.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -2481,6 +2661,8 @@ void init_ops(py::module_& m) {
|
|||||||
py::kw_only(),
|
py::kw_only(),
|
||||||
"stream"_a = none,
|
"stream"_a = none,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
|
where(condition: Union[scalar, array], x: Union[scalar, array], y: Union[scalar, array], /, *, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
Select from ``x`` or ``y`` according to ``condition``.
|
Select from ``x`` or ``y`` according to ``condition``.
|
||||||
|
|
||||||
The condition and input arrays must be the same shape or broadcastable
|
The condition and input arrays must be the same shape or broadcastable
|
||||||
|
@ -15,8 +15,8 @@ namespace py = pybind11;
|
|||||||
using namespace mlx::core;
|
using namespace mlx::core;
|
||||||
|
|
||||||
using IntOrVec = std::variant<std::monostate, int, std::vector<int>>;
|
using IntOrVec = std::variant<std::monostate, int, std::vector<int>>;
|
||||||
using ScalarOrArray =
|
using ScalarOrArray = std::
|
||||||
std::variant<py::bool_, py::int_, py::float_, std::complex<float>, array>;
|
variant<py::bool_, py::int_, py::float_, std::complex<float>, py::object>;
|
||||||
static constexpr std::monostate none{};
|
static constexpr std::monostate none{};
|
||||||
|
|
||||||
inline std::vector<int> get_reduce_axes(const IntOrVec& v, int dims) {
|
inline std::vector<int> get_reduce_axes(const IntOrVec& v, int dims) {
|
||||||
@ -32,6 +32,14 @@ inline std::vector<int> get_reduce_axes(const IntOrVec& v, int dims) {
|
|||||||
return axes;
|
return axes;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline array to_array_with_accessor(py::object obj) {
|
||||||
|
if (py::hasattr(obj, "__mlx_array__")) {
|
||||||
|
return obj.attr("__mlx_array__")().cast<array>();
|
||||||
|
} else {
|
||||||
|
return obj.cast<array>();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
inline array to_array(
|
inline array to_array(
|
||||||
const ScalarOrArray& v,
|
const ScalarOrArray& v,
|
||||||
std::optional<Dtype> dtype = std::nullopt) {
|
std::optional<Dtype> dtype = std::nullopt) {
|
||||||
@ -48,7 +56,7 @@ inline array to_array(
|
|||||||
} else if (auto pv = std::get_if<std::complex<float>>(&v); pv) {
|
} else if (auto pv = std::get_if<std::complex<float>>(&v); pv) {
|
||||||
return array(static_cast<complex64_t>(*pv), complex64);
|
return array(static_cast<complex64_t>(*pv), complex64);
|
||||||
} else {
|
} else {
|
||||||
return std::get<array>(v);
|
return to_array_with_accessor(std::get<py::object>(v));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -60,13 +68,16 @@ inline std::pair<array, array> to_arrays(
|
|||||||
// - If a is an array but b is not, treat b as a weak python type
|
// - If a is an array but b is not, treat b as a weak python type
|
||||||
// - If b is an array but a is not, treat a as a weak python type
|
// - If b is an array but a is not, treat a as a weak python type
|
||||||
// - If neither is an array convert to arrays but leave their types alone
|
// - If neither is an array convert to arrays but leave their types alone
|
||||||
if (auto pa = std::get_if<array>(&a); pa) {
|
if (auto pa = std::get_if<py::object>(&a); pa) {
|
||||||
if (auto pb = std::get_if<array>(&b); pb) {
|
auto arr_a = to_array_with_accessor(*pa);
|
||||||
return {*pa, *pb};
|
if (auto pb = std::get_if<py::object>(&b); pb) {
|
||||||
|
auto arr_b = to_array_with_accessor(*pb);
|
||||||
|
return {arr_a, arr_b};
|
||||||
}
|
}
|
||||||
return {*pa, to_array(b, pa->dtype())};
|
return {arr_a, to_array(b, arr_a.dtype())};
|
||||||
} else if (auto pb = std::get_if<array>(&b); pb) {
|
} else if (auto pb = std::get_if<py::object>(&b); pb) {
|
||||||
return {to_array(a, pb->dtype()), *pb};
|
auto arr_b = to_array_with_accessor(*pb);
|
||||||
|
return {to_array(a, arr_b.dtype()), arr_b};
|
||||||
} else {
|
} else {
|
||||||
return {to_array(a), to_array(b)};
|
return {to_array(a), to_array(b)};
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user