Add tile op (#438)

This commit is contained in:
Diogo
2024-01-13 02:03:16 -05:00
committed by GitHub
parent 1b71487e1f
commit 2e29d0815b
7 changed files with 105 additions and 3 deletions

View File

@@ -753,6 +753,36 @@ array repeat(const array& arr, int repeats, StreamOrDevice s) {
return repeat(flatten(arr, s), repeats, 0, s);
}
array tile(
const array& arr,
std::vector<int> reps,
StreamOrDevice s /* = {} */) {
auto shape = arr.shape();
if (reps.size() < shape.size()) {
reps.insert(reps.begin(), shape.size() - reps.size(), 1);
}
if (reps.size() > shape.size()) {
shape.insert(shape.begin(), reps.size() - shape.size(), 1);
}
std::vector<int> expand_shape;
std::vector<int> broad_shape;
std::vector<int> final_shape;
for (int i = 0; i < shape.size(); i++) {
if (reps[i] != 1) {
expand_shape.push_back(1);
broad_shape.push_back(reps[i]);
}
expand_shape.push_back(shape[i]);
broad_shape.push_back(shape[i]);
final_shape.push_back(reps[i] * shape[i]);
}
auto x = reshape(arr, expand_shape, s);
x = broadcast_to(x, broad_shape, s);
return reshape(x, final_shape, s);
}
/** Pad an array with a constant value */
array pad(
const array& a,

View File

@@ -218,6 +218,8 @@ array stack(const std::vector<array>& arrays, StreamOrDevice s = {});
array repeat(const array& arr, int repeats, int axis, StreamOrDevice s = {});
array repeat(const array& arr, int repeats, StreamOrDevice s = {});
array tile(const array& arr, std::vector<int> reps, StreamOrDevice s = {});
/** Permutes the dimensions according to the given axes. */
array transpose(const array& a, std::vector<int> axes, StreamOrDevice s = {});
inline array transpose(