More shape type (#1705)

* more shape type

* fix
This commit is contained in:
Awni Hannun
2024-12-19 08:08:20 -08:00
committed by GitHub
parent f17536af9c
commit e03f0372b1
38 changed files with 260 additions and 258 deletions

View File

@@ -189,13 +189,10 @@ array slice_update(
std::vector<array>
split(const array& a, int num_splits, int axis, StreamOrDevice s = {});
std::vector<array> split(const array& a, int num_splits, StreamOrDevice s = {});
std::vector<array> split(
const array& a,
const std::vector<int>& indices,
int axis,
StreamOrDevice s = {});
std::vector<array>
split(const array& a, const std::vector<int>& indices, StreamOrDevice s = {});
split(const array& a, const Shape& indices, int axis, StreamOrDevice s = {});
std::vector<array>
split(const array& a, const Shape& indices, StreamOrDevice s = {});
/** A vector of coordinate arrays from coordinate vectors. */
std::vector<array> meshgrid(
@@ -253,8 +250,8 @@ array moveaxis(
array pad(
const array& a,
const std::vector<int>& axes,
const std::vector<int>& low_pad_size,
const std::vector<int>& high_pad_size,
const Shape& low_pad_size,
const Shape& high_pad_size,
const array& pad_value = array(0),
const std::string mode = "constant",
StreamOrDevice s = {});
@@ -1453,7 +1450,11 @@ array view(const array& a, const Dtype& dtype, StreamOrDevice s = {});
array roll(const array& a, int shift, StreamOrDevice s = {});
array roll(const array& a, const Shape& shift, StreamOrDevice s = {});
array roll(const array& a, int shift, int axis, StreamOrDevice s = {});
array roll(const array& a, int shift, const Shape& axes, StreamOrDevice s = {});
array roll(
const array& a,
int shift,
const std::vector<int>& axes,
StreamOrDevice s = {});
array roll(const array& a, const Shape& shift, int axis, StreamOrDevice s = {});
array roll(
const array& a,