Shape and Strides 1 / N (#1645)

* shape and stride type def

* more shape
This commit is contained in:
Awni Hannun
2024-12-05 12:53:43 -08:00
committed by GitHub
parent c5b0928c1f
commit fc88fd9097
6 changed files with 178 additions and 242 deletions

View File

@@ -49,8 +49,8 @@ array astype(array a, Dtype dtype, StreamOrDevice s = {});
/** Create a view of an array with the given shape and strides. */
array as_strided(
array a,
std::vector<int> shape,
std::vector<size_t> strides,
Shape shape,
Strides strides,
size_t offset,
StreamOrDevice s = {});
@@ -58,31 +58,27 @@ array as_strided(
array copy(array a, StreamOrDevice s = {});
/** Fill an array of the given shape with the given value(s). */
array full(
std::vector<int> shape,
array vals,
Dtype dtype,
StreamOrDevice s = {});
array full(std::vector<int> shape, array vals, StreamOrDevice s = {});
array full(Shape shape, array vals, Dtype dtype, StreamOrDevice s = {});
array full(Shape shape, array vals, StreamOrDevice s = {});
template <typename T>
array full(std::vector<int> shape, T val, Dtype dtype, StreamOrDevice s = {}) {
array full(Shape shape, T val, Dtype dtype, StreamOrDevice s = {}) {
return full(std::move(shape), array(val, dtype), to_stream(s));
}
template <typename T>
array full(std::vector<int> shape, T val, StreamOrDevice s = {}) {
array full(Shape shape, T val, StreamOrDevice s = {}) {
return full(std::move(shape), array(val), to_stream(s));
}
/** Fill an array of the given shape with zeros. */
array zeros(const std::vector<int>& shape, Dtype dtype, StreamOrDevice s = {});
inline array zeros(const std::vector<int>& shape, StreamOrDevice s = {}) {
array zeros(const Shape& shape, Dtype dtype, StreamOrDevice s = {});
inline array zeros(const Shape& shape, StreamOrDevice s = {}) {
return zeros(shape, float32, s);
}
array zeros_like(const array& a, StreamOrDevice s = {});
/** Fill an array of the given shape with ones. */
array ones(const std::vector<int>& shape, Dtype dtype, StreamOrDevice s = {});
inline array ones(const std::vector<int>& shape, StreamOrDevice s = {}) {
array ones(const Shape& shape, Dtype dtype, StreamOrDevice s = {});
inline array ones(const Shape& shape, StreamOrDevice s = {}) {
return ones(shape, float32, s);
}
array ones_like(const array& a, StreamOrDevice s = {});
@@ -119,7 +115,7 @@ array tril(array x, int k = 0, StreamOrDevice s = {});
array triu(array x, int k = 0, StreamOrDevice s = {});
/** Reshape an array to the given shape. */
array reshape(const array& a, std::vector<int> shape, StreamOrDevice s = {});
array reshape(const array& a, Shape shape, StreamOrDevice s = {});
/** Flatten the dimensions in the range `[start_axis, end_axis]` . */
array flatten(
@@ -161,33 +157,29 @@ array expand_dims(const array& a, int axis, StreamOrDevice s = {});
/** Slice an array. */
array slice(
const array& a,
std::vector<int> start,
std::vector<int> stop,
std::vector<int> strides,
Shape start,
Shape stop,
Shape strides,
StreamOrDevice s = {});
/** Slice an array with a stride of 1 in each dimension. */
array slice(
const array& a,
std::vector<int> start,
std::vector<int> stop,
StreamOrDevice s = {});
array slice(const array& a, Shape start, Shape stop, StreamOrDevice s = {});
/** Update a slice from the source array */
array slice_update(
const array& src,
const array& update,
std::vector<int> start,
std::vector<int> stop,
std::vector<int> strides,
Shape start,
Shape stop,
Shape strides,
StreamOrDevice s = {});
/** Update a slice from the source array with stride 1 in each dimension */
array slice_update(
const array& src,
const array& update,
std::vector<int> start,
std::vector<int> stop,
Shape start,
Shape stop,
StreamOrDevice s = {});
/** Split an array into sub-arrays along a given axis. */
@@ -288,10 +280,7 @@ array pad(
array transpose(const array& a, StreamOrDevice s = {});
/** Broadcast an array to a given shape. */
array broadcast_to(
const array& a,
const std::vector<int>& shape,
StreamOrDevice s = {});
array broadcast_to(const array& a, const Shape& shape, StreamOrDevice s = {});
/** Broadcast a vector of arrays against one another. */
std::vector<array> broadcast_arrays(
@@ -917,13 +906,13 @@ array gather(
const array& a,
const std::vector<array>& indices,
const std::vector<int>& axes,
const std::vector<int>& slice_sizes,
const Shape& slice_sizes,
StreamOrDevice s = {});
inline array gather(
const array& a,
const array& indices,
int axis,
const std::vector<int>& slice_sizes,
const Shape& slice_sizes,
StreamOrDevice s = {}) {
return gather(a, {indices}, std::vector<int>{axis}, slice_sizes, s);
}
@@ -1459,24 +1448,13 @@ array view(const array& a, const Dtype& dtype, StreamOrDevice s = {});
/** Roll elements along an axis and introduce them on the other side */
array roll(const array& a, int shift, StreamOrDevice s = {});
array roll(
const array& a,
const std::vector<int>& shift,
StreamOrDevice s = {});
array roll(const array& a, const Shape& shift, StreamOrDevice s = {});
array roll(const array& a, int shift, int axis, StreamOrDevice s = {});
array roll(const array& a, int shift, const Shape& axes, StreamOrDevice s = {});
array roll(const array& a, const Shape& shift, int axis, StreamOrDevice s = {});
array roll(
const array& a,
int shift,
const std::vector<int>& axes,
StreamOrDevice s = {});
array roll(
const array& a,
const std::vector<int>& shift,
int axis,
StreamOrDevice s = {});
array roll(
const array& a,
const std::vector<int>& shift,
const Shape& shift,
const std::vector<int>& axes,
StreamOrDevice s = {});