Improve repeat using broadcasting and reshape (#318)

This commit is contained in:
Angelos Katharopoulos 2023-12-29 21:40:20 -08:00 committed by GitHub
parent 930b159885
commit a020a2d49d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -735,17 +735,18 @@ array repeat(const array& arr, int repeats, int axis, StreamOrDevice s) {
return arr;
}
std::vector<int> new_shape(arr.shape());
new_shape[axis] *= repeats;
// Broadcast to (S_1, S_2, ..., S_axis, repeats, S_axis+1, ...)
std::vector<int> shape(arr.shape());
shape.insert(shape.begin() + axis + 1, repeats);
array out = expand_dims(arr, axis + 1, s);
out = broadcast_to(out, shape, s);
std::vector<array> repeated_arrays;
repeated_arrays.reserve(repeats);
// Reshape back into a contiguous array where S_axis is now S_axis * repeats
shape.erase(shape.begin() + axis + 1);
shape[axis] *= repeats;
out = reshape(out, shape, s);
for (int i = 0; i < repeats; ++i) {
repeated_arrays.push_back(expand_dims(arr, -1, s));
}
array repeated = concatenate(repeated_arrays, axis + 1, s);
return reshape(repeated, new_shape, s);
return out;
}
array repeat(const array& arr, int repeats, StreamOrDevice s) {