* some updates for numpy 2.0 and array api

* some updates for numpy 2.0 and array api

* fix array api doc
This commit is contained in:
Awni Hannun
2024-07-26 10:40:49 -07:00
committed by GitHub
parent e9e53856d2
commit 7b456fd2c0
6 changed files with 70 additions and 37 deletions

View File

@@ -2061,7 +2061,7 @@ void init_ops(nb::module_& m) {
const std::optional<std::vector<int>>& axes,
StreamOrDevice s) {
if (axes.has_value()) {
return transpose(a, get_reduce_axes(axes.value(), a.ndim()), s);
return transpose(a, *axes, s);
} else {
return transpose(a, s);
}
@@ -2083,6 +2083,26 @@ void init_ops(nb::module_& m) {
Returns:
array: The transposed array.
)pbdoc");
m.def(
"permute_dims",
[](const array& a,
const std::optional<std::vector<int>>& axes,
StreamOrDevice s) {
if (axes.has_value()) {
return transpose(a, *axes, s);
} else {
return transpose(a, s);
}
},
nb::arg(),
"axes"_a = nb::none(),
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def permute_dims(a: array, /, axes: Optional[Sequence[int]] = None, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
See :func:`transpose`.
)pbdoc");
m.def(
"sum",
[](const array& a,
@@ -2666,6 +2686,26 @@ void init_ops(nb::module_& m) {
Returns:
array: The concatenated array.
)pbdoc");
m.def(
"concat",
[](const std::vector<array>& arrays,
std::optional<int> axis,
StreamOrDevice s) {
if (axis) {
return concatenate(arrays, *axis, s);
} else {
return concatenate(arrays, s);
}
},
nb::arg(),
"axis"_a.none() = 0,
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def concat(arrays: List[array], axis: Optional[int] = 0, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
See :func:`concatenate`.
)pbdoc");
m.def(
"stack",
[](const std::vector<array>& arrays,