Add support for repeat (#278)

* add repeat function

* fix styling

* optimizing repeat

* fixed minor issues

* not sure why that folder is there xD

* fixed now for sure

* test repeat not repeat test

* Fixed

---------

Co-authored-by: Bahaa Eddin tabbakha <bahaa@Bahaas-MacBook-Pro.local>
This commit is contained in:
Bahaa
2023-12-28 00:11:38 +03:00
committed by GitHub
parent 4417e37ede
commit ff2b58e299
6 changed files with 149 additions and 0 deletions

View File

@@ -718,6 +718,40 @@ array stack(const std::vector<array>& arrays, StreamOrDevice s /* = {} */) {
return stack(arrays, 0, s);
}
/** array repeat with axis */
array repeat(const array& arr, int repeats, int axis, StreamOrDevice s) {
axis = normalize_axis(axis, arr.ndim());
if (repeats < 0) {
throw std::invalid_argument(
"[repeat] Number of repeats cannot be negative");
}
if (repeats == 0) {
return array({}, arr.dtype());
}
if (repeats == 1) {
return arr;
}
std::vector<int> new_shape(arr.shape());
new_shape[axis] *= repeats;
std::vector<array> repeated_arrays;
repeated_arrays.reserve(repeats);
for (int i = 0; i < repeats; ++i) {
repeated_arrays.push_back(expand_dims(arr, -1, s));
}
array repeated = concatenate(repeated_arrays, axis + 1, s);
return reshape(repeated, new_shape, s);
}
array repeat(const array& arr, int repeats, StreamOrDevice s) {
return repeat(flatten(arr, s), repeats, 0, s);
}
/** Pad an array with a constant value */
array pad(
const array& a,

View File

@@ -214,6 +214,10 @@ array concatenate(const std::vector<array>& arrays, StreamOrDevice s = {});
array stack(const std::vector<array>& arrays, int axis, StreamOrDevice s = {});
array stack(const std::vector<array>& arrays, StreamOrDevice s = {});
/** Repeate an array along an axis. */
array repeat(const array& arr, int repeats, int axis, StreamOrDevice s = {});
array repeat(const array& arr, int repeats, StreamOrDevice s = {});
/** Permutes the dimensions according to the given axes. */
array transpose(const array& a, std::vector<int> axes, StreamOrDevice s = {});
inline array transpose(